diff --git a/.github/workflows/deps.yaml b/.github/workflows/deps.yaml index e644dfc5..dbb2105f 100644 --- a/.github/workflows/deps.yaml +++ b/.github/workflows/deps.yaml @@ -17,12 +17,7 @@ jobs: with: ref: ${{ github.head_ref }} - shell: bash - run: make deps.source - - uses: actions/upload-artifact@v3 - with: - name: duckdb - path: duckdb/ - retention-days: 1 + run: make deps.header - uses: actions/upload-artifact@v3 with: name: duckdb_h @@ -35,10 +30,6 @@ jobs: - uses: actions/checkout@v3 with: ref: ${{ github.head_ref }} - - uses: actions/download-artifact@v3 - with: - name: duckdb - path: duckdb/ - shell: bash run: make deps.darwin.amd64 - uses: actions/upload-artifact@v3 @@ -53,10 +44,6 @@ jobs: - uses: actions/checkout@v3 with: ref: ${{ github.head_ref }} - - uses: actions/download-artifact@v3 - with: - name: duckdb - path: duckdb/ - shell: bash run: make deps.darwin.arm64 - uses: actions/upload-artifact@v3 @@ -71,10 +58,6 @@ jobs: - uses: actions/checkout@v3 with: ref: ${{ github.head_ref }} - - uses: actions/download-artifact@v3 - with: - name: duckdb - path: duckdb/ - shell: bash run: make deps.linux.amd64 - uses: actions/upload-artifact@v3 @@ -89,10 +72,6 @@ jobs: - uses: actions/checkout@v3 with: ref: ${{ github.head_ref }} - - uses: actions/download-artifact@v3 - with: - name: duckdb - path: duckdb/ - name: Install cross compile toolchain shell: bash run: | diff --git a/.gitignore b/.gitignore index 4edda059..150db584 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ .vscode .DS_Store run.sh -duckdb/ +duckdb-*/ +duckdb.zip .idea diff --git a/Makefile b/Makefile index 6bbe9b5e..7370147b 100644 --- a/Makefile +++ b/Makefile @@ -16,7 +16,7 @@ SRC_DIR := duckdb/src/amalgamation FILES := $(wildcard $(SRC_DIR)/*) .PHONY: deps.source -deps.source: +deps.header: git clone -b v${DUCKDB_VERSION} --depth 1 https://github.com/duckdb/duckdb.git cp duckdb/src/include/duckdb.h duckdb.h @@ -24,6 +24,7 @@ deps.source: deps.darwin.amd64: if [ "$(shell uname -s | tr '[:upper:]' '[:lower:]')" != "darwin" ]; then echo "Error: must run build on darwin"; false; fi + git clone -b v${DUCKDB_VERSION} --depth 1 https://github.com/duckdb/duckdb.git cd duckdb && \ make -j 8 && \ mkdir -p lib && \ @@ -37,6 +38,7 @@ deps.darwin.amd64: deps.darwin.arm64: if [ "$(shell uname -s | tr '[:upper:]' '[:lower:]')" != "darwin" ]; then echo "Error: must run build on darwin"; false; fi + git clone -b v${DUCKDB_VERSION} --depth 1 https://github.com/duckdb/duckdb.git cd duckdb && \ make -j 8 && \ mkdir -p lib && \ @@ -50,6 +52,7 @@ deps.darwin.arm64: deps.linux.amd64: if [ "$(shell uname -s | tr '[:upper:]' '[:lower:]')" != "linux" ]; then echo "Error: must run build on linux"; false; fi + git clone -b v${DUCKDB_VERSION} --depth 1 https://github.com/duckdb/duckdb.git cd duckdb && \ make -j 8 && \ mkdir -p lib && \ @@ -63,6 +66,7 @@ deps.linux.amd64: deps.linux.arm64: if [ "$(shell uname -s | tr '[:upper:]' '[:lower:]')" != "linux" ]; then echo "Error: must run build on linux"; false; fi + git clone -b v${DUCKDB_VERSION} --depth 1 https://github.com/duckdb/duckdb.git cd duckdb && \ make -j 8 && \ mkdir -p lib && \ diff --git a/lib/duckdb-1.cpp b/lib/duckdb-1.cpp deleted file mode 100644 index b94047ed..00000000 --- a/lib/duckdb-1.cpp +++ /dev/null @@ -1,21629 +0,0 @@ -#ifdef GODUCKDB_FROM_SOURCE -// See https://raw.githubusercontent.com/duckdb/duckdb/main/LICENSE for licensing information - -#include "duckdb.hpp" -#include "duckdb-internal.hpp" -#ifndef DUCKDB_AMALGAMATION -#error header mismatch -#endif - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -#include - -namespace duckdb { - -Catalog::Catalog(AttachedDatabase &db) : db(db) { -} - -Catalog::~Catalog() { -} - -DatabaseInstance &Catalog::GetDatabase() { - return db.GetDatabase(); -} - -AttachedDatabase &Catalog::GetAttached() { - return db; -} - -const string &Catalog::GetName() { - return GetAttached().GetName(); -} - -idx_t Catalog::GetOid() { - return GetAttached().oid; -} - -Catalog &Catalog::GetSystemCatalog(ClientContext &context) { - return Catalog::GetSystemCatalog(*context.db); -} - -optional_ptr Catalog::GetCatalogEntry(ClientContext &context, const string &catalog_name) { - auto &db_manager = DatabaseManager::Get(context); - if (catalog_name == TEMP_CATALOG) { - return &ClientData::Get(context).temporary_objects->GetCatalog(); - } - if (catalog_name == SYSTEM_CATALOG) { - return &GetSystemCatalog(context); - } - auto entry = db_manager.GetDatabase( - context, IsInvalidCatalog(catalog_name) ? DatabaseManager::GetDefaultDatabase(context) : catalog_name); - if (!entry) { - return nullptr; - } - return &entry->GetCatalog(); -} - -Catalog &Catalog::GetCatalog(ClientContext &context, const string &catalog_name) { - auto catalog = Catalog::GetCatalogEntry(context, catalog_name); - if (!catalog) { - throw BinderException("Catalog \"%s\" does not exist!", catalog_name); - } - return *catalog; -} - -//===--------------------------------------------------------------------===// -// Schema -//===--------------------------------------------------------------------===// -optional_ptr Catalog::CreateSchema(ClientContext &context, CreateSchemaInfo &info) { - return CreateSchema(GetCatalogTransaction(context), info); -} - -CatalogTransaction Catalog::GetCatalogTransaction(ClientContext &context) { - return CatalogTransaction(*this, context); -} - -//===--------------------------------------------------------------------===// -// Table -//===--------------------------------------------------------------------===// -optional_ptr Catalog::CreateTable(ClientContext &context, BoundCreateTableInfo &info) { - return CreateTable(GetCatalogTransaction(context), info); -} - -optional_ptr Catalog::CreateTable(ClientContext &context, unique_ptr info) { - auto binder = Binder::CreateBinder(context); - auto bound_info = binder->BindCreateTableInfo(std::move(info)); - return CreateTable(context, *bound_info); -} - -optional_ptr Catalog::CreateTable(CatalogTransaction transaction, SchemaCatalogEntry &schema, - BoundCreateTableInfo &info) { - return schema.CreateTable(transaction, info); -} - -optional_ptr Catalog::CreateTable(CatalogTransaction transaction, BoundCreateTableInfo &info) { - auto &schema = GetSchema(transaction, info.base->schema); - return CreateTable(transaction, schema, info); -} - -//===--------------------------------------------------------------------===// -// View -//===--------------------------------------------------------------------===// -optional_ptr Catalog::CreateView(CatalogTransaction transaction, CreateViewInfo &info) { - auto &schema = GetSchema(transaction, info.schema); - return CreateView(transaction, schema, info); -} - -optional_ptr Catalog::CreateView(ClientContext &context, CreateViewInfo &info) { - return CreateView(GetCatalogTransaction(context), info); -} - -optional_ptr Catalog::CreateView(CatalogTransaction transaction, SchemaCatalogEntry &schema, - CreateViewInfo &info) { - return schema.CreateView(transaction, info); -} - -//===--------------------------------------------------------------------===// -// Sequence -//===--------------------------------------------------------------------===// -optional_ptr Catalog::CreateSequence(CatalogTransaction transaction, CreateSequenceInfo &info) { - auto &schema = GetSchema(transaction, info.schema); - return CreateSequence(transaction, schema, info); -} - -optional_ptr Catalog::CreateSequence(ClientContext &context, CreateSequenceInfo &info) { - return CreateSequence(GetCatalogTransaction(context), info); -} - -optional_ptr Catalog::CreateSequence(CatalogTransaction transaction, SchemaCatalogEntry &schema, - CreateSequenceInfo &info) { - return schema.CreateSequence(transaction, info); -} - -//===--------------------------------------------------------------------===// -// Type -//===--------------------------------------------------------------------===// -optional_ptr Catalog::CreateType(CatalogTransaction transaction, CreateTypeInfo &info) { - auto &schema = GetSchema(transaction, info.schema); - return CreateType(transaction, schema, info); -} - -optional_ptr Catalog::CreateType(ClientContext &context, CreateTypeInfo &info) { - return CreateType(GetCatalogTransaction(context), info); -} - -optional_ptr Catalog::CreateType(CatalogTransaction transaction, SchemaCatalogEntry &schema, - CreateTypeInfo &info) { - return schema.CreateType(transaction, info); -} - -//===--------------------------------------------------------------------===// -// Table Function -//===--------------------------------------------------------------------===// -optional_ptr Catalog::CreateTableFunction(CatalogTransaction transaction, CreateTableFunctionInfo &info) { - auto &schema = GetSchema(transaction, info.schema); - return CreateTableFunction(transaction, schema, info); -} - -optional_ptr Catalog::CreateTableFunction(ClientContext &context, CreateTableFunctionInfo &info) { - return CreateTableFunction(GetCatalogTransaction(context), info); -} - -optional_ptr Catalog::CreateTableFunction(CatalogTransaction transaction, SchemaCatalogEntry &schema, - CreateTableFunctionInfo &info) { - return schema.CreateTableFunction(transaction, info); -} - -optional_ptr Catalog::CreateTableFunction(ClientContext &context, - optional_ptr info) { - return CreateTableFunction(context, *info); -} - -//===--------------------------------------------------------------------===// -// Copy Function -//===--------------------------------------------------------------------===// -optional_ptr Catalog::CreateCopyFunction(CatalogTransaction transaction, CreateCopyFunctionInfo &info) { - auto &schema = GetSchema(transaction, info.schema); - return CreateCopyFunction(transaction, schema, info); -} - -optional_ptr Catalog::CreateCopyFunction(ClientContext &context, CreateCopyFunctionInfo &info) { - return CreateCopyFunction(GetCatalogTransaction(context), info); -} - -optional_ptr Catalog::CreateCopyFunction(CatalogTransaction transaction, SchemaCatalogEntry &schema, - CreateCopyFunctionInfo &info) { - return schema.CreateCopyFunction(transaction, info); -} - -//===--------------------------------------------------------------------===// -// Pragma Function -//===--------------------------------------------------------------------===// -optional_ptr Catalog::CreatePragmaFunction(CatalogTransaction transaction, - CreatePragmaFunctionInfo &info) { - auto &schema = GetSchema(transaction, info.schema); - return CreatePragmaFunction(transaction, schema, info); -} - -optional_ptr Catalog::CreatePragmaFunction(ClientContext &context, CreatePragmaFunctionInfo &info) { - return CreatePragmaFunction(GetCatalogTransaction(context), info); -} - -optional_ptr Catalog::CreatePragmaFunction(CatalogTransaction transaction, SchemaCatalogEntry &schema, - CreatePragmaFunctionInfo &info) { - return schema.CreatePragmaFunction(transaction, info); -} - -//===--------------------------------------------------------------------===// -// Function -//===--------------------------------------------------------------------===// -optional_ptr Catalog::CreateFunction(CatalogTransaction transaction, CreateFunctionInfo &info) { - auto &schema = GetSchema(transaction, info.schema); - return CreateFunction(transaction, schema, info); -} - -optional_ptr Catalog::CreateFunction(ClientContext &context, CreateFunctionInfo &info) { - return CreateFunction(GetCatalogTransaction(context), info); -} - -optional_ptr Catalog::CreateFunction(CatalogTransaction transaction, SchemaCatalogEntry &schema, - CreateFunctionInfo &info) { - return schema.CreateFunction(transaction, info); -} - -optional_ptr Catalog::AddFunction(ClientContext &context, CreateFunctionInfo &info) { - info.on_conflict = OnCreateConflict::ALTER_ON_CONFLICT; - return CreateFunction(context, info); -} - -//===--------------------------------------------------------------------===// -// Collation -//===--------------------------------------------------------------------===// -optional_ptr Catalog::CreateCollation(CatalogTransaction transaction, CreateCollationInfo &info) { - auto &schema = GetSchema(transaction, info.schema); - return CreateCollation(transaction, schema, info); -} - -optional_ptr Catalog::CreateCollation(ClientContext &context, CreateCollationInfo &info) { - return CreateCollation(GetCatalogTransaction(context), info); -} - -optional_ptr Catalog::CreateCollation(CatalogTransaction transaction, SchemaCatalogEntry &schema, - CreateCollationInfo &info) { - return schema.CreateCollation(transaction, info); -} - -//===--------------------------------------------------------------------===// -// Index -//===--------------------------------------------------------------------===// -optional_ptr Catalog::CreateIndex(CatalogTransaction transaction, CreateIndexInfo &info) { - auto &context = transaction.GetContext(); - return CreateIndex(context, info); -} - -optional_ptr Catalog::CreateIndex(ClientContext &context, CreateIndexInfo &info) { - auto &schema = GetSchema(context, info.schema); - auto &table = GetEntry(context, schema.name, info.table); - return schema.CreateIndex(context, info, table); -} - -//===--------------------------------------------------------------------===// -// Lookup Structures -//===--------------------------------------------------------------------===// -struct CatalogLookup { - CatalogLookup(Catalog &catalog, string schema_p) : catalog(catalog), schema(std::move(schema_p)) { - } - - Catalog &catalog; - string schema; -}; - -//! Return value of Catalog::LookupEntry -struct CatalogEntryLookup { - optional_ptr schema; - optional_ptr entry; - PreservedError error; - - DUCKDB_API bool Found() const { - return entry; - } -}; - -//===--------------------------------------------------------------------===// -// Generic -//===--------------------------------------------------------------------===// -void Catalog::DropEntry(ClientContext &context, DropInfo &info) { - ModifyCatalog(); - if (info.type == CatalogType::SCHEMA_ENTRY) { - // DROP SCHEMA - DropSchema(context, info); - return; - } - - auto lookup = LookupEntry(context, info.type, info.schema, info.name, info.if_not_found); - - if (!lookup.Found()) { - return; - } - - lookup.schema->DropEntry(context, info); -} - -SchemaCatalogEntry &Catalog::GetSchema(ClientContext &context, const string &name, QueryErrorContext error_context) { - return *Catalog::GetSchema(context, name, OnEntryNotFound::THROW_EXCEPTION, error_context); -} - -optional_ptr Catalog::GetSchema(ClientContext &context, const string &schema_name, - OnEntryNotFound if_not_found, QueryErrorContext error_context) { - return GetSchema(GetCatalogTransaction(context), schema_name, if_not_found, error_context); -} - -SchemaCatalogEntry &Catalog::GetSchema(ClientContext &context, const string &catalog_name, const string &schema_name, - QueryErrorContext error_context) { - return *Catalog::GetSchema(context, catalog_name, schema_name, OnEntryNotFound::THROW_EXCEPTION, error_context); -} - -SchemaCatalogEntry &Catalog::GetSchema(CatalogTransaction transaction, const string &name, - QueryErrorContext error_context) { - return *GetSchema(transaction, name, OnEntryNotFound::THROW_EXCEPTION, error_context); -} - -//===--------------------------------------------------------------------===// -// Lookup -//===--------------------------------------------------------------------===// -SimilarCatalogEntry Catalog::SimilarEntryInSchemas(ClientContext &context, const string &entry_name, CatalogType type, - const reference_set_t &schemas) { - SimilarCatalogEntry result; - for (auto schema_ref : schemas) { - auto &schema = schema_ref.get(); - auto transaction = schema.catalog.GetCatalogTransaction(context); - auto entry = schema.GetSimilarEntry(transaction, type, entry_name); - if (!entry.Found()) { - // no similar entry found - continue; - } - if (!result.Found() || result.distance > entry.distance) { - result = entry; - result.schema = &schema; - } - } - return result; -} - -vector GetCatalogEntries(ClientContext &context, const string &catalog, const string &schema) { - vector entries; - auto &search_path = *context.client_data->catalog_search_path; - if (IsInvalidCatalog(catalog) && IsInvalidSchema(schema)) { - // no catalog or schema provided - scan the entire search path - entries = search_path.Get(); - } else if (IsInvalidCatalog(catalog)) { - auto catalogs = search_path.GetCatalogsForSchema(schema); - for (auto &catalog_name : catalogs) { - entries.emplace_back(catalog_name, schema); - } - if (entries.empty()) { - entries.emplace_back(DatabaseManager::GetDefaultDatabase(context), schema); - } - } else if (IsInvalidSchema(schema)) { - auto schemas = search_path.GetSchemasForCatalog(catalog); - for (auto &schema_name : schemas) { - entries.emplace_back(catalog, schema_name); - } - if (entries.empty()) { - entries.emplace_back(catalog, DEFAULT_SCHEMA); - } - } else { - // specific catalog and schema provided - entries.emplace_back(catalog, schema); - } - return entries; -} - -void FindMinimalQualification(ClientContext &context, const string &catalog_name, const string &schema_name, - bool &qualify_database, bool &qualify_schema) { - // check if we can we qualify ONLY the schema - bool found = false; - auto entries = GetCatalogEntries(context, INVALID_CATALOG, schema_name); - for (auto &entry : entries) { - if (entry.catalog == catalog_name && entry.schema == schema_name) { - found = true; - break; - } - } - if (found) { - qualify_database = false; - qualify_schema = true; - return; - } - // check if we can qualify ONLY the catalog - found = false; - entries = GetCatalogEntries(context, catalog_name, INVALID_SCHEMA); - for (auto &entry : entries) { - if (entry.catalog == catalog_name && entry.schema == schema_name) { - found = true; - break; - } - } - if (found) { - qualify_database = true; - qualify_schema = false; - return; - } - // need to qualify both catalog and schema - qualify_database = true; - qualify_schema = true; -} - -bool Catalog::TryAutoLoad(ClientContext &context, const string &original_name) noexcept { - string extension_name = ExtensionHelper::ApplyExtensionAlias(original_name); - if (context.db->ExtensionIsLoaded(extension_name)) { - return true; - } -#ifndef DUCKDB_DISABLE_EXTENSION_LOAD - auto &dbconfig = DBConfig::GetConfig(context); - if (!dbconfig.options.autoload_known_extensions) { - return false; - } - try { - if (ExtensionHelper::CanAutoloadExtension(extension_name)) { - return ExtensionHelper::TryAutoLoadExtension(context, extension_name); - } - } catch (...) { - return false; - } -#endif - return false; -} - -void Catalog::AutoloadExtensionByConfigName(ClientContext &context, const string &configuration_name) { -#ifndef DUCKDB_DISABLE_EXTENSION_LOAD - auto &dbconfig = DBConfig::GetConfig(context); - if (dbconfig.options.autoload_known_extensions) { - auto extension_name = ExtensionHelper::FindExtensionInEntries(configuration_name, EXTENSION_SETTINGS); - if (ExtensionHelper::CanAutoloadExtension(extension_name)) { - ExtensionHelper::AutoLoadExtension(context, extension_name); - return; - } - } -#endif - - throw Catalog::UnrecognizedConfigurationError(context, configuration_name); -} - -bool Catalog::AutoLoadExtensionByCatalogEntry(ClientContext &context, CatalogType type, const string &entry_name) { -#ifndef DUCKDB_DISABLE_EXTENSION_LOAD - auto &dbconfig = DBConfig::GetConfig(context); - if (dbconfig.options.autoload_known_extensions) { - string extension_name; - if (type == CatalogType::TABLE_FUNCTION_ENTRY || type == CatalogType::SCALAR_FUNCTION_ENTRY || - type == CatalogType::AGGREGATE_FUNCTION_ENTRY || type == CatalogType::PRAGMA_FUNCTION_ENTRY) { - extension_name = ExtensionHelper::FindExtensionInEntries(entry_name, EXTENSION_FUNCTIONS); - } else if (type == CatalogType::COPY_FUNCTION_ENTRY) { - extension_name = ExtensionHelper::FindExtensionInEntries(entry_name, EXTENSION_COPY_FUNCTIONS); - } else if (type == CatalogType::TYPE_ENTRY) { - extension_name = ExtensionHelper::FindExtensionInEntries(entry_name, EXTENSION_TYPES); - } else if (type == CatalogType::COLLATION_ENTRY) { - extension_name = ExtensionHelper::FindExtensionInEntries(entry_name, EXTENSION_COLLATIONS); - } - - if (!extension_name.empty() && ExtensionHelper::CanAutoloadExtension(extension_name)) { - ExtensionHelper::AutoLoadExtension(context, extension_name); - return true; - } - } -#endif - - return false; -} - -CatalogException Catalog::UnrecognizedConfigurationError(ClientContext &context, const string &name) { - // check if the setting exists in any extensions - auto extension_name = ExtensionHelper::FindExtensionInEntries(name, EXTENSION_SETTINGS); - if (!extension_name.empty()) { - auto error_message = "Setting with name \"" + name + "\" is not in the catalog, but it exists in the " + - extension_name + " extension."; - error_message = ExtensionHelper::AddExtensionInstallHintToErrorMsg(context, error_message, extension_name); - return CatalogException(error_message); - } - // the setting is not in an extension - // get a list of all options - vector potential_names = DBConfig::GetOptionNames(); - for (auto &entry : DBConfig::GetConfig(context).extension_parameters) { - potential_names.push_back(entry.first); - } - - throw CatalogException("unrecognized configuration parameter \"%s\"\n%s", name, - StringUtil::CandidatesErrorMessage(potential_names, name, "Did you mean")); -} - -CatalogException Catalog::CreateMissingEntryException(ClientContext &context, const string &entry_name, - CatalogType type, - const reference_set_t &schemas, - QueryErrorContext error_context) { - auto entry = SimilarEntryInSchemas(context, entry_name, type, schemas); - - reference_set_t unseen_schemas; - auto &db_manager = DatabaseManager::Get(context); - auto databases = db_manager.GetDatabases(context); - for (auto database : databases) { - auto &catalog = database.get().GetCatalog(); - auto current_schemas = catalog.GetAllSchemas(context); - for (auto ¤t_schema : current_schemas) { - unseen_schemas.insert(current_schema.get()); - } - } - // check if the entry exists in any extension - string extension_name; - if (type == CatalogType::TABLE_FUNCTION_ENTRY || type == CatalogType::SCALAR_FUNCTION_ENTRY || - type == CatalogType::AGGREGATE_FUNCTION_ENTRY || type == CatalogType::PRAGMA_FUNCTION_ENTRY) { - extension_name = ExtensionHelper::FindExtensionInEntries(entry_name, EXTENSION_FUNCTIONS); - } else if (type == CatalogType::TYPE_ENTRY) { - extension_name = ExtensionHelper::FindExtensionInEntries(entry_name, EXTENSION_TYPES); - } else if (type == CatalogType::COPY_FUNCTION_ENTRY) { - extension_name = ExtensionHelper::FindExtensionInEntries(entry_name, EXTENSION_COPY_FUNCTIONS); - } else if (type == CatalogType::COLLATION_ENTRY) { - extension_name = ExtensionHelper::FindExtensionInEntries(entry_name, EXTENSION_COLLATIONS); - } - - // if we found an extension that can handle this catalog entry, create an error hinting the user - if (!extension_name.empty()) { - auto error_message = CatalogTypeToString(type) + " with name \"" + entry_name + - "\" is not in the catalog, but it exists in the " + extension_name + " extension."; - error_message = ExtensionHelper::AddExtensionInstallHintToErrorMsg(context, error_message, extension_name); - return CatalogException(error_message); - } - - auto unseen_entry = SimilarEntryInSchemas(context, entry_name, type, unseen_schemas); - string did_you_mean; - if (unseen_entry.Found() && unseen_entry.distance < entry.distance) { - // the closest matching entry requires qualification as it is not in the default search path - // check how to minimally qualify this entry - auto catalog_name = unseen_entry.schema->catalog.GetName(); - auto schema_name = unseen_entry.schema->name; - bool qualify_database; - bool qualify_schema; - FindMinimalQualification(context, catalog_name, schema_name, qualify_database, qualify_schema); - did_you_mean = "\nDid you mean \"" + unseen_entry.GetQualifiedName(qualify_database, qualify_schema) + "\"?"; - } else if (entry.Found()) { - did_you_mean = "\nDid you mean \"" + entry.name + "\"?"; - } - - return CatalogException(error_context.FormatError("%s with name %s does not exist!%s", CatalogTypeToString(type), - entry_name, did_you_mean)); -} - -CatalogEntryLookup Catalog::TryLookupEntryInternal(CatalogTransaction transaction, CatalogType type, - const string &schema, const string &name) { - auto schema_entry = GetSchema(transaction, schema, OnEntryNotFound::RETURN_NULL); - if (!schema_entry) { - return {nullptr, nullptr, PreservedError()}; - } - auto entry = schema_entry->GetEntry(transaction, type, name); - if (!entry) { - return {schema_entry, nullptr, PreservedError()}; - } - return {schema_entry, entry, PreservedError()}; -} - -CatalogEntryLookup Catalog::TryLookupEntry(ClientContext &context, CatalogType type, const string &schema, - const string &name, OnEntryNotFound if_not_found, - QueryErrorContext error_context) { - reference_set_t schemas; - if (IsInvalidSchema(schema)) { - // try all schemas for this catalog - auto entries = GetCatalogEntries(context, GetName(), INVALID_SCHEMA); - for (auto &entry : entries) { - auto &candidate_schema = entry.schema; - auto transaction = GetCatalogTransaction(context); - auto result = TryLookupEntryInternal(transaction, type, candidate_schema, name); - if (result.Found()) { - return result; - } - if (result.schema) { - schemas.insert(*result.schema); - } - } - } else { - auto transaction = GetCatalogTransaction(context); - auto result = TryLookupEntryInternal(transaction, type, schema, name); - if (result.Found()) { - return result; - } - if (result.schema) { - schemas.insert(*result.schema); - } - } - - if (if_not_found == OnEntryNotFound::RETURN_NULL) { - return {nullptr, nullptr, PreservedError()}; - } else { - auto except = CreateMissingEntryException(context, name, type, schemas, error_context); - return {nullptr, nullptr, PreservedError(except)}; - } -} - -CatalogEntryLookup Catalog::LookupEntry(ClientContext &context, CatalogType type, const string &schema, - const string &name, OnEntryNotFound if_not_found, - QueryErrorContext error_context) { - auto res = TryLookupEntry(context, type, schema, name, if_not_found, error_context); - - if (res.error) { - res.error.Throw(); - } - - return res; -} - -CatalogEntryLookup Catalog::TryLookupEntry(ClientContext &context, vector &lookups, CatalogType type, - const string &name, OnEntryNotFound if_not_found, - QueryErrorContext error_context) { - reference_set_t schemas; - for (auto &lookup : lookups) { - auto transaction = lookup.catalog.GetCatalogTransaction(context); - auto result = lookup.catalog.TryLookupEntryInternal(transaction, type, lookup.schema, name); - if (result.Found()) { - return result; - } - if (result.schema) { - schemas.insert(*result.schema); - } - } - - if (if_not_found == OnEntryNotFound::RETURN_NULL) { - return {nullptr, nullptr, PreservedError()}; - } else { - auto except = CreateMissingEntryException(context, name, type, schemas, error_context); - return {nullptr, nullptr, PreservedError(except)}; - } -} - -CatalogEntryLookup Catalog::TryLookupEntry(ClientContext &context, CatalogType type, const string &catalog, - const string &schema, const string &name, OnEntryNotFound if_not_found, - QueryErrorContext error_context) { - auto entries = GetCatalogEntries(context, catalog, schema); - vector lookups; - lookups.reserve(entries.size()); - for (auto &entry : entries) { - if (if_not_found == OnEntryNotFound::RETURN_NULL) { - auto catalog_entry = Catalog::GetCatalogEntry(context, entry.catalog); - if (!catalog_entry) { - return {nullptr, nullptr, PreservedError()}; - } - lookups.emplace_back(*catalog_entry, entry.schema); - } else { - lookups.emplace_back(Catalog::GetCatalog(context, entry.catalog), entry.schema); - } - } - return Catalog::TryLookupEntry(context, lookups, type, name, if_not_found, error_context); -} - -CatalogEntry &Catalog::GetEntry(ClientContext &context, const string &schema, const string &name) { - vector entry_types {CatalogType::TABLE_ENTRY, CatalogType::SEQUENCE_ENTRY}; - - for (auto entry_type : entry_types) { - auto result = GetEntry(context, entry_type, schema, name, OnEntryNotFound::RETURN_NULL); - if (result) { - return *result; - } - } - - throw CatalogException("CatalogElement \"%s.%s\" does not exist!", schema, name); -} - -optional_ptr Catalog::GetEntry(ClientContext &context, CatalogType type, const string &schema_name, - const string &name, OnEntryNotFound if_not_found, - QueryErrorContext error_context) { - auto lookup_entry = TryLookupEntry(context, type, schema_name, name, if_not_found, error_context); - - // Try autoloading extension to resolve lookup - if (!lookup_entry.Found()) { - if (AutoLoadExtensionByCatalogEntry(context, type, name)) { - lookup_entry = TryLookupEntry(context, type, schema_name, name, if_not_found, error_context); - } - } - - if (lookup_entry.error) { - lookup_entry.error.Throw(); - } - - return lookup_entry.entry.get(); -} - -CatalogEntry &Catalog::GetEntry(ClientContext &context, CatalogType type, const string &schema, const string &name, - QueryErrorContext error_context) { - return *Catalog::GetEntry(context, type, schema, name, OnEntryNotFound::THROW_EXCEPTION, error_context); -} - -optional_ptr Catalog::GetEntry(ClientContext &context, CatalogType type, const string &catalog, - const string &schema, const string &name, OnEntryNotFound if_not_found, - QueryErrorContext error_context) { - auto result = TryLookupEntry(context, type, catalog, schema, name, if_not_found, error_context); - - // Try autoloading extension to resolve lookup - if (!result.Found()) { - if (AutoLoadExtensionByCatalogEntry(context, type, name)) { - result = TryLookupEntry(context, type, catalog, schema, name, if_not_found, error_context); - } - } - - if (result.error) { - result.error.Throw(); - } - - if (!result.Found()) { - D_ASSERT(if_not_found == OnEntryNotFound::RETURN_NULL); - return nullptr; - } - return result.entry.get(); -} - -CatalogEntry &Catalog::GetEntry(ClientContext &context, CatalogType type, const string &catalog, const string &schema, - const string &name, QueryErrorContext error_context) { - return *Catalog::GetEntry(context, type, catalog, schema, name, OnEntryNotFound::THROW_EXCEPTION, error_context); -} - -optional_ptr Catalog::GetSchema(ClientContext &context, const string &catalog_name, - const string &schema_name, OnEntryNotFound if_not_found, - QueryErrorContext error_context) { - auto entries = GetCatalogEntries(context, catalog_name, schema_name); - for (idx_t i = 0; i < entries.size(); i++) { - auto on_not_found = i + 1 == entries.size() ? if_not_found : OnEntryNotFound::RETURN_NULL; - auto &catalog = Catalog::GetCatalog(context, entries[i].catalog); - auto result = catalog.GetSchema(context, schema_name, on_not_found, error_context); - if (result) { - return result; - } - } - return nullptr; -} - -LogicalType Catalog::GetType(ClientContext &context, const string &schema, const string &name, - OnEntryNotFound if_not_found) { - auto type_entry = GetEntry(context, schema, name, if_not_found); - if (!type_entry) { - return LogicalType::INVALID; - } - return type_entry->user_type; -} - -LogicalType Catalog::GetType(ClientContext &context, const string &catalog_name, const string &schema, - const string &name) { - auto &type_entry = Catalog::GetEntry(context, catalog_name, schema, name); - return type_entry.user_type; -} - -vector> Catalog::GetSchemas(ClientContext &context) { - vector> schemas; - ScanSchemas(context, [&](SchemaCatalogEntry &entry) { schemas.push_back(entry); }); - return schemas; -} - -vector> Catalog::GetSchemas(ClientContext &context, const string &catalog_name) { - vector> catalogs; - if (IsInvalidCatalog(catalog_name)) { - reference_set_t inserted_catalogs; - - auto &search_path = *context.client_data->catalog_search_path; - for (auto &entry : search_path.Get()) { - auto &catalog = Catalog::GetCatalog(context, entry.catalog); - if (inserted_catalogs.find(catalog) != inserted_catalogs.end()) { - continue; - } - inserted_catalogs.insert(catalog); - catalogs.push_back(catalog); - } - } else { - catalogs.push_back(Catalog::GetCatalog(context, catalog_name)); - } - vector> result; - for (auto catalog : catalogs) { - auto schemas = catalog.get().GetSchemas(context); - result.insert(result.end(), schemas.begin(), schemas.end()); - } - return result; -} - -vector> Catalog::GetAllSchemas(ClientContext &context) { - vector> result; - - auto &db_manager = DatabaseManager::Get(context); - auto databases = db_manager.GetDatabases(context); - for (auto database : databases) { - auto &catalog = database.get().GetCatalog(); - auto new_schemas = catalog.GetSchemas(context); - result.insert(result.end(), new_schemas.begin(), new_schemas.end()); - } - sort(result.begin(), result.end(), - [&](reference left_p, reference right_p) { - auto &left = left_p.get(); - auto &right = right_p.get(); - if (left.catalog.GetName() < right.catalog.GetName()) { - return true; - } - if (left.catalog.GetName() == right.catalog.GetName()) { - return left.name < right.name; - } - return false; - }); - - return result; -} - -void Catalog::Alter(ClientContext &context, AlterInfo &info) { - ModifyCatalog(); - auto lookup = LookupEntry(context, info.GetCatalogType(), info.schema, info.name, info.if_not_found); - - if (!lookup.Found()) { - return; - } - return lookup.schema->Alter(context, info); -} - -vector Catalog::GetMetadataInfo(ClientContext &context) { - return vector(); -} - -void Catalog::Verify() { -} - -//===--------------------------------------------------------------------===// -// Catalog Version -//===--------------------------------------------------------------------===// -idx_t Catalog::GetCatalogVersion() { - return GetDatabase().GetDatabaseManager().catalog_version; -} - -idx_t Catalog::ModifyCatalog() { - return GetDatabase().GetDatabaseManager().ModifyCatalog(); -} - -bool Catalog::IsSystemCatalog() const { - return db.IsSystem(); -} - -bool Catalog::IsTemporaryCatalog() const { - return db.IsTemporary(); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -ColumnDependencyManager::ColumnDependencyManager() { -} - -ColumnDependencyManager::~ColumnDependencyManager() { -} - -void ColumnDependencyManager::AddGeneratedColumn(const ColumnDefinition &column, const ColumnList &list) { - D_ASSERT(column.Generated()); - vector referenced_columns; - column.GetListOfDependencies(referenced_columns); - vector indices; - for (auto &col : referenced_columns) { - if (!list.ColumnExists(col)) { - throw BinderException("Column \"%s\" referenced by generated column does not exist", col); - } - auto &entry = list.GetColumn(col); - indices.push_back(entry.Logical()); - } - return AddGeneratedColumn(column.Logical(), indices); -} - -void ColumnDependencyManager::AddGeneratedColumn(LogicalIndex index, const vector &indices, bool root) { - if (indices.empty()) { - return; - } - auto &list = dependents_map[index]; - // Create a link between the dependencies - for (auto &dep : indices) { - // Add this column as a dependency of the new column - list.insert(dep); - // Add the new column as a dependent of the column - dependencies_map[dep].insert(index); - // Inherit the dependencies - if (HasDependencies(dep)) { - auto &inherited_deps = dependents_map[dep]; - D_ASSERT(!inherited_deps.empty()); - for (auto &inherited_dep : inherited_deps) { - list.insert(inherited_dep); - dependencies_map[inherited_dep].insert(index); - } - } - if (!root) { - continue; - } - direct_dependencies[index].insert(dep); - } - if (!HasDependents(index)) { - return; - } - auto &dependents = dependencies_map[index]; - if (dependents.count(index)) { - throw InvalidInputException("Circular dependency encountered when resolving generated column expressions"); - } - // Also let the dependents of this generated column inherit the dependencies - for (auto &dependent : dependents) { - AddGeneratedColumn(dependent, indices, false); - } -} - -vector ColumnDependencyManager::RemoveColumn(LogicalIndex index, idx_t column_amount) { - // Always add the initial column - deleted_columns.insert(index); - - RemoveGeneratedColumn(index); - RemoveStandardColumn(index); - - // Clean up the internal list - vector new_indices = CleanupInternals(column_amount); - D_ASSERT(deleted_columns.empty()); - return new_indices; -} - -bool ColumnDependencyManager::IsDependencyOf(LogicalIndex gcol, LogicalIndex col) const { - auto entry = dependents_map.find(gcol); - if (entry == dependents_map.end()) { - return false; - } - auto &list = entry->second; - return list.count(col); -} - -bool ColumnDependencyManager::HasDependencies(LogicalIndex index) const { - auto entry = dependents_map.find(index); - if (entry == dependents_map.end()) { - return false; - } - return true; -} - -const logical_index_set_t &ColumnDependencyManager::GetDependencies(LogicalIndex index) const { - auto entry = dependents_map.find(index); - D_ASSERT(entry != dependents_map.end()); - return entry->second; -} - -bool ColumnDependencyManager::HasDependents(LogicalIndex index) const { - auto entry = dependencies_map.find(index); - if (entry == dependencies_map.end()) { - return false; - } - return true; -} - -const logical_index_set_t &ColumnDependencyManager::GetDependents(LogicalIndex index) const { - auto entry = dependencies_map.find(index); - D_ASSERT(entry != dependencies_map.end()); - return entry->second; -} - -void ColumnDependencyManager::RemoveStandardColumn(LogicalIndex index) { - if (!HasDependents(index)) { - return; - } - auto dependents = dependencies_map[index]; - for (auto &gcol : dependents) { - // If index is a direct dependency of gcol, remove it from the list - if (direct_dependencies.find(gcol) != direct_dependencies.end()) { - direct_dependencies[gcol].erase(index); - } - RemoveGeneratedColumn(gcol); - } - // Remove this column from the dependencies map - dependencies_map.erase(index); -} - -void ColumnDependencyManager::RemoveGeneratedColumn(LogicalIndex index) { - deleted_columns.insert(index); - if (!HasDependencies(index)) { - return; - } - auto &dependencies = dependents_map[index]; - for (auto &col : dependencies) { - // Remove this generated column from the list of this column - auto &col_dependents = dependencies_map[col]; - D_ASSERT(col_dependents.count(index)); - col_dependents.erase(index); - // If the resulting list is empty, remove the column from the dependencies map altogether - if (col_dependents.empty()) { - dependencies_map.erase(col); - } - } - // Remove this column from the dependents_map map - dependents_map.erase(index); -} - -void ColumnDependencyManager::AdjustSingle(LogicalIndex idx, idx_t offset) { - D_ASSERT(idx.index >= offset); - LogicalIndex new_idx = LogicalIndex(idx.index - offset); - // Adjust this index in the dependents of this column - bool has_dependents = HasDependents(idx); - bool has_dependencies = HasDependencies(idx); - - if (has_dependents) { - auto &dependents = GetDependents(idx); - for (auto &dep : dependents) { - auto &dep_dependencies = dependents_map[dep]; - dep_dependencies.erase(idx); - D_ASSERT(!dep_dependencies.count(new_idx)); - dep_dependencies.insert(new_idx); - } - } - if (has_dependencies) { - auto &dependencies = GetDependencies(idx); - for (auto &dep : dependencies) { - auto &dep_dependents = dependencies_map[dep]; - dep_dependents.erase(idx); - D_ASSERT(!dep_dependents.count(new_idx)); - dep_dependents.insert(new_idx); - } - } - if (has_dependents) { - D_ASSERT(!dependencies_map.count(new_idx)); - dependencies_map[new_idx] = std::move(dependencies_map[idx]); - dependencies_map.erase(idx); - } - if (has_dependencies) { - D_ASSERT(!dependents_map.count(new_idx)); - dependents_map[new_idx] = std::move(dependents_map[idx]); - dependents_map.erase(idx); - } -} - -vector ColumnDependencyManager::CleanupInternals(idx_t column_amount) { - vector to_adjust; - D_ASSERT(!deleted_columns.empty()); - // Get the lowest index that was deleted - vector new_indices(column_amount, LogicalIndex(DConstants::INVALID_INDEX)); - idx_t threshold = deleted_columns.begin()->index; - - idx_t offset = 0; - for (idx_t i = 0; i < column_amount; i++) { - auto current_index = LogicalIndex(i); - auto new_index = LogicalIndex(i - offset); - new_indices[i] = new_index; - if (deleted_columns.count(current_index)) { - offset++; - continue; - } - if (i > threshold && (HasDependencies(current_index) || HasDependents(current_index))) { - to_adjust.push_back(current_index); - } - } - - // Adjust all indices inside the dependency managers internal mappings - for (auto &col : to_adjust) { - auto offset = col.index - new_indices[col.index].index; - AdjustSingle(col, offset); - } - deleted_columns.clear(); - return new_indices; -} - -stack ColumnDependencyManager::GetBindOrder(const ColumnList &columns) { - stack bind_order; - queue to_visit; - logical_index_set_t visited; - - for (auto &entry : direct_dependencies) { - auto dependent = entry.first; - //! Skip the dependents that are also dependencies - if (dependencies_map.find(dependent) != dependencies_map.end()) { - continue; - } - bind_order.push(dependent); - visited.insert(dependent); - for (auto &dependency : direct_dependencies[dependent]) { - to_visit.push(dependency); - } - } - - while (!to_visit.empty()) { - auto column = to_visit.front(); - to_visit.pop(); - - //! If this column does not have dependencies, the queue stops getting filled - if (direct_dependencies.find(column) == direct_dependencies.end()) { - continue; - } - bind_order.push(column); - visited.insert(column); - - for (auto &dependency : direct_dependencies[column]) { - to_visit.push(dependency); - } - } - - // Add generated columns that have no dependencies, but still might need to have their type resolved - for (auto &col : columns.Logical()) { - // Not a generated column - if (!col.Generated()) { - continue; - } - // Already added to the bind_order stack - if (visited.count(col.Logical())) { - continue; - } - bind_order.push(col.Logical()); - } - - return bind_order; -} - -} // namespace duckdb - - - -namespace duckdb { - -CopyFunctionCatalogEntry::CopyFunctionCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, - CreateCopyFunctionInfo &info) - : StandardEntry(CatalogType::COPY_FUNCTION_ENTRY, schema, catalog, info.name), function(info.function) { -} - -} // namespace duckdb - - - - -namespace duckdb { - -DuckIndexEntry::DuckIndexEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateIndexInfo &info) - : IndexCatalogEntry(catalog, schema, info) { -} - -DuckIndexEntry::~DuckIndexEntry() { - // remove the associated index from the info - if (!info || !index) { - return; - } - info->indexes.RemoveIndex(*index); -} - -string DuckIndexEntry::GetSchemaName() const { - return info->schema; -} - -string DuckIndexEntry::GetTableName() const { - return info->table; -} - -void DuckIndexEntry::CommitDrop() { - D_ASSERT(info && index); - index->CommitDrop(); -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -namespace duckdb { - -void FindForeignKeyInformation(CatalogEntry &entry, AlterForeignKeyType alter_fk_type, - vector> &fk_arrays) { - if (entry.type != CatalogType::TABLE_ENTRY) { - return; - } - auto &table_entry = entry.Cast(); - auto &constraints = table_entry.GetConstraints(); - for (idx_t i = 0; i < constraints.size(); i++) { - auto &cond = constraints[i]; - if (cond->type != ConstraintType::FOREIGN_KEY) { - continue; - } - auto &fk = cond->Cast(); - if (fk.info.type == ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE) { - AlterEntryData alter_data(entry.ParentCatalog().GetName(), fk.info.schema, fk.info.table, - OnEntryNotFound::THROW_EXCEPTION); - fk_arrays.push_back(make_uniq(std::move(alter_data), entry.name, fk.pk_columns, - fk.fk_columns, fk.info.pk_keys, fk.info.fk_keys, - alter_fk_type)); - } else if (fk.info.type == ForeignKeyType::FK_TYPE_PRIMARY_KEY_TABLE && - alter_fk_type == AlterForeignKeyType::AFT_DELETE) { - throw CatalogException("Could not drop the table because this table is main key table of the table \"%s\"", - fk.info.table); - } - } -} - -DuckSchemaEntry::DuckSchemaEntry(Catalog &catalog, string name_p, bool is_internal) - : SchemaCatalogEntry(catalog, std::move(name_p), is_internal), - tables(catalog, make_uniq(catalog, *this)), indexes(catalog), table_functions(catalog), - copy_functions(catalog), pragma_functions(catalog), - functions(catalog, make_uniq(catalog, *this)), sequences(catalog), collations(catalog), - types(catalog, make_uniq(catalog, *this)) { -} - -optional_ptr DuckSchemaEntry::AddEntryInternal(CatalogTransaction transaction, - unique_ptr entry, - OnCreateConflict on_conflict, - DependencyList dependencies) { - auto entry_name = entry->name; - auto entry_type = entry->type; - auto result = entry.get(); - - // first find the set for this entry - auto &set = GetCatalogSet(entry_type); - dependencies.AddDependency(*this); - if (on_conflict == OnCreateConflict::REPLACE_ON_CONFLICT) { - // CREATE OR REPLACE: first try to drop the entry - auto old_entry = set.GetEntry(transaction, entry_name); - if (old_entry) { - if (old_entry->type != entry_type) { - throw CatalogException("Existing object %s is of type %s, trying to replace with type %s", entry_name, - CatalogTypeToString(old_entry->type), CatalogTypeToString(entry_type)); - } - (void)set.DropEntry(transaction, entry_name, false, entry->internal); - } - } - // now try to add the entry - if (!set.CreateEntry(transaction, entry_name, std::move(entry), dependencies)) { - // entry already exists! - if (on_conflict == OnCreateConflict::ERROR_ON_CONFLICT) { - throw CatalogException("%s with name \"%s\" already exists!", CatalogTypeToString(entry_type), entry_name); - } else { - return nullptr; - } - } - return result; -} - -optional_ptr DuckSchemaEntry::CreateTable(CatalogTransaction transaction, BoundCreateTableInfo &info) { - auto table = make_uniq(catalog, *this, info); - auto &storage = table->GetStorage(); - storage.info->cardinality = storage.GetTotalRows(); - - auto entry = AddEntryInternal(transaction, std::move(table), info.Base().on_conflict, info.dependencies); - if (!entry) { - return nullptr; - } - - // add a foreign key constraint in main key table if there is a foreign key constraint - vector> fk_arrays; - FindForeignKeyInformation(*entry, AlterForeignKeyType::AFT_ADD, fk_arrays); - for (idx_t i = 0; i < fk_arrays.size(); i++) { - // alter primary key table - auto &fk_info = *fk_arrays[i]; - catalog.Alter(transaction.GetContext(), fk_info); - - // make a dependency between this table and referenced table - auto &set = GetCatalogSet(CatalogType::TABLE_ENTRY); - info.dependencies.AddDependency(*set.GetEntry(transaction, fk_info.name)); - } - return entry; -} - -optional_ptr DuckSchemaEntry::CreateFunction(CatalogTransaction transaction, CreateFunctionInfo &info) { - if (info.on_conflict == OnCreateConflict::ALTER_ON_CONFLICT) { - // check if the original entry exists - auto &catalog_set = GetCatalogSet(info.type); - auto current_entry = catalog_set.GetEntry(transaction, info.name); - if (current_entry) { - // the current entry exists - alter it instead - auto alter_info = info.GetAlterInfo(); - Alter(transaction.GetContext(), *alter_info); - return nullptr; - } - } - unique_ptr function; - switch (info.type) { - case CatalogType::SCALAR_FUNCTION_ENTRY: - function = make_uniq_base(catalog, *this, - info.Cast()); - break; - case CatalogType::TABLE_FUNCTION_ENTRY: - function = make_uniq_base(catalog, *this, - info.Cast()); - break; - case CatalogType::MACRO_ENTRY: - // create a macro function - function = make_uniq_base(catalog, *this, info.Cast()); - break; - - case CatalogType::TABLE_MACRO_ENTRY: - // create a macro table function - function = make_uniq_base(catalog, *this, info.Cast()); - break; - case CatalogType::AGGREGATE_FUNCTION_ENTRY: - D_ASSERT(info.type == CatalogType::AGGREGATE_FUNCTION_ENTRY); - // create an aggregate function - function = make_uniq_base( - catalog, *this, info.Cast()); - break; - default: - throw InternalException("Unknown function type \"%s\"", CatalogTypeToString(info.type)); - } - function->internal = info.internal; - return AddEntry(transaction, std::move(function), info.on_conflict); -} - -optional_ptr DuckSchemaEntry::AddEntry(CatalogTransaction transaction, unique_ptr entry, - OnCreateConflict on_conflict) { - DependencyList dependencies; - return AddEntryInternal(transaction, std::move(entry), on_conflict, dependencies); -} - -optional_ptr DuckSchemaEntry::CreateSequence(CatalogTransaction transaction, CreateSequenceInfo &info) { - auto sequence = make_uniq(catalog, *this, info); - return AddEntry(transaction, std::move(sequence), info.on_conflict); -} - -optional_ptr DuckSchemaEntry::CreateType(CatalogTransaction transaction, CreateTypeInfo &info) { - auto type_entry = make_uniq(catalog, *this, info); - return AddEntry(transaction, std::move(type_entry), info.on_conflict); -} - -optional_ptr DuckSchemaEntry::CreateView(CatalogTransaction transaction, CreateViewInfo &info) { - auto view = make_uniq(catalog, *this, info); - return AddEntry(transaction, std::move(view), info.on_conflict); -} - -optional_ptr DuckSchemaEntry::CreateIndex(ClientContext &context, CreateIndexInfo &info, - TableCatalogEntry &table) { - DependencyList dependencies; - dependencies.AddDependency(table); - auto index = make_uniq(catalog, *this, info); - return AddEntryInternal(GetCatalogTransaction(context), std::move(index), info.on_conflict, dependencies); -} - -optional_ptr DuckSchemaEntry::CreateCollation(CatalogTransaction transaction, CreateCollationInfo &info) { - auto collation = make_uniq(catalog, *this, info); - collation->internal = info.internal; - return AddEntry(transaction, std::move(collation), info.on_conflict); -} - -optional_ptr DuckSchemaEntry::CreateTableFunction(CatalogTransaction transaction, - CreateTableFunctionInfo &info) { - auto table_function = make_uniq(catalog, *this, info); - table_function->internal = info.internal; - return AddEntry(transaction, std::move(table_function), info.on_conflict); -} - -optional_ptr DuckSchemaEntry::CreateCopyFunction(CatalogTransaction transaction, - CreateCopyFunctionInfo &info) { - auto copy_function = make_uniq(catalog, *this, info); - copy_function->internal = info.internal; - return AddEntry(transaction, std::move(copy_function), info.on_conflict); -} - -optional_ptr DuckSchemaEntry::CreatePragmaFunction(CatalogTransaction transaction, - CreatePragmaFunctionInfo &info) { - auto pragma_function = make_uniq(catalog, *this, info); - pragma_function->internal = info.internal; - return AddEntry(transaction, std::move(pragma_function), info.on_conflict); -} - -void DuckSchemaEntry::Alter(ClientContext &context, AlterInfo &info) { - CatalogType type = info.GetCatalogType(); - auto &set = GetCatalogSet(type); - auto transaction = GetCatalogTransaction(context); - if (info.type == AlterType::CHANGE_OWNERSHIP) { - if (!set.AlterOwnership(transaction, info.Cast())) { - throw CatalogException("Couldn't change ownership!"); - } - } else { - string name = info.name; - if (!set.AlterEntry(transaction, name, info)) { - throw CatalogException("Entry with name \"%s\" does not exist!", name); - } - } -} - -void DuckSchemaEntry::Scan(ClientContext &context, CatalogType type, - const std::function &callback) { - auto &set = GetCatalogSet(type); - set.Scan(GetCatalogTransaction(context), callback); -} - -void DuckSchemaEntry::Scan(CatalogType type, const std::function &callback) { - auto &set = GetCatalogSet(type); - set.Scan(callback); -} - -void DuckSchemaEntry::DropEntry(ClientContext &context, DropInfo &info) { - auto &set = GetCatalogSet(info.type); - - // first find the entry - auto transaction = GetCatalogTransaction(context); - auto existing_entry = set.GetEntry(transaction, info.name); - if (!existing_entry) { - throw InternalException("Failed to drop entry \"%s\" - entry could not be found", info.name); - } - if (existing_entry->type != info.type) { - throw CatalogException("Existing object %s is of type %s, trying to replace with type %s", info.name, - CatalogTypeToString(existing_entry->type), CatalogTypeToString(info.type)); - } - - // if there is a foreign key constraint, get that information - vector> fk_arrays; - FindForeignKeyInformation(*existing_entry, AlterForeignKeyType::AFT_DELETE, fk_arrays); - - if (!set.DropEntry(transaction, info.name, info.cascade, info.allow_drop_internal)) { - throw InternalException("Could not drop element because of an internal error"); - } - - // remove the foreign key constraint in main key table if main key table's name is valid - for (idx_t i = 0; i < fk_arrays.size(); i++) { - // alter primary key table - catalog.Alter(context, *fk_arrays[i]); - } -} - -optional_ptr DuckSchemaEntry::GetEntry(CatalogTransaction transaction, CatalogType type, - const string &name) { - return GetCatalogSet(type).GetEntry(transaction, name); -} - -SimilarCatalogEntry DuckSchemaEntry::GetSimilarEntry(CatalogTransaction transaction, CatalogType type, - const string &name) { - return GetCatalogSet(type).SimilarEntry(transaction, name); -} - -CatalogSet &DuckSchemaEntry::GetCatalogSet(CatalogType type) { - switch (type) { - case CatalogType::VIEW_ENTRY: - case CatalogType::TABLE_ENTRY: - return tables; - case CatalogType::INDEX_ENTRY: - return indexes; - case CatalogType::TABLE_FUNCTION_ENTRY: - case CatalogType::TABLE_MACRO_ENTRY: - return table_functions; - case CatalogType::COPY_FUNCTION_ENTRY: - return copy_functions; - case CatalogType::PRAGMA_FUNCTION_ENTRY: - return pragma_functions; - case CatalogType::AGGREGATE_FUNCTION_ENTRY: - case CatalogType::SCALAR_FUNCTION_ENTRY: - case CatalogType::MACRO_ENTRY: - return functions; - case CatalogType::SEQUENCE_ENTRY: - return sequences; - case CatalogType::COLLATION_ENTRY: - return collations; - case CatalogType::TYPE_ENTRY: - return types; - default: - throw InternalException("Unsupported catalog type in schema"); - } -} - -void DuckSchemaEntry::Verify(Catalog &catalog) { - InCatalogEntry::Verify(catalog); - - tables.Verify(catalog); - indexes.Verify(catalog); - table_functions.Verify(catalog); - copy_functions.Verify(catalog); - pragma_functions.Verify(catalog); - functions.Verify(catalog); - sequences.Verify(catalog); - collations.Verify(catalog); - types.Verify(catalog); -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - - - - - -namespace duckdb { - -void AddDataTableIndex(DataTable &storage, const ColumnList &columns, const vector &keys, - IndexConstraintType constraint_type, BlockPointer index_block = BlockPointer()) { - // fetch types and create expressions for the index from the columns - vector column_ids; - vector> unbound_expressions; - vector> bound_expressions; - idx_t key_nr = 0; - column_ids.reserve(keys.size()); - for (auto &physical_key : keys) { - auto &column = columns.GetColumn(physical_key); - D_ASSERT(!column.Generated()); - unbound_expressions.push_back( - make_uniq(column.Name(), column.Type(), ColumnBinding(0, column_ids.size()))); - - bound_expressions.push_back(make_uniq(column.Type(), key_nr++)); - column_ids.push_back(column.StorageOid()); - } - unique_ptr art; - // create an adaptive radix tree around the expressions - if (index_block.IsValid()) { - art = make_uniq(column_ids, TableIOManager::Get(storage), std::move(unbound_expressions), constraint_type, - storage.db, nullptr, index_block); - } else { - art = make_uniq(column_ids, TableIOManager::Get(storage), std::move(unbound_expressions), constraint_type, - storage.db); - if (!storage.IsRoot()) { - throw TransactionException("Transaction conflict: cannot add an index to a table that has been altered!"); - } - } - storage.info->indexes.AddIndex(std::move(art)); -} - -void AddDataTableIndex(DataTable &storage, const ColumnList &columns, vector &keys, - IndexConstraintType constraint_type, BlockPointer index_block = BlockPointer()) { - vector new_keys; - new_keys.reserve(keys.size()); - for (auto &logical_key : keys) { - new_keys.push_back(columns.LogicalToPhysical(logical_key)); - } - AddDataTableIndex(storage, columns, new_keys, constraint_type, index_block); -} - -DuckTableEntry::DuckTableEntry(Catalog &catalog, SchemaCatalogEntry &schema, BoundCreateTableInfo &info, - std::shared_ptr inherited_storage) - : TableCatalogEntry(catalog, schema, info.Base()), storage(std::move(inherited_storage)), - bound_constraints(std::move(info.bound_constraints)), - column_dependency_manager(std::move(info.column_dependency_manager)) { - if (!storage) { - // create the physical storage - vector storage_columns; - for (auto &col_def : columns.Physical()) { - storage_columns.push_back(col_def.Copy()); - } - storage = make_shared(catalog.GetAttached(), StorageManager::Get(catalog).GetTableIOManager(&info), - schema.name, name, std::move(storage_columns), std::move(info.data)); - - // create the unique indexes for the UNIQUE and PRIMARY KEY and FOREIGN KEY constraints - idx_t indexes_idx = 0; - for (idx_t i = 0; i < bound_constraints.size(); i++) { - auto &constraint = bound_constraints[i]; - if (constraint->type == ConstraintType::UNIQUE) { - // unique constraint: create a unique index - auto &unique = constraint->Cast(); - IndexConstraintType constraint_type = IndexConstraintType::UNIQUE; - if (unique.is_primary_key) { - constraint_type = IndexConstraintType::PRIMARY; - } - if (info.indexes.empty()) { - AddDataTableIndex(*storage, columns, unique.keys, constraint_type); - } else { - AddDataTableIndex(*storage, columns, unique.keys, constraint_type, info.indexes[indexes_idx++]); - } - } else if (constraint->type == ConstraintType::FOREIGN_KEY) { - // foreign key constraint: create a foreign key index - auto &bfk = constraint->Cast(); - if (bfk.info.type == ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE || - bfk.info.type == ForeignKeyType::FK_TYPE_SELF_REFERENCE_TABLE) { - if (info.indexes.empty()) { - AddDataTableIndex(*storage, columns, bfk.info.fk_keys, IndexConstraintType::FOREIGN); - } else { - AddDataTableIndex(*storage, columns, bfk.info.fk_keys, IndexConstraintType::FOREIGN, - info.indexes[indexes_idx++]); - } - } - } - } - } -} - -unique_ptr DuckTableEntry::GetStatistics(ClientContext &context, column_t column_id) { - if (column_id == COLUMN_IDENTIFIER_ROW_ID) { - return nullptr; - } - auto &column = columns.GetColumn(LogicalIndex(column_id)); - if (column.Generated()) { - return nullptr; - } - return storage->GetStatistics(context, column.StorageOid()); -} - -unique_ptr DuckTableEntry::AlterEntry(ClientContext &context, AlterInfo &info) { - D_ASSERT(!internal); - if (info.type != AlterType::ALTER_TABLE) { - throw CatalogException("Can only modify table with ALTER TABLE statement"); - } - auto &table_info = info.Cast(); - switch (table_info.alter_table_type) { - case AlterTableType::RENAME_COLUMN: { - auto &rename_info = table_info.Cast(); - return RenameColumn(context, rename_info); - } - case AlterTableType::RENAME_TABLE: { - auto &rename_info = table_info.Cast(); - auto copied_table = Copy(context); - copied_table->name = rename_info.new_table_name; - storage->info->table = rename_info.new_table_name; - return copied_table; - } - case AlterTableType::ADD_COLUMN: { - auto &add_info = table_info.Cast(); - return AddColumn(context, add_info); - } - case AlterTableType::REMOVE_COLUMN: { - auto &remove_info = table_info.Cast(); - return RemoveColumn(context, remove_info); - } - case AlterTableType::SET_DEFAULT: { - auto &set_default_info = table_info.Cast(); - return SetDefault(context, set_default_info); - } - case AlterTableType::ALTER_COLUMN_TYPE: { - auto &change_type_info = table_info.Cast(); - return ChangeColumnType(context, change_type_info); - } - case AlterTableType::FOREIGN_KEY_CONSTRAINT: { - auto &foreign_key_constraint_info = table_info.Cast(); - if (foreign_key_constraint_info.type == AlterForeignKeyType::AFT_ADD) { - return AddForeignKeyConstraint(context, foreign_key_constraint_info); - } else { - return DropForeignKeyConstraint(context, foreign_key_constraint_info); - } - } - case AlterTableType::SET_NOT_NULL: { - auto &set_not_null_info = table_info.Cast(); - return SetNotNull(context, set_not_null_info); - } - case AlterTableType::DROP_NOT_NULL: { - auto &drop_not_null_info = table_info.Cast(); - return DropNotNull(context, drop_not_null_info); - } - default: - throw InternalException("Unrecognized alter table type!"); - } -} - -void DuckTableEntry::UndoAlter(ClientContext &context, AlterInfo &info) { - D_ASSERT(!internal); - D_ASSERT(info.type == AlterType::ALTER_TABLE); - auto &table_info = info.Cast(); - switch (table_info.alter_table_type) { - case AlterTableType::RENAME_TABLE: { - storage->info->table = this->name; - break; - default: - break; - } - } -} - -static void RenameExpression(ParsedExpression &expr, RenameColumnInfo &info) { - if (expr.type == ExpressionType::COLUMN_REF) { - auto &colref = expr.Cast(); - if (colref.column_names.back() == info.old_name) { - colref.column_names.back() = info.new_name; - } - } - ParsedExpressionIterator::EnumerateChildren( - expr, [&](const ParsedExpression &child) { RenameExpression((ParsedExpression &)child, info); }); -} - -unique_ptr DuckTableEntry::RenameColumn(ClientContext &context, RenameColumnInfo &info) { - auto rename_idx = GetColumnIndex(info.old_name); - if (rename_idx.index == COLUMN_IDENTIFIER_ROW_ID) { - throw CatalogException("Cannot rename rowid column"); - } - auto create_info = make_uniq(schema, name); - create_info->temporary = temporary; - for (auto &col : columns.Logical()) { - auto copy = col.Copy(); - if (rename_idx == col.Logical()) { - copy.SetName(info.new_name); - } - if (col.Generated() && column_dependency_manager.IsDependencyOf(col.Logical(), rename_idx)) { - RenameExpression(copy.GeneratedExpressionMutable(), info); - } - create_info->columns.AddColumn(std::move(copy)); - } - for (idx_t c_idx = 0; c_idx < constraints.size(); c_idx++) { - auto copy = constraints[c_idx]->Copy(); - switch (copy->type) { - case ConstraintType::NOT_NULL: - // NOT NULL constraint: no adjustments necessary - break; - case ConstraintType::CHECK: { - // CHECK constraint: need to rename column references that refer to the renamed column - auto &check = copy->Cast(); - RenameExpression(*check.expression, info); - break; - } - case ConstraintType::UNIQUE: { - // UNIQUE constraint: possibly need to rename columns - auto &unique = copy->Cast(); - for (idx_t i = 0; i < unique.columns.size(); i++) { - if (unique.columns[i] == info.old_name) { - unique.columns[i] = info.new_name; - } - } - break; - } - case ConstraintType::FOREIGN_KEY: { - // FOREIGN KEY constraint: possibly need to rename columns - auto &fk = copy->Cast(); - vector columns = fk.pk_columns; - if (fk.info.type == ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE) { - columns = fk.fk_columns; - } else if (fk.info.type == ForeignKeyType::FK_TYPE_SELF_REFERENCE_TABLE) { - for (idx_t i = 0; i < fk.fk_columns.size(); i++) { - columns.push_back(fk.fk_columns[i]); - } - } - for (idx_t i = 0; i < columns.size(); i++) { - if (columns[i] == info.old_name) { - throw CatalogException( - "Cannot rename column \"%s\" because this is involved in the foreign key constraint", - info.old_name); - } - } - break; - } - default: - throw InternalException("Unsupported constraint for entry!"); - } - create_info->constraints.push_back(std::move(copy)); - } - auto binder = Binder::CreateBinder(context); - auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info)); - return make_uniq(catalog, schema, *bound_create_info, storage); -} - -unique_ptr DuckTableEntry::AddColumn(ClientContext &context, AddColumnInfo &info) { - auto col_name = info.new_column.GetName(); - - // We're checking for the opposite condition (ADD COLUMN IF _NOT_ EXISTS ...). - if (info.if_column_not_exists && ColumnExists(col_name)) { - return nullptr; - } - - auto create_info = make_uniq(schema, name); - create_info->temporary = temporary; - - for (auto &col : columns.Logical()) { - create_info->columns.AddColumn(col.Copy()); - } - for (auto &constraint : constraints) { - create_info->constraints.push_back(constraint->Copy()); - } - Binder::BindLogicalType(context, info.new_column.TypeMutable(), &catalog, schema.name); - info.new_column.SetOid(columns.LogicalColumnCount()); - info.new_column.SetStorageOid(columns.PhysicalColumnCount()); - auto col = info.new_column.Copy(); - - create_info->columns.AddColumn(std::move(col)); - - auto binder = Binder::CreateBinder(context); - auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info)); - auto new_storage = - make_shared(context, *storage, info.new_column, *bound_create_info->bound_defaults.back()); - return make_uniq(catalog, schema, *bound_create_info, new_storage); -} - -void DuckTableEntry::UpdateConstraintsOnColumnDrop(const LogicalIndex &removed_index, - const vector &adjusted_indices, - const RemoveColumnInfo &info, CreateTableInfo &create_info, - bool is_generated) { - // handle constraints for the new table - D_ASSERT(constraints.size() == bound_constraints.size()); - - for (idx_t constr_idx = 0; constr_idx < constraints.size(); constr_idx++) { - auto &constraint = constraints[constr_idx]; - auto &bound_constraint = bound_constraints[constr_idx]; - switch (constraint->type) { - case ConstraintType::NOT_NULL: { - auto ¬_null_constraint = bound_constraint->Cast(); - auto not_null_index = columns.PhysicalToLogical(not_null_constraint.index); - if (not_null_index != removed_index) { - // the constraint is not about this column: we need to copy it - // we might need to shift the index back by one though, to account for the removed column - auto new_index = adjusted_indices[not_null_index.index]; - create_info.constraints.push_back(make_uniq(new_index)); - } - break; - } - case ConstraintType::CHECK: { - // Generated columns can not be part of an index - // CHECK constraint - auto &bound_check = bound_constraint->Cast(); - // check if the removed column is part of the check constraint - if (is_generated) { - // generated columns can not be referenced by constraints, we can just add the constraint back - create_info.constraints.push_back(constraint->Copy()); - break; - } - auto physical_index = columns.LogicalToPhysical(removed_index); - if (bound_check.bound_columns.find(physical_index) != bound_check.bound_columns.end()) { - if (bound_check.bound_columns.size() > 1) { - // CHECK constraint that concerns mult - throw CatalogException( - "Cannot drop column \"%s\" because there is a CHECK constraint that depends on it", - info.removed_column); - } else { - // CHECK constraint that ONLY concerns this column, strip the constraint - } - } else { - // check constraint does not concern the removed column: simply re-add it - create_info.constraints.push_back(constraint->Copy()); - } - break; - } - case ConstraintType::UNIQUE: { - auto copy = constraint->Copy(); - auto &unique = copy->Cast(); - if (unique.index.index != DConstants::INVALID_INDEX) { - if (unique.index == removed_index) { - throw CatalogException( - "Cannot drop column \"%s\" because there is a UNIQUE constraint that depends on it", - info.removed_column); - } - unique.index = adjusted_indices[unique.index.index]; - } - create_info.constraints.push_back(std::move(copy)); - break; - } - case ConstraintType::FOREIGN_KEY: { - auto copy = constraint->Copy(); - auto &fk = copy->Cast(); - vector columns = fk.pk_columns; - if (fk.info.type == ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE) { - columns = fk.fk_columns; - } else if (fk.info.type == ForeignKeyType::FK_TYPE_SELF_REFERENCE_TABLE) { - for (idx_t i = 0; i < fk.fk_columns.size(); i++) { - columns.push_back(fk.fk_columns[i]); - } - } - for (idx_t i = 0; i < columns.size(); i++) { - if (columns[i] == info.removed_column) { - throw CatalogException( - "Cannot drop column \"%s\" because there is a FOREIGN KEY constraint that depends on it", - info.removed_column); - } - } - create_info.constraints.push_back(std::move(copy)); - break; - } - default: - throw InternalException("Unsupported constraint for entry!"); - } - } -} - -unique_ptr DuckTableEntry::RemoveColumn(ClientContext &context, RemoveColumnInfo &info) { - auto removed_index = GetColumnIndex(info.removed_column, info.if_column_exists); - if (!removed_index.IsValid()) { - if (!info.if_column_exists) { - throw CatalogException("Cannot drop column: rowid column cannot be dropped"); - } - return nullptr; - } - - auto create_info = make_uniq(schema, name); - create_info->temporary = temporary; - - logical_index_set_t removed_columns; - if (column_dependency_manager.HasDependents(removed_index)) { - removed_columns = column_dependency_manager.GetDependents(removed_index); - } - if (!removed_columns.empty() && !info.cascade) { - throw CatalogException("Cannot drop column: column is a dependency of 1 or more generated column(s)"); - } - bool dropped_column_is_generated = false; - for (auto &col : columns.Logical()) { - if (col.Logical() == removed_index || removed_columns.count(col.Logical())) { - if (col.Generated()) { - dropped_column_is_generated = true; - } - continue; - } - create_info->columns.AddColumn(col.Copy()); - } - if (create_info->columns.empty()) { - throw CatalogException("Cannot drop column: table only has one column remaining!"); - } - auto adjusted_indices = column_dependency_manager.RemoveColumn(removed_index, columns.LogicalColumnCount()); - - UpdateConstraintsOnColumnDrop(removed_index, adjusted_indices, info, *create_info, dropped_column_is_generated); - - auto binder = Binder::CreateBinder(context); - auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info)); - if (columns.GetColumn(LogicalIndex(removed_index)).Generated()) { - return make_uniq(catalog, schema, *bound_create_info, storage); - } - auto new_storage = - make_shared(context, *storage, columns.LogicalToPhysical(LogicalIndex(removed_index)).index); - return make_uniq(catalog, schema, *bound_create_info, new_storage); -} - -unique_ptr DuckTableEntry::SetDefault(ClientContext &context, SetDefaultInfo &info) { - auto create_info = make_uniq(schema, name); - auto default_idx = GetColumnIndex(info.column_name); - if (default_idx.index == COLUMN_IDENTIFIER_ROW_ID) { - throw CatalogException("Cannot SET DEFAULT for rowid column"); - } - - // Copy all the columns, changing the value of the one that was specified by 'column_name' - for (auto &col : columns.Logical()) { - auto copy = col.Copy(); - if (default_idx == col.Logical()) { - // set the default value of this column - if (copy.Generated()) { - throw BinderException("Cannot SET DEFAULT for generated column \"%s\"", col.Name()); - } - copy.SetDefaultValue(info.expression ? info.expression->Copy() : nullptr); - } - create_info->columns.AddColumn(std::move(copy)); - } - // Copy all the constraints - for (idx_t i = 0; i < constraints.size(); i++) { - auto constraint = constraints[i]->Copy(); - create_info->constraints.push_back(std::move(constraint)); - } - - auto binder = Binder::CreateBinder(context); - auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info)); - return make_uniq(catalog, schema, *bound_create_info, storage); -} - -unique_ptr DuckTableEntry::SetNotNull(ClientContext &context, SetNotNullInfo &info) { - - auto create_info = make_uniq(schema, name); - create_info->columns = columns.Copy(); - - auto not_null_idx = GetColumnIndex(info.column_name); - if (columns.GetColumn(LogicalIndex(not_null_idx)).Generated()) { - throw BinderException("Unsupported constraint for generated column!"); - } - bool has_not_null = false; - for (idx_t i = 0; i < constraints.size(); i++) { - auto constraint = constraints[i]->Copy(); - if (constraint->type == ConstraintType::NOT_NULL) { - auto ¬_null = constraint->Cast(); - if (not_null.index == not_null_idx) { - has_not_null = true; - } - } - create_info->constraints.push_back(std::move(constraint)); - } - if (!has_not_null) { - create_info->constraints.push_back(make_uniq(not_null_idx)); - } - auto binder = Binder::CreateBinder(context); - auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info)); - - // Early return - if (has_not_null) { - return make_uniq(catalog, schema, *bound_create_info, storage); - } - - // Return with new storage info. Note that we need the bound column index here. - auto new_storage = make_shared( - context, *storage, make_uniq(columns.LogicalToPhysical(LogicalIndex(not_null_idx)))); - return make_uniq(catalog, schema, *bound_create_info, new_storage); -} - -unique_ptr DuckTableEntry::DropNotNull(ClientContext &context, DropNotNullInfo &info) { - auto create_info = make_uniq(schema, name); - create_info->columns = columns.Copy(); - - auto not_null_idx = GetColumnIndex(info.column_name); - for (idx_t i = 0; i < constraints.size(); i++) { - auto constraint = constraints[i]->Copy(); - // Skip/drop not_null - if (constraint->type == ConstraintType::NOT_NULL) { - auto ¬_null = constraint->Cast(); - if (not_null.index == not_null_idx) { - continue; - } - } - create_info->constraints.push_back(std::move(constraint)); - } - - auto binder = Binder::CreateBinder(context); - auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info)); - return make_uniq(catalog, schema, *bound_create_info, storage); -} - -unique_ptr DuckTableEntry::ChangeColumnType(ClientContext &context, ChangeColumnTypeInfo &info) { - Binder::BindLogicalType(context, info.target_type, &catalog, schema.name); - auto change_idx = GetColumnIndex(info.column_name); - auto create_info = make_uniq(schema, name); - create_info->temporary = temporary; - - for (auto &col : columns.Logical()) { - auto copy = col.Copy(); - if (change_idx == col.Logical()) { - // set the type of this column - if (copy.Generated()) { - throw NotImplementedException("Changing types of generated columns is not supported yet"); - } - copy.SetType(info.target_type); - } - // TODO: check if the generated_expression breaks, only delete it if it does - if (copy.Generated() && column_dependency_manager.IsDependencyOf(col.Logical(), change_idx)) { - throw BinderException( - "This column is referenced by the generated column \"%s\", so its type can not be changed", - copy.Name()); - } - create_info->columns.AddColumn(std::move(copy)); - } - - for (idx_t i = 0; i < constraints.size(); i++) { - auto constraint = constraints[i]->Copy(); - switch (constraint->type) { - case ConstraintType::CHECK: { - auto &bound_check = bound_constraints[i]->Cast(); - auto physical_index = columns.LogicalToPhysical(change_idx); - if (bound_check.bound_columns.find(physical_index) != bound_check.bound_columns.end()) { - throw BinderException("Cannot change the type of a column that has a CHECK constraint specified"); - } - break; - } - case ConstraintType::NOT_NULL: - break; - case ConstraintType::UNIQUE: { - auto &bound_unique = bound_constraints[i]->Cast(); - if (bound_unique.key_set.find(change_idx) != bound_unique.key_set.end()) { - throw BinderException( - "Cannot change the type of a column that has a UNIQUE or PRIMARY KEY constraint specified"); - } - break; - } - case ConstraintType::FOREIGN_KEY: { - auto &bfk = bound_constraints[i]->Cast(); - auto key_set = bfk.pk_key_set; - if (bfk.info.type == ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE) { - key_set = bfk.fk_key_set; - } else if (bfk.info.type == ForeignKeyType::FK_TYPE_SELF_REFERENCE_TABLE) { - for (idx_t i = 0; i < bfk.info.fk_keys.size(); i++) { - key_set.insert(bfk.info.fk_keys[i]); - } - } - if (key_set.find(columns.LogicalToPhysical(change_idx)) != key_set.end()) { - throw BinderException("Cannot change the type of a column that has a FOREIGN KEY constraint specified"); - } - break; - } - default: - throw InternalException("Unsupported constraint for entry!"); - } - create_info->constraints.push_back(std::move(constraint)); - } - - auto binder = Binder::CreateBinder(context); - // bind the specified expression - vector bound_columns; - AlterBinder expr_binder(*binder, context, *this, bound_columns, info.target_type); - auto expression = info.expression->Copy(); - auto bound_expression = expr_binder.Bind(expression); - auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info)); - vector storage_oids; - for (idx_t i = 0; i < bound_columns.size(); i++) { - storage_oids.push_back(columns.LogicalToPhysical(bound_columns[i]).index); - } - if (storage_oids.empty()) { - storage_oids.push_back(COLUMN_IDENTIFIER_ROW_ID); - } - - auto new_storage = - make_shared(context, *storage, columns.LogicalToPhysical(LogicalIndex(change_idx)).index, - info.target_type, std::move(storage_oids), *bound_expression); - auto result = make_uniq(catalog, schema, *bound_create_info, new_storage); - return std::move(result); -} - -unique_ptr DuckTableEntry::AddForeignKeyConstraint(ClientContext &context, AlterForeignKeyInfo &info) { - D_ASSERT(info.type == AlterForeignKeyType::AFT_ADD); - auto create_info = make_uniq(schema, name); - create_info->temporary = temporary; - - create_info->columns = columns.Copy(); - for (idx_t i = 0; i < constraints.size(); i++) { - create_info->constraints.push_back(constraints[i]->Copy()); - } - ForeignKeyInfo fk_info; - fk_info.type = ForeignKeyType::FK_TYPE_PRIMARY_KEY_TABLE; - fk_info.schema = info.schema; - fk_info.table = info.fk_table; - fk_info.pk_keys = info.pk_keys; - fk_info.fk_keys = info.fk_keys; - create_info->constraints.push_back( - make_uniq(info.pk_columns, info.fk_columns, std::move(fk_info))); - - auto binder = Binder::CreateBinder(context); - auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info)); - - return make_uniq(catalog, schema, *bound_create_info, storage); -} - -unique_ptr DuckTableEntry::DropForeignKeyConstraint(ClientContext &context, AlterForeignKeyInfo &info) { - D_ASSERT(info.type == AlterForeignKeyType::AFT_DELETE); - auto create_info = make_uniq(schema, name); - create_info->temporary = temporary; - - create_info->columns = columns.Copy(); - for (idx_t i = 0; i < constraints.size(); i++) { - auto constraint = constraints[i]->Copy(); - if (constraint->type == ConstraintType::FOREIGN_KEY) { - ForeignKeyConstraint &fk = constraint->Cast(); - if (fk.info.type == ForeignKeyType::FK_TYPE_PRIMARY_KEY_TABLE && fk.info.table == info.fk_table) { - continue; - } - } - create_info->constraints.push_back(std::move(constraint)); - } - - auto binder = Binder::CreateBinder(context); - auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info)); - - return make_uniq(catalog, schema, *bound_create_info, storage); -} - -unique_ptr DuckTableEntry::Copy(ClientContext &context) const { - auto create_info = make_uniq(schema, name); - create_info->columns = columns.Copy(); - - for (idx_t i = 0; i < constraints.size(); i++) { - auto constraint = constraints[i]->Copy(); - create_info->constraints.push_back(std::move(constraint)); - } - - auto binder = Binder::CreateBinder(context); - auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info)); - return make_uniq(catalog, schema, *bound_create_info, storage); -} - -void DuckTableEntry::SetAsRoot() { - storage->SetAsRoot(); - storage->info->table = name; -} - -void DuckTableEntry::CommitAlter(string &column_name) { - D_ASSERT(!column_name.empty()); - idx_t removed_index = DConstants::INVALID_INDEX; - for (auto &col : columns.Logical()) { - if (col.Name() == column_name) { - // No need to alter storage, removed column is generated column - if (col.Generated()) { - return; - } - removed_index = col.Oid(); - break; - } - } - D_ASSERT(removed_index != DConstants::INVALID_INDEX); - storage->CommitDropColumn(columns.LogicalToPhysical(LogicalIndex(removed_index)).index); -} - -void DuckTableEntry::CommitDrop() { - storage->CommitDropTable(); -} - -DataTable &DuckTableEntry::GetStorage() { - return *storage; -} - -const vector> &DuckTableEntry::GetBoundConstraints() { - return bound_constraints; -} - -TableFunction DuckTableEntry::GetScanFunction(ClientContext &context, unique_ptr &bind_data) { - bind_data = make_uniq(*this); - return TableScanFunction::GetFunction(); -} - -vector DuckTableEntry::GetColumnSegmentInfo() { - return storage->GetColumnSegmentInfo(); -} - -TableStorageInfo DuckTableEntry::GetStorageInfo(ClientContext &context) { - TableStorageInfo result; - result.cardinality = storage->info->cardinality.load(); - storage->info->indexes.Scan([&](Index &index) { - IndexInfo info; - info.is_primary = index.IsPrimary(); - info.is_unique = index.IsUnique() || info.is_primary; - info.is_foreign = index.IsForeign(); - info.column_set = index.column_id_set; - result.index_info.push_back(std::move(info)); - return false; - }); - return result; -} - -} // namespace duckdb - - - -namespace duckdb { - -IndexCatalogEntry::IndexCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateIndexInfo &info) - : StandardEntry(CatalogType::INDEX_ENTRY, schema, catalog, info.index_name), index(nullptr), sql(info.sql) { - this->temporary = info.temporary; -} - -string IndexCatalogEntry::ToSQL() const { - if (sql.empty()) { - return sql; - } - if (sql[sql.size() - 1] != ';') { - return sql + ";"; - } - return sql; -} - -unique_ptr IndexCatalogEntry::GetInfo() const { - auto result = make_uniq(); - result->schema = GetSchemaName(); - result->table = GetTableName(); - result->index_name = name; - result->sql = sql; - result->index_type = index->type; - result->constraint_type = index->constraint_type; - for (auto &expr : expressions) { - result->expressions.push_back(expr->Copy()); - } - for (auto &expr : parsed_expressions) { - result->parsed_expressions.push_back(expr->Copy()); - } - result->column_ids = index->column_ids; - result->temporary = temporary; - return std::move(result); -} - -} // namespace duckdb - - - - -namespace duckdb { - -MacroCatalogEntry::MacroCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateMacroInfo &info) - : FunctionEntry( - (info.function->type == MacroType::SCALAR_MACRO ? CatalogType::MACRO_ENTRY : CatalogType::TABLE_MACRO_ENTRY), - catalog, schema, info), - function(std::move(info.function)) { - this->temporary = info.temporary; - this->internal = info.internal; -} - -ScalarMacroCatalogEntry::ScalarMacroCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateMacroInfo &info) - : MacroCatalogEntry(catalog, schema, info) { -} - -TableMacroCatalogEntry::TableMacroCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateMacroInfo &info) - : MacroCatalogEntry(catalog, schema, info) { -} - -unique_ptr MacroCatalogEntry::GetInfo() const { - auto info = make_uniq(type); - info->catalog = catalog.GetName(); - info->schema = schema.name; - info->name = name; - info->function = function->Copy(); - return std::move(info); -} - -} // namespace duckdb - - - -namespace duckdb { - -PragmaFunctionCatalogEntry::PragmaFunctionCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, - CreatePragmaFunctionInfo &info) - : FunctionEntry(CatalogType::PRAGMA_FUNCTION_ENTRY, catalog, schema, info), functions(std::move(info.functions)) { -} - -} // namespace duckdb - - - -namespace duckdb { - -ScalarFunctionCatalogEntry::ScalarFunctionCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, - CreateScalarFunctionInfo &info) - : FunctionEntry(CatalogType::SCALAR_FUNCTION_ENTRY, catalog, schema, info), functions(info.functions) { -} - -unique_ptr ScalarFunctionCatalogEntry::AlterEntry(ClientContext &context, AlterInfo &info) { - if (info.type != AlterType::ALTER_SCALAR_FUNCTION) { - throw InternalException("Attempting to alter ScalarFunctionCatalogEntry with unsupported alter type"); - } - auto &function_info = info.Cast(); - if (function_info.alter_scalar_function_type != AlterScalarFunctionType::ADD_FUNCTION_OVERLOADS) { - throw InternalException( - "Attempting to alter ScalarFunctionCatalogEntry with unsupported alter scalar function type"); - } - auto &add_overloads = function_info.Cast(); - - ScalarFunctionSet new_set = functions; - if (!new_set.MergeFunctionSet(add_overloads.new_overloads)) { - throw BinderException("Failed to add new function overloads to function \"%s\": function already exists", name); - } - CreateScalarFunctionInfo new_info(std::move(new_set)); - return make_uniq(catalog, schema, new_info); -} - -} // namespace duckdb - - - - - - - - -#include - -namespace duckdb { - -SchemaCatalogEntry::SchemaCatalogEntry(Catalog &catalog, string name_p, bool internal) - : InCatalogEntry(CatalogType::SCHEMA_ENTRY, catalog, std::move(name_p)) { - this->internal = internal; -} - -CatalogTransaction SchemaCatalogEntry::GetCatalogTransaction(ClientContext &context) { - return CatalogTransaction(catalog, context); -} - -SimilarCatalogEntry SchemaCatalogEntry::GetSimilarEntry(CatalogTransaction transaction, CatalogType type, - const string &name) { - SimilarCatalogEntry result; - Scan(transaction.GetContext(), type, [&](CatalogEntry &entry) { - auto ldist = StringUtil::SimilarityScore(entry.name, name); - if (ldist < result.distance) { - result.distance = ldist; - result.name = entry.name; - } - }); - return result; -} - -unique_ptr SchemaCatalogEntry::GetInfo() const { - auto result = make_uniq(); - result->schema = name; - return std::move(result); -} - -string SchemaCatalogEntry::ToSQL() const { - std::stringstream ss; - ss << "CREATE SCHEMA " << name << ";"; - return ss.str(); -} - -} // namespace duckdb - - - - - - - -#include -#include - -namespace duckdb { - -SequenceCatalogEntry::SequenceCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateSequenceInfo &info) - : StandardEntry(CatalogType::SEQUENCE_ENTRY, schema, catalog, info.name), usage_count(info.usage_count), - counter(info.start_value), increment(info.increment), start_value(info.start_value), min_value(info.min_value), - max_value(info.max_value), cycle(info.cycle) { - this->temporary = info.temporary; -} - -unique_ptr SequenceCatalogEntry::GetInfo() const { - auto result = make_uniq(); - result->schema = schema.name; - result->name = name; - result->usage_count = usage_count; - result->increment = increment; - result->min_value = min_value; - result->max_value = max_value; - result->start_value = counter; - result->cycle = cycle; - return std::move(result); -} - -string SequenceCatalogEntry::ToSQL() const { - std::stringstream ss; - ss << "CREATE SEQUENCE "; - ss << name; - ss << " INCREMENT BY " << increment; - ss << " MINVALUE " << min_value; - ss << " MAXVALUE " << max_value; - ss << " START " << counter; - ss << " " << (cycle ? "CYCLE" : "NO CYCLE") << ";"; - return ss.str(); -} -} // namespace duckdb - - - - - - - - - - - - - - - -#include - -namespace duckdb { - -TableCatalogEntry::TableCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateTableInfo &info) - : StandardEntry(CatalogType::TABLE_ENTRY, schema, catalog, info.table), columns(std::move(info.columns)), - constraints(std::move(info.constraints)) { - this->temporary = info.temporary; -} - -bool TableCatalogEntry::HasGeneratedColumns() const { - return columns.LogicalColumnCount() != columns.PhysicalColumnCount(); -} - -LogicalIndex TableCatalogEntry::GetColumnIndex(string &column_name, bool if_exists) { - auto entry = columns.GetColumnIndex(column_name); - if (!entry.IsValid()) { - if (if_exists) { - return entry; - } - throw BinderException("Table \"%s\" does not have a column with name \"%s\"", name, column_name); - } - return entry; -} - -bool TableCatalogEntry::ColumnExists(const string &name) { - return columns.ColumnExists(name); -} - -const ColumnDefinition &TableCatalogEntry::GetColumn(const string &name) { - return columns.GetColumn(name); -} - -vector TableCatalogEntry::GetTypes() { - vector types; - for (auto &col : columns.Physical()) { - types.push_back(col.Type()); - } - return types; -} - -unique_ptr TableCatalogEntry::GetInfo() const { - auto result = make_uniq(); - result->catalog = catalog.GetName(); - result->schema = schema.name; - result->table = name; - result->columns = columns.Copy(); - result->constraints.reserve(constraints.size()); - std::for_each(constraints.begin(), constraints.end(), - [&result](const unique_ptr &c) { result->constraints.emplace_back(c->Copy()); }); - return std::move(result); -} - -string TableCatalogEntry::ColumnsToSQL(const ColumnList &columns, const vector> &constraints) { - std::stringstream ss; - - ss << "("; - - // find all columns that have NOT NULL specified, but are NOT primary key columns - logical_index_set_t not_null_columns; - logical_index_set_t unique_columns; - logical_index_set_t pk_columns; - unordered_set multi_key_pks; - vector extra_constraints; - for (auto &constraint : constraints) { - if (constraint->type == ConstraintType::NOT_NULL) { - auto ¬_null = constraint->Cast(); - not_null_columns.insert(not_null.index); - } else if (constraint->type == ConstraintType::UNIQUE) { - auto &pk = constraint->Cast(); - vector constraint_columns = pk.columns; - if (pk.index.index != DConstants::INVALID_INDEX) { - // no columns specified: single column constraint - if (pk.is_primary_key) { - pk_columns.insert(pk.index); - } else { - unique_columns.insert(pk.index); - } - } else { - // multi-column constraint, this constraint needs to go at the end after all columns - if (pk.is_primary_key) { - // multi key pk column: insert set of columns into multi_key_pks - for (auto &col : pk.columns) { - multi_key_pks.insert(col); - } - } - extra_constraints.push_back(constraint->ToString()); - } - } else if (constraint->type == ConstraintType::FOREIGN_KEY) { - auto &fk = constraint->Cast(); - if (fk.info.type == ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE || - fk.info.type == ForeignKeyType::FK_TYPE_SELF_REFERENCE_TABLE) { - extra_constraints.push_back(constraint->ToString()); - } - } else { - extra_constraints.push_back(constraint->ToString()); - } - } - - for (auto &column : columns.Logical()) { - if (column.Oid() > 0) { - ss << ", "; - } - ss << KeywordHelper::WriteOptionallyQuoted(column.Name()) << " "; - ss << column.Type().ToString(); - bool not_null = not_null_columns.find(column.Logical()) != not_null_columns.end(); - bool is_single_key_pk = pk_columns.find(column.Logical()) != pk_columns.end(); - bool is_multi_key_pk = multi_key_pks.find(column.Name()) != multi_key_pks.end(); - bool is_unique = unique_columns.find(column.Logical()) != unique_columns.end(); - if (not_null && !is_single_key_pk && !is_multi_key_pk) { - // NOT NULL but not a primary key column - ss << " NOT NULL"; - } - if (is_single_key_pk) { - // single column pk: insert constraint here - ss << " PRIMARY KEY"; - } - if (is_unique) { - // single column unique: insert constraint here - ss << " UNIQUE"; - } - if (column.Generated()) { - ss << " GENERATED ALWAYS AS(" << column.GeneratedExpression().ToString() << ")"; - } else if (column.DefaultValue()) { - ss << " DEFAULT(" << column.DefaultValue()->ToString() << ")"; - } - } - // print any extra constraints that still need to be printed - for (auto &extra_constraint : extra_constraints) { - ss << ", "; - ss << extra_constraint; - } - - ss << ")"; - return ss.str(); -} - -string TableCatalogEntry::ToSQL() const { - std::stringstream ss; - - ss << "CREATE TABLE "; - - if (schema.name != DEFAULT_SCHEMA) { - ss << KeywordHelper::WriteOptionallyQuoted(schema.name) << "."; - } - - ss << KeywordHelper::WriteOptionallyQuoted(name); - ss << ColumnsToSQL(columns, constraints); - ss << ";"; - - return ss.str(); -} - -const ColumnList &TableCatalogEntry::GetColumns() const { - return columns; -} - -const ColumnDefinition &TableCatalogEntry::GetColumn(LogicalIndex idx) { - return columns.GetColumn(idx); -} - -const vector> &TableCatalogEntry::GetConstraints() { - return constraints; -} - -// LCOV_EXCL_START -DataTable &TableCatalogEntry::GetStorage() { - throw InternalException("Calling GetStorage on a TableCatalogEntry that is not a DuckTableEntry"); -} - -const vector> &TableCatalogEntry::GetBoundConstraints() { - throw InternalException("Calling GetBoundConstraints on a TableCatalogEntry that is not a DuckTableEntry"); -} - -// LCOV_EXCL_STOP - -static void BindExtraColumns(TableCatalogEntry &table, LogicalGet &get, LogicalProjection &proj, LogicalUpdate &update, - physical_index_set_t &bound_columns) { - if (bound_columns.size() <= 1) { - return; - } - idx_t found_column_count = 0; - physical_index_set_t found_columns; - for (idx_t i = 0; i < update.columns.size(); i++) { - if (bound_columns.find(update.columns[i]) != bound_columns.end()) { - // this column is referenced in the CHECK constraint - found_column_count++; - found_columns.insert(update.columns[i]); - } - } - if (found_column_count > 0 && found_column_count != bound_columns.size()) { - // columns in this CHECK constraint were referenced, but not all were part of the UPDATE - // add them to the scan and update set - for (auto &check_column_id : bound_columns) { - if (found_columns.find(check_column_id) != found_columns.end()) { - // column is already projected - continue; - } - // column is not projected yet: project it by adding the clause "i=i" to the set of updated columns - auto &column = table.GetColumns().GetColumn(check_column_id); - update.expressions.push_back(make_uniq( - column.Type(), ColumnBinding(proj.table_index, proj.expressions.size()))); - proj.expressions.push_back(make_uniq( - column.Type(), ColumnBinding(get.table_index, get.column_ids.size()))); - get.column_ids.push_back(check_column_id.index); - update.columns.push_back(check_column_id); - } - } -} - -static bool TypeSupportsRegularUpdate(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::LIST: - case LogicalTypeId::MAP: - case LogicalTypeId::UNION: - // lists and maps and unions don't support updates directly - return false; - case LogicalTypeId::STRUCT: { - auto &child_types = StructType::GetChildTypes(type); - for (auto &entry : child_types) { - if (!TypeSupportsRegularUpdate(entry.second)) { - return false; - } - } - return true; - } - default: - return true; - } -} - -vector TableCatalogEntry::GetColumnSegmentInfo() { - return {}; -} - -void TableCatalogEntry::BindUpdateConstraints(LogicalGet &get, LogicalProjection &proj, LogicalUpdate &update, - ClientContext &context) { - // check the constraints and indexes of the table to see if we need to project any additional columns - // we do this for indexes with multiple columns and CHECK constraints in the UPDATE clause - // suppose we have a constraint CHECK(i + j < 10); now we need both i and j to check the constraint - // if we are only updating one of the two columns we add the other one to the UPDATE set - // with a "useless" update (i.e. i=i) so we can verify that the CHECK constraint is not violated - for (auto &constraint : GetBoundConstraints()) { - if (constraint->type == ConstraintType::CHECK) { - auto &check = constraint->Cast(); - // check constraint! check if we need to add any extra columns to the UPDATE clause - BindExtraColumns(*this, get, proj, update, check.bound_columns); - } - } - if (update.return_chunk) { - physical_index_set_t all_columns; - for (auto &column : GetColumns().Physical()) { - all_columns.insert(column.Physical()); - } - BindExtraColumns(*this, get, proj, update, all_columns); - } - // for index updates we always turn any update into an insert and a delete - // we thus need all the columns to be available, hence we check if the update touches any index columns - // If the returning keyword is used, we need access to the whole row in case the user requests it. - // Therefore switch the update to a delete and insert. - update.update_is_del_and_insert = false; - TableStorageInfo table_storage_info = GetStorageInfo(context); - for (auto index : table_storage_info.index_info) { - for (auto &column : update.columns) { - if (index.column_set.find(column.index) != index.column_set.end()) { - update.update_is_del_and_insert = true; - break; - } - } - }; - - // we also convert any updates on LIST columns into delete + insert - for (auto &col_index : update.columns) { - auto &column = GetColumns().GetColumn(col_index); - if (!TypeSupportsRegularUpdate(column.Type())) { - update.update_is_del_and_insert = true; - break; - } - } - - if (update.update_is_del_and_insert) { - // the update updates a column required by an index or requires returning the updated rows, - // push projections for all columns - physical_index_set_t all_columns; - for (auto &column : GetColumns().Physical()) { - all_columns.insert(column.Physical()); - } - BindExtraColumns(*this, get, proj, update, all_columns); - } -} - -} // namespace duckdb - - - -namespace duckdb { - -TableFunctionCatalogEntry::TableFunctionCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, - CreateTableFunctionInfo &info) - : FunctionEntry(CatalogType::TABLE_FUNCTION_ENTRY, catalog, schema, info), functions(std::move(info.functions)) { - D_ASSERT(this->functions.Size() > 0); -} - -unique_ptr TableFunctionCatalogEntry::AlterEntry(ClientContext &context, AlterInfo &info) { - if (info.type != AlterType::ALTER_TABLE_FUNCTION) { - throw InternalException("Attempting to alter TableFunctionCatalogEntry with unsupported alter type"); - } - auto &function_info = info.Cast(); - if (function_info.alter_table_function_type != AlterTableFunctionType::ADD_FUNCTION_OVERLOADS) { - throw InternalException( - "Attempting to alter TableFunctionCatalogEntry with unsupported alter table function type"); - } - auto &add_overloads = function_info.Cast(); - - TableFunctionSet new_set = functions; - if (!new_set.MergeFunctionSet(add_overloads.new_overloads)) { - throw BinderException("Failed to add new function overloads to function \"%s\": function already exists", name); - } - CreateTableFunctionInfo new_info(std::move(new_set)); - return make_uniq(catalog, schema, new_info); -} - -} // namespace duckdb - - - - - - -#include -#include - -namespace duckdb { - -TypeCatalogEntry::TypeCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateTypeInfo &info) - : StandardEntry(CatalogType::TYPE_ENTRY, schema, catalog, info.name), user_type(info.type) { - this->temporary = info.temporary; - this->internal = info.internal; -} - -unique_ptr TypeCatalogEntry::GetInfo() const { - auto result = make_uniq(); - result->catalog = catalog.GetName(); - result->schema = schema.name; - result->name = name; - result->type = user_type; - return std::move(result); -} - -string TypeCatalogEntry::ToSQL() const { - std::stringstream ss; - switch (user_type.id()) { - case (LogicalTypeId::ENUM): { - auto &values_insert_order = EnumType::GetValuesInsertOrder(user_type); - idx_t size = EnumType::GetSize(user_type); - ss << "CREATE TYPE "; - ss << KeywordHelper::WriteOptionallyQuoted(name); - ss << " AS ENUM ( "; - - for (idx_t i = 0; i < size; i++) { - ss << "'" << values_insert_order.GetValue(i).ToString() << "'"; - if (i != size - 1) { - ss << ", "; - } - } - ss << ");"; - break; - } - default: - throw InternalException("Logical Type can't be used as a User Defined Type"); - } - - return ss.str(); -} - -} // namespace duckdb - - - - - - - - -#include - -namespace duckdb { - -void ViewCatalogEntry::Initialize(CreateViewInfo &info) { - query = std::move(info.query); - this->aliases = info.aliases; - this->types = info.types; - this->temporary = info.temporary; - this->sql = info.sql; - this->internal = info.internal; -} - -ViewCatalogEntry::ViewCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateViewInfo &info) - : StandardEntry(CatalogType::VIEW_ENTRY, schema, catalog, info.view_name) { - Initialize(info); -} - -unique_ptr ViewCatalogEntry::GetInfo() const { - auto result = make_uniq(); - result->schema = schema.name; - result->view_name = name; - result->sql = sql; - result->query = unique_ptr_cast(query->Copy()); - result->aliases = aliases; - result->types = types; - return std::move(result); -} - -unique_ptr ViewCatalogEntry::AlterEntry(ClientContext &context, AlterInfo &info) { - D_ASSERT(!internal); - if (info.type != AlterType::ALTER_VIEW) { - throw CatalogException("Can only modify view with ALTER VIEW statement"); - } - auto &view_info = info.Cast(); - switch (view_info.alter_view_type) { - case AlterViewType::RENAME_VIEW: { - auto &rename_info = view_info.Cast(); - auto copied_view = Copy(context); - copied_view->name = rename_info.new_view_name; - return copied_view; - } - default: - throw InternalException("Unrecognized alter view type!"); - } -} - -string ViewCatalogEntry::ToSQL() const { - if (sql.empty()) { - //! Return empty sql with view name so pragma view_tables don't complain - return sql; - } - return sql + "\n;"; -} - -unique_ptr ViewCatalogEntry::Copy(ClientContext &context) const { - D_ASSERT(!internal); - CreateViewInfo create_info(schema, name); - create_info.query = unique_ptr_cast(query->Copy()); - for (idx_t i = 0; i < aliases.size(); i++) { - create_info.aliases.push_back(aliases[i]); - } - for (idx_t i = 0; i < types.size(); i++) { - create_info.types.push_back(types[i]); - } - create_info.temporary = temporary; - create_info.sql = sql; - - return make_uniq(catalog, schema, create_info); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -CatalogEntry::CatalogEntry(CatalogType type, string name_p, idx_t oid) - : oid(oid), type(type), set(nullptr), name(std::move(name_p)), deleted(false), temporary(false), internal(false), - parent(nullptr) { -} - -CatalogEntry::CatalogEntry(CatalogType type, Catalog &catalog, string name_p) - : CatalogEntry(type, std::move(name_p), catalog.ModifyCatalog()) { -} - -CatalogEntry::~CatalogEntry() { -} - -void CatalogEntry::SetAsRoot() { -} - -// LCOV_EXCL_START -unique_ptr CatalogEntry::AlterEntry(ClientContext &context, AlterInfo &info) { - throw InternalException("Unsupported alter type for catalog entry!"); -} - -void CatalogEntry::UndoAlter(ClientContext &context, AlterInfo &info) { -} - -unique_ptr CatalogEntry::Copy(ClientContext &context) const { - throw InternalException("Unsupported copy type for catalog entry!"); -} - -unique_ptr CatalogEntry::GetInfo() const { - throw InternalException("Unsupported type for CatalogEntry::GetInfo!"); -} - -string CatalogEntry::ToSQL() const { - throw InternalException("Unsupported catalog type for ToSQL()"); -} - -Catalog &CatalogEntry::ParentCatalog() { - throw InternalException("CatalogEntry::ParentCatalog called on catalog entry without catalog"); -} - -SchemaCatalogEntry &CatalogEntry::ParentSchema() { - throw InternalException("CatalogEntry::ParentSchema called on catalog entry without schema"); -} -// LCOV_EXCL_STOP - -void CatalogEntry::Serialize(Serializer &serializer) const { - const auto info = GetInfo(); - info->Serialize(serializer); -} - -unique_ptr CatalogEntry::Deserialize(Deserializer &deserializer) { - return CreateInfo::Deserialize(deserializer); -} - -void CatalogEntry::Verify(Catalog &catalog_p) { -} - -InCatalogEntry::InCatalogEntry(CatalogType type, Catalog &catalog, string name) - : CatalogEntry(type, catalog, std::move(name)), catalog(catalog) { -} - -InCatalogEntry::~InCatalogEntry() { -} - -void InCatalogEntry::Verify(Catalog &catalog_p) { - D_ASSERT(&catalog_p == &catalog); -} -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -CatalogSearchEntry::CatalogSearchEntry(string catalog_p, string schema_p) - : catalog(std::move(catalog_p)), schema(std::move(schema_p)) { -} - -string CatalogSearchEntry::ToString() const { - if (catalog.empty()) { - return WriteOptionallyQuoted(schema); - } else { - return WriteOptionallyQuoted(catalog) + "." + WriteOptionallyQuoted(schema); - } -} - -string CatalogSearchEntry::WriteOptionallyQuoted(const string &input) { - for (idx_t i = 0; i < input.size(); i++) { - if (input[i] == '.' || input[i] == ',') { - return "\"" + input + "\""; - } - } - return input; -} - -string CatalogSearchEntry::ListToString(const vector &input) { - string result; - for (auto &entry : input) { - if (!result.empty()) { - result += ","; - } - result += entry.ToString(); - } - return result; -} - -CatalogSearchEntry CatalogSearchEntry::ParseInternal(const string &input, idx_t &idx) { - string catalog; - string schema; - string entry; - bool finished = false; -normal: - for (; idx < input.size(); idx++) { - if (input[idx] == '"') { - idx++; - goto quoted; - } else if (input[idx] == '.') { - goto separator; - } else if (input[idx] == ',') { - finished = true; - goto separator; - } - entry += input[idx]; - } - finished = true; - goto separator; -quoted: - //! look for another quote - for (; idx < input.size(); idx++) { - if (input[idx] == '"') { - //! unquote - idx++; - goto normal; - } - entry += input[idx]; - } - throw ParserException("Unterminated quote in qualified name!"); -separator: - if (entry.empty()) { - throw ParserException("Unexpected dot - empty CatalogSearchEntry"); - } - if (schema.empty()) { - // if we parse one entry it is the schema - schema = std::move(entry); - } else if (catalog.empty()) { - // if we parse two entries it is [catalog.schema] - catalog = std::move(schema); - schema = std::move(entry); - } else { - throw ParserException("Too many dots - expected [schema] or [catalog.schema] for CatalogSearchEntry"); - } - entry = ""; - idx++; - if (finished) { - goto final; - } - goto normal; -final: - if (schema.empty()) { - throw ParserException("Unexpected end of entry - empty CatalogSearchEntry"); - } - return CatalogSearchEntry(std::move(catalog), std::move(schema)); -} - -CatalogSearchEntry CatalogSearchEntry::Parse(const string &input) { - idx_t pos = 0; - auto result = ParseInternal(input, pos); - if (pos < input.size()) { - throw ParserException("Failed to convert entry \"%s\" to CatalogSearchEntry - expected a single entry", input); - } - return result; -} - -vector CatalogSearchEntry::ParseList(const string &input) { - idx_t pos = 0; - vector result; - while (pos < input.size()) { - auto entry = ParseInternal(input, pos); - result.push_back(entry); - } - return result; -} - -CatalogSearchPath::CatalogSearchPath(ClientContext &context_p) : context(context_p) { - Reset(); -} - -void CatalogSearchPath::Reset() { - vector empty; - SetPaths(empty); -} - -string CatalogSearchPath::GetSetName(CatalogSetPathType set_type) { - switch (set_type) { - case CatalogSetPathType::SET_SCHEMA: - return "SET schema"; - case CatalogSetPathType::SET_SCHEMAS: - return "SET search_path"; - default: - throw InternalException("Unrecognized CatalogSetPathType"); - } -} - -void CatalogSearchPath::Set(vector new_paths, CatalogSetPathType set_type) { - if (set_type != CatalogSetPathType::SET_SCHEMAS && new_paths.size() != 1) { - throw CatalogException("%s can set only 1 schema. This has %d", GetSetName(set_type), new_paths.size()); - } - for (auto &path : new_paths) { - auto schema_entry = Catalog::GetSchema(context, path.catalog, path.schema, OnEntryNotFound::RETURN_NULL); - if (schema_entry) { - // we are setting a schema - update the catalog and schema - if (path.catalog.empty()) { - path.catalog = GetDefault().catalog; - } - continue; - } - // only schema supplied - check if this is a catalog instead - if (path.catalog.empty()) { - auto catalog = Catalog::GetCatalogEntry(context, path.schema); - if (catalog) { - auto schema = catalog->GetSchema(context, DEFAULT_SCHEMA, OnEntryNotFound::RETURN_NULL); - if (schema) { - path.catalog = std::move(path.schema); - path.schema = schema->name; - continue; - } - } - } - throw CatalogException("%s: No catalog + schema named \"%s\" found.", GetSetName(set_type), path.ToString()); - } - if (set_type == CatalogSetPathType::SET_SCHEMA) { - if (new_paths[0].catalog == TEMP_CATALOG || new_paths[0].catalog == SYSTEM_CATALOG) { - throw CatalogException("%s cannot be set to internal schema \"%s\"", GetSetName(set_type), - new_paths[0].catalog); - } - } - this->set_paths = std::move(new_paths); - SetPaths(set_paths); -} - -void CatalogSearchPath::Set(CatalogSearchEntry new_value, CatalogSetPathType set_type) { - vector new_paths {std::move(new_value)}; - Set(std::move(new_paths), set_type); -} - -const vector &CatalogSearchPath::Get() { - return paths; -} - -string CatalogSearchPath::GetDefaultSchema(const string &catalog) { - for (auto &path : paths) { - if (path.catalog == TEMP_CATALOG) { - continue; - } - if (StringUtil::CIEquals(path.catalog, catalog)) { - return path.schema; - } - } - return DEFAULT_SCHEMA; -} - -string CatalogSearchPath::GetDefaultCatalog(const string &schema) { - for (auto &path : paths) { - if (path.catalog == TEMP_CATALOG) { - continue; - } - if (StringUtil::CIEquals(path.schema, schema)) { - return path.catalog; - } - } - return INVALID_CATALOG; -} - -vector CatalogSearchPath::GetCatalogsForSchema(const string &schema) { - vector schemas; - for (auto &path : paths) { - if (StringUtil::CIEquals(path.schema, schema)) { - schemas.push_back(path.catalog); - } - } - return schemas; -} - -vector CatalogSearchPath::GetSchemasForCatalog(const string &catalog) { - vector schemas; - for (auto &path : paths) { - if (StringUtil::CIEquals(path.catalog, catalog)) { - schemas.push_back(path.schema); - } - } - return schemas; -} - -const CatalogSearchEntry &CatalogSearchPath::GetDefault() { - const auto &paths = Get(); - D_ASSERT(paths.size() >= 2); - return paths[1]; -} - -void CatalogSearchPath::SetPaths(vector new_paths) { - paths.clear(); - paths.reserve(new_paths.size() + 3); - paths.emplace_back(TEMP_CATALOG, DEFAULT_SCHEMA); - for (auto &path : new_paths) { - paths.push_back(std::move(path)); - } - paths.emplace_back(INVALID_CATALOG, DEFAULT_SCHEMA); - paths.emplace_back(SYSTEM_CATALOG, DEFAULT_SCHEMA); - paths.emplace_back(SYSTEM_CATALOG, "pg_catalog"); -} - -bool CatalogSearchPath::SchemaInSearchPath(ClientContext &context, const string &catalog_name, - const string &schema_name) { - for (auto &path : paths) { - if (!StringUtil::CIEquals(path.schema, schema_name)) { - continue; - } - if (StringUtil::CIEquals(path.catalog, catalog_name)) { - return true; - } - if (IsInvalidCatalog(path.catalog) && - StringUtil::CIEquals(catalog_name, DatabaseManager::GetDefaultDatabase(context))) { - return true; - } - } - return false; -} - -} // namespace duckdb - - - - - - - - - - - - - - - - -namespace duckdb { - -//! Class responsible to keep track of state when removing entries from the catalog. -//! When deleting, many types of errors can be thrown, since we want to avoid try/catch blocks -//! this class makes sure that whatever elements were modified are returned to a correct state -//! when exceptions are thrown. -//! The idea here is to use RAII (Resource acquisition is initialization) to mimic a try/catch/finally block. -//! If any exception is raised when this object exists, then its destructor will be called -//! and the entry will return to its previous state during deconstruction. -class EntryDropper { -public: - //! Both constructor and destructor are privates because they should only be called by DropEntryDependencies - explicit EntryDropper(EntryIndex &entry_index_p) : entry_index(entry_index_p) { - old_deleted = entry_index.GetEntry()->deleted; - } - - ~EntryDropper() { - entry_index.GetEntry()->deleted = old_deleted; - } - -private: - //! Keeps track of the state of the entry before starting the delete - bool old_deleted; - //! Index of entry to be deleted - EntryIndex &entry_index; -}; - -CatalogSet::CatalogSet(Catalog &catalog_p, unique_ptr defaults) - : catalog(catalog_p.Cast()), defaults(std::move(defaults)) { - D_ASSERT(catalog_p.IsDuckCatalog()); -} -CatalogSet::~CatalogSet() { -} - -EntryIndex CatalogSet::PutEntry(idx_t entry_index, unique_ptr entry) { - if (entries.find(entry_index) != entries.end()) { - throw InternalException("Entry with entry index \"%llu\" already exists", entry_index); - } - entries.insert(make_pair(entry_index, EntryValue(std::move(entry)))); - return EntryIndex(*this, entry_index); -} - -void CatalogSet::PutEntry(EntryIndex index, unique_ptr catalog_entry) { - auto entry = entries.find(index.GetIndex()); - if (entry == entries.end()) { - throw InternalException("Entry with entry index \"%llu\" does not exist", index.GetIndex()); - } - catalog_entry->child = std::move(entry->second.entry); - catalog_entry->child->parent = catalog_entry.get(); - entry->second.entry = std::move(catalog_entry); -} - -bool CatalogSet::CreateEntry(CatalogTransaction transaction, const string &name, unique_ptr value, - DependencyList &dependencies) { - if (value->internal && !catalog.IsSystemCatalog() && name != DEFAULT_SCHEMA) { - throw InternalException("Attempting to create internal entry \"%s\" in non-system catalog - internal entries " - "can only be created in the system catalog", - name); - } - if (!value->internal) { - if (!value->temporary && catalog.IsSystemCatalog()) { - throw InternalException( - "Attempting to create non-internal entry \"%s\" in system catalog - the system catalog " - "can only contain internal entries", - name); - } - if (value->temporary && !catalog.IsTemporaryCatalog()) { - throw InternalException("Attempting to create temporary entry \"%s\" in non-temporary catalog", name); - } - if (!value->temporary && catalog.IsTemporaryCatalog() && name != DEFAULT_SCHEMA) { - throw InvalidInputException("Cannot create non-temporary entry \"%s\" in temporary catalog", name); - } - } - // lock the catalog for writing - lock_guard write_lock(catalog.GetWriteLock()); - // lock this catalog set to disallow reading - unique_lock read_lock(catalog_lock); - - // first check if the entry exists in the unordered set - idx_t index; - auto mapping_value = GetMapping(transaction, name); - if (mapping_value == nullptr || mapping_value->deleted) { - // if it does not: entry has never been created - - // check if there is a default entry - auto entry = CreateDefaultEntry(transaction, name, read_lock); - if (entry) { - return false; - } - - // first create a dummy deleted entry for this entry - // so transactions started before the commit of this transaction don't - // see it yet - auto dummy_node = make_uniq(CatalogType::INVALID, value->ParentCatalog(), name); - dummy_node->timestamp = 0; - dummy_node->deleted = true; - dummy_node->set = this; - - auto entry_index = PutEntry(current_entry++, std::move(dummy_node)); - index = entry_index.GetIndex(); - PutMapping(transaction, name, std::move(entry_index)); - } else { - index = mapping_value->index.GetIndex(); - auto ¤t = *mapping_value->index.GetEntry(); - // if it does, we have to check version numbers - if (HasConflict(transaction, current.timestamp)) { - // current version has been written to by a currently active - // transaction - throw TransactionException("Catalog write-write conflict on create with \"%s\"", current.name); - } - // there is a current version that has been committed - // if it has not been deleted there is a conflict - if (!current.deleted) { - return false; - } - } - // create a new entry and replace the currently stored one - // set the timestamp to the timestamp of the current transaction - // and point it at the dummy node - value->timestamp = transaction.transaction_id; - value->set = this; - - // now add the dependency set of this object to the dependency manager - catalog.GetDependencyManager().AddObject(transaction, *value, dependencies); - - auto value_ptr = value.get(); - EntryIndex entry_index(*this, index); - PutEntry(std::move(entry_index), std::move(value)); - // push the old entry in the undo buffer for this transaction - if (transaction.transaction) { - auto &dtransaction = transaction.transaction->Cast(); - dtransaction.PushCatalogEntry(*value_ptr->child); - } - return true; -} - -bool CatalogSet::CreateEntry(ClientContext &context, const string &name, unique_ptr value, - DependencyList &dependencies) { - return CreateEntry(catalog.GetCatalogTransaction(context), name, std::move(value), dependencies); -} - -optional_ptr CatalogSet::GetEntryInternal(CatalogTransaction transaction, EntryIndex &entry_index) { - auto &catalog_entry = *entry_index.GetEntry(); - // if it does: we have to retrieve the entry and to check version numbers - if (HasConflict(transaction, catalog_entry.timestamp)) { - // current version has been written to by a currently active - // transaction - throw TransactionException("Catalog write-write conflict on alter with \"%s\"", catalog_entry.name); - } - // there is a current version that has been committed by this transaction - if (catalog_entry.deleted) { - // if the entry was already deleted, it now does not exist anymore - // so we return that we could not find it - return nullptr; - } - return &catalog_entry; -} - -optional_ptr CatalogSet::GetEntryInternal(CatalogTransaction transaction, const string &name, - EntryIndex *entry_index) { - auto mapping_value = GetMapping(transaction, name); - if (mapping_value == nullptr || mapping_value->deleted) { - // the entry does not exist, check if we can create a default entry - return nullptr; - } - if (entry_index) { - *entry_index = mapping_value->index.Copy(); - } - return GetEntryInternal(transaction, mapping_value->index); -} - -bool CatalogSet::AlterOwnership(CatalogTransaction transaction, ChangeOwnershipInfo &info) { - auto entry = GetEntryInternal(transaction, info.name, nullptr); - if (!entry) { - return false; - } - - auto &owner_entry = catalog.GetEntry(transaction.GetContext(), info.owner_schema, info.owner_name); - catalog.GetDependencyManager().AddOwnership(transaction, owner_entry, *entry); - return true; -} - -bool CatalogSet::AlterEntry(CatalogTransaction transaction, const string &name, AlterInfo &alter_info) { - // lock the catalog for writing - lock_guard write_lock(catalog.GetWriteLock()); - - // first check if the entry exists in the unordered set - EntryIndex entry_index; - auto entry = GetEntryInternal(transaction, name, &entry_index); - if (!entry) { - return false; - } - if (!alter_info.allow_internal && entry->internal) { - throw CatalogException("Cannot alter entry \"%s\" because it is an internal system entry", entry->name); - } - - // lock this catalog set to disallow reading - lock_guard read_lock(catalog_lock); - - // create a new entry and replace the currently stored one - // set the timestamp to the timestamp of the current transaction - // and point it to the updated table node - string original_name = entry->name; - if (!transaction.context) { - throw InternalException("Cannot AlterEntry without client context"); - } - auto &context = *transaction.context; - auto value = entry->AlterEntry(context, alter_info); - if (!value) { - // alter failed, but did not result in an error - return true; - } - - if (value->name != original_name) { - auto mapping_value = GetMapping(transaction, value->name); - if (mapping_value && !mapping_value->deleted) { - auto &original_entry = GetEntryForTransaction(transaction, *mapping_value->index.GetEntry()); - if (!original_entry.deleted) { - entry->UndoAlter(context, alter_info); - string rename_err_msg = - "Could not rename \"%s\" to \"%s\": another entry with this name already exists!"; - throw CatalogException(rename_err_msg, original_name, value->name); - } - } - } - - if (value->name != original_name) { - // Do PutMapping and DeleteMapping after dependency check - PutMapping(transaction, value->name, entry_index.Copy()); - DeleteMapping(transaction, original_name); - } - - value->timestamp = transaction.transaction_id; - value->set = this; - auto new_entry = value.get(); - PutEntry(std::move(entry_index), std::move(value)); - - // serialize the AlterInfo into a temporary buffer - MemoryStream stream; - BinarySerializer serializer(stream); - serializer.Begin(); - serializer.WriteProperty(100, "column_name", alter_info.GetColumnName()); - serializer.WriteProperty(101, "alter_info", &alter_info); - serializer.End(); - - // push the old entry in the undo buffer for this transaction - if (transaction.transaction) { - auto &dtransaction = transaction.transaction->Cast(); - dtransaction.PushCatalogEntry(*new_entry->child, stream.GetData(), stream.GetPosition()); - } - - // Check the dependency manager to verify that there are no conflicting dependencies with this alter - // Note that we do this AFTER the new entry has been entirely set up in the catalog set - // that is because in case the alter fails because of a dependency conflict, we need to be able to cleanly roll back - // to the old entry. - catalog.GetDependencyManager().AlterObject(transaction, *entry, *new_entry); - - return true; -} - -void CatalogSet::DropEntryDependencies(CatalogTransaction transaction, EntryIndex &entry_index, CatalogEntry &entry, - bool cascade) { - // Stores the deleted value of the entry before starting the process - EntryDropper dropper(entry_index); - - // To correctly delete the object and its dependencies, it temporarily is set to deleted. - entry_index.GetEntry()->deleted = true; - - // check any dependencies of this object - D_ASSERT(entry.ParentCatalog().IsDuckCatalog()); - auto &duck_catalog = entry.ParentCatalog().Cast(); - duck_catalog.GetDependencyManager().DropObject(transaction, entry, cascade); - - // dropper destructor is called here - // the destructor makes sure to return the value to the previous state - // dropper.~EntryDropper() -} - -void CatalogSet::DropEntryInternal(CatalogTransaction transaction, EntryIndex entry_index, CatalogEntry &entry, - bool cascade) { - DropEntryDependencies(transaction, entry_index, entry, cascade); - - // create a new entry and replace the currently stored one - // set the timestamp to the timestamp of the current transaction - // and point it at the dummy node - auto value = make_uniq(CatalogType::DELETED_ENTRY, entry.ParentCatalog(), entry.name); - value->timestamp = transaction.transaction_id; - value->set = this; - value->deleted = true; - auto value_ptr = value.get(); - PutEntry(std::move(entry_index), std::move(value)); - - // push the old entry in the undo buffer for this transaction - if (transaction.transaction) { - auto &dtransaction = transaction.transaction->Cast(); - dtransaction.PushCatalogEntry(*value_ptr->child); - } -} - -bool CatalogSet::DropEntry(CatalogTransaction transaction, const string &name, bool cascade, bool allow_drop_internal) { - // lock the catalog for writing - lock_guard write_lock(catalog.GetWriteLock()); - // we can only delete an entry that exists - EntryIndex entry_index; - auto entry = GetEntryInternal(transaction, name, &entry_index); - if (!entry) { - return false; - } - if (entry->internal && !allow_drop_internal) { - throw CatalogException("Cannot drop entry \"%s\" because it is an internal system entry", entry->name); - } - - lock_guard read_lock(catalog_lock); - DropEntryInternal(transaction, std::move(entry_index), *entry, cascade); - return true; -} - -bool CatalogSet::DropEntry(ClientContext &context, const string &name, bool cascade, bool allow_drop_internal) { - return DropEntry(catalog.GetCatalogTransaction(context), name, cascade, allow_drop_internal); -} - -DuckCatalog &CatalogSet::GetCatalog() { - return catalog; -} - -void CatalogSet::CleanupEntry(CatalogEntry &catalog_entry) { - // destroy the backed up entry: it is no longer required - D_ASSERT(catalog_entry.parent); - if (catalog_entry.parent->type != CatalogType::UPDATED_ENTRY) { - lock_guard write_lock(catalog.GetWriteLock()); - lock_guard lock(catalog_lock); - if (!catalog_entry.deleted) { - // delete the entry from the dependency manager, if it is not deleted yet - D_ASSERT(catalog_entry.ParentCatalog().IsDuckCatalog()); - catalog_entry.ParentCatalog().Cast().GetDependencyManager().EraseObject(catalog_entry); - } - auto parent = catalog_entry.parent; - parent->child = std::move(catalog_entry.child); - if (parent->deleted && !parent->child && !parent->parent) { - auto mapping_entry = mapping.find(parent->name); - D_ASSERT(mapping_entry != mapping.end()); - auto &entry = mapping_entry->second->index.GetEntry(); - D_ASSERT(entry); - if (entry.get() == parent.get()) { - mapping.erase(mapping_entry); - } - } - } -} - -bool CatalogSet::HasConflict(CatalogTransaction transaction, transaction_t timestamp) { - return (timestamp >= TRANSACTION_ID_START && timestamp != transaction.transaction_id) || - (timestamp < TRANSACTION_ID_START && timestamp > transaction.start_time); -} - -optional_ptr CatalogSet::GetMapping(CatalogTransaction transaction, const string &name, bool get_latest) { - optional_ptr mapping_value; - auto entry = mapping.find(name); - if (entry != mapping.end()) { - mapping_value = entry->second.get(); - } else { - - return nullptr; - } - if (get_latest) { - return mapping_value; - } - while (mapping_value->child) { - if (UseTimestamp(transaction, mapping_value->timestamp)) { - break; - } - mapping_value = mapping_value->child.get(); - D_ASSERT(mapping_value); - } - return mapping_value; -} - -void CatalogSet::PutMapping(CatalogTransaction transaction, const string &name, EntryIndex entry_index) { - auto entry = mapping.find(name); - auto new_value = make_uniq(std::move(entry_index)); - new_value->timestamp = transaction.transaction_id; - if (entry != mapping.end()) { - if (HasConflict(transaction, entry->second->timestamp)) { - throw TransactionException("Catalog write-write conflict on name \"%s\"", name); - } - new_value->child = std::move(entry->second); - new_value->child->parent = new_value.get(); - } - mapping[name] = std::move(new_value); -} - -void CatalogSet::DeleteMapping(CatalogTransaction transaction, const string &name) { - auto entry = mapping.find(name); - D_ASSERT(entry != mapping.end()); - auto delete_marker = make_uniq(entry->second->index.Copy()); - delete_marker->deleted = true; - delete_marker->timestamp = transaction.transaction_id; - delete_marker->child = std::move(entry->second); - delete_marker->child->parent = delete_marker.get(); - mapping[name] = std::move(delete_marker); -} - -bool CatalogSet::UseTimestamp(CatalogTransaction transaction, transaction_t timestamp) { - if (timestamp == transaction.transaction_id) { - // we created this version - return true; - } - if (timestamp < transaction.start_time) { - // this version was commited before we started the transaction - return true; - } - return false; -} - -CatalogEntry &CatalogSet::GetEntryForTransaction(CatalogTransaction transaction, CatalogEntry ¤t) { - reference entry(current); - while (entry.get().child) { - if (UseTimestamp(transaction, entry.get().timestamp)) { - break; - } - entry = *entry.get().child; - } - return entry.get(); -} - -CatalogEntry &CatalogSet::GetCommittedEntry(CatalogEntry ¤t) { - reference entry(current); - while (entry.get().child) { - if (entry.get().timestamp < TRANSACTION_ID_START) { - // this entry is committed: use it - break; - } - entry = *entry.get().child; - } - return entry.get(); -} - -SimilarCatalogEntry CatalogSet::SimilarEntry(CatalogTransaction transaction, const string &name) { - unique_lock lock(catalog_lock); - CreateDefaultEntries(transaction, lock); - - SimilarCatalogEntry result; - for (auto &kv : mapping) { - auto mapping_value = GetMapping(transaction, kv.first); - if (mapping_value && !mapping_value->deleted) { - auto ldist = StringUtil::SimilarityScore(kv.first, name); - if (ldist < result.distance) { - result.distance = ldist; - result.name = kv.first; - } - } - } - return result; -} - -optional_ptr CatalogSet::CreateEntryInternal(CatalogTransaction transaction, - unique_ptr entry) { - if (mapping.find(entry->name) != mapping.end()) { - return nullptr; - } - auto &name = entry->name; - auto catalog_entry = entry.get(); - - entry->set = this; - entry->timestamp = 0; - - auto entry_index = PutEntry(current_entry++, std::move(entry)); - PutMapping(transaction, name, std::move(entry_index)); - mapping[name]->timestamp = 0; - return catalog_entry; -} - -optional_ptr CatalogSet::CreateDefaultEntry(CatalogTransaction transaction, const string &name, - unique_lock &lock) { - // no entry found with this name, check for defaults - if (!defaults || defaults->created_all_entries) { - // no defaults either: return null - return nullptr; - } - // this catalog set has a default map defined - // check if there is a default entry that we can create with this name - if (!transaction.context) { - // no context - cannot create default entry - return nullptr; - } - lock.unlock(); - auto entry = defaults->CreateDefaultEntry(*transaction.context, name); - - lock.lock(); - if (!entry) { - // no default entry - return nullptr; - } - // there is a default entry! create it - auto result = CreateEntryInternal(transaction, std::move(entry)); - if (result) { - return result; - } - // we found a default entry, but failed - // this means somebody else created the entry first - // just retry? - lock.unlock(); - return GetEntry(transaction, name); -} - -optional_ptr CatalogSet::GetEntry(CatalogTransaction transaction, const string &name) { - unique_lock lock(catalog_lock); - auto mapping_value = GetMapping(transaction, name); - if (mapping_value != nullptr && !mapping_value->deleted) { - // we found an entry for this name - // check the version numbers - - auto &catalog_entry = *mapping_value->index.GetEntry(); - auto ¤t = GetEntryForTransaction(transaction, catalog_entry); - if (current.deleted || (current.name != name && !UseTimestamp(transaction, mapping_value->timestamp))) { - return nullptr; - } - return ¤t; - } - return CreateDefaultEntry(transaction, name, lock); -} - -optional_ptr CatalogSet::GetEntry(ClientContext &context, const string &name) { - return GetEntry(catalog.GetCatalogTransaction(context), name); -} - -void CatalogSet::UpdateTimestamp(CatalogEntry &entry, transaction_t timestamp) { - entry.timestamp = timestamp; - mapping[entry.name]->timestamp = timestamp; -} - -void CatalogSet::Undo(CatalogEntry &entry) { - lock_guard write_lock(catalog.GetWriteLock()); - lock_guard lock(catalog_lock); - - // entry has to be restored - // and entry->parent has to be removed ("rolled back") - - // i.e. we have to place (entry) as (entry->parent) again - auto &to_be_removed_node = *entry.parent; - - if (!to_be_removed_node.deleted) { - // delete the entry from the dependency manager as well - auto &dependency_manager = catalog.GetDependencyManager(); - dependency_manager.EraseObject(to_be_removed_node); - } - if (!StringUtil::CIEquals(entry.name, to_be_removed_node.name)) { - // rename: clean up the new name when the rename is rolled back - auto removed_entry = mapping.find(to_be_removed_node.name); - if (removed_entry->second->child) { - removed_entry->second->child->parent = nullptr; - mapping[to_be_removed_node.name] = std::move(removed_entry->second->child); - } else { - mapping.erase(removed_entry); - } - } - if (to_be_removed_node.parent) { - // if the to be removed node has a parent, set the child pointer to the - // to be restored node - to_be_removed_node.parent->child = std::move(to_be_removed_node.child); - entry.parent = to_be_removed_node.parent; - } else { - // otherwise we need to update the base entry tables - auto &name = entry.name; - to_be_removed_node.child->SetAsRoot(); - mapping[name]->index.GetEntry() = std::move(to_be_removed_node.child); - entry.parent = nullptr; - } - - // restore the name if it was deleted - auto restored_entry = mapping.find(entry.name); - if (restored_entry->second->deleted || entry.type == CatalogType::INVALID) { - if (restored_entry->second->child) { - restored_entry->second->child->parent = nullptr; - mapping[entry.name] = std::move(restored_entry->second->child); - } else { - mapping.erase(restored_entry); - } - } - // we mark the catalog as being modified, since this action can lead to e.g. tables being dropped - catalog.ModifyCatalog(); -} - -void CatalogSet::CreateDefaultEntries(CatalogTransaction transaction, unique_lock &lock) { - if (!defaults || defaults->created_all_entries || !transaction.context) { - return; - } - // this catalog set has a default set defined: - auto default_entries = defaults->GetDefaultEntries(); - for (auto &default_entry : default_entries) { - auto map_entry = mapping.find(default_entry); - if (map_entry == mapping.end()) { - // we unlock during the CreateEntry, since it might reference other catalog sets... - // specifically for views this can happen since the view will be bound - lock.unlock(); - auto entry = defaults->CreateDefaultEntry(*transaction.context, default_entry); - if (!entry) { - throw InternalException("Failed to create default entry for %s", default_entry); - } - - lock.lock(); - CreateEntryInternal(transaction, std::move(entry)); - } - } - defaults->created_all_entries = true; -} - -void CatalogSet::Scan(CatalogTransaction transaction, const std::function &callback) { - // lock the catalog set - unique_lock lock(catalog_lock); - CreateDefaultEntries(transaction, lock); - - for (auto &kv : entries) { - auto &entry = *kv.second.entry.get(); - auto &entry_for_transaction = GetEntryForTransaction(transaction, entry); - if (!entry_for_transaction.deleted) { - callback(entry_for_transaction); - } - } -} - -void CatalogSet::Scan(ClientContext &context, const std::function &callback) { - Scan(catalog.GetCatalogTransaction(context), callback); -} - -void CatalogSet::Scan(const std::function &callback) { - // lock the catalog set - lock_guard lock(catalog_lock); - for (auto &kv : entries) { - auto entry = kv.second.entry.get(); - auto &commited_entry = GetCommittedEntry(*entry); - if (!commited_entry.deleted) { - callback(commited_entry); - } - } -} - -void CatalogSet::Verify(Catalog &catalog_p) { - D_ASSERT(&catalog_p == &catalog); - vector> entries; - Scan([&](CatalogEntry &entry) { entries.push_back(entry); }); - for (auto &entry : entries) { - entry.get().Verify(catalog_p); - } -} - -} // namespace duckdb - - - - - -namespace duckdb { - -CatalogTransaction::CatalogTransaction(Catalog &catalog, ClientContext &context) { - auto &transaction = Transaction::Get(context, catalog); - this->db = &DatabaseInstance::GetDatabase(context); - if (!transaction.IsDuckTransaction()) { - this->transaction_id = transaction_t(-1); - this->start_time = transaction_t(-1); - } else { - auto &dtransaction = transaction.Cast(); - this->transaction_id = dtransaction.transaction_id; - this->start_time = dtransaction.start_time; - } - this->transaction = &transaction; - this->context = &context; -} - -CatalogTransaction::CatalogTransaction(DatabaseInstance &db, transaction_t transaction_id_p, transaction_t start_time_p) - : db(&db), context(nullptr), transaction(nullptr), transaction_id(transaction_id_p), start_time(start_time_p) { -} - -ClientContext &CatalogTransaction::GetContext() { - if (!context) { - throw InternalException("Attempting to get a context in a CatalogTransaction without a context"); - } - return *context; -} - -CatalogTransaction CatalogTransaction::GetSystemTransaction(DatabaseInstance &db) { - return CatalogTransaction(db, 1, 1); -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -static DefaultMacro internal_macros[] = { - {DEFAULT_SCHEMA, "current_role", {nullptr}, "'duckdb'"}, // user name of current execution context - {DEFAULT_SCHEMA, "current_user", {nullptr}, "'duckdb'"}, // user name of current execution context - {DEFAULT_SCHEMA, "current_catalog", {nullptr}, "current_database()"}, // name of current database (called "catalog" in the SQL standard) - {DEFAULT_SCHEMA, "user", {nullptr}, "current_user"}, // equivalent to current_user - {DEFAULT_SCHEMA, "session_user", {nullptr}, "'duckdb'"}, // session user name - {"pg_catalog", "inet_client_addr", {nullptr}, "NULL"}, // address of the remote connection - {"pg_catalog", "inet_client_port", {nullptr}, "NULL"}, // port of the remote connection - {"pg_catalog", "inet_server_addr", {nullptr}, "NULL"}, // address of the local connection - {"pg_catalog", "inet_server_port", {nullptr}, "NULL"}, // port of the local connection - {"pg_catalog", "pg_my_temp_schema", {nullptr}, "0"}, // OID of session's temporary schema, or 0 if none - {"pg_catalog", "pg_is_other_temp_schema", {"schema_id", nullptr}, "false"}, // is schema another session's temporary schema? - - {"pg_catalog", "pg_conf_load_time", {nullptr}, "current_timestamp"}, // configuration load time - {"pg_catalog", "pg_postmaster_start_time", {nullptr}, "current_timestamp"}, // server start time - - {"pg_catalog", "pg_typeof", {"expression", nullptr}, "lower(typeof(expression))"}, // get the data type of any value - - // privilege functions - // {"has_any_column_privilege", {"user", "table", "privilege", nullptr}, "true"}, //boolean //does user have privilege for any column of table - {"pg_catalog", "has_any_column_privilege", {"table", "privilege", nullptr}, "true"}, //boolean //does current user have privilege for any column of table - // {"has_column_privilege", {"user", "table", "column", "privilege", nullptr}, "true"}, //boolean //does user have privilege for column - {"pg_catalog", "has_column_privilege", {"table", "column", "privilege", nullptr}, "true"}, //boolean //does current user have privilege for column - // {"has_database_privilege", {"user", "database", "privilege", nullptr}, "true"}, //boolean //does user have privilege for database - {"pg_catalog", "has_database_privilege", {"database", "privilege", nullptr}, "true"}, //boolean //does current user have privilege for database - // {"has_foreign_data_wrapper_privilege", {"user", "fdw", "privilege", nullptr}, "true"}, //boolean //does user have privilege for foreign-data wrapper - {"pg_catalog", "has_foreign_data_wrapper_privilege", {"fdw", "privilege", nullptr}, "true"}, //boolean //does current user have privilege for foreign-data wrapper - // {"has_function_privilege", {"user", "function", "privilege", nullptr}, "true"}, //boolean //does user have privilege for function - {"pg_catalog", "has_function_privilege", {"function", "privilege", nullptr}, "true"}, //boolean //does current user have privilege for function - // {"has_language_privilege", {"user", "language", "privilege", nullptr}, "true"}, //boolean //does user have privilege for language - {"pg_catalog", "has_language_privilege", {"language", "privilege", nullptr}, "true"}, //boolean //does current user have privilege for language - // {"has_schema_privilege", {"user", "schema, privilege", nullptr}, "true"}, //boolean //does user have privilege for schema - {"pg_catalog", "has_schema_privilege", {"schema", "privilege", nullptr}, "true"}, //boolean //does current user have privilege for schema - // {"has_sequence_privilege", {"user", "sequence", "privilege", nullptr}, "true"}, //boolean //does user have privilege for sequence - {"pg_catalog", "has_sequence_privilege", {"sequence", "privilege", nullptr}, "true"}, //boolean //does current user have privilege for sequence - // {"has_server_privilege", {"user", "server", "privilege", nullptr}, "true"}, //boolean //does user have privilege for foreign server - {"pg_catalog", "has_server_privilege", {"server", "privilege", nullptr}, "true"}, //boolean //does current user have privilege for foreign server - // {"has_table_privilege", {"user", "table", "privilege", nullptr}, "true"}, //boolean //does user have privilege for table - {"pg_catalog", "has_table_privilege", {"table", "privilege", nullptr}, "true"}, //boolean //does current user have privilege for table - // {"has_tablespace_privilege", {"user", "tablespace", "privilege", nullptr}, "true"}, //boolean //does user have privilege for tablespace - {"pg_catalog", "has_tablespace_privilege", {"tablespace", "privilege", nullptr}, "true"}, //boolean //does current user have privilege for tablespace - - // various postgres system functions - {"pg_catalog", "pg_get_viewdef", {"oid", nullptr}, "(select sql from duckdb_views() v where v.view_oid=oid)"}, - {"pg_catalog", "pg_get_constraintdef", {"constraint_oid", "pretty_bool", nullptr}, "(select constraint_text from duckdb_constraints() d_constraint where d_constraint.table_oid=constraint_oid//1000000 and d_constraint.constraint_index=constraint_oid%1000000)"}, - {"pg_catalog", "pg_get_expr", {"pg_node_tree", "relation_oid", nullptr}, "pg_node_tree"}, - {"pg_catalog", "format_pg_type", {"type_name", nullptr}, "case when logical_type='FLOAT' then 'real' when logical_type='DOUBLE' then 'double precision' when logical_type='DECIMAL' then 'numeric' when logical_type='ENUM' then lower(type_name) when logical_type='VARCHAR' then 'character varying' when logical_type='BLOB' then 'bytea' when logical_type='TIMESTAMP' then 'timestamp without time zone' when logical_type='TIME' then 'time without time zone' else lower(logical_type) end"}, - {"pg_catalog", "format_type", {"type_oid", "typemod", nullptr}, "(select format_pg_type(type_name) from duckdb_types() t where t.type_oid=type_oid) || case when typemod>0 then concat('(', typemod//1000, ',', typemod%1000, ')') else '' end"}, - - {"pg_catalog", "pg_has_role", {"user", "role", "privilege", nullptr}, "true"}, //boolean //does user have privilege for role - {"pg_catalog", "pg_has_role", {"role", "privilege", nullptr}, "true"}, //boolean //does current user have privilege for role - - {"pg_catalog", "col_description", {"table_oid", "column_number", nullptr}, "NULL"}, // get comment for a table column - {"pg_catalog", "obj_description", {"object_oid", "catalog_name", nullptr}, "NULL"}, // get comment for a database object - {"pg_catalog", "shobj_description", {"object_oid", "catalog_name", nullptr}, "NULL"}, // get comment for a shared database object - - // visibility functions - {"pg_catalog", "pg_collation_is_visible", {"collation_oid", nullptr}, "true"}, - {"pg_catalog", "pg_conversion_is_visible", {"conversion_oid", nullptr}, "true"}, - {"pg_catalog", "pg_function_is_visible", {"function_oid", nullptr}, "true"}, - {"pg_catalog", "pg_opclass_is_visible", {"opclass_oid", nullptr}, "true"}, - {"pg_catalog", "pg_operator_is_visible", {"operator_oid", nullptr}, "true"}, - {"pg_catalog", "pg_opfamily_is_visible", {"opclass_oid", nullptr}, "true"}, - {"pg_catalog", "pg_table_is_visible", {"table_oid", nullptr}, "true"}, - {"pg_catalog", "pg_ts_config_is_visible", {"config_oid", nullptr}, "true"}, - {"pg_catalog", "pg_ts_dict_is_visible", {"dict_oid", nullptr}, "true"}, - {"pg_catalog", "pg_ts_parser_is_visible", {"parser_oid", nullptr}, "true"}, - {"pg_catalog", "pg_ts_template_is_visible", {"template_oid", nullptr}, "true"}, - {"pg_catalog", "pg_type_is_visible", {"type_oid", nullptr}, "true"}, - - {"pg_catalog", "pg_size_pretty", {"bytes", nullptr}, "format_bytes(bytes)"}, - - {DEFAULT_SCHEMA, "round_even", {"x", "n", nullptr}, "CASE ((abs(x) * power(10, n+1)) % 10) WHEN 5 THEN round(x/2, n) * 2 ELSE round(x, n) END"}, - {DEFAULT_SCHEMA, "roundbankers", {"x", "n", nullptr}, "round_even(x, n)"}, - {DEFAULT_SCHEMA, "nullif", {"a", "b", nullptr}, "CASE WHEN a=b THEN NULL ELSE a END"}, - {DEFAULT_SCHEMA, "list_append", {"l", "e", nullptr}, "list_concat(l, list_value(e))"}, - {DEFAULT_SCHEMA, "array_append", {"arr", "el", nullptr}, "list_append(arr, el)"}, - {DEFAULT_SCHEMA, "list_prepend", {"e", "l", nullptr}, "list_concat(list_value(e), l)"}, - {DEFAULT_SCHEMA, "array_prepend", {"el", "arr", nullptr}, "list_prepend(el, arr)"}, - {DEFAULT_SCHEMA, "array_pop_back", {"arr", nullptr}, "arr[:LEN(arr)-1]"}, - {DEFAULT_SCHEMA, "array_pop_front", {"arr", nullptr}, "arr[2:]"}, - {DEFAULT_SCHEMA, "array_push_back", {"arr", "e", nullptr}, "list_concat(arr, list_value(e))"}, - {DEFAULT_SCHEMA, "array_push_front", {"arr", "e", nullptr}, "list_concat(list_value(e), arr)"}, - {DEFAULT_SCHEMA, "array_to_string", {"arr", "sep", nullptr}, "list_aggr(arr, 'string_agg', sep)"}, - {DEFAULT_SCHEMA, "generate_subscripts", {"arr", "dim", nullptr}, "unnest(generate_series(1, array_length(arr, dim)))"}, - {DEFAULT_SCHEMA, "fdiv", {"x", "y", nullptr}, "floor(x/y)"}, - {DEFAULT_SCHEMA, "fmod", {"x", "y", nullptr}, "(x-y*floor(x/y))"}, - {DEFAULT_SCHEMA, "count_if", {"l", nullptr}, "sum(if(l, 1, 0))"}, - {DEFAULT_SCHEMA, "split_part", {"string", "delimiter", "position", nullptr}, "coalesce(string_split(string, delimiter)[position],'')"}, - {DEFAULT_SCHEMA, "geomean", {"x", nullptr}, "exp(avg(ln(x)))"}, - {DEFAULT_SCHEMA, "geometric_mean", {"x", nullptr}, "geomean(x)"}, - - {DEFAULT_SCHEMA, "list_reverse", {"l", nullptr}, "l[:-:-1]"}, - {DEFAULT_SCHEMA, "array_reverse", {"l", nullptr}, "list_reverse(l)"}, - - // FIXME implement as actual function if we encounter a lot of performance issues. Complexity now: n * m, with hashing possibly n + m - {DEFAULT_SCHEMA, "list_intersect", {"l1", "l2", nullptr}, "list_filter(l1, (x) -> list_contains(l2, x))"}, - {DEFAULT_SCHEMA, "array_intersect", {"l1", "l2", nullptr}, "list_intersect(l1, l2)"}, - - {DEFAULT_SCHEMA, "list_has_any", {"l1", "l2", nullptr}, "CASE WHEN l1 IS NULL THEN NULL WHEN l2 IS NULL THEN NULL WHEN len(list_intersect(l1, l2)) > 0 THEN true ELSE false END"}, - {DEFAULT_SCHEMA, "array_has_any", {"l1", "l2", nullptr}, "list_has_any(l1, l2)" }, - {DEFAULT_SCHEMA, "&&", {"l1", "l2", nullptr}, "list_has_any(l1, l2)" }, // "&&" is the operator for "list_has_any - - {DEFAULT_SCHEMA, "list_has_all", {"l1", "l2", nullptr}, "CASE WHEN l1 IS NULL THEN NULL WHEN l2 IS NULL THEN NULL WHEN len(list_intersect(l2, l1)) = len(list_filter(l2, x -> x IS NOT NULL)) THEN true ELSE false END"}, - {DEFAULT_SCHEMA, "array_has_all", {"l1", "l2", nullptr}, "list_has_all(l1, l2)" }, - {DEFAULT_SCHEMA, "@>", {"l1", "l2", nullptr}, "list_has_all(l1, l2)" }, // "@>" is the operator for "list_has_all - {DEFAULT_SCHEMA, "<@", {"l1", "l2", nullptr}, "list_has_all(l2, l1)" }, // "<@" is the operator for "list_has_all - - // algebraic list aggregates - {DEFAULT_SCHEMA, "list_avg", {"l", nullptr}, "list_aggr(l, 'avg')"}, - {DEFAULT_SCHEMA, "list_var_samp", {"l", nullptr}, "list_aggr(l, 'var_samp')"}, - {DEFAULT_SCHEMA, "list_var_pop", {"l", nullptr}, "list_aggr(l, 'var_pop')"}, - {DEFAULT_SCHEMA, "list_stddev_pop", {"l", nullptr}, "list_aggr(l, 'stddev_pop')"}, - {DEFAULT_SCHEMA, "list_stddev_samp", {"l", nullptr}, "list_aggr(l, 'stddev_samp')"}, - {DEFAULT_SCHEMA, "list_sem", {"l", nullptr}, "list_aggr(l, 'sem')"}, - - // distributive list aggregates - {DEFAULT_SCHEMA, "list_approx_count_distinct", {"l", nullptr}, "list_aggr(l, 'approx_count_distinct')"}, - {DEFAULT_SCHEMA, "list_bit_xor", {"l", nullptr}, "list_aggr(l, 'bit_xor')"}, - {DEFAULT_SCHEMA, "list_bit_or", {"l", nullptr}, "list_aggr(l, 'bit_or')"}, - {DEFAULT_SCHEMA, "list_bit_and", {"l", nullptr}, "list_aggr(l, 'bit_and')"}, - {DEFAULT_SCHEMA, "list_bool_and", {"l", nullptr}, "list_aggr(l, 'bool_and')"}, - {DEFAULT_SCHEMA, "list_bool_or", {"l", nullptr}, "list_aggr(l, 'bool_or')"}, - {DEFAULT_SCHEMA, "list_count", {"l", nullptr}, "list_aggr(l, 'count')"}, - {DEFAULT_SCHEMA, "list_entropy", {"l", nullptr}, "list_aggr(l, 'entropy')"}, - {DEFAULT_SCHEMA, "list_last", {"l", nullptr}, "list_aggr(l, 'last')"}, - {DEFAULT_SCHEMA, "list_first", {"l", nullptr}, "list_aggr(l, 'first')"}, - {DEFAULT_SCHEMA, "list_any_value", {"l", nullptr}, "list_aggr(l, 'any_value')"}, - {DEFAULT_SCHEMA, "list_kurtosis", {"l", nullptr}, "list_aggr(l, 'kurtosis')"}, - {DEFAULT_SCHEMA, "list_min", {"l", nullptr}, "list_aggr(l, 'min')"}, - {DEFAULT_SCHEMA, "list_max", {"l", nullptr}, "list_aggr(l, 'max')"}, - {DEFAULT_SCHEMA, "list_product", {"l", nullptr}, "list_aggr(l, 'product')"}, - {DEFAULT_SCHEMA, "list_skewness", {"l", nullptr}, "list_aggr(l, 'skewness')"}, - {DEFAULT_SCHEMA, "list_sum", {"l", nullptr}, "list_aggr(l, 'sum')"}, - {DEFAULT_SCHEMA, "list_string_agg", {"l", nullptr}, "list_aggr(l, 'string_agg')"}, - - // holistic list aggregates - {DEFAULT_SCHEMA, "list_mode", {"l", nullptr}, "list_aggr(l, 'mode')"}, - {DEFAULT_SCHEMA, "list_median", {"l", nullptr}, "list_aggr(l, 'median')"}, - {DEFAULT_SCHEMA, "list_mad", {"l", nullptr}, "list_aggr(l, 'mad')"}, - - // nested list aggregates - {DEFAULT_SCHEMA, "list_histogram", {"l", nullptr}, "list_aggr(l, 'histogram')"}, - - // date functions - {DEFAULT_SCHEMA, "date_add", {"date", "interval", nullptr}, "date + interval"}, - - {nullptr, nullptr, {nullptr}, nullptr} - }; - -unique_ptr DefaultFunctionGenerator::CreateInternalTableMacroInfo(DefaultMacro &default_macro, unique_ptr function) { - for (idx_t param_idx = 0; default_macro.parameters[param_idx] != nullptr; param_idx++) { - function->parameters.push_back( - make_uniq(default_macro.parameters[param_idx])); - } - - auto type = function->type == MacroType::TABLE_MACRO ? CatalogType::TABLE_MACRO_ENTRY : CatalogType::MACRO_ENTRY; - auto bind_info = make_uniq(type); - bind_info->schema = default_macro.schema; - bind_info->name = default_macro.name; - bind_info->temporary = true; - bind_info->internal = true; - bind_info->function = std::move(function); - return bind_info; - -} - -unique_ptr DefaultFunctionGenerator::CreateInternalMacroInfo(DefaultMacro &default_macro) { - // parse the expression - auto expressions = Parser::ParseExpressionList(default_macro.macro); - D_ASSERT(expressions.size() == 1); - - auto result = make_uniq(std::move(expressions[0])); - return CreateInternalTableMacroInfo(default_macro, std::move(result)); -} - -unique_ptr DefaultFunctionGenerator::CreateInternalTableMacroInfo(DefaultMacro &default_macro) { - Parser parser; - parser.ParseQuery(default_macro.macro); - D_ASSERT(parser.statements.size() == 1); - D_ASSERT(parser.statements[0]->type == StatementType::SELECT_STATEMENT); - - auto &select = parser.statements[0]->Cast(); - auto result = make_uniq(std::move(select.node)); - return CreateInternalTableMacroInfo(default_macro, std::move(result)); -} - -static unique_ptr GetDefaultFunction(const string &input_schema, const string &input_name) { - auto schema = StringUtil::Lower(input_schema); - auto name = StringUtil::Lower(input_name); - for (idx_t index = 0; internal_macros[index].name != nullptr; index++) { - if (internal_macros[index].schema == schema && internal_macros[index].name == name) { - return DefaultFunctionGenerator::CreateInternalMacroInfo(internal_macros[index]); - } - } - return nullptr; -} - -DefaultFunctionGenerator::DefaultFunctionGenerator(Catalog &catalog, SchemaCatalogEntry &schema) - : DefaultGenerator(catalog), schema(schema) { -} - -unique_ptr DefaultFunctionGenerator::CreateDefaultEntry(ClientContext &context, - const string &entry_name) { - auto info = GetDefaultFunction(schema.name, entry_name); - if (info) { - return make_uniq_base(catalog, schema, info->Cast()); - } - return nullptr; -} - -vector DefaultFunctionGenerator::GetDefaultEntries() { - vector result; - for (idx_t index = 0; internal_macros[index].name != nullptr; index++) { - if (StringUtil::Lower(internal_macros[index].name) != internal_macros[index].name) { - throw InternalException("Default macro name %s should be lowercase", internal_macros[index].name); - } - if (internal_macros[index].schema == schema.name) { - result.emplace_back(internal_macros[index].name); - } - } - return result; -} - -} // namespace duckdb - - - - -namespace duckdb { - -struct DefaultSchema { - const char *name; -}; - -static DefaultSchema internal_schemas[] = {{"information_schema"}, {"pg_catalog"}, {nullptr}}; - -static bool GetDefaultSchema(const string &input_schema) { - auto schema = StringUtil::Lower(input_schema); - for (idx_t index = 0; internal_schemas[index].name != nullptr; index++) { - if (internal_schemas[index].name == schema) { - return true; - } - } - return false; -} - -DefaultSchemaGenerator::DefaultSchemaGenerator(Catalog &catalog) : DefaultGenerator(catalog) { -} - -unique_ptr DefaultSchemaGenerator::CreateDefaultEntry(ClientContext &context, const string &entry_name) { - if (GetDefaultSchema(entry_name)) { - return make_uniq_base(catalog, StringUtil::Lower(entry_name), true); - } - return nullptr; -} - -vector DefaultSchemaGenerator::GetDefaultEntries() { - vector result; - for (idx_t index = 0; internal_schemas[index].name != nullptr; index++) { - result.emplace_back(internal_schemas[index].name); - } - return result; -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -LogicalTypeId DefaultTypeGenerator::GetDefaultType(const string &name) { - auto &internal_types = BUILTIN_TYPES; - for (auto &type : internal_types) { - if (StringUtil::CIEquals(name, type.name)) { - return type.type; - } - } - return LogicalType::INVALID; -} - -DefaultTypeGenerator::DefaultTypeGenerator(Catalog &catalog, SchemaCatalogEntry &schema) - : DefaultGenerator(catalog), schema(schema) { -} - -unique_ptr DefaultTypeGenerator::CreateDefaultEntry(ClientContext &context, const string &entry_name) { - if (schema.name != DEFAULT_SCHEMA) { - return nullptr; - } - auto type_id = GetDefaultType(entry_name); - if (type_id == LogicalTypeId::INVALID) { - return nullptr; - } - CreateTypeInfo info; - info.name = entry_name; - info.type = LogicalType(type_id); - info.internal = true; - info.temporary = true; - return make_uniq_base(catalog, schema, info); -} - -vector DefaultTypeGenerator::GetDefaultEntries() { - vector result; - if (schema.name != DEFAULT_SCHEMA) { - return result; - } - auto &internal_types = BUILTIN_TYPES; - for (auto &type : internal_types) { - result.emplace_back(StringUtil::Lower(type.name)); - } - return result; -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -struct DefaultView { - const char *schema; - const char *name; - const char *sql; -}; - -static DefaultView internal_views[] = { - {DEFAULT_SCHEMA, "pragma_database_list", "SELECT database_oid AS seq, database_name AS name, path AS file FROM duckdb_databases() WHERE NOT internal ORDER BY 1"}, - {DEFAULT_SCHEMA, "sqlite_master", "select 'table' \"type\", table_name \"name\", table_name \"tbl_name\", 0 rootpage, sql from duckdb_tables union all select 'view' \"type\", view_name \"name\", view_name \"tbl_name\", 0 rootpage, sql from duckdb_views union all select 'index' \"type\", index_name \"name\", table_name \"tbl_name\", 0 rootpage, sql from duckdb_indexes;"}, - {DEFAULT_SCHEMA, "sqlite_schema", "SELECT * FROM sqlite_master"}, - {DEFAULT_SCHEMA, "sqlite_temp_master", "SELECT * FROM sqlite_master"}, - {DEFAULT_SCHEMA, "sqlite_temp_schema", "SELECT * FROM sqlite_master"}, - {DEFAULT_SCHEMA, "duckdb_constraints", "SELECT * FROM duckdb_constraints()"}, - {DEFAULT_SCHEMA, "duckdb_columns", "SELECT * FROM duckdb_columns() WHERE NOT internal"}, - {DEFAULT_SCHEMA, "duckdb_databases", "SELECT * FROM duckdb_databases() WHERE NOT internal"}, - {DEFAULT_SCHEMA, "duckdb_indexes", "SELECT * FROM duckdb_indexes()"}, - {DEFAULT_SCHEMA, "duckdb_schemas", "SELECT * FROM duckdb_schemas() WHERE NOT internal"}, - {DEFAULT_SCHEMA, "duckdb_tables", "SELECT * FROM duckdb_tables() WHERE NOT internal"}, - {DEFAULT_SCHEMA, "duckdb_types", "SELECT * FROM duckdb_types()"}, - {DEFAULT_SCHEMA, "duckdb_views", "SELECT * FROM duckdb_views() WHERE NOT internal"}, - {"pg_catalog", "pg_am", "SELECT 0 oid, 'art' amname, NULL amhandler, 'i' amtype"}, - {"pg_catalog", "pg_attribute", "SELECT table_oid attrelid, column_name attname, data_type_id atttypid, 0 attstattarget, NULL attlen, column_index attnum, 0 attndims, -1 attcacheoff, case when data_type ilike '%decimal%' then numeric_precision*1000+numeric_scale else -1 end atttypmod, false attbyval, NULL attstorage, NULL attalign, NOT is_nullable attnotnull, column_default IS NOT NULL atthasdef, false atthasmissing, '' attidentity, '' attgenerated, false attisdropped, true attislocal, 0 attinhcount, 0 attcollation, NULL attcompression, NULL attacl, NULL attoptions, NULL attfdwoptions, NULL attmissingval FROM duckdb_columns()"}, - {"pg_catalog", "pg_attrdef", "SELECT column_index oid, table_oid adrelid, column_index adnum, column_default adbin from duckdb_columns() where column_default is not null;"}, - {"pg_catalog", "pg_class", "SELECT table_oid oid, table_name relname, schema_oid relnamespace, 0 reltype, 0 reloftype, 0 relowner, 0 relam, 0 relfilenode, 0 reltablespace, 0 relpages, estimated_size::real reltuples, 0 relallvisible, 0 reltoastrelid, 0 reltoastidxid, index_count > 0 relhasindex, false relisshared, case when temporary then 't' else 'p' end relpersistence, 'r' relkind, column_count relnatts, check_constraint_count relchecks, false relhasoids, has_primary_key relhaspkey, false relhasrules, false relhastriggers, false relhassubclass, false relrowsecurity, true relispopulated, NULL relreplident, false relispartition, 0 relrewrite, 0 relfrozenxid, NULL relminmxid, NULL relacl, NULL reloptions, NULL relpartbound FROM duckdb_tables() UNION ALL SELECT view_oid oid, view_name relname, schema_oid relnamespace, 0 reltype, 0 reloftype, 0 relowner, 0 relam, 0 relfilenode, 0 reltablespace, 0 relpages, 0 reltuples, 0 relallvisible, 0 reltoastrelid, 0 reltoastidxid, false relhasindex, false relisshared, case when temporary then 't' else 'p' end relpersistence, 'v' relkind, column_count relnatts, 0 relchecks, false relhasoids, false relhaspkey, false relhasrules, false relhastriggers, false relhassubclass, false relrowsecurity, true relispopulated, NULL relreplident, false relispartition, 0 relrewrite, 0 relfrozenxid, NULL relminmxid, NULL relacl, NULL reloptions, NULL relpartbound FROM duckdb_views() UNION ALL SELECT sequence_oid oid, sequence_name relname, schema_oid relnamespace, 0 reltype, 0 reloftype, 0 relowner, 0 relam, 0 relfilenode, 0 reltablespace, 0 relpages, 0 reltuples, 0 relallvisible, 0 reltoastrelid, 0 reltoastidxid, false relhasindex, false relisshared, case when temporary then 't' else 'p' end relpersistence, 'S' relkind, 0 relnatts, 0 relchecks, false relhasoids, false relhaspkey, false relhasrules, false relhastriggers, false relhassubclass, false relrowsecurity, true relispopulated, NULL relreplident, false relispartition, 0 relrewrite, 0 relfrozenxid, NULL relminmxid, NULL relacl, NULL reloptions, NULL relpartbound FROM duckdb_sequences() UNION ALL SELECT index_oid oid, index_name relname, schema_oid relnamespace, 0 reltype, 0 reloftype, 0 relowner, 0 relam, 0 relfilenode, 0 reltablespace, 0 relpages, 0 reltuples, 0 relallvisible, 0 reltoastrelid, 0 reltoastidxid, false relhasindex, false relisshared, 't' relpersistence, 'i' relkind, NULL relnatts, 0 relchecks, false relhasoids, false relhaspkey, false relhasrules, false relhastriggers, false relhassubclass, false relrowsecurity, true relispopulated, NULL relreplident, false relispartition, 0 relrewrite, 0 relfrozenxid, NULL relminmxid, NULL relacl, NULL reloptions, NULL relpartbound FROM duckdb_indexes()"}, - {"pg_catalog", "pg_constraint", "SELECT table_oid*1000000+constraint_index oid, constraint_text conname, schema_oid connamespace, CASE constraint_type WHEN 'CHECK' then 'c' WHEN 'UNIQUE' then 'u' WHEN 'PRIMARY KEY' THEN 'p' WHEN 'FOREIGN KEY' THEN 'f' ELSE 'x' END contype, false condeferrable, false condeferred, true convalidated, table_oid conrelid, 0 contypid, 0 conindid, 0 conparentid, 0 confrelid, NULL confupdtype, NULL confdeltype, NULL confmatchtype, true conislocal, 0 coninhcount, false connoinherit, constraint_column_indexes conkey, NULL confkey, NULL conpfeqop, NULL conppeqop, NULL conffeqop, NULL conexclop, expression conbin FROM duckdb_constraints()"}, - {"pg_catalog", "pg_database", "SELECT database_oid oid, database_name datname FROM duckdb_databases()"}, - {"pg_catalog", "pg_depend", "SELECT * FROM duckdb_dependencies()"}, - {"pg_catalog", "pg_description", "SELECT NULL objoid, NULL classoid, NULL objsubid, NULL description WHERE 1=0"}, - {"pg_catalog", "pg_enum", "SELECT NULL oid, a.type_oid enumtypid, list_position(b.labels, a.elabel) enumsortorder, a.elabel enumlabel FROM (SELECT UNNEST(labels) elabel, type_oid FROM duckdb_types() WHERE logical_type='ENUM') a JOIN duckdb_types() b ON a.type_oid=b.type_oid;"}, - {"pg_catalog", "pg_index", "SELECT index_oid indexrelid, table_oid indrelid, 0 indnatts, 0 indnkeyatts, is_unique indisunique, is_primary indisprimary, false indisexclusion, true indimmediate, false indisclustered, true indisvalid, false indcheckxmin, true indisready, true indislive, false indisreplident, NULL::INT[] indkey, NULL::OID[] indcollation, NULL::OID[] indclass, NULL::INT[] indoption, expressions indexprs, NULL indpred FROM duckdb_indexes()"}, - {"pg_catalog", "pg_indexes", "SELECT schema_name schemaname, table_name tablename, index_name indexname, NULL \"tablespace\", sql indexdef FROM duckdb_indexes()"}, - {"pg_catalog", "pg_namespace", "SELECT oid, schema_name nspname, 0 nspowner, NULL nspacl FROM duckdb_schemas()"}, - {"pg_catalog", "pg_proc", "SELECT f.function_oid oid, function_name proname, s.oid pronamespace, varargs provariadic, function_type = 'aggregate' proisagg, function_type = 'table' proretset, return_type prorettype, parameter_types proargtypes, parameters proargnames FROM duckdb_functions() f LEFT JOIN duckdb_schemas() s USING (database_name, schema_name)"}, - {"pg_catalog", "pg_sequence", "SELECT sequence_oid seqrelid, 0 seqtypid, start_value seqstart, increment_by seqincrement, max_value seqmax, min_value seqmin, 0 seqcache, cycle seqcycle FROM duckdb_sequences()"}, - {"pg_catalog", "pg_sequences", "SELECT schema_name schemaname, sequence_name sequencename, 'duckdb' sequenceowner, 0 data_type, start_value, min_value, max_value, increment_by, cycle, 0 cache_size, last_value FROM duckdb_sequences()"}, - {"pg_catalog", "pg_settings", "SELECT name, value setting, description short_desc, CASE WHEN input_type = 'VARCHAR' THEN 'string' WHEN input_type = 'BOOLEAN' THEN 'bool' WHEN input_type IN ('BIGINT', 'UBIGINT') THEN 'integer' ELSE input_type END vartype FROM duckdb_settings()"}, - {"pg_catalog", "pg_tables", "SELECT schema_name schemaname, table_name tablename, 'duckdb' tableowner, NULL \"tablespace\", index_count > 0 hasindexes, false hasrules, false hastriggers FROM duckdb_tables()"}, - {"pg_catalog", "pg_tablespace", "SELECT 0 oid, 'pg_default' spcname, 0 spcowner, NULL spcacl, NULL spcoptions"}, - {"pg_catalog", "pg_type", "SELECT type_oid oid, format_pg_type(type_name) typname, schema_oid typnamespace, 0 typowner, type_size typlen, false typbyval, CASE WHEN logical_type='ENUM' THEN 'e' else 'b' end typtype, CASE WHEN type_category='NUMERIC' THEN 'N' WHEN type_category='STRING' THEN 'S' WHEN type_category='DATETIME' THEN 'D' WHEN type_category='BOOLEAN' THEN 'B' WHEN type_category='COMPOSITE' THEN 'C' WHEN type_category='USER' THEN 'U' ELSE 'X' END typcategory, false typispreferred, true typisdefined, NULL typdelim, NULL typrelid, NULL typsubscript, NULL typelem, NULL typarray, NULL typinput, NULL typoutput, NULL typreceive, NULL typsend, NULL typmodin, NULL typmodout, NULL typanalyze, 'd' typalign, 'p' typstorage, NULL typnotnull, NULL typbasetype, NULL typtypmod, NULL typndims, NULL typcollation, NULL typdefaultbin, NULL typdefault, NULL typacl FROM duckdb_types() WHERE type_size IS NOT NULL;"}, - {"pg_catalog", "pg_views", "SELECT schema_name schemaname, view_name viewname, 'duckdb' viewowner, sql definition FROM duckdb_views()"}, - {"information_schema", "columns", "SELECT database_name table_catalog, schema_name table_schema, table_name, column_name, column_index ordinal_position, column_default, CASE WHEN is_nullable THEN 'YES' ELSE 'NO' END is_nullable, data_type, character_maximum_length, NULL character_octet_length, numeric_precision, numeric_precision_radix, numeric_scale, NULL datetime_precision, NULL interval_type, NULL interval_precision, NULL character_set_catalog, NULL character_set_schema, NULL character_set_name, NULL collation_catalog, NULL collation_schema, NULL collation_name, NULL domain_catalog, NULL domain_schema, NULL domain_name, NULL udt_catalog, NULL udt_schema, NULL udt_name, NULL scope_catalog, NULL scope_schema, NULL scope_name, NULL maximum_cardinality, NULL dtd_identifier, NULL is_self_referencing, NULL is_identity, NULL identity_generation, NULL identity_start, NULL identity_increment, NULL identity_maximum, NULL identity_minimum, NULL identity_cycle, NULL is_generated, NULL generation_expression, NULL is_updatable FROM duckdb_columns;"}, - {"information_schema", "schemata", "SELECT database_name catalog_name, schema_name, 'duckdb' schema_owner, NULL default_character_set_catalog, NULL default_character_set_schema, NULL default_character_set_name, sql sql_path FROM duckdb_schemas()"}, - {"information_schema", "tables", "SELECT database_name table_catalog, schema_name table_schema, table_name, CASE WHEN temporary THEN 'LOCAL TEMPORARY' ELSE 'BASE TABLE' END table_type, NULL self_referencing_column_name, NULL reference_generation, NULL user_defined_type_catalog, NULL user_defined_type_schema, NULL user_defined_type_name, 'YES' is_insertable_into, 'NO' is_typed, CASE WHEN temporary THEN 'PRESERVE' ELSE NULL END commit_action FROM duckdb_tables() UNION ALL SELECT database_name table_catalog, schema_name table_schema, view_name table_name, 'VIEW' table_type, NULL self_referencing_column_name, NULL reference_generation, NULL user_defined_type_catalog, NULL user_defined_type_schema, NULL user_defined_type_name, 'NO' is_insertable_into, 'NO' is_typed, NULL commit_action FROM duckdb_views;"}, - {nullptr, nullptr, nullptr}}; - -static unique_ptr GetDefaultView(ClientContext &context, const string &input_schema, const string &input_name) { - auto schema = StringUtil::Lower(input_schema); - auto name = StringUtil::Lower(input_name); - for (idx_t index = 0; internal_views[index].name != nullptr; index++) { - if (internal_views[index].schema == schema && internal_views[index].name == name) { - auto result = make_uniq(); - result->schema = schema; - result->view_name = name; - result->sql = internal_views[index].sql; - result->temporary = true; - result->internal = true; - - return CreateViewInfo::FromSelect(context, std::move(result)); - } - } - return nullptr; -} - -DefaultViewGenerator::DefaultViewGenerator(Catalog &catalog, SchemaCatalogEntry &schema) - : DefaultGenerator(catalog), schema(schema) { -} - -unique_ptr DefaultViewGenerator::CreateDefaultEntry(ClientContext &context, const string &entry_name) { - auto info = GetDefaultView(context, schema.name, entry_name); - if (info) { - return make_uniq_base(catalog, schema, *info); - } - return nullptr; -} - -vector DefaultViewGenerator::GetDefaultEntries() { - vector result; - for (idx_t index = 0; internal_views[index].name != nullptr; index++) { - if (internal_views[index].schema == schema.name) { - result.emplace_back(internal_views[index].name); - } - } - return result; -} - -} // namespace duckdb - - - - -namespace duckdb { - -void DependencyList::AddDependency(CatalogEntry &entry) { - if (entry.internal) { - return; - } - set.insert(entry); -} - -void DependencyList::VerifyDependencies(Catalog &catalog, const string &name) { - for (auto &dep_entry : set) { - auto &dep = dep_entry.get(); - if (&dep.ParentCatalog() != &catalog) { - throw DependencyException( - "Error adding dependency for object \"%s\" - dependency \"%s\" is in catalog " - "\"%s\", which does not match the catalog \"%s\".\nCross catalog dependencies are not supported.", - name, dep.name, dep.ParentCatalog().GetName(), catalog.GetName()); - } - } -} - -} // namespace duckdb - - - - - - - - - - - -namespace duckdb { - -DependencyManager::DependencyManager(DuckCatalog &catalog) : catalog(catalog) { -} - -void DependencyManager::AddObject(CatalogTransaction transaction, CatalogEntry &object, DependencyList &dependencies) { - // check for each object in the sources if they were not deleted yet - for (auto &dep : dependencies.set) { - auto &dependency = dep.get(); - if (&dependency.ParentCatalog() != &object.ParentCatalog()) { - throw DependencyException( - "Error adding dependency for object \"%s\" - dependency \"%s\" is in catalog " - "\"%s\", which does not match the catalog \"%s\".\nCross catalog dependencies are not supported.", - object.name, dependency.name, dependency.ParentCatalog().GetName(), object.ParentCatalog().GetName()); - } - if (!dependency.set) { - throw InternalException("Dependency has no set"); - } - auto catalog_entry = dependency.set->GetEntryInternal(transaction, dependency.name, nullptr); - if (!catalog_entry) { - throw InternalException("Dependency has already been deleted?"); - } - } - // indexes do not require CASCADE to be dropped, they are simply always dropped along with the table - auto dependency_type = object.type == CatalogType::INDEX_ENTRY ? DependencyType::DEPENDENCY_AUTOMATIC - : DependencyType::DEPENDENCY_REGULAR; - // add the object to the dependents_map of each object that it depends on - for (auto &dependency : dependencies.set) { - auto &set = dependents_map[dependency]; - set.insert(Dependency(object, dependency_type)); - } - // create the dependents map for this object: it starts out empty - dependents_map[object] = dependency_set_t(); - dependencies_map[object] = dependencies.set; -} - -void DependencyManager::DropObject(CatalogTransaction transaction, CatalogEntry &object, bool cascade) { - D_ASSERT(dependents_map.find(object) != dependents_map.end()); - - // first check the objects that depend on this object - auto &dependent_objects = dependents_map[object]; - for (auto &dep : dependent_objects) { - // look up the entry in the catalog set - auto &entry = dep.entry.get(); - auto &catalog_set = *entry.set; - auto mapping_value = catalog_set.GetMapping(transaction, entry.name, true /* get_latest */); - if (mapping_value == nullptr) { - continue; - } - auto dependency_entry = catalog_set.GetEntryInternal(transaction, mapping_value->index); - if (!dependency_entry) { - // the dependent object was already deleted, no conflict - continue; - } - // conflict: attempting to delete this object but the dependent object still exists - if (cascade || dep.dependency_type == DependencyType::DEPENDENCY_AUTOMATIC || - dep.dependency_type == DependencyType::DEPENDENCY_OWNS) { - // cascade: drop the dependent object - catalog_set.DropEntryInternal(transaction, mapping_value->index.Copy(), *dependency_entry, cascade); - } else { - // no cascade and there are objects that depend on this object: throw error - throw DependencyException("Cannot drop entry \"%s\" because there are entries that " - "depend on it. Use DROP...CASCADE to drop all dependents.", - object.name); - } - } -} - -void DependencyManager::AlterObject(CatalogTransaction transaction, CatalogEntry &old_obj, CatalogEntry &new_obj) { - D_ASSERT(dependents_map.find(old_obj) != dependents_map.end()); - D_ASSERT(dependencies_map.find(old_obj) != dependencies_map.end()); - - // first check the objects that depend on this object - catalog_entry_vector_t owned_objects_to_add; - auto &dependent_objects = dependents_map[old_obj]; - for (auto &dep : dependent_objects) { - // look up the entry in the catalog set - auto &entry = dep.entry.get(); - auto &catalog_set = *entry.set; - auto dependency_entry = catalog_set.GetEntryInternal(transaction, entry.name, nullptr); - if (!dependency_entry) { - // the dependent object was already deleted, no conflict - continue; - } - if (dep.dependency_type == DependencyType::DEPENDENCY_OWNS) { - // the dependent object is owned by the current object - owned_objects_to_add.push_back(dep.entry); - continue; - } - // conflict: attempting to alter this object but the dependent object still exists - // no cascade and there are objects that depend on this object: throw error - throw DependencyException("Cannot alter entry \"%s\" because there are entries that " - "depend on it.", - old_obj.name); - } - // add the new object to the dependents_map of each object that it depends on - auto &old_dependencies = dependencies_map[old_obj]; - for (auto &dep : old_dependencies) { - auto &dependency = dep.get(); - dependents_map[dependency].insert(new_obj); - } - - // We might have to add a type dependency - // add the new object to the dependency manager - dependents_map[new_obj] = dependency_set_t(); - dependencies_map[new_obj] = old_dependencies; - - for (auto &dependency : owned_objects_to_add) { - dependents_map[new_obj].insert(Dependency(dependency, DependencyType::DEPENDENCY_OWNS)); - dependents_map[dependency].insert(Dependency(new_obj, DependencyType::DEPENDENCY_OWNED_BY)); - dependencies_map[new_obj].insert(dependency); - } -} - -void DependencyManager::EraseObject(CatalogEntry &object) { - // obtain the writing lock - EraseObjectInternal(object); -} - -void DependencyManager::EraseObjectInternal(CatalogEntry &object) { - if (dependents_map.find(object) == dependents_map.end()) { - // dependencies already removed - return; - } - D_ASSERT(dependents_map.find(object) != dependents_map.end()); - D_ASSERT(dependencies_map.find(object) != dependencies_map.end()); - // now for each of the dependencies, erase the entries from the dependents_map - for (auto &dependency : dependencies_map[object]) { - auto entry = dependents_map.find(dependency); - if (entry != dependents_map.end()) { - D_ASSERT(entry->second.find(object) != entry->second.end()); - entry->second.erase(object); - } - } - // erase the dependents and dependencies for this object - dependents_map.erase(object); - dependencies_map.erase(object); -} - -void DependencyManager::Scan(const std::function &callback) { - lock_guard write_lock(catalog.GetWriteLock()); - for (auto &entry : dependents_map) { - for (auto &dependent : entry.second) { - callback(entry.first, dependent.entry, dependent.dependency_type); - } - } -} - -void DependencyManager::AddOwnership(CatalogTransaction transaction, CatalogEntry &owner, CatalogEntry &entry) { - // lock the catalog for writing - lock_guard write_lock(catalog.GetWriteLock()); - - // If the owner is already owned by something else, throw an error - for (auto &dep : dependents_map[owner]) { - if (dep.dependency_type == DependencyType::DEPENDENCY_OWNED_BY) { - throw DependencyException(owner.name + " already owned by " + dep.entry.get().name); - } - } - - // If the entry is already owned, throw an error - for (auto &dep : dependents_map[entry]) { - // if the entry is already owned, throw error - if (&dep.entry.get() != &owner) { - throw DependencyException(entry.name + " already depends on " + dep.entry.get().name); - } - // if the entry owns the owner, throw error - if (&dep.entry.get() == &owner && dep.dependency_type == DependencyType::DEPENDENCY_OWNS) { - throw DependencyException(entry.name + " already owns " + owner.name + - ". Cannot have circular dependencies"); - } - } - - // Emplace guarantees that the same object cannot be inserted twice in the unordered_set - // In the case AddOwnership is called twice, because of emplace, the object will not be repeated in the set. - // We use an automatic dependency because if the Owner gets deleted, then the owned objects are also deleted - dependents_map[owner].emplace(entry, DependencyType::DEPENDENCY_OWNS); - dependents_map[entry].emplace(owner, DependencyType::DEPENDENCY_OWNED_BY); - dependencies_map[owner].emplace(entry); -} - -} // namespace duckdb - - - - - - - - - -#ifndef DISABLE_CORE_FUNCTIONS_EXTENSION - -#endif - -namespace duckdb { - -DuckCatalog::DuckCatalog(AttachedDatabase &db) - : Catalog(db), dependency_manager(make_uniq(*this)), - schemas(make_uniq(*this, make_uniq(*this))) { -} - -DuckCatalog::~DuckCatalog() { -} - -void DuckCatalog::Initialize(bool load_builtin) { - // first initialize the base system catalogs - // these are never written to the WAL - // we start these at 1 because deleted entries default to 0 - auto data = CatalogTransaction::GetSystemTransaction(GetDatabase()); - - // create the default schema - CreateSchemaInfo info; - info.schema = DEFAULT_SCHEMA; - info.internal = true; - CreateSchema(data, info); - - if (load_builtin) { - // initialize default functions - BuiltinFunctions builtin(data, *this); - builtin.Initialize(); - -#ifndef DISABLE_CORE_FUNCTIONS_EXTENSION - CoreFunctions::RegisterFunctions(*this, data); -#endif - } - - Verify(); -} - -bool DuckCatalog::IsDuckCatalog() { - return true; -} - -//===--------------------------------------------------------------------===// -// Schema -//===--------------------------------------------------------------------===// -optional_ptr DuckCatalog::CreateSchemaInternal(CatalogTransaction transaction, CreateSchemaInfo &info) { - DependencyList dependencies; - auto entry = make_uniq(*this, info.schema, info.internal); - auto result = entry.get(); - if (!schemas->CreateEntry(transaction, info.schema, std::move(entry), dependencies)) { - return nullptr; - } - return (CatalogEntry *)result; -} - -optional_ptr DuckCatalog::CreateSchema(CatalogTransaction transaction, CreateSchemaInfo &info) { - D_ASSERT(!info.schema.empty()); - auto result = CreateSchemaInternal(transaction, info); - if (!result) { - switch (info.on_conflict) { - case OnCreateConflict::ERROR_ON_CONFLICT: - throw CatalogException("Schema with name %s already exists!", info.schema); - case OnCreateConflict::REPLACE_ON_CONFLICT: { - DropInfo drop_info; - drop_info.type = CatalogType::SCHEMA_ENTRY; - drop_info.catalog = info.catalog; - drop_info.name = info.schema; - DropSchema(transaction, drop_info); - result = CreateSchemaInternal(transaction, info); - if (!result) { - throw InternalException("Failed to create schema entry in CREATE_OR_REPLACE"); - } - break; - } - case OnCreateConflict::IGNORE_ON_CONFLICT: - break; - default: - throw InternalException("Unsupported OnCreateConflict for CreateSchema"); - } - return nullptr; - } - return result; -} - -void DuckCatalog::DropSchema(CatalogTransaction transaction, DropInfo &info) { - D_ASSERT(!info.name.empty()); - ModifyCatalog(); - if (!schemas->DropEntry(transaction, info.name, info.cascade)) { - if (info.if_not_found == OnEntryNotFound::THROW_EXCEPTION) { - throw CatalogException("Schema with name \"%s\" does not exist!", info.name); - } - } -} - -void DuckCatalog::DropSchema(ClientContext &context, DropInfo &info) { - DropSchema(GetCatalogTransaction(context), info); -} - -void DuckCatalog::ScanSchemas(ClientContext &context, std::function callback) { - schemas->Scan(GetCatalogTransaction(context), - [&](CatalogEntry &entry) { callback(entry.Cast()); }); -} - -void DuckCatalog::ScanSchemas(std::function callback) { - schemas->Scan([&](CatalogEntry &entry) { callback(entry.Cast()); }); -} - -optional_ptr DuckCatalog::GetSchema(CatalogTransaction transaction, const string &schema_name, - OnEntryNotFound if_not_found, QueryErrorContext error_context) { - D_ASSERT(!schema_name.empty()); - auto entry = schemas->GetEntry(transaction, schema_name); - if (!entry) { - if (if_not_found == OnEntryNotFound::THROW_EXCEPTION) { - throw CatalogException(error_context.FormatError("Schema with name %s does not exist!", schema_name)); - } - return nullptr; - } - return &entry->Cast(); -} - -DatabaseSize DuckCatalog::GetDatabaseSize(ClientContext &context) { - return db.GetStorageManager().GetDatabaseSize(); -} - -vector DuckCatalog::GetMetadataInfo(ClientContext &context) { - return db.GetStorageManager().GetMetadataInfo(); -} - -bool DuckCatalog::InMemory() { - return db.GetStorageManager().InMemory(); -} - -string DuckCatalog::GetDBPath() { - return db.GetStorageManager().GetDBPath(); -} - -void DuckCatalog::Verify() { -#ifdef DEBUG - Catalog::Verify(); - schemas->Verify(*this); -#endif -} - -} // namespace duckdb - - - - -namespace duckdb { - -string SimilarCatalogEntry::GetQualifiedName(bool qualify_catalog, bool qualify_schema) const { - D_ASSERT(Found()); - string result; - if (qualify_catalog) { - result += schema->catalog.GetName(); - } - if (qualify_schema) { - if (!result.empty()) { - result += "."; - } - result += schema->name; - } - if (!result.empty()) { - result += "."; - } - result += name; - return result; -} - -} // namespace duckdb - - - - - - - - - - -#ifndef DUCKDB_AMALGAMATION - -#endif - - - -#include -#include - -// We must leak the symbols of the init function -duckdb_adbc::AdbcStatusCode duckdb_adbc_init(size_t count, struct duckdb_adbc::AdbcDriver *driver, - struct duckdb_adbc::AdbcError *error) { - if (!driver) { - return ADBC_STATUS_INVALID_ARGUMENT; - } - - driver->DatabaseNew = duckdb_adbc::DatabaseNew; - driver->DatabaseSetOption = duckdb_adbc::DatabaseSetOption; - driver->DatabaseInit = duckdb_adbc::DatabaseInit; - driver->DatabaseRelease = duckdb_adbc::DatabaseRelease; - driver->ConnectionNew = duckdb_adbc::ConnectionNew; - driver->ConnectionSetOption = duckdb_adbc::ConnectionSetOption; - driver->ConnectionInit = duckdb_adbc::ConnectionInit; - driver->ConnectionRelease = duckdb_adbc::ConnectionRelease; - driver->ConnectionGetTableTypes = duckdb_adbc::ConnectionGetTableTypes; - driver->StatementNew = duckdb_adbc::StatementNew; - driver->StatementRelease = duckdb_adbc::StatementRelease; - driver->StatementBind = duckdb_adbc::StatementBind; - driver->StatementBindStream = duckdb_adbc::StatementBindStream; - driver->StatementExecuteQuery = duckdb_adbc::StatementExecuteQuery; - driver->StatementPrepare = duckdb_adbc::StatementPrepare; - driver->StatementSetOption = duckdb_adbc::StatementSetOption; - driver->StatementSetSqlQuery = duckdb_adbc::StatementSetSqlQuery; - driver->ConnectionGetObjects = duckdb_adbc::ConnectionGetObjects; - driver->ConnectionCommit = duckdb_adbc::ConnectionCommit; - driver->ConnectionRollback = duckdb_adbc::ConnectionRollback; - driver->ConnectionReadPartition = duckdb_adbc::ConnectionReadPartition; - driver->StatementExecutePartitions = duckdb_adbc::StatementExecutePartitions; - driver->ConnectionGetInfo = duckdb_adbc::ConnectionGetInfo; - driver->StatementGetParameterSchema = duckdb_adbc::StatementGetParameterSchema; - driver->ConnectionGetTableSchema = duckdb_adbc::ConnectionGetTableSchema; - driver->StatementSetSubstraitPlan = duckdb_adbc::StatementSetSubstraitPlan; - - driver->ConnectionGetInfo = duckdb_adbc::ConnectionGetInfo; - driver->StatementGetParameterSchema = duckdb_adbc::StatementGetParameterSchema; - return ADBC_STATUS_OK; -} - -namespace duckdb_adbc { - -enum class IngestionMode { CREATE = 0, APPEND = 1 }; -struct DuckDBAdbcStatementWrapper { - ::duckdb_connection connection; - ::duckdb_arrow result; - ::duckdb_prepared_statement statement; - char *ingestion_table_name; - ArrowArrayStream ingestion_stream; - IngestionMode ingestion_mode = IngestionMode::CREATE; -}; - -static AdbcStatusCode QueryInternal(struct AdbcConnection *connection, struct ArrowArrayStream *out, const char *query, - struct AdbcError *error) { - AdbcStatement statement; - - auto status = StatementNew(connection, &statement, error); - if (status != ADBC_STATUS_OK) { - SetError(error, "unable to initialize statement"); - return status; - } - status = StatementSetSqlQuery(&statement, query, error); - if (status != ADBC_STATUS_OK) { - SetError(error, "unable to initialize statement"); - return status; - } - status = StatementExecuteQuery(&statement, out, nullptr, error); - if (status != ADBC_STATUS_OK) { - SetError(error, "unable to initialize statement"); - return status; - } - - return ADBC_STATUS_OK; -} - -struct DuckDBAdbcDatabaseWrapper { - //! The DuckDB Database Configuration - ::duckdb_config config; - //! The DuckDB Database - ::duckdb_database database; - //! Path of Disk-Based Database or :memory: database - std::string path; -}; - -static void EmptyErrorRelease(AdbcError *error) { - // The object is valid but doesn't contain any data that needs to be cleaned up - // Just set the release to nullptr to indicate that it's no longer valid. - error->release = nullptr; - return; -} - -void InitializeADBCError(AdbcError *error) { - if (!error) { - return; - } - error->message = nullptr; - // Don't set to nullptr, as that indicates that it's invalid - error->release = EmptyErrorRelease; - std::memset(error->sqlstate, '\0', sizeof(error->sqlstate)); - error->vendor_code = -1; -} - -AdbcStatusCode CheckResult(duckdb_state &res, AdbcError *error, const char *error_msg) { - if (!error) { - // Error should be a non-null pointer - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (res != DuckDBSuccess) { - duckdb_adbc::SetError(error, error_msg); - return ADBC_STATUS_INTERNAL; - } - return ADBC_STATUS_OK; -} - -AdbcStatusCode DatabaseNew(struct AdbcDatabase *database, struct AdbcError *error) { - if (!database) { - SetError(error, "Missing database object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - database->private_data = nullptr; - // you can't malloc a struct with a non-trivial C++ constructor - // and std::string has a non-trivial constructor. so we need - // to use new and delete rather than malloc and free. - auto wrapper = new (std::nothrow) DuckDBAdbcDatabaseWrapper; - if (!wrapper) { - SetError(error, "Allocation error"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - database->private_data = wrapper; - auto res = duckdb_create_config(&wrapper->config); - return CheckResult(res, error, "Failed to allocate"); -} - -AdbcStatusCode StatementSetSubstraitPlan(struct AdbcStatement *statement, const uint8_t *plan, size_t length, - struct AdbcError *error) { - if (!statement) { - SetError(error, "Statement is not set"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!plan) { - SetError(error, "Substrait Plan is not set"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (length == 0) { - SetError(error, "Can't execute plan with size = 0"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - auto wrapper = reinterpret_cast(statement->private_data); - auto plan_str = std::string(reinterpret_cast(plan), length); - auto query = "CALL from_substrait('" + plan_str + "'::BLOB)"; - auto res = duckdb_prepare(wrapper->connection, query.c_str(), &wrapper->statement); - auto error_msg = duckdb_prepare_error(wrapper->statement); - return CheckResult(res, error, error_msg); -} - -AdbcStatusCode DatabaseSetOption(struct AdbcDatabase *database, const char *key, const char *value, - struct AdbcError *error) { - if (!database) { - SetError(error, "Missing database object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!key) { - SetError(error, "Missing key"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - auto wrapper = (DuckDBAdbcDatabaseWrapper *)database->private_data; - if (strcmp(key, "path") == 0) { - wrapper->path = value; - return ADBC_STATUS_OK; - } - auto res = duckdb_set_config(wrapper->config, key, value); - - return CheckResult(res, error, "Failed to set configuration option"); -} - -AdbcStatusCode DatabaseInit(struct AdbcDatabase *database, struct AdbcError *error) { - if (!error) { - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!database) { - duckdb_adbc::SetError(error, "ADBC Database has an invalid pointer"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - char *errormsg; - // TODO can we set the database path via option, too? Does not look like it... - auto wrapper = (DuckDBAdbcDatabaseWrapper *)database->private_data; - auto res = duckdb_open_ext(wrapper->path.c_str(), &wrapper->database, wrapper->config, &errormsg); - return CheckResult(res, error, errormsg); -} - -AdbcStatusCode DatabaseRelease(struct AdbcDatabase *database, struct AdbcError *error) { - - if (database && database->private_data) { - auto wrapper = (DuckDBAdbcDatabaseWrapper *)database->private_data; - - duckdb_close(&wrapper->database); - duckdb_destroy_config(&wrapper->config); - delete wrapper; - database->private_data = nullptr; - } - return ADBC_STATUS_OK; -} - -AdbcStatusCode ConnectionGetTableSchema(struct AdbcConnection *connection, const char *catalog, const char *db_schema, - const char *table_name, struct ArrowSchema *schema, struct AdbcError *error) { - if (!connection) { - SetError(error, "Connection is not set"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (db_schema == nullptr) { - // if schema is not set, we use the default schema - db_schema = "main"; - } - if (catalog != nullptr && strlen(catalog) > 0) { - // In DuckDB this is the name of the database, not sure what's the expected functionality here, so for now, - // scream. - SetError(error, "Catalog Name is not used in DuckDB. It must be set to nullptr or an empty string"); - return ADBC_STATUS_NOT_IMPLEMENTED; - } else if (table_name == nullptr) { - SetError(error, "AdbcConnectionGetTableSchema: must provide table_name"); - return ADBC_STATUS_INVALID_ARGUMENT; - } else if (strlen(table_name) == 0) { - SetError(error, "AdbcConnectionGetTableSchema: must provide table_name"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - ArrowArrayStream arrow_stream; - - std::string query = "SELECT * FROM "; - if (strlen(db_schema) > 0) { - query += std::string(db_schema) + "."; - } - query += std::string(table_name) + " LIMIT 0;"; - - auto success = QueryInternal(connection, &arrow_stream, query.c_str(), error); - if (success != ADBC_STATUS_OK) { - return success; - } - arrow_stream.get_schema(&arrow_stream, schema); - arrow_stream.release(&arrow_stream); - return ADBC_STATUS_OK; -} - -AdbcStatusCode ConnectionNew(struct AdbcConnection *connection, struct AdbcError *error) { - if (!connection) { - SetError(error, "Missing connection object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - connection->private_data = nullptr; - return ADBC_STATUS_OK; -} - -AdbcStatusCode ExecuteQuery(duckdb::Connection *conn, const char *query, struct AdbcError *error) { - auto res = conn->Query(query); - if (res->HasError()) { - auto error_message = "Failed to execute query \"" + std::string(query) + "\": " + res->GetError(); - SetError(error, error_message); - return ADBC_STATUS_INTERNAL; - } - return ADBC_STATUS_OK; -} - -AdbcStatusCode ConnectionSetOption(struct AdbcConnection *connection, const char *key, const char *value, - struct AdbcError *error) { - if (!connection) { - SetError(error, "Connection is not set"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - auto conn = (duckdb::Connection *)connection->private_data; - if (strcmp(key, ADBC_CONNECTION_OPTION_AUTOCOMMIT) == 0) { - if (strcmp(value, ADBC_OPTION_VALUE_ENABLED) == 0) { - if (conn->HasActiveTransaction()) { - AdbcStatusCode status = ExecuteQuery(conn, "COMMIT", error); - if (status != ADBC_STATUS_OK) { - return status; - } - } else { - // no-op - } - } else if (strcmp(value, ADBC_OPTION_VALUE_DISABLED) == 0) { - if (conn->HasActiveTransaction()) { - // no-op - } else { - // begin - AdbcStatusCode status = ExecuteQuery(conn, "START TRANSACTION", error); - if (status != ADBC_STATUS_OK) { - return status; - } - } - } else { - auto error_message = "Invalid connection option value " + std::string(key) + "=" + std::string(value); - SetError(error, error_message); - return ADBC_STATUS_INVALID_ARGUMENT; - } - return ADBC_STATUS_OK; - } - auto error_message = - "Unknown connection option " + std::string(key) + "=" + (value ? std::string(value) : "(NULL)"); - SetError(error, error_message); - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode ConnectionReadPartition(struct AdbcConnection *connection, const uint8_t *serialized_partition, - size_t serialized_length, struct ArrowArrayStream *out, - struct AdbcError *error) { - SetError(error, "Read Partitions are not supported in DuckDB"); - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode StatementExecutePartitions(struct AdbcStatement *statement, struct ArrowSchema *schema, - struct AdbcPartitions *partitions, int64_t *rows_affected, - struct AdbcError *error) { - SetError(error, "Execute Partitions are not supported in DuckDB"); - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode ConnectionCommit(struct AdbcConnection *connection, struct AdbcError *error) { - if (!connection) { - SetError(error, "Connection is not set"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - auto conn = (duckdb::Connection *)connection->private_data; - if (!conn->HasActiveTransaction()) { - SetError(error, "No active transaction, cannot commit"); - return ADBC_STATUS_INVALID_STATE; - } - - AdbcStatusCode status = ExecuteQuery(conn, "COMMIT", error); - if (status != ADBC_STATUS_OK) { - return status; - } - return ExecuteQuery(conn, "START TRANSACTION", error); -} - -AdbcStatusCode ConnectionRollback(struct AdbcConnection *connection, struct AdbcError *error) { - if (!connection) { - SetError(error, "Connection is not set"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - auto conn = (duckdb::Connection *)connection->private_data; - if (!conn->HasActiveTransaction()) { - SetError(error, "No active transaction, cannot rollback"); - return ADBC_STATUS_INVALID_STATE; - } - - AdbcStatusCode status = ExecuteQuery(conn, "ROLLBACK", error); - if (status != ADBC_STATUS_OK) { - return status; - } - return ExecuteQuery(conn, "START TRANSACTION", error); -} - -enum class AdbcInfoCode : uint32_t { - VENDOR_NAME, - VENDOR_VERSION, - DRIVER_NAME, - DRIVER_VERSION, - DRIVER_ARROW_VERSION, - UNRECOGNIZED // always the last entry of the enum -}; - -static AdbcInfoCode ConvertToInfoCode(uint32_t info_code) { - switch (info_code) { - case 0: - return AdbcInfoCode::VENDOR_NAME; - case 1: - return AdbcInfoCode::VENDOR_VERSION; - case 2: - return AdbcInfoCode::DRIVER_NAME; - case 3: - return AdbcInfoCode::DRIVER_VERSION; - case 4: - return AdbcInfoCode::DRIVER_ARROW_VERSION; - default: - return AdbcInfoCode::UNRECOGNIZED; - } -} - -AdbcStatusCode ConnectionGetInfo(struct AdbcConnection *connection, uint32_t *info_codes, size_t info_codes_length, - struct ArrowArrayStream *out, struct AdbcError *error) { - if (!connection) { - SetError(error, "Missing connection object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!connection->private_data) { - SetError(error, "Connection is invalid"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!out) { - SetError(error, "Output parameter was not provided"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - // If 'info_codes' is NULL, we should output all the info codes we recognize - size_t length = info_codes ? info_codes_length : (size_t)AdbcInfoCode::UNRECOGNIZED; - - duckdb::string q = R"EOF( - select - name::UINTEGER as info_name, - info::UNION( - string_value VARCHAR, - bool_value BOOL, - int64_value BIGINT, - int32_bitmask INTEGER, - string_list VARCHAR[], - int32_to_int32_list_map MAP(INTEGER, INTEGER[]) - ) as info_value from values - )EOF"; - - duckdb::string results = ""; - - for (size_t i = 0; i < length; i++) { - uint32_t code = info_codes ? info_codes[i] : i; - auto info_code = ConvertToInfoCode(code); - switch (info_code) { - case AdbcInfoCode::VENDOR_NAME: { - results += "(0, 'duckdb'),"; - break; - } - case AdbcInfoCode::VENDOR_VERSION: { - results += duckdb::StringUtil::Format("(1, '%s'),", duckdb_library_version()); - break; - } - case AdbcInfoCode::DRIVER_NAME: { - results += "(2, 'ADBC DuckDB Driver'),"; - break; - } - case AdbcInfoCode::DRIVER_VERSION: { - // TODO: fill in driver version - results += "(3, '(unknown)'),"; - break; - } - case AdbcInfoCode::DRIVER_ARROW_VERSION: { - // TODO: fill in arrow version - results += "(4, '(unknown)'),"; - break; - } - case AdbcInfoCode::UNRECOGNIZED: { - // Unrecognized codes are not an error, just ignored - continue; - } - default: { - // Codes that we have implemented but not handled here are a developer error - SetError(error, "Info code recognized but not handled"); - return ADBC_STATUS_INTERNAL; - } - } - } - if (results.empty()) { - // Add a group of values so the query parses - q += "(NULL, NULL)"; - } else { - q += results; - } - q += " tbl(name, info)"; - if (results.empty()) { - // Add an impossible where clause to return an empty result set - q += " where true = false"; - } - return QueryInternal(connection, out, q.c_str(), error); -} - -AdbcStatusCode ConnectionInit(struct AdbcConnection *connection, struct AdbcDatabase *database, - struct AdbcError *error) { - if (!database) { - SetError(error, "Missing database object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!database->private_data) { - SetError(error, "Invalid database"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!connection) { - SetError(error, "Missing connection object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - auto database_wrapper = (DuckDBAdbcDatabaseWrapper *)database->private_data; - - connection->private_data = nullptr; - auto res = duckdb_connect(database_wrapper->database, (duckdb_connection *)&connection->private_data); - return CheckResult(res, error, "Failed to connect to Database"); -} - -AdbcStatusCode ConnectionRelease(struct AdbcConnection *connection, struct AdbcError *error) { - if (connection && connection->private_data) { - duckdb_disconnect((duckdb_connection *)&connection->private_data); - connection->private_data = nullptr; - } - return ADBC_STATUS_OK; -} - -// some stream callbacks - -static int get_schema(struct ArrowArrayStream *stream, struct ArrowSchema *out) { - if (!stream || !stream->private_data || !out) { - return DuckDBError; - } - return duckdb_query_arrow_schema((duckdb_arrow)stream->private_data, (duckdb_arrow_schema *)&out); -} - -static int get_next(struct ArrowArrayStream *stream, struct ArrowArray *out) { - if (!stream || !stream->private_data || !out) { - return DuckDBError; - } - out->release = nullptr; - - return duckdb_query_arrow_array((duckdb_arrow)stream->private_data, (duckdb_arrow_array *)&out); -} - -void release(struct ArrowArrayStream *stream) { - if (!stream || !stream->release) { - return; - } - if (stream->private_data) { - duckdb_destroy_arrow((duckdb_arrow *)&stream->private_data); - stream->private_data = nullptr; - } - stream->release = nullptr; -} - -const char *get_last_error(struct ArrowArrayStream *stream) { - if (!stream) { - return nullptr; - } - return nullptr; - // return duckdb_query_arrow_error(stream); -} - -// this is an evil hack, normally we would need a stream factory here, but its probably much easier if the adbc clients -// just hand over a stream - -duckdb::unique_ptr -stream_produce(uintptr_t factory_ptr, - std::pair, std::vector> &project_columns, - duckdb::TableFilterSet *filters) { - - // TODO this will ignore any projections or filters but since we don't expose the scan it should be sort of fine - auto res = duckdb::make_uniq(); - res->arrow_array_stream = *(ArrowArrayStream *)factory_ptr; - return res; -} - -void stream_schema(uintptr_t factory_ptr, duckdb::ArrowSchemaWrapper &schema) { - auto stream = (ArrowArrayStream *)factory_ptr; - get_schema(stream, &schema.arrow_schema); -} - -AdbcStatusCode Ingest(duckdb_connection connection, const char *table_name, struct ArrowArrayStream *input, - struct AdbcError *error, IngestionMode ingestion_mode) { - - if (!connection) { - SetError(error, "Missing connection object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!input) { - SetError(error, "Missing input arrow stream pointer"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!table_name) { - SetError(error, "Missing database object name"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - auto cconn = (duckdb::Connection *)connection; - - auto arrow_scan = cconn->TableFunction("arrow_scan", {duckdb::Value::POINTER((uintptr_t)input), - duckdb::Value::POINTER((uintptr_t)stream_produce), - duckdb::Value::POINTER((uintptr_t)input->get_schema)}); - try { - if (ingestion_mode == IngestionMode::CREATE) { - // We create the table based on an Arrow Scanner - arrow_scan->Create(table_name); - } else { - arrow_scan->CreateView("temp_adbc_view", true, true); - auto query = duckdb::StringUtil::Format("insert into \"%s\" select * from temp_adbc_view", table_name); - auto result = cconn->Query(query); - } - // After creating a table, the arrow array stream is released. Hence we must set it as released to avoid - // double-releasing it - input->release = nullptr; - } catch (std::exception &ex) { - if (error) { - error->message = strdup(ex.what()); - } - return ADBC_STATUS_INTERNAL; - } catch (...) { - return ADBC_STATUS_INTERNAL; - } - return ADBC_STATUS_OK; -} - -AdbcStatusCode StatementNew(struct AdbcConnection *connection, struct AdbcStatement *statement, - struct AdbcError *error) { - if (!connection) { - SetError(error, "Missing connection object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!connection->private_data) { - SetError(error, "Invalid connection object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!statement) { - SetError(error, "Missing statement object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - statement->private_data = nullptr; - - auto statement_wrapper = (DuckDBAdbcStatementWrapper *)malloc(sizeof(DuckDBAdbcStatementWrapper)); - if (!statement_wrapper) { - SetError(error, "Allocation error"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - statement->private_data = statement_wrapper; - statement_wrapper->connection = (duckdb_connection)connection->private_data; - statement_wrapper->statement = nullptr; - statement_wrapper->result = nullptr; - statement_wrapper->ingestion_stream.release = nullptr; - statement_wrapper->ingestion_table_name = nullptr; - statement_wrapper->ingestion_mode = IngestionMode::CREATE; - return ADBC_STATUS_OK; -} - -AdbcStatusCode StatementRelease(struct AdbcStatement *statement, struct AdbcError *error) { - if (!statement || !statement->private_data) { - return ADBC_STATUS_OK; - } - auto wrapper = (DuckDBAdbcStatementWrapper *)statement->private_data; - if (wrapper->statement) { - duckdb_destroy_prepare(&wrapper->statement); - wrapper->statement = nullptr; - } - if (wrapper->result) { - duckdb_destroy_arrow(&wrapper->result); - wrapper->result = nullptr; - } - if (wrapper->ingestion_stream.release) { - wrapper->ingestion_stream.release(&wrapper->ingestion_stream); - wrapper->ingestion_stream.release = nullptr; - } - if (wrapper->ingestion_table_name) { - free(wrapper->ingestion_table_name); - wrapper->ingestion_table_name = nullptr; - } - free(statement->private_data); - statement->private_data = nullptr; - return ADBC_STATUS_OK; -} - -AdbcStatusCode StatementGetParameterSchema(struct AdbcStatement *statement, struct ArrowSchema *schema, - struct AdbcError *error) { - if (!statement) { - SetError(error, "Missing statement object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!statement->private_data) { - SetError(error, "Invalid statement object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!schema) { - SetError(error, "Missing schema object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - auto wrapper = (DuckDBAdbcStatementWrapper *)statement->private_data; - // TODO: we might want to cache this, but then we need to return a deep copy anyways.., so I'm not sure if that - // would be worth the extra management - auto res = duckdb_prepared_arrow_schema(wrapper->statement, (duckdb_arrow_schema *)&schema); - if (res != DuckDBSuccess) { - return ADBC_STATUS_INVALID_ARGUMENT; - } - return ADBC_STATUS_OK; -} - -AdbcStatusCode GetPreparedParameters(duckdb_connection connection, duckdb::unique_ptr &result, - ArrowArrayStream *input, AdbcError *error) { - - auto cconn = (duckdb::Connection *)connection; - - try { - auto arrow_scan = cconn->TableFunction("arrow_scan", {duckdb::Value::POINTER((uintptr_t)input), - duckdb::Value::POINTER((uintptr_t)stream_produce), - duckdb::Value::POINTER((uintptr_t)input->get_schema)}); - result = arrow_scan->Execute(); - // After creating a table, the arrow array stream is released. Hence we must set it as released to avoid - // double-releasing it - input->release = nullptr; - } catch (std::exception &ex) { - if (error) { - error->message = strdup(ex.what()); - } - return ADBC_STATUS_INTERNAL; - } catch (...) { - return ADBC_STATUS_INTERNAL; - } - return ADBC_STATUS_OK; -} - -static AdbcStatusCode IngestToTableFromBoundStream(DuckDBAdbcStatementWrapper *statement, AdbcError *error) { - // See ADBC_INGEST_OPTION_TARGET_TABLE - D_ASSERT(statement->ingestion_stream.release); - D_ASSERT(statement->ingestion_table_name); - - // Take the input stream from the statement - auto stream = statement->ingestion_stream; - statement->ingestion_stream.release = nullptr; - - // Ingest into a table from the bound stream - return Ingest(statement->connection, statement->ingestion_table_name, &stream, error, statement->ingestion_mode); -} - -AdbcStatusCode StatementExecuteQuery(struct AdbcStatement *statement, struct ArrowArrayStream *out, - int64_t *rows_affected, struct AdbcError *error) { - if (!statement) { - SetError(error, "Missing statement object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!statement->private_data) { - SetError(error, "Invalid statement object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - auto wrapper = (DuckDBAdbcStatementWrapper *)statement->private_data; - - // TODO: Set affected rows, careful with early return - if (rows_affected) { - *rows_affected = 0; - } - - const auto has_stream = wrapper->ingestion_stream.release != nullptr; - const auto to_table = wrapper->ingestion_table_name != nullptr; - - if (has_stream && to_table) { - return IngestToTableFromBoundStream(wrapper, error); - } - - if (has_stream) { - // A stream was bound to the statement, use that to bind parameters - duckdb::unique_ptr result; - ArrowArrayStream stream = wrapper->ingestion_stream; - wrapper->ingestion_stream.release = nullptr; - auto adbc_res = GetPreparedParameters(wrapper->connection, result, &stream, error); - if (adbc_res != ADBC_STATUS_OK) { - return adbc_res; - } - if (!result) { - return ADBC_STATUS_INVALID_ARGUMENT; - } - duckdb::unique_ptr chunk; - while ((chunk = result->Fetch()) != nullptr) { - if (chunk->size() == 0) { - SetError(error, "Please provide a non-empty chunk to be bound"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (chunk->size() != 1) { - // TODO: add support for binding multiple rows - SetError(error, "Binding multiple rows at once is not supported yet"); - return ADBC_STATUS_NOT_IMPLEMENTED; - } - duckdb_clear_bindings(wrapper->statement); - for (idx_t col_idx = 0; col_idx < chunk->ColumnCount(); col_idx++) { - auto val = chunk->GetValue(col_idx, 0); - auto duck_val = (duckdb_value)&val; - auto res = duckdb_bind_value(wrapper->statement, 1 + col_idx, duck_val); - if (res != DuckDBSuccess) { - SetError(error, duckdb_prepare_error(wrapper->statement)); - return ADBC_STATUS_INVALID_ARGUMENT; - } - } - - auto res = duckdb_execute_prepared_arrow(wrapper->statement, &wrapper->result); - if (res != DuckDBSuccess) { - SetError(error, duckdb_query_arrow_error(wrapper->result)); - return ADBC_STATUS_INVALID_ARGUMENT; - } - } - } else { - auto res = duckdb_execute_prepared_arrow(wrapper->statement, &wrapper->result); - if (res != DuckDBSuccess) { - SetError(error, duckdb_query_arrow_error(wrapper->result)); - return ADBC_STATUS_INVALID_ARGUMENT; - } - } - - if (out) { - out->private_data = wrapper->result; - out->get_schema = get_schema; - out->get_next = get_next; - out->release = release; - out->get_last_error = get_last_error; - - // because we handed out the stream pointer its no longer our responsibility to destroy it in - // AdbcStatementRelease, this is now done in release() - wrapper->result = nullptr; - } - - return ADBC_STATUS_OK; -} - -// this is a nop for us -AdbcStatusCode StatementPrepare(struct AdbcStatement *statement, struct AdbcError *error) { - if (!statement) { - SetError(error, "Missing statement object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!statement->private_data) { - SetError(error, "Invalid statement object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - return ADBC_STATUS_OK; -} - -AdbcStatusCode StatementSetSqlQuery(struct AdbcStatement *statement, const char *query, struct AdbcError *error) { - if (!statement) { - SetError(error, "Missing statement object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!statement->private_data) { - SetError(error, "Invalid statement object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!query) { - SetError(error, "Missing query"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - auto wrapper = (DuckDBAdbcStatementWrapper *)statement->private_data; - auto res = duckdb_prepare(wrapper->connection, query, &wrapper->statement); - auto error_msg = duckdb_prepare_error(wrapper->statement); - return CheckResult(res, error, error_msg); -} - -AdbcStatusCode StatementBind(struct AdbcStatement *statement, struct ArrowArray *values, struct ArrowSchema *schemas, - struct AdbcError *error) { - if (!statement) { - SetError(error, "Missing statement object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!statement->private_data) { - SetError(error, "Invalid statement object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!values) { - SetError(error, "Missing values object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!schemas) { - SetError(error, "Invalid schemas object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - auto wrapper = (DuckDBAdbcStatementWrapper *)statement->private_data; - if (wrapper->ingestion_stream.release) { - // Free the stream that was previously bound - wrapper->ingestion_stream.release(&wrapper->ingestion_stream); - } - auto status = BatchToArrayStream(values, schemas, &wrapper->ingestion_stream, error); - return status; -} - -AdbcStatusCode StatementBindStream(struct AdbcStatement *statement, struct ArrowArrayStream *values, - struct AdbcError *error) { - if (!statement) { - SetError(error, "Missing statement object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!statement->private_data) { - SetError(error, "Invalid statement object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!values) { - SetError(error, "Missing values object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - auto wrapper = (DuckDBAdbcStatementWrapper *)statement->private_data; - if (wrapper->ingestion_stream.release) { - // Release any resources currently held by the ingestion stream before we overwrite it - wrapper->ingestion_stream.release(&wrapper->ingestion_stream); - } - wrapper->ingestion_stream = *values; - values->release = nullptr; - return ADBC_STATUS_OK; -} - -AdbcStatusCode StatementSetOption(struct AdbcStatement *statement, const char *key, const char *value, - struct AdbcError *error) { - if (!statement) { - SetError(error, "Missing statement object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!statement->private_data) { - SetError(error, "Invalid statement object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!key) { - SetError(error, "Missing key object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - auto wrapper = (DuckDBAdbcStatementWrapper *)statement->private_data; - - if (strcmp(key, ADBC_INGEST_OPTION_TARGET_TABLE) == 0) { - wrapper->ingestion_table_name = strdup(value); - return ADBC_STATUS_OK; - } - if (strcmp(key, ADBC_INGEST_OPTION_MODE) == 0) { - if (strcmp(value, ADBC_INGEST_OPTION_MODE_CREATE) == 0) { - wrapper->ingestion_mode = IngestionMode::CREATE; - return ADBC_STATUS_OK; - } else if (strcmp(value, ADBC_INGEST_OPTION_MODE_APPEND) == 0) { - wrapper->ingestion_mode = IngestionMode::APPEND; - return ADBC_STATUS_OK; - } else { - SetError(error, "Invalid ingestion mode"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - } - return ADBC_STATUS_INVALID_ARGUMENT; -} - -AdbcStatusCode ConnectionGetObjects(struct AdbcConnection *connection, int depth, const char *catalog, - const char *db_schema, const char *table_name, const char **table_type, - const char *column_name, struct ArrowArrayStream *out, struct AdbcError *error) { - if (catalog != nullptr) { - if (strcmp(catalog, "duckdb") == 0) { - SetError(error, "catalog must be NULL or 'duckdb'"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - } - - if (table_type != nullptr) { - SetError(error, "Table types parameter not yet supported"); - return ADBC_STATUS_NOT_IMPLEMENTED; - } - std::string query; - switch (depth) { - case ADBC_OBJECT_DEPTH_CATALOGS: - SetError(error, "ADBC_OBJECT_DEPTH_CATALOGS not yet supported"); - return ADBC_STATUS_NOT_IMPLEMENTED; - case ADBC_OBJECT_DEPTH_DB_SCHEMAS: - // Return metadata on catalogs and schemas. - query = duckdb::StringUtil::Format(R"( - SELECT table_schema db_schema_name - FROM information_schema.columns - WHERE table_schema LIKE '%s' AND table_name LIKE '%s' AND column_name LIKE '%s' ; - )", - db_schema ? db_schema : "%", table_name ? table_name : "%", - column_name ? column_name : "%"); - break; - case ADBC_OBJECT_DEPTH_TABLES: - // Return metadata on catalogs, schemas, and tables. - query = duckdb::StringUtil::Format(R"( - SELECT table_schema db_schema_name, LIST(table_schema_list) db_schema_tables - FROM ( - SELECT table_schema, { table_name : table_name} table_schema_list - FROM information_schema.columns - WHERE table_schema LIKE '%s' AND table_name LIKE '%s' AND column_name LIKE '%s' GROUP BY table_schema, table_name - ) GROUP BY table_schema; - )", - db_schema ? db_schema : "%", table_name ? table_name : "%", - column_name ? column_name : "%"); - break; - case ADBC_OBJECT_DEPTH_COLUMNS: - // Return metadata on catalogs, schemas, tables, and columns. - query = duckdb::StringUtil::Format(R"( - SELECT table_schema db_schema_name, LIST(table_schema_list) db_schema_tables - FROM ( - SELECT table_schema, { table_name : table_name, table_columns : LIST({column_name : column_name, ordinal_position : ordinal_position + 1, remarks : ''})} table_schema_list - FROM information_schema.columns - WHERE table_schema LIKE '%s' AND table_name LIKE '%s' AND column_name LIKE '%s' GROUP BY table_schema, table_name - ) GROUP BY table_schema; - )", - db_schema ? db_schema : "%", table_name ? table_name : "%", - column_name ? column_name : "%"); - break; - default: - SetError(error, "Invalid value of Depth"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - return QueryInternal(connection, out, query.c_str(), error); -} - -AdbcStatusCode ConnectionGetTableTypes(struct AdbcConnection *connection, struct ArrowArrayStream *out, - struct AdbcError *error) { - const char *q = "SELECT DISTINCT table_type FROM information_schema.tables ORDER BY table_type"; - return QueryInternal(connection, out, q, error); -} - -} // namespace duckdb_adbc -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - - - - - -#include -#include -#include -#include -#include - -#if defined(_WIN32) -#include // Must come first - -#include -#include -#else -#include -#endif // defined(_WIN32) - -namespace duckdb_adbc { - -// Platform-specific helpers - -#if defined(_WIN32) -/// Append a description of the Windows error to the buffer. -void GetWinError(std::string *buffer) { - DWORD rc = GetLastError(); - LPVOID message; - - FormatMessage(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, - /*lpSource=*/nullptr, rc, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), - reinterpret_cast(&message), /*nSize=*/0, /*Arguments=*/nullptr); - - (*buffer) += '('; - (*buffer) += std::to_string(rc); - (*buffer) += ") "; - (*buffer) += reinterpret_cast(message); - LocalFree(message); -} - -#endif // defined(_WIN32) - -// Temporary state while the database is being configured. -struct TempDatabase { - std::unordered_map options; - std::string driver; - // Default name (see adbc.h) - std::string entrypoint = "AdbcDriverInit"; - AdbcDriverInitFunc init_func = nullptr; -}; - -// Error handling - -void ReleaseError(struct AdbcError *error) { - if (error) { - if (error->message) { - delete[] error->message; - } - error->message = nullptr; - error->release = nullptr; - } -} - -void SetError(struct AdbcError *error, const std::string &message) { - if (!error) { - return; - } - if (error->message) { - // Append - std::string buffer = error->message; - buffer.reserve(buffer.size() + message.size() + 1); - buffer += '\n'; - buffer += message; - error->release(error); - - error->message = new char[buffer.size() + 1]; - buffer.copy(error->message, buffer.size()); - error->message[buffer.size()] = '\0'; - } else { - error->message = new char[message.size() + 1]; - message.copy(error->message, message.size()); - error->message[message.size()] = '\0'; - } - error->release = ReleaseError; -} - -void SetError(struct AdbcError *error, const char *message_p) { - if (!message_p) { - message_p = ""; - } - std::string message(message_p); - SetError(error, message); -} - -// Driver state - -/// Hold the driver DLL and the driver release callback in the driver struct. -struct ManagerDriverState { - // The original release callback - AdbcStatusCode (*driver_release)(struct AdbcDriver *driver, struct AdbcError *error); - -#if defined(_WIN32) - // The loaded DLL - HMODULE handle; -#endif // defined(_WIN32) -}; - -/// Unload the driver DLL. -static AdbcStatusCode ReleaseDriver(struct AdbcDriver *driver, struct AdbcError *error) { - AdbcStatusCode status = ADBC_STATUS_OK; - - if (!driver->private_manager) { - return status; - } - ManagerDriverState *state = reinterpret_cast(driver->private_manager); - - if (state->driver_release) { - status = state->driver_release(driver, error); - } - -#if defined(_WIN32) - // TODO(apache/arrow-adbc#204): causes tests to segfault - // if (!FreeLibrary(state->handle)) { - // std::string message = "FreeLibrary() failed: "; - // GetWinError(&message); - // SetError(error, message); - // } -#endif // defined(_WIN32) - - driver->private_manager = nullptr; - delete state; - return status; -} - -/// Temporary state while the database is being configured. -struct TempConnection { - std::unordered_map options; -}; - -// Direct implementations of API methods - -AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase *database, struct AdbcError *error) { - // Allocate a temporary structure to store options pre-Init - database->private_data = new TempDatabase(); - database->private_driver = nullptr; - return ADBC_STATUS_OK; -} - -AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase *database, const char *key, const char *value, - struct AdbcError *error) { - if (!database) { - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (database->private_driver) { - return database->private_driver->DatabaseSetOption(database, key, value, error); - } - - TempDatabase *args = reinterpret_cast(database->private_data); - if (std::strcmp(key, "driver") == 0) { - args->driver = value; - } else if (std::strcmp(key, "entrypoint") == 0) { - args->entrypoint = value; - } else { - args->options[key] = value; - } - return ADBC_STATUS_OK; -} - -AdbcStatusCode AdbcDriverManagerDatabaseSetInitFunc(struct AdbcDatabase *database, AdbcDriverInitFunc init_func, - struct AdbcError *error) { - if (!database) { - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (database->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - - TempDatabase *args = reinterpret_cast(database->private_data); - args->init_func = init_func; - return ADBC_STATUS_OK; -} - -AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase *database, struct AdbcError *error) { - if (!database->private_data) { - SetError(error, "Must call AdbcDatabaseNew first"); - return ADBC_STATUS_INVALID_STATE; - } - TempDatabase *args = reinterpret_cast(database->private_data); - if (args->init_func) { - // Do nothing - } else if (args->driver.empty()) { - SetError(error, "Must provide 'driver' parameter"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - database->private_driver = new AdbcDriver; - std::memset(database->private_driver, 0, sizeof(AdbcDriver)); - AdbcStatusCode status; - // So we don't confuse a driver into thinking it's initialized already - database->private_data = nullptr; - if (args->init_func) { - status = AdbcLoadDriverFromInitFunc(args->init_func, ADBC_VERSION_1_0_0, database->private_driver, error); - } else { - status = AdbcLoadDriver(args->driver.c_str(), args->entrypoint.c_str(), ADBC_VERSION_1_0_0, - database->private_driver, error); - } - if (status != ADBC_STATUS_OK) { - // Restore private_data so it will be released by AdbcDatabaseRelease - database->private_data = args; - if (database->private_driver->release) { - database->private_driver->release(database->private_driver, error); - } - delete database->private_driver; - database->private_driver = nullptr; - return status; - } - status = database->private_driver->DatabaseNew(database, error); - if (status != ADBC_STATUS_OK) { - if (database->private_driver->release) { - database->private_driver->release(database->private_driver, error); - } - delete database->private_driver; - database->private_driver = nullptr; - return status; - } - for (const auto &option : args->options) { - status = - database->private_driver->DatabaseSetOption(database, option.first.c_str(), option.second.c_str(), error); - if (status != ADBC_STATUS_OK) { - delete args; - // Release the database - std::ignore = database->private_driver->DatabaseRelease(database, error); - if (database->private_driver->release) { - database->private_driver->release(database->private_driver, error); - } - delete database->private_driver; - database->private_driver = nullptr; - // Should be redundant, but ensure that AdbcDatabaseRelease - // below doesn't think that it contains a TempDatabase - database->private_data = nullptr; - return status; - } - } - delete args; - return database->private_driver->DatabaseInit(database, error); -} - -AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase *database, struct AdbcError *error) { - if (!database->private_driver) { - if (database->private_data) { - TempDatabase *args = reinterpret_cast(database->private_data); - delete args; - database->private_data = nullptr; - return ADBC_STATUS_OK; - } - return ADBC_STATUS_INVALID_STATE; - } - auto status = database->private_driver->DatabaseRelease(database, error); - if (database->private_driver->release) { - database->private_driver->release(database->private_driver, error); - } - delete database->private_driver; - database->private_data = nullptr; - database->private_driver = nullptr; - return status; -} - -AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection *connection, struct AdbcError *error) { - if (!connection) { - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!connection->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - return connection->private_driver->ConnectionCommit(connection, error); -} - -AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection *connection, uint32_t *info_codes, size_t info_codes_length, - struct ArrowArrayStream *out, struct AdbcError *error) { - if (!connection) { - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!connection->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - return connection->private_driver->ConnectionGetInfo(connection, info_codes, info_codes_length, out, error); -} - -AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection *connection, int depth, const char *catalog, - const char *db_schema, const char *table_name, const char **table_types, - const char *column_name, struct ArrowArrayStream *stream, - struct AdbcError *error) { - if (!connection) { - SetError(error, "connection can't be null"); - return ADBC_STATUS_INVALID_STATE; - } - if (!connection->private_data) { - SetError(error, "connection must be initialized"); - return ADBC_STATUS_INVALID_STATE; - } - return connection->private_driver->ConnectionGetObjects(connection, depth, catalog, db_schema, table_name, - table_types, column_name, stream, error); -} - -AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection *connection, const char *catalog, - const char *db_schema, const char *table_name, struct ArrowSchema *schema, - struct AdbcError *error) { - if (!connection) { - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!connection->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - return connection->private_driver->ConnectionGetTableSchema(connection, catalog, db_schema, table_name, schema, - error); -} - -AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection *connection, struct ArrowArrayStream *stream, - struct AdbcError *error) { - if (!connection) { - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!connection->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - return connection->private_driver->ConnectionGetTableTypes(connection, stream, error); -} - -AdbcStatusCode AdbcConnectionInit(struct AdbcConnection *connection, struct AdbcDatabase *database, - struct AdbcError *error) { - if (!connection) { - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!connection->private_data) { - SetError(error, "Must call AdbcConnectionNew first"); - return ADBC_STATUS_INVALID_STATE; - } else if (!database->private_driver) { - SetError(error, "Database is not initialized"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - TempConnection *args = reinterpret_cast(connection->private_data); - connection->private_data = nullptr; - std::unordered_map options = std::move(args->options); - delete args; - - auto status = database->private_driver->ConnectionNew(connection, error); - if (status != ADBC_STATUS_OK) { - return status; - } - connection->private_driver = database->private_driver; - - for (const auto &option : options) { - status = database->private_driver->ConnectionSetOption(connection, option.first.c_str(), option.second.c_str(), - error); - if (status != ADBC_STATUS_OK) { - return status; - } - } - return connection->private_driver->ConnectionInit(connection, database, error); -} - -AdbcStatusCode AdbcConnectionNew(struct AdbcConnection *connection, struct AdbcError *error) { - // Allocate a temporary structure to store options pre-Init, because - // we don't get access to the database (and hence the driver - // function table) until then - connection->private_data = new TempConnection; - connection->private_driver = nullptr; - return ADBC_STATUS_OK; -} - -AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection *connection, const uint8_t *serialized_partition, - size_t serialized_length, struct ArrowArrayStream *out, - struct AdbcError *error) { - if (!connection) { - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!connection->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - return connection->private_driver->ConnectionReadPartition(connection, serialized_partition, serialized_length, out, - error); -} - -AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection *connection, struct AdbcError *error) { - if (!connection) { - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!connection->private_driver) { - if (connection->private_data) { - TempConnection *args = reinterpret_cast(connection->private_data); - delete args; - connection->private_data = nullptr; - return ADBC_STATUS_OK; - } - return ADBC_STATUS_INVALID_STATE; - } - auto status = connection->private_driver->ConnectionRelease(connection, error); - connection->private_driver = nullptr; - return status; -} - -AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection *connection, struct AdbcError *error) { - if (!connection) { - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!connection->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - return connection->private_driver->ConnectionRollback(connection, error); -} - -AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection *connection, const char *key, const char *value, - struct AdbcError *error) { - if (!connection) { - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!connection->private_data) { - SetError(error, "AdbcConnectionSetOption: must AdbcConnectionNew first"); - return ADBC_STATUS_INVALID_STATE; - } - if (!connection->private_driver) { - // Init not yet called, save the option - TempConnection *args = reinterpret_cast(connection->private_data); - args->options[key] = value; - return ADBC_STATUS_OK; - } - return connection->private_driver->ConnectionSetOption(connection, key, value, error); -} - -AdbcStatusCode AdbcStatementBind(struct AdbcStatement *statement, struct ArrowArray *values, struct ArrowSchema *schema, - struct AdbcError *error) { - if (!statement) { - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - return statement->private_driver->StatementBind(statement, values, schema, error); -} - -AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement *statement, struct ArrowArrayStream *stream, - struct AdbcError *error) { - if (!statement) { - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - return statement->private_driver->StatementBindStream(statement, stream, error); -} - -// XXX: cpplint gets confused here if declared as 'struct ArrowSchema* schema' -AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement *statement, ArrowSchema *schema, - struct AdbcPartitions *partitions, int64_t *rows_affected, - struct AdbcError *error) { - if (!statement) { - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - return statement->private_driver->StatementExecutePartitions(statement, schema, partitions, rows_affected, error); -} - -AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement *statement, struct ArrowArrayStream *out, - int64_t *rows_affected, struct AdbcError *error) { - if (!statement) { - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - return statement->private_driver->StatementExecuteQuery(statement, out, rows_affected, error); -} - -AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement *statement, struct ArrowSchema *schema, - struct AdbcError *error) { - if (!statement) { - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - return statement->private_driver->StatementGetParameterSchema(statement, schema, error); -} - -AdbcStatusCode AdbcStatementNew(struct AdbcConnection *connection, struct AdbcStatement *statement, - struct AdbcError *error) { - if (!connection) { - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!connection->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - auto status = connection->private_driver->StatementNew(connection, statement, error); - statement->private_driver = connection->private_driver; - return status; -} - -AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement *statement, struct AdbcError *error) { - if (!statement) { - SetError(error, "Missing statement object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!statement->private_data) { - SetError(error, "Invalid statement object"); - return ADBC_STATUS_INVALID_STATE; - } - return statement->private_driver->StatementPrepare(statement, error); -} - -AdbcStatusCode AdbcStatementRelease(struct AdbcStatement *statement, struct AdbcError *error) { - if (!statement) { - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - auto status = statement->private_driver->StatementRelease(statement, error); - statement->private_driver = nullptr; - return status; -} - -AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement *statement, const char *key, const char *value, - struct AdbcError *error) { - if (!statement) { - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - return statement->private_driver->StatementSetOption(statement, key, value, error); -} - -AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement *statement, const char *query, struct AdbcError *error) { - if (!statement) { - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - return statement->private_driver->StatementSetSqlQuery(statement, query, error); -} - -AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement *statement, const uint8_t *plan, size_t length, - struct AdbcError *error) { - if (!statement) { - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - return statement->private_driver->StatementSetSubstraitPlan(statement, plan, length, error); -} - -const char *AdbcStatusCodeMessage(AdbcStatusCode code) { -#define STRINGIFY(s) #s -#define STRINGIFY_VALUE(s) STRINGIFY(s) -#define CASE(CONSTANT) \ - case CONSTANT: \ - return #CONSTANT " (" STRINGIFY_VALUE(CONSTANT) ")"; - - switch (code) { - CASE(ADBC_STATUS_OK); - CASE(ADBC_STATUS_UNKNOWN); - CASE(ADBC_STATUS_NOT_IMPLEMENTED); - CASE(ADBC_STATUS_NOT_FOUND); - CASE(ADBC_STATUS_ALREADY_EXISTS); - CASE(ADBC_STATUS_INVALID_ARGUMENT); - CASE(ADBC_STATUS_INVALID_STATE); - CASE(ADBC_STATUS_INVALID_DATA); - CASE(ADBC_STATUS_INTEGRITY); - CASE(ADBC_STATUS_INTERNAL); - CASE(ADBC_STATUS_IO); - CASE(ADBC_STATUS_CANCELLED); - CASE(ADBC_STATUS_TIMEOUT); - CASE(ADBC_STATUS_UNAUTHENTICATED); - CASE(ADBC_STATUS_UNAUTHORIZED); - default: - return "(invalid code)"; - } -#undef CASE -#undef STRINGIFY_VALUE -#undef STRINGIFY -} - -AdbcStatusCode AdbcLoadDriver(const char *driver_name, const char *entrypoint, int version, void *raw_driver, - struct AdbcError *error) { - AdbcDriverInitFunc init_func; - std::string error_message; - - if (version != ADBC_VERSION_1_0_0) { - SetError(error, "Only ADBC 1.0.0 is supported"); - return ADBC_STATUS_NOT_IMPLEMENTED; - } - - auto *driver = reinterpret_cast(raw_driver); - - if (!entrypoint) { - // Default entrypoint (see adbc.h) - entrypoint = "AdbcDriverInit"; - } - -#if defined(_WIN32) - - HMODULE handle = LoadLibraryExA(driver_name, NULL, 0); - if (!handle) { - error_message += driver_name; - error_message += ": LoadLibraryExA() failed: "; - GetWinError(&error_message); - - std::string full_driver_name = driver_name; - full_driver_name += ".lib"; - handle = LoadLibraryExA(full_driver_name.c_str(), NULL, 0); - if (!handle) { - error_message += '\n'; - error_message += full_driver_name; - error_message += ": LoadLibraryExA() failed: "; - GetWinError(&error_message); - } - } - if (!handle) { - SetError(error, error_message); - return ADBC_STATUS_INTERNAL; - } - - void *load_handle = reinterpret_cast(GetProcAddress(handle, entrypoint)); - init_func = reinterpret_cast(load_handle); - if (!init_func) { - std::string message = "GetProcAddress("; - message += entrypoint; - message += ") failed: "; - GetWinError(&message); - if (!FreeLibrary(handle)) { - message += "\nFreeLibrary() failed: "; - GetWinError(&message); - } - SetError(error, message); - return ADBC_STATUS_INTERNAL; - } - -#else - -#if defined(__APPLE__) - const std::string kPlatformLibraryPrefix = "lib"; - const std::string kPlatformLibrarySuffix = ".dylib"; -#else - const std::string kPlatformLibraryPrefix = "lib"; - const std::string kPlatformLibrarySuffix = ".so"; -#endif // defined(__APPLE__) - - void *handle = dlopen(driver_name, RTLD_NOW | RTLD_LOCAL); - if (!handle) { - error_message = "dlopen() failed: "; - error_message += dlerror(); - - // If applicable, append the shared library prefix/extension and - // try again (this way you don't have to hardcode driver names by - // platform in the application) - const std::string driver_str = driver_name; - - std::string full_driver_name; - if (driver_str.size() < kPlatformLibraryPrefix.size() || - driver_str.compare(0, kPlatformLibraryPrefix.size(), kPlatformLibraryPrefix) != 0) { - full_driver_name += kPlatformLibraryPrefix; - } - full_driver_name += driver_name; - if (driver_str.size() < kPlatformLibrarySuffix.size() || - driver_str.compare(full_driver_name.size() - kPlatformLibrarySuffix.size(), kPlatformLibrarySuffix.size(), - kPlatformLibrarySuffix) != 0) { - full_driver_name += kPlatformLibrarySuffix; - } - handle = dlopen(full_driver_name.c_str(), RTLD_NOW | RTLD_LOCAL); - if (!handle) { - error_message += "\ndlopen() failed: "; - error_message += dlerror(); - } - } - if (!handle) { - SetError(error, error_message); - // AdbcDatabaseInit tries to call this if set - driver->release = nullptr; - return ADBC_STATUS_INTERNAL; - } - - void *load_handle = dlsym(handle, entrypoint); - if (!load_handle) { - std::string message = "dlsym("; - message += entrypoint; - message += ") failed: "; - message += dlerror(); - SetError(error, message); - return ADBC_STATUS_INTERNAL; - } - init_func = reinterpret_cast(load_handle); - -#endif // defined(_WIN32) - - AdbcStatusCode status = AdbcLoadDriverFromInitFunc(init_func, version, driver, error); - if (status == ADBC_STATUS_OK) { - ManagerDriverState *state = new ManagerDriverState; - state->driver_release = driver->release; -#if defined(_WIN32) - state->handle = handle; -#endif // defined(_WIN32) - driver->release = &ReleaseDriver; - driver->private_manager = state; - } else { -#if defined(_WIN32) - if (!FreeLibrary(handle)) { - std::string message = "FreeLibrary() failed: "; - GetWinError(&message); - SetError(error, message); - } -#endif // defined(_WIN32) - } - return status; -} - -AdbcStatusCode AdbcLoadDriverFromInitFunc(AdbcDriverInitFunc init_func, int version, void *raw_driver, - struct AdbcError *error) { -#define FILL_DEFAULT(DRIVER, STUB) \ - if (!DRIVER->STUB) { \ - DRIVER->STUB = &STUB; \ - } -#define CHECK_REQUIRED(DRIVER, STUB) \ - if (!DRIVER->STUB) { \ - SetError(error, "Driver does not implement required function Adbc" #STUB); \ - return ADBC_STATUS_INTERNAL; \ - } - - auto result = init_func(version, raw_driver, error); - if (result != ADBC_STATUS_OK) { - return result; - } - - if (version == ADBC_VERSION_1_0_0) { - auto *driver = reinterpret_cast(raw_driver); - CHECK_REQUIRED(driver, DatabaseNew); - CHECK_REQUIRED(driver, DatabaseInit); - CHECK_REQUIRED(driver, DatabaseRelease); - FILL_DEFAULT(driver, DatabaseSetOption); - - CHECK_REQUIRED(driver, ConnectionNew); - CHECK_REQUIRED(driver, ConnectionInit); - CHECK_REQUIRED(driver, ConnectionRelease); - FILL_DEFAULT(driver, ConnectionCommit); - FILL_DEFAULT(driver, ConnectionGetInfo); - FILL_DEFAULT(driver, ConnectionGetObjects); - FILL_DEFAULT(driver, ConnectionGetTableSchema); - FILL_DEFAULT(driver, ConnectionGetTableTypes); - FILL_DEFAULT(driver, ConnectionReadPartition); - FILL_DEFAULT(driver, ConnectionRollback); - FILL_DEFAULT(driver, ConnectionSetOption); - - FILL_DEFAULT(driver, StatementExecutePartitions); - CHECK_REQUIRED(driver, StatementExecuteQuery); - CHECK_REQUIRED(driver, StatementNew); - CHECK_REQUIRED(driver, StatementRelease); - FILL_DEFAULT(driver, StatementBind); - FILL_DEFAULT(driver, StatementGetParameterSchema); - FILL_DEFAULT(driver, StatementPrepare); - FILL_DEFAULT(driver, StatementSetOption); - FILL_DEFAULT(driver, StatementSetSqlQuery); - FILL_DEFAULT(driver, StatementSetSubstraitPlan); - } - - return ADBC_STATUS_OK; - -#undef FILL_DEFAULT -#undef CHECK_REQUIRED -} -} // namespace duckdb_adbc -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include -#include - - - -namespace duckdb_nanoarrow { - -void *ArrowMalloc(int64_t size) { - return malloc(size); -} - -void *ArrowRealloc(void *ptr, int64_t size) { - return realloc(ptr, size); -} - -void ArrowFree(void *ptr) { - free(ptr); -} - -static uint8_t *ArrowBufferAllocatorMallocAllocate(struct ArrowBufferAllocator *allocator, int64_t size) { - return (uint8_t *)ArrowMalloc(size); -} - -static uint8_t *ArrowBufferAllocatorMallocReallocate(struct ArrowBufferAllocator *allocator, uint8_t *ptr, - int64_t old_size, int64_t new_size) { - return (uint8_t *)ArrowRealloc(ptr, new_size); -} - -static void ArrowBufferAllocatorMallocFree(struct ArrowBufferAllocator *allocator, uint8_t *ptr, int64_t size) { - ArrowFree(ptr); -} - -static struct ArrowBufferAllocator ArrowBufferAllocatorMalloc = { - &ArrowBufferAllocatorMallocAllocate, &ArrowBufferAllocatorMallocReallocate, &ArrowBufferAllocatorMallocFree, NULL}; - -struct ArrowBufferAllocator *ArrowBufferAllocatorDefault() { - return &ArrowBufferAllocatorMalloc; -} - -} // namespace duckdb_nanoarrow -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include -#include -#include - - - -namespace duckdb_nanoarrow { - -ArrowErrorCode ArrowMetadataReaderInit(struct ArrowMetadataReader *reader, const char *metadata) { - reader->metadata = metadata; - - if (reader->metadata == NULL) { - reader->offset = 0; - reader->remaining_keys = 0; - } else { - memcpy(&reader->remaining_keys, reader->metadata, sizeof(int32_t)); - reader->offset = sizeof(int32_t); - } - - return NANOARROW_OK; -} - -ArrowErrorCode ArrowMetadataReaderRead(struct ArrowMetadataReader *reader, struct ArrowStringView *key_out, - struct ArrowStringView *value_out) { - if (reader->remaining_keys <= 0) { - return EINVAL; - } - - int64_t pos = 0; - - int32_t key_size; - memcpy(&key_size, reader->metadata + reader->offset + pos, sizeof(int32_t)); - pos += sizeof(int32_t); - - key_out->data = reader->metadata + reader->offset + pos; - key_out->n_bytes = key_size; - pos += key_size; - - int32_t value_size; - memcpy(&value_size, reader->metadata + reader->offset + pos, sizeof(int32_t)); - pos += sizeof(int32_t); - - value_out->data = reader->metadata + reader->offset + pos; - value_out->n_bytes = value_size; - pos += value_size; - - reader->offset += pos; - reader->remaining_keys--; - return NANOARROW_OK; -} - -int64_t ArrowMetadataSizeOf(const char *metadata) { - if (metadata == NULL) { - return 0; - } - - struct ArrowMetadataReader reader; - struct ArrowStringView key; - struct ArrowStringView value; - ArrowMetadataReaderInit(&reader, metadata); - - int64_t size = sizeof(int32_t); - while (ArrowMetadataReaderRead(&reader, &key, &value) == NANOARROW_OK) { - size += sizeof(int32_t) + key.n_bytes + sizeof(int32_t) + value.n_bytes; - } - - return size; -} - -ArrowErrorCode ArrowMetadataGetValue(const char *metadata, const char *key, const char *default_value, - struct ArrowStringView *value_out) { - struct ArrowStringView target_key_view = {key, static_cast(strlen(key))}; - value_out->data = default_value; - if (default_value != NULL) { - value_out->n_bytes = strlen(default_value); - } else { - value_out->n_bytes = 0; - } - - struct ArrowMetadataReader reader; - struct ArrowStringView key_view; - struct ArrowStringView value; - ArrowMetadataReaderInit(&reader, metadata); - - while (ArrowMetadataReaderRead(&reader, &key_view, &value) == NANOARROW_OK) { - int key_equal = target_key_view.n_bytes == key_view.n_bytes && - strncmp(target_key_view.data, key_view.data, key_view.n_bytes) == 0; - if (key_equal) { - value_out->data = value.data; - value_out->n_bytes = value.n_bytes; - break; - } - } - - return NANOARROW_OK; -} - -char ArrowMetadataHasKey(const char *metadata, const char *key) { - struct ArrowStringView value; - ArrowMetadataGetValue(metadata, key, NULL, &value); - return value.data != NULL; -} - -} // namespace duckdb_nanoarrow -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include -#include -#include -#include - - - -namespace duckdb_nanoarrow { - -void ArrowSchemaRelease(struct ArrowSchema *schema) { - if (schema->format != NULL) - ArrowFree((void *)schema->format); - if (schema->name != NULL) - ArrowFree((void *)schema->name); - if (schema->metadata != NULL) - ArrowFree((void *)schema->metadata); - - // This object owns the memory for all the children, but those - // children may have been generated elsewhere and might have - // their own release() callback. - if (schema->children != NULL) { - for (int64_t i = 0; i < schema->n_children; i++) { - if (schema->children[i] != NULL) { - if (schema->children[i]->release != NULL) { - schema->children[i]->release(schema->children[i]); - } - - ArrowFree(schema->children[i]); - } - } - - ArrowFree(schema->children); - } - - // This object owns the memory for the dictionary but it - // may have been generated somewhere else and have its own - // release() callback. - if (schema->dictionary != NULL) { - if (schema->dictionary->release != NULL) { - schema->dictionary->release(schema->dictionary); - } - - ArrowFree(schema->dictionary); - } - - // private data not currently used - if (schema->private_data != NULL) { - ArrowFree(schema->private_data); - } - - schema->release = NULL; -} - -const char *ArrowSchemaFormatTemplate(enum ArrowType data_type) { - switch (data_type) { - case NANOARROW_TYPE_UNINITIALIZED: - return NULL; - case NANOARROW_TYPE_NA: - return "n"; - case NANOARROW_TYPE_BOOL: - return "b"; - - case NANOARROW_TYPE_UINT8: - return "C"; - case NANOARROW_TYPE_INT8: - return "c"; - case NANOARROW_TYPE_UINT16: - return "S"; - case NANOARROW_TYPE_INT16: - return "s"; - case NANOARROW_TYPE_UINT32: - return "I"; - case NANOARROW_TYPE_INT32: - return "i"; - case NANOARROW_TYPE_UINT64: - return "L"; - case NANOARROW_TYPE_INT64: - return "l"; - - case NANOARROW_TYPE_HALF_FLOAT: - return "e"; - case NANOARROW_TYPE_FLOAT: - return "f"; - case NANOARROW_TYPE_DOUBLE: - return "g"; - - case NANOARROW_TYPE_STRING: - return "u"; - case NANOARROW_TYPE_LARGE_STRING: - return "U"; - case NANOARROW_TYPE_BINARY: - return "z"; - case NANOARROW_TYPE_LARGE_BINARY: - return "Z"; - - case NANOARROW_TYPE_DATE32: - return "tdD"; - case NANOARROW_TYPE_DATE64: - return "tdm"; - case NANOARROW_TYPE_INTERVAL_MONTHS: - return "tiM"; - case NANOARROW_TYPE_INTERVAL_DAY_TIME: - return "tiD"; - case NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: - return "tin"; - - case NANOARROW_TYPE_LIST: - return "+l"; - case NANOARROW_TYPE_LARGE_LIST: - return "+L"; - case NANOARROW_TYPE_STRUCT: - return "+s"; - case NANOARROW_TYPE_MAP: - return "+m"; - - default: - return NULL; - } -} - -ArrowErrorCode ArrowSchemaInit(struct ArrowSchema *schema, enum ArrowType data_type) { - schema->format = NULL; - schema->name = NULL; - schema->metadata = NULL; - schema->flags = ARROW_FLAG_NULLABLE; - schema->n_children = 0; - schema->children = NULL; - schema->dictionary = NULL; - schema->private_data = NULL; - schema->release = &ArrowSchemaRelease; - - // We don't allocate the dictionary because it has to be nullptr - // for non-dictionary-encoded arrays. - - // Set the format to a valid format string for data_type - const char *template_format = ArrowSchemaFormatTemplate(data_type); - - // If data_type isn't recognized and not explicitly unset - if (template_format == NULL && data_type != NANOARROW_TYPE_UNINITIALIZED) { - schema->release(schema); - return EINVAL; - } - - int result = ArrowSchemaSetFormat(schema, template_format); - if (result != NANOARROW_OK) { - schema->release(schema); - return result; - } - - return NANOARROW_OK; -} - -ArrowErrorCode ArrowSchemaInitFixedSize(struct ArrowSchema *schema, enum ArrowType data_type, int32_t fixed_size) { - int result = ArrowSchemaInit(schema, NANOARROW_TYPE_UNINITIALIZED); - if (result != NANOARROW_OK) { - return result; - } - - if (fixed_size <= 0) { - schema->release(schema); - return EINVAL; - } - - char buffer[64]; - int n_chars; - switch (data_type) { - case NANOARROW_TYPE_FIXED_SIZE_BINARY: - n_chars = snprintf(buffer, sizeof(buffer), "w:%d", (int)fixed_size); - break; - case NANOARROW_TYPE_FIXED_SIZE_LIST: - n_chars = snprintf(buffer, sizeof(buffer), "+w:%d", (int)fixed_size); - break; - default: - schema->release(schema); - return EINVAL; - } - - buffer[n_chars] = '\0'; - result = ArrowSchemaSetFormat(schema, buffer); - if (result != NANOARROW_OK) { - schema->release(schema); - } - - return result; -} - -ArrowErrorCode ArrowSchemaInitDecimal(struct ArrowSchema *schema, enum ArrowType data_type, int32_t decimal_precision, - int32_t decimal_scale) { - int result = ArrowSchemaInit(schema, NANOARROW_TYPE_UNINITIALIZED); - if (result != NANOARROW_OK) { - return result; - } - - if (decimal_precision <= 0) { - schema->release(schema); - return EINVAL; - } - - char buffer[64]; - int n_chars; - switch (data_type) { - case NANOARROW_TYPE_DECIMAL128: - n_chars = snprintf(buffer, sizeof(buffer), "d:%d,%d", decimal_precision, decimal_scale); - break; - case NANOARROW_TYPE_DECIMAL256: - n_chars = snprintf(buffer, sizeof(buffer), "d:%d,%d,256", decimal_precision, decimal_scale); - break; - default: - schema->release(schema); - return EINVAL; - } - - buffer[n_chars] = '\0'; - - result = ArrowSchemaSetFormat(schema, buffer); - if (result != NANOARROW_OK) { - schema->release(schema); - return result; - } - - return NANOARROW_OK; -} - -static const char *ArrowTimeUnitString(enum ArrowTimeUnit time_unit) { - switch (time_unit) { - case NANOARROW_TIME_UNIT_SECOND: - return "s"; - case NANOARROW_TIME_UNIT_MILLI: - return "m"; - case NANOARROW_TIME_UNIT_MICRO: - return "u"; - case NANOARROW_TIME_UNIT_NANO: - return "n"; - default: - return NULL; - } -} - -ArrowErrorCode ArrowSchemaInitDateTime(struct ArrowSchema *schema, enum ArrowType data_type, - enum ArrowTimeUnit time_unit, const char *timezone) { - int result = ArrowSchemaInit(schema, NANOARROW_TYPE_UNINITIALIZED); - if (result != NANOARROW_OK) { - return result; - } - - const char *time_unit_str = ArrowTimeUnitString(time_unit); - if (time_unit_str == NULL) { - schema->release(schema); - return EINVAL; - } - - char buffer[128]; - int n_chars; - switch (data_type) { - case NANOARROW_TYPE_TIME32: - case NANOARROW_TYPE_TIME64: - if (timezone != NULL) { - schema->release(schema); - return EINVAL; - } - n_chars = snprintf(buffer, sizeof(buffer), "tt%s", time_unit_str); - break; - case NANOARROW_TYPE_TIMESTAMP: - if (timezone == NULL) { - timezone = ""; - } - n_chars = snprintf(buffer, sizeof(buffer), "ts%s:%s", time_unit_str, timezone); - break; - case NANOARROW_TYPE_DURATION: - if (timezone != NULL) { - schema->release(schema); - return EINVAL; - } - n_chars = snprintf(buffer, sizeof(buffer), "tD%s", time_unit_str); - break; - default: - schema->release(schema); - return EINVAL; - } - - if (static_cast(n_chars) >= sizeof(buffer)) { - schema->release(schema); - return ERANGE; - } - - buffer[n_chars] = '\0'; - - result = ArrowSchemaSetFormat(schema, buffer); - if (result != NANOARROW_OK) { - schema->release(schema); - return result; - } - - return NANOARROW_OK; -} - -ArrowErrorCode ArrowSchemaSetFormat(struct ArrowSchema *schema, const char *format) { - if (schema->format != NULL) { - ArrowFree((void *)schema->format); - } - - if (format != NULL) { - size_t format_size = strlen(format) + 1; - schema->format = (const char *)ArrowMalloc(format_size); - if (schema->format == NULL) { - return ENOMEM; - } - - memcpy((void *)schema->format, format, format_size); - } else { - schema->format = NULL; - } - - return NANOARROW_OK; -} - -ArrowErrorCode ArrowSchemaSetName(struct ArrowSchema *schema, const char *name) { - if (schema->name != NULL) { - ArrowFree((void *)schema->name); - } - - if (name != NULL) { - size_t name_size = strlen(name) + 1; - schema->name = (const char *)ArrowMalloc(name_size); - if (schema->name == NULL) { - return ENOMEM; - } - - memcpy((void *)schema->name, name, name_size); - } else { - schema->name = NULL; - } - - return NANOARROW_OK; -} - -ArrowErrorCode ArrowSchemaSetMetadata(struct ArrowSchema *schema, const char *metadata) { - if (schema->metadata != NULL) { - ArrowFree((void *)schema->metadata); - } - - if (metadata != NULL) { - size_t metadata_size = ArrowMetadataSizeOf(metadata); - schema->metadata = (const char *)ArrowMalloc(metadata_size); - if (schema->metadata == NULL) { - return ENOMEM; - } - - memcpy((void *)schema->metadata, metadata, metadata_size); - } else { - schema->metadata = NULL; - } - - return NANOARROW_OK; -} - -ArrowErrorCode ArrowSchemaAllocateChildren(struct ArrowSchema *schema, int64_t n_children) { - if (schema->children != NULL) { - return EEXIST; - } - - if (n_children > 0) { - schema->children = (struct ArrowSchema **)ArrowMalloc(n_children * sizeof(struct ArrowSchema *)); - - if (schema->children == NULL) { - return ENOMEM; - } - - schema->n_children = n_children; - - memset(schema->children, 0, n_children * sizeof(struct ArrowSchema *)); - - for (int64_t i = 0; i < n_children; i++) { - schema->children[i] = (struct ArrowSchema *)ArrowMalloc(sizeof(struct ArrowSchema)); - - if (schema->children[i] == NULL) { - return ENOMEM; - } - - schema->children[i]->release = NULL; - } - } - - return NANOARROW_OK; -} - -ArrowErrorCode ArrowSchemaAllocateDictionary(struct ArrowSchema *schema) { - if (schema->dictionary != NULL) { - return EEXIST; - } - - schema->dictionary = (struct ArrowSchema *)ArrowMalloc(sizeof(struct ArrowSchema)); - if (schema->dictionary == NULL) { - return ENOMEM; - } - - schema->dictionary->release = NULL; - return NANOARROW_OK; -} - -int ArrowSchemaDeepCopy(struct ArrowSchema *schema, struct ArrowSchema *schema_out) { - int result; - result = ArrowSchemaInit(schema_out, NANOARROW_TYPE_NA); - if (result != NANOARROW_OK) { - return result; - } - - result = ArrowSchemaSetFormat(schema_out, schema->format); - if (result != NANOARROW_OK) { - schema_out->release(schema_out); - return result; - } - - result = ArrowSchemaSetName(schema_out, schema->name); - if (result != NANOARROW_OK) { - schema_out->release(schema_out); - return result; - } - - result = ArrowSchemaSetMetadata(schema_out, schema->metadata); - if (result != NANOARROW_OK) { - schema_out->release(schema_out); - return result; - } - - result = ArrowSchemaAllocateChildren(schema_out, schema->n_children); - if (result != NANOARROW_OK) { - schema_out->release(schema_out); - return result; - } - - for (int64_t i = 0; i < schema->n_children; i++) { - result = ArrowSchemaDeepCopy(schema->children[i], schema_out->children[i]); - if (result != NANOARROW_OK) { - schema_out->release(schema_out); - return result; - } - } - - if (schema->dictionary != NULL) { - result = ArrowSchemaAllocateDictionary(schema_out); - if (result != NANOARROW_OK) { - schema_out->release(schema_out); - return result; - } - - result = ArrowSchemaDeepCopy(schema->dictionary, schema_out->dictionary); - if (result != NANOARROW_OK) { - schema_out->release(schema_out); - return result; - } - } - - return NANOARROW_OK; -} - -} // namespace duckdb_nanoarrow - - - - - - - - -#include -#include -#include -#include -#include - -namespace duckdb_adbc { - -using duckdb_nanoarrow::ArrowSchemaDeepCopy; - -static const char *SingleBatchArrayStreamGetLastError(struct ArrowArrayStream *stream) { - return NULL; -} - -static int SingleBatchArrayStreamGetNext(struct ArrowArrayStream *stream, struct ArrowArray *batch) { - if (!stream || !stream->private_data) { - return EINVAL; - } - struct SingleBatchArrayStream *impl = (struct SingleBatchArrayStream *)stream->private_data; - - memcpy(batch, &impl->batch, sizeof(*batch)); - memset(&impl->batch, 0, sizeof(*batch)); - return 0; -} - -static int SingleBatchArrayStreamGetSchema(struct ArrowArrayStream *stream, struct ArrowSchema *schema) { - if (!stream || !stream->private_data) { - return EINVAL; - } - struct SingleBatchArrayStream *impl = (struct SingleBatchArrayStream *)stream->private_data; - - return ArrowSchemaDeepCopy(&impl->schema, schema); -} - -static void SingleBatchArrayStreamRelease(struct ArrowArrayStream *stream) { - if (!stream || !stream->private_data) { - return; - } - struct SingleBatchArrayStream *impl = (struct SingleBatchArrayStream *)stream->private_data; - impl->schema.release(&impl->schema); - if (impl->batch.release) { - impl->batch.release(&impl->batch); - } - free(impl); - - memset(stream, 0, sizeof(*stream)); -} - -AdbcStatusCode BatchToArrayStream(struct ArrowArray *values, struct ArrowSchema *schema, - struct ArrowArrayStream *stream, struct AdbcError *error) { - if (!values->release) { - SetError(error, "ArrowArray is not initialized"); - return ADBC_STATUS_INTERNAL; - } else if (!schema->release) { - SetError(error, "ArrowSchema is not initialized"); - return ADBC_STATUS_INTERNAL; - } else if (stream->release) { - SetError(error, "ArrowArrayStream is already initialized"); - return ADBC_STATUS_INTERNAL; - } - - struct SingleBatchArrayStream *impl = (struct SingleBatchArrayStream *)malloc(sizeof(*impl)); - memcpy(&impl->schema, schema, sizeof(*schema)); - memcpy(&impl->batch, values, sizeof(*values)); - memset(schema, 0, sizeof(*schema)); - memset(values, 0, sizeof(*values)); - stream->private_data = impl; - stream->get_last_error = SingleBatchArrayStreamGetLastError; - stream->get_next = SingleBatchArrayStreamGetNext; - stream->get_schema = SingleBatchArrayStreamGetSchema; - stream->release = SingleBatchArrayStreamRelease; - - return ADBC_STATUS_OK; -} - -} // namespace duckdb_adbc - - - - - - -#include - -#ifdef DUCKDB_DEBUG_ALLOCATION - - - - -#include -#endif - -#ifndef USE_JEMALLOC -#if defined(DUCKDB_EXTENSION_JEMALLOC_LINKED) && DUCKDB_EXTENSION_JEMALLOC_LINKED && !defined(WIN32) -#define USE_JEMALLOC -#endif -#endif - -#ifdef USE_JEMALLOC -#include "jemalloc_extension.hpp" -#endif - -namespace duckdb { - -AllocatedData::AllocatedData() : allocator(nullptr), pointer(nullptr), allocated_size(0) { -} - -AllocatedData::AllocatedData(Allocator &allocator, data_ptr_t pointer, idx_t allocated_size) - : allocator(&allocator), pointer(pointer), allocated_size(allocated_size) { - if (!pointer) { - throw InternalException("AllocatedData object constructed with nullptr"); - } -} -AllocatedData::~AllocatedData() { - Reset(); -} - -AllocatedData::AllocatedData(AllocatedData &&other) noexcept - : allocator(other.allocator), pointer(nullptr), allocated_size(0) { - std::swap(pointer, other.pointer); - std::swap(allocated_size, other.allocated_size); -} - -AllocatedData &AllocatedData::operator=(AllocatedData &&other) noexcept { - std::swap(allocator, other.allocator); - std::swap(pointer, other.pointer); - std::swap(allocated_size, other.allocated_size); - return *this; -} - -void AllocatedData::Reset() { - if (!pointer) { - return; - } - D_ASSERT(allocator); - allocator->FreeData(pointer, allocated_size); - allocated_size = 0; - pointer = nullptr; -} - -//===--------------------------------------------------------------------===// -// Debug Info -//===--------------------------------------------------------------------===// -struct AllocatorDebugInfo { -#ifdef DEBUG - AllocatorDebugInfo(); - ~AllocatorDebugInfo(); - - void AllocateData(data_ptr_t pointer, idx_t size); - void FreeData(data_ptr_t pointer, idx_t size); - void ReallocateData(data_ptr_t pointer, data_ptr_t new_pointer, idx_t old_size, idx_t new_size); - -private: - //! The number of bytes that are outstanding (i.e. that have been allocated - but not freed) - //! Used for debug purposes - atomic allocation_count; -#ifdef DUCKDB_DEBUG_ALLOCATION - mutex pointer_lock; - //! Set of active outstanding pointers together with stack traces - unordered_map> pointers; -#endif -#endif -}; - -PrivateAllocatorData::PrivateAllocatorData() { -} - -PrivateAllocatorData::~PrivateAllocatorData() { -} - -//===--------------------------------------------------------------------===// -// Allocator -//===--------------------------------------------------------------------===// -#ifdef USE_JEMALLOC -Allocator::Allocator() - : Allocator(JemallocExtension::Allocate, JemallocExtension::Free, JemallocExtension::Reallocate, nullptr) { -} -#else -Allocator::Allocator() - : Allocator(Allocator::DefaultAllocate, Allocator::DefaultFree, Allocator::DefaultReallocate, nullptr) { -} -#endif - -Allocator::Allocator(allocate_function_ptr_t allocate_function_p, free_function_ptr_t free_function_p, - reallocate_function_ptr_t reallocate_function_p, unique_ptr private_data_p) - : allocate_function(allocate_function_p), free_function(free_function_p), - reallocate_function(reallocate_function_p), private_data(std::move(private_data_p)) { - D_ASSERT(allocate_function); - D_ASSERT(free_function); - D_ASSERT(reallocate_function); -#ifdef DEBUG - if (!private_data) { - private_data = make_uniq(); - } - private_data->debug_info = make_uniq(); -#endif -} - -Allocator::~Allocator() { -} - -data_ptr_t Allocator::AllocateData(idx_t size) { - D_ASSERT(size > 0); - if (size >= MAXIMUM_ALLOC_SIZE) { - D_ASSERT(false); - throw InternalException("Requested allocation size of %llu is out of range - maximum allocation size is %llu", - size, MAXIMUM_ALLOC_SIZE); - } - auto result = allocate_function(private_data.get(), size); -#ifdef DEBUG - D_ASSERT(private_data); - private_data->debug_info->AllocateData(result, size); -#endif - if (!result) { - throw OutOfMemoryException("Failed to allocate block of %llu bytes", size); - } - return result; -} - -void Allocator::FreeData(data_ptr_t pointer, idx_t size) { - if (!pointer) { - return; - } - D_ASSERT(size > 0); -#ifdef DEBUG - D_ASSERT(private_data); - private_data->debug_info->FreeData(pointer, size); -#endif - free_function(private_data.get(), pointer, size); -} - -data_ptr_t Allocator::ReallocateData(data_ptr_t pointer, idx_t old_size, idx_t size) { - if (!pointer) { - return nullptr; - } - if (size >= MAXIMUM_ALLOC_SIZE) { - D_ASSERT(false); - throw InternalException( - "Requested re-allocation size of %llu is out of range - maximum allocation size is %llu", size, - MAXIMUM_ALLOC_SIZE); - } - auto new_pointer = reallocate_function(private_data.get(), pointer, old_size, size); -#ifdef DEBUG - D_ASSERT(private_data); - private_data->debug_info->ReallocateData(pointer, new_pointer, old_size, size); -#endif - if (!new_pointer) { - throw OutOfMemoryException("Failed to re-allocate block of %llu bytes", size); - } - return new_pointer; -} - -shared_ptr &Allocator::DefaultAllocatorReference() { - static shared_ptr DEFAULT_ALLOCATOR = make_shared(); - return DEFAULT_ALLOCATOR; -} - -Allocator &Allocator::DefaultAllocator() { - return *DefaultAllocatorReference(); -} - -void Allocator::ThreadFlush(idx_t threshold) { -#ifdef USE_JEMALLOC - JemallocExtension::ThreadFlush(threshold); -#endif -} - -//===--------------------------------------------------------------------===// -// Debug Info (extended) -//===--------------------------------------------------------------------===// -#ifdef DEBUG -AllocatorDebugInfo::AllocatorDebugInfo() { - allocation_count = 0; -} -AllocatorDebugInfo::~AllocatorDebugInfo() { -#ifdef DUCKDB_DEBUG_ALLOCATION - if (allocation_count != 0) { - printf("Outstanding allocations found for Allocator\n"); - for (auto &entry : pointers) { - printf("Allocation of size %llu at address %p\n", entry.second.first, (void *)entry.first); - printf("Stack trace:\n%s\n", entry.second.second.c_str()); - printf("\n"); - } - } -#endif - //! Verify that there is no outstanding memory still associated with the batched allocator - //! Only works for access to the batched allocator through the batched allocator interface - //! If this assertion triggers, enable DUCKDB_DEBUG_ALLOCATION for more information about the allocations - D_ASSERT(allocation_count == 0); -} - -void AllocatorDebugInfo::AllocateData(data_ptr_t pointer, idx_t size) { - allocation_count += size; -#ifdef DUCKDB_DEBUG_ALLOCATION - lock_guard l(pointer_lock); - pointers[pointer] = make_pair(size, Exception::GetStackTrace()); -#endif -} - -void AllocatorDebugInfo::FreeData(data_ptr_t pointer, idx_t size) { - D_ASSERT(allocation_count >= size); - allocation_count -= size; -#ifdef DUCKDB_DEBUG_ALLOCATION - lock_guard l(pointer_lock); - // verify that the pointer exists - D_ASSERT(pointers.find(pointer) != pointers.end()); - // verify that the stored size matches the passed in size - D_ASSERT(pointers[pointer].first == size); - // erase the pointer - pointers.erase(pointer); -#endif -} - -void AllocatorDebugInfo::ReallocateData(data_ptr_t pointer, data_ptr_t new_pointer, idx_t old_size, idx_t new_size) { - FreeData(pointer, old_size); - AllocateData(new_pointer, new_size); -} - -#endif - -} // namespace duckdb - - - -namespace duckdb { - -void ArrowBoolData::Initialize(ArrowAppendData &result, const LogicalType &type, idx_t capacity) { - auto byte_count = (capacity + 7) / 8; - result.main_buffer.reserve(byte_count); -} - -void ArrowBoolData::Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size) { - idx_t size = to - from; - UnifiedVectorFormat format; - input.ToUnifiedFormat(input_size, format); - - // we initialize both the validity and the bit set to 1's - ResizeValidity(append_data.validity, append_data.row_count + size); - ResizeValidity(append_data.main_buffer, append_data.row_count + size); - auto data = UnifiedVectorFormat::GetData(format); - - auto result_data = append_data.main_buffer.GetData(); - auto validity_data = append_data.validity.GetData(); - uint8_t current_bit; - idx_t current_byte; - GetBitPosition(append_data.row_count, current_byte, current_bit); - for (idx_t i = from; i < to; i++) { - auto source_idx = format.sel->get_index(i); - // append the validity mask - if (!format.validity.RowIsValid(source_idx)) { - SetNull(append_data, validity_data, current_byte, current_bit); - } else if (!data[source_idx]) { - UnsetBit(result_data, current_byte, current_bit); - } - NextBit(current_byte, current_bit); - } - append_data.row_count += size; -} - -void ArrowBoolData::Finalize(ArrowAppendData &append_data, const LogicalType &type, ArrowArray *result) { - result->n_buffers = 2; - result->buffers[1] = append_data.main_buffer.data(); -} - -} // namespace duckdb - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Lists -//===--------------------------------------------------------------------===// -void ArrowListData::AppendOffsets(ArrowAppendData &append_data, UnifiedVectorFormat &format, idx_t from, idx_t to, - vector &child_sel) { - // resize the offset buffer - the offset buffer holds the offsets into the child array - idx_t size = to - from; - append_data.main_buffer.resize(append_data.main_buffer.size() + sizeof(uint32_t) * (size + 1)); - auto data = UnifiedVectorFormat::GetData(format); - auto offset_data = append_data.main_buffer.GetData(); - if (append_data.row_count == 0) { - // first entry - offset_data[0] = 0; - } - // set up the offsets using the list entries - auto last_offset = offset_data[append_data.row_count]; - for (idx_t i = from; i < to; i++) { - auto source_idx = format.sel->get_index(i); - auto offset_idx = append_data.row_count + i + 1 - from; - - if (!format.validity.RowIsValid(source_idx)) { - offset_data[offset_idx] = last_offset; - continue; - } - - // append the offset data - auto list_length = data[source_idx].length; - last_offset += list_length; - offset_data[offset_idx] = last_offset; - - for (idx_t k = 0; k < list_length; k++) { - child_sel.push_back(data[source_idx].offset + k); - } - } -} - -void ArrowListData::Initialize(ArrowAppendData &result, const LogicalType &type, idx_t capacity) { - auto &child_type = ListType::GetChildType(type); - result.main_buffer.reserve((capacity + 1) * sizeof(uint32_t)); - auto child_buffer = ArrowAppender::InitializeChild(child_type, capacity, result.options); - result.child_data.push_back(std::move(child_buffer)); -} - -void ArrowListData::Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size) { - UnifiedVectorFormat format; - input.ToUnifiedFormat(input_size, format); - idx_t size = to - from; - vector child_indices; - AppendValidity(append_data, format, from, to); - ArrowListData::AppendOffsets(append_data, format, from, to, child_indices); - - // append the child vector of the list - SelectionVector child_sel(child_indices.data()); - auto &child = ListVector::GetEntry(input); - auto child_size = child_indices.size(); - Vector child_copy(child.GetType()); - child_copy.Slice(child, child_sel, child_size); - append_data.child_data[0]->append_vector(*append_data.child_data[0], child_copy, 0, child_size, child_size); - append_data.row_count += size; -} - -void ArrowListData::Finalize(ArrowAppendData &append_data, const LogicalType &type, ArrowArray *result) { - result->n_buffers = 2; - result->buffers[1] = append_data.main_buffer.data(); - - auto &child_type = ListType::GetChildType(type); - append_data.child_pointers.resize(1); - result->children = append_data.child_pointers.data(); - result->n_children = 1; - append_data.child_pointers[0] = ArrowAppender::FinalizeChild(child_type, *append_data.child_data[0]); -} - -} // namespace duckdb - - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Maps -//===--------------------------------------------------------------------===// -void ArrowMapData::Initialize(ArrowAppendData &result, const LogicalType &type, idx_t capacity) { - // map types are stored in a (too) clever way - // the main buffer holds the null values and the offsets - // then we have a single child, which is a struct of the map_type, and the key_type - result.main_buffer.reserve((capacity + 1) * sizeof(uint32_t)); - - auto &key_type = MapType::KeyType(type); - auto &value_type = MapType::ValueType(type); - auto internal_struct = make_uniq(result.options); - internal_struct->child_data.push_back(ArrowAppender::InitializeChild(key_type, capacity, result.options)); - internal_struct->child_data.push_back(ArrowAppender::InitializeChild(value_type, capacity, result.options)); - - result.child_data.push_back(std::move(internal_struct)); -} - -void ArrowMapData::Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size) { - UnifiedVectorFormat format; - input.ToUnifiedFormat(input_size, format); - idx_t size = to - from; - AppendValidity(append_data, format, from, to); - vector child_indices; - ArrowListData::AppendOffsets(append_data, format, from, to, child_indices); - - SelectionVector child_sel(child_indices.data()); - auto &key_vector = MapVector::GetKeys(input); - auto &value_vector = MapVector::GetValues(input); - auto list_size = child_indices.size(); - - auto &struct_data = *append_data.child_data[0]; - auto &key_data = *struct_data.child_data[0]; - auto &value_data = *struct_data.child_data[1]; - - Vector key_vector_copy(key_vector.GetType()); - key_vector_copy.Slice(key_vector, child_sel, list_size); - Vector value_vector_copy(value_vector.GetType()); - value_vector_copy.Slice(value_vector, child_sel, list_size); - key_data.append_vector(key_data, key_vector_copy, 0, list_size, list_size); - value_data.append_vector(value_data, value_vector_copy, 0, list_size, list_size); - - append_data.row_count += size; - struct_data.row_count += size; -} - -void ArrowMapData::Finalize(ArrowAppendData &append_data, const LogicalType &type, ArrowArray *result) { - // set up the main map buffer - result->n_buffers = 2; - result->buffers[1] = append_data.main_buffer.data(); - - // the main map buffer has a single child: a struct - append_data.child_pointers.resize(1); - result->children = append_data.child_pointers.data(); - result->n_children = 1; - append_data.child_pointers[0] = ArrowAppender::FinalizeChild(type, *append_data.child_data[0]); - - // now that struct has two children: the key and the value type - auto &struct_data = *append_data.child_data[0]; - auto &struct_result = append_data.child_pointers[0]; - struct_data.child_pointers.resize(2); - struct_result->n_buffers = 1; - struct_result->n_children = 2; - struct_result->length = struct_data.child_data[0]->row_count; - struct_result->children = struct_data.child_pointers.data(); - - D_ASSERT(struct_data.child_data[0]->row_count == struct_data.child_data[1]->row_count); - - auto &key_type = MapType::KeyType(type); - auto &value_type = MapType::ValueType(type); - struct_data.child_pointers[0] = ArrowAppender::FinalizeChild(key_type, *struct_data.child_data[0]); - struct_data.child_pointers[1] = ArrowAppender::FinalizeChild(value_type, *struct_data.child_data[1]); - - // keys cannot have null values - if (struct_data.child_pointers[0]->null_count > 0) { - throw std::runtime_error("Arrow doesn't accept NULL keys on Maps"); - } -} - -} // namespace duckdb - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Structs -//===--------------------------------------------------------------------===// -void ArrowStructData::Initialize(ArrowAppendData &result, const LogicalType &type, idx_t capacity) { - auto &children = StructType::GetChildTypes(type); - for (auto &child : children) { - auto child_buffer = ArrowAppender::InitializeChild(child.second, capacity, result.options); - result.child_data.push_back(std::move(child_buffer)); - } -} - -void ArrowStructData::Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size) { - UnifiedVectorFormat format; - input.ToUnifiedFormat(input_size, format); - idx_t size = to - from; - AppendValidity(append_data, format, from, to); - // append the children of the struct - auto &children = StructVector::GetEntries(input); - for (idx_t child_idx = 0; child_idx < children.size(); child_idx++) { - auto &child = children[child_idx]; - auto &child_data = *append_data.child_data[child_idx]; - child_data.append_vector(child_data, *child, from, to, size); - } - append_data.row_count += size; -} - -void ArrowStructData::Finalize(ArrowAppendData &append_data, const LogicalType &type, ArrowArray *result) { - result->n_buffers = 1; - - auto &child_types = StructType::GetChildTypes(type); - append_data.child_pointers.resize(child_types.size()); - result->children = append_data.child_pointers.data(); - result->n_children = child_types.size(); - for (idx_t i = 0; i < child_types.size(); i++) { - auto &child_type = child_types[i].second; - append_data.child_pointers[i] = ArrowAppender::FinalizeChild(child_type, *append_data.child_data[i]); - } -} - -} // namespace duckdb - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Unions -//===--------------------------------------------------------------------===// -void ArrowUnionData::Initialize(ArrowAppendData &result, const LogicalType &type, idx_t capacity) { - result.main_buffer.reserve(capacity * sizeof(int8_t)); - - for (auto &child : UnionType::CopyMemberTypes(type)) { - auto child_buffer = ArrowAppender::InitializeChild(child.second, capacity, result.options); - result.child_data.push_back(std::move(child_buffer)); - } -} - -void ArrowUnionData::Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size) { - UnifiedVectorFormat format; - input.ToUnifiedFormat(input_size, format); - idx_t size = to - from; - - auto &types_buffer = append_data.main_buffer; - - duckdb::vector child_vectors; - for (const auto &child : UnionType::CopyMemberTypes(input.GetType())) { - child_vectors.emplace_back(child.second); - } - - for (idx_t input_idx = from; input_idx < to; input_idx++) { - const auto &val = input.GetValue(input_idx); - - idx_t tag = 0; - Value resolved_value(nullptr); - if (!val.IsNull()) { - tag = UnionValue::GetTag(val); - - resolved_value = UnionValue::GetValue(val); - } - - for (idx_t child_idx = 0; child_idx < child_vectors.size(); child_idx++) { - child_vectors[child_idx].SetValue(input_idx, child_idx == tag ? resolved_value : Value(nullptr)); - } - - types_buffer.data()[input_idx] = tag; - } - - for (idx_t child_idx = 0; child_idx < child_vectors.size(); child_idx++) { - auto &child_buffer = append_data.child_data[child_idx]; - auto &child = child_vectors[child_idx]; - child_buffer->append_vector(*child_buffer, child, from, to, size); - } - append_data.row_count += size; -} - -void ArrowUnionData::Finalize(ArrowAppendData &append_data, const LogicalType &type, ArrowArray *result) { - result->n_buffers = 2; - result->buffers[1] = append_data.main_buffer.data(); - - auto &child_types = UnionType::CopyMemberTypes(type); - append_data.child_pointers.resize(child_types.size()); - result->children = append_data.child_pointers.data(); - result->n_children = child_types.size(); - for (idx_t i = 0; i < child_types.size(); i++) { - auto &child_type = child_types[i].second; - append_data.child_pointers[i] = ArrowAppender::FinalizeChild(child_type, *append_data.child_data[i]); - } -} - -} // namespace duckdb - - - - - - - - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// ArrowAppender -//===--------------------------------------------------------------------===// - -ArrowAppender::ArrowAppender(vector types_p, idx_t initial_capacity, ClientProperties options) - : types(std::move(types_p)) { - for (auto &type : types) { - auto entry = ArrowAppender::InitializeChild(type, initial_capacity, options); - root_data.push_back(std::move(entry)); - } -} - -ArrowAppender::~ArrowAppender() { -} - -//! Append a data chunk to the underlying arrow array -void ArrowAppender::Append(DataChunk &input, idx_t from, idx_t to, idx_t input_size) { - D_ASSERT(types == input.GetTypes()); - D_ASSERT(to >= from); - for (idx_t i = 0; i < input.ColumnCount(); i++) { - root_data[i]->append_vector(*root_data[i], input.data[i], from, to, input_size); - } - row_count += to - from; -} - -void ArrowAppender::ReleaseArray(ArrowArray *array) { - if (!array || !array->release) { - return; - } - array->release = nullptr; - auto holder = static_cast(array->private_data); - delete holder; -} - -//===--------------------------------------------------------------------===// -// Finalize Arrow Child -//===--------------------------------------------------------------------===// -ArrowArray *ArrowAppender::FinalizeChild(const LogicalType &type, ArrowAppendData &append_data) { - auto result = make_uniq(); - - result->private_data = nullptr; - result->release = ArrowAppender::ReleaseArray; - result->n_children = 0; - result->null_count = 0; - result->offset = 0; - result->dictionary = nullptr; - result->buffers = append_data.buffers.data(); - result->null_count = append_data.null_count; - result->length = append_data.row_count; - result->buffers[0] = append_data.validity.data(); - - if (append_data.finalize) { - append_data.finalize(append_data, type, result.get()); - } - - append_data.array = std::move(result); - return append_data.array.get(); -} - -//! Returns the underlying arrow array -ArrowArray ArrowAppender::Finalize() { - D_ASSERT(root_data.size() == types.size()); - auto root_holder = make_uniq(options); - - ArrowArray result; - root_holder->child_pointers.resize(types.size()); - result.children = root_holder->child_pointers.data(); - result.n_children = types.size(); - - // Configure root array - result.length = row_count; - result.n_buffers = 1; - result.buffers = root_holder->buffers.data(); // there is no actual buffer there since we don't have NULLs - result.offset = 0; - result.null_count = 0; // needs to be 0 - result.dictionary = nullptr; - root_holder->child_data = std::move(root_data); - - // FIXME: this violates a property of the arrow format, if root owns all the child memory then consumers can't move - // child arrays https://arrow.apache.org/docs/format/CDataInterface.html#moving-child-arrays - for (idx_t i = 0; i < root_holder->child_data.size(); i++) { - root_holder->child_pointers[i] = ArrowAppender::FinalizeChild(types[i], *root_holder->child_data[i]); - } - - // Release ownership to caller - result.private_data = root_holder.release(); - result.release = ArrowAppender::ReleaseArray; - return result; -} - -//===--------------------------------------------------------------------===// -// Initialize Arrow Child -//===--------------------------------------------------------------------===// - -template -static void InitializeAppenderForType(ArrowAppendData &append_data) { - append_data.initialize = OP::Initialize; - append_data.append_vector = OP::Append; - append_data.finalize = OP::Finalize; -} - -static void InitializeFunctionPointers(ArrowAppendData &append_data, const LogicalType &type) { - // handle special logical types - switch (type.id()) { - case LogicalTypeId::BOOLEAN: - InitializeAppenderForType(append_data); - break; - case LogicalTypeId::TINYINT: - InitializeAppenderForType>(append_data); - break; - case LogicalTypeId::SMALLINT: - InitializeAppenderForType>(append_data); - break; - case LogicalTypeId::DATE: - case LogicalTypeId::INTEGER: - InitializeAppenderForType>(append_data); - break; - case LogicalTypeId::TIME: - case LogicalTypeId::TIMESTAMP_SEC: - case LogicalTypeId::TIMESTAMP_MS: - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_NS: - case LogicalTypeId::TIMESTAMP_TZ: - case LogicalTypeId::TIME_TZ: - case LogicalTypeId::BIGINT: - InitializeAppenderForType>(append_data); - break; - case LogicalTypeId::HUGEINT: - InitializeAppenderForType>(append_data); - break; - case LogicalTypeId::UTINYINT: - InitializeAppenderForType>(append_data); - break; - case LogicalTypeId::USMALLINT: - InitializeAppenderForType>(append_data); - break; - case LogicalTypeId::UINTEGER: - InitializeAppenderForType>(append_data); - break; - case LogicalTypeId::UBIGINT: - InitializeAppenderForType>(append_data); - break; - case LogicalTypeId::FLOAT: - InitializeAppenderForType>(append_data); - break; - case LogicalTypeId::DOUBLE: - InitializeAppenderForType>(append_data); - break; - case LogicalTypeId::DECIMAL: - switch (type.InternalType()) { - case PhysicalType::INT16: - InitializeAppenderForType>(append_data); - break; - case PhysicalType::INT32: - InitializeAppenderForType>(append_data); - break; - case PhysicalType::INT64: - InitializeAppenderForType>(append_data); - break; - case PhysicalType::INT128: - InitializeAppenderForType>(append_data); - break; - default: - throw InternalException("Unsupported internal decimal type"); - } - break; - case LogicalTypeId::VARCHAR: - case LogicalTypeId::BLOB: - case LogicalTypeId::BIT: - if (append_data.options.arrow_offset_size == ArrowOffsetSize::LARGE) { - InitializeAppenderForType>(append_data); - } else { - InitializeAppenderForType>(append_data); - } - break; - case LogicalTypeId::UUID: - if (append_data.options.arrow_offset_size == ArrowOffsetSize::LARGE) { - InitializeAppenderForType>(append_data); - } else { - InitializeAppenderForType>(append_data); - } - break; - case LogicalTypeId::ENUM: - switch (type.InternalType()) { - case PhysicalType::UINT8: - InitializeAppenderForType>(append_data); - break; - case PhysicalType::UINT16: - InitializeAppenderForType>(append_data); - break; - case PhysicalType::UINT32: - InitializeAppenderForType>(append_data); - break; - default: - throw InternalException("Unsupported internal enum type"); - } - break; - case LogicalTypeId::INTERVAL: - InitializeAppenderForType>(append_data); - break; - case LogicalTypeId::UNION: - InitializeAppenderForType(append_data); - break; - case LogicalTypeId::STRUCT: - InitializeAppenderForType(append_data); - break; - case LogicalTypeId::LIST: - InitializeAppenderForType(append_data); - break; - case LogicalTypeId::MAP: - InitializeAppenderForType(append_data); - break; - default: - throw NotImplementedException("Unsupported type in DuckDB -> Arrow Conversion: %s\n", type.ToString()); - } -} - -unique_ptr ArrowAppender::InitializeChild(const LogicalType &type, idx_t capacity, - ClientProperties &options) { - auto result = make_uniq(options); - InitializeFunctionPointers(*result, type); - - auto byte_count = (capacity + 7) / 8; - result->validity.reserve(byte_count); - result->initialize(*result, type, capacity); - return result; -} - -} // namespace duckdb - - - - - - - - - - - -#include - - -namespace duckdb { - -void ArrowConverter::ToArrowArray(DataChunk &input, ArrowArray *out_array, ClientProperties options) { - ArrowAppender appender(input.GetTypes(), input.size(), std::move(options)); - appender.Append(input, 0, input.size(), input.size()); - *out_array = appender.Finalize(); -} - -unsafe_unique_array AddName(const string &name) { - auto name_ptr = make_unsafe_uniq_array(name.size() + 1); - for (size_t i = 0; i < name.size(); i++) { - name_ptr[i] = name[i]; - } - name_ptr[name.size()] = '\0'; - return name_ptr; -} - -//===--------------------------------------------------------------------===// -// Arrow Schema -//===--------------------------------------------------------------------===// -struct DuckDBArrowSchemaHolder { - // unused in children - vector children; - // unused in children - vector children_ptrs; - //! used for nested structures - std::list> nested_children; - std::list> nested_children_ptr; - //! This holds strings created to represent decimal types - vector> owned_type_names; - vector> owned_column_names; -}; - -static void ReleaseDuckDBArrowSchema(ArrowSchema *schema) { - if (!schema || !schema->release) { - return; - } - schema->release = nullptr; - auto holder = static_cast(schema->private_data); - delete holder; -} - -void InitializeChild(ArrowSchema &child, DuckDBArrowSchemaHolder &root_holder, const string &name = "") { - //! Child is cleaned up by parent - child.private_data = nullptr; - child.release = ReleaseDuckDBArrowSchema; - - // Store the child schema - child.flags = ARROW_FLAG_NULLABLE; - root_holder.owned_type_names.push_back(AddName(name)); - - child.name = root_holder.owned_type_names.back().get(); - child.n_children = 0; - child.children = nullptr; - child.metadata = nullptr; - child.dictionary = nullptr; -} - -void SetArrowFormat(DuckDBArrowSchemaHolder &root_holder, ArrowSchema &child, const LogicalType &type, - const ClientProperties &options); - -void SetArrowMapFormat(DuckDBArrowSchemaHolder &root_holder, ArrowSchema &child, const LogicalType &type, - const ClientProperties &options) { - child.format = "+m"; - //! Map has one child which is a struct - child.n_children = 1; - root_holder.nested_children.emplace_back(); - root_holder.nested_children.back().resize(1); - root_holder.nested_children_ptr.emplace_back(); - root_holder.nested_children_ptr.back().push_back(&root_holder.nested_children.back()[0]); - InitializeChild(root_holder.nested_children.back()[0], root_holder); - child.children = &root_holder.nested_children_ptr.back()[0]; - child.children[0]->name = "entries"; - SetArrowFormat(root_holder, **child.children, ListType::GetChildType(type), options); -} - -void SetArrowFormat(DuckDBArrowSchemaHolder &root_holder, ArrowSchema &child, const LogicalType &type, - const ClientProperties &options) { - switch (type.id()) { - case LogicalTypeId::BOOLEAN: - child.format = "b"; - break; - case LogicalTypeId::TINYINT: - child.format = "c"; - break; - case LogicalTypeId::SMALLINT: - child.format = "s"; - break; - case LogicalTypeId::INTEGER: - child.format = "i"; - break; - case LogicalTypeId::BIGINT: - child.format = "l"; - break; - case LogicalTypeId::UTINYINT: - child.format = "C"; - break; - case LogicalTypeId::USMALLINT: - child.format = "S"; - break; - case LogicalTypeId::UINTEGER: - child.format = "I"; - break; - case LogicalTypeId::UBIGINT: - child.format = "L"; - break; - case LogicalTypeId::FLOAT: - child.format = "f"; - break; - case LogicalTypeId::HUGEINT: - child.format = "d:38,0"; - break; - case LogicalTypeId::DOUBLE: - child.format = "g"; - break; - case LogicalTypeId::UUID: - case LogicalTypeId::VARCHAR: - if (options.arrow_offset_size == ArrowOffsetSize::LARGE) { - child.format = "U"; - } else { - child.format = "u"; - } - break; - case LogicalTypeId::DATE: - child.format = "tdD"; - break; -#ifdef DUCKDB_WASM - case LogicalTypeId::TIME_TZ: -#endif - case LogicalTypeId::TIME: - child.format = "ttu"; - break; - case LogicalTypeId::TIMESTAMP: - child.format = "tsu:"; - break; - case LogicalTypeId::TIMESTAMP_TZ: { - string format = "tsu:" + options.time_zone; - root_holder.owned_type_names.push_back(AddName(format)); - child.format = root_holder.owned_type_names.back().get(); - break; - } - case LogicalTypeId::TIMESTAMP_SEC: - child.format = "tss:"; - break; - case LogicalTypeId::TIMESTAMP_NS: - child.format = "tsn:"; - break; - case LogicalTypeId::TIMESTAMP_MS: - child.format = "tsm:"; - break; - case LogicalTypeId::INTERVAL: - child.format = "tin"; - break; - case LogicalTypeId::DECIMAL: { - uint8_t width, scale; - type.GetDecimalProperties(width, scale); - string format = "d:" + to_string(width) + "," + to_string(scale); - root_holder.owned_type_names.push_back(AddName(format)); - child.format = root_holder.owned_type_names.back().get(); - break; - } - case LogicalTypeId::SQLNULL: { - child.format = "n"; - break; - } - case LogicalTypeId::BLOB: - case LogicalTypeId::BIT: { - if (options.arrow_offset_size == ArrowOffsetSize::LARGE) { - child.format = "Z"; - } else { - child.format = "z"; - } - break; - } - case LogicalTypeId::LIST: { - child.format = "+l"; - child.n_children = 1; - root_holder.nested_children.emplace_back(); - root_holder.nested_children.back().resize(1); - root_holder.nested_children_ptr.emplace_back(); - root_holder.nested_children_ptr.back().push_back(&root_holder.nested_children.back()[0]); - InitializeChild(root_holder.nested_children.back()[0], root_holder); - child.children = &root_holder.nested_children_ptr.back()[0]; - child.children[0]->name = "l"; - SetArrowFormat(root_holder, **child.children, ListType::GetChildType(type), options); - break; - } - case LogicalTypeId::STRUCT: { - child.format = "+s"; - auto &child_types = StructType::GetChildTypes(type); - child.n_children = child_types.size(); - root_holder.nested_children.emplace_back(); - root_holder.nested_children.back().resize(child_types.size()); - root_holder.nested_children_ptr.emplace_back(); - root_holder.nested_children_ptr.back().resize(child_types.size()); - for (idx_t type_idx = 0; type_idx < child_types.size(); type_idx++) { - root_holder.nested_children_ptr.back()[type_idx] = &root_holder.nested_children.back()[type_idx]; - } - child.children = &root_holder.nested_children_ptr.back()[0]; - for (size_t type_idx = 0; type_idx < child_types.size(); type_idx++) { - - InitializeChild(*child.children[type_idx], root_holder); - - root_holder.owned_type_names.push_back(AddName(child_types[type_idx].first)); - - child.children[type_idx]->name = root_holder.owned_type_names.back().get(); - SetArrowFormat(root_holder, *child.children[type_idx], child_types[type_idx].second, options); - } - break; - } - case LogicalTypeId::MAP: { - SetArrowMapFormat(root_holder, child, type, options); - break; - } - case LogicalTypeId::UNION: { - std::string format = "+us:"; - - auto &child_types = UnionType::CopyMemberTypes(type); - child.n_children = child_types.size(); - root_holder.nested_children.emplace_back(); - root_holder.nested_children.back().resize(child_types.size()); - root_holder.nested_children_ptr.emplace_back(); - root_holder.nested_children_ptr.back().resize(child_types.size()); - for (idx_t type_idx = 0; type_idx < child_types.size(); type_idx++) { - root_holder.nested_children_ptr.back()[type_idx] = &root_holder.nested_children.back()[type_idx]; - } - child.children = &root_holder.nested_children_ptr.back()[0]; - for (size_t type_idx = 0; type_idx < child_types.size(); type_idx++) { - - InitializeChild(*child.children[type_idx], root_holder); - - root_holder.owned_type_names.push_back(AddName(child_types[type_idx].first)); - - child.children[type_idx]->name = root_holder.owned_type_names.back().get(); - SetArrowFormat(root_holder, *child.children[type_idx], child_types[type_idx].second, options); - - format += to_string(type_idx) + ","; - } - - format.pop_back(); - - root_holder.owned_type_names.push_back(AddName(format)); - child.format = root_holder.owned_type_names.back().get(); - - break; - } - case LogicalTypeId::ENUM: { - // TODO what do we do with pointer enums here? - switch (EnumType::GetPhysicalType(type)) { - case PhysicalType::UINT8: - child.format = "C"; - break; - case PhysicalType::UINT16: - child.format = "S"; - break; - case PhysicalType::UINT32: - child.format = "I"; - break; - default: - throw InternalException("Unsupported Enum Internal Type"); - } - root_holder.nested_children.emplace_back(); - root_holder.nested_children.back().resize(1); - root_holder.nested_children_ptr.emplace_back(); - root_holder.nested_children_ptr.back().push_back(&root_holder.nested_children.back()[0]); - InitializeChild(root_holder.nested_children.back()[0], root_holder); - child.dictionary = root_holder.nested_children_ptr.back()[0]; - child.dictionary->format = "u"; - break; - } - default: - throw NotImplementedException("Unsupported Arrow type " + type.ToString()); - } -} - -void ArrowConverter::ToArrowSchema(ArrowSchema *out_schema, const vector &types, - const vector &names, const ClientProperties &options) { - D_ASSERT(out_schema); - D_ASSERT(types.size() == names.size()); - idx_t column_count = types.size(); - // Allocate as unique_ptr first to cleanup properly on error - auto root_holder = make_uniq(); - - // Allocate the children - root_holder->children.resize(column_count); - root_holder->children_ptrs.resize(column_count, nullptr); - for (size_t i = 0; i < column_count; ++i) { - root_holder->children_ptrs[i] = &root_holder->children[i]; - } - out_schema->children = root_holder->children_ptrs.data(); - out_schema->n_children = column_count; - - // Store the schema - out_schema->format = "+s"; // struct apparently - out_schema->flags = 0; - out_schema->metadata = nullptr; - out_schema->name = "duckdb_query_result"; - out_schema->dictionary = nullptr; - - // Configure all child schemas - for (idx_t col_idx = 0; col_idx < column_count; col_idx++) { - root_holder->owned_column_names.push_back(AddName(names[col_idx])); - auto &child = root_holder->children[col_idx]; - InitializeChild(child, *root_holder, names[col_idx]); - SetArrowFormat(*root_holder, child, types[col_idx], options); - } - - // Release ownership to caller - out_schema->private_data = root_holder.release(); - out_schema->release = ReleaseDuckDBArrowSchema; -} - -} // namespace duckdb - - - - - - - - - - - - - -namespace duckdb { - -ArrowSchemaWrapper::~ArrowSchemaWrapper() { - if (arrow_schema.release) { - arrow_schema.release(&arrow_schema); - arrow_schema.release = nullptr; - } -} - -ArrowArrayWrapper::~ArrowArrayWrapper() { - if (arrow_array.release) { - arrow_array.release(&arrow_array); - arrow_array.release = nullptr; - } -} - -ArrowArrayStreamWrapper::~ArrowArrayStreamWrapper() { - if (arrow_array_stream.release) { - arrow_array_stream.release(&arrow_array_stream); - arrow_array_stream.release = nullptr; - } -} - -void ArrowArrayStreamWrapper::GetSchema(ArrowSchemaWrapper &schema) { - D_ASSERT(arrow_array_stream.get_schema); - // LCOV_EXCL_START - if (arrow_array_stream.get_schema(&arrow_array_stream, &schema.arrow_schema)) { - throw InvalidInputException("arrow_scan: get_schema failed(): %s", string(GetError())); - } - if (!schema.arrow_schema.release) { - throw InvalidInputException("arrow_scan: released schema passed"); - } - if (schema.arrow_schema.n_children < 1) { - throw InvalidInputException("arrow_scan: empty schema passed"); - } - // LCOV_EXCL_STOP -} - -shared_ptr ArrowArrayStreamWrapper::GetNextChunk() { - auto current_chunk = make_shared(); - if (arrow_array_stream.get_next(&arrow_array_stream, ¤t_chunk->arrow_array)) { // LCOV_EXCL_START - throw InvalidInputException("arrow_scan: get_next failed(): %s", string(GetError())); - } // LCOV_EXCL_STOP - - return current_chunk; -} - -const char *ArrowArrayStreamWrapper::GetError() { // LCOV_EXCL_START - return arrow_array_stream.get_last_error(&arrow_array_stream); -} // LCOV_EXCL_STOP - -int ResultArrowArrayStreamWrapper::MyStreamGetSchema(struct ArrowArrayStream *stream, struct ArrowSchema *out) { - if (!stream->release) { - return -1; - } - auto my_stream = reinterpret_cast(stream->private_data); - if (!my_stream->column_types.empty()) { - ArrowConverter::ToArrowSchema(out, my_stream->column_types, my_stream->column_names, - my_stream->result->client_properties); - return 0; - } - - auto &result = *my_stream->result; - if (result.HasError()) { - my_stream->last_error = result.GetErrorObject(); - return -1; - } - if (result.type == QueryResultType::STREAM_RESULT) { - auto &stream_result = result.Cast(); - if (!stream_result.IsOpen()) { - my_stream->last_error = PreservedError("Query Stream is closed"); - return -1; - } - } - if (my_stream->column_types.empty()) { - my_stream->column_types = result.types; - my_stream->column_names = result.names; - } - ArrowConverter::ToArrowSchema(out, my_stream->column_types, my_stream->column_names, - my_stream->result->client_properties); - return 0; -} - -int ResultArrowArrayStreamWrapper::MyStreamGetNext(struct ArrowArrayStream *stream, struct ArrowArray *out) { - if (!stream->release) { - return -1; - } - auto my_stream = reinterpret_cast(stream->private_data); - auto &result = *my_stream->result; - auto &scan_state = *my_stream->scan_state; - if (result.HasError()) { - my_stream->last_error = result.GetErrorObject(); - return -1; - } - if (result.type == QueryResultType::STREAM_RESULT) { - auto &stream_result = result.Cast(); - if (!stream_result.IsOpen()) { - // Nothing to output - out->release = nullptr; - return 0; - } - } - if (my_stream->column_types.empty()) { - my_stream->column_types = result.types; - my_stream->column_names = result.names; - } - idx_t result_count; - PreservedError error; - if (!ArrowUtil::TryFetchChunk(scan_state, result.client_properties, my_stream->batch_size, out, result_count, - error)) { - D_ASSERT(error); - my_stream->last_error = error; - return -1; - } - if (result_count == 0) { - // Nothing to output - out->release = nullptr; - } - return 0; -} - -void ResultArrowArrayStreamWrapper::MyStreamRelease(struct ArrowArrayStream *stream) { - if (!stream || !stream->release) { - return; - } - stream->release = nullptr; - delete reinterpret_cast(stream->private_data); -} - -const char *ResultArrowArrayStreamWrapper::MyStreamGetLastError(struct ArrowArrayStream *stream) { - if (!stream->release) { - return "stream was released"; - } - D_ASSERT(stream->private_data); - auto my_stream = reinterpret_cast(stream->private_data); - return my_stream->last_error.Message().c_str(); -} - -ResultArrowArrayStreamWrapper::ResultArrowArrayStreamWrapper(unique_ptr result_p, idx_t batch_size_p) - : result(std::move(result_p)), scan_state(make_uniq(*result)) { - //! We first initialize the private data of the stream - stream.private_data = this; - //! Ceil Approx_Batch_Size/STANDARD_VECTOR_SIZE - if (batch_size_p == 0) { - throw std::runtime_error("Approximate Batch Size of Record Batch MUST be higher than 0"); - } - batch_size = batch_size_p; - //! We initialize the stream functions - stream.get_schema = ResultArrowArrayStreamWrapper::MyStreamGetSchema; - stream.get_next = ResultArrowArrayStreamWrapper::MyStreamGetNext; - stream.release = ResultArrowArrayStreamWrapper::MyStreamRelease; - stream.get_last_error = ResultArrowArrayStreamWrapper::MyStreamGetLastError; -} - -bool ArrowUtil::TryFetchChunk(ChunkScanState &scan_state, ClientProperties options, idx_t batch_size, ArrowArray *out, - idx_t &count, PreservedError &error) { - count = 0; - ArrowAppender appender(scan_state.Types(), batch_size, std::move(options)); - auto remaining_tuples_in_chunk = scan_state.RemainingInChunk(); - if (remaining_tuples_in_chunk) { - // We start by scanning the non-finished current chunk - idx_t cur_consumption = MinValue(remaining_tuples_in_chunk, batch_size); - count += cur_consumption; - auto ¤t_chunk = scan_state.CurrentChunk(); - appender.Append(current_chunk, scan_state.CurrentOffset(), scan_state.CurrentOffset() + cur_consumption, - current_chunk.size()); - scan_state.IncreaseOffset(cur_consumption); - } - while (count < batch_size) { - if (!scan_state.LoadNextChunk(error)) { - if (scan_state.HasError()) { - error = scan_state.GetError(); - } - return false; - } - if (scan_state.ChunkIsEmpty()) { - // The scan was successful, but an empty chunk was returned - break; - } - auto ¤t_chunk = scan_state.CurrentChunk(); - if (scan_state.Finished() || current_chunk.size() == 0) { - break; - } - // The amount we still need to append into this chunk - auto remaining = batch_size - count; - - // The amount remaining, capped by the amount left in the current chunk - auto to_append_to_batch = MinValue(remaining, scan_state.RemainingInChunk()); - appender.Append(current_chunk, 0, to_append_to_batch, current_chunk.size()); - count += to_append_to_batch; - scan_state.IncreaseOffset(to_append_to_batch); - } - if (count > 0) { - *out = appender.Finalize(); - } - return true; -} - -idx_t ArrowUtil::FetchChunk(ChunkScanState &scan_state, ClientProperties options, idx_t chunk_size, ArrowArray *out) { - PreservedError error; - idx_t result_count; - if (!TryFetchChunk(scan_state, std::move(options), chunk_size, out, result_count, error)) { - error.Throw(); - } - return result_count; -} - -} // namespace duckdb - - - -namespace duckdb { - -void DuckDBAssertInternal(bool condition, const char *condition_name, const char *file, int linenr) { -#ifdef DISABLE_ASSERTIONS - return; -#endif - if (condition) { - return; - } - throw InternalException("Assertion triggered in file \"%s\" on line %d: %s%s", file, linenr, condition_name, - Exception::GetStackTrace()); -} - -} // namespace duckdb - - - - - - -#include - -namespace duckdb { - -Value ConvertVectorToValue(vector set) { - if (set.empty()) { - return Value::EMPTYLIST(LogicalType::BOOLEAN); - } - return Value::LIST(std::move(set)); -} - -vector ParseColumnList(const vector &set, vector &names, const string &loption) { - vector result; - - if (set.empty()) { - throw BinderException("\"%s\" expects a column list or * as parameter", loption); - } - // list of options: parse the list - case_insensitive_map_t option_map; - for (idx_t i = 0; i < set.size(); i++) { - option_map[set[i].ToString()] = false; - } - result.resize(names.size(), false); - for (idx_t i = 0; i < names.size(); i++) { - auto entry = option_map.find(names[i]); - if (entry != option_map.end()) { - result[i] = true; - entry->second = true; - } - } - for (auto &entry : option_map) { - if (!entry.second) { - throw BinderException("\"%s\" expected to find %s, but it was not found in the table", loption, - entry.first.c_str()); - } - } - return result; -} - -vector ParseColumnList(const Value &value, vector &names, const string &loption) { - vector result; - - // Only accept a list of arguments - if (value.type().id() != LogicalTypeId::LIST) { - // Support a single argument if it's '*' - if (value.type().id() == LogicalTypeId::VARCHAR && value.GetValue() == "*") { - result.resize(names.size(), true); - return result; - } - throw BinderException("\"%s\" expects a column list or * as parameter", loption); - } - auto &children = ListValue::GetChildren(value); - // accept '*' as single argument - if (children.size() == 1 && children[0].type().id() == LogicalTypeId::VARCHAR && - children[0].GetValue() == "*") { - result.resize(names.size(), true); - return result; - } - return ParseColumnList(children, names, loption); -} - -vector ParseColumnsOrdered(const vector &set, vector &names, const string &loption) { - vector result; - - if (set.empty()) { - throw BinderException("\"%s\" expects a column list or * as parameter", loption); - } - - // Maps option to bool indicating if its found and the index in the original set - case_insensitive_map_t> option_map; - for (idx_t i = 0; i < set.size(); i++) { - option_map[set[i].ToString()] = {false, i}; - } - result.resize(option_map.size()); - - for (idx_t i = 0; i < names.size(); i++) { - auto entry = option_map.find(names[i]); - if (entry != option_map.end()) { - result[entry->second.second] = i; - entry->second.first = true; - } - } - for (auto &entry : option_map) { - if (!entry.second.first) { - throw BinderException("\"%s\" expected to find %s, but it was not found in the table", loption, - entry.first.c_str()); - } - } - return result; -} - -vector ParseColumnsOrdered(const Value &value, vector &names, const string &loption) { - vector result; - - // Only accept a list of arguments - if (value.type().id() != LogicalTypeId::LIST) { - // Support a single argument if it's '*' - if (value.type().id() == LogicalTypeId::VARCHAR && value.GetValue() == "*") { - result.resize(names.size(), 0); - std::iota(std::begin(result), std::end(result), 0); - return result; - } - throw BinderException("\"%s\" expects a column list or * as parameter", loption); - } - auto &children = ListValue::GetChildren(value); - // accept '*' as single argument - if (children.size() == 1 && children[0].type().id() == LogicalTypeId::VARCHAR && - children[0].GetValue() == "*") { - result.resize(names.size(), 0); - std::iota(std::begin(result), std::end(result), 0); - return result; - } - return ParseColumnsOrdered(children, names, loption); -} - -} // namespace duckdb - - - - - - - -#include - -namespace duckdb { - -const idx_t BoxRenderer::SPLIT_COLUMN = idx_t(-1); - -BoxRenderer::BoxRenderer(BoxRendererConfig config_p) : config(std::move(config_p)) { -} - -string BoxRenderer::ToString(ClientContext &context, const vector &names, const ColumnDataCollection &result) { - std::stringstream ss; - Render(context, names, result, ss); - return ss.str(); -} - -void BoxRenderer::Print(ClientContext &context, const vector &names, const ColumnDataCollection &result) { - Printer::Print(ToString(context, names, result)); -} - -void BoxRenderer::RenderValue(std::ostream &ss, const string &value, idx_t column_width, - ValueRenderAlignment alignment) { - auto render_width = Utf8Proc::RenderWidth(value); - - const string *render_value = &value; - string small_value; - if (render_width > column_width) { - // the string is too large to fit in this column! - // the size of this column must have been reduced - // figure out how much of this value we can render - idx_t pos = 0; - idx_t current_render_width = config.DOTDOTDOT_LENGTH; - while (pos < value.size()) { - // check if this character fits... - auto char_size = Utf8Proc::RenderWidth(value.c_str(), value.size(), pos); - if (current_render_width + char_size >= column_width) { - // it doesn't! stop - break; - } - // it does! move to the next character - current_render_width += char_size; - pos = Utf8Proc::NextGraphemeCluster(value.c_str(), value.size(), pos); - } - small_value = value.substr(0, pos) + config.DOTDOTDOT; - render_value = &small_value; - render_width = current_render_width; - } - auto padding_count = (column_width - render_width) + 2; - idx_t lpadding; - idx_t rpadding; - switch (alignment) { - case ValueRenderAlignment::LEFT: - lpadding = 1; - rpadding = padding_count - 1; - break; - case ValueRenderAlignment::MIDDLE: - lpadding = padding_count / 2; - rpadding = padding_count - lpadding; - break; - case ValueRenderAlignment::RIGHT: - lpadding = padding_count - 1; - rpadding = 1; - break; - default: - throw InternalException("Unrecognized value renderer alignment"); - } - ss << config.VERTICAL; - ss << string(lpadding, ' '); - ss << *render_value; - ss << string(rpadding, ' '); -} - -string BoxRenderer::RenderType(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::TINYINT: - return "int8"; - case LogicalTypeId::SMALLINT: - return "int16"; - case LogicalTypeId::INTEGER: - return "int32"; - case LogicalTypeId::BIGINT: - return "int64"; - case LogicalTypeId::HUGEINT: - return "int128"; - case LogicalTypeId::UTINYINT: - return "uint8"; - case LogicalTypeId::USMALLINT: - return "uint16"; - case LogicalTypeId::UINTEGER: - return "uint32"; - case LogicalTypeId::UBIGINT: - return "uint64"; - case LogicalTypeId::LIST: { - auto child = RenderType(ListType::GetChildType(type)); - return child + "[]"; - } - default: - return StringUtil::Lower(type.ToString()); - } -} - -ValueRenderAlignment BoxRenderer::TypeAlignment(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::TINYINT: - case LogicalTypeId::SMALLINT: - case LogicalTypeId::INTEGER: - case LogicalTypeId::BIGINT: - case LogicalTypeId::HUGEINT: - case LogicalTypeId::UTINYINT: - case LogicalTypeId::USMALLINT: - case LogicalTypeId::UINTEGER: - case LogicalTypeId::UBIGINT: - case LogicalTypeId::DECIMAL: - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - return ValueRenderAlignment::RIGHT; - default: - return ValueRenderAlignment::LEFT; - } -} - -list BoxRenderer::FetchRenderCollections(ClientContext &context, - const ColumnDataCollection &result, idx_t top_rows, - idx_t bottom_rows) { - auto column_count = result.ColumnCount(); - vector varchar_types; - for (idx_t c = 0; c < column_count; c++) { - varchar_types.emplace_back(LogicalType::VARCHAR); - } - std::list collections; - collections.emplace_back(context, varchar_types); - collections.emplace_back(context, varchar_types); - - auto &top_collection = collections.front(); - auto &bottom_collection = collections.back(); - - DataChunk fetch_result; - fetch_result.Initialize(context, result.Types()); - - DataChunk insert_result; - insert_result.Initialize(context, varchar_types); - - // fetch the top rows from the ColumnDataCollection - idx_t chunk_idx = 0; - idx_t row_idx = 0; - while (row_idx < top_rows) { - fetch_result.Reset(); - insert_result.Reset(); - // fetch the next chunk - result.FetchChunk(chunk_idx, fetch_result); - idx_t insert_count = MinValue(fetch_result.size(), top_rows - row_idx); - - // cast all columns to varchar - for (idx_t c = 0; c < column_count; c++) { - VectorOperations::Cast(context, fetch_result.data[c], insert_result.data[c], insert_count); - } - insert_result.SetCardinality(insert_count); - - // construct the render collection - top_collection.Append(insert_result); - - chunk_idx++; - row_idx += fetch_result.size(); - } - - // fetch the bottom rows from the ColumnDataCollection - row_idx = 0; - chunk_idx = result.ChunkCount() - 1; - while (row_idx < bottom_rows) { - fetch_result.Reset(); - insert_result.Reset(); - // fetch the next chunk - result.FetchChunk(chunk_idx, fetch_result); - idx_t insert_count = MinValue(fetch_result.size(), bottom_rows - row_idx); - - // invert the rows - SelectionVector inverted_sel(insert_count); - for (idx_t r = 0; r < insert_count; r++) { - inverted_sel.set_index(r, fetch_result.size() - r - 1); - } - - for (idx_t c = 0; c < column_count; c++) { - Vector slice(fetch_result.data[c], inverted_sel, insert_count); - VectorOperations::Cast(context, slice, insert_result.data[c], insert_count); - } - insert_result.SetCardinality(insert_count); - // construct the render collection - bottom_collection.Append(insert_result); - - chunk_idx--; - row_idx += fetch_result.size(); - } - return collections; -} - -list BoxRenderer::PivotCollections(ClientContext &context, list input, - vector &column_names, - vector &result_types, idx_t row_count) { - auto &top = input.front(); - auto &bottom = input.back(); - - vector varchar_types; - vector new_names; - new_names.emplace_back("Column"); - new_names.emplace_back("Type"); - varchar_types.emplace_back(LogicalType::VARCHAR); - varchar_types.emplace_back(LogicalType::VARCHAR); - for (idx_t r = 0; r < top.Count(); r++) { - new_names.emplace_back("Row " + to_string(r + 1)); - varchar_types.emplace_back(LogicalType::VARCHAR); - } - for (idx_t r = 0; r < bottom.Count(); r++) { - auto row_index = row_count - bottom.Count() + r + 1; - new_names.emplace_back("Row " + to_string(row_index)); - varchar_types.emplace_back(LogicalType::VARCHAR); - } - // - DataChunk row_chunk; - row_chunk.Initialize(Allocator::DefaultAllocator(), varchar_types); - std::list result; - result.emplace_back(context, varchar_types); - result.emplace_back(context, varchar_types); - auto &res_coll = result.front(); - ColumnDataAppendState append_state; - res_coll.InitializeAppend(append_state); - for (idx_t c = 0; c < top.ColumnCount(); c++) { - vector column_ids {c}; - auto row_index = row_chunk.size(); - idx_t current_index = 0; - row_chunk.SetValue(current_index++, row_index, column_names[c]); - row_chunk.SetValue(current_index++, row_index, RenderType(result_types[c])); - for (auto &collection : input) { - for (auto &chunk : collection.Chunks(column_ids)) { - for (idx_t r = 0; r < chunk.size(); r++) { - row_chunk.SetValue(current_index++, row_index, chunk.GetValue(0, r)); - } - } - } - row_chunk.SetCardinality(row_chunk.size() + 1); - if (row_chunk.size() == STANDARD_VECTOR_SIZE || c + 1 == top.ColumnCount()) { - res_coll.Append(append_state, row_chunk); - row_chunk.Reset(); - } - } - column_names = std::move(new_names); - result_types = std::move(varchar_types); - return result; -} - -string ConvertRenderValue(const string &input) { - return StringUtil::Replace(StringUtil::Replace(input, "\n", "\\n"), string("\0", 1), "\\0"); -} - -string BoxRenderer::GetRenderValue(ColumnDataRowCollection &rows, idx_t c, idx_t r) { - try { - auto row = rows.GetValue(c, r); - if (row.IsNull()) { - return config.null_value; - } - return ConvertRenderValue(StringValue::Get(row)); - } catch (std::exception &ex) { - return "????INVALID VALUE - " + string(ex.what()) + "?????"; - } -} - -vector BoxRenderer::ComputeRenderWidths(const vector &names, const vector &result_types, - list &collections, idx_t min_width, - idx_t max_width, vector &column_map, idx_t &total_length) { - auto column_count = result_types.size(); - - vector widths; - widths.reserve(column_count); - for (idx_t c = 0; c < column_count; c++) { - auto name_width = Utf8Proc::RenderWidth(ConvertRenderValue(names[c])); - auto type_width = Utf8Proc::RenderWidth(RenderType(result_types[c])); - widths.push_back(MaxValue(name_width, type_width)); - } - - // now iterate over the data in the render collection and find out the true max width - for (auto &collection : collections) { - for (auto &chunk : collection.Chunks()) { - for (idx_t c = 0; c < column_count; c++) { - auto string_data = FlatVector::GetData(chunk.data[c]); - for (idx_t r = 0; r < chunk.size(); r++) { - string render_value; - if (FlatVector::IsNull(chunk.data[c], r)) { - render_value = config.null_value; - } else { - render_value = ConvertRenderValue(string_data[r].GetString()); - } - auto render_width = Utf8Proc::RenderWidth(render_value); - widths[c] = MaxValue(render_width, widths[c]); - } - } - } - } - - // figure out the total length - // we start off with a pipe (|) - total_length = 1; - for (idx_t c = 0; c < widths.size(); c++) { - // each column has a space at the beginning, and a space plus a pipe (|) at the end - // hence + 3 - total_length += widths[c] + 3; - } - if (total_length < min_width) { - // if there are hidden rows we should always display that - // stretch up the first column until we have space to show the row count - widths[0] += min_width - total_length; - total_length = min_width; - } - // now we need to constrain the length - unordered_set pruned_columns; - if (total_length > max_width) { - // before we remove columns, check if we can just reduce the size of columns - for (auto &w : widths) { - if (w > config.max_col_width) { - auto max_diff = w - config.max_col_width; - if (total_length - max_diff <= max_width) { - // if we reduce the size of this column we fit within the limits! - // reduce the width exactly enough so that the box fits - w -= total_length - max_width; - total_length = max_width; - break; - } else { - // reducing the width of this column does not make the result fit - // reduce the column width by the maximum amount anyway - w = config.max_col_width; - total_length -= max_diff; - } - } - } - - if (total_length > max_width) { - // the total length is still too large - // we need to remove columns! - // first, we add 6 characters to the total length - // this is what we need to add the "..." in the middle - total_length += 3 + config.DOTDOTDOT_LENGTH; - // now select columns to prune - // we select columns in zig-zag order starting from the middle - // e.g. if we have 10 columns, we remove #5, then #4, then #6, then #3, then #7, etc - int64_t offset = 0; - while (total_length > max_width) { - idx_t c = column_count / 2 + offset; - total_length -= widths[c] + 3; - pruned_columns.insert(c); - if (offset >= 0) { - offset = -offset - 1; - } else { - offset = -offset; - } - } - } - } - - bool added_split_column = false; - vector new_widths; - for (idx_t c = 0; c < column_count; c++) { - if (pruned_columns.find(c) == pruned_columns.end()) { - column_map.push_back(c); - new_widths.push_back(widths[c]); - } else { - if (!added_split_column) { - // "..." - column_map.push_back(SPLIT_COLUMN); - new_widths.push_back(config.DOTDOTDOT_LENGTH); - added_split_column = true; - } - } - } - return new_widths; -} - -void BoxRenderer::RenderHeader(const vector &names, const vector &result_types, - const vector &column_map, const vector &widths, - const vector &boundaries, idx_t total_length, bool has_results, - std::ostream &ss) { - auto column_count = column_map.size(); - // render the top line - ss << config.LTCORNER; - idx_t column_index = 0; - for (idx_t k = 0; k < total_length - 2; k++) { - if (column_index + 1 < column_count && k == boundaries[column_index]) { - ss << config.TMIDDLE; - column_index++; - } else { - ss << config.HORIZONTAL; - } - } - ss << config.RTCORNER; - ss << std::endl; - - // render the header names - for (idx_t c = 0; c < column_count; c++) { - auto column_idx = column_map[c]; - string name; - if (column_idx == SPLIT_COLUMN) { - name = config.DOTDOTDOT; - } else { - name = ConvertRenderValue(names[column_idx]); - } - RenderValue(ss, name, widths[c]); - } - ss << config.VERTICAL; - ss << std::endl; - - // render the types - if (config.render_mode == RenderMode::ROWS) { - for (idx_t c = 0; c < column_count; c++) { - auto column_idx = column_map[c]; - auto type = column_idx == SPLIT_COLUMN ? "" : RenderType(result_types[column_idx]); - RenderValue(ss, type, widths[c]); - } - ss << config.VERTICAL; - ss << std::endl; - } - - // render the line under the header - ss << config.LMIDDLE; - column_index = 0; - for (idx_t k = 0; k < total_length - 2; k++) { - if (has_results && column_index + 1 < column_count && k == boundaries[column_index]) { - ss << config.MIDDLE; - column_index++; - } else { - ss << config.HORIZONTAL; - } - } - ss << config.RMIDDLE; - ss << std::endl; -} - -void BoxRenderer::RenderValues(const list &collections, const vector &column_map, - const vector &widths, const vector &result_types, std::ostream &ss) { - auto &top_collection = collections.front(); - auto &bottom_collection = collections.back(); - // render the top rows - auto top_rows = top_collection.Count(); - auto bottom_rows = bottom_collection.Count(); - auto column_count = column_map.size(); - - vector alignments; - if (config.render_mode == RenderMode::ROWS) { - for (idx_t c = 0; c < column_count; c++) { - auto column_idx = column_map[c]; - if (column_idx == SPLIT_COLUMN) { - alignments.push_back(ValueRenderAlignment::MIDDLE); - } else { - alignments.push_back(TypeAlignment(result_types[column_idx])); - } - } - } - - auto rows = top_collection.GetRows(); - for (idx_t r = 0; r < top_rows; r++) { - for (idx_t c = 0; c < column_count; c++) { - auto column_idx = column_map[c]; - string str; - if (column_idx == SPLIT_COLUMN) { - str = config.DOTDOTDOT; - } else { - str = GetRenderValue(rows, column_idx, r); - } - ValueRenderAlignment alignment; - if (config.render_mode == RenderMode::ROWS) { - alignment = alignments[c]; - } else { - if (c < 2) { - alignment = ValueRenderAlignment::LEFT; - } else if (c == SPLIT_COLUMN) { - alignment = ValueRenderAlignment::MIDDLE; - } else { - alignment = ValueRenderAlignment::RIGHT; - } - } - RenderValue(ss, str, widths[c], alignment); - } - ss << config.VERTICAL; - ss << std::endl; - } - - if (bottom_rows > 0) { - if (config.render_mode == RenderMode::COLUMNS) { - throw InternalException("Columns render mode does not support bottom rows"); - } - // render the bottom rows - // first render the divider - auto brows = bottom_collection.GetRows(); - for (idx_t k = 0; k < 3; k++) { - for (idx_t c = 0; c < column_count; c++) { - auto column_idx = column_map[c]; - string str; - auto alignment = alignments[c]; - if (alignment == ValueRenderAlignment::MIDDLE || column_idx == SPLIT_COLUMN) { - str = config.DOT; - } else { - // align the dots in the center of the column - auto top_value = GetRenderValue(rows, column_idx, top_rows - 1); - auto bottom_value = GetRenderValue(brows, column_idx, bottom_rows - 1); - auto top_length = MinValue(widths[c], Utf8Proc::RenderWidth(top_value)); - auto bottom_length = MinValue(widths[c], Utf8Proc::RenderWidth(bottom_value)); - auto dot_length = MinValue(top_length, bottom_length); - if (top_length == 0) { - dot_length = bottom_length; - } else if (bottom_length == 0) { - dot_length = top_length; - } - if (dot_length > 1) { - auto padding = dot_length - 1; - idx_t left_padding, right_padding; - switch (alignment) { - case ValueRenderAlignment::LEFT: - left_padding = padding / 2; - right_padding = padding - left_padding; - break; - case ValueRenderAlignment::RIGHT: - right_padding = padding / 2; - left_padding = padding - right_padding; - break; - default: - throw InternalException("Unrecognized value renderer alignment"); - } - str = string(left_padding, ' ') + config.DOT + string(right_padding, ' '); - } else { - if (dot_length == 0) { - // everything is empty - alignment = ValueRenderAlignment::MIDDLE; - } - str = config.DOT; - } - } - RenderValue(ss, str, widths[c], alignment); - } - ss << config.VERTICAL; - ss << std::endl; - } - // note that the bottom rows are in reverse order - for (idx_t r = 0; r < bottom_rows; r++) { - for (idx_t c = 0; c < column_count; c++) { - auto column_idx = column_map[c]; - string str; - if (column_idx == SPLIT_COLUMN) { - str = config.DOTDOTDOT; - } else { - str = GetRenderValue(brows, column_idx, bottom_rows - r - 1); - } - RenderValue(ss, str, widths[c], alignments[c]); - } - ss << config.VERTICAL; - ss << std::endl; - } - } -} - -void BoxRenderer::RenderRowCount(string row_count_str, string shown_str, const string &column_count_str, - const vector &boundaries, bool has_hidden_rows, bool has_hidden_columns, - idx_t total_length, idx_t row_count, idx_t column_count, idx_t minimum_row_length, - std::ostream &ss) { - // check if we can merge the row_count_str and the shown_str - bool display_shown_separately = has_hidden_rows; - if (has_hidden_rows && total_length >= row_count_str.size() + shown_str.size() + 5) { - // we can! - row_count_str += " " + shown_str; - shown_str = string(); - display_shown_separately = false; - minimum_row_length = row_count_str.size() + 4; - } - auto minimum_length = row_count_str.size() + column_count_str.size() + 6; - bool render_rows_and_columns = total_length >= minimum_length && - ((has_hidden_columns && row_count > 0) || (row_count >= 10 && column_count > 1)); - bool render_rows = total_length >= minimum_row_length && (row_count == 0 || row_count >= 10); - bool render_anything = true; - if (!render_rows && !render_rows_and_columns) { - render_anything = false; - } - // render the bottom of the result values, if there are any - if (row_count > 0) { - ss << (render_anything ? config.LMIDDLE : config.LDCORNER); - idx_t column_index = 0; - for (idx_t k = 0; k < total_length - 2; k++) { - if (column_index + 1 < boundaries.size() && k == boundaries[column_index]) { - ss << config.DMIDDLE; - column_index++; - } else { - ss << config.HORIZONTAL; - } - } - ss << (render_anything ? config.RMIDDLE : config.RDCORNER); - ss << std::endl; - } - if (!render_anything) { - return; - } - - if (render_rows_and_columns) { - ss << config.VERTICAL; - ss << " "; - ss << row_count_str; - ss << string(total_length - row_count_str.size() - column_count_str.size() - 4, ' '); - ss << column_count_str; - ss << " "; - ss << config.VERTICAL; - ss << std::endl; - } else if (render_rows) { - RenderValue(ss, row_count_str, total_length - 4); - ss << config.VERTICAL; - ss << std::endl; - - if (display_shown_separately) { - RenderValue(ss, shown_str, total_length - 4); - ss << config.VERTICAL; - ss << std::endl; - } - } - // render the bottom line - ss << config.LDCORNER; - for (idx_t k = 0; k < total_length - 2; k++) { - ss << config.HORIZONTAL; - } - ss << config.RDCORNER; - ss << std::endl; -} - -void BoxRenderer::Render(ClientContext &context, const vector &names, const ColumnDataCollection &result, - std::ostream &ss) { - if (result.ColumnCount() != names.size()) { - throw InternalException("Error in BoxRenderer::Render - unaligned columns and names"); - } - auto max_width = config.max_width; - if (max_width == 0) { - if (Printer::IsTerminal(OutputStream::STREAM_STDOUT)) { - max_width = Printer::TerminalWidth(); - } else { - max_width = 120; - } - } - // we do not support max widths under 80 - max_width = MaxValue(80, max_width); - - // figure out how many/which rows to render - idx_t row_count = result.Count(); - idx_t rows_to_render = MinValue(row_count, config.max_rows); - if (row_count <= config.max_rows + 3) { - // hiding rows adds 3 extra rows - // so hiding rows makes no sense if we are only slightly over the limit - // if we are 1 row over the limit hiding rows will actually increase the number of lines we display! - // in this case render all the rows - rows_to_render = row_count; - } - idx_t top_rows; - idx_t bottom_rows; - if (rows_to_render == row_count) { - top_rows = row_count; - bottom_rows = 0; - } else { - top_rows = rows_to_render / 2 + (rows_to_render % 2 != 0 ? 1 : 0); - bottom_rows = rows_to_render - top_rows; - } - auto row_count_str = to_string(row_count) + " rows"; - bool has_limited_rows = config.limit > 0 && row_count == config.limit; - if (has_limited_rows) { - row_count_str = "? rows"; - } - string shown_str; - bool has_hidden_rows = top_rows < row_count; - if (has_hidden_rows) { - shown_str = "("; - if (has_limited_rows) { - shown_str += ">" + to_string(config.limit - 1) + " rows, "; - } - shown_str += to_string(top_rows + bottom_rows) + " shown)"; - } - auto minimum_row_length = MaxValue(row_count_str.size(), shown_str.size()) + 4; - - // fetch the top and bottom render collections from the result - auto collections = FetchRenderCollections(context, result, top_rows, bottom_rows); - auto column_names = names; - auto result_types = result.Types(); - if (config.render_mode == RenderMode::COLUMNS) { - collections = PivotCollections(context, std::move(collections), column_names, result_types, row_count); - } - - // for each column, figure out the width - // start off by figuring out the name of the header by looking at the column name and column type - idx_t min_width = has_hidden_rows || row_count == 0 ? minimum_row_length : 0; - vector column_map; - idx_t total_length; - auto widths = - ComputeRenderWidths(column_names, result_types, collections, min_width, max_width, column_map, total_length); - - // render boundaries for the individual columns - vector boundaries; - for (idx_t c = 0; c < widths.size(); c++) { - idx_t render_boundary; - if (c == 0) { - render_boundary = widths[c] + 2; - } else { - render_boundary = boundaries[c - 1] + widths[c] + 3; - } - boundaries.push_back(render_boundary); - } - - // now begin rendering - // first render the header - RenderHeader(column_names, result_types, column_map, widths, boundaries, total_length, row_count > 0, ss); - - // render the values, if there are any - RenderValues(collections, column_map, widths, result_types, ss); - - // render the row count and column count - auto column_count_str = to_string(result.ColumnCount()) + " column"; - if (result.ColumnCount() > 1) { - column_count_str += "s"; - } - bool has_hidden_columns = false; - for (auto entry : column_map) { - if (entry == SPLIT_COLUMN) { - has_hidden_columns = true; - break; - } - } - idx_t column_count = column_map.size(); - if (config.render_mode == RenderMode::COLUMNS) { - if (has_hidden_columns) { - has_hidden_rows = true; - shown_str = " (" + to_string(column_count - 3) + " shown)"; - } else { - shown_str = string(); - } - } else { - if (has_hidden_columns) { - column_count--; - column_count_str += " (" + to_string(column_count) + " shown)"; - } - } - - RenderRowCount(std::move(row_count_str), std::move(shown_str), column_count_str, boundaries, has_hidden_rows, - has_hidden_columns, total_length, row_count, column_count, minimum_row_length, ss); -} - -} // namespace duckdb - - - -namespace duckdb { - -hash_t Checksum(uint64_t x) { - return x * UINT64_C(0xbf58476d1ce4e5b9); -} - -uint64_t Checksum(uint8_t *buffer, size_t size) { - uint64_t result = 5381; - uint64_t *ptr = reinterpret_cast(buffer); - size_t i; - // for efficiency, we first checksum uint64_t values - for (i = 0; i < size / 8; i++) { - result ^= Checksum(ptr[i]); - } - if (size - i * 8 > 0) { - // the remaining 0-7 bytes we hash using a string hash - result ^= Hash(buffer + i * 8, size - i * 8); - } - return result; -} - -} // namespace duckdb - - -namespace duckdb { - -StreamWrapper::~StreamWrapper() { -} - -CompressedFile::CompressedFile(CompressedFileSystem &fs, unique_ptr child_handle_p, const string &path) - : FileHandle(fs, path), compressed_fs(fs), child_handle(std::move(child_handle_p)) { -} - -CompressedFile::~CompressedFile() { - CompressedFile::Close(); -} - -void CompressedFile::Initialize(bool write) { - Close(); - - this->write = write; - stream_data.in_buf_size = compressed_fs.InBufferSize(); - stream_data.out_buf_size = compressed_fs.OutBufferSize(); - stream_data.in_buff = make_unsafe_uniq_array(stream_data.in_buf_size); - stream_data.in_buff_start = stream_data.in_buff.get(); - stream_data.in_buff_end = stream_data.in_buff.get(); - stream_data.out_buff = make_unsafe_uniq_array(stream_data.out_buf_size); - stream_data.out_buff_start = stream_data.out_buff.get(); - stream_data.out_buff_end = stream_data.out_buff.get(); - - stream_wrapper = compressed_fs.CreateStream(); - stream_wrapper->Initialize(*this, write); -} - -int64_t CompressedFile::ReadData(void *buffer, int64_t remaining) { - idx_t total_read = 0; - while (true) { - // first check if there are input bytes available in the output buffers - if (stream_data.out_buff_start != stream_data.out_buff_end) { - // there is! copy it into the output buffer - idx_t available = MinValue(remaining, stream_data.out_buff_end - stream_data.out_buff_start); - memcpy(data_ptr_t(buffer) + total_read, stream_data.out_buff_start, available); - - // increment the total read variables as required - stream_data.out_buff_start += available; - total_read += available; - remaining -= available; - if (remaining == 0) { - // done! read enough - return total_read; - } - } - if (!stream_wrapper) { - return total_read; - } - - // ran out of buffer: read more data from the child stream - stream_data.out_buff_start = stream_data.out_buff.get(); - stream_data.out_buff_end = stream_data.out_buff.get(); - D_ASSERT(stream_data.in_buff_start <= stream_data.in_buff_end); - D_ASSERT(stream_data.in_buff_end <= stream_data.in_buff_start + stream_data.in_buf_size); - - // read more input when requested and still data in the input stream - if (stream_data.refresh && (stream_data.in_buff_end == stream_data.in_buff.get() + stream_data.in_buf_size)) { - auto bufrem = stream_data.in_buff_end - stream_data.in_buff_start; - // buffer not empty, move remaining bytes to the beginning - memmove(stream_data.in_buff.get(), stream_data.in_buff_start, bufrem); - stream_data.in_buff_start = stream_data.in_buff.get(); - // refill the rest of input buffer - auto sz = child_handle->Read(stream_data.in_buff_start + bufrem, stream_data.in_buf_size - bufrem); - stream_data.in_buff_end = stream_data.in_buff_start + bufrem + sz; - if (sz <= 0) { - stream_wrapper.reset(); - break; - } - } - - // read more input if none available - if (stream_data.in_buff_start == stream_data.in_buff_end) { - // empty input buffer: refill from the start - stream_data.in_buff_start = stream_data.in_buff.get(); - stream_data.in_buff_end = stream_data.in_buff_start; - auto sz = child_handle->Read(stream_data.in_buff.get(), stream_data.in_buf_size); - if (sz <= 0) { - stream_wrapper.reset(); - break; - } - stream_data.in_buff_end = stream_data.in_buff_start + sz; - } - - auto finished = stream_wrapper->Read(stream_data); - if (finished) { - stream_wrapper.reset(); - } - } - return total_read; -} - -int64_t CompressedFile::WriteData(data_ptr_t buffer, int64_t nr_bytes) { - stream_wrapper->Write(*this, stream_data, buffer, nr_bytes); - return nr_bytes; -} - -void CompressedFile::Close() { - if (stream_wrapper) { - stream_wrapper->Close(); - stream_wrapper.reset(); - } - stream_data.in_buff.reset(); - stream_data.out_buff.reset(); - stream_data.out_buff_start = nullptr; - stream_data.out_buff_end = nullptr; - stream_data.in_buff_start = nullptr; - stream_data.in_buff_end = nullptr; - stream_data.in_buf_size = 0; - stream_data.out_buf_size = 0; -} - -int64_t CompressedFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes) { - auto &compressed_file = handle.Cast(); - return compressed_file.ReadData(buffer, nr_bytes); -} - -int64_t CompressedFileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes) { - auto &compressed_file = handle.Cast(); - return compressed_file.WriteData(data_ptr_cast(buffer), nr_bytes); -} - -void CompressedFileSystem::Reset(FileHandle &handle) { - auto &compressed_file = handle.Cast(); - compressed_file.child_handle->Reset(); - compressed_file.Initialize(compressed_file.write); -} - -int64_t CompressedFileSystem::GetFileSize(FileHandle &handle) { - auto &compressed_file = handle.Cast(); - return compressed_file.child_handle->GetFileSize(); -} - -bool CompressedFileSystem::OnDiskFile(FileHandle &handle) { - auto &compressed_file = handle.Cast(); - return compressed_file.child_handle->OnDiskFile(); -} - -bool CompressedFileSystem::CanSeek() { - return false; -} - -} // namespace duckdb - - - - - -namespace duckdb { - -constexpr const idx_t DConstants::INVALID_INDEX; -const row_t MAX_ROW_ID = 36028797018960000ULL; // 2^55 -const row_t MAX_ROW_ID_LOCAL = 72057594037920000ULL; // 2^56 -const column_t COLUMN_IDENTIFIER_ROW_ID = (column_t)-1; -const sel_t ZERO_VECTOR[STANDARD_VECTOR_SIZE] = {0}; -const double PI = 3.141592653589793; - -const transaction_t TRANSACTION_ID_START = 4611686018427388000ULL; // 2^62 -const transaction_t MAX_TRANSACTION_ID = NumericLimits::Maximum(); // 2^63 -const transaction_t NOT_DELETED_ID = NumericLimits::Maximum() - 1; // 2^64 - 1 -const transaction_t MAXIMUM_QUERY_ID = NumericLimits::Maximum(); // 2^64 - -bool IsPowerOfTwo(uint64_t v) { - return (v & (v - 1)) == 0; -} - -uint64_t NextPowerOfTwo(uint64_t v) { - v--; - v |= v >> 1; - v |= v >> 2; - v |= v >> 4; - v |= v >> 8; - v |= v >> 16; - v |= v >> 32; - v++; - return v; -} - -uint64_t PreviousPowerOfTwo(uint64_t v) { - return NextPowerOfTwo((v / 2) + 1); -} - -bool IsInvalidSchema(const string &str) { - return str.empty(); -} - -bool IsInvalidCatalog(const string &str) { - return str.empty(); -} - -bool IsRowIdColumnId(column_t column_id) { - return column_id == COLUMN_IDENTIFIER_ROW_ID; -} - -} // namespace duckdb -/* -** This code taken from the SQLite test library. Originally found on -** the internet. The original header comment follows this comment. -** The code is largerly unchanged, but there have been some modifications. -*/ -/* - * This code implements the MD5 message-digest algorithm. - * The algorithm is due to Ron Rivest. This code was - * written by Colin Plumb in 1993, no copyright is claimed. - * This code is in the public domain; do with it what you wish. - * - * Equivalent code is available from RSA Data Security, Inc. - * This code has been tested against that, and is equivalent, - * except that you don't need to include two pages of legalese - * with every copy. - * - * To compute the message digest of a chunk of bytes, declare an - * MD5Context structure, pass it to MD5Init, call MD5Update as - * needed on buffers full of bytes, and then call MD5Final, which - * will fill a supplied 16-byte array with the digest. - */ - - - -namespace duckdb { - -/* - * Note: this code is harmless on little-endian machines. - */ -static void ByteReverse(unsigned char *buf, unsigned longs) { - uint32_t t; - do { - t = (uint32_t)((unsigned)buf[3] << 8 | buf[2]) << 16 | ((unsigned)buf[1] << 8 | buf[0]); - *reinterpret_cast(buf) = t; - buf += 4; - } while (--longs); -} -/* The four core functions - F1 is optimized somewhat */ - -/* #define F1(x, y, z) (x & y | ~x & z) */ -#define F1(x, y, z) ((z) ^ ((x) & ((y) ^ (z)))) -#define F2(x, y, z) F1(z, x, y) -#define F3(x, y, z) ((x) ^ (y) ^ (z)) -#define F4(x, y, z) ((y) ^ ((x) | ~(z))) - -/* This is the central step in the MD5 algorithm. */ -#define MD5STEP(f, w, x, y, z, data, s) ((w) += f(x, y, z) + (data), (w) = (w) << (s) | (w) >> (32 - (s)), (w) += (x)) - -/* - * The core of the MD5 algorithm, this alters an existing MD5 hash to - * reflect the addition of 16 longwords of new data. MD5Update blocks - * the data and converts bytes into longwords for this routine. - */ -static void MD5Transform(uint32_t buf[4], const uint32_t in[16]) { - uint32_t a, b, c, d; - - a = buf[0]; - b = buf[1]; - c = buf[2]; - d = buf[3]; - - MD5STEP(F1, a, b, c, d, in[0] + 0xd76aa478, 7); - MD5STEP(F1, d, a, b, c, in[1] + 0xe8c7b756, 12); - MD5STEP(F1, c, d, a, b, in[2] + 0x242070db, 17); - MD5STEP(F1, b, c, d, a, in[3] + 0xc1bdceee, 22); - MD5STEP(F1, a, b, c, d, in[4] + 0xf57c0faf, 7); - MD5STEP(F1, d, a, b, c, in[5] + 0x4787c62a, 12); - MD5STEP(F1, c, d, a, b, in[6] + 0xa8304613, 17); - MD5STEP(F1, b, c, d, a, in[7] + 0xfd469501, 22); - MD5STEP(F1, a, b, c, d, in[8] + 0x698098d8, 7); - MD5STEP(F1, d, a, b, c, in[9] + 0x8b44f7af, 12); - MD5STEP(F1, c, d, a, b, in[10] + 0xffff5bb1, 17); - MD5STEP(F1, b, c, d, a, in[11] + 0x895cd7be, 22); - MD5STEP(F1, a, b, c, d, in[12] + 0x6b901122, 7); - MD5STEP(F1, d, a, b, c, in[13] + 0xfd987193, 12); - MD5STEP(F1, c, d, a, b, in[14] + 0xa679438e, 17); - MD5STEP(F1, b, c, d, a, in[15] + 0x49b40821, 22); - - MD5STEP(F2, a, b, c, d, in[1] + 0xf61e2562, 5); - MD5STEP(F2, d, a, b, c, in[6] + 0xc040b340, 9); - MD5STEP(F2, c, d, a, b, in[11] + 0x265e5a51, 14); - MD5STEP(F2, b, c, d, a, in[0] + 0xe9b6c7aa, 20); - MD5STEP(F2, a, b, c, d, in[5] + 0xd62f105d, 5); - MD5STEP(F2, d, a, b, c, in[10] + 0x02441453, 9); - MD5STEP(F2, c, d, a, b, in[15] + 0xd8a1e681, 14); - MD5STEP(F2, b, c, d, a, in[4] + 0xe7d3fbc8, 20); - MD5STEP(F2, a, b, c, d, in[9] + 0x21e1cde6, 5); - MD5STEP(F2, d, a, b, c, in[14] + 0xc33707d6, 9); - MD5STEP(F2, c, d, a, b, in[3] + 0xf4d50d87, 14); - MD5STEP(F2, b, c, d, a, in[8] + 0x455a14ed, 20); - MD5STEP(F2, a, b, c, d, in[13] + 0xa9e3e905, 5); - MD5STEP(F2, d, a, b, c, in[2] + 0xfcefa3f8, 9); - MD5STEP(F2, c, d, a, b, in[7] + 0x676f02d9, 14); - MD5STEP(F2, b, c, d, a, in[12] + 0x8d2a4c8a, 20); - - MD5STEP(F3, a, b, c, d, in[5] + 0xfffa3942, 4); - MD5STEP(F3, d, a, b, c, in[8] + 0x8771f681, 11); - MD5STEP(F3, c, d, a, b, in[11] + 0x6d9d6122, 16); - MD5STEP(F3, b, c, d, a, in[14] + 0xfde5380c, 23); - MD5STEP(F3, a, b, c, d, in[1] + 0xa4beea44, 4); - MD5STEP(F3, d, a, b, c, in[4] + 0x4bdecfa9, 11); - MD5STEP(F3, c, d, a, b, in[7] + 0xf6bb4b60, 16); - MD5STEP(F3, b, c, d, a, in[10] + 0xbebfbc70, 23); - MD5STEP(F3, a, b, c, d, in[13] + 0x289b7ec6, 4); - MD5STEP(F3, d, a, b, c, in[0] + 0xeaa127fa, 11); - MD5STEP(F3, c, d, a, b, in[3] + 0xd4ef3085, 16); - MD5STEP(F3, b, c, d, a, in[6] + 0x04881d05, 23); - MD5STEP(F3, a, b, c, d, in[9] + 0xd9d4d039, 4); - MD5STEP(F3, d, a, b, c, in[12] + 0xe6db99e5, 11); - MD5STEP(F3, c, d, a, b, in[15] + 0x1fa27cf8, 16); - MD5STEP(F3, b, c, d, a, in[2] + 0xc4ac5665, 23); - - MD5STEP(F4, a, b, c, d, in[0] + 0xf4292244, 6); - MD5STEP(F4, d, a, b, c, in[7] + 0x432aff97, 10); - MD5STEP(F4, c, d, a, b, in[14] + 0xab9423a7, 15); - MD5STEP(F4, b, c, d, a, in[5] + 0xfc93a039, 21); - MD5STEP(F4, a, b, c, d, in[12] + 0x655b59c3, 6); - MD5STEP(F4, d, a, b, c, in[3] + 0x8f0ccc92, 10); - MD5STEP(F4, c, d, a, b, in[10] + 0xffeff47d, 15); - MD5STEP(F4, b, c, d, a, in[1] + 0x85845dd1, 21); - MD5STEP(F4, a, b, c, d, in[8] + 0x6fa87e4f, 6); - MD5STEP(F4, d, a, b, c, in[15] + 0xfe2ce6e0, 10); - MD5STEP(F4, c, d, a, b, in[6] + 0xa3014314, 15); - MD5STEP(F4, b, c, d, a, in[13] + 0x4e0811a1, 21); - MD5STEP(F4, a, b, c, d, in[4] + 0xf7537e82, 6); - MD5STEP(F4, d, a, b, c, in[11] + 0xbd3af235, 10); - MD5STEP(F4, c, d, a, b, in[2] + 0x2ad7d2bb, 15); - MD5STEP(F4, b, c, d, a, in[9] + 0xeb86d391, 21); - - buf[0] += a; - buf[1] += b; - buf[2] += c; - buf[3] += d; -} - -/* - * Start MD5 accumulation. Set bit count to 0 and buffer to mysterious - * initialization constants. - */ -MD5Context::MD5Context() { - buf[0] = 0x67452301; - buf[1] = 0xefcdab89; - buf[2] = 0x98badcfe; - buf[3] = 0x10325476; - bits[0] = 0; - bits[1] = 0; -} - -/* - * Update context to reflect the concatenation of another buffer full - * of bytes. - */ -void MD5Context::MD5Update(const_data_ptr_t input, idx_t len) { - uint32_t t; - - /* Update bitcount */ - - t = bits[0]; - if ((bits[0] = t + ((uint32_t)len << 3)) < t) { - bits[1]++; /* Carry from low to high */ - } - bits[1] += len >> 29; - - t = (t >> 3) & 0x3f; /* Bytes already in shsInfo->data */ - - /* Handle any leading odd-sized chunks */ - - if (t) { - unsigned char *p = (unsigned char *)in + t; - - t = 64 - t; - if (len < t) { - memcpy(p, input, len); - return; - } - memcpy(p, input, t); - ByteReverse(in, 16); - MD5Transform(buf, reinterpret_cast(in)); - input += t; - len -= t; - } - - /* Process data in 64-byte chunks */ - - while (len >= 64) { - memcpy(in, input, 64); - ByteReverse(in, 16); - MD5Transform(buf, reinterpret_cast(in)); - input += 64; - len -= 64; - } - - /* Handle any remaining bytes of data. */ - memcpy(in, input, len); -} - -/* - * Final wrapup - pad to 64-byte boundary with the bit pattern - * 1 0* (64-bit count of bits processed, MSB-first) - */ -void MD5Context::Finish(data_ptr_t out_digest) { - unsigned count; - unsigned char *p; - - /* Compute number of bytes mod 64 */ - count = (bits[0] >> 3) & 0x3F; - - /* Set the first char of padding to 0x80. This is safe since there is - always at least one byte free */ - p = in + count; - *p++ = 0x80; - - /* Bytes of padding needed to make 64 bytes */ - count = 64 - 1 - count; - - /* Pad out to 56 mod 64 */ - if (count < 8) { - /* Two lots of padding: Pad the first block to 64 bytes */ - memset(p, 0, count); - ByteReverse(in, 16); - MD5Transform(buf, reinterpret_cast(in)); - - /* Now fill the next block with 56 bytes */ - memset(in, 0, 56); - } else { - /* Pad block to 56 bytes */ - memset(p, 0, count - 8); - } - ByteReverse(in, 14); - - /* Append length in bits and transform */ - (reinterpret_cast(in))[14] = bits[0]; - (reinterpret_cast(in))[15] = bits[1]; - - MD5Transform(buf, reinterpret_cast(in)); - ByteReverse(reinterpret_cast(buf), 4); - memcpy(out_digest, buf, 16); -} - -void MD5Context::FinishHex(char *out_digest) { - data_t digest[MD5_HASH_LENGTH_BINARY]; - Finish(digest); - duckdb_mbedtls::MbedTlsWrapper::ToBase16(reinterpret_cast(digest), out_digest, MD5_HASH_LENGTH_BINARY); -} - -string MD5Context::FinishHex() { - char digest[MD5_HASH_LENGTH_TEXT]; - FinishHex(digest); - return string(digest, MD5_HASH_LENGTH_TEXT); -} - -void MD5Context::Add(const char *data) { - MD5Update(const_data_ptr_cast(data), strlen(data)); -} - -} // namespace duckdb -// This file is licensed under Apache License 2.0 -// Source code taken from https://github.com/google/benchmark -// It is highly modified - - - - -namespace duckdb { - -inline uint64_t ChronoNow() { - return std::chrono::duration_cast( - std::chrono::time_point_cast(std::chrono::high_resolution_clock::now()) - .time_since_epoch()) - .count(); -} - -inline uint64_t Now() { -#if defined(RDTSC) -#if defined(__i386__) - uint64_t ret; - __asm__ volatile("rdtsc" : "=A"(ret)); - return ret; -#elif defined(__x86_64__) || defined(__amd64__) - uint64_t low, high; - __asm__ volatile("rdtsc" : "=a"(low), "=d"(high)); - return (high << 32) | low; -#elif defined(__powerpc__) || defined(__ppc__) - uint64_t tbl, tbu0, tbu1; - asm("mftbu %0" : "=r"(tbu0)); - asm("mftb %0" : "=r"(tbl)); - asm("mftbu %0" : "=r"(tbu1)); - tbl &= -static_cast(tbu0 == tbu1); - return (tbu1 << 32) | tbl; -#elif defined(__sparc__) - uint64_t tick; - asm(".byte 0x83, 0x41, 0x00, 0x00"); - asm("mov %%g1, %0" : "=r"(tick)); - return tick; -#elif defined(__ia64__) - uint64_t itc; - asm("mov %0 = ar.itc" : "=r"(itc)); - return itc; -#elif defined(COMPILER_MSVC) && defined(_M_IX86) - _asm rdtsc -#elif defined(COMPILER_MSVC) - return __rdtsc(); -#elif defined(__aarch64__) - uint64_t virtual_timer_value; - asm volatile("mrs %0, cntvct_el0" : "=r"(virtual_timer_value)); - return virtual_timer_value; -#elif defined(__ARM_ARCH) -#if (__ARM_ARCH >= 6) - uint32_t pmccntr; - uint32_t pmuseren; - uint32_t pmcntenset; - asm volatile("mrc p15, 0, %0, c9, c14, 0" : "=r"(pmuseren)); - if (pmuseren & 1) { // Allows reading perfmon counters for user mode code. - asm volatile("mrc p15, 0, %0, c9, c12, 1" : "=r"(pmcntenset)); - if (pmcntenset & 0x80000000ul) { // Is it counting? - asm volatile("mrc p15, 0, %0, c9, c13, 0" : "=r"(pmccntr)); - return static_cast(pmccntr) * 64; // Should optimize to << 6 - } - } -#endif - return ChronoNow(); -#else - return ChronoNow(); -#endif -#else - return ChronoNow(); -#endif // defined(RDTSC) -} -uint64_t CycleCounter::Tick() const { - return Now(); -} -} // namespace duckdb -//------------------------------------------------------------------------- -// This file is automatically generated by scripts/generate_enum_util.py -// Do not edit this file manually, your changes will be overwritten -// If you want to exclude an enum from serialization, add it to the blacklist in the script -// -// Note: The generated code will only work properly if the enum is a top level item in the duckdb namespace -// If the enum is nested in a class, or in another namespace, the generated code will not compile. -// You should move the enum to the duckdb namespace, manually write a specialization or add it to the blacklist -//------------------------------------------------------------------------- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -namespace duckdb { - -template<> -const char* EnumUtil::ToChars(AccessMode value) { - switch(value) { - case AccessMode::UNDEFINED: - return "UNDEFINED"; - case AccessMode::AUTOMATIC: - return "AUTOMATIC"; - case AccessMode::READ_ONLY: - return "READ_ONLY"; - case AccessMode::READ_WRITE: - return "READ_WRITE"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -AccessMode EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "UNDEFINED")) { - return AccessMode::UNDEFINED; - } - if (StringUtil::Equals(value, "AUTOMATIC")) { - return AccessMode::AUTOMATIC; - } - if (StringUtil::Equals(value, "READ_ONLY")) { - return AccessMode::READ_ONLY; - } - if (StringUtil::Equals(value, "READ_WRITE")) { - return AccessMode::READ_WRITE; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(AggregateHandling value) { - switch(value) { - case AggregateHandling::STANDARD_HANDLING: - return "STANDARD_HANDLING"; - case AggregateHandling::NO_AGGREGATES_ALLOWED: - return "NO_AGGREGATES_ALLOWED"; - case AggregateHandling::FORCE_AGGREGATES: - return "FORCE_AGGREGATES"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -AggregateHandling EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "STANDARD_HANDLING")) { - return AggregateHandling::STANDARD_HANDLING; - } - if (StringUtil::Equals(value, "NO_AGGREGATES_ALLOWED")) { - return AggregateHandling::NO_AGGREGATES_ALLOWED; - } - if (StringUtil::Equals(value, "FORCE_AGGREGATES")) { - return AggregateHandling::FORCE_AGGREGATES; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(AggregateOrderDependent value) { - switch(value) { - case AggregateOrderDependent::ORDER_DEPENDENT: - return "ORDER_DEPENDENT"; - case AggregateOrderDependent::NOT_ORDER_DEPENDENT: - return "NOT_ORDER_DEPENDENT"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -AggregateOrderDependent EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "ORDER_DEPENDENT")) { - return AggregateOrderDependent::ORDER_DEPENDENT; - } - if (StringUtil::Equals(value, "NOT_ORDER_DEPENDENT")) { - return AggregateOrderDependent::NOT_ORDER_DEPENDENT; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(AggregateType value) { - switch(value) { - case AggregateType::NON_DISTINCT: - return "NON_DISTINCT"; - case AggregateType::DISTINCT: - return "DISTINCT"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -AggregateType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "NON_DISTINCT")) { - return AggregateType::NON_DISTINCT; - } - if (StringUtil::Equals(value, "DISTINCT")) { - return AggregateType::DISTINCT; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(AlterForeignKeyType value) { - switch(value) { - case AlterForeignKeyType::AFT_ADD: - return "AFT_ADD"; - case AlterForeignKeyType::AFT_DELETE: - return "AFT_DELETE"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -AlterForeignKeyType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "AFT_ADD")) { - return AlterForeignKeyType::AFT_ADD; - } - if (StringUtil::Equals(value, "AFT_DELETE")) { - return AlterForeignKeyType::AFT_DELETE; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(AlterScalarFunctionType value) { - switch(value) { - case AlterScalarFunctionType::INVALID: - return "INVALID"; - case AlterScalarFunctionType::ADD_FUNCTION_OVERLOADS: - return "ADD_FUNCTION_OVERLOADS"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -AlterScalarFunctionType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "INVALID")) { - return AlterScalarFunctionType::INVALID; - } - if (StringUtil::Equals(value, "ADD_FUNCTION_OVERLOADS")) { - return AlterScalarFunctionType::ADD_FUNCTION_OVERLOADS; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(AlterTableFunctionType value) { - switch(value) { - case AlterTableFunctionType::INVALID: - return "INVALID"; - case AlterTableFunctionType::ADD_FUNCTION_OVERLOADS: - return "ADD_FUNCTION_OVERLOADS"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -AlterTableFunctionType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "INVALID")) { - return AlterTableFunctionType::INVALID; - } - if (StringUtil::Equals(value, "ADD_FUNCTION_OVERLOADS")) { - return AlterTableFunctionType::ADD_FUNCTION_OVERLOADS; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(AlterTableType value) { - switch(value) { - case AlterTableType::INVALID: - return "INVALID"; - case AlterTableType::RENAME_COLUMN: - return "RENAME_COLUMN"; - case AlterTableType::RENAME_TABLE: - return "RENAME_TABLE"; - case AlterTableType::ADD_COLUMN: - return "ADD_COLUMN"; - case AlterTableType::REMOVE_COLUMN: - return "REMOVE_COLUMN"; - case AlterTableType::ALTER_COLUMN_TYPE: - return "ALTER_COLUMN_TYPE"; - case AlterTableType::SET_DEFAULT: - return "SET_DEFAULT"; - case AlterTableType::FOREIGN_KEY_CONSTRAINT: - return "FOREIGN_KEY_CONSTRAINT"; - case AlterTableType::SET_NOT_NULL: - return "SET_NOT_NULL"; - case AlterTableType::DROP_NOT_NULL: - return "DROP_NOT_NULL"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -AlterTableType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "INVALID")) { - return AlterTableType::INVALID; - } - if (StringUtil::Equals(value, "RENAME_COLUMN")) { - return AlterTableType::RENAME_COLUMN; - } - if (StringUtil::Equals(value, "RENAME_TABLE")) { - return AlterTableType::RENAME_TABLE; - } - if (StringUtil::Equals(value, "ADD_COLUMN")) { - return AlterTableType::ADD_COLUMN; - } - if (StringUtil::Equals(value, "REMOVE_COLUMN")) { - return AlterTableType::REMOVE_COLUMN; - } - if (StringUtil::Equals(value, "ALTER_COLUMN_TYPE")) { - return AlterTableType::ALTER_COLUMN_TYPE; - } - if (StringUtil::Equals(value, "SET_DEFAULT")) { - return AlterTableType::SET_DEFAULT; - } - if (StringUtil::Equals(value, "FOREIGN_KEY_CONSTRAINT")) { - return AlterTableType::FOREIGN_KEY_CONSTRAINT; - } - if (StringUtil::Equals(value, "SET_NOT_NULL")) { - return AlterTableType::SET_NOT_NULL; - } - if (StringUtil::Equals(value, "DROP_NOT_NULL")) { - return AlterTableType::DROP_NOT_NULL; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(AlterType value) { - switch(value) { - case AlterType::INVALID: - return "INVALID"; - case AlterType::ALTER_TABLE: - return "ALTER_TABLE"; - case AlterType::ALTER_VIEW: - return "ALTER_VIEW"; - case AlterType::ALTER_SEQUENCE: - return "ALTER_SEQUENCE"; - case AlterType::CHANGE_OWNERSHIP: - return "CHANGE_OWNERSHIP"; - case AlterType::ALTER_SCALAR_FUNCTION: - return "ALTER_SCALAR_FUNCTION"; - case AlterType::ALTER_TABLE_FUNCTION: - return "ALTER_TABLE_FUNCTION"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -AlterType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "INVALID")) { - return AlterType::INVALID; - } - if (StringUtil::Equals(value, "ALTER_TABLE")) { - return AlterType::ALTER_TABLE; - } - if (StringUtil::Equals(value, "ALTER_VIEW")) { - return AlterType::ALTER_VIEW; - } - if (StringUtil::Equals(value, "ALTER_SEQUENCE")) { - return AlterType::ALTER_SEQUENCE; - } - if (StringUtil::Equals(value, "CHANGE_OWNERSHIP")) { - return AlterType::CHANGE_OWNERSHIP; - } - if (StringUtil::Equals(value, "ALTER_SCALAR_FUNCTION")) { - return AlterType::ALTER_SCALAR_FUNCTION; - } - if (StringUtil::Equals(value, "ALTER_TABLE_FUNCTION")) { - return AlterType::ALTER_TABLE_FUNCTION; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(AlterViewType value) { - switch(value) { - case AlterViewType::INVALID: - return "INVALID"; - case AlterViewType::RENAME_VIEW: - return "RENAME_VIEW"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -AlterViewType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "INVALID")) { - return AlterViewType::INVALID; - } - if (StringUtil::Equals(value, "RENAME_VIEW")) { - return AlterViewType::RENAME_VIEW; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(AppenderType value) { - switch(value) { - case AppenderType::LOGICAL: - return "LOGICAL"; - case AppenderType::PHYSICAL: - return "PHYSICAL"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -AppenderType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "LOGICAL")) { - return AppenderType::LOGICAL; - } - if (StringUtil::Equals(value, "PHYSICAL")) { - return AppenderType::PHYSICAL; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(ArrowDateTimeType value) { - switch(value) { - case ArrowDateTimeType::MILLISECONDS: - return "MILLISECONDS"; - case ArrowDateTimeType::MICROSECONDS: - return "MICROSECONDS"; - case ArrowDateTimeType::NANOSECONDS: - return "NANOSECONDS"; - case ArrowDateTimeType::SECONDS: - return "SECONDS"; - case ArrowDateTimeType::DAYS: - return "DAYS"; - case ArrowDateTimeType::MONTHS: - return "MONTHS"; - case ArrowDateTimeType::MONTH_DAY_NANO: - return "MONTH_DAY_NANO"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -ArrowDateTimeType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "MILLISECONDS")) { - return ArrowDateTimeType::MILLISECONDS; - } - if (StringUtil::Equals(value, "MICROSECONDS")) { - return ArrowDateTimeType::MICROSECONDS; - } - if (StringUtil::Equals(value, "NANOSECONDS")) { - return ArrowDateTimeType::NANOSECONDS; - } - if (StringUtil::Equals(value, "SECONDS")) { - return ArrowDateTimeType::SECONDS; - } - if (StringUtil::Equals(value, "DAYS")) { - return ArrowDateTimeType::DAYS; - } - if (StringUtil::Equals(value, "MONTHS")) { - return ArrowDateTimeType::MONTHS; - } - if (StringUtil::Equals(value, "MONTH_DAY_NANO")) { - return ArrowDateTimeType::MONTH_DAY_NANO; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(ArrowVariableSizeType value) { - switch(value) { - case ArrowVariableSizeType::FIXED_SIZE: - return "FIXED_SIZE"; - case ArrowVariableSizeType::NORMAL: - return "NORMAL"; - case ArrowVariableSizeType::SUPER_SIZE: - return "SUPER_SIZE"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -ArrowVariableSizeType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "FIXED_SIZE")) { - return ArrowVariableSizeType::FIXED_SIZE; - } - if (StringUtil::Equals(value, "NORMAL")) { - return ArrowVariableSizeType::NORMAL; - } - if (StringUtil::Equals(value, "SUPER_SIZE")) { - return ArrowVariableSizeType::SUPER_SIZE; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(BindingMode value) { - switch(value) { - case BindingMode::STANDARD_BINDING: - return "STANDARD_BINDING"; - case BindingMode::EXTRACT_NAMES: - return "EXTRACT_NAMES"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -BindingMode EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "STANDARD_BINDING")) { - return BindingMode::STANDARD_BINDING; - } - if (StringUtil::Equals(value, "EXTRACT_NAMES")) { - return BindingMode::EXTRACT_NAMES; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(BitpackingMode value) { - switch(value) { - case BitpackingMode::INVALID: - return "INVALID"; - case BitpackingMode::AUTO: - return "AUTO"; - case BitpackingMode::CONSTANT: - return "CONSTANT"; - case BitpackingMode::CONSTANT_DELTA: - return "CONSTANT_DELTA"; - case BitpackingMode::DELTA_FOR: - return "DELTA_FOR"; - case BitpackingMode::FOR: - return "FOR"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -BitpackingMode EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "INVALID")) { - return BitpackingMode::INVALID; - } - if (StringUtil::Equals(value, "AUTO")) { - return BitpackingMode::AUTO; - } - if (StringUtil::Equals(value, "CONSTANT")) { - return BitpackingMode::CONSTANT; - } - if (StringUtil::Equals(value, "CONSTANT_DELTA")) { - return BitpackingMode::CONSTANT_DELTA; - } - if (StringUtil::Equals(value, "DELTA_FOR")) { - return BitpackingMode::DELTA_FOR; - } - if (StringUtil::Equals(value, "FOR")) { - return BitpackingMode::FOR; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(BlockState value) { - switch(value) { - case BlockState::BLOCK_UNLOADED: - return "BLOCK_UNLOADED"; - case BlockState::BLOCK_LOADED: - return "BLOCK_LOADED"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -BlockState EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "BLOCK_UNLOADED")) { - return BlockState::BLOCK_UNLOADED; - } - if (StringUtil::Equals(value, "BLOCK_LOADED")) { - return BlockState::BLOCK_LOADED; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(CAPIResultSetType value) { - switch(value) { - case CAPIResultSetType::CAPI_RESULT_TYPE_NONE: - return "CAPI_RESULT_TYPE_NONE"; - case CAPIResultSetType::CAPI_RESULT_TYPE_MATERIALIZED: - return "CAPI_RESULT_TYPE_MATERIALIZED"; - case CAPIResultSetType::CAPI_RESULT_TYPE_STREAMING: - return "CAPI_RESULT_TYPE_STREAMING"; - case CAPIResultSetType::CAPI_RESULT_TYPE_DEPRECATED: - return "CAPI_RESULT_TYPE_DEPRECATED"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -CAPIResultSetType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "CAPI_RESULT_TYPE_NONE")) { - return CAPIResultSetType::CAPI_RESULT_TYPE_NONE; - } - if (StringUtil::Equals(value, "CAPI_RESULT_TYPE_MATERIALIZED")) { - return CAPIResultSetType::CAPI_RESULT_TYPE_MATERIALIZED; - } - if (StringUtil::Equals(value, "CAPI_RESULT_TYPE_STREAMING")) { - return CAPIResultSetType::CAPI_RESULT_TYPE_STREAMING; - } - if (StringUtil::Equals(value, "CAPI_RESULT_TYPE_DEPRECATED")) { - return CAPIResultSetType::CAPI_RESULT_TYPE_DEPRECATED; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(CSVState value) { - switch(value) { - case CSVState::STANDARD: - return "STANDARD"; - case CSVState::DELIMITER: - return "DELIMITER"; - case CSVState::RECORD_SEPARATOR: - return "RECORD_SEPARATOR"; - case CSVState::CARRIAGE_RETURN: - return "CARRIAGE_RETURN"; - case CSVState::QUOTED: - return "QUOTED"; - case CSVState::UNQUOTED: - return "UNQUOTED"; - case CSVState::ESCAPE: - return "ESCAPE"; - case CSVState::EMPTY_LINE: - return "EMPTY_LINE"; - case CSVState::INVALID: - return "INVALID"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -CSVState EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "STANDARD")) { - return CSVState::STANDARD; - } - if (StringUtil::Equals(value, "DELIMITER")) { - return CSVState::DELIMITER; - } - if (StringUtil::Equals(value, "RECORD_SEPARATOR")) { - return CSVState::RECORD_SEPARATOR; - } - if (StringUtil::Equals(value, "CARRIAGE_RETURN")) { - return CSVState::CARRIAGE_RETURN; - } - if (StringUtil::Equals(value, "QUOTED")) { - return CSVState::QUOTED; - } - if (StringUtil::Equals(value, "UNQUOTED")) { - return CSVState::UNQUOTED; - } - if (StringUtil::Equals(value, "ESCAPE")) { - return CSVState::ESCAPE; - } - if (StringUtil::Equals(value, "EMPTY_LINE")) { - return CSVState::EMPTY_LINE; - } - if (StringUtil::Equals(value, "INVALID")) { - return CSVState::INVALID; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(CTEMaterialize value) { - switch(value) { - case CTEMaterialize::CTE_MATERIALIZE_DEFAULT: - return "CTE_MATERIALIZE_DEFAULT"; - case CTEMaterialize::CTE_MATERIALIZE_ALWAYS: - return "CTE_MATERIALIZE_ALWAYS"; - case CTEMaterialize::CTE_MATERIALIZE_NEVER: - return "CTE_MATERIALIZE_NEVER"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -CTEMaterialize EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "CTE_MATERIALIZE_DEFAULT")) { - return CTEMaterialize::CTE_MATERIALIZE_DEFAULT; - } - if (StringUtil::Equals(value, "CTE_MATERIALIZE_ALWAYS")) { - return CTEMaterialize::CTE_MATERIALIZE_ALWAYS; - } - if (StringUtil::Equals(value, "CTE_MATERIALIZE_NEVER")) { - return CTEMaterialize::CTE_MATERIALIZE_NEVER; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(CatalogType value) { - switch(value) { - case CatalogType::INVALID: - return "INVALID"; - case CatalogType::TABLE_ENTRY: - return "TABLE_ENTRY"; - case CatalogType::SCHEMA_ENTRY: - return "SCHEMA_ENTRY"; - case CatalogType::VIEW_ENTRY: - return "VIEW_ENTRY"; - case CatalogType::INDEX_ENTRY: - return "INDEX_ENTRY"; - case CatalogType::PREPARED_STATEMENT: - return "PREPARED_STATEMENT"; - case CatalogType::SEQUENCE_ENTRY: - return "SEQUENCE_ENTRY"; - case CatalogType::COLLATION_ENTRY: - return "COLLATION_ENTRY"; - case CatalogType::TYPE_ENTRY: - return "TYPE_ENTRY"; - case CatalogType::DATABASE_ENTRY: - return "DATABASE_ENTRY"; - case CatalogType::TABLE_FUNCTION_ENTRY: - return "TABLE_FUNCTION_ENTRY"; - case CatalogType::SCALAR_FUNCTION_ENTRY: - return "SCALAR_FUNCTION_ENTRY"; - case CatalogType::AGGREGATE_FUNCTION_ENTRY: - return "AGGREGATE_FUNCTION_ENTRY"; - case CatalogType::PRAGMA_FUNCTION_ENTRY: - return "PRAGMA_FUNCTION_ENTRY"; - case CatalogType::COPY_FUNCTION_ENTRY: - return "COPY_FUNCTION_ENTRY"; - case CatalogType::MACRO_ENTRY: - return "MACRO_ENTRY"; - case CatalogType::TABLE_MACRO_ENTRY: - return "TABLE_MACRO_ENTRY"; - case CatalogType::UPDATED_ENTRY: - return "UPDATED_ENTRY"; - case CatalogType::DELETED_ENTRY: - return "DELETED_ENTRY"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -CatalogType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "INVALID")) { - return CatalogType::INVALID; - } - if (StringUtil::Equals(value, "TABLE_ENTRY")) { - return CatalogType::TABLE_ENTRY; - } - if (StringUtil::Equals(value, "SCHEMA_ENTRY")) { - return CatalogType::SCHEMA_ENTRY; - } - if (StringUtil::Equals(value, "VIEW_ENTRY")) { - return CatalogType::VIEW_ENTRY; - } - if (StringUtil::Equals(value, "INDEX_ENTRY")) { - return CatalogType::INDEX_ENTRY; - } - if (StringUtil::Equals(value, "PREPARED_STATEMENT")) { - return CatalogType::PREPARED_STATEMENT; - } - if (StringUtil::Equals(value, "SEQUENCE_ENTRY")) { - return CatalogType::SEQUENCE_ENTRY; - } - if (StringUtil::Equals(value, "COLLATION_ENTRY")) { - return CatalogType::COLLATION_ENTRY; - } - if (StringUtil::Equals(value, "TYPE_ENTRY")) { - return CatalogType::TYPE_ENTRY; - } - if (StringUtil::Equals(value, "DATABASE_ENTRY")) { - return CatalogType::DATABASE_ENTRY; - } - if (StringUtil::Equals(value, "TABLE_FUNCTION_ENTRY")) { - return CatalogType::TABLE_FUNCTION_ENTRY; - } - if (StringUtil::Equals(value, "SCALAR_FUNCTION_ENTRY")) { - return CatalogType::SCALAR_FUNCTION_ENTRY; - } - if (StringUtil::Equals(value, "AGGREGATE_FUNCTION_ENTRY")) { - return CatalogType::AGGREGATE_FUNCTION_ENTRY; - } - if (StringUtil::Equals(value, "PRAGMA_FUNCTION_ENTRY")) { - return CatalogType::PRAGMA_FUNCTION_ENTRY; - } - if (StringUtil::Equals(value, "COPY_FUNCTION_ENTRY")) { - return CatalogType::COPY_FUNCTION_ENTRY; - } - if (StringUtil::Equals(value, "MACRO_ENTRY")) { - return CatalogType::MACRO_ENTRY; - } - if (StringUtil::Equals(value, "TABLE_MACRO_ENTRY")) { - return CatalogType::TABLE_MACRO_ENTRY; - } - if (StringUtil::Equals(value, "UPDATED_ENTRY")) { - return CatalogType::UPDATED_ENTRY; - } - if (StringUtil::Equals(value, "DELETED_ENTRY")) { - return CatalogType::DELETED_ENTRY; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(CheckpointAbort value) { - switch(value) { - case CheckpointAbort::NO_ABORT: - return "NO_ABORT"; - case CheckpointAbort::DEBUG_ABORT_BEFORE_TRUNCATE: - return "DEBUG_ABORT_BEFORE_TRUNCATE"; - case CheckpointAbort::DEBUG_ABORT_BEFORE_HEADER: - return "DEBUG_ABORT_BEFORE_HEADER"; - case CheckpointAbort::DEBUG_ABORT_AFTER_FREE_LIST_WRITE: - return "DEBUG_ABORT_AFTER_FREE_LIST_WRITE"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -CheckpointAbort EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "NO_ABORT")) { - return CheckpointAbort::NO_ABORT; - } - if (StringUtil::Equals(value, "DEBUG_ABORT_BEFORE_TRUNCATE")) { - return CheckpointAbort::DEBUG_ABORT_BEFORE_TRUNCATE; - } - if (StringUtil::Equals(value, "DEBUG_ABORT_BEFORE_HEADER")) { - return CheckpointAbort::DEBUG_ABORT_BEFORE_HEADER; - } - if (StringUtil::Equals(value, "DEBUG_ABORT_AFTER_FREE_LIST_WRITE")) { - return CheckpointAbort::DEBUG_ABORT_AFTER_FREE_LIST_WRITE; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(ChunkInfoType value) { - switch(value) { - case ChunkInfoType::CONSTANT_INFO: - return "CONSTANT_INFO"; - case ChunkInfoType::VECTOR_INFO: - return "VECTOR_INFO"; - case ChunkInfoType::EMPTY_INFO: - return "EMPTY_INFO"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -ChunkInfoType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "CONSTANT_INFO")) { - return ChunkInfoType::CONSTANT_INFO; - } - if (StringUtil::Equals(value, "VECTOR_INFO")) { - return ChunkInfoType::VECTOR_INFO; - } - if (StringUtil::Equals(value, "EMPTY_INFO")) { - return ChunkInfoType::EMPTY_INFO; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(ColumnDataAllocatorType value) { - switch(value) { - case ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR: - return "BUFFER_MANAGER_ALLOCATOR"; - case ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR: - return "IN_MEMORY_ALLOCATOR"; - case ColumnDataAllocatorType::HYBRID: - return "HYBRID"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -ColumnDataAllocatorType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "BUFFER_MANAGER_ALLOCATOR")) { - return ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR; - } - if (StringUtil::Equals(value, "IN_MEMORY_ALLOCATOR")) { - return ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR; - } - if (StringUtil::Equals(value, "HYBRID")) { - return ColumnDataAllocatorType::HYBRID; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(ColumnDataScanProperties value) { - switch(value) { - case ColumnDataScanProperties::INVALID: - return "INVALID"; - case ColumnDataScanProperties::ALLOW_ZERO_COPY: - return "ALLOW_ZERO_COPY"; - case ColumnDataScanProperties::DISALLOW_ZERO_COPY: - return "DISALLOW_ZERO_COPY"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -ColumnDataScanProperties EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "INVALID")) { - return ColumnDataScanProperties::INVALID; - } - if (StringUtil::Equals(value, "ALLOW_ZERO_COPY")) { - return ColumnDataScanProperties::ALLOW_ZERO_COPY; - } - if (StringUtil::Equals(value, "DISALLOW_ZERO_COPY")) { - return ColumnDataScanProperties::DISALLOW_ZERO_COPY; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(ColumnSegmentType value) { - switch(value) { - case ColumnSegmentType::TRANSIENT: - return "TRANSIENT"; - case ColumnSegmentType::PERSISTENT: - return "PERSISTENT"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -ColumnSegmentType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "TRANSIENT")) { - return ColumnSegmentType::TRANSIENT; - } - if (StringUtil::Equals(value, "PERSISTENT")) { - return ColumnSegmentType::PERSISTENT; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(CompressedMaterializationDirection value) { - switch(value) { - case CompressedMaterializationDirection::INVALID: - return "INVALID"; - case CompressedMaterializationDirection::COMPRESS: - return "COMPRESS"; - case CompressedMaterializationDirection::DECOMPRESS: - return "DECOMPRESS"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -CompressedMaterializationDirection EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "INVALID")) { - return CompressedMaterializationDirection::INVALID; - } - if (StringUtil::Equals(value, "COMPRESS")) { - return CompressedMaterializationDirection::COMPRESS; - } - if (StringUtil::Equals(value, "DECOMPRESS")) { - return CompressedMaterializationDirection::DECOMPRESS; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(CompressionType value) { - switch(value) { - case CompressionType::COMPRESSION_AUTO: - return "COMPRESSION_AUTO"; - case CompressionType::COMPRESSION_UNCOMPRESSED: - return "COMPRESSION_UNCOMPRESSED"; - case CompressionType::COMPRESSION_CONSTANT: - return "COMPRESSION_CONSTANT"; - case CompressionType::COMPRESSION_RLE: - return "COMPRESSION_RLE"; - case CompressionType::COMPRESSION_DICTIONARY: - return "COMPRESSION_DICTIONARY"; - case CompressionType::COMPRESSION_PFOR_DELTA: - return "COMPRESSION_PFOR_DELTA"; - case CompressionType::COMPRESSION_BITPACKING: - return "COMPRESSION_BITPACKING"; - case CompressionType::COMPRESSION_FSST: - return "COMPRESSION_FSST"; - case CompressionType::COMPRESSION_CHIMP: - return "COMPRESSION_CHIMP"; - case CompressionType::COMPRESSION_PATAS: - return "COMPRESSION_PATAS"; - case CompressionType::COMPRESSION_COUNT: - return "COMPRESSION_COUNT"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -CompressionType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "COMPRESSION_AUTO")) { - return CompressionType::COMPRESSION_AUTO; - } - if (StringUtil::Equals(value, "COMPRESSION_UNCOMPRESSED")) { - return CompressionType::COMPRESSION_UNCOMPRESSED; - } - if (StringUtil::Equals(value, "COMPRESSION_CONSTANT")) { - return CompressionType::COMPRESSION_CONSTANT; - } - if (StringUtil::Equals(value, "COMPRESSION_RLE")) { - return CompressionType::COMPRESSION_RLE; - } - if (StringUtil::Equals(value, "COMPRESSION_DICTIONARY")) { - return CompressionType::COMPRESSION_DICTIONARY; - } - if (StringUtil::Equals(value, "COMPRESSION_PFOR_DELTA")) { - return CompressionType::COMPRESSION_PFOR_DELTA; - } - if (StringUtil::Equals(value, "COMPRESSION_BITPACKING")) { - return CompressionType::COMPRESSION_BITPACKING; - } - if (StringUtil::Equals(value, "COMPRESSION_FSST")) { - return CompressionType::COMPRESSION_FSST; - } - if (StringUtil::Equals(value, "COMPRESSION_CHIMP")) { - return CompressionType::COMPRESSION_CHIMP; - } - if (StringUtil::Equals(value, "COMPRESSION_PATAS")) { - return CompressionType::COMPRESSION_PATAS; - } - if (StringUtil::Equals(value, "COMPRESSION_COUNT")) { - return CompressionType::COMPRESSION_COUNT; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(ConflictManagerMode value) { - switch(value) { - case ConflictManagerMode::SCAN: - return "SCAN"; - case ConflictManagerMode::THROW: - return "THROW"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -ConflictManagerMode EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "SCAN")) { - return ConflictManagerMode::SCAN; - } - if (StringUtil::Equals(value, "THROW")) { - return ConflictManagerMode::THROW; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(ConstraintType value) { - switch(value) { - case ConstraintType::INVALID: - return "INVALID"; - case ConstraintType::NOT_NULL: - return "NOT_NULL"; - case ConstraintType::CHECK: - return "CHECK"; - case ConstraintType::UNIQUE: - return "UNIQUE"; - case ConstraintType::FOREIGN_KEY: - return "FOREIGN_KEY"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -ConstraintType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "INVALID")) { - return ConstraintType::INVALID; - } - if (StringUtil::Equals(value, "NOT_NULL")) { - return ConstraintType::NOT_NULL; - } - if (StringUtil::Equals(value, "CHECK")) { - return ConstraintType::CHECK; - } - if (StringUtil::Equals(value, "UNIQUE")) { - return ConstraintType::UNIQUE; - } - if (StringUtil::Equals(value, "FOREIGN_KEY")) { - return ConstraintType::FOREIGN_KEY; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(DataFileType value) { - switch(value) { - case DataFileType::FILE_DOES_NOT_EXIST: - return "FILE_DOES_NOT_EXIST"; - case DataFileType::DUCKDB_FILE: - return "DUCKDB_FILE"; - case DataFileType::SQLITE_FILE: - return "SQLITE_FILE"; - case DataFileType::PARQUET_FILE: - return "PARQUET_FILE"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -DataFileType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "FILE_DOES_NOT_EXIST")) { - return DataFileType::FILE_DOES_NOT_EXIST; - } - if (StringUtil::Equals(value, "DUCKDB_FILE")) { - return DataFileType::DUCKDB_FILE; - } - if (StringUtil::Equals(value, "SQLITE_FILE")) { - return DataFileType::SQLITE_FILE; - } - if (StringUtil::Equals(value, "PARQUET_FILE")) { - return DataFileType::PARQUET_FILE; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(DatePartSpecifier value) { - switch(value) { - case DatePartSpecifier::YEAR: - return "YEAR"; - case DatePartSpecifier::MONTH: - return "MONTH"; - case DatePartSpecifier::DAY: - return "DAY"; - case DatePartSpecifier::DECADE: - return "DECADE"; - case DatePartSpecifier::CENTURY: - return "CENTURY"; - case DatePartSpecifier::MILLENNIUM: - return "MILLENNIUM"; - case DatePartSpecifier::MICROSECONDS: - return "MICROSECONDS"; - case DatePartSpecifier::MILLISECONDS: - return "MILLISECONDS"; - case DatePartSpecifier::SECOND: - return "SECOND"; - case DatePartSpecifier::MINUTE: - return "MINUTE"; - case DatePartSpecifier::HOUR: - return "HOUR"; - case DatePartSpecifier::DOW: - return "DOW"; - case DatePartSpecifier::ISODOW: - return "ISODOW"; - case DatePartSpecifier::WEEK: - return "WEEK"; - case DatePartSpecifier::ISOYEAR: - return "ISOYEAR"; - case DatePartSpecifier::QUARTER: - return "QUARTER"; - case DatePartSpecifier::DOY: - return "DOY"; - case DatePartSpecifier::YEARWEEK: - return "YEARWEEK"; - case DatePartSpecifier::ERA: - return "ERA"; - case DatePartSpecifier::TIMEZONE: - return "TIMEZONE"; - case DatePartSpecifier::TIMEZONE_HOUR: - return "TIMEZONE_HOUR"; - case DatePartSpecifier::TIMEZONE_MINUTE: - return "TIMEZONE_MINUTE"; - case DatePartSpecifier::EPOCH: - return "EPOCH"; - case DatePartSpecifier::JULIAN_DAY: - return "JULIAN_DAY"; - case DatePartSpecifier::INVALID: - return "INVALID"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -DatePartSpecifier EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "YEAR")) { - return DatePartSpecifier::YEAR; - } - if (StringUtil::Equals(value, "MONTH")) { - return DatePartSpecifier::MONTH; - } - if (StringUtil::Equals(value, "DAY")) { - return DatePartSpecifier::DAY; - } - if (StringUtil::Equals(value, "DECADE")) { - return DatePartSpecifier::DECADE; - } - if (StringUtil::Equals(value, "CENTURY")) { - return DatePartSpecifier::CENTURY; - } - if (StringUtil::Equals(value, "MILLENNIUM")) { - return DatePartSpecifier::MILLENNIUM; - } - if (StringUtil::Equals(value, "MICROSECONDS")) { - return DatePartSpecifier::MICROSECONDS; - } - if (StringUtil::Equals(value, "MILLISECONDS")) { - return DatePartSpecifier::MILLISECONDS; - } - if (StringUtil::Equals(value, "SECOND")) { - return DatePartSpecifier::SECOND; - } - if (StringUtil::Equals(value, "MINUTE")) { - return DatePartSpecifier::MINUTE; - } - if (StringUtil::Equals(value, "HOUR")) { - return DatePartSpecifier::HOUR; - } - if (StringUtil::Equals(value, "DOW")) { - return DatePartSpecifier::DOW; - } - if (StringUtil::Equals(value, "ISODOW")) { - return DatePartSpecifier::ISODOW; - } - if (StringUtil::Equals(value, "WEEK")) { - return DatePartSpecifier::WEEK; - } - if (StringUtil::Equals(value, "ISOYEAR")) { - return DatePartSpecifier::ISOYEAR; - } - if (StringUtil::Equals(value, "QUARTER")) { - return DatePartSpecifier::QUARTER; - } - if (StringUtil::Equals(value, "DOY")) { - return DatePartSpecifier::DOY; - } - if (StringUtil::Equals(value, "YEARWEEK")) { - return DatePartSpecifier::YEARWEEK; - } - if (StringUtil::Equals(value, "ERA")) { - return DatePartSpecifier::ERA; - } - if (StringUtil::Equals(value, "TIMEZONE")) { - return DatePartSpecifier::TIMEZONE; - } - if (StringUtil::Equals(value, "TIMEZONE_HOUR")) { - return DatePartSpecifier::TIMEZONE_HOUR; - } - if (StringUtil::Equals(value, "TIMEZONE_MINUTE")) { - return DatePartSpecifier::TIMEZONE_MINUTE; - } - if (StringUtil::Equals(value, "EPOCH")) { - return DatePartSpecifier::EPOCH; - } - if (StringUtil::Equals(value, "JULIAN_DAY")) { - return DatePartSpecifier::JULIAN_DAY; - } - if (StringUtil::Equals(value, "INVALID")) { - return DatePartSpecifier::INVALID; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(DebugInitialize value) { - switch(value) { - case DebugInitialize::NO_INITIALIZE: - return "NO_INITIALIZE"; - case DebugInitialize::DEBUG_ZERO_INITIALIZE: - return "DEBUG_ZERO_INITIALIZE"; - case DebugInitialize::DEBUG_ONE_INITIALIZE: - return "DEBUG_ONE_INITIALIZE"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -DebugInitialize EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "NO_INITIALIZE")) { - return DebugInitialize::NO_INITIALIZE; - } - if (StringUtil::Equals(value, "DEBUG_ZERO_INITIALIZE")) { - return DebugInitialize::DEBUG_ZERO_INITIALIZE; - } - if (StringUtil::Equals(value, "DEBUG_ONE_INITIALIZE")) { - return DebugInitialize::DEBUG_ONE_INITIALIZE; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(DefaultOrderByNullType value) { - switch(value) { - case DefaultOrderByNullType::INVALID: - return "INVALID"; - case DefaultOrderByNullType::NULLS_FIRST: - return "NULLS_FIRST"; - case DefaultOrderByNullType::NULLS_LAST: - return "NULLS_LAST"; - case DefaultOrderByNullType::NULLS_FIRST_ON_ASC_LAST_ON_DESC: - return "NULLS_FIRST_ON_ASC_LAST_ON_DESC"; - case DefaultOrderByNullType::NULLS_LAST_ON_ASC_FIRST_ON_DESC: - return "NULLS_LAST_ON_ASC_FIRST_ON_DESC"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -DefaultOrderByNullType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "INVALID")) { - return DefaultOrderByNullType::INVALID; - } - if (StringUtil::Equals(value, "NULLS_FIRST")) { - return DefaultOrderByNullType::NULLS_FIRST; - } - if (StringUtil::Equals(value, "NULLS_LAST")) { - return DefaultOrderByNullType::NULLS_LAST; - } - if (StringUtil::Equals(value, "NULLS_FIRST_ON_ASC_LAST_ON_DESC")) { - return DefaultOrderByNullType::NULLS_FIRST_ON_ASC_LAST_ON_DESC; - } - if (StringUtil::Equals(value, "NULLS_LAST_ON_ASC_FIRST_ON_DESC")) { - return DefaultOrderByNullType::NULLS_LAST_ON_ASC_FIRST_ON_DESC; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(DistinctType value) { - switch(value) { - case DistinctType::DISTINCT: - return "DISTINCT"; - case DistinctType::DISTINCT_ON: - return "DISTINCT_ON"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -DistinctType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "DISTINCT")) { - return DistinctType::DISTINCT; - } - if (StringUtil::Equals(value, "DISTINCT_ON")) { - return DistinctType::DISTINCT_ON; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(ErrorType value) { - switch(value) { - case ErrorType::UNSIGNED_EXTENSION: - return "UNSIGNED_EXTENSION"; - case ErrorType::INVALIDATED_TRANSACTION: - return "INVALIDATED_TRANSACTION"; - case ErrorType::INVALIDATED_DATABASE: - return "INVALIDATED_DATABASE"; - case ErrorType::ERROR_COUNT: - return "ERROR_COUNT"; - case ErrorType::INVALID: - return "INVALID"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -ErrorType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "UNSIGNED_EXTENSION")) { - return ErrorType::UNSIGNED_EXTENSION; - } - if (StringUtil::Equals(value, "INVALIDATED_TRANSACTION")) { - return ErrorType::INVALIDATED_TRANSACTION; - } - if (StringUtil::Equals(value, "INVALIDATED_DATABASE")) { - return ErrorType::INVALIDATED_DATABASE; - } - if (StringUtil::Equals(value, "ERROR_COUNT")) { - return ErrorType::ERROR_COUNT; - } - if (StringUtil::Equals(value, "INVALID")) { - return ErrorType::INVALID; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(ExceptionFormatValueType value) { - switch(value) { - case ExceptionFormatValueType::FORMAT_VALUE_TYPE_DOUBLE: - return "FORMAT_VALUE_TYPE_DOUBLE"; - case ExceptionFormatValueType::FORMAT_VALUE_TYPE_INTEGER: - return "FORMAT_VALUE_TYPE_INTEGER"; - case ExceptionFormatValueType::FORMAT_VALUE_TYPE_STRING: - return "FORMAT_VALUE_TYPE_STRING"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -ExceptionFormatValueType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "FORMAT_VALUE_TYPE_DOUBLE")) { - return ExceptionFormatValueType::FORMAT_VALUE_TYPE_DOUBLE; - } - if (StringUtil::Equals(value, "FORMAT_VALUE_TYPE_INTEGER")) { - return ExceptionFormatValueType::FORMAT_VALUE_TYPE_INTEGER; - } - if (StringUtil::Equals(value, "FORMAT_VALUE_TYPE_STRING")) { - return ExceptionFormatValueType::FORMAT_VALUE_TYPE_STRING; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(ExplainOutputType value) { - switch(value) { - case ExplainOutputType::ALL: - return "ALL"; - case ExplainOutputType::OPTIMIZED_ONLY: - return "OPTIMIZED_ONLY"; - case ExplainOutputType::PHYSICAL_ONLY: - return "PHYSICAL_ONLY"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -ExplainOutputType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "ALL")) { - return ExplainOutputType::ALL; - } - if (StringUtil::Equals(value, "OPTIMIZED_ONLY")) { - return ExplainOutputType::OPTIMIZED_ONLY; - } - if (StringUtil::Equals(value, "PHYSICAL_ONLY")) { - return ExplainOutputType::PHYSICAL_ONLY; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(ExplainType value) { - switch(value) { - case ExplainType::EXPLAIN_STANDARD: - return "EXPLAIN_STANDARD"; - case ExplainType::EXPLAIN_ANALYZE: - return "EXPLAIN_ANALYZE"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -ExplainType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "EXPLAIN_STANDARD")) { - return ExplainType::EXPLAIN_STANDARD; - } - if (StringUtil::Equals(value, "EXPLAIN_ANALYZE")) { - return ExplainType::EXPLAIN_ANALYZE; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(ExpressionClass value) { - switch(value) { - case ExpressionClass::INVALID: - return "INVALID"; - case ExpressionClass::AGGREGATE: - return "AGGREGATE"; - case ExpressionClass::CASE: - return "CASE"; - case ExpressionClass::CAST: - return "CAST"; - case ExpressionClass::COLUMN_REF: - return "COLUMN_REF"; - case ExpressionClass::COMPARISON: - return "COMPARISON"; - case ExpressionClass::CONJUNCTION: - return "CONJUNCTION"; - case ExpressionClass::CONSTANT: - return "CONSTANT"; - case ExpressionClass::DEFAULT: - return "DEFAULT"; - case ExpressionClass::FUNCTION: - return "FUNCTION"; - case ExpressionClass::OPERATOR: - return "OPERATOR"; - case ExpressionClass::STAR: - return "STAR"; - case ExpressionClass::SUBQUERY: - return "SUBQUERY"; - case ExpressionClass::WINDOW: - return "WINDOW"; - case ExpressionClass::PARAMETER: - return "PARAMETER"; - case ExpressionClass::COLLATE: - return "COLLATE"; - case ExpressionClass::LAMBDA: - return "LAMBDA"; - case ExpressionClass::POSITIONAL_REFERENCE: - return "POSITIONAL_REFERENCE"; - case ExpressionClass::BETWEEN: - return "BETWEEN"; - case ExpressionClass::BOUND_AGGREGATE: - return "BOUND_AGGREGATE"; - case ExpressionClass::BOUND_CASE: - return "BOUND_CASE"; - case ExpressionClass::BOUND_CAST: - return "BOUND_CAST"; - case ExpressionClass::BOUND_COLUMN_REF: - return "BOUND_COLUMN_REF"; - case ExpressionClass::BOUND_COMPARISON: - return "BOUND_COMPARISON"; - case ExpressionClass::BOUND_CONJUNCTION: - return "BOUND_CONJUNCTION"; - case ExpressionClass::BOUND_CONSTANT: - return "BOUND_CONSTANT"; - case ExpressionClass::BOUND_DEFAULT: - return "BOUND_DEFAULT"; - case ExpressionClass::BOUND_FUNCTION: - return "BOUND_FUNCTION"; - case ExpressionClass::BOUND_OPERATOR: - return "BOUND_OPERATOR"; - case ExpressionClass::BOUND_PARAMETER: - return "BOUND_PARAMETER"; - case ExpressionClass::BOUND_REF: - return "BOUND_REF"; - case ExpressionClass::BOUND_SUBQUERY: - return "BOUND_SUBQUERY"; - case ExpressionClass::BOUND_WINDOW: - return "BOUND_WINDOW"; - case ExpressionClass::BOUND_BETWEEN: - return "BOUND_BETWEEN"; - case ExpressionClass::BOUND_UNNEST: - return "BOUND_UNNEST"; - case ExpressionClass::BOUND_LAMBDA: - return "BOUND_LAMBDA"; - case ExpressionClass::BOUND_LAMBDA_REF: - return "BOUND_LAMBDA_REF"; - case ExpressionClass::BOUND_EXPRESSION: - return "BOUND_EXPRESSION"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -ExpressionClass EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "INVALID")) { - return ExpressionClass::INVALID; - } - if (StringUtil::Equals(value, "AGGREGATE")) { - return ExpressionClass::AGGREGATE; - } - if (StringUtil::Equals(value, "CASE")) { - return ExpressionClass::CASE; - } - if (StringUtil::Equals(value, "CAST")) { - return ExpressionClass::CAST; - } - if (StringUtil::Equals(value, "COLUMN_REF")) { - return ExpressionClass::COLUMN_REF; - } - if (StringUtil::Equals(value, "COMPARISON")) { - return ExpressionClass::COMPARISON; - } - if (StringUtil::Equals(value, "CONJUNCTION")) { - return ExpressionClass::CONJUNCTION; - } - if (StringUtil::Equals(value, "CONSTANT")) { - return ExpressionClass::CONSTANT; - } - if (StringUtil::Equals(value, "DEFAULT")) { - return ExpressionClass::DEFAULT; - } - if (StringUtil::Equals(value, "FUNCTION")) { - return ExpressionClass::FUNCTION; - } - if (StringUtil::Equals(value, "OPERATOR")) { - return ExpressionClass::OPERATOR; - } - if (StringUtil::Equals(value, "STAR")) { - return ExpressionClass::STAR; - } - if (StringUtil::Equals(value, "SUBQUERY")) { - return ExpressionClass::SUBQUERY; - } - if (StringUtil::Equals(value, "WINDOW")) { - return ExpressionClass::WINDOW; - } - if (StringUtil::Equals(value, "PARAMETER")) { - return ExpressionClass::PARAMETER; - } - if (StringUtil::Equals(value, "COLLATE")) { - return ExpressionClass::COLLATE; - } - if (StringUtil::Equals(value, "LAMBDA")) { - return ExpressionClass::LAMBDA; - } - if (StringUtil::Equals(value, "POSITIONAL_REFERENCE")) { - return ExpressionClass::POSITIONAL_REFERENCE; - } - if (StringUtil::Equals(value, "BETWEEN")) { - return ExpressionClass::BETWEEN; - } - if (StringUtil::Equals(value, "BOUND_AGGREGATE")) { - return ExpressionClass::BOUND_AGGREGATE; - } - if (StringUtil::Equals(value, "BOUND_CASE")) { - return ExpressionClass::BOUND_CASE; - } - if (StringUtil::Equals(value, "BOUND_CAST")) { - return ExpressionClass::BOUND_CAST; - } - if (StringUtil::Equals(value, "BOUND_COLUMN_REF")) { - return ExpressionClass::BOUND_COLUMN_REF; - } - if (StringUtil::Equals(value, "BOUND_COMPARISON")) { - return ExpressionClass::BOUND_COMPARISON; - } - if (StringUtil::Equals(value, "BOUND_CONJUNCTION")) { - return ExpressionClass::BOUND_CONJUNCTION; - } - if (StringUtil::Equals(value, "BOUND_CONSTANT")) { - return ExpressionClass::BOUND_CONSTANT; - } - if (StringUtil::Equals(value, "BOUND_DEFAULT")) { - return ExpressionClass::BOUND_DEFAULT; - } - if (StringUtil::Equals(value, "BOUND_FUNCTION")) { - return ExpressionClass::BOUND_FUNCTION; - } - if (StringUtil::Equals(value, "BOUND_OPERATOR")) { - return ExpressionClass::BOUND_OPERATOR; - } - if (StringUtil::Equals(value, "BOUND_PARAMETER")) { - return ExpressionClass::BOUND_PARAMETER; - } - if (StringUtil::Equals(value, "BOUND_REF")) { - return ExpressionClass::BOUND_REF; - } - if (StringUtil::Equals(value, "BOUND_SUBQUERY")) { - return ExpressionClass::BOUND_SUBQUERY; - } - if (StringUtil::Equals(value, "BOUND_WINDOW")) { - return ExpressionClass::BOUND_WINDOW; - } - if (StringUtil::Equals(value, "BOUND_BETWEEN")) { - return ExpressionClass::BOUND_BETWEEN; - } - if (StringUtil::Equals(value, "BOUND_UNNEST")) { - return ExpressionClass::BOUND_UNNEST; - } - if (StringUtil::Equals(value, "BOUND_LAMBDA")) { - return ExpressionClass::BOUND_LAMBDA; - } - if (StringUtil::Equals(value, "BOUND_LAMBDA_REF")) { - return ExpressionClass::BOUND_LAMBDA_REF; - } - if (StringUtil::Equals(value, "BOUND_EXPRESSION")) { - return ExpressionClass::BOUND_EXPRESSION; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(ExpressionType value) { - switch(value) { - case ExpressionType::INVALID: - return "INVALID"; - case ExpressionType::OPERATOR_CAST: - return "OPERATOR_CAST"; - case ExpressionType::OPERATOR_NOT: - return "OPERATOR_NOT"; - case ExpressionType::OPERATOR_IS_NULL: - return "OPERATOR_IS_NULL"; - case ExpressionType::OPERATOR_IS_NOT_NULL: - return "OPERATOR_IS_NOT_NULL"; - case ExpressionType::COMPARE_EQUAL: - return "COMPARE_EQUAL"; - case ExpressionType::COMPARE_NOTEQUAL: - return "COMPARE_NOTEQUAL"; - case ExpressionType::COMPARE_LESSTHAN: - return "COMPARE_LESSTHAN"; - case ExpressionType::COMPARE_GREATERTHAN: - return "COMPARE_GREATERTHAN"; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - return "COMPARE_LESSTHANOREQUALTO"; - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - return "COMPARE_GREATERTHANOREQUALTO"; - case ExpressionType::COMPARE_IN: - return "COMPARE_IN"; - case ExpressionType::COMPARE_NOT_IN: - return "COMPARE_NOT_IN"; - case ExpressionType::COMPARE_DISTINCT_FROM: - return "COMPARE_DISTINCT_FROM"; - case ExpressionType::COMPARE_BETWEEN: - return "COMPARE_BETWEEN"; - case ExpressionType::COMPARE_NOT_BETWEEN: - return "COMPARE_NOT_BETWEEN"; - case ExpressionType::COMPARE_NOT_DISTINCT_FROM: - return "COMPARE_NOT_DISTINCT_FROM"; - case ExpressionType::CONJUNCTION_AND: - return "CONJUNCTION_AND"; - case ExpressionType::CONJUNCTION_OR: - return "CONJUNCTION_OR"; - case ExpressionType::VALUE_CONSTANT: - return "VALUE_CONSTANT"; - case ExpressionType::VALUE_PARAMETER: - return "VALUE_PARAMETER"; - case ExpressionType::VALUE_TUPLE: - return "VALUE_TUPLE"; - case ExpressionType::VALUE_TUPLE_ADDRESS: - return "VALUE_TUPLE_ADDRESS"; - case ExpressionType::VALUE_NULL: - return "VALUE_NULL"; - case ExpressionType::VALUE_VECTOR: - return "VALUE_VECTOR"; - case ExpressionType::VALUE_SCALAR: - return "VALUE_SCALAR"; - case ExpressionType::VALUE_DEFAULT: - return "VALUE_DEFAULT"; - case ExpressionType::AGGREGATE: - return "AGGREGATE"; - case ExpressionType::BOUND_AGGREGATE: - return "BOUND_AGGREGATE"; - case ExpressionType::GROUPING_FUNCTION: - return "GROUPING_FUNCTION"; - case ExpressionType::WINDOW_AGGREGATE: - return "WINDOW_AGGREGATE"; - case ExpressionType::WINDOW_RANK: - return "WINDOW_RANK"; - case ExpressionType::WINDOW_RANK_DENSE: - return "WINDOW_RANK_DENSE"; - case ExpressionType::WINDOW_NTILE: - return "WINDOW_NTILE"; - case ExpressionType::WINDOW_PERCENT_RANK: - return "WINDOW_PERCENT_RANK"; - case ExpressionType::WINDOW_CUME_DIST: - return "WINDOW_CUME_DIST"; - case ExpressionType::WINDOW_ROW_NUMBER: - return "WINDOW_ROW_NUMBER"; - case ExpressionType::WINDOW_FIRST_VALUE: - return "WINDOW_FIRST_VALUE"; - case ExpressionType::WINDOW_LAST_VALUE: - return "WINDOW_LAST_VALUE"; - case ExpressionType::WINDOW_LEAD: - return "WINDOW_LEAD"; - case ExpressionType::WINDOW_LAG: - return "WINDOW_LAG"; - case ExpressionType::WINDOW_NTH_VALUE: - return "WINDOW_NTH_VALUE"; - case ExpressionType::FUNCTION: - return "FUNCTION"; - case ExpressionType::BOUND_FUNCTION: - return "BOUND_FUNCTION"; - case ExpressionType::CASE_EXPR: - return "CASE_EXPR"; - case ExpressionType::OPERATOR_NULLIF: - return "OPERATOR_NULLIF"; - case ExpressionType::OPERATOR_COALESCE: - return "OPERATOR_COALESCE"; - case ExpressionType::ARRAY_EXTRACT: - return "ARRAY_EXTRACT"; - case ExpressionType::ARRAY_SLICE: - return "ARRAY_SLICE"; - case ExpressionType::STRUCT_EXTRACT: - return "STRUCT_EXTRACT"; - case ExpressionType::ARRAY_CONSTRUCTOR: - return "ARRAY_CONSTRUCTOR"; - case ExpressionType::ARROW: - return "ARROW"; - case ExpressionType::SUBQUERY: - return "SUBQUERY"; - case ExpressionType::STAR: - return "STAR"; - case ExpressionType::TABLE_STAR: - return "TABLE_STAR"; - case ExpressionType::PLACEHOLDER: - return "PLACEHOLDER"; - case ExpressionType::COLUMN_REF: - return "COLUMN_REF"; - case ExpressionType::FUNCTION_REF: - return "FUNCTION_REF"; - case ExpressionType::TABLE_REF: - return "TABLE_REF"; - case ExpressionType::CAST: - return "CAST"; - case ExpressionType::BOUND_REF: - return "BOUND_REF"; - case ExpressionType::BOUND_COLUMN_REF: - return "BOUND_COLUMN_REF"; - case ExpressionType::BOUND_UNNEST: - return "BOUND_UNNEST"; - case ExpressionType::COLLATE: - return "COLLATE"; - case ExpressionType::LAMBDA: - return "LAMBDA"; - case ExpressionType::POSITIONAL_REFERENCE: - return "POSITIONAL_REFERENCE"; - case ExpressionType::BOUND_LAMBDA_REF: - return "BOUND_LAMBDA_REF"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -ExpressionType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "INVALID")) { - return ExpressionType::INVALID; - } - if (StringUtil::Equals(value, "OPERATOR_CAST")) { - return ExpressionType::OPERATOR_CAST; - } - if (StringUtil::Equals(value, "OPERATOR_NOT")) { - return ExpressionType::OPERATOR_NOT; - } - if (StringUtil::Equals(value, "OPERATOR_IS_NULL")) { - return ExpressionType::OPERATOR_IS_NULL; - } - if (StringUtil::Equals(value, "OPERATOR_IS_NOT_NULL")) { - return ExpressionType::OPERATOR_IS_NOT_NULL; - } - if (StringUtil::Equals(value, "COMPARE_EQUAL")) { - return ExpressionType::COMPARE_EQUAL; - } - if (StringUtil::Equals(value, "COMPARE_NOTEQUAL")) { - return ExpressionType::COMPARE_NOTEQUAL; - } - if (StringUtil::Equals(value, "COMPARE_LESSTHAN")) { - return ExpressionType::COMPARE_LESSTHAN; - } - if (StringUtil::Equals(value, "COMPARE_GREATERTHAN")) { - return ExpressionType::COMPARE_GREATERTHAN; - } - if (StringUtil::Equals(value, "COMPARE_LESSTHANOREQUALTO")) { - return ExpressionType::COMPARE_LESSTHANOREQUALTO; - } - if (StringUtil::Equals(value, "COMPARE_GREATERTHANOREQUALTO")) { - return ExpressionType::COMPARE_GREATERTHANOREQUALTO; - } - if (StringUtil::Equals(value, "COMPARE_IN")) { - return ExpressionType::COMPARE_IN; - } - if (StringUtil::Equals(value, "COMPARE_NOT_IN")) { - return ExpressionType::COMPARE_NOT_IN; - } - if (StringUtil::Equals(value, "COMPARE_DISTINCT_FROM")) { - return ExpressionType::COMPARE_DISTINCT_FROM; - } - if (StringUtil::Equals(value, "COMPARE_BETWEEN")) { - return ExpressionType::COMPARE_BETWEEN; - } - if (StringUtil::Equals(value, "COMPARE_NOT_BETWEEN")) { - return ExpressionType::COMPARE_NOT_BETWEEN; - } - if (StringUtil::Equals(value, "COMPARE_NOT_DISTINCT_FROM")) { - return ExpressionType::COMPARE_NOT_DISTINCT_FROM; - } - if (StringUtil::Equals(value, "CONJUNCTION_AND")) { - return ExpressionType::CONJUNCTION_AND; - } - if (StringUtil::Equals(value, "CONJUNCTION_OR")) { - return ExpressionType::CONJUNCTION_OR; - } - if (StringUtil::Equals(value, "VALUE_CONSTANT")) { - return ExpressionType::VALUE_CONSTANT; - } - if (StringUtil::Equals(value, "VALUE_PARAMETER")) { - return ExpressionType::VALUE_PARAMETER; - } - if (StringUtil::Equals(value, "VALUE_TUPLE")) { - return ExpressionType::VALUE_TUPLE; - } - if (StringUtil::Equals(value, "VALUE_TUPLE_ADDRESS")) { - return ExpressionType::VALUE_TUPLE_ADDRESS; - } - if (StringUtil::Equals(value, "VALUE_NULL")) { - return ExpressionType::VALUE_NULL; - } - if (StringUtil::Equals(value, "VALUE_VECTOR")) { - return ExpressionType::VALUE_VECTOR; - } - if (StringUtil::Equals(value, "VALUE_SCALAR")) { - return ExpressionType::VALUE_SCALAR; - } - if (StringUtil::Equals(value, "VALUE_DEFAULT")) { - return ExpressionType::VALUE_DEFAULT; - } - if (StringUtil::Equals(value, "AGGREGATE")) { - return ExpressionType::AGGREGATE; - } - if (StringUtil::Equals(value, "BOUND_AGGREGATE")) { - return ExpressionType::BOUND_AGGREGATE; - } - if (StringUtil::Equals(value, "GROUPING_FUNCTION")) { - return ExpressionType::GROUPING_FUNCTION; - } - if (StringUtil::Equals(value, "WINDOW_AGGREGATE")) { - return ExpressionType::WINDOW_AGGREGATE; - } - if (StringUtil::Equals(value, "WINDOW_RANK")) { - return ExpressionType::WINDOW_RANK; - } - if (StringUtil::Equals(value, "WINDOW_RANK_DENSE")) { - return ExpressionType::WINDOW_RANK_DENSE; - } - if (StringUtil::Equals(value, "WINDOW_NTILE")) { - return ExpressionType::WINDOW_NTILE; - } - if (StringUtil::Equals(value, "WINDOW_PERCENT_RANK")) { - return ExpressionType::WINDOW_PERCENT_RANK; - } - if (StringUtil::Equals(value, "WINDOW_CUME_DIST")) { - return ExpressionType::WINDOW_CUME_DIST; - } - if (StringUtil::Equals(value, "WINDOW_ROW_NUMBER")) { - return ExpressionType::WINDOW_ROW_NUMBER; - } - if (StringUtil::Equals(value, "WINDOW_FIRST_VALUE")) { - return ExpressionType::WINDOW_FIRST_VALUE; - } - if (StringUtil::Equals(value, "WINDOW_LAST_VALUE")) { - return ExpressionType::WINDOW_LAST_VALUE; - } - if (StringUtil::Equals(value, "WINDOW_LEAD")) { - return ExpressionType::WINDOW_LEAD; - } - if (StringUtil::Equals(value, "WINDOW_LAG")) { - return ExpressionType::WINDOW_LAG; - } - if (StringUtil::Equals(value, "WINDOW_NTH_VALUE")) { - return ExpressionType::WINDOW_NTH_VALUE; - } - if (StringUtil::Equals(value, "FUNCTION")) { - return ExpressionType::FUNCTION; - } - if (StringUtil::Equals(value, "BOUND_FUNCTION")) { - return ExpressionType::BOUND_FUNCTION; - } - if (StringUtil::Equals(value, "CASE_EXPR")) { - return ExpressionType::CASE_EXPR; - } - if (StringUtil::Equals(value, "OPERATOR_NULLIF")) { - return ExpressionType::OPERATOR_NULLIF; - } - if (StringUtil::Equals(value, "OPERATOR_COALESCE")) { - return ExpressionType::OPERATOR_COALESCE; - } - if (StringUtil::Equals(value, "ARRAY_EXTRACT")) { - return ExpressionType::ARRAY_EXTRACT; - } - if (StringUtil::Equals(value, "ARRAY_SLICE")) { - return ExpressionType::ARRAY_SLICE; - } - if (StringUtil::Equals(value, "STRUCT_EXTRACT")) { - return ExpressionType::STRUCT_EXTRACT; - } - if (StringUtil::Equals(value, "ARRAY_CONSTRUCTOR")) { - return ExpressionType::ARRAY_CONSTRUCTOR; - } - if (StringUtil::Equals(value, "ARROW")) { - return ExpressionType::ARROW; - } - if (StringUtil::Equals(value, "SUBQUERY")) { - return ExpressionType::SUBQUERY; - } - if (StringUtil::Equals(value, "STAR")) { - return ExpressionType::STAR; - } - if (StringUtil::Equals(value, "TABLE_STAR")) { - return ExpressionType::TABLE_STAR; - } - if (StringUtil::Equals(value, "PLACEHOLDER")) { - return ExpressionType::PLACEHOLDER; - } - if (StringUtil::Equals(value, "COLUMN_REF")) { - return ExpressionType::COLUMN_REF; - } - if (StringUtil::Equals(value, "FUNCTION_REF")) { - return ExpressionType::FUNCTION_REF; - } - if (StringUtil::Equals(value, "TABLE_REF")) { - return ExpressionType::TABLE_REF; - } - if (StringUtil::Equals(value, "CAST")) { - return ExpressionType::CAST; - } - if (StringUtil::Equals(value, "BOUND_REF")) { - return ExpressionType::BOUND_REF; - } - if (StringUtil::Equals(value, "BOUND_COLUMN_REF")) { - return ExpressionType::BOUND_COLUMN_REF; - } - if (StringUtil::Equals(value, "BOUND_UNNEST")) { - return ExpressionType::BOUND_UNNEST; - } - if (StringUtil::Equals(value, "COLLATE")) { - return ExpressionType::COLLATE; - } - if (StringUtil::Equals(value, "LAMBDA")) { - return ExpressionType::LAMBDA; - } - if (StringUtil::Equals(value, "POSITIONAL_REFERENCE")) { - return ExpressionType::POSITIONAL_REFERENCE; - } - if (StringUtil::Equals(value, "BOUND_LAMBDA_REF")) { - return ExpressionType::BOUND_LAMBDA_REF; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(ExtensionLoadResult value) { - switch(value) { - case ExtensionLoadResult::LOADED_EXTENSION: - return "LOADED_EXTENSION"; - case ExtensionLoadResult::EXTENSION_UNKNOWN: - return "EXTENSION_UNKNOWN"; - case ExtensionLoadResult::NOT_LOADED: - return "NOT_LOADED"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -ExtensionLoadResult EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "LOADED_EXTENSION")) { - return ExtensionLoadResult::LOADED_EXTENSION; - } - if (StringUtil::Equals(value, "EXTENSION_UNKNOWN")) { - return ExtensionLoadResult::EXTENSION_UNKNOWN; - } - if (StringUtil::Equals(value, "NOT_LOADED")) { - return ExtensionLoadResult::NOT_LOADED; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(ExtraTypeInfoType value) { - switch(value) { - case ExtraTypeInfoType::INVALID_TYPE_INFO: - return "INVALID_TYPE_INFO"; - case ExtraTypeInfoType::GENERIC_TYPE_INFO: - return "GENERIC_TYPE_INFO"; - case ExtraTypeInfoType::DECIMAL_TYPE_INFO: - return "DECIMAL_TYPE_INFO"; - case ExtraTypeInfoType::STRING_TYPE_INFO: - return "STRING_TYPE_INFO"; - case ExtraTypeInfoType::LIST_TYPE_INFO: - return "LIST_TYPE_INFO"; - case ExtraTypeInfoType::STRUCT_TYPE_INFO: - return "STRUCT_TYPE_INFO"; - case ExtraTypeInfoType::ENUM_TYPE_INFO: - return "ENUM_TYPE_INFO"; - case ExtraTypeInfoType::USER_TYPE_INFO: - return "USER_TYPE_INFO"; - case ExtraTypeInfoType::AGGREGATE_STATE_TYPE_INFO: - return "AGGREGATE_STATE_TYPE_INFO"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -ExtraTypeInfoType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "INVALID_TYPE_INFO")) { - return ExtraTypeInfoType::INVALID_TYPE_INFO; - } - if (StringUtil::Equals(value, "GENERIC_TYPE_INFO")) { - return ExtraTypeInfoType::GENERIC_TYPE_INFO; - } - if (StringUtil::Equals(value, "DECIMAL_TYPE_INFO")) { - return ExtraTypeInfoType::DECIMAL_TYPE_INFO; - } - if (StringUtil::Equals(value, "STRING_TYPE_INFO")) { - return ExtraTypeInfoType::STRING_TYPE_INFO; - } - if (StringUtil::Equals(value, "LIST_TYPE_INFO")) { - return ExtraTypeInfoType::LIST_TYPE_INFO; - } - if (StringUtil::Equals(value, "STRUCT_TYPE_INFO")) { - return ExtraTypeInfoType::STRUCT_TYPE_INFO; - } - if (StringUtil::Equals(value, "ENUM_TYPE_INFO")) { - return ExtraTypeInfoType::ENUM_TYPE_INFO; - } - if (StringUtil::Equals(value, "USER_TYPE_INFO")) { - return ExtraTypeInfoType::USER_TYPE_INFO; - } - if (StringUtil::Equals(value, "AGGREGATE_STATE_TYPE_INFO")) { - return ExtraTypeInfoType::AGGREGATE_STATE_TYPE_INFO; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(FileBufferType value) { - switch(value) { - case FileBufferType::BLOCK: - return "BLOCK"; - case FileBufferType::MANAGED_BUFFER: - return "MANAGED_BUFFER"; - case FileBufferType::TINY_BUFFER: - return "TINY_BUFFER"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -FileBufferType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "BLOCK")) { - return FileBufferType::BLOCK; - } - if (StringUtil::Equals(value, "MANAGED_BUFFER")) { - return FileBufferType::MANAGED_BUFFER; - } - if (StringUtil::Equals(value, "TINY_BUFFER")) { - return FileBufferType::TINY_BUFFER; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(FileCompressionType value) { - switch(value) { - case FileCompressionType::AUTO_DETECT: - return "AUTO_DETECT"; - case FileCompressionType::UNCOMPRESSED: - return "UNCOMPRESSED"; - case FileCompressionType::GZIP: - return "GZIP"; - case FileCompressionType::ZSTD: - return "ZSTD"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -FileCompressionType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "AUTO_DETECT")) { - return FileCompressionType::AUTO_DETECT; - } - if (StringUtil::Equals(value, "UNCOMPRESSED")) { - return FileCompressionType::UNCOMPRESSED; - } - if (StringUtil::Equals(value, "GZIP")) { - return FileCompressionType::GZIP; - } - if (StringUtil::Equals(value, "ZSTD")) { - return FileCompressionType::ZSTD; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(FileGlobOptions value) { - switch(value) { - case FileGlobOptions::DISALLOW_EMPTY: - return "DISALLOW_EMPTY"; - case FileGlobOptions::ALLOW_EMPTY: - return "ALLOW_EMPTY"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -FileGlobOptions EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "DISALLOW_EMPTY")) { - return FileGlobOptions::DISALLOW_EMPTY; - } - if (StringUtil::Equals(value, "ALLOW_EMPTY")) { - return FileGlobOptions::ALLOW_EMPTY; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(FileLockType value) { - switch(value) { - case FileLockType::NO_LOCK: - return "NO_LOCK"; - case FileLockType::READ_LOCK: - return "READ_LOCK"; - case FileLockType::WRITE_LOCK: - return "WRITE_LOCK"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -FileLockType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "NO_LOCK")) { - return FileLockType::NO_LOCK; - } - if (StringUtil::Equals(value, "READ_LOCK")) { - return FileLockType::READ_LOCK; - } - if (StringUtil::Equals(value, "WRITE_LOCK")) { - return FileLockType::WRITE_LOCK; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(FilterPropagateResult value) { - switch(value) { - case FilterPropagateResult::NO_PRUNING_POSSIBLE: - return "NO_PRUNING_POSSIBLE"; - case FilterPropagateResult::FILTER_ALWAYS_TRUE: - return "FILTER_ALWAYS_TRUE"; - case FilterPropagateResult::FILTER_ALWAYS_FALSE: - return "FILTER_ALWAYS_FALSE"; - case FilterPropagateResult::FILTER_TRUE_OR_NULL: - return "FILTER_TRUE_OR_NULL"; - case FilterPropagateResult::FILTER_FALSE_OR_NULL: - return "FILTER_FALSE_OR_NULL"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -FilterPropagateResult EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "NO_PRUNING_POSSIBLE")) { - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } - if (StringUtil::Equals(value, "FILTER_ALWAYS_TRUE")) { - return FilterPropagateResult::FILTER_ALWAYS_TRUE; - } - if (StringUtil::Equals(value, "FILTER_ALWAYS_FALSE")) { - return FilterPropagateResult::FILTER_ALWAYS_FALSE; - } - if (StringUtil::Equals(value, "FILTER_TRUE_OR_NULL")) { - return FilterPropagateResult::FILTER_TRUE_OR_NULL; - } - if (StringUtil::Equals(value, "FILTER_FALSE_OR_NULL")) { - return FilterPropagateResult::FILTER_FALSE_OR_NULL; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(ForeignKeyType value) { - switch(value) { - case ForeignKeyType::FK_TYPE_PRIMARY_KEY_TABLE: - return "FK_TYPE_PRIMARY_KEY_TABLE"; - case ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE: - return "FK_TYPE_FOREIGN_KEY_TABLE"; - case ForeignKeyType::FK_TYPE_SELF_REFERENCE_TABLE: - return "FK_TYPE_SELF_REFERENCE_TABLE"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -ForeignKeyType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "FK_TYPE_PRIMARY_KEY_TABLE")) { - return ForeignKeyType::FK_TYPE_PRIMARY_KEY_TABLE; - } - if (StringUtil::Equals(value, "FK_TYPE_FOREIGN_KEY_TABLE")) { - return ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE; - } - if (StringUtil::Equals(value, "FK_TYPE_SELF_REFERENCE_TABLE")) { - return ForeignKeyType::FK_TYPE_SELF_REFERENCE_TABLE; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(FunctionNullHandling value) { - switch(value) { - case FunctionNullHandling::DEFAULT_NULL_HANDLING: - return "DEFAULT_NULL_HANDLING"; - case FunctionNullHandling::SPECIAL_HANDLING: - return "SPECIAL_HANDLING"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -FunctionNullHandling EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "DEFAULT_NULL_HANDLING")) { - return FunctionNullHandling::DEFAULT_NULL_HANDLING; - } - if (StringUtil::Equals(value, "SPECIAL_HANDLING")) { - return FunctionNullHandling::SPECIAL_HANDLING; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(FunctionSideEffects value) { - switch(value) { - case FunctionSideEffects::NO_SIDE_EFFECTS: - return "NO_SIDE_EFFECTS"; - case FunctionSideEffects::HAS_SIDE_EFFECTS: - return "HAS_SIDE_EFFECTS"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -FunctionSideEffects EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "NO_SIDE_EFFECTS")) { - return FunctionSideEffects::NO_SIDE_EFFECTS; - } - if (StringUtil::Equals(value, "HAS_SIDE_EFFECTS")) { - return FunctionSideEffects::HAS_SIDE_EFFECTS; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(HLLStorageType value) { - switch(value) { - case HLLStorageType::UNCOMPRESSED: - return "UNCOMPRESSED"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -HLLStorageType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "UNCOMPRESSED")) { - return HLLStorageType::UNCOMPRESSED; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(IndexConstraintType value) { - switch(value) { - case IndexConstraintType::NONE: - return "NONE"; - case IndexConstraintType::UNIQUE: - return "UNIQUE"; - case IndexConstraintType::PRIMARY: - return "PRIMARY"; - case IndexConstraintType::FOREIGN: - return "FOREIGN"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -IndexConstraintType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "NONE")) { - return IndexConstraintType::NONE; - } - if (StringUtil::Equals(value, "UNIQUE")) { - return IndexConstraintType::UNIQUE; - } - if (StringUtil::Equals(value, "PRIMARY")) { - return IndexConstraintType::PRIMARY; - } - if (StringUtil::Equals(value, "FOREIGN")) { - return IndexConstraintType::FOREIGN; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(IndexType value) { - switch(value) { - case IndexType::INVALID: - return "INVALID"; - case IndexType::ART: - return "ART"; - case IndexType::EXTENSION: - return "EXTENSION"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -IndexType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "INVALID")) { - return IndexType::INVALID; - } - if (StringUtil::Equals(value, "ART")) { - return IndexType::ART; - } - if (StringUtil::Equals(value, "EXTENSION")) { - return IndexType::EXTENSION; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(InsertColumnOrder value) { - switch(value) { - case InsertColumnOrder::INSERT_BY_POSITION: - return "INSERT_BY_POSITION"; - case InsertColumnOrder::INSERT_BY_NAME: - return "INSERT_BY_NAME"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -InsertColumnOrder EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "INSERT_BY_POSITION")) { - return InsertColumnOrder::INSERT_BY_POSITION; - } - if (StringUtil::Equals(value, "INSERT_BY_NAME")) { - return InsertColumnOrder::INSERT_BY_NAME; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(InterruptMode value) { - switch(value) { - case InterruptMode::NO_INTERRUPTS: - return "NO_INTERRUPTS"; - case InterruptMode::TASK: - return "TASK"; - case InterruptMode::BLOCKING: - return "BLOCKING"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -InterruptMode EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "NO_INTERRUPTS")) { - return InterruptMode::NO_INTERRUPTS; - } - if (StringUtil::Equals(value, "TASK")) { - return InterruptMode::TASK; - } - if (StringUtil::Equals(value, "BLOCKING")) { - return InterruptMode::BLOCKING; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(JoinRefType value) { - switch(value) { - case JoinRefType::REGULAR: - return "REGULAR"; - case JoinRefType::NATURAL: - return "NATURAL"; - case JoinRefType::CROSS: - return "CROSS"; - case JoinRefType::POSITIONAL: - return "POSITIONAL"; - case JoinRefType::ASOF: - return "ASOF"; - case JoinRefType::DEPENDENT: - return "DEPENDENT"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -JoinRefType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "REGULAR")) { - return JoinRefType::REGULAR; - } - if (StringUtil::Equals(value, "NATURAL")) { - return JoinRefType::NATURAL; - } - if (StringUtil::Equals(value, "CROSS")) { - return JoinRefType::CROSS; - } - if (StringUtil::Equals(value, "POSITIONAL")) { - return JoinRefType::POSITIONAL; - } - if (StringUtil::Equals(value, "ASOF")) { - return JoinRefType::ASOF; - } - if (StringUtil::Equals(value, "DEPENDENT")) { - return JoinRefType::DEPENDENT; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(JoinType value) { - switch(value) { - case JoinType::INVALID: - return "INVALID"; - case JoinType::LEFT: - return "LEFT"; - case JoinType::RIGHT: - return "RIGHT"; - case JoinType::INNER: - return "INNER"; - case JoinType::OUTER: - return "FULL"; - case JoinType::SEMI: - return "SEMI"; - case JoinType::ANTI: - return "ANTI"; - case JoinType::MARK: - return "MARK"; - case JoinType::SINGLE: - return "SINGLE"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -JoinType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "INVALID")) { - return JoinType::INVALID; - } - if (StringUtil::Equals(value, "LEFT")) { - return JoinType::LEFT; - } - if (StringUtil::Equals(value, "RIGHT")) { - return JoinType::RIGHT; - } - if (StringUtil::Equals(value, "INNER")) { - return JoinType::INNER; - } - if (StringUtil::Equals(value, "FULL")) { - return JoinType::OUTER; - } - if (StringUtil::Equals(value, "SEMI")) { - return JoinType::SEMI; - } - if (StringUtil::Equals(value, "ANTI")) { - return JoinType::ANTI; - } - if (StringUtil::Equals(value, "MARK")) { - return JoinType::MARK; - } - if (StringUtil::Equals(value, "SINGLE")) { - return JoinType::SINGLE; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(KeywordCategory value) { - switch(value) { - case KeywordCategory::KEYWORD_RESERVED: - return "KEYWORD_RESERVED"; - case KeywordCategory::KEYWORD_UNRESERVED: - return "KEYWORD_UNRESERVED"; - case KeywordCategory::KEYWORD_TYPE_FUNC: - return "KEYWORD_TYPE_FUNC"; - case KeywordCategory::KEYWORD_COL_NAME: - return "KEYWORD_COL_NAME"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -KeywordCategory EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "KEYWORD_RESERVED")) { - return KeywordCategory::KEYWORD_RESERVED; - } - if (StringUtil::Equals(value, "KEYWORD_UNRESERVED")) { - return KeywordCategory::KEYWORD_UNRESERVED; - } - if (StringUtil::Equals(value, "KEYWORD_TYPE_FUNC")) { - return KeywordCategory::KEYWORD_TYPE_FUNC; - } - if (StringUtil::Equals(value, "KEYWORD_COL_NAME")) { - return KeywordCategory::KEYWORD_COL_NAME; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(LoadType value) { - switch(value) { - case LoadType::LOAD: - return "LOAD"; - case LoadType::INSTALL: - return "INSTALL"; - case LoadType::FORCE_INSTALL: - return "FORCE_INSTALL"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -LoadType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "LOAD")) { - return LoadType::LOAD; - } - if (StringUtil::Equals(value, "INSTALL")) { - return LoadType::INSTALL; - } - if (StringUtil::Equals(value, "FORCE_INSTALL")) { - return LoadType::FORCE_INSTALL; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(LogicalOperatorType value) { - switch(value) { - case LogicalOperatorType::LOGICAL_INVALID: - return "LOGICAL_INVALID"; - case LogicalOperatorType::LOGICAL_PROJECTION: - return "LOGICAL_PROJECTION"; - case LogicalOperatorType::LOGICAL_FILTER: - return "LOGICAL_FILTER"; - case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: - return "LOGICAL_AGGREGATE_AND_GROUP_BY"; - case LogicalOperatorType::LOGICAL_WINDOW: - return "LOGICAL_WINDOW"; - case LogicalOperatorType::LOGICAL_UNNEST: - return "LOGICAL_UNNEST"; - case LogicalOperatorType::LOGICAL_LIMIT: - return "LOGICAL_LIMIT"; - case LogicalOperatorType::LOGICAL_ORDER_BY: - return "LOGICAL_ORDER_BY"; - case LogicalOperatorType::LOGICAL_TOP_N: - return "LOGICAL_TOP_N"; - case LogicalOperatorType::LOGICAL_COPY_TO_FILE: - return "LOGICAL_COPY_TO_FILE"; - case LogicalOperatorType::LOGICAL_DISTINCT: - return "LOGICAL_DISTINCT"; - case LogicalOperatorType::LOGICAL_SAMPLE: - return "LOGICAL_SAMPLE"; - case LogicalOperatorType::LOGICAL_LIMIT_PERCENT: - return "LOGICAL_LIMIT_PERCENT"; - case LogicalOperatorType::LOGICAL_PIVOT: - return "LOGICAL_PIVOT"; - case LogicalOperatorType::LOGICAL_GET: - return "LOGICAL_GET"; - case LogicalOperatorType::LOGICAL_CHUNK_GET: - return "LOGICAL_CHUNK_GET"; - case LogicalOperatorType::LOGICAL_DELIM_GET: - return "LOGICAL_DELIM_GET"; - case LogicalOperatorType::LOGICAL_EXPRESSION_GET: - return "LOGICAL_EXPRESSION_GET"; - case LogicalOperatorType::LOGICAL_DUMMY_SCAN: - return "LOGICAL_DUMMY_SCAN"; - case LogicalOperatorType::LOGICAL_EMPTY_RESULT: - return "LOGICAL_EMPTY_RESULT"; - case LogicalOperatorType::LOGICAL_CTE_REF: - return "LOGICAL_CTE_REF"; - case LogicalOperatorType::LOGICAL_JOIN: - return "LOGICAL_JOIN"; - case LogicalOperatorType::LOGICAL_DELIM_JOIN: - return "LOGICAL_DELIM_JOIN"; - case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: - return "LOGICAL_COMPARISON_JOIN"; - case LogicalOperatorType::LOGICAL_ANY_JOIN: - return "LOGICAL_ANY_JOIN"; - case LogicalOperatorType::LOGICAL_CROSS_PRODUCT: - return "LOGICAL_CROSS_PRODUCT"; - case LogicalOperatorType::LOGICAL_POSITIONAL_JOIN: - return "LOGICAL_POSITIONAL_JOIN"; - case LogicalOperatorType::LOGICAL_ASOF_JOIN: - return "LOGICAL_ASOF_JOIN"; - case LogicalOperatorType::LOGICAL_DEPENDENT_JOIN: - return "LOGICAL_DEPENDENT_JOIN"; - case LogicalOperatorType::LOGICAL_UNION: - return "LOGICAL_UNION"; - case LogicalOperatorType::LOGICAL_EXCEPT: - return "LOGICAL_EXCEPT"; - case LogicalOperatorType::LOGICAL_INTERSECT: - return "LOGICAL_INTERSECT"; - case LogicalOperatorType::LOGICAL_RECURSIVE_CTE: - return "LOGICAL_RECURSIVE_CTE"; - case LogicalOperatorType::LOGICAL_MATERIALIZED_CTE: - return "LOGICAL_MATERIALIZED_CTE"; - case LogicalOperatorType::LOGICAL_INSERT: - return "LOGICAL_INSERT"; - case LogicalOperatorType::LOGICAL_DELETE: - return "LOGICAL_DELETE"; - case LogicalOperatorType::LOGICAL_UPDATE: - return "LOGICAL_UPDATE"; - case LogicalOperatorType::LOGICAL_ALTER: - return "LOGICAL_ALTER"; - case LogicalOperatorType::LOGICAL_CREATE_TABLE: - return "LOGICAL_CREATE_TABLE"; - case LogicalOperatorType::LOGICAL_CREATE_INDEX: - return "LOGICAL_CREATE_INDEX"; - case LogicalOperatorType::LOGICAL_CREATE_SEQUENCE: - return "LOGICAL_CREATE_SEQUENCE"; - case LogicalOperatorType::LOGICAL_CREATE_VIEW: - return "LOGICAL_CREATE_VIEW"; - case LogicalOperatorType::LOGICAL_CREATE_SCHEMA: - return "LOGICAL_CREATE_SCHEMA"; - case LogicalOperatorType::LOGICAL_CREATE_MACRO: - return "LOGICAL_CREATE_MACRO"; - case LogicalOperatorType::LOGICAL_DROP: - return "LOGICAL_DROP"; - case LogicalOperatorType::LOGICAL_PRAGMA: - return "LOGICAL_PRAGMA"; - case LogicalOperatorType::LOGICAL_TRANSACTION: - return "LOGICAL_TRANSACTION"; - case LogicalOperatorType::LOGICAL_CREATE_TYPE: - return "LOGICAL_CREATE_TYPE"; - case LogicalOperatorType::LOGICAL_ATTACH: - return "LOGICAL_ATTACH"; - case LogicalOperatorType::LOGICAL_DETACH: - return "LOGICAL_DETACH"; - case LogicalOperatorType::LOGICAL_EXPLAIN: - return "LOGICAL_EXPLAIN"; - case LogicalOperatorType::LOGICAL_SHOW: - return "LOGICAL_SHOW"; - case LogicalOperatorType::LOGICAL_PREPARE: - return "LOGICAL_PREPARE"; - case LogicalOperatorType::LOGICAL_EXECUTE: - return "LOGICAL_EXECUTE"; - case LogicalOperatorType::LOGICAL_EXPORT: - return "LOGICAL_EXPORT"; - case LogicalOperatorType::LOGICAL_VACUUM: - return "LOGICAL_VACUUM"; - case LogicalOperatorType::LOGICAL_SET: - return "LOGICAL_SET"; - case LogicalOperatorType::LOGICAL_LOAD: - return "LOGICAL_LOAD"; - case LogicalOperatorType::LOGICAL_RESET: - return "LOGICAL_RESET"; - case LogicalOperatorType::LOGICAL_EXTENSION_OPERATOR: - return "LOGICAL_EXTENSION_OPERATOR"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -LogicalOperatorType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "LOGICAL_INVALID")) { - return LogicalOperatorType::LOGICAL_INVALID; - } - if (StringUtil::Equals(value, "LOGICAL_PROJECTION")) { - return LogicalOperatorType::LOGICAL_PROJECTION; - } - if (StringUtil::Equals(value, "LOGICAL_FILTER")) { - return LogicalOperatorType::LOGICAL_FILTER; - } - if (StringUtil::Equals(value, "LOGICAL_AGGREGATE_AND_GROUP_BY")) { - return LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY; - } - if (StringUtil::Equals(value, "LOGICAL_WINDOW")) { - return LogicalOperatorType::LOGICAL_WINDOW; - } - if (StringUtil::Equals(value, "LOGICAL_UNNEST")) { - return LogicalOperatorType::LOGICAL_UNNEST; - } - if (StringUtil::Equals(value, "LOGICAL_LIMIT")) { - return LogicalOperatorType::LOGICAL_LIMIT; - } - if (StringUtil::Equals(value, "LOGICAL_ORDER_BY")) { - return LogicalOperatorType::LOGICAL_ORDER_BY; - } - if (StringUtil::Equals(value, "LOGICAL_TOP_N")) { - return LogicalOperatorType::LOGICAL_TOP_N; - } - if (StringUtil::Equals(value, "LOGICAL_COPY_TO_FILE")) { - return LogicalOperatorType::LOGICAL_COPY_TO_FILE; - } - if (StringUtil::Equals(value, "LOGICAL_DISTINCT")) { - return LogicalOperatorType::LOGICAL_DISTINCT; - } - if (StringUtil::Equals(value, "LOGICAL_SAMPLE")) { - return LogicalOperatorType::LOGICAL_SAMPLE; - } - if (StringUtil::Equals(value, "LOGICAL_LIMIT_PERCENT")) { - return LogicalOperatorType::LOGICAL_LIMIT_PERCENT; - } - if (StringUtil::Equals(value, "LOGICAL_PIVOT")) { - return LogicalOperatorType::LOGICAL_PIVOT; - } - if (StringUtil::Equals(value, "LOGICAL_GET")) { - return LogicalOperatorType::LOGICAL_GET; - } - if (StringUtil::Equals(value, "LOGICAL_CHUNK_GET")) { - return LogicalOperatorType::LOGICAL_CHUNK_GET; - } - if (StringUtil::Equals(value, "LOGICAL_DELIM_GET")) { - return LogicalOperatorType::LOGICAL_DELIM_GET; - } - if (StringUtil::Equals(value, "LOGICAL_EXPRESSION_GET")) { - return LogicalOperatorType::LOGICAL_EXPRESSION_GET; - } - if (StringUtil::Equals(value, "LOGICAL_DUMMY_SCAN")) { - return LogicalOperatorType::LOGICAL_DUMMY_SCAN; - } - if (StringUtil::Equals(value, "LOGICAL_EMPTY_RESULT")) { - return LogicalOperatorType::LOGICAL_EMPTY_RESULT; - } - if (StringUtil::Equals(value, "LOGICAL_CTE_REF")) { - return LogicalOperatorType::LOGICAL_CTE_REF; - } - if (StringUtil::Equals(value, "LOGICAL_JOIN")) { - return LogicalOperatorType::LOGICAL_JOIN; - } - if (StringUtil::Equals(value, "LOGICAL_DELIM_JOIN")) { - return LogicalOperatorType::LOGICAL_DELIM_JOIN; - } - if (StringUtil::Equals(value, "LOGICAL_COMPARISON_JOIN")) { - return LogicalOperatorType::LOGICAL_COMPARISON_JOIN; - } - if (StringUtil::Equals(value, "LOGICAL_ANY_JOIN")) { - return LogicalOperatorType::LOGICAL_ANY_JOIN; - } - if (StringUtil::Equals(value, "LOGICAL_CROSS_PRODUCT")) { - return LogicalOperatorType::LOGICAL_CROSS_PRODUCT; - } - if (StringUtil::Equals(value, "LOGICAL_POSITIONAL_JOIN")) { - return LogicalOperatorType::LOGICAL_POSITIONAL_JOIN; - } - if (StringUtil::Equals(value, "LOGICAL_ASOF_JOIN")) { - return LogicalOperatorType::LOGICAL_ASOF_JOIN; - } - if (StringUtil::Equals(value, "LOGICAL_DEPENDENT_JOIN")) { - return LogicalOperatorType::LOGICAL_DEPENDENT_JOIN; - } - if (StringUtil::Equals(value, "LOGICAL_UNION")) { - return LogicalOperatorType::LOGICAL_UNION; - } - if (StringUtil::Equals(value, "LOGICAL_EXCEPT")) { - return LogicalOperatorType::LOGICAL_EXCEPT; - } - if (StringUtil::Equals(value, "LOGICAL_INTERSECT")) { - return LogicalOperatorType::LOGICAL_INTERSECT; - } - if (StringUtil::Equals(value, "LOGICAL_RECURSIVE_CTE")) { - return LogicalOperatorType::LOGICAL_RECURSIVE_CTE; - } - if (StringUtil::Equals(value, "LOGICAL_MATERIALIZED_CTE")) { - return LogicalOperatorType::LOGICAL_MATERIALIZED_CTE; - } - if (StringUtil::Equals(value, "LOGICAL_INSERT")) { - return LogicalOperatorType::LOGICAL_INSERT; - } - if (StringUtil::Equals(value, "LOGICAL_DELETE")) { - return LogicalOperatorType::LOGICAL_DELETE; - } - if (StringUtil::Equals(value, "LOGICAL_UPDATE")) { - return LogicalOperatorType::LOGICAL_UPDATE; - } - if (StringUtil::Equals(value, "LOGICAL_ALTER")) { - return LogicalOperatorType::LOGICAL_ALTER; - } - if (StringUtil::Equals(value, "LOGICAL_CREATE_TABLE")) { - return LogicalOperatorType::LOGICAL_CREATE_TABLE; - } - if (StringUtil::Equals(value, "LOGICAL_CREATE_INDEX")) { - return LogicalOperatorType::LOGICAL_CREATE_INDEX; - } - if (StringUtil::Equals(value, "LOGICAL_CREATE_SEQUENCE")) { - return LogicalOperatorType::LOGICAL_CREATE_SEQUENCE; - } - if (StringUtil::Equals(value, "LOGICAL_CREATE_VIEW")) { - return LogicalOperatorType::LOGICAL_CREATE_VIEW; - } - if (StringUtil::Equals(value, "LOGICAL_CREATE_SCHEMA")) { - return LogicalOperatorType::LOGICAL_CREATE_SCHEMA; - } - if (StringUtil::Equals(value, "LOGICAL_CREATE_MACRO")) { - return LogicalOperatorType::LOGICAL_CREATE_MACRO; - } - if (StringUtil::Equals(value, "LOGICAL_DROP")) { - return LogicalOperatorType::LOGICAL_DROP; - } - if (StringUtil::Equals(value, "LOGICAL_PRAGMA")) { - return LogicalOperatorType::LOGICAL_PRAGMA; - } - if (StringUtil::Equals(value, "LOGICAL_TRANSACTION")) { - return LogicalOperatorType::LOGICAL_TRANSACTION; - } - if (StringUtil::Equals(value, "LOGICAL_CREATE_TYPE")) { - return LogicalOperatorType::LOGICAL_CREATE_TYPE; - } - if (StringUtil::Equals(value, "LOGICAL_ATTACH")) { - return LogicalOperatorType::LOGICAL_ATTACH; - } - if (StringUtil::Equals(value, "LOGICAL_DETACH")) { - return LogicalOperatorType::LOGICAL_DETACH; - } - if (StringUtil::Equals(value, "LOGICAL_EXPLAIN")) { - return LogicalOperatorType::LOGICAL_EXPLAIN; - } - if (StringUtil::Equals(value, "LOGICAL_SHOW")) { - return LogicalOperatorType::LOGICAL_SHOW; - } - if (StringUtil::Equals(value, "LOGICAL_PREPARE")) { - return LogicalOperatorType::LOGICAL_PREPARE; - } - if (StringUtil::Equals(value, "LOGICAL_EXECUTE")) { - return LogicalOperatorType::LOGICAL_EXECUTE; - } - if (StringUtil::Equals(value, "LOGICAL_EXPORT")) { - return LogicalOperatorType::LOGICAL_EXPORT; - } - if (StringUtil::Equals(value, "LOGICAL_VACUUM")) { - return LogicalOperatorType::LOGICAL_VACUUM; - } - if (StringUtil::Equals(value, "LOGICAL_SET")) { - return LogicalOperatorType::LOGICAL_SET; - } - if (StringUtil::Equals(value, "LOGICAL_LOAD")) { - return LogicalOperatorType::LOGICAL_LOAD; - } - if (StringUtil::Equals(value, "LOGICAL_RESET")) { - return LogicalOperatorType::LOGICAL_RESET; - } - if (StringUtil::Equals(value, "LOGICAL_EXTENSION_OPERATOR")) { - return LogicalOperatorType::LOGICAL_EXTENSION_OPERATOR; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(LogicalTypeId value) { - switch(value) { - case LogicalTypeId::INVALID: - return "INVALID"; - case LogicalTypeId::SQLNULL: - return "NULL"; - case LogicalTypeId::UNKNOWN: - return "UNKNOWN"; - case LogicalTypeId::ANY: - return "ANY"; - case LogicalTypeId::USER: - return "USER"; - case LogicalTypeId::BOOLEAN: - return "BOOLEAN"; - case LogicalTypeId::TINYINT: - return "TINYINT"; - case LogicalTypeId::SMALLINT: - return "SMALLINT"; - case LogicalTypeId::INTEGER: - return "INTEGER"; - case LogicalTypeId::BIGINT: - return "BIGINT"; - case LogicalTypeId::DATE: - return "DATE"; - case LogicalTypeId::TIME: - return "TIME"; - case LogicalTypeId::TIMESTAMP_SEC: - return "TIMESTAMP_S"; - case LogicalTypeId::TIMESTAMP_MS: - return "TIMESTAMP_MS"; - case LogicalTypeId::TIMESTAMP: - return "TIMESTAMP"; - case LogicalTypeId::TIMESTAMP_NS: - return "TIMESTAMP_NS"; - case LogicalTypeId::DECIMAL: - return "DECIMAL"; - case LogicalTypeId::FLOAT: - return "FLOAT"; - case LogicalTypeId::DOUBLE: - return "DOUBLE"; - case LogicalTypeId::CHAR: - return "CHAR"; - case LogicalTypeId::VARCHAR: - return "VARCHAR"; - case LogicalTypeId::BLOB: - return "BLOB"; - case LogicalTypeId::INTERVAL: - return "INTERVAL"; - case LogicalTypeId::UTINYINT: - return "UTINYINT"; - case LogicalTypeId::USMALLINT: - return "USMALLINT"; - case LogicalTypeId::UINTEGER: - return "UINTEGER"; - case LogicalTypeId::UBIGINT: - return "UBIGINT"; - case LogicalTypeId::TIMESTAMP_TZ: - return "TIMESTAMP WITH TIME ZONE"; - case LogicalTypeId::TIME_TZ: - return "TIME WITH TIME ZONE"; - case LogicalTypeId::BIT: - return "BIT"; - case LogicalTypeId::HUGEINT: - return "HUGEINT"; - case LogicalTypeId::POINTER: - return "POINTER"; - case LogicalTypeId::VALIDITY: - return "VALIDITY"; - case LogicalTypeId::UUID: - return "UUID"; - case LogicalTypeId::STRUCT: - return "STRUCT"; - case LogicalTypeId::LIST: - return "LIST"; - case LogicalTypeId::MAP: - return "MAP"; - case LogicalTypeId::TABLE: - return "TABLE"; - case LogicalTypeId::ENUM: - return "ENUM"; - case LogicalTypeId::AGGREGATE_STATE: - return "AGGREGATE_STATE"; - case LogicalTypeId::LAMBDA: - return "LAMBDA"; - case LogicalTypeId::UNION: - return "UNION"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -LogicalTypeId EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "INVALID")) { - return LogicalTypeId::INVALID; - } - if (StringUtil::Equals(value, "NULL")) { - return LogicalTypeId::SQLNULL; - } - if (StringUtil::Equals(value, "UNKNOWN")) { - return LogicalTypeId::UNKNOWN; - } - if (StringUtil::Equals(value, "ANY")) { - return LogicalTypeId::ANY; - } - if (StringUtil::Equals(value, "USER")) { - return LogicalTypeId::USER; - } - if (StringUtil::Equals(value, "BOOLEAN")) { - return LogicalTypeId::BOOLEAN; - } - if (StringUtil::Equals(value, "TINYINT")) { - return LogicalTypeId::TINYINT; - } - if (StringUtil::Equals(value, "SMALLINT")) { - return LogicalTypeId::SMALLINT; - } - if (StringUtil::Equals(value, "INTEGER")) { - return LogicalTypeId::INTEGER; - } - if (StringUtil::Equals(value, "BIGINT")) { - return LogicalTypeId::BIGINT; - } - if (StringUtil::Equals(value, "DATE")) { - return LogicalTypeId::DATE; - } - if (StringUtil::Equals(value, "TIME")) { - return LogicalTypeId::TIME; - } - if (StringUtil::Equals(value, "TIMESTAMP_S")) { - return LogicalTypeId::TIMESTAMP_SEC; - } - if (StringUtil::Equals(value, "TIMESTAMP_MS")) { - return LogicalTypeId::TIMESTAMP_MS; - } - if (StringUtil::Equals(value, "TIMESTAMP")) { - return LogicalTypeId::TIMESTAMP; - } - if (StringUtil::Equals(value, "TIMESTAMP_NS")) { - return LogicalTypeId::TIMESTAMP_NS; - } - if (StringUtil::Equals(value, "DECIMAL")) { - return LogicalTypeId::DECIMAL; - } - if (StringUtil::Equals(value, "FLOAT")) { - return LogicalTypeId::FLOAT; - } - if (StringUtil::Equals(value, "DOUBLE")) { - return LogicalTypeId::DOUBLE; - } - if (StringUtil::Equals(value, "CHAR")) { - return LogicalTypeId::CHAR; - } - if (StringUtil::Equals(value, "VARCHAR")) { - return LogicalTypeId::VARCHAR; - } - if (StringUtil::Equals(value, "BLOB")) { - return LogicalTypeId::BLOB; - } - if (StringUtil::Equals(value, "INTERVAL")) { - return LogicalTypeId::INTERVAL; - } - if (StringUtil::Equals(value, "UTINYINT")) { - return LogicalTypeId::UTINYINT; - } - if (StringUtil::Equals(value, "USMALLINT")) { - return LogicalTypeId::USMALLINT; - } - if (StringUtil::Equals(value, "UINTEGER")) { - return LogicalTypeId::UINTEGER; - } - if (StringUtil::Equals(value, "UBIGINT")) { - return LogicalTypeId::UBIGINT; - } - if (StringUtil::Equals(value, "TIMESTAMP WITH TIME ZONE")) { - return LogicalTypeId::TIMESTAMP_TZ; - } - if (StringUtil::Equals(value, "TIME WITH TIME ZONE")) { - return LogicalTypeId::TIME_TZ; - } - if (StringUtil::Equals(value, "BIT")) { - return LogicalTypeId::BIT; - } - if (StringUtil::Equals(value, "HUGEINT")) { - return LogicalTypeId::HUGEINT; - } - if (StringUtil::Equals(value, "POINTER")) { - return LogicalTypeId::POINTER; - } - if (StringUtil::Equals(value, "VALIDITY")) { - return LogicalTypeId::VALIDITY; - } - if (StringUtil::Equals(value, "UUID")) { - return LogicalTypeId::UUID; - } - if (StringUtil::Equals(value, "STRUCT")) { - return LogicalTypeId::STRUCT; - } - if (StringUtil::Equals(value, "LIST")) { - return LogicalTypeId::LIST; - } - if (StringUtil::Equals(value, "MAP")) { - return LogicalTypeId::MAP; - } - if (StringUtil::Equals(value, "TABLE")) { - return LogicalTypeId::TABLE; - } - if (StringUtil::Equals(value, "ENUM")) { - return LogicalTypeId::ENUM; - } - if (StringUtil::Equals(value, "AGGREGATE_STATE")) { - return LogicalTypeId::AGGREGATE_STATE; - } - if (StringUtil::Equals(value, "LAMBDA")) { - return LogicalTypeId::LAMBDA; - } - if (StringUtil::Equals(value, "UNION")) { - return LogicalTypeId::UNION; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(LookupResultType value) { - switch(value) { - case LookupResultType::LOOKUP_MISS: - return "LOOKUP_MISS"; - case LookupResultType::LOOKUP_HIT: - return "LOOKUP_HIT"; - case LookupResultType::LOOKUP_NULL: - return "LOOKUP_NULL"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -LookupResultType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "LOOKUP_MISS")) { - return LookupResultType::LOOKUP_MISS; - } - if (StringUtil::Equals(value, "LOOKUP_HIT")) { - return LookupResultType::LOOKUP_HIT; - } - if (StringUtil::Equals(value, "LOOKUP_NULL")) { - return LookupResultType::LOOKUP_NULL; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(MacroType value) { - switch(value) { - case MacroType::VOID_MACRO: - return "VOID_MACRO"; - case MacroType::TABLE_MACRO: - return "TABLE_MACRO"; - case MacroType::SCALAR_MACRO: - return "SCALAR_MACRO"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -MacroType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "VOID_MACRO")) { - return MacroType::VOID_MACRO; - } - if (StringUtil::Equals(value, "TABLE_MACRO")) { - return MacroType::TABLE_MACRO; - } - if (StringUtil::Equals(value, "SCALAR_MACRO")) { - return MacroType::SCALAR_MACRO; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(MapInvalidReason value) { - switch(value) { - case MapInvalidReason::VALID: - return "VALID"; - case MapInvalidReason::NULL_KEY_LIST: - return "NULL_KEY_LIST"; - case MapInvalidReason::NULL_KEY: - return "NULL_KEY"; - case MapInvalidReason::DUPLICATE_KEY: - return "DUPLICATE_KEY"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -MapInvalidReason EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "VALID")) { - return MapInvalidReason::VALID; - } - if (StringUtil::Equals(value, "NULL_KEY_LIST")) { - return MapInvalidReason::NULL_KEY_LIST; - } - if (StringUtil::Equals(value, "NULL_KEY")) { - return MapInvalidReason::NULL_KEY; - } - if (StringUtil::Equals(value, "DUPLICATE_KEY")) { - return MapInvalidReason::DUPLICATE_KEY; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(NType value) { - switch(value) { - case NType::PREFIX: - return "PREFIX"; - case NType::LEAF: - return "LEAF"; - case NType::NODE_4: - return "NODE_4"; - case NType::NODE_16: - return "NODE_16"; - case NType::NODE_48: - return "NODE_48"; - case NType::NODE_256: - return "NODE_256"; - case NType::LEAF_INLINED: - return "LEAF_INLINED"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -NType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "PREFIX")) { - return NType::PREFIX; - } - if (StringUtil::Equals(value, "LEAF")) { - return NType::LEAF; - } - if (StringUtil::Equals(value, "NODE_4")) { - return NType::NODE_4; - } - if (StringUtil::Equals(value, "NODE_16")) { - return NType::NODE_16; - } - if (StringUtil::Equals(value, "NODE_48")) { - return NType::NODE_48; - } - if (StringUtil::Equals(value, "NODE_256")) { - return NType::NODE_256; - } - if (StringUtil::Equals(value, "LEAF_INLINED")) { - return NType::LEAF_INLINED; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(NewLineIdentifier value) { - switch(value) { - case NewLineIdentifier::SINGLE: - return "SINGLE"; - case NewLineIdentifier::CARRY_ON: - return "CARRY_ON"; - case NewLineIdentifier::MIX: - return "MIX"; - case NewLineIdentifier::NOT_SET: - return "NOT_SET"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -NewLineIdentifier EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "SINGLE")) { - return NewLineIdentifier::SINGLE; - } - if (StringUtil::Equals(value, "CARRY_ON")) { - return NewLineIdentifier::CARRY_ON; - } - if (StringUtil::Equals(value, "MIX")) { - return NewLineIdentifier::MIX; - } - if (StringUtil::Equals(value, "NOT_SET")) { - return NewLineIdentifier::NOT_SET; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(OnConflictAction value) { - switch(value) { - case OnConflictAction::THROW: - return "THROW"; - case OnConflictAction::NOTHING: - return "NOTHING"; - case OnConflictAction::UPDATE: - return "UPDATE"; - case OnConflictAction::REPLACE: - return "REPLACE"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -OnConflictAction EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "THROW")) { - return OnConflictAction::THROW; - } - if (StringUtil::Equals(value, "NOTHING")) { - return OnConflictAction::NOTHING; - } - if (StringUtil::Equals(value, "UPDATE")) { - return OnConflictAction::UPDATE; - } - if (StringUtil::Equals(value, "REPLACE")) { - return OnConflictAction::REPLACE; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(OnCreateConflict value) { - switch(value) { - case OnCreateConflict::ERROR_ON_CONFLICT: - return "ERROR_ON_CONFLICT"; - case OnCreateConflict::IGNORE_ON_CONFLICT: - return "IGNORE_ON_CONFLICT"; - case OnCreateConflict::REPLACE_ON_CONFLICT: - return "REPLACE_ON_CONFLICT"; - case OnCreateConflict::ALTER_ON_CONFLICT: - return "ALTER_ON_CONFLICT"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -OnCreateConflict EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "ERROR_ON_CONFLICT")) { - return OnCreateConflict::ERROR_ON_CONFLICT; - } - if (StringUtil::Equals(value, "IGNORE_ON_CONFLICT")) { - return OnCreateConflict::IGNORE_ON_CONFLICT; - } - if (StringUtil::Equals(value, "REPLACE_ON_CONFLICT")) { - return OnCreateConflict::REPLACE_ON_CONFLICT; - } - if (StringUtil::Equals(value, "ALTER_ON_CONFLICT")) { - return OnCreateConflict::ALTER_ON_CONFLICT; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(OnEntryNotFound value) { - switch(value) { - case OnEntryNotFound::THROW_EXCEPTION: - return "THROW_EXCEPTION"; - case OnEntryNotFound::RETURN_NULL: - return "RETURN_NULL"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -OnEntryNotFound EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "THROW_EXCEPTION")) { - return OnEntryNotFound::THROW_EXCEPTION; - } - if (StringUtil::Equals(value, "RETURN_NULL")) { - return OnEntryNotFound::RETURN_NULL; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(OperatorFinalizeResultType value) { - switch(value) { - case OperatorFinalizeResultType::HAVE_MORE_OUTPUT: - return "HAVE_MORE_OUTPUT"; - case OperatorFinalizeResultType::FINISHED: - return "FINISHED"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -OperatorFinalizeResultType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "HAVE_MORE_OUTPUT")) { - return OperatorFinalizeResultType::HAVE_MORE_OUTPUT; - } - if (StringUtil::Equals(value, "FINISHED")) { - return OperatorFinalizeResultType::FINISHED; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(OperatorResultType value) { - switch(value) { - case OperatorResultType::NEED_MORE_INPUT: - return "NEED_MORE_INPUT"; - case OperatorResultType::HAVE_MORE_OUTPUT: - return "HAVE_MORE_OUTPUT"; - case OperatorResultType::FINISHED: - return "FINISHED"; - case OperatorResultType::BLOCKED: - return "BLOCKED"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -OperatorResultType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "NEED_MORE_INPUT")) { - return OperatorResultType::NEED_MORE_INPUT; - } - if (StringUtil::Equals(value, "HAVE_MORE_OUTPUT")) { - return OperatorResultType::HAVE_MORE_OUTPUT; - } - if (StringUtil::Equals(value, "FINISHED")) { - return OperatorResultType::FINISHED; - } - if (StringUtil::Equals(value, "BLOCKED")) { - return OperatorResultType::BLOCKED; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(OptimizerType value) { - switch(value) { - case OptimizerType::INVALID: - return "INVALID"; - case OptimizerType::EXPRESSION_REWRITER: - return "EXPRESSION_REWRITER"; - case OptimizerType::FILTER_PULLUP: - return "FILTER_PULLUP"; - case OptimizerType::FILTER_PUSHDOWN: - return "FILTER_PUSHDOWN"; - case OptimizerType::REGEX_RANGE: - return "REGEX_RANGE"; - case OptimizerType::IN_CLAUSE: - return "IN_CLAUSE"; - case OptimizerType::JOIN_ORDER: - return "JOIN_ORDER"; - case OptimizerType::DELIMINATOR: - return "DELIMINATOR"; - case OptimizerType::UNNEST_REWRITER: - return "UNNEST_REWRITER"; - case OptimizerType::UNUSED_COLUMNS: - return "UNUSED_COLUMNS"; - case OptimizerType::STATISTICS_PROPAGATION: - return "STATISTICS_PROPAGATION"; - case OptimizerType::COMMON_SUBEXPRESSIONS: - return "COMMON_SUBEXPRESSIONS"; - case OptimizerType::COMMON_AGGREGATE: - return "COMMON_AGGREGATE"; - case OptimizerType::COLUMN_LIFETIME: - return "COLUMN_LIFETIME"; - case OptimizerType::TOP_N: - return "TOP_N"; - case OptimizerType::COMPRESSED_MATERIALIZATION: - return "COMPRESSED_MATERIALIZATION"; - case OptimizerType::DUPLICATE_GROUPS: - return "DUPLICATE_GROUPS"; - case OptimizerType::REORDER_FILTER: - return "REORDER_FILTER"; - case OptimizerType::EXTENSION: - return "EXTENSION"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -OptimizerType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "INVALID")) { - return OptimizerType::INVALID; - } - if (StringUtil::Equals(value, "EXPRESSION_REWRITER")) { - return OptimizerType::EXPRESSION_REWRITER; - } - if (StringUtil::Equals(value, "FILTER_PULLUP")) { - return OptimizerType::FILTER_PULLUP; - } - if (StringUtil::Equals(value, "FILTER_PUSHDOWN")) { - return OptimizerType::FILTER_PUSHDOWN; - } - if (StringUtil::Equals(value, "REGEX_RANGE")) { - return OptimizerType::REGEX_RANGE; - } - if (StringUtil::Equals(value, "IN_CLAUSE")) { - return OptimizerType::IN_CLAUSE; - } - if (StringUtil::Equals(value, "JOIN_ORDER")) { - return OptimizerType::JOIN_ORDER; - } - if (StringUtil::Equals(value, "DELIMINATOR")) { - return OptimizerType::DELIMINATOR; - } - if (StringUtil::Equals(value, "UNNEST_REWRITER")) { - return OptimizerType::UNNEST_REWRITER; - } - if (StringUtil::Equals(value, "UNUSED_COLUMNS")) { - return OptimizerType::UNUSED_COLUMNS; - } - if (StringUtil::Equals(value, "STATISTICS_PROPAGATION")) { - return OptimizerType::STATISTICS_PROPAGATION; - } - if (StringUtil::Equals(value, "COMMON_SUBEXPRESSIONS")) { - return OptimizerType::COMMON_SUBEXPRESSIONS; - } - if (StringUtil::Equals(value, "COMMON_AGGREGATE")) { - return OptimizerType::COMMON_AGGREGATE; - } - if (StringUtil::Equals(value, "COLUMN_LIFETIME")) { - return OptimizerType::COLUMN_LIFETIME; - } - if (StringUtil::Equals(value, "TOP_N")) { - return OptimizerType::TOP_N; - } - if (StringUtil::Equals(value, "COMPRESSED_MATERIALIZATION")) { - return OptimizerType::COMPRESSED_MATERIALIZATION; - } - if (StringUtil::Equals(value, "DUPLICATE_GROUPS")) { - return OptimizerType::DUPLICATE_GROUPS; - } - if (StringUtil::Equals(value, "REORDER_FILTER")) { - return OptimizerType::REORDER_FILTER; - } - if (StringUtil::Equals(value, "EXTENSION")) { - return OptimizerType::EXTENSION; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(OrderByNullType value) { - switch(value) { - case OrderByNullType::INVALID: - return "INVALID"; - case OrderByNullType::ORDER_DEFAULT: - return "ORDER_DEFAULT"; - case OrderByNullType::NULLS_FIRST: - return "NULLS_FIRST"; - case OrderByNullType::NULLS_LAST: - return "NULLS_LAST"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -OrderByNullType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "INVALID")) { - return OrderByNullType::INVALID; - } - if (StringUtil::Equals(value, "ORDER_DEFAULT") || StringUtil::Equals(value, "DEFAULT")) { - return OrderByNullType::ORDER_DEFAULT; - } - if (StringUtil::Equals(value, "NULLS_FIRST") || StringUtil::Equals(value, "NULLS FIRST")) { - return OrderByNullType::NULLS_FIRST; - } - if (StringUtil::Equals(value, "NULLS_LAST") || StringUtil::Equals(value, "NULLS LAST")) { - return OrderByNullType::NULLS_LAST; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(OrderPreservationType value) { - switch(value) { - case OrderPreservationType::NO_ORDER: - return "NO_ORDER"; - case OrderPreservationType::INSERTION_ORDER: - return "INSERTION_ORDER"; - case OrderPreservationType::FIXED_ORDER: - return "FIXED_ORDER"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -OrderPreservationType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "NO_ORDER")) { - return OrderPreservationType::NO_ORDER; - } - if (StringUtil::Equals(value, "INSERTION_ORDER")) { - return OrderPreservationType::INSERTION_ORDER; - } - if (StringUtil::Equals(value, "FIXED_ORDER")) { - return OrderPreservationType::FIXED_ORDER; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(OrderType value) { - switch(value) { - case OrderType::INVALID: - return "INVALID"; - case OrderType::ORDER_DEFAULT: - return "ORDER_DEFAULT"; - case OrderType::ASCENDING: - return "ASCENDING"; - case OrderType::DESCENDING: - return "DESCENDING"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -OrderType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "INVALID")) { - return OrderType::INVALID; - } - if (StringUtil::Equals(value, "ORDER_DEFAULT") || StringUtil::Equals(value, "DEFAULT")) { - return OrderType::ORDER_DEFAULT; - } - if (StringUtil::Equals(value, "ASCENDING") || StringUtil::Equals(value, "ASC")) { - return OrderType::ASCENDING; - } - if (StringUtil::Equals(value, "DESCENDING") || StringUtil::Equals(value, "DESC")) { - return OrderType::DESCENDING; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(OutputStream value) { - switch(value) { - case OutputStream::STREAM_STDOUT: - return "STREAM_STDOUT"; - case OutputStream::STREAM_STDERR: - return "STREAM_STDERR"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -OutputStream EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "STREAM_STDOUT")) { - return OutputStream::STREAM_STDOUT; - } - if (StringUtil::Equals(value, "STREAM_STDERR")) { - return OutputStream::STREAM_STDERR; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(ParseInfoType value) { - switch(value) { - case ParseInfoType::ALTER_INFO: - return "ALTER_INFO"; - case ParseInfoType::ATTACH_INFO: - return "ATTACH_INFO"; - case ParseInfoType::COPY_INFO: - return "COPY_INFO"; - case ParseInfoType::CREATE_INFO: - return "CREATE_INFO"; - case ParseInfoType::DETACH_INFO: - return "DETACH_INFO"; - case ParseInfoType::DROP_INFO: - return "DROP_INFO"; - case ParseInfoType::BOUND_EXPORT_DATA: - return "BOUND_EXPORT_DATA"; - case ParseInfoType::LOAD_INFO: - return "LOAD_INFO"; - case ParseInfoType::PRAGMA_INFO: - return "PRAGMA_INFO"; - case ParseInfoType::SHOW_SELECT_INFO: - return "SHOW_SELECT_INFO"; - case ParseInfoType::TRANSACTION_INFO: - return "TRANSACTION_INFO"; - case ParseInfoType::VACUUM_INFO: - return "VACUUM_INFO"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -ParseInfoType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "ALTER_INFO")) { - return ParseInfoType::ALTER_INFO; - } - if (StringUtil::Equals(value, "ATTACH_INFO")) { - return ParseInfoType::ATTACH_INFO; - } - if (StringUtil::Equals(value, "COPY_INFO")) { - return ParseInfoType::COPY_INFO; - } - if (StringUtil::Equals(value, "CREATE_INFO")) { - return ParseInfoType::CREATE_INFO; - } - if (StringUtil::Equals(value, "DETACH_INFO")) { - return ParseInfoType::DETACH_INFO; - } - if (StringUtil::Equals(value, "DROP_INFO")) { - return ParseInfoType::DROP_INFO; - } - if (StringUtil::Equals(value, "BOUND_EXPORT_DATA")) { - return ParseInfoType::BOUND_EXPORT_DATA; - } - if (StringUtil::Equals(value, "LOAD_INFO")) { - return ParseInfoType::LOAD_INFO; - } - if (StringUtil::Equals(value, "PRAGMA_INFO")) { - return ParseInfoType::PRAGMA_INFO; - } - if (StringUtil::Equals(value, "SHOW_SELECT_INFO")) { - return ParseInfoType::SHOW_SELECT_INFO; - } - if (StringUtil::Equals(value, "TRANSACTION_INFO")) { - return ParseInfoType::TRANSACTION_INFO; - } - if (StringUtil::Equals(value, "VACUUM_INFO")) { - return ParseInfoType::VACUUM_INFO; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(ParserExtensionResultType value) { - switch(value) { - case ParserExtensionResultType::PARSE_SUCCESSFUL: - return "PARSE_SUCCESSFUL"; - case ParserExtensionResultType::DISPLAY_ORIGINAL_ERROR: - return "DISPLAY_ORIGINAL_ERROR"; - case ParserExtensionResultType::DISPLAY_EXTENSION_ERROR: - return "DISPLAY_EXTENSION_ERROR"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -ParserExtensionResultType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "PARSE_SUCCESSFUL")) { - return ParserExtensionResultType::PARSE_SUCCESSFUL; - } - if (StringUtil::Equals(value, "DISPLAY_ORIGINAL_ERROR")) { - return ParserExtensionResultType::DISPLAY_ORIGINAL_ERROR; - } - if (StringUtil::Equals(value, "DISPLAY_EXTENSION_ERROR")) { - return ParserExtensionResultType::DISPLAY_EXTENSION_ERROR; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(ParserMode value) { - switch(value) { - case ParserMode::PARSING: - return "PARSING"; - case ParserMode::SNIFFING_DATATYPES: - return "SNIFFING_DATATYPES"; - case ParserMode::PARSING_HEADER: - return "PARSING_HEADER"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -ParserMode EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "PARSING")) { - return ParserMode::PARSING; - } - if (StringUtil::Equals(value, "SNIFFING_DATATYPES")) { - return ParserMode::SNIFFING_DATATYPES; - } - if (StringUtil::Equals(value, "PARSING_HEADER")) { - return ParserMode::PARSING_HEADER; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(PartitionSortStage value) { - switch(value) { - case PartitionSortStage::INIT: - return "INIT"; - case PartitionSortStage::SCAN: - return "SCAN"; - case PartitionSortStage::PREPARE: - return "PREPARE"; - case PartitionSortStage::MERGE: - return "MERGE"; - case PartitionSortStage::SORTED: - return "SORTED"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -PartitionSortStage EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "INIT")) { - return PartitionSortStage::INIT; - } - if (StringUtil::Equals(value, "SCAN")) { - return PartitionSortStage::SCAN; - } - if (StringUtil::Equals(value, "PREPARE")) { - return PartitionSortStage::PREPARE; - } - if (StringUtil::Equals(value, "MERGE")) { - return PartitionSortStage::MERGE; - } - if (StringUtil::Equals(value, "SORTED")) { - return PartitionSortStage::SORTED; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(PartitionedColumnDataType value) { - switch(value) { - case PartitionedColumnDataType::INVALID: - return "INVALID"; - case PartitionedColumnDataType::RADIX: - return "RADIX"; - case PartitionedColumnDataType::HIVE: - return "HIVE"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -PartitionedColumnDataType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "INVALID")) { - return PartitionedColumnDataType::INVALID; - } - if (StringUtil::Equals(value, "RADIX")) { - return PartitionedColumnDataType::RADIX; - } - if (StringUtil::Equals(value, "HIVE")) { - return PartitionedColumnDataType::HIVE; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(PartitionedTupleDataType value) { - switch(value) { - case PartitionedTupleDataType::INVALID: - return "INVALID"; - case PartitionedTupleDataType::RADIX: - return "RADIX"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -PartitionedTupleDataType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "INVALID")) { - return PartitionedTupleDataType::INVALID; - } - if (StringUtil::Equals(value, "RADIX")) { - return PartitionedTupleDataType::RADIX; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(PendingExecutionResult value) { - switch(value) { - case PendingExecutionResult::RESULT_READY: - return "RESULT_READY"; - case PendingExecutionResult::RESULT_NOT_READY: - return "RESULT_NOT_READY"; - case PendingExecutionResult::EXECUTION_ERROR: - return "EXECUTION_ERROR"; - case PendingExecutionResult::NO_TASKS_AVAILABLE: - return "NO_TASKS_AVAILABLE"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -PendingExecutionResult EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "RESULT_READY")) { - return PendingExecutionResult::RESULT_READY; - } - if (StringUtil::Equals(value, "RESULT_NOT_READY")) { - return PendingExecutionResult::RESULT_NOT_READY; - } - if (StringUtil::Equals(value, "EXECUTION_ERROR")) { - return PendingExecutionResult::EXECUTION_ERROR; - } - if (StringUtil::Equals(value, "NO_TASKS_AVAILABLE")) { - return PendingExecutionResult::NO_TASKS_AVAILABLE; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(PhysicalOperatorType value) { - switch(value) { - case PhysicalOperatorType::INVALID: - return "INVALID"; - case PhysicalOperatorType::ORDER_BY: - return "ORDER_BY"; - case PhysicalOperatorType::LIMIT: - return "LIMIT"; - case PhysicalOperatorType::STREAMING_LIMIT: - return "STREAMING_LIMIT"; - case PhysicalOperatorType::LIMIT_PERCENT: - return "LIMIT_PERCENT"; - case PhysicalOperatorType::TOP_N: - return "TOP_N"; - case PhysicalOperatorType::WINDOW: - return "WINDOW"; - case PhysicalOperatorType::UNNEST: - return "UNNEST"; - case PhysicalOperatorType::UNGROUPED_AGGREGATE: - return "UNGROUPED_AGGREGATE"; - case PhysicalOperatorType::HASH_GROUP_BY: - return "HASH_GROUP_BY"; - case PhysicalOperatorType::PERFECT_HASH_GROUP_BY: - return "PERFECT_HASH_GROUP_BY"; - case PhysicalOperatorType::FILTER: - return "FILTER"; - case PhysicalOperatorType::PROJECTION: - return "PROJECTION"; - case PhysicalOperatorType::COPY_TO_FILE: - return "COPY_TO_FILE"; - case PhysicalOperatorType::BATCH_COPY_TO_FILE: - return "BATCH_COPY_TO_FILE"; - case PhysicalOperatorType::FIXED_BATCH_COPY_TO_FILE: - return "FIXED_BATCH_COPY_TO_FILE"; - case PhysicalOperatorType::RESERVOIR_SAMPLE: - return "RESERVOIR_SAMPLE"; - case PhysicalOperatorType::STREAMING_SAMPLE: - return "STREAMING_SAMPLE"; - case PhysicalOperatorType::STREAMING_WINDOW: - return "STREAMING_WINDOW"; - case PhysicalOperatorType::PIVOT: - return "PIVOT"; - case PhysicalOperatorType::TABLE_SCAN: - return "TABLE_SCAN"; - case PhysicalOperatorType::DUMMY_SCAN: - return "DUMMY_SCAN"; - case PhysicalOperatorType::COLUMN_DATA_SCAN: - return "COLUMN_DATA_SCAN"; - case PhysicalOperatorType::CHUNK_SCAN: - return "CHUNK_SCAN"; - case PhysicalOperatorType::RECURSIVE_CTE_SCAN: - return "RECURSIVE_CTE_SCAN"; - case PhysicalOperatorType::CTE_SCAN: - return "CTE_SCAN"; - case PhysicalOperatorType::DELIM_SCAN: - return "DELIM_SCAN"; - case PhysicalOperatorType::EXPRESSION_SCAN: - return "EXPRESSION_SCAN"; - case PhysicalOperatorType::POSITIONAL_SCAN: - return "POSITIONAL_SCAN"; - case PhysicalOperatorType::BLOCKWISE_NL_JOIN: - return "BLOCKWISE_NL_JOIN"; - case PhysicalOperatorType::NESTED_LOOP_JOIN: - return "NESTED_LOOP_JOIN"; - case PhysicalOperatorType::HASH_JOIN: - return "HASH_JOIN"; - case PhysicalOperatorType::CROSS_PRODUCT: - return "CROSS_PRODUCT"; - case PhysicalOperatorType::PIECEWISE_MERGE_JOIN: - return "PIECEWISE_MERGE_JOIN"; - case PhysicalOperatorType::IE_JOIN: - return "IE_JOIN"; - case PhysicalOperatorType::DELIM_JOIN: - return "DELIM_JOIN"; - case PhysicalOperatorType::INDEX_JOIN: - return "INDEX_JOIN"; - case PhysicalOperatorType::POSITIONAL_JOIN: - return "POSITIONAL_JOIN"; - case PhysicalOperatorType::ASOF_JOIN: - return "ASOF_JOIN"; - case PhysicalOperatorType::UNION: - return "UNION"; - case PhysicalOperatorType::RECURSIVE_CTE: - return "RECURSIVE_CTE"; - case PhysicalOperatorType::CTE: - return "CTE"; - case PhysicalOperatorType::INSERT: - return "INSERT"; - case PhysicalOperatorType::BATCH_INSERT: - return "BATCH_INSERT"; - case PhysicalOperatorType::DELETE_OPERATOR: - return "DELETE_OPERATOR"; - case PhysicalOperatorType::UPDATE: - return "UPDATE"; - case PhysicalOperatorType::CREATE_TABLE: - return "CREATE_TABLE"; - case PhysicalOperatorType::CREATE_TABLE_AS: - return "CREATE_TABLE_AS"; - case PhysicalOperatorType::BATCH_CREATE_TABLE_AS: - return "BATCH_CREATE_TABLE_AS"; - case PhysicalOperatorType::CREATE_INDEX: - return "CREATE_INDEX"; - case PhysicalOperatorType::ALTER: - return "ALTER"; - case PhysicalOperatorType::CREATE_SEQUENCE: - return "CREATE_SEQUENCE"; - case PhysicalOperatorType::CREATE_VIEW: - return "CREATE_VIEW"; - case PhysicalOperatorType::CREATE_SCHEMA: - return "CREATE_SCHEMA"; - case PhysicalOperatorType::CREATE_MACRO: - return "CREATE_MACRO"; - case PhysicalOperatorType::DROP: - return "DROP"; - case PhysicalOperatorType::PRAGMA: - return "PRAGMA"; - case PhysicalOperatorType::TRANSACTION: - return "TRANSACTION"; - case PhysicalOperatorType::CREATE_TYPE: - return "CREATE_TYPE"; - case PhysicalOperatorType::ATTACH: - return "ATTACH"; - case PhysicalOperatorType::DETACH: - return "DETACH"; - case PhysicalOperatorType::EXPLAIN: - return "EXPLAIN"; - case PhysicalOperatorType::EXPLAIN_ANALYZE: - return "EXPLAIN_ANALYZE"; - case PhysicalOperatorType::EMPTY_RESULT: - return "EMPTY_RESULT"; - case PhysicalOperatorType::EXECUTE: - return "EXECUTE"; - case PhysicalOperatorType::PREPARE: - return "PREPARE"; - case PhysicalOperatorType::VACUUM: - return "VACUUM"; - case PhysicalOperatorType::EXPORT: - return "EXPORT"; - case PhysicalOperatorType::SET: - return "SET"; - case PhysicalOperatorType::LOAD: - return "LOAD"; - case PhysicalOperatorType::INOUT_FUNCTION: - return "INOUT_FUNCTION"; - case PhysicalOperatorType::RESULT_COLLECTOR: - return "RESULT_COLLECTOR"; - case PhysicalOperatorType::RESET: - return "RESET"; - case PhysicalOperatorType::EXTENSION: - return "EXTENSION"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -PhysicalOperatorType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "INVALID")) { - return PhysicalOperatorType::INVALID; - } - if (StringUtil::Equals(value, "ORDER_BY")) { - return PhysicalOperatorType::ORDER_BY; - } - if (StringUtil::Equals(value, "LIMIT")) { - return PhysicalOperatorType::LIMIT; - } - if (StringUtil::Equals(value, "STREAMING_LIMIT")) { - return PhysicalOperatorType::STREAMING_LIMIT; - } - if (StringUtil::Equals(value, "LIMIT_PERCENT")) { - return PhysicalOperatorType::LIMIT_PERCENT; - } - if (StringUtil::Equals(value, "TOP_N")) { - return PhysicalOperatorType::TOP_N; - } - if (StringUtil::Equals(value, "WINDOW")) { - return PhysicalOperatorType::WINDOW; - } - if (StringUtil::Equals(value, "UNNEST")) { - return PhysicalOperatorType::UNNEST; - } - if (StringUtil::Equals(value, "UNGROUPED_AGGREGATE")) { - return PhysicalOperatorType::UNGROUPED_AGGREGATE; - } - if (StringUtil::Equals(value, "HASH_GROUP_BY")) { - return PhysicalOperatorType::HASH_GROUP_BY; - } - if (StringUtil::Equals(value, "PERFECT_HASH_GROUP_BY")) { - return PhysicalOperatorType::PERFECT_HASH_GROUP_BY; - } - if (StringUtil::Equals(value, "FILTER")) { - return PhysicalOperatorType::FILTER; - } - if (StringUtil::Equals(value, "PROJECTION")) { - return PhysicalOperatorType::PROJECTION; - } - if (StringUtil::Equals(value, "COPY_TO_FILE")) { - return PhysicalOperatorType::COPY_TO_FILE; - } - if (StringUtil::Equals(value, "BATCH_COPY_TO_FILE")) { - return PhysicalOperatorType::BATCH_COPY_TO_FILE; - } - if (StringUtil::Equals(value, "FIXED_BATCH_COPY_TO_FILE")) { - return PhysicalOperatorType::FIXED_BATCH_COPY_TO_FILE; - } - if (StringUtil::Equals(value, "RESERVOIR_SAMPLE")) { - return PhysicalOperatorType::RESERVOIR_SAMPLE; - } - if (StringUtil::Equals(value, "STREAMING_SAMPLE")) { - return PhysicalOperatorType::STREAMING_SAMPLE; - } - if (StringUtil::Equals(value, "STREAMING_WINDOW")) { - return PhysicalOperatorType::STREAMING_WINDOW; - } - if (StringUtil::Equals(value, "PIVOT")) { - return PhysicalOperatorType::PIVOT; - } - if (StringUtil::Equals(value, "TABLE_SCAN")) { - return PhysicalOperatorType::TABLE_SCAN; - } - if (StringUtil::Equals(value, "DUMMY_SCAN")) { - return PhysicalOperatorType::DUMMY_SCAN; - } - if (StringUtil::Equals(value, "COLUMN_DATA_SCAN")) { - return PhysicalOperatorType::COLUMN_DATA_SCAN; - } - if (StringUtil::Equals(value, "CHUNK_SCAN")) { - return PhysicalOperatorType::CHUNK_SCAN; - } - if (StringUtil::Equals(value, "RECURSIVE_CTE_SCAN")) { - return PhysicalOperatorType::RECURSIVE_CTE_SCAN; - } - if (StringUtil::Equals(value, "CTE_SCAN")) { - return PhysicalOperatorType::CTE_SCAN; - } - if (StringUtil::Equals(value, "DELIM_SCAN")) { - return PhysicalOperatorType::DELIM_SCAN; - } - if (StringUtil::Equals(value, "EXPRESSION_SCAN")) { - return PhysicalOperatorType::EXPRESSION_SCAN; - } - if (StringUtil::Equals(value, "POSITIONAL_SCAN")) { - return PhysicalOperatorType::POSITIONAL_SCAN; - } - if (StringUtil::Equals(value, "BLOCKWISE_NL_JOIN")) { - return PhysicalOperatorType::BLOCKWISE_NL_JOIN; - } - if (StringUtil::Equals(value, "NESTED_LOOP_JOIN")) { - return PhysicalOperatorType::NESTED_LOOP_JOIN; - } - if (StringUtil::Equals(value, "HASH_JOIN")) { - return PhysicalOperatorType::HASH_JOIN; - } - if (StringUtil::Equals(value, "CROSS_PRODUCT")) { - return PhysicalOperatorType::CROSS_PRODUCT; - } - if (StringUtil::Equals(value, "PIECEWISE_MERGE_JOIN")) { - return PhysicalOperatorType::PIECEWISE_MERGE_JOIN; - } - if (StringUtil::Equals(value, "IE_JOIN")) { - return PhysicalOperatorType::IE_JOIN; - } - if (StringUtil::Equals(value, "DELIM_JOIN")) { - return PhysicalOperatorType::DELIM_JOIN; - } - if (StringUtil::Equals(value, "INDEX_JOIN")) { - return PhysicalOperatorType::INDEX_JOIN; - } - if (StringUtil::Equals(value, "POSITIONAL_JOIN")) { - return PhysicalOperatorType::POSITIONAL_JOIN; - } - if (StringUtil::Equals(value, "ASOF_JOIN")) { - return PhysicalOperatorType::ASOF_JOIN; - } - if (StringUtil::Equals(value, "UNION")) { - return PhysicalOperatorType::UNION; - } - if (StringUtil::Equals(value, "RECURSIVE_CTE")) { - return PhysicalOperatorType::RECURSIVE_CTE; - } - if (StringUtil::Equals(value, "CTE")) { - return PhysicalOperatorType::CTE; - } - if (StringUtil::Equals(value, "INSERT")) { - return PhysicalOperatorType::INSERT; - } - if (StringUtil::Equals(value, "BATCH_INSERT")) { - return PhysicalOperatorType::BATCH_INSERT; - } - if (StringUtil::Equals(value, "DELETE_OPERATOR")) { - return PhysicalOperatorType::DELETE_OPERATOR; - } - if (StringUtil::Equals(value, "UPDATE")) { - return PhysicalOperatorType::UPDATE; - } - if (StringUtil::Equals(value, "CREATE_TABLE")) { - return PhysicalOperatorType::CREATE_TABLE; - } - if (StringUtil::Equals(value, "CREATE_TABLE_AS")) { - return PhysicalOperatorType::CREATE_TABLE_AS; - } - if (StringUtil::Equals(value, "BATCH_CREATE_TABLE_AS")) { - return PhysicalOperatorType::BATCH_CREATE_TABLE_AS; - } - if (StringUtil::Equals(value, "CREATE_INDEX")) { - return PhysicalOperatorType::CREATE_INDEX; - } - if (StringUtil::Equals(value, "ALTER")) { - return PhysicalOperatorType::ALTER; - } - if (StringUtil::Equals(value, "CREATE_SEQUENCE")) { - return PhysicalOperatorType::CREATE_SEQUENCE; - } - if (StringUtil::Equals(value, "CREATE_VIEW")) { - return PhysicalOperatorType::CREATE_VIEW; - } - if (StringUtil::Equals(value, "CREATE_SCHEMA")) { - return PhysicalOperatorType::CREATE_SCHEMA; - } - if (StringUtil::Equals(value, "CREATE_MACRO")) { - return PhysicalOperatorType::CREATE_MACRO; - } - if (StringUtil::Equals(value, "DROP")) { - return PhysicalOperatorType::DROP; - } - if (StringUtil::Equals(value, "PRAGMA")) { - return PhysicalOperatorType::PRAGMA; - } - if (StringUtil::Equals(value, "TRANSACTION")) { - return PhysicalOperatorType::TRANSACTION; - } - if (StringUtil::Equals(value, "CREATE_TYPE")) { - return PhysicalOperatorType::CREATE_TYPE; - } - if (StringUtil::Equals(value, "ATTACH")) { - return PhysicalOperatorType::ATTACH; - } - if (StringUtil::Equals(value, "DETACH")) { - return PhysicalOperatorType::DETACH; - } - if (StringUtil::Equals(value, "EXPLAIN")) { - return PhysicalOperatorType::EXPLAIN; - } - if (StringUtil::Equals(value, "EXPLAIN_ANALYZE")) { - return PhysicalOperatorType::EXPLAIN_ANALYZE; - } - if (StringUtil::Equals(value, "EMPTY_RESULT")) { - return PhysicalOperatorType::EMPTY_RESULT; - } - if (StringUtil::Equals(value, "EXECUTE")) { - return PhysicalOperatorType::EXECUTE; - } - if (StringUtil::Equals(value, "PREPARE")) { - return PhysicalOperatorType::PREPARE; - } - if (StringUtil::Equals(value, "VACUUM")) { - return PhysicalOperatorType::VACUUM; - } - if (StringUtil::Equals(value, "EXPORT")) { - return PhysicalOperatorType::EXPORT; - } - if (StringUtil::Equals(value, "SET")) { - return PhysicalOperatorType::SET; - } - if (StringUtil::Equals(value, "LOAD")) { - return PhysicalOperatorType::LOAD; - } - if (StringUtil::Equals(value, "INOUT_FUNCTION")) { - return PhysicalOperatorType::INOUT_FUNCTION; - } - if (StringUtil::Equals(value, "RESULT_COLLECTOR")) { - return PhysicalOperatorType::RESULT_COLLECTOR; - } - if (StringUtil::Equals(value, "RESET")) { - return PhysicalOperatorType::RESET; - } - if (StringUtil::Equals(value, "EXTENSION")) { - return PhysicalOperatorType::EXTENSION; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(PhysicalType value) { - switch(value) { - case PhysicalType::BOOL: - return "BOOL"; - case PhysicalType::UINT8: - return "UINT8"; - case PhysicalType::INT8: - return "INT8"; - case PhysicalType::UINT16: - return "UINT16"; - case PhysicalType::INT16: - return "INT16"; - case PhysicalType::UINT32: - return "UINT32"; - case PhysicalType::INT32: - return "INT32"; - case PhysicalType::UINT64: - return "UINT64"; - case PhysicalType::INT64: - return "INT64"; - case PhysicalType::FLOAT: - return "FLOAT"; - case PhysicalType::DOUBLE: - return "DOUBLE"; - case PhysicalType::INTERVAL: - return "INTERVAL"; - case PhysicalType::LIST: - return "LIST"; - case PhysicalType::STRUCT: - return "STRUCT"; - case PhysicalType::VARCHAR: - return "VARCHAR"; - case PhysicalType::INT128: - return "INT128"; - case PhysicalType::UNKNOWN: - return "UNKNOWN"; - case PhysicalType::BIT: - return "BIT"; - case PhysicalType::INVALID: - return "INVALID"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -PhysicalType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "BOOL")) { - return PhysicalType::BOOL; - } - if (StringUtil::Equals(value, "UINT8")) { - return PhysicalType::UINT8; - } - if (StringUtil::Equals(value, "INT8")) { - return PhysicalType::INT8; - } - if (StringUtil::Equals(value, "UINT16")) { - return PhysicalType::UINT16; - } - if (StringUtil::Equals(value, "INT16")) { - return PhysicalType::INT16; - } - if (StringUtil::Equals(value, "UINT32")) { - return PhysicalType::UINT32; - } - if (StringUtil::Equals(value, "INT32")) { - return PhysicalType::INT32; - } - if (StringUtil::Equals(value, "UINT64")) { - return PhysicalType::UINT64; - } - if (StringUtil::Equals(value, "INT64")) { - return PhysicalType::INT64; - } - if (StringUtil::Equals(value, "FLOAT")) { - return PhysicalType::FLOAT; - } - if (StringUtil::Equals(value, "DOUBLE")) { - return PhysicalType::DOUBLE; - } - if (StringUtil::Equals(value, "INTERVAL")) { - return PhysicalType::INTERVAL; - } - if (StringUtil::Equals(value, "LIST")) { - return PhysicalType::LIST; - } - if (StringUtil::Equals(value, "STRUCT")) { - return PhysicalType::STRUCT; - } - if (StringUtil::Equals(value, "VARCHAR")) { - return PhysicalType::VARCHAR; - } - if (StringUtil::Equals(value, "INT128")) { - return PhysicalType::INT128; - } - if (StringUtil::Equals(value, "UNKNOWN")) { - return PhysicalType::UNKNOWN; - } - if (StringUtil::Equals(value, "BIT")) { - return PhysicalType::BIT; - } - if (StringUtil::Equals(value, "INVALID")) { - return PhysicalType::INVALID; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(PragmaType value) { - switch(value) { - case PragmaType::PRAGMA_STATEMENT: - return "PRAGMA_STATEMENT"; - case PragmaType::PRAGMA_CALL: - return "PRAGMA_CALL"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -PragmaType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "PRAGMA_STATEMENT")) { - return PragmaType::PRAGMA_STATEMENT; - } - if (StringUtil::Equals(value, "PRAGMA_CALL")) { - return PragmaType::PRAGMA_CALL; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(PreparedParamType value) { - switch(value) { - case PreparedParamType::AUTO_INCREMENT: - return "AUTO_INCREMENT"; - case PreparedParamType::POSITIONAL: - return "POSITIONAL"; - case PreparedParamType::NAMED: - return "NAMED"; - case PreparedParamType::INVALID: - return "INVALID"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -PreparedParamType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "AUTO_INCREMENT")) { - return PreparedParamType::AUTO_INCREMENT; - } - if (StringUtil::Equals(value, "POSITIONAL")) { - return PreparedParamType::POSITIONAL; - } - if (StringUtil::Equals(value, "NAMED")) { - return PreparedParamType::NAMED; - } - if (StringUtil::Equals(value, "INVALID")) { - return PreparedParamType::INVALID; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(ProfilerPrintFormat value) { - switch(value) { - case ProfilerPrintFormat::QUERY_TREE: - return "QUERY_TREE"; - case ProfilerPrintFormat::JSON: - return "JSON"; - case ProfilerPrintFormat::QUERY_TREE_OPTIMIZER: - return "QUERY_TREE_OPTIMIZER"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -ProfilerPrintFormat EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "QUERY_TREE")) { - return ProfilerPrintFormat::QUERY_TREE; - } - if (StringUtil::Equals(value, "JSON")) { - return ProfilerPrintFormat::JSON; - } - if (StringUtil::Equals(value, "QUERY_TREE_OPTIMIZER")) { - return ProfilerPrintFormat::QUERY_TREE_OPTIMIZER; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(QueryNodeType value) { - switch(value) { - case QueryNodeType::SELECT_NODE: - return "SELECT_NODE"; - case QueryNodeType::SET_OPERATION_NODE: - return "SET_OPERATION_NODE"; - case QueryNodeType::BOUND_SUBQUERY_NODE: - return "BOUND_SUBQUERY_NODE"; - case QueryNodeType::RECURSIVE_CTE_NODE: - return "RECURSIVE_CTE_NODE"; - case QueryNodeType::CTE_NODE: - return "CTE_NODE"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -QueryNodeType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "SELECT_NODE")) { - return QueryNodeType::SELECT_NODE; - } - if (StringUtil::Equals(value, "SET_OPERATION_NODE")) { - return QueryNodeType::SET_OPERATION_NODE; - } - if (StringUtil::Equals(value, "BOUND_SUBQUERY_NODE")) { - return QueryNodeType::BOUND_SUBQUERY_NODE; - } - if (StringUtil::Equals(value, "RECURSIVE_CTE_NODE")) { - return QueryNodeType::RECURSIVE_CTE_NODE; - } - if (StringUtil::Equals(value, "CTE_NODE")) { - return QueryNodeType::CTE_NODE; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(QueryResultType value) { - switch(value) { - case QueryResultType::MATERIALIZED_RESULT: - return "MATERIALIZED_RESULT"; - case QueryResultType::STREAM_RESULT: - return "STREAM_RESULT"; - case QueryResultType::PENDING_RESULT: - return "PENDING_RESULT"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -QueryResultType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "MATERIALIZED_RESULT")) { - return QueryResultType::MATERIALIZED_RESULT; - } - if (StringUtil::Equals(value, "STREAM_RESULT")) { - return QueryResultType::STREAM_RESULT; - } - if (StringUtil::Equals(value, "PENDING_RESULT")) { - return QueryResultType::PENDING_RESULT; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(QuoteRule value) { - switch(value) { - case QuoteRule::QUOTES_RFC: - return "QUOTES_RFC"; - case QuoteRule::QUOTES_OTHER: - return "QUOTES_OTHER"; - case QuoteRule::NO_QUOTES: - return "NO_QUOTES"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -QuoteRule EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "QUOTES_RFC")) { - return QuoteRule::QUOTES_RFC; - } - if (StringUtil::Equals(value, "QUOTES_OTHER")) { - return QuoteRule::QUOTES_OTHER; - } - if (StringUtil::Equals(value, "NO_QUOTES")) { - return QuoteRule::NO_QUOTES; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(RelationType value) { - switch(value) { - case RelationType::INVALID_RELATION: - return "INVALID_RELATION"; - case RelationType::TABLE_RELATION: - return "TABLE_RELATION"; - case RelationType::PROJECTION_RELATION: - return "PROJECTION_RELATION"; - case RelationType::FILTER_RELATION: - return "FILTER_RELATION"; - case RelationType::EXPLAIN_RELATION: - return "EXPLAIN_RELATION"; - case RelationType::CROSS_PRODUCT_RELATION: - return "CROSS_PRODUCT_RELATION"; - case RelationType::JOIN_RELATION: - return "JOIN_RELATION"; - case RelationType::AGGREGATE_RELATION: - return "AGGREGATE_RELATION"; - case RelationType::SET_OPERATION_RELATION: - return "SET_OPERATION_RELATION"; - case RelationType::DISTINCT_RELATION: - return "DISTINCT_RELATION"; - case RelationType::LIMIT_RELATION: - return "LIMIT_RELATION"; - case RelationType::ORDER_RELATION: - return "ORDER_RELATION"; - case RelationType::CREATE_VIEW_RELATION: - return "CREATE_VIEW_RELATION"; - case RelationType::CREATE_TABLE_RELATION: - return "CREATE_TABLE_RELATION"; - case RelationType::INSERT_RELATION: - return "INSERT_RELATION"; - case RelationType::VALUE_LIST_RELATION: - return "VALUE_LIST_RELATION"; - case RelationType::DELETE_RELATION: - return "DELETE_RELATION"; - case RelationType::UPDATE_RELATION: - return "UPDATE_RELATION"; - case RelationType::WRITE_CSV_RELATION: - return "WRITE_CSV_RELATION"; - case RelationType::WRITE_PARQUET_RELATION: - return "WRITE_PARQUET_RELATION"; - case RelationType::READ_CSV_RELATION: - return "READ_CSV_RELATION"; - case RelationType::SUBQUERY_RELATION: - return "SUBQUERY_RELATION"; - case RelationType::TABLE_FUNCTION_RELATION: - return "TABLE_FUNCTION_RELATION"; - case RelationType::VIEW_RELATION: - return "VIEW_RELATION"; - case RelationType::QUERY_RELATION: - return "QUERY_RELATION"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -RelationType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "INVALID_RELATION")) { - return RelationType::INVALID_RELATION; - } - if (StringUtil::Equals(value, "TABLE_RELATION")) { - return RelationType::TABLE_RELATION; - } - if (StringUtil::Equals(value, "PROJECTION_RELATION")) { - return RelationType::PROJECTION_RELATION; - } - if (StringUtil::Equals(value, "FILTER_RELATION")) { - return RelationType::FILTER_RELATION; - } - if (StringUtil::Equals(value, "EXPLAIN_RELATION")) { - return RelationType::EXPLAIN_RELATION; - } - if (StringUtil::Equals(value, "CROSS_PRODUCT_RELATION")) { - return RelationType::CROSS_PRODUCT_RELATION; - } - if (StringUtil::Equals(value, "JOIN_RELATION")) { - return RelationType::JOIN_RELATION; - } - if (StringUtil::Equals(value, "AGGREGATE_RELATION")) { - return RelationType::AGGREGATE_RELATION; - } - if (StringUtil::Equals(value, "SET_OPERATION_RELATION")) { - return RelationType::SET_OPERATION_RELATION; - } - if (StringUtil::Equals(value, "DISTINCT_RELATION")) { - return RelationType::DISTINCT_RELATION; - } - if (StringUtil::Equals(value, "LIMIT_RELATION")) { - return RelationType::LIMIT_RELATION; - } - if (StringUtil::Equals(value, "ORDER_RELATION")) { - return RelationType::ORDER_RELATION; - } - if (StringUtil::Equals(value, "CREATE_VIEW_RELATION")) { - return RelationType::CREATE_VIEW_RELATION; - } - if (StringUtil::Equals(value, "CREATE_TABLE_RELATION")) { - return RelationType::CREATE_TABLE_RELATION; - } - if (StringUtil::Equals(value, "INSERT_RELATION")) { - return RelationType::INSERT_RELATION; - } - if (StringUtil::Equals(value, "VALUE_LIST_RELATION")) { - return RelationType::VALUE_LIST_RELATION; - } - if (StringUtil::Equals(value, "DELETE_RELATION")) { - return RelationType::DELETE_RELATION; - } - if (StringUtil::Equals(value, "UPDATE_RELATION")) { - return RelationType::UPDATE_RELATION; - } - if (StringUtil::Equals(value, "WRITE_CSV_RELATION")) { - return RelationType::WRITE_CSV_RELATION; - } - if (StringUtil::Equals(value, "WRITE_PARQUET_RELATION")) { - return RelationType::WRITE_PARQUET_RELATION; - } - if (StringUtil::Equals(value, "READ_CSV_RELATION")) { - return RelationType::READ_CSV_RELATION; - } - if (StringUtil::Equals(value, "SUBQUERY_RELATION")) { - return RelationType::SUBQUERY_RELATION; - } - if (StringUtil::Equals(value, "TABLE_FUNCTION_RELATION")) { - return RelationType::TABLE_FUNCTION_RELATION; - } - if (StringUtil::Equals(value, "VIEW_RELATION")) { - return RelationType::VIEW_RELATION; - } - if (StringUtil::Equals(value, "QUERY_RELATION")) { - return RelationType::QUERY_RELATION; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(RenderMode value) { - switch(value) { - case RenderMode::ROWS: - return "ROWS"; - case RenderMode::COLUMNS: - return "COLUMNS"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -RenderMode EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "ROWS")) { - return RenderMode::ROWS; - } - if (StringUtil::Equals(value, "COLUMNS")) { - return RenderMode::COLUMNS; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(ResultModifierType value) { - switch(value) { - case ResultModifierType::LIMIT_MODIFIER: - return "LIMIT_MODIFIER"; - case ResultModifierType::ORDER_MODIFIER: - return "ORDER_MODIFIER"; - case ResultModifierType::DISTINCT_MODIFIER: - return "DISTINCT_MODIFIER"; - case ResultModifierType::LIMIT_PERCENT_MODIFIER: - return "LIMIT_PERCENT_MODIFIER"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -ResultModifierType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "LIMIT_MODIFIER")) { - return ResultModifierType::LIMIT_MODIFIER; - } - if (StringUtil::Equals(value, "ORDER_MODIFIER")) { - return ResultModifierType::ORDER_MODIFIER; - } - if (StringUtil::Equals(value, "DISTINCT_MODIFIER")) { - return ResultModifierType::DISTINCT_MODIFIER; - } - if (StringUtil::Equals(value, "LIMIT_PERCENT_MODIFIER")) { - return ResultModifierType::LIMIT_PERCENT_MODIFIER; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(SampleMethod value) { - switch(value) { - case SampleMethod::SYSTEM_SAMPLE: - return "System"; - case SampleMethod::BERNOULLI_SAMPLE: - return "Bernoulli"; - case SampleMethod::RESERVOIR_SAMPLE: - return "Reservoir"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -SampleMethod EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "System")) { - return SampleMethod::SYSTEM_SAMPLE; - } - if (StringUtil::Equals(value, "Bernoulli")) { - return SampleMethod::BERNOULLI_SAMPLE; - } - if (StringUtil::Equals(value, "Reservoir")) { - return SampleMethod::RESERVOIR_SAMPLE; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(SequenceInfo value) { - switch(value) { - case SequenceInfo::SEQ_START: - return "SEQ_START"; - case SequenceInfo::SEQ_INC: - return "SEQ_INC"; - case SequenceInfo::SEQ_MIN: - return "SEQ_MIN"; - case SequenceInfo::SEQ_MAX: - return "SEQ_MAX"; - case SequenceInfo::SEQ_CYCLE: - return "SEQ_CYCLE"; - case SequenceInfo::SEQ_OWN: - return "SEQ_OWN"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -SequenceInfo EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "SEQ_START")) { - return SequenceInfo::SEQ_START; - } - if (StringUtil::Equals(value, "SEQ_INC")) { - return SequenceInfo::SEQ_INC; - } - if (StringUtil::Equals(value, "SEQ_MIN")) { - return SequenceInfo::SEQ_MIN; - } - if (StringUtil::Equals(value, "SEQ_MAX")) { - return SequenceInfo::SEQ_MAX; - } - if (StringUtil::Equals(value, "SEQ_CYCLE")) { - return SequenceInfo::SEQ_CYCLE; - } - if (StringUtil::Equals(value, "SEQ_OWN")) { - return SequenceInfo::SEQ_OWN; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(SetOperationType value) { - switch(value) { - case SetOperationType::NONE: - return "NONE"; - case SetOperationType::UNION: - return "UNION"; - case SetOperationType::EXCEPT: - return "EXCEPT"; - case SetOperationType::INTERSECT: - return "INTERSECT"; - case SetOperationType::UNION_BY_NAME: - return "UNION_BY_NAME"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -SetOperationType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "NONE")) { - return SetOperationType::NONE; - } - if (StringUtil::Equals(value, "UNION")) { - return SetOperationType::UNION; - } - if (StringUtil::Equals(value, "EXCEPT")) { - return SetOperationType::EXCEPT; - } - if (StringUtil::Equals(value, "INTERSECT")) { - return SetOperationType::INTERSECT; - } - if (StringUtil::Equals(value, "UNION_BY_NAME")) { - return SetOperationType::UNION_BY_NAME; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(SetScope value) { - switch(value) { - case SetScope::AUTOMATIC: - return "AUTOMATIC"; - case SetScope::LOCAL: - return "LOCAL"; - case SetScope::SESSION: - return "SESSION"; - case SetScope::GLOBAL: - return "GLOBAL"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -SetScope EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "AUTOMATIC")) { - return SetScope::AUTOMATIC; - } - if (StringUtil::Equals(value, "LOCAL")) { - return SetScope::LOCAL; - } - if (StringUtil::Equals(value, "SESSION")) { - return SetScope::SESSION; - } - if (StringUtil::Equals(value, "GLOBAL")) { - return SetScope::GLOBAL; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(SetType value) { - switch(value) { - case SetType::SET: - return "SET"; - case SetType::RESET: - return "RESET"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -SetType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "SET")) { - return SetType::SET; - } - if (StringUtil::Equals(value, "RESET")) { - return SetType::RESET; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(SimplifiedTokenType value) { - switch(value) { - case SimplifiedTokenType::SIMPLIFIED_TOKEN_IDENTIFIER: - return "SIMPLIFIED_TOKEN_IDENTIFIER"; - case SimplifiedTokenType::SIMPLIFIED_TOKEN_NUMERIC_CONSTANT: - return "SIMPLIFIED_TOKEN_NUMERIC_CONSTANT"; - case SimplifiedTokenType::SIMPLIFIED_TOKEN_STRING_CONSTANT: - return "SIMPLIFIED_TOKEN_STRING_CONSTANT"; - case SimplifiedTokenType::SIMPLIFIED_TOKEN_OPERATOR: - return "SIMPLIFIED_TOKEN_OPERATOR"; - case SimplifiedTokenType::SIMPLIFIED_TOKEN_KEYWORD: - return "SIMPLIFIED_TOKEN_KEYWORD"; - case SimplifiedTokenType::SIMPLIFIED_TOKEN_COMMENT: - return "SIMPLIFIED_TOKEN_COMMENT"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -SimplifiedTokenType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "SIMPLIFIED_TOKEN_IDENTIFIER")) { - return SimplifiedTokenType::SIMPLIFIED_TOKEN_IDENTIFIER; - } - if (StringUtil::Equals(value, "SIMPLIFIED_TOKEN_NUMERIC_CONSTANT")) { - return SimplifiedTokenType::SIMPLIFIED_TOKEN_NUMERIC_CONSTANT; - } - if (StringUtil::Equals(value, "SIMPLIFIED_TOKEN_STRING_CONSTANT")) { - return SimplifiedTokenType::SIMPLIFIED_TOKEN_STRING_CONSTANT; - } - if (StringUtil::Equals(value, "SIMPLIFIED_TOKEN_OPERATOR")) { - return SimplifiedTokenType::SIMPLIFIED_TOKEN_OPERATOR; - } - if (StringUtil::Equals(value, "SIMPLIFIED_TOKEN_KEYWORD")) { - return SimplifiedTokenType::SIMPLIFIED_TOKEN_KEYWORD; - } - if (StringUtil::Equals(value, "SIMPLIFIED_TOKEN_COMMENT")) { - return SimplifiedTokenType::SIMPLIFIED_TOKEN_COMMENT; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(SinkCombineResultType value) { - switch(value) { - case SinkCombineResultType::FINISHED: - return "FINISHED"; - case SinkCombineResultType::BLOCKED: - return "BLOCKED"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -SinkCombineResultType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "FINISHED")) { - return SinkCombineResultType::FINISHED; - } - if (StringUtil::Equals(value, "BLOCKED")) { - return SinkCombineResultType::BLOCKED; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(SinkFinalizeType value) { - switch(value) { - case SinkFinalizeType::READY: - return "READY"; - case SinkFinalizeType::NO_OUTPUT_POSSIBLE: - return "NO_OUTPUT_POSSIBLE"; - case SinkFinalizeType::BLOCKED: - return "BLOCKED"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -SinkFinalizeType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "READY")) { - return SinkFinalizeType::READY; - } - if (StringUtil::Equals(value, "NO_OUTPUT_POSSIBLE")) { - return SinkFinalizeType::NO_OUTPUT_POSSIBLE; - } - if (StringUtil::Equals(value, "BLOCKED")) { - return SinkFinalizeType::BLOCKED; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(SinkResultType value) { - switch(value) { - case SinkResultType::NEED_MORE_INPUT: - return "NEED_MORE_INPUT"; - case SinkResultType::FINISHED: - return "FINISHED"; - case SinkResultType::BLOCKED: - return "BLOCKED"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -SinkResultType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "NEED_MORE_INPUT")) { - return SinkResultType::NEED_MORE_INPUT; - } - if (StringUtil::Equals(value, "FINISHED")) { - return SinkResultType::FINISHED; - } - if (StringUtil::Equals(value, "BLOCKED")) { - return SinkResultType::BLOCKED; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(SourceResultType value) { - switch(value) { - case SourceResultType::HAVE_MORE_OUTPUT: - return "HAVE_MORE_OUTPUT"; - case SourceResultType::FINISHED: - return "FINISHED"; - case SourceResultType::BLOCKED: - return "BLOCKED"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -SourceResultType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "HAVE_MORE_OUTPUT")) { - return SourceResultType::HAVE_MORE_OUTPUT; - } - if (StringUtil::Equals(value, "FINISHED")) { - return SourceResultType::FINISHED; - } - if (StringUtil::Equals(value, "BLOCKED")) { - return SourceResultType::BLOCKED; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(StatementReturnType value) { - switch(value) { - case StatementReturnType::QUERY_RESULT: - return "QUERY_RESULT"; - case StatementReturnType::CHANGED_ROWS: - return "CHANGED_ROWS"; - case StatementReturnType::NOTHING: - return "NOTHING"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -StatementReturnType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "QUERY_RESULT")) { - return StatementReturnType::QUERY_RESULT; - } - if (StringUtil::Equals(value, "CHANGED_ROWS")) { - return StatementReturnType::CHANGED_ROWS; - } - if (StringUtil::Equals(value, "NOTHING")) { - return StatementReturnType::NOTHING; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(StatementType value) { - switch(value) { - case StatementType::INVALID_STATEMENT: - return "INVALID_STATEMENT"; - case StatementType::SELECT_STATEMENT: - return "SELECT_STATEMENT"; - case StatementType::INSERT_STATEMENT: - return "INSERT_STATEMENT"; - case StatementType::UPDATE_STATEMENT: - return "UPDATE_STATEMENT"; - case StatementType::CREATE_STATEMENT: - return "CREATE_STATEMENT"; - case StatementType::DELETE_STATEMENT: - return "DELETE_STATEMENT"; - case StatementType::PREPARE_STATEMENT: - return "PREPARE_STATEMENT"; - case StatementType::EXECUTE_STATEMENT: - return "EXECUTE_STATEMENT"; - case StatementType::ALTER_STATEMENT: - return "ALTER_STATEMENT"; - case StatementType::TRANSACTION_STATEMENT: - return "TRANSACTION_STATEMENT"; - case StatementType::COPY_STATEMENT: - return "COPY_STATEMENT"; - case StatementType::ANALYZE_STATEMENT: - return "ANALYZE_STATEMENT"; - case StatementType::VARIABLE_SET_STATEMENT: - return "VARIABLE_SET_STATEMENT"; - case StatementType::CREATE_FUNC_STATEMENT: - return "CREATE_FUNC_STATEMENT"; - case StatementType::EXPLAIN_STATEMENT: - return "EXPLAIN_STATEMENT"; - case StatementType::DROP_STATEMENT: - return "DROP_STATEMENT"; - case StatementType::EXPORT_STATEMENT: - return "EXPORT_STATEMENT"; - case StatementType::PRAGMA_STATEMENT: - return "PRAGMA_STATEMENT"; - case StatementType::SHOW_STATEMENT: - return "SHOW_STATEMENT"; - case StatementType::VACUUM_STATEMENT: - return "VACUUM_STATEMENT"; - case StatementType::CALL_STATEMENT: - return "CALL_STATEMENT"; - case StatementType::SET_STATEMENT: - return "SET_STATEMENT"; - case StatementType::LOAD_STATEMENT: - return "LOAD_STATEMENT"; - case StatementType::RELATION_STATEMENT: - return "RELATION_STATEMENT"; - case StatementType::EXTENSION_STATEMENT: - return "EXTENSION_STATEMENT"; - case StatementType::LOGICAL_PLAN_STATEMENT: - return "LOGICAL_PLAN_STATEMENT"; - case StatementType::ATTACH_STATEMENT: - return "ATTACH_STATEMENT"; - case StatementType::DETACH_STATEMENT: - return "DETACH_STATEMENT"; - case StatementType::MULTI_STATEMENT: - return "MULTI_STATEMENT"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -StatementType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "INVALID_STATEMENT")) { - return StatementType::INVALID_STATEMENT; - } - if (StringUtil::Equals(value, "SELECT_STATEMENT")) { - return StatementType::SELECT_STATEMENT; - } - if (StringUtil::Equals(value, "INSERT_STATEMENT")) { - return StatementType::INSERT_STATEMENT; - } - if (StringUtil::Equals(value, "UPDATE_STATEMENT")) { - return StatementType::UPDATE_STATEMENT; - } - if (StringUtil::Equals(value, "CREATE_STATEMENT")) { - return StatementType::CREATE_STATEMENT; - } - if (StringUtil::Equals(value, "DELETE_STATEMENT")) { - return StatementType::DELETE_STATEMENT; - } - if (StringUtil::Equals(value, "PREPARE_STATEMENT")) { - return StatementType::PREPARE_STATEMENT; - } - if (StringUtil::Equals(value, "EXECUTE_STATEMENT")) { - return StatementType::EXECUTE_STATEMENT; - } - if (StringUtil::Equals(value, "ALTER_STATEMENT")) { - return StatementType::ALTER_STATEMENT; - } - if (StringUtil::Equals(value, "TRANSACTION_STATEMENT")) { - return StatementType::TRANSACTION_STATEMENT; - } - if (StringUtil::Equals(value, "COPY_STATEMENT")) { - return StatementType::COPY_STATEMENT; - } - if (StringUtil::Equals(value, "ANALYZE_STATEMENT")) { - return StatementType::ANALYZE_STATEMENT; - } - if (StringUtil::Equals(value, "VARIABLE_SET_STATEMENT")) { - return StatementType::VARIABLE_SET_STATEMENT; - } - if (StringUtil::Equals(value, "CREATE_FUNC_STATEMENT")) { - return StatementType::CREATE_FUNC_STATEMENT; - } - if (StringUtil::Equals(value, "EXPLAIN_STATEMENT")) { - return StatementType::EXPLAIN_STATEMENT; - } - if (StringUtil::Equals(value, "DROP_STATEMENT")) { - return StatementType::DROP_STATEMENT; - } - if (StringUtil::Equals(value, "EXPORT_STATEMENT")) { - return StatementType::EXPORT_STATEMENT; - } - if (StringUtil::Equals(value, "PRAGMA_STATEMENT")) { - return StatementType::PRAGMA_STATEMENT; - } - if (StringUtil::Equals(value, "SHOW_STATEMENT")) { - return StatementType::SHOW_STATEMENT; - } - if (StringUtil::Equals(value, "VACUUM_STATEMENT")) { - return StatementType::VACUUM_STATEMENT; - } - if (StringUtil::Equals(value, "CALL_STATEMENT")) { - return StatementType::CALL_STATEMENT; - } - if (StringUtil::Equals(value, "SET_STATEMENT")) { - return StatementType::SET_STATEMENT; - } - if (StringUtil::Equals(value, "LOAD_STATEMENT")) { - return StatementType::LOAD_STATEMENT; - } - if (StringUtil::Equals(value, "RELATION_STATEMENT")) { - return StatementType::RELATION_STATEMENT; - } - if (StringUtil::Equals(value, "EXTENSION_STATEMENT")) { - return StatementType::EXTENSION_STATEMENT; - } - if (StringUtil::Equals(value, "LOGICAL_PLAN_STATEMENT")) { - return StatementType::LOGICAL_PLAN_STATEMENT; - } - if (StringUtil::Equals(value, "ATTACH_STATEMENT")) { - return StatementType::ATTACH_STATEMENT; - } - if (StringUtil::Equals(value, "DETACH_STATEMENT")) { - return StatementType::DETACH_STATEMENT; - } - if (StringUtil::Equals(value, "MULTI_STATEMENT")) { - return StatementType::MULTI_STATEMENT; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(StatisticsType value) { - switch(value) { - case StatisticsType::NUMERIC_STATS: - return "NUMERIC_STATS"; - case StatisticsType::STRING_STATS: - return "STRING_STATS"; - case StatisticsType::LIST_STATS: - return "LIST_STATS"; - case StatisticsType::STRUCT_STATS: - return "STRUCT_STATS"; - case StatisticsType::BASE_STATS: - return "BASE_STATS"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -StatisticsType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "NUMERIC_STATS")) { - return StatisticsType::NUMERIC_STATS; - } - if (StringUtil::Equals(value, "STRING_STATS")) { - return StatisticsType::STRING_STATS; - } - if (StringUtil::Equals(value, "LIST_STATS")) { - return StatisticsType::LIST_STATS; - } - if (StringUtil::Equals(value, "STRUCT_STATS")) { - return StatisticsType::STRUCT_STATS; - } - if (StringUtil::Equals(value, "BASE_STATS")) { - return StatisticsType::BASE_STATS; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(StatsInfo value) { - switch(value) { - case StatsInfo::CAN_HAVE_NULL_VALUES: - return "CAN_HAVE_NULL_VALUES"; - case StatsInfo::CANNOT_HAVE_NULL_VALUES: - return "CANNOT_HAVE_NULL_VALUES"; - case StatsInfo::CAN_HAVE_VALID_VALUES: - return "CAN_HAVE_VALID_VALUES"; - case StatsInfo::CANNOT_HAVE_VALID_VALUES: - return "CANNOT_HAVE_VALID_VALUES"; - case StatsInfo::CAN_HAVE_NULL_AND_VALID_VALUES: - return "CAN_HAVE_NULL_AND_VALID_VALUES"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -StatsInfo EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "CAN_HAVE_NULL_VALUES")) { - return StatsInfo::CAN_HAVE_NULL_VALUES; - } - if (StringUtil::Equals(value, "CANNOT_HAVE_NULL_VALUES")) { - return StatsInfo::CANNOT_HAVE_NULL_VALUES; - } - if (StringUtil::Equals(value, "CAN_HAVE_VALID_VALUES")) { - return StatsInfo::CAN_HAVE_VALID_VALUES; - } - if (StringUtil::Equals(value, "CANNOT_HAVE_VALID_VALUES")) { - return StatsInfo::CANNOT_HAVE_VALID_VALUES; - } - if (StringUtil::Equals(value, "CAN_HAVE_NULL_AND_VALID_VALUES")) { - return StatsInfo::CAN_HAVE_NULL_AND_VALID_VALUES; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(StrTimeSpecifier value) { - switch(value) { - case StrTimeSpecifier::ABBREVIATED_WEEKDAY_NAME: - return "ABBREVIATED_WEEKDAY_NAME"; - case StrTimeSpecifier::FULL_WEEKDAY_NAME: - return "FULL_WEEKDAY_NAME"; - case StrTimeSpecifier::WEEKDAY_DECIMAL: - return "WEEKDAY_DECIMAL"; - case StrTimeSpecifier::DAY_OF_MONTH_PADDED: - return "DAY_OF_MONTH_PADDED"; - case StrTimeSpecifier::DAY_OF_MONTH: - return "DAY_OF_MONTH"; - case StrTimeSpecifier::ABBREVIATED_MONTH_NAME: - return "ABBREVIATED_MONTH_NAME"; - case StrTimeSpecifier::FULL_MONTH_NAME: - return "FULL_MONTH_NAME"; - case StrTimeSpecifier::MONTH_DECIMAL_PADDED: - return "MONTH_DECIMAL_PADDED"; - case StrTimeSpecifier::MONTH_DECIMAL: - return "MONTH_DECIMAL"; - case StrTimeSpecifier::YEAR_WITHOUT_CENTURY_PADDED: - return "YEAR_WITHOUT_CENTURY_PADDED"; - case StrTimeSpecifier::YEAR_WITHOUT_CENTURY: - return "YEAR_WITHOUT_CENTURY"; - case StrTimeSpecifier::YEAR_DECIMAL: - return "YEAR_DECIMAL"; - case StrTimeSpecifier::HOUR_24_PADDED: - return "HOUR_24_PADDED"; - case StrTimeSpecifier::HOUR_24_DECIMAL: - return "HOUR_24_DECIMAL"; - case StrTimeSpecifier::HOUR_12_PADDED: - return "HOUR_12_PADDED"; - case StrTimeSpecifier::HOUR_12_DECIMAL: - return "HOUR_12_DECIMAL"; - case StrTimeSpecifier::AM_PM: - return "AM_PM"; - case StrTimeSpecifier::MINUTE_PADDED: - return "MINUTE_PADDED"; - case StrTimeSpecifier::MINUTE_DECIMAL: - return "MINUTE_DECIMAL"; - case StrTimeSpecifier::SECOND_PADDED: - return "SECOND_PADDED"; - case StrTimeSpecifier::SECOND_DECIMAL: - return "SECOND_DECIMAL"; - case StrTimeSpecifier::MICROSECOND_PADDED: - return "MICROSECOND_PADDED"; - case StrTimeSpecifier::MILLISECOND_PADDED: - return "MILLISECOND_PADDED"; - case StrTimeSpecifier::UTC_OFFSET: - return "UTC_OFFSET"; - case StrTimeSpecifier::TZ_NAME: - return "TZ_NAME"; - case StrTimeSpecifier::DAY_OF_YEAR_PADDED: - return "DAY_OF_YEAR_PADDED"; - case StrTimeSpecifier::DAY_OF_YEAR_DECIMAL: - return "DAY_OF_YEAR_DECIMAL"; - case StrTimeSpecifier::WEEK_NUMBER_PADDED_SUN_FIRST: - return "WEEK_NUMBER_PADDED_SUN_FIRST"; - case StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST: - return "WEEK_NUMBER_PADDED_MON_FIRST"; - case StrTimeSpecifier::LOCALE_APPROPRIATE_DATE_AND_TIME: - return "LOCALE_APPROPRIATE_DATE_AND_TIME"; - case StrTimeSpecifier::LOCALE_APPROPRIATE_DATE: - return "LOCALE_APPROPRIATE_DATE"; - case StrTimeSpecifier::LOCALE_APPROPRIATE_TIME: - return "LOCALE_APPROPRIATE_TIME"; - case StrTimeSpecifier::NANOSECOND_PADDED: - return "NANOSECOND_PADDED"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -StrTimeSpecifier EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "ABBREVIATED_WEEKDAY_NAME")) { - return StrTimeSpecifier::ABBREVIATED_WEEKDAY_NAME; - } - if (StringUtil::Equals(value, "FULL_WEEKDAY_NAME")) { - return StrTimeSpecifier::FULL_WEEKDAY_NAME; - } - if (StringUtil::Equals(value, "WEEKDAY_DECIMAL")) { - return StrTimeSpecifier::WEEKDAY_DECIMAL; - } - if (StringUtil::Equals(value, "DAY_OF_MONTH_PADDED")) { - return StrTimeSpecifier::DAY_OF_MONTH_PADDED; - } - if (StringUtil::Equals(value, "DAY_OF_MONTH")) { - return StrTimeSpecifier::DAY_OF_MONTH; - } - if (StringUtil::Equals(value, "ABBREVIATED_MONTH_NAME")) { - return StrTimeSpecifier::ABBREVIATED_MONTH_NAME; - } - if (StringUtil::Equals(value, "FULL_MONTH_NAME")) { - return StrTimeSpecifier::FULL_MONTH_NAME; - } - if (StringUtil::Equals(value, "MONTH_DECIMAL_PADDED")) { - return StrTimeSpecifier::MONTH_DECIMAL_PADDED; - } - if (StringUtil::Equals(value, "MONTH_DECIMAL")) { - return StrTimeSpecifier::MONTH_DECIMAL; - } - if (StringUtil::Equals(value, "YEAR_WITHOUT_CENTURY_PADDED")) { - return StrTimeSpecifier::YEAR_WITHOUT_CENTURY_PADDED; - } - if (StringUtil::Equals(value, "YEAR_WITHOUT_CENTURY")) { - return StrTimeSpecifier::YEAR_WITHOUT_CENTURY; - } - if (StringUtil::Equals(value, "YEAR_DECIMAL")) { - return StrTimeSpecifier::YEAR_DECIMAL; - } - if (StringUtil::Equals(value, "HOUR_24_PADDED")) { - return StrTimeSpecifier::HOUR_24_PADDED; - } - if (StringUtil::Equals(value, "HOUR_24_DECIMAL")) { - return StrTimeSpecifier::HOUR_24_DECIMAL; - } - if (StringUtil::Equals(value, "HOUR_12_PADDED")) { - return StrTimeSpecifier::HOUR_12_PADDED; - } - if (StringUtil::Equals(value, "HOUR_12_DECIMAL")) { - return StrTimeSpecifier::HOUR_12_DECIMAL; - } - if (StringUtil::Equals(value, "AM_PM")) { - return StrTimeSpecifier::AM_PM; - } - if (StringUtil::Equals(value, "MINUTE_PADDED")) { - return StrTimeSpecifier::MINUTE_PADDED; - } - if (StringUtil::Equals(value, "MINUTE_DECIMAL")) { - return StrTimeSpecifier::MINUTE_DECIMAL; - } - if (StringUtil::Equals(value, "SECOND_PADDED")) { - return StrTimeSpecifier::SECOND_PADDED; - } - if (StringUtil::Equals(value, "SECOND_DECIMAL")) { - return StrTimeSpecifier::SECOND_DECIMAL; - } - if (StringUtil::Equals(value, "MICROSECOND_PADDED")) { - return StrTimeSpecifier::MICROSECOND_PADDED; - } - if (StringUtil::Equals(value, "MILLISECOND_PADDED")) { - return StrTimeSpecifier::MILLISECOND_PADDED; - } - if (StringUtil::Equals(value, "UTC_OFFSET")) { - return StrTimeSpecifier::UTC_OFFSET; - } - if (StringUtil::Equals(value, "TZ_NAME")) { - return StrTimeSpecifier::TZ_NAME; - } - if (StringUtil::Equals(value, "DAY_OF_YEAR_PADDED")) { - return StrTimeSpecifier::DAY_OF_YEAR_PADDED; - } - if (StringUtil::Equals(value, "DAY_OF_YEAR_DECIMAL")) { - return StrTimeSpecifier::DAY_OF_YEAR_DECIMAL; - } - if (StringUtil::Equals(value, "WEEK_NUMBER_PADDED_SUN_FIRST")) { - return StrTimeSpecifier::WEEK_NUMBER_PADDED_SUN_FIRST; - } - if (StringUtil::Equals(value, "WEEK_NUMBER_PADDED_MON_FIRST")) { - return StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST; - } - if (StringUtil::Equals(value, "LOCALE_APPROPRIATE_DATE_AND_TIME")) { - return StrTimeSpecifier::LOCALE_APPROPRIATE_DATE_AND_TIME; - } - if (StringUtil::Equals(value, "LOCALE_APPROPRIATE_DATE")) { - return StrTimeSpecifier::LOCALE_APPROPRIATE_DATE; - } - if (StringUtil::Equals(value, "LOCALE_APPROPRIATE_TIME")) { - return StrTimeSpecifier::LOCALE_APPROPRIATE_TIME; - } - if (StringUtil::Equals(value, "NANOSECOND_PADDED")) { - return StrTimeSpecifier::NANOSECOND_PADDED; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(SubqueryType value) { - switch(value) { - case SubqueryType::INVALID: - return "INVALID"; - case SubqueryType::SCALAR: - return "SCALAR"; - case SubqueryType::EXISTS: - return "EXISTS"; - case SubqueryType::NOT_EXISTS: - return "NOT_EXISTS"; - case SubqueryType::ANY: - return "ANY"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -SubqueryType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "INVALID")) { - return SubqueryType::INVALID; - } - if (StringUtil::Equals(value, "SCALAR")) { - return SubqueryType::SCALAR; - } - if (StringUtil::Equals(value, "EXISTS")) { - return SubqueryType::EXISTS; - } - if (StringUtil::Equals(value, "NOT_EXISTS")) { - return SubqueryType::NOT_EXISTS; - } - if (StringUtil::Equals(value, "ANY")) { - return SubqueryType::ANY; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(TableColumnType value) { - switch(value) { - case TableColumnType::STANDARD: - return "STANDARD"; - case TableColumnType::GENERATED: - return "GENERATED"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -TableColumnType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "STANDARD")) { - return TableColumnType::STANDARD; - } - if (StringUtil::Equals(value, "GENERATED")) { - return TableColumnType::GENERATED; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(TableFilterType value) { - switch(value) { - case TableFilterType::CONSTANT_COMPARISON: - return "CONSTANT_COMPARISON"; - case TableFilterType::IS_NULL: - return "IS_NULL"; - case TableFilterType::IS_NOT_NULL: - return "IS_NOT_NULL"; - case TableFilterType::CONJUNCTION_OR: - return "CONJUNCTION_OR"; - case TableFilterType::CONJUNCTION_AND: - return "CONJUNCTION_AND"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -TableFilterType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "CONSTANT_COMPARISON")) { - return TableFilterType::CONSTANT_COMPARISON; - } - if (StringUtil::Equals(value, "IS_NULL")) { - return TableFilterType::IS_NULL; - } - if (StringUtil::Equals(value, "IS_NOT_NULL")) { - return TableFilterType::IS_NOT_NULL; - } - if (StringUtil::Equals(value, "CONJUNCTION_OR")) { - return TableFilterType::CONJUNCTION_OR; - } - if (StringUtil::Equals(value, "CONJUNCTION_AND")) { - return TableFilterType::CONJUNCTION_AND; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(TableReferenceType value) { - switch(value) { - case TableReferenceType::INVALID: - return "INVALID"; - case TableReferenceType::BASE_TABLE: - return "BASE_TABLE"; - case TableReferenceType::SUBQUERY: - return "SUBQUERY"; - case TableReferenceType::JOIN: - return "JOIN"; - case TableReferenceType::TABLE_FUNCTION: - return "TABLE_FUNCTION"; - case TableReferenceType::EXPRESSION_LIST: - return "EXPRESSION_LIST"; - case TableReferenceType::CTE: - return "CTE"; - case TableReferenceType::EMPTY: - return "EMPTY"; - case TableReferenceType::PIVOT: - return "PIVOT"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -TableReferenceType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "INVALID")) { - return TableReferenceType::INVALID; - } - if (StringUtil::Equals(value, "BASE_TABLE")) { - return TableReferenceType::BASE_TABLE; - } - if (StringUtil::Equals(value, "SUBQUERY")) { - return TableReferenceType::SUBQUERY; - } - if (StringUtil::Equals(value, "JOIN")) { - return TableReferenceType::JOIN; - } - if (StringUtil::Equals(value, "TABLE_FUNCTION")) { - return TableReferenceType::TABLE_FUNCTION; - } - if (StringUtil::Equals(value, "EXPRESSION_LIST")) { - return TableReferenceType::EXPRESSION_LIST; - } - if (StringUtil::Equals(value, "CTE")) { - return TableReferenceType::CTE; - } - if (StringUtil::Equals(value, "EMPTY")) { - return TableReferenceType::EMPTY; - } - if (StringUtil::Equals(value, "PIVOT")) { - return TableReferenceType::PIVOT; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(TableScanType value) { - switch(value) { - case TableScanType::TABLE_SCAN_REGULAR: - return "TABLE_SCAN_REGULAR"; - case TableScanType::TABLE_SCAN_COMMITTED_ROWS: - return "TABLE_SCAN_COMMITTED_ROWS"; - case TableScanType::TABLE_SCAN_COMMITTED_ROWS_DISALLOW_UPDATES: - return "TABLE_SCAN_COMMITTED_ROWS_DISALLOW_UPDATES"; - case TableScanType::TABLE_SCAN_COMMITTED_ROWS_OMIT_PERMANENTLY_DELETED: - return "TABLE_SCAN_COMMITTED_ROWS_OMIT_PERMANENTLY_DELETED"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -TableScanType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "TABLE_SCAN_REGULAR")) { - return TableScanType::TABLE_SCAN_REGULAR; - } - if (StringUtil::Equals(value, "TABLE_SCAN_COMMITTED_ROWS")) { - return TableScanType::TABLE_SCAN_COMMITTED_ROWS; - } - if (StringUtil::Equals(value, "TABLE_SCAN_COMMITTED_ROWS_DISALLOW_UPDATES")) { - return TableScanType::TABLE_SCAN_COMMITTED_ROWS_DISALLOW_UPDATES; - } - if (StringUtil::Equals(value, "TABLE_SCAN_COMMITTED_ROWS_OMIT_PERMANENTLY_DELETED")) { - return TableScanType::TABLE_SCAN_COMMITTED_ROWS_OMIT_PERMANENTLY_DELETED; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(TaskExecutionMode value) { - switch(value) { - case TaskExecutionMode::PROCESS_ALL: - return "PROCESS_ALL"; - case TaskExecutionMode::PROCESS_PARTIAL: - return "PROCESS_PARTIAL"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -TaskExecutionMode EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "PROCESS_ALL")) { - return TaskExecutionMode::PROCESS_ALL; - } - if (StringUtil::Equals(value, "PROCESS_PARTIAL")) { - return TaskExecutionMode::PROCESS_PARTIAL; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(TaskExecutionResult value) { - switch(value) { - case TaskExecutionResult::TASK_FINISHED: - return "TASK_FINISHED"; - case TaskExecutionResult::TASK_NOT_FINISHED: - return "TASK_NOT_FINISHED"; - case TaskExecutionResult::TASK_ERROR: - return "TASK_ERROR"; - case TaskExecutionResult::TASK_BLOCKED: - return "TASK_BLOCKED"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -TaskExecutionResult EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "TASK_FINISHED")) { - return TaskExecutionResult::TASK_FINISHED; - } - if (StringUtil::Equals(value, "TASK_NOT_FINISHED")) { - return TaskExecutionResult::TASK_NOT_FINISHED; - } - if (StringUtil::Equals(value, "TASK_ERROR")) { - return TaskExecutionResult::TASK_ERROR; - } - if (StringUtil::Equals(value, "TASK_BLOCKED")) { - return TaskExecutionResult::TASK_BLOCKED; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(TimestampCastResult value) { - switch(value) { - case TimestampCastResult::SUCCESS: - return "SUCCESS"; - case TimestampCastResult::ERROR_INCORRECT_FORMAT: - return "ERROR_INCORRECT_FORMAT"; - case TimestampCastResult::ERROR_NON_UTC_TIMEZONE: - return "ERROR_NON_UTC_TIMEZONE"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -TimestampCastResult EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "SUCCESS")) { - return TimestampCastResult::SUCCESS; - } - if (StringUtil::Equals(value, "ERROR_INCORRECT_FORMAT")) { - return TimestampCastResult::ERROR_INCORRECT_FORMAT; - } - if (StringUtil::Equals(value, "ERROR_NON_UTC_TIMEZONE")) { - return TimestampCastResult::ERROR_NON_UTC_TIMEZONE; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(TransactionType value) { - switch(value) { - case TransactionType::INVALID: - return "INVALID"; - case TransactionType::BEGIN_TRANSACTION: - return "BEGIN_TRANSACTION"; - case TransactionType::COMMIT: - return "COMMIT"; - case TransactionType::ROLLBACK: - return "ROLLBACK"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -TransactionType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "INVALID")) { - return TransactionType::INVALID; - } - if (StringUtil::Equals(value, "BEGIN_TRANSACTION")) { - return TransactionType::BEGIN_TRANSACTION; - } - if (StringUtil::Equals(value, "COMMIT")) { - return TransactionType::COMMIT; - } - if (StringUtil::Equals(value, "ROLLBACK")) { - return TransactionType::ROLLBACK; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(TupleDataPinProperties value) { - switch(value) { - case TupleDataPinProperties::INVALID: - return "INVALID"; - case TupleDataPinProperties::KEEP_EVERYTHING_PINNED: - return "KEEP_EVERYTHING_PINNED"; - case TupleDataPinProperties::UNPIN_AFTER_DONE: - return "UNPIN_AFTER_DONE"; - case TupleDataPinProperties::DESTROY_AFTER_DONE: - return "DESTROY_AFTER_DONE"; - case TupleDataPinProperties::ALREADY_PINNED: - return "ALREADY_PINNED"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -TupleDataPinProperties EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "INVALID")) { - return TupleDataPinProperties::INVALID; - } - if (StringUtil::Equals(value, "KEEP_EVERYTHING_PINNED")) { - return TupleDataPinProperties::KEEP_EVERYTHING_PINNED; - } - if (StringUtil::Equals(value, "UNPIN_AFTER_DONE")) { - return TupleDataPinProperties::UNPIN_AFTER_DONE; - } - if (StringUtil::Equals(value, "DESTROY_AFTER_DONE")) { - return TupleDataPinProperties::DESTROY_AFTER_DONE; - } - if (StringUtil::Equals(value, "ALREADY_PINNED")) { - return TupleDataPinProperties::ALREADY_PINNED; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(UndoFlags value) { - switch(value) { - case UndoFlags::EMPTY_ENTRY: - return "EMPTY_ENTRY"; - case UndoFlags::CATALOG_ENTRY: - return "CATALOG_ENTRY"; - case UndoFlags::INSERT_TUPLE: - return "INSERT_TUPLE"; - case UndoFlags::DELETE_TUPLE: - return "DELETE_TUPLE"; - case UndoFlags::UPDATE_TUPLE: - return "UPDATE_TUPLE"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -UndoFlags EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "EMPTY_ENTRY")) { - return UndoFlags::EMPTY_ENTRY; - } - if (StringUtil::Equals(value, "CATALOG_ENTRY")) { - return UndoFlags::CATALOG_ENTRY; - } - if (StringUtil::Equals(value, "INSERT_TUPLE")) { - return UndoFlags::INSERT_TUPLE; - } - if (StringUtil::Equals(value, "DELETE_TUPLE")) { - return UndoFlags::DELETE_TUPLE; - } - if (StringUtil::Equals(value, "UPDATE_TUPLE")) { - return UndoFlags::UPDATE_TUPLE; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(UnionInvalidReason value) { - switch(value) { - case UnionInvalidReason::VALID: - return "VALID"; - case UnionInvalidReason::TAG_OUT_OF_RANGE: - return "TAG_OUT_OF_RANGE"; - case UnionInvalidReason::NO_MEMBERS: - return "NO_MEMBERS"; - case UnionInvalidReason::VALIDITY_OVERLAP: - return "VALIDITY_OVERLAP"; - case UnionInvalidReason::TAG_MISMATCH: - return "TAG_MISMATCH"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -UnionInvalidReason EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "VALID")) { - return UnionInvalidReason::VALID; - } - if (StringUtil::Equals(value, "TAG_OUT_OF_RANGE")) { - return UnionInvalidReason::TAG_OUT_OF_RANGE; - } - if (StringUtil::Equals(value, "NO_MEMBERS")) { - return UnionInvalidReason::NO_MEMBERS; - } - if (StringUtil::Equals(value, "VALIDITY_OVERLAP")) { - return UnionInvalidReason::VALIDITY_OVERLAP; - } - if (StringUtil::Equals(value, "TAG_MISMATCH")) { - return UnionInvalidReason::TAG_MISMATCH; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(VectorAuxiliaryDataType value) { - switch(value) { - case VectorAuxiliaryDataType::ARROW_AUXILIARY: - return "ARROW_AUXILIARY"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -VectorAuxiliaryDataType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "ARROW_AUXILIARY")) { - return VectorAuxiliaryDataType::ARROW_AUXILIARY; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(VectorBufferType value) { - switch(value) { - case VectorBufferType::STANDARD_BUFFER: - return "STANDARD_BUFFER"; - case VectorBufferType::DICTIONARY_BUFFER: - return "DICTIONARY_BUFFER"; - case VectorBufferType::VECTOR_CHILD_BUFFER: - return "VECTOR_CHILD_BUFFER"; - case VectorBufferType::STRING_BUFFER: - return "STRING_BUFFER"; - case VectorBufferType::FSST_BUFFER: - return "FSST_BUFFER"; - case VectorBufferType::STRUCT_BUFFER: - return "STRUCT_BUFFER"; - case VectorBufferType::LIST_BUFFER: - return "LIST_BUFFER"; - case VectorBufferType::MANAGED_BUFFER: - return "MANAGED_BUFFER"; - case VectorBufferType::OPAQUE_BUFFER: - return "OPAQUE_BUFFER"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -VectorBufferType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "STANDARD_BUFFER")) { - return VectorBufferType::STANDARD_BUFFER; - } - if (StringUtil::Equals(value, "DICTIONARY_BUFFER")) { - return VectorBufferType::DICTIONARY_BUFFER; - } - if (StringUtil::Equals(value, "VECTOR_CHILD_BUFFER")) { - return VectorBufferType::VECTOR_CHILD_BUFFER; - } - if (StringUtil::Equals(value, "STRING_BUFFER")) { - return VectorBufferType::STRING_BUFFER; - } - if (StringUtil::Equals(value, "FSST_BUFFER")) { - return VectorBufferType::FSST_BUFFER; - } - if (StringUtil::Equals(value, "STRUCT_BUFFER")) { - return VectorBufferType::STRUCT_BUFFER; - } - if (StringUtil::Equals(value, "LIST_BUFFER")) { - return VectorBufferType::LIST_BUFFER; - } - if (StringUtil::Equals(value, "MANAGED_BUFFER")) { - return VectorBufferType::MANAGED_BUFFER; - } - if (StringUtil::Equals(value, "OPAQUE_BUFFER")) { - return VectorBufferType::OPAQUE_BUFFER; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(VectorType value) { - switch(value) { - case VectorType::FLAT_VECTOR: - return "FLAT_VECTOR"; - case VectorType::FSST_VECTOR: - return "FSST_VECTOR"; - case VectorType::CONSTANT_VECTOR: - return "CONSTANT_VECTOR"; - case VectorType::DICTIONARY_VECTOR: - return "DICTIONARY_VECTOR"; - case VectorType::SEQUENCE_VECTOR: - return "SEQUENCE_VECTOR"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -VectorType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "FLAT_VECTOR")) { - return VectorType::FLAT_VECTOR; - } - if (StringUtil::Equals(value, "FSST_VECTOR")) { - return VectorType::FSST_VECTOR; - } - if (StringUtil::Equals(value, "CONSTANT_VECTOR")) { - return VectorType::CONSTANT_VECTOR; - } - if (StringUtil::Equals(value, "DICTIONARY_VECTOR")) { - return VectorType::DICTIONARY_VECTOR; - } - if (StringUtil::Equals(value, "SEQUENCE_VECTOR")) { - return VectorType::SEQUENCE_VECTOR; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(VerificationType value) { - switch(value) { - case VerificationType::ORIGINAL: - return "ORIGINAL"; - case VerificationType::COPIED: - return "COPIED"; - case VerificationType::DESERIALIZED: - return "DESERIALIZED"; - case VerificationType::PARSED: - return "PARSED"; - case VerificationType::UNOPTIMIZED: - return "UNOPTIMIZED"; - case VerificationType::NO_OPERATOR_CACHING: - return "NO_OPERATOR_CACHING"; - case VerificationType::PREPARED: - return "PREPARED"; - case VerificationType::EXTERNAL: - return "EXTERNAL"; - case VerificationType::INVALID: - return "INVALID"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -VerificationType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "ORIGINAL")) { - return VerificationType::ORIGINAL; - } - if (StringUtil::Equals(value, "COPIED")) { - return VerificationType::COPIED; - } - if (StringUtil::Equals(value, "DESERIALIZED")) { - return VerificationType::DESERIALIZED; - } - if (StringUtil::Equals(value, "PARSED")) { - return VerificationType::PARSED; - } - if (StringUtil::Equals(value, "UNOPTIMIZED")) { - return VerificationType::UNOPTIMIZED; - } - if (StringUtil::Equals(value, "NO_OPERATOR_CACHING")) { - return VerificationType::NO_OPERATOR_CACHING; - } - if (StringUtil::Equals(value, "PREPARED")) { - return VerificationType::PREPARED; - } - if (StringUtil::Equals(value, "EXTERNAL")) { - return VerificationType::EXTERNAL; - } - if (StringUtil::Equals(value, "INVALID")) { - return VerificationType::INVALID; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(VerifyExistenceType value) { - switch(value) { - case VerifyExistenceType::APPEND: - return "APPEND"; - case VerifyExistenceType::APPEND_FK: - return "APPEND_FK"; - case VerifyExistenceType::DELETE_FK: - return "DELETE_FK"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -VerifyExistenceType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "APPEND")) { - return VerifyExistenceType::APPEND; - } - if (StringUtil::Equals(value, "APPEND_FK")) { - return VerifyExistenceType::APPEND_FK; - } - if (StringUtil::Equals(value, "DELETE_FK")) { - return VerifyExistenceType::DELETE_FK; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(WALType value) { - switch(value) { - case WALType::INVALID: - return "INVALID"; - case WALType::CREATE_TABLE: - return "CREATE_TABLE"; - case WALType::DROP_TABLE: - return "DROP_TABLE"; - case WALType::CREATE_SCHEMA: - return "CREATE_SCHEMA"; - case WALType::DROP_SCHEMA: - return "DROP_SCHEMA"; - case WALType::CREATE_VIEW: - return "CREATE_VIEW"; - case WALType::DROP_VIEW: - return "DROP_VIEW"; - case WALType::CREATE_SEQUENCE: - return "CREATE_SEQUENCE"; - case WALType::DROP_SEQUENCE: - return "DROP_SEQUENCE"; - case WALType::SEQUENCE_VALUE: - return "SEQUENCE_VALUE"; - case WALType::CREATE_MACRO: - return "CREATE_MACRO"; - case WALType::DROP_MACRO: - return "DROP_MACRO"; - case WALType::CREATE_TYPE: - return "CREATE_TYPE"; - case WALType::DROP_TYPE: - return "DROP_TYPE"; - case WALType::ALTER_INFO: - return "ALTER_INFO"; - case WALType::CREATE_TABLE_MACRO: - return "CREATE_TABLE_MACRO"; - case WALType::DROP_TABLE_MACRO: - return "DROP_TABLE_MACRO"; - case WALType::CREATE_INDEX: - return "CREATE_INDEX"; - case WALType::DROP_INDEX: - return "DROP_INDEX"; - case WALType::USE_TABLE: - return "USE_TABLE"; - case WALType::INSERT_TUPLE: - return "INSERT_TUPLE"; - case WALType::DELETE_TUPLE: - return "DELETE_TUPLE"; - case WALType::UPDATE_TUPLE: - return "UPDATE_TUPLE"; - case WALType::CHECKPOINT: - return "CHECKPOINT"; - case WALType::WAL_FLUSH: - return "WAL_FLUSH"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -WALType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "INVALID")) { - return WALType::INVALID; - } - if (StringUtil::Equals(value, "CREATE_TABLE")) { - return WALType::CREATE_TABLE; - } - if (StringUtil::Equals(value, "DROP_TABLE")) { - return WALType::DROP_TABLE; - } - if (StringUtil::Equals(value, "CREATE_SCHEMA")) { - return WALType::CREATE_SCHEMA; - } - if (StringUtil::Equals(value, "DROP_SCHEMA")) { - return WALType::DROP_SCHEMA; - } - if (StringUtil::Equals(value, "CREATE_VIEW")) { - return WALType::CREATE_VIEW; - } - if (StringUtil::Equals(value, "DROP_VIEW")) { - return WALType::DROP_VIEW; - } - if (StringUtil::Equals(value, "CREATE_SEQUENCE")) { - return WALType::CREATE_SEQUENCE; - } - if (StringUtil::Equals(value, "DROP_SEQUENCE")) { - return WALType::DROP_SEQUENCE; - } - if (StringUtil::Equals(value, "SEQUENCE_VALUE")) { - return WALType::SEQUENCE_VALUE; - } - if (StringUtil::Equals(value, "CREATE_MACRO")) { - return WALType::CREATE_MACRO; - } - if (StringUtil::Equals(value, "DROP_MACRO")) { - return WALType::DROP_MACRO; - } - if (StringUtil::Equals(value, "CREATE_TYPE")) { - return WALType::CREATE_TYPE; - } - if (StringUtil::Equals(value, "DROP_TYPE")) { - return WALType::DROP_TYPE; - } - if (StringUtil::Equals(value, "ALTER_INFO")) { - return WALType::ALTER_INFO; - } - if (StringUtil::Equals(value, "CREATE_TABLE_MACRO")) { - return WALType::CREATE_TABLE_MACRO; - } - if (StringUtil::Equals(value, "DROP_TABLE_MACRO")) { - return WALType::DROP_TABLE_MACRO; - } - if (StringUtil::Equals(value, "CREATE_INDEX")) { - return WALType::CREATE_INDEX; - } - if (StringUtil::Equals(value, "DROP_INDEX")) { - return WALType::DROP_INDEX; - } - if (StringUtil::Equals(value, "USE_TABLE")) { - return WALType::USE_TABLE; - } - if (StringUtil::Equals(value, "INSERT_TUPLE")) { - return WALType::INSERT_TUPLE; - } - if (StringUtil::Equals(value, "DELETE_TUPLE")) { - return WALType::DELETE_TUPLE; - } - if (StringUtil::Equals(value, "UPDATE_TUPLE")) { - return WALType::UPDATE_TUPLE; - } - if (StringUtil::Equals(value, "CHECKPOINT")) { - return WALType::CHECKPOINT; - } - if (StringUtil::Equals(value, "WAL_FLUSH")) { - return WALType::WAL_FLUSH; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(WindowAggregationMode value) { - switch(value) { - case WindowAggregationMode::WINDOW: - return "WINDOW"; - case WindowAggregationMode::COMBINE: - return "COMBINE"; - case WindowAggregationMode::SEPARATE: - return "SEPARATE"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -WindowAggregationMode EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "WINDOW")) { - return WindowAggregationMode::WINDOW; - } - if (StringUtil::Equals(value, "COMBINE")) { - return WindowAggregationMode::COMBINE; - } - if (StringUtil::Equals(value, "SEPARATE")) { - return WindowAggregationMode::SEPARATE; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -template<> -const char* EnumUtil::ToChars(WindowBoundary value) { - switch(value) { - case WindowBoundary::INVALID: - return "INVALID"; - case WindowBoundary::UNBOUNDED_PRECEDING: - return "UNBOUNDED_PRECEDING"; - case WindowBoundary::UNBOUNDED_FOLLOWING: - return "UNBOUNDED_FOLLOWING"; - case WindowBoundary::CURRENT_ROW_RANGE: - return "CURRENT_ROW_RANGE"; - case WindowBoundary::CURRENT_ROW_ROWS: - return "CURRENT_ROW_ROWS"; - case WindowBoundary::EXPR_PRECEDING_ROWS: - return "EXPR_PRECEDING_ROWS"; - case WindowBoundary::EXPR_FOLLOWING_ROWS: - return "EXPR_FOLLOWING_ROWS"; - case WindowBoundary::EXPR_PRECEDING_RANGE: - return "EXPR_PRECEDING_RANGE"; - case WindowBoundary::EXPR_FOLLOWING_RANGE: - return "EXPR_FOLLOWING_RANGE"; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); - } -} - -template<> -WindowBoundary EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "INVALID")) { - return WindowBoundary::INVALID; - } - if (StringUtil::Equals(value, "UNBOUNDED_PRECEDING")) { - return WindowBoundary::UNBOUNDED_PRECEDING; - } - if (StringUtil::Equals(value, "UNBOUNDED_FOLLOWING")) { - return WindowBoundary::UNBOUNDED_FOLLOWING; - } - if (StringUtil::Equals(value, "CURRENT_ROW_RANGE")) { - return WindowBoundary::CURRENT_ROW_RANGE; - } - if (StringUtil::Equals(value, "CURRENT_ROW_ROWS")) { - return WindowBoundary::CURRENT_ROW_ROWS; - } - if (StringUtil::Equals(value, "EXPR_PRECEDING_ROWS")) { - return WindowBoundary::EXPR_PRECEDING_ROWS; - } - if (StringUtil::Equals(value, "EXPR_FOLLOWING_ROWS")) { - return WindowBoundary::EXPR_FOLLOWING_ROWS; - } - if (StringUtil::Equals(value, "EXPR_PRECEDING_RANGE")) { - return WindowBoundary::EXPR_PRECEDING_RANGE; - } - if (StringUtil::Equals(value, "EXPR_FOLLOWING_RANGE")) { - return WindowBoundary::EXPR_FOLLOWING_RANGE; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -} - - - - - -namespace duckdb { - -// LCOV_EXCL_START -string CatalogTypeToString(CatalogType type) { - switch (type) { - case CatalogType::COLLATION_ENTRY: - return "Collation"; - case CatalogType::TYPE_ENTRY: - return "Type"; - case CatalogType::TABLE_ENTRY: - return "Table"; - case CatalogType::SCHEMA_ENTRY: - return "Schema"; - case CatalogType::DATABASE_ENTRY: - return "Database"; - case CatalogType::TABLE_FUNCTION_ENTRY: - return "Table Function"; - case CatalogType::SCALAR_FUNCTION_ENTRY: - return "Scalar Function"; - case CatalogType::AGGREGATE_FUNCTION_ENTRY: - return "Aggregate Function"; - case CatalogType::COPY_FUNCTION_ENTRY: - return "Copy Function"; - case CatalogType::PRAGMA_FUNCTION_ENTRY: - return "Pragma Function"; - case CatalogType::MACRO_ENTRY: - return "Macro Function"; - case CatalogType::TABLE_MACRO_ENTRY: - return "Table Macro Function"; - case CatalogType::VIEW_ENTRY: - return "View"; - case CatalogType::INDEX_ENTRY: - return "Index"; - case CatalogType::PREPARED_STATEMENT: - return "Prepared Statement"; - case CatalogType::SEQUENCE_ENTRY: - return "Sequence"; - case CatalogType::INVALID: - case CatalogType::DELETED_ENTRY: - case CatalogType::UPDATED_ENTRY: - break; - } - return "INVALID"; -} -// LCOV_EXCL_STOP - -} // namespace duckdb - - - - -namespace duckdb { - -// LCOV_EXCL_START - -vector ListCompressionTypes(void) { - vector compression_types; - uint8_t amount_of_compression_options = (uint8_t)CompressionType::COMPRESSION_COUNT; - compression_types.reserve(amount_of_compression_options); - for (uint8_t i = 0; i < amount_of_compression_options; i++) { - compression_types.push_back(CompressionTypeToString((CompressionType)i)); - } - return compression_types; -} - -CompressionType CompressionTypeFromString(const string &str) { - auto compression = StringUtil::Lower(str); - if (compression == "uncompressed") { - return CompressionType::COMPRESSION_UNCOMPRESSED; - } else if (compression == "rle") { - return CompressionType::COMPRESSION_RLE; - } else if (compression == "dictionary") { - return CompressionType::COMPRESSION_DICTIONARY; - } else if (compression == "pfor") { - return CompressionType::COMPRESSION_PFOR_DELTA; - } else if (compression == "bitpacking") { - return CompressionType::COMPRESSION_BITPACKING; - } else if (compression == "fsst") { - return CompressionType::COMPRESSION_FSST; - } else if (compression == "chimp") { - return CompressionType::COMPRESSION_CHIMP; - } else if (compression == "patas") { - return CompressionType::COMPRESSION_PATAS; - } else { - return CompressionType::COMPRESSION_AUTO; - } -} - -string CompressionTypeToString(CompressionType type) { - switch (type) { - case CompressionType::COMPRESSION_AUTO: - return "Auto"; - case CompressionType::COMPRESSION_UNCOMPRESSED: - return "Uncompressed"; - case CompressionType::COMPRESSION_CONSTANT: - return "Constant"; - case CompressionType::COMPRESSION_RLE: - return "RLE"; - case CompressionType::COMPRESSION_DICTIONARY: - return "Dictionary"; - case CompressionType::COMPRESSION_PFOR_DELTA: - return "PFOR"; - case CompressionType::COMPRESSION_BITPACKING: - return "BitPacking"; - case CompressionType::COMPRESSION_FSST: - return "FSST"; - case CompressionType::COMPRESSION_CHIMP: - return "Chimp"; - case CompressionType::COMPRESSION_PATAS: - return "Patas"; - default: - throw InternalException("Unrecognized compression type!"); - } -} -// LCOV_EXCL_STOP - -} // namespace duckdb - - - -namespace duckdb { - -bool TryGetDatePartSpecifier(const string &specifier_p, DatePartSpecifier &result) { - auto specifier = StringUtil::Lower(specifier_p); - if (specifier == "year" || specifier == "yr" || specifier == "y" || specifier == "years" || specifier == "yrs") { - result = DatePartSpecifier::YEAR; - } else if (specifier == "month" || specifier == "mon" || specifier == "months" || specifier == "mons") { - result = DatePartSpecifier::MONTH; - } else if (specifier == "day" || specifier == "days" || specifier == "d" || specifier == "dayofmonth") { - result = DatePartSpecifier::DAY; - } else if (specifier == "decade" || specifier == "dec" || specifier == "decades" || specifier == "decs") { - result = DatePartSpecifier::DECADE; - } else if (specifier == "century" || specifier == "cent" || specifier == "centuries" || specifier == "c") { - result = DatePartSpecifier::CENTURY; - } else if (specifier == "millennium" || specifier == "mil" || specifier == "millenniums" || - specifier == "millennia" || specifier == "mils" || specifier == "millenium") { - result = DatePartSpecifier::MILLENNIUM; - } else if (specifier == "microseconds" || specifier == "microsecond" || specifier == "us" || specifier == "usec" || - specifier == "usecs" || specifier == "usecond" || specifier == "useconds") { - result = DatePartSpecifier::MICROSECONDS; - } else if (specifier == "milliseconds" || specifier == "millisecond" || specifier == "ms" || specifier == "msec" || - specifier == "msecs" || specifier == "msecond" || specifier == "mseconds") { - result = DatePartSpecifier::MILLISECONDS; - } else if (specifier == "second" || specifier == "sec" || specifier == "seconds" || specifier == "secs" || - specifier == "s") { - result = DatePartSpecifier::SECOND; - } else if (specifier == "minute" || specifier == "min" || specifier == "minutes" || specifier == "mins" || - specifier == "m") { - result = DatePartSpecifier::MINUTE; - } else if (specifier == "hour" || specifier == "hr" || specifier == "hours" || specifier == "hrs" || - specifier == "h") { - result = DatePartSpecifier::HOUR; - } else if (specifier == "epoch") { - // seconds since 1970-01-01 - result = DatePartSpecifier::EPOCH; - } else if (specifier == "dow" || specifier == "dayofweek" || specifier == "weekday") { - // day of the week (Sunday = 0, Saturday = 6) - result = DatePartSpecifier::DOW; - } else if (specifier == "isodow") { - // isodow (Monday = 1, Sunday = 7) - result = DatePartSpecifier::ISODOW; - } else if (specifier == "week" || specifier == "weeks" || specifier == "w" || specifier == "weekofyear") { - // ISO week number - result = DatePartSpecifier::WEEK; - } else if (specifier == "doy" || specifier == "dayofyear") { - // day of the year (1-365/366) - result = DatePartSpecifier::DOY; - } else if (specifier == "quarter" || specifier == "quarters") { - // quarter of the year (1-4) - result = DatePartSpecifier::QUARTER; - } else if (specifier == "yearweek") { - // Combined isoyear and isoweek YYYYWW - result = DatePartSpecifier::YEARWEEK; - } else if (specifier == "isoyear") { - // ISO year (first week of the year may be in previous year) - result = DatePartSpecifier::ISOYEAR; - } else if (specifier == "era") { - result = DatePartSpecifier::ERA; - } else if (specifier == "timezone") { - result = DatePartSpecifier::TIMEZONE; - } else if (specifier == "timezone_hour") { - result = DatePartSpecifier::TIMEZONE_HOUR; - } else if (specifier == "timezone_minute") { - result = DatePartSpecifier::TIMEZONE_MINUTE; - } else if (specifier == "julian" || specifier == "jd") { - result = DatePartSpecifier::JULIAN_DAY; - } else { - return false; - } - return true; -} - -DatePartSpecifier GetDatePartSpecifier(const string &specifier) { - DatePartSpecifier result; - if (!TryGetDatePartSpecifier(specifier, result)) { - throw ConversionException("extract specifier \"%s\" not recognized", specifier); - } - return result; -} - -} // namespace duckdb - - - - - -namespace duckdb { - -string ExpressionTypeToString(ExpressionType type) { - switch (type) { - case ExpressionType::OPERATOR_CAST: - return "CAST"; - case ExpressionType::OPERATOR_NOT: - return "NOT"; - case ExpressionType::OPERATOR_IS_NULL: - return "IS_NULL"; - case ExpressionType::OPERATOR_IS_NOT_NULL: - return "IS_NOT_NULL"; - case ExpressionType::COMPARE_EQUAL: - return "EQUAL"; - case ExpressionType::COMPARE_NOTEQUAL: - return "NOTEQUAL"; - case ExpressionType::COMPARE_LESSTHAN: - return "LESSTHAN"; - case ExpressionType::COMPARE_GREATERTHAN: - return "GREATERTHAN"; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - return "LESSTHANOREQUALTO"; - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - return "GREATERTHANOREQUALTO"; - case ExpressionType::COMPARE_IN: - return "IN"; - case ExpressionType::COMPARE_DISTINCT_FROM: - return "DISTINCT_FROM"; - case ExpressionType::COMPARE_NOT_DISTINCT_FROM: - return "NOT_DISTINCT_FROM"; - case ExpressionType::CONJUNCTION_AND: - return "AND"; - case ExpressionType::CONJUNCTION_OR: - return "OR"; - case ExpressionType::VALUE_CONSTANT: - return "CONSTANT"; - case ExpressionType::VALUE_PARAMETER: - return "PARAMETER"; - case ExpressionType::VALUE_TUPLE: - return "TUPLE"; - case ExpressionType::VALUE_TUPLE_ADDRESS: - return "TUPLE_ADDRESS"; - case ExpressionType::VALUE_NULL: - return "NULL"; - case ExpressionType::VALUE_VECTOR: - return "VECTOR"; - case ExpressionType::VALUE_SCALAR: - return "SCALAR"; - case ExpressionType::AGGREGATE: - return "AGGREGATE"; - case ExpressionType::WINDOW_AGGREGATE: - return "WINDOW_AGGREGATE"; - case ExpressionType::WINDOW_RANK: - return "RANK"; - case ExpressionType::WINDOW_RANK_DENSE: - return "RANK_DENSE"; - case ExpressionType::WINDOW_PERCENT_RANK: - return "PERCENT_RANK"; - case ExpressionType::WINDOW_ROW_NUMBER: - return "ROW_NUMBER"; - case ExpressionType::WINDOW_FIRST_VALUE: - return "FIRST_VALUE"; - case ExpressionType::WINDOW_LAST_VALUE: - return "LAST_VALUE"; - case ExpressionType::WINDOW_NTH_VALUE: - return "NTH_VALUE"; - case ExpressionType::WINDOW_CUME_DIST: - return "CUME_DIST"; - case ExpressionType::WINDOW_LEAD: - return "LEAD"; - case ExpressionType::WINDOW_LAG: - return "LAG"; - case ExpressionType::WINDOW_NTILE: - return "NTILE"; - case ExpressionType::FUNCTION: - return "FUNCTION"; - case ExpressionType::CASE_EXPR: - return "CASE"; - case ExpressionType::OPERATOR_NULLIF: - return "NULLIF"; - case ExpressionType::OPERATOR_COALESCE: - return "COALESCE"; - case ExpressionType::ARRAY_EXTRACT: - return "ARRAY_EXTRACT"; - case ExpressionType::ARRAY_SLICE: - return "ARRAY_SLICE"; - case ExpressionType::STRUCT_EXTRACT: - return "STRUCT_EXTRACT"; - case ExpressionType::SUBQUERY: - return "SUBQUERY"; - case ExpressionType::STAR: - return "STAR"; - case ExpressionType::PLACEHOLDER: - return "PLACEHOLDER"; - case ExpressionType::COLUMN_REF: - return "COLUMN_REF"; - case ExpressionType::FUNCTION_REF: - return "FUNCTION_REF"; - case ExpressionType::TABLE_REF: - return "TABLE_REF"; - case ExpressionType::CAST: - return "CAST"; - case ExpressionType::COMPARE_NOT_IN: - return "COMPARE_NOT_IN"; - case ExpressionType::COMPARE_BETWEEN: - return "COMPARE_BETWEEN"; - case ExpressionType::COMPARE_NOT_BETWEEN: - return "COMPARE_NOT_BETWEEN"; - case ExpressionType::VALUE_DEFAULT: - return "VALUE_DEFAULT"; - case ExpressionType::BOUND_REF: - return "BOUND_REF"; - case ExpressionType::BOUND_COLUMN_REF: - return "BOUND_COLUMN_REF"; - case ExpressionType::BOUND_FUNCTION: - return "BOUND_FUNCTION"; - case ExpressionType::BOUND_AGGREGATE: - return "BOUND_AGGREGATE"; - case ExpressionType::GROUPING_FUNCTION: - return "GROUPING"; - case ExpressionType::ARRAY_CONSTRUCTOR: - return "ARRAY_CONSTRUCTOR"; - case ExpressionType::TABLE_STAR: - return "TABLE_STAR"; - case ExpressionType::BOUND_UNNEST: - return "BOUND_UNNEST"; - case ExpressionType::COLLATE: - return "COLLATE"; - case ExpressionType::POSITIONAL_REFERENCE: - return "POSITIONAL_REFERENCE"; - case ExpressionType::BOUND_LAMBDA_REF: - return "BOUND_LAMBDA_REF"; - case ExpressionType::LAMBDA: - return "LAMBDA"; - case ExpressionType::ARROW: - return "ARROW"; - case ExpressionType::INVALID: - break; - } - return "INVALID"; -} -string ExpressionClassToString(ExpressionClass type) { - switch (type) { - case ExpressionClass::INVALID: - return "INVALID"; - case ExpressionClass::AGGREGATE: - return "AGGREGATE"; - case ExpressionClass::CASE: - return "CASE"; - case ExpressionClass::CAST: - return "CAST"; - case ExpressionClass::COLUMN_REF: - return "COLUMN_REF"; - case ExpressionClass::COMPARISON: - return "COMPARISON"; - case ExpressionClass::CONJUNCTION: - return "CONJUNCTION"; - case ExpressionClass::CONSTANT: - return "CONSTANT"; - case ExpressionClass::DEFAULT: - return "DEFAULT"; - case ExpressionClass::FUNCTION: - return "FUNCTION"; - case ExpressionClass::OPERATOR: - return "OPERATOR"; - case ExpressionClass::STAR: - return "STAR"; - case ExpressionClass::SUBQUERY: - return "SUBQUERY"; - case ExpressionClass::WINDOW: - return "WINDOW"; - case ExpressionClass::PARAMETER: - return "PARAMETER"; - case ExpressionClass::COLLATE: - return "COLLATE"; - case ExpressionClass::LAMBDA: - return "LAMBDA"; - case ExpressionClass::POSITIONAL_REFERENCE: - return "POSITIONAL_REFERENCE"; - case ExpressionClass::BETWEEN: - return "BETWEEN"; - case ExpressionClass::BOUND_AGGREGATE: - return "BOUND_AGGREGATE"; - case ExpressionClass::BOUND_CASE: - return "BOUND_CASE"; - case ExpressionClass::BOUND_CAST: - return "BOUND_CAST"; - case ExpressionClass::BOUND_COLUMN_REF: - return "BOUND_COLUMN_REF"; - case ExpressionClass::BOUND_COMPARISON: - return "BOUND_COMPARISON"; - case ExpressionClass::BOUND_CONJUNCTION: - return "BOUND_CONJUNCTION"; - case ExpressionClass::BOUND_CONSTANT: - return "BOUND_CONSTANT"; - case ExpressionClass::BOUND_DEFAULT: - return "BOUND_DEFAULT"; - case ExpressionClass::BOUND_FUNCTION: - return "BOUND_FUNCTION"; - case ExpressionClass::BOUND_OPERATOR: - return "BOUND_OPERATOR"; - case ExpressionClass::BOUND_PARAMETER: - return "BOUND_PARAMETER"; - case ExpressionClass::BOUND_REF: - return "BOUND_REF"; - case ExpressionClass::BOUND_SUBQUERY: - return "BOUND_SUBQUERY"; - case ExpressionClass::BOUND_WINDOW: - return "BOUND_WINDOW"; - case ExpressionClass::BOUND_BETWEEN: - return "BOUND_BETWEEN"; - case ExpressionClass::BOUND_UNNEST: - return "BOUND_UNNEST"; - case ExpressionClass::BOUND_LAMBDA: - return "BOUND_LAMBDA"; - case ExpressionClass::BOUND_EXPRESSION: - return "BOUND_EXPRESSION"; - default: - return "ExpressionClass::!!UNIMPLEMENTED_CASE!!"; - } -} - -string ExpressionTypeToOperator(ExpressionType type) { - switch (type) { - case ExpressionType::COMPARE_EQUAL: - return "="; - case ExpressionType::COMPARE_NOTEQUAL: - return "!="; - case ExpressionType::COMPARE_LESSTHAN: - return "<"; - case ExpressionType::COMPARE_GREATERTHAN: - return ">"; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - return "<="; - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - return ">="; - case ExpressionType::COMPARE_DISTINCT_FROM: - return "IS DISTINCT FROM"; - case ExpressionType::COMPARE_NOT_DISTINCT_FROM: - return "IS NOT DISTINCT FROM"; - case ExpressionType::CONJUNCTION_AND: - return "AND"; - case ExpressionType::CONJUNCTION_OR: - return "OR"; - default: - return ""; - } -} - -ExpressionType NegateComparisonExpression(ExpressionType type) { - ExpressionType negated_type = ExpressionType::INVALID; - switch (type) { - case ExpressionType::COMPARE_EQUAL: - negated_type = ExpressionType::COMPARE_NOTEQUAL; - break; - case ExpressionType::COMPARE_NOTEQUAL: - negated_type = ExpressionType::COMPARE_EQUAL; - break; - case ExpressionType::COMPARE_LESSTHAN: - negated_type = ExpressionType::COMPARE_GREATERTHANOREQUALTO; - break; - case ExpressionType::COMPARE_GREATERTHAN: - negated_type = ExpressionType::COMPARE_LESSTHANOREQUALTO; - break; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - negated_type = ExpressionType::COMPARE_GREATERTHAN; - break; - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - negated_type = ExpressionType::COMPARE_LESSTHAN; - break; - default: - throw InternalException("Unsupported comparison type in negation"); - } - return negated_type; -} - -ExpressionType FlipComparisonExpression(ExpressionType type) { - ExpressionType flipped_type = ExpressionType::INVALID; - switch (type) { - case ExpressionType::COMPARE_NOT_DISTINCT_FROM: - case ExpressionType::COMPARE_DISTINCT_FROM: - case ExpressionType::COMPARE_NOTEQUAL: - case ExpressionType::COMPARE_EQUAL: - flipped_type = type; - break; - case ExpressionType::COMPARE_LESSTHAN: - flipped_type = ExpressionType::COMPARE_GREATERTHAN; - break; - case ExpressionType::COMPARE_GREATERTHAN: - flipped_type = ExpressionType::COMPARE_LESSTHAN; - break; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - flipped_type = ExpressionType::COMPARE_GREATERTHANOREQUALTO; - break; - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - flipped_type = ExpressionType::COMPARE_LESSTHANOREQUALTO; - break; - default: - throw InternalException("Unsupported comparison type in flip"); - } - return flipped_type; -} - -ExpressionType OperatorToExpressionType(const string &op) { - if (op == "=" || op == "==") { - return ExpressionType::COMPARE_EQUAL; - } else if (op == "!=" || op == "<>") { - return ExpressionType::COMPARE_NOTEQUAL; - } else if (op == "<") { - return ExpressionType::COMPARE_LESSTHAN; - } else if (op == ">") { - return ExpressionType::COMPARE_GREATERTHAN; - } else if (op == "<=") { - return ExpressionType::COMPARE_LESSTHANOREQUALTO; - } else if (op == ">=") { - return ExpressionType::COMPARE_GREATERTHANOREQUALTO; - } - return ExpressionType::INVALID; -} - -} // namespace duckdb - - - -namespace duckdb { - -FileCompressionType FileCompressionTypeFromString(const string &input) { - auto parameter = StringUtil::Lower(input); - if (parameter == "infer" || parameter == "auto") { - return FileCompressionType::AUTO_DETECT; - } else if (parameter == "gzip") { - return FileCompressionType::GZIP; - } else if (parameter == "zstd") { - return FileCompressionType::ZSTD; - } else if (parameter == "uncompressed" || parameter == "none" || parameter.empty()) { - return FileCompressionType::UNCOMPRESSED; - } else { - throw ParserException("Unrecognized file compression type \"%s\"", input); - } -} - -} // namespace duckdb - - - -namespace duckdb { - -bool IsLeftOuterJoin(JoinType type) { - return type == JoinType::LEFT || type == JoinType::OUTER; -} - -bool IsRightOuterJoin(JoinType type) { - return type == JoinType::OUTER || type == JoinType::RIGHT; -} - -// **DEPRECATED**: Use EnumUtil directly instead. -string JoinTypeToString(JoinType type) { - return EnumUtil::ToString(type); -} - -} // namespace duckdb - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Value <--> String Utilities -//===--------------------------------------------------------------------===// -// LCOV_EXCL_START -string LogicalOperatorToString(LogicalOperatorType type) { - switch (type) { - case LogicalOperatorType::LOGICAL_GET: - return "GET"; - case LogicalOperatorType::LOGICAL_CHUNK_GET: - return "CHUNK_GET"; - case LogicalOperatorType::LOGICAL_DELIM_GET: - return "DELIM_GET"; - case LogicalOperatorType::LOGICAL_EMPTY_RESULT: - return "EMPTY_RESULT"; - case LogicalOperatorType::LOGICAL_EXPRESSION_GET: - return "EXPRESSION_GET"; - case LogicalOperatorType::LOGICAL_ANY_JOIN: - return "ANY_JOIN"; - case LogicalOperatorType::LOGICAL_ASOF_JOIN: - return "ASOF_JOIN"; - case LogicalOperatorType::LOGICAL_DEPENDENT_JOIN: - return "DEPENDENT_JOIN"; - case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: - return "COMPARISON_JOIN"; - case LogicalOperatorType::LOGICAL_DELIM_JOIN: - return "DELIM_JOIN"; - case LogicalOperatorType::LOGICAL_PROJECTION: - return "PROJECTION"; - case LogicalOperatorType::LOGICAL_FILTER: - return "FILTER"; - case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: - return "AGGREGATE"; - case LogicalOperatorType::LOGICAL_WINDOW: - return "WINDOW"; - case LogicalOperatorType::LOGICAL_UNNEST: - return "UNNEST"; - case LogicalOperatorType::LOGICAL_LIMIT: - return "LIMIT"; - case LogicalOperatorType::LOGICAL_ORDER_BY: - return "ORDER_BY"; - case LogicalOperatorType::LOGICAL_TOP_N: - return "TOP_N"; - case LogicalOperatorType::LOGICAL_SAMPLE: - return "SAMPLE"; - case LogicalOperatorType::LOGICAL_LIMIT_PERCENT: - return "LIMIT_PERCENT"; - case LogicalOperatorType::LOGICAL_COPY_TO_FILE: - return "COPY_TO_FILE"; - case LogicalOperatorType::LOGICAL_JOIN: - return "JOIN"; - case LogicalOperatorType::LOGICAL_CROSS_PRODUCT: - return "CROSS_PRODUCT"; - case LogicalOperatorType::LOGICAL_POSITIONAL_JOIN: - return "POSITIONAL_JOIN"; - case LogicalOperatorType::LOGICAL_UNION: - return "UNION"; - case LogicalOperatorType::LOGICAL_EXCEPT: - return "EXCEPT"; - case LogicalOperatorType::LOGICAL_INTERSECT: - return "INTERSECT"; - case LogicalOperatorType::LOGICAL_INSERT: - return "INSERT"; - case LogicalOperatorType::LOGICAL_DISTINCT: - return "DISTINCT"; - case LogicalOperatorType::LOGICAL_DELETE: - return "DELETE"; - case LogicalOperatorType::LOGICAL_UPDATE: - return "UPDATE"; - case LogicalOperatorType::LOGICAL_PREPARE: - return "PREPARE"; - case LogicalOperatorType::LOGICAL_DUMMY_SCAN: - return "DUMMY_SCAN"; - case LogicalOperatorType::LOGICAL_CREATE_INDEX: - return "CREATE_INDEX"; - case LogicalOperatorType::LOGICAL_CREATE_TABLE: - return "CREATE_TABLE"; - case LogicalOperatorType::LOGICAL_CREATE_MACRO: - return "CREATE_MACRO"; - case LogicalOperatorType::LOGICAL_EXPLAIN: - return "EXPLAIN"; - case LogicalOperatorType::LOGICAL_EXECUTE: - return "EXECUTE"; - case LogicalOperatorType::LOGICAL_VACUUM: - return "VACUUM"; - case LogicalOperatorType::LOGICAL_RECURSIVE_CTE: - return "REC_CTE"; - case LogicalOperatorType::LOGICAL_MATERIALIZED_CTE: - return "CTE"; - case LogicalOperatorType::LOGICAL_CTE_REF: - return "CTE_SCAN"; - case LogicalOperatorType::LOGICAL_SHOW: - return "SHOW"; - case LogicalOperatorType::LOGICAL_ALTER: - return "ALTER"; - case LogicalOperatorType::LOGICAL_CREATE_SEQUENCE: - return "CREATE_SEQUENCE"; - case LogicalOperatorType::LOGICAL_CREATE_TYPE: - return "CREATE_TYPE"; - case LogicalOperatorType::LOGICAL_CREATE_VIEW: - return "CREATE_VIEW"; - case LogicalOperatorType::LOGICAL_CREATE_SCHEMA: - return "CREATE_SCHEMA"; - case LogicalOperatorType::LOGICAL_ATTACH: - return "ATTACH"; - case LogicalOperatorType::LOGICAL_DETACH: - return "ATTACH"; - case LogicalOperatorType::LOGICAL_DROP: - return "DROP"; - case LogicalOperatorType::LOGICAL_PRAGMA: - return "PRAGMA"; - case LogicalOperatorType::LOGICAL_TRANSACTION: - return "TRANSACTION"; - case LogicalOperatorType::LOGICAL_EXPORT: - return "EXPORT"; - case LogicalOperatorType::LOGICAL_SET: - return "SET"; - case LogicalOperatorType::LOGICAL_RESET: - return "RESET"; - case LogicalOperatorType::LOGICAL_LOAD: - return "LOAD"; - case LogicalOperatorType::LOGICAL_INVALID: - break; - case LogicalOperatorType::LOGICAL_EXTENSION_OPERATOR: - return "CUSTOM_OP"; - case LogicalOperatorType::LOGICAL_PIVOT: - return "PIVOT"; - } - return "INVALID"; -} -// LCOV_EXCL_STOP - -} // namespace duckdb - - - - - -namespace duckdb { - -struct DefaultOptimizerType { - const char *name; - OptimizerType type; -}; - -static DefaultOptimizerType internal_optimizer_types[] = { - {"expression_rewriter", OptimizerType::EXPRESSION_REWRITER}, - {"filter_pullup", OptimizerType::FILTER_PULLUP}, - {"filter_pushdown", OptimizerType::FILTER_PUSHDOWN}, - {"regex_range", OptimizerType::REGEX_RANGE}, - {"in_clause", OptimizerType::IN_CLAUSE}, - {"join_order", OptimizerType::JOIN_ORDER}, - {"deliminator", OptimizerType::DELIMINATOR}, - {"unnest_rewriter", OptimizerType::UNNEST_REWRITER}, - {"unused_columns", OptimizerType::UNUSED_COLUMNS}, - {"statistics_propagation", OptimizerType::STATISTICS_PROPAGATION}, - {"common_subexpressions", OptimizerType::COMMON_SUBEXPRESSIONS}, - {"common_aggregate", OptimizerType::COMMON_AGGREGATE}, - {"column_lifetime", OptimizerType::COLUMN_LIFETIME}, - {"top_n", OptimizerType::TOP_N}, - {"compressed_materialization", OptimizerType::COMPRESSED_MATERIALIZATION}, - {"duplicate_groups", OptimizerType::DUPLICATE_GROUPS}, - {"reorder_filter", OptimizerType::REORDER_FILTER}, - {"extension", OptimizerType::EXTENSION}, - {nullptr, OptimizerType::INVALID}}; - -string OptimizerTypeToString(OptimizerType type) { - for (idx_t i = 0; internal_optimizer_types[i].name; i++) { - if (internal_optimizer_types[i].type == type) { - return internal_optimizer_types[i].name; - } - } - throw InternalException("Invalid optimizer type"); -} - -OptimizerType OptimizerTypeFromString(const string &str) { - for (idx_t i = 0; internal_optimizer_types[i].name; i++) { - if (internal_optimizer_types[i].name == str) { - return internal_optimizer_types[i].type; - } - } - // optimizer not found, construct candidate list - vector optimizer_names; - for (idx_t i = 0; internal_optimizer_types[i].name; i++) { - optimizer_names.emplace_back(internal_optimizer_types[i].name); - } - throw ParserException("Optimizer type \"%s\" not recognized\n%s", str, - StringUtil::CandidatesErrorMessage(optimizer_names, str, "Candidate optimizers")); -} - -} // namespace duckdb - - -namespace duckdb { - -// LCOV_EXCL_START -string PhysicalOperatorToString(PhysicalOperatorType type) { - switch (type) { - case PhysicalOperatorType::TABLE_SCAN: - return "TABLE_SCAN"; - case PhysicalOperatorType::DUMMY_SCAN: - return "DUMMY_SCAN"; - case PhysicalOperatorType::CHUNK_SCAN: - return "CHUNK_SCAN"; - case PhysicalOperatorType::COLUMN_DATA_SCAN: - return "COLUMN_DATA_SCAN"; - case PhysicalOperatorType::DELIM_SCAN: - return "DELIM_SCAN"; - case PhysicalOperatorType::ORDER_BY: - return "ORDER_BY"; - case PhysicalOperatorType::LIMIT: - return "LIMIT"; - case PhysicalOperatorType::LIMIT_PERCENT: - return "LIMIT_PERCENT"; - case PhysicalOperatorType::STREAMING_LIMIT: - return "STREAMING_LIMIT"; - case PhysicalOperatorType::RESERVOIR_SAMPLE: - return "RESERVOIR_SAMPLE"; - case PhysicalOperatorType::STREAMING_SAMPLE: - return "STREAMING_SAMPLE"; - case PhysicalOperatorType::TOP_N: - return "TOP_N"; - case PhysicalOperatorType::WINDOW: - return "WINDOW"; - case PhysicalOperatorType::STREAMING_WINDOW: - return "STREAMING_WINDOW"; - case PhysicalOperatorType::UNNEST: - return "UNNEST"; - case PhysicalOperatorType::UNGROUPED_AGGREGATE: - return "UNGROUPED_AGGREGATE"; - case PhysicalOperatorType::HASH_GROUP_BY: - return "HASH_GROUP_BY"; - case PhysicalOperatorType::PERFECT_HASH_GROUP_BY: - return "PERFECT_HASH_GROUP_BY"; - case PhysicalOperatorType::FILTER: - return "FILTER"; - case PhysicalOperatorType::PROJECTION: - return "PROJECTION"; - case PhysicalOperatorType::COPY_TO_FILE: - return "COPY_TO_FILE"; - case PhysicalOperatorType::BATCH_COPY_TO_FILE: - return "BATCH_COPY_TO_FILE"; - case PhysicalOperatorType::FIXED_BATCH_COPY_TO_FILE: - return "FIXED_BATCH_COPY_TO_FILE"; - case PhysicalOperatorType::DELIM_JOIN: - return "DELIM_JOIN"; - case PhysicalOperatorType::BLOCKWISE_NL_JOIN: - return "BLOCKWISE_NL_JOIN"; - case PhysicalOperatorType::NESTED_LOOP_JOIN: - return "NESTED_LOOP_JOIN"; - case PhysicalOperatorType::HASH_JOIN: - return "HASH_JOIN"; - case PhysicalOperatorType::INDEX_JOIN: - return "INDEX_JOIN"; - case PhysicalOperatorType::PIECEWISE_MERGE_JOIN: - return "PIECEWISE_MERGE_JOIN"; - case PhysicalOperatorType::IE_JOIN: - return "IE_JOIN"; - case PhysicalOperatorType::ASOF_JOIN: - return "ASOF_JOIN"; - case PhysicalOperatorType::CROSS_PRODUCT: - return "CROSS_PRODUCT"; - case PhysicalOperatorType::POSITIONAL_JOIN: - return "POSITIONAL_JOIN"; - case PhysicalOperatorType::POSITIONAL_SCAN: - return "POSITIONAL_SCAN"; - case PhysicalOperatorType::UNION: - return "UNION"; - case PhysicalOperatorType::INSERT: - return "INSERT"; - case PhysicalOperatorType::BATCH_INSERT: - return "BATCH_INSERT"; - case PhysicalOperatorType::DELETE_OPERATOR: - return "DELETE"; - case PhysicalOperatorType::UPDATE: - return "UPDATE"; - case PhysicalOperatorType::EMPTY_RESULT: - return "EMPTY_RESULT"; - case PhysicalOperatorType::CREATE_TABLE: - return "CREATE_TABLE"; - case PhysicalOperatorType::CREATE_TABLE_AS: - return "CREATE_TABLE_AS"; - case PhysicalOperatorType::BATCH_CREATE_TABLE_AS: - return "BATCH_CREATE_TABLE_AS"; - case PhysicalOperatorType::CREATE_INDEX: - return "CREATE_INDEX"; - case PhysicalOperatorType::EXPLAIN: - return "EXPLAIN"; - case PhysicalOperatorType::EXPLAIN_ANALYZE: - return "EXPLAIN_ANALYZE"; - case PhysicalOperatorType::EXECUTE: - return "EXECUTE"; - case PhysicalOperatorType::VACUUM: - return "VACUUM"; - case PhysicalOperatorType::RECURSIVE_CTE: - return "REC_CTE"; - case PhysicalOperatorType::CTE: - return "CTE"; - case PhysicalOperatorType::RECURSIVE_CTE_SCAN: - return "REC_CTE_SCAN"; - case PhysicalOperatorType::CTE_SCAN: - return "CTE_SCAN"; - case PhysicalOperatorType::EXPRESSION_SCAN: - return "EXPRESSION_SCAN"; - case PhysicalOperatorType::ALTER: - return "ALTER"; - case PhysicalOperatorType::CREATE_SEQUENCE: - return "CREATE_SEQUENCE"; - case PhysicalOperatorType::CREATE_VIEW: - return "CREATE_VIEW"; - case PhysicalOperatorType::CREATE_SCHEMA: - return "CREATE_SCHEMA"; - case PhysicalOperatorType::CREATE_MACRO: - return "CREATE_MACRO"; - case PhysicalOperatorType::DROP: - return "DROP"; - case PhysicalOperatorType::PRAGMA: - return "PRAGMA"; - case PhysicalOperatorType::TRANSACTION: - return "TRANSACTION"; - case PhysicalOperatorType::PREPARE: - return "PREPARE"; - case PhysicalOperatorType::EXPORT: - return "EXPORT"; - case PhysicalOperatorType::SET: - return "SET"; - case PhysicalOperatorType::RESET: - return "RESET"; - case PhysicalOperatorType::LOAD: - return "LOAD"; - case PhysicalOperatorType::INOUT_FUNCTION: - return "INOUT_FUNCTION"; - case PhysicalOperatorType::CREATE_TYPE: - return "CREATE_TYPE"; - case PhysicalOperatorType::ATTACH: - return "ATTACH"; - case PhysicalOperatorType::DETACH: - return "DETACH"; - case PhysicalOperatorType::RESULT_COLLECTOR: - return "RESULT_COLLECTOR"; - case PhysicalOperatorType::EXTENSION: - return "EXTENSION"; - case PhysicalOperatorType::PIVOT: - return "PIVOT"; - case PhysicalOperatorType::INVALID: - break; - } - return "INVALID"; -} -// LCOV_EXCL_STOP - -} // namespace duckdb - - - - -namespace duckdb { - -// LCOV_EXCL_START -string RelationTypeToString(RelationType type) { - switch (type) { - case RelationType::TABLE_RELATION: - return "TABLE_RELATION"; - case RelationType::PROJECTION_RELATION: - return "PROJECTION_RELATION"; - case RelationType::FILTER_RELATION: - return "FILTER_RELATION"; - case RelationType::EXPLAIN_RELATION: - return "EXPLAIN_RELATION"; - case RelationType::CROSS_PRODUCT_RELATION: - return "CROSS_PRODUCT_RELATION"; - case RelationType::JOIN_RELATION: - return "JOIN_RELATION"; - case RelationType::AGGREGATE_RELATION: - return "AGGREGATE_RELATION"; - case RelationType::SET_OPERATION_RELATION: - return "SET_OPERATION_RELATION"; - case RelationType::DISTINCT_RELATION: - return "DISTINCT_RELATION"; - case RelationType::LIMIT_RELATION: - return "LIMIT_RELATION"; - case RelationType::ORDER_RELATION: - return "ORDER_RELATION"; - case RelationType::CREATE_VIEW_RELATION: - return "CREATE_VIEW_RELATION"; - case RelationType::CREATE_TABLE_RELATION: - return "CREATE_TABLE_RELATION"; - case RelationType::INSERT_RELATION: - return "INSERT_RELATION"; - case RelationType::VALUE_LIST_RELATION: - return "VALUE_LIST_RELATION"; - case RelationType::DELETE_RELATION: - return "DELETE_RELATION"; - case RelationType::UPDATE_RELATION: - return "UPDATE_RELATION"; - case RelationType::WRITE_CSV_RELATION: - return "WRITE_CSV_RELATION"; - case RelationType::WRITE_PARQUET_RELATION: - return "WRITE_PARQUET_RELATION"; - case RelationType::READ_CSV_RELATION: - return "READ_CSV_RELATION"; - case RelationType::SUBQUERY_RELATION: - return "SUBQUERY_RELATION"; - case RelationType::TABLE_FUNCTION_RELATION: - return "TABLE_FUNCTION_RELATION"; - case RelationType::VIEW_RELATION: - return "VIEW_RELATION"; - case RelationType::QUERY_RELATION: - return "QUERY_RELATION"; - case RelationType::INVALID_RELATION: - break; - } - return "INVALID_RELATION"; -} -// LCOV_EXCL_STOP - -} // namespace duckdb - - -namespace duckdb { - -// LCOV_EXCL_START -string StatementTypeToString(StatementType type) { - switch (type) { - case StatementType::SELECT_STATEMENT: - return "SELECT"; - case StatementType::INSERT_STATEMENT: - return "INSERT"; - case StatementType::UPDATE_STATEMENT: - return "UPDATE"; - case StatementType::DELETE_STATEMENT: - return "DELETE"; - case StatementType::PREPARE_STATEMENT: - return "PREPARE"; - case StatementType::EXECUTE_STATEMENT: - return "EXECUTE"; - case StatementType::ALTER_STATEMENT: - return "ALTER"; - case StatementType::TRANSACTION_STATEMENT: - return "TRANSACTION"; - case StatementType::COPY_STATEMENT: - return "COPY"; - case StatementType::ANALYZE_STATEMENT: - return "ANALYZE"; - case StatementType::VARIABLE_SET_STATEMENT: - return "VARIABLE_SET"; - case StatementType::CREATE_FUNC_STATEMENT: - return "CREATE_FUNC"; - case StatementType::EXPLAIN_STATEMENT: - return "EXPLAIN"; - case StatementType::CREATE_STATEMENT: - return "CREATE"; - case StatementType::DROP_STATEMENT: - return "DROP"; - case StatementType::PRAGMA_STATEMENT: - return "PRAGMA"; - case StatementType::SHOW_STATEMENT: - return "SHOW"; - case StatementType::VACUUM_STATEMENT: - return "VACUUM"; - case StatementType::RELATION_STATEMENT: - return "RELATION"; - case StatementType::EXPORT_STATEMENT: - return "EXPORT"; - case StatementType::CALL_STATEMENT: - return "CALL"; - case StatementType::SET_STATEMENT: - return "SET"; - case StatementType::LOAD_STATEMENT: - return "LOAD"; - case StatementType::EXTENSION_STATEMENT: - return "EXTENSION"; - case StatementType::LOGICAL_PLAN_STATEMENT: - return "LOGICAL_PLAN"; - case StatementType::ATTACH_STATEMENT: - return "ATTACH"; - case StatementType::DETACH_STATEMENT: - return "DETACH"; - case StatementType::MULTI_STATEMENT: - return "MULTI"; - case StatementType::INVALID_STATEMENT: - break; - } - return "INVALID"; -} - -string StatementReturnTypeToString(StatementReturnType type) { - switch (type) { - case StatementReturnType::QUERY_RESULT: - return "QUERY_RESULT"; - case StatementReturnType::CHANGED_ROWS: - return "CHANGED_ROWS"; - case StatementReturnType::NOTHING: - return "NOTHING"; - } - return "INVALID"; -} -// LCOV_EXCL_STOP - -} // namespace duckdb - - - - - - -#ifdef DUCKDB_CRASH_ON_ASSERT - -#include -#include -#endif -#ifdef DUCKDB_DEBUG_STACKTRACE -#include -#endif - -namespace duckdb { - -Exception::Exception(const string &msg) : std::exception(), type(ExceptionType::INVALID), raw_message_(msg) { - exception_message_ = msg; -} - -Exception::Exception(ExceptionType exception_type, const string &message) - : std::exception(), type(exception_type), raw_message_(message) { - exception_message_ = ExceptionTypeToString(exception_type) + " Error: " + message; -} - -const char *Exception::what() const noexcept { - return exception_message_.c_str(); -} - -const string &Exception::RawMessage() const { - return raw_message_; -} - -bool Exception::UncaughtException() { -#if __cplusplus >= 201703L - return std::uncaught_exceptions() > 0; -#else - return std::uncaught_exception(); -#endif -} - -string Exception::GetStackTrace(int max_depth) { -#ifdef DUCKDB_DEBUG_STACKTRACE - string result; - auto callstack = unique_ptr(new void *[max_depth]); - int frames = backtrace(callstack.get(), max_depth); - char **strs = backtrace_symbols(callstack.get(), frames); - for (int i = 0; i < frames; i++) { - result += strs[i]; - result += "\n"; - } - free(strs); - return "\n" + result; -#else - // Stack trace not available. Toggle DUCKDB_DEBUG_STACKTRACE in exception.cpp to enable stack traces. - return ""; -#endif -} - -string Exception::ConstructMessageRecursive(const string &msg, std::vector &values) { -#ifdef DEBUG - // Verify that we have the required amount of values for the message - idx_t parameter_count = 0; - for (idx_t i = 0; i + 1 < msg.size(); i++) { - if (msg[i] != '%') { - continue; - } - if (msg[i + 1] == '%') { - i++; - continue; - } - parameter_count++; - } - if (parameter_count != values.size()) { - throw InternalException("Primary exception: %s\nSecondary exception in ConstructMessageRecursive: Expected %d " - "parameters, received %d", - msg.c_str(), parameter_count, values.size()); - } - -#endif - return ExceptionFormatValue::Format(msg, values); -} - -string Exception::ExceptionTypeToString(ExceptionType type) { - switch (type) { - case ExceptionType::INVALID: - return "Invalid"; - case ExceptionType::OUT_OF_RANGE: - return "Out of Range"; - case ExceptionType::CONVERSION: - return "Conversion"; - case ExceptionType::UNKNOWN_TYPE: - return "Unknown Type"; - case ExceptionType::DECIMAL: - return "Decimal"; - case ExceptionType::MISMATCH_TYPE: - return "Mismatch Type"; - case ExceptionType::DIVIDE_BY_ZERO: - return "Divide by Zero"; - case ExceptionType::OBJECT_SIZE: - return "Object Size"; - case ExceptionType::INVALID_TYPE: - return "Invalid type"; - case ExceptionType::SERIALIZATION: - return "Serialization"; - case ExceptionType::TRANSACTION: - return "TransactionContext"; - case ExceptionType::NOT_IMPLEMENTED: - return "Not implemented"; - case ExceptionType::EXPRESSION: - return "Expression"; - case ExceptionType::CATALOG: - return "Catalog"; - case ExceptionType::PARSER: - return "Parser"; - case ExceptionType::BINDER: - return "Binder"; - case ExceptionType::PLANNER: - return "Planner"; - case ExceptionType::SCHEDULER: - return "Scheduler"; - case ExceptionType::EXECUTOR: - return "Executor"; - case ExceptionType::CONSTRAINT: - return "Constraint"; - case ExceptionType::INDEX: - return "Index"; - case ExceptionType::STAT: - return "Stat"; - case ExceptionType::CONNECTION: - return "Connection"; - case ExceptionType::SYNTAX: - return "Syntax"; - case ExceptionType::SETTINGS: - return "Settings"; - case ExceptionType::OPTIMIZER: - return "Optimizer"; - case ExceptionType::NULL_POINTER: - return "NullPointer"; - case ExceptionType::IO: - return "IO"; - case ExceptionType::INTERRUPT: - return "INTERRUPT"; - case ExceptionType::FATAL: - return "FATAL"; - case ExceptionType::INTERNAL: - return "INTERNAL"; - case ExceptionType::INVALID_INPUT: - return "Invalid Input"; - case ExceptionType::OUT_OF_MEMORY: - return "Out of Memory"; - case ExceptionType::PERMISSION: - return "Permission"; - case ExceptionType::PARAMETER_NOT_RESOLVED: - return "Parameter Not Resolved"; - case ExceptionType::PARAMETER_NOT_ALLOWED: - return "Parameter Not Allowed"; - case ExceptionType::DEPENDENCY: - return "Dependency"; - case ExceptionType::MISSING_EXTENSION: - return "Missing Extension"; - case ExceptionType::HTTP: - return "HTTP"; - case ExceptionType::AUTOLOAD: - return "Extension Autoloading"; - default: - return "Unknown"; - } -} - -const HTTPException &Exception::AsHTTPException() const { - D_ASSERT(type == ExceptionType::HTTP); - const auto &e = static_cast(this); - D_ASSERT(e->GetStatusCode() != 0); - D_ASSERT(e->GetHeaders().size() > 0); - return *e; -} - -void Exception::ThrowAsTypeWithMessage(ExceptionType type, const string &message, - const std::shared_ptr &original) { - switch (type) { - case ExceptionType::OUT_OF_RANGE: - throw OutOfRangeException(message); - case ExceptionType::CONVERSION: - throw ConversionException(message); // FIXME: make a separation between Conversion/Cast exception? - case ExceptionType::INVALID_TYPE: - throw InvalidTypeException(message); - case ExceptionType::MISMATCH_TYPE: - throw TypeMismatchException(message); - case ExceptionType::TRANSACTION: - throw TransactionException(message); - case ExceptionType::NOT_IMPLEMENTED: - throw NotImplementedException(message); - case ExceptionType::CATALOG: - throw CatalogException(message); - case ExceptionType::CONNECTION: - throw ConnectionException(message); - case ExceptionType::PARSER: - throw ParserException(message); - case ExceptionType::PERMISSION: - throw PermissionException(message); - case ExceptionType::SYNTAX: - throw SyntaxException(message); - case ExceptionType::CONSTRAINT: - throw ConstraintException(message); - case ExceptionType::BINDER: - throw BinderException(message); - case ExceptionType::IO: - throw IOException(message); - case ExceptionType::SERIALIZATION: - throw SerializationException(message); - case ExceptionType::INTERRUPT: - throw InterruptException(); - case ExceptionType::INTERNAL: - throw InternalException(message); - case ExceptionType::INVALID_INPUT: - throw InvalidInputException(message); - case ExceptionType::OUT_OF_MEMORY: - throw OutOfMemoryException(message); - case ExceptionType::PARAMETER_NOT_ALLOWED: - throw ParameterNotAllowedException(message); - case ExceptionType::PARAMETER_NOT_RESOLVED: - throw ParameterNotResolvedException(); - case ExceptionType::FATAL: - throw FatalException(message); - case ExceptionType::DEPENDENCY: - throw DependencyException(message); - case ExceptionType::HTTP: { - original->AsHTTPException().Throw(); - } - case ExceptionType::MISSING_EXTENSION: - throw MissingExtensionException(message); - default: - throw Exception(type, message); - } -} - -StandardException::StandardException(ExceptionType exception_type, const string &message) - : Exception(exception_type, message) { -} - -CastException::CastException(const PhysicalType orig_type, const PhysicalType new_type) - : Exception(ExceptionType::CONVERSION, - "Type " + TypeIdToString(orig_type) + " can't be cast as " + TypeIdToString(new_type)) { -} - -CastException::CastException(const LogicalType &orig_type, const LogicalType &new_type) - : Exception(ExceptionType::CONVERSION, - "Type " + orig_type.ToString() + " can't be cast as " + new_type.ToString()) { -} - -CastException::CastException(const string &msg) : Exception(ExceptionType::CONVERSION, msg) { -} - -ValueOutOfRangeException::ValueOutOfRangeException(const int64_t value, const PhysicalType orig_type, - const PhysicalType new_type) - : Exception(ExceptionType::CONVERSION, "Type " + TypeIdToString(orig_type) + " with value " + - to_string((intmax_t)value) + - " can't be cast because the value is out of range " - "for the destination type " + - TypeIdToString(new_type)) { -} - -ValueOutOfRangeException::ValueOutOfRangeException(const double value, const PhysicalType orig_type, - const PhysicalType new_type) - : Exception(ExceptionType::CONVERSION, "Type " + TypeIdToString(orig_type) + " with value " + to_string(value) + - " can't be cast because the value is out of range " - "for the destination type " + - TypeIdToString(new_type)) { -} - -ValueOutOfRangeException::ValueOutOfRangeException(const hugeint_t value, const PhysicalType orig_type, - const PhysicalType new_type) - : Exception(ExceptionType::CONVERSION, "Type " + TypeIdToString(orig_type) + " with value " + value.ToString() + - " can't be cast because the value is out of range " - "for the destination type " + - TypeIdToString(new_type)) { -} - -ValueOutOfRangeException::ValueOutOfRangeException(const PhysicalType var_type, const idx_t length) - : Exception(ExceptionType::OUT_OF_RANGE, - "The value is too long to fit into type " + TypeIdToString(var_type) + "(" + to_string(length) + ")") { -} - -ValueOutOfRangeException::ValueOutOfRangeException(const string &msg) : Exception(ExceptionType::OUT_OF_RANGE, msg) { -} - -ConversionException::ConversionException(const string &msg) : Exception(ExceptionType::CONVERSION, msg) { -} - -InvalidTypeException::InvalidTypeException(PhysicalType type, const string &msg) - : Exception(ExceptionType::INVALID_TYPE, "Invalid Type [" + TypeIdToString(type) + "]: " + msg) { -} - -InvalidTypeException::InvalidTypeException(const LogicalType &type, const string &msg) - : Exception(ExceptionType::INVALID_TYPE, "Invalid Type [" + type.ToString() + "]: " + msg) { -} - -InvalidTypeException::InvalidTypeException(const string &msg) : Exception(ExceptionType::INVALID_TYPE, msg) { -} - -TypeMismatchException::TypeMismatchException(const PhysicalType type_1, const PhysicalType type_2, const string &msg) - : Exception(ExceptionType::MISMATCH_TYPE, - "Type " + TypeIdToString(type_1) + " does not match with " + TypeIdToString(type_2) + ". " + msg) { -} - -TypeMismatchException::TypeMismatchException(const LogicalType &type_1, const LogicalType &type_2, const string &msg) - : Exception(ExceptionType::MISMATCH_TYPE, - "Type " + type_1.ToString() + " does not match with " + type_2.ToString() + ". " + msg) { -} - -TypeMismatchException::TypeMismatchException(const string &msg) : Exception(ExceptionType::MISMATCH_TYPE, msg) { -} - -TransactionException::TransactionException(const string &msg) : Exception(ExceptionType::TRANSACTION, msg) { -} - -NotImplementedException::NotImplementedException(const string &msg) : Exception(ExceptionType::NOT_IMPLEMENTED, msg) { -} - -OutOfRangeException::OutOfRangeException(const string &msg) : Exception(ExceptionType::OUT_OF_RANGE, msg) { -} - -CatalogException::CatalogException(const string &msg) : StandardException(ExceptionType::CATALOG, msg) { -} - -ConnectionException::ConnectionException(const string &msg) : StandardException(ExceptionType::CONNECTION, msg) { -} - -ParserException::ParserException(const string &msg) : StandardException(ExceptionType::PARSER, msg) { -} - -PermissionException::PermissionException(const string &msg) : StandardException(ExceptionType::PERMISSION, msg) { -} - -SyntaxException::SyntaxException(const string &msg) : Exception(ExceptionType::SYNTAX, msg) { -} - -ConstraintException::ConstraintException(const string &msg) : Exception(ExceptionType::CONSTRAINT, msg) { -} - -DependencyException::DependencyException(const string &msg) : Exception(ExceptionType::DEPENDENCY, msg) { -} - -BinderException::BinderException(const string &msg) : StandardException(ExceptionType::BINDER, msg) { -} - -IOException::IOException(const string &msg) : Exception(ExceptionType::IO, msg) { -} - -MissingExtensionException::MissingExtensionException(const string &msg) - : Exception(ExceptionType::MISSING_EXTENSION, msg) { -} - -AutoloadException::AutoloadException(const string &extension_name, Exception &e) - : Exception(ExceptionType::AUTOLOAD, - "An error occurred while trying to automatically install the required extension '" + extension_name + - "':\n" + e.RawMessage()), - wrapped_exception(e) { -} - -SerializationException::SerializationException(const string &msg) : Exception(ExceptionType::SERIALIZATION, msg) { -} - -SequenceException::SequenceException(const string &msg) : Exception(ExceptionType::SERIALIZATION, msg) { -} - -InterruptException::InterruptException() : Exception(ExceptionType::INTERRUPT, "Interrupted!") { -} - -FatalException::FatalException(ExceptionType type, const string &msg) : Exception(type, msg) { -} - -InternalException::InternalException(const string &msg) : FatalException(ExceptionType::INTERNAL, msg) { -#ifdef DUCKDB_CRASH_ON_ASSERT - Printer::Print("ABORT THROWN BY INTERNAL EXCEPTION: " + msg); - abort(); -#endif -} - -InvalidInputException::InvalidInputException(const string &msg) : Exception(ExceptionType::INVALID_INPUT, msg) { -} - -OutOfMemoryException::OutOfMemoryException(const string &msg) : Exception(ExceptionType::OUT_OF_MEMORY, msg) { -} - -ParameterNotAllowedException::ParameterNotAllowedException(const string &msg) - : StandardException(ExceptionType::PARAMETER_NOT_ALLOWED, msg) { -} - -ParameterNotResolvedException::ParameterNotResolvedException() - : Exception(ExceptionType::PARAMETER_NOT_RESOLVED, "Parameter types could not be resolved") { -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -ExceptionFormatValue::ExceptionFormatValue(double dbl_val) - : type(ExceptionFormatValueType::FORMAT_VALUE_TYPE_DOUBLE), dbl_val(dbl_val) { -} -ExceptionFormatValue::ExceptionFormatValue(int64_t int_val) - : type(ExceptionFormatValueType::FORMAT_VALUE_TYPE_INTEGER), int_val(int_val) { -} -ExceptionFormatValue::ExceptionFormatValue(hugeint_t huge_val) - : type(ExceptionFormatValueType::FORMAT_VALUE_TYPE_STRING), str_val(Hugeint::ToString(huge_val)) { -} -ExceptionFormatValue::ExceptionFormatValue(string str_val) - : type(ExceptionFormatValueType::FORMAT_VALUE_TYPE_STRING), str_val(std::move(str_val)) { -} - -template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(PhysicalType value) { - return ExceptionFormatValue(TypeIdToString(value)); -} -template <> -ExceptionFormatValue -ExceptionFormatValue::CreateFormatValue(LogicalType value) { // NOLINT: templating requires us to copy value here - return ExceptionFormatValue(value.ToString()); -} -template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(float value) { - return ExceptionFormatValue(double(value)); -} -template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(double value) { - return ExceptionFormatValue(double(value)); -} -template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(string value) { - return ExceptionFormatValue(std::move(value)); -} - -template <> -ExceptionFormatValue -ExceptionFormatValue::CreateFormatValue(SQLString value) { // NOLINT: templating requires us to copy value here - return KeywordHelper::WriteQuoted(value.raw_string, '\''); -} - -template <> -ExceptionFormatValue -ExceptionFormatValue::CreateFormatValue(SQLIdentifier value) { // NOLINT: templating requires us to copy value here - return KeywordHelper::WriteOptionallyQuoted(value.raw_string, '"'); -} - -template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const char *value) { - return ExceptionFormatValue(string(value)); -} -template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(char *value) { - return ExceptionFormatValue(string(value)); -} -template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(hugeint_t value) { - return ExceptionFormatValue(value); -} - -string ExceptionFormatValue::Format(const string &msg, std::vector &values) { - try { - std::vector> format_args; - for (auto &val : values) { - switch (val.type) { - case ExceptionFormatValueType::FORMAT_VALUE_TYPE_DOUBLE: - format_args.push_back(duckdb_fmt::internal::make_arg(val.dbl_val)); - break; - case ExceptionFormatValueType::FORMAT_VALUE_TYPE_INTEGER: - format_args.push_back(duckdb_fmt::internal::make_arg(val.int_val)); - break; - case ExceptionFormatValueType::FORMAT_VALUE_TYPE_STRING: - format_args.push_back(duckdb_fmt::internal::make_arg(val.str_val)); - break; - } - } - return duckdb_fmt::vsprintf(msg, duckdb_fmt::basic_format_args( - format_args.data(), static_cast(format_args.size()))); - } catch (std::exception &ex) { // LCOV_EXCL_START - // work-around for oss-fuzz limiting memory which causes issues here - if (StringUtil::Contains(ex.what(), "fuzz mode")) { - throw Exception(msg); - } - throw InternalException(std::string("Primary exception: ") + msg + - "\nSecondary exception in ExceptionFormatValue: " + ex.what()); - } // LCOV_EXCL_STOP -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Extra Type Info -//===--------------------------------------------------------------------===// -ExtraTypeInfo::ExtraTypeInfo(ExtraTypeInfoType type) : type(type) { -} -ExtraTypeInfo::ExtraTypeInfo(ExtraTypeInfoType type, string alias) : type(type), alias(std::move(alias)) { -} -ExtraTypeInfo::~ExtraTypeInfo() { -} - -bool ExtraTypeInfo::Equals(ExtraTypeInfo *other_p) const { - if (type == ExtraTypeInfoType::INVALID_TYPE_INFO || type == ExtraTypeInfoType::STRING_TYPE_INFO || - type == ExtraTypeInfoType::GENERIC_TYPE_INFO) { - if (!other_p) { - if (!alias.empty()) { - return false; - } - //! We only need to compare aliases when both types have them in this case - return true; - } - if (alias != other_p->alias) { - return false; - } - return true; - } - if (!other_p) { - return false; - } - if (type != other_p->type) { - return false; - } - return alias == other_p->alias && EqualsInternal(other_p); -} - -bool ExtraTypeInfo::EqualsInternal(ExtraTypeInfo *other_p) const { - // Do nothing - return true; -} - -//===--------------------------------------------------------------------===// -// Decimal Type Info -//===--------------------------------------------------------------------===// -DecimalTypeInfo::DecimalTypeInfo() : ExtraTypeInfo(ExtraTypeInfoType::DECIMAL_TYPE_INFO) { -} - -DecimalTypeInfo::DecimalTypeInfo(uint8_t width_p, uint8_t scale_p) - : ExtraTypeInfo(ExtraTypeInfoType::DECIMAL_TYPE_INFO), width(width_p), scale(scale_p) { - D_ASSERT(width_p >= scale_p); -} - -bool DecimalTypeInfo::EqualsInternal(ExtraTypeInfo *other_p) const { - auto &other = other_p->Cast(); - return width == other.width && scale == other.scale; -} - -//===--------------------------------------------------------------------===// -// String Type Info -//===--------------------------------------------------------------------===// -StringTypeInfo::StringTypeInfo() : ExtraTypeInfo(ExtraTypeInfoType::STRING_TYPE_INFO) { -} - -StringTypeInfo::StringTypeInfo(string collation_p) - : ExtraTypeInfo(ExtraTypeInfoType::STRING_TYPE_INFO), collation(std::move(collation_p)) { -} - -bool StringTypeInfo::EqualsInternal(ExtraTypeInfo *other_p) const { - // collation info has no impact on equality - return true; -} - -//===--------------------------------------------------------------------===// -// List Type Info -//===--------------------------------------------------------------------===// -ListTypeInfo::ListTypeInfo() : ExtraTypeInfo(ExtraTypeInfoType::LIST_TYPE_INFO) { -} - -ListTypeInfo::ListTypeInfo(LogicalType child_type_p) - : ExtraTypeInfo(ExtraTypeInfoType::LIST_TYPE_INFO), child_type(std::move(child_type_p)) { -} - -bool ListTypeInfo::EqualsInternal(ExtraTypeInfo *other_p) const { - auto &other = other_p->Cast(); - return child_type == other.child_type; -} - -//===--------------------------------------------------------------------===// -// Struct Type Info -//===--------------------------------------------------------------------===// -StructTypeInfo::StructTypeInfo() : ExtraTypeInfo(ExtraTypeInfoType::STRUCT_TYPE_INFO) { -} - -StructTypeInfo::StructTypeInfo(child_list_t child_types_p) - : ExtraTypeInfo(ExtraTypeInfoType::STRUCT_TYPE_INFO), child_types(std::move(child_types_p)) { -} - -bool StructTypeInfo::EqualsInternal(ExtraTypeInfo *other_p) const { - auto &other = other_p->Cast(); - return child_types == other.child_types; -} - -//===--------------------------------------------------------------------===// -// Aggregate State Type Info -//===--------------------------------------------------------------------===// -AggregateStateTypeInfo::AggregateStateTypeInfo() : ExtraTypeInfo(ExtraTypeInfoType::AGGREGATE_STATE_TYPE_INFO) { -} - -AggregateStateTypeInfo::AggregateStateTypeInfo(aggregate_state_t state_type_p) - : ExtraTypeInfo(ExtraTypeInfoType::AGGREGATE_STATE_TYPE_INFO), state_type(std::move(state_type_p)) { -} - -bool AggregateStateTypeInfo::EqualsInternal(ExtraTypeInfo *other_p) const { - auto &other = other_p->Cast(); - return state_type.function_name == other.state_type.function_name && - state_type.return_type == other.state_type.return_type && - state_type.bound_argument_types == other.state_type.bound_argument_types; -} - -//===--------------------------------------------------------------------===// -// User Type Info -//===--------------------------------------------------------------------===// -UserTypeInfo::UserTypeInfo() : ExtraTypeInfo(ExtraTypeInfoType::USER_TYPE_INFO) { -} - -UserTypeInfo::UserTypeInfo(string name_p) - : ExtraTypeInfo(ExtraTypeInfoType::USER_TYPE_INFO), user_type_name(std::move(name_p)) { -} - -bool UserTypeInfo::EqualsInternal(ExtraTypeInfo *other_p) const { - auto &other = other_p->Cast(); - return other.user_type_name == user_type_name; -} - -//===--------------------------------------------------------------------===// -// Enum Type Info -//===--------------------------------------------------------------------===// -PhysicalType EnumTypeInfo::DictType(idx_t size) { - if (size <= NumericLimits::Maximum()) { - return PhysicalType::UINT8; - } else if (size <= NumericLimits::Maximum()) { - return PhysicalType::UINT16; - } else if (size <= NumericLimits::Maximum()) { - return PhysicalType::UINT32; - } else { - throw InternalException("Enum size must be lower than " + std::to_string(NumericLimits::Maximum())); - } -} - -template -struct EnumTypeInfoTemplated : public EnumTypeInfo { - explicit EnumTypeInfoTemplated(Vector &values_insert_order_p, idx_t size_p) - : EnumTypeInfo(values_insert_order_p, size_p) { - D_ASSERT(values_insert_order_p.GetType().InternalType() == PhysicalType::VARCHAR); - - UnifiedVectorFormat vdata; - values_insert_order.ToUnifiedFormat(size_p, vdata); - - auto data = UnifiedVectorFormat::GetData(vdata); - for (idx_t i = 0; i < size_p; i++) { - auto idx = vdata.sel->get_index(i); - if (!vdata.validity.RowIsValid(idx)) { - throw InternalException("Attempted to create ENUM type with NULL value"); - } - if (values.count(data[idx]) > 0) { - throw InvalidInputException("Attempted to create ENUM type with duplicate value %s", - data[idx].GetString()); - } - values[data[idx]] = i; - } - } - - static shared_ptr Deserialize(Deserializer &deserializer, uint32_t size) { - Vector values_insert_order(LogicalType::VARCHAR, size); - auto strings = FlatVector::GetData(values_insert_order); - - deserializer.ReadList(201, "values", [&](Deserializer::List &list, idx_t i) { - strings[i] = StringVector::AddStringOrBlob(values_insert_order, list.ReadElement()); - }); - return make_shared(values_insert_order, size); - } - - const string_map_t &GetValues() const { - return values; - } - - EnumTypeInfoTemplated(const EnumTypeInfoTemplated &) = delete; - EnumTypeInfoTemplated &operator=(const EnumTypeInfoTemplated &) = delete; - -private: - string_map_t values; -}; - -EnumTypeInfo::EnumTypeInfo(Vector &values_insert_order_p, idx_t dict_size_p) - : ExtraTypeInfo(ExtraTypeInfoType::ENUM_TYPE_INFO), values_insert_order(values_insert_order_p), - dict_type(EnumDictType::VECTOR_DICT), dict_size(dict_size_p) { -} - -const EnumDictType &EnumTypeInfo::GetEnumDictType() const { - return dict_type; -} - -const Vector &EnumTypeInfo::GetValuesInsertOrder() const { - return values_insert_order; -} - -const idx_t &EnumTypeInfo::GetDictSize() const { - return dict_size; -} - -LogicalType EnumTypeInfo::CreateType(Vector &ordered_data, idx_t size) { - // Generate EnumTypeInfo - shared_ptr info; - auto enum_internal_type = EnumTypeInfo::DictType(size); - switch (enum_internal_type) { - case PhysicalType::UINT8: - info = make_shared>(ordered_data, size); - break; - case PhysicalType::UINT16: - info = make_shared>(ordered_data, size); - break; - case PhysicalType::UINT32: - info = make_shared>(ordered_data, size); - break; - default: - throw InternalException("Invalid Physical Type for ENUMs"); - } - // Generate Actual Enum Type - return LogicalType(LogicalTypeId::ENUM, info); -} - -template -int64_t TemplatedGetPos(const string_map_t &map, const string_t &key) { - auto it = map.find(key); - if (it == map.end()) { - return -1; - } - return it->second; -} - -int64_t EnumType::GetPos(const LogicalType &type, const string_t &key) { - auto info = type.AuxInfo(); - switch (type.InternalType()) { - case PhysicalType::UINT8: - return TemplatedGetPos(info->Cast>().GetValues(), key); - case PhysicalType::UINT16: - return TemplatedGetPos(info->Cast>().GetValues(), key); - case PhysicalType::UINT32: - return TemplatedGetPos(info->Cast>().GetValues(), key); - default: - throw InternalException("ENUM can only have unsigned integers (except UINT64) as physical types"); - } -} - -string_t EnumType::GetString(const LogicalType &type, idx_t pos) { - D_ASSERT(pos < EnumType::GetSize(type)); - return FlatVector::GetData(EnumType::GetValuesInsertOrder(type))[pos]; -} - -shared_ptr EnumTypeInfo::Deserialize(Deserializer &deserializer) { - auto values_count = deserializer.ReadProperty(200, "values_count"); - auto enum_internal_type = EnumTypeInfo::DictType(values_count); - switch (enum_internal_type) { - case PhysicalType::UINT8: - return EnumTypeInfoTemplated::Deserialize(deserializer, values_count); - case PhysicalType::UINT16: - return EnumTypeInfoTemplated::Deserialize(deserializer, values_count); - case PhysicalType::UINT32: - return EnumTypeInfoTemplated::Deserialize(deserializer, values_count); - default: - throw InternalException("Invalid Physical Type for ENUMs"); - } -} - -// Equalities are only used in enums with different catalog entries -bool EnumTypeInfo::EqualsInternal(ExtraTypeInfo *other_p) const { - auto &other = other_p->Cast(); - if (dict_type != other.dict_type) { - return false; - } - D_ASSERT(dict_type == EnumDictType::VECTOR_DICT); - // We must check if both enums have the same size - if (other.dict_size != dict_size) { - return false; - } - auto other_vector_ptr = FlatVector::GetData(other.values_insert_order); - auto this_vector_ptr = FlatVector::GetData(values_insert_order); - - // Now we must check if all strings are the same - for (idx_t i = 0; i < dict_size; i++) { - if (!Equals::Operation(other_vector_ptr[i], this_vector_ptr[i])) { - return false; - } - } - return true; -} - -void EnumTypeInfo::Serialize(Serializer &serializer) const { - ExtraTypeInfo::Serialize(serializer); - - // Enums are special in that we serialize their values as a list instead of dumping the whole vector - auto strings = FlatVector::GetData(values_insert_order); - serializer.WriteProperty(200, "values_count", dict_size); - serializer.WriteList(201, "values", dict_size, - [&](Serializer::List &list, idx_t i) { list.WriteElement(strings[i]); }); -} - -} // namespace duckdb - - - - - - - - -#include - -namespace duckdb { - -FileBuffer::FileBuffer(Allocator &allocator, FileBufferType type, uint64_t user_size) - : allocator(allocator), type(type) { - Init(); - if (user_size) { - Resize(user_size); - } -} - -void FileBuffer::Init() { - buffer = nullptr; - size = 0; - internal_buffer = nullptr; - internal_size = 0; -} - -FileBuffer::FileBuffer(FileBuffer &source, FileBufferType type_p) : allocator(source.allocator), type(type_p) { - // take over the structures of the source buffer - buffer = source.buffer; - size = source.size; - internal_buffer = source.internal_buffer; - internal_size = source.internal_size; - - source.Init(); -} - -FileBuffer::~FileBuffer() { - if (!internal_buffer) { - return; - } - allocator.FreeData(internal_buffer, internal_size); -} - -void FileBuffer::ReallocBuffer(size_t new_size) { - data_ptr_t new_buffer; - if (internal_buffer) { - new_buffer = allocator.ReallocateData(internal_buffer, internal_size, new_size); - } else { - new_buffer = allocator.AllocateData(new_size); - } - if (!new_buffer) { - throw std::bad_alloc(); - } - internal_buffer = new_buffer; - internal_size = new_size; - // Caller must update these. - buffer = nullptr; - size = 0; -} - -FileBuffer::MemoryRequirement FileBuffer::CalculateMemory(uint64_t user_size) { - FileBuffer::MemoryRequirement result; - - if (type == FileBufferType::TINY_BUFFER) { - // We never do IO on tiny buffers, so there's no need to add a header or sector-align. - result.header_size = 0; - result.alloc_size = user_size; - } else { - result.header_size = Storage::BLOCK_HEADER_SIZE; - result.alloc_size = AlignValue(result.header_size + user_size); - } - return result; -} - -void FileBuffer::Resize(uint64_t new_size) { - auto req = CalculateMemory(new_size); - ReallocBuffer(req.alloc_size); - - if (new_size > 0) { - buffer = internal_buffer + req.header_size; - size = internal_size - req.header_size; - } -} - -void FileBuffer::Read(FileHandle &handle, uint64_t location) { - D_ASSERT(type != FileBufferType::TINY_BUFFER); - handle.Read(internal_buffer, internal_size, location); -} - -void FileBuffer::Write(FileHandle &handle, uint64_t location) { - D_ASSERT(type != FileBufferType::TINY_BUFFER); - handle.Write(internal_buffer, internal_size, location); -} - -void FileBuffer::Clear() { - memset(internal_buffer, 0, internal_size); -} - -void FileBuffer::Initialize(DebugInitialize initialize) { - if (initialize == DebugInitialize::NO_INITIALIZE) { - return; - } - uint8_t value = initialize == DebugInitialize::DEBUG_ZERO_INITIALIZE ? 0 : 0xFF; - memset(internal_buffer, value, internal_size); -} - -} // namespace duckdb - - - - - - - - - - - - - - - -#include -#include - -#ifndef _WIN32 -#include -#include -#include -#include -#include -#include - -#ifdef __MVS__ -#define _XOPEN_SOURCE_EXTENDED 1 -#include -// enjoy - https://reviews.llvm.org/D92110 -#define PATH_MAX _XOPEN_PATH_MAX -#endif - -#else -#include -#include - -#ifdef __MINGW32__ -// need to manually define this for mingw -extern "C" WINBASEAPI BOOL WINAPI GetPhysicallyInstalledSystemMemory(PULONGLONG); -#endif - -#undef FILE_CREATE // woo mingw -#endif - -namespace duckdb { - -FileSystem::~FileSystem() { -} - -FileSystem &FileSystem::GetFileSystem(ClientContext &context) { - auto &client_data = ClientData::Get(context); - return *client_data.client_file_system; -} - -bool PathMatched(const string &path, const string &sub_path) { - if (path.rfind(sub_path, 0) == 0) { - return true; - } - return false; -} - -#ifndef _WIN32 - -string FileSystem::GetEnvVariable(const string &name) { - const char *env = getenv(name.c_str()); - if (!env) { - return string(); - } - return env; -} - -bool FileSystem::IsPathAbsolute(const string &path) { - auto path_separator = PathSeparator(path); - return PathMatched(path, path_separator); -} - -string FileSystem::PathSeparator(const string &path) { - return "/"; -} - -void FileSystem::SetWorkingDirectory(const string &path) { - if (chdir(path.c_str()) != 0) { - throw IOException("Could not change working directory!"); - } -} - -idx_t FileSystem::GetAvailableMemory() { - errno = 0; - -#ifdef __MVS__ - struct rlimit limit; - int rlim_rc = getrlimit(RLIMIT_AS, &limit); - idx_t max_memory = MinValue(limit.rlim_max, UINTPTR_MAX); -#else - idx_t max_memory = MinValue((idx_t)sysconf(_SC_PHYS_PAGES) * (idx_t)sysconf(_SC_PAGESIZE), UINTPTR_MAX); -#endif - if (errno != 0) { - return DConstants::INVALID_INDEX; - } - return max_memory; -} - -string FileSystem::GetWorkingDirectory() { - auto buffer = make_unsafe_uniq_array(PATH_MAX); - char *ret = getcwd(buffer.get(), PATH_MAX); - if (!ret) { - throw IOException("Could not get working directory!"); - } - return string(buffer.get()); -} - -string FileSystem::NormalizeAbsolutePath(const string &path) { - D_ASSERT(IsPathAbsolute(path)); - return path; -} - -#else - -string FileSystem::GetEnvVariable(const string &env) { - // first convert the environment variable name to the correct encoding - auto env_w = WindowsUtil::UTF8ToUnicode(env.c_str()); - // use _wgetenv to get the value - auto res_w = _wgetenv(env_w.c_str()); - if (!res_w) { - // no environment variable of this name found - return string(); - } - return WindowsUtil::UnicodeToUTF8(res_w); -} - -static bool StartsWithSingleBackslash(const string &path) { - if (path.size() < 2) { - return false; - } - if (path[0] != '/' && path[0] != '\\') { - return false; - } - if (path[1] == '/' || path[1] == '\\') { - return false; - } - return true; -} - -bool FileSystem::IsPathAbsolute(const string &path) { - // 1) A single backslash or forward-slash - if (StartsWithSingleBackslash(path)) { - return true; - } - // 2) A disk designator with a backslash (e.g., C:\ or C:/) - auto path_aux = path; - path_aux.erase(0, 1); - if (PathMatched(path_aux, ":\\") || PathMatched(path_aux, ":/")) { - return true; - } - return false; -} - -string FileSystem::NormalizeAbsolutePath(const string &path) { - D_ASSERT(IsPathAbsolute(path)); - auto result = StringUtil::Lower(FileSystem::ConvertSeparators(path)); - if (StartsWithSingleBackslash(result)) { - // Path starts with a single backslash or forward slash - // prepend drive letter - return GetWorkingDirectory().substr(0, 2) + result; - } - return result; -} - -string FileSystem::PathSeparator(const string &path) { - return "\\"; -} - -void FileSystem::SetWorkingDirectory(const string &path) { - auto unicode_path = WindowsUtil::UTF8ToUnicode(path.c_str()); - if (!SetCurrentDirectoryW(unicode_path.c_str())) { - throw IOException("Could not change working directory to \"%s\"", path); - } -} - -idx_t FileSystem::GetAvailableMemory() { - ULONGLONG available_memory_kb; - if (GetPhysicallyInstalledSystemMemory(&available_memory_kb)) { - return MinValue(available_memory_kb * 1000, UINTPTR_MAX); - } - // fallback: try GlobalMemoryStatusEx - MEMORYSTATUSEX mem_state; - mem_state.dwLength = sizeof(MEMORYSTATUSEX); - - if (GlobalMemoryStatusEx(&mem_state)) { - return MinValue(mem_state.ullTotalPhys, UINTPTR_MAX); - } - return DConstants::INVALID_INDEX; -} - -string FileSystem::GetWorkingDirectory() { - idx_t count = GetCurrentDirectoryW(0, nullptr); - if (count == 0) { - throw IOException("Could not get working directory!"); - } - auto buffer = make_unsafe_uniq_array(count); - idx_t ret = GetCurrentDirectoryW(count, buffer.get()); - if (count != ret + 1) { - throw IOException("Could not get working directory!"); - } - return WindowsUtil::UnicodeToUTF8(buffer.get()); -} - -#endif - -string FileSystem::JoinPath(const string &a, const string &b) { - // FIXME: sanitize paths - return a + PathSeparator(a) + b; -} - -string FileSystem::ConvertSeparators(const string &path) { - auto separator_str = PathSeparator(path); - char separator = separator_str[0]; - if (separator == '/') { - // on unix-based systems we only accept / as a separator - return path; - } - // on windows-based systems we accept both - return StringUtil::Replace(path, "/", separator_str); -} - -string FileSystem::ExtractName(const string &path) { - if (path.empty()) { - return string(); - } - auto normalized_path = ConvertSeparators(path); - auto sep = PathSeparator(path); - auto splits = StringUtil::Split(normalized_path, sep); - D_ASSERT(!splits.empty()); - return splits.back(); -} - -string FileSystem::ExtractBaseName(const string &path) { - if (path.empty()) { - return string(); - } - auto vec = StringUtil::Split(ExtractName(path), "."); - D_ASSERT(!vec.empty()); - return vec[0]; -} - -string FileSystem::GetHomeDirectory(optional_ptr opener) { - // read the home_directory setting first, if it is set - if (opener) { - Value result; - if (opener->TryGetCurrentSetting("home_directory", result)) { - if (!result.IsNull() && !result.ToString().empty()) { - return result.ToString(); - } - } - } - // fallback to the default home directories for the specified system -#ifdef DUCKDB_WINDOWS - return FileSystem::GetEnvVariable("USERPROFILE"); -#else - return FileSystem::GetEnvVariable("HOME"); -#endif -} - -string FileSystem::GetHomeDirectory() { - return GetHomeDirectory(nullptr); -} - -string FileSystem::ExpandPath(const string &path, optional_ptr opener) { - if (path.empty()) { - return path; - } - if (path[0] == '~') { - return GetHomeDirectory(opener) + path.substr(1); - } - return path; -} - -string FileSystem::ExpandPath(const string &path) { - return FileSystem::ExpandPath(path, nullptr); -} - -// LCOV_EXCL_START -unique_ptr FileSystem::OpenFile(const string &path, uint8_t flags, FileLockType lock, - FileCompressionType compression, FileOpener *opener) { - throw NotImplementedException("%s: OpenFile is not implemented!", GetName()); -} - -void FileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) { - throw NotImplementedException("%s: Read (with location) is not implemented!", GetName()); -} - -void FileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) { - throw NotImplementedException("%s: Write (with location) is not implemented!", GetName()); -} - -int64_t FileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes) { - throw NotImplementedException("%s: Read is not implemented!", GetName()); -} - -int64_t FileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes) { - throw NotImplementedException("%s: Write is not implemented!", GetName()); -} - -int64_t FileSystem::GetFileSize(FileHandle &handle) { - throw NotImplementedException("%s: GetFileSize is not implemented!", GetName()); -} - -time_t FileSystem::GetLastModifiedTime(FileHandle &handle) { - throw NotImplementedException("%s: GetLastModifiedTime is not implemented!", GetName()); -} - -FileType FileSystem::GetFileType(FileHandle &handle) { - return FileType::FILE_TYPE_INVALID; -} - -void FileSystem::Truncate(FileHandle &handle, int64_t new_size) { - throw NotImplementedException("%s: Truncate is not implemented!", GetName()); -} - -bool FileSystem::DirectoryExists(const string &directory) { - throw NotImplementedException("%s: DirectoryExists is not implemented!", GetName()); -} - -void FileSystem::CreateDirectory(const string &directory) { - throw NotImplementedException("%s: CreateDirectory is not implemented!", GetName()); -} - -void FileSystem::RemoveDirectory(const string &directory) { - throw NotImplementedException("%s: RemoveDirectory is not implemented!", GetName()); -} - -bool FileSystem::ListFiles(const string &directory, const std::function &callback, - FileOpener *opener) { - throw NotImplementedException("%s: ListFiles is not implemented!", GetName()); -} - -void FileSystem::MoveFile(const string &source, const string &target) { - throw NotImplementedException("%s: MoveFile is not implemented!", GetName()); -} - -bool FileSystem::FileExists(const string &filename) { - throw NotImplementedException("%s: FileExists is not implemented!", GetName()); -} - -bool FileSystem::IsPipe(const string &filename) { - throw NotImplementedException("%s: IsPipe is not implemented!", GetName()); -} - -void FileSystem::RemoveFile(const string &filename) { - throw NotImplementedException("%s: RemoveFile is not implemented!", GetName()); -} - -void FileSystem::FileSync(FileHandle &handle) { - throw NotImplementedException("%s: FileSync is not implemented!", GetName()); -} - -bool FileSystem::HasGlob(const string &str) { - for (idx_t i = 0; i < str.size(); i++) { - switch (str[i]) { - case '*': - case '?': - case '[': - return true; - default: - break; - } - } - return false; -} - -vector FileSystem::Glob(const string &path, FileOpener *opener) { - throw NotImplementedException("%s: Glob is not implemented!", GetName()); -} - -void FileSystem::RegisterSubSystem(unique_ptr sub_fs) { - throw NotImplementedException("%s: Can't register a sub system on a non-virtual file system", GetName()); -} - -void FileSystem::RegisterSubSystem(FileCompressionType compression_type, unique_ptr sub_fs) { - throw NotImplementedException("%s: Can't register a sub system on a non-virtual file system", GetName()); -} - -void FileSystem::UnregisterSubSystem(const string &name) { - throw NotImplementedException("%s: Can't unregister a sub system on a non-virtual file system", GetName()); -} - -void FileSystem::SetDisabledFileSystems(const vector &names) { - throw NotImplementedException("%s: Can't disable file systems on a non-virtual file system", GetName()); -} - -vector FileSystem::ListSubSystems() { - throw NotImplementedException("%s: Can't list sub systems on a non-virtual file system", GetName()); -} - -bool FileSystem::CanHandleFile(const string &fpath) { - throw NotImplementedException("%s: CanHandleFile is not implemented!", GetName()); -} - -static string LookupExtensionForPattern(const string &pattern) { - for (const auto &entry : EXTENSION_FILE_PREFIXES) { - if (StringUtil::StartsWith(pattern, entry.name)) { - return entry.extension; - } - } - return ""; -} - -vector FileSystem::GlobFiles(const string &pattern, ClientContext &context, FileGlobOptions options) { - auto result = Glob(pattern); - if (result.empty()) { - string required_extension = LookupExtensionForPattern(pattern); - if (!required_extension.empty() && !context.db->ExtensionIsLoaded(required_extension)) { - auto &dbconfig = DBConfig::GetConfig(context); - if (!ExtensionHelper::CanAutoloadExtension(required_extension) || - !dbconfig.options.autoload_known_extensions) { - auto error_message = - "File " + pattern + " requires the extension " + required_extension + " to be loaded"; - error_message = - ExtensionHelper::AddExtensionInstallHintToErrorMsg(context, error_message, required_extension); - throw MissingExtensionException(error_message); - } - // an extension is required to read this file, but it is not loaded - try to load it - ExtensionHelper::AutoLoadExtension(context, required_extension); - // success! glob again - // check the extension is loaded just in case to prevent an infinite loop here - if (!context.db->ExtensionIsLoaded(required_extension)) { - throw InternalException("Extension load \"%s\" did not throw but somehow the extension was not loaded", - required_extension); - } - return GlobFiles(pattern, context, options); - } - if (options == FileGlobOptions::DISALLOW_EMPTY) { - throw IOException("No files found that match the pattern \"%s\"", pattern); - } - } - return result; -} - -void FileSystem::Seek(FileHandle &handle, idx_t location) { - throw NotImplementedException("%s: Seek is not implemented!", GetName()); -} - -void FileSystem::Reset(FileHandle &handle) { - handle.Seek(0); -} - -idx_t FileSystem::SeekPosition(FileHandle &handle) { - throw NotImplementedException("%s: SeekPosition is not implemented!", GetName()); -} - -bool FileSystem::CanSeek() { - throw NotImplementedException("%s: CanSeek is not implemented!", GetName()); -} - -unique_ptr FileSystem::OpenCompressedFile(unique_ptr handle, bool write) { - throw NotImplementedException("%s: OpenCompressedFile is not implemented!", GetName()); -} - -bool FileSystem::OnDiskFile(FileHandle &handle) { - throw NotImplementedException("%s: OnDiskFile is not implemented!", GetName()); -} -// LCOV_EXCL_STOP - -FileHandle::FileHandle(FileSystem &file_system, string path_p) : file_system(file_system), path(std::move(path_p)) { -} - -FileHandle::~FileHandle() { -} - -int64_t FileHandle::Read(void *buffer, idx_t nr_bytes) { - return file_system.Read(*this, buffer, nr_bytes); -} - -int64_t FileHandle::Write(void *buffer, idx_t nr_bytes) { - return file_system.Write(*this, buffer, nr_bytes); -} - -void FileHandle::Read(void *buffer, idx_t nr_bytes, idx_t location) { - file_system.Read(*this, buffer, nr_bytes, location); -} - -void FileHandle::Write(void *buffer, idx_t nr_bytes, idx_t location) { - file_system.Write(*this, buffer, nr_bytes, location); -} - -void FileHandle::Seek(idx_t location) { - file_system.Seek(*this, location); -} - -void FileHandle::Reset() { - file_system.Reset(*this); -} - -idx_t FileHandle::SeekPosition() { - return file_system.SeekPosition(*this); -} - -bool FileHandle::CanSeek() { - return file_system.CanSeek(); -} - -string FileHandle::ReadLine() { - string result; - char buffer[1]; - while (true) { - idx_t tuples_read = Read(buffer, 1); - if (tuples_read == 0 || buffer[0] == '\n') { - return result; - } - if (buffer[0] != '\r') { - result += buffer[0]; - } - } -} - -bool FileHandle::OnDiskFile() { - return file_system.OnDiskFile(*this); -} - -idx_t FileHandle::GetFileSize() { - return file_system.GetFileSize(*this); -} - -void FileHandle::Sync() { - file_system.FileSync(*this); -} - -void FileHandle::Truncate(int64_t new_size) { - file_system.Truncate(*this, new_size); -} - -FileType FileHandle::GetType() { - return file_system.GetFileType(*this); -} - -bool FileSystem::IsRemoteFile(const string &path) { - const string prefixes[] = {"http://", "https://", "s3://"}; - for (auto &prefix : prefixes) { - if (StringUtil::StartsWith(path, prefix)) { - return true; - } - } - return false; -} - -} // namespace duckdb - - - -namespace duckdb { - -void FilenamePattern::SetFilenamePattern(const string &pattern) { - const string id_format {"{i}"}; - const string uuid_format {"{uuid}"}; - - _base = pattern; - - _pos = _base.find(id_format); - if (_pos != string::npos) { - _base = StringUtil::Replace(_base, id_format, ""); - _uuid = false; - } - - _pos = _base.find(uuid_format); - if (_pos != string::npos) { - _base = StringUtil::Replace(_base, uuid_format, ""); - _uuid = true; - } - - _pos = std::min(_pos, (idx_t)_base.length()); -} - -string FilenamePattern::CreateFilename(FileSystem &fs, const string &path, const string &extension, - idx_t offset) const { - string result(_base); - string replacement; - - if (_uuid) { - replacement = UUID::ToString(UUID::GenerateRandomUUID()); - } else { - replacement = std::to_string(offset); - } - result.insert(_pos, replacement); - return fs.JoinPath(path, result + "." + extension); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -string_t FSSTPrimitives::DecompressValue(void *duckdb_fsst_decoder, Vector &result, const char *compressed_string, - idx_t compressed_string_len) { - D_ASSERT(result.GetVectorType() == VectorType::FLAT_VECTOR); - unsigned char decompress_buffer[StringUncompressed::STRING_BLOCK_LIMIT + 1]; - auto fsst_decoder = reinterpret_cast(duckdb_fsst_decoder); - auto compressed_string_ptr = (unsigned char *)compressed_string; // NOLINT - auto decompressed_string_size = - duckdb_fsst_decompress(fsst_decoder, compressed_string_len, compressed_string_ptr, - StringUncompressed::STRING_BLOCK_LIMIT + 1, &decompress_buffer[0]); - D_ASSERT(decompressed_string_size <= StringUncompressed::STRING_BLOCK_LIMIT); - - return StringVector::AddStringOrBlob(result, const_char_ptr_cast(decompress_buffer), decompressed_string_size); -} - -Value FSSTPrimitives::DecompressValue(void *duckdb_fsst_decoder, const char *compressed_string, - idx_t compressed_string_len) { - unsigned char decompress_buffer[StringUncompressed::STRING_BLOCK_LIMIT + 1]; - auto compressed_string_ptr = (unsigned char *)compressed_string; // NOLINT - auto fsst_decoder = reinterpret_cast(duckdb_fsst_decoder); - auto decompressed_string_size = - duckdb_fsst_decompress(fsst_decoder, compressed_string_len, compressed_string_ptr, - StringUncompressed::STRING_BLOCK_LIMIT + 1, &decompress_buffer[0]); - D_ASSERT(decompressed_string_size <= StringUncompressed::STRING_BLOCK_LIMIT); - - return Value(string(char_ptr_cast(decompress_buffer), decompressed_string_size)); -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -/* - - 0 2 bytes magic header 0x1f, 0x8b (\037 \213) - 2 1 byte compression method - 0: store (copied) - 1: compress - 2: pack - 3: lzh - 4..7: reserved - 8: deflate - 3 1 byte flags - bit 0 set: file probably ascii text - bit 1 set: continuation of multi-part gzip file, part number present - bit 2 set: extra field present - bit 3 set: original file name present - bit 4 set: file comment present - bit 5 set: file is encrypted, encryption header present - bit 6,7: reserved - 4 4 bytes file modification time in Unix format - 8 1 byte extra flags (depend on compression method) - 9 1 byte OS type -[ - 2 bytes optional part number (second part=1) -]? -[ - 2 bytes optional extra field length (e) - (e)bytes optional extra field -]? -[ - bytes optional original file name, zero terminated -]? -[ - bytes optional file comment, zero terminated -]? -[ - 12 bytes optional encryption header -]? - bytes compressed data - 4 bytes crc32 - 4 bytes uncompressed input size modulo 2^32 - - */ - -static idx_t GZipConsumeString(FileHandle &input) { - idx_t size = 1; // terminator - char buffer[1]; - while (input.Read(buffer, 1) == 1) { - if (buffer[0] == '\0') { - break; - } - size++; - } - return size; -} - -struct MiniZStreamWrapper : public StreamWrapper { - ~MiniZStreamWrapper() override; - - CompressedFile *file = nullptr; - duckdb_miniz::mz_stream *mz_stream_ptr = nullptr; - bool writing = false; - duckdb_miniz::mz_ulong crc; - idx_t total_size; - -public: - void Initialize(CompressedFile &file, bool write) override; - - bool Read(StreamData &stream_data) override; - void Write(CompressedFile &file, StreamData &stream_data, data_ptr_t buffer, int64_t nr_bytes) override; - - void Close() override; - - void FlushStream(); -}; - -MiniZStreamWrapper::~MiniZStreamWrapper() { - // avoid closing if destroyed during stack unwinding - if (Exception::UncaughtException()) { - return; - } - try { - MiniZStreamWrapper::Close(); - } catch (...) { - } -} - -void MiniZStreamWrapper::Initialize(CompressedFile &file, bool write) { - Close(); - this->file = &file; - mz_stream_ptr = new duckdb_miniz::mz_stream(); - memset(mz_stream_ptr, 0, sizeof(duckdb_miniz::mz_stream)); - this->writing = write; - - // TODO use custom alloc/free methods in miniz to throw exceptions on OOM - uint8_t gzip_hdr[GZIP_HEADER_MINSIZE]; - if (write) { - crc = MZ_CRC32_INIT; - total_size = 0; - - MiniZStream::InitializeGZIPHeader(gzip_hdr); - file.child_handle->Write(gzip_hdr, GZIP_HEADER_MINSIZE); - - auto ret = mz_deflateInit2((duckdb_miniz::mz_streamp)mz_stream_ptr, duckdb_miniz::MZ_DEFAULT_LEVEL, MZ_DEFLATED, - -MZ_DEFAULT_WINDOW_BITS, 1, 0); - if (ret != duckdb_miniz::MZ_OK) { - throw InternalException("Failed to initialize miniz"); - } - } else { - idx_t data_start = GZIP_HEADER_MINSIZE; - auto read_count = file.child_handle->Read(gzip_hdr, GZIP_HEADER_MINSIZE); - GZipFileSystem::VerifyGZIPHeader(gzip_hdr, read_count); - // Skip over the extra field if necessary - if (gzip_hdr[3] & GZIP_FLAG_EXTRA) { - uint8_t gzip_xlen[2]; - file.child_handle->Seek(data_start); - file.child_handle->Read(gzip_xlen, 2); - idx_t xlen = (uint8_t)gzip_xlen[0] | (uint8_t)gzip_xlen[1] << 8; - data_start += xlen + 2; - } - // Skip over the file name if necessary - if (gzip_hdr[3] & GZIP_FLAG_NAME) { - file.child_handle->Seek(data_start); - data_start += GZipConsumeString(*file.child_handle); - } - file.child_handle->Seek(data_start); - // stream is now set to beginning of payload data - auto ret = duckdb_miniz::mz_inflateInit2((duckdb_miniz::mz_streamp)mz_stream_ptr, -MZ_DEFAULT_WINDOW_BITS); - if (ret != duckdb_miniz::MZ_OK) { - throw InternalException("Failed to initialize miniz"); - } - } -} - -bool MiniZStreamWrapper::Read(StreamData &sd) { - // Handling for the concatenated files - if (sd.refresh) { - auto available = (uint32_t)(sd.in_buff_end - sd.in_buff_start); - if (available <= GZIP_FOOTER_SIZE) { - // Only footer is available so we just close and return finished - Close(); - return true; - } - - sd.refresh = false; - auto body_ptr = sd.in_buff_start + GZIP_FOOTER_SIZE; - uint8_t gzip_hdr[GZIP_HEADER_MINSIZE]; - memcpy(gzip_hdr, body_ptr, GZIP_HEADER_MINSIZE); - GZipFileSystem::VerifyGZIPHeader(gzip_hdr, GZIP_HEADER_MINSIZE); - body_ptr += GZIP_HEADER_MINSIZE; - if (gzip_hdr[3] & GZIP_FLAG_EXTRA) { - idx_t xlen = (uint8_t)*body_ptr | (uint8_t) * (body_ptr + 1) << 8; - body_ptr += xlen + 2; - if (GZIP_FOOTER_SIZE + GZIP_HEADER_MINSIZE + 2 + xlen >= GZIP_HEADER_MAXSIZE) { - throw InternalException("Extra field resulting in GZIP header larger than defined maximum (%d)", - GZIP_HEADER_MAXSIZE); - } - } - if (gzip_hdr[3] & GZIP_FLAG_NAME) { - char c; - do { - c = *body_ptr; - body_ptr++; - } while (c != '\0' && body_ptr < sd.in_buff_end); - if ((idx_t)(body_ptr - sd.in_buff_start) >= GZIP_HEADER_MAXSIZE) { - throw InternalException("Filename resulting in GZIP header larger than defined maximum (%d)", - GZIP_HEADER_MAXSIZE); - } - } - sd.in_buff_start = body_ptr; - if (sd.in_buff_end - sd.in_buff_start < 1) { - Close(); - return true; - } - duckdb_miniz::mz_inflateEnd(mz_stream_ptr); - auto sta = duckdb_miniz::mz_inflateInit2((duckdb_miniz::mz_streamp)mz_stream_ptr, -MZ_DEFAULT_WINDOW_BITS); - if (sta != duckdb_miniz::MZ_OK) { - throw InternalException("Failed to initialize miniz"); - } - } - - // actually decompress - mz_stream_ptr->next_in = sd.in_buff_start; - D_ASSERT(sd.in_buff_end - sd.in_buff_start < NumericLimits::Maximum()); - mz_stream_ptr->avail_in = (uint32_t)(sd.in_buff_end - sd.in_buff_start); - mz_stream_ptr->next_out = data_ptr_cast(sd.out_buff_end); - mz_stream_ptr->avail_out = (uint32_t)((sd.out_buff.get() + sd.out_buf_size) - sd.out_buff_end); - auto ret = duckdb_miniz::mz_inflate(mz_stream_ptr, duckdb_miniz::MZ_NO_FLUSH); - if (ret != duckdb_miniz::MZ_OK && ret != duckdb_miniz::MZ_STREAM_END) { - throw IOException("Failed to decode gzip stream: %s", duckdb_miniz::mz_error(ret)); - } - // update pointers following inflate() - sd.in_buff_start = (data_ptr_t)mz_stream_ptr->next_in; // NOLINT - sd.in_buff_end = sd.in_buff_start + mz_stream_ptr->avail_in; - sd.out_buff_end = data_ptr_cast(mz_stream_ptr->next_out); - D_ASSERT(sd.out_buff_end + mz_stream_ptr->avail_out == sd.out_buff.get() + sd.out_buf_size); - - // if stream ended, deallocate inflator - if (ret == duckdb_miniz::MZ_STREAM_END) { - // Concatenated GZIP potentially coming up - refresh input buffer - sd.refresh = true; - } - return false; -} - -void MiniZStreamWrapper::Write(CompressedFile &file, StreamData &sd, data_ptr_t uncompressed_data, - int64_t uncompressed_size) { - // update the src and the total size - crc = duckdb_miniz::mz_crc32(crc, reinterpret_cast(uncompressed_data), uncompressed_size); - total_size += uncompressed_size; - - auto remaining = uncompressed_size; - while (remaining > 0) { - idx_t output_remaining = (sd.out_buff.get() + sd.out_buf_size) - sd.out_buff_start; - - mz_stream_ptr->next_in = reinterpret_cast(uncompressed_data); - mz_stream_ptr->avail_in = remaining; - mz_stream_ptr->next_out = sd.out_buff_start; - mz_stream_ptr->avail_out = output_remaining; - - auto res = mz_deflate(mz_stream_ptr, duckdb_miniz::MZ_NO_FLUSH); - if (res != duckdb_miniz::MZ_OK) { - D_ASSERT(res != duckdb_miniz::MZ_STREAM_END); - throw InternalException("Failed to compress GZIP block"); - } - sd.out_buff_start += output_remaining - mz_stream_ptr->avail_out; - if (mz_stream_ptr->avail_out == 0) { - // no more output buffer available: flush - file.child_handle->Write(sd.out_buff.get(), sd.out_buff_start - sd.out_buff.get()); - sd.out_buff_start = sd.out_buff.get(); - } - idx_t written = remaining - mz_stream_ptr->avail_in; - uncompressed_data += written; - remaining = mz_stream_ptr->avail_in; - } -} - -void MiniZStreamWrapper::FlushStream() { - auto &sd = file->stream_data; - mz_stream_ptr->next_in = nullptr; - mz_stream_ptr->avail_in = 0; - while (true) { - auto output_remaining = (sd.out_buff.get() + sd.out_buf_size) - sd.out_buff_start; - mz_stream_ptr->next_out = sd.out_buff_start; - mz_stream_ptr->avail_out = output_remaining; - - auto res = mz_deflate(mz_stream_ptr, duckdb_miniz::MZ_FINISH); - sd.out_buff_start += (output_remaining - mz_stream_ptr->avail_out); - if (sd.out_buff_start > sd.out_buff.get()) { - file->child_handle->Write(sd.out_buff.get(), sd.out_buff_start - sd.out_buff.get()); - sd.out_buff_start = sd.out_buff.get(); - } - if (res == duckdb_miniz::MZ_STREAM_END) { - break; - } - if (res != duckdb_miniz::MZ_OK) { - throw InternalException("Failed to compress GZIP block"); - } - } -} - -void MiniZStreamWrapper::Close() { - if (!mz_stream_ptr) { - return; - } - if (writing) { - // flush anything remaining in the stream - FlushStream(); - - // write the footer - unsigned char gzip_footer[MiniZStream::GZIP_FOOTER_SIZE]; - MiniZStream::InitializeGZIPFooter(gzip_footer, crc, total_size); - file->child_handle->Write(gzip_footer, MiniZStream::GZIP_FOOTER_SIZE); - - duckdb_miniz::mz_deflateEnd(mz_stream_ptr); - } else { - duckdb_miniz::mz_inflateEnd(mz_stream_ptr); - } - delete mz_stream_ptr; - mz_stream_ptr = nullptr; - file = nullptr; -} - -class GZipFile : public CompressedFile { -public: - GZipFile(unique_ptr child_handle_p, const string &path, bool write) - : CompressedFile(gzip_fs, std::move(child_handle_p), path) { - Initialize(write); - } - - GZipFileSystem gzip_fs; -}; - -void GZipFileSystem::VerifyGZIPHeader(uint8_t gzip_hdr[], idx_t read_count) { - // check for incorrectly formatted files - if (read_count != GZIP_HEADER_MINSIZE) { - throw IOException("Input is not a GZIP stream"); - } - if (gzip_hdr[0] != 0x1F || gzip_hdr[1] != 0x8B) { // magic header - throw IOException("Input is not a GZIP stream"); - } - if (gzip_hdr[2] != GZIP_COMPRESSION_DEFLATE) { // compression method - throw IOException("Unsupported GZIP compression method"); - } - if (gzip_hdr[3] & GZIP_FLAG_UNSUPPORTED) { - throw IOException("Unsupported GZIP archive"); - } -} - -string GZipFileSystem::UncompressGZIPString(const string &in) { - // decompress file - auto body_ptr = in.data(); - - auto mz_stream_ptr = new duckdb_miniz::mz_stream(); - memset(mz_stream_ptr, 0, sizeof(duckdb_miniz::mz_stream)); - - uint8_t gzip_hdr[GZIP_HEADER_MINSIZE]; - - // check for incorrectly formatted files - - // TODO this is mostly the same as gzip_file_system.cpp - if (in.size() < GZIP_HEADER_MINSIZE) { - throw IOException("Input is not a GZIP stream"); - } - memcpy(gzip_hdr, body_ptr, GZIP_HEADER_MINSIZE); - body_ptr += GZIP_HEADER_MINSIZE; - GZipFileSystem::VerifyGZIPHeader(gzip_hdr, GZIP_HEADER_MINSIZE); - - if (gzip_hdr[3] & GZIP_FLAG_EXTRA) { - throw IOException("Extra field in a GZIP stream unsupported"); - } - - if (gzip_hdr[3] & GZIP_FLAG_NAME) { - char c; - do { - c = *body_ptr; - body_ptr++; - } while (c != '\0' && (idx_t)(body_ptr - in.data()) < in.size()); - } - - // stream is now set to beginning of payload data - auto status = duckdb_miniz::mz_inflateInit2(mz_stream_ptr, -MZ_DEFAULT_WINDOW_BITS); - if (status != duckdb_miniz::MZ_OK) { - throw InternalException("Failed to initialize miniz"); - } - - auto bytes_remaining = in.size() - (body_ptr - in.data()); - mz_stream_ptr->next_in = const_uchar_ptr_cast(body_ptr); - mz_stream_ptr->avail_in = bytes_remaining; - - unsigned char decompress_buffer[BUFSIZ]; - string decompressed; - - while (status == duckdb_miniz::MZ_OK) { - mz_stream_ptr->next_out = decompress_buffer; - mz_stream_ptr->avail_out = sizeof(decompress_buffer); - status = mz_inflate(mz_stream_ptr, duckdb_miniz::MZ_NO_FLUSH); - if (status != duckdb_miniz::MZ_STREAM_END && status != duckdb_miniz::MZ_OK) { - throw IOException("Failed to uncompress"); - } - decompressed.append(char_ptr_cast(decompress_buffer), mz_stream_ptr->total_out - decompressed.size()); - } - duckdb_miniz::mz_inflateEnd(mz_stream_ptr); - if (decompressed.empty()) { - throw IOException("Failed to uncompress"); - } - return decompressed; -} - -unique_ptr GZipFileSystem::OpenCompressedFile(unique_ptr handle, bool write) { - auto path = handle->path; - return make_uniq(std::move(handle), path, write); -} - -unique_ptr GZipFileSystem::CreateStream() { - return make_uniq(); -} - -idx_t GZipFileSystem::InBufferSize() { - return BUFFER_SIZE; -} - -idx_t GZipFileSystem::OutBufferSize() { - return BUFFER_SIZE; -} - -} // namespace duckdb - - - - - - - - - - - - -namespace duckdb { - -static unordered_map GetKnownColumnValues(string &filename, - unordered_map &column_map, - duckdb_re2::RE2 &compiled_regex, bool filename_col, - bool hive_partition_cols) { - unordered_map result; - - if (filename_col) { - auto lookup_column_id = column_map.find("filename"); - if (lookup_column_id != column_map.end()) { - result[lookup_column_id->second] = filename; - } - } - - if (hive_partition_cols) { - auto partitions = HivePartitioning::Parse(filename, compiled_regex); - for (auto &partition : partitions) { - auto lookup_column_id = column_map.find(partition.first); - if (lookup_column_id != column_map.end()) { - result[lookup_column_id->second] = partition.second; - } - } - } - - return result; -} - -// Takes an expression and converts a list of known column_refs to constants -static void ConvertKnownColRefToConstants(unique_ptr &expr, - unordered_map &known_column_values, idx_t table_index) { - if (expr->type == ExpressionType::BOUND_COLUMN_REF) { - auto &bound_colref = expr->Cast(); - - // This bound column ref is for another table - if (table_index != bound_colref.binding.table_index) { - return; - } - - auto lookup = known_column_values.find(bound_colref.binding.column_index); - if (lookup != known_column_values.end()) { - expr = make_uniq(Value(lookup->second).DefaultCastAs(bound_colref.return_type)); - } - } else { - ExpressionIterator::EnumerateChildren(*expr, [&](unique_ptr &child) { - ConvertKnownColRefToConstants(child, known_column_values, table_index); - }); - } -} - -// matches hive partitions in file name. For example: -// - s3://bucket/var1=value1/bla/bla/var2=value2 -// - http(s)://domain(:port)/lala/kasdl/var1=value1/?not-a-var=not-a-value -// - folder/folder/folder/../var1=value1/etc/.//var2=value2 -const string HivePartitioning::REGEX_STRING = "[\\/\\\\]([^\\/\\?\\\\]+)=([^\\/\\n\\?\\\\]+)"; - -std::map HivePartitioning::Parse(const string &filename, duckdb_re2::RE2 ®ex) { - std::map result; - duckdb_re2::StringPiece input(filename); // Wrap a StringPiece around it - - string var; - string value; - while (RE2::FindAndConsume(&input, regex, &var, &value)) { - result.insert(std::pair(var, value)); - } - return result; -} - -std::map HivePartitioning::Parse(const string &filename) { - duckdb_re2::RE2 regex(REGEX_STRING); - return Parse(filename, regex); -} - -// TODO: this can still be improved by removing the parts of filter expressions that are true for all remaining files. -// currently, only expressions that cannot be evaluated during pushdown are removed. -void HivePartitioning::ApplyFiltersToFileList(ClientContext &context, vector &files, - vector> &filters, - unordered_map &column_map, LogicalGet &get, - bool hive_enabled, bool filename_enabled) { - - vector pruned_files; - vector have_preserved_filter(filters.size(), false); - vector> pruned_filters; - unordered_set filters_applied_to_files; - duckdb_re2::RE2 regex(REGEX_STRING); - auto table_index = get.table_index; - - if ((!filename_enabled && !hive_enabled) || filters.empty()) { - return; - } - - for (idx_t i = 0; i < files.size(); i++) { - auto &file = files[i]; - bool should_prune_file = false; - auto known_values = GetKnownColumnValues(file, column_map, regex, filename_enabled, hive_enabled); - - FilterCombiner combiner(context); - - for (idx_t j = 0; j < filters.size(); j++) { - auto &filter = filters[j]; - unique_ptr filter_copy = filter->Copy(); - ConvertKnownColRefToConstants(filter_copy, known_values, table_index); - // Evaluate the filter, if it can be evaluated here, we can not prune this filter - Value result_value; - - if (!filter_copy->IsScalar() || !filter_copy->IsFoldable() || - !ExpressionExecutor::TryEvaluateScalar(context, *filter_copy, result_value)) { - // can not be evaluated only with the filename/hive columns added, we can not prune this filter - if (!have_preserved_filter[j]) { - pruned_filters.emplace_back(filter->Copy()); - have_preserved_filter[j] = true; - } - } else if (!result_value.GetValue()) { - // filter evaluates to false - should_prune_file = true; - // convert the filter to a table filter. - if (filters_applied_to_files.find(j) == filters_applied_to_files.end()) { - get.extra_info.file_filters += filter->ToString(); - filters_applied_to_files.insert(j); - } - } - } - - if (!should_prune_file) { - pruned_files.push_back(file); - } - } - - D_ASSERT(filters.size() >= pruned_filters.size()); - - filters = std::move(pruned_filters); - files = std::move(pruned_files); -} - -HivePartitionedColumnData::HivePartitionedColumnData(const HivePartitionedColumnData &other) - : PartitionedColumnData(other), hashes_v(LogicalType::HASH) { - // Synchronize to ensure consistency of shared partition map - if (other.global_state) { - global_state = other.global_state; - unique_lock lck(global_state->lock); - SynchronizeLocalMap(); - } - InitializeKeys(); -} - -void HivePartitionedColumnData::InitializeKeys() { - keys.resize(STANDARD_VECTOR_SIZE); - for (idx_t i = 0; i < STANDARD_VECTOR_SIZE; i++) { - keys[i].values.resize(group_by_columns.size()); - } -} - -template -static inline Value GetHiveKeyValue(const T &val) { - return Value::CreateValue(val); -} - -template -static inline Value GetHiveKeyValue(const T &val, const LogicalType &type) { - auto result = GetHiveKeyValue(val); - result.Reinterpret(type); - return result; -} - -static inline Value GetHiveKeyNullValue(const LogicalType &type) { - Value result; - result.Reinterpret(type); - return result; -} - -template -static void TemplatedGetHivePartitionValues(Vector &input, vector &keys, const idx_t col_idx, - const idx_t count) { - UnifiedVectorFormat format; - input.ToUnifiedFormat(count, format); - - const auto &sel = *format.sel; - const auto data = UnifiedVectorFormat::GetData(format); - const auto &validity = format.validity; - - const auto &type = input.GetType(); - - const auto reinterpret = Value::CreateValue(data[0]).GetTypeMutable() != type; - if (reinterpret) { - for (idx_t i = 0; i < count; i++) { - auto &key = keys[i]; - const auto idx = sel.get_index(i); - if (validity.RowIsValid(idx)) { - key.values[col_idx] = GetHiveKeyValue(data[idx], type); - } else { - key.values[col_idx] = GetHiveKeyNullValue(type); - } - } - } else { - for (idx_t i = 0; i < count; i++) { - auto &key = keys[i]; - const auto idx = sel.get_index(i); - if (validity.RowIsValid(idx)) { - key.values[col_idx] = GetHiveKeyValue(data[idx]); - } else { - key.values[col_idx] = GetHiveKeyNullValue(type); - } - } - } -} - -static void GetNestedHivePartitionValues(Vector &input, vector &keys, const idx_t col_idx, - const idx_t count) { - for (idx_t i = 0; i < count; i++) { - auto &key = keys[i]; - key.values[col_idx] = input.GetValue(i); - } -} - -static void GetHivePartitionValuesTypeSwitch(Vector &input, vector &keys, const idx_t col_idx, - const idx_t count) { - const auto &type = input.GetType(); - switch (type.InternalType()) { - case PhysicalType::BOOL: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); - break; - case PhysicalType::INT8: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); - break; - case PhysicalType::INT16: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); - break; - case PhysicalType::INT32: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); - break; - case PhysicalType::INT64: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); - break; - case PhysicalType::INT128: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); - break; - case PhysicalType::UINT8: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); - break; - case PhysicalType::UINT16: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); - break; - case PhysicalType::UINT32: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); - break; - case PhysicalType::UINT64: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); - break; - case PhysicalType::FLOAT: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); - break; - case PhysicalType::DOUBLE: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); - break; - case PhysicalType::INTERVAL: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); - break; - case PhysicalType::VARCHAR: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); - break; - case PhysicalType::STRUCT: - case PhysicalType::LIST: - GetNestedHivePartitionValues(input, keys, col_idx, count); - break; - default: - throw InternalException("Unsupported type for HivePartitionedColumnData::ComputePartitionIndices"); - } -} - -void HivePartitionedColumnData::ComputePartitionIndices(PartitionedColumnDataAppendState &state, DataChunk &input) { - const auto count = input.size(); - - input.Hash(group_by_columns, hashes_v); - hashes_v.Flatten(count); - - for (idx_t col_idx = 0; col_idx < group_by_columns.size(); col_idx++) { - auto &group_by_col = input.data[group_by_columns[col_idx]]; - GetHivePartitionValuesTypeSwitch(group_by_col, keys, col_idx, count); - } - - const auto hashes = FlatVector::GetData(hashes_v); - const auto partition_indices = FlatVector::GetData(state.partition_indices); - for (idx_t i = 0; i < count; i++) { - auto &key = keys[i]; - key.hash = hashes[i]; - auto lookup = local_partition_map.find(key); - if (lookup == local_partition_map.end()) { - idx_t new_partition_id = RegisterNewPartition(key, state); - partition_indices[i] = new_partition_id; - } else { - partition_indices[i] = lookup->second; - } - } -} - -std::map HivePartitionedColumnData::GetReverseMap() { - std::map ret; - for (const auto &pair : local_partition_map) { - ret[pair.second] = &(pair.first); - } - return ret; -} - -void HivePartitionedColumnData::GrowAllocators() { - unique_lock lck_gstate(allocators->lock); - - idx_t current_allocator_size = allocators->allocators.size(); - idx_t required_allocators = local_partition_map.size(); - - allocators->allocators.reserve(current_allocator_size); - for (idx_t i = current_allocator_size; i < required_allocators; i++) { - CreateAllocator(); - } - - D_ASSERT(allocators->allocators.size() == local_partition_map.size()); -} - -void HivePartitionedColumnData::GrowAppendState(PartitionedColumnDataAppendState &state) { - idx_t current_append_state_size = state.partition_append_states.size(); - idx_t required_append_state_size = local_partition_map.size(); - - for (idx_t i = current_append_state_size; i < required_append_state_size; i++) { - state.partition_append_states.emplace_back(make_uniq()); - state.partition_buffers.emplace_back(CreatePartitionBuffer()); - } -} - -void HivePartitionedColumnData::GrowPartitions(PartitionedColumnDataAppendState &state) { - idx_t current_partitions = partitions.size(); - idx_t required_partitions = local_partition_map.size(); - - D_ASSERT(allocators->allocators.size() == required_partitions); - - for (idx_t i = current_partitions; i < required_partitions; i++) { - partitions.emplace_back(CreatePartitionCollection(i)); - partitions[i]->InitializeAppend(*state.partition_append_states[i]); - } - D_ASSERT(partitions.size() == local_partition_map.size()); -} - -void HivePartitionedColumnData::SynchronizeLocalMap() { - // Synchronise global map into local, may contain changes from other threads too - for (auto it = global_state->partitions.begin() + local_partition_map.size(); it < global_state->partitions.end(); - it++) { - local_partition_map[(*it)->first] = (*it)->second; - } -} - -idx_t HivePartitionedColumnData::RegisterNewPartition(HivePartitionKey key, PartitionedColumnDataAppendState &state) { - if (global_state) { - idx_t partition_id; - - // Synchronize Global state with our local state with the newly discoveren partition - { - unique_lock lck_gstate(global_state->lock); - - // Insert into global map, or return partition if already present - auto res = - global_state->partition_map.emplace(std::make_pair(std::move(key), global_state->partition_map.size())); - auto it = res.first; - partition_id = it->second; - - // Add iterator to vector to allow incrementally updating local states from global state - global_state->partitions.emplace_back(it); - SynchronizeLocalMap(); - } - - // After synchronizing with the global state, we need to grow the shared allocators to support - // the number of partitions, which guarantees that there's always enough allocators available to each thread - GrowAllocators(); - - // Grow local partition data - GrowAppendState(state); - GrowPartitions(state); - - return partition_id; - } else { - return local_partition_map.emplace(std::make_pair(std::move(key), local_partition_map.size())).first->second; - } -} - -} // namespace duckdb - - -namespace duckdb { - -CachedFileHandle::CachedFileHandle(shared_ptr &file_p) { - // If the file was not yet initialized, we need to grab a lock. - if (!file_p->initialized) { - lock = make_uniq>(file_p->lock); - } - file = file_p; -} - -void CachedFileHandle::SetInitialized() { - if (file->initialized) { - throw InternalException("Cannot set initialized on cached file that was already initialized"); - } - if (!lock) { - throw InternalException("Cannot set initialized on cached file without lock"); - } - file->initialized = true; - lock = nullptr; -} - -void CachedFileHandle::AllocateBuffer(idx_t size) { - if (file->initialized) { - throw InternalException("Cannot allocate a buffer for a cached file that was already initialized"); - } - file->data = std::shared_ptr(new char[size], std::default_delete()); - file->capacity = size; -} - -void CachedFileHandle::GrowBuffer(idx_t new_capacity, idx_t bytes_to_copy) { - // copy shared ptr to old data - auto old_data = file->data; - // allocate new buffer that can hold the new capacity - AllocateBuffer(new_capacity); - // copy the old data - Write(old_data.get(), bytes_to_copy); -} - -void CachedFileHandle::Write(const char *buffer, idx_t length, idx_t offset) { - //! Only write to non-initialized files with a lock; - D_ASSERT(!file->initialized && lock); - memcpy(file->data.get() + offset, buffer, length); -} - -void HTTPState::Reset() { - // Reset Counters - head_count = 0; - get_count = 0; - put_count = 0; - post_count = 0; - total_bytes_received = 0; - total_bytes_sent = 0; - - // Reset cached files - cached_files.clear(); -} - -shared_ptr HTTPState::TryGetState(FileOpener *opener) { - auto client_context = FileOpener::TryGetClientContext(opener); - if (client_context) { - return client_context->client_data->http_state; - } - return nullptr; -} - -//! Get cache entry, create if not exists -shared_ptr &HTTPState::GetCachedFile(const string &path) { - lock_guard lock(cached_files_mutex); - auto &cache_entry_ref = cached_files[path]; - if (!cache_entry_ref) { - cache_entry_ref = make_shared(); - } - return cache_entry_ref; -} - -} // namespace duckdb - - - - - - - - - - - - -#include -#include -#include - -#ifndef _WIN32 -#include -#include -#include -#include -#include -#else - - -#include -#include - -#ifdef __MINGW32__ -// need to manually define this for mingw -extern "C" WINBASEAPI BOOL WINAPI GetPhysicallyInstalledSystemMemory(PULONGLONG); -#endif - -#undef FILE_CREATE // woo mingw -#endif - -namespace duckdb { - -static void AssertValidFileFlags(uint8_t flags) { -#ifdef DEBUG - bool is_read = flags & FileFlags::FILE_FLAGS_READ; - bool is_write = flags & FileFlags::FILE_FLAGS_WRITE; - // require either READ or WRITE (or both) - D_ASSERT(is_read || is_write); - // CREATE/Append flags require writing - D_ASSERT(is_write || !(flags & FileFlags::FILE_FLAGS_APPEND)); - D_ASSERT(is_write || !(flags & FileFlags::FILE_FLAGS_FILE_CREATE)); - D_ASSERT(is_write || !(flags & FileFlags::FILE_FLAGS_FILE_CREATE_NEW)); - // cannot combine CREATE and CREATE_NEW flags - D_ASSERT(!(flags & FileFlags::FILE_FLAGS_FILE_CREATE && flags & FileFlags::FILE_FLAGS_FILE_CREATE_NEW)); -#endif -} - -#ifndef _WIN32 -bool LocalFileSystem::FileExists(const string &filename) { - if (!filename.empty()) { - if (access(filename.c_str(), 0) == 0) { - struct stat status; - stat(filename.c_str(), &status); - if (S_ISREG(status.st_mode)) { - return true; - } - } - } - // if any condition fails - return false; -} - -bool LocalFileSystem::IsPipe(const string &filename) { - if (!filename.empty()) { - if (access(filename.c_str(), 0) == 0) { - struct stat status; - stat(filename.c_str(), &status); - if (S_ISFIFO(status.st_mode)) { - return true; - } - } - } - // if any condition fails - return false; -} - -#else -bool LocalFileSystem::FileExists(const string &filename) { - auto unicode_path = WindowsUtil::UTF8ToUnicode(filename.c_str()); - const wchar_t *wpath = unicode_path.c_str(); - if (_waccess(wpath, 0) == 0) { - struct _stati64 status; - _wstati64(wpath, &status); - if (status.st_mode & S_IFREG) { - return true; - } - } - return false; -} -bool LocalFileSystem::IsPipe(const string &filename) { - auto unicode_path = WindowsUtil::UTF8ToUnicode(filename.c_str()); - const wchar_t *wpath = unicode_path.c_str(); - if (_waccess(wpath, 0) == 0) { - struct _stati64 status; - _wstati64(wpath, &status); - if (status.st_mode & _S_IFCHR) { - return true; - } - } - return false; -} -#endif - -#ifndef _WIN32 -// somehow sometimes this is missing -#ifndef O_CLOEXEC -#define O_CLOEXEC 0 -#endif - -// Solaris -#ifndef O_DIRECT -#define O_DIRECT 0 -#endif - -struct UnixFileHandle : public FileHandle { -public: - UnixFileHandle(FileSystem &file_system, string path, int fd) : FileHandle(file_system, std::move(path)), fd(fd) { - } - ~UnixFileHandle() override { - UnixFileHandle::Close(); - } - - int fd; - -public: - void Close() override { - if (fd != -1) { - close(fd); - fd = -1; - } - }; -}; - -static FileType GetFileTypeInternal(int fd) { // LCOV_EXCL_START - struct stat s; - if (fstat(fd, &s) == -1) { - return FileType::FILE_TYPE_INVALID; - } - switch (s.st_mode & S_IFMT) { - case S_IFBLK: - return FileType::FILE_TYPE_BLOCKDEV; - case S_IFCHR: - return FileType::FILE_TYPE_CHARDEV; - case S_IFIFO: - return FileType::FILE_TYPE_FIFO; - case S_IFDIR: - return FileType::FILE_TYPE_DIR; - case S_IFLNK: - return FileType::FILE_TYPE_LINK; - case S_IFREG: - return FileType::FILE_TYPE_REGULAR; - case S_IFSOCK: - return FileType::FILE_TYPE_SOCKET; - default: - return FileType::FILE_TYPE_INVALID; - } -} // LCOV_EXCL_STOP - -unique_ptr LocalFileSystem::OpenFile(const string &path_p, uint8_t flags, FileLockType lock_type, - FileCompressionType compression, FileOpener *opener) { - auto path = FileSystem::ExpandPath(path_p, opener); - if (compression != FileCompressionType::UNCOMPRESSED) { - throw NotImplementedException("Unsupported compression type for default file system"); - } - - AssertValidFileFlags(flags); - - int open_flags = 0; - int rc; - bool open_read = flags & FileFlags::FILE_FLAGS_READ; - bool open_write = flags & FileFlags::FILE_FLAGS_WRITE; - if (open_read && open_write) { - open_flags = O_RDWR; - } else if (open_read) { - open_flags = O_RDONLY; - } else if (open_write) { - open_flags = O_WRONLY; - } else { - throw InternalException("READ, WRITE or both should be specified when opening a file"); - } - if (open_write) { - // need Read or Write - D_ASSERT(flags & FileFlags::FILE_FLAGS_WRITE); - open_flags |= O_CLOEXEC; - if (flags & FileFlags::FILE_FLAGS_FILE_CREATE) { - open_flags |= O_CREAT; - } else if (flags & FileFlags::FILE_FLAGS_FILE_CREATE_NEW) { - open_flags |= O_CREAT | O_TRUNC; - } - if (flags & FileFlags::FILE_FLAGS_APPEND) { - open_flags |= O_APPEND; - } - } - if (flags & FileFlags::FILE_FLAGS_DIRECT_IO) { -#if defined(__sun) && defined(__SVR4) - throw Exception("DIRECT_IO not supported on Solaris"); -#endif -#if defined(__DARWIN__) || defined(__APPLE__) || defined(__OpenBSD__) - // OSX does not have O_DIRECT, instead we need to use fcntl afterwards to support direct IO - open_flags |= O_SYNC; -#else - open_flags |= O_DIRECT | O_SYNC; -#endif - } - int fd = open(path.c_str(), open_flags, 0666); - if (fd == -1) { - throw IOException("Cannot open file \"%s\": %s", path, strerror(errno)); - } - // #if defined(__DARWIN__) || defined(__APPLE__) - // if (flags & FileFlags::FILE_FLAGS_DIRECT_IO) { - // // OSX requires fcntl for Direct IO - // rc = fcntl(fd, F_NOCACHE, 1); - // if (fd == -1) { - // throw IOException("Could not enable direct IO for file \"%s\": %s", path, strerror(errno)); - // } - // } - // #endif - if (lock_type != FileLockType::NO_LOCK) { - // set lock on file - // but only if it is not an input/output stream - auto file_type = GetFileTypeInternal(fd); - if (file_type != FileType::FILE_TYPE_FIFO && file_type != FileType::FILE_TYPE_SOCKET) { - struct flock fl; - memset(&fl, 0, sizeof fl); - fl.l_type = lock_type == FileLockType::READ_LOCK ? F_RDLCK : F_WRLCK; - fl.l_whence = SEEK_SET; - fl.l_start = 0; - fl.l_len = 0; - rc = fcntl(fd, F_SETLK, &fl); - if (rc == -1) { - throw IOException("Could not set lock on file \"%s\": %s", path, strerror(errno)); - } - } - } - return make_uniq(*this, path, fd); -} - -void LocalFileSystem::SetFilePointer(FileHandle &handle, idx_t location) { - int fd = handle.Cast().fd; - off_t offset = lseek(fd, location, SEEK_SET); - if (offset == (off_t)-1) { - throw IOException("Could not seek to location %lld for file \"%s\": %s", location, handle.path, - strerror(errno)); - } -} - -idx_t LocalFileSystem::GetFilePointer(FileHandle &handle) { - int fd = handle.Cast().fd; - off_t position = lseek(fd, 0, SEEK_CUR); - if (position == (off_t)-1) { - throw IOException("Could not get file position file \"%s\": %s", handle.path, strerror(errno)); - } - return position; -} - -void LocalFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) { - int fd = handle.Cast().fd; - auto read_buffer = char_ptr_cast(buffer); - while (nr_bytes > 0) { - int64_t bytes_read = pread(fd, read_buffer, nr_bytes, location); - if (bytes_read == -1) { - throw IOException("Could not read from file \"%s\": %s", handle.path, strerror(errno)); - } - if (bytes_read == 0) { - throw IOException( - "Could not read enough bytes from file \"%s\": attempted to read %llu bytes from location %llu", - handle.path, nr_bytes, location); - } - read_buffer += bytes_read; - nr_bytes -= bytes_read; - } -} - -int64_t LocalFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes) { - int fd = handle.Cast().fd; - int64_t bytes_read = read(fd, buffer, nr_bytes); - if (bytes_read == -1) { - throw IOException("Could not read from file \"%s\": %s", handle.path, strerror(errno)); - } - return bytes_read; -} - -void LocalFileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) { - int fd = handle.Cast().fd; - auto write_buffer = char_ptr_cast(buffer); - while (nr_bytes > 0) { - int64_t bytes_written = pwrite(fd, write_buffer, nr_bytes, location); - if (bytes_written < 0) { - throw IOException("Could not write file \"%s\": %s", handle.path, strerror(errno)); - } - D_ASSERT(bytes_written >= 0 && bytes_written); - write_buffer += bytes_written; - nr_bytes -= bytes_written; - } -} - -int64_t LocalFileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes) { - int fd = handle.Cast().fd; - int64_t bytes_written = write(fd, buffer, nr_bytes); - if (bytes_written == -1) { - throw IOException("Could not write file \"%s\": %s", handle.path, strerror(errno)); - } - return bytes_written; -} - -int64_t LocalFileSystem::GetFileSize(FileHandle &handle) { - int fd = handle.Cast().fd; - struct stat s; - if (fstat(fd, &s) == -1) { - return -1; - } - return s.st_size; -} - -time_t LocalFileSystem::GetLastModifiedTime(FileHandle &handle) { - int fd = handle.Cast().fd; - struct stat s; - if (fstat(fd, &s) == -1) { - return -1; - } - return s.st_mtime; -} - -FileType LocalFileSystem::GetFileType(FileHandle &handle) { - int fd = handle.Cast().fd; - return GetFileTypeInternal(fd); -} - -void LocalFileSystem::Truncate(FileHandle &handle, int64_t new_size) { - int fd = handle.Cast().fd; - if (ftruncate(fd, new_size) != 0) { - throw IOException("Could not truncate file \"%s\": %s", handle.path, strerror(errno)); - } -} - -bool LocalFileSystem::DirectoryExists(const string &directory) { - if (!directory.empty()) { - if (access(directory.c_str(), 0) == 0) { - struct stat status; - stat(directory.c_str(), &status); - if (status.st_mode & S_IFDIR) { - return true; - } - } - } - // if any condition fails - return false; -} - -void LocalFileSystem::CreateDirectory(const string &directory) { - struct stat st; - - if (stat(directory.c_str(), &st) != 0) { - /* Directory does not exist. EEXIST for race condition */ - if (mkdir(directory.c_str(), 0755) != 0 && errno != EEXIST) { - throw IOException("Failed to create directory \"%s\"!", directory); - } - } else if (!S_ISDIR(st.st_mode)) { - throw IOException("Failed to create directory \"%s\": path exists but is not a directory!", directory); - } -} - -int RemoveDirectoryRecursive(const char *path) { - DIR *d = opendir(path); - idx_t path_len = (idx_t)strlen(path); - int r = -1; - - if (d) { - struct dirent *p; - r = 0; - while (!r && (p = readdir(d))) { - int r2 = -1; - char *buf; - idx_t len; - /* Skip the names "." and ".." as we don't want to recurse on them. */ - if (!strcmp(p->d_name, ".") || !strcmp(p->d_name, "..")) { - continue; - } - len = path_len + (idx_t)strlen(p->d_name) + 2; - buf = new (std::nothrow) char[len]; - if (buf) { - struct stat statbuf; - snprintf(buf, len, "%s/%s", path, p->d_name); - if (!stat(buf, &statbuf)) { - if (S_ISDIR(statbuf.st_mode)) { - r2 = RemoveDirectoryRecursive(buf); - } else { - r2 = unlink(buf); - } - } - delete[] buf; - } - r = r2; - } - closedir(d); - } - if (!r) { - r = rmdir(path); - } - return r; -} - -void LocalFileSystem::RemoveDirectory(const string &directory) { - RemoveDirectoryRecursive(directory.c_str()); -} - -void LocalFileSystem::RemoveFile(const string &filename) { - if (std::remove(filename.c_str()) != 0) { - throw IOException("Could not remove file \"%s\": %s", filename, strerror(errno)); - } -} - -bool LocalFileSystem::ListFiles(const string &directory, const std::function &callback, - FileOpener *opener) { - if (!DirectoryExists(directory)) { - return false; - } - DIR *dir = opendir(directory.c_str()); - if (!dir) { - return false; - } - struct dirent *ent; - // loop over all files in the directory - while ((ent = readdir(dir)) != nullptr) { - string name = string(ent->d_name); - // skip . .. and empty files - if (name.empty() || name == "." || name == "..") { - continue; - } - // now stat the file to figure out if it is a regular file or directory - string full_path = JoinPath(directory, name); - if (access(full_path.c_str(), 0) != 0) { - continue; - } - struct stat status; - stat(full_path.c_str(), &status); - if (!(status.st_mode & S_IFREG) && !(status.st_mode & S_IFDIR)) { - // not a file or directory: skip - continue; - } - // invoke callback - callback(name, status.st_mode & S_IFDIR); - } - closedir(dir); - return true; -} - -void LocalFileSystem::FileSync(FileHandle &handle) { - int fd = handle.Cast().fd; - if (fsync(fd) != 0) { - throw FatalException("fsync failed!"); - } -} - -void LocalFileSystem::MoveFile(const string &source, const string &target) { - //! FIXME: rename does not guarantee atomicity or overwriting target file if it exists - if (rename(source.c_str(), target.c_str()) != 0) { - throw IOException("Could not rename file!"); - } -} - -std::string LocalFileSystem::GetLastErrorAsString() { - return string(); -} - -#else - -constexpr char PIPE_PREFIX[] = "\\\\.\\pipe\\"; - -// Returns the last Win32 error, in string format. Returns an empty string if there is no error. -std::string LocalFileSystem::GetLastErrorAsString() { - // Get the error message, if any. - DWORD errorMessageID = GetLastError(); - if (errorMessageID == 0) - return std::string(); // No error message has been recorded - - LPSTR messageBuffer = nullptr; - idx_t size = - FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, - NULL, errorMessageID, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&messageBuffer, 0, NULL); - - std::string message(messageBuffer, size); - - // Free the buffer. - LocalFree(messageBuffer); - - return message; -} - -struct WindowsFileHandle : public FileHandle { -public: - WindowsFileHandle(FileSystem &file_system, string path, HANDLE fd) - : FileHandle(file_system, path), position(0), fd(fd) { - } - ~WindowsFileHandle() override { - Close(); - } - - idx_t position; - HANDLE fd; - -public: - void Close() override { - if (!fd) { - return; - } - CloseHandle(fd); - fd = nullptr; - }; -}; - -unique_ptr LocalFileSystem::OpenFile(const string &path_p, uint8_t flags, FileLockType lock_type, - FileCompressionType compression, FileOpener *opener) { - auto path = FileSystem::ExpandPath(path_p, opener); - if (compression != FileCompressionType::UNCOMPRESSED) { - throw NotImplementedException("Unsupported compression type for default file system"); - } - AssertValidFileFlags(flags); - - DWORD desired_access; - DWORD share_mode; - DWORD creation_disposition = OPEN_EXISTING; - DWORD flags_and_attributes = FILE_ATTRIBUTE_NORMAL; - bool open_read = flags & FileFlags::FILE_FLAGS_READ; - bool open_write = flags & FileFlags::FILE_FLAGS_WRITE; - if (open_read && open_write) { - desired_access = GENERIC_READ | GENERIC_WRITE; - share_mode = 0; - } else if (open_read) { - desired_access = GENERIC_READ; - share_mode = FILE_SHARE_READ; - } else if (open_write) { - desired_access = GENERIC_WRITE; - share_mode = 0; - } else { - throw InternalException("READ, WRITE or both should be specified when opening a file"); - } - if (open_write) { - if (flags & FileFlags::FILE_FLAGS_FILE_CREATE) { - creation_disposition = OPEN_ALWAYS; - } else if (flags & FileFlags::FILE_FLAGS_FILE_CREATE_NEW) { - creation_disposition = CREATE_ALWAYS; - } - } - if (flags & FileFlags::FILE_FLAGS_DIRECT_IO) { - flags_and_attributes |= FILE_FLAG_NO_BUFFERING; - } - auto unicode_path = WindowsUtil::UTF8ToUnicode(path.c_str()); - HANDLE hFile = CreateFileW(unicode_path.c_str(), desired_access, share_mode, NULL, creation_disposition, - flags_and_attributes, NULL); - if (hFile == INVALID_HANDLE_VALUE) { - auto error = LocalFileSystem::GetLastErrorAsString(); - throw IOException("Cannot open file \"%s\": %s", path.c_str(), error); - } - auto handle = make_uniq(*this, path.c_str(), hFile); - if (flags & FileFlags::FILE_FLAGS_APPEND) { - auto file_size = GetFileSize(*handle); - SetFilePointer(*handle, file_size); - } - return std::move(handle); -} - -void LocalFileSystem::SetFilePointer(FileHandle &handle, idx_t location) { - auto &whandle = handle.Cast(); - whandle.position = location; - LARGE_INTEGER wlocation; - wlocation.QuadPart = location; - SetFilePointerEx(whandle.fd, wlocation, NULL, FILE_BEGIN); -} - -idx_t LocalFileSystem::GetFilePointer(FileHandle &handle) { - return handle.Cast().position; -} - -static DWORD FSInternalRead(FileHandle &handle, HANDLE hFile, void *buffer, int64_t nr_bytes, idx_t location) { - DWORD bytes_read = 0; - OVERLAPPED ov = {}; - ov.Internal = 0; - ov.InternalHigh = 0; - ov.Offset = location & 0xFFFFFFFF; - ov.OffsetHigh = location >> 32; - ov.hEvent = 0; - auto rc = ReadFile(hFile, buffer, (DWORD)nr_bytes, &bytes_read, &ov); - if (!rc) { - auto error = LocalFileSystem::GetLastErrorAsString(); - throw IOException("Could not read file \"%s\" (error in ReadFile(location: %llu, nr_bytes: %lld)): %s", - handle.path, location, nr_bytes, error); - } - return bytes_read; -} - -void LocalFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) { - HANDLE hFile = ((WindowsFileHandle &)handle).fd; - auto bytes_read = FSInternalRead(handle, hFile, buffer, nr_bytes, location); - if (bytes_read != nr_bytes) { - throw IOException("Could not read all bytes from file \"%s\": wanted=%lld read=%lld", handle.path, nr_bytes, - bytes_read); - } -} - -int64_t LocalFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes) { - HANDLE hFile = handle.Cast().fd; - auto &pos = handle.Cast().position; - auto n = std::min(std::max(GetFileSize(handle), pos) - pos, nr_bytes); - auto bytes_read = FSInternalRead(handle, hFile, buffer, n, pos); - pos += bytes_read; - return bytes_read; -} - -static DWORD FSInternalWrite(FileHandle &handle, HANDLE hFile, void *buffer, int64_t nr_bytes, idx_t location) { - DWORD bytes_written = 0; - OVERLAPPED ov = {}; - ov.Internal = 0; - ov.InternalHigh = 0; - ov.Offset = location & 0xFFFFFFFF; - ov.OffsetHigh = location >> 32; - ov.hEvent = 0; - auto rc = WriteFile(hFile, buffer, (DWORD)nr_bytes, &bytes_written, &ov); - if (!rc) { - auto error = LocalFileSystem::GetLastErrorAsString(); - throw IOException("Could not write file \"%s\" (error in WriteFile): %s", handle.path, error); - } - return bytes_written; -} - -void LocalFileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) { - HANDLE hFile = handle.Cast().fd; - auto bytes_written = FSInternalWrite(handle, hFile, buffer, nr_bytes, location); - if (bytes_written != nr_bytes) { - throw IOException("Could not write all bytes from file \"%s\": wanted=%lld wrote=%lld", handle.path, nr_bytes, - bytes_written); - } -} - -int64_t LocalFileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes) { - HANDLE hFile = handle.Cast().fd; - auto &pos = handle.Cast().position; - auto bytes_written = FSInternalWrite(handle, hFile, buffer, nr_bytes, pos); - pos += bytes_written; - return bytes_written; -} - -int64_t LocalFileSystem::GetFileSize(FileHandle &handle) { - HANDLE hFile = handle.Cast().fd; - LARGE_INTEGER result; - if (!GetFileSizeEx(hFile, &result)) { - return -1; - } - return result.QuadPart; -} - -time_t LocalFileSystem::GetLastModifiedTime(FileHandle &handle) { - HANDLE hFile = handle.Cast().fd; - - // https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-getfiletime - FILETIME last_write; - if (GetFileTime(hFile, nullptr, nullptr, &last_write) == 0) { - return -1; - } - - // https://stackoverflow.com/questions/29266743/what-is-dwlowdatetime-and-dwhighdatetime - ULARGE_INTEGER ul; - ul.LowPart = last_write.dwLowDateTime; - ul.HighPart = last_write.dwHighDateTime; - int64_t fileTime64 = ul.QuadPart; - - // fileTime64 contains a 64-bit value representing the number of - // 100-nanosecond intervals since January 1, 1601 (UTC). - // https://docs.microsoft.com/en-us/windows/win32/api/minwinbase/ns-minwinbase-filetime - - // Adapted from: https://stackoverflow.com/questions/6161776/convert-windows-filetime-to-second-in-unix-linux - const auto WINDOWS_TICK = 10000000; - const auto SEC_TO_UNIX_EPOCH = 11644473600LL; - time_t result = (fileTime64 / WINDOWS_TICK - SEC_TO_UNIX_EPOCH); - return result; -} - -void LocalFileSystem::Truncate(FileHandle &handle, int64_t new_size) { - HANDLE hFile = handle.Cast().fd; - // seek to the location - SetFilePointer(handle, new_size); - // now set the end of file position - if (!SetEndOfFile(hFile)) { - auto error = LocalFileSystem::GetLastErrorAsString(); - throw IOException("Failure in SetEndOfFile call on file \"%s\": %s", handle.path, error); - } -} - -static DWORD WindowsGetFileAttributes(const string &filename) { - auto unicode_path = WindowsUtil::UTF8ToUnicode(filename.c_str()); - return GetFileAttributesW(unicode_path.c_str()); -} - -bool LocalFileSystem::DirectoryExists(const string &directory) { - DWORD attrs = WindowsGetFileAttributes(directory); - return (attrs != INVALID_FILE_ATTRIBUTES && (attrs & FILE_ATTRIBUTE_DIRECTORY)); -} - -void LocalFileSystem::CreateDirectory(const string &directory) { - if (DirectoryExists(directory)) { - return; - } - auto unicode_path = WindowsUtil::UTF8ToUnicode(directory.c_str()); - if (directory.empty() || !CreateDirectoryW(unicode_path.c_str(), NULL) || !DirectoryExists(directory)) { - throw IOException("Could not create directory: \'%s\'", directory.c_str()); - } -} - -static void DeleteDirectoryRecursive(FileSystem &fs, string directory) { - fs.ListFiles(directory, [&](const string &fname, bool is_directory) { - if (is_directory) { - DeleteDirectoryRecursive(fs, fs.JoinPath(directory, fname)); - } else { - fs.RemoveFile(fs.JoinPath(directory, fname)); - } - }); - auto unicode_path = WindowsUtil::UTF8ToUnicode(directory.c_str()); - if (!RemoveDirectoryW(unicode_path.c_str())) { - auto error = LocalFileSystem::GetLastErrorAsString(); - throw IOException("Failed to delete directory \"%s\": %s", directory, error); - } -} - -void LocalFileSystem::RemoveDirectory(const string &directory) { - if (FileExists(directory)) { - throw IOException("Attempting to delete directory \"%s\", but it is a file and not a directory!", directory); - } - if (!DirectoryExists(directory)) { - return; - } - DeleteDirectoryRecursive(*this, directory.c_str()); -} - -void LocalFileSystem::RemoveFile(const string &filename) { - auto unicode_path = WindowsUtil::UTF8ToUnicode(filename.c_str()); - if (!DeleteFileW(unicode_path.c_str())) { - auto error = LocalFileSystem::GetLastErrorAsString(); - throw IOException("Failed to delete file \"%s\": %s", filename, error); - } -} - -bool LocalFileSystem::ListFiles(const string &directory, const std::function &callback, - FileOpener *opener) { - string search_dir = JoinPath(directory, "*"); - - auto unicode_path = WindowsUtil::UTF8ToUnicode(search_dir.c_str()); - - WIN32_FIND_DATAW ffd; - HANDLE hFind = FindFirstFileW(unicode_path.c_str(), &ffd); - if (hFind == INVALID_HANDLE_VALUE) { - return false; - } - do { - string cFileName = WindowsUtil::UnicodeToUTF8(ffd.cFileName); - if (cFileName == "." || cFileName == "..") { - continue; - } - callback(cFileName, ffd.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY); - } while (FindNextFileW(hFind, &ffd) != 0); - - DWORD dwError = GetLastError(); - if (dwError != ERROR_NO_MORE_FILES) { - FindClose(hFind); - return false; - } - - FindClose(hFind); - return true; -} - -void LocalFileSystem::FileSync(FileHandle &handle) { - HANDLE hFile = handle.Cast().fd; - if (FlushFileBuffers(hFile) == 0) { - throw IOException("Could not flush file handle to disk!"); - } -} - -void LocalFileSystem::MoveFile(const string &source, const string &target) { - auto source_unicode = WindowsUtil::UTF8ToUnicode(source.c_str()); - auto target_unicode = WindowsUtil::UTF8ToUnicode(target.c_str()); - if (!MoveFileW(source_unicode.c_str(), target_unicode.c_str())) { - throw IOException("Could not move file: %s", GetLastErrorAsString()); - } -} - -FileType LocalFileSystem::GetFileType(FileHandle &handle) { - auto path = handle.Cast().path; - // pipes in windows are just files in '\\.\pipe\' folder - if (strncmp(path.c_str(), PIPE_PREFIX, strlen(PIPE_PREFIX)) == 0) { - return FileType::FILE_TYPE_FIFO; - } - DWORD attrs = WindowsGetFileAttributes(path.c_str()); - if (attrs != INVALID_FILE_ATTRIBUTES) { - if (attrs & FILE_ATTRIBUTE_DIRECTORY) { - return FileType::FILE_TYPE_DIR; - } else { - return FileType::FILE_TYPE_REGULAR; - } - } - return FileType::FILE_TYPE_INVALID; -} -#endif - -bool LocalFileSystem::CanSeek() { - return true; -} - -bool LocalFileSystem::OnDiskFile(FileHandle &handle) { - return true; -} - -void LocalFileSystem::Seek(FileHandle &handle, idx_t location) { - if (!CanSeek()) { - throw IOException("Cannot seek in files of this type"); - } - SetFilePointer(handle, location); -} - -idx_t LocalFileSystem::SeekPosition(FileHandle &handle) { - if (!CanSeek()) { - throw IOException("Cannot seek in files of this type"); - } - return GetFilePointer(handle); -} - -static bool IsCrawl(const string &glob) { - // glob must match exactly - return glob == "**"; -} -static bool HasMultipleCrawl(const vector &splits) { - return std::count(splits.begin(), splits.end(), "**") > 1; -} -static bool IsSymbolicLink(const string &path) { -#ifndef _WIN32 - struct stat status; - return (lstat(path.c_str(), &status) != -1 && S_ISLNK(status.st_mode)); -#else - auto attributes = WindowsGetFileAttributes(path); - if (attributes == INVALID_FILE_ATTRIBUTES) - return false; - return attributes & FILE_ATTRIBUTE_REPARSE_POINT; -#endif -} - -static void RecursiveGlobDirectories(FileSystem &fs, const string &path, vector &result, bool match_directory, - bool join_path) { - - fs.ListFiles(path, [&](const string &fname, bool is_directory) { - string concat; - if (join_path) { - concat = fs.JoinPath(path, fname); - } else { - concat = fname; - } - if (IsSymbolicLink(concat)) { - return; - } - if (is_directory == match_directory) { - result.push_back(concat); - } - if (is_directory) { - RecursiveGlobDirectories(fs, concat, result, match_directory, true); - } - }); -} - -static void GlobFilesInternal(FileSystem &fs, const string &path, const string &glob, bool match_directory, - vector &result, bool join_path) { - fs.ListFiles(path, [&](const string &fname, bool is_directory) { - if (is_directory != match_directory) { - return; - } - if (LikeFun::Glob(fname.c_str(), fname.size(), glob.c_str(), glob.size())) { - if (join_path) { - result.push_back(fs.JoinPath(path, fname)); - } else { - result.push_back(fname); - } - } - }); -} - -vector LocalFileSystem::FetchFileWithoutGlob(const string &path, FileOpener *opener, bool absolute_path) { - vector result; - if (FileExists(path) || IsPipe(path)) { - result.push_back(path); - } else if (!absolute_path) { - Value value; - if (opener && opener->TryGetCurrentSetting("file_search_path", value)) { - auto search_paths_str = value.ToString(); - vector search_paths = StringUtil::Split(search_paths_str, ','); - for (const auto &search_path : search_paths) { - auto joined_path = JoinPath(search_path, path); - if (FileExists(joined_path) || IsPipe(joined_path)) { - result.push_back(joined_path); - } - } - } - } - return result; -} - -vector LocalFileSystem::Glob(const string &path, FileOpener *opener) { - if (path.empty()) { - return vector(); - } - // split up the path into separate chunks - vector splits; - idx_t last_pos = 0; - for (idx_t i = 0; i < path.size(); i++) { - if (path[i] == '\\' || path[i] == '/') { - if (i == last_pos) { - // empty: skip this position - last_pos = i + 1; - continue; - } - if (splits.empty()) { - splits.push_back(path.substr(0, i)); - } else { - splits.push_back(path.substr(last_pos, i - last_pos)); - } - last_pos = i + 1; - } - } - splits.push_back(path.substr(last_pos, path.size() - last_pos)); - // handle absolute paths - bool absolute_path = false; - if (path[0] == '/') { - // first character is a slash - unix absolute path - absolute_path = true; - } else if (StringUtil::Contains(splits[0], ":")) { - // first split has a colon - windows absolute path - absolute_path = true; - } else if (splits[0] == "~") { - // starts with home directory - auto home_directory = GetHomeDirectory(opener); - if (!home_directory.empty()) { - absolute_path = true; - splits[0] = home_directory; - D_ASSERT(path[0] == '~'); - if (!HasGlob(path)) { - return Glob(home_directory + path.substr(1)); - } - } - } - // Check if the path has a glob at all - if (!HasGlob(path)) { - // no glob: return only the file (if it exists or is a pipe) - return FetchFileWithoutGlob(path, opener, absolute_path); - } - vector previous_directories; - if (absolute_path) { - // for absolute paths, we don't start by scanning the current directory - previous_directories.push_back(splits[0]); - } else { - // If file_search_path is set, use those paths as the first glob elements - Value value; - if (opener && opener->TryGetCurrentSetting("file_search_path", value)) { - auto search_paths_str = value.ToString(); - vector search_paths = StringUtil::Split(search_paths_str, ','); - for (const auto &search_path : search_paths) { - previous_directories.push_back(search_path); - } - } - } - - if (HasMultipleCrawl(splits)) { - throw IOException("Cannot use multiple \'**\' in one path"); - } - - for (idx_t i = absolute_path ? 1 : 0; i < splits.size(); i++) { - bool is_last_chunk = i + 1 == splits.size(); - bool has_glob = HasGlob(splits[i]); - // if it's the last chunk we need to find files, otherwise we find directories - // not the last chunk: gather a list of all directories that match the glob pattern - vector result; - if (!has_glob) { - // no glob, just append as-is - if (previous_directories.empty()) { - result.push_back(splits[i]); - } else { - if (is_last_chunk) { - for (auto &prev_directory : previous_directories) { - const string filename = JoinPath(prev_directory, splits[i]); - if (FileExists(filename) || DirectoryExists(filename)) { - result.push_back(filename); - } - } - } else { - for (auto &prev_directory : previous_directories) { - result.push_back(JoinPath(prev_directory, splits[i])); - } - } - } - } else { - if (IsCrawl(splits[i])) { - if (!is_last_chunk) { - result = previous_directories; - } - if (previous_directories.empty()) { - RecursiveGlobDirectories(*this, ".", result, !is_last_chunk, false); - } else { - for (auto &prev_dir : previous_directories) { - RecursiveGlobDirectories(*this, prev_dir, result, !is_last_chunk, true); - } - } - } else { - if (previous_directories.empty()) { - // no previous directories: list in the current path - GlobFilesInternal(*this, ".", splits[i], !is_last_chunk, result, false); - } else { - // previous directories - // we iterate over each of the previous directories, and apply the glob of the current directory - for (auto &prev_directory : previous_directories) { - GlobFilesInternal(*this, prev_directory, splits[i], !is_last_chunk, result, true); - } - } - } - } - if (result.empty()) { - // no result found that matches the glob - // last ditch effort: search the path as a string literal - return FetchFileWithoutGlob(path, opener, absolute_path); - } - if (is_last_chunk) { - return result; - } - previous_directories = std::move(result); - } - return vector(); -} - -unique_ptr FileSystem::CreateLocal() { - return make_uniq(); -} - -} // namespace duckdb - - - - - - - - - - -namespace duckdb { - -void MultiFileReader::AddParameters(TableFunction &table_function) { - table_function.named_parameters["filename"] = LogicalType::BOOLEAN; - table_function.named_parameters["hive_partitioning"] = LogicalType::BOOLEAN; - table_function.named_parameters["union_by_name"] = LogicalType::BOOLEAN; - table_function.named_parameters["hive_types"] = LogicalType::ANY; - table_function.named_parameters["hive_types_autocast"] = LogicalType::BOOLEAN; -} - -vector MultiFileReader::GetFileList(ClientContext &context, const Value &input, const string &name, - FileGlobOptions options) { - auto &config = DBConfig::GetConfig(context); - if (!config.options.enable_external_access) { - throw PermissionException("Scanning %s files is disabled through configuration", name); - } - if (input.IsNull()) { - throw ParserException("%s reader cannot take NULL list as parameter", name); - } - FileSystem &fs = FileSystem::GetFileSystem(context); - vector files; - if (input.type().id() == LogicalTypeId::VARCHAR) { - auto file_name = StringValue::Get(input); - files = fs.GlobFiles(file_name, context, options); - } else if (input.type().id() == LogicalTypeId::LIST) { - for (auto &val : ListValue::GetChildren(input)) { - if (val.IsNull()) { - throw ParserException("%s reader cannot take NULL input as parameter", name); - } - if (val.type().id() != LogicalTypeId::VARCHAR) { - throw ParserException("%s reader can only take a list of strings as a parameter", name); - } - auto glob_files = fs.GlobFiles(StringValue::Get(val), context, options); - files.insert(files.end(), glob_files.begin(), glob_files.end()); - } - } else { - throw InternalException("Unsupported type for MultiFileReader::GetFileList"); - } - if (files.empty() && options == FileGlobOptions::DISALLOW_EMPTY) { - throw IOException("%s reader needs at least one file to read", name); - } - return files; -} - -bool MultiFileReader::ParseOption(const string &key, const Value &val, MultiFileReaderOptions &options, - ClientContext &context) { - auto loption = StringUtil::Lower(key); - if (loption == "filename") { - options.filename = BooleanValue::Get(val); - } else if (loption == "hive_partitioning") { - options.hive_partitioning = BooleanValue::Get(val); - options.auto_detect_hive_partitioning = false; - } else if (loption == "union_by_name") { - options.union_by_name = BooleanValue::Get(val); - } else if (loption == "hive_types_autocast" || loption == "hive_type_autocast") { - options.hive_types_autocast = BooleanValue::Get(val); - } else if (loption == "hive_types" || loption == "hive_type") { - if (val.type().id() != LogicalTypeId::STRUCT) { - throw InvalidInputException( - "'hive_types' only accepts a STRUCT('name':VARCHAR, ...), but '%s' was provided", - val.type().ToString()); - } - // verify that that all the children of the struct value are VARCHAR - auto &children = StructValue::GetChildren(val); - for (idx_t i = 0; i < children.size(); i++) { - const Value &child = children[i]; - if (child.type().id() != LogicalType::VARCHAR) { - throw InvalidInputException("hive_types: '%s' must be a VARCHAR, instead: '%s' was provided", - StructType::GetChildName(val.type(), i), child.type().ToString()); - } - // for every child of the struct, get the logical type - LogicalType transformed_type = TransformStringToLogicalType(child.ToString(), context); - const string &name = StructType::GetChildName(val.type(), i); - options.hive_types_schema[name] = transformed_type; - } - D_ASSERT(!options.hive_types_schema.empty()); - } else { - return false; - } - return true; -} - -bool MultiFileReader::ComplexFilterPushdown(ClientContext &context, vector &files, - const MultiFileReaderOptions &options, LogicalGet &get, - vector> &filters) { - if (files.empty()) { - return false; - } - if (!options.hive_partitioning && !options.filename) { - return false; - } - - unordered_map column_map; - for (idx_t i = 0; i < get.column_ids.size(); i++) { - column_map.insert({get.names[get.column_ids[i]], i}); - } - - auto start_files = files.size(); - HivePartitioning::ApplyFiltersToFileList(context, files, filters, column_map, get, options.hive_partitioning, - options.filename); - - if (files.size() != start_files) { - // we have pruned files - return true; - } - return false; -} - -MultiFileReaderBindData MultiFileReader::BindOptions(MultiFileReaderOptions &options, const vector &files, - vector &return_types, vector &names) { - MultiFileReaderBindData bind_data; - // Add generated constant column for filename - if (options.filename) { - if (std::find(names.begin(), names.end(), "filename") != names.end()) { - throw BinderException("Using filename option on file with column named filename is not supported"); - } - bind_data.filename_idx = names.size(); - return_types.emplace_back(LogicalType::VARCHAR); - names.emplace_back("filename"); - } - - // Add generated constant columns from hive partitioning scheme - if (options.hive_partitioning) { - D_ASSERT(!files.empty()); - auto partitions = HivePartitioning::Parse(files[0]); - // verify that all files have the same hive partitioning scheme - for (auto &f : files) { - auto file_partitions = HivePartitioning::Parse(f); - for (auto &part_info : partitions) { - if (file_partitions.find(part_info.first) == file_partitions.end()) { - string error = "Hive partition mismatch between file \"%s\" and \"%s\": key \"%s\" not found"; - if (options.auto_detect_hive_partitioning == true) { - throw InternalException(error + "(hive partitioning was autodetected)", files[0], f, - part_info.first); - } - throw BinderException(error.c_str(), files[0], f, part_info.first); - } - } - if (partitions.size() != file_partitions.size()) { - string error_msg = "Hive partition mismatch between file \"%s\" and \"%s\""; - if (options.auto_detect_hive_partitioning == true) { - throw InternalException(error_msg + "(hive partitioning was autodetected)", files[0], f); - } - throw BinderException(error_msg.c_str(), files[0], f); - } - } - - if (!options.hive_types_schema.empty()) { - // verify that all hive_types are existing partitions - options.VerifyHiveTypesArePartitions(partitions); - } - - for (auto &part : partitions) { - idx_t hive_partitioning_index = DConstants::INVALID_INDEX; - auto lookup = std::find(names.begin(), names.end(), part.first); - if (lookup != names.end()) { - // hive partitioning column also exists in file - override - auto idx = lookup - names.begin(); - hive_partitioning_index = idx; - return_types[idx] = options.GetHiveLogicalType(part.first); - } else { - // hive partitioning column does not exist in file - add a new column containing the key - hive_partitioning_index = names.size(); - return_types.emplace_back(options.GetHiveLogicalType(part.first)); - names.emplace_back(part.first); - } - bind_data.hive_partitioning_indexes.emplace_back(part.first, hive_partitioning_index); - } - } - return bind_data; -} - -void MultiFileReader::FinalizeBind(const MultiFileReaderOptions &file_options, const MultiFileReaderBindData &options, - const string &filename, const vector &local_names, - const vector &global_types, const vector &global_names, - const vector &global_column_ids, MultiFileReaderData &reader_data, - ClientContext &context) { - - // create a map of name -> column index - case_insensitive_map_t name_map; - if (file_options.union_by_name) { - for (idx_t col_idx = 0; col_idx < local_names.size(); col_idx++) { - name_map[local_names[col_idx]] = col_idx; - } - } - for (idx_t i = 0; i < global_column_ids.size(); i++) { - auto column_id = global_column_ids[i]; - if (IsRowIdColumnId(column_id)) { - // row-id - reader_data.constant_map.emplace_back(i, Value::BIGINT(42)); - continue; - } - if (column_id == options.filename_idx) { - // filename - reader_data.constant_map.emplace_back(i, Value(filename)); - continue; - } - if (!options.hive_partitioning_indexes.empty()) { - // hive partition constants - auto partitions = HivePartitioning::Parse(filename); - D_ASSERT(partitions.size() == options.hive_partitioning_indexes.size()); - bool found_partition = false; - for (auto &entry : options.hive_partitioning_indexes) { - if (column_id == entry.index) { - Value value = file_options.GetHivePartitionValue(partitions[entry.value], entry.value, context); - reader_data.constant_map.emplace_back(i, value); - found_partition = true; - break; - } - } - if (found_partition) { - continue; - } - } - if (file_options.union_by_name) { - auto &global_name = global_names[column_id]; - auto entry = name_map.find(global_name); - bool not_present_in_file = entry == name_map.end(); - if (not_present_in_file) { - // we need to project a column with name \"global_name\" - but it does not exist in the current file - // push a NULL value of the specified type - reader_data.constant_map.emplace_back(i, Value(global_types[column_id])); - continue; - } - } - } -} - -void MultiFileReader::CreateNameMapping(const string &file_name, const vector &local_types, - const vector &local_names, const vector &global_types, - const vector &global_names, const vector &global_column_ids, - MultiFileReaderData &reader_data, const string &initial_file) { - D_ASSERT(global_types.size() == global_names.size()); - D_ASSERT(local_types.size() == local_names.size()); - // we have expected types: create a map of name -> column index - case_insensitive_map_t name_map; - for (idx_t col_idx = 0; col_idx < local_names.size(); col_idx++) { - name_map[local_names[col_idx]] = col_idx; - } - for (idx_t i = 0; i < global_column_ids.size(); i++) { - // check if this is a constant column - bool constant = false; - for (auto &entry : reader_data.constant_map) { - if (entry.column_id == i) { - constant = true; - break; - } - } - if (constant) { - // this column is constant for this file - continue; - } - // not constant - look up the column in the name map - auto global_id = global_column_ids[i]; - if (global_id >= global_types.size()) { - throw InternalException( - "MultiFileReader::CreatePositionalMapping - global_id is out of range in global_types for this file"); - } - auto &global_name = global_names[global_id]; - auto entry = name_map.find(global_name); - if (entry == name_map.end()) { - string candidate_names; - for (auto &local_name : local_names) { - if (!candidate_names.empty()) { - candidate_names += ", "; - } - candidate_names += local_name; - } - throw IOException( - StringUtil::Format("Failed to read file \"%s\": schema mismatch in glob: column \"%s\" was read from " - "the original file \"%s\", but could not be found in file \"%s\".\nCandidate names: " - "%s\nIf you are trying to " - "read files with different schemas, try setting union_by_name=True", - file_name, global_name, initial_file, file_name, candidate_names)); - } - // we found the column in the local file - check if the types are the same - auto local_id = entry->second; - D_ASSERT(global_id < global_types.size()); - D_ASSERT(local_id < local_types.size()); - auto &global_type = global_types[global_id]; - auto &local_type = local_types[local_id]; - if (global_type != local_type) { - reader_data.cast_map[local_id] = global_type; - } - // the types are the same - create the mapping - reader_data.column_mapping.push_back(i); - reader_data.column_ids.push_back(local_id); - } - reader_data.empty_columns = reader_data.column_ids.empty(); -} - -void MultiFileReader::CreateMapping(const string &file_name, const vector &local_types, - const vector &local_names, const vector &global_types, - const vector &global_names, const vector &global_column_ids, - optional_ptr filters, MultiFileReaderData &reader_data, - const string &initial_file) { - CreateNameMapping(file_name, local_types, local_names, global_types, global_names, global_column_ids, reader_data, - initial_file); - if (filters) { - reader_data.filter_map.resize(global_types.size()); - for (idx_t c = 0; c < reader_data.column_mapping.size(); c++) { - auto map_index = reader_data.column_mapping[c]; - reader_data.filter_map[map_index].index = c; - reader_data.filter_map[map_index].is_constant = false; - } - for (idx_t c = 0; c < reader_data.constant_map.size(); c++) { - auto constant_index = reader_data.constant_map[c].column_id; - reader_data.filter_map[constant_index].index = c; - reader_data.filter_map[constant_index].is_constant = true; - } - } -} - -void MultiFileReader::FinalizeChunk(const MultiFileReaderBindData &bind_data, const MultiFileReaderData &reader_data, - DataChunk &chunk) { - // reference all the constants set up in MultiFileReader::FinalizeBind - for (auto &entry : reader_data.constant_map) { - chunk.data[entry.column_id].Reference(entry.value); - } - chunk.Verify(); -} - -TableFunctionSet MultiFileReader::CreateFunctionSet(TableFunction table_function) { - TableFunctionSet function_set(table_function.name); - function_set.AddFunction(table_function); - D_ASSERT(table_function.arguments.size() == 1 && table_function.arguments[0] == LogicalType::VARCHAR); - table_function.arguments[0] = LogicalType::LIST(LogicalType::VARCHAR); - function_set.AddFunction(std::move(table_function)); - return function_set; -} - -HivePartitioningIndex::HivePartitioningIndex(string value_p, idx_t index) : value(std::move(value_p)), index(index) { -} - -void MultiFileReaderOptions::AddBatchInfo(BindInfo &bind_info) const { - bind_info.InsertOption("filename", Value::BOOLEAN(filename)); - bind_info.InsertOption("hive_partitioning", Value::BOOLEAN(hive_partitioning)); - bind_info.InsertOption("auto_detect_hive_partitioning", Value::BOOLEAN(auto_detect_hive_partitioning)); - bind_info.InsertOption("union_by_name", Value::BOOLEAN(union_by_name)); - bind_info.InsertOption("hive_types_autocast", Value::BOOLEAN(hive_types_autocast)); -} - -void UnionByName::CombineUnionTypes(const vector &col_names, const vector &sql_types, - vector &union_col_types, vector &union_col_names, - case_insensitive_map_t &union_names_map) { - D_ASSERT(col_names.size() == sql_types.size()); - - for (idx_t col = 0; col < col_names.size(); ++col) { - auto union_find = union_names_map.find(col_names[col]); - - if (union_find != union_names_map.end()) { - // given same name , union_col's type must compatible with col's type - auto ¤t_type = union_col_types[union_find->second]; - LogicalType compatible_type; - compatible_type = LogicalType::MaxLogicalType(current_type, sql_types[col]); - union_col_types[union_find->second] = compatible_type; - } else { - union_names_map[col_names[col]] = union_col_names.size(); - union_col_names.emplace_back(col_names[col]); - union_col_types.emplace_back(sql_types[col]); - } - } -} - -bool MultiFileReaderOptions::AutoDetectHivePartitioningInternal(const vector &files, ClientContext &context) { - std::unordered_set partitions; - auto &fs = FileSystem::GetFileSystem(context); - - auto splits_first_file = StringUtil::Split(files.front(), fs.PathSeparator(files.front())); - if (splits_first_file.size() < 2) { - return false; - } - for (auto it = splits_first_file.begin(); it != splits_first_file.end(); it++) { - auto partition = StringUtil::Split(*it, "="); - if (partition.size() == 2) { - partitions.insert(partition.front()); - } - } - if (partitions.empty()) { - return false; - } - for (auto &file : files) { - auto splits = StringUtil::Split(file, fs.PathSeparator(file)); - if (splits.size() != splits_first_file.size()) { - return false; - } - for (auto it = splits.begin(); it != std::prev(splits.end()); it++) { - auto part = StringUtil::Split(*it, "="); - if (part.size() != 2) { - continue; - } - if (partitions.find(part.front()) == partitions.end()) { - return false; - } - } - } - return true; -} -void MultiFileReaderOptions::AutoDetectHiveTypesInternal(const string &file, ClientContext &context) { - auto &fs = FileSystem::GetFileSystem(context); - - std::map partitions; - auto splits = StringUtil::Split(file, fs.PathSeparator(file)); - if (splits.size() < 2) { - return; - } - for (auto it = splits.begin(); it != std::prev(splits.end()); it++) { - auto part = StringUtil::Split(*it, "="); - if (part.size() == 2) { - partitions[part.front()] = part.back(); - } - } - if (partitions.empty()) { - return; - } - - const LogicalType candidates[] = {LogicalType::DATE, LogicalType::TIMESTAMP, LogicalType::BIGINT}; - for (auto &part : partitions) { - const string &name = part.first; - if (hive_types_schema.find(name) != hive_types_schema.end()) { - continue; - } - Value value(part.second); - for (auto &candidate : candidates) { - const bool success = value.TryCastAs(context, candidate); - if (success) { - hive_types_schema[name] = candidate; - break; - } - } - } -} -void MultiFileReaderOptions::AutoDetectHivePartitioning(const vector &files, ClientContext &context) { - D_ASSERT(!files.empty()); - const bool hp_explicitly_disabled = !auto_detect_hive_partitioning && !hive_partitioning; - const bool ht_enabled = !hive_types_schema.empty(); - if (hp_explicitly_disabled && ht_enabled) { - throw InvalidInputException("cannot disable hive_partitioning when hive_types is enabled"); - } - if (ht_enabled && auto_detect_hive_partitioning && !hive_partitioning) { - // hive_types flag implies hive_partitioning - hive_partitioning = true; - auto_detect_hive_partitioning = false; - } - if (auto_detect_hive_partitioning) { - hive_partitioning = AutoDetectHivePartitioningInternal(files, context); - } - if (hive_partitioning && hive_types_autocast) { - AutoDetectHiveTypesInternal(files.front(), context); - } -} -void MultiFileReaderOptions::VerifyHiveTypesArePartitions(const std::map &partitions) const { - for (auto &hive_type : hive_types_schema) { - if (partitions.find(hive_type.first) == partitions.end()) { - throw InvalidInputException("Unknown hive_type: \"%s\" does not appear to be a partition", hive_type.first); - } - } -} -LogicalType MultiFileReaderOptions::GetHiveLogicalType(const string &hive_partition_column) const { - if (!hive_types_schema.empty()) { - auto it = hive_types_schema.find(hive_partition_column); - if (it != hive_types_schema.end()) { - return it->second; - } - } - return LogicalType::VARCHAR; -} -Value MultiFileReaderOptions::GetHivePartitionValue(const string &base, const string &entry, - ClientContext &context) const { - Value value(base); - auto it = hive_types_schema.find(entry); - if (it == hive_types_schema.end()) { - return value; - } - - // Handle nulls - if (base.empty() || StringUtil::CIEquals(base, "NULL")) { - return Value(it->second); - } - - if (!value.TryCastAs(context, it->second)) { - throw InvalidInputException("Unable to cast '%s' (from hive partition column '%s') to: '%s'", value.ToString(), - StringUtil::Upper(it->first), it->second.ToString()); - } - return value; -} - -} // namespace duckdb - -#endif diff --git a/lib/duckdb-10.cpp b/lib/duckdb-10.cpp deleted file mode 100644 index de5fec03..00000000 --- a/lib/duckdb-10.cpp +++ /dev/null @@ -1,17321 +0,0 @@ -#ifdef GODUCKDB_FROM_SOURCE -// See https://raw.githubusercontent.com/duckdb/duckdb/main/LICENSE for licensing information - -#include "duckdb.hpp" -#include "duckdb-internal.hpp" -#ifndef DUCKDB_AMALGAMATION -#error header mismatch -#endif - - - - - - - -namespace duckdb { - -Index::Index(AttachedDatabase &db, IndexType type, TableIOManager &table_io_manager, - const vector &column_ids_p, const vector> &unbound_expressions, - IndexConstraintType constraint_type_p) - - : type(type), table_io_manager(table_io_manager), column_ids(column_ids_p), constraint_type(constraint_type_p), - db(db) { - - for (auto &expr : unbound_expressions) { - types.push_back(expr->return_type.InternalType()); - logical_types.push_back(expr->return_type); - auto unbound_expression = expr->Copy(); - bound_expressions.push_back(BindExpression(unbound_expression->Copy())); - this->unbound_expressions.emplace_back(std::move(unbound_expression)); - } - for (auto &bound_expr : bound_expressions) { - executor.AddExpression(*bound_expr); - } - - // create the column id set - column_id_set.insert(column_ids.begin(), column_ids.end()); -} - -void Index::InitializeLock(IndexLock &state) { - state.index_lock = unique_lock(lock); -} - -PreservedError Index::Append(DataChunk &entries, Vector &row_identifiers) { - IndexLock state; - InitializeLock(state); - return Append(state, entries, row_identifiers); -} - -void Index::CommitDrop() { - IndexLock index_lock; - InitializeLock(index_lock); - CommitDrop(index_lock); -} - -void Index::Delete(DataChunk &entries, Vector &row_identifiers) { - IndexLock state; - InitializeLock(state); - Delete(state, entries, row_identifiers); -} - -bool Index::MergeIndexes(Index &other_index) { - IndexLock state; - InitializeLock(state); - return MergeIndexes(state, other_index); -} - -string Index::VerifyAndToString(const bool only_verify) { - IndexLock state; - InitializeLock(state); - return VerifyAndToString(state, only_verify); -} - -void Index::Vacuum() { - IndexLock state; - InitializeLock(state); - Vacuum(state); -} - -void Index::ExecuteExpressions(DataChunk &input, DataChunk &result) { - executor.Execute(input, result); -} - -unique_ptr Index::BindExpression(unique_ptr expr) { - if (expr->type == ExpressionType::BOUND_COLUMN_REF) { - auto &bound_colref = expr->Cast(); - return make_uniq(expr->return_type, column_ids[bound_colref.binding.column_index]); - } - ExpressionIterator::EnumerateChildren( - *expr, [this](unique_ptr &expr) { expr = BindExpression(std::move(expr)); }); - return expr; -} - -bool Index::IndexIsUpdated(const vector &column_ids) const { - for (auto &column : column_ids) { - if (column_id_set.find(column.index) != column_id_set.end()) { - return true; - } - } - return false; -} - -BlockPointer Index::Serialize(MetadataWriter &writer) { - throw NotImplementedException("The implementation of this index serialization does not exist."); -} - -string Index::AppendRowError(DataChunk &input, idx_t index) { - string error; - for (idx_t c = 0; c < input.ColumnCount(); c++) { - if (c > 0) { - error += ", "; - } - error += input.GetValue(c, index).ToString(); - } - return error; -} - -} // namespace duckdb - - - - - - - - - - - - - - -namespace duckdb { - -LocalTableStorage::LocalTableStorage(DataTable &table) - : table_ref(table), allocator(Allocator::Get(table.db)), deleted_rows(0), optimistic_writer(table), - merged_storage(false) { - auto types = table.GetTypes(); - row_groups = make_shared(table.info, TableIOManager::Get(table).GetBlockManagerForRowData(), - types, MAX_ROW_ID, 0); - row_groups->InitializeEmpty(); - - table.info->indexes.Scan([&](Index &index) { - D_ASSERT(index.type == IndexType::ART); - auto &art = index.Cast(); - if (art.constraint_type != IndexConstraintType::NONE) { - // unique index: create a local ART index that maintains the same unique constraint - vector> unbound_expressions; - unbound_expressions.reserve(art.unbound_expressions.size()); - for (auto &expr : art.unbound_expressions) { - unbound_expressions.push_back(expr->Copy()); - } - indexes.AddIndex(make_uniq(art.column_ids, art.table_io_manager, std::move(unbound_expressions), - art.constraint_type, art.db)); - } - return false; - }); -} - -LocalTableStorage::LocalTableStorage(ClientContext &context, DataTable &new_dt, LocalTableStorage &parent, - idx_t changed_idx, const LogicalType &target_type, - const vector &bound_columns, Expression &cast_expr) - : table_ref(new_dt), allocator(Allocator::Get(new_dt.db)), deleted_rows(parent.deleted_rows), - optimistic_writer(new_dt, parent.optimistic_writer), optimistic_writers(std::move(parent.optimistic_writers)), - merged_storage(parent.merged_storage) { - row_groups = parent.row_groups->AlterType(context, changed_idx, target_type, bound_columns, cast_expr); - parent.row_groups.reset(); - indexes.Move(parent.indexes); -} - -LocalTableStorage::LocalTableStorage(DataTable &new_dt, LocalTableStorage &parent, idx_t drop_idx) - : table_ref(new_dt), allocator(Allocator::Get(new_dt.db)), deleted_rows(parent.deleted_rows), - optimistic_writer(new_dt, parent.optimistic_writer), optimistic_writers(std::move(parent.optimistic_writers)), - merged_storage(parent.merged_storage) { - row_groups = parent.row_groups->RemoveColumn(drop_idx); - parent.row_groups.reset(); - indexes.Move(parent.indexes); -} - -LocalTableStorage::LocalTableStorage(ClientContext &context, DataTable &new_dt, LocalTableStorage &parent, - ColumnDefinition &new_column, Expression &default_value) - : table_ref(new_dt), allocator(Allocator::Get(new_dt.db)), deleted_rows(parent.deleted_rows), - optimistic_writer(new_dt, parent.optimistic_writer), optimistic_writers(std::move(parent.optimistic_writers)), - merged_storage(parent.merged_storage) { - row_groups = parent.row_groups->AddColumn(context, new_column, default_value); - parent.row_groups.reset(); - indexes.Move(parent.indexes); -} - -LocalTableStorage::~LocalTableStorage() { -} - -void LocalTableStorage::InitializeScan(CollectionScanState &state, optional_ptr table_filters) { - if (row_groups->GetTotalRows() == 0) { - throw InternalException("No rows in LocalTableStorage row group for scan"); - } - row_groups->InitializeScan(state, state.GetColumnIds(), table_filters.get()); -} - -idx_t LocalTableStorage::EstimatedSize() { - idx_t appended_rows = row_groups->GetTotalRows() - deleted_rows; - idx_t row_size = 0; - auto &types = row_groups->GetTypes(); - for (auto &type : types) { - row_size += GetTypeIdSize(type.InternalType()); - } - return appended_rows * row_size; -} - -void LocalTableStorage::WriteNewRowGroup() { - if (deleted_rows != 0) { - // we have deletes - we cannot merge row groups - return; - } - optimistic_writer.WriteNewRowGroup(*row_groups); -} - -void LocalTableStorage::FlushBlocks() { - if (!merged_storage && row_groups->GetTotalRows() > Storage::ROW_GROUP_SIZE) { - optimistic_writer.WriteLastRowGroup(*row_groups); - } - optimistic_writer.FinalFlush(); -} - -PreservedError LocalTableStorage::AppendToIndexes(DuckTransaction &transaction, RowGroupCollection &source, - TableIndexList &index_list, const vector &table_types, - row_t &start_row) { - // only need to scan for index append - // figure out which columns we need to scan for the set of indexes - auto columns = index_list.GetRequiredColumns(); - // create an empty mock chunk that contains all the correct types for the table - DataChunk mock_chunk; - mock_chunk.InitializeEmpty(table_types); - PreservedError error; - source.Scan(transaction, columns, [&](DataChunk &chunk) -> bool { - // construct the mock chunk by referencing the required columns - for (idx_t i = 0; i < columns.size(); i++) { - mock_chunk.data[columns[i]].Reference(chunk.data[i]); - } - mock_chunk.SetCardinality(chunk); - // append this chunk to the indexes of the table - error = DataTable::AppendToIndexes(index_list, mock_chunk, start_row); - if (error) { - return false; - } - start_row += chunk.size(); - return true; - }); - return error; -} - -void LocalTableStorage::AppendToIndexes(DuckTransaction &transaction, TableAppendState &append_state, - idx_t append_count, bool append_to_table) { - auto &table = table_ref.get(); - if (append_to_table) { - table.InitializeAppend(transaction, append_state, append_count); - } - PreservedError error; - if (append_to_table) { - // appending: need to scan entire - row_groups->Scan(transaction, [&](DataChunk &chunk) -> bool { - // append this chunk to the indexes of the table - error = table.AppendToIndexes(chunk, append_state.current_row); - if (error) { - return false; - } - // append to base table - table.Append(chunk, append_state); - return true; - }); - } else { - error = - AppendToIndexes(transaction, *row_groups, table.info->indexes, table.GetTypes(), append_state.current_row); - } - if (error) { - // need to revert all appended row ids - row_t current_row = append_state.row_start; - // remove the data from the indexes, if there are any indexes - row_groups->Scan(transaction, [&](DataChunk &chunk) -> bool { - // append this chunk to the indexes of the table - try { - table.RemoveFromIndexes(append_state, chunk, current_row); - } catch (Exception &ex) { - error = PreservedError(ex); - return false; - } catch (std::exception &ex) { // LCOV_EXCL_START - error = PreservedError(ex); - return false; - } // LCOV_EXCL_STOP - - current_row += chunk.size(); - if (current_row >= append_state.current_row) { - // finished deleting all rows from the index: abort now - return false; - } - return true; - }); - if (append_to_table) { - table.RevertAppendInternal(append_state.row_start); - } - - // we need to vacuum the indexes to remove any buffers that are now empty - // due to reverting the appends - table.info->indexes.Scan([&](Index &index) { - index.Vacuum(); - return false; - }); - error.Throw(); - } -} - -OptimisticDataWriter &LocalTableStorage::CreateOptimisticWriter() { - auto writer = make_uniq(table_ref.get()); - optimistic_writers.push_back(std::move(writer)); - return *optimistic_writers.back(); -} - -void LocalTableStorage::FinalizeOptimisticWriter(OptimisticDataWriter &writer) { - // remove the writer from the set of optimistic writers - unique_ptr owned_writer; - for (idx_t i = 0; i < optimistic_writers.size(); i++) { - if (optimistic_writers[i].get() == &writer) { - owned_writer = std::move(optimistic_writers[i]); - optimistic_writers.erase(optimistic_writers.begin() + i); - break; - } - } - if (!owned_writer) { - throw InternalException("Error in FinalizeOptimisticWriter - could not find writer"); - } - optimistic_writer.Merge(*owned_writer); -} - -void LocalTableStorage::Rollback() { - for (auto &writer : optimistic_writers) { - writer->Rollback(); - } - optimistic_writers.clear(); - optimistic_writer.Rollback(); -} - -//===--------------------------------------------------------------------===// -// LocalTableManager -//===--------------------------------------------------------------------===// -optional_ptr LocalTableManager::GetStorage(DataTable &table) { - lock_guard l(table_storage_lock); - auto entry = table_storage.find(table); - return entry == table_storage.end() ? nullptr : entry->second.get(); -} - -LocalTableStorage &LocalTableManager::GetOrCreateStorage(DataTable &table) { - lock_guard l(table_storage_lock); - auto entry = table_storage.find(table); - if (entry == table_storage.end()) { - auto new_storage = make_shared(table); - auto storage = new_storage.get(); - table_storage.insert(make_pair(reference(table), std::move(new_storage))); - return *storage; - } else { - return *entry->second.get(); - } -} - -bool LocalTableManager::IsEmpty() { - lock_guard l(table_storage_lock); - return table_storage.empty(); -} - -shared_ptr LocalTableManager::MoveEntry(DataTable &table) { - lock_guard l(table_storage_lock); - auto entry = table_storage.find(table); - if (entry == table_storage.end()) { - return nullptr; - } - auto storage_entry = std::move(entry->second); - table_storage.erase(entry); - return storage_entry; -} - -reference_map_t> LocalTableManager::MoveEntries() { - lock_guard l(table_storage_lock); - return std::move(table_storage); -} - -idx_t LocalTableManager::EstimatedSize() { - lock_guard l(table_storage_lock); - idx_t estimated_size = 0; - for (auto &storage : table_storage) { - estimated_size += storage.second->EstimatedSize(); - } - return estimated_size; -} - -void LocalTableManager::InsertEntry(DataTable &table, shared_ptr entry) { - lock_guard l(table_storage_lock); - D_ASSERT(table_storage.find(table) == table_storage.end()); - table_storage[table] = std::move(entry); -} - -//===--------------------------------------------------------------------===// -// LocalStorage -//===--------------------------------------------------------------------===// -LocalStorage::LocalStorage(ClientContext &context, DuckTransaction &transaction) - : context(context), transaction(transaction) { -} - -LocalStorage::CommitState::CommitState() { -} - -LocalStorage::CommitState::~CommitState() { -} - -LocalStorage &LocalStorage::Get(DuckTransaction &transaction) { - return transaction.GetLocalStorage(); -} - -LocalStorage &LocalStorage::Get(ClientContext &context, AttachedDatabase &db) { - return DuckTransaction::Get(context, db).GetLocalStorage(); -} - -LocalStorage &LocalStorage::Get(ClientContext &context, Catalog &catalog) { - return LocalStorage::Get(context, catalog.GetAttached()); -} - -void LocalStorage::InitializeScan(DataTable &table, CollectionScanState &state, - optional_ptr table_filters) { - auto storage = table_manager.GetStorage(table); - if (storage == nullptr) { - return; - } - storage->InitializeScan(state, table_filters); -} - -void LocalStorage::Scan(CollectionScanState &state, const vector &column_ids, DataChunk &result) { - state.Scan(transaction, result); -} - -void LocalStorage::InitializeParallelScan(DataTable &table, ParallelCollectionScanState &state) { - auto storage = table_manager.GetStorage(table); - if (!storage) { - state.max_row = 0; - state.vector_index = 0; - state.current_row_group = nullptr; - } else { - storage->row_groups->InitializeParallelScan(state); - } -} - -bool LocalStorage::NextParallelScan(ClientContext &context, DataTable &table, ParallelCollectionScanState &state, - CollectionScanState &scan_state) { - auto storage = table_manager.GetStorage(table); - if (!storage) { - return false; - } - return storage->row_groups->NextParallelScan(context, state, scan_state); -} - -void LocalStorage::InitializeAppend(LocalAppendState &state, DataTable &table) { - state.storage = &table_manager.GetOrCreateStorage(table); - state.storage->row_groups->InitializeAppend(TransactionData(transaction), state.append_state, 0); -} - -void LocalStorage::Append(LocalAppendState &state, DataChunk &chunk) { - // append to unique indices (if any) - auto storage = state.storage; - idx_t base_id = MAX_ROW_ID + storage->row_groups->GetTotalRows() + state.append_state.total_append_count; - auto error = DataTable::AppendToIndexes(storage->indexes, chunk, base_id); - if (error) { - error.Throw(); - } - - //! Append the chunk to the local storage - auto new_row_group = storage->row_groups->Append(chunk, state.append_state); - //! Check if we should pre-emptively flush blocks to disk - if (new_row_group) { - storage->WriteNewRowGroup(); - } -} - -void LocalStorage::FinalizeAppend(LocalAppendState &state) { - state.storage->row_groups->FinalizeAppend(state.append_state.transaction, state.append_state); -} - -void LocalStorage::LocalMerge(DataTable &table, RowGroupCollection &collection) { - auto &storage = table_manager.GetOrCreateStorage(table); - if (!storage.indexes.Empty()) { - // append data to indexes if required - row_t base_id = MAX_ROW_ID + storage.row_groups->GetTotalRows(); - auto error = storage.AppendToIndexes(transaction, collection, storage.indexes, table.GetTypes(), base_id); - if (error) { - error.Throw(); - } - } - storage.row_groups->MergeStorage(collection); - storage.merged_storage = true; -} - -OptimisticDataWriter &LocalStorage::CreateOptimisticWriter(DataTable &table) { - auto &storage = table_manager.GetOrCreateStorage(table); - return storage.CreateOptimisticWriter(); -} - -void LocalStorage::FinalizeOptimisticWriter(DataTable &table, OptimisticDataWriter &writer) { - auto &storage = table_manager.GetOrCreateStorage(table); - storage.FinalizeOptimisticWriter(writer); -} - -bool LocalStorage::ChangesMade() noexcept { - return !table_manager.IsEmpty(); -} - -bool LocalStorage::Find(DataTable &table) { - return table_manager.GetStorage(table) != nullptr; -} - -idx_t LocalStorage::EstimatedSize() { - return table_manager.EstimatedSize(); -} - -idx_t LocalStorage::Delete(DataTable &table, Vector &row_ids, idx_t count) { - auto storage = table_manager.GetStorage(table); - D_ASSERT(storage); - - // delete from unique indices (if any) - if (!storage->indexes.Empty()) { - storage->row_groups->RemoveFromIndexes(storage->indexes, row_ids, count); - } - - auto ids = FlatVector::GetData(row_ids); - idx_t delete_count = storage->row_groups->Delete(TransactionData(0, 0), table, ids, count); - storage->deleted_rows += delete_count; - return delete_count; -} - -void LocalStorage::Update(DataTable &table, Vector &row_ids, const vector &column_ids, - DataChunk &updates) { - auto storage = table_manager.GetStorage(table); - D_ASSERT(storage); - - auto ids = FlatVector::GetData(row_ids); - storage->row_groups->Update(TransactionData(0, 0), ids, column_ids, updates); -} - -void LocalStorage::Flush(DataTable &table, LocalTableStorage &storage) { - if (storage.row_groups->GetTotalRows() <= storage.deleted_rows) { - return; - } - idx_t append_count = storage.row_groups->GetTotalRows() - storage.deleted_rows; - - TableAppendState append_state; - table.AppendLock(append_state); - transaction.PushAppend(table, append_state.row_start, append_count); - if ((append_state.row_start == 0 || storage.row_groups->GetTotalRows() >= MERGE_THRESHOLD) && - storage.deleted_rows == 0) { - // table is currently empty OR we are bulk appending: move over the storage directly - // first flush any outstanding blocks - storage.FlushBlocks(); - // now append to the indexes (if there are any) - // FIXME: we should be able to merge the transaction-local index directly into the main table index - // as long we just rewrite some row-ids - if (!table.info->indexes.Empty()) { - storage.AppendToIndexes(transaction, append_state, append_count, false); - } - // finally move over the row groups - table.MergeStorage(*storage.row_groups, storage.indexes); - } else { - // check if we have written data - // if we have, we cannot merge to disk after all - // so we need to revert the data we have already written - storage.Rollback(); - // append to the indexes and append to the base table - storage.AppendToIndexes(transaction, append_state, append_count, true); - } - - // possibly vacuum any excess index data - table.info->indexes.Scan([&](Index &index) { - index.Vacuum(); - return false; - }); -} - -void LocalStorage::Commit(LocalStorage::CommitState &commit_state, DuckTransaction &transaction) { - // commit local storage - // iterate over all entries in the table storage map and commit them - // after this, the local storage is no longer required and can be cleared - auto table_storage = table_manager.MoveEntries(); - for (auto &entry : table_storage) { - auto table = entry.first; - auto storage = entry.second.get(); - Flush(table, *storage); - entry.second.reset(); - } -} - -void LocalStorage::Rollback() { - // rollback local storage - // after this, the local storage is no longer required and can be cleared - auto table_storage = table_manager.MoveEntries(); - for (auto &entry : table_storage) { - auto storage = entry.second.get(); - if (!storage) { - continue; - } - storage->Rollback(); - - entry.second.reset(); - } -} - -idx_t LocalStorage::AddedRows(DataTable &table) { - auto storage = table_manager.GetStorage(table); - if (!storage) { - return 0; - } - return storage->row_groups->GetTotalRows() - storage->deleted_rows; -} - -void LocalStorage::MoveStorage(DataTable &old_dt, DataTable &new_dt) { - // check if there are any pending appends for the old version of the table - auto new_storage = table_manager.MoveEntry(old_dt); - if (!new_storage) { - return; - } - // take over the storage from the old entry - new_storage->table_ref = new_dt; - table_manager.InsertEntry(new_dt, std::move(new_storage)); -} - -void LocalStorage::AddColumn(DataTable &old_dt, DataTable &new_dt, ColumnDefinition &new_column, - Expression &default_value) { - // check if there are any pending appends for the old version of the table - auto storage = table_manager.MoveEntry(old_dt); - if (!storage) { - return; - } - auto new_storage = make_shared(context, new_dt, *storage, new_column, default_value); - table_manager.InsertEntry(new_dt, std::move(new_storage)); -} - -void LocalStorage::DropColumn(DataTable &old_dt, DataTable &new_dt, idx_t removed_column) { - // check if there are any pending appends for the old version of the table - auto storage = table_manager.MoveEntry(old_dt); - if (!storage) { - return; - } - auto new_storage = make_shared(new_dt, *storage, removed_column); - table_manager.InsertEntry(new_dt, std::move(new_storage)); -} - -void LocalStorage::ChangeType(DataTable &old_dt, DataTable &new_dt, idx_t changed_idx, const LogicalType &target_type, - const vector &bound_columns, Expression &cast_expr) { - // check if there are any pending appends for the old version of the table - auto storage = table_manager.MoveEntry(old_dt); - if (!storage) { - return; - } - auto new_storage = - make_shared(context, new_dt, *storage, changed_idx, target_type, bound_columns, cast_expr); - table_manager.InsertEntry(new_dt, std::move(new_storage)); -} - -void LocalStorage::FetchChunk(DataTable &table, Vector &row_ids, idx_t count, const vector &col_ids, - DataChunk &chunk, ColumnFetchState &fetch_state) { - auto storage = table_manager.GetStorage(table); - if (!storage) { - throw InternalException("LocalStorage::FetchChunk - local storage not found"); - } - - storage->row_groups->Fetch(transaction, chunk, col_ids, row_ids, count, fetch_state); -} - -TableIndexList &LocalStorage::GetIndexes(DataTable &table) { - auto storage = table_manager.GetStorage(table); - if (!storage) { - throw InternalException("LocalStorage::GetIndexes - local storage not found"); - } - return storage->indexes; -} - -void LocalStorage::VerifyNewConstraint(DataTable &parent, const BoundConstraint &constraint) { - auto storage = table_manager.GetStorage(parent); - if (!storage) { - return; - } - storage->row_groups->VerifyNewConstraint(parent, constraint); -} - -} // namespace duckdb - - - - -namespace duckdb { - -DataFileType MagicBytes::CheckMagicBytes(FileSystem *fs_p, const string &path) { - LocalFileSystem lfs; - FileSystem &fs = fs_p ? *fs_p : lfs; - if (!fs.FileExists(path)) { - return DataFileType::FILE_DOES_NOT_EXIST; - } - auto handle = fs.OpenFile(path, FileFlags::FILE_FLAGS_READ); - - constexpr const idx_t MAGIC_BYTES_READ_SIZE = 16; - char buffer[MAGIC_BYTES_READ_SIZE]; - - handle->Read(buffer, MAGIC_BYTES_READ_SIZE); - if (memcmp(buffer, "SQLite format 3\0", 16) == 0) { - return DataFileType::SQLITE_FILE; - } - if (memcmp(buffer, "PAR1", 4) == 0) { - return DataFileType::PARQUET_FILE; - } - if (memcmp(buffer + MainHeader::MAGIC_BYTE_OFFSET, MainHeader::MAGIC_BYTES, MainHeader::MAGIC_BYTE_SIZE) == 0) { - return DataFileType::DUCKDB_FILE; - } - return DataFileType::FILE_DOES_NOT_EXIST; -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -MetadataManager::MetadataManager(BlockManager &block_manager, BufferManager &buffer_manager) - : block_manager(block_manager), buffer_manager(buffer_manager) { -} - -MetadataManager::~MetadataManager() { -} - -MetadataHandle MetadataManager::AllocateHandle() { - // check if there is any free space left in an existing block - // if not allocate a new block - block_id_t free_block = INVALID_BLOCK; - for (auto &kv : blocks) { - auto &block = kv.second; - D_ASSERT(kv.first == block.block_id); - if (!block.free_blocks.empty()) { - free_block = kv.first; - break; - } - } - if (free_block == INVALID_BLOCK) { - free_block = AllocateNewBlock(); - } - D_ASSERT(free_block != INVALID_BLOCK); - - // select the first free metadata block we can find - MetadataPointer pointer; - pointer.block_index = free_block; - auto &block = blocks[free_block]; - if (block.block->BlockId() < MAXIMUM_BLOCK) { - // this block is a disk-backed block, yet we are planning to write to it - // we need to convert it into a transient block before we can write to it - ConvertToTransient(block); - D_ASSERT(block.block->BlockId() >= MAXIMUM_BLOCK); - } - D_ASSERT(!block.free_blocks.empty()); - pointer.index = block.free_blocks.back(); - // mark the block as used - block.free_blocks.pop_back(); - D_ASSERT(pointer.index < METADATA_BLOCK_COUNT); - // pin the block - return Pin(pointer); -} - -MetadataHandle MetadataManager::Pin(MetadataPointer pointer) { - D_ASSERT(pointer.index < METADATA_BLOCK_COUNT); - auto &block = blocks[pointer.block_index]; - - MetadataHandle handle; - handle.pointer.block_index = pointer.block_index; - handle.pointer.index = pointer.index; - handle.handle = buffer_manager.Pin(block.block); - return handle; -} - -void MetadataManager::ConvertToTransient(MetadataBlock &block) { - // pin the old block - auto old_buffer = buffer_manager.Pin(block.block); - - // allocate a new transient block to replace it - shared_ptr new_block; - auto new_buffer = buffer_manager.Allocate(Storage::BLOCK_SIZE, false, &new_block); - - // copy the data to the transient block - memcpy(new_buffer.Ptr(), old_buffer.Ptr(), Storage::BLOCK_SIZE); - - block.block = std::move(new_block); - - // unregister the old block - block_manager.UnregisterBlock(block.block_id, false); -} - -block_id_t MetadataManager::AllocateNewBlock() { - auto new_block_id = GetNextBlockId(); - - MetadataBlock new_block; - auto handle = buffer_manager.Allocate(Storage::BLOCK_SIZE, false, &new_block.block); - new_block.block_id = new_block_id; - for (idx_t i = 0; i < METADATA_BLOCK_COUNT; i++) { - new_block.free_blocks.push_back(METADATA_BLOCK_COUNT - i - 1); - } - // zero-initialize the handle - memset(handle.Ptr(), 0, Storage::BLOCK_SIZE); - AddBlock(std::move(new_block)); - return new_block_id; -} - -void MetadataManager::AddBlock(MetadataBlock new_block, bool if_exists) { - if (blocks.find(new_block.block_id) != blocks.end()) { - if (if_exists) { - return; - } - throw InternalException("Block id with id %llu already exists", new_block.block_id); - } - blocks[new_block.block_id] = std::move(new_block); -} - -void MetadataManager::AddAndRegisterBlock(MetadataBlock block) { - if (block.block) { - throw InternalException("Calling AddAndRegisterBlock on block that already exists"); - } - block.block = block_manager.RegisterBlock(block.block_id); - AddBlock(std::move(block), true); -} - -MetaBlockPointer MetadataManager::GetDiskPointer(MetadataPointer pointer, uint32_t offset) { - idx_t block_pointer = idx_t(pointer.block_index); - block_pointer |= idx_t(pointer.index) << 56ULL; - return MetaBlockPointer(block_pointer, offset); -} - -block_id_t MetaBlockPointer::GetBlockId() const { - return block_id_t(block_pointer & ~(idx_t(0xFF) << 56ULL)); -} - -uint32_t MetaBlockPointer::GetBlockIndex() const { - return block_pointer >> 56ULL; -} - -MetadataPointer MetadataManager::FromDiskPointer(MetaBlockPointer pointer) { - auto block_id = pointer.GetBlockId(); - auto index = pointer.GetBlockIndex(); - auto entry = blocks.find(block_id); - if (entry == blocks.end()) { // LCOV_EXCL_START - throw InternalException("Failed to load metadata pointer (id %llu, idx %llu, ptr %llu)\n", block_id, index, - pointer.block_pointer); - } // LCOV_EXCL_STOP - MetadataPointer result; - result.block_index = block_id; - result.index = index; - return result; -} - -MetadataPointer MetadataManager::RegisterDiskPointer(MetaBlockPointer pointer) { - auto block_id = pointer.GetBlockId(); - MetadataBlock block; - block.block_id = block_id; - AddAndRegisterBlock(block); - return FromDiskPointer(pointer); -} - -BlockPointer MetadataManager::ToBlockPointer(MetaBlockPointer meta_pointer) { - BlockPointer result; - result.block_id = meta_pointer.GetBlockId(); - result.offset = meta_pointer.GetBlockIndex() * MetadataManager::METADATA_BLOCK_SIZE + meta_pointer.offset; - D_ASSERT(result.offset < MetadataManager::METADATA_BLOCK_SIZE * MetadataManager::METADATA_BLOCK_COUNT); - return result; -} - -MetaBlockPointer MetadataManager::FromBlockPointer(BlockPointer block_pointer) { - if (!block_pointer.IsValid()) { - return MetaBlockPointer(); - } - idx_t index = block_pointer.offset / MetadataManager::METADATA_BLOCK_SIZE; - auto offset = block_pointer.offset % MetadataManager::METADATA_BLOCK_SIZE; - D_ASSERT(index < MetadataManager::METADATA_BLOCK_COUNT); - D_ASSERT(offset < MetadataManager::METADATA_BLOCK_SIZE); - MetaBlockPointer result; - result.block_pointer = idx_t(block_pointer.block_id) | index << 56ULL; - result.offset = offset; - return result; -} - -idx_t MetadataManager::BlockCount() { - return blocks.size(); -} - -void MetadataManager::Flush() { - const idx_t total_metadata_size = MetadataManager::METADATA_BLOCK_SIZE * MetadataManager::METADATA_BLOCK_COUNT; - // write the blocks of the metadata manager to disk - for (auto &kv : blocks) { - auto &block = kv.second; - auto handle = buffer_manager.Pin(block.block); - // there are a few bytes left-over at the end of the block, zero-initialize them - memset(handle.Ptr() + total_metadata_size, 0, Storage::BLOCK_SIZE - total_metadata_size); - D_ASSERT(kv.first == block.block_id); - if (block.block->BlockId() >= MAXIMUM_BLOCK) { - // temporary block - convert to persistent - block.block = block_manager.ConvertToPersistent(kv.first, std::move(block.block)); - } else { - // already a persistent block - only need to write it - D_ASSERT(block.block->BlockId() == block.block_id); - block_manager.Write(handle.GetFileBuffer(), block.block_id); - } - } -} - -void MetadataManager::Write(WriteStream &sink) { - sink.Write(blocks.size()); - for (auto &kv : blocks) { - kv.second.Write(sink); - } -} - -void MetadataManager::Read(ReadStream &source) { - auto block_count = source.Read(); - for (idx_t i = 0; i < block_count; i++) { - auto block = MetadataBlock::Read(source); - auto entry = blocks.find(block.block_id); - if (entry == blocks.end()) { - // block does not exist yet - AddAndRegisterBlock(std::move(block)); - } else { - // block was already created - only copy over the free list - entry->second.free_blocks = std::move(block.free_blocks); - } - } -} - -void MetadataBlock::Write(WriteStream &sink) { - sink.Write(block_id); - sink.Write(FreeBlocksToInteger()); -} - -MetadataBlock MetadataBlock::Read(ReadStream &source) { - MetadataBlock result; - result.block_id = source.Read(); - auto free_list = source.Read(); - result.FreeBlocksFromInteger(free_list); - return result; -} - -idx_t MetadataBlock::FreeBlocksToInteger() { - idx_t result = 0; - for (idx_t i = 0; i < free_blocks.size(); i++) { - D_ASSERT(free_blocks[i] < idx_t(64)); - idx_t mask = idx_t(1) << idx_t(free_blocks[i]); - result |= mask; - } - return result; -} - -void MetadataBlock::FreeBlocksFromInteger(idx_t free_list) { - free_blocks.clear(); - if (free_list == 0) { - return; - } - for (idx_t i = 64; i > 0; i--) { - auto index = i - 1; - idx_t mask = idx_t(1) << index; - if (free_list & mask) { - free_blocks.push_back(index); - } - } -} - -void MetadataManager::MarkBlocksAsModified() { - // for any blocks that were modified in the last checkpoint - set them to free blocks currently - for (auto &kv : modified_blocks) { - auto block_id = kv.first; - idx_t modified_list = kv.second; - auto entry = blocks.find(block_id); - D_ASSERT(entry != blocks.end()); - auto &block = entry->second; - idx_t current_free_blocks = block.FreeBlocksToInteger(); - // merge the current set of free blocks with the modified blocks - idx_t new_free_blocks = current_free_blocks | modified_list; - if (new_free_blocks == NumericLimits::Maximum()) { - // if new free_blocks is all blocks - mark entire block as modified - blocks.erase(entry); - block_manager.MarkBlockAsModified(block_id); - } else { - // set the new set of free blocks - block.FreeBlocksFromInteger(new_free_blocks); - } - } - - modified_blocks.clear(); - for (auto &kv : blocks) { - auto &block = kv.second; - idx_t free_list = block.FreeBlocksToInteger(); - idx_t occupied_list = ~free_list; - modified_blocks[block.block_id] = occupied_list; - } -} - -void MetadataManager::ClearModifiedBlocks(const vector &pointers) { - for (auto &pointer : pointers) { - auto block_id = pointer.GetBlockId(); - auto block_index = pointer.GetBlockIndex(); - auto entry = modified_blocks.find(block_id); - if (entry == modified_blocks.end()) { - throw InternalException("ClearModifiedBlocks - Block id %llu not found in modified_blocks", block_id); - } - auto &modified_list = entry->second; - // verify the block has been modified - D_ASSERT(modified_list && (1ULL << block_index)); - // unset the bit - modified_list &= ~(1ULL << block_index); - } -} - -vector MetadataManager::GetMetadataInfo() const { - vector result; - for (auto &block : blocks) { - MetadataBlockInfo block_info; - block_info.block_id = block.second.block_id; - block_info.total_blocks = MetadataManager::METADATA_BLOCK_COUNT; - for (auto free_block : block.second.free_blocks) { - block_info.free_list.push_back(free_block); - } - std::sort(block_info.free_list.begin(), block_info.free_list.end()); - result.push_back(std::move(block_info)); - } - std::sort(result.begin(), result.end(), - [](const MetadataBlockInfo &a, const MetadataBlockInfo &b) { return a.block_id < b.block_id; }); - return result; -} - -block_id_t MetadataManager::GetNextBlockId() { - return block_manager.GetFreeBlockId(); -} - -} // namespace duckdb - - -namespace duckdb { - -MetadataReader::MetadataReader(MetadataManager &manager, MetaBlockPointer pointer, - optional_ptr> read_pointers_p, BlockReaderType type) - : manager(manager), type(type), next_pointer(FromDiskPointer(pointer)), has_next_block(true), - read_pointers(read_pointers_p), index(0), offset(0), next_offset(pointer.offset), capacity(0) { - if (read_pointers) { - D_ASSERT(read_pointers->empty()); - read_pointers->push_back(pointer); - } -} - -MetadataReader::MetadataReader(MetadataManager &manager, BlockPointer pointer) - : MetadataReader(manager, MetadataManager::FromBlockPointer(pointer)) { -} - -MetadataPointer MetadataReader::FromDiskPointer(MetaBlockPointer pointer) { - if (type == BlockReaderType::EXISTING_BLOCKS) { - return manager.FromDiskPointer(pointer); - } else { - return manager.RegisterDiskPointer(pointer); - } -} - -MetadataReader::~MetadataReader() { -} - -void MetadataReader::ReadData(data_ptr_t buffer, idx_t read_size) { - while (offset + read_size > capacity) { - // cannot read entire entry from block - // first read what we can from this block - idx_t to_read = capacity - offset; - if (to_read > 0) { - memcpy(buffer, Ptr(), to_read); - read_size -= to_read; - buffer += to_read; - offset += read_size; - } - // then move to the next block - ReadNextBlock(); - } - // we have enough left in this block to read from the buffer - memcpy(buffer, Ptr(), read_size); - offset += read_size; -} - -MetaBlockPointer MetadataReader::GetMetaBlockPointer() { - return manager.GetDiskPointer(block.pointer, offset); -} - -void MetadataReader::ReadNextBlock() { - if (!has_next_block) { - throw IOException("No more data remaining in MetadataReader"); - } - block = manager.Pin(next_pointer); - index = next_pointer.index; - - idx_t next_block = Load(BasePtr()); - if (next_block == idx_t(-1)) { - has_next_block = false; - } else { - next_pointer = FromDiskPointer(MetaBlockPointer(next_block, 0)); - MetaBlockPointer next_block_pointer(next_block, 0); - if (read_pointers) { - read_pointers->push_back(next_block_pointer); - } - } - if (next_offset < sizeof(block_id_t)) { - next_offset = sizeof(block_id_t); - } - if (next_offset > MetadataManager::METADATA_BLOCK_SIZE) { - throw InternalException("next_offset cannot be bigger than block size"); - } - offset = next_offset; - next_offset = sizeof(block_id_t); - capacity = MetadataManager::METADATA_BLOCK_SIZE; -} - -data_ptr_t MetadataReader::BasePtr() { - return block.handle.Ptr() + index * MetadataManager::METADATA_BLOCK_SIZE; -} - -data_ptr_t MetadataReader::Ptr() { - return BasePtr() + offset; -} - -} // namespace duckdb - - - -namespace duckdb { - -MetadataWriter::MetadataWriter(MetadataManager &manager, optional_ptr> written_pointers_p) - : manager(manager), written_pointers(written_pointers_p), capacity(0), offset(0) { - D_ASSERT(!written_pointers || written_pointers->empty()); -} - -MetadataWriter::~MetadataWriter() { - // If there's an exception during checkpoint, this can get destroyed without - // flushing the data...which is fine, because none of the unwritten data - // will be referenced. - // - // Otherwise, we should have explicitly flushed (and thereby nulled the block). - D_ASSERT(!block.handle.IsValid() || Exception::UncaughtException()); -} - -BlockPointer MetadataWriter::GetBlockPointer() { - return MetadataManager::ToBlockPointer(GetMetaBlockPointer()); -} - -MetaBlockPointer MetadataWriter::GetMetaBlockPointer() { - if (offset >= capacity) { - // at the end of the block - fetch the next block - NextBlock(); - D_ASSERT(capacity > 0); - } - return manager.GetDiskPointer(block.pointer, offset); -} - -MetadataHandle MetadataWriter::NextHandle() { - return manager.AllocateHandle(); -} - -void MetadataWriter::NextBlock() { - // now we need to get a new block id - auto new_handle = NextHandle(); - - // write the block id of the new block to the start of the current block - if (capacity > 0) { - auto disk_block = manager.GetDiskPointer(new_handle.pointer); - Store(disk_block.block_pointer, BasePtr()); - } - // now update the block id of the block - block = std::move(new_handle); - current_pointer = block.pointer; - offset = sizeof(idx_t); - capacity = MetadataManager::METADATA_BLOCK_SIZE; - Store(-1, BasePtr()); - if (written_pointers) { - written_pointers->push_back(manager.GetDiskPointer(current_pointer)); - } -} - -void MetadataWriter::WriteData(const_data_ptr_t buffer, idx_t write_size) { - while (offset + write_size > capacity) { - // we need to make a new block - // first copy what we can - D_ASSERT(offset <= capacity); - idx_t copy_amount = capacity - offset; - if (copy_amount > 0) { - memcpy(Ptr(), buffer, copy_amount); - buffer += copy_amount; - offset += copy_amount; - write_size -= copy_amount; - } - // move forward to the next block - NextBlock(); - } - memcpy(Ptr(), buffer, write_size); - offset += write_size; -} - -void MetadataWriter::Flush() { - if (offset < capacity) { - // clear remaining bytes of block (if any) - memset(Ptr(), 0, capacity - offset); - } - block.handle.Destroy(); -} - -data_ptr_t MetadataWriter::BasePtr() { - return block.handle.Ptr() + current_pointer.index * MetadataManager::METADATA_BLOCK_SIZE; -} - -data_ptr_t MetadataWriter::Ptr() { - return BasePtr() + offset; -} - -} // namespace duckdb - - - - - -namespace duckdb { - -OptimisticDataWriter::OptimisticDataWriter(DataTable &table) : table(table) { -} - -OptimisticDataWriter::OptimisticDataWriter(DataTable &table, OptimisticDataWriter &parent) : table(table) { - if (parent.partial_manager) { - parent.partial_manager->ClearBlocks(); - } -} - -OptimisticDataWriter::~OptimisticDataWriter() { -} - -bool OptimisticDataWriter::PrepareWrite() { - // check if we should pre-emptively write the table to disk - if (table.info->IsTemporary() || StorageManager::Get(table.info->db).InMemory()) { - return false; - } - // we should! write the second-to-last row group to disk - // allocate the partial block-manager if none is allocated yet - if (!partial_manager) { - auto &block_manager = table.info->table_io_manager->GetBlockManagerForRowData(); - partial_manager = make_uniq(block_manager, CheckpointType::APPEND_TO_TABLE); - } - return true; -} - -void OptimisticDataWriter::WriteNewRowGroup(RowGroupCollection &row_groups) { - // we finished writing a complete row group - if (!PrepareWrite()) { - return; - } - // flush second-to-last row group - auto row_group = row_groups.GetRowGroup(-2); - FlushToDisk(row_group); -} - -void OptimisticDataWriter::WriteLastRowGroup(RowGroupCollection &row_groups) { - // we finished writing a complete row group - if (!PrepareWrite()) { - return; - } - // flush second-to-last row group - auto row_group = row_groups.GetRowGroup(-1); - if (!row_group) { - return; - } - FlushToDisk(row_group); -} - -void OptimisticDataWriter::FlushToDisk(RowGroup *row_group) { - if (!row_group) { - throw InternalException("FlushToDisk called without a RowGroup"); - } - //! The set of column compression types (if any) - vector compression_types; - D_ASSERT(compression_types.empty()); - for (auto &column : table.column_definitions) { - compression_types.push_back(column.CompressionType()); - } - row_group->WriteToDisk(*partial_manager, compression_types); -} - -void OptimisticDataWriter::Merge(OptimisticDataWriter &other) { - if (!other.partial_manager) { - return; - } - if (!partial_manager) { - partial_manager = std::move(other.partial_manager); - return; - } - partial_manager->Merge(*other.partial_manager); - other.partial_manager.reset(); -} - -void OptimisticDataWriter::FinalFlush() { - if (partial_manager) { - partial_manager->FlushPartialBlocks(); - partial_manager.reset(); - } -} - -void OptimisticDataWriter::Rollback() { - if (partial_manager) { - partial_manager->Rollback(); - partial_manager.reset(); - } -} - -} // namespace duckdb - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// PartialBlock -//===--------------------------------------------------------------------===// - -PartialBlock::PartialBlock(PartialBlockState state, BlockManager &block_manager, - const shared_ptr &block_handle) - : state(state), block_manager(block_manager), block_handle(block_handle) { -} - -void PartialBlock::AddUninitializedRegion(idx_t start, idx_t end) { - uninitialized_regions.push_back({start, end}); -} - -void PartialBlock::FlushInternal(const idx_t free_space_left) { - - // ensure that we do not leak any data - if (free_space_left > 0 || !uninitialized_regions.empty()) { - auto buffer_handle = block_manager.buffer_manager.Pin(block_handle); - - // memset any uninitialized regions - for (auto &uninitialized : uninitialized_regions) { - memset(buffer_handle.Ptr() + uninitialized.start, 0, uninitialized.end - uninitialized.start); - } - // memset any free space at the end of the block to 0 prior to writing to disk - memset(buffer_handle.Ptr() + Storage::BLOCK_SIZE - free_space_left, 0, free_space_left); - } -} - -//===--------------------------------------------------------------------===// -// PartialBlockManager -//===--------------------------------------------------------------------===// - -PartialBlockManager::PartialBlockManager(BlockManager &block_manager, CheckpointType checkpoint_type, - uint32_t max_partial_block_size, uint32_t max_use_count) - : block_manager(block_manager), checkpoint_type(checkpoint_type), max_partial_block_size(max_partial_block_size), - max_use_count(max_use_count) { -} -PartialBlockManager::~PartialBlockManager() { -} - -PartialBlockAllocation PartialBlockManager::GetBlockAllocation(uint32_t segment_size) { - PartialBlockAllocation allocation; - allocation.block_manager = &block_manager; - allocation.allocation_size = segment_size; - - // if the block is less than 80% full, we consider it a "partial block" - // which means we will try to fit it with other blocks - // check if there is a partial block available we can write to - if (segment_size <= max_partial_block_size && GetPartialBlock(segment_size, allocation.partial_block)) { - //! there is! increase the reference count of this block - allocation.partial_block->state.block_use_count += 1; - allocation.state = allocation.partial_block->state; - if (checkpoint_type == CheckpointType::FULL_CHECKPOINT) { - block_manager.IncreaseBlockReferenceCount(allocation.state.block_id); - } - } else { - // full block: get a free block to write to - AllocateBlock(allocation.state, segment_size); - } - return allocation; -} - -bool PartialBlockManager::HasBlockAllocation(uint32_t segment_size) { - return segment_size <= max_partial_block_size && - partially_filled_blocks.lower_bound(segment_size) != partially_filled_blocks.end(); -} - -void PartialBlockManager::AllocateBlock(PartialBlockState &state, uint32_t segment_size) { - D_ASSERT(segment_size <= Storage::BLOCK_SIZE); - if (checkpoint_type == CheckpointType::FULL_CHECKPOINT) { - state.block_id = block_manager.GetFreeBlockId(); - } else { - state.block_id = INVALID_BLOCK; - } - state.block_size = Storage::BLOCK_SIZE; - state.offset = 0; - state.block_use_count = 1; -} - -bool PartialBlockManager::GetPartialBlock(idx_t segment_size, unique_ptr &partial_block) { - auto entry = partially_filled_blocks.lower_bound(segment_size); - if (entry == partially_filled_blocks.end()) { - return false; - } - // found a partially filled block! fill in the info - partial_block = std::move(entry->second); - partially_filled_blocks.erase(entry); - - D_ASSERT(partial_block->state.offset > 0); - D_ASSERT(ValueIsAligned(partial_block->state.offset)); - return true; -} - -void PartialBlockManager::RegisterPartialBlock(PartialBlockAllocation &&allocation) { - auto &state = allocation.partial_block->state; - D_ASSERT(checkpoint_type != CheckpointType::FULL_CHECKPOINT || state.block_id >= 0); - if (state.block_use_count < max_use_count) { - auto unaligned_size = allocation.allocation_size + state.offset; - auto new_size = AlignValue(unaligned_size); - if (new_size != unaligned_size) { - // register the uninitialized region so we can correctly initialize it before writing to disk - allocation.partial_block->AddUninitializedRegion(unaligned_size, new_size); - } - state.offset = new_size; - auto new_space_left = state.block_size - new_size; - // check if the block is STILL partially filled after adding the segment_size - if (new_space_left >= Storage::BLOCK_SIZE - max_partial_block_size) { - // the block is still partially filled: add it to the partially_filled_blocks list - partially_filled_blocks.insert(make_pair(new_space_left, std::move(allocation.partial_block))); - } - } - idx_t free_space = state.block_size - state.offset; - auto block_to_free = std::move(allocation.partial_block); - if (!block_to_free && partially_filled_blocks.size() > MAX_BLOCK_MAP_SIZE) { - // Free the page with the least space free. - auto itr = partially_filled_blocks.begin(); - block_to_free = std::move(itr->second); - free_space = state.block_size - itr->first; - partially_filled_blocks.erase(itr); - } - // Flush any block that we're not going to reuse. - if (block_to_free) { - block_to_free->Flush(free_space); - AddWrittenBlock(block_to_free->state.block_id); - } -} - -void PartialBlockManager::Merge(PartialBlockManager &other) { - if (&other == this) { - throw InternalException("Cannot merge into itself"); - } - // for each partially filled block in the other manager, check if we can merge it into an existing block in this - // manager - for (auto &e : other.partially_filled_blocks) { - if (!e.second) { - throw InternalException("Empty partially filled block found"); - } - auto used_space = Storage::BLOCK_SIZE - e.first; - if (HasBlockAllocation(used_space)) { - // we can merge this block into an existing block - merge them - // merge blocks - auto allocation = GetBlockAllocation(used_space); - allocation.partial_block->Merge(*e.second, allocation.state.offset, used_space); - - // re-register the partial block - allocation.state.offset += used_space; - RegisterPartialBlock(std::move(allocation)); - } else { - // we cannot merge this block - append it directly to the current block manager - partially_filled_blocks.insert(make_pair(e.first, std::move(e.second))); - } - } - // copy over the written blocks - for (auto &block_id : other.written_blocks) { - AddWrittenBlock(block_id); - } - other.written_blocks.clear(); - other.partially_filled_blocks.clear(); -} - -void PartialBlockManager::AddWrittenBlock(block_id_t block) { - auto entry = written_blocks.insert(block); - if (!entry.second) { - throw InternalException("Written block already exists"); - } -} - -void PartialBlockManager::ClearBlocks() { - for (auto &e : partially_filled_blocks) { - e.second->Clear(); - } - partially_filled_blocks.clear(); -} - -void PartialBlockManager::FlushPartialBlocks() { - for (auto &e : partially_filled_blocks) { - e.second->Flush(e.first); - } - partially_filled_blocks.clear(); -} - -void PartialBlockManager::Rollback() { - ClearBlocks(); - for (auto &block_id : written_blocks) { - block_manager.MarkBlockAsFree(block_id); - } -} - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// This file is automatically generated by scripts/generate_serialization.py -// Do not edit this file manually, your changes will be overwritten -//===----------------------------------------------------------------------===// - - - - - -namespace duckdb { - -void Constraint::Serialize(Serializer &serializer) const { - serializer.WriteProperty(100, "type", type); -} - -unique_ptr Constraint::Deserialize(Deserializer &deserializer) { - auto type = deserializer.ReadProperty(100, "type"); - unique_ptr result; - switch (type) { - case ConstraintType::CHECK: - result = CheckConstraint::Deserialize(deserializer); - break; - case ConstraintType::FOREIGN_KEY: - result = ForeignKeyConstraint::Deserialize(deserializer); - break; - case ConstraintType::NOT_NULL: - result = NotNullConstraint::Deserialize(deserializer); - break; - case ConstraintType::UNIQUE: - result = UniqueConstraint::Deserialize(deserializer); - break; - default: - throw SerializationException("Unsupported type for deserialization of Constraint!"); - } - return result; -} - -void CheckConstraint::Serialize(Serializer &serializer) const { - Constraint::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "expression", expression); -} - -unique_ptr CheckConstraint::Deserialize(Deserializer &deserializer) { - auto expression = deserializer.ReadPropertyWithDefault>(200, "expression"); - auto result = duckdb::unique_ptr(new CheckConstraint(std::move(expression))); - return std::move(result); -} - -void ForeignKeyConstraint::Serialize(Serializer &serializer) const { - Constraint::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "pk_columns", pk_columns); - serializer.WritePropertyWithDefault>(201, "fk_columns", fk_columns); - serializer.WriteProperty(202, "fk_type", info.type); - serializer.WritePropertyWithDefault(203, "schema", info.schema); - serializer.WritePropertyWithDefault(204, "table", info.table); - serializer.WritePropertyWithDefault>(205, "pk_keys", info.pk_keys); - serializer.WritePropertyWithDefault>(206, "fk_keys", info.fk_keys); -} - -unique_ptr ForeignKeyConstraint::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new ForeignKeyConstraint()); - deserializer.ReadPropertyWithDefault>(200, "pk_columns", result->pk_columns); - deserializer.ReadPropertyWithDefault>(201, "fk_columns", result->fk_columns); - deserializer.ReadProperty(202, "fk_type", result->info.type); - deserializer.ReadPropertyWithDefault(203, "schema", result->info.schema); - deserializer.ReadPropertyWithDefault(204, "table", result->info.table); - deserializer.ReadPropertyWithDefault>(205, "pk_keys", result->info.pk_keys); - deserializer.ReadPropertyWithDefault>(206, "fk_keys", result->info.fk_keys); - return std::move(result); -} - -void NotNullConstraint::Serialize(Serializer &serializer) const { - Constraint::Serialize(serializer); - serializer.WriteProperty(200, "index", index); -} - -unique_ptr NotNullConstraint::Deserialize(Deserializer &deserializer) { - auto index = deserializer.ReadProperty(200, "index"); - auto result = duckdb::unique_ptr(new NotNullConstraint(index)); - return std::move(result); -} - -void UniqueConstraint::Serialize(Serializer &serializer) const { - Constraint::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "is_primary_key", is_primary_key); - serializer.WriteProperty(201, "index", index); - serializer.WritePropertyWithDefault>(202, "columns", columns); -} - -unique_ptr UniqueConstraint::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new UniqueConstraint()); - deserializer.ReadPropertyWithDefault(200, "is_primary_key", result->is_primary_key); - deserializer.ReadProperty(201, "index", result->index); - deserializer.ReadPropertyWithDefault>(202, "columns", result->columns); - return std::move(result); -} - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// This file is automatically generated by scripts/generate_serialization.py -// Do not edit this file manually, your changes will be overwritten -//===----------------------------------------------------------------------===// - - - - - - - - - - - - -namespace duckdb { - -void CreateInfo::Serialize(Serializer &serializer) const { - serializer.WriteProperty(100, "type", type); - serializer.WritePropertyWithDefault(101, "catalog", catalog); - serializer.WritePropertyWithDefault(102, "schema", schema); - serializer.WritePropertyWithDefault(103, "temporary", temporary); - serializer.WritePropertyWithDefault(104, "internal", internal); - serializer.WriteProperty(105, "on_conflict", on_conflict); - serializer.WritePropertyWithDefault(106, "sql", sql); -} - -unique_ptr CreateInfo::Deserialize(Deserializer &deserializer) { - auto type = deserializer.ReadProperty(100, "type"); - auto catalog = deserializer.ReadPropertyWithDefault(101, "catalog"); - auto schema = deserializer.ReadPropertyWithDefault(102, "schema"); - auto temporary = deserializer.ReadPropertyWithDefault(103, "temporary"); - auto internal = deserializer.ReadPropertyWithDefault(104, "internal"); - auto on_conflict = deserializer.ReadProperty(105, "on_conflict"); - auto sql = deserializer.ReadPropertyWithDefault(106, "sql"); - deserializer.Set(type); - unique_ptr result; - switch (type) { - case CatalogType::INDEX_ENTRY: - result = CreateIndexInfo::Deserialize(deserializer); - break; - case CatalogType::MACRO_ENTRY: - result = CreateMacroInfo::Deserialize(deserializer); - break; - case CatalogType::SCHEMA_ENTRY: - result = CreateSchemaInfo::Deserialize(deserializer); - break; - case CatalogType::SEQUENCE_ENTRY: - result = CreateSequenceInfo::Deserialize(deserializer); - break; - case CatalogType::TABLE_ENTRY: - result = CreateTableInfo::Deserialize(deserializer); - break; - case CatalogType::TABLE_MACRO_ENTRY: - result = CreateMacroInfo::Deserialize(deserializer); - break; - case CatalogType::TYPE_ENTRY: - result = CreateTypeInfo::Deserialize(deserializer); - break; - case CatalogType::VIEW_ENTRY: - result = CreateViewInfo::Deserialize(deserializer); - break; - default: - throw SerializationException("Unsupported type for deserialization of CreateInfo!"); - } - deserializer.Unset(); - result->catalog = std::move(catalog); - result->schema = std::move(schema); - result->temporary = temporary; - result->internal = internal; - result->on_conflict = on_conflict; - result->sql = std::move(sql); - return result; -} - -void CreateIndexInfo::Serialize(Serializer &serializer) const { - CreateInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "name", index_name); - serializer.WritePropertyWithDefault(201, "table", table); - serializer.WriteProperty(202, "index_type", index_type); - serializer.WriteProperty(203, "constraint_type", constraint_type); - serializer.WritePropertyWithDefault>>(204, "parsed_expressions", parsed_expressions); - serializer.WritePropertyWithDefault>(205, "scan_types", scan_types); - serializer.WritePropertyWithDefault>(206, "names", names); - serializer.WritePropertyWithDefault>(207, "column_ids", column_ids); - serializer.WritePropertyWithDefault>(208, "options", options); - serializer.WritePropertyWithDefault(209, "index_type_name", index_type_name); -} - -unique_ptr CreateIndexInfo::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new CreateIndexInfo()); - deserializer.ReadPropertyWithDefault(200, "name", result->index_name); - deserializer.ReadPropertyWithDefault(201, "table", result->table); - deserializer.ReadProperty(202, "index_type", result->index_type); - deserializer.ReadProperty(203, "constraint_type", result->constraint_type); - deserializer.ReadPropertyWithDefault>>(204, "parsed_expressions", result->parsed_expressions); - deserializer.ReadPropertyWithDefault>(205, "scan_types", result->scan_types); - deserializer.ReadPropertyWithDefault>(206, "names", result->names); - deserializer.ReadPropertyWithDefault>(207, "column_ids", result->column_ids); - deserializer.ReadPropertyWithDefault>(208, "options", result->options); - deserializer.ReadPropertyWithDefault(209, "index_type_name", result->index_type_name); - return std::move(result); -} - -void CreateMacroInfo::Serialize(Serializer &serializer) const { - CreateInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "name", name); - serializer.WritePropertyWithDefault>(201, "function", function); -} - -unique_ptr CreateMacroInfo::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new CreateMacroInfo(deserializer.Get())); - deserializer.ReadPropertyWithDefault(200, "name", result->name); - deserializer.ReadPropertyWithDefault>(201, "function", result->function); - return std::move(result); -} - -void CreateSchemaInfo::Serialize(Serializer &serializer) const { - CreateInfo::Serialize(serializer); -} - -unique_ptr CreateSchemaInfo::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new CreateSchemaInfo()); - return std::move(result); -} - -void CreateSequenceInfo::Serialize(Serializer &serializer) const { - CreateInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "name", name); - serializer.WritePropertyWithDefault(201, "usage_count", usage_count); - serializer.WritePropertyWithDefault(202, "increment", increment); - serializer.WritePropertyWithDefault(203, "min_value", min_value); - serializer.WritePropertyWithDefault(204, "max_value", max_value); - serializer.WritePropertyWithDefault(205, "start_value", start_value); - serializer.WritePropertyWithDefault(206, "cycle", cycle); -} - -unique_ptr CreateSequenceInfo::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new CreateSequenceInfo()); - deserializer.ReadPropertyWithDefault(200, "name", result->name); - deserializer.ReadPropertyWithDefault(201, "usage_count", result->usage_count); - deserializer.ReadPropertyWithDefault(202, "increment", result->increment); - deserializer.ReadPropertyWithDefault(203, "min_value", result->min_value); - deserializer.ReadPropertyWithDefault(204, "max_value", result->max_value); - deserializer.ReadPropertyWithDefault(205, "start_value", result->start_value); - deserializer.ReadPropertyWithDefault(206, "cycle", result->cycle); - return std::move(result); -} - -void CreateTableInfo::Serialize(Serializer &serializer) const { - CreateInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "table", table); - serializer.WriteProperty(201, "columns", columns); - serializer.WritePropertyWithDefault>>(202, "constraints", constraints); - serializer.WritePropertyWithDefault>(203, "query", query); -} - -unique_ptr CreateTableInfo::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new CreateTableInfo()); - deserializer.ReadPropertyWithDefault(200, "table", result->table); - deserializer.ReadProperty(201, "columns", result->columns); - deserializer.ReadPropertyWithDefault>>(202, "constraints", result->constraints); - deserializer.ReadPropertyWithDefault>(203, "query", result->query); - return std::move(result); -} - -void CreateTypeInfo::Serialize(Serializer &serializer) const { - CreateInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "name", name); - serializer.WriteProperty(201, "logical_type", type); -} - -unique_ptr CreateTypeInfo::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new CreateTypeInfo()); - deserializer.ReadPropertyWithDefault(200, "name", result->name); - deserializer.ReadProperty(201, "logical_type", result->type); - return std::move(result); -} - -void CreateViewInfo::Serialize(Serializer &serializer) const { - CreateInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "view_name", view_name); - serializer.WritePropertyWithDefault>(201, "aliases", aliases); - serializer.WritePropertyWithDefault>(202, "types", types); - serializer.WritePropertyWithDefault>(203, "query", query); -} - -unique_ptr CreateViewInfo::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new CreateViewInfo()); - deserializer.ReadPropertyWithDefault(200, "view_name", result->view_name); - deserializer.ReadPropertyWithDefault>(201, "aliases", result->aliases); - deserializer.ReadPropertyWithDefault>(202, "types", result->types); - deserializer.ReadPropertyWithDefault>(203, "query", result->query); - return std::move(result); -} - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// This file is automatically generated by scripts/generate_serialization.py -// Do not edit this file manually, your changes will be overwritten -//===----------------------------------------------------------------------===// - - - - - -namespace duckdb { - -void Expression::Serialize(Serializer &serializer) const { - serializer.WriteProperty(100, "expression_class", expression_class); - serializer.WriteProperty(101, "type", type); - serializer.WritePropertyWithDefault(102, "alias", alias); -} - -unique_ptr Expression::Deserialize(Deserializer &deserializer) { - auto expression_class = deserializer.ReadProperty(100, "expression_class"); - auto type = deserializer.ReadProperty(101, "type"); - auto alias = deserializer.ReadPropertyWithDefault(102, "alias"); - deserializer.Set(type); - unique_ptr result; - switch (expression_class) { - case ExpressionClass::BOUND_AGGREGATE: - result = BoundAggregateExpression::Deserialize(deserializer); - break; - case ExpressionClass::BOUND_BETWEEN: - result = BoundBetweenExpression::Deserialize(deserializer); - break; - case ExpressionClass::BOUND_CASE: - result = BoundCaseExpression::Deserialize(deserializer); - break; - case ExpressionClass::BOUND_CAST: - result = BoundCastExpression::Deserialize(deserializer); - break; - case ExpressionClass::BOUND_COLUMN_REF: - result = BoundColumnRefExpression::Deserialize(deserializer); - break; - case ExpressionClass::BOUND_COMPARISON: - result = BoundComparisonExpression::Deserialize(deserializer); - break; - case ExpressionClass::BOUND_CONJUNCTION: - result = BoundConjunctionExpression::Deserialize(deserializer); - break; - case ExpressionClass::BOUND_CONSTANT: - result = BoundConstantExpression::Deserialize(deserializer); - break; - case ExpressionClass::BOUND_DEFAULT: - result = BoundDefaultExpression::Deserialize(deserializer); - break; - case ExpressionClass::BOUND_FUNCTION: - result = BoundFunctionExpression::Deserialize(deserializer); - break; - case ExpressionClass::BOUND_LAMBDA: - result = BoundLambdaExpression::Deserialize(deserializer); - break; - case ExpressionClass::BOUND_LAMBDA_REF: - result = BoundLambdaRefExpression::Deserialize(deserializer); - break; - case ExpressionClass::BOUND_OPERATOR: - result = BoundOperatorExpression::Deserialize(deserializer); - break; - case ExpressionClass::BOUND_PARAMETER: - result = BoundParameterExpression::Deserialize(deserializer); - break; - case ExpressionClass::BOUND_REF: - result = BoundReferenceExpression::Deserialize(deserializer); - break; - case ExpressionClass::BOUND_UNNEST: - result = BoundUnnestExpression::Deserialize(deserializer); - break; - case ExpressionClass::BOUND_WINDOW: - result = BoundWindowExpression::Deserialize(deserializer); - break; - default: - throw SerializationException("Unsupported type for deserialization of Expression!"); - } - deserializer.Unset(); - result->alias = std::move(alias); - return result; -} - -void BoundBetweenExpression::Serialize(Serializer &serializer) const { - Expression::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "input", input); - serializer.WritePropertyWithDefault>(201, "lower", lower); - serializer.WritePropertyWithDefault>(202, "upper", upper); - serializer.WritePropertyWithDefault(203, "lower_inclusive", lower_inclusive); - serializer.WritePropertyWithDefault(204, "upper_inclusive", upper_inclusive); -} - -unique_ptr BoundBetweenExpression::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new BoundBetweenExpression()); - deserializer.ReadPropertyWithDefault>(200, "input", result->input); - deserializer.ReadPropertyWithDefault>(201, "lower", result->lower); - deserializer.ReadPropertyWithDefault>(202, "upper", result->upper); - deserializer.ReadPropertyWithDefault(203, "lower_inclusive", result->lower_inclusive); - deserializer.ReadPropertyWithDefault(204, "upper_inclusive", result->upper_inclusive); - return std::move(result); -} - -void BoundCaseExpression::Serialize(Serializer &serializer) const { - Expression::Serialize(serializer); - serializer.WriteProperty(200, "return_type", return_type); - serializer.WritePropertyWithDefault>(201, "case_checks", case_checks); - serializer.WritePropertyWithDefault>(202, "else_expr", else_expr); -} - -unique_ptr BoundCaseExpression::Deserialize(Deserializer &deserializer) { - auto return_type = deserializer.ReadProperty(200, "return_type"); - auto result = duckdb::unique_ptr(new BoundCaseExpression(std::move(return_type))); - deserializer.ReadPropertyWithDefault>(201, "case_checks", result->case_checks); - deserializer.ReadPropertyWithDefault>(202, "else_expr", result->else_expr); - return std::move(result); -} - -void BoundCastExpression::Serialize(Serializer &serializer) const { - Expression::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "child", child); - serializer.WriteProperty(201, "return_type", return_type); - serializer.WritePropertyWithDefault(202, "try_cast", try_cast); -} - -unique_ptr BoundCastExpression::Deserialize(Deserializer &deserializer) { - auto child = deserializer.ReadPropertyWithDefault>(200, "child"); - auto return_type = deserializer.ReadProperty(201, "return_type"); - auto result = duckdb::unique_ptr(new BoundCastExpression(deserializer.Get(), std::move(child), std::move(return_type))); - deserializer.ReadPropertyWithDefault(202, "try_cast", result->try_cast); - return std::move(result); -} - -void BoundColumnRefExpression::Serialize(Serializer &serializer) const { - Expression::Serialize(serializer); - serializer.WriteProperty(200, "return_type", return_type); - serializer.WriteProperty(201, "binding", binding); - serializer.WritePropertyWithDefault(202, "depth", depth); -} - -unique_ptr BoundColumnRefExpression::Deserialize(Deserializer &deserializer) { - auto return_type = deserializer.ReadProperty(200, "return_type"); - auto binding = deserializer.ReadProperty(201, "binding"); - auto depth = deserializer.ReadPropertyWithDefault(202, "depth"); - auto result = duckdb::unique_ptr(new BoundColumnRefExpression(std::move(return_type), binding, depth)); - return std::move(result); -} - -void BoundComparisonExpression::Serialize(Serializer &serializer) const { - Expression::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "left", left); - serializer.WritePropertyWithDefault>(201, "right", right); -} - -unique_ptr BoundComparisonExpression::Deserialize(Deserializer &deserializer) { - auto left = deserializer.ReadPropertyWithDefault>(200, "left"); - auto right = deserializer.ReadPropertyWithDefault>(201, "right"); - auto result = duckdb::unique_ptr(new BoundComparisonExpression(deserializer.Get(), std::move(left), std::move(right))); - return std::move(result); -} - -void BoundConjunctionExpression::Serialize(Serializer &serializer) const { - Expression::Serialize(serializer); - serializer.WritePropertyWithDefault>>(200, "children", children); -} - -unique_ptr BoundConjunctionExpression::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new BoundConjunctionExpression(deserializer.Get())); - deserializer.ReadPropertyWithDefault>>(200, "children", result->children); - return std::move(result); -} - -void BoundConstantExpression::Serialize(Serializer &serializer) const { - Expression::Serialize(serializer); - serializer.WriteProperty(200, "value", value); -} - -unique_ptr BoundConstantExpression::Deserialize(Deserializer &deserializer) { - auto value = deserializer.ReadProperty(200, "value"); - auto result = duckdb::unique_ptr(new BoundConstantExpression(value)); - return std::move(result); -} - -void BoundDefaultExpression::Serialize(Serializer &serializer) const { - Expression::Serialize(serializer); - serializer.WriteProperty(200, "return_type", return_type); -} - -unique_ptr BoundDefaultExpression::Deserialize(Deserializer &deserializer) { - auto return_type = deserializer.ReadProperty(200, "return_type"); - auto result = duckdb::unique_ptr(new BoundDefaultExpression(std::move(return_type))); - return std::move(result); -} - -void BoundLambdaExpression::Serialize(Serializer &serializer) const { - Expression::Serialize(serializer); - serializer.WriteProperty(200, "return_type", return_type); - serializer.WritePropertyWithDefault>(201, "lambda_expr", lambda_expr); - serializer.WritePropertyWithDefault>>(202, "captures", captures); - serializer.WritePropertyWithDefault(203, "parameter_count", parameter_count); -} - -unique_ptr BoundLambdaExpression::Deserialize(Deserializer &deserializer) { - auto return_type = deserializer.ReadProperty(200, "return_type"); - auto lambda_expr = deserializer.ReadPropertyWithDefault>(201, "lambda_expr"); - auto captures = deserializer.ReadPropertyWithDefault>>(202, "captures"); - auto parameter_count = deserializer.ReadPropertyWithDefault(203, "parameter_count"); - auto result = duckdb::unique_ptr(new BoundLambdaExpression(deserializer.Get(), std::move(return_type), std::move(lambda_expr), parameter_count)); - result->captures = std::move(captures); - return std::move(result); -} - -void BoundLambdaRefExpression::Serialize(Serializer &serializer) const { - Expression::Serialize(serializer); - serializer.WriteProperty(200, "return_type", return_type); - serializer.WriteProperty(201, "binding", binding); - serializer.WritePropertyWithDefault(202, "lambda_index", lambda_index); - serializer.WritePropertyWithDefault(203, "depth", depth); -} - -unique_ptr BoundLambdaRefExpression::Deserialize(Deserializer &deserializer) { - auto return_type = deserializer.ReadProperty(200, "return_type"); - auto binding = deserializer.ReadProperty(201, "binding"); - auto lambda_index = deserializer.ReadPropertyWithDefault(202, "lambda_index"); - auto depth = deserializer.ReadPropertyWithDefault(203, "depth"); - auto result = duckdb::unique_ptr(new BoundLambdaRefExpression(std::move(return_type), binding, lambda_index, depth)); - return std::move(result); -} - -void BoundOperatorExpression::Serialize(Serializer &serializer) const { - Expression::Serialize(serializer); - serializer.WriteProperty(200, "return_type", return_type); - serializer.WritePropertyWithDefault>>(201, "children", children); -} - -unique_ptr BoundOperatorExpression::Deserialize(Deserializer &deserializer) { - auto return_type = deserializer.ReadProperty(200, "return_type"); - auto result = duckdb::unique_ptr(new BoundOperatorExpression(deserializer.Get(), std::move(return_type))); - deserializer.ReadPropertyWithDefault>>(201, "children", result->children); - return std::move(result); -} - -void BoundParameterExpression::Serialize(Serializer &serializer) const { - Expression::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "identifier", identifier); - serializer.WriteProperty(201, "return_type", return_type); - serializer.WritePropertyWithDefault>(202, "parameter_data", parameter_data); -} - -unique_ptr BoundParameterExpression::Deserialize(Deserializer &deserializer) { - auto identifier = deserializer.ReadPropertyWithDefault(200, "identifier"); - auto return_type = deserializer.ReadProperty(201, "return_type"); - auto parameter_data = deserializer.ReadPropertyWithDefault>(202, "parameter_data"); - auto result = duckdb::unique_ptr(new BoundParameterExpression(deserializer.Get(), std::move(identifier), std::move(return_type), std::move(parameter_data))); - return std::move(result); -} - -void BoundReferenceExpression::Serialize(Serializer &serializer) const { - Expression::Serialize(serializer); - serializer.WriteProperty(200, "return_type", return_type); - serializer.WritePropertyWithDefault(201, "index", index); -} - -unique_ptr BoundReferenceExpression::Deserialize(Deserializer &deserializer) { - auto return_type = deserializer.ReadProperty(200, "return_type"); - auto index = deserializer.ReadPropertyWithDefault(201, "index"); - auto result = duckdb::unique_ptr(new BoundReferenceExpression(std::move(return_type), index)); - return std::move(result); -} - -void BoundUnnestExpression::Serialize(Serializer &serializer) const { - Expression::Serialize(serializer); - serializer.WriteProperty(200, "return_type", return_type); - serializer.WritePropertyWithDefault>(201, "child", child); -} - -unique_ptr BoundUnnestExpression::Deserialize(Deserializer &deserializer) { - auto return_type = deserializer.ReadProperty(200, "return_type"); - auto result = duckdb::unique_ptr(new BoundUnnestExpression(std::move(return_type))); - deserializer.ReadPropertyWithDefault>(201, "child", result->child); - return std::move(result); -} - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// This file is automatically generated by scripts/generate_serialization.py -// Do not edit this file manually, your changes will be overwritten -//===----------------------------------------------------------------------===// - - - - - - - -namespace duckdb { - -void LogicalOperator::Serialize(Serializer &serializer) const { - serializer.WriteProperty(100, "type", type); - serializer.WritePropertyWithDefault>>(101, "children", children); -} - -unique_ptr LogicalOperator::Deserialize(Deserializer &deserializer) { - auto type = deserializer.ReadProperty(100, "type"); - auto children = deserializer.ReadPropertyWithDefault>>(101, "children"); - deserializer.Set(type); - unique_ptr result; - switch (type) { - case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: - result = LogicalAggregate::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_ALTER: - result = LogicalSimple::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_ANY_JOIN: - result = LogicalAnyJoin::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_ASOF_JOIN: - result = LogicalComparisonJoin::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_ATTACH: - result = LogicalSimple::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_CHUNK_GET: - result = LogicalColumnDataGet::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: - result = LogicalComparisonJoin::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_COPY_TO_FILE: - result = LogicalCopyToFile::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_CREATE_INDEX: - result = LogicalCreateIndex::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_CREATE_MACRO: - result = LogicalCreate::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_CREATE_SCHEMA: - result = LogicalCreate::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_CREATE_SEQUENCE: - result = LogicalCreate::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_CREATE_TABLE: - result = LogicalCreateTable::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_CREATE_TYPE: - result = LogicalCreate::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_CREATE_VIEW: - result = LogicalCreate::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_CROSS_PRODUCT: - result = LogicalCrossProduct::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_CTE_REF: - result = LogicalCTERef::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_DELETE: - result = LogicalDelete::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_DELIM_GET: - result = LogicalDelimGet::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_DELIM_JOIN: - result = LogicalComparisonJoin::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_DETACH: - result = LogicalSimple::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_DISTINCT: - result = LogicalDistinct::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_DROP: - result = LogicalSimple::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_DUMMY_SCAN: - result = LogicalDummyScan::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_EMPTY_RESULT: - result = LogicalEmptyResult::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_EXCEPT: - result = LogicalSetOperation::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_EXPLAIN: - result = LogicalExplain::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_EXPRESSION_GET: - result = LogicalExpressionGet::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_EXTENSION_OPERATOR: - result = LogicalExtensionOperator::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_FILTER: - result = LogicalFilter::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_GET: - result = LogicalGet::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_INSERT: - result = LogicalInsert::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_INTERSECT: - result = LogicalSetOperation::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_LIMIT: - result = LogicalLimit::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_LIMIT_PERCENT: - result = LogicalLimitPercent::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_LOAD: - result = LogicalSimple::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_MATERIALIZED_CTE: - result = LogicalMaterializedCTE::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_ORDER_BY: - result = LogicalOrder::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_PIVOT: - result = LogicalPivot::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_POSITIONAL_JOIN: - result = LogicalPositionalJoin::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_PROJECTION: - result = LogicalProjection::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_RECURSIVE_CTE: - result = LogicalRecursiveCTE::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_RESET: - result = LogicalReset::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_SAMPLE: - result = LogicalSample::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_SET: - result = LogicalSet::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_SHOW: - result = LogicalShow::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_TOP_N: - result = LogicalTopN::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_TRANSACTION: - result = LogicalSimple::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_UNION: - result = LogicalSetOperation::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_UNNEST: - result = LogicalUnnest::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_UPDATE: - result = LogicalUpdate::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_VACUUM: - result = LogicalSimple::Deserialize(deserializer); - break; - case LogicalOperatorType::LOGICAL_WINDOW: - result = LogicalWindow::Deserialize(deserializer); - break; - default: - throw SerializationException("Unsupported type for deserialization of LogicalOperator!"); - } - deserializer.Unset(); - result->children = std::move(children); - return result; -} - -void LogicalAggregate::Serialize(Serializer &serializer) const { - LogicalOperator::Serialize(serializer); - serializer.WritePropertyWithDefault>>(200, "expressions", expressions); - serializer.WritePropertyWithDefault(201, "group_index", group_index); - serializer.WritePropertyWithDefault(202, "aggregate_index", aggregate_index); - serializer.WritePropertyWithDefault(203, "groupings_index", groupings_index); - serializer.WritePropertyWithDefault>>(204, "groups", groups); - serializer.WritePropertyWithDefault>(205, "grouping_sets", grouping_sets); - serializer.WritePropertyWithDefault>>(206, "grouping_functions", grouping_functions); -} - -unique_ptr LogicalAggregate::Deserialize(Deserializer &deserializer) { - auto expressions = deserializer.ReadPropertyWithDefault>>(200, "expressions"); - auto group_index = deserializer.ReadPropertyWithDefault(201, "group_index"); - auto aggregate_index = deserializer.ReadPropertyWithDefault(202, "aggregate_index"); - auto result = duckdb::unique_ptr(new LogicalAggregate(group_index, aggregate_index, std::move(expressions))); - deserializer.ReadPropertyWithDefault(203, "groupings_index", result->groupings_index); - deserializer.ReadPropertyWithDefault>>(204, "groups", result->groups); - deserializer.ReadPropertyWithDefault>(205, "grouping_sets", result->grouping_sets); - deserializer.ReadPropertyWithDefault>>(206, "grouping_functions", result->grouping_functions); - return std::move(result); -} - -void LogicalAnyJoin::Serialize(Serializer &serializer) const { - LogicalOperator::Serialize(serializer); - serializer.WriteProperty(200, "join_type", join_type); - serializer.WritePropertyWithDefault(201, "mark_index", mark_index); - serializer.WritePropertyWithDefault>(202, "left_projection_map", left_projection_map); - serializer.WritePropertyWithDefault>(203, "right_projection_map", right_projection_map); - serializer.WritePropertyWithDefault>(204, "condition", condition); -} - -unique_ptr LogicalAnyJoin::Deserialize(Deserializer &deserializer) { - auto join_type = deserializer.ReadProperty(200, "join_type"); - auto result = duckdb::unique_ptr(new LogicalAnyJoin(join_type)); - deserializer.ReadPropertyWithDefault(201, "mark_index", result->mark_index); - deserializer.ReadPropertyWithDefault>(202, "left_projection_map", result->left_projection_map); - deserializer.ReadPropertyWithDefault>(203, "right_projection_map", result->right_projection_map); - deserializer.ReadPropertyWithDefault>(204, "condition", result->condition); - return std::move(result); -} - -void LogicalCTERef::Serialize(Serializer &serializer) const { - LogicalOperator::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "table_index", table_index); - serializer.WritePropertyWithDefault(201, "cte_index", cte_index); - serializer.WritePropertyWithDefault>(202, "chunk_types", chunk_types); - serializer.WritePropertyWithDefault>(203, "bound_columns", bound_columns); - serializer.WriteProperty(204, "materialized_cte", materialized_cte); -} - -unique_ptr LogicalCTERef::Deserialize(Deserializer &deserializer) { - auto table_index = deserializer.ReadPropertyWithDefault(200, "table_index"); - auto cte_index = deserializer.ReadPropertyWithDefault(201, "cte_index"); - auto chunk_types = deserializer.ReadPropertyWithDefault>(202, "chunk_types"); - auto bound_columns = deserializer.ReadPropertyWithDefault>(203, "bound_columns"); - auto materialized_cte = deserializer.ReadProperty(204, "materialized_cte"); - auto result = duckdb::unique_ptr(new LogicalCTERef(table_index, cte_index, std::move(chunk_types), std::move(bound_columns), materialized_cte)); - return std::move(result); -} - -void LogicalColumnDataGet::Serialize(Serializer &serializer) const { - LogicalOperator::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "table_index", table_index); - serializer.WritePropertyWithDefault>(201, "chunk_types", chunk_types); - serializer.WritePropertyWithDefault>(202, "collection", collection); -} - -unique_ptr LogicalColumnDataGet::Deserialize(Deserializer &deserializer) { - auto table_index = deserializer.ReadPropertyWithDefault(200, "table_index"); - auto chunk_types = deserializer.ReadPropertyWithDefault>(201, "chunk_types"); - auto collection = deserializer.ReadPropertyWithDefault>(202, "collection"); - auto result = duckdb::unique_ptr(new LogicalColumnDataGet(table_index, std::move(chunk_types), std::move(collection))); - return std::move(result); -} - -void LogicalComparisonJoin::Serialize(Serializer &serializer) const { - LogicalOperator::Serialize(serializer); - serializer.WriteProperty(200, "join_type", join_type); - serializer.WritePropertyWithDefault(201, "mark_index", mark_index); - serializer.WritePropertyWithDefault>(202, "left_projection_map", left_projection_map); - serializer.WritePropertyWithDefault>(203, "right_projection_map", right_projection_map); - serializer.WritePropertyWithDefault>(204, "conditions", conditions); - serializer.WritePropertyWithDefault>(205, "mark_types", mark_types); - serializer.WritePropertyWithDefault>>(206, "duplicate_eliminated_columns", duplicate_eliminated_columns); -} - -unique_ptr LogicalComparisonJoin::Deserialize(Deserializer &deserializer) { - auto join_type = deserializer.ReadProperty(200, "join_type"); - auto result = duckdb::unique_ptr(new LogicalComparisonJoin(join_type, deserializer.Get())); - deserializer.ReadPropertyWithDefault(201, "mark_index", result->mark_index); - deserializer.ReadPropertyWithDefault>(202, "left_projection_map", result->left_projection_map); - deserializer.ReadPropertyWithDefault>(203, "right_projection_map", result->right_projection_map); - deserializer.ReadPropertyWithDefault>(204, "conditions", result->conditions); - deserializer.ReadPropertyWithDefault>(205, "mark_types", result->mark_types); - deserializer.ReadPropertyWithDefault>>(206, "duplicate_eliminated_columns", result->duplicate_eliminated_columns); - return std::move(result); -} - -void LogicalCreate::Serialize(Serializer &serializer) const { - LogicalOperator::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "info", info); -} - -unique_ptr LogicalCreate::Deserialize(Deserializer &deserializer) { - auto info = deserializer.ReadPropertyWithDefault>(200, "info"); - auto result = duckdb::unique_ptr(new LogicalCreate(deserializer.Get(), deserializer.Get(), std::move(info))); - return std::move(result); -} - -void LogicalCreateIndex::Serialize(Serializer &serializer) const { - LogicalOperator::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "info", info); - serializer.WritePropertyWithDefault>>(201, "unbound_expressions", unbound_expressions); -} - -unique_ptr LogicalCreateIndex::Deserialize(Deserializer &deserializer) { - auto info = deserializer.ReadPropertyWithDefault>(200, "info"); - auto unbound_expressions = deserializer.ReadPropertyWithDefault>>(201, "unbound_expressions"); - auto result = duckdb::unique_ptr(new LogicalCreateIndex(deserializer.Get(), std::move(info), std::move(unbound_expressions))); - return std::move(result); -} - -void LogicalCreateTable::Serialize(Serializer &serializer) const { - LogicalOperator::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "info", info->base); -} - -unique_ptr LogicalCreateTable::Deserialize(Deserializer &deserializer) { - auto info = deserializer.ReadPropertyWithDefault>(200, "info"); - auto result = duckdb::unique_ptr(new LogicalCreateTable(deserializer.Get(), std::move(info))); - return std::move(result); -} - -void LogicalCrossProduct::Serialize(Serializer &serializer) const { - LogicalOperator::Serialize(serializer); -} - -unique_ptr LogicalCrossProduct::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new LogicalCrossProduct()); - return std::move(result); -} - -void LogicalDelete::Serialize(Serializer &serializer) const { - LogicalOperator::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "table_info", table.GetInfo()); - serializer.WritePropertyWithDefault(201, "table_index", table_index); - serializer.WritePropertyWithDefault(202, "return_chunk", return_chunk); - serializer.WritePropertyWithDefault>>(203, "expressions", expressions); -} - -unique_ptr LogicalDelete::Deserialize(Deserializer &deserializer) { - auto table_info = deserializer.ReadPropertyWithDefault>(200, "table_info"); - auto result = duckdb::unique_ptr(new LogicalDelete(deserializer.Get(), table_info)); - deserializer.ReadPropertyWithDefault(201, "table_index", result->table_index); - deserializer.ReadPropertyWithDefault(202, "return_chunk", result->return_chunk); - deserializer.ReadPropertyWithDefault>>(203, "expressions", result->expressions); - return std::move(result); -} - -void LogicalDelimGet::Serialize(Serializer &serializer) const { - LogicalOperator::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "table_index", table_index); - serializer.WritePropertyWithDefault>(201, "chunk_types", chunk_types); -} - -unique_ptr LogicalDelimGet::Deserialize(Deserializer &deserializer) { - auto table_index = deserializer.ReadPropertyWithDefault(200, "table_index"); - auto chunk_types = deserializer.ReadPropertyWithDefault>(201, "chunk_types"); - auto result = duckdb::unique_ptr(new LogicalDelimGet(table_index, std::move(chunk_types))); - return std::move(result); -} - -void LogicalDistinct::Serialize(Serializer &serializer) const { - LogicalOperator::Serialize(serializer); - serializer.WriteProperty(200, "distinct_type", distinct_type); - serializer.WritePropertyWithDefault>>(201, "distinct_targets", distinct_targets); - serializer.WritePropertyWithDefault>(202, "order_by", order_by); -} - -unique_ptr LogicalDistinct::Deserialize(Deserializer &deserializer) { - auto distinct_type = deserializer.ReadProperty(200, "distinct_type"); - auto distinct_targets = deserializer.ReadPropertyWithDefault>>(201, "distinct_targets"); - auto result = duckdb::unique_ptr(new LogicalDistinct(std::move(distinct_targets), distinct_type)); - deserializer.ReadPropertyWithDefault>(202, "order_by", result->order_by); - return std::move(result); -} - -void LogicalDummyScan::Serialize(Serializer &serializer) const { - LogicalOperator::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "table_index", table_index); -} - -unique_ptr LogicalDummyScan::Deserialize(Deserializer &deserializer) { - auto table_index = deserializer.ReadPropertyWithDefault(200, "table_index"); - auto result = duckdb::unique_ptr(new LogicalDummyScan(table_index)); - return std::move(result); -} - -void LogicalEmptyResult::Serialize(Serializer &serializer) const { - LogicalOperator::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "return_types", return_types); - serializer.WritePropertyWithDefault>(201, "bindings", bindings); -} - -unique_ptr LogicalEmptyResult::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new LogicalEmptyResult()); - deserializer.ReadPropertyWithDefault>(200, "return_types", result->return_types); - deserializer.ReadPropertyWithDefault>(201, "bindings", result->bindings); - return std::move(result); -} - -void LogicalExplain::Serialize(Serializer &serializer) const { - LogicalOperator::Serialize(serializer); - serializer.WriteProperty(200, "explain_type", explain_type); - serializer.WritePropertyWithDefault(201, "physical_plan", physical_plan); - serializer.WritePropertyWithDefault(202, "logical_plan_unopt", logical_plan_unopt); - serializer.WritePropertyWithDefault(203, "logical_plan_opt", logical_plan_opt); -} - -unique_ptr LogicalExplain::Deserialize(Deserializer &deserializer) { - auto explain_type = deserializer.ReadProperty(200, "explain_type"); - auto result = duckdb::unique_ptr(new LogicalExplain(explain_type)); - deserializer.ReadPropertyWithDefault(201, "physical_plan", result->physical_plan); - deserializer.ReadPropertyWithDefault(202, "logical_plan_unopt", result->logical_plan_unopt); - deserializer.ReadPropertyWithDefault(203, "logical_plan_opt", result->logical_plan_opt); - return std::move(result); -} - -void LogicalExpressionGet::Serialize(Serializer &serializer) const { - LogicalOperator::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "table_index", table_index); - serializer.WritePropertyWithDefault>(201, "expr_types", expr_types); - serializer.WritePropertyWithDefault>>>(202, "expressions", expressions); -} - -unique_ptr LogicalExpressionGet::Deserialize(Deserializer &deserializer) { - auto table_index = deserializer.ReadPropertyWithDefault(200, "table_index"); - auto expr_types = deserializer.ReadPropertyWithDefault>(201, "expr_types"); - auto expressions = deserializer.ReadPropertyWithDefault>>>(202, "expressions"); - auto result = duckdb::unique_ptr(new LogicalExpressionGet(table_index, std::move(expr_types), std::move(expressions))); - return std::move(result); -} - -void LogicalFilter::Serialize(Serializer &serializer) const { - LogicalOperator::Serialize(serializer); - serializer.WritePropertyWithDefault>>(200, "expressions", expressions); - serializer.WritePropertyWithDefault>(201, "projection_map", projection_map); -} - -unique_ptr LogicalFilter::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new LogicalFilter()); - deserializer.ReadPropertyWithDefault>>(200, "expressions", result->expressions); - deserializer.ReadPropertyWithDefault>(201, "projection_map", result->projection_map); - return std::move(result); -} - -void LogicalInsert::Serialize(Serializer &serializer) const { - LogicalOperator::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "table_info", table.GetInfo()); - serializer.WritePropertyWithDefault>>>(201, "insert_values", insert_values); - serializer.WriteProperty>(202, "column_index_map", column_index_map); - serializer.WritePropertyWithDefault>(203, "expected_types", expected_types); - serializer.WritePropertyWithDefault(204, "table_index", table_index); - serializer.WritePropertyWithDefault(205, "return_chunk", return_chunk); - serializer.WritePropertyWithDefault>>(206, "bound_defaults", bound_defaults); - serializer.WriteProperty(207, "action_type", action_type); - serializer.WritePropertyWithDefault>(208, "expected_set_types", expected_set_types); - serializer.WritePropertyWithDefault>(209, "on_conflict_filter", on_conflict_filter); - serializer.WritePropertyWithDefault>(210, "on_conflict_condition", on_conflict_condition); - serializer.WritePropertyWithDefault>(211, "do_update_condition", do_update_condition); - serializer.WritePropertyWithDefault>(212, "set_columns", set_columns); - serializer.WritePropertyWithDefault>(213, "set_types", set_types); - serializer.WritePropertyWithDefault(214, "excluded_table_index", excluded_table_index); - serializer.WritePropertyWithDefault>(215, "columns_to_fetch", columns_to_fetch); - serializer.WritePropertyWithDefault>(216, "source_columns", source_columns); - serializer.WritePropertyWithDefault>>(217, "expressions", expressions); -} - -unique_ptr LogicalInsert::Deserialize(Deserializer &deserializer) { - auto table_info = deserializer.ReadPropertyWithDefault>(200, "table_info"); - auto result = duckdb::unique_ptr(new LogicalInsert(deserializer.Get(), std::move(table_info))); - deserializer.ReadPropertyWithDefault>>>(201, "insert_values", result->insert_values); - deserializer.ReadProperty>(202, "column_index_map", result->column_index_map); - deserializer.ReadPropertyWithDefault>(203, "expected_types", result->expected_types); - deserializer.ReadPropertyWithDefault(204, "table_index", result->table_index); - deserializer.ReadPropertyWithDefault(205, "return_chunk", result->return_chunk); - deserializer.ReadPropertyWithDefault>>(206, "bound_defaults", result->bound_defaults); - deserializer.ReadProperty(207, "action_type", result->action_type); - deserializer.ReadPropertyWithDefault>(208, "expected_set_types", result->expected_set_types); - deserializer.ReadPropertyWithDefault>(209, "on_conflict_filter", result->on_conflict_filter); - deserializer.ReadPropertyWithDefault>(210, "on_conflict_condition", result->on_conflict_condition); - deserializer.ReadPropertyWithDefault>(211, "do_update_condition", result->do_update_condition); - deserializer.ReadPropertyWithDefault>(212, "set_columns", result->set_columns); - deserializer.ReadPropertyWithDefault>(213, "set_types", result->set_types); - deserializer.ReadPropertyWithDefault(214, "excluded_table_index", result->excluded_table_index); - deserializer.ReadPropertyWithDefault>(215, "columns_to_fetch", result->columns_to_fetch); - deserializer.ReadPropertyWithDefault>(216, "source_columns", result->source_columns); - deserializer.ReadPropertyWithDefault>>(217, "expressions", result->expressions); - return std::move(result); -} - -void LogicalLimit::Serialize(Serializer &serializer) const { - LogicalOperator::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "limit_val", limit_val); - serializer.WritePropertyWithDefault(201, "offset_val", offset_val); - serializer.WritePropertyWithDefault>(202, "limit", limit); - serializer.WritePropertyWithDefault>(203, "offset", offset); -} - -unique_ptr LogicalLimit::Deserialize(Deserializer &deserializer) { - auto limit_val = deserializer.ReadPropertyWithDefault(200, "limit_val"); - auto offset_val = deserializer.ReadPropertyWithDefault(201, "offset_val"); - auto limit = deserializer.ReadPropertyWithDefault>(202, "limit"); - auto offset = deserializer.ReadPropertyWithDefault>(203, "offset"); - auto result = duckdb::unique_ptr(new LogicalLimit(limit_val, offset_val, std::move(limit), std::move(offset))); - return std::move(result); -} - -void LogicalLimitPercent::Serialize(Serializer &serializer) const { - LogicalOperator::Serialize(serializer); - serializer.WriteProperty(200, "limit_percent", limit_percent); - serializer.WritePropertyWithDefault(201, "offset_val", offset_val); - serializer.WritePropertyWithDefault>(202, "limit", limit); - serializer.WritePropertyWithDefault>(203, "offset", offset); -} - -unique_ptr LogicalLimitPercent::Deserialize(Deserializer &deserializer) { - auto limit_percent = deserializer.ReadProperty(200, "limit_percent"); - auto offset_val = deserializer.ReadPropertyWithDefault(201, "offset_val"); - auto limit = deserializer.ReadPropertyWithDefault>(202, "limit"); - auto offset = deserializer.ReadPropertyWithDefault>(203, "offset"); - auto result = duckdb::unique_ptr(new LogicalLimitPercent(limit_percent, offset_val, std::move(limit), std::move(offset))); - return std::move(result); -} - -void LogicalMaterializedCTE::Serialize(Serializer &serializer) const { - LogicalOperator::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "table_index", table_index); - serializer.WritePropertyWithDefault(201, "column_count", column_count); - serializer.WritePropertyWithDefault(202, "ctename", ctename); -} - -unique_ptr LogicalMaterializedCTE::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new LogicalMaterializedCTE()); - deserializer.ReadPropertyWithDefault(200, "table_index", result->table_index); - deserializer.ReadPropertyWithDefault(201, "column_count", result->column_count); - deserializer.ReadPropertyWithDefault(202, "ctename", result->ctename); - return std::move(result); -} - -void LogicalOrder::Serialize(Serializer &serializer) const { - LogicalOperator::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "orders", orders); - serializer.WritePropertyWithDefault>(201, "projections", projections); -} - -unique_ptr LogicalOrder::Deserialize(Deserializer &deserializer) { - auto orders = deserializer.ReadPropertyWithDefault>(200, "orders"); - auto result = duckdb::unique_ptr(new LogicalOrder(std::move(orders))); - deserializer.ReadPropertyWithDefault>(201, "projections", result->projections); - return std::move(result); -} - -void LogicalPivot::Serialize(Serializer &serializer) const { - LogicalOperator::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "pivot_index", pivot_index); - serializer.WriteProperty(201, "bound_pivot", bound_pivot); -} - -unique_ptr LogicalPivot::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new LogicalPivot()); - deserializer.ReadPropertyWithDefault(200, "pivot_index", result->pivot_index); - deserializer.ReadProperty(201, "bound_pivot", result->bound_pivot); - return std::move(result); -} - -void LogicalPositionalJoin::Serialize(Serializer &serializer) const { - LogicalOperator::Serialize(serializer); -} - -unique_ptr LogicalPositionalJoin::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new LogicalPositionalJoin()); - return std::move(result); -} - -void LogicalProjection::Serialize(Serializer &serializer) const { - LogicalOperator::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "table_index", table_index); - serializer.WritePropertyWithDefault>>(201, "expressions", expressions); -} - -unique_ptr LogicalProjection::Deserialize(Deserializer &deserializer) { - auto table_index = deserializer.ReadPropertyWithDefault(200, "table_index"); - auto expressions = deserializer.ReadPropertyWithDefault>>(201, "expressions"); - auto result = duckdb::unique_ptr(new LogicalProjection(table_index, std::move(expressions))); - return std::move(result); -} - -void LogicalRecursiveCTE::Serialize(Serializer &serializer) const { - LogicalOperator::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "union_all", union_all); - serializer.WritePropertyWithDefault(201, "ctename", ctename); - serializer.WritePropertyWithDefault(202, "table_index", table_index); - serializer.WritePropertyWithDefault(203, "column_count", column_count); -} - -unique_ptr LogicalRecursiveCTE::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new LogicalRecursiveCTE()); - deserializer.ReadPropertyWithDefault(200, "union_all", result->union_all); - deserializer.ReadPropertyWithDefault(201, "ctename", result->ctename); - deserializer.ReadPropertyWithDefault(202, "table_index", result->table_index); - deserializer.ReadPropertyWithDefault(203, "column_count", result->column_count); - return std::move(result); -} - -void LogicalReset::Serialize(Serializer &serializer) const { - LogicalOperator::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "name", name); - serializer.WriteProperty(201, "scope", scope); -} - -unique_ptr LogicalReset::Deserialize(Deserializer &deserializer) { - auto name = deserializer.ReadPropertyWithDefault(200, "name"); - auto scope = deserializer.ReadProperty(201, "scope"); - auto result = duckdb::unique_ptr(new LogicalReset(std::move(name), scope)); - return std::move(result); -} - -void LogicalSample::Serialize(Serializer &serializer) const { - LogicalOperator::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "sample_options", sample_options); -} - -unique_ptr LogicalSample::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new LogicalSample()); - deserializer.ReadPropertyWithDefault>(200, "sample_options", result->sample_options); - return std::move(result); -} - -void LogicalSet::Serialize(Serializer &serializer) const { - LogicalOperator::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "name", name); - serializer.WriteProperty(201, "value", value); - serializer.WriteProperty(202, "scope", scope); -} - -unique_ptr LogicalSet::Deserialize(Deserializer &deserializer) { - auto name = deserializer.ReadPropertyWithDefault(200, "name"); - auto value = deserializer.ReadProperty(201, "value"); - auto scope = deserializer.ReadProperty(202, "scope"); - auto result = duckdb::unique_ptr(new LogicalSet(std::move(name), value, scope)); - return std::move(result); -} - -void LogicalSetOperation::Serialize(Serializer &serializer) const { - LogicalOperator::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "table_index", table_index); - serializer.WritePropertyWithDefault(201, "column_count", column_count); -} - -unique_ptr LogicalSetOperation::Deserialize(Deserializer &deserializer) { - auto table_index = deserializer.ReadPropertyWithDefault(200, "table_index"); - auto column_count = deserializer.ReadPropertyWithDefault(201, "column_count"); - auto result = duckdb::unique_ptr(new LogicalSetOperation(table_index, column_count, deserializer.Get())); - return std::move(result); -} - -void LogicalShow::Serialize(Serializer &serializer) const { - LogicalOperator::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "types_select", types_select); - serializer.WritePropertyWithDefault>(201, "aliases", aliases); -} - -unique_ptr LogicalShow::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new LogicalShow()); - deserializer.ReadPropertyWithDefault>(200, "types_select", result->types_select); - deserializer.ReadPropertyWithDefault>(201, "aliases", result->aliases); - return std::move(result); -} - -void LogicalSimple::Serialize(Serializer &serializer) const { - LogicalOperator::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "info", info); -} - -unique_ptr LogicalSimple::Deserialize(Deserializer &deserializer) { - auto info = deserializer.ReadPropertyWithDefault>(200, "info"); - auto result = duckdb::unique_ptr(new LogicalSimple(deserializer.Get(), std::move(info))); - return std::move(result); -} - -void LogicalTopN::Serialize(Serializer &serializer) const { - LogicalOperator::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "orders", orders); - serializer.WritePropertyWithDefault(201, "limit", limit); - serializer.WritePropertyWithDefault(202, "offset", offset); -} - -unique_ptr LogicalTopN::Deserialize(Deserializer &deserializer) { - auto orders = deserializer.ReadPropertyWithDefault>(200, "orders"); - auto limit = deserializer.ReadPropertyWithDefault(201, "limit"); - auto offset = deserializer.ReadPropertyWithDefault(202, "offset"); - auto result = duckdb::unique_ptr(new LogicalTopN(std::move(orders), limit, offset)); - return std::move(result); -} - -void LogicalUnnest::Serialize(Serializer &serializer) const { - LogicalOperator::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "unnest_index", unnest_index); - serializer.WritePropertyWithDefault>>(201, "expressions", expressions); -} - -unique_ptr LogicalUnnest::Deserialize(Deserializer &deserializer) { - auto unnest_index = deserializer.ReadPropertyWithDefault(200, "unnest_index"); - auto result = duckdb::unique_ptr(new LogicalUnnest(unnest_index)); - deserializer.ReadPropertyWithDefault>>(201, "expressions", result->expressions); - return std::move(result); -} - -void LogicalUpdate::Serialize(Serializer &serializer) const { - LogicalOperator::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "table_info", table.GetInfo()); - serializer.WritePropertyWithDefault(201, "table_index", table_index); - serializer.WritePropertyWithDefault(202, "return_chunk", return_chunk); - serializer.WritePropertyWithDefault>>(203, "expressions", expressions); - serializer.WritePropertyWithDefault>(204, "columns", columns); - serializer.WritePropertyWithDefault>>(205, "bound_defaults", bound_defaults); - serializer.WritePropertyWithDefault(206, "update_is_del_and_insert", update_is_del_and_insert); -} - -unique_ptr LogicalUpdate::Deserialize(Deserializer &deserializer) { - auto table_info = deserializer.ReadPropertyWithDefault>(200, "table_info"); - auto result = duckdb::unique_ptr(new LogicalUpdate(deserializer.Get(), table_info)); - deserializer.ReadPropertyWithDefault(201, "table_index", result->table_index); - deserializer.ReadPropertyWithDefault(202, "return_chunk", result->return_chunk); - deserializer.ReadPropertyWithDefault>>(203, "expressions", result->expressions); - deserializer.ReadPropertyWithDefault>(204, "columns", result->columns); - deserializer.ReadPropertyWithDefault>>(205, "bound_defaults", result->bound_defaults); - deserializer.ReadPropertyWithDefault(206, "update_is_del_and_insert", result->update_is_del_and_insert); - return std::move(result); -} - -void LogicalWindow::Serialize(Serializer &serializer) const { - LogicalOperator::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "window_index", window_index); - serializer.WritePropertyWithDefault>>(201, "expressions", expressions); -} - -unique_ptr LogicalWindow::Deserialize(Deserializer &deserializer) { - auto window_index = deserializer.ReadPropertyWithDefault(200, "window_index"); - auto result = duckdb::unique_ptr(new LogicalWindow(window_index)); - deserializer.ReadPropertyWithDefault>>(201, "expressions", result->expressions); - return std::move(result); -} - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// This file is automatically generated by scripts/generate_serialization.py -// Do not edit this file manually, your changes will be overwritten -//===----------------------------------------------------------------------===// - - - - - - - -namespace duckdb { - -void MacroFunction::Serialize(Serializer &serializer) const { - serializer.WriteProperty(100, "type", type); - serializer.WritePropertyWithDefault>>(101, "parameters", parameters); - serializer.WritePropertyWithDefault>>(102, "default_parameters", default_parameters); -} - -unique_ptr MacroFunction::Deserialize(Deserializer &deserializer) { - auto type = deserializer.ReadProperty(100, "type"); - auto parameters = deserializer.ReadPropertyWithDefault>>(101, "parameters"); - auto default_parameters = deserializer.ReadPropertyWithDefault>>(102, "default_parameters"); - unique_ptr result; - switch (type) { - case MacroType::SCALAR_MACRO: - result = ScalarMacroFunction::Deserialize(deserializer); - break; - case MacroType::TABLE_MACRO: - result = TableMacroFunction::Deserialize(deserializer); - break; - default: - throw SerializationException("Unsupported type for deserialization of MacroFunction!"); - } - result->parameters = std::move(parameters); - result->default_parameters = std::move(default_parameters); - return result; -} - -void ScalarMacroFunction::Serialize(Serializer &serializer) const { - MacroFunction::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "expression", expression); -} - -unique_ptr ScalarMacroFunction::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new ScalarMacroFunction()); - deserializer.ReadPropertyWithDefault>(200, "expression", result->expression); - return std::move(result); -} - -void TableMacroFunction::Serialize(Serializer &serializer) const { - MacroFunction::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "query_node", query_node); -} - -unique_ptr TableMacroFunction::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new TableMacroFunction()); - deserializer.ReadPropertyWithDefault>(200, "query_node", result->query_node); - return std::move(result); -} - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// This file is automatically generated by scripts/generate_serialization.py -// Do not edit this file manually, your changes will be overwritten -//===----------------------------------------------------------------------===// - - - - - - - - - - - - - - - - - - - - - - - - - - - - -namespace duckdb { - -void BoundCaseCheck::Serialize(Serializer &serializer) const { - serializer.WritePropertyWithDefault>(100, "when_expr", when_expr); - serializer.WritePropertyWithDefault>(101, "then_expr", then_expr); -} - -BoundCaseCheck BoundCaseCheck::Deserialize(Deserializer &deserializer) { - BoundCaseCheck result; - deserializer.ReadPropertyWithDefault>(100, "when_expr", result.when_expr); - deserializer.ReadPropertyWithDefault>(101, "then_expr", result.then_expr); - return result; -} - -void BoundOrderByNode::Serialize(Serializer &serializer) const { - serializer.WriteProperty(100, "type", type); - serializer.WriteProperty(101, "null_order", null_order); - serializer.WritePropertyWithDefault>(102, "expression", expression); -} - -BoundOrderByNode BoundOrderByNode::Deserialize(Deserializer &deserializer) { - auto type = deserializer.ReadProperty(100, "type"); - auto null_order = deserializer.ReadProperty(101, "null_order"); - auto expression = deserializer.ReadPropertyWithDefault>(102, "expression"); - BoundOrderByNode result(type, null_order, std::move(expression)); - return result; -} - -void BoundParameterData::Serialize(Serializer &serializer) const { - serializer.WriteProperty(100, "value", value); - serializer.WriteProperty(101, "return_type", return_type); -} - -shared_ptr BoundParameterData::Deserialize(Deserializer &deserializer) { - auto value = deserializer.ReadProperty(100, "value"); - auto result = duckdb::shared_ptr(new BoundParameterData(value)); - deserializer.ReadProperty(101, "return_type", result->return_type); - return result; -} - -void BoundPivotInfo::Serialize(Serializer &serializer) const { - serializer.WritePropertyWithDefault(100, "group_count", group_count); - serializer.WritePropertyWithDefault>(101, "types", types); - serializer.WritePropertyWithDefault>(102, "pivot_values", pivot_values); - serializer.WritePropertyWithDefault>>(103, "aggregates", aggregates); -} - -BoundPivotInfo BoundPivotInfo::Deserialize(Deserializer &deserializer) { - BoundPivotInfo result; - deserializer.ReadPropertyWithDefault(100, "group_count", result.group_count); - deserializer.ReadPropertyWithDefault>(101, "types", result.types); - deserializer.ReadPropertyWithDefault>(102, "pivot_values", result.pivot_values); - deserializer.ReadPropertyWithDefault>>(103, "aggregates", result.aggregates); - return result; -} - -void CSVReaderOptions::Serialize(Serializer &serializer) const { - serializer.WritePropertyWithDefault(100, "has_delimiter", has_delimiter); - serializer.WritePropertyWithDefault(101, "has_quote", has_quote); - serializer.WritePropertyWithDefault(102, "has_escape", has_escape); - serializer.WritePropertyWithDefault(103, "has_header", has_header); - serializer.WritePropertyWithDefault(104, "ignore_errors", ignore_errors); - serializer.WritePropertyWithDefault(105, "buffer_sample_size", buffer_sample_size); - serializer.WritePropertyWithDefault(106, "null_str", null_str); - serializer.WriteProperty(107, "compression", compression); - serializer.WritePropertyWithDefault(108, "allow_quoted_nulls", allow_quoted_nulls); - serializer.WritePropertyWithDefault(109, "skip_rows_set", skip_rows_set); - serializer.WritePropertyWithDefault(110, "maximum_line_size", maximum_line_size); - serializer.WritePropertyWithDefault(111, "normalize_names", normalize_names); - serializer.WritePropertyWithDefault>(112, "force_not_null", force_not_null); - serializer.WritePropertyWithDefault(113, "all_varchar", all_varchar); - serializer.WritePropertyWithDefault(114, "sample_size_chunks", sample_size_chunks); - serializer.WritePropertyWithDefault(115, "auto_detect", auto_detect); - serializer.WritePropertyWithDefault(116, "file_path", file_path); - serializer.WritePropertyWithDefault(117, "decimal_separator", decimal_separator); - serializer.WritePropertyWithDefault(118, "null_padding", null_padding); - serializer.WritePropertyWithDefault(119, "buffer_size", buffer_size); - serializer.WriteProperty(120, "file_options", file_options); - serializer.WritePropertyWithDefault>(121, "force_quote", force_quote); - serializer.WritePropertyWithDefault(122, "rejects_table_name", rejects_table_name); - serializer.WritePropertyWithDefault(123, "rejects_limit", rejects_limit); - serializer.WritePropertyWithDefault>(124, "rejects_recovery_columns", rejects_recovery_columns); - serializer.WritePropertyWithDefault>(125, "rejects_recovery_column_ids", rejects_recovery_column_ids); - serializer.WriteProperty(126, "dialect_options.state_machine_options.delimiter", dialect_options.state_machine_options.delimiter); - serializer.WriteProperty(127, "dialect_options.state_machine_options.quote", dialect_options.state_machine_options.quote); - serializer.WriteProperty(128, "dialect_options.state_machine_options.escape", dialect_options.state_machine_options.escape); - serializer.WritePropertyWithDefault(129, "dialect_options.header", dialect_options.header); - serializer.WritePropertyWithDefault(130, "dialect_options.num_cols", dialect_options.num_cols); - serializer.WriteProperty(131, "dialect_options.new_line", dialect_options.new_line); - serializer.WritePropertyWithDefault(132, "dialect_options.skip_rows", dialect_options.skip_rows); - serializer.WritePropertyWithDefault>(133, "dialect_options.date_format", dialect_options.date_format); - serializer.WritePropertyWithDefault>(134, "dialect_options.has_format", dialect_options.has_format); -} - -CSVReaderOptions CSVReaderOptions::Deserialize(Deserializer &deserializer) { - CSVReaderOptions result; - deserializer.ReadPropertyWithDefault(100, "has_delimiter", result.has_delimiter); - deserializer.ReadPropertyWithDefault(101, "has_quote", result.has_quote); - deserializer.ReadPropertyWithDefault(102, "has_escape", result.has_escape); - deserializer.ReadPropertyWithDefault(103, "has_header", result.has_header); - deserializer.ReadPropertyWithDefault(104, "ignore_errors", result.ignore_errors); - deserializer.ReadPropertyWithDefault(105, "buffer_sample_size", result.buffer_sample_size); - deserializer.ReadPropertyWithDefault(106, "null_str", result.null_str); - deserializer.ReadProperty(107, "compression", result.compression); - deserializer.ReadPropertyWithDefault(108, "allow_quoted_nulls", result.allow_quoted_nulls); - deserializer.ReadPropertyWithDefault(109, "skip_rows_set", result.skip_rows_set); - deserializer.ReadPropertyWithDefault(110, "maximum_line_size", result.maximum_line_size); - deserializer.ReadPropertyWithDefault(111, "normalize_names", result.normalize_names); - deserializer.ReadPropertyWithDefault>(112, "force_not_null", result.force_not_null); - deserializer.ReadPropertyWithDefault(113, "all_varchar", result.all_varchar); - deserializer.ReadPropertyWithDefault(114, "sample_size_chunks", result.sample_size_chunks); - deserializer.ReadPropertyWithDefault(115, "auto_detect", result.auto_detect); - deserializer.ReadPropertyWithDefault(116, "file_path", result.file_path); - deserializer.ReadPropertyWithDefault(117, "decimal_separator", result.decimal_separator); - deserializer.ReadPropertyWithDefault(118, "null_padding", result.null_padding); - deserializer.ReadPropertyWithDefault(119, "buffer_size", result.buffer_size); - deserializer.ReadProperty(120, "file_options", result.file_options); - deserializer.ReadPropertyWithDefault>(121, "force_quote", result.force_quote); - deserializer.ReadPropertyWithDefault(122, "rejects_table_name", result.rejects_table_name); - deserializer.ReadPropertyWithDefault(123, "rejects_limit", result.rejects_limit); - deserializer.ReadPropertyWithDefault>(124, "rejects_recovery_columns", result.rejects_recovery_columns); - deserializer.ReadPropertyWithDefault>(125, "rejects_recovery_column_ids", result.rejects_recovery_column_ids); - deserializer.ReadProperty(126, "dialect_options.state_machine_options.delimiter", result.dialect_options.state_machine_options.delimiter); - deserializer.ReadProperty(127, "dialect_options.state_machine_options.quote", result.dialect_options.state_machine_options.quote); - deserializer.ReadProperty(128, "dialect_options.state_machine_options.escape", result.dialect_options.state_machine_options.escape); - deserializer.ReadPropertyWithDefault(129, "dialect_options.header", result.dialect_options.header); - deserializer.ReadPropertyWithDefault(130, "dialect_options.num_cols", result.dialect_options.num_cols); - deserializer.ReadProperty(131, "dialect_options.new_line", result.dialect_options.new_line); - deserializer.ReadPropertyWithDefault(132, "dialect_options.skip_rows", result.dialect_options.skip_rows); - deserializer.ReadPropertyWithDefault>(133, "dialect_options.date_format", result.dialect_options.date_format); - deserializer.ReadPropertyWithDefault>(134, "dialect_options.has_format", result.dialect_options.has_format); - return result; -} - -void CaseCheck::Serialize(Serializer &serializer) const { - serializer.WritePropertyWithDefault>(100, "when_expr", when_expr); - serializer.WritePropertyWithDefault>(101, "then_expr", then_expr); -} - -CaseCheck CaseCheck::Deserialize(Deserializer &deserializer) { - CaseCheck result; - deserializer.ReadPropertyWithDefault>(100, "when_expr", result.when_expr); - deserializer.ReadPropertyWithDefault>(101, "then_expr", result.then_expr); - return result; -} - -void ColumnBinding::Serialize(Serializer &serializer) const { - serializer.WritePropertyWithDefault(100, "table_index", table_index); - serializer.WritePropertyWithDefault(101, "column_index", column_index); -} - -ColumnBinding ColumnBinding::Deserialize(Deserializer &deserializer) { - ColumnBinding result; - deserializer.ReadPropertyWithDefault(100, "table_index", result.table_index); - deserializer.ReadPropertyWithDefault(101, "column_index", result.column_index); - return result; -} - -void ColumnDefinition::Serialize(Serializer &serializer) const { - serializer.WritePropertyWithDefault(100, "name", name); - serializer.WriteProperty(101, "type", type); - serializer.WritePropertyWithDefault>(102, "expression", expression); - serializer.WriteProperty(103, "category", category); - serializer.WriteProperty(104, "compression_type", compression_type); -} - -ColumnDefinition ColumnDefinition::Deserialize(Deserializer &deserializer) { - auto name = deserializer.ReadPropertyWithDefault(100, "name"); - auto type = deserializer.ReadProperty(101, "type"); - auto expression = deserializer.ReadPropertyWithDefault>(102, "expression"); - auto category = deserializer.ReadProperty(103, "category"); - ColumnDefinition result(std::move(name), std::move(type), std::move(expression), category); - deserializer.ReadProperty(104, "compression_type", result.compression_type); - return result; -} - -void ColumnInfo::Serialize(Serializer &serializer) const { - serializer.WritePropertyWithDefault>(100, "names", names); - serializer.WritePropertyWithDefault>(101, "types", types); -} - -ColumnInfo ColumnInfo::Deserialize(Deserializer &deserializer) { - ColumnInfo result; - deserializer.ReadPropertyWithDefault>(100, "names", result.names); - deserializer.ReadPropertyWithDefault>(101, "types", result.types); - return result; -} - -void ColumnList::Serialize(Serializer &serializer) const { - serializer.WritePropertyWithDefault>(100, "columns", columns); -} - -ColumnList ColumnList::Deserialize(Deserializer &deserializer) { - auto columns = deserializer.ReadPropertyWithDefault>(100, "columns"); - ColumnList result(std::move(columns)); - return result; -} - -void CommonTableExpressionInfo::Serialize(Serializer &serializer) const { - serializer.WritePropertyWithDefault>(100, "aliases", aliases); - serializer.WritePropertyWithDefault>(101, "query", query); - serializer.WriteProperty(102, "materialized", materialized); -} - -unique_ptr CommonTableExpressionInfo::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new CommonTableExpressionInfo()); - deserializer.ReadPropertyWithDefault>(100, "aliases", result->aliases); - deserializer.ReadPropertyWithDefault>(101, "query", result->query); - deserializer.ReadProperty(102, "materialized", result->materialized); - return result; -} - -void CommonTableExpressionMap::Serialize(Serializer &serializer) const { - serializer.WritePropertyWithDefault>>(100, "map", map); -} - -CommonTableExpressionMap CommonTableExpressionMap::Deserialize(Deserializer &deserializer) { - CommonTableExpressionMap result; - deserializer.ReadPropertyWithDefault>>(100, "map", result.map); - return result; -} - -void HivePartitioningIndex::Serialize(Serializer &serializer) const { - serializer.WritePropertyWithDefault(100, "value", value); - serializer.WritePropertyWithDefault(101, "index", index); -} - -HivePartitioningIndex HivePartitioningIndex::Deserialize(Deserializer &deserializer) { - auto value = deserializer.ReadPropertyWithDefault(100, "value"); - auto index = deserializer.ReadPropertyWithDefault(101, "index"); - HivePartitioningIndex result(std::move(value), index); - return result; -} - -void JoinCondition::Serialize(Serializer &serializer) const { - serializer.WritePropertyWithDefault>(100, "left", left); - serializer.WritePropertyWithDefault>(101, "right", right); - serializer.WriteProperty(102, "comparison", comparison); -} - -JoinCondition JoinCondition::Deserialize(Deserializer &deserializer) { - JoinCondition result; - deserializer.ReadPropertyWithDefault>(100, "left", result.left); - deserializer.ReadPropertyWithDefault>(101, "right", result.right); - deserializer.ReadProperty(102, "comparison", result.comparison); - return result; -} - -void LogicalType::Serialize(Serializer &serializer) const { - serializer.WriteProperty(100, "id", id_); - serializer.WritePropertyWithDefault>(101, "type_info", type_info_); -} - -LogicalType LogicalType::Deserialize(Deserializer &deserializer) { - auto id = deserializer.ReadProperty(100, "id"); - auto type_info = deserializer.ReadPropertyWithDefault>(101, "type_info"); - LogicalType result(id, std::move(type_info)); - return result; -} - -void MultiFileReaderBindData::Serialize(Serializer &serializer) const { - serializer.WritePropertyWithDefault(100, "filename_idx", filename_idx); - serializer.WritePropertyWithDefault>(101, "hive_partitioning_indexes", hive_partitioning_indexes); -} - -MultiFileReaderBindData MultiFileReaderBindData::Deserialize(Deserializer &deserializer) { - MultiFileReaderBindData result; - deserializer.ReadPropertyWithDefault(100, "filename_idx", result.filename_idx); - deserializer.ReadPropertyWithDefault>(101, "hive_partitioning_indexes", result.hive_partitioning_indexes); - return result; -} - -void MultiFileReaderOptions::Serialize(Serializer &serializer) const { - serializer.WritePropertyWithDefault(100, "filename", filename); - serializer.WritePropertyWithDefault(101, "hive_partitioning", hive_partitioning); - serializer.WritePropertyWithDefault(102, "auto_detect_hive_partitioning", auto_detect_hive_partitioning); - serializer.WritePropertyWithDefault(103, "union_by_name", union_by_name); - serializer.WritePropertyWithDefault(104, "hive_types_autocast", hive_types_autocast); - serializer.WritePropertyWithDefault>(105, "hive_types_schema", hive_types_schema); -} - -MultiFileReaderOptions MultiFileReaderOptions::Deserialize(Deserializer &deserializer) { - MultiFileReaderOptions result; - deserializer.ReadPropertyWithDefault(100, "filename", result.filename); - deserializer.ReadPropertyWithDefault(101, "hive_partitioning", result.hive_partitioning); - deserializer.ReadPropertyWithDefault(102, "auto_detect_hive_partitioning", result.auto_detect_hive_partitioning); - deserializer.ReadPropertyWithDefault(103, "union_by_name", result.union_by_name); - deserializer.ReadPropertyWithDefault(104, "hive_types_autocast", result.hive_types_autocast); - deserializer.ReadPropertyWithDefault>(105, "hive_types_schema", result.hive_types_schema); - return result; -} - -void OrderByNode::Serialize(Serializer &serializer) const { - serializer.WriteProperty(100, "type", type); - serializer.WriteProperty(101, "null_order", null_order); - serializer.WritePropertyWithDefault>(102, "expression", expression); -} - -OrderByNode OrderByNode::Deserialize(Deserializer &deserializer) { - auto type = deserializer.ReadProperty(100, "type"); - auto null_order = deserializer.ReadProperty(101, "null_order"); - auto expression = deserializer.ReadPropertyWithDefault>(102, "expression"); - OrderByNode result(type, null_order, std::move(expression)); - return result; -} - -void PivotColumn::Serialize(Serializer &serializer) const { - serializer.WritePropertyWithDefault>>(100, "pivot_expressions", pivot_expressions); - serializer.WritePropertyWithDefault>(101, "unpivot_names", unpivot_names); - serializer.WritePropertyWithDefault>(102, "entries", entries); - serializer.WritePropertyWithDefault(103, "pivot_enum", pivot_enum); -} - -PivotColumn PivotColumn::Deserialize(Deserializer &deserializer) { - PivotColumn result; - deserializer.ReadPropertyWithDefault>>(100, "pivot_expressions", result.pivot_expressions); - deserializer.ReadPropertyWithDefault>(101, "unpivot_names", result.unpivot_names); - deserializer.ReadPropertyWithDefault>(102, "entries", result.entries); - deserializer.ReadPropertyWithDefault(103, "pivot_enum", result.pivot_enum); - return result; -} - -void PivotColumnEntry::Serialize(Serializer &serializer) const { - serializer.WritePropertyWithDefault>(100, "values", values); - serializer.WritePropertyWithDefault>(101, "star_expr", star_expr); - serializer.WritePropertyWithDefault(102, "alias", alias); -} - -PivotColumnEntry PivotColumnEntry::Deserialize(Deserializer &deserializer) { - PivotColumnEntry result; - deserializer.ReadPropertyWithDefault>(100, "values", result.values); - deserializer.ReadPropertyWithDefault>(101, "star_expr", result.star_expr); - deserializer.ReadPropertyWithDefault(102, "alias", result.alias); - return result; -} - -void ReadCSVData::Serialize(Serializer &serializer) const { - serializer.WritePropertyWithDefault>(100, "files", files); - serializer.WritePropertyWithDefault>(101, "csv_types", csv_types); - serializer.WritePropertyWithDefault>(102, "csv_names", csv_names); - serializer.WritePropertyWithDefault>(103, "return_types", return_types); - serializer.WritePropertyWithDefault>(104, "return_names", return_names); - serializer.WritePropertyWithDefault(105, "filename_col_idx", filename_col_idx); - serializer.WriteProperty(106, "options", options); - serializer.WritePropertyWithDefault(107, "single_threaded", single_threaded); - serializer.WriteProperty(108, "reader_bind", reader_bind); - serializer.WritePropertyWithDefault>(109, "column_info", column_info); -} - -unique_ptr ReadCSVData::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new ReadCSVData()); - deserializer.ReadPropertyWithDefault>(100, "files", result->files); - deserializer.ReadPropertyWithDefault>(101, "csv_types", result->csv_types); - deserializer.ReadPropertyWithDefault>(102, "csv_names", result->csv_names); - deserializer.ReadPropertyWithDefault>(103, "return_types", result->return_types); - deserializer.ReadPropertyWithDefault>(104, "return_names", result->return_names); - deserializer.ReadPropertyWithDefault(105, "filename_col_idx", result->filename_col_idx); - deserializer.ReadProperty(106, "options", result->options); - deserializer.ReadPropertyWithDefault(107, "single_threaded", result->single_threaded); - deserializer.ReadProperty(108, "reader_bind", result->reader_bind); - deserializer.ReadPropertyWithDefault>(109, "column_info", result->column_info); - return result; -} - -void SampleOptions::Serialize(Serializer &serializer) const { - serializer.WriteProperty(100, "sample_size", sample_size); - serializer.WritePropertyWithDefault(101, "is_percentage", is_percentage); - serializer.WriteProperty(102, "method", method); - serializer.WritePropertyWithDefault(103, "seed", seed); -} - -unique_ptr SampleOptions::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new SampleOptions()); - deserializer.ReadProperty(100, "sample_size", result->sample_size); - deserializer.ReadPropertyWithDefault(101, "is_percentage", result->is_percentage); - deserializer.ReadProperty(102, "method", result->method); - deserializer.ReadPropertyWithDefault(103, "seed", result->seed); - return result; -} - -void StrpTimeFormat::Serialize(Serializer &serializer) const { - serializer.WritePropertyWithDefault(100, "format_specifier", format_specifier); -} - -StrpTimeFormat StrpTimeFormat::Deserialize(Deserializer &deserializer) { - auto format_specifier = deserializer.ReadPropertyWithDefault(100, "format_specifier"); - StrpTimeFormat result(format_specifier); - return result; -} - -void TableFilterSet::Serialize(Serializer &serializer) const { - serializer.WritePropertyWithDefault>>(100, "filters", filters); -} - -TableFilterSet TableFilterSet::Deserialize(Deserializer &deserializer) { - TableFilterSet result; - deserializer.ReadPropertyWithDefault>>(100, "filters", result.filters); - return result; -} - -void VacuumOptions::Serialize(Serializer &serializer) const { - serializer.WritePropertyWithDefault(100, "vacuum", vacuum); - serializer.WritePropertyWithDefault(101, "analyze", analyze); -} - -VacuumOptions VacuumOptions::Deserialize(Deserializer &deserializer) { - VacuumOptions result; - deserializer.ReadPropertyWithDefault(100, "vacuum", result.vacuum); - deserializer.ReadPropertyWithDefault(101, "analyze", result.analyze); - return result; -} - -void interval_t::Serialize(Serializer &serializer) const { - serializer.WritePropertyWithDefault(1, "months", months); - serializer.WritePropertyWithDefault(2, "days", days); - serializer.WritePropertyWithDefault(3, "micros", micros); -} - -interval_t interval_t::Deserialize(Deserializer &deserializer) { - interval_t result; - deserializer.ReadPropertyWithDefault(1, "months", result.months); - deserializer.ReadPropertyWithDefault(2, "days", result.days); - deserializer.ReadPropertyWithDefault(3, "micros", result.micros); - return result; -} - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// This file is automatically generated by scripts/generate_serialization.py -// Do not edit this file manually, your changes will be overwritten -//===----------------------------------------------------------------------===// - - - - - - - - - - - - - - - -namespace duckdb { - -void ParseInfo::Serialize(Serializer &serializer) const { - serializer.WriteProperty(100, "info_type", info_type); -} - -unique_ptr ParseInfo::Deserialize(Deserializer &deserializer) { - auto info_type = deserializer.ReadProperty(100, "info_type"); - unique_ptr result; - switch (info_type) { - case ParseInfoType::ALTER_INFO: - result = AlterInfo::Deserialize(deserializer); - break; - case ParseInfoType::ATTACH_INFO: - result = AttachInfo::Deserialize(deserializer); - break; - case ParseInfoType::COPY_INFO: - result = CopyInfo::Deserialize(deserializer); - break; - case ParseInfoType::DETACH_INFO: - result = DetachInfo::Deserialize(deserializer); - break; - case ParseInfoType::DROP_INFO: - result = DropInfo::Deserialize(deserializer); - break; - case ParseInfoType::LOAD_INFO: - result = LoadInfo::Deserialize(deserializer); - break; - case ParseInfoType::PRAGMA_INFO: - result = PragmaInfo::Deserialize(deserializer); - break; - case ParseInfoType::TRANSACTION_INFO: - result = TransactionInfo::Deserialize(deserializer); - break; - case ParseInfoType::VACUUM_INFO: - result = VacuumInfo::Deserialize(deserializer); - break; - default: - throw SerializationException("Unsupported type for deserialization of ParseInfo!"); - } - return result; -} - -void AlterInfo::Serialize(Serializer &serializer) const { - ParseInfo::Serialize(serializer); - serializer.WriteProperty(200, "type", type); - serializer.WritePropertyWithDefault(201, "catalog", catalog); - serializer.WritePropertyWithDefault(202, "schema", schema); - serializer.WritePropertyWithDefault(203, "name", name); - serializer.WriteProperty(204, "if_not_found", if_not_found); - serializer.WritePropertyWithDefault(205, "allow_internal", allow_internal); -} - -unique_ptr AlterInfo::Deserialize(Deserializer &deserializer) { - auto type = deserializer.ReadProperty(200, "type"); - auto catalog = deserializer.ReadPropertyWithDefault(201, "catalog"); - auto schema = deserializer.ReadPropertyWithDefault(202, "schema"); - auto name = deserializer.ReadPropertyWithDefault(203, "name"); - auto if_not_found = deserializer.ReadProperty(204, "if_not_found"); - auto allow_internal = deserializer.ReadPropertyWithDefault(205, "allow_internal"); - unique_ptr result; - switch (type) { - case AlterType::ALTER_TABLE: - result = AlterTableInfo::Deserialize(deserializer); - break; - case AlterType::ALTER_VIEW: - result = AlterViewInfo::Deserialize(deserializer); - break; - default: - throw SerializationException("Unsupported type for deserialization of AlterInfo!"); - } - result->catalog = std::move(catalog); - result->schema = std::move(schema); - result->name = std::move(name); - result->if_not_found = if_not_found; - result->allow_internal = allow_internal; - return std::move(result); -} - -void AlterTableInfo::Serialize(Serializer &serializer) const { - AlterInfo::Serialize(serializer); - serializer.WriteProperty(300, "alter_table_type", alter_table_type); -} - -unique_ptr AlterTableInfo::Deserialize(Deserializer &deserializer) { - auto alter_table_type = deserializer.ReadProperty(300, "alter_table_type"); - unique_ptr result; - switch (alter_table_type) { - case AlterTableType::ADD_COLUMN: - result = AddColumnInfo::Deserialize(deserializer); - break; - case AlterTableType::ALTER_COLUMN_TYPE: - result = ChangeColumnTypeInfo::Deserialize(deserializer); - break; - case AlterTableType::DROP_NOT_NULL: - result = DropNotNullInfo::Deserialize(deserializer); - break; - case AlterTableType::FOREIGN_KEY_CONSTRAINT: - result = AlterForeignKeyInfo::Deserialize(deserializer); - break; - case AlterTableType::REMOVE_COLUMN: - result = RemoveColumnInfo::Deserialize(deserializer); - break; - case AlterTableType::RENAME_COLUMN: - result = RenameColumnInfo::Deserialize(deserializer); - break; - case AlterTableType::RENAME_TABLE: - result = RenameTableInfo::Deserialize(deserializer); - break; - case AlterTableType::SET_DEFAULT: - result = SetDefaultInfo::Deserialize(deserializer); - break; - case AlterTableType::SET_NOT_NULL: - result = SetNotNullInfo::Deserialize(deserializer); - break; - default: - throw SerializationException("Unsupported type for deserialization of AlterTableInfo!"); - } - return std::move(result); -} - -void AlterViewInfo::Serialize(Serializer &serializer) const { - AlterInfo::Serialize(serializer); - serializer.WriteProperty(300, "alter_view_type", alter_view_type); -} - -unique_ptr AlterViewInfo::Deserialize(Deserializer &deserializer) { - auto alter_view_type = deserializer.ReadProperty(300, "alter_view_type"); - unique_ptr result; - switch (alter_view_type) { - case AlterViewType::RENAME_VIEW: - result = RenameViewInfo::Deserialize(deserializer); - break; - default: - throw SerializationException("Unsupported type for deserialization of AlterViewInfo!"); - } - return std::move(result); -} - -void AddColumnInfo::Serialize(Serializer &serializer) const { - AlterTableInfo::Serialize(serializer); - serializer.WriteProperty(400, "new_column", new_column); - serializer.WritePropertyWithDefault(401, "if_column_not_exists", if_column_not_exists); -} - -unique_ptr AddColumnInfo::Deserialize(Deserializer &deserializer) { - auto new_column = deserializer.ReadProperty(400, "new_column"); - auto result = duckdb::unique_ptr(new AddColumnInfo(std::move(new_column))); - deserializer.ReadPropertyWithDefault(401, "if_column_not_exists", result->if_column_not_exists); - return std::move(result); -} - -void AlterForeignKeyInfo::Serialize(Serializer &serializer) const { - AlterTableInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(400, "fk_table", fk_table); - serializer.WritePropertyWithDefault>(401, "pk_columns", pk_columns); - serializer.WritePropertyWithDefault>(402, "fk_columns", fk_columns); - serializer.WritePropertyWithDefault>(403, "pk_keys", pk_keys); - serializer.WritePropertyWithDefault>(404, "fk_keys", fk_keys); - serializer.WriteProperty(405, "alter_fk_type", type); -} - -unique_ptr AlterForeignKeyInfo::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new AlterForeignKeyInfo()); - deserializer.ReadPropertyWithDefault(400, "fk_table", result->fk_table); - deserializer.ReadPropertyWithDefault>(401, "pk_columns", result->pk_columns); - deserializer.ReadPropertyWithDefault>(402, "fk_columns", result->fk_columns); - deserializer.ReadPropertyWithDefault>(403, "pk_keys", result->pk_keys); - deserializer.ReadPropertyWithDefault>(404, "fk_keys", result->fk_keys); - deserializer.ReadProperty(405, "alter_fk_type", result->type); - return std::move(result); -} - -void AttachInfo::Serialize(Serializer &serializer) const { - ParseInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "name", name); - serializer.WritePropertyWithDefault(201, "path", path); - serializer.WritePropertyWithDefault>(202, "options", options); -} - -unique_ptr AttachInfo::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new AttachInfo()); - deserializer.ReadPropertyWithDefault(200, "name", result->name); - deserializer.ReadPropertyWithDefault(201, "path", result->path); - deserializer.ReadPropertyWithDefault>(202, "options", result->options); - return std::move(result); -} - -void ChangeColumnTypeInfo::Serialize(Serializer &serializer) const { - AlterTableInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(400, "column_name", column_name); - serializer.WriteProperty(401, "target_type", target_type); - serializer.WritePropertyWithDefault>(402, "expression", expression); -} - -unique_ptr ChangeColumnTypeInfo::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new ChangeColumnTypeInfo()); - deserializer.ReadPropertyWithDefault(400, "column_name", result->column_name); - deserializer.ReadProperty(401, "target_type", result->target_type); - deserializer.ReadPropertyWithDefault>(402, "expression", result->expression); - return std::move(result); -} - -void CopyInfo::Serialize(Serializer &serializer) const { - ParseInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "catalog", catalog); - serializer.WritePropertyWithDefault(201, "schema", schema); - serializer.WritePropertyWithDefault(202, "table", table); - serializer.WritePropertyWithDefault>(203, "select_list", select_list); - serializer.WritePropertyWithDefault(204, "is_from", is_from); - serializer.WritePropertyWithDefault(205, "format", format); - serializer.WritePropertyWithDefault(206, "file_path", file_path); - serializer.WritePropertyWithDefault>>(207, "options", options); -} - -unique_ptr CopyInfo::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new CopyInfo()); - deserializer.ReadPropertyWithDefault(200, "catalog", result->catalog); - deserializer.ReadPropertyWithDefault(201, "schema", result->schema); - deserializer.ReadPropertyWithDefault(202, "table", result->table); - deserializer.ReadPropertyWithDefault>(203, "select_list", result->select_list); - deserializer.ReadPropertyWithDefault(204, "is_from", result->is_from); - deserializer.ReadPropertyWithDefault(205, "format", result->format); - deserializer.ReadPropertyWithDefault(206, "file_path", result->file_path); - deserializer.ReadPropertyWithDefault>>(207, "options", result->options); - return std::move(result); -} - -void DetachInfo::Serialize(Serializer &serializer) const { - ParseInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "name", name); - serializer.WriteProperty(201, "if_not_found", if_not_found); -} - -unique_ptr DetachInfo::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new DetachInfo()); - deserializer.ReadPropertyWithDefault(200, "name", result->name); - deserializer.ReadProperty(201, "if_not_found", result->if_not_found); - return std::move(result); -} - -void DropInfo::Serialize(Serializer &serializer) const { - ParseInfo::Serialize(serializer); - serializer.WriteProperty(200, "type", type); - serializer.WritePropertyWithDefault(201, "catalog", catalog); - serializer.WritePropertyWithDefault(202, "schema", schema); - serializer.WritePropertyWithDefault(203, "name", name); - serializer.WriteProperty(204, "if_not_found", if_not_found); - serializer.WritePropertyWithDefault(205, "cascade", cascade); - serializer.WritePropertyWithDefault(206, "allow_drop_internal", allow_drop_internal); -} - -unique_ptr DropInfo::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new DropInfo()); - deserializer.ReadProperty(200, "type", result->type); - deserializer.ReadPropertyWithDefault(201, "catalog", result->catalog); - deserializer.ReadPropertyWithDefault(202, "schema", result->schema); - deserializer.ReadPropertyWithDefault(203, "name", result->name); - deserializer.ReadProperty(204, "if_not_found", result->if_not_found); - deserializer.ReadPropertyWithDefault(205, "cascade", result->cascade); - deserializer.ReadPropertyWithDefault(206, "allow_drop_internal", result->allow_drop_internal); - return std::move(result); -} - -void DropNotNullInfo::Serialize(Serializer &serializer) const { - AlterTableInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(400, "column_name", column_name); -} - -unique_ptr DropNotNullInfo::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new DropNotNullInfo()); - deserializer.ReadPropertyWithDefault(400, "column_name", result->column_name); - return std::move(result); -} - -void LoadInfo::Serialize(Serializer &serializer) const { - ParseInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "filename", filename); - serializer.WriteProperty(201, "load_type", load_type); - serializer.WritePropertyWithDefault(202, "repository", repository); -} - -unique_ptr LoadInfo::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new LoadInfo()); - deserializer.ReadPropertyWithDefault(200, "filename", result->filename); - deserializer.ReadProperty(201, "load_type", result->load_type); - deserializer.ReadPropertyWithDefault(202, "repository", result->repository); - return std::move(result); -} - -void PragmaInfo::Serialize(Serializer &serializer) const { - ParseInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "name", name); - serializer.WritePropertyWithDefault>(201, "parameters", parameters); - serializer.WriteProperty(202, "named_parameters", named_parameters); -} - -unique_ptr PragmaInfo::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new PragmaInfo()); - deserializer.ReadPropertyWithDefault(200, "name", result->name); - deserializer.ReadPropertyWithDefault>(201, "parameters", result->parameters); - deserializer.ReadProperty(202, "named_parameters", result->named_parameters); - return std::move(result); -} - -void RemoveColumnInfo::Serialize(Serializer &serializer) const { - AlterTableInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(400, "removed_column", removed_column); - serializer.WritePropertyWithDefault(401, "if_column_exists", if_column_exists); - serializer.WritePropertyWithDefault(402, "cascade", cascade); -} - -unique_ptr RemoveColumnInfo::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new RemoveColumnInfo()); - deserializer.ReadPropertyWithDefault(400, "removed_column", result->removed_column); - deserializer.ReadPropertyWithDefault(401, "if_column_exists", result->if_column_exists); - deserializer.ReadPropertyWithDefault(402, "cascade", result->cascade); - return std::move(result); -} - -void RenameColumnInfo::Serialize(Serializer &serializer) const { - AlterTableInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(400, "old_name", old_name); - serializer.WritePropertyWithDefault(401, "new_name", new_name); -} - -unique_ptr RenameColumnInfo::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new RenameColumnInfo()); - deserializer.ReadPropertyWithDefault(400, "old_name", result->old_name); - deserializer.ReadPropertyWithDefault(401, "new_name", result->new_name); - return std::move(result); -} - -void RenameTableInfo::Serialize(Serializer &serializer) const { - AlterTableInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(400, "new_table_name", new_table_name); -} - -unique_ptr RenameTableInfo::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new RenameTableInfo()); - deserializer.ReadPropertyWithDefault(400, "new_table_name", result->new_table_name); - return std::move(result); -} - -void RenameViewInfo::Serialize(Serializer &serializer) const { - AlterViewInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(400, "new_view_name", new_view_name); -} - -unique_ptr RenameViewInfo::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new RenameViewInfo()); - deserializer.ReadPropertyWithDefault(400, "new_view_name", result->new_view_name); - return std::move(result); -} - -void SetDefaultInfo::Serialize(Serializer &serializer) const { - AlterTableInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(400, "column_name", column_name); - serializer.WritePropertyWithDefault>(401, "expression", expression); -} - -unique_ptr SetDefaultInfo::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new SetDefaultInfo()); - deserializer.ReadPropertyWithDefault(400, "column_name", result->column_name); - deserializer.ReadPropertyWithDefault>(401, "expression", result->expression); - return std::move(result); -} - -void SetNotNullInfo::Serialize(Serializer &serializer) const { - AlterTableInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(400, "column_name", column_name); -} - -unique_ptr SetNotNullInfo::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new SetNotNullInfo()); - deserializer.ReadPropertyWithDefault(400, "column_name", result->column_name); - return std::move(result); -} - -void TransactionInfo::Serialize(Serializer &serializer) const { - ParseInfo::Serialize(serializer); - serializer.WriteProperty(200, "type", type); -} - -unique_ptr TransactionInfo::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new TransactionInfo()); - deserializer.ReadProperty(200, "type", result->type); - return std::move(result); -} - -void VacuumInfo::Serialize(Serializer &serializer) const { - ParseInfo::Serialize(serializer); - serializer.WriteProperty(200, "options", options); -} - -unique_ptr VacuumInfo::Deserialize(Deserializer &deserializer) { - auto options = deserializer.ReadProperty(200, "options"); - auto result = duckdb::unique_ptr(new VacuumInfo(options)); - return std::move(result); -} - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// This file is automatically generated by scripts/generate_serialization.py -// Do not edit this file manually, your changes will be overwritten -//===----------------------------------------------------------------------===// - - - - - -namespace duckdb { - -void ParsedExpression::Serialize(Serializer &serializer) const { - serializer.WriteProperty(100, "class", expression_class); - serializer.WriteProperty(101, "type", type); - serializer.WritePropertyWithDefault(102, "alias", alias); -} - -unique_ptr ParsedExpression::Deserialize(Deserializer &deserializer) { - auto expression_class = deserializer.ReadProperty(100, "class"); - auto type = deserializer.ReadProperty(101, "type"); - auto alias = deserializer.ReadPropertyWithDefault(102, "alias"); - deserializer.Set(type); - unique_ptr result; - switch (expression_class) { - case ExpressionClass::BETWEEN: - result = BetweenExpression::Deserialize(deserializer); - break; - case ExpressionClass::CASE: - result = CaseExpression::Deserialize(deserializer); - break; - case ExpressionClass::CAST: - result = CastExpression::Deserialize(deserializer); - break; - case ExpressionClass::COLLATE: - result = CollateExpression::Deserialize(deserializer); - break; - case ExpressionClass::COLUMN_REF: - result = ColumnRefExpression::Deserialize(deserializer); - break; - case ExpressionClass::COMPARISON: - result = ComparisonExpression::Deserialize(deserializer); - break; - case ExpressionClass::CONJUNCTION: - result = ConjunctionExpression::Deserialize(deserializer); - break; - case ExpressionClass::CONSTANT: - result = ConstantExpression::Deserialize(deserializer); - break; - case ExpressionClass::DEFAULT: - result = DefaultExpression::Deserialize(deserializer); - break; - case ExpressionClass::FUNCTION: - result = FunctionExpression::Deserialize(deserializer); - break; - case ExpressionClass::LAMBDA: - result = LambdaExpression::Deserialize(deserializer); - break; - case ExpressionClass::OPERATOR: - result = OperatorExpression::Deserialize(deserializer); - break; - case ExpressionClass::PARAMETER: - result = ParameterExpression::Deserialize(deserializer); - break; - case ExpressionClass::POSITIONAL_REFERENCE: - result = PositionalReferenceExpression::Deserialize(deserializer); - break; - case ExpressionClass::STAR: - result = StarExpression::Deserialize(deserializer); - break; - case ExpressionClass::SUBQUERY: - result = SubqueryExpression::Deserialize(deserializer); - break; - case ExpressionClass::WINDOW: - result = WindowExpression::Deserialize(deserializer); - break; - default: - throw SerializationException("Unsupported type for deserialization of ParsedExpression!"); - } - deserializer.Unset(); - result->alias = std::move(alias); - return result; -} - -void BetweenExpression::Serialize(Serializer &serializer) const { - ParsedExpression::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "input", input); - serializer.WritePropertyWithDefault>(201, "lower", lower); - serializer.WritePropertyWithDefault>(202, "upper", upper); -} - -unique_ptr BetweenExpression::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new BetweenExpression()); - deserializer.ReadPropertyWithDefault>(200, "input", result->input); - deserializer.ReadPropertyWithDefault>(201, "lower", result->lower); - deserializer.ReadPropertyWithDefault>(202, "upper", result->upper); - return std::move(result); -} - -void CaseExpression::Serialize(Serializer &serializer) const { - ParsedExpression::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "case_checks", case_checks); - serializer.WritePropertyWithDefault>(201, "else_expr", else_expr); -} - -unique_ptr CaseExpression::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new CaseExpression()); - deserializer.ReadPropertyWithDefault>(200, "case_checks", result->case_checks); - deserializer.ReadPropertyWithDefault>(201, "else_expr", result->else_expr); - return std::move(result); -} - -void CastExpression::Serialize(Serializer &serializer) const { - ParsedExpression::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "child", child); - serializer.WriteProperty(201, "cast_type", cast_type); - serializer.WritePropertyWithDefault(202, "try_cast", try_cast); -} - -unique_ptr CastExpression::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new CastExpression()); - deserializer.ReadPropertyWithDefault>(200, "child", result->child); - deserializer.ReadProperty(201, "cast_type", result->cast_type); - deserializer.ReadPropertyWithDefault(202, "try_cast", result->try_cast); - return std::move(result); -} - -void CollateExpression::Serialize(Serializer &serializer) const { - ParsedExpression::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "child", child); - serializer.WritePropertyWithDefault(201, "collation", collation); -} - -unique_ptr CollateExpression::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new CollateExpression()); - deserializer.ReadPropertyWithDefault>(200, "child", result->child); - deserializer.ReadPropertyWithDefault(201, "collation", result->collation); - return std::move(result); -} - -void ColumnRefExpression::Serialize(Serializer &serializer) const { - ParsedExpression::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "column_names", column_names); -} - -unique_ptr ColumnRefExpression::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new ColumnRefExpression()); - deserializer.ReadPropertyWithDefault>(200, "column_names", result->column_names); - return std::move(result); -} - -void ComparisonExpression::Serialize(Serializer &serializer) const { - ParsedExpression::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "left", left); - serializer.WritePropertyWithDefault>(201, "right", right); -} - -unique_ptr ComparisonExpression::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new ComparisonExpression(deserializer.Get())); - deserializer.ReadPropertyWithDefault>(200, "left", result->left); - deserializer.ReadPropertyWithDefault>(201, "right", result->right); - return std::move(result); -} - -void ConjunctionExpression::Serialize(Serializer &serializer) const { - ParsedExpression::Serialize(serializer); - serializer.WritePropertyWithDefault>>(200, "children", children); -} - -unique_ptr ConjunctionExpression::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new ConjunctionExpression(deserializer.Get())); - deserializer.ReadPropertyWithDefault>>(200, "children", result->children); - return std::move(result); -} - -void ConstantExpression::Serialize(Serializer &serializer) const { - ParsedExpression::Serialize(serializer); - serializer.WriteProperty(200, "value", value); -} - -unique_ptr ConstantExpression::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new ConstantExpression()); - deserializer.ReadProperty(200, "value", result->value); - return std::move(result); -} - -void DefaultExpression::Serialize(Serializer &serializer) const { - ParsedExpression::Serialize(serializer); -} - -unique_ptr DefaultExpression::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new DefaultExpression()); - return std::move(result); -} - -void FunctionExpression::Serialize(Serializer &serializer) const { - ParsedExpression::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "function_name", function_name); - serializer.WritePropertyWithDefault(201, "schema", schema); - serializer.WritePropertyWithDefault>>(202, "children", children); - serializer.WritePropertyWithDefault>(203, "filter", filter); - serializer.WritePropertyWithDefault>(204, "order_bys", order_bys); - serializer.WritePropertyWithDefault(205, "distinct", distinct); - serializer.WritePropertyWithDefault(206, "is_operator", is_operator); - serializer.WritePropertyWithDefault(207, "export_state", export_state); - serializer.WritePropertyWithDefault(208, "catalog", catalog); -} - -unique_ptr FunctionExpression::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new FunctionExpression()); - deserializer.ReadPropertyWithDefault(200, "function_name", result->function_name); - deserializer.ReadPropertyWithDefault(201, "schema", result->schema); - deserializer.ReadPropertyWithDefault>>(202, "children", result->children); - deserializer.ReadPropertyWithDefault>(203, "filter", result->filter); - auto order_bys = deserializer.ReadPropertyWithDefault>(204, "order_bys"); - result->order_bys = unique_ptr_cast(std::move(order_bys)); - deserializer.ReadPropertyWithDefault(205, "distinct", result->distinct); - deserializer.ReadPropertyWithDefault(206, "is_operator", result->is_operator); - deserializer.ReadPropertyWithDefault(207, "export_state", result->export_state); - deserializer.ReadPropertyWithDefault(208, "catalog", result->catalog); - return std::move(result); -} - -void LambdaExpression::Serialize(Serializer &serializer) const { - ParsedExpression::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "lhs", lhs); - serializer.WritePropertyWithDefault>(201, "expr", expr); -} - -unique_ptr LambdaExpression::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new LambdaExpression()); - deserializer.ReadPropertyWithDefault>(200, "lhs", result->lhs); - deserializer.ReadPropertyWithDefault>(201, "expr", result->expr); - return std::move(result); -} - -void OperatorExpression::Serialize(Serializer &serializer) const { - ParsedExpression::Serialize(serializer); - serializer.WritePropertyWithDefault>>(200, "children", children); -} - -unique_ptr OperatorExpression::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new OperatorExpression(deserializer.Get())); - deserializer.ReadPropertyWithDefault>>(200, "children", result->children); - return std::move(result); -} - -void ParameterExpression::Serialize(Serializer &serializer) const { - ParsedExpression::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "identifier", identifier); -} - -unique_ptr ParameterExpression::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new ParameterExpression()); - deserializer.ReadPropertyWithDefault(200, "identifier", result->identifier); - return std::move(result); -} - -void PositionalReferenceExpression::Serialize(Serializer &serializer) const { - ParsedExpression::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "index", index); -} - -unique_ptr PositionalReferenceExpression::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new PositionalReferenceExpression()); - deserializer.ReadPropertyWithDefault(200, "index", result->index); - return std::move(result); -} - -void StarExpression::Serialize(Serializer &serializer) const { - ParsedExpression::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "relation_name", relation_name); - serializer.WriteProperty(201, "exclude_list", exclude_list); - serializer.WritePropertyWithDefault>>(202, "replace_list", replace_list); - serializer.WritePropertyWithDefault(203, "columns", columns); - serializer.WritePropertyWithDefault>(204, "expr", expr); -} - -unique_ptr StarExpression::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new StarExpression()); - deserializer.ReadPropertyWithDefault(200, "relation_name", result->relation_name); - deserializer.ReadProperty(201, "exclude_list", result->exclude_list); - deserializer.ReadPropertyWithDefault>>(202, "replace_list", result->replace_list); - deserializer.ReadPropertyWithDefault(203, "columns", result->columns); - deserializer.ReadPropertyWithDefault>(204, "expr", result->expr); - return std::move(result); -} - -void SubqueryExpression::Serialize(Serializer &serializer) const { - ParsedExpression::Serialize(serializer); - serializer.WriteProperty(200, "subquery_type", subquery_type); - serializer.WritePropertyWithDefault>(201, "subquery", subquery); - serializer.WritePropertyWithDefault>(202, "child", child); - serializer.WriteProperty(203, "comparison_type", comparison_type); -} - -unique_ptr SubqueryExpression::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new SubqueryExpression()); - deserializer.ReadProperty(200, "subquery_type", result->subquery_type); - deserializer.ReadPropertyWithDefault>(201, "subquery", result->subquery); - deserializer.ReadPropertyWithDefault>(202, "child", result->child); - deserializer.ReadProperty(203, "comparison_type", result->comparison_type); - return std::move(result); -} - -void WindowExpression::Serialize(Serializer &serializer) const { - ParsedExpression::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "function_name", function_name); - serializer.WritePropertyWithDefault(201, "schema", schema); - serializer.WritePropertyWithDefault(202, "catalog", catalog); - serializer.WritePropertyWithDefault>>(203, "children", children); - serializer.WritePropertyWithDefault>>(204, "partitions", partitions); - serializer.WritePropertyWithDefault>(205, "orders", orders); - serializer.WriteProperty(206, "start", start); - serializer.WriteProperty(207, "end", end); - serializer.WritePropertyWithDefault>(208, "start_expr", start_expr); - serializer.WritePropertyWithDefault>(209, "end_expr", end_expr); - serializer.WritePropertyWithDefault>(210, "offset_expr", offset_expr); - serializer.WritePropertyWithDefault>(211, "default_expr", default_expr); - serializer.WritePropertyWithDefault(212, "ignore_nulls", ignore_nulls); - serializer.WritePropertyWithDefault>(213, "filter_expr", filter_expr); -} - -unique_ptr WindowExpression::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new WindowExpression(deserializer.Get())); - deserializer.ReadPropertyWithDefault(200, "function_name", result->function_name); - deserializer.ReadPropertyWithDefault(201, "schema", result->schema); - deserializer.ReadPropertyWithDefault(202, "catalog", result->catalog); - deserializer.ReadPropertyWithDefault>>(203, "children", result->children); - deserializer.ReadPropertyWithDefault>>(204, "partitions", result->partitions); - deserializer.ReadPropertyWithDefault>(205, "orders", result->orders); - deserializer.ReadProperty(206, "start", result->start); - deserializer.ReadProperty(207, "end", result->end); - deserializer.ReadPropertyWithDefault>(208, "start_expr", result->start_expr); - deserializer.ReadPropertyWithDefault>(209, "end_expr", result->end_expr); - deserializer.ReadPropertyWithDefault>(210, "offset_expr", result->offset_expr); - deserializer.ReadPropertyWithDefault>(211, "default_expr", result->default_expr); - deserializer.ReadPropertyWithDefault(212, "ignore_nulls", result->ignore_nulls); - deserializer.ReadPropertyWithDefault>(213, "filter_expr", result->filter_expr); - return std::move(result); -} - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// This file is automatically generated by scripts/generate_serialization.py -// Do not edit this file manually, your changes will be overwritten -//===----------------------------------------------------------------------===// - - - - - -namespace duckdb { - -void QueryNode::Serialize(Serializer &serializer) const { - serializer.WriteProperty(100, "type", type); - serializer.WritePropertyWithDefault>>(101, "modifiers", modifiers); - serializer.WriteProperty(102, "cte_map", cte_map); -} - -unique_ptr QueryNode::Deserialize(Deserializer &deserializer) { - auto type = deserializer.ReadProperty(100, "type"); - auto modifiers = deserializer.ReadPropertyWithDefault>>(101, "modifiers"); - auto cte_map = deserializer.ReadProperty(102, "cte_map"); - unique_ptr result; - switch (type) { - case QueryNodeType::CTE_NODE: - result = CTENode::Deserialize(deserializer); - break; - case QueryNodeType::RECURSIVE_CTE_NODE: - result = RecursiveCTENode::Deserialize(deserializer); - break; - case QueryNodeType::SELECT_NODE: - result = SelectNode::Deserialize(deserializer); - break; - case QueryNodeType::SET_OPERATION_NODE: - result = SetOperationNode::Deserialize(deserializer); - break; - default: - throw SerializationException("Unsupported type for deserialization of QueryNode!"); - } - result->modifiers = std::move(modifiers); - result->cte_map = std::move(cte_map); - return result; -} - -void CTENode::Serialize(Serializer &serializer) const { - QueryNode::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "cte_name", ctename); - serializer.WritePropertyWithDefault>(201, "query", query); - serializer.WritePropertyWithDefault>(202, "child", child); - serializer.WritePropertyWithDefault>(203, "aliases", aliases); -} - -unique_ptr CTENode::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new CTENode()); - deserializer.ReadPropertyWithDefault(200, "cte_name", result->ctename); - deserializer.ReadPropertyWithDefault>(201, "query", result->query); - deserializer.ReadPropertyWithDefault>(202, "child", result->child); - deserializer.ReadPropertyWithDefault>(203, "aliases", result->aliases); - return std::move(result); -} - -void RecursiveCTENode::Serialize(Serializer &serializer) const { - QueryNode::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "cte_name", ctename); - serializer.WritePropertyWithDefault(201, "union_all", union_all, false); - serializer.WritePropertyWithDefault>(202, "left", left); - serializer.WritePropertyWithDefault>(203, "right", right); - serializer.WritePropertyWithDefault>(204, "aliases", aliases); -} - -unique_ptr RecursiveCTENode::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new RecursiveCTENode()); - deserializer.ReadPropertyWithDefault(200, "cte_name", result->ctename); - deserializer.ReadPropertyWithDefault(201, "union_all", result->union_all, false); - deserializer.ReadPropertyWithDefault>(202, "left", result->left); - deserializer.ReadPropertyWithDefault>(203, "right", result->right); - deserializer.ReadPropertyWithDefault>(204, "aliases", result->aliases); - return std::move(result); -} - -void SelectNode::Serialize(Serializer &serializer) const { - QueryNode::Serialize(serializer); - serializer.WritePropertyWithDefault>>(200, "select_list", select_list); - serializer.WritePropertyWithDefault>(201, "from_table", from_table); - serializer.WritePropertyWithDefault>(202, "where_clause", where_clause); - serializer.WritePropertyWithDefault>>(203, "group_expressions", groups.group_expressions); - serializer.WritePropertyWithDefault>(204, "group_sets", groups.grouping_sets); - serializer.WriteProperty(205, "aggregate_handling", aggregate_handling); - serializer.WritePropertyWithDefault>(206, "having", having); - serializer.WritePropertyWithDefault>(207, "sample", sample); - serializer.WritePropertyWithDefault>(208, "qualify", qualify); -} - -unique_ptr SelectNode::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new SelectNode()); - deserializer.ReadPropertyWithDefault>>(200, "select_list", result->select_list); - deserializer.ReadPropertyWithDefault>(201, "from_table", result->from_table); - deserializer.ReadPropertyWithDefault>(202, "where_clause", result->where_clause); - deserializer.ReadPropertyWithDefault>>(203, "group_expressions", result->groups.group_expressions); - deserializer.ReadPropertyWithDefault>(204, "group_sets", result->groups.grouping_sets); - deserializer.ReadProperty(205, "aggregate_handling", result->aggregate_handling); - deserializer.ReadPropertyWithDefault>(206, "having", result->having); - deserializer.ReadPropertyWithDefault>(207, "sample", result->sample); - deserializer.ReadPropertyWithDefault>(208, "qualify", result->qualify); - return std::move(result); -} - -void SetOperationNode::Serialize(Serializer &serializer) const { - QueryNode::Serialize(serializer); - serializer.WriteProperty(200, "setop_type", setop_type); - serializer.WritePropertyWithDefault>(201, "left", left); - serializer.WritePropertyWithDefault>(202, "right", right); -} - -unique_ptr SetOperationNode::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new SetOperationNode()); - deserializer.ReadProperty(200, "setop_type", result->setop_type); - deserializer.ReadPropertyWithDefault>(201, "left", result->left); - deserializer.ReadPropertyWithDefault>(202, "right", result->right); - return std::move(result); -} - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// This file is automatically generated by scripts/generate_serialization.py -// Do not edit this file manually, your changes will be overwritten -//===----------------------------------------------------------------------===// - - - - - - -namespace duckdb { - -void ResultModifier::Serialize(Serializer &serializer) const { - serializer.WriteProperty(100, "type", type); -} - -unique_ptr ResultModifier::Deserialize(Deserializer &deserializer) { - auto type = deserializer.ReadProperty(100, "type"); - unique_ptr result; - switch (type) { - case ResultModifierType::DISTINCT_MODIFIER: - result = DistinctModifier::Deserialize(deserializer); - break; - case ResultModifierType::LIMIT_MODIFIER: - result = LimitModifier::Deserialize(deserializer); - break; - case ResultModifierType::LIMIT_PERCENT_MODIFIER: - result = LimitPercentModifier::Deserialize(deserializer); - break; - case ResultModifierType::ORDER_MODIFIER: - result = OrderModifier::Deserialize(deserializer); - break; - default: - throw SerializationException("Unsupported type for deserialization of ResultModifier!"); - } - return result; -} - -void BoundOrderModifier::Serialize(Serializer &serializer) const { - serializer.WritePropertyWithDefault>(100, "orders", orders); -} - -unique_ptr BoundOrderModifier::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new BoundOrderModifier()); - deserializer.ReadPropertyWithDefault>(100, "orders", result->orders); - return result; -} - -void DistinctModifier::Serialize(Serializer &serializer) const { - ResultModifier::Serialize(serializer); - serializer.WritePropertyWithDefault>>(200, "distinct_on_targets", distinct_on_targets); -} - -unique_ptr DistinctModifier::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new DistinctModifier()); - deserializer.ReadPropertyWithDefault>>(200, "distinct_on_targets", result->distinct_on_targets); - return std::move(result); -} - -void LimitModifier::Serialize(Serializer &serializer) const { - ResultModifier::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "limit", limit); - serializer.WritePropertyWithDefault>(201, "offset", offset); -} - -unique_ptr LimitModifier::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new LimitModifier()); - deserializer.ReadPropertyWithDefault>(200, "limit", result->limit); - deserializer.ReadPropertyWithDefault>(201, "offset", result->offset); - return std::move(result); -} - -void LimitPercentModifier::Serialize(Serializer &serializer) const { - ResultModifier::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "limit", limit); - serializer.WritePropertyWithDefault>(201, "offset", offset); -} - -unique_ptr LimitPercentModifier::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new LimitPercentModifier()); - deserializer.ReadPropertyWithDefault>(200, "limit", result->limit); - deserializer.ReadPropertyWithDefault>(201, "offset", result->offset); - return std::move(result); -} - -void OrderModifier::Serialize(Serializer &serializer) const { - ResultModifier::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "orders", orders); -} - -unique_ptr OrderModifier::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new OrderModifier()); - deserializer.ReadPropertyWithDefault>(200, "orders", result->orders); - return std::move(result); -} - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// This file is automatically generated by scripts/generate_serialization.py -// Do not edit this file manually, your changes will be overwritten -//===----------------------------------------------------------------------===// - - - - - -namespace duckdb { - -void SelectStatement::Serialize(Serializer &serializer) const { - serializer.WritePropertyWithDefault>(100, "node", node); -} - -unique_ptr SelectStatement::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new SelectStatement()); - deserializer.ReadPropertyWithDefault>(100, "node", result->node); - return result; -} - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// This file is automatically generated by scripts/generate_serialization.py -// Do not edit this file manually, your changes will be overwritten -//===----------------------------------------------------------------------===// - - - - - - - -namespace duckdb { - -void BlockPointer::Serialize(Serializer &serializer) const { - serializer.WriteProperty(100, "block_id", block_id); - serializer.WritePropertyWithDefault(101, "offset", offset); -} - -BlockPointer BlockPointer::Deserialize(Deserializer &deserializer) { - auto block_id = deserializer.ReadProperty(100, "block_id"); - auto offset = deserializer.ReadPropertyWithDefault(101, "offset"); - BlockPointer result(block_id, offset); - return result; -} - -void DataPointer::Serialize(Serializer &serializer) const { - serializer.WritePropertyWithDefault(100, "row_start", row_start); - serializer.WritePropertyWithDefault(101, "tuple_count", tuple_count); - serializer.WriteProperty(102, "block_pointer", block_pointer); - serializer.WriteProperty(103, "compression_type", compression_type); - serializer.WriteProperty(104, "statistics", statistics); - serializer.WritePropertyWithDefault>(105, "segment_state", segment_state); -} - -DataPointer DataPointer::Deserialize(Deserializer &deserializer) { - auto row_start = deserializer.ReadPropertyWithDefault(100, "row_start"); - auto tuple_count = deserializer.ReadPropertyWithDefault(101, "tuple_count"); - auto block_pointer = deserializer.ReadProperty(102, "block_pointer"); - auto compression_type = deserializer.ReadProperty(103, "compression_type"); - auto statistics = deserializer.ReadProperty(104, "statistics"); - DataPointer result(std::move(statistics)); - result.row_start = row_start; - result.tuple_count = tuple_count; - result.block_pointer = block_pointer; - result.compression_type = compression_type; - deserializer.Set(compression_type); - deserializer.ReadPropertyWithDefault>(105, "segment_state", result.segment_state); - deserializer.Unset(); - return result; -} - -void DistinctStatistics::Serialize(Serializer &serializer) const { - serializer.WritePropertyWithDefault(100, "sample_count", sample_count); - serializer.WritePropertyWithDefault(101, "total_count", total_count); - serializer.WritePropertyWithDefault>(102, "log", log); -} - -unique_ptr DistinctStatistics::Deserialize(Deserializer &deserializer) { - auto sample_count = deserializer.ReadPropertyWithDefault(100, "sample_count"); - auto total_count = deserializer.ReadPropertyWithDefault(101, "total_count"); - auto log = deserializer.ReadPropertyWithDefault>(102, "log"); - auto result = duckdb::unique_ptr(new DistinctStatistics(std::move(log), sample_count, total_count)); - return result; -} - -void MetaBlockPointer::Serialize(Serializer &serializer) const { - serializer.WritePropertyWithDefault(100, "block_pointer", block_pointer); - serializer.WritePropertyWithDefault(101, "offset", offset); -} - -MetaBlockPointer MetaBlockPointer::Deserialize(Deserializer &deserializer) { - auto block_pointer = deserializer.ReadPropertyWithDefault(100, "block_pointer"); - auto offset = deserializer.ReadPropertyWithDefault(101, "offset"); - MetaBlockPointer result(block_pointer, offset); - return result; -} - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// This file is automatically generated by scripts/generate_serialization.py -// Do not edit this file manually, your changes will be overwritten -//===----------------------------------------------------------------------===// - - - - - - - - -namespace duckdb { - -void TableFilter::Serialize(Serializer &serializer) const { - serializer.WriteProperty(100, "filter_type", filter_type); -} - -unique_ptr TableFilter::Deserialize(Deserializer &deserializer) { - auto filter_type = deserializer.ReadProperty(100, "filter_type"); - unique_ptr result; - switch (filter_type) { - case TableFilterType::CONJUNCTION_AND: - result = ConjunctionAndFilter::Deserialize(deserializer); - break; - case TableFilterType::CONJUNCTION_OR: - result = ConjunctionOrFilter::Deserialize(deserializer); - break; - case TableFilterType::CONSTANT_COMPARISON: - result = ConstantFilter::Deserialize(deserializer); - break; - case TableFilterType::IS_NOT_NULL: - result = IsNotNullFilter::Deserialize(deserializer); - break; - case TableFilterType::IS_NULL: - result = IsNullFilter::Deserialize(deserializer); - break; - default: - throw SerializationException("Unsupported type for deserialization of TableFilter!"); - } - return result; -} - -void ConjunctionAndFilter::Serialize(Serializer &serializer) const { - TableFilter::Serialize(serializer); - serializer.WritePropertyWithDefault>>(200, "child_filters", child_filters); -} - -unique_ptr ConjunctionAndFilter::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new ConjunctionAndFilter()); - deserializer.ReadPropertyWithDefault>>(200, "child_filters", result->child_filters); - return std::move(result); -} - -void ConjunctionOrFilter::Serialize(Serializer &serializer) const { - TableFilter::Serialize(serializer); - serializer.WritePropertyWithDefault>>(200, "child_filters", child_filters); -} - -unique_ptr ConjunctionOrFilter::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new ConjunctionOrFilter()); - deserializer.ReadPropertyWithDefault>>(200, "child_filters", result->child_filters); - return std::move(result); -} - -void ConstantFilter::Serialize(Serializer &serializer) const { - TableFilter::Serialize(serializer); - serializer.WriteProperty(200, "comparison_type", comparison_type); - serializer.WriteProperty(201, "constant", constant); -} - -unique_ptr ConstantFilter::Deserialize(Deserializer &deserializer) { - auto comparison_type = deserializer.ReadProperty(200, "comparison_type"); - auto constant = deserializer.ReadProperty(201, "constant"); - auto result = duckdb::unique_ptr(new ConstantFilter(comparison_type, constant)); - return std::move(result); -} - -void IsNotNullFilter::Serialize(Serializer &serializer) const { - TableFilter::Serialize(serializer); -} - -unique_ptr IsNotNullFilter::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new IsNotNullFilter()); - return std::move(result); -} - -void IsNullFilter::Serialize(Serializer &serializer) const { - TableFilter::Serialize(serializer); -} - -unique_ptr IsNullFilter::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new IsNullFilter()); - return std::move(result); -} - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// This file is automatically generated by scripts/generate_serialization.py -// Do not edit this file manually, your changes will be overwritten -//===----------------------------------------------------------------------===// - - - - - -namespace duckdb { - -void TableRef::Serialize(Serializer &serializer) const { - serializer.WriteProperty(100, "type", type); - serializer.WritePropertyWithDefault(101, "alias", alias); - serializer.WritePropertyWithDefault>(102, "sample", sample); -} - -unique_ptr TableRef::Deserialize(Deserializer &deserializer) { - auto type = deserializer.ReadProperty(100, "type"); - auto alias = deserializer.ReadPropertyWithDefault(101, "alias"); - auto sample = deserializer.ReadPropertyWithDefault>(102, "sample"); - unique_ptr result; - switch (type) { - case TableReferenceType::BASE_TABLE: - result = BaseTableRef::Deserialize(deserializer); - break; - case TableReferenceType::EMPTY: - result = EmptyTableRef::Deserialize(deserializer); - break; - case TableReferenceType::EXPRESSION_LIST: - result = ExpressionListRef::Deserialize(deserializer); - break; - case TableReferenceType::JOIN: - result = JoinRef::Deserialize(deserializer); - break; - case TableReferenceType::PIVOT: - result = PivotRef::Deserialize(deserializer); - break; - case TableReferenceType::SUBQUERY: - result = SubqueryRef::Deserialize(deserializer); - break; - case TableReferenceType::TABLE_FUNCTION: - result = TableFunctionRef::Deserialize(deserializer); - break; - default: - throw SerializationException("Unsupported type for deserialization of TableRef!"); - } - result->alias = std::move(alias); - result->sample = std::move(sample); - return result; -} - -void BaseTableRef::Serialize(Serializer &serializer) const { - TableRef::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "schema_name", schema_name); - serializer.WritePropertyWithDefault(201, "table_name", table_name); - serializer.WritePropertyWithDefault>(202, "column_name_alias", column_name_alias); - serializer.WritePropertyWithDefault(203, "catalog_name", catalog_name); -} - -unique_ptr BaseTableRef::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new BaseTableRef()); - deserializer.ReadPropertyWithDefault(200, "schema_name", result->schema_name); - deserializer.ReadPropertyWithDefault(201, "table_name", result->table_name); - deserializer.ReadPropertyWithDefault>(202, "column_name_alias", result->column_name_alias); - deserializer.ReadPropertyWithDefault(203, "catalog_name", result->catalog_name); - return std::move(result); -} - -void EmptyTableRef::Serialize(Serializer &serializer) const { - TableRef::Serialize(serializer); -} - -unique_ptr EmptyTableRef::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new EmptyTableRef()); - return std::move(result); -} - -void ExpressionListRef::Serialize(Serializer &serializer) const { - TableRef::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "expected_names", expected_names); - serializer.WritePropertyWithDefault>(201, "expected_types", expected_types); - serializer.WritePropertyWithDefault>>>(202, "values", values); -} - -unique_ptr ExpressionListRef::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new ExpressionListRef()); - deserializer.ReadPropertyWithDefault>(200, "expected_names", result->expected_names); - deserializer.ReadPropertyWithDefault>(201, "expected_types", result->expected_types); - deserializer.ReadPropertyWithDefault>>>(202, "values", result->values); - return std::move(result); -} - -void JoinRef::Serialize(Serializer &serializer) const { - TableRef::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "left", left); - serializer.WritePropertyWithDefault>(201, "right", right); - serializer.WritePropertyWithDefault>(202, "condition", condition); - serializer.WriteProperty(203, "join_type", type); - serializer.WriteProperty(204, "ref_type", ref_type); - serializer.WritePropertyWithDefault>(205, "using_columns", using_columns); -} - -unique_ptr JoinRef::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new JoinRef()); - deserializer.ReadPropertyWithDefault>(200, "left", result->left); - deserializer.ReadPropertyWithDefault>(201, "right", result->right); - deserializer.ReadPropertyWithDefault>(202, "condition", result->condition); - deserializer.ReadProperty(203, "join_type", result->type); - deserializer.ReadProperty(204, "ref_type", result->ref_type); - deserializer.ReadPropertyWithDefault>(205, "using_columns", result->using_columns); - return std::move(result); -} - -void PivotRef::Serialize(Serializer &serializer) const { - TableRef::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "source", source); - serializer.WritePropertyWithDefault>>(201, "aggregates", aggregates); - serializer.WritePropertyWithDefault>(202, "unpivot_names", unpivot_names); - serializer.WritePropertyWithDefault>(203, "pivots", pivots); - serializer.WritePropertyWithDefault>(204, "groups", groups); - serializer.WritePropertyWithDefault>(205, "column_name_alias", column_name_alias); - serializer.WritePropertyWithDefault(206, "include_nulls", include_nulls); -} - -unique_ptr PivotRef::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new PivotRef()); - deserializer.ReadPropertyWithDefault>(200, "source", result->source); - deserializer.ReadPropertyWithDefault>>(201, "aggregates", result->aggregates); - deserializer.ReadPropertyWithDefault>(202, "unpivot_names", result->unpivot_names); - deserializer.ReadPropertyWithDefault>(203, "pivots", result->pivots); - deserializer.ReadPropertyWithDefault>(204, "groups", result->groups); - deserializer.ReadPropertyWithDefault>(205, "column_name_alias", result->column_name_alias); - deserializer.ReadPropertyWithDefault(206, "include_nulls", result->include_nulls); - return std::move(result); -} - -void SubqueryRef::Serialize(Serializer &serializer) const { - TableRef::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "subquery", subquery); - serializer.WritePropertyWithDefault>(201, "column_name_alias", column_name_alias); -} - -unique_ptr SubqueryRef::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new SubqueryRef()); - deserializer.ReadPropertyWithDefault>(200, "subquery", result->subquery); - deserializer.ReadPropertyWithDefault>(201, "column_name_alias", result->column_name_alias); - return std::move(result); -} - -void TableFunctionRef::Serialize(Serializer &serializer) const { - TableRef::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "function", function); - serializer.WritePropertyWithDefault>(201, "column_name_alias", column_name_alias); -} - -unique_ptr TableFunctionRef::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new TableFunctionRef()); - deserializer.ReadPropertyWithDefault>(200, "function", result->function); - deserializer.ReadPropertyWithDefault>(201, "column_name_alias", result->column_name_alias); - return std::move(result); -} - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// This file is automatically generated by scripts/generate_serialization.py -// Do not edit this file manually, your changes will be overwritten -//===----------------------------------------------------------------------===// - - - - - -namespace duckdb { - -void ExtraTypeInfo::Serialize(Serializer &serializer) const { - serializer.WriteProperty(100, "type", type); - serializer.WritePropertyWithDefault(101, "alias", alias); -} - -shared_ptr ExtraTypeInfo::Deserialize(Deserializer &deserializer) { - auto type = deserializer.ReadProperty(100, "type"); - auto alias = deserializer.ReadPropertyWithDefault(101, "alias"); - shared_ptr result; - switch (type) { - case ExtraTypeInfoType::AGGREGATE_STATE_TYPE_INFO: - result = AggregateStateTypeInfo::Deserialize(deserializer); - break; - case ExtraTypeInfoType::DECIMAL_TYPE_INFO: - result = DecimalTypeInfo::Deserialize(deserializer); - break; - case ExtraTypeInfoType::ENUM_TYPE_INFO: - result = EnumTypeInfo::Deserialize(deserializer); - break; - case ExtraTypeInfoType::GENERIC_TYPE_INFO: - result = make_shared(type); - break; - case ExtraTypeInfoType::INVALID_TYPE_INFO: - return nullptr; - case ExtraTypeInfoType::LIST_TYPE_INFO: - result = ListTypeInfo::Deserialize(deserializer); - break; - case ExtraTypeInfoType::STRING_TYPE_INFO: - result = StringTypeInfo::Deserialize(deserializer); - break; - case ExtraTypeInfoType::STRUCT_TYPE_INFO: - result = StructTypeInfo::Deserialize(deserializer); - break; - case ExtraTypeInfoType::USER_TYPE_INFO: - result = UserTypeInfo::Deserialize(deserializer); - break; - default: - throw SerializationException("Unsupported type for deserialization of ExtraTypeInfo!"); - } - result->alias = std::move(alias); - return result; -} - -void AggregateStateTypeInfo::Serialize(Serializer &serializer) const { - ExtraTypeInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "function_name", state_type.function_name); - serializer.WriteProperty(201, "return_type", state_type.return_type); - serializer.WritePropertyWithDefault>(202, "bound_argument_types", state_type.bound_argument_types); -} - -shared_ptr AggregateStateTypeInfo::Deserialize(Deserializer &deserializer) { - auto result = duckdb::shared_ptr(new AggregateStateTypeInfo()); - deserializer.ReadPropertyWithDefault(200, "function_name", result->state_type.function_name); - deserializer.ReadProperty(201, "return_type", result->state_type.return_type); - deserializer.ReadPropertyWithDefault>(202, "bound_argument_types", result->state_type.bound_argument_types); - return std::move(result); -} - -void DecimalTypeInfo::Serialize(Serializer &serializer) const { - ExtraTypeInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "width", width); - serializer.WritePropertyWithDefault(201, "scale", scale); -} - -shared_ptr DecimalTypeInfo::Deserialize(Deserializer &deserializer) { - auto result = duckdb::shared_ptr(new DecimalTypeInfo()); - deserializer.ReadPropertyWithDefault(200, "width", result->width); - deserializer.ReadPropertyWithDefault(201, "scale", result->scale); - return std::move(result); -} - -void ListTypeInfo::Serialize(Serializer &serializer) const { - ExtraTypeInfo::Serialize(serializer); - serializer.WriteProperty(200, "child_type", child_type); -} - -shared_ptr ListTypeInfo::Deserialize(Deserializer &deserializer) { - auto result = duckdb::shared_ptr(new ListTypeInfo()); - deserializer.ReadProperty(200, "child_type", result->child_type); - return std::move(result); -} - -void StringTypeInfo::Serialize(Serializer &serializer) const { - ExtraTypeInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "collation", collation); -} - -shared_ptr StringTypeInfo::Deserialize(Deserializer &deserializer) { - auto result = duckdb::shared_ptr(new StringTypeInfo()); - deserializer.ReadPropertyWithDefault(200, "collation", result->collation); - return std::move(result); -} - -void StructTypeInfo::Serialize(Serializer &serializer) const { - ExtraTypeInfo::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "child_types", child_types); -} - -shared_ptr StructTypeInfo::Deserialize(Deserializer &deserializer) { - auto result = duckdb::shared_ptr(new StructTypeInfo()); - deserializer.ReadPropertyWithDefault>(200, "child_types", result->child_types); - return std::move(result); -} - -void UserTypeInfo::Serialize(Serializer &serializer) const { - ExtraTypeInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "user_type_name", user_type_name); -} - -shared_ptr UserTypeInfo::Deserialize(Deserializer &deserializer) { - auto result = duckdb::shared_ptr(new UserTypeInfo()); - deserializer.ReadPropertyWithDefault(200, "user_type_name", result->user_type_name); - return std::move(result); -} - -} // namespace duckdb - - - - - - - - - - - - -#include -#include - -namespace duckdb { - -const char MainHeader::MAGIC_BYTES[] = "DUCK"; - -void SerializeVersionNumber(WriteStream &ser, const string &version_str) { - constexpr const idx_t MAX_VERSION_SIZE = 32; - data_t version[MAX_VERSION_SIZE]; - memset(version, 0, MAX_VERSION_SIZE); - memcpy(version, version_str.c_str(), MinValue(version_str.size(), MAX_VERSION_SIZE)); - ser.WriteData(version, MAX_VERSION_SIZE); -} - -void MainHeader::Write(WriteStream &ser) { - ser.WriteData(const_data_ptr_cast(MAGIC_BYTES), MAGIC_BYTE_SIZE); - ser.Write(version_number); - for (idx_t i = 0; i < FLAG_COUNT; i++) { - ser.Write(flags[i]); - } - SerializeVersionNumber(ser, DuckDB::LibraryVersion()); - SerializeVersionNumber(ser, DuckDB::SourceID()); -} - -void MainHeader::CheckMagicBytes(FileHandle &handle) { - data_t magic_bytes[MAGIC_BYTE_SIZE]; - if (handle.GetFileSize() < MainHeader::MAGIC_BYTE_SIZE + MainHeader::MAGIC_BYTE_OFFSET) { - throw IOException("The file \"%s\" exists, but it is not a valid DuckDB database file!", handle.path); - } - handle.Read(magic_bytes, MainHeader::MAGIC_BYTE_SIZE, MainHeader::MAGIC_BYTE_OFFSET); - if (memcmp(magic_bytes, MainHeader::MAGIC_BYTES, MainHeader::MAGIC_BYTE_SIZE) != 0) { - throw IOException("The file \"%s\" exists, but it is not a valid DuckDB database file!", handle.path); - } -} - -MainHeader MainHeader::Read(ReadStream &source) { - data_t magic_bytes[MAGIC_BYTE_SIZE]; - MainHeader header; - source.ReadData(magic_bytes, MainHeader::MAGIC_BYTE_SIZE); - if (memcmp(magic_bytes, MainHeader::MAGIC_BYTES, MainHeader::MAGIC_BYTE_SIZE) != 0) { - throw IOException("The file is not a valid DuckDB database file!"); - } - header.version_number = source.Read(); - // check the version number - if (header.version_number != VERSION_NUMBER) { - auto version = GetDuckDBVersion(header.version_number); - string version_text; - if (version) { - // known version - version_text = "DuckDB version " + string(version); - } else { - version_text = string("an ") + (VERSION_NUMBER > header.version_number ? "older development" : "newer") + - string(" version of DuckDB"); - } - throw IOException( - "Trying to read a database file with version number %lld, but we can only read version %lld.\n" - "The database file was created with %s.\n\n" - "The storage of DuckDB is not yet stable; newer versions of DuckDB cannot read old database files and " - "vice versa.\n" - "The storage will be stabilized when version 1.0 releases.\n\n" - "For now, we recommend that you load the database file in a supported version of DuckDB, and use the " - "EXPORT DATABASE command " - "followed by IMPORT DATABASE on the current version of DuckDB.\n\n" - "See the storage page for more information: https://duckdb.org/internals/storage", - header.version_number, VERSION_NUMBER, version_text); - } - // read the flags - for (idx_t i = 0; i < FLAG_COUNT; i++) { - header.flags[i] = source.Read(); - } - return header; -} - -void DatabaseHeader::Write(WriteStream &ser) { - ser.Write(iteration); - ser.Write(meta_block); - ser.Write(free_list); - ser.Write(block_count); -} - -DatabaseHeader DatabaseHeader::Read(ReadStream &source) { - DatabaseHeader header; - header.iteration = source.Read(); - header.meta_block = source.Read(); - header.free_list = source.Read(); - header.block_count = source.Read(); - return header; -} - -template -void SerializeHeaderStructure(T header, data_ptr_t ptr) { - MemoryStream ser(ptr, Storage::FILE_HEADER_SIZE); - header.Write(ser); -} - -template -T DeserializeHeaderStructure(data_ptr_t ptr) { - MemoryStream source(ptr, Storage::FILE_HEADER_SIZE); - return T::Read(source); -} - -SingleFileBlockManager::SingleFileBlockManager(AttachedDatabase &db, string path_p, StorageManagerOptions options) - : BlockManager(BufferManager::GetBufferManager(db)), db(db), path(std::move(path_p)), - header_buffer(Allocator::Get(db), FileBufferType::MANAGED_BUFFER, - Storage::FILE_HEADER_SIZE - Storage::BLOCK_HEADER_SIZE), - iteration_count(0), options(options) { -} - -void SingleFileBlockManager::GetFileFlags(uint8_t &flags, FileLockType &lock, bool create_new) { - if (options.read_only) { - D_ASSERT(!create_new); - flags = FileFlags::FILE_FLAGS_READ; - lock = FileLockType::READ_LOCK; - } else { - flags = FileFlags::FILE_FLAGS_WRITE | FileFlags::FILE_FLAGS_READ; - lock = FileLockType::WRITE_LOCK; - if (create_new) { - flags |= FileFlags::FILE_FLAGS_FILE_CREATE; - } - } - if (options.use_direct_io) { - flags |= FileFlags::FILE_FLAGS_DIRECT_IO; - } -} - -void SingleFileBlockManager::CreateNewDatabase() { - uint8_t flags; - FileLockType lock; - GetFileFlags(flags, lock, true); - - // open the RDBMS handle - auto &fs = FileSystem::Get(db); - handle = fs.OpenFile(path, flags, lock); - - // if we create a new file, we fill the metadata of the file - // first fill in the new header - header_buffer.Clear(); - - MainHeader main_header; - main_header.version_number = VERSION_NUMBER; - memset(main_header.flags, 0, sizeof(uint64_t) * 4); - - SerializeHeaderStructure(main_header, header_buffer.buffer); - // now write the header to the file - ChecksumAndWrite(header_buffer, 0); - header_buffer.Clear(); - - // write the database headers - // initialize meta_block and free_list to INVALID_BLOCK because the database file does not contain any actual - // content yet - DatabaseHeader h1, h2; - // header 1 - h1.iteration = 0; - h1.meta_block = INVALID_BLOCK; - h1.free_list = INVALID_BLOCK; - h1.block_count = 0; - SerializeHeaderStructure(h1, header_buffer.buffer); - ChecksumAndWrite(header_buffer, Storage::FILE_HEADER_SIZE); - // header 2 - h2.iteration = 0; - h2.meta_block = INVALID_BLOCK; - h2.free_list = INVALID_BLOCK; - h2.block_count = 0; - SerializeHeaderStructure(h2, header_buffer.buffer); - ChecksumAndWrite(header_buffer, Storage::FILE_HEADER_SIZE * 2ULL); - // ensure that writing to disk is completed before returning - handle->Sync(); - // we start with h2 as active_header, this way our initial write will be in h1 - iteration_count = 0; - active_header = 1; - max_block = 0; -} - -void SingleFileBlockManager::LoadExistingDatabase() { - uint8_t flags; - FileLockType lock; - GetFileFlags(flags, lock, false); - - // open the RDBMS handle - auto &fs = FileSystem::Get(db); - handle = fs.OpenFile(path, flags, lock); - - MainHeader::CheckMagicBytes(*handle); - // otherwise, we check the metadata of the file - ReadAndChecksum(header_buffer, 0); - DeserializeHeaderStructure(header_buffer.buffer); - - // read the database headers from disk - DatabaseHeader h1, h2; - ReadAndChecksum(header_buffer, Storage::FILE_HEADER_SIZE); - h1 = DeserializeHeaderStructure(header_buffer.buffer); - ReadAndChecksum(header_buffer, Storage::FILE_HEADER_SIZE * 2ULL); - h2 = DeserializeHeaderStructure(header_buffer.buffer); - // check the header with the highest iteration count - if (h1.iteration > h2.iteration) { - // h1 is active header - active_header = 0; - Initialize(h1); - } else { - // h2 is active header - active_header = 1; - Initialize(h2); - } - LoadFreeList(); -} - -void SingleFileBlockManager::ReadAndChecksum(FileBuffer &block, uint64_t location) const { - // read the buffer from disk - block.Read(*handle, location); - // compute the checksum - auto stored_checksum = Load(block.InternalBuffer()); - uint64_t computed_checksum = Checksum(block.buffer, block.size); - // verify the checksum - if (stored_checksum != computed_checksum) { - throw IOException("Corrupt database file: computed checksum %llu does not match stored checksum %llu in block", - computed_checksum, stored_checksum); - } -} - -void SingleFileBlockManager::ChecksumAndWrite(FileBuffer &block, uint64_t location) const { - // compute the checksum and write it to the start of the buffer (if not temp buffer) - uint64_t checksum = Checksum(block.buffer, block.size); - Store(checksum, block.InternalBuffer()); - // now write the buffer - block.Write(*handle, location); -} - -void SingleFileBlockManager::Initialize(DatabaseHeader &header) { - free_list_id = header.free_list; - meta_block = header.meta_block; - iteration_count = header.iteration; - max_block = header.block_count; -} - -void SingleFileBlockManager::LoadFreeList() { - MetaBlockPointer free_pointer(free_list_id, 0); - if (!free_pointer.IsValid()) { - // no free list - return; - } - MetadataReader reader(GetMetadataManager(), free_pointer, nullptr, BlockReaderType::REGISTER_BLOCKS); - auto free_list_count = reader.Read(); - free_list.clear(); - for (idx_t i = 0; i < free_list_count; i++) { - free_list.insert(reader.Read()); - } - auto multi_use_blocks_count = reader.Read(); - multi_use_blocks.clear(); - for (idx_t i = 0; i < multi_use_blocks_count; i++) { - auto block_id = reader.Read(); - auto usage_count = reader.Read(); - multi_use_blocks[block_id] = usage_count; - } - GetMetadataManager().Read(reader); - GetMetadataManager().MarkBlocksAsModified(); -} - -bool SingleFileBlockManager::IsRootBlock(MetaBlockPointer root) { - return root.block_pointer == meta_block; -} - -block_id_t SingleFileBlockManager::GetFreeBlockId() { - lock_guard lock(block_lock); - block_id_t block; - if (!free_list.empty()) { - // free list is non empty - // take an entry from the free list - block = *free_list.begin(); - // erase the entry from the free list again - free_list.erase(free_list.begin()); - } else { - block = max_block++; - } - return block; -} - -void SingleFileBlockManager::MarkBlockAsFree(block_id_t block_id) { - lock_guard lock(block_lock); - D_ASSERT(block_id >= 0); - D_ASSERT(block_id < max_block); - if (free_list.find(block_id) != free_list.end()) { - throw InternalException("MarkBlockAsFree called but block %llu was already freed!", block_id); - } - multi_use_blocks.erase(block_id); - free_list.insert(block_id); -} - -void SingleFileBlockManager::MarkBlockAsModified(block_id_t block_id) { - lock_guard lock(block_lock); - D_ASSERT(block_id >= 0); - D_ASSERT(block_id < max_block); - - // check if the block is a multi-use block - auto entry = multi_use_blocks.find(block_id); - if (entry != multi_use_blocks.end()) { - // it is! reduce the reference count of the block - entry->second--; - // check the reference count: is the block still a multi-use block? - if (entry->second <= 1) { - // no longer a multi-use block! - multi_use_blocks.erase(entry); - } - return; - } - // Check for multi-free - // TODO: Fix the bug that causes this assert to fire, then uncomment it. - // D_ASSERT(modified_blocks.find(block_id) == modified_blocks.end()); - D_ASSERT(free_list.find(block_id) == free_list.end()); - modified_blocks.insert(block_id); -} - -void SingleFileBlockManager::IncreaseBlockReferenceCount(block_id_t block_id) { - lock_guard lock(block_lock); - D_ASSERT(block_id >= 0); - D_ASSERT(block_id < max_block); - D_ASSERT(free_list.find(block_id) == free_list.end()); - auto entry = multi_use_blocks.find(block_id); - if (entry != multi_use_blocks.end()) { - entry->second++; - } else { - multi_use_blocks[block_id] = 2; - } -} - -idx_t SingleFileBlockManager::GetMetaBlock() { - return meta_block; -} - -idx_t SingleFileBlockManager::TotalBlocks() { - lock_guard lock(block_lock); - return max_block; -} - -idx_t SingleFileBlockManager::FreeBlocks() { - lock_guard lock(block_lock); - return free_list.size(); -} - -unique_ptr SingleFileBlockManager::ConvertBlock(block_id_t block_id, FileBuffer &source_buffer) { - D_ASSERT(source_buffer.AllocSize() == Storage::BLOCK_ALLOC_SIZE); - return make_uniq(source_buffer, block_id); -} - -unique_ptr SingleFileBlockManager::CreateBlock(block_id_t block_id, FileBuffer *source_buffer) { - unique_ptr result; - if (source_buffer) { - result = ConvertBlock(block_id, *source_buffer); - } else { - result = make_uniq(Allocator::Get(db), block_id); - } - result->Initialize(options.debug_initialize); - return result; -} - -void SingleFileBlockManager::Read(Block &block) { - D_ASSERT(block.id >= 0); - D_ASSERT(std::find(free_list.begin(), free_list.end(), block.id) == free_list.end()); - ReadAndChecksum(block, BLOCK_START + block.id * Storage::BLOCK_ALLOC_SIZE); -} - -void SingleFileBlockManager::Write(FileBuffer &buffer, block_id_t block_id) { - D_ASSERT(block_id >= 0); - ChecksumAndWrite(buffer, BLOCK_START + block_id * Storage::BLOCK_ALLOC_SIZE); -} - -void SingleFileBlockManager::Truncate() { - BlockManager::Truncate(); - idx_t blocks_to_truncate = 0; - // reverse iterate over the free-list - for (auto entry = free_list.rbegin(); entry != free_list.rend(); entry++) { - auto block_id = *entry; - if (block_id + 1 != max_block) { - break; - } - blocks_to_truncate++; - max_block--; - } - if (blocks_to_truncate == 0) { - // nothing to truncate - return; - } - // truncate the file - for (idx_t i = 0; i < blocks_to_truncate; i++) { - free_list.erase(max_block + i); - } - handle->Truncate(BLOCK_START + max_block * Storage::BLOCK_ALLOC_SIZE); -} - -vector SingleFileBlockManager::GetFreeListBlocks() { - vector free_list_blocks; - - // reserve all blocks that we are going to write the free list to - // since these blocks are no longer free we cannot just include them in the free list! - auto block_size = MetadataManager::METADATA_BLOCK_SIZE - sizeof(idx_t); - idx_t allocated_size = 0; - while (true) { - auto free_list_size = sizeof(uint64_t) + sizeof(block_id_t) * (free_list.size() + modified_blocks.size()); - auto multi_use_blocks_size = - sizeof(uint64_t) + (sizeof(block_id_t) + sizeof(uint32_t)) * multi_use_blocks.size(); - auto metadata_blocks = - sizeof(uint64_t) + (sizeof(block_id_t) + sizeof(idx_t)) * GetMetadataManager().BlockCount(); - auto total_size = free_list_size + multi_use_blocks_size + metadata_blocks; - if (total_size < allocated_size) { - break; - } - auto free_list_handle = GetMetadataManager().AllocateHandle(); - free_list_blocks.push_back(std::move(free_list_handle)); - allocated_size += block_size; - } - - return free_list_blocks; -} - -class FreeListBlockWriter : public MetadataWriter { -public: - FreeListBlockWriter(MetadataManager &manager, vector free_list_blocks_p) - : MetadataWriter(manager), free_list_blocks(std::move(free_list_blocks_p)), index(0) { - } - - vector free_list_blocks; - idx_t index; - -protected: - MetadataHandle NextHandle() override { - if (index >= free_list_blocks.size()) { - throw InternalException( - "Free List Block Writer ran out of blocks, this means not enough blocks were allocated up front"); - } - return std::move(free_list_blocks[index++]); - } -}; - -void SingleFileBlockManager::WriteHeader(DatabaseHeader header) { - // set the iteration count - header.iteration = ++iteration_count; - - auto free_list_blocks = GetFreeListBlocks(); - - // now handle the free list - auto &metadata_manager = GetMetadataManager(); - // add all modified blocks to the free list: they can now be written to again - metadata_manager.MarkBlocksAsModified(); - for (auto &block : modified_blocks) { - free_list.insert(block); - } - modified_blocks.clear(); - - if (!free_list_blocks.empty()) { - // there are blocks to write, either in the free_list or in the modified_blocks - // we write these blocks specifically to the free_list_blocks - // a normal MetadataWriter will fetch blocks to use from the free_list - // but since we are WRITING the free_list, this behavior is sub-optimal - FreeListBlockWriter writer(metadata_manager, std::move(free_list_blocks)); - - auto ptr = writer.GetMetaBlockPointer(); - header.free_list = ptr.block_pointer; - - writer.Write(free_list.size()); - for (auto &block_id : free_list) { - writer.Write(block_id); - } - writer.Write(multi_use_blocks.size()); - for (auto &entry : multi_use_blocks) { - writer.Write(entry.first); - writer.Write(entry.second); - } - GetMetadataManager().Write(writer); - writer.Flush(); - } else { - // no blocks in the free list - header.free_list = DConstants::INVALID_INDEX; - } - metadata_manager.Flush(); - header.block_count = max_block; - - auto &config = DBConfig::Get(db); - if (config.options.checkpoint_abort == CheckpointAbort::DEBUG_ABORT_AFTER_FREE_LIST_WRITE) { - throw FatalException("Checkpoint aborted after free list write because of PRAGMA checkpoint_abort flag"); - } - - if (!options.use_direct_io) { - // if we are not using Direct IO we need to fsync BEFORE we write the header to ensure that all the previous - // blocks are written as well - handle->Sync(); - } - // set the header inside the buffer - header_buffer.Clear(); - MemoryStream serializer; - header.Write(serializer); - memcpy(header_buffer.buffer, serializer.GetData(), serializer.GetPosition()); - // now write the header to the file, active_header determines whether we write to h1 or h2 - // note that if active_header is h1 we write to h2, and vice versa - ChecksumAndWrite(header_buffer, active_header == 1 ? Storage::FILE_HEADER_SIZE : Storage::FILE_HEADER_SIZE * 2); - // switch active header to the other header - active_header = 1 - active_header; - //! Ensure the header write ends up on disk - handle->Sync(); -} - -} // namespace duckdb - - - - - - - - - - - -namespace duckdb { - -struct BufferAllocatorData : PrivateAllocatorData { - explicit BufferAllocatorData(StandardBufferManager &manager) : manager(manager) { - } - - StandardBufferManager &manager; -}; - -unique_ptr StandardBufferManager::ConstructManagedBuffer(idx_t size, unique_ptr &&source, - FileBufferType type) { - unique_ptr result; - if (source) { - auto tmp = std::move(source); - D_ASSERT(tmp->AllocSize() == BufferManager::GetAllocSize(size)); - result = make_uniq(*tmp, type); - } else { - // no re-usable buffer: allocate a new buffer - result = make_uniq(Allocator::Get(db), type, size); - } - result->Initialize(DBConfig::GetConfig(db).options.debug_initialize); - return result; -} - -class TemporaryFileManager; - -class TemporaryDirectoryHandle { -public: - TemporaryDirectoryHandle(DatabaseInstance &db, string path_p); - ~TemporaryDirectoryHandle(); - - TemporaryFileManager &GetTempFile(); - -private: - DatabaseInstance &db; - string temp_directory; - bool created_directory = false; - unique_ptr temp_file; -}; - -void StandardBufferManager::SetTemporaryDirectory(const string &new_dir) { - if (temp_directory_handle) { - throw NotImplementedException("Cannot switch temporary directory after the current one has been used"); - } - this->temp_directory = new_dir; -} - -StandardBufferManager::StandardBufferManager(DatabaseInstance &db, string tmp) - : BufferManager(), db(db), buffer_pool(db.GetBufferPool()), temp_directory(std::move(tmp)), - temporary_id(MAXIMUM_BLOCK), buffer_allocator(BufferAllocatorAllocate, BufferAllocatorFree, - BufferAllocatorRealloc, make_uniq(*this)) { - temp_block_manager = make_uniq(*this); -} - -StandardBufferManager::~StandardBufferManager() { -} - -BufferPool &StandardBufferManager::GetBufferPool() { - return buffer_pool; -} - -idx_t StandardBufferManager::GetUsedMemory() const { - return buffer_pool.GetUsedMemory(); -} -idx_t StandardBufferManager::GetMaxMemory() const { - return buffer_pool.GetMaxMemory(); -} - -template -TempBufferPoolReservation StandardBufferManager::EvictBlocksOrThrow(idx_t memory_delta, unique_ptr *buffer, - ARGS... args) { - auto r = buffer_pool.EvictBlocks(memory_delta, buffer_pool.maximum_memory, buffer); - if (!r.success) { - string extra_text = StringUtil::Format(" (%s/%s used)", StringUtil::BytesToHumanReadableString(GetUsedMemory()), - StringUtil::BytesToHumanReadableString(GetMaxMemory())); - extra_text += InMemoryWarning(); - throw OutOfMemoryException(args..., extra_text); - } - return std::move(r.reservation); -} - -shared_ptr StandardBufferManager::RegisterSmallMemory(idx_t block_size) { - D_ASSERT(block_size < Storage::BLOCK_SIZE); - auto res = EvictBlocksOrThrow(block_size, nullptr, "could not allocate block of size %s%s", - StringUtil::BytesToHumanReadableString(block_size)); - - auto buffer = ConstructManagedBuffer(block_size, nullptr, FileBufferType::TINY_BUFFER); - - // create a new block pointer for this block - return make_shared(*temp_block_manager, ++temporary_id, std::move(buffer), false, block_size, - std::move(res)); -} - -shared_ptr StandardBufferManager::RegisterMemory(idx_t block_size, bool can_destroy) { - D_ASSERT(block_size >= Storage::BLOCK_SIZE); - auto alloc_size = GetAllocSize(block_size); - // first evict blocks until we have enough memory to store this buffer - unique_ptr reusable_buffer; - auto res = EvictBlocksOrThrow(alloc_size, &reusable_buffer, "could not allocate block of size %s%s", - StringUtil::BytesToHumanReadableString(alloc_size)); - - auto buffer = ConstructManagedBuffer(block_size, std::move(reusable_buffer)); - - // create a new block pointer for this block - return make_shared(*temp_block_manager, ++temporary_id, std::move(buffer), can_destroy, alloc_size, - std::move(res)); -} - -BufferHandle StandardBufferManager::Allocate(idx_t block_size, bool can_destroy, shared_ptr *block) { - shared_ptr local_block; - auto block_ptr = block ? block : &local_block; - *block_ptr = RegisterMemory(block_size, can_destroy); - return Pin(*block_ptr); -} - -void StandardBufferManager::ReAllocate(shared_ptr &handle, idx_t block_size) { - D_ASSERT(block_size >= Storage::BLOCK_SIZE); - lock_guard lock(handle->lock); - D_ASSERT(handle->state == BlockState::BLOCK_LOADED); - D_ASSERT(handle->memory_usage == handle->buffer->AllocSize()); - D_ASSERT(handle->memory_usage == handle->memory_charge.size); - - auto req = handle->buffer->CalculateMemory(block_size); - int64_t memory_delta = (int64_t)req.alloc_size - handle->memory_usage; - - if (memory_delta == 0) { - return; - } else if (memory_delta > 0) { - // evict blocks until we have space to resize this block - auto reservation = EvictBlocksOrThrow(memory_delta, nullptr, "failed to resize block from %s to %s%s", - StringUtil::BytesToHumanReadableString(handle->memory_usage), - StringUtil::BytesToHumanReadableString(req.alloc_size)); - // EvictBlocks decrements 'current_memory' for us. - handle->memory_charge.Merge(std::move(reservation)); - } else { - // no need to evict blocks, but we do need to decrement 'current_memory'. - handle->memory_charge.Resize(req.alloc_size); - } - - handle->ResizeBuffer(block_size, memory_delta); -} - -BufferHandle StandardBufferManager::Pin(shared_ptr &handle) { - idx_t required_memory; - { - // lock the block - lock_guard lock(handle->lock); - // check if the block is already loaded - if (handle->state == BlockState::BLOCK_LOADED) { - // the block is loaded, increment the reader count and return a pointer to the handle - handle->readers++; - return handle->Load(handle); - } - required_memory = handle->memory_usage; - } - // evict blocks until we have space for the current block - unique_ptr reusable_buffer; - auto reservation = EvictBlocksOrThrow(required_memory, &reusable_buffer, "failed to pin block of size %s%s", - StringUtil::BytesToHumanReadableString(required_memory)); - // lock the handle again and repeat the check (in case anybody loaded in the mean time) - lock_guard lock(handle->lock); - // check if the block is already loaded - if (handle->state == BlockState::BLOCK_LOADED) { - // the block is loaded, increment the reader count and return a pointer to the handle - handle->readers++; - reservation.Resize(0); - return handle->Load(handle); - } - // now we can actually load the current block - D_ASSERT(handle->readers == 0); - handle->readers = 1; - auto buf = handle->Load(handle, std::move(reusable_buffer)); - handle->memory_charge = std::move(reservation); - // In the case of a variable sized block, the buffer may be smaller than a full block. - int64_t delta = handle->buffer->AllocSize() - handle->memory_usage; - if (delta) { - D_ASSERT(delta < 0); - handle->memory_usage += delta; - handle->memory_charge.Resize(handle->memory_usage); - } - D_ASSERT(handle->memory_usage == handle->buffer->AllocSize()); - return buf; -} - -void StandardBufferManager::PurgeQueue() { - buffer_pool.PurgeQueue(); -} - -void StandardBufferManager::AddToEvictionQueue(shared_ptr &handle) { - buffer_pool.AddToEvictionQueue(handle); -} - -void StandardBufferManager::VerifyZeroReaders(shared_ptr &handle) { -#ifdef DUCKDB_DEBUG_DESTROY_BLOCKS - auto replacement_buffer = make_uniq(Allocator::Get(db), handle->buffer->type, - handle->memory_usage - Storage::BLOCK_HEADER_SIZE); - memcpy(replacement_buffer->buffer, handle->buffer->buffer, handle->buffer->size); - memset(handle->buffer->buffer, 0xa5, handle->buffer->size); // 0xa5 is default memory in debug mode - handle->buffer = std::move(replacement_buffer); -#endif -} - -void StandardBufferManager::Unpin(shared_ptr &handle) { - lock_guard lock(handle->lock); - if (!handle->buffer || handle->buffer->type == FileBufferType::TINY_BUFFER) { - return; - } - D_ASSERT(handle->readers > 0); - handle->readers--; - if (handle->readers == 0) { - VerifyZeroReaders(handle); - buffer_pool.AddToEvictionQueue(handle); - } -} - -void StandardBufferManager::SetLimit(idx_t limit) { - buffer_pool.SetLimit(limit, InMemoryWarning()); -} - -//===--------------------------------------------------------------------===// -// Temporary File Management -//===--------------------------------------------------------------------===// -unique_ptr ReadTemporaryBufferInternal(BufferManager &buffer_manager, FileHandle &handle, idx_t position, - idx_t size, block_id_t id, unique_ptr reusable_buffer) { - auto buffer = buffer_manager.ConstructManagedBuffer(size, std::move(reusable_buffer)); - buffer->Read(handle, position); - return buffer; -} - -struct TemporaryFileIndex { - explicit TemporaryFileIndex(idx_t file_index = DConstants::INVALID_INDEX, - idx_t block_index = DConstants::INVALID_INDEX) - : file_index(file_index), block_index(block_index) { - } - - idx_t file_index; - idx_t block_index; - -public: - bool IsValid() { - return block_index != DConstants::INVALID_INDEX; - } -}; - -struct BlockIndexManager { - BlockIndexManager() : max_index(0) { - } - -public: - //! Obtains a new block index from the index manager - idx_t GetNewBlockIndex() { - auto index = GetNewBlockIndexInternal(); - indexes_in_use.insert(index); - return index; - } - - //! Removes an index from the block manager - //! Returns true if the max_index has been altered - bool RemoveIndex(idx_t index) { - // remove this block from the set of blocks - auto entry = indexes_in_use.find(index); - if (entry == indexes_in_use.end()) { - throw InternalException("RemoveIndex - index %llu not found in indexes_in_use", index); - } - indexes_in_use.erase(entry); - free_indexes.insert(index); - // check if we can truncate the file - - // get the max_index in use right now - auto max_index_in_use = indexes_in_use.empty() ? 0 : *indexes_in_use.rbegin(); - if (max_index_in_use < max_index) { - // max index in use is lower than the max_index - // reduce the max_index - max_index = indexes_in_use.empty() ? 0 : max_index_in_use + 1; - // we can remove any free_indexes that are larger than the current max_index - while (!free_indexes.empty()) { - auto max_entry = *free_indexes.rbegin(); - if (max_entry < max_index) { - break; - } - free_indexes.erase(max_entry); - } - return true; - } - return false; - } - - idx_t GetMaxIndex() { - return max_index; - } - - bool HasFreeBlocks() { - return !free_indexes.empty(); - } - -private: - idx_t GetNewBlockIndexInternal() { - if (free_indexes.empty()) { - return max_index++; - } - auto entry = free_indexes.begin(); - auto index = *entry; - free_indexes.erase(entry); - return index; - } - - idx_t max_index; - set free_indexes; - set indexes_in_use; -}; - -class TemporaryFileHandle { - constexpr static idx_t MAX_ALLOWED_INDEX_BASE = 4000; - -public: - TemporaryFileHandle(idx_t temp_file_count, DatabaseInstance &db, const string &temp_directory, idx_t index) - : max_allowed_index((1 << temp_file_count) * MAX_ALLOWED_INDEX_BASE), db(db), file_index(index), - path(FileSystem::GetFileSystem(db).JoinPath(temp_directory, - "duckdb_temp_storage-" + to_string(index) + ".tmp")) { - } - -public: - struct TemporaryFileLock { - explicit TemporaryFileLock(mutex &mutex) : lock(mutex) { - } - - lock_guard lock; - }; - -public: - TemporaryFileIndex TryGetBlockIndex() { - TemporaryFileLock lock(file_lock); - if (index_manager.GetMaxIndex() >= max_allowed_index && index_manager.HasFreeBlocks()) { - // file is at capacity - return TemporaryFileIndex(); - } - // open the file handle if it does not yet exist - CreateFileIfNotExists(lock); - // fetch a new block index to write to - auto block_index = index_manager.GetNewBlockIndex(); - return TemporaryFileIndex(file_index, block_index); - } - - void WriteTemporaryFile(FileBuffer &buffer, TemporaryFileIndex index) { - D_ASSERT(buffer.size == Storage::BLOCK_SIZE); - buffer.Write(*handle, GetPositionInFile(index.block_index)); - } - - unique_ptr ReadTemporaryBuffer(block_id_t id, idx_t block_index, - unique_ptr reusable_buffer) { - return ReadTemporaryBufferInternal(BufferManager::GetBufferManager(db), *handle, GetPositionInFile(block_index), - Storage::BLOCK_SIZE, id, std::move(reusable_buffer)); - } - - void EraseBlockIndex(block_id_t block_index) { - // remove the block (and potentially truncate the temp file) - TemporaryFileLock lock(file_lock); - D_ASSERT(handle); - RemoveTempBlockIndex(lock, block_index); - } - - bool DeleteIfEmpty() { - TemporaryFileLock lock(file_lock); - if (index_manager.GetMaxIndex() > 0) { - // there are still blocks in this file - return false; - } - // the file is empty: delete it - handle.reset(); - auto &fs = FileSystem::GetFileSystem(db); - fs.RemoveFile(path); - return true; - } - - TemporaryFileInformation GetTemporaryFile() { - TemporaryFileLock lock(file_lock); - TemporaryFileInformation info; - info.path = path; - info.size = GetPositionInFile(index_manager.GetMaxIndex()); - return info; - } - -private: - void CreateFileIfNotExists(TemporaryFileLock &) { - if (handle) { - return; - } - auto &fs = FileSystem::GetFileSystem(db); - handle = fs.OpenFile(path, FileFlags::FILE_FLAGS_READ | FileFlags::FILE_FLAGS_WRITE | - FileFlags::FILE_FLAGS_FILE_CREATE); - } - - void RemoveTempBlockIndex(TemporaryFileLock &, idx_t index) { - // remove the block index from the index manager - if (index_manager.RemoveIndex(index)) { - // the max_index that is currently in use has decreased - // as a result we can truncate the file -#ifndef WIN32 // this ended up causing issues when sorting - auto max_index = index_manager.GetMaxIndex(); - auto &fs = FileSystem::GetFileSystem(db); - fs.Truncate(*handle, GetPositionInFile(max_index + 1)); -#endif - } - } - - idx_t GetPositionInFile(idx_t index) { - return index * Storage::BLOCK_ALLOC_SIZE; - } - -private: - const idx_t max_allowed_index; - DatabaseInstance &db; - unique_ptr handle; - idx_t file_index; - string path; - mutex file_lock; - BlockIndexManager index_manager; -}; - -class TemporaryFileManager { -public: - TemporaryFileManager(DatabaseInstance &db, const string &temp_directory_p) - : db(db), temp_directory(temp_directory_p) { - } - -public: - struct TemporaryManagerLock { - explicit TemporaryManagerLock(mutex &mutex) : lock(mutex) { - } - - lock_guard lock; - }; - - void WriteTemporaryBuffer(block_id_t block_id, FileBuffer &buffer) { - D_ASSERT(buffer.size == Storage::BLOCK_SIZE); - TemporaryFileIndex index; - TemporaryFileHandle *handle = nullptr; - - { - TemporaryManagerLock lock(manager_lock); - // first check if we can write to an open existing file - for (auto &entry : files) { - auto &temp_file = entry.second; - index = temp_file->TryGetBlockIndex(); - if (index.IsValid()) { - handle = entry.second.get(); - break; - } - } - if (!handle) { - // no existing handle to write to; we need to create & open a new file - auto new_file_index = index_manager.GetNewBlockIndex(); - auto new_file = make_uniq(files.size(), db, temp_directory, new_file_index); - handle = new_file.get(); - files[new_file_index] = std::move(new_file); - - index = handle->TryGetBlockIndex(); - } - D_ASSERT(used_blocks.find(block_id) == used_blocks.end()); - used_blocks[block_id] = index; - } - D_ASSERT(handle); - D_ASSERT(index.IsValid()); - handle->WriteTemporaryFile(buffer, index); - } - - bool HasTemporaryBuffer(block_id_t block_id) { - lock_guard lock(manager_lock); - return used_blocks.find(block_id) != used_blocks.end(); - } - - unique_ptr ReadTemporaryBuffer(block_id_t id, unique_ptr reusable_buffer) { - TemporaryFileIndex index; - TemporaryFileHandle *handle; - { - TemporaryManagerLock lock(manager_lock); - index = GetTempBlockIndex(lock, id); - handle = GetFileHandle(lock, index.file_index); - } - auto buffer = handle->ReadTemporaryBuffer(id, index.block_index, std::move(reusable_buffer)); - { - // remove the block (and potentially erase the temp file) - TemporaryManagerLock lock(manager_lock); - EraseUsedBlock(lock, id, handle, index); - } - return buffer; - } - - void DeleteTemporaryBuffer(block_id_t id) { - TemporaryManagerLock lock(manager_lock); - auto index = GetTempBlockIndex(lock, id); - auto handle = GetFileHandle(lock, index.file_index); - EraseUsedBlock(lock, id, handle, index); - } - - vector GetTemporaryFiles() { - lock_guard lock(manager_lock); - vector result; - for (auto &file : files) { - result.push_back(file.second->GetTemporaryFile()); - } - return result; - } - -private: - void EraseUsedBlock(TemporaryManagerLock &lock, block_id_t id, TemporaryFileHandle *handle, - TemporaryFileIndex index) { - auto entry = used_blocks.find(id); - if (entry == used_blocks.end()) { - throw InternalException("EraseUsedBlock - Block %llu not found in used blocks", id); - } - used_blocks.erase(entry); - handle->EraseBlockIndex(index.block_index); - if (handle->DeleteIfEmpty()) { - EraseFileHandle(lock, index.file_index); - } - } - - TemporaryFileHandle *GetFileHandle(TemporaryManagerLock &, idx_t index) { - return files[index].get(); - } - - TemporaryFileIndex GetTempBlockIndex(TemporaryManagerLock &, block_id_t id) { - D_ASSERT(used_blocks.find(id) != used_blocks.end()); - return used_blocks[id]; - } - - void EraseFileHandle(TemporaryManagerLock &, idx_t file_index) { - files.erase(file_index); - index_manager.RemoveIndex(file_index); - } - -private: - DatabaseInstance &db; - mutex manager_lock; - //! The temporary directory - string temp_directory; - //! The set of active temporary file handles - unordered_map> files; - //! map of block_id -> temporary file position - unordered_map used_blocks; - //! Manager of in-use temporary file indexes - BlockIndexManager index_manager; -}; - -TemporaryDirectoryHandle::TemporaryDirectoryHandle(DatabaseInstance &db, string path_p) - : db(db), temp_directory(std::move(path_p)), temp_file(make_uniq(db, temp_directory)) { - auto &fs = FileSystem::GetFileSystem(db); - if (!temp_directory.empty()) { - if (!fs.DirectoryExists(temp_directory)) { - fs.CreateDirectory(temp_directory); - created_directory = true; - } - } -} -TemporaryDirectoryHandle::~TemporaryDirectoryHandle() { - // first release any temporary files - temp_file.reset(); - // then delete the temporary file directory - auto &fs = FileSystem::GetFileSystem(db); - if (!temp_directory.empty()) { - bool delete_directory = created_directory; - vector files_to_delete; - if (!created_directory) { - bool deleted_everything = true; - fs.ListFiles(temp_directory, [&](const string &path, bool isdir) { - if (isdir) { - deleted_everything = false; - return; - } - if (!StringUtil::StartsWith(path, "duckdb_temp_")) { - deleted_everything = false; - return; - } - files_to_delete.push_back(path); - }); - } - if (delete_directory) { - // we want to remove all files in the directory - fs.RemoveDirectory(temp_directory); - } else { - for (auto &file : files_to_delete) { - fs.RemoveFile(fs.JoinPath(temp_directory, file)); - } - } - } -} - -TemporaryFileManager &TemporaryDirectoryHandle::GetTempFile() { - return *temp_file; -} - -string StandardBufferManager::GetTemporaryPath(block_id_t id) { - auto &fs = FileSystem::GetFileSystem(db); - return fs.JoinPath(temp_directory, "duckdb_temp_block-" + to_string(id) + ".block"); -} - -void StandardBufferManager::RequireTemporaryDirectory() { - if (temp_directory.empty()) { - throw Exception( - "Out-of-memory: cannot write buffer because no temporary directory is specified!\nTo enable " - "temporary buffer eviction set a temporary directory using PRAGMA temp_directory='/path/to/tmp.tmp'"); - } - lock_guard temp_handle_guard(temp_handle_lock); - if (!temp_directory_handle) { - // temp directory has not been created yet: initialize it - temp_directory_handle = make_uniq(db, temp_directory); - } -} - -void StandardBufferManager::WriteTemporaryBuffer(block_id_t block_id, FileBuffer &buffer) { - RequireTemporaryDirectory(); - if (buffer.size == Storage::BLOCK_SIZE) { - temp_directory_handle->GetTempFile().WriteTemporaryBuffer(block_id, buffer); - return; - } - // get the path to write to - auto path = GetTemporaryPath(block_id); - D_ASSERT(buffer.size > Storage::BLOCK_SIZE); - // create the file and write the size followed by the buffer contents - auto &fs = FileSystem::GetFileSystem(db); - auto handle = fs.OpenFile(path, FileFlags::FILE_FLAGS_WRITE | FileFlags::FILE_FLAGS_FILE_CREATE); - handle->Write(&buffer.size, sizeof(idx_t), 0); - buffer.Write(*handle, sizeof(idx_t)); -} - -unique_ptr StandardBufferManager::ReadTemporaryBuffer(block_id_t id, - unique_ptr reusable_buffer) { - D_ASSERT(!temp_directory.empty()); - D_ASSERT(temp_directory_handle.get()); - if (temp_directory_handle->GetTempFile().HasTemporaryBuffer(id)) { - return temp_directory_handle->GetTempFile().ReadTemporaryBuffer(id, std::move(reusable_buffer)); - } - idx_t block_size; - // open the temporary file and read the size - auto path = GetTemporaryPath(id); - auto &fs = FileSystem::GetFileSystem(db); - auto handle = fs.OpenFile(path, FileFlags::FILE_FLAGS_READ); - handle->Read(&block_size, sizeof(idx_t), 0); - - // now allocate a buffer of this size and read the data into that buffer - auto buffer = - ReadTemporaryBufferInternal(*this, *handle, sizeof(idx_t), block_size, id, std::move(reusable_buffer)); - - handle.reset(); - DeleteTemporaryFile(id); - return buffer; -} - -void StandardBufferManager::DeleteTemporaryFile(block_id_t id) { - if (temp_directory.empty()) { - // no temporary directory specified: nothing to delete - return; - } - { - lock_guard temp_handle_guard(temp_handle_lock); - if (!temp_directory_handle) { - // temporary directory was not initialized yet: nothing to delete - return; - } - } - // check if we should delete the file from the shared pool of files, or from the general file system - if (temp_directory_handle->GetTempFile().HasTemporaryBuffer(id)) { - temp_directory_handle->GetTempFile().DeleteTemporaryBuffer(id); - return; - } - auto &fs = FileSystem::GetFileSystem(db); - auto path = GetTemporaryPath(id); - if (fs.FileExists(path)) { - fs.RemoveFile(path); - } -} - -bool StandardBufferManager::HasTemporaryDirectory() const { - return !temp_directory.empty(); -} - -vector StandardBufferManager::GetTemporaryFiles() { - vector result; - if (temp_directory.empty()) { - return result; - } - { - lock_guard temp_handle_guard(temp_handle_lock); - if (temp_directory_handle) { - result = temp_directory_handle->GetTempFile().GetTemporaryFiles(); - } - } - auto &fs = FileSystem::GetFileSystem(db); - fs.ListFiles(temp_directory, [&](const string &name, bool is_dir) { - if (is_dir) { - return; - } - if (!StringUtil::EndsWith(name, ".block")) { - return; - } - TemporaryFileInformation info; - info.path = name; - auto handle = fs.OpenFile(name, FileFlags::FILE_FLAGS_READ); - info.size = fs.GetFileSize(*handle); - handle.reset(); - result.push_back(info); - }); - return result; -} - -const char *StandardBufferManager::InMemoryWarning() { - if (!temp_directory.empty()) { - return ""; - } - return "\nDatabase is launched in in-memory mode and no temporary directory is specified." - "\nUnused blocks cannot be offloaded to disk." - "\n\nLaunch the database with a persistent storage back-end" - "\nOr set PRAGMA temp_directory='/path/to/tmp.tmp'"; -} - -void StandardBufferManager::ReserveMemory(idx_t size) { - if (size == 0) { - return; - } - auto reservation = EvictBlocksOrThrow(size, nullptr, "failed to reserve memory data of size %s%s", - StringUtil::BytesToHumanReadableString(size)); - reservation.size = 0; -} - -void StandardBufferManager::FreeReservedMemory(idx_t size) { - if (size == 0) { - return; - } - buffer_pool.current_memory -= size; -} - -//===--------------------------------------------------------------------===// -// Buffer Allocator -//===--------------------------------------------------------------------===// -data_ptr_t StandardBufferManager::BufferAllocatorAllocate(PrivateAllocatorData *private_data, idx_t size) { - auto &data = private_data->Cast(); - auto reservation = data.manager.EvictBlocksOrThrow(size, nullptr, "failed to allocate data of size %s%s", - StringUtil::BytesToHumanReadableString(size)); - // We rely on manual tracking of this one. :( - reservation.size = 0; - return Allocator::Get(data.manager.db).AllocateData(size); -} - -void StandardBufferManager::BufferAllocatorFree(PrivateAllocatorData *private_data, data_ptr_t pointer, idx_t size) { - auto &data = private_data->Cast(); - BufferPoolReservation r(data.manager.GetBufferPool()); - r.size = size; - r.Resize(0); - return Allocator::Get(data.manager.db).FreeData(pointer, size); -} - -data_ptr_t StandardBufferManager::BufferAllocatorRealloc(PrivateAllocatorData *private_data, data_ptr_t pointer, - idx_t old_size, idx_t size) { - if (old_size == size) { - return pointer; - } - auto &data = private_data->Cast(); - BufferPoolReservation r(data.manager.GetBufferPool()); - r.size = old_size; - r.Resize(size); - r.size = 0; - return Allocator::Get(data.manager.db).ReallocateData(pointer, old_size, size); -} - -Allocator &BufferAllocator::Get(ClientContext &context) { - auto &manager = StandardBufferManager::GetBufferManager(context); - return manager.GetBufferAllocator(); -} - -Allocator &BufferAllocator::Get(DatabaseInstance &db) { - return StandardBufferManager::GetBufferManager(db).GetBufferAllocator(); -} - -Allocator &BufferAllocator::Get(AttachedDatabase &db) { - return BufferAllocator::Get(db.GetDatabase()); -} - -Allocator &StandardBufferManager::GetBufferAllocator() { - return buffer_allocator; -} - -} // namespace duckdb - - - - - - - - - - -namespace duckdb { - -BaseStatistics::BaseStatistics() : type(LogicalType::INVALID) { -} - -BaseStatistics::BaseStatistics(LogicalType type) { - Construct(*this, std::move(type)); -} - -void BaseStatistics::Construct(BaseStatistics &stats, LogicalType type) { - stats.distinct_count = 0; - stats.type = std::move(type); - switch (GetStatsType(stats.type)) { - case StatisticsType::LIST_STATS: - ListStats::Construct(stats); - break; - case StatisticsType::STRUCT_STATS: - StructStats::Construct(stats); - break; - default: - break; - } -} - -BaseStatistics::~BaseStatistics() { -} - -BaseStatistics::BaseStatistics(BaseStatistics &&other) noexcept { - std::swap(type, other.type); - has_null = other.has_null; - has_no_null = other.has_no_null; - distinct_count = other.distinct_count; - stats_union = other.stats_union; - std::swap(child_stats, other.child_stats); -} - -BaseStatistics &BaseStatistics::operator=(BaseStatistics &&other) noexcept { - std::swap(type, other.type); - has_null = other.has_null; - has_no_null = other.has_no_null; - distinct_count = other.distinct_count; - stats_union = other.stats_union; - std::swap(child_stats, other.child_stats); - return *this; -} - -StatisticsType BaseStatistics::GetStatsType(const LogicalType &type) { - if (type.id() == LogicalTypeId::SQLNULL) { - return StatisticsType::BASE_STATS; - } - switch (type.InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - case PhysicalType::INT16: - case PhysicalType::INT32: - case PhysicalType::INT64: - case PhysicalType::UINT8: - case PhysicalType::UINT16: - case PhysicalType::UINT32: - case PhysicalType::UINT64: - case PhysicalType::INT128: - case PhysicalType::FLOAT: - case PhysicalType::DOUBLE: - return StatisticsType::NUMERIC_STATS; - case PhysicalType::VARCHAR: - return StatisticsType::STRING_STATS; - case PhysicalType::STRUCT: - return StatisticsType::STRUCT_STATS; - case PhysicalType::LIST: - return StatisticsType::LIST_STATS; - case PhysicalType::BIT: - case PhysicalType::INTERVAL: - default: - return StatisticsType::BASE_STATS; - } -} - -StatisticsType BaseStatistics::GetStatsType() const { - return GetStatsType(GetType()); -} - -void BaseStatistics::InitializeUnknown() { - has_null = true; - has_no_null = true; -} - -void BaseStatistics::InitializeEmpty() { - has_null = false; - has_no_null = true; -} - -bool BaseStatistics::CanHaveNull() const { - return has_null; -} - -bool BaseStatistics::CanHaveNoNull() const { - return has_no_null; -} - -bool BaseStatistics::IsConstant() const { - if (type.id() == LogicalTypeId::VALIDITY) { - // validity mask - if (CanHaveNull() && !CanHaveNoNull()) { - return true; - } - if (!CanHaveNull() && CanHaveNoNull()) { - return true; - } - return false; - } - switch (GetStatsType()) { - case StatisticsType::NUMERIC_STATS: - return NumericStats::IsConstant(*this); - default: - break; - } - return false; -} - -void BaseStatistics::Merge(const BaseStatistics &other) { - has_null = has_null || other.has_null; - has_no_null = has_no_null || other.has_no_null; - switch (GetStatsType()) { - case StatisticsType::NUMERIC_STATS: - NumericStats::Merge(*this, other); - break; - case StatisticsType::STRING_STATS: - StringStats::Merge(*this, other); - break; - case StatisticsType::LIST_STATS: - ListStats::Merge(*this, other); - break; - case StatisticsType::STRUCT_STATS: - StructStats::Merge(*this, other); - break; - default: - break; - } -} - -idx_t BaseStatistics::GetDistinctCount() { - return distinct_count; -} - -BaseStatistics BaseStatistics::CreateUnknownType(LogicalType type) { - switch (GetStatsType(type)) { - case StatisticsType::NUMERIC_STATS: - return NumericStats::CreateUnknown(std::move(type)); - case StatisticsType::STRING_STATS: - return StringStats::CreateUnknown(std::move(type)); - case StatisticsType::LIST_STATS: - return ListStats::CreateUnknown(std::move(type)); - case StatisticsType::STRUCT_STATS: - return StructStats::CreateUnknown(std::move(type)); - default: - return BaseStatistics(std::move(type)); - } -} - -BaseStatistics BaseStatistics::CreateEmptyType(LogicalType type) { - switch (GetStatsType(type)) { - case StatisticsType::NUMERIC_STATS: - return NumericStats::CreateEmpty(std::move(type)); - case StatisticsType::STRING_STATS: - return StringStats::CreateEmpty(std::move(type)); - case StatisticsType::LIST_STATS: - return ListStats::CreateEmpty(std::move(type)); - case StatisticsType::STRUCT_STATS: - return StructStats::CreateEmpty(std::move(type)); - default: - return BaseStatistics(std::move(type)); - } -} - -BaseStatistics BaseStatistics::CreateUnknown(LogicalType type) { - auto result = CreateUnknownType(std::move(type)); - result.InitializeUnknown(); - return result; -} - -BaseStatistics BaseStatistics::CreateEmpty(LogicalType type) { - if (type.InternalType() == PhysicalType::BIT) { - // FIXME: this special case should not be necessary - // but currently InitializeEmpty sets StatsInfo::CAN_HAVE_VALID_VALUES - BaseStatistics result(std::move(type)); - result.Set(StatsInfo::CANNOT_HAVE_NULL_VALUES); - result.Set(StatsInfo::CANNOT_HAVE_VALID_VALUES); - return result; - } - auto result = CreateEmptyType(std::move(type)); - result.InitializeEmpty(); - return result; -} - -void BaseStatistics::Copy(const BaseStatistics &other) { - D_ASSERT(GetType() == other.GetType()); - CopyBase(other); - stats_union = other.stats_union; - switch (GetStatsType()) { - case StatisticsType::LIST_STATS: - ListStats::Copy(*this, other); - break; - case StatisticsType::STRUCT_STATS: - StructStats::Copy(*this, other); - break; - default: - break; - } -} - -BaseStatistics BaseStatistics::Copy() const { - BaseStatistics result(type); - result.Copy(*this); - return result; -} - -unique_ptr BaseStatistics::ToUnique() const { - auto result = unique_ptr(new BaseStatistics(type)); - result->Copy(*this); - return result; -} - -void BaseStatistics::CopyBase(const BaseStatistics &other) { - has_null = other.has_null; - has_no_null = other.has_no_null; - distinct_count = other.distinct_count; -} - -void BaseStatistics::Set(StatsInfo info) { - switch (info) { - case StatsInfo::CAN_HAVE_NULL_VALUES: - has_null = true; - break; - case StatsInfo::CANNOT_HAVE_NULL_VALUES: - has_null = false; - break; - case StatsInfo::CAN_HAVE_VALID_VALUES: - has_no_null = true; - break; - case StatsInfo::CANNOT_HAVE_VALID_VALUES: - has_no_null = false; - break; - case StatsInfo::CAN_HAVE_NULL_AND_VALID_VALUES: - has_null = true; - has_no_null = true; - break; - default: - throw InternalException("Unrecognized StatsInfo for BaseStatistics::Set"); - } -} - -void BaseStatistics::CombineValidity(BaseStatistics &left, BaseStatistics &right) { - has_null = left.has_null || right.has_null; - has_no_null = left.has_no_null || right.has_no_null; -} - -void BaseStatistics::CopyValidity(BaseStatistics &stats) { - has_null = stats.has_null; - has_no_null = stats.has_no_null; -} - -void BaseStatistics::SetDistinctCount(idx_t count) { - this->distinct_count = count; -} - -void BaseStatistics::Serialize(Serializer &serializer) const { - serializer.WriteProperty(100, "has_null", has_null); - serializer.WriteProperty(101, "has_no_null", has_no_null); - serializer.WriteProperty(102, "distinct_count", distinct_count); - serializer.WriteObject(103, "type_stats", [&](Serializer &serializer) { - switch (GetStatsType()) { - case StatisticsType::NUMERIC_STATS: - NumericStats::Serialize(*this, serializer); - break; - case StatisticsType::STRING_STATS: - StringStats::Serialize(*this, serializer); - break; - case StatisticsType::LIST_STATS: - ListStats::Serialize(*this, serializer); - break; - case StatisticsType::STRUCT_STATS: - StructStats::Serialize(*this, serializer); - break; - default: - break; - } - }); -} - -BaseStatistics BaseStatistics::Deserialize(Deserializer &deserializer) { - auto has_null = deserializer.ReadProperty(100, "has_null"); - auto has_no_null = deserializer.ReadProperty(101, "has_no_null"); - auto distinct_count = deserializer.ReadProperty(102, "distinct_count"); - - // Get the logical type from the deserializer context. - auto type = deserializer.Get(); - - auto stats_type = GetStatsType(type); - - BaseStatistics stats(std::move(type)); - - stats.has_null = has_null; - stats.has_no_null = has_no_null; - stats.distinct_count = distinct_count; - - deserializer.ReadObject(103, "type_stats", [&](Deserializer &obj) { - switch (stats_type) { - case StatisticsType::NUMERIC_STATS: - NumericStats::Deserialize(obj, stats); - break; - case StatisticsType::STRING_STATS: - StringStats::Deserialize(obj, stats); - break; - case StatisticsType::LIST_STATS: - ListStats::Deserialize(obj, stats); - break; - case StatisticsType::STRUCT_STATS: - StructStats::Deserialize(obj, stats); - break; - default: - break; - } - }); - - return stats; -} - -string BaseStatistics::ToString() const { - auto has_n = has_null ? "true" : "false"; - auto has_n_n = has_no_null ? "true" : "false"; - string result = - StringUtil::Format("%s%s", StringUtil::Format("[Has Null: %s, Has No Null: %s]", has_n, has_n_n), - distinct_count > 0 ? StringUtil::Format("[Approx Unique: %lld]", distinct_count) : ""); - switch (GetStatsType()) { - case StatisticsType::NUMERIC_STATS: - result = NumericStats::ToString(*this) + result; - break; - case StatisticsType::STRING_STATS: - result = StringStats::ToString(*this) + result; - break; - case StatisticsType::LIST_STATS: - result = ListStats::ToString(*this) + result; - break; - case StatisticsType::STRUCT_STATS: - result = StructStats::ToString(*this) + result; - break; - default: - break; - } - return result; -} - -void BaseStatistics::Verify(Vector &vector, const SelectionVector &sel, idx_t count) const { - D_ASSERT(vector.GetType() == this->type); - switch (GetStatsType()) { - case StatisticsType::NUMERIC_STATS: - NumericStats::Verify(*this, vector, sel, count); - break; - case StatisticsType::STRING_STATS: - StringStats::Verify(*this, vector, sel, count); - break; - case StatisticsType::LIST_STATS: - ListStats::Verify(*this, vector, sel, count); - break; - case StatisticsType::STRUCT_STATS: - StructStats::Verify(*this, vector, sel, count); - break; - default: - break; - } - if (has_null && has_no_null) { - // nothing to verify - return; - } - UnifiedVectorFormat vdata; - vector.ToUnifiedFormat(count, vdata); - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto index = vdata.sel->get_index(idx); - bool row_is_valid = vdata.validity.RowIsValid(index); - if (row_is_valid && !has_no_null) { - throw InternalException( - "Statistics mismatch: vector labeled as having only NULL values, but vector contains valid values: %s", - vector.ToString(count)); - } - if (!row_is_valid && !has_null) { - throw InternalException( - "Statistics mismatch: vector labeled as not having NULL values, but vector contains null values: %s", - vector.ToString(count)); - } - } -} - -void BaseStatistics::Verify(Vector &vector, idx_t count) const { - auto sel = FlatVector::IncrementalSelectionVector(); - Verify(vector, *sel, count); -} - -BaseStatistics BaseStatistics::FromConstantType(const Value &input) { - switch (GetStatsType(input.type())) { - case StatisticsType::NUMERIC_STATS: { - auto result = NumericStats::CreateEmpty(input.type()); - NumericStats::SetMin(result, input); - NumericStats::SetMax(result, input); - return result; - } - case StatisticsType::STRING_STATS: { - auto result = StringStats::CreateEmpty(input.type()); - if (!input.IsNull()) { - auto &string_value = StringValue::Get(input); - StringStats::Update(result, string_t(string_value)); - } - return result; - } - case StatisticsType::LIST_STATS: { - auto result = ListStats::CreateEmpty(input.type()); - auto &child_stats = ListStats::GetChildStats(result); - if (!input.IsNull()) { - auto &list_children = ListValue::GetChildren(input); - for (auto &child_element : list_children) { - child_stats.Merge(FromConstant(child_element)); - } - } - return result; - } - case StatisticsType::STRUCT_STATS: { - auto result = StructStats::CreateEmpty(input.type()); - auto &child_types = StructType::GetChildTypes(input.type()); - if (input.IsNull()) { - for (idx_t i = 0; i < child_types.size(); i++) { - StructStats::SetChildStats(result, i, FromConstant(Value(child_types[i].second))); - } - } else { - auto &struct_children = StructValue::GetChildren(input); - for (idx_t i = 0; i < child_types.size(); i++) { - StructStats::SetChildStats(result, i, FromConstant(struct_children[i])); - } - } - return result; - } - default: - return BaseStatistics(input.type()); - } -} - -BaseStatistics BaseStatistics::FromConstant(const Value &input) { - auto result = FromConstantType(input); - result.SetDistinctCount(1); - if (input.IsNull()) { - result.Set(StatsInfo::CAN_HAVE_NULL_VALUES); - result.Set(StatsInfo::CANNOT_HAVE_VALID_VALUES); - } else { - result.Set(StatsInfo::CANNOT_HAVE_NULL_VALUES); - result.Set(StatsInfo::CAN_HAVE_VALID_VALUES); - } - return result; -} - -} // namespace duckdb - - - - -namespace duckdb { - -ColumnStatistics::ColumnStatistics(BaseStatistics stats_p) : stats(std::move(stats_p)) { - if (DistinctStatistics::TypeIsSupported(stats.GetType())) { - distinct_stats = make_uniq(); - } -} -ColumnStatistics::ColumnStatistics(BaseStatistics stats_p, unique_ptr distinct_stats_p) - : stats(std::move(stats_p)), distinct_stats(std::move(distinct_stats_p)) { -} - -shared_ptr ColumnStatistics::CreateEmptyStats(const LogicalType &type) { - return make_shared(BaseStatistics::CreateEmpty(type)); -} - -void ColumnStatistics::Merge(ColumnStatistics &other) { - stats.Merge(other.stats); - if (distinct_stats) { - distinct_stats->Merge(*other.distinct_stats); - } -} - -BaseStatistics &ColumnStatistics::Statistics() { - return stats; -} - -bool ColumnStatistics::HasDistinctStats() { - return distinct_stats.get(); -} - -DistinctStatistics &ColumnStatistics::DistinctStats() { - if (!distinct_stats) { - throw InternalException("DistinctStats called without distinct_stats"); - } - return *distinct_stats; -} - -void ColumnStatistics::SetDistinct(unique_ptr distinct) { - this->distinct_stats = std::move(distinct); -} - -void ColumnStatistics::UpdateDistinctStatistics(Vector &v, idx_t count) { - if (!distinct_stats) { - return; - } - auto &d_stats = (DistinctStatistics &)*distinct_stats; - d_stats.Update(v, count); -} - -shared_ptr ColumnStatistics::Copy() const { - return make_shared(stats.Copy(), distinct_stats ? distinct_stats->Copy() : nullptr); -} - -void ColumnStatistics::Serialize(Serializer &serializer) const { - serializer.WriteProperty(100, "statistics", stats); - serializer.WritePropertyWithDefault(101, "distinct", distinct_stats, unique_ptr()); -} - -shared_ptr ColumnStatistics::Deserialize(Deserializer &deserializer) { - auto stats = deserializer.ReadProperty(100, "statistics"); - auto distinct_stats = deserializer.ReadPropertyWithDefault>( - 101, "distinct", unique_ptr()); - return make_shared(std::move(stats), std::move(distinct_stats)); -} - -} // namespace duckdb - - - - -#include - -namespace duckdb { - -DistinctStatistics::DistinctStatistics() : log(make_uniq()), sample_count(0), total_count(0) { -} - -DistinctStatistics::DistinctStatistics(unique_ptr log, idx_t sample_count, idx_t total_count) - : log(std::move(log)), sample_count(sample_count), total_count(total_count) { -} - -unique_ptr DistinctStatistics::Copy() const { - return make_uniq(log->Copy(), sample_count, total_count); -} - -void DistinctStatistics::Merge(const DistinctStatistics &other) { - log = log->Merge(*other.log); - sample_count += other.sample_count; - total_count += other.total_count; -} - -void DistinctStatistics::Update(Vector &v, idx_t count, bool sample) { - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(count, vdata); - Update(vdata, v.GetType(), count, sample); -} - -void DistinctStatistics::Update(UnifiedVectorFormat &vdata, const LogicalType &type, idx_t count, bool sample) { - if (count == 0) { - return; - } - - total_count += count; - if (sample) { - count = MinValue(idx_t(SAMPLE_RATE * MaxValue(STANDARD_VECTOR_SIZE, count)), count); - } - sample_count += count; - - uint64_t indices[STANDARD_VECTOR_SIZE]; - uint8_t counts[STANDARD_VECTOR_SIZE]; - - HyperLogLog::ProcessEntries(vdata, type, indices, counts, count); - log->AddToLog(vdata, count, indices, counts); -} - -string DistinctStatistics::ToString() const { - return StringUtil::Format("[Approx Unique: %s]", to_string(GetCount())); -} - -idx_t DistinctStatistics::GetCount() const { - if (sample_count == 0 || total_count == 0) { - return 0; - } - - double u = MinValue(log->Count(), sample_count); - double s = sample_count; - double n = total_count; - - // Assume this proportion of the the sampled values occurred only once - double u1 = pow(u / s, 2) * u; - - // Estimate total uniques using Good Turing Estimation - idx_t estimate = u + u1 / s * (n - s); - return MinValue(estimate, total_count); -} - -bool DistinctStatistics::TypeIsSupported(const LogicalType &type) { - return type.InternalType() != PhysicalType::LIST && type.InternalType() != PhysicalType::STRUCT; -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -void ListStats::Construct(BaseStatistics &stats) { - stats.child_stats = unsafe_unique_array(new BaseStatistics[1]); - BaseStatistics::Construct(stats.child_stats[0], ListType::GetChildType(stats.GetType())); -} - -BaseStatistics ListStats::CreateUnknown(LogicalType type) { - auto &child_type = ListType::GetChildType(type); - BaseStatistics result(std::move(type)); - result.InitializeUnknown(); - result.child_stats[0].Copy(BaseStatistics::CreateUnknown(child_type)); - return result; -} - -BaseStatistics ListStats::CreateEmpty(LogicalType type) { - auto &child_type = ListType::GetChildType(type); - BaseStatistics result(std::move(type)); - result.InitializeEmpty(); - result.child_stats[0].Copy(BaseStatistics::CreateEmpty(child_type)); - return result; -} - -void ListStats::Copy(BaseStatistics &stats, const BaseStatistics &other) { - D_ASSERT(stats.child_stats); - D_ASSERT(other.child_stats); - stats.child_stats[0].Copy(other.child_stats[0]); -} - -const BaseStatistics &ListStats::GetChildStats(const BaseStatistics &stats) { - if (stats.GetStatsType() != StatisticsType::LIST_STATS) { - throw InternalException("ListStats::GetChildStats called on stats that is not a list"); - } - D_ASSERT(stats.child_stats); - return stats.child_stats[0]; -} -BaseStatistics &ListStats::GetChildStats(BaseStatistics &stats) { - if (stats.GetStatsType() != StatisticsType::LIST_STATS) { - throw InternalException("ListStats::GetChildStats called on stats that is not a list"); - } - D_ASSERT(stats.child_stats); - return stats.child_stats[0]; -} - -void ListStats::SetChildStats(BaseStatistics &stats, unique_ptr new_stats) { - if (!new_stats) { - stats.child_stats[0].Copy(BaseStatistics::CreateUnknown(ListType::GetChildType(stats.GetType()))); - } else { - stats.child_stats[0].Copy(*new_stats); - } -} - -void ListStats::Merge(BaseStatistics &stats, const BaseStatistics &other) { - if (other.GetType().id() == LogicalTypeId::VALIDITY) { - return; - } - - auto &child_stats = ListStats::GetChildStats(stats); - auto &other_child_stats = ListStats::GetChildStats(other); - child_stats.Merge(other_child_stats); -} - -void ListStats::Serialize(const BaseStatistics &stats, Serializer &serializer) { - auto &child_stats = ListStats::GetChildStats(stats); - serializer.WriteProperty(200, "child_stats", child_stats); -} - -void ListStats::Deserialize(Deserializer &deserializer, BaseStatistics &base) { - auto &type = base.GetType(); - D_ASSERT(type.InternalType() == PhysicalType::LIST); - auto &child_type = ListType::GetChildType(type); - - // Push the logical type of the child type to the deserialization context - deserializer.Set(const_cast(child_type)); - base.child_stats[0].Copy(deserializer.ReadProperty(200, "child_stats")); - deserializer.Unset(); -} - -string ListStats::ToString(const BaseStatistics &stats) { - auto &child_stats = ListStats::GetChildStats(stats); - return StringUtil::Format("[%s]", child_stats.ToString()); -} - -void ListStats::Verify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, idx_t count) { - auto &child_stats = ListStats::GetChildStats(stats); - auto &child_entry = ListVector::GetEntry(vector); - UnifiedVectorFormat vdata; - vector.ToUnifiedFormat(count, vdata); - - auto list_data = UnifiedVectorFormat::GetData(vdata); - idx_t total_list_count = 0; - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto index = vdata.sel->get_index(idx); - auto list = list_data[index]; - if (vdata.validity.RowIsValid(index)) { - for (idx_t list_idx = 0; list_idx < list.length; list_idx++) { - total_list_count++; - } - } - } - SelectionVector list_sel(total_list_count); - idx_t list_count = 0; - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto index = vdata.sel->get_index(idx); - auto list = list_data[index]; - if (vdata.validity.RowIsValid(index)) { - for (idx_t list_idx = 0; list_idx < list.length; list_idx++) { - list_sel.set_index(list_count++, list.offset + list_idx); - } - } - } - - child_stats.Verify(child_entry, list_sel, list_count); -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -template <> -void NumericStats::Update(BaseStatistics &stats, interval_t new_value) { -} - -template <> -void NumericStats::Update(BaseStatistics &stats, list_entry_t new_value) { -} - -//===--------------------------------------------------------------------===// -// NumericStats -//===--------------------------------------------------------------------===// -BaseStatistics NumericStats::CreateUnknown(LogicalType type) { - BaseStatistics result(std::move(type)); - result.InitializeUnknown(); - SetMin(result, Value(result.GetType())); - SetMax(result, Value(result.GetType())); - return result; -} - -BaseStatistics NumericStats::CreateEmpty(LogicalType type) { - BaseStatistics result(std::move(type)); - result.InitializeEmpty(); - SetMin(result, Value::MaximumValue(result.GetType())); - SetMax(result, Value::MinimumValue(result.GetType())); - return result; -} - -NumericStatsData &NumericStats::GetDataUnsafe(BaseStatistics &stats) { - D_ASSERT(stats.GetStatsType() == StatisticsType::NUMERIC_STATS); - return stats.stats_union.numeric_data; -} - -const NumericStatsData &NumericStats::GetDataUnsafe(const BaseStatistics &stats) { - D_ASSERT(stats.GetStatsType() == StatisticsType::NUMERIC_STATS); - return stats.stats_union.numeric_data; -} - -void NumericStats::Merge(BaseStatistics &stats, const BaseStatistics &other) { - if (other.GetType().id() == LogicalTypeId::VALIDITY) { - return; - } - D_ASSERT(stats.GetType() == other.GetType()); - if (NumericStats::HasMin(other) && NumericStats::HasMin(stats)) { - auto other_min = NumericStats::Min(other); - if (other_min < NumericStats::Min(stats)) { - NumericStats::SetMin(stats, other_min); - } - } else { - NumericStats::SetMin(stats, Value()); - } - if (NumericStats::HasMax(other) && NumericStats::HasMax(stats)) { - auto other_max = NumericStats::Max(other); - if (other_max > NumericStats::Max(stats)) { - NumericStats::SetMax(stats, other_max); - } - } else { - NumericStats::SetMax(stats, Value()); - } -} - -struct GetNumericValueUnion { - template - static T Operation(const NumericValueUnion &v); -}; - -template <> -int8_t GetNumericValueUnion::Operation(const NumericValueUnion &v) { - return v.value_.tinyint; -} - -template <> -int16_t GetNumericValueUnion::Operation(const NumericValueUnion &v) { - return v.value_.smallint; -} - -template <> -int32_t GetNumericValueUnion::Operation(const NumericValueUnion &v) { - return v.value_.integer; -} - -template <> -int64_t GetNumericValueUnion::Operation(const NumericValueUnion &v) { - return v.value_.bigint; -} - -template <> -hugeint_t GetNumericValueUnion::Operation(const NumericValueUnion &v) { - return v.value_.hugeint; -} - -template <> -uint8_t GetNumericValueUnion::Operation(const NumericValueUnion &v) { - return v.value_.utinyint; -} - -template <> -uint16_t GetNumericValueUnion::Operation(const NumericValueUnion &v) { - return v.value_.usmallint; -} - -template <> -uint32_t GetNumericValueUnion::Operation(const NumericValueUnion &v) { - return v.value_.uinteger; -} - -template <> -uint64_t GetNumericValueUnion::Operation(const NumericValueUnion &v) { - return v.value_.ubigint; -} - -template <> -float GetNumericValueUnion::Operation(const NumericValueUnion &v) { - return v.value_.float_; -} - -template <> -double GetNumericValueUnion::Operation(const NumericValueUnion &v) { - return v.value_.double_; -} - -template -T NumericStats::GetMinUnsafe(const BaseStatistics &stats) { - return GetNumericValueUnion::Operation(NumericStats::GetDataUnsafe(stats).min); -} - -template -T NumericStats::GetMaxUnsafe(const BaseStatistics &stats) { - return GetNumericValueUnion::Operation(NumericStats::GetDataUnsafe(stats).max); -} - -template -bool ConstantExactRange(T min, T max, T constant) { - return Equals::Operation(constant, min) && Equals::Operation(constant, max); -} - -template -bool ConstantValueInRange(T min, T max, T constant) { - return !(LessThan::Operation(constant, min) || GreaterThan::Operation(constant, max)); -} - -template -FilterPropagateResult CheckZonemapTemplated(const BaseStatistics &stats, ExpressionType comparison_type, - const Value &constant_value) { - T min_value = NumericStats::GetMinUnsafe(stats); - T max_value = NumericStats::GetMaxUnsafe(stats); - T constant = constant_value.GetValueUnsafe(); - switch (comparison_type) { - case ExpressionType::COMPARE_EQUAL: - if (ConstantExactRange(min_value, max_value, constant)) { - return FilterPropagateResult::FILTER_ALWAYS_TRUE; - } - if (ConstantValueInRange(min_value, max_value, constant)) { - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } - return FilterPropagateResult::FILTER_ALWAYS_FALSE; - case ExpressionType::COMPARE_NOTEQUAL: - if (!ConstantValueInRange(min_value, max_value, constant)) { - return FilterPropagateResult::FILTER_ALWAYS_TRUE; - } else if (ConstantExactRange(min_value, max_value, constant)) { - // corner case of a cluster with one numeric equal to the target constant - return FilterPropagateResult::FILTER_ALWAYS_FALSE; - } - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - // GreaterThanEquals::Operation(X, C) - // this can be true only if max(X) >= C - // if min(X) >= C, then this is always true - if (GreaterThanEquals::Operation(min_value, constant)) { - return FilterPropagateResult::FILTER_ALWAYS_TRUE; - } else if (GreaterThanEquals::Operation(max_value, constant)) { - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } else { - return FilterPropagateResult::FILTER_ALWAYS_FALSE; - } - case ExpressionType::COMPARE_GREATERTHAN: - // GreaterThan::Operation(X, C) - // this can be true only if max(X) > C - // if min(X) > C, then this is always true - if (GreaterThan::Operation(min_value, constant)) { - return FilterPropagateResult::FILTER_ALWAYS_TRUE; - } else if (GreaterThan::Operation(max_value, constant)) { - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } else { - return FilterPropagateResult::FILTER_ALWAYS_FALSE; - } - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - // LessThanEquals::Operation(X, C) - // this can be true only if min(X) <= C - // if max(X) <= C, then this is always true - if (LessThanEquals::Operation(max_value, constant)) { - return FilterPropagateResult::FILTER_ALWAYS_TRUE; - } else if (LessThanEquals::Operation(min_value, constant)) { - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } else { - return FilterPropagateResult::FILTER_ALWAYS_FALSE; - } - case ExpressionType::COMPARE_LESSTHAN: - // LessThan::Operation(X, C) - // this can be true only if min(X) < C - // if max(X) < C, then this is always true - if (LessThan::Operation(max_value, constant)) { - return FilterPropagateResult::FILTER_ALWAYS_TRUE; - } else if (LessThan::Operation(min_value, constant)) { - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } else { - return FilterPropagateResult::FILTER_ALWAYS_FALSE; - } - default: - throw InternalException("Expression type in zonemap check not implemented"); - } -} - -FilterPropagateResult NumericStats::CheckZonemap(const BaseStatistics &stats, ExpressionType comparison_type, - const Value &constant) { - D_ASSERT(constant.type() == stats.GetType()); - if (constant.IsNull()) { - return FilterPropagateResult::FILTER_ALWAYS_FALSE; - } - if (!NumericStats::HasMinMax(stats)) { - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } - switch (stats.GetType().InternalType()) { - case PhysicalType::INT8: - return CheckZonemapTemplated(stats, comparison_type, constant); - case PhysicalType::INT16: - return CheckZonemapTemplated(stats, comparison_type, constant); - case PhysicalType::INT32: - return CheckZonemapTemplated(stats, comparison_type, constant); - case PhysicalType::INT64: - return CheckZonemapTemplated(stats, comparison_type, constant); - case PhysicalType::UINT8: - return CheckZonemapTemplated(stats, comparison_type, constant); - case PhysicalType::UINT16: - return CheckZonemapTemplated(stats, comparison_type, constant); - case PhysicalType::UINT32: - return CheckZonemapTemplated(stats, comparison_type, constant); - case PhysicalType::UINT64: - return CheckZonemapTemplated(stats, comparison_type, constant); - case PhysicalType::INT128: - return CheckZonemapTemplated(stats, comparison_type, constant); - case PhysicalType::FLOAT: - return CheckZonemapTemplated(stats, comparison_type, constant); - case PhysicalType::DOUBLE: - return CheckZonemapTemplated(stats, comparison_type, constant); - default: - throw InternalException("Unsupported type for NumericStats::CheckZonemap"); - } -} - -bool NumericStats::IsConstant(const BaseStatistics &stats) { - return NumericStats::Max(stats) <= NumericStats::Min(stats); -} - -void SetNumericValueInternal(const Value &input, const LogicalType &type, NumericValueUnion &val, bool &has_val) { - if (input.IsNull()) { - has_val = false; - return; - } - if (input.type().InternalType() != type.InternalType()) { - throw InternalException("SetMin or SetMax called with Value that does not match statistics' column value"); - } - has_val = true; - switch (type.InternalType()) { - case PhysicalType::BOOL: - val.value_.boolean = BooleanValue::Get(input); - break; - case PhysicalType::INT8: - val.value_.tinyint = TinyIntValue::Get(input); - break; - case PhysicalType::INT16: - val.value_.smallint = SmallIntValue::Get(input); - break; - case PhysicalType::INT32: - val.value_.integer = IntegerValue::Get(input); - break; - case PhysicalType::INT64: - val.value_.bigint = BigIntValue::Get(input); - break; - case PhysicalType::UINT8: - val.value_.utinyint = UTinyIntValue::Get(input); - break; - case PhysicalType::UINT16: - val.value_.usmallint = USmallIntValue::Get(input); - break; - case PhysicalType::UINT32: - val.value_.uinteger = UIntegerValue::Get(input); - break; - case PhysicalType::UINT64: - val.value_.ubigint = UBigIntValue::Get(input); - break; - case PhysicalType::INT128: - val.value_.hugeint = HugeIntValue::Get(input); - break; - case PhysicalType::FLOAT: - val.value_.float_ = FloatValue::Get(input); - break; - case PhysicalType::DOUBLE: - val.value_.double_ = DoubleValue::Get(input); - break; - default: - throw InternalException("Unsupported type for NumericStatistics::SetValueInternal"); - } -} - -void NumericStats::SetMin(BaseStatistics &stats, const Value &new_min) { - auto &data = NumericStats::GetDataUnsafe(stats); - SetNumericValueInternal(new_min, stats.GetType(), data.min, data.has_min); -} - -void NumericStats::SetMax(BaseStatistics &stats, const Value &new_max) { - auto &data = NumericStats::GetDataUnsafe(stats); - SetNumericValueInternal(new_max, stats.GetType(), data.max, data.has_max); -} - -Value NumericValueUnionToValueInternal(const LogicalType &type, const NumericValueUnion &val) { - switch (type.InternalType()) { - case PhysicalType::BOOL: - return Value::BOOLEAN(val.value_.boolean); - case PhysicalType::INT8: - return Value::TINYINT(val.value_.tinyint); - case PhysicalType::INT16: - return Value::SMALLINT(val.value_.smallint); - case PhysicalType::INT32: - return Value::INTEGER(val.value_.integer); - case PhysicalType::INT64: - return Value::BIGINT(val.value_.bigint); - case PhysicalType::UINT8: - return Value::UTINYINT(val.value_.utinyint); - case PhysicalType::UINT16: - return Value::USMALLINT(val.value_.usmallint); - case PhysicalType::UINT32: - return Value::UINTEGER(val.value_.uinteger); - case PhysicalType::UINT64: - return Value::UBIGINT(val.value_.ubigint); - case PhysicalType::INT128: - return Value::HUGEINT(val.value_.hugeint); - case PhysicalType::FLOAT: - return Value::FLOAT(val.value_.float_); - case PhysicalType::DOUBLE: - return Value::DOUBLE(val.value_.double_); - default: - throw InternalException("Unsupported type for NumericValueUnionToValue"); - } -} - -Value NumericValueUnionToValue(const LogicalType &type, const NumericValueUnion &val) { - Value result = NumericValueUnionToValueInternal(type, val); - result.GetTypeMutable() = type; - return result; -} - -bool NumericStats::HasMinMax(const BaseStatistics &stats) { - return NumericStats::HasMin(stats) && NumericStats::HasMax(stats); -} - -bool NumericStats::HasMin(const BaseStatistics &stats) { - if (stats.GetType().id() == LogicalTypeId::SQLNULL) { - return false; - } - return NumericStats::GetDataUnsafe(stats).has_min; -} - -bool NumericStats::HasMax(const BaseStatistics &stats) { - if (stats.GetType().id() == LogicalTypeId::SQLNULL) { - return false; - } - return NumericStats::GetDataUnsafe(stats).has_max; -} - -Value NumericStats::Min(const BaseStatistics &stats) { - if (!NumericStats::HasMin(stats)) { - throw InternalException("Min() called on statistics that does not have min"); - } - return NumericValueUnionToValue(stats.GetType(), NumericStats::GetDataUnsafe(stats).min); -} - -Value NumericStats::Max(const BaseStatistics &stats) { - if (!NumericStats::HasMax(stats)) { - throw InternalException("Max() called on statistics that does not have max"); - } - return NumericValueUnionToValue(stats.GetType(), NumericStats::GetDataUnsafe(stats).max); -} - -Value NumericStats::MinOrNull(const BaseStatistics &stats) { - if (!NumericStats::HasMin(stats)) { - return Value(stats.GetType()); - } - return NumericStats::Min(stats); -} - -Value NumericStats::MaxOrNull(const BaseStatistics &stats) { - if (!NumericStats::HasMax(stats)) { - return Value(stats.GetType()); - } - return NumericStats::Max(stats); -} - -static void SerializeNumericStatsValue(const LogicalType &type, NumericValueUnion val, bool has_value, - Serializer &serializer) { - serializer.WriteProperty(100, "has_value", has_value); - if (!has_value) { - return; - } - switch (type.InternalType()) { - case PhysicalType::BOOL: - serializer.WriteProperty(101, "value", val.value_.boolean); - break; - case PhysicalType::INT8: - serializer.WriteProperty(101, "value", val.value_.tinyint); - break; - case PhysicalType::INT16: - serializer.WriteProperty(101, "value", val.value_.smallint); - break; - case PhysicalType::INT32: - serializer.WriteProperty(101, "value", val.value_.integer); - break; - case PhysicalType::INT64: - serializer.WriteProperty(101, "value", val.value_.bigint); - break; - case PhysicalType::UINT8: - serializer.WriteProperty(101, "value", val.value_.utinyint); - break; - case PhysicalType::UINT16: - serializer.WriteProperty(101, "value", val.value_.usmallint); - break; - case PhysicalType::UINT32: - serializer.WriteProperty(101, "value", val.value_.uinteger); - break; - case PhysicalType::UINT64: - serializer.WriteProperty(101, "value", val.value_.ubigint); - break; - case PhysicalType::INT128: - serializer.WriteProperty(101, "value", val.value_.hugeint); - break; - case PhysicalType::FLOAT: - serializer.WriteProperty(101, "value", val.value_.float_); - break; - case PhysicalType::DOUBLE: - serializer.WriteProperty(101, "value", val.value_.double_); - break; - default: - throw InternalException("Unsupported type for serializing numeric statistics"); - } -} - -static void DeserializeNumericStatsValue(const LogicalType &type, NumericValueUnion &result, bool &has_stats, - Deserializer &deserializer) { - auto has_value = deserializer.ReadProperty(100, "has_value"); - if (!has_value) { - has_stats = false; - return; - } - has_stats = true; - switch (type.InternalType()) { - case PhysicalType::BOOL: - result.value_.boolean = deserializer.ReadProperty(101, "value"); - break; - case PhysicalType::INT8: - result.value_.tinyint = deserializer.ReadProperty(101, "value"); - break; - case PhysicalType::INT16: - result.value_.smallint = deserializer.ReadProperty(101, "value"); - break; - case PhysicalType::INT32: - result.value_.integer = deserializer.ReadProperty(101, "value"); - break; - case PhysicalType::INT64: - result.value_.bigint = deserializer.ReadProperty(101, "value"); - break; - case PhysicalType::UINT8: - result.value_.utinyint = deserializer.ReadProperty(101, "value"); - break; - case PhysicalType::UINT16: - result.value_.usmallint = deserializer.ReadProperty(101, "value"); - break; - case PhysicalType::UINT32: - result.value_.uinteger = deserializer.ReadProperty(101, "value"); - break; - case PhysicalType::UINT64: - result.value_.ubigint = deserializer.ReadProperty(101, "value"); - break; - case PhysicalType::INT128: - result.value_.hugeint = deserializer.ReadProperty(101, "value"); - break; - case PhysicalType::FLOAT: - result.value_.float_ = deserializer.ReadProperty(101, "value"); - break; - case PhysicalType::DOUBLE: - result.value_.double_ = deserializer.ReadProperty(101, "value"); - break; - default: - throw InternalException("Unsupported type for serializing numeric statistics"); - } -} - -void NumericStats::Serialize(const BaseStatistics &stats, Serializer &serializer) { - auto &numeric_stats = NumericStats::GetDataUnsafe(stats); - serializer.WriteObject(200, "max", [&](Serializer &object) { - SerializeNumericStatsValue(stats.GetType(), numeric_stats.min, numeric_stats.has_min, object); - }); - serializer.WriteObject(201, "min", [&](Serializer &object) { - SerializeNumericStatsValue(stats.GetType(), numeric_stats.max, numeric_stats.has_max, object); - }); -} - -void NumericStats::Deserialize(Deserializer &deserializer, BaseStatistics &result) { - auto &numeric_stats = NumericStats::GetDataUnsafe(result); - - deserializer.ReadObject(200, "max", [&](Deserializer &object) { - DeserializeNumericStatsValue(result.GetType(), numeric_stats.min, numeric_stats.has_min, object); - }); - deserializer.ReadObject(201, "min", [&](Deserializer &object) { - DeserializeNumericStatsValue(result.GetType(), numeric_stats.max, numeric_stats.has_max, object); - }); -} - -string NumericStats::ToString(const BaseStatistics &stats) { - return StringUtil::Format("[Min: %s, Max: %s]", NumericStats::MinOrNull(stats).ToString(), - NumericStats::MaxOrNull(stats).ToString()); -} - -template -void NumericStats::TemplatedVerify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, - idx_t count) { - UnifiedVectorFormat vdata; - vector.ToUnifiedFormat(count, vdata); - - auto data = UnifiedVectorFormat::GetData(vdata); - auto min_value = NumericStats::MinOrNull(stats); - auto max_value = NumericStats::MaxOrNull(stats); - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto index = vdata.sel->get_index(idx); - if (!vdata.validity.RowIsValid(index)) { - continue; - } - if (!min_value.IsNull() && LessThan::Operation(data[index], min_value.GetValueUnsafe())) { // LCOV_EXCL_START - throw InternalException("Statistics mismatch: value is smaller than min.\nStatistics: %s\nVector: %s", - stats.ToString(), vector.ToString(count)); - } // LCOV_EXCL_STOP - if (!max_value.IsNull() && GreaterThan::Operation(data[index], max_value.GetValueUnsafe())) { - throw InternalException("Statistics mismatch: value is bigger than max.\nStatistics: %s\nVector: %s", - stats.ToString(), vector.ToString(count)); - } - } -} - -void NumericStats::Verify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, idx_t count) { - auto &type = stats.GetType(); - switch (type.InternalType()) { - case PhysicalType::BOOL: - break; - case PhysicalType::INT8: - TemplatedVerify(stats, vector, sel, count); - break; - case PhysicalType::INT16: - TemplatedVerify(stats, vector, sel, count); - break; - case PhysicalType::INT32: - TemplatedVerify(stats, vector, sel, count); - break; - case PhysicalType::INT64: - TemplatedVerify(stats, vector, sel, count); - break; - case PhysicalType::UINT8: - TemplatedVerify(stats, vector, sel, count); - break; - case PhysicalType::UINT16: - TemplatedVerify(stats, vector, sel, count); - break; - case PhysicalType::UINT32: - TemplatedVerify(stats, vector, sel, count); - break; - case PhysicalType::UINT64: - TemplatedVerify(stats, vector, sel, count); - break; - case PhysicalType::INT128: - TemplatedVerify(stats, vector, sel, count); - break; - case PhysicalType::FLOAT: - TemplatedVerify(stats, vector, sel, count); - break; - case PhysicalType::DOUBLE: - TemplatedVerify(stats, vector, sel, count); - break; - default: - throw InternalException("Unsupported type %s for numeric statistics verify", type.ToString()); - } -} - -} // namespace duckdb - - -namespace duckdb { - -template <> -bool &NumericValueUnion::GetReferenceUnsafe() { - return value_.boolean; -} - -template <> -int8_t &NumericValueUnion::GetReferenceUnsafe() { - return value_.tinyint; -} - -template <> -int16_t &NumericValueUnion::GetReferenceUnsafe() { - return value_.smallint; -} - -template <> -int32_t &NumericValueUnion::GetReferenceUnsafe() { - return value_.integer; -} - -template <> -int64_t &NumericValueUnion::GetReferenceUnsafe() { - return value_.bigint; -} - -template <> -hugeint_t &NumericValueUnion::GetReferenceUnsafe() { - return value_.hugeint; -} - -template <> -uint8_t &NumericValueUnion::GetReferenceUnsafe() { - return value_.utinyint; -} - -template <> -uint16_t &NumericValueUnion::GetReferenceUnsafe() { - return value_.usmallint; -} - -template <> -uint32_t &NumericValueUnion::GetReferenceUnsafe() { - return value_.uinteger; -} - -template <> -uint64_t &NumericValueUnion::GetReferenceUnsafe() { - return value_.ubigint; -} - -template <> -float &NumericValueUnion::GetReferenceUnsafe() { - return value_.float_; -} - -template <> -double &NumericValueUnion::GetReferenceUnsafe() { - return value_.double_; -} - -} // namespace duckdb - - - - -namespace duckdb { - -SegmentStatistics::SegmentStatistics(LogicalType type) : statistics(BaseStatistics::CreateEmpty(std::move(type))) { -} - -SegmentStatistics::SegmentStatistics(BaseStatistics stats) : statistics(std::move(stats)) { -} - -} // namespace duckdb - - - - - - - - - - - -namespace duckdb { - -BaseStatistics StringStats::CreateUnknown(LogicalType type) { - BaseStatistics result(std::move(type)); - result.InitializeUnknown(); - auto &string_data = StringStats::GetDataUnsafe(result); - for (idx_t i = 0; i < StringStatsData::MAX_STRING_MINMAX_SIZE; i++) { - string_data.min[i] = 0; - string_data.max[i] = 0xFF; - } - string_data.max_string_length = 0; - string_data.has_max_string_length = false; - string_data.has_unicode = true; - return result; -} - -BaseStatistics StringStats::CreateEmpty(LogicalType type) { - BaseStatistics result(std::move(type)); - result.InitializeEmpty(); - auto &string_data = StringStats::GetDataUnsafe(result); - for (idx_t i = 0; i < StringStatsData::MAX_STRING_MINMAX_SIZE; i++) { - string_data.min[i] = 0xFF; - string_data.max[i] = 0; - } - string_data.max_string_length = 0; - string_data.has_max_string_length = true; - string_data.has_unicode = false; - return result; -} - -StringStatsData &StringStats::GetDataUnsafe(BaseStatistics &stats) { - D_ASSERT(stats.GetStatsType() == StatisticsType::STRING_STATS); - return stats.stats_union.string_data; -} - -const StringStatsData &StringStats::GetDataUnsafe(const BaseStatistics &stats) { - D_ASSERT(stats.GetStatsType() == StatisticsType::STRING_STATS); - return stats.stats_union.string_data; -} - -bool StringStats::HasMaxStringLength(const BaseStatistics &stats) { - if (stats.GetType().id() == LogicalTypeId::SQLNULL) { - return false; - } - return StringStats::GetDataUnsafe(stats).has_max_string_length; -} - -uint32_t StringStats::MaxStringLength(const BaseStatistics &stats) { - if (!HasMaxStringLength(stats)) { - throw InternalException("MaxStringLength called on statistics that does not have a max string length"); - } - return StringStats::GetDataUnsafe(stats).max_string_length; -} - -bool StringStats::CanContainUnicode(const BaseStatistics &stats) { - if (stats.GetType().id() == LogicalTypeId::SQLNULL) { - return true; - } - return StringStats::GetDataUnsafe(stats).has_unicode; -} - -string GetStringMinMaxValue(const data_t data[]) { - idx_t len; - for (len = 0; len < StringStatsData::MAX_STRING_MINMAX_SIZE; len++) { - if (!data[len]) { - break; - } - } - return string(const_char_ptr_cast(data), len); -} - -string StringStats::Min(const BaseStatistics &stats) { - return GetStringMinMaxValue(StringStats::GetDataUnsafe(stats).min); -} - -string StringStats::Max(const BaseStatistics &stats) { - return GetStringMinMaxValue(StringStats::GetDataUnsafe(stats).max); -} - -void StringStats::ResetMaxStringLength(BaseStatistics &stats) { - StringStats::GetDataUnsafe(stats).has_max_string_length = false; -} - -void StringStats::SetContainsUnicode(BaseStatistics &stats) { - StringStats::GetDataUnsafe(stats).has_unicode = true; -} - -void StringStats::Serialize(const BaseStatistics &stats, Serializer &serializer) { - auto &string_data = StringStats::GetDataUnsafe(stats); - serializer.WriteProperty(200, "min", string_data.min, StringStatsData::MAX_STRING_MINMAX_SIZE); - serializer.WriteProperty(201, "max", string_data.max, StringStatsData::MAX_STRING_MINMAX_SIZE); - serializer.WriteProperty(202, "has_unicode", string_data.has_unicode); - serializer.WriteProperty(203, "has_max_string_length", string_data.has_max_string_length); - serializer.WriteProperty(204, "max_string_length", string_data.max_string_length); -} - -void StringStats::Deserialize(Deserializer &deserializer, BaseStatistics &base) { - auto &string_data = StringStats::GetDataUnsafe(base); - deserializer.ReadProperty(200, "min", string_data.min, StringStatsData::MAX_STRING_MINMAX_SIZE); - deserializer.ReadProperty(201, "max", string_data.max, StringStatsData::MAX_STRING_MINMAX_SIZE); - deserializer.ReadProperty(202, "has_unicode", string_data.has_unicode); - deserializer.ReadProperty(203, "has_max_string_length", string_data.has_max_string_length); - deserializer.ReadProperty(204, "max_string_length", string_data.max_string_length); -} - -static int StringValueComparison(const_data_ptr_t data, idx_t len, const_data_ptr_t comparison) { - D_ASSERT(len <= StringStatsData::MAX_STRING_MINMAX_SIZE); - for (idx_t i = 0; i < len; i++) { - if (data[i] < comparison[i]) { - return -1; - } else if (data[i] > comparison[i]) { - return 1; - } - } - return 0; -} - -static void ConstructValue(const_data_ptr_t data, idx_t size, data_t target[]) { - idx_t value_size = size > StringStatsData::MAX_STRING_MINMAX_SIZE ? StringStatsData::MAX_STRING_MINMAX_SIZE : size; - memcpy(target, data, value_size); - for (idx_t i = value_size; i < StringStatsData::MAX_STRING_MINMAX_SIZE; i++) { - target[i] = '\0'; - } -} - -void StringStats::Update(BaseStatistics &stats, const string_t &value) { - auto data = const_data_ptr_cast(value.GetData()); - auto size = value.GetSize(); - - //! we can only fit 8 bytes, so we might need to trim our string - // construct the value - data_t target[StringStatsData::MAX_STRING_MINMAX_SIZE]; - ConstructValue(data, size, target); - - // update the min and max - auto &string_data = StringStats::GetDataUnsafe(stats); - if (StringValueComparison(target, StringStatsData::MAX_STRING_MINMAX_SIZE, string_data.min) < 0) { - memcpy(string_data.min, target, StringStatsData::MAX_STRING_MINMAX_SIZE); - } - if (StringValueComparison(target, StringStatsData::MAX_STRING_MINMAX_SIZE, string_data.max) > 0) { - memcpy(string_data.max, target, StringStatsData::MAX_STRING_MINMAX_SIZE); - } - if (size > string_data.max_string_length) { - string_data.max_string_length = size; - } - if (stats.GetType().id() == LogicalTypeId::VARCHAR && !string_data.has_unicode) { - auto unicode = Utf8Proc::Analyze(const_char_ptr_cast(data), size); - if (unicode == UnicodeType::UNICODE) { - string_data.has_unicode = true; - } else if (unicode == UnicodeType::INVALID) { - throw InvalidInputException(ErrorManager::InvalidUnicodeError(string(const_char_ptr_cast(data), size), - "segment statistics update")); - } - } -} - -void StringStats::Merge(BaseStatistics &stats, const BaseStatistics &other) { - if (other.GetType().id() == LogicalTypeId::VALIDITY) { - return; - } - auto &string_data = StringStats::GetDataUnsafe(stats); - auto &other_data = StringStats::GetDataUnsafe(other); - if (StringValueComparison(other_data.min, StringStatsData::MAX_STRING_MINMAX_SIZE, string_data.min) < 0) { - memcpy(string_data.min, other_data.min, StringStatsData::MAX_STRING_MINMAX_SIZE); - } - if (StringValueComparison(other_data.max, StringStatsData::MAX_STRING_MINMAX_SIZE, string_data.max) > 0) { - memcpy(string_data.max, other_data.max, StringStatsData::MAX_STRING_MINMAX_SIZE); - } - string_data.has_unicode = string_data.has_unicode || other_data.has_unicode; - string_data.has_max_string_length = string_data.has_max_string_length && other_data.has_max_string_length; - string_data.max_string_length = MaxValue(string_data.max_string_length, other_data.max_string_length); -} - -FilterPropagateResult StringStats::CheckZonemap(const BaseStatistics &stats, ExpressionType comparison_type, - const string &constant) { - auto &string_data = StringStats::GetDataUnsafe(stats); - auto data = const_data_ptr_cast(constant.c_str()); - auto size = constant.size(); - - idx_t value_size = size > StringStatsData::MAX_STRING_MINMAX_SIZE ? StringStatsData::MAX_STRING_MINMAX_SIZE : size; - int min_comp = StringValueComparison(data, value_size, string_data.min); - int max_comp = StringValueComparison(data, value_size, string_data.max); - switch (comparison_type) { - case ExpressionType::COMPARE_EQUAL: - if (min_comp >= 0 && max_comp <= 0) { - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } else { - return FilterPropagateResult::FILTER_ALWAYS_FALSE; - } - case ExpressionType::COMPARE_NOTEQUAL: - if (min_comp < 0 || max_comp > 0) { - return FilterPropagateResult::FILTER_ALWAYS_TRUE; - } - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - case ExpressionType::COMPARE_GREATERTHAN: - if (max_comp <= 0) { - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } else { - return FilterPropagateResult::FILTER_ALWAYS_FALSE; - } - case ExpressionType::COMPARE_LESSTHAN: - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - if (min_comp >= 0) { - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } else { - return FilterPropagateResult::FILTER_ALWAYS_FALSE; - } - default: - throw InternalException("Expression type not implemented for string statistics zone map"); - } -} - -static idx_t GetValidMinMaxSubstring(const_data_ptr_t data) { - for (idx_t i = 0; i < StringStatsData::MAX_STRING_MINMAX_SIZE; i++) { - if (data[i] == '\0') { - return i; - } - if ((data[i] & 0x80) != 0) { - return i; - } - } - return StringStatsData::MAX_STRING_MINMAX_SIZE; -} - -string StringStats::ToString(const BaseStatistics &stats) { - auto &string_data = StringStats::GetDataUnsafe(stats); - idx_t min_len = GetValidMinMaxSubstring(string_data.min); - idx_t max_len = GetValidMinMaxSubstring(string_data.max); - return StringUtil::Format("[Min: %s, Max: %s, Has Unicode: %s, Max String Length: %s]", - string(const_char_ptr_cast(string_data.min), min_len), - string(const_char_ptr_cast(string_data.max), max_len), - string_data.has_unicode ? "true" : "false", - string_data.has_max_string_length ? to_string(string_data.max_string_length) : "?"); -} - -void StringStats::Verify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, idx_t count) { - auto &string_data = StringStats::GetDataUnsafe(stats); - - UnifiedVectorFormat vdata; - vector.ToUnifiedFormat(count, vdata); - auto data = UnifiedVectorFormat::GetData(vdata); - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto index = vdata.sel->get_index(idx); - if (!vdata.validity.RowIsValid(index)) { - continue; - } - auto value = data[index]; - auto data = value.GetData(); - auto len = value.GetSize(); - // LCOV_EXCL_START - if (string_data.has_max_string_length && len > string_data.max_string_length) { - throw InternalException( - "Statistics mismatch: string value exceeds maximum string length.\nStatistics: %s\nVector: %s", - stats.ToString(), vector.ToString(count)); - } - if (stats.GetType().id() == LogicalTypeId::VARCHAR && !string_data.has_unicode) { - auto unicode = Utf8Proc::Analyze(data, len); - if (unicode == UnicodeType::UNICODE) { - throw InternalException("Statistics mismatch: string value contains unicode, but statistics says it " - "shouldn't.\nStatistics: %s\nVector: %s", - stats.ToString(), vector.ToString(count)); - } else if (unicode == UnicodeType::INVALID) { - throw InternalException("Invalid unicode detected in vector: %s", vector.ToString(count)); - } - } - if (StringValueComparison(const_data_ptr_cast(data), - MinValue(len, StringStatsData::MAX_STRING_MINMAX_SIZE), string_data.min) < 0) { - throw InternalException("Statistics mismatch: value is smaller than min.\nStatistics: %s\nVector: %s", - stats.ToString(), vector.ToString(count)); - } - if (StringValueComparison(const_data_ptr_cast(data), - MinValue(len, StringStatsData::MAX_STRING_MINMAX_SIZE), string_data.max) > 0) { - throw InternalException("Statistics mismatch: value is bigger than max.\nStatistics: %s\nVector: %s", - stats.ToString(), vector.ToString(count)); - } - // LCOV_EXCL_STOP - } -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -void StructStats::Construct(BaseStatistics &stats) { - auto &child_types = StructType::GetChildTypes(stats.GetType()); - stats.child_stats = unsafe_unique_array(new BaseStatistics[child_types.size()]); - for (idx_t i = 0; i < child_types.size(); i++) { - BaseStatistics::Construct(stats.child_stats[i], child_types[i].second); - } -} - -BaseStatistics StructStats::CreateUnknown(LogicalType type) { - auto &child_types = StructType::GetChildTypes(type); - BaseStatistics result(std::move(type)); - result.InitializeUnknown(); - for (idx_t i = 0; i < child_types.size(); i++) { - result.child_stats[i].Copy(BaseStatistics::CreateUnknown(child_types[i].second)); - } - return result; -} - -BaseStatistics StructStats::CreateEmpty(LogicalType type) { - auto &child_types = StructType::GetChildTypes(type); - BaseStatistics result(std::move(type)); - result.InitializeEmpty(); - for (idx_t i = 0; i < child_types.size(); i++) { - result.child_stats[i].Copy(BaseStatistics::CreateEmpty(child_types[i].second)); - } - return result; -} - -const BaseStatistics *StructStats::GetChildStats(const BaseStatistics &stats) { - if (stats.GetStatsType() != StatisticsType::STRUCT_STATS) { - throw InternalException("Calling StructStats::GetChildStats on stats that is not a struct"); - } - return stats.child_stats.get(); -} - -const BaseStatistics &StructStats::GetChildStats(const BaseStatistics &stats, idx_t i) { - D_ASSERT(stats.GetStatsType() == StatisticsType::STRUCT_STATS); - if (i >= StructType::GetChildCount(stats.GetType())) { - throw InternalException("Calling StructStats::GetChildStats but there are no stats for this index"); - } - return stats.child_stats[i]; -} - -BaseStatistics &StructStats::GetChildStats(BaseStatistics &stats, idx_t i) { - D_ASSERT(stats.GetStatsType() == StatisticsType::STRUCT_STATS); - if (i >= StructType::GetChildCount(stats.GetType())) { - throw InternalException("Calling StructStats::GetChildStats but there are no stats for this index"); - } - return stats.child_stats[i]; -} - -void StructStats::SetChildStats(BaseStatistics &stats, idx_t i, const BaseStatistics &new_stats) { - D_ASSERT(stats.GetStatsType() == StatisticsType::STRUCT_STATS); - D_ASSERT(i < StructType::GetChildCount(stats.GetType())); - stats.child_stats[i].Copy(new_stats); -} - -void StructStats::SetChildStats(BaseStatistics &stats, idx_t i, unique_ptr new_stats) { - D_ASSERT(stats.GetStatsType() == StatisticsType::STRUCT_STATS); - if (!new_stats) { - StructStats::SetChildStats(stats, i, - BaseStatistics::CreateUnknown(StructType::GetChildType(stats.GetType(), i))); - } else { - StructStats::SetChildStats(stats, i, *new_stats); - } -} - -void StructStats::Copy(BaseStatistics &stats, const BaseStatistics &other) { - auto count = StructType::GetChildCount(stats.GetType()); - for (idx_t i = 0; i < count; i++) { - stats.child_stats[i].Copy(other.child_stats[i]); - } -} - -void StructStats::Merge(BaseStatistics &stats, const BaseStatistics &other) { - if (other.GetType().id() == LogicalTypeId::VALIDITY) { - return; - } - D_ASSERT(stats.GetType() == other.GetType()); - auto child_count = StructType::GetChildCount(stats.GetType()); - for (idx_t i = 0; i < child_count; i++) { - stats.child_stats[i].Merge(other.child_stats[i]); - } -} - -void StructStats::Serialize(const BaseStatistics &stats, Serializer &serializer) { - auto child_stats = StructStats::GetChildStats(stats); - auto child_count = StructType::GetChildCount(stats.GetType()); - - serializer.WriteList(200, "child_stats", child_count, - [&](Serializer::List &list, idx_t i) { list.WriteElement(child_stats[i]); }); -} - -void StructStats::Deserialize(Deserializer &deserializer, BaseStatistics &base) { - auto &type = base.GetType(); - D_ASSERT(type.InternalType() == PhysicalType::STRUCT); - - auto &child_types = StructType::GetChildTypes(type); - - deserializer.ReadList(200, "child_stats", [&](Deserializer::List &list, idx_t i) { - deserializer.Set(const_cast(child_types[i].second)); - auto stat = list.ReadElement(); - base.child_stats[i].Copy(stat); - deserializer.Unset(); - }); -} - -string StructStats::ToString(const BaseStatistics &stats) { - string result; - result += " {"; - auto &child_types = StructType::GetChildTypes(stats.GetType()); - for (idx_t i = 0; i < child_types.size(); i++) { - if (i > 0) { - result += ", "; - } - result += child_types[i].first + ": " + stats.child_stats[i].ToString(); - } - result += "}"; - return result; -} - -void StructStats::Verify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, idx_t count) { - auto &child_entries = StructVector::GetEntries(vector); - for (idx_t i = 0; i < child_entries.size(); i++) { - stats.child_stats[i].Verify(*child_entries[i], sel, count); - } -} - -} // namespace duckdb - - -namespace duckdb { - -const uint64_t VERSION_NUMBER = 64; - -struct StorageVersionInfo { - const char *version_name; - idx_t storage_version; -}; - -static StorageVersionInfo storage_version_info[] = {{"v0.8.0 or v0.8.1", 51}, - {"v0.7.0 or v0.7.1", 43}, - {"v0.6.0 or v0.6.1", 39}, - {"v0.5.0 or v0.5.1", 38}, - {"v0.3.3, v0.3.4 or v0.4.0", 33}, - {"v0.3.2", 31}, - {"v0.3.1", 27}, - {"v0.3.0", 25}, - {"v0.2.9", 21}, - {"v0.2.8", 18}, - {"v0.2.7", 17}, - {"v0.2.6", 15}, - {"v0.2.5", 13}, - {"v0.2.4", 11}, - {"v0.2.3", 6}, - {"v0.2.2", 4}, - {"v0.2.1 and prior", 1}, - {nullptr, 0}}; - -const char *GetDuckDBVersion(idx_t version_number) { - for (idx_t i = 0; storage_version_info[i].version_name; i++) { - if (version_number == storage_version_info[i].storage_version) { - return storage_version_info[i].version_name; - } - } - return nullptr; -} - -} // namespace duckdb - - - - -namespace duckdb { - -StorageLockKey::StorageLockKey(StorageLock &lock, StorageLockType type) : lock(lock), type(type) { -} - -StorageLockKey::~StorageLockKey() { - if (type == StorageLockType::EXCLUSIVE) { - lock.ReleaseExclusiveLock(); - } else { - D_ASSERT(type == StorageLockType::SHARED); - lock.ReleaseSharedLock(); - } -} - -StorageLock::StorageLock() : read_count(0) { -} - -unique_ptr StorageLock::GetExclusiveLock() { - exclusive_lock.lock(); - while (read_count != 0) { - } - return make_uniq(*this, StorageLockType::EXCLUSIVE); -} - -unique_ptr StorageLock::GetSharedLock() { - exclusive_lock.lock(); - read_count++; - exclusive_lock.unlock(); - return make_uniq(*this, StorageLockType::SHARED); -} - -void StorageLock::ReleaseExclusiveLock() { - exclusive_lock.unlock(); -} - -void StorageLock::ReleaseSharedLock() { - read_count--; -} - -} // namespace duckdb - - - - - - - - - - - - - - - - -namespace duckdb { - -StorageManager::StorageManager(AttachedDatabase &db, string path_p, bool read_only) - : db(db), path(std::move(path_p)), read_only(read_only) { - if (path.empty()) { - path = ":memory:"; - } else { - auto &fs = FileSystem::Get(db); - this->path = fs.ExpandPath(path); - } -} - -StorageManager::~StorageManager() { -} - -StorageManager &StorageManager::Get(AttachedDatabase &db) { - return db.GetStorageManager(); -} -StorageManager &StorageManager::Get(Catalog &catalog) { - return StorageManager::Get(catalog.GetAttached()); -} - -DatabaseInstance &StorageManager::GetDatabase() { - return db.GetDatabase(); -} - -BufferManager &BufferManager::GetBufferManager(ClientContext &context) { - return BufferManager::GetBufferManager(*context.db); -} - -ObjectCache &ObjectCache::GetObjectCache(ClientContext &context) { - return context.db->GetObjectCache(); -} - -bool ObjectCache::ObjectCacheEnabled(ClientContext &context) { - return context.db->config.options.object_cache_enable; -} - -bool StorageManager::InMemory() { - D_ASSERT(!path.empty()); - return path == ":memory:"; -} - -void StorageManager::Initialize() { - bool in_memory = InMemory(); - if (in_memory && read_only) { - throw CatalogException("Cannot launch in-memory database in read-only mode!"); - } - - // create or load the database from disk, if not in-memory mode - LoadDatabase(); -} - -/////////////////////////////////////////////////////////////////////////// -class SingleFileTableIOManager : public TableIOManager { -public: - explicit SingleFileTableIOManager(BlockManager &block_manager) : block_manager(block_manager) { - } - - BlockManager &block_manager; - -public: - BlockManager &GetIndexBlockManager() override { - return block_manager; - } - BlockManager &GetBlockManagerForRowData() override { - return block_manager; - } - MetadataManager &GetMetadataManager() override { - return block_manager.GetMetadataManager(); - } -}; - -SingleFileStorageManager::SingleFileStorageManager(AttachedDatabase &db, string path, bool read_only) - : StorageManager(db, std::move(path), read_only) { -} - -void SingleFileStorageManager::LoadDatabase() { - if (InMemory()) { - block_manager = make_uniq(BufferManager::GetBufferManager(db)); - table_io_manager = make_uniq(*block_manager); - return; - } - std::size_t question_mark_pos = path.find('?'); - auto wal_path = path; - if (question_mark_pos != std::string::npos) { - wal_path.insert(question_mark_pos, ".wal"); - } else { - wal_path += ".wal"; - } - auto &fs = FileSystem::Get(db); - auto &config = DBConfig::Get(db); - bool truncate_wal = false; - if (!config.options.enable_external_access) { - if (!db.IsInitialDatabase()) { - throw PermissionException("Attaching on-disk databases is disabled through configuration"); - } - } - - StorageManagerOptions options; - options.read_only = read_only; - options.use_direct_io = config.options.use_direct_io; - options.debug_initialize = config.options.debug_initialize; - // first check if the database exists - if (!fs.FileExists(path)) { - if (read_only) { - throw CatalogException("Cannot open database \"%s\" in read-only mode: database does not exist", path); - } - // check if the WAL exists - if (fs.FileExists(wal_path)) { - // WAL file exists but database file does not - // remove the WAL - fs.RemoveFile(wal_path); - } - // initialize the block manager while creating a new db file - auto sf_block_manager = make_uniq(db, path, options); - sf_block_manager->CreateNewDatabase(); - block_manager = std::move(sf_block_manager); - table_io_manager = make_uniq(*block_manager); - } else { - // initialize the block manager while loading the current db file - auto sf_block_manager = make_uniq(db, path, options); - sf_block_manager->LoadExistingDatabase(); - block_manager = std::move(sf_block_manager); - table_io_manager = make_uniq(*block_manager); - - //! Load from storage - auto checkpointer = SingleFileCheckpointReader(*this); - checkpointer.LoadFromStorage(); - // check if the WAL file exists - if (fs.FileExists(wal_path)) { - // replay the WAL - truncate_wal = WriteAheadLog::Replay(db, wal_path); - } - } - // initialize the WAL file - if (!read_only) { - wal = make_uniq(db, wal_path); - if (truncate_wal) { - wal->Truncate(0); - } - } -} - -/////////////////////////////////////////////////////////////////////////////// - -class SingleFileStorageCommitState : public StorageCommitState { - idx_t initial_wal_size = 0; - idx_t initial_written = 0; - optional_ptr log; - bool checkpoint; - -public: - SingleFileStorageCommitState(StorageManager &storage_manager, bool checkpoint); - ~SingleFileStorageCommitState() override { - // If log is non-null, then commit threw an exception before flushing. - if (log) { - auto &wal = *log.get(); - wal.skip_writing = false; - if (wal.GetTotalWritten() > initial_written) { - // remove any entries written into the WAL by truncating it - wal.Truncate(initial_wal_size); - } - } - } - - // Make the commit persistent - void FlushCommit() override; -}; - -SingleFileStorageCommitState::SingleFileStorageCommitState(StorageManager &storage_manager, bool checkpoint) - : checkpoint(checkpoint) { - log = storage_manager.GetWriteAheadLog(); - if (log) { - auto initial_size = log->GetWALSize(); - initial_written = log->GetTotalWritten(); - initial_wal_size = initial_size < 0 ? 0 : idx_t(initial_size); - - if (checkpoint) { - // check if we are checkpointing after this commit - // if we are checkpointing, we don't need to write anything to the WAL - // this saves us a lot of unnecessary writes to disk in the case of large commits - log->skip_writing = true; - } - } else { - D_ASSERT(!checkpoint); - } -} - -// Make the commit persistent -void SingleFileStorageCommitState::FlushCommit() { - if (log) { - // flush the WAL if any changes were made - if (log->GetTotalWritten() > initial_written) { - (void)checkpoint; - D_ASSERT(!checkpoint); - D_ASSERT(!log->skip_writing); - log->Flush(); - } - log->skip_writing = false; - } - // Null so that the destructor will not truncate the log. - log = nullptr; -} - -unique_ptr SingleFileStorageManager::GenStorageCommitState(Transaction &transaction, - bool checkpoint) { - return make_uniq(*this, checkpoint); -} - -bool SingleFileStorageManager::IsCheckpointClean(MetaBlockPointer checkpoint_id) { - return block_manager->IsRootBlock(checkpoint_id); -} - -void SingleFileStorageManager::CreateCheckpoint(bool delete_wal, bool force_checkpoint) { - if (InMemory() || read_only || !wal) { - return; - } - auto &config = DBConfig::Get(db); - if (wal->GetWALSize() > 0 || config.options.force_checkpoint || force_checkpoint) { - // we only need to checkpoint if there is anything in the WAL - try { - SingleFileCheckpointWriter checkpointer(db, *block_manager); - checkpointer.CreateCheckpoint(); - } catch (std::exception &ex) { - throw FatalException("Failed to create checkpoint because of error: %s", ex.what()); - } - } - if (delete_wal) { - wal->Delete(); - wal.reset(); - } -} - -DatabaseSize SingleFileStorageManager::GetDatabaseSize() { - // All members default to zero - DatabaseSize ds; - if (!InMemory()) { - ds.total_blocks = block_manager->TotalBlocks(); - ds.block_size = Storage::BLOCK_ALLOC_SIZE; - ds.free_blocks = block_manager->FreeBlocks(); - ds.used_blocks = ds.total_blocks - ds.free_blocks; - ds.bytes = (ds.total_blocks * ds.block_size); - if (auto wal = GetWriteAheadLog()) { - ds.wal_size = wal->GetWALSize(); - } - } - return ds; -} - -vector SingleFileStorageManager::GetMetadataInfo() { - auto &metadata_manager = block_manager->GetMetadataManager(); - return metadata_manager.GetMetadataInfo(); -} - -bool SingleFileStorageManager::AutomaticCheckpoint(idx_t estimated_wal_bytes) { - auto log = GetWriteAheadLog(); - if (!log) { - return false; - } - - auto &config = DBConfig::Get(db); - auto initial_size = log->GetWALSize(); - idx_t expected_wal_size = initial_size + estimated_wal_bytes; - return expected_wal_size > config.options.checkpoint_wal_size; -} - -shared_ptr SingleFileStorageManager::GetTableIOManager(BoundCreateTableInfo *info /*info*/) { - // This is an unmanaged reference. No ref/deref overhead. Lifetime of the - // TableIoManager follows lifetime of the StorageManager (this). - return shared_ptr(shared_ptr(nullptr), table_io_manager.get()); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -struct TransactionVersionOperator { - static bool UseInsertedVersion(transaction_t start_time, transaction_t transaction_id, transaction_t id) { - return id < start_time || id == transaction_id; - } - - static bool UseDeletedVersion(transaction_t start_time, transaction_t transaction_id, transaction_t id) { - return !UseInsertedVersion(start_time, transaction_id, id); - } -}; - -struct CommittedVersionOperator { - static bool UseInsertedVersion(transaction_t start_time, transaction_t transaction_id, transaction_t id) { - return true; - } - - static bool UseDeletedVersion(transaction_t min_start_time, transaction_t min_transaction_id, transaction_t id) { - return (id >= min_start_time && id < TRANSACTION_ID_START) || (id >= min_transaction_id); - } -}; - -static bool UseVersion(TransactionData transaction, transaction_t id) { - return TransactionVersionOperator::UseInsertedVersion(transaction.start_time, transaction.transaction_id, id); -} - -void ChunkInfo::Write(WriteStream &writer) const { - writer.Write(type); -} - -unique_ptr ChunkInfo::Read(ReadStream &reader) { - auto type = reader.Read(); - switch (type) { - case ChunkInfoType::EMPTY_INFO: - return nullptr; - case ChunkInfoType::CONSTANT_INFO: - return ChunkConstantInfo::Read(reader); - case ChunkInfoType::VECTOR_INFO: - return ChunkVectorInfo::Read(reader); - default: - throw SerializationException("Could not deserialize Chunk Info Type: unrecognized type"); - } -} - -//===--------------------------------------------------------------------===// -// Constant info -//===--------------------------------------------------------------------===// -ChunkConstantInfo::ChunkConstantInfo(idx_t start) - : ChunkInfo(start, ChunkInfoType::CONSTANT_INFO), insert_id(0), delete_id(NOT_DELETED_ID) { -} - -template -idx_t ChunkConstantInfo::TemplatedGetSelVector(transaction_t start_time, transaction_t transaction_id, - SelectionVector &sel_vector, idx_t max_count) const { - if (OP::UseInsertedVersion(start_time, transaction_id, insert_id) && - OP::UseDeletedVersion(start_time, transaction_id, delete_id)) { - return max_count; - } - return 0; -} - -idx_t ChunkConstantInfo::GetSelVector(TransactionData transaction, SelectionVector &sel_vector, idx_t max_count) { - return TemplatedGetSelVector(transaction.start_time, transaction.transaction_id, - sel_vector, max_count); -} - -idx_t ChunkConstantInfo::GetCommittedSelVector(transaction_t min_start_id, transaction_t min_transaction_id, - SelectionVector &sel_vector, idx_t max_count) { - return TemplatedGetSelVector(min_start_id, min_transaction_id, sel_vector, max_count); -} - -bool ChunkConstantInfo::Fetch(TransactionData transaction, row_t row) { - return UseVersion(transaction, insert_id) && !UseVersion(transaction, delete_id); -} - -void ChunkConstantInfo::CommitAppend(transaction_t commit_id, idx_t start, idx_t end) { - D_ASSERT(start == 0 && end == STANDARD_VECTOR_SIZE); - insert_id = commit_id; -} - -bool ChunkConstantInfo::HasDeletes() const { - bool is_deleted = insert_id >= TRANSACTION_ID_START || delete_id < TRANSACTION_ID_START; - return is_deleted; -} - -idx_t ChunkConstantInfo::GetCommittedDeletedCount(idx_t max_count) { - return delete_id < TRANSACTION_ID_START ? max_count : 0; -} - -void ChunkConstantInfo::Write(WriteStream &writer) const { - D_ASSERT(HasDeletes()); - ChunkInfo::Write(writer); - writer.Write(start); -} - -unique_ptr ChunkConstantInfo::Read(ReadStream &reader) { - auto start = reader.Read(); - auto info = make_uniq(start); - info->insert_id = 0; - info->delete_id = 0; - return std::move(info); -} - -//===--------------------------------------------------------------------===// -// Vector info -//===--------------------------------------------------------------------===// -ChunkVectorInfo::ChunkVectorInfo(idx_t start) - : ChunkInfo(start, ChunkInfoType::VECTOR_INFO), insert_id(0), same_inserted_id(true), any_deleted(false) { - for (idx_t i = 0; i < STANDARD_VECTOR_SIZE; i++) { - inserted[i] = 0; - deleted[i] = NOT_DELETED_ID; - } -} - -template -idx_t ChunkVectorInfo::TemplatedGetSelVector(transaction_t start_time, transaction_t transaction_id, - SelectionVector &sel_vector, idx_t max_count) const { - idx_t count = 0; - if (same_inserted_id && !any_deleted) { - // all tuples have the same inserted id: and no tuples were deleted - if (OP::UseInsertedVersion(start_time, transaction_id, insert_id)) { - return max_count; - } else { - return 0; - } - } else if (same_inserted_id) { - if (!OP::UseInsertedVersion(start_time, transaction_id, insert_id)) { - return 0; - } - // have to check deleted flag - for (idx_t i = 0; i < max_count; i++) { - if (OP::UseDeletedVersion(start_time, transaction_id, deleted[i])) { - sel_vector.set_index(count++, i); - } - } - } else if (!any_deleted) { - // have to check inserted flag - for (idx_t i = 0; i < max_count; i++) { - if (OP::UseInsertedVersion(start_time, transaction_id, inserted[i])) { - sel_vector.set_index(count++, i); - } - } - } else { - // have to check both flags - for (idx_t i = 0; i < max_count; i++) { - if (OP::UseInsertedVersion(start_time, transaction_id, inserted[i]) && - OP::UseDeletedVersion(start_time, transaction_id, deleted[i])) { - sel_vector.set_index(count++, i); - } - } - } - return count; -} - -idx_t ChunkVectorInfo::GetSelVector(transaction_t start_time, transaction_t transaction_id, SelectionVector &sel_vector, - idx_t max_count) const { - return TemplatedGetSelVector(start_time, transaction_id, sel_vector, max_count); -} - -idx_t ChunkVectorInfo::GetCommittedSelVector(transaction_t min_start_id, transaction_t min_transaction_id, - SelectionVector &sel_vector, idx_t max_count) { - return TemplatedGetSelVector(min_start_id, min_transaction_id, sel_vector, max_count); -} - -idx_t ChunkVectorInfo::GetSelVector(TransactionData transaction, SelectionVector &sel_vector, idx_t max_count) { - return GetSelVector(transaction.start_time, transaction.transaction_id, sel_vector, max_count); -} - -bool ChunkVectorInfo::Fetch(TransactionData transaction, row_t row) { - return UseVersion(transaction, inserted[row]) && !UseVersion(transaction, deleted[row]); -} - -idx_t ChunkVectorInfo::Delete(transaction_t transaction_id, row_t rows[], idx_t count) { - any_deleted = true; - - idx_t deleted_tuples = 0; - for (idx_t i = 0; i < count; i++) { - if (deleted[rows[i]] == transaction_id) { - continue; - } - // first check the chunk for conflicts - if (deleted[rows[i]] != NOT_DELETED_ID) { - // tuple was already deleted by another transaction - throw TransactionException("Conflict on tuple deletion!"); - } - // after verifying that there are no conflicts we mark the tuple as deleted - deleted[rows[i]] = transaction_id; - rows[deleted_tuples] = rows[i]; - deleted_tuples++; - } - return deleted_tuples; -} - -void ChunkVectorInfo::CommitDelete(transaction_t commit_id, row_t rows[], idx_t count) { - for (idx_t i = 0; i < count; i++) { - deleted[rows[i]] = commit_id; - } -} - -void ChunkVectorInfo::Append(idx_t start, idx_t end, transaction_t commit_id) { - if (start == 0) { - insert_id = commit_id; - } else if (insert_id != commit_id) { - same_inserted_id = false; - insert_id = NOT_DELETED_ID; - } - for (idx_t i = start; i < end; i++) { - inserted[i] = commit_id; - } -} - -void ChunkVectorInfo::CommitAppend(transaction_t commit_id, idx_t start, idx_t end) { - if (same_inserted_id) { - insert_id = commit_id; - } - for (idx_t i = start; i < end; i++) { - inserted[i] = commit_id; - } -} - -bool ChunkVectorInfo::HasDeletes() const { - return any_deleted; -} - -idx_t ChunkVectorInfo::GetCommittedDeletedCount(idx_t max_count) { - if (!any_deleted) { - return 0; - } - idx_t delete_count = 0; - for (idx_t i = 0; i < max_count; i++) { - if (deleted[i] < TRANSACTION_ID_START) { - delete_count++; - } - } - return delete_count; -} - -void ChunkVectorInfo::Write(WriteStream &writer) const { - SelectionVector sel(STANDARD_VECTOR_SIZE); - transaction_t start_time = TRANSACTION_ID_START - 1; - transaction_t transaction_id = DConstants::INVALID_INDEX; - idx_t count = GetSelVector(start_time, transaction_id, sel, STANDARD_VECTOR_SIZE); - if (count == STANDARD_VECTOR_SIZE) { - // nothing is deleted: skip writing anything - writer.Write(ChunkInfoType::EMPTY_INFO); - return; - } - if (count == 0) { - // everything is deleted: write a constant vector - writer.Write(ChunkInfoType::CONSTANT_INFO); - writer.Write(start); - return; - } - // write a boolean vector - ChunkInfo::Write(writer); - writer.Write(start); - ValidityMask mask(STANDARD_VECTOR_SIZE); - mask.Initialize(STANDARD_VECTOR_SIZE); - for (idx_t i = 0; i < count; i++) { - mask.SetInvalid(sel.get_index(i)); - } - mask.Write(writer, STANDARD_VECTOR_SIZE); -} - -unique_ptr ChunkVectorInfo::Read(ReadStream &reader) { - auto start = reader.Read(); - auto result = make_uniq(start); - result->any_deleted = true; - ValidityMask mask; - mask.Read(reader, STANDARD_VECTOR_SIZE); - for (idx_t i = 0; i < STANDARD_VECTOR_SIZE; i++) { - if (mask.RowIsValid(i)) { - result->deleted[i] = 0; - } - } - return std::move(result); -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -ColumnCheckpointState::ColumnCheckpointState(RowGroup &row_group, ColumnData &column_data, - PartialBlockManager &partial_block_manager) - : row_group(row_group), column_data(column_data), partial_block_manager(partial_block_manager) { -} - -ColumnCheckpointState::~ColumnCheckpointState() { -} - -unique_ptr ColumnCheckpointState::GetStatistics() { - D_ASSERT(global_stats); - return std::move(global_stats); -} - -PartialBlockForCheckpoint::PartialBlockForCheckpoint(ColumnData &data, ColumnSegment &segment, PartialBlockState state, - BlockManager &block_manager) - : PartialBlock(state, block_manager, segment.block) { - AddSegmentToTail(data, segment, 0); -} - -PartialBlockForCheckpoint::~PartialBlockForCheckpoint() { - D_ASSERT(IsFlushed() || Exception::UncaughtException()); -} - -bool PartialBlockForCheckpoint::IsFlushed() { - // segments are cleared on Flush - return segments.empty(); -} - -void PartialBlockForCheckpoint::Flush(const idx_t free_space_left) { - - if (IsFlushed()) { - throw InternalException("Flush called on partial block that was already flushed"); - } - - // zero-initialize unused memory - FlushInternal(free_space_left); - - // At this point, we've already copied all data from tail_segments - // into the page owned by first_segment. We flush all segment data to - // disk with the following call. - // persist the first segment to disk and point the remaining segments to the same block - bool fetch_new_block = state.block_id == INVALID_BLOCK; - if (fetch_new_block) { - state.block_id = block_manager.GetFreeBlockId(); - } - - for (idx_t i = 0; i < segments.size(); i++) { - auto &segment = segments[i]; - segment.data.IncrementVersion(); - if (i == 0) { - // the first segment is converted to persistent - this writes the data for ALL segments to disk - D_ASSERT(segment.offset_in_block == 0); - segment.segment.ConvertToPersistent(&block_manager, state.block_id); - // update the block after it has been converted to a persistent segment - block_handle = segment.segment.block; - } else { - // subsequent segments are MARKED as persistent - they don't need to be rewritten - segment.segment.MarkAsPersistent(block_handle, segment.offset_in_block); - if (fetch_new_block) { - // if we fetched a new block we need to increase the reference count to the block - block_manager.IncreaseBlockReferenceCount(state.block_id); - } - } - } - - Clear(); -} - -void PartialBlockForCheckpoint::Merge(PartialBlock &other_p, idx_t offset, idx_t other_size) { - auto &other = other_p.Cast(); - - auto &buffer_manager = block_manager.buffer_manager; - // pin the source block - auto old_handle = buffer_manager.Pin(other.block_handle); - // pin the target block - auto new_handle = buffer_manager.Pin(block_handle); - // memcpy the contents of the old block to the new block - memcpy(new_handle.Ptr() + offset, old_handle.Ptr(), other_size); - - // now copy over all segments to the new block - // move over the uninitialized regions - for (auto ®ion : other.uninitialized_regions) { - region.start += offset; - region.end += offset; - uninitialized_regions.push_back(region); - } - - // move over the segments - for (auto &segment : other.segments) { - AddSegmentToTail(segment.data, segment.segment, segment.offset_in_block + offset); - } - - other.Clear(); -} - -void PartialBlockForCheckpoint::AddSegmentToTail(ColumnData &data, ColumnSegment &segment, uint32_t offset_in_block) { - segments.emplace_back(data, segment, offset_in_block); -} - -void PartialBlockForCheckpoint::Clear() { - uninitialized_regions.clear(); - block_handle.reset(); - segments.clear(); -} - -void ColumnCheckpointState::FlushSegment(unique_ptr segment, idx_t segment_size) { - D_ASSERT(segment_size <= Storage::BLOCK_SIZE); - auto tuple_count = segment->count.load(); - if (tuple_count == 0) { // LCOV_EXCL_START - return; - } // LCOV_EXCL_STOP - - // merge the segment stats into the global stats - global_stats->Merge(segment->stats.statistics); - - // get the buffer of the segment and pin it - auto &db = column_data.GetDatabase(); - auto &buffer_manager = BufferManager::GetBufferManager(db); - block_id_t block_id = INVALID_BLOCK; - uint32_t offset_in_block = 0; - - if (!segment->stats.statistics.IsConstant()) { - // non-constant block - PartialBlockAllocation allocation = partial_block_manager.GetBlockAllocation(segment_size); - block_id = allocation.state.block_id; - offset_in_block = allocation.state.offset; - - if (allocation.partial_block) { - // Use an existing block. - D_ASSERT(offset_in_block > 0); - auto &pstate = allocation.partial_block->Cast(); - // pin the source block - auto old_handle = buffer_manager.Pin(segment->block); - // pin the target block - auto new_handle = buffer_manager.Pin(pstate.block_handle); - // memcpy the contents of the old block to the new block - memcpy(new_handle.Ptr() + offset_in_block, old_handle.Ptr(), segment_size); - pstate.AddSegmentToTail(column_data, *segment, offset_in_block); - } else { - // Create a new block for future reuse. - if (segment->SegmentSize() != Storage::BLOCK_SIZE) { - // the segment is smaller than the block size - // allocate a new block and copy the data over - D_ASSERT(segment->SegmentSize() < Storage::BLOCK_SIZE); - segment->Resize(Storage::BLOCK_SIZE); - } - D_ASSERT(offset_in_block == 0); - allocation.partial_block = make_uniq(column_data, *segment, allocation.state, - *allocation.block_manager); - } - // Writer will decide whether to reuse this block. - partial_block_manager.RegisterPartialBlock(std::move(allocation)); - } else { - // constant block: no need to write anything to disk besides the stats - // set up the compression function to constant - auto &config = DBConfig::GetConfig(db); - segment->function = - *config.GetCompressionFunction(CompressionType::COMPRESSION_CONSTANT, segment->type.InternalType()); - segment->ConvertToPersistent(nullptr, INVALID_BLOCK); - } - - // construct the data pointer - DataPointer data_pointer(segment->stats.statistics.Copy()); - data_pointer.block_pointer.block_id = block_id; - data_pointer.block_pointer.offset = offset_in_block; - data_pointer.row_start = row_group.start; - if (!data_pointers.empty()) { - auto &last_pointer = data_pointers.back(); - data_pointer.row_start = last_pointer.row_start + last_pointer.tuple_count; - } - data_pointer.tuple_count = tuple_count; - data_pointer.compression_type = segment->function.get().type; - if (segment->function.get().serialize_state) { - data_pointer.segment_state = segment->function.get().serialize_state(*segment); - } - - // append the segment to the new segment tree - new_tree.AppendSegment(std::move(segment)); - data_pointers.push_back(std::move(data_pointer)); -} - -void ColumnCheckpointState::WriteDataPointers(RowGroupWriter &writer, Serializer &serializer) { - writer.WriteColumnDataPointers(*this, serializer); -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - - - -namespace duckdb { - -ColumnData::ColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, idx_t start_row, - LogicalType type_p, optional_ptr parent) - : start(start_row), count(0), block_manager(block_manager), info(info), column_index(column_index), - type(std::move(type_p)), parent(parent), version(0) { - if (!parent) { - stats = make_uniq(type); - } -} - -ColumnData::~ColumnData() { -} - -void ColumnData::SetStart(idx_t new_start) { - this->start = new_start; - idx_t offset = 0; - for (auto &segment : data.Segments()) { - segment.start = start + offset; - offset += segment.count; - } - data.Reinitialize(); -} - -DatabaseInstance &ColumnData::GetDatabase() const { - return info.db.GetDatabase(); -} - -DataTableInfo &ColumnData::GetTableInfo() const { - return info; -} - -const LogicalType &ColumnData::RootType() const { - if (parent) { - return parent->RootType(); - } - return type; -} - -void ColumnData::IncrementVersion() { - version++; -} - -idx_t ColumnData::GetMaxEntry() { - return count; -} - -void ColumnData::InitializeScan(ColumnScanState &state) { - state.current = data.GetRootSegment(); - state.segment_tree = &data; - state.row_index = state.current ? state.current->start : 0; - state.internal_index = state.row_index; - state.initialized = false; - state.version = version; - state.scan_state.reset(); - state.last_offset = 0; -} - -void ColumnData::InitializeScanWithOffset(ColumnScanState &state, idx_t row_idx) { - state.current = data.GetSegment(row_idx); - state.segment_tree = &data; - state.row_index = row_idx; - state.internal_index = state.current->start; - state.initialized = false; - state.version = version; - state.scan_state.reset(); - state.last_offset = 0; -} - -idx_t ColumnData::ScanVector(ColumnScanState &state, Vector &result, idx_t remaining, bool has_updates) { - state.previous_states.clear(); - if (state.version != version) { - InitializeScanWithOffset(state, state.row_index); - state.current->InitializeScan(state); - state.initialized = true; - } else if (!state.initialized) { - D_ASSERT(state.current); - state.current->InitializeScan(state); - state.internal_index = state.current->start; - state.initialized = true; - } - D_ASSERT(data.HasSegment(state.current)); - D_ASSERT(state.version == version); - D_ASSERT(state.internal_index <= state.row_index); - if (state.internal_index < state.row_index) { - state.current->Skip(state); - } - D_ASSERT(state.current->type == type); - idx_t initial_remaining = remaining; - while (remaining > 0) { - D_ASSERT(state.row_index >= state.current->start && - state.row_index <= state.current->start + state.current->count); - idx_t scan_count = MinValue(remaining, state.current->start + state.current->count - state.row_index); - idx_t result_offset = initial_remaining - remaining; - if (scan_count > 0) { - state.current->Scan(state, scan_count, result, result_offset, - !has_updates && scan_count == initial_remaining); - - state.row_index += scan_count; - remaining -= scan_count; - } - - if (remaining > 0) { - auto next = data.GetNextSegment(state.current); - if (!next) { - break; - } - state.previous_states.emplace_back(std::move(state.scan_state)); - state.current = next; - state.current->InitializeScan(state); - state.segment_checked = false; - D_ASSERT(state.row_index >= state.current->start && - state.row_index <= state.current->start + state.current->count); - } - } - state.internal_index = state.row_index; - return initial_remaining - remaining; -} - -template -idx_t ColumnData::ScanVector(TransactionData transaction, idx_t vector_index, ColumnScanState &state, Vector &result) { - bool has_updates; - { - lock_guard update_guard(update_lock); - has_updates = updates ? true : false; - } - auto scan_count = ScanVector(state, result, STANDARD_VECTOR_SIZE, has_updates); - if (has_updates) { - lock_guard update_guard(update_lock); - if (!ALLOW_UPDATES && updates->HasUncommittedUpdates(vector_index)) { - throw TransactionException("Cannot create index with outstanding updates"); - } - result.Flatten(scan_count); - if (SCAN_COMMITTED) { - updates->FetchCommitted(vector_index, result); - } else { - updates->FetchUpdates(transaction, vector_index, result); - } - } - return scan_count; -} - -template idx_t ColumnData::ScanVector(TransactionData transaction, idx_t vector_index, - ColumnScanState &state, Vector &result); -template idx_t ColumnData::ScanVector(TransactionData transaction, idx_t vector_index, - ColumnScanState &state, Vector &result); -template idx_t ColumnData::ScanVector(TransactionData transaction, idx_t vector_index, - ColumnScanState &state, Vector &result); -template idx_t ColumnData::ScanVector(TransactionData transaction, idx_t vector_index, - ColumnScanState &state, Vector &result); - -idx_t ColumnData::Scan(TransactionData transaction, idx_t vector_index, ColumnScanState &state, Vector &result) { - return ScanVector(transaction, vector_index, state, result); -} - -idx_t ColumnData::ScanCommitted(idx_t vector_index, ColumnScanState &state, Vector &result, bool allow_updates) { - if (allow_updates) { - return ScanVector(TransactionData(0, 0), vector_index, state, result); - } else { - return ScanVector(TransactionData(0, 0), vector_index, state, result); - } -} - -void ColumnData::ScanCommittedRange(idx_t row_group_start, idx_t offset_in_row_group, idx_t count, Vector &result) { - ColumnScanState child_state; - InitializeScanWithOffset(child_state, row_group_start + offset_in_row_group); - auto scan_count = ScanVector(child_state, result, count, updates ? true : false); - if (updates) { - result.Flatten(scan_count); - updates->FetchCommittedRange(offset_in_row_group, count, result); - } -} - -idx_t ColumnData::ScanCount(ColumnScanState &state, Vector &result, idx_t count) { - if (count == 0) { - return 0; - } - // ScanCount can only be used if there are no updates - D_ASSERT(!updates); - return ScanVector(state, result, count, false); -} - -void ColumnData::Select(TransactionData transaction, idx_t vector_index, ColumnScanState &state, Vector &result, - SelectionVector &sel, idx_t &count, const TableFilter &filter) { - idx_t scan_count = Scan(transaction, vector_index, state, result); - result.Flatten(scan_count); - ColumnSegment::FilterSelection(sel, result, filter, count, FlatVector::Validity(result)); -} - -void ColumnData::FilterScan(TransactionData transaction, idx_t vector_index, ColumnScanState &state, Vector &result, - SelectionVector &sel, idx_t count) { - Scan(transaction, vector_index, state, result); - result.Slice(sel, count); -} - -void ColumnData::FilterScanCommitted(idx_t vector_index, ColumnScanState &state, Vector &result, SelectionVector &sel, - idx_t count, bool allow_updates) { - ScanCommitted(vector_index, state, result, allow_updates); - result.Slice(sel, count); -} - -void ColumnData::Skip(ColumnScanState &state, idx_t count) { - state.Next(count); -} - -void ColumnData::Append(BaseStatistics &stats, ColumnAppendState &state, Vector &vector, idx_t count) { - UnifiedVectorFormat vdata; - vector.ToUnifiedFormat(count, vdata); - AppendData(stats, state, vdata, count); -} - -void ColumnData::Append(ColumnAppendState &state, Vector &vector, idx_t count) { - if (parent || !stats) { - throw InternalException("ColumnData::Append called on a column with a parent or without stats"); - } - Append(stats->statistics, state, vector, count); -} - -bool ColumnData::CheckZonemap(TableFilter &filter) { - if (!stats) { - throw InternalException("ColumnData::CheckZonemap called on a column without stats"); - } - auto propagate_result = filter.CheckStatistics(stats->statistics); - if (propagate_result == FilterPropagateResult::FILTER_ALWAYS_FALSE || - propagate_result == FilterPropagateResult::FILTER_FALSE_OR_NULL) { - return false; - } - return true; -} - -unique_ptr ColumnData::GetStatistics() { - if (!stats) { - throw InternalException("ColumnData::GetStatistics called on a column without stats"); - } - return stats->statistics.ToUnique(); -} - -void ColumnData::MergeStatistics(const BaseStatistics &other) { - if (!stats) { - throw InternalException("ColumnData::MergeStatistics called on a column without stats"); - } - return stats->statistics.Merge(other); -} - -void ColumnData::MergeIntoStatistics(BaseStatistics &other) { - if (!stats) { - throw InternalException("ColumnData::MergeIntoStatistics called on a column without stats"); - } - return other.Merge(stats->statistics); -} - -void ColumnData::InitializeAppend(ColumnAppendState &state) { - auto l = data.Lock(); - if (data.IsEmpty(l)) { - // no segments yet, append an empty segment - AppendTransientSegment(l, start); - } - auto segment = data.GetLastSegment(l); - if (segment->segment_type == ColumnSegmentType::PERSISTENT || !segment->function.get().init_append) { - // we cannot append to this segment - append a new segment - auto total_rows = segment->start + segment->count; - AppendTransientSegment(l, total_rows); - state.current = data.GetLastSegment(l); - } else { - state.current = segment; - } - - D_ASSERT(state.current->segment_type == ColumnSegmentType::TRANSIENT); - state.current->InitializeAppend(state); - D_ASSERT(state.current->function.get().append); -} - -void ColumnData::AppendData(BaseStatistics &stats, ColumnAppendState &state, UnifiedVectorFormat &vdata, idx_t count) { - idx_t offset = 0; - this->count += count; - while (true) { - // append the data from the vector - idx_t copied_elements = state.current->Append(state, vdata, offset, count); - stats.Merge(state.current->stats.statistics); - if (copied_elements == count) { - // finished copying everything - break; - } - - // we couldn't fit everything we wanted in the current column segment, create a new one - { - auto l = data.Lock(); - AppendTransientSegment(l, state.current->start + state.current->count); - state.current = data.GetLastSegment(l); - state.current->InitializeAppend(state); - } - offset += copied_elements; - count -= copied_elements; - } -} - -void ColumnData::RevertAppend(row_t start_row) { - auto l = data.Lock(); - // check if this row is in the segment tree at all - auto last_segment = data.GetLastSegment(l); - if (idx_t(start_row) >= last_segment->start + last_segment->count) { - // the start row is equal to the final portion of the column data: nothing was ever appended here - D_ASSERT(idx_t(start_row) == last_segment->start + last_segment->count); - return; - } - // find the segment index that the current row belongs to - idx_t segment_index = data.GetSegmentIndex(l, start_row); - auto segment = data.GetSegmentByIndex(l, segment_index); - auto &transient = *segment; - D_ASSERT(transient.segment_type == ColumnSegmentType::TRANSIENT); - - // remove any segments AFTER this segment: they should be deleted entirely - data.EraseSegments(l, segment_index); - - this->count = start_row - this->start; - segment->next = nullptr; - transient.RevertAppend(start_row); -} - -idx_t ColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &result) { - D_ASSERT(row_id >= 0); - D_ASSERT(idx_t(row_id) >= start); - // perform the fetch within the segment - state.row_index = start + ((row_id - start) / STANDARD_VECTOR_SIZE * STANDARD_VECTOR_SIZE); - state.current = data.GetSegment(state.row_index); - state.internal_index = state.current->start; - return ScanVector(state, result, STANDARD_VECTOR_SIZE, false); -} - -void ColumnData::FetchRow(TransactionData transaction, ColumnFetchState &state, row_t row_id, Vector &result, - idx_t result_idx) { - auto segment = data.GetSegment(row_id); - - // now perform the fetch within the segment - segment->FetchRow(state, row_id, result, result_idx); - // merge any updates made to this row - lock_guard update_guard(update_lock); - if (updates) { - updates->FetchRow(transaction, row_id, result, result_idx); - } -} - -void ColumnData::Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) { - lock_guard update_guard(update_lock); - if (!updates) { - updates = make_uniq(*this); - } - Vector base_vector(type); - ColumnScanState state; - auto fetch_count = Fetch(state, row_ids[0], base_vector); - - base_vector.Flatten(fetch_count); - updates->Update(transaction, column_index, update_vector, row_ids, update_count, base_vector); -} - -void ColumnData::UpdateColumn(TransactionData transaction, const vector &column_path, Vector &update_vector, - row_t *row_ids, idx_t update_count, idx_t depth) { - // this method should only be called at the end of the path in the base column case - D_ASSERT(depth >= column_path.size()); - ColumnData::Update(transaction, column_path[0], update_vector, row_ids, update_count); -} - -unique_ptr ColumnData::GetUpdateStatistics() { - lock_guard update_guard(update_lock); - return updates ? updates->GetStatistics() : nullptr; -} - -void ColumnData::AppendTransientSegment(SegmentLock &l, idx_t start_row) { - idx_t segment_size = Storage::BLOCK_SIZE; - if (start_row == idx_t(MAX_ROW_ID)) { -#if STANDARD_VECTOR_SIZE < 1024 - segment_size = 1024 * GetTypeIdSize(type.InternalType()); -#else - segment_size = STANDARD_VECTOR_SIZE * GetTypeIdSize(type.InternalType()); -#endif - } - auto new_segment = ColumnSegment::CreateTransientSegment(GetDatabase(), type, start_row, segment_size); - data.AppendSegment(l, std::move(new_segment)); -} - -void ColumnData::CommitDropColumn() { - for (auto &segment_p : data.Segments()) { - auto &segment = segment_p; - segment.CommitDropSegment(); - } -} - -unique_ptr ColumnData::CreateCheckpointState(RowGroup &row_group, - PartialBlockManager &partial_block_manager) { - return make_uniq(row_group, *this, partial_block_manager); -} - -void ColumnData::CheckpointScan(ColumnSegment &segment, ColumnScanState &state, idx_t row_group_start, idx_t count, - Vector &scan_vector) { - segment.Scan(state, count, scan_vector, 0, true); - if (updates) { - scan_vector.Flatten(count); - updates->FetchCommittedRange(state.row_index - row_group_start, count, scan_vector); - } -} - -unique_ptr ColumnData::Checkpoint(RowGroup &row_group, - PartialBlockManager &partial_block_manager, - ColumnCheckpointInfo &checkpoint_info) { - // scan the segments of the column data - // set up the checkpoint state - auto checkpoint_state = CreateCheckpointState(row_group, partial_block_manager); - checkpoint_state->global_stats = BaseStatistics::CreateEmpty(type).ToUnique(); - - auto l = data.Lock(); - auto nodes = data.MoveSegments(l); - if (nodes.empty()) { - // empty table: flush the empty list - return checkpoint_state; - } - lock_guard update_guard(update_lock); - - ColumnDataCheckpointer checkpointer(*this, row_group, *checkpoint_state, checkpoint_info); - checkpointer.Checkpoint(std::move(nodes)); - - // replace the old tree with the new one - data.Replace(l, checkpoint_state->new_tree); - version++; - - return checkpoint_state; -} - -void ColumnData::DeserializeColumn(Deserializer &deserializer) { - // load the data pointers for the column - deserializer.Set(info.db.GetDatabase()); - deserializer.Set(type); - - vector data_pointers; - deserializer.ReadProperty(100, "data_pointers", data_pointers); - - deserializer.Unset(); - deserializer.Unset(); - - // construct the segments based on the data pointers - this->count = 0; - for (auto &data_pointer : data_pointers) { - // Update the count and statistics - this->count += data_pointer.tuple_count; - if (stats) { - stats->statistics.Merge(data_pointer.statistics); - } - - // create a persistent segment - auto segment = ColumnSegment::CreatePersistentSegment( - GetDatabase(), block_manager, data_pointer.block_pointer.block_id, data_pointer.block_pointer.offset, type, - data_pointer.row_start, data_pointer.tuple_count, data_pointer.compression_type, - std::move(data_pointer.statistics), std::move(data_pointer.segment_state)); - - data.AppendSegment(std::move(segment)); - } -} - -shared_ptr ColumnData::Deserialize(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, - idx_t start_row, ReadStream &source, const LogicalType &type, - optional_ptr parent) { - auto entry = ColumnData::CreateColumn(block_manager, info, column_index, start_row, type, parent); - BinaryDeserializer deserializer(source); - deserializer.Begin(); - entry->DeserializeColumn(deserializer); - deserializer.End(); - return entry; -} - -void ColumnData::GetColumnSegmentInfo(idx_t row_group_index, vector col_path, - vector &result) { - D_ASSERT(!col_path.empty()); - - // convert the column path to a string - string col_path_str = "["; - for (idx_t i = 0; i < col_path.size(); i++) { - if (i > 0) { - col_path_str += ", "; - } - col_path_str += to_string(col_path[i]); - } - col_path_str += "]"; - - // iterate over the segments - idx_t segment_idx = 0; - auto segment = (ColumnSegment *)data.GetRootSegment(); - while (segment) { - ColumnSegmentInfo column_info; - column_info.row_group_index = row_group_index; - column_info.column_id = col_path[0]; - column_info.column_path = col_path_str; - column_info.segment_idx = segment_idx; - column_info.segment_type = type.ToString(); - column_info.segment_start = segment->start; - column_info.segment_count = segment->count; - column_info.compression_type = CompressionTypeToString(segment->function.get().type); - column_info.segment_stats = segment->stats.statistics.ToString(); - { - lock_guard ulock(update_lock); - column_info.has_updates = updates ? true : false; - } - // persistent - // block_id - // block_offset - if (segment->segment_type == ColumnSegmentType::PERSISTENT) { - column_info.persistent = true; - column_info.block_id = segment->GetBlockId(); - column_info.block_offset = segment->GetBlockOffset(); - } else { - column_info.persistent = false; - } - auto segment_state = segment->GetSegmentState(); - if (segment_state) { - column_info.segment_info = segment_state->GetSegmentInfo(); - } - result.emplace_back(column_info); - - segment_idx++; - segment = data.GetNextSegment(segment); - } -} - -void ColumnData::Verify(RowGroup &parent) { -#ifdef DEBUG - D_ASSERT(this->start == parent.start); - data.Verify(); - if (type.InternalType() == PhysicalType::STRUCT) { - // structs don't have segments - D_ASSERT(!data.GetRootSegment()); - return; - } - idx_t current_index = 0; - idx_t current_start = this->start; - idx_t total_count = 0; - for (auto &segment : data.Segments()) { - D_ASSERT(segment.index == current_index); - D_ASSERT(segment.start == current_start); - current_start += segment.count; - total_count += segment.count; - current_index++; - } - D_ASSERT(this->count == total_count); -#endif -} - -template -static RET CreateColumnInternal(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, idx_t start_row, - const LogicalType &type, optional_ptr parent) { - if (type.InternalType() == PhysicalType::STRUCT) { - return OP::template Create(block_manager, info, column_index, start_row, type, parent); - } else if (type.InternalType() == PhysicalType::LIST) { - return OP::template Create(block_manager, info, column_index, start_row, type, parent); - } else if (type.id() == LogicalTypeId::VALIDITY) { - return OP::template Create(block_manager, info, column_index, start_row, *parent); - } - return OP::template Create(block_manager, info, column_index, start_row, type, parent); -} - -shared_ptr ColumnData::CreateColumn(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, - idx_t start_row, const LogicalType &type, - optional_ptr parent) { - return CreateColumnInternal, SharedConstructor>(block_manager, info, column_index, start_row, - type, parent); -} - -unique_ptr ColumnData::CreateColumnUnique(BlockManager &block_manager, DataTableInfo &info, - idx_t column_index, idx_t start_row, const LogicalType &type, - optional_ptr parent) { - return CreateColumnInternal, UniqueConstructor>(block_manager, info, column_index, start_row, - type, parent); -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -ColumnDataCheckpointer::ColumnDataCheckpointer(ColumnData &col_data_p, RowGroup &row_group_p, - ColumnCheckpointState &state_p, ColumnCheckpointInfo &checkpoint_info_p) - : col_data(col_data_p), row_group(row_group_p), state(state_p), - is_validity(GetType().id() == LogicalTypeId::VALIDITY), - intermediate(is_validity ? LogicalType::BOOLEAN : GetType(), true, is_validity), - checkpoint_info(checkpoint_info_p) { - auto &config = DBConfig::GetConfig(GetDatabase()); - auto functions = config.GetCompressionFunctions(GetType().InternalType()); - for (auto &func : functions) { - compression_functions.push_back(&func.get()); - } -} - -DatabaseInstance &ColumnDataCheckpointer::GetDatabase() { - return col_data.GetDatabase(); -} - -const LogicalType &ColumnDataCheckpointer::GetType() const { - return col_data.type; -} - -ColumnData &ColumnDataCheckpointer::GetColumnData() { - return col_data; -} - -RowGroup &ColumnDataCheckpointer::GetRowGroup() { - return row_group; -} - -ColumnCheckpointState &ColumnDataCheckpointer::GetCheckpointState() { - return state; -} - -void ColumnDataCheckpointer::ScanSegments(const std::function &callback) { - Vector scan_vector(intermediate.GetType(), nullptr); - for (idx_t segment_idx = 0; segment_idx < nodes.size(); segment_idx++) { - auto &segment = *nodes[segment_idx].node; - ColumnScanState scan_state; - scan_state.current = &segment; - segment.InitializeScan(scan_state); - - for (idx_t base_row_index = 0; base_row_index < segment.count; base_row_index += STANDARD_VECTOR_SIZE) { - scan_vector.Reference(intermediate); - - idx_t count = MinValue(segment.count - base_row_index, STANDARD_VECTOR_SIZE); - scan_state.row_index = segment.start + base_row_index; - - col_data.CheckpointScan(segment, scan_state, row_group.start, count, scan_vector); - - callback(scan_vector, count); - } - } -} - -CompressionType ForceCompression(vector> &compression_functions, - CompressionType compression_type) { - // On of the force_compression flags has been set - // check if this compression method is available - bool found = false; - for (idx_t i = 0; i < compression_functions.size(); i++) { - auto &compression_function = *compression_functions[i]; - if (compression_function.type == compression_type) { - found = true; - break; - } - } - if (found) { - // the force_compression method is available - // clear all other compression methods - // except the uncompressed method, so we can fall back on that - for (idx_t i = 0; i < compression_functions.size(); i++) { - auto &compression_function = *compression_functions[i]; - if (compression_function.type == CompressionType::COMPRESSION_UNCOMPRESSED) { - continue; - } - if (compression_function.type != compression_type) { - compression_functions[i] = nullptr; - } - } - } - return found ? compression_type : CompressionType::COMPRESSION_AUTO; -} - -unique_ptr ColumnDataCheckpointer::DetectBestCompressionMethod(idx_t &compression_idx) { - D_ASSERT(!compression_functions.empty()); - auto &config = DBConfig::GetConfig(GetDatabase()); - CompressionType forced_method = CompressionType::COMPRESSION_AUTO; - - auto compression_type = checkpoint_info.compression_type; - if (compression_type != CompressionType::COMPRESSION_AUTO) { - forced_method = ForceCompression(compression_functions, compression_type); - } - if (compression_type == CompressionType::COMPRESSION_AUTO && - config.options.force_compression != CompressionType::COMPRESSION_AUTO) { - forced_method = ForceCompression(compression_functions, config.options.force_compression); - } - // set up the analyze states for each compression method - vector> analyze_states; - analyze_states.reserve(compression_functions.size()); - for (idx_t i = 0; i < compression_functions.size(); i++) { - if (!compression_functions[i]) { - analyze_states.push_back(nullptr); - continue; - } - analyze_states.push_back(compression_functions[i]->init_analyze(col_data, col_data.type.InternalType())); - } - - // scan over all the segments and run the analyze step - ScanSegments([&](Vector &scan_vector, idx_t count) { - for (idx_t i = 0; i < compression_functions.size(); i++) { - if (!compression_functions[i]) { - continue; - } - auto success = compression_functions[i]->analyze(*analyze_states[i], scan_vector, count); - if (!success) { - // could not use this compression function on this data set - // erase it - compression_functions[i] = nullptr; - analyze_states[i].reset(); - } - } - }); - - // now that we have passed over all the data, we need to figure out the best method - // we do this using the final_analyze method - unique_ptr state; - compression_idx = DConstants::INVALID_INDEX; - idx_t best_score = NumericLimits::Maximum(); - for (idx_t i = 0; i < compression_functions.size(); i++) { - if (!compression_functions[i]) { - continue; - } - //! Check if the method type is the forced method (if forced is used) - bool forced_method_found = compression_functions[i]->type == forced_method; - auto score = compression_functions[i]->final_analyze(*analyze_states[i]); - - //! The finalize method can return this value from final_analyze to indicate it should not be used. - if (score == DConstants::INVALID_INDEX) { - continue; - } - - if (score < best_score || forced_method_found) { - compression_idx = i; - best_score = score; - state = std::move(analyze_states[i]); - } - //! If we have found the forced method, we're done - if (forced_method_found) { - break; - } - } - return state; -} - -void ColumnDataCheckpointer::WriteToDisk() { - // there were changes or transient segments - // we need to rewrite the column segments to disk - - // first we check the current segments - // if there are any persistent segments, we will mark their old block ids as modified - // since the segments will be rewritten their old on disk data is no longer required - for (idx_t segment_idx = 0; segment_idx < nodes.size(); segment_idx++) { - auto segment = nodes[segment_idx].node.get(); - segment->CommitDropSegment(); - } - - // now we need to write our segment - // we will first run an analyze step that determines which compression function to use - idx_t compression_idx; - auto analyze_state = DetectBestCompressionMethod(compression_idx); - - if (!analyze_state) { - throw FatalException("No suitable compression/storage method found to store column"); - } - - // now that we have analyzed the compression functions we can start writing to disk - auto best_function = compression_functions[compression_idx]; - auto compress_state = best_function->init_compression(*this, std::move(analyze_state)); - ScanSegments( - [&](Vector &scan_vector, idx_t count) { best_function->compress(*compress_state, scan_vector, count); }); - best_function->compress_finalize(*compress_state); - - nodes.clear(); -} - -bool ColumnDataCheckpointer::HasChanges() { - for (idx_t segment_idx = 0; segment_idx < nodes.size(); segment_idx++) { - auto segment = nodes[segment_idx].node.get(); - if (segment->segment_type == ColumnSegmentType::TRANSIENT) { - // transient segment: always need to write to disk - return true; - } else { - // persistent segment; check if there were any updates or deletions in this segment - idx_t start_row_idx = segment->start - row_group.start; - idx_t end_row_idx = start_row_idx + segment->count; - if (col_data.updates && col_data.updates->HasUpdates(start_row_idx, end_row_idx)) { - return true; - } - } - } - return false; -} - -void ColumnDataCheckpointer::WritePersistentSegments() { - // all segments are persistent and there are no updates - // we only need to write the metadata - for (idx_t segment_idx = 0; segment_idx < nodes.size(); segment_idx++) { - auto segment = nodes[segment_idx].node.get(); - D_ASSERT(segment->segment_type == ColumnSegmentType::PERSISTENT); - - // set up the data pointer directly using the data from the persistent segment - DataPointer pointer(segment->stats.statistics.Copy()); - pointer.block_pointer.block_id = segment->GetBlockId(); - pointer.block_pointer.offset = segment->GetBlockOffset(); - pointer.row_start = segment->start; - pointer.tuple_count = segment->count; - pointer.compression_type = segment->function.get().type; - if (segment->function.get().serialize_state) { - pointer.segment_state = segment->function.get().serialize_state(*segment); - } - - // merge the persistent stats into the global column stats - state.global_stats->Merge(segment->stats.statistics); - - // directly append the current segment to the new tree - state.new_tree.AppendSegment(std::move(nodes[segment_idx].node)); - - state.data_pointers.push_back(std::move(pointer)); - } -} - -void ColumnDataCheckpointer::Checkpoint(vector> nodes_p) { - D_ASSERT(!nodes_p.empty()); - this->nodes = std::move(nodes_p); - // first check if any of the segments have changes - if (!HasChanges()) { - // no changes: only need to write the metadata for this column - WritePersistentSegments(); - } else { - // there are changes: rewrite the set of columns); - WriteToDisk(); - } -} - -CompressionFunction &ColumnDataCheckpointer::GetCompressionFunction(CompressionType compression_type) { - auto &db = GetDatabase(); - auto &column_type = GetType(); - auto &config = DBConfig::GetConfig(db); - return *config.GetCompressionFunction(compression_type, column_type.InternalType()); -} - -} // namespace duckdb - - - - - - - - - - - - - -#include - -namespace duckdb { - -unique_ptr ColumnSegment::CreatePersistentSegment(DatabaseInstance &db, BlockManager &block_manager, - block_id_t block_id, idx_t offset, - const LogicalType &type, idx_t start, idx_t count, - CompressionType compression_type, - BaseStatistics statistics, - unique_ptr segment_state) { - auto &config = DBConfig::GetConfig(db); - optional_ptr function; - shared_ptr block; - if (block_id == INVALID_BLOCK) { - // constant segment, no need to allocate an actual block - function = config.GetCompressionFunction(CompressionType::COMPRESSION_CONSTANT, type.InternalType()); - } else { - function = config.GetCompressionFunction(compression_type, type.InternalType()); - block = block_manager.RegisterBlock(block_id); - } - auto segment_size = Storage::BLOCK_SIZE; - return make_uniq(db, std::move(block), type, ColumnSegmentType::PERSISTENT, start, count, *function, - std::move(statistics), block_id, offset, segment_size, std::move(segment_state)); -} - -unique_ptr ColumnSegment::CreateTransientSegment(DatabaseInstance &db, const LogicalType &type, - idx_t start, idx_t segment_size) { - auto &config = DBConfig::GetConfig(db); - auto function = config.GetCompressionFunction(CompressionType::COMPRESSION_UNCOMPRESSED, type.InternalType()); - auto &buffer_manager = BufferManager::GetBufferManager(db); - shared_ptr block; - // transient: allocate a buffer for the uncompressed segment - if (segment_size < Storage::BLOCK_SIZE) { - block = buffer_manager.RegisterSmallMemory(segment_size); - } else { - buffer_manager.Allocate(segment_size, false, &block); - } - return make_uniq(db, std::move(block), type, ColumnSegmentType::TRANSIENT, start, 0, *function, - BaseStatistics::CreateEmpty(type), INVALID_BLOCK, 0, segment_size); -} - -unique_ptr ColumnSegment::CreateSegment(ColumnSegment &other, idx_t start) { - return make_uniq(other, start); -} - -ColumnSegment::ColumnSegment(DatabaseInstance &db, shared_ptr block, LogicalType type_p, - ColumnSegmentType segment_type, idx_t start, idx_t count, CompressionFunction &function_p, - BaseStatistics statistics, block_id_t block_id_p, idx_t offset_p, idx_t segment_size_p, - unique_ptr segment_state) - : SegmentBase(start, count), db(db), type(std::move(type_p)), - type_size(GetTypeIdSize(type.InternalType())), segment_type(segment_type), function(function_p), - stats(std::move(statistics)), block(std::move(block)), block_id(block_id_p), offset(offset_p), - segment_size(segment_size_p) { - if (function.get().init_segment) { - this->segment_state = function.get().init_segment(*this, block_id, segment_state.get()); - } -} - -ColumnSegment::ColumnSegment(ColumnSegment &other, idx_t start) - : SegmentBase(start, other.count.load()), db(other.db), type(std::move(other.type)), - type_size(other.type_size), segment_type(other.segment_type), function(other.function), - stats(std::move(other.stats)), block(std::move(other.block)), block_id(other.block_id), offset(other.offset), - segment_size(other.segment_size), segment_state(std::move(other.segment_state)) { -} - -ColumnSegment::~ColumnSegment() { -} - -//===--------------------------------------------------------------------===// -// Scan -//===--------------------------------------------------------------------===// -void ColumnSegment::InitializeScan(ColumnScanState &state) { - state.scan_state = function.get().init_scan(*this); -} - -void ColumnSegment::Scan(ColumnScanState &state, idx_t scan_count, Vector &result, idx_t result_offset, - bool entire_vector) { - if (entire_vector) { - D_ASSERT(result_offset == 0); - Scan(state, scan_count, result); - } else { - D_ASSERT(result.GetVectorType() == VectorType::FLAT_VECTOR); - ScanPartial(state, scan_count, result, result_offset); - D_ASSERT(result.GetVectorType() == VectorType::FLAT_VECTOR); - } -} - -void ColumnSegment::Skip(ColumnScanState &state) { - function.get().skip(*this, state, state.row_index - state.internal_index); - state.internal_index = state.row_index; -} - -void ColumnSegment::Scan(ColumnScanState &state, idx_t scan_count, Vector &result) { - function.get().scan_vector(*this, state, scan_count, result); -} - -void ColumnSegment::ScanPartial(ColumnScanState &state, idx_t scan_count, Vector &result, idx_t result_offset) { - function.get().scan_partial(*this, state, scan_count, result, result_offset); -} - -//===--------------------------------------------------------------------===// -// Fetch -//===--------------------------------------------------------------------===// -void ColumnSegment::FetchRow(ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) { - function.get().fetch_row(*this, state, row_id - this->start, result, result_idx); -} - -//===--------------------------------------------------------------------===// -// Append -//===--------------------------------------------------------------------===// -idx_t ColumnSegment::SegmentSize() const { - return segment_size; -} - -void ColumnSegment::Resize(idx_t new_size) { - D_ASSERT(new_size > this->segment_size); - D_ASSERT(offset == 0); - auto &buffer_manager = BufferManager::GetBufferManager(db); - auto old_handle = buffer_manager.Pin(block); - shared_ptr new_block; - auto new_handle = buffer_manager.Allocate(Storage::BLOCK_SIZE, false, &new_block); - memcpy(new_handle.Ptr(), old_handle.Ptr(), segment_size); - this->block_id = new_block->BlockId(); - this->block = std::move(new_block); - this->segment_size = new_size; -} - -void ColumnSegment::InitializeAppend(ColumnAppendState &state) { - D_ASSERT(segment_type == ColumnSegmentType::TRANSIENT); - if (!function.get().init_append) { - throw InternalException("Attempting to init append to a segment without init_append method"); - } - state.append_state = function.get().init_append(*this); -} - -idx_t ColumnSegment::Append(ColumnAppendState &state, UnifiedVectorFormat &append_data, idx_t offset, idx_t count) { - D_ASSERT(segment_type == ColumnSegmentType::TRANSIENT); - if (!function.get().append) { - throw InternalException("Attempting to append to a segment without append method"); - } - return function.get().append(*state.append_state, *this, stats, append_data, offset, count); -} - -idx_t ColumnSegment::FinalizeAppend(ColumnAppendState &state) { - D_ASSERT(segment_type == ColumnSegmentType::TRANSIENT); - if (!function.get().finalize_append) { - throw InternalException("Attempting to call FinalizeAppend on a segment without a finalize_append method"); - } - auto result_count = function.get().finalize_append(*this, stats); - state.append_state.reset(); - return result_count; -} - -void ColumnSegment::RevertAppend(idx_t start_row) { - D_ASSERT(segment_type == ColumnSegmentType::TRANSIENT); - if (function.get().revert_append) { - function.get().revert_append(*this, start_row); - } - this->count = start_row - this->start; -} - -//===--------------------------------------------------------------------===// -// Convert To Persistent -//===--------------------------------------------------------------------===// -void ColumnSegment::ConvertToPersistent(optional_ptr block_manager, block_id_t block_id_p) { - D_ASSERT(segment_type == ColumnSegmentType::TRANSIENT); - segment_type = ColumnSegmentType::PERSISTENT; - - block_id = block_id_p; - offset = 0; - - if (block_id == INVALID_BLOCK) { - // constant block: reset the block buffer - D_ASSERT(stats.statistics.IsConstant()); - block.reset(); - } else { - D_ASSERT(!stats.statistics.IsConstant()); - // non-constant block: write the block to disk - // the data for the block already exists in-memory of our block - // instead of copying the data we alter some metadata so the buffer points to an on-disk block - block = block_manager->ConvertToPersistent(block_id, std::move(block)); - } -} - -void ColumnSegment::MarkAsPersistent(shared_ptr block_p, uint32_t offset_p) { - D_ASSERT(segment_type == ColumnSegmentType::TRANSIENT); - segment_type = ColumnSegmentType::PERSISTENT; - - block_id = block_p->BlockId(); - offset = offset_p; - block = std::move(block_p); -} - -//===--------------------------------------------------------------------===// -// Drop Segment -//===--------------------------------------------------------------------===// -void ColumnSegment::CommitDropSegment() { - if (segment_type != ColumnSegmentType::PERSISTENT) { - // not persistent - return; - } - if (block_id != INVALID_BLOCK) { - GetBlockManager().MarkBlockAsModified(block_id); - } - if (function.get().cleanup_state) { - function.get().cleanup_state(*this); - } -} - -//===--------------------------------------------------------------------===// -// Filter Selection -//===--------------------------------------------------------------------===// -template -static idx_t TemplatedFilterSelection(T *vec, T predicate, SelectionVector &sel, idx_t approved_tuple_count, - ValidityMask &mask, SelectionVector &result_sel) { - idx_t result_count = 0; - for (idx_t i = 0; i < approved_tuple_count; i++) { - auto idx = sel.get_index(i); - if ((!HAS_NULL || mask.RowIsValid(idx)) && OP::Operation(vec[idx], predicate)) { - result_sel.set_index(result_count++, idx); - } - } - return result_count; -} - -template -static void FilterSelectionSwitch(T *vec, T predicate, SelectionVector &sel, idx_t &approved_tuple_count, - ExpressionType comparison_type, ValidityMask &mask) { - SelectionVector new_sel(approved_tuple_count); - // the inplace loops take the result as the last parameter - switch (comparison_type) { - case ExpressionType::COMPARE_EQUAL: { - if (mask.AllValid()) { - approved_tuple_count = - TemplatedFilterSelection(vec, predicate, sel, approved_tuple_count, mask, new_sel); - } else { - approved_tuple_count = - TemplatedFilterSelection(vec, predicate, sel, approved_tuple_count, mask, new_sel); - } - break; - } - case ExpressionType::COMPARE_NOTEQUAL: { - if (mask.AllValid()) { - approved_tuple_count = - TemplatedFilterSelection(vec, predicate, sel, approved_tuple_count, mask, new_sel); - } else { - approved_tuple_count = - TemplatedFilterSelection(vec, predicate, sel, approved_tuple_count, mask, new_sel); - } - break; - } - case ExpressionType::COMPARE_LESSTHAN: { - if (mask.AllValid()) { - approved_tuple_count = - TemplatedFilterSelection(vec, predicate, sel, approved_tuple_count, mask, new_sel); - } else { - approved_tuple_count = - TemplatedFilterSelection(vec, predicate, sel, approved_tuple_count, mask, new_sel); - } - break; - } - case ExpressionType::COMPARE_GREATERTHAN: { - if (mask.AllValid()) { - approved_tuple_count = TemplatedFilterSelection(vec, predicate, sel, - approved_tuple_count, mask, new_sel); - } else { - approved_tuple_count = TemplatedFilterSelection(vec, predicate, sel, - approved_tuple_count, mask, new_sel); - } - break; - } - case ExpressionType::COMPARE_LESSTHANOREQUALTO: { - if (mask.AllValid()) { - approved_tuple_count = TemplatedFilterSelection( - vec, predicate, sel, approved_tuple_count, mask, new_sel); - } else { - approved_tuple_count = TemplatedFilterSelection( - vec, predicate, sel, approved_tuple_count, mask, new_sel); - } - break; - } - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: { - if (mask.AllValid()) { - approved_tuple_count = TemplatedFilterSelection( - vec, predicate, sel, approved_tuple_count, mask, new_sel); - } else { - approved_tuple_count = TemplatedFilterSelection( - vec, predicate, sel, approved_tuple_count, mask, new_sel); - } - break; - } - default: - throw NotImplementedException("Unknown comparison type for filter pushed down to table!"); - } - sel.Initialize(new_sel); -} - -template -static idx_t TemplatedNullSelection(SelectionVector &sel, idx_t &approved_tuple_count, ValidityMask &mask) { - if (mask.AllValid()) { - // no NULL values - if (IS_NULL) { - approved_tuple_count = 0; - return 0; - } else { - return approved_tuple_count; - } - } else { - SelectionVector result_sel(approved_tuple_count); - idx_t result_count = 0; - for (idx_t i = 0; i < approved_tuple_count; i++) { - auto idx = sel.get_index(i); - if (mask.RowIsValid(idx) != IS_NULL) { - result_sel.set_index(result_count++, idx); - } - } - sel.Initialize(result_sel); - approved_tuple_count = result_count; - return result_count; - } -} - -idx_t ColumnSegment::FilterSelection(SelectionVector &sel, Vector &result, const TableFilter &filter, - idx_t &approved_tuple_count, ValidityMask &mask) { - switch (filter.filter_type) { - case TableFilterType::CONJUNCTION_OR: { - // similar to the CONJUNCTION_AND, but we need to take care of the SelectionVectors (OR all of them) - idx_t count_total = 0; - SelectionVector result_sel(approved_tuple_count); - auto &conjunction_or = filter.Cast(); - for (auto &child_filter : conjunction_or.child_filters) { - SelectionVector temp_sel; - temp_sel.Initialize(sel); - idx_t temp_tuple_count = approved_tuple_count; - idx_t temp_count = FilterSelection(temp_sel, result, *child_filter, temp_tuple_count, mask); - // tuples passed, move them into the actual result vector - for (idx_t i = 0; i < temp_count; i++) { - auto new_idx = temp_sel.get_index(i); - bool is_new_idx = true; - for (idx_t res_idx = 0; res_idx < count_total; res_idx++) { - if (result_sel.get_index(res_idx) == new_idx) { - is_new_idx = false; - break; - } - } - if (is_new_idx) { - result_sel.set_index(count_total++, new_idx); - } - } - } - sel.Initialize(result_sel); - approved_tuple_count = count_total; - return approved_tuple_count; - } - case TableFilterType::CONJUNCTION_AND: { - auto &conjunction_and = filter.Cast(); - for (auto &child_filter : conjunction_and.child_filters) { - FilterSelection(sel, result, *child_filter, approved_tuple_count, mask); - } - return approved_tuple_count; - } - case TableFilterType::CONSTANT_COMPARISON: { - auto &constant_filter = filter.Cast(); - // the inplace loops take the result as the last parameter - switch (result.GetType().InternalType()) { - case PhysicalType::UINT8: { - auto result_flat = FlatVector::GetData(result); - auto predicate = UTinyIntValue::Get(constant_filter.constant); - FilterSelectionSwitch(result_flat, predicate, sel, approved_tuple_count, - constant_filter.comparison_type, mask); - break; - } - case PhysicalType::UINT16: { - auto result_flat = FlatVector::GetData(result); - auto predicate = USmallIntValue::Get(constant_filter.constant); - FilterSelectionSwitch(result_flat, predicate, sel, approved_tuple_count, - constant_filter.comparison_type, mask); - break; - } - case PhysicalType::UINT32: { - auto result_flat = FlatVector::GetData(result); - auto predicate = UIntegerValue::Get(constant_filter.constant); - FilterSelectionSwitch(result_flat, predicate, sel, approved_tuple_count, - constant_filter.comparison_type, mask); - break; - } - case PhysicalType::UINT64: { - auto result_flat = FlatVector::GetData(result); - auto predicate = UBigIntValue::Get(constant_filter.constant); - FilterSelectionSwitch(result_flat, predicate, sel, approved_tuple_count, - constant_filter.comparison_type, mask); - break; - } - case PhysicalType::INT8: { - auto result_flat = FlatVector::GetData(result); - auto predicate = TinyIntValue::Get(constant_filter.constant); - FilterSelectionSwitch(result_flat, predicate, sel, approved_tuple_count, - constant_filter.comparison_type, mask); - break; - } - case PhysicalType::INT16: { - auto result_flat = FlatVector::GetData(result); - auto predicate = SmallIntValue::Get(constant_filter.constant); - FilterSelectionSwitch(result_flat, predicate, sel, approved_tuple_count, - constant_filter.comparison_type, mask); - break; - } - case PhysicalType::INT32: { - auto result_flat = FlatVector::GetData(result); - auto predicate = IntegerValue::Get(constant_filter.constant); - FilterSelectionSwitch(result_flat, predicate, sel, approved_tuple_count, - constant_filter.comparison_type, mask); - break; - } - case PhysicalType::INT64: { - auto result_flat = FlatVector::GetData(result); - auto predicate = BigIntValue::Get(constant_filter.constant); - FilterSelectionSwitch(result_flat, predicate, sel, approved_tuple_count, - constant_filter.comparison_type, mask); - break; - } - case PhysicalType::INT128: { - auto result_flat = FlatVector::GetData(result); - auto predicate = HugeIntValue::Get(constant_filter.constant); - FilterSelectionSwitch(result_flat, predicate, sel, approved_tuple_count, - constant_filter.comparison_type, mask); - break; - } - case PhysicalType::FLOAT: { - auto result_flat = FlatVector::GetData(result); - auto predicate = FloatValue::Get(constant_filter.constant); - FilterSelectionSwitch(result_flat, predicate, sel, approved_tuple_count, - constant_filter.comparison_type, mask); - break; - } - case PhysicalType::DOUBLE: { - auto result_flat = FlatVector::GetData(result); - auto predicate = DoubleValue::Get(constant_filter.constant); - FilterSelectionSwitch(result_flat, predicate, sel, approved_tuple_count, - constant_filter.comparison_type, mask); - break; - } - case PhysicalType::VARCHAR: { - auto result_flat = FlatVector::GetData(result); - auto predicate = string_t(StringValue::Get(constant_filter.constant)); - FilterSelectionSwitch(result_flat, predicate, sel, approved_tuple_count, - constant_filter.comparison_type, mask); - break; - } - case PhysicalType::BOOL: { - auto result_flat = FlatVector::GetData(result); - auto predicate = BooleanValue::Get(constant_filter.constant); - FilterSelectionSwitch(result_flat, predicate, sel, approved_tuple_count, - constant_filter.comparison_type, mask); - break; - } - default: - throw InvalidTypeException(result.GetType(), "Invalid type for filter pushed down to table comparison"); - } - return approved_tuple_count; - } - case TableFilterType::IS_NULL: - return TemplatedNullSelection(sel, approved_tuple_count, mask); - case TableFilterType::IS_NOT_NULL: - return TemplatedNullSelection(sel, approved_tuple_count, mask); - default: - throw InternalException("FIXME: unsupported type for filter selection"); - } -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -ListColumnData::ListColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, idx_t start_row, - LogicalType type_p, optional_ptr parent) - : ColumnData(block_manager, info, column_index, start_row, std::move(type_p), parent), - validity(block_manager, info, 0, start_row, *this) { - D_ASSERT(type.InternalType() == PhysicalType::LIST); - auto &child_type = ListType::GetChildType(type); - // the child column, with column index 1 (0 is the validity mask) - child_column = ColumnData::CreateColumnUnique(block_manager, info, 1, start_row, child_type, this); -} - -void ListColumnData::SetStart(idx_t new_start) { - ColumnData::SetStart(new_start); - child_column->SetStart(new_start); - validity.SetStart(new_start); -} - -bool ListColumnData::CheckZonemap(ColumnScanState &state, TableFilter &filter) { - // table filters are not supported yet for list columns - return false; -} - -void ListColumnData::InitializeScan(ColumnScanState &state) { - ColumnData::InitializeScan(state); - - // initialize the validity segment - D_ASSERT(state.child_states.size() == 2); - validity.InitializeScan(state.child_states[0]); - - // initialize the child scan - child_column->InitializeScan(state.child_states[1]); -} - -uint64_t ListColumnData::FetchListOffset(idx_t row_idx) { - auto segment = data.GetSegment(row_idx); - ColumnFetchState fetch_state; - Vector result(type, 1); - segment->FetchRow(fetch_state, row_idx, result, 0); - - // initialize the child scan with the required offset - return FlatVector::GetData(result)[0]; -} - -void ListColumnData::InitializeScanWithOffset(ColumnScanState &state, idx_t row_idx) { - if (row_idx == 0) { - InitializeScan(state); - return; - } - ColumnData::InitializeScanWithOffset(state, row_idx); - - // initialize the validity segment - D_ASSERT(state.child_states.size() == 2); - validity.InitializeScanWithOffset(state.child_states[0], row_idx); - - // we need to read the list at position row_idx to get the correct row offset of the child - auto child_offset = row_idx == start ? 0 : FetchListOffset(row_idx - 1); - D_ASSERT(child_offset <= child_column->GetMaxEntry()); - if (child_offset < child_column->GetMaxEntry()) { - child_column->InitializeScanWithOffset(state.child_states[1], start + child_offset); - } - state.last_offset = child_offset; -} - -idx_t ListColumnData::Scan(TransactionData transaction, idx_t vector_index, ColumnScanState &state, Vector &result) { - return ScanCount(state, result, STANDARD_VECTOR_SIZE); -} - -idx_t ListColumnData::ScanCommitted(idx_t vector_index, ColumnScanState &state, Vector &result, bool allow_updates) { - return ScanCount(state, result, STANDARD_VECTOR_SIZE); -} - -idx_t ListColumnData::ScanCount(ColumnScanState &state, Vector &result, idx_t count) { - if (count == 0) { - return 0; - } - // updates not supported for lists - D_ASSERT(!updates); - - Vector offset_vector(LogicalType::UBIGINT, count); - idx_t scan_count = ScanVector(state, offset_vector, count, false); - D_ASSERT(scan_count > 0); - validity.ScanCount(state.child_states[0], result, count); - - UnifiedVectorFormat offsets; - offset_vector.ToUnifiedFormat(scan_count, offsets); - auto data = UnifiedVectorFormat::GetData(offsets); - auto last_entry = data[offsets.sel->get_index(scan_count - 1)]; - - // shift all offsets so they are 0 at the first entry - auto result_data = FlatVector::GetData(result); - auto base_offset = state.last_offset; - idx_t current_offset = 0; - for (idx_t i = 0; i < scan_count; i++) { - auto offset_index = offsets.sel->get_index(i); - result_data[i].offset = current_offset; - result_data[i].length = data[offset_index] - current_offset - base_offset; - current_offset += result_data[i].length; - } - - D_ASSERT(last_entry >= base_offset); - idx_t child_scan_count = last_entry - base_offset; - ListVector::Reserve(result, child_scan_count); - - if (child_scan_count > 0) { - auto &child_entry = ListVector::GetEntry(result); - if (child_entry.GetType().InternalType() != PhysicalType::STRUCT && - state.child_states[1].row_index + child_scan_count > child_column->start + child_column->GetMaxEntry()) { - throw InternalException("ListColumnData::ScanCount - internal list scan offset is out of range"); - } - child_column->ScanCount(state.child_states[1], child_entry, child_scan_count); - } - state.last_offset = last_entry; - - ListVector::SetListSize(result, child_scan_count); - return scan_count; -} - -void ListColumnData::Skip(ColumnScanState &state, idx_t count) { - // skip inside the validity segment - validity.Skip(state.child_states[0], count); - - // we need to read the list entries/offsets to figure out how much to skip - // note that we only need to read the first and last entry - // however, let's just read all "count" entries for now - Vector result(LogicalType::UBIGINT, count); - idx_t scan_count = ScanVector(state, result, count, false); - if (scan_count == 0) { - return; - } - - auto data = FlatVector::GetData(result); - auto last_entry = data[scan_count - 1]; - idx_t child_scan_count = last_entry - state.last_offset; - if (child_scan_count == 0) { - return; - } - state.last_offset = last_entry; - - // skip the child state forward by the child_scan_count - child_column->Skip(state.child_states[1], child_scan_count); -} - -void ListColumnData::InitializeAppend(ColumnAppendState &state) { - // initialize the list offset append - ColumnData::InitializeAppend(state); - - // initialize the validity append - ColumnAppendState validity_append_state; - validity.InitializeAppend(validity_append_state); - state.child_appends.push_back(std::move(validity_append_state)); - - // initialize the child column append - ColumnAppendState child_append_state; - child_column->InitializeAppend(child_append_state); - state.child_appends.push_back(std::move(child_append_state)); -} - -void ListColumnData::Append(BaseStatistics &stats, ColumnAppendState &state, Vector &vector, idx_t count) { - D_ASSERT(count > 0); - UnifiedVectorFormat list_data; - vector.ToUnifiedFormat(count, list_data); - auto &list_validity = list_data.validity; - - // construct the list_entry_t entries to append to the column data - auto input_offsets = UnifiedVectorFormat::GetData(list_data); - auto start_offset = child_column->GetMaxEntry(); - idx_t child_count = 0; - - ValidityMask append_mask(count); - auto append_offsets = unique_ptr(new uint64_t[count]); - bool child_contiguous = true; - for (idx_t i = 0; i < count; i++) { - auto input_idx = list_data.sel->get_index(i); - if (list_validity.RowIsValid(input_idx)) { - auto &input_list = input_offsets[input_idx]; - if (input_list.offset != child_count) { - child_contiguous = false; - } - append_offsets[i] = start_offset + child_count + input_list.length; - child_count += input_list.length; - } else { - append_mask.SetInvalid(i); - append_offsets[i] = start_offset + child_count; - } - } - auto &list_child = ListVector::GetEntry(vector); - Vector child_vector(list_child); - if (!child_contiguous) { - // if the child of the list vector is a non-contiguous vector (i.e. list elements are repeating or have gaps) - // we first push a selection vector and flatten the child vector to turn it into a contiguous vector - SelectionVector child_sel(child_count); - idx_t current_count = 0; - for (idx_t i = 0; i < count; i++) { - auto input_idx = list_data.sel->get_index(i); - if (list_validity.RowIsValid(input_idx)) { - auto &input_list = input_offsets[input_idx]; - for (idx_t list_idx = 0; list_idx < input_list.length; list_idx++) { - child_sel.set_index(current_count++, input_list.offset + list_idx); - } - } - } - D_ASSERT(current_count == child_count); - child_vector.Slice(list_child, child_sel, child_count); - } - - UnifiedVectorFormat vdata; - vdata.sel = FlatVector::IncrementalSelectionVector(); - vdata.data = data_ptr_cast(append_offsets.get()); - - // append the list offsets - ColumnData::AppendData(stats, state, vdata, count); - // append the validity data - vdata.validity = append_mask; - validity.AppendData(stats, state.child_appends[0], vdata, count); - // append the child vector - if (child_count > 0) { - child_column->Append(ListStats::GetChildStats(stats), state.child_appends[1], child_vector, child_count); - } -} - -void ListColumnData::RevertAppend(row_t start_row) { - ColumnData::RevertAppend(start_row); - validity.RevertAppend(start_row); - auto column_count = GetMaxEntry(); - if (column_count > start) { - // revert append in the child column - auto list_offset = FetchListOffset(column_count - 1); - child_column->RevertAppend(list_offset); - } -} - -idx_t ListColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &result) { - throw NotImplementedException("List Fetch"); -} - -void ListColumnData::Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) { - throw NotImplementedException("List Update is not supported."); -} - -void ListColumnData::UpdateColumn(TransactionData transaction, const vector &column_path, - Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) { - throw NotImplementedException("List Update Column is not supported"); -} - -unique_ptr ListColumnData::GetUpdateStatistics() { - return nullptr; -} - -void ListColumnData::FetchRow(TransactionData transaction, ColumnFetchState &state, row_t row_id, Vector &result, - idx_t result_idx) { - // insert any child states that are required - // we need two (validity & list child) - // note that we need a scan state for the child vector - // this is because we will (potentially) fetch more than one tuple from the list child - if (state.child_states.empty()) { - auto child_state = make_uniq(); - state.child_states.push_back(std::move(child_state)); - } - - // now perform the fetch within the segment - auto start_offset = idx_t(row_id) == this->start ? 0 : FetchListOffset(row_id - 1); - auto end_offset = FetchListOffset(row_id); - validity.FetchRow(transaction, *state.child_states[0], row_id, result, result_idx); - - auto &validity = FlatVector::Validity(result); - auto list_data = FlatVector::GetData(result); - auto &list_entry = list_data[result_idx]; - // set the list entry offset to the size of the current list - list_entry.offset = ListVector::GetListSize(result); - list_entry.length = end_offset - start_offset; - if (!validity.RowIsValid(result_idx)) { - // the list is NULL! no need to fetch the child - D_ASSERT(list_entry.length == 0); - return; - } - - // now we need to read from the child all the elements between [offset...length] - auto child_scan_count = list_entry.length; - if (child_scan_count > 0) { - auto child_state = make_uniq(); - auto &child_type = ListType::GetChildType(result.GetType()); - Vector child_scan(child_type, child_scan_count); - // seek the scan towards the specified position and read [length] entries - child_state->Initialize(child_type); - child_column->InitializeScanWithOffset(*child_state, start + start_offset); - D_ASSERT(child_type.InternalType() == PhysicalType::STRUCT || - child_state->row_index + child_scan_count - this->start <= child_column->GetMaxEntry()); - child_column->ScanCount(*child_state, child_scan, child_scan_count); - - ListVector::Append(result, child_scan, child_scan_count); - } -} - -void ListColumnData::CommitDropColumn() { - validity.CommitDropColumn(); - child_column->CommitDropColumn(); -} - -struct ListColumnCheckpointState : public ColumnCheckpointState { - ListColumnCheckpointState(RowGroup &row_group, ColumnData &column_data, PartialBlockManager &partial_block_manager) - : ColumnCheckpointState(row_group, column_data, partial_block_manager) { - global_stats = ListStats::CreateEmpty(column_data.type).ToUnique(); - } - - unique_ptr validity_state; - unique_ptr child_state; - -public: - unique_ptr GetStatistics() override { - auto stats = global_stats->Copy(); - ListStats::SetChildStats(stats, child_state->GetStatistics()); - return stats.ToUnique(); - } - - void WriteDataPointers(RowGroupWriter &writer, Serializer &serializer) override { - ColumnCheckpointState::WriteDataPointers(writer, serializer); - serializer.WriteObject(101, "validity", - [&](Serializer &serializer) { validity_state->WriteDataPointers(writer, serializer); }); - serializer.WriteObject(102, "child_column", - [&](Serializer &serializer) { child_state->WriteDataPointers(writer, serializer); }); - } -}; - -unique_ptr ListColumnData::CreateCheckpointState(RowGroup &row_group, - PartialBlockManager &partial_block_manager) { - return make_uniq(row_group, *this, partial_block_manager); -} - -unique_ptr ListColumnData::Checkpoint(RowGroup &row_group, - PartialBlockManager &partial_block_manager, - ColumnCheckpointInfo &checkpoint_info) { - auto validity_state = validity.Checkpoint(row_group, partial_block_manager, checkpoint_info); - auto base_state = ColumnData::Checkpoint(row_group, partial_block_manager, checkpoint_info); - auto child_state = child_column->Checkpoint(row_group, partial_block_manager, checkpoint_info); - - auto &checkpoint_state = base_state->Cast(); - checkpoint_state.validity_state = std::move(validity_state); - checkpoint_state.child_state = std::move(child_state); - return base_state; -} - -void ListColumnData::DeserializeColumn(Deserializer &deserializer) { - ColumnData::DeserializeColumn(deserializer); - - deserializer.ReadObject(101, "validity", - [&](Deserializer &deserializer) { validity.DeserializeColumn(deserializer); }); - - deserializer.ReadObject(102, "child_column", - [&](Deserializer &deserializer) { child_column->DeserializeColumn(deserializer); }); -} - -void ListColumnData::GetColumnSegmentInfo(duckdb::idx_t row_group_index, vector col_path, - vector &result) { - ColumnData::GetColumnSegmentInfo(row_group_index, col_path, result); - col_path.push_back(0); - validity.GetColumnSegmentInfo(row_group_index, col_path, result); - col_path.back() = 1; - child_column->GetColumnSegmentInfo(row_group_index, col_path, result); -} - -} // namespace duckdb - - - -namespace duckdb { - -PersistentTableData::PersistentTableData(idx_t column_count) : total_rows(0), row_group_count(0) { -} - -PersistentTableData::~PersistentTableData() { -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - - - - - - -namespace duckdb { - -RowGroup::RowGroup(RowGroupCollection &collection, idx_t start, idx_t count) - : SegmentBase(start, count), collection(collection) { - Verify(); -} - -RowGroup::RowGroup(RowGroupCollection &collection, RowGroupPointer &&pointer) - : SegmentBase(pointer.row_start, pointer.tuple_count), collection(collection) { - // deserialize the columns - if (pointer.data_pointers.size() != collection.GetTypes().size()) { - throw IOException("Row group column count is unaligned with table column count. Corrupt file?"); - } - this->column_pointers = std::move(pointer.data_pointers); - this->columns.resize(column_pointers.size()); - this->is_loaded = unique_ptr[]>(new atomic[columns.size()]); - for (idx_t c = 0; c < columns.size(); c++) { - this->is_loaded[c] = false; - } - this->deletes_pointers = std::move(pointer.deletes_pointers); - this->deletes_is_loaded = false; - - Verify(); -} - -void RowGroup::MoveToCollection(RowGroupCollection &collection, idx_t new_start) { - this->collection = collection; - this->start = new_start; - for (auto &column : GetColumns()) { - column->SetStart(new_start); - } - if (!HasUnloadedDeletes()) { - auto &vinfo = GetVersionInfo(); - if (vinfo) { - vinfo->SetStart(new_start); - } - } -} - -RowGroup::~RowGroup() { -} - -vector> &RowGroup::GetColumns() { - // ensure all columns are loaded - for (idx_t c = 0; c < GetColumnCount(); c++) { - GetColumn(c); - } - return columns; -} - -idx_t RowGroup::GetColumnCount() const { - return columns.size(); -} - -ColumnData &RowGroup::GetColumn(storage_t c) { - D_ASSERT(c < columns.size()); - if (!is_loaded) { - // not being lazy loaded - D_ASSERT(columns[c]); - return *columns[c]; - } - if (is_loaded[c]) { - D_ASSERT(columns[c]); - return *columns[c]; - } - lock_guard l(row_group_lock); - if (columns[c]) { - D_ASSERT(is_loaded[c]); - return *columns[c]; - } - if (column_pointers.size() != columns.size()) { - throw InternalException("Lazy loading a column but the pointer was not set"); - } - auto &metadata_manager = GetCollection().GetMetadataManager(); - auto &types = GetCollection().GetTypes(); - auto &block_pointer = column_pointers[c]; - MetadataReader column_data_reader(metadata_manager, block_pointer); - this->columns[c] = - ColumnData::Deserialize(GetBlockManager(), GetTableInfo(), c, start, column_data_reader, types[c], nullptr); - is_loaded[c] = true; - if (this->columns[c]->count != this->count) { - throw InternalException("Corrupted database - loaded column with index %llu at row start %llu, count %llu did " - "not match count of row group %llu", - c, start, this->columns[c]->count, this->count.load()); - } - return *columns[c]; -} - -BlockManager &RowGroup::GetBlockManager() { - return GetCollection().GetBlockManager(); -} -DataTableInfo &RowGroup::GetTableInfo() { - return GetCollection().GetTableInfo(); -} - -void RowGroup::InitializeEmpty(const vector &types) { - // set up the segment trees for the column segments - D_ASSERT(columns.empty()); - for (idx_t i = 0; i < types.size(); i++) { - auto column_data = ColumnData::CreateColumn(GetBlockManager(), GetTableInfo(), i, start, types[i]); - columns.push_back(std::move(column_data)); - } -} - -void ColumnScanState::Initialize(const LogicalType &type) { - if (type.id() == LogicalTypeId::VALIDITY) { - // validity - nothing to initialize - return; - } - if (type.InternalType() == PhysicalType::STRUCT) { - // validity + struct children - auto &struct_children = StructType::GetChildTypes(type); - child_states.resize(struct_children.size() + 1); - for (idx_t i = 0; i < struct_children.size(); i++) { - child_states[i + 1].Initialize(struct_children[i].second); - } - } else if (type.InternalType() == PhysicalType::LIST) { - // validity + list child - child_states.resize(2); - child_states[1].Initialize(ListType::GetChildType(type)); - } else { - // validity - child_states.resize(1); - } -} - -void CollectionScanState::Initialize(const vector &types) { - auto &column_ids = GetColumnIds(); - column_scans = make_unsafe_uniq_array(column_ids.size()); - for (idx_t i = 0; i < column_ids.size(); i++) { - if (column_ids[i] == COLUMN_IDENTIFIER_ROW_ID) { - continue; - } - column_scans[i].Initialize(types[column_ids[i]]); - } -} - -bool RowGroup::InitializeScanWithOffset(CollectionScanState &state, idx_t vector_offset) { - auto &column_ids = state.GetColumnIds(); - auto filters = state.GetFilters(); - if (filters) { - if (!CheckZonemap(*filters, column_ids)) { - return false; - } - } - - state.row_group = this; - state.vector_index = vector_offset; - state.max_row_group_row = - this->start > state.max_row ? 0 : MinValue(this->count, state.max_row - this->start); - D_ASSERT(state.column_scans); - for (idx_t i = 0; i < column_ids.size(); i++) { - const auto &column = column_ids[i]; - if (column != COLUMN_IDENTIFIER_ROW_ID) { - auto &column_data = GetColumn(column); - column_data.InitializeScanWithOffset(state.column_scans[i], start + vector_offset * STANDARD_VECTOR_SIZE); - } else { - state.column_scans[i].current = nullptr; - } - } - return true; -} - -bool RowGroup::InitializeScan(CollectionScanState &state) { - auto &column_ids = state.GetColumnIds(); - auto filters = state.GetFilters(); - if (filters) { - if (!CheckZonemap(*filters, column_ids)) { - return false; - } - } - state.row_group = this; - state.vector_index = 0; - state.max_row_group_row = - this->start > state.max_row ? 0 : MinValue(this->count, state.max_row - this->start); - if (state.max_row_group_row == 0) { - return false; - } - D_ASSERT(state.column_scans); - for (idx_t i = 0; i < column_ids.size(); i++) { - auto column = column_ids[i]; - if (column != COLUMN_IDENTIFIER_ROW_ID) { - auto &column_data = GetColumn(column); - column_data.InitializeScan(state.column_scans[i]); - } else { - state.column_scans[i].current = nullptr; - } - } - return true; -} - -unique_ptr RowGroup::AlterType(RowGroupCollection &new_collection, const LogicalType &target_type, - idx_t changed_idx, ExpressionExecutor &executor, - CollectionScanState &scan_state, DataChunk &scan_chunk) { - Verify(); - - // construct a new column data for this type - auto column_data = ColumnData::CreateColumn(GetBlockManager(), GetTableInfo(), changed_idx, start, target_type); - - ColumnAppendState append_state; - column_data->InitializeAppend(append_state); - - // scan the original table, and fill the new column with the transformed value - scan_state.Initialize(GetCollection().GetTypes()); - InitializeScan(scan_state); - - DataChunk append_chunk; - vector append_types; - append_types.push_back(target_type); - append_chunk.Initialize(Allocator::DefaultAllocator(), append_types); - auto &append_vector = append_chunk.data[0]; - while (true) { - // scan the table - scan_chunk.Reset(); - ScanCommitted(scan_state, scan_chunk, TableScanType::TABLE_SCAN_COMMITTED_ROWS); - if (scan_chunk.size() == 0) { - break; - } - // execute the expression - append_chunk.Reset(); - executor.ExecuteExpression(scan_chunk, append_vector); - column_data->Append(append_state, append_vector, scan_chunk.size()); - } - - // set up the row_group based on this row_group - auto row_group = make_uniq(new_collection, this->start, this->count); - row_group->version_info = GetOrCreateVersionInfoPtr(); - auto &cols = GetColumns(); - for (idx_t i = 0; i < cols.size(); i++) { - if (i == changed_idx) { - // this is the altered column: use the new column - row_group->columns.push_back(std::move(column_data)); - } else { - // this column was not altered: use the data directly - row_group->columns.push_back(cols[i]); - } - } - row_group->Verify(); - return row_group; -} - -unique_ptr RowGroup::AddColumn(RowGroupCollection &new_collection, ColumnDefinition &new_column, - ExpressionExecutor &executor, Expression &default_value, Vector &result) { - Verify(); - - // construct a new column data for the new column - auto added_column = - ColumnData::CreateColumn(GetBlockManager(), GetTableInfo(), GetColumnCount(), start, new_column.Type()); - - idx_t rows_to_write = this->count; - if (rows_to_write > 0) { - DataChunk dummy_chunk; - - ColumnAppendState state; - added_column->InitializeAppend(state); - for (idx_t i = 0; i < rows_to_write; i += STANDARD_VECTOR_SIZE) { - idx_t rows_in_this_vector = MinValue(rows_to_write - i, STANDARD_VECTOR_SIZE); - dummy_chunk.SetCardinality(rows_in_this_vector); - executor.ExecuteExpression(dummy_chunk, result); - added_column->Append(state, result, rows_in_this_vector); - } - } - - // set up the row_group based on this row_group - auto row_group = make_uniq(new_collection, this->start, this->count); - row_group->version_info = GetOrCreateVersionInfoPtr(); - row_group->columns = GetColumns(); - // now add the new column - row_group->columns.push_back(std::move(added_column)); - - row_group->Verify(); - return row_group; -} - -unique_ptr RowGroup::RemoveColumn(RowGroupCollection &new_collection, idx_t removed_column) { - Verify(); - - D_ASSERT(removed_column < columns.size()); - - auto row_group = make_uniq(new_collection, this->start, this->count); - row_group->version_info = GetOrCreateVersionInfoPtr(); - // copy over all columns except for the removed one - auto &cols = GetColumns(); - for (idx_t i = 0; i < cols.size(); i++) { - if (i != removed_column) { - row_group->columns.push_back(cols[i]); - } - } - - row_group->Verify(); - return row_group; -} - -void RowGroup::CommitDrop() { - for (idx_t column_idx = 0; column_idx < GetColumnCount(); column_idx++) { - CommitDropColumn(column_idx); - } -} - -void RowGroup::CommitDropColumn(idx_t column_idx) { - GetColumn(column_idx).CommitDropColumn(); -} - -void RowGroup::NextVector(CollectionScanState &state) { - state.vector_index++; - const auto &column_ids = state.GetColumnIds(); - for (idx_t i = 0; i < column_ids.size(); i++) { - const auto &column = column_ids[i]; - if (column == COLUMN_IDENTIFIER_ROW_ID) { - continue; - } - D_ASSERT(column < columns.size()); - GetColumn(column).Skip(state.column_scans[i]); - } -} - -bool RowGroup::CheckZonemap(TableFilterSet &filters, const vector &column_ids) { - for (auto &entry : filters.filters) { - auto column_index = entry.first; - auto &filter = entry.second; - const auto &base_column_index = column_ids[column_index]; - if (!GetColumn(base_column_index).CheckZonemap(*filter)) { - return false; - } - } - return true; -} - -bool RowGroup::CheckZonemapSegments(CollectionScanState &state) { - auto &column_ids = state.GetColumnIds(); - auto filters = state.GetFilters(); - if (!filters) { - return true; - } - for (auto &entry : filters->filters) { - D_ASSERT(entry.first < column_ids.size()); - auto column_idx = entry.first; - const auto &base_column_idx = column_ids[column_idx]; - bool read_segment = GetColumn(base_column_idx).CheckZonemap(state.column_scans[column_idx], *entry.second); - if (!read_segment) { - idx_t target_row = - state.column_scans[column_idx].current->start + state.column_scans[column_idx].current->count; - D_ASSERT(target_row >= this->start); - D_ASSERT(target_row <= this->start + this->count); - idx_t target_vector_index = (target_row - this->start) / STANDARD_VECTOR_SIZE; - if (state.vector_index == target_vector_index) { - // we can't skip any full vectors because this segment contains less than a full vector - // for now we just bail-out - // FIXME: we could check if we can ALSO skip the next segments, in which case skipping a full vector - // might be possible - // we don't care that much though, since a single segment that fits less than a full vector is - // exceedingly rare - return true; - } - while (state.vector_index < target_vector_index) { - NextVector(state); - } - return false; - } - } - - return true; -} - -template -void RowGroup::TemplatedScan(TransactionData transaction, CollectionScanState &state, DataChunk &result) { - const bool ALLOW_UPDATES = TYPE != TableScanType::TABLE_SCAN_COMMITTED_ROWS_DISALLOW_UPDATES && - TYPE != TableScanType::TABLE_SCAN_COMMITTED_ROWS_OMIT_PERMANENTLY_DELETED; - auto table_filters = state.GetFilters(); - const auto &column_ids = state.GetColumnIds(); - auto adaptive_filter = state.GetAdaptiveFilter(); - while (true) { - if (state.vector_index * STANDARD_VECTOR_SIZE >= state.max_row_group_row) { - // exceeded the amount of rows to scan - return; - } - idx_t current_row = state.vector_index * STANDARD_VECTOR_SIZE; - auto max_count = MinValue(STANDARD_VECTOR_SIZE, state.max_row_group_row - current_row); - - //! first check the zonemap if we have to scan this partition - if (!CheckZonemapSegments(state)) { - continue; - } - // second, scan the version chunk manager to figure out which tuples to load for this transaction - idx_t count; - SelectionVector valid_sel(STANDARD_VECTOR_SIZE); - if (TYPE == TableScanType::TABLE_SCAN_REGULAR) { - count = state.row_group->GetSelVector(transaction, state.vector_index, valid_sel, max_count); - if (count == 0) { - // nothing to scan for this vector, skip the entire vector - NextVector(state); - continue; - } - } else if (TYPE == TableScanType::TABLE_SCAN_COMMITTED_ROWS_OMIT_PERMANENTLY_DELETED) { - count = state.row_group->GetCommittedSelVector(transaction.start_time, transaction.transaction_id, - state.vector_index, valid_sel, max_count); - if (count == 0) { - // nothing to scan for this vector, skip the entire vector - NextVector(state); - continue; - } - } else { - count = max_count; - } - if (count == max_count && !table_filters) { - // scan all vectors completely: full scan without deletions or table filters - for (idx_t i = 0; i < column_ids.size(); i++) { - const auto &column = column_ids[i]; - if (column == COLUMN_IDENTIFIER_ROW_ID) { - // scan row id - D_ASSERT(result.data[i].GetType().InternalType() == ROW_TYPE); - result.data[i].Sequence(this->start + current_row, 1, count); - } else { - auto &col_data = GetColumn(column); - if (TYPE != TableScanType::TABLE_SCAN_REGULAR) { - col_data.ScanCommitted(state.vector_index, state.column_scans[i], result.data[i], - ALLOW_UPDATES); - } else { - col_data.Scan(transaction, state.vector_index, state.column_scans[i], result.data[i]); - } - } - } - } else { - // partial scan: we have deletions or table filters - idx_t approved_tuple_count = count; - SelectionVector sel; - if (count != max_count) { - sel.Initialize(valid_sel); - } else { - sel.Initialize(nullptr); - } - //! first, we scan the columns with filters, fetch their data and generate a selection vector. - //! get runtime statistics - auto start_time = high_resolution_clock::now(); - if (table_filters) { - D_ASSERT(adaptive_filter); - D_ASSERT(ALLOW_UPDATES); - for (idx_t i = 0; i < table_filters->filters.size(); i++) { - auto tf_idx = adaptive_filter->permutation[i]; - auto col_idx = column_ids[tf_idx]; - auto &col_data = GetColumn(col_idx); - col_data.Select(transaction, state.vector_index, state.column_scans[tf_idx], result.data[tf_idx], - sel, approved_tuple_count, *table_filters->filters[tf_idx]); - } - for (auto &table_filter : table_filters->filters) { - result.data[table_filter.first].Slice(sel, approved_tuple_count); - } - } - if (approved_tuple_count == 0) { - // all rows were filtered out by the table filters - // skip this vector in all the scans that were not scanned yet - D_ASSERT(table_filters); - result.Reset(); - for (idx_t i = 0; i < column_ids.size(); i++) { - auto col_idx = column_ids[i]; - if (col_idx == COLUMN_IDENTIFIER_ROW_ID) { - continue; - } - if (table_filters->filters.find(i) == table_filters->filters.end()) { - auto &col_data = GetColumn(col_idx); - col_data.Skip(state.column_scans[i]); - } - } - state.vector_index++; - continue; - } - //! Now we use the selection vector to fetch data for the other columns. - for (idx_t i = 0; i < column_ids.size(); i++) { - if (!table_filters || table_filters->filters.find(i) == table_filters->filters.end()) { - auto column = column_ids[i]; - if (column == COLUMN_IDENTIFIER_ROW_ID) { - D_ASSERT(result.data[i].GetType().InternalType() == PhysicalType::INT64); - result.data[i].SetVectorType(VectorType::FLAT_VECTOR); - auto result_data = FlatVector::GetData(result.data[i]); - for (size_t sel_idx = 0; sel_idx < approved_tuple_count; sel_idx++) { - result_data[sel_idx] = this->start + current_row + sel.get_index(sel_idx); - } - } else { - auto &col_data = GetColumn(column); - if (TYPE == TableScanType::TABLE_SCAN_REGULAR) { - col_data.FilterScan(transaction, state.vector_index, state.column_scans[i], result.data[i], - sel, approved_tuple_count); - } else { - col_data.FilterScanCommitted(state.vector_index, state.column_scans[i], result.data[i], sel, - approved_tuple_count, ALLOW_UPDATES); - } - } - } - } - auto end_time = high_resolution_clock::now(); - if (adaptive_filter && table_filters->filters.size() > 1) { - adaptive_filter->AdaptRuntimeStatistics(duration_cast>(end_time - start_time).count()); - } - D_ASSERT(approved_tuple_count > 0); - count = approved_tuple_count; - } - result.SetCardinality(count); - state.vector_index++; - break; - } -} - -void RowGroup::Scan(TransactionData transaction, CollectionScanState &state, DataChunk &result) { - TemplatedScan(transaction, state, result); -} - -void RowGroup::ScanCommitted(CollectionScanState &state, DataChunk &result, TableScanType type) { - auto &transaction_manager = DuckTransactionManager::Get(GetCollection().GetAttached()); - - auto lowest_active_start = transaction_manager.LowestActiveStart(); - auto lowest_active_id = transaction_manager.LowestActiveId(); - TransactionData data(lowest_active_id, lowest_active_start); - switch (type) { - case TableScanType::TABLE_SCAN_COMMITTED_ROWS: - TemplatedScan(data, state, result); - break; - case TableScanType::TABLE_SCAN_COMMITTED_ROWS_DISALLOW_UPDATES: - TemplatedScan(data, state, result); - break; - case TableScanType::TABLE_SCAN_COMMITTED_ROWS_OMIT_PERMANENTLY_DELETED: - TemplatedScan(data, state, result); - break; - default: - throw InternalException("Unrecognized table scan type"); - } -} - -shared_ptr &RowGroup::GetVersionInfo() { - if (!HasUnloadedDeletes()) { - // deletes are loaded - return the version info - return version_info; - } - lock_guard lock(row_group_lock); - // double-check after obtaining the lock whether or not deletes are still not loaded to avoid double load - if (HasUnloadedDeletes()) { - // deletes are not loaded - reload - auto root_delete = deletes_pointers[0]; - version_info = RowVersionManager::Deserialize(root_delete, GetBlockManager().GetMetadataManager(), start); - deletes_is_loaded = true; - } - return version_info; -} - -shared_ptr &RowGroup::GetOrCreateVersionInfoPtr() { - auto vinfo = GetVersionInfo(); - if (!vinfo) { - lock_guard lock(row_group_lock); - if (!version_info) { - version_info = make_shared(start); - } - } - return version_info; -} - -RowVersionManager &RowGroup::GetOrCreateVersionInfo() { - return *GetOrCreateVersionInfoPtr(); -} - -idx_t RowGroup::GetSelVector(TransactionData transaction, idx_t vector_idx, SelectionVector &sel_vector, - idx_t max_count) { - auto &vinfo = GetVersionInfo(); - if (!vinfo) { - return max_count; - } - return vinfo->GetSelVector(transaction, vector_idx, sel_vector, max_count); -} - -idx_t RowGroup::GetCommittedSelVector(transaction_t start_time, transaction_t transaction_id, idx_t vector_idx, - SelectionVector &sel_vector, idx_t max_count) { - auto &vinfo = GetVersionInfo(); - if (!vinfo) { - return max_count; - } - return vinfo->GetCommittedSelVector(start_time, transaction_id, vector_idx, sel_vector, max_count); -} - -bool RowGroup::Fetch(TransactionData transaction, idx_t row) { - D_ASSERT(row < this->count); - auto &vinfo = GetVersionInfo(); - if (!vinfo) { - return true; - } - return vinfo->Fetch(transaction, row); -} - -void RowGroup::FetchRow(TransactionData transaction, ColumnFetchState &state, const vector &column_ids, - row_t row_id, DataChunk &result, idx_t result_idx) { - for (idx_t col_idx = 0; col_idx < column_ids.size(); col_idx++) { - auto column = column_ids[col_idx]; - if (column == COLUMN_IDENTIFIER_ROW_ID) { - // row id column: fill in the row ids - D_ASSERT(result.data[col_idx].GetType().InternalType() == PhysicalType::INT64); - result.data[col_idx].SetVectorType(VectorType::FLAT_VECTOR); - auto data = FlatVector::GetData(result.data[col_idx]); - data[result_idx] = row_id; - } else { - // regular column: fetch data from the base column - auto &col_data = GetColumn(column); - col_data.FetchRow(transaction, state, row_id, result.data[col_idx], result_idx); - } - } -} - -void RowGroup::AppendVersionInfo(TransactionData transaction, idx_t count) { - idx_t row_group_start = this->count.load(); - idx_t row_group_end = row_group_start + count; - if (row_group_end > Storage::ROW_GROUP_SIZE) { - row_group_end = Storage::ROW_GROUP_SIZE; - } - // create the version_info if it doesn't exist yet - auto &vinfo = GetOrCreateVersionInfo(); - vinfo.AppendVersionInfo(transaction, count, row_group_start, row_group_end); - this->count = row_group_end; -} - -void RowGroup::CommitAppend(transaction_t commit_id, idx_t row_group_start, idx_t count) { - auto &vinfo = GetOrCreateVersionInfo(); - vinfo.CommitAppend(commit_id, row_group_start, count); -} - -void RowGroup::RevertAppend(idx_t row_group_start) { - auto &vinfo = GetOrCreateVersionInfo(); - vinfo.RevertAppend(row_group_start - this->start); - for (auto &column : columns) { - column->RevertAppend(row_group_start); - } - this->count = MinValue(row_group_start - this->start, this->count); - Verify(); -} - -void RowGroup::InitializeAppend(RowGroupAppendState &append_state) { - append_state.row_group = this; - append_state.offset_in_row_group = this->count; - // for each column, initialize the append state - append_state.states = make_unsafe_uniq_array(GetColumnCount()); - for (idx_t i = 0; i < GetColumnCount(); i++) { - auto &col_data = GetColumn(i); - col_data.InitializeAppend(append_state.states[i]); - } -} - -void RowGroup::Append(RowGroupAppendState &state, DataChunk &chunk, idx_t append_count) { - // append to the current row_group - for (idx_t i = 0; i < GetColumnCount(); i++) { - auto &col_data = GetColumn(i); - col_data.Append(state.states[i], chunk.data[i], append_count); - } - state.offset_in_row_group += append_count; -} - -void RowGroup::Update(TransactionData transaction, DataChunk &update_chunk, row_t *ids, idx_t offset, idx_t count, - const vector &column_ids) { -#ifdef DEBUG - for (size_t i = offset; i < offset + count; i++) { - D_ASSERT(ids[i] >= row_t(this->start) && ids[i] < row_t(this->start + this->count)); - } -#endif - for (idx_t i = 0; i < column_ids.size(); i++) { - auto column = column_ids[i]; - D_ASSERT(column.index != COLUMN_IDENTIFIER_ROW_ID); - auto &col_data = GetColumn(column.index); - D_ASSERT(col_data.type.id() == update_chunk.data[i].GetType().id()); - if (offset > 0) { - Vector sliced_vector(update_chunk.data[i], offset, offset + count); - sliced_vector.Flatten(count); - col_data.Update(transaction, column.index, sliced_vector, ids + offset, count); - } else { - col_data.Update(transaction, column.index, update_chunk.data[i], ids, count); - } - MergeStatistics(column.index, *col_data.GetUpdateStatistics()); - } -} - -void RowGroup::UpdateColumn(TransactionData transaction, DataChunk &updates, Vector &row_ids, - const vector &column_path) { - D_ASSERT(updates.ColumnCount() == 1); - auto ids = FlatVector::GetData(row_ids); - - auto primary_column_idx = column_path[0]; - D_ASSERT(primary_column_idx != COLUMN_IDENTIFIER_ROW_ID); - D_ASSERT(primary_column_idx < columns.size()); - auto &col_data = GetColumn(primary_column_idx); - col_data.UpdateColumn(transaction, column_path, updates.data[0], ids, updates.size(), 1); - MergeStatistics(primary_column_idx, *col_data.GetUpdateStatistics()); -} - -unique_ptr RowGroup::GetStatistics(idx_t column_idx) { - auto &col_data = GetColumn(column_idx); - lock_guard slock(stats_lock); - return col_data.GetStatistics(); -} - -void RowGroup::MergeStatistics(idx_t column_idx, const BaseStatistics &other) { - auto &col_data = GetColumn(column_idx); - lock_guard slock(stats_lock); - col_data.MergeStatistics(other); -} - -void RowGroup::MergeIntoStatistics(idx_t column_idx, BaseStatistics &other) { - auto &col_data = GetColumn(column_idx); - lock_guard slock(stats_lock); - col_data.MergeIntoStatistics(other); -} - -RowGroupWriteData RowGroup::WriteToDisk(PartialBlockManager &manager, - const vector &compression_types) { - RowGroupWriteData result; - result.states.reserve(columns.size()); - result.statistics.reserve(columns.size()); - - // Checkpoint the individual columns of the row group - // Here we're iterating over columns. Each column can have multiple segments. - // (Some columns will be wider than others, and require different numbers - // of blocks to encode.) Segments cannot span blocks. - // - // Some of these columns are composite (list, struct). The data is written - // first sequentially, and the pointers are written later, so that the - // pointers all end up densely packed, and thus more cache-friendly. - for (idx_t column_idx = 0; column_idx < GetColumnCount(); column_idx++) { - auto &column = GetColumn(column_idx); - ColumnCheckpointInfo checkpoint_info {compression_types[column_idx]}; - auto checkpoint_state = column.Checkpoint(*this, manager, checkpoint_info); - D_ASSERT(checkpoint_state); - - auto stats = checkpoint_state->GetStatistics(); - D_ASSERT(stats); - - result.statistics.push_back(stats->Copy()); - result.states.push_back(std::move(checkpoint_state)); - } - D_ASSERT(result.states.size() == result.statistics.size()); - return result; -} - -bool RowGroup::AllDeleted() { - if (HasUnloadedDeletes()) { - // deletes aren't loaded yet - we know not everything is deleted - return false; - } - auto &vinfo = GetVersionInfo(); - if (!vinfo) { - return false; - } - return vinfo->GetCommittedDeletedCount(count) == count; -} - -bool RowGroup::HasUnloadedDeletes() const { - if (deletes_pointers.empty()) { - // no stored deletes at all - return false; - } - // return whether or not the deletes have been loaded - return !deletes_is_loaded; -} - -RowGroupPointer RowGroup::Checkpoint(RowGroupWriter &writer, TableStatistics &global_stats) { - RowGroupPointer row_group_pointer; - - vector compression_types; - compression_types.reserve(columns.size()); - for (idx_t column_idx = 0; column_idx < GetColumnCount(); column_idx++) { - auto &column = GetColumn(column_idx); - if (column.count != this->count) { - throw InternalException("Corrupted in-memory column - column with index %llu has misaligned count (row " - "group has %llu rows, column has %llu)", - column_idx, this->count.load(), column.count); - } - compression_types.push_back(writer.GetColumnCompressionType(column_idx)); - } - auto result = WriteToDisk(writer.GetPartialBlockManager(), compression_types); - for (idx_t column_idx = 0; column_idx < GetColumnCount(); column_idx++) { - global_stats.GetStats(column_idx).Statistics().Merge(result.statistics[column_idx]); - } - - // construct the row group pointer and write the column meta data to disk - D_ASSERT(result.states.size() == columns.size()); - row_group_pointer.row_start = start; - row_group_pointer.tuple_count = count; - for (auto &state : result.states) { - // get the current position of the table data writer - auto &data_writer = writer.GetPayloadWriter(); - auto pointer = data_writer.GetMetaBlockPointer(); - - // store the stats and the data pointers in the row group pointers - row_group_pointer.data_pointers.push_back(pointer); - - // Write pointers to the column segments. - // - // Just as above, the state can refer to many other states, so this - // can cascade recursively into more pointer writes. - BinarySerializer serializer(data_writer); - serializer.Begin(); - state->WriteDataPointers(writer, serializer); - serializer.End(); - } - row_group_pointer.deletes_pointers = CheckpointDeletes(writer.GetPayloadWriter().GetManager()); - Verify(); - return row_group_pointer; -} - -vector RowGroup::CheckpointDeletes(MetadataManager &manager) { - if (HasUnloadedDeletes()) { - // deletes were not loaded so they cannot be changed - // re-use them as-is - manager.ClearModifiedBlocks(deletes_pointers); - return deletes_pointers; - } - if (!version_info) { - // no version information: write nothing - return vector(); - } - return version_info->Checkpoint(manager); -} - -void RowGroup::Serialize(RowGroupPointer &pointer, Serializer &serializer) { - serializer.WriteProperty(100, "row_start", pointer.row_start); - serializer.WriteProperty(101, "tuple_count", pointer.tuple_count); - serializer.WriteProperty(102, "data_pointers", pointer.data_pointers); - serializer.WriteProperty(103, "delete_pointers", pointer.deletes_pointers); -} - -RowGroupPointer RowGroup::Deserialize(Deserializer &deserializer) { - RowGroupPointer result; - result.row_start = deserializer.ReadProperty(100, "row_start"); - result.tuple_count = deserializer.ReadProperty(101, "tuple_count"); - result.data_pointers = deserializer.ReadProperty>(102, "data_pointers"); - result.deletes_pointers = deserializer.ReadProperty>(103, "delete_pointers"); - return result; -} - -//===--------------------------------------------------------------------===// -// GetColumnSegmentInfo -//===--------------------------------------------------------------------===// -void RowGroup::GetColumnSegmentInfo(idx_t row_group_index, vector &result) { - for (idx_t col_idx = 0; col_idx < GetColumnCount(); col_idx++) { - auto &col_data = GetColumn(col_idx); - col_data.GetColumnSegmentInfo(row_group_index, {col_idx}, result); - } -} - -//===--------------------------------------------------------------------===// -// Version Delete Information -//===--------------------------------------------------------------------===// -class VersionDeleteState { -public: - VersionDeleteState(RowGroup &info, TransactionData transaction, DataTable &table, idx_t base_row) - : info(info), transaction(transaction), table(table), current_chunk(DConstants::INVALID_INDEX), count(0), - base_row(base_row), delete_count(0) { - } - - RowGroup &info; - TransactionData transaction; - DataTable &table; - idx_t current_chunk; - row_t rows[STANDARD_VECTOR_SIZE]; - idx_t count; - idx_t base_row; - idx_t chunk_row; - idx_t delete_count; - -public: - void Delete(row_t row_id); - void Flush(); -}; - -idx_t RowGroup::Delete(TransactionData transaction, DataTable &table, row_t *ids, idx_t count) { - VersionDeleteState del_state(*this, transaction, table, this->start); - - // obtain a write lock - for (idx_t i = 0; i < count; i++) { - D_ASSERT(ids[i] >= 0); - D_ASSERT(idx_t(ids[i]) >= this->start && idx_t(ids[i]) < this->start + this->count); - del_state.Delete(ids[i] - this->start); - } - del_state.Flush(); - return del_state.delete_count; -} - -void RowGroup::Verify() { -#ifdef DEBUG - for (auto &column : GetColumns()) { - column->Verify(*this); - } -#endif -} - -idx_t RowGroup::DeleteRows(idx_t vector_idx, transaction_t transaction_id, row_t rows[], idx_t count) { - return GetOrCreateVersionInfo().DeleteRows(vector_idx, transaction_id, rows, count); -} - -void VersionDeleteState::Delete(row_t row_id) { - D_ASSERT(row_id >= 0); - idx_t vector_idx = row_id / STANDARD_VECTOR_SIZE; - idx_t idx_in_vector = row_id - vector_idx * STANDARD_VECTOR_SIZE; - if (current_chunk != vector_idx) { - Flush(); - - current_chunk = vector_idx; - chunk_row = vector_idx * STANDARD_VECTOR_SIZE; - } - rows[count++] = idx_in_vector; -} - -void VersionDeleteState::Flush() { - if (count == 0) { - return; - } - // it is possible for delete statements to delete the same tuple multiple times when combined with a USING clause - // in the current_info->Delete, we check which tuples are actually deleted (excluding duplicate deletions) - // this is returned in the actual_delete_count - auto actual_delete_count = info.DeleteRows(current_chunk, transaction.transaction_id, rows, count); - delete_count += actual_delete_count; - if (transaction.transaction && actual_delete_count > 0) { - // now push the delete into the undo buffer, but only if any deletes were actually performed - transaction.transaction->PushDelete(table, info.GetOrCreateVersionInfo(), current_chunk, rows, - actual_delete_count, base_row + chunk_row); - } - count = 0; -} - -} // namespace duckdb - - - - - - - - - - - - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Row Group Segment Tree -//===--------------------------------------------------------------------===// -RowGroupSegmentTree::RowGroupSegmentTree(RowGroupCollection &collection) - : SegmentTree(), collection(collection), current_row_group(0), max_row_group(0) { -} -RowGroupSegmentTree::~RowGroupSegmentTree() { -} - -void RowGroupSegmentTree::Initialize(PersistentTableData &data) { - D_ASSERT(data.row_group_count > 0); - current_row_group = 0; - max_row_group = data.row_group_count; - finished_loading = false; - reader = make_uniq(collection.GetMetadataManager(), data.block_pointer); -} - -unique_ptr RowGroupSegmentTree::LoadSegment() { - if (current_row_group >= max_row_group) { - reader.reset(); - finished_loading = true; - return nullptr; - } - BinaryDeserializer deserializer(*reader); - deserializer.Begin(); - auto row_group_pointer = RowGroup::Deserialize(deserializer); - deserializer.End(); - current_row_group++; - return make_uniq(collection, std::move(row_group_pointer)); -} - -//===--------------------------------------------------------------------===// -// Row Group Collection -//===--------------------------------------------------------------------===// -RowGroupCollection::RowGroupCollection(shared_ptr info_p, BlockManager &block_manager, - vector types_p, idx_t row_start_p, idx_t total_rows_p) - : block_manager(block_manager), total_rows(total_rows_p), info(std::move(info_p)), types(std::move(types_p)), - row_start(row_start_p) { - row_groups = make_shared(*this); -} - -idx_t RowGroupCollection::GetTotalRows() const { - return total_rows.load(); -} - -const vector &RowGroupCollection::GetTypes() const { - return types; -} - -Allocator &RowGroupCollection::GetAllocator() const { - return Allocator::Get(info->db); -} - -AttachedDatabase &RowGroupCollection::GetAttached() { - return GetTableInfo().db; -} - -MetadataManager &RowGroupCollection::GetMetadataManager() { - return GetBlockManager().GetMetadataManager(); -} - -//===--------------------------------------------------------------------===// -// Initialize -//===--------------------------------------------------------------------===// -void RowGroupCollection::Initialize(PersistentTableData &data) { - D_ASSERT(this->row_start == 0); - auto l = row_groups->Lock(); - this->total_rows = data.total_rows; - row_groups->Initialize(data); - stats.Initialize(types, data); -} - -void RowGroupCollection::InitializeEmpty() { - stats.InitializeEmpty(types); -} - -void RowGroupCollection::AppendRowGroup(SegmentLock &l, idx_t start_row) { - D_ASSERT(start_row >= row_start); - auto new_row_group = make_uniq(*this, start_row, 0); - new_row_group->InitializeEmpty(types); - row_groups->AppendSegment(l, std::move(new_row_group)); -} - -RowGroup *RowGroupCollection::GetRowGroup(int64_t index) { - return (RowGroup *)row_groups->GetSegmentByIndex(index); -} - -void RowGroupCollection::Verify() { -#ifdef DEBUG - idx_t current_total_rows = 0; - row_groups->Verify(); - for (auto &row_group : row_groups->Segments()) { - row_group.Verify(); - D_ASSERT(&row_group.GetCollection() == this); - D_ASSERT(row_group.start == this->row_start + current_total_rows); - current_total_rows += row_group.count; - } - D_ASSERT(current_total_rows == total_rows.load()); -#endif -} - -//===--------------------------------------------------------------------===// -// Scan -//===--------------------------------------------------------------------===// -void RowGroupCollection::InitializeScan(CollectionScanState &state, const vector &column_ids, - TableFilterSet *table_filters) { - auto row_group = row_groups->GetRootSegment(); - D_ASSERT(row_group); - state.row_groups = row_groups.get(); - state.max_row = row_start + total_rows; - state.Initialize(GetTypes()); - while (row_group && !row_group->InitializeScan(state)) { - row_group = row_groups->GetNextSegment(row_group); - } -} - -void RowGroupCollection::InitializeCreateIndexScan(CreateIndexScanState &state) { - state.segment_lock = row_groups->Lock(); -} - -void RowGroupCollection::InitializeScanWithOffset(CollectionScanState &state, const vector &column_ids, - idx_t start_row, idx_t end_row) { - auto row_group = row_groups->GetSegment(start_row); - D_ASSERT(row_group); - state.row_groups = row_groups.get(); - state.max_row = end_row; - state.Initialize(GetTypes()); - idx_t start_vector = (start_row - row_group->start) / STANDARD_VECTOR_SIZE; - if (!row_group->InitializeScanWithOffset(state, start_vector)) { - throw InternalException("Failed to initialize row group scan with offset"); - } -} - -bool RowGroupCollection::InitializeScanInRowGroup(CollectionScanState &state, RowGroupCollection &collection, - RowGroup &row_group, idx_t vector_index, idx_t max_row) { - state.max_row = max_row; - state.row_groups = collection.row_groups.get(); - if (!state.column_scans) { - // initialize the scan state - state.Initialize(collection.GetTypes()); - } - return row_group.InitializeScanWithOffset(state, vector_index); -} - -void RowGroupCollection::InitializeParallelScan(ParallelCollectionScanState &state) { - state.collection = this; - state.current_row_group = row_groups->GetRootSegment(); - state.vector_index = 0; - state.max_row = row_start + total_rows; - state.batch_index = 0; - state.processed_rows = 0; -} - -bool RowGroupCollection::NextParallelScan(ClientContext &context, ParallelCollectionScanState &state, - CollectionScanState &scan_state) { - while (true) { - idx_t vector_index; - idx_t max_row; - RowGroupCollection *collection; - RowGroup *row_group; - { - // select the next row group to scan from the parallel state - lock_guard l(state.lock); - if (!state.current_row_group || state.current_row_group->count == 0) { - // no more data left to scan - break; - } - collection = state.collection; - row_group = state.current_row_group; - if (ClientConfig::GetConfig(context).verify_parallelism) { - vector_index = state.vector_index; - max_row = state.current_row_group->start + - MinValue(state.current_row_group->count, - STANDARD_VECTOR_SIZE * state.vector_index + STANDARD_VECTOR_SIZE); - D_ASSERT(vector_index * STANDARD_VECTOR_SIZE < state.current_row_group->count); - state.vector_index++; - if (state.vector_index * STANDARD_VECTOR_SIZE >= state.current_row_group->count) { - state.current_row_group = row_groups->GetNextSegment(state.current_row_group); - state.vector_index = 0; - } - } else { - state.processed_rows += state.current_row_group->count; - vector_index = 0; - max_row = state.current_row_group->start + state.current_row_group->count; - state.current_row_group = row_groups->GetNextSegment(state.current_row_group); - } - max_row = MinValue(max_row, state.max_row); - scan_state.batch_index = ++state.batch_index; - } - D_ASSERT(collection); - D_ASSERT(row_group); - - // initialize the scan for this row group - bool need_to_scan = InitializeScanInRowGroup(scan_state, *collection, *row_group, vector_index, max_row); - if (!need_to_scan) { - // skip this row group - continue; - } - return true; - } - return false; -} - -bool RowGroupCollection::Scan(DuckTransaction &transaction, const vector &column_ids, - const std::function &fun) { - vector scan_types; - for (idx_t i = 0; i < column_ids.size(); i++) { - scan_types.push_back(types[column_ids[i]]); - } - DataChunk chunk; - chunk.Initialize(GetAllocator(), scan_types); - - // initialize the scan - TableScanState state; - state.Initialize(column_ids, nullptr); - InitializeScan(state.local_state, column_ids, nullptr); - - while (true) { - chunk.Reset(); - state.local_state.Scan(transaction, chunk); - if (chunk.size() == 0) { - return true; - } - if (!fun(chunk)) { - return false; - } - } -} - -bool RowGroupCollection::Scan(DuckTransaction &transaction, const std::function &fun) { - vector column_ids; - column_ids.reserve(types.size()); - for (idx_t i = 0; i < types.size(); i++) { - column_ids.push_back(i); - } - return Scan(transaction, column_ids, fun); -} - -//===--------------------------------------------------------------------===// -// Fetch -//===--------------------------------------------------------------------===// -void RowGroupCollection::Fetch(TransactionData transaction, DataChunk &result, const vector &column_ids, - const Vector &row_identifiers, idx_t fetch_count, ColumnFetchState &state) { - // figure out which row_group to fetch from - auto row_ids = FlatVector::GetData(row_identifiers); - idx_t count = 0; - for (idx_t i = 0; i < fetch_count; i++) { - auto row_id = row_ids[i]; - RowGroup *row_group; - { - idx_t segment_index; - auto l = row_groups->Lock(); - if (!row_groups->TryGetSegmentIndex(l, row_id, segment_index)) { - // in parallel append scenarios it is possible for the row_id - continue; - } - row_group = row_groups->GetSegmentByIndex(l, segment_index); - } - if (!row_group->Fetch(transaction, row_id - row_group->start)) { - continue; - } - row_group->FetchRow(transaction, state, column_ids, row_id, result, count); - count++; - } - result.SetCardinality(count); -} - -//===--------------------------------------------------------------------===// -// Append -//===--------------------------------------------------------------------===// -TableAppendState::TableAppendState() - : row_group_append_state(*this), total_append_count(0), start_row_group(nullptr), transaction(0, 0), remaining(0) { -} - -TableAppendState::~TableAppendState() { - D_ASSERT(Exception::UncaughtException() || remaining == 0); -} - -bool RowGroupCollection::IsEmpty() const { - auto l = row_groups->Lock(); - return IsEmpty(l); -} - -bool RowGroupCollection::IsEmpty(SegmentLock &l) const { - return row_groups->IsEmpty(l); -} - -void RowGroupCollection::InitializeAppend(TransactionData transaction, TableAppendState &state, idx_t append_count) { - state.row_start = total_rows; - state.current_row = state.row_start; - state.total_append_count = 0; - - // start writing to the row_groups - auto l = row_groups->Lock(); - if (IsEmpty(l)) { - // empty row group collection: empty first row group - AppendRowGroup(l, row_start); - } - state.start_row_group = row_groups->GetLastSegment(l); - D_ASSERT(this->row_start + total_rows == state.start_row_group->start + state.start_row_group->count); - state.start_row_group->InitializeAppend(state.row_group_append_state); - state.remaining = append_count; - state.transaction = transaction; - if (state.remaining > 0) { - state.start_row_group->AppendVersionInfo(transaction, state.remaining); - total_rows += state.remaining; - } -} - -void RowGroupCollection::InitializeAppend(TableAppendState &state) { - TransactionData tdata(0, 0); - InitializeAppend(tdata, state, 0); -} - -bool RowGroupCollection::Append(DataChunk &chunk, TableAppendState &state) { - D_ASSERT(chunk.ColumnCount() == types.size()); - chunk.Verify(); - - bool new_row_group = false; - idx_t append_count = chunk.size(); - idx_t remaining = chunk.size(); - state.total_append_count += append_count; - while (true) { - auto current_row_group = state.row_group_append_state.row_group; - // check how much we can fit into the current row_group - idx_t append_count = - MinValue(remaining, Storage::ROW_GROUP_SIZE - state.row_group_append_state.offset_in_row_group); - if (append_count > 0) { - current_row_group->Append(state.row_group_append_state, chunk, append_count); - // merge the stats - auto stats_lock = stats.GetLock(); - for (idx_t i = 0; i < types.size(); i++) { - current_row_group->MergeIntoStatistics(i, stats.GetStats(i).Statistics()); - } - } - remaining -= append_count; - if (state.remaining > 0) { - state.remaining -= append_count; - } - if (remaining > 0) { - // we expect max 1 iteration of this loop (i.e. a single chunk should never overflow more than one - // row_group) - D_ASSERT(chunk.size() == remaining + append_count); - // slice the input chunk - if (remaining < chunk.size()) { - SelectionVector sel(remaining); - for (idx_t i = 0; i < remaining; i++) { - sel.set_index(i, append_count + i); - } - chunk.Slice(sel, remaining); - } - // append a new row_group - new_row_group = true; - auto next_start = current_row_group->start + state.row_group_append_state.offset_in_row_group; - - auto l = row_groups->Lock(); - AppendRowGroup(l, next_start); - // set up the append state for this row_group - auto last_row_group = row_groups->GetLastSegment(l); - last_row_group->InitializeAppend(state.row_group_append_state); - if (state.remaining > 0) { - last_row_group->AppendVersionInfo(state.transaction, state.remaining); - } - continue; - } else { - break; - } - } - state.current_row += append_count; - auto stats_lock = stats.GetLock(); - for (idx_t col_idx = 0; col_idx < types.size(); col_idx++) { - stats.GetStats(col_idx).UpdateDistinctStatistics(chunk.data[col_idx], chunk.size()); - } - return new_row_group; -} - -void RowGroupCollection::FinalizeAppend(TransactionData transaction, TableAppendState &state) { - auto remaining = state.total_append_count; - auto row_group = state.start_row_group; - while (remaining > 0) { - auto append_count = MinValue(remaining, Storage::ROW_GROUP_SIZE - row_group->count); - row_group->AppendVersionInfo(transaction, append_count); - remaining -= append_count; - row_group = row_groups->GetNextSegment(row_group); - } - total_rows += state.total_append_count; - - state.total_append_count = 0; - state.start_row_group = nullptr; - - Verify(); -} - -void RowGroupCollection::CommitAppend(transaction_t commit_id, idx_t row_start, idx_t count) { - auto row_group = row_groups->GetSegment(row_start); - D_ASSERT(row_group); - idx_t current_row = row_start; - idx_t remaining = count; - while (true) { - idx_t start_in_row_group = current_row - row_group->start; - idx_t append_count = MinValue(row_group->count - start_in_row_group, remaining); - - row_group->CommitAppend(commit_id, start_in_row_group, append_count); - - current_row += append_count; - remaining -= append_count; - if (remaining == 0) { - break; - } - row_group = row_groups->GetNextSegment(row_group); - } -} - -void RowGroupCollection::RevertAppendInternal(idx_t start_row) { - if (total_rows <= start_row) { - return; - } - total_rows = start_row; - - auto l = row_groups->Lock(); - // find the segment index that the current row belongs to - idx_t segment_index = row_groups->GetSegmentIndex(l, start_row); - auto segment = row_groups->GetSegmentByIndex(l, segment_index); - auto &info = *segment; - - // remove any segments AFTER this segment: they should be deleted entirely - row_groups->EraseSegments(l, segment_index); - - info.next = nullptr; - info.RevertAppend(start_row); -} - -void RowGroupCollection::MergeStorage(RowGroupCollection &data) { - D_ASSERT(data.types == types); - auto index = row_start + total_rows.load(); - auto segments = data.row_groups->MoveSegments(); - for (auto &entry : segments) { - auto &row_group = entry.node; - row_group->MoveToCollection(*this, index); - index += row_group->count; - row_groups->AppendSegment(std::move(row_group)); - } - stats.MergeStats(data.stats); - total_rows += data.total_rows.load(); -} - -//===--------------------------------------------------------------------===// -// Delete -//===--------------------------------------------------------------------===// -idx_t RowGroupCollection::Delete(TransactionData transaction, DataTable &table, row_t *ids, idx_t count) { - idx_t delete_count = 0; - // delete is in the row groups - // we need to figure out for each id to which row group it belongs - // usually all (or many) ids belong to the same row group - // we iterate over the ids and check for every id if it belongs to the same row group as their predecessor - idx_t pos = 0; - do { - idx_t start = pos; - auto row_group = row_groups->GetSegment(ids[start]); - for (pos++; pos < count; pos++) { - D_ASSERT(ids[pos] >= 0); - // check if this id still belongs to this row group - if (idx_t(ids[pos]) < row_group->start) { - // id is before row_group start -> it does not - break; - } - if (idx_t(ids[pos]) >= row_group->start + row_group->count) { - // id is after row group end -> it does not - break; - } - } - delete_count += row_group->Delete(transaction, table, ids + start, pos - start); - } while (pos < count); - return delete_count; -} - -//===--------------------------------------------------------------------===// -// Update -//===--------------------------------------------------------------------===// -void RowGroupCollection::Update(TransactionData transaction, row_t *ids, const vector &column_ids, - DataChunk &updates) { - idx_t pos = 0; - do { - idx_t start = pos; - auto row_group = row_groups->GetSegment(ids[pos]); - row_t base_id = - row_group->start + ((ids[pos] - row_group->start) / STANDARD_VECTOR_SIZE * STANDARD_VECTOR_SIZE); - row_t max_id = MinValue(base_id + STANDARD_VECTOR_SIZE, row_group->start + row_group->count); - for (pos++; pos < updates.size(); pos++) { - D_ASSERT(ids[pos] >= 0); - // check if this id still belongs to this vector in this row group - if (ids[pos] < base_id) { - // id is before vector start -> it does not - break; - } - if (ids[pos] >= max_id) { - // id is after the maximum id in this vector -> it does not - break; - } - } - row_group->Update(transaction, updates, ids, start, pos - start, column_ids); - - auto l = stats.GetLock(); - for (idx_t i = 0; i < column_ids.size(); i++) { - auto column_id = column_ids[i]; - stats.MergeStats(*l, column_id.index, *row_group->GetStatistics(column_id.index)); - } - } while (pos < updates.size()); -} - -void RowGroupCollection::RemoveFromIndexes(TableIndexList &indexes, Vector &row_identifiers, idx_t count) { - auto row_ids = FlatVector::GetData(row_identifiers); - - // initialize the fetch state - // FIXME: we do not need to fetch all columns, only the columns required by the indices! - TableScanState state; - vector column_ids; - column_ids.reserve(types.size()); - for (idx_t i = 0; i < types.size(); i++) { - column_ids.push_back(i); - } - state.Initialize(std::move(column_ids)); - state.table_state.max_row = row_start + total_rows; - - // initialize the fetch chunk - DataChunk result; - result.Initialize(GetAllocator(), types); - - SelectionVector sel(STANDARD_VECTOR_SIZE); - // now iterate over the row ids - for (idx_t r = 0; r < count;) { - result.Reset(); - // figure out which row_group to fetch from - auto row_id = row_ids[r]; - auto row_group = row_groups->GetSegment(row_id); - auto row_group_vector_idx = (row_id - row_group->start) / STANDARD_VECTOR_SIZE; - auto base_row_id = row_group_vector_idx * STANDARD_VECTOR_SIZE + row_group->start; - - // fetch the current vector - state.table_state.Initialize(GetTypes()); - row_group->InitializeScanWithOffset(state.table_state, row_group_vector_idx); - row_group->ScanCommitted(state.table_state, result, TableScanType::TABLE_SCAN_COMMITTED_ROWS); - result.Verify(); - - // check for any remaining row ids if they also fall into this vector - // we try to fetch handle as many rows as possible at the same time - idx_t sel_count = 0; - for (; r < count; r++) { - idx_t current_row = idx_t(row_ids[r]); - if (current_row < base_row_id || current_row >= base_row_id + result.size()) { - // this row-id does not fall into the current chunk - break - break; - } - auto row_in_vector = current_row - base_row_id; - D_ASSERT(row_in_vector < result.size()); - sel.set_index(sel_count++, row_in_vector); - } - D_ASSERT(sel_count > 0); - // slice the vector with all rows that are present in this vector and erase from the index - result.Slice(sel, sel_count); - - indexes.Scan([&](Index &index) { - index.Delete(result, row_identifiers); - return false; - }); - } -} - -void RowGroupCollection::UpdateColumn(TransactionData transaction, Vector &row_ids, const vector &column_path, - DataChunk &updates) { - auto first_id = FlatVector::GetValue(row_ids, 0); - if (first_id >= MAX_ROW_ID) { - throw NotImplementedException("Cannot update a column-path on transaction local data"); - } - // find the row_group this id belongs to - auto primary_column_idx = column_path[0]; - auto row_group = row_groups->GetSegment(first_id); - row_group->UpdateColumn(transaction, updates, row_ids, column_path); - - row_group->MergeIntoStatistics(primary_column_idx, stats.GetStats(primary_column_idx).Statistics()); -} - -//===--------------------------------------------------------------------===// -// Checkpoint -//===--------------------------------------------------------------------===// -void RowGroupCollection::Checkpoint(TableDataWriter &writer, TableStatistics &global_stats) { - bool can_vacuum_deletes = info->indexes.Empty(); - idx_t start = this->row_start; - auto segments = row_groups->MoveSegments(); - auto l = row_groups->Lock(); - for (auto &entry : segments) { - auto &row_group = *entry.node; - if (can_vacuum_deletes && row_group.AllDeleted()) { - row_group.CommitDrop(); - continue; - } - row_group.MoveToCollection(*this, start); - auto row_group_writer = writer.GetRowGroupWriter(row_group); - auto pointer = row_group.Checkpoint(*row_group_writer, global_stats); - writer.AddRowGroup(std::move(pointer), std::move(row_group_writer)); - row_groups->AppendSegment(l, std::move(entry.node)); - start += row_group.count; - } - total_rows = start; -} - -//===--------------------------------------------------------------------===// -// CommitDrop -//===--------------------------------------------------------------------===// -void RowGroupCollection::CommitDropColumn(idx_t index) { - for (auto &row_group : row_groups->Segments()) { - row_group.CommitDropColumn(index); - } -} - -void RowGroupCollection::CommitDropTable() { - for (auto &row_group : row_groups->Segments()) { - row_group.CommitDrop(); - } -} - -//===--------------------------------------------------------------------===// -// GetColumnSegmentInfo -//===--------------------------------------------------------------------===// -vector RowGroupCollection::GetColumnSegmentInfo() { - vector result; - for (auto &row_group : row_groups->Segments()) { - row_group.GetColumnSegmentInfo(row_group.index, result); - } - return result; -} - -//===--------------------------------------------------------------------===// -// Alter -//===--------------------------------------------------------------------===// -shared_ptr RowGroupCollection::AddColumn(ClientContext &context, ColumnDefinition &new_column, - Expression &default_value) { - idx_t new_column_idx = types.size(); - auto new_types = types; - new_types.push_back(new_column.GetType()); - auto result = - make_shared(info, block_manager, std::move(new_types), row_start, total_rows.load()); - - ExpressionExecutor executor(context); - DataChunk dummy_chunk; - Vector default_vector(new_column.GetType()); - executor.AddExpression(default_value); - - result->stats.InitializeAddColumn(stats, new_column.GetType()); - auto &new_column_stats = result->stats.GetStats(new_column_idx); - - // fill the column with its DEFAULT value, or NULL if none is specified - auto new_stats = make_uniq(new_column.GetType()); - for (auto ¤t_row_group : row_groups->Segments()) { - auto new_row_group = current_row_group.AddColumn(*result, new_column, executor, default_value, default_vector); - // merge in the statistics - new_row_group->MergeIntoStatistics(new_column_idx, new_column_stats.Statistics()); - - result->row_groups->AppendSegment(std::move(new_row_group)); - } - return result; -} - -shared_ptr RowGroupCollection::RemoveColumn(idx_t col_idx) { - D_ASSERT(col_idx < types.size()); - auto new_types = types; - new_types.erase(new_types.begin() + col_idx); - - auto result = - make_shared(info, block_manager, std::move(new_types), row_start, total_rows.load()); - result->stats.InitializeRemoveColumn(stats, col_idx); - - for (auto ¤t_row_group : row_groups->Segments()) { - auto new_row_group = current_row_group.RemoveColumn(*result, col_idx); - result->row_groups->AppendSegment(std::move(new_row_group)); - } - return result; -} - -shared_ptr RowGroupCollection::AlterType(ClientContext &context, idx_t changed_idx, - const LogicalType &target_type, - vector bound_columns, Expression &cast_expr) { - D_ASSERT(changed_idx < types.size()); - auto new_types = types; - new_types[changed_idx] = target_type; - - auto result = - make_shared(info, block_manager, std::move(new_types), row_start, total_rows.load()); - result->stats.InitializeAlterType(stats, changed_idx, target_type); - - vector scan_types; - for (idx_t i = 0; i < bound_columns.size(); i++) { - if (bound_columns[i] == COLUMN_IDENTIFIER_ROW_ID) { - scan_types.emplace_back(LogicalType::ROW_TYPE); - } else { - scan_types.push_back(types[bound_columns[i]]); - } - } - DataChunk scan_chunk; - scan_chunk.Initialize(GetAllocator(), scan_types); - - ExpressionExecutor executor(context); - executor.AddExpression(cast_expr); - - TableScanState scan_state; - scan_state.Initialize(bound_columns); - scan_state.table_state.max_row = row_start + total_rows; - - // now alter the type of the column within all of the row_groups individually - auto &changed_stats = result->stats.GetStats(changed_idx); - for (auto ¤t_row_group : row_groups->Segments()) { - auto new_row_group = current_row_group.AlterType(*result, target_type, changed_idx, executor, - scan_state.table_state, scan_chunk); - new_row_group->MergeIntoStatistics(changed_idx, changed_stats.Statistics()); - result->row_groups->AppendSegment(std::move(new_row_group)); - } - - return result; -} - -void RowGroupCollection::VerifyNewConstraint(DataTable &parent, const BoundConstraint &constraint) { - if (total_rows == 0) { - return; - } - // scan the original table, check if there's any null value - auto ¬_null_constraint = constraint.Cast(); - vector scan_types; - auto physical_index = not_null_constraint.index.index; - D_ASSERT(physical_index < types.size()); - scan_types.push_back(types[physical_index]); - DataChunk scan_chunk; - scan_chunk.Initialize(GetAllocator(), scan_types); - - CreateIndexScanState state; - vector cids; - cids.push_back(physical_index); - // Use ScanCommitted to scan the latest committed data - state.Initialize(cids, nullptr); - InitializeScan(state.table_state, cids, nullptr); - InitializeCreateIndexScan(state); - while (true) { - scan_chunk.Reset(); - state.table_state.ScanCommitted(scan_chunk, state.segment_lock, - TableScanType::TABLE_SCAN_COMMITTED_ROWS_OMIT_PERMANENTLY_DELETED); - if (scan_chunk.size() == 0) { - break; - } - // Check constraint - if (VectorOperations::HasNull(scan_chunk.data[0], scan_chunk.size())) { - throw ConstraintException("NOT NULL constraint failed: %s.%s", info->table, - parent.column_definitions[physical_index].GetName()); - } - } -} - -//===--------------------------------------------------------------------===// -// Statistics -//===--------------------------------------------------------------------===// -void RowGroupCollection::CopyStats(TableStatistics &other_stats) { - stats.CopyStats(other_stats); -} - -unique_ptr RowGroupCollection::CopyStats(column_t column_id) { - return stats.CopyStats(column_id); -} - -void RowGroupCollection::SetDistinct(column_t column_id, unique_ptr distinct_stats) { - D_ASSERT(column_id != COLUMN_IDENTIFIER_ROW_ID); - auto stats_guard = stats.GetLock(); - stats.GetStats(column_id).SetDistinct(std::move(distinct_stats)); -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -RowVersionManager::RowVersionManager(idx_t start) : start(start), has_changes(false) { -} - -void RowVersionManager::SetStart(idx_t new_start) { - lock_guard l(version_lock); - this->start = new_start; - idx_t current_start = start; - for (idx_t i = 0; i < Storage::ROW_GROUP_VECTOR_COUNT; i++) { - if (vector_info[i]) { - vector_info[i]->start = current_start; - } - current_start += STANDARD_VECTOR_SIZE; - } -} - -idx_t RowVersionManager::GetCommittedDeletedCount(idx_t count) { - lock_guard l(version_lock); - idx_t deleted_count = 0; - for (idx_t r = 0, i = 0; r < count; r += STANDARD_VECTOR_SIZE, i++) { - if (!vector_info[i]) { - continue; - } - idx_t max_count = MinValue(STANDARD_VECTOR_SIZE, count - r); - if (max_count == 0) { - break; - } - deleted_count += vector_info[i]->GetCommittedDeletedCount(max_count); - } - return deleted_count; -} - -optional_ptr RowVersionManager::GetChunkInfo(idx_t vector_idx) { - return vector_info[vector_idx].get(); -} - -idx_t RowVersionManager::GetSelVector(TransactionData transaction, idx_t vector_idx, SelectionVector &sel_vector, - idx_t max_count) { - lock_guard l(version_lock); - auto chunk_info = GetChunkInfo(vector_idx); - if (!chunk_info) { - return max_count; - } - return chunk_info->GetSelVector(transaction, sel_vector, max_count); -} - -idx_t RowVersionManager::GetCommittedSelVector(transaction_t start_time, transaction_t transaction_id, idx_t vector_idx, - SelectionVector &sel_vector, idx_t max_count) { - lock_guard l(version_lock); - auto info = GetChunkInfo(vector_idx); - if (!info) { - return max_count; - } - return info->GetCommittedSelVector(start_time, transaction_id, sel_vector, max_count); -} - -bool RowVersionManager::Fetch(TransactionData transaction, idx_t row) { - lock_guard lock(version_lock); - idx_t vector_index = row / STANDARD_VECTOR_SIZE; - auto info = GetChunkInfo(vector_index); - if (!info) { - return true; - } - return info->Fetch(transaction, row - vector_index * STANDARD_VECTOR_SIZE); -} - -void RowVersionManager::AppendVersionInfo(TransactionData transaction, idx_t count, idx_t row_group_start, - idx_t row_group_end) { - lock_guard lock(version_lock); - has_changes = true; - idx_t start_vector_idx = row_group_start / STANDARD_VECTOR_SIZE; - idx_t end_vector_idx = (row_group_end - 1) / STANDARD_VECTOR_SIZE; - for (idx_t vector_idx = start_vector_idx; vector_idx <= end_vector_idx; vector_idx++) { - idx_t vector_start = - vector_idx == start_vector_idx ? row_group_start - start_vector_idx * STANDARD_VECTOR_SIZE : 0; - idx_t vector_end = - vector_idx == end_vector_idx ? row_group_end - end_vector_idx * STANDARD_VECTOR_SIZE : STANDARD_VECTOR_SIZE; - if (vector_start == 0 && vector_end == STANDARD_VECTOR_SIZE) { - // entire vector is encapsulated by append: append a single constant - auto constant_info = make_uniq(start + vector_idx * STANDARD_VECTOR_SIZE); - constant_info->insert_id = transaction.transaction_id; - constant_info->delete_id = NOT_DELETED_ID; - vector_info[vector_idx] = std::move(constant_info); - } else { - // part of a vector is encapsulated: append to that part - optional_ptr new_info; - if (!vector_info[vector_idx]) { - // first time appending to this vector: create new info - auto insert_info = make_uniq(start + vector_idx * STANDARD_VECTOR_SIZE); - new_info = insert_info.get(); - vector_info[vector_idx] = std::move(insert_info); - } else if (vector_info[vector_idx]->type == ChunkInfoType::VECTOR_INFO) { - // use existing vector - new_info = &vector_info[vector_idx]->Cast(); - } else { - throw InternalException("Error in RowVersionManager::AppendVersionInfo - expected either a " - "ChunkVectorInfo or no version info"); - } - new_info->Append(vector_start, vector_end, transaction.transaction_id); - } - } -} - -void RowVersionManager::CommitAppend(transaction_t commit_id, idx_t row_group_start, idx_t count) { - idx_t row_group_end = row_group_start + count; - - lock_guard lock(version_lock); - idx_t start_vector_idx = row_group_start / STANDARD_VECTOR_SIZE; - idx_t end_vector_idx = (row_group_end - 1) / STANDARD_VECTOR_SIZE; - for (idx_t vector_idx = start_vector_idx; vector_idx <= end_vector_idx; vector_idx++) { - idx_t vstart = vector_idx == start_vector_idx ? row_group_start - start_vector_idx * STANDARD_VECTOR_SIZE : 0; - idx_t vend = - vector_idx == end_vector_idx ? row_group_end - end_vector_idx * STANDARD_VECTOR_SIZE : STANDARD_VECTOR_SIZE; - - auto info = vector_info[vector_idx].get(); - info->CommitAppend(commit_id, vstart, vend); - } -} - -void RowVersionManager::RevertAppend(idx_t start_row) { - lock_guard lock(version_lock); - idx_t start_vector_idx = (start_row + (STANDARD_VECTOR_SIZE - 1)) / STANDARD_VECTOR_SIZE; - for (idx_t vector_idx = start_vector_idx; vector_idx < Storage::ROW_GROUP_VECTOR_COUNT; vector_idx++) { - vector_info[vector_idx].reset(); - } -} - -ChunkVectorInfo &RowVersionManager::GetVectorInfo(idx_t vector_idx) { - if (!vector_info[vector_idx]) { - // no info yet: create it - vector_info[vector_idx] = make_uniq(start + vector_idx * STANDARD_VECTOR_SIZE); - } else if (vector_info[vector_idx]->type == ChunkInfoType::CONSTANT_INFO) { - auto &constant = vector_info[vector_idx]->Cast(); - // info exists but it's a constant info: convert to a vector info - auto new_info = make_uniq(start + vector_idx * STANDARD_VECTOR_SIZE); - new_info->insert_id = constant.insert_id; - for (idx_t i = 0; i < STANDARD_VECTOR_SIZE; i++) { - new_info->inserted[i] = constant.insert_id; - } - vector_info[vector_idx] = std::move(new_info); - } - D_ASSERT(vector_info[vector_idx]->type == ChunkInfoType::VECTOR_INFO); - return vector_info[vector_idx]->Cast(); -} - -idx_t RowVersionManager::DeleteRows(idx_t vector_idx, transaction_t transaction_id, row_t rows[], idx_t count) { - lock_guard lock(version_lock); - has_changes = true; - return GetVectorInfo(vector_idx).Delete(transaction_id, rows, count); -} - -void RowVersionManager::CommitDelete(idx_t vector_idx, transaction_t commit_id, row_t rows[], idx_t count) { - lock_guard lock(version_lock); - has_changes = true; - GetVectorInfo(vector_idx).CommitDelete(commit_id, rows, count); -} - -vector RowVersionManager::Checkpoint(MetadataManager &manager) { - if (!has_changes && !storage_pointers.empty()) { - // the row version manager already exists on disk and no changes were made - // we can write the current pointer as-is - // ensure the blocks we are pointing to are not marked as free - manager.ClearModifiedBlocks(storage_pointers); - // return the root pointer - return storage_pointers; - } - // first count how many ChunkInfo's we need to deserialize - vector>> to_serialize; - for (idx_t vector_idx = 0; vector_idx < Storage::ROW_GROUP_VECTOR_COUNT; vector_idx++) { - auto chunk_info = vector_info[vector_idx].get(); - if (!chunk_info) { - continue; - } - if (!chunk_info->HasDeletes()) { - continue; - } - to_serialize.emplace_back(vector_idx, *chunk_info); - } - if (to_serialize.empty()) { - return vector(); - } - - storage_pointers.clear(); - - MetadataWriter writer(manager, &storage_pointers); - // now serialize the actual version information - writer.Write(to_serialize.size()); - for (auto &entry : to_serialize) { - auto &vector_idx = entry.first; - auto &chunk_info = entry.second.get(); - writer.Write(vector_idx); - chunk_info.Write(writer); - } - writer.Flush(); - - has_changes = false; - return storage_pointers; -} - -shared_ptr RowVersionManager::Deserialize(MetaBlockPointer delete_pointer, MetadataManager &manager, - idx_t start) { - if (!delete_pointer.IsValid()) { - return nullptr; - } - auto version_info = make_shared(start); - MetadataReader source(manager, delete_pointer, &version_info->storage_pointers); - auto chunk_count = source.Read(); - D_ASSERT(chunk_count > 0); - for (idx_t i = 0; i < chunk_count; i++) { - idx_t vector_index = source.Read(); - if (vector_index >= Storage::ROW_GROUP_VECTOR_COUNT) { - throw Exception("In DeserializeDeletes, vector_index is out of range for the row group. Corrupted file?"); - } - version_info->vector_info[vector_index] = ChunkInfo::Read(source); - } - version_info->has_changes = false; - return version_info; -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -void TableScanState::Initialize(vector column_ids, TableFilterSet *table_filters) { - this->column_ids = std::move(column_ids); - this->table_filters = table_filters; - if (table_filters) { - D_ASSERT(table_filters->filters.size() > 0); - this->adaptive_filter = make_uniq(table_filters); - } -} - -const vector &TableScanState::GetColumnIds() { - D_ASSERT(!column_ids.empty()); - return column_ids; -} - -TableFilterSet *TableScanState::GetFilters() { - D_ASSERT(!table_filters || adaptive_filter.get()); - return table_filters; -} - -AdaptiveFilter *TableScanState::GetAdaptiveFilter() { - return adaptive_filter.get(); -} - -void ColumnScanState::NextInternal(idx_t count) { - if (!current) { - //! There is no column segment - return; - } - row_index += count; - while (row_index >= current->start + current->count) { - current = segment_tree->GetNextSegment(current); - initialized = false; - segment_checked = false; - if (!current) { - break; - } - } - D_ASSERT(!current || (row_index >= current->start && row_index < current->start + current->count)); -} - -void ColumnScanState::Next(idx_t count) { - NextInternal(count); - for (auto &child_state : child_states) { - child_state.Next(count); - } -} - -const vector &CollectionScanState::GetColumnIds() { - return parent.GetColumnIds(); -} - -TableFilterSet *CollectionScanState::GetFilters() { - return parent.GetFilters(); -} - -AdaptiveFilter *CollectionScanState::GetAdaptiveFilter() { - return parent.GetAdaptiveFilter(); -} - -ParallelCollectionScanState::ParallelCollectionScanState() - : collection(nullptr), current_row_group(nullptr), processed_rows(0) { -} - -CollectionScanState::CollectionScanState(TableScanState &parent_p) - : row_group(nullptr), vector_index(0), max_row_group_row(0), row_groups(nullptr), max_row(0), batch_index(0), - parent(parent_p) { -} - -bool CollectionScanState::Scan(DuckTransaction &transaction, DataChunk &result) { - while (row_group) { - row_group->Scan(transaction, *this, result); - if (result.size() > 0) { - return true; - } else if (max_row <= row_group->start + row_group->count) { - row_group = nullptr; - return false; - } else { - do { - row_group = row_groups->GetNextSegment(row_group); - if (row_group) { - if (row_group->start >= max_row) { - row_group = nullptr; - break; - } - bool scan_row_group = row_group->InitializeScan(*this); - if (scan_row_group) { - // scan this row group - break; - } - } - } while (row_group); - } - } - return false; -} - -bool CollectionScanState::ScanCommitted(DataChunk &result, SegmentLock &l, TableScanType type) { - while (row_group) { - row_group->ScanCommitted(*this, result, type); - if (result.size() > 0) { - return true; - } else { - row_group = row_groups->GetNextSegment(l, row_group); - if (row_group) { - row_group->InitializeScan(*this); - } - } - } - return false; -} - -bool CollectionScanState::ScanCommitted(DataChunk &result, TableScanType type) { - while (row_group) { - row_group->ScanCommitted(*this, result, type); - if (result.size() > 0) { - return true; - } else { - row_group = row_groups->GetNextSegment(row_group); - if (row_group) { - row_group->InitializeScan(*this); - } - } - } - return false; -} - -} // namespace duckdb - - - - - - - - - - -namespace duckdb { - -StandardColumnData::StandardColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, - idx_t start_row, LogicalType type, optional_ptr parent) - : ColumnData(block_manager, info, column_index, start_row, std::move(type), parent), - validity(block_manager, info, 0, start_row, *this) { -} - -void StandardColumnData::SetStart(idx_t new_start) { - ColumnData::SetStart(new_start); - validity.SetStart(new_start); -} - -bool StandardColumnData::CheckZonemap(ColumnScanState &state, TableFilter &filter) { - if (!state.segment_checked) { - if (!state.current) { - return true; - } - state.segment_checked = true; - auto prune_result = filter.CheckStatistics(state.current->stats.statistics); - if (prune_result != FilterPropagateResult::FILTER_ALWAYS_FALSE) { - return true; - } - if (updates) { - auto update_stats = updates->GetStatistics(); - prune_result = filter.CheckStatistics(*update_stats); - return prune_result != FilterPropagateResult::FILTER_ALWAYS_FALSE; - } else { - return false; - } - } else { - return true; - } -} - -void StandardColumnData::InitializeScan(ColumnScanState &state) { - ColumnData::InitializeScan(state); - - // initialize the validity segment - D_ASSERT(state.child_states.size() == 1); - validity.InitializeScan(state.child_states[0]); -} - -void StandardColumnData::InitializeScanWithOffset(ColumnScanState &state, idx_t row_idx) { - ColumnData::InitializeScanWithOffset(state, row_idx); - - // initialize the validity segment - D_ASSERT(state.child_states.size() == 1); - validity.InitializeScanWithOffset(state.child_states[0], row_idx); -} - -idx_t StandardColumnData::Scan(TransactionData transaction, idx_t vector_index, ColumnScanState &state, - Vector &result) { - D_ASSERT(state.row_index == state.child_states[0].row_index); - auto scan_count = ColumnData::Scan(transaction, vector_index, state, result); - validity.Scan(transaction, vector_index, state.child_states[0], result); - return scan_count; -} - -idx_t StandardColumnData::ScanCommitted(idx_t vector_index, ColumnScanState &state, Vector &result, - bool allow_updates) { - D_ASSERT(state.row_index == state.child_states[0].row_index); - auto scan_count = ColumnData::ScanCommitted(vector_index, state, result, allow_updates); - validity.ScanCommitted(vector_index, state.child_states[0], result, allow_updates); - return scan_count; -} - -idx_t StandardColumnData::ScanCount(ColumnScanState &state, Vector &result, idx_t count) { - auto scan_count = ColumnData::ScanCount(state, result, count); - validity.ScanCount(state.child_states[0], result, count); - return scan_count; -} - -void StandardColumnData::InitializeAppend(ColumnAppendState &state) { - ColumnData::InitializeAppend(state); - - ColumnAppendState child_append; - validity.InitializeAppend(child_append); - state.child_appends.push_back(std::move(child_append)); -} - -void StandardColumnData::AppendData(BaseStatistics &stats, ColumnAppendState &state, UnifiedVectorFormat &vdata, - idx_t count) { - ColumnData::AppendData(stats, state, vdata, count); - validity.AppendData(stats, state.child_appends[0], vdata, count); -} - -void StandardColumnData::RevertAppend(row_t start_row) { - ColumnData::RevertAppend(start_row); - - validity.RevertAppend(start_row); -} - -idx_t StandardColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &result) { - // fetch validity mask - if (state.child_states.empty()) { - ColumnScanState child_state; - state.child_states.push_back(std::move(child_state)); - } - auto scan_count = ColumnData::Fetch(state, row_id, result); - validity.Fetch(state.child_states[0], row_id, result); - return scan_count; -} - -void StandardColumnData::Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) { - ColumnData::Update(transaction, column_index, update_vector, row_ids, update_count); - validity.Update(transaction, column_index, update_vector, row_ids, update_count); -} - -void StandardColumnData::UpdateColumn(TransactionData transaction, const vector &column_path, - Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) { - if (depth >= column_path.size()) { - // update this column - ColumnData::Update(transaction, column_path[0], update_vector, row_ids, update_count); - } else { - // update the child column (i.e. the validity column) - validity.UpdateColumn(transaction, column_path, update_vector, row_ids, update_count, depth + 1); - } -} - -unique_ptr StandardColumnData::GetUpdateStatistics() { - auto stats = updates ? updates->GetStatistics() : nullptr; - auto validity_stats = validity.GetUpdateStatistics(); - if (!stats && !validity_stats) { - return nullptr; - } - if (!stats) { - stats = BaseStatistics::CreateEmpty(type).ToUnique(); - } - if (validity_stats) { - stats->Merge(*validity_stats); - } - return stats; -} - -void StandardColumnData::FetchRow(TransactionData transaction, ColumnFetchState &state, row_t row_id, Vector &result, - idx_t result_idx) { - // find the segment the row belongs to - if (state.child_states.empty()) { - auto child_state = make_uniq(); - state.child_states.push_back(std::move(child_state)); - } - validity.FetchRow(transaction, *state.child_states[0], row_id, result, result_idx); - ColumnData::FetchRow(transaction, state, row_id, result, result_idx); -} - -void StandardColumnData::CommitDropColumn() { - ColumnData::CommitDropColumn(); - validity.CommitDropColumn(); -} - -struct StandardColumnCheckpointState : public ColumnCheckpointState { - StandardColumnCheckpointState(RowGroup &row_group, ColumnData &column_data, - PartialBlockManager &partial_block_manager) - : ColumnCheckpointState(row_group, column_data, partial_block_manager) { - } - - unique_ptr validity_state; - -public: - unique_ptr GetStatistics() override { - D_ASSERT(global_stats); - return std::move(global_stats); - } - - void WriteDataPointers(RowGroupWriter &writer, Serializer &serializer) override { - ColumnCheckpointState::WriteDataPointers(writer, serializer); - serializer.WriteObject(101, "validity", - [&](Serializer &serializer) { validity_state->WriteDataPointers(writer, serializer); }); - } -}; - -unique_ptr -StandardColumnData::CreateCheckpointState(RowGroup &row_group, PartialBlockManager &partial_block_manager) { - return make_uniq(row_group, *this, partial_block_manager); -} - -unique_ptr StandardColumnData::Checkpoint(RowGroup &row_group, - PartialBlockManager &partial_block_manager, - ColumnCheckpointInfo &checkpoint_info) { - auto validity_state = validity.Checkpoint(row_group, partial_block_manager, checkpoint_info); - auto base_state = ColumnData::Checkpoint(row_group, partial_block_manager, checkpoint_info); - auto &checkpoint_state = base_state->Cast(); - checkpoint_state.validity_state = std::move(validity_state); - return base_state; -} - -void StandardColumnData::CheckpointScan(ColumnSegment &segment, ColumnScanState &state, idx_t row_group_start, - idx_t count, Vector &scan_vector) { - ColumnData::CheckpointScan(segment, state, row_group_start, count, scan_vector); - - idx_t offset_in_row_group = state.row_index - row_group_start; - validity.ScanCommittedRange(row_group_start, offset_in_row_group, count, scan_vector); -} - -void StandardColumnData::DeserializeColumn(Deserializer &deserializer) { - ColumnData::DeserializeColumn(deserializer); - deserializer.ReadObject(101, "validity", - [&](Deserializer &deserializer) { validity.DeserializeColumn(deserializer); }); -} - -void StandardColumnData::GetColumnSegmentInfo(duckdb::idx_t row_group_index, vector col_path, - vector &result) { - ColumnData::GetColumnSegmentInfo(row_group_index, col_path, result); - col_path.push_back(0); - validity.GetColumnSegmentInfo(row_group_index, std::move(col_path), result); -} - -void StandardColumnData::Verify(RowGroup &parent) { -#ifdef DEBUG - ColumnData::Verify(parent); - validity.Verify(parent); -#endif -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -StructColumnData::StructColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, - idx_t start_row, LogicalType type_p, optional_ptr parent) - : ColumnData(block_manager, info, column_index, start_row, std::move(type_p), parent), - validity(block_manager, info, 0, start_row, *this) { - D_ASSERT(type.InternalType() == PhysicalType::STRUCT); - auto &child_types = StructType::GetChildTypes(type); - D_ASSERT(child_types.size() > 0); - if (type.id() != LogicalTypeId::UNION && StructType::IsUnnamed(type)) { - throw InvalidInputException("A table cannot be created from an unnamed struct"); - } - // the sub column index, starting at 1 (0 is the validity mask) - idx_t sub_column_index = 1; - for (auto &child_type : child_types) { - sub_columns.push_back( - ColumnData::CreateColumnUnique(block_manager, info, sub_column_index, start_row, child_type.second, this)); - sub_column_index++; - } -} - -void StructColumnData::SetStart(idx_t new_start) { - this->start = new_start; - for (auto &sub_column : sub_columns) { - sub_column->SetStart(new_start); - } - validity.SetStart(new_start); -} - -bool StructColumnData::CheckZonemap(ColumnScanState &state, TableFilter &filter) { - // table filters are not supported yet for struct columns - return false; -} - -idx_t StructColumnData::GetMaxEntry() { - return sub_columns[0]->GetMaxEntry(); -} - -void StructColumnData::InitializeScan(ColumnScanState &state) { - D_ASSERT(state.child_states.size() == sub_columns.size() + 1); - state.row_index = 0; - state.current = nullptr; - - // initialize the validity segment - validity.InitializeScan(state.child_states[0]); - - // initialize the sub-columns - for (idx_t i = 0; i < sub_columns.size(); i++) { - sub_columns[i]->InitializeScan(state.child_states[i + 1]); - } -} - -void StructColumnData::InitializeScanWithOffset(ColumnScanState &state, idx_t row_idx) { - D_ASSERT(state.child_states.size() == sub_columns.size() + 1); - state.row_index = row_idx; - state.current = nullptr; - - // initialize the validity segment - validity.InitializeScanWithOffset(state.child_states[0], row_idx); - - // initialize the sub-columns - for (idx_t i = 0; i < sub_columns.size(); i++) { - sub_columns[i]->InitializeScanWithOffset(state.child_states[i + 1], row_idx); - } -} - -idx_t StructColumnData::Scan(TransactionData transaction, idx_t vector_index, ColumnScanState &state, Vector &result) { - auto scan_count = validity.Scan(transaction, vector_index, state.child_states[0], result); - auto &child_entries = StructVector::GetEntries(result); - for (idx_t i = 0; i < sub_columns.size(); i++) { - sub_columns[i]->Scan(transaction, vector_index, state.child_states[i + 1], *child_entries[i]); - } - return scan_count; -} - -idx_t StructColumnData::ScanCommitted(idx_t vector_index, ColumnScanState &state, Vector &result, bool allow_updates) { - auto scan_count = validity.ScanCommitted(vector_index, state.child_states[0], result, allow_updates); - auto &child_entries = StructVector::GetEntries(result); - for (idx_t i = 0; i < sub_columns.size(); i++) { - sub_columns[i]->ScanCommitted(vector_index, state.child_states[i + 1], *child_entries[i], allow_updates); - } - return scan_count; -} - -idx_t StructColumnData::ScanCount(ColumnScanState &state, Vector &result, idx_t count) { - auto scan_count = validity.ScanCount(state.child_states[0], result, count); - auto &child_entries = StructVector::GetEntries(result); - for (idx_t i = 0; i < sub_columns.size(); i++) { - sub_columns[i]->ScanCount(state.child_states[i + 1], *child_entries[i], count); - } - return scan_count; -} - -void StructColumnData::Skip(ColumnScanState &state, idx_t count) { - validity.Skip(state.child_states[0], count); - - // skip inside the sub-columns - for (idx_t child_idx = 0; child_idx < sub_columns.size(); child_idx++) { - sub_columns[child_idx]->Skip(state.child_states[child_idx + 1], count); - } -} - -void StructColumnData::InitializeAppend(ColumnAppendState &state) { - ColumnAppendState validity_append; - validity.InitializeAppend(validity_append); - state.child_appends.push_back(std::move(validity_append)); - - for (auto &sub_column : sub_columns) { - ColumnAppendState child_append; - sub_column->InitializeAppend(child_append); - state.child_appends.push_back(std::move(child_append)); - } -} - -void StructColumnData::Append(BaseStatistics &stats, ColumnAppendState &state, Vector &vector, idx_t count) { - vector.Flatten(count); - - // append the null values - validity.Append(stats, state.child_appends[0], vector, count); - - auto &child_entries = StructVector::GetEntries(vector); - for (idx_t i = 0; i < child_entries.size(); i++) { - sub_columns[i]->Append(StructStats::GetChildStats(stats, i), state.child_appends[i + 1], *child_entries[i], - count); - } - this->count += count; -} - -void StructColumnData::RevertAppend(row_t start_row) { - validity.RevertAppend(start_row); - for (auto &sub_column : sub_columns) { - sub_column->RevertAppend(start_row); - } - this->count = start_row - this->start; -} - -idx_t StructColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &result) { - // fetch validity mask - auto &child_entries = StructVector::GetEntries(result); - // insert any child states that are required - for (idx_t i = state.child_states.size(); i < child_entries.size() + 1; i++) { - ColumnScanState child_state; - state.child_states.push_back(std::move(child_state)); - } - // fetch the validity state - idx_t scan_count = validity.Fetch(state.child_states[0], row_id, result); - // fetch the sub-column states - for (idx_t i = 0; i < child_entries.size(); i++) { - sub_columns[i]->Fetch(state.child_states[i + 1], row_id, *child_entries[i]); - } - return scan_count; -} - -void StructColumnData::Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) { - validity.Update(transaction, column_index, update_vector, row_ids, update_count); - auto &child_entries = StructVector::GetEntries(update_vector); - for (idx_t i = 0; i < child_entries.size(); i++) { - sub_columns[i]->Update(transaction, column_index, *child_entries[i], row_ids, update_count); - } -} - -void StructColumnData::UpdateColumn(TransactionData transaction, const vector &column_path, - Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) { - // we can never DIRECTLY update a struct column - if (depth >= column_path.size()) { - throw InternalException("Attempting to directly update a struct column - this should not be possible"); - } - auto update_column = column_path[depth]; - if (update_column == 0) { - // update the validity column - validity.UpdateColumn(transaction, column_path, update_vector, row_ids, update_count, depth + 1); - } else { - if (update_column > sub_columns.size()) { - throw InternalException("Update column_path out of range"); - } - sub_columns[update_column - 1]->UpdateColumn(transaction, column_path, update_vector, row_ids, update_count, - depth + 1); - } -} - -unique_ptr StructColumnData::GetUpdateStatistics() { - // check if any child column has updates - auto stats = BaseStatistics::CreateEmpty(type); - auto validity_stats = validity.GetUpdateStatistics(); - if (validity_stats) { - stats.Merge(*validity_stats); - } - for (idx_t i = 0; i < sub_columns.size(); i++) { - auto child_stats = sub_columns[i]->GetUpdateStatistics(); - if (child_stats) { - StructStats::SetChildStats(stats, i, std::move(child_stats)); - } - } - return stats.ToUnique(); -} - -void StructColumnData::FetchRow(TransactionData transaction, ColumnFetchState &state, row_t row_id, Vector &result, - idx_t result_idx) { - // fetch validity mask - auto &child_entries = StructVector::GetEntries(result); - // insert any child states that are required - for (idx_t i = state.child_states.size(); i < child_entries.size() + 1; i++) { - auto child_state = make_uniq(); - state.child_states.push_back(std::move(child_state)); - } - // fetch the validity state - validity.FetchRow(transaction, *state.child_states[0], row_id, result, result_idx); - // fetch the sub-column states - for (idx_t i = 0; i < child_entries.size(); i++) { - sub_columns[i]->FetchRow(transaction, *state.child_states[i + 1], row_id, *child_entries[i], result_idx); - } -} - -void StructColumnData::CommitDropColumn() { - validity.CommitDropColumn(); - for (auto &sub_column : sub_columns) { - sub_column->CommitDropColumn(); - } -} - -struct StructColumnCheckpointState : public ColumnCheckpointState { - StructColumnCheckpointState(RowGroup &row_group, ColumnData &column_data, - PartialBlockManager &partial_block_manager) - : ColumnCheckpointState(row_group, column_data, partial_block_manager) { - global_stats = StructStats::CreateEmpty(column_data.type).ToUnique(); - } - - unique_ptr validity_state; - vector> child_states; - -public: - unique_ptr GetStatistics() override { - auto stats = StructStats::CreateEmpty(column_data.type); - for (idx_t i = 0; i < child_states.size(); i++) { - StructStats::SetChildStats(stats, i, child_states[i]->GetStatistics()); - } - return stats.ToUnique(); - } - - void WriteDataPointers(RowGroupWriter &writer, Serializer &serializer) override { - serializer.WriteObject(101, "validity", - [&](Serializer &serializer) { validity_state->WriteDataPointers(writer, serializer); }); - serializer.WriteList(102, "sub_columns", child_states.size(), [&](Serializer::List &list, idx_t i) { - auto &state = child_states[i]; - list.WriteObject([&](Serializer &serializer) { state->WriteDataPointers(writer, serializer); }); - }); - } -}; - -unique_ptr StructColumnData::CreateCheckpointState(RowGroup &row_group, - PartialBlockManager &partial_block_manager) { - return make_uniq(row_group, *this, partial_block_manager); -} - -unique_ptr StructColumnData::Checkpoint(RowGroup &row_group, - PartialBlockManager &partial_block_manager, - ColumnCheckpointInfo &checkpoint_info) { - auto checkpoint_state = make_uniq(row_group, *this, partial_block_manager); - checkpoint_state->validity_state = validity.Checkpoint(row_group, partial_block_manager, checkpoint_info); - for (auto &sub_column : sub_columns) { - checkpoint_state->child_states.push_back( - sub_column->Checkpoint(row_group, partial_block_manager, checkpoint_info)); - } - return std::move(checkpoint_state); -} - -void StructColumnData::DeserializeColumn(Deserializer &deserializer) { - deserializer.ReadObject(101, "validity", - [&](Deserializer &deserializer) { validity.DeserializeColumn(deserializer); }); - - deserializer.ReadList(102, "sub_columns", [&](Deserializer::List &list, idx_t i) { - list.ReadObject([&](Deserializer &item) { sub_columns[i]->DeserializeColumn(item); }); - }); - - this->count = validity.count; -} - -void StructColumnData::GetColumnSegmentInfo(duckdb::idx_t row_group_index, vector col_path, - vector &result) { - col_path.push_back(0); - validity.GetColumnSegmentInfo(row_group_index, col_path, result); - for (idx_t i = 0; i < sub_columns.size(); i++) { - col_path.back() = i + 1; - sub_columns[i]->GetColumnSegmentInfo(row_group_index, col_path, result); - } -} - -void StructColumnData::Verify(RowGroup &parent) { -#ifdef DEBUG - ColumnData::Verify(parent); - validity.Verify(parent); - for (auto &sub_column : sub_columns) { - sub_column->Verify(parent); - } -#endif -} - -} // namespace duckdb - - - - - -namespace duckdb { - -void TableStatistics::Initialize(const vector &types, PersistentTableData &data) { - D_ASSERT(Empty()); - - column_stats = std::move(data.table_stats.column_stats); - if (column_stats.size() != types.size()) { // LCOV_EXCL_START - throw IOException("Table statistics column count is not aligned with table column count. Corrupt file?"); - } // LCOV_EXCL_STOP -} - -void TableStatistics::InitializeEmpty(const vector &types) { - D_ASSERT(Empty()); - - for (auto &type : types) { - column_stats.push_back(ColumnStatistics::CreateEmptyStats(type)); - } -} - -void TableStatistics::InitializeAddColumn(TableStatistics &parent, const LogicalType &new_column_type) { - D_ASSERT(Empty()); - - lock_guard stats_lock(parent.stats_lock); - for (idx_t i = 0; i < parent.column_stats.size(); i++) { - column_stats.push_back(parent.column_stats[i]); - } - column_stats.push_back(ColumnStatistics::CreateEmptyStats(new_column_type)); -} - -void TableStatistics::InitializeRemoveColumn(TableStatistics &parent, idx_t removed_column) { - D_ASSERT(Empty()); - - lock_guard stats_lock(parent.stats_lock); - for (idx_t i = 0; i < parent.column_stats.size(); i++) { - if (i != removed_column) { - column_stats.push_back(parent.column_stats[i]); - } - } -} - -void TableStatistics::InitializeAlterType(TableStatistics &parent, idx_t changed_idx, const LogicalType &new_type) { - D_ASSERT(Empty()); - - lock_guard stats_lock(parent.stats_lock); - for (idx_t i = 0; i < parent.column_stats.size(); i++) { - if (i == changed_idx) { - column_stats.push_back(ColumnStatistics::CreateEmptyStats(new_type)); - } else { - column_stats.push_back(parent.column_stats[i]); - } - } -} - -void TableStatistics::InitializeAddConstraint(TableStatistics &parent) { - D_ASSERT(Empty()); - - lock_guard stats_lock(parent.stats_lock); - for (idx_t i = 0; i < parent.column_stats.size(); i++) { - column_stats.push_back(parent.column_stats[i]); - } -} - -void TableStatistics::MergeStats(TableStatistics &other) { - auto l = GetLock(); - D_ASSERT(column_stats.size() == other.column_stats.size()); - for (idx_t i = 0; i < column_stats.size(); i++) { - column_stats[i]->Merge(*other.column_stats[i]); - } -} - -void TableStatistics::MergeStats(idx_t i, BaseStatistics &stats) { - auto l = GetLock(); - MergeStats(*l, i, stats); -} - -void TableStatistics::MergeStats(TableStatisticsLock &lock, idx_t i, BaseStatistics &stats) { - column_stats[i]->Statistics().Merge(stats); -} - -ColumnStatistics &TableStatistics::GetStats(idx_t i) { - return *column_stats[i]; -} - -unique_ptr TableStatistics::CopyStats(idx_t i) { - lock_guard l(stats_lock); - auto result = column_stats[i]->Statistics().Copy(); - if (column_stats[i]->HasDistinctStats()) { - result.SetDistinctCount(column_stats[i]->DistinctStats().GetCount()); - } - return result.ToUnique(); -} - -void TableStatistics::CopyStats(TableStatistics &other) { - for (auto &stats : column_stats) { - other.column_stats.push_back(stats->Copy()); - } -} - -void TableStatistics::Serialize(Serializer &serializer) const { - serializer.WriteProperty(100, "column_stats", column_stats); -} - -void TableStatistics::Deserialize(Deserializer &deserializer, ColumnList &columns) { - auto physical_columns = columns.Physical(); - - auto iter = physical_columns.begin(); - deserializer.ReadList(100, "column_stats", [&](Deserializer::List &list, idx_t i) { - auto &col = *iter; - iter.operator++(); - - auto type = col.GetType(); - deserializer.Set(type); - - column_stats.push_back(list.ReadElement>()); - - deserializer.Unset(); - }); -} - -unique_ptr TableStatistics::GetLock() { - return make_uniq(stats_lock); -} - -bool TableStatistics::Empty() { - return column_stats.empty(); -} - -} // namespace duckdb - - - - - - - - - -#include - -namespace duckdb { - -static UpdateSegment::initialize_update_function_t GetInitializeUpdateFunction(PhysicalType type); -static UpdateSegment::fetch_update_function_t GetFetchUpdateFunction(PhysicalType type); -static UpdateSegment::fetch_committed_function_t GetFetchCommittedFunction(PhysicalType type); -static UpdateSegment::fetch_committed_range_function_t GetFetchCommittedRangeFunction(PhysicalType type); - -static UpdateSegment::merge_update_function_t GetMergeUpdateFunction(PhysicalType type); -static UpdateSegment::rollback_update_function_t GetRollbackUpdateFunction(PhysicalType type); -static UpdateSegment::statistics_update_function_t GetStatisticsUpdateFunction(PhysicalType type); -static UpdateSegment::fetch_row_function_t GetFetchRowFunction(PhysicalType type); - -UpdateSegment::UpdateSegment(ColumnData &column_data) - : column_data(column_data), stats(column_data.type), heap(BufferAllocator::Get(column_data.GetDatabase())) { - auto physical_type = column_data.type.InternalType(); - - this->type_size = GetTypeIdSize(physical_type); - - this->initialize_update_function = GetInitializeUpdateFunction(physical_type); - this->fetch_update_function = GetFetchUpdateFunction(physical_type); - this->fetch_committed_function = GetFetchCommittedFunction(physical_type); - this->fetch_committed_range = GetFetchCommittedRangeFunction(physical_type); - this->fetch_row_function = GetFetchRowFunction(physical_type); - this->merge_update_function = GetMergeUpdateFunction(physical_type); - this->rollback_update_function = GetRollbackUpdateFunction(physical_type); - this->statistics_update_function = GetStatisticsUpdateFunction(physical_type); -} - -UpdateSegment::~UpdateSegment() { -} - -//===--------------------------------------------------------------------===// -// Update Info Helpers -//===--------------------------------------------------------------------===// -Value UpdateInfo::GetValue(idx_t index) { - auto &type = segment->column_data.type; - - switch (type.id()) { - case LogicalTypeId::VALIDITY: - return Value::BOOLEAN(reinterpret_cast(tuple_data)[index]); - case LogicalTypeId::INTEGER: - return Value::INTEGER(reinterpret_cast(tuple_data)[index]); - default: - throw NotImplementedException("Unimplemented type for UpdateInfo::GetValue"); - } -} - -void UpdateInfo::Print() { - Printer::Print(ToString()); -} - -string UpdateInfo::ToString() { - auto &type = segment->column_data.type; - string result = "Update Info [" + type.ToString() + ", Count: " + to_string(N) + - ", Transaction Id: " + to_string(version_number) + "]\n"; - for (idx_t i = 0; i < N; i++) { - result += to_string(tuples[i]) + ": " + GetValue(i).ToString() + "\n"; - } - if (next) { - result += "\nChild Segment: " + next->ToString(); - } - return result; -} - -void UpdateInfo::Verify() { -#ifdef DEBUG - for (idx_t i = 1; i < N; i++) { - D_ASSERT(tuples[i] > tuples[i - 1] && tuples[i] < STANDARD_VECTOR_SIZE); - } -#endif -} - -//===--------------------------------------------------------------------===// -// Update Fetch -//===--------------------------------------------------------------------===// -static void MergeValidityInfo(UpdateInfo *current, ValidityMask &result_mask) { - auto info_data = reinterpret_cast(current->tuple_data); - for (idx_t i = 0; i < current->N; i++) { - result_mask.Set(current->tuples[i], info_data[i]); - } -} - -static void UpdateMergeValidity(transaction_t start_time, transaction_t transaction_id, UpdateInfo *info, - Vector &result) { - auto &result_mask = FlatVector::Validity(result); - UpdateInfo::UpdatesForTransaction(info, start_time, transaction_id, - [&](UpdateInfo *current) { MergeValidityInfo(current, result_mask); }); -} - -template -static void MergeUpdateInfo(UpdateInfo *current, T *result_data) { - auto info_data = reinterpret_cast(current->tuple_data); - if (current->N == STANDARD_VECTOR_SIZE) { - // special case: update touches ALL tuples of this vector - // in this case we can just memcpy the data - // since the layout of the update info is guaranteed to be [0, 1, 2, 3, ...] - memcpy(result_data, info_data, sizeof(T) * current->N); - } else { - for (idx_t i = 0; i < current->N; i++) { - result_data[current->tuples[i]] = info_data[i]; - } - } -} - -template -static void UpdateMergeFetch(transaction_t start_time, transaction_t transaction_id, UpdateInfo *info, Vector &result) { - auto result_data = FlatVector::GetData(result); - UpdateInfo::UpdatesForTransaction(info, start_time, transaction_id, - [&](UpdateInfo *current) { MergeUpdateInfo(current, result_data); }); -} - -static UpdateSegment::fetch_update_function_t GetFetchUpdateFunction(PhysicalType type) { - switch (type) { - case PhysicalType::BIT: - return UpdateMergeValidity; - case PhysicalType::BOOL: - case PhysicalType::INT8: - return UpdateMergeFetch; - case PhysicalType::INT16: - return UpdateMergeFetch; - case PhysicalType::INT32: - return UpdateMergeFetch; - case PhysicalType::INT64: - return UpdateMergeFetch; - case PhysicalType::UINT8: - return UpdateMergeFetch; - case PhysicalType::UINT16: - return UpdateMergeFetch; - case PhysicalType::UINT32: - return UpdateMergeFetch; - case PhysicalType::UINT64: - return UpdateMergeFetch; - case PhysicalType::INT128: - return UpdateMergeFetch; - case PhysicalType::FLOAT: - return UpdateMergeFetch; - case PhysicalType::DOUBLE: - return UpdateMergeFetch; - case PhysicalType::INTERVAL: - return UpdateMergeFetch; - case PhysicalType::VARCHAR: - return UpdateMergeFetch; - default: - throw NotImplementedException("Unimplemented type for update segment"); - } -} - -void UpdateSegment::FetchUpdates(TransactionData transaction, idx_t vector_index, Vector &result) { - auto lock_handle = lock.GetSharedLock(); - if (!root) { - return; - } - if (!root->info[vector_index]) { - return; - } - // FIXME: normalify if this is not the case... need to pass in count? - D_ASSERT(result.GetVectorType() == VectorType::FLAT_VECTOR); - - fetch_update_function(transaction.start_time, transaction.transaction_id, root->info[vector_index]->info.get(), - result); -} - -//===--------------------------------------------------------------------===// -// Fetch Committed -//===--------------------------------------------------------------------===// -static void FetchCommittedValidity(UpdateInfo *info, Vector &result) { - auto &result_mask = FlatVector::Validity(result); - MergeValidityInfo(info, result_mask); -} - -template -static void TemplatedFetchCommitted(UpdateInfo *info, Vector &result) { - auto result_data = FlatVector::GetData(result); - MergeUpdateInfo(info, result_data); -} - -static UpdateSegment::fetch_committed_function_t GetFetchCommittedFunction(PhysicalType type) { - switch (type) { - case PhysicalType::BIT: - return FetchCommittedValidity; - case PhysicalType::BOOL: - case PhysicalType::INT8: - return TemplatedFetchCommitted; - case PhysicalType::INT16: - return TemplatedFetchCommitted; - case PhysicalType::INT32: - return TemplatedFetchCommitted; - case PhysicalType::INT64: - return TemplatedFetchCommitted; - case PhysicalType::UINT8: - return TemplatedFetchCommitted; - case PhysicalType::UINT16: - return TemplatedFetchCommitted; - case PhysicalType::UINT32: - return TemplatedFetchCommitted; - case PhysicalType::UINT64: - return TemplatedFetchCommitted; - case PhysicalType::INT128: - return TemplatedFetchCommitted; - case PhysicalType::FLOAT: - return TemplatedFetchCommitted; - case PhysicalType::DOUBLE: - return TemplatedFetchCommitted; - case PhysicalType::INTERVAL: - return TemplatedFetchCommitted; - case PhysicalType::VARCHAR: - return TemplatedFetchCommitted; - default: - throw NotImplementedException("Unimplemented type for update segment"); - } -} - -void UpdateSegment::FetchCommitted(idx_t vector_index, Vector &result) { - auto lock_handle = lock.GetSharedLock(); - - if (!root) { - return; - } - if (!root->info[vector_index]) { - return; - } - // FIXME: normalify if this is not the case... need to pass in count? - D_ASSERT(result.GetVectorType() == VectorType::FLAT_VECTOR); - - fetch_committed_function(root->info[vector_index]->info.get(), result); -} - -//===--------------------------------------------------------------------===// -// Fetch Range -//===--------------------------------------------------------------------===// -static void MergeUpdateInfoRangeValidity(UpdateInfo *current, idx_t start, idx_t end, idx_t result_offset, - ValidityMask &result_mask) { - auto info_data = reinterpret_cast(current->tuple_data); - for (idx_t i = 0; i < current->N; i++) { - auto tuple_idx = current->tuples[i]; - if (tuple_idx < start) { - continue; - } else if (tuple_idx >= end) { - break; - } - auto result_idx = result_offset + tuple_idx - start; - result_mask.Set(result_idx, info_data[i]); - } -} - -static void FetchCommittedRangeValidity(UpdateInfo *info, idx_t start, idx_t end, idx_t result_offset, Vector &result) { - auto &result_mask = FlatVector::Validity(result); - MergeUpdateInfoRangeValidity(info, start, end, result_offset, result_mask); -} - -template -static void MergeUpdateInfoRange(UpdateInfo *current, idx_t start, idx_t end, idx_t result_offset, T *result_data) { - auto info_data = reinterpret_cast(current->tuple_data); - for (idx_t i = 0; i < current->N; i++) { - auto tuple_idx = current->tuples[i]; - if (tuple_idx < start) { - continue; - } else if (tuple_idx >= end) { - break; - } - auto result_idx = result_offset + tuple_idx - start; - result_data[result_idx] = info_data[i]; - } -} - -template -static void TemplatedFetchCommittedRange(UpdateInfo *info, idx_t start, idx_t end, idx_t result_offset, - Vector &result) { - auto result_data = FlatVector::GetData(result); - MergeUpdateInfoRange(info, start, end, result_offset, result_data); -} - -static UpdateSegment::fetch_committed_range_function_t GetFetchCommittedRangeFunction(PhysicalType type) { - switch (type) { - case PhysicalType::BIT: - return FetchCommittedRangeValidity; - case PhysicalType::BOOL: - case PhysicalType::INT8: - return TemplatedFetchCommittedRange; - case PhysicalType::INT16: - return TemplatedFetchCommittedRange; - case PhysicalType::INT32: - return TemplatedFetchCommittedRange; - case PhysicalType::INT64: - return TemplatedFetchCommittedRange; - case PhysicalType::UINT8: - return TemplatedFetchCommittedRange; - case PhysicalType::UINT16: - return TemplatedFetchCommittedRange; - case PhysicalType::UINT32: - return TemplatedFetchCommittedRange; - case PhysicalType::UINT64: - return TemplatedFetchCommittedRange; - case PhysicalType::INT128: - return TemplatedFetchCommittedRange; - case PhysicalType::FLOAT: - return TemplatedFetchCommittedRange; - case PhysicalType::DOUBLE: - return TemplatedFetchCommittedRange; - case PhysicalType::INTERVAL: - return TemplatedFetchCommittedRange; - case PhysicalType::VARCHAR: - return TemplatedFetchCommittedRange; - default: - throw NotImplementedException("Unimplemented type for update segment"); - } -} - -void UpdateSegment::FetchCommittedRange(idx_t start_row, idx_t count, Vector &result) { - D_ASSERT(count > 0); - if (!root) { - return; - } - D_ASSERT(result.GetVectorType() == VectorType::FLAT_VECTOR); - - idx_t end_row = start_row + count; - idx_t start_vector = start_row / STANDARD_VECTOR_SIZE; - idx_t end_vector = (end_row - 1) / STANDARD_VECTOR_SIZE; - D_ASSERT(start_vector <= end_vector); - D_ASSERT(end_vector < Storage::ROW_GROUP_VECTOR_COUNT); - - for (idx_t vector_idx = start_vector; vector_idx <= end_vector; vector_idx++) { - if (!root->info[vector_idx]) { - continue; - } - idx_t start_in_vector = vector_idx == start_vector ? start_row - start_vector * STANDARD_VECTOR_SIZE : 0; - idx_t end_in_vector = - vector_idx == end_vector ? end_row - end_vector * STANDARD_VECTOR_SIZE : STANDARD_VECTOR_SIZE; - D_ASSERT(start_in_vector < end_in_vector); - D_ASSERT(end_in_vector > 0 && end_in_vector <= STANDARD_VECTOR_SIZE); - idx_t result_offset = ((vector_idx * STANDARD_VECTOR_SIZE) + start_in_vector) - start_row; - fetch_committed_range(root->info[vector_idx]->info.get(), start_in_vector, end_in_vector, result_offset, - result); - } -} - -//===--------------------------------------------------------------------===// -// Fetch Row -//===--------------------------------------------------------------------===// -static void FetchRowValidity(transaction_t start_time, transaction_t transaction_id, UpdateInfo *info, idx_t row_idx, - Vector &result, idx_t result_idx) { - auto &result_mask = FlatVector::Validity(result); - UpdateInfo::UpdatesForTransaction(info, start_time, transaction_id, [&](UpdateInfo *current) { - auto info_data = reinterpret_cast(current->tuple_data); - // FIXME: we could do a binary search in here - for (idx_t i = 0; i < current->N; i++) { - if (current->tuples[i] == row_idx) { - result_mask.Set(result_idx, info_data[i]); - break; - } else if (current->tuples[i] > row_idx) { - break; - } - } - }); -} - -template -static void TemplatedFetchRow(transaction_t start_time, transaction_t transaction_id, UpdateInfo *info, idx_t row_idx, - Vector &result, idx_t result_idx) { - auto result_data = FlatVector::GetData(result); - UpdateInfo::UpdatesForTransaction(info, start_time, transaction_id, [&](UpdateInfo *current) { - auto info_data = (T *)current->tuple_data; - // FIXME: we could do a binary search in here - for (idx_t i = 0; i < current->N; i++) { - if (current->tuples[i] == row_idx) { - result_data[result_idx] = info_data[i]; - break; - } else if (current->tuples[i] > row_idx) { - break; - } - } - }); -} - -static UpdateSegment::fetch_row_function_t GetFetchRowFunction(PhysicalType type) { - switch (type) { - case PhysicalType::BIT: - return FetchRowValidity; - case PhysicalType::BOOL: - case PhysicalType::INT8: - return TemplatedFetchRow; - case PhysicalType::INT16: - return TemplatedFetchRow; - case PhysicalType::INT32: - return TemplatedFetchRow; - case PhysicalType::INT64: - return TemplatedFetchRow; - case PhysicalType::UINT8: - return TemplatedFetchRow; - case PhysicalType::UINT16: - return TemplatedFetchRow; - case PhysicalType::UINT32: - return TemplatedFetchRow; - case PhysicalType::UINT64: - return TemplatedFetchRow; - case PhysicalType::INT128: - return TemplatedFetchRow; - case PhysicalType::FLOAT: - return TemplatedFetchRow; - case PhysicalType::DOUBLE: - return TemplatedFetchRow; - case PhysicalType::INTERVAL: - return TemplatedFetchRow; - case PhysicalType::VARCHAR: - return TemplatedFetchRow; - default: - throw NotImplementedException("Unimplemented type for update segment fetch row"); - } -} - -void UpdateSegment::FetchRow(TransactionData transaction, idx_t row_id, Vector &result, idx_t result_idx) { - if (!root) { - return; - } - idx_t vector_index = (row_id - column_data.start) / STANDARD_VECTOR_SIZE; - if (!root->info[vector_index]) { - return; - } - idx_t row_in_vector = (row_id - column_data.start) - vector_index * STANDARD_VECTOR_SIZE; - fetch_row_function(transaction.start_time, transaction.transaction_id, root->info[vector_index]->info.get(), - row_in_vector, result, result_idx); -} - -//===--------------------------------------------------------------------===// -// Rollback update -//===--------------------------------------------------------------------===// -template -static void RollbackUpdate(UpdateInfo &base_info, UpdateInfo &rollback_info) { - auto base_data = (T *)base_info.tuple_data; - auto rollback_data = (T *)rollback_info.tuple_data; - idx_t base_offset = 0; - for (idx_t i = 0; i < rollback_info.N; i++) { - auto id = rollback_info.tuples[i]; - while (base_info.tuples[base_offset] < id) { - base_offset++; - D_ASSERT(base_offset < base_info.N); - } - base_data[base_offset] = rollback_data[i]; - } -} - -static UpdateSegment::rollback_update_function_t GetRollbackUpdateFunction(PhysicalType type) { - switch (type) { - case PhysicalType::BIT: - return RollbackUpdate; - case PhysicalType::BOOL: - case PhysicalType::INT8: - return RollbackUpdate; - case PhysicalType::INT16: - return RollbackUpdate; - case PhysicalType::INT32: - return RollbackUpdate; - case PhysicalType::INT64: - return RollbackUpdate; - case PhysicalType::UINT8: - return RollbackUpdate; - case PhysicalType::UINT16: - return RollbackUpdate; - case PhysicalType::UINT32: - return RollbackUpdate; - case PhysicalType::UINT64: - return RollbackUpdate; - case PhysicalType::INT128: - return RollbackUpdate; - case PhysicalType::FLOAT: - return RollbackUpdate; - case PhysicalType::DOUBLE: - return RollbackUpdate; - case PhysicalType::INTERVAL: - return RollbackUpdate; - case PhysicalType::VARCHAR: - return RollbackUpdate; - default: - throw NotImplementedException("Unimplemented type for uncompressed segment"); - } -} - -void UpdateSegment::RollbackUpdate(UpdateInfo &info) { - // obtain an exclusive lock - auto lock_handle = lock.GetExclusiveLock(); - - // move the data from the UpdateInfo back into the base info - D_ASSERT(root->info[info.vector_index]); - rollback_update_function(*root->info[info.vector_index]->info, info); - - // clean up the update chain - CleanupUpdateInternal(*lock_handle, info); -} - -//===--------------------------------------------------------------------===// -// Cleanup Update -//===--------------------------------------------------------------------===// -void UpdateSegment::CleanupUpdateInternal(const StorageLockKey &lock, UpdateInfo &info) { - D_ASSERT(info.prev); - auto prev = info.prev; - prev->next = info.next; - if (prev->next) { - prev->next->prev = prev; - } -} - -void UpdateSegment::CleanupUpdate(UpdateInfo &info) { - // obtain an exclusive lock - auto lock_handle = lock.GetExclusiveLock(); - CleanupUpdateInternal(*lock_handle, info); -} - -//===--------------------------------------------------------------------===// -// Check for conflicts in update -//===--------------------------------------------------------------------===// -static void CheckForConflicts(UpdateInfo *info, TransactionData transaction, row_t *ids, const SelectionVector &sel, - idx_t count, row_t offset, UpdateInfo *&node) { - if (!info) { - return; - } - if (info->version_number == transaction.transaction_id) { - // this UpdateInfo belongs to the current transaction, set it in the node - node = info; - } else if (info->version_number > transaction.start_time) { - // potential conflict, check that tuple ids do not conflict - // as both ids and info->tuples are sorted, this is similar to a merge join - idx_t i = 0, j = 0; - while (true) { - auto id = ids[sel.get_index(i)] - offset; - if (id == info->tuples[j]) { - throw TransactionException("Conflict on update!"); - } else if (id < info->tuples[j]) { - // id < the current tuple in info, move to next id - i++; - if (i == count) { - break; - } - } else { - // id > the current tuple, move to next tuple in info - j++; - if (j == info->N) { - break; - } - } - } - } - CheckForConflicts(info->next, transaction, ids, sel, count, offset, node); -} - -//===--------------------------------------------------------------------===// -// Initialize update info -//===--------------------------------------------------------------------===// -void UpdateSegment::InitializeUpdateInfo(UpdateInfo &info, row_t *ids, const SelectionVector &sel, idx_t count, - idx_t vector_index, idx_t vector_offset) { - info.segment = this; - info.vector_index = vector_index; - info.prev = nullptr; - info.next = nullptr; - - // set up the tuple ids - info.N = count; - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto id = ids[idx]; - D_ASSERT(idx_t(id) >= vector_offset && idx_t(id) < vector_offset + STANDARD_VECTOR_SIZE); - info.tuples[i] = id - vector_offset; - }; -} - -static void InitializeUpdateValidity(UpdateInfo *base_info, Vector &base_data, UpdateInfo *update_info, Vector &update, - const SelectionVector &sel) { - auto &update_mask = FlatVector::Validity(update); - auto tuple_data = reinterpret_cast(update_info->tuple_data); - - if (!update_mask.AllValid()) { - for (idx_t i = 0; i < update_info->N; i++) { - auto idx = sel.get_index(i); - tuple_data[i] = update_mask.RowIsValidUnsafe(idx); - } - } else { - for (idx_t i = 0; i < update_info->N; i++) { - tuple_data[i] = true; - } - } - - auto &base_mask = FlatVector::Validity(base_data); - auto base_tuple_data = reinterpret_cast(base_info->tuple_data); - if (!base_mask.AllValid()) { - for (idx_t i = 0; i < base_info->N; i++) { - base_tuple_data[i] = base_mask.RowIsValidUnsafe(base_info->tuples[i]); - } - } else { - for (idx_t i = 0; i < base_info->N; i++) { - base_tuple_data[i] = true; - } - } -} - -struct UpdateSelectElement { - template - static T Operation(UpdateSegment *segment, T element) { - return element; - } -}; - -template <> -string_t UpdateSelectElement::Operation(UpdateSegment *segment, string_t element) { - return element.IsInlined() ? element : segment->GetStringHeap().AddBlob(element); -} - -template -static void InitializeUpdateData(UpdateInfo *base_info, Vector &base_data, UpdateInfo *update_info, Vector &update, - const SelectionVector &sel) { - auto update_data = FlatVector::GetData(update); - auto tuple_data = (T *)update_info->tuple_data; - - for (idx_t i = 0; i < update_info->N; i++) { - auto idx = sel.get_index(i); - tuple_data[i] = update_data[idx]; - } - - auto base_array_data = FlatVector::GetData(base_data); - auto &base_validity = FlatVector::Validity(base_data); - auto base_tuple_data = (T *)base_info->tuple_data; - for (idx_t i = 0; i < base_info->N; i++) { - auto base_idx = base_info->tuples[i]; - if (!base_validity.RowIsValid(base_idx)) { - continue; - } - base_tuple_data[i] = UpdateSelectElement::Operation(base_info->segment, base_array_data[base_idx]); - } -} - -static UpdateSegment::initialize_update_function_t GetInitializeUpdateFunction(PhysicalType type) { - switch (type) { - case PhysicalType::BIT: - return InitializeUpdateValidity; - case PhysicalType::BOOL: - case PhysicalType::INT8: - return InitializeUpdateData; - case PhysicalType::INT16: - return InitializeUpdateData; - case PhysicalType::INT32: - return InitializeUpdateData; - case PhysicalType::INT64: - return InitializeUpdateData; - case PhysicalType::UINT8: - return InitializeUpdateData; - case PhysicalType::UINT16: - return InitializeUpdateData; - case PhysicalType::UINT32: - return InitializeUpdateData; - case PhysicalType::UINT64: - return InitializeUpdateData; - case PhysicalType::INT128: - return InitializeUpdateData; - case PhysicalType::FLOAT: - return InitializeUpdateData; - case PhysicalType::DOUBLE: - return InitializeUpdateData; - case PhysicalType::INTERVAL: - return InitializeUpdateData; - case PhysicalType::VARCHAR: - return InitializeUpdateData; - default: - throw NotImplementedException("Unimplemented type for update segment"); - } -} - -//===--------------------------------------------------------------------===// -// Merge update info -//===--------------------------------------------------------------------===// -template -static idx_t MergeLoop(row_t a[], sel_t b[], idx_t acount, idx_t bcount, idx_t aoffset, F1 merge, F2 pick_a, F3 pick_b, - const SelectionVector &asel) { - idx_t aidx = 0, bidx = 0; - idx_t count = 0; - while (aidx < acount && bidx < bcount) { - auto a_index = asel.get_index(aidx); - auto a_id = a[a_index] - aoffset; - auto b_id = b[bidx]; - if (a_id == b_id) { - merge(a_id, a_index, bidx, count); - aidx++; - bidx++; - count++; - } else if (a_id < b_id) { - pick_a(a_id, a_index, count); - aidx++; - count++; - } else { - pick_b(b_id, bidx, count); - bidx++; - count++; - } - } - for (; aidx < acount; aidx++) { - auto a_index = asel.get_index(aidx); - pick_a(a[a_index] - aoffset, a_index, count); - count++; - } - for (; bidx < bcount; bidx++) { - pick_b(b[bidx], bidx, count); - count++; - } - return count; -} - -struct ExtractStandardEntry { - template - static T Extract(V *data, idx_t entry) { - return data[entry]; - } -}; - -struct ExtractValidityEntry { - template - static T Extract(V *data, idx_t entry) { - return data->RowIsValid(entry); - } -}; - -template -static void MergeUpdateLoopInternal(UpdateInfo *base_info, V *base_table_data, UpdateInfo *update_info, - V *update_vector_data, row_t *ids, idx_t count, const SelectionVector &sel) { - auto base_id = base_info->segment->column_data.start + base_info->vector_index * STANDARD_VECTOR_SIZE; -#ifdef DEBUG - // all of these should be sorted, otherwise the below algorithm does not work - for (idx_t i = 1; i < count; i++) { - auto prev_idx = sel.get_index(i - 1); - auto idx = sel.get_index(i); - D_ASSERT(ids[idx] > ids[prev_idx] && ids[idx] >= row_t(base_id) && - ids[idx] < row_t(base_id + STANDARD_VECTOR_SIZE)); - } -#endif - - // we have a new batch of updates (update, ids, count) - // we already have existing updates (base_info) - // and potentially, this transaction already has updates present (update_info) - // we need to merge these all together so that the latest updates get merged into base_info - // and the "old" values (fetched from EITHER base_info OR from base_data) get placed into update_info - auto base_info_data = (T *)base_info->tuple_data; - auto update_info_data = (T *)update_info->tuple_data; - - // we first do the merging of the old values - // what we are trying to do here is update the "update_info" of this transaction with all the old data we require - // this means we need to merge (1) any previously updated values (stored in update_info->tuples) - // together with (2) - // to simplify this, we create new arrays here - // we memcpy these over afterwards - T result_values[STANDARD_VECTOR_SIZE]; - sel_t result_ids[STANDARD_VECTOR_SIZE]; - - idx_t base_info_offset = 0; - idx_t update_info_offset = 0; - idx_t result_offset = 0; - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - // we have to merge the info for "ids[i]" - auto update_id = ids[idx] - base_id; - - while (update_info_offset < update_info->N && update_info->tuples[update_info_offset] < update_id) { - // old id comes before the current id: write it - result_values[result_offset] = update_info_data[update_info_offset]; - result_ids[result_offset++] = update_info->tuples[update_info_offset]; - update_info_offset++; - } - // write the new id - if (update_info_offset < update_info->N && update_info->tuples[update_info_offset] == update_id) { - // we have an id that is equivalent in the current update info: write the update info - result_values[result_offset] = update_info_data[update_info_offset]; - result_ids[result_offset++] = update_info->tuples[update_info_offset]; - update_info_offset++; - continue; - } - - /// now check if we have the current update_id in the base_info, or if we should fetch it from the base data - while (base_info_offset < base_info->N && base_info->tuples[base_info_offset] < update_id) { - base_info_offset++; - } - if (base_info_offset < base_info->N && base_info->tuples[base_info_offset] == update_id) { - // it is! we have to move the tuple from base_info->ids[base_info_offset] to update_info - result_values[result_offset] = base_info_data[base_info_offset]; - } else { - // it is not! we have to move base_table_data[update_id] to update_info - result_values[result_offset] = UpdateSelectElement::Operation( - base_info->segment, OP::template Extract(base_table_data, update_id)); - } - result_ids[result_offset++] = update_id; - } - // write any remaining entries from the old updates - while (update_info_offset < update_info->N) { - result_values[result_offset] = update_info_data[update_info_offset]; - result_ids[result_offset++] = update_info->tuples[update_info_offset]; - update_info_offset++; - } - // now copy them back - update_info->N = result_offset; - memcpy(update_info_data, result_values, result_offset * sizeof(T)); - memcpy(update_info->tuples, result_ids, result_offset * sizeof(sel_t)); - - // now we merge the new values into the base_info - result_offset = 0; - auto pick_new = [&](idx_t id, idx_t aidx, idx_t count) { - result_values[result_offset] = OP::template Extract(update_vector_data, aidx); - result_ids[result_offset] = id; - result_offset++; - }; - auto pick_old = [&](idx_t id, idx_t bidx, idx_t count) { - result_values[result_offset] = base_info_data[bidx]; - result_ids[result_offset] = id; - result_offset++; - }; - // now we perform a merge of the new ids with the old ids - auto merge = [&](idx_t id, idx_t aidx, idx_t bidx, idx_t count) { - pick_new(id, aidx, count); - }; - MergeLoop(ids, base_info->tuples, count, base_info->N, base_id, merge, pick_new, pick_old, sel); - - base_info->N = result_offset; - memcpy(base_info_data, result_values, result_offset * sizeof(T)); - memcpy(base_info->tuples, result_ids, result_offset * sizeof(sel_t)); -} - -static void MergeValidityLoop(UpdateInfo *base_info, Vector &base_data, UpdateInfo *update_info, Vector &update, - row_t *ids, idx_t count, const SelectionVector &sel) { - auto &base_validity = FlatVector::Validity(base_data); - auto &update_validity = FlatVector::Validity(update); - MergeUpdateLoopInternal(base_info, &base_validity, update_info, - &update_validity, ids, count, sel); -} - -template -static void MergeUpdateLoop(UpdateInfo *base_info, Vector &base_data, UpdateInfo *update_info, Vector &update, - row_t *ids, idx_t count, const SelectionVector &sel) { - auto base_table_data = FlatVector::GetData(base_data); - auto update_vector_data = FlatVector::GetData(update); - MergeUpdateLoopInternal(base_info, base_table_data, update_info, update_vector_data, ids, count, sel); -} - -static UpdateSegment::merge_update_function_t GetMergeUpdateFunction(PhysicalType type) { - switch (type) { - case PhysicalType::BIT: - return MergeValidityLoop; - case PhysicalType::BOOL: - case PhysicalType::INT8: - return MergeUpdateLoop; - case PhysicalType::INT16: - return MergeUpdateLoop; - case PhysicalType::INT32: - return MergeUpdateLoop; - case PhysicalType::INT64: - return MergeUpdateLoop; - case PhysicalType::UINT8: - return MergeUpdateLoop; - case PhysicalType::UINT16: - return MergeUpdateLoop; - case PhysicalType::UINT32: - return MergeUpdateLoop; - case PhysicalType::UINT64: - return MergeUpdateLoop; - case PhysicalType::INT128: - return MergeUpdateLoop; - case PhysicalType::FLOAT: - return MergeUpdateLoop; - case PhysicalType::DOUBLE: - return MergeUpdateLoop; - case PhysicalType::INTERVAL: - return MergeUpdateLoop; - case PhysicalType::VARCHAR: - return MergeUpdateLoop; - default: - throw NotImplementedException("Unimplemented type for uncompressed segment"); - } -} - -//===--------------------------------------------------------------------===// -// Update statistics -//===--------------------------------------------------------------------===// -unique_ptr UpdateSegment::GetStatistics() { - lock_guard stats_guard(stats_lock); - return stats.statistics.ToUnique(); -} - -idx_t UpdateValidityStatistics(UpdateSegment *segment, SegmentStatistics &stats, Vector &update, idx_t count, - SelectionVector &sel) { - auto &mask = FlatVector::Validity(update); - auto &validity = stats.statistics; - if (!mask.AllValid() && !validity.CanHaveNull()) { - for (idx_t i = 0; i < count; i++) { - if (!mask.RowIsValid(i)) { - validity.SetHasNull(); - break; - } - } - } - sel.Initialize(nullptr); - return count; -} - -template -idx_t TemplatedUpdateNumericStatistics(UpdateSegment *segment, SegmentStatistics &stats, Vector &update, idx_t count, - SelectionVector &sel) { - auto update_data = FlatVector::GetData(update); - auto &mask = FlatVector::Validity(update); - - if (mask.AllValid()) { - for (idx_t i = 0; i < count; i++) { - NumericStats::Update(stats.statistics, update_data[i]); - } - sel.Initialize(nullptr); - return count; - } else { - idx_t not_null_count = 0; - sel.Initialize(STANDARD_VECTOR_SIZE); - for (idx_t i = 0; i < count; i++) { - if (mask.RowIsValid(i)) { - sel.set_index(not_null_count++, i); - NumericStats::Update(stats.statistics, update_data[i]); - } - } - return not_null_count; - } -} - -idx_t UpdateStringStatistics(UpdateSegment *segment, SegmentStatistics &stats, Vector &update, idx_t count, - SelectionVector &sel) { - auto update_data = FlatVector::GetData(update); - auto &mask = FlatVector::Validity(update); - if (mask.AllValid()) { - for (idx_t i = 0; i < count; i++) { - StringStats::Update(stats.statistics, update_data[i]); - if (!update_data[i].IsInlined()) { - update_data[i] = segment->GetStringHeap().AddBlob(update_data[i]); - } - } - sel.Initialize(nullptr); - return count; - } else { - idx_t not_null_count = 0; - sel.Initialize(STANDARD_VECTOR_SIZE); - for (idx_t i = 0; i < count; i++) { - if (mask.RowIsValid(i)) { - sel.set_index(not_null_count++, i); - StringStats::Update(stats.statistics, update_data[i]); - if (!update_data[i].IsInlined()) { - update_data[i] = segment->GetStringHeap().AddBlob(update_data[i]); - } - } - } - return not_null_count; - } -} - -UpdateSegment::statistics_update_function_t GetStatisticsUpdateFunction(PhysicalType type) { - switch (type) { - case PhysicalType::BIT: - return UpdateValidityStatistics; - case PhysicalType::BOOL: - case PhysicalType::INT8: - return TemplatedUpdateNumericStatistics; - case PhysicalType::INT16: - return TemplatedUpdateNumericStatistics; - case PhysicalType::INT32: - return TemplatedUpdateNumericStatistics; - case PhysicalType::INT64: - return TemplatedUpdateNumericStatistics; - case PhysicalType::UINT8: - return TemplatedUpdateNumericStatistics; - case PhysicalType::UINT16: - return TemplatedUpdateNumericStatistics; - case PhysicalType::UINT32: - return TemplatedUpdateNumericStatistics; - case PhysicalType::UINT64: - return TemplatedUpdateNumericStatistics; - case PhysicalType::INT128: - return TemplatedUpdateNumericStatistics; - case PhysicalType::FLOAT: - return TemplatedUpdateNumericStatistics; - case PhysicalType::DOUBLE: - return TemplatedUpdateNumericStatistics; - case PhysicalType::INTERVAL: - return TemplatedUpdateNumericStatistics; - case PhysicalType::VARCHAR: - return UpdateStringStatistics; - default: - throw NotImplementedException("Unimplemented type for uncompressed segment"); - } -} - -//===--------------------------------------------------------------------===// -// Update -//===--------------------------------------------------------------------===// -static idx_t SortSelectionVector(SelectionVector &sel, idx_t count, row_t *ids) { - D_ASSERT(count > 0); - - bool is_sorted = true; - for (idx_t i = 1; i < count; i++) { - auto prev_idx = sel.get_index(i - 1); - auto idx = sel.get_index(i); - if (ids[idx] <= ids[prev_idx]) { - is_sorted = false; - break; - } - } - if (is_sorted) { - // already sorted: bailout - return count; - } - // not sorted: need to sort the selection vector - SelectionVector sorted_sel(count); - for (idx_t i = 0; i < count; i++) { - sorted_sel.set_index(i, sel.get_index(i)); - } - std::sort(sorted_sel.data(), sorted_sel.data() + count, [&](sel_t l, sel_t r) { return ids[l] < ids[r]; }); - // eliminate any duplicates - idx_t pos = 1; - for (idx_t i = 1; i < count; i++) { - auto prev_idx = sorted_sel.get_index(i - 1); - auto idx = sorted_sel.get_index(i); - D_ASSERT(ids[idx] >= ids[prev_idx]); - if (ids[prev_idx] != ids[idx]) { - sorted_sel.set_index(pos++, idx); - } - } -#ifdef DEBUG - for (idx_t i = 1; i < pos; i++) { - auto prev_idx = sorted_sel.get_index(i - 1); - auto idx = sorted_sel.get_index(i); - D_ASSERT(ids[idx] > ids[prev_idx]); - } -#endif - - sel.Initialize(sorted_sel); - D_ASSERT(pos > 0); - return pos; -} - -UpdateInfo *CreateEmptyUpdateInfo(TransactionData transaction, idx_t type_size, idx_t count, - unsafe_unique_array &data) { - data = make_unsafe_uniq_array(sizeof(UpdateInfo) + (sizeof(sel_t) + type_size) * STANDARD_VECTOR_SIZE); - auto update_info = reinterpret_cast(data.get()); - update_info->max = STANDARD_VECTOR_SIZE; - update_info->tuples = reinterpret_cast((data_ptr_cast(update_info)) + sizeof(UpdateInfo)); - update_info->tuple_data = (data_ptr_cast(update_info)) + sizeof(UpdateInfo) + sizeof(sel_t) * update_info->max; - update_info->version_number = transaction.transaction_id; - return update_info; -} - -void UpdateSegment::Update(TransactionData transaction, idx_t column_index, Vector &update, row_t *ids, idx_t count, - Vector &base_data) { - // obtain an exclusive lock - auto write_lock = lock.GetExclusiveLock(); - - update.Flatten(count); - - // update statistics - SelectionVector sel; - { - lock_guard stats_guard(stats_lock); - count = statistics_update_function(this, stats, update, count, sel); - } - if (count == 0) { - return; - } - - // subsequent algorithms used by the update require row ids to be (1) sorted, and (2) unique - // this is usually the case for "standard" queries (e.g. UPDATE tbl SET x=bla WHERE cond) - // however, for more exotic queries involving e.g. cross products/joins this might not be the case - // hence we explicitly check here if the ids are sorted and, if not, sort + duplicate eliminate them - count = SortSelectionVector(sel, count, ids); - D_ASSERT(count > 0); - - // create the versions for this segment, if there are none yet - if (!root) { - root = make_uniq(); - } - - // get the vector index based on the first id - // we assert that all updates must be part of the same vector - auto first_id = ids[sel.get_index(0)]; - idx_t vector_index = (first_id - column_data.start) / STANDARD_VECTOR_SIZE; - idx_t vector_offset = column_data.start + vector_index * STANDARD_VECTOR_SIZE; - - D_ASSERT(idx_t(first_id) >= column_data.start); - D_ASSERT(vector_index < Storage::ROW_GROUP_VECTOR_COUNT); - - // first check the version chain - UpdateInfo *node = nullptr; - - if (root->info[vector_index]) { - // there is already a version here, check if there are any conflicts and search for the node that belongs to - // this transaction in the version chain - auto base_info = root->info[vector_index]->info.get(); - CheckForConflicts(base_info->next, transaction, ids, sel, count, vector_offset, node); - - // there are no conflicts - // first, check if this thread has already done any updates - auto node = base_info->next; - while (node) { - if (node->version_number == transaction.transaction_id) { - // it has! use this node - break; - } - node = node->next; - } - unsafe_unique_array update_info_data; - if (!node) { - // no updates made yet by this transaction: initially the update info to empty - if (transaction.transaction) { - auto &dtransaction = transaction.transaction->Cast(); - node = dtransaction.CreateUpdateInfo(type_size, count); - } else { - node = CreateEmptyUpdateInfo(transaction, type_size, count, update_info_data); - } - node->segment = this; - node->vector_index = vector_index; - node->N = 0; - node->column_index = column_index; - - // insert the new node into the chain - node->next = base_info->next; - if (node->next) { - node->next->prev = node; - } - node->prev = base_info; - base_info->next = transaction.transaction ? node : nullptr; - } - base_info->Verify(); - node->Verify(); - - // now we are going to perform the merge - merge_update_function(base_info, base_data, node, update, ids, count, sel); - - base_info->Verify(); - node->Verify(); - } else { - // there is no version info yet: create the top level update info and fill it with the updates - auto result = make_uniq(); - - result->info = make_uniq(); - result->tuples = make_unsafe_uniq_array(STANDARD_VECTOR_SIZE); - result->tuple_data = make_unsafe_uniq_array(STANDARD_VECTOR_SIZE * type_size); - result->info->tuples = result->tuples.get(); - result->info->tuple_data = result->tuple_data.get(); - result->info->version_number = TRANSACTION_ID_START - 1; - result->info->column_index = column_index; - InitializeUpdateInfo(*result->info, ids, sel, count, vector_index, vector_offset); - - // now create the transaction level update info in the undo log - unsafe_unique_array update_info_data; - UpdateInfo *transaction_node; - if (transaction.transaction) { - transaction_node = transaction.transaction->CreateUpdateInfo(type_size, count); - } else { - transaction_node = CreateEmptyUpdateInfo(transaction, type_size, count, update_info_data); - } - - InitializeUpdateInfo(*transaction_node, ids, sel, count, vector_index, vector_offset); - - // we write the updates in the update node data, and write the updates in the info - initialize_update_function(transaction_node, base_data, result->info.get(), update, sel); - - result->info->next = transaction.transaction ? transaction_node : nullptr; - result->info->prev = nullptr; - transaction_node->next = nullptr; - transaction_node->prev = result->info.get(); - transaction_node->column_index = column_index; - - transaction_node->Verify(); - result->info->Verify(); - - root->info[vector_index] = std::move(result); - } -} - -bool UpdateSegment::HasUpdates() const { - return root.get() != nullptr; -} - -bool UpdateSegment::HasUpdates(idx_t vector_index) const { - if (!HasUpdates()) { - return false; - } - return root->info[vector_index].get(); -} - -bool UpdateSegment::HasUncommittedUpdates(idx_t vector_index) { - if (!HasUpdates(vector_index)) { - return false; - } - auto read_lock = lock.GetSharedLock(); - auto entry = root->info[vector_index].get(); - if (entry->info->next) { - return true; - } - return false; -} - -bool UpdateSegment::HasUpdates(idx_t start_row_index, idx_t end_row_index) { - if (!HasUpdates()) { - return false; - } - auto read_lock = lock.GetSharedLock(); - idx_t base_vector_index = start_row_index / STANDARD_VECTOR_SIZE; - idx_t end_vector_index = end_row_index / STANDARD_VECTOR_SIZE; - for (idx_t i = base_vector_index; i <= end_vector_index; i++) { - if (root->info[i]) { - return true; - } - } - return false; -} - -} // namespace duckdb - - - - -namespace duckdb { - -ValidityColumnData::ValidityColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, - idx_t start_row, ColumnData &parent) - : ColumnData(block_manager, info, column_index, start_row, LogicalType(LogicalTypeId::VALIDITY), &parent) { -} - -bool ValidityColumnData::CheckZonemap(ColumnScanState &state, TableFilter &filter) { - return true; -} - -} // namespace duckdb - - - - - -namespace duckdb { -void TableIndexList::AddIndex(unique_ptr index) { - D_ASSERT(index); - lock_guard lock(indexes_lock); - indexes.push_back(std::move(index)); -} - -void TableIndexList::RemoveIndex(Index &index) { - lock_guard lock(indexes_lock); - - for (idx_t index_idx = 0; index_idx < indexes.size(); index_idx++) { - auto &index_entry = indexes[index_idx]; - if (index_entry.get() == &index) { - indexes.erase(indexes.begin() + index_idx); - break; - } - } -} - -bool TableIndexList::Empty() { - lock_guard lock(indexes_lock); - return indexes.empty(); -} - -idx_t TableIndexList::Count() { - lock_guard lock(indexes_lock); - return indexes.size(); -} - -void TableIndexList::Move(TableIndexList &other) { - D_ASSERT(indexes.empty()); - indexes = std::move(other.indexes); -} - -Index *TableIndexList::FindForeignKeyIndex(const vector &fk_keys, ForeignKeyType fk_type) { - Index *result = nullptr; - Scan([&](Index &index) { - if (DataTable::IsForeignKeyIndex(fk_keys, index, fk_type)) { - result = &index; - } - return false; - }); - return result; -} - -void TableIndexList::VerifyForeignKey(const vector &fk_keys, DataChunk &chunk, - ConflictManager &conflict_manager) { - auto fk_type = conflict_manager.LookupType() == VerifyExistenceType::APPEND_FK - ? ForeignKeyType::FK_TYPE_PRIMARY_KEY_TABLE - : ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE; - - // check whether the chunk can be inserted or deleted into the referenced table storage - auto index = FindForeignKeyIndex(fk_keys, fk_type); - if (!index) { - throw InternalException("Internal Foreign Key error: could not find index to verify..."); - } - conflict_manager.SetIndexCount(1); - index->CheckConstraintsForChunk(chunk, conflict_manager); -} - -vector TableIndexList::GetRequiredColumns() { - lock_guard lock(indexes_lock); - set unique_indexes; - for (auto &index : indexes) { - for (auto col_index : index->column_ids) { - unique_indexes.insert(col_index); - } - } - vector result; - result.reserve(unique_indexes.size()); - for (auto column_index : unique_indexes) { - result.emplace_back(column_index); - } - return result; -} - -vector TableIndexList::SerializeIndexes(duckdb::MetadataWriter &writer) { - vector blocks_info; - for (auto &index : indexes) { - blocks_info.emplace_back(index->Serialize(writer)); - } - return blocks_info; -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - - - - - - - -namespace duckdb { - -bool WriteAheadLog::Replay(AttachedDatabase &database, string &path) { - Connection con(database.GetDatabase()); - auto initial_source = make_uniq(FileSystem::Get(database), path.c_str()); - if (initial_source->Finished()) { - // WAL is empty - return false; - } - - con.BeginTransaction(); - - // first deserialize the WAL to look for a checkpoint flag - // if there is a checkpoint flag, we might have already flushed the contents of the WAL to disk - ReplayState checkpoint_state(database, *con.context); - checkpoint_state.deserialize_only = true; - try { - while (true) { - // read the current entry - BinaryDeserializer deserializer(*initial_source); - deserializer.Begin(); - auto entry_type = deserializer.ReadProperty(100, "wal_type"); - if (entry_type == WALType::WAL_FLUSH) { - deserializer.End(); - // check if the file is exhausted - if (initial_source->Finished()) { - // we finished reading the file: break - break; - } - } else { - // replay the entry - checkpoint_state.ReplayEntry(entry_type, deserializer); - deserializer.End(); - } - } - } catch (SerializationException &ex) { // LCOV_EXCL_START - // serialization exception - torn WAL - // continue reading - } catch (std::exception &ex) { - Printer::PrintF("Exception in WAL playback during initial read: %s\n", ex.what()); - return false; - } catch (...) { - Printer::Print("Unknown Exception in WAL playback during initial read"); - return false; - } // LCOV_EXCL_STOP - initial_source.reset(); - if (checkpoint_state.checkpoint_id.IsValid()) { - // there is a checkpoint flag: check if we need to deserialize the WAL - auto &manager = database.GetStorageManager(); - if (manager.IsCheckpointClean(checkpoint_state.checkpoint_id)) { - // the contents of the WAL have already been checkpointed - // we can safely truncate the WAL and ignore its contents - return true; - } - } - - // we need to recover from the WAL: actually set up the replay state - BufferedFileReader reader(FileSystem::Get(database), path.c_str()); - ReplayState state(database, *con.context); - - // replay the WAL - // note that everything is wrapped inside a try/catch block here - // there can be errors in WAL replay because of a corrupt WAL file - // in this case we should throw a warning but startup anyway - try { - while (true) { - // read the current entry - BinaryDeserializer deserializer(reader); - deserializer.Begin(); - auto entry_type = deserializer.ReadProperty(100, "wal_type"); - if (entry_type == WALType::WAL_FLUSH) { - deserializer.End(); - con.Commit(); - // check if the file is exhausted - if (reader.Finished()) { - // we finished reading the file: break - break; - } - con.BeginTransaction(); - } else { - // replay the entry - state.ReplayEntry(entry_type, deserializer); - deserializer.End(); - } - } - } catch (SerializationException &ex) { // LCOV_EXCL_START - // serialization error during WAL replay: rollback - con.Rollback(); - } catch (std::exception &ex) { - // FIXME: this should report a proper warning in the connection - Printer::PrintF("Exception in WAL playback: %s\n", ex.what()); - // exception thrown in WAL replay: rollback - con.Rollback(); - } catch (...) { - Printer::Print("Unknown Exception in WAL playback: %s\n"); - // exception thrown in WAL replay: rollback - con.Rollback(); - } // LCOV_EXCL_STOP - return false; -} - -//===--------------------------------------------------------------------===// -// Replay Entries -//===--------------------------------------------------------------------===// -void ReplayState::ReplayEntry(WALType entry_type, BinaryDeserializer &deserializer) { - switch (entry_type) { - case WALType::CREATE_TABLE: - ReplayCreateTable(deserializer); - break; - case WALType::DROP_TABLE: - ReplayDropTable(deserializer); - break; - case WALType::ALTER_INFO: - ReplayAlter(deserializer); - break; - case WALType::CREATE_VIEW: - ReplayCreateView(deserializer); - break; - case WALType::DROP_VIEW: - ReplayDropView(deserializer); - break; - case WALType::CREATE_SCHEMA: - ReplayCreateSchema(deserializer); - break; - case WALType::DROP_SCHEMA: - ReplayDropSchema(deserializer); - break; - case WALType::CREATE_SEQUENCE: - ReplayCreateSequence(deserializer); - break; - case WALType::DROP_SEQUENCE: - ReplayDropSequence(deserializer); - break; - case WALType::SEQUENCE_VALUE: - ReplaySequenceValue(deserializer); - break; - case WALType::CREATE_MACRO: - ReplayCreateMacro(deserializer); - break; - case WALType::DROP_MACRO: - ReplayDropMacro(deserializer); - break; - case WALType::CREATE_TABLE_MACRO: - ReplayCreateTableMacro(deserializer); - break; - case WALType::DROP_TABLE_MACRO: - ReplayDropTableMacro(deserializer); - break; - case WALType::CREATE_INDEX: - ReplayCreateIndex(deserializer); - break; - case WALType::DROP_INDEX: - ReplayDropIndex(deserializer); - break; - case WALType::USE_TABLE: - ReplayUseTable(deserializer); - break; - case WALType::INSERT_TUPLE: - ReplayInsert(deserializer); - break; - case WALType::DELETE_TUPLE: - ReplayDelete(deserializer); - break; - case WALType::UPDATE_TUPLE: - ReplayUpdate(deserializer); - break; - case WALType::CHECKPOINT: - ReplayCheckpoint(deserializer); - break; - case WALType::CREATE_TYPE: - ReplayCreateType(deserializer); - break; - case WALType::DROP_TYPE: - ReplayDropType(deserializer); - break; - default: - throw InternalException("Invalid WAL entry type!"); - } -} - -//===--------------------------------------------------------------------===// -// Replay Table -//===--------------------------------------------------------------------===// -void ReplayState::ReplayCreateTable(BinaryDeserializer &deserializer) { - auto info = deserializer.ReadProperty>(101, "table"); - if (deserialize_only) { - return; - } - // bind the constraints to the table again - auto binder = Binder::CreateBinder(context); - auto &schema = catalog.GetSchema(context, info->schema); - auto bound_info = binder->BindCreateTableInfo(std::move(info), schema); - - catalog.CreateTable(context, *bound_info); -} - -void ReplayState::ReplayDropTable(BinaryDeserializer &deserializer) { - - DropInfo info; - - info.type = CatalogType::TABLE_ENTRY; - info.schema = deserializer.ReadProperty(101, "schema"); - info.name = deserializer.ReadProperty(102, "name"); - if (deserialize_only) { - return; - } - - catalog.DropEntry(context, info); -} - -void ReplayState::ReplayAlter(BinaryDeserializer &deserializer) { - - auto info = deserializer.ReadProperty>(101, "info"); - auto &alter_info = info->Cast(); - if (deserialize_only) { - return; - } - catalog.Alter(context, alter_info); -} - -//===--------------------------------------------------------------------===// -// Replay View -//===--------------------------------------------------------------------===// -void ReplayState::ReplayCreateView(BinaryDeserializer &deserializer) { - auto entry = deserializer.ReadProperty>(101, "view"); - if (deserialize_only) { - return; - } - catalog.CreateView(context, entry->Cast()); -} - -void ReplayState::ReplayDropView(BinaryDeserializer &deserializer) { - DropInfo info; - info.type = CatalogType::VIEW_ENTRY; - info.schema = deserializer.ReadProperty(101, "schema"); - info.name = deserializer.ReadProperty(102, "name"); - if (deserialize_only) { - return; - } - catalog.DropEntry(context, info); -} - -//===--------------------------------------------------------------------===// -// Replay Schema -//===--------------------------------------------------------------------===// -void ReplayState::ReplayCreateSchema(BinaryDeserializer &deserializer) { - CreateSchemaInfo info; - info.schema = deserializer.ReadProperty(101, "schema"); - if (deserialize_only) { - return; - } - - catalog.CreateSchema(context, info); -} - -void ReplayState::ReplayDropSchema(BinaryDeserializer &deserializer) { - DropInfo info; - - info.type = CatalogType::SCHEMA_ENTRY; - info.name = deserializer.ReadProperty(101, "schema"); - if (deserialize_only) { - return; - } - - catalog.DropEntry(context, info); -} - -//===--------------------------------------------------------------------===// -// Replay Custom Type -//===--------------------------------------------------------------------===// -void ReplayState::ReplayCreateType(BinaryDeserializer &deserializer) { - auto info = deserializer.ReadProperty>(101, "type"); - info->on_conflict = OnCreateConflict::IGNORE_ON_CONFLICT; - catalog.CreateType(context, info->Cast()); -} - -void ReplayState::ReplayDropType(BinaryDeserializer &deserializer) { - DropInfo info; - - info.type = CatalogType::TYPE_ENTRY; - info.schema = deserializer.ReadProperty(101, "schema"); - info.name = deserializer.ReadProperty(102, "name"); - if (deserialize_only) { - return; - } - - catalog.DropEntry(context, info); -} - -//===--------------------------------------------------------------------===// -// Replay Sequence -//===--------------------------------------------------------------------===// -void ReplayState::ReplayCreateSequence(BinaryDeserializer &deserializer) { - auto entry = deserializer.ReadProperty>(101, "sequence"); - if (deserialize_only) { - return; - } - - catalog.CreateSequence(context, entry->Cast()); -} - -void ReplayState::ReplayDropSequence(BinaryDeserializer &deserializer) { - DropInfo info; - info.type = CatalogType::SEQUENCE_ENTRY; - info.schema = deserializer.ReadProperty(101, "schema"); - info.name = deserializer.ReadProperty(102, "name"); - if (deserialize_only) { - return; - } - - catalog.DropEntry(context, info); -} - -void ReplayState::ReplaySequenceValue(BinaryDeserializer &deserializer) { - auto schema = deserializer.ReadProperty(101, "schema"); - auto name = deserializer.ReadProperty(102, "name"); - auto usage_count = deserializer.ReadProperty(103, "usage_count"); - auto counter = deserializer.ReadProperty(104, "counter"); - if (deserialize_only) { - return; - } - - // fetch the sequence from the catalog - auto &seq = catalog.GetEntry(context, schema, name); - if (usage_count > seq.usage_count) { - seq.usage_count = usage_count; - seq.counter = counter; - } -} - -//===--------------------------------------------------------------------===// -// Replay Macro -//===--------------------------------------------------------------------===// -void ReplayState::ReplayCreateMacro(BinaryDeserializer &deserializer) { - auto entry = deserializer.ReadProperty>(101, "macro"); - if (deserialize_only) { - return; - } - - catalog.CreateFunction(context, entry->Cast()); -} - -void ReplayState::ReplayDropMacro(BinaryDeserializer &deserializer) { - DropInfo info; - info.type = CatalogType::MACRO_ENTRY; - info.schema = deserializer.ReadProperty(101, "schema"); - info.name = deserializer.ReadProperty(102, "name"); - if (deserialize_only) { - return; - } - - catalog.DropEntry(context, info); -} - -//===--------------------------------------------------------------------===// -// Replay Table Macro -//===--------------------------------------------------------------------===// -void ReplayState::ReplayCreateTableMacro(BinaryDeserializer &deserializer) { - auto entry = deserializer.ReadProperty>(101, "table_macro"); - if (deserialize_only) { - return; - } - catalog.CreateFunction(context, entry->Cast()); -} - -void ReplayState::ReplayDropTableMacro(BinaryDeserializer &deserializer) { - DropInfo info; - info.type = CatalogType::TABLE_MACRO_ENTRY; - info.schema = deserializer.ReadProperty(101, "schema"); - info.name = deserializer.ReadProperty(102, "name"); - if (deserialize_only) { - return; - } - - catalog.DropEntry(context, info); -} - -//===--------------------------------------------------------------------===// -// Replay Index -//===--------------------------------------------------------------------===// -void ReplayState::ReplayCreateIndex(BinaryDeserializer &deserializer) { - auto info = deserializer.ReadProperty>(101, "index"); - if (deserialize_only) { - return; - } - auto &index_info = info->Cast(); - - // get the physical table to which we'll add the index - auto &table = catalog.GetEntry(context, info->schema, index_info.table); - auto &data_table = table.GetStorage(); - - // bind the parsed expressions - if (index_info.expressions.empty()) { - for (auto &parsed_expr : index_info.parsed_expressions) { - index_info.expressions.push_back(parsed_expr->Copy()); - } - } - auto binder = Binder::CreateBinder(context); - auto expressions = binder->BindCreateIndexExpressions(table, index_info); - - // create the empty index - unique_ptr index; - switch (index_info.index_type) { - case IndexType::ART: { - index = make_uniq(index_info.column_ids, TableIOManager::Get(data_table), expressions, - index_info.constraint_type, data_table.db); - break; - } - default: - throw InternalException("Unimplemented index type"); - } - - // add the index to the catalog - auto &index_entry = catalog.CreateIndex(context, index_info)->Cast(); - index_entry.index = index.get(); - index_entry.info = data_table.info; - for (auto &parsed_expr : index_info.parsed_expressions) { - index_entry.parsed_expressions.push_back(parsed_expr->Copy()); - } - - // physically add the index to the data table storage - data_table.WALAddIndex(context, std::move(index), expressions); -} - -void ReplayState::ReplayDropIndex(BinaryDeserializer &deserializer) { - DropInfo info; - info.type = CatalogType::INDEX_ENTRY; - info.schema = deserializer.ReadProperty(101, "schema"); - info.name = deserializer.ReadProperty(102, "name"); - if (deserialize_only) { - return; - } - - catalog.DropEntry(context, info); -} - -//===--------------------------------------------------------------------===// -// Replay Data -//===--------------------------------------------------------------------===// -void ReplayState::ReplayUseTable(BinaryDeserializer &deserializer) { - auto schema_name = deserializer.ReadProperty(101, "schema"); - auto table_name = deserializer.ReadProperty(102, "table"); - if (deserialize_only) { - return; - } - current_table = &catalog.GetEntry(context, schema_name, table_name); -} - -void ReplayState::ReplayInsert(BinaryDeserializer &deserializer) { - DataChunk chunk; - deserializer.ReadObject(101, "chunk", [&](Deserializer &object) { chunk.Deserialize(object); }); - if (deserialize_only) { - return; - } - if (!current_table) { - throw Exception("Corrupt WAL: insert without table"); - } - - // append to the current table - current_table->GetStorage().LocalAppend(*current_table, context, chunk); -} - -void ReplayState::ReplayDelete(BinaryDeserializer &deserializer) { - DataChunk chunk; - deserializer.ReadObject(101, "chunk", [&](Deserializer &object) { chunk.Deserialize(object); }); - if (deserialize_only) { - return; - } - if (!current_table) { - throw InternalException("Corrupt WAL: delete without table"); - } - - D_ASSERT(chunk.ColumnCount() == 1 && chunk.data[0].GetType() == LogicalType::ROW_TYPE); - row_t row_ids[1]; - Vector row_identifiers(LogicalType::ROW_TYPE, data_ptr_cast(row_ids)); - - auto source_ids = FlatVector::GetData(chunk.data[0]); - // delete the tuples from the current table - for (idx_t i = 0; i < chunk.size(); i++) { - row_ids[0] = source_ids[i]; - current_table->GetStorage().Delete(*current_table, context, row_identifiers, 1); - } -} - -void ReplayState::ReplayUpdate(BinaryDeserializer &deserializer) { - auto column_path = deserializer.ReadProperty>(101, "column_indexes"); - - DataChunk chunk; - deserializer.ReadObject(102, "chunk", [&](Deserializer &object) { chunk.Deserialize(object); }); - - if (deserialize_only) { - return; - } - if (!current_table) { - throw InternalException("Corrupt WAL: update without table"); - } - - if (column_path[0] >= current_table->GetColumns().PhysicalColumnCount()) { - throw InternalException("Corrupt WAL: column index for update out of bounds"); - } - - // remove the row id vector from the chunk - auto row_ids = std::move(chunk.data.back()); - chunk.data.pop_back(); - - // now perform the update - current_table->GetStorage().UpdateColumn(*current_table, context, row_ids, column_path, chunk); -} - -void ReplayState::ReplayCheckpoint(BinaryDeserializer &deserializer) { - checkpoint_id = deserializer.ReadProperty(101, "meta_block"); -} - -} // namespace duckdb - - - - - - - - - - -#include - -namespace duckdb { - -WriteAheadLog::WriteAheadLog(AttachedDatabase &database, const string &path) : skip_writing(false), database(database) { - wal_path = path; - writer = make_uniq(FileSystem::Get(database), path.c_str(), - FileFlags::FILE_FLAGS_WRITE | FileFlags::FILE_FLAGS_FILE_CREATE | - FileFlags::FILE_FLAGS_APPEND); -} - -WriteAheadLog::~WriteAheadLog() { -} - -int64_t WriteAheadLog::GetWALSize() { - D_ASSERT(writer); - return writer->GetFileSize(); -} - -idx_t WriteAheadLog::GetTotalWritten() { - D_ASSERT(writer); - return writer->GetTotalWritten(); -} - -void WriteAheadLog::Truncate(int64_t size) { - writer->Truncate(size); -} - -void WriteAheadLog::Delete() { - if (!writer) { - return; - } - writer.reset(); - - auto &fs = FileSystem::Get(database); - fs.RemoveFile(wal_path); -} - -//===--------------------------------------------------------------------===// -// Write Entries -//===--------------------------------------------------------------------===// -void WriteAheadLog::WriteCheckpoint(MetaBlockPointer meta_block) { - BinarySerializer serializer(*writer); - serializer.Begin(); - serializer.WriteProperty(100, "wal_type", WALType::CHECKPOINT); - serializer.WriteProperty(101, "meta_block", meta_block); - serializer.End(); -} - -//===--------------------------------------------------------------------===// -// CREATE TABLE -//===--------------------------------------------------------------------===// -void WriteAheadLog::WriteCreateTable(const TableCatalogEntry &entry) { - if (skip_writing) { - return; - } - BinarySerializer serializer(*writer); - serializer.Begin(); - serializer.WriteProperty(100, "wal_type", WALType::CREATE_TABLE); - serializer.WriteProperty(101, "table", &entry); - serializer.End(); -} - -//===--------------------------------------------------------------------===// -// DROP TABLE -//===--------------------------------------------------------------------===// -void WriteAheadLog::WriteDropTable(const TableCatalogEntry &entry) { - if (skip_writing) { - return; - } - BinarySerializer serializer(*writer); - serializer.Begin(); - serializer.WriteProperty(100, "wal_type", WALType::DROP_TABLE); - serializer.WriteProperty(101, "schema", entry.schema.name); - serializer.WriteProperty(102, "name", entry.name); - serializer.End(); -} - -//===--------------------------------------------------------------------===// -// CREATE SCHEMA -//===--------------------------------------------------------------------===// -void WriteAheadLog::WriteCreateSchema(const SchemaCatalogEntry &entry) { - if (skip_writing) { - return; - } - BinarySerializer serializer(*writer); - serializer.Begin(); - serializer.WriteProperty(100, "wal_type", WALType::CREATE_SCHEMA); - serializer.WriteProperty(101, "schema", entry.name); - serializer.End(); -} - -//===--------------------------------------------------------------------===// -// SEQUENCES -//===--------------------------------------------------------------------===// -void WriteAheadLog::WriteCreateSequence(const SequenceCatalogEntry &entry) { - if (skip_writing) { - return; - } - BinarySerializer serializer(*writer); - serializer.Begin(); - serializer.WriteProperty(100, "wal_type", WALType::CREATE_SEQUENCE); - serializer.WriteProperty(101, "sequence", &entry); - serializer.End(); -} - -void WriteAheadLog::WriteDropSequence(const SequenceCatalogEntry &entry) { - if (skip_writing) { - return; - } - BinarySerializer serializer(*writer); - serializer.Begin(); - serializer.WriteProperty(100, "wal_type", WALType::DROP_SEQUENCE); - serializer.WriteProperty(101, "schema", entry.schema.name); - serializer.WriteProperty(102, "name", entry.name); - serializer.End(); -} - -void WriteAheadLog::WriteSequenceValue(const SequenceCatalogEntry &entry, SequenceValue val) { - if (skip_writing) { - return; - } - BinarySerializer serializer(*writer); - serializer.Begin(); - serializer.WriteProperty(100, "wal_type", WALType::SEQUENCE_VALUE); - serializer.WriteProperty(101, "schema", entry.schema.name); - serializer.WriteProperty(102, "name", entry.name); - serializer.WriteProperty(103, "usage_count", val.usage_count); - serializer.WriteProperty(104, "counter", val.counter); - serializer.End(); -} - -//===--------------------------------------------------------------------===// -// MACROS -//===--------------------------------------------------------------------===// -void WriteAheadLog::WriteCreateMacro(const ScalarMacroCatalogEntry &entry) { - if (skip_writing) { - return; - } - BinarySerializer serializer(*writer); - serializer.Begin(); - serializer.WriteProperty(100, "wal_type", WALType::CREATE_MACRO); - serializer.WriteProperty(101, "macro", &entry); - serializer.End(); -} - -void WriteAheadLog::WriteDropMacro(const ScalarMacroCatalogEntry &entry) { - if (skip_writing) { - return; - } - BinarySerializer serializer(*writer); - serializer.Begin(); - serializer.WriteProperty(100, "wal_type", WALType::DROP_MACRO); - serializer.WriteProperty(101, "schema", entry.schema.name); - serializer.WriteProperty(102, "name", entry.name); - serializer.End(); -} - -void WriteAheadLog::WriteCreateTableMacro(const TableMacroCatalogEntry &entry) { - if (skip_writing) { - return; - } - BinarySerializer serializer(*writer); - serializer.Begin(); - serializer.WriteProperty(100, "wal_type", WALType::CREATE_TABLE_MACRO); - serializer.WriteProperty(101, "table", &entry); - serializer.End(); -} - -void WriteAheadLog::WriteDropTableMacro(const TableMacroCatalogEntry &entry) { - if (skip_writing) { - return; - } - BinarySerializer serializer(*writer); - serializer.Begin(); - serializer.WriteProperty(100, "wal_type", WALType::DROP_TABLE_MACRO); - serializer.WriteProperty(101, "schema", entry.schema.name); - serializer.WriteProperty(102, "name", entry.name); - serializer.End(); -} - -//===--------------------------------------------------------------------===// -// Indexes -//===--------------------------------------------------------------------===// -void WriteAheadLog::WriteCreateIndex(const IndexCatalogEntry &entry) { - if (skip_writing) { - return; - } - BinarySerializer serializer(*writer); - serializer.Begin(); - serializer.WriteProperty(100, "wal_type", WALType::CREATE_INDEX); - serializer.WriteProperty(101, "index", &entry); - serializer.End(); -} - -void WriteAheadLog::WriteDropIndex(const IndexCatalogEntry &entry) { - if (skip_writing) { - return; - } - BinarySerializer serializer(*writer); - serializer.Begin(); - serializer.WriteProperty(100, "wal_type", WALType::DROP_INDEX); - serializer.WriteProperty(101, "schema", entry.schema.name); - serializer.WriteProperty(102, "name", entry.name); - serializer.End(); -} - -//===--------------------------------------------------------------------===// -// Custom Types -//===--------------------------------------------------------------------===// -void WriteAheadLog::WriteCreateType(const TypeCatalogEntry &entry) { - if (skip_writing) { - return; - } - BinarySerializer serializer(*writer); - serializer.Begin(); - serializer.WriteProperty(100, "wal_type", WALType::CREATE_TYPE); - serializer.WriteProperty(101, "type", &entry); - serializer.End(); -} - -void WriteAheadLog::WriteDropType(const TypeCatalogEntry &entry) { - if (skip_writing) { - return; - } - BinarySerializer serializer(*writer); - serializer.Begin(); - serializer.WriteProperty(100, "wal_type", WALType::DROP_TYPE); - serializer.WriteProperty(101, "schema", entry.schema.name); - serializer.WriteProperty(102, "name", entry.name); - serializer.End(); -} - -//===--------------------------------------------------------------------===// -// VIEWS -//===--------------------------------------------------------------------===// -void WriteAheadLog::WriteCreateView(const ViewCatalogEntry &entry) { - if (skip_writing) { - return; - } - BinarySerializer serializer(*writer); - serializer.Begin(); - serializer.WriteProperty(100, "wal_type", WALType::CREATE_VIEW); - serializer.WriteProperty(101, "view", &entry); - serializer.End(); -} - -void WriteAheadLog::WriteDropView(const ViewCatalogEntry &entry) { - if (skip_writing) { - return; - } - BinarySerializer serializer(*writer); - serializer.Begin(); - serializer.WriteProperty(100, "wal_type", WALType::DROP_VIEW); - serializer.WriteProperty(101, "schema", entry.schema.name); - serializer.WriteProperty(102, "name", entry.name); - serializer.End(); -} - -//===--------------------------------------------------------------------===// -// DROP SCHEMA -//===--------------------------------------------------------------------===// -void WriteAheadLog::WriteDropSchema(const SchemaCatalogEntry &entry) { - if (skip_writing) { - return; - } - BinarySerializer serializer(*writer); - serializer.Begin(); - serializer.WriteProperty(100, "wal_type", WALType::DROP_SCHEMA); - serializer.WriteProperty(101, "schema", entry.name); - serializer.End(); -} - -//===--------------------------------------------------------------------===// -// DATA -//===--------------------------------------------------------------------===// -void WriteAheadLog::WriteSetTable(string &schema, string &table) { - if (skip_writing) { - return; - } - BinarySerializer serializer(*writer); - serializer.Begin(); - serializer.WriteProperty(100, "wal_type", WALType::USE_TABLE); - serializer.WriteProperty(101, "schema", schema); - serializer.WriteProperty(102, "table", table); - serializer.End(); -} - -void WriteAheadLog::WriteInsert(DataChunk &chunk) { - if (skip_writing) { - return; - } - D_ASSERT(chunk.size() > 0); - chunk.Verify(); - - BinarySerializer serializer(*writer); - serializer.Begin(); - serializer.WriteProperty(100, "wal_type", WALType::INSERT_TUPLE); - serializer.WriteProperty(101, "chunk", chunk); - serializer.End(); -} - -void WriteAheadLog::WriteDelete(DataChunk &chunk) { - if (skip_writing) { - return; - } - D_ASSERT(chunk.size() > 0); - D_ASSERT(chunk.ColumnCount() == 1 && chunk.data[0].GetType() == LogicalType::ROW_TYPE); - chunk.Verify(); - - BinarySerializer serializer(*writer); - serializer.Begin(); - serializer.WriteProperty(100, "wal_type", WALType::DELETE_TUPLE); - serializer.WriteProperty(101, "chunk", chunk); - serializer.End(); -} - -void WriteAheadLog::WriteUpdate(DataChunk &chunk, const vector &column_indexes) { - if (skip_writing) { - return; - } - D_ASSERT(chunk.size() > 0); - D_ASSERT(chunk.ColumnCount() == 2); - D_ASSERT(chunk.data[1].GetType().id() == LogicalType::ROW_TYPE); - chunk.Verify(); - - BinarySerializer serializer(*writer); - serializer.Begin(); - serializer.WriteProperty(100, "wal_type", WALType::UPDATE_TUPLE); - serializer.WriteProperty(101, "column_indexes", column_indexes); - serializer.WriteProperty(102, "chunk", chunk); - serializer.End(); -} - -//===--------------------------------------------------------------------===// -// Write ALTER Statement -//===--------------------------------------------------------------------===// -void WriteAheadLog::WriteAlter(const AlterInfo &info) { - if (skip_writing) { - return; - } - BinarySerializer serializer(*writer); - serializer.Begin(); - serializer.WriteProperty(100, "wal_type", WALType::ALTER_INFO); - serializer.WriteProperty(101, "info", &info); - serializer.End(); -} - -//===--------------------------------------------------------------------===// -// FLUSH -//===--------------------------------------------------------------------===// -void WriteAheadLog::Flush() { - if (skip_writing) { - return; - } - - BinarySerializer serializer(*writer); - serializer.Begin(); - // write an empty entry - serializer.WriteProperty(100, "wal_type", WALType::WAL_FLUSH); - serializer.End(); - - // flushes all changes made to the WAL to disk - writer->Sync(); -} - -} // namespace duckdb - - - - - - - - - - - - -namespace duckdb { - -CleanupState::CleanupState() : current_table(nullptr), count(0) { -} - -CleanupState::~CleanupState() { - Flush(); -} - -void CleanupState::CleanupEntry(UndoFlags type, data_ptr_t data) { - switch (type) { - case UndoFlags::CATALOG_ENTRY: { - auto catalog_entry = Load(data); - D_ASSERT(catalog_entry); - D_ASSERT(catalog_entry->set); - catalog_entry->set->CleanupEntry(*catalog_entry); - break; - } - case UndoFlags::DELETE_TUPLE: { - auto info = reinterpret_cast(data); - CleanupDelete(*info); - break; - } - case UndoFlags::UPDATE_TUPLE: { - auto info = reinterpret_cast(data); - CleanupUpdate(*info); - break; - } - default: - break; - } -} - -void CleanupState::CleanupUpdate(UpdateInfo &info) { - // remove the update info from the update chain - // first obtain an exclusive lock on the segment - info.segment->CleanupUpdate(info); -} - -void CleanupState::CleanupDelete(DeleteInfo &info) { - auto version_table = info.table; - D_ASSERT(version_table->info->cardinality >= info.count); - version_table->info->cardinality -= info.count; - - if (version_table->info->indexes.Empty()) { - // this table has no indexes: no cleanup to be done - return; - } - - if (current_table != version_table) { - // table for this entry differs from previous table: flush and switch to the new table - Flush(); - current_table = version_table; - } - - // possibly vacuum any indexes in this table later - indexed_tables[current_table->info->table] = current_table; - - count = 0; - for (idx_t i = 0; i < info.count; i++) { - row_numbers[count++] = info.base_row + info.rows[i]; - } - Flush(); -} - -void CleanupState::Flush() { - if (count == 0) { - return; - } - - // set up the row identifiers vector - Vector row_identifiers(LogicalType::ROW_TYPE, data_ptr_cast(row_numbers)); - - // delete the tuples from all the indexes - try { - current_table->RemoveFromIndexes(row_identifiers, count); - } catch (...) { - } - - count = 0; -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - - - - - -namespace duckdb { - -CommitState::CommitState(transaction_t commit_id, optional_ptr log) - : log(log), commit_id(commit_id), current_table_info(nullptr) { -} - -void CommitState::SwitchTable(DataTableInfo *table_info, UndoFlags new_op) { - if (current_table_info != table_info) { - // write the current table to the log - log->WriteSetTable(table_info->schema, table_info->table); - current_table_info = table_info; - } -} - -void CommitState::WriteCatalogEntry(CatalogEntry &entry, data_ptr_t dataptr) { - if (entry.temporary || entry.parent->temporary) { - return; - } - D_ASSERT(log); - // look at the type of the parent entry - auto parent = entry.parent; - switch (parent->type) { - case CatalogType::TABLE_ENTRY: - if (entry.type == CatalogType::TABLE_ENTRY) { - auto &table_entry = entry.Cast(); - D_ASSERT(table_entry.IsDuckTable()); - // ALTER TABLE statement, read the extra data after the entry - - auto extra_data_size = Load(dataptr); - auto extra_data = data_ptr_cast(dataptr + sizeof(idx_t)); - - MemoryStream source(extra_data, extra_data_size); - BinaryDeserializer deserializer(source); - deserializer.Begin(); - auto column_name = deserializer.ReadProperty(100, "column_name"); - auto parse_info = deserializer.ReadProperty>(101, "alter_info"); - deserializer.End(); - - if (!column_name.empty()) { - // write the alter table in the log - table_entry.CommitAlter(column_name); - } - auto &alter_info = parse_info->Cast(); - log->WriteAlter(alter_info); - } else { - // CREATE TABLE statement - log->WriteCreateTable(parent->Cast()); - } - break; - case CatalogType::SCHEMA_ENTRY: - if (entry.type == CatalogType::SCHEMA_ENTRY) { - // ALTER TABLE statement, skip it - return; - } - log->WriteCreateSchema(parent->Cast()); - break; - case CatalogType::VIEW_ENTRY: - if (entry.type == CatalogType::VIEW_ENTRY) { - // ALTER TABLE statement, read the extra data after the entry - auto extra_data_size = Load(dataptr); - auto extra_data = data_ptr_cast(dataptr + sizeof(idx_t)); - // deserialize it - MemoryStream source(extra_data, extra_data_size); - BinaryDeserializer deserializer(source); - deserializer.Begin(); - auto column_name = deserializer.ReadProperty(100, "column_name"); - auto parse_info = deserializer.ReadProperty>(101, "alter_info"); - deserializer.End(); - - (void)column_name; - - // write the alter table in the log - auto &alter_info = parse_info->Cast(); - log->WriteAlter(alter_info); - } else { - log->WriteCreateView(parent->Cast()); - } - break; - case CatalogType::SEQUENCE_ENTRY: - log->WriteCreateSequence(parent->Cast()); - break; - case CatalogType::MACRO_ENTRY: - log->WriteCreateMacro(parent->Cast()); - break; - case CatalogType::TABLE_MACRO_ENTRY: - log->WriteCreateTableMacro(parent->Cast()); - break; - case CatalogType::INDEX_ENTRY: - log->WriteCreateIndex(parent->Cast()); - break; - case CatalogType::TYPE_ENTRY: - log->WriteCreateType(parent->Cast()); - break; - case CatalogType::DELETED_ENTRY: - switch (entry.type) { - case CatalogType::TABLE_ENTRY: { - auto &table_entry = entry.Cast(); - D_ASSERT(table_entry.IsDuckTable()); - table_entry.CommitDrop(); - log->WriteDropTable(table_entry); - break; - } - case CatalogType::SCHEMA_ENTRY: - log->WriteDropSchema(entry.Cast()); - break; - case CatalogType::VIEW_ENTRY: - log->WriteDropView(entry.Cast()); - break; - case CatalogType::SEQUENCE_ENTRY: - log->WriteDropSequence(entry.Cast()); - break; - case CatalogType::MACRO_ENTRY: - log->WriteDropMacro(entry.Cast()); - break; - case CatalogType::TABLE_MACRO_ENTRY: - log->WriteDropTableMacro(entry.Cast()); - break; - case CatalogType::TYPE_ENTRY: - log->WriteDropType(entry.Cast()); - break; - case CatalogType::INDEX_ENTRY: { - auto &index_entry = entry.Cast(); - index_entry.CommitDrop(); - log->WriteDropIndex(entry.Cast()); - break; - } - case CatalogType::PREPARED_STATEMENT: - case CatalogType::SCALAR_FUNCTION_ENTRY: - // do nothing, indexes/prepared statements/functions aren't persisted to disk - break; - default: - throw InternalException("Don't know how to drop this type!"); - } - break; - case CatalogType::PREPARED_STATEMENT: - case CatalogType::AGGREGATE_FUNCTION_ENTRY: - case CatalogType::SCALAR_FUNCTION_ENTRY: - case CatalogType::TABLE_FUNCTION_ENTRY: - case CatalogType::COPY_FUNCTION_ENTRY: - case CatalogType::PRAGMA_FUNCTION_ENTRY: - case CatalogType::COLLATION_ENTRY: - // do nothing, these entries are not persisted to disk - break; - default: - throw InternalException("UndoBuffer - don't know how to write this entry to the WAL"); - } -} - -void CommitState::WriteDelete(DeleteInfo &info) { - D_ASSERT(log); - // switch to the current table, if necessary - SwitchTable(info.table->info.get(), UndoFlags::DELETE_TUPLE); - - if (!delete_chunk) { - delete_chunk = make_uniq(); - vector delete_types = {LogicalType::ROW_TYPE}; - delete_chunk->Initialize(Allocator::DefaultAllocator(), delete_types); - } - auto rows = FlatVector::GetData(delete_chunk->data[0]); - for (idx_t i = 0; i < info.count; i++) { - rows[i] = info.base_row + info.rows[i]; - } - delete_chunk->SetCardinality(info.count); - log->WriteDelete(*delete_chunk); -} - -void CommitState::WriteUpdate(UpdateInfo &info) { - D_ASSERT(log); - // switch to the current table, if necessary - auto &column_data = info.segment->column_data; - auto &table_info = column_data.GetTableInfo(); - - SwitchTable(&table_info, UndoFlags::UPDATE_TUPLE); - - // initialize the update chunk - vector update_types; - if (column_data.type.id() == LogicalTypeId::VALIDITY) { - update_types.emplace_back(LogicalType::BOOLEAN); - } else { - update_types.push_back(column_data.type); - } - update_types.emplace_back(LogicalType::ROW_TYPE); - - update_chunk = make_uniq(); - update_chunk->Initialize(Allocator::DefaultAllocator(), update_types); - - // fetch the updated values from the base segment - info.segment->FetchCommitted(info.vector_index, update_chunk->data[0]); - - // write the row ids into the chunk - auto row_ids = FlatVector::GetData(update_chunk->data[1]); - idx_t start = column_data.start + info.vector_index * STANDARD_VECTOR_SIZE; - for (idx_t i = 0; i < info.N; i++) { - row_ids[info.tuples[i]] = start + info.tuples[i]; - } - if (column_data.type.id() == LogicalTypeId::VALIDITY) { - // zero-initialize the booleans - // FIXME: this is only required because of NullValue in Vector::Serialize... - auto booleans = FlatVector::GetData(update_chunk->data[0]); - for (idx_t i = 0; i < info.N; i++) { - auto idx = info.tuples[i]; - booleans[idx] = false; - } - } - SelectionVector sel(info.tuples); - update_chunk->Slice(sel, info.N); - - // construct the column index path - vector column_indexes; - reference current_column_data = column_data; - while (current_column_data.get().parent) { - column_indexes.push_back(current_column_data.get().column_index); - current_column_data = *current_column_data.get().parent; - } - column_indexes.push_back(info.column_index); - std::reverse(column_indexes.begin(), column_indexes.end()); - - log->WriteUpdate(*update_chunk, column_indexes); -} - -template -void CommitState::CommitEntry(UndoFlags type, data_ptr_t data) { - switch (type) { - case UndoFlags::CATALOG_ENTRY: { - // set the commit timestamp of the catalog entry to the given id - auto catalog_entry = Load(data); - D_ASSERT(catalog_entry->parent); - - auto &catalog = catalog_entry->ParentCatalog(); - D_ASSERT(catalog.IsDuckCatalog()); - - // Grab a write lock on the catalog - auto &duck_catalog = catalog.Cast(); - lock_guard write_lock(duck_catalog.GetWriteLock()); - catalog_entry->set->UpdateTimestamp(*catalog_entry->parent, commit_id); - if (catalog_entry->name != catalog_entry->parent->name) { - catalog_entry->set->UpdateTimestamp(*catalog_entry, commit_id); - } - if (HAS_LOG) { - // push the catalog update to the WAL - WriteCatalogEntry(*catalog_entry, data + sizeof(CatalogEntry *)); - } - break; - } - case UndoFlags::INSERT_TUPLE: { - // append: - auto info = reinterpret_cast(data); - if (HAS_LOG && !info->table->info->IsTemporary()) { - info->table->WriteToLog(*log, info->start_row, info->count); - } - // mark the tuples as committed - info->table->CommitAppend(commit_id, info->start_row, info->count); - break; - } - case UndoFlags::DELETE_TUPLE: { - // deletion: - auto info = reinterpret_cast(data); - if (HAS_LOG && !info->table->info->IsTemporary()) { - WriteDelete(*info); - } - // mark the tuples as committed - info->version_info->CommitDelete(info->vector_idx, commit_id, info->rows, info->count); - break; - } - case UndoFlags::UPDATE_TUPLE: { - // update: - auto info = reinterpret_cast(data); - if (HAS_LOG && !info->segment->column_data.GetTableInfo().IsTemporary()) { - WriteUpdate(*info); - } - info->version_number = commit_id; - break; - } - default: - throw InternalException("UndoBuffer - don't know how to commit this type!"); - } -} - -void CommitState::RevertCommit(UndoFlags type, data_ptr_t data) { - transaction_t transaction_id = commit_id; - switch (type) { - case UndoFlags::CATALOG_ENTRY: { - // set the commit timestamp of the catalog entry to the given id - auto catalog_entry = Load(data); - D_ASSERT(catalog_entry->parent); - catalog_entry->set->UpdateTimestamp(*catalog_entry->parent, transaction_id); - if (catalog_entry->name != catalog_entry->parent->name) { - catalog_entry->set->UpdateTimestamp(*catalog_entry, transaction_id); - } - break; - } - case UndoFlags::INSERT_TUPLE: { - auto info = reinterpret_cast(data); - // revert this append - info->table->RevertAppend(info->start_row, info->count); - break; - } - case UndoFlags::DELETE_TUPLE: { - // deletion: - auto info = reinterpret_cast(data); - info->table->info->cardinality += info->count; - // revert the commit by writing the (uncommitted) transaction_id back into the version info - info->version_info->CommitDelete(info->vector_idx, transaction_id, info->rows, info->count); - break; - } - case UndoFlags::UPDATE_TUPLE: { - // update: - auto info = reinterpret_cast(data); - info->version_number = transaction_id; - break; - } - default: - throw InternalException("UndoBuffer - don't know how to revert commit of this type!"); - } -} - -template void CommitState::CommitEntry(UndoFlags type, data_ptr_t data); -template void CommitState::CommitEntry(UndoFlags type, data_ptr_t data); - -} // namespace duckdb - - - - - - - - - - - - - - - - - - - -namespace duckdb { - -TransactionData::TransactionData(DuckTransaction &transaction_p) // NOLINT - : transaction(&transaction_p), transaction_id(transaction_p.transaction_id), start_time(transaction_p.start_time) { -} -TransactionData::TransactionData(transaction_t transaction_id_p, transaction_t start_time_p) - : transaction(nullptr), transaction_id(transaction_id_p), start_time(start_time_p) { -} - -DuckTransaction::DuckTransaction(TransactionManager &manager, ClientContext &context_p, transaction_t start_time, - transaction_t transaction_id) - : Transaction(manager, context_p), start_time(start_time), transaction_id(transaction_id), commit_id(0), - highest_active_query(0), undo_buffer(context_p), storage(make_uniq(context_p, *this)) { -} - -DuckTransaction::~DuckTransaction() { -} - -DuckTransaction &DuckTransaction::Get(ClientContext &context, AttachedDatabase &db) { - return DuckTransaction::Get(context, db.GetCatalog()); -} - -DuckTransaction &DuckTransaction::Get(ClientContext &context, Catalog &catalog) { - auto &transaction = Transaction::Get(context, catalog); - if (!transaction.IsDuckTransaction()) { - throw InternalException("DuckTransaction::Get called on non-DuckDB transaction"); - } - return transaction.Cast(); -} - -LocalStorage &DuckTransaction::GetLocalStorage() { - return *storage; -} - -void DuckTransaction::PushCatalogEntry(CatalogEntry &entry, data_ptr_t extra_data, idx_t extra_data_size) { - idx_t alloc_size = sizeof(CatalogEntry *); - if (extra_data_size > 0) { - alloc_size += extra_data_size + sizeof(idx_t); - } - auto baseptr = undo_buffer.CreateEntry(UndoFlags::CATALOG_ENTRY, alloc_size); - // store the pointer to the catalog entry - Store(&entry, baseptr); - if (extra_data_size > 0) { - // copy the extra data behind the catalog entry pointer (if any) - baseptr += sizeof(CatalogEntry *); - // first store the extra data size - Store(extra_data_size, baseptr); - baseptr += sizeof(idx_t); - // then copy over the actual data - memcpy(baseptr, extra_data, extra_data_size); - } -} - -void DuckTransaction::PushDelete(DataTable &table, RowVersionManager &info, idx_t vector_idx, row_t rows[], idx_t count, - idx_t base_row) { - auto delete_info = reinterpret_cast( - undo_buffer.CreateEntry(UndoFlags::DELETE_TUPLE, sizeof(DeleteInfo) + sizeof(row_t) * count)); - delete_info->version_info = &info; - delete_info->vector_idx = vector_idx; - delete_info->table = &table; - delete_info->count = count; - delete_info->base_row = base_row; - memcpy(delete_info->rows, rows, sizeof(row_t) * count); -} - -void DuckTransaction::PushAppend(DataTable &table, idx_t start_row, idx_t row_count) { - auto append_info = - reinterpret_cast(undo_buffer.CreateEntry(UndoFlags::INSERT_TUPLE, sizeof(AppendInfo))); - append_info->table = &table; - append_info->start_row = start_row; - append_info->count = row_count; -} - -UpdateInfo *DuckTransaction::CreateUpdateInfo(idx_t type_size, idx_t entries) { - data_ptr_t base_info = undo_buffer.CreateEntry( - UndoFlags::UPDATE_TUPLE, sizeof(UpdateInfo) + (sizeof(sel_t) + type_size) * STANDARD_VECTOR_SIZE); - auto update_info = reinterpret_cast(base_info); - update_info->max = STANDARD_VECTOR_SIZE; - update_info->tuples = reinterpret_cast(base_info + sizeof(UpdateInfo)); - update_info->tuple_data = base_info + sizeof(UpdateInfo) + sizeof(sel_t) * update_info->max; - update_info->version_number = transaction_id; - return update_info; -} - -bool DuckTransaction::ChangesMade() { - return undo_buffer.ChangesMade() || storage->ChangesMade(); -} - -bool DuckTransaction::AutomaticCheckpoint(AttachedDatabase &db) { - auto &storage_manager = db.GetStorageManager(); - return storage_manager.AutomaticCheckpoint(storage->EstimatedSize() + undo_buffer.EstimatedSize()); -} - -string DuckTransaction::Commit(AttachedDatabase &db, transaction_t commit_id, bool checkpoint) noexcept { - // "checkpoint" parameter indicates if the caller will checkpoint. If checkpoint == - // true: Then this function will NOT write to the WAL or flush/persist. - // This method only makes commit in memory, expecting caller to checkpoint/flush. - // false: Then this function WILL write to the WAL and Flush/Persist it. - this->commit_id = commit_id; - - UndoBuffer::IteratorState iterator_state; - LocalStorage::CommitState commit_state; - unique_ptr storage_commit_state; - optional_ptr log; - if (!db.IsSystem()) { - auto &storage_manager = db.GetStorageManager(); - log = storage_manager.GetWriteAheadLog(); - storage_commit_state = storage_manager.GenStorageCommitState(*this, checkpoint); - } else { - log = nullptr; - } - try { - storage->Commit(commit_state, *this); - undo_buffer.Commit(iterator_state, log, commit_id); - if (log) { - // commit any sequences that were used to the WAL - for (auto &entry : sequence_usage) { - log->WriteSequenceValue(*entry.first, entry.second); - } - } - if (storage_commit_state) { - storage_commit_state->FlushCommit(); - } - return string(); - } catch (std::exception &ex) { - undo_buffer.RevertCommit(iterator_state, this->transaction_id); - return ex.what(); - } -} - -void DuckTransaction::Rollback() noexcept { - storage->Rollback(); - undo_buffer.Rollback(); -} - -void DuckTransaction::Cleanup() { - undo_buffer.Cleanup(); -} - -} // namespace duckdb - - - - - - - - - - - - - - - -namespace duckdb { - -struct CheckpointLock { - explicit CheckpointLock(DuckTransactionManager &manager) : manager(manager), is_locked(false) { - } - ~CheckpointLock() { - Unlock(); - } - - DuckTransactionManager &manager; - bool is_locked; - - void Lock() { - D_ASSERT(!manager.thread_is_checkpointing); - manager.thread_is_checkpointing = true; - is_locked = true; - } - void Unlock() { - if (!is_locked) { - return; - } - D_ASSERT(manager.thread_is_checkpointing); - manager.thread_is_checkpointing = false; - is_locked = false; - } -}; - -DuckTransactionManager::DuckTransactionManager(AttachedDatabase &db) - : TransactionManager(db), thread_is_checkpointing(false) { - // start timestamp starts at two - current_start_timestamp = 2; - // transaction ID starts very high: - // it should be much higher than the current start timestamp - // if transaction_id < start_timestamp for any set of active transactions - // uncommited data could be read by - current_transaction_id = TRANSACTION_ID_START; - lowest_active_id = TRANSACTION_ID_START; - lowest_active_start = MAX_TRANSACTION_ID; -} - -DuckTransactionManager::~DuckTransactionManager() { -} - -DuckTransactionManager &DuckTransactionManager::Get(AttachedDatabase &db) { - auto &transaction_manager = TransactionManager::Get(db); - if (!transaction_manager.IsDuckTransactionManager()) { - throw InternalException("Calling DuckTransactionManager::Get on non-DuckDB transaction manager"); - } - return reinterpret_cast(transaction_manager); -} - -Transaction *DuckTransactionManager::StartTransaction(ClientContext &context) { - // obtain the transaction lock during this function - lock_guard lock(transaction_lock); - if (current_start_timestamp >= TRANSACTION_ID_START) { // LCOV_EXCL_START - throw InternalException("Cannot start more transactions, ran out of " - "transaction identifiers!"); - } // LCOV_EXCL_STOP - - // obtain the start time and transaction ID of this transaction - transaction_t start_time = current_start_timestamp++; - transaction_t transaction_id = current_transaction_id++; - if (active_transactions.empty()) { - lowest_active_start = start_time; - lowest_active_id = transaction_id; - } - - // create the actual transaction - auto transaction = make_uniq(*this, context, start_time, transaction_id); - auto transaction_ptr = transaction.get(); - - // store it in the set of active transactions - active_transactions.push_back(std::move(transaction)); - return transaction_ptr; -} - -struct ClientLockWrapper { - ClientLockWrapper(mutex &client_lock, shared_ptr connection) - : connection(std::move(connection)), connection_lock(make_uniq>(client_lock)) { - } - - shared_ptr connection; - unique_ptr> connection_lock; -}; - -void DuckTransactionManager::LockClients(vector &client_locks, ClientContext &context) { - auto &connection_manager = ConnectionManager::Get(context); - client_locks.emplace_back(connection_manager.connections_lock, nullptr); - auto connection_list = connection_manager.GetConnectionList(); - for (auto &con : connection_list) { - if (con.get() == &context) { - continue; - } - auto &context_lock = con->context_lock; - client_locks.emplace_back(context_lock, std::move(con)); - } -} - -void DuckTransactionManager::Checkpoint(ClientContext &context, bool force) { - auto &storage_manager = db.GetStorageManager(); - if (storage_manager.InMemory()) { - return; - } - - // first check if no other thread is checkpointing right now - auto lock = unique_lock(transaction_lock); - if (thread_is_checkpointing) { - throw TransactionException("Cannot CHECKPOINT: another thread is checkpointing right now"); - } - CheckpointLock checkpoint_lock(*this); - checkpoint_lock.Lock(); - lock.unlock(); - - // lock all the clients AND the connection manager now - // this ensures no new queries can be started, and no new connections to the database can be made - // to avoid deadlock we release the transaction lock while locking the clients - vector client_locks; - LockClients(client_locks, context); - - auto current = &DuckTransaction::Get(context, db); - lock.lock(); - if (current->ChangesMade()) { - throw TransactionException("Cannot CHECKPOINT: the current transaction has transaction local changes"); - } - if (!force) { - if (!CanCheckpoint(current)) { - throw TransactionException("Cannot CHECKPOINT: there are other transactions. Use FORCE CHECKPOINT to abort " - "the other transactions and force a checkpoint"); - } - } else { - if (!CanCheckpoint(current)) { - for (size_t i = 0; i < active_transactions.size(); i++) { - auto &transaction = active_transactions[i]; - // rollback the transaction - transaction->Rollback(); - auto transaction_context = transaction->context.lock(); - - // remove the transaction id from the list of active transactions - // potentially resulting in garbage collection - RemoveTransaction(*transaction); - if (transaction_context) { - transaction_context->transaction.ClearTransaction(); - } - i--; - } - D_ASSERT(CanCheckpoint(nullptr)); - } - } - storage_manager.CreateCheckpoint(); -} - -bool DuckTransactionManager::CanCheckpoint(optional_ptr current) { - if (db.IsSystem()) { - return false; - } - auto &storage_manager = db.GetStorageManager(); - if (storage_manager.InMemory()) { - return false; - } - if (!recently_committed_transactions.empty() || !old_transactions.empty()) { - return false; - } - for (auto &transaction : active_transactions) { - if (transaction.get() != current.get()) { - return false; - } - } - return true; -} - -string DuckTransactionManager::CommitTransaction(ClientContext &context, Transaction *transaction_p) { - auto &transaction = transaction_p->Cast(); - vector client_locks; - auto lock = make_uniq>(transaction_lock); - CheckpointLock checkpoint_lock(*this); - // check if we can checkpoint - bool checkpoint = thread_is_checkpointing ? false : CanCheckpoint(&transaction); - if (checkpoint) { - if (transaction.AutomaticCheckpoint(db)) { - checkpoint_lock.Lock(); - // we might be able to checkpoint: lock all clients - // to avoid deadlock we release the transaction lock while locking the clients - lock.reset(); - - LockClients(client_locks, context); - - lock = make_uniq>(transaction_lock); - checkpoint = CanCheckpoint(&transaction); - if (!checkpoint) { - checkpoint_lock.Unlock(); - client_locks.clear(); - } - } else { - checkpoint = false; - } - } - // obtain a commit id for the transaction - transaction_t commit_id = current_start_timestamp++; - // commit the UndoBuffer of the transaction - string error = transaction.Commit(db, commit_id, checkpoint); - if (!error.empty()) { - // commit unsuccessful: rollback the transaction instead - checkpoint = false; - transaction.commit_id = 0; - transaction.Rollback(); - } - if (!checkpoint) { - // we won't checkpoint after all: unlock the clients again - checkpoint_lock.Unlock(); - client_locks.clear(); - } - - // commit successful: remove the transaction id from the list of active transactions - // potentially resulting in garbage collection - RemoveTransaction(transaction); - // now perform a checkpoint if (1) we are able to checkpoint, and (2) the WAL has reached sufficient size to - // checkpoint - if (checkpoint) { - // checkpoint the database to disk - auto &storage_manager = db.GetStorageManager(); - storage_manager.CreateCheckpoint(false, true); - } - return error; -} - -void DuckTransactionManager::RollbackTransaction(Transaction *transaction_p) { - auto &transaction = transaction_p->Cast(); - // obtain the transaction lock during this function - lock_guard lock(transaction_lock); - - // rollback the transaction - transaction.Rollback(); - - // remove the transaction id from the list of active transactions - // potentially resulting in garbage collection - RemoveTransaction(transaction); -} - -void DuckTransactionManager::RemoveTransaction(DuckTransaction &transaction) noexcept { - // remove the transaction from the list of active transactions - idx_t t_index = active_transactions.size(); - // check for the lowest and highest start time in the list of transactions - transaction_t lowest_start_time = TRANSACTION_ID_START; - transaction_t lowest_transaction_id = MAX_TRANSACTION_ID; - transaction_t lowest_active_query = MAXIMUM_QUERY_ID; - for (idx_t i = 0; i < active_transactions.size(); i++) { - if (active_transactions[i].get() == &transaction) { - t_index = i; - } else { - transaction_t active_query = active_transactions[i]->active_query; - lowest_start_time = MinValue(lowest_start_time, active_transactions[i]->start_time); - lowest_active_query = MinValue(lowest_active_query, active_query); - lowest_transaction_id = MinValue(lowest_transaction_id, active_transactions[i]->transaction_id); - } - } - lowest_active_start = lowest_start_time; - lowest_active_id = lowest_transaction_id; - - transaction_t lowest_stored_query = lowest_start_time; - D_ASSERT(t_index != active_transactions.size()); - auto current_transaction = std::move(active_transactions[t_index]); - auto current_query = DatabaseManager::Get(db).ActiveQueryNumber(); - if (transaction.commit_id != 0) { - // the transaction was committed, add it to the list of recently - // committed transactions - recently_committed_transactions.push_back(std::move(current_transaction)); - } else { - // the transaction was aborted, but we might still need its information - // add it to the set of transactions awaiting GC - current_transaction->highest_active_query = current_query; - old_transactions.push_back(std::move(current_transaction)); - } - // remove the transaction from the set of currently active transactions - active_transactions.erase(active_transactions.begin() + t_index); - // traverse the recently_committed transactions to see if we can remove any - idx_t i = 0; - for (; i < recently_committed_transactions.size(); i++) { - D_ASSERT(recently_committed_transactions[i]); - lowest_stored_query = MinValue(recently_committed_transactions[i]->start_time, lowest_stored_query); - if (recently_committed_transactions[i]->commit_id < lowest_start_time) { - // changes made BEFORE this transaction are no longer relevant - // we can cleanup the undo buffer - - // HOWEVER: any currently running QUERY can still be using - // the version information after the cleanup! - - // if we remove the UndoBuffer immediately, we have a race - // condition - - // we can only safely do the actual memory cleanup when all the - // currently active queries have finished running! (actually, - // when all the currently active scans have finished running...) - recently_committed_transactions[i]->Cleanup(); - // store the current highest active query - recently_committed_transactions[i]->highest_active_query = current_query; - // move it to the list of transactions awaiting GC - old_transactions.push_back(std::move(recently_committed_transactions[i])); - } else { - // recently_committed_transactions is ordered on commit_id - // implicitly thus if the current one is bigger than - // lowest_start_time any subsequent ones are also bigger - break; - } - } - if (i > 0) { - // we garbage collected transactions: remove them from the list - recently_committed_transactions.erase(recently_committed_transactions.begin(), - recently_committed_transactions.begin() + i); - } - // check if we can free the memory of any old transactions - i = active_transactions.empty() ? old_transactions.size() : 0; - for (; i < old_transactions.size(); i++) { - D_ASSERT(old_transactions[i]); - D_ASSERT(old_transactions[i]->highest_active_query > 0); - if (old_transactions[i]->highest_active_query >= lowest_active_query) { - // there is still a query running that could be using - // this transactions' data - break; - } - } - if (i > 0) { - // we garbage collected transactions: remove them from the list - old_transactions.erase(old_transactions.begin(), old_transactions.begin() + i); - } -} - -} // namespace duckdb - - - - - -namespace duckdb { - -MetaTransaction::MetaTransaction(ClientContext &context_p, timestamp_t start_timestamp_p, idx_t catalog_version_p) - : context(context_p), start_timestamp(start_timestamp_p), catalog_version(catalog_version_p), read_only(true), - active_query(MAXIMUM_QUERY_ID), modified_database(nullptr) { -} - -MetaTransaction &MetaTransaction::Get(ClientContext &context) { - return context.transaction.ActiveTransaction(); -} - -ValidChecker &ValidChecker::Get(MetaTransaction &transaction) { - return transaction.transaction_validity; -} - -Transaction &Transaction::Get(ClientContext &context, AttachedDatabase &db) { - auto &meta_transaction = MetaTransaction::Get(context); - return meta_transaction.GetTransaction(db); -} - -Transaction &MetaTransaction::GetTransaction(AttachedDatabase &db) { - auto entry = transactions.find(&db); - if (entry == transactions.end()) { - auto new_transaction = db.GetTransactionManager().StartTransaction(context); - if (!new_transaction) { - throw InternalException("StartTransaction did not return a valid transaction"); - } - new_transaction->active_query = active_query; - all_transactions.push_back(&db); - transactions[&db] = new_transaction; - return *new_transaction; - } else { - D_ASSERT(entry->second->active_query == active_query); - return *entry->second; - } -} - -Transaction &Transaction::Get(ClientContext &context, Catalog &catalog) { - return Transaction::Get(context, catalog.GetAttached()); -} - -string MetaTransaction::Commit() { - string error; - // commit transactions in reverse order - for (idx_t i = all_transactions.size(); i > 0; i--) { - auto db = all_transactions[i - 1]; - auto entry = transactions.find(db.get()); - if (entry == transactions.end()) { - throw InternalException("Could not find transaction corresponding to database in MetaTransaction"); - } - auto &transaction_manager = db->GetTransactionManager(); - auto transaction = entry->second; - if (error.empty()) { - // commit - error = transaction_manager.CommitTransaction(context, transaction); - } else { - // we have encountered an error previously - roll back subsequent entries - transaction_manager.RollbackTransaction(transaction); - } - } - return error; -} - -void MetaTransaction::Rollback() { - // rollback transactions in reverse order - for (idx_t i = all_transactions.size(); i > 0; i--) { - auto db = all_transactions[i - 1]; - auto &transaction_manager = db->GetTransactionManager(); - auto entry = transactions.find(db.get()); - D_ASSERT(entry != transactions.end()); - auto transaction = entry->second; - transaction_manager.RollbackTransaction(transaction); - } -} - -idx_t MetaTransaction::GetActiveQuery() { - return active_query; -} - -void MetaTransaction::SetActiveQuery(transaction_t query_number) { - active_query = query_number; - for (auto &entry : transactions) { - entry.second->active_query = query_number; - } -} - -void MetaTransaction::ModifyDatabase(AttachedDatabase &db) { - if (db.IsSystem() || db.IsTemporary()) { - // we can always modify the system and temp databases - return; - } - if (!modified_database) { - modified_database = &db; - return; - } - if (&db != modified_database.get()) { - throw TransactionException( - "Attempting to write to database \"%s\" in a transaction that has already modified database \"%s\" - a " - "single transaction can only write to a single attached database.", - db.GetName(), modified_database->GetName()); - } -} - -} // namespace duckdb - - - - - - - - - - - - - -namespace duckdb { - -void RollbackState::RollbackEntry(UndoFlags type, data_ptr_t data) { - switch (type) { - case UndoFlags::CATALOG_ENTRY: { - // undo this catalog entry - auto catalog_entry = Load(data); - D_ASSERT(catalog_entry->set); - catalog_entry->set->Undo(*catalog_entry); - break; - } - case UndoFlags::INSERT_TUPLE: { - auto info = reinterpret_cast(data); - // revert the append in the base table - info->table->RevertAppend(info->start_row, info->count); - break; - } - case UndoFlags::DELETE_TUPLE: { - auto info = reinterpret_cast(data); - // reset the deleted flag on rollback - info->version_info->CommitDelete(info->vector_idx, NOT_DELETED_ID, info->rows, info->count); - break; - } - case UndoFlags::UPDATE_TUPLE: { - auto info = reinterpret_cast(data); - info->segment->RollbackUpdate(*info); - break; - } - default: // LCOV_EXCL_START - D_ASSERT(type == UndoFlags::EMPTY_ENTRY); - break; - } // LCOV_EXCL_STOP -} - -} // namespace duckdb - - - - - -namespace duckdb { - -Transaction::Transaction(TransactionManager &manager_p, ClientContext &context_p) - : manager(manager_p), context(context_p.shared_from_this()), active_query(MAXIMUM_QUERY_ID) { -} - -Transaction::~Transaction() { -} - -bool Transaction::IsReadOnly() { - auto ctxt = context.lock(); - if (!ctxt) { - throw InternalException("Transaction::IsReadOnly() called after client context has been destroyed"); - } - auto &db = manager.GetDB(); - return MetaTransaction::Get(*ctxt).ModifiedDatabase().get() != &db; -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -TransactionContext::TransactionContext(ClientContext &context) - : context(context), auto_commit(true), current_transaction(nullptr) { -} - -TransactionContext::~TransactionContext() { - if (current_transaction) { - try { - Rollback(); - } catch (...) { - } - } -} - -void TransactionContext::BeginTransaction() { - if (current_transaction) { - throw TransactionException("cannot start a transaction within a transaction"); - } - auto start_timestamp = Timestamp::GetCurrentTimestamp(); - auto catalog_version = Catalog::GetSystemCatalog(context).GetCatalogVersion(); - current_transaction = make_uniq(context, start_timestamp, catalog_version); - - auto &config = DBConfig::GetConfig(context); - if (config.options.immediate_transaction_mode) { - // if immediate transaction mode is enabled then start all transactions immediately - auto databases = DatabaseManager::Get(context).GetDatabases(context); - for (auto db : databases) { - current_transaction->GetTransaction(db.get()); - } - } -} - -void TransactionContext::Commit() { - if (!current_transaction) { - throw TransactionException("failed to commit: no transaction active"); - } - auto transaction = std::move(current_transaction); - ClearTransaction(); - string error = transaction->Commit(); - if (!error.empty()) { - throw TransactionException("Failed to commit: %s", error); - } -} - -void TransactionContext::SetAutoCommit(bool value) { - auto_commit = value; - if (!auto_commit && !current_transaction) { - BeginTransaction(); - } -} - -void TransactionContext::Rollback() { - if (!current_transaction) { - throw TransactionException("failed to rollback: no transaction active"); - } - auto transaction = std::move(current_transaction); - ClearTransaction(); - transaction->Rollback(); -} - -void TransactionContext::ClearTransaction() { - SetAutoCommit(true); - current_transaction = nullptr; -} - -idx_t TransactionContext::GetActiveQuery() { - if (!current_transaction) { - throw InternalException("GetActiveQuery called without active transaction"); - } - return current_transaction->GetActiveQuery(); -} - -void TransactionContext::ResetActiveQuery() { - if (current_transaction) { - SetActiveQuery(MAXIMUM_QUERY_ID); - } -} - -void TransactionContext::SetActiveQuery(transaction_t query_number) { - if (!current_transaction) { - throw InternalException("SetActiveQuery called without active transaction"); - } - current_transaction->SetActiveQuery(query_number); -} - -} // namespace duckdb - - -namespace duckdb { - -TransactionManager::TransactionManager(AttachedDatabase &db) : db(db) { -} - -TransactionManager::~TransactionManager() { -} - -} // namespace duckdb - - - - - - - - - - - - - -namespace duckdb { -constexpr uint32_t UNDO_ENTRY_HEADER_SIZE = sizeof(UndoFlags) + sizeof(uint32_t); - -UndoBuffer::UndoBuffer(ClientContext &context_p) : allocator(BufferAllocator::Get(context_p)) { -} - -data_ptr_t UndoBuffer::CreateEntry(UndoFlags type, idx_t len) { - D_ASSERT(len <= NumericLimits::Maximum()); - len = AlignValue(len); - idx_t needed_space = len + UNDO_ENTRY_HEADER_SIZE; - auto data = allocator.Allocate(needed_space); - Store(type, data); - data += sizeof(UndoFlags); - Store(len, data); - data += sizeof(uint32_t); - return data; -} - -template -void UndoBuffer::IterateEntries(UndoBuffer::IteratorState &state, T &&callback) { - // iterate in insertion order: start with the tail - state.current = allocator.GetTail(); - while (state.current) { - state.start = state.current->data.get(); - state.end = state.start + state.current->current_position; - while (state.start < state.end) { - UndoFlags type = Load(state.start); - state.start += sizeof(UndoFlags); - - uint32_t len = Load(state.start); - state.start += sizeof(uint32_t); - callback(type, state.start); - state.start += len; - } - state.current = state.current->prev; - } -} - -template -void UndoBuffer::IterateEntries(UndoBuffer::IteratorState &state, UndoBuffer::IteratorState &end_state, T &&callback) { - // iterate in insertion order: start with the tail - state.current = allocator.GetTail(); - while (state.current) { - state.start = state.current->data.get(); - state.end = - state.current == end_state.current ? end_state.start : state.start + state.current->current_position; - while (state.start < state.end) { - auto type = Load(state.start); - state.start += sizeof(UndoFlags); - auto len = Load(state.start); - state.start += sizeof(uint32_t); - callback(type, state.start); - state.start += len; - } - if (state.current == end_state.current) { - // finished executing until the current end state - return; - } - state.current = state.current->prev; - } -} - -template -void UndoBuffer::ReverseIterateEntries(T &&callback) { - // iterate in reverse insertion order: start with the head - auto current = allocator.GetHead(); - while (current) { - data_ptr_t start = current->data.get(); - data_ptr_t end = start + current->current_position; - // create a vector with all nodes in this chunk - vector> nodes; - while (start < end) { - auto type = Load(start); - start += sizeof(UndoFlags); - auto len = Load(start); - start += sizeof(uint32_t); - nodes.emplace_back(type, start); - start += len; - } - // iterate over it in reverse order - for (idx_t i = nodes.size(); i > 0; i--) { - callback(nodes[i - 1].first, nodes[i - 1].second); - } - current = current->next.get(); - } -} - -bool UndoBuffer::ChangesMade() { - return !allocator.IsEmpty(); -} - -idx_t UndoBuffer::EstimatedSize() { - idx_t estimated_size = 0; - auto node = allocator.GetHead(); - while (node) { - estimated_size += node->current_position; - node = node->next.get(); - } - return estimated_size; -} - -void UndoBuffer::Cleanup() { - // garbage collect everything in the Undo Chunk - // this should only happen if - // (1) the transaction this UndoBuffer belongs to has successfully - // committed - // (on Rollback the Rollback() function should be called, that clears - // the chunks) - // (2) there is no active transaction with start_id < commit_id of this - // transaction - CleanupState state; - UndoBuffer::IteratorState iterator_state; - IterateEntries(iterator_state, [&](UndoFlags type, data_ptr_t data) { state.CleanupEntry(type, data); }); - - // possibly vacuum indexes - for (const auto &table : state.indexed_tables) { - table.second->info->indexes.Scan([&](Index &index) { - index.Vacuum(); - return false; - }); - } -} - -void UndoBuffer::Commit(UndoBuffer::IteratorState &iterator_state, optional_ptr log, - transaction_t commit_id) { - CommitState state(commit_id, log); - if (log) { - // commit WITH write ahead log - IterateEntries(iterator_state, [&](UndoFlags type, data_ptr_t data) { state.CommitEntry(type, data); }); - } else { - // commit WITHOUT write ahead log - IterateEntries(iterator_state, [&](UndoFlags type, data_ptr_t data) { state.CommitEntry(type, data); }); - } -} - -void UndoBuffer::RevertCommit(UndoBuffer::IteratorState &end_state, transaction_t transaction_id) { - CommitState state(transaction_id, nullptr); - UndoBuffer::IteratorState start_state; - IterateEntries(start_state, end_state, [&](UndoFlags type, data_ptr_t data) { state.RevertCommit(type, data); }); -} - -void UndoBuffer::Rollback() noexcept { - // rollback needs to be performed in reverse - RollbackState state; - ReverseIterateEntries([&](UndoFlags type, data_ptr_t data) { state.RollbackEntry(type, data); }); -} -} // namespace duckdb - - -namespace duckdb { - -CopiedStatementVerifier::CopiedStatementVerifier(unique_ptr statement_p) - : StatementVerifier(VerificationType::COPIED, "Copied", std::move(statement_p)) { -} - -unique_ptr CopiedStatementVerifier::Create(const SQLStatement &statement) { - return make_uniq(statement.Copy()); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -DeserializedStatementVerifier::DeserializedStatementVerifier(unique_ptr statement_p) - : StatementVerifier(VerificationType::DESERIALIZED, "Deserialized", std::move(statement_p)) { -} - -unique_ptr DeserializedStatementVerifier::Create(const SQLStatement &statement) { - - auto &select_stmt = statement.Cast(); - - MemoryStream stream; - BinarySerializer::Serialize(select_stmt, stream); - stream.Rewind(); - auto result = BinaryDeserializer::Deserialize(stream); - - return make_uniq(std::move(result)); -} - -} // namespace duckdb - - -namespace duckdb { - -ExternalStatementVerifier::ExternalStatementVerifier(unique_ptr statement_p) - : StatementVerifier(VerificationType::EXTERNAL, "External", std::move(statement_p)) { -} - -unique_ptr ExternalStatementVerifier::Create(const SQLStatement &statement) { - return make_uniq(statement.Copy()); -} - -} // namespace duckdb - - -namespace duckdb { - -NoOperatorCachingVerifier::NoOperatorCachingVerifier(unique_ptr statement_p) - : StatementVerifier(VerificationType::NO_OPERATOR_CACHING, "No operator caching", std::move(statement_p)) { -} - -unique_ptr NoOperatorCachingVerifier::Create(const SQLStatement &statement_p) { - return make_uniq(statement_p.Copy()); -} - -} // namespace duckdb - - - - -namespace duckdb { - -ParsedStatementVerifier::ParsedStatementVerifier(unique_ptr statement_p) - : StatementVerifier(VerificationType::PARSED, "Parsed", std::move(statement_p)) { -} - -unique_ptr ParsedStatementVerifier::Create(const SQLStatement &statement) { - auto query_str = statement.ToString(); - Parser parser; - try { - parser.ParseQuery(query_str); - } catch (std::exception &ex) { - throw InternalException("Parsed statement verification failed. Query:\n%s\n\nError: %s", query_str, ex.what()); - } - D_ASSERT(parser.statements.size() == 1); - D_ASSERT(parser.statements[0]->type == StatementType::SELECT_STATEMENT); - return make_uniq(std::move(parser.statements[0])); -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -PreparedStatementVerifier::PreparedStatementVerifier(unique_ptr statement_p) - : StatementVerifier(VerificationType::PREPARED, "Prepared", std::move(statement_p)) { -} - -unique_ptr PreparedStatementVerifier::Create(const SQLStatement &statement) { - return make_uniq(statement.Copy()); -} - -void PreparedStatementVerifier::Extract() { - auto &select = *statement; - // replace all the constants from the select statement and replace them with parameter expressions - ParsedExpressionIterator::EnumerateQueryNodeChildren( - *select.node, [&](unique_ptr &child) { ConvertConstants(child); }); - statement->n_param = values.size(); - for (auto &kv : values) { - statement->named_param_map[kv.first] = 0; - } - // create the PREPARE and EXECUTE statements - string name = "__duckdb_verification_prepared_statement"; - auto prepare = make_uniq(); - prepare->name = name; - prepare->statement = std::move(statement); - - auto execute = make_uniq(); - execute->name = name; - execute->named_values = std::move(values); - - auto dealloc = make_uniq(); - dealloc->info->type = CatalogType::PREPARED_STATEMENT; - dealloc->info->name = string(name); - - prepare_statement = std::move(prepare); - execute_statement = std::move(execute); - dealloc_statement = std::move(dealloc); -} - -void PreparedStatementVerifier::ConvertConstants(unique_ptr &child) { - if (child->type == ExpressionType::VALUE_CONSTANT) { - // constant: extract the constant value - auto alias = child->alias; - child->alias = string(); - // check if the value already exists - idx_t index = values.size(); - auto identifier = std::to_string(index + 1); - const auto predicate = [&](const std::pair> &pair) { - return pair.second->Equals(*child.get()); - }; - auto result = std::find_if(values.begin(), values.end(), predicate); - if (result == values.end()) { - // If it doesn't exist yet, add it - values[identifier] = std::move(child); - } else { - identifier = result->first; - } - - // replace it with an expression - auto parameter = make_uniq(); - parameter->identifier = identifier; - parameter->alias = alias; - child = std::move(parameter); - return; - } - ParsedExpressionIterator::EnumerateChildren(*child, - [&](unique_ptr &child) { ConvertConstants(child); }); -} - -bool PreparedStatementVerifier::Run( - ClientContext &context, const string &query, - const std::function(const string &, unique_ptr)> &run) { - bool failed = false; - // verify that we can extract all constants from the query and run the query as a prepared statement - // create the PREPARE and EXECUTE statements - Extract(); - // execute the prepared statements - try { - auto prepare_result = run(string(), std::move(prepare_statement)); - if (prepare_result->HasError()) { - prepare_result->ThrowError("Failed prepare during verify: "); - } - auto execute_result = run(string(), std::move(execute_statement)); - if (execute_result->HasError()) { - execute_result->ThrowError("Failed execute during verify: "); - } - materialized_result = unique_ptr_cast(std::move(execute_result)); - } catch (const Exception &ex) { - if (ex.type != ExceptionType::PARAMETER_NOT_ALLOWED) { - materialized_result = make_uniq(PreservedError(ex)); - } - failed = true; - } catch (std::exception &ex) { - materialized_result = make_uniq(PreservedError(ex)); - failed = true; - } - run(string(), std::move(dealloc_statement)); - context.interrupted = false; - - return failed; -} - -} // namespace duckdb - - - - - - - - - - - - - -namespace duckdb { - -StatementVerifier::StatementVerifier(VerificationType type, string name, unique_ptr statement_p) - : type(type), name(std::move(name)), - statement(unique_ptr_cast(std::move(statement_p))), - select_list(statement->node->GetSelectList()) { -} - -StatementVerifier::StatementVerifier(unique_ptr statement_p) - : StatementVerifier(VerificationType::ORIGINAL, "Original", std::move(statement_p)) { -} - -StatementVerifier::~StatementVerifier() noexcept { -} - -unique_ptr StatementVerifier::Create(VerificationType type, const SQLStatement &statement_p) { - switch (type) { - case VerificationType::COPIED: - return CopiedStatementVerifier::Create(statement_p); - case VerificationType::DESERIALIZED: - return DeserializedStatementVerifier::Create(statement_p); - case VerificationType::PARSED: - return ParsedStatementVerifier::Create(statement_p); - case VerificationType::UNOPTIMIZED: - return UnoptimizedStatementVerifier::Create(statement_p); - case VerificationType::NO_OPERATOR_CACHING: - return NoOperatorCachingVerifier::Create(statement_p); - case VerificationType::PREPARED: - return PreparedStatementVerifier::Create(statement_p); - case VerificationType::EXTERNAL: - return ExternalStatementVerifier::Create(statement_p); - case VerificationType::INVALID: - default: - throw InternalException("Invalid statement verification type!"); - } -} - -void StatementVerifier::CheckExpressions(const StatementVerifier &other) const { - // Only the original statement should check other statements - D_ASSERT(type == VerificationType::ORIGINAL); - - // Check equality - if (other.RequireEquality()) { - D_ASSERT(statement->Equals(*other.statement)); - } - -#ifdef DEBUG - // Now perform checking on the expressions - D_ASSERT(select_list.size() == other.select_list.size()); - const auto expr_count = select_list.size(); - if (other.RequireEquality()) { - for (idx_t i = 0; i < expr_count; i++) { - // Run the ToString, to verify that it doesn't crash - select_list[i]->ToString(); - - if (select_list[i]->HasSubquery()) { - continue; - } - - // Check that the expressions are equivalent - D_ASSERT(select_list[i]->Equals(*other.select_list[i])); - // Check that the hashes are equivalent too - D_ASSERT(select_list[i]->Hash() == other.select_list[i]->Hash()); - - other.select_list[i]->Verify(); - } - } -#endif -} - -void StatementVerifier::CheckExpressions() const { -#ifdef DEBUG - D_ASSERT(type == VerificationType::ORIGINAL); - // Perform additional checking within the expressions - const auto expr_count = select_list.size(); - for (idx_t outer_idx = 0; outer_idx < expr_count; outer_idx++) { - auto hash = select_list[outer_idx]->Hash(); - for (idx_t inner_idx = 0; inner_idx < expr_count; inner_idx++) { - auto hash2 = select_list[inner_idx]->Hash(); - if (hash != hash2) { - // if the hashes are not equivalent, the expressions should not be equivalent - D_ASSERT(!select_list[outer_idx]->Equals(*select_list[inner_idx])); - } - } - } -#endif -} - -bool StatementVerifier::Run( - ClientContext &context, const string &query, - const std::function(const string &, unique_ptr)> &run) { - bool failed = false; - - context.interrupted = false; - context.config.enable_optimizer = !DisableOptimizer(); - context.config.enable_caching_operators = !DisableOperatorCaching(); - context.config.force_external = ForceExternal(); - try { - auto result = run(query, std::move(statement)); - if (result->HasError()) { - failed = true; - } - materialized_result = unique_ptr_cast(std::move(result)); - } catch (const Exception &ex) { - failed = true; - materialized_result = make_uniq(PreservedError(ex)); - } catch (std::exception &ex) { - failed = true; - materialized_result = make_uniq(PreservedError(ex)); - } - context.interrupted = false; - - return failed; -} - -string StatementVerifier::CompareResults(const StatementVerifier &other) { - D_ASSERT(type == VerificationType::ORIGINAL); - string error; - if (materialized_result->HasError() != other.materialized_result->HasError()) { // LCOV_EXCL_START - string result = other.name + " statement differs from original result!\n"; - result += "Original Result:\n" + materialized_result->ToString(); - result += other.name + ":\n" + other.materialized_result->ToString(); - return result; - } // LCOV_EXCL_STOP - if (materialized_result->HasError()) { - return ""; - } - if (!ColumnDataCollection::ResultEquals(materialized_result->Collection(), other.materialized_result->Collection(), - error)) { // LCOV_EXCL_START - string result = other.name + " statement differs from original result!\n"; - result += "Original Result:\n" + materialized_result->ToString(); - result += other.name + ":\n" + other.materialized_result->ToString(); - result += "\n\n---------------------------------\n" + error; - return result; - } // LCOV_EXCL_STOP - - return ""; -} - -} // namespace duckdb - - -namespace duckdb { - -UnoptimizedStatementVerifier::UnoptimizedStatementVerifier(unique_ptr statement_p) - : StatementVerifier(VerificationType::UNOPTIMIZED, "Unoptimized", std::move(statement_p)) { -} - -unique_ptr UnoptimizedStatementVerifier::Create(const SQLStatement &statement_p) { - return make_uniq(statement_p.Copy()); -} - -} // namespace duckdb - -#endif diff --git a/lib/duckdb-2.cpp b/lib/duckdb-2.cpp deleted file mode 100644 index d61cfcd1..00000000 --- a/lib/duckdb-2.cpp +++ /dev/null @@ -1,20858 +0,0 @@ -#ifdef GODUCKDB_FROM_SOURCE -// See https://raw.githubusercontent.com/duckdb/duckdb/main/LICENSE for licensing information - -#include "duckdb.hpp" -#include "duckdb-internal.hpp" -#ifndef DUCKDB_AMALGAMATION -#error header mismatch -#endif - - - - - - - - - - - - - - - - - - - - - - - - - -#include -#include -#include - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Cast bool -> Numeric -//===--------------------------------------------------------------------===// -template <> -bool TryCast::Operation(bool input, bool &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(bool input, int8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(bool input, int16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(bool input, int32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(bool input, int64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(bool input, hugeint_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(bool input, uint8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(bool input, uint16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(bool input, uint32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(bool input, uint64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(bool input, float &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(bool input, double &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -//===--------------------------------------------------------------------===// -// Cast int8_t -> Numeric -//===--------------------------------------------------------------------===// -template <> -bool TryCast::Operation(int8_t input, bool &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int8_t input, int8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int8_t input, int16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int8_t input, int32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int8_t input, int64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int8_t input, hugeint_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int8_t input, uint8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int8_t input, uint16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int8_t input, uint32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int8_t input, uint64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int8_t input, float &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int8_t input, double &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -//===--------------------------------------------------------------------===// -// Cast int16_t -> Numeric -//===--------------------------------------------------------------------===// -template <> -bool TryCast::Operation(int16_t input, bool &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int16_t input, int8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int16_t input, int16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int16_t input, int32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int16_t input, int64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int16_t input, hugeint_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int16_t input, uint8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int16_t input, uint16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int16_t input, uint32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int16_t input, uint64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int16_t input, float &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int16_t input, double &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -//===--------------------------------------------------------------------===// -// Cast int32_t -> Numeric -//===--------------------------------------------------------------------===// -template <> -bool TryCast::Operation(int32_t input, bool &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int32_t input, int8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int32_t input, int16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int32_t input, int32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int32_t input, int64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int32_t input, hugeint_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int32_t input, uint8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int32_t input, uint16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int32_t input, uint32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int32_t input, uint64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int32_t input, float &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int32_t input, double &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -//===--------------------------------------------------------------------===// -// Cast int64_t -> Numeric -//===--------------------------------------------------------------------===// -template <> -bool TryCast::Operation(int64_t input, bool &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int64_t input, int8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int64_t input, int16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int64_t input, int32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int64_t input, int64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int64_t input, hugeint_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int64_t input, uint8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int64_t input, uint16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int64_t input, uint32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int64_t input, uint64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int64_t input, float &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int64_t input, double &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -//===--------------------------------------------------------------------===// -// Cast hugeint_t -> Numeric -//===--------------------------------------------------------------------===// -template <> -bool TryCast::Operation(hugeint_t input, bool &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(hugeint_t input, int8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(hugeint_t input, int16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(hugeint_t input, int32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(hugeint_t input, int64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(hugeint_t input, hugeint_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(hugeint_t input, uint8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(hugeint_t input, uint16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(hugeint_t input, uint32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(hugeint_t input, uint64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(hugeint_t input, float &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(hugeint_t input, double &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -//===--------------------------------------------------------------------===// -// Cast uint8_t -> Numeric -//===--------------------------------------------------------------------===// -template <> -bool TryCast::Operation(uint8_t input, bool &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint8_t input, int8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint8_t input, int16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint8_t input, int32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint8_t input, int64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint8_t input, hugeint_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint8_t input, uint8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint8_t input, uint16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint8_t input, uint32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint8_t input, uint64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint8_t input, float &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint8_t input, double &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -//===--------------------------------------------------------------------===// -// Cast uint16_t -> Numeric -//===--------------------------------------------------------------------===// -template <> -bool TryCast::Operation(uint16_t input, bool &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint16_t input, int8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint16_t input, int16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint16_t input, int32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint16_t input, int64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint16_t input, hugeint_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint16_t input, uint8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint16_t input, uint16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint16_t input, uint32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint16_t input, uint64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint16_t input, float &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint16_t input, double &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -//===--------------------------------------------------------------------===// -// Cast uint32_t -> Numeric -//===--------------------------------------------------------------------===// -template <> -bool TryCast::Operation(uint32_t input, bool &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint32_t input, int8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint32_t input, int16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint32_t input, int32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint32_t input, int64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint32_t input, hugeint_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint32_t input, uint8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint32_t input, uint16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint32_t input, uint32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint32_t input, uint64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint32_t input, float &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint32_t input, double &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -//===--------------------------------------------------------------------===// -// Cast uint64_t -> Numeric -//===--------------------------------------------------------------------===// -template <> -bool TryCast::Operation(uint64_t input, bool &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint64_t input, int8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint64_t input, int16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint64_t input, int32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint64_t input, int64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint64_t input, hugeint_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint64_t input, uint8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint64_t input, uint16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint64_t input, uint32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint64_t input, uint64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint64_t input, float &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint64_t input, double &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -//===--------------------------------------------------------------------===// -// Cast float -> Numeric -//===--------------------------------------------------------------------===// -template <> -bool TryCast::Operation(float input, bool &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(float input, int8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(float input, int16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(float input, int32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(float input, int64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(float input, hugeint_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(float input, uint8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(float input, uint16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(float input, uint32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(float input, uint64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(float input, float &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(float input, double &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -//===--------------------------------------------------------------------===// -// Cast double -> Numeric -//===--------------------------------------------------------------------===// -template <> -bool TryCast::Operation(double input, bool &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(double input, int8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(double input, int16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(double input, int32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(double input, int64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(double input, hugeint_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(double input, uint8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(double input, uint16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(double input, uint32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(double input, uint64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(double input, float &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(double input, double &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -//===--------------------------------------------------------------------===// -// Cast String -> Numeric -//===--------------------------------------------------------------------===// -template -struct IntegerCastData { - using Result = T; - Result result; - bool seen_decimal; -}; - -struct IntegerCastOperation { - template - static bool HandleDigit(T &state, uint8_t digit) { - using result_t = typename T::Result; - if (NEGATIVE) { - if (state.result < (NumericLimits::Minimum() + digit) / 10) { - return false; - } - state.result = state.result * 10 - digit; - } else { - if (state.result > (NumericLimits::Maximum() - digit) / 10) { - return false; - } - state.result = state.result * 10 + digit; - } - return true; - } - - template - static bool HandleHexDigit(T &state, uint8_t digit) { - using result_t = typename T::Result; - if (state.result > (NumericLimits::Maximum() - digit) / 16) { - return false; - } - state.result = state.result * 16 + digit; - return true; - } - - template - static bool HandleBinaryDigit(T &state, uint8_t digit) { - using result_t = typename T::Result; - if (state.result > (NumericLimits::Maximum() - digit) / 2) { - return false; - } - state.result = state.result * 2 + digit; - return true; - } - - template - static bool HandleExponent(T &state, int32_t exponent) { - using result_t = typename T::Result; - double dbl_res = state.result * std::pow(10.0L, exponent); - if (dbl_res < (double)NumericLimits::Minimum() || - dbl_res > (double)NumericLimits::Maximum()) { - return false; - } - state.result = (result_t)std::nearbyint(dbl_res); - return true; - } - - template - static bool HandleDecimal(T &state, uint8_t digit) { - if (state.seen_decimal) { - return true; - } - state.seen_decimal = true; - // round the integer based on what is after the decimal point - // if digit >= 5, then we round up (or down in case of negative numbers) - auto increment = digit >= 5; - if (!increment) { - return true; - } - if (NEGATIVE) { - if (state.result == NumericLimits::Minimum()) { - return false; - } - state.result--; - } else { - if (state.result == NumericLimits::Maximum()) { - return false; - } - state.result++; - } - return true; - } - - template - static bool Finalize(T &state) { - return true; - } -}; - -template -static bool IntegerCastLoop(const char *buf, idx_t len, T &result, bool strict) { - idx_t start_pos; - if (NEGATIVE) { - start_pos = 1; - } else { - if (*buf == '+') { - if (strict) { - // leading plus is not allowed in strict mode - return false; - } - start_pos = 1; - } else { - start_pos = 0; - } - } - idx_t pos = start_pos; - while (pos < len) { - if (!StringUtil::CharacterIsDigit(buf[pos])) { - // not a digit! - if (buf[pos] == decimal_separator) { - if (strict) { - return false; - } - bool number_before_period = pos > start_pos; - // decimal point: we accept decimal values for integers as well - // we just truncate them - // make sure everything after the period is a number - pos++; - idx_t start_digit = pos; - while (pos < len) { - if (!StringUtil::CharacterIsDigit(buf[pos])) { - break; - } - if (!OP::template HandleDecimal(result, buf[pos] - '0')) { - return false; - } - pos++; - } - // make sure there is either (1) one number after the period, or (2) one number before the period - // i.e. we accept "1." and ".1" as valid numbers, but not "." - if (!(number_before_period || pos > start_digit)) { - return false; - } - if (pos >= len) { - break; - } - } - if (StringUtil::CharacterIsSpace(buf[pos])) { - // skip any trailing spaces - while (++pos < len) { - if (!StringUtil::CharacterIsSpace(buf[pos])) { - return false; - } - } - break; - } - if (ALLOW_EXPONENT) { - if (buf[pos] == 'e' || buf[pos] == 'E') { - if (pos == start_pos) { - return false; - } - pos++; - if (pos >= len) { - return false; - } - using ExponentData = IntegerCastData; - ExponentData exponent {0, false}; - int negative = buf[pos] == '-'; - if (negative) { - if (!IntegerCastLoop( - buf + pos, len - pos, exponent, strict)) { - return false; - } - } else { - if (!IntegerCastLoop( - buf + pos, len - pos, exponent, strict)) { - return false; - } - } - return OP::template HandleExponent(result, exponent.result); - } - } - return false; - } - uint8_t digit = buf[pos++] - '0'; - if (!OP::template HandleDigit(result, digit)) { - return false; - } - } - if (!OP::template Finalize(result)) { - return false; - } - return pos > start_pos; -} - -template -static bool IntegerHexCastLoop(const char *buf, idx_t len, T &result, bool strict) { - if (ALLOW_EXPONENT || NEGATIVE) { - return false; - } - idx_t start_pos = 1; - idx_t pos = start_pos; - char current_char; - while (pos < len) { - current_char = StringUtil::CharacterToLower(buf[pos]); - if (!StringUtil::CharacterIsHex(current_char)) { - return false; - } - uint8_t digit; - if (current_char >= 'a') { - digit = current_char - 'a' + 10; - } else { - digit = current_char - '0'; - } - pos++; - if (!OP::template HandleHexDigit(result, digit)) { - return false; - } - } - if (!OP::template Finalize(result)) { - return false; - } - return pos > start_pos; -} - -template -static bool IntegerBinaryCastLoop(const char *buf, idx_t len, T &result, bool strict) { - if (ALLOW_EXPONENT || NEGATIVE) { - return false; - } - idx_t start_pos = 1; - idx_t pos = start_pos; - uint8_t digit; - char current_char; - while (pos < len) { - current_char = buf[pos]; - if (current_char == '_' && pos > start_pos) { - // skip underscore, if it is not the first character - pos++; - if (pos == len) { - // we cant end on an underscore either - return false; - } - continue; - } else if (current_char == '0') { - digit = 0; - } else if (current_char == '1') { - digit = 1; - } else { - return false; - } - pos++; - if (!OP::template HandleBinaryDigit(result, digit)) { - return false; - } - } - if (!OP::template Finalize(result)) { - return false; - } - return pos > start_pos; -} - -template -static bool TryIntegerCast(const char *buf, idx_t len, T &result, bool strict) { - // skip any spaces at the start - while (len > 0 && StringUtil::CharacterIsSpace(*buf)) { - buf++; - len--; - } - if (len == 0) { - return false; - } - if (ZERO_INITIALIZE) { - memset(&result, 0, sizeof(T)); - } - // if the number is negative, we set the negative flag and skip the negative sign - if (*buf == '-') { - if (!IS_SIGNED) { - // Need to check if its not -0 - idx_t pos = 1; - while (pos < len) { - if (buf[pos++] != '0') { - return false; - } - } - } - return IntegerCastLoop(buf, len, result, strict); - } - if (len > 1 && *buf == '0') { - if (buf[1] == 'x' || buf[1] == 'X') { - // If it starts with 0x or 0X, we parse it as a hex value - buf++; - len--; - return IntegerHexCastLoop(buf, len, result, strict); - } else if (buf[1] == 'b' || buf[1] == 'B') { - // If it starts with 0b or 0B, we parse it as a binary value - buf++; - len--; - return IntegerBinaryCastLoop(buf, len, result, strict); - } else if (strict && StringUtil::CharacterIsDigit(buf[1])) { - // leading zeros are not allowed in strict mode - return false; - } - } - return IntegerCastLoop(buf, len, result, strict); -} - -template -static inline bool TrySimpleIntegerCast(const char *buf, idx_t len, T &result, bool strict) { - IntegerCastData data; - if (TryIntegerCast, IS_SIGNED>(buf, len, data, strict)) { - result = data.result; - return true; - } - return false; -} - -template <> -bool TryCast::Operation(string_t input, bool &result, bool strict) { - auto input_data = input.GetData(); - auto input_size = input.GetSize(); - - switch (input_size) { - case 1: { - char c = std::tolower(*input_data); - if (c == 't' || (!strict && c == '1')) { - result = true; - return true; - } else if (c == 'f' || (!strict && c == '0')) { - result = false; - return true; - } - return false; - } - case 4: { - char t = std::tolower(input_data[0]); - char r = std::tolower(input_data[1]); - char u = std::tolower(input_data[2]); - char e = std::tolower(input_data[3]); - if (t == 't' && r == 'r' && u == 'u' && e == 'e') { - result = true; - return true; - } - return false; - } - case 5: { - char f = std::tolower(input_data[0]); - char a = std::tolower(input_data[1]); - char l = std::tolower(input_data[2]); - char s = std::tolower(input_data[3]); - char e = std::tolower(input_data[4]); - if (f == 'f' && a == 'a' && l == 'l' && s == 's' && e == 'e') { - result = false; - return true; - } - return false; - } - default: - return false; - } -} -template <> -bool TryCast::Operation(string_t input, int8_t &result, bool strict) { - return TrySimpleIntegerCast(input.GetData(), input.GetSize(), result, strict); -} -template <> -bool TryCast::Operation(string_t input, int16_t &result, bool strict) { - return TrySimpleIntegerCast(input.GetData(), input.GetSize(), result, strict); -} -template <> -bool TryCast::Operation(string_t input, int32_t &result, bool strict) { - return TrySimpleIntegerCast(input.GetData(), input.GetSize(), result, strict); -} -template <> -bool TryCast::Operation(string_t input, int64_t &result, bool strict) { - return TrySimpleIntegerCast(input.GetData(), input.GetSize(), result, strict); -} - -template <> -bool TryCast::Operation(string_t input, uint8_t &result, bool strict) { - return TrySimpleIntegerCast(input.GetData(), input.GetSize(), result, strict); -} -template <> -bool TryCast::Operation(string_t input, uint16_t &result, bool strict) { - return TrySimpleIntegerCast(input.GetData(), input.GetSize(), result, strict); -} -template <> -bool TryCast::Operation(string_t input, uint32_t &result, bool strict) { - return TrySimpleIntegerCast(input.GetData(), input.GetSize(), result, strict); -} -template <> -bool TryCast::Operation(string_t input, uint64_t &result, bool strict) { - return TrySimpleIntegerCast(input.GetData(), input.GetSize(), result, strict); -} - -template -static bool TryDoubleCast(const char *buf, idx_t len, T &result, bool strict) { - // skip any spaces at the start - while (len > 0 && StringUtil::CharacterIsSpace(*buf)) { - buf++; - len--; - } - if (len == 0) { - return false; - } - if (*buf == '+') { - if (strict) { - // plus is not allowed in strict mode - return false; - } - buf++; - len--; - } - if (strict && len >= 2) { - if (buf[0] == '0' && StringUtil::CharacterIsDigit(buf[1])) { - // leading zeros are not allowed in strict mode - return false; - } - } - auto endptr = buf + len; - auto parse_result = duckdb_fast_float::from_chars(buf, buf + len, result, decimal_separator); - if (parse_result.ec != std::errc()) { - return false; - } - auto current_end = parse_result.ptr; - if (!strict) { - while (current_end < endptr && StringUtil::CharacterIsSpace(*current_end)) { - current_end++; - } - } - return current_end == endptr; -} - -template <> -bool TryCast::Operation(string_t input, float &result, bool strict) { - return TryDoubleCast(input.GetData(), input.GetSize(), result, strict); -} - -template <> -bool TryCast::Operation(string_t input, double &result, bool strict) { - return TryDoubleCast(input.GetData(), input.GetSize(), result, strict); -} - -template <> -bool TryCastErrorMessageCommaSeparated::Operation(string_t input, float &result, string *error_message, bool strict) { - if (!TryDoubleCast(input.GetData(), input.GetSize(), result, strict)) { - HandleCastError::AssignError(StringUtil::Format("Could not cast string to float: \"%s\"", input.GetString()), - error_message); - return false; - } - return true; -} - -template <> -bool TryCastErrorMessageCommaSeparated::Operation(string_t input, double &result, string *error_message, bool strict) { - if (!TryDoubleCast(input.GetData(), input.GetSize(), result, strict)) { - HandleCastError::AssignError(StringUtil::Format("Could not cast string to double: \"%s\"", input.GetString()), - error_message); - return false; - } - return true; -} - -//===--------------------------------------------------------------------===// -// Cast From Date -//===--------------------------------------------------------------------===// -template <> -bool TryCast::Operation(date_t input, date_t &result, bool strict) { - result = input; - return true; -} - -template <> -bool TryCast::Operation(date_t input, timestamp_t &result, bool strict) { - if (input == date_t::infinity()) { - result = timestamp_t::infinity(); - return true; - } else if (input == date_t::ninfinity()) { - result = timestamp_t::ninfinity(); - return true; - } - return Timestamp::TryFromDatetime(input, Time::FromTime(0, 0, 0), result); -} - -//===--------------------------------------------------------------------===// -// Cast From Time -//===--------------------------------------------------------------------===// -template <> -bool TryCast::Operation(dtime_t input, dtime_t &result, bool strict) { - result = input; - return true; -} - -template <> -bool TryCast::Operation(dtime_t input, dtime_tz_t &result, bool strict) { - result = dtime_tz_t(input, 0); - return true; -} - -//===--------------------------------------------------------------------===// -// Cast From Time With Time Zone (Offset) -//===--------------------------------------------------------------------===// -template <> -bool TryCast::Operation(dtime_tz_t input, dtime_tz_t &result, bool strict) { - result = input; - return true; -} - -template <> -bool TryCast::Operation(dtime_tz_t input, dtime_t &result, bool strict) { - result = input.time(); - return true; -} - -//===--------------------------------------------------------------------===// -// Cast From Timestamps -//===--------------------------------------------------------------------===// -template <> -bool TryCast::Operation(timestamp_t input, date_t &result, bool strict) { - result = Timestamp::GetDate(input); - return true; -} - -template <> -bool TryCast::Operation(timestamp_t input, dtime_t &result, bool strict) { - if (!Timestamp::IsFinite(input)) { - return false; - } - result = Timestamp::GetTime(input); - return true; -} - -template <> -bool TryCast::Operation(timestamp_t input, timestamp_t &result, bool strict) { - result = input; - return true; -} - -template <> -bool TryCast::Operation(timestamp_t input, dtime_tz_t &result, bool strict) { - if (!Timestamp::IsFinite(input)) { - return false; - } - result = dtime_tz_t(Timestamp::GetTime(input), 0); - return true; -} - -//===--------------------------------------------------------------------===// -// Cast from Interval -//===--------------------------------------------------------------------===// -template <> -bool TryCast::Operation(interval_t input, interval_t &result, bool strict) { - result = input; - return true; -} - -//===--------------------------------------------------------------------===// -// Non-Standard Timestamps -//===--------------------------------------------------------------------===// -template <> -duckdb::string_t CastFromTimestampNS::Operation(duckdb::timestamp_t input, Vector &result) { - return StringCast::Operation(Timestamp::FromEpochNanoSeconds(input.value), result); -} -template <> -duckdb::string_t CastFromTimestampMS::Operation(duckdb::timestamp_t input, Vector &result) { - return StringCast::Operation(Timestamp::FromEpochMs(input.value), result); -} -template <> -duckdb::string_t CastFromTimestampSec::Operation(duckdb::timestamp_t input, Vector &result) { - return StringCast::Operation(Timestamp::FromEpochSeconds(input.value), result); -} - -template <> -timestamp_t CastTimestampUsToMs::Operation(timestamp_t input) { - timestamp_t cast_timestamp(Timestamp::GetEpochMs(input)); - return cast_timestamp; -} - -template <> -timestamp_t CastTimestampUsToNs::Operation(timestamp_t input) { - timestamp_t cast_timestamp(Timestamp::GetEpochNanoSeconds(input)); - return cast_timestamp; -} - -template <> -timestamp_t CastTimestampUsToSec::Operation(timestamp_t input) { - timestamp_t cast_timestamp(Timestamp::GetEpochSeconds(input)); - return cast_timestamp; -} -template <> -timestamp_t CastTimestampMsToUs::Operation(timestamp_t input) { - return Timestamp::FromEpochMs(input.value); -} - -template <> -timestamp_t CastTimestampMsToNs::Operation(timestamp_t input) { - auto us = CastTimestampMsToUs::Operation(input); - return CastTimestampUsToNs::Operation(us); -} - -template <> -timestamp_t CastTimestampNsToUs::Operation(timestamp_t input) { - return Timestamp::FromEpochNanoSeconds(input.value); -} - -template <> -timestamp_t CastTimestampSecToUs::Operation(timestamp_t input) { - return Timestamp::FromEpochSeconds(input.value); -} - -template <> -timestamp_t CastTimestampSecToMs::Operation(timestamp_t input) { - auto us = CastTimestampSecToUs::Operation(input); - return CastTimestampUsToMs::Operation(us); -} - -template <> -timestamp_t CastTimestampSecToNs::Operation(timestamp_t input) { - auto us = CastTimestampSecToUs::Operation(input); - return CastTimestampUsToNs::Operation(us); -} - -//===--------------------------------------------------------------------===// -// Cast To Timestamp -//===--------------------------------------------------------------------===// -template <> -bool TryCastToTimestampNS::Operation(string_t input, timestamp_t &result, bool strict) { - if (!TryCast::Operation(input, result, strict)) { - return false; - } - result = Timestamp::GetEpochNanoSeconds(result); - return true; -} - -template <> -bool TryCastToTimestampMS::Operation(string_t input, timestamp_t &result, bool strict) { - if (!TryCast::Operation(input, result, strict)) { - return false; - } - result = Timestamp::GetEpochMs(result); - return true; -} - -template <> -bool TryCastToTimestampSec::Operation(string_t input, timestamp_t &result, bool strict) { - if (!TryCast::Operation(input, result, strict)) { - return false; - } - result = Timestamp::GetEpochSeconds(result); - return true; -} - -template <> -bool TryCastToTimestampNS::Operation(date_t input, timestamp_t &result, bool strict) { - if (!TryCast::Operation(input, result, strict)) { - return false; - } - if (!TryMultiplyOperator::Operation(result.value, Interval::NANOS_PER_MICRO, result.value)) { - return false; - } - return true; -} - -template <> -bool TryCastToTimestampMS::Operation(date_t input, timestamp_t &result, bool strict) { - if (!TryCast::Operation(input, result, strict)) { - return false; - } - result.value /= Interval::MICROS_PER_MSEC; - return true; -} - -template <> -bool TryCastToTimestampSec::Operation(date_t input, timestamp_t &result, bool strict) { - if (!TryCast::Operation(input, result, strict)) { - return false; - } - result.value /= Interval::MICROS_PER_MSEC * Interval::MSECS_PER_SEC; - return true; -} - -//===--------------------------------------------------------------------===// -// Cast From Blob -//===--------------------------------------------------------------------===// -template <> -string_t CastFromBlob::Operation(string_t input, Vector &vector) { - idx_t result_size = Blob::GetStringSize(input); - - string_t result = StringVector::EmptyString(vector, result_size); - Blob::ToString(input, result.GetDataWriteable()); - result.Finalize(); - - return result; -} - -template <> -string_t CastFromBlobToBit::Operation(string_t input, Vector &vector) { - idx_t result_size = input.GetSize() + 1; - if (result_size <= 1) { - throw ConversionException("Cannot cast empty BLOB to BIT"); - } - return StringVector::AddStringOrBlob(vector, Bit::BlobToBit(input)); -} - -//===--------------------------------------------------------------------===// -// Cast From Bit -//===--------------------------------------------------------------------===// -template <> -string_t CastFromBitToString::Operation(string_t input, Vector &vector) { - - idx_t result_size = Bit::BitLength(input); - string_t result = StringVector::EmptyString(vector, result_size); - Bit::ToString(input, result.GetDataWriteable()); - result.Finalize(); - - return result; -} - -//===--------------------------------------------------------------------===// -// Cast From Pointer -//===--------------------------------------------------------------------===// -template <> -string_t CastFromPointer::Operation(uintptr_t input, Vector &vector) { - std::string s = duckdb_fmt::format("0x{:x}", input); - return StringVector::AddString(vector, s); -} - -//===--------------------------------------------------------------------===// -// Cast To Blob -//===--------------------------------------------------------------------===// -template <> -bool TryCastToBlob::Operation(string_t input, string_t &result, Vector &result_vector, string *error_message, - bool strict) { - idx_t result_size; - if (!Blob::TryGetBlobSize(input, result_size, error_message)) { - return false; - } - - result = StringVector::EmptyString(result_vector, result_size); - Blob::ToBlob(input, data_ptr_cast(result.GetDataWriteable())); - result.Finalize(); - return true; -} - -//===--------------------------------------------------------------------===// -// Cast To Bit -//===--------------------------------------------------------------------===// -template <> -bool TryCastToBit::Operation(string_t input, string_t &result, Vector &result_vector, string *error_message, - bool strict) { - idx_t result_size; - if (!Bit::TryGetBitStringSize(input, result_size, error_message)) { - return false; - } - - result = StringVector::EmptyString(result_vector, result_size); - Bit::ToBit(input, result); - result.Finalize(); - return true; -} - -template <> -bool CastFromBitToNumeric::Operation(string_t input, bool &result, bool strict) { - D_ASSERT(input.GetSize() > 1); - - uint8_t value; - bool success = CastFromBitToNumeric::Operation(input, value, strict); - result = (value > 0); - return (success); -} - -template <> -bool CastFromBitToNumeric::Operation(string_t input, hugeint_t &result, bool strict) { - D_ASSERT(input.GetSize() > 1); - - if (input.GetSize() - 1 > sizeof(hugeint_t)) { - throw ConversionException("Bitstring doesn't fit inside of %s", GetTypeId()); - } - Bit::BitToNumeric(input, result); - if (result < NumericLimits::Minimum()) { - throw ConversionException("Minimum limit for HUGEINT is %s", NumericLimits::Minimum().ToString()); - } - return (true); -} - -//===--------------------------------------------------------------------===// -// Cast From UUID -//===--------------------------------------------------------------------===// -template <> -string_t CastFromUUID::Operation(hugeint_t input, Vector &vector) { - string_t result = StringVector::EmptyString(vector, 36); - UUID::ToString(input, result.GetDataWriteable()); - result.Finalize(); - return result; -} - -//===--------------------------------------------------------------------===// -// Cast To UUID -//===--------------------------------------------------------------------===// -template <> -bool TryCastToUUID::Operation(string_t input, hugeint_t &result, Vector &result_vector, string *error_message, - bool strict) { - return UUID::FromString(input.GetString(), result); -} - -//===--------------------------------------------------------------------===// -// Cast To Date -//===--------------------------------------------------------------------===// -template <> -bool TryCastErrorMessage::Operation(string_t input, date_t &result, string *error_message, bool strict) { - if (!TryCast::Operation(input, result, strict)) { - HandleCastError::AssignError(Date::ConversionError(input), error_message); - return false; - } - return true; -} - -template <> -bool TryCast::Operation(string_t input, date_t &result, bool strict) { - idx_t pos; - bool special = false; - return Date::TryConvertDate(input.GetData(), input.GetSize(), pos, result, special, strict); -} - -template <> -date_t Cast::Operation(string_t input) { - return Date::FromCString(input.GetData(), input.GetSize()); -} - -//===--------------------------------------------------------------------===// -// Cast To Time -//===--------------------------------------------------------------------===// -template <> -bool TryCastErrorMessage::Operation(string_t input, dtime_t &result, string *error_message, bool strict) { - if (!TryCast::Operation(input, result, strict)) { - HandleCastError::AssignError(Time::ConversionError(input), error_message); - return false; - } - return true; -} - -template <> -bool TryCast::Operation(string_t input, dtime_t &result, bool strict) { - idx_t pos; - return Time::TryConvertTime(input.GetData(), input.GetSize(), pos, result, strict); -} - -template <> -dtime_t Cast::Operation(string_t input) { - return Time::FromCString(input.GetData(), input.GetSize()); -} - -//===--------------------------------------------------------------------===// -// Cast To TimeTZ -//===--------------------------------------------------------------------===// -template <> -bool TryCastErrorMessage::Operation(string_t input, dtime_tz_t &result, string *error_message, bool strict) { - if (!TryCast::Operation(input, result, strict)) { - HandleCastError::AssignError(Time::ConversionError(input), error_message); - return false; - } - return true; -} - -template <> -bool TryCast::Operation(string_t input, dtime_tz_t &result, bool strict) { - idx_t pos; - return Time::TryConvertTimeTZ(input.GetData(), input.GetSize(), pos, result, strict); -} - -template <> -dtime_tz_t Cast::Operation(string_t input) { - dtime_tz_t result; - if (!TryCast::Operation(input, result, false)) { - throw ConversionException(Time::ConversionError(input)); - } - return result; -} - -//===--------------------------------------------------------------------===// -// Cast To Timestamp -//===--------------------------------------------------------------------===// -template <> -bool TryCastErrorMessage::Operation(string_t input, timestamp_t &result, string *error_message, bool strict) { - auto cast_result = Timestamp::TryConvertTimestamp(input.GetData(), input.GetSize(), result); - if (cast_result == TimestampCastResult::SUCCESS) { - return true; - } - if (cast_result == TimestampCastResult::ERROR_INCORRECT_FORMAT) { - HandleCastError::AssignError(Timestamp::ConversionError(input), error_message); - } else { - HandleCastError::AssignError(Timestamp::UnsupportedTimezoneError(input), error_message); - } - return false; -} - -template <> -bool TryCast::Operation(string_t input, timestamp_t &result, bool strict) { - return Timestamp::TryConvertTimestamp(input.GetData(), input.GetSize(), result) == TimestampCastResult::SUCCESS; -} - -template <> -timestamp_t Cast::Operation(string_t input) { - return Timestamp::FromCString(input.GetData(), input.GetSize()); -} - -//===--------------------------------------------------------------------===// -// Cast From Interval -//===--------------------------------------------------------------------===// -template <> -bool TryCastErrorMessage::Operation(string_t input, interval_t &result, string *error_message, bool strict) { - return Interval::FromCString(input.GetData(), input.GetSize(), result, error_message, strict); -} - -//===--------------------------------------------------------------------===// -// Cast From Hugeint -//===--------------------------------------------------------------------===// -// parsing hugeint from string is done a bit differently for performance reasons -// for other integer types we keep track of a single value -// and multiply that value by 10 for every digit we read -// however, for hugeints, multiplication is very expensive (>20X as expensive as for int64) -// for that reason, we parse numbers first into an int64 value -// when that value is full, we perform a HUGEINT multiplication to flush it into the hugeint -// this takes the number of HUGEINT multiplications down from [0-38] to [0-2] -struct HugeIntCastData { - hugeint_t hugeint; - int64_t intermediate; - uint8_t digits; - bool decimal; - - bool Flush() { - if (digits == 0 && intermediate == 0) { - return true; - } - if (hugeint.lower != 0 || hugeint.upper != 0) { - if (digits > 38) { - return false; - } - if (!Hugeint::TryMultiply(hugeint, Hugeint::POWERS_OF_TEN[digits], hugeint)) { - return false; - } - } - if (!Hugeint::AddInPlace(hugeint, hugeint_t(intermediate))) { - return false; - } - digits = 0; - intermediate = 0; - return true; - } -}; - -struct HugeIntegerCastOperation { - template - static bool HandleDigit(T &result, uint8_t digit) { - if (NEGATIVE) { - if (result.intermediate < (NumericLimits::Minimum() + digit) / 10) { - // intermediate is full: need to flush it - if (!result.Flush()) { - return false; - } - } - result.intermediate = result.intermediate * 10 - digit; - } else { - if (result.intermediate > (NumericLimits::Maximum() - digit) / 10) { - if (!result.Flush()) { - return false; - } - } - result.intermediate = result.intermediate * 10 + digit; - } - result.digits++; - return true; - } - - template - static bool HandleHexDigit(T &result, uint8_t digit) { - return false; - } - - template - static bool HandleBinaryDigit(T &result, uint8_t digit) { - if (result.intermediate > (NumericLimits::Maximum() - digit) / 2) { - // intermediate is full: need to flush it - if (!result.Flush()) { - return false; - } - } - result.intermediate = result.intermediate * 2 + digit; - result.digits++; - return true; - } - - template - static bool HandleExponent(T &result, int32_t exponent) { - if (!result.Flush()) { - return false; - } - if (exponent < -38 || exponent > 38) { - // out of range for exact exponent: use double and convert - double dbl_res = Hugeint::Cast(result.hugeint) * std::pow(10.0L, exponent); - if (dbl_res < Hugeint::Cast(NumericLimits::Minimum()) || - dbl_res > Hugeint::Cast(NumericLimits::Maximum())) { - return false; - } - result.hugeint = Hugeint::Convert(dbl_res); - return true; - } - if (exponent < 0) { - // negative exponent: divide by power of 10 - result.hugeint = Hugeint::Divide(result.hugeint, Hugeint::POWERS_OF_TEN[-exponent]); - return true; - } else { - // positive exponent: multiply by power of 10 - return Hugeint::TryMultiply(result.hugeint, Hugeint::POWERS_OF_TEN[exponent], result.hugeint); - } - } - - template - static bool HandleDecimal(T &result, uint8_t digit) { - // Integer casts round - if (!result.decimal) { - if (!result.Flush()) { - return false; - } - if (NEGATIVE) { - result.intermediate = -(digit >= 5); - } else { - result.intermediate = (digit >= 5); - } - } - result.decimal = true; - - return true; - } - - template - static bool Finalize(T &result) { - return result.Flush(); - } -}; - -template <> -bool TryCast::Operation(string_t input, hugeint_t &result, bool strict) { - HugeIntCastData data; - if (!TryIntegerCast(input.GetData(), input.GetSize(), data, - strict)) { - return false; - } - result = data.hugeint; - return true; -} - -//===--------------------------------------------------------------------===// -// Decimal String Cast -//===--------------------------------------------------------------------===// - -template -struct DecimalCastData { - typedef TYPE type_t; - TYPE result; - uint8_t width; - uint8_t scale; - uint8_t digit_count; - uint8_t decimal_count; - //! Whether we have determined if the result should be rounded - bool round_set; - //! If the result should be rounded - bool should_round; - //! Only set when ALLOW_EXPONENT is enabled - enum class ExponentType : uint8_t { NONE, POSITIVE, NEGATIVE }; - uint8_t excessive_decimals; - ExponentType exponent_type; -}; - -struct DecimalCastOperation { - template - static bool HandleDigit(T &state, uint8_t digit) { - if (state.result == 0 && digit == 0) { - // leading zero's don't count towards the digit count - return true; - } - if (state.digit_count == state.width - state.scale) { - // width of decimal type is exceeded! - return false; - } - state.digit_count++; - if (NEGATIVE) { - if (state.result < (NumericLimits::Minimum() / 10)) { - return false; - } - state.result = state.result * 10 - digit; - } else { - if (state.result > (NumericLimits::Maximum() / 10)) { - return false; - } - state.result = state.result * 10 + digit; - } - return true; - } - - template - static bool HandleHexDigit(T &state, uint8_t digit) { - return false; - } - - template - static bool HandleBinaryDigit(T &state, uint8_t digit) { - return false; - } - - template - static void RoundUpResult(T &state) { - if (NEGATIVE) { - state.result -= 1; - } else { - state.result += 1; - } - } - - template - static bool HandleExponent(T &state, int32_t exponent) { - auto decimal_excess = (state.decimal_count > state.scale) ? state.decimal_count - state.scale : 0; - if (exponent > 0) { - state.exponent_type = T::ExponentType::POSITIVE; - // Positive exponents need up to 'exponent' amount of digits - // Everything beyond that amount needs to be truncated - if (decimal_excess > exponent) { - // We've allowed too many decimals - state.excessive_decimals = decimal_excess - exponent; - exponent = 0; - } else { - exponent -= decimal_excess; - } - D_ASSERT(exponent >= 0); - } else if (exponent < 0) { - state.exponent_type = T::ExponentType::NEGATIVE; - } - if (!Finalize(state)) { - return false; - } - if (exponent < 0) { - bool round_up = false; - for (idx_t i = 0; i < idx_t(-int64_t(exponent)); i++) { - auto mod = state.result % 10; - round_up = NEGATIVE ? mod <= -5 : mod >= 5; - state.result /= 10; - if (state.result == 0) { - break; - } - } - if (round_up) { - RoundUpResult(state); - } - return true; - } else { - // positive exponent: append 0's - for (idx_t i = 0; i < idx_t(exponent); i++) { - if (!HandleDigit(state, 0)) { - return false; - } - } - return true; - } - } - - template - static bool HandleDecimal(T &state, uint8_t digit) { - if (state.decimal_count == state.scale && !state.round_set) { - // Determine whether the last registered decimal should be rounded or not - state.round_set = true; - state.should_round = digit >= 5; - } - if (!ALLOW_EXPONENT && state.decimal_count == state.scale) { - // we exceeded the amount of supported decimals - // however, we don't throw an error here - // we just truncate the decimal - return true; - } - //! If we expect an exponent, we need to preserve the decimals - //! But we don't want to overflow, so we prevent overflowing the result with this check - if (state.digit_count + state.decimal_count >= DecimalWidth::max) { - return true; - } - state.decimal_count++; - if (NEGATIVE) { - state.result = state.result * 10 - digit; - } else { - state.result = state.result * 10 + digit; - } - return true; - } - - template - static bool TruncateExcessiveDecimals(T &state) { - D_ASSERT(state.excessive_decimals); - bool round_up = false; - for (idx_t i = 0; i < state.excessive_decimals; i++) { - auto mod = state.result % 10; - round_up = NEGATIVE ? mod <= -5 : mod >= 5; - state.result /= 10.0; - } - //! Only round up when exponents are involved - if (state.exponent_type == T::ExponentType::POSITIVE && round_up) { - RoundUpResult(state); - } - D_ASSERT(state.decimal_count > state.scale); - state.decimal_count = state.scale; - return true; - } - - template - static bool Finalize(T &state) { - if (state.exponent_type != T::ExponentType::POSITIVE && state.decimal_count > state.scale) { - //! Did not encounter an exponent, but ALLOW_EXPONENT was on - state.excessive_decimals = state.decimal_count - state.scale; - } - if (state.excessive_decimals && !TruncateExcessiveDecimals(state)) { - return false; - } - if (state.exponent_type == T::ExponentType::NONE && state.round_set && state.should_round) { - RoundUpResult(state); - } - // if we have not gotten exactly "scale" decimals, we need to multiply the result - // e.g. if we have a string "1.0" that is cast to a DECIMAL(9,3), the value needs to be 1000 - // but we have only gotten the value "10" so far, so we multiply by 1000 - for (uint8_t i = state.decimal_count; i < state.scale; i++) { - state.result *= 10; - } - return true; - } -}; - -template -bool TryDecimalStringCast(string_t input, T &result, string *error_message, uint8_t width, uint8_t scale) { - DecimalCastData state; - state.result = 0; - state.width = width; - state.scale = scale; - state.digit_count = 0; - state.decimal_count = 0; - state.excessive_decimals = 0; - state.exponent_type = DecimalCastData::ExponentType::NONE; - state.round_set = false; - state.should_round = false; - if (!TryIntegerCast, true, true, DecimalCastOperation, false, decimal_separator>( - input.GetData(), input.GetSize(), state, false)) { - string error = StringUtil::Format("Could not convert string \"%s\" to DECIMAL(%d,%d)", input.GetString(), - (int)width, (int)scale); - HandleCastError::AssignError(error, error_message); - return false; - } - result = state.result; - return true; -} - -template <> -bool TryCastToDecimal::Operation(string_t input, int16_t &result, string *error_message, uint8_t width, uint8_t scale) { - return TryDecimalStringCast(input, result, error_message, width, scale); -} - -template <> -bool TryCastToDecimal::Operation(string_t input, int32_t &result, string *error_message, uint8_t width, uint8_t scale) { - return TryDecimalStringCast(input, result, error_message, width, scale); -} - -template <> -bool TryCastToDecimal::Operation(string_t input, int64_t &result, string *error_message, uint8_t width, uint8_t scale) { - return TryDecimalStringCast(input, result, error_message, width, scale); -} - -template <> -bool TryCastToDecimal::Operation(string_t input, hugeint_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return TryDecimalStringCast(input, result, error_message, width, scale); -} - -template <> -bool TryCastToDecimalCommaSeparated::Operation(string_t input, int16_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return TryDecimalStringCast(input, result, error_message, width, scale); -} - -template <> -bool TryCastToDecimalCommaSeparated::Operation(string_t input, int32_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return TryDecimalStringCast(input, result, error_message, width, scale); -} - -template <> -bool TryCastToDecimalCommaSeparated::Operation(string_t input, int64_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return TryDecimalStringCast(input, result, error_message, width, scale); -} - -template <> -bool TryCastToDecimalCommaSeparated::Operation(string_t input, hugeint_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return TryDecimalStringCast(input, result, error_message, width, scale); -} - -template <> -string_t StringCastFromDecimal::Operation(int16_t input, uint8_t width, uint8_t scale, Vector &result) { - return DecimalToString::Format(input, width, scale, result); -} - -template <> -string_t StringCastFromDecimal::Operation(int32_t input, uint8_t width, uint8_t scale, Vector &result) { - return DecimalToString::Format(input, width, scale, result); -} - -template <> -string_t StringCastFromDecimal::Operation(int64_t input, uint8_t width, uint8_t scale, Vector &result) { - return DecimalToString::Format(input, width, scale, result); -} - -template <> -string_t StringCastFromDecimal::Operation(hugeint_t input, uint8_t width, uint8_t scale, Vector &result) { - return HugeintToStringCast::FormatDecimal(input, width, scale, result); -} - -//===--------------------------------------------------------------------===// -// Decimal Casts -//===--------------------------------------------------------------------===// -// Decimal <-> Bool -//===--------------------------------------------------------------------===// -template -bool TryCastBoolToDecimal(bool input, T &result, string *error_message, uint8_t width, uint8_t scale) { - if (width > scale) { - result = input ? OP::POWERS_OF_TEN[scale] : 0; - return true; - } else { - return TryCast::Operation(input, result); - } -} - -template <> -bool TryCastToDecimal::Operation(bool input, int16_t &result, string *error_message, uint8_t width, uint8_t scale) { - return TryCastBoolToDecimal(input, result, error_message, width, scale); -} - -template <> -bool TryCastToDecimal::Operation(bool input, int32_t &result, string *error_message, uint8_t width, uint8_t scale) { - return TryCastBoolToDecimal(input, result, error_message, width, scale); -} - -template <> -bool TryCastToDecimal::Operation(bool input, int64_t &result, string *error_message, uint8_t width, uint8_t scale) { - return TryCastBoolToDecimal(input, result, error_message, width, scale); -} - -template <> -bool TryCastToDecimal::Operation(bool input, hugeint_t &result, string *error_message, uint8_t width, uint8_t scale) { - return TryCastBoolToDecimal(input, result, error_message, width, scale); -} - -template <> -bool TryCastFromDecimal::Operation(int16_t input, bool &result, string *error_message, uint8_t width, uint8_t scale) { - return TryCast::Operation(input, result); -} - -template <> -bool TryCastFromDecimal::Operation(int32_t input, bool &result, string *error_message, uint8_t width, uint8_t scale) { - return TryCast::Operation(input, result); -} - -template <> -bool TryCastFromDecimal::Operation(int64_t input, bool &result, string *error_message, uint8_t width, uint8_t scale) { - return TryCast::Operation(input, result); -} - -template <> -bool TryCastFromDecimal::Operation(hugeint_t input, bool &result, string *error_message, uint8_t width, uint8_t scale) { - return TryCast::Operation(input, result); -} - -//===--------------------------------------------------------------------===// -// Numeric -> Decimal Cast -//===--------------------------------------------------------------------===// -struct SignedToDecimalOperator { - template - static bool Operation(SRC input, DST max_width) { - return int64_t(input) >= int64_t(max_width) || int64_t(input) <= int64_t(-max_width); - } -}; - -struct UnsignedToDecimalOperator { - template - static bool Operation(SRC input, DST max_width) { - return uint64_t(input) >= uint64_t(max_width); - } -}; - -template -bool StandardNumericToDecimalCast(SRC input, DST &result, string *error_message, uint8_t width, uint8_t scale) { - // check for overflow - DST max_width = NumericHelper::POWERS_OF_TEN[width - scale]; - if (OP::template Operation(input, max_width)) { - string error = StringUtil::Format("Could not cast value %d to DECIMAL(%d,%d)", input, width, scale); - HandleCastError::AssignError(error, error_message); - return false; - } - result = DST(input) * NumericHelper::POWERS_OF_TEN[scale]; - return true; -} - -template -bool NumericToHugeDecimalCast(SRC input, hugeint_t &result, string *error_message, uint8_t width, uint8_t scale) { - // check for overflow - hugeint_t max_width = Hugeint::POWERS_OF_TEN[width - scale]; - hugeint_t hinput = Hugeint::Convert(input); - if (hinput >= max_width || hinput <= -max_width) { - string error = StringUtil::Format("Could not cast value %s to DECIMAL(%d,%d)", hinput.ToString(), width, scale); - HandleCastError::AssignError(error, error_message); - return false; - } - result = hinput * Hugeint::POWERS_OF_TEN[scale]; - return true; -} - -//===--------------------------------------------------------------------===// -// Cast int8_t -> Decimal -//===--------------------------------------------------------------------===// -template <> -bool TryCastToDecimal::Operation(int8_t input, int16_t &result, string *error_message, uint8_t width, uint8_t scale) { - return StandardNumericToDecimalCast(input, result, error_message, width, scale); -} -template <> -bool TryCastToDecimal::Operation(int8_t input, int32_t &result, string *error_message, uint8_t width, uint8_t scale) { - return StandardNumericToDecimalCast(input, result, error_message, width, scale); -} -template <> -bool TryCastToDecimal::Operation(int8_t input, int64_t &result, string *error_message, uint8_t width, uint8_t scale) { - return StandardNumericToDecimalCast(input, result, error_message, width, scale); -} -template <> -bool TryCastToDecimal::Operation(int8_t input, hugeint_t &result, string *error_message, uint8_t width, uint8_t scale) { - return NumericToHugeDecimalCast(input, result, error_message, width, scale); -} - -//===--------------------------------------------------------------------===// -// Cast int16_t -> Decimal -//===--------------------------------------------------------------------===// -template <> -bool TryCastToDecimal::Operation(int16_t input, int16_t &result, string *error_message, uint8_t width, uint8_t scale) { - return StandardNumericToDecimalCast(input, result, error_message, width, scale); -} -template <> -bool TryCastToDecimal::Operation(int16_t input, int32_t &result, string *error_message, uint8_t width, uint8_t scale) { - return StandardNumericToDecimalCast(input, result, error_message, width, scale); -} -template <> -bool TryCastToDecimal::Operation(int16_t input, int64_t &result, string *error_message, uint8_t width, uint8_t scale) { - return StandardNumericToDecimalCast(input, result, error_message, width, scale); -} -template <> -bool TryCastToDecimal::Operation(int16_t input, hugeint_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return NumericToHugeDecimalCast(input, result, error_message, width, scale); -} - -//===--------------------------------------------------------------------===// -// Cast int32_t -> Decimal -//===--------------------------------------------------------------------===// -template <> -bool TryCastToDecimal::Operation(int32_t input, int16_t &result, string *error_message, uint8_t width, uint8_t scale) { - return StandardNumericToDecimalCast(input, result, error_message, width, scale); -} -template <> -bool TryCastToDecimal::Operation(int32_t input, int32_t &result, string *error_message, uint8_t width, uint8_t scale) { - return StandardNumericToDecimalCast(input, result, error_message, width, scale); -} -template <> -bool TryCastToDecimal::Operation(int32_t input, int64_t &result, string *error_message, uint8_t width, uint8_t scale) { - return StandardNumericToDecimalCast(input, result, error_message, width, scale); -} -template <> -bool TryCastToDecimal::Operation(int32_t input, hugeint_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return NumericToHugeDecimalCast(input, result, error_message, width, scale); -} - -//===--------------------------------------------------------------------===// -// Cast int64_t -> Decimal -//===--------------------------------------------------------------------===// -template <> -bool TryCastToDecimal::Operation(int64_t input, int16_t &result, string *error_message, uint8_t width, uint8_t scale) { - return StandardNumericToDecimalCast(input, result, error_message, width, scale); -} -template <> -bool TryCastToDecimal::Operation(int64_t input, int32_t &result, string *error_message, uint8_t width, uint8_t scale) { - return StandardNumericToDecimalCast(input, result, error_message, width, scale); -} -template <> -bool TryCastToDecimal::Operation(int64_t input, int64_t &result, string *error_message, uint8_t width, uint8_t scale) { - return StandardNumericToDecimalCast(input, result, error_message, width, scale); -} -template <> -bool TryCastToDecimal::Operation(int64_t input, hugeint_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return NumericToHugeDecimalCast(input, result, error_message, width, scale); -} - -//===--------------------------------------------------------------------===// -// Cast uint8_t -> Decimal -//===--------------------------------------------------------------------===// -template <> -bool TryCastToDecimal::Operation(uint8_t input, int16_t &result, string *error_message, uint8_t width, uint8_t scale) { - return StandardNumericToDecimalCast(input, result, error_message, - width, scale); -} -template <> -bool TryCastToDecimal::Operation(uint8_t input, int32_t &result, string *error_message, uint8_t width, uint8_t scale) { - return StandardNumericToDecimalCast(input, result, error_message, - width, scale); -} -template <> -bool TryCastToDecimal::Operation(uint8_t input, int64_t &result, string *error_message, uint8_t width, uint8_t scale) { - return StandardNumericToDecimalCast(input, result, error_message, - width, scale); -} -template <> -bool TryCastToDecimal::Operation(uint8_t input, hugeint_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return NumericToHugeDecimalCast(input, result, error_message, width, scale); -} - -//===--------------------------------------------------------------------===// -// Cast uint16_t -> Decimal -//===--------------------------------------------------------------------===// -template <> -bool TryCastToDecimal::Operation(uint16_t input, int16_t &result, string *error_message, uint8_t width, uint8_t scale) { - return StandardNumericToDecimalCast(input, result, error_message, - width, scale); -} -template <> -bool TryCastToDecimal::Operation(uint16_t input, int32_t &result, string *error_message, uint8_t width, uint8_t scale) { - return StandardNumericToDecimalCast(input, result, error_message, - width, scale); -} -template <> -bool TryCastToDecimal::Operation(uint16_t input, int64_t &result, string *error_message, uint8_t width, uint8_t scale) { - return StandardNumericToDecimalCast(input, result, error_message, - width, scale); -} -template <> -bool TryCastToDecimal::Operation(uint16_t input, hugeint_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return NumericToHugeDecimalCast(input, result, error_message, width, scale); -} - -//===--------------------------------------------------------------------===// -// Cast uint32_t -> Decimal -//===--------------------------------------------------------------------===// -template <> -bool TryCastToDecimal::Operation(uint32_t input, int16_t &result, string *error_message, uint8_t width, uint8_t scale) { - return StandardNumericToDecimalCast(input, result, error_message, - width, scale); -} -template <> -bool TryCastToDecimal::Operation(uint32_t input, int32_t &result, string *error_message, uint8_t width, uint8_t scale) { - return StandardNumericToDecimalCast(input, result, error_message, - width, scale); -} -template <> -bool TryCastToDecimal::Operation(uint32_t input, int64_t &result, string *error_message, uint8_t width, uint8_t scale) { - return StandardNumericToDecimalCast(input, result, error_message, - width, scale); -} -template <> -bool TryCastToDecimal::Operation(uint32_t input, hugeint_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return NumericToHugeDecimalCast(input, result, error_message, width, scale); -} - -//===--------------------------------------------------------------------===// -// Cast uint64_t -> Decimal -//===--------------------------------------------------------------------===// -template <> -bool TryCastToDecimal::Operation(uint64_t input, int16_t &result, string *error_message, uint8_t width, uint8_t scale) { - return StandardNumericToDecimalCast(input, result, error_message, - width, scale); -} -template <> -bool TryCastToDecimal::Operation(uint64_t input, int32_t &result, string *error_message, uint8_t width, uint8_t scale) { - return StandardNumericToDecimalCast(input, result, error_message, - width, scale); -} -template <> -bool TryCastToDecimal::Operation(uint64_t input, int64_t &result, string *error_message, uint8_t width, uint8_t scale) { - return StandardNumericToDecimalCast(input, result, error_message, - width, scale); -} -template <> -bool TryCastToDecimal::Operation(uint64_t input, hugeint_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return NumericToHugeDecimalCast(input, result, error_message, width, scale); -} - -//===--------------------------------------------------------------------===// -// Hugeint -> Decimal Cast -//===--------------------------------------------------------------------===// -template -bool HugeintToDecimalCast(hugeint_t input, DST &result, string *error_message, uint8_t width, uint8_t scale) { - // check for overflow - hugeint_t max_width = Hugeint::POWERS_OF_TEN[width - scale]; - if (input >= max_width || input <= -max_width) { - string error = StringUtil::Format("Could not cast value %s to DECIMAL(%d,%d)", input.ToString(), width, scale); - HandleCastError::AssignError(error, error_message); - return false; - } - result = Hugeint::Cast(input * Hugeint::POWERS_OF_TEN[scale]); - return true; -} - -template <> -bool TryCastToDecimal::Operation(hugeint_t input, int16_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return HugeintToDecimalCast(input, result, error_message, width, scale); -} - -template <> -bool TryCastToDecimal::Operation(hugeint_t input, int32_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return HugeintToDecimalCast(input, result, error_message, width, scale); -} - -template <> -bool TryCastToDecimal::Operation(hugeint_t input, int64_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return HugeintToDecimalCast(input, result, error_message, width, scale); -} - -template <> -bool TryCastToDecimal::Operation(hugeint_t input, hugeint_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return HugeintToDecimalCast(input, result, error_message, width, scale); -} - -//===--------------------------------------------------------------------===// -// Float/Double -> Decimal Cast -//===--------------------------------------------------------------------===// -template -bool DoubleToDecimalCast(SRC input, DST &result, string *error_message, uint8_t width, uint8_t scale) { - double value = input * NumericHelper::DOUBLE_POWERS_OF_TEN[scale]; - // Add the sign (-1, 0, 1) times a tiny value to fix floating point issues (issue 3091) - double sign = (double(0) < value) - (value < double(0)); - value += 1e-9 * sign; - if (value <= -NumericHelper::DOUBLE_POWERS_OF_TEN[width] || value >= NumericHelper::DOUBLE_POWERS_OF_TEN[width]) { - string error = StringUtil::Format("Could not cast value %f to DECIMAL(%d,%d)", value, width, scale); - HandleCastError::AssignError(error, error_message); - return false; - } - result = Cast::Operation(value); - return true; -} - -template <> -bool TryCastToDecimal::Operation(float input, int16_t &result, string *error_message, uint8_t width, uint8_t scale) { - return DoubleToDecimalCast(input, result, error_message, width, scale); -} - -template <> -bool TryCastToDecimal::Operation(float input, int32_t &result, string *error_message, uint8_t width, uint8_t scale) { - return DoubleToDecimalCast(input, result, error_message, width, scale); -} - -template <> -bool TryCastToDecimal::Operation(float input, int64_t &result, string *error_message, uint8_t width, uint8_t scale) { - return DoubleToDecimalCast(input, result, error_message, width, scale); -} - -template <> -bool TryCastToDecimal::Operation(float input, hugeint_t &result, string *error_message, uint8_t width, uint8_t scale) { - return DoubleToDecimalCast(input, result, error_message, width, scale); -} - -template <> -bool TryCastToDecimal::Operation(double input, int16_t &result, string *error_message, uint8_t width, uint8_t scale) { - return DoubleToDecimalCast(input, result, error_message, width, scale); -} - -template <> -bool TryCastToDecimal::Operation(double input, int32_t &result, string *error_message, uint8_t width, uint8_t scale) { - return DoubleToDecimalCast(input, result, error_message, width, scale); -} - -template <> -bool TryCastToDecimal::Operation(double input, int64_t &result, string *error_message, uint8_t width, uint8_t scale) { - return DoubleToDecimalCast(input, result, error_message, width, scale); -} - -template <> -bool TryCastToDecimal::Operation(double input, hugeint_t &result, string *error_message, uint8_t width, uint8_t scale) { - return DoubleToDecimalCast(input, result, error_message, width, scale); -} - -//===--------------------------------------------------------------------===// -// Decimal -> Numeric Cast -//===--------------------------------------------------------------------===// -template -bool TryCastDecimalToNumeric(SRC input, DST &result, string *error_message, uint8_t scale) { - // Round away from 0. - const auto power = NumericHelper::POWERS_OF_TEN[scale]; - // https://graphics.stanford.edu/~seander/bithacks.html#ConditionalNegate - const auto fNegate = int64_t(input < 0); - const auto rounding = ((power ^ -fNegate) + fNegate) / 2; - const auto scaled_value = (input + rounding) / power; - if (!TryCast::Operation(scaled_value, result)) { - string error = StringUtil::Format("Failed to cast decimal value %d to type %s", scaled_value, GetTypeId()); - HandleCastError::AssignError(error, error_message); - return false; - } - return true; -} - -template -bool TryCastHugeDecimalToNumeric(hugeint_t input, DST &result, string *error_message, uint8_t scale) { - const auto power = Hugeint::POWERS_OF_TEN[scale]; - const auto rounding = ((input < 0) ? -power : power) / 2; - auto scaled_value = (input + rounding) / power; - if (!TryCast::Operation(scaled_value, result)) { - string error = StringUtil::Format("Failed to cast decimal value %s to type %s", - ConvertToString::Operation(scaled_value), GetTypeId()); - HandleCastError::AssignError(error, error_message); - return false; - } - return true; -} - -//===--------------------------------------------------------------------===// -// Cast Decimal -> int8_t -//===--------------------------------------------------------------------===// -template <> -bool TryCastFromDecimal::Operation(int16_t input, int8_t &result, string *error_message, uint8_t width, uint8_t scale) { - return TryCastDecimalToNumeric(input, result, error_message, scale); -} -template <> -bool TryCastFromDecimal::Operation(int32_t input, int8_t &result, string *error_message, uint8_t width, uint8_t scale) { - return TryCastDecimalToNumeric(input, result, error_message, scale); -} -template <> -bool TryCastFromDecimal::Operation(int64_t input, int8_t &result, string *error_message, uint8_t width, uint8_t scale) { - return TryCastDecimalToNumeric(input, result, error_message, scale); -} -template <> -bool TryCastFromDecimal::Operation(hugeint_t input, int8_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return TryCastHugeDecimalToNumeric(input, result, error_message, scale); -} - -//===--------------------------------------------------------------------===// -// Cast Decimal -> int16_t -//===--------------------------------------------------------------------===// -template <> -bool TryCastFromDecimal::Operation(int16_t input, int16_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, error_message, scale); -} -template <> -bool TryCastFromDecimal::Operation(int32_t input, int16_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, error_message, scale); -} -template <> -bool TryCastFromDecimal::Operation(int64_t input, int16_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, error_message, scale); -} -template <> -bool TryCastFromDecimal::Operation(hugeint_t input, int16_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return TryCastHugeDecimalToNumeric(input, result, error_message, scale); -} - -//===--------------------------------------------------------------------===// -// Cast Decimal -> int32_t -//===--------------------------------------------------------------------===// -template <> -bool TryCastFromDecimal::Operation(int16_t input, int32_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, error_message, scale); -} -template <> -bool TryCastFromDecimal::Operation(int32_t input, int32_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, error_message, scale); -} -template <> -bool TryCastFromDecimal::Operation(int64_t input, int32_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, error_message, scale); -} -template <> -bool TryCastFromDecimal::Operation(hugeint_t input, int32_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return TryCastHugeDecimalToNumeric(input, result, error_message, scale); -} - -//===--------------------------------------------------------------------===// -// Cast Decimal -> int64_t -//===--------------------------------------------------------------------===// -template <> -bool TryCastFromDecimal::Operation(int16_t input, int64_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, error_message, scale); -} -template <> -bool TryCastFromDecimal::Operation(int32_t input, int64_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, error_message, scale); -} -template <> -bool TryCastFromDecimal::Operation(int64_t input, int64_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, error_message, scale); -} -template <> -bool TryCastFromDecimal::Operation(hugeint_t input, int64_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return TryCastHugeDecimalToNumeric(input, result, error_message, scale); -} - -//===--------------------------------------------------------------------===// -// Cast Decimal -> uint8_t -//===--------------------------------------------------------------------===// -template <> -bool TryCastFromDecimal::Operation(int16_t input, uint8_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, error_message, scale); -} -template <> -bool TryCastFromDecimal::Operation(int32_t input, uint8_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, error_message, scale); -} -template <> -bool TryCastFromDecimal::Operation(int64_t input, uint8_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, error_message, scale); -} -template <> -bool TryCastFromDecimal::Operation(hugeint_t input, uint8_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return TryCastHugeDecimalToNumeric(input, result, error_message, scale); -} - -//===--------------------------------------------------------------------===// -// Cast Decimal -> uint16_t -//===--------------------------------------------------------------------===// -template <> -bool TryCastFromDecimal::Operation(int16_t input, uint16_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, error_message, scale); -} -template <> -bool TryCastFromDecimal::Operation(int32_t input, uint16_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, error_message, scale); -} -template <> -bool TryCastFromDecimal::Operation(int64_t input, uint16_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, error_message, scale); -} -template <> -bool TryCastFromDecimal::Operation(hugeint_t input, uint16_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return TryCastHugeDecimalToNumeric(input, result, error_message, scale); -} - -//===--------------------------------------------------------------------===// -// Cast Decimal -> uint32_t -//===--------------------------------------------------------------------===// -template <> -bool TryCastFromDecimal::Operation(int16_t input, uint32_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, error_message, scale); -} -template <> -bool TryCastFromDecimal::Operation(int32_t input, uint32_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, error_message, scale); -} -template <> -bool TryCastFromDecimal::Operation(int64_t input, uint32_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, error_message, scale); -} -template <> -bool TryCastFromDecimal::Operation(hugeint_t input, uint32_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return TryCastHugeDecimalToNumeric(input, result, error_message, scale); -} - -//===--------------------------------------------------------------------===// -// Cast Decimal -> uint64_t -//===--------------------------------------------------------------------===// -template <> -bool TryCastFromDecimal::Operation(int16_t input, uint64_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, error_message, scale); -} -template <> -bool TryCastFromDecimal::Operation(int32_t input, uint64_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, error_message, scale); -} -template <> -bool TryCastFromDecimal::Operation(int64_t input, uint64_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, error_message, scale); -} -template <> -bool TryCastFromDecimal::Operation(hugeint_t input, uint64_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return TryCastHugeDecimalToNumeric(input, result, error_message, scale); -} - -//===--------------------------------------------------------------------===// -// Cast Decimal -> hugeint_t -//===--------------------------------------------------------------------===// -template <> -bool TryCastFromDecimal::Operation(int16_t input, hugeint_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, error_message, scale); -} -template <> -bool TryCastFromDecimal::Operation(int32_t input, hugeint_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, error_message, scale); -} -template <> -bool TryCastFromDecimal::Operation(int64_t input, hugeint_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, error_message, scale); -} -template <> -bool TryCastFromDecimal::Operation(hugeint_t input, hugeint_t &result, string *error_message, uint8_t width, - uint8_t scale) { - return TryCastHugeDecimalToNumeric(input, result, error_message, scale); -} - -//===--------------------------------------------------------------------===// -// Decimal -> Float/Double Cast -//===--------------------------------------------------------------------===// -template -bool TryCastDecimalToFloatingPoint(SRC input, DST &result, uint8_t scale) { - result = Cast::Operation(input) / DST(NumericHelper::DOUBLE_POWERS_OF_TEN[scale]); - return true; -} - -// DECIMAL -> FLOAT -template <> -bool TryCastFromDecimal::Operation(int16_t input, float &result, string *error_message, uint8_t width, uint8_t scale) { - return TryCastDecimalToFloatingPoint(input, result, scale); -} - -template <> -bool TryCastFromDecimal::Operation(int32_t input, float &result, string *error_message, uint8_t width, uint8_t scale) { - return TryCastDecimalToFloatingPoint(input, result, scale); -} - -template <> -bool TryCastFromDecimal::Operation(int64_t input, float &result, string *error_message, uint8_t width, uint8_t scale) { - return TryCastDecimalToFloatingPoint(input, result, scale); -} - -template <> -bool TryCastFromDecimal::Operation(hugeint_t input, float &result, string *error_message, uint8_t width, - uint8_t scale) { - return TryCastDecimalToFloatingPoint(input, result, scale); -} - -// DECIMAL -> DOUBLE -template <> -bool TryCastFromDecimal::Operation(int16_t input, double &result, string *error_message, uint8_t width, uint8_t scale) { - return TryCastDecimalToFloatingPoint(input, result, scale); -} - -template <> -bool TryCastFromDecimal::Operation(int32_t input, double &result, string *error_message, uint8_t width, uint8_t scale) { - return TryCastDecimalToFloatingPoint(input, result, scale); -} - -template <> -bool TryCastFromDecimal::Operation(int64_t input, double &result, string *error_message, uint8_t width, uint8_t scale) { - return TryCastDecimalToFloatingPoint(input, result, scale); -} - -template <> -bool TryCastFromDecimal::Operation(hugeint_t input, double &result, string *error_message, uint8_t width, - uint8_t scale) { - return TryCastDecimalToFloatingPoint(input, result, scale); -} - -} // namespace duckdb - - - - -namespace duckdb { - -template -string StandardStringCast(T input) { - Vector v(LogicalType::VARCHAR); - return StringCast::Operation(input, v).GetString(); -} - -template <> -string ConvertToString::Operation(bool input) { - return StandardStringCast(input); -} -template <> -string ConvertToString::Operation(int8_t input) { - return StandardStringCast(input); -} -template <> -string ConvertToString::Operation(int16_t input) { - return StandardStringCast(input); -} -template <> -string ConvertToString::Operation(int32_t input) { - return StandardStringCast(input); -} -template <> -string ConvertToString::Operation(int64_t input) { - return StandardStringCast(input); -} -template <> -string ConvertToString::Operation(uint8_t input) { - return StandardStringCast(input); -} -template <> -string ConvertToString::Operation(uint16_t input) { - return StandardStringCast(input); -} -template <> -string ConvertToString::Operation(uint32_t input) { - return StandardStringCast(input); -} -template <> -string ConvertToString::Operation(uint64_t input) { - return StandardStringCast(input); -} -template <> -string ConvertToString::Operation(hugeint_t input) { - return StandardStringCast(input); -} -template <> -string ConvertToString::Operation(float input) { - return StandardStringCast(input); -} -template <> -string ConvertToString::Operation(double input) { - return StandardStringCast(input); -} -template <> -string ConvertToString::Operation(interval_t input) { - return StandardStringCast(input); -} -template <> -string ConvertToString::Operation(date_t input) { - return StandardStringCast(input); -} -template <> -string ConvertToString::Operation(dtime_t input) { - return StandardStringCast(input); -} -template <> -string ConvertToString::Operation(timestamp_t input) { - return StandardStringCast(input); -} -template <> -string ConvertToString::Operation(string_t input) { - return input.GetString(); -} - -} // namespace duckdb - - - - - - - - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Cast Numeric -> String -//===--------------------------------------------------------------------===// -template <> -string_t StringCast::Operation(bool input, Vector &vector) { - if (input) { - return StringVector::AddString(vector, "true", 4); - } else { - return StringVector::AddString(vector, "false", 5); - } -} - -template <> -string_t StringCast::Operation(int8_t input, Vector &vector) { - return NumericHelper::FormatSigned(input, vector); -} - -template <> -string_t StringCast::Operation(int16_t input, Vector &vector) { - return NumericHelper::FormatSigned(input, vector); -} -template <> -string_t StringCast::Operation(int32_t input, Vector &vector) { - return NumericHelper::FormatSigned(input, vector); -} - -template <> -string_t StringCast::Operation(int64_t input, Vector &vector) { - return NumericHelper::FormatSigned(input, vector); -} -template <> -duckdb::string_t StringCast::Operation(uint8_t input, Vector &vector) { - return NumericHelper::FormatSigned(input, vector); -} -template <> -duckdb::string_t StringCast::Operation(uint16_t input, Vector &vector) { - return NumericHelper::FormatSigned(input, vector); -} -template <> -duckdb::string_t StringCast::Operation(uint32_t input, Vector &vector) { - return NumericHelper::FormatSigned(input, vector); -} -template <> -duckdb::string_t StringCast::Operation(uint64_t input, Vector &vector) { - return NumericHelper::FormatSigned(input, vector); -} - -template <> -string_t StringCast::Operation(float input, Vector &vector) { - std::string s = duckdb_fmt::format("{}", input); - return StringVector::AddString(vector, s); -} - -template <> -string_t StringCast::Operation(double input, Vector &vector) { - std::string s = duckdb_fmt::format("{}", input); - return StringVector::AddString(vector, s); -} - -template <> -string_t StringCast::Operation(interval_t input, Vector &vector) { - char buffer[70]; - idx_t length = IntervalToStringCast::Format(input, buffer); - return StringVector::AddString(vector, buffer, length); -} - -template <> -duckdb::string_t StringCast::Operation(hugeint_t input, Vector &vector) { - return HugeintToStringCast::FormatSigned(input, vector); -} - -template <> -duckdb::string_t StringCast::Operation(date_t input, Vector &vector) { - if (input == date_t::infinity()) { - return StringVector::AddString(vector, Date::PINF); - } else if (input == date_t::ninfinity()) { - return StringVector::AddString(vector, Date::NINF); - } - int32_t date[3]; - Date::Convert(input, date[0], date[1], date[2]); - - idx_t year_length; - bool add_bc; - idx_t length = DateToStringCast::Length(date, year_length, add_bc); - - string_t result = StringVector::EmptyString(vector, length); - auto data = result.GetDataWriteable(); - - DateToStringCast::Format(data, date, year_length, add_bc); - - result.Finalize(); - return result; -} - -template <> -duckdb::string_t StringCast::Operation(dtime_t input, Vector &vector) { - int32_t time[4]; - Time::Convert(input, time[0], time[1], time[2], time[3]); - - char micro_buffer[10]; - idx_t length = TimeToStringCast::Length(time, micro_buffer); - - string_t result = StringVector::EmptyString(vector, length); - auto data = result.GetDataWriteable(); - - TimeToStringCast::Format(data, length, time, micro_buffer); - - result.Finalize(); - return result; -} - -template <> -duckdb::string_t StringCast::Operation(timestamp_t input, Vector &vector) { - if (input == timestamp_t::infinity()) { - return StringVector::AddString(vector, Date::PINF); - } else if (input == timestamp_t::ninfinity()) { - return StringVector::AddString(vector, Date::NINF); - } - date_t date_entry; - dtime_t time_entry; - Timestamp::Convert(input, date_entry, time_entry); - - int32_t date[3], time[4]; - Date::Convert(date_entry, date[0], date[1], date[2]); - Time::Convert(time_entry, time[0], time[1], time[2], time[3]); - - // format for timestamp is DATE TIME (separated by space) - idx_t year_length; - bool add_bc; - char micro_buffer[6]; - idx_t date_length = DateToStringCast::Length(date, year_length, add_bc); - idx_t time_length = TimeToStringCast::Length(time, micro_buffer); - idx_t length = date_length + time_length + 1; - - string_t result = StringVector::EmptyString(vector, length); - auto data = result.GetDataWriteable(); - - DateToStringCast::Format(data, date, year_length, add_bc); - data[date_length] = ' '; - TimeToStringCast::Format(data + date_length + 1, time_length, time, micro_buffer); - - result.Finalize(); - return result; -} - -template <> -duckdb::string_t StringCast::Operation(duckdb::string_t input, Vector &result) { - return StringVector::AddStringOrBlob(result, input); -} - -template <> -string_t StringCastTZ::Operation(dtime_tz_t input, Vector &vector) { - int32_t time[4]; - Time::Convert(input.time(), time[0], time[1], time[2], time[3]); - - char micro_buffer[10]; - const auto time_length = TimeToStringCast::Length(time, micro_buffer); - idx_t length = time_length; - - const auto offset = input.offset(); - const bool negative = (offset < 0); - ++length; - - auto ss = std::abs(offset); - const auto hh = ss / Interval::SECS_PER_HOUR; - - const auto hh_length = (hh < 100) ? 2 : NumericHelper::UnsignedLength(uint32_t(hh)); - length += hh_length; - - ss %= Interval::SECS_PER_HOUR; - const auto mm = ss / Interval::SECS_PER_MINUTE; - if (mm) { - length += 3; - } - - ss %= Interval::SECS_PER_MINUTE; - if (ss) { - length += 3; - } - - string_t result = StringVector::EmptyString(vector, length); - auto data = result.GetDataWriteable(); - - idx_t pos = 0; - TimeToStringCast::Format(data + pos, time_length, time, micro_buffer); - pos += time_length; - - data[pos++] = negative ? '-' : '+'; - if (hh < 100) { - TimeToStringCast::FormatTwoDigits(data + pos, hh); - } else { - NumericHelper::FormatUnsigned(hh, data + pos + hh_length); - } - pos += hh_length; - - if (mm) { - data[pos++] = ':'; - TimeToStringCast::FormatTwoDigits(data + pos, mm); - pos += 2; - } - - if (ss) { - data[pos++] = ':'; - TimeToStringCast::FormatTwoDigits(data + pos, ss); - pos += 2; - } - - result.Finalize(); - return result; -} - -template <> -string_t StringCastTZ::Operation(timestamp_t input, Vector &vector) { - if (input == timestamp_t::infinity()) { - return StringVector::AddString(vector, Date::PINF); - } else if (input == timestamp_t::ninfinity()) { - return StringVector::AddString(vector, Date::NINF); - } - date_t date_entry; - dtime_t time_entry; - Timestamp::Convert(input, date_entry, time_entry); - - int32_t date[3], time[4]; - Date::Convert(date_entry, date[0], date[1], date[2]); - Time::Convert(time_entry, time[0], time[1], time[2], time[3]); - - // format for timestamptz is DATE TIME+00 (separated by space) - idx_t year_length; - bool add_bc; - char micro_buffer[6]; - const idx_t date_length = DateToStringCast::Length(date, year_length, add_bc); - const idx_t time_length = TimeToStringCast::Length(time, micro_buffer); - const idx_t length = date_length + 1 + time_length + 3; - - string_t result = StringVector::EmptyString(vector, length); - auto data = result.GetDataWriteable(); - - idx_t pos = 0; - DateToStringCast::Format(data + pos, date, year_length, add_bc); - pos += date_length; - data[pos++] = ' '; - TimeToStringCast::Format(data + pos, time_length, time, micro_buffer); - pos += time_length; - data[pos++] = '+'; - data[pos++] = '0'; - data[pos++] = '0'; - - result.Finalize(); - return result; -} - -} // namespace duckdb - - - - - -namespace duckdb { -class PipeFile : public FileHandle { -public: - PipeFile(unique_ptr child_handle_p, const string &path) - : FileHandle(pipe_fs, path), child_handle(std::move(child_handle_p)) { - } - - PipeFileSystem pipe_fs; - unique_ptr child_handle; - -public: - int64_t ReadChunk(void *buffer, int64_t nr_bytes); - int64_t WriteChunk(void *buffer, int64_t nr_bytes); - - void Close() override { - } -}; - -int64_t PipeFile::ReadChunk(void *buffer, int64_t nr_bytes) { - return child_handle->Read(buffer, nr_bytes); -} -int64_t PipeFile::WriteChunk(void *buffer, int64_t nr_bytes) { - return child_handle->Write(buffer, nr_bytes); -} - -void PipeFileSystem::Reset(FileHandle &handle) { - throw InternalException("Cannot reset pipe file system"); -} - -int64_t PipeFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes) { - auto &pipe = handle.Cast(); - return pipe.ReadChunk(buffer, nr_bytes); -} - -int64_t PipeFileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes) { - auto &pipe = handle.Cast(); - return pipe.WriteChunk(buffer, nr_bytes); -} - -int64_t PipeFileSystem::GetFileSize(FileHandle &handle) { - return 0; -} - -void PipeFileSystem::FileSync(FileHandle &handle) { -} - -unique_ptr PipeFileSystem::OpenPipe(unique_ptr handle) { - auto path = handle->path; - return make_uniq(std::move(handle), path); -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -PreservedError::PreservedError() : initialized(false), exception_instance(nullptr) { -} - -PreservedError::PreservedError(const Exception &exception) - : initialized(true), type(exception.type), raw_message(SanitizeErrorMessage(exception.RawMessage())), - exception_instance(exception.Copy()) { -} - -PreservedError::PreservedError(const string &message) - : initialized(true), type(ExceptionType::INVALID), raw_message(SanitizeErrorMessage(message)), - exception_instance(nullptr) { -} - -const string &PreservedError::Message() { - if (final_message.empty()) { - final_message = Exception::ExceptionTypeToString(type) + " Error: " + raw_message; - } - return final_message; -} - -string PreservedError::SanitizeErrorMessage(string error) { - return StringUtil::Replace(std::move(error), string("\0", 1), "\\0"); -} - -void PreservedError::Throw(const string &prepended_message) const { - D_ASSERT(initialized); - if (!prepended_message.empty()) { - string new_message = prepended_message + raw_message; - Exception::ThrowAsTypeWithMessage(type, new_message, exception_instance); - } - Exception::ThrowAsTypeWithMessage(type, raw_message, exception_instance); -} - -const ExceptionType &PreservedError::Type() const { - D_ASSERT(initialized); - return this->type; -} - -PreservedError &PreservedError::AddToMessage(const string &prepended_message) { - raw_message = prepended_message + raw_message; - return *this; -} - -PreservedError::operator bool() const { - return initialized; -} - -bool PreservedError::operator==(const PreservedError &other) const { - if (initialized != other.initialized) { - return false; - } - if (type != other.type) { - return false; - } - return raw_message == other.raw_message; -} - -} // namespace duckdb - - - - -#include - -#ifndef DUCKDB_DISABLE_PRINT -#ifdef DUCKDB_WINDOWS -#include -#else -#include -#include -#include -#endif -#endif - -namespace duckdb { - -void Printer::RawPrint(OutputStream stream, const string &str) { -#ifndef DUCKDB_DISABLE_PRINT -#ifdef DUCKDB_WINDOWS - if (IsTerminal(stream)) { - // print utf8 to terminal - auto unicode = WindowsUtil::UTF8ToMBCS(str.c_str()); - fprintf(stream == OutputStream::STREAM_STDERR ? stderr : stdout, "%s", unicode.c_str()); - return; - } -#endif - fprintf(stream == OutputStream::STREAM_STDERR ? stderr : stdout, "%s", str.c_str()); -#endif -} - -// LCOV_EXCL_START -void Printer::Print(OutputStream stream, const string &str) { - Printer::RawPrint(stream, str); - Printer::RawPrint(stream, "\n"); -} -void Printer::Flush(OutputStream stream) { -#ifndef DUCKDB_DISABLE_PRINT - fflush(stream == OutputStream::STREAM_STDERR ? stderr : stdout); -#endif -} - -void Printer::Print(const string &str) { - Printer::Print(OutputStream::STREAM_STDERR, str); -} - -bool Printer::IsTerminal(OutputStream stream) { -#ifndef DUCKDB_DISABLE_PRINT -#ifdef DUCKDB_WINDOWS - auto stream_handle = stream == OutputStream::STREAM_STDERR ? STD_ERROR_HANDLE : STD_OUTPUT_HANDLE; - return GetFileType(GetStdHandle(stream_handle)) == FILE_TYPE_CHAR; -#else - return isatty(stream == OutputStream::STREAM_STDERR ? 2 : 1); -#endif -#else - throw InternalException("IsTerminal called while printing is disabled"); -#endif -} - -idx_t Printer::TerminalWidth() { -#ifndef DUCKDB_DISABLE_PRINT -#ifdef DUCKDB_WINDOWS - CONSOLE_SCREEN_BUFFER_INFO csbi; - int columns, rows; - - GetConsoleScreenBufferInfo(GetStdHandle(STD_OUTPUT_HANDLE), &csbi); - rows = csbi.srWindow.Right - csbi.srWindow.Left + 1; - return rows; -#else - struct winsize w; - ioctl(0, TIOCGWINSZ, &w); - return w.ws_col; -#endif -#else - throw InternalException("TerminalWidth called while printing is disabled"); -#endif -} -// LCOV_EXCL_STOP - -} // namespace duckdb - - - - -namespace duckdb { - -void ProgressBar::SystemOverrideCheck(ClientConfig &config) { - if (config.system_progress_bar_disable_reason != nullptr) { - throw InvalidInputException("Could not change the progress bar setting because: '%s'", - config.system_progress_bar_disable_reason); - } -} - -unique_ptr ProgressBar::DefaultProgressBarDisplay() { - return make_uniq(); -} - -ProgressBar::ProgressBar(Executor &executor, idx_t show_progress_after, - progress_bar_display_create_func_t create_display_func) - : executor(executor), show_progress_after(show_progress_after), current_percentage(-1) { - if (create_display_func) { - display = create_display_func(); - } -} - -double ProgressBar::GetCurrentPercentage() { - return current_percentage; -} - -void ProgressBar::Start() { - profiler.Start(); - current_percentage = 0; - supported = true; -} - -bool ProgressBar::PrintEnabled() const { - return display != nullptr; -} - -bool ProgressBar::ShouldPrint(bool final) const { - if (!PrintEnabled()) { - // Don't print progress at all - return false; - } - // FIXME - do we need to check supported before running `profiler.Elapsed()` ? - auto sufficient_time_elapsed = profiler.Elapsed() > show_progress_after / 1000.0; - if (!sufficient_time_elapsed) { - // Don't print yet - return false; - } - if (final) { - // Print the last completed bar - return true; - } - if (!supported) { - return false; - } - return current_percentage > -1; -} - -void ProgressBar::Update(bool final) { - if (!final && !supported) { - return; - } - double new_percentage; - supported = executor.GetPipelinesProgress(new_percentage); - if (!final && !supported) { - return; - } - if (new_percentage > current_percentage) { - current_percentage = new_percentage; - } - if (ShouldPrint(final)) { -#ifndef DUCKDB_DISABLE_PRINT - if (final) { - FinishProgressBarPrint(); - } else { - PrintProgress(current_percentage); - } -#endif - } -} - -void ProgressBar::PrintProgress(int current_percentage) { - D_ASSERT(display); - display->Update(current_percentage); -} - -void ProgressBar::FinishProgressBarPrint() { - if (finished) { - return; - } - D_ASSERT(display); - display->Finish(); - finished = true; -} - -} // namespace duckdb - - - - -namespace duckdb { - -void TerminalProgressBarDisplay::PrintProgressInternal(int percentage) { - if (percentage > 100) { - percentage = 100; - } - if (percentage < 0) { - percentage = 0; - } - string result; - // we divide the number of blocks by the percentage - // 0% = 0 - // 100% = PROGRESS_BAR_WIDTH - // the percentage determines how many blocks we need to draw - double blocks_to_draw = PROGRESS_BAR_WIDTH * (percentage / 100.0); - // because of the power of unicode, we can also draw partial blocks - - // render the percentage with some padding to ensure everything stays nicely aligned - result = "\r"; - if (percentage < 100) { - result += " "; - } - if (percentage < 10) { - result += " "; - } - result += to_string(percentage) + "%"; - result += " "; - result += PROGRESS_START; - idx_t i; - for (i = 0; i < idx_t(blocks_to_draw); i++) { - result += PROGRESS_BLOCK; - } - if (i < PROGRESS_BAR_WIDTH) { - // print a partial block based on the percentage of the progress bar remaining - idx_t index = idx_t((blocks_to_draw - idx_t(blocks_to_draw)) * PARTIAL_BLOCK_COUNT); - if (index >= PARTIAL_BLOCK_COUNT) { - index = PARTIAL_BLOCK_COUNT - 1; - } - result += PROGRESS_PARTIAL[index]; - i++; - } - for (; i < PROGRESS_BAR_WIDTH; i++) { - result += PROGRESS_EMPTY; - } - result += PROGRESS_END; - result += " "; - - Printer::RawPrint(OutputStream::STREAM_STDOUT, result); -} - -void TerminalProgressBarDisplay::Update(double percentage) { - PrintProgressInternal(percentage); - Printer::Flush(OutputStream::STREAM_STDOUT); -} - -void TerminalProgressBarDisplay::Finish() { - PrintProgressInternal(100); - Printer::RawPrint(OutputStream::STREAM_STDOUT, "\n"); - Printer::Flush(OutputStream::STREAM_STDOUT); -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -//! Templated radix partitioning constants, can be templated to the number of radix bits -template -struct RadixPartitioningConstants { -public: - //! Bitmask of the upper bits starting at the 5th byte - static constexpr const idx_t NUM_PARTITIONS = RadixPartitioning::NumberOfPartitions(radix_bits); - static constexpr const idx_t SHIFT = RadixPartitioning::Shift(radix_bits); - static constexpr const hash_t MASK = RadixPartitioning::Mask(radix_bits); - -public: - //! Apply bitmask and right shift to get a number between 0 and NUM_PARTITIONS - static inline hash_t ApplyMask(hash_t hash) { - D_ASSERT((hash & MASK) >> SHIFT < NUM_PARTITIONS); - return (hash & MASK) >> SHIFT; - } -}; - -template -RETURN_TYPE RadixBitsSwitch(idx_t radix_bits, ARGS &&... args) { - D_ASSERT(radix_bits <= RadixPartitioning::MAX_RADIX_BITS); - switch (radix_bits) { - case 0: - return OP::template Operation<0>(std::forward(args)...); - case 1: - return OP::template Operation<1>(std::forward(args)...); - case 2: - return OP::template Operation<2>(std::forward(args)...); - case 3: - return OP::template Operation<3>(std::forward(args)...); - case 4: - return OP::template Operation<4>(std::forward(args)...); - case 5: // LCOV_EXCL_START - return OP::template Operation<5>(std::forward(args)...); - case 6: - return OP::template Operation<6>(std::forward(args)...); - case 7: - return OP::template Operation<7>(std::forward(args)...); - case 8: - return OP::template Operation<8>(std::forward(args)...); - case 9: - return OP::template Operation<9>(std::forward(args)...); - case 10: - return OP::template Operation<10>(std::forward(args)...); - case 11: - return OP::template Operation<10>(std::forward(args)...); - case 12: - return OP::template Operation<10>(std::forward(args)...); - default: - throw InternalException( - "radix_bits higher than RadixPartitioning::MAX_RADIX_BITS encountered in RadixBitsSwitch"); - } // LCOV_EXCL_STOP -} - -template -struct RadixLessThan { - static inline bool Operation(hash_t hash, hash_t cutoff) { - using CONSTANTS = RadixPartitioningConstants; - return CONSTANTS::ApplyMask(hash) < cutoff; - } -}; - -struct SelectFunctor { - template - static idx_t Operation(Vector &hashes, const SelectionVector *sel, idx_t count, idx_t cutoff, - SelectionVector *true_sel, SelectionVector *false_sel) { - Vector cutoff_vector(Value::HASH(cutoff)); - return BinaryExecutor::Select>(hashes, cutoff_vector, sel, count, - true_sel, false_sel); - } -}; - -idx_t RadixPartitioning::Select(Vector &hashes, const SelectionVector *sel, idx_t count, idx_t radix_bits, idx_t cutoff, - SelectionVector *true_sel, SelectionVector *false_sel) { - return RadixBitsSwitch(radix_bits, hashes, sel, count, cutoff, true_sel, false_sel); -} - -struct ComputePartitionIndicesFunctor { - template - static void Operation(Vector &hashes, Vector &partition_indices, idx_t count) { - UnaryExecutor::Execute(hashes, partition_indices, count, [&](hash_t hash) { - using CONSTANTS = RadixPartitioningConstants; - return CONSTANTS::ApplyMask(hash); - }); - } -}; - -//===--------------------------------------------------------------------===// -// Column Data Partitioning -//===--------------------------------------------------------------------===// -RadixPartitionedColumnData::RadixPartitionedColumnData(ClientContext &context_p, vector types_p, - idx_t radix_bits_p, idx_t hash_col_idx_p) - : PartitionedColumnData(PartitionedColumnDataType::RADIX, context_p, std::move(types_p)), radix_bits(radix_bits_p), - hash_col_idx(hash_col_idx_p) { - D_ASSERT(radix_bits <= RadixPartitioning::MAX_RADIX_BITS); - D_ASSERT(hash_col_idx < types.size()); - const auto num_partitions = RadixPartitioning::NumberOfPartitions(radix_bits); - allocators->allocators.reserve(num_partitions); - for (idx_t i = 0; i < num_partitions; i++) { - CreateAllocator(); - } - D_ASSERT(allocators->allocators.size() == num_partitions); -} - -RadixPartitionedColumnData::RadixPartitionedColumnData(const RadixPartitionedColumnData &other) - : PartitionedColumnData(other), radix_bits(other.radix_bits), hash_col_idx(other.hash_col_idx) { - for (idx_t i = 0; i < RadixPartitioning::NumberOfPartitions(radix_bits); i++) { - partitions.emplace_back(CreatePartitionCollection(i)); - } -} - -RadixPartitionedColumnData::~RadixPartitionedColumnData() { -} - -void RadixPartitionedColumnData::InitializeAppendStateInternal(PartitionedColumnDataAppendState &state) const { - const auto num_partitions = RadixPartitioning::NumberOfPartitions(radix_bits); - state.partition_append_states.reserve(num_partitions); - state.partition_buffers.reserve(num_partitions); - for (idx_t i = 0; i < num_partitions; i++) { - state.partition_append_states.emplace_back(make_uniq()); - partitions[i]->InitializeAppend(*state.partition_append_states[i]); - state.partition_buffers.emplace_back(CreatePartitionBuffer()); - } -} - -void RadixPartitionedColumnData::ComputePartitionIndices(PartitionedColumnDataAppendState &state, DataChunk &input) { - D_ASSERT(partitions.size() == RadixPartitioning::NumberOfPartitions(radix_bits)); - D_ASSERT(state.partition_buffers.size() == RadixPartitioning::NumberOfPartitions(radix_bits)); - RadixBitsSwitch(radix_bits, input.data[hash_col_idx], state.partition_indices, - input.size()); -} - -//===--------------------------------------------------------------------===// -// Tuple Data Partitioning -//===--------------------------------------------------------------------===// -RadixPartitionedTupleData::RadixPartitionedTupleData(BufferManager &buffer_manager, const TupleDataLayout &layout_p, - idx_t radix_bits_p, idx_t hash_col_idx_p) - : PartitionedTupleData(PartitionedTupleDataType::RADIX, buffer_manager, layout_p.Copy()), radix_bits(radix_bits_p), - hash_col_idx(hash_col_idx_p) { - D_ASSERT(radix_bits <= RadixPartitioning::MAX_RADIX_BITS); - D_ASSERT(hash_col_idx < layout.GetTypes().size()); - const auto num_partitions = RadixPartitioning::NumberOfPartitions(radix_bits); - allocators->allocators.reserve(num_partitions); - for (idx_t i = 0; i < num_partitions; i++) { - CreateAllocator(); - } - D_ASSERT(allocators->allocators.size() == num_partitions); - Initialize(); -} - -RadixPartitionedTupleData::RadixPartitionedTupleData(const RadixPartitionedTupleData &other) - : PartitionedTupleData(other), radix_bits(other.radix_bits), hash_col_idx(other.hash_col_idx) { - Initialize(); -} - -RadixPartitionedTupleData::~RadixPartitionedTupleData() { -} - -void RadixPartitionedTupleData::Initialize() { - for (idx_t i = 0; i < RadixPartitioning::NumberOfPartitions(radix_bits); i++) { - partitions.emplace_back(CreatePartitionCollection(i)); - } -} - -void RadixPartitionedTupleData::InitializeAppendStateInternal(PartitionedTupleDataAppendState &state, - TupleDataPinProperties properties) const { - // Init pin state per partition - const auto num_partitions = RadixPartitioning::NumberOfPartitions(radix_bits); - state.partition_pin_states.reserve(num_partitions); - for (idx_t i = 0; i < num_partitions; i++) { - state.partition_pin_states.emplace_back(make_uniq()); - partitions[i]->InitializeAppend(*state.partition_pin_states[i], properties); - } - - // Init single chunk state - auto column_count = layout.ColumnCount(); - vector column_ids; - column_ids.reserve(column_count); - for (idx_t col_idx = 0; col_idx < column_count; col_idx++) { - column_ids.emplace_back(col_idx); - } - partitions[0]->InitializeChunkState(state.chunk_state, std::move(column_ids)); - - // Initialize fixed-size map - state.fixed_partition_entries.resize(RadixPartitioning::NumberOfPartitions(radix_bits)); -} - -void RadixPartitionedTupleData::ComputePartitionIndices(PartitionedTupleDataAppendState &state, DataChunk &input) { - D_ASSERT(partitions.size() == RadixPartitioning::NumberOfPartitions(radix_bits)); - RadixBitsSwitch(radix_bits, input.data[hash_col_idx], state.partition_indices, - input.size()); -} - -void RadixPartitionedTupleData::ComputePartitionIndices(Vector &row_locations, idx_t count, - Vector &partition_indices) const { - Vector intermediate(LogicalType::HASH); - partitions[0]->Gather(row_locations, *FlatVector::IncrementalSelectionVector(), count, hash_col_idx, intermediate, - *FlatVector::IncrementalSelectionVector()); - RadixBitsSwitch(radix_bits, intermediate, partition_indices, count); -} - -void RadixPartitionedTupleData::RepartitionFinalizeStates(PartitionedTupleData &old_partitioned_data, - PartitionedTupleData &new_partitioned_data, - PartitionedTupleDataAppendState &state, - idx_t finished_partition_idx) const { - D_ASSERT(old_partitioned_data.GetType() == PartitionedTupleDataType::RADIX && - new_partitioned_data.GetType() == PartitionedTupleDataType::RADIX); - const auto &old_radix_partitions = old_partitioned_data.Cast(); - const auto &new_radix_partitions = new_partitioned_data.Cast(); - const auto old_radix_bits = old_radix_partitions.GetRadixBits(); - const auto new_radix_bits = new_radix_partitions.GetRadixBits(); - D_ASSERT(new_radix_bits > old_radix_bits); - - // We take the most significant digits as the partition index - // When repartitioning, e.g., partition 0 from "old" goes into the first N partitions in "new" - // When partition 0 is done, we can already finalize the append states, unpinning blocks - const auto multiplier = RadixPartitioning::NumberOfPartitions(new_radix_bits - old_radix_bits); - const auto from_idx = finished_partition_idx * multiplier; - const auto to_idx = from_idx + multiplier; - auto &partitions = new_partitioned_data.GetPartitions(); - for (idx_t partition_index = from_idx; partition_index < to_idx; partition_index++) { - auto &partition = *partitions[partition_index]; - auto &partition_pin_state = *state.partition_pin_states[partition_index]; - partition.FinalizePinState(partition_pin_state); - } -} - -} // namespace duckdb - - -#include - -namespace duckdb { - -struct RandomState { - RandomState() { - } - - pcg32 pcg; -}; - -RandomEngine::RandomEngine(int64_t seed) : random_state(make_uniq()) { - if (seed < 0) { - random_state->pcg.seed(pcg_extras::seed_seq_from()); - } else { - random_state->pcg.seed(seed); - } -} - -RandomEngine::~RandomEngine() { -} - -double RandomEngine::NextRandom(double min, double max) { - D_ASSERT(max >= min); - return min + (NextRandom() * (max - min)); -} - -double RandomEngine::NextRandom() { - return std::ldexp(random_state->pcg(), -32); -} -uint32_t RandomEngine::NextRandomInteger() { - return random_state->pcg(); -} - -void RandomEngine::SetSeed(uint32_t seed) { - random_state->pcg.seed(seed); -} - -} // namespace duckdb - -#include - - - - -namespace duckdb_re2 { - -Regex::Regex(const std::string &pattern, RegexOptions options) { - RE2::Options o; - o.set_case_sensitive(options == RegexOptions::CASE_INSENSITIVE); - regex = std::make_shared(StringPiece(pattern), o); -} - -bool RegexSearchInternal(const char *input, Match &match, const Regex &r, RE2::Anchor anchor, size_t start, - size_t end) { - auto ®ex = r.GetRegex(); - duckdb::vector target_groups; - auto group_count = regex.NumberOfCapturingGroups() + 1; - target_groups.resize(group_count); - match.groups.clear(); - if (!regex.Match(StringPiece(input), start, end, anchor, target_groups.data(), group_count)) { - return false; - } - for (auto &group : target_groups) { - GroupMatch group_match; - group_match.text = group.ToString(); - group_match.position = group.data() - input; - match.groups.emplace_back(group_match); - } - return true; -} - -bool RegexSearch(const std::string &input, Match &match, const Regex ®ex) { - return RegexSearchInternal(input.c_str(), match, regex, RE2::UNANCHORED, 0, input.size()); -} - -bool RegexMatch(const std::string &input, Match &match, const Regex ®ex) { - return RegexSearchInternal(input.c_str(), match, regex, RE2::ANCHOR_BOTH, 0, input.size()); -} - -bool RegexMatch(const char *start, const char *end, Match &match, const Regex ®ex) { - return RegexSearchInternal(start, match, regex, RE2::ANCHOR_BOTH, 0, end - start); -} - -bool RegexMatch(const std::string &input, const Regex ®ex) { - Match nop_match; - return RegexSearchInternal(input.c_str(), nop_match, regex, RE2::ANCHOR_BOTH, 0, input.size()); -} - -duckdb::vector RegexFindAll(const std::string &input, const Regex ®ex) { - duckdb::vector matches; - size_t position = 0; - Match match; - while (RegexSearchInternal(input.c_str(), match, regex, RE2::UNANCHORED, position, input.size())) { - position += match.position(0) + match.length(0); - matches.emplace_back(match); - } - return matches; -} - -} // namespace duckdb_re2 -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/types/row_operations/row_aggregate.cpp -// -// -//===----------------------------------------------------------------------===// - - - - - -namespace duckdb { - -void RowOperations::InitializeStates(TupleDataLayout &layout, Vector &addresses, const SelectionVector &sel, - idx_t count) { - if (count == 0) { - return; - } - auto pointers = FlatVector::GetData(addresses); - auto &offsets = layout.GetOffsets(); - auto aggr_idx = layout.ColumnCount(); - - for (const auto &aggr : layout.GetAggregates()) { - for (idx_t i = 0; i < count; ++i) { - auto row_idx = sel.get_index(i); - auto row = pointers[row_idx]; - aggr.function.initialize(row + offsets[aggr_idx]); - } - ++aggr_idx; - } -} - -void RowOperations::DestroyStates(RowOperationsState &state, TupleDataLayout &layout, Vector &addresses, idx_t count) { - if (count == 0) { - return; - } - // Move to the first aggregate state - VectorOperations::AddInPlace(addresses, layout.GetAggrOffset(), count); - for (const auto &aggr : layout.GetAggregates()) { - if (aggr.function.destructor) { - AggregateInputData aggr_input_data(aggr.GetFunctionData(), state.allocator); - aggr.function.destructor(addresses, aggr_input_data, count); - } - // Move to the next aggregate state - VectorOperations::AddInPlace(addresses, aggr.payload_size, count); - } -} - -void RowOperations::UpdateStates(RowOperationsState &state, AggregateObject &aggr, Vector &addresses, - DataChunk &payload, idx_t arg_idx, idx_t count) { - AggregateInputData aggr_input_data(aggr.GetFunctionData(), state.allocator); - aggr.function.update(aggr.child_count == 0 ? nullptr : &payload.data[arg_idx], aggr_input_data, aggr.child_count, - addresses, count); -} - -void RowOperations::UpdateFilteredStates(RowOperationsState &state, AggregateFilterData &filter_data, - AggregateObject &aggr, Vector &addresses, DataChunk &payload, idx_t arg_idx) { - idx_t count = filter_data.ApplyFilter(payload); - if (count == 0) { - return; - } - - Vector filtered_addresses(addresses, filter_data.true_sel, count); - filtered_addresses.Flatten(count); - - UpdateStates(state, aggr, filtered_addresses, filter_data.filtered_payload, arg_idx, count); -} - -void RowOperations::CombineStates(RowOperationsState &state, TupleDataLayout &layout, Vector &sources, Vector &targets, - idx_t count) { - if (count == 0) { - return; - } - - // Move to the first aggregate states - VectorOperations::AddInPlace(sources, layout.GetAggrOffset(), count); - VectorOperations::AddInPlace(targets, layout.GetAggrOffset(), count); - - // Keep track of the offset - idx_t offset = layout.GetAggrOffset(); - - for (auto &aggr : layout.GetAggregates()) { - D_ASSERT(aggr.function.combine); - AggregateInputData aggr_input_data(aggr.GetFunctionData(), state.allocator); - aggr.function.combine(sources, targets, aggr_input_data, count); - - // Move to the next aggregate states - VectorOperations::AddInPlace(sources, aggr.payload_size, count); - VectorOperations::AddInPlace(targets, aggr.payload_size, count); - - // Increment the offset - offset += aggr.payload_size; - } - - // Now subtract the offset to get back to the original position - VectorOperations::AddInPlace(sources, -offset, count); - VectorOperations::AddInPlace(targets, -offset, count); -} - -void RowOperations::FinalizeStates(RowOperationsState &state, TupleDataLayout &layout, Vector &addresses, - DataChunk &result, idx_t aggr_idx) { - // Copy the addresses - Vector addresses_copy(LogicalType::POINTER); - VectorOperations::Copy(addresses, addresses_copy, result.size(), 0, 0); - - // Move to the first aggregate state - VectorOperations::AddInPlace(addresses_copy, layout.GetAggrOffset(), result.size()); - - auto &aggregates = layout.GetAggregates(); - for (idx_t i = 0; i < aggregates.size(); i++) { - auto &target = result.data[aggr_idx + i]; - auto &aggr = aggregates[i]; - AggregateInputData aggr_input_data(aggr.GetFunctionData(), state.allocator); - aggr.function.finalize(addresses_copy, aggr_input_data, target, result.size(), 0); - - // Move to the next aggregate state - VectorOperations::AddInPlace(addresses_copy, aggr.payload_size, result.size()); - } -} - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/types/row_operations/row_external.cpp -// -// -//===----------------------------------------------------------------------===// - - - -namespace duckdb { - -using ValidityBytes = RowLayout::ValidityBytes; - -void RowOperations::SwizzleColumns(const RowLayout &layout, const data_ptr_t base_row_ptr, const idx_t count) { - const idx_t row_width = layout.GetRowWidth(); - data_ptr_t heap_row_ptrs[STANDARD_VECTOR_SIZE]; - idx_t done = 0; - while (done != count) { - const idx_t next = MinValue(count - done, STANDARD_VECTOR_SIZE); - const data_ptr_t row_ptr = base_row_ptr + done * row_width; - // Load heap row pointers - data_ptr_t heap_ptr_ptr = row_ptr + layout.GetHeapOffset(); - for (idx_t i = 0; i < next; i++) { - heap_row_ptrs[i] = Load(heap_ptr_ptr); - heap_ptr_ptr += row_width; - } - // Loop through the blob columns - for (idx_t col_idx = 0; col_idx < layout.ColumnCount(); col_idx++) { - auto physical_type = layout.GetTypes()[col_idx].InternalType(); - if (TypeIsConstantSize(physical_type)) { - continue; - } - data_ptr_t col_ptr = row_ptr + layout.GetOffsets()[col_idx]; - if (physical_type == PhysicalType::VARCHAR) { - data_ptr_t string_ptr = col_ptr + string_t::HEADER_SIZE; - for (idx_t i = 0; i < next; i++) { - if (Load(col_ptr) > string_t::INLINE_LENGTH) { - // Overwrite the string pointer with the within-row offset (if not inlined) - Store(Load(string_ptr) - heap_row_ptrs[i], string_ptr); - } - col_ptr += row_width; - string_ptr += row_width; - } - } else { - // Non-varchar blob columns - for (idx_t i = 0; i < next; i++) { - // Overwrite the column data pointer with the within-row offset - Store(Load(col_ptr) - heap_row_ptrs[i], col_ptr); - col_ptr += row_width; - } - } - } - done += next; - } -} - -void RowOperations::SwizzleHeapPointer(const RowLayout &layout, data_ptr_t row_ptr, const data_ptr_t heap_base_ptr, - const idx_t count, const idx_t base_offset) { - const idx_t row_width = layout.GetRowWidth(); - row_ptr += layout.GetHeapOffset(); - idx_t cumulative_offset = 0; - for (idx_t i = 0; i < count; i++) { - Store(base_offset + cumulative_offset, row_ptr); - cumulative_offset += Load(heap_base_ptr + cumulative_offset); - row_ptr += row_width; - } -} - -void RowOperations::CopyHeapAndSwizzle(const RowLayout &layout, data_ptr_t row_ptr, const data_ptr_t heap_base_ptr, - data_ptr_t heap_ptr, const idx_t count) { - const auto row_width = layout.GetRowWidth(); - const auto heap_offset = layout.GetHeapOffset(); - for (idx_t i = 0; i < count; i++) { - // Figure out source and size - const auto source_heap_ptr = Load(row_ptr + heap_offset); - const auto size = Load(source_heap_ptr); - D_ASSERT(size >= sizeof(uint32_t)); - - // Copy and swizzle - memcpy(heap_ptr, source_heap_ptr, size); - Store(heap_ptr - heap_base_ptr, row_ptr + heap_offset); - - // Increment for next iteration - row_ptr += row_width; - heap_ptr += size; - } -} - -void RowOperations::UnswizzleHeapPointer(const RowLayout &layout, const data_ptr_t base_row_ptr, - const data_ptr_t base_heap_ptr, const idx_t count) { - const auto row_width = layout.GetRowWidth(); - data_ptr_t heap_ptr_ptr = base_row_ptr + layout.GetHeapOffset(); - for (idx_t i = 0; i < count; i++) { - Store(base_heap_ptr + Load(heap_ptr_ptr), heap_ptr_ptr); - heap_ptr_ptr += row_width; - } -} - -static inline void VerifyUnswizzledString(const RowLayout &layout, const idx_t &col_idx, const data_ptr_t &row_ptr) { -#ifdef DEBUG - if (layout.GetTypes()[col_idx].id() != LogicalTypeId::VARCHAR) { - return; - } - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); - - ValidityBytes row_mask(row_ptr); - if (row_mask.RowIsValid(row_mask.GetValidityEntry(entry_idx), idx_in_entry)) { - auto str = Load(row_ptr + layout.GetOffsets()[col_idx]); - str.Verify(); - } -#endif -} - -void RowOperations::UnswizzlePointers(const RowLayout &layout, const data_ptr_t base_row_ptr, - const data_ptr_t base_heap_ptr, const idx_t count) { - const idx_t row_width = layout.GetRowWidth(); - data_ptr_t heap_row_ptrs[STANDARD_VECTOR_SIZE]; - idx_t done = 0; - while (done != count) { - const idx_t next = MinValue(count - done, STANDARD_VECTOR_SIZE); - const data_ptr_t row_ptr = base_row_ptr + done * row_width; - // Restore heap row pointers - data_ptr_t heap_ptr_ptr = row_ptr + layout.GetHeapOffset(); - for (idx_t i = 0; i < next; i++) { - heap_row_ptrs[i] = base_heap_ptr + Load(heap_ptr_ptr); - Store(heap_row_ptrs[i], heap_ptr_ptr); - heap_ptr_ptr += row_width; - } - // Loop through the blob columns - for (idx_t col_idx = 0; col_idx < layout.ColumnCount(); col_idx++) { - auto physical_type = layout.GetTypes()[col_idx].InternalType(); - if (TypeIsConstantSize(physical_type)) { - continue; - } - data_ptr_t col_ptr = row_ptr + layout.GetOffsets()[col_idx]; - if (physical_type == PhysicalType::VARCHAR) { - data_ptr_t string_ptr = col_ptr + string_t::HEADER_SIZE; - for (idx_t i = 0; i < next; i++) { - if (Load(col_ptr) > string_t::INLINE_LENGTH) { - // Overwrite the string offset with the pointer (if not inlined) - Store(heap_row_ptrs[i] + Load(string_ptr), string_ptr); - VerifyUnswizzledString(layout, col_idx, row_ptr + i * row_width); - } - col_ptr += row_width; - string_ptr += row_width; - } - } else { - // Non-varchar blob columns - for (idx_t i = 0; i < next; i++) { - // Overwrite the column data offset with the pointer - Store(heap_row_ptrs[i] + Load(col_ptr), col_ptr); - col_ptr += row_width; - } - } - } - done += next; - } -} - -} // namespace duckdb -//===--------------------------------------------------------------------===// -// row_gather.cpp -// Description: This file contains the implementation of the gather operators -//===--------------------------------------------------------------------===// - - - - - - - - -namespace duckdb { - -using ValidityBytes = RowLayout::ValidityBytes; - -template -static void TemplatedGatherLoop(Vector &rows, const SelectionVector &row_sel, Vector &col, - const SelectionVector &col_sel, idx_t count, const RowLayout &layout, idx_t col_no, - idx_t build_size) { - // Precompute mask indexes - const auto &offsets = layout.GetOffsets(); - const auto col_offset = offsets[col_no]; - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_no, entry_idx, idx_in_entry); - - auto ptrs = FlatVector::GetData(rows); - auto data = FlatVector::GetData(col); - auto &col_mask = FlatVector::Validity(col); - - for (idx_t i = 0; i < count; i++) { - auto row_idx = row_sel.get_index(i); - auto row = ptrs[row_idx]; - auto col_idx = col_sel.get_index(i); - data[col_idx] = Load(row + col_offset); - ValidityBytes row_mask(row); - if (!row_mask.RowIsValid(row_mask.GetValidityEntry(entry_idx), idx_in_entry)) { - if (build_size > STANDARD_VECTOR_SIZE && col_mask.AllValid()) { - //! We need to initialize the mask with the vector size. - col_mask.Initialize(build_size); - } - col_mask.SetInvalid(col_idx); - } - } -} - -static void GatherVarchar(Vector &rows, const SelectionVector &row_sel, Vector &col, const SelectionVector &col_sel, - idx_t count, const RowLayout &layout, idx_t col_no, idx_t build_size, - data_ptr_t base_heap_ptr) { - // Precompute mask indexes - const auto &offsets = layout.GetOffsets(); - const auto col_offset = offsets[col_no]; - const auto heap_offset = layout.GetHeapOffset(); - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_no, entry_idx, idx_in_entry); - - auto ptrs = FlatVector::GetData(rows); - auto data = FlatVector::GetData(col); - auto &col_mask = FlatVector::Validity(col); - - for (idx_t i = 0; i < count; i++) { - auto row_idx = row_sel.get_index(i); - auto row = ptrs[row_idx]; - auto col_idx = col_sel.get_index(i); - auto col_ptr = row + col_offset; - data[col_idx] = Load(col_ptr); - ValidityBytes row_mask(row); - if (!row_mask.RowIsValid(row_mask.GetValidityEntry(entry_idx), idx_in_entry)) { - if (build_size > STANDARD_VECTOR_SIZE && col_mask.AllValid()) { - //! We need to initialize the mask with the vector size. - col_mask.Initialize(build_size); - } - col_mask.SetInvalid(col_idx); - } else if (base_heap_ptr && Load(col_ptr) > string_t::INLINE_LENGTH) { - // Not inline, so unswizzle the copied pointer the pointer - auto heap_ptr_ptr = row + heap_offset; - auto heap_row_ptr = base_heap_ptr + Load(heap_ptr_ptr); - auto string_ptr = data_ptr_t(data + col_idx) + string_t::HEADER_SIZE; - Store(heap_row_ptr + Load(string_ptr), string_ptr); -#ifdef DEBUG - data[col_idx].Verify(); -#endif - } - } -} - -static void GatherNestedVector(Vector &rows, const SelectionVector &row_sel, Vector &col, - const SelectionVector &col_sel, idx_t count, const RowLayout &layout, idx_t col_no, - data_ptr_t base_heap_ptr) { - const auto &offsets = layout.GetOffsets(); - const auto col_offset = offsets[col_no]; - const auto heap_offset = layout.GetHeapOffset(); - auto ptrs = FlatVector::GetData(rows); - - // Build the gather locations - auto data_locations = make_unsafe_uniq_array(count); - auto mask_locations = make_unsafe_uniq_array(count); - for (idx_t i = 0; i < count; i++) { - auto row_idx = row_sel.get_index(i); - auto row = ptrs[row_idx]; - mask_locations[i] = row; - auto col_ptr = ptrs[row_idx] + col_offset; - if (base_heap_ptr) { - auto heap_ptr_ptr = row + heap_offset; - auto heap_row_ptr = base_heap_ptr + Load(heap_ptr_ptr); - data_locations[i] = heap_row_ptr + Load(col_ptr); - } else { - data_locations[i] = Load(col_ptr); - } - } - - // Deserialise into the selected locations - RowOperations::HeapGather(col, count, col_sel, col_no, data_locations.get(), mask_locations.get()); -} - -void RowOperations::Gather(Vector &rows, const SelectionVector &row_sel, Vector &col, const SelectionVector &col_sel, - const idx_t count, const RowLayout &layout, const idx_t col_no, const idx_t build_size, - data_ptr_t heap_ptr) { - D_ASSERT(rows.GetVectorType() == VectorType::FLAT_VECTOR); - D_ASSERT(rows.GetType().id() == LogicalTypeId::POINTER); // "Cannot gather from non-pointer type!" - - col.SetVectorType(VectorType::FLAT_VECTOR); - switch (col.GetType().InternalType()) { - case PhysicalType::UINT8: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::UINT16: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::UINT32: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::UINT64: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::INT16: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::INT32: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::INT64: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::INT128: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::FLOAT: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::DOUBLE: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::INTERVAL: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::VARCHAR: - GatherVarchar(rows, row_sel, col, col_sel, count, layout, col_no, build_size, heap_ptr); - break; - case PhysicalType::LIST: - case PhysicalType::STRUCT: - GatherNestedVector(rows, row_sel, col, col_sel, count, layout, col_no, heap_ptr); - break; - default: - throw InternalException("Unimplemented type for RowOperations::Gather"); - } -} - -template -static void TemplatedFullScanLoop(Vector &rows, Vector &col, idx_t count, idx_t col_offset, idx_t col_no) { - // Precompute mask indexes - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_no, entry_idx, idx_in_entry); - - auto ptrs = FlatVector::GetData(rows); - auto data = FlatVector::GetData(col); - // auto &col_mask = FlatVector::Validity(col); - - for (idx_t i = 0; i < count; i++) { - auto row = ptrs[i]; - data[i] = Load(row + col_offset); - ValidityBytes row_mask(row); - if (!row_mask.RowIsValid(row_mask.GetValidityEntry(entry_idx), idx_in_entry)) { - throw InternalException("Null value comparisons not implemented for perfect hash table yet"); - // col_mask.SetInvalid(i); - } - } -} - -void RowOperations::FullScanColumn(const TupleDataLayout &layout, Vector &rows, Vector &col, idx_t count, - idx_t col_no) { - const auto col_offset = layout.GetOffsets()[col_no]; - col.SetVectorType(VectorType::FLAT_VECTOR); - switch (col.GetType().InternalType()) { - case PhysicalType::UINT8: - TemplatedFullScanLoop(rows, col, count, col_offset, col_no); - break; - case PhysicalType::UINT16: - TemplatedFullScanLoop(rows, col, count, col_offset, col_no); - break; - case PhysicalType::UINT32: - TemplatedFullScanLoop(rows, col, count, col_offset, col_no); - break; - case PhysicalType::UINT64: - TemplatedFullScanLoop(rows, col, count, col_offset, col_no); - break; - case PhysicalType::INT8: - TemplatedFullScanLoop(rows, col, count, col_offset, col_no); - break; - case PhysicalType::INT16: - TemplatedFullScanLoop(rows, col, count, col_offset, col_no); - break; - case PhysicalType::INT32: - TemplatedFullScanLoop(rows, col, count, col_offset, col_no); - break; - case PhysicalType::INT64: - TemplatedFullScanLoop(rows, col, count, col_offset, col_no); - break; - default: - throw NotImplementedException("Unimplemented type for RowOperations::FullScanColumn"); - } -} - -} // namespace duckdb - - - - -namespace duckdb { - -using ValidityBytes = TemplatedValidityMask; - -template -static void TemplatedHeapGather(Vector &v, const idx_t count, const SelectionVector &sel, data_ptr_t *key_locations) { - auto target = FlatVector::GetData(v); - - for (idx_t i = 0; i < count; ++i) { - const auto col_idx = sel.get_index(i); - target[col_idx] = Load(key_locations[i]); - key_locations[i] += sizeof(T); - } -} - -static void HeapGatherStringVector(Vector &v, const idx_t vcount, const SelectionVector &sel, - data_ptr_t *key_locations) { - const auto &validity = FlatVector::Validity(v); - auto target = FlatVector::GetData(v); - - for (idx_t i = 0; i < vcount; i++) { - const auto col_idx = sel.get_index(i); - if (!validity.RowIsValid(col_idx)) { - continue; - } - auto len = Load(key_locations[i]); - key_locations[i] += sizeof(uint32_t); - target[col_idx] = StringVector::AddStringOrBlob(v, string_t(const_char_ptr_cast(key_locations[i]), len)); - key_locations[i] += len; - } -} - -static void HeapGatherStructVector(Vector &v, const idx_t vcount, const SelectionVector &sel, - data_ptr_t *key_locations) { - // struct must have a validitymask for its fields - auto &child_types = StructType::GetChildTypes(v.GetType()); - const idx_t struct_validitymask_size = (child_types.size() + 7) / 8; - data_ptr_t struct_validitymask_locations[STANDARD_VECTOR_SIZE]; - for (idx_t i = 0; i < vcount; i++) { - // use key_locations as the validitymask, and create struct_key_locations - struct_validitymask_locations[i] = key_locations[i]; - key_locations[i] += struct_validitymask_size; - } - - // now deserialize into the struct vectors - auto &children = StructVector::GetEntries(v); - for (idx_t i = 0; i < child_types.size(); i++) { - RowOperations::HeapGather(*children[i], vcount, sel, i, key_locations, struct_validitymask_locations); - } -} - -static void HeapGatherListVector(Vector &v, const idx_t vcount, const SelectionVector &sel, data_ptr_t *key_locations) { - const auto &validity = FlatVector::Validity(v); - - auto child_type = ListType::GetChildType(v.GetType()); - auto list_data = ListVector::GetData(v); - data_ptr_t list_entry_locations[STANDARD_VECTOR_SIZE]; - - uint64_t entry_offset = ListVector::GetListSize(v); - for (idx_t i = 0; i < vcount; i++) { - const auto col_idx = sel.get_index(i); - if (!validity.RowIsValid(col_idx)) { - continue; - } - // read list length - auto entry_remaining = Load(key_locations[i]); - key_locations[i] += sizeof(uint64_t); - // set list entry attributes - list_data[col_idx].length = entry_remaining; - list_data[col_idx].offset = entry_offset; - // skip over the validity mask - data_ptr_t validitymask_location = key_locations[i]; - idx_t offset_in_byte = 0; - key_locations[i] += (entry_remaining + 7) / 8; - // entry sizes - data_ptr_t var_entry_size_ptr = nullptr; - if (!TypeIsConstantSize(child_type.InternalType())) { - var_entry_size_ptr = key_locations[i]; - key_locations[i] += entry_remaining * sizeof(idx_t); - } - - // now read the list data - while (entry_remaining > 0) { - auto next = MinValue(entry_remaining, (idx_t)STANDARD_VECTOR_SIZE); - - // initialize a new vector to append - Vector append_vector(v.GetType()); - append_vector.SetVectorType(v.GetVectorType()); - - auto &list_vec_to_append = ListVector::GetEntry(append_vector); - - // set validity - //! Since we are constructing the vector, this will always be a flat vector. - auto &append_validity = FlatVector::Validity(list_vec_to_append); - for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { - append_validity.Set(entry_idx, *(validitymask_location) & (1 << offset_in_byte)); - if (++offset_in_byte == 8) { - validitymask_location++; - offset_in_byte = 0; - } - } - - // compute entry sizes and set locations where the list entries are - if (TypeIsConstantSize(child_type.InternalType())) { - // constant size list entries - const idx_t type_size = GetTypeIdSize(child_type.InternalType()); - for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { - list_entry_locations[entry_idx] = key_locations[i]; - key_locations[i] += type_size; - } - } else { - // variable size list entries - for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { - list_entry_locations[entry_idx] = key_locations[i]; - key_locations[i] += Load(var_entry_size_ptr); - var_entry_size_ptr += sizeof(idx_t); - } - } - - // now deserialize and add to listvector - RowOperations::HeapGather(list_vec_to_append, next, *FlatVector::IncrementalSelectionVector(), 0, - list_entry_locations, nullptr); - ListVector::Append(v, list_vec_to_append, next); - - // update for next iteration - entry_remaining -= next; - entry_offset += next; - } - } -} - -void RowOperations::HeapGather(Vector &v, const idx_t &vcount, const SelectionVector &sel, const idx_t &col_no, - data_ptr_t *key_locations, data_ptr_t *validitymask_locations) { - v.SetVectorType(VectorType::FLAT_VECTOR); - - auto &validity = FlatVector::Validity(v); - if (validitymask_locations) { - // Precompute mask indexes - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_no, entry_idx, idx_in_entry); - - for (idx_t i = 0; i < vcount; i++) { - ValidityBytes row_mask(validitymask_locations[i]); - const auto valid = row_mask.RowIsValid(row_mask.GetValidityEntry(entry_idx), idx_in_entry); - const auto col_idx = sel.get_index(i); - validity.Set(col_idx, valid); - } - } - - auto type = v.GetType().InternalType(); - switch (type) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::INT16: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::INT32: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::INT64: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::UINT8: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::UINT16: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::UINT32: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::UINT64: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::INT128: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::FLOAT: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::DOUBLE: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::INTERVAL: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::VARCHAR: - HeapGatherStringVector(v, vcount, sel, key_locations); - break; - case PhysicalType::STRUCT: - HeapGatherStructVector(v, vcount, sel, key_locations); - break; - case PhysicalType::LIST: - HeapGatherListVector(v, vcount, sel, key_locations); - break; - default: - throw NotImplementedException("Unimplemented deserialize from row-format"); - } -} - -} // namespace duckdb - - - - -namespace duckdb { - -using ValidityBytes = TemplatedValidityMask; - -static void ComputeStringEntrySizes(UnifiedVectorFormat &vdata, idx_t entry_sizes[], const idx_t ser_count, - const SelectionVector &sel, const idx_t offset) { - auto strings = UnifiedVectorFormat::GetData(vdata); - for (idx_t i = 0; i < ser_count; i++) { - auto idx = sel.get_index(i); - auto str_idx = vdata.sel->get_index(idx + offset); - if (vdata.validity.RowIsValid(str_idx)) { - entry_sizes[i] += sizeof(uint32_t) + strings[str_idx].GetSize(); - } - } -} - -static void ComputeStructEntrySizes(Vector &v, idx_t entry_sizes[], idx_t vcount, idx_t ser_count, - const SelectionVector &sel, idx_t offset) { - // obtain child vectors - idx_t num_children; - auto &children = StructVector::GetEntries(v); - num_children = children.size(); - // add struct validitymask size - const idx_t struct_validitymask_size = (num_children + 7) / 8; - for (idx_t i = 0; i < ser_count; i++) { - entry_sizes[i] += struct_validitymask_size; - } - // compute size of child vectors - for (auto &struct_vector : children) { - RowOperations::ComputeEntrySizes(*struct_vector, entry_sizes, vcount, ser_count, sel, offset); - } -} - -static void ComputeListEntrySizes(Vector &v, UnifiedVectorFormat &vdata, idx_t entry_sizes[], idx_t ser_count, - const SelectionVector &sel, idx_t offset) { - auto list_data = ListVector::GetData(v); - auto &child_vector = ListVector::GetEntry(v); - idx_t list_entry_sizes[STANDARD_VECTOR_SIZE]; - for (idx_t i = 0; i < ser_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx + offset); - if (vdata.validity.RowIsValid(source_idx)) { - auto list_entry = list_data[source_idx]; - - // make room for list length, list validitymask - entry_sizes[i] += sizeof(list_entry.length); - entry_sizes[i] += (list_entry.length + 7) / 8; - - // serialize size of each entry (if non-constant size) - if (!TypeIsConstantSize(ListType::GetChildType(v.GetType()).InternalType())) { - entry_sizes[i] += list_entry.length * sizeof(list_entry.length); - } - - // compute size of each the elements in list_entry and sum them - auto entry_remaining = list_entry.length; - auto entry_offset = list_entry.offset; - while (entry_remaining > 0) { - // the list entry can span multiple vectors - auto next = MinValue((idx_t)STANDARD_VECTOR_SIZE, entry_remaining); - - // compute and add to the total - std::fill_n(list_entry_sizes, next, 0); - RowOperations::ComputeEntrySizes(child_vector, list_entry_sizes, next, next, - *FlatVector::IncrementalSelectionVector(), entry_offset); - for (idx_t list_idx = 0; list_idx < next; list_idx++) { - entry_sizes[i] += list_entry_sizes[list_idx]; - } - - // update for next iteration - entry_remaining -= next; - entry_offset += next; - } - } - } -} - -void RowOperations::ComputeEntrySizes(Vector &v, UnifiedVectorFormat &vdata, idx_t entry_sizes[], idx_t vcount, - idx_t ser_count, const SelectionVector &sel, idx_t offset) { - const auto physical_type = v.GetType().InternalType(); - if (TypeIsConstantSize(physical_type)) { - const auto type_size = GetTypeIdSize(physical_type); - for (idx_t i = 0; i < ser_count; i++) { - entry_sizes[i] += type_size; - } - } else { - switch (physical_type) { - case PhysicalType::VARCHAR: - ComputeStringEntrySizes(vdata, entry_sizes, ser_count, sel, offset); - break; - case PhysicalType::STRUCT: - ComputeStructEntrySizes(v, entry_sizes, vcount, ser_count, sel, offset); - break; - case PhysicalType::LIST: - ComputeListEntrySizes(v, vdata, entry_sizes, ser_count, sel, offset); - break; - default: - // LCOV_EXCL_START - throw NotImplementedException("Column with variable size type %s cannot be serialized to row-format", - v.GetType().ToString()); - // LCOV_EXCL_STOP - } - } -} - -void RowOperations::ComputeEntrySizes(Vector &v, idx_t entry_sizes[], idx_t vcount, idx_t ser_count, - const SelectionVector &sel, idx_t offset) { - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - ComputeEntrySizes(v, vdata, entry_sizes, vcount, ser_count, sel, offset); -} - -template -static void TemplatedHeapScatter(UnifiedVectorFormat &vdata, const SelectionVector &sel, idx_t count, idx_t col_idx, - data_ptr_t *key_locations, data_ptr_t *validitymask_locations, idx_t offset) { - auto source = UnifiedVectorFormat::GetData(vdata); - if (!validitymask_locations) { - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx + offset); - - auto target = (T *)key_locations[i]; - Store(source[source_idx], data_ptr_cast(target)); - key_locations[i] += sizeof(T); - } - } else { - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); - const auto bit = ~(1UL << idx_in_entry); - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx + offset); - - auto target = (T *)key_locations[i]; - Store(source[source_idx], data_ptr_cast(target)); - key_locations[i] += sizeof(T); - - // set the validitymask - if (!vdata.validity.RowIsValid(source_idx)) { - *(validitymask_locations[i] + entry_idx) &= bit; - } - } - } -} - -static void HeapScatterStringVector(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, idx_t col_idx, - data_ptr_t *key_locations, data_ptr_t *validitymask_locations, idx_t offset) { - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - - auto strings = UnifiedVectorFormat::GetData(vdata); - if (!validitymask_locations) { - for (idx_t i = 0; i < ser_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx + offset); - if (vdata.validity.RowIsValid(source_idx)) { - auto &string_entry = strings[source_idx]; - // store string size - Store(string_entry.GetSize(), key_locations[i]); - key_locations[i] += sizeof(uint32_t); - // store the string - memcpy(key_locations[i], string_entry.GetData(), string_entry.GetSize()); - key_locations[i] += string_entry.GetSize(); - } - } - } else { - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); - const auto bit = ~(1UL << idx_in_entry); - for (idx_t i = 0; i < ser_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx + offset); - if (vdata.validity.RowIsValid(source_idx)) { - auto &string_entry = strings[source_idx]; - // store string size - Store(string_entry.GetSize(), key_locations[i]); - key_locations[i] += sizeof(uint32_t); - // store the string - memcpy(key_locations[i], string_entry.GetData(), string_entry.GetSize()); - key_locations[i] += string_entry.GetSize(); - } else { - // set the validitymask - *(validitymask_locations[i] + entry_idx) &= bit; - } - } - } -} - -static void HeapScatterStructVector(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, idx_t col_idx, - data_ptr_t *key_locations, data_ptr_t *validitymask_locations, idx_t offset) { - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - - auto &children = StructVector::GetEntries(v); - idx_t num_children = children.size(); - - // the whole struct itself can be NULL - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); - const auto bit = ~(1UL << idx_in_entry); - - // struct must have a validitymask for its fields - const idx_t struct_validitymask_size = (num_children + 7) / 8; - data_ptr_t struct_validitymask_locations[STANDARD_VECTOR_SIZE]; - for (idx_t i = 0; i < ser_count; i++) { - // initialize the struct validity mask - struct_validitymask_locations[i] = key_locations[i]; - memset(struct_validitymask_locations[i], -1, struct_validitymask_size); - key_locations[i] += struct_validitymask_size; - - // set whether the whole struct is null - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - if (validitymask_locations && !vdata.validity.RowIsValid(source_idx)) { - *(validitymask_locations[i] + entry_idx) &= bit; - } - } - - // now serialize the struct vectors - for (idx_t i = 0; i < children.size(); i++) { - auto &struct_vector = *children[i]; - RowOperations::HeapScatter(struct_vector, vcount, sel, ser_count, i, key_locations, - struct_validitymask_locations, offset); - } -} - -static void HeapScatterListVector(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, idx_t col_no, - data_ptr_t *key_locations, data_ptr_t *validitymask_locations, idx_t offset) { - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_no, entry_idx, idx_in_entry); - - auto list_data = ListVector::GetData(v); - - auto &child_vector = ListVector::GetEntry(v); - - UnifiedVectorFormat list_vdata; - child_vector.ToUnifiedFormat(ListVector::GetListSize(v), list_vdata); - auto child_type = ListType::GetChildType(v.GetType()).InternalType(); - - idx_t list_entry_sizes[STANDARD_VECTOR_SIZE]; - data_ptr_t list_entry_locations[STANDARD_VECTOR_SIZE]; - - for (idx_t i = 0; i < ser_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx + offset); - if (!vdata.validity.RowIsValid(source_idx)) { - if (validitymask_locations) { - // set the row validitymask for this column to invalid - ValidityBytes row_mask(validitymask_locations[i]); - row_mask.SetInvalidUnsafe(entry_idx, idx_in_entry); - } - continue; - } - auto list_entry = list_data[source_idx]; - - // store list length - Store(list_entry.length, key_locations[i]); - key_locations[i] += sizeof(list_entry.length); - - // make room for the validitymask - data_ptr_t list_validitymask_location = key_locations[i]; - idx_t entry_offset_in_byte = 0; - idx_t validitymask_size = (list_entry.length + 7) / 8; - memset(list_validitymask_location, -1, validitymask_size); - key_locations[i] += validitymask_size; - - // serialize size of each entry (if non-constant size) - data_ptr_t var_entry_size_ptr = nullptr; - if (!TypeIsConstantSize(child_type)) { - var_entry_size_ptr = key_locations[i]; - key_locations[i] += list_entry.length * sizeof(idx_t); - } - - auto entry_remaining = list_entry.length; - auto entry_offset = list_entry.offset; - while (entry_remaining > 0) { - // the list entry can span multiple vectors - auto next = MinValue((idx_t)STANDARD_VECTOR_SIZE, entry_remaining); - - // serialize list validity - for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { - auto list_idx = list_vdata.sel->get_index(entry_idx + entry_offset); - if (!list_vdata.validity.RowIsValid(list_idx)) { - *(list_validitymask_location) &= ~(1UL << entry_offset_in_byte); - } - if (++entry_offset_in_byte == 8) { - list_validitymask_location++; - entry_offset_in_byte = 0; - } - } - - if (TypeIsConstantSize(child_type)) { - // constant size list entries: set list entry locations - const idx_t type_size = GetTypeIdSize(child_type); - for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { - list_entry_locations[entry_idx] = key_locations[i]; - key_locations[i] += type_size; - } - } else { - // variable size list entries: compute entry sizes and set list entry locations - std::fill_n(list_entry_sizes, next, 0); - RowOperations::ComputeEntrySizes(child_vector, list_entry_sizes, next, next, - *FlatVector::IncrementalSelectionVector(), entry_offset); - for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { - list_entry_locations[entry_idx] = key_locations[i]; - key_locations[i] += list_entry_sizes[entry_idx]; - Store(list_entry_sizes[entry_idx], var_entry_size_ptr); - var_entry_size_ptr += sizeof(idx_t); - } - } - - // now serialize to the locations - RowOperations::HeapScatter(child_vector, ListVector::GetListSize(v), - *FlatVector::IncrementalSelectionVector(), next, 0, list_entry_locations, - nullptr, entry_offset); - - // update for next iteration - entry_remaining -= next; - entry_offset += next; - } - } -} - -void RowOperations::HeapScatter(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, idx_t col_idx, - data_ptr_t *key_locations, data_ptr_t *validitymask_locations, idx_t offset) { - if (TypeIsConstantSize(v.GetType().InternalType())) { - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - RowOperations::HeapScatterVData(vdata, v.GetType().InternalType(), sel, ser_count, col_idx, key_locations, - validitymask_locations, offset); - } else { - switch (v.GetType().InternalType()) { - case PhysicalType::VARCHAR: - HeapScatterStringVector(v, vcount, sel, ser_count, col_idx, key_locations, validitymask_locations, offset); - break; - case PhysicalType::STRUCT: - HeapScatterStructVector(v, vcount, sel, ser_count, col_idx, key_locations, validitymask_locations, offset); - break; - case PhysicalType::LIST: - HeapScatterListVector(v, vcount, sel, ser_count, col_idx, key_locations, validitymask_locations, offset); - break; - default: - // LCOV_EXCL_START - throw NotImplementedException("Serialization of variable length vector with type %s", - v.GetType().ToString()); - // LCOV_EXCL_STOP - } - } -} - -void RowOperations::HeapScatterVData(UnifiedVectorFormat &vdata, PhysicalType type, const SelectionVector &sel, - idx_t ser_count, idx_t col_idx, data_ptr_t *key_locations, - data_ptr_t *validitymask_locations, idx_t offset) { - switch (type) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedHeapScatter(vdata, sel, ser_count, col_idx, key_locations, validitymask_locations, offset); - break; - case PhysicalType::INT16: - TemplatedHeapScatter(vdata, sel, ser_count, col_idx, key_locations, validitymask_locations, offset); - break; - case PhysicalType::INT32: - TemplatedHeapScatter(vdata, sel, ser_count, col_idx, key_locations, validitymask_locations, offset); - break; - case PhysicalType::INT64: - TemplatedHeapScatter(vdata, sel, ser_count, col_idx, key_locations, validitymask_locations, offset); - break; - case PhysicalType::UINT8: - TemplatedHeapScatter(vdata, sel, ser_count, col_idx, key_locations, validitymask_locations, offset); - break; - case PhysicalType::UINT16: - TemplatedHeapScatter(vdata, sel, ser_count, col_idx, key_locations, validitymask_locations, offset); - break; - case PhysicalType::UINT32: - TemplatedHeapScatter(vdata, sel, ser_count, col_idx, key_locations, validitymask_locations, offset); - break; - case PhysicalType::UINT64: - TemplatedHeapScatter(vdata, sel, ser_count, col_idx, key_locations, validitymask_locations, offset); - break; - case PhysicalType::INT128: - TemplatedHeapScatter(vdata, sel, ser_count, col_idx, key_locations, validitymask_locations, offset); - break; - case PhysicalType::FLOAT: - TemplatedHeapScatter(vdata, sel, ser_count, col_idx, key_locations, validitymask_locations, offset); - break; - case PhysicalType::DOUBLE: - TemplatedHeapScatter(vdata, sel, ser_count, col_idx, key_locations, validitymask_locations, offset); - break; - case PhysicalType::INTERVAL: - TemplatedHeapScatter(vdata, sel, ser_count, col_idx, key_locations, validitymask_locations, offset); - break; - default: - throw NotImplementedException("FIXME: Serialize to of constant type column to row-format"); - } -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -using ValidityBytes = TupleDataLayout::ValidityBytes; - -template -static idx_t TemplatedMatch(Vector &, const TupleDataVectorFormat &lhs_format, SelectionVector &sel, const idx_t count, - const TupleDataLayout &rhs_layout, Vector &rhs_row_locations, const idx_t col_idx, - const vector &, SelectionVector *no_match_sel, idx_t &no_match_count) { - using COMPARISON_OP = ComparisonOperationWrapper; - - // LHS - const auto &lhs_sel = *lhs_format.unified.sel; - const auto lhs_data = UnifiedVectorFormat::GetData(lhs_format.unified); - const auto &lhs_validity = lhs_format.unified.validity; - - // RHS - const auto rhs_locations = FlatVector::GetData(rhs_row_locations); - const auto rhs_offset_in_row = rhs_layout.GetOffsets()[col_idx]; - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); - - idx_t match_count = 0; - for (idx_t i = 0; i < count; i++) { - const auto idx = sel.get_index(i); - - const auto lhs_idx = lhs_sel.get_index(idx); - const auto lhs_null = lhs_validity.AllValid() ? false : !lhs_validity.RowIsValid(lhs_idx); - - const auto &rhs_location = rhs_locations[idx]; - const ValidityBytes rhs_mask(rhs_location); - const auto rhs_null = !rhs_mask.RowIsValid(rhs_mask.GetValidityEntryUnsafe(entry_idx), idx_in_entry); - - if (COMPARISON_OP::template Operation(lhs_data[lhs_idx], Load(rhs_location + rhs_offset_in_row), lhs_null, - rhs_null)) { - sel.set_index(match_count++, idx); - } else if (NO_MATCH_SEL) { - no_match_sel->set_index(no_match_count++, idx); - } - } - return match_count; -} - -template -static idx_t StructMatchEquality(Vector &lhs_vector, const TupleDataVectorFormat &lhs_format, SelectionVector &sel, - const idx_t count, const TupleDataLayout &rhs_layout, Vector &rhs_row_locations, - const idx_t col_idx, const vector &child_functions, - SelectionVector *no_match_sel, idx_t &no_match_count) { - using COMPARISON_OP = ComparisonOperationWrapper; - - // LHS - const auto &lhs_sel = *lhs_format.unified.sel; - const auto &lhs_validity = lhs_format.unified.validity; - - // RHS - const auto rhs_locations = FlatVector::GetData(rhs_row_locations); - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); - - idx_t match_count = 0; - for (idx_t i = 0; i < count; i++) { - const auto idx = sel.get_index(i); - - const auto lhs_idx = lhs_sel.get_index(idx); - const auto lhs_null = lhs_validity.AllValid() ? false : !lhs_validity.RowIsValid(lhs_idx); - - const auto &rhs_location = rhs_locations[idx]; - const ValidityBytes rhs_mask(rhs_location); - const auto rhs_null = !rhs_mask.RowIsValid(rhs_mask.GetValidityEntryUnsafe(entry_idx), idx_in_entry); - - // For structs there is no value to compare, here we match NULLs and let recursion do the rest - // So we use the comparison only if rhs or LHS is NULL and COMPARE_NULL is true - if (!(lhs_null || rhs_null) || - (COMPARISON_OP::COMPARE_NULL && COMPARISON_OP::template Operation(0, 0, lhs_null, rhs_null))) { - sel.set_index(match_count++, idx); - } else if (NO_MATCH_SEL) { - no_match_sel->set_index(no_match_count++, idx); - } - } - - // Create a Vector of pointers to the start of the TupleDataLayout of the STRUCT - Vector rhs_struct_row_locations(LogicalType::POINTER); - const auto rhs_offset_in_row = rhs_layout.GetOffsets()[col_idx]; - auto rhs_struct_locations = FlatVector::GetData(rhs_struct_row_locations); - for (idx_t i = 0; i < match_count; i++) { - const auto idx = sel.get_index(i); - rhs_struct_locations[idx] = rhs_locations[idx] + rhs_offset_in_row; - } - - // Get the struct layout and struct entries - const auto &rhs_struct_layout = rhs_layout.GetStructLayout(col_idx); - auto &lhs_struct_vectors = StructVector::GetEntries(lhs_vector); - D_ASSERT(rhs_struct_layout.ColumnCount() == lhs_struct_vectors.size()); - - for (idx_t struct_col_idx = 0; struct_col_idx < rhs_struct_layout.ColumnCount(); struct_col_idx++) { - auto &lhs_struct_vector = *lhs_struct_vectors[struct_col_idx]; - auto &lhs_struct_format = lhs_format.children[struct_col_idx]; - const auto &child_function = child_functions[struct_col_idx]; - match_count = child_function.function(lhs_struct_vector, lhs_struct_format, sel, match_count, rhs_struct_layout, - rhs_struct_row_locations, struct_col_idx, child_function.child_functions, - no_match_sel, no_match_count); - } - - return match_count; -} - -template -static idx_t SelectComparison(Vector &, Vector &, const SelectionVector &, idx_t, SelectionVector *, - SelectionVector *) { - throw NotImplementedException("Unsupported list comparison operand for RowMatcher::GetMatchFunction"); -} - -template <> -idx_t SelectComparison(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - return VectorOperations::NestedEquals(left, right, sel, count, true_sel, false_sel); -} - -template <> -idx_t SelectComparison(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - return VectorOperations::NestedNotEquals(left, right, sel, count, true_sel, false_sel); -} - -template <> -idx_t SelectComparison(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - return VectorOperations::DistinctFrom(left, right, &sel, count, true_sel, false_sel); -} - -template <> -idx_t SelectComparison(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - return VectorOperations::NotDistinctFrom(left, right, &sel, count, true_sel, false_sel); -} - -template <> -idx_t SelectComparison(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - return VectorOperations::DistinctGreaterThan(left, right, &sel, count, true_sel, false_sel); -} - -template <> -idx_t SelectComparison(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - return VectorOperations::DistinctGreaterThanEquals(left, right, &sel, count, true_sel, false_sel); -} - -template <> -idx_t SelectComparison(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - return VectorOperations::DistinctLessThan(left, right, &sel, count, true_sel, false_sel); -} - -template <> -idx_t SelectComparison(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - return VectorOperations::DistinctLessThanEquals(left, right, &sel, count, true_sel, false_sel); -} - -template -static idx_t GenericNestedMatch(Vector &lhs_vector, const TupleDataVectorFormat &, SelectionVector &sel, - const idx_t count, const TupleDataLayout &rhs_layout, Vector &rhs_row_locations, - const idx_t col_idx, const vector &, SelectionVector *no_match_sel, - idx_t &no_match_count) { - const auto &type = rhs_layout.GetTypes()[col_idx]; - - // Gather a dense Vector containing the column values being matched - Vector key(type); - const auto gather_function = TupleDataCollection::GetGatherFunction(type); - gather_function.function(rhs_layout, rhs_row_locations, col_idx, sel, count, key, - *FlatVector::IncrementalSelectionVector(), key, gather_function.child_functions); - - // Densify the input column - Vector sliced(lhs_vector, sel, count); - - if (NO_MATCH_SEL) { - SelectionVector no_match_sel_offset(no_match_sel->data() + no_match_count); - auto match_count = SelectComparison(sliced, key, sel, count, &sel, &no_match_sel_offset); - no_match_count += count - match_count; - return match_count; - } - return SelectComparison(sliced, key, sel, count, &sel, nullptr); -} - -void RowMatcher::Initialize(const bool no_match_sel, const TupleDataLayout &layout, const Predicates &predicates) { - match_functions.reserve(predicates.size()); - for (idx_t col_idx = 0; col_idx < predicates.size(); col_idx++) { - match_functions.push_back(GetMatchFunction(no_match_sel, layout.GetTypes()[col_idx], predicates[col_idx])); - } -} - -idx_t RowMatcher::Match(DataChunk &lhs, const vector &lhs_formats, SelectionVector &sel, - idx_t count, const TupleDataLayout &rhs_layout, Vector &rhs_row_locations, - SelectionVector *no_match_sel, idx_t &no_match_count) { - D_ASSERT(!match_functions.empty()); - for (idx_t col_idx = 0; col_idx < match_functions.size(); col_idx++) { - const auto &match_function = match_functions[col_idx]; - count = - match_function.function(lhs.data[col_idx], lhs_formats[col_idx], sel, count, rhs_layout, rhs_row_locations, - col_idx, match_function.child_functions, no_match_sel, no_match_count); - } - return count; -} - -MatchFunction RowMatcher::GetMatchFunction(const bool no_match_sel, const LogicalType &type, - const ExpressionType predicate) { - return no_match_sel ? GetMatchFunction(type, predicate) : GetMatchFunction(type, predicate); -} - -template -MatchFunction RowMatcher::GetMatchFunction(const LogicalType &type, const ExpressionType predicate) { - switch (type.InternalType()) { - case PhysicalType::BOOL: - return GetMatchFunction(predicate); - case PhysicalType::INT8: - return GetMatchFunction(predicate); - case PhysicalType::INT16: - return GetMatchFunction(predicate); - case PhysicalType::INT32: - return GetMatchFunction(predicate); - case PhysicalType::INT64: - return GetMatchFunction(predicate); - case PhysicalType::INT128: - return GetMatchFunction(predicate); - case PhysicalType::UINT8: - return GetMatchFunction(predicate); - case PhysicalType::UINT16: - return GetMatchFunction(predicate); - case PhysicalType::UINT32: - return GetMatchFunction(predicate); - case PhysicalType::UINT64: - return GetMatchFunction(predicate); - case PhysicalType::FLOAT: - return GetMatchFunction(predicate); - case PhysicalType::DOUBLE: - return GetMatchFunction(predicate); - case PhysicalType::INTERVAL: - return GetMatchFunction(predicate); - case PhysicalType::VARCHAR: - return GetMatchFunction(predicate); - case PhysicalType::STRUCT: - return GetStructMatchFunction(type, predicate); - case PhysicalType::LIST: - return GetListMatchFunction(predicate); - default: - throw InternalException("Unsupported PhysicalType for RowMatcher::GetMatchFunction: %s", - EnumUtil::ToString(type.InternalType())); - } -} - -template -MatchFunction RowMatcher::GetMatchFunction(const ExpressionType predicate) { - MatchFunction result; - switch (predicate) { - case ExpressionType::COMPARE_EQUAL: - result.function = TemplatedMatch; - break; - case ExpressionType::COMPARE_NOTEQUAL: - result.function = TemplatedMatch; - break; - case ExpressionType::COMPARE_DISTINCT_FROM: - result.function = TemplatedMatch; - break; - case ExpressionType::COMPARE_NOT_DISTINCT_FROM: - result.function = TemplatedMatch; - break; - case ExpressionType::COMPARE_GREATERTHAN: - result.function = TemplatedMatch; - break; - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - result.function = TemplatedMatch; - break; - case ExpressionType::COMPARE_LESSTHAN: - result.function = TemplatedMatch; - break; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - result.function = TemplatedMatch; - break; - default: - throw InternalException("Unsupported ExpressionType for RowMatcher::GetMatchFunction: %s", - EnumUtil::ToString(predicate)); - } - return result; -} - -template -MatchFunction RowMatcher::GetStructMatchFunction(const LogicalType &type, const ExpressionType predicate) { - // We perform equality conditions like it's just a row, but we cannot perform inequality conditions like a row, - // because for equality conditions we need to always loop through all columns, but for inequality conditions, - // we need to find the first inequality, so the loop looks very different - MatchFunction result; - ExpressionType child_predicate = predicate; - switch (predicate) { - case ExpressionType::COMPARE_EQUAL: - result.function = StructMatchEquality; - child_predicate = ExpressionType::COMPARE_NOT_DISTINCT_FROM; - break; - case ExpressionType::COMPARE_NOTEQUAL: - result.function = GenericNestedMatch; - return result; - case ExpressionType::COMPARE_DISTINCT_FROM: - result.function = GenericNestedMatch; - return result; - case ExpressionType::COMPARE_NOT_DISTINCT_FROM: - result.function = StructMatchEquality; - break; - case ExpressionType::COMPARE_GREATERTHAN: - result.function = GenericNestedMatch; - return result; - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - result.function = GenericNestedMatch; - return result; - case ExpressionType::COMPARE_LESSTHAN: - result.function = GenericNestedMatch; - return result; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - result.function = GenericNestedMatch; - return result; - default: - throw InternalException("Unsupported ExpressionType for RowMatcher::GetStructMatchFunction: %s", - EnumUtil::ToString(predicate)); - } - - result.child_functions.reserve(StructType::GetChildCount(type)); - for (const auto &child_type : StructType::GetChildTypes(type)) { - result.child_functions.push_back(GetMatchFunction(child_type.second, child_predicate)); - } - - return result; -} - -template -MatchFunction RowMatcher::GetListMatchFunction(const ExpressionType predicate) { - MatchFunction result; - switch (predicate) { - case ExpressionType::COMPARE_EQUAL: - result.function = GenericNestedMatch; - break; - case ExpressionType::COMPARE_NOTEQUAL: - result.function = GenericNestedMatch; - break; - case ExpressionType::COMPARE_DISTINCT_FROM: - result.function = GenericNestedMatch; - break; - case ExpressionType::COMPARE_NOT_DISTINCT_FROM: - result.function = GenericNestedMatch; - break; - case ExpressionType::COMPARE_GREATERTHAN: - result.function = GenericNestedMatch; - break; - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - result.function = GenericNestedMatch; - break; - case ExpressionType::COMPARE_LESSTHAN: - result.function = GenericNestedMatch; - break; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - result.function = GenericNestedMatch; - break; - default: - throw InternalException("Unsupported ExpressionType for RowMatcher::GetListMatchFunction: %s", - EnumUtil::ToString(predicate)); - } - return result; -} - -} // namespace duckdb - - - - - -namespace duckdb { - -template -void TemplatedRadixScatter(UnifiedVectorFormat &vdata, const SelectionVector &sel, idx_t add_count, - data_ptr_t *key_locations, const bool desc, const bool has_null, const bool nulls_first, - const idx_t offset) { - auto source = UnifiedVectorFormat::GetData(vdata); - if (has_null) { - auto &validity = vdata.validity; - const data_t valid = nulls_first ? 1 : 0; - const data_t invalid = 1 - valid; - - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - // write validity and according value - if (validity.RowIsValid(source_idx)) { - key_locations[i][0] = valid; - Radix::EncodeData(key_locations[i] + 1, source[source_idx]); - // invert bits if desc - if (desc) { - for (idx_t s = 1; s < sizeof(T) + 1; s++) { - *(key_locations[i] + s) = ~*(key_locations[i] + s); - } - } - } else { - key_locations[i][0] = invalid; - memset(key_locations[i] + 1, '\0', sizeof(T)); - } - key_locations[i] += sizeof(T) + 1; - } - } else { - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - // write value - Radix::EncodeData(key_locations[i], source[source_idx]); - // invert bits if desc - if (desc) { - for (idx_t s = 0; s < sizeof(T); s++) { - *(key_locations[i] + s) = ~*(key_locations[i] + s); - } - } - key_locations[i] += sizeof(T); - } - } -} - -void RadixScatterStringVector(UnifiedVectorFormat &vdata, const SelectionVector &sel, idx_t add_count, - data_ptr_t *key_locations, const bool desc, const bool has_null, const bool nulls_first, - const idx_t prefix_len, idx_t offset) { - auto source = UnifiedVectorFormat::GetData(vdata); - if (has_null) { - auto &validity = vdata.validity; - const data_t valid = nulls_first ? 1 : 0; - const data_t invalid = 1 - valid; - - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - // write validity and according value - if (validity.RowIsValid(source_idx)) { - key_locations[i][0] = valid; - Radix::EncodeStringDataPrefix(key_locations[i] + 1, source[source_idx], prefix_len); - // invert bits if desc - if (desc) { - for (idx_t s = 1; s < prefix_len + 1; s++) { - *(key_locations[i] + s) = ~*(key_locations[i] + s); - } - } - } else { - key_locations[i][0] = invalid; - memset(key_locations[i] + 1, '\0', prefix_len); - } - key_locations[i] += prefix_len + 1; - } - } else { - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - // write value - Radix::EncodeStringDataPrefix(key_locations[i], source[source_idx], prefix_len); - // invert bits if desc - if (desc) { - for (idx_t s = 0; s < prefix_len; s++) { - *(key_locations[i] + s) = ~*(key_locations[i] + s); - } - } - key_locations[i] += prefix_len; - } - } -} - -void RadixScatterListVector(Vector &v, UnifiedVectorFormat &vdata, const SelectionVector &sel, idx_t add_count, - data_ptr_t *key_locations, const bool desc, const bool has_null, const bool nulls_first, - const idx_t prefix_len, const idx_t width, const idx_t offset) { - auto list_data = ListVector::GetData(v); - auto &child_vector = ListVector::GetEntry(v); - auto list_size = ListVector::GetListSize(v); - child_vector.Flatten(list_size); - - // serialize null values - if (has_null) { - auto &validity = vdata.validity; - const data_t valid = nulls_first ? 1 : 0; - const data_t invalid = 1 - valid; - - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - data_ptr_t key_location = key_locations[i] + 1; - // write validity and according value - if (validity.RowIsValid(source_idx)) { - key_locations[i][0] = valid; - key_locations[i]++; - auto &list_entry = list_data[source_idx]; - if (list_entry.length > 0) { - // denote that the list is not empty with a 1 - key_locations[i][0] = 1; - key_locations[i]++; - RowOperations::RadixScatter(child_vector, list_size, *FlatVector::IncrementalSelectionVector(), 1, - key_locations + i, false, true, false, prefix_len, width - 1, - list_entry.offset); - } else { - // denote that the list is empty with a 0 - key_locations[i][0] = 0; - key_locations[i]++; - memset(key_locations[i], '\0', width - 2); - } - // invert bits if desc - if (desc) { - for (idx_t s = 0; s < width - 1; s++) { - *(key_location + s) = ~*(key_location + s); - } - } - } else { - key_locations[i][0] = invalid; - memset(key_locations[i] + 1, '\0', width - 1); - key_locations[i] += width; - } - } - } else { - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - auto &list_entry = list_data[source_idx]; - data_ptr_t key_location = key_locations[i]; - if (list_entry.length > 0) { - // denote that the list is not empty with a 1 - key_locations[i][0] = 1; - key_locations[i]++; - RowOperations::RadixScatter(child_vector, list_size, *FlatVector::IncrementalSelectionVector(), 1, - key_locations + i, false, true, false, prefix_len, width - 1, - list_entry.offset); - } else { - // denote that the list is empty with a 0 - key_locations[i][0] = 0; - key_locations[i]++; - memset(key_locations[i], '\0', width - 1); - } - // invert bits if desc - if (desc) { - for (idx_t s = 0; s < width; s++) { - *(key_location + s) = ~*(key_location + s); - } - } - } - } -} - -void RadixScatterStructVector(Vector &v, UnifiedVectorFormat &vdata, idx_t vcount, const SelectionVector &sel, - idx_t add_count, data_ptr_t *key_locations, const bool desc, const bool has_null, - const bool nulls_first, const idx_t prefix_len, idx_t width, const idx_t offset) { - // serialize null values - if (has_null) { - auto &validity = vdata.validity; - const data_t valid = nulls_first ? 1 : 0; - const data_t invalid = 1 - valid; - - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - // write validity and according value - if (validity.RowIsValid(source_idx)) { - key_locations[i][0] = valid; - } else { - key_locations[i][0] = invalid; - } - key_locations[i]++; - } - width--; - } - // serialize the struct - auto &child_vector = *StructVector::GetEntries(v)[0]; - RowOperations::RadixScatter(child_vector, vcount, *FlatVector::IncrementalSelectionVector(), add_count, - key_locations, false, true, false, prefix_len, width, offset); - // invert bits if desc - if (desc) { - for (idx_t i = 0; i < add_count; i++) { - for (idx_t s = 0; s < width; s++) { - *(key_locations[i] - width + s) = ~*(key_locations[i] - width + s); - } - } - } -} - -void RowOperations::RadixScatter(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, - data_ptr_t *key_locations, bool desc, bool has_null, bool nulls_first, - idx_t prefix_len, idx_t width, idx_t offset) { - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - switch (v.GetType().InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::INT16: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::INT32: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::INT64: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::UINT8: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::UINT16: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::UINT32: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::UINT64: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::INT128: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::FLOAT: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::DOUBLE: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::INTERVAL: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::VARCHAR: - RadixScatterStringVector(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, prefix_len, offset); - break; - case PhysicalType::LIST: - RadixScatterListVector(v, vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, prefix_len, width, - offset); - break; - case PhysicalType::STRUCT: - RadixScatterStructVector(v, vdata, vcount, sel, ser_count, key_locations, desc, has_null, nulls_first, - prefix_len, width, offset); - break; - default: - throw NotImplementedException("Cannot ORDER BY column with type %s", v.GetType().ToString()); - } -} - -} // namespace duckdb -//===--------------------------------------------------------------------===// -// row_scatter.cpp -// Description: This file contains the implementation of the row scattering -// operators -//===--------------------------------------------------------------------===// - - - - - - - - - - -namespace duckdb { - -using ValidityBytes = RowLayout::ValidityBytes; - -template -static void TemplatedScatter(UnifiedVectorFormat &col, Vector &rows, const SelectionVector &sel, const idx_t count, - const idx_t col_offset, const idx_t col_no) { - auto data = UnifiedVectorFormat::GetData(col); - auto ptrs = FlatVector::GetData(rows); - - if (!col.validity.AllValid()) { - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto col_idx = col.sel->get_index(idx); - auto row = ptrs[idx]; - - auto isnull = !col.validity.RowIsValid(col_idx); - T store_value = isnull ? NullValue() : data[col_idx]; - Store(store_value, row + col_offset); - if (isnull) { - ValidityBytes col_mask(ptrs[idx]); - col_mask.SetInvalidUnsafe(col_no); - } - } - } else { - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto col_idx = col.sel->get_index(idx); - auto row = ptrs[idx]; - - Store(data[col_idx], row + col_offset); - } - } -} - -static void ComputeStringEntrySizes(const UnifiedVectorFormat &col, idx_t entry_sizes[], const SelectionVector &sel, - const idx_t count, const idx_t offset = 0) { - auto data = UnifiedVectorFormat::GetData(col); - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto col_idx = col.sel->get_index(idx) + offset; - const auto &str = data[col_idx]; - if (col.validity.RowIsValid(col_idx) && !str.IsInlined()) { - entry_sizes[i] += str.GetSize(); - } - } -} - -static void ScatterStringVector(UnifiedVectorFormat &col, Vector &rows, data_ptr_t str_locations[], - const SelectionVector &sel, const idx_t count, const idx_t col_offset, - const idx_t col_no) { - auto string_data = UnifiedVectorFormat::GetData(col); - auto ptrs = FlatVector::GetData(rows); - - // Write out zero length to avoid swizzling problems. - const string_t null(nullptr, 0); - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto col_idx = col.sel->get_index(idx); - auto row = ptrs[idx]; - if (!col.validity.RowIsValid(col_idx)) { - ValidityBytes col_mask(row); - col_mask.SetInvalidUnsafe(col_no); - Store(null, row + col_offset); - } else if (string_data[col_idx].IsInlined()) { - Store(string_data[col_idx], row + col_offset); - } else { - const auto &str = string_data[col_idx]; - string_t inserted(const_char_ptr_cast(str_locations[i]), str.GetSize()); - memcpy(inserted.GetDataWriteable(), str.GetData(), str.GetSize()); - str_locations[i] += str.GetSize(); - inserted.Finalize(); - Store(inserted, row + col_offset); - } - } -} - -static void ScatterNestedVector(Vector &vec, UnifiedVectorFormat &col, Vector &rows, data_ptr_t data_locations[], - const SelectionVector &sel, const idx_t count, const idx_t col_offset, - const idx_t col_no, const idx_t vcount) { - // Store pointers to the data in the row - // Do this first because SerializeVector destroys the locations - auto ptrs = FlatVector::GetData(rows); - data_ptr_t validitymask_locations[STANDARD_VECTOR_SIZE]; - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto row = ptrs[idx]; - validitymask_locations[i] = row; - - Store(data_locations[i], row + col_offset); - } - - // Serialise the data - RowOperations::HeapScatter(vec, vcount, sel, count, col_no, data_locations, validitymask_locations); -} - -void RowOperations::Scatter(DataChunk &columns, UnifiedVectorFormat col_data[], const RowLayout &layout, Vector &rows, - RowDataCollection &string_heap, const SelectionVector &sel, idx_t count) { - if (count == 0) { - return; - } - - // Set the validity mask for each row before inserting data - auto ptrs = FlatVector::GetData(rows); - for (idx_t i = 0; i < count; ++i) { - auto row_idx = sel.get_index(i); - auto row = ptrs[row_idx]; - ValidityBytes(row).SetAllValid(layout.ColumnCount()); - } - - const auto vcount = columns.size(); - auto &offsets = layout.GetOffsets(); - auto &types = layout.GetTypes(); - - // Compute the entry size of the variable size columns - vector handles; - data_ptr_t data_locations[STANDARD_VECTOR_SIZE]; - if (!layout.AllConstant()) { - idx_t entry_sizes[STANDARD_VECTOR_SIZE]; - std::fill_n(entry_sizes, count, sizeof(uint32_t)); - for (idx_t col_no = 0; col_no < types.size(); col_no++) { - if (TypeIsConstantSize(types[col_no].InternalType())) { - continue; - } - - auto &vec = columns.data[col_no]; - auto &col = col_data[col_no]; - switch (types[col_no].InternalType()) { - case PhysicalType::VARCHAR: - ComputeStringEntrySizes(col, entry_sizes, sel, count); - break; - case PhysicalType::LIST: - case PhysicalType::STRUCT: - RowOperations::ComputeEntrySizes(vec, col, entry_sizes, vcount, count, sel); - break; - default: - throw InternalException("Unsupported type for RowOperations::Scatter"); - } - } - - // Build out the buffer space - handles = string_heap.Build(count, data_locations, entry_sizes); - - // Serialize information that is needed for swizzling if the computation goes out-of-core - const idx_t heap_pointer_offset = layout.GetHeapOffset(); - for (idx_t i = 0; i < count; i++) { - auto row_idx = sel.get_index(i); - auto row = ptrs[row_idx]; - // Pointer to this row in the heap block - Store(data_locations[i], row + heap_pointer_offset); - // Row size is stored in the heap in front of each row - Store(entry_sizes[i], data_locations[i]); - data_locations[i] += sizeof(uint32_t); - } - } - - for (idx_t col_no = 0; col_no < types.size(); col_no++) { - auto &vec = columns.data[col_no]; - auto &col = col_data[col_no]; - auto col_offset = offsets[col_no]; - - switch (types[col_no].InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedScatter(col, rows, sel, count, col_offset, col_no); - break; - case PhysicalType::INT16: - TemplatedScatter(col, rows, sel, count, col_offset, col_no); - break; - case PhysicalType::INT32: - TemplatedScatter(col, rows, sel, count, col_offset, col_no); - break; - case PhysicalType::INT64: - TemplatedScatter(col, rows, sel, count, col_offset, col_no); - break; - case PhysicalType::UINT8: - TemplatedScatter(col, rows, sel, count, col_offset, col_no); - break; - case PhysicalType::UINT16: - TemplatedScatter(col, rows, sel, count, col_offset, col_no); - break; - case PhysicalType::UINT32: - TemplatedScatter(col, rows, sel, count, col_offset, col_no); - break; - case PhysicalType::UINT64: - TemplatedScatter(col, rows, sel, count, col_offset, col_no); - break; - case PhysicalType::INT128: - TemplatedScatter(col, rows, sel, count, col_offset, col_no); - break; - case PhysicalType::FLOAT: - TemplatedScatter(col, rows, sel, count, col_offset, col_no); - break; - case PhysicalType::DOUBLE: - TemplatedScatter(col, rows, sel, count, col_offset, col_no); - break; - case PhysicalType::INTERVAL: - TemplatedScatter(col, rows, sel, count, col_offset, col_no); - break; - case PhysicalType::VARCHAR: - ScatterStringVector(col, rows, data_locations, sel, count, col_offset, col_no); - break; - case PhysicalType::LIST: - case PhysicalType::STRUCT: - ScatterNestedVector(vec, col, rows, data_locations, sel, count, col_offset, col_no, vcount); - break; - default: - throw InternalException("Unsupported type for RowOperations::Scatter"); - } - } -} - -} // namespace duckdb - - -namespace duckdb { - -//------------------------------------------------------------------------- -// Nested Type Hooks -//------------------------------------------------------------------------- -void BinaryDeserializer::OnPropertyBegin(const field_id_t field_id, const char *) { - auto field = NextField(); - if (field != field_id) { - throw InternalException("Failed to deserialize: field id mismatch, expected: %d, got: %d", field_id, field); - } -} - -void BinaryDeserializer::OnPropertyEnd() { -} - -bool BinaryDeserializer::OnOptionalPropertyBegin(const field_id_t field_id, const char *s) { - auto next_field = PeekField(); - auto present = next_field == field_id; - if (present) { - ConsumeField(); - } - return present; -} - -void BinaryDeserializer::OnOptionalPropertyEnd(bool present) { -} - -void BinaryDeserializer::OnObjectBegin() { - nesting_level++; -} - -void BinaryDeserializer::OnObjectEnd() { - auto next_field = NextField(); - if (next_field != MESSAGE_TERMINATOR_FIELD_ID) { - throw InternalException("Failed to deserialize: expected end of object, but found field id: %d", next_field); - } - nesting_level--; -} - -idx_t BinaryDeserializer::OnListBegin() { - return VarIntDecode(); -} - -void BinaryDeserializer::OnListEnd() { -} - -bool BinaryDeserializer::OnNullableBegin() { - return ReadBool(); -} - -void BinaryDeserializer::OnNullableEnd() { -} - -//------------------------------------------------------------------------- -// Primitive Types -//------------------------------------------------------------------------- -bool BinaryDeserializer::ReadBool() { - return static_cast(ReadPrimitive()); -} - -char BinaryDeserializer::ReadChar() { - return ReadPrimitive(); -} - -int8_t BinaryDeserializer::ReadSignedInt8() { - return VarIntDecode(); -} - -uint8_t BinaryDeserializer::ReadUnsignedInt8() { - return VarIntDecode(); -} - -int16_t BinaryDeserializer::ReadSignedInt16() { - return VarIntDecode(); -} - -uint16_t BinaryDeserializer::ReadUnsignedInt16() { - return VarIntDecode(); -} - -int32_t BinaryDeserializer::ReadSignedInt32() { - return VarIntDecode(); -} - -uint32_t BinaryDeserializer::ReadUnsignedInt32() { - return VarIntDecode(); -} - -int64_t BinaryDeserializer::ReadSignedInt64() { - return VarIntDecode(); -} - -uint64_t BinaryDeserializer::ReadUnsignedInt64() { - return VarIntDecode(); -} - -float BinaryDeserializer::ReadFloat() { - auto value = ReadPrimitive(); - return value; -} - -double BinaryDeserializer::ReadDouble() { - auto value = ReadPrimitive(); - return value; -} - -string BinaryDeserializer::ReadString() { - auto len = VarIntDecode(); - if (len == 0) { - return string(); - } - auto buffer = make_unsafe_uniq_array(len); - ReadData(buffer.get(), len); - return string(const_char_ptr_cast(buffer.get()), len); -} - -hugeint_t BinaryDeserializer::ReadHugeInt() { - auto upper = VarIntDecode(); - auto lower = VarIntDecode(); - return hugeint_t(upper, lower); -} - -void BinaryDeserializer::ReadDataPtr(data_ptr_t &ptr_p, idx_t count) { - auto len = VarIntDecode(); - if (len != count) { - throw SerializationException("Tried to read blob of %d size, but only %d elements are available", count, len); - } - ReadData(ptr_p, count); -} - -} // namespace duckdb - - -#ifdef DEBUG - -#endif - -namespace duckdb { - -void BinarySerializer::OnPropertyBegin(const field_id_t field_id, const char *tag) { - // Just write the field id straight up - Write(field_id); -#ifdef DEBUG - // First of check that we are inside an object - if (debug_stack.empty()) { - throw InternalException("OnPropertyBegin called outside of object"); - } - - // Check that the tag is unique - auto &state = debug_stack.back(); - auto &seen_field_ids = state.seen_field_ids; - auto &seen_field_tags = state.seen_field_tags; - auto &seen_fields = state.seen_fields; - - if (seen_field_ids.find(field_id) != seen_field_ids.end() || seen_field_tags.find(tag) != seen_field_tags.end()) { - string all_fields; - for (auto &field : seen_fields) { - all_fields += StringUtil::Format("\"%s\":%d ", field.first, field.second); - } - throw InternalException("Duplicate field id/tag in field: \"%s\":%d, other fields: %s", tag, field_id, - all_fields); - } - - seen_field_ids.insert(field_id); - seen_field_tags.insert(tag); - seen_fields.emplace_back(tag, field_id); -#else - (void)tag; -#endif -} - -void BinarySerializer::OnPropertyEnd() { - // Nothing to do here -} - -void BinarySerializer::OnOptionalPropertyBegin(const field_id_t field_id, const char *tag, bool present) { - // Dont write anything at all if the property is not present - if (present) { - OnPropertyBegin(field_id, tag); - } -} - -void BinarySerializer::OnOptionalPropertyEnd(bool present) { - // Nothing to do here -} - -//------------------------------------------------------------------------- -// Nested Type Hooks -//------------------------------------------------------------------------- -void BinarySerializer::OnObjectBegin() { -#ifdef DEBUG - debug_stack.emplace_back(); -#endif -} - -void BinarySerializer::OnObjectEnd() { -#ifdef DEBUG - debug_stack.pop_back(); -#endif - // Write object terminator - Write(MESSAGE_TERMINATOR_FIELD_ID); -} - -void BinarySerializer::OnListBegin(idx_t count) { - VarIntEncode(count); -} - -void BinarySerializer::OnListEnd() { -} - -void BinarySerializer::OnNullableBegin(bool present) { - WriteValue(present); -} - -void BinarySerializer::OnNullableEnd() { -} - -//------------------------------------------------------------------------- -// Primitive Types -//------------------------------------------------------------------------- -void BinarySerializer::WriteNull() { - // This should never be called, optional writes should be handled by OnOptionalBegin -} - -void BinarySerializer::WriteValue(bool value) { - Write(value); -} - -void BinarySerializer::WriteValue(uint8_t value) { - VarIntEncode(value); -} - -void BinarySerializer::WriteValue(char value) { - Write(value); -} - -void BinarySerializer::WriteValue(int8_t value) { - VarIntEncode(value); -} - -void BinarySerializer::WriteValue(uint16_t value) { - VarIntEncode(value); -} - -void BinarySerializer::WriteValue(int16_t value) { - VarIntEncode(value); -} - -void BinarySerializer::WriteValue(uint32_t value) { - VarIntEncode(value); -} - -void BinarySerializer::WriteValue(int32_t value) { - VarIntEncode(value); -} - -void BinarySerializer::WriteValue(uint64_t value) { - VarIntEncode(value); -} - -void BinarySerializer::WriteValue(int64_t value) { - VarIntEncode(value); -} - -void BinarySerializer::WriteValue(hugeint_t value) { - VarIntEncode(value.upper); - VarIntEncode(value.lower); -} - -void BinarySerializer::WriteValue(float value) { - Write(value); -} - -void BinarySerializer::WriteValue(double value) { - Write(value); -} - -void BinarySerializer::WriteValue(const string &value) { - uint32_t len = value.length(); - VarIntEncode(len); - WriteData(value.c_str(), len); -} - -void BinarySerializer::WriteValue(const string_t value) { - uint32_t len = value.GetSize(); - VarIntEncode(len); - WriteData(value.GetDataUnsafe(), len); -} - -void BinarySerializer::WriteValue(const char *value) { - uint32_t len = strlen(value); - VarIntEncode(len); - WriteData(value, len); -} - -void BinarySerializer::WriteDataPtr(const_data_ptr_t ptr, idx_t count) { - VarIntEncode(static_cast(count)); - WriteData(ptr, count); -} - -} // namespace duckdb - - - - -#include -#include - -namespace duckdb { - -BufferedFileReader::BufferedFileReader(FileSystem &fs, const char *path, FileLockType lock_type, - optional_ptr opener) - : fs(fs), data(make_unsafe_uniq_array(FILE_BUFFER_SIZE)), offset(0), read_data(0), total_read(0) { - handle = fs.OpenFile(path, FileFlags::FILE_FLAGS_READ, lock_type, FileSystem::DEFAULT_COMPRESSION, opener.get()); - file_size = fs.GetFileSize(*handle); -} - -void BufferedFileReader::ReadData(data_ptr_t target_buffer, uint64_t read_size) { - // first copy anything we can from the buffer - data_ptr_t end_ptr = target_buffer + read_size; - while (true) { - idx_t to_read = MinValue(end_ptr - target_buffer, read_data - offset); - if (to_read > 0) { - memcpy(target_buffer, data.get() + offset, to_read); - offset += to_read; - target_buffer += to_read; - } - if (target_buffer < end_ptr) { - D_ASSERT(offset == read_data); - total_read += read_data; - // did not finish reading yet but exhausted buffer - // read data into buffer - offset = 0; - read_data = fs.Read(*handle, data.get(), FILE_BUFFER_SIZE); - if (read_data == 0) { - throw SerializationException("not enough data in file to deserialize result"); - } - } else { - return; - } - } -} - -bool BufferedFileReader::Finished() { - return total_read + offset == file_size; -} - -void BufferedFileReader::Seek(uint64_t location) { - D_ASSERT(location <= file_size); - handle->Seek(location); - total_read = location; - read_data = offset = 0; -} - -uint64_t BufferedFileReader::CurrentOffset() { - return total_read + offset; -} - -} // namespace duckdb - - - -#include - -namespace duckdb { - -// Remove this when we switch C++17: https://stackoverflow.com/a/53350948 -constexpr uint8_t BufferedFileWriter::DEFAULT_OPEN_FLAGS; - -BufferedFileWriter::BufferedFileWriter(FileSystem &fs, const string &path_p, uint8_t open_flags) - : fs(fs), path(path_p), data(make_unsafe_uniq_array(FILE_BUFFER_SIZE)), offset(0), total_written(0) { - handle = fs.OpenFile(path, open_flags, FileLockType::WRITE_LOCK); -} - -int64_t BufferedFileWriter::GetFileSize() { - return fs.GetFileSize(*handle) + offset; -} - -idx_t BufferedFileWriter::GetTotalWritten() { - return total_written + offset; -} - -void BufferedFileWriter::WriteData(const_data_ptr_t buffer, idx_t write_size) { - // first copy anything we can from the buffer - const_data_ptr_t end_ptr = buffer + write_size; - while (buffer < end_ptr) { - idx_t to_write = MinValue((end_ptr - buffer), FILE_BUFFER_SIZE - offset); - D_ASSERT(to_write > 0); - memcpy(data.get() + offset, buffer, to_write); - offset += to_write; - buffer += to_write; - if (offset == FILE_BUFFER_SIZE) { - Flush(); - } - } -} - -void BufferedFileWriter::Flush() { - if (offset == 0) { - return; - } - fs.Write(*handle, data.get(), offset); - total_written += offset; - offset = 0; -} - -void BufferedFileWriter::Sync() { - Flush(); - handle->Sync(); -} - -void BufferedFileWriter::Truncate(int64_t size) { - uint64_t persistent = fs.GetFileSize(*handle); - D_ASSERT((uint64_t)size <= persistent + offset); - if (persistent <= (uint64_t)size) { - // truncating into the pending write buffer. - offset = size - persistent; - } else { - // truncate the physical file on disk - handle->Truncate(size); - // reset anything written in the buffer - offset = 0; - } -} - -} // namespace duckdb - - -namespace duckdb { - -MemoryStream::MemoryStream(idx_t capacity) - : position(0), capacity(capacity), owns_data(true), data(static_cast(malloc(capacity))) { -} - -MemoryStream::MemoryStream(data_ptr_t buffer, idx_t capacity) - : position(0), capacity(capacity), owns_data(false), data(buffer) { -} - -MemoryStream::~MemoryStream() { - if (owns_data) { - free(data); - } -} - -void MemoryStream::WriteData(const_data_ptr_t source, idx_t write_size) { - while (position + write_size > capacity) { - if (owns_data) { - capacity *= 2; - data = static_cast(realloc(data, capacity)); - } else { - throw SerializationException("Failed to serialize: not enough space in buffer to fulfill write request"); - } - } - - memcpy(data + position, source, write_size); - position += write_size; -} - -void MemoryStream::ReadData(data_ptr_t destination, idx_t read_size) { - if (position + read_size > capacity) { - throw SerializationException("Failed to deserialize: not enough data in buffer to fulfill read request"); - } - memcpy(destination, data + position, read_size); - position += read_size; -} - -void MemoryStream::Rewind() { - position = 0; -} - -void MemoryStream::Release() { - owns_data = false; -} - -data_ptr_t MemoryStream::GetData() const { - return data; -} - -idx_t MemoryStream::GetPosition() const { - return position; -} - -idx_t MemoryStream::GetCapacity() const { - return capacity; -} - -} // namespace duckdb - - -namespace duckdb { - -template <> -void Serializer::WriteValue(const vector &vec) { - auto count = vec.size(); - OnListBegin(count); - for (auto item : vec) { - WriteValue(item); - } - OnListEnd(); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -bool Comparators::TieIsBreakable(const idx_t &tie_col, const data_ptr_t &row_ptr, const SortLayout &sort_layout) { - const auto &col_idx = sort_layout.sorting_to_blob_col.at(tie_col); - // Check if the blob is NULL - ValidityBytes row_mask(row_ptr); - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); - if (!row_mask.RowIsValid(row_mask.GetValidityEntry(entry_idx), idx_in_entry)) { - // Can't break a NULL tie - return false; - } - auto &row_layout = sort_layout.blob_layout; - if (row_layout.GetTypes()[col_idx].InternalType() != PhysicalType::VARCHAR) { - // Nested type, must be broken - return true; - } - const auto &tie_col_offset = row_layout.GetOffsets()[col_idx]; - auto tie_string = Load(row_ptr + tie_col_offset); - if (tie_string.GetSize() < sort_layout.prefix_lengths[tie_col]) { - // No need to break the tie - we already compared the full string - return false; - } - return true; -} - -int Comparators::CompareTuple(const SBScanState &left, const SBScanState &right, const data_ptr_t &l_ptr, - const data_ptr_t &r_ptr, const SortLayout &sort_layout, const bool &external_sort) { - // Compare the sorting columns one by one - int comp_res = 0; - data_ptr_t l_ptr_offset = l_ptr; - data_ptr_t r_ptr_offset = r_ptr; - for (idx_t col_idx = 0; col_idx < sort_layout.column_count; col_idx++) { - comp_res = FastMemcmp(l_ptr_offset, r_ptr_offset, sort_layout.column_sizes[col_idx]); - if (comp_res == 0 && !sort_layout.constant_size[col_idx]) { - comp_res = BreakBlobTie(col_idx, left, right, sort_layout, external_sort); - } - if (comp_res != 0) { - break; - } - l_ptr_offset += sort_layout.column_sizes[col_idx]; - r_ptr_offset += sort_layout.column_sizes[col_idx]; - } - return comp_res; -} - -int Comparators::CompareVal(const data_ptr_t l_ptr, const data_ptr_t r_ptr, const LogicalType &type) { - switch (type.InternalType()) { - case PhysicalType::VARCHAR: - return TemplatedCompareVal(l_ptr, r_ptr); - case PhysicalType::LIST: - case PhysicalType::STRUCT: { - auto l_nested_ptr = Load(l_ptr); - auto r_nested_ptr = Load(r_ptr); - return CompareValAndAdvance(l_nested_ptr, r_nested_ptr, type, true); - } - default: - throw NotImplementedException("Unimplemented CompareVal for type %s", type.ToString()); - } -} - -int Comparators::BreakBlobTie(const idx_t &tie_col, const SBScanState &left, const SBScanState &right, - const SortLayout &sort_layout, const bool &external) { - data_ptr_t l_data_ptr = left.DataPtr(*left.sb->blob_sorting_data); - data_ptr_t r_data_ptr = right.DataPtr(*right.sb->blob_sorting_data); - if (!TieIsBreakable(tie_col, l_data_ptr, sort_layout)) { - // Quick check to see if ties can be broken - return 0; - } - // Align the pointers - const idx_t &col_idx = sort_layout.sorting_to_blob_col.at(tie_col); - const auto &tie_col_offset = sort_layout.blob_layout.GetOffsets()[col_idx]; - l_data_ptr += tie_col_offset; - r_data_ptr += tie_col_offset; - // Do the comparison - const int order = sort_layout.order_types[tie_col] == OrderType::DESCENDING ? -1 : 1; - const auto &type = sort_layout.blob_layout.GetTypes()[col_idx]; - int result; - if (external) { - // Store heap pointers - data_ptr_t l_heap_ptr = left.HeapPtr(*left.sb->blob_sorting_data); - data_ptr_t r_heap_ptr = right.HeapPtr(*right.sb->blob_sorting_data); - // Unswizzle offset to pointer - UnswizzleSingleValue(l_data_ptr, l_heap_ptr, type); - UnswizzleSingleValue(r_data_ptr, r_heap_ptr, type); - // Compare - result = CompareVal(l_data_ptr, r_data_ptr, type); - // Swizzle the pointers back to offsets - SwizzleSingleValue(l_data_ptr, l_heap_ptr, type); - SwizzleSingleValue(r_data_ptr, r_heap_ptr, type); - } else { - result = CompareVal(l_data_ptr, r_data_ptr, type); - } - return order * result; -} - -template -int Comparators::TemplatedCompareVal(const data_ptr_t &left_ptr, const data_ptr_t &right_ptr) { - const auto left_val = Load(left_ptr); - const auto right_val = Load(right_ptr); - if (Equals::Operation(left_val, right_val)) { - return 0; - } else if (LessThan::Operation(left_val, right_val)) { - return -1; - } else { - return 1; - } -} - -int Comparators::CompareValAndAdvance(data_ptr_t &l_ptr, data_ptr_t &r_ptr, const LogicalType &type, bool valid) { - switch (type.InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::INT16: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::INT32: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::INT64: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::UINT8: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::UINT16: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::UINT32: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::UINT64: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::INT128: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::FLOAT: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::DOUBLE: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::INTERVAL: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::VARCHAR: - return CompareStringAndAdvance(l_ptr, r_ptr, valid); - case PhysicalType::LIST: - return CompareListAndAdvance(l_ptr, r_ptr, ListType::GetChildType(type), valid); - case PhysicalType::STRUCT: - return CompareStructAndAdvance(l_ptr, r_ptr, StructType::GetChildTypes(type), valid); - default: - throw NotImplementedException("Unimplemented CompareValAndAdvance for type %s", type.ToString()); - } -} - -template -int Comparators::TemplatedCompareAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr) { - auto result = TemplatedCompareVal(left_ptr, right_ptr); - left_ptr += sizeof(T); - right_ptr += sizeof(T); - return result; -} - -int Comparators::CompareStringAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, bool valid) { - if (!valid) { - return 0; - } - uint32_t left_string_size = Load(left_ptr); - uint32_t right_string_size = Load(right_ptr); - left_ptr += sizeof(uint32_t); - right_ptr += sizeof(uint32_t); - auto memcmp_res = memcmp(const_char_ptr_cast(left_ptr), const_char_ptr_cast(right_ptr), - std::min(left_string_size, right_string_size)); - - left_ptr += left_string_size; - right_ptr += right_string_size; - - if (memcmp_res != 0) { - return memcmp_res; - } - if (left_string_size == right_string_size) { - return 0; - } - if (left_string_size < right_string_size) { - return -1; - } - return 1; -} - -int Comparators::CompareStructAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, - const child_list_t &types, bool valid) { - idx_t count = types.size(); - // Load validity masks - ValidityBytes left_validity(left_ptr); - ValidityBytes right_validity(right_ptr); - left_ptr += (count + 7) / 8; - right_ptr += (count + 7) / 8; - // Initialize variables - bool left_valid; - bool right_valid; - idx_t entry_idx; - idx_t idx_in_entry; - // Compare - int comp_res = 0; - for (idx_t i = 0; i < count; i++) { - ValidityBytes::GetEntryIndex(i, entry_idx, idx_in_entry); - left_valid = left_validity.RowIsValid(left_validity.GetValidityEntry(entry_idx), idx_in_entry); - right_valid = right_validity.RowIsValid(right_validity.GetValidityEntry(entry_idx), idx_in_entry); - auto &type = types[i].second; - if ((left_valid == right_valid) || TypeIsConstantSize(type.InternalType())) { - comp_res = CompareValAndAdvance(left_ptr, right_ptr, types[i].second, left_valid && valid); - } - if (!left_valid && !right_valid) { - comp_res = 0; - } else if (!left_valid) { - comp_res = 1; - } else if (!right_valid) { - comp_res = -1; - } - if (comp_res != 0) { - break; - } - } - return comp_res; -} - -int Comparators::CompareListAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, const LogicalType &type, - bool valid) { - if (!valid) { - return 0; - } - // Load list lengths - auto left_len = Load(left_ptr); - auto right_len = Load(right_ptr); - left_ptr += sizeof(idx_t); - right_ptr += sizeof(idx_t); - // Load list validity masks - ValidityBytes left_validity(left_ptr); - ValidityBytes right_validity(right_ptr); - left_ptr += (left_len + 7) / 8; - right_ptr += (right_len + 7) / 8; - // Compare - int comp_res = 0; - idx_t count = MinValue(left_len, right_len); - if (TypeIsConstantSize(type.InternalType())) { - // Templated code for fixed-size types - switch (type.InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::INT16: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::INT32: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::INT64: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::UINT8: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::UINT16: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::UINT32: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::UINT64: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::INT128: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::FLOAT: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::DOUBLE: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::INTERVAL: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - default: - throw NotImplementedException("CompareListAndAdvance for fixed-size type %s", type.ToString()); - } - } else { - // Variable-sized list entries - bool left_valid; - bool right_valid; - idx_t entry_idx; - idx_t idx_in_entry; - // Size (in bytes) of all variable-sizes entries is stored before the entries begin, - // to make deserialization easier. We need to skip over them - left_ptr += left_len * sizeof(idx_t); - right_ptr += right_len * sizeof(idx_t); - for (idx_t i = 0; i < count; i++) { - ValidityBytes::GetEntryIndex(i, entry_idx, idx_in_entry); - left_valid = left_validity.RowIsValid(left_validity.GetValidityEntry(entry_idx), idx_in_entry); - right_valid = right_validity.RowIsValid(right_validity.GetValidityEntry(entry_idx), idx_in_entry); - if (left_valid && right_valid) { - switch (type.InternalType()) { - case PhysicalType::LIST: - comp_res = CompareListAndAdvance(left_ptr, right_ptr, ListType::GetChildType(type), left_valid); - break; - case PhysicalType::VARCHAR: - comp_res = CompareStringAndAdvance(left_ptr, right_ptr, left_valid); - break; - case PhysicalType::STRUCT: - comp_res = - CompareStructAndAdvance(left_ptr, right_ptr, StructType::GetChildTypes(type), left_valid); - break; - default: - throw NotImplementedException("CompareListAndAdvance for variable-size type %s", type.ToString()); - } - } else if (!left_valid && !right_valid) { - comp_res = 0; - } else if (left_valid) { - comp_res = -1; - } else { - comp_res = 1; - } - if (comp_res != 0) { - break; - } - } - } - // All values that we looped over were equal - if (comp_res == 0 && left_len != right_len) { - // Smaller lists first - if (left_len < right_len) { - comp_res = -1; - } else { - comp_res = 1; - } - } - return comp_res; -} - -template -int Comparators::TemplatedCompareListLoop(data_ptr_t &left_ptr, data_ptr_t &right_ptr, - const ValidityBytes &left_validity, const ValidityBytes &right_validity, - const idx_t &count) { - int comp_res = 0; - bool left_valid; - bool right_valid; - idx_t entry_idx; - idx_t idx_in_entry; - for (idx_t i = 0; i < count; i++) { - ValidityBytes::GetEntryIndex(i, entry_idx, idx_in_entry); - left_valid = left_validity.RowIsValid(left_validity.GetValidityEntry(entry_idx), idx_in_entry); - right_valid = right_validity.RowIsValid(right_validity.GetValidityEntry(entry_idx), idx_in_entry); - comp_res = TemplatedCompareAndAdvance(left_ptr, right_ptr); - if (!left_valid && !right_valid) { - comp_res = 0; - } else if (!left_valid) { - comp_res = 1; - } else if (!right_valid) { - comp_res = -1; - } - if (comp_res != 0) { - break; - } - } - return comp_res; -} - -void Comparators::UnswizzleSingleValue(data_ptr_t data_ptr, const data_ptr_t &heap_ptr, const LogicalType &type) { - if (type.InternalType() == PhysicalType::VARCHAR) { - data_ptr += string_t::HEADER_SIZE; - } - Store(heap_ptr + Load(data_ptr), data_ptr); -} - -void Comparators::SwizzleSingleValue(data_ptr_t data_ptr, const data_ptr_t &heap_ptr, const LogicalType &type) { - if (type.InternalType() == PhysicalType::VARCHAR) { - data_ptr += string_t::HEADER_SIZE; - } - Store(Load(data_ptr) - heap_ptr, data_ptr); -} - -} // namespace duckdb - - - - -namespace duckdb { - -MergeSorter::MergeSorter(GlobalSortState &state, BufferManager &buffer_manager) - : state(state), buffer_manager(buffer_manager), sort_layout(state.sort_layout) { -} - -void MergeSorter::PerformInMergeRound() { - while (true) { - { - lock_guard pair_guard(state.lock); - if (state.pair_idx == state.num_pairs) { - break; - } - GetNextPartition(); - } - MergePartition(); - } -} - -void MergeSorter::MergePartition() { - auto &left_block = *left->sb; - auto &right_block = *right->sb; -#ifdef DEBUG - D_ASSERT(left_block.radix_sorting_data.size() == left_block.payload_data->data_blocks.size()); - D_ASSERT(right_block.radix_sorting_data.size() == right_block.payload_data->data_blocks.size()); - if (!state.payload_layout.AllConstant() && state.external) { - D_ASSERT(left_block.payload_data->data_blocks.size() == left_block.payload_data->heap_blocks.size()); - D_ASSERT(right_block.payload_data->data_blocks.size() == right_block.payload_data->heap_blocks.size()); - } - if (!sort_layout.all_constant) { - D_ASSERT(left_block.radix_sorting_data.size() == left_block.blob_sorting_data->data_blocks.size()); - D_ASSERT(right_block.radix_sorting_data.size() == right_block.blob_sorting_data->data_blocks.size()); - if (state.external) { - D_ASSERT(left_block.blob_sorting_data->data_blocks.size() == - left_block.blob_sorting_data->heap_blocks.size()); - D_ASSERT(right_block.blob_sorting_data->data_blocks.size() == - right_block.blob_sorting_data->heap_blocks.size()); - } - } -#endif - // Set up the write block - // Each merge task produces a SortedBlock with exactly state.block_capacity rows or less - result->InitializeWrite(); - // Initialize arrays to store merge data - bool left_smaller[STANDARD_VECTOR_SIZE]; - idx_t next_entry_sizes[STANDARD_VECTOR_SIZE]; - // Merge loop -#ifdef DEBUG - auto l_count = left->Remaining(); - auto r_count = right->Remaining(); -#endif - while (true) { - auto l_remaining = left->Remaining(); - auto r_remaining = right->Remaining(); - if (l_remaining + r_remaining == 0) { - // Done - break; - } - const idx_t next = MinValue(l_remaining + r_remaining, (idx_t)STANDARD_VECTOR_SIZE); - if (l_remaining != 0 && r_remaining != 0) { - // Compute the merge (not needed if one side is exhausted) - ComputeMerge(next, left_smaller); - } - // Actually merge the data (radix, blob, and payload) - MergeRadix(next, left_smaller); - if (!sort_layout.all_constant) { - MergeData(*result->blob_sorting_data, *left_block.blob_sorting_data, *right_block.blob_sorting_data, next, - left_smaller, next_entry_sizes, true); - D_ASSERT(result->radix_sorting_data.size() == result->blob_sorting_data->data_blocks.size()); - } - MergeData(*result->payload_data, *left_block.payload_data, *right_block.payload_data, next, left_smaller, - next_entry_sizes, false); - D_ASSERT(result->radix_sorting_data.size() == result->payload_data->data_blocks.size()); - } -#ifdef DEBUG - D_ASSERT(result->Count() == l_count + r_count); -#endif -} - -void MergeSorter::GetNextPartition() { - // Create result block - state.sorted_blocks_temp[state.pair_idx].push_back(make_uniq(buffer_manager, state)); - result = state.sorted_blocks_temp[state.pair_idx].back().get(); - // Determine which blocks must be merged - auto &left_block = *state.sorted_blocks[state.pair_idx * 2]; - auto &right_block = *state.sorted_blocks[state.pair_idx * 2 + 1]; - const idx_t l_count = left_block.Count(); - const idx_t r_count = right_block.Count(); - // Initialize left and right reader - left = make_uniq(buffer_manager, state); - right = make_uniq(buffer_manager, state); - // Compute the work that this thread must do using Merge Path - idx_t l_end; - idx_t r_end; - if (state.l_start + state.r_start + state.block_capacity < l_count + r_count) { - left->sb = state.sorted_blocks[state.pair_idx * 2].get(); - right->sb = state.sorted_blocks[state.pair_idx * 2 + 1].get(); - const idx_t intersection = state.l_start + state.r_start + state.block_capacity; - GetIntersection(intersection, l_end, r_end); - D_ASSERT(l_end <= l_count); - D_ASSERT(r_end <= r_count); - D_ASSERT(intersection == l_end + r_end); - } else { - l_end = l_count; - r_end = r_count; - } - // Create slices of the data that this thread must merge - left->SetIndices(0, 0); - right->SetIndices(0, 0); - left_input = left_block.CreateSlice(state.l_start, l_end, left->entry_idx); - right_input = right_block.CreateSlice(state.r_start, r_end, right->entry_idx); - left->sb = left_input.get(); - right->sb = right_input.get(); - state.l_start = l_end; - state.r_start = r_end; - D_ASSERT(left->Remaining() + right->Remaining() == state.block_capacity || (l_end == l_count && r_end == r_count)); - // Update global state - if (state.l_start == l_count && state.r_start == r_count) { - // Delete references to previous pair - state.sorted_blocks[state.pair_idx * 2] = nullptr; - state.sorted_blocks[state.pair_idx * 2 + 1] = nullptr; - // Advance pair - state.pair_idx++; - state.l_start = 0; - state.r_start = 0; - } -} - -int MergeSorter::CompareUsingGlobalIndex(SBScanState &l, SBScanState &r, const idx_t l_idx, const idx_t r_idx) { - D_ASSERT(l_idx < l.sb->Count()); - D_ASSERT(r_idx < r.sb->Count()); - - // Easy comparison using the previous result (intersections must increase monotonically) - if (l_idx < state.l_start) { - return -1; - } - if (r_idx < state.r_start) { - return 1; - } - - l.sb->GlobalToLocalIndex(l_idx, l.block_idx, l.entry_idx); - r.sb->GlobalToLocalIndex(r_idx, r.block_idx, r.entry_idx); - - l.PinRadix(l.block_idx); - r.PinRadix(r.block_idx); - data_ptr_t l_ptr = l.radix_handle.Ptr() + l.entry_idx * sort_layout.entry_size; - data_ptr_t r_ptr = r.radix_handle.Ptr() + r.entry_idx * sort_layout.entry_size; - - int comp_res; - if (sort_layout.all_constant) { - comp_res = FastMemcmp(l_ptr, r_ptr, sort_layout.comparison_size); - } else { - l.PinData(*l.sb->blob_sorting_data); - r.PinData(*r.sb->blob_sorting_data); - comp_res = Comparators::CompareTuple(l, r, l_ptr, r_ptr, sort_layout, state.external); - } - return comp_res; -} - -void MergeSorter::GetIntersection(const idx_t diagonal, idx_t &l_idx, idx_t &r_idx) { - const idx_t l_count = left->sb->Count(); - const idx_t r_count = right->sb->Count(); - // Cover some edge cases - // Code coverage off because these edge cases cannot happen unless other code changes - // Edge cases have been tested extensively while developing Merge Path in a script - // LCOV_EXCL_START - if (diagonal >= l_count + r_count) { - l_idx = l_count; - r_idx = r_count; - return; - } else if (diagonal == 0) { - l_idx = 0; - r_idx = 0; - return; - } else if (l_count == 0) { - l_idx = 0; - r_idx = diagonal; - return; - } else if (r_count == 0) { - r_idx = 0; - l_idx = diagonal; - return; - } - // LCOV_EXCL_STOP - // Determine offsets for the binary search - const idx_t l_offset = MinValue(l_count, diagonal); - const idx_t r_offset = diagonal > l_count ? diagonal - l_count : 0; - D_ASSERT(l_offset + r_offset == diagonal); - const idx_t search_space = diagonal > MaxValue(l_count, r_count) ? l_count + r_count - diagonal - : MinValue(diagonal, MinValue(l_count, r_count)); - // Double binary search - idx_t li = 0; - idx_t ri = search_space - 1; - idx_t middle; - int comp_res; - while (li <= ri) { - middle = (li + ri) / 2; - l_idx = l_offset - middle; - r_idx = r_offset + middle; - if (l_idx == l_count || r_idx == 0) { - comp_res = CompareUsingGlobalIndex(*left, *right, l_idx - 1, r_idx); - if (comp_res > 0) { - l_idx--; - r_idx++; - } else { - return; - } - if (l_idx == 0 || r_idx == r_count) { - // This case is incredibly difficult to cover as it is dependent on parallelism randomness - // But it has been tested extensively during development in a script - // LCOV_EXCL_START - return; - // LCOV_EXCL_STOP - } else { - break; - } - } - comp_res = CompareUsingGlobalIndex(*left, *right, l_idx, r_idx); - if (comp_res > 0) { - li = middle + 1; - } else { - ri = middle - 1; - } - } - int l_r_min1 = CompareUsingGlobalIndex(*left, *right, l_idx, r_idx - 1); - int l_min1_r = CompareUsingGlobalIndex(*left, *right, l_idx - 1, r_idx); - if (l_r_min1 > 0 && l_min1_r < 0) { - return; - } else if (l_r_min1 > 0) { - l_idx--; - r_idx++; - } else if (l_min1_r < 0) { - l_idx++; - r_idx--; - } -} - -void MergeSorter::ComputeMerge(const idx_t &count, bool left_smaller[]) { - auto &l = *left; - auto &r = *right; - auto &l_sorted_block = *l.sb; - auto &r_sorted_block = *r.sb; - // Save indices to restore afterwards - idx_t l_block_idx_before = l.block_idx; - idx_t l_entry_idx_before = l.entry_idx; - idx_t r_block_idx_before = r.block_idx; - idx_t r_entry_idx_before = r.entry_idx; - // Data pointers for both sides - data_ptr_t l_radix_ptr; - data_ptr_t r_radix_ptr; - // Compute the merge of the next 'count' tuples - idx_t compared = 0; - while (compared < count) { - // Move to the next block (if needed) - if (l.block_idx < l_sorted_block.radix_sorting_data.size() && - l.entry_idx == l_sorted_block.radix_sorting_data[l.block_idx]->count) { - l.block_idx++; - l.entry_idx = 0; - } - if (r.block_idx < r_sorted_block.radix_sorting_data.size() && - r.entry_idx == r_sorted_block.radix_sorting_data[r.block_idx]->count) { - r.block_idx++; - r.entry_idx = 0; - } - const bool l_done = l.block_idx == l_sorted_block.radix_sorting_data.size(); - const bool r_done = r.block_idx == r_sorted_block.radix_sorting_data.size(); - if (l_done || r_done) { - // One of the sides is exhausted, no need to compare - break; - } - // Pin the radix sorting data - left->PinRadix(l.block_idx); - l_radix_ptr = left->RadixPtr(); - right->PinRadix(r.block_idx); - r_radix_ptr = right->RadixPtr(); - - const idx_t l_count = l_sorted_block.radix_sorting_data[l.block_idx]->count; - const idx_t r_count = r_sorted_block.radix_sorting_data[r.block_idx]->count; - // Compute the merge - if (sort_layout.all_constant) { - // All sorting columns are constant size - for (; compared < count && l.entry_idx < l_count && r.entry_idx < r_count; compared++) { - left_smaller[compared] = FastMemcmp(l_radix_ptr, r_radix_ptr, sort_layout.comparison_size) < 0; - const bool &l_smaller = left_smaller[compared]; - const bool r_smaller = !l_smaller; - // Use comparison bool (0 or 1) to increment entries and pointers - l.entry_idx += l_smaller; - r.entry_idx += r_smaller; - l_radix_ptr += l_smaller * sort_layout.entry_size; - r_radix_ptr += r_smaller * sort_layout.entry_size; - } - } else { - // Pin the blob data - left->PinData(*l_sorted_block.blob_sorting_data); - right->PinData(*r_sorted_block.blob_sorting_data); - // Merge with variable size sorting columns - for (; compared < count && l.entry_idx < l_count && r.entry_idx < r_count; compared++) { - left_smaller[compared] = - Comparators::CompareTuple(*left, *right, l_radix_ptr, r_radix_ptr, sort_layout, state.external) < 0; - const bool &l_smaller = left_smaller[compared]; - const bool r_smaller = !l_smaller; - // Use comparison bool (0 or 1) to increment entries and pointers - l.entry_idx += l_smaller; - r.entry_idx += r_smaller; - l_radix_ptr += l_smaller * sort_layout.entry_size; - r_radix_ptr += r_smaller * sort_layout.entry_size; - } - } - } - // Reset block indices - left->SetIndices(l_block_idx_before, l_entry_idx_before); - right->SetIndices(r_block_idx_before, r_entry_idx_before); -} - -void MergeSorter::MergeRadix(const idx_t &count, const bool left_smaller[]) { - auto &l = *left; - auto &r = *right; - // Save indices to restore afterwards - idx_t l_block_idx_before = l.block_idx; - idx_t l_entry_idx_before = l.entry_idx; - idx_t r_block_idx_before = r.block_idx; - idx_t r_entry_idx_before = r.entry_idx; - - auto &l_blocks = l.sb->radix_sorting_data; - auto &r_blocks = r.sb->radix_sorting_data; - RowDataBlock *l_block = nullptr; - RowDataBlock *r_block = nullptr; - - data_ptr_t l_ptr; - data_ptr_t r_ptr; - - RowDataBlock *result_block = result->radix_sorting_data.back().get(); - auto result_handle = buffer_manager.Pin(result_block->block); - data_ptr_t result_ptr = result_handle.Ptr() + result_block->count * sort_layout.entry_size; - - idx_t copied = 0; - while (copied < count) { - // Move to the next block (if needed) - if (l.block_idx < l_blocks.size() && l.entry_idx == l_blocks[l.block_idx]->count) { - // Delete reference to previous block - l_blocks[l.block_idx]->block = nullptr; - // Advance block - l.block_idx++; - l.entry_idx = 0; - } - if (r.block_idx < r_blocks.size() && r.entry_idx == r_blocks[r.block_idx]->count) { - // Delete reference to previous block - r_blocks[r.block_idx]->block = nullptr; - // Advance block - r.block_idx++; - r.entry_idx = 0; - } - const bool l_done = l.block_idx == l_blocks.size(); - const bool r_done = r.block_idx == r_blocks.size(); - // Pin the radix sortable blocks - idx_t l_count; - if (!l_done) { - l_block = l_blocks[l.block_idx].get(); - left->PinRadix(l.block_idx); - l_ptr = l.RadixPtr(); - l_count = l_block->count; - } else { - l_count = 0; - } - idx_t r_count; - if (!r_done) { - r_block = r_blocks[r.block_idx].get(); - r.PinRadix(r.block_idx); - r_ptr = r.RadixPtr(); - r_count = r_block->count; - } else { - r_count = 0; - } - // Copy using computed merge - if (!l_done && !r_done) { - // Both sides have data - merge - MergeRows(l_ptr, l.entry_idx, l_count, r_ptr, r.entry_idx, r_count, *result_block, result_ptr, - sort_layout.entry_size, left_smaller, copied, count); - } else if (r_done) { - // Right side is exhausted - FlushRows(l_ptr, l.entry_idx, l_count, *result_block, result_ptr, sort_layout.entry_size, copied, count); - } else { - // Left side is exhausted - FlushRows(r_ptr, r.entry_idx, r_count, *result_block, result_ptr, sort_layout.entry_size, copied, count); - } - } - // Reset block indices - left->SetIndices(l_block_idx_before, l_entry_idx_before); - right->SetIndices(r_block_idx_before, r_entry_idx_before); -} - -void MergeSorter::MergeData(SortedData &result_data, SortedData &l_data, SortedData &r_data, const idx_t &count, - const bool left_smaller[], idx_t next_entry_sizes[], bool reset_indices) { - auto &l = *left; - auto &r = *right; - // Save indices to restore afterwards - idx_t l_block_idx_before = l.block_idx; - idx_t l_entry_idx_before = l.entry_idx; - idx_t r_block_idx_before = r.block_idx; - idx_t r_entry_idx_before = r.entry_idx; - - const auto &layout = result_data.layout; - const idx_t row_width = layout.GetRowWidth(); - const idx_t heap_pointer_offset = layout.GetHeapOffset(); - - // Left and right row data to merge - data_ptr_t l_ptr; - data_ptr_t r_ptr; - // Accompanying left and right heap data (if needed) - data_ptr_t l_heap_ptr; - data_ptr_t r_heap_ptr; - - // Result rows to write to - RowDataBlock *result_data_block = result_data.data_blocks.back().get(); - auto result_data_handle = buffer_manager.Pin(result_data_block->block); - data_ptr_t result_data_ptr = result_data_handle.Ptr() + result_data_block->count * row_width; - // Result heap to write to (if needed) - RowDataBlock *result_heap_block = nullptr; - BufferHandle result_heap_handle; - data_ptr_t result_heap_ptr; - if (!layout.AllConstant() && state.external) { - result_heap_block = result_data.heap_blocks.back().get(); - result_heap_handle = buffer_manager.Pin(result_heap_block->block); - result_heap_ptr = result_heap_handle.Ptr() + result_heap_block->byte_offset; - } - - idx_t copied = 0; - while (copied < count) { - // Move to new data blocks (if needed) - if (l.block_idx < l_data.data_blocks.size() && l.entry_idx == l_data.data_blocks[l.block_idx]->count) { - // Delete reference to previous block - l_data.data_blocks[l.block_idx]->block = nullptr; - if (!layout.AllConstant() && state.external) { - l_data.heap_blocks[l.block_idx]->block = nullptr; - } - // Advance block - l.block_idx++; - l.entry_idx = 0; - } - if (r.block_idx < r_data.data_blocks.size() && r.entry_idx == r_data.data_blocks[r.block_idx]->count) { - // Delete reference to previous block - r_data.data_blocks[r.block_idx]->block = nullptr; - if (!layout.AllConstant() && state.external) { - r_data.heap_blocks[r.block_idx]->block = nullptr; - } - // Advance block - r.block_idx++; - r.entry_idx = 0; - } - const bool l_done = l.block_idx == l_data.data_blocks.size(); - const bool r_done = r.block_idx == r_data.data_blocks.size(); - // Pin the row data blocks - if (!l_done) { - l.PinData(l_data); - l_ptr = l.DataPtr(l_data); - } - if (!r_done) { - r.PinData(r_data); - r_ptr = r.DataPtr(r_data); - } - const idx_t &l_count = !l_done ? l_data.data_blocks[l.block_idx]->count : 0; - const idx_t &r_count = !r_done ? r_data.data_blocks[r.block_idx]->count : 0; - // Perform the merge - if (layout.AllConstant() || !state.external) { - // If all constant size, or if we are doing an in-memory sort, we do not need to touch the heap - if (!l_done && !r_done) { - // Both sides have data - merge - MergeRows(l_ptr, l.entry_idx, l_count, r_ptr, r.entry_idx, r_count, *result_data_block, result_data_ptr, - row_width, left_smaller, copied, count); - } else if (r_done) { - // Right side is exhausted - FlushRows(l_ptr, l.entry_idx, l_count, *result_data_block, result_data_ptr, row_width, copied, count); - } else { - // Left side is exhausted - FlushRows(r_ptr, r.entry_idx, r_count, *result_data_block, result_data_ptr, row_width, copied, count); - } - } else { - // External sorting with variable size data. Pin the heap blocks too - if (!l_done) { - l_heap_ptr = l.BaseHeapPtr(l_data) + Load(l_ptr + heap_pointer_offset); - D_ASSERT(l_heap_ptr - l.BaseHeapPtr(l_data) >= 0); - D_ASSERT((idx_t)(l_heap_ptr - l.BaseHeapPtr(l_data)) < l_data.heap_blocks[l.block_idx]->byte_offset); - } - if (!r_done) { - r_heap_ptr = r.BaseHeapPtr(r_data) + Load(r_ptr + heap_pointer_offset); - D_ASSERT(r_heap_ptr - r.BaseHeapPtr(r_data) >= 0); - D_ASSERT((idx_t)(r_heap_ptr - r.BaseHeapPtr(r_data)) < r_data.heap_blocks[r.block_idx]->byte_offset); - } - // Both the row and heap data need to be dealt with - if (!l_done && !r_done) { - // Both sides have data - merge - idx_t l_idx_copy = l.entry_idx; - idx_t r_idx_copy = r.entry_idx; - data_ptr_t result_data_ptr_copy = result_data_ptr; - idx_t copied_copy = copied; - // Merge row data - MergeRows(l_ptr, l_idx_copy, l_count, r_ptr, r_idx_copy, r_count, *result_data_block, - result_data_ptr_copy, row_width, left_smaller, copied_copy, count); - const idx_t merged = copied_copy - copied; - // Compute the entry sizes and number of heap bytes that will be copied - idx_t copy_bytes = 0; - data_ptr_t l_heap_ptr_copy = l_heap_ptr; - data_ptr_t r_heap_ptr_copy = r_heap_ptr; - for (idx_t i = 0; i < merged; i++) { - // Store base heap offset in the row data - Store(result_heap_block->byte_offset + copy_bytes, result_data_ptr + heap_pointer_offset); - result_data_ptr += row_width; - // Compute entry size and add to total - const bool &l_smaller = left_smaller[copied + i]; - const bool r_smaller = !l_smaller; - auto &entry_size = next_entry_sizes[copied + i]; - entry_size = - l_smaller * Load(l_heap_ptr_copy) + r_smaller * Load(r_heap_ptr_copy); - D_ASSERT(entry_size >= sizeof(uint32_t)); - D_ASSERT(l_heap_ptr_copy - l.BaseHeapPtr(l_data) + l_smaller * entry_size <= - l_data.heap_blocks[l.block_idx]->byte_offset); - D_ASSERT(r_heap_ptr_copy - r.BaseHeapPtr(r_data) + r_smaller * entry_size <= - r_data.heap_blocks[r.block_idx]->byte_offset); - l_heap_ptr_copy += l_smaller * entry_size; - r_heap_ptr_copy += r_smaller * entry_size; - copy_bytes += entry_size; - } - // Reallocate result heap block size (if needed) - if (result_heap_block->byte_offset + copy_bytes > result_heap_block->capacity) { - idx_t new_capacity = result_heap_block->byte_offset + copy_bytes; - buffer_manager.ReAllocate(result_heap_block->block, new_capacity); - result_heap_block->capacity = new_capacity; - result_heap_ptr = result_heap_handle.Ptr() + result_heap_block->byte_offset; - } - D_ASSERT(result_heap_block->byte_offset + copy_bytes <= result_heap_block->capacity); - // Now copy the heap data - for (idx_t i = 0; i < merged; i++) { - const bool &l_smaller = left_smaller[copied + i]; - const bool r_smaller = !l_smaller; - const auto &entry_size = next_entry_sizes[copied + i]; - memcpy(result_heap_ptr, - reinterpret_cast(l_smaller * CastPointerToValue(l_heap_ptr) + - r_smaller * CastPointerToValue(r_heap_ptr)), - entry_size); - D_ASSERT(Load(result_heap_ptr) == entry_size); - result_heap_ptr += entry_size; - l_heap_ptr += l_smaller * entry_size; - r_heap_ptr += r_smaller * entry_size; - l.entry_idx += l_smaller; - r.entry_idx += r_smaller; - } - // Update result indices and pointers - result_heap_block->count += merged; - result_heap_block->byte_offset += copy_bytes; - copied += merged; - } else if (r_done) { - // Right side is exhausted - flush left - FlushBlobs(layout, l_count, l_ptr, l.entry_idx, l_heap_ptr, *result_data_block, result_data_ptr, - *result_heap_block, result_heap_handle, result_heap_ptr, copied, count); - } else { - // Left side is exhausted - flush right - FlushBlobs(layout, r_count, r_ptr, r.entry_idx, r_heap_ptr, *result_data_block, result_data_ptr, - *result_heap_block, result_heap_handle, result_heap_ptr, copied, count); - } - D_ASSERT(result_data_block->count == result_heap_block->count); - } - } - if (reset_indices) { - left->SetIndices(l_block_idx_before, l_entry_idx_before); - right->SetIndices(r_block_idx_before, r_entry_idx_before); - } -} - -void MergeSorter::MergeRows(data_ptr_t &l_ptr, idx_t &l_entry_idx, const idx_t &l_count, data_ptr_t &r_ptr, - idx_t &r_entry_idx, const idx_t &r_count, RowDataBlock &target_block, - data_ptr_t &target_ptr, const idx_t &entry_size, const bool left_smaller[], idx_t &copied, - const idx_t &count) { - const idx_t next = MinValue(count - copied, target_block.capacity - target_block.count); - idx_t i; - for (i = 0; i < next && l_entry_idx < l_count && r_entry_idx < r_count; i++) { - const bool &l_smaller = left_smaller[copied + i]; - const bool r_smaller = !l_smaller; - // Use comparison bool (0 or 1) to copy an entry from either side - FastMemcpy( - target_ptr, - reinterpret_cast(l_smaller * CastPointerToValue(l_ptr) + r_smaller * CastPointerToValue(r_ptr)), - entry_size); - target_ptr += entry_size; - // Use the comparison bool to increment entries and pointers - l_entry_idx += l_smaller; - r_entry_idx += r_smaller; - l_ptr += l_smaller * entry_size; - r_ptr += r_smaller * entry_size; - } - // Update counts - target_block.count += i; - copied += i; -} - -void MergeSorter::FlushRows(data_ptr_t &source_ptr, idx_t &source_entry_idx, const idx_t &source_count, - RowDataBlock &target_block, data_ptr_t &target_ptr, const idx_t &entry_size, idx_t &copied, - const idx_t &count) { - // Compute how many entries we can fit - idx_t next = MinValue(count - copied, target_block.capacity - target_block.count); - next = MinValue(next, source_count - source_entry_idx); - // Copy them all in a single memcpy - const idx_t copy_bytes = next * entry_size; - memcpy(target_ptr, source_ptr, copy_bytes); - target_ptr += copy_bytes; - source_ptr += copy_bytes; - // Update counts - source_entry_idx += next; - target_block.count += next; - copied += next; -} - -void MergeSorter::FlushBlobs(const RowLayout &layout, const idx_t &source_count, data_ptr_t &source_data_ptr, - idx_t &source_entry_idx, data_ptr_t &source_heap_ptr, RowDataBlock &target_data_block, - data_ptr_t &target_data_ptr, RowDataBlock &target_heap_block, - BufferHandle &target_heap_handle, data_ptr_t &target_heap_ptr, idx_t &copied, - const idx_t &count) { - const idx_t row_width = layout.GetRowWidth(); - const idx_t heap_pointer_offset = layout.GetHeapOffset(); - idx_t source_entry_idx_copy = source_entry_idx; - data_ptr_t target_data_ptr_copy = target_data_ptr; - idx_t copied_copy = copied; - // Flush row data - FlushRows(source_data_ptr, source_entry_idx_copy, source_count, target_data_block, target_data_ptr_copy, row_width, - copied_copy, count); - const idx_t flushed = copied_copy - copied; - // Compute the entry sizes and number of heap bytes that will be copied - idx_t copy_bytes = 0; - data_ptr_t source_heap_ptr_copy = source_heap_ptr; - for (idx_t i = 0; i < flushed; i++) { - // Store base heap offset in the row data - Store(target_heap_block.byte_offset + copy_bytes, target_data_ptr + heap_pointer_offset); - target_data_ptr += row_width; - // Compute entry size and add to total - auto entry_size = Load(source_heap_ptr_copy); - D_ASSERT(entry_size >= sizeof(uint32_t)); - source_heap_ptr_copy += entry_size; - copy_bytes += entry_size; - } - // Reallocate result heap block size (if needed) - if (target_heap_block.byte_offset + copy_bytes > target_heap_block.capacity) { - idx_t new_capacity = target_heap_block.byte_offset + copy_bytes; - buffer_manager.ReAllocate(target_heap_block.block, new_capacity); - target_heap_block.capacity = new_capacity; - target_heap_ptr = target_heap_handle.Ptr() + target_heap_block.byte_offset; - } - D_ASSERT(target_heap_block.byte_offset + copy_bytes <= target_heap_block.capacity); - // Copy the heap data in one go - memcpy(target_heap_ptr, source_heap_ptr, copy_bytes); - target_heap_ptr += copy_bytes; - source_heap_ptr += copy_bytes; - source_entry_idx += flushed; - copied += flushed; - // Update result indices and pointers - target_heap_block.count += flushed; - target_heap_block.byte_offset += copy_bytes; - D_ASSERT(target_heap_block.byte_offset <= target_heap_block.capacity); -} - -} // namespace duckdb - - - - - - - -#include - -namespace duckdb { - -PartitionGlobalHashGroup::PartitionGlobalHashGroup(BufferManager &buffer_manager, const Orders &partitions, - const Orders &orders, const Types &payload_types, bool external) - : count(0), batch_base(0) { - - RowLayout payload_layout; - payload_layout.Initialize(payload_types); - global_sort = make_uniq(buffer_manager, orders, payload_layout); - global_sort->external = external; - - // Set up a comparator for the partition subset - partition_layout = global_sort->sort_layout.GetPrefixComparisonLayout(partitions.size()); -} - -int PartitionGlobalHashGroup::ComparePartitions(const SBIterator &left, const SBIterator &right) const { - int part_cmp = 0; - if (partition_layout.all_constant) { - part_cmp = FastMemcmp(left.entry_ptr, right.entry_ptr, partition_layout.comparison_size); - } else { - part_cmp = Comparators::CompareTuple(left.scan, right.scan, left.entry_ptr, right.entry_ptr, partition_layout, - left.external); - } - return part_cmp; -} - -void PartitionGlobalHashGroup::ComputeMasks(ValidityMask &partition_mask, ValidityMask &order_mask) { - D_ASSERT(count > 0); - - SBIterator prev(*global_sort, ExpressionType::COMPARE_LESSTHAN); - SBIterator curr(*global_sort, ExpressionType::COMPARE_LESSTHAN); - - partition_mask.SetValidUnsafe(0); - order_mask.SetValidUnsafe(0); - for (++curr; curr.GetIndex() < count; ++curr) { - // Compare the partition subset first because if that differs, then so does the full ordering - const auto part_cmp = ComparePartitions(prev, curr); - ; - - if (part_cmp) { - partition_mask.SetValidUnsafe(curr.GetIndex()); - order_mask.SetValidUnsafe(curr.GetIndex()); - } else if (prev.Compare(curr)) { - order_mask.SetValidUnsafe(curr.GetIndex()); - } - ++prev; - } -} - -void PartitionGlobalSinkState::GenerateOrderings(Orders &partitions, Orders &orders, - const vector> &partition_bys, - const Orders &order_bys, - const vector> &partition_stats) { - - // we sort by both 1) partition by expression list and 2) order by expressions - const auto partition_cols = partition_bys.size(); - for (idx_t prt_idx = 0; prt_idx < partition_cols; prt_idx++) { - auto &pexpr = partition_bys[prt_idx]; - - if (partition_stats.empty() || !partition_stats[prt_idx]) { - orders.emplace_back(OrderType::ASCENDING, OrderByNullType::NULLS_FIRST, pexpr->Copy(), nullptr); - } else { - orders.emplace_back(OrderType::ASCENDING, OrderByNullType::NULLS_FIRST, pexpr->Copy(), - partition_stats[prt_idx]->ToUnique()); - } - partitions.emplace_back(orders.back().Copy()); - } - - for (const auto &order : order_bys) { - orders.emplace_back(order.Copy()); - } -} - -PartitionGlobalSinkState::PartitionGlobalSinkState(ClientContext &context, - const vector> &partition_bys, - const vector &order_bys, - const Types &payload_types, - const vector> &partition_stats, - idx_t estimated_cardinality) - : context(context), buffer_manager(BufferManager::GetBufferManager(context)), allocator(Allocator::Get(context)), - fixed_bits(0), payload_types(payload_types), memory_per_thread(0), max_bits(1), count(0) { - - GenerateOrderings(partitions, orders, partition_bys, order_bys, partition_stats); - - memory_per_thread = PhysicalOperator::GetMaxThreadMemory(context); - external = ClientConfig::GetConfig(context).force_external; - - const auto thread_pages = PreviousPowerOfTwo(memory_per_thread / (4 * idx_t(Storage::BLOCK_ALLOC_SIZE))); - while (max_bits < 10 && (thread_pages >> max_bits) > 1) { - ++max_bits; - } - - if (!orders.empty()) { - if (partitions.empty()) { - // Sort early into a dedicated hash group if we only sort. - grouping_types.Initialize(payload_types); - auto new_group = - make_uniq(buffer_manager, partitions, orders, payload_types, external); - hash_groups.emplace_back(std::move(new_group)); - } else { - auto types = payload_types; - types.push_back(LogicalType::HASH); - grouping_types.Initialize(types); - ResizeGroupingData(estimated_cardinality); - } - } -} - -bool PartitionGlobalSinkState::HasMergeTasks() const { - if (grouping_data) { - auto &groups = grouping_data->GetPartitions(); - return !groups.empty(); - } else if (!hash_groups.empty()) { - D_ASSERT(hash_groups.size() == 1); - return hash_groups[0]->count > 0; - } else { - return false; - } -} - -void PartitionGlobalSinkState::SyncPartitioning(const PartitionGlobalSinkState &other) { - fixed_bits = other.grouping_data ? other.grouping_data->GetRadixBits() : 0; - - const auto old_bits = grouping_data ? grouping_data->GetRadixBits() : 0; - if (fixed_bits != old_bits) { - const auto hash_col_idx = payload_types.size(); - grouping_data = make_uniq(buffer_manager, grouping_types, fixed_bits, hash_col_idx); - } -} - -unique_ptr PartitionGlobalSinkState::CreatePartition(idx_t new_bits) const { - const auto hash_col_idx = payload_types.size(); - return make_uniq(buffer_manager, grouping_types, new_bits, hash_col_idx); -} - -void PartitionGlobalSinkState::ResizeGroupingData(idx_t cardinality) { - // Have we started to combine? Then just live with it. - if (fixed_bits || (grouping_data && !grouping_data->GetPartitions().empty())) { - return; - } - // Is the average partition size too large? - const idx_t partition_size = STANDARD_ROW_GROUPS_SIZE; - const auto bits = grouping_data ? grouping_data->GetRadixBits() : 0; - auto new_bits = bits ? bits : 4; - while (new_bits < max_bits && (cardinality / RadixPartitioning::NumberOfPartitions(new_bits)) > partition_size) { - ++new_bits; - } - - // Repartition the grouping data - if (new_bits != bits) { - grouping_data = CreatePartition(new_bits); - } -} - -void PartitionGlobalSinkState::SyncLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append) { - // We are done if the local_partition is right sized. - auto &local_radix = local_partition->Cast(); - const auto new_bits = grouping_data->GetRadixBits(); - if (local_radix.GetRadixBits() == new_bits) { - return; - } - - // If the local partition is now too small, flush it and reallocate - auto new_partition = CreatePartition(new_bits); - local_partition->FlushAppendState(*local_append); - local_partition->Repartition(*new_partition); - - local_partition = std::move(new_partition); - local_append = make_uniq(); - local_partition->InitializeAppendState(*local_append); -} - -void PartitionGlobalSinkState::UpdateLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append) { - // Make sure grouping_data doesn't change under us. - lock_guard guard(lock); - - if (!local_partition) { - local_partition = CreatePartition(grouping_data->GetRadixBits()); - local_append = make_uniq(); - local_partition->InitializeAppendState(*local_append); - return; - } - - // Grow the groups if they are too big - ResizeGroupingData(count); - - // Sync local partition to have the same bit count - SyncLocalPartition(local_partition, local_append); -} - -void PartitionGlobalSinkState::CombineLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append) { - if (!local_partition) { - return; - } - local_partition->FlushAppendState(*local_append); - - // Make sure grouping_data doesn't change under us. - // Combine has an internal mutex, so this is single-threaded anyway. - lock_guard guard(lock); - SyncLocalPartition(local_partition, local_append); - grouping_data->Combine(*local_partition); -} - -PartitionLocalMergeState::PartitionLocalMergeState(PartitionGlobalSinkState &gstate) - : merge_state(nullptr), stage(PartitionSortStage::INIT), finished(true), executor(gstate.context) { - - // Set up the sort expression computation. - vector sort_types; - for (auto &order : gstate.orders) { - auto &oexpr = order.expression; - sort_types.emplace_back(oexpr->return_type); - executor.AddExpression(*oexpr); - } - sort_chunk.Initialize(gstate.allocator, sort_types); - payload_chunk.Initialize(gstate.allocator, gstate.payload_types); -} - -void PartitionLocalMergeState::Scan() { - if (!merge_state->group_data) { - // OVER(ORDER BY...) - // Already sorted - return; - } - - auto &group_data = *merge_state->group_data; - auto &hash_group = *merge_state->hash_group; - auto &chunk_state = merge_state->chunk_state; - // Copy the data from the group into the sort code. - auto &global_sort = *hash_group.global_sort; - LocalSortState local_sort; - local_sort.Initialize(global_sort, global_sort.buffer_manager); - - TupleDataScanState local_scan; - group_data.InitializeScan(local_scan, merge_state->column_ids); - while (group_data.Scan(chunk_state, local_scan, payload_chunk)) { - sort_chunk.Reset(); - executor.Execute(payload_chunk, sort_chunk); - - local_sort.SinkChunk(sort_chunk, payload_chunk); - if (local_sort.SizeInBytes() > merge_state->memory_per_thread) { - local_sort.Sort(global_sort, true); - } - hash_group.count += payload_chunk.size(); - } - - global_sort.AddLocalState(local_sort); -} - -// Per-thread sink state -PartitionLocalSinkState::PartitionLocalSinkState(ClientContext &context, PartitionGlobalSinkState &gstate_p) - : gstate(gstate_p), allocator(Allocator::Get(context)), executor(context) { - - vector group_types; - for (idx_t prt_idx = 0; prt_idx < gstate.partitions.size(); prt_idx++) { - auto &pexpr = *gstate.partitions[prt_idx].expression.get(); - group_types.push_back(pexpr.return_type); - executor.AddExpression(pexpr); - } - sort_cols = gstate.orders.size() + group_types.size(); - - if (sort_cols) { - auto payload_types = gstate.payload_types; - if (!group_types.empty()) { - // OVER(PARTITION BY...) - group_chunk.Initialize(allocator, group_types); - payload_types.emplace_back(LogicalType::HASH); - } else { - // OVER(ORDER BY...) - for (idx_t ord_idx = 0; ord_idx < gstate.orders.size(); ord_idx++) { - auto &pexpr = *gstate.orders[ord_idx].expression.get(); - group_types.push_back(pexpr.return_type); - executor.AddExpression(pexpr); - } - group_chunk.Initialize(allocator, group_types); - - // Single partition - auto &global_sort = *gstate.hash_groups[0]->global_sort; - local_sort = make_uniq(); - local_sort->Initialize(global_sort, global_sort.buffer_manager); - } - // OVER(...) - payload_chunk.Initialize(allocator, payload_types); - } else { - // OVER() - payload_layout.Initialize(gstate.payload_types); - } -} - -void PartitionLocalSinkState::Hash(DataChunk &input_chunk, Vector &hash_vector) { - const auto count = input_chunk.size(); - D_ASSERT(group_chunk.ColumnCount() > 0); - - // OVER(PARTITION BY...) (hash grouping) - group_chunk.Reset(); - executor.Execute(input_chunk, group_chunk); - VectorOperations::Hash(group_chunk.data[0], hash_vector, count); - for (idx_t prt_idx = 1; prt_idx < group_chunk.ColumnCount(); ++prt_idx) { - VectorOperations::CombineHash(hash_vector, group_chunk.data[prt_idx], count); - } -} - -void PartitionLocalSinkState::Sink(DataChunk &input_chunk) { - gstate.count += input_chunk.size(); - - // OVER() - if (sort_cols == 0) { - // No sorts, so build paged row chunks - if (!rows) { - const auto entry_size = payload_layout.GetRowWidth(); - const auto capacity = MaxValue(STANDARD_VECTOR_SIZE, (Storage::BLOCK_SIZE / entry_size) + 1); - rows = make_uniq(gstate.buffer_manager, capacity, entry_size); - strings = make_uniq(gstate.buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1, true); - } - const auto row_count = input_chunk.size(); - const auto row_sel = FlatVector::IncrementalSelectionVector(); - Vector addresses(LogicalType::POINTER); - auto key_locations = FlatVector::GetData(addresses); - const auto prev_rows_blocks = rows->blocks.size(); - auto handles = rows->Build(row_count, key_locations, nullptr, row_sel); - auto input_data = input_chunk.ToUnifiedFormat(); - RowOperations::Scatter(input_chunk, input_data.get(), payload_layout, addresses, *strings, *row_sel, row_count); - // Mark that row blocks contain pointers (heap blocks are pinned) - if (!payload_layout.AllConstant()) { - D_ASSERT(strings->keep_pinned); - for (size_t i = prev_rows_blocks; i < rows->blocks.size(); ++i) { - rows->blocks[i]->block->SetSwizzling("PartitionLocalSinkState::Sink"); - } - } - return; - } - - if (local_sort) { - // OVER(ORDER BY...) - group_chunk.Reset(); - executor.Execute(input_chunk, group_chunk); - local_sort->SinkChunk(group_chunk, input_chunk); - - auto &hash_group = *gstate.hash_groups[0]; - hash_group.count += input_chunk.size(); - - if (local_sort->SizeInBytes() > gstate.memory_per_thread) { - auto &global_sort = *hash_group.global_sort; - local_sort->Sort(global_sort, true); - } - return; - } - - // OVER(...) - payload_chunk.Reset(); - auto &hash_vector = payload_chunk.data.back(); - Hash(input_chunk, hash_vector); - for (idx_t col_idx = 0; col_idx < input_chunk.ColumnCount(); ++col_idx) { - payload_chunk.data[col_idx].Reference(input_chunk.data[col_idx]); - } - payload_chunk.SetCardinality(input_chunk); - - gstate.UpdateLocalPartition(local_partition, local_append); - local_partition->Append(*local_append, payload_chunk); -} - -void PartitionLocalSinkState::Combine() { - // OVER() - if (sort_cols == 0) { - // Only one partition again, so need a global lock. - lock_guard glock(gstate.lock); - if (gstate.rows) { - if (rows) { - gstate.rows->Merge(*rows); - gstate.strings->Merge(*strings); - rows.reset(); - strings.reset(); - } - } else { - gstate.rows = std::move(rows); - gstate.strings = std::move(strings); - } - return; - } - - if (local_sort) { - // OVER(ORDER BY...) - auto &hash_group = *gstate.hash_groups[0]; - auto &global_sort = *hash_group.global_sort; - global_sort.AddLocalState(*local_sort); - local_sort.reset(); - return; - } - - // OVER(...) - gstate.CombineLocalPartition(local_partition, local_append); -} - -PartitionGlobalMergeState::PartitionGlobalMergeState(PartitionGlobalSinkState &sink, GroupDataPtr group_data_p, - hash_t hash_bin) - : sink(sink), group_data(std::move(group_data_p)), memory_per_thread(sink.memory_per_thread), - num_threads(TaskScheduler::GetScheduler(sink.context).NumberOfThreads()), stage(PartitionSortStage::INIT), - total_tasks(0), tasks_assigned(0), tasks_completed(0) { - - const auto group_idx = sink.hash_groups.size(); - auto new_group = make_uniq(sink.buffer_manager, sink.partitions, sink.orders, - sink.payload_types, sink.external); - sink.hash_groups.emplace_back(std::move(new_group)); - - hash_group = sink.hash_groups[group_idx].get(); - global_sort = sink.hash_groups[group_idx]->global_sort.get(); - - sink.bin_groups[hash_bin] = group_idx; - - column_ids.reserve(sink.payload_types.size()); - for (column_t i = 0; i < sink.payload_types.size(); ++i) { - column_ids.emplace_back(i); - } - group_data->InitializeScan(chunk_state, column_ids); -} - -PartitionGlobalMergeState::PartitionGlobalMergeState(PartitionGlobalSinkState &sink) - : sink(sink), memory_per_thread(sink.memory_per_thread), - num_threads(TaskScheduler::GetScheduler(sink.context).NumberOfThreads()), stage(PartitionSortStage::INIT), - total_tasks(0), tasks_assigned(0), tasks_completed(0) { - - const hash_t hash_bin = 0; - const size_t group_idx = 0; - hash_group = sink.hash_groups[group_idx].get(); - global_sort = sink.hash_groups[group_idx]->global_sort.get(); - - sink.bin_groups[hash_bin] = group_idx; -} - -void PartitionLocalMergeState::Prepare() { - merge_state->group_data.reset(); - - auto &global_sort = *merge_state->global_sort; - global_sort.PrepareMergePhase(); -} - -void PartitionLocalMergeState::Merge() { - auto &global_sort = *merge_state->global_sort; - MergeSorter merge_sorter(global_sort, global_sort.buffer_manager); - merge_sorter.PerformInMergeRound(); -} - -void PartitionLocalMergeState::ExecuteTask() { - switch (stage) { - case PartitionSortStage::SCAN: - Scan(); - break; - case PartitionSortStage::PREPARE: - Prepare(); - break; - case PartitionSortStage::MERGE: - Merge(); - break; - default: - throw InternalException("Unexpected PartitionSortStage in ExecuteTask!"); - } - - merge_state->CompleteTask(); - finished = true; -} - -bool PartitionGlobalMergeState::AssignTask(PartitionLocalMergeState &local_state) { - lock_guard guard(lock); - - if (tasks_assigned >= total_tasks) { - return false; - } - - local_state.merge_state = this; - local_state.stage = stage; - local_state.finished = false; - tasks_assigned++; - - return true; -} - -void PartitionGlobalMergeState::CompleteTask() { - lock_guard guard(lock); - - ++tasks_completed; -} - -bool PartitionGlobalMergeState::TryPrepareNextStage() { - lock_guard guard(lock); - - if (tasks_completed < total_tasks) { - return false; - } - - tasks_assigned = tasks_completed = 0; - - switch (stage) { - case PartitionSortStage::INIT: - // If the partitions are unordered, don't scan in parallel - // because it produces non-deterministic orderings. - // This can theoretically happen with ORDER BY, - // but that is something the query should be explicit about. - total_tasks = sink.orders.size() > sink.partitions.size() ? num_threads : 1; - stage = PartitionSortStage::SCAN; - return true; - - case PartitionSortStage::SCAN: - total_tasks = 1; - stage = PartitionSortStage::PREPARE; - return true; - - case PartitionSortStage::PREPARE: - total_tasks = global_sort->sorted_blocks.size() / 2; - if (!total_tasks) { - break; - } - stage = PartitionSortStage::MERGE; - global_sort->InitializeMergeRound(); - return true; - - case PartitionSortStage::MERGE: - global_sort->CompleteMergeRound(true); - total_tasks = global_sort->sorted_blocks.size() / 2; - if (!total_tasks) { - break; - } - global_sort->InitializeMergeRound(); - return true; - - case PartitionSortStage::SORTED: - break; - } - - stage = PartitionSortStage::SORTED; - - return false; -} - -PartitionGlobalMergeStates::PartitionGlobalMergeStates(PartitionGlobalSinkState &sink) { - // Schedule all the sorts for maximum thread utilisation - if (sink.grouping_data) { - auto &partitions = sink.grouping_data->GetPartitions(); - sink.bin_groups.resize(partitions.size(), partitions.size()); - for (hash_t hash_bin = 0; hash_bin < partitions.size(); ++hash_bin) { - auto &group_data = partitions[hash_bin]; - // Prepare for merge sort phase - if (group_data->Count()) { - auto state = make_uniq(sink, std::move(group_data), hash_bin); - states.emplace_back(std::move(state)); - } - } - } else { - // OVER(ORDER BY...) - // Already sunk into the single global sort, so set up single merge with no data - sink.bin_groups.resize(1, 1); - auto state = make_uniq(sink); - states.emplace_back(std::move(state)); - } -} - -class PartitionMergeTask : public ExecutorTask { -public: - PartitionMergeTask(shared_ptr event_p, ClientContext &context_p, PartitionGlobalMergeStates &hash_groups_p, - PartitionGlobalSinkState &gstate) - : ExecutorTask(context_p), event(std::move(event_p)), local_state(gstate), hash_groups(hash_groups_p) { - } - - TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override; - -private: - struct ExecutorCallback : public PartitionGlobalMergeStates::Callback { - explicit ExecutorCallback(Executor &executor) : executor(executor) { - } - - bool HasError() const override { - return executor.HasError(); - } - - Executor &executor; - }; - - shared_ptr event; - PartitionLocalMergeState local_state; - PartitionGlobalMergeStates &hash_groups; -}; - -bool PartitionGlobalMergeStates::ExecuteTask(PartitionLocalMergeState &local_state, Callback &callback) { - // Loop until all hash groups are done - size_t sorted = 0; - while (sorted < states.size()) { - // First check if there is an unfinished task for this thread - if (callback.HasError()) { - return false; - } - if (!local_state.TaskFinished()) { - local_state.ExecuteTask(); - continue; - } - - // Thread is done with its assigned task, try to fetch new work - for (auto group = sorted; group < states.size(); ++group) { - auto &global_state = states[group]; - if (global_state->IsSorted()) { - // This hash group is done - // Update the high water mark of densely completed groups - if (sorted == group) { - ++sorted; - } - continue; - } - - // Try to assign work for this hash group to this thread - if (global_state->AssignTask(local_state)) { - // We assigned a task to this thread! - // Break out of this loop to re-enter the top-level loop and execute the task - break; - } - - // Hash group global state couldn't assign a task to this thread - // Try to prepare the next stage - if (!global_state->TryPrepareNextStage()) { - // This current hash group is not yet done - // But we were not able to assign a task for it to this thread - // See if the next hash group is better - continue; - } - - // We were able to prepare the next stage for this hash group! - // Try to assign a task once more - if (global_state->AssignTask(local_state)) { - // We assigned a task to this thread! - // Break out of this loop to re-enter the top-level loop and execute the task - break; - } - - // We were able to prepare the next merge round, - // but we were not able to assign a task for it to this thread - // The tasks were assigned to other threads while this thread waited for the lock - // Go to the next iteration to see if another hash group has a task - } - } - - return true; -} - -TaskExecutionResult PartitionMergeTask::ExecuteTask(TaskExecutionMode mode) { - ExecutorCallback callback(executor); - - if (!hash_groups.ExecuteTask(local_state, callback)) { - return TaskExecutionResult::TASK_ERROR; - } - - event->FinishTask(); - return TaskExecutionResult::TASK_FINISHED; -} - -void PartitionMergeEvent::Schedule() { - auto &context = pipeline->GetClientContext(); - - // Schedule tasks equal to the number of threads, which will each merge multiple partitions - auto &ts = TaskScheduler::GetScheduler(context); - idx_t num_threads = ts.NumberOfThreads(); - - vector> merge_tasks; - for (idx_t tnum = 0; tnum < num_threads; tnum++) { - merge_tasks.emplace_back(make_uniq(shared_from_this(), context, merge_states, gstate)); - } - SetTasks(std::move(merge_tasks)); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -//! Calls std::sort on strings that are tied by their prefix after the radix sort -static void SortTiedBlobs(BufferManager &buffer_manager, const data_ptr_t dataptr, const idx_t &start, const idx_t &end, - const idx_t &tie_col, bool *ties, const data_ptr_t blob_ptr, const SortLayout &sort_layout) { - const auto row_width = sort_layout.blob_layout.GetRowWidth(); - // Locate the first blob row in question - data_ptr_t row_ptr = dataptr + start * sort_layout.entry_size; - data_ptr_t blob_row_ptr = blob_ptr + Load(row_ptr + sort_layout.comparison_size) * row_width; - if (!Comparators::TieIsBreakable(tie_col, blob_row_ptr, sort_layout)) { - // Quick check to see if ties can be broken - return; - } - // Fill pointer array for sorting - auto ptr_block = make_unsafe_uniq_array(end - start); - auto entry_ptrs = (data_ptr_t *)ptr_block.get(); - for (idx_t i = start; i < end; i++) { - entry_ptrs[i - start] = row_ptr; - row_ptr += sort_layout.entry_size; - } - // Slow pointer-based sorting - const int order = sort_layout.order_types[tie_col] == OrderType::DESCENDING ? -1 : 1; - const idx_t &col_idx = sort_layout.sorting_to_blob_col.at(tie_col); - const auto &tie_col_offset = sort_layout.blob_layout.GetOffsets()[col_idx]; - auto logical_type = sort_layout.blob_layout.GetTypes()[col_idx]; - std::sort(entry_ptrs, entry_ptrs + end - start, - [&blob_ptr, &order, &sort_layout, &tie_col_offset, &row_width, &logical_type](const data_ptr_t l, - const data_ptr_t r) { - idx_t left_idx = Load(l + sort_layout.comparison_size); - idx_t right_idx = Load(r + sort_layout.comparison_size); - data_ptr_t left_ptr = blob_ptr + left_idx * row_width + tie_col_offset; - data_ptr_t right_ptr = blob_ptr + right_idx * row_width + tie_col_offset; - return order * Comparators::CompareVal(left_ptr, right_ptr, logical_type) < 0; - }); - // Re-order - auto temp_block = buffer_manager.GetBufferAllocator().Allocate((end - start) * sort_layout.entry_size); - data_ptr_t temp_ptr = temp_block.get(); - for (idx_t i = 0; i < end - start; i++) { - FastMemcpy(temp_ptr, entry_ptrs[i], sort_layout.entry_size); - temp_ptr += sort_layout.entry_size; - } - memcpy(dataptr + start * sort_layout.entry_size, temp_block.get(), (end - start) * sort_layout.entry_size); - // Determine if there are still ties (if this is not the last column) - if (tie_col < sort_layout.column_count - 1) { - data_ptr_t idx_ptr = dataptr + start * sort_layout.entry_size + sort_layout.comparison_size; - // Load current entry - data_ptr_t current_ptr = blob_ptr + Load(idx_ptr) * row_width + tie_col_offset; - for (idx_t i = 0; i < end - start - 1; i++) { - // Load next entry and compare - idx_ptr += sort_layout.entry_size; - data_ptr_t next_ptr = blob_ptr + Load(idx_ptr) * row_width + tie_col_offset; - ties[start + i] = Comparators::CompareVal(current_ptr, next_ptr, logical_type) == 0; - current_ptr = next_ptr; - } - } -} - -//! Identifies sequences of rows that are tied by the prefix of a blob column, and sorts them -static void SortTiedBlobs(BufferManager &buffer_manager, SortedBlock &sb, bool *ties, data_ptr_t dataptr, - const idx_t &count, const idx_t &tie_col, const SortLayout &sort_layout) { - D_ASSERT(!ties[count - 1]); - auto &blob_block = *sb.blob_sorting_data->data_blocks.back(); - auto blob_handle = buffer_manager.Pin(blob_block.block); - const data_ptr_t blob_ptr = blob_handle.Ptr(); - - for (idx_t i = 0; i < count; i++) { - if (!ties[i]) { - continue; - } - idx_t j; - for (j = i; j < count; j++) { - if (!ties[j]) { - break; - } - } - SortTiedBlobs(buffer_manager, dataptr, i, j + 1, tie_col, ties, blob_ptr, sort_layout); - i = j; - } -} - -//! Returns whether there are any 'true' values in the ties[] array -static bool AnyTies(bool ties[], const idx_t &count) { - D_ASSERT(!ties[count - 1]); - bool any_ties = false; - for (idx_t i = 0; i < count - 1; i++) { - any_ties = any_ties || ties[i]; - } - return any_ties; -} - -//! Compares subsequent rows to check for ties -static void ComputeTies(data_ptr_t dataptr, const idx_t &count, const idx_t &col_offset, const idx_t &tie_size, - bool ties[], const SortLayout &sort_layout) { - D_ASSERT(!ties[count - 1]); - D_ASSERT(col_offset + tie_size <= sort_layout.comparison_size); - // Align dataptr - dataptr += col_offset; - for (idx_t i = 0; i < count - 1; i++) { - ties[i] = ties[i] && FastMemcmp(dataptr, dataptr + sort_layout.entry_size, tie_size) == 0; - dataptr += sort_layout.entry_size; - } -} - -//! Textbook LSD radix sort -void RadixSortLSD(BufferManager &buffer_manager, const data_ptr_t &dataptr, const idx_t &count, const idx_t &col_offset, - const idx_t &row_width, const idx_t &sorting_size) { - auto temp_block = buffer_manager.GetBufferAllocator().Allocate(count * row_width); - bool swap = false; - - idx_t counts[SortConstants::VALUES_PER_RADIX]; - for (idx_t r = 1; r <= sorting_size; r++) { - // Init counts to 0 - memset(counts, 0, sizeof(counts)); - // Const some values for convenience - const data_ptr_t source_ptr = swap ? temp_block.get() : dataptr; - const data_ptr_t target_ptr = swap ? dataptr : temp_block.get(); - const idx_t offset = col_offset + sorting_size - r; - // Collect counts - data_ptr_t offset_ptr = source_ptr + offset; - for (idx_t i = 0; i < count; i++) { - counts[*offset_ptr]++; - offset_ptr += row_width; - } - // Compute offsets from counts - idx_t max_count = counts[0]; - for (idx_t val = 1; val < SortConstants::VALUES_PER_RADIX; val++) { - max_count = MaxValue(max_count, counts[val]); - counts[val] = counts[val] + counts[val - 1]; - } - if (max_count == count) { - continue; - } - // Re-order the data in temporary array - data_ptr_t row_ptr = source_ptr + (count - 1) * row_width; - for (idx_t i = 0; i < count; i++) { - idx_t &radix_offset = --counts[*(row_ptr + offset)]; - FastMemcpy(target_ptr + radix_offset * row_width, row_ptr, row_width); - row_ptr -= row_width; - } - swap = !swap; - } - // Move data back to original buffer (if it was swapped) - if (swap) { - memcpy(dataptr, temp_block.get(), count * row_width); - } -} - -//! Insertion sort, used when count of values is low -inline void InsertionSort(const data_ptr_t orig_ptr, const data_ptr_t temp_ptr, const idx_t &count, - const idx_t &col_offset, const idx_t &row_width, const idx_t &total_comp_width, - const idx_t &offset, bool swap) { - const data_ptr_t source_ptr = swap ? temp_ptr : orig_ptr; - const data_ptr_t target_ptr = swap ? orig_ptr : temp_ptr; - if (count > 1) { - const idx_t total_offset = col_offset + offset; - auto temp_val = make_unsafe_uniq_array(row_width); - const data_ptr_t val = temp_val.get(); - const auto comp_width = total_comp_width - offset; - for (idx_t i = 1; i < count; i++) { - FastMemcpy(val, source_ptr + i * row_width, row_width); - idx_t j = i; - while (j > 0 && - FastMemcmp(source_ptr + (j - 1) * row_width + total_offset, val + total_offset, comp_width) > 0) { - FastMemcpy(source_ptr + j * row_width, source_ptr + (j - 1) * row_width, row_width); - j--; - } - FastMemcpy(source_ptr + j * row_width, val, row_width); - } - } - if (swap) { - memcpy(target_ptr, source_ptr, count * row_width); - } -} - -//! MSD radix sort that switches to insertion sort with low bucket sizes -void RadixSortMSD(const data_ptr_t orig_ptr, const data_ptr_t temp_ptr, const idx_t &count, const idx_t &col_offset, - const idx_t &row_width, const idx_t &comp_width, const idx_t &offset, idx_t locations[], bool swap) { - const data_ptr_t source_ptr = swap ? temp_ptr : orig_ptr; - const data_ptr_t target_ptr = swap ? orig_ptr : temp_ptr; - // Init counts to 0 - memset(locations, 0, SortConstants::MSD_RADIX_LOCATIONS * sizeof(idx_t)); - idx_t *counts = locations + 1; - // Collect counts - const idx_t total_offset = col_offset + offset; - data_ptr_t offset_ptr = source_ptr + total_offset; - for (idx_t i = 0; i < count; i++) { - counts[*offset_ptr]++; - offset_ptr += row_width; - } - // Compute locations from counts - idx_t max_count = 0; - for (idx_t radix = 0; radix < SortConstants::VALUES_PER_RADIX; radix++) { - max_count = MaxValue(max_count, counts[radix]); - counts[radix] += locations[radix]; - } - if (max_count != count) { - // Re-order the data in temporary array - data_ptr_t row_ptr = source_ptr; - for (idx_t i = 0; i < count; i++) { - const idx_t &radix_offset = locations[*(row_ptr + total_offset)]++; - FastMemcpy(target_ptr + radix_offset * row_width, row_ptr, row_width); - row_ptr += row_width; - } - swap = !swap; - } - // Check if done - if (offset == comp_width - 1) { - if (swap) { - memcpy(orig_ptr, temp_ptr, count * row_width); - } - return; - } - if (max_count == count) { - RadixSortMSD(orig_ptr, temp_ptr, count, col_offset, row_width, comp_width, offset + 1, - locations + SortConstants::MSD_RADIX_LOCATIONS, swap); - return; - } - // Recurse - idx_t radix_count = locations[0]; - for (idx_t radix = 0; radix < SortConstants::VALUES_PER_RADIX; radix++) { - const idx_t loc = (locations[radix] - radix_count) * row_width; - if (radix_count > SortConstants::INSERTION_SORT_THRESHOLD) { - RadixSortMSD(orig_ptr + loc, temp_ptr + loc, radix_count, col_offset, row_width, comp_width, offset + 1, - locations + SortConstants::MSD_RADIX_LOCATIONS, swap); - } else if (radix_count != 0) { - InsertionSort(orig_ptr + loc, temp_ptr + loc, radix_count, col_offset, row_width, comp_width, offset + 1, - swap); - } - radix_count = locations[radix + 1] - locations[radix]; - } -} - -//! Calls different sort functions, depending on the count and sorting sizes -void RadixSort(BufferManager &buffer_manager, const data_ptr_t &dataptr, const idx_t &count, const idx_t &col_offset, - const idx_t &sorting_size, const SortLayout &sort_layout, bool contains_string) { - if (contains_string) { - auto begin = duckdb_pdqsort::PDQIterator(dataptr, sort_layout.entry_size); - auto end = begin + count; - duckdb_pdqsort::PDQConstants constants(sort_layout.entry_size, col_offset, sorting_size, *end); - duckdb_pdqsort::pdqsort_branchless(begin, begin + count, constants); - } else if (count <= SortConstants::INSERTION_SORT_THRESHOLD) { - InsertionSort(dataptr, nullptr, count, 0, sort_layout.entry_size, sort_layout.comparison_size, 0, false); - } else if (sorting_size <= SortConstants::MSD_RADIX_SORT_SIZE_THRESHOLD) { - RadixSortLSD(buffer_manager, dataptr, count, col_offset, sort_layout.entry_size, sorting_size); - } else { - auto temp_block = buffer_manager.Allocate(MaxValue(count * sort_layout.entry_size, (idx_t)Storage::BLOCK_SIZE)); - auto preallocated_array = make_unsafe_uniq_array(sorting_size * SortConstants::MSD_RADIX_LOCATIONS); - RadixSortMSD(dataptr, temp_block.Ptr(), count, col_offset, sort_layout.entry_size, sorting_size, 0, - preallocated_array.get(), false); - } -} - -//! Identifies sequences of rows that are tied, and calls radix sort on these -static void SubSortTiedTuples(BufferManager &buffer_manager, const data_ptr_t dataptr, const idx_t &count, - const idx_t &col_offset, const idx_t &sorting_size, bool ties[], - const SortLayout &sort_layout, bool contains_string) { - D_ASSERT(!ties[count - 1]); - for (idx_t i = 0; i < count; i++) { - if (!ties[i]) { - continue; - } - idx_t j; - for (j = i + 1; j < count; j++) { - if (!ties[j]) { - break; - } - } - RadixSort(buffer_manager, dataptr + i * sort_layout.entry_size, j - i + 1, col_offset, sorting_size, - sort_layout, contains_string); - i = j; - } -} - -void LocalSortState::SortInMemory() { - auto &sb = *sorted_blocks.back(); - auto &block = *sb.radix_sorting_data.back(); - const auto &count = block.count; - auto handle = buffer_manager->Pin(block.block); - const auto dataptr = handle.Ptr(); - // Assign an index to each row - data_ptr_t idx_dataptr = dataptr + sort_layout->comparison_size; - for (uint32_t i = 0; i < count; i++) { - Store(i, idx_dataptr); - idx_dataptr += sort_layout->entry_size; - } - // Radix sort and break ties until no more ties, or until all columns are sorted - idx_t sorting_size = 0; - idx_t col_offset = 0; - unsafe_unique_array ties_ptr; - bool *ties = nullptr; - bool contains_string = false; - for (idx_t i = 0; i < sort_layout->column_count; i++) { - sorting_size += sort_layout->column_sizes[i]; - contains_string = contains_string || sort_layout->logical_types[i].InternalType() == PhysicalType::VARCHAR; - if (sort_layout->constant_size[i] && i < sort_layout->column_count - 1) { - // Add columns to the sorting size until we reach a variable size column, or the last column - continue; - } - - if (!ties) { - // This is the first sort - RadixSort(*buffer_manager, dataptr, count, col_offset, sorting_size, *sort_layout, contains_string); - ties_ptr = make_unsafe_uniq_array(count); - ties = ties_ptr.get(); - std::fill_n(ties, count - 1, true); - ties[count - 1] = false; - } else { - // For subsequent sorts, we only have to subsort the tied tuples - SubSortTiedTuples(*buffer_manager, dataptr, count, col_offset, sorting_size, ties, *sort_layout, - contains_string); - } - - contains_string = false; - - if (sort_layout->constant_size[i] && i == sort_layout->column_count - 1) { - // All columns are sorted, no ties to break because last column is constant size - break; - } - - ComputeTies(dataptr, count, col_offset, sorting_size, ties, *sort_layout); - if (!AnyTies(ties, count)) { - // No ties, stop sorting - break; - } - - if (!sort_layout->constant_size[i]) { - SortTiedBlobs(*buffer_manager, sb, ties, dataptr, count, i, *sort_layout); - if (!AnyTies(ties, count)) { - // No more ties after tie-breaking, stop - break; - } - } - - col_offset += sorting_size; - sorting_size = 0; - } -} - -} // namespace duckdb - - - - - - -#include -#include - -namespace duckdb { - -idx_t GetNestedSortingColSize(idx_t &col_size, const LogicalType &type) { - auto physical_type = type.InternalType(); - if (TypeIsConstantSize(physical_type)) { - col_size += GetTypeIdSize(physical_type); - return 0; - } else { - switch (physical_type) { - case PhysicalType::VARCHAR: { - // Nested strings are between 4 and 11 chars long for alignment - auto size_before_str = col_size; - col_size += 11; - col_size -= (col_size - 12) % 8; - return col_size - size_before_str; - } - case PhysicalType::LIST: - // Lists get 2 bytes (null and empty list) - col_size += 2; - return GetNestedSortingColSize(col_size, ListType::GetChildType(type)); - case PhysicalType::STRUCT: - // Structs get 1 bytes (null) - col_size++; - return GetNestedSortingColSize(col_size, StructType::GetChildType(type, 0)); - default: - throw NotImplementedException("Unable to order column with type %s", type.ToString()); - } - } -} - -SortLayout::SortLayout(const vector &orders) - : column_count(orders.size()), all_constant(true), comparison_size(0), entry_size(0) { - vector blob_layout_types; - for (idx_t i = 0; i < column_count; i++) { - const auto &order = orders[i]; - - order_types.push_back(order.type); - order_by_null_types.push_back(order.null_order); - auto &expr = *order.expression; - logical_types.push_back(expr.return_type); - - auto physical_type = expr.return_type.InternalType(); - constant_size.push_back(TypeIsConstantSize(physical_type)); - - if (order.stats) { - stats.push_back(order.stats.get()); - has_null.push_back(stats.back()->CanHaveNull()); - } else { - stats.push_back(nullptr); - has_null.push_back(true); - } - - idx_t col_size = has_null.back() ? 1 : 0; - prefix_lengths.push_back(0); - if (!TypeIsConstantSize(physical_type) && physical_type != PhysicalType::VARCHAR) { - prefix_lengths.back() = GetNestedSortingColSize(col_size, expr.return_type); - } else if (physical_type == PhysicalType::VARCHAR) { - idx_t size_before = col_size; - if (stats.back() && StringStats::HasMaxStringLength(*stats.back())) { - col_size += StringStats::MaxStringLength(*stats.back()); - if (col_size > 12) { - col_size = 12; - } else { - constant_size.back() = true; - } - } else { - col_size = 12; - } - prefix_lengths.back() = col_size - size_before; - } else { - col_size += GetTypeIdSize(physical_type); - } - - comparison_size += col_size; - column_sizes.push_back(col_size); - } - entry_size = comparison_size + sizeof(uint32_t); - - // 8-byte alignment - if (entry_size % 8 != 0) { - // First assign more bytes to strings instead of aligning - idx_t bytes_to_fill = 8 - (entry_size % 8); - for (idx_t col_idx = 0; col_idx < column_count; col_idx++) { - if (bytes_to_fill == 0) { - break; - } - if (logical_types[col_idx].InternalType() == PhysicalType::VARCHAR && stats[col_idx] && - StringStats::HasMaxStringLength(*stats[col_idx])) { - idx_t diff = StringStats::MaxStringLength(*stats[col_idx]) - prefix_lengths[col_idx]; - if (diff > 0) { - // Increase all sizes accordingly - idx_t increase = MinValue(bytes_to_fill, diff); - column_sizes[col_idx] += increase; - prefix_lengths[col_idx] += increase; - constant_size[col_idx] = increase == diff; - comparison_size += increase; - entry_size += increase; - bytes_to_fill -= increase; - } - } - } - entry_size = AlignValue(entry_size); - } - - for (idx_t col_idx = 0; col_idx < column_count; col_idx++) { - all_constant = all_constant && constant_size[col_idx]; - if (!constant_size[col_idx]) { - sorting_to_blob_col[col_idx] = blob_layout_types.size(); - blob_layout_types.push_back(logical_types[col_idx]); - } - } - - blob_layout.Initialize(blob_layout_types); -} - -SortLayout SortLayout::GetPrefixComparisonLayout(idx_t num_prefix_cols) const { - SortLayout result; - result.column_count = num_prefix_cols; - result.all_constant = true; - result.comparison_size = 0; - for (idx_t col_idx = 0; col_idx < num_prefix_cols; col_idx++) { - result.order_types.push_back(order_types[col_idx]); - result.order_by_null_types.push_back(order_by_null_types[col_idx]); - result.logical_types.push_back(logical_types[col_idx]); - - result.all_constant = result.all_constant && constant_size[col_idx]; - result.constant_size.push_back(constant_size[col_idx]); - - result.comparison_size += column_sizes[col_idx]; - result.column_sizes.push_back(column_sizes[col_idx]); - - result.prefix_lengths.push_back(prefix_lengths[col_idx]); - result.stats.push_back(stats[col_idx]); - result.has_null.push_back(has_null[col_idx]); - } - result.entry_size = entry_size; - result.blob_layout = blob_layout; - result.sorting_to_blob_col = sorting_to_blob_col; - return result; -} - -LocalSortState::LocalSortState() : initialized(false) { - if (!Radix::IsLittleEndian()) { - throw NotImplementedException("Sorting is not supported on big endian architectures"); - } -} - -void LocalSortState::Initialize(GlobalSortState &global_sort_state, BufferManager &buffer_manager_p) { - sort_layout = &global_sort_state.sort_layout; - payload_layout = &global_sort_state.payload_layout; - buffer_manager = &buffer_manager_p; - // Radix sorting data - radix_sorting_data = make_uniq( - *buffer_manager, RowDataCollection::EntriesPerBlock(sort_layout->entry_size), sort_layout->entry_size); - // Blob sorting data - if (!sort_layout->all_constant) { - auto blob_row_width = sort_layout->blob_layout.GetRowWidth(); - blob_sorting_data = make_uniq( - *buffer_manager, RowDataCollection::EntriesPerBlock(blob_row_width), blob_row_width); - blob_sorting_heap = make_uniq(*buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1, true); - } - // Payload data - auto payload_row_width = payload_layout->GetRowWidth(); - payload_data = make_uniq(*buffer_manager, RowDataCollection::EntriesPerBlock(payload_row_width), - payload_row_width); - payload_heap = make_uniq(*buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1, true); - // Init done - initialized = true; -} - -void LocalSortState::SinkChunk(DataChunk &sort, DataChunk &payload) { - D_ASSERT(sort.size() == payload.size()); - // Build and serialize sorting data to radix sortable rows - auto data_pointers = FlatVector::GetData(addresses); - auto handles = radix_sorting_data->Build(sort.size(), data_pointers, nullptr); - for (idx_t sort_col = 0; sort_col < sort.ColumnCount(); sort_col++) { - bool has_null = sort_layout->has_null[sort_col]; - bool nulls_first = sort_layout->order_by_null_types[sort_col] == OrderByNullType::NULLS_FIRST; - bool desc = sort_layout->order_types[sort_col] == OrderType::DESCENDING; - RowOperations::RadixScatter(sort.data[sort_col], sort.size(), sel_ptr, sort.size(), data_pointers, desc, - has_null, nulls_first, sort_layout->prefix_lengths[sort_col], - sort_layout->column_sizes[sort_col]); - } - - // Also fully serialize blob sorting columns (to be able to break ties - if (!sort_layout->all_constant) { - DataChunk blob_chunk; - blob_chunk.SetCardinality(sort.size()); - for (idx_t sort_col = 0; sort_col < sort.ColumnCount(); sort_col++) { - if (!sort_layout->constant_size[sort_col]) { - blob_chunk.data.emplace_back(sort.data[sort_col]); - } - } - handles = blob_sorting_data->Build(blob_chunk.size(), data_pointers, nullptr); - auto blob_data = blob_chunk.ToUnifiedFormat(); - RowOperations::Scatter(blob_chunk, blob_data.get(), sort_layout->blob_layout, addresses, *blob_sorting_heap, - sel_ptr, blob_chunk.size()); - D_ASSERT(blob_sorting_heap->keep_pinned); - } - - // Finally, serialize payload data - handles = payload_data->Build(payload.size(), data_pointers, nullptr); - auto input_data = payload.ToUnifiedFormat(); - RowOperations::Scatter(payload, input_data.get(), *payload_layout, addresses, *payload_heap, sel_ptr, - payload.size()); - D_ASSERT(payload_heap->keep_pinned); -} - -idx_t LocalSortState::SizeInBytes() const { - idx_t size_in_bytes = radix_sorting_data->SizeInBytes() + payload_data->SizeInBytes(); - if (!sort_layout->all_constant) { - size_in_bytes += blob_sorting_data->SizeInBytes() + blob_sorting_heap->SizeInBytes(); - } - if (!payload_layout->AllConstant()) { - size_in_bytes += payload_heap->SizeInBytes(); - } - return size_in_bytes; -} - -void LocalSortState::Sort(GlobalSortState &global_sort_state, bool reorder_heap) { - D_ASSERT(radix_sorting_data->count == payload_data->count); - if (radix_sorting_data->count == 0) { - return; - } - // Move all data to a single SortedBlock - sorted_blocks.emplace_back(make_uniq(*buffer_manager, global_sort_state)); - auto &sb = *sorted_blocks.back(); - // Fixed-size sorting data - auto sorting_block = ConcatenateBlocks(*radix_sorting_data); - sb.radix_sorting_data.push_back(std::move(sorting_block)); - // Variable-size sorting data - if (!sort_layout->all_constant) { - auto &blob_data = *blob_sorting_data; - auto new_block = ConcatenateBlocks(blob_data); - sb.blob_sorting_data->data_blocks.push_back(std::move(new_block)); - } - // Payload data - auto payload_block = ConcatenateBlocks(*payload_data); - sb.payload_data->data_blocks.push_back(std::move(payload_block)); - // Now perform the actual sort - SortInMemory(); - // Re-order before the merge sort - ReOrder(global_sort_state, reorder_heap); -} - -unique_ptr LocalSortState::ConcatenateBlocks(RowDataCollection &row_data) { - // Don't copy and delete if there is only one block. - if (row_data.blocks.size() == 1) { - auto new_block = std::move(row_data.blocks[0]); - row_data.blocks.clear(); - row_data.count = 0; - return new_block; - } - // Create block with the correct capacity - auto buffer_manager = &row_data.buffer_manager; - const idx_t &entry_size = row_data.entry_size; - idx_t capacity = MaxValue(((idx_t)Storage::BLOCK_SIZE + entry_size - 1) / entry_size, row_data.count); - auto new_block = make_uniq(*buffer_manager, capacity, entry_size); - new_block->count = row_data.count; - auto new_block_handle = buffer_manager->Pin(new_block->block); - data_ptr_t new_block_ptr = new_block_handle.Ptr(); - // Copy the data of the blocks into a single block - for (idx_t i = 0; i < row_data.blocks.size(); i++) { - auto &block = row_data.blocks[i]; - auto block_handle = buffer_manager->Pin(block->block); - memcpy(new_block_ptr, block_handle.Ptr(), block->count * entry_size); - new_block_ptr += block->count * entry_size; - block.reset(); - } - row_data.blocks.clear(); - row_data.count = 0; - return new_block; -} - -void LocalSortState::ReOrder(SortedData &sd, data_ptr_t sorting_ptr, RowDataCollection &heap, GlobalSortState &gstate, - bool reorder_heap) { - sd.swizzled = reorder_heap; - auto &unordered_data_block = sd.data_blocks.back(); - const idx_t count = unordered_data_block->count; - auto unordered_data_handle = buffer_manager->Pin(unordered_data_block->block); - const data_ptr_t unordered_data_ptr = unordered_data_handle.Ptr(); - // Create new block that will hold re-ordered row data - auto ordered_data_block = - make_uniq(*buffer_manager, unordered_data_block->capacity, unordered_data_block->entry_size); - ordered_data_block->count = count; - auto ordered_data_handle = buffer_manager->Pin(ordered_data_block->block); - data_ptr_t ordered_data_ptr = ordered_data_handle.Ptr(); - // Re-order fixed-size row layout - const idx_t row_width = sd.layout.GetRowWidth(); - const idx_t sorting_entry_size = gstate.sort_layout.entry_size; - for (idx_t i = 0; i < count; i++) { - auto index = Load(sorting_ptr); - FastMemcpy(ordered_data_ptr, unordered_data_ptr + index * row_width, row_width); - ordered_data_ptr += row_width; - sorting_ptr += sorting_entry_size; - } - ordered_data_block->block->SetSwizzling( - sd.layout.AllConstant() || !sd.swizzled ? nullptr : "LocalSortState::ReOrder.ordered_data"); - // Replace the unordered data block with the re-ordered data block - sd.data_blocks.clear(); - sd.data_blocks.push_back(std::move(ordered_data_block)); - // Deal with the heap (if necessary) - if (!sd.layout.AllConstant() && reorder_heap) { - // Swizzle the column pointers to offsets - RowOperations::SwizzleColumns(sd.layout, ordered_data_handle.Ptr(), count); - sd.data_blocks.back()->block->SetSwizzling(nullptr); - // Create a single heap block to store the ordered heap - idx_t total_byte_offset = - std::accumulate(heap.blocks.begin(), heap.blocks.end(), (idx_t)0, - [](idx_t a, const unique_ptr &b) { return a + b->byte_offset; }); - idx_t heap_block_size = MaxValue(total_byte_offset, (idx_t)Storage::BLOCK_SIZE); - auto ordered_heap_block = make_uniq(*buffer_manager, heap_block_size, 1); - ordered_heap_block->count = count; - ordered_heap_block->byte_offset = total_byte_offset; - auto ordered_heap_handle = buffer_manager->Pin(ordered_heap_block->block); - data_ptr_t ordered_heap_ptr = ordered_heap_handle.Ptr(); - // Fill the heap in order - ordered_data_ptr = ordered_data_handle.Ptr(); - const idx_t heap_pointer_offset = sd.layout.GetHeapOffset(); - for (idx_t i = 0; i < count; i++) { - auto heap_row_ptr = Load(ordered_data_ptr + heap_pointer_offset); - auto heap_row_size = Load(heap_row_ptr); - memcpy(ordered_heap_ptr, heap_row_ptr, heap_row_size); - ordered_heap_ptr += heap_row_size; - ordered_data_ptr += row_width; - } - // Swizzle the base pointer to the offset of each row in the heap - RowOperations::SwizzleHeapPointer(sd.layout, ordered_data_handle.Ptr(), ordered_heap_handle.Ptr(), count); - // Move the re-ordered heap to the SortedData, and clear the local heap - sd.heap_blocks.push_back(std::move(ordered_heap_block)); - heap.pinned_blocks.clear(); - heap.blocks.clear(); - heap.count = 0; - } -} - -void LocalSortState::ReOrder(GlobalSortState &gstate, bool reorder_heap) { - auto &sb = *sorted_blocks.back(); - auto sorting_handle = buffer_manager->Pin(sb.radix_sorting_data.back()->block); - const data_ptr_t sorting_ptr = sorting_handle.Ptr() + gstate.sort_layout.comparison_size; - // Re-order variable size sorting columns - if (!gstate.sort_layout.all_constant) { - ReOrder(*sb.blob_sorting_data, sorting_ptr, *blob_sorting_heap, gstate, reorder_heap); - } - // And the payload - ReOrder(*sb.payload_data, sorting_ptr, *payload_heap, gstate, reorder_heap); -} - -GlobalSortState::GlobalSortState(BufferManager &buffer_manager, const vector &orders, - RowLayout &payload_layout) - : buffer_manager(buffer_manager), sort_layout(SortLayout(orders)), payload_layout(payload_layout), - block_capacity(0), external(false) { -} - -void GlobalSortState::AddLocalState(LocalSortState &local_sort_state) { - if (!local_sort_state.radix_sorting_data) { - return; - } - - // Sort accumulated data - // we only re-order the heap when the data is expected to not fit in memory - // re-ordering the heap avoids random access when reading/merging but incurs a significant cost of shuffling data - // when data fits in memory, doing random access on reads is cheaper than re-shuffling - local_sort_state.Sort(*this, external || !local_sort_state.sorted_blocks.empty()); - - // Append local state sorted data to this global state - lock_guard append_guard(lock); - for (auto &sb : local_sort_state.sorted_blocks) { - sorted_blocks.push_back(std::move(sb)); - } - auto &payload_heap = local_sort_state.payload_heap; - for (idx_t i = 0; i < payload_heap->blocks.size(); i++) { - heap_blocks.push_back(std::move(payload_heap->blocks[i])); - pinned_blocks.push_back(std::move(payload_heap->pinned_blocks[i])); - } - if (!sort_layout.all_constant) { - auto &blob_heap = local_sort_state.blob_sorting_heap; - for (idx_t i = 0; i < blob_heap->blocks.size(); i++) { - heap_blocks.push_back(std::move(blob_heap->blocks[i])); - pinned_blocks.push_back(std::move(blob_heap->pinned_blocks[i])); - } - } -} - -void GlobalSortState::PrepareMergePhase() { - // Determine if we need to use do an external sort - idx_t total_heap_size = - std::accumulate(sorted_blocks.begin(), sorted_blocks.end(), (idx_t)0, - [](idx_t a, const unique_ptr &b) { return a + b->HeapSize(); }); - if (external || (pinned_blocks.empty() && total_heap_size > 0.25 * buffer_manager.GetMaxMemory())) { - external = true; - } - // Use the data that we have to determine which partition size to use during the merge - if (external && total_heap_size > 0) { - // If we have variable size data we need to be conservative, as there might be skew - idx_t max_block_size = 0; - for (auto &sb : sorted_blocks) { - idx_t size_in_bytes = sb->SizeInBytes(); - if (size_in_bytes > max_block_size) { - max_block_size = size_in_bytes; - block_capacity = sb->Count(); - } - } - } else { - for (auto &sb : sorted_blocks) { - block_capacity = MaxValue(block_capacity, sb->Count()); - } - } - // Unswizzle and pin heap blocks if we can fit everything in memory - if (!external) { - for (auto &sb : sorted_blocks) { - sb->blob_sorting_data->Unswizzle(); - sb->payload_data->Unswizzle(); - } - } -} - -void GlobalSortState::InitializeMergeRound() { - D_ASSERT(sorted_blocks_temp.empty()); - // If we reverse this list, the blocks that were merged last will be merged first in the next round - // These are still in memory, therefore this reduces the amount of read/write to disk! - std::reverse(sorted_blocks.begin(), sorted_blocks.end()); - // Uneven number of blocks - keep one on the side - if (sorted_blocks.size() % 2 == 1) { - odd_one_out = std::move(sorted_blocks.back()); - sorted_blocks.pop_back(); - } - // Init merge path path indices - pair_idx = 0; - num_pairs = sorted_blocks.size() / 2; - l_start = 0; - r_start = 0; - // Allocate room for merge results - for (idx_t p_idx = 0; p_idx < num_pairs; p_idx++) { - sorted_blocks_temp.emplace_back(); - } -} - -void GlobalSortState::CompleteMergeRound(bool keep_radix_data) { - sorted_blocks.clear(); - for (auto &sorted_block_vector : sorted_blocks_temp) { - sorted_blocks.push_back(make_uniq(buffer_manager, *this)); - sorted_blocks.back()->AppendSortedBlocks(sorted_block_vector); - } - sorted_blocks_temp.clear(); - if (odd_one_out) { - sorted_blocks.push_back(std::move(odd_one_out)); - odd_one_out = nullptr; - } - // Only one block left: Done! - if (sorted_blocks.size() == 1 && !keep_radix_data) { - sorted_blocks[0]->radix_sorting_data.clear(); - sorted_blocks[0]->blob_sorting_data = nullptr; - } -} -void GlobalSortState::Print() { - PayloadScanner scanner(*this, false); - DataChunk chunk; - chunk.Initialize(Allocator::DefaultAllocator(), scanner.GetPayloadTypes()); - for (;;) { - scanner.Scan(chunk); - const auto count = chunk.size(); - if (!count) { - break; - } - chunk.Print(); - } -} - -} // namespace duckdb - - - - - - - -#include - -namespace duckdb { - -SortedData::SortedData(SortedDataType type, const RowLayout &layout, BufferManager &buffer_manager, - GlobalSortState &state) - : type(type), layout(layout), swizzled(state.external), buffer_manager(buffer_manager), state(state) { -} - -idx_t SortedData::Count() { - idx_t count = std::accumulate(data_blocks.begin(), data_blocks.end(), (idx_t)0, - [](idx_t a, const unique_ptr &b) { return a + b->count; }); - if (!layout.AllConstant() && state.external) { - D_ASSERT(count == std::accumulate(heap_blocks.begin(), heap_blocks.end(), (idx_t)0, - [](idx_t a, const unique_ptr &b) { return a + b->count; })); - } - return count; -} - -void SortedData::CreateBlock() { - auto capacity = - MaxValue(((idx_t)Storage::BLOCK_SIZE + layout.GetRowWidth() - 1) / layout.GetRowWidth(), state.block_capacity); - data_blocks.push_back(make_uniq(buffer_manager, capacity, layout.GetRowWidth())); - if (!layout.AllConstant() && state.external) { - heap_blocks.push_back(make_uniq(buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1)); - D_ASSERT(data_blocks.size() == heap_blocks.size()); - } -} - -unique_ptr SortedData::CreateSlice(idx_t start_block_index, idx_t end_block_index, idx_t end_entry_index) { - // Add the corresponding blocks to the result - auto result = make_uniq(type, layout, buffer_manager, state); - for (idx_t i = start_block_index; i <= end_block_index; i++) { - result->data_blocks.push_back(data_blocks[i]->Copy()); - if (!layout.AllConstant() && state.external) { - result->heap_blocks.push_back(heap_blocks[i]->Copy()); - } - } - // All of the blocks that come before block with idx = start_block_idx can be reset (other references exist) - for (idx_t i = 0; i < start_block_index; i++) { - data_blocks[i]->block = nullptr; - if (!layout.AllConstant() && state.external) { - heap_blocks[i]->block = nullptr; - } - } - // Use start and end entry indices to set the boundaries - D_ASSERT(end_entry_index <= result->data_blocks.back()->count); - result->data_blocks.back()->count = end_entry_index; - if (!layout.AllConstant() && state.external) { - result->heap_blocks.back()->count = end_entry_index; - } - return result; -} - -void SortedData::Unswizzle() { - if (layout.AllConstant() || !swizzled) { - return; - } - for (idx_t i = 0; i < data_blocks.size(); i++) { - auto &data_block = data_blocks[i]; - auto &heap_block = heap_blocks[i]; - D_ASSERT(data_block->block->IsSwizzled()); - auto data_handle_p = buffer_manager.Pin(data_block->block); - auto heap_handle_p = buffer_manager.Pin(heap_block->block); - RowOperations::UnswizzlePointers(layout, data_handle_p.Ptr(), heap_handle_p.Ptr(), data_block->count); - state.heap_blocks.push_back(std::move(heap_block)); - state.pinned_blocks.push_back(std::move(heap_handle_p)); - } - swizzled = false; - heap_blocks.clear(); -} - -SortedBlock::SortedBlock(BufferManager &buffer_manager, GlobalSortState &state) - : buffer_manager(buffer_manager), state(state), sort_layout(state.sort_layout), - payload_layout(state.payload_layout) { - blob_sorting_data = make_uniq(SortedDataType::BLOB, sort_layout.blob_layout, buffer_manager, state); - payload_data = make_uniq(SortedDataType::PAYLOAD, payload_layout, buffer_manager, state); -} - -idx_t SortedBlock::Count() const { - idx_t count = std::accumulate(radix_sorting_data.begin(), radix_sorting_data.end(), (idx_t)0, - [](idx_t a, const unique_ptr &b) { return a + b->count; }); - if (!sort_layout.all_constant) { - D_ASSERT(count == blob_sorting_data->Count()); - } - D_ASSERT(count == payload_data->Count()); - return count; -} - -void SortedBlock::InitializeWrite() { - CreateBlock(); - if (!sort_layout.all_constant) { - blob_sorting_data->CreateBlock(); - } - payload_data->CreateBlock(); -} - -void SortedBlock::CreateBlock() { - auto capacity = MaxValue(((idx_t)Storage::BLOCK_SIZE + sort_layout.entry_size - 1) / sort_layout.entry_size, - state.block_capacity); - radix_sorting_data.push_back(make_uniq(buffer_manager, capacity, sort_layout.entry_size)); -} - -void SortedBlock::AppendSortedBlocks(vector> &sorted_blocks) { - D_ASSERT(Count() == 0); - for (auto &sb : sorted_blocks) { - for (auto &radix_block : sb->radix_sorting_data) { - radix_sorting_data.push_back(std::move(radix_block)); - } - if (!sort_layout.all_constant) { - for (auto &blob_block : sb->blob_sorting_data->data_blocks) { - blob_sorting_data->data_blocks.push_back(std::move(blob_block)); - } - for (auto &heap_block : sb->blob_sorting_data->heap_blocks) { - blob_sorting_data->heap_blocks.push_back(std::move(heap_block)); - } - } - for (auto &payload_data_block : sb->payload_data->data_blocks) { - payload_data->data_blocks.push_back(std::move(payload_data_block)); - } - if (!payload_data->layout.AllConstant()) { - for (auto &payload_heap_block : sb->payload_data->heap_blocks) { - payload_data->heap_blocks.push_back(std::move(payload_heap_block)); - } - } - } -} - -void SortedBlock::GlobalToLocalIndex(const idx_t &global_idx, idx_t &local_block_index, idx_t &local_entry_index) { - if (global_idx == Count()) { - local_block_index = radix_sorting_data.size() - 1; - local_entry_index = radix_sorting_data.back()->count; - return; - } - D_ASSERT(global_idx < Count()); - local_entry_index = global_idx; - for (local_block_index = 0; local_block_index < radix_sorting_data.size(); local_block_index++) { - const idx_t &block_count = radix_sorting_data[local_block_index]->count; - if (local_entry_index >= block_count) { - local_entry_index -= block_count; - } else { - break; - } - } - D_ASSERT(local_entry_index < radix_sorting_data[local_block_index]->count); -} - -unique_ptr SortedBlock::CreateSlice(const idx_t start, const idx_t end, idx_t &entry_idx) { - // Identify blocks/entry indices of this slice - idx_t start_block_index; - idx_t start_entry_index; - GlobalToLocalIndex(start, start_block_index, start_entry_index); - idx_t end_block_index; - idx_t end_entry_index; - GlobalToLocalIndex(end, end_block_index, end_entry_index); - // Add the corresponding blocks to the result - auto result = make_uniq(buffer_manager, state); - for (idx_t i = start_block_index; i <= end_block_index; i++) { - result->radix_sorting_data.push_back(radix_sorting_data[i]->Copy()); - } - // Reset all blocks that come before block with idx = start_block_idx (slice holds new reference) - for (idx_t i = 0; i < start_block_index; i++) { - radix_sorting_data[i]->block = nullptr; - } - // Use start and end entry indices to set the boundaries - entry_idx = start_entry_index; - D_ASSERT(end_entry_index <= result->radix_sorting_data.back()->count); - result->radix_sorting_data.back()->count = end_entry_index; - // Same for the var size sorting data - if (!sort_layout.all_constant) { - result->blob_sorting_data = blob_sorting_data->CreateSlice(start_block_index, end_block_index, end_entry_index); - } - // And the payload data - result->payload_data = payload_data->CreateSlice(start_block_index, end_block_index, end_entry_index); - return result; -} - -idx_t SortedBlock::HeapSize() const { - idx_t result = 0; - if (!sort_layout.all_constant) { - for (auto &block : blob_sorting_data->heap_blocks) { - result += block->capacity; - } - } - if (!payload_layout.AllConstant()) { - for (auto &block : payload_data->heap_blocks) { - result += block->capacity; - } - } - return result; -} - -idx_t SortedBlock::SizeInBytes() const { - idx_t bytes = 0; - for (idx_t i = 0; i < radix_sorting_data.size(); i++) { - bytes += radix_sorting_data[i]->capacity * sort_layout.entry_size; - if (!sort_layout.all_constant) { - bytes += blob_sorting_data->data_blocks[i]->capacity * sort_layout.blob_layout.GetRowWidth(); - bytes += blob_sorting_data->heap_blocks[i]->capacity; - } - bytes += payload_data->data_blocks[i]->capacity * payload_layout.GetRowWidth(); - if (!payload_layout.AllConstant()) { - bytes += payload_data->heap_blocks[i]->capacity; - } - } - return bytes; -} - -SBScanState::SBScanState(BufferManager &buffer_manager, GlobalSortState &state) - : buffer_manager(buffer_manager), sort_layout(state.sort_layout), state(state), block_idx(0), entry_idx(0) { -} - -void SBScanState::PinRadix(idx_t block_idx_to) { - auto &radix_sorting_data = sb->radix_sorting_data; - D_ASSERT(block_idx_to < radix_sorting_data.size()); - auto &block = radix_sorting_data[block_idx_to]; - if (!radix_handle.IsValid() || radix_handle.GetBlockHandle() != block->block) { - radix_handle = buffer_manager.Pin(block->block); - } -} - -void SBScanState::PinData(SortedData &sd) { - D_ASSERT(block_idx < sd.data_blocks.size()); - auto &data_handle = sd.type == SortedDataType::BLOB ? blob_sorting_data_handle : payload_data_handle; - auto &heap_handle = sd.type == SortedDataType::BLOB ? blob_sorting_heap_handle : payload_heap_handle; - - auto &data_block = sd.data_blocks[block_idx]; - if (!data_handle.IsValid() || data_handle.GetBlockHandle() != data_block->block) { - data_handle = buffer_manager.Pin(data_block->block); - } - if (sd.layout.AllConstant() || !state.external) { - return; - } - auto &heap_block = sd.heap_blocks[block_idx]; - if (!heap_handle.IsValid() || heap_handle.GetBlockHandle() != heap_block->block) { - heap_handle = buffer_manager.Pin(heap_block->block); - } -} - -data_ptr_t SBScanState::RadixPtr() const { - return radix_handle.Ptr() + entry_idx * sort_layout.entry_size; -} - -data_ptr_t SBScanState::DataPtr(SortedData &sd) const { - auto &data_handle = sd.type == SortedDataType::BLOB ? blob_sorting_data_handle : payload_data_handle; - D_ASSERT(sd.data_blocks[block_idx]->block->Readers() != 0 && - data_handle.GetBlockHandle() == sd.data_blocks[block_idx]->block); - return data_handle.Ptr() + entry_idx * sd.layout.GetRowWidth(); -} - -data_ptr_t SBScanState::HeapPtr(SortedData &sd) const { - return BaseHeapPtr(sd) + Load(DataPtr(sd) + sd.layout.GetHeapOffset()); -} - -data_ptr_t SBScanState::BaseHeapPtr(SortedData &sd) const { - auto &heap_handle = sd.type == SortedDataType::BLOB ? blob_sorting_heap_handle : payload_heap_handle; - D_ASSERT(!sd.layout.AllConstant() && state.external); - D_ASSERT(sd.heap_blocks[block_idx]->block->Readers() != 0 && - heap_handle.GetBlockHandle() == sd.heap_blocks[block_idx]->block); - return heap_handle.Ptr(); -} - -idx_t SBScanState::Remaining() const { - const auto &blocks = sb->radix_sorting_data; - idx_t remaining = 0; - if (block_idx < blocks.size()) { - remaining += blocks[block_idx]->count - entry_idx; - for (idx_t i = block_idx + 1; i < blocks.size(); i++) { - remaining += blocks[i]->count; - } - } - return remaining; -} - -void SBScanState::SetIndices(idx_t block_idx_to, idx_t entry_idx_to) { - block_idx = block_idx_to; - entry_idx = entry_idx_to; -} - -PayloadScanner::PayloadScanner(SortedData &sorted_data, GlobalSortState &global_sort_state, bool flush_p) { - auto count = sorted_data.Count(); - auto &layout = sorted_data.layout; - - // Create collections to put the data into so we can use RowDataCollectionScanner - rows = make_uniq(global_sort_state.buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1); - rows->count = count; - - heap = make_uniq(global_sort_state.buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1); - if (!sorted_data.layout.AllConstant()) { - heap->count = count; - } - - if (flush_p) { - // If we are flushing, we can just move the data - rows->blocks = std::move(sorted_data.data_blocks); - if (!layout.AllConstant()) { - heap->blocks = std::move(sorted_data.heap_blocks); - } - } else { - // Not flushing, create references to the blocks - for (auto &block : sorted_data.data_blocks) { - rows->blocks.emplace_back(block->Copy()); - } - if (!layout.AllConstant()) { - for (auto &block : sorted_data.heap_blocks) { - heap->blocks.emplace_back(block->Copy()); - } - } - } - - scanner = make_uniq(*rows, *heap, layout, global_sort_state.external, flush_p); -} - -PayloadScanner::PayloadScanner(GlobalSortState &global_sort_state, bool flush_p) - : PayloadScanner(*global_sort_state.sorted_blocks[0]->payload_data, global_sort_state, flush_p) { -} - -PayloadScanner::PayloadScanner(GlobalSortState &global_sort_state, idx_t block_idx, bool flush_p) { - auto &sorted_data = *global_sort_state.sorted_blocks[0]->payload_data; - auto count = sorted_data.data_blocks[block_idx]->count; - auto &layout = sorted_data.layout; - - // Create collections to put the data into so we can use RowDataCollectionScanner - rows = make_uniq(global_sort_state.buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1); - if (flush_p) { - rows->blocks.emplace_back(std::move(sorted_data.data_blocks[block_idx])); - } else { - rows->blocks.emplace_back(sorted_data.data_blocks[block_idx]->Copy()); - } - rows->count = count; - - heap = make_uniq(global_sort_state.buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1); - if (!sorted_data.layout.AllConstant() && sorted_data.swizzled) { - if (flush_p) { - heap->blocks.emplace_back(std::move(sorted_data.heap_blocks[block_idx])); - } else { - heap->blocks.emplace_back(sorted_data.heap_blocks[block_idx]->Copy()); - } - heap->count = count; - } - - scanner = make_uniq(*rows, *heap, layout, global_sort_state.external, flush_p); -} - -void PayloadScanner::Scan(DataChunk &chunk) { - scanner->Scan(chunk); -} - -int SBIterator::ComparisonValue(ExpressionType comparison) { - switch (comparison) { - case ExpressionType::COMPARE_LESSTHAN: - case ExpressionType::COMPARE_GREATERTHAN: - return -1; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - return 0; - default: - throw InternalException("Unimplemented comparison type for IEJoin!"); - } -} - -static idx_t GetBlockCountWithEmptyCheck(const GlobalSortState &gss) { - D_ASSERT(!gss.sorted_blocks.empty()); - return gss.sorted_blocks[0]->radix_sorting_data.size(); -} - -SBIterator::SBIterator(GlobalSortState &gss, ExpressionType comparison, idx_t entry_idx_p) - : sort_layout(gss.sort_layout), block_count(GetBlockCountWithEmptyCheck(gss)), block_capacity(gss.block_capacity), - cmp_size(sort_layout.comparison_size), entry_size(sort_layout.entry_size), all_constant(sort_layout.all_constant), - external(gss.external), cmp(ComparisonValue(comparison)), scan(gss.buffer_manager, gss), block_ptr(nullptr), - entry_ptr(nullptr) { - - scan.sb = gss.sorted_blocks[0].get(); - scan.block_idx = block_count; - SetIndex(entry_idx_p); -} - -} // namespace duckdb - - - - - - - -#include -#include -#include -#include -#include -#include -#include -#include - -namespace duckdb { - -string StringUtil::GenerateRandomName(idx_t length) { - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_int_distribution<> dis(0, 15); - - std::stringstream ss; - ss << std::hex; - for (idx_t i = 0; i < length; i++) { - ss << dis(gen); - } - return ss.str(); -} - -bool StringUtil::Contains(const string &haystack, const string &needle) { - return (haystack.find(needle) != string::npos); -} - -void StringUtil::LTrim(string &str) { - auto it = str.begin(); - while (it != str.end() && CharacterIsSpace(*it)) { - it++; - } - str.erase(str.begin(), it); -} - -// Remove trailing ' ', '\f', '\n', '\r', '\t', '\v' -void StringUtil::RTrim(string &str) { - str.erase(find_if(str.rbegin(), str.rend(), [](int ch) { return ch > 0 && !CharacterIsSpace(ch); }).base(), - str.end()); -} - -void StringUtil::RTrim(string &str, const string &chars_to_trim) { - str.erase(find_if(str.rbegin(), str.rend(), - [&chars_to_trim](int ch) { return ch > 0 && chars_to_trim.find(ch) == string::npos; }) - .base(), - str.end()); -} - -void StringUtil::Trim(string &str) { - StringUtil::LTrim(str); - StringUtil::RTrim(str); -} - -bool StringUtil::StartsWith(string str, string prefix) { - if (prefix.size() > str.size()) { - return false; - } - return equal(prefix.begin(), prefix.end(), str.begin()); -} - -bool StringUtil::EndsWith(const string &str, const string &suffix) { - if (suffix.size() > str.size()) { - return false; - } - return equal(suffix.rbegin(), suffix.rend(), str.rbegin()); -} - -string StringUtil::Repeat(const string &str, idx_t n) { - std::ostringstream os; - for (idx_t i = 0; i < n; i++) { - os << str; - } - return (os.str()); -} - -vector StringUtil::Split(const string &str, char delimiter) { - std::stringstream ss(str); - vector lines; - string temp; - while (getline(ss, temp, delimiter)) { - lines.push_back(temp); - } - return (lines); -} - -namespace string_util_internal { - -inline void SkipSpaces(const string &str, idx_t &index) { - while (index < str.size() && std::isspace(str[index])) { - index++; - } -} - -inline void ConsumeLetter(const string &str, idx_t &index, char expected) { - if (index >= str.size() || str[index] != expected) { - throw ParserException("Invalid quoted list: %s", str); - } - - index++; -} - -template -inline void TakeWhile(const string &str, idx_t &index, const F &cond, string &taker) { - while (index < str.size() && cond(str[index])) { - taker.push_back(str[index]); - index++; - } -} - -inline string TakePossiblyQuotedItem(const string &str, idx_t &index, char delimiter, char quote) { - string entry; - - if (str[index] == quote) { - index++; - TakeWhile( - str, index, [quote](char c) { return c != quote; }, entry); - ConsumeLetter(str, index, quote); - } else { - TakeWhile( - str, index, [delimiter, quote](char c) { return c != delimiter && c != quote && !std::isspace(c); }, entry); - } - - return entry; -} - -} // namespace string_util_internal - -vector StringUtil::SplitWithQuote(const string &str, char delimiter, char quote) { - vector entries; - idx_t i = 0; - - string_util_internal::SkipSpaces(str, i); - while (i < str.size()) { - if (!entries.empty()) { - string_util_internal::ConsumeLetter(str, i, delimiter); - } - - entries.emplace_back(string_util_internal::TakePossiblyQuotedItem(str, i, delimiter, quote)); - string_util_internal::SkipSpaces(str, i); - } - - return entries; -} - -string StringUtil::Join(const vector &input, const string &separator) { - return StringUtil::Join(input, input.size(), separator, [](const string &s) { return s; }); -} - -string StringUtil::BytesToHumanReadableString(idx_t bytes) { - string db_size; - auto kilobytes = bytes / 1000; - auto megabytes = kilobytes / 1000; - kilobytes -= megabytes * 1000; - auto gigabytes = megabytes / 1000; - megabytes -= gigabytes * 1000; - auto terabytes = gigabytes / 1000; - gigabytes -= terabytes * 1000; - auto petabytes = terabytes / 1000; - terabytes -= petabytes * 1000; - if (petabytes > 0) { - return to_string(petabytes) + "." + to_string(terabytes / 100) + "PB"; - } - if (terabytes > 0) { - return to_string(terabytes) + "." + to_string(gigabytes / 100) + "TB"; - } else if (gigabytes > 0) { - return to_string(gigabytes) + "." + to_string(megabytes / 100) + "GB"; - } else if (megabytes > 0) { - return to_string(megabytes) + "." + to_string(kilobytes / 100) + "MB"; - } else if (kilobytes > 0) { - return to_string(kilobytes) + "KB"; - } else { - return to_string(bytes) + (bytes == 1 ? " byte" : " bytes"); - } -} - -string StringUtil::Upper(const string &str) { - string copy(str); - transform(copy.begin(), copy.end(), copy.begin(), [](unsigned char c) { return std::toupper(c); }); - return (copy); -} - -string StringUtil::Lower(const string &str) { - string copy(str); - transform(copy.begin(), copy.end(), copy.begin(), [](unsigned char c) { return StringUtil::CharacterToLower(c); }); - return (copy); -} - -bool StringUtil::IsLower(const string &str) { - return str == Lower(str); -} - -// Jenkins hash function: https://en.wikipedia.org/wiki/Jenkins_hash_function -uint64_t StringUtil::CIHash(const string &str) { - uint32_t hash = 0; - for (auto c : str) { - hash += StringUtil::CharacterToLower(c); - hash += hash << 10; - hash ^= hash >> 6; - } - hash += hash << 3; - hash ^= hash >> 11; - hash += hash << 15; - return hash; -} - -bool StringUtil::CIEquals(const string &l1, const string &l2) { - if (l1.size() != l2.size()) { - return false; - } - for (idx_t c = 0; c < l1.size(); c++) { - if (StringUtil::CharacterToLower(l1[c]) != StringUtil::CharacterToLower(l2[c])) { - return false; - } - } - return true; -} - -vector StringUtil::Split(const string &input, const string &split) { - vector splits; - - idx_t last = 0; - idx_t input_len = input.size(); - idx_t split_len = split.size(); - while (last <= input_len) { - idx_t next = input.find(split, last); - if (next == string::npos) { - next = input_len; - } - - // Push the substring [last, next) on to splits - string substr = input.substr(last, next - last); - if (!substr.empty()) { - splits.push_back(substr); - } - last = next + split_len; - } - if (splits.empty()) { - splits.push_back(input); - } - return splits; -} - -string StringUtil::Replace(string source, const string &from, const string &to) { - if (from.empty()) { - throw InternalException("Invalid argument to StringUtil::Replace - empty FROM"); - } - idx_t start_pos = 0; - while ((start_pos = source.find(from, start_pos)) != string::npos) { - source.replace(start_pos, from.length(), to); - start_pos += to.length(); // In case 'to' contains 'from', like - // replacing 'x' with 'yx' - } - return source; -} - -vector StringUtil::TopNStrings(vector> scores, idx_t n, idx_t threshold) { - if (scores.empty()) { - return vector(); - } - sort(scores.begin(), scores.end(), [](const pair &a, const pair &b) -> bool { - return a.second < b.second || (a.second == b.second && a.first.size() < b.first.size()); - }); - vector result; - result.push_back(scores[0].first); - for (idx_t i = 1; i < MinValue(scores.size(), n); i++) { - if (scores[i].second > threshold) { - break; - } - result.push_back(scores[i].first); - } - return result; -} - -struct LevenshteinArray { - LevenshteinArray(idx_t len1, idx_t len2) : len1(len1) { - dist = make_unsafe_uniq_array(len1 * len2); - } - - idx_t &Score(idx_t i, idx_t j) { - return dist[GetIndex(i, j)]; - } - -private: - idx_t len1; - unsafe_unique_array dist; - - idx_t GetIndex(idx_t i, idx_t j) { - return j * len1 + i; - } -}; - -// adapted from https://en.wikibooks.org/wiki/Algorithm_Implementation/Strings/Levenshtein_distance#C++ -idx_t StringUtil::LevenshteinDistance(const string &s1_p, const string &s2_p, idx_t not_equal_penalty) { - auto s1 = StringUtil::Lower(s1_p); - auto s2 = StringUtil::Lower(s2_p); - idx_t len1 = s1.size(); - idx_t len2 = s2.size(); - if (len1 == 0) { - return len2; - } - if (len2 == 0) { - return len1; - } - LevenshteinArray array(len1 + 1, len2 + 1); - array.Score(0, 0) = 0; - for (idx_t i = 0; i <= len1; i++) { - array.Score(i, 0) = i; - } - for (idx_t j = 0; j <= len2; j++) { - array.Score(0, j) = j; - } - for (idx_t i = 1; i <= len1; i++) { - for (idx_t j = 1; j <= len2; j++) { - // d[i][j] = std::min({ d[i - 1][j] + 1, - // d[i][j - 1] + 1, - // d[i - 1][j - 1] + (s1[i - 1] == s2[j - 1] ? 0 : 1) }); - int equal = s1[i - 1] == s2[j - 1] ? 0 : not_equal_penalty; - idx_t adjacent_score1 = array.Score(i - 1, j) + 1; - idx_t adjacent_score2 = array.Score(i, j - 1) + 1; - idx_t adjacent_score3 = array.Score(i - 1, j - 1) + equal; - - idx_t t = MinValue(adjacent_score1, adjacent_score2); - array.Score(i, j) = MinValue(t, adjacent_score3); - } - } - return array.Score(len1, len2); -} - -idx_t StringUtil::SimilarityScore(const string &s1, const string &s2) { - return LevenshteinDistance(s1, s2, 3); -} - -vector StringUtil::TopNLevenshtein(const vector &strings, const string &target, idx_t n, - idx_t threshold) { - vector> scores; - scores.reserve(strings.size()); - for (auto &str : strings) { - if (target.size() < str.size()) { - scores.emplace_back(str, SimilarityScore(str.substr(0, target.size()), target)); - } else { - scores.emplace_back(str, SimilarityScore(str, target)); - } - } - return TopNStrings(scores, n, threshold); -} - -string StringUtil::CandidatesMessage(const vector &candidates, const string &candidate) { - string result_str; - if (!candidates.empty()) { - result_str = "\n" + candidate + ": "; - for (idx_t i = 0; i < candidates.size(); i++) { - if (i > 0) { - result_str += ", "; - } - result_str += "\"" + candidates[i] + "\""; - } - } - return result_str; -} - -string StringUtil::CandidatesErrorMessage(const vector &strings, const string &target, - const string &message_prefix, idx_t n) { - auto closest_strings = StringUtil::TopNLevenshtein(strings, target, n); - return StringUtil::CandidatesMessage(closest_strings, message_prefix); -} - -} // namespace duckdb - - - - - - - - - - - -#include - -namespace duckdb { - -RenderTree::RenderTree(idx_t width_p, idx_t height_p) : width(width_p), height(height_p) { - nodes = unique_ptr[]>(new unique_ptr[(width + 1) * (height + 1)]); -} - -RenderTreeNode *RenderTree::GetNode(idx_t x, idx_t y) { - if (x >= width || y >= height) { - return nullptr; - } - return nodes[GetPosition(x, y)].get(); -} - -bool RenderTree::HasNode(idx_t x, idx_t y) { - if (x >= width || y >= height) { - return false; - } - return nodes[GetPosition(x, y)].get() != nullptr; -} - -idx_t RenderTree::GetPosition(idx_t x, idx_t y) { - return y * width + x; -} - -void RenderTree::SetNode(idx_t x, idx_t y, unique_ptr node) { - nodes[GetPosition(x, y)] = std::move(node); -} - -void TreeRenderer::RenderTopLayer(RenderTree &root, std::ostream &ss, idx_t y) { - for (idx_t x = 0; x < root.width; x++) { - if (x * config.NODE_RENDER_WIDTH >= config.MAXIMUM_RENDER_WIDTH) { - break; - } - if (root.HasNode(x, y)) { - ss << config.LTCORNER; - ss << StringUtil::Repeat(config.HORIZONTAL, config.NODE_RENDER_WIDTH / 2 - 1); - if (y == 0) { - // top level node: no node above this one - ss << config.HORIZONTAL; - } else { - // render connection to node above this one - ss << config.DMIDDLE; - } - ss << StringUtil::Repeat(config.HORIZONTAL, config.NODE_RENDER_WIDTH / 2 - 1); - ss << config.RTCORNER; - } else { - ss << StringUtil::Repeat(" ", config.NODE_RENDER_WIDTH); - } - } - ss << std::endl; -} - -void TreeRenderer::RenderBottomLayer(RenderTree &root, std::ostream &ss, idx_t y) { - for (idx_t x = 0; x <= root.width; x++) { - if (x * config.NODE_RENDER_WIDTH >= config.MAXIMUM_RENDER_WIDTH) { - break; - } - if (root.HasNode(x, y)) { - ss << config.LDCORNER; - ss << StringUtil::Repeat(config.HORIZONTAL, config.NODE_RENDER_WIDTH / 2 - 1); - if (root.HasNode(x, y + 1)) { - // node below this one: connect to that one - ss << config.TMIDDLE; - } else { - // no node below this one: end the box - ss << config.HORIZONTAL; - } - ss << StringUtil::Repeat(config.HORIZONTAL, config.NODE_RENDER_WIDTH / 2 - 1); - ss << config.RDCORNER; - } else if (root.HasNode(x, y + 1)) { - ss << StringUtil::Repeat(" ", config.NODE_RENDER_WIDTH / 2); - ss << config.VERTICAL; - ss << StringUtil::Repeat(" ", config.NODE_RENDER_WIDTH / 2); - } else { - ss << StringUtil::Repeat(" ", config.NODE_RENDER_WIDTH); - } - } - ss << std::endl; -} - -string AdjustTextForRendering(string source, idx_t max_render_width) { - idx_t cpos = 0; - idx_t render_width = 0; - vector> render_widths; - while (cpos < source.size()) { - idx_t char_render_width = Utf8Proc::RenderWidth(source.c_str(), source.size(), cpos); - cpos = Utf8Proc::NextGraphemeCluster(source.c_str(), source.size(), cpos); - render_width += char_render_width; - render_widths.emplace_back(cpos, render_width); - if (render_width > max_render_width) { - break; - } - } - if (render_width > max_render_width) { - // need to find a position to truncate - for (idx_t pos = render_widths.size(); pos > 0; pos--) { - if (render_widths[pos - 1].second < max_render_width - 4) { - return source.substr(0, render_widths[pos - 1].first) + "..." + - string(max_render_width - render_widths[pos - 1].second - 3, ' '); - } - } - source = "..."; - } - // need to pad with spaces - idx_t total_spaces = max_render_width - render_width; - idx_t half_spaces = total_spaces / 2; - idx_t extra_left_space = total_spaces % 2 == 0 ? 0 : 1; - return string(half_spaces + extra_left_space, ' ') + source + string(half_spaces, ' '); -} - -static bool NodeHasMultipleChildren(RenderTree &root, idx_t x, idx_t y) { - for (; x < root.width && !root.HasNode(x + 1, y); x++) { - if (root.HasNode(x + 1, y + 1)) { - return true; - } - } - return false; -} - -void TreeRenderer::RenderBoxContent(RenderTree &root, std::ostream &ss, idx_t y) { - // we first need to figure out how high our boxes are going to be - vector> extra_info; - idx_t extra_height = 0; - extra_info.resize(root.width); - for (idx_t x = 0; x < root.width; x++) { - auto node = root.GetNode(x, y); - if (node) { - SplitUpExtraInfo(node->extra_text, extra_info[x]); - if (extra_info[x].size() > extra_height) { - extra_height = extra_info[x].size(); - } - } - } - extra_height = MinValue(extra_height, config.MAX_EXTRA_LINES); - idx_t halfway_point = (extra_height + 1) / 2; - // now we render the actual node - for (idx_t render_y = 0; render_y <= extra_height; render_y++) { - for (idx_t x = 0; x < root.width; x++) { - if (x * config.NODE_RENDER_WIDTH >= config.MAXIMUM_RENDER_WIDTH) { - break; - } - auto node = root.GetNode(x, y); - if (!node) { - if (render_y == halfway_point) { - bool has_child_to_the_right = NodeHasMultipleChildren(root, x, y); - if (root.HasNode(x, y + 1)) { - // node right below this one - ss << StringUtil::Repeat(config.HORIZONTAL, config.NODE_RENDER_WIDTH / 2); - ss << config.RTCORNER; - if (has_child_to_the_right) { - // but we have another child to the right! keep rendering the line - ss << StringUtil::Repeat(config.HORIZONTAL, config.NODE_RENDER_WIDTH / 2); - } else { - // only a child below this one: fill the rest with spaces - ss << StringUtil::Repeat(" ", config.NODE_RENDER_WIDTH / 2); - } - } else if (has_child_to_the_right) { - // child to the right, but no child right below this one: render a full line - ss << StringUtil::Repeat(config.HORIZONTAL, config.NODE_RENDER_WIDTH); - } else { - // empty spot: render spaces - ss << StringUtil::Repeat(" ", config.NODE_RENDER_WIDTH); - } - } else if (render_y >= halfway_point) { - if (root.HasNode(x, y + 1)) { - // we have a node below this empty spot: render a vertical line - ss << StringUtil::Repeat(" ", config.NODE_RENDER_WIDTH / 2); - ss << config.VERTICAL; - ss << StringUtil::Repeat(" ", config.NODE_RENDER_WIDTH / 2); - } else { - // empty spot: render spaces - ss << StringUtil::Repeat(" ", config.NODE_RENDER_WIDTH); - } - } else { - // empty spot: render spaces - ss << StringUtil::Repeat(" ", config.NODE_RENDER_WIDTH); - } - } else { - ss << config.VERTICAL; - // figure out what to render - string render_text; - if (render_y == 0) { - render_text = node->name; - } else { - if (render_y <= extra_info[x].size()) { - render_text = extra_info[x][render_y - 1]; - } - } - render_text = AdjustTextForRendering(render_text, config.NODE_RENDER_WIDTH - 2); - ss << render_text; - - if (render_y == halfway_point && NodeHasMultipleChildren(root, x, y)) { - ss << config.LMIDDLE; - } else { - ss << config.VERTICAL; - } - } - } - ss << std::endl; - } -} - -string TreeRenderer::ToString(const LogicalOperator &op) { - std::stringstream ss; - Render(op, ss); - return ss.str(); -} - -string TreeRenderer::ToString(const PhysicalOperator &op) { - std::stringstream ss; - Render(op, ss); - return ss.str(); -} - -string TreeRenderer::ToString(const QueryProfiler::TreeNode &op) { - std::stringstream ss; - Render(op, ss); - return ss.str(); -} - -string TreeRenderer::ToString(const Pipeline &op) { - std::stringstream ss; - Render(op, ss); - return ss.str(); -} - -void TreeRenderer::Render(const LogicalOperator &op, std::ostream &ss) { - auto tree = CreateTree(op); - ToStream(*tree, ss); -} - -void TreeRenderer::Render(const PhysicalOperator &op, std::ostream &ss) { - auto tree = CreateTree(op); - ToStream(*tree, ss); -} - -void TreeRenderer::Render(const QueryProfiler::TreeNode &op, std::ostream &ss) { - auto tree = CreateTree(op); - ToStream(*tree, ss); -} - -void TreeRenderer::Render(const Pipeline &op, std::ostream &ss) { - auto tree = CreateTree(op); - ToStream(*tree, ss); -} - -void TreeRenderer::ToStream(RenderTree &root, std::ostream &ss) { - while (root.width * config.NODE_RENDER_WIDTH > config.MAXIMUM_RENDER_WIDTH) { - if (config.NODE_RENDER_WIDTH - 2 < config.MINIMUM_RENDER_WIDTH) { - break; - } - config.NODE_RENDER_WIDTH -= 2; - } - - for (idx_t y = 0; y < root.height; y++) { - // start by rendering the top layer - RenderTopLayer(root, ss, y); - // now we render the content of the boxes - RenderBoxContent(root, ss, y); - // render the bottom layer of each of the boxes - RenderBottomLayer(root, ss, y); - } -} - -bool TreeRenderer::CanSplitOnThisChar(char l) { - return (l < '0' || (l > '9' && l < 'A') || (l > 'Z' && l < 'a')) && l != '_'; -} - -bool TreeRenderer::IsPadding(char l) { - return l == ' ' || l == '\t' || l == '\n' || l == '\r'; -} - -string TreeRenderer::RemovePadding(string l) { - idx_t start = 0, end = l.size(); - while (start < l.size() && IsPadding(l[start])) { - start++; - } - while (end > 0 && IsPadding(l[end - 1])) { - end--; - } - return l.substr(start, end - start); -} - -void TreeRenderer::SplitStringBuffer(const string &source, vector &result) { - D_ASSERT(Utf8Proc::IsValid(source.c_str(), source.size())); - idx_t max_line_render_size = config.NODE_RENDER_WIDTH - 2; - // utf8 in prompt, get render width - idx_t cpos = 0; - idx_t start_pos = 0; - idx_t render_width = 0; - idx_t last_possible_split = 0; - while (cpos < source.size()) { - // check if we can split on this character - if (CanSplitOnThisChar(source[cpos])) { - last_possible_split = cpos; - } - size_t char_render_width = Utf8Proc::RenderWidth(source.c_str(), source.size(), cpos); - idx_t next_cpos = Utf8Proc::NextGraphemeCluster(source.c_str(), source.size(), cpos); - if (render_width + char_render_width > max_line_render_size) { - if (last_possible_split <= start_pos + 8) { - last_possible_split = cpos; - } - result.push_back(source.substr(start_pos, last_possible_split - start_pos)); - start_pos = last_possible_split; - cpos = last_possible_split; - render_width = 0; - } - cpos = next_cpos; - render_width += char_render_width; - } - if (source.size() > start_pos) { - result.push_back(source.substr(start_pos, source.size() - start_pos)); - } -} - -void TreeRenderer::SplitUpExtraInfo(const string &extra_info, vector &result) { - if (extra_info.empty()) { - return; - } - if (!Utf8Proc::IsValid(extra_info.c_str(), extra_info.size())) { - return; - } - auto splits = StringUtil::Split(extra_info, "\n"); - if (!splits.empty() && splits[0] != "[INFOSEPARATOR]") { - result.push_back(ExtraInfoSeparator()); - } - for (auto &split : splits) { - if (split == "[INFOSEPARATOR]") { - result.push_back(ExtraInfoSeparator()); - continue; - } - string str = RemovePadding(split); - if (str.empty()) { - continue; - } - SplitStringBuffer(str, result); - } -} - -string TreeRenderer::ExtraInfoSeparator() { - return StringUtil::Repeat(string(config.HORIZONTAL) + " ", (config.NODE_RENDER_WIDTH - 7) / 2); -} - -unique_ptr TreeRenderer::CreateRenderNode(string name, string extra_info) { - auto result = make_uniq(); - result->name = std::move(name); - result->extra_text = std::move(extra_info); - return result; -} - -class TreeChildrenIterator { -public: - template - static bool HasChildren(const T &op) { - return !op.children.empty(); - } - template - static void Iterate(const T &op, const std::function &callback) { - for (auto &child : op.children) { - callback(*child); - } - } -}; - -template <> -bool TreeChildrenIterator::HasChildren(const PhysicalOperator &op) { - switch (op.type) { - case PhysicalOperatorType::DELIM_JOIN: - case PhysicalOperatorType::POSITIONAL_SCAN: - return true; - default: - return !op.children.empty(); - } -} -template <> -void TreeChildrenIterator::Iterate(const PhysicalOperator &op, - const std::function &callback) { - for (auto &child : op.children) { - callback(*child); - } - if (op.type == PhysicalOperatorType::DELIM_JOIN) { - auto &delim = op.Cast(); - callback(*delim.join); - } else if ((op.type == PhysicalOperatorType::POSITIONAL_SCAN)) { - auto &pscan = op.Cast(); - for (auto &table : pscan.child_tables) { - callback(*table); - } - } -} - -struct PipelineRenderNode { - explicit PipelineRenderNode(const PhysicalOperator &op) : op(op) { - } - - const PhysicalOperator &op; - unique_ptr child; -}; - -template <> -bool TreeChildrenIterator::HasChildren(const PipelineRenderNode &op) { - return op.child.get(); -} - -template <> -void TreeChildrenIterator::Iterate(const PipelineRenderNode &op, - const std::function &callback) { - if (op.child) { - callback(*op.child); - } -} - -template -static void GetTreeWidthHeight(const T &op, idx_t &width, idx_t &height) { - if (!TreeChildrenIterator::HasChildren(op)) { - width = 1; - height = 1; - return; - } - width = 0; - height = 0; - - TreeChildrenIterator::Iterate(op, [&](const T &child) { - idx_t child_width, child_height; - GetTreeWidthHeight(child, child_width, child_height); - width += child_width; - height = MaxValue(height, child_height); - }); - height++; -} - -template -idx_t TreeRenderer::CreateRenderTreeRecursive(RenderTree &result, const T &op, idx_t x, idx_t y) { - auto node = TreeRenderer::CreateNode(op); - result.SetNode(x, y, std::move(node)); - - if (!TreeChildrenIterator::HasChildren(op)) { - return 1; - } - idx_t width = 0; - // render the children of this node - TreeChildrenIterator::Iterate( - op, [&](const T &child) { width += CreateRenderTreeRecursive(result, child, x + width, y + 1); }); - return width; -} - -template -unique_ptr TreeRenderer::CreateRenderTree(const T &op) { - idx_t width, height; - GetTreeWidthHeight(op, width, height); - - auto result = make_uniq(width, height); - - // now fill in the tree - CreateRenderTreeRecursive(*result, op, 0, 0); - return result; -} - -unique_ptr TreeRenderer::CreateNode(const LogicalOperator &op) { - return CreateRenderNode(op.GetName(), op.ParamsToString()); -} - -unique_ptr TreeRenderer::CreateNode(const PhysicalOperator &op) { - return CreateRenderNode(op.GetName(), op.ParamsToString()); -} - -unique_ptr TreeRenderer::CreateNode(const PipelineRenderNode &op) { - return CreateNode(op.op); -} - -string TreeRenderer::ExtractExpressionsRecursive(ExpressionInfo &state) { - string result = "\n[INFOSEPARATOR]"; - result += "\n" + state.function_name; - result += "\n" + StringUtil::Format("%.9f", double(state.function_time)); - if (state.children.empty()) { - return result; - } - // render the children of this node - for (auto &child : state.children) { - result += ExtractExpressionsRecursive(*child); - } - return result; -} - -unique_ptr TreeRenderer::CreateNode(const QueryProfiler::TreeNode &op) { - auto result = TreeRenderer::CreateRenderNode(op.name, op.extra_info); - result->extra_text += "\n[INFOSEPARATOR]"; - result->extra_text += "\n" + to_string(op.info.elements); - string timing = StringUtil::Format("%.2f", op.info.time); - result->extra_text += "\n(" + timing + "s)"; - if (config.detailed) { - for (auto &info : op.info.executors_info) { - if (!info) { - continue; - } - for (auto &executor_info : info->roots) { - string sample_count = to_string(executor_info->sample_count); - result->extra_text += "\n[INFOSEPARATOR]"; - result->extra_text += "\nsample_count: " + sample_count; - string sample_tuples_count = to_string(executor_info->sample_tuples_count); - result->extra_text += "\n[INFOSEPARATOR]"; - result->extra_text += "\nsample_tuples_count: " + sample_tuples_count; - string total_count = to_string(executor_info->total_count); - result->extra_text += "\n[INFOSEPARATOR]"; - result->extra_text += "\ntotal_count: " + total_count; - for (auto &state : executor_info->root->children) { - result->extra_text += ExtractExpressionsRecursive(*state); - } - } - } - } - return result; -} - -unique_ptr TreeRenderer::CreateTree(const LogicalOperator &op) { - return CreateRenderTree(op); -} - -unique_ptr TreeRenderer::CreateTree(const PhysicalOperator &op) { - return CreateRenderTree(op); -} - -unique_ptr TreeRenderer::CreateTree(const QueryProfiler::TreeNode &op) { - return CreateRenderTree(op); -} - -unique_ptr TreeRenderer::CreateTree(const Pipeline &op) { - auto operators = op.GetOperators(); - D_ASSERT(!operators.empty()); - unique_ptr node; - for (auto &op : operators) { - auto new_node = make_uniq(op.get()); - new_node->child = std::move(node); - node = std::move(new_node); - } - return CreateRenderTree(*node); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -BatchedDataCollection::BatchedDataCollection(ClientContext &context_p, vector types_p, - bool buffer_managed_p) - : context(context_p), types(std::move(types_p)), buffer_managed(buffer_managed_p) { -} - -void BatchedDataCollection::Append(DataChunk &input, idx_t batch_index) { - D_ASSERT(batch_index != DConstants::INVALID_INDEX); - optional_ptr collection; - if (last_collection.collection && last_collection.batch_index == batch_index) { - // we are inserting into the same collection as before: use it directly - collection = last_collection.collection; - } else { - // new collection: check if there is already an entry - D_ASSERT(data.find(batch_index) == data.end()); - unique_ptr new_collection; - if (last_collection.collection) { - new_collection = make_uniq(*last_collection.collection); - } else if (buffer_managed) { - new_collection = make_uniq(BufferManager::GetBufferManager(context), types); - } else { - new_collection = make_uniq(Allocator::DefaultAllocator(), types); - } - last_collection.collection = new_collection.get(); - last_collection.batch_index = batch_index; - new_collection->InitializeAppend(last_collection.append_state); - collection = new_collection.get(); - data.insert(make_pair(batch_index, std::move(new_collection))); - } - collection->Append(last_collection.append_state, input); -} - -void BatchedDataCollection::Merge(BatchedDataCollection &other) { - for (auto &entry : other.data) { - if (data.find(entry.first) != data.end()) { - throw InternalException( - "BatchedDataCollection::Merge error - batch index %d is present in both collections. This occurs when " - "batch indexes are not uniquely distributed over threads", - entry.first); - } - data[entry.first] = std::move(entry.second); - } - other.data.clear(); -} - -void BatchedDataCollection::InitializeScan(BatchedChunkScanState &state) { - state.iterator = data.begin(); - if (state.iterator == data.end()) { - return; - } - state.iterator->second->InitializeScan(state.scan_state); -} - -void BatchedDataCollection::Scan(BatchedChunkScanState &state, DataChunk &output) { - while (state.iterator != data.end()) { - // check if there is a chunk remaining in this collection - auto collection = state.iterator->second.get(); - collection->Scan(state.scan_state, output); - if (output.size() > 0) { - return; - } - // there isn't! move to the next collection - state.iterator++; - if (state.iterator == data.end()) { - return; - } - state.iterator->second->InitializeScan(state.scan_state); - } -} - -unique_ptr BatchedDataCollection::FetchCollection() { - unique_ptr result; - for (auto &entry : data) { - if (!result) { - result = std::move(entry.second); - } else { - result->Combine(*entry.second); - } - } - data.clear(); - if (!result) { - // empty result - return make_uniq(Allocator::DefaultAllocator(), types); - } - return result; -} - -string BatchedDataCollection::ToString() const { - string result; - result += "Batched Data Collection\n"; - for (auto &entry : data) { - result += "Batch Index - " + to_string(entry.first) + "\n"; - result += entry.second->ToString() + "\n\n"; - } - return result; -} - -void BatchedDataCollection::Print() const { - Printer::Print(ToString()); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -// **** helper functions **** -static char ComputePadding(idx_t len) { - return (8 - (len % 8)) % 8; -} - -idx_t Bit::ComputeBitstringLen(idx_t len) { - idx_t result = len / 8; - if (len % 8 != 0) { - result++; - } - // additional first byte to store info on zero padding - result++; - return result; -} - -static inline idx_t GetBitPadding(const string_t &bit_string) { - auto data = const_data_ptr_cast(bit_string.GetData()); - D_ASSERT(idx_t(data[0]) <= 8); - return data[0]; -} - -static inline idx_t GetBitSize(const string_t &str) { - string error_message; - idx_t str_len; - if (!Bit::TryGetBitStringSize(str, str_len, &error_message)) { - throw ConversionException(error_message); - } - return str_len; -} - -uint8_t Bit::GetFirstByte(const string_t &str) { - D_ASSERT(str.GetSize() > 1); - - auto data = const_data_ptr_cast(str.GetData()); - return data[1] & ((1 << (8 - data[0])) - 1); -} - -void Bit::Finalize(string_t &str) { - // bit strings require all padding bits to be set to 1 - // this method sets all padding bits to 1 - auto padding = GetBitPadding(str); - for (idx_t i = 0; i < idx_t(padding); i++) { - Bit::SetBitInternal(str, i, 1); - } - Bit::Verify(str); -} - -void Bit::SetEmptyBitString(string_t &target, string_t &input) { - char *res_buf = target.GetDataWriteable(); - const char *buf = input.GetData(); - memset(res_buf, 0, input.GetSize()); - res_buf[0] = buf[0]; - Bit::Finalize(target); -} - -void Bit::SetEmptyBitString(string_t &target, idx_t len) { - char *res_buf = target.GetDataWriteable(); - memset(res_buf, 0, target.GetSize()); - res_buf[0] = ComputePadding(len); - Bit::Finalize(target); -} - -// **** casting functions **** -void Bit::ToString(string_t bits, char *output) { - auto data = const_data_ptr_cast(bits.GetData()); - auto len = bits.GetSize(); - - idx_t padding = GetBitPadding(bits); - idx_t output_idx = 0; - for (idx_t bit_idx = padding; bit_idx < 8; bit_idx++) { - output[output_idx++] = data[1] & (1 << (7 - bit_idx)) ? '1' : '0'; - } - for (idx_t byte_idx = 2; byte_idx < len; byte_idx++) { - for (idx_t bit_idx = 0; bit_idx < 8; bit_idx++) { - output[output_idx++] = data[byte_idx] & (1 << (7 - bit_idx)) ? '1' : '0'; - } - } -} - -string Bit::ToString(string_t str) { - auto len = BitLength(str); - auto buffer = make_unsafe_uniq_array(len); - ToString(str, buffer.get()); - return string(buffer.get(), len); -} - -bool Bit::TryGetBitStringSize(string_t str, idx_t &str_len, string *error_message) { - auto data = const_data_ptr_cast(str.GetData()); - auto len = str.GetSize(); - str_len = 0; - for (idx_t i = 0; i < len; i++) { - if (data[i] == '0' || data[i] == '1') { - str_len++; - } else { - string error = StringUtil::Format("Invalid character encountered in string -> bit conversion: '%s'", - string(const_char_ptr_cast(data) + i, 1)); - HandleCastError::AssignError(error, error_message); - return false; - } - } - if (str_len == 0) { - string error = "Cannot cast empty string to BIT"; - HandleCastError::AssignError(error, error_message); - return false; - } - str_len = ComputeBitstringLen(str_len); - return true; -} - -void Bit::ToBit(string_t str, string_t &output_str) { - auto data = const_data_ptr_cast(str.GetData()); - auto len = str.GetSize(); - auto output = output_str.GetDataWriteable(); - - char byte = 0; - idx_t padded_byte = len % 8; - for (idx_t i = 0; i < padded_byte; i++) { - byte <<= 1; - if (data[i] == '1') { - byte |= 1; - } - } - if (padded_byte != 0) { - *(output++) = (8 - padded_byte); // the first byte contains the number of padded zeroes - } - *(output++) = byte; - - for (idx_t byte_idx = padded_byte; byte_idx < len; byte_idx += 8) { - byte = 0; - for (idx_t bit_idx = 0; bit_idx < 8; bit_idx++) { - byte <<= 1; - if (data[byte_idx + bit_idx] == '1') { - byte |= 1; - } - } - *(output++) = byte; - } - Bit::Finalize(output_str); - Bit::Verify(output_str); -} - -string Bit::ToBit(string_t str) { - auto bit_len = GetBitSize(str); - auto buffer = make_unsafe_uniq_array(bit_len); - string_t output_str(buffer.get(), bit_len); - Bit::ToBit(str, output_str); - return output_str.GetString(); -} - -void Bit::BlobToBit(string_t blob, string_t &output_str) { - auto data = const_data_ptr_cast(blob.GetData()); - auto output = output_str.GetDataWriteable(); - idx_t size = blob.GetSize(); - - *output = 0; // No padding - memcpy(output + 1, data, size); -} - -string Bit::BlobToBit(string_t blob) { - auto buffer = make_unsafe_uniq_array(blob.GetSize() + 1); - string_t output_str(buffer.get(), blob.GetSize() + 1); - Bit::BlobToBit(blob, output_str); - return output_str.GetString(); -} - -void Bit::BitToBlob(string_t bit, string_t &output_blob) { - D_ASSERT(bit.GetSize() == output_blob.GetSize() + 1); - - auto data = const_data_ptr_cast(bit.GetData()); - auto output = output_blob.GetDataWriteable(); - idx_t size = output_blob.GetSize(); - - output[0] = GetFirstByte(bit); - if (size > 2) { - ++output; - // First byte in bitstring contains amount of padded bits, - // second byte in bitstring is the padded byte, - // therefore the rest of the data starts at data + 2 (third byte) - memcpy(output, data + 2, size - 1); - } -} - -string Bit::BitToBlob(string_t bit) { - D_ASSERT(bit.GetSize() > 1); - - auto buffer = make_unsafe_uniq_array(bit.GetSize() - 1); - string_t output_str(buffer.get(), bit.GetSize() - 1); - Bit::BitToBlob(bit, output_str); - return output_str.GetString(); -} - -// **** scalar functions **** -void Bit::BitString(const string_t &input, const idx_t &bit_length, string_t &result) { - char *res_buf = result.GetDataWriteable(); - const char *buf = input.GetData(); - - auto padding = ComputePadding(bit_length); - res_buf[0] = padding; - for (idx_t i = 0; i < bit_length; i++) { - if (i < bit_length - input.GetSize()) { - Bit::SetBit(result, i, 0); - } else { - idx_t bit = buf[i - (bit_length - input.GetSize())] == '1' ? 1 : 0; - Bit::SetBit(result, i, bit); - } - } - Bit::Finalize(result); -} - -idx_t Bit::BitLength(string_t bits) { - return ((bits.GetSize() - 1) * 8) - GetBitPadding(bits); -} - -idx_t Bit::OctetLength(string_t bits) { - return bits.GetSize() - 1; -} - -idx_t Bit::BitCount(string_t bits) { - idx_t count = 0; - const char *buf = bits.GetData(); - for (idx_t byte_idx = 1; byte_idx < OctetLength(bits) + 1; byte_idx++) { - for (idx_t bit_idx = 0; bit_idx < 8; bit_idx++) { - count += (buf[byte_idx] & (1 << bit_idx)) ? 1 : 0; - } - } - return count - GetBitPadding(bits); -} - -idx_t Bit::BitPosition(string_t substring, string_t bits) { - const char *buf = bits.GetData(); - auto len = bits.GetSize(); - auto substr_len = BitLength(substring); - idx_t substr_idx = 0; - - for (idx_t bit_idx = GetBitPadding(bits); bit_idx < 8; bit_idx++) { - idx_t bit = buf[1] & (1 << (7 - bit_idx)) ? 1 : 0; - if (bit == GetBit(substring, substr_idx)) { - substr_idx++; - if (substr_idx == substr_len) { - return (bit_idx - GetBitPadding(bits)) - substr_len + 2; - } - } else { - substr_idx = 0; - } - } - - for (idx_t byte_idx = 2; byte_idx < len; byte_idx++) { - for (idx_t bit_idx = 0; bit_idx < 8; bit_idx++) { - idx_t bit = buf[byte_idx] & (1 << (7 - bit_idx)) ? 1 : 0; - if (bit == GetBit(substring, substr_idx)) { - substr_idx++; - if (substr_idx == substr_len) { - return (((byte_idx - 1) * 8) + bit_idx - GetBitPadding(bits)) - substr_len + 2; - } - } else { - substr_idx = 0; - } - } - } - return 0; -} - -idx_t Bit::GetBit(string_t bit_string, idx_t n) { - return Bit::GetBitInternal(bit_string, n + GetBitPadding(bit_string)); -} - -idx_t Bit::GetBitIndex(idx_t n) { - return n / 8 + 1; -} - -idx_t Bit::GetBitInternal(string_t bit_string, idx_t n) { - const char *buf = bit_string.GetData(); - auto idx = Bit::GetBitIndex(n); - D_ASSERT(idx < bit_string.GetSize()); - char byte = buf[idx] >> (7 - (n % 8)); - return (byte & 1 ? 1 : 0); -} - -void Bit::SetBit(string_t &bit_string, idx_t n, idx_t new_value) { - SetBitInternal(bit_string, n + GetBitPadding(bit_string), new_value); -} - -void Bit::SetBitInternal(string_t &bit_string, idx_t n, idx_t new_value) { - char *buf = bit_string.GetDataWriteable(); - - auto idx = Bit::GetBitIndex(n); - D_ASSERT(idx < bit_string.GetSize()); - char shift_byte = 1 << (7 - (n % 8)); - if (new_value == 0) { - shift_byte = ~shift_byte; - buf[idx] &= shift_byte; - } else { - buf[idx] |= shift_byte; - } -} - -// **** BITWISE operators **** -void Bit::RightShift(const string_t &bit_string, const idx_t &shift, string_t &result) { - char *res_buf = result.GetDataWriteable(); - const char *buf = bit_string.GetData(); - res_buf[0] = buf[0]; - for (idx_t i = 0; i < Bit::BitLength(result); i++) { - if (i < shift) { - Bit::SetBit(result, i, 0); - } else { - idx_t bit = Bit::GetBit(bit_string, i - shift); - Bit::SetBit(result, i, bit); - } - } - Bit::Finalize(result); -} - -void Bit::LeftShift(const string_t &bit_string, const idx_t &shift, string_t &result) { - char *res_buf = result.GetDataWriteable(); - const char *buf = bit_string.GetData(); - res_buf[0] = buf[0]; - for (idx_t i = 0; i < Bit::BitLength(bit_string); i++) { - if (i < (Bit::BitLength(bit_string) - shift)) { - idx_t bit = Bit::GetBit(bit_string, shift + i); - Bit::SetBit(result, i, bit); - } else { - Bit::SetBit(result, i, 0); - } - } - Bit::Finalize(result); - Bit::Verify(result); -} - -void Bit::BitwiseAnd(const string_t &rhs, const string_t &lhs, string_t &result) { - if (Bit::BitLength(lhs) != Bit::BitLength(rhs)) { - throw InvalidInputException("Cannot AND bit strings of different sizes"); - } - - char *buf = result.GetDataWriteable(); - const char *r_buf = rhs.GetData(); - const char *l_buf = lhs.GetData(); - - buf[0] = l_buf[0]; - for (idx_t i = 1; i < lhs.GetSize(); i++) { - buf[i] = l_buf[i] & r_buf[i]; - } - // and should preserve padding bits - Bit::Verify(result); -} - -void Bit::BitwiseOr(const string_t &rhs, const string_t &lhs, string_t &result) { - if (Bit::BitLength(lhs) != Bit::BitLength(rhs)) { - throw InvalidInputException("Cannot OR bit strings of different sizes"); - } - - char *buf = result.GetDataWriteable(); - const char *r_buf = rhs.GetData(); - const char *l_buf = lhs.GetData(); - - buf[0] = l_buf[0]; - for (idx_t i = 1; i < lhs.GetSize(); i++) { - buf[i] = l_buf[i] | r_buf[i]; - } - // or should preserve padding bits - Bit::Verify(result); -} - -void Bit::BitwiseXor(const string_t &rhs, const string_t &lhs, string_t &result) { - if (Bit::BitLength(lhs) != Bit::BitLength(rhs)) { - throw InvalidInputException("Cannot XOR bit strings of different sizes"); - } - - char *buf = result.GetDataWriteable(); - const char *r_buf = rhs.GetData(); - const char *l_buf = lhs.GetData(); - - buf[0] = l_buf[0]; - for (idx_t i = 1; i < lhs.GetSize(); i++) { - buf[i] = l_buf[i] ^ r_buf[i]; - } - Bit::Finalize(result); -} - -void Bit::BitwiseNot(const string_t &input, string_t &result) { - char *result_buf = result.GetDataWriteable(); - const char *buf = input.GetData(); - - result_buf[0] = buf[0]; - for (idx_t i = 1; i < input.GetSize(); i++) { - result_buf[i] = ~buf[i]; - } - Bit::Finalize(result); -} - -void Bit::Verify(const string_t &input) { -#ifdef DEBUG - // bit strings require all padding bits to be set to 1 - auto padding = GetBitPadding(input); - for (idx_t i = 0; i < padding; i++) { - D_ASSERT(Bit::GetBitInternal(input, i)); - } -#endif -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -constexpr const char *Blob::HEX_TABLE; -const int Blob::HEX_MAP[256] = { - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, - -1, -1, -1, -1, -1, -1, -1, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}; - -bool IsRegularCharacter(data_t c) { - return c >= 32 && c <= 126 && c != '\\' && c != '\'' && c != '"'; -} - -idx_t Blob::GetStringSize(string_t blob) { - auto data = const_data_ptr_cast(blob.GetData()); - auto len = blob.GetSize(); - idx_t str_len = 0; - for (idx_t i = 0; i < len; i++) { - if (IsRegularCharacter(data[i])) { - // ascii characters are rendered as-is - str_len++; - } else { - // non-ascii characters are rendered as hexadecimal (e.g. \x00) - str_len += 4; - } - } - return str_len; -} - -void Blob::ToString(string_t blob, char *output) { - auto data = const_data_ptr_cast(blob.GetData()); - auto len = blob.GetSize(); - idx_t str_idx = 0; - for (idx_t i = 0; i < len; i++) { - if (IsRegularCharacter(data[i])) { - // ascii characters are rendered as-is - output[str_idx++] = data[i]; - } else { - auto byte_a = data[i] >> 4; - auto byte_b = data[i] & 0x0F; - D_ASSERT(byte_a >= 0 && byte_a < 16); - D_ASSERT(byte_b >= 0 && byte_b < 16); - // non-ascii characters are rendered as hexadecimal (e.g. \x00) - output[str_idx++] = '\\'; - output[str_idx++] = 'x'; - output[str_idx++] = Blob::HEX_TABLE[byte_a]; - output[str_idx++] = Blob::HEX_TABLE[byte_b]; - } - } - D_ASSERT(str_idx == GetStringSize(blob)); -} - -string Blob::ToString(string_t blob) { - auto str_len = GetStringSize(blob); - auto buffer = make_unsafe_uniq_array(str_len); - Blob::ToString(blob, buffer.get()); - return string(buffer.get(), str_len); -} - -bool Blob::TryGetBlobSize(string_t str, idx_t &str_len, string *error_message) { - auto data = const_data_ptr_cast(str.GetData()); - auto len = str.GetSize(); - str_len = 0; - for (idx_t i = 0; i < len; i++) { - if (data[i] == '\\') { - if (i + 3 >= len) { - string error = "Invalid hex escape code encountered in string -> blob conversion: " - "unterminated escape code at end of blob"; - HandleCastError::AssignError(error, error_message); - return false; - } - if (data[i + 1] != 'x' || Blob::HEX_MAP[data[i + 2]] < 0 || Blob::HEX_MAP[data[i + 3]] < 0) { - string error = - StringUtil::Format("Invalid hex escape code encountered in string -> blob conversion: %s", - string(const_char_ptr_cast(data) + i, 4)); - HandleCastError::AssignError(error, error_message); - return false; - } - str_len++; - i += 3; - } else if (data[i] <= 127) { - str_len++; - } else { - string error = "Invalid byte encountered in STRING -> BLOB conversion. All non-ascii characters " - "must be escaped with hex codes (e.g. \\xAA)"; - HandleCastError::AssignError(error, error_message); - return false; - } - } - return true; -} - -idx_t Blob::GetBlobSize(string_t str) { - string error_message; - idx_t str_len; - if (!Blob::TryGetBlobSize(str, str_len, &error_message)) { - throw ConversionException(error_message); - } - return str_len; -} - -void Blob::ToBlob(string_t str, data_ptr_t output) { - auto data = const_data_ptr_cast(str.GetData()); - auto len = str.GetSize(); - idx_t blob_idx = 0; - for (idx_t i = 0; i < len; i++) { - if (data[i] == '\\') { - int byte_a = Blob::HEX_MAP[data[i + 2]]; - int byte_b = Blob::HEX_MAP[data[i + 3]]; - D_ASSERT(i + 3 < len); - D_ASSERT(byte_a >= 0 && byte_b >= 0); - D_ASSERT(data[i + 1] == 'x'); - output[blob_idx++] = (byte_a << 4) + byte_b; - i += 3; - } else if (data[i] <= 127) { - output[blob_idx++] = data_t(data[i]); - } else { - throw ConversionException("Invalid byte encountered in STRING -> BLOB conversion. All non-ascii characters " - "must be escaped with hex codes (e.g. \\xAA)"); - } - } - D_ASSERT(blob_idx == GetBlobSize(str)); -} - -string Blob::ToBlob(string_t str) { - auto blob_len = GetBlobSize(str); - auto buffer = make_unsafe_uniq_array(blob_len); - Blob::ToBlob(str, data_ptr_cast(buffer.get())); - return string(buffer.get(), blob_len); -} - -// base64 functions are adapted from https://gist.github.com/tomykaira/f0fd86b6c73063283afe550bc5d77594 -idx_t Blob::ToBase64Size(string_t blob) { - // every 4 characters in base64 encode 3 bytes, plus (potential) padding at the end - auto input_size = blob.GetSize(); - return ((input_size + 2) / 3) * 4; -} - -void Blob::ToBase64(string_t blob, char *output) { - auto input_data = const_data_ptr_cast(blob.GetData()); - auto input_size = blob.GetSize(); - idx_t out_idx = 0; - idx_t i; - // convert the bulk of the string to base64 - // this happens in steps of 3 bytes -> 4 output bytes - for (i = 0; i + 2 < input_size; i += 3) { - output[out_idx++] = Blob::BASE64_MAP[(input_data[i] >> 2) & 0x3F]; - output[out_idx++] = Blob::BASE64_MAP[((input_data[i] & 0x3) << 4) | ((input_data[i + 1] & 0xF0) >> 4)]; - output[out_idx++] = Blob::BASE64_MAP[((input_data[i + 1] & 0xF) << 2) | ((input_data[i + 2] & 0xC0) >> 6)]; - output[out_idx++] = Blob::BASE64_MAP[input_data[i + 2] & 0x3F]; - } - - if (i < input_size) { - // there are one or two bytes left over: we have to insert padding - // first write the first 6 bits of the first byte - output[out_idx++] = Blob::BASE64_MAP[(input_data[i] >> 2) & 0x3F]; - // now check the character count - if (i == input_size - 1) { - // single byte left over: convert the remainder of that byte and insert padding - output[out_idx++] = Blob::BASE64_MAP[((input_data[i] & 0x3) << 4)]; - output[out_idx++] = Blob::BASE64_PADDING; - } else { - // two bytes left over: convert the second byte as well - output[out_idx++] = Blob::BASE64_MAP[((input_data[i] & 0x3) << 4) | ((input_data[i + 1] & 0xF0) >> 4)]; - output[out_idx++] = Blob::BASE64_MAP[((input_data[i + 1] & 0xF) << 2)]; - } - output[out_idx++] = Blob::BASE64_PADDING; - } -} - -static constexpr int BASE64_DECODING_TABLE[256] = { - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 62, -1, -1, -1, 63, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, - -1, -1, -1, -1, -1, -1, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, - 22, 23, 24, 25, -1, -1, -1, -1, -1, -1, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, - 45, 46, 47, 48, 49, 50, 51, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}; - -idx_t Blob::FromBase64Size(string_t str) { - auto input_data = str.GetData(); - auto input_size = str.GetSize(); - if (input_size % 4 != 0) { - // valid base64 needs to always be cleanly divisible by 4 - throw ConversionException("Could not decode string \"%s\" as base64: length must be a multiple of 4", - str.GetString()); - } - if (input_size < 4) { - // empty string - return 0; - } - auto base_size = input_size / 4 * 3; - // check for padding to figure out the length - if (input_data[input_size - 2] == Blob::BASE64_PADDING) { - // two bytes of padding - return base_size - 2; - } - if (input_data[input_size - 1] == Blob::BASE64_PADDING) { - // one byte of padding - return base_size - 1; - } - // no padding - return base_size; -} - -template -uint32_t DecodeBase64Bytes(const string_t &str, const_data_ptr_t input_data, idx_t base_idx) { - int decoded_bytes[4]; - for (idx_t decode_idx = 0; decode_idx < 4; decode_idx++) { - if (ALLOW_PADDING && decode_idx >= 2 && input_data[base_idx + decode_idx] == Blob::BASE64_PADDING) { - // the last two bytes of a base64 string can have padding: in this case we set the byte to 0 - decoded_bytes[decode_idx] = 0; - } else { - decoded_bytes[decode_idx] = BASE64_DECODING_TABLE[input_data[base_idx + decode_idx]]; - } - if (decoded_bytes[decode_idx] < 0) { - throw ConversionException( - "Could not decode string \"%s\" as base64: invalid byte value '%d' at position %d", str.GetString(), - input_data[base_idx + decode_idx], base_idx + decode_idx); - } - } - return (decoded_bytes[0] << 3 * 6) + (decoded_bytes[1] << 2 * 6) + (decoded_bytes[2] << 1 * 6) + - (decoded_bytes[3] << 0 * 6); -} - -void Blob::FromBase64(string_t str, data_ptr_t output, idx_t output_size) { - D_ASSERT(output_size == FromBase64Size(str)); - auto input_data = const_data_ptr_cast(str.GetData()); - auto input_size = str.GetSize(); - if (input_size == 0) { - return; - } - idx_t out_idx = 0; - idx_t i = 0; - for (i = 0; i + 4 < input_size; i += 4) { - auto combined = DecodeBase64Bytes(str, input_data, i); - output[out_idx++] = (combined >> 2 * 8) & 0xFF; - output[out_idx++] = (combined >> 1 * 8) & 0xFF; - output[out_idx++] = (combined >> 0 * 8) & 0xFF; - } - // decode the final four bytes: padding is allowed here - auto combined = DecodeBase64Bytes(str, input_data, i); - output[out_idx++] = (combined >> 2 * 8) & 0xFF; - if (out_idx < output_size) { - output[out_idx++] = (combined >> 1 * 8) & 0xFF; - } - if (out_idx < output_size) { - output[out_idx++] = (combined >> 0 * 8) & 0xFF; - } -} - -} // namespace duckdb - - - -namespace duckdb { - -const int64_t NumericHelper::POWERS_OF_TEN[] {1, - 10, - 100, - 1000, - 10000, - 100000, - 1000000, - 10000000, - 100000000, - 1000000000, - 10000000000, - 100000000000, - 1000000000000, - 10000000000000, - 100000000000000, - 1000000000000000, - 10000000000000000, - 100000000000000000, - 1000000000000000000}; - -const double NumericHelper::DOUBLE_POWERS_OF_TEN[] {1e0, 1e1, 1e2, 1e3, 1e4, 1e5, 1e6, 1e7, 1e8, 1e9, - 1e10, 1e11, 1e12, 1e13, 1e14, 1e15, 1e16, 1e17, 1e18, 1e19, - 1e20, 1e21, 1e22, 1e23, 1e24, 1e25, 1e26, 1e27, 1e28, 1e29, - 1e30, 1e31, 1e32, 1e33, 1e34, 1e35, 1e36, 1e37, 1e38, 1e39}; - -template <> -int NumericHelper::UnsignedLength(uint8_t value) { - int length = 1; - length += value >= 10; - length += value >= 100; - return length; -} - -template <> -int NumericHelper::UnsignedLength(uint16_t value) { - int length = 1; - length += value >= 10; - length += value >= 100; - length += value >= 1000; - length += value >= 10000; - return length; -} - -template <> -int NumericHelper::UnsignedLength(uint32_t value) { - if (value >= 10000) { - int length = 5; - length += value >= 100000; - length += value >= 1000000; - length += value >= 10000000; - length += value >= 100000000; - length += value >= 1000000000; - return length; - } else { - int length = 1; - length += value >= 10; - length += value >= 100; - length += value >= 1000; - return length; - } -} - -template <> -int NumericHelper::UnsignedLength(uint64_t value) { - if (value >= 10000000000ULL) { - if (value >= 1000000000000000ULL) { - int length = 16; - length += value >= 10000000000000000ULL; - length += value >= 100000000000000000ULL; - length += value >= 1000000000000000000ULL; - length += value >= 10000000000000000000ULL; - return length; - } else { - int length = 11; - length += value >= 100000000000ULL; - length += value >= 1000000000000ULL; - length += value >= 10000000000000ULL; - length += value >= 100000000000000ULL; - return length; - } - } else { - if (value >= 100000ULL) { - int length = 6; - length += value >= 1000000ULL; - length += value >= 10000000ULL; - length += value >= 100000000ULL; - length += value >= 1000000000ULL; - return length; - } else { - int length = 1; - length += value >= 10ULL; - length += value >= 100ULL; - length += value >= 1000ULL; - length += value >= 10000ULL; - return length; - } - } -} - -template <> -std::string NumericHelper::ToString(hugeint_t value) { - return Hugeint::ToString(value); -} - -} // namespace duckdb - - - - - - - - - - -#include -#include - -namespace duckdb { - -ChunkCollection::ChunkCollection(Allocator &allocator) : allocator(allocator), count(0) { -} - -ChunkCollection::ChunkCollection(ClientContext &context) : ChunkCollection(Allocator::Get(context)) { -} - -void ChunkCollection::Verify() { -#ifdef DEBUG - for (auto &chunk : chunks) { - chunk->Verify(); - } -#endif -} - -void ChunkCollection::Append(ChunkCollection &other) { - for (auto &chunk : other.chunks) { - Append(*chunk); - } -} - -void ChunkCollection::Merge(ChunkCollection &other) { - if (other.count == 0) { - return; - } - if (count == 0) { - chunks = std::move(other.chunks); - types = std::move(other.types); - count = other.count; - return; - } - unique_ptr old_back; - if (!chunks.empty() && chunks.back()->size() != STANDARD_VECTOR_SIZE) { - old_back = std::move(chunks.back()); - chunks.pop_back(); - count -= old_back->size(); - } - for (auto &chunk : other.chunks) { - chunks.push_back(std::move(chunk)); - } - count += other.count; - if (old_back) { - Append(*old_back); - } - Verify(); -} - -void ChunkCollection::Append(DataChunk &new_chunk) { - if (new_chunk.size() == 0) { - return; - } - new_chunk.Verify(); - - // we have to ensure that every chunk in the ChunkCollection is completely - // filled, otherwise our O(1) lookup in GetValue and SetValue does not work - // first fill the latest chunk, if it exists - count += new_chunk.size(); - - idx_t remaining_data = new_chunk.size(); - idx_t offset = 0; - if (chunks.empty()) { - // first chunk - types = new_chunk.GetTypes(); - } else { - // the types of the new chunk should match the types of the previous one - D_ASSERT(types.size() == new_chunk.ColumnCount()); - auto new_types = new_chunk.GetTypes(); - for (idx_t i = 0; i < types.size(); i++) { - if (new_types[i] != types[i]) { - throw TypeMismatchException(new_types[i], types[i], "Type mismatch when combining rows"); - } - if (types[i].InternalType() == PhysicalType::LIST) { - // need to check all the chunks because they can have only-null list entries - for (auto &chunk : chunks) { - auto &chunk_vec = chunk->data[i]; - auto &new_vec = new_chunk.data[i]; - auto &chunk_type = chunk_vec.GetType(); - auto &new_type = new_vec.GetType(); - if (chunk_type != new_type) { - throw TypeMismatchException(chunk_type, new_type, "Type mismatch when combining lists"); - } - } - } - // TODO check structs, too - } - - // first append data to the current chunk - DataChunk &last_chunk = *chunks.back(); - idx_t added_data = MinValue(remaining_data, STANDARD_VECTOR_SIZE - last_chunk.size()); - if (added_data > 0) { - // copy elements to the last chunk - new_chunk.Flatten(); - // have to be careful here: setting the cardinality without calling normalify can cause incorrect partial - // decompression - idx_t old_count = new_chunk.size(); - new_chunk.SetCardinality(added_data); - - last_chunk.Append(new_chunk); - remaining_data -= added_data; - // reset the chunk to the old data - new_chunk.SetCardinality(old_count); - offset = added_data; - } - } - - if (remaining_data > 0) { - // create a new chunk and fill it with the remainder - auto chunk = make_uniq(); - chunk->Initialize(allocator, types); - new_chunk.Copy(*chunk, offset); - chunks.push_back(std::move(chunk)); - } -} - -void ChunkCollection::Append(unique_ptr new_chunk) { - if (types.empty()) { - types = new_chunk->GetTypes(); - } - D_ASSERT(types == new_chunk->GetTypes()); - count += new_chunk->size(); - chunks.push_back(std::move(new_chunk)); -} - -void ChunkCollection::Fuse(ChunkCollection &other) { - if (count == 0) { - chunks.reserve(other.ChunkCount()); - for (idx_t chunk_idx = 0; chunk_idx < other.ChunkCount(); ++chunk_idx) { - auto lhs = make_uniq(); - auto &rhs = other.GetChunk(chunk_idx); - lhs->data.reserve(rhs.data.size()); - for (auto &v : rhs.data) { - lhs->data.emplace_back(v); - } - lhs->SetCardinality(rhs.size()); - chunks.push_back(std::move(lhs)); - } - count = other.Count(); - } else { - D_ASSERT(this->ChunkCount() == other.ChunkCount()); - for (idx_t chunk_idx = 0; chunk_idx < ChunkCount(); ++chunk_idx) { - auto &lhs = this->GetChunk(chunk_idx); - auto &rhs = other.GetChunk(chunk_idx); - D_ASSERT(lhs.size() == rhs.size()); - for (auto &v : rhs.data) { - lhs.data.emplace_back(v); - } - } - } - types.insert(types.end(), other.types.begin(), other.types.end()); -} - -Value ChunkCollection::GetValue(idx_t column, idx_t index) { - return chunks[LocateChunk(index)]->GetValue(column, index % STANDARD_VECTOR_SIZE); -} - -void ChunkCollection::SetValue(idx_t column, idx_t index, const Value &value) { - chunks[LocateChunk(index)]->SetValue(column, index % STANDARD_VECTOR_SIZE, value); -} - -void ChunkCollection::CopyCell(idx_t column, idx_t index, Vector &target, idx_t target_offset) { - auto &chunk = GetChunkForRow(index); - auto &source = chunk.data[column]; - const auto source_offset = index % STANDARD_VECTOR_SIZE; - VectorOperations::Copy(source, target, source_offset + 1, source_offset, target_offset); -} - -string ChunkCollection::ToString() const { - return chunks.empty() ? "ChunkCollection [ 0 ]" - : "ChunkCollection [ " + std::to_string(count) + " ]: \n" + chunks[0]->ToString(); -} - -void ChunkCollection::Print() const { - Printer::Print(ToString()); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -ColumnDataAllocator::ColumnDataAllocator(Allocator &allocator) : type(ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR) { - alloc.allocator = &allocator; -} - -ColumnDataAllocator::ColumnDataAllocator(BufferManager &buffer_manager) - : type(ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR) { - alloc.buffer_manager = &buffer_manager; -} - -ColumnDataAllocator::ColumnDataAllocator(ClientContext &context, ColumnDataAllocatorType allocator_type) - : type(allocator_type) { - switch (type) { - case ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR: - case ColumnDataAllocatorType::HYBRID: - alloc.buffer_manager = &BufferManager::GetBufferManager(context); - break; - case ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR: - alloc.allocator = &Allocator::Get(context); - break; - default: - throw InternalException("Unrecognized column data allocator type"); - } -} - -ColumnDataAllocator::ColumnDataAllocator(ColumnDataAllocator &other) { - type = other.GetType(); - switch (type) { - case ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR: - case ColumnDataAllocatorType::HYBRID: - alloc.allocator = other.alloc.allocator; - break; - case ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR: - alloc.buffer_manager = other.alloc.buffer_manager; - break; - default: - throw InternalException("Unrecognized column data allocator type"); - } -} - -BufferHandle ColumnDataAllocator::Pin(uint32_t block_id) { - D_ASSERT(type == ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR || type == ColumnDataAllocatorType::HYBRID); - shared_ptr handle; - if (shared) { - // we only need to grab the lock when accessing the vector, because vector access is not thread-safe: - // the vector can be resized by another thread while we try to access it - lock_guard guard(lock); - handle = blocks[block_id].handle; - } else { - handle = blocks[block_id].handle; - } - return alloc.buffer_manager->Pin(handle); -} - -BufferHandle ColumnDataAllocator::AllocateBlock(idx_t size) { - D_ASSERT(type == ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR || type == ColumnDataAllocatorType::HYBRID); - auto block_size = MaxValue(size, Storage::BLOCK_SIZE); - BlockMetaData data; - data.size = 0; - data.capacity = block_size; - auto pin = alloc.buffer_manager->Allocate(block_size, false, &data.handle); - blocks.push_back(std::move(data)); - return pin; -} - -void ColumnDataAllocator::AllocateEmptyBlock(idx_t size) { - auto allocation_amount = MaxValue(NextPowerOfTwo(size), 4096); - if (!blocks.empty()) { - idx_t last_capacity = blocks.back().capacity; - auto next_capacity = MinValue(last_capacity * 2, last_capacity + Storage::BLOCK_SIZE); - allocation_amount = MaxValue(next_capacity, allocation_amount); - } - D_ASSERT(type == ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR); - BlockMetaData data; - data.size = 0; - data.capacity = allocation_amount; - data.handle = nullptr; - blocks.push_back(std::move(data)); -} - -void ColumnDataAllocator::AssignPointer(uint32_t &block_id, uint32_t &offset, data_ptr_t pointer) { - auto pointer_value = uintptr_t(pointer); - if (sizeof(uintptr_t) == sizeof(uint32_t)) { - block_id = uint32_t(pointer_value); - } else if (sizeof(uintptr_t) == sizeof(uint64_t)) { - block_id = uint32_t(pointer_value & 0xFFFFFFFF); - offset = uint32_t(pointer_value >> 32); - } else { - throw InternalException("ColumnDataCollection: Architecture not supported!?"); - } -} - -void ColumnDataAllocator::AllocateBuffer(idx_t size, uint32_t &block_id, uint32_t &offset, - ChunkManagementState *chunk_state) { - D_ASSERT(allocated_data.empty()); - if (blocks.empty() || blocks.back().Capacity() < size) { - auto pinned_block = AllocateBlock(size); - if (chunk_state) { - D_ASSERT(!blocks.empty()); - auto new_block_id = blocks.size() - 1; - chunk_state->handles[new_block_id] = std::move(pinned_block); - } - } - auto &block = blocks.back(); - D_ASSERT(size <= block.capacity - block.size); - block_id = blocks.size() - 1; - if (chunk_state && chunk_state->handles.find(block_id) == chunk_state->handles.end()) { - // not guaranteed to be pinned already by this thread (if shared allocator) - chunk_state->handles[block_id] = alloc.buffer_manager->Pin(blocks[block_id].handle); - } - offset = block.size; - block.size += size; -} - -void ColumnDataAllocator::AllocateMemory(idx_t size, uint32_t &block_id, uint32_t &offset, - ChunkManagementState *chunk_state) { - D_ASSERT(blocks.size() == allocated_data.size()); - if (blocks.empty() || blocks.back().Capacity() < size) { - AllocateEmptyBlock(size); - auto &last_block = blocks.back(); - auto allocated = alloc.allocator->Allocate(last_block.capacity); - allocated_data.push_back(std::move(allocated)); - } - auto &block = blocks.back(); - D_ASSERT(size <= block.capacity - block.size); - AssignPointer(block_id, offset, allocated_data.back().get() + block.size); - block.size += size; -} - -void ColumnDataAllocator::AllocateData(idx_t size, uint32_t &block_id, uint32_t &offset, - ChunkManagementState *chunk_state) { - switch (type) { - case ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR: - case ColumnDataAllocatorType::HYBRID: - if (shared) { - lock_guard guard(lock); - AllocateBuffer(size, block_id, offset, chunk_state); - } else { - AllocateBuffer(size, block_id, offset, chunk_state); - } - break; - case ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR: - D_ASSERT(!shared); - AllocateMemory(size, block_id, offset, chunk_state); - break; - default: - throw InternalException("Unrecognized allocator type"); - } -} - -void ColumnDataAllocator::Initialize(ColumnDataAllocator &other) { - D_ASSERT(other.HasBlocks()); - blocks.push_back(other.blocks.back()); -} - -data_ptr_t ColumnDataAllocator::GetDataPointer(ChunkManagementState &state, uint32_t block_id, uint32_t offset) { - if (type == ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR) { - // in-memory allocator: construct pointer from block_id and offset - if (sizeof(uintptr_t) == sizeof(uint32_t)) { - uintptr_t pointer_value = uintptr_t(block_id); - return (data_ptr_t)pointer_value; // NOLINT - convert from pointer value back to pointer - } else if (sizeof(uintptr_t) == sizeof(uint64_t)) { - uintptr_t pointer_value = (uintptr_t(offset) << 32) | uintptr_t(block_id); - return (data_ptr_t)pointer_value; // NOLINT - convert from pointer value back to pointer - } else { - throw InternalException("ColumnDataCollection: Architecture not supported!?"); - } - } - D_ASSERT(state.handles.find(block_id) != state.handles.end()); - return state.handles[block_id].Ptr() + offset; -} - -void ColumnDataAllocator::UnswizzlePointers(ChunkManagementState &state, Vector &result, idx_t v_offset, uint16_t count, - uint32_t block_id, uint32_t offset) { - D_ASSERT(result.GetType().InternalType() == PhysicalType::VARCHAR); - lock_guard guard(lock); - - auto &validity = FlatVector::Validity(result); - auto strings = FlatVector::GetData(result); - - // find first non-inlined string - uint32_t i = v_offset; - const uint32_t end = v_offset + count; - for (; i < end; i++) { - if (!validity.RowIsValid(i)) { - continue; - } - if (!strings[i].IsInlined()) { - break; - } - } - // at least one string must be non-inlined, otherwise this function should not be called - D_ASSERT(i < end); - - auto base_ptr = char_ptr_cast(GetDataPointer(state, block_id, offset)); - if (strings[i].GetData() == base_ptr) { - // pointers are still valid - return; - } - - // pointer mismatch! pointers are invalid, set them correctly - for (; i < end; i++) { - if (!validity.RowIsValid(i)) { - continue; - } - if (strings[i].IsInlined()) { - continue; - } - strings[i].SetPointer(base_ptr); - base_ptr += strings[i].GetSize(); - } -} - -void ColumnDataAllocator::DeleteBlock(uint32_t block_id) { - blocks[block_id].handle->SetCanDestroy(true); -} - -Allocator &ColumnDataAllocator::GetAllocator() { - return type == ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR ? *alloc.allocator - : alloc.buffer_manager->GetBufferAllocator(); -} - -void ColumnDataAllocator::InitializeChunkState(ChunkManagementState &state, ChunkMetaData &chunk) { - if (type != ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR && type != ColumnDataAllocatorType::HYBRID) { - // nothing to pin - return; - } - // release any handles that are no longer required - bool found_handle; - do { - found_handle = false; - for (auto it = state.handles.begin(); it != state.handles.end(); it++) { - if (chunk.block_ids.find(it->first) != chunk.block_ids.end()) { - // still required: do not release - continue; - } - state.handles.erase(it); - found_handle = true; - break; - } - } while (found_handle); - - // grab any handles that are now required - for (auto &block_id : chunk.block_ids) { - if (state.handles.find(block_id) != state.handles.end()) { - // already pinned: don't need to do anything - continue; - } - state.handles[block_id] = Pin(block_id); - } -} - -uint32_t BlockMetaData::Capacity() { - D_ASSERT(size <= capacity); - return capacity - size; -} - -} // namespace duckdb - - - - - - - - - - - -namespace duckdb { - -struct ColumnDataMetaData; - -typedef void (*column_data_copy_function_t)(ColumnDataMetaData &meta_data, const UnifiedVectorFormat &source_data, - Vector &source, idx_t offset, idx_t copy_count); - -struct ColumnDataCopyFunction { - column_data_copy_function_t function; - vector child_functions; -}; - -struct ColumnDataMetaData { - ColumnDataMetaData(ColumnDataCopyFunction ©_function, ColumnDataCollectionSegment &segment, - ColumnDataAppendState &state, ChunkMetaData &chunk_data, VectorDataIndex vector_data_index) - : copy_function(copy_function), segment(segment), state(state), chunk_data(chunk_data), - vector_data_index(vector_data_index) { - } - ColumnDataMetaData(ColumnDataCopyFunction ©_function, ColumnDataMetaData &parent, - VectorDataIndex vector_data_index) - : copy_function(copy_function), segment(parent.segment), state(parent.state), chunk_data(parent.chunk_data), - vector_data_index(vector_data_index) { - } - - ColumnDataCopyFunction ©_function; - ColumnDataCollectionSegment &segment; - ColumnDataAppendState &state; - ChunkMetaData &chunk_data; - VectorDataIndex vector_data_index; - idx_t child_list_size = DConstants::INVALID_INDEX; - - VectorMetaData &GetVectorMetaData() { - return segment.GetVectorData(vector_data_index); - } -}; - -//! Explicitly initialized without types -ColumnDataCollection::ColumnDataCollection(Allocator &allocator_p) { - types.clear(); - count = 0; - this->finished_append = false; - allocator = make_shared(allocator_p); -} - -ColumnDataCollection::ColumnDataCollection(Allocator &allocator_p, vector types_p) { - Initialize(std::move(types_p)); - allocator = make_shared(allocator_p); -} - -ColumnDataCollection::ColumnDataCollection(BufferManager &buffer_manager, vector types_p) { - Initialize(std::move(types_p)); - allocator = make_shared(buffer_manager); -} - -ColumnDataCollection::ColumnDataCollection(shared_ptr allocator_p, vector types_p) { - Initialize(std::move(types_p)); - this->allocator = std::move(allocator_p); -} - -ColumnDataCollection::ColumnDataCollection(ClientContext &context, vector types_p, - ColumnDataAllocatorType type) - : ColumnDataCollection(make_shared(context, type), std::move(types_p)) { - D_ASSERT(!types.empty()); -} - -ColumnDataCollection::ColumnDataCollection(ColumnDataCollection &other) - : ColumnDataCollection(other.allocator, other.types) { - other.finished_append = true; - D_ASSERT(!types.empty()); -} - -ColumnDataCollection::~ColumnDataCollection() { -} - -void ColumnDataCollection::Initialize(vector types_p) { - this->types = std::move(types_p); - this->count = 0; - this->finished_append = false; - D_ASSERT(!types.empty()); - copy_functions.reserve(types.size()); - for (auto &type : types) { - copy_functions.push_back(GetCopyFunction(type)); - } -} - -void ColumnDataCollection::CreateSegment() { - segments.emplace_back(make_uniq(allocator, types)); -} - -Allocator &ColumnDataCollection::GetAllocator() const { - return allocator->GetAllocator(); -} - -idx_t ColumnDataCollection::SizeInBytes() const { - idx_t total_size = 0; - for (const auto &segment : segments) { - total_size += segment->SizeInBytes(); - } - return total_size; -} - -//===--------------------------------------------------------------------===// -// ColumnDataRow -//===--------------------------------------------------------------------===// -ColumnDataRow::ColumnDataRow(DataChunk &chunk_p, idx_t row_index, idx_t base_index) - : chunk(chunk_p), row_index(row_index), base_index(base_index) { -} - -Value ColumnDataRow::GetValue(idx_t column_index) const { - D_ASSERT(column_index < chunk.ColumnCount()); - D_ASSERT(row_index < chunk.size()); - return chunk.data[column_index].GetValue(row_index); -} - -idx_t ColumnDataRow::RowIndex() const { - return base_index + row_index; -} - -//===--------------------------------------------------------------------===// -// ColumnDataRowCollection -//===--------------------------------------------------------------------===// -ColumnDataRowCollection::ColumnDataRowCollection(const ColumnDataCollection &collection) { - if (collection.Count() == 0) { - return; - } - // read all the chunks - ColumnDataScanState temp_scan_state; - collection.InitializeScan(temp_scan_state, ColumnDataScanProperties::DISALLOW_ZERO_COPY); - while (true) { - auto chunk = make_uniq(); - collection.InitializeScanChunk(*chunk); - if (!collection.Scan(temp_scan_state, *chunk)) { - break; - } - chunks.push_back(std::move(chunk)); - } - // now create all of the column data rows - rows.reserve(collection.Count()); - idx_t base_row = 0; - for (auto &chunk : chunks) { - for (idx_t row_idx = 0; row_idx < chunk->size(); row_idx++) { - rows.emplace_back(*chunk, row_idx, base_row); - } - base_row += chunk->size(); - } -} - -ColumnDataRow &ColumnDataRowCollection::operator[](idx_t i) { - return rows[i]; -} - -const ColumnDataRow &ColumnDataRowCollection::operator[](idx_t i) const { - return rows[i]; -} - -Value ColumnDataRowCollection::GetValue(idx_t column, idx_t index) const { - return rows[index].GetValue(column); -} - -//===--------------------------------------------------------------------===// -// ColumnDataChunkIterator -//===--------------------------------------------------------------------===// -ColumnDataChunkIterationHelper ColumnDataCollection::Chunks() const { - vector column_ids; - for (idx_t i = 0; i < ColumnCount(); i++) { - column_ids.push_back(i); - } - return Chunks(column_ids); -} - -ColumnDataChunkIterationHelper ColumnDataCollection::Chunks(vector column_ids) const { - return ColumnDataChunkIterationHelper(*this, std::move(column_ids)); -} - -ColumnDataChunkIterationHelper::ColumnDataChunkIterationHelper(const ColumnDataCollection &collection_p, - vector column_ids_p) - : collection(collection_p), column_ids(std::move(column_ids_p)) { -} - -ColumnDataChunkIterationHelper::ColumnDataChunkIterator::ColumnDataChunkIterator( - const ColumnDataCollection *collection_p, vector column_ids_p) - : collection(collection_p), scan_chunk(make_shared()), row_index(0) { - if (!collection) { - return; - } - collection->InitializeScan(scan_state, std::move(column_ids_p)); - collection->InitializeScanChunk(scan_state, *scan_chunk); - collection->Scan(scan_state, *scan_chunk); -} - -void ColumnDataChunkIterationHelper::ColumnDataChunkIterator::Next() { - if (!collection) { - return; - } - if (!collection->Scan(scan_state, *scan_chunk)) { - collection = nullptr; - row_index = 0; - } else { - row_index += scan_chunk->size(); - } -} - -ColumnDataChunkIterationHelper::ColumnDataChunkIterator & -ColumnDataChunkIterationHelper::ColumnDataChunkIterator::operator++() { - Next(); - return *this; -} - -bool ColumnDataChunkIterationHelper::ColumnDataChunkIterator::operator!=(const ColumnDataChunkIterator &other) const { - return collection != other.collection || row_index != other.row_index; -} - -DataChunk &ColumnDataChunkIterationHelper::ColumnDataChunkIterator::operator*() const { - return *scan_chunk; -} - -//===--------------------------------------------------------------------===// -// ColumnDataRowIterator -//===--------------------------------------------------------------------===// -ColumnDataRowIterationHelper ColumnDataCollection::Rows() const { - return ColumnDataRowIterationHelper(*this); -} - -ColumnDataRowIterationHelper::ColumnDataRowIterationHelper(const ColumnDataCollection &collection_p) - : collection(collection_p) { -} - -ColumnDataRowIterationHelper::ColumnDataRowIterator::ColumnDataRowIterator(const ColumnDataCollection *collection_p) - : collection(collection_p), scan_chunk(make_shared()), current_row(*scan_chunk, 0, 0) { - if (!collection) { - return; - } - collection->InitializeScan(scan_state); - collection->InitializeScanChunk(*scan_chunk); - collection->Scan(scan_state, *scan_chunk); -} - -void ColumnDataRowIterationHelper::ColumnDataRowIterator::Next() { - if (!collection) { - return; - } - current_row.row_index++; - if (current_row.row_index >= scan_chunk->size()) { - current_row.base_index += scan_chunk->size(); - current_row.row_index = 0; - if (!collection->Scan(scan_state, *scan_chunk)) { - // exhausted collection: move iterator to nop state - current_row.base_index = 0; - collection = nullptr; - } - } -} - -ColumnDataRowIterationHelper::ColumnDataRowIterator ColumnDataRowIterationHelper::begin() { // NOLINT - return ColumnDataRowIterationHelper::ColumnDataRowIterator(collection.Count() == 0 ? nullptr : &collection); -} -ColumnDataRowIterationHelper::ColumnDataRowIterator ColumnDataRowIterationHelper::end() { // NOLINT - return ColumnDataRowIterationHelper::ColumnDataRowIterator(nullptr); -} - -ColumnDataRowIterationHelper::ColumnDataRowIterator &ColumnDataRowIterationHelper::ColumnDataRowIterator::operator++() { - Next(); - return *this; -} - -bool ColumnDataRowIterationHelper::ColumnDataRowIterator::operator!=(const ColumnDataRowIterator &other) const { - return collection != other.collection || current_row.row_index != other.current_row.row_index || - current_row.base_index != other.current_row.base_index; -} - -const ColumnDataRow &ColumnDataRowIterationHelper::ColumnDataRowIterator::operator*() const { - return current_row; -} - -//===--------------------------------------------------------------------===// -// Append -//===--------------------------------------------------------------------===// -void ColumnDataCollection::InitializeAppend(ColumnDataAppendState &state) { - D_ASSERT(!finished_append); - state.vector_data.resize(types.size()); - if (segments.empty()) { - CreateSegment(); - } - auto &segment = *segments.back(); - if (segment.chunk_data.empty()) { - segment.AllocateNewChunk(); - } - segment.InitializeChunkState(segment.chunk_data.size() - 1, state.current_chunk_state); -} - -void ColumnDataCopyValidity(const UnifiedVectorFormat &source_data, validity_t *target, idx_t source_offset, - idx_t target_offset, idx_t copy_count) { - ValidityMask validity(target); - if (target_offset == 0) { - // first time appending to this vector - // all data here is still uninitialized - // initialize the validity mask to set all to valid - validity.SetAllValid(STANDARD_VECTOR_SIZE); - } - // FIXME: we can do something more optimized here using bitshifts & bitwise ors - if (!source_data.validity.AllValid()) { - for (idx_t i = 0; i < copy_count; i++) { - auto idx = source_data.sel->get_index(source_offset + i); - if (!source_data.validity.RowIsValid(idx)) { - validity.SetInvalid(target_offset + i); - } - } - } -} - -template -struct BaseValueCopy { - static idx_t TypeSize() { - return sizeof(T); - } - - template - static void Assign(ColumnDataMetaData &meta_data, data_ptr_t target, data_ptr_t source, idx_t target_idx, - idx_t source_idx) { - auto result_data = (T *)target; - auto source_data = (T *)source; - result_data[target_idx] = OP::Operation(meta_data, source_data[source_idx]); - } -}; - -template -struct StandardValueCopy : public BaseValueCopy { - static T Operation(ColumnDataMetaData &, T input) { - return input; - } -}; - -struct StringValueCopy : public BaseValueCopy { - static string_t Operation(ColumnDataMetaData &meta_data, string_t input) { - return input.IsInlined() ? input : meta_data.segment.heap->AddBlob(input); - } -}; - -struct ConstListValueCopy : public BaseValueCopy { - using TYPE = list_entry_t; - - static TYPE Operation(ColumnDataMetaData &meta_data, TYPE input) { - input.offset = meta_data.child_list_size; - return input; - } -}; - -struct ListValueCopy : public BaseValueCopy { - using TYPE = list_entry_t; - - static TYPE Operation(ColumnDataMetaData &meta_data, TYPE input) { - input.offset = meta_data.child_list_size; - meta_data.child_list_size += input.length; - return input; - } -}; - -struct StructValueCopy { - static idx_t TypeSize() { - return 0; - } - - template - static void Assign(ColumnDataMetaData &meta_data, data_ptr_t target, data_ptr_t source, idx_t target_idx, - idx_t source_idx) { - } -}; - -template -static void TemplatedColumnDataCopy(ColumnDataMetaData &meta_data, const UnifiedVectorFormat &source_data, - Vector &source, idx_t offset, idx_t count) { - auto &segment = meta_data.segment; - auto &append_state = meta_data.state; - - auto current_index = meta_data.vector_data_index; - idx_t remaining = count; - while (remaining > 0) { - auto ¤t_segment = segment.GetVectorData(current_index); - idx_t append_count = MinValue(STANDARD_VECTOR_SIZE - current_segment.count, remaining); - - auto base_ptr = segment.allocator->GetDataPointer(append_state.current_chunk_state, current_segment.block_id, - current_segment.offset); - auto validity_data = ColumnDataCollectionSegment::GetValidityPointer(base_ptr, OP::TypeSize()); - - ValidityMask result_validity(validity_data); - if (current_segment.count == 0) { - // first time appending to this vector - // all data here is still uninitialized - // initialize the validity mask to set all to valid - result_validity.SetAllValid(STANDARD_VECTOR_SIZE); - } - for (idx_t i = 0; i < append_count; i++) { - auto source_idx = source_data.sel->get_index(offset + i); - if (source_data.validity.RowIsValid(source_idx)) { - OP::template Assign(meta_data, base_ptr, source_data.data, current_segment.count + i, source_idx); - } else { - result_validity.SetInvalid(current_segment.count + i); - } - } - current_segment.count += append_count; - offset += append_count; - remaining -= append_count; - if (remaining > 0) { - // need to append more, check if we need to allocate a new vector or not - if (!current_segment.next_data.IsValid()) { - segment.AllocateVector(source.GetType(), meta_data.chunk_data, append_state, current_index); - } - D_ASSERT(segment.GetVectorData(current_index).next_data.IsValid()); - current_index = segment.GetVectorData(current_index).next_data; - } - } -} - -template -static void ColumnDataCopy(ColumnDataMetaData &meta_data, const UnifiedVectorFormat &source_data, Vector &source, - idx_t offset, idx_t copy_count) { - TemplatedColumnDataCopy>(meta_data, source_data, source, offset, copy_count); -} - -template <> -void ColumnDataCopy(ColumnDataMetaData &meta_data, const UnifiedVectorFormat &source_data, Vector &source, - idx_t offset, idx_t copy_count) { - - const auto &allocator_type = meta_data.segment.allocator->GetType(); - if (allocator_type == ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR || - allocator_type == ColumnDataAllocatorType::HYBRID) { - // strings cannot be spilled to disk - use StringHeap - TemplatedColumnDataCopy(meta_data, source_data, source, offset, copy_count); - return; - } - D_ASSERT(allocator_type == ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR); - - auto &segment = meta_data.segment; - auto &append_state = meta_data.state; - - VectorDataIndex child_index; - if (meta_data.GetVectorMetaData().child_index.IsValid()) { - // find the last child index - child_index = segment.GetChildIndex(meta_data.GetVectorMetaData().child_index); - auto next_child_index = segment.GetVectorData(child_index).next_data; - while (next_child_index.IsValid()) { - child_index = next_child_index; - next_child_index = segment.GetVectorData(child_index).next_data; - } - } - - auto current_index = meta_data.vector_data_index; - idx_t remaining = copy_count; - while (remaining > 0) { - // how many values fit in the current string vector - idx_t vector_remaining = - MinValue(STANDARD_VECTOR_SIZE - segment.GetVectorData(current_index).count, remaining); - - // 'append_count' is less if we cannot fit that amount of non-inlined strings on one buffer-managed block - idx_t append_count; - idx_t heap_size = 0; - const auto source_entries = UnifiedVectorFormat::GetData(source_data); - for (append_count = 0; append_count < vector_remaining; append_count++) { - auto source_idx = source_data.sel->get_index(offset + append_count); - if (!source_data.validity.RowIsValid(source_idx)) { - continue; - } - const auto &entry = source_entries[source_idx]; - if (entry.IsInlined()) { - continue; - } - if (heap_size + entry.GetSize() > Storage::BLOCK_SIZE) { - break; - } - heap_size += entry.GetSize(); - } - - if (vector_remaining != 0 && append_count == 0) { - // single string is longer than Storage::BLOCK_SIZE - // we allocate one block at a time for long strings - auto source_idx = source_data.sel->get_index(offset + append_count); - D_ASSERT(source_data.validity.RowIsValid(source_idx)); - D_ASSERT(!source_entries[source_idx].IsInlined()); - D_ASSERT(source_entries[source_idx].GetSize() > Storage::BLOCK_SIZE); - heap_size += source_entries[source_idx].GetSize(); - append_count++; - } - - // allocate string heap for the next 'append_count' strings - data_ptr_t heap_ptr = nullptr; - if (heap_size != 0) { - child_index = segment.AllocateStringHeap(heap_size, meta_data.chunk_data, append_state, child_index); - if (!meta_data.GetVectorMetaData().child_index.IsValid()) { - meta_data.GetVectorMetaData().child_index = meta_data.segment.AddChildIndex(child_index); - } - auto &child_segment = segment.GetVectorData(child_index); - heap_ptr = segment.allocator->GetDataPointer(append_state.current_chunk_state, child_segment.block_id, - child_segment.offset); - } - - auto ¤t_segment = segment.GetVectorData(current_index); - auto base_ptr = segment.allocator->GetDataPointer(append_state.current_chunk_state, current_segment.block_id, - current_segment.offset); - auto validity_data = ColumnDataCollectionSegment::GetValidityPointer(base_ptr, sizeof(string_t)); - ValidityMask target_validity(validity_data); - if (current_segment.count == 0) { - // first time appending to this vector - // all data here is still uninitialized - // initialize the validity mask to set all to valid - target_validity.SetAllValid(STANDARD_VECTOR_SIZE); - } - - auto target_entries = reinterpret_cast(base_ptr); - for (idx_t i = 0; i < append_count; i++) { - auto source_idx = source_data.sel->get_index(offset + i); - auto target_idx = current_segment.count + i; - if (!source_data.validity.RowIsValid(source_idx)) { - target_validity.SetInvalid(target_idx); - continue; - } - const auto &source_entry = source_entries[source_idx]; - auto &target_entry = target_entries[target_idx]; - if (source_entry.IsInlined()) { - target_entry = source_entry; - } else { - D_ASSERT(heap_ptr != nullptr); - memcpy(heap_ptr, source_entry.GetData(), source_entry.GetSize()); - target_entry = string_t(const_char_ptr_cast(heap_ptr), source_entry.GetSize()); - heap_ptr += source_entry.GetSize(); - } - } - - if (heap_size != 0) { - current_segment.swizzle_data.emplace_back(child_index, current_segment.count, append_count); - } - - current_segment.count += append_count; - offset += append_count; - remaining -= append_count; - - if (vector_remaining - append_count == 0) { - // need to append more, check if we need to allocate a new vector or not - if (!current_segment.next_data.IsValid()) { - segment.AllocateVector(source.GetType(), meta_data.chunk_data, append_state, current_index); - } - D_ASSERT(segment.GetVectorData(current_index).next_data.IsValid()); - current_index = segment.GetVectorData(current_index).next_data; - } - } -} - -template <> -void ColumnDataCopy(ColumnDataMetaData &meta_data, const UnifiedVectorFormat &source_data, Vector &source, - idx_t offset, idx_t copy_count) { - - auto &segment = meta_data.segment; - - auto &child_vector = ListVector::GetEntry(source); - auto &child_type = child_vector.GetType(); - - if (!meta_data.GetVectorMetaData().child_index.IsValid()) { - auto child_index = segment.AllocateVector(child_type, meta_data.chunk_data, meta_data.state); - meta_data.GetVectorMetaData().child_index = meta_data.segment.AddChildIndex(child_index); - } - - auto &child_function = meta_data.copy_function.child_functions[0]; - auto child_index = segment.GetChildIndex(meta_data.GetVectorMetaData().child_index); - - // figure out the current list size by traversing the set of child entries - idx_t current_list_size = 0; - auto current_child_index = child_index; - while (current_child_index.IsValid()) { - auto &child_vdata = segment.GetVectorData(current_child_index); - current_list_size += child_vdata.count; - current_child_index = child_vdata.next_data; - } - - // set the child vector - UnifiedVectorFormat child_vector_data; - ColumnDataMetaData child_meta_data(child_function, meta_data, child_index); - auto info = ListVector::GetConsecutiveChildListInfo(source, offset, copy_count); - - if (info.needs_slicing) { - SelectionVector sel(info.child_list_info.length); - ListVector::GetConsecutiveChildSelVector(source, sel, offset, copy_count); - - auto sliced_child_vector = Vector(child_vector, sel, info.child_list_info.length); - sliced_child_vector.Flatten(info.child_list_info.length); - info.child_list_info.offset = 0; - - sliced_child_vector.ToUnifiedFormat(info.child_list_info.length, child_vector_data); - child_function.function(child_meta_data, child_vector_data, sliced_child_vector, info.child_list_info.offset, - info.child_list_info.length); - - } else { - child_vector.ToUnifiedFormat(info.child_list_info.length, child_vector_data); - child_function.function(child_meta_data, child_vector_data, child_vector, info.child_list_info.offset, - info.child_list_info.length); - } - - // now copy the list entries - meta_data.child_list_size = current_list_size; - if (info.is_constant) { - TemplatedColumnDataCopy(meta_data, source_data, source, offset, copy_count); - } else { - TemplatedColumnDataCopy(meta_data, source_data, source, offset, copy_count); - } -} - -void ColumnDataCopyStruct(ColumnDataMetaData &meta_data, const UnifiedVectorFormat &source_data, Vector &source, - idx_t offset, idx_t copy_count) { - auto &segment = meta_data.segment; - - // copy the NULL values for the main struct vector - TemplatedColumnDataCopy(meta_data, source_data, source, offset, copy_count); - - auto &child_types = StructType::GetChildTypes(source.GetType()); - // now copy all the child vectors - D_ASSERT(meta_data.GetVectorMetaData().child_index.IsValid()); - auto &child_vectors = StructVector::GetEntries(source); - for (idx_t child_idx = 0; child_idx < child_types.size(); child_idx++) { - auto &child_function = meta_data.copy_function.child_functions[child_idx]; - auto child_index = segment.GetChildIndex(meta_data.GetVectorMetaData().child_index, child_idx); - ColumnDataMetaData child_meta_data(child_function, meta_data, child_index); - - UnifiedVectorFormat child_data; - child_vectors[child_idx]->ToUnifiedFormat(copy_count, child_data); - - child_function.function(child_meta_data, child_data, *child_vectors[child_idx], offset, copy_count); - } -} - -ColumnDataCopyFunction ColumnDataCollection::GetCopyFunction(const LogicalType &type) { - ColumnDataCopyFunction result; - column_data_copy_function_t function; - switch (type.InternalType()) { - case PhysicalType::BOOL: - function = ColumnDataCopy; - break; - case PhysicalType::INT8: - function = ColumnDataCopy; - break; - case PhysicalType::INT16: - function = ColumnDataCopy; - break; - case PhysicalType::INT32: - function = ColumnDataCopy; - break; - case PhysicalType::INT64: - function = ColumnDataCopy; - break; - case PhysicalType::INT128: - function = ColumnDataCopy; - break; - case PhysicalType::UINT8: - function = ColumnDataCopy; - break; - case PhysicalType::UINT16: - function = ColumnDataCopy; - break; - case PhysicalType::UINT32: - function = ColumnDataCopy; - break; - case PhysicalType::UINT64: - function = ColumnDataCopy; - break; - case PhysicalType::FLOAT: - function = ColumnDataCopy; - break; - case PhysicalType::DOUBLE: - function = ColumnDataCopy; - break; - case PhysicalType::INTERVAL: - function = ColumnDataCopy; - break; - case PhysicalType::VARCHAR: - function = ColumnDataCopy; - break; - case PhysicalType::STRUCT: { - function = ColumnDataCopyStruct; - auto &child_types = StructType::GetChildTypes(type); - for (auto &kv : child_types) { - result.child_functions.push_back(GetCopyFunction(kv.second)); - } - break; - } - case PhysicalType::LIST: { - function = ColumnDataCopy; - auto child_function = GetCopyFunction(ListType::GetChildType(type)); - result.child_functions.push_back(child_function); - break; - } - default: - throw InternalException("Unsupported type for ColumnDataCollection::GetCopyFunction"); - } - result.function = function; - return result; -} - -static bool IsComplexType(const LogicalType &type) { - switch (type.InternalType()) { - case PhysicalType::STRUCT: - case PhysicalType::LIST: - return true; - default: - return false; - }; -} - -void ColumnDataCollection::Append(ColumnDataAppendState &state, DataChunk &input) { - D_ASSERT(!finished_append); - D_ASSERT(types == input.GetTypes()); - - auto &segment = *segments.back(); - for (idx_t vector_idx = 0; vector_idx < types.size(); vector_idx++) { - if (IsComplexType(input.data[vector_idx].GetType())) { - input.data[vector_idx].Flatten(input.size()); - } - input.data[vector_idx].ToUnifiedFormat(input.size(), state.vector_data[vector_idx]); - } - - idx_t remaining = input.size(); - while (remaining > 0) { - auto &chunk_data = segment.chunk_data.back(); - idx_t append_amount = MinValue(remaining, STANDARD_VECTOR_SIZE - chunk_data.count); - if (append_amount > 0) { - idx_t offset = input.size() - remaining; - for (idx_t vector_idx = 0; vector_idx < types.size(); vector_idx++) { - ColumnDataMetaData meta_data(copy_functions[vector_idx], segment, state, chunk_data, - chunk_data.vector_data[vector_idx]); - copy_functions[vector_idx].function(meta_data, state.vector_data[vector_idx], input.data[vector_idx], - offset, append_amount); - } - chunk_data.count += append_amount; - } - remaining -= append_amount; - if (remaining > 0) { - // more to do - // allocate a new chunk - segment.AllocateNewChunk(); - segment.InitializeChunkState(segment.chunk_data.size() - 1, state.current_chunk_state); - } - } - segment.count += input.size(); - count += input.size(); -} - -void ColumnDataCollection::Append(DataChunk &input) { - ColumnDataAppendState state; - InitializeAppend(state); - Append(state, input); -} - -//===--------------------------------------------------------------------===// -// Scan -//===--------------------------------------------------------------------===// -void ColumnDataCollection::InitializeScan(ColumnDataScanState &state, ColumnDataScanProperties properties) const { - vector column_ids; - column_ids.reserve(types.size()); - for (idx_t i = 0; i < types.size(); i++) { - column_ids.push_back(i); - } - InitializeScan(state, std::move(column_ids), properties); -} - -void ColumnDataCollection::InitializeScan(ColumnDataScanState &state, vector column_ids, - ColumnDataScanProperties properties) const { - state.chunk_index = 0; - state.segment_index = 0; - state.current_row_index = 0; - state.next_row_index = 0; - state.current_chunk_state.handles.clear(); - state.properties = properties; - state.column_ids = std::move(column_ids); -} - -void ColumnDataCollection::InitializeScan(ColumnDataParallelScanState &state, - ColumnDataScanProperties properties) const { - InitializeScan(state.scan_state, properties); -} - -void ColumnDataCollection::InitializeScan(ColumnDataParallelScanState &state, vector column_ids, - ColumnDataScanProperties properties) const { - InitializeScan(state.scan_state, std::move(column_ids), properties); -} - -bool ColumnDataCollection::Scan(ColumnDataParallelScanState &state, ColumnDataLocalScanState &lstate, - DataChunk &result) const { - result.Reset(); - - idx_t chunk_index; - idx_t segment_index; - idx_t row_index; - { - lock_guard l(state.lock); - if (!NextScanIndex(state.scan_state, chunk_index, segment_index, row_index)) { - return false; - } - } - ScanAtIndex(state, lstate, result, chunk_index, segment_index, row_index); - return true; -} - -void ColumnDataCollection::InitializeScanChunk(DataChunk &chunk) const { - chunk.Initialize(allocator->GetAllocator(), types); -} - -void ColumnDataCollection::InitializeScanChunk(ColumnDataScanState &state, DataChunk &chunk) const { - D_ASSERT(!state.column_ids.empty()); - vector chunk_types; - chunk_types.reserve(state.column_ids.size()); - for (idx_t i = 0; i < state.column_ids.size(); i++) { - auto column_idx = state.column_ids[i]; - D_ASSERT(column_idx < types.size()); - chunk_types.push_back(types[column_idx]); - } - chunk.Initialize(allocator->GetAllocator(), chunk_types); -} - -bool ColumnDataCollection::NextScanIndex(ColumnDataScanState &state, idx_t &chunk_index, idx_t &segment_index, - idx_t &row_index) const { - row_index = state.current_row_index = state.next_row_index; - // check if we still have collections to scan - if (state.segment_index >= segments.size()) { - // no more data left in the scan - return false; - } - // check within the current collection if we still have chunks to scan - while (state.chunk_index >= segments[state.segment_index]->chunk_data.size()) { - // exhausted all chunks for this internal data structure: move to the next one - state.chunk_index = 0; - state.segment_index++; - state.current_chunk_state.handles.clear(); - if (state.segment_index >= segments.size()) { - return false; - } - } - state.next_row_index += segments[state.segment_index]->chunk_data[state.chunk_index].count; - segment_index = state.segment_index; - chunk_index = state.chunk_index++; - return true; -} - -void ColumnDataCollection::ScanAtIndex(ColumnDataParallelScanState &state, ColumnDataLocalScanState &lstate, - DataChunk &result, idx_t chunk_index, idx_t segment_index, - idx_t row_index) const { - if (segment_index != lstate.current_segment_index) { - lstate.current_chunk_state.handles.clear(); - lstate.current_segment_index = segment_index; - } - auto &segment = *segments[segment_index]; - lstate.current_chunk_state.properties = state.scan_state.properties; - segment.ReadChunk(chunk_index, lstate.current_chunk_state, result, state.scan_state.column_ids); - lstate.current_row_index = row_index; - result.Verify(); -} - -bool ColumnDataCollection::Scan(ColumnDataScanState &state, DataChunk &result) const { - result.Reset(); - - idx_t chunk_index; - idx_t segment_index; - idx_t row_index; - if (!NextScanIndex(state, chunk_index, segment_index, row_index)) { - return false; - } - - // found a chunk to scan -> scan it - auto &segment = *segments[segment_index]; - state.current_chunk_state.properties = state.properties; - segment.ReadChunk(chunk_index, state.current_chunk_state, result, state.column_ids); - result.Verify(); - return true; -} - -ColumnDataRowCollection ColumnDataCollection::GetRows() const { - return ColumnDataRowCollection(*this); -} - -//===--------------------------------------------------------------------===// -// Combine -//===--------------------------------------------------------------------===// -void ColumnDataCollection::Combine(ColumnDataCollection &other) { - if (other.count == 0) { - return; - } - if (types != other.types) { - throw InternalException("Attempting to combine ColumnDataCollections with mismatching types"); - } - this->count += other.count; - this->segments.reserve(segments.size() + other.segments.size()); - for (auto &other_seg : other.segments) { - segments.push_back(std::move(other_seg)); - } - other.Reset(); - Verify(); -} - -//===--------------------------------------------------------------------===// -// Fetch -//===--------------------------------------------------------------------===// -idx_t ColumnDataCollection::ChunkCount() const { - idx_t chunk_count = 0; - for (auto &segment : segments) { - chunk_count += segment->ChunkCount(); - } - return chunk_count; -} - -void ColumnDataCollection::FetchChunk(idx_t chunk_idx, DataChunk &result) const { - D_ASSERT(chunk_idx < ChunkCount()); - for (auto &segment : segments) { - if (chunk_idx >= segment->ChunkCount()) { - chunk_idx -= segment->ChunkCount(); - } else { - segment->FetchChunk(chunk_idx, result); - return; - } - } - throw InternalException("Failed to find chunk in ColumnDataCollection"); -} - -//===--------------------------------------------------------------------===// -// Helpers -//===--------------------------------------------------------------------===// -void ColumnDataCollection::Verify() { -#ifdef DEBUG - // verify counts - idx_t total_segment_count = 0; - for (auto &segment : segments) { - segment->Verify(); - total_segment_count += segment->count; - } - D_ASSERT(total_segment_count == this->count); -#endif -} - -// LCOV_EXCL_START -string ColumnDataCollection::ToString() const { - DataChunk chunk; - InitializeScanChunk(chunk); - - ColumnDataScanState scan_state; - InitializeScan(scan_state); - - string result = StringUtil::Format("ColumnDataCollection - [%llu Chunks, %llu Rows]\n", ChunkCount(), Count()); - idx_t chunk_idx = 0; - idx_t row_count = 0; - while (Scan(scan_state, chunk)) { - result += - StringUtil::Format("Chunk %llu - [Rows %llu - %llu]\n", chunk_idx, row_count, row_count + chunk.size()) + - chunk.ToString(); - chunk_idx++; - row_count += chunk.size(); - } - - return result; -} -// LCOV_EXCL_STOP - -void ColumnDataCollection::Print() const { - Printer::Print(ToString()); -} - -void ColumnDataCollection::Reset() { - count = 0; - segments.clear(); - - // Refreshes the ColumnDataAllocator to prevent holding on to allocated data unnecessarily - allocator = make_shared(*allocator); -} - -struct ValueResultEquals { - bool operator()(const Value &a, const Value &b) const { - return Value::DefaultValuesAreEqual(a, b); - } -}; - -bool ColumnDataCollection::ResultEquals(const ColumnDataCollection &left, const ColumnDataCollection &right, - string &error_message, bool ordered) { - if (left.ColumnCount() != right.ColumnCount()) { - error_message = "Column count mismatch"; - return false; - } - if (left.Count() != right.Count()) { - error_message = "Row count mismatch"; - return false; - } - auto left_rows = left.GetRows(); - auto right_rows = right.GetRows(); - for (idx_t r = 0; r < left.Count(); r++) { - for (idx_t c = 0; c < left.ColumnCount(); c++) { - auto lvalue = left_rows.GetValue(c, r); - auto rvalue = right_rows.GetValue(c, r); - - if (!Value::DefaultValuesAreEqual(lvalue, rvalue)) { - error_message = - StringUtil::Format("%s <> %s (row: %lld, col: %lld)\n", lvalue.ToString(), rvalue.ToString(), r, c); - break; - } - } - if (!error_message.empty()) { - if (ordered) { - return false; - } else { - break; - } - } - } - if (!error_message.empty()) { - // do an unordered comparison - bool found_all = true; - for (idx_t c = 0; c < left.ColumnCount(); c++) { - std::unordered_multiset lvalues; - for (idx_t r = 0; r < left.Count(); r++) { - auto lvalue = left_rows.GetValue(c, r); - lvalues.insert(lvalue); - } - for (idx_t r = 0; r < right.Count(); r++) { - auto rvalue = right_rows.GetValue(c, r); - auto entry = lvalues.find(rvalue); - if (entry == lvalues.end()) { - found_all = false; - break; - } - lvalues.erase(entry); - } - if (!found_all) { - break; - } - } - if (!found_all) { - return false; - } - error_message = string(); - } - return true; -} - -vector> ColumnDataCollection::GetHeapReferences() { - vector> result(segments.size(), nullptr); - for (idx_t segment_idx = 0; segment_idx < segments.size(); segment_idx++) { - result[segment_idx] = segments[segment_idx]->heap; - } - return result; -} - -ColumnDataAllocatorType ColumnDataCollection::GetAllocatorType() const { - return allocator->GetType(); -} - -const vector> &ColumnDataCollection::GetSegments() const { - return segments; -} - -void ColumnDataCollection::Serialize(Serializer &serializer) const { - vector> values; - values.resize(ColumnCount()); - for (auto &chunk : Chunks()) { - for (idx_t c = 0; c < chunk.ColumnCount(); c++) { - for (idx_t r = 0; r < chunk.size(); r++) { - values[c].push_back(chunk.GetValue(c, r)); - } - } - } - serializer.WriteProperty(100, "types", types); - serializer.WriteProperty(101, "values", values); -} - -unique_ptr ColumnDataCollection::Deserialize(Deserializer &deserializer) { - auto types = deserializer.ReadProperty>(100, "types"); - auto values = deserializer.ReadProperty>>(101, "values"); - - auto collection = make_uniq(Allocator::DefaultAllocator(), types); - if (values.empty()) { - return collection; - } - DataChunk chunk; - chunk.Initialize(Allocator::DefaultAllocator(), types); - - for (idx_t r = 0; r < values[0].size(); r++) { - for (idx_t c = 0; c < types.size(); c++) { - chunk.SetValue(c, chunk.size(), values[c][r]); - } - chunk.SetCardinality(chunk.size() + 1); - if (chunk.size() == STANDARD_VECTOR_SIZE) { - collection->Append(chunk); - chunk.Reset(); - } - } - if (chunk.size() > 0) { - collection->Append(chunk); - } - return collection; -} - -} // namespace duckdb - - - - -namespace duckdb { - -ColumnDataCollectionSegment::ColumnDataCollectionSegment(shared_ptr allocator_p, - vector types_p) - : allocator(std::move(allocator_p)), types(std::move(types_p)), count(0), - heap(make_shared(allocator->GetAllocator())) { -} - -idx_t ColumnDataCollectionSegment::GetDataSize(idx_t type_size) { - return AlignValue(type_size * STANDARD_VECTOR_SIZE); -} - -validity_t *ColumnDataCollectionSegment::GetValidityPointer(data_ptr_t base_ptr, idx_t type_size) { - return reinterpret_cast(base_ptr + GetDataSize(type_size)); -} - -VectorDataIndex ColumnDataCollectionSegment::AllocateVectorInternal(const LogicalType &type, ChunkMetaData &chunk_meta, - ChunkManagementState *chunk_state) { - VectorMetaData meta_data; - meta_data.count = 0; - - auto internal_type = type.InternalType(); - auto type_size = internal_type == PhysicalType::STRUCT ? 0 : GetTypeIdSize(internal_type); - allocator->AllocateData(GetDataSize(type_size) + ValidityMask::STANDARD_MASK_SIZE, meta_data.block_id, - meta_data.offset, chunk_state); - if (allocator->GetType() == ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR || - allocator->GetType() == ColumnDataAllocatorType::HYBRID) { - chunk_meta.block_ids.insert(meta_data.block_id); - } - - auto index = vector_data.size(); - vector_data.push_back(meta_data); - return VectorDataIndex(index); -} - -VectorDataIndex ColumnDataCollectionSegment::AllocateVector(const LogicalType &type, ChunkMetaData &chunk_meta, - ChunkManagementState *chunk_state, - VectorDataIndex prev_index) { - auto index = AllocateVectorInternal(type, chunk_meta, chunk_state); - if (prev_index.IsValid()) { - GetVectorData(prev_index).next_data = index; - } - if (type.InternalType() == PhysicalType::STRUCT) { - // initialize the struct children - auto &child_types = StructType::GetChildTypes(type); - auto base_child_index = ReserveChildren(child_types.size()); - for (idx_t child_idx = 0; child_idx < child_types.size(); child_idx++) { - VectorDataIndex prev_child_index; - if (prev_index.IsValid()) { - prev_child_index = GetChildIndex(GetVectorData(prev_index).child_index, child_idx); - } - auto child_index = AllocateVector(child_types[child_idx].second, chunk_meta, chunk_state, prev_child_index); - SetChildIndex(base_child_index, child_idx, child_index); - } - GetVectorData(index).child_index = base_child_index; - } - return index; -} - -VectorDataIndex ColumnDataCollectionSegment::AllocateVector(const LogicalType &type, ChunkMetaData &chunk_meta, - ColumnDataAppendState &append_state, - VectorDataIndex prev_index) { - return AllocateVector(type, chunk_meta, &append_state.current_chunk_state, prev_index); -} - -VectorDataIndex ColumnDataCollectionSegment::AllocateStringHeap(idx_t size, ChunkMetaData &chunk_meta, - ColumnDataAppendState &append_state, - VectorDataIndex prev_index) { - D_ASSERT(allocator->GetType() == ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR); - D_ASSERT(size != 0); - - VectorMetaData meta_data; - meta_data.count = 0; - - allocator->AllocateData(AlignValue(size), meta_data.block_id, meta_data.offset, &append_state.current_chunk_state); - chunk_meta.block_ids.insert(meta_data.block_id); - - VectorDataIndex index(vector_data.size()); - vector_data.push_back(meta_data); - - if (prev_index.IsValid()) { - GetVectorData(prev_index).next_data = index; - } - - return index; -} - -void ColumnDataCollectionSegment::AllocateNewChunk() { - ChunkMetaData meta_data; - meta_data.count = 0; - meta_data.vector_data.reserve(types.size()); - for (idx_t i = 0; i < types.size(); i++) { - auto vector_idx = AllocateVector(types[i], meta_data); - meta_data.vector_data.push_back(vector_idx); - } - chunk_data.push_back(std::move(meta_data)); -} - -void ColumnDataCollectionSegment::InitializeChunkState(idx_t chunk_index, ChunkManagementState &state) { - auto &chunk = chunk_data[chunk_index]; - allocator->InitializeChunkState(state, chunk); -} - -VectorDataIndex ColumnDataCollectionSegment::GetChildIndex(VectorChildIndex index, idx_t child_entry) { - D_ASSERT(index.IsValid()); - D_ASSERT(index.index + child_entry < child_indices.size()); - return VectorDataIndex(child_indices[index.index + child_entry]); -} - -VectorChildIndex ColumnDataCollectionSegment::AddChildIndex(VectorDataIndex index) { - auto result = child_indices.size(); - child_indices.push_back(index); - return VectorChildIndex(result); -} - -VectorChildIndex ColumnDataCollectionSegment::ReserveChildren(idx_t child_count) { - auto result = child_indices.size(); - for (idx_t i = 0; i < child_count; i++) { - child_indices.emplace_back(); - } - return VectorChildIndex(result); -} - -void ColumnDataCollectionSegment::SetChildIndex(VectorChildIndex base_idx, idx_t child_number, VectorDataIndex index) { - D_ASSERT(base_idx.IsValid()); - D_ASSERT(index.IsValid()); - D_ASSERT(base_idx.index + child_number < child_indices.size()); - child_indices[base_idx.index + child_number] = index; -} - -idx_t ColumnDataCollectionSegment::ReadVectorInternal(ChunkManagementState &state, VectorDataIndex vector_index, - Vector &result) { - auto &vector_type = result.GetType(); - auto internal_type = vector_type.InternalType(); - auto type_size = GetTypeIdSize(internal_type); - auto &vdata = GetVectorData(vector_index); - - auto base_ptr = allocator->GetDataPointer(state, vdata.block_id, vdata.offset); - auto validity_data = GetValidityPointer(base_ptr, type_size); - if (!vdata.next_data.IsValid() && state.properties != ColumnDataScanProperties::DISALLOW_ZERO_COPY) { - // no next data, we can do a zero-copy read of this vector - FlatVector::SetData(result, base_ptr); - FlatVector::Validity(result).Initialize(validity_data); - return vdata.count; - } - - // the data for this vector is spread over multiple vector data entries - // we need to copy over the data for each of the vectors - // first figure out how many rows we need to copy by looping over all of the child vector indexes - idx_t vector_count = 0; - auto next_index = vector_index; - while (next_index.IsValid()) { - auto ¤t_vdata = GetVectorData(next_index); - vector_count += current_vdata.count; - next_index = current_vdata.next_data; - } - // resize the result vector - result.Resize(0, vector_count); - next_index = vector_index; - // now perform the copy of each of the vectors - auto target_data = FlatVector::GetData(result); - auto &target_validity = FlatVector::Validity(result); - idx_t current_offset = 0; - while (next_index.IsValid()) { - auto ¤t_vdata = GetVectorData(next_index); - base_ptr = allocator->GetDataPointer(state, current_vdata.block_id, current_vdata.offset); - validity_data = GetValidityPointer(base_ptr, type_size); - if (type_size > 0) { - memcpy(target_data + current_offset * type_size, base_ptr, current_vdata.count * type_size); - } - ValidityMask current_validity(validity_data); - target_validity.SliceInPlace(current_validity, current_offset, 0, current_vdata.count); - current_offset += current_vdata.count; - next_index = current_vdata.next_data; - } - return vector_count; -} - -idx_t ColumnDataCollectionSegment::ReadVector(ChunkManagementState &state, VectorDataIndex vector_index, - Vector &result) { - auto &vector_type = result.GetType(); - auto internal_type = vector_type.InternalType(); - auto &vdata = GetVectorData(vector_index); - if (vdata.count == 0) { - return 0; - } - auto vcount = ReadVectorInternal(state, vector_index, result); - if (internal_type == PhysicalType::LIST) { - // list: copy child - auto &child_vector = ListVector::GetEntry(result); - auto child_count = ReadVector(state, GetChildIndex(vdata.child_index), child_vector); - ListVector::SetListSize(result, child_count); - } else if (internal_type == PhysicalType::STRUCT) { - auto &child_vectors = StructVector::GetEntries(result); - for (idx_t child_idx = 0; child_idx < child_vectors.size(); child_idx++) { - auto child_count = - ReadVector(state, GetChildIndex(vdata.child_index, child_idx), *child_vectors[child_idx]); - if (child_count != vcount) { - throw InternalException("Column Data Collection: mismatch in struct child sizes"); - } - } - } else if (internal_type == PhysicalType::VARCHAR) { - if (allocator->GetType() == ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR) { - auto next_index = vector_index; - idx_t offset = 0; - while (next_index.IsValid()) { - auto ¤t_vdata = GetVectorData(next_index); - for (auto &swizzle_segment : current_vdata.swizzle_data) { - auto &string_heap_segment = GetVectorData(swizzle_segment.child_index); - allocator->UnswizzlePointers(state, result, offset + swizzle_segment.offset, swizzle_segment.count, - string_heap_segment.block_id, string_heap_segment.offset); - } - offset += current_vdata.count; - next_index = current_vdata.next_data; - } - } - if (state.properties == ColumnDataScanProperties::DISALLOW_ZERO_COPY) { - VectorOperations::Copy(result, result, vdata.count, 0, 0); - } - } - return vcount; -} - -void ColumnDataCollectionSegment::ReadChunk(idx_t chunk_index, ChunkManagementState &state, DataChunk &chunk, - const vector &column_ids) { - D_ASSERT(chunk.ColumnCount() == column_ids.size()); - D_ASSERT(state.properties != ColumnDataScanProperties::INVALID); - InitializeChunkState(chunk_index, state); - auto &chunk_meta = chunk_data[chunk_index]; - for (idx_t i = 0; i < column_ids.size(); i++) { - auto vector_idx = column_ids[i]; - D_ASSERT(vector_idx < chunk_meta.vector_data.size()); - ReadVector(state, chunk_meta.vector_data[vector_idx], chunk.data[i]); - } - chunk.SetCardinality(chunk_meta.count); -} - -idx_t ColumnDataCollectionSegment::ChunkCount() const { - return chunk_data.size(); -} - -idx_t ColumnDataCollectionSegment::SizeInBytes() const { - D_ASSERT(!allocator->IsShared()); - return allocator->SizeInBytes() + heap->SizeInBytes(); -} - -void ColumnDataCollectionSegment::FetchChunk(idx_t chunk_idx, DataChunk &result) { - vector column_ids; - column_ids.reserve(types.size()); - for (idx_t i = 0; i < types.size(); i++) { - column_ids.push_back(i); - } - FetchChunk(chunk_idx, result, column_ids); -} - -void ColumnDataCollectionSegment::FetchChunk(idx_t chunk_idx, DataChunk &result, const vector &column_ids) { - D_ASSERT(chunk_idx < chunk_data.size()); - ChunkManagementState state; - state.properties = ColumnDataScanProperties::DISALLOW_ZERO_COPY; - ReadChunk(chunk_idx, state, result, column_ids); -} - -void ColumnDataCollectionSegment::Verify() { -#ifdef DEBUG - idx_t total_count = 0; - for (idx_t i = 0; i < chunk_data.size(); i++) { - total_count += chunk_data[i].count; - } - D_ASSERT(total_count == this->count); -#endif -} - -} // namespace duckdb - - -#include - -namespace duckdb { - -using ChunkReference = ColumnDataConsumer::ChunkReference; - -ChunkReference::ChunkReference(ColumnDataCollectionSegment *segment_p, uint32_t chunk_index_p) - : segment(segment_p), chunk_index_in_segment(chunk_index_p) { -} - -uint32_t ChunkReference::GetMinimumBlockID() const { - const auto &block_ids = segment->chunk_data[chunk_index_in_segment].block_ids; - return *std::min_element(block_ids.begin(), block_ids.end()); -} - -ColumnDataConsumer::ColumnDataConsumer(ColumnDataCollection &collection_p, vector column_ids) - : collection(collection_p), column_ids(std::move(column_ids)) { -} - -void ColumnDataConsumer::InitializeScan() { - chunk_count = collection.ChunkCount(); - current_chunk_index = 0; - chunk_delete_index = DConstants::INVALID_INDEX; - - // Initialize chunk references and sort them, so we can scan them in a sane order, regardless of how it was created - chunk_references.reserve(chunk_count); - for (auto &segment : collection.GetSegments()) { - for (idx_t chunk_index = 0; chunk_index < segment->chunk_data.size(); chunk_index++) { - chunk_references.emplace_back(segment.get(), chunk_index); - } - } - std::sort(chunk_references.begin(), chunk_references.end()); -} - -bool ColumnDataConsumer::AssignChunk(ColumnDataConsumerScanState &state) { - lock_guard guard(lock); - if (current_chunk_index == chunk_count) { - // All chunks have been assigned - state.current_chunk_state.handles.clear(); - state.chunk_index = DConstants::INVALID_INDEX; - return false; - } - // Assign chunk index - state.chunk_index = current_chunk_index++; - D_ASSERT(chunks_in_progress.find(state.chunk_index) == chunks_in_progress.end()); - chunks_in_progress.insert(state.chunk_index); - return true; -} - -void ColumnDataConsumer::ScanChunk(ColumnDataConsumerScanState &state, DataChunk &chunk) const { - D_ASSERT(state.chunk_index < chunk_count); - auto &chunk_ref = chunk_references[state.chunk_index]; - if (state.allocator != chunk_ref.segment->allocator.get()) { - // Previously scanned a chunk from a different allocator, reset the handles - state.allocator = chunk_ref.segment->allocator.get(); - state.current_chunk_state.handles.clear(); - } - chunk_ref.segment->ReadChunk(chunk_ref.chunk_index_in_segment, state.current_chunk_state, chunk, column_ids); -} - -void ColumnDataConsumer::FinishChunk(ColumnDataConsumerScanState &state) { - D_ASSERT(state.chunk_index < chunk_count); - idx_t delete_index_start; - idx_t delete_index_end; - { - lock_guard guard(lock); - D_ASSERT(chunks_in_progress.find(state.chunk_index) != chunks_in_progress.end()); - delete_index_start = chunk_delete_index; - delete_index_end = *std::min_element(chunks_in_progress.begin(), chunks_in_progress.end()); - chunks_in_progress.erase(state.chunk_index); - chunk_delete_index = delete_index_end; - } - ConsumeChunks(delete_index_start, delete_index_end); -} -void ColumnDataConsumer::ConsumeChunks(idx_t delete_index_start, idx_t delete_index_end) { - for (idx_t chunk_index = delete_index_start; chunk_index < delete_index_end; chunk_index++) { - if (chunk_index == 0) { - continue; - } - auto &prev_chunk_ref = chunk_references[chunk_index - 1]; - auto &curr_chunk_ref = chunk_references[chunk_index]; - auto prev_allocator = prev_chunk_ref.segment->allocator.get(); - auto curr_allocator = curr_chunk_ref.segment->allocator.get(); - auto prev_min_block_id = prev_chunk_ref.GetMinimumBlockID(); - auto curr_min_block_id = curr_chunk_ref.GetMinimumBlockID(); - if (prev_allocator != curr_allocator) { - // Moved to the next allocator, delete all remaining blocks in the previous one - for (uint32_t block_id = prev_min_block_id; block_id < prev_allocator->BlockCount(); block_id++) { - prev_allocator->DeleteBlock(block_id); - } - continue; - } - // Same allocator, see if we can delete blocks - for (uint32_t block_id = prev_min_block_id; block_id < curr_min_block_id; block_id++) { - prev_allocator->DeleteBlock(block_id); - } - } -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -PartitionedColumnData::PartitionedColumnData(PartitionedColumnDataType type_p, ClientContext &context_p, - vector types_p) - : type(type_p), context(context_p), types(std::move(types_p)), - allocators(make_shared()) { -} - -PartitionedColumnData::PartitionedColumnData(const PartitionedColumnData &other) - : type(other.type), context(other.context), types(other.types), allocators(other.allocators) { -} - -unique_ptr PartitionedColumnData::CreateShared() { - switch (type) { - case PartitionedColumnDataType::RADIX: - return make_uniq(Cast()); - case PartitionedColumnDataType::HIVE: - return make_uniq(Cast()); - default: - throw NotImplementedException("CreateShared for this type of PartitionedColumnData"); - } -} - -PartitionedColumnData::~PartitionedColumnData() { -} - -void PartitionedColumnData::InitializeAppendState(PartitionedColumnDataAppendState &state) const { - state.partition_sel.Initialize(); - state.slice_chunk.Initialize(BufferAllocator::Get(context), types); - InitializeAppendStateInternal(state); -} - -unique_ptr PartitionedColumnData::CreatePartitionBuffer() const { - auto result = make_uniq(); - result->Initialize(BufferAllocator::Get(context), types, BufferSize()); - return result; -} - -void PartitionedColumnData::Append(PartitionedColumnDataAppendState &state, DataChunk &input) { - // Compute partition indices and store them in state.partition_indices - ComputePartitionIndices(state, input); - - // Compute the counts per partition - const auto count = input.size(); - const auto partition_indices = FlatVector::GetData(state.partition_indices); - auto &partition_entries = state.partition_entries; - partition_entries.clear(); - switch (state.partition_indices.GetVectorType()) { - case VectorType::FLAT_VECTOR: - for (idx_t i = 0; i < count; i++) { - const auto &partition_index = partition_indices[i]; - auto partition_entry = partition_entries.find(partition_index); - if (partition_entry == partition_entries.end()) { - partition_entries[partition_index] = list_entry_t(0, 1); - } else { - partition_entry->second.length++; - } - } - break; - case VectorType::CONSTANT_VECTOR: - partition_entries[partition_indices[0]] = list_entry_t(0, count); - break; - default: - throw InternalException("Unexpected VectorType in PartitionedColumnData::Append"); - } - - // Early out: check if everything belongs to a single partition - if (partition_entries.size() == 1) { - const auto &partition_index = partition_entries.begin()->first; - auto &partition = *partitions[partition_index]; - auto &partition_append_state = *state.partition_append_states[partition_index]; - partition.Append(partition_append_state, input); - return; - } - - // Compute offsets from the counts - idx_t offset = 0; - for (auto &pc : partition_entries) { - auto &partition_entry = pc.second; - partition_entry.offset = offset; - offset += partition_entry.length; - } - - // Now initialize a single selection vector that acts as a selection vector for every partition - auto &all_partitions_sel = state.partition_sel; - for (idx_t i = 0; i < count; i++) { - const auto &partition_index = partition_indices[i]; - auto &partition_offset = partition_entries[partition_index].offset; - all_partitions_sel[partition_offset++] = i; - } - - // Loop through the partitions to append the new data to the partition buffers, and flush the buffers if necessary - SelectionVector partition_sel; - for (auto &pc : partition_entries) { - const auto &partition_index = pc.first; - - // Partition, buffer, and append state for this partition index - auto &partition = *partitions[partition_index]; - auto &partition_buffer = *state.partition_buffers[partition_index]; - auto &partition_append_state = *state.partition_append_states[partition_index]; - - // Length and offset into the selection vector for this chunk, for this partition - const auto &partition_entry = pc.second; - const auto &partition_length = partition_entry.length; - const auto partition_offset = partition_entry.offset - partition_length; - - // Create a selection vector for this partition using the offset into the single selection vector - partition_sel.Initialize(all_partitions_sel.data() + partition_offset); - - if (partition_length >= HalfBufferSize()) { - // Slice the input chunk using the selection vector - state.slice_chunk.Reset(); - state.slice_chunk.Slice(input, partition_sel, partition_length); - - // Append it to the partition directly - partition.Append(partition_append_state, state.slice_chunk); - } else { - // Append the input chunk to the partition buffer using the selection vector - partition_buffer.Append(input, false, &partition_sel, partition_length); - - if (partition_buffer.size() >= HalfBufferSize()) { - // Next batch won't fit in the buffer, flush it to the partition - partition.Append(partition_append_state, partition_buffer); - partition_buffer.Reset(); - partition_buffer.SetCapacity(BufferSize()); - } - } - } -} - -void PartitionedColumnData::FlushAppendState(PartitionedColumnDataAppendState &state) { - for (idx_t i = 0; i < state.partition_buffers.size(); i++) { - auto &partition_buffer = *state.partition_buffers[i]; - if (partition_buffer.size() > 0) { - partitions[i]->Append(partition_buffer); - partition_buffer.Reset(); - } - } -} - -void PartitionedColumnData::Combine(PartitionedColumnData &other) { - // Now combine the state's partitions into this - lock_guard guard(lock); - - if (partitions.empty()) { - // This is the first merge, we just copy them over - partitions = std::move(other.partitions); - } else { - D_ASSERT(partitions.size() == other.partitions.size()); - // Combine the append state's partitions into this PartitionedColumnData - for (idx_t i = 0; i < other.partitions.size(); i++) { - partitions[i]->Combine(*other.partitions[i]); - } - } -} - -vector> &PartitionedColumnData::GetPartitions() { - return partitions; -} - -void PartitionedColumnData::CreateAllocator() { - allocators->allocators.emplace_back(make_shared(BufferManager::GetBufferManager(context))); - allocators->allocators.back()->MakeShared(); -} - -} // namespace duckdb - - - -namespace duckdb { - -bool ConflictInfo::ConflictTargetMatches(Index &index) const { - if (only_check_unique && !index.IsUnique()) { - // We only support checking ON CONFLICT for Unique/Primary key constraints - return false; - } - if (column_ids.empty()) { - return true; - } - // Check whether the column ids match - return column_ids == index.column_id_set; -} - -} // namespace duckdb - - - - - -namespace duckdb { - -ConflictManager::ConflictManager(VerifyExistenceType lookup_type, idx_t input_size, - optional_ptr conflict_info) - : lookup_type(lookup_type), input_size(input_size), conflict_info(conflict_info), conflicts(input_size, false), - mode(ConflictManagerMode::THROW) { -} - -ManagedSelection &ConflictManager::InternalSelection() { - if (!conflicts.Initialized()) { - conflicts.Initialize(input_size); - } - return conflicts; -} - -const unordered_set &ConflictManager::InternalConflictSet() const { - D_ASSERT(conflict_set); - return *conflict_set; -} - -Vector &ConflictManager::InternalRowIds() { - if (!row_ids) { - row_ids = make_uniq(LogicalType::ROW_TYPE, input_size); - } - return *row_ids; -} - -Vector &ConflictManager::InternalIntermediate() { - if (!intermediate_vector) { - intermediate_vector = make_uniq(LogicalType::BOOLEAN, true, true, input_size); - } - return *intermediate_vector; -} - -const ConflictInfo &ConflictManager::GetConflictInfo() const { - D_ASSERT(conflict_info); - return *conflict_info; -} - -void ConflictManager::FinishLookup() { - if (mode == ConflictManagerMode::THROW) { - return; - } - if (!SingleIndexTarget()) { - return; - } - if (conflicts.Count() != 0) { - // We have recorded conflicts from the one index we're interested in - // We set this so we don't duplicate the conflicts when there are duplicate indexes - // that also match our conflict target - single_index_finished = true; - } -} - -void ConflictManager::SetMode(ConflictManagerMode mode) { - // Only allow SCAN when we have conflict info - D_ASSERT(mode != ConflictManagerMode::SCAN || conflict_info != nullptr); - this->mode = mode; -} - -void ConflictManager::AddToConflictSet(idx_t chunk_index) { - if (!conflict_set) { - conflict_set = make_uniq>(); - } - auto &set = *conflict_set; - set.insert(chunk_index); -} - -void ConflictManager::AddConflictInternal(idx_t chunk_index, row_t row_id) { - D_ASSERT(mode == ConflictManagerMode::SCAN); - - // Only when we should not throw on conflict should we get here - D_ASSERT(!ShouldThrow(chunk_index)); - AddToConflictSet(chunk_index); - if (SingleIndexTarget()) { - // If we have identical indexes, only the conflicts of the first index should be recorded - // as the other index(es) would produce the exact same conflicts anyways - if (single_index_finished) { - return; - } - - // We can be more efficient because we don't need to merge conflicts of multiple indexes - auto &selection = InternalSelection(); - auto &row_ids = InternalRowIds(); - auto data = FlatVector::GetData(row_ids); - data[selection.Count()] = row_id; - selection.Append(chunk_index); - } else { - auto &intermediate = InternalIntermediate(); - auto data = FlatVector::GetData(intermediate); - // Mark this index in the chunk as producing a conflict - data[chunk_index] = true; - if (row_id_map.empty()) { - row_id_map.resize(input_size); - } - row_id_map[chunk_index] = row_id; - } -} - -bool ConflictManager::IsConflict(LookupResultType type) { - switch (type) { - case LookupResultType::LOOKUP_NULL: { - if (ShouldIgnoreNulls()) { - return false; - } - // If nulls are not ignored, treat this as a hit instead - return IsConflict(LookupResultType::LOOKUP_HIT); - } - case LookupResultType::LOOKUP_HIT: { - return true; - } - case LookupResultType::LOOKUP_MISS: { - // FIXME: If we record a miss as a conflict when the verify type is APPEND_FK, then we can simplify the checks - // in VerifyForeignKeyConstraint This also means we should not record a hit as a conflict when the verify type - // is APPEND_FK - return false; - } - default: { - throw NotImplementedException("Type not implemented for LookupResultType"); - } - } -} - -bool ConflictManager::AddHit(idx_t chunk_index, row_t row_id) { - D_ASSERT(chunk_index < input_size); - // First check if this causes a conflict - if (!IsConflict(LookupResultType::LOOKUP_HIT)) { - return false; - } - - // Then check if we should throw on a conflict - if (ShouldThrow(chunk_index)) { - return true; - } - if (mode == ConflictManagerMode::THROW) { - // When our mode is THROW, and the chunk index is part of the previously scanned conflicts - // then we ignore the conflict instead - D_ASSERT(!ShouldThrow(chunk_index)); - return false; - } - D_ASSERT(conflict_info); - // Because we don't throw, we need to register the conflict - AddConflictInternal(chunk_index, row_id); - return false; -} - -bool ConflictManager::AddMiss(idx_t chunk_index) { - D_ASSERT(chunk_index < input_size); - return IsConflict(LookupResultType::LOOKUP_MISS); -} - -bool ConflictManager::AddNull(idx_t chunk_index) { - D_ASSERT(chunk_index < input_size); - if (!IsConflict(LookupResultType::LOOKUP_NULL)) { - return false; - } - return AddHit(chunk_index, DConstants::INVALID_INDEX); -} - -bool ConflictManager::SingleIndexTarget() const { - D_ASSERT(conflict_info); - // We are only interested in a specific index - return !conflict_info->column_ids.empty(); -} - -bool ConflictManager::ShouldThrow(idx_t chunk_index) const { - if (mode == ConflictManagerMode::SCAN) { - return false; - } - D_ASSERT(mode == ConflictManagerMode::THROW); - if (conflict_set == nullptr) { - // No conflicts were scanned, so this conflict is not in the set - return true; - } - auto &set = InternalConflictSet(); - if (set.count(chunk_index)) { - return false; - } - // None of the scanned conflicts arose from this insert tuple - return true; -} - -bool ConflictManager::ShouldIgnoreNulls() const { - switch (lookup_type) { - case VerifyExistenceType::APPEND: - return true; - case VerifyExistenceType::APPEND_FK: - return false; - case VerifyExistenceType::DELETE_FK: - return true; - default: - throw InternalException("Type not implemented for VerifyExistenceType"); - } -} - -Vector &ConflictManager::RowIds() { - D_ASSERT(finalized); - return *row_ids; -} - -const ManagedSelection &ConflictManager::Conflicts() const { - D_ASSERT(finalized); - return conflicts; -} - -idx_t ConflictManager::ConflictCount() const { - return conflicts.Count(); -} - -void ConflictManager::Finalize() { - D_ASSERT(!finalized); - if (SingleIndexTarget()) { - // Selection vector has been directly populated already, no need to finalize - finalized = true; - return; - } - finalized = true; - if (!intermediate_vector) { - // No conflicts were found, we're done - return; - } - auto &intermediate = InternalIntermediate(); - auto data = FlatVector::GetData(intermediate); - auto &selection = InternalSelection(); - // Create the selection vector from the encountered conflicts - for (idx_t i = 0; i < input_size; i++) { - if (data[i]) { - selection.Append(i); - } - } - // Now create the row_ids Vector, aligned with the selection vector - auto &row_ids = InternalRowIds(); - auto row_id_data = FlatVector::GetData(row_ids); - - for (idx_t i = 0; i < selection.Count(); i++) { - D_ASSERT(!row_id_map.empty()); - auto index = selection[i]; - D_ASSERT(index < row_id_map.size()); - auto row_id = row_id_map[index]; - row_id_data[i] = row_id; - } - intermediate_vector.reset(); -} - -VerifyExistenceType ConflictManager::LookupType() const { - return this->lookup_type; -} - -void ConflictManager::SetIndexCount(idx_t count) { - index_count = count; -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - - -namespace duckdb { - -DataChunk::DataChunk() : count(0), capacity(STANDARD_VECTOR_SIZE) { -} - -DataChunk::~DataChunk() { -} - -void DataChunk::InitializeEmpty(const vector &types) { - InitializeEmpty(types.begin(), types.end()); -} - -void DataChunk::Initialize(Allocator &allocator, const vector &types, idx_t capacity_p) { - Initialize(allocator, types.begin(), types.end(), capacity_p); -} - -void DataChunk::Initialize(ClientContext &context, const vector &types, idx_t capacity_p) { - Initialize(Allocator::Get(context), types, capacity_p); -} - -void DataChunk::Initialize(Allocator &allocator, vector::const_iterator begin, - vector::const_iterator end, idx_t capacity_p) { - D_ASSERT(data.empty()); // can only be initialized once - D_ASSERT(std::distance(begin, end) != 0); // empty chunk not allowed - capacity = capacity_p; - for (; begin != end; begin++) { - VectorCache cache(allocator, *begin, capacity); - data.emplace_back(cache); - vector_caches.push_back(std::move(cache)); - } -} - -void DataChunk::Initialize(ClientContext &context, vector::const_iterator begin, - vector::const_iterator end, idx_t capacity_p) { - Initialize(Allocator::Get(context), begin, end, capacity_p); -} - -void DataChunk::InitializeEmpty(vector::const_iterator begin, vector::const_iterator end) { - capacity = STANDARD_VECTOR_SIZE; - D_ASSERT(data.empty()); // can only be initialized once - D_ASSERT(std::distance(begin, end) != 0); // empty chunk not allowed - for (; begin != end; begin++) { - data.emplace_back(*begin, nullptr); - } -} - -void DataChunk::Reset() { - if (data.empty()) { - return; - } - if (vector_caches.size() != data.size()) { - throw InternalException("VectorCache and column count mismatch in DataChunk::Reset"); - } - for (idx_t i = 0; i < ColumnCount(); i++) { - data[i].ResetFromCache(vector_caches[i]); - } - capacity = STANDARD_VECTOR_SIZE; - SetCardinality(0); -} - -void DataChunk::Destroy() { - data.clear(); - vector_caches.clear(); - capacity = 0; - SetCardinality(0); -} - -Value DataChunk::GetValue(idx_t col_idx, idx_t index) const { - D_ASSERT(index < size()); - return data[col_idx].GetValue(index); -} - -void DataChunk::SetValue(idx_t col_idx, idx_t index, const Value &val) { - data[col_idx].SetValue(index, val); -} - -bool DataChunk::AllConstant() const { - for (auto &v : data) { - if (v.GetVectorType() != VectorType::CONSTANT_VECTOR) { - return false; - } - } - return true; -} - -void DataChunk::Reference(DataChunk &chunk) { - D_ASSERT(chunk.ColumnCount() <= ColumnCount()); - SetCapacity(chunk); - SetCardinality(chunk); - for (idx_t i = 0; i < chunk.ColumnCount(); i++) { - data[i].Reference(chunk.data[i]); - } -} - -void DataChunk::Move(DataChunk &chunk) { - SetCardinality(chunk); - SetCapacity(chunk); - data = std::move(chunk.data); - vector_caches = std::move(chunk.vector_caches); - - chunk.Destroy(); -} - -void DataChunk::Copy(DataChunk &other, idx_t offset) const { - D_ASSERT(ColumnCount() == other.ColumnCount()); - D_ASSERT(other.size() == 0); - - for (idx_t i = 0; i < ColumnCount(); i++) { - D_ASSERT(other.data[i].GetVectorType() == VectorType::FLAT_VECTOR); - VectorOperations::Copy(data[i], other.data[i], size(), offset, 0); - } - other.SetCardinality(size() - offset); -} - -void DataChunk::Copy(DataChunk &other, const SelectionVector &sel, const idx_t source_count, const idx_t offset) const { - D_ASSERT(ColumnCount() == other.ColumnCount()); - D_ASSERT(other.size() == 0); - D_ASSERT((offset + source_count) <= size()); - - for (idx_t i = 0; i < ColumnCount(); i++) { - D_ASSERT(other.data[i].GetVectorType() == VectorType::FLAT_VECTOR); - VectorOperations::Copy(data[i], other.data[i], sel, source_count, offset, 0); - } - other.SetCardinality(source_count - offset); -} - -void DataChunk::Split(DataChunk &other, idx_t split_idx) { - D_ASSERT(other.size() == 0); - D_ASSERT(other.data.empty()); - D_ASSERT(split_idx < data.size()); - const idx_t num_cols = data.size(); - for (idx_t col_idx = split_idx; col_idx < num_cols; col_idx++) { - other.data.push_back(std::move(data[col_idx])); - other.vector_caches.push_back(std::move(vector_caches[col_idx])); - } - for (idx_t col_idx = split_idx; col_idx < num_cols; col_idx++) { - data.pop_back(); - vector_caches.pop_back(); - } - other.SetCapacity(*this); - other.SetCardinality(*this); -} - -void DataChunk::Fuse(DataChunk &other) { - D_ASSERT(other.size() == size()); - const idx_t num_cols = other.data.size(); - for (idx_t col_idx = 0; col_idx < num_cols; ++col_idx) { - data.emplace_back(std::move(other.data[col_idx])); - vector_caches.emplace_back(std::move(other.vector_caches[col_idx])); - } - other.Destroy(); -} - -void DataChunk::ReferenceColumns(DataChunk &other, const vector &column_ids) { - D_ASSERT(ColumnCount() == column_ids.size()); - Reset(); - for (idx_t col_idx = 0; col_idx < ColumnCount(); col_idx++) { - auto &other_col = other.data[column_ids[col_idx]]; - auto &this_col = data[col_idx]; - D_ASSERT(other_col.GetType() == this_col.GetType()); - this_col.Reference(other_col); - } - SetCardinality(other.size()); -} - -void DataChunk::Append(const DataChunk &other, bool resize, SelectionVector *sel, idx_t sel_count) { - idx_t new_size = sel ? size() + sel_count : size() + other.size(); - if (other.size() == 0) { - return; - } - if (ColumnCount() != other.ColumnCount()) { - throw InternalException("Column counts of appending chunk doesn't match!"); - } - if (new_size > capacity) { - if (resize) { - auto new_capacity = NextPowerOfTwo(new_size); - for (idx_t i = 0; i < ColumnCount(); i++) { - data[i].Resize(size(), new_capacity); - } - capacity = new_capacity; - } else { - throw InternalException("Can't append chunk to other chunk without resizing"); - } - } - for (idx_t i = 0; i < ColumnCount(); i++) { - D_ASSERT(data[i].GetVectorType() == VectorType::FLAT_VECTOR); - if (sel) { - VectorOperations::Copy(other.data[i], data[i], *sel, sel_count, 0, size()); - } else { - VectorOperations::Copy(other.data[i], data[i], other.size(), 0, size()); - } - } - SetCardinality(new_size); -} - -void DataChunk::Flatten() { - for (idx_t i = 0; i < ColumnCount(); i++) { - data[i].Flatten(size()); - } -} - -vector DataChunk::GetTypes() { - vector types; - for (idx_t i = 0; i < ColumnCount(); i++) { - types.push_back(data[i].GetType()); - } - return types; -} - -string DataChunk::ToString() const { - string retval = "Chunk - [" + to_string(ColumnCount()) + " Columns]\n"; - for (idx_t i = 0; i < ColumnCount(); i++) { - retval += "- " + data[i].ToString(size()) + "\n"; - } - return retval; -} - -void DataChunk::Serialize(Serializer &serializer) const { - - // write the count - auto row_count = size(); - serializer.WriteProperty(100, "rows", row_count); - - // we should never try to serialize empty data chunks - auto column_count = ColumnCount(); - D_ASSERT(column_count); - - // write the types - serializer.WriteList(101, "types", column_count, - [&](Serializer::List &list, idx_t i) { list.WriteElement(data[i].GetType()); }); - - // write the data - serializer.WriteList(102, "columns", column_count, [&](Serializer::List &list, idx_t i) { - list.WriteObject([&](Serializer &object) { - // Reference the vector to avoid potentially mutating it during serialization - Vector serialized_vector(data[i].GetType()); - serialized_vector.Reference(data[i]); - serialized_vector.Serialize(object, row_count); - }); - }); -} - -void DataChunk::Deserialize(Deserializer &deserializer) { - - // read and set the row count - auto row_count = deserializer.ReadProperty(100, "rows"); - - // read the types - vector types; - deserializer.ReadList(101, "types", [&](Deserializer::List &list, idx_t i) { - auto type = list.ReadElement(); - types.push_back(type); - }); - - // initialize the data chunk - D_ASSERT(!types.empty()); - Initialize(Allocator::DefaultAllocator(), types); - SetCardinality(row_count); - - // read the data - deserializer.ReadList(102, "columns", [&](Deserializer::List &list, idx_t i) { - list.ReadObject([&](Deserializer &object) { data[i].Deserialize(object, row_count); }); - }); -} - -void DataChunk::Slice(const SelectionVector &sel_vector, idx_t count_p) { - this->count = count_p; - SelCache merge_cache; - for (idx_t c = 0; c < ColumnCount(); c++) { - data[c].Slice(sel_vector, count_p, merge_cache); - } -} - -void DataChunk::Slice(DataChunk &other, const SelectionVector &sel, idx_t count_p, idx_t col_offset) { - D_ASSERT(other.ColumnCount() <= col_offset + ColumnCount()); - this->count = count_p; - SelCache merge_cache; - for (idx_t c = 0; c < other.ColumnCount(); c++) { - if (other.data[c].GetVectorType() == VectorType::DICTIONARY_VECTOR) { - // already a dictionary! merge the dictionaries - data[col_offset + c].Reference(other.data[c]); - data[col_offset + c].Slice(sel, count_p, merge_cache); - } else { - data[col_offset + c].Slice(other.data[c], sel, count_p); - } - } -} - -unsafe_unique_array DataChunk::ToUnifiedFormat() { - auto unified_data = make_unsafe_uniq_array(ColumnCount()); - for (idx_t col_idx = 0; col_idx < ColumnCount(); col_idx++) { - data[col_idx].ToUnifiedFormat(size(), unified_data[col_idx]); - } - return unified_data; -} - -void DataChunk::Hash(Vector &result) { - D_ASSERT(result.GetType().id() == LogicalType::HASH); - VectorOperations::Hash(data[0], result, size()); - for (idx_t i = 1; i < ColumnCount(); i++) { - VectorOperations::CombineHash(result, data[i], size()); - } -} - -void DataChunk::Hash(vector &column_ids, Vector &result) { - D_ASSERT(result.GetType().id() == LogicalType::HASH); - D_ASSERT(!column_ids.empty()); - - VectorOperations::Hash(data[column_ids[0]], result, size()); - for (idx_t i = 1; i < column_ids.size(); i++) { - VectorOperations::CombineHash(result, data[column_ids[i]], size()); - } -} - -void DataChunk::Verify() { -#ifdef DEBUG - D_ASSERT(size() <= capacity); - - // verify that all vectors in this chunk have the chunk selection vector - for (idx_t i = 0; i < ColumnCount(); i++) { - data[i].Verify(size()); - } - - if (!ColumnCount()) { - // don't try to round-trip dummy data chunks with no data - // e.g., these exist in queries like 'SELECT distinct(col0, col1) FROM tbl', where we have groups, but no - // payload so the payload will be such an empty data chunk - return; - } - - // verify that we can round-trip chunk serialization - MemoryStream mem_stream; - BinarySerializer serializer(mem_stream); - - serializer.Begin(); - Serialize(serializer); - serializer.End(); - - mem_stream.Rewind(); - - BinaryDeserializer deserializer(mem_stream); - DataChunk new_chunk; - - deserializer.Begin(); - new_chunk.Deserialize(deserializer); - deserializer.End(); - - D_ASSERT(size() == new_chunk.size()); -#endif -} - -void DataChunk::Print() const { - Printer::Print(ToString()); -} - -} // namespace duckdb - - - - - - - - - - -#include -#include -#include - -namespace duckdb { - -static_assert(sizeof(date_t) == sizeof(int32_t), "date_t was padded"); - -const char *Date::PINF = "infinity"; // NOLINT -const char *Date::NINF = "-infinity"; // NOLINT -const char *Date::EPOCH = "epoch"; // NOLINT - -const string_t Date::MONTH_NAMES_ABBREVIATED[] = {"Jan", "Feb", "Mar", "Apr", "May", "Jun", - "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"}; -const string_t Date::MONTH_NAMES[] = {"January", "February", "March", "April", "May", "June", - "July", "August", "September", "October", "November", "December"}; -const string_t Date::DAY_NAMES[] = {"Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday"}; -const string_t Date::DAY_NAMES_ABBREVIATED[] = {"Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat"}; - -const int32_t Date::NORMAL_DAYS[] = {0, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31}; -const int32_t Date::CUMULATIVE_DAYS[] = {0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334, 365}; -const int32_t Date::LEAP_DAYS[] = {0, 31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31}; -const int32_t Date::CUMULATIVE_LEAP_DAYS[] = {0, 31, 60, 91, 121, 152, 182, 213, 244, 274, 305, 335, 366}; -const int8_t Date::MONTH_PER_DAY_OF_YEAR[] = { - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, - 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, - 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, - 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, - 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, - 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, - 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, - 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, - 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, - 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, - 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, - 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12}; -const int8_t Date::LEAP_MONTH_PER_DAY_OF_YEAR[] = { - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, - 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, - 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, - 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, - 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, - 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, - 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, - 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, - 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, - 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, - 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, - 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12}; -const int32_t Date::CUMULATIVE_YEAR_DAYS[] = { - 0, 365, 730, 1096, 1461, 1826, 2191, 2557, 2922, 3287, 3652, 4018, 4383, 4748, - 5113, 5479, 5844, 6209, 6574, 6940, 7305, 7670, 8035, 8401, 8766, 9131, 9496, 9862, - 10227, 10592, 10957, 11323, 11688, 12053, 12418, 12784, 13149, 13514, 13879, 14245, 14610, 14975, - 15340, 15706, 16071, 16436, 16801, 17167, 17532, 17897, 18262, 18628, 18993, 19358, 19723, 20089, - 20454, 20819, 21184, 21550, 21915, 22280, 22645, 23011, 23376, 23741, 24106, 24472, 24837, 25202, - 25567, 25933, 26298, 26663, 27028, 27394, 27759, 28124, 28489, 28855, 29220, 29585, 29950, 30316, - 30681, 31046, 31411, 31777, 32142, 32507, 32872, 33238, 33603, 33968, 34333, 34699, 35064, 35429, - 35794, 36160, 36525, 36890, 37255, 37621, 37986, 38351, 38716, 39082, 39447, 39812, 40177, 40543, - 40908, 41273, 41638, 42004, 42369, 42734, 43099, 43465, 43830, 44195, 44560, 44926, 45291, 45656, - 46021, 46387, 46752, 47117, 47482, 47847, 48212, 48577, 48942, 49308, 49673, 50038, 50403, 50769, - 51134, 51499, 51864, 52230, 52595, 52960, 53325, 53691, 54056, 54421, 54786, 55152, 55517, 55882, - 56247, 56613, 56978, 57343, 57708, 58074, 58439, 58804, 59169, 59535, 59900, 60265, 60630, 60996, - 61361, 61726, 62091, 62457, 62822, 63187, 63552, 63918, 64283, 64648, 65013, 65379, 65744, 66109, - 66474, 66840, 67205, 67570, 67935, 68301, 68666, 69031, 69396, 69762, 70127, 70492, 70857, 71223, - 71588, 71953, 72318, 72684, 73049, 73414, 73779, 74145, 74510, 74875, 75240, 75606, 75971, 76336, - 76701, 77067, 77432, 77797, 78162, 78528, 78893, 79258, 79623, 79989, 80354, 80719, 81084, 81450, - 81815, 82180, 82545, 82911, 83276, 83641, 84006, 84371, 84736, 85101, 85466, 85832, 86197, 86562, - 86927, 87293, 87658, 88023, 88388, 88754, 89119, 89484, 89849, 90215, 90580, 90945, 91310, 91676, - 92041, 92406, 92771, 93137, 93502, 93867, 94232, 94598, 94963, 95328, 95693, 96059, 96424, 96789, - 97154, 97520, 97885, 98250, 98615, 98981, 99346, 99711, 100076, 100442, 100807, 101172, 101537, 101903, - 102268, 102633, 102998, 103364, 103729, 104094, 104459, 104825, 105190, 105555, 105920, 106286, 106651, 107016, - 107381, 107747, 108112, 108477, 108842, 109208, 109573, 109938, 110303, 110669, 111034, 111399, 111764, 112130, - 112495, 112860, 113225, 113591, 113956, 114321, 114686, 115052, 115417, 115782, 116147, 116513, 116878, 117243, - 117608, 117974, 118339, 118704, 119069, 119435, 119800, 120165, 120530, 120895, 121260, 121625, 121990, 122356, - 122721, 123086, 123451, 123817, 124182, 124547, 124912, 125278, 125643, 126008, 126373, 126739, 127104, 127469, - 127834, 128200, 128565, 128930, 129295, 129661, 130026, 130391, 130756, 131122, 131487, 131852, 132217, 132583, - 132948, 133313, 133678, 134044, 134409, 134774, 135139, 135505, 135870, 136235, 136600, 136966, 137331, 137696, - 138061, 138427, 138792, 139157, 139522, 139888, 140253, 140618, 140983, 141349, 141714, 142079, 142444, 142810, - 143175, 143540, 143905, 144271, 144636, 145001, 145366, 145732, 146097}; - -void Date::ExtractYearOffset(int32_t &n, int32_t &year, int32_t &year_offset) { - year = Date::EPOCH_YEAR; - // first we normalize n to be in the year range [1970, 2370] - // since leap years repeat every 400 years, we can safely normalize just by "shifting" the CumulativeYearDays array - while (n < 0) { - n += Date::DAYS_PER_YEAR_INTERVAL; - year -= Date::YEAR_INTERVAL; - } - while (n >= Date::DAYS_PER_YEAR_INTERVAL) { - n -= Date::DAYS_PER_YEAR_INTERVAL; - year += Date::YEAR_INTERVAL; - } - // interpolation search - // we can find an upper bound of the year by assuming each year has 365 days - year_offset = n / 365; - // because of leap years we might be off by a little bit: compensate by decrementing the year offset until we find - // our year - while (n < Date::CUMULATIVE_YEAR_DAYS[year_offset]) { - year_offset--; - D_ASSERT(year_offset >= 0); - } - year += year_offset; - D_ASSERT(n >= Date::CUMULATIVE_YEAR_DAYS[year_offset]); -} - -void Date::Convert(date_t d, int32_t &year, int32_t &month, int32_t &day) { - auto n = d.days; - int32_t year_offset; - Date::ExtractYearOffset(n, year, year_offset); - - day = n - Date::CUMULATIVE_YEAR_DAYS[year_offset]; - D_ASSERT(day >= 0 && day <= 365); - - bool is_leap_year = (Date::CUMULATIVE_YEAR_DAYS[year_offset + 1] - Date::CUMULATIVE_YEAR_DAYS[year_offset]) == 366; - if (is_leap_year) { - month = Date::LEAP_MONTH_PER_DAY_OF_YEAR[day]; - day -= Date::CUMULATIVE_LEAP_DAYS[month - 1]; - } else { - month = Date::MONTH_PER_DAY_OF_YEAR[day]; - day -= Date::CUMULATIVE_DAYS[month - 1]; - } - day++; - D_ASSERT(day > 0 && day <= (is_leap_year ? Date::LEAP_DAYS[month] : Date::NORMAL_DAYS[month])); - D_ASSERT(month > 0 && month <= 12); -} - -bool Date::TryFromDate(int32_t year, int32_t month, int32_t day, date_t &result) { - int32_t n = 0; - if (!Date::IsValid(year, month, day)) { - return false; - } - n += Date::IsLeapYear(year) ? Date::CUMULATIVE_LEAP_DAYS[month - 1] : Date::CUMULATIVE_DAYS[month - 1]; - n += day - 1; - if (year < 1970) { - int32_t diff_from_base = 1970 - year; - int32_t year_index = 400 - (diff_from_base % 400); - int32_t fractions = diff_from_base / 400; - n += Date::CUMULATIVE_YEAR_DAYS[year_index]; - n -= Date::DAYS_PER_YEAR_INTERVAL; - n -= fractions * Date::DAYS_PER_YEAR_INTERVAL; - } else if (year >= 2370) { - int32_t diff_from_base = year - 2370; - int32_t year_index = diff_from_base % 400; - int32_t fractions = diff_from_base / 400; - n += Date::CUMULATIVE_YEAR_DAYS[year_index]; - n += Date::DAYS_PER_YEAR_INTERVAL; - n += fractions * Date::DAYS_PER_YEAR_INTERVAL; - } else { - n += Date::CUMULATIVE_YEAR_DAYS[year - 1970]; - } -#ifdef DEBUG - int32_t y, m, d; - Date::Convert(date_t(n), y, m, d); - D_ASSERT(year == y); - D_ASSERT(month == m); - D_ASSERT(day == d); -#endif - result = date_t(n); - return true; -} - -date_t Date::FromDate(int32_t year, int32_t month, int32_t day) { - date_t result; - if (!Date::TryFromDate(year, month, day, result)) { - throw ConversionException("Date out of range: %d-%d-%d", year, month, day); - } - return result; -} - -bool Date::ParseDoubleDigit(const char *buf, idx_t len, idx_t &pos, int32_t &result) { - if (pos < len && StringUtil::CharacterIsDigit(buf[pos])) { - result = buf[pos++] - '0'; - if (pos < len && StringUtil::CharacterIsDigit(buf[pos])) { - result = (buf[pos++] - '0') + result * 10; - } - return true; - } - return false; -} - -static bool TryConvertDateSpecial(const char *buf, idx_t len, idx_t &pos, const char *special) { - auto p = pos; - for (; p < len && *special; ++p) { - const auto s = *special++; - if (!s || StringUtil::CharacterToLower(buf[p]) != s) { - return false; - } - } - if (*special) { - return false; - } - pos = p; - return true; -} - -bool Date::TryConvertDate(const char *buf, idx_t len, idx_t &pos, date_t &result, bool &special, bool strict) { - special = false; - pos = 0; - if (len == 0) { - return false; - } - - int32_t day = 0; - int32_t month = -1; - int32_t year = 0; - bool yearneg = false; - int sep; - - // skip leading spaces - while (pos < len && StringUtil::CharacterIsSpace(buf[pos])) { - pos++; - } - - if (pos >= len) { - return false; - } - if (buf[pos] == '-') { - yearneg = true; - pos++; - if (pos >= len) { - return false; - } - } - if (!StringUtil::CharacterIsDigit(buf[pos])) { - // Check for special values - if (TryConvertDateSpecial(buf, len, pos, PINF)) { - result = yearneg ? date_t::ninfinity() : date_t::infinity(); - } else if (TryConvertDateSpecial(buf, len, pos, EPOCH)) { - result = date_t::epoch(); - } else { - return false; - } - // skip trailing spaces - parsing must be strict here - while (pos < len && StringUtil::CharacterIsSpace(buf[pos])) { - pos++; - } - special = true; - return pos == len; - } - // first parse the year - for (; pos < len && StringUtil::CharacterIsDigit(buf[pos]); pos++) { - if (year >= 100000000) { - return false; - } - year = (buf[pos] - '0') + year * 10; - } - if (yearneg) { - year = -year; - } - - if (pos >= len) { - return false; - } - - // fetch the separator - sep = buf[pos++]; - if (sep != ' ' && sep != '-' && sep != '/' && sep != '\\') { - // invalid separator - return false; - } - - // parse the month - if (!Date::ParseDoubleDigit(buf, len, pos, month)) { - return false; - } - - if (pos >= len) { - return false; - } - - if (buf[pos++] != sep) { - return false; - } - - if (pos >= len) { - return false; - } - - // now parse the day - if (!Date::ParseDoubleDigit(buf, len, pos, day)) { - return false; - } - - // check for an optional trailing " (BC)"" - if (len - pos >= 5 && StringUtil::CharacterIsSpace(buf[pos]) && buf[pos + 1] == '(' && - StringUtil::CharacterToLower(buf[pos + 2]) == 'b' && StringUtil::CharacterToLower(buf[pos + 3]) == 'c' && - buf[pos + 4] == ')') { - if (yearneg || year == 0) { - return false; - } - year = -year + 1; - pos += 5; - } - - // in strict mode, check remaining string for non-space characters - if (strict) { - // skip trailing spaces - while (pos < len && StringUtil::CharacterIsSpace((unsigned char)buf[pos])) { - pos++; - } - // check position. if end was not reached, non-space chars remaining - if (pos < len) { - return false; - } - } else { - // in non-strict mode, check for any direct trailing digits - if (pos < len && StringUtil::CharacterIsDigit((unsigned char)buf[pos])) { - return false; - } - } - - return Date::TryFromDate(year, month, day, result); -} - -string Date::ConversionError(const string &str) { - return StringUtil::Format("date field value out of range: \"%s\", " - "expected format is (YYYY-MM-DD)", - str); -} - -string Date::ConversionError(string_t str) { - return ConversionError(str.GetString()); -} - -date_t Date::FromCString(const char *buf, idx_t len, bool strict) { - date_t result; - idx_t pos; - bool special = false; - if (!TryConvertDate(buf, len, pos, result, special, strict)) { - throw ConversionException(ConversionError(string(buf, len))); - } - return result; -} - -date_t Date::FromString(const string &str, bool strict) { - return Date::FromCString(str.c_str(), str.size(), strict); -} - -string Date::ToString(date_t date) { - // PG displays temporal infinities in lowercase, - // but numerics in Titlecase. - if (date == date_t::infinity()) { - return PINF; - } else if (date == date_t::ninfinity()) { - return NINF; - } - int32_t date_units[3]; - idx_t year_length; - bool add_bc; - Date::Convert(date, date_units[0], date_units[1], date_units[2]); - - auto length = DateToStringCast::Length(date_units, year_length, add_bc); - auto buffer = make_unsafe_uniq_array(length); - DateToStringCast::Format(buffer.get(), date_units, year_length, add_bc); - return string(buffer.get(), length); -} - -string Date::Format(int32_t year, int32_t month, int32_t day) { - return ToString(Date::FromDate(year, month, day)); -} - -bool Date::IsLeapYear(int32_t year) { - return year % 4 == 0 && (year % 100 != 0 || year % 400 == 0); -} - -bool Date::IsValid(int32_t year, int32_t month, int32_t day) { - if (month < 1 || month > 12) { - return false; - } - if (day < 1) { - return false; - } - if (year <= DATE_MIN_YEAR) { - if (year < DATE_MIN_YEAR) { - return false; - } else if (year == DATE_MIN_YEAR) { - if (month < DATE_MIN_MONTH || (month == DATE_MIN_MONTH && day < DATE_MIN_DAY)) { - return false; - } - } - } - if (year >= DATE_MAX_YEAR) { - if (year > DATE_MAX_YEAR) { - return false; - } else if (year == DATE_MAX_YEAR) { - if (month > DATE_MAX_MONTH || (month == DATE_MAX_MONTH && day > DATE_MAX_DAY)) { - return false; - } - } - } - return Date::IsLeapYear(year) ? day <= Date::LEAP_DAYS[month] : day <= Date::NORMAL_DAYS[month]; -} - -int32_t Date::MonthDays(int32_t year, int32_t month) { - D_ASSERT(month >= 1 && month <= 12); - return Date::IsLeapYear(year) ? Date::LEAP_DAYS[month] : Date::NORMAL_DAYS[month]; -} - -date_t Date::EpochDaysToDate(int32_t epoch) { - return (date_t)epoch; -} - -int32_t Date::EpochDays(date_t date) { - return date.days; -} - -date_t Date::EpochToDate(int64_t epoch) { - return date_t(epoch / Interval::SECS_PER_DAY); -} - -int64_t Date::Epoch(date_t date) { - return ((int64_t)date.days) * Interval::SECS_PER_DAY; -} - -int64_t Date::EpochNanoseconds(date_t date) { - int64_t result; - if (!TryMultiplyOperator::Operation(date.days, Interval::MICROS_PER_DAY * 1000, - result)) { - throw ConversionException("Could not convert DATE (%s) to nanoseconds", Date::ToString(date)); - } - return result; -} - -int64_t Date::EpochMicroseconds(date_t date) { - int64_t result; - if (!TryMultiplyOperator::Operation(date.days, Interval::MICROS_PER_DAY, result)) { - throw ConversionException("Could not convert DATE (%s) to microseconds", Date::ToString(date)); - } - return result; -} - -int64_t Date::EpochMilliseconds(date_t date) { - int64_t result; - const auto MILLIS_PER_DAY = Interval::MICROS_PER_DAY / Interval::MICROS_PER_MSEC; - if (!TryMultiplyOperator::Operation(date.days, MILLIS_PER_DAY, result)) { - throw ConversionException("Could not convert DATE (%s) to milliseconds", Date::ToString(date)); - } - return result; -} - -int32_t Date::ExtractYear(date_t d, int32_t *last_year) { - auto n = d.days; - // cached look up: check if year of this date is the same as the last one we looked up - // note that this only works for years in the range [1970, 2370] - if (n >= Date::CUMULATIVE_YEAR_DAYS[*last_year] && n < Date::CUMULATIVE_YEAR_DAYS[*last_year + 1]) { - return Date::EPOCH_YEAR + *last_year; - } - int32_t year; - Date::ExtractYearOffset(n, year, *last_year); - return year; -} - -int32_t Date::ExtractYear(timestamp_t ts, int32_t *last_year) { - return Date::ExtractYear(Timestamp::GetDate(ts), last_year); -} - -int32_t Date::ExtractYear(date_t d) { - int32_t year, year_offset; - Date::ExtractYearOffset(d.days, year, year_offset); - return year; -} - -int32_t Date::ExtractMonth(date_t date) { - int32_t out_year, out_month, out_day; - Date::Convert(date, out_year, out_month, out_day); - return out_month; -} - -int32_t Date::ExtractDay(date_t date) { - int32_t out_year, out_month, out_day; - Date::Convert(date, out_year, out_month, out_day); - return out_day; -} - -int32_t Date::ExtractDayOfTheYear(date_t date) { - int32_t year, year_offset; - Date::ExtractYearOffset(date.days, year, year_offset); - return date.days - Date::CUMULATIVE_YEAR_DAYS[year_offset] + 1; -} - -int64_t Date::ExtractJulianDay(date_t date) { - // Julian Day 0 is (-4713, 11, 24) in the proleptic Gregorian calendar. - static const int64_t JULIAN_EPOCH = -2440588; - return date.days - JULIAN_EPOCH; -} - -int32_t Date::ExtractISODayOfTheWeek(date_t date) { - // date of 0 is 1970-01-01, which was a Thursday (4) - // -7 = 4 - // -6 = 5 - // -5 = 6 - // -4 = 7 - // -3 = 1 - // -2 = 2 - // -1 = 3 - // 0 = 4 - // 1 = 5 - // 2 = 6 - // 3 = 7 - // 4 = 1 - // 5 = 2 - // 6 = 3 - // 7 = 4 - if (date.days < 0) { - // negative date: start off at 4 and cycle downwards - return (7 - ((-int64_t(date.days) + 3) % 7)); - } else { - // positive date: start off at 4 and cycle upwards - return ((int64_t(date.days) + 3) % 7) + 1; - } -} - -template -static T PythonDivMod(const T &x, const T &y, T &r) { - // D_ASSERT(y > 0); - T quo = x / y; - r = x - quo * y; - if (r < 0) { - --quo; - r += y; - } - // D_ASSERT(0 <= r && r < y); - return quo; -} - -static date_t GetISOWeekOne(int32_t year) { - const auto first_day = Date::FromDate(year, 1, 1); /* ord of 1/1 */ - /* 0 if 1/1 is a Monday, 1 if a Tue, etc. */ - const auto first_weekday = Date::ExtractISODayOfTheWeek(first_day) - 1; - /* ordinal of closest Monday at or before 1/1 */ - auto week1_monday = first_day - first_weekday; - - if (first_weekday > 3) { /* if 1/1 was Fri, Sat, Sun */ - week1_monday += 7; - } - - return week1_monday; -} - -static int32_t GetISOYearWeek(const date_t date, int32_t &year) { - int32_t month, day; - Date::Convert(date, year, month, day); - auto week1_monday = GetISOWeekOne(year); - auto week = PythonDivMod((date.days - week1_monday.days), 7, day); - if (week < 0) { - week1_monday = GetISOWeekOne(--year); - week = PythonDivMod((date.days - week1_monday.days), 7, day); - } else if (week >= 52 && date >= GetISOWeekOne(year + 1)) { - ++year; - week = 0; - } - - return week + 1; -} - -void Date::ExtractISOYearWeek(date_t date, int32_t &year, int32_t &week) { - week = GetISOYearWeek(date, year); -} - -int32_t Date::ExtractISOWeekNumber(date_t date) { - int32_t year, week; - ExtractISOYearWeek(date, year, week); - return week; -} - -int32_t Date::ExtractISOYearNumber(date_t date) { - int32_t year, week; - ExtractISOYearWeek(date, year, week); - return year; -} - -int32_t Date::ExtractWeekNumberRegular(date_t date, bool monday_first) { - int32_t year, month, day; - Date::Convert(date, year, month, day); - month -= 1; - day -= 1; - // get the day of the year - auto day_of_the_year = - (Date::IsLeapYear(year) ? Date::CUMULATIVE_LEAP_DAYS[month] : Date::CUMULATIVE_DAYS[month]) + day; - // now figure out the first monday or sunday of the year - // what day is January 1st? - auto day_of_jan_first = Date::ExtractISODayOfTheWeek(Date::FromDate(year, 1, 1)); - // monday = 1, sunday = 7 - int32_t first_week_start; - if (monday_first) { - // have to find next "1" - if (day_of_jan_first == 1) { - // jan 1 is monday: starts immediately - first_week_start = 0; - } else { - // jan 1 is not monday: count days until next monday - first_week_start = 8 - day_of_jan_first; - } - } else { - first_week_start = 7 - day_of_jan_first; - } - if (day_of_the_year < first_week_start) { - // day occurs before first week starts: week 0 - return 0; - } - return ((day_of_the_year - first_week_start) / 7) + 1; -} - -// Returns the date of the monday of the current week. -date_t Date::GetMondayOfCurrentWeek(date_t date) { - int32_t dotw = Date::ExtractISODayOfTheWeek(date); - return date - (dotw - 1); -} - -} // namespace duckdb - - - -namespace duckdb { - -template -string TemplatedDecimalToString(SIGNED value, uint8_t width, uint8_t scale) { - auto len = DecimalToString::DecimalLength(value, width, scale); - auto data = make_unsafe_uniq_array(len + 1); - DecimalToString::FormatDecimal(value, width, scale, data.get(), len); - return string(data.get(), len); -} - -string Decimal::ToString(int16_t value, uint8_t width, uint8_t scale) { - return TemplatedDecimalToString(value, width, scale); -} - -string Decimal::ToString(int32_t value, uint8_t width, uint8_t scale) { - return TemplatedDecimalToString(value, width, scale); -} - -string Decimal::ToString(int64_t value, uint8_t width, uint8_t scale) { - return TemplatedDecimalToString(value, width, scale); -} - -string Decimal::ToString(hugeint_t value, uint8_t width, uint8_t scale) { - auto len = HugeintToStringCast::DecimalLength(value, width, scale); - auto data = make_unsafe_uniq_array(len + 1); - HugeintToStringCast::FormatDecimal(value, width, scale, data.get(), len); - return string(data.get(), len); -} - -} // namespace duckdb - - - - - - -#include -#include - -namespace duckdb { - -template <> -hash_t Hash(uint64_t val) { - return murmurhash64(val); -} - -template <> -hash_t Hash(int64_t val) { - return murmurhash64((uint64_t)val); -} - -template <> -hash_t Hash(hugeint_t val) { - return murmurhash64(val.lower) ^ murmurhash64(val.upper); -} - -template -struct FloatingPointEqualityTransform { - static void OP(T &val) { - if (val == (T)0.0) { - // Turn negative zero into positive zero - val = (T)0.0; - } else if (std::isnan(val)) { - val = std::numeric_limits::quiet_NaN(); - } - } -}; - -template <> -hash_t Hash(float val) { - static_assert(sizeof(float) == sizeof(uint32_t), ""); - FloatingPointEqualityTransform::OP(val); - uint32_t uval = Load(const_data_ptr_cast(&val)); - return murmurhash64(uval); -} - -template <> -hash_t Hash(double val) { - static_assert(sizeof(double) == sizeof(uint64_t), ""); - FloatingPointEqualityTransform::OP(val); - uint64_t uval = Load(const_data_ptr_cast(&val)); - return murmurhash64(uval); -} - -template <> -hash_t Hash(interval_t val) { - return Hash(val.days) ^ Hash(val.months) ^ Hash(val.micros); -} - -template <> -hash_t Hash(const char *str) { - return Hash(str, strlen(str)); -} - -template <> -hash_t Hash(string_t val) { - return Hash(val.GetData(), val.GetSize()); -} - -template <> -hash_t Hash(char *val) { - return Hash(val); -} - -// MIT License -// Copyright (c) 2018-2021 Martin Ankerl -// https://github.com/martinus/robin-hood-hashing/blob/3.11.5/LICENSE -hash_t HashBytes(void *ptr, size_t len) noexcept { - static constexpr uint64_t M = UINT64_C(0xc6a4a7935bd1e995); - static constexpr uint64_t SEED = UINT64_C(0xe17a1465); - static constexpr unsigned int R = 47; - - auto const *const data64 = static_cast(ptr); - uint64_t h = SEED ^ (len * M); - - size_t const n_blocks = len / 8; - for (size_t i = 0; i < n_blocks; ++i) { - auto k = Load(reinterpret_cast(data64 + i)); - - k *= M; - k ^= k >> R; - k *= M; - - h ^= k; - h *= M; - } - - auto const *const data8 = reinterpret_cast(data64 + n_blocks); - switch (len & 7U) { - case 7: - h ^= static_cast(data8[6]) << 48U; - DUCKDB_EXPLICIT_FALLTHROUGH; - case 6: - h ^= static_cast(data8[5]) << 40U; - DUCKDB_EXPLICIT_FALLTHROUGH; - case 5: - h ^= static_cast(data8[4]) << 32U; - DUCKDB_EXPLICIT_FALLTHROUGH; - case 4: - h ^= static_cast(data8[3]) << 24U; - DUCKDB_EXPLICIT_FALLTHROUGH; - case 3: - h ^= static_cast(data8[2]) << 16U; - DUCKDB_EXPLICIT_FALLTHROUGH; - case 2: - h ^= static_cast(data8[1]) << 8U; - DUCKDB_EXPLICIT_FALLTHROUGH; - case 1: - h ^= static_cast(data8[0]); - h *= M; - DUCKDB_EXPLICIT_FALLTHROUGH; - default: - break; - } - h ^= h >> R; - h *= M; - h ^= h >> R; - return static_cast(h); -} - -hash_t Hash(const char *val, size_t size) { - return HashBytes((void *)val, size); -} - -hash_t Hash(uint8_t *val, size_t size) { - return HashBytes((void *)val, size); -} - -} // namespace duckdb - - - - - - - - - -#include -#include - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// String Conversion -//===--------------------------------------------------------------------===// -const hugeint_t Hugeint::POWERS_OF_TEN[] { - hugeint_t(1), - hugeint_t(10), - hugeint_t(100), - hugeint_t(1000), - hugeint_t(10000), - hugeint_t(100000), - hugeint_t(1000000), - hugeint_t(10000000), - hugeint_t(100000000), - hugeint_t(1000000000), - hugeint_t(10000000000), - hugeint_t(100000000000), - hugeint_t(1000000000000), - hugeint_t(10000000000000), - hugeint_t(100000000000000), - hugeint_t(1000000000000000), - hugeint_t(10000000000000000), - hugeint_t(100000000000000000), - hugeint_t(1000000000000000000), - hugeint_t(1000000000000000000) * hugeint_t(10), - hugeint_t(1000000000000000000) * hugeint_t(100), - hugeint_t(1000000000000000000) * hugeint_t(1000), - hugeint_t(1000000000000000000) * hugeint_t(10000), - hugeint_t(1000000000000000000) * hugeint_t(100000), - hugeint_t(1000000000000000000) * hugeint_t(1000000), - hugeint_t(1000000000000000000) * hugeint_t(10000000), - hugeint_t(1000000000000000000) * hugeint_t(100000000), - hugeint_t(1000000000000000000) * hugeint_t(1000000000), - hugeint_t(1000000000000000000) * hugeint_t(10000000000), - hugeint_t(1000000000000000000) * hugeint_t(100000000000), - hugeint_t(1000000000000000000) * hugeint_t(1000000000000), - hugeint_t(1000000000000000000) * hugeint_t(10000000000000), - hugeint_t(1000000000000000000) * hugeint_t(100000000000000), - hugeint_t(1000000000000000000) * hugeint_t(1000000000000000), - hugeint_t(1000000000000000000) * hugeint_t(10000000000000000), - hugeint_t(1000000000000000000) * hugeint_t(100000000000000000), - hugeint_t(1000000000000000000) * hugeint_t(1000000000000000000), - hugeint_t(1000000000000000000) * hugeint_t(1000000000000000000) * hugeint_t(10), - hugeint_t(1000000000000000000) * hugeint_t(1000000000000000000) * hugeint_t(100)}; - -static uint8_t PositiveHugeintHighestBit(hugeint_t bits) { - uint8_t out = 0; - if (bits.upper) { - out = 64; - uint64_t up = bits.upper; - while (up) { - up >>= 1; - out++; - } - } else { - uint64_t low = bits.lower; - while (low) { - low >>= 1; - out++; - } - } - return out; -} - -static bool PositiveHugeintIsBitSet(hugeint_t lhs, uint8_t bit_position) { - if (bit_position < 64) { - return lhs.lower & (uint64_t(1) << uint64_t(bit_position)); - } else { - return lhs.upper & (uint64_t(1) << uint64_t(bit_position - 64)); - } -} - -hugeint_t PositiveHugeintLeftShift(hugeint_t lhs, uint32_t amount) { - D_ASSERT(amount > 0 && amount < 64); - hugeint_t result; - result.lower = lhs.lower << amount; - result.upper = (lhs.upper << amount) + (lhs.lower >> (64 - amount)); - return result; -} - -hugeint_t Hugeint::DivModPositive(hugeint_t lhs, uint64_t rhs, uint64_t &remainder) { - D_ASSERT(lhs.upper >= 0); - // DivMod code adapted from: - // https://github.com/calccrypto/uint128_t/blob/master/uint128_t.cpp - - // initialize the result and remainder to 0 - hugeint_t div_result; - div_result.lower = 0; - div_result.upper = 0; - remainder = 0; - - uint8_t highest_bit_set = PositiveHugeintHighestBit(lhs); - // now iterate over the amount of bits that are set in the LHS - for (uint8_t x = highest_bit_set; x > 0; x--) { - // left-shift the current result and remainder by 1 - div_result = PositiveHugeintLeftShift(div_result, 1); - remainder <<= 1; - // we get the value of the bit at position X, where position 0 is the least-significant bit - if (PositiveHugeintIsBitSet(lhs, x - 1)) { - // increment the remainder - remainder++; - } - if (remainder >= rhs) { - // the remainder has passed the division multiplier: add one to the divide result - remainder -= rhs; - div_result.lower++; - if (div_result.lower == 0) { - // overflow - div_result.upper++; - } - } - } - return div_result; -} - -string Hugeint::ToString(hugeint_t input) { - uint64_t remainder; - string result; - bool negative = input.upper < 0; - if (negative) { - NegateInPlace(input); - } - while (true) { - if (!input.lower && !input.upper) { - break; - } - input = Hugeint::DivModPositive(input, 10, remainder); - result = string(1, '0' + remainder) + result; // NOLINT - } - if (result.empty()) { - // value is zero - return "0"; - } - return negative ? "-" + result : result; -} - -//===--------------------------------------------------------------------===// -// Multiply -//===--------------------------------------------------------------------===// -bool Hugeint::TryMultiply(hugeint_t lhs, hugeint_t rhs, hugeint_t &result) { - bool lhs_negative = lhs.upper < 0; - bool rhs_negative = rhs.upper < 0; - if (lhs_negative) { - NegateInPlace(lhs); - } - if (rhs_negative) { - NegateInPlace(rhs); - } -#if ((__GNUC__ >= 5) || defined(__clang__)) && defined(__SIZEOF_INT128__) - __uint128_t left = __uint128_t(lhs.lower) + (__uint128_t(lhs.upper) << 64); - __uint128_t right = __uint128_t(rhs.lower) + (__uint128_t(rhs.upper) << 64); - __uint128_t result_i128; - if (__builtin_mul_overflow(left, right, &result_i128)) { - return false; - } - uint64_t upper = uint64_t(result_i128 >> 64); - if (upper & 0x8000000000000000) { - return false; - } - result.upper = int64_t(upper); - result.lower = uint64_t(result_i128 & 0xffffffffffffffff); -#else - // Multiply code adapted from: - // https://github.com/calccrypto/uint128_t/blob/master/uint128_t.cpp - - // split values into 4 32-bit parts - uint64_t top[4] = {uint64_t(lhs.upper) >> 32, uint64_t(lhs.upper) & 0xffffffff, lhs.lower >> 32, - lhs.lower & 0xffffffff}; - uint64_t bottom[4] = {uint64_t(rhs.upper) >> 32, uint64_t(rhs.upper) & 0xffffffff, rhs.lower >> 32, - rhs.lower & 0xffffffff}; - uint64_t products[4][4]; - - // multiply each component of the values - for (auto x = 0; x < 4; x++) { - for (auto y = 0; y < 4; y++) { - products[x][y] = top[x] * bottom[y]; - } - } - - // if any of these products are set to a non-zero value, there is always an overflow - if (products[0][0] || products[0][1] || products[0][2] || products[1][0] || products[2][0] || products[1][1]) { - return false; - } - // if the high bits of any of these are set, there is always an overflow - if ((products[0][3] & 0xffffffff80000000) || (products[1][2] & 0xffffffff80000000) || - (products[2][1] & 0xffffffff80000000) || (products[3][0] & 0xffffffff80000000)) { - return false; - } - - // otherwise we merge the result of the different products together in-order - - // first row - uint64_t fourth32 = (products[3][3] & 0xffffffff); - uint64_t third32 = (products[3][2] & 0xffffffff) + (products[3][3] >> 32); - uint64_t second32 = (products[3][1] & 0xffffffff) + (products[3][2] >> 32); - uint64_t first32 = (products[3][0] & 0xffffffff) + (products[3][1] >> 32); - - // second row - third32 += (products[2][3] & 0xffffffff); - second32 += (products[2][2] & 0xffffffff) + (products[2][3] >> 32); - first32 += (products[2][1] & 0xffffffff) + (products[2][2] >> 32); - - // third row - second32 += (products[1][3] & 0xffffffff); - first32 += (products[1][2] & 0xffffffff) + (products[1][3] >> 32); - - // fourth row - first32 += (products[0][3] & 0xffffffff); - - // move carry to next digit - third32 += fourth32 >> 32; - second32 += third32 >> 32; - first32 += second32 >> 32; - - // check if the combination of the different products resulted in an overflow - if (first32 & 0xffffff80000000) { - return false; - } - - // remove carry from current digit - fourth32 &= 0xffffffff; - third32 &= 0xffffffff; - second32 &= 0xffffffff; - first32 &= 0xffffffff; - - // combine components - result.lower = (third32 << 32) | fourth32; - result.upper = (first32 << 32) | second32; -#endif - if (lhs_negative ^ rhs_negative) { - NegateInPlace(result); - } - return true; -} - -hugeint_t Hugeint::Multiply(hugeint_t lhs, hugeint_t rhs) { - hugeint_t result; - if (!TryMultiply(lhs, rhs, result)) { - throw OutOfRangeException("Overflow in HUGEINT multiplication!"); - } - return result; -} - -//===--------------------------------------------------------------------===// -// Divide -//===--------------------------------------------------------------------===// -hugeint_t Hugeint::DivMod(hugeint_t lhs, hugeint_t rhs, hugeint_t &remainder) { - // division by zero not allowed - D_ASSERT(!(rhs.upper == 0 && rhs.lower == 0)); - - bool lhs_negative = lhs.upper < 0; - bool rhs_negative = rhs.upper < 0; - if (lhs_negative) { - Hugeint::NegateInPlace(lhs); - } - if (rhs_negative) { - Hugeint::NegateInPlace(rhs); - } - // DivMod code adapted from: - // https://github.com/calccrypto/uint128_t/blob/master/uint128_t.cpp - - // initialize the result and remainder to 0 - hugeint_t div_result; - div_result.lower = 0; - div_result.upper = 0; - remainder.lower = 0; - remainder.upper = 0; - - uint8_t highest_bit_set = PositiveHugeintHighestBit(lhs); - // now iterate over the amount of bits that are set in the LHS - for (uint8_t x = highest_bit_set; x > 0; x--) { - // left-shift the current result and remainder by 1 - div_result = PositiveHugeintLeftShift(div_result, 1); - remainder = PositiveHugeintLeftShift(remainder, 1); - - // we get the value of the bit at position X, where position 0 is the least-significant bit - if (PositiveHugeintIsBitSet(lhs, x - 1)) { - // increment the remainder - Hugeint::AddInPlace(remainder, 1); - } - if (Hugeint::GreaterThanEquals(remainder, rhs)) { - // the remainder has passed the division multiplier: add one to the divide result - remainder = Hugeint::Subtract(remainder, rhs); - Hugeint::AddInPlace(div_result, 1); - } - } - if (lhs_negative ^ rhs_negative) { - Hugeint::NegateInPlace(div_result); - } - if (lhs_negative) { - Hugeint::NegateInPlace(remainder); - } - return div_result; -} - -hugeint_t Hugeint::Divide(hugeint_t lhs, hugeint_t rhs) { - hugeint_t remainder; - return Hugeint::DivMod(lhs, rhs, remainder); -} - -hugeint_t Hugeint::Modulo(hugeint_t lhs, hugeint_t rhs) { - hugeint_t remainder; - Hugeint::DivMod(lhs, rhs, remainder); - return remainder; -} - -//===--------------------------------------------------------------------===// -// Add/Subtract -//===--------------------------------------------------------------------===// -bool Hugeint::AddInPlace(hugeint_t &lhs, hugeint_t rhs) { - int overflow = lhs.lower + rhs.lower < lhs.lower; - if (rhs.upper >= 0) { - // RHS is positive: check for overflow - if (lhs.upper > (std::numeric_limits::max() - rhs.upper - overflow)) { - return false; - } - lhs.upper = lhs.upper + overflow + rhs.upper; - } else { - // RHS is negative: check for underflow - if (lhs.upper < std::numeric_limits::min() - rhs.upper - overflow) { - return false; - } - lhs.upper = lhs.upper + (overflow + rhs.upper); - } - lhs.lower += rhs.lower; - if (lhs.upper == std::numeric_limits::min() && lhs.lower == 0) { - return false; - } - return true; -} - -bool Hugeint::SubtractInPlace(hugeint_t &lhs, hugeint_t rhs) { - // underflow - int underflow = lhs.lower - rhs.lower > lhs.lower; - if (rhs.upper >= 0) { - // RHS is positive: check for underflow - if (lhs.upper < (std::numeric_limits::min() + rhs.upper + underflow)) { - return false; - } - lhs.upper = (lhs.upper - rhs.upper) - underflow; - } else { - // RHS is negative: check for overflow - if (lhs.upper > std::numeric_limits::min() && - lhs.upper - 1 >= (std::numeric_limits::max() + rhs.upper + underflow)) { - return false; - } - lhs.upper = lhs.upper - (rhs.upper + underflow); - } - lhs.lower -= rhs.lower; - if (lhs.upper == std::numeric_limits::min() && lhs.lower == 0) { - return false; - } - return true; -} - -hugeint_t Hugeint::Add(hugeint_t lhs, hugeint_t rhs) { - if (!AddInPlace(lhs, rhs)) { - throw OutOfRangeException("Overflow in HUGEINT addition"); - } - return lhs; -} - -hugeint_t Hugeint::Subtract(hugeint_t lhs, hugeint_t rhs) { - if (!SubtractInPlace(lhs, rhs)) { - throw OutOfRangeException("Underflow in HUGEINT addition"); - } - return lhs; -} - -//===--------------------------------------------------------------------===// -// Hugeint Cast/Conversion -//===--------------------------------------------------------------------===// -template -bool HugeintTryCastInteger(hugeint_t input, DST &result) { - switch (input.upper) { - case 0: - // positive number: check if the positive number is in range - if (input.lower <= uint64_t(NumericLimits::Maximum())) { - result = DST(input.lower); - return true; - } - break; - case -1: - if (!SIGNED) { - return false; - } - // negative number: check if the negative number is in range - if (input.lower >= NumericLimits::Maximum() - uint64_t(NumericLimits::Maximum())) { - result = -DST(NumericLimits::Maximum() - input.lower) - 1; - return true; - } - break; - default: - break; - } - return false; -} - -template <> -bool Hugeint::TryCast(hugeint_t input, int8_t &result) { - return HugeintTryCastInteger(input, result); -} - -template <> -bool Hugeint::TryCast(hugeint_t input, int16_t &result) { - return HugeintTryCastInteger(input, result); -} - -template <> -bool Hugeint::TryCast(hugeint_t input, int32_t &result) { - return HugeintTryCastInteger(input, result); -} - -template <> -bool Hugeint::TryCast(hugeint_t input, int64_t &result) { - return HugeintTryCastInteger(input, result); -} - -template <> -bool Hugeint::TryCast(hugeint_t input, uint8_t &result) { - return HugeintTryCastInteger(input, result); -} - -template <> -bool Hugeint::TryCast(hugeint_t input, uint16_t &result) { - return HugeintTryCastInteger(input, result); -} - -template <> -bool Hugeint::TryCast(hugeint_t input, uint32_t &result) { - return HugeintTryCastInteger(input, result); -} - -template <> -bool Hugeint::TryCast(hugeint_t input, uint64_t &result) { - return HugeintTryCastInteger(input, result); -} - -template <> -bool Hugeint::TryCast(hugeint_t input, hugeint_t &result) { - result = input; - return true; -} - -template <> -bool Hugeint::TryCast(hugeint_t input, float &result) { - double dbl_result; - Hugeint::TryCast(input, dbl_result); - result = (float)dbl_result; - return true; -} - -template -bool CastBigintToFloating(hugeint_t input, REAL_T &result) { - switch (input.upper) { - case -1: - // special case for upper = -1 to avoid rounding issues in small negative numbers - result = -REAL_T(NumericLimits::Maximum() - input.lower) - 1; - break; - default: - result = REAL_T(input.lower) + REAL_T(input.upper) * REAL_T(NumericLimits::Maximum()); - break; - } - return true; -} - -template <> -bool Hugeint::TryCast(hugeint_t input, double &result) { - return CastBigintToFloating(input, result); -} - -template <> -bool Hugeint::TryCast(hugeint_t input, long double &result) { - return CastBigintToFloating(input, result); -} - -template -hugeint_t HugeintConvertInteger(DST input) { - hugeint_t result; - result.lower = (uint64_t)input; - result.upper = (input < 0) * -1; - return result; -} - -template <> -bool Hugeint::TryConvert(int8_t value, hugeint_t &result) { - result = HugeintConvertInteger(value); - return true; -} - -template <> -bool Hugeint::TryConvert(const char *value, hugeint_t &result) { - auto len = strlen(value); - string_t string_val(value, len); - return TryCast::Operation(string_val, result, true); -} - -template <> -bool Hugeint::TryConvert(int16_t value, hugeint_t &result) { - result = HugeintConvertInteger(value); - return true; -} - -template <> -bool Hugeint::TryConvert(int32_t value, hugeint_t &result) { - result = HugeintConvertInteger(value); - return true; -} - -template <> -bool Hugeint::TryConvert(int64_t value, hugeint_t &result) { - result = HugeintConvertInteger(value); - return true; -} -template <> -bool Hugeint::TryConvert(uint8_t value, hugeint_t &result) { - result = HugeintConvertInteger(value); - return true; -} -template <> -bool Hugeint::TryConvert(uint16_t value, hugeint_t &result) { - result = HugeintConvertInteger(value); - return true; -} -template <> -bool Hugeint::TryConvert(uint32_t value, hugeint_t &result) { - result = HugeintConvertInteger(value); - return true; -} -template <> -bool Hugeint::TryConvert(uint64_t value, hugeint_t &result) { - result = HugeintConvertInteger(value); - return true; -} - -template <> -bool Hugeint::TryConvert(hugeint_t value, hugeint_t &result) { - result = value; - return true; -} - -template <> -bool Hugeint::TryConvert(float value, hugeint_t &result) { - return Hugeint::TryConvert(double(value), result); -} - -template -bool ConvertFloatingToBigint(REAL_T value, hugeint_t &result) { - if (!Value::IsFinite(value)) { - return false; - } - if (value <= -170141183460469231731687303715884105728.0 || value >= 170141183460469231731687303715884105727.0) { - return false; - } - bool negative = value < 0; - if (negative) { - value = -value; - } - result.lower = (uint64_t)fmod(value, REAL_T(NumericLimits::Maximum())); - result.upper = (uint64_t)(value / REAL_T(NumericLimits::Maximum())); - if (negative) { - Hugeint::NegateInPlace(result); - } - return true; -} - -template <> -bool Hugeint::TryConvert(double value, hugeint_t &result) { - return ConvertFloatingToBigint(value, result); -} - -template <> -bool Hugeint::TryConvert(long double value, hugeint_t &result) { - return ConvertFloatingToBigint(value, result); -} - -//===--------------------------------------------------------------------===// -// hugeint_t operators -//===--------------------------------------------------------------------===// -hugeint_t::hugeint_t(int64_t value) { - auto result = Hugeint::Convert(value); - this->lower = result.lower; - this->upper = result.upper; -} - -bool hugeint_t::operator==(const hugeint_t &rhs) const { - return Hugeint::Equals(*this, rhs); -} - -bool hugeint_t::operator!=(const hugeint_t &rhs) const { - return Hugeint::NotEquals(*this, rhs); -} - -bool hugeint_t::operator<(const hugeint_t &rhs) const { - return Hugeint::LessThan(*this, rhs); -} - -bool hugeint_t::operator<=(const hugeint_t &rhs) const { - return Hugeint::LessThanEquals(*this, rhs); -} - -bool hugeint_t::operator>(const hugeint_t &rhs) const { - return Hugeint::GreaterThan(*this, rhs); -} - -bool hugeint_t::operator>=(const hugeint_t &rhs) const { - return Hugeint::GreaterThanEquals(*this, rhs); -} - -hugeint_t hugeint_t::operator+(const hugeint_t &rhs) const { - return Hugeint::Add(*this, rhs); -} - -hugeint_t hugeint_t::operator-(const hugeint_t &rhs) const { - return Hugeint::Subtract(*this, rhs); -} - -hugeint_t hugeint_t::operator*(const hugeint_t &rhs) const { - return Hugeint::Multiply(*this, rhs); -} - -hugeint_t hugeint_t::operator/(const hugeint_t &rhs) const { - return Hugeint::Divide(*this, rhs); -} - -hugeint_t hugeint_t::operator%(const hugeint_t &rhs) const { - return Hugeint::Modulo(*this, rhs); -} - -hugeint_t hugeint_t::operator-() const { - return Hugeint::Negate(*this); -} - -hugeint_t hugeint_t::operator>>(const hugeint_t &rhs) const { - hugeint_t result; - uint64_t shift = rhs.lower; - if (rhs.upper != 0 || shift >= 128) { - return hugeint_t(0); - } else if (shift == 0) { - return *this; - } else if (shift == 64) { - result.upper = (upper < 0) ? -1 : 0; - result.lower = upper; - } else if (shift < 64) { - // perform lower shift in unsigned integer, and mask away the most significant bit - result.lower = (uint64_t(upper) << (64 - shift)) | (lower >> shift); - result.upper = upper >> shift; - } else { - D_ASSERT(shift < 128); - result.lower = upper >> (shift - 64); - result.upper = (upper < 0) ? -1 : 0; - } - return result; -} - -hugeint_t hugeint_t::operator<<(const hugeint_t &rhs) const { - if (upper < 0) { - return hugeint_t(0); - } - hugeint_t result; - uint64_t shift = rhs.lower; - if (rhs.upper != 0 || shift >= 128) { - return hugeint_t(0); - } else if (shift == 64) { - result.upper = lower; - result.lower = 0; - } else if (shift == 0) { - return *this; - } else if (shift < 64) { - // perform upper shift in unsigned integer, and mask away the most significant bit - uint64_t upper_shift = ((uint64_t(upper) << shift) + (lower >> (64 - shift))) & 0x7FFFFFFFFFFFFFFF; - result.lower = lower << shift; - result.upper = upper_shift; - } else { - D_ASSERT(shift < 128); - result.lower = 0; - result.upper = (lower << (shift - 64)) & 0x7FFFFFFFFFFFFFFF; - } - return result; -} - -hugeint_t hugeint_t::operator&(const hugeint_t &rhs) const { - hugeint_t result; - result.lower = lower & rhs.lower; - result.upper = upper & rhs.upper; - return result; -} - -hugeint_t hugeint_t::operator|(const hugeint_t &rhs) const { - hugeint_t result; - result.lower = lower | rhs.lower; - result.upper = upper | rhs.upper; - return result; -} - -hugeint_t hugeint_t::operator^(const hugeint_t &rhs) const { - hugeint_t result; - result.lower = lower ^ rhs.lower; - result.upper = upper ^ rhs.upper; - return result; -} - -hugeint_t hugeint_t::operator~() const { - hugeint_t result; - result.lower = ~lower; - result.upper = ~upper; - return result; -} - -hugeint_t &hugeint_t::operator+=(const hugeint_t &rhs) { - Hugeint::AddInPlace(*this, rhs); - return *this; -} -hugeint_t &hugeint_t::operator-=(const hugeint_t &rhs) { - Hugeint::SubtractInPlace(*this, rhs); - return *this; -} -hugeint_t &hugeint_t::operator*=(const hugeint_t &rhs) { - *this = Hugeint::Multiply(*this, rhs); - return *this; -} -hugeint_t &hugeint_t::operator/=(const hugeint_t &rhs) { - *this = Hugeint::Divide(*this, rhs); - return *this; -} -hugeint_t &hugeint_t::operator%=(const hugeint_t &rhs) { - *this = Hugeint::Modulo(*this, rhs); - return *this; -} -hugeint_t &hugeint_t::operator>>=(const hugeint_t &rhs) { - *this = *this >> rhs; - return *this; -} -hugeint_t &hugeint_t::operator<<=(const hugeint_t &rhs) { - *this = *this << rhs; - return *this; -} -hugeint_t &hugeint_t::operator&=(const hugeint_t &rhs) { - lower &= rhs.lower; - upper &= rhs.upper; - return *this; -} -hugeint_t &hugeint_t::operator|=(const hugeint_t &rhs) { - lower |= rhs.lower; - upper |= rhs.upper; - return *this; -} -hugeint_t &hugeint_t::operator^=(const hugeint_t &rhs) { - lower ^= rhs.lower; - upper ^= rhs.upper; - return *this; -} - -bool hugeint_t::operator!() const { - return *this == 0; -} - -hugeint_t::operator bool() const { - return *this != 0; -} - -template -static T NarrowCast(const hugeint_t &input) { - // NarrowCast is supposed to truncate (take lower) - return static_cast(input.lower); -} - -hugeint_t::operator uint8_t() const { - return NarrowCast(*this); -} -hugeint_t::operator uint16_t() const { - return NarrowCast(*this); -} -hugeint_t::operator uint32_t() const { - return NarrowCast(*this); -} -hugeint_t::operator uint64_t() const { - return NarrowCast(*this); -} -hugeint_t::operator int8_t() const { - return NarrowCast(*this); -} -hugeint_t::operator int16_t() const { - return NarrowCast(*this); -} -hugeint_t::operator int32_t() const { - return NarrowCast(*this); -} -hugeint_t::operator int64_t() const { - return NarrowCast(*this); -} - -string hugeint_t::ToString() const { - return Hugeint::ToString(*this); -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -HyperLogLog::HyperLogLog() : hll(nullptr) { - hll = duckdb_hll::hll_create(); - // Insert into a dense hll can be vectorized, sparse cannot, so we immediately convert - duckdb_hll::hllSparseToDense(hll); -} - -HyperLogLog::HyperLogLog(duckdb_hll::robj *hll) : hll(hll) { -} - -HyperLogLog::~HyperLogLog() { - duckdb_hll::hll_destroy(hll); -} - -void HyperLogLog::Add(data_ptr_t element, idx_t size) { - if (duckdb_hll::hll_add(hll, element, size) == HLL_C_ERR) { - throw InternalException("Could not add to HLL?"); - } -} - -idx_t HyperLogLog::Count() const { - // exception from size_t ban - size_t result; - - if (duckdb_hll::hll_count(hll, &result) != HLL_C_OK) { - throw InternalException("Could not count HLL?"); - } - return result; -} - -unique_ptr HyperLogLog::Merge(HyperLogLog &other) { - duckdb_hll::robj *hlls[2]; - hlls[0] = hll; - hlls[1] = other.hll; - auto new_hll = duckdb_hll::hll_merge(hlls, 2); - if (!new_hll) { - throw InternalException("Could not merge HLLs"); - } - return unique_ptr(new HyperLogLog(new_hll)); -} - -HyperLogLog *HyperLogLog::MergePointer(HyperLogLog &other) { - duckdb_hll::robj *hlls[2]; - hlls[0] = hll; - hlls[1] = other.hll; - auto new_hll = duckdb_hll::hll_merge(hlls, 2); - if (!new_hll) { - throw Exception("Could not merge HLLs"); - } - return new HyperLogLog(new_hll); -} - -unique_ptr HyperLogLog::Merge(HyperLogLog logs[], idx_t count) { - auto hlls_uptr = unique_ptr { - new duckdb_hll::robj *[count] - }; - auto hlls = hlls_uptr.get(); - for (idx_t i = 0; i < count; i++) { - hlls[i] = logs[i].hll; - } - auto new_hll = duckdb_hll::hll_merge(hlls, count); - if (!new_hll) { - throw InternalException("Could not merge HLLs"); - } - return unique_ptr(new HyperLogLog(new_hll)); -} - -idx_t HyperLogLog::GetSize() { - return duckdb_hll::get_size(); -} - -data_ptr_t HyperLogLog::GetPtr() const { - return data_ptr_cast((hll)->ptr); -} - -unique_ptr HyperLogLog::Copy() { - auto result = make_uniq(); - lock_guard guard(lock); - memcpy(result->GetPtr(), GetPtr(), GetSize()); - D_ASSERT(result->Count() == Count()); - return result; -} - -void HyperLogLog::Serialize(Serializer &serializer) const { - serializer.WriteProperty(100, "type", HLLStorageType::UNCOMPRESSED); - serializer.WriteProperty(101, "data", GetPtr(), GetSize()); -} - -unique_ptr HyperLogLog::Deserialize(Deserializer &deserializer) { - auto result = make_uniq(); - auto storage_type = deserializer.ReadProperty(100, "type"); - switch (storage_type) { - case HLLStorageType::UNCOMPRESSED: - deserializer.ReadProperty(101, "data", result->GetPtr(), GetSize()); - break; - default: - throw SerializationException("Unknown HyperLogLog storage type!"); - } - return result; -} - -//===--------------------------------------------------------------------===// -// Vectorized HLL implementation -//===--------------------------------------------------------------------===// -//! Taken from https://nullprogram.com/blog/2018/07/31/ -template -inline uint64_t TemplatedHash(const T &elem) { - uint64_t x = elem; - x ^= x >> 30; - x *= UINT64_C(0xbf58476d1ce4e5b9); - x ^= x >> 27; - x *= UINT64_C(0x94d049bb133111eb); - x ^= x >> 31; - return x; -} - -template <> -inline uint64_t TemplatedHash(const hugeint_t &elem) { - return TemplatedHash(Load(const_data_ptr_cast(&elem.upper))) ^ - TemplatedHash(elem.lower); -} - -template -inline void CreateIntegerRecursive(const_data_ptr_t &data, uint64_t &x) { - x ^= (uint64_t)data[rest - 1] << ((rest - 1) * 8); - return CreateIntegerRecursive(data, x); -} - -template <> -inline void CreateIntegerRecursive<1>(const_data_ptr_t &data, uint64_t &x) { - x ^= (uint64_t)data[0]; -} - -inline uint64_t HashOtherSize(const_data_ptr_t &data, const idx_t &len) { - uint64_t x = 0; - switch (len & 7) { - case 7: - CreateIntegerRecursive<7>(data, x); - break; - case 6: - CreateIntegerRecursive<6>(data, x); - break; - case 5: - CreateIntegerRecursive<5>(data, x); - break; - case 4: - CreateIntegerRecursive<4>(data, x); - break; - case 3: - CreateIntegerRecursive<3>(data, x); - break; - case 2: - CreateIntegerRecursive<2>(data, x); - break; - case 1: - CreateIntegerRecursive<1>(data, x); - break; - case 0: - break; - } - return TemplatedHash(x); -} - -template <> -inline uint64_t TemplatedHash(const string_t &elem) { - auto data = const_data_ptr_cast(elem.GetData()); - const auto &len = elem.GetSize(); - uint64_t h = 0; - for (idx_t i = 0; i + sizeof(uint64_t) <= len; i += sizeof(uint64_t)) { - h ^= TemplatedHash(Load(data)); - data += sizeof(uint64_t); - } - switch (len & (sizeof(uint64_t) - 1)) { - case 4: - h ^= TemplatedHash(Load(data)); - break; - case 2: - h ^= TemplatedHash(Load(data)); - break; - case 1: - h ^= TemplatedHash(Load(data)); - break; - default: - h ^= HashOtherSize(data, len); - } - return h; -} - -template -void TemplatedComputeHashes(UnifiedVectorFormat &vdata, const idx_t &count, uint64_t hashes[]) { - auto data = UnifiedVectorFormat::GetData(vdata); - for (idx_t i = 0; i < count; i++) { - auto idx = vdata.sel->get_index(i); - if (vdata.validity.RowIsValid(idx)) { - hashes[i] = TemplatedHash(data[idx]); - } else { - hashes[i] = 0; - } - } -} - -static void ComputeHashes(UnifiedVectorFormat &vdata, const LogicalType &type, uint64_t hashes[], idx_t count) { - switch (type.InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - case PhysicalType::UINT8: - return TemplatedComputeHashes(vdata, count, hashes); - case PhysicalType::INT16: - case PhysicalType::UINT16: - return TemplatedComputeHashes(vdata, count, hashes); - case PhysicalType::INT32: - case PhysicalType::UINT32: - case PhysicalType::FLOAT: - return TemplatedComputeHashes(vdata, count, hashes); - case PhysicalType::INT64: - case PhysicalType::UINT64: - case PhysicalType::DOUBLE: - return TemplatedComputeHashes(vdata, count, hashes); - case PhysicalType::INT128: - case PhysicalType::INTERVAL: - static_assert(sizeof(hugeint_t) == sizeof(interval_t), "ComputeHashes assumes these are the same size!"); - return TemplatedComputeHashes(vdata, count, hashes); - case PhysicalType::VARCHAR: - return TemplatedComputeHashes(vdata, count, hashes); - default: - throw InternalException("Unimplemented type for HyperLogLog::ComputeHashes"); - } -} - -//! Taken from https://stackoverflow.com/a/72088344 -static inline uint8_t CountTrailingZeros(uint64_t &x) { - static constexpr const uint64_t DEBRUIJN = 0x03f79d71b4cb0a89; - static constexpr const uint8_t LOOKUP[] = {0, 47, 1, 56, 48, 27, 2, 60, 57, 49, 41, 37, 28, 16, 3, 61, - 54, 58, 35, 52, 50, 42, 21, 44, 38, 32, 29, 23, 17, 11, 4, 62, - 46, 55, 26, 59, 40, 36, 15, 53, 34, 51, 20, 43, 31, 22, 10, 45, - 25, 39, 14, 33, 19, 30, 9, 24, 13, 18, 8, 12, 7, 6, 5, 63}; - return LOOKUP[(DEBRUIJN * (x ^ (x - 1))) >> 58]; -} - -static inline void ComputeIndexAndCount(uint64_t &hash, uint8_t &prefix) { - uint64_t index = hash & ((1 << 12) - 1); /* Register index. */ - hash >>= 12; /* Remove bits used to address the register. */ - hash |= ((uint64_t)1 << (64 - 12)); /* Make sure the count will be <= Q+1. */ - - prefix = CountTrailingZeros(hash) + 1; /* Add 1 since we count the "00000...1" pattern. */ - hash = index; -} - -void HyperLogLog::ProcessEntries(UnifiedVectorFormat &vdata, const LogicalType &type, uint64_t hashes[], - uint8_t counts[], idx_t count) { - ComputeHashes(vdata, type, hashes, count); - for (idx_t i = 0; i < count; i++) { - ComputeIndexAndCount(hashes[i], counts[i]); - } -} - -void HyperLogLog::AddToLogs(UnifiedVectorFormat &vdata, idx_t count, uint64_t indices[], uint8_t counts[], - HyperLogLog **logs[], const SelectionVector *log_sel) { - AddToLogsInternal(vdata, count, indices, counts, reinterpret_cast(logs), log_sel); -} - -void HyperLogLog::AddToLog(UnifiedVectorFormat &vdata, idx_t count, uint64_t indices[], uint8_t counts[]) { - lock_guard guard(lock); - AddToSingleLogInternal(vdata, count, indices, counts, hll); -} - -} // namespace duckdb - - - - - - - - - - - - - - - - -namespace duckdb { - -bool Interval::FromString(const string &str, interval_t &result) { - string error_message; - return Interval::FromCString(str.c_str(), str.size(), result, &error_message, false); -} - -template -void IntervalTryAddition(T &target, int64_t input, int64_t multiplier) { - int64_t addition; - if (!TryMultiplyOperator::Operation(input, multiplier, addition)) { - throw OutOfRangeException("interval value is out of range"); - } - T addition_base = Cast::Operation(addition); - if (!TryAddOperator::Operation(target, addition_base, target)) { - throw OutOfRangeException("interval value is out of range"); - } -} - -bool Interval::FromCString(const char *str, idx_t len, interval_t &result, string *error_message, bool strict) { - idx_t pos = 0; - idx_t start_pos; - bool negative; - bool found_any = false; - int64_t number; - DatePartSpecifier specifier; - string specifier_str; - - result.days = 0; - result.micros = 0; - result.months = 0; - - if (len == 0) { - return false; - } - - switch (str[pos]) { - case '@': - pos++; - goto standard_interval; - case 'P': - case 'p': - pos++; - goto posix_interval; - default: - goto standard_interval; - } -standard_interval: - // start parsing a standard interval (e.g. 2 years 3 months...) - for (; pos < len; pos++) { - char c = str[pos]; - if (c == ' ' || c == '\t' || c == '\n') { - // skip spaces - continue; - } else if (c >= '0' && c <= '9') { - // start parsing a positive number - negative = false; - goto interval_parse_number; - } else if (c == '-') { - // negative number - negative = true; - pos++; - goto interval_parse_number; - } else if (c == 'a' || c == 'A') { - // parse the word "ago" as the final specifier - goto interval_parse_ago; - } else { - // unrecognized character, expected a number or end of string - return false; - } - } - goto end_of_string; -interval_parse_number: - start_pos = pos; - for (; pos < len; pos++) { - char c = str[pos]; - if (c >= '0' && c <= '9') { - // the number continues - continue; - } else if (c == ':') { - // colon: we are parsing a time - goto interval_parse_time; - } else { - if (pos == start_pos) { - return false; - } - // finished the number, parse it from the string - string_t nr_string(str + start_pos, pos - start_pos); - number = Cast::Operation(nr_string); - if (negative) { - number = -number; - } - goto interval_parse_identifier; - } - } - goto end_of_string; -interval_parse_time : { - // parse the remainder of the time as a Time type - dtime_t time; - idx_t pos; - if (!Time::TryConvertTime(str + start_pos, len - start_pos, pos, time)) { - return false; - } - result.micros += time.micros; - found_any = true; - if (negative) { - result.micros = -result.micros; - } - goto end_of_string; -} -interval_parse_identifier: - for (; pos < len; pos++) { - char c = str[pos]; - if (c == ' ' || c == '\t' || c == '\n') { - // skip spaces at the start - continue; - } else { - break; - } - } - // now parse the identifier - start_pos = pos; - for (; pos < len; pos++) { - char c = str[pos]; - if ((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')) { - // keep parsing the string - continue; - } else { - break; - } - } - specifier_str = string(str + start_pos, pos - start_pos); - if (!TryGetDatePartSpecifier(specifier_str, specifier)) { - HandleCastError::AssignError(StringUtil::Format("extract specifier \"%s\" not recognized", specifier_str), - error_message); - return false; - } - // add the specifier to the interval - switch (specifier) { - case DatePartSpecifier::MILLENNIUM: - IntervalTryAddition(result.months, number, MONTHS_PER_MILLENIUM); - break; - case DatePartSpecifier::CENTURY: - IntervalTryAddition(result.months, number, MONTHS_PER_CENTURY); - break; - case DatePartSpecifier::DECADE: - IntervalTryAddition(result.months, number, MONTHS_PER_DECADE); - break; - case DatePartSpecifier::YEAR: - IntervalTryAddition(result.months, number, MONTHS_PER_YEAR); - break; - case DatePartSpecifier::QUARTER: - IntervalTryAddition(result.months, number, MONTHS_PER_QUARTER); - break; - case DatePartSpecifier::MONTH: - IntervalTryAddition(result.months, number, 1); - break; - case DatePartSpecifier::DAY: - IntervalTryAddition(result.days, number, 1); - break; - case DatePartSpecifier::WEEK: - IntervalTryAddition(result.days, number, DAYS_PER_WEEK); - break; - case DatePartSpecifier::MICROSECONDS: - IntervalTryAddition(result.micros, number, 1); - break; - case DatePartSpecifier::MILLISECONDS: - IntervalTryAddition(result.micros, number, MICROS_PER_MSEC); - break; - case DatePartSpecifier::SECOND: - IntervalTryAddition(result.micros, number, MICROS_PER_SEC); - break; - case DatePartSpecifier::MINUTE: - IntervalTryAddition(result.micros, number, MICROS_PER_MINUTE); - break; - case DatePartSpecifier::HOUR: - IntervalTryAddition(result.micros, number, MICROS_PER_HOUR); - break; - default: - HandleCastError::AssignError( - StringUtil::Format("extract specifier \"%s\" not supported for interval", specifier_str), error_message); - return false; - } - found_any = true; - goto standard_interval; -interval_parse_ago: - D_ASSERT(str[pos] == 'a' || str[pos] == 'A'); - // parse the "ago" string at the end of the interval - if (len - pos < 3) { - return false; - } - pos++; - if (!(str[pos] == 'g' || str[pos] == 'G')) { - return false; - } - pos++; - if (!(str[pos] == 'o' || str[pos] == 'O')) { - return false; - } - pos++; - // parse any trailing whitespace - for (; pos < len; pos++) { - char c = str[pos]; - if (c == ' ' || c == '\t' || c == '\n') { - continue; - } else { - return false; - } - } - // invert all the values - result.months = -result.months; - result.days = -result.days; - result.micros = -result.micros; - goto end_of_string; -end_of_string: - if (!found_any) { - // end of string and no identifiers were found: cannot convert empty interval - return false; - } - return true; -posix_interval: - return false; -} - -string Interval::ToString(const interval_t &interval) { - char buffer[70]; - idx_t length = IntervalToStringCast::Format(interval, buffer); - return string(buffer, length); -} - -int64_t Interval::GetMilli(const interval_t &val) { - int64_t milli_month, milli_day, milli; - if (!TryMultiplyOperator::Operation((int64_t)val.months, Interval::MICROS_PER_MONTH / 1000, milli_month)) { - throw ConversionException("Could not convert Interval to Milliseconds"); - } - if (!TryMultiplyOperator::Operation((int64_t)val.days, Interval::MICROS_PER_DAY / 1000, milli_day)) { - throw ConversionException("Could not convert Interval to Milliseconds"); - } - milli = val.micros / 1000; - if (!TryAddOperator::Operation(milli, milli_month, milli)) { - throw ConversionException("Could not convert Interval to Milliseconds"); - } - if (!TryAddOperator::Operation(milli, milli_day, milli)) { - throw ConversionException("Could not convert Interval to Milliseconds"); - } - return milli; -} - -int64_t Interval::GetMicro(const interval_t &val) { - int64_t micro_month, micro_day, micro_total; - micro_total = val.micros; - if (!TryMultiplyOperator::Operation((int64_t)val.months, MICROS_PER_MONTH, micro_month)) { - throw ConversionException("Could not convert Month to Microseconds"); - } - if (!TryMultiplyOperator::Operation((int64_t)val.days, MICROS_PER_DAY, micro_day)) { - throw ConversionException("Could not convert Day to Microseconds"); - } - if (!TryAddOperator::Operation(micro_total, micro_month, micro_total)) { - throw ConversionException("Could not convert Interval to Microseconds"); - } - if (!TryAddOperator::Operation(micro_total, micro_day, micro_total)) { - throw ConversionException("Could not convert Interval to Microseconds"); - } - - return micro_total; -} - -int64_t Interval::GetNanoseconds(const interval_t &val) { - int64_t nano; - const auto micro_total = GetMicro(val); - if (!TryMultiplyOperator::Operation(micro_total, NANOS_PER_MICRO, nano)) { - throw ConversionException("Could not convert Interval to Nanoseconds"); - } - - return nano; -} - -interval_t Interval::GetAge(timestamp_t timestamp_1, timestamp_t timestamp_2) { - D_ASSERT(Timestamp::IsFinite(timestamp_1) && Timestamp::IsFinite(timestamp_2)); - date_t date1, date2; - dtime_t time1, time2; - - Timestamp::Convert(timestamp_1, date1, time1); - Timestamp::Convert(timestamp_2, date2, time2); - - // and from date extract the years, months and days - int32_t year1, month1, day1; - int32_t year2, month2, day2; - Date::Convert(date1, year1, month1, day1); - Date::Convert(date2, year2, month2, day2); - // finally perform the differences - auto year_diff = year1 - year2; - auto month_diff = month1 - month2; - auto day_diff = day1 - day2; - - // and from time extract hours, minutes, seconds and milliseconds - int32_t hour1, min1, sec1, micros1; - int32_t hour2, min2, sec2, micros2; - Time::Convert(time1, hour1, min1, sec1, micros1); - Time::Convert(time2, hour2, min2, sec2, micros2); - // finally perform the differences - auto hour_diff = hour1 - hour2; - auto min_diff = min1 - min2; - auto sec_diff = sec1 - sec2; - auto micros_diff = micros1 - micros2; - - // flip sign if necessary - bool sign_flipped = false; - if (timestamp_1 < timestamp_2) { - year_diff = -year_diff; - month_diff = -month_diff; - day_diff = -day_diff; - hour_diff = -hour_diff; - min_diff = -min_diff; - sec_diff = -sec_diff; - micros_diff = -micros_diff; - sign_flipped = true; - } - // now propagate any negative field into the next higher field - while (micros_diff < 0) { - micros_diff += MICROS_PER_SEC; - sec_diff--; - } - while (sec_diff < 0) { - sec_diff += SECS_PER_MINUTE; - min_diff--; - } - while (min_diff < 0) { - min_diff += MINS_PER_HOUR; - hour_diff--; - } - while (hour_diff < 0) { - hour_diff += HOURS_PER_DAY; - day_diff--; - } - while (day_diff < 0) { - if (timestamp_1 < timestamp_2) { - day_diff += Date::IsLeapYear(year1) ? Date::LEAP_DAYS[month1] : Date::NORMAL_DAYS[month1]; - month_diff--; - } else { - day_diff += Date::IsLeapYear(year2) ? Date::LEAP_DAYS[month2] : Date::NORMAL_DAYS[month2]; - month_diff--; - } - } - while (month_diff < 0) { - month_diff += MONTHS_PER_YEAR; - year_diff--; - } - - // recover sign if necessary - if (sign_flipped) { - year_diff = -year_diff; - month_diff = -month_diff; - day_diff = -day_diff; - hour_diff = -hour_diff; - min_diff = -min_diff; - sec_diff = -sec_diff; - micros_diff = -micros_diff; - } - interval_t interval; - interval.months = year_diff * MONTHS_PER_YEAR + month_diff; - interval.days = day_diff; - interval.micros = Time::FromTime(hour_diff, min_diff, sec_diff, micros_diff).micros; - - return interval; -} - -interval_t Interval::GetDifference(timestamp_t timestamp_1, timestamp_t timestamp_2) { - if (!Timestamp::IsFinite(timestamp_1) || !Timestamp::IsFinite(timestamp_2)) { - throw InvalidInputException("Cannot subtract infinite timestamps"); - } - const auto us_1 = Timestamp::GetEpochMicroSeconds(timestamp_1); - const auto us_2 = Timestamp::GetEpochMicroSeconds(timestamp_2); - int64_t delta_us; - if (!TrySubtractOperator::Operation(us_1, us_2, delta_us)) { - throw ConversionException("Timestamp difference is out of bounds"); - } - return FromMicro(delta_us); -} - -interval_t Interval::FromMicro(int64_t delta_us) { - interval_t result; - result.months = 0; - result.days = delta_us / Interval::MICROS_PER_DAY; - result.micros = delta_us % Interval::MICROS_PER_DAY; - - return result; -} - -interval_t Interval::Invert(interval_t interval) { - interval.days = -interval.days; - interval.micros = -interval.micros; - interval.months = -interval.months; - return interval; -} - -date_t Interval::Add(date_t left, interval_t right) { - if (!Date::IsFinite(left)) { - return left; - } - date_t result; - if (right.months != 0) { - int32_t year, month, day; - Date::Convert(left, year, month, day); - int32_t year_diff = right.months / Interval::MONTHS_PER_YEAR; - year += year_diff; - month += right.months - year_diff * Interval::MONTHS_PER_YEAR; - if (month > Interval::MONTHS_PER_YEAR) { - year++; - month -= Interval::MONTHS_PER_YEAR; - } else if (month <= 0) { - year--; - month += Interval::MONTHS_PER_YEAR; - } - day = MinValue(day, Date::MonthDays(year, month)); - result = Date::FromDate(year, month, day); - } else { - result = left; - } - if (right.days != 0) { - if (!TryAddOperator::Operation(result.days, right.days, result.days)) { - throw OutOfRangeException("Date out of range"); - } - } - if (right.micros != 0) { - if (!TryAddOperator::Operation(result.days, int32_t(right.micros / Interval::MICROS_PER_DAY), result.days)) { - throw OutOfRangeException("Date out of range"); - } - } - if (!Date::IsFinite(result)) { - throw OutOfRangeException("Date out of range"); - } - return result; -} - -dtime_t Interval::Add(dtime_t left, interval_t right, date_t &date) { - int64_t diff = right.micros - ((right.micros / Interval::MICROS_PER_DAY) * Interval::MICROS_PER_DAY); - left += diff; - if (left.micros >= Interval::MICROS_PER_DAY) { - left.micros -= Interval::MICROS_PER_DAY; - date.days++; - } else if (left.micros < 0) { - left.micros += Interval::MICROS_PER_DAY; - date.days--; - } - return left; -} - -timestamp_t Interval::Add(timestamp_t left, interval_t right) { - if (!Timestamp::IsFinite(left)) { - return left; - } - date_t date; - dtime_t time; - Timestamp::Convert(left, date, time); - auto new_date = Interval::Add(date, right); - auto new_time = Interval::Add(time, right, new_date); - return Timestamp::FromDatetime(new_date, new_time); -} - -} // namespace duckdb - - -namespace duckdb { - -// forward declarations -//===--------------------------------------------------------------------===// -// Primitives -//===--------------------------------------------------------------------===// -template -static idx_t GetAllocationSize(uint16_t capacity) { - return AlignValue(sizeof(ListSegment) + capacity * (sizeof(bool) + sizeof(T))); -} - -template -static data_ptr_t AllocatePrimitiveData(ArenaAllocator &allocator, uint16_t capacity) { - return allocator.Allocate(GetAllocationSize(capacity)); -} - -template -static T *GetPrimitiveData(ListSegment *segment) { - return reinterpret_cast(data_ptr_cast(segment) + sizeof(ListSegment) + segment->capacity * sizeof(bool)); -} - -template -static const T *GetPrimitiveData(const ListSegment *segment) { - return reinterpret_cast(const_data_ptr_cast(segment) + sizeof(ListSegment) + - segment->capacity * sizeof(bool)); -} - -//===--------------------------------------------------------------------===// -// Lists -//===--------------------------------------------------------------------===// -static idx_t GetAllocationSizeList(uint16_t capacity) { - return AlignValue(sizeof(ListSegment) + capacity * (sizeof(bool) + sizeof(uint64_t)) + sizeof(LinkedList)); -} - -static data_ptr_t AllocateListData(ArenaAllocator &allocator, uint16_t capacity) { - return allocator.Allocate(GetAllocationSizeList(capacity)); -} - -static uint64_t *GetListLengthData(ListSegment *segment) { - return reinterpret_cast(data_ptr_cast(segment) + sizeof(ListSegment) + - segment->capacity * sizeof(bool)); -} - -static const uint64_t *GetListLengthData(const ListSegment *segment) { - return reinterpret_cast(const_data_ptr_cast(segment) + sizeof(ListSegment) + - segment->capacity * sizeof(bool)); -} - -static const LinkedList *GetListChildData(const ListSegment *segment) { - return reinterpret_cast(const_data_ptr_cast(segment) + sizeof(ListSegment) + - segment->capacity * (sizeof(bool) + sizeof(uint64_t))); -} - -static LinkedList *GetListChildData(ListSegment *segment) { - return reinterpret_cast(data_ptr_cast(segment) + sizeof(ListSegment) + - segment->capacity * (sizeof(bool) + sizeof(uint64_t))); -} - -//===--------------------------------------------------------------------===// -// Structs -//===--------------------------------------------------------------------===// -static idx_t GetAllocationSizeStruct(uint16_t capacity, idx_t child_count) { - return AlignValue(sizeof(ListSegment) + capacity * sizeof(bool) + child_count * sizeof(ListSegment *)); -} - -static data_ptr_t AllocateStructData(ArenaAllocator &allocator, uint16_t capacity, idx_t child_count) { - return allocator.Allocate(GetAllocationSizeStruct(capacity, child_count)); -} - -static ListSegment **GetStructData(ListSegment *segment) { - return reinterpret_cast(data_ptr_cast(segment) + +sizeof(ListSegment) + - segment->capacity * sizeof(bool)); -} - -static const ListSegment *const *GetStructData(const ListSegment *segment) { - return reinterpret_cast(const_data_ptr_cast(segment) + sizeof(ListSegment) + - segment->capacity * sizeof(bool)); -} - -static bool *GetNullMask(ListSegment *segment) { - return reinterpret_cast(data_ptr_cast(segment) + sizeof(ListSegment)); -} - -static const bool *GetNullMask(const ListSegment *segment) { - return reinterpret_cast(const_data_ptr_cast(segment) + sizeof(ListSegment)); -} - -static uint16_t GetCapacityForNewSegment(uint16_t capacity) { - auto next_power_of_two = idx_t(capacity) * 2; - if (next_power_of_two >= NumericLimits::Maximum()) { - return capacity; - } - return uint16_t(next_power_of_two); -} - -//===--------------------------------------------------------------------===// -// Create -//===--------------------------------------------------------------------===// -template -static ListSegment *CreatePrimitiveSegment(const ListSegmentFunctions &, ArenaAllocator &allocator, uint16_t capacity) { - // allocate data and set the header - auto segment = (ListSegment *)AllocatePrimitiveData(allocator, capacity); - segment->capacity = capacity; - segment->count = 0; - segment->next = nullptr; - return segment; -} - -static ListSegment *CreateListSegment(const ListSegmentFunctions &, ArenaAllocator &allocator, uint16_t capacity) { - // allocate data and set the header - auto segment = reinterpret_cast(AllocateListData(allocator, capacity)); - segment->capacity = capacity; - segment->count = 0; - segment->next = nullptr; - - // create an empty linked list for the child vector - auto linked_child_list = GetListChildData(segment); - LinkedList linked_list(0, nullptr, nullptr); - Store(linked_list, data_ptr_cast(linked_child_list)); - - return segment; -} - -static ListSegment *CreateStructSegment(const ListSegmentFunctions &functions, ArenaAllocator &allocator, - uint16_t capacity) { - // allocate data and set header - auto segment = - reinterpret_cast(AllocateStructData(allocator, capacity, functions.child_functions.size())); - segment->capacity = capacity; - segment->count = 0; - segment->next = nullptr; - - // create a child ListSegment with exactly the same capacity for each child vector - auto child_segments = GetStructData(segment); - for (idx_t i = 0; i < functions.child_functions.size(); i++) { - auto child_function = functions.child_functions[i]; - auto child_segment = child_function.create_segment(child_function, allocator, capacity); - Store(child_segment, data_ptr_cast(child_segments + i)); - } - - return segment; -} - -static ListSegment *GetSegment(const ListSegmentFunctions &functions, ArenaAllocator &allocator, - LinkedList &linked_list) { - ListSegment *segment; - - // determine segment - if (!linked_list.last_segment) { - // empty linked list, create the first (and last) segment - auto capacity = ListSegment::INITIAL_CAPACITY; - segment = functions.create_segment(functions, allocator, capacity); - linked_list.first_segment = segment; - linked_list.last_segment = segment; - - } else if (linked_list.last_segment->capacity == linked_list.last_segment->count) { - // the last segment of the linked list is full, create a new one and append it - auto capacity = GetCapacityForNewSegment(linked_list.last_segment->capacity); - segment = functions.create_segment(functions, allocator, capacity); - linked_list.last_segment->next = segment; - linked_list.last_segment = segment; - } else { - // the last segment of the linked list is not full, append the data to it - segment = linked_list.last_segment; - } - - D_ASSERT(segment); - return segment; -} - -//===--------------------------------------------------------------------===// -// Append -//===--------------------------------------------------------------------===// -template -static void WriteDataToPrimitiveSegment(const ListSegmentFunctions &, ArenaAllocator &, ListSegment *segment, - RecursiveUnifiedVectorFormat &input_data, idx_t &entry_idx) { - - auto sel_entry_idx = input_data.unified.sel->get_index(entry_idx); - - // write null validity - auto null_mask = GetNullMask(segment); - auto valid = input_data.unified.validity.RowIsValid(sel_entry_idx); - null_mask[segment->count] = !valid; - - // write value - if (valid) { - auto segment_data = GetPrimitiveData(segment); - auto input_data_ptr = UnifiedVectorFormat::GetData(input_data.unified); - Store(input_data_ptr[sel_entry_idx], data_ptr_cast(segment_data + segment->count)); - } -} - -static void WriteDataToVarcharSegment(const ListSegmentFunctions &functions, ArenaAllocator &allocator, - ListSegment *segment, RecursiveUnifiedVectorFormat &input_data, - idx_t &entry_idx) { - - auto sel_entry_idx = input_data.unified.sel->get_index(entry_idx); - - // write null validity - auto null_mask = GetNullMask(segment); - auto valid = input_data.unified.validity.RowIsValid(sel_entry_idx); - null_mask[segment->count] = !valid; - - // set the length of this string - auto str_length_data = GetListLengthData(segment); - uint64_t str_length = 0; - - // get the string - string_t str_entry; - if (valid) { - str_entry = UnifiedVectorFormat::GetData(input_data.unified)[sel_entry_idx]; - str_length = str_entry.GetSize(); - } - - // we can reconstruct the offset from the length - Store(str_length, data_ptr_cast(str_length_data + segment->count)); - if (!valid) { - return; - } - - // write the characters to the linked list of child segments - auto child_segments = Load(data_ptr_cast(GetListChildData(segment))); - for (char &c : str_entry.GetString()) { - auto child_segment = GetSegment(functions.child_functions.back(), allocator, child_segments); - auto data = GetPrimitiveData(child_segment); - data[child_segment->count] = c; - child_segment->count++; - child_segments.total_capacity++; - } - - // store the updated linked list - Store(child_segments, data_ptr_cast(GetListChildData(segment))); -} - -static void WriteDataToListSegment(const ListSegmentFunctions &functions, ArenaAllocator &allocator, - ListSegment *segment, RecursiveUnifiedVectorFormat &input_data, idx_t &entry_idx) { - - auto sel_entry_idx = input_data.unified.sel->get_index(entry_idx); - - // write null validity - auto null_mask = GetNullMask(segment); - auto valid = input_data.unified.validity.RowIsValid(sel_entry_idx); - null_mask[segment->count] = !valid; - - // set the length of this list - auto list_length_data = GetListLengthData(segment); - uint64_t list_length = 0; - - if (valid) { - // get list entry information - const auto &list_entry = UnifiedVectorFormat::GetData(input_data.unified)[sel_entry_idx]; - list_length = list_entry.length; - - // loop over the child vector entries and recurse on them - auto child_segments = Load(data_ptr_cast(GetListChildData(segment))); - D_ASSERT(functions.child_functions.size() == 1); - for (idx_t child_idx = 0; child_idx < list_entry.length; child_idx++) { - auto source_idx_child = list_entry.offset + child_idx; - functions.child_functions[0].AppendRow(allocator, child_segments, input_data.children.back(), - source_idx_child); - } - // store the updated linked list - Store(child_segments, data_ptr_cast(GetListChildData(segment))); - } - - Store(list_length, data_ptr_cast(list_length_data + segment->count)); -} - -static void WriteDataToStructSegment(const ListSegmentFunctions &functions, ArenaAllocator &allocator, - ListSegment *segment, RecursiveUnifiedVectorFormat &input_data, idx_t &entry_idx) { - - auto sel_entry_idx = input_data.unified.sel->get_index(entry_idx); - - // write null validity - auto null_mask = GetNullMask(segment); - auto valid = input_data.unified.validity.RowIsValid(sel_entry_idx); - null_mask[segment->count] = !valid; - - // write value - D_ASSERT(input_data.children.size() == functions.child_functions.size()); - auto child_list = GetStructData(segment); - - // write the data of each of the children of the struct - for (idx_t i = 0; i < input_data.children.size(); i++) { - auto child_list_segment = Load(data_ptr_cast(child_list + i)); - auto &child_function = functions.child_functions[i]; - child_function.write_data(child_function, allocator, child_list_segment, input_data.children[i], entry_idx); - child_list_segment->count++; - } -} - -void ListSegmentFunctions::AppendRow(ArenaAllocator &allocator, LinkedList &linked_list, - RecursiveUnifiedVectorFormat &input_data, idx_t &entry_idx) const { - - auto &write_data_to_segment = *this; - auto segment = GetSegment(write_data_to_segment, allocator, linked_list); - write_data_to_segment.write_data(write_data_to_segment, allocator, segment, input_data, entry_idx); - - linked_list.total_capacity++; - segment->count++; -} - -//===--------------------------------------------------------------------===// -// Read -//===--------------------------------------------------------------------===// -template -static void ReadDataFromPrimitiveSegment(const ListSegmentFunctions &, const ListSegment *segment, Vector &result, - idx_t &total_count) { - - auto &aggr_vector_validity = FlatVector::Validity(result); - - // set NULLs - auto null_mask = GetNullMask(segment); - for (idx_t i = 0; i < segment->count; i++) { - if (null_mask[i]) { - aggr_vector_validity.SetInvalid(total_count + i); - } - } - - auto aggr_vector_data = FlatVector::GetData(result); - - // load values - for (idx_t i = 0; i < segment->count; i++) { - if (aggr_vector_validity.RowIsValid(total_count + i)) { - auto data = GetPrimitiveData(segment); - aggr_vector_data[total_count + i] = Load(const_data_ptr_cast(data + i)); - } - } -} - -static void ReadDataFromVarcharSegment(const ListSegmentFunctions &, const ListSegment *segment, Vector &result, - idx_t &total_count) { - - auto &aggr_vector_validity = FlatVector::Validity(result); - - // set NULLs - auto null_mask = GetNullMask(segment); - for (idx_t i = 0; i < segment->count; i++) { - if (null_mask[i]) { - aggr_vector_validity.SetInvalid(total_count + i); - } - } - - // append all the child chars to one string - string str = ""; - auto linked_child_list = Load(const_data_ptr_cast(GetListChildData(segment))); - while (linked_child_list.first_segment) { - auto child_segment = linked_child_list.first_segment; - auto data = GetPrimitiveData(child_segment); - str.append(data, child_segment->count); - linked_child_list.first_segment = child_segment->next; - } - linked_child_list.last_segment = nullptr; - - // use length and (reconstructed) offset to get the correct substrings - auto aggr_vector_data = FlatVector::GetData(result); - auto str_length_data = GetListLengthData(segment); - - // get the substrings and write them to the result vector - idx_t offset = 0; - for (idx_t i = 0; i < segment->count; i++) { - if (!null_mask[i]) { - auto str_length = Load(const_data_ptr_cast(str_length_data + i)); - auto substr = str.substr(offset, str_length); - auto str_t = StringVector::AddStringOrBlob(result, substr); - aggr_vector_data[total_count + i] = str_t; - offset += str_length; - } - } -} - -static void ReadDataFromListSegment(const ListSegmentFunctions &functions, const ListSegment *segment, Vector &result, - idx_t &total_count) { - - auto &aggr_vector_validity = FlatVector::Validity(result); - - // set NULLs - auto null_mask = GetNullMask(segment); - for (idx_t i = 0; i < segment->count; i++) { - if (null_mask[i]) { - aggr_vector_validity.SetInvalid(total_count + i); - } - } - - auto list_vector_data = FlatVector::GetData(result); - - // get the starting offset - idx_t offset = 0; - if (total_count != 0) { - offset = list_vector_data[total_count - 1].offset + list_vector_data[total_count - 1].length; - } - idx_t starting_offset = offset; - - // set length and offsets - auto list_length_data = GetListLengthData(segment); - for (idx_t i = 0; i < segment->count; i++) { - auto list_length = Load(const_data_ptr_cast(list_length_data + i)); - list_vector_data[total_count + i].length = list_length; - list_vector_data[total_count + i].offset = offset; - offset += list_length; - } - - auto &child_vector = ListVector::GetEntry(result); - auto linked_child_list = Load(const_data_ptr_cast(GetListChildData(segment))); - ListVector::Reserve(result, offset); - - // recurse into the linked list of child values - D_ASSERT(functions.child_functions.size() == 1); - functions.child_functions[0].BuildListVector(linked_child_list, child_vector, starting_offset); - ListVector::SetListSize(result, offset); -} - -static void ReadDataFromStructSegment(const ListSegmentFunctions &functions, const ListSegment *segment, Vector &result, - idx_t &total_count) { - - auto &aggr_vector_validity = FlatVector::Validity(result); - - // set NULLs - auto null_mask = GetNullMask(segment); - for (idx_t i = 0; i < segment->count; i++) { - if (null_mask[i]) { - aggr_vector_validity.SetInvalid(total_count + i); - } - } - - auto &children = StructVector::GetEntries(result); - - // recurse into the child segments of each child of the struct - D_ASSERT(children.size() == functions.child_functions.size()); - auto struct_children = GetStructData(segment); - for (idx_t child_count = 0; child_count < children.size(); child_count++) { - auto struct_children_segment = Load(const_data_ptr_cast(struct_children + child_count)); - auto &child_function = functions.child_functions[child_count]; - child_function.read_data(child_function, struct_children_segment, *children[child_count], total_count); - } -} - -void ListSegmentFunctions::BuildListVector(const LinkedList &linked_list, Vector &result, - idx_t &initial_total_count) const { - auto &read_data_from_segment = *this; - idx_t total_count = initial_total_count; - auto segment = linked_list.first_segment; - while (segment) { - read_data_from_segment.read_data(read_data_from_segment, segment, result, total_count); - - total_count += segment->count; - segment = segment->next; - } -} - -//===--------------------------------------------------------------------===// -// Functions -//===--------------------------------------------------------------------===// -template -void SegmentPrimitiveFunction(ListSegmentFunctions &functions) { - functions.create_segment = CreatePrimitiveSegment; - functions.write_data = WriteDataToPrimitiveSegment; - functions.read_data = ReadDataFromPrimitiveSegment; -} - -void GetSegmentDataFunctions(ListSegmentFunctions &functions, const LogicalType &type) { - - auto physical_type = type.InternalType(); - switch (physical_type) { - case PhysicalType::BIT: - case PhysicalType::BOOL: - SegmentPrimitiveFunction(functions); - break; - case PhysicalType::INT8: - SegmentPrimitiveFunction(functions); - break; - case PhysicalType::INT16: - SegmentPrimitiveFunction(functions); - break; - case PhysicalType::INT32: - SegmentPrimitiveFunction(functions); - break; - case PhysicalType::INT64: - SegmentPrimitiveFunction(functions); - break; - case PhysicalType::UINT8: - SegmentPrimitiveFunction(functions); - break; - case PhysicalType::UINT16: - SegmentPrimitiveFunction(functions); - break; - case PhysicalType::UINT32: - SegmentPrimitiveFunction(functions); - break; - case PhysicalType::UINT64: - SegmentPrimitiveFunction(functions); - break; - case PhysicalType::FLOAT: - SegmentPrimitiveFunction(functions); - break; - case PhysicalType::DOUBLE: - SegmentPrimitiveFunction(functions); - break; - case PhysicalType::INT128: - SegmentPrimitiveFunction(functions); - break; - case PhysicalType::INTERVAL: - SegmentPrimitiveFunction(functions); - break; - case PhysicalType::VARCHAR: { - functions.create_segment = CreateListSegment; - functions.write_data = WriteDataToVarcharSegment; - functions.read_data = ReadDataFromVarcharSegment; - - functions.child_functions.emplace_back(); - SegmentPrimitiveFunction(functions.child_functions.back()); - break; - } - case PhysicalType::LIST: { - functions.create_segment = CreateListSegment; - functions.write_data = WriteDataToListSegment; - functions.read_data = ReadDataFromListSegment; - - // recurse - functions.child_functions.emplace_back(); - GetSegmentDataFunctions(functions.child_functions.back(), ListType::GetChildType(type)); - break; - } - case PhysicalType::STRUCT: { - functions.create_segment = CreateStructSegment; - functions.write_data = WriteDataToStructSegment; - functions.read_data = ReadDataFromStructSegment; - - // recurse - auto child_types = StructType::GetChildTypes(type); - for (idx_t i = 0; i < child_types.size(); i++) { - functions.child_functions.emplace_back(); - GetSegmentDataFunctions(functions.child_functions.back(), child_types[i].second); - } - break; - } - default: - throw InternalException("LIST aggregate not yet implemented for " + type.ToString()); - } -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -PartitionedTupleData::PartitionedTupleData(PartitionedTupleDataType type_p, BufferManager &buffer_manager_p, - const TupleDataLayout &layout_p) - : type(type_p), buffer_manager(buffer_manager_p), layout(layout_p.Copy()), count(0), data_size(0), - allocators(make_shared()) { -} - -PartitionedTupleData::PartitionedTupleData(const PartitionedTupleData &other) - : type(other.type), buffer_manager(other.buffer_manager), layout(other.layout.Copy()) { -} - -PartitionedTupleData::~PartitionedTupleData() { -} - -const TupleDataLayout &PartitionedTupleData::GetLayout() const { - return layout; -} - -PartitionedTupleDataType PartitionedTupleData::GetType() const { - return type; -} - -void PartitionedTupleData::InitializeAppendState(PartitionedTupleDataAppendState &state, - TupleDataPinProperties properties) const { - state.partition_sel.Initialize(); - state.reverse_partition_sel.Initialize(); - - vector column_ids; - column_ids.reserve(layout.ColumnCount()); - for (idx_t col_idx = 0; col_idx < layout.ColumnCount(); col_idx++) { - column_ids.emplace_back(col_idx); - } - - InitializeAppendStateInternal(state, properties); -} - -void PartitionedTupleData::Append(PartitionedTupleDataAppendState &state, DataChunk &input, - const SelectionVector &append_sel, const idx_t append_count) { - TupleDataCollection::ToUnifiedFormat(state.chunk_state, input); - AppendUnified(state, input, append_sel, append_count); -} - -bool PartitionedTupleData::UseFixedSizeMap() const { - return MaxPartitionIndex() < PartitionedTupleDataAppendState::MAP_THRESHOLD; -} - -void PartitionedTupleData::AppendUnified(PartitionedTupleDataAppendState &state, DataChunk &input, - const SelectionVector &append_sel, const idx_t append_count) { - const idx_t actual_append_count = append_count == DConstants::INVALID_INDEX ? input.size() : append_count; - - // Compute partition indices and store them in state.partition_indices - ComputePartitionIndices(state, input); - - // Build the selection vector for the partitions - BuildPartitionSel(state, append_sel, actual_append_count); - - // Early out: check if everything belongs to a single partition - optional_idx partition_index; - if (UseFixedSizeMap()) { - if (state.fixed_partition_entries.size() == 1) { - partition_index = state.fixed_partition_entries.begin().GetKey(); - } - } else { - if (state.partition_entries.size() == 1) { - partition_index = state.partition_entries.begin()->first; - } - } - if (partition_index.IsValid()) { - auto &partition = *partitions[partition_index.GetIndex()]; - auto &partition_pin_state = *state.partition_pin_states[partition_index.GetIndex()]; - - const auto size_before = partition.SizeInBytes(); - partition.AppendUnified(partition_pin_state, state.chunk_state, input, append_sel, actual_append_count); - data_size += partition.SizeInBytes() - size_before; - } else { - // Compute the heap sizes for the whole chunk - if (!layout.AllConstant()) { - TupleDataCollection::ComputeHeapSizes(state.chunk_state, input, state.partition_sel, actual_append_count); - } - - // Build the buffer space - BuildBufferSpace(state); - - // Now scatter everything in one go - partitions[0]->Scatter(state.chunk_state, input, state.partition_sel, actual_append_count); - } - - count += actual_append_count; - Verify(); -} - -void PartitionedTupleData::Append(PartitionedTupleDataAppendState &state, TupleDataChunkState &input, - const idx_t append_count) { - // Compute partition indices and store them in state.partition_indices - ComputePartitionIndices(input.row_locations, append_count, state.partition_indices); - - // Build the selection vector for the partitions - BuildPartitionSel(state, *FlatVector::IncrementalSelectionVector(), append_count); - - // Early out: check if everything belongs to a single partition - optional_idx partition_index; - if (UseFixedSizeMap()) { - if (state.fixed_partition_entries.size() == 1) { - partition_index = state.fixed_partition_entries.begin().GetKey(); - } - } else { - if (state.partition_entries.size() == 1) { - partition_index = state.partition_entries.begin()->first; - } - } - - if (partition_index.IsValid()) { - auto &partition = *partitions[partition_index.GetIndex()]; - auto &partition_pin_state = *state.partition_pin_states[partition_index.GetIndex()]; - - state.chunk_state.heap_sizes.Reference(input.heap_sizes); - - const auto size_before = partition.SizeInBytes(); - partition.Build(partition_pin_state, state.chunk_state, 0, append_count); - data_size += partition.SizeInBytes() - size_before; - - partition.CopyRows(state.chunk_state, input, *FlatVector::IncrementalSelectionVector(), append_count); - } else { - // Build the buffer space - state.chunk_state.heap_sizes.Slice(input.heap_sizes, state.partition_sel, append_count); - state.chunk_state.heap_sizes.Flatten(append_count); - BuildBufferSpace(state); - - // Copy the rows - partitions[0]->CopyRows(state.chunk_state, input, state.partition_sel, append_count); - } - - count += append_count; - Verify(); -} - -// LCOV_EXCL_START -template -struct UnorderedMapGetter { - static inline const typename MAP_TYPE::key_type &GetKey(typename MAP_TYPE::iterator &iterator) { - return iterator->first; - } - - static inline const typename MAP_TYPE::key_type &GetKey(const typename MAP_TYPE::const_iterator &iterator) { - return iterator->first; - } - - static inline typename MAP_TYPE::mapped_type &GetValue(typename MAP_TYPE::iterator &iterator) { - return iterator->second; - } - - static inline const typename MAP_TYPE::mapped_type &GetValue(const typename MAP_TYPE::const_iterator &iterator) { - return iterator->second; - } -}; - -template -struct FixedSizeMapGetter { - static inline const idx_t &GetKey(fixed_size_map_iterator_t &iterator) { - return iterator.GetKey(); - } - - static inline const idx_t &GetKey(const const_fixed_size_map_iterator_t &iterator) { - return iterator.GetKey(); - } - - static inline T &GetValue(fixed_size_map_iterator_t &iterator) { - return iterator.GetValue(); - } - - static inline const T &GetValue(const const_fixed_size_map_iterator_t &iterator) { - return iterator.GetValue(); - } -}; -// LCOV_EXCL_STOP - -void PartitionedTupleData::BuildPartitionSel(PartitionedTupleDataAppendState &state, const SelectionVector &append_sel, - const idx_t append_count) { - if (UseFixedSizeMap()) { - BuildPartitionSel, FixedSizeMapGetter>( - state, state.fixed_partition_entries, append_sel, append_count); - } else { - BuildPartitionSel, UnorderedMapGetter>>( - state, state.partition_entries, append_sel, append_count); - } -} - -template -void PartitionedTupleData::BuildPartitionSel(PartitionedTupleDataAppendState &state, MAP_TYPE &partition_entries, - const SelectionVector &append_sel, const idx_t append_count) { - const auto partition_indices = FlatVector::GetData(state.partition_indices); - partition_entries.clear(); - - switch (state.partition_indices.GetVectorType()) { - case VectorType::FLAT_VECTOR: - for (idx_t i = 0; i < append_count; i++) { - const auto index = append_sel.get_index(i); - const auto &partition_index = partition_indices[index]; - auto partition_entry = partition_entries.find(partition_index); - if (partition_entry == partition_entries.end()) { - partition_entries[partition_index] = list_entry_t(0, 1); - } else { - GETTER::GetValue(partition_entry).length++; - } - } - break; - case VectorType::CONSTANT_VECTOR: - partition_entries[partition_indices[0]] = list_entry_t(0, append_count); - break; - default: - throw InternalException("Unexpected VectorType in PartitionedTupleData::Append"); - } - - // Early out: check if everything belongs to a single partition - if (partition_entries.size() == 1) { - // This needs to be initialized, even if we go the short path here - for (idx_t i = 0; i < append_count; i++) { - const auto index = append_sel.get_index(i); - state.reverse_partition_sel[index] = i; - } - return; - } - - // Compute offsets from the counts - idx_t offset = 0; - for (auto it = partition_entries.begin(); it != partition_entries.end(); ++it) { - auto &partition_entry = GETTER::GetValue(it); - partition_entry.offset = offset; - offset += partition_entry.length; - } - - // Now initialize a single selection vector that acts as a selection vector for every partition - auto &partition_sel = state.partition_sel; - auto &reverse_partition_sel = state.reverse_partition_sel; - for (idx_t i = 0; i < append_count; i++) { - const auto index = append_sel.get_index(i); - const auto &partition_index = partition_indices[index]; - auto &partition_offset = partition_entries[partition_index].offset; - reverse_partition_sel[index] = partition_offset; - partition_sel[partition_offset++] = index; - } -} - -void PartitionedTupleData::BuildBufferSpace(PartitionedTupleDataAppendState &state) { - if (UseFixedSizeMap()) { - BuildBufferSpace, FixedSizeMapGetter>( - state, state.fixed_partition_entries); - } else { - BuildBufferSpace, UnorderedMapGetter>>( - state, state.partition_entries); - } -} - -template -void PartitionedTupleData::BuildBufferSpace(PartitionedTupleDataAppendState &state, const MAP_TYPE &partition_entries) { - for (auto it = partition_entries.begin(); it != partition_entries.end(); ++it) { - const auto &partition_index = GETTER::GetKey(it); - - // Partition, pin state for this partition index - auto &partition = *partitions[partition_index]; - auto &partition_pin_state = *state.partition_pin_states[partition_index]; - - // Length and offset for this partition - const auto &partition_entry = GETTER::GetValue(it); - const auto &partition_length = partition_entry.length; - const auto partition_offset = partition_entry.offset - partition_length; - - // Build out the buffer space for this partition - const auto size_before = partition.SizeInBytes(); - partition.Build(partition_pin_state, state.chunk_state, partition_offset, partition_length); - data_size += partition.SizeInBytes() - size_before; - } -} - -void PartitionedTupleData::FlushAppendState(PartitionedTupleDataAppendState &state) { - for (idx_t partition_index = 0; partition_index < partitions.size(); partition_index++) { - auto &partition = *partitions[partition_index]; - auto &partition_pin_state = *state.partition_pin_states[partition_index]; - partition.FinalizePinState(partition_pin_state); - } -} - -void PartitionedTupleData::Combine(PartitionedTupleData &other) { - if (other.Count() == 0) { - return; - } - - // Now combine the state's partitions into this - lock_guard guard(lock); - if (partitions.empty()) { - // This is the first merge, we just copy them over - partitions = std::move(other.partitions); - } else { - D_ASSERT(partitions.size() == other.partitions.size()); - // Combine the append state's partitions into this PartitionedTupleData - for (idx_t i = 0; i < other.partitions.size(); i++) { - partitions[i]->Combine(*other.partitions[i]); - } - } - this->count += other.count; - this->data_size += other.data_size; - Verify(); -} - -void PartitionedTupleData::Reset() { - for (auto &partition : partitions) { - partition->Reset(); - } - this->count = 0; - this->data_size = 0; - Verify(); -} - -void PartitionedTupleData::Repartition(PartitionedTupleData &new_partitioned_data) { - D_ASSERT(layout.GetTypes() == new_partitioned_data.layout.GetTypes()); - - if (partitions.size() == new_partitioned_data.partitions.size()) { - new_partitioned_data.Combine(*this); - return; - } - - PartitionedTupleDataAppendState append_state; - new_partitioned_data.InitializeAppendState(append_state); - - const auto reverse = RepartitionReverseOrder(); - const idx_t start_idx = reverse ? partitions.size() : 0; - const idx_t end_idx = reverse ? 0 : partitions.size(); - const int64_t update = reverse ? -1 : 1; - const int64_t adjustment = reverse ? -1 : 0; - - for (idx_t partition_idx = start_idx; partition_idx != end_idx; partition_idx += update) { - auto actual_partition_idx = partition_idx + adjustment; - auto &partition = *partitions[actual_partition_idx]; - - if (partition.Count() > 0) { - TupleDataChunkIterator iterator(partition, TupleDataPinProperties::DESTROY_AFTER_DONE, true); - auto &chunk_state = iterator.GetChunkState(); - do { - new_partitioned_data.Append(append_state, chunk_state, iterator.GetCurrentChunkCount()); - } while (iterator.Next()); - - RepartitionFinalizeStates(*this, new_partitioned_data, append_state, actual_partition_idx); - } - partitions[actual_partition_idx]->Reset(); - } - new_partitioned_data.FlushAppendState(append_state); - - count = 0; - data_size = 0; - - Verify(); -} - -void PartitionedTupleData::Unpin() { - for (auto &partition : partitions) { - partition->Unpin(); - } -} - -vector> &PartitionedTupleData::GetPartitions() { - return partitions; -} - -unique_ptr PartitionedTupleData::GetUnpartitioned() { - auto data_collection = std::move(partitions[0]); - partitions[0] = make_uniq(buffer_manager, layout); - - for (idx_t i = 1; i < partitions.size(); i++) { - data_collection->Combine(*partitions[i]); - } - count = 0; - data_size = 0; - - data_collection->Verify(); - Verify(); - - return data_collection; -} - -idx_t PartitionedTupleData::Count() const { - return count; -} - -idx_t PartitionedTupleData::SizeInBytes() const { - idx_t total_size = 0; - for (auto &partition : partitions) { - total_size += partition->SizeInBytes(); - } - return total_size; -} - -idx_t PartitionedTupleData::PartitionCount() const { - return partitions.size(); -} - -void PartitionedTupleData::Verify() const { -#ifdef DEBUG - idx_t total_count = 0; - idx_t total_size = 0; - for (auto &partition : partitions) { - partition->Verify(); - total_count += partition->Count(); - total_size += partition->SizeInBytes(); - } - D_ASSERT(total_count == this->count); - D_ASSERT(total_size == this->data_size); -#endif -} - -// LCOV_EXCL_START -string PartitionedTupleData::ToString() { - string result = - StringUtil::Format("PartitionedTupleData - [%llu Partitions, %llu Rows]\n", partitions.size(), Count()); - for (idx_t partition_idx = 0; partition_idx < partitions.size(); partition_idx++) { - result += StringUtil::Format("Partition %llu: ", partition_idx) + partitions[partition_idx]->ToString(); - } - return result; -} - -void PartitionedTupleData::Print() { - Printer::Print(ToString()); -} -// LCOV_EXCL_STOP - -void PartitionedTupleData::CreateAllocator() { - allocators->allocators.emplace_back(make_shared(buffer_manager, layout)); -} - -} // namespace duckdb - - -namespace duckdb { - -RowDataCollection::RowDataCollection(BufferManager &buffer_manager, idx_t block_capacity, idx_t entry_size, - bool keep_pinned) - : buffer_manager(buffer_manager), count(0), block_capacity(block_capacity), entry_size(entry_size), - keep_pinned(keep_pinned) { - D_ASSERT(block_capacity * entry_size + entry_size > Storage::BLOCK_SIZE); -} - -idx_t RowDataCollection::AppendToBlock(RowDataBlock &block, BufferHandle &handle, - vector &append_entries, idx_t remaining, idx_t entry_sizes[]) { - idx_t append_count = 0; - data_ptr_t dataptr; - if (entry_sizes) { - D_ASSERT(entry_size == 1); - // compute how many entries fit if entry size is variable - dataptr = handle.Ptr() + block.byte_offset; - for (idx_t i = 0; i < remaining; i++) { - if (block.byte_offset + entry_sizes[i] > block.capacity) { - if (block.count == 0 && append_count == 0 && entry_sizes[i] > block.capacity) { - // special case: single entry is bigger than block capacity - // resize current block to fit the entry, append it, and move to the next block - block.capacity = entry_sizes[i]; - buffer_manager.ReAllocate(block.block, block.capacity); - dataptr = handle.Ptr(); - append_count++; - block.byte_offset += entry_sizes[i]; - } - break; - } - append_count++; - block.byte_offset += entry_sizes[i]; - } - } else { - append_count = MinValue(remaining, block.capacity - block.count); - dataptr = handle.Ptr() + block.count * entry_size; - } - append_entries.emplace_back(dataptr, append_count); - block.count += append_count; - return append_count; -} - -RowDataBlock &RowDataCollection::CreateBlock() { - blocks.push_back(make_uniq(buffer_manager, block_capacity, entry_size)); - return *blocks.back(); -} - -vector RowDataCollection::Build(idx_t added_count, data_ptr_t key_locations[], idx_t entry_sizes[], - const SelectionVector *sel) { - vector handles; - vector append_entries; - - // first allocate space of where to serialize the keys and payload columns - idx_t remaining = added_count; - { - // first append to the last block (if any) - lock_guard append_lock(rdc_lock); - count += added_count; - - if (!blocks.empty()) { - auto &last_block = *blocks.back(); - if (last_block.count < last_block.capacity) { - // last block has space: pin the buffer of this block - auto handle = buffer_manager.Pin(last_block.block); - // now append to the block - idx_t append_count = AppendToBlock(last_block, handle, append_entries, remaining, entry_sizes); - remaining -= append_count; - handles.push_back(std::move(handle)); - } - } - while (remaining > 0) { - // now for the remaining data, allocate new buffers to store the data and append there - auto &new_block = CreateBlock(); - auto handle = buffer_manager.Pin(new_block.block); - - // offset the entry sizes array if we have added entries already - idx_t *offset_entry_sizes = entry_sizes ? entry_sizes + added_count - remaining : nullptr; - - idx_t append_count = AppendToBlock(new_block, handle, append_entries, remaining, offset_entry_sizes); - D_ASSERT(new_block.count > 0); - remaining -= append_count; - - if (keep_pinned) { - pinned_blocks.push_back(std::move(handle)); - } else { - handles.push_back(std::move(handle)); - } - } - } - // now set up the key_locations based on the append entries - idx_t append_idx = 0; - for (auto &append_entry : append_entries) { - idx_t next = append_idx + append_entry.count; - if (entry_sizes) { - for (; append_idx < next; append_idx++) { - key_locations[append_idx] = append_entry.baseptr; - append_entry.baseptr += entry_sizes[append_idx]; - } - } else { - for (; append_idx < next; append_idx++) { - auto idx = sel->get_index(append_idx); - key_locations[idx] = append_entry.baseptr; - append_entry.baseptr += entry_size; - } - } - } - // return the unique pointers to the handles because they must stay pinned - return handles; -} - -void RowDataCollection::Merge(RowDataCollection &other) { - if (other.count == 0) { - return; - } - RowDataCollection temp(buffer_manager, Storage::BLOCK_SIZE, 1); - { - // One lock at a time to avoid deadlocks - lock_guard read_lock(other.rdc_lock); - temp.count = other.count; - temp.block_capacity = other.block_capacity; - temp.entry_size = other.entry_size; - temp.blocks = std::move(other.blocks); - temp.pinned_blocks = std::move(other.pinned_blocks); - } - other.Clear(); - - lock_guard write_lock(rdc_lock); - count += temp.count; - block_capacity = MaxValue(block_capacity, temp.block_capacity); - entry_size = MaxValue(entry_size, temp.entry_size); - for (auto &block : temp.blocks) { - blocks.emplace_back(std::move(block)); - } - for (auto &handle : temp.pinned_blocks) { - pinned_blocks.emplace_back(std::move(handle)); - } -} - -} // namespace duckdb - - - - - - -#include - -namespace duckdb { - -void RowDataCollectionScanner::AlignHeapBlocks(RowDataCollection &swizzled_block_collection, - RowDataCollection &swizzled_string_heap, - RowDataCollection &block_collection, RowDataCollection &string_heap, - const RowLayout &layout) { - if (block_collection.count == 0) { - return; - } - - if (layout.AllConstant()) { - // No heap blocks! Just merge fixed-size data - swizzled_block_collection.Merge(block_collection); - return; - } - - // We create one heap block per data block and swizzle the pointers - D_ASSERT(string_heap.keep_pinned == swizzled_string_heap.keep_pinned); - auto &buffer_manager = block_collection.buffer_manager; - auto &heap_blocks = string_heap.blocks; - idx_t heap_block_idx = 0; - idx_t heap_block_remaining = heap_blocks[heap_block_idx]->count; - for (auto &data_block : block_collection.blocks) { - if (heap_block_remaining == 0) { - heap_block_remaining = heap_blocks[++heap_block_idx]->count; - } - - // Pin the data block and swizzle the pointers within the rows - auto data_handle = buffer_manager.Pin(data_block->block); - auto data_ptr = data_handle.Ptr(); - if (!string_heap.keep_pinned) { - D_ASSERT(!data_block->block->IsSwizzled()); - RowOperations::SwizzleColumns(layout, data_ptr, data_block->count); - data_block->block->SetSwizzling(nullptr); - } - // At this point the data block is pinned and the heap pointer is valid - // so we can copy heap data as needed - - // We want to copy as little of the heap data as possible, check how the data and heap blocks line up - if (heap_block_remaining >= data_block->count) { - // Easy: current heap block contains all strings for this data block, just copy (reference) the block - swizzled_string_heap.blocks.emplace_back(heap_blocks[heap_block_idx]->Copy()); - swizzled_string_heap.blocks.back()->count = data_block->count; - - // Swizzle the heap pointer if we are not pinning the heap - auto &heap_block = swizzled_string_heap.blocks.back()->block; - auto heap_handle = buffer_manager.Pin(heap_block); - if (!swizzled_string_heap.keep_pinned) { - auto heap_ptr = Load(data_ptr + layout.GetHeapOffset()); - auto heap_offset = heap_ptr - heap_handle.Ptr(); - RowOperations::SwizzleHeapPointer(layout, data_ptr, heap_ptr, data_block->count, heap_offset); - } else { - swizzled_string_heap.pinned_blocks.emplace_back(std::move(heap_handle)); - } - - // Update counter - heap_block_remaining -= data_block->count; - } else { - // Strings for this data block are spread over the current heap block and the next (and possibly more) - if (string_heap.keep_pinned) { - // The heap is changing underneath the data block, - // so swizzle the string pointers to make them portable. - RowOperations::SwizzleColumns(layout, data_ptr, data_block->count); - } - idx_t data_block_remaining = data_block->count; - vector> ptrs_and_sizes; - idx_t total_size = 0; - const auto base_row_ptr = data_ptr; - while (data_block_remaining > 0) { - if (heap_block_remaining == 0) { - heap_block_remaining = heap_blocks[++heap_block_idx]->count; - } - auto next = MinValue(data_block_remaining, heap_block_remaining); - - // Figure out where to start copying strings, and how many bytes we need to copy - auto heap_start_ptr = Load(data_ptr + layout.GetHeapOffset()); - auto heap_end_ptr = - Load(data_ptr + layout.GetHeapOffset() + (next - 1) * layout.GetRowWidth()); - idx_t size = heap_end_ptr - heap_start_ptr + Load(heap_end_ptr); - ptrs_and_sizes.emplace_back(heap_start_ptr, size); - D_ASSERT(size <= heap_blocks[heap_block_idx]->byte_offset); - - // Swizzle the heap pointer - RowOperations::SwizzleHeapPointer(layout, data_ptr, heap_start_ptr, next, total_size); - total_size += size; - - // Update where we are in the data and heap blocks - data_ptr += next * layout.GetRowWidth(); - data_block_remaining -= next; - heap_block_remaining -= next; - } - - // Finally, we allocate a new heap block and copy data to it - swizzled_string_heap.blocks.emplace_back( - make_uniq(buffer_manager, MaxValue(total_size, (idx_t)Storage::BLOCK_SIZE), 1)); - auto new_heap_handle = buffer_manager.Pin(swizzled_string_heap.blocks.back()->block); - auto new_heap_ptr = new_heap_handle.Ptr(); - for (auto &ptr_and_size : ptrs_and_sizes) { - memcpy(new_heap_ptr, ptr_and_size.first, ptr_and_size.second); - new_heap_ptr += ptr_and_size.second; - } - new_heap_ptr = new_heap_handle.Ptr(); - if (swizzled_string_heap.keep_pinned) { - // Since the heap blocks are pinned, we can unswizzle the data again. - swizzled_string_heap.pinned_blocks.emplace_back(std::move(new_heap_handle)); - RowOperations::UnswizzlePointers(layout, base_row_ptr, new_heap_ptr, data_block->count); - RowOperations::UnswizzleHeapPointer(layout, base_row_ptr, new_heap_ptr, data_block->count); - } - } - } - - // We're done with variable-sized data, now just merge the fixed-size data - swizzled_block_collection.Merge(block_collection); - D_ASSERT(swizzled_block_collection.blocks.size() == swizzled_string_heap.blocks.size()); - - // Update counts and cleanup - swizzled_string_heap.count = string_heap.count; - string_heap.Clear(); -} - -void RowDataCollectionScanner::ScanState::PinData() { - auto &rows = scanner.rows; - D_ASSERT(block_idx < rows.blocks.size()); - auto &data_block = rows.blocks[block_idx]; - if (!data_handle.IsValid() || data_handle.GetBlockHandle() != data_block->block) { - data_handle = rows.buffer_manager.Pin(data_block->block); - } - if (scanner.layout.AllConstant() || !scanner.external) { - return; - } - - auto &heap = scanner.heap; - D_ASSERT(block_idx < heap.blocks.size()); - auto &heap_block = heap.blocks[block_idx]; - if (!heap_handle.IsValid() || heap_handle.GetBlockHandle() != heap_block->block) { - heap_handle = heap.buffer_manager.Pin(heap_block->block); - } -} - -RowDataCollectionScanner::RowDataCollectionScanner(RowDataCollection &rows_p, RowDataCollection &heap_p, - const RowLayout &layout_p, bool external_p, bool flush_p) - : rows(rows_p), heap(heap_p), layout(layout_p), read_state(*this), total_count(rows.count), total_scanned(0), - external(external_p), flush(flush_p), unswizzling(!layout.AllConstant() && external && !heap.keep_pinned) { - - if (unswizzling) { - D_ASSERT(rows.blocks.size() == heap.blocks.size()); - } - - ValidateUnscannedBlock(); -} - -RowDataCollectionScanner::RowDataCollectionScanner(RowDataCollection &rows_p, RowDataCollection &heap_p, - const RowLayout &layout_p, bool external_p, idx_t block_idx, - bool flush_p) - : rows(rows_p), heap(heap_p), layout(layout_p), read_state(*this), total_count(rows.count), total_scanned(0), - external(external_p), flush(flush_p), unswizzling(!layout.AllConstant() && external && !heap.keep_pinned) { - - if (unswizzling) { - D_ASSERT(rows.blocks.size() == heap.blocks.size()); - } - - D_ASSERT(block_idx < rows.blocks.size()); - read_state.block_idx = block_idx; - read_state.entry_idx = 0; - - // Pretend that we have scanned up to the start block - // and will stop at the end - auto begin = rows.blocks.begin(); - auto end = begin + block_idx; - total_scanned = - std::accumulate(begin, end, idx_t(0), [&](idx_t c, const unique_ptr &b) { return c + b->count; }); - total_count = total_scanned + (*end)->count; - - ValidateUnscannedBlock(); -} - -void RowDataCollectionScanner::SwizzleBlock(RowDataBlock &data_block, RowDataBlock &heap_block) { - // Pin the data block and swizzle the pointers within the rows - D_ASSERT(!data_block.block->IsSwizzled()); - auto data_handle = rows.buffer_manager.Pin(data_block.block); - auto data_ptr = data_handle.Ptr(); - RowOperations::SwizzleColumns(layout, data_ptr, data_block.count); - data_block.block->SetSwizzling(nullptr); - - // Swizzle the heap pointers - auto heap_handle = heap.buffer_manager.Pin(heap_block.block); - auto heap_ptr = Load(data_ptr + layout.GetHeapOffset()); - auto heap_offset = heap_ptr - heap_handle.Ptr(); - RowOperations::SwizzleHeapPointer(layout, data_ptr, heap_ptr, data_block.count, heap_offset); -} - -void RowDataCollectionScanner::ReSwizzle() { - if (rows.count == 0) { - return; - } - - if (!unswizzling) { - // No swizzled blocks! - return; - } - - D_ASSERT(rows.blocks.size() == heap.blocks.size()); - for (idx_t i = 0; i < rows.blocks.size(); ++i) { - auto &data_block = rows.blocks[i]; - if (data_block->block && !data_block->block->IsSwizzled()) { - SwizzleBlock(*data_block, *heap.blocks[i]); - } - } -} - -void RowDataCollectionScanner::ValidateUnscannedBlock() const { - if (unswizzling && read_state.block_idx < rows.blocks.size() && Remaining()) { - D_ASSERT(rows.blocks[read_state.block_idx]->block->IsSwizzled()); - } -} - -void RowDataCollectionScanner::Scan(DataChunk &chunk) { - auto count = MinValue((idx_t)STANDARD_VECTOR_SIZE, total_count - total_scanned); - if (count == 0) { - chunk.SetCardinality(count); - return; - } - - // Only flush blocks we processed. - const auto flush_block_idx = read_state.block_idx; - - const idx_t &row_width = layout.GetRowWidth(); - // Set up a batch of pointers to scan data from - idx_t scanned = 0; - auto data_pointers = FlatVector::GetData(addresses); - - // We must pin ALL blocks we are going to gather from - vector pinned_blocks; - while (scanned < count) { - read_state.PinData(); - auto &data_block = rows.blocks[read_state.block_idx]; - idx_t next = MinValue(data_block->count - read_state.entry_idx, count - scanned); - const data_ptr_t data_ptr = read_state.data_handle.Ptr() + read_state.entry_idx * row_width; - // Set up the next pointers - data_ptr_t row_ptr = data_ptr; - for (idx_t i = 0; i < next; i++) { - data_pointers[scanned + i] = row_ptr; - row_ptr += row_width; - } - // Unswizzle the offsets back to pointers (if needed) - if (unswizzling) { - RowOperations::UnswizzlePointers(layout, data_ptr, read_state.heap_handle.Ptr(), next); - rows.blocks[read_state.block_idx]->block->SetSwizzling("RowDataCollectionScanner::Scan"); - } - // Update state indices - read_state.entry_idx += next; - scanned += next; - total_scanned += next; - if (read_state.entry_idx == data_block->count) { - // Pin completed blocks so we don't lose them - pinned_blocks.emplace_back(rows.buffer_manager.Pin(data_block->block)); - if (unswizzling) { - auto &heap_block = heap.blocks[read_state.block_idx]; - pinned_blocks.emplace_back(heap.buffer_manager.Pin(heap_block->block)); - } - read_state.block_idx++; - read_state.entry_idx = 0; - ValidateUnscannedBlock(); - } - } - D_ASSERT(scanned == count); - // Deserialize the payload data - for (idx_t col_no = 0; col_no < layout.ColumnCount(); col_no++) { - RowOperations::Gather(addresses, *FlatVector::IncrementalSelectionVector(), chunk.data[col_no], - *FlatVector::IncrementalSelectionVector(), count, layout, col_no); - } - chunk.SetCardinality(count); - chunk.Verify(); - - // Switch to a new set of pinned blocks - read_state.pinned_blocks.swap(pinned_blocks); - - if (flush) { - // Release blocks we have passed. - for (idx_t i = flush_block_idx; i < read_state.block_idx; ++i) { - rows.blocks[i]->block = nullptr; - if (unswizzling) { - heap.blocks[i]->block = nullptr; - } - } - } else if (unswizzling) { - // Reswizzle blocks we have passed so they can be flushed safely. - for (idx_t i = flush_block_idx; i < read_state.block_idx; ++i) { - auto &data_block = rows.blocks[i]; - if (data_block->block && !data_block->block->IsSwizzled()) { - SwizzleBlock(*data_block, *heap.blocks[i]); - } - } - } -} - -void RowDataCollectionScanner::Reset(bool flush_p) { - flush = flush_p; - total_scanned = 0; - - read_state.block_idx = 0; - read_state.entry_idx = 0; -} - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/types/row_layout.cpp -// -// -//===----------------------------------------------------------------------===// - - - - - -namespace duckdb { - -RowLayout::RowLayout() : flag_width(0), data_width(0), row_width(0), all_constant(true), heap_pointer_offset(0) { -} - -void RowLayout::Initialize(vector types_p, bool align) { - offsets.clear(); - types = std::move(types_p); - - // Null mask at the front - 1 bit per value. - flag_width = ValidityBytes::ValidityMaskSize(types.size()); - row_width = flag_width; - - // Whether all columns are constant size. - for (const auto &type : types) { - all_constant = all_constant && TypeIsConstantSize(type.InternalType()); - } - - // This enables pointer swizzling for out-of-core computation. - if (!all_constant) { - // When unswizzled, the pointer lives here. - // When swizzled, the pointer is replaced by an offset. - heap_pointer_offset = row_width; - // The 8 byte pointer will be replaced with an 8 byte idx_t when swizzled. - // However, this cannot be sizeof(data_ptr_t), since 32 bit builds use 4 byte pointers. - row_width += sizeof(idx_t); - } - - // Data columns. No alignment required. - for (const auto &type : types) { - offsets.push_back(row_width); - const auto internal_type = type.InternalType(); - if (TypeIsConstantSize(internal_type) || internal_type == PhysicalType::VARCHAR) { - row_width += GetTypeIdSize(type.InternalType()); - } else { - // Variable size types use pointers to the actual data (can be swizzled). - // Again, we would use sizeof(data_ptr_t), but this is not guaranteed to be equal to sizeof(idx_t). - row_width += sizeof(idx_t); - } - } - - data_width = row_width - flag_width; - - // Alignment padding for the next row - if (align) { - row_width = AlignValue(row_width); - } -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -using ValidityBytes = TupleDataLayout::ValidityBytes; - -TupleDataBlock::TupleDataBlock(BufferManager &buffer_manager, idx_t capacity_p) : capacity(capacity_p), size(0) { - buffer_manager.Allocate(capacity, false, &handle); -} - -TupleDataBlock::TupleDataBlock(TupleDataBlock &&other) noexcept { - std::swap(handle, other.handle); - std::swap(capacity, other.capacity); - std::swap(size, other.size); -} - -TupleDataBlock &TupleDataBlock::operator=(TupleDataBlock &&other) noexcept { - std::swap(handle, other.handle); - std::swap(capacity, other.capacity); - std::swap(size, other.size); - return *this; -} - -TupleDataAllocator::TupleDataAllocator(BufferManager &buffer_manager, const TupleDataLayout &layout) - : buffer_manager(buffer_manager), layout(layout.Copy()) { -} - -TupleDataAllocator::TupleDataAllocator(TupleDataAllocator &allocator) - : buffer_manager(allocator.buffer_manager), layout(allocator.layout.Copy()) { -} - -BufferManager &TupleDataAllocator::GetBufferManager() { - return buffer_manager; -} - -Allocator &TupleDataAllocator::GetAllocator() { - return buffer_manager.GetBufferAllocator(); -} - -const TupleDataLayout &TupleDataAllocator::GetLayout() const { - return layout; -} - -idx_t TupleDataAllocator::RowBlockCount() const { - return row_blocks.size(); -} - -idx_t TupleDataAllocator::HeapBlockCount() const { - return heap_blocks.size(); -} - -void TupleDataAllocator::Build(TupleDataSegment &segment, TupleDataPinState &pin_state, - TupleDataChunkState &chunk_state, const idx_t append_offset, const idx_t append_count) { - D_ASSERT(this == segment.allocator.get()); - auto &chunks = segment.chunks; - if (!chunks.empty()) { - ReleaseOrStoreHandles(pin_state, segment, chunks.back(), true); - } - - // Build the chunk parts for the incoming data - chunk_part_indices.clear(); - idx_t offset = 0; - while (offset != append_count) { - if (chunks.empty() || chunks.back().count == STANDARD_VECTOR_SIZE) { - chunks.emplace_back(); - } - auto &chunk = chunks.back(); - - // Build the next part - auto next = MinValue(append_count - offset, STANDARD_VECTOR_SIZE - chunk.count); - chunk.AddPart(BuildChunkPart(pin_state, chunk_state, append_offset + offset, next, chunk), layout); - auto &chunk_part = chunk.parts.back(); - next = chunk_part.count; - - segment.count += next; - segment.data_size += chunk_part.count * layout.GetRowWidth(); - if (!layout.AllConstant()) { - segment.data_size += chunk_part.total_heap_size; - } - - offset += next; - chunk_part_indices.emplace_back(chunks.size() - 1, chunk.parts.size() - 1); - } - - // Now initialize the pointers to write the data to - chunk_parts.clear(); - for (auto &indices : chunk_part_indices) { - chunk_parts.emplace_back(segment.chunks[indices.first].parts[indices.second]); - } - InitializeChunkStateInternal(pin_state, chunk_state, append_offset, false, true, false, chunk_parts); - - // To reduce metadata, we try to merge chunk parts where possible - // Due to the way chunk parts are constructed, only the last part of the first chunk is eligible for merging - segment.chunks[chunk_part_indices[0].first].MergeLastChunkPart(layout); - - segment.Verify(); -} - -TupleDataChunkPart TupleDataAllocator::BuildChunkPart(TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, - const idx_t append_offset, const idx_t append_count, - TupleDataChunk &chunk) { - D_ASSERT(append_count != 0); - TupleDataChunkPart result(*chunk.lock); - - // Allocate row block (if needed) - if (row_blocks.empty() || row_blocks.back().RemainingCapacity() < layout.GetRowWidth()) { - row_blocks.emplace_back(buffer_manager, (idx_t)Storage::BLOCK_SIZE); - } - result.row_block_index = row_blocks.size() - 1; - auto &row_block = row_blocks[result.row_block_index]; - result.row_block_offset = row_block.size; - - // Set count (might be reduced later when checking heap space) - result.count = MinValue(row_block.RemainingCapacity(layout.GetRowWidth()), append_count); - if (!layout.AllConstant()) { - const auto heap_sizes = FlatVector::GetData(chunk_state.heap_sizes); - - // Compute total heap size first - idx_t total_heap_size = 0; - for (idx_t i = 0; i < result.count; i++) { - const auto &heap_size = heap_sizes[append_offset + i]; - total_heap_size += heap_size; - } - - if (total_heap_size == 0) { - // We don't need a heap at all - result.heap_block_index = TupleDataChunkPart::INVALID_INDEX; - result.heap_block_offset = TupleDataChunkPart::INVALID_INDEX; - result.total_heap_size = 0; - result.base_heap_ptr = nullptr; - } else { - // Allocate heap block (if needed) - if (heap_blocks.empty() || heap_blocks.back().RemainingCapacity() < heap_sizes[append_offset]) { - const auto size = MaxValue((idx_t)Storage::BLOCK_SIZE, heap_sizes[append_offset]); - heap_blocks.emplace_back(buffer_manager, size); - } - result.heap_block_index = heap_blocks.size() - 1; - auto &heap_block = heap_blocks[result.heap_block_index]; - result.heap_block_offset = heap_block.size; - - const auto heap_remaining = heap_block.RemainingCapacity(); - if (total_heap_size <= heap_remaining) { - // Everything fits - result.total_heap_size = total_heap_size; - } else { - // Not everything fits - determine how many we can read next - result.total_heap_size = 0; - for (idx_t i = 0; i < result.count; i++) { - const auto &heap_size = heap_sizes[append_offset + i]; - if (result.total_heap_size + heap_size > heap_remaining) { - result.count = i; - break; - } - result.total_heap_size += heap_size; - } - } - - // Mark this portion of the heap block as filled and set the pointer - heap_block.size += result.total_heap_size; - result.base_heap_ptr = GetBaseHeapPointer(pin_state, result); - } - } - D_ASSERT(result.count != 0 && result.count <= STANDARD_VECTOR_SIZE); - - // Mark this portion of the row block as filled - row_block.size += result.count * layout.GetRowWidth(); - - return result; -} - -void TupleDataAllocator::InitializeChunkState(TupleDataSegment &segment, TupleDataPinState &pin_state, - TupleDataChunkState &chunk_state, idx_t chunk_idx, bool init_heap) { - D_ASSERT(this == segment.allocator.get()); - D_ASSERT(chunk_idx < segment.ChunkCount()); - auto &chunk = segment.chunks[chunk_idx]; - - // Release or store any handles that are no longer required: - // We can't release the heap here if the current chunk's heap_block_ids is empty, because if we are iterating with - // PinProperties::DESTROY_AFTER_DONE, we might destroy a heap block that is needed by a later chunk, e.g., - // when chunk 0 needs heap block 0, chunk 1 does not need any heap blocks, and chunk 2 needs heap block 0 again - ReleaseOrStoreHandles(pin_state, segment, chunk, !chunk.heap_block_ids.empty()); - - unsafe_vector> parts; - parts.reserve(chunk.parts.size()); - for (auto &part : chunk.parts) { - parts.emplace_back(part); - } - - InitializeChunkStateInternal(pin_state, chunk_state, 0, true, init_heap, init_heap, parts); -} - -static inline void InitializeHeapSizes(const data_ptr_t row_locations[], idx_t heap_sizes[], const idx_t offset, - const idx_t next, const TupleDataChunkPart &part, const idx_t heap_size_offset) { - // Read the heap sizes from the rows - for (idx_t i = 0; i < next; i++) { - auto idx = offset + i; - heap_sizes[idx] = Load(row_locations[idx] + heap_size_offset); - } - - // Verify total size -#ifdef DEBUG - idx_t total_heap_size = 0; - for (idx_t i = 0; i < next; i++) { - auto idx = offset + i; - total_heap_size += heap_sizes[idx]; - } - D_ASSERT(total_heap_size == part.total_heap_size); -#endif -} - -void TupleDataAllocator::InitializeChunkStateInternal(TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, - idx_t offset, bool recompute, bool init_heap_pointers, - bool init_heap_sizes, - unsafe_vector> &parts) { - auto row_locations = FlatVector::GetData(chunk_state.row_locations); - auto heap_sizes = FlatVector::GetData(chunk_state.heap_sizes); - auto heap_locations = FlatVector::GetData(chunk_state.heap_locations); - - for (auto &part_ref : parts) { - auto &part = part_ref.get(); - const auto next = part.count; - - // Set up row locations for the scan - const auto row_width = layout.GetRowWidth(); - const auto base_row_ptr = GetRowPointer(pin_state, part); - for (idx_t i = 0; i < next; i++) { - row_locations[offset + i] = base_row_ptr + i * row_width; - } - - if (layout.AllConstant()) { // Can't have a heap - offset += next; - continue; - } - - if (part.total_heap_size == 0) { - if (init_heap_sizes) { // No heap, but we need the heap sizes - InitializeHeapSizes(row_locations, heap_sizes, offset, next, part, layout.GetHeapSizeOffset()); - } - offset += next; - continue; - } - - // Check if heap block has changed - re-compute the pointers within each row if so - if (recompute && pin_state.properties != TupleDataPinProperties::ALREADY_PINNED) { - const auto new_base_heap_ptr = GetBaseHeapPointer(pin_state, part); - if (part.base_heap_ptr != new_base_heap_ptr) { - lock_guard guard(part.lock); - const auto old_base_heap_ptr = part.base_heap_ptr; - if (old_base_heap_ptr != new_base_heap_ptr) { - Vector old_heap_ptrs( - Value::POINTER(CastPointerToValue(old_base_heap_ptr + part.heap_block_offset))); - Vector new_heap_ptrs( - Value::POINTER(CastPointerToValue(new_base_heap_ptr + part.heap_block_offset))); - RecomputeHeapPointers(old_heap_ptrs, *ConstantVector::ZeroSelectionVector(), row_locations, - new_heap_ptrs, offset, next, layout, 0); - part.base_heap_ptr = new_base_heap_ptr; - } - } - } - - if (init_heap_sizes) { - InitializeHeapSizes(row_locations, heap_sizes, offset, next, part, layout.GetHeapSizeOffset()); - } - - if (init_heap_pointers) { - // Set the pointers where the heap data will be written (if needed) - heap_locations[offset] = part.base_heap_ptr + part.heap_block_offset; - for (idx_t i = 1; i < next; i++) { - auto idx = offset + i; - heap_locations[idx] = heap_locations[idx - 1] + heap_sizes[idx - 1]; - } - } - - offset += next; - } - D_ASSERT(offset <= STANDARD_VECTOR_SIZE); -} - -static inline void VerifyStrings(const LogicalTypeId type_id, const data_ptr_t row_locations[], const idx_t col_idx, - const idx_t base_col_offset, const idx_t col_offset, const idx_t offset, - const idx_t count) { -#ifdef DEBUG - if (type_id != LogicalTypeId::VARCHAR) { - // Make sure we don't verify BLOB / AGGREGATE_STATE - return; - } - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); - for (idx_t i = 0; i < count; i++) { - const auto &row_location = row_locations[offset + i] + base_col_offset; - ValidityBytes row_mask(row_location); - if (row_mask.RowIsValid(row_mask.GetValidityEntryUnsafe(entry_idx), idx_in_entry)) { - auto recomputed_string = Load(row_location + col_offset); - recomputed_string.Verify(); - } - } -#endif -} - -void TupleDataAllocator::RecomputeHeapPointers(Vector &old_heap_ptrs, const SelectionVector &old_heap_sel, - const data_ptr_t row_locations[], Vector &new_heap_ptrs, - const idx_t offset, const idx_t count, const TupleDataLayout &layout, - const idx_t base_col_offset) { - const auto old_heap_locations = FlatVector::GetData(old_heap_ptrs); - - UnifiedVectorFormat new_heap_data; - new_heap_ptrs.ToUnifiedFormat(offset + count, new_heap_data); - const auto new_heap_locations = UnifiedVectorFormat::GetData(new_heap_data); - const auto new_heap_sel = *new_heap_data.sel; - - for (idx_t col_idx = 0; col_idx < layout.ColumnCount(); col_idx++) { - const auto &col_offset = layout.GetOffsets()[col_idx]; - - // Precompute mask indexes - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); - - const auto &type = layout.GetTypes()[col_idx]; - switch (type.InternalType()) { - case PhysicalType::VARCHAR: { - for (idx_t i = 0; i < count; i++) { - const auto idx = offset + i; - const auto &row_location = row_locations[idx] + base_col_offset; - ValidityBytes row_mask(row_location); - if (!row_mask.RowIsValid(row_mask.GetValidityEntryUnsafe(entry_idx), idx_in_entry)) { - continue; - } - - const auto &old_heap_ptr = old_heap_locations[old_heap_sel.get_index(idx)]; - const auto &new_heap_ptr = new_heap_locations[new_heap_sel.get_index(idx)]; - - const auto string_location = row_location + col_offset; - if (Load(string_location) > string_t::INLINE_LENGTH) { - const auto string_ptr_location = string_location + string_t::HEADER_SIZE; - const auto string_ptr = Load(string_ptr_location); - const auto diff = string_ptr - old_heap_ptr; - D_ASSERT(diff >= 0); - Store(new_heap_ptr + diff, string_ptr_location); - } - } - VerifyStrings(type.id(), row_locations, col_idx, base_col_offset, col_offset, offset, count); - break; - } - case PhysicalType::LIST: { - for (idx_t i = 0; i < count; i++) { - const auto idx = offset + i; - const auto &row_location = row_locations[idx] + base_col_offset; - ValidityBytes row_mask(row_location); - if (!row_mask.RowIsValid(row_mask.GetValidityEntryUnsafe(entry_idx), idx_in_entry)) { - continue; - } - - const auto &old_heap_ptr = old_heap_locations[old_heap_sel.get_index(idx)]; - const auto &new_heap_ptr = new_heap_locations[new_heap_sel.get_index(idx)]; - - const auto &list_ptr_location = row_location + col_offset; - const auto list_ptr = Load(list_ptr_location); - const auto diff = list_ptr - old_heap_ptr; - D_ASSERT(diff >= 0); - Store(new_heap_ptr + diff, list_ptr_location); - } - break; - } - case PhysicalType::STRUCT: { - const auto &struct_layout = layout.GetStructLayout(col_idx); - if (!struct_layout.AllConstant()) { - RecomputeHeapPointers(old_heap_ptrs, old_heap_sel, row_locations, new_heap_ptrs, offset, count, - struct_layout, base_col_offset + col_offset); - } - break; - } - default: - continue; - } - } -} - -void TupleDataAllocator::ReleaseOrStoreHandles(TupleDataPinState &pin_state, TupleDataSegment &segment, - TupleDataChunk &chunk, bool release_heap) { - D_ASSERT(this == segment.allocator.get()); - ReleaseOrStoreHandlesInternal(segment, segment.pinned_row_handles, pin_state.row_handles, chunk.row_block_ids, - row_blocks, pin_state.properties); - if (!layout.AllConstant() && release_heap) { - ReleaseOrStoreHandlesInternal(segment, segment.pinned_heap_handles, pin_state.heap_handles, - chunk.heap_block_ids, heap_blocks, pin_state.properties); - } -} - -void TupleDataAllocator::ReleaseOrStoreHandles(TupleDataPinState &pin_state, TupleDataSegment &segment) { - static TupleDataChunk DUMMY_CHUNK; - ReleaseOrStoreHandles(pin_state, segment, DUMMY_CHUNK, true); -} - -void TupleDataAllocator::ReleaseOrStoreHandlesInternal( - TupleDataSegment &segment, unsafe_vector &pinned_handles, perfect_map_t &handles, - const perfect_set_t &block_ids, unsafe_vector &blocks, TupleDataPinProperties properties) { - bool found_handle; - do { - found_handle = false; - for (auto it = handles.begin(); it != handles.end(); it++) { - const auto block_id = it->first; - if (block_ids.find(block_id) != block_ids.end()) { - // still required: do not release - continue; - } - switch (properties) { - case TupleDataPinProperties::KEEP_EVERYTHING_PINNED: { - lock_guard guard(segment.pinned_handles_lock); - const auto block_count = block_id + 1; - if (block_count > pinned_handles.size()) { - pinned_handles.resize(block_count); - } - pinned_handles[block_id] = std::move(it->second); - break; - } - case TupleDataPinProperties::UNPIN_AFTER_DONE: - case TupleDataPinProperties::ALREADY_PINNED: - break; - case TupleDataPinProperties::DESTROY_AFTER_DONE: - blocks[block_id].handle = nullptr; - break; - default: - D_ASSERT(properties == TupleDataPinProperties::INVALID); - throw InternalException("Encountered TupleDataPinProperties::INVALID"); - } - handles.erase(it); - found_handle = true; - break; - } - } while (found_handle); -} - -BufferHandle &TupleDataAllocator::PinRowBlock(TupleDataPinState &pin_state, const TupleDataChunkPart &part) { - const auto &row_block_index = part.row_block_index; - auto it = pin_state.row_handles.find(row_block_index); - if (it == pin_state.row_handles.end()) { - D_ASSERT(row_block_index < row_blocks.size()); - auto &row_block = row_blocks[row_block_index]; - D_ASSERT(row_block.handle); - D_ASSERT(part.row_block_offset < row_block.size); - D_ASSERT(part.row_block_offset + part.count * layout.GetRowWidth() <= row_block.size); - it = pin_state.row_handles.emplace(row_block_index, buffer_manager.Pin(row_block.handle)).first; - } - return it->second; -} - -BufferHandle &TupleDataAllocator::PinHeapBlock(TupleDataPinState &pin_state, const TupleDataChunkPart &part) { - const auto &heap_block_index = part.heap_block_index; - auto it = pin_state.heap_handles.find(heap_block_index); - if (it == pin_state.heap_handles.end()) { - D_ASSERT(heap_block_index < heap_blocks.size()); - auto &heap_block = heap_blocks[heap_block_index]; - D_ASSERT(heap_block.handle); - D_ASSERT(part.heap_block_offset < heap_block.size); - D_ASSERT(part.heap_block_offset + part.total_heap_size <= heap_block.size); - it = pin_state.heap_handles.emplace(heap_block_index, buffer_manager.Pin(heap_block.handle)).first; - } - return it->second; -} - -data_ptr_t TupleDataAllocator::GetRowPointer(TupleDataPinState &pin_state, const TupleDataChunkPart &part) { - return PinRowBlock(pin_state, part).Ptr() + part.row_block_offset; -} - -data_ptr_t TupleDataAllocator::GetBaseHeapPointer(TupleDataPinState &pin_state, const TupleDataChunkPart &part) { - return PinHeapBlock(pin_state, part).Ptr(); -} - -} // namespace duckdb - - - - - - - -#include - -namespace duckdb { - -using ValidityBytes = TupleDataLayout::ValidityBytes; - -TupleDataCollection::TupleDataCollection(BufferManager &buffer_manager, const TupleDataLayout &layout_p) - : layout(layout_p.Copy()), allocator(make_shared(buffer_manager, layout)) { - Initialize(); -} - -TupleDataCollection::TupleDataCollection(shared_ptr allocator) - : layout(allocator->GetLayout().Copy()), allocator(std::move(allocator)) { - Initialize(); -} - -TupleDataCollection::~TupleDataCollection() { -} - -void TupleDataCollection::Initialize() { - D_ASSERT(!layout.GetTypes().empty()); - this->count = 0; - this->data_size = 0; - scatter_functions.reserve(layout.ColumnCount()); - gather_functions.reserve(layout.ColumnCount()); - for (idx_t col_idx = 0; col_idx < layout.ColumnCount(); col_idx++) { - auto &type = layout.GetTypes()[col_idx]; - scatter_functions.emplace_back(GetScatterFunction(type)); - gather_functions.emplace_back(GetGatherFunction(type)); - } -} - -void GetAllColumnIDsInternal(vector &column_ids, const idx_t column_count) { - column_ids.reserve(column_count); - for (idx_t col_idx = 0; col_idx < column_count; col_idx++) { - column_ids.emplace_back(col_idx); - } -} - -void TupleDataCollection::GetAllColumnIDs(vector &column_ids) { - GetAllColumnIDsInternal(column_ids, layout.ColumnCount()); -} - -const TupleDataLayout &TupleDataCollection::GetLayout() const { - return layout; -} - -const idx_t &TupleDataCollection::Count() const { - return count; -} - -idx_t TupleDataCollection::ChunkCount() const { - idx_t total_chunk_count = 0; - for (const auto &segment : segments) { - total_chunk_count += segment.ChunkCount(); - } - return total_chunk_count; -} - -idx_t TupleDataCollection::SizeInBytes() const { - idx_t total_size = 0; - for (const auto &segment : segments) { - total_size += segment.SizeInBytes(); - } - return total_size; -} - -void TupleDataCollection::Unpin() { - for (auto &segment : segments) { - segment.Unpin(); - } -} - -// LCOV_EXCL_START -void VerifyAppendColumns(const TupleDataLayout &layout, const vector &column_ids) { -#ifdef DEBUG - for (idx_t col_idx = 0; col_idx < layout.ColumnCount(); col_idx++) { - if (std::find(column_ids.begin(), column_ids.end(), col_idx) != column_ids.end()) { - continue; - } - // This column will not be appended in the first go - verify that it is fixed-size - we cannot resize heap after - const auto physical_type = layout.GetTypes()[col_idx].InternalType(); - D_ASSERT(physical_type != PhysicalType::VARCHAR && physical_type != PhysicalType::LIST); - if (physical_type == PhysicalType::STRUCT) { - const auto &struct_layout = layout.GetStructLayout(col_idx); - vector struct_column_ids; - struct_column_ids.reserve(struct_layout.ColumnCount()); - for (idx_t struct_col_idx = 0; struct_col_idx < struct_layout.ColumnCount(); struct_col_idx++) { - struct_column_ids.emplace_back(struct_col_idx); - } - VerifyAppendColumns(struct_layout, struct_column_ids); - } - } -#endif -} -// LCOV_EXCL_STOP - -void TupleDataCollection::InitializeAppend(TupleDataAppendState &append_state, TupleDataPinProperties properties) { - vector column_ids; - GetAllColumnIDs(column_ids); - InitializeAppend(append_state, std::move(column_ids), properties); -} - -void TupleDataCollection::InitializeAppend(TupleDataAppendState &append_state, vector column_ids, - TupleDataPinProperties properties) { - VerifyAppendColumns(layout, column_ids); - InitializeAppend(append_state.pin_state, properties); - InitializeChunkState(append_state.chunk_state, std::move(column_ids)); -} - -void TupleDataCollection::InitializeAppend(TupleDataPinState &pin_state, TupleDataPinProperties properties) { - pin_state.properties = properties; - if (segments.empty()) { - segments.emplace_back(allocator); - } -} - -static void InitializeVectorFormat(vector &vector_data, const vector &types) { - vector_data.resize(types.size()); - for (idx_t col_idx = 0; col_idx < types.size(); col_idx++) { - const auto &type = types[col_idx]; - switch (type.InternalType()) { - case PhysicalType::STRUCT: { - const auto &child_list = StructType::GetChildTypes(type); - vector child_types; - child_types.reserve(child_list.size()); - for (const auto &child_entry : child_list) { - child_types.emplace_back(child_entry.second); - } - InitializeVectorFormat(vector_data[col_idx].children, child_types); - break; - } - case PhysicalType::LIST: - InitializeVectorFormat(vector_data[col_idx].children, {ListType::GetChildType(type)}); - break; - default: - break; - } - } -} - -void TupleDataCollection::InitializeChunkState(TupleDataChunkState &chunk_state, vector column_ids) { - TupleDataCollection::InitializeChunkState(chunk_state, layout.GetTypes(), std::move(column_ids)); -} - -void TupleDataCollection::InitializeChunkState(TupleDataChunkState &chunk_state, const vector &types, - vector column_ids) { - if (column_ids.empty()) { - GetAllColumnIDsInternal(column_ids, types.size()); - } - InitializeVectorFormat(chunk_state.vector_data, types); - chunk_state.column_ids = std::move(column_ids); -} - -void TupleDataCollection::Append(DataChunk &new_chunk, const SelectionVector &append_sel, idx_t append_count) { - TupleDataAppendState append_state; - InitializeAppend(append_state); - Append(append_state, new_chunk, append_sel, append_count); -} - -void TupleDataCollection::Append(DataChunk &new_chunk, vector column_ids, const SelectionVector &append_sel, - const idx_t append_count) { - TupleDataAppendState append_state; - InitializeAppend(append_state, std::move(column_ids)); - Append(append_state, new_chunk, append_sel, append_count); -} - -void TupleDataCollection::Append(TupleDataAppendState &append_state, DataChunk &new_chunk, - const SelectionVector &append_sel, const idx_t append_count) { - Append(append_state.pin_state, append_state.chunk_state, new_chunk, append_sel, append_count); -} - -void TupleDataCollection::Append(TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, DataChunk &new_chunk, - const SelectionVector &append_sel, const idx_t append_count) { - TupleDataCollection::ToUnifiedFormat(chunk_state, new_chunk); - AppendUnified(pin_state, chunk_state, new_chunk, append_sel, append_count); -} - -void TupleDataCollection::AppendUnified(TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, - DataChunk &new_chunk, const SelectionVector &append_sel, - const idx_t append_count) { - const idx_t actual_append_count = append_count == DConstants::INVALID_INDEX ? new_chunk.size() : append_count; - if (actual_append_count == 0) { - return; - } - - if (!layout.AllConstant()) { - TupleDataCollection::ComputeHeapSizes(chunk_state, new_chunk, append_sel, actual_append_count); - } - - Build(pin_state, chunk_state, 0, actual_append_count); - -#ifdef DEBUG - Vector heap_locations_copy(LogicalType::POINTER); - if (!layout.AllConstant()) { - VectorOperations::Copy(chunk_state.heap_locations, heap_locations_copy, actual_append_count, 0, 0); - } -#endif - - Scatter(chunk_state, new_chunk, append_sel, actual_append_count); - -#ifdef DEBUG - // Verify that the size of the data written to the heap is the same as the size we computed it would be - if (!layout.AllConstant()) { - const auto original_heap_locations = FlatVector::GetData(heap_locations_copy); - const auto heap_sizes = FlatVector::GetData(chunk_state.heap_sizes); - const auto offset_heap_locations = FlatVector::GetData(chunk_state.heap_locations); - for (idx_t i = 0; i < actual_append_count; i++) { - D_ASSERT(offset_heap_locations[i] == original_heap_locations[i] + heap_sizes[i]); - } - } -#endif -} - -static inline void ToUnifiedFormatInternal(TupleDataVectorFormat &format, Vector &vector, const idx_t count) { - vector.ToUnifiedFormat(count, format.unified); - format.original_sel = format.unified.sel; - format.original_owned_sel.Initialize(format.unified.owned_sel); - switch (vector.GetType().InternalType()) { - case PhysicalType::STRUCT: { - auto &entries = StructVector::GetEntries(vector); - D_ASSERT(format.children.size() == entries.size()); - for (idx_t struct_col_idx = 0; struct_col_idx < entries.size(); struct_col_idx++) { - ToUnifiedFormatInternal(reinterpret_cast(format.children[struct_col_idx]), - *entries[struct_col_idx], count); - } - break; - } - case PhysicalType::LIST: - D_ASSERT(format.children.size() == 1); - ToUnifiedFormatInternal(reinterpret_cast(format.children[0]), - ListVector::GetEntry(vector), ListVector::GetListSize(vector)); - break; - default: - break; - } -} - -void TupleDataCollection::ToUnifiedFormat(TupleDataChunkState &chunk_state, DataChunk &new_chunk) { - D_ASSERT(chunk_state.vector_data.size() >= chunk_state.column_ids.size()); // Needs InitializeAppend - for (const auto &col_idx : chunk_state.column_ids) { - ToUnifiedFormatInternal(chunk_state.vector_data[col_idx], new_chunk.data[col_idx], new_chunk.size()); - } -} - -void TupleDataCollection::GetVectorData(const TupleDataChunkState &chunk_state, UnifiedVectorFormat result[]) { - const auto &vector_data = chunk_state.vector_data; - for (idx_t i = 0; i < vector_data.size(); i++) { - const auto &source = vector_data[i].unified; - auto &target = result[i]; - target.sel = source.sel; - target.data = source.data; - target.validity = source.validity; - } -} - -void TupleDataCollection::Build(TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, - const idx_t append_offset, const idx_t append_count) { - auto &segment = segments.back(); - const auto size_before = segment.SizeInBytes(); - segment.allocator->Build(segment, pin_state, chunk_state, append_offset, append_count); - data_size += segment.SizeInBytes() - size_before; - count += append_count; - Verify(); -} - -// LCOV_EXCL_START -void VerifyHeapSizes(const data_ptr_t source_locations[], const idx_t heap_sizes[], const SelectionVector &append_sel, - const idx_t append_count, const idx_t heap_size_offset) { -#ifdef DEBUG - for (idx_t i = 0; i < append_count; i++) { - auto idx = append_sel.get_index(i); - const auto stored_heap_size = Load(source_locations[idx] + heap_size_offset); - D_ASSERT(stored_heap_size == heap_sizes[idx]); - } -#endif -} -// LCOV_EXCL_STOP - -void TupleDataCollection::CopyRows(TupleDataChunkState &chunk_state, TupleDataChunkState &input, - const SelectionVector &append_sel, const idx_t append_count) const { - const auto source_locations = FlatVector::GetData(input.row_locations); - const auto target_locations = FlatVector::GetData(chunk_state.row_locations); - - // Copy rows - const auto row_width = layout.GetRowWidth(); - for (idx_t i = 0; i < append_count; i++) { - auto idx = append_sel.get_index(i); - FastMemcpy(target_locations[i], source_locations[idx], row_width); - } - - // Copy heap if we need to - if (!layout.AllConstant()) { - const auto source_heap_locations = FlatVector::GetData(input.heap_locations); - const auto target_heap_locations = FlatVector::GetData(chunk_state.heap_locations); - const auto heap_sizes = FlatVector::GetData(input.heap_sizes); - VerifyHeapSizes(source_locations, heap_sizes, append_sel, append_count, layout.GetHeapSizeOffset()); - - // Check if we need to copy anything at all - idx_t total_heap_size = 0; - for (idx_t i = 0; i < append_count; i++) { - auto idx = append_sel.get_index(i); - total_heap_size += heap_sizes[idx]; - } - if (total_heap_size == 0) { - return; - } - - // Copy heap - for (idx_t i = 0; i < append_count; i++) { - auto idx = append_sel.get_index(i); - FastMemcpy(target_heap_locations[i], source_heap_locations[idx], heap_sizes[idx]); - } - - // Recompute pointers after copying the data - TupleDataAllocator::RecomputeHeapPointers(input.heap_locations, append_sel, target_locations, - chunk_state.heap_locations, 0, append_count, layout, 0); - } -} - -void TupleDataCollection::Combine(TupleDataCollection &other) { - if (other.count == 0) { - return; - } - if (this->layout.GetTypes() != other.GetLayout().GetTypes()) { - throw InternalException("Attempting to combine TupleDataCollection with mismatching types"); - } - this->segments.reserve(this->segments.size() + other.segments.size()); - for (auto &other_seg : other.segments) { - AddSegment(std::move(other_seg)); - } - other.Reset(); -} - -void TupleDataCollection::AddSegment(TupleDataSegment &&segment) { - count += segment.count; - data_size += segment.data_size; - segments.emplace_back(std::move(segment)); - Verify(); -} - -void TupleDataCollection::Combine(unique_ptr other) { - Combine(*other); -} - -void TupleDataCollection::Reset() { - count = 0; - data_size = 0; - segments.clear(); - - // Refreshes the TupleDataAllocator to prevent holding on to allocated data unnecessarily - allocator = make_shared(*allocator); -} - -void TupleDataCollection::InitializeChunk(DataChunk &chunk) const { - chunk.Initialize(allocator->GetAllocator(), layout.GetTypes()); -} - -void TupleDataCollection::InitializeScanChunk(TupleDataScanState &state, DataChunk &chunk) const { - auto &column_ids = state.chunk_state.column_ids; - D_ASSERT(!column_ids.empty()); - vector chunk_types; - chunk_types.reserve(column_ids.size()); - for (idx_t i = 0; i < column_ids.size(); i++) { - auto column_idx = column_ids[i]; - D_ASSERT(column_idx < layout.ColumnCount()); - chunk_types.push_back(layout.GetTypes()[column_idx]); - } - chunk.Initialize(allocator->GetAllocator(), chunk_types); -} - -void TupleDataCollection::InitializeScan(TupleDataScanState &state, TupleDataPinProperties properties) const { - vector column_ids; - column_ids.reserve(layout.ColumnCount()); - for (idx_t i = 0; i < layout.ColumnCount(); i++) { - column_ids.push_back(i); - } - InitializeScan(state, std::move(column_ids), properties); -} - -void TupleDataCollection::InitializeScan(TupleDataScanState &state, vector column_ids, - TupleDataPinProperties properties) const { - state.pin_state.row_handles.clear(); - state.pin_state.heap_handles.clear(); - state.pin_state.properties = properties; - state.segment_index = 0; - state.chunk_index = 0; - state.chunk_state.column_ids = std::move(column_ids); -} - -void TupleDataCollection::InitializeScan(TupleDataParallelScanState &gstate, TupleDataPinProperties properties) const { - InitializeScan(gstate.scan_state, properties); -} - -void TupleDataCollection::InitializeScan(TupleDataParallelScanState &state, vector column_ids, - TupleDataPinProperties properties) const { - InitializeScan(state.scan_state, std::move(column_ids), properties); -} - -bool TupleDataCollection::Scan(TupleDataScanState &state, DataChunk &result) { - const auto segment_index_before = state.segment_index; - idx_t segment_index; - idx_t chunk_index; - if (!NextScanIndex(state, segment_index, chunk_index)) { - if (!segments.empty()) { - FinalizePinState(state.pin_state, segments[segment_index_before]); - } - result.SetCardinality(0); - return false; - } - if (segment_index_before != DConstants::INVALID_INDEX && segment_index != segment_index_before) { - FinalizePinState(state.pin_state, segments[segment_index_before]); - } - ScanAtIndex(state.pin_state, state.chunk_state, state.chunk_state.column_ids, segment_index, chunk_index, result); - return true; -} - -bool TupleDataCollection::Scan(TupleDataParallelScanState &gstate, TupleDataLocalScanState &lstate, DataChunk &result) { - lstate.pin_state.properties = gstate.scan_state.pin_state.properties; - - const auto segment_index_before = lstate.segment_index; - { - lock_guard guard(gstate.lock); - if (!NextScanIndex(gstate.scan_state, lstate.segment_index, lstate.chunk_index)) { - if (!segments.empty()) { - FinalizePinState(lstate.pin_state, segments[segment_index_before]); - } - result.SetCardinality(0); - return false; - } - } - if (segment_index_before != DConstants::INVALID_INDEX && segment_index_before != lstate.segment_index) { - FinalizePinState(lstate.pin_state, segments[lstate.segment_index]); - } - ScanAtIndex(lstate.pin_state, lstate.chunk_state, gstate.scan_state.chunk_state.column_ids, lstate.segment_index, - lstate.chunk_index, result); - return true; -} - -bool TupleDataCollection::ScanComplete(const TupleDataScanState &state) const { - if (Count() == 0) { - return true; - } - return state.segment_index == segments.size() - 1 && state.chunk_index == segments.back().ChunkCount(); -} - -void TupleDataCollection::FinalizePinState(TupleDataPinState &pin_state, TupleDataSegment &segment) { - segment.allocator->ReleaseOrStoreHandles(pin_state, segment); -} - -void TupleDataCollection::FinalizePinState(TupleDataPinState &pin_state) { - D_ASSERT(!segments.empty()); - FinalizePinState(pin_state, segments.back()); -} - -bool TupleDataCollection::NextScanIndex(TupleDataScanState &state, idx_t &segment_index, idx_t &chunk_index) { - // Check if we still have segments to scan - if (state.segment_index >= segments.size()) { - // No more data left in the scan - return false; - } - // Check within the current segment if we still have chunks to scan - while (state.chunk_index >= segments[state.segment_index].ChunkCount()) { - // Exhausted all chunks for this segment: Move to the next one - state.segment_index++; - state.chunk_index = 0; - if (state.segment_index >= segments.size()) { - return false; - } - } - segment_index = state.segment_index; - chunk_index = state.chunk_index++; - return true; -} - -void TupleDataCollection::ScanAtIndex(TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, - const vector &column_ids, idx_t segment_index, idx_t chunk_index, - DataChunk &result) { - auto &segment = segments[segment_index]; - auto &chunk = segment.chunks[chunk_index]; - segment.allocator->InitializeChunkState(segment, pin_state, chunk_state, chunk_index, false); - result.Reset(); - Gather(chunk_state.row_locations, *FlatVector::IncrementalSelectionVector(), chunk.count, column_ids, result, - *FlatVector::IncrementalSelectionVector()); - result.SetCardinality(chunk.count); -} - -// LCOV_EXCL_START -string TupleDataCollection::ToString() { - DataChunk chunk; - InitializeChunk(chunk); - - TupleDataScanState scan_state; - InitializeScan(scan_state); - - string result = StringUtil::Format("TupleDataCollection - [%llu Chunks, %llu Rows]\n", ChunkCount(), Count()); - idx_t chunk_idx = 0; - idx_t row_count = 0; - while (Scan(scan_state, chunk)) { - result += - StringUtil::Format("Chunk %llu - [Rows %llu - %llu]\n", chunk_idx, row_count, row_count + chunk.size()) + - chunk.ToString(); - chunk_idx++; - row_count += chunk.size(); - } - - return result; -} - -void TupleDataCollection::Print() { - Printer::Print(ToString()); -} - -void TupleDataCollection::Verify() const { -#ifdef DEBUG - idx_t total_count = 0; - idx_t total_size = 0; - for (const auto &segment : segments) { - segment.Verify(); - total_count += segment.count; - total_size += segment.data_size; - } - D_ASSERT(total_count == this->count); - D_ASSERT(total_size == this->data_size); -#endif -} - -void TupleDataCollection::VerifyEverythingPinned() const { -#ifdef DEBUG - for (const auto &segment : segments) { - segment.VerifyEverythingPinned(); - } -#endif -} -// LCOV_EXCL_STOP - -} // namespace duckdb - - - - -namespace duckdb { - -TupleDataChunkIterator::TupleDataChunkIterator(TupleDataCollection &collection_p, TupleDataPinProperties properties_p, - bool init_heap) - : TupleDataChunkIterator(collection_p, properties_p, 0, collection_p.ChunkCount(), init_heap) { -} - -TupleDataChunkIterator::TupleDataChunkIterator(TupleDataCollection &collection_p, TupleDataPinProperties properties, - idx_t chunk_idx_from, idx_t chunk_idx_to, bool init_heap_p) - : collection(collection_p), init_heap(init_heap_p) { - state.pin_state.properties = properties; - D_ASSERT(chunk_idx_from < chunk_idx_to); - D_ASSERT(chunk_idx_to <= collection.ChunkCount()); - idx_t overall_chunk_index = 0; - for (idx_t segment_idx = 0; segment_idx < collection.segments.size(); segment_idx++) { - const auto &segment = collection.segments[segment_idx]; - if (chunk_idx_from >= overall_chunk_index && chunk_idx_from <= overall_chunk_index + segment.ChunkCount()) { - // We start in this segment - start_segment_idx = segment_idx; - start_chunk_idx = chunk_idx_from - overall_chunk_index; - } - if (chunk_idx_to >= overall_chunk_index && chunk_idx_to <= overall_chunk_index + segment.ChunkCount()) { - // We end in this segment - end_segment_idx = segment_idx; - end_chunk_idx = chunk_idx_to - overall_chunk_index; - } - overall_chunk_index += segment.ChunkCount(); - } - - Reset(); -} - -void TupleDataChunkIterator::InitializeCurrentChunk() { - auto &segment = collection.segments[current_segment_idx]; - segment.allocator->InitializeChunkState(segment, state.pin_state, state.chunk_state, current_chunk_idx, init_heap); -} - -bool TupleDataChunkIterator::Done() const { - return current_segment_idx == end_segment_idx && current_chunk_idx == end_chunk_idx; -} - -bool TupleDataChunkIterator::Next() { - D_ASSERT(!Done()); // Check if called after already done - - // Set the next indices and checks if we're at the end of the collection - // NextScanIndex can go past this iterators 'end', so we have to check the indices again - const auto segment_idx_before = current_segment_idx; - if (!collection.NextScanIndex(state, current_segment_idx, current_chunk_idx) || Done()) { - // Drop pins / stores them if TupleDataPinProperties::KEEP_EVERYTHING_PINNED - collection.FinalizePinState(state.pin_state, collection.segments[segment_idx_before]); - current_segment_idx = end_segment_idx; - current_chunk_idx = end_chunk_idx; - return false; - } - - // Finalize pin state when moving from one segment to the next - if (current_segment_idx != segment_idx_before) { - collection.FinalizePinState(state.pin_state, collection.segments[segment_idx_before]); - } - - InitializeCurrentChunk(); - return true; -} - -void TupleDataChunkIterator::Reset() { - state.segment_index = start_segment_idx; - state.chunk_index = start_chunk_idx; - collection.NextScanIndex(state, current_segment_idx, current_chunk_idx); - InitializeCurrentChunk(); -} - -idx_t TupleDataChunkIterator::GetCurrentChunkCount() const { - return collection.segments[current_segment_idx].chunks[current_chunk_idx].count; -} - -TupleDataChunkState &TupleDataChunkIterator::GetChunkState() { - return state.chunk_state; -} - -data_ptr_t *TupleDataChunkIterator::GetRowLocations() { - return FlatVector::GetData(state.chunk_state.row_locations); -} - -data_ptr_t *TupleDataChunkIterator::GetHeapLocations() { - return FlatVector::GetData(state.chunk_state.heap_locations); -} - -idx_t *TupleDataChunkIterator::GetHeapSizes() { - return FlatVector::GetData(state.chunk_state.heap_sizes); -} - -} // namespace duckdb - - - - -namespace duckdb { - -TupleDataLayout::TupleDataLayout() - : flag_width(0), data_width(0), aggr_width(0), row_width(0), all_constant(true), heap_size_offset(0), - has_destructor(false) { -} - -TupleDataLayout TupleDataLayout::Copy() const { - TupleDataLayout result; - result.types = this->types; - result.aggregates = this->aggregates; - if (this->struct_layouts) { - result.struct_layouts = make_uniq>(); - for (const auto &entry : *this->struct_layouts) { - result.struct_layouts->emplace(entry.first, entry.second.Copy()); - } - } - result.flag_width = this->flag_width; - result.data_width = this->data_width; - result.aggr_width = this->aggr_width; - result.row_width = this->row_width; - result.offsets = this->offsets; - result.all_constant = this->all_constant; - result.heap_size_offset = this->heap_size_offset; - result.has_destructor = this->has_destructor; - return result; -} - -void TupleDataLayout::Initialize(vector types_p, Aggregates aggregates_p, bool align, bool heap_offset_p) { - offsets.clear(); - types = std::move(types_p); - - // Null mask at the front - 1 bit per value. - flag_width = ValidityBytes::ValidityMaskSize(types.size()); - row_width = flag_width; - - // Whether all columns are constant size. - for (idx_t col_idx = 0; col_idx < types.size(); col_idx++) { - const auto &type = types[col_idx]; - if (type.InternalType() == PhysicalType::STRUCT) { - // structs are recursively stored as a TupleDataLayout again - const auto &child_types = StructType::GetChildTypes(type); - vector child_type_vector; - child_type_vector.reserve(child_types.size()); - for (auto &ct : child_types) { - child_type_vector.emplace_back(ct.second); - } - if (!struct_layouts) { - struct_layouts = make_uniq>(); - } - auto struct_entry = struct_layouts->emplace(col_idx, TupleDataLayout()); - struct_entry.first->second.Initialize(std::move(child_type_vector), false, false); - all_constant = all_constant && struct_entry.first->second.AllConstant(); - } else { - all_constant = all_constant && TypeIsConstantSize(type.InternalType()); - } - } - - // This enables pointer swizzling for out-of-core computation. - if (heap_offset_p && !all_constant) { - heap_size_offset = row_width; - row_width += sizeof(uint32_t); - } - - // Data columns. No alignment required. - for (idx_t col_idx = 0; col_idx < types.size(); col_idx++) { - const auto &type = types[col_idx]; - offsets.push_back(row_width); - const auto internal_type = type.InternalType(); - if (TypeIsConstantSize(internal_type) || internal_type == PhysicalType::VARCHAR) { - row_width += GetTypeIdSize(type.InternalType()); - } else if (internal_type == PhysicalType::STRUCT) { - // Just get the size of the TupleDataLayout of the struct - row_width += GetStructLayout(col_idx).GetRowWidth(); - } else { - // Variable size types use pointers to the actual data (can be swizzled). - // Again, we would use sizeof(data_ptr_t), but this is not guaranteed to be equal to sizeof(idx_t). - row_width += sizeof(idx_t); - } - } - - // Alignment padding for aggregates -#ifndef DUCKDB_ALLOW_UNDEFINED - if (align) { - row_width = AlignValue(row_width); - } -#endif - data_width = row_width - flag_width; - - // Aggregate fields. - aggregates = std::move(aggregates_p); - for (auto &aggregate : aggregates) { - offsets.push_back(row_width); - row_width += aggregate.payload_size; -#ifndef DUCKDB_ALLOW_UNDEFINED - D_ASSERT(aggregate.payload_size == AlignValue(aggregate.payload_size)); -#endif - } - aggr_width = row_width - data_width - flag_width; - - // Alignment padding for the next row -#ifndef DUCKDB_ALLOW_UNDEFINED - if (align) { - row_width = AlignValue(row_width); - } -#endif - - has_destructor = false; - for (auto &aggr : GetAggregates()) { - if (aggr.function.destructor) { - has_destructor = true; - break; - } - } -} - -void TupleDataLayout::Initialize(vector types_p, bool align, bool heap_offset_p) { - Initialize(std::move(types_p), Aggregates(), align, heap_offset_p); -} - -void TupleDataLayout::Initialize(Aggregates aggregates_p, bool align, bool heap_offset_p) { - Initialize(vector(), std::move(aggregates_p), align, heap_offset_p); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -using ValidityBytes = TupleDataLayout::ValidityBytes; - -template -static constexpr idx_t TupleDataWithinListFixedSize() { - return sizeof(T); -} - -template <> -constexpr idx_t TupleDataWithinListFixedSize() { - return sizeof(uint32_t); -} - -template -static inline void TupleDataValueStore(const T &source, const data_ptr_t &row_location, const idx_t offset_in_row, - data_ptr_t &heap_location) { - Store(source, row_location + offset_in_row); -} - -template <> -inline void TupleDataValueStore(const string_t &source, const data_ptr_t &row_location, const idx_t offset_in_row, - data_ptr_t &heap_location) { - if (source.IsInlined()) { - Store(source, row_location + offset_in_row); - } else { - memcpy(heap_location, source.GetData(), source.GetSize()); - Store(string_t(const_char_ptr_cast(heap_location), source.GetSize()), row_location + offset_in_row); - heap_location += source.GetSize(); - } -} - -template -static inline void TupleDataWithinListValueStore(const T &source, const data_ptr_t &location, - data_ptr_t &heap_location) { - Store(source, location); -} - -template <> -inline void TupleDataWithinListValueStore(const string_t &source, const data_ptr_t &location, - data_ptr_t &heap_location) { - Store(source.GetSize(), location); - memcpy(heap_location, source.GetData(), source.GetSize()); - heap_location += source.GetSize(); -} - -template -static inline T TupleDataWithinListValueLoad(const data_ptr_t &location, data_ptr_t &heap_location) { - return Load(location); -} - -template <> -inline string_t TupleDataWithinListValueLoad(const data_ptr_t &location, data_ptr_t &heap_location) { - const auto size = Load(location); - string_t result(const_char_ptr_cast(heap_location), size); - heap_location += size; - return result; -} - -#ifdef DEBUG -static void ResetCombinedListData(vector &vector_data) { - for (auto &vd : vector_data) { - vd.combined_list_data = nullptr; - ResetCombinedListData(vd.children); - } -} -#endif - -void TupleDataCollection::ComputeHeapSizes(TupleDataChunkState &chunk_state, const DataChunk &new_chunk, - const SelectionVector &append_sel, const idx_t append_count) { -#ifdef DEBUG - ResetCombinedListData(chunk_state.vector_data); -#endif - - auto heap_sizes = FlatVector::GetData(chunk_state.heap_sizes); - std::fill_n(heap_sizes, new_chunk.size(), 0); - - for (idx_t col_idx = 0; col_idx < new_chunk.ColumnCount(); col_idx++) { - auto &source_v = new_chunk.data[col_idx]; - auto &source_format = chunk_state.vector_data[col_idx]; - TupleDataCollection::ComputeHeapSizes(chunk_state.heap_sizes, source_v, source_format, append_sel, - append_count); - } -} - -static inline idx_t StringHeapSize(const string_t &val) { - return val.IsInlined() ? 0 : val.GetSize(); -} - -void TupleDataCollection::ComputeHeapSizes(Vector &heap_sizes_v, const Vector &source_v, - TupleDataVectorFormat &source_format, const SelectionVector &append_sel, - const idx_t append_count) { - const auto type = source_v.GetType().InternalType(); - if (type != PhysicalType::VARCHAR && type != PhysicalType::STRUCT && type != PhysicalType::LIST) { - return; - } - - auto heap_sizes = FlatVector::GetData(heap_sizes_v); - - const auto &source_vector_data = source_format.unified; - const auto &source_sel = *source_vector_data.sel; - const auto &source_validity = source_vector_data.validity; - - switch (type) { - case PhysicalType::VARCHAR: { - // Only non-inlined strings are stored in the heap - const auto source_data = UnifiedVectorFormat::GetData(source_vector_data); - for (idx_t i = 0; i < append_count; i++) { - const auto source_idx = source_sel.get_index(append_sel.get_index(i)); - if (source_validity.RowIsValid(source_idx)) { - heap_sizes[i] += StringHeapSize(source_data[source_idx]); - } else { - heap_sizes[i] += StringHeapSize(NullValue()); - } - } - break; - } - case PhysicalType::STRUCT: { - // Recurse through the struct children - auto &struct_sources = StructVector::GetEntries(source_v); - for (idx_t struct_col_idx = 0; struct_col_idx < struct_sources.size(); struct_col_idx++) { - const auto &struct_source = struct_sources[struct_col_idx]; - auto &struct_format = source_format.children[struct_col_idx]; - TupleDataCollection::ComputeHeapSizes(heap_sizes_v, *struct_source, struct_format, append_sel, - append_count); - } - break; - } - case PhysicalType::LIST: { - // Lists are stored entirely in the heap - for (idx_t i = 0; i < append_count; i++) { - auto source_idx = source_sel.get_index(append_sel.get_index(i)); - if (source_validity.RowIsValid(source_idx)) { - heap_sizes[i] += sizeof(uint64_t); // Size of the list - } - } - - // Recurse - D_ASSERT(source_format.children.size() == 1); - auto &child_source_v = ListVector::GetEntry(source_v); - auto &child_format = source_format.children[0]; - TupleDataCollection::WithinListHeapComputeSizes(heap_sizes_v, child_source_v, child_format, append_sel, - append_count, source_vector_data); - break; - } - default: - throw NotImplementedException("ComputeHeapSizes for %s", EnumUtil::ToString(source_v.GetType().id())); - } -} - -void TupleDataCollection::WithinListHeapComputeSizes(Vector &heap_sizes_v, const Vector &source_v, - TupleDataVectorFormat &source_format, - const SelectionVector &append_sel, const idx_t append_count, - const UnifiedVectorFormat &list_data) { - auto type = source_v.GetType().InternalType(); - if (TypeIsConstantSize(type)) { - TupleDataCollection::ComputeFixedWithinListHeapSizes(heap_sizes_v, source_v, source_format, append_sel, - append_count, list_data); - return; - } - - switch (type) { - case PhysicalType::VARCHAR: - TupleDataCollection::StringWithinListComputeHeapSizes(heap_sizes_v, source_v, source_format, append_sel, - append_count, list_data); - break; - case PhysicalType::STRUCT: - TupleDataCollection::StructWithinListComputeHeapSizes(heap_sizes_v, source_v, source_format, append_sel, - append_count, list_data); - break; - case PhysicalType::LIST: - TupleDataCollection::ListWithinListComputeHeapSizes(heap_sizes_v, source_v, source_format, append_sel, - append_count, list_data); - break; - default: - throw NotImplementedException("WithinListHeapComputeSizes for %s", EnumUtil::ToString(source_v.GetType().id())); - } -} - -void TupleDataCollection::ComputeFixedWithinListHeapSizes(Vector &heap_sizes_v, const Vector &source_v, - TupleDataVectorFormat &source_format, - const SelectionVector &append_sel, const idx_t append_count, - const UnifiedVectorFormat &list_data) { - // List data - const auto list_sel = *list_data.sel; - const auto list_entries = UnifiedVectorFormat::GetData(list_data); - const auto &list_validity = list_data.validity; - - // Target - auto heap_sizes = FlatVector::GetData(heap_sizes_v); - - D_ASSERT(TypeIsConstantSize(source_v.GetType().InternalType())); - const auto type_size = GetTypeIdSize(source_v.GetType().InternalType()); - for (idx_t i = 0; i < append_count; i++) { - const auto list_idx = list_sel.get_index(append_sel.get_index(i)); - if (!list_validity.RowIsValid(list_idx)) { - continue; // Original list entry is invalid - no need to serialize the child - } - - // Get the current list length - const auto &list_length = list_entries[list_idx].length; - - // Size is validity mask and all values - auto &heap_size = heap_sizes[i]; - heap_size += ValidityBytes::SizeInBytes(list_length); - heap_size += list_length * type_size; - } -} - -void TupleDataCollection::StringWithinListComputeHeapSizes(Vector &heap_sizes_v, const Vector &source_v, - TupleDataVectorFormat &source_format, - const SelectionVector &append_sel, const idx_t append_count, - const UnifiedVectorFormat &list_data) { - // Source - const auto &source_data = source_format.unified; - const auto &source_sel = *source_data.sel; - const auto data = UnifiedVectorFormat::GetData(source_data); - const auto &source_validity = source_data.validity; - - // List data - const auto list_sel = *list_data.sel; - const auto list_entries = UnifiedVectorFormat::GetData(list_data); - const auto &list_validity = list_data.validity; - - // Target - auto heap_sizes = FlatVector::GetData(heap_sizes_v); - - for (idx_t i = 0; i < append_count; i++) { - const auto list_idx = list_sel.get_index(append_sel.get_index(i)); - if (!list_validity.RowIsValid(list_idx)) { - continue; // Original list entry is invalid - no need to serialize the child - } - - // Get the current list entry - const auto &list_entry = list_entries[list_idx]; - const auto &list_offset = list_entry.offset; - const auto &list_length = list_entry.length; - - // Size is validity mask and all string sizes - auto &heap_size = heap_sizes[i]; - heap_size += ValidityBytes::SizeInBytes(list_length); - heap_size += list_length * TupleDataWithinListFixedSize(); - - // Plus all the actual strings - for (idx_t child_i = 0; child_i < list_length; child_i++) { - const auto child_source_idx = source_sel.get_index(list_offset + child_i); - if (source_validity.RowIsValid(child_source_idx)) { - heap_size += data[child_source_idx].GetSize(); - } - } - } -} - -void TupleDataCollection::StructWithinListComputeHeapSizes(Vector &heap_sizes_v, const Vector &source_v, - TupleDataVectorFormat &source_format, - const SelectionVector &append_sel, const idx_t append_count, - const UnifiedVectorFormat &list_data) { - // List data - const auto list_sel = *list_data.sel; - const auto list_entries = UnifiedVectorFormat::GetData(list_data); - const auto &list_validity = list_data.validity; - - // Target - auto heap_sizes = FlatVector::GetData(heap_sizes_v); - - for (idx_t i = 0; i < append_count; i++) { - const auto list_idx = list_sel.get_index(append_sel.get_index(i)); - if (!list_validity.RowIsValid(list_idx)) { - continue; // Original list entry is invalid - no need to serialize the child - } - - // Get the current list length - const auto &list_length = list_entries[list_idx].length; - - // Size is just the validity mask - heap_sizes[i] += ValidityBytes::SizeInBytes(list_length); - } - - // Recurse - auto &struct_sources = StructVector::GetEntries(source_v); - for (idx_t struct_col_idx = 0; struct_col_idx < struct_sources.size(); struct_col_idx++) { - auto &struct_source = *struct_sources[struct_col_idx]; - auto &struct_format = source_format.children[struct_col_idx]; - TupleDataCollection::WithinListHeapComputeSizes(heap_sizes_v, struct_source, struct_format, append_sel, - append_count, list_data); - } -} - -static void ApplySliceRecursive(const Vector &source_v, TupleDataVectorFormat &source_format, - const SelectionVector &combined_sel, const idx_t count) { - D_ASSERT(source_format.combined_list_data); - auto &combined_list_data = *source_format.combined_list_data; - - combined_list_data.selection_data = source_format.original_sel->Slice(combined_sel, count); - source_format.unified.owned_sel.Initialize(combined_list_data.selection_data); - source_format.unified.sel = &source_format.unified.owned_sel; - - if (source_v.GetType().InternalType() == PhysicalType::STRUCT) { - // We have to apply it to the child vectors too - auto &struct_sources = StructVector::GetEntries(source_v); - for (idx_t struct_col_idx = 0; struct_col_idx < struct_sources.size(); struct_col_idx++) { - auto &struct_source = *struct_sources[struct_col_idx]; - auto &struct_format = source_format.children[struct_col_idx]; -#ifdef DEBUG - D_ASSERT(!struct_format.combined_list_data); -#endif - if (!struct_format.combined_list_data) { - struct_format.combined_list_data = make_uniq(); - } - ApplySliceRecursive(struct_source, struct_format, *source_format.unified.sel, count); - } - } -} - -void TupleDataCollection::ListWithinListComputeHeapSizes(Vector &heap_sizes_v, const Vector &source_v, - TupleDataVectorFormat &source_format, - const SelectionVector &append_sel, const idx_t append_count, - const UnifiedVectorFormat &list_data) { - // List data (of the list Vector that "source_v" is in) - const auto list_sel = *list_data.sel; - const auto list_entries = UnifiedVectorFormat::GetData(list_data); - const auto &list_validity = list_data.validity; - - // Child list ("source_v") - const auto &child_list_data = source_format.unified; - const auto child_list_sel = *child_list_data.sel; - const auto child_list_entries = UnifiedVectorFormat::GetData(child_list_data); - const auto &child_list_validity = child_list_data.validity; - - // Figure out actual child list size (can differ from ListVector::GetListSize if dict/const vector), - // and we cannot use ConstantVector::ZeroSelectionVector because it may need to be longer than STANDARD_VECTOR_SIZE - idx_t sum_of_sizes = 0; - for (idx_t i = 0; i < append_count; i++) { - const auto list_idx = list_sel.get_index(append_sel.get_index(i)); - if (!list_validity.RowIsValid(list_idx)) { - continue; - } - const auto &list_entry = list_entries[list_idx]; - const auto &list_offset = list_entry.offset; - const auto &list_length = list_entry.length; - - for (idx_t child_i = 0; child_i < list_length; child_i++) { - const auto child_list_idx = child_list_sel.get_index(list_offset + child_i); - if (!child_list_validity.RowIsValid(child_list_idx)) { - continue; - } - - const auto &child_list_entry = child_list_entries[child_list_idx]; - const auto &child_list_length = child_list_entry.length; - - sum_of_sizes += child_list_length; - } - } - const auto child_list_child_count = MaxValue(sum_of_sizes, ListVector::GetListSize(source_v)); - - // Target - auto heap_sizes = FlatVector::GetData(heap_sizes_v); - - // Construct combined list entries and a selection vector for the child list child - auto &child_format = source_format.children[0]; -#ifdef DEBUG - // In debug mode this should be deleted by ResetCombinedListData - D_ASSERT(!child_format.combined_list_data); -#endif - if (!child_format.combined_list_data) { - child_format.combined_list_data = make_uniq(); - } - auto &combined_list_data = *child_format.combined_list_data; - auto &combined_list_entries = combined_list_data.combined_list_entries; - SelectionVector combined_sel(child_list_child_count); - for (idx_t i = 0; i < child_list_child_count; i++) { - combined_sel.set_index(i, 0); - } - - idx_t combined_list_offset = 0; - for (idx_t i = 0; i < append_count; i++) { - const auto list_idx = list_sel.get_index(append_sel.get_index(i)); - if (!list_validity.RowIsValid(list_idx)) { - continue; // Original list entry is invalid - no need to serialize the child list - } - - // Get the current list entry - const auto &list_entry = list_entries[list_idx]; - const auto &list_offset = list_entry.offset; - const auto &list_length = list_entry.length; - - // Size is the validity mask and the list sizes - auto &heap_size = heap_sizes[i]; - heap_size += ValidityBytes::SizeInBytes(list_length); - heap_size += list_length * sizeof(uint64_t); - - idx_t child_list_size = 0; - for (idx_t child_i = 0; child_i < list_length; child_i++) { - const auto child_list_idx = child_list_sel.get_index(list_offset + child_i); - const auto &child_list_entry = child_list_entries[child_list_idx]; - if (child_list_validity.RowIsValid(child_list_idx)) { - const auto &child_list_offset = child_list_entry.offset; - const auto &child_list_length = child_list_entry.length; - - // Add this child's list entries to the combined selection vector - for (idx_t child_value_i = 0; child_value_i < child_list_length; child_value_i++) { - auto idx = combined_list_offset + child_list_size + child_value_i; - auto loc = child_list_offset + child_value_i; - combined_sel.set_index(idx, loc); - } - - child_list_size += child_list_length; - } - } - - // Combine the child list entries into one - combined_list_entries[list_idx] = {combined_list_offset, child_list_size}; - combined_list_offset += child_list_size; - } - - // Create a combined child_list_data to be used as list_data in the recursion - auto &combined_child_list_data = combined_list_data.combined_data; - combined_child_list_data.sel = list_data.sel; - combined_child_list_data.data = data_ptr_cast(combined_list_entries); - combined_child_list_data.validity = list_data.validity; - - // Combine the selection vectors - D_ASSERT(source_format.children.size() == 1); - auto &child_source = ListVector::GetEntry(source_v); - ApplySliceRecursive(child_source, child_format, combined_sel, child_list_child_count); - - // Recurse - TupleDataCollection::WithinListHeapComputeSizes(heap_sizes_v, child_source, child_format, append_sel, append_count, - combined_child_list_data); -} - -void TupleDataCollection::Scatter(TupleDataChunkState &chunk_state, const DataChunk &new_chunk, - const SelectionVector &append_sel, const idx_t append_count) const { - const auto row_locations = FlatVector::GetData(chunk_state.row_locations); - - // Set the validity mask for each row before inserting data - const auto validity_bytes = ValidityBytes::SizeInBytes(layout.ColumnCount()); - for (idx_t i = 0; i < append_count; i++) { - FastMemset(row_locations[i], ~0, validity_bytes); - } - - if (!layout.AllConstant()) { - // Set the heap size for each row - const auto heap_size_offset = layout.GetHeapSizeOffset(); - const auto heap_sizes = FlatVector::GetData(chunk_state.heap_sizes); - for (idx_t i = 0; i < append_count; i++) { - Store(heap_sizes[i], row_locations[i] + heap_size_offset); - } - } - - // Write the data - for (const auto &col_idx : chunk_state.column_ids) { - Scatter(chunk_state, new_chunk.data[col_idx], col_idx, append_sel, append_count); - } -} - -void TupleDataCollection::Scatter(TupleDataChunkState &chunk_state, const Vector &source, const column_t column_id, - const SelectionVector &append_sel, const idx_t append_count) const { - const auto &scatter_function = scatter_functions[column_id]; - scatter_function.function(source, chunk_state.vector_data[column_id], append_sel, append_count, layout, - chunk_state.row_locations, chunk_state.heap_locations, column_id, - chunk_state.vector_data[column_id].unified, scatter_function.child_functions); -} - -template -static void TupleDataTemplatedScatter(const Vector &source, const TupleDataVectorFormat &source_format, - const SelectionVector &append_sel, const idx_t append_count, - const TupleDataLayout &layout, const Vector &row_locations, - Vector &heap_locations, const idx_t col_idx, const UnifiedVectorFormat &dummy_arg, - const vector &child_functions) { - // Source - const auto &source_data = source_format.unified; - const auto &source_sel = *source_data.sel; - const auto data = UnifiedVectorFormat::GetData(source_data); - const auto &validity = source_data.validity; - - // Target - auto target_locations = FlatVector::GetData(row_locations); - auto target_heap_locations = FlatVector::GetData(heap_locations); - - // Precompute mask indexes - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); - - const auto offset_in_row = layout.GetOffsets()[col_idx]; - if (validity.AllValid()) { - for (idx_t i = 0; i < append_count; i++) { - const auto source_idx = source_sel.get_index(append_sel.get_index(i)); - TupleDataValueStore(data[source_idx], target_locations[i], offset_in_row, target_heap_locations[i]); - } - } else { - for (idx_t i = 0; i < append_count; i++) { - const auto source_idx = source_sel.get_index(append_sel.get_index(i)); - if (validity.RowIsValid(source_idx)) { - TupleDataValueStore(data[source_idx], target_locations[i], offset_in_row, target_heap_locations[i]); - } else { - TupleDataValueStore(NullValue(), target_locations[i], offset_in_row, target_heap_locations[i]); - ValidityBytes(target_locations[i]).SetInvalidUnsafe(entry_idx, idx_in_entry); - } - } - } -} - -static void TupleDataStructScatter(const Vector &source, const TupleDataVectorFormat &source_format, - const SelectionVector &append_sel, const idx_t append_count, - const TupleDataLayout &layout, const Vector &row_locations, Vector &heap_locations, - const idx_t col_idx, const UnifiedVectorFormat &dummy_arg, - const vector &child_functions) { - // Source - const auto &source_data = source_format.unified; - const auto &source_sel = *source_data.sel; - const auto &validity = source_data.validity; - - // Target - auto target_locations = FlatVector::GetData(row_locations); - - // Precompute mask indexes - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); - - // Set validity of the STRUCT in this layout - if (!validity.AllValid()) { - for (idx_t i = 0; i < append_count; i++) { - const auto source_idx = source_sel.get_index(append_sel.get_index(i)); - if (!validity.RowIsValid(source_idx)) { - ValidityBytes(target_locations[i]).SetInvalidUnsafe(entry_idx, idx_in_entry); - } - } - } - - // Create a Vector of pointers to the TupleDataLayout of the STRUCT - Vector struct_row_locations(LogicalType::POINTER, append_count); - auto struct_target_locations = FlatVector::GetData(struct_row_locations); - const auto offset_in_row = layout.GetOffsets()[col_idx]; - for (idx_t i = 0; i < append_count; i++) { - struct_target_locations[i] = target_locations[i] + offset_in_row; - } - - const auto &struct_layout = layout.GetStructLayout(col_idx); - auto &struct_sources = StructVector::GetEntries(source); - D_ASSERT(struct_layout.ColumnCount() == struct_sources.size()); - - // Set the validity of the entries within the STRUCTs - const auto validity_bytes = ValidityBytes::SizeInBytes(struct_layout.ColumnCount()); - for (idx_t i = 0; i < append_count; i++) { - memset(struct_target_locations[i], ~0, validity_bytes); - } - - // Recurse through the struct children - for (idx_t struct_col_idx = 0; struct_col_idx < struct_layout.ColumnCount(); struct_col_idx++) { - auto &struct_source = *struct_sources[struct_col_idx]; - const auto &struct_source_format = source_format.children[struct_col_idx]; - const auto &struct_scatter_function = child_functions[struct_col_idx]; - struct_scatter_function.function(struct_source, struct_source_format, append_sel, append_count, struct_layout, - struct_row_locations, heap_locations, struct_col_idx, dummy_arg, - struct_scatter_function.child_functions); - } -} - -static void TupleDataListScatter(const Vector &source, const TupleDataVectorFormat &source_format, - const SelectionVector &append_sel, const idx_t append_count, - const TupleDataLayout &layout, const Vector &row_locations, Vector &heap_locations, - const idx_t col_idx, const UnifiedVectorFormat &dummy_arg, - const vector &child_functions) { - // Source - const auto &source_data = source_format.unified; - const auto &source_sel = *source_data.sel; - const auto data = UnifiedVectorFormat::GetData(source_data); - const auto &validity = source_data.validity; - - // Target - auto target_locations = FlatVector::GetData(row_locations); - auto target_heap_locations = FlatVector::GetData(heap_locations); - - // Precompute mask indexes - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); - - // Set validity of the LIST in this layout, and store pointer to where it's stored - const auto offset_in_row = layout.GetOffsets()[col_idx]; - for (idx_t i = 0; i < append_count; i++) { - const auto source_idx = source_sel.get_index(append_sel.get_index(i)); - if (validity.RowIsValid(source_idx)) { - auto &target_heap_location = target_heap_locations[i]; - Store(target_heap_location, target_locations[i] + offset_in_row); - - // Store list length and skip over it - Store(data[source_idx].length, target_heap_location); - target_heap_location += sizeof(uint64_t); - } else { - ValidityBytes(target_locations[i]).SetInvalidUnsafe(entry_idx, idx_in_entry); - } - } - - // Recurse - D_ASSERT(child_functions.size() == 1); - auto &child_source = ListVector::GetEntry(source); - auto &child_format = source_format.children[0]; - const auto &child_function = child_functions[0]; - child_function.function(child_source, child_format, append_sel, append_count, layout, row_locations, heap_locations, - col_idx, source_format.unified, child_function.child_functions); -} - -template -static void TupleDataTemplatedWithinListScatter(const Vector &source, const TupleDataVectorFormat &source_format, - const SelectionVector &append_sel, const idx_t append_count, - const TupleDataLayout &layout, const Vector &row_locations, - Vector &heap_locations, const idx_t col_idx, - const UnifiedVectorFormat &list_data, - const vector &child_functions) { - // Source - const auto &source_data = source_format.unified; - const auto &source_sel = *source_data.sel; - const auto data = UnifiedVectorFormat::GetData(source_data); - const auto &source_validity = source_data.validity; - - // List data - const auto list_sel = *list_data.sel; - const auto list_entries = UnifiedVectorFormat::GetData(list_data); - const auto &list_validity = list_data.validity; - - // Target - auto target_heap_locations = FlatVector::GetData(heap_locations); - - for (idx_t i = 0; i < append_count; i++) { - const auto list_idx = list_sel.get_index(append_sel.get_index(i)); - if (!list_validity.RowIsValid(list_idx)) { - continue; // Original list entry is invalid - no need to serialize the child - } - - // Get the current list entry - const auto &list_entry = list_entries[list_idx]; - const auto &list_offset = list_entry.offset; - const auto &list_length = list_entry.length; - - // Initialize validity mask and skip heap pointer over it - auto &target_heap_location = target_heap_locations[i]; - ValidityBytes child_mask(target_heap_location); - child_mask.SetAllValid(list_length); - target_heap_location += ValidityBytes::SizeInBytes(list_length); - - // Get the start to the fixed-size data and skip the heap pointer over it - const auto child_data_location = target_heap_location; - target_heap_location += list_length * TupleDataWithinListFixedSize(); - - // Store the data and validity belonging to this list entry - for (idx_t child_i = 0; child_i < list_length; child_i++) { - const auto child_source_idx = source_sel.get_index(list_offset + child_i); - if (source_validity.RowIsValid(child_source_idx)) { - TupleDataWithinListValueStore(data[child_source_idx], - child_data_location + child_i * TupleDataWithinListFixedSize(), - target_heap_location); - } else { - child_mask.SetInvalidUnsafe(child_i); - } - } - } -} - -static void TupleDataStructWithinListScatter(const Vector &source, const TupleDataVectorFormat &source_format, - const SelectionVector &append_sel, const idx_t append_count, - const TupleDataLayout &layout, const Vector &row_locations, - Vector &heap_locations, const idx_t col_idx, - const UnifiedVectorFormat &list_data, - const vector &child_functions) { - // Source - const auto &source_data = source_format.unified; - const auto &source_sel = *source_data.sel; - const auto &source_validity = source_data.validity; - - // List data - const auto list_sel = *list_data.sel; - const auto list_entries = UnifiedVectorFormat::GetData(list_data); - const auto &list_validity = list_data.validity; - - // Target - auto target_heap_locations = FlatVector::GetData(heap_locations); - - // Initialize the validity of the STRUCTs - for (idx_t i = 0; i < append_count; i++) { - const auto list_idx = list_sel.get_index(append_sel.get_index(i)); - if (!list_validity.RowIsValid(list_idx)) { - continue; // Original list entry is invalid - no need to serialize the child - } - - // Get the current list entry - const auto &list_entry = list_entries[list_idx]; - const auto &list_offset = list_entry.offset; - const auto &list_length = list_entry.length; - - // Initialize validity mask and skip the heap pointer over it - auto &target_heap_location = target_heap_locations[i]; - ValidityBytes child_mask(target_heap_location); - child_mask.SetAllValid(list_length); - target_heap_location += ValidityBytes::SizeInBytes(list_length); - - // Store the validity belonging to this list entry - for (idx_t child_i = 0; child_i < list_length; child_i++) { - const auto child_source_idx = source_sel.get_index(list_offset + child_i); - if (!source_validity.RowIsValid(child_source_idx)) { - child_mask.SetInvalidUnsafe(child_i); - } - } - } - - // Recurse through the children - auto &struct_sources = StructVector::GetEntries(source); - for (idx_t struct_col_idx = 0; struct_col_idx < struct_sources.size(); struct_col_idx++) { - auto &struct_source = *struct_sources[struct_col_idx]; - auto &struct_format = source_format.children[struct_col_idx]; - const auto &struct_scatter_function = child_functions[struct_col_idx]; - struct_scatter_function.function(struct_source, struct_format, append_sel, append_count, layout, row_locations, - heap_locations, struct_col_idx, list_data, - struct_scatter_function.child_functions); - } -} - -static void TupleDataListWithinListScatter(const Vector &child_list, const TupleDataVectorFormat &child_list_format, - const SelectionVector &append_sel, const idx_t append_count, - const TupleDataLayout &layout, const Vector &row_locations, - Vector &heap_locations, const idx_t col_idx, - const UnifiedVectorFormat &list_data, - const vector &child_functions) { - // List data (of the list Vector that "child_list" is in) - const auto list_sel = *list_data.sel; - const auto list_entries = UnifiedVectorFormat::GetData(list_data); - const auto &list_validity = list_data.validity; - - // Child list - const auto &child_list_data = child_list_format.unified; - const auto child_list_sel = *child_list_data.sel; - const auto child_list_entries = UnifiedVectorFormat::GetData(child_list_data); - const auto &child_list_validity = child_list_data.validity; - - // Target - auto target_heap_locations = FlatVector::GetData(heap_locations); - - for (idx_t i = 0; i < append_count; i++) { - const auto list_idx = list_sel.get_index(append_sel.get_index(i)); - if (!list_validity.RowIsValid(list_idx)) { - continue; // Original list entry is invalid - no need to serialize the child list - } - - // Get the current list entry - const auto &list_entry = list_entries[list_idx]; - const auto &list_offset = list_entry.offset; - const auto &list_length = list_entry.length; - - // Initialize validity mask and skip heap pointer over it - auto &target_heap_location = target_heap_locations[i]; - ValidityBytes child_mask(target_heap_location); - child_mask.SetAllValid(list_length); - target_heap_location += ValidityBytes::SizeInBytes(list_length); - - // Get the start to the fixed-size data and skip the heap pointer over it - const auto child_data_location = target_heap_location; - target_heap_location += list_length * sizeof(uint64_t); - - for (idx_t child_i = 0; child_i < list_length; child_i++) { - const auto child_list_idx = child_list_sel.get_index(list_offset + child_i); - if (child_list_validity.RowIsValid(child_list_idx)) { - const auto &child_list_length = child_list_entries[child_list_idx].length; - Store(child_list_length, child_data_location + child_i * sizeof(uint64_t)); - } else { - child_mask.SetInvalidUnsafe(child_i); - } - } - } - - // Recurse - D_ASSERT(child_functions.size() == 1); - auto &child_vec = ListVector::GetEntry(child_list); - auto &child_format = child_list_format.children[0]; - auto &combined_child_list_data = child_format.combined_list_data->combined_data; - const auto &child_function = child_functions[0]; - child_function.function(child_vec, child_format, append_sel, append_count, layout, row_locations, heap_locations, - col_idx, combined_child_list_data, child_function.child_functions); -} - -template -tuple_data_scatter_function_t TupleDataGetScatterFunction(bool within_list) { - return within_list ? TupleDataTemplatedWithinListScatter : TupleDataTemplatedScatter; -} - -TupleDataScatterFunction TupleDataCollection::GetScatterFunction(const LogicalType &type, bool within_list) { - TupleDataScatterFunction result; - switch (type.InternalType()) { - case PhysicalType::BOOL: - result.function = TupleDataGetScatterFunction(within_list); - break; - case PhysicalType::INT8: - result.function = TupleDataGetScatterFunction(within_list); - break; - case PhysicalType::INT16: - result.function = TupleDataGetScatterFunction(within_list); - break; - case PhysicalType::INT32: - result.function = TupleDataGetScatterFunction(within_list); - break; - case PhysicalType::INT64: - result.function = TupleDataGetScatterFunction(within_list); - break; - case PhysicalType::INT128: - result.function = TupleDataGetScatterFunction(within_list); - break; - case PhysicalType::UINT8: - result.function = TupleDataGetScatterFunction(within_list); - break; - case PhysicalType::UINT16: - result.function = TupleDataGetScatterFunction(within_list); - break; - case PhysicalType::UINT32: - result.function = TupleDataGetScatterFunction(within_list); - break; - case PhysicalType::UINT64: - result.function = TupleDataGetScatterFunction(within_list); - break; - case PhysicalType::FLOAT: - result.function = TupleDataGetScatterFunction(within_list); - break; - case PhysicalType::DOUBLE: - result.function = TupleDataGetScatterFunction(within_list); - break; - case PhysicalType::INTERVAL: - result.function = TupleDataGetScatterFunction(within_list); - break; - case PhysicalType::VARCHAR: - result.function = TupleDataGetScatterFunction(within_list); - break; - case PhysicalType::STRUCT: { - result.function = within_list ? TupleDataStructWithinListScatter : TupleDataStructScatter; - for (const auto &child_type : StructType::GetChildTypes(type)) { - result.child_functions.push_back(GetScatterFunction(child_type.second, within_list)); - } - break; - } - case PhysicalType::LIST: - result.function = within_list ? TupleDataListWithinListScatter : TupleDataListScatter; - result.child_functions.emplace_back(GetScatterFunction(ListType::GetChildType(type), true)); - break; - default: - throw InternalException("Unsupported type for TupleDataCollection::GetScatterFunction"); - } - return result; -} - -void TupleDataCollection::Gather(Vector &row_locations, const SelectionVector &scan_sel, const idx_t scan_count, - DataChunk &result, const SelectionVector &target_sel) const { - D_ASSERT(result.ColumnCount() == layout.ColumnCount()); - vector column_ids; - column_ids.reserve(layout.ColumnCount()); - for (idx_t col_idx = 0; col_idx < layout.ColumnCount(); col_idx++) { - column_ids.emplace_back(col_idx); - } - Gather(row_locations, scan_sel, scan_count, column_ids, result, target_sel); -} - -void TupleDataCollection::Gather(Vector &row_locations, const SelectionVector &scan_sel, const idx_t scan_count, - const vector &column_ids, DataChunk &result, - const SelectionVector &target_sel) const { - for (idx_t col_idx = 0; col_idx < column_ids.size(); col_idx++) { - Gather(row_locations, scan_sel, scan_count, column_ids[col_idx], result.data[col_idx], target_sel); - } -} - -void TupleDataCollection::Gather(Vector &row_locations, const SelectionVector &scan_sel, const idx_t scan_count, - const column_t column_id, Vector &result, const SelectionVector &target_sel) const { - const auto &gather_function = gather_functions[column_id]; - gather_function.function(layout, row_locations, column_id, scan_sel, scan_count, result, target_sel, result, - gather_function.child_functions); -} - -template -static void TupleDataTemplatedGather(const TupleDataLayout &layout, Vector &row_locations, const idx_t col_idx, - const SelectionVector &scan_sel, const idx_t scan_count, Vector &target, - const SelectionVector &target_sel, Vector &dummy_vector, - const vector &child_functions) { - // Source - auto source_locations = FlatVector::GetData(row_locations); - - // Target - auto target_data = FlatVector::GetData(target); - auto &target_validity = FlatVector::Validity(target); - - // Precompute mask indexes - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); - - const auto offset_in_row = layout.GetOffsets()[col_idx]; - for (idx_t i = 0; i < scan_count; i++) { - const auto &source_row = source_locations[scan_sel.get_index(i)]; - const auto target_idx = target_sel.get_index(i); - ValidityBytes row_mask(source_row); - if (row_mask.RowIsValid(row_mask.GetValidityEntryUnsafe(entry_idx), idx_in_entry)) { - target_data[target_idx] = Load(source_row + offset_in_row); - } else { - target_validity.SetInvalid(target_idx); - } - } -} - -static void TupleDataStructGather(const TupleDataLayout &layout, Vector &row_locations, const idx_t col_idx, - const SelectionVector &scan_sel, const idx_t scan_count, Vector &target, - const SelectionVector &target_sel, Vector &dummy_vector, - const vector &child_functions) { - // Source - auto source_locations = FlatVector::GetData(row_locations); - - // Target - auto &target_validity = FlatVector::Validity(target); - - // Precompute mask indexes - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); - - // Get validity of the struct and create a Vector of pointers to the start of the TupleDataLayout of the STRUCT - Vector struct_row_locations(LogicalType::POINTER); - auto struct_source_locations = FlatVector::GetData(struct_row_locations); - const auto offset_in_row = layout.GetOffsets()[col_idx]; - for (idx_t i = 0; i < scan_count; i++) { - const auto source_idx = scan_sel.get_index(i); - const auto &source_row = source_locations[source_idx]; - - // Set the validity - ValidityBytes row_mask(source_row); - if (!row_mask.RowIsValid(row_mask.GetValidityEntryUnsafe(entry_idx), idx_in_entry)) { - const auto target_idx = target_sel.get_index(i); - target_validity.SetInvalid(target_idx); - } - - // Set the pointer - struct_source_locations[source_idx] = source_row + offset_in_row; - } - - // Get the struct layout and struct entries - const auto &struct_layout = layout.GetStructLayout(col_idx); - auto &struct_targets = StructVector::GetEntries(target); - D_ASSERT(struct_layout.ColumnCount() == struct_targets.size()); - - // Recurse through the struct children - for (idx_t struct_col_idx = 0; struct_col_idx < struct_layout.ColumnCount(); struct_col_idx++) { - auto &struct_target = *struct_targets[struct_col_idx]; - const auto &struct_gather_function = child_functions[struct_col_idx]; - struct_gather_function.function(struct_layout, struct_row_locations, struct_col_idx, scan_sel, scan_count, - struct_target, target_sel, dummy_vector, - struct_gather_function.child_functions); - } -} - -static void TupleDataListGather(const TupleDataLayout &layout, Vector &row_locations, const idx_t col_idx, - const SelectionVector &scan_sel, const idx_t scan_count, Vector &target, - const SelectionVector &target_sel, Vector &dummy_vector, - const vector &child_functions) { - // Source - auto source_locations = FlatVector::GetData(row_locations); - - // Target - auto target_list_entries = FlatVector::GetData(target); - auto &target_validity = FlatVector::Validity(target); - - // Precompute mask indexes - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); - - // Load pointers to the data from the row - Vector heap_locations(LogicalType::POINTER); - auto source_heap_locations = FlatVector::GetData(heap_locations); - auto &source_heap_validity = FlatVector::Validity(heap_locations); - - const auto offset_in_row = layout.GetOffsets()[col_idx]; - uint64_t target_list_offset = 0; - for (idx_t i = 0; i < scan_count; i++) { - const auto source_idx = scan_sel.get_index(i); - const auto target_idx = target_sel.get_index(i); - - const auto &source_row = source_locations[source_idx]; - ValidityBytes row_mask(source_row); - if (row_mask.RowIsValid(row_mask.GetValidityEntryUnsafe(entry_idx), idx_in_entry)) { - auto &source_heap_location = source_heap_locations[source_idx]; - source_heap_location = Load(source_row + offset_in_row); - - // Load list size and skip over - const auto list_length = Load(source_heap_location); - source_heap_location += sizeof(uint64_t); - - // Initialize list entry, and increment offset - target_list_entries[target_idx] = {target_list_offset, list_length}; - target_list_offset += list_length; - } else { - source_heap_validity.SetInvalid(source_idx); - target_validity.SetInvalid(target_idx); - } - } - auto list_size_before = ListVector::GetListSize(target); - ListVector::Reserve(target, list_size_before + target_list_offset); - ListVector::SetListSize(target, list_size_before + target_list_offset); - - // Recurse - D_ASSERT(child_functions.size() == 1); - const auto &child_function = child_functions[0]; - child_function.function(layout, heap_locations, list_size_before, scan_sel, scan_count, - ListVector::GetEntry(target), target_sel, target, child_function.child_functions); -} - -template -static void TupleDataTemplatedWithinListGather(const TupleDataLayout &layout, Vector &heap_locations, - const idx_t list_size_before, const SelectionVector &scan_sel, - const idx_t scan_count, Vector &target, - const SelectionVector &target_sel, Vector &list_vector, - const vector &child_functions) { - // Source - auto source_heap_locations = FlatVector::GetData(heap_locations); - auto &source_heap_validity = FlatVector::Validity(heap_locations); - - // Target - auto target_data = FlatVector::GetData(target); - auto &target_validity = FlatVector::Validity(target); - - // List parent - const auto list_entries = FlatVector::GetData(list_vector); - - uint64_t target_offset = list_size_before; - for (idx_t i = 0; i < scan_count; i++) { - const auto source_idx = scan_sel.get_index(i); - if (!source_heap_validity.RowIsValid(source_idx)) { - continue; - } - - const auto &list_length = list_entries[target_sel.get_index(i)].length; - - // Initialize validity mask - auto &source_heap_location = source_heap_locations[source_idx]; - ValidityBytes source_mask(source_heap_location); - source_heap_location += ValidityBytes::SizeInBytes(list_length); - - // Get the start to the fixed-size data and skip the heap pointer over it - const auto source_data_location = source_heap_location; - source_heap_location += list_length * TupleDataWithinListFixedSize(); - - // Load the child validity and data belonging to this list entry - for (idx_t child_i = 0; child_i < list_length; child_i++) { - if (source_mask.RowIsValidUnsafe(child_i)) { - target_data[target_offset + child_i] = TupleDataWithinListValueLoad( - source_data_location + child_i * TupleDataWithinListFixedSize(), source_heap_location); - } else { - target_validity.SetInvalid(target_offset + child_i); - } - } - target_offset += list_length; - } -} - -static void TupleDataStructWithinListGather(const TupleDataLayout &layout, Vector &heap_locations, - const idx_t list_size_before, const SelectionVector &scan_sel, - const idx_t scan_count, Vector &target, const SelectionVector &target_sel, - Vector &list_vector, - const vector &child_functions) { - // Source - auto source_heap_locations = FlatVector::GetData(heap_locations); - auto &source_heap_validity = FlatVector::Validity(heap_locations); - - // Target - auto &target_validity = FlatVector::Validity(target); - - // List parent - const auto list_entries = FlatVector::GetData(list_vector); - - uint64_t target_offset = list_size_before; - for (idx_t i = 0; i < scan_count; i++) { - const auto source_idx = scan_sel.get_index(i); - if (!source_heap_validity.RowIsValid(source_idx)) { - continue; - } - - const auto &list_length = list_entries[target_sel.get_index(i)].length; - - // Initialize validity mask and skip over it - auto &source_heap_location = source_heap_locations[source_idx]; - ValidityBytes source_mask(source_heap_location); - source_heap_location += ValidityBytes::SizeInBytes(list_length); - - // Load the child validity belonging to this list entry - for (idx_t child_i = 0; child_i < list_length; child_i++) { - if (!source_mask.RowIsValidUnsafe(child_i)) { - target_validity.SetInvalid(target_offset + child_i); - } - } - target_offset += list_length; - } - - // Recurse - auto &struct_targets = StructVector::GetEntries(target); - for (idx_t struct_col_idx = 0; struct_col_idx < struct_targets.size(); struct_col_idx++) { - auto &struct_target = *struct_targets[struct_col_idx]; - const auto &struct_gather_function = child_functions[struct_col_idx]; - struct_gather_function.function(layout, heap_locations, list_size_before, scan_sel, scan_count, struct_target, - target_sel, list_vector, struct_gather_function.child_functions); - } -} - -static void TupleDataListWithinListGather(const TupleDataLayout &layout, Vector &heap_locations, - const idx_t list_size_before, const SelectionVector &scan_sel, - const idx_t scan_count, Vector &target, const SelectionVector &target_sel, - Vector &list_vector, const vector &child_functions) { - // Source - auto source_heap_locations = FlatVector::GetData(heap_locations); - auto &source_heap_validity = FlatVector::Validity(heap_locations); - - // Target - auto target_list_entries = FlatVector::GetData(target); - auto &target_validity = FlatVector::Validity(target); - const auto child_list_size_before = ListVector::GetListSize(target); - - // List parent - const auto list_entries = FlatVector::GetData(list_vector); - - // We need to create a vector that has the combined list sizes (hugeint_t has same size as list_entry_t) - Vector combined_list_vector(LogicalType::HUGEINT); - auto combined_list_entries = FlatVector::GetData(combined_list_vector); - - uint64_t target_offset = list_size_before; - uint64_t target_child_offset = child_list_size_before; - for (idx_t i = 0; i < scan_count; i++) { - const auto source_idx = scan_sel.get_index(i); - if (!source_heap_validity.RowIsValid(source_idx)) { - continue; - } - - const auto &list_length = list_entries[target_sel.get_index(i)].length; - - // Initialize validity mask and skip over it - auto &source_heap_location = source_heap_locations[source_idx]; - ValidityBytes source_mask(source_heap_location); - source_heap_location += ValidityBytes::SizeInBytes(list_length); - - // Get the start to the fixed-size data and skip the heap pointer over it - const auto source_data_location = source_heap_location; - source_heap_location += list_length * sizeof(uint64_t); - - // Set the offset of the combined list entry - auto &combined_list_entry = combined_list_entries[target_sel.get_index(i)]; - combined_list_entry.offset = target_child_offset; - - // Load the child validity and data belonging to this list entry - for (idx_t child_i = 0; child_i < list_length; child_i++) { - if (source_mask.RowIsValidUnsafe(child_i)) { - auto &target_list_entry = target_list_entries[target_offset + child_i]; - target_list_entry.offset = target_child_offset; - target_list_entry.length = Load(source_data_location + child_i * sizeof(uint64_t)); - target_child_offset += target_list_entry.length; - } else { - target_validity.SetInvalid(target_offset + child_i); - } - } - - // Set the length of the combined list entry - combined_list_entry.length = target_child_offset - combined_list_entry.offset; - - target_offset += list_length; - } - ListVector::Reserve(target, target_child_offset); - ListVector::SetListSize(target, target_child_offset); - - // Recurse - D_ASSERT(child_functions.size() == 1); - const auto &child_function = child_functions[0]; - child_function.function(layout, heap_locations, child_list_size_before, scan_sel, scan_count, - ListVector::GetEntry(target), target_sel, combined_list_vector, - child_function.child_functions); -} - -template -tuple_data_gather_function_t TupleDataGetGatherFunction(bool within_list) { - return within_list ? TupleDataTemplatedWithinListGather : TupleDataTemplatedGather; -} - -TupleDataGatherFunction TupleDataCollection::GetGatherFunction(const LogicalType &type, bool within_list) { - TupleDataGatherFunction result; - switch (type.InternalType()) { - case PhysicalType::BOOL: - result.function = TupleDataGetGatherFunction(within_list); - break; - case PhysicalType::INT8: - result.function = TupleDataGetGatherFunction(within_list); - break; - case PhysicalType::INT16: - result.function = TupleDataGetGatherFunction(within_list); - break; - case PhysicalType::INT32: - result.function = TupleDataGetGatherFunction(within_list); - break; - case PhysicalType::INT64: - result.function = TupleDataGetGatherFunction(within_list); - break; - case PhysicalType::INT128: - result.function = TupleDataGetGatherFunction(within_list); - break; - case PhysicalType::UINT8: - result.function = TupleDataGetGatherFunction(within_list); - break; - case PhysicalType::UINT16: - result.function = TupleDataGetGatherFunction(within_list); - break; - case PhysicalType::UINT32: - result.function = TupleDataGetGatherFunction(within_list); - break; - case PhysicalType::UINT64: - result.function = TupleDataGetGatherFunction(within_list); - break; - case PhysicalType::FLOAT: - result.function = TupleDataGetGatherFunction(within_list); - break; - case PhysicalType::DOUBLE: - result.function = TupleDataGetGatherFunction(within_list); - break; - case PhysicalType::INTERVAL: - result.function = TupleDataGetGatherFunction(within_list); - break; - case PhysicalType::VARCHAR: - result.function = TupleDataGetGatherFunction(within_list); - break; - case PhysicalType::STRUCT: { - result.function = within_list ? TupleDataStructWithinListGather : TupleDataStructGather; - for (const auto &child_type : StructType::GetChildTypes(type)) { - result.child_functions.push_back(GetGatherFunction(child_type.second, within_list)); - } - break; - } - case PhysicalType::LIST: - result.function = within_list ? TupleDataListWithinListGather : TupleDataListGather; - result.child_functions.push_back(GetGatherFunction(ListType::GetChildType(type), true)); - break; - default: - throw InternalException("Unsupported type for TupleDataCollection::GetGatherFunction"); - } - return result; -} - -} // namespace duckdb - - - - -namespace duckdb { - -TupleDataChunkPart::TupleDataChunkPart(mutex &lock_p) : lock(lock_p) { -} - -void SwapTupleDataChunkPart(TupleDataChunkPart &a, TupleDataChunkPart &b) { - std::swap(a.row_block_index, b.row_block_index); - std::swap(a.row_block_offset, b.row_block_offset); - std::swap(a.heap_block_index, b.heap_block_index); - std::swap(a.heap_block_offset, b.heap_block_offset); - std::swap(a.base_heap_ptr, b.base_heap_ptr); - std::swap(a.total_heap_size, b.total_heap_size); - std::swap(a.count, b.count); - std::swap(a.lock, b.lock); -} - -TupleDataChunkPart::TupleDataChunkPart(TupleDataChunkPart &&other) noexcept : lock((other.lock)) { - SwapTupleDataChunkPart(*this, other); -} - -TupleDataChunkPart &TupleDataChunkPart::operator=(TupleDataChunkPart &&other) noexcept { - SwapTupleDataChunkPart(*this, other); - return *this; -} - -TupleDataChunk::TupleDataChunk() : count(0), lock(make_unsafe_uniq()) { - parts.reserve(2); -} - -static inline void SwapTupleDataChunk(TupleDataChunk &a, TupleDataChunk &b) noexcept { - std::swap(a.parts, b.parts); - std::swap(a.row_block_ids, b.row_block_ids); - std::swap(a.heap_block_ids, b.heap_block_ids); - std::swap(a.count, b.count); - std::swap(a.lock, b.lock); -} - -TupleDataChunk::TupleDataChunk(TupleDataChunk &&other) noexcept { - SwapTupleDataChunk(*this, other); -} - -TupleDataChunk &TupleDataChunk::operator=(TupleDataChunk &&other) noexcept { - SwapTupleDataChunk(*this, other); - return *this; -} - -void TupleDataChunk::AddPart(TupleDataChunkPart &&part, const TupleDataLayout &layout) { - count += part.count; - row_block_ids.insert(part.row_block_index); - if (!layout.AllConstant() && part.total_heap_size > 0) { - heap_block_ids.insert(part.heap_block_index); - } - part.lock = *lock; - parts.emplace_back(std::move(part)); -} - -void TupleDataChunk::Verify() const { -#ifdef DEBUG - idx_t total_count = 0; - for (const auto &part : parts) { - total_count += part.count; - } - D_ASSERT(this->count == total_count); - D_ASSERT(this->count <= STANDARD_VECTOR_SIZE); -#endif -} - -void TupleDataChunk::MergeLastChunkPart(const TupleDataLayout &layout) { - if (parts.size() < 2) { - return; - } - - auto &second_to_last = parts[parts.size() - 2]; - auto &last = parts[parts.size() - 1]; - - auto rows_align = - last.row_block_index == second_to_last.row_block_index && - last.row_block_offset == second_to_last.row_block_offset + second_to_last.count * layout.GetRowWidth(); - - if (!rows_align) { // If rows don't align we can never merge - return; - } - - if (layout.AllConstant()) { // No heap and rows align - merge - second_to_last.count += last.count; - parts.pop_back(); - return; - } - - if (last.heap_block_index == second_to_last.heap_block_index && - last.heap_block_offset == second_to_last.heap_block_index + second_to_last.total_heap_size && - last.base_heap_ptr == second_to_last.base_heap_ptr) { // There is a heap and it aligns - merge - second_to_last.total_heap_size += last.total_heap_size; - second_to_last.count += last.count; - parts.pop_back(); - } -} - -TupleDataSegment::TupleDataSegment(shared_ptr allocator_p) - : allocator(std::move(allocator_p)), count(0), data_size(0) { -} - -TupleDataSegment::~TupleDataSegment() { - lock_guard guard(pinned_handles_lock); - pinned_row_handles.clear(); - pinned_heap_handles.clear(); - allocator = nullptr; -} - -void SwapTupleDataSegment(TupleDataSegment &a, TupleDataSegment &b) { - std::swap(a.allocator, b.allocator); - std::swap(a.chunks, b.chunks); - std::swap(a.count, b.count); - std::swap(a.data_size, b.data_size); - std::swap(a.pinned_row_handles, b.pinned_row_handles); - std::swap(a.pinned_heap_handles, b.pinned_heap_handles); -} - -TupleDataSegment::TupleDataSegment(TupleDataSegment &&other) noexcept { - SwapTupleDataSegment(*this, other); -} - -TupleDataSegment &TupleDataSegment::operator=(TupleDataSegment &&other) noexcept { - SwapTupleDataSegment(*this, other); - return *this; -} - -idx_t TupleDataSegment::ChunkCount() const { - return chunks.size(); -} - -idx_t TupleDataSegment::SizeInBytes() const { - return data_size; -} - -void TupleDataSegment::Unpin() { - lock_guard guard(pinned_handles_lock); - pinned_row_handles.clear(); - pinned_heap_handles.clear(); -} - -void TupleDataSegment::Verify() const { -#ifdef DEBUG - const auto &layout = allocator->GetLayout(); - - idx_t total_count = 0; - idx_t total_size = 0; - for (const auto &chunk : chunks) { - chunk.Verify(); - total_count += chunk.count; - - total_size += chunk.count * layout.GetRowWidth(); - if (!layout.AllConstant()) { - for (const auto &part : chunk.parts) { - total_size += part.total_heap_size; - } - } - } - D_ASSERT(total_count == this->count); - D_ASSERT(total_size == this->data_size); -#endif -} - -void TupleDataSegment::VerifyEverythingPinned() const { -#ifdef DEBUG - D_ASSERT(pinned_row_handles.size() == allocator->RowBlockCount()); - D_ASSERT(pinned_heap_handles.size() == allocator->HeapBlockCount()); -#endif -} - -} // namespace duckdb - - - - -namespace duckdb { - -SelectionData::SelectionData(idx_t count) { - owned_data = make_unsafe_uniq_array(count); -#ifdef DEBUG - for (idx_t i = 0; i < count; i++) { - owned_data[i] = std::numeric_limits::max(); - } -#endif -} - -// LCOV_EXCL_START -string SelectionVector::ToString(idx_t count) const { - string result = "Selection Vector (" + to_string(count) + ") ["; - for (idx_t i = 0; i < count; i++) { - if (i != 0) { - result += ", "; - } - result += to_string(get_index(i)); - } - result += "]"; - return result; -} - -void SelectionVector::Print(idx_t count) const { - Printer::Print(ToString(count)); -} -// LCOV_EXCL_STOP - -buffer_ptr SelectionVector::Slice(const SelectionVector &sel, idx_t count) const { - auto data = make_buffer(count); - auto result_ptr = data->owned_data.get(); - // for every element, we perform result[i] = target[new[i]] - for (idx_t i = 0; i < count; i++) { - auto new_idx = sel.get_index(i); - auto idx = this->get_index(new_idx); - result_ptr[i] = idx; - } - return data; -} - -} // namespace duckdb - - - - - - - -#include - -namespace duckdb { - -StringHeap::StringHeap(Allocator &allocator) : allocator(allocator) { -} - -void StringHeap::Destroy() { - allocator.Destroy(); -} - -void StringHeap::Move(StringHeap &other) { - other.allocator.Move(allocator); -} - -string_t StringHeap::AddString(const char *data, idx_t len) { - D_ASSERT(Utf8Proc::Analyze(data, len) != UnicodeType::INVALID); - return AddBlob(data, len); -} - -string_t StringHeap::AddString(const char *data) { - return AddString(data, strlen(data)); -} - -string_t StringHeap::AddString(const string &data) { - return AddString(data.c_str(), data.size()); -} - -string_t StringHeap::AddString(const string_t &data) { - return AddString(data.GetData(), data.GetSize()); -} - -string_t StringHeap::AddBlob(const char *data, idx_t len) { - auto insert_string = EmptyString(len); - auto insert_pos = insert_string.GetDataWriteable(); - memcpy(insert_pos, data, len); - insert_string.Finalize(); - return insert_string; -} - -string_t StringHeap::AddBlob(const string_t &data) { - return AddBlob(data.GetData(), data.GetSize()); -} - -string_t StringHeap::EmptyString(idx_t len) { - D_ASSERT(len > string_t::INLINE_LENGTH); - auto insert_pos = const_char_ptr_cast(allocator.Allocate(len)); - return string_t(insert_pos, len); -} - -idx_t StringHeap::SizeInBytes() const { - return allocator.SizeInBytes(); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -void string_t::Verify() const { - auto dataptr = GetData(); - (void)dataptr; - D_ASSERT(dataptr); - -#ifdef DEBUG - auto utf_type = Utf8Proc::Analyze(dataptr, GetSize()); - D_ASSERT(utf_type != UnicodeType::INVALID); -#endif - - // verify that the prefix contains the first four characters of the string - for (idx_t i = 0; i < MinValue(PREFIX_LENGTH, GetSize()); i++) { - D_ASSERT(GetPrefix()[i] == dataptr[i]); - } - // verify that for strings with length <= INLINE_LENGTH, the rest of the string is zero - for (idx_t i = GetSize(); i < INLINE_LENGTH; i++) { - D_ASSERT(GetData()[i] == '\0'); - } -} - -} // namespace duckdb - - - - - - - - - - -#include -#include -#include - -namespace duckdb { - -static_assert(sizeof(dtime_t) == sizeof(int64_t), "dtime_t was padded"); - -// string format is hh:mm:ss.microsecondsZ -// microseconds and Z are optional -// ISO 8601 - -bool Time::TryConvertInternal(const char *buf, idx_t len, idx_t &pos, dtime_t &result, bool strict) { - int32_t hour = -1, min = -1, sec = -1, micros = -1; - pos = 0; - - if (len == 0) { - return false; - } - - int sep; - - // skip leading spaces - while (pos < len && StringUtil::CharacterIsSpace(buf[pos])) { - pos++; - } - - if (pos >= len) { - return false; - } - - if (!StringUtil::CharacterIsDigit(buf[pos])) { - return false; - } - - if (!Date::ParseDoubleDigit(buf, len, pos, hour)) { - return false; - } - if (hour < 0 || hour >= 24) { - return false; - } - - if (pos >= len) { - return false; - } - - // fetch the separator - sep = buf[pos++]; - if (sep != ':') { - // invalid separator - return false; - } - - if (!Date::ParseDoubleDigit(buf, len, pos, min)) { - return false; - } - if (min < 0 || min >= 60) { - return false; - } - - if (pos >= len) { - return false; - } - - if (buf[pos++] != sep) { - return false; - } - - if (!Date::ParseDoubleDigit(buf, len, pos, sec)) { - return false; - } - if (sec < 0 || sec >= 60) { - return false; - } - - micros = 0; - if (pos < len && buf[pos] == '.') { - pos++; - // we expect some microseconds - int32_t mult = 100000; - for (; pos < len && StringUtil::CharacterIsDigit(buf[pos]); pos++, mult /= 10) { - if (mult > 0) { - micros += (buf[pos] - '0') * mult; - } - } - } - - // in strict mode, check remaining string for non-space characters - if (strict) { - // skip trailing spaces - while (pos < len && StringUtil::CharacterIsSpace(buf[pos])) { - pos++; - } - // check position. if end was not reached, non-space chars remaining - if (pos < len) { - return false; - } - } - - result = Time::FromTime(hour, min, sec, micros); - return true; -} - -bool Time::TryConvertTime(const char *buf, idx_t len, idx_t &pos, dtime_t &result, bool strict) { - if (!Time::TryConvertInternal(buf, len, pos, result, strict)) { - if (!strict) { - // last chance, check if we can parse as timestamp - timestamp_t timestamp; - if (Timestamp::TryConvertTimestamp(buf, len, timestamp) == TimestampCastResult::SUCCESS) { - if (!Timestamp::IsFinite(timestamp)) { - return false; - } - result = Timestamp::GetTime(timestamp); - return true; - } - } - return false; - } - return true; -} - -bool Time::TryParseUTCOffset(const char *str, idx_t &pos, idx_t len, int32_t &offset) { - offset = 0; - if (pos == len || StringUtil::CharacterIsSpace(str[pos])) { - return true; - } - - idx_t curpos = pos; - // Minimum of 3 characters - if (curpos + 3 > len) { - // no characters left to parse - return false; - } - - const auto sign_char = str[curpos]; - if (sign_char != '+' && sign_char != '-') { - // expected either + or - - return false; - } - curpos++; - - int32_t hh = 0; - idx_t start = curpos; - for (; curpos < len; ++curpos) { - const auto c = str[curpos]; - if (!StringUtil::CharacterIsDigit(c)) { - break; - } - hh = hh * 10 + (c - '0'); - } - // HH is in [-1559,+1559] and must be at least two digits - if (curpos - start < 2 || hh > 1559) { - return false; - } - - // optional minute specifier: expected ":MM" - int32_t mm = 0; - if (curpos + 3 <= len && str[curpos] == ':') { - ++curpos; - if (!Date::ParseDoubleDigit(str, len, curpos, mm) || mm >= Interval::MINS_PER_HOUR) { - return false; - } - } - - // optional seconds specifier: expected ":SS" - int32_t ss = 0; - if (curpos + 3 <= len && str[curpos] == ':') { - ++curpos; - if (!Date::ParseDoubleDigit(str, len, curpos, ss) || ss >= Interval::SECS_PER_MINUTE) { - return false; - } - } - - // Assemble the offset now that we know nothing went wrong - offset += hh * Interval::SECS_PER_HOUR; - offset += mm * Interval::SECS_PER_MINUTE; - offset += ss; - if (sign_char == '-') { - offset = -offset; - } - - pos = curpos; - - return true; -} - -bool Time::TryConvertTimeTZ(const char *buf, idx_t len, idx_t &pos, dtime_tz_t &result, bool strict) { - dtime_t time_part; - if (!Time::TryConvertInternal(buf, len, pos, time_part, false)) { - if (!strict) { - // last chance, check if we can parse as timestamp - timestamp_t timestamp; - if (Timestamp::TryConvertTimestamp(buf, len, timestamp) == TimestampCastResult::SUCCESS) { - if (!Timestamp::IsFinite(timestamp)) { - return false; - } - result = dtime_tz_t(Timestamp::GetTime(timestamp), 0); - return true; - } - } - return false; - } - - // We can't use Timestamp::TryParseUTCOffset because the colon is optional there but required here. - int32_t offset = 0; - if (!TryParseUTCOffset(buf, pos, len, offset)) { - return false; - } - - // in strict mode, check remaining string for non-space characters - if (strict) { - // skip trailing spaces - while (pos < len && StringUtil::CharacterIsSpace(buf[pos])) { - pos++; - } - // check position. if end was not reached, non-space chars remaining - if (pos < len) { - return false; - } - } - - result = dtime_tz_t(time_part, offset); - - return true; -} - -string Time::ConversionError(const string &str) { - return StringUtil::Format("time field value out of range: \"%s\", " - "expected format is ([YYYY-MM-DD ]HH:MM:SS[.MS])", - str); -} - -string Time::ConversionError(string_t str) { - return Time::ConversionError(str.GetString()); -} - -dtime_t Time::FromCString(const char *buf, idx_t len, bool strict) { - dtime_t result; - idx_t pos; - if (!Time::TryConvertTime(buf, len, pos, result, strict)) { - throw ConversionException(ConversionError(string(buf, len))); - } - return result; -} - -dtime_t Time::FromString(const string &str, bool strict) { - return Time::FromCString(str.c_str(), str.size(), strict); -} - -string Time::ToString(dtime_t time) { - int32_t time_units[4]; - Time::Convert(time, time_units[0], time_units[1], time_units[2], time_units[3]); - - char micro_buffer[6]; - auto length = TimeToStringCast::Length(time_units, micro_buffer); - auto buffer = make_unsafe_uniq_array(length); - TimeToStringCast::Format(buffer.get(), length, time_units, micro_buffer); - return string(buffer.get(), length); -} - -string Time::ToUTCOffset(int hour_offset, int minute_offset) { - dtime_t time((hour_offset * Interval::MINS_PER_HOUR + minute_offset) * Interval::MICROS_PER_MINUTE); - - char buffer[1 + 2 + 1 + 2]; - idx_t length = 0; - buffer[length++] = (time.micros < 0 ? '-' : '+'); - time.micros = std::abs(time.micros); - - int32_t time_units[4]; - Time::Convert(time, time_units[0], time_units[1], time_units[2], time_units[3]); - - TimeToStringCast::FormatTwoDigits(buffer + length, time_units[0]); - length += 2; - if (time_units[1]) { - buffer[length++] = ':'; - TimeToStringCast::FormatTwoDigits(buffer + length, time_units[1]); - length += 2; - } - - return string(buffer, length); -} - -dtime_t Time::FromTime(int32_t hour, int32_t minute, int32_t second, int32_t microseconds) { - int64_t result; - result = hour; // hours - result = result * Interval::MINS_PER_HOUR + minute; // hours -> minutes - result = result * Interval::SECS_PER_MINUTE + second; // minutes -> seconds - result = result * Interval::MICROS_PER_SEC + microseconds; // seconds -> microseconds - return dtime_t(result); -} - -bool Time::IsValidTime(int32_t hour, int32_t minute, int32_t second, int32_t microseconds) { - if (hour < 0 || hour >= 24) { - return false; - } - if (minute < 0 || minute >= 60) { - return false; - } - if (second < 0 || second > 60) { - return false; - } - if (microseconds < 0 || microseconds > 1000000) { - return false; - } - return true; -} - -void Time::Convert(dtime_t dtime, int32_t &hour, int32_t &min, int32_t &sec, int32_t µs) { - int64_t time = dtime.micros; - hour = int32_t(time / Interval::MICROS_PER_HOUR); - time -= int64_t(hour) * Interval::MICROS_PER_HOUR; - min = int32_t(time / Interval::MICROS_PER_MINUTE); - time -= int64_t(min) * Interval::MICROS_PER_MINUTE; - sec = int32_t(time / Interval::MICROS_PER_SEC); - time -= int64_t(sec) * Interval::MICROS_PER_SEC; - micros = int32_t(time); - D_ASSERT(Time::IsValidTime(hour, min, sec, micros)); -} - -dtime_t Time::FromTimeMs(int64_t time_ms) { - int64_t result; - if (!TryMultiplyOperator::Operation(time_ms, Interval::MICROS_PER_MSEC, result)) { - throw ConversionException("Could not convert Time(MS) to Time(US)"); - } - return dtime_t(result); -} - -dtime_t Time::FromTimeNs(int64_t time_ns) { - return dtime_t(time_ns / Interval::NANOS_PER_MICRO); -} - -} // namespace duckdb - -#endif diff --git a/lib/duckdb-3.cpp b/lib/duckdb-3.cpp deleted file mode 100644 index 745ed712..00000000 --- a/lib/duckdb-3.cpp +++ /dev/null @@ -1,20726 +0,0 @@ -#ifdef GODUCKDB_FROM_SOURCE -// See https://raw.githubusercontent.com/duckdb/duckdb/main/LICENSE for licensing information - -#include "duckdb.hpp" -#include "duckdb-internal.hpp" -#ifndef DUCKDB_AMALGAMATION -#error header mismatch -#endif - - - - - - - - - - - - -#include - -namespace duckdb { - -static_assert(sizeof(timestamp_t) == sizeof(int64_t), "timestamp_t was padded"); - -// timestamp/datetime uses 64 bits, high 32 bits for date and low 32 bits for time -// string format is YYYY-MM-DDThh:mm:ssZ -// T may be a space -// Z is optional -// ISO 8601 - -// arithmetic operators -timestamp_t timestamp_t::operator+(const double &value) const { - timestamp_t result; - if (!TryAddOperator::Operation(this->value, int64_t(value), result.value)) { - throw OutOfRangeException("Overflow in timestamp addition"); - } - return result; -} - -int64_t timestamp_t::operator-(const timestamp_t &other) const { - int64_t result; - if (!TrySubtractOperator::Operation(value, int64_t(other.value), result)) { - throw OutOfRangeException("Overflow in timestamp subtraction"); - } - return result; -} - -// in-place operators -timestamp_t ×tamp_t::operator+=(const int64_t &delta) { - if (!TryAddOperator::Operation(value, delta, value)) { - throw OutOfRangeException("Overflow in timestamp increment"); - } - return *this; -} - -timestamp_t ×tamp_t::operator-=(const int64_t &delta) { - if (!TrySubtractOperator::Operation(value, delta, value)) { - throw OutOfRangeException("Overflow in timestamp decrement"); - } - return *this; -} - -bool Timestamp::TryConvertTimestampTZ(const char *str, idx_t len, timestamp_t &result, bool &has_offset, string_t &tz) { - idx_t pos; - date_t date; - dtime_t time; - has_offset = false; - if (!Date::TryConvertDate(str, len, pos, date, has_offset)) { - return false; - } - if (pos == len) { - // no time: only a date or special - if (date == date_t::infinity()) { - result = timestamp_t::infinity(); - return true; - } else if (date == date_t::ninfinity()) { - result = timestamp_t::ninfinity(); - return true; - } - return Timestamp::TryFromDatetime(date, dtime_t(0), result); - } - // try to parse a time field - if (str[pos] == ' ' || str[pos] == 'T') { - pos++; - } - idx_t time_pos = 0; - if (!Time::TryConvertTime(str + pos, len - pos, time_pos, time)) { - return false; - } - pos += time_pos; - if (!Timestamp::TryFromDatetime(date, time, result)) { - return false; - } - if (pos < len) { - // skip a "Z" at the end (as per the ISO8601 specs) - int hour_offset, minute_offset; - if (str[pos] == 'Z') { - pos++; - has_offset = true; - } else if (Timestamp::TryParseUTCOffset(str, pos, len, hour_offset, minute_offset)) { - const int64_t delta = hour_offset * Interval::MICROS_PER_HOUR + minute_offset * Interval::MICROS_PER_MINUTE; - if (!TrySubtractOperator::Operation(result.value, delta, result.value)) { - return false; - } - has_offset = true; - } else { - // Parse a time zone: / [A-Za-z0-9/_]+/ - if (str[pos++] != ' ') { - return false; - } - auto tz_name = str + pos; - for (; pos < len && CharacterIsTimeZone(str[pos]); ++pos) { - continue; - } - auto tz_len = str + pos - tz_name; - if (tz_len) { - tz = string_t(tz_name, tz_len); - } - // Note that the caller must reinterpret the instant we return to the given time zone - } - - // skip any spaces at the end - while (pos < len && StringUtil::CharacterIsSpace(str[pos])) { - pos++; - } - if (pos < len) { - return false; - } - } - return true; -} - -TimestampCastResult Timestamp::TryConvertTimestamp(const char *str, idx_t len, timestamp_t &result) { - string_t tz(nullptr, 0); - bool has_offset = false; - // We don't understand TZ without an extension, so fail if one was provided. - auto success = TryConvertTimestampTZ(str, len, result, has_offset, tz); - if (!success) { - return TimestampCastResult::ERROR_INCORRECT_FORMAT; - } - if (tz.GetSize() == 0) { - // no timezone provided - success! - return TimestampCastResult::SUCCESS; - } - if (tz.GetSize() == 3) { - // we can ONLY handle UTC without ICU being loaded - auto tz_ptr = tz.GetData(); - if ((tz_ptr[0] == 'u' || tz_ptr[0] == 'U') && (tz_ptr[1] == 't' || tz_ptr[1] == 'T') && - (tz_ptr[2] == 'c' || tz_ptr[2] == 'C')) { - return TimestampCastResult::SUCCESS; - } - } - return TimestampCastResult::ERROR_NON_UTC_TIMEZONE; -} - -string Timestamp::ConversionError(const string &str) { - return StringUtil::Format("timestamp field value out of range: \"%s\", " - "expected format is (YYYY-MM-DD HH:MM:SS[.US][±HH:MM| ZONE])", - str); -} - -string Timestamp::UnsupportedTimezoneError(const string &str) { - return StringUtil::Format("timestamp field value \"%s\" has a timestamp that is not UTC.\nUse the TIMESTAMPTZ type " - "with the ICU extension loaded to handle non-UTC timestamps.", - str); -} - -string Timestamp::ConversionError(string_t str) { - return Timestamp::ConversionError(str.GetString()); -} - -string Timestamp::UnsupportedTimezoneError(string_t str) { - return Timestamp::UnsupportedTimezoneError(str.GetString()); -} - -timestamp_t Timestamp::FromCString(const char *str, idx_t len) { - timestamp_t result; - auto cast_result = Timestamp::TryConvertTimestamp(str, len, result); - if (cast_result == TimestampCastResult::SUCCESS) { - return result; - } - if (cast_result == TimestampCastResult::ERROR_NON_UTC_TIMEZONE) { - throw ConversionException(Timestamp::UnsupportedTimezoneError(string(str, len))); - } else { - throw ConversionException(Timestamp::ConversionError(string(str, len))); - } -} - -bool Timestamp::TryParseUTCOffset(const char *str, idx_t &pos, idx_t len, int &hour_offset, int &minute_offset) { - minute_offset = 0; - idx_t curpos = pos; - // parse the next 3 characters - if (curpos + 3 > len) { - // no characters left to parse - return false; - } - char sign_char = str[curpos]; - if (sign_char != '+' && sign_char != '-') { - // expected either + or - - return false; - } - curpos++; - if (!StringUtil::CharacterIsDigit(str[curpos]) || !StringUtil::CharacterIsDigit(str[curpos + 1])) { - // expected +HH or -HH - return false; - } - hour_offset = (str[curpos] - '0') * 10 + (str[curpos + 1] - '0'); - if (sign_char == '-') { - hour_offset = -hour_offset; - } - curpos += 2; - - // optional minute specifier: expected either "MM" or ":MM" - if (curpos >= len) { - // done, nothing left - pos = curpos; - return true; - } - if (str[curpos] == ':') { - curpos++; - } - if (curpos + 2 > len || !StringUtil::CharacterIsDigit(str[curpos]) || - !StringUtil::CharacterIsDigit(str[curpos + 1])) { - // no MM specifier - pos = curpos; - return true; - } - // we have an MM specifier: parse it - minute_offset = (str[curpos] - '0') * 10 + (str[curpos + 1] - '0'); - if (sign_char == '-') { - minute_offset = -minute_offset; - } - pos = curpos + 2; - return true; -} - -timestamp_t Timestamp::FromString(const string &str) { - return Timestamp::FromCString(str.c_str(), str.size()); -} - -string Timestamp::ToString(timestamp_t timestamp) { - if (timestamp == timestamp_t::infinity()) { - return Date::PINF; - } else if (timestamp == timestamp_t::ninfinity()) { - return Date::NINF; - } - date_t date; - dtime_t time; - Timestamp::Convert(timestamp, date, time); - return Date::ToString(date) + " " + Time::ToString(time); -} - -date_t Timestamp::GetDate(timestamp_t timestamp) { - if (timestamp == timestamp_t::infinity()) { - return date_t::infinity(); - } else if (timestamp == timestamp_t::ninfinity()) { - return date_t::ninfinity(); - } - return date_t((timestamp.value + (timestamp.value < 0)) / Interval::MICROS_PER_DAY - (timestamp.value < 0)); -} - -dtime_t Timestamp::GetTime(timestamp_t timestamp) { - if (!IsFinite(timestamp)) { - throw ConversionException("Can't get TIME of infinite TIMESTAMP"); - } - date_t date = Timestamp::GetDate(timestamp); - return dtime_t(timestamp.value - (int64_t(date.days) * int64_t(Interval::MICROS_PER_DAY))); -} - -bool Timestamp::TryFromDatetime(date_t date, dtime_t time, timestamp_t &result) { - if (!TryMultiplyOperator::Operation(date.days, Interval::MICROS_PER_DAY, result.value)) { - return false; - } - if (!TryAddOperator::Operation(result.value, time.micros, result.value)) { - return false; - } - return Timestamp::IsFinite(result); -} - -timestamp_t Timestamp::FromDatetime(date_t date, dtime_t time) { - timestamp_t result; - if (!TryFromDatetime(date, time, result)) { - throw Exception("Overflow exception in date/time -> timestamp conversion"); - } - return result; -} - -void Timestamp::Convert(timestamp_t timestamp, date_t &out_date, dtime_t &out_time) { - out_date = GetDate(timestamp); - int64_t days_micros; - if (!TryMultiplyOperator::Operation(out_date.days, Interval::MICROS_PER_DAY, - days_micros)) { - throw ConversionException("Date out of range in timestamp conversion"); - } - out_time = dtime_t(timestamp.value - days_micros); - D_ASSERT(timestamp == Timestamp::FromDatetime(out_date, out_time)); -} - -timestamp_t Timestamp::GetCurrentTimestamp() { - auto now = system_clock::now(); - auto epoch_ms = duration_cast(now.time_since_epoch()).count(); - return Timestamp::FromEpochMs(epoch_ms); -} - -timestamp_t Timestamp::FromEpochSeconds(int64_t sec) { - int64_t result; - if (!TryMultiplyOperator::Operation(sec, Interval::MICROS_PER_SEC, result)) { - throw ConversionException("Could not convert Timestamp(S) to Timestamp(US)"); - } - return timestamp_t(result); -} - -timestamp_t Timestamp::FromEpochMs(int64_t ms) { - int64_t result; - if (!TryMultiplyOperator::Operation(ms, Interval::MICROS_PER_MSEC, result)) { - throw ConversionException("Could not convert Timestamp(MS) to Timestamp(US)"); - } - return timestamp_t(result); -} - -timestamp_t Timestamp::FromEpochMicroSeconds(int64_t micros) { - return timestamp_t(micros); -} - -timestamp_t Timestamp::FromEpochNanoSeconds(int64_t ns) { - return timestamp_t(ns / 1000); -} - -int64_t Timestamp::GetEpochSeconds(timestamp_t timestamp) { - return timestamp.value / Interval::MICROS_PER_SEC; -} - -int64_t Timestamp::GetEpochMs(timestamp_t timestamp) { - return timestamp.value / Interval::MICROS_PER_MSEC; -} - -int64_t Timestamp::GetEpochMicroSeconds(timestamp_t timestamp) { - return timestamp.value; -} - -int64_t Timestamp::GetEpochNanoSeconds(timestamp_t timestamp) { - int64_t result; - int64_t ns_in_us = 1000; - if (!TryMultiplyOperator::Operation(timestamp.value, ns_in_us, result)) { - throw ConversionException("Could not convert Timestamp(US) to Timestamp(NS)"); - } - return result; -} - -double Timestamp::GetJulianDay(timestamp_t timestamp) { - double result = Timestamp::GetTime(timestamp).micros; - result /= Interval::MICROS_PER_DAY; - result += Date::ExtractJulianDay(Timestamp::GetDate(timestamp)); - return result; -} - -} // namespace duckdb - - - -namespace duckdb { - -bool UUID::FromString(string str, hugeint_t &result) { - auto hex2char = [](char ch) -> unsigned char { - if (ch >= '0' && ch <= '9') { - return ch - '0'; - } - if (ch >= 'a' && ch <= 'f') { - return 10 + ch - 'a'; - } - if (ch >= 'A' && ch <= 'F') { - return 10 + ch - 'A'; - } - return 0; - }; - auto is_hex = [](char ch) -> bool { - return (ch >= '0' && ch <= '9') || (ch >= 'a' && ch <= 'f') || (ch >= 'A' && ch <= 'F'); - }; - - if (str.empty()) { - return false; - } - int has_braces = 0; - if (str.front() == '{') { - has_braces = 1; - } - if (has_braces && str.back() != '}') { - return false; - } - - result.lower = 0; - result.upper = 0; - size_t count = 0; - for (size_t i = has_braces; i < str.size() - has_braces; ++i) { - if (str[i] == '-') { - continue; - } - if (count >= 32 || !is_hex(str[i])) { - return false; - } - if (count >= 16) { - result.lower = (result.lower << 4) | hex2char(str[i]); - } else { - result.upper = (result.upper << 4) | hex2char(str[i]); - } - count++; - } - // Flip the first bit to make `order by uuid` same as `order by uuid::varchar` - result.upper ^= (uint64_t(1) << 63); - return count == 32; -} - -void UUID::ToString(hugeint_t input, char *buf) { - auto byte_to_hex = [](char byte_val, char *buf, idx_t &pos) { - static char const HEX_DIGITS[] = "0123456789abcdef"; - buf[pos++] = HEX_DIGITS[(byte_val >> 4) & 0xf]; - buf[pos++] = HEX_DIGITS[byte_val & 0xf]; - }; - - // Flip back before convert to string - int64_t upper = input.upper ^ (uint64_t(1) << 63); - idx_t pos = 0; - byte_to_hex(upper >> 56 & 0xFF, buf, pos); - byte_to_hex(upper >> 48 & 0xFF, buf, pos); - byte_to_hex(upper >> 40 & 0xFF, buf, pos); - byte_to_hex(upper >> 32 & 0xFF, buf, pos); - buf[pos++] = '-'; - byte_to_hex(upper >> 24 & 0xFF, buf, pos); - byte_to_hex(upper >> 16 & 0xFF, buf, pos); - buf[pos++] = '-'; - byte_to_hex(upper >> 8 & 0xFF, buf, pos); - byte_to_hex(upper & 0xFF, buf, pos); - buf[pos++] = '-'; - byte_to_hex(input.lower >> 56 & 0xFF, buf, pos); - byte_to_hex(input.lower >> 48 & 0xFF, buf, pos); - buf[pos++] = '-'; - byte_to_hex(input.lower >> 40 & 0xFF, buf, pos); - byte_to_hex(input.lower >> 32 & 0xFF, buf, pos); - byte_to_hex(input.lower >> 24 & 0xFF, buf, pos); - byte_to_hex(input.lower >> 16 & 0xFF, buf, pos); - byte_to_hex(input.lower >> 8 & 0xFF, buf, pos); - byte_to_hex(input.lower & 0xFF, buf, pos); -} - -hugeint_t UUID::GenerateRandomUUID(RandomEngine &engine) { - uint8_t bytes[16]; - for (int i = 0; i < 16; i += 4) { - *reinterpret_cast(bytes + i) = engine.NextRandomInteger(); - } - // variant must be 10xxxxxx - bytes[8] &= 0xBF; - bytes[8] |= 0x80; - // version must be 0100xxxx - bytes[6] &= 0x4F; - bytes[6] |= 0x40; - - hugeint_t result; - result.upper = 0; - result.upper |= ((int64_t)bytes[0] << 56); - result.upper |= ((int64_t)bytes[1] << 48); - result.upper |= ((int64_t)bytes[2] << 40); - result.upper |= ((int64_t)bytes[3] << 32); - result.upper |= ((int64_t)bytes[4] << 24); - result.upper |= ((int64_t)bytes[5] << 16); - result.upper |= ((int64_t)bytes[6] << 8); - result.upper |= bytes[7]; - result.lower = 0; - result.lower |= ((uint64_t)bytes[8] << 56); - result.lower |= ((uint64_t)bytes[9] << 48); - result.lower |= ((uint64_t)bytes[10] << 40); - result.lower |= ((uint64_t)bytes[11] << 32); - result.lower |= ((uint64_t)bytes[12] << 24); - result.lower |= ((uint64_t)bytes[13] << 16); - result.lower |= ((uint64_t)bytes[14] << 8); - result.lower |= bytes[15]; - return result; -} - -hugeint_t UUID::GenerateRandomUUID() { - RandomEngine engine; - return GenerateRandomUUID(engine); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -ValidityData::ValidityData(idx_t count) : TemplatedValidityData(count) { -} -ValidityData::ValidityData(const ValidityMask &original, idx_t count) - : TemplatedValidityData(original.GetData(), count) { -} - -void ValidityMask::Combine(const ValidityMask &other, idx_t count) { - if (other.AllValid()) { - // X & 1 = X - return; - } - if (AllValid()) { - // 1 & Y = Y - Initialize(other); - return; - } - if (validity_mask == other.validity_mask) { - // X & X == X - return; - } - // have to merge - // create a new validity mask that contains the combined mask - auto owned_data = std::move(validity_data); - auto data = GetData(); - auto other_data = other.GetData(); - - Initialize(count); - auto result_data = GetData(); - - auto entry_count = ValidityData::EntryCount(count); - for (idx_t entry_idx = 0; entry_idx < entry_count; entry_idx++) { - result_data[entry_idx] = data[entry_idx] & other_data[entry_idx]; - } -} - -// LCOV_EXCL_START -string ValidityMask::ToString(idx_t count) const { - string result = "Validity Mask (" + to_string(count) + ") ["; - for (idx_t i = 0; i < count; i++) { - result += RowIsValid(i) ? "." : "X"; - } - result += "]"; - return result; -} -// LCOV_EXCL_STOP - -void ValidityMask::Resize(idx_t old_size, idx_t new_size) { - D_ASSERT(new_size >= old_size); - if (validity_mask) { - auto new_size_count = EntryCount(new_size); - auto old_size_count = EntryCount(old_size); - auto new_validity_data = make_buffer(new_size); - auto new_owned_data = new_validity_data->owned_data.get(); - for (idx_t entry_idx = 0; entry_idx < old_size_count; entry_idx++) { - new_owned_data[entry_idx] = validity_mask[entry_idx]; - } - for (idx_t entry_idx = old_size_count; entry_idx < new_size_count; entry_idx++) { - new_owned_data[entry_idx] = ValidityData::MAX_ENTRY; - } - validity_data = std::move(new_validity_data); - validity_mask = validity_data->owned_data.get(); - } else { - Initialize(new_size); - } -} - -void ValidityMask::Slice(const ValidityMask &other, idx_t source_offset, idx_t count) { - if (other.AllValid()) { - validity_mask = nullptr; - validity_data.reset(); - return; - } - if (source_offset == 0) { - Initialize(other); - return; - } - ValidityMask new_mask(count); - new_mask.SliceInPlace(other, 0, source_offset, count); - Initialize(new_mask); -} - -bool ValidityMask::IsAligned(idx_t count) { - return count % BITS_PER_VALUE == 0; -} - -void ValidityMask::SliceInPlace(const ValidityMask &other, idx_t target_offset, idx_t source_offset, idx_t count) { - if (IsAligned(source_offset) && IsAligned(target_offset)) { - auto target_validity = GetData(); - auto source_validity = other.GetData(); - auto source_offset_entries = EntryCount(source_offset); - auto target_offset_entries = EntryCount(target_offset); - memcpy(target_validity + target_offset_entries, source_validity + source_offset_entries, - sizeof(validity_t) * EntryCount(count)); - return; - } else if (IsAligned(target_offset)) { - // Simple common case where we are shifting into an aligned mask (e.g., 0 in Slice above) - const idx_t entire_units = count / BITS_PER_VALUE; - const idx_t ragged = count % BITS_PER_VALUE; - const idx_t tail = source_offset % BITS_PER_VALUE; - const idx_t head = BITS_PER_VALUE - tail; - auto source_validity = other.GetData() + (source_offset / BITS_PER_VALUE); - auto target_validity = this->GetData() + (target_offset / BITS_PER_VALUE); - auto src_entry = *source_validity++; - for (idx_t i = 0; i < entire_units; ++i) { - // Start with head of previous src - validity_t tgt_entry = src_entry >> tail; - src_entry = *source_validity++; - // Add in tail of current src - tgt_entry |= (src_entry << head); - *target_validity++ = tgt_entry; - } - // Finish last ragged entry - if (ragged) { - // Start with head of previous src - validity_t tgt_entry = (src_entry >> tail); - // Add in the tail of the next src, if head was too small - if (head < ragged) { - src_entry = *source_validity++; - tgt_entry |= (src_entry << head); - } - // Mask off the bits that go past the ragged end - tgt_entry &= (ValidityBuffer::MAX_ENTRY >> (BITS_PER_VALUE - ragged)); - // Restore the ragged end of the target - tgt_entry |= *target_validity & (ValidityBuffer::MAX_ENTRY << ragged); - *target_validity++ = tgt_entry; - } - return; - } - - // FIXME: use bitwise operations here -#if 1 - for (idx_t i = 0; i < count; i++) { - Set(target_offset + i, other.RowIsValid(source_offset + i)); - } -#else - // first shift the "whole" units - idx_t entire_units = offset / BITS_PER_VALUE; - idx_t sub_units = offset - entire_units * BITS_PER_VALUE; - if (entire_units > 0) { - idx_t validity_idx; - for (validity_idx = 0; validity_idx + entire_units < STANDARD_ENTRY_COUNT; validity_idx++) { - new_mask.validity_mask[validity_idx] = other.validity_mask[validity_idx + entire_units]; - } - } - // now we shift the remaining sub units - // this gets a bit more complicated because we have to shift over the borders of the entries - // e.g. suppose we have 2 entries of length 4 and we left-shift by two - // 0101|1010 - // a regular left-shift of both gets us: - // 0100|1000 - // we then OR the overflow (right-shifted by BITS_PER_VALUE - offset) together to get the correct result - // 0100|1000 -> - // 0110|1000 - if (sub_units > 0) { - idx_t validity_idx; - for (validity_idx = 0; validity_idx + 1 < STANDARD_ENTRY_COUNT; validity_idx++) { - new_mask.validity_mask[validity_idx] = - (other.validity_mask[validity_idx] >> sub_units) | - (other.validity_mask[validity_idx + 1] << (BITS_PER_VALUE - sub_units)); - } - new_mask.validity_mask[validity_idx] >>= sub_units; - } -#ifdef DEBUG - for (idx_t i = offset; i < STANDARD_VECTOR_SIZE; i++) { - D_ASSERT(new_mask.RowIsValid(i - offset) == other.RowIsValid(i)); - } - Initialize(new_mask); -#endif -#endif -} - -enum class ValiditySerialization : uint8_t { BITMASK = 0, VALID_VALUES = 1, INVALID_VALUES = 2 }; - -void ValidityMask::Write(WriteStream &writer, idx_t count) { - auto valid_values = CountValid(count); - auto invalid_values = count - valid_values; - auto bitmask_bytes = ValidityMask::ValidityMaskSize(count); - auto need_u32 = count >= NumericLimits::Maximum(); - auto bytes_per_value = need_u32 ? sizeof(uint32_t) : sizeof(uint16_t); - auto valid_value_size = bytes_per_value * valid_values + sizeof(uint32_t); - auto invalid_value_size = bytes_per_value * invalid_values + sizeof(uint32_t); - if (valid_value_size < bitmask_bytes || invalid_value_size < bitmask_bytes) { - auto serialize_valid = valid_value_size < invalid_value_size; - // serialize (in)valid value indexes as [COUNT][V0][V1][...][VN] - auto flag = serialize_valid ? ValiditySerialization::VALID_VALUES : ValiditySerialization::INVALID_VALUES; - writer.Write(flag); - writer.Write(MinValue(valid_values, invalid_values)); - for (idx_t i = 0; i < count; i++) { - if (RowIsValid(i) == serialize_valid) { - if (need_u32) { - writer.Write(i); - } else { - writer.Write(i); - } - } - } - } else { - // serialize the entire bitmask - writer.Write(ValiditySerialization::BITMASK); - writer.WriteData(const_data_ptr_cast(GetData()), bitmask_bytes); - } -} - -void ValidityMask::Read(ReadStream &reader, idx_t count) { - Initialize(count); - // deserialize the storage type - auto flag = reader.Read(); - if (flag == ValiditySerialization::BITMASK) { - // deserialize the bitmask - reader.ReadData(data_ptr_cast(GetData()), ValidityMask::ValidityMaskSize(count)); - return; - } - auto is_u32 = count >= NumericLimits::Maximum(); - auto is_valid = flag == ValiditySerialization::VALID_VALUES; - auto serialize_count = reader.Read(); - if (is_valid) { - SetAllInvalid(count); - } - for (idx_t i = 0; i < serialize_count; i++) { - idx_t index = is_u32 ? reader.Read() : reader.Read(); - Set(index, is_valid); - } -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -#include -#include - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Extra Value Info -//===--------------------------------------------------------------------===// -enum class ExtraValueInfoType : uint8_t { INVALID_TYPE_INFO = 0, STRING_VALUE_INFO = 1, NESTED_VALUE_INFO = 2 }; - -struct ExtraValueInfo { - explicit ExtraValueInfo(ExtraValueInfoType type) : type(type) { - } - virtual ~ExtraValueInfo() { - } - - ExtraValueInfoType type; - -public: - bool Equals(ExtraValueInfo *other_p) const { - if (!other_p) { - return false; - } - if (type != other_p->type) { - return false; - } - return EqualsInternal(other_p); - } - - template - T &Get() { - if (type != T::TYPE) { - throw InternalException("ExtraValueInfo type mismatch"); - } - return (T &)*this; - } - -protected: - virtual bool EqualsInternal(ExtraValueInfo *other_p) const { - return true; - } -}; - -//===--------------------------------------------------------------------===// -// String Value Info -//===--------------------------------------------------------------------===// -struct StringValueInfo : public ExtraValueInfo { - static constexpr const ExtraValueInfoType TYPE = ExtraValueInfoType::STRING_VALUE_INFO; - -public: - explicit StringValueInfo(string str_p) - : ExtraValueInfo(ExtraValueInfoType::STRING_VALUE_INFO), str(std::move(str_p)) { - } - - const string &GetString() { - return str; - } - -protected: - bool EqualsInternal(ExtraValueInfo *other_p) const override { - return other_p->Get().str == str; - } - - string str; -}; - -//===--------------------------------------------------------------------===// -// Nested Value Info -//===--------------------------------------------------------------------===// -struct NestedValueInfo : public ExtraValueInfo { - static constexpr const ExtraValueInfoType TYPE = ExtraValueInfoType::NESTED_VALUE_INFO; - -public: - NestedValueInfo() : ExtraValueInfo(ExtraValueInfoType::NESTED_VALUE_INFO) { - } - explicit NestedValueInfo(vector values_p) - : ExtraValueInfo(ExtraValueInfoType::NESTED_VALUE_INFO), values(std::move(values_p)) { - } - - const vector &GetValues() { - return values; - } - -protected: - bool EqualsInternal(ExtraValueInfo *other_p) const override { - return other_p->Get().values == values; - } - - vector values; -}; -//===--------------------------------------------------------------------===// -// Value -//===--------------------------------------------------------------------===// -Value::Value(LogicalType type) : type_(std::move(type)), is_null(true) { -} - -Value::Value(int32_t val) : type_(LogicalType::INTEGER), is_null(false) { - value_.integer = val; -} - -Value::Value(int64_t val) : type_(LogicalType::BIGINT), is_null(false) { - value_.bigint = val; -} - -Value::Value(float val) : type_(LogicalType::FLOAT), is_null(false) { - value_.float_ = val; -} - -Value::Value(double val) : type_(LogicalType::DOUBLE), is_null(false) { - value_.double_ = val; -} - -Value::Value(const char *val) : Value(val ? string(val) : string()) { -} - -Value::Value(std::nullptr_t val) : Value(LogicalType::VARCHAR) { -} - -Value::Value(string_t val) : Value(val.GetString()) { -} - -Value::Value(string val) : type_(LogicalType::VARCHAR), is_null(false) { - if (!Value::StringIsValid(val.c_str(), val.size())) { - throw Exception(ErrorManager::InvalidUnicodeError(val, "value construction")); - } - value_info_ = make_shared(std::move(val)); -} - -Value::~Value() { -} - -Value::Value(const Value &other) - : type_(other.type_), is_null(other.is_null), value_(other.value_), value_info_(other.value_info_) { -} - -Value::Value(Value &&other) noexcept - : type_(std::move(other.type_)), is_null(other.is_null), value_(other.value_), - value_info_(std::move(other.value_info_)) { -} - -Value &Value::operator=(const Value &other) { - if (this == &other) { - return *this; - } - type_ = other.type_; - is_null = other.is_null; - value_ = other.value_; - value_info_ = other.value_info_; - return *this; -} - -Value &Value::operator=(Value &&other) noexcept { - type_ = std::move(other.type_); - is_null = other.is_null; - value_ = other.value_; - value_info_ = std::move(other.value_info_); - return *this; -} - -Value Value::MinimumValue(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::BOOLEAN: - return Value::BOOLEAN(false); - case LogicalTypeId::TINYINT: - return Value::TINYINT(NumericLimits::Minimum()); - case LogicalTypeId::SMALLINT: - return Value::SMALLINT(NumericLimits::Minimum()); - case LogicalTypeId::INTEGER: - case LogicalTypeId::SQLNULL: - return Value::INTEGER(NumericLimits::Minimum()); - case LogicalTypeId::BIGINT: - return Value::BIGINT(NumericLimits::Minimum()); - case LogicalTypeId::HUGEINT: - return Value::HUGEINT(NumericLimits::Minimum()); - case LogicalTypeId::UUID: - return Value::UUID(NumericLimits::Minimum()); - case LogicalTypeId::UTINYINT: - return Value::UTINYINT(NumericLimits::Minimum()); - case LogicalTypeId::USMALLINT: - return Value::USMALLINT(NumericLimits::Minimum()); - case LogicalTypeId::UINTEGER: - return Value::UINTEGER(NumericLimits::Minimum()); - case LogicalTypeId::UBIGINT: - return Value::UBIGINT(NumericLimits::Minimum()); - case LogicalTypeId::DATE: - return Value::DATE(Date::FromDate(Date::DATE_MIN_YEAR, Date::DATE_MIN_MONTH, Date::DATE_MIN_DAY)); - case LogicalTypeId::TIME: - return Value::TIME(dtime_t(0)); - case LogicalTypeId::TIMESTAMP: - return Value::TIMESTAMP(Date::FromDate(Timestamp::MIN_YEAR, Timestamp::MIN_MONTH, Timestamp::MIN_DAY), - dtime_t(0)); - case LogicalTypeId::TIMESTAMP_SEC: - return MinimumValue(LogicalType::TIMESTAMP).DefaultCastAs(LogicalType::TIMESTAMP_S); - case LogicalTypeId::TIMESTAMP_MS: - return MinimumValue(LogicalType::TIMESTAMP).DefaultCastAs(LogicalType::TIMESTAMP_MS); - case LogicalTypeId::TIMESTAMP_NS: - return Value::TIMESTAMPNS(timestamp_t(NumericLimits::Minimum())); - case LogicalTypeId::TIME_TZ: - return Value::TIMETZ(dtime_tz_t(dtime_t(0), dtime_tz_t::MIN_OFFSET)); - case LogicalTypeId::TIMESTAMP_TZ: - return Value::TIMESTAMPTZ(Timestamp::FromDatetime( - Date::FromDate(Timestamp::MIN_YEAR, Timestamp::MIN_MONTH, Timestamp::MIN_DAY), dtime_t(0))); - case LogicalTypeId::FLOAT: - return Value::FLOAT(NumericLimits::Minimum()); - case LogicalTypeId::DOUBLE: - return Value::DOUBLE(NumericLimits::Minimum()); - case LogicalTypeId::DECIMAL: { - auto width = DecimalType::GetWidth(type); - auto scale = DecimalType::GetScale(type); - switch (type.InternalType()) { - case PhysicalType::INT16: - return Value::DECIMAL(int16_t(-NumericHelper::POWERS_OF_TEN[width] + 1), width, scale); - case PhysicalType::INT32: - return Value::DECIMAL(int32_t(-NumericHelper::POWERS_OF_TEN[width] + 1), width, scale); - case PhysicalType::INT64: - return Value::DECIMAL(int64_t(-NumericHelper::POWERS_OF_TEN[width] + 1), width, scale); - case PhysicalType::INT128: - return Value::DECIMAL(-Hugeint::POWERS_OF_TEN[width] + 1, width, scale); - default: - throw InternalException("Unknown decimal type"); - } - } - case LogicalTypeId::ENUM: - return Value::ENUM(0, type); - default: - throw InvalidTypeException(type, "MinimumValue requires numeric type"); - } -} - -Value Value::MaximumValue(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::BOOLEAN: - return Value::BOOLEAN(true); - case LogicalTypeId::TINYINT: - return Value::TINYINT(NumericLimits::Maximum()); - case LogicalTypeId::SMALLINT: - return Value::SMALLINT(NumericLimits::Maximum()); - case LogicalTypeId::INTEGER: - case LogicalTypeId::SQLNULL: - return Value::INTEGER(NumericLimits::Maximum()); - case LogicalTypeId::BIGINT: - return Value::BIGINT(NumericLimits::Maximum()); - case LogicalTypeId::HUGEINT: - return Value::HUGEINT(NumericLimits::Maximum()); - case LogicalTypeId::UUID: - return Value::UUID(NumericLimits::Maximum()); - case LogicalTypeId::UTINYINT: - return Value::UTINYINT(NumericLimits::Maximum()); - case LogicalTypeId::USMALLINT: - return Value::USMALLINT(NumericLimits::Maximum()); - case LogicalTypeId::UINTEGER: - return Value::UINTEGER(NumericLimits::Maximum()); - case LogicalTypeId::UBIGINT: - return Value::UBIGINT(NumericLimits::Maximum()); - case LogicalTypeId::DATE: - return Value::DATE(Date::FromDate(Date::DATE_MAX_YEAR, Date::DATE_MAX_MONTH, Date::DATE_MAX_DAY)); - case LogicalTypeId::TIME: - return Value::TIME(dtime_t(Interval::SECS_PER_DAY * Interval::MICROS_PER_SEC - 1)); - case LogicalTypeId::TIMESTAMP: - return Value::TIMESTAMP(timestamp_t(NumericLimits::Maximum() - 1)); - case LogicalTypeId::TIMESTAMP_MS: - return MaximumValue(LogicalType::TIMESTAMP).DefaultCastAs(LogicalType::TIMESTAMP_MS); - case LogicalTypeId::TIMESTAMP_NS: - return Value::TIMESTAMPNS(timestamp_t(NumericLimits::Maximum() - 1)); - case LogicalTypeId::TIMESTAMP_SEC: - return MaximumValue(LogicalType::TIMESTAMP).DefaultCastAs(LogicalType::TIMESTAMP_S); - case LogicalTypeId::TIME_TZ: - return Value::TIMETZ( - dtime_tz_t(dtime_t(Interval::SECS_PER_DAY * Interval::MICROS_PER_SEC - 1), dtime_tz_t::MAX_OFFSET)); - case LogicalTypeId::TIMESTAMP_TZ: - return MaximumValue(LogicalType::TIMESTAMP); - case LogicalTypeId::FLOAT: - return Value::FLOAT(NumericLimits::Maximum()); - case LogicalTypeId::DOUBLE: - return Value::DOUBLE(NumericLimits::Maximum()); - case LogicalTypeId::DECIMAL: { - auto width = DecimalType::GetWidth(type); - auto scale = DecimalType::GetScale(type); - switch (type.InternalType()) { - case PhysicalType::INT16: - return Value::DECIMAL(int16_t(NumericHelper::POWERS_OF_TEN[width] - 1), width, scale); - case PhysicalType::INT32: - return Value::DECIMAL(int32_t(NumericHelper::POWERS_OF_TEN[width] - 1), width, scale); - case PhysicalType::INT64: - return Value::DECIMAL(int64_t(NumericHelper::POWERS_OF_TEN[width] - 1), width, scale); - case PhysicalType::INT128: - return Value::DECIMAL(Hugeint::POWERS_OF_TEN[width] - 1, width, scale); - default: - throw InternalException("Unknown decimal type"); - } - } - case LogicalTypeId::ENUM: - return Value::ENUM(EnumType::GetSize(type) - 1, type); - default: - throw InvalidTypeException(type, "MaximumValue requires numeric type"); - } -} - -Value Value::Infinity(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::DATE: - return Value::DATE(date_t::infinity()); - case LogicalTypeId::TIMESTAMP: - return Value::TIMESTAMP(timestamp_t::infinity()); - case LogicalTypeId::TIMESTAMP_MS: - return Value::TIMESTAMPMS(timestamp_t::infinity()); - case LogicalTypeId::TIMESTAMP_NS: - return Value::TIMESTAMPNS(timestamp_t::infinity()); - case LogicalTypeId::TIMESTAMP_SEC: - return Value::TIMESTAMPSEC(timestamp_t::infinity()); - case LogicalTypeId::TIMESTAMP_TZ: - return Value::TIMESTAMPTZ(timestamp_t::infinity()); - case LogicalTypeId::FLOAT: - return Value::FLOAT(std::numeric_limits::infinity()); - case LogicalTypeId::DOUBLE: - return Value::DOUBLE(std::numeric_limits::infinity()); - default: - throw InvalidTypeException(type, "Infinity requires numeric type"); - } -} - -Value Value::NegativeInfinity(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::DATE: - return Value::DATE(date_t::ninfinity()); - case LogicalTypeId::TIMESTAMP: - return Value::TIMESTAMP(timestamp_t::ninfinity()); - case LogicalTypeId::TIMESTAMP_MS: - return Value::TIMESTAMPMS(timestamp_t::ninfinity()); - case LogicalTypeId::TIMESTAMP_NS: - return Value::TIMESTAMPNS(timestamp_t::ninfinity()); - case LogicalTypeId::TIMESTAMP_SEC: - return Value::TIMESTAMPSEC(timestamp_t::ninfinity()); - case LogicalTypeId::TIMESTAMP_TZ: - return Value::TIMESTAMPTZ(timestamp_t::ninfinity()); - case LogicalTypeId::FLOAT: - return Value::FLOAT(-std::numeric_limits::infinity()); - case LogicalTypeId::DOUBLE: - return Value::DOUBLE(-std::numeric_limits::infinity()); - default: - throw InvalidTypeException(type, "NegativeInfinity requires numeric type"); - } -} - -Value Value::BOOLEAN(int8_t value) { - Value result(LogicalType::BOOLEAN); - result.value_.boolean = bool(value); - result.is_null = false; - return result; -} - -Value Value::TINYINT(int8_t value) { - Value result(LogicalType::TINYINT); - result.value_.tinyint = value; - result.is_null = false; - return result; -} - -Value Value::SMALLINT(int16_t value) { - Value result(LogicalType::SMALLINT); - result.value_.smallint = value; - result.is_null = false; - return result; -} - -Value Value::INTEGER(int32_t value) { - Value result(LogicalType::INTEGER); - result.value_.integer = value; - result.is_null = false; - return result; -} - -Value Value::BIGINT(int64_t value) { - Value result(LogicalType::BIGINT); - result.value_.bigint = value; - result.is_null = false; - return result; -} - -Value Value::HUGEINT(hugeint_t value) { - Value result(LogicalType::HUGEINT); - result.value_.hugeint = value; - result.is_null = false; - return result; -} - -Value Value::UUID(hugeint_t value) { - Value result(LogicalType::UUID); - result.value_.hugeint = value; - result.is_null = false; - return result; -} - -Value Value::UUID(const string &value) { - Value result(LogicalType::UUID); - result.value_.hugeint = UUID::FromString(value); - result.is_null = false; - return result; -} - -Value Value::UTINYINT(uint8_t value) { - Value result(LogicalType::UTINYINT); - result.value_.utinyint = value; - result.is_null = false; - return result; -} - -Value Value::USMALLINT(uint16_t value) { - Value result(LogicalType::USMALLINT); - result.value_.usmallint = value; - result.is_null = false; - return result; -} - -Value Value::UINTEGER(uint32_t value) { - Value result(LogicalType::UINTEGER); - result.value_.uinteger = value; - result.is_null = false; - return result; -} - -Value Value::UBIGINT(uint64_t value) { - Value result(LogicalType::UBIGINT); - result.value_.ubigint = value; - result.is_null = false; - return result; -} - -bool Value::FloatIsFinite(float value) { - return !(std::isnan(value) || std::isinf(value)); -} - -bool Value::DoubleIsFinite(double value) { - return !(std::isnan(value) || std::isinf(value)); -} - -template <> -bool Value::IsNan(float input) { - return std::isnan(input); -} - -template <> -bool Value::IsNan(double input) { - return std::isnan(input); -} - -template <> -bool Value::IsFinite(float input) { - return Value::FloatIsFinite(input); -} - -template <> -bool Value::IsFinite(double input) { - return Value::DoubleIsFinite(input); -} - -template <> -bool Value::IsFinite(date_t input) { - return Date::IsFinite(input); -} - -template <> -bool Value::IsFinite(timestamp_t input) { - return Timestamp::IsFinite(input); -} - -bool Value::StringIsValid(const char *str, idx_t length) { - auto utf_type = Utf8Proc::Analyze(str, length); - return utf_type != UnicodeType::INVALID; -} - -Value Value::DECIMAL(int16_t value, uint8_t width, uint8_t scale) { - return Value::DECIMAL(int64_t(value), width, scale); -} - -Value Value::DECIMAL(int32_t value, uint8_t width, uint8_t scale) { - return Value::DECIMAL(int64_t(value), width, scale); -} - -Value Value::DECIMAL(int64_t value, uint8_t width, uint8_t scale) { - auto decimal_type = LogicalType::DECIMAL(width, scale); - Value result(decimal_type); - switch (decimal_type.InternalType()) { - case PhysicalType::INT16: - result.value_.smallint = value; - break; - case PhysicalType::INT32: - result.value_.integer = value; - break; - case PhysicalType::INT64: - result.value_.bigint = value; - break; - default: - result.value_.hugeint = value; - break; - } - result.type_.Verify(); - result.is_null = false; - return result; -} - -Value Value::DECIMAL(hugeint_t value, uint8_t width, uint8_t scale) { - D_ASSERT(width >= Decimal::MAX_WIDTH_INT64 && width <= Decimal::MAX_WIDTH_INT128); - Value result(LogicalType::DECIMAL(width, scale)); - result.value_.hugeint = value; - result.is_null = false; - return result; -} - -Value Value::FLOAT(float value) { - Value result(LogicalType::FLOAT); - result.value_.float_ = value; - result.is_null = false; - return result; -} - -Value Value::DOUBLE(double value) { - Value result(LogicalType::DOUBLE); - result.value_.double_ = value; - result.is_null = false; - return result; -} - -Value Value::HASH(hash_t value) { - Value result(LogicalType::HASH); - result.value_.hash = value; - result.is_null = false; - return result; -} - -Value Value::POINTER(uintptr_t value) { - Value result(LogicalType::POINTER); - result.value_.pointer = value; - result.is_null = false; - return result; -} - -Value Value::DATE(date_t value) { - Value result(LogicalType::DATE); - result.value_.date = value; - result.is_null = false; - return result; -} - -Value Value::DATE(int32_t year, int32_t month, int32_t day) { - return Value::DATE(Date::FromDate(year, month, day)); -} - -Value Value::TIME(dtime_t value) { - Value result(LogicalType::TIME); - result.value_.time = value; - result.is_null = false; - return result; -} - -Value Value::TIMETZ(dtime_tz_t value) { - Value result(LogicalType::TIME_TZ); - result.value_.timetz = value; - result.is_null = false; - return result; -} - -Value Value::TIME(int32_t hour, int32_t min, int32_t sec, int32_t micros) { - return Value::TIME(Time::FromTime(hour, min, sec, micros)); -} - -Value Value::TIMESTAMP(timestamp_t value) { - Value result(LogicalType::TIMESTAMP); - result.value_.timestamp = value; - result.is_null = false; - return result; -} - -Value Value::TIMESTAMPTZ(timestamp_t value) { - Value result(LogicalType::TIMESTAMP_TZ); - result.value_.timestamp = value; - result.is_null = false; - return result; -} - -Value Value::TIMESTAMPNS(timestamp_t timestamp) { - Value result(LogicalType::TIMESTAMP_NS); - result.value_.timestamp = timestamp; - result.is_null = false; - return result; -} - -Value Value::TIMESTAMPMS(timestamp_t timestamp) { - Value result(LogicalType::TIMESTAMP_MS); - result.value_.timestamp = timestamp; - result.is_null = false; - return result; -} - -Value Value::TIMESTAMPSEC(timestamp_t timestamp) { - Value result(LogicalType::TIMESTAMP_S); - result.value_.timestamp = timestamp; - result.is_null = false; - return result; -} - -Value Value::TIMESTAMP(date_t date, dtime_t time) { - return Value::TIMESTAMP(Timestamp::FromDatetime(date, time)); -} - -Value Value::TIMESTAMP(int32_t year, int32_t month, int32_t day, int32_t hour, int32_t min, int32_t sec, - int32_t micros) { - auto val = Value::TIMESTAMP(Date::FromDate(year, month, day), Time::FromTime(hour, min, sec, micros)); - val.type_ = LogicalType::TIMESTAMP; - return val; -} - -Value Value::STRUCT(child_list_t values) { - Value result; - child_list_t child_types; - vector struct_values; - for (auto &child : values) { - child_types.push_back(make_pair(std::move(child.first), child.second.type())); - struct_values.push_back(std::move(child.second)); - } - result.value_info_ = make_shared(std::move(struct_values)); - result.type_ = LogicalType::STRUCT(child_types); - result.is_null = false; - return result; -} - -Value Value::MAP(const LogicalType &child_type, vector values) { - Value result; - - result.type_ = LogicalType::MAP(child_type); - result.is_null = false; - for (auto &val : values) { - D_ASSERT(val.type().InternalType() == PhysicalType::STRUCT); - auto &children = StructValue::GetChildren(val); - - // Ensure that the field containing the keys is called 'key' - // and that the field containing the values is called 'value' - // this is required to make equality checks work - D_ASSERT(children.size() == 2); - child_list_t new_children; - new_children.reserve(2); - new_children.push_back(std::make_pair("key", children[0])); - new_children.push_back(std::make_pair("value", children[1])); - val = Value::STRUCT(std::move(new_children)); - } - result.value_info_ = make_shared(std::move(values)); - return result; -} - -Value Value::UNION(child_list_t members, uint8_t tag, Value value) { - D_ASSERT(!members.empty()); - D_ASSERT(members.size() <= UnionType::MAX_UNION_MEMBERS); - D_ASSERT(members.size() > tag); - - D_ASSERT(value.type() == members[tag].second); - - Value result; - result.is_null = false; - // add the tag to the front of the struct - vector union_values; - union_values.emplace_back(Value::UTINYINT(tag)); - for (idx_t i = 0; i < members.size(); i++) { - if (i != tag) { - union_values.emplace_back(members[i].second); - } else { - union_values.emplace_back(nullptr); - } - } - union_values[tag + 1] = std::move(value); - result.value_info_ = make_shared(std::move(union_values)); - result.type_ = LogicalType::UNION(std::move(members)); - return result; -} - -Value Value::LIST(vector values) { - if (values.empty()) { - throw InternalException("Value::LIST without providing a child-type requires a non-empty list of values. Use " - "Value::LIST(child_type, list) instead."); - } -#ifdef DEBUG - for (idx_t i = 1; i < values.size(); i++) { - D_ASSERT(values[i].type() == values[0].type()); - } -#endif - Value result; - result.type_ = LogicalType::LIST(values[0].type()); - result.value_info_ = make_shared(std::move(values)); - result.is_null = false; - return result; -} - -Value Value::LIST(const LogicalType &child_type, vector values) { - if (values.empty()) { - return Value::EMPTYLIST(child_type); - } - for (auto &val : values) { - val = val.DefaultCastAs(child_type); - } - return Value::LIST(std::move(values)); -} - -Value Value::EMPTYLIST(const LogicalType &child_type) { - Value result; - result.type_ = LogicalType::LIST(child_type); - result.value_info_ = make_shared(); - result.is_null = false; - return result; -} - -Value Value::BLOB(const_data_ptr_t data, idx_t len) { - Value result(LogicalType::BLOB); - result.is_null = false; - result.value_info_ = make_shared(string(const_char_ptr_cast(data), len)); - return result; -} - -Value Value::BLOB(const string &data) { - Value result(LogicalType::BLOB); - result.is_null = false; - result.value_info_ = make_shared(Blob::ToBlob(string_t(data))); - return result; -} - -Value Value::BIT(const_data_ptr_t data, idx_t len) { - Value result(LogicalType::BIT); - result.is_null = false; - result.value_info_ = make_shared(string(const_char_ptr_cast(data), len)); - return result; -} - -Value Value::BIT(const string &data) { - Value result(LogicalType::BIT); - result.is_null = false; - result.value_info_ = make_shared(Bit::ToBit(string_t(data))); - return result; -} - -Value Value::ENUM(uint64_t value, const LogicalType &original_type) { - D_ASSERT(original_type.id() == LogicalTypeId::ENUM); - Value result(original_type); - switch (original_type.InternalType()) { - case PhysicalType::UINT8: - result.value_.utinyint = value; - break; - case PhysicalType::UINT16: - result.value_.usmallint = value; - break; - case PhysicalType::UINT32: - result.value_.uinteger = value; - break; - default: - throw InternalException("Incorrect Physical Type for ENUM"); - } - result.is_null = false; - return result; -} - -Value Value::INTERVAL(int32_t months, int32_t days, int64_t micros) { - Value result(LogicalType::INTERVAL); - result.is_null = false; - result.value_.interval.months = months; - result.value_.interval.days = days; - result.value_.interval.micros = micros; - return result; -} - -Value Value::INTERVAL(interval_t interval) { - return Value::INTERVAL(interval.months, interval.days, interval.micros); -} - -//===--------------------------------------------------------------------===// -// CreateValue -//===--------------------------------------------------------------------===// -template <> -Value Value::CreateValue(bool value) { - return Value::BOOLEAN(value); -} - -template <> -Value Value::CreateValue(int8_t value) { - return Value::TINYINT(value); -} - -template <> -Value Value::CreateValue(int16_t value) { - return Value::SMALLINT(value); -} - -template <> -Value Value::CreateValue(int32_t value) { - return Value::INTEGER(value); -} - -template <> -Value Value::CreateValue(int64_t value) { - return Value::BIGINT(value); -} - -template <> -Value Value::CreateValue(uint8_t value) { - return Value::UTINYINT(value); -} - -template <> -Value Value::CreateValue(uint16_t value) { - return Value::USMALLINT(value); -} - -template <> -Value Value::CreateValue(uint32_t value) { - return Value::UINTEGER(value); -} - -template <> -Value Value::CreateValue(uint64_t value) { - return Value::UBIGINT(value); -} - -template <> -Value Value::CreateValue(hugeint_t value) { - return Value::HUGEINT(value); -} - -template <> -Value Value::CreateValue(date_t value) { - return Value::DATE(value); -} - -template <> -Value Value::CreateValue(dtime_t value) { - return Value::TIME(value); -} - -template <> -Value Value::CreateValue(dtime_tz_t value) { - return Value::TIMETZ(value); -} - -template <> -Value Value::CreateValue(timestamp_t value) { - return Value::TIMESTAMP(value); -} - -template <> -Value Value::CreateValue(timestamp_sec_t value) { - return Value::TIMESTAMPSEC(value); -} - -template <> -Value Value::CreateValue(timestamp_ms_t value) { - return Value::TIMESTAMPMS(value); -} - -template <> -Value Value::CreateValue(timestamp_ns_t value) { - return Value::TIMESTAMPNS(value); -} - -template <> -Value Value::CreateValue(timestamp_tz_t value) { - return Value::TIMESTAMPTZ(value); -} - -template <> -Value Value::CreateValue(const char *value) { - return Value(string(value)); -} - -template <> -Value Value::CreateValue(string value) { // NOLINT: required for templating - return Value::BLOB(value); -} - -template <> -Value Value::CreateValue(string_t value) { - return Value(value); -} - -template <> -Value Value::CreateValue(float value) { - return Value::FLOAT(value); -} - -template <> -Value Value::CreateValue(double value) { - return Value::DOUBLE(value); -} - -template <> -Value Value::CreateValue(interval_t value) { - return Value::INTERVAL(value); -} - -template <> -Value Value::CreateValue(Value value) { - return value; -} - -//===--------------------------------------------------------------------===// -// GetValue -//===--------------------------------------------------------------------===// -template -T Value::GetValueInternal() const { - if (IsNull()) { - throw InternalException("Calling GetValueInternal on a value that is NULL"); - } - switch (type_.id()) { - case LogicalTypeId::BOOLEAN: - return Cast::Operation(value_.boolean); - case LogicalTypeId::TINYINT: - return Cast::Operation(value_.tinyint); - case LogicalTypeId::SMALLINT: - return Cast::Operation(value_.smallint); - case LogicalTypeId::INTEGER: - return Cast::Operation(value_.integer); - case LogicalTypeId::BIGINT: - return Cast::Operation(value_.bigint); - case LogicalTypeId::HUGEINT: - case LogicalTypeId::UUID: - return Cast::Operation(value_.hugeint); - case LogicalTypeId::DATE: - return Cast::Operation(value_.date); - case LogicalTypeId::TIME: - return Cast::Operation(value_.time); - case LogicalTypeId::TIME_TZ: - return Cast::Operation(value_.timetz); - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_TZ: - return Cast::Operation(value_.timestamp); - case LogicalTypeId::UTINYINT: - return Cast::Operation(value_.utinyint); - case LogicalTypeId::USMALLINT: - return Cast::Operation(value_.usmallint); - case LogicalTypeId::UINTEGER: - return Cast::Operation(value_.uinteger); - case LogicalTypeId::TIMESTAMP_MS: - case LogicalTypeId::TIMESTAMP_NS: - case LogicalTypeId::TIMESTAMP_SEC: - case LogicalTypeId::UBIGINT: - return Cast::Operation(value_.ubigint); - case LogicalTypeId::FLOAT: - return Cast::Operation(value_.float_); - case LogicalTypeId::DOUBLE: - return Cast::Operation(value_.double_); - case LogicalTypeId::VARCHAR: - return Cast::Operation(StringValue::Get(*this).c_str()); - case LogicalTypeId::INTERVAL: - return Cast::Operation(value_.interval); - case LogicalTypeId::DECIMAL: - return DefaultCastAs(LogicalType::DOUBLE).GetValueInternal(); - case LogicalTypeId::ENUM: { - switch (type_.InternalType()) { - case PhysicalType::UINT8: - return Cast::Operation(value_.utinyint); - case PhysicalType::UINT16: - return Cast::Operation(value_.usmallint); - case PhysicalType::UINT32: - return Cast::Operation(value_.uinteger); - default: - throw InternalException("Invalid Internal Type for ENUMs"); - } - } - default: - throw NotImplementedException("Unimplemented type \"%s\" for GetValue()", type_.ToString()); - } -} - -template <> -bool Value::GetValue() const { - return GetValueInternal(); -} -template <> -int8_t Value::GetValue() const { - return GetValueInternal(); -} -template <> -int16_t Value::GetValue() const { - return GetValueInternal(); -} -template <> -int32_t Value::GetValue() const { - if (type_.id() == LogicalTypeId::DATE) { - return value_.integer; - } - return GetValueInternal(); -} -template <> -int64_t Value::GetValue() const { - if (IsNull()) { - throw InternalException("Calling GetValue on a value that is NULL"); - } - switch (type_.id()) { - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_SEC: - case LogicalTypeId::TIMESTAMP_NS: - case LogicalTypeId::TIMESTAMP_MS: - case LogicalTypeId::TIME: - case LogicalTypeId::TIMESTAMP_TZ: - return value_.bigint; - default: - return GetValueInternal(); - } -} -template <> -hugeint_t Value::GetValue() const { - return GetValueInternal(); -} -template <> -uint8_t Value::GetValue() const { - return GetValueInternal(); -} -template <> -uint16_t Value::GetValue() const { - return GetValueInternal(); -} -template <> -uint32_t Value::GetValue() const { - return GetValueInternal(); -} -template <> -uint64_t Value::GetValue() const { - return GetValueInternal(); -} -template <> -string Value::GetValue() const { - return ToString(); -} -template <> -float Value::GetValue() const { - return GetValueInternal(); -} -template <> -double Value::GetValue() const { - return GetValueInternal(); -} -template <> -date_t Value::GetValue() const { - return GetValueInternal(); -} -template <> -dtime_t Value::GetValue() const { - return GetValueInternal(); -} -template <> -timestamp_t Value::GetValue() const { - return GetValueInternal(); -} - -template <> -DUCKDB_API interval_t Value::GetValue() const { - return GetValueInternal(); -} - -template <> -DUCKDB_API Value Value::GetValue() const { - return Value(*this); -} - -uintptr_t Value::GetPointer() const { - D_ASSERT(type() == LogicalType::POINTER); - return value_.pointer; -} - -Value Value::Numeric(const LogicalType &type, int64_t value) { - switch (type.id()) { - case LogicalTypeId::BOOLEAN: - D_ASSERT(value == 0 || value == 1); - return Value::BOOLEAN(value ? 1 : 0); - case LogicalTypeId::TINYINT: - D_ASSERT(value >= NumericLimits::Minimum() && value <= NumericLimits::Maximum()); - return Value::TINYINT((int8_t)value); - case LogicalTypeId::SMALLINT: - D_ASSERT(value >= NumericLimits::Minimum() && value <= NumericLimits::Maximum()); - return Value::SMALLINT((int16_t)value); - case LogicalTypeId::INTEGER: - D_ASSERT(value >= NumericLimits::Minimum() && value <= NumericLimits::Maximum()); - return Value::INTEGER((int32_t)value); - case LogicalTypeId::BIGINT: - return Value::BIGINT(value); - case LogicalTypeId::UTINYINT: - D_ASSERT(value >= NumericLimits::Minimum() && value <= NumericLimits::Maximum()); - return Value::UTINYINT((uint8_t)value); - case LogicalTypeId::USMALLINT: - D_ASSERT(value >= NumericLimits::Minimum() && value <= NumericLimits::Maximum()); - return Value::USMALLINT((uint16_t)value); - case LogicalTypeId::UINTEGER: - D_ASSERT(value >= NumericLimits::Minimum() && value <= NumericLimits::Maximum()); - return Value::UINTEGER((uint32_t)value); - case LogicalTypeId::UBIGINT: - D_ASSERT(value >= 0); - return Value::UBIGINT(value); - case LogicalTypeId::HUGEINT: - return Value::HUGEINT(value); - case LogicalTypeId::DECIMAL: - return Value::DECIMAL(value, DecimalType::GetWidth(type), DecimalType::GetScale(type)); - case LogicalTypeId::FLOAT: - return Value((float)value); - case LogicalTypeId::DOUBLE: - return Value((double)value); - case LogicalTypeId::POINTER: - return Value::POINTER(value); - case LogicalTypeId::DATE: - D_ASSERT(value >= NumericLimits::Minimum() && value <= NumericLimits::Maximum()); - return Value::DATE(date_t(value)); - case LogicalTypeId::TIME: - return Value::TIME(dtime_t(value)); - case LogicalTypeId::TIMESTAMP: - return Value::TIMESTAMP(timestamp_t(value)); - case LogicalTypeId::TIMESTAMP_NS: - return Value::TIMESTAMPNS(timestamp_t(value)); - case LogicalTypeId::TIMESTAMP_MS: - return Value::TIMESTAMPMS(timestamp_t(value)); - case LogicalTypeId::TIMESTAMP_SEC: - return Value::TIMESTAMPSEC(timestamp_t(value)); - case LogicalTypeId::TIMESTAMP_TZ: - return Value::TIMESTAMPTZ(timestamp_t(value)); - case LogicalTypeId::ENUM: - switch (type.InternalType()) { - case PhysicalType::UINT8: - D_ASSERT(value >= NumericLimits::Minimum() && value <= NumericLimits::Maximum()); - return Value::UTINYINT((uint8_t)value); - case PhysicalType::UINT16: - D_ASSERT(value >= NumericLimits::Minimum() && value <= NumericLimits::Maximum()); - return Value::USMALLINT((uint16_t)value); - case PhysicalType::UINT32: - D_ASSERT(value >= NumericLimits::Minimum() && value <= NumericLimits::Maximum()); - return Value::UINTEGER((uint32_t)value); - default: - throw InternalException("Enum doesn't accept this physical type"); - } - default: - throw InvalidTypeException(type, "Numeric requires numeric type"); - } -} - -Value Value::Numeric(const LogicalType &type, hugeint_t value) { -#ifdef DEBUG - // perform a throwing cast to verify that the type fits - Value::HUGEINT(value).DefaultCastAs(type); -#endif - switch (type.id()) { - case LogicalTypeId::HUGEINT: - return Value::HUGEINT(value); - case LogicalTypeId::UBIGINT: - return Value::UBIGINT(Hugeint::Cast(value)); - default: - return Value::Numeric(type, Hugeint::Cast(value)); - } -} - -//===--------------------------------------------------------------------===// -// GetValueUnsafe -//===--------------------------------------------------------------------===// -template <> -DUCKDB_API bool Value::GetValueUnsafe() const { - D_ASSERT(type_.InternalType() == PhysicalType::BOOL); - return value_.boolean; -} - -template <> -int8_t Value::GetValueUnsafe() const { - D_ASSERT(type_.InternalType() == PhysicalType::INT8 || type_.InternalType() == PhysicalType::BOOL); - return value_.tinyint; -} - -template <> -int16_t Value::GetValueUnsafe() const { - D_ASSERT(type_.InternalType() == PhysicalType::INT16); - return value_.smallint; -} - -template <> -int32_t Value::GetValueUnsafe() const { - D_ASSERT(type_.InternalType() == PhysicalType::INT32); - return value_.integer; -} - -template <> -int64_t Value::GetValueUnsafe() const { - D_ASSERT(type_.InternalType() == PhysicalType::INT64); - return value_.bigint; -} - -template <> -hugeint_t Value::GetValueUnsafe() const { - D_ASSERT(type_.InternalType() == PhysicalType::INT128); - return value_.hugeint; -} - -template <> -uint8_t Value::GetValueUnsafe() const { - D_ASSERT(type_.InternalType() == PhysicalType::UINT8); - return value_.utinyint; -} - -template <> -uint16_t Value::GetValueUnsafe() const { - D_ASSERT(type_.InternalType() == PhysicalType::UINT16); - return value_.usmallint; -} - -template <> -uint32_t Value::GetValueUnsafe() const { - D_ASSERT(type_.InternalType() == PhysicalType::UINT32); - return value_.uinteger; -} - -template <> -uint64_t Value::GetValueUnsafe() const { - D_ASSERT(type_.InternalType() == PhysicalType::UINT64); - return value_.ubigint; -} - -template <> -string Value::GetValueUnsafe() const { - return StringValue::Get(*this); -} - -template <> -DUCKDB_API string_t Value::GetValueUnsafe() const { - return string_t(StringValue::Get(*this)); -} - -template <> -float Value::GetValueUnsafe() const { - D_ASSERT(type_.InternalType() == PhysicalType::FLOAT); - return value_.float_; -} - -template <> -double Value::GetValueUnsafe() const { - D_ASSERT(type_.InternalType() == PhysicalType::DOUBLE); - return value_.double_; -} - -template <> -date_t Value::GetValueUnsafe() const { - D_ASSERT(type_.InternalType() == PhysicalType::INT32); - return value_.date; -} - -template <> -dtime_t Value::GetValueUnsafe() const { - D_ASSERT(type_.InternalType() == PhysicalType::INT64); - return value_.time; -} - -template <> -timestamp_t Value::GetValueUnsafe() const { - D_ASSERT(type_.InternalType() == PhysicalType::INT64); - return value_.timestamp; -} - -template <> -interval_t Value::GetValueUnsafe() const { - D_ASSERT(type_.InternalType() == PhysicalType::INTERVAL); - return value_.interval; -} - -//===--------------------------------------------------------------------===// -// Hash -//===--------------------------------------------------------------------===// -hash_t Value::Hash() const { - if (IsNull()) { - return 0; - } - Vector input(*this); - Vector result(LogicalType::HASH); - VectorOperations::Hash(input, result, 1); - - auto data = FlatVector::GetData(result); - return data[0]; -} - -string Value::ToString() const { - if (IsNull()) { - return "NULL"; - } - return StringValue::Get(DefaultCastAs(LogicalType::VARCHAR)); -} - -string Value::ToSQLString() const { - if (IsNull()) { - return ToString(); - } - switch (type_.id()) { - case LogicalTypeId::UUID: - case LogicalTypeId::DATE: - case LogicalTypeId::TIME: - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIME_TZ: - case LogicalTypeId::TIMESTAMP_TZ: - case LogicalTypeId::TIMESTAMP_SEC: - case LogicalTypeId::TIMESTAMP_MS: - case LogicalTypeId::TIMESTAMP_NS: - case LogicalTypeId::INTERVAL: - case LogicalTypeId::BLOB: - return "'" + ToString() + "'::" + type_.ToString(); - case LogicalTypeId::VARCHAR: - case LogicalTypeId::ENUM: - return "'" + StringUtil::Replace(ToString(), "'", "''") + "'"; - case LogicalTypeId::STRUCT: { - string ret = "{"; - auto &child_types = StructType::GetChildTypes(type_); - auto &struct_values = StructValue::GetChildren(*this); - for (size_t i = 0; i < struct_values.size(); i++) { - auto &name = child_types[i].first; - auto &child = struct_values[i]; - ret += "'" + name + "': " + child.ToSQLString(); - if (i < struct_values.size() - 1) { - ret += ", "; - } - } - ret += "}"; - return ret; - } - case LogicalTypeId::FLOAT: - if (!FloatIsFinite(FloatValue::Get(*this))) { - return "'" + ToString() + "'::" + type_.ToString(); - } - return ToString(); - case LogicalTypeId::DOUBLE: { - double val = DoubleValue::Get(*this); - if (!DoubleIsFinite(val)) { - if (!Value::IsNan(val)) { - // to infinity and beyond - return val < 0 ? "-1e1000" : "1e1000"; - } - return "'" + ToString() + "'::" + type_.ToString(); - } - return ToString(); - } - case LogicalTypeId::LIST: { - string ret = "["; - auto &list_values = ListValue::GetChildren(*this); - for (size_t i = 0; i < list_values.size(); i++) { - auto &child = list_values[i]; - ret += child.ToSQLString(); - if (i < list_values.size() - 1) { - ret += ", "; - } - } - ret += "]"; - return ret; - } - default: - return ToString(); - } -} - -//===--------------------------------------------------------------------===// -// Type-specific getters -//===--------------------------------------------------------------------===// -bool BooleanValue::Get(const Value &value) { - return value.GetValueUnsafe(); -} - -int8_t TinyIntValue::Get(const Value &value) { - return value.GetValueUnsafe(); -} - -int16_t SmallIntValue::Get(const Value &value) { - return value.GetValueUnsafe(); -} - -int32_t IntegerValue::Get(const Value &value) { - return value.GetValueUnsafe(); -} - -int64_t BigIntValue::Get(const Value &value) { - return value.GetValueUnsafe(); -} - -hugeint_t HugeIntValue::Get(const Value &value) { - return value.GetValueUnsafe(); -} - -uint8_t UTinyIntValue::Get(const Value &value) { - return value.GetValueUnsafe(); -} - -uint16_t USmallIntValue::Get(const Value &value) { - return value.GetValueUnsafe(); -} - -uint32_t UIntegerValue::Get(const Value &value) { - return value.GetValueUnsafe(); -} - -uint64_t UBigIntValue::Get(const Value &value) { - return value.GetValueUnsafe(); -} - -float FloatValue::Get(const Value &value) { - return value.GetValueUnsafe(); -} - -double DoubleValue::Get(const Value &value) { - return value.GetValueUnsafe(); -} - -const string &StringValue::Get(const Value &value) { - if (value.is_null) { - throw InternalException("Calling StringValue::Get on a NULL value"); - } - D_ASSERT(value.type().InternalType() == PhysicalType::VARCHAR); - D_ASSERT(value.value_info_); - return value.value_info_->Get().GetString(); -} - -date_t DateValue::Get(const Value &value) { - return value.GetValueUnsafe(); -} - -dtime_t TimeValue::Get(const Value &value) { - return value.GetValueUnsafe(); -} - -timestamp_t TimestampValue::Get(const Value &value) { - return value.GetValueUnsafe(); -} - -interval_t IntervalValue::Get(const Value &value) { - return value.GetValueUnsafe(); -} - -const vector &StructValue::GetChildren(const Value &value) { - if (value.is_null) { - throw InternalException("Calling StructValue::GetChildren on a NULL value"); - } - D_ASSERT(value.type().InternalType() == PhysicalType::STRUCT); - D_ASSERT(value.value_info_); - return value.value_info_->Get().GetValues(); -} - -const vector &ListValue::GetChildren(const Value &value) { - if (value.is_null) { - throw InternalException("Calling ListValue::GetChildren on a NULL value"); - } - D_ASSERT(value.type().InternalType() == PhysicalType::LIST); - D_ASSERT(value.value_info_); - return value.value_info_->Get().GetValues(); -} - -const Value &UnionValue::GetValue(const Value &value) { - D_ASSERT(value.type().id() == LogicalTypeId::UNION); - auto &children = StructValue::GetChildren(value); - auto tag = children[0].GetValueUnsafe(); - D_ASSERT(tag < children.size() - 1); - return children[tag + 1]; -} - -union_tag_t UnionValue::GetTag(const Value &value) { - D_ASSERT(value.type().id() == LogicalTypeId::UNION); - auto children = StructValue::GetChildren(value); - auto tag = children[0].GetValueUnsafe(); - D_ASSERT(tag < children.size() - 1); - return tag; -} - -const LogicalType &UnionValue::GetType(const Value &value) { - return UnionType::GetMemberType(value.type(), UnionValue::GetTag(value)); -} - -hugeint_t IntegralValue::Get(const Value &value) { - switch (value.type().InternalType()) { - case PhysicalType::INT8: - return TinyIntValue::Get(value); - case PhysicalType::INT16: - return SmallIntValue::Get(value); - case PhysicalType::INT32: - return IntegerValue::Get(value); - case PhysicalType::INT64: - return BigIntValue::Get(value); - case PhysicalType::INT128: - return HugeIntValue::Get(value); - case PhysicalType::UINT8: - return UTinyIntValue::Get(value); - case PhysicalType::UINT16: - return USmallIntValue::Get(value); - case PhysicalType::UINT32: - return UIntegerValue::Get(value); - case PhysicalType::UINT64: - return UBigIntValue::Get(value); - default: - throw InternalException("Invalid internal type \"%s\" for IntegralValue::Get", value.type().ToString()); - } -} - -//===--------------------------------------------------------------------===// -// Comparison Operators -//===--------------------------------------------------------------------===// -bool Value::operator==(const Value &rhs) const { - return ValueOperations::Equals(*this, rhs); -} - -bool Value::operator!=(const Value &rhs) const { - return ValueOperations::NotEquals(*this, rhs); -} - -bool Value::operator<(const Value &rhs) const { - return ValueOperations::LessThan(*this, rhs); -} - -bool Value::operator>(const Value &rhs) const { - return ValueOperations::GreaterThan(*this, rhs); -} - -bool Value::operator<=(const Value &rhs) const { - return ValueOperations::LessThanEquals(*this, rhs); -} - -bool Value::operator>=(const Value &rhs) const { - return ValueOperations::GreaterThanEquals(*this, rhs); -} - -bool Value::operator==(const int64_t &rhs) const { - return *this == Value::Numeric(type_, rhs); -} - -bool Value::operator!=(const int64_t &rhs) const { - return *this != Value::Numeric(type_, rhs); -} - -bool Value::operator<(const int64_t &rhs) const { - return *this < Value::Numeric(type_, rhs); -} - -bool Value::operator>(const int64_t &rhs) const { - return *this > Value::Numeric(type_, rhs); -} - -bool Value::operator<=(const int64_t &rhs) const { - return *this <= Value::Numeric(type_, rhs); -} - -bool Value::operator>=(const int64_t &rhs) const { - return *this >= Value::Numeric(type_, rhs); -} - -bool Value::TryCastAs(CastFunctionSet &set, GetCastFunctionInput &get_input, const LogicalType &target_type, - Value &new_value, string *error_message, bool strict) const { - if (type_ == target_type) { - new_value = Copy(); - return true; - } - Vector input(*this); - Vector result(target_type); - if (!VectorOperations::TryCast(set, get_input, input, result, 1, error_message, strict)) { - return false; - } - new_value = result.GetValue(0); - return true; -} - -bool Value::TryCastAs(ClientContext &context, const LogicalType &target_type, Value &new_value, string *error_message, - bool strict) const { - GetCastFunctionInput get_input(context); - return TryCastAs(CastFunctionSet::Get(context), get_input, target_type, new_value, error_message, strict); -} - -bool Value::DefaultTryCastAs(const LogicalType &target_type, Value &new_value, string *error_message, - bool strict) const { - CastFunctionSet set; - GetCastFunctionInput get_input; - return TryCastAs(set, get_input, target_type, new_value, error_message, strict); -} - -Value Value::CastAs(CastFunctionSet &set, GetCastFunctionInput &get_input, const LogicalType &target_type, - bool strict) const { - Value new_value; - string error_message; - if (!TryCastAs(set, get_input, target_type, new_value, &error_message, strict)) { - throw InvalidInputException("Failed to cast value: %s", error_message); - } - return new_value; -} - -Value Value::CastAs(ClientContext &context, const LogicalType &target_type, bool strict) const { - GetCastFunctionInput get_input(context); - return CastAs(CastFunctionSet::Get(context), get_input, target_type, strict); -} - -Value Value::DefaultCastAs(const LogicalType &target_type, bool strict) const { - CastFunctionSet set; - GetCastFunctionInput get_input; - return CastAs(set, get_input, target_type, strict); -} - -bool Value::TryCastAs(CastFunctionSet &set, GetCastFunctionInput &get_input, const LogicalType &target_type, - bool strict) { - Value new_value; - string error_message; - if (!TryCastAs(set, get_input, target_type, new_value, &error_message, strict)) { - return false; - } - type_ = target_type; - is_null = new_value.is_null; - value_ = new_value.value_; - value_info_ = std::move(new_value.value_info_); - return true; -} - -bool Value::TryCastAs(ClientContext &context, const LogicalType &target_type, bool strict) { - GetCastFunctionInput get_input(context); - return TryCastAs(CastFunctionSet::Get(context), get_input, target_type, strict); -} - -bool Value::DefaultTryCastAs(const LogicalType &target_type, bool strict) { - CastFunctionSet set; - GetCastFunctionInput get_input; - return TryCastAs(set, get_input, target_type, strict); -} - -void Value::Reinterpret(LogicalType new_type) { - this->type_ = std::move(new_type); -} - -void Value::Serialize(Serializer &serializer) const { - serializer.WriteProperty(100, "type", type_); - serializer.WriteProperty(101, "is_null", is_null); - if (!IsNull()) { - switch (type_.InternalType()) { - case PhysicalType::BIT: - throw InternalException("BIT type should not be serialized"); - case PhysicalType::BOOL: - serializer.WriteProperty(102, "value", value_.boolean); - break; - case PhysicalType::INT8: - serializer.WriteProperty(102, "value", value_.tinyint); - break; - case PhysicalType::INT16: - serializer.WriteProperty(102, "value", value_.smallint); - break; - case PhysicalType::INT32: - serializer.WriteProperty(102, "value", value_.integer); - break; - case PhysicalType::INT64: - serializer.WriteProperty(102, "value", value_.bigint); - break; - case PhysicalType::UINT8: - serializer.WriteProperty(102, "value", value_.utinyint); - break; - case PhysicalType::UINT16: - serializer.WriteProperty(102, "value", value_.usmallint); - break; - case PhysicalType::UINT32: - serializer.WriteProperty(102, "value", value_.uinteger); - break; - case PhysicalType::UINT64: - serializer.WriteProperty(102, "value", value_.ubigint); - break; - case PhysicalType::INT128: - serializer.WriteProperty(102, "value", value_.hugeint); - break; - case PhysicalType::FLOAT: - serializer.WriteProperty(102, "value", value_.float_); - break; - case PhysicalType::DOUBLE: - serializer.WriteProperty(102, "value", value_.double_); - break; - case PhysicalType::INTERVAL: - serializer.WriteProperty(102, "value", value_.interval); - break; - case PhysicalType::VARCHAR: { - if (type_.id() == LogicalTypeId::BLOB) { - auto blob_str = Blob::ToString(StringValue::Get(*this)); - serializer.WriteProperty(102, "value", blob_str); - } else { - serializer.WriteProperty(102, "value", StringValue::Get(*this)); - } - } break; - case PhysicalType::LIST: { - serializer.WriteObject(102, "value", [&](Serializer &serializer) { - auto &children = ListValue::GetChildren(*this); - serializer.WriteProperty(100, "children", children); - }); - } break; - case PhysicalType::STRUCT: { - serializer.WriteObject(102, "value", [&](Serializer &serializer) { - auto &children = StructValue::GetChildren(*this); - serializer.WriteProperty(100, "children", children); - }); - } break; - default: - throw NotImplementedException("Unimplemented type for Serialize"); - } - } -} - -Value Value::Deserialize(Deserializer &deserializer) { - auto type = deserializer.ReadProperty(100, "type"); - auto is_null = deserializer.ReadProperty(101, "is_null"); - Value new_value = Value(type); - if (is_null) { - return new_value; - } - new_value.is_null = false; - switch (type.InternalType()) { - case PhysicalType::BIT: - throw InternalException("BIT type should not be deserialized"); - case PhysicalType::BOOL: - new_value.value_.boolean = deserializer.ReadProperty(102, "value"); - break; - case PhysicalType::UINT8: - new_value.value_.utinyint = deserializer.ReadProperty(102, "value"); - break; - case PhysicalType::INT8: - new_value.value_.tinyint = deserializer.ReadProperty(102, "value"); - break; - case PhysicalType::UINT16: - new_value.value_.usmallint = deserializer.ReadProperty(102, "value"); - break; - case PhysicalType::INT16: - new_value.value_.smallint = deserializer.ReadProperty(102, "value"); - break; - case PhysicalType::UINT32: - new_value.value_.uinteger = deserializer.ReadProperty(102, "value"); - break; - case PhysicalType::INT32: - new_value.value_.integer = deserializer.ReadProperty(102, "value"); - break; - case PhysicalType::UINT64: - new_value.value_.ubigint = deserializer.ReadProperty(102, "value"); - break; - case PhysicalType::INT64: - new_value.value_.bigint = deserializer.ReadProperty(102, "value"); - break; - case PhysicalType::INT128: - new_value.value_.hugeint = deserializer.ReadProperty(102, "value"); - break; - case PhysicalType::FLOAT: - new_value.value_.float_ = deserializer.ReadProperty(102, "value"); - break; - case PhysicalType::DOUBLE: - new_value.value_.double_ = deserializer.ReadProperty(102, "value"); - break; - case PhysicalType::INTERVAL: - new_value.value_.interval = deserializer.ReadProperty(102, "value"); - break; - case PhysicalType::VARCHAR: { - auto str = deserializer.ReadProperty(102, "value"); - if (type.id() == LogicalTypeId::BLOB) { - new_value.value_info_ = make_shared(Blob::ToBlob(str)); - } else { - new_value.value_info_ = make_shared(str); - } - } break; - case PhysicalType::LIST: { - deserializer.ReadObject(102, "value", [&](Deserializer &obj) { - auto children = obj.ReadProperty>(100, "children"); - new_value.value_info_ = make_shared(children); - }); - } break; - case PhysicalType::STRUCT: { - deserializer.ReadObject(102, "value", [&](Deserializer &obj) { - auto children = obj.ReadProperty>(100, "children"); - new_value.value_info_ = make_shared(children); - }); - } break; - default: - throw NotImplementedException("Unimplemented type for Deserialize"); - } - return new_value; -} - -void Value::Print() const { - Printer::Print(ToString()); -} - -bool Value::NotDistinctFrom(const Value &lvalue, const Value &rvalue) { - return ValueOperations::NotDistinctFrom(lvalue, rvalue); -} - -static string SanitizeValue(string input) { - // some results might contain padding spaces, e.g. when rendering - // VARCHAR(10) and the string only has 6 characters, they will be padded - // with spaces to 10 in the rendering. We don't do that here yet as we - // are looking at internal structures. So just ignore any extra spaces - // on the right - StringUtil::RTrim(input); - // for result checking code, replace null bytes with their escaped value (\0) - return StringUtil::Replace(input, string("\0", 1), "\\0"); -} - -bool Value::ValuesAreEqual(CastFunctionSet &set, GetCastFunctionInput &get_input, const Value &result_value, - const Value &value) { - if (result_value.IsNull() != value.IsNull()) { - return false; - } - if (result_value.IsNull() && value.IsNull()) { - // NULL = NULL in checking code - return true; - } - switch (value.type_.id()) { - case LogicalTypeId::FLOAT: { - auto other = result_value.CastAs(set, get_input, LogicalType::FLOAT); - float ldecimal = value.value_.float_; - float rdecimal = other.value_.float_; - return ApproxEqual(ldecimal, rdecimal); - } - case LogicalTypeId::DOUBLE: { - auto other = result_value.CastAs(set, get_input, LogicalType::DOUBLE); - double ldecimal = value.value_.double_; - double rdecimal = other.value_.double_; - return ApproxEqual(ldecimal, rdecimal); - } - case LogicalTypeId::VARCHAR: { - auto other = result_value.CastAs(set, get_input, LogicalType::VARCHAR); - string left = SanitizeValue(StringValue::Get(other)); - string right = SanitizeValue(StringValue::Get(value)); - return left == right; - } - default: - if (result_value.type_.id() == LogicalTypeId::FLOAT || result_value.type_.id() == LogicalTypeId::DOUBLE) { - return Value::ValuesAreEqual(set, get_input, value, result_value); - } - return value == result_value; - } -} - -bool Value::ValuesAreEqual(ClientContext &context, const Value &result_value, const Value &value) { - GetCastFunctionInput get_input(context); - return Value::ValuesAreEqual(CastFunctionSet::Get(context), get_input, result_value, value); -} -bool Value::DefaultValuesAreEqual(const Value &result_value, const Value &value) { - CastFunctionSet set; - GetCastFunctionInput get_input; - return Value::ValuesAreEqual(set, get_input, result_value, value); -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - - - - - - - -#include // strlen() on Solaris - -namespace duckdb { - -Vector::Vector(LogicalType type_p, bool create_data, bool zero_data, idx_t capacity) - : vector_type(VectorType::FLAT_VECTOR), type(std::move(type_p)), data(nullptr) { - if (create_data) { - Initialize(zero_data, capacity); - } -} - -Vector::Vector(LogicalType type_p, idx_t capacity) : Vector(std::move(type_p), true, false, capacity) { -} - -Vector::Vector(LogicalType type_p, data_ptr_t dataptr) - : vector_type(VectorType::FLAT_VECTOR), type(std::move(type_p)), data(dataptr) { - if (dataptr && !type.IsValid()) { - throw InternalException("Cannot create a vector of type INVALID!"); - } -} - -Vector::Vector(const VectorCache &cache) : type(cache.GetType()) { - ResetFromCache(cache); -} - -Vector::Vector(Vector &other) : type(other.type) { - Reference(other); -} - -Vector::Vector(Vector &other, const SelectionVector &sel, idx_t count) : type(other.type) { - Slice(other, sel, count); -} - -Vector::Vector(Vector &other, idx_t offset, idx_t end) : type(other.type) { - Slice(other, offset, end); -} - -Vector::Vector(const Value &value) : type(value.type()) { - Reference(value); -} - -Vector::Vector(Vector &&other) noexcept - : vector_type(other.vector_type), type(std::move(other.type)), data(other.data), - validity(std::move(other.validity)), buffer(std::move(other.buffer)), auxiliary(std::move(other.auxiliary)) { -} - -void Vector::Reference(const Value &value) { - D_ASSERT(GetType().id() == value.type().id()); - this->vector_type = VectorType::CONSTANT_VECTOR; - buffer = VectorBuffer::CreateConstantVector(value.type()); - auto internal_type = value.type().InternalType(); - if (internal_type == PhysicalType::STRUCT) { - auto struct_buffer = make_uniq(); - auto &child_types = StructType::GetChildTypes(value.type()); - auto &child_vectors = struct_buffer->GetChildren(); - for (idx_t i = 0; i < child_types.size(); i++) { - auto vector = - make_uniq(value.IsNull() ? Value(child_types[i].second) : StructValue::GetChildren(value)[i]); - child_vectors.push_back(std::move(vector)); - } - auxiliary = shared_ptr(struct_buffer.release()); - if (value.IsNull()) { - SetValue(0, value); - } - } else if (internal_type == PhysicalType::LIST) { - auto list_buffer = make_uniq(value.type()); - auxiliary = shared_ptr(list_buffer.release()); - data = buffer->GetData(); - SetValue(0, value); - } else { - auxiliary.reset(); - data = buffer->GetData(); - SetValue(0, value); - } -} - -void Vector::Reference(const Vector &other) { - if (other.GetType().id() != GetType().id()) { - throw InternalException("Vector::Reference used on vector of different type"); - } - D_ASSERT(other.GetType() == GetType()); - Reinterpret(other); -} - -void Vector::ReferenceAndSetType(const Vector &other) { - type = other.GetType(); - Reference(other); -} - -void Vector::Reinterpret(const Vector &other) { - vector_type = other.vector_type; - AssignSharedPointer(buffer, other.buffer); - AssignSharedPointer(auxiliary, other.auxiliary); - data = other.data; - validity = other.validity; -} - -void Vector::ResetFromCache(const VectorCache &cache) { - cache.ResetFromCache(*this); -} - -void Vector::Slice(Vector &other, idx_t offset, idx_t end) { - if (other.GetVectorType() == VectorType::CONSTANT_VECTOR) { - Reference(other); - return; - } - D_ASSERT(other.GetVectorType() == VectorType::FLAT_VECTOR); - - auto internal_type = GetType().InternalType(); - if (internal_type == PhysicalType::STRUCT) { - Vector new_vector(GetType()); - auto &entries = StructVector::GetEntries(new_vector); - auto &other_entries = StructVector::GetEntries(other); - D_ASSERT(entries.size() == other_entries.size()); - for (idx_t i = 0; i < entries.size(); i++) { - entries[i]->Slice(*other_entries[i], offset, end); - } - new_vector.validity.Slice(other.validity, offset, end - offset); - Reference(new_vector); - } else { - Reference(other); - if (offset > 0) { - data = data + GetTypeIdSize(internal_type) * offset; - validity.Slice(other.validity, offset, end - offset); - } - } -} - -void Vector::Slice(Vector &other, const SelectionVector &sel, idx_t count) { - Reference(other); - Slice(sel, count); -} - -void Vector::Slice(const SelectionVector &sel, idx_t count) { - if (GetVectorType() == VectorType::CONSTANT_VECTOR) { - // dictionary on a constant is just a constant - return; - } - if (GetVectorType() == VectorType::DICTIONARY_VECTOR) { - // already a dictionary, slice the current dictionary - auto ¤t_sel = DictionaryVector::SelVector(*this); - auto sliced_dictionary = current_sel.Slice(sel, count); - buffer = make_buffer(std::move(sliced_dictionary)); - if (GetType().InternalType() == PhysicalType::STRUCT) { - auto &child_vector = DictionaryVector::Child(*this); - - Vector new_child(child_vector); - new_child.auxiliary = make_buffer(new_child, sel, count); - auxiliary = make_buffer(std::move(new_child)); - } - return; - } - - if (GetVectorType() == VectorType::FSST_VECTOR) { - Flatten(sel, count); - return; - } - - Vector child_vector(*this); - auto internal_type = GetType().InternalType(); - if (internal_type == PhysicalType::STRUCT) { - child_vector.auxiliary = make_buffer(*this, sel, count); - } - auto child_ref = make_buffer(std::move(child_vector)); - auto dict_buffer = make_buffer(sel); - vector_type = VectorType::DICTIONARY_VECTOR; - buffer = std::move(dict_buffer); - auxiliary = std::move(child_ref); -} - -void Vector::Slice(const SelectionVector &sel, idx_t count, SelCache &cache) { - if (GetVectorType() == VectorType::DICTIONARY_VECTOR && GetType().InternalType() != PhysicalType::STRUCT) { - // dictionary vector: need to merge dictionaries - // check if we have a cached entry - auto ¤t_sel = DictionaryVector::SelVector(*this); - auto target_data = current_sel.data(); - auto entry = cache.cache.find(target_data); - if (entry != cache.cache.end()) { - // cached entry exists: use that - this->buffer = make_buffer(entry->second->Cast().GetSelVector()); - vector_type = VectorType::DICTIONARY_VECTOR; - } else { - Slice(sel, count); - cache.cache[target_data] = this->buffer; - } - } else { - Slice(sel, count); - } -} - -void Vector::Initialize(bool zero_data, idx_t capacity) { - auxiliary.reset(); - validity.Reset(); - auto &type = GetType(); - auto internal_type = type.InternalType(); - if (internal_type == PhysicalType::STRUCT) { - auto struct_buffer = make_uniq(type, capacity); - auxiliary = shared_ptr(struct_buffer.release()); - } else if (internal_type == PhysicalType::LIST) { - auto list_buffer = make_uniq(type, capacity); - auxiliary = shared_ptr(list_buffer.release()); - } - auto type_size = GetTypeIdSize(internal_type); - if (type_size > 0) { - buffer = VectorBuffer::CreateStandardVector(type, capacity); - data = buffer->GetData(); - if (zero_data) { - memset(data, 0, capacity * type_size); - } - } - if (capacity > STANDARD_VECTOR_SIZE) { - validity.Resize(STANDARD_VECTOR_SIZE, capacity); - } -} - -struct DataArrays { - Vector &vec; - data_ptr_t data; - optional_ptr buffer; - idx_t type_size; - bool is_nested; - DataArrays(Vector &vec, data_ptr_t data, optional_ptr buffer, idx_t type_size, bool is_nested) - : vec(vec), data(data), buffer(buffer), type_size(type_size), is_nested(is_nested) { - } -}; - -void FindChildren(vector &to_resize, VectorBuffer &auxiliary) { - if (auxiliary.GetBufferType() == VectorBufferType::LIST_BUFFER) { - auto &buffer = auxiliary.Cast(); - auto &child = buffer.GetChild(); - auto data = child.GetData(); - if (!data) { - //! Nested type - DataArrays arrays(child, data, child.GetBuffer().get(), GetTypeIdSize(child.GetType().InternalType()), - true); - to_resize.emplace_back(arrays); - FindChildren(to_resize, *child.GetAuxiliary()); - } else { - DataArrays arrays(child, data, child.GetBuffer().get(), GetTypeIdSize(child.GetType().InternalType()), - false); - to_resize.emplace_back(arrays); - } - } else if (auxiliary.GetBufferType() == VectorBufferType::STRUCT_BUFFER) { - auto &buffer = auxiliary.Cast(); - auto &children = buffer.GetChildren(); - for (auto &child : children) { - auto data = child->GetData(); - if (!data) { - //! Nested type - DataArrays arrays(*child, data, child->GetBuffer().get(), - GetTypeIdSize(child->GetType().InternalType()), true); - to_resize.emplace_back(arrays); - FindChildren(to_resize, *child->GetAuxiliary()); - } else { - DataArrays arrays(*child, data, child->GetBuffer().get(), - GetTypeIdSize(child->GetType().InternalType()), false); - to_resize.emplace_back(arrays); - } - } - } -} -void Vector::Resize(idx_t cur_size, idx_t new_size) { - vector to_resize; - if (!buffer) { - buffer = make_buffer(0); - } - if (!data) { - //! this is a nested structure - DataArrays arrays(*this, data, buffer.get(), GetTypeIdSize(GetType().InternalType()), true); - to_resize.emplace_back(arrays); - FindChildren(to_resize, *auxiliary); - } else { - DataArrays arrays(*this, data, buffer.get(), GetTypeIdSize(GetType().InternalType()), false); - to_resize.emplace_back(arrays); - } - for (auto &data_to_resize : to_resize) { - if (!data_to_resize.is_nested) { - auto new_data = make_unsafe_uniq_array(new_size * data_to_resize.type_size); - memcpy(new_data.get(), data_to_resize.data, cur_size * data_to_resize.type_size * sizeof(data_t)); - data_to_resize.buffer->SetData(std::move(new_data)); - data_to_resize.vec.data = data_to_resize.buffer->GetData(); - } - data_to_resize.vec.validity.Resize(cur_size, new_size); - } -} - -void Vector::SetValue(idx_t index, const Value &val) { - if (GetVectorType() == VectorType::DICTIONARY_VECTOR) { - // dictionary: apply dictionary and forward to child - auto &sel_vector = DictionaryVector::SelVector(*this); - auto &child = DictionaryVector::Child(*this); - return child.SetValue(sel_vector.get_index(index), val); - } - if (val.type() != GetType()) { - SetValue(index, val.DefaultCastAs(GetType())); - return; - } - D_ASSERT(val.type().InternalType() == GetType().InternalType()); - - validity.EnsureWritable(); - validity.Set(index, !val.IsNull()); - if (val.IsNull() && GetType().InternalType() != PhysicalType::STRUCT) { - // for structs we still need to set the child-entries to NULL - // so we do not bail out yet - return; - } - - switch (GetType().InternalType()) { - case PhysicalType::BOOL: - reinterpret_cast(data)[index] = val.GetValueUnsafe(); - break; - case PhysicalType::INT8: - reinterpret_cast(data)[index] = val.GetValueUnsafe(); - break; - case PhysicalType::INT16: - reinterpret_cast(data)[index] = val.GetValueUnsafe(); - break; - case PhysicalType::INT32: - reinterpret_cast(data)[index] = val.GetValueUnsafe(); - break; - case PhysicalType::INT64: - reinterpret_cast(data)[index] = val.GetValueUnsafe(); - break; - case PhysicalType::INT128: - reinterpret_cast(data)[index] = val.GetValueUnsafe(); - break; - case PhysicalType::UINT8: - reinterpret_cast(data)[index] = val.GetValueUnsafe(); - break; - case PhysicalType::UINT16: - reinterpret_cast(data)[index] = val.GetValueUnsafe(); - break; - case PhysicalType::UINT32: - reinterpret_cast(data)[index] = val.GetValueUnsafe(); - break; - case PhysicalType::UINT64: - reinterpret_cast(data)[index] = val.GetValueUnsafe(); - break; - case PhysicalType::FLOAT: - reinterpret_cast(data)[index] = val.GetValueUnsafe(); - break; - case PhysicalType::DOUBLE: - reinterpret_cast(data)[index] = val.GetValueUnsafe(); - break; - case PhysicalType::INTERVAL: - reinterpret_cast(data)[index] = val.GetValueUnsafe(); - break; - case PhysicalType::VARCHAR: - reinterpret_cast(data)[index] = StringVector::AddStringOrBlob(*this, StringValue::Get(val)); - break; - case PhysicalType::STRUCT: { - D_ASSERT(GetVectorType() == VectorType::CONSTANT_VECTOR || GetVectorType() == VectorType::FLAT_VECTOR); - - auto &children = StructVector::GetEntries(*this); - if (val.IsNull()) { - for (size_t i = 0; i < children.size(); i++) { - auto &vec_child = children[i]; - vec_child->SetValue(index, Value()); - } - } else { - auto &val_children = StructValue::GetChildren(val); - D_ASSERT(children.size() == val_children.size()); - for (size_t i = 0; i < children.size(); i++) { - auto &vec_child = children[i]; - auto &struct_child = val_children[i]; - vec_child->SetValue(index, struct_child); - } - } - break; - } - case PhysicalType::LIST: { - auto offset = ListVector::GetListSize(*this); - auto &val_children = ListValue::GetChildren(val); - if (!val_children.empty()) { - for (idx_t i = 0; i < val_children.size(); i++) { - ListVector::PushBack(*this, val_children[i]); - } - } - //! now set the pointer - auto &entry = reinterpret_cast(data)[index]; - entry.length = val_children.size(); - entry.offset = offset; - break; - } - default: - throw InternalException("Unimplemented type for Vector::SetValue"); - } -} - -Value Vector::GetValueInternal(const Vector &v_p, idx_t index_p) { - const Vector *vector = &v_p; - idx_t index = index_p; - bool finished = false; - while (!finished) { - switch (vector->GetVectorType()) { - case VectorType::CONSTANT_VECTOR: - index = 0; - finished = true; - break; - case VectorType::FLAT_VECTOR: - finished = true; - break; - case VectorType::FSST_VECTOR: - finished = true; - break; - // dictionary: apply dictionary and forward to child - case VectorType::DICTIONARY_VECTOR: { - auto &sel_vector = DictionaryVector::SelVector(*vector); - auto &child = DictionaryVector::Child(*vector); - vector = &child; - index = sel_vector.get_index(index); - break; - } - case VectorType::SEQUENCE_VECTOR: { - int64_t start, increment; - SequenceVector::GetSequence(*vector, start, increment); - return Value::Numeric(vector->GetType(), start + increment * index); - } - default: - throw InternalException("Unimplemented vector type for Vector::GetValue"); - } - } - auto data = vector->data; - auto &validity = vector->validity; - auto &type = vector->GetType(); - - if (!validity.RowIsValid(index)) { - return Value(vector->GetType()); - } - - if (vector->GetVectorType() == VectorType::FSST_VECTOR) { - if (vector->GetType().InternalType() != PhysicalType::VARCHAR) { - throw InternalException("FSST Vector with non-string datatype found!"); - } - auto str_compressed = reinterpret_cast(data)[index]; - Value result = FSSTPrimitives::DecompressValue(FSSTVector::GetDecoder(const_cast(*vector)), - str_compressed.GetData(), str_compressed.GetSize()); - return result; - } - - switch (vector->GetType().id()) { - case LogicalTypeId::BOOLEAN: - return Value::BOOLEAN(reinterpret_cast(data)[index]); - case LogicalTypeId::TINYINT: - return Value::TINYINT(reinterpret_cast(data)[index]); - case LogicalTypeId::SMALLINT: - return Value::SMALLINT(reinterpret_cast(data)[index]); - case LogicalTypeId::INTEGER: - return Value::INTEGER(reinterpret_cast(data)[index]); - case LogicalTypeId::DATE: - return Value::DATE(reinterpret_cast(data)[index]); - case LogicalTypeId::TIME: - return Value::TIME(reinterpret_cast(data)[index]); - case LogicalTypeId::TIME_TZ: - return Value::TIMETZ(reinterpret_cast(data)[index]); - case LogicalTypeId::BIGINT: - return Value::BIGINT(reinterpret_cast(data)[index]); - case LogicalTypeId::UTINYINT: - return Value::UTINYINT(reinterpret_cast(data)[index]); - case LogicalTypeId::USMALLINT: - return Value::USMALLINT(reinterpret_cast(data)[index]); - case LogicalTypeId::UINTEGER: - return Value::UINTEGER(reinterpret_cast(data)[index]); - case LogicalTypeId::UBIGINT: - return Value::UBIGINT(reinterpret_cast(data)[index]); - case LogicalTypeId::TIMESTAMP: - return Value::TIMESTAMP(reinterpret_cast(data)[index]); - case LogicalTypeId::TIMESTAMP_NS: - return Value::TIMESTAMPNS(reinterpret_cast(data)[index]); - case LogicalTypeId::TIMESTAMP_MS: - return Value::TIMESTAMPMS(reinterpret_cast(data)[index]); - case LogicalTypeId::TIMESTAMP_SEC: - return Value::TIMESTAMPSEC(reinterpret_cast(data)[index]); - case LogicalTypeId::TIMESTAMP_TZ: - return Value::TIMESTAMPTZ(reinterpret_cast(data)[index]); - case LogicalTypeId::HUGEINT: - return Value::HUGEINT(reinterpret_cast(data)[index]); - case LogicalTypeId::UUID: - return Value::UUID(reinterpret_cast(data)[index]); - case LogicalTypeId::DECIMAL: { - auto width = DecimalType::GetWidth(type); - auto scale = DecimalType::GetScale(type); - switch (type.InternalType()) { - case PhysicalType::INT16: - return Value::DECIMAL(reinterpret_cast(data)[index], width, scale); - case PhysicalType::INT32: - return Value::DECIMAL(reinterpret_cast(data)[index], width, scale); - case PhysicalType::INT64: - return Value::DECIMAL(reinterpret_cast(data)[index], width, scale); - case PhysicalType::INT128: - return Value::DECIMAL(reinterpret_cast(data)[index], width, scale); - default: - throw InternalException("Physical type '%s' has a width bigger than 38, which is not supported", - TypeIdToString(type.InternalType())); - } - } - case LogicalTypeId::ENUM: { - switch (type.InternalType()) { - case PhysicalType::UINT8: - return Value::ENUM(reinterpret_cast(data)[index], type); - case PhysicalType::UINT16: - return Value::ENUM(reinterpret_cast(data)[index], type); - case PhysicalType::UINT32: - return Value::ENUM(reinterpret_cast(data)[index], type); - default: - throw InternalException("ENUM can only have unsigned integers as physical types"); - } - } - case LogicalTypeId::POINTER: - return Value::POINTER(reinterpret_cast(data)[index]); - case LogicalTypeId::FLOAT: - return Value::FLOAT(reinterpret_cast(data)[index]); - case LogicalTypeId::DOUBLE: - return Value::DOUBLE(reinterpret_cast(data)[index]); - case LogicalTypeId::INTERVAL: - return Value::INTERVAL(reinterpret_cast(data)[index]); - case LogicalTypeId::VARCHAR: { - auto str = reinterpret_cast(data)[index]; - return Value(str.GetString()); - } - case LogicalTypeId::AGGREGATE_STATE: - case LogicalTypeId::BLOB: { - auto str = reinterpret_cast(data)[index]; - return Value::BLOB(const_data_ptr_cast(str.GetData()), str.GetSize()); - } - case LogicalTypeId::BIT: { - auto str = reinterpret_cast(data)[index]; - return Value::BIT(const_data_ptr_cast(str.GetData()), str.GetSize()); - } - case LogicalTypeId::MAP: { - auto offlen = reinterpret_cast(data)[index]; - auto &child_vec = ListVector::GetEntry(*vector); - duckdb::vector children; - for (idx_t i = offlen.offset; i < offlen.offset + offlen.length; i++) { - children.push_back(child_vec.GetValue(i)); - } - return Value::MAP(ListType::GetChildType(type), std::move(children)); - } - case LogicalTypeId::UNION: { - auto tag = UnionVector::GetTag(*vector, index); - auto value = UnionVector::GetMember(*vector, tag).GetValue(index); - auto members = UnionType::CopyMemberTypes(type); - return Value::UNION(members, tag, std::move(value)); - } - case LogicalTypeId::STRUCT: { - // we can derive the value schema from the vector schema - auto &child_entries = StructVector::GetEntries(*vector); - child_list_t children; - for (idx_t child_idx = 0; child_idx < child_entries.size(); child_idx++) { - auto &struct_child = child_entries[child_idx]; - children.push_back(make_pair(StructType::GetChildName(type, child_idx), struct_child->GetValue(index_p))); - } - return Value::STRUCT(std::move(children)); - } - case LogicalTypeId::LIST: { - auto offlen = reinterpret_cast(data)[index]; - auto &child_vec = ListVector::GetEntry(*vector); - duckdb::vector children; - for (idx_t i = offlen.offset; i < offlen.offset + offlen.length; i++) { - children.push_back(child_vec.GetValue(i)); - } - return Value::LIST(ListType::GetChildType(type), std::move(children)); - } - default: - throw InternalException("Unimplemented type for value access"); - } -} - -Value Vector::GetValue(const Vector &v_p, idx_t index_p) { - auto value = GetValueInternal(v_p, index_p); - // set the alias of the type to the correct value, if there is a type alias - if (v_p.GetType().HasAlias()) { - value.GetTypeMutable().CopyAuxInfo(v_p.GetType()); - } - if (v_p.GetType().id() != LogicalTypeId::AGGREGATE_STATE && value.type().id() != LogicalTypeId::AGGREGATE_STATE) { - - D_ASSERT(v_p.GetType() == value.type()); - } - return value; -} - -Value Vector::GetValue(idx_t index) const { - return GetValue(*this, index); -} - -// LCOV_EXCL_START -string VectorTypeToString(VectorType type) { - switch (type) { - case VectorType::FLAT_VECTOR: - return "FLAT"; - case VectorType::FSST_VECTOR: - return "FSST"; - case VectorType::SEQUENCE_VECTOR: - return "SEQUENCE"; - case VectorType::DICTIONARY_VECTOR: - return "DICTIONARY"; - case VectorType::CONSTANT_VECTOR: - return "CONSTANT"; - default: - return "UNKNOWN"; - } -} - -string Vector::ToString(idx_t count) const { - string retval = - VectorTypeToString(GetVectorType()) + " " + GetType().ToString() + ": " + to_string(count) + " = [ "; - switch (GetVectorType()) { - case VectorType::FLAT_VECTOR: - case VectorType::DICTIONARY_VECTOR: - for (idx_t i = 0; i < count; i++) { - retval += GetValue(i).ToString() + (i == count - 1 ? "" : ", "); - } - break; - case VectorType::FSST_VECTOR: { - for (idx_t i = 0; i < count; i++) { - string_t compressed_string = reinterpret_cast(data)[i]; - Value val = FSSTPrimitives::DecompressValue(FSSTVector::GetDecoder(const_cast(*this)), - compressed_string.GetData(), compressed_string.GetSize()); - retval += GetValue(i).ToString() + (i == count - 1 ? "" : ", "); - } - } break; - case VectorType::CONSTANT_VECTOR: - retval += GetValue(0).ToString(); - break; - case VectorType::SEQUENCE_VECTOR: { - int64_t start, increment; - SequenceVector::GetSequence(*this, start, increment); - for (idx_t i = 0; i < count; i++) { - retval += to_string(start + increment * i) + (i == count - 1 ? "" : ", "); - } - break; - } - default: - retval += "UNKNOWN VECTOR TYPE"; - break; - } - retval += "]"; - return retval; -} - -void Vector::Print(idx_t count) const { - Printer::Print(ToString(count)); -} - -string Vector::ToString() const { - string retval = VectorTypeToString(GetVectorType()) + " " + GetType().ToString() + ": (UNKNOWN COUNT) [ "; - switch (GetVectorType()) { - case VectorType::FLAT_VECTOR: - case VectorType::DICTIONARY_VECTOR: - break; - case VectorType::CONSTANT_VECTOR: - retval += GetValue(0).ToString(); - break; - case VectorType::SEQUENCE_VECTOR: { - break; - } - default: - retval += "UNKNOWN VECTOR TYPE"; - break; - } - retval += "]"; - return retval; -} - -void Vector::Print() const { - Printer::Print(ToString()); -} -// LCOV_EXCL_STOP - -template -static void TemplatedFlattenConstantVector(data_ptr_t data, data_ptr_t old_data, idx_t count) { - auto constant = Load(old_data); - auto output = (T *)data; - for (idx_t i = 0; i < count; i++) { - output[i] = constant; - } -} - -void Vector::Flatten(idx_t count) { - switch (GetVectorType()) { - case VectorType::FLAT_VECTOR: - // already a flat vector - break; - case VectorType::FSST_VECTOR: { - // Even though count may only be a part of the vector, we need to flatten the whole thing due to the way - // ToUnifiedFormat uses flatten - idx_t total_count = FSSTVector::GetCount(*this); - // create vector to decompress into - Vector other(GetType(), total_count); - // now copy the data of this vector to the other vector, decompressing the strings in the process - VectorOperations::Copy(*this, other, total_count, 0, 0); - // create a reference to the data in the other vector - this->Reference(other); - break; - } - case VectorType::DICTIONARY_VECTOR: { - // create a new flat vector of this type - Vector other(GetType(), count); - // now copy the data of this vector to the other vector, removing the selection vector in the process - VectorOperations::Copy(*this, other, count, 0, 0); - // create a reference to the data in the other vector - this->Reference(other); - break; - } - case VectorType::CONSTANT_VECTOR: { - bool is_null = ConstantVector::IsNull(*this); - // allocate a new buffer for the vector - auto old_buffer = std::move(buffer); - auto old_data = data; - buffer = VectorBuffer::CreateStandardVector(type, MaxValue(STANDARD_VECTOR_SIZE, count)); - if (old_buffer) { - D_ASSERT(buffer->GetAuxiliaryData() == nullptr); - // The old buffer might be relying on the auxiliary data, keep it alive - buffer->MoveAuxiliaryData(*old_buffer); - } - data = buffer->GetData(); - vector_type = VectorType::FLAT_VECTOR; - if (is_null) { - // constant NULL, set nullmask - validity.EnsureWritable(); - validity.SetAllInvalid(count); - return; - } - // non-null constant: have to repeat the constant - switch (GetType().InternalType()) { - case PhysicalType::BOOL: - TemplatedFlattenConstantVector(data, old_data, count); - break; - case PhysicalType::INT8: - TemplatedFlattenConstantVector(data, old_data, count); - break; - case PhysicalType::INT16: - TemplatedFlattenConstantVector(data, old_data, count); - break; - case PhysicalType::INT32: - TemplatedFlattenConstantVector(data, old_data, count); - break; - case PhysicalType::INT64: - TemplatedFlattenConstantVector(data, old_data, count); - break; - case PhysicalType::UINT8: - TemplatedFlattenConstantVector(data, old_data, count); - break; - case PhysicalType::UINT16: - TemplatedFlattenConstantVector(data, old_data, count); - break; - case PhysicalType::UINT32: - TemplatedFlattenConstantVector(data, old_data, count); - break; - case PhysicalType::UINT64: - TemplatedFlattenConstantVector(data, old_data, count); - break; - case PhysicalType::INT128: - TemplatedFlattenConstantVector(data, old_data, count); - break; - case PhysicalType::FLOAT: - TemplatedFlattenConstantVector(data, old_data, count); - break; - case PhysicalType::DOUBLE: - TemplatedFlattenConstantVector(data, old_data, count); - break; - case PhysicalType::INTERVAL: - TemplatedFlattenConstantVector(data, old_data, count); - break; - case PhysicalType::VARCHAR: - TemplatedFlattenConstantVector(data, old_data, count); - break; - case PhysicalType::LIST: { - TemplatedFlattenConstantVector(data, old_data, count); - break; - } - case PhysicalType::STRUCT: { - auto normalified_buffer = make_uniq(); - - auto &new_children = normalified_buffer->GetChildren(); - - auto &child_entries = StructVector::GetEntries(*this); - for (auto &child : child_entries) { - D_ASSERT(child->GetVectorType() == VectorType::CONSTANT_VECTOR); - auto vector = make_uniq(*child); - vector->Flatten(count); - new_children.push_back(std::move(vector)); - } - auxiliary = shared_ptr(normalified_buffer.release()); - } break; - default: - throw InternalException("Unimplemented type for VectorOperations::Flatten"); - } - break; - } - case VectorType::SEQUENCE_VECTOR: { - int64_t start, increment, sequence_count; - SequenceVector::GetSequence(*this, start, increment, sequence_count); - - buffer = VectorBuffer::CreateStandardVector(GetType()); - data = buffer->GetData(); - VectorOperations::GenerateSequence(*this, sequence_count, start, increment); - break; - } - default: - throw InternalException("Unimplemented type for normalify"); - } -} - -void Vector::Flatten(const SelectionVector &sel, idx_t count) { - switch (GetVectorType()) { - case VectorType::FLAT_VECTOR: - // already a flat vector - break; - case VectorType::FSST_VECTOR: { - // create a new flat vector of this type - Vector other(GetType()); - // copy the data of this vector to the other vector, removing compression and selection vector in the process - VectorOperations::Copy(*this, other, sel, count, 0, 0); - // create a reference to the data in the other vector - this->Reference(other); - break; - } - case VectorType::SEQUENCE_VECTOR: { - int64_t start, increment; - SequenceVector::GetSequence(*this, start, increment); - - buffer = VectorBuffer::CreateStandardVector(GetType()); - data = buffer->GetData(); - VectorOperations::GenerateSequence(*this, count, sel, start, increment); - break; - } - default: - throw InternalException("Unimplemented type for normalify with selection vector"); - } -} - -void Vector::ToUnifiedFormat(idx_t count, UnifiedVectorFormat &format) { - switch (GetVectorType()) { - case VectorType::DICTIONARY_VECTOR: { - auto &sel = DictionaryVector::SelVector(*this); - format.owned_sel.Initialize(sel); - format.sel = &format.owned_sel; - - auto &child = DictionaryVector::Child(*this); - if (child.GetVectorType() == VectorType::FLAT_VECTOR) { - format.data = FlatVector::GetData(child); - format.validity = FlatVector::Validity(child); - } else { - // dictionary with non-flat child: create a new reference to the child and flatten it - Vector child_vector(child); - child_vector.Flatten(sel, count); - auto new_aux = make_buffer(std::move(child_vector)); - - format.data = FlatVector::GetData(new_aux->data); - format.validity = FlatVector::Validity(new_aux->data); - this->auxiliary = std::move(new_aux); - } - break; - } - case VectorType::CONSTANT_VECTOR: - format.sel = ConstantVector::ZeroSelectionVector(count, format.owned_sel); - format.data = ConstantVector::GetData(*this); - format.validity = ConstantVector::Validity(*this); - break; - default: - Flatten(count); - format.sel = FlatVector::IncrementalSelectionVector(); - format.data = FlatVector::GetData(*this); - format.validity = FlatVector::Validity(*this); - break; - } -} - -void Vector::RecursiveToUnifiedFormat(Vector &input, idx_t count, RecursiveUnifiedVectorFormat &data) { - - input.ToUnifiedFormat(count, data.unified); - - if (input.GetType().InternalType() == PhysicalType::LIST) { - auto &child = ListVector::GetEntry(input); - auto child_count = ListVector::GetListSize(input); - data.children.emplace_back(); - Vector::RecursiveToUnifiedFormat(child, child_count, data.children.back()); - - } else if (input.GetType().InternalType() == PhysicalType::STRUCT) { - auto &children = StructVector::GetEntries(input); - for (idx_t i = 0; i < children.size(); i++) { - data.children.emplace_back(); - } - for (idx_t i = 0; i < children.size(); i++) { - Vector::RecursiveToUnifiedFormat(*children[i], count, data.children[i]); - } - } -} - -void Vector::Sequence(int64_t start, int64_t increment, idx_t count) { - this->vector_type = VectorType::SEQUENCE_VECTOR; - this->buffer = make_buffer(sizeof(int64_t) * 3); - auto data = reinterpret_cast(buffer->GetData()); - data[0] = start; - data[1] = increment; - data[2] = int64_t(count); - validity.Reset(); - auxiliary.reset(); -} - -void Vector::Serialize(Serializer &serializer, idx_t count) { - auto &logical_type = GetType(); - - UnifiedVectorFormat vdata; - ToUnifiedFormat(count, vdata); - - const bool all_valid = (count > 0) && !vdata.validity.AllValid(); - serializer.WriteProperty(100, "all_valid", all_valid); - if (all_valid) { - ValidityMask flat_mask(count); - for (idx_t i = 0; i < count; ++i) { - auto row_idx = vdata.sel->get_index(i); - flat_mask.Set(i, vdata.validity.RowIsValid(row_idx)); - } - serializer.WriteProperty(101, "validity", const_data_ptr_cast(flat_mask.GetData()), - flat_mask.ValidityMaskSize(count)); - } - if (TypeIsConstantSize(logical_type.InternalType())) { - // constant size type: simple copy - idx_t write_size = GetTypeIdSize(logical_type.InternalType()) * count; - auto ptr = make_unsafe_uniq_array(write_size); - VectorOperations::WriteToStorage(*this, count, ptr.get()); - serializer.WriteProperty(102, "data", ptr.get(), write_size); - } else { - switch (logical_type.InternalType()) { - case PhysicalType::VARCHAR: { - auto strings = UnifiedVectorFormat::GetData(vdata); - - // Serialize data as a list - serializer.WriteList(102, "data", count, [&](Serializer::List &list, idx_t i) { - auto idx = vdata.sel->get_index(i); - auto str = !vdata.validity.RowIsValid(idx) ? NullValue() : strings[idx]; - list.WriteElement(str); - }); - break; - } - case PhysicalType::STRUCT: { - auto &entries = StructVector::GetEntries(*this); - - // Serialize entries as a list - serializer.WriteList(103, "children", entries.size(), [&](Serializer::List &list, idx_t i) { - list.WriteObject([&](Serializer &object) { entries[i]->Serialize(object, count); }); - }); - break; - } - case PhysicalType::LIST: { - auto &child = ListVector::GetEntry(*this); - auto list_size = ListVector::GetListSize(*this); - - // serialize the list entries in a flat array - auto entries = make_unsafe_uniq_array(count); - auto source_array = UnifiedVectorFormat::GetData(vdata); - for (idx_t i = 0; i < count; i++) { - auto idx = vdata.sel->get_index(i); - auto source = source_array[idx]; - entries[i].offset = source.offset; - entries[i].length = source.length; - } - serializer.WriteProperty(104, "list_size", list_size); - serializer.WriteList(105, "entries", count, [&](Serializer::List &list, idx_t i) { - list.WriteObject([&](Serializer &object) { - object.WriteProperty(100, "offset", entries[i].offset); - object.WriteProperty(101, "length", entries[i].length); - }); - }); - serializer.WriteObject(106, "child", [&](Serializer &object) { child.Serialize(object, list_size); }); - break; - } - default: - throw InternalException("Unimplemented variable width type for Vector::Serialize!"); - } - } -} - -void Vector::Deserialize(Deserializer &deserializer, idx_t count) { - auto &logical_type = GetType(); - - auto &validity = FlatVector::Validity(*this); - validity.Reset(); - const auto has_validity = deserializer.ReadProperty(100, "all_valid"); - if (has_validity) { - validity.Initialize(count); - deserializer.ReadProperty(101, "validity", data_ptr_cast(validity.GetData()), validity.ValidityMaskSize(count)); - } - - if (TypeIsConstantSize(logical_type.InternalType())) { - // constant size type: read fixed amount of data - auto column_size = GetTypeIdSize(logical_type.InternalType()) * count; - auto ptr = make_unsafe_uniq_array(column_size); - deserializer.ReadProperty(102, "data", ptr.get(), column_size); - - VectorOperations::ReadFromStorage(ptr.get(), count, *this); - } else { - switch (logical_type.InternalType()) { - case PhysicalType::VARCHAR: { - auto strings = FlatVector::GetData(*this); - deserializer.ReadList(102, "data", [&](Deserializer::List &list, idx_t i) { - auto str = list.ReadElement(); - if (validity.RowIsValid(i)) { - strings[i] = StringVector::AddStringOrBlob(*this, str); - } - }); - break; - } - case PhysicalType::STRUCT: { - auto &entries = StructVector::GetEntries(*this); - // Deserialize entries as a list - deserializer.ReadList(103, "children", [&](Deserializer::List &list, idx_t i) { - list.ReadObject([&](Deserializer &obj) { entries[i]->Deserialize(obj, count); }); - }); - break; - } - case PhysicalType::LIST: { - // Read the list size - auto list_size = deserializer.ReadProperty(104, "list_size"); - ListVector::Reserve(*this, list_size); - ListVector::SetListSize(*this, list_size); - - // Read the entries - auto list_entries = FlatVector::GetData(*this); - deserializer.ReadList(105, "entries", [&](Deserializer::List &list, idx_t i) { - list.ReadObject([&](Deserializer &obj) { - list_entries[i].offset = obj.ReadProperty(100, "offset"); - list_entries[i].length = obj.ReadProperty(101, "length"); - }); - }); - - // Read the child vector - deserializer.ReadObject(106, "child", [&](Deserializer &obj) { - auto &child = ListVector::GetEntry(*this); - child.Deserialize(obj, list_size); - }); - break; - } - default: - throw InternalException("Unimplemented variable width type for Vector::Deserialize!"); - } - } -} - -void Vector::SetVectorType(VectorType vector_type_p) { - this->vector_type = vector_type_p; - if (TypeIsConstantSize(GetType().InternalType()) && - (GetVectorType() == VectorType::CONSTANT_VECTOR || GetVectorType() == VectorType::FLAT_VECTOR)) { - auxiliary.reset(); - } - if (vector_type == VectorType::CONSTANT_VECTOR && GetType().InternalType() == PhysicalType::STRUCT) { - auto &entries = StructVector::GetEntries(*this); - for (auto &entry : entries) { - entry->SetVectorType(vector_type); - } - } -} - -void Vector::UTFVerify(const SelectionVector &sel, idx_t count) { -#ifdef DEBUG - if (count == 0) { - return; - } - if (GetType().InternalType() == PhysicalType::VARCHAR) { - // we just touch all the strings and let the sanitizer figure out if any - // of them are deallocated/corrupt - switch (GetVectorType()) { - case VectorType::CONSTANT_VECTOR: { - auto string = ConstantVector::GetData(*this); - if (!ConstantVector::IsNull(*this)) { - string->Verify(); - } - break; - } - case VectorType::FLAT_VECTOR: { - auto strings = FlatVector::GetData(*this); - for (idx_t i = 0; i < count; i++) { - auto oidx = sel.get_index(i); - if (validity.RowIsValid(oidx)) { - strings[oidx].Verify(); - } - } - break; - } - default: - break; - } - } -#endif -} - -void Vector::UTFVerify(idx_t count) { - auto flat_sel = FlatVector::IncrementalSelectionVector(); - - UTFVerify(*flat_sel, count); -} - -void Vector::VerifyMap(Vector &vector_p, const SelectionVector &sel_p, idx_t count) { -#ifdef DEBUG - D_ASSERT(vector_p.GetType().id() == LogicalTypeId::MAP); - auto &child = ListType::GetChildType(vector_p.GetType()); - D_ASSERT(StructType::GetChildCount(child) == 2); - D_ASSERT(StructType::GetChildName(child, 0) == "key"); - D_ASSERT(StructType::GetChildName(child, 1) == "value"); - - auto valid_check = MapVector::CheckMapValidity(vector_p, count, sel_p); - D_ASSERT(valid_check == MapInvalidReason::VALID); -#endif // DEBUG -} - -void Vector::VerifyUnion(Vector &vector_p, const SelectionVector &sel_p, idx_t count) { -#ifdef DEBUG - D_ASSERT(vector_p.GetType().id() == LogicalTypeId::UNION); - auto valid_check = UnionVector::CheckUnionValidity(vector_p, count, sel_p); - D_ASSERT(valid_check == UnionInvalidReason::VALID); -#endif // DEBUG -} - -void Vector::Verify(Vector &vector_p, const SelectionVector &sel_p, idx_t count) { -#ifdef DEBUG - if (count == 0) { - return; - } - Vector *vector = &vector_p; - const SelectionVector *sel = &sel_p; - SelectionVector owned_sel; - auto &type = vector->GetType(); - auto vtype = vector->GetVectorType(); - if (vector->GetVectorType() == VectorType::DICTIONARY_VECTOR) { - auto &child = DictionaryVector::Child(*vector); - D_ASSERT(child.GetVectorType() != VectorType::DICTIONARY_VECTOR); - auto &dict_sel = DictionaryVector::SelVector(*vector); - // merge the selection vectors and verify the child - auto new_buffer = dict_sel.Slice(*sel, count); - owned_sel.Initialize(new_buffer); - sel = &owned_sel; - vector = &child; - vtype = vector->GetVectorType(); - } - if (TypeIsConstantSize(type.InternalType()) && - (vtype == VectorType::CONSTANT_VECTOR || vtype == VectorType::FLAT_VECTOR)) { - D_ASSERT(!vector->auxiliary); - } - if (type.id() == LogicalTypeId::VARCHAR) { - // verify that the string is correct unicode - switch (vtype) { - case VectorType::FLAT_VECTOR: { - auto &validity = FlatVector::Validity(*vector); - auto strings = FlatVector::GetData(*vector); - for (idx_t i = 0; i < count; i++) { - auto oidx = sel->get_index(i); - if (validity.RowIsValid(oidx)) { - strings[oidx].Verify(); - } - } - break; - } - default: - break; - } - } - - if (type.id() == LogicalTypeId::BIT) { - switch (vtype) { - case VectorType::FLAT_VECTOR: { - auto &validity = FlatVector::Validity(*vector); - auto strings = FlatVector::GetData(*vector); - for (idx_t i = 0; i < count; i++) { - auto oidx = sel->get_index(i); - if (validity.RowIsValid(oidx)) { - auto buf = strings[oidx].GetData(); - D_ASSERT(*buf >= 0 && *buf < 8); - Bit::Verify(strings[oidx]); - } - } - break; - } - default: - break; - } - } - - if (type.InternalType() == PhysicalType::STRUCT) { - auto &child_types = StructType::GetChildTypes(type); - D_ASSERT(!child_types.empty()); - // create a selection vector of the non-null entries of the struct vector - auto &children = StructVector::GetEntries(*vector); - D_ASSERT(child_types.size() == children.size()); - for (idx_t child_idx = 0; child_idx < children.size(); child_idx++) { - D_ASSERT(children[child_idx]->GetType() == child_types[child_idx].second); - Vector::Verify(*children[child_idx], sel_p, count); - if (vtype == VectorType::CONSTANT_VECTOR) { - D_ASSERT(children[child_idx]->GetVectorType() == VectorType::CONSTANT_VECTOR); - if (ConstantVector::IsNull(*vector)) { - D_ASSERT(ConstantVector::IsNull(*children[child_idx])); - } - } - if (vtype != VectorType::FLAT_VECTOR) { - continue; - } - optional_ptr child_validity; - SelectionVector owned_child_sel; - const SelectionVector *child_sel = &owned_child_sel; - if (children[child_idx]->GetVectorType() == VectorType::FLAT_VECTOR) { - child_sel = FlatVector::IncrementalSelectionVector(); - child_validity = &FlatVector::Validity(*children[child_idx]); - } else if (children[child_idx]->GetVectorType() == VectorType::DICTIONARY_VECTOR) { - auto &child = DictionaryVector::Child(*children[child_idx]); - if (child.GetVectorType() != VectorType::FLAT_VECTOR) { - continue; - } - child_validity = &FlatVector::Validity(child); - child_sel = &DictionaryVector::SelVector(*children[child_idx]); - } else if (children[child_idx]->GetVectorType() == VectorType::CONSTANT_VECTOR) { - child_sel = ConstantVector::ZeroSelectionVector(count, owned_child_sel); - child_validity = &ConstantVector::Validity(*children[child_idx]); - } else { - continue; - } - // for any NULL entry in the struct, the child should be NULL as well - auto &validity = FlatVector::Validity(*vector); - for (idx_t i = 0; i < count; i++) { - auto index = sel->get_index(i); - if (!validity.RowIsValid(index)) { - auto child_index = child_sel->get_index(sel_p.get_index(i)); - D_ASSERT(!child_validity->RowIsValid(child_index)); - } - } - } - - if (vector->GetType().id() == LogicalTypeId::UNION) { - VerifyUnion(*vector, *sel, count); - } - } - - if (type.InternalType() == PhysicalType::LIST) { - if (vtype == VectorType::CONSTANT_VECTOR) { - if (!ConstantVector::IsNull(*vector)) { - auto &child = ListVector::GetEntry(*vector); - SelectionVector child_sel(ListVector::GetListSize(*vector)); - idx_t child_count = 0; - auto le = ConstantVector::GetData(*vector); - D_ASSERT(le->offset + le->length <= ListVector::GetListSize(*vector)); - for (idx_t k = 0; k < le->length; k++) { - child_sel.set_index(child_count++, le->offset + k); - } - Vector::Verify(child, child_sel, child_count); - } - } else if (vtype == VectorType::FLAT_VECTOR) { - auto &validity = FlatVector::Validity(*vector); - auto &child = ListVector::GetEntry(*vector); - auto child_size = ListVector::GetListSize(*vector); - auto list_data = FlatVector::GetData(*vector); - idx_t total_size = 0; - for (idx_t i = 0; i < count; i++) { - auto idx = sel->get_index(i); - auto &le = list_data[idx]; - if (validity.RowIsValid(idx)) { - D_ASSERT(le.offset + le.length <= child_size); - total_size += le.length; - } - } - SelectionVector child_sel(total_size); - idx_t child_count = 0; - for (idx_t i = 0; i < count; i++) { - auto idx = sel->get_index(i); - auto &le = list_data[idx]; - if (validity.RowIsValid(idx)) { - D_ASSERT(le.offset + le.length <= child_size); - for (idx_t k = 0; k < le.length; k++) { - child_sel.set_index(child_count++, le.offset + k); - } - } - } - Vector::Verify(child, child_sel, child_count); - } - - if (vector->GetType().id() == LogicalTypeId::MAP) { - VerifyMap(*vector, *sel, count); - } - } -#endif -} - -void Vector::Verify(idx_t count) { - auto flat_sel = FlatVector::IncrementalSelectionVector(); - Verify(*this, *flat_sel, count); -} - -//===--------------------------------------------------------------------===// -// FlatVector -//===--------------------------------------------------------------------===// -void FlatVector::SetNull(Vector &vector, idx_t idx, bool is_null) { - D_ASSERT(vector.GetVectorType() == VectorType::FLAT_VECTOR); - vector.validity.Set(idx, !is_null); - if (is_null && vector.GetType().InternalType() == PhysicalType::STRUCT) { - // set all child entries to null as well - auto &entries = StructVector::GetEntries(vector); - for (auto &entry : entries) { - FlatVector::SetNull(*entry, idx, is_null); - } - } -} - -//===--------------------------------------------------------------------===// -// ConstantVector -//===--------------------------------------------------------------------===// -void ConstantVector::SetNull(Vector &vector, bool is_null) { - D_ASSERT(vector.GetVectorType() == VectorType::CONSTANT_VECTOR); - vector.validity.Set(0, !is_null); - if (is_null && vector.GetType().InternalType() == PhysicalType::STRUCT) { - // set all child entries to null as well - auto &entries = StructVector::GetEntries(vector); - for (auto &entry : entries) { - entry->SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(*entry, is_null); - } - } -} - -const SelectionVector *ConstantVector::ZeroSelectionVector(idx_t count, SelectionVector &owned_sel) { - if (count <= STANDARD_VECTOR_SIZE) { - return ConstantVector::ZeroSelectionVector(); - } - owned_sel.Initialize(count); - for (idx_t i = 0; i < count; i++) { - owned_sel.set_index(i, 0); - } - return &owned_sel; -} - -void ConstantVector::Reference(Vector &vector, Vector &source, idx_t position, idx_t count) { - auto &source_type = source.GetType(); - switch (source_type.InternalType()) { - case PhysicalType::LIST: { - // retrieve the list entry from the source vector - UnifiedVectorFormat vdata; - source.ToUnifiedFormat(count, vdata); - - auto list_index = vdata.sel->get_index(position); - if (!vdata.validity.RowIsValid(list_index)) { - // list is null: create null value - Value null_value(source_type); - vector.Reference(null_value); - break; - } - - auto list_data = UnifiedVectorFormat::GetData(vdata); - auto list_entry = list_data[list_index]; - - // add the list entry as the first element of "vector" - // FIXME: we only need to allocate space for 1 tuple here - auto target_data = FlatVector::GetData(vector); - target_data[0] = list_entry; - - // create a reference to the child list of the source vector - auto &child = ListVector::GetEntry(vector); - child.Reference(ListVector::GetEntry(source)); - - ListVector::SetListSize(vector, ListVector::GetListSize(source)); - vector.SetVectorType(VectorType::CONSTANT_VECTOR); - break; - } - case PhysicalType::STRUCT: { - UnifiedVectorFormat vdata; - source.ToUnifiedFormat(count, vdata); - - auto struct_index = vdata.sel->get_index(position); - if (!vdata.validity.RowIsValid(struct_index)) { - // null struct: create null value - Value null_value(source_type); - vector.Reference(null_value); - break; - } - - // struct: pass constant reference into child entries - auto &source_entries = StructVector::GetEntries(source); - auto &target_entries = StructVector::GetEntries(vector); - for (idx_t i = 0; i < source_entries.size(); i++) { - ConstantVector::Reference(*target_entries[i], *source_entries[i], position, count); - } - vector.SetVectorType(VectorType::CONSTANT_VECTOR); - vector.validity.Set(0, true); - break; - } - default: - // default behavior: get a value from the vector and reference it - // this is not that expensive for scalar types - auto value = source.GetValue(position); - vector.Reference(value); - D_ASSERT(vector.GetVectorType() == VectorType::CONSTANT_VECTOR); - break; - } -} - -//===--------------------------------------------------------------------===// -// StringVector -//===--------------------------------------------------------------------===// -string_t StringVector::AddString(Vector &vector, const char *data, idx_t len) { - return StringVector::AddString(vector, string_t(data, len)); -} - -string_t StringVector::AddStringOrBlob(Vector &vector, const char *data, idx_t len) { - return StringVector::AddStringOrBlob(vector, string_t(data, len)); -} - -string_t StringVector::AddString(Vector &vector, const char *data) { - return StringVector::AddString(vector, string_t(data, strlen(data))); -} - -string_t StringVector::AddString(Vector &vector, const string &data) { - return StringVector::AddString(vector, string_t(data.c_str(), data.size())); -} - -string_t StringVector::AddString(Vector &vector, string_t data) { - D_ASSERT(vector.GetType().id() == LogicalTypeId::VARCHAR || vector.GetType().id() == LogicalTypeId::BIT); - if (data.IsInlined()) { - // string will be inlined: no need to store in string heap - return data; - } - if (!vector.auxiliary) { - vector.auxiliary = make_buffer(); - } - D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::STRING_BUFFER); - auto &string_buffer = vector.auxiliary->Cast(); - return string_buffer.AddString(data); -} - -string_t StringVector::AddStringOrBlob(Vector &vector, string_t data) { - D_ASSERT(vector.GetType().InternalType() == PhysicalType::VARCHAR); - if (data.IsInlined()) { - // string will be inlined: no need to store in string heap - return data; - } - if (!vector.auxiliary) { - vector.auxiliary = make_buffer(); - } - D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::STRING_BUFFER); - auto &string_buffer = vector.auxiliary->Cast(); - return string_buffer.AddBlob(data); -} - -string_t StringVector::EmptyString(Vector &vector, idx_t len) { - D_ASSERT(vector.GetType().InternalType() == PhysicalType::VARCHAR); - if (len <= string_t::INLINE_LENGTH) { - return string_t(len); - } - if (!vector.auxiliary) { - vector.auxiliary = make_buffer(); - } - D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::STRING_BUFFER); - auto &string_buffer = vector.auxiliary->Cast(); - return string_buffer.EmptyString(len); -} - -void StringVector::AddHandle(Vector &vector, BufferHandle handle) { - D_ASSERT(vector.GetType().InternalType() == PhysicalType::VARCHAR); - if (!vector.auxiliary) { - vector.auxiliary = make_buffer(); - } - auto &string_buffer = vector.auxiliary->Cast(); - string_buffer.AddHeapReference(make_buffer(std::move(handle))); -} - -void StringVector::AddBuffer(Vector &vector, buffer_ptr buffer) { - D_ASSERT(vector.GetType().InternalType() == PhysicalType::VARCHAR); - D_ASSERT(buffer.get() != vector.auxiliary.get()); - if (!vector.auxiliary) { - vector.auxiliary = make_buffer(); - } - auto &string_buffer = vector.auxiliary->Cast(); - string_buffer.AddHeapReference(std::move(buffer)); -} - -void StringVector::AddHeapReference(Vector &vector, Vector &other) { - D_ASSERT(vector.GetType().InternalType() == PhysicalType::VARCHAR); - D_ASSERT(other.GetType().InternalType() == PhysicalType::VARCHAR); - - if (other.GetVectorType() == VectorType::DICTIONARY_VECTOR) { - StringVector::AddHeapReference(vector, DictionaryVector::Child(other)); - return; - } - if (!other.auxiliary) { - return; - } - StringVector::AddBuffer(vector, other.auxiliary); -} - -//===--------------------------------------------------------------------===// -// FSSTVector -//===--------------------------------------------------------------------===// -string_t FSSTVector::AddCompressedString(Vector &vector, const char *data, idx_t len) { - return FSSTVector::AddCompressedString(vector, string_t(data, len)); -} - -string_t FSSTVector::AddCompressedString(Vector &vector, string_t data) { - D_ASSERT(vector.GetType().InternalType() == PhysicalType::VARCHAR); - if (data.IsInlined()) { - // string will be inlined: no need to store in string heap - return data; - } - if (!vector.auxiliary) { - vector.auxiliary = make_buffer(); - } - D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::FSST_BUFFER); - auto &fsst_string_buffer = vector.auxiliary->Cast(); - return fsst_string_buffer.AddBlob(data); -} - -void *FSSTVector::GetDecoder(const Vector &vector) { - D_ASSERT(vector.GetType().InternalType() == PhysicalType::VARCHAR); - if (!vector.auxiliary) { - throw InternalException("GetDecoder called on FSST Vector without registered buffer"); - } - D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::FSST_BUFFER); - auto &fsst_string_buffer = vector.auxiliary->Cast(); - return fsst_string_buffer.GetDecoder(); -} - -void FSSTVector::RegisterDecoder(Vector &vector, buffer_ptr &duckdb_fsst_decoder) { - D_ASSERT(vector.GetType().InternalType() == PhysicalType::VARCHAR); - - if (!vector.auxiliary) { - vector.auxiliary = make_buffer(); - } - D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::FSST_BUFFER); - - auto &fsst_string_buffer = vector.auxiliary->Cast(); - fsst_string_buffer.AddDecoder(duckdb_fsst_decoder); -} - -void FSSTVector::SetCount(Vector &vector, idx_t count) { - D_ASSERT(vector.GetType().InternalType() == PhysicalType::VARCHAR); - - if (!vector.auxiliary) { - vector.auxiliary = make_buffer(); - } - D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::FSST_BUFFER); - - auto &fsst_string_buffer = vector.auxiliary->Cast(); - fsst_string_buffer.SetCount(count); -} - -idx_t FSSTVector::GetCount(Vector &vector) { - D_ASSERT(vector.GetType().InternalType() == PhysicalType::VARCHAR); - - if (!vector.auxiliary) { - vector.auxiliary = make_buffer(); - } - D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::FSST_BUFFER); - - auto &fsst_string_buffer = vector.auxiliary->Cast(); - return fsst_string_buffer.GetCount(); -} - -void FSSTVector::DecompressVector(const Vector &src, Vector &dst, idx_t src_offset, idx_t dst_offset, idx_t copy_count, - const SelectionVector *sel) { - D_ASSERT(src.GetVectorType() == VectorType::FSST_VECTOR); - D_ASSERT(dst.GetVectorType() == VectorType::FLAT_VECTOR); - auto dst_mask = FlatVector::Validity(dst); - auto ldata = FSSTVector::GetCompressedData(src); - auto tdata = FlatVector::GetData(dst); - for (idx_t i = 0; i < copy_count; i++) { - auto source_idx = sel->get_index(src_offset + i); - auto target_idx = dst_offset + i; - string_t compressed_string = ldata[source_idx]; - if (dst_mask.RowIsValid(target_idx) && compressed_string.GetSize() > 0) { - tdata[target_idx] = FSSTPrimitives::DecompressValue( - FSSTVector::GetDecoder(src), dst, compressed_string.GetData(), compressed_string.GetSize()); - } else { - tdata[target_idx] = string_t(nullptr, 0); - } - } -} - -//===--------------------------------------------------------------------===// -// MapVector -//===--------------------------------------------------------------------===// -Vector &MapVector::GetKeys(Vector &vector) { - auto &entries = StructVector::GetEntries(ListVector::GetEntry(vector)); - D_ASSERT(entries.size() == 2); - return *entries[0]; -} -Vector &MapVector::GetValues(Vector &vector) { - auto &entries = StructVector::GetEntries(ListVector::GetEntry(vector)); - D_ASSERT(entries.size() == 2); - return *entries[1]; -} - -const Vector &MapVector::GetKeys(const Vector &vector) { - return GetKeys((Vector &)vector); -} -const Vector &MapVector::GetValues(const Vector &vector) { - return GetValues((Vector &)vector); -} - -MapInvalidReason MapVector::CheckMapValidity(Vector &map, idx_t count, const SelectionVector &sel) { - D_ASSERT(map.GetType().id() == LogicalTypeId::MAP); - UnifiedVectorFormat map_vdata; - - map.ToUnifiedFormat(count, map_vdata); - auto &map_validity = map_vdata.validity; - - auto list_data = ListVector::GetData(map); - auto &keys = MapVector::GetKeys(map); - UnifiedVectorFormat key_vdata; - keys.ToUnifiedFormat(count, key_vdata); - auto &key_validity = key_vdata.validity; - - for (idx_t row = 0; row < count; row++) { - auto mapped_row = sel.get_index(row); - auto map_idx = map_vdata.sel->get_index(mapped_row); - // map is allowed to be NULL - if (!map_validity.RowIsValid(map_idx)) { - continue; - } - value_set_t unique_keys; - for (idx_t i = 0; i < list_data[map_idx].length; i++) { - auto index = list_data[map_idx].offset + i; - index = key_vdata.sel->get_index(index); - if (!key_validity.RowIsValid(index)) { - return MapInvalidReason::NULL_KEY; - } - auto value = keys.GetValue(index); - auto result = unique_keys.insert(value); - if (!result.second) { - return MapInvalidReason::DUPLICATE_KEY; - } - } - } - return MapInvalidReason::VALID; -} - -void MapVector::MapConversionVerify(Vector &vector, idx_t count) { - auto valid_check = MapVector::CheckMapValidity(vector, count); - switch (valid_check) { - case MapInvalidReason::VALID: - break; - case MapInvalidReason::DUPLICATE_KEY: { - throw InvalidInputException("Map keys have to be unique"); - } - case MapInvalidReason::NULL_KEY: { - throw InvalidInputException("Map keys can not be NULL"); - } - case MapInvalidReason::NULL_KEY_LIST: { - throw InvalidInputException("The list of map keys is not allowed to be NULL"); - } - default: { - throw InternalException("MapInvalidReason not implemented"); - } - } -} - -//===--------------------------------------------------------------------===// -// StructVector -//===--------------------------------------------------------------------===// -vector> &StructVector::GetEntries(Vector &vector) { - D_ASSERT(vector.GetType().id() == LogicalTypeId::STRUCT || vector.GetType().id() == LogicalTypeId::UNION); - - if (vector.GetVectorType() == VectorType::DICTIONARY_VECTOR) { - auto &child = DictionaryVector::Child(vector); - return StructVector::GetEntries(child); - } - D_ASSERT(vector.GetVectorType() == VectorType::FLAT_VECTOR || - vector.GetVectorType() == VectorType::CONSTANT_VECTOR); - D_ASSERT(vector.auxiliary); - D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::STRUCT_BUFFER); - return vector.auxiliary->Cast().GetChildren(); -} - -const vector> &StructVector::GetEntries(const Vector &vector) { - return GetEntries((Vector &)vector); -} - -//===--------------------------------------------------------------------===// -// ListVector -//===--------------------------------------------------------------------===// -const Vector &ListVector::GetEntry(const Vector &vector) { - D_ASSERT(vector.GetType().id() == LogicalTypeId::LIST || vector.GetType().id() == LogicalTypeId::MAP); - if (vector.GetVectorType() == VectorType::DICTIONARY_VECTOR) { - auto &child = DictionaryVector::Child(vector); - return ListVector::GetEntry(child); - } - D_ASSERT(vector.GetVectorType() == VectorType::FLAT_VECTOR || - vector.GetVectorType() == VectorType::CONSTANT_VECTOR); - D_ASSERT(vector.auxiliary); - D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::LIST_BUFFER); - return vector.auxiliary->Cast().GetChild(); -} - -Vector &ListVector::GetEntry(Vector &vector) { - const Vector &cvector = vector; - return const_cast(ListVector::GetEntry(cvector)); -} - -void ListVector::Reserve(Vector &vector, idx_t required_capacity) { - D_ASSERT(vector.GetType().id() == LogicalTypeId::LIST || vector.GetType().id() == LogicalTypeId::MAP); - D_ASSERT(vector.GetVectorType() == VectorType::FLAT_VECTOR || - vector.GetVectorType() == VectorType::CONSTANT_VECTOR); - D_ASSERT(vector.auxiliary); - D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::LIST_BUFFER); - auto &child_buffer = vector.auxiliary->Cast(); - child_buffer.Reserve(required_capacity); -} - -idx_t ListVector::GetListSize(const Vector &vec) { - if (vec.GetVectorType() == VectorType::DICTIONARY_VECTOR) { - auto &child = DictionaryVector::Child(vec); - return ListVector::GetListSize(child); - } - D_ASSERT(vec.auxiliary); - return vec.auxiliary->Cast().GetSize(); -} - -idx_t ListVector::GetListCapacity(const Vector &vec) { - if (vec.GetVectorType() == VectorType::DICTIONARY_VECTOR) { - auto &child = DictionaryVector::Child(vec); - return ListVector::GetListSize(child); - } - D_ASSERT(vec.auxiliary); - return vec.auxiliary->Cast().GetCapacity(); -} - -void ListVector::ReferenceEntry(Vector &vector, Vector &other) { - D_ASSERT(vector.GetType().id() == LogicalTypeId::LIST); - D_ASSERT(vector.GetVectorType() == VectorType::FLAT_VECTOR || - vector.GetVectorType() == VectorType::CONSTANT_VECTOR); - D_ASSERT(other.GetType().id() == LogicalTypeId::LIST); - D_ASSERT(other.GetVectorType() == VectorType::FLAT_VECTOR || other.GetVectorType() == VectorType::CONSTANT_VECTOR); - vector.auxiliary = other.auxiliary; -} - -void ListVector::SetListSize(Vector &vec, idx_t size) { - if (vec.GetVectorType() == VectorType::DICTIONARY_VECTOR) { - auto &child = DictionaryVector::Child(vec); - ListVector::SetListSize(child, size); - } - vec.auxiliary->Cast().SetSize(size); -} - -void ListVector::Append(Vector &target, const Vector &source, idx_t source_size, idx_t source_offset) { - if (source_size - source_offset == 0) { - //! Nothing to add - return; - } - auto &target_buffer = target.auxiliary->Cast(); - target_buffer.Append(source, source_size, source_offset); -} - -void ListVector::Append(Vector &target, const Vector &source, const SelectionVector &sel, idx_t source_size, - idx_t source_offset) { - if (source_size - source_offset == 0) { - //! Nothing to add - return; - } - auto &target_buffer = target.auxiliary->Cast(); - target_buffer.Append(source, sel, source_size, source_offset); -} - -void ListVector::PushBack(Vector &target, const Value &insert) { - auto &target_buffer = target.auxiliary->Cast(); - target_buffer.PushBack(insert); -} - -idx_t ListVector::GetConsecutiveChildList(Vector &list, Vector &result, idx_t offset, idx_t count) { - - auto info = ListVector::GetConsecutiveChildListInfo(list, offset, count); - if (info.needs_slicing) { - SelectionVector sel(info.child_list_info.length); - ListVector::GetConsecutiveChildSelVector(list, sel, offset, count); - - result.Slice(sel, info.child_list_info.length); - result.Flatten(info.child_list_info.length); - } - return info.child_list_info.length; -} - -ConsecutiveChildListInfo ListVector::GetConsecutiveChildListInfo(Vector &list, idx_t offset, idx_t count) { - - ConsecutiveChildListInfo info; - UnifiedVectorFormat unified_list_data; - list.ToUnifiedFormat(offset + count, unified_list_data); - auto list_data = UnifiedVectorFormat::GetData(unified_list_data); - - // find the first non-NULL entry - idx_t first_length = 0; - for (idx_t i = offset; i < offset + count; i++) { - auto idx = unified_list_data.sel->get_index(i); - if (!unified_list_data.validity.RowIsValid(idx)) { - continue; - } - info.child_list_info.offset = list_data[idx].offset; - first_length = list_data[idx].length; - break; - } - - // small performance improvement for constant vectors - // avoids iterating over all their (constant) elements - if (list.GetVectorType() == VectorType::CONSTANT_VECTOR) { - info.child_list_info.length = first_length; - return info; - } - - // now get the child count and determine whether the children are stored consecutively - // also determine if a flat vector has pseudo constant values (all offsets + length the same) - // this can happen e.g. for UNNESTs - bool is_consecutive = true; - for (idx_t i = offset; i < offset + count; i++) { - auto idx = unified_list_data.sel->get_index(i); - if (!unified_list_data.validity.RowIsValid(idx)) { - continue; - } - if (list_data[idx].offset != info.child_list_info.offset || list_data[idx].length != first_length) { - info.is_constant = false; - } - if (list_data[idx].offset != info.child_list_info.offset + info.child_list_info.length) { - is_consecutive = false; - } - info.child_list_info.length += list_data[idx].length; - } - - if (info.is_constant) { - info.child_list_info.length = first_length; - } - if (!info.is_constant && !is_consecutive) { - info.needs_slicing = true; - } - - return info; -} - -void ListVector::GetConsecutiveChildSelVector(Vector &list, SelectionVector &sel, idx_t offset, idx_t count) { - UnifiedVectorFormat unified_list_data; - list.ToUnifiedFormat(offset + count, unified_list_data); - auto list_data = UnifiedVectorFormat::GetData(unified_list_data); - - // SelectionVector child_sel(info.second.length); - idx_t entry = 0; - for (idx_t i = offset; i < offset + count; i++) { - auto idx = unified_list_data.sel->get_index(i); - if (!unified_list_data.validity.RowIsValid(idx)) { - continue; - } - for (idx_t k = 0; k < list_data[idx].length; k++) { - // child_sel.set_index(entry++, list_data[idx].offset + k); - sel.set_index(entry++, list_data[idx].offset + k); - } - } - // - // result.Slice(child_sel, info.second.length); - // result.Flatten(info.second.length); - // info.second.offset = 0; -} - -//===--------------------------------------------------------------------===// -// UnionVector -//===--------------------------------------------------------------------===// -const Vector &UnionVector::GetMember(const Vector &vector, idx_t member_index) { - D_ASSERT(member_index < UnionType::GetMemberCount(vector.GetType())); - auto &entries = StructVector::GetEntries(vector); - return *entries[member_index + 1]; // skip the "tag" entry -} - -Vector &UnionVector::GetMember(Vector &vector, idx_t member_index) { - D_ASSERT(member_index < UnionType::GetMemberCount(vector.GetType())); - auto &entries = StructVector::GetEntries(vector); - return *entries[member_index + 1]; // skip the "tag" entry -} - -const Vector &UnionVector::GetTags(const Vector &vector) { - // the tag vector is always the first struct child. - return *StructVector::GetEntries(vector)[0]; -} - -Vector &UnionVector::GetTags(Vector &vector) { - // the tag vector is always the first struct child. - return *StructVector::GetEntries(vector)[0]; -} - -void UnionVector::SetToMember(Vector &union_vector, union_tag_t tag, Vector &member_vector, idx_t count, - bool keep_tags_for_null) { - D_ASSERT(union_vector.GetType().id() == LogicalTypeId::UNION); - D_ASSERT(tag < UnionType::GetMemberCount(union_vector.GetType())); - - // Set the union member to the specified vector - UnionVector::GetMember(union_vector, tag).Reference(member_vector); - auto &tag_vector = UnionVector::GetTags(union_vector); - - if (member_vector.GetVectorType() == VectorType::CONSTANT_VECTOR) { - // if the member vector is constant, we can set the union to constant as well - union_vector.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::GetData(tag_vector)[0] = tag; - ConstantVector::SetNull(union_vector, ConstantVector::IsNull(member_vector)); - - } else { - // otherwise flatten and set to flatvector - member_vector.Flatten(count); - union_vector.SetVectorType(VectorType::FLAT_VECTOR); - - if (member_vector.validity.AllValid()) { - // if the member vector is all valid, we can set the tag to constant - tag_vector.SetVectorType(VectorType::CONSTANT_VECTOR); - auto tag_data = ConstantVector::GetData(tag_vector); - *tag_data = tag; - } else { - tag_vector.SetVectorType(VectorType::FLAT_VECTOR); - if (keep_tags_for_null) { - FlatVector::Validity(tag_vector).SetAllValid(count); - FlatVector::Validity(union_vector).SetAllValid(count); - } else { - // ensure the tags have the same validity as the member - FlatVector::Validity(union_vector) = FlatVector::Validity(member_vector); - FlatVector::Validity(tag_vector) = FlatVector::Validity(member_vector); - } - - auto tag_data = FlatVector::GetData(tag_vector); - memset(tag_data, tag, count); - } - } - - // Set the non-selected members to constant null vectors - for (idx_t i = 0; i < UnionType::GetMemberCount(union_vector.GetType()); i++) { - if (i != tag) { - auto &member = UnionVector::GetMember(union_vector, i); - member.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(member, true); - } - } -} - -union_tag_t UnionVector::GetTag(const Vector &vector, idx_t index) { - // the tag vector is always the first struct child. - auto &tag_vector = *StructVector::GetEntries(vector)[0]; - if (tag_vector.GetVectorType() == VectorType::DICTIONARY_VECTOR) { - auto &child = DictionaryVector::Child(tag_vector); - return FlatVector::GetData(child)[index]; - } - if (tag_vector.GetVectorType() == VectorType::CONSTANT_VECTOR) { - return ConstantVector::GetData(tag_vector)[0]; - } - return FlatVector::GetData(tag_vector)[index]; -} - -UnionInvalidReason UnionVector::CheckUnionValidity(Vector &vector, idx_t count, const SelectionVector &sel) { - D_ASSERT(vector.GetType().id() == LogicalTypeId::UNION); - auto member_count = UnionType::GetMemberCount(vector.GetType()); - if (member_count == 0) { - return UnionInvalidReason::NO_MEMBERS; - } - - UnifiedVectorFormat union_vdata; - vector.ToUnifiedFormat(count, union_vdata); - - UnifiedVectorFormat tags_vdata; - auto &tag_vector = UnionVector::GetTags(vector); - tag_vector.ToUnifiedFormat(count, tags_vdata); - - // check that only one member is valid at a time - for (idx_t row_idx = 0; row_idx < count; row_idx++) { - auto union_mapped_row_idx = sel.get_index(row_idx); - if (!union_vdata.validity.RowIsValid(union_mapped_row_idx)) { - continue; - } - - auto tag_mapped_row_idx = tags_vdata.sel->get_index(row_idx); - if (!tags_vdata.validity.RowIsValid(tag_mapped_row_idx)) { - continue; - } - - auto tag = (UnifiedVectorFormat::GetData(tags_vdata))[tag_mapped_row_idx]; - if (tag >= member_count) { - return UnionInvalidReason::TAG_OUT_OF_RANGE; - } - - bool found_valid = false; - for (idx_t member_idx = 0; member_idx < member_count; member_idx++) { - - UnifiedVectorFormat member_vdata; - auto &member = UnionVector::GetMember(vector, member_idx); - member.ToUnifiedFormat(count, member_vdata); - - auto mapped_row_idx = member_vdata.sel->get_index(row_idx); - if (member_vdata.validity.RowIsValid(mapped_row_idx)) { - if (found_valid) { - return UnionInvalidReason::VALIDITY_OVERLAP; - } - found_valid = true; - if (tag != static_cast(member_idx)) { - return UnionInvalidReason::TAG_MISMATCH; - } - } - } - } - - return UnionInvalidReason::VALID; -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -buffer_ptr VectorBuffer::CreateStandardVector(PhysicalType type, idx_t capacity) { - return make_buffer(capacity * GetTypeIdSize(type)); -} - -buffer_ptr VectorBuffer::CreateConstantVector(PhysicalType type) { - return make_buffer(GetTypeIdSize(type)); -} - -buffer_ptr VectorBuffer::CreateConstantVector(const LogicalType &type) { - return VectorBuffer::CreateConstantVector(type.InternalType()); -} - -buffer_ptr VectorBuffer::CreateStandardVector(const LogicalType &type, idx_t capacity) { - return VectorBuffer::CreateStandardVector(type.InternalType(), capacity); -} - -VectorStringBuffer::VectorStringBuffer() : VectorBuffer(VectorBufferType::STRING_BUFFER) { -} - -VectorStringBuffer::VectorStringBuffer(VectorBufferType type) : VectorBuffer(type) { -} - -VectorFSSTStringBuffer::VectorFSSTStringBuffer() : VectorStringBuffer(VectorBufferType::FSST_BUFFER) { -} - -VectorStructBuffer::VectorStructBuffer() : VectorBuffer(VectorBufferType::STRUCT_BUFFER) { -} - -VectorStructBuffer::VectorStructBuffer(const LogicalType &type, idx_t capacity) - : VectorBuffer(VectorBufferType::STRUCT_BUFFER) { - auto &child_types = StructType::GetChildTypes(type); - for (auto &child_type : child_types) { - auto vector = make_uniq(child_type.second, capacity); - children.push_back(std::move(vector)); - } -} - -VectorStructBuffer::VectorStructBuffer(Vector &other, const SelectionVector &sel, idx_t count) - : VectorBuffer(VectorBufferType::STRUCT_BUFFER) { - auto &other_vector = StructVector::GetEntries(other); - for (auto &child_vector : other_vector) { - auto vector = make_uniq(*child_vector, sel, count); - children.push_back(std::move(vector)); - } -} - -VectorStructBuffer::~VectorStructBuffer() { -} - -VectorListBuffer::VectorListBuffer(unique_ptr vector, idx_t initial_capacity) - : VectorBuffer(VectorBufferType::LIST_BUFFER), child(std::move(vector)), capacity(initial_capacity) { -} - -VectorListBuffer::VectorListBuffer(const LogicalType &list_type, idx_t initial_capacity) - : VectorBuffer(VectorBufferType::LIST_BUFFER), - child(make_uniq(ListType::GetChildType(list_type), initial_capacity)), capacity(initial_capacity) { -} - -void VectorListBuffer::Reserve(idx_t to_reserve) { - if (to_reserve > capacity) { - idx_t new_capacity = NextPowerOfTwo(to_reserve); - D_ASSERT(new_capacity >= to_reserve); - child->Resize(capacity, new_capacity); - capacity = new_capacity; - } -} - -void VectorListBuffer::Append(const Vector &to_append, idx_t to_append_size, idx_t source_offset) { - Reserve(size + to_append_size - source_offset); - VectorOperations::Copy(to_append, *child, to_append_size, source_offset, size); - size += to_append_size - source_offset; -} - -void VectorListBuffer::Append(const Vector &to_append, const SelectionVector &sel, idx_t to_append_size, - idx_t source_offset) { - Reserve(size + to_append_size - source_offset); - VectorOperations::Copy(to_append, *child, sel, to_append_size, source_offset, size); - size += to_append_size - source_offset; -} - -void VectorListBuffer::PushBack(const Value &insert) { - while (size + 1 > capacity) { - child->Resize(capacity, capacity * 2); - capacity *= 2; - } - child->SetValue(size++, insert); -} - -void VectorListBuffer::SetCapacity(idx_t new_capacity) { - this->capacity = new_capacity; -} - -void VectorListBuffer::SetSize(idx_t new_size) { - this->size = new_size; -} - -VectorListBuffer::~VectorListBuffer() { -} - -ManagedVectorBuffer::ManagedVectorBuffer(BufferHandle handle) - : VectorBuffer(VectorBufferType::MANAGED_BUFFER), handle(std::move(handle)) { -} - -ManagedVectorBuffer::~ManagedVectorBuffer() { -} - -} // namespace duckdb - - - - - -namespace duckdb { - -class VectorCacheBuffer : public VectorBuffer { -public: - explicit VectorCacheBuffer(Allocator &allocator, const LogicalType &type_p, idx_t capacity_p = STANDARD_VECTOR_SIZE) - : VectorBuffer(VectorBufferType::OPAQUE_BUFFER), type(type_p), capacity(capacity_p) { - auto internal_type = type.InternalType(); - switch (internal_type) { - case PhysicalType::LIST: { - // memory for the list offsets - owned_data = allocator.Allocate(capacity * GetTypeIdSize(internal_type)); - // child data of the list - auto &child_type = ListType::GetChildType(type); - child_caches.push_back(make_buffer(allocator, child_type, capacity)); - auto child_vector = make_uniq(child_type, false, false); - auxiliary = make_shared(std::move(child_vector)); - break; - } - case PhysicalType::STRUCT: { - auto &child_types = StructType::GetChildTypes(type); - for (auto &child_type : child_types) { - child_caches.push_back(make_buffer(allocator, child_type.second, capacity)); - } - auto struct_buffer = make_shared(type); - auxiliary = std::move(struct_buffer); - break; - } - default: - owned_data = allocator.Allocate(capacity * GetTypeIdSize(internal_type)); - break; - } - } - - void ResetFromCache(Vector &result, const buffer_ptr &buffer) { - D_ASSERT(type == result.GetType()); - auto internal_type = type.InternalType(); - result.vector_type = VectorType::FLAT_VECTOR; - AssignSharedPointer(result.buffer, buffer); - result.validity.Reset(); - switch (internal_type) { - case PhysicalType::LIST: { - result.data = owned_data.get(); - // reinitialize the VectorListBuffer - AssignSharedPointer(result.auxiliary, auxiliary); - // propagate through child - auto &child_cache = child_caches[0]->Cast(); - auto &list_buffer = result.auxiliary->Cast(); - list_buffer.SetCapacity(child_cache.capacity); - list_buffer.SetSize(0); - list_buffer.SetAuxiliaryData(nullptr); - - auto &list_child = list_buffer.GetChild(); - child_cache.ResetFromCache(list_child, child_caches[0]); - break; - } - case PhysicalType::STRUCT: { - // struct does not have data - result.data = nullptr; - // reinitialize the VectorStructBuffer - auxiliary->SetAuxiliaryData(nullptr); - AssignSharedPointer(result.auxiliary, auxiliary); - // propagate through children - auto &children = result.auxiliary->Cast().GetChildren(); - for (idx_t i = 0; i < children.size(); i++) { - auto &child_cache = child_caches[i]->Cast(); - child_cache.ResetFromCache(*children[i], child_caches[i]); - } - break; - } - default: - // regular type: no aux data and reset data to cached data - result.data = owned_data.get(); - result.auxiliary.reset(); - break; - } - } - - const LogicalType &GetType() { - return type; - } - -private: - //! The type of the vector cache - LogicalType type; - //! Owned data - AllocatedData owned_data; - //! Child caches (if any). Used for nested types. - vector> child_caches; - //! Aux data for the vector (if any) - buffer_ptr auxiliary; - //! Capacity of the vector - idx_t capacity; -}; - -VectorCache::VectorCache(Allocator &allocator, const LogicalType &type_p, idx_t capacity_p) { - buffer = make_buffer(allocator, type_p, capacity_p); -} - -void VectorCache::ResetFromCache(Vector &result) const { - D_ASSERT(buffer); - auto &vcache = buffer->Cast(); - vcache.ResetFromCache(result, buffer); -} - -const LogicalType &VectorCache::GetType() const { - auto &vcache = buffer->Cast(); - return vcache.GetType(); -} - -} // namespace duckdb - - -namespace duckdb { - -const SelectionVector *ConstantVector::ZeroSelectionVector() { - static const SelectionVector ZERO_SELECTION_VECTOR = - SelectionVector(const_cast(ConstantVector::ZERO_VECTOR)); // NOLINT - return &ZERO_SELECTION_VECTOR; -} - -const SelectionVector *FlatVector::IncrementalSelectionVector() { - static const SelectionVector INCREMENTAL_SELECTION_VECTOR; - return &INCREMENTAL_SELECTION_VECTOR; -} - -const sel_t ConstantVector::ZERO_VECTOR[STANDARD_VECTOR_SIZE] = {0}; - -} // namespace duckdb - - - - - - - - - - - - - - - - - - - - - - - - - - - - -#include - -namespace duckdb { - -LogicalType::LogicalType() : LogicalType(LogicalTypeId::INVALID) { -} - -LogicalType::LogicalType(LogicalTypeId id) : id_(id) { - physical_type_ = GetInternalType(); -} -LogicalType::LogicalType(LogicalTypeId id, shared_ptr type_info_p) - : id_(id), type_info_(std::move(type_info_p)) { - physical_type_ = GetInternalType(); -} - -LogicalType::LogicalType(const LogicalType &other) - : id_(other.id_), physical_type_(other.physical_type_), type_info_(other.type_info_) { -} - -LogicalType::LogicalType(LogicalType &&other) noexcept - : id_(other.id_), physical_type_(other.physical_type_), type_info_(std::move(other.type_info_)) { -} - -hash_t LogicalType::Hash() const { - return duckdb::Hash((uint8_t)id_); -} - -PhysicalType LogicalType::GetInternalType() { - switch (id_) { - case LogicalTypeId::BOOLEAN: - return PhysicalType::BOOL; - case LogicalTypeId::TINYINT: - return PhysicalType::INT8; - case LogicalTypeId::UTINYINT: - return PhysicalType::UINT8; - case LogicalTypeId::SMALLINT: - return PhysicalType::INT16; - case LogicalTypeId::USMALLINT: - return PhysicalType::UINT16; - case LogicalTypeId::SQLNULL: - case LogicalTypeId::DATE: - case LogicalTypeId::INTEGER: - return PhysicalType::INT32; - case LogicalTypeId::UINTEGER: - return PhysicalType::UINT32; - case LogicalTypeId::BIGINT: - case LogicalTypeId::TIME: - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_SEC: - case LogicalTypeId::TIMESTAMP_NS: - case LogicalTypeId::TIMESTAMP_MS: - case LogicalTypeId::TIME_TZ: - case LogicalTypeId::TIMESTAMP_TZ: - return PhysicalType::INT64; - case LogicalTypeId::UBIGINT: - return PhysicalType::UINT64; - case LogicalTypeId::HUGEINT: - case LogicalTypeId::UUID: - return PhysicalType::INT128; - case LogicalTypeId::FLOAT: - return PhysicalType::FLOAT; - case LogicalTypeId::DOUBLE: - return PhysicalType::DOUBLE; - case LogicalTypeId::DECIMAL: { - if (!type_info_) { - return PhysicalType::INVALID; - } - auto width = DecimalType::GetWidth(*this); - if (width <= Decimal::MAX_WIDTH_INT16) { - return PhysicalType::INT16; - } else if (width <= Decimal::MAX_WIDTH_INT32) { - return PhysicalType::INT32; - } else if (width <= Decimal::MAX_WIDTH_INT64) { - return PhysicalType::INT64; - } else if (width <= Decimal::MAX_WIDTH_INT128) { - return PhysicalType::INT128; - } else { - throw InternalException("Decimal has a width of %d which is bigger than the maximum supported width of %d", - width, DecimalType::MaxWidth()); - } - } - case LogicalTypeId::VARCHAR: - case LogicalTypeId::CHAR: - case LogicalTypeId::BLOB: - case LogicalTypeId::BIT: - return PhysicalType::VARCHAR; - case LogicalTypeId::INTERVAL: - return PhysicalType::INTERVAL; - case LogicalTypeId::UNION: - case LogicalTypeId::STRUCT: - return PhysicalType::STRUCT; - case LogicalTypeId::LIST: - case LogicalTypeId::MAP: - return PhysicalType::LIST; - case LogicalTypeId::POINTER: - // LCOV_EXCL_START - if (sizeof(uintptr_t) == sizeof(uint32_t)) { - return PhysicalType::UINT32; - } else if (sizeof(uintptr_t) == sizeof(uint64_t)) { - return PhysicalType::UINT64; - } else { - throw InternalException("Unsupported pointer size"); - } - // LCOV_EXCL_STOP - case LogicalTypeId::VALIDITY: - return PhysicalType::BIT; - case LogicalTypeId::ENUM: { - if (!type_info_) { - return PhysicalType::INVALID; - } - return EnumType::GetPhysicalType(*this); - } - case LogicalTypeId::TABLE: - case LogicalTypeId::LAMBDA: - case LogicalTypeId::ANY: - case LogicalTypeId::INVALID: - case LogicalTypeId::UNKNOWN: - return PhysicalType::INVALID; - case LogicalTypeId::USER: - return PhysicalType::UNKNOWN; - case LogicalTypeId::AGGREGATE_STATE: - return PhysicalType::VARCHAR; - default: - throw InternalException("Invalid LogicalType %s", ToString()); - } -} - -// **DEPRECATED**: Use EnumUtil directly instead. -string LogicalTypeIdToString(LogicalTypeId type) { - return EnumUtil::ToString(type); -} - -constexpr const LogicalTypeId LogicalType::INVALID; -constexpr const LogicalTypeId LogicalType::SQLNULL; -constexpr const LogicalTypeId LogicalType::BOOLEAN; -constexpr const LogicalTypeId LogicalType::TINYINT; -constexpr const LogicalTypeId LogicalType::UTINYINT; -constexpr const LogicalTypeId LogicalType::SMALLINT; -constexpr const LogicalTypeId LogicalType::USMALLINT; -constexpr const LogicalTypeId LogicalType::INTEGER; -constexpr const LogicalTypeId LogicalType::UINTEGER; -constexpr const LogicalTypeId LogicalType::BIGINT; -constexpr const LogicalTypeId LogicalType::UBIGINT; -constexpr const LogicalTypeId LogicalType::HUGEINT; -constexpr const LogicalTypeId LogicalType::UUID; -constexpr const LogicalTypeId LogicalType::FLOAT; -constexpr const LogicalTypeId LogicalType::DOUBLE; -constexpr const LogicalTypeId LogicalType::DATE; - -constexpr const LogicalTypeId LogicalType::TIMESTAMP; -constexpr const LogicalTypeId LogicalType::TIMESTAMP_MS; -constexpr const LogicalTypeId LogicalType::TIMESTAMP_NS; -constexpr const LogicalTypeId LogicalType::TIMESTAMP_S; - -constexpr const LogicalTypeId LogicalType::TIME; - -constexpr const LogicalTypeId LogicalType::TIME_TZ; -constexpr const LogicalTypeId LogicalType::TIMESTAMP_TZ; - -constexpr const LogicalTypeId LogicalType::HASH; -constexpr const LogicalTypeId LogicalType::POINTER; - -constexpr const LogicalTypeId LogicalType::VARCHAR; - -constexpr const LogicalTypeId LogicalType::BLOB; -constexpr const LogicalTypeId LogicalType::BIT; -constexpr const LogicalTypeId LogicalType::INTERVAL; -constexpr const LogicalTypeId LogicalType::ROW_TYPE; - -// TODO these are incomplete and should maybe not exist as such -constexpr const LogicalTypeId LogicalType::TABLE; -constexpr const LogicalTypeId LogicalType::LAMBDA; - -constexpr const LogicalTypeId LogicalType::ANY; - -const vector LogicalType::Numeric() { - vector types = {LogicalType::TINYINT, LogicalType::SMALLINT, LogicalType::INTEGER, - LogicalType::BIGINT, LogicalType::HUGEINT, LogicalType::FLOAT, - LogicalType::DOUBLE, LogicalTypeId::DECIMAL, LogicalType::UTINYINT, - LogicalType::USMALLINT, LogicalType::UINTEGER, LogicalType::UBIGINT}; - return types; -} - -const vector LogicalType::Integral() { - vector types = {LogicalType::TINYINT, LogicalType::SMALLINT, LogicalType::INTEGER, - LogicalType::BIGINT, LogicalType::HUGEINT, LogicalType::UTINYINT, - LogicalType::USMALLINT, LogicalType::UINTEGER, LogicalType::UBIGINT}; - return types; -} - -const vector LogicalType::AllTypes() { - vector types = { - LogicalType::BOOLEAN, LogicalType::TINYINT, LogicalType::SMALLINT, LogicalType::INTEGER, - LogicalType::BIGINT, LogicalType::DATE, LogicalType::TIMESTAMP, LogicalType::DOUBLE, - LogicalType::FLOAT, LogicalType::VARCHAR, LogicalType::BLOB, LogicalType::BIT, - LogicalType::INTERVAL, LogicalType::HUGEINT, LogicalTypeId::DECIMAL, LogicalType::UTINYINT, - LogicalType::USMALLINT, LogicalType::UINTEGER, LogicalType::UBIGINT, LogicalType::TIME, - LogicalTypeId::LIST, LogicalTypeId::STRUCT, LogicalType::TIME_TZ, LogicalType::TIMESTAMP_TZ, - LogicalTypeId::MAP, LogicalTypeId::UNION, LogicalType::UUID}; - return types; -} - -const PhysicalType ROW_TYPE = PhysicalType::INT64; - -// LCOV_EXCL_START -string TypeIdToString(PhysicalType type) { - switch (type) { - case PhysicalType::BOOL: - return "BOOL"; - case PhysicalType::INT8: - return "INT8"; - case PhysicalType::INT16: - return "INT16"; - case PhysicalType::INT32: - return "INT32"; - case PhysicalType::INT64: - return "INT64"; - case PhysicalType::UINT8: - return "UINT8"; - case PhysicalType::UINT16: - return "UINT16"; - case PhysicalType::UINT32: - return "UINT32"; - case PhysicalType::UINT64: - return "UINT64"; - case PhysicalType::INT128: - return "INT128"; - case PhysicalType::FLOAT: - return "FLOAT"; - case PhysicalType::DOUBLE: - return "DOUBLE"; - case PhysicalType::VARCHAR: - return "VARCHAR"; - case PhysicalType::INTERVAL: - return "INTERVAL"; - case PhysicalType::STRUCT: - return "STRUCT"; - case PhysicalType::LIST: - return "LIST"; - case PhysicalType::INVALID: - return "INVALID"; - case PhysicalType::BIT: - return "BIT"; - case PhysicalType::UNKNOWN: - return "UNKNOWN"; - } - return "INVALID"; -} -// LCOV_EXCL_STOP - -idx_t GetTypeIdSize(PhysicalType type) { - switch (type) { - case PhysicalType::BIT: - case PhysicalType::BOOL: - return sizeof(bool); - case PhysicalType::INT8: - return sizeof(int8_t); - case PhysicalType::INT16: - return sizeof(int16_t); - case PhysicalType::INT32: - return sizeof(int32_t); - case PhysicalType::INT64: - return sizeof(int64_t); - case PhysicalType::UINT8: - return sizeof(uint8_t); - case PhysicalType::UINT16: - return sizeof(uint16_t); - case PhysicalType::UINT32: - return sizeof(uint32_t); - case PhysicalType::UINT64: - return sizeof(uint64_t); - case PhysicalType::INT128: - return sizeof(hugeint_t); - case PhysicalType::FLOAT: - return sizeof(float); - case PhysicalType::DOUBLE: - return sizeof(double); - case PhysicalType::VARCHAR: - return sizeof(string_t); - case PhysicalType::INTERVAL: - return sizeof(interval_t); - case PhysicalType::STRUCT: - case PhysicalType::UNKNOWN: - return 0; // no own payload - case PhysicalType::LIST: - return sizeof(list_entry_t); // offset + len - default: - throw InternalException("Invalid PhysicalType for GetTypeIdSize"); - } -} - -bool TypeIsConstantSize(PhysicalType type) { - return (type >= PhysicalType::BOOL && type <= PhysicalType::DOUBLE) || type == PhysicalType::INTERVAL || - type == PhysicalType::INT128; -} -bool TypeIsIntegral(PhysicalType type) { - return (type >= PhysicalType::UINT8 && type <= PhysicalType::INT64) || type == PhysicalType::INT128; -} -bool TypeIsNumeric(PhysicalType type) { - return (type >= PhysicalType::UINT8 && type <= PhysicalType::DOUBLE) || type == PhysicalType::INT128; -} -bool TypeIsInteger(PhysicalType type) { - return (type >= PhysicalType::UINT8 && type <= PhysicalType::INT64) || type == PhysicalType::INT128; -} - -string LogicalType::ToString() const { - auto alias = GetAlias(); - if (!alias.empty()) { - return alias; - } - switch (id_) { - case LogicalTypeId::STRUCT: { - if (!type_info_) { - return "STRUCT"; - } - auto &child_types = StructType::GetChildTypes(*this); - string ret = "STRUCT("; - for (size_t i = 0; i < child_types.size(); i++) { - ret += StringUtil::Format("%s %s", SQLIdentifier(child_types[i].first), child_types[i].second); - if (i < child_types.size() - 1) { - ret += ", "; - } - } - ret += ")"; - return ret; - } - case LogicalTypeId::LIST: { - if (!type_info_) { - return "LIST"; - } - return ListType::GetChildType(*this).ToString() + "[]"; - } - case LogicalTypeId::MAP: { - if (!type_info_) { - return "MAP"; - } - auto &key_type = MapType::KeyType(*this); - auto &value_type = MapType::ValueType(*this); - return "MAP(" + key_type.ToString() + ", " + value_type.ToString() + ")"; - } - case LogicalTypeId::UNION: { - if (!type_info_) { - return "UNION"; - } - string ret = "UNION("; - size_t count = UnionType::GetMemberCount(*this); - for (size_t i = 0; i < count; i++) { - ret += UnionType::GetMemberName(*this, i) + " " + UnionType::GetMemberType(*this, i).ToString(); - if (i < count - 1) { - ret += ", "; - } - } - ret += ")"; - return ret; - } - case LogicalTypeId::DECIMAL: { - if (!type_info_) { - return "DECIMAL"; - } - auto width = DecimalType::GetWidth(*this); - auto scale = DecimalType::GetScale(*this); - if (width == 0) { - return "DECIMAL"; - } - return StringUtil::Format("DECIMAL(%d,%d)", width, scale); - } - case LogicalTypeId::ENUM: { - string ret = "ENUM("; - for (idx_t i = 0; i < EnumType::GetSize(*this); i++) { - if (i > 0) { - ret += ", "; - } - ret += KeywordHelper::WriteQuoted(EnumType::GetString(*this, i).GetString(), '\''); - } - ret += ")"; - return ret; - } - case LogicalTypeId::USER: { - return KeywordHelper::WriteOptionallyQuoted(UserType::GetTypeName(*this)); - } - case LogicalTypeId::AGGREGATE_STATE: { - return AggregateStateType::GetTypeName(*this); - } - default: - return EnumUtil::ToString(id_); - } -} -// LCOV_EXCL_STOP - -LogicalTypeId TransformStringToLogicalTypeId(const string &str) { - auto type = DefaultTypeGenerator::GetDefaultType(str); - if (type == LogicalTypeId::INVALID) { - // This is a User Type, at this point we don't know if its one of the User Defined Types or an error - // It is checked in the binder - type = LogicalTypeId::USER; - } - return type; -} - -LogicalType TransformStringToLogicalType(const string &str) { - if (StringUtil::Lower(str) == "null") { - return LogicalType::SQLNULL; - } - return Parser::ParseColumnList("dummy " + str).GetColumn(LogicalIndex(0)).Type(); -} - -LogicalType GetUserTypeRecursive(const LogicalType &type, ClientContext &context) { - if (type.id() == LogicalTypeId::USER && type.HasAlias()) { - return Catalog::GetType(context, INVALID_CATALOG, INVALID_SCHEMA, type.GetAlias()); - } - // Look for LogicalTypeId::USER in nested types - if (type.id() == LogicalTypeId::STRUCT) { - child_list_t children; - children.reserve(StructType::GetChildCount(type)); - for (auto &child : StructType::GetChildTypes(type)) { - children.emplace_back(child.first, GetUserTypeRecursive(child.second, context)); - } - return LogicalType::STRUCT(children); - } - if (type.id() == LogicalTypeId::LIST) { - return LogicalType::LIST(GetUserTypeRecursive(ListType::GetChildType(type), context)); - } - if (type.id() == LogicalTypeId::MAP) { - return LogicalType::MAP(GetUserTypeRecursive(MapType::KeyType(type), context), - GetUserTypeRecursive(MapType::ValueType(type), context)); - } - // Not LogicalTypeId::USER or a nested type - return type; -} - -LogicalType TransformStringToLogicalType(const string &str, ClientContext &context) { - return GetUserTypeRecursive(TransformStringToLogicalType(str), context); -} - -bool LogicalType::IsIntegral() const { - switch (id_) { - case LogicalTypeId::TINYINT: - case LogicalTypeId::SMALLINT: - case LogicalTypeId::INTEGER: - case LogicalTypeId::BIGINT: - case LogicalTypeId::UTINYINT: - case LogicalTypeId::USMALLINT: - case LogicalTypeId::UINTEGER: - case LogicalTypeId::UBIGINT: - case LogicalTypeId::HUGEINT: - return true; - default: - return false; - } -} - -bool LogicalType::IsNumeric() const { - switch (id_) { - case LogicalTypeId::TINYINT: - case LogicalTypeId::SMALLINT: - case LogicalTypeId::INTEGER: - case LogicalTypeId::BIGINT: - case LogicalTypeId::HUGEINT: - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - case LogicalTypeId::DECIMAL: - case LogicalTypeId::UTINYINT: - case LogicalTypeId::USMALLINT: - case LogicalTypeId::UINTEGER: - case LogicalTypeId::UBIGINT: - return true; - default: - return false; - } -} - -bool LogicalType::IsValid() const { - return id() != LogicalTypeId::INVALID && id() != LogicalTypeId::UNKNOWN; -} - -bool LogicalType::GetDecimalProperties(uint8_t &width, uint8_t &scale) const { - switch (id_) { - case LogicalTypeId::SQLNULL: - width = 0; - scale = 0; - break; - case LogicalTypeId::BOOLEAN: - width = 1; - scale = 0; - break; - case LogicalTypeId::TINYINT: - // tinyint: [-127, 127] = DECIMAL(3,0) - width = 3; - scale = 0; - break; - case LogicalTypeId::SMALLINT: - // smallint: [-32767, 32767] = DECIMAL(5,0) - width = 5; - scale = 0; - break; - case LogicalTypeId::INTEGER: - // integer: [-2147483647, 2147483647] = DECIMAL(10,0) - width = 10; - scale = 0; - break; - case LogicalTypeId::BIGINT: - // bigint: [-9223372036854775807, 9223372036854775807] = DECIMAL(19,0) - width = 19; - scale = 0; - break; - case LogicalTypeId::UTINYINT: - // UInt8 — [0 : 255] - width = 3; - scale = 0; - break; - case LogicalTypeId::USMALLINT: - // UInt16 — [0 : 65535] - width = 5; - scale = 0; - break; - case LogicalTypeId::UINTEGER: - // UInt32 — [0 : 4294967295] - width = 10; - scale = 0; - break; - case LogicalTypeId::UBIGINT: - // UInt64 — [0 : 18446744073709551615] - width = 20; - scale = 0; - break; - case LogicalTypeId::HUGEINT: - // hugeint: max size decimal (38, 0) - // note that a hugeint is not guaranteed to fit in this - width = 38; - scale = 0; - break; - case LogicalTypeId::DECIMAL: - width = DecimalType::GetWidth(*this); - scale = DecimalType::GetScale(*this); - break; - default: - // Nonsense values to ensure initialization - width = 255u; - scale = 255u; - // FIXME(carlo): This should be probably a throw, requires checkign the various call-sites - return false; - } - return true; -} - -//! Grows Decimal width/scale when appropriate -static LogicalType DecimalSizeCheck(const LogicalType &left, const LogicalType &right) { - D_ASSERT(left.id() == LogicalTypeId::DECIMAL || right.id() == LogicalTypeId::DECIMAL); - D_ASSERT(left.id() != right.id()); - - //! Make sure the 'right' is the DECIMAL type - if (left.id() == LogicalTypeId::DECIMAL) { - return DecimalSizeCheck(right, left); - } - auto width = DecimalType::GetWidth(right); - auto scale = DecimalType::GetScale(right); - - uint8_t other_width; - uint8_t other_scale; - bool success = left.GetDecimalProperties(other_width, other_scale); - if (!success) { - throw InternalException("Type provided to DecimalSizeCheck was not a numeric type"); - } - D_ASSERT(other_scale == 0); - const auto effective_width = width - scale; - if (other_width > effective_width) { - auto new_width = other_width + scale; - //! Cap the width at max, if an actual value exceeds this, an exception will be thrown later - if (new_width > DecimalType::MaxWidth()) { - new_width = DecimalType::MaxWidth(); - } - return LogicalType::DECIMAL(new_width, scale); - } - return right; -} - -static LogicalType CombineNumericTypes(const LogicalType &left, const LogicalType &right) { - D_ASSERT(left.id() != right.id()); - if (left.id() > right.id()) { - // this method is symmetric - // arrange it so the left type is smaller to limit the number of options we need to check - return CombineNumericTypes(right, left); - } - if (CastRules::ImplicitCast(left, right) >= 0) { - // we can implicitly cast left to right, return right - //! Depending on the type, we might need to grow the `width` of the DECIMAL type - if (right.id() == LogicalTypeId::DECIMAL) { - return DecimalSizeCheck(left, right); - } - return right; - } - if (CastRules::ImplicitCast(right, left) >= 0) { - // we can implicitly cast right to left, return left - //! Depending on the type, we might need to grow the `width` of the DECIMAL type - if (left.id() == LogicalTypeId::DECIMAL) { - return DecimalSizeCheck(right, left); - } - return left; - } - // we can't cast implicitly either way and types are not equal - // this happens when left is signed and right is unsigned - // e.g. INTEGER and UINTEGER - // in this case we need to upcast to make sure the types fit - - if (left.id() == LogicalTypeId::BIGINT || right.id() == LogicalTypeId::UBIGINT) { - return LogicalType::HUGEINT; - } - if (left.id() == LogicalTypeId::INTEGER || right.id() == LogicalTypeId::UINTEGER) { - return LogicalType::BIGINT; - } - if (left.id() == LogicalTypeId::SMALLINT || right.id() == LogicalTypeId::USMALLINT) { - return LogicalType::INTEGER; - } - if (left.id() == LogicalTypeId::TINYINT || right.id() == LogicalTypeId::UTINYINT) { - return LogicalType::SMALLINT; - } - throw InternalException("Cannot combine these numeric types!?"); -} - -LogicalType LogicalType::MaxLogicalType(const LogicalType &left, const LogicalType &right) { - // we always prefer aliased types - if (!left.GetAlias().empty()) { - return left; - } - if (!right.GetAlias().empty()) { - return right; - } - if (left.id() != right.id() && left.IsNumeric() && right.IsNumeric()) { - return CombineNumericTypes(left, right); - } else if (left.id() == LogicalTypeId::UNKNOWN) { - return right; - } else if (right.id() == LogicalTypeId::UNKNOWN) { - return left; - } else if ((right.id() == LogicalTypeId::ENUM || left.id() == LogicalTypeId::ENUM) && right.id() != left.id()) { - // if one is an enum and the other is not, compare strings, not enums - // see https://github.com/duckdb/duckdb/issues/8561 - return LogicalTypeId::VARCHAR; - } else if (left.id() < right.id()) { - return right; - } - if (right.id() < left.id()) { - return left; - } - // Since both left and right are equal we get the left type as our type_id for checks - auto type_id = left.id(); - if (type_id == LogicalTypeId::ENUM) { - // If both types are different ENUMs we do a string comparison. - return left == right ? left : LogicalType::VARCHAR; - } - if (type_id == LogicalTypeId::VARCHAR) { - // varchar: use type that has collation (if any) - if (StringType::GetCollation(right).empty()) { - return left; - } - return right; - } - if (type_id == LogicalTypeId::DECIMAL) { - // unify the width/scale so that the resulting decimal always fits - // "width - scale" gives us the number of digits on the left side of the decimal point - // "scale" gives us the number of digits allowed on the right of the decimal point - // using the max of these of the two types gives us the new decimal size - auto extra_width_left = DecimalType::GetWidth(left) - DecimalType::GetScale(left); - auto extra_width_right = DecimalType::GetWidth(right) - DecimalType::GetScale(right); - auto extra_width = MaxValue(extra_width_left, extra_width_right); - auto scale = MaxValue(DecimalType::GetScale(left), DecimalType::GetScale(right)); - auto width = extra_width + scale; - if (width > DecimalType::MaxWidth()) { - // if the resulting decimal does not fit, we truncate the scale - width = DecimalType::MaxWidth(); - scale = width - extra_width; - } - return LogicalType::DECIMAL(width, scale); - } - if (type_id == LogicalTypeId::LIST) { - // list: perform max recursively on child type - auto new_child = MaxLogicalType(ListType::GetChildType(left), ListType::GetChildType(right)); - return LogicalType::LIST(new_child); - } - if (type_id == LogicalTypeId::MAP) { - // list: perform max recursively on child type - auto new_child = MaxLogicalType(ListType::GetChildType(left), ListType::GetChildType(right)); - return LogicalType::MAP(new_child); - } - if (type_id == LogicalTypeId::STRUCT) { - // struct: perform recursively - auto &left_child_types = StructType::GetChildTypes(left); - auto &right_child_types = StructType::GetChildTypes(right); - if (left_child_types.size() != right_child_types.size()) { - // child types are not of equal size, we can't cast anyway - // just return the left child - return left; - } - child_list_t child_types; - for (idx_t i = 0; i < left_child_types.size(); i++) { - auto child_type = MaxLogicalType(left_child_types[i].second, right_child_types[i].second); - child_types.emplace_back(left_child_types[i].first, std::move(child_type)); - } - - return LogicalType::STRUCT(child_types); - } - if (type_id == LogicalTypeId::UNION) { - auto left_member_count = UnionType::GetMemberCount(left); - auto right_member_count = UnionType::GetMemberCount(right); - if (left_member_count != right_member_count) { - // return the "larger" type, with the most members - return left_member_count > right_member_count ? left : right; - } - // otherwise, keep left, don't try to meld the two together. - return left; - } - // types are equal but no extra specifier: just return the type - return left; -} - -void LogicalType::Verify() const { -#ifdef DEBUG - if (id_ == LogicalTypeId::DECIMAL) { - D_ASSERT(DecimalType::GetWidth(*this) >= 1 && DecimalType::GetWidth(*this) <= Decimal::MAX_WIDTH_DECIMAL); - D_ASSERT(DecimalType::GetScale(*this) >= 0 && DecimalType::GetScale(*this) <= DecimalType::GetWidth(*this)); - } -#endif -} - -bool ApproxEqual(float ldecimal, float rdecimal) { - if (Value::IsNan(ldecimal) && Value::IsNan(rdecimal)) { - return true; - } - if (!Value::FloatIsFinite(ldecimal) || !Value::FloatIsFinite(rdecimal)) { - return ldecimal == rdecimal; - } - float epsilon = std::fabs(rdecimal) * 0.01 + 0.00000001; - return std::fabs(ldecimal - rdecimal) <= epsilon; -} - -bool ApproxEqual(double ldecimal, double rdecimal) { - if (Value::IsNan(ldecimal) && Value::IsNan(rdecimal)) { - return true; - } - if (!Value::DoubleIsFinite(ldecimal) || !Value::DoubleIsFinite(rdecimal)) { - return ldecimal == rdecimal; - } - double epsilon = std::fabs(rdecimal) * 0.01 + 0.00000001; - return std::fabs(ldecimal - rdecimal) <= epsilon; -} - -//===--------------------------------------------------------------------===// -// Extra Type Info -//===--------------------------------------------------------------------===// -void LogicalType::SetAlias(string alias) { - if (!type_info_) { - type_info_ = make_shared(ExtraTypeInfoType::GENERIC_TYPE_INFO, std::move(alias)); - } else { - type_info_->alias = std::move(alias); - } -} - -string LogicalType::GetAlias() const { - if (id() == LogicalTypeId::USER) { - return UserType::GetTypeName(*this); - } - if (type_info_) { - return type_info_->alias; - } - return string(); -} - -bool LogicalType::HasAlias() const { - if (id() == LogicalTypeId::USER) { - return !UserType::GetTypeName(*this).empty(); - } - if (type_info_ && !type_info_->alias.empty()) { - return true; - } - return false; -} - -//===--------------------------------------------------------------------===// -// Decimal Type -//===--------------------------------------------------------------------===// -uint8_t DecimalType::GetWidth(const LogicalType &type) { - D_ASSERT(type.id() == LogicalTypeId::DECIMAL); - auto info = type.AuxInfo(); - D_ASSERT(info); - return info->Cast().width; -} - -uint8_t DecimalType::GetScale(const LogicalType &type) { - D_ASSERT(type.id() == LogicalTypeId::DECIMAL); - auto info = type.AuxInfo(); - D_ASSERT(info); - return info->Cast().scale; -} - -uint8_t DecimalType::MaxWidth() { - return DecimalWidth::max; -} - -LogicalType LogicalType::DECIMAL(int width, int scale) { - D_ASSERT(width >= scale); - auto type_info = make_shared(width, scale); - return LogicalType(LogicalTypeId::DECIMAL, std::move(type_info)); -} - -//===--------------------------------------------------------------------===// -// String Type -//===--------------------------------------------------------------------===// -string StringType::GetCollation(const LogicalType &type) { - if (type.id() != LogicalTypeId::VARCHAR) { - return string(); - } - auto info = type.AuxInfo(); - if (!info) { - return string(); - } - if (info->type == ExtraTypeInfoType::GENERIC_TYPE_INFO) { - return string(); - } - return info->Cast().collation; -} - -LogicalType LogicalType::VARCHAR_COLLATION(string collation) { // NOLINT - auto string_info = make_shared(std::move(collation)); - return LogicalType(LogicalTypeId::VARCHAR, std::move(string_info)); -} - -//===--------------------------------------------------------------------===// -// List Type -//===--------------------------------------------------------------------===// -const LogicalType &ListType::GetChildType(const LogicalType &type) { - D_ASSERT(type.id() == LogicalTypeId::LIST || type.id() == LogicalTypeId::MAP); - auto info = type.AuxInfo(); - D_ASSERT(info); - return info->Cast().child_type; -} - -LogicalType LogicalType::LIST(const LogicalType &child) { - auto info = make_shared(child); - return LogicalType(LogicalTypeId::LIST, std::move(info)); -} - -//===--------------------------------------------------------------------===// -// Aggregate State Type -//===--------------------------------------------------------------------===// -const aggregate_state_t &AggregateStateType::GetStateType(const LogicalType &type) { - D_ASSERT(type.id() == LogicalTypeId::AGGREGATE_STATE); - auto info = type.AuxInfo(); - D_ASSERT(info); - return info->Cast().state_type; -} - -const string AggregateStateType::GetTypeName(const LogicalType &type) { - D_ASSERT(type.id() == LogicalTypeId::AGGREGATE_STATE); - auto info = type.AuxInfo(); - if (!info) { - return "AGGREGATE_STATE"; - } - auto aggr_state = info->Cast().state_type; - return "AGGREGATE_STATE<" + aggr_state.function_name + "(" + - StringUtil::Join(aggr_state.bound_argument_types, aggr_state.bound_argument_types.size(), ", ", - [](const LogicalType &arg_type) { return arg_type.ToString(); }) + - ")" + "::" + aggr_state.return_type.ToString() + ">"; -} - -//===--------------------------------------------------------------------===// -// Struct Type -//===--------------------------------------------------------------------===// -const child_list_t &StructType::GetChildTypes(const LogicalType &type) { - D_ASSERT(type.id() == LogicalTypeId::STRUCT || type.id() == LogicalTypeId::UNION); - - auto info = type.AuxInfo(); - D_ASSERT(info); - return info->Cast().child_types; -} - -const LogicalType &StructType::GetChildType(const LogicalType &type, idx_t index) { - auto &child_types = StructType::GetChildTypes(type); - D_ASSERT(index < child_types.size()); - return child_types[index].second; -} - -const string &StructType::GetChildName(const LogicalType &type, idx_t index) { - auto &child_types = StructType::GetChildTypes(type); - D_ASSERT(index < child_types.size()); - return child_types[index].first; -} - -idx_t StructType::GetChildCount(const LogicalType &type) { - return StructType::GetChildTypes(type).size(); -} -bool StructType::IsUnnamed(const LogicalType &type) { - auto &child_types = StructType::GetChildTypes(type); - D_ASSERT(child_types.size() > 0); - return child_types[0].first.empty(); -} - -LogicalType LogicalType::STRUCT(child_list_t children) { - auto info = make_shared(std::move(children)); - return LogicalType(LogicalTypeId::STRUCT, std::move(info)); -} - -LogicalType LogicalType::AGGREGATE_STATE(aggregate_state_t state_type) { // NOLINT - auto info = make_shared(std::move(state_type)); - return LogicalType(LogicalTypeId::AGGREGATE_STATE, std::move(info)); -} - -//===--------------------------------------------------------------------===// -// Map Type -//===--------------------------------------------------------------------===// -LogicalType LogicalType::MAP(const LogicalType &child_p) { - D_ASSERT(child_p.id() == LogicalTypeId::STRUCT); - auto &children = StructType::GetChildTypes(child_p); - D_ASSERT(children.size() == 2); - - // We do this to enforce that for every MAP created, the keys are called "key" - // and the values are called "value" - - // This is done because for Vector the keys of the STRUCT are used in equality checks. - // Vector::Reference will throw if the types don't match - child_list_t new_children(2); - new_children[0] = children[0]; - new_children[0].first = "key"; - - new_children[1] = children[1]; - new_children[1].first = "value"; - - auto child = LogicalType::STRUCT(std::move(new_children)); - auto info = make_shared(child); - return LogicalType(LogicalTypeId::MAP, std::move(info)); -} - -LogicalType LogicalType::MAP(LogicalType key, LogicalType value) { - child_list_t child_types; - child_types.emplace_back("key", std::move(key)); - child_types.emplace_back("value", std::move(value)); - return LogicalType::MAP(LogicalType::STRUCT(child_types)); -} - -const LogicalType &MapType::KeyType(const LogicalType &type) { - D_ASSERT(type.id() == LogicalTypeId::MAP); - return StructType::GetChildTypes(ListType::GetChildType(type))[0].second; -} - -const LogicalType &MapType::ValueType(const LogicalType &type) { - D_ASSERT(type.id() == LogicalTypeId::MAP); - return StructType::GetChildTypes(ListType::GetChildType(type))[1].second; -} - -//===--------------------------------------------------------------------===// -// Union Type -//===--------------------------------------------------------------------===// -LogicalType LogicalType::UNION(child_list_t members) { - D_ASSERT(!members.empty()); - D_ASSERT(members.size() <= UnionType::MAX_UNION_MEMBERS); - // union types always have a hidden "tag" field in front - members.insert(members.begin(), {"", LogicalType::UTINYINT}); - auto info = make_shared(std::move(members)); - return LogicalType(LogicalTypeId::UNION, std::move(info)); -} - -const LogicalType &UnionType::GetMemberType(const LogicalType &type, idx_t index) { - auto &child_types = StructType::GetChildTypes(type); - D_ASSERT(index < child_types.size()); - // skip the "tag" field - return child_types[index + 1].second; -} - -const string &UnionType::GetMemberName(const LogicalType &type, idx_t index) { - auto &child_types = StructType::GetChildTypes(type); - D_ASSERT(index < child_types.size()); - // skip the "tag" field - return child_types[index + 1].first; -} - -idx_t UnionType::GetMemberCount(const LogicalType &type) { - // don't count the "tag" field - return StructType::GetChildTypes(type).size() - 1; -} -const child_list_t UnionType::CopyMemberTypes(const LogicalType &type) { - auto child_types = StructType::GetChildTypes(type); - child_types.erase(child_types.begin()); - return child_types; -} - -//===--------------------------------------------------------------------===// -// User Type -//===--------------------------------------------------------------------===// -const string &UserType::GetTypeName(const LogicalType &type) { - D_ASSERT(type.id() == LogicalTypeId::USER); - auto info = type.AuxInfo(); - D_ASSERT(info); - return info->Cast().user_type_name; -} - -LogicalType LogicalType::USER(const string &user_type_name) { - auto info = make_shared(user_type_name); - return LogicalType(LogicalTypeId::USER, std::move(info)); -} - -//===--------------------------------------------------------------------===// -// Enum Type -//===--------------------------------------------------------------------===// -LogicalType LogicalType::ENUM(Vector &ordered_data, idx_t size) { - return EnumTypeInfo::CreateType(ordered_data, size); -} - -LogicalType LogicalType::ENUM(const string &enum_name, Vector &ordered_data, idx_t size) { - return LogicalType::ENUM(ordered_data, size); -} - -const string EnumType::GetValue(const Value &val) { - auto info = val.type().AuxInfo(); - auto &values_insert_order = info->Cast().GetValuesInsertOrder(); - return StringValue::Get(values_insert_order.GetValue(val.GetValue())); -} - -const Vector &EnumType::GetValuesInsertOrder(const LogicalType &type) { - D_ASSERT(type.id() == LogicalTypeId::ENUM); - auto info = type.AuxInfo(); - D_ASSERT(info); - return info->Cast().GetValuesInsertOrder(); -} - -idx_t EnumType::GetSize(const LogicalType &type) { - D_ASSERT(type.id() == LogicalTypeId::ENUM); - auto info = type.AuxInfo(); - D_ASSERT(info); - return info->Cast().GetDictSize(); -} - -PhysicalType EnumType::GetPhysicalType(const LogicalType &type) { - D_ASSERT(type.id() == LogicalTypeId::ENUM); - auto aux_info = type.AuxInfo(); - D_ASSERT(aux_info); - auto &info = aux_info->Cast(); - D_ASSERT(info.GetEnumDictType() == EnumDictType::VECTOR_DICT); - return EnumTypeInfo::DictType(info.GetDictSize()); -} - -//===--------------------------------------------------------------------===// -// Logical Type -//===--------------------------------------------------------------------===// - -// the destructor needs to know about the extra type info -LogicalType::~LogicalType() { -} - -bool LogicalType::EqualTypeInfo(const LogicalType &rhs) const { - if (type_info_.get() == rhs.type_info_.get()) { - return true; - } - if (type_info_) { - return type_info_->Equals(rhs.type_info_.get()); - } else { - D_ASSERT(rhs.type_info_); - return rhs.type_info_->Equals(type_info_.get()); - } -} - -bool LogicalType::operator==(const LogicalType &rhs) const { - if (id_ != rhs.id_) { - return false; - } - return EqualTypeInfo(rhs); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Comparison Operations -//===--------------------------------------------------------------------===// - -struct ValuePositionComparator { - // Return true if the positional Values definitely match. - // Default to the same as the final value - template - static inline bool Definite(const Value &lhs, const Value &rhs) { - return Final(lhs, rhs); - } - - // Select the positional Values that need further testing. - // Usually this means Is Not Distinct, as those are the semantics used by Postges - template - static inline bool Possible(const Value &lhs, const Value &rhs) { - return ValueOperations::NotDistinctFrom(lhs, rhs); - } - - // Return true if the positional Values definitely match in the final position - // This needs to be specialised. - template - static inline bool Final(const Value &lhs, const Value &rhs) { - return false; - } - - // Tie-break based on length when one of the sides has been exhausted, returning true if the LHS matches. - // This essentially means that the existing positions compare equal. - // Default to the same semantics as the OP for idx_t. This works in most cases. - template - static inline bool TieBreak(const idx_t lpos, const idx_t rpos) { - return OP::Operation(lpos, rpos); - } -}; - -// Equals must always check every column -template <> -inline bool ValuePositionComparator::Definite(const Value &lhs, const Value &rhs) { - return false; -} - -template <> -inline bool ValuePositionComparator::Final(const Value &lhs, const Value &rhs) { - return ValueOperations::NotDistinctFrom(lhs, rhs); -} - -// NotEquals must check everything that matched -template <> -inline bool ValuePositionComparator::Possible(const Value &lhs, const Value &rhs) { - return true; -} - -template <> -inline bool ValuePositionComparator::Final(const Value &lhs, const Value &rhs) { - return ValueOperations::NotDistinctFrom(lhs, rhs); -} - -// Non-strict inequalities must use strict comparisons for Definite -template <> -bool ValuePositionComparator::Definite(const Value &lhs, const Value &rhs) { - return !ValuePositionComparator::Definite(lhs, rhs); -} - -template <> -bool ValuePositionComparator::Final(const Value &lhs, const Value &rhs) { - return ValueOperations::DistinctGreaterThan(lhs, rhs); -} - -template <> -bool ValuePositionComparator::Final(const Value &lhs, const Value &rhs) { - return !ValuePositionComparator::Final(lhs, rhs); -} - -template <> -bool ValuePositionComparator::Definite(const Value &lhs, const Value &rhs) { - return !ValuePositionComparator::Definite(rhs, lhs); -} - -template <> -bool ValuePositionComparator::Final(const Value &lhs, const Value &rhs) { - return !ValuePositionComparator::Final(rhs, lhs); -} - -// Strict inequalities just use strict for both Definite and Final -template <> -bool ValuePositionComparator::Final(const Value &lhs, const Value &rhs) { - return ValuePositionComparator::Final(rhs, lhs); -} - -template -static bool TemplatedBooleanOperation(const Value &left, const Value &right) { - const auto &left_type = left.type(); - const auto &right_type = right.type(); - if (left_type != right_type) { - Value left_copy = left; - Value right_copy = right; - - LogicalType comparison_type = BoundComparisonExpression::BindComparison(left_type, right_type); - if (!left_copy.DefaultTryCastAs(comparison_type) || !right_copy.DefaultTryCastAs(comparison_type)) { - return false; - } - D_ASSERT(left_copy.type() == right_copy.type()); - return TemplatedBooleanOperation(left_copy, right_copy); - } - switch (left_type.InternalType()) { - case PhysicalType::BOOL: - return OP::Operation(left.GetValueUnsafe(), right.GetValueUnsafe()); - case PhysicalType::INT8: - return OP::Operation(left.GetValueUnsafe(), right.GetValueUnsafe()); - case PhysicalType::INT16: - return OP::Operation(left.GetValueUnsafe(), right.GetValueUnsafe()); - case PhysicalType::INT32: - return OP::Operation(left.GetValueUnsafe(), right.GetValueUnsafe()); - case PhysicalType::INT64: - return OP::Operation(left.GetValueUnsafe(), right.GetValueUnsafe()); - case PhysicalType::UINT8: - return OP::Operation(left.GetValueUnsafe(), right.GetValueUnsafe()); - case PhysicalType::UINT16: - return OP::Operation(left.GetValueUnsafe(), right.GetValueUnsafe()); - case PhysicalType::UINT32: - return OP::Operation(left.GetValueUnsafe(), right.GetValueUnsafe()); - case PhysicalType::UINT64: - return OP::Operation(left.GetValueUnsafe(), right.GetValueUnsafe()); - case PhysicalType::INT128: - return OP::Operation(left.GetValueUnsafe(), right.GetValueUnsafe()); - case PhysicalType::FLOAT: - return OP::Operation(left.GetValueUnsafe(), right.GetValueUnsafe()); - case PhysicalType::DOUBLE: - return OP::Operation(left.GetValueUnsafe(), right.GetValueUnsafe()); - case PhysicalType::INTERVAL: - return OP::Operation(left.GetValueUnsafe(), right.GetValueUnsafe()); - case PhysicalType::VARCHAR: - return OP::Operation(StringValue::Get(left), StringValue::Get(right)); - case PhysicalType::STRUCT: { - auto &left_children = StructValue::GetChildren(left); - auto &right_children = StructValue::GetChildren(right); - // this should be enforced by the type - D_ASSERT(left_children.size() == right_children.size()); - idx_t i = 0; - for (; i < left_children.size() - 1; ++i) { - if (ValuePositionComparator::Definite(left_children[i], right_children[i])) { - return true; - } - if (!ValuePositionComparator::Possible(left_children[i], right_children[i])) { - return false; - } - } - return ValuePositionComparator::Final(left_children[i], right_children[i]); - } - case PhysicalType::LIST: { - auto &left_children = ListValue::GetChildren(left); - auto &right_children = ListValue::GetChildren(right); - for (idx_t pos = 0;; ++pos) { - if (pos == left_children.size() || pos == right_children.size()) { - return ValuePositionComparator::TieBreak(left_children.size(), right_children.size()); - } - if (ValuePositionComparator::Definite(left_children[pos], right_children[pos])) { - return true; - } - if (!ValuePositionComparator::Possible(left_children[pos], right_children[pos])) { - return false; - } - } - return false; - } - default: - throw InternalException("Unimplemented type for value comparison"); - } -} - -bool ValueOperations::Equals(const Value &left, const Value &right) { - if (left.IsNull() || right.IsNull()) { - throw InternalException("Comparison on NULL values"); - } - return TemplatedBooleanOperation(left, right); -} - -bool ValueOperations::NotEquals(const Value &left, const Value &right) { - return !ValueOperations::Equals(left, right); -} - -bool ValueOperations::GreaterThan(const Value &left, const Value &right) { - if (left.IsNull() || right.IsNull()) { - throw InternalException("Comparison on NULL values"); - } - return TemplatedBooleanOperation(left, right); -} - -bool ValueOperations::GreaterThanEquals(const Value &left, const Value &right) { - return !ValueOperations::GreaterThan(right, left); -} - -bool ValueOperations::LessThan(const Value &left, const Value &right) { - return ValueOperations::GreaterThan(right, left); -} - -bool ValueOperations::LessThanEquals(const Value &left, const Value &right) { - return !ValueOperations::GreaterThan(left, right); -} - -bool ValueOperations::NotDistinctFrom(const Value &left, const Value &right) { - if (left.IsNull() && right.IsNull()) { - return true; - } - if (left.IsNull() != right.IsNull()) { - return false; - } - return TemplatedBooleanOperation(left, right); -} - -bool ValueOperations::DistinctFrom(const Value &left, const Value &right) { - return !ValueOperations::NotDistinctFrom(left, right); -} - -bool ValueOperations::DistinctGreaterThan(const Value &left, const Value &right) { - if (left.IsNull() && right.IsNull()) { - return false; - } else if (right.IsNull()) { - return false; - } else if (left.IsNull()) { - return true; - } - return TemplatedBooleanOperation(left, right); -} - -bool ValueOperations::DistinctGreaterThanEquals(const Value &left, const Value &right) { - return !ValueOperations::DistinctGreaterThan(right, left); -} - -bool ValueOperations::DistinctLessThan(const Value &left, const Value &right) { - return ValueOperations::DistinctGreaterThan(right, left); -} - -bool ValueOperations::DistinctLessThanEquals(const Value &left, const Value &right) { - return !ValueOperations::DistinctGreaterThan(left, right); -} - -} // namespace duckdb -//===--------------------------------------------------------------------===// -// boolean_operators.cpp -// Description: This file contains the implementation of the boolean -// operations AND OR ! -//===--------------------------------------------------------------------===// - - - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// AND/OR -//===--------------------------------------------------------------------===// -template -static void TemplatedBooleanNullmask(Vector &left, Vector &right, Vector &result, idx_t count) { - D_ASSERT(left.GetType().id() == LogicalTypeId::BOOLEAN && right.GetType().id() == LogicalTypeId::BOOLEAN && - result.GetType().id() == LogicalTypeId::BOOLEAN); - - if (left.GetVectorType() == VectorType::CONSTANT_VECTOR && right.GetVectorType() == VectorType::CONSTANT_VECTOR) { - // operation on two constants, result is constant vector - result.SetVectorType(VectorType::CONSTANT_VECTOR); - auto ldata = ConstantVector::GetData(left); - auto rdata = ConstantVector::GetData(right); - auto result_data = ConstantVector::GetData(result); - - bool is_null = OP::Operation(*ldata > 0, *rdata > 0, ConstantVector::IsNull(left), - ConstantVector::IsNull(right), *result_data); - ConstantVector::SetNull(result, is_null); - } else { - // perform generic loop - UnifiedVectorFormat ldata, rdata; - left.ToUnifiedFormat(count, ldata); - right.ToUnifiedFormat(count, rdata); - - result.SetVectorType(VectorType::FLAT_VECTOR); - auto left_data = UnifiedVectorFormat::GetData(ldata); // we use uint8 to avoid load of gunk bools - auto right_data = UnifiedVectorFormat::GetData(rdata); - auto result_data = FlatVector::GetData(result); - auto &result_mask = FlatVector::Validity(result); - if (!ldata.validity.AllValid() || !rdata.validity.AllValid()) { - for (idx_t i = 0; i < count; i++) { - auto lidx = ldata.sel->get_index(i); - auto ridx = rdata.sel->get_index(i); - bool is_null = - OP::Operation(left_data[lidx] > 0, right_data[ridx] > 0, !ldata.validity.RowIsValid(lidx), - !rdata.validity.RowIsValid(ridx), result_data[i]); - result_mask.Set(i, !is_null); - } - } else { - for (idx_t i = 0; i < count; i++) { - auto lidx = ldata.sel->get_index(i); - auto ridx = rdata.sel->get_index(i); - result_data[i] = OP::SimpleOperation(left_data[lidx], right_data[ridx]); - } - } - } -} - -/* -SQL AND Rules: - -TRUE AND TRUE = TRUE -TRUE AND FALSE = FALSE -TRUE AND NULL = NULL -FALSE AND TRUE = FALSE -FALSE AND FALSE = FALSE -FALSE AND NULL = FALSE -NULL AND TRUE = NULL -NULL AND FALSE = FALSE -NULL AND NULL = NULL - -Basically: -- Only true if both are true -- False if either is false (regardless of NULLs) -- NULL otherwise -*/ -struct TernaryAnd { - static bool SimpleOperation(bool left, bool right) { - return left && right; - } - static bool Operation(bool left, bool right, bool left_null, bool right_null, bool &result) { - if (left_null && right_null) { - // both NULL: - // result is NULL - return true; - } else if (left_null) { - // left is NULL: - // result is FALSE if right is false - // result is NULL if right is true - result = right; - return right; - } else if (right_null) { - // right is NULL: - // result is FALSE if left is false - // result is NULL if left is true - result = left; - return left; - } else { - // no NULL: perform the AND - result = left && right; - return false; - } - } -}; - -void VectorOperations::And(Vector &left, Vector &right, Vector &result, idx_t count) { - TemplatedBooleanNullmask(left, right, result, count); -} - -/* -SQL OR Rules: - -OR -TRUE OR TRUE = TRUE -TRUE OR FALSE = TRUE -TRUE OR NULL = TRUE -FALSE OR TRUE = TRUE -FALSE OR FALSE = FALSE -FALSE OR NULL = NULL -NULL OR TRUE = TRUE -NULL OR FALSE = NULL -NULL OR NULL = NULL - -Basically: -- Only false if both are false -- True if either is true (regardless of NULLs) -- NULL otherwise -*/ - -struct TernaryOr { - static bool SimpleOperation(bool left, bool right) { - return left || right; - } - static bool Operation(bool left, bool right, bool left_null, bool right_null, bool &result) { - if (left_null && right_null) { - // both NULL: - // result is NULL - return true; - } else if (left_null) { - // left is NULL: - // result is TRUE if right is true - // result is NULL if right is false - result = right; - return !right; - } else if (right_null) { - // right is NULL: - // result is TRUE if left is true - // result is NULL if left is false - result = left; - return !left; - } else { - // no NULL: perform the OR - result = left || right; - return false; - } - } -}; - -void VectorOperations::Or(Vector &left, Vector &right, Vector &result, idx_t count) { - TemplatedBooleanNullmask(left, right, result, count); -} - -struct NotOperator { - template - static inline TR Operation(TA left) { - return !left; - } -}; - -void VectorOperations::Not(Vector &input, Vector &result, idx_t count) { - D_ASSERT(input.GetType() == LogicalType::BOOLEAN && result.GetType() == LogicalType::BOOLEAN); - UnaryExecutor::Execute(input, result, count); -} - -} // namespace duckdb -//===--------------------------------------------------------------------===// -// comparison_operators.cpp -// Description: This file contains the implementation of the comparison -// operations == != >= <= > < -//===--------------------------------------------------------------------===// - - - - - - - - -namespace duckdb { - -template -bool EqualsFloat(T left, T right) { - if (DUCKDB_UNLIKELY(Value::IsNan(left) && Value::IsNan(right))) { - return true; - } - return left == right; -} - -template <> -bool Equals::Operation(const float &left, const float &right) { - return EqualsFloat(left, right); -} - -template <> -bool Equals::Operation(const double &left, const double &right) { - return EqualsFloat(left, right); -} - -template -bool GreaterThanFloat(T left, T right) { - // handle nans - // nan is always bigger than everything else - bool left_is_nan = Value::IsNan(left); - bool right_is_nan = Value::IsNan(right); - // if right is nan, there is no number that is bigger than right - if (DUCKDB_UNLIKELY(right_is_nan)) { - return false; - } - // if left is nan, but right is not, left is always bigger - if (DUCKDB_UNLIKELY(left_is_nan)) { - return true; - } - return left > right; -} - -template <> -bool GreaterThan::Operation(const float &left, const float &right) { - return GreaterThanFloat(left, right); -} - -template <> -bool GreaterThan::Operation(const double &left, const double &right) { - return GreaterThanFloat(left, right); -} - -template -bool GreaterThanEqualsFloat(T left, T right) { - // handle nans - // nan is always bigger than everything else - bool left_is_nan = Value::IsNan(left); - bool right_is_nan = Value::IsNan(right); - // if right is nan, there is no bigger number - // we only return true if left is also nan (in which case the numbers are equal) - if (DUCKDB_UNLIKELY(right_is_nan)) { - return left_is_nan; - } - // if left is nan, but right is not, left is always bigger - if (DUCKDB_UNLIKELY(left_is_nan)) { - return true; - } - return left >= right; -} - -template <> -bool GreaterThanEquals::Operation(const float &left, const float &right) { - return GreaterThanEqualsFloat(left, right); -} - -template <> -bool GreaterThanEquals::Operation(const double &left, const double &right) { - return GreaterThanEqualsFloat(left, right); -} - -struct ComparisonSelector { - template - static idx_t Select(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, SelectionVector *true_sel, - SelectionVector *false_sel) { - throw NotImplementedException("Unknown comparison operation!"); - } -}; - -template <> -inline idx_t ComparisonSelector::Select(Vector &left, Vector &right, const SelectionVector *sel, - idx_t count, SelectionVector *true_sel, - SelectionVector *false_sel) { - return VectorOperations::Equals(left, right, sel, count, true_sel, false_sel); -} - -template <> -inline idx_t ComparisonSelector::Select(Vector &left, Vector &right, const SelectionVector *sel, - idx_t count, SelectionVector *true_sel, - SelectionVector *false_sel) { - return VectorOperations::NotEquals(left, right, sel, count, true_sel, false_sel); -} - -template <> -inline idx_t ComparisonSelector::Select(Vector &left, Vector &right, const SelectionVector *sel, - idx_t count, SelectionVector *true_sel, - SelectionVector *false_sel) { - return VectorOperations::GreaterThan(left, right, sel, count, true_sel, false_sel); -} - -template <> -inline idx_t ComparisonSelector::Select(Vector &left, Vector &right, - const SelectionVector *sel, idx_t count, - SelectionVector *true_sel, - SelectionVector *false_sel) { - return VectorOperations::GreaterThanEquals(left, right, sel, count, true_sel, false_sel); -} - -template <> -inline idx_t ComparisonSelector::Select(Vector &left, Vector &right, const SelectionVector *sel, - idx_t count, SelectionVector *true_sel, - SelectionVector *false_sel) { - return VectorOperations::GreaterThan(right, left, sel, count, true_sel, false_sel); -} - -template <> -inline idx_t ComparisonSelector::Select(Vector &left, Vector &right, const SelectionVector *sel, - idx_t count, SelectionVector *true_sel, - SelectionVector *false_sel) { - return VectorOperations::GreaterThanEquals(right, left, sel, count, true_sel, false_sel); -} - -static void ComparesNotNull(UnifiedVectorFormat &ldata, UnifiedVectorFormat &rdata, ValidityMask &vresult, - idx_t count) { - for (idx_t i = 0; i < count; ++i) { - auto lidx = ldata.sel->get_index(i); - auto ridx = rdata.sel->get_index(i); - if (!ldata.validity.RowIsValid(lidx) || !rdata.validity.RowIsValid(ridx)) { - vresult.SetInvalid(i); - } - } -} - -template -static void NestedComparisonExecutor(Vector &left, Vector &right, Vector &result, idx_t count) { - const auto left_constant = left.GetVectorType() == VectorType::CONSTANT_VECTOR; - const auto right_constant = right.GetVectorType() == VectorType::CONSTANT_VECTOR; - - if ((left_constant && ConstantVector::IsNull(left)) || (right_constant && ConstantVector::IsNull(right))) { - // either left or right is constant NULL: result is constant NULL - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - return; - } - - if (left_constant && right_constant) { - // both sides are constant, and neither is NULL so just compare one element. - result.SetVectorType(VectorType::CONSTANT_VECTOR); - SelectionVector true_sel(1); - auto match_count = ComparisonSelector::Select(left, right, nullptr, 1, &true_sel, nullptr); - auto result_data = ConstantVector::GetData(result); - result_data[0] = match_count > 0; - return; - } - - result.SetVectorType(VectorType::FLAT_VECTOR); - auto result_data = FlatVector::GetData(result); - auto &result_validity = FlatVector::Validity(result); - - UnifiedVectorFormat leftv, rightv; - left.ToUnifiedFormat(count, leftv); - right.ToUnifiedFormat(count, rightv); - if (!leftv.validity.AllValid() || !rightv.validity.AllValid()) { - ComparesNotNull(leftv, rightv, result_validity, count); - } - SelectionVector true_sel(count); - SelectionVector false_sel(count); - idx_t match_count = ComparisonSelector::Select(left, right, nullptr, count, &true_sel, &false_sel); - - for (idx_t i = 0; i < match_count; ++i) { - const auto idx = true_sel.get_index(i); - result_data[idx] = true; - } - - const idx_t no_match_count = count - match_count; - for (idx_t i = 0; i < no_match_count; ++i) { - const auto idx = false_sel.get_index(i); - result_data[idx] = false; - } -} - -struct ComparisonExecutor { -private: - template - static inline void TemplatedExecute(Vector &left, Vector &right, Vector &result, idx_t count) { - BinaryExecutor::Execute(left, right, result, count); - } - -public: - template - static inline void Execute(Vector &left, Vector &right, Vector &result, idx_t count) { - D_ASSERT(left.GetType() == right.GetType() && result.GetType() == LogicalType::BOOLEAN); - // the inplace loops take the result as the last parameter - switch (left.GetType().InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedExecute(left, right, result, count); - break; - case PhysicalType::INT16: - TemplatedExecute(left, right, result, count); - break; - case PhysicalType::INT32: - TemplatedExecute(left, right, result, count); - break; - case PhysicalType::INT64: - TemplatedExecute(left, right, result, count); - break; - case PhysicalType::UINT8: - TemplatedExecute(left, right, result, count); - break; - case PhysicalType::UINT16: - TemplatedExecute(left, right, result, count); - break; - case PhysicalType::UINT32: - TemplatedExecute(left, right, result, count); - break; - case PhysicalType::UINT64: - TemplatedExecute(left, right, result, count); - break; - case PhysicalType::INT128: - TemplatedExecute(left, right, result, count); - break; - case PhysicalType::FLOAT: - TemplatedExecute(left, right, result, count); - break; - case PhysicalType::DOUBLE: - TemplatedExecute(left, right, result, count); - break; - case PhysicalType::INTERVAL: - TemplatedExecute(left, right, result, count); - break; - case PhysicalType::VARCHAR: - TemplatedExecute(left, right, result, count); - break; - case PhysicalType::LIST: - case PhysicalType::STRUCT: - NestedComparisonExecutor(left, right, result, count); - break; - default: - throw InternalException("Invalid type for comparison"); - } - } -}; - -void VectorOperations::Equals(Vector &left, Vector &right, Vector &result, idx_t count) { - ComparisonExecutor::Execute(left, right, result, count); -} - -void VectorOperations::NotEquals(Vector &left, Vector &right, Vector &result, idx_t count) { - ComparisonExecutor::Execute(left, right, result, count); -} - -void VectorOperations::GreaterThanEquals(Vector &left, Vector &right, Vector &result, idx_t count) { - ComparisonExecutor::Execute(left, right, result, count); -} - -void VectorOperations::LessThanEquals(Vector &left, Vector &right, Vector &result, idx_t count) { - ComparisonExecutor::Execute(right, left, result, count); -} - -void VectorOperations::GreaterThan(Vector &left, Vector &right, Vector &result, idx_t count) { - ComparisonExecutor::Execute(left, right, result, count); -} - -void VectorOperations::LessThan(Vector &left, Vector &right, Vector &result, idx_t count) { - ComparisonExecutor::Execute(right, left, result, count); -} - -} // namespace duckdb -//===--------------------------------------------------------------------===// -// generators.cpp -// Description: This file contains the implementation of different generators -//===--------------------------------------------------------------------===// - - - - - -namespace duckdb { - -template -void TemplatedGenerateSequence(Vector &result, idx_t count, int64_t start, int64_t increment) { - D_ASSERT(result.GetType().IsNumeric()); - if (start > NumericLimits::Maximum() || increment > NumericLimits::Maximum()) { - throw Exception("Sequence start or increment out of type range"); - } - result.SetVectorType(VectorType::FLAT_VECTOR); - auto result_data = FlatVector::GetData(result); - auto value = (T)start; - for (idx_t i = 0; i < count; i++) { - if (i > 0) { - value += increment; - } - result_data[i] = value; - } -} - -void VectorOperations::GenerateSequence(Vector &result, idx_t count, int64_t start, int64_t increment) { - if (!result.GetType().IsNumeric()) { - throw InvalidTypeException(result.GetType(), "Can only generate sequences for numeric values!"); - } - switch (result.GetType().InternalType()) { - case PhysicalType::INT8: - TemplatedGenerateSequence(result, count, start, increment); - break; - case PhysicalType::INT16: - TemplatedGenerateSequence(result, count, start, increment); - break; - case PhysicalType::INT32: - TemplatedGenerateSequence(result, count, start, increment); - break; - case PhysicalType::INT64: - TemplatedGenerateSequence(result, count, start, increment); - break; - case PhysicalType::FLOAT: - TemplatedGenerateSequence(result, count, start, increment); - break; - case PhysicalType::DOUBLE: - TemplatedGenerateSequence(result, count, start, increment); - break; - default: - throw NotImplementedException("Unimplemented type for generate sequence"); - } -} - -template -void TemplatedGenerateSequence(Vector &result, idx_t count, const SelectionVector &sel, int64_t start, - int64_t increment) { - D_ASSERT(result.GetType().IsNumeric()); - if (start > NumericLimits::Maximum() || increment > NumericLimits::Maximum()) { - throw Exception("Sequence start or increment out of type range"); - } - result.SetVectorType(VectorType::FLAT_VECTOR); - auto result_data = FlatVector::GetData(result); - auto value = (T)start; - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - result_data[idx] = value + increment * idx; - } -} - -void VectorOperations::GenerateSequence(Vector &result, idx_t count, const SelectionVector &sel, int64_t start, - int64_t increment) { - if (!result.GetType().IsNumeric()) { - throw InvalidTypeException(result.GetType(), "Can only generate sequences for numeric values!"); - } - switch (result.GetType().InternalType()) { - case PhysicalType::INT8: - TemplatedGenerateSequence(result, count, sel, start, increment); - break; - case PhysicalType::INT16: - TemplatedGenerateSequence(result, count, sel, start, increment); - break; - case PhysicalType::INT32: - TemplatedGenerateSequence(result, count, sel, start, increment); - break; - case PhysicalType::INT64: - TemplatedGenerateSequence(result, count, sel, start, increment); - break; - case PhysicalType::FLOAT: - TemplatedGenerateSequence(result, count, sel, start, increment); - break; - case PhysicalType::DOUBLE: - TemplatedGenerateSequence(result, count, sel, start, increment); - break; - default: - throw NotImplementedException("Unimplemented type for generate sequence"); - } -} - -} // namespace duckdb - - - -namespace duckdb { - -struct DistinctBinaryLambdaWrapper { - template - static inline RESULT_TYPE Operation(LEFT_TYPE left, RIGHT_TYPE right, bool is_left_null, bool is_right_null) { - return OP::template Operation(left, right, is_left_null, is_right_null); - } -}; - -template -static void DistinctExecuteGenericLoop(const LEFT_TYPE *__restrict ldata, const RIGHT_TYPE *__restrict rdata, - RESULT_TYPE *__restrict result_data, const SelectionVector *__restrict lsel, - const SelectionVector *__restrict rsel, idx_t count, ValidityMask &lmask, - ValidityMask &rmask, ValidityMask &result_mask) { - for (idx_t i = 0; i < count; i++) { - auto lindex = lsel->get_index(i); - auto rindex = rsel->get_index(i); - auto lentry = ldata[lindex]; - auto rentry = rdata[rindex]; - result_data[i] = - OP::template Operation(lentry, rentry, !lmask.RowIsValid(lindex), !rmask.RowIsValid(rindex)); - } -} - -template -static void DistinctExecuteConstant(Vector &left, Vector &right, Vector &result) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - - auto ldata = ConstantVector::GetData(left); - auto rdata = ConstantVector::GetData(right); - auto result_data = ConstantVector::GetData(result); - *result_data = - OP::template Operation(*ldata, *rdata, ConstantVector::IsNull(left), ConstantVector::IsNull(right)); -} - -template -static void DistinctExecuteGeneric(Vector &left, Vector &right, Vector &result, idx_t count) { - if (left.GetVectorType() == VectorType::CONSTANT_VECTOR && right.GetVectorType() == VectorType::CONSTANT_VECTOR) { - DistinctExecuteConstant(left, right, result); - } else { - UnifiedVectorFormat ldata, rdata; - - left.ToUnifiedFormat(count, ldata); - right.ToUnifiedFormat(count, rdata); - - result.SetVectorType(VectorType::FLAT_VECTOR); - auto result_data = FlatVector::GetData(result); - DistinctExecuteGenericLoop( - UnifiedVectorFormat::GetData(ldata), UnifiedVectorFormat::GetData(rdata), - result_data, ldata.sel, rdata.sel, count, ldata.validity, rdata.validity, FlatVector::Validity(result)); - } -} - -template -static void DistinctExecuteSwitch(Vector &left, Vector &right, Vector &result, idx_t count) { - DistinctExecuteGeneric(left, right, result, count); -} - -template -static void DistinctExecute(Vector &left, Vector &right, Vector &result, idx_t count) { - DistinctExecuteSwitch(left, right, result, count); -} - -template -static inline idx_t -DistinctSelectGenericLoop(const LEFT_TYPE *__restrict ldata, const RIGHT_TYPE *__restrict rdata, - const SelectionVector *__restrict lsel, const SelectionVector *__restrict rsel, - const SelectionVector *__restrict result_sel, idx_t count, ValidityMask &lmask, - ValidityMask &rmask, SelectionVector *true_sel, SelectionVector *false_sel) { - idx_t true_count = 0, false_count = 0; - for (idx_t i = 0; i < count; i++) { - auto result_idx = result_sel->get_index(i); - auto lindex = lsel->get_index(i); - auto rindex = rsel->get_index(i); - if (NO_NULL) { - if (OP::Operation(ldata[lindex], rdata[rindex], false, false)) { - if (HAS_TRUE_SEL) { - true_sel->set_index(true_count++, result_idx); - } - } else { - if (HAS_FALSE_SEL) { - false_sel->set_index(false_count++, result_idx); - } - } - } else { - if (OP::Operation(ldata[lindex], rdata[rindex], !lmask.RowIsValid(lindex), !rmask.RowIsValid(rindex))) { - if (HAS_TRUE_SEL) { - true_sel->set_index(true_count++, result_idx); - } - } else { - if (HAS_FALSE_SEL) { - false_sel->set_index(false_count++, result_idx); - } - } - } - } - if (HAS_TRUE_SEL) { - return true_count; - } else { - return count - false_count; - } -} -template -static inline idx_t -DistinctSelectGenericLoopSelSwitch(const LEFT_TYPE *__restrict ldata, const RIGHT_TYPE *__restrict rdata, - const SelectionVector *__restrict lsel, const SelectionVector *__restrict rsel, - const SelectionVector *__restrict result_sel, idx_t count, ValidityMask &lmask, - ValidityMask &rmask, SelectionVector *true_sel, SelectionVector *false_sel) { - if (true_sel && false_sel) { - return DistinctSelectGenericLoop( - ldata, rdata, lsel, rsel, result_sel, count, lmask, rmask, true_sel, false_sel); - } else if (true_sel) { - return DistinctSelectGenericLoop( - ldata, rdata, lsel, rsel, result_sel, count, lmask, rmask, true_sel, false_sel); - } else { - D_ASSERT(false_sel); - return DistinctSelectGenericLoop( - ldata, rdata, lsel, rsel, result_sel, count, lmask, rmask, true_sel, false_sel); - } -} - -template -static inline idx_t -DistinctSelectGenericLoopSwitch(const LEFT_TYPE *__restrict ldata, const RIGHT_TYPE *__restrict rdata, - const SelectionVector *__restrict lsel, const SelectionVector *__restrict rsel, - const SelectionVector *__restrict result_sel, idx_t count, ValidityMask &lmask, - ValidityMask &rmask, SelectionVector *true_sel, SelectionVector *false_sel) { - if (!lmask.AllValid() || !rmask.AllValid()) { - return DistinctSelectGenericLoopSelSwitch( - ldata, rdata, lsel, rsel, result_sel, count, lmask, rmask, true_sel, false_sel); - } else { - return DistinctSelectGenericLoopSelSwitch( - ldata, rdata, lsel, rsel, result_sel, count, lmask, rmask, true_sel, false_sel); - } -} - -template -static idx_t DistinctSelectGeneric(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - UnifiedVectorFormat ldata, rdata; - - left.ToUnifiedFormat(count, ldata); - right.ToUnifiedFormat(count, rdata); - - return DistinctSelectGenericLoopSwitch( - UnifiedVectorFormat::GetData(ldata), UnifiedVectorFormat::GetData(rdata), ldata.sel, - rdata.sel, sel, count, ldata.validity, rdata.validity, true_sel, false_sel); -} -template -static inline idx_t DistinctSelectFlatLoop(LEFT_TYPE *__restrict ldata, RIGHT_TYPE *__restrict rdata, - const SelectionVector *sel, idx_t count, ValidityMask &lmask, - ValidityMask &rmask, SelectionVector *true_sel, SelectionVector *false_sel) { - idx_t true_count = 0, false_count = 0; - for (idx_t i = 0; i < count; i++) { - idx_t result_idx = sel->get_index(i); - idx_t lidx = LEFT_CONSTANT ? 0 : i; - idx_t ridx = RIGHT_CONSTANT ? 0 : i; - const bool lnull = !lmask.RowIsValid(lidx); - const bool rnull = !rmask.RowIsValid(ridx); - bool comparison_result = OP::Operation(ldata[lidx], rdata[ridx], lnull, rnull); - if (HAS_TRUE_SEL) { - true_sel->set_index(true_count, result_idx); - true_count += comparison_result; - } - if (HAS_FALSE_SEL) { - false_sel->set_index(false_count, result_idx); - false_count += !comparison_result; - } - } - if (HAS_TRUE_SEL) { - return true_count; - } else { - return count - false_count; - } -} - -template -static inline idx_t DistinctSelectFlatLoopSelSwitch(LEFT_TYPE *__restrict ldata, RIGHT_TYPE *__restrict rdata, - const SelectionVector *sel, idx_t count, ValidityMask &lmask, - ValidityMask &rmask, SelectionVector *true_sel, - SelectionVector *false_sel) { - if (true_sel && false_sel) { - return DistinctSelectFlatLoop( - ldata, rdata, sel, count, lmask, rmask, true_sel, false_sel); - } else if (true_sel) { - return DistinctSelectFlatLoop( - ldata, rdata, sel, count, lmask, rmask, true_sel, false_sel); - } else { - D_ASSERT(false_sel); - return DistinctSelectFlatLoop( - ldata, rdata, sel, count, lmask, rmask, true_sel, false_sel); - } -} - -template -static inline idx_t DistinctSelectFlatLoopSwitch(LEFT_TYPE *__restrict ldata, RIGHT_TYPE *__restrict rdata, - const SelectionVector *sel, idx_t count, ValidityMask &lmask, - ValidityMask &rmask, SelectionVector *true_sel, - SelectionVector *false_sel) { - return DistinctSelectFlatLoopSelSwitch( - ldata, rdata, sel, count, lmask, rmask, true_sel, false_sel); -} -template -static idx_t DistinctSelectFlat(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - auto ldata = FlatVector::GetData(left); - auto rdata = FlatVector::GetData(right); - if (LEFT_CONSTANT) { - ValidityMask validity; - if (ConstantVector::IsNull(left)) { - validity.SetAllInvalid(1); - } - return DistinctSelectFlatLoopSwitch( - ldata, rdata, sel, count, validity, FlatVector::Validity(right), true_sel, false_sel); - } else if (RIGHT_CONSTANT) { - ValidityMask validity; - if (ConstantVector::IsNull(right)) { - validity.SetAllInvalid(1); - } - return DistinctSelectFlatLoopSwitch( - ldata, rdata, sel, count, FlatVector::Validity(left), validity, true_sel, false_sel); - } else { - return DistinctSelectFlatLoopSwitch( - ldata, rdata, sel, count, FlatVector::Validity(left), FlatVector::Validity(right), true_sel, false_sel); - } -} -template -static idx_t DistinctSelectConstant(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - auto ldata = ConstantVector::GetData(left); - auto rdata = ConstantVector::GetData(right); - - // both sides are constant, return either 0 or the count - // in this case we do not fill in the result selection vector at all - if (!OP::Operation(*ldata, *rdata, ConstantVector::IsNull(left), ConstantVector::IsNull(right))) { - if (false_sel) { - for (idx_t i = 0; i < count; i++) { - false_sel->set_index(i, sel->get_index(i)); - } - } - return 0; - } else { - if (true_sel) { - for (idx_t i = 0; i < count; i++) { - true_sel->set_index(i, sel->get_index(i)); - } - } - return count; - } -} - -template -static idx_t DistinctSelect(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - if (!sel) { - sel = FlatVector::IncrementalSelectionVector(); - } - if (left.GetVectorType() == VectorType::CONSTANT_VECTOR && right.GetVectorType() == VectorType::CONSTANT_VECTOR) { - return DistinctSelectConstant(left, right, sel, count, true_sel, false_sel); - } else if (left.GetVectorType() == VectorType::CONSTANT_VECTOR && - right.GetVectorType() == VectorType::FLAT_VECTOR) { - return DistinctSelectFlat(left, right, sel, count, true_sel, false_sel); - } else if (left.GetVectorType() == VectorType::FLAT_VECTOR && - right.GetVectorType() == VectorType::CONSTANT_VECTOR) { - return DistinctSelectFlat(left, right, sel, count, true_sel, false_sel); - } else if (left.GetVectorType() == VectorType::FLAT_VECTOR && right.GetVectorType() == VectorType::FLAT_VECTOR) { - return DistinctSelectFlat(left, right, sel, count, true_sel, - false_sel); - } else { - return DistinctSelectGeneric(left, right, sel, count, true_sel, false_sel); - } -} - -template -static idx_t DistinctSelectNotNull(Vector &left, Vector &right, const idx_t count, idx_t &true_count, - const SelectionVector &sel, SelectionVector &maybe_vec, OptionalSelection &true_opt, - OptionalSelection &false_opt) { - UnifiedVectorFormat lvdata, rvdata; - left.ToUnifiedFormat(count, lvdata); - right.ToUnifiedFormat(count, rvdata); - - auto &lmask = lvdata.validity; - auto &rmask = rvdata.validity; - - idx_t remaining = 0; - if (lmask.AllValid() && rmask.AllValid()) { - // None are NULL, distinguish values. - for (idx_t i = 0; i < count; ++i) { - const auto idx = sel.get_index(i); - maybe_vec.set_index(remaining++, idx); - } - return remaining; - } - - // Slice the Vectors down to the rows that are not determined (i.e., neither is NULL) - SelectionVector slicer(count); - true_count = 0; - idx_t false_count = 0; - for (idx_t i = 0; i < count; ++i) { - const auto result_idx = sel.get_index(i); - const auto lidx = lvdata.sel->get_index(i); - const auto ridx = rvdata.sel->get_index(i); - const auto lnull = !lmask.RowIsValid(lidx); - const auto rnull = !rmask.RowIsValid(ridx); - if (lnull || rnull) { - // If either is NULL then we can major distinguish them - if (!OP::Operation(false, false, lnull, rnull)) { - false_opt.Append(false_count, result_idx); - } else { - true_opt.Append(true_count, result_idx); - } - } else { - // Neither is NULL, distinguish values. - slicer.set_index(remaining, i); - maybe_vec.set_index(remaining++, result_idx); - } - } - - true_opt.Advance(true_count); - false_opt.Advance(false_count); - - if (remaining && remaining < count) { - left.Slice(slicer, remaining); - right.Slice(slicer, remaining); - } - - return remaining; -} - -struct PositionComparator { - // Select the rows that definitely match. - // Default to the same as the final row - template - static idx_t Definite(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, - SelectionVector *true_sel, SelectionVector &false_sel) { - return Final(left, right, sel, count, true_sel, &false_sel); - } - - // Select the possible rows that need further testing. - // Usually this means Is Not Distinct, as those are the semantics used by Postges - template - static idx_t Possible(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, - SelectionVector &true_sel, SelectionVector *false_sel) { - return VectorOperations::NestedEquals(left, right, sel, count, &true_sel, false_sel); - } - - // Select the matching rows for the final position. - // This needs to be specialised. - template - static idx_t Final(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, SelectionVector *true_sel, - SelectionVector *false_sel) { - return 0; - } - - // Tie-break based on length when one of the sides has been exhausted, returning true if the LHS matches. - // This essentially means that the existing positions compare equal. - // Default to the same semantics as the OP for idx_t. This works in most cases. - template - static bool TieBreak(const idx_t lpos, const idx_t rpos) { - return OP::Operation(lpos, rpos, false, false); - } -}; - -// NotDistinctFrom must always check every column -template <> -idx_t PositionComparator::Definite(Vector &left, Vector &right, const SelectionVector &sel, - idx_t count, SelectionVector *true_sel, - SelectionVector &false_sel) { - return 0; -} - -template <> -idx_t PositionComparator::Final(Vector &left, Vector &right, const SelectionVector &sel, - idx_t count, SelectionVector *true_sel, - SelectionVector *false_sel) { - return VectorOperations::NestedEquals(left, right, sel, count, true_sel, false_sel); -} - -// DistinctFrom must check everything that matched -template <> -idx_t PositionComparator::Possible(Vector &left, Vector &right, const SelectionVector &sel, - idx_t count, SelectionVector &true_sel, - SelectionVector *false_sel) { - return count; -} - -template <> -idx_t PositionComparator::Final(Vector &left, Vector &right, const SelectionVector &sel, - idx_t count, SelectionVector *true_sel, - SelectionVector *false_sel) { - return VectorOperations::NestedNotEquals(left, right, sel, count, true_sel, false_sel); -} - -// Non-strict inequalities must use strict comparisons for Definite -template <> -idx_t PositionComparator::Definite(Vector &left, Vector &right, - const SelectionVector &sel, idx_t count, - SelectionVector *true_sel, - SelectionVector &false_sel) { - return VectorOperations::DistinctGreaterThan(right, left, &sel, count, true_sel, &false_sel); -} - -template <> -idx_t PositionComparator::Final(Vector &left, Vector &right, const SelectionVector &sel, - idx_t count, SelectionVector *true_sel, - SelectionVector *false_sel) { - return VectorOperations::DistinctGreaterThanEquals(right, left, &sel, count, true_sel, false_sel); -} - -template <> -idx_t PositionComparator::Definite(Vector &left, Vector &right, - const SelectionVector &sel, idx_t count, - SelectionVector *true_sel, - SelectionVector &false_sel) { - return VectorOperations::DistinctGreaterThan(left, right, &sel, count, true_sel, &false_sel); -} - -template <> -idx_t PositionComparator::Final(Vector &left, Vector &right, - const SelectionVector &sel, idx_t count, - SelectionVector *true_sel, - SelectionVector *false_sel) { - return VectorOperations::DistinctGreaterThanEquals(left, right, &sel, count, true_sel, false_sel); -} - -// Strict inequalities just use strict for both Definite and Final -template <> -idx_t PositionComparator::Final(Vector &left, Vector &right, const SelectionVector &sel, - idx_t count, SelectionVector *true_sel, - SelectionVector *false_sel) { - return VectorOperations::DistinctGreaterThan(right, left, &sel, count, true_sel, false_sel); -} - -template <> -idx_t PositionComparator::Final(Vector &left, Vector &right, const SelectionVector &sel, - idx_t count, SelectionVector *true_sel, - SelectionVector *false_sel) { - return VectorOperations::DistinctGreaterThan(left, right, &sel, count, true_sel, false_sel); -} - -using StructEntries = vector>; - -static void ExtractNestedSelection(const SelectionVector &slice_sel, const idx_t count, const SelectionVector &sel, - OptionalSelection &opt) { - - for (idx_t i = 0; i < count;) { - const auto slice_idx = slice_sel.get_index(i); - const auto result_idx = sel.get_index(slice_idx); - opt.Append(i, result_idx); - } - opt.Advance(count); -} - -static void DensifyNestedSelection(const SelectionVector &dense_sel, const idx_t count, SelectionVector &slice_sel) { - for (idx_t i = 0; i < count; ++i) { - slice_sel.set_index(i, dense_sel.get_index(i)); - } -} - -template -static idx_t DistinctSelectStruct(Vector &left, Vector &right, idx_t count, const SelectionVector &sel, - OptionalSelection &true_opt, OptionalSelection &false_opt) { - if (count == 0) { - return 0; - } - - // Avoid allocating in the 99% of the cases where we don't need to. - StructEntries lsliced, rsliced; - auto &lchildren = StructVector::GetEntries(left); - auto &rchildren = StructVector::GetEntries(right); - D_ASSERT(lchildren.size() == rchildren.size()); - - // In order to reuse the comparators, we have to track what passed and failed internally. - // To do that, we need local SVs that we then merge back into the real ones after every pass. - const auto vcount = count; - SelectionVector slice_sel(count); - for (idx_t i = 0; i < count; ++i) { - slice_sel.set_index(i, i); - } - - SelectionVector true_sel(count); - SelectionVector false_sel(count); - - idx_t match_count = 0; - for (idx_t col_no = 0; col_no < lchildren.size(); ++col_no) { - // Slice the children to maintain density - Vector lchild(*lchildren[col_no]); - lchild.Flatten(vcount); - lchild.Slice(slice_sel, count); - - Vector rchild(*rchildren[col_no]); - rchild.Flatten(vcount); - rchild.Slice(slice_sel, count); - - // Find everything that definitely matches - auto true_count = PositionComparator::Definite(lchild, rchild, slice_sel, count, &true_sel, false_sel); - if (true_count > 0) { - auto false_count = count - true_count; - - // Extract the definite matches into the true result - ExtractNestedSelection(false_count ? true_sel : slice_sel, true_count, sel, true_opt); - - // Remove the definite matches from the slicing vector - DensifyNestedSelection(false_sel, false_count, slice_sel); - - match_count += true_count; - count -= true_count; - } - - if (col_no != lchildren.size() - 1) { - // Find what might match on the next position - true_count = PositionComparator::Possible(lchild, rchild, slice_sel, count, true_sel, &false_sel); - auto false_count = count - true_count; - - // Extract the definite failures into the false result - ExtractNestedSelection(true_count ? false_sel : slice_sel, false_count, sel, false_opt); - - // Remove any definite failures from the slicing vector - if (false_count) { - DensifyNestedSelection(true_sel, true_count, slice_sel); - } - - count = true_count; - } else { - true_count = PositionComparator::Final(lchild, rchild, slice_sel, count, &true_sel, &false_sel); - auto false_count = count - true_count; - - // Extract the definite matches into the true result - ExtractNestedSelection(false_count ? true_sel : slice_sel, true_count, sel, true_opt); - - // Extract the definite failures into the false result - ExtractNestedSelection(true_count ? false_sel : slice_sel, false_count, sel, false_opt); - - match_count += true_count; - } - } - return match_count; -} - -static void PositionListCursor(SelectionVector &cursor, UnifiedVectorFormat &vdata, const idx_t pos, - const SelectionVector &slice_sel, const idx_t count) { - const auto data = UnifiedVectorFormat::GetData(vdata); - for (idx_t i = 0; i < count; ++i) { - const auto slice_idx = slice_sel.get_index(i); - - const auto lidx = vdata.sel->get_index(slice_idx); - const auto &entry = data[lidx]; - cursor.set_index(i, entry.offset + pos); - } -} - -template -static idx_t DistinctSelectList(Vector &left, Vector &right, idx_t count, const SelectionVector &sel, - OptionalSelection &true_opt, OptionalSelection &false_opt) { - if (count == 0) { - return count; - } - - // Create dictionary views of the children so we can vectorise the positional comparisons. - SelectionVector lcursor(count); - SelectionVector rcursor(count); - - Vector lentry_flattened(ListVector::GetEntry(left)); - Vector rentry_flattened(ListVector::GetEntry(right)); - lentry_flattened.Flatten(ListVector::GetListSize(left)); - rentry_flattened.Flatten(ListVector::GetListSize(right)); - Vector lchild(lentry_flattened, lcursor, count); - Vector rchild(rentry_flattened, rcursor, count); - - // To perform the positional comparison, we use a vectorisation of the following algorithm: - // bool CompareLists(T *left, idx_t nleft, T *right, nright) { - // for (idx_t pos = 0; ; ++pos) { - // if (nleft == pos || nright == pos) - // return OP::TieBreak(nleft, nright); - // if (OP::Definite(*left, *right)) - // return true; - // if (!OP::Maybe(*left, *right)) - // return false; - // } - // ++left, ++right; - // } - // } - - // Get pointers to the list entries - UnifiedVectorFormat lvdata; - left.ToUnifiedFormat(count, lvdata); - const auto ldata = UnifiedVectorFormat::GetData(lvdata); - - UnifiedVectorFormat rvdata; - right.ToUnifiedFormat(count, rvdata); - const auto rdata = UnifiedVectorFormat::GetData(rvdata); - - // In order to reuse the comparators, we have to track what passed and failed internally. - // To do that, we need local SVs that we then merge back into the real ones after every pass. - SelectionVector slice_sel(count); - for (idx_t i = 0; i < count; ++i) { - slice_sel.set_index(i, i); - } - - SelectionVector true_sel(count); - SelectionVector false_sel(count); - - idx_t match_count = 0; - for (idx_t pos = 0; count > 0; ++pos) { - // Set up the cursors for the current position - PositionListCursor(lcursor, lvdata, pos, slice_sel, count); - PositionListCursor(rcursor, rvdata, pos, slice_sel, count); - - // Tie-break the pairs where one of the LISTs is exhausted. - idx_t true_count = 0; - idx_t false_count = 0; - idx_t maybe_count = 0; - for (idx_t i = 0; i < count; ++i) { - const auto slice_idx = slice_sel.get_index(i); - const auto lidx = lvdata.sel->get_index(slice_idx); - const auto &lentry = ldata[lidx]; - const auto ridx = rvdata.sel->get_index(slice_idx); - const auto &rentry = rdata[ridx]; - if (lentry.length == pos || rentry.length == pos) { - const auto idx = sel.get_index(slice_idx); - if (PositionComparator::TieBreak(lentry.length, rentry.length)) { - true_opt.Append(true_count, idx); - } else { - false_opt.Append(false_count, idx); - } - } else { - true_sel.set_index(maybe_count++, slice_idx); - } - } - true_opt.Advance(true_count); - false_opt.Advance(false_count); - match_count += true_count; - - // Redensify the list cursors - if (maybe_count < count) { - count = maybe_count; - DensifyNestedSelection(true_sel, count, slice_sel); - PositionListCursor(lcursor, lvdata, pos, slice_sel, count); - PositionListCursor(rcursor, rvdata, pos, slice_sel, count); - } - - // Find everything that definitely matches - true_count = PositionComparator::Definite(lchild, rchild, slice_sel, count, &true_sel, false_sel); - if (true_count) { - false_count = count - true_count; - ExtractNestedSelection(false_count ? true_sel : slice_sel, true_count, sel, true_opt); - match_count += true_count; - - // Redensify the list cursors - count -= true_count; - DensifyNestedSelection(false_sel, count, slice_sel); - PositionListCursor(lcursor, lvdata, pos, slice_sel, count); - PositionListCursor(rcursor, rvdata, pos, slice_sel, count); - } - - // Find what might match on the next position - true_count = PositionComparator::Possible(lchild, rchild, slice_sel, count, true_sel, &false_sel); - false_count = count - true_count; - ExtractNestedSelection(true_count ? false_sel : slice_sel, false_count, sel, false_opt); - - if (false_count) { - DensifyNestedSelection(true_sel, true_count, slice_sel); - } - count = true_count; - } - - return match_count; -} - -template -static idx_t DistinctSelectNested(Vector &left, Vector &right, const SelectionVector *sel, const idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - // The Select operations all use a dense pair of input vectors to partition - // a selection vector in a single pass. But to implement progressive comparisons, - // we have to make multiple passes, so we need to keep track of the original input positions - // and then scatter the output selections when we are done. - if (!sel) { - sel = FlatVector::IncrementalSelectionVector(); - } - - // Make buffered selections for progressive comparisons - // TODO: Remove unnecessary allocations - SelectionVector true_vec(count); - OptionalSelection true_opt(&true_vec); - - SelectionVector false_vec(count); - OptionalSelection false_opt(&false_vec); - - SelectionVector maybe_vec(count); - - // Handle NULL nested values - Vector l_not_null(left); - Vector r_not_null(right); - - idx_t match_count = 0; - auto unknown = - DistinctSelectNotNull(l_not_null, r_not_null, count, match_count, *sel, maybe_vec, true_opt, false_opt); - - if (PhysicalType::LIST == left.GetType().InternalType()) { - match_count += DistinctSelectList(l_not_null, r_not_null, unknown, maybe_vec, true_opt, false_opt); - } else { - match_count += DistinctSelectStruct(l_not_null, r_not_null, unknown, maybe_vec, true_opt, false_opt); - } - - // Copy the buffered selections to the output selections - if (true_sel) { - DensifyNestedSelection(true_vec, match_count, *true_sel); - } - - if (false_sel) { - DensifyNestedSelection(false_vec, count - match_count, *false_sel); - } - - return match_count; -} - -template -static void NestedDistinctExecute(Vector &left, Vector &right, Vector &result, idx_t count); - -template -static inline void TemplatedDistinctExecute(Vector &left, Vector &right, Vector &result, idx_t count) { - DistinctExecute(left, right, result, count); -} -template -static void ExecuteDistinct(Vector &left, Vector &right, Vector &result, idx_t count) { - D_ASSERT(left.GetType() == right.GetType() && result.GetType() == LogicalType::BOOLEAN); - // the inplace loops take the result as the last parameter - switch (left.GetType().InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedDistinctExecute(left, right, result, count); - break; - case PhysicalType::INT16: - TemplatedDistinctExecute(left, right, result, count); - break; - case PhysicalType::INT32: - TemplatedDistinctExecute(left, right, result, count); - break; - case PhysicalType::INT64: - TemplatedDistinctExecute(left, right, result, count); - break; - case PhysicalType::UINT8: - TemplatedDistinctExecute(left, right, result, count); - break; - case PhysicalType::UINT16: - TemplatedDistinctExecute(left, right, result, count); - break; - case PhysicalType::UINT32: - TemplatedDistinctExecute(left, right, result, count); - break; - case PhysicalType::UINT64: - TemplatedDistinctExecute(left, right, result, count); - break; - case PhysicalType::INT128: - TemplatedDistinctExecute(left, right, result, count); - break; - case PhysicalType::FLOAT: - TemplatedDistinctExecute(left, right, result, count); - break; - case PhysicalType::DOUBLE: - TemplatedDistinctExecute(left, right, result, count); - break; - case PhysicalType::INTERVAL: - TemplatedDistinctExecute(left, right, result, count); - break; - case PhysicalType::VARCHAR: - TemplatedDistinctExecute(left, right, result, count); - break; - case PhysicalType::LIST: - case PhysicalType::STRUCT: - NestedDistinctExecute(left, right, result, count); - break; - default: - throw InternalException("Invalid type for distinct comparison"); - } -} - -template -static idx_t TemplatedDistinctSelectOperation(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - // the inplace loops take the result as the last parameter - switch (left.GetType().InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - return DistinctSelect(left, right, sel, count, true_sel, false_sel); - case PhysicalType::INT16: - return DistinctSelect(left, right, sel, count, true_sel, false_sel); - case PhysicalType::INT32: - return DistinctSelect(left, right, sel, count, true_sel, false_sel); - case PhysicalType::INT64: - return DistinctSelect(left, right, sel, count, true_sel, false_sel); - case PhysicalType::UINT8: - return DistinctSelect(left, right, sel, count, true_sel, false_sel); - case PhysicalType::UINT16: - return DistinctSelect(left, right, sel, count, true_sel, false_sel); - case PhysicalType::UINT32: - return DistinctSelect(left, right, sel, count, true_sel, false_sel); - case PhysicalType::UINT64: - return DistinctSelect(left, right, sel, count, true_sel, false_sel); - case PhysicalType::INT128: - return DistinctSelect(left, right, sel, count, true_sel, false_sel); - case PhysicalType::FLOAT: - return DistinctSelect(left, right, sel, count, true_sel, false_sel); - case PhysicalType::DOUBLE: - return DistinctSelect(left, right, sel, count, true_sel, false_sel); - case PhysicalType::INTERVAL: - return DistinctSelect(left, right, sel, count, true_sel, false_sel); - case PhysicalType::VARCHAR: - return DistinctSelect(left, right, sel, count, true_sel, false_sel); - case PhysicalType::STRUCT: - case PhysicalType::LIST: - return DistinctSelectNested(left, right, sel, count, true_sel, false_sel); - default: - throw InternalException("Invalid type for distinct selection"); - } -} - -template -static void NestedDistinctExecute(Vector &left, Vector &right, Vector &result, idx_t count) { - const auto left_constant = left.GetVectorType() == VectorType::CONSTANT_VECTOR; - const auto right_constant = right.GetVectorType() == VectorType::CONSTANT_VECTOR; - - if (left_constant && right_constant) { - // both sides are constant, so just compare one element. - result.SetVectorType(VectorType::CONSTANT_VECTOR); - auto result_data = ConstantVector::GetData(result); - SelectionVector true_sel(1); - auto match_count = TemplatedDistinctSelectOperation(left, right, nullptr, 1, &true_sel, nullptr); - result_data[0] = match_count > 0; - return; - } - - SelectionVector true_sel(count); - SelectionVector false_sel(count); - - // DISTINCT is either true or false - idx_t match_count = TemplatedDistinctSelectOperation(left, right, nullptr, count, &true_sel, &false_sel); - - result.SetVectorType(VectorType::FLAT_VECTOR); - auto result_data = FlatVector::GetData(result); - - for (idx_t i = 0; i < match_count; ++i) { - const auto idx = true_sel.get_index(i); - result_data[idx] = true; - } - - const idx_t no_match_count = count - match_count; - for (idx_t i = 0; i < no_match_count; ++i) { - const auto idx = false_sel.get_index(i); - result_data[idx] = false; - } -} - -void VectorOperations::DistinctFrom(Vector &left, Vector &right, Vector &result, idx_t count) { - ExecuteDistinct(left, right, result, count); -} - -void VectorOperations::NotDistinctFrom(Vector &left, Vector &right, Vector &result, idx_t count) { - ExecuteDistinct(left, right, result, count); -} - -// true := A != B with nulls being equal -idx_t VectorOperations::DistinctFrom(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - return TemplatedDistinctSelectOperation(left, right, sel, count, true_sel, false_sel); -} -// true := A == B with nulls being equal -idx_t VectorOperations::NotDistinctFrom(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - return count - TemplatedDistinctSelectOperation(left, right, sel, count, false_sel, true_sel); -} - -// true := A > B with nulls being maximal -idx_t VectorOperations::DistinctGreaterThan(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - return TemplatedDistinctSelectOperation(left, right, sel, count, true_sel, false_sel); -} - -// true := A > B with nulls being minimal -idx_t VectorOperations::DistinctGreaterThanNullsFirst(Vector &left, Vector &right, const SelectionVector *sel, - idx_t count, SelectionVector *true_sel, - SelectionVector *false_sel) { - return TemplatedDistinctSelectOperation( - left, right, sel, count, true_sel, false_sel); -} -// true := A >= B with nulls being maximal -idx_t VectorOperations::DistinctGreaterThanEquals(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - return count - - TemplatedDistinctSelectOperation(right, left, sel, count, false_sel, true_sel); -} -// true := A < B with nulls being maximal -idx_t VectorOperations::DistinctLessThan(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - return TemplatedDistinctSelectOperation(right, left, sel, count, true_sel, false_sel); -} - -// true := A < B with nulls being minimal -idx_t VectorOperations::DistinctLessThanNullsFirst(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - return TemplatedDistinctSelectOperation( - right, left, sel, count, true_sel, false_sel); -} - -// true := A <= B with nulls being maximal -idx_t VectorOperations::DistinctLessThanEquals(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - return count - - TemplatedDistinctSelectOperation(left, right, sel, count, false_sel, true_sel); -} - -// true := A != B with nulls being equal, inputs selected -idx_t VectorOperations::NestedNotEquals(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - return TemplatedDistinctSelectOperation(left, right, &sel, count, true_sel, false_sel); -} -// true := A == B with nulls being equal, inputs selected -idx_t VectorOperations::NestedEquals(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - return count - - TemplatedDistinctSelectOperation(left, right, &sel, count, false_sel, true_sel); -} - -} // namespace duckdb -//===--------------------------------------------------------------------===// -// null_operators.cpp -// Description: This file contains the implementation of the -// IS NULL/NOT IS NULL operators -//===--------------------------------------------------------------------===// - - - - -namespace duckdb { - -template -void IsNullLoop(Vector &input, Vector &result, idx_t count) { - D_ASSERT(result.GetType() == LogicalType::BOOLEAN); - - if (input.GetVectorType() == VectorType::CONSTANT_VECTOR) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - auto result_data = ConstantVector::GetData(result); - *result_data = INVERSE ? !ConstantVector::IsNull(input) : ConstantVector::IsNull(input); - } else { - UnifiedVectorFormat data; - input.ToUnifiedFormat(count, data); - - result.SetVectorType(VectorType::FLAT_VECTOR); - auto result_data = FlatVector::GetData(result); - for (idx_t i = 0; i < count; i++) { - auto idx = data.sel->get_index(i); - result_data[i] = INVERSE ? data.validity.RowIsValid(idx) : !data.validity.RowIsValid(idx); - } - } -} - -void VectorOperations::IsNotNull(Vector &input, Vector &result, idx_t count) { - IsNullLoop(input, result, count); -} - -void VectorOperations::IsNull(Vector &input, Vector &result, idx_t count) { - IsNullLoop(input, result, count); -} - -bool VectorOperations::HasNotNull(Vector &input, idx_t count) { - if (count == 0) { - return false; - } - if (input.GetVectorType() == VectorType::CONSTANT_VECTOR) { - return !ConstantVector::IsNull(input); - } else { - UnifiedVectorFormat data; - input.ToUnifiedFormat(count, data); - - if (data.validity.AllValid()) { - return true; - } - for (idx_t i = 0; i < count; i++) { - auto idx = data.sel->get_index(i); - if (data.validity.RowIsValid(idx)) { - return true; - } - } - return false; - } -} - -bool VectorOperations::HasNull(Vector &input, idx_t count) { - if (count == 0) { - return false; - } - if (input.GetVectorType() == VectorType::CONSTANT_VECTOR) { - return ConstantVector::IsNull(input); - } else { - UnifiedVectorFormat data; - input.ToUnifiedFormat(count, data); - - if (data.validity.AllValid()) { - return false; - } - for (idx_t i = 0; i < count; i++) { - auto idx = data.sel->get_index(i); - if (!data.validity.RowIsValid(idx)) { - return true; - } - } - return false; - } -} - -idx_t VectorOperations::CountNotNull(Vector &input, const idx_t count) { - idx_t valid = 0; - - UnifiedVectorFormat vdata; - input.ToUnifiedFormat(count, vdata); - if (vdata.validity.AllValid()) { - return count; - } - switch (input.GetVectorType()) { - case VectorType::FLAT_VECTOR: - valid += vdata.validity.CountValid(count); - break; - case VectorType::CONSTANT_VECTOR: - valid += vdata.validity.CountValid(1) * count; - break; - default: - for (idx_t i = 0; i < count; ++i) { - const auto row_idx = vdata.sel->get_index(i); - valid += int(vdata.validity.RowIsValid(row_idx)); - } - break; - } - - return valid; -} - -} // namespace duckdb -//===--------------------------------------------------------------------===// -// numeric_inplace_operators.cpp -// Description: This file contains the implementation of numeric inplace ops -// += *= /= -= %= -//===--------------------------------------------------------------------===// - - - -#include - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// In-Place Addition -//===--------------------------------------------------------------------===// - -void VectorOperations::AddInPlace(Vector &input, int64_t right, idx_t count) { - D_ASSERT(input.GetType().id() == LogicalTypeId::POINTER); - if (right == 0) { - return; - } - switch (input.GetVectorType()) { - case VectorType::CONSTANT_VECTOR: { - D_ASSERT(!ConstantVector::IsNull(input)); - auto data = ConstantVector::GetData(input); - *data += right; - break; - } - default: { - D_ASSERT(input.GetVectorType() == VectorType::FLAT_VECTOR); - auto data = FlatVector::GetData(input); - for (idx_t i = 0; i < count; i++) { - data[i] += right; - } - break; - } - } -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -bool VectorOperations::TryCast(CastFunctionSet &set, GetCastFunctionInput &input, Vector &source, Vector &result, - idx_t count, string *error_message, bool strict) { - auto cast_function = set.GetCastFunction(source.GetType(), result.GetType(), input); - unique_ptr local_state; - if (cast_function.init_local_state) { - CastLocalStateParameters lparameters(input.context, cast_function.cast_data); - local_state = cast_function.init_local_state(lparameters); - } - CastParameters parameters(cast_function.cast_data.get(), strict, error_message, local_state.get()); - return cast_function.function(source, result, count, parameters); -} - -bool VectorOperations::DefaultTryCast(Vector &source, Vector &result, idx_t count, string *error_message, bool strict) { - CastFunctionSet set; - GetCastFunctionInput input; - return VectorOperations::TryCast(set, input, source, result, count, error_message, strict); -} - -void VectorOperations::DefaultCast(Vector &source, Vector &result, idx_t count, bool strict) { - VectorOperations::DefaultTryCast(source, result, count, nullptr, strict); -} - -bool VectorOperations::TryCast(ClientContext &context, Vector &source, Vector &result, idx_t count, - string *error_message, bool strict) { - auto &config = DBConfig::GetConfig(context); - auto &set = config.GetCastFunctions(); - GetCastFunctionInput get_input(context); - return VectorOperations::TryCast(set, get_input, source, result, count, error_message, strict); -} - -void VectorOperations::Cast(ClientContext &context, Vector &source, Vector &result, idx_t count, bool strict) { - VectorOperations::TryCast(context, source, result, count, nullptr, strict); -} - -} // namespace duckdb -//===--------------------------------------------------------------------===// -// copy.cpp -// Description: This file contains the implementation of the different copy -// functions -//===--------------------------------------------------------------------===// - - - - - - - -namespace duckdb { - -template -static void TemplatedCopy(const Vector &source, const SelectionVector &sel, Vector &target, idx_t source_offset, - idx_t target_offset, idx_t copy_count) { - auto ldata = FlatVector::GetData(source); - auto tdata = FlatVector::GetData(target); - for (idx_t i = 0; i < copy_count; i++) { - auto source_idx = sel.get_index(source_offset + i); - tdata[target_offset + i] = ldata[source_idx]; - } -} - -static const ValidityMask &CopyValidityMask(const Vector &v) { - switch (v.GetVectorType()) { - case VectorType::FLAT_VECTOR: - return FlatVector::Validity(v); - case VectorType::FSST_VECTOR: - return FSSTVector::Validity(v); - default: - throw InternalException("Unsupported vector type in vector copy"); - } -} - -void VectorOperations::Copy(const Vector &source_p, Vector &target, const SelectionVector &sel_p, idx_t source_count, - idx_t source_offset, idx_t target_offset) { - D_ASSERT(source_offset <= source_count); - D_ASSERT(source_p.GetType() == target.GetType()); - idx_t copy_count = source_count - source_offset; - - SelectionVector owned_sel; - const SelectionVector *sel = &sel_p; - - const Vector *source = &source_p; - bool finished = false; - while (!finished) { - switch (source->GetVectorType()) { - case VectorType::DICTIONARY_VECTOR: { - // dictionary vector: merge selection vectors - auto &child = DictionaryVector::Child(*source); - auto &dict_sel = DictionaryVector::SelVector(*source); - // merge the selection vectors and verify the child - auto new_buffer = dict_sel.Slice(*sel, source_count); - owned_sel.Initialize(new_buffer); - sel = &owned_sel; - source = &child; - break; - } - case VectorType::SEQUENCE_VECTOR: { - int64_t start, increment; - Vector seq(source->GetType()); - SequenceVector::GetSequence(*source, start, increment); - VectorOperations::GenerateSequence(seq, source_count, *sel, start, increment); - VectorOperations::Copy(seq, target, *sel, source_count, source_offset, target_offset); - return; - } - case VectorType::CONSTANT_VECTOR: - sel = ConstantVector::ZeroSelectionVector(copy_count, owned_sel); - finished = true; - break; - case VectorType::FSST_VECTOR: - finished = true; - break; - case VectorType::FLAT_VECTOR: - finished = true; - break; - default: - throw NotImplementedException("FIXME unimplemented vector type for VectorOperations::Copy"); - } - } - - if (copy_count == 0) { - return; - } - - // Allow copying of a single value to constant vectors - const auto target_vector_type = target.GetVectorType(); - if (copy_count == 1 && target_vector_type == VectorType::CONSTANT_VECTOR) { - target_offset = 0; - target.SetVectorType(VectorType::FLAT_VECTOR); - } - D_ASSERT(target.GetVectorType() == VectorType::FLAT_VECTOR); - - // first copy the nullmask - auto &tmask = FlatVector::Validity(target); - if (source->GetVectorType() == VectorType::CONSTANT_VECTOR) { - const bool valid = !ConstantVector::IsNull(*source); - for (idx_t i = 0; i < copy_count; i++) { - tmask.Set(target_offset + i, valid); - } - } else { - auto &smask = CopyValidityMask(*source); - if (smask.IsMaskSet()) { - for (idx_t i = 0; i < copy_count; i++) { - auto idx = sel->get_index(source_offset + i); - - if (smask.RowIsValid(idx)) { - // set valid - if (!tmask.AllValid()) { - tmask.SetValidUnsafe(target_offset + i); - } - } else { - // set invalid - if (tmask.AllValid()) { - auto init_size = MaxValue(STANDARD_VECTOR_SIZE, target_offset + copy_count); - tmask.Initialize(init_size); - } - tmask.SetInvalidUnsafe(target_offset + i); - } - } - } - } - - D_ASSERT(sel); - - // For FSST Vectors we decompress instead of copying. - if (source->GetVectorType() == VectorType::FSST_VECTOR) { - FSSTVector::DecompressVector(*source, target, source_offset, target_offset, copy_count, sel); - return; - } - - // now copy over the data - switch (source->GetType().InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedCopy(*source, *sel, target, source_offset, target_offset, copy_count); - break; - case PhysicalType::INT16: - TemplatedCopy(*source, *sel, target, source_offset, target_offset, copy_count); - break; - case PhysicalType::INT32: - TemplatedCopy(*source, *sel, target, source_offset, target_offset, copy_count); - break; - case PhysicalType::INT64: - TemplatedCopy(*source, *sel, target, source_offset, target_offset, copy_count); - break; - case PhysicalType::UINT8: - TemplatedCopy(*source, *sel, target, source_offset, target_offset, copy_count); - break; - case PhysicalType::UINT16: - TemplatedCopy(*source, *sel, target, source_offset, target_offset, copy_count); - break; - case PhysicalType::UINT32: - TemplatedCopy(*source, *sel, target, source_offset, target_offset, copy_count); - break; - case PhysicalType::UINT64: - TemplatedCopy(*source, *sel, target, source_offset, target_offset, copy_count); - break; - case PhysicalType::INT128: - TemplatedCopy(*source, *sel, target, source_offset, target_offset, copy_count); - break; - case PhysicalType::FLOAT: - TemplatedCopy(*source, *sel, target, source_offset, target_offset, copy_count); - break; - case PhysicalType::DOUBLE: - TemplatedCopy(*source, *sel, target, source_offset, target_offset, copy_count); - break; - case PhysicalType::INTERVAL: - TemplatedCopy(*source, *sel, target, source_offset, target_offset, copy_count); - break; - case PhysicalType::VARCHAR: { - auto ldata = FlatVector::GetData(*source); - auto tdata = FlatVector::GetData(target); - for (idx_t i = 0; i < copy_count; i++) { - auto source_idx = sel->get_index(source_offset + i); - auto target_idx = target_offset + i; - if (tmask.RowIsValid(target_idx)) { - tdata[target_idx] = StringVector::AddStringOrBlob(target, ldata[source_idx]); - } - } - break; - } - case PhysicalType::STRUCT: { - auto &source_children = StructVector::GetEntries(*source); - auto &target_children = StructVector::GetEntries(target); - D_ASSERT(source_children.size() == target_children.size()); - for (idx_t i = 0; i < source_children.size(); i++) { - VectorOperations::Copy(*source_children[i], *target_children[i], sel_p, source_count, source_offset, - target_offset); - } - break; - } - case PhysicalType::LIST: { - D_ASSERT(target.GetType().InternalType() == PhysicalType::LIST); - - auto &source_child = ListVector::GetEntry(*source); - auto sdata = FlatVector::GetData(*source); - auto tdata = FlatVector::GetData(target); - - if (target_vector_type == VectorType::CONSTANT_VECTOR) { - // If we are only writing one value, then the copied values (if any) are contiguous - // and we can just Append from the offset position - if (!tmask.RowIsValid(target_offset)) { - break; - } - auto source_idx = sel->get_index(source_offset); - auto &source_entry = sdata[source_idx]; - const idx_t source_child_size = source_entry.length + source_entry.offset; - - //! overwrite constant target vectors. - ListVector::SetListSize(target, 0); - ListVector::Append(target, source_child, source_child_size, source_entry.offset); - - auto &target_entry = tdata[target_offset]; - target_entry.length = source_entry.length; - target_entry.offset = 0; - } else { - //! if the source has list offsets, we need to append them to the target - //! build a selection vector for the copied child elements - vector child_rows; - for (idx_t i = 0; i < copy_count; ++i) { - if (tmask.RowIsValid(target_offset + i)) { - auto source_idx = sel->get_index(source_offset + i); - auto &source_entry = sdata[source_idx]; - for (idx_t j = 0; j < source_entry.length; ++j) { - child_rows.emplace_back(source_entry.offset + j); - } - } - } - idx_t source_child_size = child_rows.size(); - SelectionVector child_sel(child_rows.data()); - - idx_t old_target_child_len = ListVector::GetListSize(target); - - //! append to list itself - ListVector::Append(target, source_child, child_sel, source_child_size); - - //! now write the list offsets - for (idx_t i = 0; i < copy_count; i++) { - auto source_idx = sel->get_index(source_offset + i); - auto &source_entry = sdata[source_idx]; - auto &target_entry = tdata[target_offset + i]; - - target_entry.length = source_entry.length; - target_entry.offset = old_target_child_len; - if (tmask.RowIsValid(target_offset + i)) { - old_target_child_len += target_entry.length; - } - } - } - break; - } - default: - throw NotImplementedException("Unimplemented type '%s' for copy!", - TypeIdToString(source->GetType().InternalType())); - } - - if (target_vector_type != VectorType::FLAT_VECTOR) { - target.SetVectorType(target_vector_type); - } -} - -void VectorOperations::Copy(const Vector &source, Vector &target, idx_t source_count, idx_t source_offset, - idx_t target_offset) { - VectorOperations::Copy(source, target, *FlatVector::IncrementalSelectionVector(), source_count, source_offset, - target_offset); -} - -} // namespace duckdb -//===--------------------------------------------------------------------===// -// hash.cpp -// Description: This file contains the vectorized hash implementations -//===--------------------------------------------------------------------===// - - - - - - - -namespace duckdb { - -struct HashOp { - static const hash_t NULL_HASH = 0xbf58476d1ce4e5b9; - - template - static inline hash_t Operation(T input, bool is_null) { - return is_null ? NULL_HASH : duckdb::Hash(input); - } -}; - -static inline hash_t CombineHashScalar(hash_t a, hash_t b) { - return (a * UINT64_C(0xbf58476d1ce4e5b9)) ^ b; -} - -template -static inline void TightLoopHash(const T *__restrict ldata, hash_t *__restrict result_data, const SelectionVector *rsel, - idx_t count, const SelectionVector *__restrict sel_vector, ValidityMask &mask) { - if (!mask.AllValid()) { - for (idx_t i = 0; i < count; i++) { - auto ridx = HAS_RSEL ? rsel->get_index(i) : i; - auto idx = sel_vector->get_index(ridx); - result_data[ridx] = HashOp::Operation(ldata[idx], !mask.RowIsValid(idx)); - } - } else { - for (idx_t i = 0; i < count; i++) { - auto ridx = HAS_RSEL ? rsel->get_index(i) : i; - auto idx = sel_vector->get_index(ridx); - result_data[ridx] = duckdb::Hash(ldata[idx]); - } - } -} - -template -static inline void TemplatedLoopHash(Vector &input, Vector &result, const SelectionVector *rsel, idx_t count) { - if (input.GetVectorType() == VectorType::CONSTANT_VECTOR) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - - auto ldata = ConstantVector::GetData(input); - auto result_data = ConstantVector::GetData(result); - *result_data = HashOp::Operation(*ldata, ConstantVector::IsNull(input)); - } else { - result.SetVectorType(VectorType::FLAT_VECTOR); - - UnifiedVectorFormat idata; - input.ToUnifiedFormat(count, idata); - - TightLoopHash(UnifiedVectorFormat::GetData(idata), FlatVector::GetData(result), rsel, - count, idata.sel, idata.validity); - } -} - -template -static inline void StructLoopHash(Vector &input, Vector &hashes, const SelectionVector *rsel, idx_t count) { - auto &children = StructVector::GetEntries(input); - - D_ASSERT(!children.empty()); - idx_t col_no = 0; - if (HAS_RSEL) { - if (FIRST_HASH) { - VectorOperations::Hash(*children[col_no++], hashes, *rsel, count); - } else { - VectorOperations::CombineHash(hashes, *children[col_no++], *rsel, count); - } - while (col_no < children.size()) { - VectorOperations::CombineHash(hashes, *children[col_no++], *rsel, count); - } - } else { - if (FIRST_HASH) { - VectorOperations::Hash(*children[col_no++], hashes, count); - } else { - VectorOperations::CombineHash(hashes, *children[col_no++], count); - } - while (col_no < children.size()) { - VectorOperations::CombineHash(hashes, *children[col_no++], count); - } - } -} - -template -static inline void ListLoopHash(Vector &input, Vector &hashes, const SelectionVector *rsel, idx_t count) { - auto hdata = FlatVector::GetData(hashes); - - UnifiedVectorFormat idata; - input.ToUnifiedFormat(count, idata); - const auto ldata = UnifiedVectorFormat::GetData(idata); - - // Hash the children into a temporary - auto &child = ListVector::GetEntry(input); - const auto child_count = ListVector::GetListSize(input); - - Vector child_hashes(LogicalType::HASH, child_count); - if (child_count > 0) { - VectorOperations::Hash(child, child_hashes, child_count); - child_hashes.Flatten(child_count); - } - auto chdata = FlatVector::GetData(child_hashes); - - // Reduce the number of entries to check to the non-empty ones - SelectionVector unprocessed(count); - SelectionVector cursor(HAS_RSEL ? STANDARD_VECTOR_SIZE : count); - idx_t remaining = 0; - for (idx_t i = 0; i < count; ++i) { - const idx_t ridx = HAS_RSEL ? rsel->get_index(i) : i; - const auto lidx = idata.sel->get_index(ridx); - const auto &entry = ldata[lidx]; - if (idata.validity.RowIsValid(lidx) && entry.length > 0) { - unprocessed.set_index(remaining++, ridx); - cursor.set_index(ridx, entry.offset); - } else if (FIRST_HASH) { - hdata[ridx] = HashOp::NULL_HASH; - } - // Empty or NULL non-first elements have no effect. - } - - count = remaining; - if (count == 0) { - return; - } - - // Merge the first position hash into the main hash - idx_t position = 1; - if (FIRST_HASH) { - remaining = 0; - for (idx_t i = 0; i < count; ++i) { - const auto ridx = unprocessed.get_index(i); - const auto cidx = cursor.get_index(ridx); - hdata[ridx] = chdata[cidx]; - - const auto lidx = idata.sel->get_index(ridx); - const auto &entry = ldata[lidx]; - if (entry.length > position) { - // Entry still has values to hash - unprocessed.set_index(remaining++, ridx); - cursor.set_index(ridx, cidx + 1); - } - } - count = remaining; - if (count == 0) { - return; - } - ++position; - } - - // Combine the hashes for the remaining positions until there are none left - for (;; ++position) { - remaining = 0; - for (idx_t i = 0; i < count; ++i) { - const auto ridx = unprocessed.get_index(i); - const auto cidx = cursor.get_index(ridx); - hdata[ridx] = CombineHashScalar(hdata[ridx], chdata[cidx]); - - const auto lidx = idata.sel->get_index(ridx); - const auto &entry = ldata[lidx]; - if (entry.length > position) { - // Entry still has values to hash - unprocessed.set_index(remaining++, ridx); - cursor.set_index(ridx, cidx + 1); - } - } - - count = remaining; - if (count == 0) { - break; - } - } -} - -template -static inline void HashTypeSwitch(Vector &input, Vector &result, const SelectionVector *rsel, idx_t count) { - D_ASSERT(result.GetType().id() == LogicalType::HASH); - switch (input.GetType().InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedLoopHash(input, result, rsel, count); - break; - case PhysicalType::INT16: - TemplatedLoopHash(input, result, rsel, count); - break; - case PhysicalType::INT32: - TemplatedLoopHash(input, result, rsel, count); - break; - case PhysicalType::INT64: - TemplatedLoopHash(input, result, rsel, count); - break; - case PhysicalType::UINT8: - TemplatedLoopHash(input, result, rsel, count); - break; - case PhysicalType::UINT16: - TemplatedLoopHash(input, result, rsel, count); - break; - case PhysicalType::UINT32: - TemplatedLoopHash(input, result, rsel, count); - break; - case PhysicalType::UINT64: - TemplatedLoopHash(input, result, rsel, count); - break; - case PhysicalType::INT128: - TemplatedLoopHash(input, result, rsel, count); - break; - case PhysicalType::FLOAT: - TemplatedLoopHash(input, result, rsel, count); - break; - case PhysicalType::DOUBLE: - TemplatedLoopHash(input, result, rsel, count); - break; - case PhysicalType::INTERVAL: - TemplatedLoopHash(input, result, rsel, count); - break; - case PhysicalType::VARCHAR: - TemplatedLoopHash(input, result, rsel, count); - break; - case PhysicalType::STRUCT: - StructLoopHash(input, result, rsel, count); - break; - case PhysicalType::LIST: - ListLoopHash(input, result, rsel, count); - break; - default: - throw InvalidTypeException(input.GetType(), "Invalid type for hash"); - } -} - -void VectorOperations::Hash(Vector &input, Vector &result, idx_t count) { - HashTypeSwitch(input, result, nullptr, count); -} - -void VectorOperations::Hash(Vector &input, Vector &result, const SelectionVector &sel, idx_t count) { - HashTypeSwitch(input, result, &sel, count); -} - -template -static inline void TightLoopCombineHashConstant(const T *__restrict ldata, hash_t constant_hash, - hash_t *__restrict hash_data, const SelectionVector *rsel, idx_t count, - const SelectionVector *__restrict sel_vector, ValidityMask &mask) { - if (!mask.AllValid()) { - for (idx_t i = 0; i < count; i++) { - auto ridx = HAS_RSEL ? rsel->get_index(i) : i; - auto idx = sel_vector->get_index(ridx); - auto other_hash = HashOp::Operation(ldata[idx], !mask.RowIsValid(idx)); - hash_data[ridx] = CombineHashScalar(constant_hash, other_hash); - } - } else { - for (idx_t i = 0; i < count; i++) { - auto ridx = HAS_RSEL ? rsel->get_index(i) : i; - auto idx = sel_vector->get_index(ridx); - auto other_hash = duckdb::Hash(ldata[idx]); - hash_data[ridx] = CombineHashScalar(constant_hash, other_hash); - } - } -} - -template -static inline void TightLoopCombineHash(const T *__restrict ldata, hash_t *__restrict hash_data, - const SelectionVector *rsel, idx_t count, - const SelectionVector *__restrict sel_vector, ValidityMask &mask) { - if (!mask.AllValid()) { - for (idx_t i = 0; i < count; i++) { - auto ridx = HAS_RSEL ? rsel->get_index(i) : i; - auto idx = sel_vector->get_index(ridx); - auto other_hash = HashOp::Operation(ldata[idx], !mask.RowIsValid(idx)); - hash_data[ridx] = CombineHashScalar(hash_data[ridx], other_hash); - } - } else { - for (idx_t i = 0; i < count; i++) { - auto ridx = HAS_RSEL ? rsel->get_index(i) : i; - auto idx = sel_vector->get_index(ridx); - auto other_hash = duckdb::Hash(ldata[idx]); - hash_data[ridx] = CombineHashScalar(hash_data[ridx], other_hash); - } - } -} - -template -void TemplatedLoopCombineHash(Vector &input, Vector &hashes, const SelectionVector *rsel, idx_t count) { - if (input.GetVectorType() == VectorType::CONSTANT_VECTOR && hashes.GetVectorType() == VectorType::CONSTANT_VECTOR) { - auto ldata = ConstantVector::GetData(input); - auto hash_data = ConstantVector::GetData(hashes); - - auto other_hash = HashOp::Operation(*ldata, ConstantVector::IsNull(input)); - *hash_data = CombineHashScalar(*hash_data, other_hash); - } else { - UnifiedVectorFormat idata; - input.ToUnifiedFormat(count, idata); - if (hashes.GetVectorType() == VectorType::CONSTANT_VECTOR) { - // mix constant with non-constant, first get the constant value - auto constant_hash = *ConstantVector::GetData(hashes); - // now re-initialize the hashes vector to an empty flat vector - hashes.SetVectorType(VectorType::FLAT_VECTOR); - TightLoopCombineHashConstant(UnifiedVectorFormat::GetData(idata), constant_hash, - FlatVector::GetData(hashes), rsel, count, idata.sel, - idata.validity); - } else { - D_ASSERT(hashes.GetVectorType() == VectorType::FLAT_VECTOR); - TightLoopCombineHash(UnifiedVectorFormat::GetData(idata), - FlatVector::GetData(hashes), rsel, count, idata.sel, - idata.validity); - } - } -} - -template -static inline void CombineHashTypeSwitch(Vector &hashes, Vector &input, const SelectionVector *rsel, idx_t count) { - D_ASSERT(hashes.GetType().id() == LogicalType::HASH); - switch (input.GetType().InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedLoopCombineHash(input, hashes, rsel, count); - break; - case PhysicalType::INT16: - TemplatedLoopCombineHash(input, hashes, rsel, count); - break; - case PhysicalType::INT32: - TemplatedLoopCombineHash(input, hashes, rsel, count); - break; - case PhysicalType::INT64: - TemplatedLoopCombineHash(input, hashes, rsel, count); - break; - case PhysicalType::UINT8: - TemplatedLoopCombineHash(input, hashes, rsel, count); - break; - case PhysicalType::UINT16: - TemplatedLoopCombineHash(input, hashes, rsel, count); - break; - case PhysicalType::UINT32: - TemplatedLoopCombineHash(input, hashes, rsel, count); - break; - case PhysicalType::UINT64: - TemplatedLoopCombineHash(input, hashes, rsel, count); - break; - case PhysicalType::INT128: - TemplatedLoopCombineHash(input, hashes, rsel, count); - break; - case PhysicalType::FLOAT: - TemplatedLoopCombineHash(input, hashes, rsel, count); - break; - case PhysicalType::DOUBLE: - TemplatedLoopCombineHash(input, hashes, rsel, count); - break; - case PhysicalType::INTERVAL: - TemplatedLoopCombineHash(input, hashes, rsel, count); - break; - case PhysicalType::VARCHAR: - TemplatedLoopCombineHash(input, hashes, rsel, count); - break; - case PhysicalType::STRUCT: - StructLoopHash(input, hashes, rsel, count); - break; - case PhysicalType::LIST: - ListLoopHash(input, hashes, rsel, count); - break; - default: - throw InvalidTypeException(input.GetType(), "Invalid type for hash"); - } -} - -void VectorOperations::CombineHash(Vector &hashes, Vector &input, idx_t count) { - CombineHashTypeSwitch(hashes, input, nullptr, count); -} - -void VectorOperations::CombineHash(Vector &hashes, Vector &input, const SelectionVector &rsel, idx_t count) { - CombineHashTypeSwitch(hashes, input, &rsel, count); -} - -} // namespace duckdb - - - - -namespace duckdb { - -template -static void CopyToStorageLoop(UnifiedVectorFormat &vdata, idx_t count, data_ptr_t target) { - auto ldata = UnifiedVectorFormat::GetData(vdata); - auto result_data = (T *)target; - for (idx_t i = 0; i < count; i++) { - auto idx = vdata.sel->get_index(i); - if (!vdata.validity.RowIsValid(idx)) { - result_data[i] = NullValue(); - } else { - result_data[i] = ldata[idx]; - } - } -} - -void VectorOperations::WriteToStorage(Vector &source, idx_t count, data_ptr_t target) { - if (count == 0) { - return; - } - UnifiedVectorFormat vdata; - source.ToUnifiedFormat(count, vdata); - - switch (source.GetType().InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - CopyToStorageLoop(vdata, count, target); - break; - case PhysicalType::INT16: - CopyToStorageLoop(vdata, count, target); - break; - case PhysicalType::INT32: - CopyToStorageLoop(vdata, count, target); - break; - case PhysicalType::INT64: - CopyToStorageLoop(vdata, count, target); - break; - case PhysicalType::UINT8: - CopyToStorageLoop(vdata, count, target); - break; - case PhysicalType::UINT16: - CopyToStorageLoop(vdata, count, target); - break; - case PhysicalType::UINT32: - CopyToStorageLoop(vdata, count, target); - break; - case PhysicalType::UINT64: - CopyToStorageLoop(vdata, count, target); - break; - case PhysicalType::INT128: - CopyToStorageLoop(vdata, count, target); - break; - case PhysicalType::FLOAT: - CopyToStorageLoop(vdata, count, target); - break; - case PhysicalType::DOUBLE: - CopyToStorageLoop(vdata, count, target); - break; - case PhysicalType::INTERVAL: - CopyToStorageLoop(vdata, count, target); - break; - default: - throw NotImplementedException("Unimplemented type for WriteToStorage"); - } -} - -template -static void ReadFromStorageLoop(data_ptr_t source, idx_t count, Vector &result) { - auto ldata = (T *)source; - auto result_data = FlatVector::GetData(result); - for (idx_t i = 0; i < count; i++) { - result_data[i] = ldata[i]; - } -} - -void VectorOperations::ReadFromStorage(data_ptr_t source, idx_t count, Vector &result) { - result.SetVectorType(VectorType::FLAT_VECTOR); - switch (result.GetType().InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - ReadFromStorageLoop(source, count, result); - break; - case PhysicalType::INT16: - ReadFromStorageLoop(source, count, result); - break; - case PhysicalType::INT32: - ReadFromStorageLoop(source, count, result); - break; - case PhysicalType::INT64: - ReadFromStorageLoop(source, count, result); - break; - case PhysicalType::UINT8: - ReadFromStorageLoop(source, count, result); - break; - case PhysicalType::UINT16: - ReadFromStorageLoop(source, count, result); - break; - case PhysicalType::UINT32: - ReadFromStorageLoop(source, count, result); - break; - case PhysicalType::UINT64: - ReadFromStorageLoop(source, count, result); - break; - case PhysicalType::INT128: - ReadFromStorageLoop(source, count, result); - break; - case PhysicalType::FLOAT: - ReadFromStorageLoop(source, count, result); - break; - case PhysicalType::DOUBLE: - ReadFromStorageLoop(source, count, result); - break; - case PhysicalType::INTERVAL: - ReadFromStorageLoop(source, count, result); - break; - default: - throw NotImplementedException("Unimplemented type for ReadFromStorage"); - } -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -VirtualFileSystem::VirtualFileSystem() : default_fs(FileSystem::CreateLocal()) { - VirtualFileSystem::RegisterSubSystem(FileCompressionType::GZIP, make_uniq()); -} - -unique_ptr VirtualFileSystem::OpenFile(const string &path, uint8_t flags, FileLockType lock, - FileCompressionType compression, FileOpener *opener) { - if (compression == FileCompressionType::AUTO_DETECT) { - // auto detect compression settings based on file name - auto lower_path = StringUtil::Lower(path); - if (StringUtil::EndsWith(lower_path, ".tmp")) { - // strip .tmp - lower_path = lower_path.substr(0, lower_path.length() - 4); - } - if (StringUtil::EndsWith(lower_path, ".gz")) { - compression = FileCompressionType::GZIP; - } else if (StringUtil::EndsWith(lower_path, ".zst")) { - compression = FileCompressionType::ZSTD; - } else { - compression = FileCompressionType::UNCOMPRESSED; - } - } - // open the base file handle - auto file_handle = FindFileSystem(path).OpenFile(path, flags, lock, FileCompressionType::UNCOMPRESSED, opener); - if (file_handle->GetType() == FileType::FILE_TYPE_FIFO) { - file_handle = PipeFileSystem::OpenPipe(std::move(file_handle)); - } else if (compression != FileCompressionType::UNCOMPRESSED) { - auto entry = compressed_fs.find(compression); - if (entry == compressed_fs.end()) { - throw NotImplementedException( - "Attempting to open a compressed file, but the compression type is not supported"); - } - file_handle = entry->second->OpenCompressedFile(std::move(file_handle), flags & FileFlags::FILE_FLAGS_WRITE); - } - return file_handle; -} - -void VirtualFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) { - handle.file_system.Read(handle, buffer, nr_bytes, location); -} - -void VirtualFileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) { - handle.file_system.Write(handle, buffer, nr_bytes, location); -} - -int64_t VirtualFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes) { - return handle.file_system.Read(handle, buffer, nr_bytes); -} - -int64_t VirtualFileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes) { - return handle.file_system.Write(handle, buffer, nr_bytes); -} - -int64_t VirtualFileSystem::GetFileSize(FileHandle &handle) { - return handle.file_system.GetFileSize(handle); -} -time_t VirtualFileSystem::GetLastModifiedTime(FileHandle &handle) { - return handle.file_system.GetLastModifiedTime(handle); -} -FileType VirtualFileSystem::GetFileType(FileHandle &handle) { - return handle.file_system.GetFileType(handle); -} - -void VirtualFileSystem::Truncate(FileHandle &handle, int64_t new_size) { - handle.file_system.Truncate(handle, new_size); -} - -void VirtualFileSystem::FileSync(FileHandle &handle) { - handle.file_system.FileSync(handle); -} - -// need to look up correct fs for this -bool VirtualFileSystem::DirectoryExists(const string &directory) { - return FindFileSystem(directory).DirectoryExists(directory); -} -void VirtualFileSystem::CreateDirectory(const string &directory) { - FindFileSystem(directory).CreateDirectory(directory); -} - -void VirtualFileSystem::RemoveDirectory(const string &directory) { - FindFileSystem(directory).RemoveDirectory(directory); -} - -bool VirtualFileSystem::ListFiles(const string &directory, const std::function &callback, - FileOpener *opener) { - return FindFileSystem(directory).ListFiles(directory, callback, opener); -} - -void VirtualFileSystem::MoveFile(const string &source, const string &target) { - FindFileSystem(source).MoveFile(source, target); -} - -bool VirtualFileSystem::FileExists(const string &filename) { - return FindFileSystem(filename).FileExists(filename); -} - -bool VirtualFileSystem::IsPipe(const string &filename) { - return FindFileSystem(filename).IsPipe(filename); -} -void VirtualFileSystem::RemoveFile(const string &filename) { - FindFileSystem(filename).RemoveFile(filename); -} - -string VirtualFileSystem::PathSeparator(const string &path) { - return FindFileSystem(path).PathSeparator(path); -} - -vector VirtualFileSystem::Glob(const string &path, FileOpener *opener) { - return FindFileSystem(path).Glob(path, opener); -} - -void VirtualFileSystem::RegisterSubSystem(unique_ptr fs) { - sub_systems.push_back(std::move(fs)); -} - -void VirtualFileSystem::UnregisterSubSystem(const string &name) { - for (auto sub_system = sub_systems.begin(); sub_system != sub_systems.end(); sub_system++) { - if (sub_system->get()->GetName() == name) { - sub_systems.erase(sub_system); - return; - } - } - throw InvalidInputException("Could not find filesystem with name %s", name); -} - -void VirtualFileSystem::RegisterSubSystem(FileCompressionType compression_type, unique_ptr fs) { - compressed_fs[compression_type] = std::move(fs); -} - -vector VirtualFileSystem::ListSubSystems() { - vector names(sub_systems.size()); - for (idx_t i = 0; i < sub_systems.size(); i++) { - names[i] = sub_systems[i]->GetName(); - } - return names; -} - -std::string VirtualFileSystem::GetName() const { - return "VirtualFileSystem"; -} - -void VirtualFileSystem::SetDisabledFileSystems(const vector &names) { - unordered_set new_disabled_file_systems; - for (auto &name : names) { - if (name.empty()) { - continue; - } - if (new_disabled_file_systems.find(name) != new_disabled_file_systems.end()) { - throw InvalidInputException("Duplicate disabled file system \"%s\"", name); - } - new_disabled_file_systems.insert(name); - } - for (auto &disabled_fs : disabled_file_systems) { - if (new_disabled_file_systems.find(disabled_fs) == new_disabled_file_systems.end()) { - throw InvalidInputException("File system \"%s\" has been disabled previously, it cannot be re-enabled", - disabled_fs); - } - } - disabled_file_systems = std::move(new_disabled_file_systems); -} - -FileSystem &VirtualFileSystem::FindFileSystem(const string &path) { - auto &fs = FindFileSystemInternal(path); - if (!disabled_file_systems.empty() && disabled_file_systems.find(fs.GetName()) != disabled_file_systems.end()) { - throw PermissionException("File system %s has been disabled by configuration", fs.GetName()); - } - return fs; -} - -FileSystem &VirtualFileSystem::FindFileSystemInternal(const string &path) { - for (auto &sub_system : sub_systems) { - if (sub_system->CanHandleFile(path)) { - return *sub_system; - } - } - return *default_fs; -} - -} // namespace duckdb - - -namespace duckdb { - -#ifdef DUCKDB_WINDOWS - -std::wstring WindowsUtil::UTF8ToUnicode(const char *input) { - idx_t result_size; - - result_size = MultiByteToWideChar(CP_UTF8, 0, input, -1, nullptr, 0); - if (result_size == 0) { - throw IOException("Failure in MultiByteToWideChar"); - } - auto buffer = make_unsafe_uniq_array(result_size); - result_size = MultiByteToWideChar(CP_UTF8, 0, input, -1, buffer.get(), result_size); - if (result_size == 0) { - throw IOException("Failure in MultiByteToWideChar"); - } - return std::wstring(buffer.get(), result_size); -} - -static string WideCharToMultiByteWrapper(LPCWSTR input, uint32_t code_page) { - idx_t result_size; - - result_size = WideCharToMultiByte(code_page, 0, input, -1, 0, 0, 0, 0); - if (result_size == 0) { - throw IOException("Failure in WideCharToMultiByte"); - } - auto buffer = make_unsafe_uniq_array(result_size); - result_size = WideCharToMultiByte(code_page, 0, input, -1, buffer.get(), result_size, 0, 0); - if (result_size == 0) { - throw IOException("Failure in WideCharToMultiByte"); - } - return string(buffer.get(), result_size - 1); -} - -string WindowsUtil::UnicodeToUTF8(LPCWSTR input) { - return WideCharToMultiByteWrapper(input, CP_UTF8); -} - -static string WindowsUnicodeToMBCS(LPCWSTR unicode_text, int use_ansi) { - uint32_t code_page = use_ansi ? CP_ACP : CP_OEMCP; - return WideCharToMultiByteWrapper(unicode_text, code_page); -} - -string WindowsUtil::UTF8ToMBCS(const char *input, bool use_ansi) { - auto unicode = WindowsUtil::UTF8ToUnicode(input); - return WindowsUnicodeToMBCS(unicode.c_str(), use_ansi); -} - -#endif - -} // namespace duckdb - - - - - - - -namespace duckdb { - -template -struct AvgState { - uint64_t count; - T value; - - void Initialize() { - this->count = 0; - } - - void Combine(const AvgState &other) { - this->count += other.count; - this->value += other.value; - } -}; - -struct KahanAvgState { - uint64_t count; - double value; - double err; - - void Initialize() { - this->count = 0; - this->err = 0.0; - } - - void Combine(const KahanAvgState &other) { - this->count += other.count; - KahanAddInternal(other.value, this->value, this->err); - KahanAddInternal(other.err, this->value, this->err); - } -}; - -struct AverageDecimalBindData : public FunctionData { - explicit AverageDecimalBindData(double scale) : scale(scale) { - } - - double scale; - -public: - unique_ptr Copy() const override { - return make_uniq(scale); - }; - - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return scale == other.scale; - } -}; - -struct AverageSetOperation { - template - static void Initialize(STATE &state) { - state.Initialize(); - } - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - target.Combine(source); - } - template - static void AddValues(STATE &state, idx_t count) { - state.count += count; - } -}; - -template -static T GetAverageDivident(uint64_t count, optional_ptr bind_data) { - T divident = T(count); - if (bind_data) { - auto &avg_bind_data = bind_data->Cast(); - divident *= avg_bind_data.scale; - } - return divident; -} - -struct IntegerAverageOperation : public BaseSumOperation { - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.count == 0) { - finalize_data.ReturnNull(); - } else { - double divident = GetAverageDivident(state.count, finalize_data.input.bind_data); - target = double(state.value) / divident; - } - } -}; - -struct IntegerAverageOperationHugeint : public BaseSumOperation { - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.count == 0) { - finalize_data.ReturnNull(); - } else { - long double divident = GetAverageDivident(state.count, finalize_data.input.bind_data); - target = Hugeint::Cast(state.value) / divident; - } - } -}; - -struct HugeintAverageOperation : public BaseSumOperation { - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.count == 0) { - finalize_data.ReturnNull(); - } else { - long double divident = GetAverageDivident(state.count, finalize_data.input.bind_data); - target = Hugeint::Cast(state.value) / divident; - } - } -}; - -struct NumericAverageOperation : public BaseSumOperation { - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.count == 0) { - finalize_data.ReturnNull(); - } else { - target = state.value / state.count; - } - } -}; - -struct KahanAverageOperation : public BaseSumOperation { - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.count == 0) { - finalize_data.ReturnNull(); - } else { - target = (state.value / state.count) + (state.err / state.count); - } - } -}; - -AggregateFunction GetAverageAggregate(PhysicalType type) { - switch (type) { - case PhysicalType::INT16: { - return AggregateFunction::UnaryAggregate, int16_t, double, IntegerAverageOperation>( - LogicalType::SMALLINT, LogicalType::DOUBLE); - } - case PhysicalType::INT32: { - return AggregateFunction::UnaryAggregate, int32_t, double, IntegerAverageOperationHugeint>( - LogicalType::INTEGER, LogicalType::DOUBLE); - } - case PhysicalType::INT64: { - return AggregateFunction::UnaryAggregate, int64_t, double, IntegerAverageOperationHugeint>( - LogicalType::BIGINT, LogicalType::DOUBLE); - } - case PhysicalType::INT128: { - return AggregateFunction::UnaryAggregate, hugeint_t, double, HugeintAverageOperation>( - LogicalType::HUGEINT, LogicalType::DOUBLE); - } - default: - throw InternalException("Unimplemented average aggregate"); - } -} - -unique_ptr BindDecimalAvg(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - auto decimal_type = arguments[0]->return_type; - function = GetAverageAggregate(decimal_type.InternalType()); - function.name = "avg"; - function.arguments[0] = decimal_type; - function.return_type = LogicalType::DOUBLE; - return make_uniq( - Hugeint::Cast(Hugeint::POWERS_OF_TEN[DecimalType::GetScale(decimal_type)])); -} - -AggregateFunctionSet AvgFun::GetFunctions() { - AggregateFunctionSet avg; - - avg.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr, - nullptr, nullptr, FunctionNullHandling::DEFAULT_NULL_HANDLING, nullptr, - BindDecimalAvg)); - avg.AddFunction(GetAverageAggregate(PhysicalType::INT16)); - avg.AddFunction(GetAverageAggregate(PhysicalType::INT32)); - avg.AddFunction(GetAverageAggregate(PhysicalType::INT64)); - avg.AddFunction(GetAverageAggregate(PhysicalType::INT128)); - avg.AddFunction(AggregateFunction::UnaryAggregate, double, double, NumericAverageOperation>( - LogicalType::DOUBLE, LogicalType::DOUBLE)); - return avg; -} - -AggregateFunction FAvgFun::GetFunction() { - return AggregateFunction::UnaryAggregate(LogicalType::DOUBLE, - LogicalType::DOUBLE); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -AggregateFunction CorrFun::GetFunction() { - return AggregateFunction::BinaryAggregate( - LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); -} -} // namespace duckdb - - - - -namespace duckdb { - -AggregateFunction CovarPopFun::GetFunction() { - return AggregateFunction::BinaryAggregate( - LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); -} - -AggregateFunction CovarSampFun::GetFunction() { - return AggregateFunction::BinaryAggregate( - LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); -} - -} // namespace duckdb - - - - -#include - -namespace duckdb { - -AggregateFunction StdDevSampFun::GetFunction() { - return AggregateFunction::UnaryAggregate(LogicalType::DOUBLE, - LogicalType::DOUBLE); -} - -AggregateFunction StdDevPopFun::GetFunction() { - return AggregateFunction::UnaryAggregate(LogicalType::DOUBLE, - LogicalType::DOUBLE); -} - -AggregateFunction VarPopFun::GetFunction() { - return AggregateFunction::UnaryAggregate(LogicalType::DOUBLE, - LogicalType::DOUBLE); -} - -AggregateFunction VarSampFun::GetFunction() { - return AggregateFunction::UnaryAggregate(LogicalType::DOUBLE, - LogicalType::DOUBLE); -} - -AggregateFunction StandardErrorOfTheMeanFun::GetFunction() { - return AggregateFunction::UnaryAggregate( - LogicalType::DOUBLE, LogicalType::DOUBLE); -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -struct ApproxDistinctCountState { - ApproxDistinctCountState() : log(nullptr) { - } - ~ApproxDistinctCountState() { - if (log) { - delete log; - } - } - - HyperLogLog *log; -}; - -struct ApproxCountDistinctFunction { - template - static void Initialize(STATE &state) { - state.log = nullptr; - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - if (!source.log) { - return; - } - if (!target.log) { - target.log = new HyperLogLog(); - } - D_ASSERT(target.log); - D_ASSERT(source.log); - auto new_log = target.log->MergePointer(*source.log); - delete target.log; - target.log = new_log; - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.log) { - target = state.log->Count(); - } else { - target = 0; - } - } - - static bool IgnoreNull() { - return true; - } - - template - static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { - if (state.log) { - delete state.log; - state.log = nullptr; - } - } -}; - -static void ApproxCountDistinctSimpleUpdateFunction(Vector inputs[], AggregateInputData &, idx_t input_count, - data_ptr_t state, idx_t count) { - D_ASSERT(input_count == 1); - - auto agg_state = reinterpret_cast(state); - if (!agg_state->log) { - agg_state->log = new HyperLogLog(); - } - - UnifiedVectorFormat vdata; - inputs[0].ToUnifiedFormat(count, vdata); - - if (count > STANDARD_VECTOR_SIZE) { - throw InternalException("ApproxCountDistinct - count must be at most vector size"); - } - uint64_t indices[STANDARD_VECTOR_SIZE]; - uint8_t counts[STANDARD_VECTOR_SIZE]; - HyperLogLog::ProcessEntries(vdata, inputs[0].GetType(), indices, counts, count); - agg_state->log->AddToLog(vdata, count, indices, counts); -} - -static void ApproxCountDistinctUpdateFunction(Vector inputs[], AggregateInputData &, idx_t input_count, - Vector &state_vector, idx_t count) { - D_ASSERT(input_count == 1); - - UnifiedVectorFormat sdata; - state_vector.ToUnifiedFormat(count, sdata); - auto states = UnifiedVectorFormat::GetDataNoConst(sdata); - - for (idx_t i = 0; i < count; i++) { - auto agg_state = states[sdata.sel->get_index(i)]; - if (!agg_state->log) { - agg_state->log = new HyperLogLog(); - } - } - - UnifiedVectorFormat vdata; - inputs[0].ToUnifiedFormat(count, vdata); - - if (count > STANDARD_VECTOR_SIZE) { - throw InternalException("ApproxCountDistinct - count must be at most vector size"); - } - uint64_t indices[STANDARD_VECTOR_SIZE]; - uint8_t counts[STANDARD_VECTOR_SIZE]; - HyperLogLog::ProcessEntries(vdata, inputs[0].GetType(), indices, counts, count); - HyperLogLog::AddToLogs(vdata, count, indices, counts, reinterpret_cast(states), sdata.sel); -} - -AggregateFunction GetApproxCountDistinctFunction(const LogicalType &input_type) { - auto fun = AggregateFunction( - {input_type}, LogicalTypeId::BIGINT, AggregateFunction::StateSize, - AggregateFunction::StateInitialize, - ApproxCountDistinctUpdateFunction, - AggregateFunction::StateCombine, - AggregateFunction::StateFinalize, - ApproxCountDistinctSimpleUpdateFunction, nullptr, - AggregateFunction::StateDestroy); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return fun; -} - -AggregateFunctionSet ApproxCountDistinctFun::GetFunctions() { - AggregateFunctionSet approx_count("approx_count_distinct"); - approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::UTINYINT)); - approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::USMALLINT)); - approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::UINTEGER)); - approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::UBIGINT)); - approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::TINYINT)); - approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::SMALLINT)); - approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::BIGINT)); - approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::HUGEINT)); - approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::FLOAT)); - approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::DOUBLE)); - approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::VARCHAR)); - approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::TIMESTAMP)); - approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::TIMESTAMP_TZ)); - approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::BLOB)); - return approx_count; -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -struct ArgMinMaxStateBase { - ArgMinMaxStateBase() : is_initialized(false) { - } - - template - static inline void CreateValue(T &value) { - } - - template - static inline void DestroyValue(T &value) { - } - - template - static inline void AssignValue(T &target, T new_value, bool is_initialized) { - target = new_value; - } - - template - static inline void ReadValue(Vector &result, T &arg, T &target) { - target = arg; - } - - bool is_initialized; -}; - -// Out-of-line specialisations -template <> -void ArgMinMaxStateBase::CreateValue(Vector *&value) { - value = nullptr; -} - -template <> -void ArgMinMaxStateBase::DestroyValue(string_t &value) { - if (!value.IsInlined()) { - delete[] value.GetData(); - } -} - -template <> -void ArgMinMaxStateBase::DestroyValue(Vector *&value) { - delete value; - value = nullptr; -} - -template <> -void ArgMinMaxStateBase::AssignValue(string_t &target, string_t new_value, bool is_initialized) { - if (is_initialized) { - DestroyValue(target); - } - if (new_value.IsInlined()) { - target = new_value; - } else { - // non-inlined string, need to allocate space for it - auto len = new_value.GetSize(); - auto ptr = new char[len]; - memcpy(ptr, new_value.GetData(), len); - - target = string_t(ptr, len); - } -} - -template <> -void ArgMinMaxStateBase::ReadValue(Vector &result, string_t &arg, string_t &target) { - target = StringVector::AddStringOrBlob(result, arg); -} - -template -struct ArgMinMaxState : public ArgMinMaxStateBase { - using ARG_TYPE = A; - using BY_TYPE = B; - - ARG_TYPE arg; - BY_TYPE value; - - ArgMinMaxState() { - CreateValue(arg); - CreateValue(value); - } - - ~ArgMinMaxState() { - if (is_initialized) { - DestroyValue(arg); - DestroyValue(value); - is_initialized = false; - } - } -}; - -template -struct ArgMinMaxBase { - - template - static void Initialize(STATE &state) { - new (&state) STATE; - } - - template - static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { - state.~STATE(); - } - - template - static void Operation(STATE &state, const A_TYPE &x, const B_TYPE &y, AggregateBinaryInput &) { - if (!state.is_initialized) { - STATE::template AssignValue(state.arg, x, false); - STATE::template AssignValue(state.value, y, false); - state.is_initialized = true; - } else { - OP::template Execute(state, x, y); - } - } - - template - static void Execute(STATE &state, A_TYPE x_data, B_TYPE y_data) { - if (COMPARATOR::Operation(y_data, state.value)) { - STATE::template AssignValue(state.arg, x_data, true); - STATE::template AssignValue(state.value, y_data, true); - } - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - if (!source.is_initialized) { - return; - } - if (!target.is_initialized || COMPARATOR::Operation(source.value, target.value)) { - STATE::template AssignValue(target.arg, source.arg, target.is_initialized); - STATE::template AssignValue(target.value, source.value, target.is_initialized); - target.is_initialized = true; - } - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (!state.is_initialized) { - finalize_data.ReturnNull(); - } else { - STATE::template ReadValue(finalize_data.result, state.arg, target); - } - } - - static bool IgnoreNull() { - return true; - } -}; - -template -struct VectorArgMinMaxBase : ArgMinMaxBase { - template - static void AssignVector(STATE &state, Vector &arg, const idx_t idx) { - if (!state.is_initialized) { - state.arg = new Vector(arg.GetType()); - state.arg->SetVectorType(VectorType::CONSTANT_VECTOR); - } - sel_t selv = idx; - SelectionVector sel(&selv); - VectorOperations::Copy(arg, *state.arg, sel, 1, 0, 0); - } - - template - static void Update(Vector inputs[], AggregateInputData &, idx_t input_count, Vector &state_vector, idx_t count) { - auto &arg = inputs[0]; - UnifiedVectorFormat adata; - arg.ToUnifiedFormat(count, adata); - - using BY_TYPE = typename STATE::BY_TYPE; - auto &by = inputs[1]; - UnifiedVectorFormat bdata; - by.ToUnifiedFormat(count, bdata); - const auto bys = UnifiedVectorFormat::GetData(bdata); - - UnifiedVectorFormat sdata; - state_vector.ToUnifiedFormat(count, sdata); - - auto states = (STATE **)sdata.data; - for (idx_t i = 0; i < count; i++) { - const auto bidx = bdata.sel->get_index(i); - if (!bdata.validity.RowIsValid(bidx)) { - continue; - } - const auto bval = bys[bidx]; - - const auto sidx = sdata.sel->get_index(i); - auto &state = *states[sidx]; - if (!state.is_initialized) { - STATE::template AssignValue(state.value, bval, false); - AssignVector(state, arg, i); - state.is_initialized = true; - - } else if (COMPARATOR::template Operation(bval, state.value)) { - STATE::template AssignValue(state.value, bval, true); - AssignVector(state, arg, i); - } - } - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - if (!source.is_initialized) { - return; - } - if (!target.is_initialized || COMPARATOR::Operation(source.value, target.value)) { - STATE::template AssignValue(target.value, source.value, target.is_initialized); - AssignVector(target, *source.arg, 0); - target.is_initialized = true; - } - } - - template - static void Finalize(STATE &state, AggregateFinalizeData &finalize_data) { - if (!state.is_initialized) { - finalize_data.ReturnNull(); - } else { - VectorOperations::Copy(*state.arg, finalize_data.result, 1, 0, finalize_data.result_idx); - } - } - - static unique_ptr Bind(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - function.arguments[0] = arguments[0]->return_type; - function.return_type = arguments[0]->return_type; - return nullptr; - } -}; - -template -AggregateFunction GetVectorArgMinMaxFunctionInternal(const LogicalType &by_type, const LogicalType &type) { - using STATE = ArgMinMaxState; - return AggregateFunction( - {type, by_type}, type, AggregateFunction::StateSize, AggregateFunction::StateInitialize, - OP::template Update, AggregateFunction::StateCombine, - AggregateFunction::StateVoidFinalize, nullptr, OP::Bind, AggregateFunction::StateDestroy); -} - -template -AggregateFunction GetVectorArgMinMaxFunctionBy(const LogicalType &by_type, const LogicalType &type) { - switch (by_type.InternalType()) { - case PhysicalType::INT32: - return GetVectorArgMinMaxFunctionInternal(by_type, type); - case PhysicalType::INT64: - return GetVectorArgMinMaxFunctionInternal(by_type, type); - case PhysicalType::DOUBLE: - return GetVectorArgMinMaxFunctionInternal(by_type, type); - case PhysicalType::VARCHAR: - return GetVectorArgMinMaxFunctionInternal(by_type, type); - default: - throw InternalException("Unimplemented arg_min/arg_max aggregate"); - } -} - -template -void AddVectorArgMinMaxFunctionBy(AggregateFunctionSet &fun, const LogicalType &type) { - fun.AddFunction(GetVectorArgMinMaxFunctionBy(LogicalType::INTEGER, type)); - fun.AddFunction(GetVectorArgMinMaxFunctionBy(LogicalType::BIGINT, type)); - fun.AddFunction(GetVectorArgMinMaxFunctionBy(LogicalType::DOUBLE, type)); - fun.AddFunction(GetVectorArgMinMaxFunctionBy(LogicalType::VARCHAR, type)); - fun.AddFunction(GetVectorArgMinMaxFunctionBy(LogicalType::DATE, type)); - fun.AddFunction(GetVectorArgMinMaxFunctionBy(LogicalType::TIMESTAMP, type)); - fun.AddFunction(GetVectorArgMinMaxFunctionBy(LogicalType::TIMESTAMP_TZ, type)); - fun.AddFunction(GetVectorArgMinMaxFunctionBy(LogicalType::BLOB, type)); -} - -template -AggregateFunction GetArgMinMaxFunctionInternal(const LogicalType &by_type, const LogicalType &type) { - using STATE = ArgMinMaxState; - auto function = AggregateFunction::BinaryAggregate(type, by_type, type); - if (type.InternalType() == PhysicalType::VARCHAR || by_type.InternalType() == PhysicalType::VARCHAR) { - function.destructor = AggregateFunction::StateDestroy; - } - return function; -} - -template -AggregateFunction GetArgMinMaxFunctionBy(const LogicalType &by_type, const LogicalType &type) { - switch (by_type.InternalType()) { - case PhysicalType::INT32: - return GetArgMinMaxFunctionInternal(by_type, type); - case PhysicalType::INT64: - return GetArgMinMaxFunctionInternal(by_type, type); - case PhysicalType::DOUBLE: - return GetArgMinMaxFunctionInternal(by_type, type); - case PhysicalType::VARCHAR: - return GetArgMinMaxFunctionInternal(by_type, type); - default: - throw InternalException("Unimplemented arg_min/arg_max aggregate"); - } -} - -template -void AddArgMinMaxFunctionBy(AggregateFunctionSet &fun, const LogicalType &type) { - fun.AddFunction(GetArgMinMaxFunctionBy(LogicalType::INTEGER, type)); - fun.AddFunction(GetArgMinMaxFunctionBy(LogicalType::BIGINT, type)); - fun.AddFunction(GetArgMinMaxFunctionBy(LogicalType::DOUBLE, type)); - fun.AddFunction(GetArgMinMaxFunctionBy(LogicalType::VARCHAR, type)); - fun.AddFunction(GetArgMinMaxFunctionBy(LogicalType::DATE, type)); - fun.AddFunction(GetArgMinMaxFunctionBy(LogicalType::TIMESTAMP, type)); - fun.AddFunction(GetArgMinMaxFunctionBy(LogicalType::TIMESTAMP_TZ, type)); - fun.AddFunction(GetArgMinMaxFunctionBy(LogicalType::BLOB, type)); -} - -template -static void AddArgMinMaxFunctions(AggregateFunctionSet &fun) { - using OP = ArgMinMaxBase; - AddArgMinMaxFunctionBy(fun, LogicalType::INTEGER); - AddArgMinMaxFunctionBy(fun, LogicalType::BIGINT); - AddArgMinMaxFunctionBy(fun, LogicalType::DOUBLE); - AddArgMinMaxFunctionBy(fun, LogicalType::VARCHAR); - AddArgMinMaxFunctionBy(fun, LogicalType::DATE); - AddArgMinMaxFunctionBy(fun, LogicalType::TIMESTAMP); - AddArgMinMaxFunctionBy(fun, LogicalType::TIMESTAMP_TZ); - AddArgMinMaxFunctionBy(fun, LogicalType::BLOB); - - using VECTOR_OP = VectorArgMinMaxBase; - AddVectorArgMinMaxFunctionBy(fun, LogicalType::ANY); -} - -AggregateFunctionSet ArgMinFun::GetFunctions() { - AggregateFunctionSet fun; - AddArgMinMaxFunctions(fun); - return fun; -} - -AggregateFunctionSet ArgMaxFun::GetFunctions() { - AggregateFunctionSet fun; - AddArgMinMaxFunctions(fun); - return fun; -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -template -struct BitState { - bool is_set; - T value; -}; - -template -static AggregateFunction GetBitfieldUnaryAggregate(LogicalType type) { - switch (type.id()) { - case LogicalTypeId::TINYINT: - return AggregateFunction::UnaryAggregate, int8_t, int8_t, OP>(type, type); - case LogicalTypeId::SMALLINT: - return AggregateFunction::UnaryAggregate, int16_t, int16_t, OP>(type, type); - case LogicalTypeId::INTEGER: - return AggregateFunction::UnaryAggregate, int32_t, int32_t, OP>(type, type); - case LogicalTypeId::BIGINT: - return AggregateFunction::UnaryAggregate, int64_t, int64_t, OP>(type, type); - case LogicalTypeId::HUGEINT: - return AggregateFunction::UnaryAggregate, hugeint_t, hugeint_t, OP>(type, type); - case LogicalTypeId::UTINYINT: - return AggregateFunction::UnaryAggregate, uint8_t, uint8_t, OP>(type, type); - case LogicalTypeId::USMALLINT: - return AggregateFunction::UnaryAggregate, uint16_t, uint16_t, OP>(type, type); - case LogicalTypeId::UINTEGER: - return AggregateFunction::UnaryAggregate, uint32_t, uint32_t, OP>(type, type); - case LogicalTypeId::UBIGINT: - return AggregateFunction::UnaryAggregate, uint64_t, uint64_t, OP>(type, type); - default: - throw InternalException("Unimplemented bitfield type for unary aggregate"); - } -} - -struct BitwiseOperation { - template - static void Initialize(STATE &state) { - // If there are no matching rows, returns a null value. - state.is_set = false; - } - - template - static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &) { - if (!state.is_set) { - OP::template Assign(state, input); - state.is_set = true; - } else { - OP::template Execute(state, input); - } - } - - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, - idx_t count) { - OP::template Operation(state, input, unary_input); - } - - template - static void Assign(STATE &state, INPUT_TYPE input) { - state.value = input; - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - if (!source.is_set) { - // source is NULL, nothing to do. - return; - } - if (!target.is_set) { - // target is NULL, use source value directly. - OP::template Assign(target, source.value); - target.is_set = true; - } else { - OP::template Execute(target, source.value); - } - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (!state.is_set) { - finalize_data.ReturnNull(); - } else { - target = state.value; - } - } - - static bool IgnoreNull() { - return true; - } -}; - -struct BitAndOperation : public BitwiseOperation { - template - static void Execute(STATE &state, INPUT_TYPE input) { - state.value &= input; - } -}; - -struct BitOrOperation : public BitwiseOperation { - template - static void Execute(STATE &state, INPUT_TYPE input) { - state.value |= input; - } -}; - -struct BitXorOperation : public BitwiseOperation { - template - static void Execute(STATE &state, INPUT_TYPE input) { - state.value ^= input; - } - - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, - idx_t count) { - for (idx_t i = 0; i < count; i++) { - Operation(state, input, unary_input); - } - } -}; - -struct BitStringBitwiseOperation : public BitwiseOperation { - template - static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { - if (state.is_set && !state.value.IsInlined()) { - delete[] state.value.GetData(); - } - } - - template - static void Assign(STATE &state, INPUT_TYPE input) { - D_ASSERT(state.is_set == false); - if (input.IsInlined()) { - state.value = input; - } else { // non-inlined string, need to allocate space for it - auto len = input.GetSize(); - auto ptr = new char[len]; - memcpy(ptr, input.GetData(), len); - - state.value = string_t(ptr, len); - } - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (!state.is_set) { - finalize_data.ReturnNull(); - } else { - target = finalize_data.ReturnString(state.value); - } - } -}; - -struct BitStringAndOperation : public BitStringBitwiseOperation { - - template - static void Execute(STATE &state, INPUT_TYPE input) { - Bit::BitwiseAnd(input, state.value, state.value); - } -}; - -struct BitStringOrOperation : public BitStringBitwiseOperation { - - template - static void Execute(STATE &state, INPUT_TYPE input) { - Bit::BitwiseOr(input, state.value, state.value); - } -}; - -struct BitStringXorOperation : public BitStringBitwiseOperation { - template - static void Execute(STATE &state, INPUT_TYPE input) { - Bit::BitwiseXor(input, state.value, state.value); - } - - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, - idx_t count) { - for (idx_t i = 0; i < count; i++) { - Operation(state, input, unary_input); - } - } -}; - -AggregateFunctionSet BitAndFun::GetFunctions() { - AggregateFunctionSet bit_and; - for (auto &type : LogicalType::Integral()) { - bit_and.AddFunction(GetBitfieldUnaryAggregate(type)); - } - - bit_and.AddFunction( - AggregateFunction::UnaryAggregateDestructor, string_t, string_t, BitStringAndOperation>( - LogicalType::BIT, LogicalType::BIT)); - return bit_and; -} - -AggregateFunctionSet BitOrFun::GetFunctions() { - AggregateFunctionSet bit_or; - for (auto &type : LogicalType::Integral()) { - bit_or.AddFunction(GetBitfieldUnaryAggregate(type)); - } - bit_or.AddFunction( - AggregateFunction::UnaryAggregateDestructor, string_t, string_t, BitStringOrOperation>( - LogicalType::BIT, LogicalType::BIT)); - return bit_or; -} - -AggregateFunctionSet BitXorFun::GetFunctions() { - AggregateFunctionSet bit_xor; - for (auto &type : LogicalType::Integral()) { - bit_xor.AddFunction(GetBitfieldUnaryAggregate(type)); - } - bit_xor.AddFunction( - AggregateFunction::UnaryAggregateDestructor, string_t, string_t, BitStringXorOperation>( - LogicalType::BIT, LogicalType::BIT)); - return bit_xor; -} - -} // namespace duckdb - - - - - - - - - - -namespace duckdb { - -template -struct BitAggState { - bool is_set; - string_t value; - INPUT_TYPE min; - INPUT_TYPE max; -}; - -struct BitstringAggBindData : public FunctionData { - Value min; - Value max; - - BitstringAggBindData() { - } - - BitstringAggBindData(Value min, Value max) : min(std::move(min)), max(std::move(max)) { - } - - unique_ptr Copy() const override { - return make_uniq(*this); - } - - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - if (min.IsNull() && other.min.IsNull() && max.IsNull() && other.max.IsNull()) { - return true; - } - if (Value::NotDistinctFrom(min, other.min) && Value::NotDistinctFrom(max, other.max)) { - return true; - } - return false; - } -}; - -struct BitStringAggOperation { - static constexpr const idx_t MAX_BIT_RANGE = 1000000000; // for now capped at 1 billion bits - - template - static void Initialize(STATE &state) { - state.is_set = false; - } - - template - static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { - auto &bind_agg_data = unary_input.input.bind_data->template Cast(); - if (!state.is_set) { - if (bind_agg_data.min.IsNull() || bind_agg_data.max.IsNull()) { - throw BinderException( - "Could not retrieve required statistics. Alternatively, try by providing the statistics " - "explicitly: BITSTRING_AGG(col, min, max) "); - } - state.min = bind_agg_data.min.GetValue(); - state.max = bind_agg_data.max.GetValue(); - idx_t bit_range = - GetRange(bind_agg_data.min.GetValue(), bind_agg_data.max.GetValue()); - if (bit_range > MAX_BIT_RANGE) { - throw OutOfRangeException( - "The range between min and max value (%s <-> %s) is too large for bitstring aggregation", - NumericHelper::ToString(state.min), NumericHelper::ToString(state.max)); - } - idx_t len = Bit::ComputeBitstringLen(bit_range); - auto target = len > string_t::INLINE_LENGTH ? string_t(new char[len], len) : string_t(len); - Bit::SetEmptyBitString(target, bit_range); - - state.value = target; - state.is_set = true; - } - if (input >= state.min && input <= state.max) { - Execute(state, input, bind_agg_data.min.GetValue()); - } else { - throw OutOfRangeException("Value %s is outside of provided min and max range (%s <-> %s)", - NumericHelper::ToString(input), NumericHelper::ToString(state.min), - NumericHelper::ToString(state.max)); - } - } - - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, - idx_t count) { - OP::template Operation(state, input, unary_input); - } - - template - static idx_t GetRange(INPUT_TYPE min, INPUT_TYPE max) { - D_ASSERT(max >= min); - INPUT_TYPE result; - if (!TrySubtractOperator::Operation(max, min, result)) { - return NumericLimits::Maximum(); - } - idx_t val(result); - if (val == NumericLimits::Maximum()) { - return val; - } - return val + 1; - } - - template - static void Execute(STATE &state, INPUT_TYPE input, INPUT_TYPE min) { - Bit::SetBit(state.value, input - min, 1); - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - if (!source.is_set) { - return; - } - if (!target.is_set) { - Assign(target, source.value); - target.is_set = true; - target.min = source.min; - target.max = source.max; - } else { - Bit::BitwiseOr(source.value, target.value, target.value); - } - } - - template - static void Assign(STATE &state, INPUT_TYPE input) { - D_ASSERT(state.is_set == false); - if (input.IsInlined()) { - state.value = input; - } else { // non-inlined string, need to allocate space for it - auto len = input.GetSize(); - auto ptr = new char[len]; - memcpy(ptr, input.GetData(), len); - state.value = string_t(ptr, len); - } - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (!state.is_set) { - finalize_data.ReturnNull(); - } else { - target = StringVector::AddStringOrBlob(finalize_data.result, state.value); - } - } - - template - static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { - if (state.is_set && !state.value.IsInlined()) { - delete[] state.value.GetData(); - } - } - - static bool IgnoreNull() { - return true; - } -}; - -template <> -void BitStringAggOperation::Execute(BitAggState &state, hugeint_t input, hugeint_t min) { - idx_t val; - if (Hugeint::TryCast(input - min, val)) { - Bit::SetBit(state.value, val, 1); - } else { - throw OutOfRangeException("Range too large for bitstring aggregation"); - } -} - -template <> -idx_t BitStringAggOperation::GetRange(hugeint_t min, hugeint_t max) { - hugeint_t result; - if (!TrySubtractOperator::Operation(max, min, result)) { - return NumericLimits::Maximum(); - } - idx_t range; - if (!Hugeint::TryCast(result + 1, range)) { - return NumericLimits::Maximum(); - } - return range; -} - -unique_ptr BitstringPropagateStats(ClientContext &context, BoundAggregateExpression &expr, - AggregateStatisticsInput &input) { - - if (!NumericStats::HasMinMax(input.child_stats[0])) { - throw BinderException("Could not retrieve required statistics. Alternatively, try by providing the statistics " - "explicitly: BITSTRING_AGG(col, min, max) "); - } - auto &bind_agg_data = input.bind_data->Cast(); - bind_agg_data.min = NumericStats::Min(input.child_stats[0]); - bind_agg_data.max = NumericStats::Max(input.child_stats[0]); - return nullptr; -} - -unique_ptr BindBitstringAgg(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - if (arguments.size() == 3) { - if (!arguments[1]->IsFoldable() || !arguments[2]->IsFoldable()) { - throw BinderException("bitstring_agg requires a constant min and max argument"); - } - auto min = ExpressionExecutor::EvaluateScalar(context, *arguments[1]); - auto max = ExpressionExecutor::EvaluateScalar(context, *arguments[2]); - Function::EraseArgument(function, arguments, 2); - Function::EraseArgument(function, arguments, 1); - return make_uniq(min, max); - } - return make_uniq(); -} - -template -static void BindBitString(AggregateFunctionSet &bitstring_agg, const LogicalTypeId &type) { - auto function = - AggregateFunction::UnaryAggregateDestructor, TYPE, string_t, BitStringAggOperation>( - type, LogicalType::BIT); - function.bind = BindBitstringAgg; // create new a 'BitstringAggBindData' - function.statistics = BitstringPropagateStats; // stores min and max from column stats in BitstringAggBindData - bitstring_agg.AddFunction(function); // uses the BitstringAggBindData to access statistics for creating bitstring - function.arguments = {type, type, type}; - function.statistics = nullptr; // min and max are provided as arguments - bitstring_agg.AddFunction(function); -} - -void GetBitStringAggregate(const LogicalType &type, AggregateFunctionSet &bitstring_agg) { - switch (type.id()) { - case LogicalType::TINYINT: { - return BindBitString(bitstring_agg, type.id()); - } - case LogicalType::SMALLINT: { - return BindBitString(bitstring_agg, type.id()); - } - case LogicalType::INTEGER: { - return BindBitString(bitstring_agg, type.id()); - } - case LogicalType::BIGINT: { - return BindBitString(bitstring_agg, type.id()); - } - case LogicalType::HUGEINT: { - return BindBitString(bitstring_agg, type.id()); - } - case LogicalType::UTINYINT: { - return BindBitString(bitstring_agg, type.id()); - } - case LogicalType::USMALLINT: { - return BindBitString(bitstring_agg, type.id()); - } - case LogicalType::UINTEGER: { - return BindBitString(bitstring_agg, type.id()); - } - case LogicalType::UBIGINT: { - return BindBitString(bitstring_agg, type.id()); - } - default: - throw InternalException("Unimplemented bitstring aggregate"); - } -} - -AggregateFunctionSet BitstringAggFun::GetFunctions() { - AggregateFunctionSet bitstring_agg("bitstring_agg"); - for (auto &type : LogicalType::Integral()) { - GetBitStringAggregate(type, bitstring_agg); - } - return bitstring_agg; -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -struct BoolState { - bool empty; - bool val; -}; - -struct BoolAndFunFunction { - template - static void Initialize(STATE &state) { - state.val = true; - state.empty = true; - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - target.val = target.val && source.val; - target.empty = target.empty && source.empty; - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.empty) { - finalize_data.ReturnNull(); - return; - } - target = state.val; - } - - template - static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { - state.empty = false; - state.val = input && state.val; - } - - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, - idx_t count) { - for (idx_t i = 0; i < count; i++) { - Operation(state, input, unary_input); - } - } - static bool IgnoreNull() { - return true; - } -}; - -struct BoolOrFunFunction { - template - static void Initialize(STATE &state) { - state.val = false; - state.empty = true; - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - target.val = target.val || source.val; - target.empty = target.empty && source.empty; - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.empty) { - finalize_data.ReturnNull(); - return; - } - target = state.val; - } - template - static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { - state.empty = false; - state.val = input || state.val; - } - - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, - idx_t count) { - for (idx_t i = 0; i < count; i++) { - Operation(state, input, unary_input); - } - } - - static bool IgnoreNull() { - return true; - } -}; - -AggregateFunction BoolOrFun::GetFunction() { - auto fun = AggregateFunction::UnaryAggregate( - LogicalType(LogicalTypeId::BOOLEAN), LogicalType::BOOLEAN); - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return fun; -} - -AggregateFunction BoolAndFun::GetFunction() { - auto fun = AggregateFunction::UnaryAggregate( - LogicalType(LogicalTypeId::BOOLEAN), LogicalType::BOOLEAN); - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return fun; -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -template -struct EntropyState { - using DistinctMap = unordered_map; - - idx_t count; - DistinctMap *distinct; - - EntropyState &operator=(const EntropyState &other) = delete; - - EntropyState &Assign(const EntropyState &other) { - D_ASSERT(!distinct); - distinct = new DistinctMap(*other.distinct); - count = other.count; - return *this; - } -}; - -struct EntropyFunctionBase { - template - static void Initialize(STATE &state) { - state.distinct = nullptr; - state.count = 0; - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - if (!source.distinct) { - return; - } - if (!target.distinct) { - target.Assign(source); - return; - } - for (auto &val : *source.distinct) { - auto value = val.first; - (*target.distinct)[value] += val.second; - } - target.count += source.count; - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - double count = state.count; - if (state.distinct) { - double entropy = 0; - for (auto &val : *state.distinct) { - entropy += (val.second / count) * log2(count / val.second); - } - target = entropy; - } else { - target = 0; - } - } - - static bool IgnoreNull() { - return true; - } - template - static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { - if (state.distinct) { - delete state.distinct; - } - } -}; - -struct EntropyFunction : EntropyFunctionBase { - template - static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { - if (!state.distinct) { - state.distinct = new unordered_map(); - } - (*state.distinct)[input]++; - state.count++; - } - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, - idx_t count) { - for (idx_t i = 0; i < count; i++) { - Operation(state, input, unary_input); - } - } -}; - -struct EntropyFunctionString : EntropyFunctionBase { - template - static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { - if (!state.distinct) { - state.distinct = new unordered_map(); - } - auto value = input.GetString(); - (*state.distinct)[value]++; - state.count++; - } - - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, - idx_t count) { - for (idx_t i = 0; i < count; i++) { - Operation(state, input, unary_input); - } - } -}; - -template -AggregateFunction GetEntropyFunction(const LogicalType &input_type, const LogicalType &result_type) { - auto fun = - AggregateFunction::UnaryAggregateDestructor, INPUT_TYPE, RESULT_TYPE, EntropyFunction>( - input_type, result_type); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return fun; -} - -AggregateFunction GetEntropyFunctionInternal(PhysicalType type) { - switch (type) { - case PhysicalType::UINT16: - return AggregateFunction::UnaryAggregateDestructor, uint16_t, double, EntropyFunction>( - LogicalType::USMALLINT, LogicalType::DOUBLE); - case PhysicalType::UINT32: - return AggregateFunction::UnaryAggregateDestructor, uint32_t, double, EntropyFunction>( - LogicalType::UINTEGER, LogicalType::DOUBLE); - case PhysicalType::UINT64: - return AggregateFunction::UnaryAggregateDestructor, uint64_t, double, EntropyFunction>( - LogicalType::UBIGINT, LogicalType::DOUBLE); - case PhysicalType::INT16: - return AggregateFunction::UnaryAggregateDestructor, int16_t, double, EntropyFunction>( - LogicalType::SMALLINT, LogicalType::DOUBLE); - case PhysicalType::INT32: - return AggregateFunction::UnaryAggregateDestructor, int32_t, double, EntropyFunction>( - LogicalType::INTEGER, LogicalType::DOUBLE); - case PhysicalType::INT64: - return AggregateFunction::UnaryAggregateDestructor, int64_t, double, EntropyFunction>( - LogicalType::BIGINT, LogicalType::DOUBLE); - case PhysicalType::FLOAT: - return AggregateFunction::UnaryAggregateDestructor, float, double, EntropyFunction>( - LogicalType::FLOAT, LogicalType::DOUBLE); - case PhysicalType::DOUBLE: - return AggregateFunction::UnaryAggregateDestructor, double, double, EntropyFunction>( - LogicalType::DOUBLE, LogicalType::DOUBLE); - case PhysicalType::VARCHAR: - return AggregateFunction::UnaryAggregateDestructor, string_t, double, - EntropyFunctionString>(LogicalType::VARCHAR, - LogicalType::DOUBLE); - - default: - throw InternalException("Unimplemented approximate_count aggregate"); - } -} - -AggregateFunction GetEntropyFunction(PhysicalType type) { - auto fun = GetEntropyFunctionInternal(type); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return fun; -} - -AggregateFunctionSet EntropyFun::GetFunctions() { - AggregateFunctionSet entropy("entropy"); - entropy.AddFunction(GetEntropyFunction(PhysicalType::UINT16)); - entropy.AddFunction(GetEntropyFunction(PhysicalType::UINT32)); - entropy.AddFunction(GetEntropyFunction(PhysicalType::UINT64)); - entropy.AddFunction(GetEntropyFunction(PhysicalType::FLOAT)); - entropy.AddFunction(GetEntropyFunction(PhysicalType::INT16)); - entropy.AddFunction(GetEntropyFunction(PhysicalType::INT32)); - entropy.AddFunction(GetEntropyFunction(PhysicalType::INT64)); - entropy.AddFunction(GetEntropyFunction(PhysicalType::DOUBLE)); - entropy.AddFunction(GetEntropyFunction(PhysicalType::VARCHAR)); - entropy.AddFunction(GetEntropyFunction(LogicalType::TIMESTAMP, LogicalType::DOUBLE)); - entropy.AddFunction(GetEntropyFunction(LogicalType::TIMESTAMP_TZ, LogicalType::DOUBLE)); - return entropy; -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -struct KurtosisState { - idx_t n; - double sum; - double sum_sqr; - double sum_cub; - double sum_four; -}; - -struct KurtosisOperation { - template - static void Initialize(STATE &state) { - state.n = 0; - state.sum = state.sum_sqr = state.sum_cub = state.sum_four = 0.0; - } - - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, - idx_t count) { - for (idx_t i = 0; i < count; i++) { - Operation(state, input, unary_input); - } - } - - template - static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { - state.n++; - state.sum += input; - state.sum_sqr += pow(input, 2); - state.sum_cub += pow(input, 3); - state.sum_four += pow(input, 4); - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - if (source.n == 0) { - return; - } - target.n += source.n; - target.sum += source.sum; - target.sum_sqr += source.sum_sqr; - target.sum_cub += source.sum_cub; - target.sum_four += source.sum_four; - } - - template - static void Finalize(STATE &state, TARGET_TYPE &target, AggregateFinalizeData &finalize_data) { - auto n = (double)state.n; - if (n <= 3) { - finalize_data.ReturnNull(); - return; - } - double temp = 1 / n; - //! This is necessary due to linux 32 bits - long double temp_aux = 1 / n; - if (state.sum_sqr - state.sum * state.sum * temp == 0 || - state.sum_sqr - state.sum * state.sum * temp_aux == 0) { - finalize_data.ReturnNull(); - return; - } - double m4 = - temp * (state.sum_four - 4 * state.sum_cub * state.sum * temp + - 6 * state.sum_sqr * state.sum * state.sum * temp * temp - 3 * pow(state.sum, 4) * pow(temp, 3)); - - double m2 = temp * (state.sum_sqr - state.sum * state.sum * temp); - if (m2 <= 0 || ((n - 2) * (n - 3)) == 0) { // m2 shouldn't be below 0 but floating points are weird - finalize_data.ReturnNull(); - return; - } - target = (n - 1) * ((n + 1) * m4 / (m2 * m2) - 3 * (n - 1)) / ((n - 2) * (n - 3)); - if (!Value::DoubleIsFinite(target)) { - throw OutOfRangeException("Kurtosis is out of range!"); - } - } - - static bool IgnoreNull() { - return true; - } -}; - -AggregateFunction KurtosisFun::GetFunction() { - return AggregateFunction::UnaryAggregate(LogicalType::DOUBLE, - LogicalType::DOUBLE); -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -template -struct MinMaxState { - T value; - bool isset; -}; - -template -static AggregateFunction GetUnaryAggregate(LogicalType type) { - switch (type.InternalType()) { - case PhysicalType::BOOL: - return AggregateFunction::UnaryAggregate, int8_t, int8_t, OP>(type, type); - case PhysicalType::INT8: - return AggregateFunction::UnaryAggregate, int8_t, int8_t, OP>(type, type); - case PhysicalType::INT16: - return AggregateFunction::UnaryAggregate, int16_t, int16_t, OP>(type, type); - case PhysicalType::INT32: - return AggregateFunction::UnaryAggregate, int32_t, int32_t, OP>(type, type); - case PhysicalType::INT64: - return AggregateFunction::UnaryAggregate, int64_t, int64_t, OP>(type, type); - case PhysicalType::UINT8: - return AggregateFunction::UnaryAggregate, uint8_t, uint8_t, OP>(type, type); - case PhysicalType::UINT16: - return AggregateFunction::UnaryAggregate, uint16_t, uint16_t, OP>(type, type); - case PhysicalType::UINT32: - return AggregateFunction::UnaryAggregate, uint32_t, uint32_t, OP>(type, type); - case PhysicalType::UINT64: - return AggregateFunction::UnaryAggregate, uint64_t, uint64_t, OP>(type, type); - case PhysicalType::INT128: - return AggregateFunction::UnaryAggregate, hugeint_t, hugeint_t, OP>(type, type); - case PhysicalType::FLOAT: - return AggregateFunction::UnaryAggregate, float, float, OP>(type, type); - case PhysicalType::DOUBLE: - return AggregateFunction::UnaryAggregate, double, double, OP>(type, type); - case PhysicalType::INTERVAL: - return AggregateFunction::UnaryAggregate, interval_t, interval_t, OP>(type, type); - default: - throw InternalException("Unimplemented type for min/max aggregate"); - } -} - -struct MinMaxBase { - template - static void Initialize(STATE &state) { - state.isset = false; - } - - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, - idx_t count) { - if (!state.isset) { - OP::template Assign(state, input, unary_input.input); - state.isset = true; - } else { - OP::template Execute(state, input, unary_input.input); - } - } - - template - static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { - if (!state.isset) { - OP::template Assign(state, input, unary_input.input); - state.isset = true; - } else { - OP::template Execute(state, input, unary_input.input); - } - } - - static bool IgnoreNull() { - return true; - } -}; - -struct NumericMinMaxBase : public MinMaxBase { - template - static void Assign(STATE &state, INPUT_TYPE input, AggregateInputData &) { - state.value = input; - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (!state.isset) { - finalize_data.ReturnNull(); - } else { - target = state.value; - } - } -}; - -struct MinOperation : public NumericMinMaxBase { - template - static void Execute(STATE &state, INPUT_TYPE input, AggregateInputData &) { - if (LessThan::Operation(input, state.value)) { - state.value = input; - } - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - if (!source.isset) { - // source is NULL, nothing to do - return; - } - if (!target.isset) { - // target is NULL, use source value directly - target = source; - } else if (GreaterThan::Operation(target.value, source.value)) { - target.value = source.value; - } - } -}; - -struct MaxOperation : public NumericMinMaxBase { - template - static void Execute(STATE &state, INPUT_TYPE input, AggregateInputData &) { - if (GreaterThan::Operation(input, state.value)) { - state.value = input; - } - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - if (!source.isset) { - // source is NULL, nothing to do - return; - } - if (!target.isset) { - // target is NULL, use source value directly - target = source; - } else if (LessThan::Operation(target.value, source.value)) { - target.value = source.value; - } - } -}; - -struct StringMinMaxBase : public MinMaxBase { - template - static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { - if (state.isset && !state.value.IsInlined()) { - delete[] state.value.GetData(); - } - } - - template - static void Assign(STATE &state, INPUT_TYPE input, AggregateInputData &input_data) { - Destroy(state, input_data); - if (input.IsInlined()) { - state.value = input; - } else { - // non-inlined string, need to allocate space for it - auto len = input.GetSize(); - auto ptr = new char[len]; - memcpy(ptr, input.GetData(), len); - - state.value = string_t(ptr, len); - } - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (!state.isset) { - finalize_data.ReturnNull(); - } else { - target = StringVector::AddStringOrBlob(finalize_data.result, state.value); - } - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &input_data) { - if (!source.isset) { - // source is NULL, nothing to do - return; - } - if (!target.isset) { - // target is NULL, use source value directly - Assign(target, source.value, input_data); - target.isset = true; - } else { - OP::template Execute(target, source.value, input_data); - } - } -}; - -struct MinOperationString : public StringMinMaxBase { - template - static void Execute(STATE &state, INPUT_TYPE input, AggregateInputData &input_data) { - if (LessThan::Operation(input, state.value)) { - Assign(state, input, input_data); - } - } -}; - -struct MaxOperationString : public StringMinMaxBase { - template - static void Execute(STATE &state, INPUT_TYPE input, AggregateInputData &input_data) { - if (GreaterThan::Operation(input, state.value)) { - Assign(state, input, input_data); - } - } -}; - -template -static bool TemplatedOptimumType(Vector &left, idx_t lidx, idx_t lcount, Vector &right, idx_t ridx, idx_t rcount) { - UnifiedVectorFormat lvdata, rvdata; - left.ToUnifiedFormat(lcount, lvdata); - right.ToUnifiedFormat(rcount, rvdata); - - lidx = lvdata.sel->get_index(lidx); - ridx = rvdata.sel->get_index(ridx); - - auto ldata = UnifiedVectorFormat::GetData(lvdata); - auto rdata = UnifiedVectorFormat::GetData(rvdata); - - auto &lval = ldata[lidx]; - auto &rval = rdata[ridx]; - - auto lnull = !lvdata.validity.RowIsValid(lidx); - auto rnull = !rvdata.validity.RowIsValid(ridx); - - return OP::Operation(lval, rval, lnull, rnull); -} - -template -static bool TemplatedOptimumList(Vector &left, idx_t lidx, idx_t lcount, Vector &right, idx_t ridx, idx_t rcount); - -template -static bool TemplatedOptimumStruct(Vector &left, idx_t lidx, idx_t lcount, Vector &right, idx_t ridx, idx_t rcount); - -template -static bool TemplatedOptimumValue(Vector &left, idx_t lidx, idx_t lcount, Vector &right, idx_t ridx, idx_t rcount) { - D_ASSERT(left.GetType() == right.GetType()); - switch (left.GetType().InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - return TemplatedOptimumType(left, lidx, lcount, right, ridx, rcount); - case PhysicalType::INT16: - return TemplatedOptimumType(left, lidx, lcount, right, ridx, rcount); - case PhysicalType::INT32: - return TemplatedOptimumType(left, lidx, lcount, right, ridx, rcount); - case PhysicalType::INT64: - return TemplatedOptimumType(left, lidx, lcount, right, ridx, rcount); - case PhysicalType::UINT8: - return TemplatedOptimumType(left, lidx, lcount, right, ridx, rcount); - case PhysicalType::UINT16: - return TemplatedOptimumType(left, lidx, lcount, right, ridx, rcount); - case PhysicalType::UINT32: - return TemplatedOptimumType(left, lidx, lcount, right, ridx, rcount); - case PhysicalType::UINT64: - return TemplatedOptimumType(left, lidx, lcount, right, ridx, rcount); - case PhysicalType::INT128: - return TemplatedOptimumType(left, lidx, lcount, right, ridx, rcount); - case PhysicalType::FLOAT: - return TemplatedOptimumType(left, lidx, lcount, right, ridx, rcount); - case PhysicalType::DOUBLE: - return TemplatedOptimumType(left, lidx, lcount, right, ridx, rcount); - case PhysicalType::INTERVAL: - return TemplatedOptimumType(left, lidx, lcount, right, ridx, rcount); - case PhysicalType::VARCHAR: - return TemplatedOptimumType(left, lidx, lcount, right, ridx, rcount); - case PhysicalType::LIST: - return TemplatedOptimumList(left, lidx, lcount, right, ridx, rcount); - case PhysicalType::STRUCT: - return TemplatedOptimumStruct(left, lidx, lcount, right, ridx, rcount); - default: - throw InternalException("Invalid type for distinct comparison"); - } -} - -template -static bool TemplatedOptimumStruct(Vector &left, idx_t lidx_p, idx_t lcount, Vector &right, idx_t ridx_p, - idx_t rcount) { - // STRUCT dictionaries apply to all the children - // so map the indexes first - UnifiedVectorFormat lvdata, rvdata; - left.ToUnifiedFormat(lcount, lvdata); - right.ToUnifiedFormat(rcount, rvdata); - - idx_t lidx = lvdata.sel->get_index(lidx_p); - idx_t ridx = rvdata.sel->get_index(ridx_p); - - // DISTINCT semantics are in effect for nested types - auto lnull = !lvdata.validity.RowIsValid(lidx); - auto rnull = !rvdata.validity.RowIsValid(ridx); - if (lnull || rnull) { - return OP::Operation(0, 0, lnull, rnull); - } - - auto &lchildren = StructVector::GetEntries(left); - auto &rchildren = StructVector::GetEntries(right); - - D_ASSERT(lchildren.size() == rchildren.size()); - for (idx_t col_no = 0; col_no < lchildren.size(); ++col_no) { - auto &lchild = *lchildren[col_no]; - auto &rchild = *rchildren[col_no]; - - // Strict comparisons use the OP for definite - if (TemplatedOptimumValue(lchild, lidx_p, lcount, rchild, ridx_p, rcount)) { - return true; - } - - if (col_no == lchildren.size() - 1) { - break; - } - - // Strict comparisons use IS NOT DISTINCT for possible - if (!TemplatedOptimumValue(lchild, lidx_p, lcount, rchild, ridx_p, rcount)) { - return false; - } - } - - return false; -} - -template -static bool TemplatedOptimumList(Vector &left, idx_t lidx, idx_t lcount, Vector &right, idx_t ridx, idx_t rcount) { - UnifiedVectorFormat lvdata, rvdata; - left.ToUnifiedFormat(lcount, lvdata); - right.ToUnifiedFormat(rcount, rvdata); - - // Update the indexes and vector sizes for recursion. - lidx = lvdata.sel->get_index(lidx); - ridx = rvdata.sel->get_index(ridx); - - lcount = ListVector::GetListSize(left); - rcount = ListVector::GetListSize(right); - - // DISTINCT semantics are in effect for nested types - auto lnull = !lvdata.validity.RowIsValid(lidx); - auto rnull = !rvdata.validity.RowIsValid(ridx); - if (lnull || rnull) { - return OP::Operation(0, 0, lnull, rnull); - } - - auto &lchild = ListVector::GetEntry(left); - auto &rchild = ListVector::GetEntry(right); - - auto ldata = UnifiedVectorFormat::GetData(lvdata); - auto rdata = UnifiedVectorFormat::GetData(rvdata); - - auto &lval = ldata[lidx]; - auto &rval = rdata[ridx]; - - for (idx_t pos = 0;; ++pos) { - // Tie-breaking uses the OP - if (pos == lval.length || pos == rval.length) { - return OP::Operation(lval.length, rval.length, false, false); - } - - // Strict comparisons use the OP for definite - lidx = lval.offset + pos; - ridx = rval.offset + pos; - if (TemplatedOptimumValue(lchild, lidx, lcount, rchild, ridx, rcount)) { - return true; - } - - // Strict comparisons use IS NOT DISTINCT for possible - if (!TemplatedOptimumValue(lchild, lidx, lcount, rchild, ridx, rcount)) { - return false; - } - } - - return false; -} - -struct VectorMinMaxState { - Vector *value; -}; - -struct VectorMinMaxBase { - static bool IgnoreNull() { - return true; - } - - template - static void Initialize(STATE &state) { - state.value = nullptr; - } - - template - static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { - if (state.value) { - delete state.value; - } - state.value = nullptr; - } - - template - static void Assign(STATE &state, Vector &input, const idx_t idx) { - if (!state.value) { - state.value = new Vector(input.GetType()); - state.value->SetVectorType(VectorType::CONSTANT_VECTOR); - } - sel_t selv = idx; - SelectionVector sel(&selv); - VectorOperations::Copy(input, *state.value, sel, 1, 0, 0); - } - - template - static void Execute(STATE &state, Vector &input, const idx_t idx, const idx_t count) { - Assign(state, input, idx); - } - - template - static void Update(Vector inputs[], AggregateInputData &, idx_t input_count, Vector &state_vector, idx_t count) { - auto &input = inputs[0]; - UnifiedVectorFormat idata; - input.ToUnifiedFormat(count, idata); - - UnifiedVectorFormat sdata; - state_vector.ToUnifiedFormat(count, sdata); - - auto states = (STATE **)sdata.data; - for (idx_t i = 0; i < count; i++) { - const auto idx = idata.sel->get_index(i); - if (!idata.validity.RowIsValid(idx)) { - continue; - } - const auto sidx = sdata.sel->get_index(i); - auto &state = *states[sidx]; - if (!state.value) { - Assign(state, input, i); - } else { - OP::template Execute(state, input, i, count); - } - } - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - if (!source.value) { - return; - } else if (!target.value) { - Assign(target, *source.value, 0); - } else { - OP::template Execute(target, *source.value, 0, 1); - } - } - - template - static void Finalize(STATE &state, AggregateFinalizeData &finalize_data) { - if (!state.value) { - finalize_data.ReturnNull(); - } else { - VectorOperations::Copy(*state.value, finalize_data.result, 1, 0, finalize_data.result_idx); - } - } - - static unique_ptr Bind(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - function.arguments[0] = arguments[0]->return_type; - function.return_type = arguments[0]->return_type; - return nullptr; - } -}; - -struct MinOperationVector : public VectorMinMaxBase { - template - static void Execute(STATE &state, Vector &input, const idx_t idx, const idx_t count) { - if (TemplatedOptimumValue(input, idx, count, *state.value, 0, 1)) { - Assign(state, input, idx); - } - } -}; - -struct MaxOperationVector : public VectorMinMaxBase { - template - static void Execute(STATE &state, Vector &input, const idx_t idx, const idx_t count) { - if (TemplatedOptimumValue(input, idx, count, *state.value, 0, 1)) { - Assign(state, input, idx); - } - } -}; - -template -unique_ptr BindDecimalMinMax(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - auto decimal_type = arguments[0]->return_type; - auto name = function.name; - switch (decimal_type.InternalType()) { - case PhysicalType::INT16: - function = GetUnaryAggregate(LogicalType::SMALLINT); - break; - case PhysicalType::INT32: - function = GetUnaryAggregate(LogicalType::INTEGER); - break; - case PhysicalType::INT64: - function = GetUnaryAggregate(LogicalType::BIGINT); - break; - default: - function = GetUnaryAggregate(LogicalType::HUGEINT); - break; - } - function.name = std::move(name); - function.arguments[0] = decimal_type; - function.return_type = decimal_type; - function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return nullptr; -} - -template -static AggregateFunction GetMinMaxFunction(const LogicalType &type) { - return AggregateFunction( - {type}, type, AggregateFunction::StateSize, AggregateFunction::StateInitialize, - OP::template Update, AggregateFunction::StateCombine, - AggregateFunction::StateVoidFinalize, nullptr, OP::Bind, AggregateFunction::StateDestroy); -} - -template -static AggregateFunction GetMinMaxOperator(const LogicalType &type) { - if (type.InternalType() == PhysicalType::VARCHAR) { - return AggregateFunction::UnaryAggregateDestructor, string_t, string_t, OP_STRING>( - type.id(), type.id()); - } else if (type.InternalType() == PhysicalType::LIST || type.InternalType() == PhysicalType::STRUCT) { - return GetMinMaxFunction(type); - } else { - return GetUnaryAggregate(type); - } -} - -template -unique_ptr BindMinMax(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - auto input_type = arguments[0]->return_type; - auto name = std::move(function.name); - function = GetMinMaxOperator(input_type); - function.name = std::move(name); - function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - if (function.bind) { - return function.bind(context, function, arguments); - } else { - return nullptr; - } -} - -template -static void AddMinMaxOperator(AggregateFunctionSet &set) { - set.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr, - nullptr, nullptr, nullptr, BindDecimalMinMax)); - set.AddFunction(AggregateFunction({LogicalType::ANY}, LogicalType::ANY, nullptr, nullptr, nullptr, nullptr, nullptr, - nullptr, BindMinMax)); -} - -AggregateFunctionSet MinFun::GetFunctions() { - AggregateFunctionSet min("min"); - AddMinMaxOperator(min); - return min; -} - -AggregateFunctionSet MaxFun::GetFunctions() { - AggregateFunctionSet max("max"); - AddMinMaxOperator(max); - return max; -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -struct ProductState { - bool empty; - double val; -}; - -struct ProductFunction { - template - static void Initialize(STATE &state) { - state.val = 1; - state.empty = true; - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - target.val *= source.val; - target.empty = target.empty && source.empty; - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.empty) { - finalize_data.ReturnNull(); - return; - } - target = state.val; - } - template - static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { - if (state.empty) { - state.empty = false; - } - state.val *= input; - } - - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, - idx_t count) { - for (idx_t i = 0; i < count; i++) { - Operation(state, input, unary_input); - } - } - - static bool IgnoreNull() { - return true; - } -}; - -AggregateFunction ProductFun::GetFunction() { - return AggregateFunction::UnaryAggregate( - LogicalType(LogicalTypeId::DOUBLE), LogicalType::DOUBLE); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -struct SkewState { - size_t n; - double sum; - double sum_sqr; - double sum_cub; -}; - -struct SkewnessOperation { - template - static void Initialize(STATE &state) { - state.n = 0; - state.sum = state.sum_sqr = state.sum_cub = 0; - } - - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, - idx_t count) { - for (idx_t i = 0; i < count; i++) { - Operation(state, input, unary_input); - } - } - - template - static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { - state.n++; - state.sum += input; - state.sum_sqr += pow(input, 2); - state.sum_cub += pow(input, 3); - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - if (source.n == 0) { - return; - } - - target.n += source.n; - target.sum += source.sum; - target.sum_sqr += source.sum_sqr; - target.sum_cub += source.sum_cub; - } - - template - static void Finalize(STATE &state, TARGET_TYPE &target, AggregateFinalizeData &finalize_data) { - if (state.n <= 2) { - finalize_data.ReturnNull(); - return; - } - double n = state.n; - double temp = 1 / n; - auto p = std::pow(temp * (state.sum_sqr - state.sum * state.sum * temp), 3); - if (p < 0) { - p = 0; // Shouldn't be below 0 but floating points are weird - } - double div = std::sqrt(p); - if (div == 0) { - finalize_data.ReturnNull(); - return; - } - double temp1 = std::sqrt(n * (n - 1)) / (n - 2); - target = temp1 * temp * - (state.sum_cub - 3 * state.sum_sqr * state.sum * temp + 2 * pow(state.sum, 3) * temp * temp) / div; - if (!Value::DoubleIsFinite(target)) { - throw OutOfRangeException("SKEW is out of range!"); - } - } - - static bool IgnoreNull() { - return true; - } -}; - -AggregateFunction SkewnessFun::GetFunction() { - return AggregateFunction::UnaryAggregate(LogicalType::DOUBLE, - LogicalType::DOUBLE); -} - -} // namespace duckdb - - - - - - - - - - -namespace duckdb { - -struct StringAggState { - idx_t size; - idx_t alloc_size; - char *dataptr; -}; - -struct StringAggBindData : public FunctionData { - explicit StringAggBindData(string sep_p) : sep(std::move(sep_p)) { - } - - string sep; - - unique_ptr Copy() const override { - return make_uniq(sep); - } - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return sep == other.sep; - } -}; - -struct StringAggFunction { - template - static void Initialize(STATE &state) { - state.dataptr = nullptr; - state.alloc_size = 0; - state.size = 0; - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (!state.dataptr) { - finalize_data.ReturnNull(); - } else { - target = StringVector::AddString(finalize_data.result, state.dataptr, state.size); - } - } - - template - static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { - if (state.dataptr) { - delete[] state.dataptr; - } - } - - static bool IgnoreNull() { - return true; - } - - static inline void PerformOperation(StringAggState &state, const char *str, const char *sep, idx_t str_size, - idx_t sep_size) { - if (!state.dataptr) { - // first iteration: allocate space for the string and copy it into the state - state.alloc_size = MaxValue(8, NextPowerOfTwo(str_size)); - state.dataptr = new char[state.alloc_size]; - state.size = str_size; - memcpy(state.dataptr, str, str_size); - } else { - // subsequent iteration: first check if we have space to place the string and separator - idx_t required_size = state.size + str_size + sep_size; - if (required_size > state.alloc_size) { - // no space! allocate extra space - while (state.alloc_size < required_size) { - state.alloc_size *= 2; - } - auto new_data = new char[state.alloc_size]; - memcpy(new_data, state.dataptr, state.size); - delete[] state.dataptr; - state.dataptr = new_data; - } - // copy the separator - memcpy(state.dataptr + state.size, sep, sep_size); - state.size += sep_size; - // copy the string - memcpy(state.dataptr + state.size, str, str_size); - state.size += str_size; - } - } - - static inline void PerformOperation(StringAggState &state, string_t str, optional_ptr data_p) { - auto &data = data_p->Cast(); - PerformOperation(state, str.GetData(), data.sep.c_str(), str.GetSize(), data.sep.size()); - } - - template - static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { - PerformOperation(state, input, unary_input.input.bind_data); - } - - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, - idx_t count) { - for (idx_t i = 0; i < count; i++) { - Operation(state, input, unary_input); - } - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input_data) { - if (!source.dataptr) { - // source is not set: skip combining - return; - } - PerformOperation(target, string_t(source.dataptr, source.size), aggr_input_data.bind_data); - } -}; - -unique_ptr StringAggBind(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - if (arguments.size() == 1) { - // single argument: default to comma - return make_uniq(","); - } - D_ASSERT(arguments.size() == 2); - if (arguments[1]->HasParameter()) { - throw ParameterNotResolvedException(); - } - if (!arguments[1]->IsFoldable()) { - throw BinderException("Separator argument to StringAgg must be a constant"); - } - auto separator_val = ExpressionExecutor::EvaluateScalar(context, *arguments[1]); - string separator_string = ","; - if (separator_val.IsNull()) { - arguments[0] = make_uniq(Value(LogicalType::VARCHAR)); - } else { - separator_string = separator_val.ToString(); - } - Function::EraseArgument(function, arguments, arguments.size() - 1); - return make_uniq(std::move(separator_string)); -} - -static void StringAggSerialize(Serializer &serializer, const optional_ptr bind_data_p, - const AggregateFunction &function) { - auto bind_data = bind_data_p->Cast(); - serializer.WriteProperty(100, "separator", bind_data.sep); -} - -unique_ptr StringAggDeserialize(Deserializer &deserializer, AggregateFunction &bound_function) { - auto sep = deserializer.ReadProperty(100, "separator"); - return make_uniq(std::move(sep)); -} - -AggregateFunctionSet StringAggFun::GetFunctions() { - AggregateFunctionSet string_agg; - AggregateFunction string_agg_param( - {LogicalType::VARCHAR}, LogicalType::VARCHAR, AggregateFunction::StateSize, - AggregateFunction::StateInitialize, - AggregateFunction::UnaryScatterUpdate, - AggregateFunction::StateCombine, - AggregateFunction::StateFinalize, - AggregateFunction::UnaryUpdate, StringAggBind, - AggregateFunction::StateDestroy); - string_agg_param.serialize = StringAggSerialize; - string_agg_param.deserialize = StringAggDeserialize; - string_agg.AddFunction(string_agg_param); - string_agg_param.arguments.emplace_back(LogicalType::VARCHAR); - string_agg.AddFunction(string_agg_param); - return string_agg; -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -struct SumSetOperation { - template - static void Initialize(STATE &state) { - state.Initialize(); - } - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - target.Combine(source); - } - template - static void AddValues(STATE &state, idx_t count) { - state.isset = true; - } -}; - -struct IntegerSumOperation : public BaseSumOperation { - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (!state.isset) { - finalize_data.ReturnNull(); - } else { - target = Hugeint::Convert(state.value); - } - } -}; - -struct SumToHugeintOperation : public BaseSumOperation { - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (!state.isset) { - finalize_data.ReturnNull(); - } else { - target = state.value; - } - } -}; - -template -struct DoubleSumOperation : public BaseSumOperation { - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (!state.isset) { - finalize_data.ReturnNull(); - } else { - target = state.value; - } - } -}; - -using NumericSumOperation = DoubleSumOperation; -using KahanSumOperation = DoubleSumOperation; - -struct HugeintSumOperation : public BaseSumOperation { - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (!state.isset) { - finalize_data.ReturnNull(); - } else { - target = state.value; - } - } -}; - -AggregateFunction GetSumAggregateNoOverflow(PhysicalType type) { - switch (type) { - case PhysicalType::INT32: { - auto function = AggregateFunction::UnaryAggregate, int32_t, hugeint_t, IntegerSumOperation>( - LogicalType::INTEGER, LogicalType::HUGEINT); - function.name = "sum_no_overflow"; - function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return function; - } - case PhysicalType::INT64: { - auto function = AggregateFunction::UnaryAggregate, int64_t, hugeint_t, IntegerSumOperation>( - LogicalType::BIGINT, LogicalType::HUGEINT); - function.name = "sum_no_overflow"; - function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return function; - } - default: - throw BinderException("Unsupported internal type for sum_no_overflow"); - } -} - -unique_ptr SumPropagateStats(ClientContext &context, BoundAggregateExpression &expr, - AggregateStatisticsInput &input) { - if (input.node_stats && input.node_stats->has_max_cardinality) { - auto &numeric_stats = input.child_stats[0]; - if (!NumericStats::HasMinMax(numeric_stats)) { - return nullptr; - } - auto internal_type = numeric_stats.GetType().InternalType(); - hugeint_t max_negative; - hugeint_t max_positive; - switch (internal_type) { - case PhysicalType::INT32: - max_negative = NumericStats::Min(numeric_stats).GetValueUnsafe(); - max_positive = NumericStats::Max(numeric_stats).GetValueUnsafe(); - break; - case PhysicalType::INT64: - max_negative = NumericStats::Min(numeric_stats).GetValueUnsafe(); - max_positive = NumericStats::Max(numeric_stats).GetValueUnsafe(); - break; - default: - throw InternalException("Unsupported type for propagate sum stats"); - } - auto max_sum_negative = max_negative * hugeint_t(input.node_stats->max_cardinality); - auto max_sum_positive = max_positive * hugeint_t(input.node_stats->max_cardinality); - if (max_sum_positive >= NumericLimits::Maximum() || - max_sum_negative <= NumericLimits::Minimum()) { - // sum can potentially exceed int64_t bounds: use hugeint sum - return nullptr; - } - // total sum is guaranteed to fit in a single int64: use int64 sum instead of hugeint sum - expr.function = GetSumAggregateNoOverflow(internal_type); - } - return nullptr; -} - -AggregateFunction GetSumAggregate(PhysicalType type) { - switch (type) { - case PhysicalType::INT16: { - auto function = AggregateFunction::UnaryAggregate, int16_t, hugeint_t, IntegerSumOperation>( - LogicalType::SMALLINT, LogicalType::HUGEINT); - function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return function; - } - - case PhysicalType::INT32: { - auto function = - AggregateFunction::UnaryAggregate, int32_t, hugeint_t, SumToHugeintOperation>( - LogicalType::INTEGER, LogicalType::HUGEINT); - function.statistics = SumPropagateStats; - function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return function; - } - case PhysicalType::INT64: { - auto function = - AggregateFunction::UnaryAggregate, int64_t, hugeint_t, SumToHugeintOperation>( - LogicalType::BIGINT, LogicalType::HUGEINT); - function.statistics = SumPropagateStats; - function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return function; - } - case PhysicalType::INT128: { - auto function = - AggregateFunction::UnaryAggregate, hugeint_t, hugeint_t, HugeintSumOperation>( - LogicalType::HUGEINT, LogicalType::HUGEINT); - function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return function; - } - default: - throw InternalException("Unimplemented sum aggregate"); - } -} - -unique_ptr BindDecimalSum(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - auto decimal_type = arguments[0]->return_type; - function = GetSumAggregate(decimal_type.InternalType()); - function.name = "sum"; - function.arguments[0] = decimal_type; - function.return_type = LogicalType::DECIMAL(Decimal::MAX_WIDTH_DECIMAL, DecimalType::GetScale(decimal_type)); - function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return nullptr; -} - -unique_ptr BindDecimalSumNoOverflow(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - auto decimal_type = arguments[0]->return_type; - function = GetSumAggregateNoOverflow(decimal_type.InternalType()); - function.name = "sum_no_overflow"; - function.arguments[0] = decimal_type; - function.return_type = LogicalType::DECIMAL(Decimal::MAX_WIDTH_DECIMAL, DecimalType::GetScale(decimal_type)); - function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return nullptr; -} - -AggregateFunctionSet SumFun::GetFunctions() { - AggregateFunctionSet sum; - // decimal - sum.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr, - nullptr, nullptr, FunctionNullHandling::DEFAULT_NULL_HANDLING, nullptr, - BindDecimalSum)); - sum.AddFunction(GetSumAggregate(PhysicalType::INT16)); - sum.AddFunction(GetSumAggregate(PhysicalType::INT32)); - sum.AddFunction(GetSumAggregate(PhysicalType::INT64)); - sum.AddFunction(GetSumAggregate(PhysicalType::INT128)); - sum.AddFunction(AggregateFunction::UnaryAggregate, double, double, NumericSumOperation>( - LogicalType::DOUBLE, LogicalType::DOUBLE)); - return sum; -} - -AggregateFunctionSet SumNoOverflowFun::GetFunctions() { - AggregateFunctionSet sum_no_overflow; - sum_no_overflow.AddFunction(GetSumAggregateNoOverflow(PhysicalType::INT32)); - sum_no_overflow.AddFunction(GetSumAggregateNoOverflow(PhysicalType::INT64)); - sum_no_overflow.AddFunction( - AggregateFunction({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr, nullptr, nullptr, - FunctionNullHandling::DEFAULT_NULL_HANDLING, nullptr, BindDecimalSumNoOverflow)); - return sum_no_overflow; -} - -AggregateFunction KahanSumFun::GetFunction() { - return AggregateFunction::UnaryAggregate(LogicalType::DOUBLE, - LogicalType::DOUBLE); -} - -} // namespace duckdb - - - - - - - - -#include -#include -#include - -namespace duckdb { - -struct ApproxQuantileState { - duckdb_tdigest::TDigest *h; - idx_t pos; -}; - -struct ApproximateQuantileBindData : public FunctionData { - ApproximateQuantileBindData() { - } - explicit ApproximateQuantileBindData(float quantile_p) : quantiles(1, quantile_p) { - } - - explicit ApproximateQuantileBindData(vector quantiles_p) : quantiles(std::move(quantiles_p)) { - } - - unique_ptr Copy() const override { - return make_uniq(quantiles); - } - - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - // return quantiles == other.quantiles; - if (quantiles != other.quantiles) { - return false; - } - return true; - } - - static void Serialize(Serializer &serializer, const optional_ptr bind_data_p, - const AggregateFunction &function) { - auto &bind_data = bind_data_p->Cast(); - serializer.WriteProperty(100, "quantiles", bind_data.quantiles); - } - - static unique_ptr Deserialize(Deserializer &deserializer, AggregateFunction &function) { - auto result = make_uniq(); - deserializer.ReadProperty(100, "quantiles", result->quantiles); - return std::move(result); - } - - vector quantiles; -}; - -struct ApproxQuantileOperation { - using SAVE_TYPE = duckdb_tdigest::Value; - - template - static void Initialize(STATE &state) { - state.pos = 0; - state.h = nullptr; - } - - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, - idx_t count) { - for (idx_t i = 0; i < count; i++) { - Operation(state, input, unary_input); - } - } - - template - static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { - auto val = Cast::template Operation(input); - if (!Value::DoubleIsFinite(val)) { - return; - } - if (!state.h) { - state.h = new duckdb_tdigest::TDigest(100); - } - state.h->add(val); - state.pos++; - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - if (source.pos == 0) { - return; - } - D_ASSERT(source.h); - if (!target.h) { - target.h = new duckdb_tdigest::TDigest(100); - } - target.h->merge(source.h); - target.pos += source.pos; - } - - template - static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { - if (state.h) { - delete state.h; - } - } - - static bool IgnoreNull() { - return true; - } -}; - -struct ApproxQuantileScalarOperation : public ApproxQuantileOperation { - template - static void Finalize(STATE &state, TARGET_TYPE &target, AggregateFinalizeData &finalize_data) { - if (state.pos == 0) { - finalize_data.ReturnNull(); - return; - } - D_ASSERT(state.h); - D_ASSERT(finalize_data.input.bind_data); - state.h->compress(); - auto &bind_data = finalize_data.input.bind_data->template Cast(); - D_ASSERT(bind_data.quantiles.size() == 1); - target = Cast::template Operation(state.h->quantile(bind_data.quantiles[0])); - } -}; - -AggregateFunction GetApproximateQuantileAggregateFunction(PhysicalType type) { - switch (type) { - case PhysicalType::INT16: - return AggregateFunction::UnaryAggregateDestructor(LogicalType::SMALLINT, - LogicalType::SMALLINT); - case PhysicalType::INT32: - return AggregateFunction::UnaryAggregateDestructor(LogicalType::INTEGER, - LogicalType::INTEGER); - case PhysicalType::INT64: - return AggregateFunction::UnaryAggregateDestructor(LogicalType::BIGINT, - LogicalType::BIGINT); - case PhysicalType::INT128: - return AggregateFunction::UnaryAggregateDestructor(LogicalType::HUGEINT, - LogicalType::HUGEINT); - case PhysicalType::DOUBLE: - return AggregateFunction::UnaryAggregateDestructor(LogicalType::DOUBLE, - LogicalType::DOUBLE); - default: - throw InternalException("Unimplemented quantile aggregate"); - } -} - -static float CheckApproxQuantile(const Value &quantile_val) { - if (quantile_val.IsNull()) { - throw BinderException("APPROXIMATE QUANTILE parameter cannot be NULL"); - } - auto quantile = quantile_val.GetValue(); - if (quantile < 0 || quantile > 1) { - throw BinderException("APPROXIMATE QUANTILE can only take parameters in range [0, 1]"); - } - - return quantile; -} - -unique_ptr BindApproxQuantile(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - if (arguments[1]->HasParameter()) { - throw ParameterNotResolvedException(); - } - if (!arguments[1]->IsFoldable()) { - throw BinderException("APPROXIMATE QUANTILE can only take constant quantile parameters"); - } - Value quantile_val = ExpressionExecutor::EvaluateScalar(context, *arguments[1]); - - vector quantiles; - if (quantile_val.type().id() != LogicalTypeId::LIST) { - quantiles.push_back(CheckApproxQuantile(quantile_val)); - } else { - for (const auto &element_val : ListValue::GetChildren(quantile_val)) { - quantiles.push_back(CheckApproxQuantile(element_val)); - } - } - - // remove the quantile argument so we can use the unary aggregate - Function::EraseArgument(function, arguments, arguments.size() - 1); - return make_uniq(quantiles); -} - -unique_ptr BindApproxQuantileDecimal(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - auto bind_data = BindApproxQuantile(context, function, arguments); - function = GetApproximateQuantileAggregateFunction(arguments[0]->return_type.InternalType()); - function.name = "approx_quantile"; - function.serialize = ApproximateQuantileBindData::Serialize; - function.deserialize = ApproximateQuantileBindData::Deserialize; - return bind_data; -} - -AggregateFunction GetApproximateQuantileAggregate(PhysicalType type) { - auto fun = GetApproximateQuantileAggregateFunction(type); - fun.bind = BindApproxQuantile; - fun.serialize = ApproximateQuantileBindData::Serialize; - fun.deserialize = ApproximateQuantileBindData::Deserialize; - // temporarily push an argument so we can bind the actual quantile - fun.arguments.emplace_back(LogicalType::FLOAT); - return fun; -} - -template -struct ApproxQuantileListOperation : public ApproxQuantileOperation { - - template - static void Finalize(STATE &state, RESULT_TYPE &target, AggregateFinalizeData &finalize_data) { - if (state.pos == 0) { - finalize_data.ReturnNull(); - return; - } - - D_ASSERT(finalize_data.input.bind_data); - auto &bind_data = finalize_data.input.bind_data->template Cast(); - - auto &result = ListVector::GetEntry(finalize_data.result); - auto ridx = ListVector::GetListSize(finalize_data.result); - ListVector::Reserve(finalize_data.result, ridx + bind_data.quantiles.size()); - auto rdata = FlatVector::GetData(result); - - D_ASSERT(state.h); - state.h->compress(); - - auto &entry = target; - entry.offset = ridx; - entry.length = bind_data.quantiles.size(); - for (size_t q = 0; q < entry.length; ++q) { - const auto &quantile = bind_data.quantiles[q]; - rdata[ridx + q] = Cast::template Operation(state.h->quantile(quantile)); - } - - ListVector::SetListSize(finalize_data.result, entry.offset + entry.length); - } -}; - -template -static AggregateFunction ApproxQuantileListAggregate(const LogicalType &input_type, const LogicalType &child_type) { - LogicalType result_type = LogicalType::LIST(child_type); - return AggregateFunction( - {input_type}, result_type, AggregateFunction::StateSize, AggregateFunction::StateInitialize, - AggregateFunction::UnaryScatterUpdate, AggregateFunction::StateCombine, - AggregateFunction::StateFinalize, AggregateFunction::UnaryUpdate, - nullptr, AggregateFunction::StateDestroy); -} - -template -AggregateFunction GetTypedApproxQuantileListAggregateFunction(const LogicalType &type) { - using STATE = ApproxQuantileState; - using OP = ApproxQuantileListOperation; - auto fun = ApproxQuantileListAggregate(type, type); - fun.serialize = ApproximateQuantileBindData::Serialize; - fun.deserialize = ApproximateQuantileBindData::Deserialize; - return fun; -} - -AggregateFunction GetApproxQuantileListAggregateFunction(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::TINYINT: - return GetTypedApproxQuantileListAggregateFunction(type); - case LogicalTypeId::SMALLINT: - return GetTypedApproxQuantileListAggregateFunction(type); - case LogicalTypeId::INTEGER: - return GetTypedApproxQuantileListAggregateFunction(type); - case LogicalTypeId::BIGINT: - return GetTypedApproxQuantileListAggregateFunction(type); - case LogicalTypeId::HUGEINT: - return GetTypedApproxQuantileListAggregateFunction(type); - case LogicalTypeId::FLOAT: - return GetTypedApproxQuantileListAggregateFunction(type); - case LogicalTypeId::DOUBLE: - return GetTypedApproxQuantileListAggregateFunction(type); - case LogicalTypeId::DECIMAL: - switch (type.InternalType()) { - case PhysicalType::INT16: - return GetTypedApproxQuantileListAggregateFunction(type); - case PhysicalType::INT32: - return GetTypedApproxQuantileListAggregateFunction(type); - case PhysicalType::INT64: - return GetTypedApproxQuantileListAggregateFunction(type); - case PhysicalType::INT128: - return GetTypedApproxQuantileListAggregateFunction(type); - default: - throw NotImplementedException("Unimplemented approximate quantile list aggregate"); - } - default: - // TODO: Add quantitative temporal types - throw NotImplementedException("Unimplemented approximate quantile list aggregate"); - } -} - -unique_ptr BindApproxQuantileDecimalList(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - auto bind_data = BindApproxQuantile(context, function, arguments); - function = GetApproxQuantileListAggregateFunction(arguments[0]->return_type); - function.name = "approx_quantile"; - function.serialize = ApproximateQuantileBindData::Serialize; - function.deserialize = ApproximateQuantileBindData::Deserialize; - return bind_data; -} - -AggregateFunction GetApproxQuantileListAggregate(const LogicalType &type) { - auto fun = GetApproxQuantileListAggregateFunction(type); - fun.bind = BindApproxQuantile; - fun.serialize = ApproximateQuantileBindData::Serialize; - fun.deserialize = ApproximateQuantileBindData::Deserialize; - // temporarily push an argument so we can bind the actual quantile - auto list_of_float = LogicalType::LIST(LogicalType::FLOAT); - fun.arguments.push_back(list_of_float); - return fun; -} - -AggregateFunctionSet ApproxQuantileFun::GetFunctions() { - AggregateFunctionSet approx_quantile; - approx_quantile.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL, LogicalType::FLOAT}, LogicalTypeId::DECIMAL, - nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, - BindApproxQuantileDecimal)); - - approx_quantile.AddFunction(GetApproximateQuantileAggregate(PhysicalType::INT16)); - approx_quantile.AddFunction(GetApproximateQuantileAggregate(PhysicalType::INT32)); - approx_quantile.AddFunction(GetApproximateQuantileAggregate(PhysicalType::INT64)); - approx_quantile.AddFunction(GetApproximateQuantileAggregate(PhysicalType::INT128)); - approx_quantile.AddFunction(GetApproximateQuantileAggregate(PhysicalType::DOUBLE)); - - // List variants - approx_quantile.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL, LogicalType::LIST(LogicalType::FLOAT)}, - LogicalType::LIST(LogicalTypeId::DECIMAL), nullptr, nullptr, nullptr, - nullptr, nullptr, nullptr, BindApproxQuantileDecimalList)); - - approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::TINYINT)); - approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::SMALLINT)); - approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::INTEGER)); - approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::BIGINT)); - approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::HUGEINT)); - approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::FLOAT)); - approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::DOUBLE)); - return approx_quantile; -} - -} // namespace duckdb -// MODE( ) -// Returns the most frequent value for the values within expr1. -// NULL values are ignored. If all the values are NULL, or there are 0 rows, then the function returns NULL. - - - - - - - - -#include - -namespace std { - -template <> -struct hash { - inline size_t operator()(const duckdb::interval_t &val) const { - return hash {}(val.days) ^ hash {}(val.months) ^ hash {}(val.micros); - } -}; - -template <> -struct hash { - inline size_t operator()(const duckdb::hugeint_t &val) const { - return hash {}(val.upper) ^ hash {}(val.lower); - } -}; - -} // namespace std - -namespace duckdb { - -template -struct ModeState { - struct ModeAttr { - ModeAttr() : count(0), first_row(std::numeric_limits::max()) { - } - size_t count; - idx_t first_row; - }; - using Counts = unordered_map; - - Counts *frequency_map; - KEY_TYPE *mode; - size_t nonzero; - bool valid; - size_t count; - - void Initialize() { - frequency_map = nullptr; - mode = nullptr; - nonzero = 0; - valid = false; - count = 0; - } - - void Destroy() { - if (frequency_map) { - delete frequency_map; - } - if (mode) { - delete mode; - } - } - - void Reset() { - Counts empty; - frequency_map->swap(empty); - nonzero = 0; - count = 0; - valid = false; - } - - void ModeAdd(const KEY_TYPE &key, idx_t row) { - auto &attr = (*frequency_map)[key]; - auto new_count = (attr.count += 1); - if (new_count == 1) { - ++nonzero; - attr.first_row = row; - } else { - attr.first_row = MinValue(row, attr.first_row); - } - if (new_count > count) { - valid = true; - count = new_count; - if (mode) { - *mode = key; - } else { - mode = new KEY_TYPE(key); - } - } - } - - void ModeRm(const KEY_TYPE &key, idx_t frame) { - auto &attr = (*frequency_map)[key]; - auto old_count = attr.count; - nonzero -= int(old_count == 1); - - attr.count -= 1; - if (count == old_count && key == *mode) { - valid = false; - } - } - - typename Counts::const_iterator Scan() const { - //! Initialize control variables to first variable of the frequency map - auto highest_frequency = frequency_map->begin(); - for (auto i = highest_frequency; i != frequency_map->end(); ++i) { - // Tie break with the lowest insert position - if (i->second.count > highest_frequency->second.count || - (i->second.count == highest_frequency->second.count && - i->second.first_row < highest_frequency->second.first_row)) { - highest_frequency = i; - } - } - return highest_frequency; - } -}; - -struct ModeIncluded { - inline explicit ModeIncluded(const ValidityMask &fmask_p, const ValidityMask &dmask_p, idx_t bias_p) - : fmask(fmask_p), dmask(dmask_p), bias(bias_p) { - } - - inline bool operator()(const idx_t &idx) const { - return fmask.RowIsValid(idx) && dmask.RowIsValid(idx - bias); - } - const ValidityMask &fmask; - const ValidityMask &dmask; - const idx_t bias; -}; - -struct ModeAssignmentStandard { - template - static RESULT_TYPE Assign(Vector &result, INPUT_TYPE input) { - return RESULT_TYPE(input); - } -}; - -struct ModeAssignmentString { - template - static RESULT_TYPE Assign(Vector &result, INPUT_TYPE input) { - return StringVector::AddString(result, input); - } -}; - -template -struct ModeFunction { - template - static void Initialize(STATE &state) { - state.Initialize(); - } - - template - static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &) { - if (!state.frequency_map) { - state.frequency_map = new typename STATE::Counts(); - } - auto key = KEY_TYPE(input); - auto &i = (*state.frequency_map)[key]; - i.count++; - i.first_row = MinValue(i.first_row, state.count); - state.count++; - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - if (!source.frequency_map) { - return; - } - if (!target.frequency_map) { - // Copy - don't destroy! Otherwise windowing will break. - target.frequency_map = new typename STATE::Counts(*source.frequency_map); - return; - } - for (auto &val : *source.frequency_map) { - auto &i = (*target.frequency_map)[val.first]; - i.count += val.second.count; - i.first_row = MinValue(i.first_row, val.second.first_row); - } - target.count += source.count; - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (!state.frequency_map) { - finalize_data.ReturnNull(); - return; - } - auto highest_frequency = state.Scan(); - if (highest_frequency != state.frequency_map->end()) { - target = ASSIGN_OP::template Assign(finalize_data.result, highest_frequency->first); - } else { - finalize_data.ReturnNull(); - } - } - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &, idx_t count) { - if (!state.frequency_map) { - state.frequency_map = new typename STATE::Counts(); - } - auto key = KEY_TYPE(input); - auto &i = (*state.frequency_map)[key]; - i.count += count; - i.first_row = MinValue(i.first_row, state.count); - state.count += count; - } - - template - static void Window(const INPUT_TYPE *data, const ValidityMask &fmask, const ValidityMask &dmask, - AggregateInputData &, STATE &state, const FrameBounds &frame, const FrameBounds &prev, - Vector &result, idx_t rid, idx_t bias) { - auto rdata = FlatVector::GetData(result); - auto &rmask = FlatVector::Validity(result); - - ModeIncluded included(fmask, dmask, bias); - - if (!state.frequency_map) { - state.frequency_map = new typename STATE::Counts; - } - const double tau = .25; - if (state.nonzero <= tau * state.frequency_map->size() || prev.end <= frame.start || frame.end <= prev.start) { - state.Reset(); - // for f ∈ F do - for (auto f = frame.start; f < frame.end; ++f) { - if (included(f)) { - state.ModeAdd(KEY_TYPE(data[f]), f); - } - } - } else { - // for f ∈ P \ F do - for (auto p = prev.start; p < frame.start; ++p) { - if (included(p)) { - state.ModeRm(KEY_TYPE(data[p]), p); - } - } - for (auto p = frame.end; p < prev.end; ++p) { - if (included(p)) { - state.ModeRm(KEY_TYPE(data[p]), p); - } - } - - // for f ∈ F \ P do - for (auto f = frame.start; f < prev.start; ++f) { - if (included(f)) { - state.ModeAdd(KEY_TYPE(data[f]), f); - } - } - for (auto f = prev.end; f < frame.end; ++f) { - if (included(f)) { - state.ModeAdd(KEY_TYPE(data[f]), f); - } - } - } - - if (!state.valid) { - // Rescan - auto highest_frequency = state.Scan(); - if (highest_frequency != state.frequency_map->end()) { - *(state.mode) = highest_frequency->first; - state.count = highest_frequency->second.count; - state.valid = (state.count > 0); - } - } - - if (state.valid) { - rdata[rid] = ASSIGN_OP::template Assign(result, *state.mode); - } else { - rmask.Set(rid, false); - } - } - - static bool IgnoreNull() { - return true; - } - - template - static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { - state.Destroy(); - } -}; - -template -AggregateFunction GetTypedModeFunction(const LogicalType &type) { - using STATE = ModeState; - using OP = ModeFunction; - auto func = AggregateFunction::UnaryAggregateDestructor(type, type); - func.window = AggregateFunction::UnaryWindow; - return func; -} - -AggregateFunction GetModeAggregate(const LogicalType &type) { - switch (type.InternalType()) { - case PhysicalType::INT8: - return GetTypedModeFunction(type); - case PhysicalType::UINT8: - return GetTypedModeFunction(type); - case PhysicalType::INT16: - return GetTypedModeFunction(type); - case PhysicalType::UINT16: - return GetTypedModeFunction(type); - case PhysicalType::INT32: - return GetTypedModeFunction(type); - case PhysicalType::UINT32: - return GetTypedModeFunction(type); - case PhysicalType::INT64: - return GetTypedModeFunction(type); - case PhysicalType::UINT64: - return GetTypedModeFunction(type); - case PhysicalType::INT128: - return GetTypedModeFunction(type); - - case PhysicalType::FLOAT: - return GetTypedModeFunction(type); - case PhysicalType::DOUBLE: - return GetTypedModeFunction(type); - - case PhysicalType::INTERVAL: - return GetTypedModeFunction(type); - - case PhysicalType::VARCHAR: - return GetTypedModeFunction(type); - - default: - throw NotImplementedException("Unimplemented mode aggregate"); - } -} - -unique_ptr BindModeDecimal(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - function = GetModeAggregate(arguments[0]->return_type); - function.name = "mode"; - return nullptr; -} - -AggregateFunctionSet ModeFun::GetFunctions() { - const vector TEMPORAL = {LogicalType::DATE, LogicalType::TIMESTAMP, LogicalType::TIME, - LogicalType::TIMESTAMP_TZ, LogicalType::TIME_TZ, LogicalType::INTERVAL}; - - AggregateFunctionSet mode; - mode.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr, - nullptr, nullptr, nullptr, BindModeDecimal)); - - for (const auto &type : LogicalType::Numeric()) { - if (type.id() != LogicalTypeId::DECIMAL) { - mode.AddFunction(GetModeAggregate(type)); - } - } - - for (const auto &type : TEMPORAL) { - mode.AddFunction(GetModeAggregate(type)); - } - - mode.AddFunction(GetModeAggregate(LogicalType::VARCHAR)); - return mode; -} -} // namespace duckdb - - - - - - - - - - - - -#include -#include -#include - -namespace duckdb { - -// Hugeint arithmetic -static hugeint_t MultiplyByDouble(const hugeint_t &h, const double &d) { - D_ASSERT(d >= 0 && d <= 1); - return Hugeint::Convert(Hugeint::Cast(h) * d); -} - -// Interval arithmetic -static interval_t MultiplyByDouble(const interval_t &i, const double &d) { // NOLINT - D_ASSERT(d >= 0 && d <= 1); - return Interval::FromMicro(std::llround(Interval::GetMicro(i) * d)); -} - -inline interval_t operator+(const interval_t &lhs, const interval_t &rhs) { - return Interval::FromMicro(Interval::GetMicro(lhs) + Interval::GetMicro(rhs)); -} - -inline interval_t operator-(const interval_t &lhs, const interval_t &rhs) { - return Interval::FromMicro(Interval::GetMicro(lhs) - Interval::GetMicro(rhs)); -} - -template -struct QuantileState { - using SaveType = SAVE_TYPE; - - // Regular aggregation - vector v; - - // Windowed Quantile indirection - vector w; - idx_t pos; - - // Windowed MAD indirection - vector m; - - QuantileState() : pos(0) { - } - - ~QuantileState() { - } - - inline void SetPos(size_t pos_p) { - pos = pos_p; - if (pos >= w.size()) { - w.resize(pos); - } - } -}; - -struct QuantileIncluded { - inline explicit QuantileIncluded(const ValidityMask &fmask_p, const ValidityMask &dmask_p, idx_t bias_p) - : fmask(fmask_p), dmask(dmask_p), bias(bias_p) { - } - - inline bool operator()(const idx_t &idx) const { - return fmask.RowIsValid(idx) && dmask.RowIsValid(idx - bias); - } - - inline bool AllValid() const { - return fmask.AllValid() && dmask.AllValid(); - } - - const ValidityMask &fmask; - const ValidityMask &dmask; - const idx_t bias; -}; - -void ReuseIndexes(idx_t *index, const FrameBounds &frame, const FrameBounds &prev) { - idx_t j = 0; - - // Copy overlapping indices - for (idx_t p = 0; p < (prev.end - prev.start); ++p) { - auto idx = index[p]; - - // Shift down into any hole - if (j != p) { - index[j] = idx; - } - - // Skip overlapping values - if (frame.start <= idx && idx < frame.end) { - ++j; - } - } - - // Insert new indices - if (j > 0) { - // Overlap: append the new ends - for (auto f = frame.start; f < prev.start; ++f, ++j) { - index[j] = f; - } - for (auto f = prev.end; f < frame.end; ++f, ++j) { - index[j] = f; - } - } else { - // No overlap: overwrite with new values - for (auto f = frame.start; f < frame.end; ++f, ++j) { - index[j] = f; - } - } -} - -static idx_t ReplaceIndex(idx_t *index, const FrameBounds &frame, const FrameBounds &prev) { // NOLINT - D_ASSERT(index); - - idx_t j = 0; - for (idx_t p = 0; p < (prev.end - prev.start); ++p) { - auto idx = index[p]; - if (j != p) { - break; - } - - if (frame.start <= idx && idx < frame.end) { - ++j; - } - } - index[j] = frame.end - 1; - - return j; -} - -template -static inline int CanReplace(const idx_t *index, const INPUT_TYPE *fdata, const idx_t j, const idx_t k0, const idx_t k1, - const QuantileIncluded &validity) { - D_ASSERT(index); - - // NULLs sort to the end, so if we have inserted a NULL, - // it must be past the end of the quantile to be replaceable. - // Note that the quantile values are never NULL. - const auto ij = index[j]; - if (!validity(ij)) { - return k1 < j ? 1 : 0; - } - - auto curr = fdata[ij]; - if (k1 < j) { - auto hi = fdata[index[k0]]; - return hi < curr ? 1 : 0; - } else if (j < k0) { - auto lo = fdata[index[k1]]; - return curr < lo ? -1 : 0; - } - - return 0; -} - -template -struct IndirectLess { - inline explicit IndirectLess(const INPUT_TYPE *inputs_p) : inputs(inputs_p) { - } - - inline bool operator()(const idx_t &lhi, const idx_t &rhi) const { - return inputs[lhi] < inputs[rhi]; - } - - const INPUT_TYPE *inputs; -}; - -struct CastInterpolation { - - template - static inline TARGET_TYPE Cast(const INPUT_TYPE &src, Vector &result) { - return Cast::Operation(src); - } - template - static inline TARGET_TYPE Interpolate(const TARGET_TYPE &lo, const double d, const TARGET_TYPE &hi) { - const auto delta = hi - lo; - return lo + delta * d; - } -}; - -template <> -interval_t CastInterpolation::Cast(const dtime_t &src, Vector &result) { - return {0, 0, src.micros}; -} - -template <> -double CastInterpolation::Interpolate(const double &lo, const double d, const double &hi) { - return lo * (1.0 - d) + hi * d; -} - -template <> -dtime_t CastInterpolation::Interpolate(const dtime_t &lo, const double d, const dtime_t &hi) { - return dtime_t(std::llround(lo.micros * (1.0 - d) + hi.micros * d)); -} - -template <> -timestamp_t CastInterpolation::Interpolate(const timestamp_t &lo, const double d, const timestamp_t &hi) { - return timestamp_t(std::llround(lo.value * (1.0 - d) + hi.value * d)); -} - -template <> -hugeint_t CastInterpolation::Interpolate(const hugeint_t &lo, const double d, const hugeint_t &hi) { - const hugeint_t delta = hi - lo; - return lo + MultiplyByDouble(delta, d); -} - -template <> -interval_t CastInterpolation::Interpolate(const interval_t &lo, const double d, const interval_t &hi) { - const interval_t delta = hi - lo; - return lo + MultiplyByDouble(delta, d); -} - -template <> -string_t CastInterpolation::Cast(const std::string &src, Vector &result) { - return StringVector::AddString(result, src); -} - -template <> -string_t CastInterpolation::Cast(const string_t &src, Vector &result) { - return StringVector::AddString(result, src); -} - -// Direct access -template -struct QuantileDirect { - using INPUT_TYPE = T; - using RESULT_TYPE = T; - - inline const INPUT_TYPE &operator()(const INPUT_TYPE &x) const { - return x; - } -}; - -// Indirect access -template -struct QuantileIndirect { - using INPUT_TYPE = idx_t; - using RESULT_TYPE = T; - const RESULT_TYPE *data; - - explicit QuantileIndirect(const RESULT_TYPE *data_p) : data(data_p) { - } - - inline RESULT_TYPE operator()(const idx_t &input) const { - return data[input]; - } -}; - -// Composed access -template -struct QuantileComposed { - using INPUT_TYPE = typename INNER::INPUT_TYPE; - using RESULT_TYPE = typename OUTER::RESULT_TYPE; - - const OUTER &outer; - const INNER &inner; - - explicit QuantileComposed(const OUTER &outer_p, const INNER &inner_p) : outer(outer_p), inner(inner_p) { - } - - inline RESULT_TYPE operator()(const idx_t &input) const { - return outer(inner(input)); - } -}; - -// Accessed comparison -template -struct QuantileCompare { - using INPUT_TYPE = typename ACCESSOR::INPUT_TYPE; - const ACCESSOR &accessor; - const bool desc; - explicit QuantileCompare(const ACCESSOR &accessor_p, bool desc_p) : accessor(accessor_p), desc(desc_p) { - } - - inline bool operator()(const INPUT_TYPE &lhs, const INPUT_TYPE &rhs) const { - const auto lval = accessor(lhs); - const auto rval = accessor(rhs); - - return desc ? (rval < lval) : (lval < rval); - } -}; - -// Avoid using naked Values in inner loops... -struct QuantileValue { - explicit QuantileValue(const Value &v) : val(v), dbl(v.GetValue()) { - const auto &type = val.type(); - switch (type.id()) { - case LogicalTypeId::DECIMAL: { - integral = IntegralValue::Get(v); - scaling = Hugeint::POWERS_OF_TEN[DecimalType::GetScale(type)]; - break; - } - default: - break; - } - } - - Value val; - - // DOUBLE - double dbl; - - // DECIMAL - hugeint_t integral; - hugeint_t scaling; -}; - -bool operator==(const QuantileValue &x, const QuantileValue &y) { - return x.val == y.val; -} - -// Continuous interpolation -template -struct Interpolator { - Interpolator(const QuantileValue &q, const idx_t n_p, const bool desc_p) - : desc(desc_p), RN((double)(n_p - 1) * q.dbl), FRN(floor(RN)), CRN(ceil(RN)), begin(0), end(n_p) { - } - - template > - TARGET_TYPE Operation(INPUT_TYPE *v_t, Vector &result, const ACCESSOR &accessor = ACCESSOR()) const { - using ACCESS_TYPE = typename ACCESSOR::RESULT_TYPE; - QuantileCompare comp(accessor, desc); - if (CRN == FRN) { - std::nth_element(v_t + begin, v_t + FRN, v_t + end, comp); - return CastInterpolation::Cast(accessor(v_t[FRN]), result); - } else { - std::nth_element(v_t + begin, v_t + FRN, v_t + end, comp); - std::nth_element(v_t + FRN, v_t + CRN, v_t + end, comp); - auto lo = CastInterpolation::Cast(accessor(v_t[FRN]), result); - auto hi = CastInterpolation::Cast(accessor(v_t[CRN]), result); - return CastInterpolation::Interpolate(lo, RN - FRN, hi); - } - } - - template > - TARGET_TYPE Replace(const INPUT_TYPE *v_t, Vector &result, const ACCESSOR &accessor = ACCESSOR()) const { - using ACCESS_TYPE = typename ACCESSOR::RESULT_TYPE; - if (CRN == FRN) { - return CastInterpolation::Cast(accessor(v_t[FRN]), result); - } else { - auto lo = CastInterpolation::Cast(accessor(v_t[FRN]), result); - auto hi = CastInterpolation::Cast(accessor(v_t[CRN]), result); - return CastInterpolation::Interpolate(lo, RN - FRN, hi); - } - } - - const bool desc; - const double RN; - const idx_t FRN; - const idx_t CRN; - - idx_t begin; - idx_t end; -}; - -// Discrete "interpolation" -template <> -struct Interpolator { - static inline idx_t Index(const QuantileValue &q, const idx_t n) { - idx_t floored; - switch (q.val.type().id()) { - case LogicalTypeId::DECIMAL: { - // Integer arithmetic for accuracy - const auto integral = q.integral; - const auto scaling = q.scaling; - const auto scaled_q = DecimalMultiplyOverflowCheck::Operation(n, integral); - const auto scaled_n = DecimalMultiplyOverflowCheck::Operation(n, scaling); - floored = Cast::Operation((scaled_n - scaled_q) / scaling); - break; - } - default: - const auto scaled_q = (double)(n * q.dbl); - floored = floor(n - scaled_q); - break; - } - - return MaxValue(1, n - floored) - 1; - } - - Interpolator(const QuantileValue &q, const idx_t n_p, bool desc_p) - : desc(desc_p), FRN(Index(q, n_p)), CRN(FRN), begin(0), end(n_p) { - } - - template > - TARGET_TYPE Operation(INPUT_TYPE *v_t, Vector &result, const ACCESSOR &accessor = ACCESSOR()) const { - using ACCESS_TYPE = typename ACCESSOR::RESULT_TYPE; - QuantileCompare comp(accessor, desc); - std::nth_element(v_t + begin, v_t + FRN, v_t + end, comp); - return CastInterpolation::Cast(accessor(v_t[FRN]), result); - } - - template > - TARGET_TYPE Replace(const INPUT_TYPE *v_t, Vector &result, const ACCESSOR &accessor = ACCESSOR()) const { - using ACCESS_TYPE = typename ACCESSOR::RESULT_TYPE; - return CastInterpolation::Cast(accessor(v_t[FRN]), result); - } - - const bool desc; - const idx_t FRN; - const idx_t CRN; - - idx_t begin; - idx_t end; -}; - -template -static inline T QuantileAbs(const T &t) { - return AbsOperator::Operation(t); -} - -template <> -inline Value QuantileAbs(const Value &v) { - const auto &type = v.type(); - switch (type.id()) { - case LogicalTypeId::DECIMAL: { - const auto integral = IntegralValue::Get(v); - const auto width = DecimalType::GetWidth(type); - const auto scale = DecimalType::GetScale(type); - switch (type.InternalType()) { - case PhysicalType::INT16: - return Value::DECIMAL(QuantileAbs(Cast::Operation(integral)), width, scale); - case PhysicalType::INT32: - return Value::DECIMAL(QuantileAbs(Cast::Operation(integral)), width, scale); - case PhysicalType::INT64: - return Value::DECIMAL(QuantileAbs(Cast::Operation(integral)), width, scale); - case PhysicalType::INT128: - return Value::DECIMAL(QuantileAbs(integral), width, scale); - default: - throw InternalException("Unknown DECIMAL type"); - } - } - default: - return Value::DOUBLE(QuantileAbs(v.GetValue())); - } -} - -struct QuantileBindData : public FunctionData { - QuantileBindData() { - } - - explicit QuantileBindData(const Value &quantile_p) - : quantiles(1, QuantileValue(QuantileAbs(quantile_p))), order(1, 0), desc(quantile_p < 0) { - } - - explicit QuantileBindData(const vector &quantiles_p) { - vector normalised; - size_t pos = 0; - size_t neg = 0; - for (idx_t i = 0; i < quantiles_p.size(); ++i) { - const auto &q = quantiles_p[i]; - pos += (q > 0); - neg += (q < 0); - normalised.emplace_back(QuantileAbs(q)); - order.push_back(i); - } - if (pos && neg) { - throw BinderException("QUANTILE parameters must have consistent signs"); - } - desc = (neg > 0); - - IndirectLess lt(normalised.data()); - std::sort(order.begin(), order.end(), lt); - - for (const auto &q : normalised) { - quantiles.emplace_back(QuantileValue(q)); - } - } - - QuantileBindData(const QuantileBindData &other) : order(other.order), desc(other.desc) { - for (const auto &q : other.quantiles) { - quantiles.emplace_back(q); - } - } - - unique_ptr Copy() const override { - return make_uniq(*this); - } - - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return desc == other.desc && quantiles == other.quantiles && order == other.order; - } - - static void Serialize(Serializer &serializer, const optional_ptr bind_data_p, - const AggregateFunction &function) { - auto &bind_data = bind_data_p->Cast(); - vector raw; - for (const auto &q : bind_data.quantiles) { - raw.emplace_back(q.val); - } - serializer.WriteProperty(100, "quantiles", raw); - serializer.WriteProperty(101, "order", bind_data.order); - serializer.WriteProperty(102, "desc", bind_data.desc); - } - - static unique_ptr Deserialize(Deserializer &deserializer, AggregateFunction &function) { - auto result = make_uniq(); - vector raw; - deserializer.ReadProperty(100, "quantiles", raw); - deserializer.ReadProperty(101, "order", result->order); - deserializer.ReadProperty(102, "desc", result->desc); - for (const auto &r : raw) { - result->quantiles.emplace_back(QuantileValue(r)); - } - return std::move(result); - } - - static void SerializeDecimal(Serializer &serializer, const optional_ptr bind_data_p, - const AggregateFunction &function) { - throw NotImplementedException("FIXME: serializing quantiles with decimals is not supported right now"); - } - - vector quantiles; - vector order; - bool desc; -}; - -struct QuantileOperation { - template - static void Initialize(STATE &state) { - new (&state) STATE(); - } - - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, - idx_t count) { - for (idx_t i = 0; i < count; i++) { - Operation(state, input, unary_input); - } - } - - template - static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &) { - state.v.emplace_back(input); - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - if (source.v.empty()) { - return; - } - target.v.insert(target.v.end(), source.v.begin(), source.v.end()); - } - - template - static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { - state.~STATE(); - } - - static bool IgnoreNull() { - return true; - } -}; - -template -static AggregateFunction QuantileListAggregate(const LogicalType &input_type, const LogicalType &child_type) { // NOLINT - LogicalType result_type = LogicalType::LIST(child_type); - return AggregateFunction( - {input_type}, result_type, AggregateFunction::StateSize, AggregateFunction::StateInitialize, - AggregateFunction::UnaryScatterUpdate, AggregateFunction::StateCombine, - AggregateFunction::StateFinalize, AggregateFunction::UnaryUpdate, - nullptr, AggregateFunction::StateDestroy); -} - -template -struct QuantileScalarOperation : public QuantileOperation { - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.v.empty()) { - finalize_data.ReturnNull(); - return; - } - D_ASSERT(finalize_data.input.bind_data); - auto &bind_data = finalize_data.input.bind_data->Cast(); - D_ASSERT(bind_data.quantiles.size() == 1); - Interpolator interp(bind_data.quantiles[0], state.v.size(), bind_data.desc); - target = interp.template Operation(state.v.data(), finalize_data.result); - } - - template - static void Window(const INPUT_TYPE *data, const ValidityMask &fmask, const ValidityMask &dmask, - AggregateInputData &aggr_input_data, STATE &state, const FrameBounds &frame, - const FrameBounds &prev, Vector &result, idx_t ridx, idx_t bias) { - auto rdata = FlatVector::GetData(result); - auto &rmask = FlatVector::Validity(result); - - QuantileIncluded included(fmask, dmask, bias); - - // Lazily initialise frame state - auto prev_pos = state.pos; - state.SetPos(frame.end - frame.start); - - auto index = state.w.data(); - D_ASSERT(index); - - D_ASSERT(aggr_input_data.bind_data); - auto &bind_data = aggr_input_data.bind_data->Cast(); - - // Find the two positions needed - const auto &q = bind_data.quantiles[0]; - - bool replace = false; - if (frame.start == prev.start + 1 && frame.end == prev.end + 1) { - // Fixed frame size - const auto j = ReplaceIndex(index, frame, prev); - // We can only replace if the number of NULLs has not changed - if (included.AllValid() || included(prev.start) == included(prev.end)) { - Interpolator interp(q, prev_pos, false); - replace = CanReplace(index, data, j, interp.FRN, interp.CRN, included); - if (replace) { - state.pos = prev_pos; - } - } - } else { - ReuseIndexes(index, frame, prev); - } - - if (!replace && !included.AllValid()) { - // Remove the NULLs - state.pos = std::partition(index, index + state.pos, included) - index; - } - if (state.pos) { - Interpolator interp(q, state.pos, false); - - using ID = QuantileIndirect; - ID indirect(data); - rdata[ridx] = replace ? interp.template Replace(index, result, indirect) - : interp.template Operation(index, result, indirect); - } else { - rmask.Set(ridx, false); - } - } -}; - -template -AggregateFunction GetTypedDiscreteQuantileAggregateFunction(const LogicalType &type) { - using STATE = QuantileState; - using OP = QuantileScalarOperation; - auto fun = AggregateFunction::UnaryAggregateDestructor(type, type); - fun.window = AggregateFunction::UnaryWindow; - return fun; -} - -AggregateFunction GetDiscreteQuantileAggregateFunction(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::TINYINT: - return GetTypedDiscreteQuantileAggregateFunction(type); - case LogicalTypeId::SMALLINT: - return GetTypedDiscreteQuantileAggregateFunction(type); - case LogicalTypeId::INTEGER: - return GetTypedDiscreteQuantileAggregateFunction(type); - case LogicalTypeId::BIGINT: - return GetTypedDiscreteQuantileAggregateFunction(type); - case LogicalTypeId::HUGEINT: - return GetTypedDiscreteQuantileAggregateFunction(type); - case LogicalTypeId::FLOAT: - return GetTypedDiscreteQuantileAggregateFunction(type); - case LogicalTypeId::DOUBLE: - return GetTypedDiscreteQuantileAggregateFunction(type); - case LogicalTypeId::DECIMAL: - switch (type.InternalType()) { - case PhysicalType::INT16: - return GetTypedDiscreteQuantileAggregateFunction(type); - case PhysicalType::INT32: - return GetTypedDiscreteQuantileAggregateFunction(type); - case PhysicalType::INT64: - return GetTypedDiscreteQuantileAggregateFunction(type); - case PhysicalType::INT128: - return GetTypedDiscreteQuantileAggregateFunction(type); - default: - throw NotImplementedException("Unimplemented discrete quantile aggregate"); - } - case LogicalTypeId::DATE: - return GetTypedDiscreteQuantileAggregateFunction(type); - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_TZ: - return GetTypedDiscreteQuantileAggregateFunction(type); - case LogicalTypeId::TIME: - case LogicalTypeId::TIME_TZ: - return GetTypedDiscreteQuantileAggregateFunction(type); - case LogicalTypeId::INTERVAL: - return GetTypedDiscreteQuantileAggregateFunction(type); - - case LogicalTypeId::VARCHAR: - return GetTypedDiscreteQuantileAggregateFunction(type); - - default: - throw NotImplementedException("Unimplemented discrete quantile aggregate"); - } -} - -template -struct QuantileListOperation : public QuantileOperation { - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.v.empty()) { - finalize_data.ReturnNull(); - return; - } - - D_ASSERT(finalize_data.input.bind_data); - auto &bind_data = finalize_data.input.bind_data->Cast(); - - auto &result = ListVector::GetEntry(finalize_data.result); - auto ridx = ListVector::GetListSize(finalize_data.result); - ListVector::Reserve(finalize_data.result, ridx + bind_data.quantiles.size()); - auto rdata = FlatVector::GetData(result); - - auto v_t = state.v.data(); - D_ASSERT(v_t); - - auto &entry = target; - entry.offset = ridx; - idx_t lower = 0; - for (const auto &q : bind_data.order) { - const auto &quantile = bind_data.quantiles[q]; - Interpolator interp(quantile, state.v.size(), bind_data.desc); - interp.begin = lower; - rdata[ridx + q] = interp.template Operation(v_t, result); - lower = interp.FRN; - } - entry.length = bind_data.quantiles.size(); - - ListVector::SetListSize(finalize_data.result, entry.offset + entry.length); - } - - template - static void Window(const INPUT_TYPE *data, const ValidityMask &fmask, const ValidityMask &dmask, - AggregateInputData &aggr_input_data, STATE &state, const FrameBounds &frame, - const FrameBounds &prev, Vector &list, idx_t lidx, idx_t bias) { - D_ASSERT(aggr_input_data.bind_data); - auto &bind_data = aggr_input_data.bind_data->Cast(); - - QuantileIncluded included(fmask, dmask, bias); - - // Result is a constant LIST with a fixed length - auto ldata = FlatVector::GetData(list); - auto &lmask = FlatVector::Validity(list); - auto &lentry = ldata[lidx]; - lentry.offset = ListVector::GetListSize(list); - lentry.length = bind_data.quantiles.size(); - - ListVector::Reserve(list, lentry.offset + lentry.length); - ListVector::SetListSize(list, lentry.offset + lentry.length); - auto &result = ListVector::GetEntry(list); - auto rdata = FlatVector::GetData(result); - - // Lazily initialise frame state - auto prev_pos = state.pos; - state.SetPos(frame.end - frame.start); - - auto index = state.w.data(); - - // We can generalise replacement for quantile lists by observing that when a replacement is - // valid for a single quantile, it is valid for all quantiles greater/less than that quantile - // based on whether the insertion is below/above the quantile location. - // So if a replaced index in an IQR is located between Q25 and Q50, but has a value below Q25, - // then Q25 must be recomputed, but Q50 and Q75 are unaffected. - // For a single element list, this reduces to the scalar case. - std::pair replaceable {state.pos, 0}; - if (frame.start == prev.start + 1 && frame.end == prev.end + 1) { - // Fixed frame size - const auto j = ReplaceIndex(index, frame, prev); - // We can only replace if the number of NULLs has not changed - if (included.AllValid() || included(prev.start) == included(prev.end)) { - for (const auto &q : bind_data.order) { - const auto &quantile = bind_data.quantiles[q]; - Interpolator interp(quantile, prev_pos, false); - const auto replace = CanReplace(index, data, j, interp.FRN, interp.CRN, included); - if (replace < 0) { - // Replacement is before this quantile, so the rest will be replaceable too. - replaceable.first = MinValue(replaceable.first, interp.FRN); - replaceable.second = prev_pos; - break; - } else if (replace > 0) { - // Replacement is after this quantile, so everything before it is replaceable too. - replaceable.first = 0; - replaceable.second = MaxValue(replaceable.second, interp.CRN); - } - } - if (replaceable.first < replaceable.second) { - state.pos = prev_pos; - } - } - } else { - ReuseIndexes(index, frame, prev); - } - - if (replaceable.first >= replaceable.second && !included.AllValid()) { - // Remove the NULLs - state.pos = std::partition(index, index + state.pos, included) - index; - } - - if (state.pos) { - using ID = QuantileIndirect; - ID indirect(data); - for (const auto &q : bind_data.order) { - const auto &quantile = bind_data.quantiles[q]; - Interpolator interp(quantile, state.pos, false); - if (replaceable.first <= interp.FRN && interp.CRN <= replaceable.second) { - rdata[lentry.offset + q] = interp.template Replace(index, result, indirect); - } else { - // Make sure we don't disturb any replacements - if (replaceable.first < replaceable.second) { - if (interp.FRN < replaceable.first) { - interp.end = replaceable.first; - } - if (replaceable.second < interp.CRN) { - interp.begin = replaceable.second; - } - } - rdata[lentry.offset + q] = - interp.template Operation(index, result, indirect); - } - } - } else { - lmask.Set(lidx, false); - } - } -}; - -template -AggregateFunction GetTypedDiscreteQuantileListAggregateFunction(const LogicalType &type) { - using STATE = QuantileState; - using OP = QuantileListOperation; - auto fun = QuantileListAggregate(type, type); - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - fun.window = AggregateFunction::UnaryWindow; - return fun; -} - -AggregateFunction GetDiscreteQuantileListAggregateFunction(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::TINYINT: - return GetTypedDiscreteQuantileListAggregateFunction(type); - case LogicalTypeId::SMALLINT: - return GetTypedDiscreteQuantileListAggregateFunction(type); - case LogicalTypeId::INTEGER: - return GetTypedDiscreteQuantileListAggregateFunction(type); - case LogicalTypeId::BIGINT: - return GetTypedDiscreteQuantileListAggregateFunction(type); - case LogicalTypeId::HUGEINT: - return GetTypedDiscreteQuantileListAggregateFunction(type); - case LogicalTypeId::FLOAT: - return GetTypedDiscreteQuantileListAggregateFunction(type); - case LogicalTypeId::DOUBLE: - return GetTypedDiscreteQuantileListAggregateFunction(type); - case LogicalTypeId::DECIMAL: - switch (type.InternalType()) { - case PhysicalType::INT16: - return GetTypedDiscreteQuantileListAggregateFunction(type); - case PhysicalType::INT32: - return GetTypedDiscreteQuantileListAggregateFunction(type); - case PhysicalType::INT64: - return GetTypedDiscreteQuantileListAggregateFunction(type); - case PhysicalType::INT128: - return GetTypedDiscreteQuantileListAggregateFunction(type); - default: - throw NotImplementedException("Unimplemented discrete quantile list aggregate"); - } - case LogicalTypeId::DATE: - return GetTypedDiscreteQuantileListAggregateFunction(type); - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_TZ: - return GetTypedDiscreteQuantileListAggregateFunction(type); - case LogicalTypeId::TIME: - case LogicalTypeId::TIME_TZ: - return GetTypedDiscreteQuantileListAggregateFunction(type); - case LogicalTypeId::INTERVAL: - return GetTypedDiscreteQuantileListAggregateFunction(type); - case LogicalTypeId::VARCHAR: - return GetTypedDiscreteQuantileListAggregateFunction(type); - default: - throw NotImplementedException("Unimplemented discrete quantile list aggregate"); - } -} - -template -AggregateFunction GetTypedContinuousQuantileAggregateFunction(const LogicalType &input_type, - const LogicalType &target_type) { - using STATE = QuantileState; - using OP = QuantileScalarOperation; - auto fun = AggregateFunction::UnaryAggregateDestructor(input_type, target_type); - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - fun.window = AggregateFunction::UnaryWindow; - return fun; -} - -AggregateFunction GetContinuousQuantileAggregateFunction(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::TINYINT: - return GetTypedContinuousQuantileAggregateFunction(type, LogicalType::DOUBLE); - case LogicalTypeId::SMALLINT: - return GetTypedContinuousQuantileAggregateFunction(type, LogicalType::DOUBLE); - case LogicalTypeId::INTEGER: - return GetTypedContinuousQuantileAggregateFunction(type, LogicalType::DOUBLE); - case LogicalTypeId::BIGINT: - return GetTypedContinuousQuantileAggregateFunction(type, LogicalType::DOUBLE); - case LogicalTypeId::HUGEINT: - return GetTypedContinuousQuantileAggregateFunction(type, LogicalType::DOUBLE); - case LogicalTypeId::FLOAT: - return GetTypedContinuousQuantileAggregateFunction(type, type); - case LogicalTypeId::DOUBLE: - return GetTypedContinuousQuantileAggregateFunction(type, type); - case LogicalTypeId::DECIMAL: - switch (type.InternalType()) { - case PhysicalType::INT16: - return GetTypedContinuousQuantileAggregateFunction(type, type); - case PhysicalType::INT32: - return GetTypedContinuousQuantileAggregateFunction(type, type); - case PhysicalType::INT64: - return GetTypedContinuousQuantileAggregateFunction(type, type); - case PhysicalType::INT128: - return GetTypedContinuousQuantileAggregateFunction(type, type); - default: - throw NotImplementedException("Unimplemented continuous quantile DECIMAL aggregate"); - } - case LogicalTypeId::DATE: - return GetTypedContinuousQuantileAggregateFunction(type, LogicalType::TIMESTAMP); - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_TZ: - return GetTypedContinuousQuantileAggregateFunction(type, type); - case LogicalTypeId::TIME: - case LogicalTypeId::TIME_TZ: - return GetTypedContinuousQuantileAggregateFunction(type, type); - - default: - throw NotImplementedException("Unimplemented continuous quantile aggregate"); - } -} - -template -AggregateFunction GetTypedContinuousQuantileListAggregateFunction(const LogicalType &input_type, - const LogicalType &result_type) { - using STATE = QuantileState; - using OP = QuantileListOperation; - auto fun = QuantileListAggregate(input_type, result_type); - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - fun.window = AggregateFunction::UnaryWindow; - return fun; -} - -AggregateFunction GetContinuousQuantileListAggregateFunction(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::TINYINT: - return GetTypedContinuousQuantileListAggregateFunction(type, LogicalType::DOUBLE); - case LogicalTypeId::SMALLINT: - return GetTypedContinuousQuantileListAggregateFunction(type, LogicalType::DOUBLE); - case LogicalTypeId::INTEGER: - return GetTypedContinuousQuantileListAggregateFunction(type, LogicalType::DOUBLE); - case LogicalTypeId::BIGINT: - return GetTypedContinuousQuantileListAggregateFunction(type, LogicalType::DOUBLE); - case LogicalTypeId::HUGEINT: - return GetTypedContinuousQuantileListAggregateFunction(type, LogicalType::DOUBLE); - - case LogicalTypeId::FLOAT: - return GetTypedContinuousQuantileListAggregateFunction(type, type); - case LogicalTypeId::DOUBLE: - return GetTypedContinuousQuantileListAggregateFunction(type, type); - case LogicalTypeId::DECIMAL: - switch (type.InternalType()) { - case PhysicalType::INT16: - return GetTypedContinuousQuantileListAggregateFunction(type, type); - case PhysicalType::INT32: - return GetTypedContinuousQuantileListAggregateFunction(type, type); - case PhysicalType::INT64: - return GetTypedContinuousQuantileListAggregateFunction(type, type); - case PhysicalType::INT128: - return GetTypedContinuousQuantileListAggregateFunction(type, type); - default: - throw NotImplementedException("Unimplemented discrete quantile DECIMAL list aggregate"); - } - break; - - case LogicalTypeId::DATE: - return GetTypedContinuousQuantileListAggregateFunction(type, LogicalType::TIMESTAMP); - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_TZ: - return GetTypedContinuousQuantileListAggregateFunction(type, type); - case LogicalTypeId::TIME: - case LogicalTypeId::TIME_TZ: - return GetTypedContinuousQuantileListAggregateFunction(type, type); - - default: - throw NotImplementedException("Unimplemented discrete quantile list aggregate"); - } -} - -template -struct MadAccessor { - using INPUT_TYPE = T; - using RESULT_TYPE = R; - const MEDIAN_TYPE &median; - explicit MadAccessor(const MEDIAN_TYPE &median_p) : median(median_p) { - } - - inline RESULT_TYPE operator()(const INPUT_TYPE &input) const { - const auto delta = input - median; - return TryAbsOperator::Operation(delta); - } -}; - -// hugeint_t - double => undefined -template <> -struct MadAccessor { - using INPUT_TYPE = hugeint_t; - using RESULT_TYPE = double; - using MEDIAN_TYPE = double; - const MEDIAN_TYPE &median; - explicit MadAccessor(const MEDIAN_TYPE &median_p) : median(median_p) { - } - inline RESULT_TYPE operator()(const INPUT_TYPE &input) const { - const auto delta = Hugeint::Cast(input) - median; - return TryAbsOperator::Operation(delta); - } -}; - -// date_t - timestamp_t => interval_t -template <> -struct MadAccessor { - using INPUT_TYPE = date_t; - using RESULT_TYPE = interval_t; - using MEDIAN_TYPE = timestamp_t; - const MEDIAN_TYPE &median; - explicit MadAccessor(const MEDIAN_TYPE &median_p) : median(median_p) { - } - inline RESULT_TYPE operator()(const INPUT_TYPE &input) const { - const auto dt = Cast::Operation(input); - const auto delta = dt - median; - return Interval::FromMicro(TryAbsOperator::Operation(delta)); - } -}; - -// timestamp_t - timestamp_t => int64_t -template <> -struct MadAccessor { - using INPUT_TYPE = timestamp_t; - using RESULT_TYPE = interval_t; - using MEDIAN_TYPE = timestamp_t; - const MEDIAN_TYPE &median; - explicit MadAccessor(const MEDIAN_TYPE &median_p) : median(median_p) { - } - inline RESULT_TYPE operator()(const INPUT_TYPE &input) const { - const auto delta = input - median; - return Interval::FromMicro(TryAbsOperator::Operation(delta)); - } -}; - -// dtime_t - dtime_t => int64_t -template <> -struct MadAccessor { - using INPUT_TYPE = dtime_t; - using RESULT_TYPE = interval_t; - using MEDIAN_TYPE = dtime_t; - const MEDIAN_TYPE &median; - explicit MadAccessor(const MEDIAN_TYPE &median_p) : median(median_p) { - } - inline RESULT_TYPE operator()(const INPUT_TYPE &input) const { - const auto delta = input - median; - return Interval::FromMicro(TryAbsOperator::Operation(delta)); - } -}; - -template -struct MedianAbsoluteDeviationOperation : public QuantileOperation { - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.v.empty()) { - finalize_data.ReturnNull(); - return; - } - using SAVE_TYPE = typename STATE::SaveType; - D_ASSERT(finalize_data.input.bind_data); - auto &bind_data = finalize_data.input.bind_data->Cast(); - D_ASSERT(bind_data.quantiles.size() == 1); - const auto &q = bind_data.quantiles[0]; - Interpolator interp(q, state.v.size(), false); - const auto med = interp.template Operation(state.v.data(), finalize_data.result); - - MadAccessor accessor(med); - target = interp.template Operation(state.v.data(), finalize_data.result, accessor); - } - - template - static void Window(const INPUT_TYPE *data, const ValidityMask &fmask, const ValidityMask &dmask, - AggregateInputData &aggr_input_data, STATE &state, const FrameBounds &frame, - const FrameBounds &prev, Vector &result, idx_t ridx, idx_t bias) { - auto rdata = FlatVector::GetData(result); - auto &rmask = FlatVector::Validity(result); - - QuantileIncluded included(fmask, dmask, bias); - - // Lazily initialise frame state - auto prev_pos = state.pos; - state.SetPos(frame.end - frame.start); - - auto index = state.w.data(); - D_ASSERT(index); - - // We need a second index for the second pass. - if (state.pos > state.m.size()) { - state.m.resize(state.pos); - } - - auto index2 = state.m.data(); - D_ASSERT(index2); - - // The replacement trick does not work on the second index because if - // the median has changed, the previous order is not correct. - // It is probably close, however, and so reuse is helpful. - ReuseIndexes(index2, frame, prev); - std::partition(index2, index2 + state.pos, included); - - // Find the two positions needed for the median - D_ASSERT(aggr_input_data.bind_data); - auto &bind_data = aggr_input_data.bind_data->Cast(); - D_ASSERT(bind_data.quantiles.size() == 1); - const auto &q = bind_data.quantiles[0]; - - bool replace = false; - if (frame.start == prev.start + 1 && frame.end == prev.end + 1) { - // Fixed frame size - const auto j = ReplaceIndex(index, frame, prev); - // We can only replace if the number of NULLs has not changed - if (included.AllValid() || included(prev.start) == included(prev.end)) { - Interpolator interp(q, prev_pos, false); - replace = CanReplace(index, data, j, interp.FRN, interp.CRN, included); - if (replace) { - state.pos = prev_pos; - } - } - } else { - ReuseIndexes(index, frame, prev); - } - - if (!replace && !included.AllValid()) { - // Remove the NULLs - state.pos = std::partition(index, index + state.pos, included) - index; - } - - if (state.pos) { - Interpolator interp(q, state.pos, false); - - // Compute or replace median from the first index - using ID = QuantileIndirect; - ID indirect(data); - const auto med = replace ? interp.template Replace(index, result, indirect) - : interp.template Operation(index, result, indirect); - - // Compute mad from the second index - using MAD = MadAccessor; - MAD mad(med); - - using MadIndirect = QuantileComposed; - MadIndirect mad_indirect(mad, indirect); - rdata[ridx] = interp.template Operation(index2, result, mad_indirect); - } else { - rmask.Set(ridx, false); - } - } -}; - -unique_ptr BindMedian(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - return make_uniq(Value::DECIMAL(int16_t(5), 2, 1)); -} - -template -AggregateFunction GetTypedMedianAbsoluteDeviationAggregateFunction(const LogicalType &input_type, - const LogicalType &target_type) { - using STATE = QuantileState; - using OP = MedianAbsoluteDeviationOperation; - auto fun = AggregateFunction::UnaryAggregateDestructor(input_type, target_type); - fun.bind = BindMedian; - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - fun.window = AggregateFunction::UnaryWindow; - return fun; -} - -AggregateFunction GetMedianAbsoluteDeviationAggregateFunction(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::FLOAT: - return GetTypedMedianAbsoluteDeviationAggregateFunction(type, type); - case LogicalTypeId::DOUBLE: - return GetTypedMedianAbsoluteDeviationAggregateFunction(type, type); - case LogicalTypeId::DECIMAL: - switch (type.InternalType()) { - case PhysicalType::INT16: - return GetTypedMedianAbsoluteDeviationAggregateFunction(type, type); - case PhysicalType::INT32: - return GetTypedMedianAbsoluteDeviationAggregateFunction(type, type); - case PhysicalType::INT64: - return GetTypedMedianAbsoluteDeviationAggregateFunction(type, type); - case PhysicalType::INT128: - return GetTypedMedianAbsoluteDeviationAggregateFunction(type, type); - default: - throw NotImplementedException("Unimplemented Median Absolute Deviation DECIMAL aggregate"); - } - break; - - case LogicalTypeId::DATE: - return GetTypedMedianAbsoluteDeviationAggregateFunction(type, - LogicalType::INTERVAL); - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_TZ: - return GetTypedMedianAbsoluteDeviationAggregateFunction( - type, LogicalType::INTERVAL); - case LogicalTypeId::TIME: - case LogicalTypeId::TIME_TZ: - return GetTypedMedianAbsoluteDeviationAggregateFunction(type, - LogicalType::INTERVAL); - - default: - throw NotImplementedException("Unimplemented Median Absolute Deviation aggregate"); - } -} - -unique_ptr BindMedianDecimal(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - auto bind_data = BindMedian(context, function, arguments); - - function = GetDiscreteQuantileAggregateFunction(arguments[0]->return_type); - function.name = "median"; - function.serialize = QuantileBindData::SerializeDecimal; - function.deserialize = QuantileBindData::Deserialize; - function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return bind_data; -} - -unique_ptr BindMedianAbsoluteDeviationDecimal(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - function = GetMedianAbsoluteDeviationAggregateFunction(arguments[0]->return_type); - function.name = "mad"; - function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return BindMedian(context, function, arguments); -} - -static const Value &CheckQuantile(const Value &quantile_val) { - if (quantile_val.IsNull()) { - throw BinderException("QUANTILE parameter cannot be NULL"); - } - auto quantile = quantile_val.GetValue(); - if (quantile < -1 || quantile > 1) { - throw BinderException("QUANTILE can only take parameters in the range [-1, 1]"); - } - if (Value::IsNan(quantile)) { - throw BinderException("QUANTILE parameter cannot be NaN"); - } - - return quantile_val; -} - -unique_ptr BindQuantile(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - if (arguments[1]->HasParameter()) { - throw ParameterNotResolvedException(); - } - if (!arguments[1]->IsFoldable()) { - throw BinderException("QUANTILE can only take constant parameters"); - } - Value quantile_val = ExpressionExecutor::EvaluateScalar(context, *arguments[1]); - vector quantiles; - if (quantile_val.type().id() != LogicalTypeId::LIST) { - quantiles.push_back(CheckQuantile(quantile_val)); - } else { - for (const auto &element_val : ListValue::GetChildren(quantile_val)) { - quantiles.push_back(CheckQuantile(element_val)); - } - } - - Function::EraseArgument(function, arguments, arguments.size() - 1); - return make_uniq(quantiles); -} - -unique_ptr BindDiscreteQuantileDecimal(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - auto bind_data = BindQuantile(context, function, arguments); - function = GetDiscreteQuantileAggregateFunction(arguments[0]->return_type); - function.name = "quantile_disc"; - function.serialize = QuantileBindData::SerializeDecimal; - function.deserialize = QuantileBindData::Deserialize; - function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return bind_data; -} - -unique_ptr BindDiscreteQuantileDecimalList(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - auto bind_data = BindQuantile(context, function, arguments); - function = GetDiscreteQuantileListAggregateFunction(arguments[0]->return_type); - function.name = "quantile_disc"; - function.serialize = QuantileBindData::SerializeDecimal; - function.deserialize = QuantileBindData::Deserialize; - function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return bind_data; -} - -unique_ptr BindContinuousQuantileDecimal(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - auto bind_data = BindQuantile(context, function, arguments); - function = GetContinuousQuantileAggregateFunction(arguments[0]->return_type); - function.name = "quantile_cont"; - function.serialize = QuantileBindData::SerializeDecimal; - function.deserialize = QuantileBindData::Deserialize; - function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return bind_data; -} - -unique_ptr BindContinuousQuantileDecimalList(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - auto bind_data = BindQuantile(context, function, arguments); - function = GetContinuousQuantileListAggregateFunction(arguments[0]->return_type); - function.name = "quantile_cont"; - function.serialize = QuantileBindData::SerializeDecimal; - function.deserialize = QuantileBindData::Deserialize; - function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return bind_data; -} - -static bool CanInterpolate(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::INTERVAL: - case LogicalTypeId::VARCHAR: - return false; - default: - return true; - } -} - -AggregateFunction GetMedianAggregate(const LogicalType &type) { - auto fun = CanInterpolate(type) ? GetContinuousQuantileAggregateFunction(type) - : GetDiscreteQuantileAggregateFunction(type); - fun.bind = BindMedian; - fun.serialize = QuantileBindData::Serialize; - fun.deserialize = QuantileBindData::Deserialize; - return fun; -} - -AggregateFunction GetDiscreteQuantileAggregate(const LogicalType &type) { - auto fun = GetDiscreteQuantileAggregateFunction(type); - fun.bind = BindQuantile; - fun.serialize = QuantileBindData::Serialize; - fun.deserialize = QuantileBindData::Deserialize; - // temporarily push an argument so we can bind the actual quantile - fun.arguments.emplace_back(LogicalType::DOUBLE); - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return fun; -} - -AggregateFunction GetDiscreteQuantileListAggregate(const LogicalType &type) { - auto fun = GetDiscreteQuantileListAggregateFunction(type); - fun.bind = BindQuantile; - fun.serialize = QuantileBindData::Serialize; - fun.deserialize = QuantileBindData::Deserialize; - // temporarily push an argument so we can bind the actual quantile - auto list_of_double = LogicalType::LIST(LogicalType::DOUBLE); - fun.arguments.push_back(list_of_double); - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return fun; -} - -AggregateFunction GetContinuousQuantileAggregate(const LogicalType &type) { - auto fun = GetContinuousQuantileAggregateFunction(type); - fun.bind = BindQuantile; - fun.serialize = QuantileBindData::Serialize; - fun.deserialize = QuantileBindData::Deserialize; - // temporarily push an argument so we can bind the actual quantile - fun.arguments.emplace_back(LogicalType::DOUBLE); - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return fun; -} - -AggregateFunction GetContinuousQuantileListAggregate(const LogicalType &type) { - auto fun = GetContinuousQuantileListAggregateFunction(type); - fun.bind = BindQuantile; - fun.serialize = QuantileBindData::Serialize; - fun.deserialize = QuantileBindData::Deserialize; - // temporarily push an argument so we can bind the actual quantile - auto list_of_double = LogicalType::LIST(LogicalType::DOUBLE); - fun.arguments.push_back(list_of_double); - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return fun; -} - -AggregateFunction GetQuantileDecimalAggregate(const vector &arguments, const LogicalType &return_type, - bind_aggregate_function_t bind) { - AggregateFunction fun(arguments, return_type, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, bind); - fun.bind = bind; - fun.serialize = QuantileBindData::Serialize; - fun.deserialize = QuantileBindData::Deserialize; - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return fun; -} - -vector GetQuantileTypes() { - return {LogicalType::TINYINT, LogicalType::SMALLINT, LogicalType::INTEGER, LogicalType::BIGINT, - LogicalType::HUGEINT, LogicalType::FLOAT, LogicalType::DOUBLE, LogicalType::DATE, - LogicalType::TIMESTAMP, LogicalType::TIME, LogicalType::TIMESTAMP_TZ, LogicalType::TIME_TZ, - LogicalType::INTERVAL, LogicalType::VARCHAR}; -} - -AggregateFunctionSet MedianFun::GetFunctions() { - AggregateFunctionSet median("median"); - median.AddFunction( - GetQuantileDecimalAggregate({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, BindMedianDecimal)); - for (const auto &type : GetQuantileTypes()) { - median.AddFunction(GetMedianAggregate(type)); - } - return median; -} - -AggregateFunctionSet QuantileDiscFun::GetFunctions() { - AggregateFunctionSet quantile_disc("quantile_disc"); - quantile_disc.AddFunction(GetQuantileDecimalAggregate({LogicalTypeId::DECIMAL, LogicalType::DOUBLE}, - LogicalTypeId::DECIMAL, BindDiscreteQuantileDecimal)); - quantile_disc.AddFunction( - GetQuantileDecimalAggregate({LogicalTypeId::DECIMAL, LogicalType::LIST(LogicalType::DOUBLE)}, - LogicalType::LIST(LogicalTypeId::DECIMAL), BindDiscreteQuantileDecimalList)); - for (const auto &type : GetQuantileTypes()) { - quantile_disc.AddFunction(GetDiscreteQuantileAggregate(type)); - quantile_disc.AddFunction(GetDiscreteQuantileListAggregate(type)); - } - return quantile_disc; - // quantile -} - -AggregateFunctionSet QuantileContFun::GetFunctions() { - AggregateFunctionSet quantile_cont("quantile_cont"); - quantile_cont.AddFunction(GetQuantileDecimalAggregate({LogicalTypeId::DECIMAL, LogicalType::DOUBLE}, - LogicalTypeId::DECIMAL, BindContinuousQuantileDecimal)); - quantile_cont.AddFunction( - GetQuantileDecimalAggregate({LogicalTypeId::DECIMAL, LogicalType::LIST(LogicalType::DOUBLE)}, - LogicalType::LIST(LogicalTypeId::DECIMAL), BindContinuousQuantileDecimalList)); - - for (const auto &type : GetQuantileTypes()) { - if (CanInterpolate(type)) { - quantile_cont.AddFunction(GetContinuousQuantileAggregate(type)); - quantile_cont.AddFunction(GetContinuousQuantileListAggregate(type)); - } - } - return quantile_cont; -} - -AggregateFunctionSet MadFun::GetFunctions() { - AggregateFunctionSet mad("mad"); - mad.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr, - nullptr, nullptr, nullptr, BindMedianAbsoluteDeviationDecimal)); - - const vector MAD_TYPES = {LogicalType::FLOAT, LogicalType::DOUBLE, LogicalType::DATE, - LogicalType::TIMESTAMP, LogicalType::TIME, LogicalType::TIMESTAMP_TZ, - LogicalType::TIME_TZ}; - for (const auto &type : MAD_TYPES) { - mad.AddFunction(GetMedianAbsoluteDeviationAggregateFunction(type)); - } - return mad; -} - -} // namespace duckdb - - - - - - - - -#include -#include - -namespace duckdb { - -template -struct ReservoirQuantileState { - T *v; - idx_t len; - idx_t pos; - BaseReservoirSampling *r_samp; - - void Resize(idx_t new_len) { - if (new_len <= len) { - return; - } - T *old_v = v; - v = (T *)realloc(v, new_len * sizeof(T)); - if (!v) { - free(old_v); - throw InternalException("Memory allocation failure"); - } - len = new_len; - } - - void ReplaceElement(T &input) { - v[r_samp->min_entry] = input; - r_samp->ReplaceElement(); - } - - void FillReservoir(idx_t sample_size, T element) { - if (pos < sample_size) { - v[pos++] = element; - r_samp->InitializeReservoir(pos, len); - } else { - D_ASSERT(r_samp->next_index >= r_samp->current_count); - if (r_samp->next_index == r_samp->current_count) { - ReplaceElement(element); - } - } - } -}; - -struct ReservoirQuantileBindData : public FunctionData { - ReservoirQuantileBindData() { - } - ReservoirQuantileBindData(double quantile_p, int32_t sample_size_p) - : quantiles(1, quantile_p), sample_size(sample_size_p) { - } - - ReservoirQuantileBindData(vector quantiles_p, int32_t sample_size_p) - : quantiles(std::move(quantiles_p)), sample_size(sample_size_p) { - } - - unique_ptr Copy() const override { - return make_uniq(quantiles, sample_size); - } - - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return quantiles == other.quantiles && sample_size == other.sample_size; - } - - static void Serialize(Serializer &serializer, const optional_ptr bind_data_p, - const AggregateFunction &function) { - auto &bind_data = bind_data_p->Cast(); - serializer.WriteProperty(100, "quantiles", bind_data.quantiles); - serializer.WriteProperty(101, "sample_size", bind_data.sample_size); - } - - static unique_ptr Deserialize(Deserializer &deserializer, AggregateFunction &function) { - auto result = make_uniq(); - deserializer.ReadProperty(100, "quantiles", result->quantiles); - deserializer.ReadProperty(101, "sample_size", result->sample_size); - return std::move(result); - } - - vector quantiles; - int32_t sample_size; -}; - -struct ReservoirQuantileOperation { - template - static void Initialize(STATE &state) { - state.v = nullptr; - state.len = 0; - state.pos = 0; - state.r_samp = nullptr; - } - - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, - idx_t count) { - for (idx_t i = 0; i < count; i++) { - Operation(state, input, unary_input); - } - } - - template - static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { - auto &bind_data = unary_input.input.bind_data->template Cast(); - if (state.pos == 0) { - state.Resize(bind_data.sample_size); - } - if (!state.r_samp) { - state.r_samp = new BaseReservoirSampling(); - } - D_ASSERT(state.v); - state.FillReservoir(bind_data.sample_size, input); - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - if (source.pos == 0) { - return; - } - if (target.pos == 0) { - target.Resize(source.len); - } - if (!target.r_samp) { - target.r_samp = new BaseReservoirSampling(); - } - for (idx_t src_idx = 0; src_idx < source.pos; src_idx++) { - target.FillReservoir(target.len, source.v[src_idx]); - } - } - - template - static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { - if (state.v) { - free(state.v); - state.v = nullptr; - } - if (state.r_samp) { - delete state.r_samp; - state.r_samp = nullptr; - } - } - - static bool IgnoreNull() { - return true; - } -}; - -struct ReservoirQuantileScalarOperation : public ReservoirQuantileOperation { - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.pos == 0) { - finalize_data.ReturnNull(); - return; - } - D_ASSERT(state.v); - D_ASSERT(finalize_data.input.bind_data); - auto &bind_data = finalize_data.input.bind_data->template Cast(); - auto v_t = state.v; - D_ASSERT(bind_data.quantiles.size() == 1); - auto offset = (idx_t)((double)(state.pos - 1) * bind_data.quantiles[0]); - std::nth_element(v_t, v_t + offset, v_t + state.pos); - target = v_t[offset]; - } -}; - -AggregateFunction GetReservoirQuantileAggregateFunction(PhysicalType type) { - switch (type) { - case PhysicalType::INT8: - return AggregateFunction::UnaryAggregateDestructor, int8_t, int8_t, - ReservoirQuantileScalarOperation>(LogicalType::TINYINT, - LogicalType::TINYINT); - - case PhysicalType::INT16: - return AggregateFunction::UnaryAggregateDestructor, int16_t, int16_t, - ReservoirQuantileScalarOperation>(LogicalType::SMALLINT, - LogicalType::SMALLINT); - - case PhysicalType::INT32: - return AggregateFunction::UnaryAggregateDestructor, int32_t, int32_t, - ReservoirQuantileScalarOperation>(LogicalType::INTEGER, - LogicalType::INTEGER); - - case PhysicalType::INT64: - return AggregateFunction::UnaryAggregateDestructor, int64_t, int64_t, - ReservoirQuantileScalarOperation>(LogicalType::BIGINT, - LogicalType::BIGINT); - - case PhysicalType::INT128: - return AggregateFunction::UnaryAggregateDestructor, hugeint_t, hugeint_t, - ReservoirQuantileScalarOperation>(LogicalType::HUGEINT, - LogicalType::HUGEINT); - case PhysicalType::FLOAT: - return AggregateFunction::UnaryAggregateDestructor, float, float, - ReservoirQuantileScalarOperation>(LogicalType::FLOAT, - LogicalType::FLOAT); - case PhysicalType::DOUBLE: - return AggregateFunction::UnaryAggregateDestructor, double, double, - ReservoirQuantileScalarOperation>(LogicalType::DOUBLE, - LogicalType::DOUBLE); - default: - throw InternalException("Unimplemented reservoir quantile aggregate"); - } -} - -template -struct ReservoirQuantileListOperation : public ReservoirQuantileOperation { - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.pos == 0) { - finalize_data.ReturnNull(); - return; - } - - D_ASSERT(finalize_data.input.bind_data); - auto &bind_data = finalize_data.input.bind_data->template Cast(); - - auto &result = ListVector::GetEntry(finalize_data.result); - auto ridx = ListVector::GetListSize(finalize_data.result); - ListVector::Reserve(finalize_data.result, ridx + bind_data.quantiles.size()); - auto rdata = FlatVector::GetData(result); - - auto v_t = state.v; - D_ASSERT(v_t); - - auto &entry = target; - entry.offset = ridx; - entry.length = bind_data.quantiles.size(); - for (size_t q = 0; q < entry.length; ++q) { - const auto &quantile = bind_data.quantiles[q]; - auto offset = (idx_t)((double)(state.pos - 1) * quantile); - std::nth_element(v_t, v_t + offset, v_t + state.pos); - rdata[ridx + q] = v_t[offset]; - } - - ListVector::SetListSize(finalize_data.result, entry.offset + entry.length); - } -}; - -template -static AggregateFunction ReservoirQuantileListAggregate(const LogicalType &input_type, const LogicalType &child_type) { - LogicalType result_type = LogicalType::LIST(child_type); - return AggregateFunction( - {input_type}, result_type, AggregateFunction::StateSize, AggregateFunction::StateInitialize, - AggregateFunction::UnaryScatterUpdate, AggregateFunction::StateCombine, - AggregateFunction::StateFinalize, AggregateFunction::UnaryUpdate, - nullptr, AggregateFunction::StateDestroy); -} - -template -AggregateFunction GetTypedReservoirQuantileListAggregateFunction(const LogicalType &type) { - using STATE = ReservoirQuantileState; - using OP = ReservoirQuantileListOperation; - auto fun = ReservoirQuantileListAggregate(type, type); - return fun; -} - -AggregateFunction GetReservoirQuantileListAggregateFunction(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::TINYINT: - return GetTypedReservoirQuantileListAggregateFunction(type); - case LogicalTypeId::SMALLINT: - return GetTypedReservoirQuantileListAggregateFunction(type); - case LogicalTypeId::INTEGER: - return GetTypedReservoirQuantileListAggregateFunction(type); - case LogicalTypeId::BIGINT: - return GetTypedReservoirQuantileListAggregateFunction(type); - case LogicalTypeId::HUGEINT: - return GetTypedReservoirQuantileListAggregateFunction(type); - case LogicalTypeId::FLOAT: - return GetTypedReservoirQuantileListAggregateFunction(type); - case LogicalTypeId::DOUBLE: - return GetTypedReservoirQuantileListAggregateFunction(type); - case LogicalTypeId::DECIMAL: - switch (type.InternalType()) { - case PhysicalType::INT16: - return GetTypedReservoirQuantileListAggregateFunction(type); - case PhysicalType::INT32: - return GetTypedReservoirQuantileListAggregateFunction(type); - case PhysicalType::INT64: - return GetTypedReservoirQuantileListAggregateFunction(type); - case PhysicalType::INT128: - return GetTypedReservoirQuantileListAggregateFunction(type); - default: - throw NotImplementedException("Unimplemented reservoir quantile list aggregate"); - } - default: - // TODO: Add quantitative temporal types - throw NotImplementedException("Unimplemented reservoir quantile list aggregate"); - } -} - -static double CheckReservoirQuantile(const Value &quantile_val) { - if (quantile_val.IsNull()) { - throw BinderException("RESERVOIR_QUANTILE QUANTILE parameter cannot be NULL"); - } - auto quantile = quantile_val.GetValue(); - if (quantile < 0 || quantile > 1) { - throw BinderException("RESERVOIR_QUANTILE can only take parameters in the range [0, 1]"); - } - return quantile; -} - -unique_ptr BindReservoirQuantile(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - D_ASSERT(arguments.size() >= 2); - if (arguments[1]->HasParameter()) { - throw ParameterNotResolvedException(); - } - if (!arguments[1]->IsFoldable()) { - throw BinderException("RESERVOIR_QUANTILE can only take constant quantile parameters"); - } - Value quantile_val = ExpressionExecutor::EvaluateScalar(context, *arguments[1]); - vector quantiles; - if (quantile_val.type().id() != LogicalTypeId::LIST) { - quantiles.push_back(CheckReservoirQuantile(quantile_val)); - } else { - for (const auto &element_val : ListValue::GetChildren(quantile_val)) { - quantiles.push_back(CheckReservoirQuantile(element_val)); - } - } - - if (arguments.size() == 2) { - if (function.arguments.size() == 2) { - Function::EraseArgument(function, arguments, arguments.size() - 1); - } else { - arguments.pop_back(); - } - return make_uniq(quantiles, 8192); - } - if (!arguments[2]->IsFoldable()) { - throw BinderException("RESERVOIR_QUANTILE can only take constant sample size parameters"); - } - Value sample_size_val = ExpressionExecutor::EvaluateScalar(context, *arguments[2]); - if (sample_size_val.IsNull()) { - throw BinderException("Size of the RESERVOIR_QUANTILE sample cannot be NULL"); - } - auto sample_size = sample_size_val.GetValue(); - - if (sample_size_val.IsNull() || sample_size <= 0) { - throw BinderException("Size of the RESERVOIR_QUANTILE sample must be bigger than 0"); - } - - // remove the quantile argument so we can use the unary aggregate - Function::EraseArgument(function, arguments, arguments.size() - 1); - Function::EraseArgument(function, arguments, arguments.size() - 1); - return make_uniq(quantiles, sample_size); -} - -unique_ptr BindReservoirQuantileDecimal(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - function = GetReservoirQuantileAggregateFunction(arguments[0]->return_type.InternalType()); - auto bind_data = BindReservoirQuantile(context, function, arguments); - function.name = "reservoir_quantile"; - function.serialize = ReservoirQuantileBindData::Serialize; - function.deserialize = ReservoirQuantileBindData::Deserialize; - return bind_data; -} - -AggregateFunction GetReservoirQuantileAggregate(PhysicalType type) { - auto fun = GetReservoirQuantileAggregateFunction(type); - fun.bind = BindReservoirQuantile; - fun.serialize = ReservoirQuantileBindData::Serialize; - fun.deserialize = ReservoirQuantileBindData::Deserialize; - // temporarily push an argument so we can bind the actual quantile - fun.arguments.emplace_back(LogicalType::DOUBLE); - return fun; -} - -unique_ptr BindReservoirQuantileDecimalList(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - function = GetReservoirQuantileListAggregateFunction(arguments[0]->return_type); - auto bind_data = BindReservoirQuantile(context, function, arguments); - function.serialize = ReservoirQuantileBindData::Serialize; - function.deserialize = ReservoirQuantileBindData::Deserialize; - function.name = "reservoir_quantile"; - return bind_data; -} - -AggregateFunction GetReservoirQuantileListAggregate(const LogicalType &type) { - auto fun = GetReservoirQuantileListAggregateFunction(type); - fun.bind = BindReservoirQuantile; - fun.serialize = ReservoirQuantileBindData::Serialize; - fun.deserialize = ReservoirQuantileBindData::Deserialize; - // temporarily push an argument so we can bind the actual quantile - auto list_of_double = LogicalType::LIST(LogicalType::DOUBLE); - fun.arguments.push_back(list_of_double); - return fun; -} - -static void DefineReservoirQuantile(AggregateFunctionSet &set, const LogicalType &type) { - // Four versions: type, scalar/list[, count] - auto fun = GetReservoirQuantileAggregate(type.InternalType()); - set.AddFunction(fun); - - fun.arguments.emplace_back(LogicalType::INTEGER); - set.AddFunction(fun); - - // List variants - fun = GetReservoirQuantileListAggregate(type); - set.AddFunction(fun); - - fun.arguments.emplace_back(LogicalType::INTEGER); - set.AddFunction(fun); -} - -static void GetReservoirQuantileDecimalFunction(AggregateFunctionSet &set, const vector &arguments, - const LogicalType &return_value) { - AggregateFunction fun(arguments, return_value, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, - BindReservoirQuantileDecimal); - fun.serialize = ReservoirQuantileBindData::Serialize; - fun.deserialize = ReservoirQuantileBindData::Deserialize; - set.AddFunction(fun); - - fun.arguments.emplace_back(LogicalType::INTEGER); - set.AddFunction(fun); -} - -AggregateFunctionSet ReservoirQuantileFun::GetFunctions() { - AggregateFunctionSet reservoir_quantile; - - // DECIMAL - GetReservoirQuantileDecimalFunction(reservoir_quantile, {LogicalTypeId::DECIMAL, LogicalType::DOUBLE}, - LogicalTypeId::DECIMAL); - GetReservoirQuantileDecimalFunction(reservoir_quantile, - {LogicalTypeId::DECIMAL, LogicalType::LIST(LogicalType::DOUBLE)}, - LogicalType::LIST(LogicalTypeId::DECIMAL)); - - DefineReservoirQuantile(reservoir_quantile, LogicalTypeId::TINYINT); - DefineReservoirQuantile(reservoir_quantile, LogicalTypeId::SMALLINT); - DefineReservoirQuantile(reservoir_quantile, LogicalTypeId::INTEGER); - DefineReservoirQuantile(reservoir_quantile, LogicalTypeId::BIGINT); - DefineReservoirQuantile(reservoir_quantile, LogicalTypeId::HUGEINT); - DefineReservoirQuantile(reservoir_quantile, LogicalTypeId::FLOAT); - DefineReservoirQuantile(reservoir_quantile, LogicalTypeId::DOUBLE); - return reservoir_quantile; -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -struct HistogramFunctor { - template > - static void HistogramUpdate(UnifiedVectorFormat &sdata, UnifiedVectorFormat &input_data, idx_t count) { - auto states = (HistogramAggState **)sdata.data; - for (idx_t i = 0; i < count; i++) { - if (input_data.validity.RowIsValid(input_data.sel->get_index(i))) { - auto &state = *states[sdata.sel->get_index(i)]; - if (!state.hist) { - state.hist = new MAP_TYPE(); - } - auto value = UnifiedVectorFormat::GetData(input_data); - (*state.hist)[value[input_data.sel->get_index(i)]]++; - } - } - } - - template - static Value HistogramFinalize(T first) { - return Value::CreateValue(first); - } -}; - -struct HistogramStringFunctor { - template > - static void HistogramUpdate(UnifiedVectorFormat &sdata, UnifiedVectorFormat &input_data, idx_t count) { - auto states = (HistogramAggState **)sdata.data; - auto input_strings = UnifiedVectorFormat::GetData(input_data); - for (idx_t i = 0; i < count; i++) { - if (input_data.validity.RowIsValid(input_data.sel->get_index(i))) { - auto &state = *states[sdata.sel->get_index(i)]; - if (!state.hist) { - state.hist = new MAP_TYPE(); - } - (*state.hist)[input_strings[input_data.sel->get_index(i)].GetString()]++; - } - } - } - - template - static Value HistogramFinalize(T first) { - string_t value = first; - return Value::CreateValue(value); - } -}; - -struct HistogramFunction { - template - static void Initialize(STATE &state) { - state.hist = nullptr; - } - - template - static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { - if (state.hist) { - delete state.hist; - } - } - - static bool IgnoreNull() { - return true; - } -}; - -template -static void HistogramUpdateFunction(Vector inputs[], AggregateInputData &, idx_t input_count, Vector &state_vector, - idx_t count) { - - D_ASSERT(input_count == 1); - - auto &input = inputs[0]; - UnifiedVectorFormat sdata; - state_vector.ToUnifiedFormat(count, sdata); - UnifiedVectorFormat input_data; - input.ToUnifiedFormat(count, input_data); - - OP::template HistogramUpdate(sdata, input_data, count); -} - -template -static void HistogramCombineFunction(Vector &state_vector, Vector &combined, AggregateInputData &, idx_t count) { - - UnifiedVectorFormat sdata; - state_vector.ToUnifiedFormat(count, sdata); - auto states_ptr = (HistogramAggState **)sdata.data; - - auto combined_ptr = FlatVector::GetData *>(combined); - - for (idx_t i = 0; i < count; i++) { - auto &state = *states_ptr[sdata.sel->get_index(i)]; - if (!state.hist) { - continue; - } - if (!combined_ptr[i]->hist) { - combined_ptr[i]->hist = new MAP_TYPE(); - } - D_ASSERT(combined_ptr[i]->hist); - D_ASSERT(state.hist); - for (auto &entry : *state.hist) { - (*combined_ptr[i]->hist)[entry.first] += entry.second; - } - } -} - -template -static void HistogramFinalizeFunction(Vector &state_vector, AggregateInputData &, Vector &result, idx_t count, - idx_t offset) { - - UnifiedVectorFormat sdata; - state_vector.ToUnifiedFormat(count, sdata); - auto states = (HistogramAggState **)sdata.data; - - auto &mask = FlatVector::Validity(result); - auto old_len = ListVector::GetListSize(result); - - for (idx_t i = 0; i < count; i++) { - const auto rid = i + offset; - auto &state = *states[sdata.sel->get_index(i)]; - if (!state.hist) { - mask.SetInvalid(rid); - continue; - } - - for (auto &entry : *state.hist) { - Value bucket_value = OP::template HistogramFinalize(entry.first); - auto count_value = Value::CreateValue(entry.second); - auto struct_value = - Value::STRUCT({std::make_pair("key", bucket_value), std::make_pair("value", count_value)}); - ListVector::PushBack(result, struct_value); - } - - auto list_struct_data = ListVector::GetData(result); - list_struct_data[rid].length = ListVector::GetListSize(result) - old_len; - list_struct_data[rid].offset = old_len; - old_len += list_struct_data[rid].length; - } - result.Verify(count); -} - -unique_ptr HistogramBindFunction(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - - D_ASSERT(arguments.size() == 1); - - if (arguments[0]->return_type.id() == LogicalTypeId::LIST || - arguments[0]->return_type.id() == LogicalTypeId::STRUCT || - arguments[0]->return_type.id() == LogicalTypeId::MAP) { - throw NotImplementedException("Unimplemented type for histogram %s", arguments[0]->return_type.ToString()); - } - - auto struct_type = LogicalType::MAP(arguments[0]->return_type, LogicalType::UBIGINT); - - function.return_type = struct_type; - return make_uniq(function.return_type); -} - -template > -static AggregateFunction GetHistogramFunction(const LogicalType &type) { - - using STATE_TYPE = HistogramAggState; - - return AggregateFunction("histogram", {type}, LogicalTypeId::MAP, AggregateFunction::StateSize, - AggregateFunction::StateInitialize, - HistogramUpdateFunction, HistogramCombineFunction, - HistogramFinalizeFunction, nullptr, HistogramBindFunction, - AggregateFunction::StateDestroy); -} - -template -AggregateFunction GetMapType(const LogicalType &type) { - - if (IS_ORDERED) { - return GetHistogramFunction(type); - } - return GetHistogramFunction>(type); -} - -template -AggregateFunction GetHistogramFunction(const LogicalType &type) { - - switch (type.id()) { - case LogicalType::BOOLEAN: - return GetMapType(type); - case LogicalType::UTINYINT: - return GetMapType(type); - case LogicalType::USMALLINT: - return GetMapType(type); - case LogicalType::UINTEGER: - return GetMapType(type); - case LogicalType::UBIGINT: - return GetMapType(type); - case LogicalType::TINYINT: - return GetMapType(type); - case LogicalType::SMALLINT: - return GetMapType(type); - case LogicalType::INTEGER: - return GetMapType(type); - case LogicalType::BIGINT: - return GetMapType(type); - case LogicalType::FLOAT: - return GetMapType(type); - case LogicalType::DOUBLE: - return GetMapType(type); - case LogicalType::VARCHAR: - return GetMapType(type); - case LogicalType::TIMESTAMP: - return GetMapType(type); - case LogicalType::TIMESTAMP_TZ: - return GetMapType(type); - case LogicalType::TIMESTAMP_S: - return GetMapType(type); - case LogicalType::TIMESTAMP_MS: - return GetMapType(type); - case LogicalType::TIMESTAMP_NS: - return GetMapType(type); - case LogicalType::TIME: - return GetMapType(type); - case LogicalType::TIME_TZ: - return GetMapType(type); - case LogicalType::DATE: - return GetMapType(type); - default: - throw InternalException("Unimplemented histogram aggregate"); - } -} - -AggregateFunctionSet HistogramFun::GetFunctions() { - AggregateFunctionSet fun; - fun.AddFunction(GetHistogramFunction<>(LogicalType::BOOLEAN)); - fun.AddFunction(GetHistogramFunction<>(LogicalType::UTINYINT)); - fun.AddFunction(GetHistogramFunction<>(LogicalType::USMALLINT)); - fun.AddFunction(GetHistogramFunction<>(LogicalType::UINTEGER)); - fun.AddFunction(GetHistogramFunction<>(LogicalType::UBIGINT)); - fun.AddFunction(GetHistogramFunction<>(LogicalType::TINYINT)); - fun.AddFunction(GetHistogramFunction<>(LogicalType::SMALLINT)); - fun.AddFunction(GetHistogramFunction<>(LogicalType::INTEGER)); - fun.AddFunction(GetHistogramFunction<>(LogicalType::BIGINT)); - fun.AddFunction(GetHistogramFunction<>(LogicalType::FLOAT)); - fun.AddFunction(GetHistogramFunction<>(LogicalType::DOUBLE)); - fun.AddFunction(GetHistogramFunction<>(LogicalType::VARCHAR)); - fun.AddFunction(GetHistogramFunction<>(LogicalType::TIMESTAMP)); - fun.AddFunction(GetHistogramFunction<>(LogicalType::TIMESTAMP_TZ)); - fun.AddFunction(GetHistogramFunction<>(LogicalType::TIMESTAMP_S)); - fun.AddFunction(GetHistogramFunction<>(LogicalType::TIMESTAMP_MS)); - fun.AddFunction(GetHistogramFunction<>(LogicalType::TIMESTAMP_NS)); - fun.AddFunction(GetHistogramFunction<>(LogicalType::TIME)); - fun.AddFunction(GetHistogramFunction<>(LogicalType::TIME_TZ)); - fun.AddFunction(GetHistogramFunction<>(LogicalType::DATE)); - return fun; -} - -AggregateFunction HistogramFun::GetHistogramUnorderedMap(LogicalType &type) { - const auto &const_type = type; - return GetHistogramFunction(const_type); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -struct ListBindData : public FunctionData { - explicit ListBindData(const LogicalType &stype_p); - ~ListBindData() override; - - LogicalType stype; - ListSegmentFunctions functions; - - unique_ptr Copy() const override { - return make_uniq(stype); - } - - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return stype == other.stype; - } -}; - -ListBindData::ListBindData(const LogicalType &stype_p) : stype(stype_p) { - // always unnest once because the result vector is of type LIST - auto type = ListType::GetChildType(stype_p); - GetSegmentDataFunctions(functions, type); -} - -ListBindData::~ListBindData() { -} - -struct ListAggState { - LinkedList linked_list; -}; - -struct ListFunction { - template - static void Initialize(STATE &state) { - state.linked_list.total_capacity = 0; - state.linked_list.first_segment = nullptr; - state.linked_list.last_segment = nullptr; - } - static bool IgnoreNull() { - return false; - } -}; - -static void ListUpdateFunction(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, - Vector &state_vector, idx_t count) { - - D_ASSERT(input_count == 1); - auto &input = inputs[0]; - RecursiveUnifiedVectorFormat input_data; - Vector::RecursiveToUnifiedFormat(input, count, input_data); - - UnifiedVectorFormat states_data; - state_vector.ToUnifiedFormat(count, states_data); - auto states = UnifiedVectorFormat::GetData(states_data); - - auto &list_bind_data = aggr_input_data.bind_data->Cast(); - - for (idx_t i = 0; i < count; i++) { - auto &state = *states[states_data.sel->get_index(i)]; - list_bind_data.functions.AppendRow(aggr_input_data.allocator, state.linked_list, input_data, i); - } -} - -static void ListCombineFunction(Vector &states_vector, Vector &combined, AggregateInputData &, idx_t count) { - - UnifiedVectorFormat states_data; - states_vector.ToUnifiedFormat(count, states_data); - auto states_ptr = UnifiedVectorFormat::GetData(states_data); - - auto combined_ptr = FlatVector::GetData(combined); - for (idx_t i = 0; i < count; i++) { - - auto &state = *states_ptr[states_data.sel->get_index(i)]; - if (state.linked_list.total_capacity == 0) { - // NULL, no need to append - // this can happen when adding a FILTER to the grouping, e.g., - // LIST(i) FILTER (WHERE i <> 3) - continue; - } - - if (combined_ptr[i]->linked_list.total_capacity == 0) { - combined_ptr[i]->linked_list = state.linked_list; - continue; - } - - // append the linked list - combined_ptr[i]->linked_list.last_segment->next = state.linked_list.first_segment; - combined_ptr[i]->linked_list.last_segment = state.linked_list.last_segment; - combined_ptr[i]->linked_list.total_capacity += state.linked_list.total_capacity; - } -} - -static void ListFinalize(Vector &states_vector, AggregateInputData &aggr_input_data, Vector &result, idx_t count, - idx_t offset) { - - UnifiedVectorFormat states_data; - states_vector.ToUnifiedFormat(count, states_data); - auto states = UnifiedVectorFormat::GetData(states_data); - - D_ASSERT(result.GetType().id() == LogicalTypeId::LIST); - - auto &mask = FlatVector::Validity(result); - auto result_data = FlatVector::GetData(result); - size_t total_len = ListVector::GetListSize(result); - - auto &list_bind_data = aggr_input_data.bind_data->Cast(); - - // first iterate over all entries and set up the list entries, and get the newly required total length - for (idx_t i = 0; i < count; i++) { - - auto &state = *states[states_data.sel->get_index(i)]; - const auto rid = i + offset; - result_data[rid].offset = total_len; - if (state.linked_list.total_capacity == 0) { - mask.SetInvalid(rid); - result_data[rid].length = 0; - continue; - } - - // set the length and offset of this list in the result vector - auto total_capacity = state.linked_list.total_capacity; - result_data[rid].length = total_capacity; - total_len += total_capacity; - } - - // reserve capacity, then iterate over all entries again and copy over the data to the child vector - ListVector::Reserve(result, total_len); - auto &result_child = ListVector::GetEntry(result); - for (idx_t i = 0; i < count; i++) { - - auto &state = *states[states_data.sel->get_index(i)]; - const auto rid = i + offset; - if (state.linked_list.total_capacity == 0) { - continue; - } - - idx_t current_offset = result_data[rid].offset; - list_bind_data.functions.BuildListVector(state.linked_list, result_child, current_offset); - } - - ListVector::SetListSize(result, total_len); -} - -static void ListWindow(Vector inputs[], const ValidityMask &filter_mask, AggregateInputData &aggr_input_data, - idx_t input_count, data_ptr_t state, const FrameBounds &frame, const FrameBounds &prev, - Vector &result, idx_t rid, idx_t bias) { - - auto &list_bind_data = aggr_input_data.bind_data->Cast(); - LinkedList linked_list; - - // UPDATE step - - D_ASSERT(input_count == 1); - auto &input = inputs[0]; - - // FIXME: we unify more values than necessary (count is frame.end) - RecursiveUnifiedVectorFormat input_data; - Vector::RecursiveToUnifiedFormat(input, frame.end, input_data); - - for (idx_t i = frame.start; i < frame.end; i++) { - list_bind_data.functions.AppendRow(aggr_input_data.allocator, linked_list, input_data, i); - } - - // FINALIZE step - - D_ASSERT(result.GetType().id() == LogicalTypeId::LIST); - auto result_data = FlatVector::GetData(result); - size_t total_len = ListVector::GetListSize(result); - - // set the length and offset of this list in the result vector - result_data[rid].offset = total_len; - result_data[rid].length = linked_list.total_capacity; - D_ASSERT(linked_list.total_capacity != 0); - total_len += linked_list.total_capacity; - - // reserve capacity, then copy over the data to the child vector - ListVector::Reserve(result, total_len); - auto &result_child = ListVector::GetEntry(result); - idx_t offset = result_data[rid].offset; - list_bind_data.functions.BuildListVector(linked_list, result_child, offset); - - ListVector::SetListSize(result, total_len); -} - -unique_ptr ListBindFunction(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - D_ASSERT(arguments.size() == 1); - D_ASSERT(function.arguments.size() == 1); - - if (arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN) { - function.arguments[0] = LogicalTypeId::UNKNOWN; - function.return_type = LogicalType::SQLNULL; - return nullptr; - } - - function.return_type = LogicalType::LIST(arguments[0]->return_type); - return make_uniq(function.return_type); -} - -AggregateFunction ListFun::GetFunction() { - return AggregateFunction({LogicalType::ANY}, LogicalTypeId::LIST, AggregateFunction::StateSize, - AggregateFunction::StateInitialize, ListUpdateFunction, - ListCombineFunction, ListFinalize, nullptr, ListBindFunction, nullptr, nullptr, - ListWindow); -} - -} // namespace duckdb - - - - - - -namespace duckdb { -struct RegrState { - double sum; - size_t count; -}; - -struct RegrAvgFunction { - template - static void Initialize(STATE &state) { - state.sum = 0; - state.count = 0; - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - target.sum += source.sum; - target.count += source.count; - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.count == 0) { - finalize_data.ReturnNull(); - } else { - target = state.sum / (double)state.count; - } - } - static bool IgnoreNull() { - return true; - } -}; -struct RegrAvgXFunction : RegrAvgFunction { - template - static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { - state.sum += x; - state.count++; - } -}; - -struct RegrAvgYFunction : RegrAvgFunction { - template - static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { - state.sum += y; - state.count++; - } -}; - -AggregateFunction RegrAvgxFun::GetFunction() { - return AggregateFunction::BinaryAggregate( - LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); -} - -AggregateFunction RegrAvgyFun::GetFunction() { - return AggregateFunction::BinaryAggregate( - LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -AggregateFunction RegrCountFun::GetFunction() { - auto regr_count = AggregateFunction::BinaryAggregate( - LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::UINTEGER); - regr_count.name = "regr_count"; - regr_count.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return regr_count; -} - -} // namespace duckdb -//! AVG(y)-REGR_SLOPE(y,x)*AVG(x) - - - - - -namespace duckdb { - -struct RegrInterceptState { - size_t count; - double sum_x; - double sum_y; - RegrSlopeState slope; -}; - -struct RegrInterceptOperation { - template - static void Initialize(STATE &state) { - state.count = 0; - state.sum_x = 0; - state.sum_y = 0; - RegrSlopeOperation::Initialize(state.slope); - } - - template - static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { - state.count++; - state.sum_x += x; - state.sum_y += y; - RegrSlopeOperation::Operation(state.slope, y, x, idata); - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input_data) { - target.count += source.count; - target.sum_x += source.sum_x; - target.sum_y += source.sum_y; - RegrSlopeOperation::Combine(source.slope, target.slope, aggr_input_data); - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.count == 0) { - finalize_data.ReturnNull(); - return; - } - RegrSlopeOperation::Finalize(state.slope, target, finalize_data); - auto x_avg = state.sum_x / state.count; - auto y_avg = state.sum_y / state.count; - target = y_avg - target * x_avg; - } - - static bool IgnoreNull() { - return true; - } -}; - -AggregateFunction RegrInterceptFun::GetFunction() { - return AggregateFunction::BinaryAggregate( - LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); -} - -} // namespace duckdb -// REGR_R2(y, x) -// Returns the coefficient of determination for non-null pairs in a group. -// It is computed for non-null pairs using the following formula: -// null if var_pop(x) = 0, else -// 1 if var_pop(y) = 0 and var_pop(x) <> 0, else -// power(corr(y,x), 2) - - - - - -namespace duckdb { -struct RegrR2State { - CorrState corr; - StddevState var_pop_x; - StddevState var_pop_y; -}; - -struct RegrR2Operation { - template - static void Initialize(STATE &state) { - CorrOperation::Initialize(state.corr); - STDDevBaseOperation::Initialize(state.var_pop_x); - STDDevBaseOperation::Initialize(state.var_pop_y); - } - - template - static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { - CorrOperation::Operation(state.corr, y, x, idata); - STDDevBaseOperation::Execute(state.var_pop_x, x); - STDDevBaseOperation::Execute(state.var_pop_y, y); - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input_data) { - CorrOperation::Combine(source.corr, target.corr, aggr_input_data); - STDDevBaseOperation::Combine(source.var_pop_x, target.var_pop_x, aggr_input_data); - STDDevBaseOperation::Combine(source.var_pop_y, target.var_pop_y, aggr_input_data); - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - auto var_pop_x = state.var_pop_x.count > 1 ? (state.var_pop_x.dsquared / state.var_pop_x.count) : 0; - if (!Value::DoubleIsFinite(var_pop_x)) { - throw OutOfRangeException("VARPOP(X) is out of range!"); - } - if (var_pop_x == 0) { - finalize_data.ReturnNull(); - return; - } - auto var_pop_y = state.var_pop_y.count > 1 ? (state.var_pop_y.dsquared / state.var_pop_y.count) : 0; - if (!Value::DoubleIsFinite(var_pop_y)) { - throw OutOfRangeException("VARPOP(Y) is out of range!"); - } - if (var_pop_y == 0) { - target = 1; - return; - } - CorrOperation::Finalize(state.corr, target, finalize_data); - target = pow(target, 2); - } - - static bool IgnoreNull() { - return true; - } -}; - -AggregateFunction RegrR2Fun::GetFunction() { - return AggregateFunction::BinaryAggregate( - LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); -} -} // namespace duckdb -// REGR_SLOPE(y, x) -// Returns the slope of the linear regression line for non-null pairs in a group. -// It is computed for non-null pairs using the following formula: -// COVAR_POP(x,y) / VAR_POP(x) - -//! Input : Any numeric type -//! Output : Double - - - - - -namespace duckdb { - -AggregateFunction RegrSlopeFun::GetFunction() { - return AggregateFunction::BinaryAggregate( - LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); -} - -} // namespace duckdb -// REGR_SXX(y, x) -// Returns REGR_COUNT(y, x) * VAR_POP(x) for non-null pairs. -// REGR_SYY(y, x) -// Returns REGR_COUNT(y, x) * VAR_POP(y) for non-null pairs. - - - - - -namespace duckdb { - -struct RegrSState { - size_t count; - StddevState var_pop; -}; - -struct RegrBaseOperation { - template - static void Initialize(STATE &state) { - RegrCountFunction::Initialize(state.count); - STDDevBaseOperation::Initialize(state.var_pop); - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input_data) { - RegrCountFunction::Combine(source.count, target.count, aggr_input_data); - STDDevBaseOperation::Combine(source.var_pop, target.var_pop, aggr_input_data); - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.var_pop.count == 0) { - finalize_data.ReturnNull(); - return; - } - auto var_pop = state.var_pop.count > 1 ? (state.var_pop.dsquared / state.var_pop.count) : 0; - if (!Value::DoubleIsFinite(var_pop)) { - throw OutOfRangeException("VARPOP is out of range!"); - } - RegrCountFunction::Finalize(state.count, target, finalize_data); - target *= var_pop; - } - - static bool IgnoreNull() { - return true; - } -}; - -struct RegrSXXOperation : RegrBaseOperation { - template - static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { - RegrCountFunction::Operation(state.count, y, x, idata); - STDDevBaseOperation::Execute(state.var_pop, x); - } -}; - -struct RegrSYYOperation : RegrBaseOperation { - template - static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { - RegrCountFunction::Operation(state.count, y, x, idata); - STDDevBaseOperation::Execute(state.var_pop, y); - } -}; - -AggregateFunction RegrSXXFun::GetFunction() { - return AggregateFunction::BinaryAggregate( - LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); -} - -AggregateFunction RegrSYYFun::GetFunction() { - return AggregateFunction::BinaryAggregate( - LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); -} - -} // namespace duckdb -// REGR_SXY(y, x) -// Returns REGR_COUNT(expr1, expr2) * COVAR_POP(expr1, expr2) for non-null pairs. - - - - - - -namespace duckdb { - -struct RegrSXyState { - size_t count; - CovarState cov_pop; -}; - -struct RegrSXYOperation { - template - static void Initialize(STATE &state) { - RegrCountFunction::Initialize(state.count); - CovarOperation::Initialize(state.cov_pop); - } - - template - static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { - RegrCountFunction::Operation(state.count, y, x, idata); - CovarOperation::Operation(state.cov_pop, y, x, idata); - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input_data) { - CovarOperation::Combine(source.cov_pop, target.cov_pop, aggr_input_data); - RegrCountFunction::Combine(source.count, target.count, aggr_input_data); - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - CovarPopOperation::Finalize(state.cov_pop, target, finalize_data); - auto cov_pop = target; - RegrCountFunction::Finalize(state.count, target, finalize_data); - target *= cov_pop; - } - - static bool IgnoreNull() { - return true; - } -}; - -AggregateFunction RegrSXYFun::GetFunction() { - return AggregateFunction::BinaryAggregate( - LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -template -void FillExtraInfo(StaticFunctionDefinition &function, T &info) { - info.internal = true; - info.description = function.description; - info.parameter_names = StringUtil::Split(function.parameters, ","); - info.example = function.example; -} - -void CoreFunctions::RegisterFunctions(Catalog &catalog, CatalogTransaction transaction) { - auto functions = StaticFunctionDefinition::GetFunctionList(); - for (idx_t i = 0; functions[i].name; i++) { - auto &function = functions[i]; - if (function.get_function || function.get_function_set) { - // scalar function - ScalarFunctionSet result; - if (function.get_function) { - result.AddFunction(function.get_function()); - } else { - result = function.get_function_set(); - } - result.name = function.name; - CreateScalarFunctionInfo info(result); - FillExtraInfo(function, info); - catalog.CreateFunction(transaction, info); - } else if (function.get_aggregate_function || function.get_aggregate_function_set) { - // aggregate function - AggregateFunctionSet result; - if (function.get_aggregate_function) { - result.AddFunction(function.get_aggregate_function()); - } else { - result = function.get_aggregate_function_set(); - } - result.name = function.name; - CreateAggregateFunctionInfo info(result); - FillExtraInfo(function, info); - catalog.CreateFunction(transaction, info); - } else { - throw InternalException("Do not know how to register function of this type"); - } - } -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - - - - -namespace duckdb { - -// Scalar Function -#define DUCKDB_SCALAR_FUNCTION_BASE(_PARAM, _NAME) \ - { _NAME, _PARAM::Parameters, _PARAM::Description, _PARAM::Example, _PARAM::GetFunction, nullptr, nullptr, nullptr } -#define DUCKDB_SCALAR_FUNCTION(_PARAM) DUCKDB_SCALAR_FUNCTION_BASE(_PARAM, _PARAM::Name) -#define DUCKDB_SCALAR_FUNCTION_ALIAS(_PARAM) DUCKDB_SCALAR_FUNCTION_BASE(_PARAM::ALIAS, _PARAM::Name) -// Scalar Function Set -#define DUCKDB_SCALAR_FUNCTION_SET_BASE(_PARAM, _NAME) \ - { _NAME, _PARAM::Parameters, _PARAM::Description, _PARAM::Example, nullptr, _PARAM::GetFunctions, nullptr, nullptr } -#define DUCKDB_SCALAR_FUNCTION_SET(_PARAM) DUCKDB_SCALAR_FUNCTION_SET_BASE(_PARAM, _PARAM::Name) -#define DUCKDB_SCALAR_FUNCTION_SET_ALIAS(_PARAM) DUCKDB_SCALAR_FUNCTION_SET_BASE(_PARAM::ALIAS, _PARAM::Name) -// Aggregate Function -#define DUCKDB_AGGREGATE_FUNCTION_BASE(_PARAM, _NAME) \ - { _NAME, _PARAM::Parameters, _PARAM::Description, _PARAM::Example, nullptr, nullptr, _PARAM::GetFunction, nullptr } -#define DUCKDB_AGGREGATE_FUNCTION(_PARAM) DUCKDB_AGGREGATE_FUNCTION_BASE(_PARAM, _PARAM::Name) -#define DUCKDB_AGGREGATE_FUNCTION_ALIAS(_PARAM) DUCKDB_AGGREGATE_FUNCTION_BASE(_PARAM::ALIAS, _PARAM::Name) -// Aggregate Function Set -#define DUCKDB_AGGREGATE_FUNCTION_SET_BASE(_PARAM, _NAME) \ - { _NAME, _PARAM::Parameters, _PARAM::Description, _PARAM::Example, nullptr, nullptr, nullptr, _PARAM::GetFunctions } -#define DUCKDB_AGGREGATE_FUNCTION_SET(_PARAM) DUCKDB_AGGREGATE_FUNCTION_SET_BASE(_PARAM, _PARAM::Name) -#define DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(_PARAM) DUCKDB_AGGREGATE_FUNCTION_SET_BASE(_PARAM::ALIAS, _PARAM::Name) -#define FINAL_FUNCTION \ - { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr } - -// this list is generated by scripts/generate_functions.py -static StaticFunctionDefinition internal_functions[] = { - DUCKDB_SCALAR_FUNCTION(FactorialOperatorFun), - DUCKDB_SCALAR_FUNCTION_SET(BitwiseAndFun), - DUCKDB_SCALAR_FUNCTION(PowOperatorFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ListInnerProductFunAlias), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ListDistanceFunAlias), - DUCKDB_SCALAR_FUNCTION_SET(LeftShiftFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ListCosineSimilarityFunAlias), - DUCKDB_SCALAR_FUNCTION_SET(RightShiftFun), - DUCKDB_SCALAR_FUNCTION_SET(AbsOperatorFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(PowOperatorFunAlias), - DUCKDB_SCALAR_FUNCTION(StartsWithOperatorFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(AbsFun), - DUCKDB_SCALAR_FUNCTION(AcosFun), - DUCKDB_SCALAR_FUNCTION_SET(AgeFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(AggregateFun), - DUCKDB_SCALAR_FUNCTION(AliasFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ApplyFun), - DUCKDB_AGGREGATE_FUNCTION_SET(ApproxCountDistinctFun), - DUCKDB_AGGREGATE_FUNCTION_SET(ApproxQuantileFun), - DUCKDB_AGGREGATE_FUNCTION_SET(ArgMaxFun), - DUCKDB_AGGREGATE_FUNCTION_SET(ArgMinFun), - DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(ArgmaxFun), - DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(ArgminFun), - DUCKDB_AGGREGATE_FUNCTION_ALIAS(ArrayAggFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayAggrFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayAggregateFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayApplyFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayDistinctFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayFilterFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ArrayReverseSortFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ArraySliceFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ArraySortFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayTransformFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayUniqueFun), - DUCKDB_SCALAR_FUNCTION(ASCIIFun), - DUCKDB_SCALAR_FUNCTION(AsinFun), - DUCKDB_SCALAR_FUNCTION(AtanFun), - DUCKDB_SCALAR_FUNCTION(Atan2Fun), - DUCKDB_AGGREGATE_FUNCTION_SET(AvgFun), - DUCKDB_SCALAR_FUNCTION_SET(BarFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(Base64Fun), - DUCKDB_SCALAR_FUNCTION_SET(BinFun), - DUCKDB_AGGREGATE_FUNCTION_SET(BitAndFun), - DUCKDB_SCALAR_FUNCTION_SET(BitCountFun), - DUCKDB_AGGREGATE_FUNCTION_SET(BitOrFun), - DUCKDB_SCALAR_FUNCTION(BitPositionFun), - DUCKDB_AGGREGATE_FUNCTION_SET(BitXorFun), - DUCKDB_SCALAR_FUNCTION(BitStringFun), - DUCKDB_AGGREGATE_FUNCTION_SET(BitstringAggFun), - DUCKDB_AGGREGATE_FUNCTION(BoolAndFun), - DUCKDB_AGGREGATE_FUNCTION(BoolOrFun), - DUCKDB_SCALAR_FUNCTION(CardinalityFun), - DUCKDB_SCALAR_FUNCTION(CbrtFun), - DUCKDB_SCALAR_FUNCTION_SET(CeilFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(CeilingFun), - DUCKDB_SCALAR_FUNCTION_SET(CenturyFun), - DUCKDB_SCALAR_FUNCTION(ChrFun), - DUCKDB_AGGREGATE_FUNCTION(CorrFun), - DUCKDB_SCALAR_FUNCTION(CosFun), - DUCKDB_SCALAR_FUNCTION(CotFun), - DUCKDB_AGGREGATE_FUNCTION(CovarPopFun), - DUCKDB_AGGREGATE_FUNCTION(CovarSampFun), - DUCKDB_SCALAR_FUNCTION(CurrentDatabaseFun), - DUCKDB_SCALAR_FUNCTION(CurrentDateFun), - DUCKDB_SCALAR_FUNCTION(CurrentQueryFun), - DUCKDB_SCALAR_FUNCTION(CurrentSchemaFun), - DUCKDB_SCALAR_FUNCTION(CurrentSchemasFun), - DUCKDB_SCALAR_FUNCTION(CurrentSettingFun), - DUCKDB_SCALAR_FUNCTION(DamerauLevenshteinFun), - DUCKDB_SCALAR_FUNCTION_SET(DateDiffFun), - DUCKDB_SCALAR_FUNCTION_SET(DatePartFun), - DUCKDB_SCALAR_FUNCTION_SET(DateSubFun), - DUCKDB_SCALAR_FUNCTION_SET(DateTruncFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(DatediffFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(DatepartFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(DatesubFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(DatetruncFun), - DUCKDB_SCALAR_FUNCTION_SET(DayFun), - DUCKDB_SCALAR_FUNCTION_SET(DayNameFun), - DUCKDB_SCALAR_FUNCTION_SET(DayOfMonthFun), - DUCKDB_SCALAR_FUNCTION_SET(DayOfWeekFun), - DUCKDB_SCALAR_FUNCTION_SET(DayOfYearFun), - DUCKDB_SCALAR_FUNCTION_SET(DecadeFun), - DUCKDB_SCALAR_FUNCTION(DecodeFun), - DUCKDB_SCALAR_FUNCTION(DegreesFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(Editdist3Fun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ElementAtFun), - DUCKDB_SCALAR_FUNCTION(EncodeFun), - DUCKDB_AGGREGATE_FUNCTION_SET(EntropyFun), - DUCKDB_SCALAR_FUNCTION(EnumCodeFun), - DUCKDB_SCALAR_FUNCTION(EnumFirstFun), - DUCKDB_SCALAR_FUNCTION(EnumLastFun), - DUCKDB_SCALAR_FUNCTION(EnumRangeFun), - DUCKDB_SCALAR_FUNCTION(EnumRangeBoundaryFun), - DUCKDB_SCALAR_FUNCTION_SET(EpochFun), - DUCKDB_SCALAR_FUNCTION_SET(EpochMsFun), - DUCKDB_SCALAR_FUNCTION_SET(EpochNsFun), - DUCKDB_SCALAR_FUNCTION_SET(EpochUsFun), - DUCKDB_SCALAR_FUNCTION_SET(EraFun), - DUCKDB_SCALAR_FUNCTION(ErrorFun), - DUCKDB_SCALAR_FUNCTION(EvenFun), - DUCKDB_SCALAR_FUNCTION(ExpFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(FactorialFun), - DUCKDB_AGGREGATE_FUNCTION(FAvgFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(FilterFun), - DUCKDB_SCALAR_FUNCTION(ListFlattenFun), - DUCKDB_SCALAR_FUNCTION_SET(FloorFun), - DUCKDB_SCALAR_FUNCTION(FormatFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(FormatreadabledecimalsizeFun), - DUCKDB_SCALAR_FUNCTION(FormatBytesFun), - DUCKDB_SCALAR_FUNCTION(FromBase64Fun), - DUCKDB_SCALAR_FUNCTION_ALIAS(FromBinaryFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(FromHexFun), - DUCKDB_AGGREGATE_FUNCTION_ALIAS(FsumFun), - DUCKDB_SCALAR_FUNCTION(GammaFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(GcdFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(GenRandomUuidFun), - DUCKDB_SCALAR_FUNCTION_SET(GenerateSeriesFun), - DUCKDB_SCALAR_FUNCTION(GetBitFun), - DUCKDB_SCALAR_FUNCTION(CurrentTimeFun), - DUCKDB_SCALAR_FUNCTION(GetCurrentTimestampFun), - DUCKDB_SCALAR_FUNCTION_SET(GreatestFun), - DUCKDB_SCALAR_FUNCTION_SET(GreatestCommonDivisorFun), - DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(GroupConcatFun), - DUCKDB_SCALAR_FUNCTION(HammingFun), - DUCKDB_SCALAR_FUNCTION(HashFun), - DUCKDB_SCALAR_FUNCTION_SET(HexFun), - DUCKDB_AGGREGATE_FUNCTION_SET(HistogramFun), - DUCKDB_SCALAR_FUNCTION_SET(HoursFun), - DUCKDB_SCALAR_FUNCTION(InSearchPathFun), - DUCKDB_SCALAR_FUNCTION(InstrFun), - DUCKDB_SCALAR_FUNCTION_SET(IsFiniteFun), - DUCKDB_SCALAR_FUNCTION_SET(IsInfiniteFun), - DUCKDB_SCALAR_FUNCTION_SET(IsNanFun), - DUCKDB_SCALAR_FUNCTION_SET(ISODayOfWeekFun), - DUCKDB_SCALAR_FUNCTION_SET(ISOYearFun), - DUCKDB_SCALAR_FUNCTION(JaccardFun), - DUCKDB_SCALAR_FUNCTION(JaroSimilarityFun), - DUCKDB_SCALAR_FUNCTION(JaroWinklerSimilarityFun), - DUCKDB_SCALAR_FUNCTION_SET(JulianDayFun), - DUCKDB_AGGREGATE_FUNCTION(KahanSumFun), - DUCKDB_AGGREGATE_FUNCTION(KurtosisFun), - DUCKDB_SCALAR_FUNCTION_SET(LastDayFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(LcmFun), - DUCKDB_SCALAR_FUNCTION_SET(LeastFun), - DUCKDB_SCALAR_FUNCTION_SET(LeastCommonMultipleFun), - DUCKDB_SCALAR_FUNCTION(LeftFun), - DUCKDB_SCALAR_FUNCTION(LeftGraphemeFun), - DUCKDB_SCALAR_FUNCTION(LevenshteinFun), - DUCKDB_SCALAR_FUNCTION(LogGammaFun), - DUCKDB_AGGREGATE_FUNCTION(ListFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ListAggrFun), - DUCKDB_SCALAR_FUNCTION(ListAggregateFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ListApplyFun), - DUCKDB_SCALAR_FUNCTION_SET(ListCosineSimilarityFun), - DUCKDB_SCALAR_FUNCTION_SET(ListDistanceFun), - DUCKDB_SCALAR_FUNCTION(ListDistinctFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ListDotProductFun), - DUCKDB_SCALAR_FUNCTION(ListFilterFun), - DUCKDB_SCALAR_FUNCTION_SET(ListInnerProductFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ListPackFun), - DUCKDB_SCALAR_FUNCTION_SET(ListReverseSortFun), - DUCKDB_SCALAR_FUNCTION_SET(ListSliceFun), - DUCKDB_SCALAR_FUNCTION_SET(ListSortFun), - DUCKDB_SCALAR_FUNCTION(ListTransformFun), - DUCKDB_SCALAR_FUNCTION(ListUniqueFun), - DUCKDB_SCALAR_FUNCTION(ListValueFun), - DUCKDB_SCALAR_FUNCTION(LnFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(LogFun), - DUCKDB_SCALAR_FUNCTION(Log10Fun), - DUCKDB_SCALAR_FUNCTION(Log2Fun), - DUCKDB_SCALAR_FUNCTION(LpadFun), - DUCKDB_SCALAR_FUNCTION_SET(LtrimFun), - DUCKDB_AGGREGATE_FUNCTION_SET(MadFun), - DUCKDB_SCALAR_FUNCTION_SET(MakeDateFun), - DUCKDB_SCALAR_FUNCTION(MakeTimeFun), - DUCKDB_SCALAR_FUNCTION_SET(MakeTimestampFun), - DUCKDB_SCALAR_FUNCTION(MapFun), - DUCKDB_SCALAR_FUNCTION(MapConcatFun), - DUCKDB_SCALAR_FUNCTION(MapEntriesFun), - DUCKDB_SCALAR_FUNCTION(MapExtractFun), - DUCKDB_SCALAR_FUNCTION(MapFromEntriesFun), - DUCKDB_SCALAR_FUNCTION(MapKeysFun), - DUCKDB_SCALAR_FUNCTION(MapValuesFun), - DUCKDB_AGGREGATE_FUNCTION_SET(MaxFun), - DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(MaxByFun), - DUCKDB_SCALAR_FUNCTION(MD5Fun), - DUCKDB_SCALAR_FUNCTION(MD5NumberFun), - DUCKDB_SCALAR_FUNCTION(MD5NumberLowerFun), - DUCKDB_SCALAR_FUNCTION(MD5NumberUpperFun), - DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(MeanFun), - DUCKDB_AGGREGATE_FUNCTION_SET(MedianFun), - DUCKDB_SCALAR_FUNCTION_SET(MicrosecondsFun), - DUCKDB_SCALAR_FUNCTION_SET(MillenniumFun), - DUCKDB_SCALAR_FUNCTION_SET(MillisecondsFun), - DUCKDB_AGGREGATE_FUNCTION_SET(MinFun), - DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(MinByFun), - DUCKDB_SCALAR_FUNCTION_SET(MinutesFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(MismatchesFun), - DUCKDB_AGGREGATE_FUNCTION_SET(ModeFun), - DUCKDB_SCALAR_FUNCTION_SET(MonthFun), - DUCKDB_SCALAR_FUNCTION_SET(MonthNameFun), - DUCKDB_SCALAR_FUNCTION_SET(NextAfterFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(NowFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(OrdFun), - DUCKDB_SCALAR_FUNCTION(PiFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(PositionFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(PowFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(PowerFun), - DUCKDB_SCALAR_FUNCTION(PrintfFun), - DUCKDB_AGGREGATE_FUNCTION(ProductFun), - DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(QuantileFun), - DUCKDB_AGGREGATE_FUNCTION_SET(QuantileContFun), - DUCKDB_AGGREGATE_FUNCTION_SET(QuantileDiscFun), - DUCKDB_SCALAR_FUNCTION_SET(QuarterFun), - DUCKDB_SCALAR_FUNCTION(RadiansFun), - DUCKDB_SCALAR_FUNCTION(RandomFun), - DUCKDB_SCALAR_FUNCTION_SET(ListRangeFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(RegexpSplitToArrayFun), - DUCKDB_AGGREGATE_FUNCTION(RegrAvgxFun), - DUCKDB_AGGREGATE_FUNCTION(RegrAvgyFun), - DUCKDB_AGGREGATE_FUNCTION(RegrCountFun), - DUCKDB_AGGREGATE_FUNCTION(RegrInterceptFun), - DUCKDB_AGGREGATE_FUNCTION(RegrR2Fun), - DUCKDB_AGGREGATE_FUNCTION(RegrSlopeFun), - DUCKDB_AGGREGATE_FUNCTION(RegrSXXFun), - DUCKDB_AGGREGATE_FUNCTION(RegrSXYFun), - DUCKDB_AGGREGATE_FUNCTION(RegrSYYFun), - DUCKDB_SCALAR_FUNCTION_SET(RepeatFun), - DUCKDB_SCALAR_FUNCTION(ReplaceFun), - DUCKDB_AGGREGATE_FUNCTION_SET(ReservoirQuantileFun), - DUCKDB_SCALAR_FUNCTION(ReverseFun), - DUCKDB_SCALAR_FUNCTION(RightFun), - DUCKDB_SCALAR_FUNCTION(RightGraphemeFun), - DUCKDB_SCALAR_FUNCTION_SET(RoundFun), - DUCKDB_SCALAR_FUNCTION(RowFun), - DUCKDB_SCALAR_FUNCTION(RpadFun), - DUCKDB_SCALAR_FUNCTION_SET(RtrimFun), - DUCKDB_SCALAR_FUNCTION_SET(SecondsFun), - DUCKDB_AGGREGATE_FUNCTION(StandardErrorOfTheMeanFun), - DUCKDB_SCALAR_FUNCTION(SetBitFun), - DUCKDB_SCALAR_FUNCTION(SetseedFun), - DUCKDB_SCALAR_FUNCTION(SHA256Fun), - DUCKDB_SCALAR_FUNCTION_SET(SignFun), - DUCKDB_SCALAR_FUNCTION_SET(SignBitFun), - DUCKDB_SCALAR_FUNCTION(SinFun), - DUCKDB_AGGREGATE_FUNCTION(SkewnessFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(SplitFun), - DUCKDB_SCALAR_FUNCTION(SqrtFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(StartsWithFun), - DUCKDB_SCALAR_FUNCTION(StatsFun), - DUCKDB_AGGREGATE_FUNCTION_ALIAS(StddevFun), - DUCKDB_AGGREGATE_FUNCTION(StdDevPopFun), - DUCKDB_AGGREGATE_FUNCTION(StdDevSampFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(StrSplitFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(StrSplitRegexFun), - DUCKDB_SCALAR_FUNCTION_SET(StrfTimeFun), - DUCKDB_AGGREGATE_FUNCTION_SET(StringAggFun), - DUCKDB_SCALAR_FUNCTION(StringSplitFun), - DUCKDB_SCALAR_FUNCTION_SET(StringSplitRegexFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(StringToArrayFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(StrposFun), - DUCKDB_SCALAR_FUNCTION_SET(StrpTimeFun), - DUCKDB_SCALAR_FUNCTION(StructInsertFun), - DUCKDB_SCALAR_FUNCTION(StructPackFun), - DUCKDB_AGGREGATE_FUNCTION_SET(SumFun), - DUCKDB_AGGREGATE_FUNCTION_SET(SumNoOverflowFun), - DUCKDB_AGGREGATE_FUNCTION_ALIAS(SumkahanFun), - DUCKDB_SCALAR_FUNCTION(TanFun), - DUCKDB_SCALAR_FUNCTION_SET(TimeBucketFun), - DUCKDB_SCALAR_FUNCTION_SET(TimezoneFun), - DUCKDB_SCALAR_FUNCTION_SET(TimezoneHourFun), - DUCKDB_SCALAR_FUNCTION_SET(TimezoneMinuteFun), - DUCKDB_SCALAR_FUNCTION_SET(ToBaseFun), - DUCKDB_SCALAR_FUNCTION(ToBase64Fun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ToBinaryFun), - DUCKDB_SCALAR_FUNCTION(ToDaysFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ToHexFun), - DUCKDB_SCALAR_FUNCTION(ToHoursFun), - DUCKDB_SCALAR_FUNCTION(ToMicrosecondsFun), - DUCKDB_SCALAR_FUNCTION(ToMillisecondsFun), - DUCKDB_SCALAR_FUNCTION(ToMinutesFun), - DUCKDB_SCALAR_FUNCTION(ToMonthsFun), - DUCKDB_SCALAR_FUNCTION(ToSecondsFun), - DUCKDB_SCALAR_FUNCTION(ToTimestampFun), - DUCKDB_SCALAR_FUNCTION(ToYearsFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(TodayFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(TransactionTimestampFun), - DUCKDB_SCALAR_FUNCTION(TranslateFun), - DUCKDB_SCALAR_FUNCTION_SET(TrimFun), - DUCKDB_SCALAR_FUNCTION_SET(TruncFun), - DUCKDB_SCALAR_FUNCTION_SET(TryStrpTimeFun), - DUCKDB_SCALAR_FUNCTION(CurrentTransactionIdFun), - DUCKDB_SCALAR_FUNCTION(TypeOfFun), - DUCKDB_SCALAR_FUNCTION(UnbinFun), - DUCKDB_SCALAR_FUNCTION(UnhexFun), - DUCKDB_SCALAR_FUNCTION(UnicodeFun), - DUCKDB_SCALAR_FUNCTION(UnionExtractFun), - DUCKDB_SCALAR_FUNCTION(UnionTagFun), - DUCKDB_SCALAR_FUNCTION(UnionValueFun), - DUCKDB_SCALAR_FUNCTION(UUIDFun), - DUCKDB_AGGREGATE_FUNCTION(VarPopFun), - DUCKDB_AGGREGATE_FUNCTION(VarSampFun), - DUCKDB_AGGREGATE_FUNCTION_ALIAS(VarianceFun), - DUCKDB_SCALAR_FUNCTION(VectorTypeFun), - DUCKDB_SCALAR_FUNCTION(VersionFun), - DUCKDB_SCALAR_FUNCTION_SET(WeekFun), - DUCKDB_SCALAR_FUNCTION_SET(WeekDayFun), - DUCKDB_SCALAR_FUNCTION_SET(WeekOfYearFun), - DUCKDB_SCALAR_FUNCTION_SET(BitwiseXorFun), - DUCKDB_SCALAR_FUNCTION_SET(YearFun), - DUCKDB_SCALAR_FUNCTION_SET(YearWeekFun), - DUCKDB_SCALAR_FUNCTION_SET(BitwiseOrFun), - DUCKDB_SCALAR_FUNCTION_SET(BitwiseNotFun), - FINAL_FUNCTION -}; - -StaticFunctionDefinition *StaticFunctionDefinition::GetFunctionList() { - return internal_functions; -} - -} // namespace duckdb - - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// BitStringFunction -//===--------------------------------------------------------------------===// -static void BitStringFunction(DataChunk &args, ExpressionState &state, Vector &result) { - BinaryExecutor::Execute( - args.data[0], args.data[1], result, args.size(), [&](string_t input, int32_t n) { - if (n < 0) { - throw InvalidInputException("The bitstring length cannot be negative"); - } - if (idx_t(n) < input.GetSize()) { - throw InvalidInputException("Length must be equal or larger than input string"); - } - idx_t len; - Bit::TryGetBitStringSize(input, len, nullptr); // string verification - - len = Bit::ComputeBitstringLen(n); - string_t target = StringVector::EmptyString(result, len); - Bit::BitString(input, n, target); - target.Finalize(); - return target; - }); -} - -ScalarFunction BitStringFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR, LogicalType::INTEGER}, LogicalType::BIT, BitStringFunction); -} - -//===--------------------------------------------------------------------===// -// get_bit -//===--------------------------------------------------------------------===// -struct GetBitOperator { - template - static inline TR Operation(TA input, TB n) { - if (n < 0 || (idx_t)n > Bit::BitLength(input) - 1) { - throw OutOfRangeException("bit index %s out of valid range (0..%s)", NumericHelper::ToString(n), - NumericHelper::ToString(Bit::BitLength(input) - 1)); - } - return Bit::GetBit(input, n); - } -}; - -ScalarFunction GetBitFun::GetFunction() { - return ScalarFunction({LogicalType::BIT, LogicalType::INTEGER}, LogicalType::INTEGER, - ScalarFunction::BinaryFunction); -} - -//===--------------------------------------------------------------------===// -// set_bit -//===--------------------------------------------------------------------===// -static void SetBitOperation(DataChunk &args, ExpressionState &state, Vector &result) { - TernaryExecutor::Execute( - args.data[0], args.data[1], args.data[2], result, args.size(), - [&](string_t input, int32_t n, int32_t new_value) { - if (new_value != 0 && new_value != 1) { - throw InvalidInputException("The new bit must be 1 or 0"); - } - if (n < 0 || (idx_t)n > Bit::BitLength(input) - 1) { - throw OutOfRangeException("bit index %s out of valid range (0..%s)", NumericHelper::ToString(n), - NumericHelper::ToString(Bit::BitLength(input) - 1)); - } - string_t target = StringVector::EmptyString(result, input.GetSize()); - memcpy(target.GetDataWriteable(), input.GetData(), input.GetSize()); - Bit::SetBit(target, n, new_value); - return target; - }); -} - -ScalarFunction SetBitFun::GetFunction() { - return ScalarFunction({LogicalType::BIT, LogicalType::INTEGER, LogicalType::INTEGER}, LogicalType::BIT, - SetBitOperation); -} - -//===--------------------------------------------------------------------===// -// bit_position -//===--------------------------------------------------------------------===// -struct BitPositionOperator { - template - static inline TR Operation(TA substring, TB input) { - if (substring.GetSize() > input.GetSize()) { - return 0; - } - return Bit::BitPosition(substring, input); - } -}; - -ScalarFunction BitPositionFun::GetFunction() { - return ScalarFunction({LogicalType::BIT, LogicalType::BIT}, LogicalType::INTEGER, - ScalarFunction::BinaryFunction); -} - -} // namespace duckdb - - - -namespace duckdb { - -struct Base64EncodeOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto result_str = StringVector::EmptyString(result, Blob::ToBase64Size(input)); - Blob::ToBase64(input, result_str.GetDataWriteable()); - result_str.Finalize(); - return result_str; - } -}; - -struct Base64DecodeOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto result_size = Blob::FromBase64Size(input); - auto result_blob = StringVector::EmptyString(result, result_size); - Blob::FromBase64(input, data_ptr_cast(result_blob.GetDataWriteable()), result_size); - result_blob.Finalize(); - return result_blob; - } -}; - -static void Base64EncodeFunction(DataChunk &args, ExpressionState &state, Vector &result) { - // decode is also a nop cast, but requires verification if the provided string is actually - UnaryExecutor::ExecuteString(args.data[0], result, args.size()); -} - -static void Base64DecodeFunction(DataChunk &args, ExpressionState &state, Vector &result) { - // decode is also a nop cast, but requires verification if the provided string is actually - UnaryExecutor::ExecuteString(args.data[0], result, args.size()); -} - -ScalarFunction ToBase64Fun::GetFunction() { - return ScalarFunction({LogicalType::BLOB}, LogicalType::VARCHAR, Base64EncodeFunction); -} - -ScalarFunction FromBase64Fun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR}, LogicalType::BLOB, Base64DecodeFunction); -} - -} // namespace duckdb - - - -namespace duckdb { - -static void EncodeFunction(DataChunk &args, ExpressionState &state, Vector &result) { - // encode is essentially a nop cast from varchar to blob - // we only need to reinterpret the data using the blob type - result.Reinterpret(args.data[0]); -} - -struct BlobDecodeOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input) { - auto input_data = input.GetData(); - auto input_length = input.GetSize(); - if (Utf8Proc::Analyze(input_data, input_length) == UnicodeType::INVALID) { - throw ConversionException( - "Failure in decode: could not convert blob to UTF8 string, the blob contained invalid UTF8 characters"); - } - return input; - } -}; - -static void DecodeFunction(DataChunk &args, ExpressionState &state, Vector &result) { - // decode is also a nop cast, but requires verification if the provided string is actually - UnaryExecutor::Execute(args.data[0], result, args.size()); - StringVector::AddHeapReference(result, args.data[0]); -} - -ScalarFunction EncodeFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR}, LogicalType::BLOB, EncodeFunction); -} - -ScalarFunction DecodeFun::GetFunction() { - return ScalarFunction({LogicalType::BLOB}, LogicalType::VARCHAR, DecodeFunction); -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -static void AgeFunctionStandard(DataChunk &input, ExpressionState &state, Vector &result) { - D_ASSERT(input.ColumnCount() == 1); - auto current_timestamp = Timestamp::GetCurrentTimestamp(); - - UnaryExecutor::ExecuteWithNulls(input.data[0], result, input.size(), - [&](timestamp_t input, ValidityMask &mask, idx_t idx) { - if (Timestamp::IsFinite(input)) { - return Interval::GetAge(current_timestamp, input); - } else { - mask.SetInvalid(idx); - return interval_t(); - } - }); -} - -static void AgeFunction(DataChunk &input, ExpressionState &state, Vector &result) { - D_ASSERT(input.ColumnCount() == 2); - - BinaryExecutor::ExecuteWithNulls( - input.data[0], input.data[1], result, input.size(), - [&](timestamp_t input1, timestamp_t input2, ValidityMask &mask, idx_t idx) { - if (Timestamp::IsFinite(input1) && Timestamp::IsFinite(input2)) { - return Interval::GetAge(input1, input2); - } else { - mask.SetInvalid(idx); - return interval_t(); - } - }); -} - -ScalarFunctionSet AgeFun::GetFunctions() { - ScalarFunctionSet age("age"); - age.AddFunction(ScalarFunction({LogicalType::TIMESTAMP}, LogicalType::INTERVAL, AgeFunctionStandard)); - age.AddFunction( - ScalarFunction({LogicalType::TIMESTAMP, LogicalType::TIMESTAMP}, LogicalType::INTERVAL, AgeFunction)); - return age; -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -static timestamp_t GetTransactionTimestamp(ExpressionState &state) { - return MetaTransaction::Get(state.GetContext()).start_timestamp; -} - -static void CurrentTimeFunction(DataChunk &input, ExpressionState &state, Vector &result) { - D_ASSERT(input.ColumnCount() == 0); - auto val = Value::TIME(Timestamp::GetTime(GetTransactionTimestamp(state))); - result.Reference(val); -} - -static void CurrentDateFunction(DataChunk &input, ExpressionState &state, Vector &result) { - D_ASSERT(input.ColumnCount() == 0); - - auto val = Value::DATE(Timestamp::GetDate(GetTransactionTimestamp(state))); - result.Reference(val); -} - -static void CurrentTimestampFunction(DataChunk &input, ExpressionState &state, Vector &result) { - D_ASSERT(input.ColumnCount() == 0); - - auto val = Value::TIMESTAMPTZ(GetTransactionTimestamp(state)); - result.Reference(val); -} - -ScalarFunction CurrentTimeFun::GetFunction() { - ScalarFunction current_time({}, LogicalType::TIME, CurrentTimeFunction); - current_time.side_effects = FunctionSideEffects::HAS_SIDE_EFFECTS; - return current_time; -} - -ScalarFunction CurrentDateFun::GetFunction() { - ScalarFunction current_date({}, LogicalType::DATE, CurrentDateFunction); - current_date.side_effects = FunctionSideEffects::HAS_SIDE_EFFECTS; - return current_date; -} - -ScalarFunction GetCurrentTimestampFun::GetFunction() { - ScalarFunction current_timestamp({}, LogicalType::TIMESTAMP_TZ, CurrentTimestampFunction); - current_timestamp.side_effects = FunctionSideEffects::HAS_SIDE_EFFECTS; - return current_timestamp; -} - -} // namespace duckdb - - - - - - - - - - - - -namespace duckdb { - -// This function is an implementation of the "period-crossing" date difference function from T-SQL -// https://docs.microsoft.com/en-us/sql/t-sql/functions/datediff-transact-sql?view=sql-server-ver15 -struct DateDiff { - template - static inline void BinaryExecute(Vector &left, Vector &right, Vector &result, idx_t count) { - BinaryExecutor::ExecuteWithNulls( - left, right, result, count, [&](TA startdate, TB enddate, ValidityMask &mask, idx_t idx) { - if (Value::IsFinite(startdate) && Value::IsFinite(enddate)) { - return OP::template Operation(startdate, enddate); - } else { - mask.SetInvalid(idx); - return TR(); - } - }); - } - - struct YearOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - return Date::ExtractYear(enddate) - Date::ExtractYear(startdate); - } - }; - - struct MonthOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - int32_t start_year, start_month, start_day; - Date::Convert(startdate, start_year, start_month, start_day); - int32_t end_year, end_month, end_day; - Date::Convert(enddate, end_year, end_month, end_day); - - return (end_year * 12 + end_month - 1) - (start_year * 12 + start_month - 1); - } - }; - - struct DayOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - return TR(Date::EpochDays(enddate)) - TR(Date::EpochDays(startdate)); - } - }; - - struct DecadeOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - return Date::ExtractYear(enddate) / 10 - Date::ExtractYear(startdate) / 10; - } - }; - - struct CenturyOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - return Date::ExtractYear(enddate) / 100 - Date::ExtractYear(startdate) / 100; - } - }; - - struct MilleniumOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - return Date::ExtractYear(enddate) / 1000 - Date::ExtractYear(startdate) / 1000; - } - }; - - struct QuarterOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - int32_t start_year, start_month, start_day; - Date::Convert(startdate, start_year, start_month, start_day); - int32_t end_year, end_month, end_day; - Date::Convert(enddate, end_year, end_month, end_day); - - return (end_year * 12 + end_month - 1) / Interval::MONTHS_PER_QUARTER - - (start_year * 12 + start_month - 1) / Interval::MONTHS_PER_QUARTER; - } - }; - - struct WeekOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - return Date::Epoch(Date::GetMondayOfCurrentWeek(enddate)) / Interval::SECS_PER_WEEK - - Date::Epoch(Date::GetMondayOfCurrentWeek(startdate)) / Interval::SECS_PER_WEEK; - } - }; - - struct ISOYearOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - return Date::ExtractISOYearNumber(enddate) - Date::ExtractISOYearNumber(startdate); - } - }; - - struct MicrosecondsOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - return Date::EpochMicroseconds(enddate) - Date::EpochMicroseconds(startdate); - } - }; - - struct MillisecondsOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - return Date::EpochMicroseconds(enddate) / Interval::MICROS_PER_MSEC - - Date::EpochMicroseconds(startdate) / Interval::MICROS_PER_MSEC; - } - }; - - struct SecondsOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - return Date::Epoch(enddate) - Date::Epoch(startdate); - } - }; - - struct MinutesOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - return Date::Epoch(enddate) / Interval::SECS_PER_MINUTE - - Date::Epoch(startdate) / Interval::SECS_PER_MINUTE; - } - }; - - struct HoursOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - return Date::Epoch(enddate) / Interval::SECS_PER_HOUR - Date::Epoch(startdate) / Interval::SECS_PER_HOUR; - } - }; -}; - -// TIMESTAMP specialisations -template <> -int64_t DateDiff::YearOperator::Operation(timestamp_t startdate, timestamp_t enddate) { - return YearOperator::Operation(Timestamp::GetDate(startdate), Timestamp::GetDate(enddate)); -} - -template <> -int64_t DateDiff::MonthOperator::Operation(timestamp_t startdate, timestamp_t enddate) { - return MonthOperator::Operation(Timestamp::GetDate(startdate), - Timestamp::GetDate(enddate)); -} - -template <> -int64_t DateDiff::DayOperator::Operation(timestamp_t startdate, timestamp_t enddate) { - return DayOperator::Operation(Timestamp::GetDate(startdate), Timestamp::GetDate(enddate)); -} - -template <> -int64_t DateDiff::DecadeOperator::Operation(timestamp_t startdate, timestamp_t enddate) { - return DecadeOperator::Operation(Timestamp::GetDate(startdate), - Timestamp::GetDate(enddate)); -} - -template <> -int64_t DateDiff::CenturyOperator::Operation(timestamp_t startdate, timestamp_t enddate) { - return CenturyOperator::Operation(Timestamp::GetDate(startdate), - Timestamp::GetDate(enddate)); -} - -template <> -int64_t DateDiff::MilleniumOperator::Operation(timestamp_t startdate, timestamp_t enddate) { - return MilleniumOperator::Operation(Timestamp::GetDate(startdate), - Timestamp::GetDate(enddate)); -} - -template <> -int64_t DateDiff::QuarterOperator::Operation(timestamp_t startdate, timestamp_t enddate) { - return QuarterOperator::Operation(Timestamp::GetDate(startdate), - Timestamp::GetDate(enddate)); -} - -template <> -int64_t DateDiff::WeekOperator::Operation(timestamp_t startdate, timestamp_t enddate) { - return WeekOperator::Operation(Timestamp::GetDate(startdate), Timestamp::GetDate(enddate)); -} - -template <> -int64_t DateDiff::ISOYearOperator::Operation(timestamp_t startdate, timestamp_t enddate) { - return ISOYearOperator::Operation(Timestamp::GetDate(startdate), - Timestamp::GetDate(enddate)); -} - -template <> -int64_t DateDiff::MicrosecondsOperator::Operation(timestamp_t startdate, timestamp_t enddate) { - const auto start = Timestamp::GetEpochMicroSeconds(startdate); - const auto end = Timestamp::GetEpochMicroSeconds(enddate); - return SubtractOperatorOverflowCheck::Operation(end, start); -} - -template <> -int64_t DateDiff::MillisecondsOperator::Operation(timestamp_t startdate, timestamp_t enddate) { - return Timestamp::GetEpochMs(enddate) - Timestamp::GetEpochMs(startdate); -} - -template <> -int64_t DateDiff::SecondsOperator::Operation(timestamp_t startdate, timestamp_t enddate) { - return Timestamp::GetEpochSeconds(enddate) - Timestamp::GetEpochSeconds(startdate); -} - -template <> -int64_t DateDiff::MinutesOperator::Operation(timestamp_t startdate, timestamp_t enddate) { - return Timestamp::GetEpochSeconds(enddate) / Interval::SECS_PER_MINUTE - - Timestamp::GetEpochSeconds(startdate) / Interval::SECS_PER_MINUTE; -} - -template <> -int64_t DateDiff::HoursOperator::Operation(timestamp_t startdate, timestamp_t enddate) { - return Timestamp::GetEpochSeconds(enddate) / Interval::SECS_PER_HOUR - - Timestamp::GetEpochSeconds(startdate) / Interval::SECS_PER_HOUR; -} - -// TIME specialisations -template <> -int64_t DateDiff::YearOperator::Operation(dtime_t startdate, dtime_t enddate) { - throw NotImplementedException("\"time\" units \"year\" not recognized"); -} - -template <> -int64_t DateDiff::MonthOperator::Operation(dtime_t startdate, dtime_t enddate) { - throw NotImplementedException("\"time\" units \"month\" not recognized"); -} - -template <> -int64_t DateDiff::DayOperator::Operation(dtime_t startdate, dtime_t enddate) { - throw NotImplementedException("\"time\" units \"day\" not recognized"); -} - -template <> -int64_t DateDiff::DecadeOperator::Operation(dtime_t startdate, dtime_t enddate) { - throw NotImplementedException("\"time\" units \"decade\" not recognized"); -} - -template <> -int64_t DateDiff::CenturyOperator::Operation(dtime_t startdate, dtime_t enddate) { - throw NotImplementedException("\"time\" units \"century\" not recognized"); -} - -template <> -int64_t DateDiff::MilleniumOperator::Operation(dtime_t startdate, dtime_t enddate) { - throw NotImplementedException("\"time\" units \"millennium\" not recognized"); -} - -template <> -int64_t DateDiff::QuarterOperator::Operation(dtime_t startdate, dtime_t enddate) { - throw NotImplementedException("\"time\" units \"quarter\" not recognized"); -} - -template <> -int64_t DateDiff::WeekOperator::Operation(dtime_t startdate, dtime_t enddate) { - throw NotImplementedException("\"time\" units \"week\" not recognized"); -} - -template <> -int64_t DateDiff::ISOYearOperator::Operation(dtime_t startdate, dtime_t enddate) { - throw NotImplementedException("\"time\" units \"isoyear\" not recognized"); -} - -template <> -int64_t DateDiff::MicrosecondsOperator::Operation(dtime_t startdate, dtime_t enddate) { - return enddate.micros - startdate.micros; -} - -template <> -int64_t DateDiff::MillisecondsOperator::Operation(dtime_t startdate, dtime_t enddate) { - return enddate.micros / Interval::MICROS_PER_MSEC - startdate.micros / Interval::MICROS_PER_MSEC; -} - -template <> -int64_t DateDiff::SecondsOperator::Operation(dtime_t startdate, dtime_t enddate) { - return enddate.micros / Interval::MICROS_PER_SEC - startdate.micros / Interval::MICROS_PER_SEC; -} - -template <> -int64_t DateDiff::MinutesOperator::Operation(dtime_t startdate, dtime_t enddate) { - return enddate.micros / Interval::MICROS_PER_MINUTE - startdate.micros / Interval::MICROS_PER_MINUTE; -} - -template <> -int64_t DateDiff::HoursOperator::Operation(dtime_t startdate, dtime_t enddate) { - return enddate.micros / Interval::MICROS_PER_HOUR - startdate.micros / Interval::MICROS_PER_HOUR; -} - -template -static int64_t DifferenceDates(DatePartSpecifier type, TA startdate, TB enddate) { - switch (type) { - case DatePartSpecifier::YEAR: - return DateDiff::YearOperator::template Operation(startdate, enddate); - case DatePartSpecifier::MONTH: - return DateDiff::MonthOperator::template Operation(startdate, enddate); - case DatePartSpecifier::DAY: - case DatePartSpecifier::DOW: - case DatePartSpecifier::ISODOW: - case DatePartSpecifier::DOY: - case DatePartSpecifier::JULIAN_DAY: - return DateDiff::DayOperator::template Operation(startdate, enddate); - case DatePartSpecifier::DECADE: - return DateDiff::DecadeOperator::template Operation(startdate, enddate); - case DatePartSpecifier::CENTURY: - return DateDiff::CenturyOperator::template Operation(startdate, enddate); - case DatePartSpecifier::MILLENNIUM: - return DateDiff::MilleniumOperator::template Operation(startdate, enddate); - case DatePartSpecifier::QUARTER: - return DateDiff::QuarterOperator::template Operation(startdate, enddate); - case DatePartSpecifier::WEEK: - case DatePartSpecifier::YEARWEEK: - return DateDiff::WeekOperator::template Operation(startdate, enddate); - case DatePartSpecifier::ISOYEAR: - return DateDiff::ISOYearOperator::template Operation(startdate, enddate); - case DatePartSpecifier::MICROSECONDS: - return DateDiff::MicrosecondsOperator::template Operation(startdate, enddate); - case DatePartSpecifier::MILLISECONDS: - return DateDiff::MillisecondsOperator::template Operation(startdate, enddate); - case DatePartSpecifier::SECOND: - case DatePartSpecifier::EPOCH: - return DateDiff::SecondsOperator::template Operation(startdate, enddate); - case DatePartSpecifier::MINUTE: - return DateDiff::MinutesOperator::template Operation(startdate, enddate); - case DatePartSpecifier::HOUR: - return DateDiff::HoursOperator::template Operation(startdate, enddate); - default: - throw NotImplementedException("Specifier type not implemented for DATEDIFF"); - } -} - -struct DateDiffTernaryOperator { - template - static inline TR Operation(TS part, TA startdate, TB enddate, ValidityMask &mask, idx_t idx) { - if (Value::IsFinite(startdate) && Value::IsFinite(enddate)) { - return DifferenceDates(GetDatePartSpecifier(part.GetString()), startdate, enddate); - } else { - mask.SetInvalid(idx); - return TR(); - } - } -}; - -template -static void DateDiffBinaryExecutor(DatePartSpecifier type, Vector &left, Vector &right, Vector &result, idx_t count) { - switch (type) { - case DatePartSpecifier::YEAR: - DateDiff::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::MONTH: - DateDiff::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::DAY: - case DatePartSpecifier::DOW: - case DatePartSpecifier::ISODOW: - case DatePartSpecifier::DOY: - case DatePartSpecifier::JULIAN_DAY: - DateDiff::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::DECADE: - DateDiff::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::CENTURY: - DateDiff::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::MILLENNIUM: - DateDiff::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::QUARTER: - DateDiff::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::WEEK: - case DatePartSpecifier::YEARWEEK: - DateDiff::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::ISOYEAR: - DateDiff::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::MICROSECONDS: - DateDiff::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::MILLISECONDS: - DateDiff::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::SECOND: - case DatePartSpecifier::EPOCH: - DateDiff::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::MINUTE: - DateDiff::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::HOUR: - DateDiff::BinaryExecute(left, right, result, count); - break; - default: - throw NotImplementedException("Specifier type not implemented for DATEDIFF"); - } -} - -template -static void DateDiffFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 3); - auto &part_arg = args.data[0]; - auto &start_arg = args.data[1]; - auto &end_arg = args.data[2]; - - if (part_arg.GetVectorType() == VectorType::CONSTANT_VECTOR) { - // Common case of constant part. - if (ConstantVector::IsNull(part_arg)) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - } else { - const auto type = GetDatePartSpecifier(ConstantVector::GetData(part_arg)->GetString()); - DateDiffBinaryExecutor(type, start_arg, end_arg, result, args.size()); - } - } else { - TernaryExecutor::ExecuteWithNulls( - part_arg, start_arg, end_arg, result, args.size(), - DateDiffTernaryOperator::Operation); - } -} - -ScalarFunctionSet DateDiffFun::GetFunctions() { - ScalarFunctionSet date_diff("date_diff"); - date_diff.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::DATE, LogicalType::DATE}, - LogicalType::BIGINT, DateDiffFunction)); - date_diff.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::TIMESTAMP, LogicalType::TIMESTAMP}, - LogicalType::BIGINT, DateDiffFunction)); - date_diff.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::TIME, LogicalType::TIME}, - LogicalType::BIGINT, DateDiffFunction)); - return date_diff; -} - -} // namespace duckdb - - - - - - - - - - - - - -namespace duckdb { - -DatePartSpecifier GetDateTypePartSpecifier(const string &specifier, LogicalType &type) { - const auto part = GetDatePartSpecifier(specifier); - switch (type.id()) { - case LogicalType::TIMESTAMP: - case LogicalType::TIMESTAMP_TZ: - return part; - case LogicalType::DATE: - switch (part) { - case DatePartSpecifier::YEAR: - case DatePartSpecifier::MONTH: - case DatePartSpecifier::DAY: - case DatePartSpecifier::DECADE: - case DatePartSpecifier::CENTURY: - case DatePartSpecifier::MILLENNIUM: - case DatePartSpecifier::DOW: - case DatePartSpecifier::ISODOW: - case DatePartSpecifier::ISOYEAR: - case DatePartSpecifier::WEEK: - case DatePartSpecifier::QUARTER: - case DatePartSpecifier::DOY: - case DatePartSpecifier::YEARWEEK: - case DatePartSpecifier::ERA: - case DatePartSpecifier::EPOCH: - case DatePartSpecifier::JULIAN_DAY: - return part; - default: - break; - } - break; - case LogicalType::TIME: - switch (part) { - case DatePartSpecifier::MICROSECONDS: - case DatePartSpecifier::MILLISECONDS: - case DatePartSpecifier::SECOND: - case DatePartSpecifier::MINUTE: - case DatePartSpecifier::HOUR: - case DatePartSpecifier::EPOCH: - case DatePartSpecifier::TIMEZONE: - case DatePartSpecifier::TIMEZONE_HOUR: - case DatePartSpecifier::TIMEZONE_MINUTE: - return part; - default: - break; - } - break; - case LogicalType::INTERVAL: - switch (part) { - case DatePartSpecifier::YEAR: - case DatePartSpecifier::MONTH: - case DatePartSpecifier::DAY: - case DatePartSpecifier::DECADE: - case DatePartSpecifier::CENTURY: - case DatePartSpecifier::QUARTER: - case DatePartSpecifier::MILLENNIUM: - case DatePartSpecifier::MICROSECONDS: - case DatePartSpecifier::MILLISECONDS: - case DatePartSpecifier::SECOND: - case DatePartSpecifier::MINUTE: - case DatePartSpecifier::HOUR: - case DatePartSpecifier::EPOCH: - return part; - default: - break; - } - break; - default: - break; - } - - throw NotImplementedException("\"%s\" units \"%s\" not recognized", EnumUtil::ToString(type.id()), specifier); -} - -template -static unique_ptr PropagateSimpleDatePartStatistics(vector &child_stats) { - // we can always propagate simple date part statistics - // since the min and max can never exceed these bounds - auto result = NumericStats::CreateEmpty(LogicalType::BIGINT); - result.CopyValidity(child_stats[0]); - NumericStats::SetMin(result, Value::BIGINT(MIN)); - NumericStats::SetMax(result, Value::BIGINT(MAX)); - return result.ToUnique(); -} - -struct DatePart { - template - static unique_ptr PropagateDatePartStatistics(vector &child_stats, - const LogicalType &stats_type = LogicalType::BIGINT) { - // we can only propagate complex date part stats if the child has stats - auto &nstats = child_stats[0]; - if (!NumericStats::HasMinMax(nstats)) { - return nullptr; - } - // run the operator on both the min and the max, this gives us the [min, max] bound - auto min = NumericStats::GetMin(nstats); - auto max = NumericStats::GetMax(nstats); - if (min > max) { - return nullptr; - } - // Infinities prevent us from computing generic ranges - if (!Value::IsFinite(min) || !Value::IsFinite(max)) { - return nullptr; - } - TR min_part = OP::template Operation(min); - TR max_part = OP::template Operation(max); - auto result = NumericStats::CreateEmpty(stats_type); - NumericStats::SetMin(result, Value(min_part)); - NumericStats::SetMax(result, Value(max_part)); - result.CopyValidity(child_stats[0]); - return result.ToUnique(); - } - - template - struct PartOperator { - template - static inline TR Operation(TA input, ValidityMask &mask, idx_t idx, void *dataptr) { - if (Value::IsFinite(input)) { - return OP::template Operation(input); - } else { - mask.SetInvalid(idx); - return TR(); - } - } - }; - - template - static void UnaryFunction(DataChunk &input, ExpressionState &state, Vector &result) { - D_ASSERT(input.ColumnCount() >= 1); - using IOP = PartOperator; - UnaryExecutor::GenericExecute(input.data[0], result, input.size(), nullptr, true); - } - - struct YearOperator { - template - static inline TR Operation(TA input) { - return Date::ExtractYear(input); - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateDatePartStatistics(input.child_stats); - } - }; - - struct MonthOperator { - template - static inline TR Operation(TA input) { - return Date::ExtractMonth(input); - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - // min/max of month operator is [1, 12] - return PropagateSimpleDatePartStatistics<1, 12>(input.child_stats); - } - }; - - struct DayOperator { - template - static inline TR Operation(TA input) { - return Date::ExtractDay(input); - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - // min/max of day operator is [1, 31] - return PropagateSimpleDatePartStatistics<1, 31>(input.child_stats); - } - }; - - struct DecadeOperator { - // From the PG docs: "The year field divided by 10" - template - static inline TR DecadeFromYear(TR yyyy) { - return yyyy / 10; - } - - template - static inline TR Operation(TA input) { - return DecadeFromYear(YearOperator::Operation(input)); - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateDatePartStatistics(input.child_stats); - } - }; - - struct CenturyOperator { - // From the PG docs: - // "The first century starts at 0001-01-01 00:00:00 AD, although they did not know it at the time. - // This definition applies to all Gregorian calendar countries. - // There is no century number 0, you go from -1 century to 1 century. - // If you disagree with this, please write your complaint to: Pope, Cathedral Saint-Peter of Roma, Vatican." - // (To be fair, His Holiness had nothing to do with this - - // it was the lack of zero in the counting systems of the time...) - template - static inline TR CenturyFromYear(TR yyyy) { - if (yyyy > 0) { - return ((yyyy - 1) / 100) + 1; - } else { - return (yyyy / 100) - 1; - } - } - - template - static inline TR Operation(TA input) { - return CenturyFromYear(YearOperator::Operation(input)); - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateDatePartStatistics(input.child_stats); - } - }; - - struct MillenniumOperator { - // See the century comment - template - static inline TR MillenniumFromYear(TR yyyy) { - if (yyyy > 0) { - return ((yyyy - 1) / 1000) + 1; - } else { - return (yyyy / 1000) - 1; - } - } - - template - static inline TR Operation(TA input) { - return MillenniumFromYear(YearOperator::Operation(input)); - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateDatePartStatistics(input.child_stats); - } - }; - - struct QuarterOperator { - template - static inline TR QuarterFromMonth(TR mm) { - return (mm - 1) / Interval::MONTHS_PER_QUARTER + 1; - } - - template - static inline TR Operation(TA input) { - return QuarterFromMonth(Date::ExtractMonth(input)); - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - // min/max of quarter operator is [1, 4] - return PropagateSimpleDatePartStatistics<1, 4>(input.child_stats); - } - }; - - struct DayOfWeekOperator { - template - static inline TR DayOfWeekFromISO(TR isodow) { - // day of the week (Sunday = 0, Saturday = 6) - // turn sunday into 0 by doing mod 7 - return isodow % 7; - } - - template - static inline TR Operation(TA input) { - return DayOfWeekFromISO(Date::ExtractISODayOfTheWeek(input)); - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateSimpleDatePartStatistics<0, 6>(input.child_stats); - } - }; - - struct ISODayOfWeekOperator { - template - static inline TR Operation(TA input) { - // isodow (Monday = 1, Sunday = 7) - return Date::ExtractISODayOfTheWeek(input); - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateSimpleDatePartStatistics<1, 7>(input.child_stats); - } - }; - - struct DayOfYearOperator { - template - static inline TR Operation(TA input) { - return Date::ExtractDayOfTheYear(input); - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateSimpleDatePartStatistics<1, 366>(input.child_stats); - } - }; - - struct WeekOperator { - template - static inline TR Operation(TA input) { - return Date::ExtractISOWeekNumber(input); - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateSimpleDatePartStatistics<1, 54>(input.child_stats); - } - }; - - struct ISOYearOperator { - template - static inline TR Operation(TA input) { - return Date::ExtractISOYearNumber(input); - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateDatePartStatistics(input.child_stats); - } - }; - - struct YearWeekOperator { - template - static inline TR YearWeekFromParts(TR yyyy, TR ww) { - return yyyy * 100 + ((yyyy > 0) ? ww : -ww); - } - - template - static inline TR Operation(TA input) { - int32_t yyyy, ww; - Date::ExtractISOYearWeek(input, yyyy, ww); - return YearWeekFromParts(yyyy, ww); - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateDatePartStatistics(input.child_stats); - } - }; - - struct EpochNanosecondsOperator { - template - static inline TR Operation(TA input) { - return input.micros * Interval::NANOS_PER_MICRO; - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateDatePartStatistics(input.child_stats); - } - }; - - struct EpochMicrosecondsOperator { - template - static inline TR Operation(TA input) { - return input.micros; - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateDatePartStatistics(input.child_stats); - } - }; - - struct EpochMillisOperator { - template - static inline TR Operation(TA input) { - return input.micros / Interval::MICROS_PER_MSEC; - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateDatePartStatistics(input.child_stats); - } - - static void Inverse(DataChunk &input, ExpressionState &state, Vector &result) { - D_ASSERT(input.ColumnCount() == 1); - - UnaryExecutor::Execute(input.data[0], result, input.size(), - [&](int64_t input) { return Timestamp::FromEpochMs(input); }); - } - }; - - struct MicrosecondsOperator { - template - static inline TR Operation(TA input) { - return 0; - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateSimpleDatePartStatistics<0, 60000000>(input.child_stats); - } - }; - - struct MillisecondsOperator { - template - static inline TR Operation(TA input) { - return 0; - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateSimpleDatePartStatistics<0, 60000>(input.child_stats); - } - }; - - struct SecondsOperator { - template - static inline TR Operation(TA input) { - return 0; - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateSimpleDatePartStatistics<0, 60>(input.child_stats); - } - }; - - struct MinutesOperator { - template - static inline TR Operation(TA input) { - return 0; - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateSimpleDatePartStatistics<0, 60>(input.child_stats); - } - }; - - struct HoursOperator { - template - static inline TR Operation(TA input) { - return 0; - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateSimpleDatePartStatistics<0, 24>(input.child_stats); - } - }; - - struct EpochOperator { - template - static inline TR Operation(TA input) { - return Date::Epoch(input); - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateDatePartStatistics(input.child_stats, LogicalType::DOUBLE); - } - }; - - struct EraOperator { - template - static inline TR EraFromYear(TR yyyy) { - return yyyy > 0 ? 1 : 0; - } - - template - static inline TR Operation(TA input) { - return EraFromYear(Date::ExtractYear(input)); - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateSimpleDatePartStatistics<0, 1>(input.child_stats); - } - }; - - struct TimezoneOperator { - template - static inline TR Operation(TA input) { - // Regular timestamps are UTC. - return 0; - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateSimpleDatePartStatistics<0, 0>(input.child_stats); - } - }; - - struct JulianDayOperator { - template - static inline TR Operation(TA input) { - return Timestamp::GetJulianDay(input); - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateDatePartStatistics(input.child_stats, LogicalType::DOUBLE); - } - }; - - // These are all zero and have the same restrictions - using TimezoneHourOperator = TimezoneOperator; - using TimezoneMinuteOperator = TimezoneOperator; - - struct StructOperator { - using part_codes_t = vector; - using part_mask_t = uint64_t; - - enum MaskBits : uint8_t { - YMD = 1 << 0, - DOW = 1 << 1, - DOY = 1 << 2, - EPOCH = 1 << 3, - TIME = 1 << 4, - ZONE = 1 << 5, - ISO = 1 << 6, - JD = 1 << 7 - }; - - static part_mask_t GetMask(const part_codes_t &part_codes) { - part_mask_t mask = 0; - for (const auto &part_code : part_codes) { - switch (part_code) { - case DatePartSpecifier::YEAR: - case DatePartSpecifier::MONTH: - case DatePartSpecifier::DAY: - case DatePartSpecifier::DECADE: - case DatePartSpecifier::CENTURY: - case DatePartSpecifier::MILLENNIUM: - case DatePartSpecifier::QUARTER: - case DatePartSpecifier::ERA: - mask |= YMD; - break; - case DatePartSpecifier::YEARWEEK: - case DatePartSpecifier::WEEK: - case DatePartSpecifier::ISOYEAR: - mask |= ISO; - break; - case DatePartSpecifier::DOW: - case DatePartSpecifier::ISODOW: - mask |= DOW; - break; - case DatePartSpecifier::DOY: - mask |= DOY; - break; - case DatePartSpecifier::EPOCH: - mask |= EPOCH; - break; - case DatePartSpecifier::JULIAN_DAY: - mask |= JD; - break; - case DatePartSpecifier::MICROSECONDS: - case DatePartSpecifier::MILLISECONDS: - case DatePartSpecifier::SECOND: - case DatePartSpecifier::MINUTE: - case DatePartSpecifier::HOUR: - mask |= TIME; - break; - case DatePartSpecifier::TIMEZONE: - case DatePartSpecifier::TIMEZONE_HOUR: - case DatePartSpecifier::TIMEZONE_MINUTE: - mask |= ZONE; - break; - case DatePartSpecifier::INVALID: - throw InternalException("Invalid DatePartSpecifier for STRUCT mask!"); - } - } - return mask; - } - - template - static inline P HasPartValue(vector

part_values, DatePartSpecifier part) { - auto idx = size_t(part); - if (IsBigintDatepart(part)) { - return part_values[idx - size_t(DatePartSpecifier::BEGIN_BIGINT)]; - } else { - return part_values[idx - size_t(DatePartSpecifier::BEGIN_DOUBLE)]; - } - } - - using bigint_vec = vector; - using double_vec = vector; - - template - static inline void Operation(bigint_vec &bigint_values, double_vec &double_values, const TA &input, - const idx_t idx, const part_mask_t mask) { - int64_t *bigint_data; - // YMD calculations - int32_t yyyy = 1970; - int32_t mm = 0; - int32_t dd = 1; - if (mask & YMD) { - Date::Convert(input, yyyy, mm, dd); - bigint_data = HasPartValue(bigint_values, DatePartSpecifier::YEAR); - if (bigint_data) { - bigint_data[idx] = yyyy; - } - bigint_data = HasPartValue(bigint_values, DatePartSpecifier::MONTH); - if (bigint_data) { - bigint_data[idx] = mm; - } - bigint_data = HasPartValue(bigint_values, DatePartSpecifier::DAY); - if (bigint_data) { - bigint_data[idx] = dd; - } - bigint_data = HasPartValue(bigint_values, DatePartSpecifier::DECADE); - if (bigint_data) { - bigint_data[idx] = DecadeOperator::DecadeFromYear(yyyy); - } - bigint_data = HasPartValue(bigint_values, DatePartSpecifier::CENTURY); - if (bigint_data) { - bigint_data[idx] = CenturyOperator::CenturyFromYear(yyyy); - } - bigint_data = HasPartValue(bigint_values, DatePartSpecifier::MILLENNIUM); - if (bigint_data) { - bigint_data[idx] = MillenniumOperator::MillenniumFromYear(yyyy); - } - bigint_data = HasPartValue(bigint_values, DatePartSpecifier::QUARTER); - if (bigint_data) { - bigint_data[idx] = QuarterOperator::QuarterFromMonth(mm); - } - bigint_data = HasPartValue(bigint_values, DatePartSpecifier::ERA); - if (bigint_data) { - bigint_data[idx] = EraOperator::EraFromYear(yyyy); - } - } - - // Week calculations - if (mask & DOW) { - auto isodow = Date::ExtractISODayOfTheWeek(input); - bigint_data = HasPartValue(bigint_values, DatePartSpecifier::DOW); - if (bigint_data) { - bigint_data[idx] = DayOfWeekOperator::DayOfWeekFromISO(isodow); - } - bigint_data = HasPartValue(bigint_values, DatePartSpecifier::ISODOW); - if (bigint_data) { - bigint_data[idx] = isodow; - } - } - - // ISO calculations - if (mask & ISO) { - int32_t ww = 0; - int32_t iyyy = 0; - Date::ExtractISOYearWeek(input, iyyy, ww); - bigint_data = HasPartValue(bigint_values, DatePartSpecifier::WEEK); - if (bigint_data) { - bigint_data[idx] = ww; - } - bigint_data = HasPartValue(bigint_values, DatePartSpecifier::ISOYEAR); - if (bigint_data) { - bigint_data[idx] = iyyy; - } - bigint_data = HasPartValue(bigint_values, DatePartSpecifier::YEARWEEK); - if (bigint_data) { - bigint_data[idx] = YearWeekOperator::YearWeekFromParts(iyyy, ww); - } - } - - if (mask & EPOCH) { - auto double_data = HasPartValue(double_values, DatePartSpecifier::EPOCH); - if (double_data) { - double_data[idx] = Date::Epoch(input); - } - } - if (mask & DOY) { - bigint_data = HasPartValue(bigint_values, DatePartSpecifier::DOY); - if (bigint_data) { - bigint_data[idx] = Date::ExtractDayOfTheYear(input); - } - } - if (mask & JD) { - auto double_data = HasPartValue(double_values, DatePartSpecifier::JULIAN_DAY); - if (double_data) { - double_data[idx] = Date::ExtractJulianDay(input); - } - } - } - }; -}; - -template -static void LastYearFunction(DataChunk &args, ExpressionState &state, Vector &result) { - int32_t last_year = 0; - UnaryExecutor::ExecuteWithNulls(args.data[0], result, args.size(), - [&](T input, ValidityMask &mask, idx_t idx) { - if (Value::IsFinite(input)) { - return Date::ExtractYear(input, &last_year); - } else { - mask.SetInvalid(idx); - return 0; - } - }); -} - -template <> -int64_t DatePart::YearOperator::Operation(timestamp_t input) { - return YearOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -int64_t DatePart::YearOperator::Operation(interval_t input) { - return input.months / Interval::MONTHS_PER_YEAR; -} - -template <> -int64_t DatePart::YearOperator::Operation(dtime_t input) { - throw NotImplementedException("\"time\" units \"year\" not recognized"); -} - -template <> -int64_t DatePart::MonthOperator::Operation(timestamp_t input) { - return MonthOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -int64_t DatePart::MonthOperator::Operation(interval_t input) { - return input.months % Interval::MONTHS_PER_YEAR; -} - -template <> -int64_t DatePart::MonthOperator::Operation(dtime_t input) { - throw NotImplementedException("\"time\" units \"month\" not recognized"); -} - -template <> -int64_t DatePart::DayOperator::Operation(timestamp_t input) { - return DayOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -int64_t DatePart::DayOperator::Operation(interval_t input) { - return input.days; -} - -template <> -int64_t DatePart::DayOperator::Operation(dtime_t input) { - throw NotImplementedException("\"time\" units \"day\" not recognized"); -} - -template <> -int64_t DatePart::DecadeOperator::Operation(interval_t input) { - return input.months / Interval::MONTHS_PER_DECADE; -} - -template <> -int64_t DatePart::DecadeOperator::Operation(dtime_t input) { - throw NotImplementedException("\"time\" units \"decade\" not recognized"); -} - -template <> -int64_t DatePart::CenturyOperator::Operation(interval_t input) { - return input.months / Interval::MONTHS_PER_CENTURY; -} - -template <> -int64_t DatePart::CenturyOperator::Operation(dtime_t input) { - throw NotImplementedException("\"time\" units \"century\" not recognized"); -} - -template <> -int64_t DatePart::MillenniumOperator::Operation(interval_t input) { - return input.months / Interval::MONTHS_PER_MILLENIUM; -} - -template <> -int64_t DatePart::MillenniumOperator::Operation(dtime_t input) { - throw NotImplementedException("\"time\" units \"millennium\" not recognized"); -} - -template <> -int64_t DatePart::QuarterOperator::Operation(timestamp_t input) { - return QuarterOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -int64_t DatePart::QuarterOperator::Operation(interval_t input) { - return MonthOperator::Operation(input) / Interval::MONTHS_PER_QUARTER + 1; -} - -template <> -int64_t DatePart::QuarterOperator::Operation(dtime_t input) { - throw NotImplementedException("\"time\" units \"quarter\" not recognized"); -} - -template <> -int64_t DatePart::DayOfWeekOperator::Operation(timestamp_t input) { - return DayOfWeekOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -int64_t DatePart::DayOfWeekOperator::Operation(interval_t input) { - throw NotImplementedException("interval units \"dow\" not recognized"); -} - -template <> -int64_t DatePart::DayOfWeekOperator::Operation(dtime_t input) { - throw NotImplementedException("\"time\" units \"dow\" not recognized"); -} - -template <> -int64_t DatePart::ISODayOfWeekOperator::Operation(timestamp_t input) { - return ISODayOfWeekOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -int64_t DatePart::ISODayOfWeekOperator::Operation(interval_t input) { - throw NotImplementedException("interval units \"isodow\" not recognized"); -} - -template <> -int64_t DatePart::ISODayOfWeekOperator::Operation(dtime_t input) { - throw NotImplementedException("\"time\" units \"isodow\" not recognized"); -} - -template <> -int64_t DatePart::DayOfYearOperator::Operation(timestamp_t input) { - return DayOfYearOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -int64_t DatePart::DayOfYearOperator::Operation(interval_t input) { - throw NotImplementedException("interval units \"doy\" not recognized"); -} - -template <> -int64_t DatePart::DayOfYearOperator::Operation(dtime_t input) { - throw NotImplementedException("\"time\" units \"doy\" not recognized"); -} - -template <> -int64_t DatePart::WeekOperator::Operation(timestamp_t input) { - return WeekOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -int64_t DatePart::WeekOperator::Operation(interval_t input) { - throw NotImplementedException("interval units \"week\" not recognized"); -} - -template <> -int64_t DatePart::WeekOperator::Operation(dtime_t input) { - throw NotImplementedException("\"time\" units \"week\" not recognized"); -} - -template <> -int64_t DatePart::ISOYearOperator::Operation(timestamp_t input) { - return ISOYearOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -int64_t DatePart::ISOYearOperator::Operation(interval_t input) { - throw NotImplementedException("interval units \"isoyear\" not recognized"); -} - -template <> -int64_t DatePart::ISOYearOperator::Operation(dtime_t input) { - throw NotImplementedException("\"time\" units \"isoyear\" not recognized"); -} - -template <> -int64_t DatePart::YearWeekOperator::Operation(timestamp_t input) { - return YearWeekOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -int64_t DatePart::YearWeekOperator::Operation(interval_t input) { - const auto yyyy = YearOperator::Operation(input); - const auto ww = WeekOperator::Operation(input); - return YearWeekOperator::YearWeekFromParts(yyyy, ww); -} - -template <> -int64_t DatePart::YearWeekOperator::Operation(dtime_t input) { - throw NotImplementedException("\"time\" units \"yearweek\" not recognized"); -} - -template <> -int64_t DatePart::EpochNanosecondsOperator::Operation(timestamp_t input) { - return Timestamp::GetEpochNanoSeconds(input); -} - -template <> -int64_t DatePart::EpochNanosecondsOperator::Operation(date_t input) { - return Date::EpochNanoseconds(input); -} - -template <> -int64_t DatePart::EpochNanosecondsOperator::Operation(interval_t input) { - return Interval::GetNanoseconds(input); -} - -template <> -int64_t DatePart::EpochMicrosecondsOperator::Operation(timestamp_t input) { - return Timestamp::GetEpochMicroSeconds(input); -} - -template <> -int64_t DatePart::EpochMicrosecondsOperator::Operation(date_t input) { - return Date::EpochMicroseconds(input); -} - -template <> -int64_t DatePart::EpochMicrosecondsOperator::Operation(interval_t input) { - return Interval::GetMicro(input); -} - -template <> -int64_t DatePart::EpochMillisOperator::Operation(timestamp_t input) { - return Timestamp::GetEpochMs(input); -} - -template <> -int64_t DatePart::EpochMillisOperator::Operation(date_t input) { - return Date::EpochMilliseconds(input); -} - -template <> -int64_t DatePart::EpochMillisOperator::Operation(interval_t input) { - return Interval::GetMilli(input); -} - -template <> -int64_t DatePart::MicrosecondsOperator::Operation(timestamp_t input) { - auto time = Timestamp::GetTime(input); - // remove everything but the second & microsecond part - return time.micros % Interval::MICROS_PER_MINUTE; -} - -template <> -int64_t DatePart::MicrosecondsOperator::Operation(interval_t input) { - // remove everything but the second & microsecond part - return input.micros % Interval::MICROS_PER_MINUTE; -} - -template <> -int64_t DatePart::MicrosecondsOperator::Operation(dtime_t input) { - // remove everything but the second & microsecond part - return input.micros % Interval::MICROS_PER_MINUTE; -} - -template <> -int64_t DatePart::MillisecondsOperator::Operation(timestamp_t input) { - return MicrosecondsOperator::Operation(input) / Interval::MICROS_PER_MSEC; -} - -template <> -int64_t DatePart::MillisecondsOperator::Operation(interval_t input) { - return MicrosecondsOperator::Operation(input) / Interval::MICROS_PER_MSEC; -} - -template <> -int64_t DatePart::MillisecondsOperator::Operation(dtime_t input) { - return MicrosecondsOperator::Operation(input) / Interval::MICROS_PER_MSEC; -} - -template <> -int64_t DatePart::SecondsOperator::Operation(timestamp_t input) { - return MicrosecondsOperator::Operation(input) / Interval::MICROS_PER_SEC; -} - -template <> -int64_t DatePart::SecondsOperator::Operation(interval_t input) { - return MicrosecondsOperator::Operation(input) / Interval::MICROS_PER_SEC; -} - -template <> -int64_t DatePart::SecondsOperator::Operation(dtime_t input) { - return MicrosecondsOperator::Operation(input) / Interval::MICROS_PER_SEC; -} - -template <> -int64_t DatePart::MinutesOperator::Operation(timestamp_t input) { - auto time = Timestamp::GetTime(input); - // remove the hour part, and truncate to minutes - return (time.micros % Interval::MICROS_PER_HOUR) / Interval::MICROS_PER_MINUTE; -} - -template <> -int64_t DatePart::MinutesOperator::Operation(interval_t input) { - // remove the hour part, and truncate to minutes - return (input.micros % Interval::MICROS_PER_HOUR) / Interval::MICROS_PER_MINUTE; -} - -template <> -int64_t DatePart::MinutesOperator::Operation(dtime_t input) { - // remove the hour part, and truncate to minutes - return (input.micros % Interval::MICROS_PER_HOUR) / Interval::MICROS_PER_MINUTE; -} - -template <> -int64_t DatePart::HoursOperator::Operation(timestamp_t input) { - return Timestamp::GetTime(input).micros / Interval::MICROS_PER_HOUR; -} - -template <> -int64_t DatePart::HoursOperator::Operation(interval_t input) { - return input.micros / Interval::MICROS_PER_HOUR; -} - -template <> -int64_t DatePart::HoursOperator::Operation(dtime_t input) { - return input.micros / Interval::MICROS_PER_HOUR; -} - -template <> -double DatePart::EpochOperator::Operation(timestamp_t input) { - return Timestamp::GetEpochMicroSeconds(input) / double(Interval::MICROS_PER_SEC); -} - -template <> -double DatePart::EpochOperator::Operation(interval_t input) { - int64_t interval_years = input.months / Interval::MONTHS_PER_YEAR; - int64_t interval_days; - interval_days = Interval::DAYS_PER_YEAR * interval_years; - interval_days += Interval::DAYS_PER_MONTH * (input.months % Interval::MONTHS_PER_YEAR); - interval_days += input.days; - int64_t interval_epoch; - interval_epoch = interval_days * Interval::SECS_PER_DAY; - // we add 0.25 days per year to sort of account for leap days - interval_epoch += interval_years * (Interval::SECS_PER_DAY / 4); - return interval_epoch + input.micros / double(Interval::MICROS_PER_SEC); -} - -// TODO: We can't propagate interval statistics because we can't easily compare interval_t for order. -template <> -unique_ptr DatePart::EpochOperator::PropagateStatistics(ClientContext &context, - FunctionStatisticsInput &input) { - return nullptr; -} - -template <> -double DatePart::EpochOperator::Operation(dtime_t input) { - return input.micros / double(Interval::MICROS_PER_SEC); -} - -template <> -unique_ptr DatePart::EpochOperator::PropagateStatistics(ClientContext &context, - FunctionStatisticsInput &input) { - auto result = NumericStats::CreateEmpty(LogicalType::DOUBLE); - result.CopyValidity(input.child_stats[0]); - NumericStats::SetMin(result, Value::DOUBLE(0)); - NumericStats::SetMax(result, Value::DOUBLE(Interval::SECS_PER_DAY)); - return result.ToUnique(); -} - -template <> -int64_t DatePart::EraOperator::Operation(timestamp_t input) { - return EraOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -int64_t DatePart::EraOperator::Operation(interval_t input) { - throw NotImplementedException("interval units \"era\" not recognized"); -} - -template <> -int64_t DatePart::EraOperator::Operation(dtime_t input) { - throw NotImplementedException("\"time\" units \"era\" not recognized"); -} - -template <> -int64_t DatePart::TimezoneOperator::Operation(date_t input) { - throw NotImplementedException("\"date\" units \"timezone\" not recognized"); -} - -template <> -int64_t DatePart::TimezoneOperator::Operation(interval_t input) { - throw NotImplementedException("\"interval\" units \"timezone\" not recognized"); -} - -template <> -int64_t DatePart::TimezoneOperator::Operation(dtime_t input) { - return 0; -} - -template <> -double DatePart::JulianDayOperator::Operation(date_t input) { - return Date::ExtractJulianDay(input); -} - -template <> -double DatePart::JulianDayOperator::Operation(interval_t input) { - throw NotImplementedException("interval units \"julian\" not recognized"); -} - -template <> -double DatePart::JulianDayOperator::Operation(dtime_t input) { - throw NotImplementedException("\"time\" units \"julian\" not recognized"); -} - -template <> -void DatePart::StructOperator::Operation(bigint_vec &bigint_values, double_vec &double_values, const dtime_t &input, - const idx_t idx, const part_mask_t mask) { - int64_t *part_data; - if (mask & TIME) { - const auto micros = MicrosecondsOperator::Operation(input); - part_data = HasPartValue(bigint_values, DatePartSpecifier::MICROSECONDS); - if (part_data) { - part_data[idx] = micros; - } - part_data = HasPartValue(bigint_values, DatePartSpecifier::MILLISECONDS); - if (part_data) { - part_data[idx] = micros / Interval::MICROS_PER_MSEC; - } - part_data = HasPartValue(bigint_values, DatePartSpecifier::SECOND); - if (part_data) { - part_data[idx] = micros / Interval::MICROS_PER_SEC; - } - part_data = HasPartValue(bigint_values, DatePartSpecifier::MINUTE); - if (part_data) { - part_data[idx] = MinutesOperator::Operation(input); - } - part_data = HasPartValue(bigint_values, DatePartSpecifier::HOUR); - if (part_data) { - part_data[idx] = HoursOperator::Operation(input); - } - } - - if (mask & EPOCH) { - auto part_data = HasPartValue(double_values, DatePartSpecifier::EPOCH); - if (part_data) { - part_data[idx] = EpochOperator::Operation(input); - ; - } - } - - if (mask & ZONE) { - part_data = HasPartValue(bigint_values, DatePartSpecifier::TIMEZONE); - if (part_data) { - part_data[idx] = 0; - } - part_data = HasPartValue(bigint_values, DatePartSpecifier::TIMEZONE_HOUR); - if (part_data) { - part_data[idx] = 0; - } - part_data = HasPartValue(bigint_values, DatePartSpecifier::TIMEZONE_MINUTE); - if (part_data) { - part_data[idx] = 0; - } - } -} - -template <> -void DatePart::StructOperator::Operation(bigint_vec &bigint_values, double_vec &double_values, const timestamp_t &input, - const idx_t idx, const part_mask_t mask) { - date_t d; - dtime_t t; - Timestamp::Convert(input, d, t); - - // Both define epoch, and the correct value is the sum. - // So mask it out and compute it separately. - Operation(bigint_values, double_values, d, idx, mask & ~EPOCH); - Operation(bigint_values, double_values, t, idx, mask & ~EPOCH); - - if (mask & EPOCH) { - auto part_data = HasPartValue(double_values, DatePartSpecifier::EPOCH); - if (part_data) { - part_data[idx] = EpochOperator::Operation(input); - } - } - - if (mask & JD) { - auto part_data = HasPartValue(double_values, DatePartSpecifier::JULIAN_DAY); - if (part_data) { - part_data[idx] = JulianDayOperator::Operation(input); - } - } -} - -template <> -void DatePart::StructOperator::Operation(bigint_vec &bigint_values, double_vec &double_values, const interval_t &input, - const idx_t idx, const part_mask_t mask) { - int64_t *part_data; - if (mask & YMD) { - const auto mm = input.months % Interval::MONTHS_PER_YEAR; - part_data = HasPartValue(bigint_values, DatePartSpecifier::YEAR); - if (part_data) { - part_data[idx] = input.months / Interval::MONTHS_PER_YEAR; - } - part_data = HasPartValue(bigint_values, DatePartSpecifier::MONTH); - if (part_data) { - part_data[idx] = mm; - } - part_data = HasPartValue(bigint_values, DatePartSpecifier::DAY); - if (part_data) { - part_data[idx] = input.days; - } - part_data = HasPartValue(bigint_values, DatePartSpecifier::DECADE); - if (part_data) { - part_data[idx] = input.months / Interval::MONTHS_PER_DECADE; - } - part_data = HasPartValue(bigint_values, DatePartSpecifier::CENTURY); - if (part_data) { - part_data[idx] = input.months / Interval::MONTHS_PER_CENTURY; - } - part_data = HasPartValue(bigint_values, DatePartSpecifier::MILLENNIUM); - if (part_data) { - part_data[idx] = input.months / Interval::MONTHS_PER_MILLENIUM; - } - part_data = HasPartValue(bigint_values, DatePartSpecifier::QUARTER); - if (part_data) { - part_data[idx] = mm / Interval::MONTHS_PER_QUARTER + 1; - } - } - - if (mask & TIME) { - const auto micros = MicrosecondsOperator::Operation(input); - part_data = HasPartValue(bigint_values, DatePartSpecifier::MICROSECONDS); - if (part_data) { - part_data[idx] = micros; - } - part_data = HasPartValue(bigint_values, DatePartSpecifier::MILLISECONDS); - if (part_data) { - part_data[idx] = micros / Interval::MICROS_PER_MSEC; - } - part_data = HasPartValue(bigint_values, DatePartSpecifier::SECOND); - if (part_data) { - part_data[idx] = micros / Interval::MICROS_PER_SEC; - } - part_data = HasPartValue(bigint_values, DatePartSpecifier::MINUTE); - if (part_data) { - part_data[idx] = MinutesOperator::Operation(input); - } - part_data = HasPartValue(bigint_values, DatePartSpecifier::HOUR); - if (part_data) { - part_data[idx] = HoursOperator::Operation(input); - } - } - - if (mask & EPOCH) { - auto part_data = HasPartValue(double_values, DatePartSpecifier::EPOCH); - if (part_data) { - part_data[idx] = EpochOperator::Operation(input); - } - } -} - -template -static int64_t ExtractElement(DatePartSpecifier type, T element) { - switch (type) { - case DatePartSpecifier::YEAR: - return DatePart::YearOperator::template Operation(element); - case DatePartSpecifier::MONTH: - return DatePart::MonthOperator::template Operation(element); - case DatePartSpecifier::DAY: - return DatePart::DayOperator::template Operation(element); - case DatePartSpecifier::DECADE: - return DatePart::DecadeOperator::template Operation(element); - case DatePartSpecifier::CENTURY: - return DatePart::CenturyOperator::template Operation(element); - case DatePartSpecifier::MILLENNIUM: - return DatePart::MillenniumOperator::template Operation(element); - case DatePartSpecifier::QUARTER: - return DatePart::QuarterOperator::template Operation(element); - case DatePartSpecifier::DOW: - return DatePart::DayOfWeekOperator::template Operation(element); - case DatePartSpecifier::ISODOW: - return DatePart::ISODayOfWeekOperator::template Operation(element); - case DatePartSpecifier::DOY: - return DatePart::DayOfYearOperator::template Operation(element); - case DatePartSpecifier::WEEK: - return DatePart::WeekOperator::template Operation(element); - case DatePartSpecifier::ISOYEAR: - return DatePart::ISOYearOperator::template Operation(element); - case DatePartSpecifier::YEARWEEK: - return DatePart::YearWeekOperator::template Operation(element); - case DatePartSpecifier::MICROSECONDS: - return DatePart::MicrosecondsOperator::template Operation(element); - case DatePartSpecifier::MILLISECONDS: - return DatePart::MillisecondsOperator::template Operation(element); - case DatePartSpecifier::SECOND: - return DatePart::SecondsOperator::template Operation(element); - case DatePartSpecifier::MINUTE: - return DatePart::MinutesOperator::template Operation(element); - case DatePartSpecifier::HOUR: - return DatePart::HoursOperator::template Operation(element); - case DatePartSpecifier::ERA: - return DatePart::EraOperator::template Operation(element); - case DatePartSpecifier::TIMEZONE: - case DatePartSpecifier::TIMEZONE_HOUR: - case DatePartSpecifier::TIMEZONE_MINUTE: - return DatePart::TimezoneOperator::template Operation(element); - default: - throw NotImplementedException("Specifier type not implemented for DATEPART"); - } -} - -template -static void DatePartFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 2); - auto &spec_arg = args.data[0]; - auto &date_arg = args.data[1]; - - BinaryExecutor::ExecuteWithNulls( - spec_arg, date_arg, result, args.size(), [&](string_t specifier, T date, ValidityMask &mask, idx_t idx) { - if (Value::IsFinite(date)) { - return ExtractElement(GetDatePartSpecifier(specifier.GetString()), date); - } else { - mask.SetInvalid(idx); - return int64_t(0); - } - }); -} - -static unique_ptr DatePartBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - // If we are only looking for Julian Days for timestamps, - // then return doubles. - if (arguments[0]->HasParameter() || !arguments[0]->IsFoldable()) { - return nullptr; - } - - Value part_value = ExpressionExecutor::EvaluateScalar(context, *arguments[0]); - const auto part_name = part_value.ToString(); - switch (GetDatePartSpecifier(part_name)) { - case DatePartSpecifier::JULIAN_DAY: - arguments.erase(arguments.begin()); - bound_function.arguments.erase(bound_function.arguments.begin()); - bound_function.name = "julian"; - bound_function.return_type = LogicalType::DOUBLE; - switch (arguments[0]->return_type.id()) { - case LogicalType::TIMESTAMP: - bound_function.function = DatePart::UnaryFunction; - bound_function.statistics = DatePart::JulianDayOperator::template PropagateStatistics; - break; - case LogicalType::DATE: - bound_function.function = DatePart::UnaryFunction; - bound_function.statistics = DatePart::JulianDayOperator::template PropagateStatistics; - break; - default: - throw BinderException("%s can only take DATE or TIMESTAMP arguments", bound_function.name); - } - break; - case DatePartSpecifier::EPOCH: - arguments.erase(arguments.begin()); - bound_function.arguments.erase(bound_function.arguments.begin()); - bound_function.name = "epoch"; - bound_function.return_type = LogicalType::DOUBLE; - switch (arguments[0]->return_type.id()) { - case LogicalType::TIMESTAMP: - bound_function.function = DatePart::UnaryFunction; - bound_function.statistics = DatePart::EpochOperator::template PropagateStatistics; - break; - case LogicalType::DATE: - bound_function.function = DatePart::UnaryFunction; - bound_function.statistics = DatePart::EpochOperator::template PropagateStatistics; - break; - case LogicalType::INTERVAL: - bound_function.function = DatePart::UnaryFunction; - bound_function.statistics = DatePart::EpochOperator::template PropagateStatistics; - break; - case LogicalType::TIME: - bound_function.function = DatePart::UnaryFunction; - bound_function.statistics = DatePart::EpochOperator::template PropagateStatistics; - break; - default: - throw BinderException("%s can only take temporal arguments", bound_function.name); - } - break; - default: - break; - } - - return nullptr; -} - -ScalarFunctionSet GetGenericDatePartFunction(scalar_function_t date_func, scalar_function_t ts_func, - scalar_function_t interval_func, function_statistics_t date_stats, - function_statistics_t ts_stats) { - ScalarFunctionSet operator_set; - operator_set.AddFunction( - ScalarFunction({LogicalType::DATE}, LogicalType::BIGINT, std::move(date_func), nullptr, nullptr, date_stats)); - operator_set.AddFunction( - ScalarFunction({LogicalType::TIMESTAMP}, LogicalType::BIGINT, std::move(ts_func), nullptr, nullptr, ts_stats)); - operator_set.AddFunction(ScalarFunction({LogicalType::INTERVAL}, LogicalType::BIGINT, std::move(interval_func))); - return operator_set; -} - -template -static ScalarFunctionSet GetDatePartFunction() { - return GetGenericDatePartFunction( - DatePart::UnaryFunction, DatePart::UnaryFunction, - ScalarFunction::UnaryFunction, OP::template PropagateStatistics, - OP::template PropagateStatistics); -} - -ScalarFunctionSet GetGenericTimePartFunction(const LogicalType &result_type, scalar_function_t date_func, - scalar_function_t ts_func, scalar_function_t interval_func, - scalar_function_t time_func, function_statistics_t date_stats, - function_statistics_t ts_stats, function_statistics_t time_stats) { - ScalarFunctionSet operator_set; - operator_set.AddFunction( - ScalarFunction({LogicalType::DATE}, result_type, std::move(date_func), nullptr, nullptr, date_stats)); - operator_set.AddFunction( - ScalarFunction({LogicalType::TIMESTAMP}, result_type, std::move(ts_func), nullptr, nullptr, ts_stats)); - operator_set.AddFunction(ScalarFunction({LogicalType::INTERVAL}, result_type, std::move(interval_func))); - operator_set.AddFunction( - ScalarFunction({LogicalType::TIME}, result_type, std::move(time_func), nullptr, nullptr, time_stats)); - return operator_set; -} - -template -static ScalarFunctionSet GetTimePartFunction(const LogicalType &result_type = LogicalType::BIGINT) { - return GetGenericTimePartFunction( - result_type, DatePart::UnaryFunction, DatePart::UnaryFunction, - ScalarFunction::UnaryFunction, ScalarFunction::UnaryFunction, - OP::template PropagateStatistics, OP::template PropagateStatistics, - OP::template PropagateStatistics); -} - -struct LastDayOperator { - template - static inline TR Operation(TA input) { - int32_t yyyy, mm, dd; - Date::Convert(input, yyyy, mm, dd); - yyyy += (mm / 12); - mm %= 12; - ++mm; - return Date::FromDate(yyyy, mm, 1) - 1; - } -}; - -template <> -date_t LastDayOperator::Operation(timestamp_t input) { - return LastDayOperator::Operation(Timestamp::GetDate(input)); -} - -struct MonthNameOperator { - template - static inline TR Operation(TA input) { - return Date::MONTH_NAMES[DatePart::MonthOperator::Operation(input) - 1]; - } -}; - -struct DayNameOperator { - template - static inline TR Operation(TA input) { - return Date::DAY_NAMES[DatePart::DayOfWeekOperator::Operation(input)]; - } -}; - -struct StructDatePart { - using part_codes_t = vector; - - struct BindData : public VariableReturnBindData { - part_codes_t part_codes; - - explicit BindData(const LogicalType &stype, const part_codes_t &part_codes_p) - : VariableReturnBindData(stype), part_codes(part_codes_p) { - } - - unique_ptr Copy() const override { - return make_uniq(stype, part_codes); - } - }; - - static unique_ptr Bind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - // collect names and deconflict, construct return type - if (arguments[0]->HasParameter()) { - throw ParameterNotResolvedException(); - } - if (!arguments[0]->IsFoldable()) { - throw BinderException("%s can only take constant lists of part names", bound_function.name); - } - - case_insensitive_set_t name_collision_set; - child_list_t struct_children; - part_codes_t part_codes; - - Value parts_list = ExpressionExecutor::EvaluateScalar(context, *arguments[0]); - if (parts_list.type().id() == LogicalTypeId::LIST) { - auto &list_children = ListValue::GetChildren(parts_list); - if (list_children.empty()) { - throw BinderException("%s requires non-empty lists of part names", bound_function.name); - } - for (const auto &part_value : list_children) { - if (part_value.IsNull()) { - throw BinderException("NULL struct entry name in %s", bound_function.name); - } - const auto part_name = part_value.ToString(); - const auto part_code = GetDateTypePartSpecifier(part_name, arguments[1]->return_type); - if (name_collision_set.find(part_name) != name_collision_set.end()) { - throw BinderException("Duplicate struct entry name \"%s\" in %s", part_name, bound_function.name); - } - name_collision_set.insert(part_name); - part_codes.emplace_back(part_code); - const auto part_type = IsBigintDatepart(part_code) ? LogicalType::BIGINT : LogicalType::DOUBLE; - struct_children.emplace_back(make_pair(part_name, part_type)); - } - } else { - throw BinderException("%s can only take constant lists of part names", bound_function.name); - } - - Function::EraseArgument(bound_function, arguments, 0); - bound_function.return_type = LogicalType::STRUCT(struct_children); - return make_uniq(bound_function.return_type, part_codes); - } - - template - static void Function(DataChunk &args, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - D_ASSERT(args.ColumnCount() == 1); - - const auto count = args.size(); - Vector &input = args.data[0]; - - // Type counts - const auto BIGINT_COUNT = size_t(DatePartSpecifier::BEGIN_DOUBLE) - size_t(DatePartSpecifier::BEGIN_BIGINT); - const auto DOUBLE_COUNT = size_t(DatePartSpecifier::BEGIN_INVALID) - size_t(DatePartSpecifier::BEGIN_DOUBLE); - DatePart::StructOperator::bigint_vec bigint_values(BIGINT_COUNT, nullptr); - DatePart::StructOperator::double_vec double_values(DOUBLE_COUNT, nullptr); - const auto part_mask = DatePart::StructOperator::GetMask(info.part_codes); - - auto &child_entries = StructVector::GetEntries(result); - - // The first computer of a part "owns" it - // and other requestors just reference the owner - vector owners(int(DatePartSpecifier::JULIAN_DAY) + 1, child_entries.size()); - for (size_t col = 0; col < child_entries.size(); ++col) { - const auto part_index = size_t(info.part_codes[col]); - if (owners[part_index] == child_entries.size()) { - owners[part_index] = col; - } - } - - if (input.GetVectorType() == VectorType::CONSTANT_VECTOR) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - - if (ConstantVector::IsNull(input)) { - ConstantVector::SetNull(result, true); - } else { - ConstantVector::SetNull(result, false); - for (size_t col = 0; col < child_entries.size(); ++col) { - auto &child_entry = child_entries[col]; - ConstantVector::SetNull(*child_entry, false); - const auto part_index = size_t(info.part_codes[col]); - if (owners[part_index] == col) { - if (IsBigintDatepart(info.part_codes[col])) { - bigint_values[part_index - size_t(DatePartSpecifier::BEGIN_BIGINT)] = - ConstantVector::GetData(*child_entry); - } else { - double_values[part_index - size_t(DatePartSpecifier::BEGIN_DOUBLE)] = - ConstantVector::GetData(*child_entry); - } - } - } - auto tdata = ConstantVector::GetData(input); - if (Value::IsFinite(tdata[0])) { - DatePart::StructOperator::Operation(bigint_values, double_values, tdata[0], 0, part_mask); - } else { - for (auto &child_entry : child_entries) { - ConstantVector::SetNull(*child_entry, true); - } - } - } - } else { - UnifiedVectorFormat rdata; - input.ToUnifiedFormat(count, rdata); - - const auto &arg_valid = rdata.validity; - auto tdata = UnifiedVectorFormat::GetData(rdata); - - // Start with a valid flat vector - result.SetVectorType(VectorType::FLAT_VECTOR); - auto &res_valid = FlatVector::Validity(result); - if (res_valid.GetData()) { - res_valid.SetAllValid(count); - } - - // Start with valid children - for (size_t col = 0; col < child_entries.size(); ++col) { - auto &child_entry = child_entries[col]; - child_entry->SetVectorType(VectorType::FLAT_VECTOR); - auto &child_validity = FlatVector::Validity(*child_entry); - if (child_validity.GetData()) { - child_validity.SetAllValid(count); - } - - // Pre-multiplex - const auto part_index = size_t(info.part_codes[col]); - if (owners[part_index] == col) { - if (IsBigintDatepart(info.part_codes[col])) { - bigint_values[part_index - size_t(DatePartSpecifier::BEGIN_BIGINT)] = - FlatVector::GetData(*child_entry); - } else { - double_values[part_index - size_t(DatePartSpecifier::BEGIN_DOUBLE)] = - FlatVector::GetData(*child_entry); - } - } - } - - for (idx_t i = 0; i < count; ++i) { - const auto idx = rdata.sel->get_index(i); - if (arg_valid.RowIsValid(idx)) { - if (Value::IsFinite(tdata[idx])) { - DatePart::StructOperator::Operation(bigint_values, double_values, tdata[idx], i, part_mask); - } else { - for (auto &child_entry : child_entries) { - FlatVector::Validity(*child_entry).SetInvalid(i); - } - } - } else { - res_valid.SetInvalid(i); - for (auto &child_entry : child_entries) { - FlatVector::Validity(*child_entry).SetInvalid(i); - } - } - } - } - - // Reference any duplicate parts - for (size_t col = 0; col < child_entries.size(); ++col) { - const auto part_index = size_t(info.part_codes[col]); - const auto owner = owners[part_index]; - if (owner != col) { - child_entries[col]->Reference(*child_entries[owner]); - } - } - - result.Verify(count); - } - - static void SerializeFunction(Serializer &serializer, const optional_ptr bind_data_p, - const ScalarFunction &function) { - D_ASSERT(bind_data_p); - auto &info = bind_data_p->Cast(); - serializer.WriteProperty(100, "stype", info.stype); - serializer.WriteProperty(101, "part_codes", info.part_codes); - } - - static unique_ptr DeserializeFunction(Deserializer &deserializer, ScalarFunction &bound_function) { - auto stype = deserializer.ReadProperty(100, "stype"); - auto part_codes = deserializer.ReadProperty>(101, "part_codes"); - return make_uniq(std::move(stype), std::move(part_codes)); - } - - template - static ScalarFunction GetFunction(const LogicalType &temporal_type) { - auto part_type = LogicalType::LIST(LogicalType::VARCHAR); - auto result_type = LogicalType::STRUCT({}); - ScalarFunction result({part_type, temporal_type}, result_type, Function, Bind); - result.serialize = SerializeFunction; - result.deserialize = DeserializeFunction; - return result; - } -}; - -ScalarFunctionSet YearFun::GetFunctions() { - return GetGenericDatePartFunction(LastYearFunction, LastYearFunction, - ScalarFunction::UnaryFunction, - DatePart::YearOperator::PropagateStatistics, - DatePart::YearOperator::PropagateStatistics); -} - -ScalarFunctionSet MonthFun::GetFunctions() { - return GetDatePartFunction(); -} - -ScalarFunctionSet DayFun::GetFunctions() { - return GetDatePartFunction(); -} - -ScalarFunctionSet DecadeFun::GetFunctions() { - return GetDatePartFunction(); -} - -ScalarFunctionSet CenturyFun::GetFunctions() { - return GetDatePartFunction(); -} - -ScalarFunctionSet MillenniumFun::GetFunctions() { - return GetDatePartFunction(); -} - -ScalarFunctionSet QuarterFun::GetFunctions() { - return GetDatePartFunction(); -} - -ScalarFunctionSet DayOfWeekFun::GetFunctions() { - return GetDatePartFunction(); -} - -ScalarFunctionSet ISODayOfWeekFun::GetFunctions() { - return GetDatePartFunction(); -} - -ScalarFunctionSet DayOfYearFun::GetFunctions() { - return GetDatePartFunction(); -} - -ScalarFunctionSet WeekFun::GetFunctions() { - return GetDatePartFunction(); -} - -ScalarFunctionSet ISOYearFun::GetFunctions() { - return GetDatePartFunction(); -} - -ScalarFunctionSet EraFun::GetFunctions() { - return GetDatePartFunction(); -} - -ScalarFunctionSet TimezoneFun::GetFunctions() { - return GetDatePartFunction(); -} - -ScalarFunctionSet TimezoneHourFun::GetFunctions() { - return GetDatePartFunction(); -} - -ScalarFunctionSet TimezoneMinuteFun::GetFunctions() { - return GetDatePartFunction(); -} - -ScalarFunctionSet EpochFun::GetFunctions() { - return GetTimePartFunction(LogicalType::DOUBLE); -} - -ScalarFunctionSet EpochNsFun::GetFunctions() { - using OP = DatePart::EpochNanosecondsOperator; - auto operator_set = GetTimePartFunction(); - - // TIMESTAMP WITH TIME ZONE has the same representation as TIMESTAMP so no need to defer to ICU - auto tstz_func = DatePart::UnaryFunction; - auto tstz_stats = OP::template PropagateStatistics; - operator_set.AddFunction( - ScalarFunction({LogicalType::TIMESTAMP_TZ}, LogicalType::BIGINT, tstz_func, nullptr, nullptr, tstz_stats)); - return operator_set; -} - -ScalarFunctionSet EpochUsFun::GetFunctions() { - using OP = DatePart::EpochMicrosecondsOperator; - auto operator_set = GetTimePartFunction(); - - // TIMESTAMP WITH TIME ZONE has the same representation as TIMESTAMP so no need to defer to ICU - auto tstz_func = DatePart::UnaryFunction; - auto tstz_stats = OP::template PropagateStatistics; - operator_set.AddFunction( - ScalarFunction({LogicalType::TIMESTAMP_TZ}, LogicalType::BIGINT, tstz_func, nullptr, nullptr, tstz_stats)); - return operator_set; -} - -ScalarFunctionSet EpochMsFun::GetFunctions() { - using OP = DatePart::EpochMillisOperator; - auto operator_set = GetTimePartFunction(); - - // TIMESTAMP WITH TIME ZONE has the same representation as TIMESTAMP so no need to defer to ICU - auto tstz_func = DatePart::UnaryFunction; - auto tstz_stats = OP::template PropagateStatistics; - operator_set.AddFunction( - ScalarFunction({LogicalType::TIMESTAMP_TZ}, LogicalType::BIGINT, tstz_func, nullptr, nullptr, tstz_stats)); - - // Legacy inverse BIGINT => TIMESTAMP - operator_set.AddFunction( - ScalarFunction({LogicalType::BIGINT}, LogicalType::TIMESTAMP, DatePart::EpochMillisOperator::Inverse)); - - return operator_set; -} - -ScalarFunctionSet MicrosecondsFun::GetFunctions() { - return GetTimePartFunction(); -} - -ScalarFunctionSet MillisecondsFun::GetFunctions() { - return GetTimePartFunction(); -} - -ScalarFunctionSet SecondsFun::GetFunctions() { - return GetTimePartFunction(); -} - -ScalarFunctionSet MinutesFun::GetFunctions() { - return GetTimePartFunction(); -} - -ScalarFunctionSet HoursFun::GetFunctions() { - return GetTimePartFunction(); -} - -ScalarFunctionSet YearWeekFun::GetFunctions() { - return GetDatePartFunction(); -} - -ScalarFunctionSet DayOfMonthFun::GetFunctions() { - return GetDatePartFunction(); -} - -ScalarFunctionSet WeekDayFun::GetFunctions() { - return GetDatePartFunction(); -} - -ScalarFunctionSet WeekOfYearFun::GetFunctions() { - return GetDatePartFunction(); -} - -ScalarFunctionSet LastDayFun::GetFunctions() { - ScalarFunctionSet last_day; - last_day.AddFunction(ScalarFunction({LogicalType::DATE}, LogicalType::DATE, - DatePart::UnaryFunction)); - last_day.AddFunction(ScalarFunction({LogicalType::TIMESTAMP}, LogicalType::DATE, - DatePart::UnaryFunction)); - return last_day; -} - -ScalarFunctionSet MonthNameFun::GetFunctions() { - ScalarFunctionSet monthname; - monthname.AddFunction(ScalarFunction({LogicalType::DATE}, LogicalType::VARCHAR, - DatePart::UnaryFunction)); - monthname.AddFunction(ScalarFunction({LogicalType::TIMESTAMP}, LogicalType::VARCHAR, - DatePart::UnaryFunction)); - return monthname; -} - -ScalarFunctionSet DayNameFun::GetFunctions() { - ScalarFunctionSet dayname; - dayname.AddFunction(ScalarFunction({LogicalType::DATE}, LogicalType::VARCHAR, - DatePart::UnaryFunction)); - dayname.AddFunction(ScalarFunction({LogicalType::TIMESTAMP}, LogicalType::VARCHAR, - DatePart::UnaryFunction)); - return dayname; -} - -ScalarFunctionSet JulianDayFun::GetFunctions() { - using OP = DatePart::JulianDayOperator; - - ScalarFunctionSet operator_set; - auto date_func = DatePart::UnaryFunction; - auto date_stats = OP::template PropagateStatistics; - operator_set.AddFunction( - ScalarFunction({LogicalType::DATE}, LogicalType::DOUBLE, date_func, nullptr, nullptr, date_stats)); - auto ts_func = DatePart::UnaryFunction; - auto ts_stats = OP::template PropagateStatistics; - operator_set.AddFunction( - ScalarFunction({LogicalType::TIMESTAMP}, LogicalType::DOUBLE, ts_func, nullptr, nullptr, ts_stats)); - - return operator_set; -} - -ScalarFunctionSet DatePartFun::GetFunctions() { - ScalarFunctionSet date_part; - date_part.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::DATE}, LogicalType::BIGINT, - DatePartFunction, DatePartBind)); - date_part.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::TIMESTAMP}, LogicalType::BIGINT, - DatePartFunction, DatePartBind)); - date_part.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::TIME}, LogicalType::BIGINT, - DatePartFunction, DatePartBind)); - date_part.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::INTERVAL}, LogicalType::BIGINT, - DatePartFunction, DatePartBind)); - - // struct variants - date_part.AddFunction(StructDatePart::GetFunction(LogicalType::DATE)); - date_part.AddFunction(StructDatePart::GetFunction(LogicalType::TIMESTAMP)); - date_part.AddFunction(StructDatePart::GetFunction(LogicalType::TIME)); - date_part.AddFunction(StructDatePart::GetFunction(LogicalType::INTERVAL)); - - return date_part; -} - -} // namespace duckdb - - - - - - - - - - - - -namespace duckdb { - -struct DateSub { - static int64_t SubtractMicros(timestamp_t startdate, timestamp_t enddate) { - const auto start = Timestamp::GetEpochMicroSeconds(startdate); - const auto end = Timestamp::GetEpochMicroSeconds(enddate); - return SubtractOperatorOverflowCheck::Operation(end, start); - } - - template - static inline void BinaryExecute(Vector &left, Vector &right, Vector &result, idx_t count) { - BinaryExecutor::ExecuteWithNulls( - left, right, result, count, [&](TA startdate, TB enddate, ValidityMask &mask, idx_t idx) { - if (Value::IsFinite(startdate) && Value::IsFinite(enddate)) { - return OP::template Operation(startdate, enddate); - } else { - mask.SetInvalid(idx); - return TR(); - } - }); - } - - struct MonthOperator { - template - static inline TR Operation(TA start_ts, TB end_ts) { - - if (start_ts > end_ts) { - return -MonthOperator::Operation(end_ts, start_ts); - } - // The number of complete months depends on whether end_ts is on the last day of the month. - date_t end_date; - dtime_t end_time; - Timestamp::Convert(end_ts, end_date, end_time); - - int32_t yyyy, mm, dd; - Date::Convert(end_date, yyyy, mm, dd); - const auto end_days = Date::MonthDays(yyyy, mm); - if (end_days == dd) { - // Now check whether the start day is after the end day - date_t start_date; - dtime_t start_time; - Timestamp::Convert(start_ts, start_date, start_time); - Date::Convert(start_date, yyyy, mm, dd); - if (dd > end_days || (dd == end_days && start_time < end_time)) { - // Move back to the same time on the last day of the (shorter) end month - start_date = Date::FromDate(yyyy, mm, end_days); - start_ts = Timestamp::FromDatetime(start_date, start_time); - } - } - - // Our interval difference will now give the correct result. - // Note that PG gives different interval subtraction results, - // so if we change this we will have to reimplement. - return Interval::GetAge(end_ts, start_ts).months; - } - }; - - struct QuarterOperator { - template - static inline TR Operation(TA start_ts, TB end_ts) { - return MonthOperator::Operation(start_ts, end_ts) / Interval::MONTHS_PER_QUARTER; - } - }; - - struct YearOperator { - template - static inline TR Operation(TA start_ts, TB end_ts) { - return MonthOperator::Operation(start_ts, end_ts) / Interval::MONTHS_PER_YEAR; - } - }; - - struct DecadeOperator { - template - static inline TR Operation(TA start_ts, TB end_ts) { - return MonthOperator::Operation(start_ts, end_ts) / Interval::MONTHS_PER_DECADE; - } - }; - - struct CenturyOperator { - template - static inline TR Operation(TA start_ts, TB end_ts) { - return MonthOperator::Operation(start_ts, end_ts) / Interval::MONTHS_PER_CENTURY; - } - }; - - struct MilleniumOperator { - template - static inline TR Operation(TA start_ts, TB end_ts) { - return MonthOperator::Operation(start_ts, end_ts) / Interval::MONTHS_PER_MILLENIUM; - } - }; - - struct DayOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - return SubtractMicros(startdate, enddate) / Interval::MICROS_PER_DAY; - } - }; - - struct WeekOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - return SubtractMicros(startdate, enddate) / Interval::MICROS_PER_WEEK; - } - }; - - struct MicrosecondsOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - return SubtractMicros(startdate, enddate); - } - }; - - struct MillisecondsOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - return SubtractMicros(startdate, enddate) / Interval::MICROS_PER_MSEC; - } - }; - - struct SecondsOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - return SubtractMicros(startdate, enddate) / Interval::MICROS_PER_SEC; - } - }; - - struct MinutesOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - return SubtractMicros(startdate, enddate) / Interval::MICROS_PER_MINUTE; - } - }; - - struct HoursOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - return SubtractMicros(startdate, enddate) / Interval::MICROS_PER_HOUR; - } - }; -}; - -// DATE specialisations -template <> -int64_t DateSub::YearOperator::Operation(date_t startdate, date_t enddate) { - dtime_t t0(0); - return YearOperator::Operation(Timestamp::FromDatetime(startdate, t0), - Timestamp::FromDatetime(enddate, t0)); -} - -template <> -int64_t DateSub::MonthOperator::Operation(date_t startdate, date_t enddate) { - dtime_t t0(0); - return MonthOperator::Operation(Timestamp::FromDatetime(startdate, t0), - Timestamp::FromDatetime(enddate, t0)); -} - -template <> -int64_t DateSub::DayOperator::Operation(date_t startdate, date_t enddate) { - dtime_t t0(0); - return DayOperator::Operation(Timestamp::FromDatetime(startdate, t0), - Timestamp::FromDatetime(enddate, t0)); -} - -template <> -int64_t DateSub::DecadeOperator::Operation(date_t startdate, date_t enddate) { - dtime_t t0(0); - return DecadeOperator::Operation(Timestamp::FromDatetime(startdate, t0), - Timestamp::FromDatetime(enddate, t0)); -} - -template <> -int64_t DateSub::CenturyOperator::Operation(date_t startdate, date_t enddate) { - dtime_t t0(0); - return CenturyOperator::Operation(Timestamp::FromDatetime(startdate, t0), - Timestamp::FromDatetime(enddate, t0)); -} - -template <> -int64_t DateSub::MilleniumOperator::Operation(date_t startdate, date_t enddate) { - dtime_t t0(0); - return MilleniumOperator::Operation(Timestamp::FromDatetime(startdate, t0), - Timestamp::FromDatetime(enddate, t0)); -} - -template <> -int64_t DateSub::QuarterOperator::Operation(date_t startdate, date_t enddate) { - dtime_t t0(0); - return QuarterOperator::Operation(Timestamp::FromDatetime(startdate, t0), - Timestamp::FromDatetime(enddate, t0)); -} - -template <> -int64_t DateSub::WeekOperator::Operation(date_t startdate, date_t enddate) { - dtime_t t0(0); - return WeekOperator::Operation(Timestamp::FromDatetime(startdate, t0), - Timestamp::FromDatetime(enddate, t0)); -} - -template <> -int64_t DateSub::MicrosecondsOperator::Operation(date_t startdate, date_t enddate) { - dtime_t t0(0); - return MicrosecondsOperator::Operation(Timestamp::FromDatetime(startdate, t0), - Timestamp::FromDatetime(enddate, t0)); -} - -template <> -int64_t DateSub::MillisecondsOperator::Operation(date_t startdate, date_t enddate) { - dtime_t t0(0); - return MillisecondsOperator::Operation(Timestamp::FromDatetime(startdate, t0), - Timestamp::FromDatetime(enddate, t0)); -} - -template <> -int64_t DateSub::SecondsOperator::Operation(date_t startdate, date_t enddate) { - dtime_t t0(0); - return SecondsOperator::Operation(Timestamp::FromDatetime(startdate, t0), - Timestamp::FromDatetime(enddate, t0)); -} - -template <> -int64_t DateSub::MinutesOperator::Operation(date_t startdate, date_t enddate) { - dtime_t t0(0); - return MinutesOperator::Operation(Timestamp::FromDatetime(startdate, t0), - Timestamp::FromDatetime(enddate, t0)); -} - -template <> -int64_t DateSub::HoursOperator::Operation(date_t startdate, date_t enddate) { - dtime_t t0(0); - return HoursOperator::Operation(Timestamp::FromDatetime(startdate, t0), - Timestamp::FromDatetime(enddate, t0)); -} - -// TIME specialisations -template <> -int64_t DateSub::YearOperator::Operation(dtime_t startdate, dtime_t enddate) { - throw NotImplementedException("\"time\" units \"year\" not recognized"); -} - -template <> -int64_t DateSub::MonthOperator::Operation(dtime_t startdate, dtime_t enddate) { - throw NotImplementedException("\"time\" units \"month\" not recognized"); -} - -template <> -int64_t DateSub::DayOperator::Operation(dtime_t startdate, dtime_t enddate) { - throw NotImplementedException("\"time\" units \"day\" not recognized"); -} - -template <> -int64_t DateSub::DecadeOperator::Operation(dtime_t startdate, dtime_t enddate) { - throw NotImplementedException("\"time\" units \"decade\" not recognized"); -} - -template <> -int64_t DateSub::CenturyOperator::Operation(dtime_t startdate, dtime_t enddate) { - throw NotImplementedException("\"time\" units \"century\" not recognized"); -} - -template <> -int64_t DateSub::MilleniumOperator::Operation(dtime_t startdate, dtime_t enddate) { - throw NotImplementedException("\"time\" units \"millennium\" not recognized"); -} - -template <> -int64_t DateSub::QuarterOperator::Operation(dtime_t startdate, dtime_t enddate) { - throw NotImplementedException("\"time\" units \"quarter\" not recognized"); -} - -template <> -int64_t DateSub::WeekOperator::Operation(dtime_t startdate, dtime_t enddate) { - throw NotImplementedException("\"time\" units \"week\" not recognized"); -} - -template <> -int64_t DateSub::MicrosecondsOperator::Operation(dtime_t startdate, dtime_t enddate) { - return enddate.micros - startdate.micros; -} - -template <> -int64_t DateSub::MillisecondsOperator::Operation(dtime_t startdate, dtime_t enddate) { - return (enddate.micros - startdate.micros) / Interval::MICROS_PER_MSEC; -} - -template <> -int64_t DateSub::SecondsOperator::Operation(dtime_t startdate, dtime_t enddate) { - return (enddate.micros - startdate.micros) / Interval::MICROS_PER_SEC; -} - -template <> -int64_t DateSub::MinutesOperator::Operation(dtime_t startdate, dtime_t enddate) { - return (enddate.micros - startdate.micros) / Interval::MICROS_PER_MINUTE; -} - -template <> -int64_t DateSub::HoursOperator::Operation(dtime_t startdate, dtime_t enddate) { - return (enddate.micros - startdate.micros) / Interval::MICROS_PER_HOUR; -} - -template -static int64_t SubtractDateParts(DatePartSpecifier type, TA startdate, TB enddate) { - switch (type) { - case DatePartSpecifier::YEAR: - case DatePartSpecifier::ISOYEAR: - return DateSub::YearOperator::template Operation(startdate, enddate); - case DatePartSpecifier::MONTH: - return DateSub::MonthOperator::template Operation(startdate, enddate); - case DatePartSpecifier::DAY: - case DatePartSpecifier::DOW: - case DatePartSpecifier::ISODOW: - case DatePartSpecifier::DOY: - case DatePartSpecifier::JULIAN_DAY: - return DateSub::DayOperator::template Operation(startdate, enddate); - case DatePartSpecifier::DECADE: - return DateSub::DecadeOperator::template Operation(startdate, enddate); - case DatePartSpecifier::CENTURY: - return DateSub::CenturyOperator::template Operation(startdate, enddate); - case DatePartSpecifier::MILLENNIUM: - return DateSub::MilleniumOperator::template Operation(startdate, enddate); - case DatePartSpecifier::QUARTER: - return DateSub::QuarterOperator::template Operation(startdate, enddate); - case DatePartSpecifier::WEEK: - case DatePartSpecifier::YEARWEEK: - return DateSub::WeekOperator::template Operation(startdate, enddate); - case DatePartSpecifier::MICROSECONDS: - return DateSub::MicrosecondsOperator::template Operation(startdate, enddate); - case DatePartSpecifier::MILLISECONDS: - return DateSub::MillisecondsOperator::template Operation(startdate, enddate); - case DatePartSpecifier::SECOND: - case DatePartSpecifier::EPOCH: - return DateSub::SecondsOperator::template Operation(startdate, enddate); - case DatePartSpecifier::MINUTE: - return DateSub::MinutesOperator::template Operation(startdate, enddate); - case DatePartSpecifier::HOUR: - return DateSub::HoursOperator::template Operation(startdate, enddate); - default: - throw NotImplementedException("Specifier type not implemented for DATESUB"); - } -} - -struct DateSubTernaryOperator { - template - static inline TR Operation(TS part, TA startdate, TB enddate, ValidityMask &mask, idx_t idx) { - if (Value::IsFinite(startdate) && Value::IsFinite(enddate)) { - return SubtractDateParts(GetDatePartSpecifier(part.GetString()), startdate, enddate); - } else { - mask.SetInvalid(idx); - return TR(); - } - } -}; - -template -static void DateSubBinaryExecutor(DatePartSpecifier type, Vector &left, Vector &right, Vector &result, idx_t count) { - switch (type) { - case DatePartSpecifier::YEAR: - case DatePartSpecifier::ISOYEAR: - DateSub::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::MONTH: - DateSub::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::DAY: - case DatePartSpecifier::DOW: - case DatePartSpecifier::ISODOW: - case DatePartSpecifier::DOY: - case DatePartSpecifier::JULIAN_DAY: - DateSub::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::DECADE: - DateSub::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::CENTURY: - DateSub::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::MILLENNIUM: - DateSub::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::QUARTER: - DateSub::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::WEEK: - case DatePartSpecifier::YEARWEEK: - DateSub::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::MICROSECONDS: - DateSub::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::MILLISECONDS: - DateSub::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::SECOND: - case DatePartSpecifier::EPOCH: - DateSub::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::MINUTE: - DateSub::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::HOUR: - DateSub::BinaryExecute(left, right, result, count); - break; - default: - throw NotImplementedException("Specifier type not implemented for DATESUB"); - } -} - -template -static void DateSubFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 3); - auto &part_arg = args.data[0]; - auto &start_arg = args.data[1]; - auto &end_arg = args.data[2]; - - if (part_arg.GetVectorType() == VectorType::CONSTANT_VECTOR) { - // Common case of constant part. - if (ConstantVector::IsNull(part_arg)) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - } else { - const auto type = GetDatePartSpecifier(ConstantVector::GetData(part_arg)->GetString()); - DateSubBinaryExecutor(type, start_arg, end_arg, result, args.size()); - } - } else { - TernaryExecutor::ExecuteWithNulls( - part_arg, start_arg, end_arg, result, args.size(), - DateSubTernaryOperator::Operation); - } -} - -ScalarFunctionSet DateSubFun::GetFunctions() { - ScalarFunctionSet date_sub("date_sub"); - date_sub.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::DATE, LogicalType::DATE}, - LogicalType::BIGINT, DateSubFunction)); - date_sub.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::TIMESTAMP, LogicalType::TIMESTAMP}, - LogicalType::BIGINT, DateSubFunction)); - date_sub.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::TIME, LogicalType::TIME}, - LogicalType::BIGINT, DateSubFunction)); - return date_sub; -} - -} // namespace duckdb - - - - - - - - - - -namespace duckdb { - -struct DateTrunc { - template - static inline TR UnaryFunction(TA input) { - if (Value::IsFinite(input)) { - return OP::template Operation(input); - } else { - return Cast::template Operation(input); - } - } - - template - static inline void UnaryExecute(Vector &left, Vector &result, idx_t count) { - UnaryExecutor::Execute(left, result, count, UnaryFunction); - } - - struct MillenniumOperator { - template - static inline TR Operation(TA input) { - return Date::FromDate((Date::ExtractYear(input) / 1000) * 1000, 1, 1); - } - }; - - struct CenturyOperator { - template - static inline TR Operation(TA input) { - return Date::FromDate((Date::ExtractYear(input) / 100) * 100, 1, 1); - } - }; - - struct DecadeOperator { - template - static inline TR Operation(TA input) { - return Date::FromDate((Date::ExtractYear(input) / 10) * 10, 1, 1); - } - }; - - struct YearOperator { - template - static inline TR Operation(TA input) { - return Date::FromDate(Date::ExtractYear(input), 1, 1); - } - }; - - struct QuarterOperator { - template - static inline TR Operation(TA input) { - int32_t yyyy, mm, dd; - Date::Convert(input, yyyy, mm, dd); - mm = 1 + (((mm - 1) / 3) * 3); - return Date::FromDate(yyyy, mm, 1); - } - }; - - struct MonthOperator { - template - static inline TR Operation(TA input) { - return Date::FromDate(Date::ExtractYear(input), Date::ExtractMonth(input), 1); - } - }; - - struct WeekOperator { - template - static inline TR Operation(TA input) { - return Date::GetMondayOfCurrentWeek(input); - } - }; - - struct ISOYearOperator { - template - static inline TR Operation(TA input) { - date_t date = Date::GetMondayOfCurrentWeek(input); - date.days -= (Date::ExtractISOWeekNumber(date) - 1) * Interval::DAYS_PER_WEEK; - - return date; - } - }; - - struct DayOperator { - template - static inline TR Operation(TA input) { - return input; - } - }; - - struct HourOperator { - template - static inline TR Operation(TA input) { - int32_t hour, min, sec, micros; - date_t date; - dtime_t time; - Timestamp::Convert(input, date, time); - Time::Convert(time, hour, min, sec, micros); - return Timestamp::FromDatetime(date, Time::FromTime(hour, 0, 0, 0)); - } - }; - - struct MinuteOperator { - template - static inline TR Operation(TA input) { - int32_t hour, min, sec, micros; - date_t date; - dtime_t time; - Timestamp::Convert(input, date, time); - Time::Convert(time, hour, min, sec, micros); - return Timestamp::FromDatetime(date, Time::FromTime(hour, min, 0, 0)); - } - }; - - struct SecondOperator { - template - static inline TR Operation(TA input) { - int32_t hour, min, sec, micros; - date_t date; - dtime_t time; - Timestamp::Convert(input, date, time); - Time::Convert(time, hour, min, sec, micros); - return Timestamp::FromDatetime(date, Time::FromTime(hour, min, sec, 0)); - } - }; - - struct MillisecondOperator { - template - static inline TR Operation(TA input) { - int32_t hour, min, sec, micros; - date_t date; - dtime_t time; - Timestamp::Convert(input, date, time); - Time::Convert(time, hour, min, sec, micros); - micros -= micros % Interval::MICROS_PER_MSEC; - return Timestamp::FromDatetime(date, Time::FromTime(hour, min, sec, micros)); - } - }; - - struct MicrosecondOperator { - template - static inline TR Operation(TA input) { - return input; - } - }; -}; - -// DATE specialisations -template <> -date_t DateTrunc::MillenniumOperator::Operation(timestamp_t input) { - return MillenniumOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -timestamp_t DateTrunc::MillenniumOperator::Operation(date_t input) { - return Timestamp::FromDatetime(MillenniumOperator::Operation(input), dtime_t(0)); -} - -template <> -timestamp_t DateTrunc::MillenniumOperator::Operation(timestamp_t input) { - return MillenniumOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -date_t DateTrunc::CenturyOperator::Operation(timestamp_t input) { - return CenturyOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -timestamp_t DateTrunc::CenturyOperator::Operation(date_t input) { - return Timestamp::FromDatetime(CenturyOperator::Operation(input), dtime_t(0)); -} - -template <> -timestamp_t DateTrunc::CenturyOperator::Operation(timestamp_t input) { - return CenturyOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -date_t DateTrunc::DecadeOperator::Operation(timestamp_t input) { - return DecadeOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -timestamp_t DateTrunc::DecadeOperator::Operation(date_t input) { - return Timestamp::FromDatetime(DecadeOperator::Operation(input), dtime_t(0)); -} - -template <> -timestamp_t DateTrunc::DecadeOperator::Operation(timestamp_t input) { - return DecadeOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -date_t DateTrunc::YearOperator::Operation(timestamp_t input) { - return YearOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -timestamp_t DateTrunc::YearOperator::Operation(date_t input) { - return Timestamp::FromDatetime(YearOperator::Operation(input), dtime_t(0)); -} - -template <> -timestamp_t DateTrunc::YearOperator::Operation(timestamp_t input) { - return YearOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -date_t DateTrunc::QuarterOperator::Operation(timestamp_t input) { - return QuarterOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -timestamp_t DateTrunc::QuarterOperator::Operation(date_t input) { - return Timestamp::FromDatetime(QuarterOperator::Operation(input), dtime_t(0)); -} - -template <> -timestamp_t DateTrunc::QuarterOperator::Operation(timestamp_t input) { - return QuarterOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -date_t DateTrunc::MonthOperator::Operation(timestamp_t input) { - return MonthOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -timestamp_t DateTrunc::MonthOperator::Operation(date_t input) { - return Timestamp::FromDatetime(MonthOperator::Operation(input), dtime_t(0)); -} - -template <> -timestamp_t DateTrunc::MonthOperator::Operation(timestamp_t input) { - return MonthOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -date_t DateTrunc::WeekOperator::Operation(timestamp_t input) { - return WeekOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -timestamp_t DateTrunc::WeekOperator::Operation(date_t input) { - return Timestamp::FromDatetime(WeekOperator::Operation(input), dtime_t(0)); -} - -template <> -timestamp_t DateTrunc::WeekOperator::Operation(timestamp_t input) { - return WeekOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -date_t DateTrunc::ISOYearOperator::Operation(timestamp_t input) { - return ISOYearOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -timestamp_t DateTrunc::ISOYearOperator::Operation(date_t input) { - return Timestamp::FromDatetime(ISOYearOperator::Operation(input), dtime_t(0)); -} - -template <> -timestamp_t DateTrunc::ISOYearOperator::Operation(timestamp_t input) { - return ISOYearOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -date_t DateTrunc::DayOperator::Operation(timestamp_t input) { - return DayOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -timestamp_t DateTrunc::DayOperator::Operation(date_t input) { - return Timestamp::FromDatetime(DayOperator::Operation(input), dtime_t(0)); -} - -template <> -timestamp_t DateTrunc::DayOperator::Operation(timestamp_t input) { - return DayOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -date_t DateTrunc::HourOperator::Operation(date_t input) { - return DayOperator::Operation(input); -} - -template <> -timestamp_t DateTrunc::HourOperator::Operation(date_t input) { - return DayOperator::Operation(input); -} - -template <> -date_t DateTrunc::HourOperator::Operation(timestamp_t input) { - return Timestamp::GetDate(HourOperator::Operation(input)); -} - -template <> -date_t DateTrunc::MinuteOperator::Operation(date_t input) { - return DayOperator::Operation(input); -} - -template <> -timestamp_t DateTrunc::MinuteOperator::Operation(date_t input) { - return DayOperator::Operation(input); -} - -template <> -date_t DateTrunc::MinuteOperator::Operation(timestamp_t input) { - return Timestamp::GetDate(HourOperator::Operation(input)); -} - -template <> -date_t DateTrunc::SecondOperator::Operation(date_t input) { - return DayOperator::Operation(input); -} - -template <> -timestamp_t DateTrunc::SecondOperator::Operation(date_t input) { - return DayOperator::Operation(input); -} - -template <> -date_t DateTrunc::SecondOperator::Operation(timestamp_t input) { - return Timestamp::GetDate(DayOperator::Operation(input)); -} - -template <> -date_t DateTrunc::MillisecondOperator::Operation(date_t input) { - return DayOperator::Operation(input); -} - -template <> -timestamp_t DateTrunc::MillisecondOperator::Operation(date_t input) { - return DayOperator::Operation(input); -} - -template <> -date_t DateTrunc::MillisecondOperator::Operation(timestamp_t input) { - return Timestamp::GetDate(MillisecondOperator::Operation(input)); -} - -template <> -date_t DateTrunc::MicrosecondOperator::Operation(date_t input) { - return DayOperator::Operation(input); -} - -template <> -timestamp_t DateTrunc::MicrosecondOperator::Operation(date_t input) { - return DayOperator::Operation(input); -} - -template <> -date_t DateTrunc::MicrosecondOperator::Operation(timestamp_t input) { - return Timestamp::GetDate(MicrosecondOperator::Operation(input)); -} - -// INTERVAL specialisations -template <> -interval_t DateTrunc::MillenniumOperator::Operation(interval_t input) { - input.days = 0; - input.micros = 0; - input.months = (input.months / Interval::MONTHS_PER_MILLENIUM) * Interval::MONTHS_PER_MILLENIUM; - return input; -} - -template <> -interval_t DateTrunc::CenturyOperator::Operation(interval_t input) { - input.days = 0; - input.micros = 0; - input.months = (input.months / Interval::MONTHS_PER_CENTURY) * Interval::MONTHS_PER_CENTURY; - return input; -} - -template <> -interval_t DateTrunc::DecadeOperator::Operation(interval_t input) { - input.days = 0; - input.micros = 0; - input.months = (input.months / Interval::MONTHS_PER_DECADE) * Interval::MONTHS_PER_DECADE; - return input; -} - -template <> -interval_t DateTrunc::YearOperator::Operation(interval_t input) { - input.days = 0; - input.micros = 0; - input.months = (input.months / Interval::MONTHS_PER_YEAR) * Interval::MONTHS_PER_YEAR; - return input; -} - -template <> -interval_t DateTrunc::QuarterOperator::Operation(interval_t input) { - input.days = 0; - input.micros = 0; - input.months = (input.months / Interval::MONTHS_PER_QUARTER) * Interval::MONTHS_PER_QUARTER; - return input; -} - -template <> -interval_t DateTrunc::MonthOperator::Operation(interval_t input) { - input.days = 0; - input.micros = 0; - return input; -} - -template <> -interval_t DateTrunc::WeekOperator::Operation(interval_t input) { - input.micros = 0; - input.days = (input.days / Interval::DAYS_PER_WEEK) * Interval::DAYS_PER_WEEK; - return input; -} - -template <> -interval_t DateTrunc::ISOYearOperator::Operation(interval_t input) { - return YearOperator::Operation(input); -} - -template <> -interval_t DateTrunc::DayOperator::Operation(interval_t input) { - input.micros = 0; - return input; -} - -template <> -interval_t DateTrunc::HourOperator::Operation(interval_t input) { - input.micros = (input.micros / Interval::MICROS_PER_HOUR) * Interval::MICROS_PER_HOUR; - return input; -} - -template <> -interval_t DateTrunc::MinuteOperator::Operation(interval_t input) { - input.micros = (input.micros / Interval::MICROS_PER_MINUTE) * Interval::MICROS_PER_MINUTE; - return input; -} - -template <> -interval_t DateTrunc::SecondOperator::Operation(interval_t input) { - input.micros = (input.micros / Interval::MICROS_PER_SEC) * Interval::MICROS_PER_SEC; - return input; -} - -template <> -interval_t DateTrunc::MillisecondOperator::Operation(interval_t input) { - input.micros = (input.micros / Interval::MICROS_PER_MSEC) * Interval::MICROS_PER_MSEC; - return input; -} - -template <> -interval_t DateTrunc::MicrosecondOperator::Operation(interval_t input) { - return input; -} - -template -static TR TruncateElement(DatePartSpecifier type, TA element) { - if (!Value::IsFinite(element)) { - return Cast::template Operation(element); - } - - switch (type) { - case DatePartSpecifier::MILLENNIUM: - return DateTrunc::MillenniumOperator::Operation(element); - case DatePartSpecifier::CENTURY: - return DateTrunc::CenturyOperator::Operation(element); - case DatePartSpecifier::DECADE: - return DateTrunc::DecadeOperator::Operation(element); - case DatePartSpecifier::YEAR: - return DateTrunc::YearOperator::Operation(element); - case DatePartSpecifier::QUARTER: - return DateTrunc::QuarterOperator::Operation(element); - case DatePartSpecifier::MONTH: - return DateTrunc::MonthOperator::Operation(element); - case DatePartSpecifier::WEEK: - case DatePartSpecifier::YEARWEEK: - return DateTrunc::WeekOperator::Operation(element); - case DatePartSpecifier::ISOYEAR: - return DateTrunc::ISOYearOperator::Operation(element); - case DatePartSpecifier::DAY: - case DatePartSpecifier::DOW: - case DatePartSpecifier::ISODOW: - case DatePartSpecifier::DOY: - case DatePartSpecifier::JULIAN_DAY: - return DateTrunc::DayOperator::Operation(element); - case DatePartSpecifier::HOUR: - return DateTrunc::HourOperator::Operation(element); - case DatePartSpecifier::MINUTE: - return DateTrunc::MinuteOperator::Operation(element); - case DatePartSpecifier::SECOND: - case DatePartSpecifier::EPOCH: - return DateTrunc::SecondOperator::Operation(element); - case DatePartSpecifier::MILLISECONDS: - return DateTrunc::MillisecondOperator::Operation(element); - case DatePartSpecifier::MICROSECONDS: - return DateTrunc::MicrosecondOperator::Operation(element); - default: - throw NotImplementedException("Specifier type not implemented for DATETRUNC"); - } -} - -struct DateTruncBinaryOperator { - template - static inline TR Operation(TA specifier, TB date) { - return TruncateElement(GetDatePartSpecifier(specifier.GetString()), date); - } -}; - -template -static void DateTruncUnaryExecutor(DatePartSpecifier type, Vector &left, Vector &result, idx_t count) { - switch (type) { - case DatePartSpecifier::MILLENNIUM: - DateTrunc::UnaryExecute(left, result, count); - break; - case DatePartSpecifier::CENTURY: - DateTrunc::UnaryExecute(left, result, count); - break; - case DatePartSpecifier::DECADE: - DateTrunc::UnaryExecute(left, result, count); - break; - case DatePartSpecifier::YEAR: - DateTrunc::UnaryExecute(left, result, count); - break; - case DatePartSpecifier::QUARTER: - DateTrunc::UnaryExecute(left, result, count); - break; - case DatePartSpecifier::MONTH: - DateTrunc::UnaryExecute(left, result, count); - break; - case DatePartSpecifier::WEEK: - case DatePartSpecifier::YEARWEEK: - DateTrunc::UnaryExecute(left, result, count); - break; - case DatePartSpecifier::ISOYEAR: - DateTrunc::UnaryExecute(left, result, count); - break; - case DatePartSpecifier::DAY: - case DatePartSpecifier::DOW: - case DatePartSpecifier::ISODOW: - case DatePartSpecifier::DOY: - case DatePartSpecifier::JULIAN_DAY: - DateTrunc::UnaryExecute(left, result, count); - break; - case DatePartSpecifier::HOUR: - DateTrunc::UnaryExecute(left, result, count); - break; - case DatePartSpecifier::MINUTE: - DateTrunc::UnaryExecute(left, result, count); - break; - case DatePartSpecifier::SECOND: - case DatePartSpecifier::EPOCH: - DateTrunc::UnaryExecute(left, result, count); - break; - case DatePartSpecifier::MILLISECONDS: - DateTrunc::UnaryExecute(left, result, count); - break; - case DatePartSpecifier::MICROSECONDS: - DateTrunc::UnaryExecute(left, result, count); - break; - default: - throw NotImplementedException("Specifier type not implemented for DATETRUNC"); - } -} - -template -static void DateTruncFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 2); - auto &part_arg = args.data[0]; - auto &date_arg = args.data[1]; - - if (part_arg.GetVectorType() == VectorType::CONSTANT_VECTOR) { - // Common case of constant part. - if (ConstantVector::IsNull(part_arg)) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - } else { - const auto type = GetDatePartSpecifier(ConstantVector::GetData(part_arg)->GetString()); - DateTruncUnaryExecutor(type, date_arg, result, args.size()); - } - } else { - BinaryExecutor::ExecuteStandard(part_arg, date_arg, result, - args.size()); - } -} - -template -static unique_ptr DateTruncStatistics(vector &child_stats) { - // we can only propagate date stats if the child has stats - auto &nstats = child_stats[1]; - if (!NumericStats::HasMinMax(nstats)) { - return nullptr; - } - // run the operator on both the min and the max, this gives us the [min, max] bound - auto min = NumericStats::GetMin(nstats); - auto max = NumericStats::GetMax(nstats); - if (min > max) { - return nullptr; - } - - // Infinite values are unmodified - auto min_part = DateTrunc::UnaryFunction(min); - auto max_part = DateTrunc::UnaryFunction(max); - - auto min_value = Value::CreateValue(min_part); - auto max_value = Value::CreateValue(max_part); - auto result = NumericStats::CreateEmpty(min_value.type()); - NumericStats::SetMin(result, min_value); - NumericStats::SetMax(result, max_value); - result.CopyValidity(child_stats[0]); - return result.ToUnique(); -} - -template -static unique_ptr PropagateDateTruncStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return DateTruncStatistics(input.child_stats); -} - -template -static function_statistics_t DateTruncStats(DatePartSpecifier type) { - switch (type) { - case DatePartSpecifier::MILLENNIUM: - return PropagateDateTruncStatistics; - case DatePartSpecifier::CENTURY: - return PropagateDateTruncStatistics; - case DatePartSpecifier::DECADE: - return PropagateDateTruncStatistics; - case DatePartSpecifier::YEAR: - return PropagateDateTruncStatistics; - case DatePartSpecifier::QUARTER: - return PropagateDateTruncStatistics; - case DatePartSpecifier::MONTH: - return PropagateDateTruncStatistics; - case DatePartSpecifier::WEEK: - case DatePartSpecifier::YEARWEEK: - return PropagateDateTruncStatistics; - case DatePartSpecifier::ISOYEAR: - return PropagateDateTruncStatistics; - case DatePartSpecifier::DAY: - case DatePartSpecifier::DOW: - case DatePartSpecifier::ISODOW: - case DatePartSpecifier::DOY: - case DatePartSpecifier::JULIAN_DAY: - return PropagateDateTruncStatistics; - case DatePartSpecifier::HOUR: - return PropagateDateTruncStatistics; - case DatePartSpecifier::MINUTE: - return PropagateDateTruncStatistics; - case DatePartSpecifier::SECOND: - case DatePartSpecifier::EPOCH: - return PropagateDateTruncStatistics; - case DatePartSpecifier::MILLISECONDS: - return PropagateDateTruncStatistics; - case DatePartSpecifier::MICROSECONDS: - return PropagateDateTruncStatistics; - default: - throw NotImplementedException("Specifier type not implemented for DATETRUNC statistics"); - } -} - -static unique_ptr DateTruncBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - if (!arguments[0]->IsFoldable()) { - return nullptr; - } - - // Rebind to return a date if we are truncating that far - Value part_value = ExpressionExecutor::EvaluateScalar(context, *arguments[0]); - if (part_value.IsNull()) { - return nullptr; - } - const auto part_name = part_value.ToString(); - const auto part_code = GetDatePartSpecifier(part_name); - switch (part_code) { - case DatePartSpecifier::MILLENNIUM: - case DatePartSpecifier::CENTURY: - case DatePartSpecifier::DECADE: - case DatePartSpecifier::YEAR: - case DatePartSpecifier::QUARTER: - case DatePartSpecifier::MONTH: - case DatePartSpecifier::WEEK: - case DatePartSpecifier::YEARWEEK: - case DatePartSpecifier::ISOYEAR: - case DatePartSpecifier::DAY: - case DatePartSpecifier::DOW: - case DatePartSpecifier::ISODOW: - case DatePartSpecifier::DOY: - case DatePartSpecifier::JULIAN_DAY: - switch (bound_function.arguments[1].id()) { - case LogicalType::TIMESTAMP: - bound_function.function = DateTruncFunction; - bound_function.statistics = DateTruncStats(part_code); - break; - case LogicalType::DATE: - bound_function.function = DateTruncFunction; - bound_function.statistics = DateTruncStats(part_code); - break; - default: - throw NotImplementedException("Temporal argument type for DATETRUNC"); - } - bound_function.return_type = LogicalType::DATE; - break; - default: - switch (bound_function.arguments[1].id()) { - case LogicalType::TIMESTAMP: - bound_function.statistics = DateTruncStats(part_code); - break; - case LogicalType::DATE: - bound_function.statistics = DateTruncStats(part_code); - break; - default: - throw NotImplementedException("Temporal argument type for DATETRUNC"); - } - break; - } - - return nullptr; -} - -ScalarFunctionSet DateTruncFun::GetFunctions() { - ScalarFunctionSet date_trunc("date_trunc"); - date_trunc.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::TIMESTAMP}, LogicalType::TIMESTAMP, - DateTruncFunction, DateTruncBind)); - date_trunc.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::DATE}, LogicalType::TIMESTAMP, - DateTruncFunction, DateTruncBind)); - date_trunc.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::INTERVAL}, LogicalType::INTERVAL, - DateTruncFunction)); - return date_trunc; -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -struct EpochSecOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE sec) { - int64_t result; - if (!TryCast::Operation(sec * Interval::MICROS_PER_SEC, result)) { - throw ConversionException("Could not convert epoch seconds to TIMESTAMP WITH TIME ZONE"); - } - return timestamp_t(result); - } -}; - -static void EpochSecFunction(DataChunk &input, ExpressionState &state, Vector &result) { - D_ASSERT(input.ColumnCount() == 1); - - UnaryExecutor::Execute(input.data[0], result, input.size()); -} - -ScalarFunction ToTimestampFun::GetFunction() { - // to_timestamp is an alias from Postgres that converts the time in seconds to a timestamp - return ScalarFunction({LogicalType::DOUBLE}, LogicalType::TIMESTAMP_TZ, EpochSecFunction); -} - -} // namespace duckdb - - - - - - - -#include - -namespace duckdb { - -struct MakeDateOperator { - template - static RESULT_TYPE Operation(YYYY yyyy, MM mm, DD dd) { - return Date::FromDate(yyyy, mm, dd); - } -}; - -template -static void ExecuteMakeDate(DataChunk &input, ExpressionState &state, Vector &result) { - D_ASSERT(input.ColumnCount() == 3); - auto &yyyy = input.data[0]; - auto &mm = input.data[1]; - auto &dd = input.data[2]; - - TernaryExecutor::Execute(yyyy, mm, dd, result, input.size(), - MakeDateOperator::Operation); -} - -template -static void ExecuteStructMakeDate(DataChunk &input, ExpressionState &state, Vector &result) { - // this should be guaranteed by the binder - D_ASSERT(input.ColumnCount() == 1); - auto &vec = input.data[0]; - - auto &children = StructVector::GetEntries(vec); - D_ASSERT(children.size() == 3); - auto &yyyy = *children[0]; - auto &mm = *children[1]; - auto &dd = *children[2]; - - TernaryExecutor::Execute(yyyy, mm, dd, result, input.size(), Date::FromDate); -} - -struct MakeTimeOperator { - template - static RESULT_TYPE Operation(HH hh, MM mm, SS ss) { - int64_t secs = ss; - int64_t micros = std::round((ss - secs) * Interval::MICROS_PER_SEC); - if (!Time::IsValidTime(hh, mm, secs, micros)) { - throw ConversionException("Time out of range: %d:%d:%d.%d", hh, mm, secs, micros); - } - return Time::FromTime(hh, mm, secs, micros); - } -}; - -template -static void ExecuteMakeTime(DataChunk &input, ExpressionState &state, Vector &result) { - D_ASSERT(input.ColumnCount() == 3); - auto &yyyy = input.data[0]; - auto &mm = input.data[1]; - auto &dd = input.data[2]; - - TernaryExecutor::Execute(yyyy, mm, dd, result, input.size(), - MakeTimeOperator::Operation); -} - -struct MakeTimestampOperator { - template - static RESULT_TYPE Operation(YYYY yyyy, MM mm, DD dd, HR hr, MN mn, SS ss) { - const auto d = MakeDateOperator::Operation(yyyy, mm, dd); - const auto t = MakeTimeOperator::Operation(hr, mn, ss); - return Timestamp::FromDatetime(d, t); - } - - template - static RESULT_TYPE Operation(T micros) { - return timestamp_t(micros); - } -}; - -template -static void ExecuteMakeTimestamp(DataChunk &input, ExpressionState &state, Vector &result) { - if (input.ColumnCount() == 1) { - auto func = MakeTimestampOperator::Operation; - UnaryExecutor::Execute(input.data[0], result, input.size(), func); - return; - } - - D_ASSERT(input.ColumnCount() == 6); - - auto func = MakeTimestampOperator::Operation; - SenaryExecutor::Execute(input, result, func); -} - -ScalarFunctionSet MakeDateFun::GetFunctions() { - ScalarFunctionSet make_date("make_date"); - make_date.AddFunction(ScalarFunction({LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT}, - LogicalType::DATE, ExecuteMakeDate)); - - child_list_t make_date_children { - {"year", LogicalType::BIGINT}, {"month", LogicalType::BIGINT}, {"day", LogicalType::BIGINT}}; - make_date.AddFunction( - ScalarFunction({LogicalType::STRUCT(make_date_children)}, LogicalType::DATE, ExecuteStructMakeDate)); - return make_date; -} - -ScalarFunction MakeTimeFun::GetFunction() { - return ScalarFunction({LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::DOUBLE}, LogicalType::TIME, - ExecuteMakeTime); -} - -ScalarFunctionSet MakeTimestampFun::GetFunctions() { - ScalarFunctionSet operator_set("make_timestamp"); - operator_set.AddFunction(ScalarFunction({LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT, - LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::DOUBLE}, - LogicalType::TIMESTAMP, ExecuteMakeTimestamp)); - operator_set.AddFunction( - ScalarFunction({LogicalType::BIGINT}, LogicalType::TIMESTAMP, ExecuteMakeTimestamp)); - return operator_set; -} - -} // namespace duckdb - - - - - - - - -#include -#include - -namespace duckdb { - -struct StrfTimeBindData : public FunctionData { - explicit StrfTimeBindData(StrfTimeFormat format_p, string format_string_p, bool is_null) - : format(std::move(format_p)), format_string(std::move(format_string_p)), is_null(is_null) { - } - - StrfTimeFormat format; - string format_string; - bool is_null; - - unique_ptr Copy() const override { - return make_uniq(format, format_string, is_null); - } - - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return format_string == other.format_string; - } -}; - -template -static unique_ptr StrfTimeBindFunction(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - auto format_idx = REVERSED ? 0 : 1; - auto &format_arg = arguments[format_idx]; - if (format_arg->HasParameter()) { - throw ParameterNotResolvedException(); - } - if (!format_arg->IsFoldable()) { - throw InvalidInputException("strftime format must be a constant"); - } - Value options_str = ExpressionExecutor::EvaluateScalar(context, *format_arg); - auto format_string = options_str.GetValue(); - StrfTimeFormat format; - bool is_null = options_str.IsNull(); - if (!is_null) { - string error = StrTimeFormat::ParseFormatSpecifier(format_string, format); - if (!error.empty()) { - throw InvalidInputException("Failed to parse format specifier %s: %s", format_string, error); - } - } - return make_uniq(format, format_string, is_null); -} - -template -static void StrfTimeFunctionDate(DataChunk &args, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - - if (info.is_null) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - return; - } - info.format.ConvertDateVector(args.data[REVERSED ? 1 : 0], result, args.size()); -} - -template -static void StrfTimeFunctionTimestamp(DataChunk &args, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - - if (info.is_null) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - return; - } - info.format.ConvertTimestampVector(args.data[REVERSED ? 1 : 0], result, args.size()); -} - -ScalarFunctionSet StrfTimeFun::GetFunctions() { - ScalarFunctionSet strftime; - - strftime.AddFunction(ScalarFunction({LogicalType::DATE, LogicalType::VARCHAR}, LogicalType::VARCHAR, - StrfTimeFunctionDate, StrfTimeBindFunction)); - strftime.AddFunction(ScalarFunction({LogicalType::TIMESTAMP, LogicalType::VARCHAR}, LogicalType::VARCHAR, - StrfTimeFunctionTimestamp, StrfTimeBindFunction)); - strftime.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::DATE}, LogicalType::VARCHAR, - StrfTimeFunctionDate, StrfTimeBindFunction)); - strftime.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::TIMESTAMP}, LogicalType::VARCHAR, - StrfTimeFunctionTimestamp, StrfTimeBindFunction)); - return strftime; -} - -StrpTimeFormat::StrpTimeFormat() { -} - -StrpTimeFormat::StrpTimeFormat(const string &format_string) { - if (format_string.empty()) { - return; - } - StrTimeFormat::ParseFormatSpecifier(format_string, *this); -} - -struct StrpTimeBindData : public FunctionData { - StrpTimeBindData(const StrpTimeFormat &format, const string &format_string) - : formats(1, format), format_strings(1, format_string) { - } - - StrpTimeBindData(vector formats_p, vector format_strings_p) - : formats(std::move(formats_p)), format_strings(std::move(format_strings_p)) { - } - - vector formats; - vector format_strings; - - unique_ptr Copy() const override { - return make_uniq(formats, format_strings); - } - - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return format_strings == other.format_strings; - } -}; - -static unique_ptr StrpTimeBindFunction(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - if (arguments[1]->HasParameter()) { - throw ParameterNotResolvedException(); - } - if (!arguments[1]->IsFoldable()) { - throw InvalidInputException("strptime format must be a constant"); - } - Value format_value = ExpressionExecutor::EvaluateScalar(context, *arguments[1]); - string format_string; - StrpTimeFormat format; - if (format_value.IsNull()) { - return make_uniq(format, format_string); - } else if (format_value.type().id() == LogicalTypeId::VARCHAR) { - format_string = format_value.ToString(); - format.format_specifier = format_string; - string error = StrTimeFormat::ParseFormatSpecifier(format_string, format); - if (!error.empty()) { - throw InvalidInputException("Failed to parse format specifier %s: %s", format_string, error); - } - if (format.HasFormatSpecifier(StrTimeSpecifier::UTC_OFFSET)) { - bound_function.return_type = LogicalType::TIMESTAMP_TZ; - } - return make_uniq(format, format_string); - } else if (format_value.type() == LogicalType::LIST(LogicalType::VARCHAR)) { - const auto &children = ListValue::GetChildren(format_value); - if (children.empty()) { - throw InvalidInputException("strptime format list must not be empty"); - } - vector format_strings; - vector formats; - for (const auto &child : children) { - format_string = child.ToString(); - format.format_specifier = format_string; - string error = StrTimeFormat::ParseFormatSpecifier(format_string, format); - if (!error.empty()) { - throw InvalidInputException("Failed to parse format specifier %s: %s", format_string, error); - } - // If any format has UTC offsets, then we have to produce TSTZ - if (format.HasFormatSpecifier(StrTimeSpecifier::UTC_OFFSET)) { - bound_function.return_type = LogicalType::TIMESTAMP_TZ; - } - format_strings.emplace_back(format_string); - formats.emplace_back(format); - } - return make_uniq(formats, format_strings); - } else { - throw InvalidInputException("strptime format must be a string"); - } -} - -struct StrpTimeFunction { - - static void Parse(DataChunk &args, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - - if (args.data[1].GetVectorType() == VectorType::CONSTANT_VECTOR && ConstantVector::IsNull(args.data[1])) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - return; - } - UnaryExecutor::Execute(args.data[0], result, args.size(), [&](string_t input) { - StrpTimeFormat::ParseResult result; - for (auto &format : info.formats) { - if (format.Parse(input, result)) { - return result.ToTimestamp(); - } - } - throw InvalidInputException(result.FormatError(input, info.formats[0].format_specifier)); - }); - } - - static void TryParse(DataChunk &args, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - - if (args.data[1].GetVectorType() == VectorType::CONSTANT_VECTOR && ConstantVector::IsNull(args.data[1])) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - return; - } - - UnaryExecutor::ExecuteWithNulls( - args.data[0], result, args.size(), [&](string_t input, ValidityMask &mask, idx_t idx) { - timestamp_t result; - string error; - for (auto &format : info.formats) { - if (format.TryParseTimestamp(input, result, error)) { - return result; - } - } - - mask.SetInvalid(idx); - return timestamp_t(); - }); - } -}; - -ScalarFunctionSet StrpTimeFun::GetFunctions() { - ScalarFunctionSet strptime; - - const auto list_type = LogicalType::LIST(LogicalType::VARCHAR); - auto fun = ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::TIMESTAMP, - StrpTimeFunction::Parse, StrpTimeBindFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - strptime.AddFunction(fun); - - fun = ScalarFunction({LogicalType::VARCHAR, list_type}, LogicalType::TIMESTAMP, StrpTimeFunction::Parse, - StrpTimeBindFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - strptime.AddFunction(fun); - return strptime; -} - -ScalarFunctionSet TryStrpTimeFun::GetFunctions() { - ScalarFunctionSet try_strptime; - - const auto list_type = LogicalType::LIST(LogicalType::VARCHAR); - auto fun = ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::TIMESTAMP, - StrpTimeFunction::TryParse, StrpTimeBindFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - try_strptime.AddFunction(fun); - - fun = ScalarFunction({LogicalType::VARCHAR, list_type}, LogicalType::TIMESTAMP, StrpTimeFunction::TryParse, - StrpTimeBindFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - try_strptime.AddFunction(fun); - - return try_strptime; -} - -} // namespace duckdb - - - - - - - - - - - - - -namespace duckdb { - -struct TimeBucket { - - // Use 2000-01-03 00:00:00 (Monday) as origin when bucket_width is days, hours, ... for TimescaleDB compatibility - // There are 10959 days between 1970-01-01 and 2000-01-03 - constexpr static const int64_t DEFAULT_ORIGIN_MICROS = 10959 * Interval::MICROS_PER_DAY; - // Use 2000-01-01 as origin when bucket_width is months, years, ... for TimescaleDB compatibility - // There are 360 months between 1970-01-01 and 2000-01-01 - constexpr static const int32_t DEFAULT_ORIGIN_MONTHS = 360; - - enum struct BucketWidthType { CONVERTIBLE_TO_MICROS, CONVERTIBLE_TO_MONTHS, UNCLASSIFIED }; - - static inline BucketWidthType ClassifyBucketWidth(const interval_t bucket_width) { - if (bucket_width.months == 0 && Interval::GetMicro(bucket_width) > 0) { - return BucketWidthType::CONVERTIBLE_TO_MICROS; - } else if (bucket_width.months > 0 && bucket_width.days == 0 && bucket_width.micros == 0) { - return BucketWidthType::CONVERTIBLE_TO_MONTHS; - } else { - return BucketWidthType::UNCLASSIFIED; - } - } - - static inline BucketWidthType ClassifyBucketWidthErrorThrow(const interval_t bucket_width) { - if (bucket_width.months == 0) { - int64_t bucket_width_micros = Interval::GetMicro(bucket_width); - if (bucket_width_micros <= 0) { - throw NotImplementedException("Period must be greater than 0"); - } - return BucketWidthType::CONVERTIBLE_TO_MICROS; - } else if (bucket_width.months != 0 && bucket_width.days == 0 && bucket_width.micros == 0) { - if (bucket_width.months < 0) { - throw NotImplementedException("Period must be greater than 0"); - } - return BucketWidthType::CONVERTIBLE_TO_MONTHS; - } else { - throw NotImplementedException("Month intervals cannot have day or time component"); - } - } - - template - static inline int32_t EpochMonths(T ts) { - date_t ts_date = Cast::template Operation(ts); - return (Date::ExtractYear(ts_date) - 1970) * 12 + Date::ExtractMonth(ts_date) - 1; - } - - static inline timestamp_t WidthConvertibleToMicrosCommon(int64_t bucket_width_micros, int64_t ts_micros, - int64_t origin_micros) { - origin_micros %= bucket_width_micros; - ts_micros = SubtractOperatorOverflowCheck::Operation(ts_micros, origin_micros); - - int64_t result_micros = (ts_micros / bucket_width_micros) * bucket_width_micros; - if (ts_micros < 0 && ts_micros % bucket_width_micros != 0) { - result_micros = - SubtractOperatorOverflowCheck::Operation(result_micros, bucket_width_micros); - } - result_micros += origin_micros; - - return Timestamp::FromEpochMicroSeconds(result_micros); - } - - static inline date_t WidthConvertibleToMonthsCommon(int32_t bucket_width_months, int32_t ts_months, - int32_t origin_months) { - origin_months %= bucket_width_months; - ts_months = SubtractOperatorOverflowCheck::Operation(ts_months, origin_months); - - int32_t result_months = (ts_months / bucket_width_months) * bucket_width_months; - if (ts_months < 0 && ts_months % bucket_width_months != 0) { - result_months = - SubtractOperatorOverflowCheck::Operation(result_months, bucket_width_months); - } - result_months += origin_months; - - int32_t year = - (result_months < 0 && result_months % 12 != 0) ? 1970 + result_months / 12 - 1 : 1970 + result_months / 12; - int32_t month = - (result_months < 0 && result_months % 12 != 0) ? result_months % 12 + 13 : result_months % 12 + 1; - - return Date::FromDate(year, month, 1); - } - - struct WidthConvertibleToMicrosBinaryOperator { - template - static inline TR Operation(TA bucket_width, TB ts) { - if (!Value::IsFinite(ts)) { - return Cast::template Operation(ts); - } - int64_t bucket_width_micros = Interval::GetMicro(bucket_width); - int64_t ts_micros = Timestamp::GetEpochMicroSeconds(Cast::template Operation(ts)); - return Cast::template Operation( - WidthConvertibleToMicrosCommon(bucket_width_micros, ts_micros, DEFAULT_ORIGIN_MICROS)); - } - }; - - struct WidthConvertibleToMonthsBinaryOperator { - template - static inline TR Operation(TA bucket_width, TB ts) { - if (!Value::IsFinite(ts)) { - return Cast::template Operation(ts); - } - int32_t ts_months = EpochMonths(ts); - return Cast::template Operation( - WidthConvertibleToMonthsCommon(bucket_width.months, ts_months, DEFAULT_ORIGIN_MONTHS)); - } - }; - - struct BinaryOperator { - template - static inline TR Operation(TA bucket_width, TB ts) { - BucketWidthType bucket_width_type = ClassifyBucketWidthErrorThrow(bucket_width); - switch (bucket_width_type) { - case BucketWidthType::CONVERTIBLE_TO_MICROS: - return WidthConvertibleToMicrosBinaryOperator::Operation(bucket_width, ts); - case BucketWidthType::CONVERTIBLE_TO_MONTHS: - return WidthConvertibleToMonthsBinaryOperator::Operation(bucket_width, ts); - default: - throw NotImplementedException("Bucket type not implemented for TIME_BUCKET"); - } - } - }; - - struct OffsetWidthConvertibleToMicrosTernaryOperator { - template - static inline TR Operation(TA bucket_width, TB ts, TC offset) { - if (!Value::IsFinite(ts)) { - return Cast::template Operation(ts); - } - int64_t bucket_width_micros = Interval::GetMicro(bucket_width); - int64_t ts_micros = Timestamp::GetEpochMicroSeconds( - Interval::Add(Cast::template Operation(ts), Interval::Invert(offset))); - return Cast::template Operation(Interval::Add( - WidthConvertibleToMicrosCommon(bucket_width_micros, ts_micros, DEFAULT_ORIGIN_MICROS), offset)); - } - }; - - struct OffsetWidthConvertibleToMonthsTernaryOperator { - template - static inline TR Operation(TA bucket_width, TB ts, TC offset) { - if (!Value::IsFinite(ts)) { - return Cast::template Operation(ts); - } - int32_t ts_months = EpochMonths(Interval::Add(ts, Interval::Invert(offset))); - return Interval::Add(Cast::template Operation(WidthConvertibleToMonthsCommon( - bucket_width.months, ts_months, DEFAULT_ORIGIN_MONTHS)), - offset); - } - }; - - struct OffsetTernaryOperator { - template - static inline TR Operation(TA bucket_width, TB ts, TC offset) { - BucketWidthType bucket_width_type = ClassifyBucketWidthErrorThrow(bucket_width); - switch (bucket_width_type) { - case BucketWidthType::CONVERTIBLE_TO_MICROS: - return OffsetWidthConvertibleToMicrosTernaryOperator::Operation(bucket_width, ts, - offset); - case BucketWidthType::CONVERTIBLE_TO_MONTHS: - return OffsetWidthConvertibleToMonthsTernaryOperator::Operation(bucket_width, ts, - offset); - default: - throw NotImplementedException("Bucket type not implemented for TIME_BUCKET"); - } - } - }; - - struct OriginWidthConvertibleToMicrosTernaryOperator { - template - static inline TR Operation(TA bucket_width, TB ts, TC origin) { - if (!Value::IsFinite(ts)) { - return Cast::template Operation(ts); - } - int64_t bucket_width_micros = Interval::GetMicro(bucket_width); - int64_t ts_micros = Timestamp::GetEpochMicroSeconds(Cast::template Operation(ts)); - int64_t origin_micros = Timestamp::GetEpochMicroSeconds(Cast::template Operation(origin)); - return Cast::template Operation( - WidthConvertibleToMicrosCommon(bucket_width_micros, ts_micros, origin_micros)); - } - }; - - struct OriginWidthConvertibleToMonthsTernaryOperator { - template - static inline TR Operation(TA bucket_width, TB ts, TC origin) { - if (!Value::IsFinite(ts)) { - return Cast::template Operation(ts); - } - int32_t ts_months = EpochMonths(ts); - int32_t origin_months = EpochMonths(origin); - return Cast::template Operation( - WidthConvertibleToMonthsCommon(bucket_width.months, ts_months, origin_months)); - } - }; - - struct OriginTernaryOperator { - template - static inline TR Operation(TA bucket_width, TB ts, TC origin, ValidityMask &mask, idx_t idx) { - if (!Value::IsFinite(origin)) { - mask.SetInvalid(idx); - return TR(); - } - BucketWidthType bucket_width_type = ClassifyBucketWidthErrorThrow(bucket_width); - switch (bucket_width_type) { - case BucketWidthType::CONVERTIBLE_TO_MICROS: - return OriginWidthConvertibleToMicrosTernaryOperator::Operation(bucket_width, ts, - origin); - case BucketWidthType::CONVERTIBLE_TO_MONTHS: - return OriginWidthConvertibleToMonthsTernaryOperator::Operation(bucket_width, ts, - origin); - default: - throw NotImplementedException("Bucket type not implemented for TIME_BUCKET"); - } - } - }; -}; - -template -static void TimeBucketFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 2); - - auto &bucket_width_arg = args.data[0]; - auto &ts_arg = args.data[1]; - - if (bucket_width_arg.GetVectorType() == VectorType::CONSTANT_VECTOR) { - if (ConstantVector::IsNull(bucket_width_arg)) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - } else { - interval_t bucket_width = *ConstantVector::GetData(bucket_width_arg); - TimeBucket::BucketWidthType bucket_width_type = TimeBucket::ClassifyBucketWidth(bucket_width); - switch (bucket_width_type) { - case TimeBucket::BucketWidthType::CONVERTIBLE_TO_MICROS: - BinaryExecutor::Execute( - bucket_width_arg, ts_arg, result, args.size(), - TimeBucket::WidthConvertibleToMicrosBinaryOperator::Operation); - break; - case TimeBucket::BucketWidthType::CONVERTIBLE_TO_MONTHS: - BinaryExecutor::Execute( - bucket_width_arg, ts_arg, result, args.size(), - TimeBucket::WidthConvertibleToMonthsBinaryOperator::Operation); - break; - case TimeBucket::BucketWidthType::UNCLASSIFIED: - BinaryExecutor::Execute(bucket_width_arg, ts_arg, result, args.size(), - TimeBucket::BinaryOperator::Operation); - break; - default: - throw NotImplementedException("Bucket type not implemented for TIME_BUCKET"); - } - } - } else { - BinaryExecutor::Execute(bucket_width_arg, ts_arg, result, args.size(), - TimeBucket::BinaryOperator::Operation); - } -} - -template -static void TimeBucketOffsetFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 3); - - auto &bucket_width_arg = args.data[0]; - auto &ts_arg = args.data[1]; - auto &offset_arg = args.data[2]; - - if (bucket_width_arg.GetVectorType() == VectorType::CONSTANT_VECTOR) { - if (ConstantVector::IsNull(bucket_width_arg)) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - } else { - interval_t bucket_width = *ConstantVector::GetData(bucket_width_arg); - TimeBucket::BucketWidthType bucket_width_type = TimeBucket::ClassifyBucketWidth(bucket_width); - switch (bucket_width_type) { - case TimeBucket::BucketWidthType::CONVERTIBLE_TO_MICROS: - TernaryExecutor::Execute( - bucket_width_arg, ts_arg, offset_arg, result, args.size(), - TimeBucket::OffsetWidthConvertibleToMicrosTernaryOperator::Operation); - break; - case TimeBucket::BucketWidthType::CONVERTIBLE_TO_MONTHS: - TernaryExecutor::Execute( - bucket_width_arg, ts_arg, offset_arg, result, args.size(), - TimeBucket::OffsetWidthConvertibleToMonthsTernaryOperator::Operation); - break; - case TimeBucket::BucketWidthType::UNCLASSIFIED: - TernaryExecutor::Execute( - bucket_width_arg, ts_arg, offset_arg, result, args.size(), - TimeBucket::OffsetTernaryOperator::Operation); - break; - default: - throw NotImplementedException("Bucket type not implemented for TIME_BUCKET"); - } - } - } else { - TernaryExecutor::Execute( - bucket_width_arg, ts_arg, offset_arg, result, args.size(), - TimeBucket::OffsetTernaryOperator::Operation); - } -} - -template -static void TimeBucketOriginFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 3); - - auto &bucket_width_arg = args.data[0]; - auto &ts_arg = args.data[1]; - auto &origin_arg = args.data[2]; - - if (bucket_width_arg.GetVectorType() == VectorType::CONSTANT_VECTOR && - origin_arg.GetVectorType() == VectorType::CONSTANT_VECTOR) { - if (ConstantVector::IsNull(bucket_width_arg) || ConstantVector::IsNull(origin_arg) || - !Value::IsFinite(*ConstantVector::GetData(origin_arg))) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - } else { - interval_t bucket_width = *ConstantVector::GetData(bucket_width_arg); - TimeBucket::BucketWidthType bucket_width_type = TimeBucket::ClassifyBucketWidth(bucket_width); - switch (bucket_width_type) { - case TimeBucket::BucketWidthType::CONVERTIBLE_TO_MICROS: - TernaryExecutor::Execute( - bucket_width_arg, ts_arg, origin_arg, result, args.size(), - TimeBucket::OriginWidthConvertibleToMicrosTernaryOperator::Operation); - break; - case TimeBucket::BucketWidthType::CONVERTIBLE_TO_MONTHS: - TernaryExecutor::Execute( - bucket_width_arg, ts_arg, origin_arg, result, args.size(), - TimeBucket::OriginWidthConvertibleToMonthsTernaryOperator::Operation); - break; - case TimeBucket::BucketWidthType::UNCLASSIFIED: - TernaryExecutor::ExecuteWithNulls( - bucket_width_arg, ts_arg, origin_arg, result, args.size(), - TimeBucket::OriginTernaryOperator::Operation); - break; - default: - throw NotImplementedException("Bucket type not implemented for TIME_BUCKET"); - } - } - } else { - TernaryExecutor::ExecuteWithNulls( - bucket_width_arg, ts_arg, origin_arg, result, args.size(), - TimeBucket::OriginTernaryOperator::Operation); - } -} - -ScalarFunctionSet TimeBucketFun::GetFunctions() { - ScalarFunctionSet time_bucket; - time_bucket.AddFunction( - ScalarFunction({LogicalType::INTERVAL, LogicalType::DATE}, LogicalType::DATE, TimeBucketFunction)); - time_bucket.AddFunction(ScalarFunction({LogicalType::INTERVAL, LogicalType::TIMESTAMP}, LogicalType::TIMESTAMP, - TimeBucketFunction)); - time_bucket.AddFunction(ScalarFunction({LogicalType::INTERVAL, LogicalType::DATE, LogicalType::INTERVAL}, - LogicalType::DATE, TimeBucketOffsetFunction)); - time_bucket.AddFunction(ScalarFunction({LogicalType::INTERVAL, LogicalType::TIMESTAMP, LogicalType::INTERVAL}, - LogicalType::TIMESTAMP, TimeBucketOffsetFunction)); - time_bucket.AddFunction(ScalarFunction({LogicalType::INTERVAL, LogicalType::DATE, LogicalType::DATE}, - LogicalType::DATE, TimeBucketOriginFunction)); - time_bucket.AddFunction(ScalarFunction({LogicalType::INTERVAL, LogicalType::TIMESTAMP, LogicalType::TIMESTAMP}, - LogicalType::TIMESTAMP, TimeBucketOriginFunction)); - return time_bucket; -} - -} // namespace duckdb - - - - -namespace duckdb { - -struct ToYearsOperator { - template - static inline TR Operation(TA input) { - interval_t result; - result.days = 0; - result.micros = 0; - if (!TryMultiplyOperator::Operation(input, Interval::MONTHS_PER_YEAR, - result.months)) { - throw OutOfRangeException("Interval value %d years out of range", input); - } - return result; - } -}; - -struct ToMonthsOperator { - template - static inline TR Operation(TA input) { - interval_t result; - result.months = input; - result.days = 0; - result.micros = 0; - return result; - } -}; - -struct ToDaysOperator { - template - static inline TR Operation(TA input) { - interval_t result; - result.months = 0; - result.days = input; - result.micros = 0; - return result; - } -}; - -struct ToHoursOperator { - template - static inline TR Operation(TA input) { - interval_t result; - result.months = 0; - result.days = 0; - if (!TryMultiplyOperator::Operation(input, Interval::MICROS_PER_HOUR, - result.micros)) { - throw OutOfRangeException("Interval value %d hours out of range", input); - } - return result; - } -}; - -struct ToMinutesOperator { - template - static inline TR Operation(TA input) { - interval_t result; - result.months = 0; - result.days = 0; - if (!TryMultiplyOperator::Operation(input, Interval::MICROS_PER_MINUTE, - result.micros)) { - throw OutOfRangeException("Interval value %d minutes out of range", input); - } - return result; - } -}; - -struct ToSecondsOperator { - template - static inline TR Operation(TA input) { - interval_t result; - result.months = 0; - result.days = 0; - if (!TryMultiplyOperator::Operation(input, Interval::MICROS_PER_SEC, - result.micros)) { - throw OutOfRangeException("Interval value %d seconds out of range", input); - } - return result; - } -}; - -struct ToMilliSecondsOperator { - template - static inline TR Operation(TA input) { - interval_t result; - result.months = 0; - result.days = 0; - if (!TryMultiplyOperator::Operation(input, Interval::MICROS_PER_MSEC, - result.micros)) { - throw OutOfRangeException("Interval value %d milliseconds out of range", input); - } - return result; - } -}; - -struct ToMicroSecondsOperator { - template - static inline TR Operation(TA input) { - interval_t result; - result.months = 0; - result.days = 0; - result.micros = input; - return result; - } -}; - -ScalarFunction ToYearsFun::GetFunction() { - return ScalarFunction({LogicalType::INTEGER}, LogicalType::INTERVAL, - ScalarFunction::UnaryFunction); -} - -ScalarFunction ToMonthsFun::GetFunction() { - return ScalarFunction({LogicalType::INTEGER}, LogicalType::INTERVAL, - ScalarFunction::UnaryFunction); -} - -ScalarFunction ToDaysFun::GetFunction() { - return ScalarFunction({LogicalType::INTEGER}, LogicalType::INTERVAL, - ScalarFunction::UnaryFunction); -} - -ScalarFunction ToHoursFun::GetFunction() { - return ScalarFunction({LogicalType::BIGINT}, LogicalType::INTERVAL, - ScalarFunction::UnaryFunction); -} - -ScalarFunction ToMinutesFun::GetFunction() { - return ScalarFunction({LogicalType::BIGINT}, LogicalType::INTERVAL, - ScalarFunction::UnaryFunction); -} - -ScalarFunction ToSecondsFun::GetFunction() { - return ScalarFunction({LogicalType::BIGINT}, LogicalType::INTERVAL, - ScalarFunction::UnaryFunction); -} - -ScalarFunction ToMillisecondsFun::GetFunction() { - return ScalarFunction({LogicalType::BIGINT}, LogicalType::INTERVAL, - ScalarFunction::UnaryFunction); -} - -ScalarFunction ToMicrosecondsFun::GetFunction() { - return ScalarFunction({LogicalType::BIGINT}, LogicalType::INTERVAL, - ScalarFunction::UnaryFunction); -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -static void VectorTypeFunction(DataChunk &input, ExpressionState &state, Vector &result) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - auto data = ConstantVector::GetData(result); - data[0] = StringVector::AddString(result, EnumUtil::ToString(input.data[0].GetVectorType())); -} - -ScalarFunction VectorTypeFun::GetFunction() { - return ScalarFunction("vector_type", // name of the function - {LogicalType::ANY}, // argument list - LogicalType::VARCHAR, // return type - VectorTypeFunction); -} - -} // namespace duckdb - - -namespace duckdb { - -static void EnumFirstFunction(DataChunk &input, ExpressionState &state, Vector &result) { - auto types = input.GetTypes(); - D_ASSERT(types.size() == 1); - auto &enum_vector = EnumType::GetValuesInsertOrder(types[0]); - auto val = Value(enum_vector.GetValue(0)); - result.Reference(val); -} - -static void EnumLastFunction(DataChunk &input, ExpressionState &state, Vector &result) { - auto types = input.GetTypes(); - D_ASSERT(types.size() == 1); - auto enum_size = EnumType::GetSize(types[0]); - auto &enum_vector = EnumType::GetValuesInsertOrder(types[0]); - auto val = Value(enum_vector.GetValue(enum_size - 1)); - result.Reference(val); -} - -static void EnumRangeFunction(DataChunk &input, ExpressionState &state, Vector &result) { - auto types = input.GetTypes(); - D_ASSERT(types.size() == 1); - auto enum_size = EnumType::GetSize(types[0]); - auto &enum_vector = EnumType::GetValuesInsertOrder(types[0]); - vector enum_values; - for (idx_t i = 0; i < enum_size; i++) { - enum_values.emplace_back(enum_vector.GetValue(i)); - } - auto val = Value::LIST(enum_values); - result.Reference(val); -} - -static void EnumRangeBoundaryFunction(DataChunk &input, ExpressionState &state, Vector &result) { - auto types = input.GetTypes(); - D_ASSERT(types.size() == 2); - idx_t start, end; - auto first_param = input.GetValue(0, 0); - auto second_param = input.GetValue(1, 0); - - auto &enum_vector = - first_param.IsNull() ? EnumType::GetValuesInsertOrder(types[1]) : EnumType::GetValuesInsertOrder(types[0]); - - if (first_param.IsNull()) { - start = 0; - } else { - start = first_param.GetValue(); - } - if (second_param.IsNull()) { - end = EnumType::GetSize(types[0]); - } else { - end = second_param.GetValue() + 1; - } - vector enum_values; - for (idx_t i = start; i < end; i++) { - enum_values.emplace_back(enum_vector.GetValue(i)); - } - Value val; - if (enum_values.empty()) { - val = Value::EMPTYLIST(LogicalType::VARCHAR); - } else { - val = Value::LIST(enum_values); - } - result.Reference(val); -} - -static void EnumCodeFunction(DataChunk &input, ExpressionState &state, Vector &result) { - D_ASSERT(input.GetTypes().size() == 1); - result.Reinterpret(input.data[0]); -} - -static void CheckEnumParameter(const Expression &expr) { - if (expr.HasParameter()) { - throw ParameterNotResolvedException(); - } -} - -unique_ptr BindEnumFunction(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - CheckEnumParameter(*arguments[0]); - if (arguments[0]->return_type.id() != LogicalTypeId::ENUM) { - throw BinderException("This function needs an ENUM as an argument"); - } - return nullptr; -} - -unique_ptr BindEnumCodeFunction(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - CheckEnumParameter(*arguments[0]); - if (arguments[0]->return_type.id() != LogicalTypeId::ENUM) { - throw BinderException("This function needs an ENUM as an argument"); - } - - auto phy_type = EnumType::GetPhysicalType(arguments[0]->return_type); - switch (phy_type) { - case PhysicalType::UINT8: - bound_function.return_type = LogicalType(LogicalTypeId::UTINYINT); - break; - case PhysicalType::UINT16: - bound_function.return_type = LogicalType(LogicalTypeId::USMALLINT); - break; - case PhysicalType::UINT32: - bound_function.return_type = LogicalType(LogicalTypeId::UINTEGER); - break; - case PhysicalType::UINT64: - bound_function.return_type = LogicalType(LogicalTypeId::UBIGINT); - break; - default: - throw InternalException("Unsupported Enum Internal Type"); - } - - return nullptr; -} - -unique_ptr BindEnumRangeBoundaryFunction(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - CheckEnumParameter(*arguments[0]); - CheckEnumParameter(*arguments[1]); - if (arguments[0]->return_type.id() != LogicalTypeId::ENUM && arguments[0]->return_type != LogicalType::SQLNULL) { - throw BinderException("This function needs an ENUM as an argument"); - } - if (arguments[1]->return_type.id() != LogicalTypeId::ENUM && arguments[1]->return_type != LogicalType::SQLNULL) { - throw BinderException("This function needs an ENUM as an argument"); - } - if (arguments[0]->return_type == LogicalType::SQLNULL && arguments[1]->return_type == LogicalType::SQLNULL) { - throw BinderException("This function needs an ENUM as an argument"); - } - if (arguments[0]->return_type.id() == LogicalTypeId::ENUM && - arguments[1]->return_type.id() == LogicalTypeId::ENUM && - arguments[0]->return_type != arguments[1]->return_type) { - throw BinderException("The parameters need to link to ONLY one enum OR be NULL "); - } - return nullptr; -} - -ScalarFunction EnumFirstFun::GetFunction() { - auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::VARCHAR, EnumFirstFunction, BindEnumFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return fun; -} - -ScalarFunction EnumLastFun::GetFunction() { - auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::VARCHAR, EnumLastFunction, BindEnumFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return fun; -} - -ScalarFunction EnumCodeFun::GetFunction() { - auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::ANY, EnumCodeFunction, BindEnumCodeFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return fun; -} - -ScalarFunction EnumRangeFun::GetFunction() { - auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::LIST(LogicalType::VARCHAR), EnumRangeFunction, - BindEnumFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return fun; -} - -ScalarFunction EnumRangeBoundaryFun::GetFunction() { - auto fun = ScalarFunction({LogicalType::ANY, LogicalType::ANY}, LogicalType::LIST(LogicalType::VARCHAR), - EnumRangeBoundaryFunction, BindEnumRangeBoundaryFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return fun; -} - -} // namespace duckdb - - - -namespace duckdb { - -static void AliasFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - Value v(state.expr.alias.empty() ? func_expr.children[0]->GetName() : state.expr.alias); - result.Reference(v); -} - -ScalarFunction AliasFun::GetFunction() { - auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::VARCHAR, AliasFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return fun; -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -struct CurrentSettingBindData : public FunctionData { - explicit CurrentSettingBindData(Value value_p) : value(std::move(value_p)) { - } - - Value value; - -public: - unique_ptr Copy() const override { - return make_uniq(value); - } - - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return Value::NotDistinctFrom(value, other.value); - } -}; - -static void CurrentSettingFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - result.Reference(info.value); -} - -unique_ptr CurrentSettingBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - - auto &key_child = arguments[0]; - if (key_child->return_type.id() == LogicalTypeId::UNKNOWN) { - throw ParameterNotResolvedException(); - } - if (key_child->return_type.id() != LogicalTypeId::VARCHAR || - key_child->return_type.id() != LogicalTypeId::VARCHAR || !key_child->IsFoldable()) { - throw ParserException("Key name for current_setting needs to be a constant string"); - } - Value key_val = ExpressionExecutor::EvaluateScalar(context, *key_child); - D_ASSERT(key_val.type().id() == LogicalTypeId::VARCHAR); - auto &key_str = StringValue::Get(key_val); - if (key_val.IsNull() || key_str.empty()) { - throw ParserException("Key name for current_setting needs to be neither NULL nor empty"); - } - - auto key = StringUtil::Lower(key_str); - Value val; - if (!context.TryGetCurrentSetting(key, val)) { - Catalog::AutoloadExtensionByConfigName(context, key); - // If autoloader didn't throw, the config is now available - context.TryGetCurrentSetting(key, val); - } - - bound_function.return_type = val.type(); - return make_uniq(val); -} - -ScalarFunction CurrentSettingFun::GetFunction() { - auto fun = ScalarFunction({LogicalType::VARCHAR}, LogicalType::ANY, CurrentSettingFunction, CurrentSettingBind); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return fun; -} - -} // namespace duckdb - -#include - -namespace duckdb { - -struct ErrorOperator { - template - static inline TR Operation(const TA &input) { - throw Exception(input.GetString()); - } -}; - -ScalarFunction ErrorFun::GetFunction() { - auto fun = ScalarFunction({LogicalType::VARCHAR}, LogicalType::BOOLEAN, - ScalarFunction::UnaryFunction); - // Set the function with side effects to avoid the optimization. - fun.side_effects = FunctionSideEffects::HAS_SIDE_EFFECTS; - return fun; -} - -} // namespace duckdb - - -namespace duckdb { - -static void HashFunction(DataChunk &args, ExpressionState &state, Vector &result) { - args.Hash(result); - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } -} - -ScalarFunction HashFun::GetFunction() { - auto hash_fun = ScalarFunction({LogicalType::ANY}, LogicalType::HASH, HashFunction); - hash_fun.varargs = LogicalType::ANY; - hash_fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return hash_fun; -} - -} // namespace duckdb - - - -namespace duckdb { - -template -struct LeastOperator { - template - static T Operation(T left, T right) { - return OP::Operation(left, right) ? left : right; - } -}; - -template -static void LeastGreatestFunction(DataChunk &args, ExpressionState &state, Vector &result) { - if (args.ColumnCount() == 1) { - // single input: nop - result.Reference(args.data[0]); - return; - } - auto result_type = VectorType::CONSTANT_VECTOR; - for (idx_t col_idx = 0; col_idx < args.ColumnCount(); col_idx++) { - if (args.data[col_idx].GetVectorType() != VectorType::CONSTANT_VECTOR) { - // non-constant input: result is not a constant vector - result_type = VectorType::FLAT_VECTOR; - } - if (IS_STRING) { - // for string vectors we add a reference to the heap of the children - StringVector::AddHeapReference(result, args.data[col_idx]); - } - } - - auto result_data = FlatVector::GetData(result); - auto &result_mask = FlatVector::Validity(result); - // copy over the first column - bool result_has_value[STANDARD_VECTOR_SIZE]; - { - UnifiedVectorFormat vdata; - args.data[0].ToUnifiedFormat(args.size(), vdata); - auto input_data = UnifiedVectorFormat::GetData(vdata); - for (idx_t i = 0; i < args.size(); i++) { - auto vindex = vdata.sel->get_index(i); - if (vdata.validity.RowIsValid(vindex)) { - result_data[i] = input_data[vindex]; - result_has_value[i] = true; - } else { - result_has_value[i] = false; - } - } - } - // now handle the remainder of the columns - for (idx_t col_idx = 1; col_idx < args.ColumnCount(); col_idx++) { - if (args.data[col_idx].GetVectorType() == VectorType::CONSTANT_VECTOR && - ConstantVector::IsNull(args.data[col_idx])) { - // ignore null vector - continue; - } - - UnifiedVectorFormat vdata; - args.data[col_idx].ToUnifiedFormat(args.size(), vdata); - - auto input_data = UnifiedVectorFormat::GetData(vdata); - if (!vdata.validity.AllValid()) { - // potential new null entries: have to check the null mask - for (idx_t i = 0; i < args.size(); i++) { - auto vindex = vdata.sel->get_index(i); - if (vdata.validity.RowIsValid(vindex)) { - // not a null entry: perform the operation and add to new set - auto ivalue = input_data[vindex]; - if (!result_has_value[i] || OP::template Operation(ivalue, result_data[i])) { - result_has_value[i] = true; - result_data[i] = ivalue; - } - } - } - } else { - // no new null entries: only need to perform the operation - for (idx_t i = 0; i < args.size(); i++) { - auto vindex = vdata.sel->get_index(i); - - auto ivalue = input_data[vindex]; - if (!result_has_value[i] || OP::template Operation(ivalue, result_data[i])) { - result_has_value[i] = true; - result_data[i] = ivalue; - } - } - } - } - for (idx_t i = 0; i < args.size(); i++) { - if (!result_has_value[i]) { - result_mask.SetInvalid(i); - } - } - result.SetVectorType(result_type); -} - -template -ScalarFunction GetLeastGreatestFunction(const LogicalType &type) { - return ScalarFunction({type}, type, LeastGreatestFunction, nullptr, nullptr, nullptr, nullptr, type, - FunctionSideEffects::NO_SIDE_EFFECTS, FunctionNullHandling::SPECIAL_HANDLING); -} - -template -static ScalarFunctionSet GetLeastGreatestFunctions() { - ScalarFunctionSet fun_set; - fun_set.AddFunction(ScalarFunction({LogicalType::BIGINT}, LogicalType::BIGINT, LeastGreatestFunction, - nullptr, nullptr, nullptr, nullptr, LogicalType::BIGINT, - FunctionSideEffects::NO_SIDE_EFFECTS, FunctionNullHandling::SPECIAL_HANDLING)); - fun_set.AddFunction(ScalarFunction( - {LogicalType::HUGEINT}, LogicalType::HUGEINT, LeastGreatestFunction, nullptr, nullptr, nullptr, - nullptr, LogicalType::HUGEINT, FunctionSideEffects::NO_SIDE_EFFECTS, FunctionNullHandling::SPECIAL_HANDLING)); - fun_set.AddFunction(ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, LeastGreatestFunction, - nullptr, nullptr, nullptr, nullptr, LogicalType::DOUBLE, - FunctionSideEffects::NO_SIDE_EFFECTS, FunctionNullHandling::SPECIAL_HANDLING)); - fun_set.AddFunction(ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, - LeastGreatestFunction, nullptr, nullptr, nullptr, nullptr, - LogicalType::VARCHAR, FunctionSideEffects::NO_SIDE_EFFECTS, - FunctionNullHandling::SPECIAL_HANDLING)); - - fun_set.AddFunction(GetLeastGreatestFunction(LogicalType::TIMESTAMP)); - fun_set.AddFunction(GetLeastGreatestFunction(LogicalType::TIME)); - fun_set.AddFunction(GetLeastGreatestFunction(LogicalType::DATE)); - - fun_set.AddFunction(GetLeastGreatestFunction(LogicalType::TIMESTAMP_TZ)); - fun_set.AddFunction(GetLeastGreatestFunction(LogicalType::TIME_TZ)); - return fun_set; -} - -ScalarFunctionSet LeastFun::GetFunctions() { - return GetLeastGreatestFunctions(); -} - -ScalarFunctionSet GreatestFun::GetFunctions() { - return GetLeastGreatestFunctions(); -} - -} // namespace duckdb - - - -namespace duckdb { - -struct StatsBindData : public FunctionData { - explicit StatsBindData(string stats_p = string()) : stats(std::move(stats_p)) { - } - - string stats; - -public: - unique_ptr Copy() const override { - return make_uniq(stats); - } - - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return stats == other.stats; - } -}; - -static void StatsFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - if (info.stats.empty()) { - info.stats = "No statistics"; - } - Value v(info.stats); - result.Reference(v); -} - -unique_ptr StatsBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - return make_uniq(); -} - -static unique_ptr StatsPropagateStats(ClientContext &context, FunctionStatisticsInput &input) { - auto &child_stats = input.child_stats; - auto &bind_data = input.bind_data; - auto &info = bind_data->Cast(); - info.stats = child_stats[0].ToString(); - return nullptr; -} - -ScalarFunction StatsFun::GetFunction() { - ScalarFunction stats({LogicalType::ANY}, LogicalType::VARCHAR, StatsFunction, StatsBind, nullptr, - StatsPropagateStats); - stats.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - stats.side_effects = FunctionSideEffects::HAS_SIDE_EFFECTS; - return stats; -} - -} // namespace duckdb - -#endif diff --git a/lib/duckdb-4.cpp b/lib/duckdb-4.cpp deleted file mode 100644 index 07d6af8d..00000000 --- a/lib/duckdb-4.cpp +++ /dev/null @@ -1,21310 +0,0 @@ -#ifdef GODUCKDB_FROM_SOURCE -// See https://raw.githubusercontent.com/duckdb/duckdb/main/LICENSE for licensing information - -#include "duckdb.hpp" -#include "duckdb-internal.hpp" -#ifndef DUCKDB_AMALGAMATION -#error header mismatch -#endif - - - - - - - - - - -namespace duckdb { - -// current_query -static void CurrentQueryFunction(DataChunk &input, ExpressionState &state, Vector &result) { - Value val(state.GetContext().GetCurrentQuery()); - result.Reference(val); -} - -// current_schema -static void CurrentSchemaFunction(DataChunk &input, ExpressionState &state, Vector &result) { - Value val(ClientData::Get(state.GetContext()).catalog_search_path->GetDefault().schema); - result.Reference(val); -} - -// current_database -static void CurrentDatabaseFunction(DataChunk &input, ExpressionState &state, Vector &result) { - Value val(DatabaseManager::GetDefaultDatabase(state.GetContext())); - result.Reference(val); -} - -// current_schemas -static void CurrentSchemasFunction(DataChunk &input, ExpressionState &state, Vector &result) { - if (!input.AllConstant()) { - throw NotImplementedException("current_schemas requires a constant input"); - } - if (ConstantVector::IsNull(input.data[0])) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - return; - } - auto implicit_schemas = *ConstantVector::GetData(input.data[0]); - vector schema_list; - auto &catalog_search_path = ClientData::Get(state.GetContext()).catalog_search_path; - auto &search_path = implicit_schemas ? catalog_search_path->Get() : catalog_search_path->GetSetPaths(); - std::transform(search_path.begin(), search_path.end(), std::back_inserter(schema_list), - [](const CatalogSearchEntry &s) -> Value { return Value(s.schema); }); - - auto val = Value::LIST(LogicalType::VARCHAR, schema_list); - result.Reference(val); -} - -// in_search_path -static void InSearchPathFunction(DataChunk &input, ExpressionState &state, Vector &result) { - auto &context = state.GetContext(); - auto &search_path = ClientData::Get(context).catalog_search_path; - BinaryExecutor::Execute( - input.data[0], input.data[1], result, input.size(), [&](string_t db_name, string_t schema_name) { - return search_path->SchemaInSearchPath(context, db_name.GetString(), schema_name.GetString()); - }); -} - -// txid_current -static void TransactionIdCurrent(DataChunk &input, ExpressionState &state, Vector &result) { - auto &context = state.GetContext(); - auto &catalog = Catalog::GetCatalog(context, DatabaseManager::GetDefaultDatabase(context)); - auto &transaction = DuckTransaction::Get(context, catalog); - auto val = Value::BIGINT(transaction.start_time); - result.Reference(val); -} - -// version -static void VersionFunction(DataChunk &input, ExpressionState &state, Vector &result) { - auto val = Value(DuckDB::LibraryVersion()); - result.Reference(val); -} - -ScalarFunction CurrentQueryFun::GetFunction() { - ScalarFunction current_query({}, LogicalType::VARCHAR, CurrentQueryFunction); - current_query.side_effects = FunctionSideEffects::HAS_SIDE_EFFECTS; - return current_query; -} - -ScalarFunction CurrentSchemaFun::GetFunction() { - return ScalarFunction({}, LogicalType::VARCHAR, CurrentSchemaFunction); -} - -ScalarFunction CurrentDatabaseFun::GetFunction() { - return ScalarFunction({}, LogicalType::VARCHAR, CurrentDatabaseFunction); -} - -ScalarFunction CurrentSchemasFun::GetFunction() { - auto varchar_list_type = LogicalType::LIST(LogicalType::VARCHAR); - return ScalarFunction({LogicalType::BOOLEAN}, varchar_list_type, CurrentSchemasFunction); -} - -ScalarFunction InSearchPathFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, InSearchPathFunction); -} - -ScalarFunction CurrentTransactionIdFun::GetFunction() { - return ScalarFunction({}, LogicalType::BIGINT, TransactionIdCurrent); -} - -ScalarFunction VersionFun::GetFunction() { - return ScalarFunction({}, LogicalType::VARCHAR, VersionFunction); -} - -} // namespace duckdb - - -namespace duckdb { - -static void TypeOfFunction(DataChunk &args, ExpressionState &state, Vector &result) { - Value v(args.data[0].GetType().ToString()); - result.Reference(v); -} - -ScalarFunction TypeOfFun::GetFunction() { - auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::VARCHAR, TypeOfFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return fun; -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -struct ListSliceBindData : public FunctionData { - ListSliceBindData(const LogicalType &return_type_p, bool begin_is_empty_p, bool end_is_empty_p) - : return_type(return_type_p), begin_is_empty(begin_is_empty_p), end_is_empty(end_is_empty_p) { - } - ~ListSliceBindData() override; - - LogicalType return_type; - - bool begin_is_empty; - bool end_is_empty; - -public: - bool Equals(const FunctionData &other_p) const override; - unique_ptr Copy() const override; -}; - -ListSliceBindData::~ListSliceBindData() { -} - -bool ListSliceBindData::Equals(const FunctionData &other_p) const { - auto &other = other_p.Cast(); - return return_type == other.return_type && begin_is_empty == other.begin_is_empty && - end_is_empty == other.end_is_empty; -} - -unique_ptr ListSliceBindData::Copy() const { - return make_uniq(return_type, begin_is_empty, end_is_empty); -} - -template -static int CalculateSliceLength(idx_t begin, idx_t end, INDEX_TYPE step, bool svalid) { - if (step < 0) { - step = abs(step); - } - if (step == 0 && svalid) { - throw InvalidInputException("Slice step cannot be zero"); - } - if (step == 1) { - return end - begin; - } else if (static_cast(step) >= (end - begin)) { - return 1; - } - if ((end - begin) % step != 0) { - return (end - begin) / step + 1; - } - return (end - begin) / step; -} - -template -INDEX_TYPE ValueLength(const INPUT_TYPE &value) { - return 0; -} - -template <> -int64_t ValueLength(const list_entry_t &value) { - return value.length; -} - -template <> -int64_t ValueLength(const string_t &value) { - return LengthFun::Length(value); -} - -template -static void ClampIndex(INDEX_TYPE &index, const INPUT_TYPE &value, const INDEX_TYPE length, bool is_min) { - if (index < 0) { - index = (!is_min) ? index + 1 : index; - index = length + index; - return; - } else if (index > length) { - index = length; - } - return; -} - -template -static bool ClampSlice(const INPUT_TYPE &value, INDEX_TYPE &begin, INDEX_TYPE &end) { - // Clamp offsets - begin = (begin != 0 && begin != (INDEX_TYPE)NumericLimits::Minimum()) ? begin - 1 : begin; - - bool is_min = false; - if (begin == (INDEX_TYPE)NumericLimits::Minimum()) { - begin++; - is_min = true; - } - - const auto length = ValueLength(value); - if (begin < 0 && -begin > length && end < 0 && -end > length) { - begin = 0; - end = 0; - return true; - } - if (begin < 0 && -begin > length) { - begin = 0; - } - ClampIndex(begin, value, length, is_min); - ClampIndex(end, value, length, false); - end = MaxValue(begin, end); - - return true; -} - -template -INPUT_TYPE SliceValue(Vector &result, INPUT_TYPE input, INDEX_TYPE begin, INDEX_TYPE end) { - return input; -} - -template <> -list_entry_t SliceValue(Vector &result, list_entry_t input, int64_t begin, int64_t end) { - input.offset += begin; - input.length = end - begin; - return input; -} - -template <> -string_t SliceValue(Vector &result, string_t input, int64_t begin, int64_t end) { - // one-based - zero has strange semantics - return SubstringFun::SubstringUnicode(result, input, begin + 1, end - begin); -} - -template -INPUT_TYPE SliceValueWithSteps(Vector &result, SelectionVector &sel, INPUT_TYPE input, INDEX_TYPE begin, INDEX_TYPE end, - INDEX_TYPE step, idx_t &sel_idx) { - return input; -} - -template <> -list_entry_t SliceValueWithSteps(Vector &result, SelectionVector &sel, list_entry_t input, int64_t begin, int64_t end, - int64_t step, idx_t &sel_idx) { - if (end - begin == 0) { - input.length = 0; - input.offset = sel_idx; - return input; - } - input.length = CalculateSliceLength(begin, end, step, true); - idx_t child_idx = input.offset + begin; - if (step < 0) { - child_idx = input.offset + end - 1; - } - input.offset = sel_idx; - for (idx_t i = 0; i < input.length; i++) { - sel.set_index(sel_idx, child_idx); - child_idx += step; - sel_idx++; - } - return input; -} - -template -static void ExecuteConstantSlice(Vector &result, Vector &str_vector, Vector &begin_vector, Vector &end_vector, - optional_ptr step_vector, const idx_t count, SelectionVector &sel, - idx_t &sel_idx, optional_ptr result_child_vector, bool begin_is_empty, - bool end_is_empty) { - auto result_data = ConstantVector::GetData(result); - auto str_data = ConstantVector::GetData(str_vector); - auto begin_data = ConstantVector::GetData(begin_vector); - auto end_data = ConstantVector::GetData(end_vector); - auto step_data = step_vector ? ConstantVector::GetData(*step_vector) : nullptr; - - auto str = str_data[0]; - auto begin = begin_is_empty ? 0 : begin_data[0]; - auto end = end_is_empty ? ValueLength(str) : end_data[0]; - auto step = step_data ? step_data[0] : 1; - - if (step < 0) { - swap(begin, end); - begin = end_is_empty ? 0 : begin; - end = begin_is_empty ? ValueLength(str) : end; - } - - auto str_valid = !ConstantVector::IsNull(str_vector); - auto begin_valid = !ConstantVector::IsNull(begin_vector); - auto end_valid = !ConstantVector::IsNull(end_vector); - auto step_valid = step_vector && !ConstantVector::IsNull(*step_vector); - - // Clamp offsets - bool clamp_result = false; - if (str_valid && begin_valid && end_valid && (step_valid || step == 1)) { - clamp_result = ClampSlice(str, begin, end); - } - - auto sel_length = 0; - bool sel_valid = false; - if (step_vector && step_valid && str_valid && begin_valid && end_valid && step != 1 && end - begin > 0) { - sel_length = CalculateSliceLength(begin, end, step, step_valid); - sel.Initialize(sel_length); - sel_valid = true; - } - - // Try to slice - if (!str_valid || !begin_valid || !end_valid || (step_vector && !step_valid) || !clamp_result) { - ConstantVector::SetNull(result, true); - } else if (step == 1) { - result_data[0] = SliceValue(result, str, begin, end); - } else { - result_data[0] = SliceValueWithSteps(result, sel, str, begin, end, step, sel_idx); - } - - if (sel_valid) { - result_child_vector->Slice(sel, sel_length); - ListVector::SetListSize(result, sel_length); - } -} - -template -static void ExecuteFlatSlice(Vector &result, Vector &list_vector, Vector &begin_vector, Vector &end_vector, - optional_ptr step_vector, const idx_t count, SelectionVector &sel, idx_t &sel_idx, - optional_ptr result_child_vector, bool begin_is_empty, bool end_is_empty) { - UnifiedVectorFormat list_data, begin_data, end_data, step_data; - idx_t sel_length = 0; - - list_vector.ToUnifiedFormat(count, list_data); - begin_vector.ToUnifiedFormat(count, begin_data); - end_vector.ToUnifiedFormat(count, end_data); - if (step_vector) { - step_vector->ToUnifiedFormat(count, step_data); - sel.Initialize(ListVector::GetListSize(list_vector)); - } - - auto result_data = FlatVector::GetData(result); - auto &result_mask = FlatVector::Validity(result); - - for (idx_t i = 0; i < count; ++i) { - auto list_idx = list_data.sel->get_index(i); - auto begin_idx = begin_data.sel->get_index(i); - auto end_idx = end_data.sel->get_index(i); - auto step_idx = step_vector ? step_data.sel->get_index(i) : 0; - - auto list_valid = list_data.validity.RowIsValid(list_idx); - auto begin_valid = begin_data.validity.RowIsValid(begin_idx); - auto end_valid = end_data.validity.RowIsValid(end_idx); - auto step_valid = step_vector && step_data.validity.RowIsValid(step_idx); - - if (!list_valid || !begin_valid || !end_valid || (step_vector && !step_valid)) { - result_mask.SetInvalid(i); - continue; - } - - auto sliced = reinterpret_cast(list_data.data)[list_idx]; - auto begin = begin_is_empty ? 0 : reinterpret_cast(begin_data.data)[begin_idx]; - auto end = end_is_empty ? ValueLength(sliced) - : reinterpret_cast(end_data.data)[end_idx]; - auto step = step_vector ? reinterpret_cast(step_data.data)[step_idx] : 1; - - if (step < 0) { - swap(begin, end); - begin = end_is_empty ? 0 : begin; - end = begin_is_empty ? ValueLength(sliced) : end; - } - - bool clamp_result = false; - if (step_valid || step == 1) { - clamp_result = ClampSlice(sliced, begin, end); - } - - auto length = 0; - if (end - begin > 0) { - length = CalculateSliceLength(begin, end, step, step_valid); - } - sel_length += length; - - if (!clamp_result) { - result_mask.SetInvalid(i); - } else if (!step_vector) { - result_data[i] = SliceValue(result, sliced, begin, end); - } else { - result_data[i] = - SliceValueWithSteps(result, sel, sliced, begin, end, step, sel_idx); - } - } - if (step_vector) { - SelectionVector new_sel(sel_length); - for (idx_t i = 0; i < sel_length; ++i) { - new_sel.set_index(i, sel.get_index(i)); - } - result_child_vector->Slice(new_sel, sel_length); - ListVector::SetListSize(result, sel_length); - } -} - -template -static void ExecuteSlice(Vector &result, Vector &list_or_str_vector, Vector &begin_vector, Vector &end_vector, - optional_ptr step_vector, const idx_t count, bool begin_is_empty, bool end_is_empty) { - optional_ptr result_child_vector; - if (step_vector) { - result_child_vector = &ListVector::GetEntry(result); - } - - SelectionVector sel; - idx_t sel_idx = 0; - - if (result.GetVectorType() == VectorType::CONSTANT_VECTOR) { - ExecuteConstantSlice(result, list_or_str_vector, begin_vector, end_vector, step_vector, - count, sel, sel_idx, result_child_vector, begin_is_empty, - end_is_empty); - } else { - ExecuteFlatSlice(result, list_or_str_vector, begin_vector, end_vector, step_vector, - count, sel, sel_idx, result_child_vector, begin_is_empty, - end_is_empty); - } - result.Verify(count); -} - -static void ArraySliceFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 3 || args.ColumnCount() == 4); - D_ASSERT(args.data.size() == 3 || args.data.size() == 4); - auto count = args.size(); - - Vector &list_or_str_vector = args.data[0]; - if (list_or_str_vector.GetType().id() == LogicalTypeId::SQLNULL) { - auto &result_validity = FlatVector::Validity(result); - result_validity.SetInvalid(0); - return; - } - - Vector &begin_vector = args.data[1]; - Vector &end_vector = args.data[2]; - - optional_ptr step_vector; - if (args.ColumnCount() == 4) { - step_vector = &args.data[3]; - } - - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - auto begin_is_empty = info.begin_is_empty; - auto end_is_empty = info.end_is_empty; - - result.SetVectorType(args.AllConstant() ? VectorType::CONSTANT_VECTOR : VectorType::FLAT_VECTOR); - switch (result.GetType().id()) { - case LogicalTypeId::LIST: { - // Share the value dictionary as we are just going to slice it - if (list_or_str_vector.GetVectorType() != VectorType::FLAT_VECTOR && - list_or_str_vector.GetVectorType() != VectorType::CONSTANT_VECTOR) { - list_or_str_vector.Flatten(count); - } - ListVector::ReferenceEntry(result, list_or_str_vector); - ExecuteSlice(result, list_or_str_vector, begin_vector, end_vector, step_vector, count, - begin_is_empty, end_is_empty); - break; - } - case LogicalTypeId::VARCHAR: { - ExecuteSlice(result, list_or_str_vector, begin_vector, end_vector, step_vector, count, - begin_is_empty, end_is_empty); - break; - } - default: - throw NotImplementedException("Specifier type not implemented"); - } -} - -static bool CheckIfParamIsEmpty(duckdb::unique_ptr ¶m) { - bool is_empty = false; - if (param->return_type.id() == LogicalTypeId::LIST) { - auto empty_list = make_uniq(Value::LIST(LogicalType::INTEGER, vector())); - is_empty = param->Equals(*empty_list); - if (!is_empty) { - // if the param is not empty, the user has entered a list instead of a BIGINT - throw BinderException("The upper and lower bounds of the slice must be a BIGINT"); - } - } - return is_empty; -} - -static unique_ptr ArraySliceBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - D_ASSERT(arguments.size() == 3 || arguments.size() == 4); - D_ASSERT(bound_function.arguments.size() == 3 || bound_function.arguments.size() == 4); - - switch (arguments[0]->return_type.id()) { - case LogicalTypeId::LIST: - // The result is the same type - bound_function.return_type = arguments[0]->return_type; - break; - case LogicalTypeId::VARCHAR: - // string slice returns a string - if (bound_function.arguments.size() == 4) { - throw NotImplementedException( - "Slice with steps has not been implemented for string types, you can consider rewriting your query as " - "follows:\n SELECT array_to_string((str_split(string, '')[begin:end:step], '');"); - } - bound_function.return_type = arguments[0]->return_type; - for (idx_t i = 1; i < 3; i++) { - if (arguments[i]->return_type.id() != LogicalTypeId::LIST) { - bound_function.arguments[i] = LogicalType::BIGINT; - } - } - break; - case LogicalTypeId::SQLNULL: - case LogicalTypeId::UNKNOWN: - bound_function.arguments[0] = LogicalTypeId::UNKNOWN; - bound_function.return_type = LogicalType::SQLNULL; - break; - default: - throw BinderException("ARRAY_SLICE can only operate on LISTs and VARCHARs"); - } - - bool begin_is_empty = CheckIfParamIsEmpty(arguments[1]); - if (!begin_is_empty) { - bound_function.arguments[1] = LogicalType::BIGINT; - } - bool end_is_empty = CheckIfParamIsEmpty(arguments[2]); - if (!end_is_empty) { - bound_function.arguments[2] = LogicalType::BIGINT; - } - - return make_uniq(bound_function.return_type, begin_is_empty, end_is_empty); -} - -ScalarFunctionSet ListSliceFun::GetFunctions() { - // the arguments and return types are actually set in the binder function - ScalarFunction fun({LogicalType::ANY, LogicalType::ANY, LogicalType::ANY}, LogicalType::ANY, ArraySliceFunction, - ArraySliceBind); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - - ScalarFunctionSet set; - set.AddFunction(fun); - fun.arguments.push_back(LogicalType::BIGINT); - set.AddFunction(fun); - return set; -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -void ListFlattenFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 1); - - Vector &input = args.data[0]; - if (input.GetType().id() == LogicalTypeId::SQLNULL) { - result.Reference(input); - return; - } - - idx_t count = args.size(); - - UnifiedVectorFormat list_data; - input.ToUnifiedFormat(count, list_data); - auto list_entries = UnifiedVectorFormat::GetData(list_data); - auto &child_vector = ListVector::GetEntry(input); - - result.SetVectorType(VectorType::FLAT_VECTOR); - auto result_entries = FlatVector::GetData(result); - auto &result_validity = FlatVector::Validity(result); - - if (child_vector.GetType().id() == LogicalTypeId::SQLNULL) { - for (idx_t i = 0; i < count; i++) { - auto list_index = list_data.sel->get_index(i); - if (!list_data.validity.RowIsValid(list_index)) { - result_validity.SetInvalid(i); - continue; - } - result_entries[i].offset = 0; - result_entries[i].length = 0; - } - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } - return; - } - - auto child_size = ListVector::GetListSize(input); - UnifiedVectorFormat child_data; - child_vector.ToUnifiedFormat(child_size, child_data); - auto child_entries = UnifiedVectorFormat::GetData(child_data); - auto &data_vector = ListVector::GetEntry(child_vector); - - idx_t offset = 0; - for (idx_t i = 0; i < count; i++) { - auto list_index = list_data.sel->get_index(i); - if (!list_data.validity.RowIsValid(list_index)) { - result_validity.SetInvalid(i); - continue; - } - auto list_entry = list_entries[list_index]; - - idx_t source_offset = 0; - // Find first valid child list entry to get offset - for (idx_t j = 0; j < list_entry.length; j++) { - auto child_list_index = child_data.sel->get_index(list_entry.offset + j); - if (child_data.validity.RowIsValid(child_list_index)) { - source_offset = child_entries[child_list_index].offset; - break; - } - } - - idx_t length = 0; - // Find last valid child list entry to get length - for (idx_t j = list_entry.length - 1; j != (idx_t)-1; j--) { - auto child_list_index = child_data.sel->get_index(list_entry.offset + j); - if (child_data.validity.RowIsValid(child_list_index)) { - auto child_entry = child_entries[child_list_index]; - length = child_entry.offset + child_entry.length - source_offset; - break; - } - } - ListVector::Append(result, data_vector, source_offset + length, source_offset); - - result_entries[i].offset = offset; - result_entries[i].length = length; - offset += length; - } - - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } -} - -static unique_ptr ListFlattenBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - D_ASSERT(bound_function.arguments.size() == 1); - - auto &input_type = arguments[0]->return_type; - bound_function.arguments[0] = input_type; - if (input_type.id() == LogicalTypeId::UNKNOWN) { - bound_function.arguments[0] = LogicalType(LogicalTypeId::UNKNOWN); - bound_function.return_type = LogicalType(LogicalTypeId::SQLNULL); - return nullptr; - } - D_ASSERT(input_type.id() == LogicalTypeId::LIST); - - auto child_type = ListType::GetChildType(input_type); - if (child_type.id() == LogicalType::SQLNULL) { - bound_function.return_type = input_type; - return make_uniq(bound_function.return_type); - } - if (child_type.id() == LogicalTypeId::UNKNOWN) { - bound_function.arguments[0] = LogicalType(LogicalTypeId::UNKNOWN); - bound_function.return_type = LogicalType(LogicalTypeId::SQLNULL); - return nullptr; - } - D_ASSERT(child_type.id() == LogicalTypeId::LIST); - - bound_function.return_type = child_type; - return make_uniq(bound_function.return_type); -} - -static unique_ptr ListFlattenStats(ClientContext &context, FunctionStatisticsInput &input) { - auto &child_stats = input.child_stats; - auto &list_child_stats = ListStats::GetChildStats(child_stats[0]); - auto child_copy = list_child_stats.Copy(); - child_copy.Set(StatsInfo::CAN_HAVE_NULL_VALUES); - return child_copy.ToUnique(); -} - -ScalarFunction ListFlattenFun::GetFunction() { - return ScalarFunction({LogicalType::LIST(LogicalType::LIST(LogicalType::ANY))}, LogicalType::LIST(LogicalType::ANY), - ListFlattenFunction, ListFlattenBind, nullptr, ListFlattenStats); -} - -} // namespace duckdb - - - - - - - - - - - - -namespace duckdb { - -// FIXME: use a local state for each thread to increase performance? -// FIXME: benchmark the use of simple_update against using update (if applicable) - -static unique_ptr ListAggregatesBindFailure(ScalarFunction &bound_function) { - bound_function.arguments[0] = LogicalType::SQLNULL; - bound_function.return_type = LogicalType::SQLNULL; - return make_uniq(LogicalType::SQLNULL); -} - -struct ListAggregatesBindData : public FunctionData { - ListAggregatesBindData(const LogicalType &stype_p, unique_ptr aggr_expr_p); - ~ListAggregatesBindData() override; - - LogicalType stype; - unique_ptr aggr_expr; - - unique_ptr Copy() const override { - return make_uniq(stype, aggr_expr->Copy()); - } - - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return stype == other.stype && aggr_expr->Equals(*other.aggr_expr); - } - void Serialize(Serializer &serializer) const { - serializer.WriteProperty(1, "stype", stype); - serializer.WriteProperty(2, "aggr_expr", aggr_expr); - } - static unique_ptr Deserialize(Deserializer &deserializer) { - auto stype = deserializer.ReadProperty(1, "stype"); - auto aggr_expr = deserializer.ReadProperty>(2, "aggr_expr"); - auto result = make_uniq(std::move(stype), std::move(aggr_expr)); - return result; - } - - static void Serialize(Serializer &serializer, const optional_ptr bind_data_p, - const ScalarFunction &function) { - auto bind_data = dynamic_cast(bind_data_p.get()); - serializer.WritePropertyWithDefault(100, "bind_data", bind_data, (const ListAggregatesBindData *)nullptr); - } - - static unique_ptr Deserialize(Deserializer &deserializer, ScalarFunction &bound_function) { - auto result = deserializer.ReadPropertyWithDefault>( - 100, "bind_data", unique_ptr(nullptr)); - if (!result) { - return ListAggregatesBindFailure(bound_function); - } - return std::move(result); - } -}; - -ListAggregatesBindData::ListAggregatesBindData(const LogicalType &stype_p, unique_ptr aggr_expr_p) - : stype(stype_p), aggr_expr(std::move(aggr_expr_p)) { -} - -ListAggregatesBindData::~ListAggregatesBindData() { -} - -struct StateVector { - StateVector(idx_t count_p, unique_ptr aggr_expr_p) - : count(count_p), aggr_expr(std::move(aggr_expr_p)), state_vector(Vector(LogicalType::POINTER, count_p)) { - } - - ~StateVector() { // NOLINT - // destroy objects within the aggregate states - auto &aggr = aggr_expr->Cast(); - if (aggr.function.destructor) { - ArenaAllocator allocator(Allocator::DefaultAllocator()); - AggregateInputData aggr_input_data(aggr.bind_info.get(), allocator); - aggr.function.destructor(state_vector, aggr_input_data, count); - } - } - - idx_t count; - unique_ptr aggr_expr; - Vector state_vector; -}; - -struct FinalizeValueFunctor { - template - static Value FinalizeValue(T first) { - return Value::CreateValue(first); - } -}; - -struct FinalizeStringValueFunctor { - template - static Value FinalizeValue(T first) { - string_t value = first; - return Value::CreateValue(value); - } -}; - -struct AggregateFunctor { - template > - static void ListExecuteFunction(Vector &result, Vector &state_vector, idx_t count) { - } -}; - -struct DistinctFunctor { - template > - static void ListExecuteFunction(Vector &result, Vector &state_vector, idx_t count) { - - UnifiedVectorFormat sdata; - state_vector.ToUnifiedFormat(count, sdata); - auto states = (HistogramAggState **)sdata.data; - - auto result_data = FlatVector::GetData(result); - - idx_t offset = 0; - for (idx_t i = 0; i < count; i++) { - - auto state = states[sdata.sel->get_index(i)]; - result_data[i].offset = offset; - - if (!state->hist) { - result_data[i].length = 0; - continue; - } - - result_data[i].length = state->hist->size(); - offset += state->hist->size(); - - for (auto &entry : *state->hist) { - Value bucket_value = OP::template FinalizeValue(entry.first); - ListVector::PushBack(result, bucket_value); - } - } - result.Verify(count); - } -}; - -struct UniqueFunctor { - template > - static void ListExecuteFunction(Vector &result, Vector &state_vector, idx_t count) { - - UnifiedVectorFormat sdata; - state_vector.ToUnifiedFormat(count, sdata); - auto states = (HistogramAggState **)sdata.data; - - auto result_data = FlatVector::GetData(result); - - for (idx_t i = 0; i < count; i++) { - - auto state = states[sdata.sel->get_index(i)]; - - if (!state->hist) { - result_data[i] = 0; - continue; - } - - result_data[i] = state->hist->size(); - } - result.Verify(count); - } -}; - -template -static void ListAggregatesFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto count = args.size(); - Vector &lists = args.data[0]; - - // set the result vector - result.SetVectorType(VectorType::FLAT_VECTOR); - auto &result_validity = FlatVector::Validity(result); - - if (lists.GetType().id() == LogicalTypeId::SQLNULL) { - result_validity.SetInvalid(0); - return; - } - - // get the aggregate function - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - auto &aggr = info.aggr_expr->Cast(); - ArenaAllocator allocator(Allocator::DefaultAllocator()); - AggregateInputData aggr_input_data(aggr.bind_info.get(), allocator); - - D_ASSERT(aggr.function.update); - - auto lists_size = ListVector::GetListSize(lists); - auto &child_vector = ListVector::GetEntry(lists); - child_vector.Flatten(lists_size); - - UnifiedVectorFormat child_data; - child_vector.ToUnifiedFormat(lists_size, child_data); - - UnifiedVectorFormat lists_data; - lists.ToUnifiedFormat(count, lists_data); - auto list_entries = UnifiedVectorFormat::GetData(lists_data); - - // state_buffer holds the state for each list of this chunk - idx_t size = aggr.function.state_size(); - auto state_buffer = make_unsafe_uniq_array(size * count); - - // state vector for initialize and finalize - StateVector state_vector(count, info.aggr_expr->Copy()); - auto states = FlatVector::GetData(state_vector.state_vector); - - // state vector of STANDARD_VECTOR_SIZE holds the pointers to the states - Vector state_vector_update = Vector(LogicalType::POINTER); - auto states_update = FlatVector::GetData(state_vector_update); - - // selection vector pointing to the data - SelectionVector sel_vector(STANDARD_VECTOR_SIZE); - idx_t states_idx = 0; - - for (idx_t i = 0; i < count; i++) { - - // initialize the state for this list - auto state_ptr = state_buffer.get() + size * i; - states[i] = state_ptr; - aggr.function.initialize(states[i]); - - auto lists_index = lists_data.sel->get_index(i); - const auto &list_entry = list_entries[lists_index]; - - // nothing to do for this list - if (!lists_data.validity.RowIsValid(lists_index)) { - result_validity.SetInvalid(i); - continue; - } - - // skip empty list - if (list_entry.length == 0) { - continue; - } - - for (idx_t child_idx = 0; child_idx < list_entry.length; child_idx++) { - // states vector is full, update - if (states_idx == STANDARD_VECTOR_SIZE) { - // update the aggregate state(s) - Vector slice(child_vector, sel_vector, states_idx); - aggr.function.update(&slice, aggr_input_data, 1, state_vector_update, states_idx); - - // reset values - states_idx = 0; - } - - auto source_idx = child_data.sel->get_index(list_entry.offset + child_idx); - sel_vector.set_index(states_idx, source_idx); - states_update[states_idx] = state_ptr; - states_idx++; - } - } - - // update the remaining elements of the last list(s) - if (states_idx != 0) { - Vector slice(child_vector, sel_vector, states_idx); - aggr.function.update(&slice, aggr_input_data, 1, state_vector_update, states_idx); - } - - if (IS_AGGR) { - // finalize all the aggregate states - aggr.function.finalize(state_vector.state_vector, aggr_input_data, result, count, 0); - - } else { - // finalize manually to use the map - D_ASSERT(aggr.function.arguments.size() == 1); - auto key_type = aggr.function.arguments[0]; - - switch (key_type.InternalType()) { - case PhysicalType::BOOL: - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - break; - case PhysicalType::UINT8: - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - break; - case PhysicalType::UINT16: - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - break; - case PhysicalType::UINT32: - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - break; - case PhysicalType::UINT64: - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - break; - case PhysicalType::INT8: - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - break; - case PhysicalType::INT16: - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - break; - case PhysicalType::INT32: - if (key_type.id() == LogicalTypeId::DATE) { - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - } else { - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - } - break; - case PhysicalType::INT64: - switch (key_type.id()) { - case LogicalTypeId::TIME: - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - break; - case LogicalTypeId::TIME_TZ: - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - break; - case LogicalTypeId::TIMESTAMP: - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - break; - case LogicalTypeId::TIMESTAMP_MS: - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - break; - case LogicalTypeId::TIMESTAMP_NS: - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - break; - case LogicalTypeId::TIMESTAMP_SEC: - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - break; - case LogicalTypeId::TIMESTAMP_TZ: - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - break; - default: - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - break; - } - break; - case PhysicalType::FLOAT: - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - break; - case PhysicalType::DOUBLE: - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - break; - case PhysicalType::VARCHAR: - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - break; - default: - throw InternalException("Unimplemented histogram aggregate"); - } - } - - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } -} - -static void ListAggregateFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() >= 2); - ListAggregatesFunction(args, state, result); -} - -static void ListDistinctFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 1); - ListAggregatesFunction(args, state, result); -} - -static void ListUniqueFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 1); - ListAggregatesFunction(args, state, result); -} - -template -static unique_ptr -ListAggregatesBindFunction(ClientContext &context, ScalarFunction &bound_function, const LogicalType &list_child_type, - AggregateFunction &aggr_function, vector> &arguments) { - - // create the child expression and its type - vector> children; - auto expr = make_uniq(Value(list_child_type)); - children.push_back(std::move(expr)); - // push any extra arguments into the list aggregate bind - if (arguments.size() > 2) { - for (idx_t i = 2; i < arguments.size(); i++) { - children.push_back(std::move(arguments[i])); - } - arguments.resize(2); - } - - FunctionBinder function_binder(context); - auto bound_aggr_function = function_binder.BindAggregateFunction(aggr_function, std::move(children)); - bound_function.arguments[0] = LogicalType::LIST(bound_aggr_function->function.arguments[0]); - - if (IS_AGGR) { - bound_function.return_type = bound_aggr_function->function.return_type; - } - // check if the aggregate function consumed all the extra input arguments - if (bound_aggr_function->children.size() > 1) { - throw InvalidInputException( - "Aggregate function %s is not supported for list_aggr: extra arguments were not removed during bind", - bound_aggr_function->ToString()); - } - - return make_uniq(bound_function.return_type, std::move(bound_aggr_function)); -} - -template -static unique_ptr ListAggregatesBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - if (arguments[0]->return_type.id() == LogicalTypeId::SQLNULL) { - return ListAggregatesBindFailure(bound_function); - } - - bool is_parameter = arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN; - auto list_child_type = is_parameter ? LogicalTypeId::UNKNOWN : ListType::GetChildType(arguments[0]->return_type); - - string function_name = "histogram"; - if (IS_AGGR) { // get the name of the aggregate function - if (!arguments[1]->IsFoldable()) { - throw InvalidInputException("Aggregate function name must be a constant"); - } - // get the function name - Value function_value = ExpressionExecutor::EvaluateScalar(context, *arguments[1]); - function_name = function_value.ToString(); - } - - // look up the aggregate function in the catalog - QueryErrorContext error_context(nullptr, 0); - auto &func = Catalog::GetSystemCatalog(context).GetEntry( - context, DEFAULT_SCHEMA, function_name, error_context); - D_ASSERT(func.type == CatalogType::AGGREGATE_FUNCTION_ENTRY); - - if (is_parameter) { - bound_function.arguments[0] = LogicalTypeId::UNKNOWN; - bound_function.return_type = LogicalType::SQLNULL; - return nullptr; - } - - // find a matching aggregate function - string error; - vector types; - types.push_back(list_child_type); - // push any extra arguments into the type list - for (idx_t i = 2; i < arguments.size(); i++) { - types.push_back(arguments[i]->return_type); - } - - FunctionBinder function_binder(context); - auto best_function_idx = function_binder.BindFunction(func.name, func.functions, types, error); - if (best_function_idx == DConstants::INVALID_INDEX) { - throw BinderException("No matching aggregate function\n%s", error); - } - - // found a matching function, bind it as an aggregate - auto best_function = func.functions.GetFunctionByOffset(best_function_idx); - if (IS_AGGR) { - return ListAggregatesBindFunction(context, bound_function, list_child_type, best_function, arguments); - } - - // create the unordered map histogram function - D_ASSERT(best_function.arguments.size() == 1); - auto key_type = best_function.arguments[0]; - auto aggr_function = HistogramFun::GetHistogramUnorderedMap(key_type); - return ListAggregatesBindFunction(context, bound_function, list_child_type, aggr_function, arguments); -} - -static unique_ptr ListAggregateBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - - // the list column and the name of the aggregate function - D_ASSERT(bound_function.arguments.size() >= 2); - D_ASSERT(arguments.size() >= 2); - - return ListAggregatesBind(context, bound_function, arguments); -} - -static unique_ptr ListDistinctBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - - D_ASSERT(bound_function.arguments.size() == 1); - D_ASSERT(arguments.size() == 1); - bound_function.return_type = arguments[0]->return_type; - - return ListAggregatesBind<>(context, bound_function, arguments); -} - -static unique_ptr ListUniqueBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - - D_ASSERT(bound_function.arguments.size() == 1); - D_ASSERT(arguments.size() == 1); - bound_function.return_type = LogicalType::UBIGINT; - - return ListAggregatesBind<>(context, bound_function, arguments); -} - -ScalarFunction ListAggregateFun::GetFunction() { - auto result = ScalarFunction({LogicalType::LIST(LogicalType::ANY), LogicalType::VARCHAR}, LogicalType::ANY, - ListAggregateFunction, ListAggregateBind); - result.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - result.varargs = LogicalType::ANY; - result.serialize = ListAggregatesBindData::Serialize; - result.deserialize = ListAggregatesBindData::Deserialize; - return result; -} - -ScalarFunction ListDistinctFun::GetFunction() { - return ScalarFunction({LogicalType::LIST(LogicalType::ANY)}, LogicalType::LIST(LogicalType::ANY), - ListDistinctFunction, ListDistinctBind); -} - -ScalarFunction ListUniqueFun::GetFunction() { - return ScalarFunction({LogicalType::LIST(LogicalType::ANY)}, LogicalType::UBIGINT, ListUniqueFunction, - ListUniqueBind); -} - -} // namespace duckdb - -#include -#include - -namespace duckdb { - -template -static void ListCosineSimilarity(DataChunk &args, ExpressionState &, Vector &result) { - D_ASSERT(args.ColumnCount() == 2); - - auto count = args.size(); - auto &left = args.data[0]; - auto &right = args.data[1]; - auto left_count = ListVector::GetListSize(left); - auto right_count = ListVector::GetListSize(right); - - auto &left_child = ListVector::GetEntry(left); - auto &right_child = ListVector::GetEntry(right); - - D_ASSERT(left_child.GetVectorType() == VectorType::FLAT_VECTOR); - D_ASSERT(right_child.GetVectorType() == VectorType::FLAT_VECTOR); - - if (!FlatVector::Validity(left_child).CheckAllValid(left_count)) { - throw InvalidInputException("list_cosine_similarity: left argument can not contain NULL values"); - } - - if (!FlatVector::Validity(right_child).CheckAllValid(right_count)) { - throw InvalidInputException("list_cosine_similarity: right argument can not contain NULL values"); - } - - auto left_data = FlatVector::GetData(left_child); - auto right_data = FlatVector::GetData(right_child); - - BinaryExecutor::Execute( - left, right, result, count, [&](list_entry_t left, list_entry_t right) { - if (left.length != right.length) { - throw InvalidInputException(StringUtil::Format( - "list_cosine_similarity: list dimensions must be equal, got left length %d and right length %d", - left.length, right.length)); - } - - auto dimensions = left.length; - - NUMERIC_TYPE distance = 0; - NUMERIC_TYPE norm_l = 0; - NUMERIC_TYPE norm_r = 0; - - auto l_ptr = left_data + left.offset; - auto r_ptr = right_data + right.offset; - for (idx_t i = 0; i < dimensions; i++) { - auto x = *l_ptr++; - auto y = *r_ptr++; - distance += x * y; - norm_l += x * x; - norm_r += y * y; - } - - auto similarity = distance / (std::sqrt(norm_l) * std::sqrt(norm_r)); - - // clamp to [-1, 1] to avoid floating point errors - return std::max(static_cast(-1), std::min(similarity, static_cast(1))); - }); - - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } -} - -ScalarFunctionSet ListCosineSimilarityFun::GetFunctions() { - ScalarFunctionSet set("list_cosine_similarity"); - set.AddFunction(ScalarFunction({LogicalType::LIST(LogicalType::FLOAT), LogicalType::LIST(LogicalType::FLOAT)}, - LogicalType::FLOAT, ListCosineSimilarity)); - set.AddFunction(ScalarFunction({LogicalType::LIST(LogicalType::DOUBLE), LogicalType::LIST(LogicalType::DOUBLE)}, - LogicalType::DOUBLE, ListCosineSimilarity)); - return set; -} - -} // namespace duckdb - -#include - -namespace duckdb { - -template -static void ListDistance(DataChunk &args, ExpressionState &, Vector &result) { - D_ASSERT(args.ColumnCount() == 2); - - auto count = args.size(); - auto &left = args.data[0]; - auto &right = args.data[1]; - auto left_count = ListVector::GetListSize(left); - auto right_count = ListVector::GetListSize(right); - - auto &left_child = ListVector::GetEntry(left); - auto &right_child = ListVector::GetEntry(right); - - D_ASSERT(left_child.GetVectorType() == VectorType::FLAT_VECTOR); - D_ASSERT(right_child.GetVectorType() == VectorType::FLAT_VECTOR); - - if (!FlatVector::Validity(left_child).CheckAllValid(left_count)) { - throw InvalidInputException("list_distance: left argument can not contain NULL values"); - } - - if (!FlatVector::Validity(right_child).CheckAllValid(right_count)) { - throw InvalidInputException("list_distance: right argument can not contain NULL values"); - } - - auto left_data = FlatVector::GetData(left_child); - auto right_data = FlatVector::GetData(right_child); - - BinaryExecutor::Execute( - left, right, result, count, [&](list_entry_t left, list_entry_t right) { - if (left.length != right.length) { - throw InvalidInputException(StringUtil::Format( - "list_distance: list dimensions must be equal, got left length %d and right length %d", left.length, - right.length)); - } - - auto dimensions = left.length; - - NUMERIC_TYPE distance = 0; - - auto l_ptr = left_data + left.offset; - auto r_ptr = right_data + right.offset; - - for (idx_t i = 0; i < dimensions; i++) { - auto x = *l_ptr++; - auto y = *r_ptr++; - auto diff = x - y; - distance += diff * diff; - } - - return std::sqrt(distance); - }); - - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } -} - -ScalarFunctionSet ListDistanceFun::GetFunctions() { - ScalarFunctionSet set("list_distance"); - set.AddFunction(ScalarFunction({LogicalType::LIST(LogicalType::FLOAT), LogicalType::LIST(LogicalType::FLOAT)}, - LogicalType::FLOAT, ListDistance)); - set.AddFunction(ScalarFunction({LogicalType::LIST(LogicalType::DOUBLE), LogicalType::LIST(LogicalType::DOUBLE)}, - LogicalType::DOUBLE, ListDistance)); - return set; -} - -} // namespace duckdb - - -namespace duckdb { - -template -static void ListInnerProduct(DataChunk &args, ExpressionState &, Vector &result) { - D_ASSERT(args.ColumnCount() == 2); - - auto count = args.size(); - auto &left = args.data[0]; - auto &right = args.data[1]; - auto left_count = ListVector::GetListSize(left); - auto right_count = ListVector::GetListSize(right); - - auto &left_child = ListVector::GetEntry(left); - auto &right_child = ListVector::GetEntry(right); - - D_ASSERT(left_child.GetVectorType() == VectorType::FLAT_VECTOR); - D_ASSERT(right_child.GetVectorType() == VectorType::FLAT_VECTOR); - - if (!FlatVector::Validity(left_child).CheckAllValid(left_count)) { - throw InvalidInputException("list_inner_product: left argument can not contain NULL values"); - } - - if (!FlatVector::Validity(right_child).CheckAllValid(right_count)) { - throw InvalidInputException("list_inner_product: right argument can not contain NULL values"); - } - - auto left_data = FlatVector::GetData(left_child); - auto right_data = FlatVector::GetData(right_child); - - BinaryExecutor::Execute( - left, right, result, count, [&](list_entry_t left, list_entry_t right) { - if (left.length != right.length) { - throw InvalidInputException(StringUtil::Format( - "list_inner_product: list dimensions must be equal, got left length %d and right length %d", - left.length, right.length)); - } - - auto dimensions = left.length; - - NUMERIC_TYPE distance = 0; - - auto l_ptr = left_data + left.offset; - auto r_ptr = right_data + right.offset; - - for (idx_t i = 0; i < dimensions; i++) { - auto x = *l_ptr++; - auto y = *r_ptr++; - distance += x * y; - } - - return distance; - }); - - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } -} - -ScalarFunctionSet ListInnerProductFun::GetFunctions() { - ScalarFunctionSet set("list_inner_product"); - set.AddFunction(ScalarFunction({LogicalType::LIST(LogicalType::FLOAT), LogicalType::LIST(LogicalType::FLOAT)}, - LogicalType::FLOAT, ListInnerProduct)); - set.AddFunction(ScalarFunction({LogicalType::LIST(LogicalType::DOUBLE), LogicalType::LIST(LogicalType::DOUBLE)}, - LogicalType::DOUBLE, ListInnerProduct)); - return set; -} - -} // namespace duckdb - - - - - - - - - - - - - -namespace duckdb { - -struct ListLambdaBindData : public FunctionData { - ListLambdaBindData(const LogicalType &stype_p, unique_ptr lambda_expr); - ~ListLambdaBindData() override; - - LogicalType stype; - unique_ptr lambda_expr; - -public: - bool Equals(const FunctionData &other_p) const override; - unique_ptr Copy() const override; - - static void Serialize(Serializer &serializer, const optional_ptr bind_data_p, - const ScalarFunction &function) { - // auto &bind_data = bind_data_p->Cast(); - // serializer.WriteProperty(100, "stype", bind_data.stype); - // serializer.WritePropertyWithDefault(101, "lambda_expr", bind_data.lambda_expr, - // unique_ptr()); - throw NotImplementedException("FIXME: list lambda serialize"); - } - - static unique_ptr Deserialize(Deserializer &deserializer, ScalarFunction &function) { - // auto stype = deserializer.ReadProperty(100, "stype"); - // auto lambda_expr = - // deserializer.ReadPropertyWithDefault>(101, "lambda_expr", - // unique_ptr()); return make_uniq(stype, std::move(lambda_expr)); - throw NotImplementedException("FIXME: list lambda deserialize"); - } -}; - -ListLambdaBindData::ListLambdaBindData(const LogicalType &stype_p, unique_ptr lambda_expr_p) - : stype(stype_p), lambda_expr(std::move(lambda_expr_p)) { -} - -unique_ptr ListLambdaBindData::Copy() const { - return make_uniq(stype, lambda_expr ? lambda_expr->Copy() : nullptr); -} - -bool ListLambdaBindData::Equals(const FunctionData &other_p) const { - auto &other = other_p.Cast(); - return Expression::Equals(lambda_expr, other.lambda_expr) && stype == other.stype; -} - -ListLambdaBindData::~ListLambdaBindData() { -} - -static void AppendTransformedToResult(Vector &lambda_vector, idx_t &elem_cnt, Vector &result) { - - // append the lambda_vector to the result list - UnifiedVectorFormat lambda_child_data; - lambda_vector.ToUnifiedFormat(elem_cnt, lambda_child_data); - ListVector::Append(result, lambda_vector, *lambda_child_data.sel, elem_cnt, 0); -} - -static void AppendFilteredToResult(Vector &lambda_vector, list_entry_t *result_entries, idx_t &elem_cnt, Vector &result, - idx_t &curr_list_len, idx_t &curr_list_offset, idx_t &appended_lists_cnt, - vector &lists_len, idx_t &curr_original_list_len, DataChunk &input_chunk) { - - idx_t true_count = 0; - SelectionVector true_sel(elem_cnt); - UnifiedVectorFormat lambda_data; - lambda_vector.ToUnifiedFormat(elem_cnt, lambda_data); - - auto lambda_values = UnifiedVectorFormat::GetData(lambda_data); - auto &lambda_validity = lambda_data.validity; - - // compute the new lengths and offsets, and create a selection vector - for (idx_t i = 0; i < elem_cnt; i++) { - auto entry = lambda_data.sel->get_index(i); - - while (appended_lists_cnt < lists_len.size() && lists_len[appended_lists_cnt] == 0) { - result_entries[appended_lists_cnt].offset = curr_list_offset; - result_entries[appended_lists_cnt].length = 0; - appended_lists_cnt++; - } - - // found a true value - if (lambda_validity.RowIsValid(entry) && lambda_values[entry]) { - true_sel.set_index(true_count++, i); - curr_list_len++; - } - - curr_original_list_len++; - - if (lists_len[appended_lists_cnt] == curr_original_list_len) { - result_entries[appended_lists_cnt].offset = curr_list_offset; - result_entries[appended_lists_cnt].length = curr_list_len; - curr_list_offset += curr_list_len; - appended_lists_cnt++; - curr_list_len = 0; - curr_original_list_len = 0; - } - } - - while (appended_lists_cnt < lists_len.size() && lists_len[appended_lists_cnt] == 0) { - result_entries[appended_lists_cnt].offset = curr_list_offset; - result_entries[appended_lists_cnt].length = 0; - appended_lists_cnt++; - } - - // slice to get the new lists and append them to the result - Vector new_lists(input_chunk.data[0], true_sel, true_count); - new_lists.Flatten(true_count); - UnifiedVectorFormat new_lists_child_data; - new_lists.ToUnifiedFormat(true_count, new_lists_child_data); - ListVector::Append(result, new_lists, *new_lists_child_data.sel, true_count, 0); -} - -static void ExecuteExpression(vector &types, vector &result_types, idx_t &elem_cnt, - SelectionVector &sel, vector &sel_vectors, DataChunk &input_chunk, - DataChunk &lambda_chunk, Vector &child_vector, DataChunk &args, - ExpressionExecutor &expr_executor) { - - input_chunk.SetCardinality(elem_cnt); - lambda_chunk.SetCardinality(elem_cnt); - - // set the list child vector - Vector slice(child_vector, sel, elem_cnt); - Vector second_slice(child_vector, sel, elem_cnt); - slice.Flatten(elem_cnt); - second_slice.Flatten(elem_cnt); - - input_chunk.data[0].Reference(slice); - input_chunk.data[1].Reference(second_slice); - - // set the other vectors - vector slices; - for (idx_t col_idx = 0; col_idx < args.ColumnCount() - 1; col_idx++) { - slices.emplace_back(args.data[col_idx + 1], sel_vectors[col_idx], elem_cnt); - slices[col_idx].Flatten(elem_cnt); - input_chunk.data[col_idx + 2].Reference(slices[col_idx]); - } - - // execute the lambda expression - expr_executor.Execute(input_chunk, lambda_chunk); -} - -template -static void ListLambdaFunction(DataChunk &args, ExpressionState &state, Vector &result) { - - // always at least the list argument - D_ASSERT(args.ColumnCount() >= 1); - - auto count = args.size(); - Vector &lists = args.data[0]; - - result.SetVectorType(VectorType::FLAT_VECTOR); - auto result_entries = FlatVector::GetData(result); - auto &result_validity = FlatVector::Validity(result); - - if (lists.GetType().id() == LogicalTypeId::SQLNULL) { - result_validity.SetInvalid(0); - return; - } - - // e.g. window functions in sub queries return dictionary vectors, which segfault on expression execution - // if not flattened first - for (idx_t i = 1; i < args.ColumnCount(); i++) { - if (args.data[i].GetVectorType() != VectorType::FLAT_VECTOR && - args.data[i].GetVectorType() != VectorType::CONSTANT_VECTOR) { - args.data[i].Flatten(count); - } - } - - // get the lists data - UnifiedVectorFormat lists_data; - lists.ToUnifiedFormat(count, lists_data); - auto list_entries = UnifiedVectorFormat::GetData(lists_data); - - // get the lambda expression - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - auto &lambda_expr = info.lambda_expr; - - // get the child vector and child data - auto lists_size = ListVector::GetListSize(lists); - auto &child_vector = ListVector::GetEntry(lists); - child_vector.Flatten(lists_size); - UnifiedVectorFormat child_data; - child_vector.ToUnifiedFormat(lists_size, child_data); - - // to slice the child vector - SelectionVector sel(STANDARD_VECTOR_SIZE); - - // this vector never contains more than one element - vector result_types; - result_types.push_back(lambda_expr->return_type); - - // non-lambda parameter columns - vector columns; - vector indexes; - vector sel_vectors; - - vector types; - types.push_back(child_vector.GetType()); - types.push_back(child_vector.GetType()); - - // skip the list column - for (idx_t i = 1; i < args.ColumnCount(); i++) { - columns.emplace_back(); - args.data[i].ToUnifiedFormat(count, columns[i - 1]); - indexes.push_back(0); - sel_vectors.emplace_back(STANDARD_VECTOR_SIZE); - types.push_back(args.data[i].GetType()); - } - - // get the expression executor - ExpressionExecutor expr_executor(state.GetContext(), *lambda_expr); - - // these are only for the list_filter - vector lists_len; - idx_t curr_list_len = 0; - idx_t curr_list_offset = 0; - idx_t appended_lists_cnt = 0; - idx_t curr_original_list_len = 0; - - if (!IS_TRANSFORM) { - lists_len.reserve(count); - } - - DataChunk input_chunk; - DataChunk lambda_chunk; - input_chunk.InitializeEmpty(types); - lambda_chunk.Initialize(Allocator::DefaultAllocator(), result_types); - - // loop over the child entries and create chunks to be executed by the expression executor - idx_t elem_cnt = 0; - idx_t offset = 0; - for (idx_t row_idx = 0; row_idx < count; row_idx++) { - - auto lists_index = lists_data.sel->get_index(row_idx); - const auto &list_entry = list_entries[lists_index]; - - // set the result to NULL for this row - if (!lists_data.validity.RowIsValid(lists_index)) { - result_validity.SetInvalid(row_idx); - if (!IS_TRANSFORM) { - lists_len.push_back(0); - } - continue; - } - - // set the length and offset of the resulting lists of list_transform - if (IS_TRANSFORM) { - result_entries[row_idx].offset = offset; - result_entries[row_idx].length = list_entry.length; - offset += list_entry.length; - } else { - lists_len.push_back(list_entry.length); - } - - // empty list, nothing to execute - if (list_entry.length == 0) { - continue; - } - - // get the data indexes - for (idx_t col_idx = 0; col_idx < args.ColumnCount() - 1; col_idx++) { - indexes[col_idx] = columns[col_idx].sel->get_index(row_idx); - } - - // iterate list elements and create transformed expression columns - for (idx_t child_idx = 0; child_idx < list_entry.length; child_idx++) { - // reached STANDARD_VECTOR_SIZE elements - if (elem_cnt == STANDARD_VECTOR_SIZE) { - lambda_chunk.Reset(); - ExecuteExpression(types, result_types, elem_cnt, sel, sel_vectors, input_chunk, lambda_chunk, - child_vector, args, expr_executor); - - auto &lambda_vector = lambda_chunk.data[0]; - - if (IS_TRANSFORM) { - AppendTransformedToResult(lambda_vector, elem_cnt, result); - } else { - AppendFilteredToResult(lambda_vector, result_entries, elem_cnt, result, curr_list_len, - curr_list_offset, appended_lists_cnt, lists_len, curr_original_list_len, - input_chunk); - } - elem_cnt = 0; - } - - // to slice the child vector - auto source_idx = child_data.sel->get_index(list_entry.offset + child_idx); - sel.set_index(elem_cnt, source_idx); - - // for each column, set the index of the selection vector to slice properly - for (idx_t col_idx = 0; col_idx < args.ColumnCount() - 1; col_idx++) { - sel_vectors[col_idx].set_index(elem_cnt, indexes[col_idx]); - } - elem_cnt++; - } - } - - lambda_chunk.Reset(); - ExecuteExpression(types, result_types, elem_cnt, sel, sel_vectors, input_chunk, lambda_chunk, child_vector, args, - expr_executor); - auto &lambda_vector = lambda_chunk.data[0]; - - if (IS_TRANSFORM) { - AppendTransformedToResult(lambda_vector, elem_cnt, result); - } else { - AppendFilteredToResult(lambda_vector, result_entries, elem_cnt, result, curr_list_len, curr_list_offset, - appended_lists_cnt, lists_len, curr_original_list_len, input_chunk); - } - - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } -} - -static void ListTransformFunction(DataChunk &args, ExpressionState &state, Vector &result) { - ListLambdaFunction<>(args, state, result); -} - -static void ListFilterFunction(DataChunk &args, ExpressionState &state, Vector &result) { - ListLambdaFunction(args, state, result); -} - -template -static unique_ptr ListLambdaBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - auto &bound_lambda_expr = arguments[1]->Cast(); - if (bound_lambda_expr.parameter_count != LAMBDA_PARAM_CNT) { - throw BinderException("Incorrect number of parameters in lambda function! " + bound_function.name + - " expects " + to_string(LAMBDA_PARAM_CNT) + " parameter(s)."); - } - - if (arguments[0]->return_type.id() == LogicalTypeId::SQLNULL) { - bound_function.arguments[0] = LogicalType::SQLNULL; - bound_function.return_type = LogicalType::SQLNULL; - return make_uniq(bound_function.return_type, nullptr); - } - - if (arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN) { - throw ParameterNotResolvedException(); - } - - D_ASSERT(arguments[0]->return_type.id() == LogicalTypeId::LIST); - - // get the lambda expression and put it in the bind info - auto lambda_expr = std::move(bound_lambda_expr.lambda_expr); - return make_uniq(bound_function.return_type, std::move(lambda_expr)); -} - -static unique_ptr ListTransformBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - - // at least the list column and the lambda function - D_ASSERT(arguments.size() == 2); - if (arguments[1]->expression_class != ExpressionClass::BOUND_LAMBDA) { - throw BinderException("Invalid lambda expression!"); - } - - auto &bound_lambda_expr = arguments[1]->Cast(); - bound_function.return_type = LogicalType::LIST(bound_lambda_expr.lambda_expr->return_type); - return ListLambdaBind<1>(context, bound_function, arguments); -} - -static unique_ptr ListFilterBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - - // at least the list column and the lambda function - D_ASSERT(arguments.size() == 2); - if (arguments[1]->expression_class != ExpressionClass::BOUND_LAMBDA) { - throw BinderException("Invalid lambda expression!"); - } - - // try to cast to boolean, if the return type of the lambda filter expression is not already boolean - auto &bound_lambda_expr = arguments[1]->Cast(); - if (bound_lambda_expr.lambda_expr->return_type != LogicalType::BOOLEAN) { - auto cast_lambda_expr = - BoundCastExpression::AddCastToType(context, std::move(bound_lambda_expr.lambda_expr), LogicalType::BOOLEAN); - bound_lambda_expr.lambda_expr = std::move(cast_lambda_expr); - } - - bound_function.return_type = arguments[0]->return_type; - return ListLambdaBind<1>(context, bound_function, arguments); -} - -ScalarFunction ListTransformFun::GetFunction() { - ScalarFunction fun({LogicalType::LIST(LogicalType::ANY), LogicalType::LAMBDA}, LogicalType::LIST(LogicalType::ANY), - ListTransformFunction, ListTransformBind, nullptr, nullptr); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - fun.serialize = ListLambdaBindData::Serialize; - fun.deserialize = ListLambdaBindData::Deserialize; - return fun; -} - -ScalarFunction ListFilterFun::GetFunction() { - ScalarFunction fun({LogicalType::LIST(LogicalType::ANY), LogicalType::LAMBDA}, LogicalType::LIST(LogicalType::ANY), - ListFilterFunction, ListFilterBind, nullptr, nullptr); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - fun.serialize = ListLambdaBindData::Serialize; - fun.deserialize = ListLambdaBindData::Deserialize; - return fun; -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -struct ListSortBindData : public FunctionData { - ListSortBindData(OrderType order_type_p, OrderByNullType null_order_p, const LogicalType &return_type_p, - const LogicalType &child_type_p, ClientContext &context_p); - ~ListSortBindData() override; - - OrderType order_type; - OrderByNullType null_order; - LogicalType return_type; - LogicalType child_type; - - vector types; - vector payload_types; - - ClientContext &context; - RowLayout payload_layout; - vector orders; - -public: - bool Equals(const FunctionData &other_p) const override; - unique_ptr Copy() const override; -}; - -ListSortBindData::ListSortBindData(OrderType order_type_p, OrderByNullType null_order_p, - const LogicalType &return_type_p, const LogicalType &child_type_p, - ClientContext &context_p) - : order_type(order_type_p), null_order(null_order_p), return_type(return_type_p), child_type(child_type_p), - context(context_p) { - - // get the vector types - types.emplace_back(LogicalType::USMALLINT); - types.emplace_back(child_type); - D_ASSERT(types.size() == 2); - - // get the payload types - payload_types.emplace_back(LogicalType::UINTEGER); - D_ASSERT(payload_types.size() == 1); - - // initialize the payload layout - payload_layout.Initialize(payload_types); - - // get the BoundOrderByNode - auto idx_col_expr = make_uniq_base(LogicalType::USMALLINT, 0); - auto lists_col_expr = make_uniq_base(child_type, 1); - orders.emplace_back(OrderType::ASCENDING, OrderByNullType::ORDER_DEFAULT, std::move(idx_col_expr)); - orders.emplace_back(order_type, null_order, std::move(lists_col_expr)); -} - -unique_ptr ListSortBindData::Copy() const { - return make_uniq(order_type, null_order, return_type, child_type, context); -} - -bool ListSortBindData::Equals(const FunctionData &other_p) const { - auto &other = other_p.Cast(); - return order_type == other.order_type && null_order == other.null_order; -} - -ListSortBindData::~ListSortBindData() { -} - -// create the key_chunk and the payload_chunk and sink them into the local_sort_state -void SinkDataChunk(Vector *child_vector, SelectionVector &sel, idx_t offset_lists_indices, vector &types, - vector &payload_types, Vector &payload_vector, LocalSortState &local_sort_state, - bool &data_to_sort, Vector &lists_indices) { - - // slice the child vector - Vector slice(*child_vector, sel, offset_lists_indices); - - // initialize and fill key_chunk - DataChunk key_chunk; - key_chunk.InitializeEmpty(types); - key_chunk.data[0].Reference(lists_indices); - key_chunk.data[1].Reference(slice); - key_chunk.SetCardinality(offset_lists_indices); - - // initialize and fill key_chunk and payload_chunk - DataChunk payload_chunk; - payload_chunk.InitializeEmpty(payload_types); - payload_chunk.data[0].Reference(payload_vector); - payload_chunk.SetCardinality(offset_lists_indices); - - key_chunk.Verify(); - payload_chunk.Verify(); - - // sink - key_chunk.Flatten(); - local_sort_state.SinkChunk(key_chunk, payload_chunk); - data_to_sort = true; -} - -static void ListSortFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() >= 1 && args.ColumnCount() <= 3); - auto count = args.size(); - Vector &input_lists = args.data[0]; - - result.SetVectorType(VectorType::FLAT_VECTOR); - auto &result_validity = FlatVector::Validity(result); - - if (input_lists.GetType().id() == LogicalTypeId::SQLNULL) { - result_validity.SetInvalid(0); - return; - } - - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - - // initialize the global and local sorting state - auto &buffer_manager = BufferManager::GetBufferManager(info.context); - GlobalSortState global_sort_state(buffer_manager, info.orders, info.payload_layout); - LocalSortState local_sort_state; - local_sort_state.Initialize(global_sort_state, buffer_manager); - - // this ensures that we do not change the order of the entries in the input chunk - VectorOperations::Copy(input_lists, result, count, 0, 0); - - // get the child vector - auto lists_size = ListVector::GetListSize(result); - auto &child_vector = ListVector::GetEntry(result); - UnifiedVectorFormat child_data; - child_vector.ToUnifiedFormat(lists_size, child_data); - - // get the lists data - UnifiedVectorFormat lists_data; - result.ToUnifiedFormat(count, lists_data); - auto list_entries = UnifiedVectorFormat::GetData(lists_data); - - // create the lists_indices vector, this contains an element for each list's entry, - // the element corresponds to the list's index, e.g. for [1, 2, 4], [5, 4] - // lists_indices contains [0, 0, 0, 1, 1] - Vector lists_indices(LogicalType::USMALLINT); - auto lists_indices_data = FlatVector::GetData(lists_indices); - - // create the payload_vector, this is just a vector containing incrementing integers - // this will later be used as the 'new' selection vector of the child_vector, after - // rearranging the payload according to the sorting order - Vector payload_vector(LogicalType::UINTEGER); - auto payload_vector_data = FlatVector::GetData(payload_vector); - - // selection vector pointing to the data of the child vector, - // used for slicing the child_vector correctly - SelectionVector sel(STANDARD_VECTOR_SIZE); - - idx_t offset_lists_indices = 0; - uint32_t incr_payload_count = 0; - bool data_to_sort = false; - - for (idx_t i = 0; i < count; i++) { - auto lists_index = lists_data.sel->get_index(i); - const auto &list_entry = list_entries[lists_index]; - - // nothing to do for this list - if (!lists_data.validity.RowIsValid(lists_index)) { - result_validity.SetInvalid(i); - continue; - } - - // empty list, no sorting required - if (list_entry.length == 0) { - continue; - } - - for (idx_t child_idx = 0; child_idx < list_entry.length; child_idx++) { - // lists_indices vector is full, sink - if (offset_lists_indices == STANDARD_VECTOR_SIZE) { - SinkDataChunk(&child_vector, sel, offset_lists_indices, info.types, info.payload_types, payload_vector, - local_sort_state, data_to_sort, lists_indices); - offset_lists_indices = 0; - } - - auto source_idx = list_entry.offset + child_idx; - sel.set_index(offset_lists_indices, source_idx); - lists_indices_data[offset_lists_indices] = (uint32_t)i; - payload_vector_data[offset_lists_indices] = source_idx; - offset_lists_indices++; - incr_payload_count++; - } - } - - if (offset_lists_indices != 0) { - SinkDataChunk(&child_vector, sel, offset_lists_indices, info.types, info.payload_types, payload_vector, - local_sort_state, data_to_sort, lists_indices); - } - - if (data_to_sort) { - // add local state to global state, which sorts the data - global_sort_state.AddLocalState(local_sort_state); - global_sort_state.PrepareMergePhase(); - - // selection vector that is to be filled with the 'sorted' payload - SelectionVector sel_sorted(incr_payload_count); - idx_t sel_sorted_idx = 0; - - // scan the sorted row data - PayloadScanner scanner(*global_sort_state.sorted_blocks[0]->payload_data, global_sort_state); - for (;;) { - DataChunk result_chunk; - result_chunk.Initialize(Allocator::DefaultAllocator(), info.payload_types); - result_chunk.SetCardinality(0); - scanner.Scan(result_chunk); - if (result_chunk.size() == 0) { - break; - } - - // construct the selection vector with the new order from the result vectors - Vector result_vector(result_chunk.data[0]); - auto result_data = FlatVector::GetData(result_vector); - auto row_count = result_chunk.size(); - - for (idx_t i = 0; i < row_count; i++) { - sel_sorted.set_index(sel_sorted_idx, result_data[i]); - D_ASSERT(result_data[i] < lists_size); - sel_sorted_idx++; - } - } - - D_ASSERT(sel_sorted_idx == incr_payload_count); - child_vector.Slice(sel_sorted, sel_sorted_idx); - child_vector.Flatten(sel_sorted_idx); - } - - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } -} - -static unique_ptr ListSortBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments, OrderType &order, - OrderByNullType &null_order) { - - LogicalType child_type; - if (arguments[0]->return_type == LogicalTypeId::UNKNOWN) { - bound_function.arguments[0] = LogicalTypeId::UNKNOWN; - bound_function.return_type = LogicalType::SQLNULL; - child_type = bound_function.return_type; - return make_uniq(order, null_order, bound_function.return_type, child_type, context); - } - - bound_function.arguments[0] = arguments[0]->return_type; - bound_function.return_type = arguments[0]->return_type; - child_type = ListType::GetChildType(arguments[0]->return_type); - - return make_uniq(order, null_order, bound_function.return_type, child_type, context); -} - -template -static T GetOrder(ClientContext &context, Expression &expr) { - if (!expr.IsFoldable()) { - throw InvalidInputException("Sorting order must be a constant"); - } - Value order_value = ExpressionExecutor::EvaluateScalar(context, expr); - auto order_name = StringUtil::Upper(order_value.ToString()); - return EnumUtil::FromString(order_name.c_str()); -} - -static unique_ptr ListNormalSortBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - D_ASSERT(!arguments.empty() && arguments.size() <= 3); - auto order = OrderType::ORDER_DEFAULT; - auto null_order = OrderByNullType::ORDER_DEFAULT; - - // get the sorting order - if (arguments.size() >= 2) { - order = GetOrder(context, *arguments[1]); - } - // get the null sorting order - if (arguments.size() == 3) { - null_order = GetOrder(context, *arguments[2]); - } - auto &config = DBConfig::GetConfig(context); - order = config.ResolveOrder(order); - null_order = config.ResolveNullOrder(order, null_order); - return ListSortBind(context, bound_function, arguments, order, null_order); -} - -static unique_ptr ListReverseSortBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - auto order = OrderType::ORDER_DEFAULT; - auto null_order = OrderByNullType::ORDER_DEFAULT; - - if (arguments.size() == 2) { - null_order = GetOrder(context, *arguments[1]); - } - auto &config = DBConfig::GetConfig(context); - order = config.ResolveOrder(order); - switch (order) { - case OrderType::ASCENDING: - order = OrderType::DESCENDING; - break; - case OrderType::DESCENDING: - order = OrderType::ASCENDING; - break; - default: - throw InternalException("Unexpected order type in list reverse sort"); - } - null_order = config.ResolveNullOrder(order, null_order); - return ListSortBind(context, bound_function, arguments, order, null_order); -} - -ScalarFunctionSet ListSortFun::GetFunctions() { - // one parameter: list - ScalarFunction sort({LogicalType::LIST(LogicalType::ANY)}, LogicalType::LIST(LogicalType::ANY), ListSortFunction, - ListNormalSortBind); - - // two parameters: list, order - ScalarFunction sort_order({LogicalType::LIST(LogicalType::ANY), LogicalType::VARCHAR}, - LogicalType::LIST(LogicalType::ANY), ListSortFunction, ListNormalSortBind); - - // three parameters: list, order, null order - ScalarFunction sort_orders({LogicalType::LIST(LogicalType::ANY), LogicalType::VARCHAR, LogicalType::VARCHAR}, - LogicalType::LIST(LogicalType::ANY), ListSortFunction, ListNormalSortBind); - - ScalarFunctionSet list_sort; - list_sort.AddFunction(sort); - list_sort.AddFunction(sort_order); - list_sort.AddFunction(sort_orders); - return list_sort; -} - -ScalarFunctionSet ListReverseSortFun::GetFunctions() { - // one parameter: list - ScalarFunction sort_reverse({LogicalType::LIST(LogicalType::ANY)}, LogicalType::LIST(LogicalType::ANY), - ListSortFunction, ListReverseSortBind); - - // two parameters: list, null order - ScalarFunction sort_reverse_null_order({LogicalType::LIST(LogicalType::ANY), LogicalType::VARCHAR}, - LogicalType::LIST(LogicalType::ANY), ListSortFunction, ListReverseSortBind); - - ScalarFunctionSet list_reverse_sort; - list_reverse_sort.AddFunction(sort_reverse); - list_reverse_sort.AddFunction(sort_reverse_null_order); - return list_reverse_sort; -} - -} // namespace duckdb - - - - - - - - - - -namespace duckdb { - -static void ListValueFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(result.GetType().id() == LogicalTypeId::LIST); - auto &child_type = ListType::GetChildType(result.GetType()); - - result.SetVectorType(VectorType::CONSTANT_VECTOR); - for (idx_t i = 0; i < args.ColumnCount(); i++) { - if (args.data[i].GetVectorType() != VectorType::CONSTANT_VECTOR) { - result.SetVectorType(VectorType::FLAT_VECTOR); - } - } - - auto result_data = FlatVector::GetData(result); - for (idx_t i = 0; i < args.size(); i++) { - result_data[i].offset = ListVector::GetListSize(result); - for (idx_t col_idx = 0; col_idx < args.ColumnCount(); col_idx++) { - auto val = args.GetValue(col_idx, i).DefaultCastAs(child_type); - ListVector::PushBack(result, val); - } - result_data[i].length = args.ColumnCount(); - } - result.Verify(args.size()); -} - -static unique_ptr ListValueBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - // collect names and deconflict, construct return type - LogicalType child_type = arguments.empty() ? LogicalType::SQLNULL : arguments[0]->return_type; - for (idx_t i = 1; i < arguments.size(); i++) { - child_type = LogicalType::MaxLogicalType(child_type, arguments[i]->return_type); - } - - // this is more for completeness reasons - bound_function.varargs = child_type; - bound_function.return_type = LogicalType::LIST(child_type); - return make_uniq(bound_function.return_type); -} - -unique_ptr ListValueStats(ClientContext &context, FunctionStatisticsInput &input) { - auto &child_stats = input.child_stats; - auto &expr = input.expr; - auto list_stats = ListStats::CreateEmpty(expr.return_type); - auto &list_child_stats = ListStats::GetChildStats(list_stats); - for (idx_t i = 0; i < child_stats.size(); i++) { - list_child_stats.Merge(child_stats[i]); - } - return list_stats.ToUnique(); -} - -ScalarFunction ListValueFun::GetFunction() { - // the arguments and return types are actually set in the binder function - ScalarFunction fun("list_value", {}, LogicalTypeId::LIST, ListValueFunction, ListValueBind, nullptr, - ListValueStats); - fun.varargs = LogicalType::ANY; - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return fun; -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -struct NumericRangeInfo { - using TYPE = int64_t; - using INCREMENT_TYPE = int64_t; - - static int64_t DefaultStart() { - return 0; - } - static int64_t DefaultIncrement() { - return 1; - } - - static uint64_t ListLength(int64_t start_value, int64_t end_value, int64_t increment_value, bool inclusive_bound) { - if (increment_value == 0) { - return 0; - } - if (start_value > end_value && increment_value > 0) { - return 0; - } - if (start_value < end_value && increment_value < 0) { - return 0; - } - hugeint_t total_diff = AbsValue(hugeint_t(end_value) - hugeint_t(start_value)); - hugeint_t increment = AbsValue(hugeint_t(increment_value)); - hugeint_t total_values = total_diff / increment; - if (total_diff % increment == 0) { - if (inclusive_bound) { - total_values += 1; - } - } else { - total_values += 1; - } - if (total_values > NumericLimits::Maximum()) { - throw InvalidInputException("Lists larger than 2^32 elements are not supported"); - } - return Hugeint::Cast(total_values); - } - - static void Increment(int64_t &input, int64_t increment) { - input += increment; - } -}; -struct TimestampRangeInfo { - using TYPE = timestamp_t; - using INCREMENT_TYPE = interval_t; - - static timestamp_t DefaultStart() { - throw InternalException("Default start not implemented for timestamp range"); - } - static interval_t DefaultIncrement() { - throw InternalException("Default increment not implemented for timestamp range"); - } - static uint64_t ListLength(timestamp_t start_value, timestamp_t end_value, interval_t increment_value, - bool inclusive_bound) { - bool is_positive = increment_value.months > 0 || increment_value.days > 0 || increment_value.micros > 0; - bool is_negative = increment_value.months < 0 || increment_value.days < 0 || increment_value.micros < 0; - if (!is_negative && !is_positive) { - // interval is 0: no result - return 0; - } - // We don't allow infinite bounds because they generate errors or infinite loops - if (!Timestamp::IsFinite(start_value) || !Timestamp::IsFinite(end_value)) { - throw InvalidInputException("Interval infinite bounds not supported"); - } - - if (is_negative && is_positive) { - // we don't allow a mix of - throw InvalidInputException("Interval with mix of negative/positive entries not supported"); - } - if (start_value > end_value && is_positive) { - return 0; - } - if (start_value < end_value && is_negative) { - return 0; - } - int64_t total_values = 0; - if (is_negative) { - // negative interval, start_value is going down - while (inclusive_bound ? start_value >= end_value : start_value > end_value) { - start_value = Interval::Add(start_value, increment_value); - total_values++; - if (total_values > NumericLimits::Maximum()) { - throw InvalidInputException("Lists larger than 2^32 elements are not supported"); - } - } - } else { - // positive interval, start_value is going up - while (inclusive_bound ? start_value <= end_value : start_value < end_value) { - start_value = Interval::Add(start_value, increment_value); - total_values++; - if (total_values > NumericLimits::Maximum()) { - throw InvalidInputException("Lists larger than 2^32 elements are not supported"); - } - } - } - return total_values; - } - - static void Increment(timestamp_t &input, interval_t increment) { - input = Interval::Add(input, increment); - } -}; - -template -class RangeInfoStruct { -public: - explicit RangeInfoStruct(DataChunk &args_p) : args(args_p) { - switch (args.ColumnCount()) { - case 1: - args.data[0].ToUnifiedFormat(args.size(), vdata[0]); - break; - case 2: - args.data[0].ToUnifiedFormat(args.size(), vdata[0]); - args.data[1].ToUnifiedFormat(args.size(), vdata[1]); - break; - case 3: - args.data[0].ToUnifiedFormat(args.size(), vdata[0]); - args.data[1].ToUnifiedFormat(args.size(), vdata[1]); - args.data[2].ToUnifiedFormat(args.size(), vdata[2]); - break; - default: - throw InternalException("Unsupported number of parameters for range"); - } - } - - bool RowIsValid(idx_t row_idx) { - for (idx_t i = 0; i < args.ColumnCount(); i++) { - auto idx = vdata[i].sel->get_index(row_idx); - if (!vdata[i].validity.RowIsValid(idx)) { - return false; - } - } - return true; - } - - typename OP::TYPE StartListValue(idx_t row_idx) { - if (args.ColumnCount() == 1) { - return OP::DefaultStart(); - } else { - auto data = (typename OP::TYPE *)vdata[0].data; - auto idx = vdata[0].sel->get_index(row_idx); - return data[idx]; - } - } - - typename OP::TYPE EndListValue(idx_t row_idx) { - idx_t vdata_idx = args.ColumnCount() == 1 ? 0 : 1; - auto data = (typename OP::TYPE *)vdata[vdata_idx].data; - auto idx = vdata[vdata_idx].sel->get_index(row_idx); - return data[idx]; - } - - typename OP::INCREMENT_TYPE ListIncrementValue(idx_t row_idx) { - if (args.ColumnCount() < 3) { - return OP::DefaultIncrement(); - } else { - auto data = (typename OP::INCREMENT_TYPE *)vdata[2].data; - auto idx = vdata[2].sel->get_index(row_idx); - return data[idx]; - } - } - - void GetListValues(idx_t row_idx, typename OP::TYPE &start_value, typename OP::TYPE &end_value, - typename OP::INCREMENT_TYPE &increment_value) { - start_value = StartListValue(row_idx); - end_value = EndListValue(row_idx); - increment_value = ListIncrementValue(row_idx); - } - - uint64_t ListLength(idx_t row_idx) { - typename OP::TYPE start_value; - typename OP::TYPE end_value; - typename OP::INCREMENT_TYPE increment_value; - GetListValues(row_idx, start_value, end_value, increment_value); - return OP::ListLength(start_value, end_value, increment_value, INCLUSIVE_BOUND); - } - -private: - DataChunk &args; - UnifiedVectorFormat vdata[3]; -}; - -template -static void ListRangeFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(result.GetType().id() == LogicalTypeId::LIST); - - RangeInfoStruct info(args); - idx_t args_size = 1; - auto result_type = VectorType::CONSTANT_VECTOR; - for (idx_t i = 0; i < args.ColumnCount(); i++) { - if (args.data[i].GetVectorType() != VectorType::CONSTANT_VECTOR) { - args_size = args.size(); - result_type = VectorType::FLAT_VECTOR; - break; - } - } - auto list_data = FlatVector::GetData(result); - auto &result_validity = FlatVector::Validity(result); - int64_t total_size = 0; - for (idx_t i = 0; i < args_size; i++) { - if (!info.RowIsValid(i)) { - result_validity.SetInvalid(i); - list_data[i].offset = total_size; - list_data[i].length = 0; - } else { - list_data[i].offset = total_size; - list_data[i].length = info.ListLength(i); - total_size += list_data[i].length; - } - } - - // now construct the child vector of the list - ListVector::Reserve(result, total_size); - auto range_data = FlatVector::GetData(ListVector::GetEntry(result)); - idx_t total_idx = 0; - for (idx_t i = 0; i < args_size; i++) { - typename OP::TYPE start_value = info.StartListValue(i); - typename OP::INCREMENT_TYPE increment = info.ListIncrementValue(i); - - typename OP::TYPE range_value = start_value; - for (idx_t range_idx = 0; range_idx < list_data[i].length; range_idx++) { - if (range_idx > 0) { - OP::Increment(range_value, increment); - } - range_data[total_idx++] = range_value; - } - } - - ListVector::SetListSize(result, total_size); - result.SetVectorType(result_type); - - result.Verify(args.size()); -} - -ScalarFunctionSet ListRangeFun::GetFunctions() { - // the arguments and return types are actually set in the binder function - ScalarFunctionSet range_set; - range_set.AddFunction(ScalarFunction({LogicalType::BIGINT}, LogicalType::LIST(LogicalType::BIGINT), - ListRangeFunction)); - range_set.AddFunction(ScalarFunction({LogicalType::BIGINT, LogicalType::BIGINT}, - LogicalType::LIST(LogicalType::BIGINT), - ListRangeFunction)); - range_set.AddFunction(ScalarFunction({LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT}, - LogicalType::LIST(LogicalType::BIGINT), - ListRangeFunction)); - range_set.AddFunction(ScalarFunction({LogicalType::TIMESTAMP, LogicalType::TIMESTAMP, LogicalType::INTERVAL}, - LogicalType::LIST(LogicalType::TIMESTAMP), - ListRangeFunction)); - return range_set; -} - -ScalarFunctionSet GenerateSeriesFun::GetFunctions() { - ScalarFunctionSet generate_series; - generate_series.AddFunction(ScalarFunction({LogicalType::BIGINT}, LogicalType::LIST(LogicalType::BIGINT), - ListRangeFunction)); - generate_series.AddFunction(ScalarFunction({LogicalType::BIGINT, LogicalType::BIGINT}, - LogicalType::LIST(LogicalType::BIGINT), - ListRangeFunction)); - generate_series.AddFunction(ScalarFunction({LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT}, - LogicalType::LIST(LogicalType::BIGINT), - ListRangeFunction)); - generate_series.AddFunction(ScalarFunction({LogicalType::TIMESTAMP, LogicalType::TIMESTAMP, LogicalType::INTERVAL}, - LogicalType::LIST(LogicalType::TIMESTAMP), - ListRangeFunction)); - return generate_series; -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -static void CardinalityFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &map = args.data[0]; - UnifiedVectorFormat map_data; - result.SetVectorType(VectorType::FLAT_VECTOR); - auto result_data = FlatVector::GetData(result); - auto &result_validity = FlatVector::Validity(result); - - map.ToUnifiedFormat(args.size(), map_data); - for (idx_t row = 0; row < args.size(); row++) { - auto list_entry = UnifiedVectorFormat::GetData(map_data)[map_data.sel->get_index(row)]; - result_data[row] = list_entry.length; - result_validity.Set(row, map_data.validity.RowIsValid(map_data.sel->get_index(row))); - } - - if (args.size() == 1) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } -} - -static unique_ptr CardinalityBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - if (arguments.size() != 1) { - throw BinderException("Cardinality must have exactly one arguments"); - } - - if (arguments[0]->return_type.id() != LogicalTypeId::MAP) { - throw BinderException("Cardinality can only operate on MAPs"); - } - - bound_function.return_type = LogicalType::UBIGINT; - return make_uniq(bound_function.return_type); -} - -ScalarFunction CardinalityFun::GetFunction() { - ScalarFunction fun({LogicalType::ANY}, LogicalType::UBIGINT, CardinalityFunction, CardinalityBind); - fun.varargs = LogicalType::ANY; - fun.null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING; - return fun; -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -// Example: -// source: [1,2,3], expansion_factor: 4 -// target (result): [1,2,3,1,2,3,1,2,3,1,2,3] -static void CreateExpandedVector(const Vector &source, Vector &target, idx_t expansion_factor) { - idx_t count = ListVector::GetListSize(source); - auto &entry = ListVector::GetEntry(source); - - idx_t target_idx = 0; - for (idx_t copy = 0; copy < expansion_factor; copy++) { - for (idx_t key_idx = 0; key_idx < count; key_idx++) { - target.SetValue(target_idx, entry.GetValue(key_idx)); - target_idx++; - } - } - D_ASSERT(target_idx == count * expansion_factor); -} - -static void AlignVectorToReference(const Vector &original, const Vector &reference, idx_t tuple_count, Vector &result) { - auto original_length = ListVector::GetListSize(original); - auto new_length = ListVector::GetListSize(reference); - - Vector expanded_const(ListType::GetChildType(original.GetType()), new_length); - - auto expansion_factor = new_length / original_length; - if (expansion_factor != tuple_count) { - throw InvalidInputException("Error in MAP creation: key list and value list do not align. i.e. different " - "size or incompatible structure"); - } - CreateExpandedVector(original, expanded_const, expansion_factor); - result.Reference(expanded_const); -} - -static bool ListEntriesEqual(Vector &keys, Vector &values, idx_t count) { - auto key_count = ListVector::GetListSize(keys); - auto value_count = ListVector::GetListSize(values); - bool same_vector_type = keys.GetVectorType() == values.GetVectorType(); - - D_ASSERT(keys.GetType().id() == LogicalTypeId::LIST); - D_ASSERT(values.GetType().id() == LogicalTypeId::LIST); - - UnifiedVectorFormat keys_data; - UnifiedVectorFormat values_data; - - keys.ToUnifiedFormat(count, keys_data); - values.ToUnifiedFormat(count, values_data); - - auto keys_entries = UnifiedVectorFormat::GetData(keys_data); - auto values_entries = UnifiedVectorFormat::GetData(values_data); - - if (same_vector_type) { - const auto key_data = keys_data.data; - const auto value_data = values_data.data; - - if (keys.GetVectorType() == VectorType::CONSTANT_VECTOR) { - D_ASSERT(values.GetVectorType() == VectorType::CONSTANT_VECTOR); - // Only need to compare one entry in this case - return memcmp(key_data, value_data, sizeof(list_entry_t)) == 0; - } - - // Fast path if the vector types are equal, can just check if the entries are the same - if (key_count != value_count) { - return false; - } - return memcmp(key_data, value_data, count * sizeof(list_entry_t)) == 0; - } - - // Compare the list_entries one by one - for (idx_t i = 0; i < count; i++) { - auto keys_idx = keys_data.sel->get_index(i); - auto values_idx = values_data.sel->get_index(i); - - if (keys_entries[keys_idx] != values_entries[values_idx]) { - return false; - } - } - return true; -} - -static void MapFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(result.GetType().id() == LogicalTypeId::MAP); - - auto &key_vector = MapVector::GetKeys(result); - auto &value_vector = MapVector::GetValues(result); - auto result_data = ListVector::GetData(result); - - result.SetVectorType(VectorType::CONSTANT_VECTOR); - if (args.data.empty()) { - ListVector::SetListSize(result, 0); - result_data->offset = 0; - result_data->length = 0; - result.Verify(args.size()); - return; - } - - bool keys_are_const = args.data[0].GetVectorType() == VectorType::CONSTANT_VECTOR; - bool values_are_const = args.data[1].GetVectorType() == VectorType::CONSTANT_VECTOR; - if (!keys_are_const || !values_are_const) { - result.SetVectorType(VectorType::FLAT_VECTOR); - } - - auto key_count = ListVector::GetListSize(args.data[0]); - auto value_count = ListVector::GetListSize(args.data[1]); - auto key_data = ListVector::GetData(args.data[0]); - auto value_data = ListVector::GetData(args.data[1]); - auto src_data = key_data; - - if (keys_are_const && !values_are_const) { - AlignVectorToReference(args.data[0], args.data[1], args.size(), key_vector); - src_data = value_data; - } else if (values_are_const && !keys_are_const) { - AlignVectorToReference(args.data[1], args.data[0], args.size(), value_vector); - } else { - if (!ListEntriesEqual(args.data[0], args.data[1], args.size())) { - throw InvalidInputException("Error in MAP creation: key list and value list do not align. i.e. different " - "size or incompatible structure"); - } - } - - ListVector::SetListSize(result, MaxValue(key_count, value_count)); - - result_data = ListVector::GetData(result); - for (idx_t i = 0; i < args.size(); i++) { - result_data[i] = src_data[i]; - } - - // check whether one of the vectors has already been referenced to an expanded vector in the case of const/non-const - // combination. If not, then referencing is still necessary - if (!(keys_are_const && !values_are_const)) { - key_vector.Reference(ListVector::GetEntry(args.data[0])); - } - if (!(values_are_const && !keys_are_const)) { - value_vector.Reference(ListVector::GetEntry(args.data[1])); - } - - MapVector::MapConversionVerify(result, args.size()); - result.Verify(args.size()); -} - -static unique_ptr MapBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - child_list_t child_types; - - if (arguments.size() != 2 && !arguments.empty()) { - throw Exception("We need exactly two lists for a map"); - } - if (arguments.size() == 2) { - if (arguments[0]->return_type.id() != LogicalTypeId::LIST) { - throw Exception("First argument is not a list"); - } - if (arguments[1]->return_type.id() != LogicalTypeId::LIST) { - throw Exception("Second argument is not a list"); - } - child_types.push_back(make_pair("key", arguments[0]->return_type)); - child_types.push_back(make_pair("value", arguments[1]->return_type)); - } - - if (arguments.empty()) { - auto empty = LogicalType::LIST(LogicalTypeId::SQLNULL); - child_types.push_back(make_pair("key", empty)); - child_types.push_back(make_pair("value", empty)); - } - - bound_function.return_type = - LogicalType::MAP(ListType::GetChildType(child_types[0].second), ListType::GetChildType(child_types[1].second)); - - return make_uniq(bound_function.return_type); -} - -ScalarFunction MapFun::GetFunction() { - //! the arguments and return types are actually set in the binder function - ScalarFunction fun({}, LogicalTypeId::MAP, MapFunction, MapBind); - fun.varargs = LogicalType::ANY; - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return fun; -} - -} // namespace duckdb - - - - - - - - - - -namespace duckdb { - -namespace { - -struct MapKeyIndexPair { - MapKeyIndexPair(idx_t map, idx_t key) : map_index(map), key_index(key) { - } - // The index of the map that this key comes from - idx_t map_index; - // The index within the maps key_list - idx_t key_index; -}; - -} // namespace - -vector GetListEntries(vector keys, vector values) { - D_ASSERT(keys.size() == values.size()); - vector entries; - for (idx_t i = 0; i < keys.size(); i++) { - child_list_t children; - children.emplace_back(make_pair("key", std::move(keys[i]))); - children.emplace_back(make_pair("value", std::move(values[i]))); - entries.push_back(Value::STRUCT(std::move(children))); - } - return entries; -} - -static void MapConcatFunction(DataChunk &args, ExpressionState &state, Vector &result) { - if (result.GetType().id() == LogicalTypeId::SQLNULL) { - // All inputs are NULL, just return NULL - auto &validity = FlatVector::Validity(result); - validity.SetInvalid(0); - result.SetVectorType(VectorType::CONSTANT_VECTOR); - return; - } - D_ASSERT(result.GetType().id() == LogicalTypeId::MAP); - auto count = args.size(); - - auto map_count = args.ColumnCount(); - vector map_formats(map_count); - for (idx_t i = 0; i < map_count; i++) { - auto &map = args.data[i]; - map.ToUnifiedFormat(count, map_formats[i]); - } - auto result_data = FlatVector::GetData(result); - - for (idx_t i = 0; i < count; i++) { - // Loop through all the maps per list - // we cant do better because all the entries of the child vector have to be contiguous - // so we cant start the next row before we have finished the one before it - auto &result_entry = result_data[i]; - vector index_to_map; - vector keys_list; - for (idx_t map_idx = 0; map_idx < map_count; map_idx++) { - if (args.data[map_idx].GetType().id() == LogicalTypeId::SQLNULL) { - continue; - } - auto &map_format = map_formats[map_idx]; - auto &keys = MapVector::GetKeys(args.data[map_idx]); - - auto index = map_format.sel->get_index(i); - auto entry = UnifiedVectorFormat::GetData(map_format)[index]; - - // Update the list for this row - for (idx_t list_idx = 0; list_idx < entry.length; list_idx++) { - auto key_index = entry.offset + list_idx; - auto key = keys.GetValue(key_index); - auto entry = std::find(keys_list.begin(), keys_list.end(), key); - if (entry == keys_list.end()) { - // Result list does not contain this value yet - keys_list.push_back(key); - index_to_map.emplace_back(map_idx, key_index); - } else { - // Result list already contains this, update where to find the value at - auto distance = std::distance(keys_list.begin(), entry); - auto &mapping = *(index_to_map.begin() + distance); - mapping.key_index = key_index; - mapping.map_index = map_idx; - } - } - } - vector values_list; - D_ASSERT(keys_list.size() == index_to_map.size()); - // Get the values from the mapping - for (auto &mapping : index_to_map) { - auto &map = args.data[mapping.map_index]; - auto &values = MapVector::GetValues(map); - values_list.push_back(values.GetValue(mapping.key_index)); - } - D_ASSERT(values_list.size() == keys_list.size()); - result_entry.offset = ListVector::GetListSize(result); - result_entry.length = values_list.size(); - auto list_entries = GetListEntries(std::move(keys_list), std::move(values_list)); - for (auto &list_entry : list_entries) { - ListVector::PushBack(result, list_entry); - } - } - - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } - result.Verify(count); -} - -static bool IsEmptyMap(const LogicalType &map) { - D_ASSERT(map.id() == LogicalTypeId::MAP); - auto &key_type = MapType::KeyType(map); - auto &value_type = MapType::ValueType(map); - return key_type.id() == LogicalType::SQLNULL && value_type.id() == LogicalType::SQLNULL; -} - -static unique_ptr MapConcatBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - - auto arg_count = arguments.size(); - if (arg_count < 2) { - throw InvalidInputException("The provided amount of arguments is incorrect, please provide 2 or more maps"); - } - - if (arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN) { - // Prepared statement - bound_function.arguments.emplace_back(LogicalTypeId::UNKNOWN); - bound_function.return_type = LogicalType(LogicalTypeId::SQLNULL); - return nullptr; - } - - LogicalType expected = LogicalType::SQLNULL; - - bool is_null = true; - // Check and verify that all the maps are of the same type - for (idx_t i = 0; i < arg_count; i++) { - auto &arg = arguments[i]; - auto &map = arg->return_type; - if (map.id() == LogicalTypeId::UNKNOWN) { - // Prepared statement - bound_function.arguments.emplace_back(LogicalTypeId::UNKNOWN); - bound_function.return_type = LogicalType(LogicalTypeId::SQLNULL); - return nullptr; - } - if (map.id() == LogicalTypeId::SQLNULL) { - // The maps are allowed to be NULL - continue; - } - if (map.id() != LogicalTypeId::MAP) { - throw InvalidInputException("MAP_CONCAT only takes map arguments"); - } - is_null = false; - if (IsEmptyMap(map)) { - // Map is allowed to be empty - continue; - } - - if (expected.id() == LogicalTypeId::SQLNULL) { - expected = map; - } else if (map != expected) { - throw InvalidInputException( - "'value' type of map differs between arguments, expected '%s', found '%s' instead", expected.ToString(), - map.ToString()); - } - } - - if (expected.id() == LogicalTypeId::SQLNULL && is_null == false) { - expected = LogicalType::MAP(LogicalType::SQLNULL, LogicalType::SQLNULL); - } - bound_function.return_type = expected; - return make_uniq(bound_function.return_type); -} - -ScalarFunction MapConcatFun::GetFunction() { - //! the arguments and return types are actually set in the binder function - ScalarFunction fun("map_concat", {}, LogicalTypeId::LIST, MapConcatFunction, MapConcatBind); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - fun.varargs = LogicalType::ANY; - return fun; -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -// Reverse of map_from_entries -static void MapEntriesFunction(DataChunk &args, ExpressionState &state, Vector &result) { - idx_t count = args.size(); - - result.Reinterpret(args.data[0]); - - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } - result.Verify(count); -} - -static unique_ptr MapEntriesBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - child_list_t child_types; - - if (arguments.size() != 1) { - throw InvalidInputException("Too many arguments provided, only expecting a single map"); - } - auto &map = arguments[0]->return_type; - - if (map.id() == LogicalTypeId::UNKNOWN) { - // Prepared statement - bound_function.arguments.emplace_back(LogicalTypeId::UNKNOWN); - bound_function.return_type = LogicalType(LogicalTypeId::SQLNULL); - return nullptr; - } - - if (map.id() != LogicalTypeId::MAP) { - throw InvalidInputException("The provided argument is not a map"); - } - auto &key_type = MapType::KeyType(map); - auto &value_type = MapType::ValueType(map); - - child_types.push_back(make_pair("key", key_type)); - child_types.push_back(make_pair("value", value_type)); - - auto row_type = LogicalType::STRUCT(child_types); - - bound_function.return_type = LogicalType::LIST(row_type); - return make_uniq(bound_function.return_type); -} - -ScalarFunction MapEntriesFun::GetFunction() { - //! the arguments and return types are actually set in the binder function - ScalarFunction fun({}, LogicalTypeId::LIST, MapEntriesFunction, MapEntriesBind); - fun.null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING; - fun.varargs = LogicalType::ANY; - return fun; -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -struct MapKeyArgFunctor { - // MAP is a LIST(STRUCT(K,V)) - // meaning the MAP itself is a List, but the child vector that we're interested in (the keys) - // are a level deeper than the initial child vector - - static Vector &GetList(Vector &map) { - return map; - } - static idx_t GetListSize(Vector &map) { - return ListVector::GetListSize(map); - } - static Vector &GetEntry(Vector &map) { - return MapVector::GetKeys(map); - } -}; - -void FillResult(Vector &map, Vector &offsets, Vector &result, idx_t count) { - UnifiedVectorFormat map_data; - map.ToUnifiedFormat(count, map_data); - - UnifiedVectorFormat offset_data; - offsets.ToUnifiedFormat(count, offset_data); - - auto result_data = FlatVector::GetData(result); - auto entry_count = ListVector::GetListSize(map); - auto &values_entries = MapVector::GetValues(map); - UnifiedVectorFormat values_entry_data; - // Note: this vector can have a different size than the map - values_entries.ToUnifiedFormat(entry_count, values_entry_data); - - for (idx_t row = 0; row < count; row++) { - idx_t offset_idx = offset_data.sel->get_index(row); - auto offset = UnifiedVectorFormat::GetData(offset_data)[offset_idx]; - - // Get the current size of the list, for the offset - idx_t current_offset = ListVector::GetListSize(result); - if (!offset_data.validity.RowIsValid(offset_idx) || !offset) { - // Set the entry data for this result row - auto &entry = result_data[row]; - entry.length = 0; - entry.offset = current_offset; - continue; - } - // All list indices start at 1, reduce by 1 to get the actual index - offset--; - - // Get the 'values' list entry corresponding to the offset - idx_t value_index = map_data.sel->get_index(row); - auto &value_list_entry = UnifiedVectorFormat::GetData(map_data)[value_index]; - - // Add the values to the result - idx_t list_offset = value_list_entry.offset + offset; - // All keys are unique, only one will ever match - idx_t length = 1; - ListVector::Append(result, values_entries, length + list_offset, list_offset); - - // Set the entry data for this result row - auto &entry = result_data[row]; - entry.length = length; - entry.offset = current_offset; - } -} - -static void MapExtractFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.data.size() == 2); - D_ASSERT(args.data[0].GetType().id() == LogicalTypeId::MAP); - result.SetVectorType(VectorType::FLAT_VECTOR); - - idx_t tuple_count = args.size(); - // Optimization: because keys are not allowed to be NULL, we can early-out - if (args.data[1].GetType().id() == LogicalTypeId::SQLNULL) { - //! We don't need to look through the map if the 'key' to look for is NULL - ListVector::SetListSize(result, 0); - result.SetVectorType(VectorType::CONSTANT_VECTOR); - auto list_data = ConstantVector::GetData(result); - list_data->offset = 0; - list_data->length = 0; - result.Verify(tuple_count); - return; - } - - auto &map = args.data[0]; - auto &key = args.data[1]; - - UnifiedVectorFormat map_data; - - // Create the chunk we'll feed to ListPosition - DataChunk list_position_chunk; - vector chunk_types; - chunk_types.reserve(2); - chunk_types.push_back(map.GetType()); - chunk_types.push_back(key.GetType()); - list_position_chunk.InitializeEmpty(chunk_types.begin(), chunk_types.end()); - - // Populate it with the map keys list and the key vector - list_position_chunk.data[0].Reference(map); - list_position_chunk.data[1].Reference(key); - list_position_chunk.SetCardinality(tuple_count); - - Vector position_vector(LogicalType::LIST(LogicalType::INTEGER), tuple_count); - // We can pass around state as it's not used by ListPositionFunction anyways - ListContainsOrPosition(list_position_chunk, position_vector); - - FillResult(map, position_vector, result, tuple_count); - - if (tuple_count == 1) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } - - result.Verify(tuple_count); -} - -static unique_ptr MapExtractBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - if (arguments.size() != 2) { - throw BinderException("MAP_EXTRACT must have exactly two arguments"); - } - if (arguments[0]->return_type.id() != LogicalTypeId::MAP) { - throw BinderException("MAP_EXTRACT can only operate on MAPs"); - } - auto &value_type = MapType::ValueType(arguments[0]->return_type); - - //! Here we have to construct the List Type that will be returned - bound_function.return_type = LogicalType::LIST(value_type); - auto key_type = MapType::KeyType(arguments[0]->return_type); - if (key_type.id() != LogicalTypeId::SQLNULL && arguments[1]->return_type.id() != LogicalTypeId::SQLNULL) { - bound_function.arguments[1] = MapType::KeyType(arguments[0]->return_type); - } - return make_uniq(value_type); -} - -ScalarFunction MapExtractFun::GetFunction() { - ScalarFunction fun({LogicalType::ANY, LogicalType::ANY}, LogicalType::ANY, MapExtractFunction, MapExtractBind); - fun.varargs = LogicalType::ANY; - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return fun; -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -static void MapFromEntriesFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto count = args.size(); - - result.Reinterpret(args.data[0]); - - MapVector::MapConversionVerify(result, count); - result.Verify(count); - - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } -} - -static unique_ptr MapFromEntriesBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - if (arguments.size() != 1) { - throw InvalidInputException("The input argument must be a list of structs."); - } - auto &list = arguments[0]->return_type; - - if (list.id() == LogicalTypeId::UNKNOWN) { - bound_function.arguments.emplace_back(LogicalTypeId::UNKNOWN); - bound_function.return_type = LogicalType(LogicalTypeId::SQLNULL); - return nullptr; - } - - if (list.id() != LogicalTypeId::LIST) { - throw InvalidInputException("The provided argument is not a list of structs"); - } - auto &elem_type = ListType::GetChildType(list); - if (elem_type.id() != LogicalTypeId::STRUCT) { - throw InvalidInputException("The elements of the list must be structs"); - } - auto &children = StructType::GetChildTypes(elem_type); - if (children.size() != 2) { - throw InvalidInputException("The provided struct type should only contain 2 fields, a key and a value"); - } - - bound_function.return_type = LogicalType::MAP(elem_type); - return make_uniq(bound_function.return_type); -} - -ScalarFunction MapFromEntriesFun::GetFunction() { - //! the arguments and return types are actually set in the binder function - ScalarFunction fun({}, LogicalTypeId::MAP, MapFromEntriesFunction, MapFromEntriesBind); - fun.null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING; - fun.varargs = LogicalType::ANY; - return fun; -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -static void MapKeyValueFunction(DataChunk &args, ExpressionState &state, Vector &result, - Vector &(*get_child_vector)(Vector &)) { - D_ASSERT(result.GetType().id() == LogicalTypeId::LIST); - auto count = args.size(); - - auto &map = args.data[0]; - D_ASSERT(map.GetType().id() == LogicalTypeId::MAP); - auto child = get_child_vector(map); - - auto &entries = ListVector::GetEntry(result); - entries.Reference(child); - - UnifiedVectorFormat map_data; - map.ToUnifiedFormat(count, map_data); - - D_ASSERT(result.GetVectorType() == VectorType::FLAT_VECTOR); - FlatVector::SetData(result, map_data.data); - FlatVector::SetValidity(result, map_data.validity); - auto list_size = ListVector::GetListSize(map); - ListVector::SetListSize(result, list_size); - if (map.GetVectorType() == VectorType::DICTIONARY_VECTOR) { - result.Slice(*map_data.sel, count); - } - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } - result.Verify(count); -} - -static void MapKeysFunction(DataChunk &args, ExpressionState &state, Vector &result) { - MapKeyValueFunction(args, state, result, MapVector::GetKeys); -} - -static void MapValuesFunction(DataChunk &args, ExpressionState &state, Vector &result) { - MapKeyValueFunction(args, state, result, MapVector::GetValues); -} - -static unique_ptr MapKeyValueBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments, - const LogicalType &(*type_func)(const LogicalType &)) { - if (arguments.size() != 1) { - throw InvalidInputException("Too many arguments provided, only expecting a single map"); - } - auto &map = arguments[0]->return_type; - - if (map.id() == LogicalTypeId::UNKNOWN) { - // Prepared statement - bound_function.arguments.emplace_back(LogicalTypeId::UNKNOWN); - bound_function.return_type = LogicalType(LogicalTypeId::SQLNULL); - return nullptr; - } - - if (map.id() != LogicalTypeId::MAP) { - throw InvalidInputException("The provided argument is not a map"); - } - - auto &type = type_func(map); - - bound_function.return_type = LogicalType::LIST(type); - return make_uniq(bound_function.return_type); -} - -static unique_ptr MapKeysBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - return MapKeyValueBind(context, bound_function, arguments, MapType::KeyType); -} - -static unique_ptr MapValuesBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - return MapKeyValueBind(context, bound_function, arguments, MapType::ValueType); -} - -ScalarFunction MapKeysFun::GetFunction() { - //! the arguments and return types are actually set in the binder function - ScalarFunction fun({}, LogicalTypeId::LIST, MapKeysFunction, MapKeysBind); - fun.null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING; - fun.varargs = LogicalType::ANY; - return fun; -} - -ScalarFunction MapValuesFun::GetFunction() { - ScalarFunction fun({}, LogicalTypeId::LIST, MapValuesFunction, MapValuesBind); - fun.null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING; - fun.varargs = LogicalType::ANY; - return fun; -} - -} // namespace duckdb - - - - - - - - - - - -#include -#include - -namespace duckdb { - -template -static scalar_function_t GetScalarIntegerUnaryFunctionFixedReturn(const LogicalType &type) { - scalar_function_t function; - switch (type.id()) { - case LogicalTypeId::TINYINT: - function = &ScalarFunction::UnaryFunction; - break; - case LogicalTypeId::SMALLINT: - function = &ScalarFunction::UnaryFunction; - break; - case LogicalTypeId::INTEGER: - function = &ScalarFunction::UnaryFunction; - break; - case LogicalTypeId::BIGINT: - function = &ScalarFunction::UnaryFunction; - break; - case LogicalTypeId::HUGEINT: - function = &ScalarFunction::UnaryFunction; - break; - default: - throw NotImplementedException("Unimplemented type for GetScalarIntegerUnaryFunctionFixedReturn"); - } - return function; -} - -//===--------------------------------------------------------------------===// -// nextafter -//===--------------------------------------------------------------------===// -struct NextAfterOperator { - template - static inline TR Operation(TA base, TB exponent) { - throw NotImplementedException("Unimplemented type for NextAfter Function"); - } - - template - static inline double Operation(double input, double approximate_to) { - return nextafter(input, approximate_to); - } - template - static inline float Operation(float input, float approximate_to) { - return nextafterf(input, approximate_to); - } -}; - -ScalarFunctionSet NextAfterFun::GetFunctions() { - ScalarFunctionSet next_after_fun; - next_after_fun.AddFunction( - ScalarFunction({LogicalType::DOUBLE, LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::BinaryFunction)); - next_after_fun.AddFunction(ScalarFunction({LogicalType::FLOAT, LogicalType::FLOAT}, LogicalType::FLOAT, - ScalarFunction::BinaryFunction)); - return next_after_fun; -} - -//===--------------------------------------------------------------------===// -// abs -//===--------------------------------------------------------------------===// -static unique_ptr PropagateAbsStats(ClientContext &context, FunctionStatisticsInput &input) { - auto &child_stats = input.child_stats; - auto &expr = input.expr; - D_ASSERT(child_stats.size() == 1); - // can only propagate stats if the children have stats - auto &lstats = child_stats[0]; - Value new_min, new_max; - bool potential_overflow = true; - if (NumericStats::HasMinMax(lstats)) { - switch (expr.return_type.InternalType()) { - case PhysicalType::INT8: - potential_overflow = NumericStats::Min(lstats).GetValue() == NumericLimits::Minimum(); - break; - case PhysicalType::INT16: - potential_overflow = NumericStats::Min(lstats).GetValue() == NumericLimits::Minimum(); - break; - case PhysicalType::INT32: - potential_overflow = NumericStats::Min(lstats).GetValue() == NumericLimits::Minimum(); - break; - case PhysicalType::INT64: - potential_overflow = NumericStats::Min(lstats).GetValue() == NumericLimits::Minimum(); - break; - default: - return nullptr; - } - } - if (potential_overflow) { - new_min = Value(expr.return_type); - new_max = Value(expr.return_type); - } else { - // no potential overflow - - // compute stats - auto current_min = NumericStats::Min(lstats).GetValue(); - auto current_max = NumericStats::Max(lstats).GetValue(); - - int64_t min_val, max_val; - - if (current_min < 0 && current_max < 0) { - // if both min and max are below zero, then min=abs(cur_max) and max=abs(cur_min) - min_val = AbsValue(current_max); - max_val = AbsValue(current_min); - } else if (current_min < 0) { - D_ASSERT(current_max >= 0); - // if min is below zero and max is above 0, then min=0 and max=max(cur_max, abs(cur_min)) - min_val = 0; - max_val = MaxValue(AbsValue(current_min), current_max); - } else { - // if both current_min and current_max are > 0, then the abs is a no-op and can be removed entirely - *input.expr_ptr = std::move(input.expr.children[0]); - return child_stats[0].ToUnique(); - } - new_min = Value::Numeric(expr.return_type, min_val); - new_max = Value::Numeric(expr.return_type, max_val); - expr.function.function = ScalarFunction::GetScalarUnaryFunction(expr.return_type); - } - auto stats = NumericStats::CreateEmpty(expr.return_type); - NumericStats::SetMin(stats, new_min); - NumericStats::SetMax(stats, new_max); - stats.CopyValidity(lstats); - return stats.ToUnique(); -} - -template -unique_ptr DecimalUnaryOpBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - auto decimal_type = arguments[0]->return_type; - switch (decimal_type.InternalType()) { - case PhysicalType::INT16: - bound_function.function = ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::SMALLINT); - break; - case PhysicalType::INT32: - bound_function.function = ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::INTEGER); - break; - case PhysicalType::INT64: - bound_function.function = ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::BIGINT); - break; - default: - bound_function.function = ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::HUGEINT); - break; - } - bound_function.arguments[0] = decimal_type; - bound_function.return_type = decimal_type; - return nullptr; -} - -ScalarFunctionSet AbsOperatorFun::GetFunctions() { - ScalarFunctionSet abs; - for (auto &type : LogicalType::Numeric()) { - switch (type.id()) { - case LogicalTypeId::DECIMAL: - abs.AddFunction(ScalarFunction({type}, type, nullptr, DecimalUnaryOpBind)); - break; - case LogicalTypeId::TINYINT: - case LogicalTypeId::SMALLINT: - case LogicalTypeId::INTEGER: - case LogicalTypeId::BIGINT: { - ScalarFunction func({type}, type, ScalarFunction::GetScalarUnaryFunction(type)); - func.statistics = PropagateAbsStats; - abs.AddFunction(func); - break; - } - case LogicalTypeId::UTINYINT: - case LogicalTypeId::USMALLINT: - case LogicalTypeId::UINTEGER: - case LogicalTypeId::UBIGINT: - abs.AddFunction(ScalarFunction({type}, type, ScalarFunction::NopFunction)); - break; - default: - abs.AddFunction(ScalarFunction({type}, type, ScalarFunction::GetScalarUnaryFunction(type))); - break; - } - } - return abs; -} - -//===--------------------------------------------------------------------===// -// bit_count -//===--------------------------------------------------------------------===// -struct BitCntOperator { - template - static inline TR Operation(TA input) { - using TU = typename std::make_unsigned::type; - TR count = 0; - for (auto value = TU(input); value; ++count) { - value &= (value - 1); - } - return count; - } -}; - -struct HugeIntBitCntOperator { - template - static inline TR Operation(TA input) { - using TU = typename std::make_unsigned::type; - TR count = 0; - - for (auto value = TU(input.upper); value; ++count) { - value &= (value - 1); - } - for (auto value = TU(input.lower); value; ++count) { - value &= (value - 1); - } - return count; - } -}; - -struct BitStringBitCntOperator { - template - static inline TR Operation(TA input) { - TR count = Bit::BitCount(input); - return count; - } -}; - -ScalarFunctionSet BitCountFun::GetFunctions() { - ScalarFunctionSet functions; - functions.AddFunction(ScalarFunction({LogicalType::TINYINT}, LogicalType::TINYINT, - ScalarFunction::UnaryFunction)); - functions.AddFunction(ScalarFunction({LogicalType::SMALLINT}, LogicalType::TINYINT, - ScalarFunction::UnaryFunction)); - functions.AddFunction(ScalarFunction({LogicalType::INTEGER}, LogicalType::TINYINT, - ScalarFunction::UnaryFunction)); - functions.AddFunction(ScalarFunction({LogicalType::BIGINT}, LogicalType::TINYINT, - ScalarFunction::UnaryFunction)); - functions.AddFunction(ScalarFunction({LogicalType::HUGEINT}, LogicalType::TINYINT, - ScalarFunction::UnaryFunction)); - functions.AddFunction(ScalarFunction({LogicalType::BIT}, LogicalType::BIGINT, - ScalarFunction::UnaryFunction)); - return functions; -} - -//===--------------------------------------------------------------------===// -// sign -//===--------------------------------------------------------------------===// -struct SignOperator { - template - static TR Operation(TA input) { - if (input == TA(0)) { - return 0; - } else if (input > TA(0)) { - return 1; - } else { - return -1; - } - } -}; - -template <> -int8_t SignOperator::Operation(float input) { - if (input == 0 || Value::IsNan(input)) { - return 0; - } else if (input > 0) { - return 1; - } else { - return -1; - } -} - -template <> -int8_t SignOperator::Operation(double input) { - if (input == 0 || Value::IsNan(input)) { - return 0; - } else if (input > 0) { - return 1; - } else { - return -1; - } -} - -ScalarFunctionSet SignFun::GetFunctions() { - ScalarFunctionSet sign; - for (auto &type : LogicalType::Numeric()) { - if (type.id() == LogicalTypeId::DECIMAL) { - continue; - } else { - sign.AddFunction( - ScalarFunction({type}, LogicalType::TINYINT, - ScalarFunction::GetScalarUnaryFunctionFixedReturn(type))); - } - } - return sign; -} - -//===--------------------------------------------------------------------===// -// ceil -//===--------------------------------------------------------------------===// -struct CeilOperator { - template - static inline TR Operation(TA left) { - return std::ceil(left); - } -}; - -template -static void GenericRoundFunctionDecimal(DataChunk &input, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - OP::template Operation(input, DecimalType::GetScale(func_expr.children[0]->return_type), result); -} - -template -unique_ptr BindGenericRoundFunctionDecimal(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - // ceil essentially removes the scale - auto &decimal_type = arguments[0]->return_type; - auto scale = DecimalType::GetScale(decimal_type); - auto width = DecimalType::GetWidth(decimal_type); - if (scale == 0) { - bound_function.function = ScalarFunction::NopFunction; - } else { - switch (decimal_type.InternalType()) { - case PhysicalType::INT16: - bound_function.function = GenericRoundFunctionDecimal; - break; - case PhysicalType::INT32: - bound_function.function = GenericRoundFunctionDecimal; - break; - case PhysicalType::INT64: - bound_function.function = GenericRoundFunctionDecimal; - break; - default: - bound_function.function = GenericRoundFunctionDecimal; - break; - } - } - bound_function.arguments[0] = decimal_type; - bound_function.return_type = LogicalType::DECIMAL(width, 0); - return nullptr; -} - -struct CeilDecimalOperator { - template - static void Operation(DataChunk &input, uint8_t scale, Vector &result) { - T power_of_ten = POWERS_OF_TEN_CLASS::POWERS_OF_TEN[scale]; - UnaryExecutor::Execute(input.data[0], result, input.size(), [&](T input) { - if (input < 0) { - // below 0 we floor the number (e.g. -10.5 -> -10) - return input / power_of_ten; - } else { - // above 0 we ceil the number - return ((input - 1) / power_of_ten) + 1; - } - }); - } -}; - -ScalarFunctionSet CeilFun::GetFunctions() { - ScalarFunctionSet ceil; - for (auto &type : LogicalType::Numeric()) { - scalar_function_t func = nullptr; - bind_scalar_function_t bind_func = nullptr; - if (type.IsIntegral()) { - // no ceil for integral numbers - continue; - } - switch (type.id()) { - case LogicalTypeId::FLOAT: - func = ScalarFunction::UnaryFunction; - break; - case LogicalTypeId::DOUBLE: - func = ScalarFunction::UnaryFunction; - break; - case LogicalTypeId::DECIMAL: - bind_func = BindGenericRoundFunctionDecimal; - break; - default: - throw InternalException("Unimplemented numeric type for function \"ceil\""); - } - ceil.AddFunction(ScalarFunction({type}, type, func, bind_func)); - } - return ceil; -} - -//===--------------------------------------------------------------------===// -// floor -//===--------------------------------------------------------------------===// -struct FloorOperator { - template - static inline TR Operation(TA left) { - return std::floor(left); - } -}; - -struct FloorDecimalOperator { - template - static void Operation(DataChunk &input, uint8_t scale, Vector &result) { - T power_of_ten = POWERS_OF_TEN_CLASS::POWERS_OF_TEN[scale]; - UnaryExecutor::Execute(input.data[0], result, input.size(), [&](T input) { - if (input < 0) { - // below 0 we ceil the number (e.g. -10.5 -> -11) - return ((input + 1) / power_of_ten) - 1; - } else { - // above 0 we floor the number - return input / power_of_ten; - } - }); - } -}; - -ScalarFunctionSet FloorFun::GetFunctions() { - ScalarFunctionSet floor; - for (auto &type : LogicalType::Numeric()) { - scalar_function_t func = nullptr; - bind_scalar_function_t bind_func = nullptr; - if (type.IsIntegral()) { - // no floor for integral numbers - continue; - } - switch (type.id()) { - case LogicalTypeId::FLOAT: - func = ScalarFunction::UnaryFunction; - break; - case LogicalTypeId::DOUBLE: - func = ScalarFunction::UnaryFunction; - break; - case LogicalTypeId::DECIMAL: - bind_func = BindGenericRoundFunctionDecimal; - break; - default: - throw InternalException("Unimplemented numeric type for function \"floor\""); - } - floor.AddFunction(ScalarFunction({type}, type, func, bind_func)); - } - return floor; -} - -//===--------------------------------------------------------------------===// -// trunc -//===--------------------------------------------------------------------===// -struct TruncOperator { - // Integer truncation is a NOP - template - static inline TR Operation(TA left) { - return std::trunc(left); - } -}; - -struct TruncDecimalOperator { - template - static void Operation(DataChunk &input, uint8_t scale, Vector &result) { - T power_of_ten = POWERS_OF_TEN_CLASS::POWERS_OF_TEN[scale]; - UnaryExecutor::Execute(input.data[0], result, input.size(), [&](T input) { - // Always floor - return (input / power_of_ten); - }); - } -}; - -ScalarFunctionSet TruncFun::GetFunctions() { - ScalarFunctionSet trunc; - for (auto &type : LogicalType::Numeric()) { - scalar_function_t func = nullptr; - bind_scalar_function_t bind_func = nullptr; - // Truncation of integers gets generated by some tools (e.g., Tableau/JDBC:Postgres) - switch (type.id()) { - case LogicalTypeId::FLOAT: - func = ScalarFunction::UnaryFunction; - break; - case LogicalTypeId::DOUBLE: - func = ScalarFunction::UnaryFunction; - break; - case LogicalTypeId::DECIMAL: - bind_func = BindGenericRoundFunctionDecimal; - break; - case LogicalTypeId::TINYINT: - case LogicalTypeId::SMALLINT: - case LogicalTypeId::INTEGER: - case LogicalTypeId::BIGINT: - case LogicalTypeId::HUGEINT: - case LogicalTypeId::UTINYINT: - case LogicalTypeId::USMALLINT: - case LogicalTypeId::UINTEGER: - case LogicalTypeId::UBIGINT: - func = ScalarFunction::NopFunction; - break; - default: - throw InternalException("Unimplemented numeric type for function \"trunc\""); - } - trunc.AddFunction(ScalarFunction({type}, type, func, bind_func)); - } - return trunc; -} - -//===--------------------------------------------------------------------===// -// round -//===--------------------------------------------------------------------===// -struct RoundOperatorPrecision { - template - static inline TR Operation(TA input, TB precision) { - double rounded_value; - if (precision < 0) { - double modifier = std::pow(10, -TA(precision)); - rounded_value = (std::round(input / modifier)) * modifier; - if (std::isinf(rounded_value) || std::isnan(rounded_value)) { - return 0; - } - } else { - double modifier = std::pow(10, TA(precision)); - rounded_value = (std::round(input * modifier)) / modifier; - if (std::isinf(rounded_value) || std::isnan(rounded_value)) { - return input; - } - } - return rounded_value; - } -}; - -struct RoundOperator { - template - static inline TR Operation(TA input) { - double rounded_value = round(input); - if (std::isinf(rounded_value) || std::isnan(rounded_value)) { - return input; - } - return rounded_value; - } -}; - -struct RoundDecimalOperator { - template - static void Operation(DataChunk &input, uint8_t scale, Vector &result) { - T power_of_ten = POWERS_OF_TEN_CLASS::POWERS_OF_TEN[scale]; - T addition = power_of_ten / 2; - // regular round rounds towards the nearest number - // in case of a tie we round away from zero - // i.e. -10.5 -> -11, 10.5 -> 11 - // we implement this by adding (positive) or subtracting (negative) 0.5 - // and then flooring the number - // e.g. 10.5 + 0.5 = 11, floor(11) = 11 - // 10.4 + 0.5 = 10.9, floor(10.9) = 10 - UnaryExecutor::Execute(input.data[0], result, input.size(), [&](T input) { - if (input < 0) { - input -= addition; - } else { - input += addition; - } - return input / power_of_ten; - }); - } -}; - -struct RoundPrecisionFunctionData : public FunctionData { - explicit RoundPrecisionFunctionData(int32_t target_scale) : target_scale(target_scale) { - } - - int32_t target_scale; - - unique_ptr Copy() const override { - return make_uniq(target_scale); - } - - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return target_scale == other.target_scale; - } -}; - -template -static void DecimalRoundNegativePrecisionFunction(DataChunk &input, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - auto source_scale = DecimalType::GetScale(func_expr.children[0]->return_type); - auto width = DecimalType::GetWidth(func_expr.children[0]->return_type); - if (info.target_scale <= -int32_t(width)) { - // scale too big for width - result.SetVectorType(VectorType::CONSTANT_VECTOR); - result.SetValue(0, Value::INTEGER(0)); - return; - } - T divide_power_of_ten = POWERS_OF_TEN_CLASS::POWERS_OF_TEN[-info.target_scale + source_scale]; - T multiply_power_of_ten = POWERS_OF_TEN_CLASS::POWERS_OF_TEN[-info.target_scale]; - T addition = divide_power_of_ten / 2; - - UnaryExecutor::Execute(input.data[0], result, input.size(), [&](T input) { - if (input < 0) { - input -= addition; - } else { - input += addition; - } - return input / divide_power_of_ten * multiply_power_of_ten; - }); -} - -template -static void DecimalRoundPositivePrecisionFunction(DataChunk &input, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - auto source_scale = DecimalType::GetScale(func_expr.children[0]->return_type); - T power_of_ten = POWERS_OF_TEN_CLASS::POWERS_OF_TEN[source_scale - info.target_scale]; - T addition = power_of_ten / 2; - UnaryExecutor::Execute(input.data[0], result, input.size(), [&](T input) { - if (input < 0) { - input -= addition; - } else { - input += addition; - } - return input / power_of_ten; - }); -} - -unique_ptr BindDecimalRoundPrecision(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - auto &decimal_type = arguments[0]->return_type; - if (arguments[1]->HasParameter()) { - throw ParameterNotResolvedException(); - } - if (!arguments[1]->IsFoldable()) { - throw NotImplementedException("ROUND(DECIMAL, INTEGER) with non-constant precision is not supported"); - } - Value val = ExpressionExecutor::EvaluateScalar(context, *arguments[1]).DefaultCastAs(LogicalType::INTEGER); - if (val.IsNull()) { - throw NotImplementedException("ROUND(DECIMAL, INTEGER) with non-constant precision is not supported"); - } - // our new precision becomes the round value - // e.g. ROUND(DECIMAL(18,3), 1) -> DECIMAL(18,1) - // but ONLY if the round value is positive - // if it is negative the scale becomes zero - // i.e. ROUND(DECIMAL(18,3), -1) -> DECIMAL(18,0) - int32_t round_value = IntegerValue::Get(val); - uint8_t target_scale; - auto width = DecimalType::GetWidth(decimal_type); - auto scale = DecimalType::GetScale(decimal_type); - if (round_value < 0) { - target_scale = 0; - switch (decimal_type.InternalType()) { - case PhysicalType::INT16: - bound_function.function = DecimalRoundNegativePrecisionFunction; - break; - case PhysicalType::INT32: - bound_function.function = DecimalRoundNegativePrecisionFunction; - break; - case PhysicalType::INT64: - bound_function.function = DecimalRoundNegativePrecisionFunction; - break; - default: - bound_function.function = DecimalRoundNegativePrecisionFunction; - break; - } - } else { - if (round_value >= (int32_t)scale) { - // if round_value is bigger than or equal to scale we do nothing - bound_function.function = ScalarFunction::NopFunction; - target_scale = scale; - } else { - target_scale = round_value; - switch (decimal_type.InternalType()) { - case PhysicalType::INT16: - bound_function.function = DecimalRoundPositivePrecisionFunction; - break; - case PhysicalType::INT32: - bound_function.function = DecimalRoundPositivePrecisionFunction; - break; - case PhysicalType::INT64: - bound_function.function = DecimalRoundPositivePrecisionFunction; - break; - default: - bound_function.function = DecimalRoundPositivePrecisionFunction; - break; - } - } - } - bound_function.arguments[0] = decimal_type; - bound_function.return_type = LogicalType::DECIMAL(width, target_scale); - return make_uniq(round_value); -} - -ScalarFunctionSet RoundFun::GetFunctions() { - ScalarFunctionSet round; - for (auto &type : LogicalType::Numeric()) { - scalar_function_t round_prec_func = nullptr; - scalar_function_t round_func = nullptr; - bind_scalar_function_t bind_func = nullptr; - bind_scalar_function_t bind_prec_func = nullptr; - if (type.IsIntegral()) { - // no round for integral numbers - continue; - } - switch (type.id()) { - case LogicalTypeId::FLOAT: - round_func = ScalarFunction::UnaryFunction; - round_prec_func = ScalarFunction::BinaryFunction; - break; - case LogicalTypeId::DOUBLE: - round_func = ScalarFunction::UnaryFunction; - round_prec_func = ScalarFunction::BinaryFunction; - break; - case LogicalTypeId::DECIMAL: - bind_func = BindGenericRoundFunctionDecimal; - bind_prec_func = BindDecimalRoundPrecision; - break; - default: - throw InternalException("Unimplemented numeric type for function \"floor\""); - } - round.AddFunction(ScalarFunction({type}, type, round_func, bind_func)); - round.AddFunction(ScalarFunction({type, LogicalType::INTEGER}, type, round_prec_func, bind_prec_func)); - } - return round; -} - -//===--------------------------------------------------------------------===// -// exp -//===--------------------------------------------------------------------===// -struct ExpOperator { - template - static inline TR Operation(TA left) { - return std::exp(left); - } -}; - -ScalarFunction ExpFun::GetFunction() { - return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction); -} - -//===--------------------------------------------------------------------===// -// pow -//===--------------------------------------------------------------------===// -struct PowOperator { - template - static inline TR Operation(TA base, TB exponent) { - return std::pow(base, exponent); - } -}; - -ScalarFunction PowOperatorFun::GetFunction() { - return ScalarFunction({LogicalType::DOUBLE, LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::BinaryFunction); -} - -//===--------------------------------------------------------------------===// -// sqrt -//===--------------------------------------------------------------------===// -struct SqrtOperator { - template - static inline TR Operation(TA input) { - if (input < 0) { - throw OutOfRangeException("cannot take square root of a negative number"); - } - return std::sqrt(input); - } -}; - -ScalarFunction SqrtFun::GetFunction() { - return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction); -} - -//===--------------------------------------------------------------------===// -// cbrt -//===--------------------------------------------------------------------===// -struct CbRtOperator { - template - static inline TR Operation(TA left) { - return std::cbrt(left); - } -}; - -ScalarFunction CbrtFun::GetFunction() { - return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction); -} - -//===--------------------------------------------------------------------===// -// ln -//===--------------------------------------------------------------------===// - -struct LnOperator { - template - static inline TR Operation(TA input) { - if (input < 0) { - throw OutOfRangeException("cannot take logarithm of a negative number"); - } - if (input == 0) { - throw OutOfRangeException("cannot take logarithm of zero"); - } - return std::log(input); - } -}; - -ScalarFunction LnFun::GetFunction() { - return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction); -} - -//===--------------------------------------------------------------------===// -// log -//===--------------------------------------------------------------------===// -struct Log10Operator { - template - static inline TR Operation(TA input) { - if (input < 0) { - throw OutOfRangeException("cannot take logarithm of a negative number"); - } - if (input == 0) { - throw OutOfRangeException("cannot take logarithm of zero"); - } - return std::log10(input); - } -}; - -ScalarFunction Log10Fun::GetFunction() { - return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction); -} - -//===--------------------------------------------------------------------===// -// log2 -//===--------------------------------------------------------------------===// -struct Log2Operator { - template - static inline TR Operation(TA input) { - if (input < 0) { - throw OutOfRangeException("cannot take logarithm of a negative number"); - } - if (input == 0) { - throw OutOfRangeException("cannot take logarithm of zero"); - } - return std::log2(input); - } -}; - -ScalarFunction Log2Fun::GetFunction() { - return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction); -} - -//===--------------------------------------------------------------------===// -// pi -//===--------------------------------------------------------------------===// -static void PiFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 0); - Value pi_value = Value::DOUBLE(PI); - result.Reference(pi_value); -} - -ScalarFunction PiFun::GetFunction() { - return ScalarFunction({}, LogicalType::DOUBLE, PiFunction); -} - -//===--------------------------------------------------------------------===// -// degrees -//===--------------------------------------------------------------------===// -struct DegreesOperator { - template - static inline TR Operation(TA left) { - return left * (180 / PI); - } -}; - -ScalarFunction DegreesFun::GetFunction() { - return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction); -} - -//===--------------------------------------------------------------------===// -// radians -//===--------------------------------------------------------------------===// -struct RadiansOperator { - template - static inline TR Operation(TA left) { - return left * (PI / 180); - } -}; - -ScalarFunction RadiansFun::GetFunction() { - return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction); -} - -//===--------------------------------------------------------------------===// -// isnan -//===--------------------------------------------------------------------===// -struct IsNanOperator { - template - static inline TR Operation(TA input) { - return Value::IsNan(input); - } -}; - -ScalarFunctionSet IsNanFun::GetFunctions() { - ScalarFunctionSet funcs; - funcs.AddFunction(ScalarFunction({LogicalType::FLOAT}, LogicalType::BOOLEAN, - ScalarFunction::UnaryFunction)); - funcs.AddFunction(ScalarFunction({LogicalType::DOUBLE}, LogicalType::BOOLEAN, - ScalarFunction::UnaryFunction)); - return funcs; -} - -//===--------------------------------------------------------------------===// -// signbit -//===--------------------------------------------------------------------===// -struct SignBitOperator { - template - static inline TR Operation(TA input) { - return std::signbit(input); - } -}; - -ScalarFunctionSet SignBitFun::GetFunctions() { - ScalarFunctionSet funcs; - funcs.AddFunction(ScalarFunction({LogicalType::FLOAT}, LogicalType::BOOLEAN, - ScalarFunction::UnaryFunction)); - funcs.AddFunction(ScalarFunction({LogicalType::DOUBLE}, LogicalType::BOOLEAN, - ScalarFunction::UnaryFunction)); - return funcs; -} - -//===--------------------------------------------------------------------===// -// isinf -//===--------------------------------------------------------------------===// -struct IsInfiniteOperator { - template - static inline TR Operation(TA input) { - return !Value::IsNan(input) && !Value::IsFinite(input); - } -}; - -template <> -bool IsInfiniteOperator::Operation(date_t input) { - return !Value::IsFinite(input); -} - -template <> -bool IsInfiniteOperator::Operation(timestamp_t input) { - return !Value::IsFinite(input); -} - -ScalarFunctionSet IsInfiniteFun::GetFunctions() { - ScalarFunctionSet funcs("isinf"); - funcs.AddFunction(ScalarFunction({LogicalType::FLOAT}, LogicalType::BOOLEAN, - ScalarFunction::UnaryFunction)); - funcs.AddFunction(ScalarFunction({LogicalType::DOUBLE}, LogicalType::BOOLEAN, - ScalarFunction::UnaryFunction)); - funcs.AddFunction(ScalarFunction({LogicalType::DATE}, LogicalType::BOOLEAN, - ScalarFunction::UnaryFunction)); - funcs.AddFunction(ScalarFunction({LogicalType::TIMESTAMP}, LogicalType::BOOLEAN, - ScalarFunction::UnaryFunction)); - funcs.AddFunction(ScalarFunction({LogicalType::TIMESTAMP_TZ}, LogicalType::BOOLEAN, - ScalarFunction::UnaryFunction)); - return funcs; -} - -//===--------------------------------------------------------------------===// -// isfinite -//===--------------------------------------------------------------------===// -struct IsFiniteOperator { - template - static inline TR Operation(TA input) { - return Value::IsFinite(input); - } -}; - -ScalarFunctionSet IsFiniteFun::GetFunctions() { - ScalarFunctionSet funcs; - funcs.AddFunction(ScalarFunction({LogicalType::FLOAT}, LogicalType::BOOLEAN, - ScalarFunction::UnaryFunction)); - funcs.AddFunction(ScalarFunction({LogicalType::DOUBLE}, LogicalType::BOOLEAN, - ScalarFunction::UnaryFunction)); - funcs.AddFunction(ScalarFunction({LogicalType::DATE}, LogicalType::BOOLEAN, - ScalarFunction::UnaryFunction)); - funcs.AddFunction(ScalarFunction({LogicalType::TIMESTAMP}, LogicalType::BOOLEAN, - ScalarFunction::UnaryFunction)); - funcs.AddFunction(ScalarFunction({LogicalType::TIMESTAMP_TZ}, LogicalType::BOOLEAN, - ScalarFunction::UnaryFunction)); - return funcs; -} - -//===--------------------------------------------------------------------===// -// sin -//===--------------------------------------------------------------------===// -template -struct NoInfiniteDoubleWrapper { - template - static RESULT_TYPE Operation(INPUT_TYPE input) { - if (DUCKDB_UNLIKELY(!Value::IsFinite(input))) { - if (Value::IsNan(input)) { - return input; - } - throw OutOfRangeException("input value %lf is out of range for numeric function", input); - } - return OP::template Operation(input); - } -}; - -struct SinOperator { - template - static inline TR Operation(TA input) { - return std::sin(input); - } -}; - -ScalarFunction SinFun::GetFunction() { - return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction>); -} - -//===--------------------------------------------------------------------===// -// cos -//===--------------------------------------------------------------------===// -struct CosOperator { - template - static inline TR Operation(TA input) { - return (double)std::cos(input); - } -}; - -ScalarFunction CosFun::GetFunction() { - return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction>); -} - -//===--------------------------------------------------------------------===// -// tan -//===--------------------------------------------------------------------===// -struct TanOperator { - template - static inline TR Operation(TA input) { - return (double)std::tan(input); - } -}; - -ScalarFunction TanFun::GetFunction() { - return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction>); -} - -//===--------------------------------------------------------------------===// -// asin -//===--------------------------------------------------------------------===// -struct ASinOperator { - template - static inline TR Operation(TA input) { - if (input < -1 || input > 1) { - throw Exception("ASIN is undefined outside [-1,1]"); - } - return (double)std::asin(input); - } -}; - -ScalarFunction AsinFun::GetFunction() { - return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction>); -} - -//===--------------------------------------------------------------------===// -// atan -//===--------------------------------------------------------------------===// -struct ATanOperator { - template - static inline TR Operation(TA input) { - return (double)std::atan(input); - } -}; - -ScalarFunction AtanFun::GetFunction() { - return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction); -} - -//===--------------------------------------------------------------------===// -// atan2 -//===--------------------------------------------------------------------===// -struct ATan2 { - template - static inline TR Operation(TA left, TB right) { - return (double)std::atan2(left, right); - } -}; - -ScalarFunction Atan2Fun::GetFunction() { - return ScalarFunction({LogicalType::DOUBLE, LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::BinaryFunction); -} - -//===--------------------------------------------------------------------===// -// acos -//===--------------------------------------------------------------------===// -struct ACos { - template - static inline TR Operation(TA input) { - return (double)std::acos(input); - } -}; - -ScalarFunction AcosFun::GetFunction() { - return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction>); -} - -//===--------------------------------------------------------------------===// -// cot -//===--------------------------------------------------------------------===// -struct CotOperator { - template - static inline TR Operation(TA input) { - return 1.0 / (double)std::tan(input); - } -}; - -ScalarFunction CotFun::GetFunction() { - return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction>); -} - -//===--------------------------------------------------------------------===// -// gamma -//===--------------------------------------------------------------------===// -struct GammaOperator { - template - static inline TR Operation(TA input) { - if (input == 0) { - throw OutOfRangeException("cannot take gamma of zero"); - } - return std::tgamma(input); - } -}; - -ScalarFunction GammaFun::GetFunction() { - return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction); -} - -//===--------------------------------------------------------------------===// -// gamma -//===--------------------------------------------------------------------===// -struct LogGammaOperator { - template - static inline TR Operation(TA input) { - if (input == 0) { - throw OutOfRangeException("cannot take log gamma of zero"); - } - return std::lgamma(input); - } -}; - -ScalarFunction LogGammaFun::GetFunction() { - return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction); -} - -//===--------------------------------------------------------------------===// -// factorial(), ! -//===--------------------------------------------------------------------===// -struct FactorialOperator { - template - static inline TR Operation(TA left) { - TR ret = 1; - for (TA i = 2; i <= left; i++) { - ret *= i; - } - return ret; - } -}; - -ScalarFunction FactorialOperatorFun::GetFunction() { - return ScalarFunction({LogicalType::INTEGER}, LogicalType::HUGEINT, - ScalarFunction::UnaryFunction); -} - -//===--------------------------------------------------------------------===// -// even -//===--------------------------------------------------------------------===// -struct EvenOperator { - template - static inline TR Operation(TA left) { - double value; - if (left >= 0) { - value = std::ceil(left); - } else { - value = std::ceil(-left); - value = -value; - } - if (std::floor(value / 2) * 2 != value) { - if (left >= 0) { - return value += 1; - } - return value -= 1; - } - return value; - } -}; - -ScalarFunction EvenFun::GetFunction() { - return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction); -} - -//===--------------------------------------------------------------------===// -// gcd -//===--------------------------------------------------------------------===// - -// should be replaced with std::gcd in a newer C++ standard -template -TA GreatestCommonDivisor(TA left, TA right) { - TA a = left; - TA b = right; - - // This protects the following modulo operations from a corner case, - // where we would get a runtime error due to an integer overflow. - if ((left == NumericLimits::Minimum() && right == -1) || - (left == -1 && right == NumericLimits::Minimum())) { - return 1; - } - - while (true) { - if (a == 0) { - return TryAbsOperator::Operation(b); - } - b %= a; - - if (b == 0) { - return TryAbsOperator::Operation(a); - } - a %= b; - } -} - -struct GreatestCommonDivisorOperator { - template - static inline TR Operation(TA left, TB right) { - return GreatestCommonDivisor(left, right); - } -}; - -ScalarFunctionSet GreatestCommonDivisorFun::GetFunctions() { - ScalarFunctionSet funcs; - funcs.AddFunction( - ScalarFunction({LogicalType::BIGINT, LogicalType::BIGINT}, LogicalType::BIGINT, - ScalarFunction::BinaryFunction)); - funcs.AddFunction( - ScalarFunction({LogicalType::HUGEINT, LogicalType::HUGEINT}, LogicalType::HUGEINT, - ScalarFunction::BinaryFunction)); - return funcs; -} - -//===--------------------------------------------------------------------===// -// lcm -//===--------------------------------------------------------------------===// - -// should be replaced with std::lcm in a newer C++ standard -struct LeastCommonMultipleOperator { - template - static inline TR Operation(TA left, TB right) { - if (left == 0 || right == 0) { - return 0; - } - TR result; - if (!TryMultiplyOperator::Operation(left, right / GreatestCommonDivisor(left, right), result)) { - throw OutOfRangeException("lcm value is out of range"); - } - return TryAbsOperator::Operation(result); - } -}; - -ScalarFunctionSet LeastCommonMultipleFun::GetFunctions() { - ScalarFunctionSet funcs; - - funcs.AddFunction( - ScalarFunction({LogicalType::BIGINT, LogicalType::BIGINT}, LogicalType::BIGINT, - ScalarFunction::BinaryFunction)); - funcs.AddFunction( - ScalarFunction({LogicalType::HUGEINT, LogicalType::HUGEINT}, LogicalType::HUGEINT, - ScalarFunction::BinaryFunction)); - return funcs; -} - -} // namespace duckdb - - - - - -namespace duckdb { - -template -static scalar_function_t GetScalarIntegerUnaryFunction(const LogicalType &type) { - scalar_function_t function; - switch (type.id()) { - case LogicalTypeId::TINYINT: - function = &ScalarFunction::UnaryFunction; - break; - case LogicalTypeId::SMALLINT: - function = &ScalarFunction::UnaryFunction; - break; - case LogicalTypeId::INTEGER: - function = &ScalarFunction::UnaryFunction; - break; - case LogicalTypeId::BIGINT: - function = &ScalarFunction::UnaryFunction; - break; - case LogicalTypeId::UTINYINT: - function = &ScalarFunction::UnaryFunction; - break; - case LogicalTypeId::USMALLINT: - function = &ScalarFunction::UnaryFunction; - break; - case LogicalTypeId::UINTEGER: - function = &ScalarFunction::UnaryFunction; - break; - case LogicalTypeId::UBIGINT: - function = &ScalarFunction::UnaryFunction; - break; - case LogicalTypeId::HUGEINT: - function = &ScalarFunction::UnaryFunction; - break; - default: - throw NotImplementedException("Unimplemented type for GetScalarIntegerUnaryFunction"); - } - return function; -} - -template -static scalar_function_t GetScalarIntegerBinaryFunction(const LogicalType &type) { - scalar_function_t function; - switch (type.id()) { - case LogicalTypeId::TINYINT: - function = &ScalarFunction::BinaryFunction; - break; - case LogicalTypeId::SMALLINT: - function = &ScalarFunction::BinaryFunction; - break; - case LogicalTypeId::INTEGER: - function = &ScalarFunction::BinaryFunction; - break; - case LogicalTypeId::BIGINT: - function = &ScalarFunction::BinaryFunction; - break; - case LogicalTypeId::UTINYINT: - function = &ScalarFunction::BinaryFunction; - break; - case LogicalTypeId::USMALLINT: - function = &ScalarFunction::BinaryFunction; - break; - case LogicalTypeId::UINTEGER: - function = &ScalarFunction::BinaryFunction; - break; - case LogicalTypeId::UBIGINT: - function = &ScalarFunction::BinaryFunction; - break; - case LogicalTypeId::HUGEINT: - function = &ScalarFunction::BinaryFunction; - break; - default: - throw NotImplementedException("Unimplemented type for GetScalarIntegerBinaryFunction"); - } - return function; -} - -//===--------------------------------------------------------------------===// -// & [bitwise_and] -//===--------------------------------------------------------------------===// -struct BitwiseANDOperator { - template - static inline TR Operation(TA left, TB right) { - return left & right; - } -}; - -static void BitwiseANDOperation(DataChunk &args, ExpressionState &state, Vector &result) { - BinaryExecutor::Execute( - args.data[0], args.data[1], result, args.size(), [&](string_t rhs, string_t lhs) { - string_t target = StringVector::EmptyString(result, rhs.GetSize()); - - Bit::BitwiseAnd(rhs, lhs, target); - return target; - }); -} - -ScalarFunctionSet BitwiseAndFun::GetFunctions() { - ScalarFunctionSet functions; - for (auto &type : LogicalType::Integral()) { - functions.AddFunction( - ScalarFunction({type, type}, type, GetScalarIntegerBinaryFunction(type))); - } - functions.AddFunction(ScalarFunction({LogicalType::BIT, LogicalType::BIT}, LogicalType::BIT, BitwiseANDOperation)); - return functions; -} - -//===--------------------------------------------------------------------===// -// | [bitwise_or] -//===--------------------------------------------------------------------===// -struct BitwiseOROperator { - template - static inline TR Operation(TA left, TB right) { - return left | right; - } -}; - -static void BitwiseOROperation(DataChunk &args, ExpressionState &state, Vector &result) { - BinaryExecutor::Execute( - args.data[0], args.data[1], result, args.size(), [&](string_t rhs, string_t lhs) { - string_t target = StringVector::EmptyString(result, rhs.GetSize()); - - Bit::BitwiseOr(rhs, lhs, target); - return target; - }); -} - -ScalarFunctionSet BitwiseOrFun::GetFunctions() { - ScalarFunctionSet functions; - for (auto &type : LogicalType::Integral()) { - functions.AddFunction( - ScalarFunction({type, type}, type, GetScalarIntegerBinaryFunction(type))); - } - functions.AddFunction(ScalarFunction({LogicalType::BIT, LogicalType::BIT}, LogicalType::BIT, BitwiseOROperation)); - return functions; -} - -//===--------------------------------------------------------------------===// -// # [bitwise_xor] -//===--------------------------------------------------------------------===// -struct BitwiseXOROperator { - template - static inline TR Operation(TA left, TB right) { - return left ^ right; - } -}; - -static void BitwiseXOROperation(DataChunk &args, ExpressionState &state, Vector &result) { - BinaryExecutor::Execute( - args.data[0], args.data[1], result, args.size(), [&](string_t rhs, string_t lhs) { - string_t target = StringVector::EmptyString(result, rhs.GetSize()); - - Bit::BitwiseXor(rhs, lhs, target); - return target; - }); -} - -ScalarFunctionSet BitwiseXorFun::GetFunctions() { - ScalarFunctionSet functions; - for (auto &type : LogicalType::Integral()) { - functions.AddFunction( - ScalarFunction({type, type}, type, GetScalarIntegerBinaryFunction(type))); - } - functions.AddFunction(ScalarFunction({LogicalType::BIT, LogicalType::BIT}, LogicalType::BIT, BitwiseXOROperation)); - return functions; -} - -//===--------------------------------------------------------------------===// -// ~ [bitwise_not] -//===--------------------------------------------------------------------===// -struct BitwiseNotOperator { - template - static inline TR Operation(TA input) { - return ~input; - } -}; - -static void BitwiseNOTOperation(DataChunk &args, ExpressionState &state, Vector &result) { - UnaryExecutor::Execute(args.data[0], result, args.size(), [&](string_t input) { - string_t target = StringVector::EmptyString(result, input.GetSize()); - - Bit::BitwiseNot(input, target); - return target; - }); -} - -ScalarFunctionSet BitwiseNotFun::GetFunctions() { - ScalarFunctionSet functions; - for (auto &type : LogicalType::Integral()) { - functions.AddFunction(ScalarFunction({type}, type, GetScalarIntegerUnaryFunction(type))); - } - functions.AddFunction(ScalarFunction({LogicalType::BIT}, LogicalType::BIT, BitwiseNOTOperation)); - return functions; -} - -//===--------------------------------------------------------------------===// -// << [bitwise_left_shift] -//===--------------------------------------------------------------------===// - -struct BitwiseShiftLeftOperator { - template - static inline TR Operation(TA input, TB shift) { - TA max_shift = TA(sizeof(TA) * 8); - if (input < 0) { - throw OutOfRangeException("Cannot left-shift negative number %s", NumericHelper::ToString(input)); - } - if (shift < 0) { - throw OutOfRangeException("Cannot left-shift by negative number %s", NumericHelper::ToString(shift)); - } - if (shift >= max_shift) { - if (input == 0) { - return 0; - } - throw OutOfRangeException("Left-shift value %s is out of range", NumericHelper::ToString(shift)); - } - if (shift == 0) { - return input; - } - TA max_value = (TA(1) << (max_shift - shift - 1)); - if (input >= max_value) { - throw OutOfRangeException("Overflow in left shift (%s << %s)", NumericHelper::ToString(input), - NumericHelper::ToString(shift)); - } - return input << shift; - } -}; - -static void BitwiseShiftLeftOperation(DataChunk &args, ExpressionState &state, Vector &result) { - BinaryExecutor::Execute( - args.data[0], args.data[1], result, args.size(), [&](string_t input, int32_t shift) { - int32_t max_shift = Bit::BitLength(input); - if (shift == 0) { - return input; - } - if (shift < 0) { - throw OutOfRangeException("Cannot left-shift by negative number %s", NumericHelper::ToString(shift)); - } - string_t target = StringVector::EmptyString(result, input.GetSize()); - - if (shift >= max_shift) { - Bit::SetEmptyBitString(target, input); - return target; - } - Bit::LeftShift(input, shift, target); - return target; - }); -} - -ScalarFunctionSet LeftShiftFun::GetFunctions() { - ScalarFunctionSet functions; - for (auto &type : LogicalType::Integral()) { - functions.AddFunction( - ScalarFunction({type, type}, type, GetScalarIntegerBinaryFunction(type))); - } - functions.AddFunction( - ScalarFunction({LogicalType::BIT, LogicalType::INTEGER}, LogicalType::BIT, BitwiseShiftLeftOperation)); - return functions; -} - -//===--------------------------------------------------------------------===// -// >> [bitwise_right_shift] -//===--------------------------------------------------------------------===// -template -bool RightShiftInRange(T shift) { - return shift >= 0 && shift < T(sizeof(T) * 8); -} - -struct BitwiseShiftRightOperator { - template - static inline TR Operation(TA input, TB shift) { - return RightShiftInRange(shift) ? input >> shift : 0; - } -}; - -static void BitwiseShiftRightOperation(DataChunk &args, ExpressionState &state, Vector &result) { - BinaryExecutor::Execute( - args.data[0], args.data[1], result, args.size(), [&](string_t input, int32_t shift) { - int32_t max_shift = Bit::BitLength(input); - if (shift == 0) { - return input; - } - string_t target = StringVector::EmptyString(result, input.GetSize()); - if (shift < 0 || shift >= max_shift) { - Bit::SetEmptyBitString(target, input); - return target; - } - Bit::RightShift(input, shift, target); - return target; - }); -} - -ScalarFunctionSet RightShiftFun::GetFunctions() { - ScalarFunctionSet functions; - for (auto &type : LogicalType::Integral()) { - functions.AddFunction( - ScalarFunction({type, type}, type, GetScalarIntegerBinaryFunction(type))); - } - functions.AddFunction( - ScalarFunction({LogicalType::BIT, LogicalType::INTEGER}, LogicalType::BIT, BitwiseShiftRightOperation)); - return functions; -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -struct RandomLocalState : public FunctionLocalState { - explicit RandomLocalState(uint32_t seed) : random_engine(seed) { - } - - RandomEngine random_engine; -}; - -static void RandomFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 0); - auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); - - result.SetVectorType(VectorType::FLAT_VECTOR); - auto result_data = FlatVector::GetData(result); - for (idx_t i = 0; i < args.size(); i++) { - result_data[i] = lstate.random_engine.NextRandom(); - } -} - -static unique_ptr RandomInitLocalState(ExpressionState &state, const BoundFunctionExpression &expr, - FunctionData *bind_data) { - auto &random_engine = RandomEngine::Get(state.GetContext()); - lock_guard guard(random_engine.lock); - return make_uniq(random_engine.NextRandomInteger()); -} - -ScalarFunction RandomFun::GetFunction() { - ScalarFunction random("random", {}, LogicalType::DOUBLE, RandomFunction, nullptr, nullptr, nullptr, - RandomInitLocalState); - random.side_effects = FunctionSideEffects::HAS_SIDE_EFFECTS; - return random; -} - -static void GenerateUUIDFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 0); - auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); - - result.SetVectorType(VectorType::FLAT_VECTOR); - auto result_data = FlatVector::GetData(result); - - for (idx_t i = 0; i < args.size(); i++) { - result_data[i] = UUID::GenerateRandomUUID(lstate.random_engine); - } -} - -ScalarFunction UUIDFun::GetFunction() { - ScalarFunction uuid_function({}, LogicalType::UUID, GenerateUUIDFunction, nullptr, nullptr, nullptr, - RandomInitLocalState); - // generate a random uuid - uuid_function.side_effects = FunctionSideEffects::HAS_SIDE_EFFECTS; - return uuid_function; -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -struct SetseedBindData : public FunctionData { - //! The client context for the function call - ClientContext &context; - - explicit SetseedBindData(ClientContext &context) : context(context) { - } - - unique_ptr Copy() const override { - return make_uniq(context); - } - - bool Equals(const FunctionData &other_p) const override { - return true; - } -}; - -static void SetSeedFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - auto &input = args.data[0]; - input.Flatten(args.size()); - - auto input_seeds = FlatVector::GetData(input); - uint32_t half_max = NumericLimits::Maximum() / 2; - - auto &random_engine = RandomEngine::Get(info.context); - for (idx_t i = 0; i < args.size(); i++) { - if (input_seeds[i] < -1.0 || input_seeds[i] > 1.0 || Value::IsNan(input_seeds[i])) { - throw Exception("SETSEED accepts seed values between -1.0 and 1.0, inclusive"); - } - uint32_t norm_seed = (input_seeds[i] + 1.0) * half_max; - random_engine.SetSeed(norm_seed); - } - - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); -} - -unique_ptr SetSeedBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - return make_uniq(context); -} - -ScalarFunction SetseedFun::GetFunction() { - ScalarFunction setseed("setseed", {LogicalType::DOUBLE}, LogicalType::SQLNULL, SetSeedFunction, SetSeedBind); - setseed.side_effects = FunctionSideEffects::HAS_SIDE_EFFECTS; - return setseed; -} - -} // namespace duckdb - - - - -namespace duckdb { - -struct AsciiOperator { - template - static inline TR Operation(const TA &input) { - auto str = input.GetData(); - if (Utf8Proc::Analyze(str, input.GetSize()) == UnicodeType::ASCII) { - return str[0]; - } - int utf8_bytes = 4; - return Utf8Proc::UTF8ToCodepoint(str, utf8_bytes); - } -}; - -ScalarFunction ASCIIFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR}, LogicalType::INTEGER, - ScalarFunction::UnaryFunction); -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -static string_t BarScalarFunction(double x, double min, double max, double max_width, string &result) { - static const char *FULL_BLOCK = UnicodeBar::FullBlock(); - static const char *const *PARTIAL_BLOCKS = UnicodeBar::PartialBlocks(); - static const idx_t PARTIAL_BLOCKS_COUNT = UnicodeBar::PartialBlocksCount(); - - if (!Value::IsFinite(max_width)) { - throw ValueOutOfRangeException("Max bar width must not be NaN or infinity"); - } - if (max_width < 1) { - throw ValueOutOfRangeException("Max bar width must be >= 1"); - } - if (max_width > 1000) { - throw ValueOutOfRangeException("Max bar width must be <= 1000"); - } - - double width; - - if (Value::IsNan(x) || Value::IsNan(min) || Value::IsNan(max) || x <= min) { - width = 0; - } else if (x >= max) { - width = max_width; - } else { - width = max_width * (x - min) / (max - min); - } - - if (!Value::IsFinite(width)) { - throw ValueOutOfRangeException("Bar width must not be NaN or infinity"); - } - - result.clear(); - - int32_t width_as_int = static_cast(width * PARTIAL_BLOCKS_COUNT); - idx_t full_blocks_count = (width_as_int / PARTIAL_BLOCKS_COUNT); - for (idx_t i = 0; i < full_blocks_count; i++) { - result += FULL_BLOCK; - } - - idx_t remaining = width_as_int % PARTIAL_BLOCKS_COUNT; - - if (remaining) { - result += PARTIAL_BLOCKS[remaining]; - } - - return string_t(result); -} - -static void BarFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 3 || args.ColumnCount() == 4); - auto &x_arg = args.data[0]; - auto &min_arg = args.data[1]; - auto &max_arg = args.data[2]; - string buffer; - - if (args.ColumnCount() == 3) { - GenericExecutor::ExecuteTernary, PrimitiveType, PrimitiveType, - PrimitiveType>( - x_arg, min_arg, max_arg, result, args.size(), - [&](PrimitiveType x, PrimitiveType min, PrimitiveType max) { - return StringVector::AddString(result, BarScalarFunction(x.val, min.val, max.val, 80, buffer)); - }); - } else { - auto &width_arg = args.data[3]; - GenericExecutor::ExecuteQuaternary, PrimitiveType, PrimitiveType, - PrimitiveType, PrimitiveType>( - x_arg, min_arg, max_arg, width_arg, result, args.size(), - [&](PrimitiveType x, PrimitiveType min, PrimitiveType max, - PrimitiveType width) { - return StringVector::AddString(result, BarScalarFunction(x.val, min.val, max.val, width.val, buffer)); - }); - } -} - -ScalarFunctionSet BarFun::GetFunctions() { - ScalarFunctionSet bar; - bar.AddFunction(ScalarFunction({LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE}, - LogicalType::VARCHAR, BarFunction)); - bar.AddFunction(ScalarFunction({LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE}, - LogicalType::VARCHAR, BarFunction)); - return bar; -} - -} // namespace duckdb - - - - -namespace duckdb { - -struct ChrOperator { - static void GetCodepoint(int32_t input, char c[], int &utf8_bytes) { - if (input < 0 || !Utf8Proc::CodepointToUtf8(input, utf8_bytes, &c[0])) { - throw InvalidInputException("Invalid UTF8 Codepoint %d", input); - } - } - - template - static inline TR Operation(const TA &input) { - char c[5] = {'\0', '\0', '\0', '\0', '\0'}; - int utf8_bytes; - GetCodepoint(input, c, utf8_bytes); - return string_t(&c[0], utf8_bytes); - } -}; - -#ifdef DUCKDB_DEBUG_NO_INLINE -// the chr function depends on the data always being inlined (which is always possible, since it outputs max 4 bytes) -// to enable chr when string inlining is disabled we create a special function here -static void ChrFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &code_vec = args.data[0]; - - char c[5] = {'\0', '\0', '\0', '\0', '\0'}; - int utf8_bytes; - UnaryExecutor::Execute(code_vec, result, args.size(), [&](int32_t input) { - ChrOperator::GetCodepoint(input, c, utf8_bytes); - return StringVector::AddString(result, &c[0], utf8_bytes); - }); -} -#endif - -ScalarFunction ChrFun::GetFunction() { - return ScalarFunction("chr", {LogicalType::INTEGER}, LogicalType::VARCHAR, -#ifdef DUCKDB_DEBUG_NO_INLINE - ChrFunction -#else - ScalarFunction::UnaryFunction -#endif - ); -} - -} // namespace duckdb - - - - -namespace duckdb { - -// Using Lowrance-Wagner (LW) algorithm: https://doi.org/10.1145%2F321879.321880 -// Can't calculate as trivial modification to levenshtein algorithm -// as we need to potentially know about earlier in the string -static idx_t DamerauLevenshteinDistance(const string_t &source, const string_t &target) { - // costs associated with each type of edit, to aid readability - constexpr uint8_t COST_SUBSTITUTION = 1; - constexpr uint8_t COST_INSERTION = 1; - constexpr uint8_t COST_DELETION = 1; - constexpr uint8_t COST_TRANSPOSITION = 1; - const auto source_len = source.GetSize(); - const auto target_len = target.GetSize(); - - // If one string is empty, the distance equals the length of the other string - // either through target_len insertions - // or source_len deletions - if (source_len == 0) { - return target_len * COST_INSERTION; - } else if (target_len == 0) { - return source_len * COST_DELETION; - } - - const auto source_str = source.GetData(); - const auto target_str = target.GetData(); - - // larger than the largest possible value: - const auto inf = source_len * COST_DELETION + target_len * COST_INSERTION + 1; - // minimum edit distance from prefix of source string to prefix of target string - // same object as H in LW paper (with indices offset by 1) - vector> distance(source_len + 2, vector(target_len + 2, inf)); - // keeps track of the largest string indices of source string matching each character - // same as DA in LW paper - map largest_source_chr_matching; - - // initialise row/column corresponding to zero-length strings - // partial string -> empty requires a deletion for each character - for (idx_t source_idx = 0; source_idx <= source_len; source_idx++) { - distance[source_idx + 1][1] = source_idx * COST_DELETION; - } - // and empty -> partial string means simply inserting characters - for (idx_t target_idx = 1; target_idx <= target_len; target_idx++) { - distance[1][target_idx + 1] = target_idx * COST_INSERTION; - } - // loop through string indices - these are offset by 2 from distance indices - for (idx_t source_idx = 0; source_idx < source_len; source_idx++) { - // keeps track of the largest string indices of target string matching current source character - // same as DB in LW paper - idx_t largest_target_chr_matching; - largest_target_chr_matching = 0; - for (idx_t target_idx = 0; target_idx < target_len; target_idx++) { - // correspond to i1 and j1 in LW paper respectively - idx_t largest_source_chr_matching_target; - idx_t largest_target_chr_matching_source; - // cost associated to diagnanl shift in distance matrix - // corresponds to d in LW paper - uint8_t cost_diagonal_shift; - largest_source_chr_matching_target = largest_source_chr_matching[target_str[target_idx]]; - largest_target_chr_matching_source = largest_target_chr_matching; - // if characters match, diagonal move costs nothing and we update our largest target index - // otherwise move is substitution and costs as such - if (source_str[source_idx] == target_str[target_idx]) { - cost_diagonal_shift = 0; - largest_target_chr_matching = target_idx + 1; - } else { - cost_diagonal_shift = COST_SUBSTITUTION; - } - distance[source_idx + 2][target_idx + 2] = MinValue( - distance[source_idx + 1][target_idx + 1] + cost_diagonal_shift, - MinValue(distance[source_idx + 2][target_idx + 1] + COST_INSERTION, - MinValue(distance[source_idx + 1][target_idx + 2] + COST_DELETION, - distance[largest_source_chr_matching_target][largest_target_chr_matching_source] + - (source_idx - largest_source_chr_matching_target) * COST_DELETION + - COST_TRANSPOSITION + - (target_idx - largest_target_chr_matching_source) * COST_INSERTION))); - } - largest_source_chr_matching[source_str[source_idx]] = source_idx + 1; - } - return distance[source_len + 1][target_len + 1]; -} - -static int64_t DamerauLevenshteinScalarFunction(Vector &result, const string_t source, const string_t target) { - return (int64_t)DamerauLevenshteinDistance(source, target); -} - -static void DamerauLevenshteinFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &source_vec = args.data[0]; - auto &target_vec = args.data[1]; - - BinaryExecutor::Execute( - source_vec, target_vec, result, args.size(), - [&](string_t source, string_t target) { return DamerauLevenshteinScalarFunction(result, source, target); }); -} - -ScalarFunction DamerauLevenshteinFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BIGINT, - DamerauLevenshteinFunction); -} - -} // namespace duckdb - - - - -namespace duckdb { - -static void FormatBytesFunction(DataChunk &args, ExpressionState &state, Vector &result) { - UnaryExecutor::Execute(args.data[0], result, args.size(), [&](int64_t bytes) { - bool is_negative = bytes < 0; - idx_t unsigned_bytes; - if (bytes < 0) { - if (bytes == NumericLimits::Minimum()) { - unsigned_bytes = idx_t(NumericLimits::Maximum()) + 1; - } else { - unsigned_bytes = idx_t(-bytes); - } - } else { - unsigned_bytes = idx_t(bytes); - } - return StringVector::AddString(result, (is_negative ? "-" : "") + - StringUtil::BytesToHumanReadableString(unsigned_bytes)); - }); -} - -ScalarFunction FormatBytesFun::GetFunction() { - return ScalarFunction({LogicalType::BIGINT}, LogicalType::VARCHAR, FormatBytesFunction); -} - -} // namespace duckdb - - - -#include -#include - -namespace duckdb { - -static int64_t MismatchesScalarFunction(Vector &result, const string_t str, string_t tgt) { - idx_t str_len = str.GetSize(); - idx_t tgt_len = tgt.GetSize(); - - if (str_len != tgt_len) { - throw InvalidInputException("Mismatch Function: Strings must be of equal length!"); - } - if (str_len < 1) { - throw InvalidInputException("Mismatch Function: Strings must be of length > 0!"); - } - - idx_t mismatches = 0; - auto str_str = str.GetData(); - auto tgt_str = tgt.GetData(); - - for (idx_t idx = 0; idx < str_len; ++idx) { - if (str_str[idx] != tgt_str[idx]) { - mismatches++; - } - } - return (int64_t)mismatches; -} - -static void MismatchesFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &str_vec = args.data[0]; - auto &tgt_vec = args.data[1]; - - BinaryExecutor::Execute( - str_vec, tgt_vec, result, args.size(), - [&](string_t str, string_t tgt) { return MismatchesScalarFunction(result, str, tgt); }); -} - -ScalarFunction HammingFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BIGINT, MismatchesFunction); -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -static void WriteHexBytes(uint64_t x, char *&output, idx_t buffer_size) { - idx_t offset = buffer_size * 4; - - for (; offset >= 4; offset -= 4) { - uint8_t byte = (x >> (offset - 4)) & 0x0F; - *output = Blob::HEX_TABLE[byte]; - output++; - } -} - -static void WriteHugeIntHexBytes(hugeint_t x, char *&output, idx_t buffer_size) { - idx_t offset = buffer_size * 4; - auto upper = x.upper; - auto lower = x.lower; - - for (; offset >= 68; offset -= 4) { - uint8_t byte = (upper >> (offset - 68)) & 0x0F; - *output = Blob::HEX_TABLE[byte]; - output++; - } - - for (; offset >= 4; offset -= 4) { - uint8_t byte = (lower >> (offset - 4)) & 0x0F; - *output = Blob::HEX_TABLE[byte]; - output++; - } -} - -static void WriteBinBytes(uint64_t x, char *&output, idx_t buffer_size) { - idx_t offset = buffer_size; - for (; offset >= 1; offset -= 1) { - *output = ((x >> (offset - 1)) & 0x01) + '0'; - output++; - } -} - -static void WriteHugeIntBinBytes(hugeint_t x, char *&output, idx_t buffer_size) { - auto upper = x.upper; - auto lower = x.lower; - idx_t offset = buffer_size; - - for (; offset >= 65; offset -= 1) { - *output = ((upper >> (offset - 65)) & 0x01) + '0'; - output++; - } - - for (; offset >= 1; offset -= 1) { - *output = ((lower >> (offset - 1)) & 0x01) + '0'; - output++; - } -} - -struct HexStrOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto data = input.GetData(); - auto size = input.GetSize(); - - // Allocate empty space - auto target = StringVector::EmptyString(result, size * 2); - auto output = target.GetDataWriteable(); - - for (idx_t i = 0; i < size; ++i) { - *output = Blob::HEX_TABLE[(data[i] >> 4) & 0x0F]; - output++; - *output = Blob::HEX_TABLE[data[i] & 0x0F]; - output++; - } - - target.Finalize(); - return target; - } -}; - -struct HexIntegralOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - - idx_t num_leading_zero = CountZeros::Leading(input); - idx_t num_bits_to_check = 64 - num_leading_zero; - D_ASSERT(num_bits_to_check <= sizeof(INPUT_TYPE) * 8); - - idx_t buffer_size = (num_bits_to_check + 3) / 4; - - // Special case: All bits are zero - if (buffer_size == 0) { - auto target = StringVector::EmptyString(result, 1); - auto output = target.GetDataWriteable(); - *output = '0'; - target.Finalize(); - return target; - } - - D_ASSERT(buffer_size > 0); - auto target = StringVector::EmptyString(result, buffer_size); - auto output = target.GetDataWriteable(); - - WriteHexBytes(input, output, buffer_size); - - target.Finalize(); - return target; - } -}; - -struct HexHugeIntOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - - idx_t num_leading_zero = CountZeros::Leading(input); - idx_t buffer_size = sizeof(INPUT_TYPE) * 2 - (num_leading_zero / 4); - - // Special case: All bits are zero - if (buffer_size == 0) { - auto target = StringVector::EmptyString(result, 1); - auto output = target.GetDataWriteable(); - *output = '0'; - target.Finalize(); - return target; - } - - D_ASSERT(buffer_size > 0); - auto target = StringVector::EmptyString(result, buffer_size); - auto output = target.GetDataWriteable(); - - WriteHugeIntHexBytes(input, output, buffer_size); - - target.Finalize(); - return target; - } -}; - -template -static void ToHexFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 1); - auto &input = args.data[0]; - idx_t count = args.size(); - UnaryExecutor::ExecuteString(input, result, count); -} - -struct BinaryStrOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto data = input.GetData(); - auto size = input.GetSize(); - - // Allocate empty space - auto target = StringVector::EmptyString(result, size * 8); - auto output = target.GetDataWriteable(); - - for (idx_t i = 0; i < size; ++i) { - uint8_t byte = data[i]; - for (idx_t i = 8; i >= 1; --i) { - *output = ((byte >> (i - 1)) & 0x01) + '0'; - output++; - } - } - - target.Finalize(); - return target; - } -}; - -struct BinaryIntegralOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - - idx_t num_leading_zero = CountZeros::Leading(input); - idx_t num_bits_to_check = 64 - num_leading_zero; - D_ASSERT(num_bits_to_check <= sizeof(INPUT_TYPE) * 8); - - idx_t buffer_size = num_bits_to_check; - - // Special case: All bits are zero - if (buffer_size == 0) { - auto target = StringVector::EmptyString(result, 1); - auto output = target.GetDataWriteable(); - *output = '0'; - target.Finalize(); - return target; - } - - D_ASSERT(buffer_size > 0); - auto target = StringVector::EmptyString(result, buffer_size); - auto output = target.GetDataWriteable(); - - WriteBinBytes(input, output, buffer_size); - - target.Finalize(); - return target; - } -}; - -struct BinaryHugeIntOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - idx_t num_leading_zero = CountZeros::Leading(input); - idx_t buffer_size = sizeof(INPUT_TYPE) * 8 - num_leading_zero; - - // Special case: All bits are zero - if (buffer_size == 0) { - auto target = StringVector::EmptyString(result, 1); - auto output = target.GetDataWriteable(); - *output = '0'; - target.Finalize(); - return target; - } - - auto target = StringVector::EmptyString(result, buffer_size); - auto output = target.GetDataWriteable(); - - WriteHugeIntBinBytes(input, output, buffer_size); - - target.Finalize(); - return target; - } -}; - -struct FromHexOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto data = input.GetData(); - auto size = input.GetSize(); - - if (size > NumericLimits::Maximum()) { - throw InvalidInputException("Hexadecimal input length larger than 2^32 are not supported"); - } - - D_ASSERT(size <= NumericLimits::Maximum()); - auto buffer_size = (size + 1) / 2; - - // Allocate empty space - auto target = StringVector::EmptyString(result, buffer_size); - auto output = target.GetDataWriteable(); - - // Treated as a single byte - idx_t i = 0; - if (size % 2 != 0) { - *output = StringUtil::GetHexValue(data[i]); - i++; - output++; - } - - for (; i < size; i += 2) { - uint8_t major = StringUtil::GetHexValue(data[i]); - uint8_t minor = StringUtil::GetHexValue(data[i + 1]); - *output = (major << 4) | minor; - output++; - } - - target.Finalize(); - return target; - } -}; - -struct FromBinaryOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto data = input.GetData(); - auto size = input.GetSize(); - - if (size > NumericLimits::Maximum()) { - throw InvalidInputException("Binary input length larger than 2^32 are not supported"); - } - - D_ASSERT(size <= NumericLimits::Maximum()); - auto buffer_size = (size + 7) / 8; - - // Allocate empty space - auto target = StringVector::EmptyString(result, buffer_size); - auto output = target.GetDataWriteable(); - - // Treated as a single byte - idx_t i = 0; - if (size % 8 != 0) { - uint8_t byte = 0; - for (idx_t j = size % 8; j > 0; --j) { - byte |= StringUtil::GetBinaryValue(data[i]) << (j - 1); - i++; - } - *output = byte; - output++; - } - - while (i < size) { - uint8_t byte = 0; - for (idx_t j = 8; j > 0; --j) { - byte |= StringUtil::GetBinaryValue(data[i]) << (j - 1); - i++; - } - *output = byte; - output++; - } - - target.Finalize(); - return target; - } -}; - -template -static void ToBinaryFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 1); - auto &input = args.data[0]; - idx_t count = args.size(); - UnaryExecutor::ExecuteString(input, result, count); -} - -static void FromBinaryFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 1); - D_ASSERT(args.data[0].GetType().InternalType() == PhysicalType::VARCHAR); - auto &input = args.data[0]; - idx_t count = args.size(); - - UnaryExecutor::ExecuteString(input, result, count); -} - -static void FromHexFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 1); - D_ASSERT(args.data[0].GetType().InternalType() == PhysicalType::VARCHAR); - auto &input = args.data[0]; - idx_t count = args.size(); - - UnaryExecutor::ExecuteString(input, result, count); -} - -ScalarFunctionSet HexFun::GetFunctions() { - ScalarFunctionSet to_hex; - to_hex.AddFunction( - ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, ToHexFunction)); - - to_hex.AddFunction( - ScalarFunction({LogicalType::BIGINT}, LogicalType::VARCHAR, ToHexFunction)); - - to_hex.AddFunction( - ScalarFunction({LogicalType::UBIGINT}, LogicalType::VARCHAR, ToHexFunction)); - - to_hex.AddFunction( - ScalarFunction({LogicalType::HUGEINT}, LogicalType::VARCHAR, ToHexFunction)); - return to_hex; -} - -ScalarFunction UnhexFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR}, LogicalType::BLOB, FromHexFunction); -} - -ScalarFunctionSet BinFun::GetFunctions() { - ScalarFunctionSet to_binary; - - to_binary.AddFunction( - ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, ToBinaryFunction)); - to_binary.AddFunction(ScalarFunction({LogicalType::UBIGINT}, LogicalType::VARCHAR, - ToBinaryFunction)); - to_binary.AddFunction( - ScalarFunction({LogicalType::BIGINT}, LogicalType::VARCHAR, ToBinaryFunction)); - to_binary.AddFunction(ScalarFunction({LogicalType::HUGEINT}, LogicalType::VARCHAR, - ToBinaryFunction)); - return to_binary; -} - -ScalarFunction UnbinFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR}, LogicalType::BLOB, FromBinaryFunction); -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -struct InstrOperator { - template - static inline TR Operation(TA haystack, TB needle) { - int64_t string_position = 0; - - auto location = ContainsFun::Find(haystack, needle); - if (location != DConstants::INVALID_INDEX) { - auto len = (utf8proc_ssize_t)location; - auto str = reinterpret_cast(haystack.GetData()); - D_ASSERT(len <= (utf8proc_ssize_t)haystack.GetSize()); - for (++string_position; len > 0; ++string_position) { - utf8proc_int32_t codepoint; - auto bytes = utf8proc_iterate(str, len, &codepoint); - str += bytes; - len -= bytes; - } - } - return string_position; - } -}; - -struct InstrAsciiOperator { - template - static inline TR Operation(TA haystack, TB needle) { - auto location = ContainsFun::Find(haystack, needle); - return location == DConstants::INVALID_INDEX ? 0 : location + 1; - } -}; - -static unique_ptr InStrPropagateStats(ClientContext &context, FunctionStatisticsInput &input) { - auto &child_stats = input.child_stats; - auto &expr = input.expr; - D_ASSERT(child_stats.size() == 2); - // can only propagate stats if the children have stats - // for strpos, we only care if the FIRST string has unicode or not - if (!StringStats::CanContainUnicode(child_stats[0])) { - expr.function.function = ScalarFunction::BinaryFunction; - } - return nullptr; -} - -ScalarFunction InstrFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BIGINT, - ScalarFunction::BinaryFunction, nullptr, nullptr, - InStrPropagateStats); -} - -} // namespace duckdb - - - - -#include - -namespace duckdb { - -static inline map GetSet(const string_t &str) { - auto map_of_chars = map {}; - idx_t str_len = str.GetSize(); - auto s = str.GetData(); - - for (idx_t pos = 0; pos < str_len; pos++) { - map_of_chars.insert(std::make_pair(s[pos], 1)); - } - return map_of_chars; -} - -static double JaccardSimilarity(const string_t &str, const string_t &txt) { - if (str.GetSize() < 1 || txt.GetSize() < 1) { - throw InvalidInputException("Jaccard Function: An argument too short!"); - } - map m_str, m_txt; - - m_str = GetSet(str); - m_txt = GetSet(txt); - - if (m_str.size() > m_txt.size()) { - m_str.swap(m_txt); - } - - for (auto const &achar : m_str) { - ++m_txt[achar.first]; - } - // m_txt.size is now size of union. - - idx_t size_intersect = 0; - for (const auto &apair : m_txt) { - if (apair.second > 1) { - size_intersect++; - } - } - - return (double)size_intersect / (double)m_txt.size(); -} - -static double JaccardScalarFunction(Vector &result, const string_t str, string_t tgt) { - return (double)JaccardSimilarity(str, tgt); -} - -static void JaccardFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &str_vec = args.data[0]; - auto &tgt_vec = args.data[1]; - - BinaryExecutor::Execute( - str_vec, tgt_vec, result, args.size(), - [&](string_t str, string_t tgt) { return JaccardScalarFunction(result, str, tgt); }); -} - -ScalarFunction JaccardFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::DOUBLE, JaccardFunction); -} - -} // namespace duckdb - - - - -namespace duckdb { - -static inline double JaroScalarFunction(const string_t &s1, const string_t &s2) { - auto s1_begin = s1.GetData(); - auto s2_begin = s2.GetData(); - return duckdb_jaro_winkler::jaro_similarity(s1_begin, s1_begin + s1.GetSize(), s2_begin, s2_begin + s2.GetSize()); -} - -static inline double JaroWinklerScalarFunction(const string_t &s1, const string_t &s2) { - auto s1_begin = s1.GetData(); - auto s2_begin = s2.GetData(); - return duckdb_jaro_winkler::jaro_winkler_similarity(s1_begin, s1_begin + s1.GetSize(), s2_begin, - s2_begin + s2.GetSize()); -} - -template -static void CachedFunction(Vector &constant, Vector &other, Vector &result, idx_t count) { - auto val = constant.GetValue(0); - if (val.IsNull()) { - auto &result_validity = FlatVector::Validity(result); - result_validity.SetAllInvalid(count); - return; - } - - auto str_val = StringValue::Get(val); - auto cached = CACHED_SIMILARITY(str_val); - UnaryExecutor::Execute(other, result, count, [&](const string_t &other_str) { - auto other_str_begin = other_str.GetData(); - return cached.similarity(other_str_begin, other_str_begin + other_str.GetSize()); - }); -} - -template > -static void TemplatedJaroWinklerFunction(DataChunk &args, Vector &result, SIMILARITY_FUNCTION fun) { - bool arg0_constant = args.data[0].GetVectorType() == VectorType::CONSTANT_VECTOR; - bool arg1_constant = args.data[1].GetVectorType() == VectorType::CONSTANT_VECTOR; - if (!(arg0_constant ^ arg1_constant)) { - // We can't optimize by caching one of the two strings - BinaryExecutor::Execute(args.data[0], args.data[1], result, args.size(), fun); - return; - } - - if (arg0_constant) { - CachedFunction(args.data[0], args.data[1], result, args.size()); - } else { - CachedFunction(args.data[1], args.data[0], result, args.size()); - } -} - -static void JaroFunction(DataChunk &args, ExpressionState &state, Vector &result) { - TemplatedJaroWinklerFunction>(args, result, JaroScalarFunction); -} - -static void JaroWinklerFunction(DataChunk &args, ExpressionState &state, Vector &result) { - TemplatedJaroWinklerFunction>(args, result, - JaroWinklerScalarFunction); -} - -ScalarFunction JaroSimilarityFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::DOUBLE, JaroFunction); -} - -ScalarFunction JaroWinklerSimilarityFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::DOUBLE, JaroWinklerFunction); -} - -} // namespace duckdb - - - - - -#include -#include - -namespace duckdb { - -struct LeftRightUnicode { - template - static inline TR Operation(TA input) { - return LengthFun::Length(input); - } - - static string_t Substring(Vector &result, string_t input, int64_t offset, int64_t length) { - return SubstringFun::SubstringUnicode(result, input, offset, length); - } -}; - -struct LeftRightGrapheme { - template - static inline TR Operation(TA input) { - return LengthFun::GraphemeCount(input); - } - - static string_t Substring(Vector &result, string_t input, int64_t offset, int64_t length) { - return SubstringFun::SubstringGrapheme(result, input, offset, length); - } -}; - -template -static string_t LeftScalarFunction(Vector &result, const string_t str, int64_t pos) { - if (pos >= 0) { - return OP::Substring(result, str, 1, pos); - } - - int64_t num_characters = OP::template Operation(str); - pos = MaxValue(0, num_characters + pos); - return OP::Substring(result, str, 1, pos); -} - -template -static void LeftFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &str_vec = args.data[0]; - auto &pos_vec = args.data[1]; - - BinaryExecutor::Execute( - str_vec, pos_vec, result, args.size(), - [&](string_t str, int64_t pos) { return LeftScalarFunction(result, str, pos); }); -} - -ScalarFunction LeftFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR, LogicalType::BIGINT}, LogicalType::VARCHAR, - LeftFunction); -} - -ScalarFunction LeftGraphemeFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR, LogicalType::BIGINT}, LogicalType::VARCHAR, - LeftFunction); -} - -template -static string_t RightScalarFunction(Vector &result, const string_t str, int64_t pos) { - int64_t num_characters = OP::template Operation(str); - if (pos >= 0) { - int64_t len = MinValue(num_characters, pos); - int64_t start = num_characters - len + 1; - return OP::Substring(result, str, start, len); - } - - int64_t len = 0; - if (pos != std::numeric_limits::min()) { - len = num_characters - MinValue(num_characters, -pos); - } - int64_t start = num_characters - len + 1; - return OP::Substring(result, str, start, len); -} - -template -static void RightFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &str_vec = args.data[0]; - auto &pos_vec = args.data[1]; - BinaryExecutor::Execute( - str_vec, pos_vec, result, args.size(), - [&](string_t str, int64_t pos) { return RightScalarFunction(result, str, pos); }); -} - -ScalarFunction RightFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR, LogicalType::BIGINT}, LogicalType::VARCHAR, - RightFunction); -} - -ScalarFunction RightGraphemeFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR, LogicalType::BIGINT}, LogicalType::VARCHAR, - RightFunction); -} - -} // namespace duckdb - - - - -#include -#include - -namespace duckdb { - -// See: https://www.kdnuggets.com/2020/10/optimizing-levenshtein-distance-measuring-text-similarity.html -// And: Iterative 2-row algorithm: https://en.wikipedia.org/wiki/Levenshtein_distance -// Note: A first implementation using the array algorithm version resulted in an error raised by duckdb -// (too muach memory usage) - -static idx_t LevenshteinDistance(const string_t &txt, const string_t &tgt) { - auto txt_len = txt.GetSize(); - auto tgt_len = tgt.GetSize(); - - // If one string is empty, the distance equals the length of the other string - if (txt_len == 0) { - return tgt_len; - } else if (tgt_len == 0) { - return txt_len; - } - - auto txt_str = txt.GetData(); - auto tgt_str = tgt.GetData(); - - // Create two working vectors - vector distances0(tgt_len + 1, 0); - vector distances1(tgt_len + 1, 0); - - idx_t cost_substitution = 0; - idx_t cost_insertion = 0; - idx_t cost_deletion = 0; - - // initialize distances0 vector - // edit distance for an empty txt string is just the number of characters to delete from tgt - for (idx_t pos_tgt = 0; pos_tgt <= tgt_len; pos_tgt++) { - distances0[pos_tgt] = pos_tgt; - } - - for (idx_t pos_txt = 0; pos_txt < txt_len; pos_txt++) { - // calculate distances1 (current raw distances) from the previous row - - distances1[0] = pos_txt + 1; - - for (idx_t pos_tgt = 0; pos_tgt < tgt_len; pos_tgt++) { - cost_deletion = distances0[pos_tgt + 1] + 1; - cost_insertion = distances1[pos_tgt] + 1; - cost_substitution = distances0[pos_tgt]; - - if (txt_str[pos_txt] != tgt_str[pos_tgt]) { - cost_substitution += 1; - } - - distances1[pos_tgt + 1] = MinValue(cost_deletion, MinValue(cost_substitution, cost_insertion)); - } - // copy distances1 (current row) to distances0 (previous row) for next iteration - // since data in distances1 is always invalidated, a swap without copy is more efficient - distances0 = distances1; - } - - return distances0[tgt_len]; -} - -static int64_t LevenshteinScalarFunction(Vector &result, const string_t str, string_t tgt) { - return (int64_t)LevenshteinDistance(str, tgt); -} - -static void LevenshteinFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &str_vec = args.data[0]; - auto &tgt_vec = args.data[1]; - - BinaryExecutor::Execute( - str_vec, tgt_vec, result, args.size(), - [&](string_t str, string_t tgt) { return LevenshteinScalarFunction(result, str, tgt); }); -} - -ScalarFunction LevenshteinFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BIGINT, LevenshteinFunction); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -struct MD5Operator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto hash = StringVector::EmptyString(result, MD5Context::MD5_HASH_LENGTH_TEXT); - MD5Context context; - context.Add(input); - context.FinishHex(hash.GetDataWriteable()); - hash.Finalize(); - return hash; - } -}; - -struct MD5Number128Operator { - template - static RESULT_TYPE Operation(INPUT_TYPE input) { - data_t digest[MD5Context::MD5_HASH_LENGTH_BINARY]; - - MD5Context context; - context.Add(input); - context.Finish(digest); - return *reinterpret_cast(digest); - } -}; - -template -struct MD5Number64Operator { - template - static RESULT_TYPE Operation(INPUT_TYPE input) { - data_t digest[MD5Context::MD5_HASH_LENGTH_BINARY]; - - MD5Context context; - context.Add(input); - context.Finish(digest); - return *reinterpret_cast(&digest[lower ? 8 : 0]); - } -}; - -static void MD5Function(DataChunk &args, ExpressionState &state, Vector &result) { - auto &input = args.data[0]; - - UnaryExecutor::ExecuteString(input, result, args.size()); -} - -static void MD5NumberFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &input = args.data[0]; - - UnaryExecutor::Execute(input, result, args.size()); -} - -static void MD5NumberUpperFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &input = args.data[0]; - - UnaryExecutor::Execute>(input, result, args.size()); -} - -static void MD5NumberLowerFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &input = args.data[0]; - - UnaryExecutor::Execute>(input, result, args.size()); -} - -ScalarFunction MD5Fun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, MD5Function); -} - -ScalarFunction MD5NumberFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR}, LogicalType::HUGEINT, MD5NumberFunction); -} - -ScalarFunction MD5NumberUpperFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR}, LogicalType::UBIGINT, MD5NumberUpperFunction); -} - -ScalarFunction MD5NumberLowerFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR}, LogicalType::UBIGINT, MD5NumberLowerFunction); -} - -} // namespace duckdb - - - - - - - - - - -namespace duckdb { - -static pair PadCountChars(const idx_t len, const char *data, const idx_t size) { - // Count how much of str will fit in the output - auto str = reinterpret_cast(data); - idx_t nbytes = 0; - idx_t nchars = 0; - for (; nchars < len && nbytes < size; ++nchars) { - utf8proc_int32_t codepoint; - auto bytes = utf8proc_iterate(str + nbytes, size - nbytes, &codepoint); - D_ASSERT(bytes > 0); - nbytes += bytes; - } - - return pair(nbytes, nchars); -} - -static bool InsertPadding(const idx_t len, const string_t &pad, vector &result) { - // Copy the padding until the output is long enough - auto data = pad.GetData(); - auto size = pad.GetSize(); - - // Check whether we need data that we don't have - if (len > 0 && size == 0) { - return false; - } - - // Insert characters until we have all we need. - auto str = reinterpret_cast(data); - idx_t nbytes = 0; - for (idx_t nchars = 0; nchars < len; ++nchars) { - // If we are at the end of the pad, flush all of it and loop back - if (nbytes >= size) { - result.insert(result.end(), data, data + size); - nbytes = 0; - } - - // Write the next character - utf8proc_int32_t codepoint; - auto bytes = utf8proc_iterate(str + nbytes, size - nbytes, &codepoint); - D_ASSERT(bytes > 0); - nbytes += bytes; - } - - // Flush the remaining pad - result.insert(result.end(), data, data + nbytes); - - return true; -} - -static string_t LeftPadFunction(const string_t &str, const int32_t len, const string_t &pad, vector &result) { - // Reuse the buffer - result.clear(); - - // Get information about the base string - auto data_str = str.GetData(); - auto size_str = str.GetSize(); - - // Count how much of str will fit in the output - auto written = PadCountChars(len, data_str, size_str); - - // Left pad by the number of characters still needed - if (!InsertPadding(len - written.second, pad, result)) { - throw Exception("Insufficient padding in LPAD."); - } - - // Append as much of the original string as fits - result.insert(result.end(), data_str, data_str + written.first); - - return string_t(result.data(), result.size()); -} - -struct LeftPadOperator { - static inline string_t Operation(const string_t &str, const int32_t len, const string_t &pad, - vector &result) { - return LeftPadFunction(str, len, pad, result); - } -}; - -static string_t RightPadFunction(const string_t &str, const int32_t len, const string_t &pad, vector &result) { - // Reuse the buffer - result.clear(); - - // Get information about the base string - auto data_str = str.GetData(); - auto size_str = str.GetSize(); - - // Count how much of str will fit in the output - auto written = PadCountChars(len, data_str, size_str); - - // Append as much of the original string as fits - result.insert(result.end(), data_str, data_str + written.first); - - // Right pad by the number of characters still needed - if (!InsertPadding(len - written.second, pad, result)) { - throw Exception("Insufficient padding in RPAD."); - }; - - return string_t(result.data(), result.size()); -} - -struct RightPadOperator { - static inline string_t Operation(const string_t &str, const int32_t len, const string_t &pad, - vector &result) { - return RightPadFunction(str, len, pad, result); - } -}; - -template -static void PadFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &str_vector = args.data[0]; - auto &len_vector = args.data[1]; - auto &pad_vector = args.data[2]; - - vector buffer; - TernaryExecutor::Execute( - str_vector, len_vector, pad_vector, result, args.size(), [&](string_t str, int32_t len, string_t pad) { - len = MaxValue(len, 0); - return StringVector::AddString(result, OP::Operation(str, len, pad, buffer)); - }); -} - -ScalarFunction LpadFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR, LogicalType::INTEGER, LogicalType::VARCHAR}, LogicalType::VARCHAR, - PadFunction); -} - -ScalarFunction RpadFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR, LogicalType::INTEGER, LogicalType::VARCHAR}, LogicalType::VARCHAR, - PadFunction); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -struct FMTPrintf { - template - static string OP(const char *format_str, vector> &format_args) { - return duckdb_fmt::vsprintf( - format_str, duckdb_fmt::basic_format_args(format_args.data(), static_cast(format_args.size()))); - } -}; - -struct FMTFormat { - template - static string OP(const char *format_str, vector> &format_args) { - return duckdb_fmt::vformat( - format_str, duckdb_fmt::basic_format_args(format_args.data(), static_cast(format_args.size()))); - } -}; - -unique_ptr BindPrintfFunction(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - for (idx_t i = 1; i < arguments.size(); i++) { - switch (arguments[i]->return_type.id()) { - case LogicalTypeId::BOOLEAN: - case LogicalTypeId::TINYINT: - case LogicalTypeId::SMALLINT: - case LogicalTypeId::INTEGER: - case LogicalTypeId::BIGINT: - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - case LogicalTypeId::VARCHAR: - // these types are natively supported - bound_function.arguments.push_back(arguments[i]->return_type); - break; - case LogicalTypeId::DECIMAL: - // decimal type: add cast to double - bound_function.arguments.emplace_back(LogicalType::DOUBLE); - break; - case LogicalTypeId::UNKNOWN: - // parameter: accept any input and rebind later - bound_function.arguments.emplace_back(LogicalType::ANY); - break; - default: - // all other types: add cast to string - bound_function.arguments.emplace_back(LogicalType::VARCHAR); - break; - } - } - return nullptr; -} - -template -static void PrintfFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &format_string = args.data[0]; - auto &result_validity = FlatVector::Validity(result); - result.SetVectorType(VectorType::CONSTANT_VECTOR); - result_validity.Initialize(args.size()); - for (idx_t i = 0; i < args.ColumnCount(); i++) { - switch (args.data[i].GetVectorType()) { - case VectorType::CONSTANT_VECTOR: - if (ConstantVector::IsNull(args.data[i])) { - // constant null! result is always NULL regardless of other input - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - return; - } - break; - default: - // FLAT VECTOR, we can directly OR the nullmask - args.data[i].Flatten(args.size()); - result.SetVectorType(VectorType::FLAT_VECTOR); - result_validity.Combine(FlatVector::Validity(args.data[i]), args.size()); - break; - } - } - idx_t count = result.GetVectorType() == VectorType::CONSTANT_VECTOR ? 1 : args.size(); - - auto format_data = FlatVector::GetData(format_string); - auto result_data = FlatVector::GetData(result); - for (idx_t idx = 0; idx < count; idx++) { - if (result.GetVectorType() == VectorType::FLAT_VECTOR && FlatVector::IsNull(result, idx)) { - // this entry is NULL: skip it - continue; - } - - // first fetch the format string - auto fmt_idx = format_string.GetVectorType() == VectorType::CONSTANT_VECTOR ? 0 : idx; - auto format_string = format_data[fmt_idx].GetString(); - - // now gather all the format arguments - vector> format_args; - vector> string_args; - - for (idx_t col_idx = 1; col_idx < args.ColumnCount(); col_idx++) { - auto &col = args.data[col_idx]; - idx_t arg_idx = col.GetVectorType() == VectorType::CONSTANT_VECTOR ? 0 : idx; - switch (col.GetType().id()) { - case LogicalTypeId::BOOLEAN: { - auto arg_data = FlatVector::GetData(col); - format_args.emplace_back(duckdb_fmt::internal::make_arg(arg_data[arg_idx])); - break; - } - case LogicalTypeId::TINYINT: { - auto arg_data = FlatVector::GetData(col); - format_args.emplace_back(duckdb_fmt::internal::make_arg(arg_data[arg_idx])); - break; - } - case LogicalTypeId::SMALLINT: { - auto arg_data = FlatVector::GetData(col); - format_args.emplace_back(duckdb_fmt::internal::make_arg(arg_data[arg_idx])); - break; - } - case LogicalTypeId::INTEGER: { - auto arg_data = FlatVector::GetData(col); - format_args.emplace_back(duckdb_fmt::internal::make_arg(arg_data[arg_idx])); - break; - } - case LogicalTypeId::BIGINT: { - auto arg_data = FlatVector::GetData(col); - format_args.emplace_back(duckdb_fmt::internal::make_arg(arg_data[arg_idx])); - break; - } - case LogicalTypeId::FLOAT: { - auto arg_data = FlatVector::GetData(col); - format_args.emplace_back(duckdb_fmt::internal::make_arg(arg_data[arg_idx])); - break; - } - case LogicalTypeId::DOUBLE: { - auto arg_data = FlatVector::GetData(col); - format_args.emplace_back(duckdb_fmt::internal::make_arg(arg_data[arg_idx])); - break; - } - case LogicalTypeId::VARCHAR: { - auto arg_data = FlatVector::GetData(col); - auto string_view = - duckdb_fmt::basic_string_view(arg_data[arg_idx].GetData(), arg_data[arg_idx].GetSize()); - format_args.emplace_back(duckdb_fmt::internal::make_arg(string_view)); - break; - } - default: - throw InternalException("Unexpected type for printf format"); - } - } - // finally actually perform the format - string dynamic_result = FORMAT_FUN::template OP(format_string.c_str(), format_args); - result_data[idx] = StringVector::AddString(result, dynamic_result); - } -} - -ScalarFunction PrintfFun::GetFunction() { - // duckdb_fmt::printf_context, duckdb_fmt::vsprintf - ScalarFunction printf_fun({LogicalType::VARCHAR}, LogicalType::VARCHAR, - PrintfFunction, BindPrintfFunction); - printf_fun.varargs = LogicalType::ANY; - return printf_fun; -} - -ScalarFunction FormatFun::GetFunction() { - // duckdb_fmt::format_context, duckdb_fmt::vformat - ScalarFunction format_fun({LogicalType::VARCHAR}, LogicalType::VARCHAR, - PrintfFunction, BindPrintfFunction); - format_fun.varargs = LogicalType::ANY; - return format_fun; -} - -} // namespace duckdb - - - - -#include -#include - -namespace duckdb { - -static string_t RepeatScalarFunction(const string_t &str, const int64_t cnt, vector &result) { - // Get information about the repeated string - auto input_str = str.GetData(); - auto size_str = str.GetSize(); - - // Reuse the buffer - result.clear(); - for (auto remaining = cnt; remaining-- > 0;) { - result.insert(result.end(), input_str, input_str + size_str); - } - - return string_t(result.data(), result.size()); -} - -static void RepeatFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &str_vector = args.data[0]; - auto &cnt_vector = args.data[1]; - - vector buffer; - BinaryExecutor::Execute( - str_vector, cnt_vector, result, args.size(), [&](string_t str, int64_t cnt) { - return StringVector::AddString(result, RepeatScalarFunction(str, cnt, buffer)); - }); -} - -ScalarFunctionSet RepeatFun::GetFunctions() { - ScalarFunctionSet repeat; - for (const auto &type : {LogicalType::VARCHAR, LogicalType::BLOB}) { - repeat.AddFunction(ScalarFunction({type, LogicalType::BIGINT}, type, RepeatFunction)); - } - return repeat; -} - -} // namespace duckdb - - - - - - -#include -#include -#include - -namespace duckdb { - -static idx_t NextNeedle(const char *input_haystack, idx_t size_haystack, const char *input_needle, - const idx_t size_needle) { - // Needle needs something to proceed - if (size_needle > 0) { - // Haystack should be bigger or equal size to the needle - for (idx_t string_position = 0; (size_haystack - string_position) >= size_needle; ++string_position) { - // Compare Needle to the Haystack - if ((memcmp(input_haystack + string_position, input_needle, size_needle) == 0)) { - return string_position; - } - } - } - // Did not find the needle - return size_haystack; -} - -static string_t ReplaceScalarFunction(const string_t &haystack, const string_t &needle, const string_t &thread, - vector &result) { - // Get information about the needle, the haystack and the "thread" - auto input_haystack = haystack.GetData(); - auto size_haystack = haystack.GetSize(); - - auto input_needle = needle.GetData(); - auto size_needle = needle.GetSize(); - - auto input_thread = thread.GetData(); - auto size_thread = thread.GetSize(); - - // Reuse the buffer - result.clear(); - - for (;;) { - // Append the non-matching characters - auto string_position = NextNeedle(input_haystack, size_haystack, input_needle, size_needle); - result.insert(result.end(), input_haystack, input_haystack + string_position); - input_haystack += string_position; - size_haystack -= string_position; - - // Stop when we have read the entire haystack - if (size_haystack == 0) { - break; - } - - // Replace the matching characters - result.insert(result.end(), input_thread, input_thread + size_thread); - input_haystack += size_needle; - size_haystack -= size_needle; - } - - return string_t(result.data(), result.size()); -} - -static void ReplaceFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &haystack_vector = args.data[0]; - auto &needle_vector = args.data[1]; - auto &thread_vector = args.data[2]; - - vector buffer; - TernaryExecutor::Execute( - haystack_vector, needle_vector, thread_vector, result, args.size(), - [&](string_t input_string, string_t needle_string, string_t thread_string) { - return StringVector::AddString(result, - ReplaceScalarFunction(input_string, needle_string, thread_string, buffer)); - }); -} - -ScalarFunction ReplaceFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, - ReplaceFunction); -} - -} // namespace duckdb - - - - - - - -#include - -namespace duckdb { - -//! Fast ASCII string reverse, returns false if the input data is not ascii -static bool StrReverseASCII(const char *input, idx_t n, char *output) { - for (idx_t i = 0; i < n; i++) { - if (input[i] & 0x80) { - // non-ascii character - return false; - } - output[n - i - 1] = input[i]; - } - return true; -} - -//! Unicode string reverse using grapheme breakers -static void StrReverseUnicode(const char *input, idx_t n, char *output) { - utf8proc_grapheme_callback(input, n, [&](size_t start, size_t end) { - memcpy(output + n - end, input + start, end - start); - return true; - }); -} - -struct ReverseOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto input_data = input.GetData(); - auto input_length = input.GetSize(); - - auto target = StringVector::EmptyString(result, input_length); - auto target_data = target.GetDataWriteable(); - if (!StrReverseASCII(input_data, input_length, target_data)) { - StrReverseUnicode(input_data, input_length, target_data); - } - target.Finalize(); - return target; - } -}; - -static void ReverseFunction(DataChunk &args, ExpressionState &state, Vector &result) { - UnaryExecutor::ExecuteString(args.data[0], result, args.size()); -} - -ScalarFunction ReverseFun::GetFunction() { - return ScalarFunction("reverse", {LogicalType::VARCHAR}, LogicalType::VARCHAR, ReverseFunction); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -struct SHA256Operator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto hash = StringVector::EmptyString(result, duckdb_mbedtls::MbedTlsWrapper::SHA256_HASH_LENGTH_TEXT); - - duckdb_mbedtls::MbedTlsWrapper::SHA256State state; - state.AddString(input.GetString()); - state.FinishHex(hash.GetDataWriteable()); - - hash.Finalize(); - return hash; - } -}; - -static void SHA256Function(DataChunk &args, ExpressionState &state, Vector &result) { - auto &input = args.data[0]; - - UnaryExecutor::ExecuteString(input, result, args.size()); -} - -ScalarFunction SHA256Fun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, SHA256Function); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -static bool StartsWith(const unsigned char *haystack, idx_t haystack_size, const unsigned char *needle, - idx_t needle_size) { - D_ASSERT(needle_size > 0); - if (needle_size > haystack_size) { - // needle is bigger than haystack: haystack cannot start with needle - return false; - } - return memcmp(haystack, needle, needle_size) == 0; -} - -static bool StartsWith(const string_t &haystack_s, const string_t &needle_s) { - - auto haystack = const_uchar_ptr_cast(haystack_s.GetData()); - auto haystack_size = haystack_s.GetSize(); - auto needle = const_uchar_ptr_cast(needle_s.GetData()); - auto needle_size = needle_s.GetSize(); - if (needle_size == 0) { - // empty needle: always true - return true; - } - return StartsWith(haystack, haystack_size, needle, needle_size); -} - -struct StartsWithOperator { - template - static inline TR Operation(TA left, TB right) { - return StartsWith(left, right); - } -}; - -ScalarFunction StartsWithOperatorFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, - ScalarFunction::BinaryFunction); -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -struct StringSplitInput { - StringSplitInput(Vector &result_list, Vector &result_child, idx_t offset) - : result_list(result_list), result_child(result_child), offset(offset) { - } - - Vector &result_list; - Vector &result_child; - idx_t offset; - - void AddSplit(const char *split_data, idx_t split_size, idx_t list_idx) { - auto list_entry = offset + list_idx; - if (list_entry >= ListVector::GetListCapacity(result_list)) { - ListVector::SetListSize(result_list, offset + list_idx); - ListVector::Reserve(result_list, ListVector::GetListCapacity(result_list) * 2); - } - FlatVector::GetData(result_child)[list_entry] = - StringVector::AddString(result_child, split_data, split_size); - } -}; - -struct RegularStringSplit { - static idx_t Find(const char *input_data, idx_t input_size, const char *delim_data, idx_t delim_size, - idx_t &match_size, void *data) { - match_size = delim_size; - if (delim_size == 0) { - return 0; - } - return ContainsFun::Find(const_uchar_ptr_cast(input_data), input_size, const_uchar_ptr_cast(delim_data), - delim_size); - } -}; - -struct ConstantRegexpStringSplit { - static idx_t Find(const char *input_data, idx_t input_size, const char *delim_data, idx_t delim_size, - idx_t &match_size, void *data) { - D_ASSERT(data); - auto regex = reinterpret_cast(data); - duckdb_re2::StringPiece match; - if (!regex->Match(duckdb_re2::StringPiece(input_data, input_size), 0, input_size, RE2::UNANCHORED, &match, 1)) { - return DConstants::INVALID_INDEX; - } - match_size = match.size(); - return match.data() - input_data; - } -}; - -struct RegexpStringSplit { - static idx_t Find(const char *input_data, idx_t input_size, const char *delim_data, idx_t delim_size, - idx_t &match_size, void *data) { - duckdb_re2::RE2 regex(duckdb_re2::StringPiece(delim_data, delim_size)); - if (!regex.ok()) { - throw InvalidInputException(regex.error()); - } - return ConstantRegexpStringSplit::Find(input_data, input_size, delim_data, delim_size, match_size, ®ex); - } -}; - -struct StringSplitter { - template - static idx_t Split(string_t input, string_t delim, StringSplitInput &state, void *data) { - auto input_data = input.GetData(); - auto input_size = input.GetSize(); - auto delim_data = delim.GetData(); - auto delim_size = delim.GetSize(); - idx_t list_idx = 0; - while (input_size > 0) { - idx_t match_size = 0; - auto pos = OP::Find(input_data, input_size, delim_data, delim_size, match_size, data); - if (pos > input_size) { - break; - } - if (match_size == 0 && pos == 0) { - // special case: 0 length match and pos is 0 - // move to the next character - for (pos++; pos < input_size; pos++) { - if (LengthFun::IsCharacter(input_data[pos])) { - break; - } - } - if (pos == input_size) { - break; - } - } - D_ASSERT(input_size >= pos + match_size); - state.AddSplit(input_data, pos, list_idx); - - list_idx++; - input_data += (pos + match_size); - input_size -= (pos + match_size); - } - state.AddSplit(input_data, input_size, list_idx); - list_idx++; - return list_idx; - } -}; - -template -static void StringSplitExecutor(DataChunk &args, ExpressionState &state, Vector &result, void *data = nullptr) { - UnifiedVectorFormat input_data; - args.data[0].ToUnifiedFormat(args.size(), input_data); - auto inputs = UnifiedVectorFormat::GetData(input_data); - - UnifiedVectorFormat delim_data; - args.data[1].ToUnifiedFormat(args.size(), delim_data); - auto delims = UnifiedVectorFormat::GetData(delim_data); - - D_ASSERT(result.GetType().id() == LogicalTypeId::LIST); - - result.SetVectorType(VectorType::FLAT_VECTOR); - ListVector::SetListSize(result, 0); - - auto list_struct_data = FlatVector::GetData(result); - - // count all the splits and set up the list entries - auto &child_entry = ListVector::GetEntry(result); - auto &result_mask = FlatVector::Validity(result); - idx_t total_splits = 0; - for (idx_t i = 0; i < args.size(); i++) { - auto input_idx = input_data.sel->get_index(i); - auto delim_idx = delim_data.sel->get_index(i); - if (!input_data.validity.RowIsValid(input_idx)) { - result_mask.SetInvalid(i); - continue; - } - StringSplitInput split_input(result, child_entry, total_splits); - if (!delim_data.validity.RowIsValid(delim_idx)) { - // delim is NULL: copy the complete entry - split_input.AddSplit(inputs[input_idx].GetData(), inputs[input_idx].GetSize(), 0); - list_struct_data[i].length = 1; - list_struct_data[i].offset = total_splits; - total_splits++; - continue; - } - auto list_length = StringSplitter::Split(inputs[input_idx], delims[delim_idx], split_input, data); - list_struct_data[i].length = list_length; - list_struct_data[i].offset = total_splits; - total_splits += list_length; - } - ListVector::SetListSize(result, total_splits); - D_ASSERT(ListVector::GetListSize(result) == total_splits); - - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } -} - -static void StringSplitFunction(DataChunk &args, ExpressionState &state, Vector &result) { - StringSplitExecutor(args, state, result, nullptr); -} - -static void StringSplitRegexFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - if (info.constant_pattern) { - // fast path: pre-compiled regex - auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); - StringSplitExecutor(args, state, result, &lstate.constant_pattern); - } else { - // slow path: have to re-compile regex for every row - StringSplitExecutor(args, state, result); - } -} - -ScalarFunction StringSplitFun::GetFunction() { - auto varchar_list_type = LogicalType::LIST(LogicalType::VARCHAR); - - ScalarFunction string_split({LogicalType::VARCHAR, LogicalType::VARCHAR}, varchar_list_type, StringSplitFunction); - string_split.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return string_split; -} - -ScalarFunctionSet StringSplitRegexFun::GetFunctions() { - auto varchar_list_type = LogicalType::LIST(LogicalType::VARCHAR); - ScalarFunctionSet regexp_split; - ScalarFunction regex_fun({LogicalType::VARCHAR, LogicalType::VARCHAR}, varchar_list_type, StringSplitRegexFunction, - RegexpMatchesBind, nullptr, nullptr, RegexInitLocalState, LogicalType::INVALID, - FunctionSideEffects::NO_SIDE_EFFECTS, FunctionNullHandling::SPECIAL_HANDLING); - regexp_split.AddFunction(regex_fun); - // regexp options - regex_fun.arguments.emplace_back(LogicalType::VARCHAR); - regexp_split.AddFunction(regex_fun); - return regexp_split; -} - -} // namespace duckdb - - - - -namespace duckdb { - -static const char alphabet[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"; - -static unique_ptr ToBaseBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - // If no min_length is specified, default to 0 - D_ASSERT(arguments.size() == 2 || arguments.size() == 3); - if (arguments.size() == 2) { - arguments.push_back(make_uniq_base(Value::INTEGER(0))); - } - return nullptr; -} - -static void ToBaseFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &input = args.data[0]; - auto &radix = args.data[1]; - auto &min_length = args.data[2]; - auto count = args.size(); - - TernaryExecutor::Execute( - input, radix, min_length, result, count, [&](int64_t input, int32_t radix, int32_t min_length) { - if (input < 0) { - throw InvalidInputException("'to_base' number must be greater than or equal to 0"); - } - if (radix < 2 || radix > 36) { - throw InvalidInputException("'to_base' radix must be between 2 and 36"); - } - if (min_length > 64 || min_length < 0) { - throw InvalidInputException("'to_base' min_length must be between 0 and 64"); - } - - char buf[64]; - char *end = buf + sizeof(buf); - char *ptr = end; - do { - *--ptr = alphabet[input % radix]; - input /= radix; - } while (input > 0); - - auto length = end - ptr; - while (length < min_length) { - *--ptr = '0'; - length++; - } - - return StringVector::AddString(result, ptr, end - ptr); - }); -} - -ScalarFunctionSet ToBaseFun::GetFunctions() { - ScalarFunctionSet set("to_base"); - - set.AddFunction( - ScalarFunction({LogicalType::BIGINT, LogicalType::INTEGER}, LogicalType::VARCHAR, ToBaseFunction, ToBaseBind)); - set.AddFunction(ScalarFunction({LogicalType::BIGINT, LogicalType::INTEGER, LogicalType::INTEGER}, - LogicalType::VARCHAR, ToBaseFunction, ToBaseBind)); - - return set; -} - -} // namespace duckdb - - - - - - - - -#include -#include -#include -#include - -namespace duckdb { - -static string_t TranslateScalarFunction(const string_t &haystack, const string_t &needle, const string_t &thread, - vector &result) { - // Get information about the haystack, the needle and the "thread" - auto input_haystack = haystack.GetData(); - auto size_haystack = haystack.GetSize(); - - auto input_needle = needle.GetData(); - auto size_needle = needle.GetSize(); - - auto input_thread = thread.GetData(); - auto size_thread = thread.GetSize(); - - // Reuse the buffer - result.clear(); - result.reserve(size_haystack); - - idx_t i = 0, j = 0; - int sz = 0, c_sz = 0; - - // Character to be replaced - unordered_map to_replace; - while (i < size_needle && j < size_thread) { - auto codepoint_needle = Utf8Proc::UTF8ToCodepoint(input_needle, sz); - input_needle += sz; - i += sz; - auto codepoint_thread = Utf8Proc::UTF8ToCodepoint(input_thread, sz); - input_thread += sz; - j += sz; - // Ignore unicode character that is existed in to_replace - if (to_replace.count(codepoint_needle) == 0) { - to_replace[codepoint_needle] = codepoint_thread; - } - } - - // Character to be deleted - unordered_set to_delete; - while (i < size_needle) { - auto codepoint_needle = Utf8Proc::UTF8ToCodepoint(input_needle, sz); - input_needle += sz; - i += sz; - // Add unicode character that will be deleted - if (to_replace.count(codepoint_needle) == 0) { - to_delete.insert(codepoint_needle); - } - } - - char c[5] = {'\0', '\0', '\0', '\0', '\0'}; - for (i = 0; i < size_haystack; i += sz) { - auto codepoint_haystack = Utf8Proc::UTF8ToCodepoint(input_haystack, sz); - if (to_replace.count(codepoint_haystack) != 0) { - Utf8Proc::CodepointToUtf8(to_replace[codepoint_haystack], c_sz, c); - result.insert(result.end(), c, c + c_sz); - } else if (to_delete.count(codepoint_haystack) == 0) { - result.insert(result.end(), input_haystack, input_haystack + sz); - } - input_haystack += sz; - } - - return string_t(result.data(), result.size()); -} - -static void TranslateFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &haystack_vector = args.data[0]; - auto &needle_vector = args.data[1]; - auto &thread_vector = args.data[2]; - - vector buffer; - TernaryExecutor::Execute( - haystack_vector, needle_vector, thread_vector, result, args.size(), - [&](string_t input_string, string_t needle_string, string_t thread_string) { - return StringVector::AddString(result, - TranslateScalarFunction(input_string, needle_string, thread_string, buffer)); - }); -} - -ScalarFunction TranslateFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, - TranslateFunction); -} - -} // namespace duckdb - - - - - - - -#include - -namespace duckdb { - -template -struct TrimOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto data = input.GetData(); - auto size = input.GetSize(); - - utf8proc_int32_t codepoint; - auto str = reinterpret_cast(data); - - // Find the first character that is not left trimmed - idx_t begin = 0; - if (LTRIM) { - while (begin < size) { - auto bytes = utf8proc_iterate(str + begin, size - begin, &codepoint); - D_ASSERT(bytes > 0); - if (utf8proc_category(codepoint) != UTF8PROC_CATEGORY_ZS) { - break; - } - begin += bytes; - } - } - - // Find the last character that is not right trimmed - idx_t end; - if (RTRIM) { - end = begin; - for (auto next = begin; next < size;) { - auto bytes = utf8proc_iterate(str + next, size - next, &codepoint); - D_ASSERT(bytes > 0); - next += bytes; - if (utf8proc_category(codepoint) != UTF8PROC_CATEGORY_ZS) { - end = next; - } - } - } else { - end = size; - } - - // Copy the trimmed string - auto target = StringVector::EmptyString(result, end - begin); - auto output = target.GetDataWriteable(); - memcpy(output, data + begin, end - begin); - - target.Finalize(); - return target; - } -}; - -template -static void UnaryTrimFunction(DataChunk &args, ExpressionState &state, Vector &result) { - UnaryExecutor::ExecuteString>(args.data[0], result, args.size()); -} - -static void GetIgnoredCodepoints(string_t ignored, unordered_set &ignored_codepoints) { - auto dataptr = reinterpret_cast(ignored.GetData()); - auto size = ignored.GetSize(); - idx_t pos = 0; - while (pos < size) { - utf8proc_int32_t codepoint; - pos += utf8proc_iterate(dataptr + pos, size - pos, &codepoint); - ignored_codepoints.insert(codepoint); - } -} - -template -static void BinaryTrimFunction(DataChunk &input, ExpressionState &state, Vector &result) { - BinaryExecutor::Execute( - input.data[0], input.data[1], result, input.size(), [&](string_t input, string_t ignored) { - auto data = input.GetData(); - auto size = input.GetSize(); - - unordered_set ignored_codepoints; - GetIgnoredCodepoints(ignored, ignored_codepoints); - - utf8proc_int32_t codepoint; - auto str = reinterpret_cast(data); - - // Find the first character that is not left trimmed - idx_t begin = 0; - if (LTRIM) { - while (begin < size) { - auto bytes = utf8proc_iterate(str + begin, size - begin, &codepoint); - if (ignored_codepoints.find(codepoint) == ignored_codepoints.end()) { - break; - } - begin += bytes; - } - } - - // Find the last character that is not right trimmed - idx_t end; - if (RTRIM) { - end = begin; - for (auto next = begin; next < size;) { - auto bytes = utf8proc_iterate(str + next, size - next, &codepoint); - D_ASSERT(bytes > 0); - next += bytes; - if (ignored_codepoints.find(codepoint) == ignored_codepoints.end()) { - end = next; - } - } - } else { - end = size; - } - - // Copy the trimmed string - auto target = StringVector::EmptyString(result, end - begin); - auto output = target.GetDataWriteable(); - memcpy(output, data + begin, end - begin); - - target.Finalize(); - return target; - }); -} - -ScalarFunctionSet TrimFun::GetFunctions() { - ScalarFunctionSet trim; - trim.AddFunction(ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, UnaryTrimFunction)); - - trim.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, - BinaryTrimFunction)); - return trim; -} - -ScalarFunctionSet LtrimFun::GetFunctions() { - ScalarFunctionSet ltrim; - ltrim.AddFunction(ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, UnaryTrimFunction)); - ltrim.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, - BinaryTrimFunction)); - return ltrim; -} - -ScalarFunctionSet RtrimFun::GetFunctions() { - ScalarFunctionSet rtrim; - rtrim.AddFunction(ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, UnaryTrimFunction)); - - rtrim.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, - BinaryTrimFunction)); - return rtrim; -} - -} // namespace duckdb - - - - - - - -#include - -namespace duckdb { - -struct UnicodeOperator { - template - static inline TR Operation(const TA &input) { - auto str = reinterpret_cast(input.GetData()); - auto len = input.GetSize(); - utf8proc_int32_t codepoint; - (void)utf8proc_iterate(str, len, &codepoint); - return codepoint; - } -}; - -ScalarFunction UnicodeFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR}, LogicalType::INTEGER, - ScalarFunction::UnaryFunction); -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -static void StructInsertFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &starting_vec = args.data[0]; - - starting_vec.Verify(args.size()); - - auto &starting_child_entries = StructVector::GetEntries(starting_vec); - auto &result_child_entries = StructVector::GetEntries(result); - - // Assign the starting vector entries to the result vector - for (size_t i = 0; i < starting_child_entries.size(); i++) { - auto &starting_child = starting_child_entries[i]; - result_child_entries[i]->Reference(*starting_child); - } - - // Assign the new entries to the result vector - for (size_t i = 1; i < args.ColumnCount(); i++) { - result_child_entries[starting_child_entries.size() + i - 1]->Reference(args.data[i]); - } - - result.Verify(args.size()); - - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } -} - -static unique_ptr StructInsertBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - case_insensitive_set_t name_collision_set; - - if (arguments.empty()) { - throw Exception("Missing required arguments for struct_insert function."); - } - - if (LogicalTypeId::STRUCT != arguments[0]->return_type.id()) { - throw Exception("The first argument to struct_insert must be a STRUCT"); - } - - if (arguments.size() < 2) { - throw Exception("Can't insert nothing into a struct"); - } - - child_list_t new_struct_children; - - auto &existing_struct_children = StructType::GetChildTypes(arguments[0]->return_type); - - for (size_t i = 0; i < existing_struct_children.size(); i++) { - auto &child = existing_struct_children[i]; - name_collision_set.insert(child.first); - new_struct_children.push_back(make_pair(child.first, child.second)); - } - - // Loop through the additional arguments (name/value pairs) - for (idx_t i = 1; i < arguments.size(); i++) { - auto &child = arguments[i]; - if (child->alias.empty() && bound_function.name == "struct_insert") { - throw BinderException("Need named argument for struct insert, e.g. STRUCT_PACK(a := b)"); - } - if (name_collision_set.find(child->alias) != name_collision_set.end()) { - throw BinderException("Duplicate struct entry name \"%s\"", child->alias); - } - name_collision_set.insert(child->alias); - new_struct_children.push_back(make_pair(child->alias, arguments[i]->return_type)); - } - - // this is more for completeness reasons - bound_function.return_type = LogicalType::STRUCT(new_struct_children); - return make_uniq(bound_function.return_type); -} - -unique_ptr StructInsertStats(ClientContext &context, FunctionStatisticsInput &input) { - auto &child_stats = input.child_stats; - auto &expr = input.expr; - auto new_struct_stats = StructStats::CreateUnknown(expr.return_type); - - auto existing_count = StructType::GetChildCount(child_stats[0].GetType()); - auto existing_stats = StructStats::GetChildStats(child_stats[0]); - for (idx_t i = 0; i < existing_count; i++) { - StructStats::SetChildStats(new_struct_stats, i, existing_stats[i]); - } - auto new_count = StructType::GetChildCount(expr.return_type); - auto offset = new_count - child_stats.size(); - for (idx_t i = 1; i < child_stats.size(); i++) { - StructStats::SetChildStats(new_struct_stats, offset + i, child_stats[i]); - } - return new_struct_stats.ToUnique(); -} - -ScalarFunction StructInsertFun::GetFunction() { - // the arguments and return types are actually set in the binder function - ScalarFunction fun({}, LogicalTypeId::STRUCT, StructInsertFunction, StructInsertBind, nullptr, StructInsertStats); - fun.varargs = LogicalType::ANY; - fun.serialize = VariableReturnBindData::Serialize; - fun.deserialize = VariableReturnBindData::Deserialize; - return fun; -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -static void StructPackFunction(DataChunk &args, ExpressionState &state, Vector &result) { -#ifdef DEBUG - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - // this should never happen if the binder below is sane - D_ASSERT(args.ColumnCount() == StructType::GetChildTypes(info.stype).size()); -#endif - bool all_const = true; - auto &child_entries = StructVector::GetEntries(result); - for (size_t i = 0; i < args.ColumnCount(); i++) { - if (args.data[i].GetVectorType() != VectorType::CONSTANT_VECTOR) { - all_const = false; - } - // same holds for this - child_entries[i]->Reference(args.data[i]); - } - result.SetVectorType(all_const ? VectorType::CONSTANT_VECTOR : VectorType::FLAT_VECTOR); - - result.Verify(args.size()); -} - -template -static unique_ptr StructPackBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - case_insensitive_set_t name_collision_set; - - // collect names and deconflict, construct return type - if (arguments.empty()) { - throw Exception("Can't pack nothing into a struct"); - } - child_list_t struct_children; - for (idx_t i = 0; i < arguments.size(); i++) { - auto &child = arguments[i]; - string alias; - if (IS_STRUCT_PACK) { - if (child->alias.empty()) { - throw BinderException("Need named argument for struct pack, e.g. STRUCT_PACK(a := b)"); - } - alias = child->alias; - if (name_collision_set.find(alias) != name_collision_set.end()) { - throw BinderException("Duplicate struct entry name \"%s\"", alias); - } - name_collision_set.insert(alias); - } - struct_children.push_back(make_pair(alias, arguments[i]->return_type)); - } - - // this is more for completeness reasons - bound_function.return_type = LogicalType::STRUCT(struct_children); - return make_uniq(bound_function.return_type); -} - -unique_ptr StructPackStats(ClientContext &context, FunctionStatisticsInput &input) { - auto &child_stats = input.child_stats; - auto &expr = input.expr; - auto struct_stats = StructStats::CreateUnknown(expr.return_type); - for (idx_t i = 0; i < child_stats.size(); i++) { - StructStats::SetChildStats(struct_stats, i, child_stats[i]); - } - return struct_stats.ToUnique(); -} - -template -ScalarFunction GetStructPackFunction() { - ScalarFunction fun(IS_STRUCT_PACK ? "struct_pack" : "row", {}, LogicalTypeId::STRUCT, StructPackFunction, - StructPackBind, nullptr, StructPackStats); - fun.varargs = LogicalType::ANY; - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - fun.serialize = VariableReturnBindData::Serialize; - fun.deserialize = VariableReturnBindData::Deserialize; - return fun; -} - -ScalarFunction StructPackFun::GetFunction() { - return GetStructPackFunction(); -} - -ScalarFunction RowFun::GetFunction() { - return GetStructPackFunction(); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -struct UnionExtractBindData : public FunctionData { - UnionExtractBindData(string key, idx_t index, LogicalType type) - : key(std::move(key)), index(index), type(std::move(type)) { - } - - string key; - idx_t index; - LogicalType type; - -public: - unique_ptr Copy() const override { - return make_uniq(key, index, type); - } - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return key == other.key && index == other.index && type == other.type; - } -}; - -static void UnionExtractFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - - // this should be guaranteed by the binder - auto &vec = args.data[0]; - vec.Verify(args.size()); - - D_ASSERT(info.index < UnionType::GetMemberCount(vec.GetType())); - auto &member = UnionVector::GetMember(vec, info.index); - result.Reference(member); - result.Verify(args.size()); -} - -static unique_ptr UnionExtractBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - D_ASSERT(bound_function.arguments.size() == 2); - if (arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN) { - throw ParameterNotResolvedException(); - } - D_ASSERT(LogicalTypeId::UNION == arguments[0]->return_type.id()); - idx_t union_member_count = UnionType::GetMemberCount(arguments[0]->return_type); - if (union_member_count == 0) { - throw InternalException("Can't extract something from an empty union"); - } - bound_function.arguments[0] = arguments[0]->return_type; - - auto &key_child = arguments[1]; - if (key_child->HasParameter()) { - throw ParameterNotResolvedException(); - } - - if (key_child->return_type.id() != LogicalTypeId::VARCHAR || !key_child->IsFoldable()) { - throw BinderException("Key name for union_extract needs to be a constant string"); - } - Value key_val = ExpressionExecutor::EvaluateScalar(context, *key_child); - D_ASSERT(key_val.type().id() == LogicalTypeId::VARCHAR); - auto &key_str = StringValue::Get(key_val); - if (key_val.IsNull() || key_str.empty()) { - throw BinderException("Key name for union_extract needs to be neither NULL nor empty"); - } - string key = StringUtil::Lower(key_str); - - LogicalType return_type; - idx_t key_index = 0; - bool found_key = false; - - for (size_t i = 0; i < union_member_count; i++) { - auto &member_name = UnionType::GetMemberName(arguments[0]->return_type, i); - if (StringUtil::Lower(member_name) == key) { - found_key = true; - key_index = i; - return_type = UnionType::GetMemberType(arguments[0]->return_type, i); - break; - } - } - - if (!found_key) { - vector candidates; - candidates.reserve(union_member_count); - for (idx_t i = 0; i < union_member_count; i++) { - candidates.push_back(UnionType::GetMemberName(arguments[0]->return_type, i)); - } - auto closest_settings = StringUtil::TopNLevenshtein(candidates, key); - auto message = StringUtil::CandidatesMessage(closest_settings, "Candidate Entries"); - throw BinderException("Could not find key \"%s\" in union\n%s", key, message); - } - - bound_function.return_type = return_type; - return make_uniq(key, key_index, return_type); -} - -ScalarFunction UnionExtractFun::GetFunction() { - // the arguments and return types are actually set in the binder function - return ScalarFunction({LogicalTypeId::UNION, LogicalType::VARCHAR}, LogicalType::ANY, UnionExtractFunction, - UnionExtractBind, nullptr, nullptr); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -static unique_ptr UnionTagBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - - if (arguments.empty()) { - throw BinderException("Missing required arguments for union_tag function."); - } - - if (LogicalTypeId::UNKNOWN == arguments[0]->return_type.id()) { - throw ParameterNotResolvedException(); - } - - if (LogicalTypeId::UNION != arguments[0]->return_type.id()) { - throw BinderException("First argument to union_tag function must be a union type."); - } - - if (arguments.size() > 1) { - throw BinderException("Too many arguments, union_tag takes at most one argument."); - } - - auto member_count = UnionType::GetMemberCount(arguments[0]->return_type); - if (member_count == 0) { - // this should never happen, empty unions are not allowed - throw InternalException("Can't get tags from an empty union"); - } - - bound_function.arguments[0] = arguments[0]->return_type; - - auto varchar_vector = Vector(LogicalType::VARCHAR, member_count); - for (idx_t i = 0; i < member_count; i++) { - auto str = string_t(UnionType::GetMemberName(arguments[0]->return_type, i)); - FlatVector::GetData(varchar_vector)[i] = - str.IsInlined() ? str : StringVector::AddString(varchar_vector, str); - } - auto enum_type = LogicalType::ENUM(varchar_vector, member_count); - bound_function.return_type = enum_type; - - return nullptr; -} - -static void UnionTagFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(result.GetType().id() == LogicalTypeId::ENUM); - result.Reinterpret(UnionVector::GetTags(args.data[0])); -} - -ScalarFunction UnionTagFun::GetFunction() { - return ScalarFunction({LogicalTypeId::UNION}, LogicalTypeId::ANY, UnionTagFunction, UnionTagBind, nullptr, - nullptr); // TODO: Statistics? -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -struct UnionValueBindData : public FunctionData { - UnionValueBindData() { - } - -public: - unique_ptr Copy() const override { - return make_uniq(); - } - bool Equals(const FunctionData &other_p) const override { - return true; - } -}; - -static void UnionValueFunction(DataChunk &args, ExpressionState &state, Vector &result) { - // Assign the new entries to the result vector - UnionVector::GetMember(result, 0).Reference(args.data[0]); - - // Set the result tag vector to a constant value - auto &tag_vector = UnionVector::GetTags(result); - tag_vector.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::GetData(tag_vector)[0] = 0; - - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } - - result.Verify(args.size()); -} - -static unique_ptr UnionValueBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - - if (arguments.size() != 1) { - throw BinderException("union_value takes exactly one argument"); - } - auto &child = arguments[0]; - - if (child->alias.empty()) { - throw BinderException("Need named argument for union tag, e.g. UNION_VALUE(a := b)"); - } - - child_list_t union_members; - - union_members.push_back(make_pair(child->alias, child->return_type)); - - bound_function.return_type = LogicalType::UNION(std::move(union_members)); - return make_uniq(bound_function.return_type); -} - -ScalarFunction UnionValueFun::GetFunction() { - ScalarFunction fun("union_value", {}, LogicalTypeId::UNION, UnionValueFunction, UnionValueBind, nullptr, nullptr); - fun.varargs = LogicalType::ANY; - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - fun.serialize = VariableReturnBindData::Serialize; - fun.deserialize = VariableReturnBindData::Deserialize; - return fun; -} - -} // namespace duckdb - - - - - -namespace duckdb { - -AdaptiveFilter::AdaptiveFilter(const Expression &expr) - : iteration_count(0), observe_interval(10), execute_interval(20), warmup(true) { - auto &conj_expr = expr.Cast(); - D_ASSERT(conj_expr.children.size() > 1); - for (idx_t idx = 0; idx < conj_expr.children.size(); idx++) { - permutation.push_back(idx); - if (idx != conj_expr.children.size() - 1) { - swap_likeliness.push_back(100); - } - } - right_random_border = 100 * (conj_expr.children.size() - 1); -} - -AdaptiveFilter::AdaptiveFilter(TableFilterSet *table_filters) - : iteration_count(0), observe_interval(10), execute_interval(20), warmup(true) { - for (auto &table_filter : table_filters->filters) { - permutation.push_back(table_filter.first); - swap_likeliness.push_back(100); - } - swap_likeliness.pop_back(); - right_random_border = 100 * (table_filters->filters.size() - 1); -} -void AdaptiveFilter::AdaptRuntimeStatistics(double duration) { - iteration_count++; - runtime_sum += duration; - - if (!warmup) { - // the last swap was observed - if (observe && iteration_count == observe_interval) { - // keep swap if runtime decreased, else reverse swap - if (prev_mean - (runtime_sum / iteration_count) <= 0) { - // reverse swap because runtime didn't decrease - std::swap(permutation[swap_idx], permutation[swap_idx + 1]); - - // decrease swap likeliness, but make sure there is always a small likeliness left - if (swap_likeliness[swap_idx] > 1) { - swap_likeliness[swap_idx] /= 2; - } - } else { - // keep swap because runtime decreased, reset likeliness - swap_likeliness[swap_idx] = 100; - } - observe = false; - - // reset values - iteration_count = 0; - runtime_sum = 0.0; - } else if (!observe && iteration_count == execute_interval) { - // save old mean to evaluate swap - prev_mean = runtime_sum / iteration_count; - - // get swap index and swap likeliness - std::uniform_int_distribution distribution(1, right_random_border); // a <= i <= b - idx_t random_number = distribution(generator) - 1; - - swap_idx = random_number / 100; // index to be swapped - idx_t likeliness = random_number - 100 * swap_idx; // random number between [0, 100) - - // check if swap is going to happen - if (swap_likeliness[swap_idx] > likeliness) { // always true for the first swap of an index - // swap - std::swap(permutation[swap_idx], permutation[swap_idx + 1]); - - // observe whether swap will be applied - observe = true; - } - - // reset values - iteration_count = 0; - runtime_sum = 0.0; - } - } else { - if (iteration_count == 5) { - // initially set all values - iteration_count = 0; - runtime_sum = 0.0; - observe = false; - warmup = false; - } - } -} - -} // namespace duckdb - - - - - - - - - - - - - -namespace duckdb { - -using ValidityBytes = TupleDataLayout::ValidityBytes; - -GroupedAggregateHashTable::GroupedAggregateHashTable(ClientContext &context, Allocator &allocator, - vector group_types, vector payload_types, - const vector &bindings, - idx_t initial_capacity, idx_t radix_bits) - : GroupedAggregateHashTable(context, allocator, std::move(group_types), std::move(payload_types), - AggregateObject::CreateAggregateObjects(bindings), initial_capacity, radix_bits) { -} - -GroupedAggregateHashTable::GroupedAggregateHashTable(ClientContext &context, Allocator &allocator, - vector group_types) - : GroupedAggregateHashTable(context, allocator, std::move(group_types), {}, vector()) { -} - -GroupedAggregateHashTable::AggregateHTAppendState::AggregateHTAppendState() - : ht_offsets(LogicalType::UBIGINT), hash_salts(LogicalType::HASH), group_compare_vector(STANDARD_VECTOR_SIZE), - no_match_vector(STANDARD_VECTOR_SIZE), empty_vector(STANDARD_VECTOR_SIZE), new_groups(STANDARD_VECTOR_SIZE), - addresses(LogicalType::POINTER) { -} - -GroupedAggregateHashTable::GroupedAggregateHashTable(ClientContext &context, Allocator &allocator, - vector group_types_p, - vector payload_types_p, - vector aggregate_objects_p, - idx_t initial_capacity, idx_t radix_bits) - : BaseAggregateHashTable(context, allocator, aggregate_objects_p, std::move(payload_types_p)), - radix_bits(radix_bits), count(0), capacity(0), aggregate_allocator(make_shared(allocator)) { - - // Append hash column to the end and initialise the row layout - group_types_p.emplace_back(LogicalType::HASH); - layout.Initialize(std::move(group_types_p), std::move(aggregate_objects_p)); - - hash_offset = layout.GetOffsets()[layout.ColumnCount() - 1]; - - // Partitioned data and pointer table - InitializePartitionedData(); - Resize(initial_capacity); - - // Predicates - predicates.resize(layout.ColumnCount() - 1, ExpressionType::COMPARE_NOT_DISTINCT_FROM); - row_matcher.Initialize(true, layout, predicates); -} - -void GroupedAggregateHashTable::InitializePartitionedData() { - if (!partitioned_data || RadixPartitioning::RadixBits(partitioned_data->PartitionCount()) != radix_bits) { - D_ASSERT(!partitioned_data || partitioned_data->Count() == 0); - partitioned_data = - make_uniq(buffer_manager, layout, radix_bits, layout.ColumnCount() - 1); - } else { - partitioned_data->Reset(); - } - - D_ASSERT(GetLayout().GetAggrWidth() == layout.GetAggrWidth()); - D_ASSERT(GetLayout().GetDataWidth() == layout.GetDataWidth()); - D_ASSERT(GetLayout().GetRowWidth() == layout.GetRowWidth()); - - partitioned_data->InitializeAppendState(state.append_state, TupleDataPinProperties::KEEP_EVERYTHING_PINNED); -} - -unique_ptr &GroupedAggregateHashTable::GetPartitionedData() { - return partitioned_data; -} - -shared_ptr GroupedAggregateHashTable::GetAggregateAllocator() { - return aggregate_allocator; -} - -GroupedAggregateHashTable::~GroupedAggregateHashTable() { - Destroy(); -} - -void GroupedAggregateHashTable::Destroy() { - if (!partitioned_data || partitioned_data->Count() == 0 || !layout.HasDestructor()) { - return; - } - - // There are aggregates with destructors: Call the destructor for each of the aggregates - // Currently does not happen because aggregate destructors are called while scanning in RadixPartitionedHashTable - // LCOV_EXCL_START - RowOperationsState row_state(*aggregate_allocator); - for (auto &data_collection : partitioned_data->GetPartitions()) { - if (data_collection->Count() == 0) { - continue; - } - TupleDataChunkIterator iterator(*data_collection, TupleDataPinProperties::DESTROY_AFTER_DONE, false); - auto &row_locations = iterator.GetChunkState().row_locations; - do { - RowOperations::DestroyStates(row_state, layout, row_locations, iterator.GetCurrentChunkCount()); - } while (iterator.Next()); - data_collection->Reset(); - } - // LCOV_EXCL_STOP -} - -const TupleDataLayout &GroupedAggregateHashTable::GetLayout() const { - return partitioned_data->GetLayout(); -} - -idx_t GroupedAggregateHashTable::Count() const { - return count; -} - -idx_t GroupedAggregateHashTable::InitialCapacity() { - return STANDARD_VECTOR_SIZE * 2ULL; -} - -idx_t GroupedAggregateHashTable::GetCapacityForCount(idx_t count) { - count = MaxValue(InitialCapacity(), count); - return NextPowerOfTwo(count * LOAD_FACTOR); -} - -idx_t GroupedAggregateHashTable::Capacity() const { - return capacity; -} - -idx_t GroupedAggregateHashTable::ResizeThreshold() const { - return Capacity() / LOAD_FACTOR; -} - -idx_t GroupedAggregateHashTable::ApplyBitMask(hash_t hash) const { - return hash & bitmask; -} - -void GroupedAggregateHashTable::Verify() { -#ifdef DEBUG - idx_t total_count = 0; - for (idx_t i = 0; i < capacity; i++) { - const auto &entry = entries[i]; - if (!entry.IsOccupied()) { - continue; - } - auto hash = Load(entry.GetPointer() + hash_offset); - D_ASSERT(entry.GetSalt() == aggr_ht_entry_t::ExtractSalt(hash)); - total_count++; - } - D_ASSERT(total_count == Count()); -#endif -} - -void GroupedAggregateHashTable::ClearPointerTable() { - std::fill_n(entries, capacity, aggr_ht_entry_t(0)); -} - -void GroupedAggregateHashTable::ResetCount() { - count = 0; -} - -void GroupedAggregateHashTable::SetRadixBits(idx_t radix_bits_p) { - radix_bits = radix_bits_p; -} - -void GroupedAggregateHashTable::Resize(idx_t size) { - D_ASSERT(size >= STANDARD_VECTOR_SIZE); - D_ASSERT(IsPowerOfTwo(size)); - if (size < capacity) { - throw InternalException("Cannot downsize a hash table!"); - } - - capacity = size; - hash_map = buffer_manager.GetBufferAllocator().Allocate(capacity * sizeof(aggr_ht_entry_t)); - entries = reinterpret_cast(hash_map.get()); - ClearPointerTable(); - bitmask = capacity - 1; - - if (Count() != 0) { - for (auto &data_collection : partitioned_data->GetPartitions()) { - if (data_collection->Count() == 0) { - continue; - } - TupleDataChunkIterator iterator(*data_collection, TupleDataPinProperties::ALREADY_PINNED, false); - const auto row_locations = iterator.GetRowLocations(); - do { - for (idx_t i = 0; i < iterator.GetCurrentChunkCount(); i++) { - const auto &row_location = row_locations[i]; - const auto hash = Load(row_location + hash_offset); - - // Find an empty entry - auto entry_idx = ApplyBitMask(hash); - D_ASSERT(entry_idx == hash % capacity); - while (entries[entry_idx].IsOccupied() > 0) { - entry_idx++; - if (entry_idx >= capacity) { - entry_idx = 0; - } - } - auto &entry = entries[entry_idx]; - D_ASSERT(!entry.IsOccupied()); - entry.SetSalt(aggr_ht_entry_t::ExtractSalt(hash)); - entry.SetPointer(row_location); - D_ASSERT(entry.IsOccupied()); - } - } while (iterator.Next()); - } - } - - Verify(); -} - -idx_t GroupedAggregateHashTable::AddChunk(DataChunk &groups, DataChunk &payload, AggregateType filter) { - unsafe_vector aggregate_filter; - - auto &aggregates = layout.GetAggregates(); - for (idx_t i = 0; i < aggregates.size(); i++) { - auto &aggregate = aggregates[i]; - if (aggregate.aggr_type == filter) { - aggregate_filter.push_back(i); - } - } - return AddChunk(groups, payload, aggregate_filter); -} - -idx_t GroupedAggregateHashTable::AddChunk(DataChunk &groups, DataChunk &payload, const unsafe_vector &filter) { - Vector hashes(LogicalType::HASH); - groups.Hash(hashes); - - return AddChunk(groups, hashes, payload, filter); -} - -idx_t GroupedAggregateHashTable::AddChunk(DataChunk &groups, Vector &group_hashes, DataChunk &payload, - const unsafe_vector &filter) { - if (groups.size() == 0) { - return 0; - } - -#ifdef DEBUG - D_ASSERT(groups.ColumnCount() + 1 == layout.ColumnCount()); - for (idx_t i = 0; i < groups.ColumnCount(); i++) { - D_ASSERT(groups.GetTypes()[i] == layout.GetTypes()[i]); - } -#endif - - const auto new_group_count = FindOrCreateGroups(groups, group_hashes, state.addresses, state.new_groups); - VectorOperations::AddInPlace(state.addresses, layout.GetAggrOffset(), payload.size()); - - // Now every cell has an entry, update the aggregates - auto &aggregates = layout.GetAggregates(); - idx_t filter_idx = 0; - idx_t payload_idx = 0; - RowOperationsState row_state(*aggregate_allocator); - for (idx_t i = 0; i < aggregates.size(); i++) { - auto &aggr = aggregates[i]; - if (filter_idx >= filter.size() || i < filter[filter_idx]) { - // Skip all the aggregates that are not in the filter - payload_idx += aggr.child_count; - VectorOperations::AddInPlace(state.addresses, aggr.payload_size, payload.size()); - continue; - } - D_ASSERT(i == filter[filter_idx]); - - if (aggr.aggr_type != AggregateType::DISTINCT && aggr.filter) { - RowOperations::UpdateFilteredStates(row_state, filter_set.GetFilterData(i), aggr, state.addresses, payload, - payload_idx); - } else { - RowOperations::UpdateStates(row_state, aggr, state.addresses, payload, payload_idx, payload.size()); - } - - // Move to the next aggregate - payload_idx += aggr.child_count; - VectorOperations::AddInPlace(state.addresses, aggr.payload_size, payload.size()); - filter_idx++; - } - - Verify(); - return new_group_count; -} - -void GroupedAggregateHashTable::FetchAggregates(DataChunk &groups, DataChunk &result) { -#ifdef DEBUG - groups.Verify(); - D_ASSERT(groups.ColumnCount() + 1 == layout.ColumnCount()); - for (idx_t i = 0; i < result.ColumnCount(); i++) { - D_ASSERT(result.data[i].GetType() == payload_types[i]); - } -#endif - - result.SetCardinality(groups); - if (groups.size() == 0) { - return; - } - - // find the groups associated with the addresses - // FIXME: this should not use the FindOrCreateGroups, creating them is unnecessary - Vector addresses(LogicalType::POINTER); - FindOrCreateGroups(groups, addresses); - // now fetch the aggregates - RowOperationsState row_state(*aggregate_allocator); - RowOperations::FinalizeStates(row_state, layout, addresses, result, 0); -} - -idx_t GroupedAggregateHashTable::FindOrCreateGroupsInternal(DataChunk &groups, Vector &group_hashes_v, - Vector &addresses_v, SelectionVector &new_groups_out) { - D_ASSERT(groups.ColumnCount() + 1 == layout.ColumnCount()); - D_ASSERT(group_hashes_v.GetType() == LogicalType::HASH); - D_ASSERT(state.ht_offsets.GetVectorType() == VectorType::FLAT_VECTOR); - D_ASSERT(state.ht_offsets.GetType() == LogicalType::UBIGINT); - D_ASSERT(addresses_v.GetType() == LogicalType::POINTER); - D_ASSERT(state.hash_salts.GetType() == LogicalType::HASH); - - // Need to fit the entire vector, and resize at threshold - if (Count() + groups.size() > capacity || Count() + groups.size() > ResizeThreshold()) { - Verify(); - Resize(capacity * 2); - } - D_ASSERT(capacity - Count() >= groups.size()); // we need to be able to fit at least one vector of data - - group_hashes_v.Flatten(groups.size()); - auto hashes = FlatVector::GetData(group_hashes_v); - - addresses_v.Flatten(groups.size()); - auto addresses = FlatVector::GetData(addresses_v); - - // Compute the entry in the table based on the hash using a modulo, - // and precompute the hash salts for faster comparison below - auto ht_offsets = FlatVector::GetData(state.ht_offsets); - auto hash_salts = FlatVector::GetData(state.hash_salts); - for (idx_t r = 0; r < groups.size(); r++) { - const auto &hash = hashes[r]; - ht_offsets[r] = ApplyBitMask(hash); - D_ASSERT(ht_offsets[r] == hash % capacity); - hash_salts[r] = aggr_ht_entry_t::ExtractSalt(hash); - } - - // we start out with all entries [0, 1, 2, ..., groups.size()] - const SelectionVector *sel_vector = FlatVector::IncrementalSelectionVector(); - - // Make a chunk that references the groups and the hashes and convert to unified format - if (state.group_chunk.ColumnCount() == 0) { - state.group_chunk.InitializeEmpty(layout.GetTypes()); - } - D_ASSERT(state.group_chunk.ColumnCount() == layout.GetTypes().size()); - for (idx_t grp_idx = 0; grp_idx < groups.ColumnCount(); grp_idx++) { - state.group_chunk.data[grp_idx].Reference(groups.data[grp_idx]); - } - state.group_chunk.data[groups.ColumnCount()].Reference(group_hashes_v); - state.group_chunk.SetCardinality(groups); - - // convert all vectors to unified format - auto &chunk_state = state.append_state.chunk_state; - TupleDataCollection::ToUnifiedFormat(chunk_state, state.group_chunk); - if (!state.group_data) { - state.group_data = make_unsafe_uniq_array(state.group_chunk.ColumnCount()); - } - TupleDataCollection::GetVectorData(chunk_state, state.group_data.get()); - - idx_t new_group_count = 0; - idx_t remaining_entries = groups.size(); - while (remaining_entries > 0) { - idx_t new_entry_count = 0; - idx_t need_compare_count = 0; - idx_t no_match_count = 0; - - // For each remaining entry, figure out whether or not it belongs to a full or empty group - for (idx_t i = 0; i < remaining_entries; i++) { - const auto index = sel_vector->get_index(i); - const auto &salt = hash_salts[index]; - auto &entry = entries[ht_offsets[index]]; - if (entry.IsOccupied()) { // Cell is occupied: Compare salts - if (entry.GetSalt() == salt) { - state.group_compare_vector.set_index(need_compare_count++, index); - } else { - state.no_match_vector.set_index(no_match_count++, index); - } - } else { // Cell is unoccupied - // Set salt (also marks as occupied) - entry.SetSalt(salt); - - // Update selection lists for outer loops - state.empty_vector.set_index(new_entry_count++, index); - new_groups_out.set_index(new_group_count++, index); - } - } - - if (new_entry_count != 0) { - // Append everything that belongs to an empty group - partitioned_data->AppendUnified(state.append_state, state.group_chunk, state.empty_vector, new_entry_count); - RowOperations::InitializeStates(layout, chunk_state.row_locations, - *FlatVector::IncrementalSelectionVector(), new_entry_count); - - // Set the entry pointers in the 1st part of the HT now that the data has been appended - const auto row_locations = FlatVector::GetData(chunk_state.row_locations); - const auto &row_sel = state.append_state.reverse_partition_sel; - for (idx_t new_entry_idx = 0; new_entry_idx < new_entry_count; new_entry_idx++) { - const auto index = state.empty_vector.get_index(new_entry_idx); - const auto row_idx = row_sel.get_index(index); - const auto &row_location = row_locations[row_idx]; - - auto &entry = entries[ht_offsets[index]]; - - entry.SetPointer(row_location); - addresses[index] = row_location; - } - } - - if (need_compare_count != 0) { - // Get the pointers to the rows that need to be compared - for (idx_t need_compare_idx = 0; need_compare_idx < need_compare_count; need_compare_idx++) { - const auto index = state.group_compare_vector.get_index(need_compare_idx); - const auto &entry = entries[ht_offsets[index]]; - addresses[index] = entry.GetPointer(); - } - - // Perform group comparisons - row_matcher.Match(state.group_chunk, chunk_state.vector_data, state.group_compare_vector, - need_compare_count, layout, addresses_v, &state.no_match_vector, no_match_count); - } - - // Linear probing: each of the entries that do not match move to the next entry in the HT - for (idx_t i = 0; i < no_match_count; i++) { - idx_t index = state.no_match_vector.get_index(i); - ht_offsets[index]++; - if (ht_offsets[index] >= capacity) { - ht_offsets[index] = 0; - } - } - sel_vector = &state.no_match_vector; - remaining_entries = no_match_count; - } - - count += new_group_count; - return new_group_count; -} - -// this is to support distinct aggregations where we need to record whether we -// have already seen a value for a group -idx_t GroupedAggregateHashTable::FindOrCreateGroups(DataChunk &groups, Vector &group_hashes, Vector &addresses_out, - SelectionVector &new_groups_out) { - return FindOrCreateGroupsInternal(groups, group_hashes, addresses_out, new_groups_out); -} - -void GroupedAggregateHashTable::FindOrCreateGroups(DataChunk &groups, Vector &addresses) { - // create a dummy new_groups sel vector - FindOrCreateGroups(groups, addresses, state.new_groups); -} - -idx_t GroupedAggregateHashTable::FindOrCreateGroups(DataChunk &groups, Vector &addresses_out, - SelectionVector &new_groups_out) { - Vector hashes(LogicalType::HASH); - groups.Hash(hashes); - return FindOrCreateGroups(groups, hashes, addresses_out, new_groups_out); -} - -struct FlushMoveState { - explicit FlushMoveState(TupleDataCollection &collection_p) - : collection(collection_p), hashes(LogicalType::HASH), group_addresses(LogicalType::POINTER), - new_groups_sel(STANDARD_VECTOR_SIZE) { - const auto &layout = collection.GetLayout(); - vector column_ids; - column_ids.reserve(layout.ColumnCount() - 1); - for (idx_t col_idx = 0; col_idx < layout.ColumnCount() - 1; col_idx++) { - column_ids.emplace_back(col_idx); - } - collection.InitializeScan(scan_state, column_ids, TupleDataPinProperties::DESTROY_AFTER_DONE); - collection.InitializeScanChunk(scan_state, groups); - hash_col_idx = layout.ColumnCount() - 1; - } - - bool Scan() { - if (collection.Scan(scan_state, groups)) { - collection.Gather(scan_state.chunk_state.row_locations, *FlatVector::IncrementalSelectionVector(), - groups.size(), hash_col_idx, hashes, *FlatVector::IncrementalSelectionVector()); - return true; - } - - collection.FinalizePinState(scan_state.pin_state); - return false; - } - - TupleDataCollection &collection; - TupleDataScanState scan_state; - DataChunk groups; - - idx_t hash_col_idx; - Vector hashes; - - Vector group_addresses; - SelectionVector new_groups_sel; -}; - -void GroupedAggregateHashTable::Combine(GroupedAggregateHashTable &other) { - auto other_data = other.partitioned_data->GetUnpartitioned(); - Combine(*other_data); - - // Inherit ownership to all stored aggregate allocators - stored_allocators.emplace_back(other.aggregate_allocator); - for (const auto &stored_allocator : other.stored_allocators) { - stored_allocators.emplace_back(stored_allocator); - } -} - -void GroupedAggregateHashTable::Combine(TupleDataCollection &other_data) { - D_ASSERT(other_data.GetLayout().GetAggrWidth() == layout.GetAggrWidth()); - D_ASSERT(other_data.GetLayout().GetDataWidth() == layout.GetDataWidth()); - D_ASSERT(other_data.GetLayout().GetRowWidth() == layout.GetRowWidth()); - - if (other_data.Count() == 0) { - return; - } - - FlushMoveState fm_state(other_data); - RowOperationsState row_state(*aggregate_allocator); - while (fm_state.Scan()) { - FindOrCreateGroups(fm_state.groups, fm_state.hashes, fm_state.group_addresses, fm_state.new_groups_sel); - RowOperations::CombineStates(row_state, layout, fm_state.scan_state.chunk_state.row_locations, - fm_state.group_addresses, fm_state.groups.size()); - if (layout.HasDestructor()) { - RowOperations::DestroyStates(row_state, layout, fm_state.scan_state.chunk_state.row_locations, - fm_state.groups.size()); - } - } - - Verify(); -} - -void GroupedAggregateHashTable::UnpinData() { - partitioned_data->FlushAppendState(state.append_state); - partitioned_data->Unpin(); -} - -} // namespace duckdb - - - - -namespace duckdb { - -BaseAggregateHashTable::BaseAggregateHashTable(ClientContext &context, Allocator &allocator, - const vector &aggregates, - vector payload_types_p) - : allocator(allocator), buffer_manager(BufferManager::GetBufferManager(context)), - payload_types(std::move(payload_types_p)) { - filter_set.Initialize(context, aggregates, payload_types); -} - -} // namespace duckdb - - - - - - - - - - - - - - -namespace duckdb { - -ColumnBindingResolver::ColumnBindingResolver() { -} - -void ColumnBindingResolver::VisitOperator(LogicalOperator &op) { - switch (op.type) { - case LogicalOperatorType::LOGICAL_ASOF_JOIN: - case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: - case LogicalOperatorType::LOGICAL_DELIM_JOIN: { - // special case: comparison join - auto &comp_join = op.Cast(); - // first get the bindings of the LHS and resolve the LHS expressions - VisitOperator(*comp_join.children[0]); - for (auto &cond : comp_join.conditions) { - VisitExpression(&cond.left); - } - // visit the duplicate eliminated columns on the LHS, if any - for (auto &expr : comp_join.duplicate_eliminated_columns) { - VisitExpression(&expr); - } - // then get the bindings of the RHS and resolve the RHS expressions - VisitOperator(*comp_join.children[1]); - for (auto &cond : comp_join.conditions) { - VisitExpression(&cond.right); - } - // finally update the bindings with the result bindings of the join - bindings = op.GetColumnBindings(); - return; - } - case LogicalOperatorType::LOGICAL_ANY_JOIN: { - // ANY join, this join is different because we evaluate the expression on the bindings of BOTH join sides at - // once i.e. we set the bindings first to the bindings of the entire join, and then resolve the expressions of - // this operator - VisitOperatorChildren(op); - bindings = op.GetColumnBindings(); - auto &any_join = op.Cast(); - if (any_join.join_type == JoinType::SEMI || any_join.join_type == JoinType::ANTI) { - auto right_bindings = op.children[1]->GetColumnBindings(); - bindings.insert(bindings.end(), right_bindings.begin(), right_bindings.end()); - } - VisitOperatorExpressions(op); - return; - } - case LogicalOperatorType::LOGICAL_CREATE_INDEX: { - // CREATE INDEX statement, add the columns of the table with table index 0 to the binding set - // afterwards bind the expressions of the CREATE INDEX statement - auto &create_index = op.Cast(); - bindings = LogicalOperator::GenerateColumnBindings(0, create_index.table.GetColumns().LogicalColumnCount()); - VisitOperatorExpressions(op); - return; - } - case LogicalOperatorType::LOGICAL_GET: { - //! We first need to update the current set of bindings and then visit operator expressions - bindings = op.GetColumnBindings(); - VisitOperatorExpressions(op); - return; - } - case LogicalOperatorType::LOGICAL_INSERT: { - //! We want to execute the normal path, but also add a dummy 'excluded' binding if there is a - // ON CONFLICT DO UPDATE clause - auto &insert_op = op.Cast(); - if (insert_op.action_type != OnConflictAction::THROW) { - // Get the bindings from the children - VisitOperatorChildren(op); - auto column_count = insert_op.table.GetColumns().PhysicalColumnCount(); - auto dummy_bindings = LogicalOperator::GenerateColumnBindings(insert_op.excluded_table_index, column_count); - // Now insert our dummy bindings at the start of the bindings, - // so the first 'column_count' indices of the chunk are reserved for our 'excluded' columns - bindings.insert(bindings.begin(), dummy_bindings.begin(), dummy_bindings.end()); - if (insert_op.on_conflict_condition) { - VisitExpression(&insert_op.on_conflict_condition); - } - if (insert_op.do_update_condition) { - VisitExpression(&insert_op.do_update_condition); - } - VisitOperatorExpressions(op); - bindings = op.GetColumnBindings(); - return; - } - break; - } - case LogicalOperatorType::LOGICAL_EXTENSION_OPERATOR: { - auto &ext_op = op.Cast(); - ext_op.ResolveColumnBindings(*this, bindings); - return; - } - default: - break; - } - - // general case - // first visit the children of this operator - VisitOperatorChildren(op); - // now visit the expressions of this operator to resolve any bound column references - VisitOperatorExpressions(op); - // finally update the current set of bindings to the current set of column bindings - bindings = op.GetColumnBindings(); -} - -unique_ptr ColumnBindingResolver::VisitReplace(BoundColumnRefExpression &expr, - unique_ptr *expr_ptr) { - D_ASSERT(expr.depth == 0); - // check the current set of column bindings to see which index corresponds to the column reference - for (idx_t i = 0; i < bindings.size(); i++) { - if (expr.binding == bindings[i]) { - return make_uniq(expr.alias, expr.return_type, i); - } - } - // LCOV_EXCL_START - // could not bind the column reference, this should never happen and indicates a bug in the code - // generate an error message - string bound_columns = "["; - for (idx_t i = 0; i < bindings.size(); i++) { - if (i != 0) { - bound_columns += " "; - } - bound_columns += to_string(bindings[i].table_index) + "." + to_string(bindings[i].column_index); - } - bound_columns += "]"; - - throw InternalException("Failed to bind column reference \"%s\" [%d.%d] (bindings: %s)", expr.alias, - expr.binding.table_index, expr.binding.column_index, bound_columns); - // LCOV_EXCL_STOP -} - -unordered_set ColumnBindingResolver::VerifyInternal(LogicalOperator &op) { - unordered_set result; - for (auto &child : op.children) { - auto child_indexes = VerifyInternal(*child); - for (auto index : child_indexes) { - D_ASSERT(index != DConstants::INVALID_INDEX); - if (result.find(index) != result.end()) { - throw InternalException("Duplicate table index \"%lld\" found", index); - } - result.insert(index); - } - } - auto indexes = op.GetTableIndex(); - for (auto index : indexes) { - D_ASSERT(index != DConstants::INVALID_INDEX); - if (result.find(index) != result.end()) { - throw InternalException("Duplicate table index \"%lld\" found", index); - } - result.insert(index); - } - return result; -} - -void ColumnBindingResolver::Verify(LogicalOperator &op) { -#ifdef DEBUG - VerifyInternal(op); -#endif -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -struct BothInclusiveBetweenOperator { - template - static inline bool Operation(T input, T lower, T upper) { - return GreaterThanEquals::Operation(input, lower) && LessThanEquals::Operation(input, upper); - } -}; - -struct LowerInclusiveBetweenOperator { - template - static inline bool Operation(T input, T lower, T upper) { - return GreaterThanEquals::Operation(input, lower) && LessThan::Operation(input, upper); - } -}; - -struct UpperInclusiveBetweenOperator { - template - static inline bool Operation(T input, T lower, T upper) { - return GreaterThan::Operation(input, lower) && LessThanEquals::Operation(input, upper); - } -}; - -struct ExclusiveBetweenOperator { - template - static inline bool Operation(T input, T lower, T upper) { - return GreaterThan::Operation(input, lower) && LessThan::Operation(input, upper); - } -}; - -template -static idx_t BetweenLoopTypeSwitch(Vector &input, Vector &lower, Vector &upper, const SelectionVector *sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - switch (input.GetType().InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - return TernaryExecutor::Select(input, lower, upper, sel, count, true_sel, - false_sel); - case PhysicalType::INT16: - return TernaryExecutor::Select(input, lower, upper, sel, count, true_sel, - false_sel); - case PhysicalType::INT32: - return TernaryExecutor::Select(input, lower, upper, sel, count, true_sel, - false_sel); - case PhysicalType::INT64: - return TernaryExecutor::Select(input, lower, upper, sel, count, true_sel, - false_sel); - case PhysicalType::INT128: - return TernaryExecutor::Select(input, lower, upper, sel, count, true_sel, - false_sel); - case PhysicalType::UINT8: - return TernaryExecutor::Select(input, lower, upper, sel, count, true_sel, - false_sel); - case PhysicalType::UINT16: - return TernaryExecutor::Select(input, lower, upper, sel, count, true_sel, - false_sel); - case PhysicalType::UINT32: - return TernaryExecutor::Select(input, lower, upper, sel, count, true_sel, - false_sel); - case PhysicalType::UINT64: - return TernaryExecutor::Select(input, lower, upper, sel, count, true_sel, - false_sel); - case PhysicalType::FLOAT: - return TernaryExecutor::Select(input, lower, upper, sel, count, true_sel, false_sel); - case PhysicalType::DOUBLE: - return TernaryExecutor::Select(input, lower, upper, sel, count, true_sel, - false_sel); - case PhysicalType::VARCHAR: - return TernaryExecutor::Select(input, lower, upper, sel, count, true_sel, - false_sel); - case PhysicalType::INTERVAL: - return TernaryExecutor::Select(input, lower, upper, sel, count, - true_sel, false_sel); - default: - throw InvalidTypeException(input.GetType(), "Invalid type for BETWEEN"); - } -} - -unique_ptr ExpressionExecutor::InitializeState(const BoundBetweenExpression &expr, - ExpressionExecutorState &root) { - auto result = make_uniq(expr, root); - result->AddChild(expr.input.get()); - result->AddChild(expr.lower.get()); - result->AddChild(expr.upper.get()); - result->Finalize(); - return result; -} - -void ExpressionExecutor::Execute(const BoundBetweenExpression &expr, ExpressionState *state, const SelectionVector *sel, - idx_t count, Vector &result) { - // resolve the children - state->intermediate_chunk.Reset(); - - auto &input = state->intermediate_chunk.data[0]; - auto &lower = state->intermediate_chunk.data[1]; - auto &upper = state->intermediate_chunk.data[2]; - - Execute(*expr.input, state->child_states[0].get(), sel, count, input); - Execute(*expr.lower, state->child_states[1].get(), sel, count, lower); - Execute(*expr.upper, state->child_states[2].get(), sel, count, upper); - - Vector intermediate1(LogicalType::BOOLEAN); - Vector intermediate2(LogicalType::BOOLEAN); - - if (expr.upper_inclusive && expr.lower_inclusive) { - VectorOperations::GreaterThanEquals(input, lower, intermediate1, count); - VectorOperations::LessThanEquals(input, upper, intermediate2, count); - } else if (expr.lower_inclusive) { - VectorOperations::GreaterThanEquals(input, lower, intermediate1, count); - VectorOperations::LessThan(input, upper, intermediate2, count); - } else if (expr.upper_inclusive) { - VectorOperations::GreaterThan(input, lower, intermediate1, count); - VectorOperations::LessThanEquals(input, upper, intermediate2, count); - } else { - VectorOperations::GreaterThan(input, lower, intermediate1, count); - VectorOperations::LessThan(input, upper, intermediate2, count); - } - VectorOperations::And(intermediate1, intermediate2, result, count); -} - -idx_t ExpressionExecutor::Select(const BoundBetweenExpression &expr, ExpressionState *state, const SelectionVector *sel, - idx_t count, SelectionVector *true_sel, SelectionVector *false_sel) { - // resolve the children - Vector input(state->intermediate_chunk.data[0]); - Vector lower(state->intermediate_chunk.data[1]); - Vector upper(state->intermediate_chunk.data[2]); - - Execute(*expr.input, state->child_states[0].get(), sel, count, input); - Execute(*expr.lower, state->child_states[1].get(), sel, count, lower); - Execute(*expr.upper, state->child_states[2].get(), sel, count, upper); - - if (expr.upper_inclusive && expr.lower_inclusive) { - return BetweenLoopTypeSwitch(input, lower, upper, sel, count, true_sel, - false_sel); - } else if (expr.lower_inclusive) { - return BetweenLoopTypeSwitch(input, lower, upper, sel, count, true_sel, - false_sel); - } else if (expr.upper_inclusive) { - return BetweenLoopTypeSwitch(input, lower, upper, sel, count, true_sel, - false_sel); - } else { - return BetweenLoopTypeSwitch(input, lower, upper, sel, count, true_sel, false_sel); - } -} - -} // namespace duckdb - - - - -namespace duckdb { - -struct CaseExpressionState : public ExpressionState { - CaseExpressionState(const Expression &expr, ExpressionExecutorState &root) - : ExpressionState(expr, root), true_sel(STANDARD_VECTOR_SIZE), false_sel(STANDARD_VECTOR_SIZE) { - } - - SelectionVector true_sel; - SelectionVector false_sel; -}; - -unique_ptr ExpressionExecutor::InitializeState(const BoundCaseExpression &expr, - ExpressionExecutorState &root) { - auto result = make_uniq(expr, root); - for (auto &case_check : expr.case_checks) { - result->AddChild(case_check.when_expr.get()); - result->AddChild(case_check.then_expr.get()); - } - result->AddChild(expr.else_expr.get()); - result->Finalize(); - return std::move(result); -} - -void ExpressionExecutor::Execute(const BoundCaseExpression &expr, ExpressionState *state_p, const SelectionVector *sel, - idx_t count, Vector &result) { - auto &state = state_p->Cast(); - - state.intermediate_chunk.Reset(); - - // first execute the check expression - auto current_true_sel = &state.true_sel; - auto current_false_sel = &state.false_sel; - auto current_sel = sel; - idx_t current_count = count; - for (idx_t i = 0; i < expr.case_checks.size(); i++) { - auto &case_check = expr.case_checks[i]; - auto &intermediate_result = state.intermediate_chunk.data[i * 2 + 1]; - auto check_state = state.child_states[i * 2].get(); - auto then_state = state.child_states[i * 2 + 1].get(); - - idx_t tcount = - Select(*case_check.when_expr, check_state, current_sel, current_count, current_true_sel, current_false_sel); - if (tcount == 0) { - // everything is false: do nothing - continue; - } - idx_t fcount = current_count - tcount; - if (fcount == 0 && current_count == count) { - // everything is true in the first CHECK statement - // we can skip the entire case and only execute the TRUE side - Execute(*case_check.then_expr, then_state, sel, count, result); - return; - } else { - // we need to execute and then fill in the desired tuples in the result - Execute(*case_check.then_expr, then_state, current_true_sel, tcount, intermediate_result); - FillSwitch(intermediate_result, result, *current_true_sel, tcount); - } - // continue with the false tuples - current_sel = current_false_sel; - current_count = fcount; - if (fcount == 0) { - // everything is true: we are done - break; - } - } - if (current_count > 0) { - auto else_state = state.child_states.back().get(); - if (current_count == count) { - // everything was false, we can just evaluate the else expression directly - Execute(*expr.else_expr, else_state, sel, count, result); - return; - } else { - auto &intermediate_result = state.intermediate_chunk.data[expr.case_checks.size() * 2]; - - D_ASSERT(current_sel); - Execute(*expr.else_expr, else_state, current_sel, current_count, intermediate_result); - FillSwitch(intermediate_result, result, *current_sel, current_count); - } - } - if (sel) { - result.Slice(*sel, count); - } -} - -template -void TemplatedFillLoop(Vector &vector, Vector &result, const SelectionVector &sel, sel_t count) { - result.SetVectorType(VectorType::FLAT_VECTOR); - auto res = FlatVector::GetData(result); - auto &result_mask = FlatVector::Validity(result); - if (vector.GetVectorType() == VectorType::CONSTANT_VECTOR) { - auto data = ConstantVector::GetData(vector); - if (ConstantVector::IsNull(vector)) { - for (idx_t i = 0; i < count; i++) { - result_mask.SetInvalid(sel.get_index(i)); - } - } else { - for (idx_t i = 0; i < count; i++) { - res[sel.get_index(i)] = *data; - } - } - } else { - UnifiedVectorFormat vdata; - vector.ToUnifiedFormat(count, vdata); - auto data = UnifiedVectorFormat::GetData(vdata); - for (idx_t i = 0; i < count; i++) { - auto source_idx = vdata.sel->get_index(i); - auto res_idx = sel.get_index(i); - - res[res_idx] = data[source_idx]; - result_mask.Set(res_idx, vdata.validity.RowIsValid(source_idx)); - } - } -} - -void ValidityFillLoop(Vector &vector, Vector &result, const SelectionVector &sel, sel_t count) { - result.SetVectorType(VectorType::FLAT_VECTOR); - auto &result_mask = FlatVector::Validity(result); - if (vector.GetVectorType() == VectorType::CONSTANT_VECTOR) { - if (ConstantVector::IsNull(vector)) { - for (idx_t i = 0; i < count; i++) { - result_mask.SetInvalid(sel.get_index(i)); - } - } - } else { - UnifiedVectorFormat vdata; - vector.ToUnifiedFormat(count, vdata); - if (vdata.validity.AllValid()) { - return; - } - for (idx_t i = 0; i < count; i++) { - auto source_idx = vdata.sel->get_index(i); - if (!vdata.validity.RowIsValid(source_idx)) { - result_mask.SetInvalid(sel.get_index(i)); - } - } - } -} - -void ExpressionExecutor::FillSwitch(Vector &vector, Vector &result, const SelectionVector &sel, sel_t count) { - switch (result.GetType().InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedFillLoop(vector, result, sel, count); - break; - case PhysicalType::INT16: - TemplatedFillLoop(vector, result, sel, count); - break; - case PhysicalType::INT32: - TemplatedFillLoop(vector, result, sel, count); - break; - case PhysicalType::INT64: - TemplatedFillLoop(vector, result, sel, count); - break; - case PhysicalType::UINT8: - TemplatedFillLoop(vector, result, sel, count); - break; - case PhysicalType::UINT16: - TemplatedFillLoop(vector, result, sel, count); - break; - case PhysicalType::UINT32: - TemplatedFillLoop(vector, result, sel, count); - break; - case PhysicalType::UINT64: - TemplatedFillLoop(vector, result, sel, count); - break; - case PhysicalType::INT128: - TemplatedFillLoop(vector, result, sel, count); - break; - case PhysicalType::FLOAT: - TemplatedFillLoop(vector, result, sel, count); - break; - case PhysicalType::DOUBLE: - TemplatedFillLoop(vector, result, sel, count); - break; - case PhysicalType::INTERVAL: - TemplatedFillLoop(vector, result, sel, count); - break; - case PhysicalType::VARCHAR: - TemplatedFillLoop(vector, result, sel, count); - StringVector::AddHeapReference(result, vector); - break; - case PhysicalType::STRUCT: { - auto &vector_entries = StructVector::GetEntries(vector); - auto &result_entries = StructVector::GetEntries(result); - ValidityFillLoop(vector, result, sel, count); - D_ASSERT(vector_entries.size() == result_entries.size()); - for (idx_t i = 0; i < vector_entries.size(); i++) { - FillSwitch(*vector_entries[i], *result_entries[i], sel, count); - } - break; - } - case PhysicalType::LIST: { - idx_t offset = ListVector::GetListSize(result); - auto &list_child = ListVector::GetEntry(vector); - ListVector::Append(result, list_child, ListVector::GetListSize(vector)); - - // all the false offsets need to be incremented by true_child.count - TemplatedFillLoop(vector, result, sel, count); - if (offset == 0) { - break; - } - - auto result_data = FlatVector::GetData(result); - for (idx_t i = 0; i < count; i++) { - auto result_idx = sel.get_index(i); - result_data[result_idx].offset += offset; - } - - Vector::Verify(result, sel, count); - break; - } - default: - throw NotImplementedException("Unimplemented type for case expression: %s", result.GetType().ToString()); - } -} - -} // namespace duckdb - - - - - -namespace duckdb { - -unique_ptr ExpressionExecutor::InitializeState(const BoundCastExpression &expr, - ExpressionExecutorState &root) { - auto result = make_uniq(expr, root); - result->AddChild(expr.child.get()); - result->Finalize(); - if (expr.bound_cast.init_local_state) { - CastLocalStateParameters parameters(root.executor->GetContext(), expr.bound_cast.cast_data); - result->local_state = expr.bound_cast.init_local_state(parameters); - } - return std::move(result); -} - -void ExpressionExecutor::Execute(const BoundCastExpression &expr, ExpressionState *state, const SelectionVector *sel, - idx_t count, Vector &result) { - auto lstate = ExecuteFunctionState::GetFunctionState(*state); - - // resolve the child - state->intermediate_chunk.Reset(); - - auto &child = state->intermediate_chunk.data[0]; - auto child_state = state->child_states[0].get(); - - Execute(*expr.child, child_state, sel, count, child); - if (expr.try_cast) { - string error_message; - CastParameters parameters(expr.bound_cast.cast_data.get(), false, &error_message, lstate); - expr.bound_cast.function(child, result, count, parameters); - } else { - // cast it to the type specified by the cast expression - D_ASSERT(result.GetType() == expr.return_type); - CastParameters parameters(expr.bound_cast.cast_data.get(), false, nullptr, lstate); - expr.bound_cast.function(child, result, count, parameters); - } -} - -} // namespace duckdb - - - - - - -#include - -namespace duckdb { - -unique_ptr ExpressionExecutor::InitializeState(const BoundComparisonExpression &expr, - ExpressionExecutorState &root) { - auto result = make_uniq(expr, root); - result->AddChild(expr.left.get()); - result->AddChild(expr.right.get()); - result->Finalize(); - return result; -} - -void ExpressionExecutor::Execute(const BoundComparisonExpression &expr, ExpressionState *state, - const SelectionVector *sel, idx_t count, Vector &result) { - // resolve the children - state->intermediate_chunk.Reset(); - auto &left = state->intermediate_chunk.data[0]; - auto &right = state->intermediate_chunk.data[1]; - - Execute(*expr.left, state->child_states[0].get(), sel, count, left); - Execute(*expr.right, state->child_states[1].get(), sel, count, right); - - switch (expr.type) { - case ExpressionType::COMPARE_EQUAL: - VectorOperations::Equals(left, right, result, count); - break; - case ExpressionType::COMPARE_NOTEQUAL: - VectorOperations::NotEquals(left, right, result, count); - break; - case ExpressionType::COMPARE_LESSTHAN: - VectorOperations::LessThan(left, right, result, count); - break; - case ExpressionType::COMPARE_GREATERTHAN: - VectorOperations::GreaterThan(left, right, result, count); - break; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - VectorOperations::LessThanEquals(left, right, result, count); - break; - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - VectorOperations::GreaterThanEquals(left, right, result, count); - break; - case ExpressionType::COMPARE_DISTINCT_FROM: - VectorOperations::DistinctFrom(left, right, result, count); - break; - case ExpressionType::COMPARE_NOT_DISTINCT_FROM: - VectorOperations::NotDistinctFrom(left, right, result, count); - break; - default: - throw InternalException("Unknown comparison type!"); - } -} - -template -static idx_t NestedSelectOperation(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel); - -template -static idx_t TemplatedSelectOperation(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - // the inplace loops take the result as the last parameter - switch (left.GetType().InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - return BinaryExecutor::Select(left, right, sel, count, true_sel, false_sel); - case PhysicalType::INT16: - return BinaryExecutor::Select(left, right, sel, count, true_sel, false_sel); - case PhysicalType::INT32: - return BinaryExecutor::Select(left, right, sel, count, true_sel, false_sel); - case PhysicalType::INT64: - return BinaryExecutor::Select(left, right, sel, count, true_sel, false_sel); - case PhysicalType::UINT8: - return BinaryExecutor::Select(left, right, sel, count, true_sel, false_sel); - case PhysicalType::UINT16: - return BinaryExecutor::Select(left, right, sel, count, true_sel, false_sel); - case PhysicalType::UINT32: - return BinaryExecutor::Select(left, right, sel, count, true_sel, false_sel); - case PhysicalType::UINT64: - return BinaryExecutor::Select(left, right, sel, count, true_sel, false_sel); - case PhysicalType::INT128: - return BinaryExecutor::Select(left, right, sel, count, true_sel, false_sel); - case PhysicalType::FLOAT: - return BinaryExecutor::Select(left, right, sel, count, true_sel, false_sel); - case PhysicalType::DOUBLE: - return BinaryExecutor::Select(left, right, sel, count, true_sel, false_sel); - case PhysicalType::INTERVAL: - return BinaryExecutor::Select(left, right, sel, count, true_sel, false_sel); - case PhysicalType::VARCHAR: - return BinaryExecutor::Select(left, right, sel, count, true_sel, false_sel); - case PhysicalType::LIST: - case PhysicalType::STRUCT: - return NestedSelectOperation(left, right, sel, count, true_sel, false_sel); - default: - throw InternalException("Invalid type for comparison"); - } -} - -struct NestedSelector { - // Select the matching rows for the values of a nested type that are not both NULL. - // Those semantics are the same as the corresponding non-distinct comparator - template - static idx_t Select(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, SelectionVector *true_sel, - SelectionVector *false_sel) { - throw InvalidTypeException(left.GetType(), "Invalid operation for nested SELECT"); - } -}; - -template <> -idx_t NestedSelector::Select(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - return VectorOperations::NestedEquals(left, right, sel, count, true_sel, false_sel); -} - -template <> -idx_t NestedSelector::Select(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - return VectorOperations::NestedNotEquals(left, right, sel, count, true_sel, false_sel); -} - -template <> -idx_t NestedSelector::Select(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - return VectorOperations::DistinctLessThan(left, right, &sel, count, true_sel, false_sel); -} - -template <> -idx_t NestedSelector::Select(Vector &left, Vector &right, const SelectionVector &sel, - idx_t count, SelectionVector *true_sel, - SelectionVector *false_sel) { - return VectorOperations::DistinctLessThanEquals(left, right, &sel, count, true_sel, false_sel); -} - -template <> -idx_t NestedSelector::Select(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - return VectorOperations::DistinctGreaterThan(left, right, &sel, count, true_sel, false_sel); -} - -template <> -idx_t NestedSelector::Select(Vector &left, Vector &right, const SelectionVector &sel, - idx_t count, SelectionVector *true_sel, - SelectionVector *false_sel) { - return VectorOperations::DistinctGreaterThanEquals(left, right, &sel, count, true_sel, false_sel); -} - -static inline idx_t SelectNotNull(Vector &left, Vector &right, const idx_t count, const SelectionVector &sel, - SelectionVector &maybe_vec, OptionalSelection &false_opt) { - - UnifiedVectorFormat lvdata, rvdata; - left.ToUnifiedFormat(count, lvdata); - right.ToUnifiedFormat(count, rvdata); - - auto &lmask = lvdata.validity; - auto &rmask = rvdata.validity; - - // For top-level comparisons, NULL semantics are in effect, - // so filter out any NULLs - idx_t remaining = 0; - if (lmask.AllValid() && rmask.AllValid()) { - // None are NULL, distinguish values. - for (idx_t i = 0; i < count; ++i) { - const auto idx = sel.get_index(i); - maybe_vec.set_index(remaining++, idx); - } - return remaining; - } - - // Slice the Vectors down to the rows that are not determined (i.e., neither is NULL) - SelectionVector slicer(count); - idx_t false_count = 0; - for (idx_t i = 0; i < count; ++i) { - const auto result_idx = sel.get_index(i); - const auto lidx = lvdata.sel->get_index(i); - const auto ridx = rvdata.sel->get_index(i); - if (!lmask.RowIsValid(lidx) || !rmask.RowIsValid(ridx)) { - false_opt.Append(false_count, result_idx); - } else { - // Neither is NULL, distinguish values. - slicer.set_index(remaining, i); - maybe_vec.set_index(remaining++, result_idx); - } - } - false_opt.Advance(false_count); - - if (remaining && remaining < count) { - left.Slice(slicer, remaining); - right.Slice(slicer, remaining); - } - - return remaining; -} - -static void ScatterSelection(SelectionVector *target, const idx_t count, const SelectionVector &dense_vec) { - if (target) { - for (idx_t i = 0; i < count; ++i) { - target->set_index(i, dense_vec.get_index(i)); - } - } -} - -template -static idx_t NestedSelectOperation(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - // The Select operations all use a dense pair of input vectors to partition - // a selection vector in a single pass. But to implement progressive comparisons, - // we have to make multiple passes, so we need to keep track of the original input positions - // and then scatter the output selections when we are done. - if (!sel) { - sel = FlatVector::IncrementalSelectionVector(); - } - - // Make buffered selections for progressive comparisons - // TODO: Remove unnecessary allocations - SelectionVector true_vec(count); - OptionalSelection true_opt(&true_vec); - - SelectionVector false_vec(count); - OptionalSelection false_opt(&false_vec); - - SelectionVector maybe_vec(count); - - // Handle NULL nested values - Vector l_not_null(left); - Vector r_not_null(right); - - auto match_count = SelectNotNull(l_not_null, r_not_null, count, *sel, maybe_vec, false_opt); - auto no_match_count = count - match_count; - count = match_count; - - // Now that we have handled the NULLs, we can use the recursive nested comparator for the rest. - match_count = NestedSelector::Select(l_not_null, r_not_null, maybe_vec, count, true_opt, false_opt); - no_match_count += (count - match_count); - - // Copy the buffered selections to the output selections - ScatterSelection(true_sel, match_count, true_vec); - ScatterSelection(false_sel, no_match_count, false_vec); - - return match_count; -} - -idx_t VectorOperations::Equals(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - return TemplatedSelectOperation(left, right, sel, count, true_sel, false_sel); -} - -idx_t VectorOperations::NotEquals(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - return TemplatedSelectOperation(left, right, sel, count, true_sel, false_sel); -} - -idx_t VectorOperations::GreaterThan(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - return TemplatedSelectOperation(left, right, sel, count, true_sel, false_sel); -} - -idx_t VectorOperations::GreaterThanEquals(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - return TemplatedSelectOperation(left, right, sel, count, true_sel, false_sel); -} - -idx_t VectorOperations::LessThan(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - return TemplatedSelectOperation(right, left, sel, count, true_sel, false_sel); -} - -idx_t VectorOperations::LessThanEquals(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - return TemplatedSelectOperation(right, left, sel, count, true_sel, false_sel); -} - -idx_t ExpressionExecutor::Select(const BoundComparisonExpression &expr, ExpressionState *state, - const SelectionVector *sel, idx_t count, SelectionVector *true_sel, - SelectionVector *false_sel) { - // resolve the children - state->intermediate_chunk.Reset(); - auto &left = state->intermediate_chunk.data[0]; - auto &right = state->intermediate_chunk.data[1]; - - Execute(*expr.left, state->child_states[0].get(), sel, count, left); - Execute(*expr.right, state->child_states[1].get(), sel, count, right); - - switch (expr.type) { - case ExpressionType::COMPARE_EQUAL: - return VectorOperations::Equals(left, right, sel, count, true_sel, false_sel); - case ExpressionType::COMPARE_NOTEQUAL: - return VectorOperations::NotEquals(left, right, sel, count, true_sel, false_sel); - case ExpressionType::COMPARE_LESSTHAN: - return VectorOperations::LessThan(left, right, sel, count, true_sel, false_sel); - case ExpressionType::COMPARE_GREATERTHAN: - return VectorOperations::GreaterThan(left, right, sel, count, true_sel, false_sel); - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - return VectorOperations::LessThanEquals(left, right, sel, count, true_sel, false_sel); - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - return VectorOperations::GreaterThanEquals(left, right, sel, count, true_sel, false_sel); - case ExpressionType::COMPARE_DISTINCT_FROM: - return VectorOperations::DistinctFrom(left, right, sel, count, true_sel, false_sel); - case ExpressionType::COMPARE_NOT_DISTINCT_FROM: - return VectorOperations::NotDistinctFrom(left, right, sel, count, true_sel, false_sel); - default: - throw InternalException("Unknown comparison type!"); - } -} - -} // namespace duckdb - - - - - - -#include - -namespace duckdb { - -struct ConjunctionState : public ExpressionState { - ConjunctionState(const Expression &expr, ExpressionExecutorState &root) : ExpressionState(expr, root) { - adaptive_filter = make_uniq(expr); - } - unique_ptr adaptive_filter; -}; - -unique_ptr ExpressionExecutor::InitializeState(const BoundConjunctionExpression &expr, - ExpressionExecutorState &root) { - auto result = make_uniq(expr, root); - for (auto &child : expr.children) { - result->AddChild(child.get()); - } - result->Finalize(); - return std::move(result); -} - -void ExpressionExecutor::Execute(const BoundConjunctionExpression &expr, ExpressionState *state, - const SelectionVector *sel, idx_t count, Vector &result) { - // execute the children - state->intermediate_chunk.Reset(); - for (idx_t i = 0; i < expr.children.size(); i++) { - auto ¤t_result = state->intermediate_chunk.data[i]; - Execute(*expr.children[i], state->child_states[i].get(), sel, count, current_result); - if (i == 0) { - // move the result - result.Reference(current_result); - } else { - Vector intermediate(LogicalType::BOOLEAN); - // AND/OR together - switch (expr.type) { - case ExpressionType::CONJUNCTION_AND: - VectorOperations::And(current_result, result, intermediate, count); - break; - case ExpressionType::CONJUNCTION_OR: - VectorOperations::Or(current_result, result, intermediate, count); - break; - default: - throw InternalException("Unknown conjunction type!"); - } - result.Reference(intermediate); - } - } -} - -idx_t ExpressionExecutor::Select(const BoundConjunctionExpression &expr, ExpressionState *state_p, - const SelectionVector *sel, idx_t count, SelectionVector *true_sel, - SelectionVector *false_sel) { - auto &state = state_p->Cast(); - - if (expr.type == ExpressionType::CONJUNCTION_AND) { - // get runtime statistics - auto start_time = high_resolution_clock::now(); - - const SelectionVector *current_sel = sel; - idx_t current_count = count; - idx_t false_count = 0; - - unique_ptr temp_true, temp_false; - if (false_sel) { - temp_false = make_uniq(STANDARD_VECTOR_SIZE); - } - if (!true_sel) { - temp_true = make_uniq(STANDARD_VECTOR_SIZE); - true_sel = temp_true.get(); - } - for (idx_t i = 0; i < expr.children.size(); i++) { - idx_t tcount = Select(*expr.children[state.adaptive_filter->permutation[i]], - state.child_states[state.adaptive_filter->permutation[i]].get(), current_sel, - current_count, true_sel, temp_false.get()); - idx_t fcount = current_count - tcount; - if (fcount > 0 && false_sel) { - // move failing tuples into the false_sel - // tuples passed, move them into the actual result vector - for (idx_t i = 0; i < fcount; i++) { - false_sel->set_index(false_count++, temp_false->get_index(i)); - } - } - current_count = tcount; - if (current_count == 0) { - break; - } - if (current_count < count) { - // tuples were filtered out: move on to using the true_sel to only evaluate passing tuples in subsequent - // iterations - current_sel = true_sel; - } - } - - // adapt runtime statistics - auto end_time = high_resolution_clock::now(); - state.adaptive_filter->AdaptRuntimeStatistics(duration_cast>(end_time - start_time).count()); - return current_count; - } else { - // get runtime statistics - auto start_time = high_resolution_clock::now(); - - const SelectionVector *current_sel = sel; - idx_t current_count = count; - idx_t result_count = 0; - - unique_ptr temp_true, temp_false; - if (true_sel) { - temp_true = make_uniq(STANDARD_VECTOR_SIZE); - } - if (!false_sel) { - temp_false = make_uniq(STANDARD_VECTOR_SIZE); - false_sel = temp_false.get(); - } - for (idx_t i = 0; i < expr.children.size(); i++) { - idx_t tcount = Select(*expr.children[state.adaptive_filter->permutation[i]], - state.child_states[state.adaptive_filter->permutation[i]].get(), current_sel, - current_count, temp_true.get(), false_sel); - if (tcount > 0) { - if (true_sel) { - // tuples passed, move them into the actual result vector - for (idx_t i = 0; i < tcount; i++) { - true_sel->set_index(result_count++, temp_true->get_index(i)); - } - } - // now move on to check only the non-passing tuples - current_count -= tcount; - current_sel = false_sel; - } - } - - // adapt runtime statistics - auto end_time = high_resolution_clock::now(); - state.adaptive_filter->AdaptRuntimeStatistics(duration_cast>(end_time - start_time).count()); - return result_count; - } -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr ExpressionExecutor::InitializeState(const BoundConstantExpression &expr, - ExpressionExecutorState &root) { - auto result = make_uniq(expr, root); - result->Finalize(); - return result; -} - -void ExpressionExecutor::Execute(const BoundConstantExpression &expr, ExpressionState *state, - const SelectionVector *sel, idx_t count, Vector &result) { - D_ASSERT(expr.value.type() == expr.return_type); - result.Reference(expr.value); -} - -} // namespace duckdb - - - -namespace duckdb { - -ExecuteFunctionState::ExecuteFunctionState(const Expression &expr, ExpressionExecutorState &root) - : ExpressionState(expr, root) { -} - -ExecuteFunctionState::~ExecuteFunctionState() { -} - -unique_ptr ExpressionExecutor::InitializeState(const BoundFunctionExpression &expr, - ExpressionExecutorState &root) { - auto result = make_uniq(expr, root); - for (auto &child : expr.children) { - result->AddChild(child.get()); - } - result->Finalize(); - if (expr.function.init_local_state) { - result->local_state = expr.function.init_local_state(*result, expr, expr.bind_info.get()); - } - return std::move(result); -} - -static void VerifyNullHandling(const BoundFunctionExpression &expr, DataChunk &args, Vector &result) { -#ifdef DEBUG - if (args.data.empty() || expr.function.null_handling != FunctionNullHandling::DEFAULT_NULL_HANDLING) { - return; - } - - // Combine all the argument validity masks into a flat validity mask - idx_t count = args.size(); - ValidityMask combined_mask(count); - for (auto &arg : args.data) { - UnifiedVectorFormat arg_data; - arg.ToUnifiedFormat(count, arg_data); - - for (idx_t i = 0; i < count; i++) { - auto idx = arg_data.sel->get_index(i); - if (!arg_data.validity.RowIsValid(idx)) { - combined_mask.SetInvalid(i); - } - } - } - - // Default is that if any of the arguments are NULL, the result is also NULL - UnifiedVectorFormat result_data; - result.ToUnifiedFormat(count, result_data); - for (idx_t i = 0; i < count; i++) { - if (!combined_mask.RowIsValid(i)) { - auto idx = result_data.sel->get_index(i); - D_ASSERT(!result_data.validity.RowIsValid(idx)); - } - } -#endif -} - -void ExpressionExecutor::Execute(const BoundFunctionExpression &expr, ExpressionState *state, - const SelectionVector *sel, idx_t count, Vector &result) { - state->intermediate_chunk.Reset(); - auto &arguments = state->intermediate_chunk; - if (!state->types.empty()) { - for (idx_t i = 0; i < expr.children.size(); i++) { - D_ASSERT(state->types[i] == expr.children[i]->return_type); - Execute(*expr.children[i], state->child_states[i].get(), sel, count, arguments.data[i]); -#ifdef DEBUG - if (expr.children[i]->return_type.id() == LogicalTypeId::VARCHAR) { - arguments.data[i].UTFVerify(count); - } -#endif - } - arguments.Verify(); - } - arguments.SetCardinality(count); - - state->profiler.BeginSample(); - D_ASSERT(expr.function.function); - expr.function.function(arguments, *state, result); - state->profiler.EndSample(count); - - VerifyNullHandling(expr, arguments, result); - D_ASSERT(result.GetType() == expr.return_type); -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr ExpressionExecutor::InitializeState(const BoundOperatorExpression &expr, - ExpressionExecutorState &root) { - auto result = make_uniq(expr, root); - for (auto &child : expr.children) { - result->AddChild(child.get()); - } - result->Finalize(); - return result; -} - -void ExpressionExecutor::Execute(const BoundOperatorExpression &expr, ExpressionState *state, - const SelectionVector *sel, idx_t count, Vector &result) { - // special handling for special snowflake 'IN' - // IN has n children - if (expr.type == ExpressionType::COMPARE_IN || expr.type == ExpressionType::COMPARE_NOT_IN) { - if (expr.children.size() < 2) { - throw InvalidInputException("IN needs at least two children"); - } - - Vector left(expr.children[0]->return_type); - // eval left side - Execute(*expr.children[0], state->child_states[0].get(), sel, count, left); - - // init result to false - Vector intermediate(LogicalType::BOOLEAN); - Value false_val = Value::BOOLEAN(false); - intermediate.Reference(false_val); - - // in rhs is a list of constants - // for every child, OR the result of the comparision with the left - // to get the overall result. - for (idx_t child = 1; child < expr.children.size(); child++) { - Vector vector_to_check(expr.children[child]->return_type); - Vector comp_res(LogicalType::BOOLEAN); - - Execute(*expr.children[child], state->child_states[child].get(), sel, count, vector_to_check); - VectorOperations::Equals(left, vector_to_check, comp_res, count); - - if (child == 1) { - // first child: move to result - intermediate.Reference(comp_res); - } else { - // otherwise OR together - Vector new_result(LogicalType::BOOLEAN, true, false); - VectorOperations::Or(intermediate, comp_res, new_result, count); - intermediate.Reference(new_result); - } - } - if (expr.type == ExpressionType::COMPARE_NOT_IN) { - // NOT IN: invert result - VectorOperations::Not(intermediate, result, count); - } else { - // directly use the result - result.Reference(intermediate); - } - } else if (expr.type == ExpressionType::OPERATOR_COALESCE) { - SelectionVector sel_a(count); - SelectionVector sel_b(count); - SelectionVector slice_sel(count); - SelectionVector result_sel(count); - SelectionVector *next_sel = &sel_a; - const SelectionVector *current_sel = sel; - idx_t remaining_count = count; - idx_t next_count; - for (idx_t child = 0; child < expr.children.size(); child++) { - Vector vector_to_check(expr.children[child]->return_type); - Execute(*expr.children[child], state->child_states[child].get(), current_sel, remaining_count, - vector_to_check); - - UnifiedVectorFormat vdata; - vector_to_check.ToUnifiedFormat(remaining_count, vdata); - - idx_t result_count = 0; - next_count = 0; - for (idx_t i = 0; i < remaining_count; i++) { - auto base_idx = current_sel ? current_sel->get_index(i) : i; - auto idx = vdata.sel->get_index(i); - if (vdata.validity.RowIsValid(idx)) { - slice_sel.set_index(result_count, i); - result_sel.set_index(result_count++, base_idx); - } else { - next_sel->set_index(next_count++, base_idx); - } - } - if (result_count > 0) { - vector_to_check.Slice(slice_sel, result_count); - FillSwitch(vector_to_check, result, result_sel, result_count); - } - current_sel = next_sel; - next_sel = next_sel == &sel_a ? &sel_b : &sel_a; - remaining_count = next_count; - if (next_count == 0) { - break; - } - } - if (remaining_count > 0) { - for (idx_t i = 0; i < remaining_count; i++) { - FlatVector::SetNull(result, current_sel->get_index(i), true); - } - } - if (sel) { - result.Slice(*sel, count); - } else if (count == 1) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } - } else if (expr.children.size() == 1) { - state->intermediate_chunk.Reset(); - auto &child = state->intermediate_chunk.data[0]; - - Execute(*expr.children[0], state->child_states[0].get(), sel, count, child); - switch (expr.type) { - case ExpressionType::OPERATOR_NOT: { - VectorOperations::Not(child, result, count); - break; - } - case ExpressionType::OPERATOR_IS_NULL: { - VectorOperations::IsNull(child, result, count); - break; - } - case ExpressionType::OPERATOR_IS_NOT_NULL: { - VectorOperations::IsNotNull(child, result, count); - break; - } - default: - throw NotImplementedException("Unsupported operator type with 1 child!"); - } - } else { - throw NotImplementedException("operator"); - } -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr ExpressionExecutor::InitializeState(const BoundParameterExpression &expr, - ExpressionExecutorState &root) { - auto result = make_uniq(expr, root); - result->Finalize(); - return result; -} - -void ExpressionExecutor::Execute(const BoundParameterExpression &expr, ExpressionState *state, - const SelectionVector *sel, idx_t count, Vector &result) { - D_ASSERT(expr.parameter_data); - D_ASSERT(expr.parameter_data->return_type == expr.return_type); - D_ASSERT(expr.parameter_data->GetValue().type() == expr.return_type); - result.Reference(expr.parameter_data->GetValue()); -} - -} // namespace duckdb - - - -namespace duckdb { - -unique_ptr ExpressionExecutor::InitializeState(const BoundReferenceExpression &expr, - ExpressionExecutorState &root) { - auto result = make_uniq(expr, root); - result->Finalize(); - return result; -} - -void ExpressionExecutor::Execute(const BoundReferenceExpression &expr, ExpressionState *state, - const SelectionVector *sel, idx_t count, Vector &result) { - D_ASSERT(expr.index != DConstants::INVALID_INDEX); - D_ASSERT(expr.index < chunk->ColumnCount()); - - if (sel) { - result.Slice(chunk->data[expr.index], *sel, count); - } else { - result.Reference(chunk->data[expr.index]); - } -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -ExpressionExecutor::ExpressionExecutor(ClientContext &context) : context(&context) { -} - -ExpressionExecutor::ExpressionExecutor(ClientContext &context, const Expression *expression) - : ExpressionExecutor(context) { - D_ASSERT(expression); - AddExpression(*expression); -} - -ExpressionExecutor::ExpressionExecutor(ClientContext &context, const Expression &expression) - : ExpressionExecutor(context) { - AddExpression(expression); -} - -ExpressionExecutor::ExpressionExecutor(ClientContext &context, const vector> &exprs) - : ExpressionExecutor(context) { - D_ASSERT(exprs.size() > 0); - for (auto &expr : exprs) { - AddExpression(*expr); - } -} - -ExpressionExecutor::ExpressionExecutor(const vector> &exprs) : context(nullptr) { - D_ASSERT(exprs.size() > 0); - for (auto &expr : exprs) { - AddExpression(*expr); - } -} - -ExpressionExecutor::ExpressionExecutor() : context(nullptr) { -} - -bool ExpressionExecutor::HasContext() { - return context; -} - -ClientContext &ExpressionExecutor::GetContext() { - if (!context) { - throw InternalException("Calling ExpressionExecutor::GetContext on an expression executor without a context"); - } - return *context; -} - -Allocator &ExpressionExecutor::GetAllocator() { - return context ? Allocator::Get(*context) : Allocator::DefaultAllocator(); -} - -void ExpressionExecutor::AddExpression(const Expression &expr) { - expressions.push_back(&expr); - auto state = make_uniq(); - Initialize(expr, *state); - state->Verify(); - states.push_back(std::move(state)); -} - -void ExpressionExecutor::Initialize(const Expression &expression, ExpressionExecutorState &state) { - state.executor = this; - state.root_state = InitializeState(expression, state); -} - -void ExpressionExecutor::Execute(DataChunk *input, DataChunk &result) { - SetChunk(input); - D_ASSERT(expressions.size() == result.ColumnCount()); - D_ASSERT(!expressions.empty()); - - for (idx_t i = 0; i < expressions.size(); i++) { - ExecuteExpression(i, result.data[i]); - } - result.SetCardinality(input ? input->size() : 1); - result.Verify(); -} - -void ExpressionExecutor::ExecuteExpression(DataChunk &input, Vector &result) { - SetChunk(&input); - ExecuteExpression(result); -} - -idx_t ExpressionExecutor::SelectExpression(DataChunk &input, SelectionVector &sel) { - D_ASSERT(expressions.size() == 1); - SetChunk(&input); - states[0]->profiler.BeginSample(); - idx_t selected_tuples = Select(*expressions[0], states[0]->root_state.get(), nullptr, input.size(), &sel, nullptr); - states[0]->profiler.EndSample(chunk ? chunk->size() : 0); - return selected_tuples; -} - -void ExpressionExecutor::ExecuteExpression(Vector &result) { - D_ASSERT(expressions.size() == 1); - ExecuteExpression(0, result); -} - -void ExpressionExecutor::ExecuteExpression(idx_t expr_idx, Vector &result) { - D_ASSERT(expr_idx < expressions.size()); - D_ASSERT(result.GetType().id() == expressions[expr_idx]->return_type.id()); - states[expr_idx]->profiler.BeginSample(); - Execute(*expressions[expr_idx], states[expr_idx]->root_state.get(), nullptr, chunk ? chunk->size() : 1, result); - states[expr_idx]->profiler.EndSample(chunk ? chunk->size() : 0); -} - -Value ExpressionExecutor::EvaluateScalar(ClientContext &context, const Expression &expr, bool allow_unfoldable) { - D_ASSERT(allow_unfoldable || expr.IsFoldable()); - D_ASSERT(expr.IsScalar()); - // use an ExpressionExecutor to execute the expression - ExpressionExecutor executor(context, expr); - - Vector result(expr.return_type); - executor.ExecuteExpression(result); - - D_ASSERT(allow_unfoldable || result.GetVectorType() == VectorType::CONSTANT_VECTOR); - auto result_value = result.GetValue(0); - D_ASSERT(result_value.type().InternalType() == expr.return_type.InternalType()); - return result_value; -} - -bool ExpressionExecutor::TryEvaluateScalar(ClientContext &context, const Expression &expr, Value &result) { - try { - result = EvaluateScalar(context, expr); - return true; - } catch (InternalException &ex) { - throw; - } catch (...) { - return false; - } -} - -void ExpressionExecutor::Verify(const Expression &expr, Vector &vector, idx_t count) { - D_ASSERT(expr.return_type.id() == vector.GetType().id()); - vector.Verify(count); - if (expr.verification_stats) { - expr.verification_stats->Verify(vector, count); - } -} - -unique_ptr ExpressionExecutor::InitializeState(const Expression &expr, - ExpressionExecutorState &state) { - switch (expr.expression_class) { - case ExpressionClass::BOUND_REF: - return InitializeState(expr.Cast(), state); - case ExpressionClass::BOUND_BETWEEN: - return InitializeState(expr.Cast(), state); - case ExpressionClass::BOUND_CASE: - return InitializeState(expr.Cast(), state); - case ExpressionClass::BOUND_CAST: - return InitializeState(expr.Cast(), state); - case ExpressionClass::BOUND_COMPARISON: - return InitializeState(expr.Cast(), state); - case ExpressionClass::BOUND_CONJUNCTION: - return InitializeState(expr.Cast(), state); - case ExpressionClass::BOUND_CONSTANT: - return InitializeState(expr.Cast(), state); - case ExpressionClass::BOUND_FUNCTION: - return InitializeState(expr.Cast(), state); - case ExpressionClass::BOUND_OPERATOR: - return InitializeState(expr.Cast(), state); - case ExpressionClass::BOUND_PARAMETER: - return InitializeState(expr.Cast(), state); - default: - throw InternalException("Attempting to initialize state of expression of unknown type!"); - } -} - -void ExpressionExecutor::Execute(const Expression &expr, ExpressionState *state, const SelectionVector *sel, - idx_t count, Vector &result) { -#ifdef DEBUG - //! The result Vector must be "clean" - if (result.GetVectorType() == VectorType::FLAT_VECTOR) { - D_ASSERT(FlatVector::Validity(result).CheckAllValid(count)); - } -#endif - - if (count == 0) { - return; - } - if (result.GetType().id() != expr.return_type.id()) { - throw InternalException( - "ExpressionExecutor::Execute called with a result vector of type %s that does not match expression type %s", - result.GetType(), expr.return_type); - } - switch (expr.expression_class) { - case ExpressionClass::BOUND_BETWEEN: - Execute(expr.Cast(), state, sel, count, result); - break; - case ExpressionClass::BOUND_REF: - Execute(expr.Cast(), state, sel, count, result); - break; - case ExpressionClass::BOUND_CASE: - Execute(expr.Cast(), state, sel, count, result); - break; - case ExpressionClass::BOUND_CAST: - Execute(expr.Cast(), state, sel, count, result); - break; - case ExpressionClass::BOUND_COMPARISON: - Execute(expr.Cast(), state, sel, count, result); - break; - case ExpressionClass::BOUND_CONJUNCTION: - Execute(expr.Cast(), state, sel, count, result); - break; - case ExpressionClass::BOUND_CONSTANT: - Execute(expr.Cast(), state, sel, count, result); - break; - case ExpressionClass::BOUND_FUNCTION: - Execute(expr.Cast(), state, sel, count, result); - break; - case ExpressionClass::BOUND_OPERATOR: - Execute(expr.Cast(), state, sel, count, result); - break; - case ExpressionClass::BOUND_PARAMETER: - Execute(expr.Cast(), state, sel, count, result); - break; - default: - throw InternalException("Attempting to execute expression of unknown type!"); - } - Verify(expr, result, count); -} - -idx_t ExpressionExecutor::Select(const Expression &expr, ExpressionState *state, const SelectionVector *sel, - idx_t count, SelectionVector *true_sel, SelectionVector *false_sel) { - if (count == 0) { - return 0; - } - D_ASSERT(true_sel || false_sel); - D_ASSERT(expr.return_type.id() == LogicalTypeId::BOOLEAN); - switch (expr.expression_class) { - case ExpressionClass::BOUND_BETWEEN: - return Select(expr.Cast(), state, sel, count, true_sel, false_sel); - case ExpressionClass::BOUND_COMPARISON: - return Select(expr.Cast(), state, sel, count, true_sel, false_sel); - case ExpressionClass::BOUND_CONJUNCTION: - return Select(expr.Cast(), state, sel, count, true_sel, false_sel); - default: - return DefaultSelect(expr, state, sel, count, true_sel, false_sel); - } -} - -template -static inline idx_t DefaultSelectLoop(const SelectionVector *bsel, const uint8_t *__restrict bdata, ValidityMask &mask, - const SelectionVector *sel, idx_t count, SelectionVector *true_sel, - SelectionVector *false_sel) { - idx_t true_count = 0, false_count = 0; - for (idx_t i = 0; i < count; i++) { - auto bidx = bsel->get_index(i); - auto result_idx = sel->get_index(i); - if (bdata[bidx] > 0 && (NO_NULL || mask.RowIsValid(bidx))) { - if (HAS_TRUE_SEL) { - true_sel->set_index(true_count++, result_idx); - } - } else { - if (HAS_FALSE_SEL) { - false_sel->set_index(false_count++, result_idx); - } - } - } - if (HAS_TRUE_SEL) { - return true_count; - } else { - return count - false_count; - } -} - -template -static inline idx_t DefaultSelectSwitch(UnifiedVectorFormat &idata, const SelectionVector *sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - if (true_sel && false_sel) { - return DefaultSelectLoop(idata.sel, UnifiedVectorFormat::GetData(idata), - idata.validity, sel, count, true_sel, false_sel); - } else if (true_sel) { - return DefaultSelectLoop(idata.sel, UnifiedVectorFormat::GetData(idata), - idata.validity, sel, count, true_sel, false_sel); - } else { - D_ASSERT(false_sel); - return DefaultSelectLoop(idata.sel, UnifiedVectorFormat::GetData(idata), - idata.validity, sel, count, true_sel, false_sel); - } -} - -idx_t ExpressionExecutor::DefaultSelect(const Expression &expr, ExpressionState *state, const SelectionVector *sel, - idx_t count, SelectionVector *true_sel, SelectionVector *false_sel) { - // generic selection of boolean expression: - // resolve the true/false expression first - // then use that to generate the selection vector - bool intermediate_bools[STANDARD_VECTOR_SIZE]; - Vector intermediate(LogicalType::BOOLEAN, data_ptr_cast(intermediate_bools)); - Execute(expr, state, sel, count, intermediate); - - UnifiedVectorFormat idata; - intermediate.ToUnifiedFormat(count, idata); - - if (!sel) { - sel = FlatVector::IncrementalSelectionVector(); - } - if (!idata.validity.AllValid()) { - return DefaultSelectSwitch(idata, sel, count, true_sel, false_sel); - } else { - return DefaultSelectSwitch(idata, sel, count, true_sel, false_sel); - } -} - -vector> &ExpressionExecutor::GetStates() { - return states; -} - -} // namespace duckdb - - - - - -namespace duckdb { - -void ExpressionState::AddChild(Expression *expr) { - types.push_back(expr->return_type); - child_states.push_back(ExpressionExecutor::InitializeState(*expr, root)); -} - -void ExpressionState::Finalize() { - if (!types.empty()) { - intermediate_chunk.Initialize(GetAllocator(), types); - } -} - -Allocator &ExpressionState::GetAllocator() { - return root.executor->GetAllocator(); -} - -bool ExpressionState::HasContext() { - return root.executor->HasContext(); -} - -ClientContext &ExpressionState::GetContext() { - if (!HasContext()) { - throw BinderException("Cannot use %s in this context", (expr.Cast()).function.name); - } - return root.executor->GetContext(); -} - -ExpressionState::ExpressionState(const Expression &expr, ExpressionExecutorState &root) : expr(expr), root(root) { -} - -ExpressionExecutorState::ExpressionExecutorState() : profiler() { -} - -void ExpressionState::Verify(ExpressionExecutorState &root_executor) { - D_ASSERT(&root_executor == &root); - for (auto &entry : child_states) { - entry->Verify(root_executor); - } -} - -void ExpressionExecutorState::Verify() { - D_ASSERT(executor); - root_state->Verify(*this); -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - - -#include - -namespace duckdb { - -struct ARTIndexScanState : public IndexScanState { - - //! Scan predicates (single predicate scan or range scan) - Value values[2]; - //! Expressions of the scan predicates - ExpressionType expressions[2]; - bool checked = false; - //! All scanned row IDs - vector result_ids; - Iterator iterator; -}; - -ART::ART(const vector &column_ids, TableIOManager &table_io_manager, - const vector> &unbound_expressions, const IndexConstraintType constraint_type, - AttachedDatabase &db, const shared_ptr, ALLOCATOR_COUNT>> &allocators_ptr, - const BlockPointer &pointer) - : Index(db, IndexType::ART, table_io_manager, column_ids, unbound_expressions, constraint_type), - allocators(allocators_ptr), owns_data(false) { - if (!Radix::IsLittleEndian()) { - throw NotImplementedException("ART indexes are not supported on big endian architectures"); - } - - // initialize all allocators - if (!allocators) { - owns_data = true; - auto &block_manager = table_io_manager.GetIndexBlockManager(); - - array, ALLOCATOR_COUNT> allocator_array = { - make_uniq(sizeof(Prefix), block_manager), - make_uniq(sizeof(Leaf), block_manager), - make_uniq(sizeof(Node4), block_manager), - make_uniq(sizeof(Node16), block_manager), - make_uniq(sizeof(Node48), block_manager), - make_uniq(sizeof(Node256), block_manager)}; - allocators = make_shared, ALLOCATOR_COUNT>>(std::move(allocator_array)); - } - - if (pointer.IsValid()) { - Deserialize(pointer); - } - - // validate the types of the key columns - for (idx_t i = 0; i < types.size(); i++) { - switch (types[i]) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - case PhysicalType::INT16: - case PhysicalType::INT32: - case PhysicalType::INT64: - case PhysicalType::INT128: - case PhysicalType::UINT8: - case PhysicalType::UINT16: - case PhysicalType::UINT32: - case PhysicalType::UINT64: - case PhysicalType::FLOAT: - case PhysicalType::DOUBLE: - case PhysicalType::VARCHAR: - break; - default: - throw InvalidTypeException(logical_types[i], "Invalid type for index key."); - } - } -} - -//===--------------------------------------------------------------------===// -// Initialize Predicate Scans -//===--------------------------------------------------------------------===// - -unique_ptr ART::InitializeScanSinglePredicate(const Transaction &transaction, const Value &value, - const ExpressionType expression_type) { - // initialize point lookup - auto result = make_uniq(); - result->values[0] = value; - result->expressions[0] = expression_type; - return std::move(result); -} - -unique_ptr ART::InitializeScanTwoPredicates(const Transaction &transaction, const Value &low_value, - const ExpressionType low_expression_type, - const Value &high_value, - const ExpressionType high_expression_type) { - // initialize range lookup - auto result = make_uniq(); - result->values[0] = low_value; - result->expressions[0] = low_expression_type; - result->values[1] = high_value; - result->expressions[1] = high_expression_type; - return std::move(result); -} - -//===--------------------------------------------------------------------===// -// Keys -//===--------------------------------------------------------------------===// - -template -static void TemplatedGenerateKeys(ArenaAllocator &allocator, Vector &input, idx_t count, vector &keys) { - UnifiedVectorFormat idata; - input.ToUnifiedFormat(count, idata); - - D_ASSERT(keys.size() >= count); - auto input_data = UnifiedVectorFormat::GetData(idata); - for (idx_t i = 0; i < count; i++) { - auto idx = idata.sel->get_index(i); - if (idata.validity.RowIsValid(idx)) { - ARTKey::CreateARTKey(allocator, input.GetType(), keys[i], input_data[idx]); - } else { - // we need to possibly reset the former key value in the keys vector - keys[i] = ARTKey(); - } - } -} - -template -static void ConcatenateKeys(ArenaAllocator &allocator, Vector &input, idx_t count, vector &keys) { - UnifiedVectorFormat idata; - input.ToUnifiedFormat(count, idata); - - auto input_data = UnifiedVectorFormat::GetData(idata); - for (idx_t i = 0; i < count; i++) { - auto idx = idata.sel->get_index(i); - - // key is not NULL (no previous column entry was NULL) - if (!keys[i].Empty()) { - if (!idata.validity.RowIsValid(idx)) { - // this column entry is NULL, set whole key to NULL - keys[i] = ARTKey(); - } else { - auto other_key = ARTKey::CreateARTKey(allocator, input.GetType(), input_data[idx]); - keys[i].ConcatenateARTKey(allocator, other_key); - } - } - } -} - -void ART::GenerateKeys(ArenaAllocator &allocator, DataChunk &input, vector &keys) { - // generate keys for the first input column - switch (input.data[0].GetType().InternalType()) { - case PhysicalType::BOOL: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); - break; - case PhysicalType::INT8: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); - break; - case PhysicalType::INT16: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); - break; - case PhysicalType::INT32: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); - break; - case PhysicalType::INT64: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); - break; - case PhysicalType::INT128: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); - break; - case PhysicalType::UINT8: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); - break; - case PhysicalType::UINT16: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); - break; - case PhysicalType::UINT32: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); - break; - case PhysicalType::UINT64: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); - break; - case PhysicalType::FLOAT: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); - break; - case PhysicalType::DOUBLE: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); - break; - case PhysicalType::VARCHAR: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); - break; - default: - throw InternalException("Invalid type for index"); - } - - for (idx_t i = 1; i < input.ColumnCount(); i++) { - // for each of the remaining columns, concatenate - switch (input.data[i].GetType().InternalType()) { - case PhysicalType::BOOL: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); - break; - case PhysicalType::INT8: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); - break; - case PhysicalType::INT16: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); - break; - case PhysicalType::INT32: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); - break; - case PhysicalType::INT64: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); - break; - case PhysicalType::INT128: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); - break; - case PhysicalType::UINT8: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); - break; - case PhysicalType::UINT16: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); - break; - case PhysicalType::UINT32: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); - break; - case PhysicalType::UINT64: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); - break; - case PhysicalType::FLOAT: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); - break; - case PhysicalType::DOUBLE: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); - break; - case PhysicalType::VARCHAR: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); - break; - default: - throw InternalException("Invalid type for index"); - } - } -} - -//===--------------------------------------------------------------------===// -// Construct from sorted data (only during CREATE (UNIQUE) INDEX statements) -//===--------------------------------------------------------------------===// - -struct KeySection { - KeySection(idx_t start_p, idx_t end_p, idx_t depth_p, data_t key_byte_p) - : start(start_p), end(end_p), depth(depth_p), key_byte(key_byte_p) {}; - KeySection(idx_t start_p, idx_t end_p, vector &keys, KeySection &key_section) - : start(start_p), end(end_p), depth(key_section.depth + 1), key_byte(keys[end_p].data[key_section.depth]) {}; - idx_t start; - idx_t end; - idx_t depth; - data_t key_byte; -}; - -void GetChildSections(vector &child_sections, vector &keys, KeySection &key_section) { - - idx_t child_start_idx = key_section.start; - for (idx_t i = key_section.start + 1; i <= key_section.end; i++) { - if (keys[i - 1].data[key_section.depth] != keys[i].data[key_section.depth]) { - child_sections.emplace_back(child_start_idx, i - 1, keys, key_section); - child_start_idx = i; - } - } - child_sections.emplace_back(child_start_idx, key_section.end, keys, key_section); -} - -bool Construct(ART &art, vector &keys, row_t *row_ids, Node &node, KeySection &key_section, - bool &has_constraint) { - - D_ASSERT(key_section.start < keys.size()); - D_ASSERT(key_section.end < keys.size()); - D_ASSERT(key_section.start <= key_section.end); - - auto &start_key = keys[key_section.start]; - auto &end_key = keys[key_section.end]; - - // increment the depth until we reach a leaf or find a mismatching byte - auto prefix_start = key_section.depth; - while (start_key.len != key_section.depth && start_key.ByteMatches(end_key, key_section.depth)) { - key_section.depth++; - } - - // we reached a leaf, i.e. all the bytes of start_key and end_key match - if (start_key.len == key_section.depth) { - // end_idx is inclusive - auto num_row_ids = key_section.end - key_section.start + 1; - - // check for possible constraint violation - auto single_row_id = num_row_ids == 1; - if (has_constraint && !single_row_id) { - return false; - } - - reference ref_node(node); - Prefix::New(art, ref_node, start_key, prefix_start, start_key.len - prefix_start); - if (single_row_id) { - Leaf::New(ref_node, row_ids[key_section.start]); - } else { - Leaf::New(art, ref_node, row_ids + key_section.start, num_row_ids); - } - return true; - } - - // create a new node and recurse - - // we will find at least two child entries of this node, otherwise we'd have reached a leaf - vector child_sections; - GetChildSections(child_sections, keys, key_section); - - // set the prefix - reference ref_node(node); - auto prefix_length = key_section.depth - prefix_start; - Prefix::New(art, ref_node, start_key, prefix_start, prefix_length); - - // set the node - auto node_type = Node::GetARTNodeTypeByCount(child_sections.size()); - Node::New(art, ref_node, node_type); - - // recurse on each child section - for (auto &child_section : child_sections) { - Node new_child; - auto no_violation = Construct(art, keys, row_ids, new_child, child_section, has_constraint); - Node::InsertChild(art, ref_node, child_section.key_byte, new_child); - if (!no_violation) { - return false; - } - } - return true; -} - -bool ART::ConstructFromSorted(idx_t count, vector &keys, Vector &row_identifiers) { - - // prepare the row_identifiers - row_identifiers.Flatten(count); - auto row_ids = FlatVector::GetData(row_identifiers); - - auto key_section = KeySection(0, count - 1, 0, 0); - auto has_constraint = IsUnique(); - if (!Construct(*this, keys, row_ids, tree, key_section, has_constraint)) { - return false; - } - -#ifdef DEBUG - D_ASSERT(!VerifyAndToStringInternal(true).empty()); - for (idx_t i = 0; i < count; i++) { - D_ASSERT(!keys[i].Empty()); - auto leaf = Lookup(tree, keys[i], 0); - D_ASSERT(Leaf::ContainsRowId(*this, *leaf, row_ids[i])); - } -#endif - - return true; -} - -//===--------------------------------------------------------------------===// -// Insert / Verification / Constraint Checking -//===--------------------------------------------------------------------===// -PreservedError ART::Insert(IndexLock &lock, DataChunk &input, Vector &row_ids) { - - D_ASSERT(row_ids.GetType().InternalType() == ROW_TYPE); - D_ASSERT(logical_types[0] == input.data[0].GetType()); - - // generate the keys for the given input - ArenaAllocator arena_allocator(BufferAllocator::Get(db)); - vector keys(input.size()); - GenerateKeys(arena_allocator, input, keys); - - // get the corresponding row IDs - row_ids.Flatten(input.size()); - auto row_identifiers = FlatVector::GetData(row_ids); - - // now insert the elements into the index - idx_t failed_index = DConstants::INVALID_INDEX; - for (idx_t i = 0; i < input.size(); i++) { - if (keys[i].Empty()) { - continue; - } - - row_t row_id = row_identifiers[i]; - if (!Insert(tree, keys[i], 0, row_id)) { - // failed to insert because of constraint violation - failed_index = i; - break; - } - } - - // failed to insert because of constraint violation: remove previously inserted entries - if (failed_index != DConstants::INVALID_INDEX) { - for (idx_t i = 0; i < failed_index; i++) { - if (keys[i].Empty()) { - continue; - } - row_t row_id = row_identifiers[i]; - Erase(tree, keys[i], 0, row_id); - } - } - - if (failed_index != DConstants::INVALID_INDEX) { - return PreservedError(ConstraintException("PRIMARY KEY or UNIQUE constraint violated: duplicate key \"%s\"", - AppendRowError(input, failed_index))); - } - -#ifdef DEBUG - for (idx_t i = 0; i < input.size(); i++) { - if (keys[i].Empty()) { - continue; - } - - auto leaf = Lookup(tree, keys[i], 0); - D_ASSERT(Leaf::ContainsRowId(*this, *leaf, row_identifiers[i])); - } -#endif - - return PreservedError(); -} - -PreservedError ART::Append(IndexLock &lock, DataChunk &appended_data, Vector &row_identifiers) { - DataChunk expression_result; - expression_result.Initialize(Allocator::DefaultAllocator(), logical_types); - - // first resolve the expressions for the index - ExecuteExpressions(appended_data, expression_result); - - // now insert into the index - return Insert(lock, expression_result, row_identifiers); -} - -void ART::VerifyAppend(DataChunk &chunk) { - ConflictManager conflict_manager(VerifyExistenceType::APPEND, chunk.size()); - CheckConstraintsForChunk(chunk, conflict_manager); -} - -void ART::VerifyAppend(DataChunk &chunk, ConflictManager &conflict_manager) { - D_ASSERT(conflict_manager.LookupType() == VerifyExistenceType::APPEND); - CheckConstraintsForChunk(chunk, conflict_manager); -} - -bool ART::InsertToLeaf(Node &leaf, const row_t &row_id) { - - if (IsUnique()) { - return false; - } - - Leaf::Insert(*this, leaf, row_id); - return true; -} - -bool ART::Insert(Node &node, const ARTKey &key, idx_t depth, const row_t &row_id) { - - // node is currently empty, create a leaf here with the key - if (!node.HasMetadata()) { - D_ASSERT(depth <= key.len); - reference ref_node(node); - Prefix::New(*this, ref_node, key, depth, key.len - depth); - Leaf::New(ref_node, row_id); - return true; - } - - auto node_type = node.GetType(); - - // insert the row ID into this leaf - if (node_type == NType::LEAF || node_type == NType::LEAF_INLINED) { - return InsertToLeaf(node, row_id); - } - - if (node_type != NType::PREFIX) { - D_ASSERT(depth < key.len); - auto child = node.GetChildMutable(*this, key[depth]); - - // recurse, if a child exists at key[depth] - if (child) { - bool success = Insert(*child, key, depth + 1, row_id); - node.ReplaceChild(*this, key[depth], *child); - return success; - } - - // insert a new leaf node at key[depth] - Node leaf_node; - reference ref_node(leaf_node); - if (depth + 1 < key.len) { - Prefix::New(*this, ref_node, key, depth + 1, key.len - depth - 1); - } - Leaf::New(ref_node, row_id); - Node::InsertChild(*this, node, key[depth], leaf_node); - return true; - } - - // this is a prefix node, traverse - reference next_node(node); - auto mismatch_position = Prefix::TraverseMutable(*this, next_node, key, depth); - - // prefix matches key - if (next_node.get().GetType() != NType::PREFIX) { - return Insert(next_node, key, depth, row_id); - } - - // prefix does not match the key, we need to create a new Node4; this new Node4 has two children, - // the remaining part of the prefix, and the new leaf - Node remaining_prefix; - auto prefix_byte = Prefix::GetByte(*this, next_node, mismatch_position); - Prefix::Split(*this, next_node, remaining_prefix, mismatch_position); - Node4::New(*this, next_node); - - // insert remaining prefix - Node4::InsertChild(*this, next_node, prefix_byte, remaining_prefix); - - // insert new leaf - Node leaf_node; - reference ref_node(leaf_node); - if (depth + 1 < key.len) { - Prefix::New(*this, ref_node, key, depth + 1, key.len - depth - 1); - } - Leaf::New(ref_node, row_id); - Node4::InsertChild(*this, next_node, key[depth], leaf_node); - return true; -} - -//===--------------------------------------------------------------------===// -// Drop and Delete -//===--------------------------------------------------------------------===// - -void ART::CommitDrop(IndexLock &index_lock) { - for (auto &allocator : *allocators) { - allocator->Reset(); - } - tree.Clear(); -} - -void ART::Delete(IndexLock &state, DataChunk &input, Vector &row_ids) { - - DataChunk expression; - expression.Initialize(Allocator::DefaultAllocator(), logical_types); - - // first resolve the expressions - ExecuteExpressions(input, expression); - - // then generate the keys for the given input - ArenaAllocator arena_allocator(BufferAllocator::Get(db)); - vector keys(expression.size()); - GenerateKeys(arena_allocator, expression, keys); - - // now erase the elements from the database - row_ids.Flatten(input.size()); - auto row_identifiers = FlatVector::GetData(row_ids); - - for (idx_t i = 0; i < input.size(); i++) { - if (keys[i].Empty()) { - continue; - } - Erase(tree, keys[i], 0, row_identifiers[i]); - } - -#ifdef DEBUG - // verify that we removed all row IDs - for (idx_t i = 0; i < input.size(); i++) { - if (keys[i].Empty()) { - continue; - } - - auto leaf = Lookup(tree, keys[i], 0); - if (leaf) { - D_ASSERT(!Leaf::ContainsRowId(*this, *leaf, row_identifiers[i])); - } - } -#endif -} - -void ART::Erase(Node &node, const ARTKey &key, idx_t depth, const row_t &row_id) { - - if (!node.HasMetadata()) { - return; - } - - // handle prefix - reference next_node(node); - if (next_node.get().GetType() == NType::PREFIX) { - Prefix::TraverseMutable(*this, next_node, key, depth); - if (next_node.get().GetType() == NType::PREFIX) { - return; - } - } - - // delete a row ID from a leaf (root is leaf with possible prefix nodes) - if (next_node.get().GetType() == NType::LEAF || next_node.get().GetType() == NType::LEAF_INLINED) { - if (Leaf::Remove(*this, next_node, row_id)) { - Node::Free(*this, node); - } - return; - } - - D_ASSERT(depth < key.len); - auto child = next_node.get().GetChildMutable(*this, key[depth]); - if (child) { - D_ASSERT(child->HasMetadata()); - - auto temp_depth = depth + 1; - reference child_node(*child); - if (child_node.get().GetType() == NType::PREFIX) { - Prefix::TraverseMutable(*this, child_node, key, temp_depth); - if (child_node.get().GetType() == NType::PREFIX) { - return; - } - } - - if (child_node.get().GetType() == NType::LEAF || child_node.get().GetType() == NType::LEAF_INLINED) { - // leaf found, remove entry - if (Leaf::Remove(*this, child_node, row_id)) { - Node::DeleteChild(*this, next_node, node, key[depth]); - } - return; - } - - // recurse - Erase(*child, key, depth + 1, row_id); - next_node.get().ReplaceChild(*this, key[depth], *child); - } -} - -//===--------------------------------------------------------------------===// -// Point Query (Equal) -//===--------------------------------------------------------------------===// - -static ARTKey CreateKey(ArenaAllocator &allocator, PhysicalType type, Value &value) { - D_ASSERT(type == value.type().InternalType()); - switch (type) { - case PhysicalType::BOOL: - return ARTKey::CreateARTKey(allocator, value.type(), value); - case PhysicalType::INT8: - return ARTKey::CreateARTKey(allocator, value.type(), value); - case PhysicalType::INT16: - return ARTKey::CreateARTKey(allocator, value.type(), value); - case PhysicalType::INT32: - return ARTKey::CreateARTKey(allocator, value.type(), value); - case PhysicalType::INT64: - return ARTKey::CreateARTKey(allocator, value.type(), value); - case PhysicalType::UINT8: - return ARTKey::CreateARTKey(allocator, value.type(), value); - case PhysicalType::UINT16: - return ARTKey::CreateARTKey(allocator, value.type(), value); - case PhysicalType::UINT32: - return ARTKey::CreateARTKey(allocator, value.type(), value); - case PhysicalType::UINT64: - return ARTKey::CreateARTKey(allocator, value.type(), value); - case PhysicalType::INT128: - return ARTKey::CreateARTKey(allocator, value.type(), value); - case PhysicalType::FLOAT: - return ARTKey::CreateARTKey(allocator, value.type(), value); - case PhysicalType::DOUBLE: - return ARTKey::CreateARTKey(allocator, value.type(), value); - case PhysicalType::VARCHAR: - return ARTKey::CreateARTKey(allocator, value.type(), value); - default: - throw InternalException("Invalid type for the ART key"); - } -} - -bool ART::SearchEqual(ARTKey &key, idx_t max_count, vector &result_ids) { - - auto leaf = Lookup(tree, key, 0); - if (!leaf) { - return true; - } - return Leaf::GetRowIds(*this, *leaf, result_ids, max_count); -} - -void ART::SearchEqualJoinNoFetch(ARTKey &key, idx_t &result_size) { - - // we need to look for a leaf - auto leaf_node = Lookup(tree, key, 0); - if (!leaf_node) { - result_size = 0; - return; - } - - // we only perform index joins on PK/FK columns - D_ASSERT(leaf_node->GetType() == NType::LEAF_INLINED); - result_size = 1; - return; -} - -//===--------------------------------------------------------------------===// -// Lookup -//===--------------------------------------------------------------------===// - -optional_ptr ART::Lookup(const Node &node, const ARTKey &key, idx_t depth) { - - reference node_ref(node); - while (node_ref.get().HasMetadata()) { - - // traverse prefix, if exists - reference next_node(node_ref.get()); - if (next_node.get().GetType() == NType::PREFIX) { - Prefix::Traverse(*this, next_node, key, depth); - if (next_node.get().GetType() == NType::PREFIX) { - return nullptr; - } - } - - if (next_node.get().GetType() == NType::LEAF || next_node.get().GetType() == NType::LEAF_INLINED) { - return &next_node.get(); - } - - D_ASSERT(depth < key.len); - auto child = next_node.get().GetChild(*this, key[depth]); - if (!child) { - // prefix matches key, but no child at byte, ART/subtree does not contain key - return nullptr; - } - - // lookup in child node - node_ref = *child; - D_ASSERT(node_ref.get().HasMetadata()); - depth++; - } - - return nullptr; -} - -//===--------------------------------------------------------------------===// -// Greater Than and Less Than -//===--------------------------------------------------------------------===// - -bool ART::SearchGreater(ARTIndexScanState &state, ARTKey &key, bool equal, idx_t max_count, vector &result_ids) { - - if (!tree.HasMetadata()) { - return true; - } - Iterator &it = state.iterator; - - // find the lowest value that satisfies the predicate - if (!it.art) { - it.art = this; - if (!it.LowerBound(tree, key, equal, 0)) { - // early-out, if the maximum value in the ART is lower than the lower bound - return true; - } - } - - // after that we continue the scan; we don't need to check the bounds as any value following this value is - // automatically bigger and hence satisfies our predicate - ARTKey empty_key = ARTKey(); - return it.Scan(empty_key, max_count, result_ids, false); -} - -bool ART::SearchLess(ARTIndexScanState &state, ARTKey &upper_bound, bool equal, idx_t max_count, - vector &result_ids) { - - if (!tree.HasMetadata()) { - return true; - } - Iterator &it = state.iterator; - - if (!it.art) { - it.art = this; - // find the minimum value in the ART: we start scanning from this value - it.FindMinimum(tree); - // early-out, if the minimum value is higher than the upper bound - if (it.current_key > upper_bound) { - return true; - } - } - - // now continue the scan until we reach the upper bound - return it.Scan(upper_bound, max_count, result_ids, equal); -} - -//===--------------------------------------------------------------------===// -// Closed Range Query -//===--------------------------------------------------------------------===// - -bool ART::SearchCloseRange(ARTIndexScanState &state, ARTKey &lower_bound, ARTKey &upper_bound, bool left_equal, - bool right_equal, idx_t max_count, vector &result_ids) { - - Iterator &it = state.iterator; - - // find the first node that satisfies the left predicate - if (!it.art) { - it.art = this; - if (!it.LowerBound(tree, lower_bound, left_equal, 0)) { - // early-out, if the maximum value in the ART is lower than the lower bound - return true; - } - } - - // now continue the scan until we reach the upper bound - return it.Scan(upper_bound, max_count, result_ids, right_equal); -} - -bool ART::Scan(const Transaction &transaction, const DataTable &table, IndexScanState &state, const idx_t max_count, - vector &result_ids) { - - auto &scan_state = state.Cast(); - vector row_ids; - bool success; - - // FIXME: the key directly owning the data for a single key might be more efficient - D_ASSERT(scan_state.values[0].type().InternalType() == types[0]); - ArenaAllocator arena_allocator(Allocator::Get(db)); - auto key = CreateKey(arena_allocator, types[0], scan_state.values[0]); - - if (scan_state.values[1].IsNull()) { - - // single predicate - lock_guard l(lock); - switch (scan_state.expressions[0]) { - case ExpressionType::COMPARE_EQUAL: - success = SearchEqual(key, max_count, row_ids); - break; - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - success = SearchGreater(scan_state, key, true, max_count, row_ids); - break; - case ExpressionType::COMPARE_GREATERTHAN: - success = SearchGreater(scan_state, key, false, max_count, row_ids); - break; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - success = SearchLess(scan_state, key, true, max_count, row_ids); - break; - case ExpressionType::COMPARE_LESSTHAN: - success = SearchLess(scan_state, key, false, max_count, row_ids); - break; - default: - throw InternalException("Index scan type not implemented"); - } - - } else { - - // two predicates - lock_guard l(lock); - - D_ASSERT(scan_state.values[1].type().InternalType() == types[0]); - auto upper_bound = CreateKey(arena_allocator, types[0], scan_state.values[1]); - - bool left_equal = scan_state.expressions[0] == ExpressionType ::COMPARE_GREATERTHANOREQUALTO; - bool right_equal = scan_state.expressions[1] == ExpressionType ::COMPARE_LESSTHANOREQUALTO; - success = SearchCloseRange(scan_state, key, upper_bound, left_equal, right_equal, max_count, row_ids); - } - - if (!success) { - return false; - } - if (row_ids.empty()) { - return true; - } - - // sort the row ids - sort(row_ids.begin(), row_ids.end()); - // duplicate eliminate the row ids and append them to the row ids of the state - result_ids.reserve(row_ids.size()); - - result_ids.push_back(row_ids[0]); - for (idx_t i = 1; i < row_ids.size(); i++) { - if (row_ids[i] != row_ids[i - 1]) { - result_ids.push_back(row_ids[i]); - } - } - return true; -} - -//===--------------------------------------------------------------------===// -// More Verification / Constraint Checking -//===--------------------------------------------------------------------===// - -string ART::GenerateErrorKeyName(DataChunk &input, idx_t row) { - - // FIXME: why exactly can we not pass the expression_chunk as an argument to this - // FIXME: function instead of re-executing? - // re-executing the expressions is not very fast, but we're going to throw, so we don't care - DataChunk expression_chunk; - expression_chunk.Initialize(Allocator::DefaultAllocator(), logical_types); - ExecuteExpressions(input, expression_chunk); - - string key_name; - for (idx_t k = 0; k < expression_chunk.ColumnCount(); k++) { - if (k > 0) { - key_name += ", "; - } - key_name += unbound_expressions[k]->GetName() + ": " + expression_chunk.data[k].GetValue(row).ToString(); - } - return key_name; -} - -string ART::GenerateConstraintErrorMessage(VerifyExistenceType verify_type, const string &key_name) { - switch (verify_type) { - case VerifyExistenceType::APPEND: { - // APPEND to PK/UNIQUE table, but node/key already exists in PK/UNIQUE table - string type = IsPrimary() ? "primary key" : "unique"; - return StringUtil::Format( - "Duplicate key \"%s\" violates %s constraint. " - "If this is an unexpected constraint violation please double " - "check with the known index limitations section in our documentation (docs - sql - indexes).", - key_name, type); - } - case VerifyExistenceType::APPEND_FK: { - // APPEND_FK to FK table, node/key does not exist in PK/UNIQUE table - return StringUtil::Format( - "Violates foreign key constraint because key \"%s\" does not exist in the referenced table", key_name); - } - case VerifyExistenceType::DELETE_FK: { - // DELETE_FK that still exists in a FK table, i.e., not a valid delete - return StringUtil::Format("Violates foreign key constraint because key \"%s\" is still referenced by a foreign " - "key in a different table", - key_name); - } - default: - throw NotImplementedException("Type not implemented for VerifyExistenceType"); - } -} - -void ART::CheckConstraintsForChunk(DataChunk &input, ConflictManager &conflict_manager) { - - // don't alter the index during constraint checking - lock_guard l(lock); - - // first resolve the expressions for the index - DataChunk expression_chunk; - expression_chunk.Initialize(Allocator::DefaultAllocator(), logical_types); - ExecuteExpressions(input, expression_chunk); - - // generate the keys for the given input - ArenaAllocator arena_allocator(BufferAllocator::Get(db)); - vector keys(expression_chunk.size()); - GenerateKeys(arena_allocator, expression_chunk, keys); - - idx_t found_conflict = DConstants::INVALID_INDEX; - for (idx_t i = 0; found_conflict == DConstants::INVALID_INDEX && i < input.size(); i++) { - - if (keys[i].Empty()) { - if (conflict_manager.AddNull(i)) { - found_conflict = i; - } - continue; - } - - auto leaf = Lookup(tree, keys[i], 0); - if (!leaf) { - if (conflict_manager.AddMiss(i)) { - found_conflict = i; - } - continue; - } - - // when we find a node, we need to update the 'matches' and 'row_ids' - // NOTE: leaves can have more than one row_id, but for UNIQUE/PRIMARY KEY they will only have one - D_ASSERT(leaf->GetType() == NType::LEAF_INLINED); - if (conflict_manager.AddHit(i, leaf->GetRowId())) { - found_conflict = i; - } - } - - conflict_manager.FinishLookup(); - - if (found_conflict == DConstants::INVALID_INDEX) { - return; - } - - auto key_name = GenerateErrorKeyName(input, found_conflict); - auto exception_msg = GenerateConstraintErrorMessage(conflict_manager.LookupType(), key_name); - throw ConstraintException(exception_msg); -} - -//===--------------------------------------------------------------------===// -// Serialization -//===--------------------------------------------------------------------===// - -BlockPointer ART::Serialize(MetadataWriter &writer) { - - D_ASSERT(owns_data); - - // early-out, if all allocators are empty - if (!tree.HasMetadata()) { - root_block_pointer = BlockPointer(); - return root_block_pointer; - } - - lock_guard l(lock); - auto &block_manager = table_io_manager.GetIndexBlockManager(); - PartialBlockManager partial_block_manager(block_manager, CheckpointType::FULL_CHECKPOINT); - - vector allocator_pointers; - for (auto &allocator : *allocators) { - allocator_pointers.push_back(allocator->Serialize(partial_block_manager, writer)); - } - partial_block_manager.FlushPartialBlocks(); - - root_block_pointer = writer.GetBlockPointer(); - writer.Write(tree); - for (auto &allocator_pointer : allocator_pointers) { - writer.Write(allocator_pointer); - } - - return root_block_pointer; -} - -void ART::Deserialize(const BlockPointer &pointer) { - - D_ASSERT(pointer.IsValid()); - MetadataReader reader(table_io_manager.GetMetadataManager(), pointer); - tree = reader.Read(); - - for (idx_t i = 0; i < ALLOCATOR_COUNT; i++) { - (*allocators)[i]->Deserialize(reader.Read()); - } -} - -//===--------------------------------------------------------------------===// -// Vacuum -//===--------------------------------------------------------------------===// - -void ART::InitializeVacuum(ARTFlags &flags) { - - flags.vacuum_flags.reserve(allocators->size()); - for (auto &allocator : *allocators) { - flags.vacuum_flags.push_back(allocator->InitializeVacuum()); - } -} - -void ART::FinalizeVacuum(const ARTFlags &flags) { - - for (idx_t i = 0; i < allocators->size(); i++) { - if (flags.vacuum_flags[i]) { - (*allocators)[i]->FinalizeVacuum(); - } - } -} - -void ART::Vacuum(IndexLock &state) { - - D_ASSERT(owns_data); - - if (!tree.HasMetadata()) { - for (auto &allocator : *allocators) { - allocator->Reset(); - } - return; - } - - // holds true, if an allocator needs a vacuum, and false otherwise - ARTFlags flags; - InitializeVacuum(flags); - - // skip vacuum if no allocators require it - auto perform_vacuum = false; - for (const auto &vacuum_flag : flags.vacuum_flags) { - if (vacuum_flag) { - perform_vacuum = true; - break; - } - } - if (!perform_vacuum) { - return; - } - - // traverse the allocated memory of the tree to perform a vacuum - tree.Vacuum(*this, flags); - - // finalize the vacuum operation - FinalizeVacuum(flags); -} - -//===--------------------------------------------------------------------===// -// Merging -//===--------------------------------------------------------------------===// - -void ART::InitializeMerge(ARTFlags &flags) { - - D_ASSERT(owns_data); - - flags.merge_buffer_counts.reserve(allocators->size()); - for (auto &allocator : *allocators) { - flags.merge_buffer_counts.emplace_back(allocator->GetUpperBoundBufferId()); - } -} - -bool ART::MergeIndexes(IndexLock &state, Index &other_index) { - - auto &other_art = other_index.Cast(); - if (!other_art.tree.HasMetadata()) { - return true; - } - - if (other_art.owns_data) { - if (tree.HasMetadata()) { - // fully deserialize other_index, and traverse it to increment its buffer IDs - ARTFlags flags; - InitializeMerge(flags); - other_art.tree.InitializeMerge(other_art, flags); - } - - // merge the node storage - for (idx_t i = 0; i < allocators->size(); i++) { - (*allocators)[i]->Merge(*(*other_art.allocators)[i]); - } - } - - // merge the ARTs - if (!tree.Merge(*this, other_art.tree)) { - return false; - } - return true; -} - -//===--------------------------------------------------------------------===// -// Utility -//===--------------------------------------------------------------------===// - -string ART::VerifyAndToString(IndexLock &state, const bool only_verify) { - // FIXME: this can be improved by counting the allocations of each node type, - // FIXME: and by asserting that each fixed-size allocator lists an equal number of - // FIXME: allocations of that type - return VerifyAndToStringInternal(only_verify); -} - -string ART::VerifyAndToStringInternal(const bool only_verify) { - if (tree.HasMetadata()) { - return "ART: " + tree.VerifyAndToString(*this, only_verify); - } - return "[empty]"; -} - -} // namespace duckdb - - -namespace duckdb { - -ARTKey::ARTKey() : len(0) { -} - -ARTKey::ARTKey(const data_ptr_t &data, const uint32_t &len) : len(len), data(data) { -} - -ARTKey::ARTKey(ArenaAllocator &allocator, const uint32_t &len) : len(len) { - data = allocator.Allocate(len); -} - -template <> -ARTKey ARTKey::CreateARTKey(ArenaAllocator &allocator, const LogicalType &type, string_t value) { - uint32_t len = value.GetSize() + 1; - auto data = allocator.Allocate(len); - memcpy(data, value.GetData(), len - 1); - - // FIXME: rethink this - if (type == LogicalType::BLOB || type == LogicalType::VARCHAR) { - // indexes cannot contain BLOBs (or BLOBs cast to VARCHARs) that contain null-terminated bytes - for (uint32_t i = 0; i < len - 1; i++) { - if (data[i] == '\0') { - throw NotImplementedException("Indexes cannot contain BLOBs that contain null-terminated bytes."); - } - } - } - - data[len - 1] = '\0'; - return ARTKey(data, len); -} - -template <> -ARTKey ARTKey::CreateARTKey(ArenaAllocator &allocator, const LogicalType &type, const char *value) { - return ARTKey::CreateARTKey(allocator, type, string_t(value, strlen(value))); -} - -template <> -void ARTKey::CreateARTKey(ArenaAllocator &allocator, const LogicalType &type, ARTKey &key, string_t value) { - key.len = value.GetSize() + 1; - key.data = allocator.Allocate(key.len); - memcpy(key.data, value.GetData(), key.len - 1); - - // FIXME: rethink this - if (type == LogicalType::BLOB || type == LogicalType::VARCHAR) { - // indexes cannot contain BLOBs (or BLOBs cast to VARCHARs) that contain null-terminated bytes - for (uint32_t i = 0; i < key.len - 1; i++) { - if (key.data[i] == '\0') { - throw NotImplementedException("Indexes cannot contain BLOBs that contain null-terminated bytes."); - } - } - } - - key.data[key.len - 1] = '\0'; -} - -template <> -void ARTKey::CreateARTKey(ArenaAllocator &allocator, const LogicalType &type, ARTKey &key, const char *value) { - ARTKey::CreateARTKey(allocator, type, key, string_t(value, strlen(value))); -} - -bool ARTKey::operator>(const ARTKey &k) const { - for (uint32_t i = 0; i < MinValue(len, k.len); i++) { - if (data[i] > k.data[i]) { - return true; - } else if (data[i] < k.data[i]) { - return false; - } - } - return len > k.len; -} - -bool ARTKey::operator>=(const ARTKey &k) const { - for (uint32_t i = 0; i < MinValue(len, k.len); i++) { - if (data[i] > k.data[i]) { - return true; - } else if (data[i] < k.data[i]) { - return false; - } - } - return len >= k.len; -} - -bool ARTKey::operator==(const ARTKey &k) const { - if (len != k.len) { - return false; - } - for (uint32_t i = 0; i < len; i++) { - if (data[i] != k.data[i]) { - return false; - } - } - return true; -} - -void ARTKey::ConcatenateARTKey(ArenaAllocator &allocator, ARTKey &other_key) { - - auto compound_data = allocator.Allocate(len + other_key.len); - memcpy(compound_data, data, len); - memcpy(compound_data + len, other_key.data, other_key.len); - len += other_key.len; - data = compound_data; -} -} // namespace duckdb - - - - - - - -namespace duckdb { - -bool IteratorKey::operator>(const ARTKey &key) const { - for (idx_t i = 0; i < MinValue(key_bytes.size(), key.len); i++) { - if (key_bytes[i] > key.data[i]) { - return true; - } else if (key_bytes[i] < key.data[i]) { - return false; - } - } - return key_bytes.size() > key.len; -} - -bool IteratorKey::operator>=(const ARTKey &key) const { - for (idx_t i = 0; i < MinValue(key_bytes.size(), key.len); i++) { - if (key_bytes[i] > key.data[i]) { - return true; - } else if (key_bytes[i] < key.data[i]) { - return false; - } - } - return key_bytes.size() >= key.len; -} - -bool IteratorKey::operator==(const ARTKey &key) const { - // NOTE: we only use this for finding the LowerBound, in which case the length - // has to be equal - D_ASSERT(key_bytes.size() == key.len); - for (idx_t i = 0; i < key_bytes.size(); i++) { - if (key_bytes[i] != key.data[i]) { - return false; - } - } - return true; -} - -bool Iterator::Scan(const ARTKey &upper_bound, const idx_t max_count, vector &result_ids, const bool equal) { - - bool has_next; - do { - if (!upper_bound.Empty()) { - // no more row IDs within the key bounds - if (equal) { - if (current_key > upper_bound) { - return true; - } - } else { - if (current_key >= upper_bound) { - return true; - } - } - } - - // copy all row IDs of this leaf into the result IDs (if they don't exceed max_count) - if (!Leaf::GetRowIds(*art, last_leaf, result_ids, max_count)) { - return false; - } - - // get the next leaf - has_next = Next(); - - } while (has_next); - - return true; -} - -void Iterator::FindMinimum(const Node &node) { - - D_ASSERT(node.HasMetadata()); - - // found the minimum - if (node.GetType() == NType::LEAF || node.GetType() == NType::LEAF_INLINED) { - last_leaf = node; - return; - } - - // traverse the prefix - if (node.GetType() == NType::PREFIX) { - auto &prefix = Node::Ref(*art, node, NType::PREFIX); - for (idx_t i = 0; i < prefix.data[Node::PREFIX_SIZE]; i++) { - current_key.Push(prefix.data[i]); - } - nodes.emplace(node, 0); - return FindMinimum(prefix.ptr); - } - - // go to the leftmost entry in the current node and recurse - uint8_t byte = 0; - auto next = node.GetNextChild(*art, byte); - D_ASSERT(next); - current_key.Push(byte); - nodes.emplace(node, byte); - FindMinimum(*next); -} - -bool Iterator::LowerBound(const Node &node, const ARTKey &key, const bool equal, idx_t depth) { - - if (!node.HasMetadata()) { - return false; - } - - // we found the lower bound - if (node.GetType() == NType::LEAF || node.GetType() == NType::LEAF_INLINED) { - if (!equal && current_key == key) { - return Next(); - } - last_leaf = node; - return true; - } - - if (node.GetType() != NType::PREFIX) { - auto next_byte = key[depth]; - auto child = node.GetNextChild(*art, next_byte); - if (!child) { - // the key is greater than any key in this subtree - return Next(); - } - - current_key.Push(next_byte); - nodes.emplace(node, next_byte); - - if (next_byte > key[depth]) { - // we only need to find the minimum from here - // because all keys will be greater than the lower bound - FindMinimum(*child); - return true; - } - - // recurse into the child - return LowerBound(*child, key, equal, depth + 1); - } - - // resolve the prefix - auto &prefix = Node::Ref(*art, node, NType::PREFIX); - for (idx_t i = 0; i < prefix.data[Node::PREFIX_SIZE]; i++) { - current_key.Push(prefix.data[i]); - } - nodes.emplace(node, 0); - - for (idx_t i = 0; i < prefix.data[Node::PREFIX_SIZE]; i++) { - // the key down to this node is less than the lower bound, the next key will be - // greater than the lower bound - if (prefix.data[i] < key[depth + i]) { - return Next(); - } - // we only need to find the minimum from here - // because all keys will be greater than the lower bound - if (prefix.data[i] > key[depth + i]) { - FindMinimum(prefix.ptr); - return true; - } - } - - // recurse into the child - depth += prefix.data[Node::PREFIX_SIZE]; - return LowerBound(prefix.ptr, key, equal, depth); -} - -bool Iterator::Next() { - - while (!nodes.empty()) { - - auto &top = nodes.top(); - D_ASSERT(top.node.GetType() != NType::LEAF && top.node.GetType() != NType::LEAF_INLINED); - - if (top.node.GetType() == NType::PREFIX) { - PopNode(); - continue; - } - - if (top.byte == NumericLimits::Maximum()) { - // no node found: move up the tree, pop key byte of current node - PopNode(); - continue; - } - - top.byte++; - auto next_node = top.node.GetNextChild(*art, top.byte); - if (!next_node) { - PopNode(); - continue; - } - - current_key.Pop(1); - current_key.Push(top.byte); - - FindMinimum(*next_node); - return true; - } - return false; -} - -void Iterator::PopNode() { - if (nodes.top().node.GetType() == NType::PREFIX) { - auto &prefix = Node::Ref(*art, nodes.top().node, NType::PREFIX); - auto prefix_byte_count = prefix.data[Node::PREFIX_SIZE]; - current_key.Pop(prefix_byte_count); - } else { - current_key.Pop(1); - } - nodes.pop(); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -void Leaf::New(Node &node, const row_t row_id) { - - // we directly inline this row ID into the node pointer - D_ASSERT(row_id < MAX_ROW_ID_LOCAL); - node.Clear(); - node.SetMetadata(static_cast(NType::LEAF_INLINED)); - node.SetRowId(row_id); -} - -void Leaf::New(ART &art, reference &node, const row_t *row_ids, idx_t count) { - - D_ASSERT(count > 1); - - idx_t copy_count = 0; - while (count) { - node.get() = Node::GetAllocator(art, NType::LEAF).New(); - node.get().SetMetadata(static_cast(NType::LEAF)); - - auto &leaf = Node::RefMutable(art, node, NType::LEAF); - - leaf.count = MinValue((idx_t)Node::LEAF_SIZE, count); - - for (idx_t i = 0; i < leaf.count; i++) { - leaf.row_ids[i] = row_ids[copy_count + i]; - } - - copy_count += leaf.count; - count -= leaf.count; - - node = leaf.ptr; - leaf.ptr.Clear(); - } -} - -Leaf &Leaf::New(ART &art, Node &node) { - node = Node::GetAllocator(art, NType::LEAF).New(); - node.SetMetadata(static_cast(NType::LEAF)); - auto &leaf = Node::RefMutable(art, node, NType::LEAF); - - leaf.count = 0; - leaf.ptr.Clear(); - return leaf; -} - -void Leaf::Free(ART &art, Node &node) { - - Node current_node = node; - Node next_node; - while (current_node.HasMetadata()) { - next_node = Node::RefMutable(art, current_node, NType::LEAF).ptr; - Node::GetAllocator(art, NType::LEAF).Free(current_node); - current_node = next_node; - } - - node.Clear(); -} - -void Leaf::InitializeMerge(ART &art, Node &node, const ARTFlags &flags) { - - auto merge_buffer_count = flags.merge_buffer_counts[static_cast(NType::LEAF) - 1]; - - Node next_node = node; - node.IncreaseBufferId(merge_buffer_count); - - while (next_node.HasMetadata()) { - auto &leaf = Node::RefMutable(art, next_node, NType::LEAF); - next_node = leaf.ptr; - if (leaf.ptr.HasMetadata()) { - leaf.ptr.IncreaseBufferId(merge_buffer_count); - } - } -} - -void Leaf::Merge(ART &art, Node &l_node, Node &r_node) { - - D_ASSERT(l_node.HasMetadata() && r_node.HasMetadata()); - - // copy inlined row ID of r_node - if (r_node.GetType() == NType::LEAF_INLINED) { - Insert(art, l_node, r_node.GetRowId()); - r_node.Clear(); - return; - } - - // l_node has an inlined row ID, swap and insert - if (l_node.GetType() == NType::LEAF_INLINED) { - auto row_id = l_node.GetRowId(); - l_node = r_node; - Insert(art, l_node, row_id); - r_node.Clear(); - return; - } - - D_ASSERT(l_node.GetType() != NType::LEAF_INLINED); - D_ASSERT(r_node.GetType() != NType::LEAF_INLINED); - - reference l_node_ref(l_node); - reference l_leaf = Node::RefMutable(art, l_node_ref, NType::LEAF); - - // find a non-full node - while (l_leaf.get().count == Node::LEAF_SIZE) { - l_node_ref = l_leaf.get().ptr; - - // the last leaf is full - if (!l_leaf.get().ptr.HasMetadata()) { - break; - } - l_leaf = Node::RefMutable(art, l_node_ref, NType::LEAF); - } - - // store the last leaf and then append r_node - auto last_leaf_node = l_node_ref.get(); - l_node_ref.get() = r_node; - r_node.Clear(); - - // append the remaining row IDs of the last leaf node - if (last_leaf_node.HasMetadata()) { - // find the tail - l_leaf = Node::RefMutable(art, l_node_ref, NType::LEAF); - while (l_leaf.get().ptr.HasMetadata()) { - l_leaf = Node::RefMutable(art, l_leaf.get().ptr, NType::LEAF); - } - // append the row IDs - auto &last_leaf = Node::RefMutable(art, last_leaf_node, NType::LEAF); - for (idx_t i = 0; i < last_leaf.count; i++) { - l_leaf = l_leaf.get().Append(art, last_leaf.row_ids[i]); - } - Node::GetAllocator(art, NType::LEAF).Free(last_leaf_node); - } -} - -void Leaf::Insert(ART &art, Node &node, const row_t row_id) { - - D_ASSERT(node.HasMetadata()); - - if (node.GetType() == NType::LEAF_INLINED) { - MoveInlinedToLeaf(art, node); - Insert(art, node, row_id); - return; - } - - // append to the tail - reference leaf = Node::RefMutable(art, node, NType::LEAF); - while (leaf.get().ptr.HasMetadata()) { - leaf = Node::RefMutable(art, leaf.get().ptr, NType::LEAF); - } - leaf.get().Append(art, row_id); -} - -bool Leaf::Remove(ART &art, reference &node, const row_t row_id) { - - D_ASSERT(node.get().HasMetadata()); - - if (node.get().GetType() == NType::LEAF_INLINED) { - if (node.get().GetRowId() == row_id) { - return true; - } - return false; - } - - reference leaf = Node::RefMutable(art, node, NType::LEAF); - - // inline the remaining row ID - if (leaf.get().count == 2) { - if (leaf.get().row_ids[0] == row_id || leaf.get().row_ids[1] == row_id) { - auto remaining_row_id = leaf.get().row_ids[0] == row_id ? leaf.get().row_ids[1] : leaf.get().row_ids[0]; - Node::Free(art, node); - New(node, remaining_row_id); - } - return false; - } - - // get the last row ID (the order within a leaf does not matter) - // because we want to overwrite the row ID to remove with that one - - // go to the tail and keep track of the previous leaf node - reference prev_leaf(leaf); - while (leaf.get().ptr.HasMetadata()) { - prev_leaf = leaf; - leaf = Node::RefMutable(art, leaf.get().ptr, NType::LEAF); - } - - auto last_idx = leaf.get().count; - auto last_row_id = leaf.get().row_ids[last_idx - 1]; - - // only one row ID in this leaf segment, free it - if (leaf.get().count == 1) { - Node::Free(art, prev_leaf.get().ptr); - if (last_row_id == row_id) { - return false; - } - } else { - leaf.get().count--; - } - - // find the row ID and copy the last row ID to that position - while (node.get().HasMetadata()) { - leaf = Node::RefMutable(art, node, NType::LEAF); - for (idx_t i = 0; i < leaf.get().count; i++) { - if (leaf.get().row_ids[i] == row_id) { - leaf.get().row_ids[i] = last_row_id; - return false; - } - } - node = leaf.get().ptr; - } - return false; -} - -idx_t Leaf::TotalCount(ART &art, const Node &node) { - - D_ASSERT(node.HasMetadata()); - if (node.GetType() == NType::LEAF_INLINED) { - return 1; - } - - idx_t count = 0; - reference node_ref(node); - while (node_ref.get().HasMetadata()) { - auto &leaf = Node::Ref(art, node_ref, NType::LEAF); - count += leaf.count; - node_ref = leaf.ptr; - } - return count; -} - -bool Leaf::GetRowIds(ART &art, const Node &node, vector &result_ids, idx_t max_count) { - - // adding more elements would exceed the maximum count - D_ASSERT(node.HasMetadata()); - if (result_ids.size() + TotalCount(art, node) > max_count) { - return false; - } - - if (node.GetType() == NType::LEAF_INLINED) { - // push back the inlined row ID of this leaf - result_ids.push_back(node.GetRowId()); - - } else { - // push back all the row IDs of this leaf - reference last_leaf_ref(node); - while (last_leaf_ref.get().HasMetadata()) { - auto &leaf = Node::Ref(art, last_leaf_ref, NType::LEAF); - for (idx_t i = 0; i < leaf.count; i++) { - result_ids.push_back(leaf.row_ids[i]); - } - last_leaf_ref = leaf.ptr; - } - } - - return true; -} - -bool Leaf::ContainsRowId(ART &art, const Node &node, const row_t row_id) { - - D_ASSERT(node.HasMetadata()); - - if (node.GetType() == NType::LEAF_INLINED) { - return node.GetRowId() == row_id; - } - - reference ref_node(node); - while (ref_node.get().HasMetadata()) { - auto &leaf = Node::Ref(art, ref_node, NType::LEAF); - for (idx_t i = 0; i < leaf.count; i++) { - if (leaf.row_ids[i] == row_id) { - return true; - } - } - ref_node = leaf.ptr; - } - - return false; -} - -string Leaf::VerifyAndToString(ART &art, const Node &node, const bool only_verify) { - - if (node.GetType() == NType::LEAF_INLINED) { - return only_verify ? "" : "Leaf [count: 1, row ID: " + to_string(node.GetRowId()) + "]"; - } - - string str = ""; - - reference node_ref(node); - while (node_ref.get().HasMetadata()) { - - auto &leaf = Node::Ref(art, node_ref, NType::LEAF); - D_ASSERT(leaf.count <= Node::LEAF_SIZE); - - str += "Leaf [count: " + to_string(leaf.count) + ", row IDs: "; - for (idx_t i = 0; i < leaf.count; i++) { - str += to_string(leaf.row_ids[i]) + "-"; - } - str += "] "; - - node_ref = leaf.ptr; - } - return only_verify ? "" : str; -} - -void Leaf::Vacuum(ART &art, Node &node) { - - auto &allocator = Node::GetAllocator(art, NType::LEAF); - - reference node_ref(node); - while (node_ref.get().HasMetadata()) { - if (allocator.NeedsVacuum(node_ref)) { - node_ref.get() = allocator.VacuumPointer(node_ref); - node_ref.get().SetMetadata(static_cast(NType::LEAF)); - } - auto &leaf = Node::RefMutable(art, node_ref, NType::LEAF); - node_ref = leaf.ptr; - } -} - -void Leaf::MoveInlinedToLeaf(ART &art, Node &node) { - - D_ASSERT(node.GetType() == NType::LEAF_INLINED); - auto row_id = node.GetRowId(); - auto &leaf = New(art, node); - - leaf.count = 1; - leaf.row_ids[0] = row_id; -} - -Leaf &Leaf::Append(ART &art, const row_t row_id) { - - reference leaf(*this); - - // we need a new leaf node - if (leaf.get().count == Node::LEAF_SIZE) { - leaf = New(art, leaf.get().ptr); - } - - leaf.get().row_ids[leaf.get().count] = row_id; - leaf.get().count++; - return leaf.get(); -} - -} // namespace duckdb - - - - - - - - - - - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// New / Free -//===--------------------------------------------------------------------===// - -void Node::New(ART &art, Node &node, const NType type) { - - // NOTE: leaves and prefixes should not pass through this function - - switch (type) { - case NType::NODE_4: - Node4::New(art, node); - break; - case NType::NODE_16: - Node16::New(art, node); - break; - case NType::NODE_48: - Node48::New(art, node); - break; - case NType::NODE_256: - Node256::New(art, node); - break; - default: - throw InternalException("Invalid node type for New."); - } -} - -void Node::Free(ART &art, Node &node) { - - if (!node.HasMetadata()) { - return node.Clear(); - } - - // free the children of the nodes - auto type = node.GetType(); - switch (type) { - case NType::PREFIX: - // iterative - return Prefix::Free(art, node); - case NType::LEAF: - // iterative - return Leaf::Free(art, node); - case NType::NODE_4: - Node4::Free(art, node); - break; - case NType::NODE_16: - Node16::Free(art, node); - break; - case NType::NODE_48: - Node48::Free(art, node); - break; - case NType::NODE_256: - Node256::Free(art, node); - break; - case NType::LEAF_INLINED: - return node.Clear(); - } - - GetAllocator(art, type).Free(node); - node.Clear(); -} - -//===--------------------------------------------------------------------===// -// Get Allocators -//===--------------------------------------------------------------------===// - -FixedSizeAllocator &Node::GetAllocator(const ART &art, const NType type) { - return *(*art.allocators)[static_cast(type) - 1]; -} - -//===--------------------------------------------------------------------===// -// Inserts -//===--------------------------------------------------------------------===// - -void Node::ReplaceChild(const ART &art, const uint8_t byte, const Node child) const { - - switch (GetType()) { - case NType::NODE_4: - return RefMutable(art, *this, NType::NODE_4).ReplaceChild(byte, child); - case NType::NODE_16: - return RefMutable(art, *this, NType::NODE_16).ReplaceChild(byte, child); - case NType::NODE_48: - return RefMutable(art, *this, NType::NODE_48).ReplaceChild(byte, child); - case NType::NODE_256: - return RefMutable(art, *this, NType::NODE_256).ReplaceChild(byte, child); - default: - throw InternalException("Invalid node type for ReplaceChild."); - } -} - -void Node::InsertChild(ART &art, Node &node, const uint8_t byte, const Node child) { - - switch (node.GetType()) { - case NType::NODE_4: - return Node4::InsertChild(art, node, byte, child); - case NType::NODE_16: - return Node16::InsertChild(art, node, byte, child); - case NType::NODE_48: - return Node48::InsertChild(art, node, byte, child); - case NType::NODE_256: - return Node256::InsertChild(art, node, byte, child); - default: - throw InternalException("Invalid node type for InsertChild."); - } -} - -//===--------------------------------------------------------------------===// -// Deletes -//===--------------------------------------------------------------------===// - -void Node::DeleteChild(ART &art, Node &node, Node &prefix, const uint8_t byte) { - - switch (node.GetType()) { - case NType::NODE_4: - return Node4::DeleteChild(art, node, prefix, byte); - case NType::NODE_16: - return Node16::DeleteChild(art, node, byte); - case NType::NODE_48: - return Node48::DeleteChild(art, node, byte); - case NType::NODE_256: - return Node256::DeleteChild(art, node, byte); - default: - throw InternalException("Invalid node type for DeleteChild."); - } -} - -//===--------------------------------------------------------------------===// -// Get functions -//===--------------------------------------------------------------------===// - -optional_ptr Node::GetChild(ART &art, const uint8_t byte) const { - - D_ASSERT(HasMetadata()); - - switch (GetType()) { - case NType::NODE_4: - return Ref(art, *this, NType::NODE_4).GetChild(byte); - case NType::NODE_16: - return Ref(art, *this, NType::NODE_16).GetChild(byte); - case NType::NODE_48: - return Ref(art, *this, NType::NODE_48).GetChild(byte); - case NType::NODE_256: - return Ref(art, *this, NType::NODE_256).GetChild(byte); - default: - throw InternalException("Invalid node type for GetChild."); - } -} - -optional_ptr Node::GetChildMutable(ART &art, const uint8_t byte) const { - - D_ASSERT(HasMetadata()); - - switch (GetType()) { - case NType::NODE_4: - return RefMutable(art, *this, NType::NODE_4).GetChildMutable(byte); - case NType::NODE_16: - return RefMutable(art, *this, NType::NODE_16).GetChildMutable(byte); - case NType::NODE_48: - return RefMutable(art, *this, NType::NODE_48).GetChildMutable(byte); - case NType::NODE_256: - return RefMutable(art, *this, NType::NODE_256).GetChildMutable(byte); - default: - throw InternalException("Invalid node type for GetChildMutable."); - } -} - -optional_ptr Node::GetNextChild(ART &art, uint8_t &byte) const { - - D_ASSERT(HasMetadata()); - - switch (GetType()) { - case NType::NODE_4: - return Ref(art, *this, NType::NODE_4).GetNextChild(byte); - case NType::NODE_16: - return Ref(art, *this, NType::NODE_16).GetNextChild(byte); - case NType::NODE_48: - return Ref(art, *this, NType::NODE_48).GetNextChild(byte); - case NType::NODE_256: - return Ref(art, *this, NType::NODE_256).GetNextChild(byte); - default: - throw InternalException("Invalid node type for GetNextChild."); - } -} - -optional_ptr Node::GetNextChildMutable(ART &art, uint8_t &byte) const { - - D_ASSERT(HasMetadata()); - - switch (GetType()) { - case NType::NODE_4: - return RefMutable(art, *this, NType::NODE_4).GetNextChildMutable(byte); - case NType::NODE_16: - return RefMutable(art, *this, NType::NODE_16).GetNextChildMutable(byte); - case NType::NODE_48: - return RefMutable(art, *this, NType::NODE_48).GetNextChildMutable(byte); - case NType::NODE_256: - return RefMutable(art, *this, NType::NODE_256).GetNextChildMutable(byte); - default: - throw InternalException("Invalid node type for GetNextChildMutable."); - } -} - -//===--------------------------------------------------------------------===// -// Utility -//===--------------------------------------------------------------------===// - -string Node::VerifyAndToString(ART &art, const bool only_verify) const { - - D_ASSERT(HasMetadata()); - - if (GetType() == NType::LEAF || GetType() == NType::LEAF_INLINED) { - auto str = Leaf::VerifyAndToString(art, *this, only_verify); - return only_verify ? "" : "\n" + str; - } - if (GetType() == NType::PREFIX) { - auto str = Prefix::VerifyAndToString(art, *this, only_verify); - return only_verify ? "" : "\n" + str; - } - - string str = "Node" + to_string(GetCapacity()) + ": ["; - uint8_t byte = 0; - auto child = GetNextChild(art, byte); - - while (child) { - str += "(" + to_string(byte) + ", " + child->VerifyAndToString(art, only_verify) + ")"; - if (byte == NumericLimits::Maximum()) { - break; - } - - byte++; - child = GetNextChild(art, byte); - } - - return only_verify ? "" : "\n" + str + "]"; -} - -idx_t Node::GetCapacity() const { - - switch (GetType()) { - case NType::NODE_4: - return NODE_4_CAPACITY; - case NType::NODE_16: - return NODE_16_CAPACITY; - case NType::NODE_48: - return NODE_48_CAPACITY; - case NType::NODE_256: - return NODE_256_CAPACITY; - default: - throw InternalException("Invalid node type for GetCapacity."); - } -} - -NType Node::GetARTNodeTypeByCount(const idx_t count) { - - if (count <= NODE_4_CAPACITY) { - return NType::NODE_4; - } else if (count <= NODE_16_CAPACITY) { - return NType::NODE_16; - } else if (count <= NODE_48_CAPACITY) { - return NType::NODE_48; - } - return NType::NODE_256; -} - -//===--------------------------------------------------------------------===// -// Merging -//===--------------------------------------------------------------------===// - -void Node::InitializeMerge(ART &art, const ARTFlags &flags) { - - D_ASSERT(HasMetadata()); - - switch (GetType()) { - case NType::PREFIX: - // iterative - return Prefix::InitializeMerge(art, *this, flags); - case NType::LEAF: - // iterative - return Leaf::InitializeMerge(art, *this, flags); - case NType::NODE_4: - RefMutable(art, *this, NType::NODE_4).InitializeMerge(art, flags); - break; - case NType::NODE_16: - RefMutable(art, *this, NType::NODE_16).InitializeMerge(art, flags); - break; - case NType::NODE_48: - RefMutable(art, *this, NType::NODE_48).InitializeMerge(art, flags); - break; - case NType::NODE_256: - RefMutable(art, *this, NType::NODE_256).InitializeMerge(art, flags); - break; - case NType::LEAF_INLINED: - return; - } - - IncreaseBufferId(flags.merge_buffer_counts[static_cast(GetType()) - 1]); -} - -bool Node::Merge(ART &art, Node &other) { - - if (!HasMetadata()) { - *this = other; - other = Node(); - return true; - } - - return ResolvePrefixes(art, other); -} - -bool MergePrefixContainsOtherPrefix(ART &art, reference &l_node, reference &r_node, - idx_t &mismatch_position) { - - // r_node's prefix contains l_node's prefix - // l_node cannot be a leaf, otherwise the key represented by l_node would be a subset of another key - // which is not possible by our construction - D_ASSERT(l_node.get().GetType() != NType::LEAF && l_node.get().GetType() != NType::LEAF_INLINED); - - // test if the next byte (mismatch_position) in r_node (prefix) exists in l_node - auto mismatch_byte = Prefix::GetByte(art, r_node, mismatch_position); - auto child_node = l_node.get().GetChildMutable(art, mismatch_byte); - - // update the prefix of r_node to only consist of the bytes after mismatch_position - Prefix::Reduce(art, r_node, mismatch_position); - - if (!child_node) { - // insert r_node as a child of l_node at the empty position - Node::InsertChild(art, l_node, mismatch_byte, r_node); - r_node.get().Clear(); - return true; - } - - // recurse - return child_node->ResolvePrefixes(art, r_node); -} - -void MergePrefixesDiffer(ART &art, reference &l_node, reference &r_node, idx_t &mismatch_position) { - - // create a new node and insert both nodes as children - - Node l_child; - auto l_byte = Prefix::GetByte(art, l_node, mismatch_position); - Prefix::Split(art, l_node, l_child, mismatch_position); - Node4::New(art, l_node); - - // insert children - Node4::InsertChild(art, l_node, l_byte, l_child); - auto r_byte = Prefix::GetByte(art, r_node, mismatch_position); - Prefix::Reduce(art, r_node, mismatch_position); - Node4::InsertChild(art, l_node, r_byte, r_node); - - r_node.get().Clear(); -} - -bool Node::ResolvePrefixes(ART &art, Node &other) { - - // NOTE: we always merge into the left ART - - D_ASSERT(HasMetadata() && other.HasMetadata()); - - // case 1: both nodes have no prefix - if (GetType() != NType::PREFIX && other.GetType() != NType::PREFIX) { - return MergeInternal(art, other); - } - - reference l_node(*this); - reference r_node(other); - - idx_t mismatch_position = DConstants::INVALID_INDEX; - - // traverse prefixes - if (l_node.get().GetType() == NType::PREFIX && r_node.get().GetType() == NType::PREFIX) { - - if (!Prefix::Traverse(art, l_node, r_node, mismatch_position)) { - return false; - } - // we already recurse because the prefixes matched (so far) - if (mismatch_position == DConstants::INVALID_INDEX) { - return true; - } - - } else { - - // l_prefix contains r_prefix - if (l_node.get().GetType() == NType::PREFIX) { - swap(*this, other); - } - mismatch_position = 0; - } - D_ASSERT(mismatch_position != DConstants::INVALID_INDEX); - - // case 2: one prefix contains the other prefix - if (l_node.get().GetType() != NType::PREFIX && r_node.get().GetType() == NType::PREFIX) { - return MergePrefixContainsOtherPrefix(art, l_node, r_node, mismatch_position); - } - - // case 3: prefixes differ at a specific byte - MergePrefixesDiffer(art, l_node, r_node, mismatch_position); - return true; -} - -bool Node::MergeInternal(ART &art, Node &other) { - - D_ASSERT(HasMetadata() && other.HasMetadata()); - D_ASSERT(GetType() != NType::PREFIX && other.GetType() != NType::PREFIX); - - // always try to merge the smaller node into the bigger node - // because maybe there is enough free space in the bigger node to fit the smaller one - // without too much recursion - if (GetType() < other.GetType()) { - swap(*this, other); - } - - Node empty_node; - auto &l_node = *this; - auto &r_node = other; - - if (r_node.GetType() == NType::LEAF || r_node.GetType() == NType::LEAF_INLINED) { - D_ASSERT(l_node.GetType() == NType::LEAF || l_node.GetType() == NType::LEAF_INLINED); - - if (art.IsUnique()) { - return false; - } - - Leaf::Merge(art, l_node, r_node); - return true; - } - - uint8_t byte = 0; - auto r_child = r_node.GetNextChildMutable(art, byte); - - // while r_node still has children to merge - while (r_child) { - auto l_child = l_node.GetChildMutable(art, byte); - if (!l_child) { - // insert child at empty byte - InsertChild(art, l_node, byte, *r_child); - r_node.ReplaceChild(art, byte, empty_node); - - } else { - // recurse - if (!l_child->ResolvePrefixes(art, *r_child)) { - return false; - } - } - - if (byte == NumericLimits::Maximum()) { - break; - } - byte++; - r_child = r_node.GetNextChildMutable(art, byte); - } - - Free(art, r_node); - return true; -} - -//===--------------------------------------------------------------------===// -// Vacuum -//===--------------------------------------------------------------------===// - -void Node::Vacuum(ART &art, const ARTFlags &flags) { - - D_ASSERT(HasMetadata()); - - auto node_type = GetType(); - auto node_type_idx = static_cast(node_type); - - // iterative functions - if (node_type == NType::PREFIX) { - return Prefix::Vacuum(art, *this, flags); - } - if (node_type == NType::LEAF_INLINED) { - return; - } - if (node_type == NType::LEAF) { - if (flags.vacuum_flags[node_type_idx - 1]) { - Leaf::Vacuum(art, *this); - } - return; - } - - auto &allocator = GetAllocator(art, node_type); - auto needs_vacuum = flags.vacuum_flags[node_type_idx - 1] && allocator.NeedsVacuum(*this); - if (needs_vacuum) { - *this = allocator.VacuumPointer(*this); - SetMetadata(node_type_idx); - } - - // recursive functions - switch (node_type) { - case NType::NODE_4: - return RefMutable(art, *this, NType::NODE_4).Vacuum(art, flags); - case NType::NODE_16: - return RefMutable(art, *this, NType::NODE_16).Vacuum(art, flags); - case NType::NODE_48: - return RefMutable(art, *this, NType::NODE_48).Vacuum(art, flags); - case NType::NODE_256: - return RefMutable(art, *this, NType::NODE_256).Vacuum(art, flags); - default: - throw InternalException("Invalid node type for Vacuum."); - } -} - -} // namespace duckdb - - - - - -namespace duckdb { - -Node16 &Node16::New(ART &art, Node &node) { - - node = Node::GetAllocator(art, NType::NODE_16).New(); - node.SetMetadata(static_cast(NType::NODE_16)); - auto &n16 = Node::RefMutable(art, node, NType::NODE_16); - - n16.count = 0; - return n16; -} - -void Node16::Free(ART &art, Node &node) { - - D_ASSERT(node.HasMetadata()); - auto &n16 = Node::RefMutable(art, node, NType::NODE_16); - - // free all children - for (idx_t i = 0; i < n16.count; i++) { - Node::Free(art, n16.children[i]); - } -} - -Node16 &Node16::GrowNode4(ART &art, Node &node16, Node &node4) { - - auto &n4 = Node::RefMutable(art, node4, NType::NODE_4); - auto &n16 = New(art, node16); - - n16.count = n4.count; - for (idx_t i = 0; i < n4.count; i++) { - n16.key[i] = n4.key[i]; - n16.children[i] = n4.children[i]; - } - - n4.count = 0; - Node::Free(art, node4); - return n16; -} - -Node16 &Node16::ShrinkNode48(ART &art, Node &node16, Node &node48) { - - auto &n16 = New(art, node16); - auto &n48 = Node::RefMutable(art, node48, NType::NODE_48); - - n16.count = 0; - for (idx_t i = 0; i < Node::NODE_256_CAPACITY; i++) { - D_ASSERT(n16.count <= Node::NODE_16_CAPACITY); - if (n48.child_index[i] != Node::EMPTY_MARKER) { - n16.key[n16.count] = i; - n16.children[n16.count] = n48.children[n48.child_index[i]]; - n16.count++; - } - } - - n48.count = 0; - Node::Free(art, node48); - return n16; -} - -void Node16::InitializeMerge(ART &art, const ARTFlags &flags) { - - for (idx_t i = 0; i < count; i++) { - children[i].InitializeMerge(art, flags); - } -} - -void Node16::InsertChild(ART &art, Node &node, const uint8_t byte, const Node child) { - - D_ASSERT(node.HasMetadata()); - auto &n16 = Node::RefMutable(art, node, NType::NODE_16); - - // ensure that there is no other child at the same byte - for (idx_t i = 0; i < n16.count; i++) { - D_ASSERT(n16.key[i] != byte); - } - - // insert new child node into node - if (n16.count < Node::NODE_16_CAPACITY) { - // still space, just insert the child - idx_t child_pos = 0; - while (child_pos < n16.count && n16.key[child_pos] < byte) { - child_pos++; - } - // move children backwards to make space - for (idx_t i = n16.count; i > child_pos; i--) { - n16.key[i] = n16.key[i - 1]; - n16.children[i] = n16.children[i - 1]; - } - - n16.key[child_pos] = byte; - n16.children[child_pos] = child; - n16.count++; - - } else { - // node is full, grow to Node48 - auto node16 = node; - Node48::GrowNode16(art, node, node16); - Node48::InsertChild(art, node, byte, child); - } -} - -void Node16::DeleteChild(ART &art, Node &node, const uint8_t byte) { - - D_ASSERT(node.HasMetadata()); - auto &n16 = Node::RefMutable(art, node, NType::NODE_16); - - idx_t child_pos = 0; - for (; child_pos < n16.count; child_pos++) { - if (n16.key[child_pos] == byte) { - break; - } - } - - D_ASSERT(child_pos < n16.count); - - // free the child and decrease the count - Node::Free(art, n16.children[child_pos]); - n16.count--; - - // potentially move any children backwards - for (idx_t i = child_pos; i < n16.count; i++) { - n16.key[i] = n16.key[i + 1]; - n16.children[i] = n16.children[i + 1]; - } - - // shrink node to Node4 - if (n16.count < Node::NODE_4_CAPACITY) { - auto node16 = node; - Node4::ShrinkNode16(art, node, node16); - } -} - -void Node16::ReplaceChild(const uint8_t byte, const Node child) { - for (idx_t i = 0; i < count; i++) { - if (key[i] == byte) { - children[i] = child; - return; - } - } -} - -optional_ptr Node16::GetChild(const uint8_t byte) const { - for (idx_t i = 0; i < count; i++) { - if (key[i] == byte) { - D_ASSERT(children[i].HasMetadata()); - return &children[i]; - } - } - return nullptr; -} - -optional_ptr Node16::GetChildMutable(const uint8_t byte) { - for (idx_t i = 0; i < count; i++) { - if (key[i] == byte) { - D_ASSERT(children[i].HasMetadata()); - return &children[i]; - } - } - return nullptr; -} - -optional_ptr Node16::GetNextChild(uint8_t &byte) const { - for (idx_t i = 0; i < count; i++) { - if (key[i] >= byte) { - byte = key[i]; - D_ASSERT(children[i].HasMetadata()); - return &children[i]; - } - } - return nullptr; -} - -optional_ptr Node16::GetNextChildMutable(uint8_t &byte) { - for (idx_t i = 0; i < count; i++) { - if (key[i] >= byte) { - byte = key[i]; - D_ASSERT(children[i].HasMetadata()); - return &children[i]; - } - } - return nullptr; -} - -void Node16::Vacuum(ART &art, const ARTFlags &flags) { - - for (idx_t i = 0; i < count; i++) { - children[i].Vacuum(art, flags); - } -} - -} // namespace duckdb - - - - -namespace duckdb { - -Node256 &Node256::New(ART &art, Node &node) { - - node = Node::GetAllocator(art, NType::NODE_256).New(); - node.SetMetadata(static_cast(NType::NODE_256)); - auto &n256 = Node::RefMutable(art, node, NType::NODE_256); - - n256.count = 0; - for (idx_t i = 0; i < Node::NODE_256_CAPACITY; i++) { - n256.children[i].Clear(); - } - - return n256; -} - -void Node256::Free(ART &art, Node &node) { - - D_ASSERT(node.HasMetadata()); - auto &n256 = Node::RefMutable(art, node, NType::NODE_256); - - if (!n256.count) { - return; - } - - // free all children - for (idx_t i = 0; i < Node::NODE_256_CAPACITY; i++) { - if (n256.children[i].HasMetadata()) { - Node::Free(art, n256.children[i]); - } - } -} - -Node256 &Node256::GrowNode48(ART &art, Node &node256, Node &node48) { - - auto &n48 = Node::RefMutable(art, node48, NType::NODE_48); - auto &n256 = New(art, node256); - - n256.count = n48.count; - for (idx_t i = 0; i < Node::NODE_256_CAPACITY; i++) { - if (n48.child_index[i] != Node::EMPTY_MARKER) { - n256.children[i] = n48.children[n48.child_index[i]]; - } else { - n256.children[i].Clear(); - } - } - - n48.count = 0; - Node::Free(art, node48); - return n256; -} - -void Node256::InitializeMerge(ART &art, const ARTFlags &flags) { - - for (idx_t i = 0; i < Node::NODE_256_CAPACITY; i++) { - if (children[i].HasMetadata()) { - children[i].InitializeMerge(art, flags); - } - } -} - -void Node256::InsertChild(ART &art, Node &node, const uint8_t byte, const Node child) { - - D_ASSERT(node.HasMetadata()); - auto &n256 = Node::RefMutable(art, node, NType::NODE_256); - - // ensure that there is no other child at the same byte - D_ASSERT(!n256.children[byte].HasMetadata()); - - n256.count++; - D_ASSERT(n256.count <= Node::NODE_256_CAPACITY); - n256.children[byte] = child; -} - -void Node256::DeleteChild(ART &art, Node &node, const uint8_t byte) { - - D_ASSERT(node.HasMetadata()); - auto &n256 = Node::RefMutable(art, node, NType::NODE_256); - - // free the child and decrease the count - Node::Free(art, n256.children[byte]); - n256.count--; - - // shrink node to Node48 - if (n256.count <= Node::NODE_256_SHRINK_THRESHOLD) { - auto node256 = node; - Node48::ShrinkNode256(art, node, node256); - } -} - -optional_ptr Node256::GetChild(const uint8_t byte) const { - if (children[byte].HasMetadata()) { - return &children[byte]; - } - return nullptr; -} - -optional_ptr Node256::GetChildMutable(const uint8_t byte) { - if (children[byte].HasMetadata()) { - return &children[byte]; - } - return nullptr; -} - -optional_ptr Node256::GetNextChild(uint8_t &byte) const { - for (idx_t i = byte; i < Node::NODE_256_CAPACITY; i++) { - if (children[i].HasMetadata()) { - byte = i; - return &children[i]; - } - } - return nullptr; -} - -optional_ptr Node256::GetNextChildMutable(uint8_t &byte) { - for (idx_t i = byte; i < Node::NODE_256_CAPACITY; i++) { - if (children[i].HasMetadata()) { - byte = i; - return &children[i]; - } - } - return nullptr; -} - -void Node256::Vacuum(ART &art, const ARTFlags &flags) { - - for (idx_t i = 0; i < Node::NODE_256_CAPACITY; i++) { - if (children[i].HasMetadata()) { - children[i].Vacuum(art, flags); - } - } -} - -} // namespace duckdb - - - - - -namespace duckdb { - -Node4 &Node4::New(ART &art, Node &node) { - - node = Node::GetAllocator(art, NType::NODE_4).New(); - node.SetMetadata(static_cast(NType::NODE_4)); - auto &n4 = Node::RefMutable(art, node, NType::NODE_4); - - n4.count = 0; - return n4; -} - -void Node4::Free(ART &art, Node &node) { - - D_ASSERT(node.HasMetadata()); - auto &n4 = Node::RefMutable(art, node, NType::NODE_4); - - // free all children - for (idx_t i = 0; i < n4.count; i++) { - Node::Free(art, n4.children[i]); - } -} - -Node4 &Node4::ShrinkNode16(ART &art, Node &node4, Node &node16) { - - auto &n4 = New(art, node4); - auto &n16 = Node::RefMutable(art, node16, NType::NODE_16); - - D_ASSERT(n16.count <= Node::NODE_4_CAPACITY); - n4.count = n16.count; - for (idx_t i = 0; i < n16.count; i++) { - n4.key[i] = n16.key[i]; - n4.children[i] = n16.children[i]; - } - - n16.count = 0; - Node::Free(art, node16); - return n4; -} - -void Node4::InitializeMerge(ART &art, const ARTFlags &flags) { - - for (idx_t i = 0; i < count; i++) { - children[i].InitializeMerge(art, flags); - } -} - -void Node4::InsertChild(ART &art, Node &node, const uint8_t byte, const Node child) { - - D_ASSERT(node.HasMetadata()); - auto &n4 = Node::RefMutable(art, node, NType::NODE_4); - - // ensure that there is no other child at the same byte - for (idx_t i = 0; i < n4.count; i++) { - D_ASSERT(n4.key[i] != byte); - } - - // insert new child node into node - if (n4.count < Node::NODE_4_CAPACITY) { - // still space, just insert the child - idx_t child_pos = 0; - while (child_pos < n4.count && n4.key[child_pos] < byte) { - child_pos++; - } - // move children backwards to make space - for (idx_t i = n4.count; i > child_pos; i--) { - n4.key[i] = n4.key[i - 1]; - n4.children[i] = n4.children[i - 1]; - } - - n4.key[child_pos] = byte; - n4.children[child_pos] = child; - n4.count++; - - } else { - // node is full, grow to Node16 - auto node4 = node; - Node16::GrowNode4(art, node, node4); - Node16::InsertChild(art, node, byte, child); - } -} - -void Node4::DeleteChild(ART &art, Node &node, Node &prefix, const uint8_t byte) { - - D_ASSERT(node.HasMetadata()); - auto &n4 = Node::RefMutable(art, node, NType::NODE_4); - - idx_t child_pos = 0; - for (; child_pos < n4.count; child_pos++) { - if (n4.key[child_pos] == byte) { - break; - } - } - - D_ASSERT(child_pos < n4.count); - D_ASSERT(n4.count > 1); - - // free the child and decrease the count - Node::Free(art, n4.children[child_pos]); - n4.count--; - - // potentially move any children backwards - for (idx_t i = child_pos; i < n4.count; i++) { - n4.key[i] = n4.key[i + 1]; - n4.children[i] = n4.children[i + 1]; - } - - // this is a one way node, compress - if (n4.count == 1) { - - // we need to keep track of the old node pointer - // because Concatenate() might overwrite that pointer while appending bytes to - // the prefix (and by doing so overwriting the subsequent node with - // new prefix nodes) - auto old_n4_node = node; - - // get only child and concatenate prefixes - auto child = *n4.GetChildMutable(n4.key[0]); - Prefix::Concatenate(art, prefix, n4.key[0], child); - - n4.count--; - Node::Free(art, old_n4_node); - } -} - -void Node4::ReplaceChild(const uint8_t byte, const Node child) { - for (idx_t i = 0; i < count; i++) { - if (key[i] == byte) { - children[i] = child; - return; - } - } -} - -optional_ptr Node4::GetChild(const uint8_t byte) const { - for (idx_t i = 0; i < count; i++) { - if (key[i] == byte) { - D_ASSERT(children[i].HasMetadata()); - return &children[i]; - } - } - return nullptr; -} - -optional_ptr Node4::GetChildMutable(const uint8_t byte) { - for (idx_t i = 0; i < count; i++) { - if (key[i] == byte) { - D_ASSERT(children[i].HasMetadata()); - return &children[i]; - } - } - return nullptr; -} - -optional_ptr Node4::GetNextChild(uint8_t &byte) const { - for (idx_t i = 0; i < count; i++) { - if (key[i] >= byte) { - byte = key[i]; - D_ASSERT(children[i].HasMetadata()); - return &children[i]; - } - } - return nullptr; -} - -optional_ptr Node4::GetNextChildMutable(uint8_t &byte) { - for (idx_t i = 0; i < count; i++) { - if (key[i] >= byte) { - byte = key[i]; - D_ASSERT(children[i].HasMetadata()); - return &children[i]; - } - } - return nullptr; -} - -void Node4::Vacuum(ART &art, const ARTFlags &flags) { - - for (idx_t i = 0; i < count; i++) { - children[i].Vacuum(art, flags); - } -} - -} // namespace duckdb - - - - - -namespace duckdb { - -Node48 &Node48::New(ART &art, Node &node) { - - node = Node::GetAllocator(art, NType::NODE_48).New(); - node.SetMetadata(static_cast(NType::NODE_48)); - auto &n48 = Node::RefMutable(art, node, NType::NODE_48); - - n48.count = 0; - for (idx_t i = 0; i < Node::NODE_256_CAPACITY; i++) { - n48.child_index[i] = Node::EMPTY_MARKER; - } - for (idx_t i = 0; i < Node::NODE_48_CAPACITY; i++) { - n48.children[i].Clear(); - } - - return n48; -} - -void Node48::Free(ART &art, Node &node) { - - D_ASSERT(node.HasMetadata()); - auto &n48 = Node::RefMutable(art, node, NType::NODE_48); - - if (!n48.count) { - return; - } - - // free all children - for (idx_t i = 0; i < Node::NODE_256_CAPACITY; i++) { - if (n48.child_index[i] != Node::EMPTY_MARKER) { - Node::Free(art, n48.children[n48.child_index[i]]); - } - } -} - -Node48 &Node48::GrowNode16(ART &art, Node &node48, Node &node16) { - - auto &n16 = Node::RefMutable(art, node16, NType::NODE_16); - auto &n48 = New(art, node48); - - n48.count = n16.count; - for (idx_t i = 0; i < Node::NODE_256_CAPACITY; i++) { - n48.child_index[i] = Node::EMPTY_MARKER; - } - - for (idx_t i = 0; i < n16.count; i++) { - n48.child_index[n16.key[i]] = i; - n48.children[i] = n16.children[i]; - } - - // necessary for faster child insertion/deletion - for (idx_t i = n16.count; i < Node::NODE_48_CAPACITY; i++) { - n48.children[i].Clear(); - } - - n16.count = 0; - Node::Free(art, node16); - return n48; -} - -Node48 &Node48::ShrinkNode256(ART &art, Node &node48, Node &node256) { - - auto &n48 = New(art, node48); - auto &n256 = Node::RefMutable(art, node256, NType::NODE_256); - - n48.count = 0; - for (idx_t i = 0; i < Node::NODE_256_CAPACITY; i++) { - D_ASSERT(n48.count <= Node::NODE_48_CAPACITY); - if (n256.children[i].HasMetadata()) { - n48.child_index[i] = n48.count; - n48.children[n48.count] = n256.children[i]; - n48.count++; - } else { - n48.child_index[i] = Node::EMPTY_MARKER; - } - } - - // necessary for faster child insertion/deletion - for (idx_t i = n48.count; i < Node::NODE_48_CAPACITY; i++) { - n48.children[i].Clear(); - } - - n256.count = 0; - Node::Free(art, node256); - return n48; -} - -void Node48::InitializeMerge(ART &art, const ARTFlags &flags) { - - for (idx_t i = 0; i < Node::NODE_256_CAPACITY; i++) { - if (child_index[i] != Node::EMPTY_MARKER) { - children[child_index[i]].InitializeMerge(art, flags); - } - } -} - -void Node48::InsertChild(ART &art, Node &node, const uint8_t byte, const Node child) { - - D_ASSERT(node.HasMetadata()); - auto &n48 = Node::RefMutable(art, node, NType::NODE_48); - - // ensure that there is no other child at the same byte - D_ASSERT(n48.child_index[byte] == Node::EMPTY_MARKER); - - // insert new child node into node - if (n48.count < Node::NODE_48_CAPACITY) { - // still space, just insert the child - idx_t child_pos = n48.count; - if (n48.children[child_pos].HasMetadata()) { - // find an empty position in the node list if the current position is occupied - child_pos = 0; - while (n48.children[child_pos].HasMetadata()) { - child_pos++; - } - } - n48.children[child_pos] = child; - n48.child_index[byte] = child_pos; - n48.count++; - - } else { - // node is full, grow to Node256 - auto node48 = node; - Node256::GrowNode48(art, node, node48); - Node256::InsertChild(art, node, byte, child); - } -} - -void Node48::DeleteChild(ART &art, Node &node, const uint8_t byte) { - - D_ASSERT(node.HasMetadata()); - auto &n48 = Node::RefMutable(art, node, NType::NODE_48); - - // free the child and decrease the count - Node::Free(art, n48.children[n48.child_index[byte]]); - n48.child_index[byte] = Node::EMPTY_MARKER; - n48.count--; - - // shrink node to Node16 - if (n48.count < Node::NODE_48_SHRINK_THRESHOLD) { - auto node48 = node; - Node16::ShrinkNode48(art, node, node48); - } -} - -optional_ptr Node48::GetChild(const uint8_t byte) const { - if (child_index[byte] != Node::EMPTY_MARKER) { - D_ASSERT(children[child_index[byte]].HasMetadata()); - return &children[child_index[byte]]; - } - return nullptr; -} - -optional_ptr Node48::GetChildMutable(const uint8_t byte) { - if (child_index[byte] != Node::EMPTY_MARKER) { - D_ASSERT(children[child_index[byte]].HasMetadata()); - return &children[child_index[byte]]; - } - return nullptr; -} - -optional_ptr Node48::GetNextChild(uint8_t &byte) const { - for (idx_t i = byte; i < Node::NODE_256_CAPACITY; i++) { - if (child_index[i] != Node::EMPTY_MARKER) { - byte = i; - D_ASSERT(children[child_index[i]].HasMetadata()); - return &children[child_index[i]]; - } - } - return nullptr; -} - -optional_ptr Node48::GetNextChildMutable(uint8_t &byte) { - for (idx_t i = byte; i < Node::NODE_256_CAPACITY; i++) { - if (child_index[i] != Node::EMPTY_MARKER) { - byte = i; - D_ASSERT(children[child_index[i]].HasMetadata()); - return &children[child_index[i]]; - } - } - return nullptr; -} - -void Node48::Vacuum(ART &art, const ARTFlags &flags) { - - for (idx_t i = 0; i < Node::NODE_256_CAPACITY; i++) { - if (child_index[i] != Node::EMPTY_MARKER) { - children[child_index[i]].Vacuum(art, flags); - } - } -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -Prefix &Prefix::New(ART &art, Node &node) { - - node = Node::GetAllocator(art, NType::PREFIX).New(); - node.SetMetadata(static_cast(NType::PREFIX)); - - auto &prefix = Node::RefMutable(art, node, NType::PREFIX); - prefix.data[Node::PREFIX_SIZE] = 0; - return prefix; -} - -Prefix &Prefix::New(ART &art, Node &node, uint8_t byte, const Node &next) { - - node = Node::GetAllocator(art, NType::PREFIX).New(); - node.SetMetadata(static_cast(NType::PREFIX)); - - auto &prefix = Node::RefMutable(art, node, NType::PREFIX); - prefix.data[Node::PREFIX_SIZE] = 1; - prefix.data[0] = byte; - prefix.ptr = next; - return prefix; -} - -void Prefix::New(ART &art, reference &node, const ARTKey &key, const uint32_t depth, uint32_t count) { - - if (count == 0) { - return; - } - idx_t copy_count = 0; - - while (count) { - node.get() = Node::GetAllocator(art, NType::PREFIX).New(); - node.get().SetMetadata(static_cast(NType::PREFIX)); - auto &prefix = Node::RefMutable(art, node, NType::PREFIX); - - auto this_count = MinValue((uint32_t)Node::PREFIX_SIZE, count); - prefix.data[Node::PREFIX_SIZE] = (uint8_t)this_count; - memcpy(prefix.data, key.data + depth + copy_count, this_count); - - node = prefix.ptr; - copy_count += this_count; - count -= this_count; - } -} - -void Prefix::Free(ART &art, Node &node) { - - Node current_node = node; - Node next_node; - while (current_node.HasMetadata() && current_node.GetType() == NType::PREFIX) { - next_node = Node::RefMutable(art, current_node, NType::PREFIX).ptr; - Node::GetAllocator(art, NType::PREFIX).Free(current_node); - current_node = next_node; - } - - Node::Free(art, current_node); - node.Clear(); -} - -void Prefix::InitializeMerge(ART &art, Node &node, const ARTFlags &flags) { - - auto merge_buffer_count = flags.merge_buffer_counts[static_cast(NType::PREFIX) - 1]; - - Node next_node = node; - reference prefix = Node::RefMutable(art, next_node, NType::PREFIX); - - while (next_node.GetType() == NType::PREFIX) { - next_node = prefix.get().ptr; - if (prefix.get().ptr.GetType() == NType::PREFIX) { - prefix.get().ptr.IncreaseBufferId(merge_buffer_count); - prefix = Node::RefMutable(art, next_node, NType::PREFIX); - } - } - - node.IncreaseBufferId(merge_buffer_count); - prefix.get().ptr.InitializeMerge(art, flags); -} - -void Prefix::Concatenate(ART &art, Node &prefix_node, const uint8_t byte, Node &child_prefix_node) { - - D_ASSERT(prefix_node.HasMetadata() && child_prefix_node.HasMetadata()); - - // append a byte and a child_prefix to prefix - if (prefix_node.GetType() == NType::PREFIX) { - - // get the tail - reference prefix = Node::RefMutable(art, prefix_node, NType::PREFIX); - D_ASSERT(prefix.get().ptr.HasMetadata()); - - while (prefix.get().ptr.GetType() == NType::PREFIX) { - prefix = Node::RefMutable(art, prefix.get().ptr, NType::PREFIX); - D_ASSERT(prefix.get().ptr.HasMetadata()); - } - - // append the byte - prefix = prefix.get().Append(art, byte); - - if (child_prefix_node.GetType() == NType::PREFIX) { - // append the child prefix - prefix.get().Append(art, child_prefix_node); - } else { - // set child_prefix_node to succeed prefix - prefix.get().ptr = child_prefix_node; - } - return; - } - - // create a new prefix node containing the byte, then append the child_prefix to it - if (prefix_node.GetType() != NType::PREFIX && child_prefix_node.GetType() == NType::PREFIX) { - - auto child_prefix = child_prefix_node; - auto &prefix = New(art, prefix_node, byte); - prefix.Append(art, child_prefix); - return; - } - - // neither prefix nor child_prefix are prefix nodes - // create a new prefix containing the byte - New(art, prefix_node, byte, child_prefix_node); -} - -idx_t Prefix::Traverse(ART &art, reference &prefix_node, const ARTKey &key, idx_t &depth) { - - D_ASSERT(prefix_node.get().HasMetadata()); - D_ASSERT(prefix_node.get().GetType() == NType::PREFIX); - - // compare prefix nodes to key bytes - while (prefix_node.get().GetType() == NType::PREFIX) { - auto &prefix = Node::Ref(art, prefix_node, NType::PREFIX); - for (idx_t i = 0; i < prefix.data[Node::PREFIX_SIZE]; i++) { - if (prefix.data[i] != key[depth]) { - return i; - } - depth++; - } - prefix_node = prefix.ptr; - D_ASSERT(prefix_node.get().HasMetadata()); - } - - return DConstants::INVALID_INDEX; -} - -idx_t Prefix::TraverseMutable(ART &art, reference &prefix_node, const ARTKey &key, idx_t &depth) { - - D_ASSERT(prefix_node.get().HasMetadata()); - D_ASSERT(prefix_node.get().GetType() == NType::PREFIX); - - // compare prefix nodes to key bytes - while (prefix_node.get().GetType() == NType::PREFIX) { - auto &prefix = Node::RefMutable(art, prefix_node, NType::PREFIX); - for (idx_t i = 0; i < prefix.data[Node::PREFIX_SIZE]; i++) { - if (prefix.data[i] != key[depth]) { - return i; - } - depth++; - } - prefix_node = prefix.ptr; - D_ASSERT(prefix_node.get().HasMetadata()); - } - - return DConstants::INVALID_INDEX; -} - -bool Prefix::Traverse(ART &art, reference &l_node, reference &r_node, idx_t &mismatch_position) { - - auto &l_prefix = Node::RefMutable(art, l_node.get(), NType::PREFIX); - auto &r_prefix = Node::RefMutable(art, r_node.get(), NType::PREFIX); - - // compare prefix bytes - idx_t max_count = MinValue(l_prefix.data[Node::PREFIX_SIZE], r_prefix.data[Node::PREFIX_SIZE]); - for (idx_t i = 0; i < max_count; i++) { - if (l_prefix.data[i] != r_prefix.data[i]) { - mismatch_position = i; - break; - } - } - - if (mismatch_position == DConstants::INVALID_INDEX) { - - // prefixes match (so far) - if (l_prefix.data[Node::PREFIX_SIZE] == r_prefix.data[Node::PREFIX_SIZE]) { - return l_prefix.ptr.ResolvePrefixes(art, r_prefix.ptr); - } - - mismatch_position = max_count; - - // l_prefix contains r_prefix - if (r_prefix.ptr.GetType() != NType::PREFIX && r_prefix.data[Node::PREFIX_SIZE] == max_count) { - swap(l_node.get(), r_node.get()); - l_node = r_prefix.ptr; - - } else { - // r_prefix contains l_prefix - l_node = l_prefix.ptr; - } - } - - return true; -} - -void Prefix::Reduce(ART &art, Node &prefix_node, const idx_t n) { - - D_ASSERT(prefix_node.HasMetadata()); - D_ASSERT(n < Node::PREFIX_SIZE); - - reference prefix = Node::RefMutable(art, prefix_node, NType::PREFIX); - - // free this prefix node - if (n == (idx_t)(prefix.get().data[Node::PREFIX_SIZE] - 1)) { - auto next_ptr = prefix.get().ptr; - D_ASSERT(next_ptr.HasMetadata()); - prefix.get().ptr.Clear(); - Node::Free(art, prefix_node); - prefix_node = next_ptr; - return; - } - - // shift by n bytes in the current prefix - for (idx_t i = 0; i < Node::PREFIX_SIZE - n - 1; i++) { - prefix.get().data[i] = prefix.get().data[n + i + 1]; - } - D_ASSERT(n < (idx_t)(prefix.get().data[Node::PREFIX_SIZE] - 1)); - prefix.get().data[Node::PREFIX_SIZE] -= n + 1; - - // append the remaining prefix bytes - prefix.get().Append(art, prefix.get().ptr); -} - -void Prefix::Split(ART &art, reference &prefix_node, Node &child_node, idx_t position) { - - D_ASSERT(prefix_node.get().HasMetadata()); - - auto &prefix = Node::RefMutable(art, prefix_node, NType::PREFIX); - - // the split is at the last byte of this prefix, so the child_node contains all subsequent - // prefix nodes (prefix.ptr) (if any), and the count of this prefix decreases by one, - // then, we reference prefix.ptr, to overwrite it with a new node later - if (position + 1 == Node::PREFIX_SIZE) { - prefix.data[Node::PREFIX_SIZE]--; - prefix_node = prefix.ptr; - child_node = prefix.ptr; - return; - } - - // append the remaining bytes after the split - if (position + 1 < prefix.data[Node::PREFIX_SIZE]) { - reference child_prefix = New(art, child_node); - for (idx_t i = position + 1; i < prefix.data[Node::PREFIX_SIZE]; i++) { - child_prefix = child_prefix.get().Append(art, prefix.data[i]); - } - - D_ASSERT(prefix.ptr.HasMetadata()); - - if (prefix.ptr.GetType() == NType::PREFIX) { - child_prefix.get().Append(art, prefix.ptr); - } else { - // this is the last prefix node of the prefix - child_prefix.get().ptr = prefix.ptr; - } - } - - // this is the last prefix node of the prefix - if (position + 1 == prefix.data[Node::PREFIX_SIZE]) { - child_node = prefix.ptr; - } - - // set the new size of this node - prefix.data[Node::PREFIX_SIZE] = position; - - // no bytes left before the split, free this node - if (position == 0) { - prefix.ptr.Clear(); - Node::Free(art, prefix_node.get()); - return; - } - - // bytes left before the split, reference subsequent node - prefix_node = prefix.ptr; - return; -} - -string Prefix::VerifyAndToString(ART &art, const Node &node, const bool only_verify) { - - // NOTE: we could do this recursively, but the function-call overhead can become kinda crazy - string str = ""; - - reference node_ref(node); - while (node_ref.get().GetType() == NType::PREFIX) { - - auto &prefix = Node::Ref(art, node_ref, NType::PREFIX); - D_ASSERT(prefix.data[Node::PREFIX_SIZE] != 0); - D_ASSERT(prefix.data[Node::PREFIX_SIZE] <= Node::PREFIX_SIZE); - - str += " prefix_bytes:["; - for (idx_t i = 0; i < prefix.data[Node::PREFIX_SIZE]; i++) { - str += to_string(prefix.data[i]) + "-"; - } - str += "] "; - - node_ref = prefix.ptr; - } - - auto subtree = node_ref.get().VerifyAndToString(art, only_verify); - return only_verify ? "" : str + subtree; -} - -void Prefix::Vacuum(ART &art, Node &node, const ARTFlags &flags) { - - bool flag_set = flags.vacuum_flags[static_cast(NType::PREFIX) - 1]; - auto &allocator = Node::GetAllocator(art, NType::PREFIX); - - reference node_ref(node); - while (node_ref.get().GetType() == NType::PREFIX) { - if (flag_set && allocator.NeedsVacuum(node_ref)) { - node_ref.get() = allocator.VacuumPointer(node_ref); - node_ref.get().SetMetadata(static_cast(NType::PREFIX)); - } - auto &prefix = Node::RefMutable(art, node_ref, NType::PREFIX); - node_ref = prefix.ptr; - } - - node_ref.get().Vacuum(art, flags); -} - -Prefix &Prefix::Append(ART &art, const uint8_t byte) { - - reference prefix(*this); - - // we need a new prefix node - if (prefix.get().data[Node::PREFIX_SIZE] == Node::PREFIX_SIZE) { - prefix = New(art, prefix.get().ptr); - } - - prefix.get().data[prefix.get().data[Node::PREFIX_SIZE]] = byte; - prefix.get().data[Node::PREFIX_SIZE]++; - return prefix.get(); -} - -void Prefix::Append(ART &art, Node other_prefix) { - - D_ASSERT(other_prefix.HasMetadata()); - - reference prefix(*this); - while (other_prefix.GetType() == NType::PREFIX) { - - // copy prefix bytes - auto &other = Node::RefMutable(art, other_prefix, NType::PREFIX); - for (idx_t i = 0; i < other.data[Node::PREFIX_SIZE]; i++) { - prefix = prefix.get().Append(art, other.data[i]); - } - - D_ASSERT(other.ptr.HasMetadata()); - - prefix.get().ptr = other.ptr; - Node::GetAllocator(art, NType::PREFIX).Free(other_prefix); - other_prefix = prefix.get().ptr; - } - - D_ASSERT(prefix.get().ptr.GetType() != NType::PREFIX); -} - -} // namespace duckdb - - - - -namespace duckdb { - -FixedSizeAllocator::FixedSizeAllocator(const idx_t segment_size, BlockManager &block_manager) - : block_manager(block_manager), buffer_manager(block_manager.buffer_manager), - metadata_manager(block_manager.GetMetadataManager()), segment_size(segment_size), total_segment_count(0) { - - if (segment_size > Storage::BLOCK_SIZE - sizeof(validity_t)) { - throw InternalException("The maximum segment size of fixed-size allocators is " + - to_string(Storage::BLOCK_SIZE - sizeof(validity_t))); - } - - // calculate how many segments fit into one buffer (available_segments_per_buffer) - - idx_t bits_per_value = sizeof(validity_t) * 8; - idx_t byte_count = 0; - - bitmask_count = 0; - available_segments_per_buffer = 0; - - while (byte_count < Storage::BLOCK_SIZE) { - if (!bitmask_count || (bitmask_count * bits_per_value) % available_segments_per_buffer == 0) { - // we need to add another validity_t value to the bitmask, to allow storing another - // bits_per_value segments on a buffer - bitmask_count++; - byte_count += sizeof(validity_t); - } - - auto remaining_bytes = Storage::BLOCK_SIZE - byte_count; - auto remaining_segments = MinValue(remaining_bytes / segment_size, bits_per_value); - - if (remaining_segments == 0) { - break; - } - - available_segments_per_buffer += remaining_segments; - byte_count += remaining_segments * segment_size; - } - - bitmask_offset = bitmask_count * sizeof(validity_t); -} - -IndexPointer FixedSizeAllocator::New() { - - // no more segments available - if (buffers_with_free_space.empty()) { - - // add a new buffer - auto buffer_id = GetAvailableBufferId(); - FixedSizeBuffer new_buffer(block_manager); - buffers.insert(make_pair(buffer_id, std::move(new_buffer))); - buffers_with_free_space.insert(buffer_id); - - // set the bitmask - D_ASSERT(buffers.find(buffer_id) != buffers.end()); - auto &buffer = buffers.find(buffer_id)->second; - ValidityMask mask(reinterpret_cast(buffer.Get())); - - // zero-initialize the bitmask to avoid leaking memory to disk - auto data = mask.GetData(); - for (idx_t i = 0; i < bitmask_count; i++) { - data[i] = 0; - } - - // initializing the bitmask of the new buffer - mask.SetAllValid(available_segments_per_buffer); - } - - // return a pointer to a free segment - D_ASSERT(!buffers_with_free_space.empty()); - auto buffer_id = uint32_t(*buffers_with_free_space.begin()); - - D_ASSERT(buffers.find(buffer_id) != buffers.end()); - auto &buffer = buffers.find(buffer_id)->second; - auto offset = buffer.GetOffset(bitmask_count); - - total_segment_count++; - buffer.segment_count++; - if (buffer.segment_count == available_segments_per_buffer) { - buffers_with_free_space.erase(buffer_id); - } - - // zero-initialize that segment - auto buffer_ptr = buffer.Get(); - auto offset_in_buffer = buffer_ptr + offset * segment_size + bitmask_offset; - memset(offset_in_buffer, 0, segment_size); - - return IndexPointer(buffer_id, offset); -} - -void FixedSizeAllocator::Free(const IndexPointer ptr) { - - auto buffer_id = ptr.GetBufferId(); - auto offset = ptr.GetOffset(); - - D_ASSERT(buffers.find(buffer_id) != buffers.end()); - auto &buffer = buffers.find(buffer_id)->second; - - auto bitmask_ptr = reinterpret_cast(buffer.Get()); - ValidityMask mask(bitmask_ptr); - D_ASSERT(!mask.RowIsValid(offset)); - mask.SetValid(offset); - - D_ASSERT(total_segment_count > 0); - D_ASSERT(buffer.segment_count > 0); - - // adjust the allocator fields - buffers_with_free_space.insert(buffer_id); - total_segment_count--; - buffer.segment_count--; -} - -void FixedSizeAllocator::Reset() { - for (auto &buffer : buffers) { - buffer.second.Destroy(); - } - buffers.clear(); - buffers_with_free_space.clear(); - total_segment_count = 0; -} - -idx_t FixedSizeAllocator::GetMemoryUsage() const { - idx_t memory_usage = 0; - for (auto &buffer : buffers) { - if (buffer.second.InMemory()) { - memory_usage += Storage::BLOCK_SIZE; - } - } - return memory_usage; -} - -idx_t FixedSizeAllocator::GetUpperBoundBufferId() const { - idx_t upper_bound_id = 0; - for (auto &buffer : buffers) { - if (buffer.first >= upper_bound_id) { - upper_bound_id = buffer.first + 1; - } - } - return upper_bound_id; -} - -void FixedSizeAllocator::Merge(FixedSizeAllocator &other) { - - D_ASSERT(segment_size == other.segment_size); - - // remember the buffer count and merge the buffers - idx_t upper_bound_id = GetUpperBoundBufferId(); - for (auto &buffer : other.buffers) { - buffers.insert(make_pair(buffer.first + upper_bound_id, std::move(buffer.second))); - } - other.buffers.clear(); - - // merge the buffers with free spaces - for (auto &buffer_id : other.buffers_with_free_space) { - buffers_with_free_space.insert(buffer_id + upper_bound_id); - } - other.buffers_with_free_space.clear(); - - // add the total allocations - total_segment_count += other.total_segment_count; -} - -bool FixedSizeAllocator::InitializeVacuum() { - - // NOTE: we do not vacuum buffers that are not in memory. We might consider changing this - // in the future, although buffers on disk should almost never be eligible for a vacuum - - if (total_segment_count == 0) { - Reset(); - return false; - } - - // remove all empty buffers - auto buffer_it = buffers.begin(); - while (buffer_it != buffers.end()) { - if (!buffer_it->second.segment_count) { - buffers_with_free_space.erase(buffer_it->first); - buffer_it->second.Destroy(); - buffer_it = buffers.erase(buffer_it); - } else { - buffer_it++; - } - } - - // determine if a vacuum is necessary - multimap temporary_vacuum_buffers; - D_ASSERT(vacuum_buffers.empty()); - idx_t available_segments_in_memory = 0; - - for (auto &buffer : buffers) { - buffer.second.vacuum = false; - if (buffer.second.InMemory()) { - auto available_segments_in_buffer = available_segments_per_buffer - buffer.second.segment_count; - available_segments_in_memory += available_segments_in_buffer; - temporary_vacuum_buffers.emplace(available_segments_in_buffer, buffer.first); - } - } - - // no buffers in memory - if (temporary_vacuum_buffers.empty()) { - return false; - } - - auto excess_buffer_count = available_segments_in_memory / available_segments_per_buffer; - - // calculate the vacuum threshold adaptively - D_ASSERT(excess_buffer_count < temporary_vacuum_buffers.size()); - idx_t memory_usage = GetMemoryUsage(); - idx_t excess_memory_usage = excess_buffer_count * Storage::BLOCK_SIZE; - auto excess_percentage = double(excess_memory_usage) / double(memory_usage); - auto threshold = double(VACUUM_THRESHOLD) / 100.0; - if (excess_percentage < threshold) { - return false; - } - - D_ASSERT(excess_buffer_count <= temporary_vacuum_buffers.size()); - D_ASSERT(temporary_vacuum_buffers.size() <= buffers.size()); - - // erasing from a multimap, we vacuum the buffers with the most free spaces (least full) - while (temporary_vacuum_buffers.size() != excess_buffer_count) { - temporary_vacuum_buffers.erase(temporary_vacuum_buffers.begin()); - } - - // adjust the buffers, and erase all to-be-vacuumed buffers from the available buffer list - for (auto &vacuum_buffer : temporary_vacuum_buffers) { - auto buffer_id = vacuum_buffer.second; - D_ASSERT(buffers.find(buffer_id) != buffers.end()); - buffers.find(buffer_id)->second.vacuum = true; - buffers_with_free_space.erase(buffer_id); - } - - for (auto &vacuum_buffer : temporary_vacuum_buffers) { - vacuum_buffers.insert(vacuum_buffer.second); - } - - return true; -} - -void FixedSizeAllocator::FinalizeVacuum() { - - for (auto &buffer_id : vacuum_buffers) { - D_ASSERT(buffers.find(buffer_id) != buffers.end()); - auto &buffer = buffers.find(buffer_id)->second; - D_ASSERT(buffer.InMemory()); - buffer.Destroy(); - buffers.erase(buffer_id); - } - vacuum_buffers.clear(); -} - -IndexPointer FixedSizeAllocator::VacuumPointer(const IndexPointer ptr) { - - // we do not need to adjust the bitmask of the old buffer, because we will free the entire - // buffer after the vacuum operation - - auto new_ptr = New(); - // new increases the allocation count, we need to counter that here - total_segment_count--; - - memcpy(Get(new_ptr), Get(ptr), segment_size); - return new_ptr; -} - -BlockPointer FixedSizeAllocator::Serialize(PartialBlockManager &partial_block_manager, MetadataWriter &writer) { - - for (auto &buffer : buffers) { - buffer.second.Serialize(partial_block_manager, available_segments_per_buffer, segment_size, bitmask_offset); - } - - auto block_pointer = writer.GetBlockPointer(); - writer.Write(segment_size); - writer.Write(static_cast(buffers.size())); - writer.Write(static_cast(buffers_with_free_space.size())); - - for (auto &buffer : buffers) { - writer.Write(buffer.first); - writer.Write(buffer.second.block_pointer); - writer.Write(buffer.second.segment_count); - writer.Write(buffer.second.allocation_size); - } - for (auto &buffer_id : buffers_with_free_space) { - writer.Write(buffer_id); - } - - return block_pointer; -} - -void FixedSizeAllocator::Deserialize(const BlockPointer &block_pointer) { - - MetadataReader reader(metadata_manager, block_pointer); - segment_size = reader.Read(); - auto buffer_count = reader.Read(); - auto buffers_with_free_space_count = reader.Read(); - - total_segment_count = 0; - - for (idx_t i = 0; i < buffer_count; i++) { - auto buffer_id = reader.Read(); - auto buffer_block_pointer = reader.Read(); - auto segment_count = reader.Read(); - auto allocation_size = reader.Read(); - FixedSizeBuffer new_buffer(block_manager, segment_count, allocation_size, buffer_block_pointer); - buffers.insert(make_pair(buffer_id, std::move(new_buffer))); - total_segment_count += segment_count; - } - for (idx_t i = 0; i < buffers_with_free_space_count; i++) { - buffers_with_free_space.insert(reader.Read()); - } -} - -idx_t FixedSizeAllocator::GetAvailableBufferId() const { - idx_t buffer_id = buffers.size(); - while (buffers.find(buffer_id) != buffers.end()) { - D_ASSERT(buffer_id > 0); - buffer_id--; - } - return buffer_id; -} - -} // namespace duckdb - - - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// PartialBlockForIndex -//===--------------------------------------------------------------------===// - -PartialBlockForIndex::PartialBlockForIndex(PartialBlockState state, BlockManager &block_manager, - const shared_ptr &block_handle) - : PartialBlock(state, block_manager, block_handle) { -} - -void PartialBlockForIndex::Flush(const idx_t free_space_left) { - FlushInternal(free_space_left); - block_handle = block_manager.ConvertToPersistent(state.block_id, std::move(block_handle)); - Clear(); -} - -void PartialBlockForIndex::Merge(PartialBlock &other, idx_t offset, idx_t other_size) { - throw InternalException("no merge for PartialBlockForIndex"); -} - -void PartialBlockForIndex::Clear() { - block_handle.reset(); -} - -//===--------------------------------------------------------------------===// -// FixedSizeBuffer -//===--------------------------------------------------------------------===// - -constexpr idx_t FixedSizeBuffer::BASE[]; -constexpr uint8_t FixedSizeBuffer::SHIFT[]; - -FixedSizeBuffer::FixedSizeBuffer(BlockManager &block_manager) - : block_manager(block_manager), segment_count(0), allocation_size(0), dirty(false), vacuum(false), block_pointer(), - block_handle(nullptr) { - - auto &buffer_manager = block_manager.buffer_manager; - buffer_handle = buffer_manager.Allocate(Storage::BLOCK_SIZE, false, &block_handle); -} - -FixedSizeBuffer::FixedSizeBuffer(BlockManager &block_manager, const idx_t segment_count, const idx_t allocation_size, - const BlockPointer &block_pointer) - : block_manager(block_manager), segment_count(segment_count), allocation_size(allocation_size), dirty(false), - vacuum(false), block_pointer(block_pointer) { - - D_ASSERT(block_pointer.IsValid()); - block_handle = block_manager.RegisterBlock(block_pointer.block_id); - D_ASSERT(block_handle->BlockId() < MAXIMUM_BLOCK); -} - -void FixedSizeBuffer::Destroy() { - if (InMemory()) { - // we can have multiple readers on a pinned block, and unpinning the buffer handle - // decrements the reader count on the underlying block handle (Destroy() unpins) - buffer_handle.Destroy(); - } - if (OnDisk()) { - // marking a block as modified decreases the reference count of multi-use blocks - block_manager.MarkBlockAsModified(block_pointer.block_id); - } -} - -void FixedSizeBuffer::Serialize(PartialBlockManager &partial_block_manager, const idx_t available_segments, - const idx_t segment_size, const idx_t bitmask_offset) { - - // we do not serialize a block that is already on disk and not in memory - if (!InMemory()) { - if (!OnDisk() || dirty) { - throw InternalException("invalid or missing buffer in FixedSizeAllocator"); - } - return; - } - - // we do not serialize a block that is already on disk and not dirty - if (!dirty && OnDisk()) { - return; - } - - if (dirty) { - // the allocation possibly changed - auto max_offset = GetMaxOffset(available_segments); - allocation_size = max_offset * segment_size + bitmask_offset; - } - - // the buffer is in memory, so we copied it onto a new buffer when pinning - D_ASSERT(InMemory() && !OnDisk()); - - // now we write the changes, first get a partial block allocation - PartialBlockAllocation allocation = partial_block_manager.GetBlockAllocation(allocation_size); - block_pointer.block_id = allocation.state.block_id; - block_pointer.offset = allocation.state.offset; - - auto &buffer_manager = block_manager.buffer_manager; - - if (allocation.partial_block) { - // copy to an existing partial block - D_ASSERT(block_pointer.offset > 0); - auto &p_block_for_index = allocation.partial_block->Cast(); - auto dst_handle = buffer_manager.Pin(p_block_for_index.block_handle); - memcpy(dst_handle.Ptr() + block_pointer.offset, buffer_handle.Ptr(), allocation_size); - SetUninitializedRegions(p_block_for_index, segment_size, block_pointer.offset, bitmask_offset); - - } else { - // create a new block that can potentially be used as a partial block - D_ASSERT(block_handle); - D_ASSERT(!block_pointer.offset); - auto p_block_for_index = make_uniq(allocation.state, block_manager, block_handle); - SetUninitializedRegions(*p_block_for_index, segment_size, block_pointer.offset, bitmask_offset); - allocation.partial_block = std::move(p_block_for_index); - } - - partial_block_manager.RegisterPartialBlock(std::move(allocation)); - - // resetting this buffer - buffer_handle.Destroy(); - block_handle = block_manager.RegisterBlock(block_pointer.block_id); - D_ASSERT(block_handle->BlockId() < MAXIMUM_BLOCK); - - // we persist any changes, so the buffer is no longer dirty - dirty = false; -} - -void FixedSizeBuffer::Pin() { - - auto &buffer_manager = block_manager.buffer_manager; - D_ASSERT(block_pointer.IsValid()); - D_ASSERT(block_handle && block_handle->BlockId() < MAXIMUM_BLOCK); - D_ASSERT(!dirty); - - buffer_handle = buffer_manager.Pin(block_handle); - - // we need to copy the (partial) data into a new (not yet disk-backed) buffer handle - shared_ptr new_block_handle; - auto new_buffer_handle = buffer_manager.Allocate(Storage::BLOCK_SIZE, false, &new_block_handle); - - memcpy(new_buffer_handle.Ptr(), buffer_handle.Ptr() + block_pointer.offset, allocation_size); - - Destroy(); - buffer_handle = std::move(new_buffer_handle); - block_handle = new_block_handle; - block_pointer = BlockPointer(); -} - -uint32_t FixedSizeBuffer::GetOffset(const idx_t bitmask_count) { - - // get the bitmask data - auto bitmask_ptr = reinterpret_cast(Get()); - ValidityMask mask(bitmask_ptr); - auto data = mask.GetData(); - - // fills up a buffer sequentially before searching for free bits - if (mask.RowIsValid(segment_count)) { - mask.SetInvalid(segment_count); - return segment_count; - } - - for (idx_t entry_idx = 0; entry_idx < bitmask_count; entry_idx++) { - // get an entry with free bits - if (data[entry_idx] == 0) { - continue; - } - - // find the position of the free bit - auto entry = data[entry_idx]; - idx_t first_valid_bit = 0; - - // this loop finds the position of the rightmost set bit in entry and stores it - // in first_valid_bit - for (idx_t i = 0; i < 6; i++) { - // set the left half of the bits of this level to zero and test if the entry is still not zero - if (entry & BASE[i]) { - // first valid bit is in the rightmost s[i] bits - // permanently set the left half of the bits to zero - entry &= BASE[i]; - } else { - // first valid bit is in the leftmost s[i] bits - // shift by s[i] for the next iteration and add s[i] to the position of the rightmost set bit - entry >>= SHIFT[i]; - first_valid_bit += SHIFT[i]; - } - } - D_ASSERT(entry); - - auto prev_bits = entry_idx * sizeof(validity_t) * 8; - D_ASSERT(mask.RowIsValid(prev_bits + first_valid_bit)); - mask.SetInvalid(prev_bits + first_valid_bit); - return (prev_bits + first_valid_bit); - } - - throw InternalException("Invalid bitmask for FixedSizeAllocator"); -} - -uint32_t FixedSizeBuffer::GetMaxOffset(const idx_t available_segments) { - - // this function calls Get() on the buffer - D_ASSERT(InMemory()); - - // finds the maximum zero bit in a bitmask, and adds one to it, - // so that max_offset * segment_size = allocated_size of this bitmask's buffer - idx_t entry_size = sizeof(validity_t) * 8; - idx_t bitmask_count = available_segments / entry_size; - if (available_segments % entry_size != 0) { - bitmask_count++; - } - uint32_t max_offset = bitmask_count * sizeof(validity_t) * 8; - auto bits_in_last_entry = available_segments % (sizeof(validity_t) * 8); - - // get the bitmask data - auto bitmask_ptr = reinterpret_cast(Get()); - const ValidityMask mask(bitmask_ptr); - const auto data = mask.GetData(); - - D_ASSERT(bitmask_count > 0); - for (idx_t i = bitmask_count; i > 0; i--) { - - auto entry = data[i - 1]; - - // set all bits after bits_in_last_entry - if (i == bitmask_count) { - entry |= ~idx_t(0) << bits_in_last_entry; - } - - if (entry == ~idx_t(0)) { - max_offset -= sizeof(validity_t) * 8; - continue; - } - - // invert data[entry_idx] - auto entry_inv = ~entry; - idx_t first_valid_bit = 0; - - // then find the position of the LEFTMOST set bit - for (idx_t level = 0; level < 6; level++) { - - // set the right half of the bits of this level to zero and test if the entry is still not zero - if (entry_inv & ~BASE[level]) { - // first valid bit is in the leftmost s[level] bits - // shift by s[level] for the next iteration and add s[level] to the position of the leftmost set bit - entry_inv >>= SHIFT[level]; - first_valid_bit += SHIFT[level]; - } else { - // first valid bit is in the rightmost s[level] bits - // permanently set the left half of the bits to zero - entry_inv &= BASE[level]; - } - } - D_ASSERT(entry_inv); - max_offset -= sizeof(validity_t) * 8 - first_valid_bit; - D_ASSERT(!mask.RowIsValid(max_offset)); - return max_offset + 1; - } - - // there are no allocations in this buffer - throw InternalException("tried to serialize empty buffer"); -} - -void FixedSizeBuffer::SetUninitializedRegions(PartialBlockForIndex &p_block_for_index, const idx_t segment_size, - const idx_t offset, const idx_t bitmask_offset) { - - // this function calls Get() on the buffer - D_ASSERT(InMemory()); - - auto bitmask_ptr = reinterpret_cast(Get()); - ValidityMask mask(bitmask_ptr); - - idx_t i = 0; - idx_t max_offset = offset + allocation_size; - idx_t current_offset = offset + bitmask_offset; - while (current_offset < max_offset) { - - if (mask.RowIsValid(i)) { - D_ASSERT(current_offset + segment_size <= max_offset); - p_block_for_index.AddUninitializedRegion(current_offset, current_offset + segment_size); - } - current_offset += segment_size; - i++; - } -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -using ValidityBytes = JoinHashTable::ValidityBytes; -using ScanStructure = JoinHashTable::ScanStructure; -using ProbeSpill = JoinHashTable::ProbeSpill; -using ProbeSpillLocalState = JoinHashTable::ProbeSpillLocalAppendState; - -JoinHashTable::JoinHashTable(BufferManager &buffer_manager_p, const vector &conditions_p, - vector btypes, JoinType type_p) - : buffer_manager(buffer_manager_p), conditions(conditions_p), build_types(std::move(btypes)), entry_size(0), - tuple_size(0), vfound(Value::BOOLEAN(false)), join_type(type_p), finalized(false), has_null(false), - external(false), radix_bits(4), partition_start(0), partition_end(0) { - - for (auto &condition : conditions) { - D_ASSERT(condition.left->return_type == condition.right->return_type); - auto type = condition.left->return_type; - if (condition.comparison == ExpressionType::COMPARE_EQUAL || - condition.comparison == ExpressionType::COMPARE_NOT_DISTINCT_FROM) { - - // ensure that all equality conditions are at the front, - // and that all other conditions are at the back - D_ASSERT(equality_types.size() == condition_types.size()); - equality_types.push_back(type); - } - - predicates.push_back(condition.comparison); - null_values_are_equal.push_back(condition.comparison == ExpressionType::COMPARE_DISTINCT_FROM || - condition.comparison == ExpressionType::COMPARE_NOT_DISTINCT_FROM); - - condition_types.push_back(type); - } - // at least one equality is necessary - D_ASSERT(!equality_types.empty()); - - // Types for the layout - vector layout_types(condition_types); - layout_types.insert(layout_types.end(), build_types.begin(), build_types.end()); - if (IsRightOuterJoin(join_type)) { - // full/right outer joins need an extra bool to keep track of whether or not a tuple has found a matching entry - // we place the bool before the NEXT pointer - layout_types.emplace_back(LogicalType::BOOLEAN); - } - layout_types.emplace_back(LogicalType::HASH); - layout.Initialize(layout_types, false); - row_matcher.Initialize(false, layout, predicates); - row_matcher_no_match_sel.Initialize(true, layout, predicates); - - const auto &offsets = layout.GetOffsets(); - tuple_size = offsets[condition_types.size() + build_types.size()]; - pointer_offset = offsets.back(); - entry_size = layout.GetRowWidth(); - - data_collection = make_uniq(buffer_manager, layout); - sink_collection = - make_uniq(buffer_manager, layout, radix_bits, layout.ColumnCount() - 1); -} - -JoinHashTable::~JoinHashTable() { -} - -void JoinHashTable::Merge(JoinHashTable &other) { - { - lock_guard guard(data_lock); - data_collection->Combine(*other.data_collection); - } - - if (join_type == JoinType::MARK) { - auto &info = correlated_mark_join_info; - lock_guard mj_lock(info.mj_lock); - has_null = has_null || other.has_null; - if (!info.correlated_types.empty()) { - auto &other_info = other.correlated_mark_join_info; - info.correlated_counts->Combine(*other_info.correlated_counts); - } - } - - sink_collection->Combine(*other.sink_collection); -} - -void JoinHashTable::ApplyBitmask(Vector &hashes, idx_t count) { - if (hashes.GetVectorType() == VectorType::CONSTANT_VECTOR) { - D_ASSERT(!ConstantVector::IsNull(hashes)); - auto indices = ConstantVector::GetData(hashes); - *indices = *indices & bitmask; - } else { - hashes.Flatten(count); - auto indices = FlatVector::GetData(hashes); - for (idx_t i = 0; i < count; i++) { - indices[i] &= bitmask; - } - } -} - -void JoinHashTable::ApplyBitmask(Vector &hashes, const SelectionVector &sel, idx_t count, Vector &pointers) { - UnifiedVectorFormat hdata; - hashes.ToUnifiedFormat(count, hdata); - - auto hash_data = UnifiedVectorFormat::GetData(hdata); - auto result_data = FlatVector::GetData(pointers); - auto main_ht = reinterpret_cast(hash_map.get()); - for (idx_t i = 0; i < count; i++) { - auto rindex = sel.get_index(i); - auto hindex = hdata.sel->get_index(rindex); - auto hash = hash_data[hindex]; - result_data[rindex] = main_ht + (hash & bitmask); - } -} - -void JoinHashTable::Hash(DataChunk &keys, const SelectionVector &sel, idx_t count, Vector &hashes) { - if (count == keys.size()) { - // no null values are filtered: use regular hash functions - VectorOperations::Hash(keys.data[0], hashes, keys.size()); - for (idx_t i = 1; i < equality_types.size(); i++) { - VectorOperations::CombineHash(hashes, keys.data[i], keys.size()); - } - } else { - // null values were filtered: use selection vector - VectorOperations::Hash(keys.data[0], hashes, sel, count); - for (idx_t i = 1; i < equality_types.size(); i++) { - VectorOperations::CombineHash(hashes, keys.data[i], sel, count); - } - } -} - -static idx_t FilterNullValues(UnifiedVectorFormat &vdata, const SelectionVector &sel, idx_t count, - SelectionVector &result) { - idx_t result_count = 0; - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto key_idx = vdata.sel->get_index(idx); - if (vdata.validity.RowIsValid(key_idx)) { - result.set_index(result_count++, idx); - } - } - return result_count; -} - -void JoinHashTable::Build(PartitionedTupleDataAppendState &append_state, DataChunk &keys, DataChunk &payload) { - D_ASSERT(!finalized); - D_ASSERT(keys.size() == payload.size()); - if (keys.size() == 0) { - return; - } - // special case: correlated mark join - if (join_type == JoinType::MARK && !correlated_mark_join_info.correlated_types.empty()) { - auto &info = correlated_mark_join_info; - lock_guard mj_lock(info.mj_lock); - // Correlated MARK join - // for the correlated mark join we need to keep track of COUNT(*) and COUNT(COLUMN) for each of the correlated - // columns push into the aggregate hash table - D_ASSERT(info.correlated_counts); - info.group_chunk.SetCardinality(keys); - for (idx_t i = 0; i < info.correlated_types.size(); i++) { - info.group_chunk.data[i].Reference(keys.data[i]); - } - if (info.correlated_payload.data.empty()) { - vector types; - types.push_back(keys.data[info.correlated_types.size()].GetType()); - info.correlated_payload.InitializeEmpty(types); - } - info.correlated_payload.SetCardinality(keys); - info.correlated_payload.data[0].Reference(keys.data[info.correlated_types.size()]); - info.correlated_counts->AddChunk(info.group_chunk, info.correlated_payload, AggregateType::NON_DISTINCT); - } - - // build a chunk to append to the data collection [keys, payload, (optional "found" boolean), hash] - DataChunk source_chunk; - source_chunk.InitializeEmpty(layout.GetTypes()); - for (idx_t i = 0; i < keys.ColumnCount(); i++) { - source_chunk.data[i].Reference(keys.data[i]); - } - idx_t col_offset = keys.ColumnCount(); - D_ASSERT(build_types.size() == payload.ColumnCount()); - for (idx_t i = 0; i < payload.ColumnCount(); i++) { - source_chunk.data[col_offset + i].Reference(payload.data[i]); - } - col_offset += payload.ColumnCount(); - if (IsRightOuterJoin(join_type)) { - // for FULL/RIGHT OUTER joins initialize the "found" boolean to false - source_chunk.data[col_offset].Reference(vfound); - col_offset++; - } - Vector hash_values(LogicalType::HASH); - source_chunk.data[col_offset].Reference(hash_values); - source_chunk.SetCardinality(keys); - - // ToUnifiedFormat the source chunk - TupleDataCollection::ToUnifiedFormat(append_state.chunk_state, source_chunk); - - // prepare the keys for processing - const SelectionVector *current_sel; - SelectionVector sel(STANDARD_VECTOR_SIZE); - idx_t added_count = PrepareKeys(keys, append_state.chunk_state.vector_data, current_sel, sel, true); - if (added_count < keys.size()) { - has_null = true; - } - if (added_count == 0) { - return; - } - - // hash the keys and obtain an entry in the list - // note that we only hash the keys used in the equality comparison - Hash(keys, *current_sel, added_count, hash_values); - - // Re-reference and ToUnifiedFormat the hash column after computing it - source_chunk.data[col_offset].Reference(hash_values); - hash_values.ToUnifiedFormat(source_chunk.size(), append_state.chunk_state.vector_data.back().unified); - - // We already called TupleDataCollection::ToUnifiedFormat, so we can AppendUnified here - sink_collection->AppendUnified(append_state, source_chunk, *current_sel, added_count); -} - -idx_t JoinHashTable::PrepareKeys(DataChunk &keys, vector &vector_data, - const SelectionVector *¤t_sel, SelectionVector &sel, bool build_side) { - // figure out which keys are NULL, and create a selection vector out of them - current_sel = FlatVector::IncrementalSelectionVector(); - idx_t added_count = keys.size(); - if (build_side && IsRightOuterJoin(join_type)) { - // in case of a right or full outer join, we cannot remove NULL keys from the build side - return added_count; - } - - for (idx_t col_idx = 0; col_idx < keys.ColumnCount(); col_idx++) { - if (!null_values_are_equal[col_idx]) { - auto &col_key_data = vector_data[col_idx].unified; - if (col_key_data.validity.AllValid()) { - continue; - } - added_count = FilterNullValues(col_key_data, *current_sel, added_count, sel); - // null values are NOT equal for this column, filter them out - current_sel = &sel; - } - } - return added_count; -} - -template -static inline void InsertHashesLoop(atomic pointers[], const hash_t indices[], const idx_t count, - const data_ptr_t key_locations[], const idx_t pointer_offset) { - for (idx_t i = 0; i < count; i++) { - const auto index = indices[i]; - if (PARALLEL) { - data_ptr_t head; - do { - head = pointers[index]; - Store(head, key_locations[i] + pointer_offset); - } while (!std::atomic_compare_exchange_weak(&pointers[index], &head, key_locations[i])); - } else { - // set prev in current key to the value (NOTE: this will be nullptr if there is none) - Store(pointers[index], key_locations[i] + pointer_offset); - - // set pointer to current tuple - pointers[index] = key_locations[i]; - } - } -} - -void JoinHashTable::InsertHashes(Vector &hashes, idx_t count, data_ptr_t key_locations[], bool parallel) { - D_ASSERT(hashes.GetType().id() == LogicalType::HASH); - - // use bitmask to get position in array - ApplyBitmask(hashes, count); - - hashes.Flatten(count); - D_ASSERT(hashes.GetVectorType() == VectorType::FLAT_VECTOR); - - auto pointers = reinterpret_cast *>(hash_map.get()); - auto indices = FlatVector::GetData(hashes); - - if (parallel) { - InsertHashesLoop(pointers, indices, count, key_locations, pointer_offset); - } else { - InsertHashesLoop(pointers, indices, count, key_locations, pointer_offset); - } -} - -void JoinHashTable::InitializePointerTable() { - idx_t capacity = PointerTableCapacity(Count()); - D_ASSERT(IsPowerOfTwo(capacity)); - - if (hash_map.get()) { - // There is already a hash map - auto current_capacity = hash_map.GetSize() / sizeof(data_ptr_t); - if (capacity > current_capacity) { - // Need more space - hash_map = buffer_manager.GetBufferAllocator().Allocate(capacity * sizeof(data_ptr_t)); - } else { - // Just use the current hash map - capacity = current_capacity; - } - } else { - // Allocate a hash map - hash_map = buffer_manager.GetBufferAllocator().Allocate(capacity * sizeof(data_ptr_t)); - } - D_ASSERT(hash_map.GetSize() == capacity * sizeof(data_ptr_t)); - - // initialize HT with all-zero entries - std::fill_n(reinterpret_cast(hash_map.get()), capacity, nullptr); - - bitmask = capacity - 1; -} - -void JoinHashTable::Finalize(idx_t chunk_idx_from, idx_t chunk_idx_to, bool parallel) { - // Pointer table should be allocated - D_ASSERT(hash_map.get()); - - Vector hashes(LogicalType::HASH); - auto hash_data = FlatVector::GetData(hashes); - - TupleDataChunkIterator iterator(*data_collection, TupleDataPinProperties::KEEP_EVERYTHING_PINNED, chunk_idx_from, - chunk_idx_to, false); - const auto row_locations = iterator.GetRowLocations(); - do { - const auto count = iterator.GetCurrentChunkCount(); - for (idx_t i = 0; i < count; i++) { - hash_data[i] = Load(row_locations[i] + pointer_offset); - } - InsertHashes(hashes, count, row_locations, parallel); - } while (iterator.Next()); -} - -unique_ptr JoinHashTable::InitializeScanStructure(DataChunk &keys, TupleDataChunkState &key_state, - const SelectionVector *¤t_sel) { - D_ASSERT(Count() > 0); // should be handled before - D_ASSERT(finalized); - - // set up the scan structure - auto ss = make_uniq(*this, key_state); - - if (join_type != JoinType::INNER) { - ss->found_match = make_unsafe_uniq_array(STANDARD_VECTOR_SIZE); - memset(ss->found_match.get(), 0, sizeof(bool) * STANDARD_VECTOR_SIZE); - } - - // first prepare the keys for probing - TupleDataCollection::ToUnifiedFormat(key_state, keys); - ss->count = PrepareKeys(keys, key_state.vector_data, current_sel, ss->sel_vector, false); - return ss; -} - -unique_ptr JoinHashTable::Probe(DataChunk &keys, TupleDataChunkState &key_state, - Vector *precomputed_hashes) { - const SelectionVector *current_sel; - auto ss = InitializeScanStructure(keys, key_state, current_sel); - if (ss->count == 0) { - return ss; - } - - if (precomputed_hashes) { - ApplyBitmask(*precomputed_hashes, *current_sel, ss->count, ss->pointers); - } else { - // hash all the keys - Vector hashes(LogicalType::HASH); - Hash(keys, *current_sel, ss->count, hashes); - - // now initialize the pointers of the scan structure based on the hashes - ApplyBitmask(hashes, *current_sel, ss->count, ss->pointers); - } - - // create the selection vector linking to only non-empty entries - ss->InitializeSelectionVector(current_sel); - - return ss; -} - -ScanStructure::ScanStructure(JoinHashTable &ht_p, TupleDataChunkState &key_state_p) - : key_state(key_state_p), pointers(LogicalType::POINTER), sel_vector(STANDARD_VECTOR_SIZE), ht(ht_p), - finished(false) { -} - -void ScanStructure::Next(DataChunk &keys, DataChunk &left, DataChunk &result) { - if (finished) { - return; - } - switch (ht.join_type) { - case JoinType::INNER: - case JoinType::RIGHT: - NextInnerJoin(keys, left, result); - break; - case JoinType::SEMI: - NextSemiJoin(keys, left, result); - break; - case JoinType::MARK: - NextMarkJoin(keys, left, result); - break; - case JoinType::ANTI: - NextAntiJoin(keys, left, result); - break; - case JoinType::OUTER: - case JoinType::LEFT: - NextLeftJoin(keys, left, result); - break; - case JoinType::SINGLE: - NextSingleJoin(keys, left, result); - break; - default: - throw InternalException("Unhandled join type in JoinHashTable"); - } -} - -idx_t ScanStructure::ResolvePredicates(DataChunk &keys, SelectionVector &match_sel, SelectionVector *no_match_sel) { - // Start with the scan selection - for (idx_t i = 0; i < this->count; ++i) { - match_sel.set_index(i, this->sel_vector.get_index(i)); - } - idx_t no_match_count = 0; - - auto &matcher = no_match_sel ? ht.row_matcher_no_match_sel : ht.row_matcher; - return matcher.Match(keys, key_state.vector_data, match_sel, this->count, ht.layout, pointers, no_match_sel, - no_match_count); -} - -idx_t ScanStructure::ScanInnerJoin(DataChunk &keys, SelectionVector &result_vector) { - while (true) { - // resolve the predicates for this set of keys - idx_t result_count = ResolvePredicates(keys, result_vector, nullptr); - - // after doing all the comparisons set the found_match vector - if (found_match) { - for (idx_t i = 0; i < result_count; i++) { - auto idx = result_vector.get_index(i); - found_match[idx] = true; - } - } - if (result_count > 0) { - return result_count; - } - // no matches found: check the next set of pointers - AdvancePointers(); - if (this->count == 0) { - return 0; - } - } -} - -void ScanStructure::AdvancePointers(const SelectionVector &sel, idx_t sel_count) { - // now for all the pointers, we move on to the next set of pointers - idx_t new_count = 0; - auto ptrs = FlatVector::GetData(this->pointers); - for (idx_t i = 0; i < sel_count; i++) { - auto idx = sel.get_index(i); - ptrs[idx] = Load(ptrs[idx] + ht.pointer_offset); - if (ptrs[idx]) { - this->sel_vector.set_index(new_count++, idx); - } - } - this->count = new_count; -} - -void ScanStructure::InitializeSelectionVector(const SelectionVector *¤t_sel) { - idx_t non_empty_count = 0; - auto ptrs = FlatVector::GetData(pointers); - auto cnt = count; - for (idx_t i = 0; i < cnt; i++) { - const auto idx = current_sel->get_index(i); - ptrs[idx] = Load(ptrs[idx]); - if (ptrs[idx]) { - sel_vector.set_index(non_empty_count++, idx); - } - } - count = non_empty_count; -} - -void ScanStructure::AdvancePointers() { - AdvancePointers(this->sel_vector, this->count); -} - -void ScanStructure::GatherResult(Vector &result, const SelectionVector &result_vector, - const SelectionVector &sel_vector, const idx_t count, const idx_t col_no) { - ht.data_collection->Gather(pointers, sel_vector, count, col_no, result, result_vector); -} - -void ScanStructure::GatherResult(Vector &result, const SelectionVector &sel_vector, const idx_t count, - const idx_t col_idx) { - GatherResult(result, *FlatVector::IncrementalSelectionVector(), sel_vector, count, col_idx); -} - -void ScanStructure::NextInnerJoin(DataChunk &keys, DataChunk &left, DataChunk &result) { - D_ASSERT(result.ColumnCount() == left.ColumnCount() + ht.build_types.size()); - if (this->count == 0) { - // no pointers left to chase - return; - } - - SelectionVector result_vector(STANDARD_VECTOR_SIZE); - - idx_t result_count = ScanInnerJoin(keys, result_vector); - if (result_count > 0) { - if (IsRightOuterJoin(ht.join_type)) { - // full/right outer join: mark join matches as FOUND in the HT - auto ptrs = FlatVector::GetData(pointers); - for (idx_t i = 0; i < result_count; i++) { - auto idx = result_vector.get_index(i); - // NOTE: threadsan reports this as a data race because this can be set concurrently by separate threads - // Technically it is, but it does not matter, since the only value that can be written is "true" - Store(true, ptrs[idx] + ht.tuple_size); - } - } - // matches were found - // construct the result - // on the LHS, we create a slice using the result vector - result.Slice(left, result_vector, result_count); - - // on the RHS, we need to fetch the data from the hash table - for (idx_t i = 0; i < ht.build_types.size(); i++) { - auto &vector = result.data[left.ColumnCount() + i]; - D_ASSERT(vector.GetType() == ht.build_types[i]); - GatherResult(vector, result_vector, result_count, i + ht.condition_types.size()); - } - AdvancePointers(); - } -} - -void ScanStructure::ScanKeyMatches(DataChunk &keys) { - // the semi-join, anti-join and mark-join we handle a differently from the inner join - // since there can be at most STANDARD_VECTOR_SIZE results - // we handle the entire chunk in one call to Next(). - // for every pointer, we keep chasing pointers and doing comparisons. - // this results in a boolean array indicating whether or not the tuple has a match - SelectionVector match_sel(STANDARD_VECTOR_SIZE), no_match_sel(STANDARD_VECTOR_SIZE); - while (this->count > 0) { - // resolve the predicates for the current set of pointers - idx_t match_count = ResolvePredicates(keys, match_sel, &no_match_sel); - idx_t no_match_count = this->count - match_count; - - // mark each of the matches as found - for (idx_t i = 0; i < match_count; i++) { - found_match[match_sel.get_index(i)] = true; - } - // continue searching for the ones where we did not find a match yet - AdvancePointers(no_match_sel, no_match_count); - } -} - -template -void ScanStructure::NextSemiOrAntiJoin(DataChunk &keys, DataChunk &left, DataChunk &result) { - D_ASSERT(left.ColumnCount() == result.ColumnCount()); - D_ASSERT(keys.size() == left.size()); - // create the selection vector from the matches that were found - SelectionVector sel(STANDARD_VECTOR_SIZE); - idx_t result_count = 0; - for (idx_t i = 0; i < keys.size(); i++) { - if (found_match[i] == MATCH) { - // part of the result - sel.set_index(result_count++, i); - } - } - // construct the final result - if (result_count > 0) { - // we only return the columns on the left side - // reference the columns of the left side from the result - result.Slice(left, sel, result_count); - } else { - D_ASSERT(result.size() == 0); - } -} - -void ScanStructure::NextSemiJoin(DataChunk &keys, DataChunk &left, DataChunk &result) { - // first scan for key matches - ScanKeyMatches(keys); - // then construct the result from all tuples with a match - NextSemiOrAntiJoin(keys, left, result); - - finished = true; -} - -void ScanStructure::NextAntiJoin(DataChunk &keys, DataChunk &left, DataChunk &result) { - // first scan for key matches - ScanKeyMatches(keys); - // then construct the result from all tuples that did not find a match - NextSemiOrAntiJoin(keys, left, result); - - finished = true; -} - -void ScanStructure::ConstructMarkJoinResult(DataChunk &join_keys, DataChunk &child, DataChunk &result) { - // for the initial set of columns we just reference the left side - result.SetCardinality(child); - for (idx_t i = 0; i < child.ColumnCount(); i++) { - result.data[i].Reference(child.data[i]); - } - auto &mark_vector = result.data.back(); - mark_vector.SetVectorType(VectorType::FLAT_VECTOR); - // first we set the NULL values from the join keys - // if there is any NULL in the keys, the result is NULL - auto bool_result = FlatVector::GetData(mark_vector); - auto &mask = FlatVector::Validity(mark_vector); - for (idx_t col_idx = 0; col_idx < join_keys.ColumnCount(); col_idx++) { - if (ht.null_values_are_equal[col_idx]) { - continue; - } - UnifiedVectorFormat jdata; - join_keys.data[col_idx].ToUnifiedFormat(join_keys.size(), jdata); - if (!jdata.validity.AllValid()) { - for (idx_t i = 0; i < join_keys.size(); i++) { - auto jidx = jdata.sel->get_index(i); - mask.Set(i, jdata.validity.RowIsValidUnsafe(jidx)); - } - } - } - // now set the remaining entries to either true or false based on whether a match was found - if (found_match) { - for (idx_t i = 0; i < child.size(); i++) { - bool_result[i] = found_match[i]; - } - } else { - memset(bool_result, 0, sizeof(bool) * child.size()); - } - // if the right side contains NULL values, the result of any FALSE becomes NULL - if (ht.has_null) { - for (idx_t i = 0; i < child.size(); i++) { - if (!bool_result[i]) { - mask.SetInvalid(i); - } - } - } -} - -void ScanStructure::NextMarkJoin(DataChunk &keys, DataChunk &input, DataChunk &result) { - D_ASSERT(result.ColumnCount() == input.ColumnCount() + 1); - D_ASSERT(result.data.back().GetType() == LogicalType::BOOLEAN); - // this method should only be called for a non-empty HT - D_ASSERT(ht.Count() > 0); - - ScanKeyMatches(keys); - if (ht.correlated_mark_join_info.correlated_types.empty()) { - ConstructMarkJoinResult(keys, input, result); - } else { - auto &info = ht.correlated_mark_join_info; - lock_guard mj_lock(info.mj_lock); - - // there are correlated columns - // first we fetch the counts from the aggregate hashtable corresponding to these entries - D_ASSERT(keys.ColumnCount() == info.group_chunk.ColumnCount() + 1); - info.group_chunk.SetCardinality(keys); - for (idx_t i = 0; i < info.group_chunk.ColumnCount(); i++) { - info.group_chunk.data[i].Reference(keys.data[i]); - } - info.correlated_counts->FetchAggregates(info.group_chunk, info.result_chunk); - - // for the initial set of columns we just reference the left side - result.SetCardinality(input); - for (idx_t i = 0; i < input.ColumnCount(); i++) { - result.data[i].Reference(input.data[i]); - } - // create the result matching vector - auto &last_key = keys.data.back(); - auto &result_vector = result.data.back(); - // first set the nullmask based on whether or not there were NULL values in the join key - result_vector.SetVectorType(VectorType::FLAT_VECTOR); - auto bool_result = FlatVector::GetData(result_vector); - auto &mask = FlatVector::Validity(result_vector); - switch (last_key.GetVectorType()) { - case VectorType::CONSTANT_VECTOR: - if (ConstantVector::IsNull(last_key)) { - mask.SetAllInvalid(input.size()); - } - break; - case VectorType::FLAT_VECTOR: - mask.Copy(FlatVector::Validity(last_key), input.size()); - break; - default: { - UnifiedVectorFormat kdata; - last_key.ToUnifiedFormat(keys.size(), kdata); - for (idx_t i = 0; i < input.size(); i++) { - auto kidx = kdata.sel->get_index(i); - mask.Set(i, kdata.validity.RowIsValid(kidx)); - } - break; - } - } - - auto count_star = FlatVector::GetData(info.result_chunk.data[0]); - auto count = FlatVector::GetData(info.result_chunk.data[1]); - // set the entries to either true or false based on whether a match was found - for (idx_t i = 0; i < input.size(); i++) { - D_ASSERT(count_star[i] >= count[i]); - bool_result[i] = found_match ? found_match[i] : false; - if (!bool_result[i] && count_star[i] > count[i]) { - // RHS has NULL value and result is false: set to null - mask.SetInvalid(i); - } - if (count_star[i] == 0) { - // count == 0, set nullmask to false (we know the result is false now) - mask.SetValid(i); - } - } - } - finished = true; -} - -void ScanStructure::NextLeftJoin(DataChunk &keys, DataChunk &left, DataChunk &result) { - // a LEFT OUTER JOIN is identical to an INNER JOIN except all tuples that do - // not have a match must return at least one tuple (with the right side set - // to NULL in every column) - NextInnerJoin(keys, left, result); - if (result.size() == 0) { - // no entries left from the normal join - // fill in the result of the remaining left tuples - // together with NULL values on the right-hand side - idx_t remaining_count = 0; - SelectionVector sel(STANDARD_VECTOR_SIZE); - for (idx_t i = 0; i < left.size(); i++) { - if (!found_match[i]) { - sel.set_index(remaining_count++, i); - } - } - if (remaining_count > 0) { - // have remaining tuples - // slice the left side with tuples that did not find a match - result.Slice(left, sel, remaining_count); - - // now set the right side to NULL - for (idx_t i = left.ColumnCount(); i < result.ColumnCount(); i++) { - Vector &vec = result.data[i]; - vec.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(vec, true); - } - } - finished = true; - } -} - -void ScanStructure::NextSingleJoin(DataChunk &keys, DataChunk &input, DataChunk &result) { - // single join - // this join is similar to the semi join except that - // (1) we actually return data from the RHS and - // (2) we return NULL for that data if there is no match - idx_t result_count = 0; - SelectionVector result_sel(STANDARD_VECTOR_SIZE); - SelectionVector match_sel(STANDARD_VECTOR_SIZE), no_match_sel(STANDARD_VECTOR_SIZE); - while (this->count > 0) { - // resolve the predicates for the current set of pointers - idx_t match_count = ResolvePredicates(keys, match_sel, &no_match_sel); - idx_t no_match_count = this->count - match_count; - - // mark each of the matches as found - for (idx_t i = 0; i < match_count; i++) { - // found a match for this index - auto index = match_sel.get_index(i); - found_match[index] = true; - result_sel.set_index(result_count++, index); - } - // continue searching for the ones where we did not find a match yet - AdvancePointers(no_match_sel, no_match_count); - } - // reference the columns of the left side from the result - D_ASSERT(input.ColumnCount() > 0); - for (idx_t i = 0; i < input.ColumnCount(); i++) { - result.data[i].Reference(input.data[i]); - } - // now fetch the data from the RHS - for (idx_t i = 0; i < ht.build_types.size(); i++) { - auto &vector = result.data[input.ColumnCount() + i]; - // set NULL entries for every entry that was not found - for (idx_t j = 0; j < input.size(); j++) { - if (!found_match[j]) { - FlatVector::SetNull(vector, j, true); - } - } - // for the remaining values we fetch the values - GatherResult(vector, result_sel, result_sel, result_count, i + ht.condition_types.size()); - } - result.SetCardinality(input.size()); - - // like the SEMI, ANTI and MARK join types, the SINGLE join only ever does one pass over the HT per input chunk - finished = true; -} - -void JoinHashTable::ScanFullOuter(JoinHTScanState &state, Vector &addresses, DataChunk &result) { - // scan the HT starting from the current position and check which rows from the build side did not find a match - auto key_locations = FlatVector::GetData(addresses); - idx_t found_entries = 0; - - auto &iterator = state.iterator; - if (iterator.Done()) { - return; - } - - const auto row_locations = iterator.GetRowLocations(); - do { - const auto count = iterator.GetCurrentChunkCount(); - for (idx_t i = state.offset_in_chunk; i < count; i++) { - auto found_match = Load(row_locations[i] + tuple_size); - if (!found_match) { - key_locations[found_entries++] = row_locations[i]; - if (found_entries == STANDARD_VECTOR_SIZE) { - state.offset_in_chunk = i + 1; - break; - } - } - } - if (found_entries == STANDARD_VECTOR_SIZE) { - break; - } - state.offset_in_chunk = 0; - } while (iterator.Next()); - - // now gather from the found rows - if (found_entries == 0) { - return; - } - result.SetCardinality(found_entries); - idx_t left_column_count = result.ColumnCount() - build_types.size(); - const auto &sel_vector = *FlatVector::IncrementalSelectionVector(); - // set the left side as a constant NULL - for (idx_t i = 0; i < left_column_count; i++) { - Vector &vec = result.data[i]; - vec.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(vec, true); - } - - // gather the values from the RHS - for (idx_t i = 0; i < build_types.size(); i++) { - auto &vector = result.data[left_column_count + i]; - D_ASSERT(vector.GetType() == build_types[i]); - const auto col_no = condition_types.size() + i; - data_collection->Gather(addresses, sel_vector, found_entries, col_no, vector, sel_vector); - } -} - -idx_t JoinHashTable::FillWithHTOffsets(JoinHTScanState &state, Vector &addresses) { - // iterate over HT - auto key_locations = FlatVector::GetData(addresses); - idx_t key_count = 0; - - auto &iterator = state.iterator; - const auto row_locations = iterator.GetRowLocations(); - do { - const auto count = iterator.GetCurrentChunkCount(); - for (idx_t i = 0; i < count; i++) { - key_locations[key_count + i] = row_locations[i]; - } - key_count += count; - } while (iterator.Next()); - - return key_count; -} - -bool JoinHashTable::RequiresExternalJoin(ClientConfig &config, vector> &local_hts) { - total_count = 0; - idx_t data_size = 0; - for (auto &ht : local_hts) { - auto &local_sink_collection = ht->GetSinkCollection(); - total_count += local_sink_collection.Count(); - data_size += local_sink_collection.SizeInBytes(); - } - - if (total_count == 0) { - return false; - } - - if (config.force_external) { - // Do 1 round per partition if forcing external join to test all code paths - const auto r = RadixPartitioning::NumberOfPartitions(radix_bits); - auto data_size_per_round = (data_size + r - 1) / r; - auto count_per_round = (total_count + r - 1) / r; - max_ht_size = data_size_per_round + PointerTableSize(count_per_round); - external = true; - } else { - auto ht_size = data_size + PointerTableSize(total_count); - external = ht_size > max_ht_size; - } - return external; -} - -void JoinHashTable::Unpartition() { - for (auto &partition : sink_collection->GetPartitions()) { - data_collection->Combine(*partition); - } -} - -bool JoinHashTable::RequiresPartitioning(ClientConfig &config, vector> &local_hts) { - D_ASSERT(total_count != 0); - D_ASSERT(external); - - idx_t num_partitions = RadixPartitioning::NumberOfPartitions(radix_bits); - vector partition_counts(num_partitions, 0); - vector partition_sizes(num_partitions, 0); - for (auto &ht : local_hts) { - const auto &local_partitions = ht->GetSinkCollection().GetPartitions(); - for (idx_t partition_idx = 0; partition_idx < num_partitions; partition_idx++) { - auto &local_partition = local_partitions[partition_idx]; - partition_counts[partition_idx] += local_partition->Count(); - partition_sizes[partition_idx] += local_partition->SizeInBytes(); - } - } - - // Figure out if we can fit all single partitions in memory - idx_t max_partition_idx = 0; - idx_t max_partition_size = 0; - for (idx_t partition_idx = 0; partition_idx < num_partitions; partition_idx++) { - const auto &partition_count = partition_counts[partition_idx]; - const auto &partition_size = partition_sizes[partition_idx]; - auto partition_ht_size = partition_size + PointerTableSize(partition_count); - if (partition_ht_size > max_partition_size) { - max_partition_size = partition_ht_size; - max_partition_idx = partition_idx; - } - } - - if (config.force_external || max_partition_size > max_ht_size) { - const auto partition_count = partition_counts[max_partition_idx]; - const auto partition_size = partition_sizes[max_partition_idx]; - - const auto max_added_bits = RadixPartitioning::MAX_RADIX_BITS - radix_bits; - idx_t added_bits = config.force_external ? 2 : 1; - for (; added_bits < max_added_bits; added_bits++) { - double partition_multiplier = RadixPartitioning::NumberOfPartitions(added_bits); - - auto new_estimated_count = double(partition_count) / partition_multiplier; - auto new_estimated_size = double(partition_size) / partition_multiplier; - auto new_estimated_ht_size = new_estimated_size + PointerTableSize(new_estimated_count); - - if (config.force_external || new_estimated_ht_size <= double(max_ht_size) / 4) { - // Aim for an estimated partition size of max_ht_size / 4 - break; - } - } - radix_bits += added_bits; - sink_collection = - make_uniq(buffer_manager, layout, radix_bits, layout.ColumnCount() - 1); - return true; - } else { - return false; - } -} - -void JoinHashTable::Partition(JoinHashTable &global_ht) { - auto new_sink_collection = - make_uniq(buffer_manager, layout, global_ht.radix_bits, layout.ColumnCount() - 1); - sink_collection->Repartition(*new_sink_collection); - sink_collection = std::move(new_sink_collection); - global_ht.Merge(*this); -} - -void JoinHashTable::Reset() { - data_collection->Reset(); - finalized = false; -} - -bool JoinHashTable::PrepareExternalFinalize() { - if (finalized) { - Reset(); - } - - const auto num_partitions = RadixPartitioning::NumberOfPartitions(radix_bits); - if (partition_end == num_partitions) { - return false; - } - - // Start where we left off - auto &partitions = sink_collection->GetPartitions(); - partition_start = partition_end; - - // Determine how many partitions we can do next (at least one) - idx_t count = 0; - idx_t data_size = 0; - idx_t partition_idx; - for (partition_idx = partition_start; partition_idx < num_partitions; partition_idx++) { - auto incl_count = count + partitions[partition_idx]->Count(); - auto incl_data_size = data_size + partitions[partition_idx]->SizeInBytes(); - auto incl_ht_size = incl_data_size + PointerTableSize(incl_count); - if (count > 0 && incl_ht_size > max_ht_size) { - break; - } - count = incl_count; - data_size = incl_data_size; - } - partition_end = partition_idx; - - // Move the partitions to the main data collection - for (partition_idx = partition_start; partition_idx < partition_end; partition_idx++) { - data_collection->Combine(*partitions[partition_idx]); - } - D_ASSERT(Count() == count); - - return true; -} - -static void CreateSpillChunk(DataChunk &spill_chunk, DataChunk &keys, DataChunk &payload, Vector &hashes) { - spill_chunk.Reset(); - idx_t spill_col_idx = 0; - for (idx_t col_idx = 0; col_idx < keys.ColumnCount(); col_idx++) { - spill_chunk.data[col_idx].Reference(keys.data[col_idx]); - } - spill_col_idx += keys.ColumnCount(); - for (idx_t col_idx = 0; col_idx < payload.data.size(); col_idx++) { - spill_chunk.data[spill_col_idx + col_idx].Reference(payload.data[col_idx]); - } - spill_col_idx += payload.ColumnCount(); - spill_chunk.data[spill_col_idx].Reference(hashes); -} - -unique_ptr JoinHashTable::ProbeAndSpill(DataChunk &keys, TupleDataChunkState &key_state, - DataChunk &payload, ProbeSpill &probe_spill, - ProbeSpillLocalAppendState &spill_state, - DataChunk &spill_chunk) { - // hash all the keys - Vector hashes(LogicalType::HASH); - Hash(keys, *FlatVector::IncrementalSelectionVector(), keys.size(), hashes); - - // find out which keys we can match with the current pinned partitions - SelectionVector true_sel; - SelectionVector false_sel; - true_sel.Initialize(); - false_sel.Initialize(); - auto true_count = RadixPartitioning::Select(hashes, FlatVector::IncrementalSelectionVector(), keys.size(), - radix_bits, partition_end, &true_sel, &false_sel); - auto false_count = keys.size() - true_count; - - CreateSpillChunk(spill_chunk, keys, payload, hashes); - - // can't probe these values right now, append to spill - spill_chunk.Slice(false_sel, false_count); - spill_chunk.Verify(); - probe_spill.Append(spill_chunk, spill_state); - - // slice the stuff we CAN probe right now - hashes.Slice(true_sel, true_count); - keys.Slice(true_sel, true_count); - payload.Slice(true_sel, true_count); - - const SelectionVector *current_sel; - auto ss = InitializeScanStructure(keys, key_state, current_sel); - if (ss->count == 0) { - return ss; - } - - // now initialize the pointers of the scan structure based on the hashes - ApplyBitmask(hashes, *current_sel, ss->count, ss->pointers); - - // create the selection vector linking to only non-empty entries - ss->InitializeSelectionVector(current_sel); - - return ss; -} - -ProbeSpill::ProbeSpill(JoinHashTable &ht, ClientContext &context, const vector &probe_types) - : ht(ht), context(context), probe_types(probe_types) { - auto remaining_count = ht.GetSinkCollection().Count(); - auto remaining_data_size = ht.GetSinkCollection().SizeInBytes(); - auto remaining_ht_size = remaining_data_size + ht.PointerTableSize(remaining_count); - if (remaining_ht_size <= ht.max_ht_size) { - // No need to partition as we will only have one more probe round - partitioned = false; - } else { - // More than one probe round to go, so we need to partition - partitioned = true; - global_partitions = - make_uniq(context, probe_types, ht.radix_bits, probe_types.size() - 1); - } - column_ids.reserve(probe_types.size()); - for (column_t column_id = 0; column_id < probe_types.size(); column_id++) { - column_ids.emplace_back(column_id); - } -} - -ProbeSpillLocalState ProbeSpill::RegisterThread() { - ProbeSpillLocalAppendState result; - lock_guard guard(lock); - if (partitioned) { - local_partitions.emplace_back(global_partitions->CreateShared()); - local_partition_append_states.emplace_back(make_uniq()); - local_partitions.back()->InitializeAppendState(*local_partition_append_states.back()); - - result.local_partition = local_partitions.back().get(); - result.local_partition_append_state = local_partition_append_states.back().get(); - } else { - local_spill_collections.emplace_back( - make_uniq(BufferManager::GetBufferManager(context), probe_types)); - local_spill_append_states.emplace_back(make_uniq()); - local_spill_collections.back()->InitializeAppend(*local_spill_append_states.back()); - - result.local_spill_collection = local_spill_collections.back().get(); - result.local_spill_append_state = local_spill_append_states.back().get(); - } - return result; -} - -void ProbeSpill::Append(DataChunk &chunk, ProbeSpillLocalAppendState &local_state) { - if (partitioned) { - local_state.local_partition->Append(*local_state.local_partition_append_state, chunk); - } else { - local_state.local_spill_collection->Append(*local_state.local_spill_append_state, chunk); - } -} - -void ProbeSpill::Finalize() { - if (partitioned) { - D_ASSERT(local_partitions.size() == local_partition_append_states.size()); - for (idx_t i = 0; i < local_partition_append_states.size(); i++) { - local_partitions[i]->FlushAppendState(*local_partition_append_states[i]); - } - for (auto &local_partition : local_partitions) { - global_partitions->Combine(*local_partition); - } - local_partitions.clear(); - local_partition_append_states.clear(); - } else { - if (local_spill_collections.empty()) { - global_spill_collection = - make_uniq(BufferManager::GetBufferManager(context), probe_types); - } else { - global_spill_collection = std::move(local_spill_collections[0]); - for (idx_t i = 1; i < local_spill_collections.size(); i++) { - global_spill_collection->Combine(*local_spill_collections[i]); - } - } - local_spill_collections.clear(); - local_spill_append_states.clear(); - } -} - -void ProbeSpill::PrepareNextProbe() { - if (partitioned) { - auto &partitions = global_partitions->GetPartitions(); - if (partitions.empty() || ht.partition_start == partitions.size()) { - // Can't probe, just make an empty one - global_spill_collection = - make_uniq(BufferManager::GetBufferManager(context), probe_types); - } else { - // Move specific partitions to the global spill collection - global_spill_collection = std::move(partitions[ht.partition_start]); - for (idx_t i = ht.partition_start + 1; i < ht.partition_end; i++) { - auto &partition = partitions[i]; - if (global_spill_collection->Count() == 0) { - global_spill_collection = std::move(partition); - } else { - global_spill_collection->Combine(*partition); - } - } - } - } - consumer = make_uniq(*global_spill_collection, column_ids); - consumer->InitializeScan(); -} - -} // namespace duckdb - - - -namespace duckdb { - -struct InitialNestedLoopJoin { - template - static idx_t Operation(Vector &left, Vector &right, idx_t left_size, idx_t right_size, idx_t &lpos, idx_t &rpos, - SelectionVector &lvector, SelectionVector &rvector, idx_t current_match_count) { - using MATCH_OP = ComparisonOperationWrapper; - - // initialize phase of nested loop join - // fill lvector and rvector with matches from the base vectors - UnifiedVectorFormat left_data, right_data; - left.ToUnifiedFormat(left_size, left_data); - right.ToUnifiedFormat(right_size, right_data); - - auto ldata = UnifiedVectorFormat::GetData(left_data); - auto rdata = UnifiedVectorFormat::GetData(right_data); - idx_t result_count = 0; - for (; rpos < right_size; rpos++) { - idx_t right_position = right_data.sel->get_index(rpos); - bool right_is_valid = right_data.validity.RowIsValid(right_position); - for (; lpos < left_size; lpos++) { - if (result_count == STANDARD_VECTOR_SIZE) { - // out of space! - return result_count; - } - idx_t left_position = left_data.sel->get_index(lpos); - bool left_is_valid = left_data.validity.RowIsValid(left_position); - if (MATCH_OP::Operation(ldata[left_position], rdata[right_position], !left_is_valid, !right_is_valid)) { - // emit tuple - lvector.set_index(result_count, lpos); - rvector.set_index(result_count, rpos); - result_count++; - } - } - lpos = 0; - } - return result_count; - } -}; - -struct RefineNestedLoopJoin { - template - static idx_t Operation(Vector &left, Vector &right, idx_t left_size, idx_t right_size, idx_t &lpos, idx_t &rpos, - SelectionVector &lvector, SelectionVector &rvector, idx_t current_match_count) { - using MATCH_OP = ComparisonOperationWrapper; - - UnifiedVectorFormat left_data, right_data; - left.ToUnifiedFormat(left_size, left_data); - right.ToUnifiedFormat(right_size, right_data); - - // refine phase of the nested loop join - // refine lvector and rvector based on matches of subsequent conditions (in case there are multiple conditions - // in the join) - D_ASSERT(current_match_count > 0); - auto ldata = UnifiedVectorFormat::GetData(left_data); - auto rdata = UnifiedVectorFormat::GetData(right_data); - idx_t result_count = 0; - for (idx_t i = 0; i < current_match_count; i++) { - auto lidx = lvector.get_index(i); - auto ridx = rvector.get_index(i); - auto left_idx = left_data.sel->get_index(lidx); - auto right_idx = right_data.sel->get_index(ridx); - bool left_is_valid = left_data.validity.RowIsValid(left_idx); - bool right_is_valid = right_data.validity.RowIsValid(right_idx); - if (MATCH_OP::Operation(ldata[left_idx], rdata[right_idx], !left_is_valid, !right_is_valid)) { - lvector.set_index(result_count, lidx); - rvector.set_index(result_count, ridx); - result_count++; - } - } - return result_count; - } -}; - -template -static idx_t NestedLoopJoinTypeSwitch(Vector &left, Vector &right, idx_t left_size, idx_t right_size, idx_t &lpos, - idx_t &rpos, SelectionVector &lvector, SelectionVector &rvector, - idx_t current_match_count) { - switch (left.GetType().InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - return NLTYPE::template Operation(left, right, left_size, right_size, lpos, rpos, lvector, rvector, - current_match_count); - case PhysicalType::INT16: - return NLTYPE::template Operation(left, right, left_size, right_size, lpos, rpos, lvector, rvector, - current_match_count); - case PhysicalType::INT32: - return NLTYPE::template Operation(left, right, left_size, right_size, lpos, rpos, lvector, rvector, - current_match_count); - case PhysicalType::INT64: - return NLTYPE::template Operation(left, right, left_size, right_size, lpos, rpos, lvector, rvector, - current_match_count); - case PhysicalType::UINT8: - return NLTYPE::template Operation(left, right, left_size, right_size, lpos, rpos, lvector, rvector, - current_match_count); - case PhysicalType::UINT16: - return NLTYPE::template Operation(left, right, left_size, right_size, lpos, rpos, lvector, - rvector, current_match_count); - case PhysicalType::UINT32: - return NLTYPE::template Operation(left, right, left_size, right_size, lpos, rpos, lvector, - rvector, current_match_count); - case PhysicalType::UINT64: - return NLTYPE::template Operation(left, right, left_size, right_size, lpos, rpos, lvector, - rvector, current_match_count); - case PhysicalType::INT128: - return NLTYPE::template Operation(left, right, left_size, right_size, lpos, rpos, lvector, - rvector, current_match_count); - case PhysicalType::FLOAT: - return NLTYPE::template Operation(left, right, left_size, right_size, lpos, rpos, lvector, rvector, - current_match_count); - case PhysicalType::DOUBLE: - return NLTYPE::template Operation(left, right, left_size, right_size, lpos, rpos, lvector, rvector, - current_match_count); - case PhysicalType::INTERVAL: - return NLTYPE::template Operation(left, right, left_size, right_size, lpos, rpos, lvector, - rvector, current_match_count); - case PhysicalType::VARCHAR: - return NLTYPE::template Operation(left, right, left_size, right_size, lpos, rpos, lvector, - rvector, current_match_count); - default: - throw InternalException("Unimplemented type for join!"); - } -} - -template -idx_t NestedLoopJoinComparisonSwitch(Vector &left, Vector &right, idx_t left_size, idx_t right_size, idx_t &lpos, - idx_t &rpos, SelectionVector &lvector, SelectionVector &rvector, - idx_t current_match_count, ExpressionType comparison_type) { - D_ASSERT(left.GetType() == right.GetType()); - switch (comparison_type) { - case ExpressionType::COMPARE_EQUAL: - return NestedLoopJoinTypeSwitch(left, right, left_size, right_size, lpos, rpos, lvector, - rvector, current_match_count); - case ExpressionType::COMPARE_NOTEQUAL: - return NestedLoopJoinTypeSwitch(left, right, left_size, right_size, lpos, rpos, lvector, - rvector, current_match_count); - case ExpressionType::COMPARE_LESSTHAN: - return NestedLoopJoinTypeSwitch(left, right, left_size, right_size, lpos, rpos, lvector, - rvector, current_match_count); - case ExpressionType::COMPARE_GREATERTHAN: - return NestedLoopJoinTypeSwitch(left, right, left_size, right_size, lpos, rpos, lvector, - rvector, current_match_count); - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - return NestedLoopJoinTypeSwitch(left, right, left_size, right_size, lpos, rpos, lvector, - rvector, current_match_count); - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - return NestedLoopJoinTypeSwitch(left, right, left_size, right_size, lpos, rpos, - lvector, rvector, current_match_count); - case ExpressionType::COMPARE_DISTINCT_FROM: - return NestedLoopJoinTypeSwitch(left, right, left_size, right_size, lpos, rpos, lvector, - rvector, current_match_count); - default: - throw NotImplementedException("Unimplemented comparison type for join!"); - } -} - -idx_t NestedLoopJoinInner::Perform(idx_t &lpos, idx_t &rpos, DataChunk &left_conditions, DataChunk &right_conditions, - SelectionVector &lvector, SelectionVector &rvector, - const vector &conditions) { - D_ASSERT(left_conditions.ColumnCount() == right_conditions.ColumnCount()); - if (lpos >= left_conditions.size() || rpos >= right_conditions.size()) { - return 0; - } - // for the first condition, lvector and rvector are not set yet - // we initialize them using the InitialNestedLoopJoin - idx_t match_count = NestedLoopJoinComparisonSwitch( - left_conditions.data[0], right_conditions.data[0], left_conditions.size(), right_conditions.size(), lpos, rpos, - lvector, rvector, 0, conditions[0].comparison); - // now resolve the rest of the conditions - for (idx_t i = 1; i < conditions.size(); i++) { - // check if we have run out of tuples to compare - if (match_count == 0) { - return 0; - } - // if not, get the vectors to compare - Vector &l = left_conditions.data[i]; - Vector &r = right_conditions.data[i]; - // then we refine the currently obtained results using the RefineNestedLoopJoin - match_count = NestedLoopJoinComparisonSwitch( - l, r, left_conditions.size(), right_conditions.size(), lpos, rpos, lvector, rvector, match_count, - conditions[i].comparison); - } - return match_count; -} - -} // namespace duckdb - - - - -namespace duckdb { - -template -static void TemplatedMarkJoin(Vector &left, Vector &right, idx_t lcount, idx_t rcount, bool found_match[]) { - using MATCH_OP = ComparisonOperationWrapper; - - UnifiedVectorFormat left_data, right_data; - left.ToUnifiedFormat(lcount, left_data); - right.ToUnifiedFormat(rcount, right_data); - - auto ldata = UnifiedVectorFormat::GetData(left_data); - auto rdata = UnifiedVectorFormat::GetData(right_data); - for (idx_t i = 0; i < lcount; i++) { - if (found_match[i]) { - continue; - } - auto lidx = left_data.sel->get_index(i); - const auto left_null = !left_data.validity.RowIsValid(lidx); - if (!MATCH_OP::COMPARE_NULL && left_null) { - continue; - } - for (idx_t j = 0; j < rcount; j++) { - auto ridx = right_data.sel->get_index(j); - const auto right_null = !right_data.validity.RowIsValid(ridx); - if (!MATCH_OP::COMPARE_NULL && right_null) { - continue; - } - if (MATCH_OP::template Operation(ldata[lidx], rdata[ridx], left_null, right_null)) { - found_match[i] = true; - break; - } - } - } -} - -static void MarkJoinNested(Vector &left, Vector &right, idx_t lcount, idx_t rcount, bool found_match[], - ExpressionType comparison_type) { - Vector left_reference(left.GetType()); - SelectionVector true_sel(rcount); - for (idx_t i = 0; i < lcount; i++) { - if (found_match[i]) { - continue; - } - ConstantVector::Reference(left_reference, left, i, rcount); - idx_t count; - switch (comparison_type) { - case ExpressionType::COMPARE_EQUAL: - count = VectorOperations::Equals(left_reference, right, nullptr, rcount, nullptr, nullptr); - break; - case ExpressionType::COMPARE_NOTEQUAL: - count = VectorOperations::NotEquals(left_reference, right, nullptr, rcount, nullptr, nullptr); - break; - case ExpressionType::COMPARE_LESSTHAN: - count = VectorOperations::LessThan(left_reference, right, nullptr, rcount, nullptr, nullptr); - break; - case ExpressionType::COMPARE_GREATERTHAN: - count = VectorOperations::GreaterThan(left_reference, right, nullptr, rcount, nullptr, nullptr); - break; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - count = VectorOperations::LessThanEquals(left_reference, right, nullptr, rcount, nullptr, nullptr); - break; - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - count = VectorOperations::GreaterThanEquals(left_reference, right, nullptr, rcount, nullptr, nullptr); - break; - case ExpressionType::COMPARE_DISTINCT_FROM: - count = VectorOperations::DistinctFrom(left_reference, right, nullptr, rcount, nullptr, nullptr); - break; - case ExpressionType::COMPARE_NOT_DISTINCT_FROM: - count = VectorOperations::NotDistinctFrom(left_reference, right, nullptr, rcount, nullptr, nullptr); - break; - default: - throw InternalException("Unsupported comparison type for MarkJoinNested"); - } - if (count > 0) { - found_match[i] = true; - } - } -} - -template -static void MarkJoinSwitch(Vector &left, Vector &right, idx_t lcount, idx_t rcount, bool found_match[]) { - switch (left.GetType().InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - return TemplatedMarkJoin(left, right, lcount, rcount, found_match); - case PhysicalType::INT16: - return TemplatedMarkJoin(left, right, lcount, rcount, found_match); - case PhysicalType::INT32: - return TemplatedMarkJoin(left, right, lcount, rcount, found_match); - case PhysicalType::INT64: - return TemplatedMarkJoin(left, right, lcount, rcount, found_match); - case PhysicalType::INT128: - return TemplatedMarkJoin(left, right, lcount, rcount, found_match); - case PhysicalType::UINT8: - return TemplatedMarkJoin(left, right, lcount, rcount, found_match); - case PhysicalType::UINT16: - return TemplatedMarkJoin(left, right, lcount, rcount, found_match); - case PhysicalType::UINT32: - return TemplatedMarkJoin(left, right, lcount, rcount, found_match); - case PhysicalType::UINT64: - return TemplatedMarkJoin(left, right, lcount, rcount, found_match); - case PhysicalType::FLOAT: - return TemplatedMarkJoin(left, right, lcount, rcount, found_match); - case PhysicalType::DOUBLE: - return TemplatedMarkJoin(left, right, lcount, rcount, found_match); - case PhysicalType::VARCHAR: - return TemplatedMarkJoin(left, right, lcount, rcount, found_match); - default: - throw NotImplementedException("Unimplemented type for mark join!"); - } -} - -static void MarkJoinComparisonSwitch(Vector &left, Vector &right, idx_t lcount, idx_t rcount, bool found_match[], - ExpressionType comparison_type) { - switch (left.GetType().InternalType()) { - case PhysicalType::STRUCT: - case PhysicalType::LIST: - return MarkJoinNested(left, right, lcount, rcount, found_match, comparison_type); - default: - break; - } - D_ASSERT(left.GetType() == right.GetType()); - switch (comparison_type) { - case ExpressionType::COMPARE_EQUAL: - return MarkJoinSwitch(left, right, lcount, rcount, found_match); - case ExpressionType::COMPARE_NOTEQUAL: - return MarkJoinSwitch(left, right, lcount, rcount, found_match); - case ExpressionType::COMPARE_LESSTHAN: - return MarkJoinSwitch(left, right, lcount, rcount, found_match); - case ExpressionType::COMPARE_GREATERTHAN: - return MarkJoinSwitch(left, right, lcount, rcount, found_match); - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - return MarkJoinSwitch(left, right, lcount, rcount, found_match); - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - return MarkJoinSwitch(left, right, lcount, rcount, found_match); - case ExpressionType::COMPARE_DISTINCT_FROM: - return MarkJoinSwitch(left, right, lcount, rcount, found_match); - default: - throw NotImplementedException("Unimplemented comparison type for join!"); - } -} - -void NestedLoopJoinMark::Perform(DataChunk &left, ColumnDataCollection &right, bool found_match[], - const vector &conditions) { - // initialize a new temporary selection vector for the left chunk - // loop over all chunks in the RHS - ColumnDataScanState scan_state; - right.InitializeScan(scan_state); - - DataChunk scan_chunk; - right.InitializeScanChunk(scan_chunk); - - while (right.Scan(scan_state, scan_chunk)) { - for (idx_t i = 0; i < conditions.size(); i++) { - MarkJoinComparisonSwitch(left.data[i], scan_chunk.data[i], left.size(), scan_chunk.size(), found_match, - conditions[i].comparison); - } - } -} - -} // namespace duckdb - - - - - -namespace duckdb { - -AggregateObject::AggregateObject(AggregateFunction function, FunctionData *bind_data, idx_t child_count, - idx_t payload_size, AggregateType aggr_type, PhysicalType return_type, - Expression *filter) - : function(std::move(function)), - bind_data_wrapper(bind_data ? make_shared(bind_data->Copy()) : nullptr), - child_count(child_count), payload_size(payload_size), aggr_type(aggr_type), return_type(return_type), - filter(filter) { -} - -AggregateObject::AggregateObject(BoundAggregateExpression *aggr) - : AggregateObject(aggr->function, aggr->bind_info.get(), aggr->children.size(), - AlignValue(aggr->function.state_size()), aggr->aggr_type, aggr->return_type.InternalType(), - aggr->filter.get()) { -} - -AggregateObject::AggregateObject(BoundWindowExpression &window) - : AggregateObject(*window.aggregate, window.bind_info.get(), window.children.size(), - AlignValue(window.aggregate->state_size()), AggregateType::NON_DISTINCT, - window.return_type.InternalType(), window.filter_expr.get()) { -} - -vector AggregateObject::CreateAggregateObjects(const vector &bindings) { - vector aggregates; - aggregates.reserve(aggregates.size()); - for (auto &binding : bindings) { - aggregates.emplace_back(binding); - } - return aggregates; -} - -AggregateFilterData::AggregateFilterData(ClientContext &context, Expression &filter_expr, - const vector &payload_types) - : filter_executor(context, &filter_expr), true_sel(STANDARD_VECTOR_SIZE) { - if (payload_types.empty()) { - return; - } - filtered_payload.Initialize(Allocator::Get(context), payload_types); -} - -idx_t AggregateFilterData::ApplyFilter(DataChunk &payload) { - filtered_payload.Reset(); - - auto count = filter_executor.SelectExpression(payload, true_sel); - filtered_payload.Slice(payload, true_sel, count); - return count; -} - -AggregateFilterDataSet::AggregateFilterDataSet() { -} - -void AggregateFilterDataSet::Initialize(ClientContext &context, const vector &aggregates, - const vector &payload_types) { - bool has_filters = false; - for (auto &aggregate : aggregates) { - if (aggregate.filter) { - has_filters = true; - break; - } - } - if (!has_filters) { - // no filters: nothing to do - return; - } - filter_data.resize(aggregates.size()); - for (idx_t aggr_idx = 0; aggr_idx < aggregates.size(); aggr_idx++) { - auto &aggr = aggregates[aggr_idx]; - if (aggr.filter) { - filter_data[aggr_idx] = make_uniq(context, *aggr.filter, payload_types); - } - } -} - -AggregateFilterData &AggregateFilterDataSet::GetFilterData(idx_t aggr_idx) { - D_ASSERT(aggr_idx < filter_data.size()); - D_ASSERT(filter_data[aggr_idx]); - return *filter_data[aggr_idx]; -} -} // namespace duckdb - - - - - - -namespace duckdb { - -//! Shared information about a collection of distinct aggregates -DistinctAggregateCollectionInfo::DistinctAggregateCollectionInfo(const vector> &aggregates, - vector indices) - : indices(std::move(indices)), aggregates(aggregates) { - table_count = CreateTableIndexMap(); - - const idx_t aggregate_count = aggregates.size(); - - total_child_count = 0; - for (idx_t i = 0; i < aggregate_count; i++) { - auto &aggregate = aggregates[i]->Cast(); - - if (!aggregate.IsDistinct()) { - continue; - } - total_child_count += aggregate.children.size(); - } -} - -//! Stateful data for the distinct aggregates - -DistinctAggregateState::DistinctAggregateState(const DistinctAggregateData &data, ClientContext &client) - : child_executor(client) { - - radix_states.resize(data.info.table_count); - distinct_output_chunks.resize(data.info.table_count); - - idx_t aggregate_count = data.info.aggregates.size(); - for (idx_t i = 0; i < aggregate_count; i++) { - auto &aggregate = data.info.aggregates[i]->Cast(); - - // Initialize the child executor and get the payload types for every aggregate - for (auto &child : aggregate.children) { - child_executor.AddExpression(*child); - } - if (!aggregate.IsDistinct()) { - continue; - } - D_ASSERT(data.info.table_map.count(i)); - idx_t table_idx = data.info.table_map.at(i); - if (data.radix_tables[table_idx] == nullptr) { - //! This table is unused because the aggregate shares its data with another - continue; - } - - // Get the global sinkstate for the aggregate - auto &radix_table = *data.radix_tables[table_idx]; - radix_states[table_idx] = radix_table.GetGlobalSinkState(client); - - // Fill the chunk_types (group_by + children) - vector chunk_types; - for (auto &group_type : data.grouped_aggregate_data[table_idx]->group_types) { - chunk_types.push_back(group_type); - } - - // This is used in Finalize to get the data from the radix table - distinct_output_chunks[table_idx] = make_uniq(); - distinct_output_chunks[table_idx]->Initialize(client, chunk_types); - } -} - -//! Persistent + shared (read-only) data for the distinct aggregates -DistinctAggregateData::DistinctAggregateData(const DistinctAggregateCollectionInfo &info) - : DistinctAggregateData(info, {}, nullptr) { -} - -DistinctAggregateData::DistinctAggregateData(const DistinctAggregateCollectionInfo &info, const GroupingSet &groups, - const vector> *group_expressions) - : info(info) { - grouped_aggregate_data.resize(info.table_count); - radix_tables.resize(info.table_count); - grouping_sets.resize(info.table_count); - - for (auto &i : info.indices) { - auto &aggregate = info.aggregates[i]->Cast(); - - D_ASSERT(info.table_map.count(i)); - idx_t table_idx = info.table_map.at(i); - if (radix_tables[table_idx] != nullptr) { - //! This aggregate shares a table with another aggregate, and the table is already initialized - continue; - } - // The grouping set contains the indices of the chunk that correspond to the data vector - // that will be used to figure out in which bucket the payload should be put - auto &grouping_set = grouping_sets[table_idx]; - //! Populate the group with the children of the aggregate - for (auto &group : groups) { - grouping_set.insert(group); - } - idx_t group_by_size = group_expressions ? group_expressions->size() : 0; - for (idx_t set_idx = 0; set_idx < aggregate.children.size(); set_idx++) { - grouping_set.insert(set_idx + group_by_size); - } - // Create the hashtable for the aggregate - grouped_aggregate_data[table_idx] = make_uniq(); - grouped_aggregate_data[table_idx]->InitializeDistinct(info.aggregates[i], group_expressions); - radix_tables[table_idx] = - make_uniq(grouping_set, *grouped_aggregate_data[table_idx]); - - // Fill the chunk_types (only contains the payload of the distinct aggregates) - vector chunk_types; - for (auto &child_p : aggregate.children) { - chunk_types.push_back(child_p->return_type); - } - } -} - -using aggr_ref_t = reference; - -struct FindMatchingAggregate { - explicit FindMatchingAggregate(const aggr_ref_t &aggr) : aggr_r(aggr) { - } - bool operator()(const aggr_ref_t other_r) { - auto &other = other_r.get(); - auto &aggr = aggr_r.get(); - if (other.children.size() != aggr.children.size()) { - return false; - } - if (!Expression::Equals(aggr.filter, other.filter)) { - return false; - } - for (idx_t i = 0; i < aggr.children.size(); i++) { - auto &other_child = other.children[i]->Cast(); - auto &aggr_child = aggr.children[i]->Cast(); - if (other_child.index != aggr_child.index) { - return false; - } - } - return true; - } - const aggr_ref_t aggr_r; -}; - -idx_t DistinctAggregateCollectionInfo::CreateTableIndexMap() { - vector table_inputs; - - D_ASSERT(table_map.empty()); - for (auto &agg_idx : indices) { - D_ASSERT(agg_idx < aggregates.size()); - auto &aggregate = aggregates[agg_idx]->Cast(); - - auto matching_inputs = - std::find_if(table_inputs.begin(), table_inputs.end(), FindMatchingAggregate(std::ref(aggregate))); - if (matching_inputs != table_inputs.end()) { - //! Assign the existing table to the aggregate - idx_t found_idx = std::distance(table_inputs.begin(), matching_inputs); - table_map[agg_idx] = found_idx; - continue; - } - //! Create a new table and assign its index to the aggregate - table_map[agg_idx] = table_inputs.size(); - table_inputs.push_back(std::ref(aggregate)); - } - //! Every distinct aggregate needs to be assigned an index - D_ASSERT(table_map.size() == indices.size()); - //! There can not be more tables than there are distinct aggregates - D_ASSERT(table_inputs.size() <= indices.size()); - - return table_inputs.size(); -} - -bool DistinctAggregateCollectionInfo::AnyDistinct() const { - return !indices.empty(); -} - -const unsafe_vector &DistinctAggregateCollectionInfo::Indices() const { - return this->indices; -} - -static vector GetDistinctIndices(vector> &aggregates) { - vector distinct_indices; - for (idx_t i = 0; i < aggregates.size(); i++) { - auto &aggregate = aggregates[i]; - auto &aggr = aggregate->Cast(); - if (aggr.IsDistinct()) { - distinct_indices.push_back(i); - } - } - return distinct_indices; -} - -unique_ptr -DistinctAggregateCollectionInfo::Create(vector> &aggregates) { - vector indices = GetDistinctIndices(aggregates); - if (indices.empty()) { - return nullptr; - } - return make_uniq(aggregates, std::move(indices)); -} - -bool DistinctAggregateData::IsDistinct(idx_t index) const { - bool is_distinct = !radix_tables.empty() && info.table_map.count(index); -#ifdef DEBUG - //! Make sure that if it is distinct, it's also in the indices - //! And if it's not distinct, that it's also not in the indices - bool found = false; - for (auto &idx : info.indices) { - if (idx == index) { - found = true; - break; - } - } - D_ASSERT(found == is_distinct); -#endif - return is_distinct; -} - -} // namespace duckdb - - -namespace duckdb { - -idx_t GroupedAggregateData::GroupCount() const { - return groups.size(); -} - -const vector> &GroupedAggregateData::GetGroupingFunctions() const { - return grouping_functions; -} - -void GroupedAggregateData::InitializeGroupby(vector> groups, - vector> expressions, - vector> grouping_functions) { - InitializeGroupbyGroups(std::move(groups)); - vector payload_types_filters; - - SetGroupingFunctions(grouping_functions); - - filter_count = 0; - for (auto &expr : expressions) { - D_ASSERT(expr->expression_class == ExpressionClass::BOUND_AGGREGATE); - D_ASSERT(expr->IsAggregate()); - auto &aggr = expr->Cast(); - bindings.push_back(&aggr); - - aggregate_return_types.push_back(aggr.return_type); - for (auto &child : aggr.children) { - payload_types.push_back(child->return_type); - } - if (aggr.filter) { - filter_count++; - payload_types_filters.push_back(aggr.filter->return_type); - } - if (!aggr.function.combine) { - throw InternalException("Aggregate function %s is missing a combine method", aggr.function.name); - } - aggregates.push_back(std::move(expr)); - } - for (const auto &pay_filters : payload_types_filters) { - payload_types.push_back(pay_filters); - } -} - -void GroupedAggregateData::InitializeDistinct(const unique_ptr &aggregate, - const vector> *groups_p) { - auto &aggr = aggregate->Cast(); - D_ASSERT(aggr.IsDistinct()); - - // Add the (empty in ungrouped case) groups of the aggregates - InitializeDistinctGroups(groups_p); - - // bindings.push_back(&aggr); - filter_count = 0; - aggregate_return_types.push_back(aggr.return_type); - for (idx_t i = 0; i < aggr.children.size(); i++) { - auto &child = aggr.children[i]; - group_types.push_back(child->return_type); - groups.push_back(child->Copy()); - payload_types.push_back(child->return_type); - if (aggr.filter) { - filter_count++; - } - } - if (!aggr.function.combine) { - throw InternalException("Aggregate function %s is missing a combine method", aggr.function.name); - } -} - -void GroupedAggregateData::InitializeDistinctGroups(const vector> *groups_p) { - if (!groups_p) { - return; - } - for (auto &expr : *groups_p) { - group_types.push_back(expr->return_type); - groups.push_back(expr->Copy()); - } -} - -void GroupedAggregateData::InitializeGroupbyGroups(vector> groups) { - // Add all the expressions of the group by clause - for (auto &expr : groups) { - group_types.push_back(expr->return_type); - } - this->groups = std::move(groups); -} - -void GroupedAggregateData::SetGroupingFunctions(vector> &functions) { - grouping_functions.reserve(functions.size()); - for (idx_t i = 0; i < functions.size(); i++) { - grouping_functions.push_back(std::move(functions[i])); - } -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - -namespace duckdb { - -HashAggregateGroupingData::HashAggregateGroupingData(GroupingSet &grouping_set_p, - const GroupedAggregateData &grouped_aggregate_data, - unique_ptr &info) - : table_data(grouping_set_p, grouped_aggregate_data) { - if (info) { - distinct_data = make_uniq(*info, grouping_set_p, &grouped_aggregate_data.groups); - } -} - -bool HashAggregateGroupingData::HasDistinct() const { - return distinct_data != nullptr; -} - -HashAggregateGroupingGlobalState::HashAggregateGroupingGlobalState(const HashAggregateGroupingData &data, - ClientContext &context) { - table_state = data.table_data.GetGlobalSinkState(context); - if (data.HasDistinct()) { - distinct_state = make_uniq(*data.distinct_data, context); - } -} - -HashAggregateGroupingLocalState::HashAggregateGroupingLocalState(const PhysicalHashAggregate &op, - const HashAggregateGroupingData &data, - ExecutionContext &context) { - table_state = data.table_data.GetLocalSinkState(context); - if (!data.HasDistinct()) { - return; - } - auto &distinct_data = *data.distinct_data; - - auto &distinct_indices = op.distinct_collection_info->Indices(); - D_ASSERT(!distinct_indices.empty()); - - distinct_states.resize(op.distinct_collection_info->aggregates.size()); - auto &table_map = op.distinct_collection_info->table_map; - - for (auto &idx : distinct_indices) { - idx_t table_idx = table_map[idx]; - auto &radix_table = distinct_data.radix_tables[table_idx]; - if (radix_table == nullptr) { - // This aggregate has identical input as another aggregate, so no table is created for it - continue; - } - // Initialize the states of the radix tables used for the distinct aggregates - distinct_states[table_idx] = radix_table->GetLocalSinkState(context); - } -} - -static vector CreateGroupChunkTypes(vector> &groups) { - set group_indices; - - if (groups.empty()) { - return {}; - } - - for (auto &group : groups) { - D_ASSERT(group->type == ExpressionType::BOUND_REF); - auto &bound_ref = group->Cast(); - group_indices.insert(bound_ref.index); - } - idx_t highest_index = *group_indices.rbegin(); - vector types(highest_index + 1, LogicalType::SQLNULL); - for (auto &group : groups) { - auto &bound_ref = group->Cast(); - types[bound_ref.index] = bound_ref.return_type; - } - return types; -} - -bool PhysicalHashAggregate::CanSkipRegularSink() const { - if (!filter_indexes.empty()) { - // If we have filters, we can't skip the regular sink, because we might lose groups otherwise. - return false; - } - if (grouped_aggregate_data.aggregates.empty()) { - // When there are no aggregates, we have to add to the main ht right away - return false; - } - if (!non_distinct_filter.empty()) { - return false; - } - return true; -} - -PhysicalHashAggregate::PhysicalHashAggregate(ClientContext &context, vector types, - vector> expressions, idx_t estimated_cardinality) - : PhysicalHashAggregate(context, std::move(types), std::move(expressions), {}, estimated_cardinality) { -} - -PhysicalHashAggregate::PhysicalHashAggregate(ClientContext &context, vector types, - vector> expressions, - vector> groups_p, idx_t estimated_cardinality) - : PhysicalHashAggregate(context, std::move(types), std::move(expressions), std::move(groups_p), {}, {}, - estimated_cardinality) { -} - -PhysicalHashAggregate::PhysicalHashAggregate(ClientContext &context, vector types, - vector> expressions, - vector> groups_p, - vector grouping_sets_p, - vector> grouping_functions_p, - idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::HASH_GROUP_BY, std::move(types), estimated_cardinality), - grouping_sets(std::move(grouping_sets_p)) { - // get a list of all aggregates to be computed - const idx_t group_count = groups_p.size(); - if (grouping_sets.empty()) { - GroupingSet set; - for (idx_t i = 0; i < group_count; i++) { - set.insert(i); - } - grouping_sets.push_back(std::move(set)); - } - input_group_types = CreateGroupChunkTypes(groups_p); - - grouped_aggregate_data.InitializeGroupby(std::move(groups_p), std::move(expressions), - std::move(grouping_functions_p)); - - auto &aggregates = grouped_aggregate_data.aggregates; - // filter_indexes must be pre-built, not lazily instantiated in parallel... - // Because everything that lives in this class should be read-only at execution time - idx_t aggregate_input_idx = 0; - for (idx_t i = 0; i < aggregates.size(); i++) { - auto &aggregate = aggregates[i]; - auto &aggr = aggregate->Cast(); - aggregate_input_idx += aggr.children.size(); - if (aggr.aggr_type == AggregateType::DISTINCT) { - distinct_filter.push_back(i); - } else if (aggr.aggr_type == AggregateType::NON_DISTINCT) { - non_distinct_filter.push_back(i); - } else { // LCOV_EXCL_START - throw NotImplementedException("AggregateType not implemented in PhysicalHashAggregate"); - } // LCOV_EXCL_STOP - } - - for (idx_t i = 0; i < aggregates.size(); i++) { - auto &aggregate = aggregates[i]; - auto &aggr = aggregate->Cast(); - if (aggr.filter) { - auto &bound_ref_expr = aggr.filter->Cast(); - if (!filter_indexes.count(aggr.filter.get())) { - // Replace the bound reference expression's index with the corresponding index of the payload chunk - filter_indexes[aggr.filter.get()] = bound_ref_expr.index; - bound_ref_expr.index = aggregate_input_idx; - } - aggregate_input_idx++; - } - } - - distinct_collection_info = DistinctAggregateCollectionInfo::Create(grouped_aggregate_data.aggregates); - - for (idx_t i = 0; i < grouping_sets.size(); i++) { - groupings.emplace_back(grouping_sets[i], grouped_aggregate_data, distinct_collection_info); - } -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class HashAggregateGlobalSinkState : public GlobalSinkState { -public: - HashAggregateGlobalSinkState(const PhysicalHashAggregate &op, ClientContext &context) { - grouping_states.reserve(op.groupings.size()); - for (idx_t i = 0; i < op.groupings.size(); i++) { - auto &grouping = op.groupings[i]; - grouping_states.emplace_back(grouping, context); - } - vector filter_types; - for (auto &aggr : op.grouped_aggregate_data.aggregates) { - auto &aggregate = aggr->Cast(); - for (auto &child : aggregate.children) { - payload_types.push_back(child->return_type); - } - if (aggregate.filter) { - filter_types.push_back(aggregate.filter->return_type); - } - } - payload_types.reserve(payload_types.size() + filter_types.size()); - payload_types.insert(payload_types.end(), filter_types.begin(), filter_types.end()); - } - - vector grouping_states; - vector payload_types; - //! Whether or not the aggregate is finished - bool finished = false; -}; - -class HashAggregateLocalSinkState : public LocalSinkState { -public: - HashAggregateLocalSinkState(const PhysicalHashAggregate &op, ExecutionContext &context) { - - auto &payload_types = op.grouped_aggregate_data.payload_types; - if (!payload_types.empty()) { - aggregate_input_chunk.InitializeEmpty(payload_types); - } - - grouping_states.reserve(op.groupings.size()); - for (auto &grouping : op.groupings) { - grouping_states.emplace_back(op, grouping, context); - } - // The filter set is only needed here for the distinct aggregates - // the filtering of data for the regular aggregates is done within the hashtable - vector aggregate_objects; - for (auto &aggregate : op.grouped_aggregate_data.aggregates) { - auto &aggr = aggregate->Cast(); - aggregate_objects.emplace_back(&aggr); - } - - filter_set.Initialize(context.client, aggregate_objects, payload_types); - } - - DataChunk aggregate_input_chunk; - vector grouping_states; - AggregateFilterDataSet filter_set; -}; - -void PhysicalHashAggregate::SetMultiScan(GlobalSinkState &state) { - auto &gstate = state.Cast(); - for (auto &grouping_state : gstate.grouping_states) { - RadixPartitionedHashTable::SetMultiScan(*grouping_state.table_state); - if (!grouping_state.distinct_state) { - continue; - } - } -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -unique_ptr PhysicalHashAggregate::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(*this, context); -} - -unique_ptr PhysicalHashAggregate::GetLocalSinkState(ExecutionContext &context) const { - return make_uniq(*this, context); -} - -void PhysicalHashAggregate::SinkDistinctGrouping(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input, - idx_t grouping_idx) const { - auto &sink = input.local_state.Cast(); - auto &global_sink = input.global_state.Cast(); - - auto &grouping_gstate = global_sink.grouping_states[grouping_idx]; - auto &grouping_lstate = sink.grouping_states[grouping_idx]; - auto &distinct_info = *distinct_collection_info; - - auto &distinct_state = grouping_gstate.distinct_state; - auto &distinct_data = groupings[grouping_idx].distinct_data; - - DataChunk empty_chunk; - - // Create an empty filter for Sink, since we don't need to update any aggregate states here - unsafe_vector empty_filter; - - for (idx_t &idx : distinct_info.indices) { - auto &aggregate = grouped_aggregate_data.aggregates[idx]->Cast(); - - D_ASSERT(distinct_info.table_map.count(idx)); - idx_t table_idx = distinct_info.table_map[idx]; - if (!distinct_data->radix_tables[table_idx]) { - continue; - } - D_ASSERT(distinct_data->radix_tables[table_idx]); - auto &radix_table = *distinct_data->radix_tables[table_idx]; - auto &radix_global_sink = *distinct_state->radix_states[table_idx]; - auto &radix_local_sink = *grouping_lstate.distinct_states[table_idx]; - - InterruptState interrupt_state; - OperatorSinkInput sink_input {radix_global_sink, radix_local_sink, interrupt_state}; - - if (aggregate.filter) { - DataChunk filter_chunk; - auto &filtered_data = sink.filter_set.GetFilterData(idx); - filter_chunk.InitializeEmpty(filtered_data.filtered_payload.GetTypes()); - - // Add the filter Vector (BOOL) - auto it = filter_indexes.find(aggregate.filter.get()); - D_ASSERT(it != filter_indexes.end()); - D_ASSERT(it->second < chunk.data.size()); - auto &filter_bound_ref = aggregate.filter->Cast(); - filter_chunk.data[filter_bound_ref.index].Reference(chunk.data[it->second]); - filter_chunk.SetCardinality(chunk.size()); - - // We cant use the AggregateFilterData::ApplyFilter method, because the chunk we need to - // apply the filter to also has the groups, and the filtered_data.filtered_payload does not have those. - SelectionVector sel_vec(STANDARD_VECTOR_SIZE); - idx_t count = filtered_data.filter_executor.SelectExpression(filter_chunk, sel_vec); - - if (count == 0) { - continue; - } - - // Because the 'input' chunk needs to be re-used after this, we need to create - // a duplicate of it, that we can apply the filter to - DataChunk filtered_input; - filtered_input.InitializeEmpty(chunk.GetTypes()); - - for (idx_t group_idx = 0; group_idx < grouped_aggregate_data.groups.size(); group_idx++) { - auto &group = grouped_aggregate_data.groups[group_idx]; - auto &bound_ref = group->Cast(); - filtered_input.data[bound_ref.index].Reference(chunk.data[bound_ref.index]); - } - for (idx_t child_idx = 0; child_idx < aggregate.children.size(); child_idx++) { - auto &child = aggregate.children[child_idx]; - auto &bound_ref = child->Cast(); - - filtered_input.data[bound_ref.index].Reference(chunk.data[bound_ref.index]); - } - filtered_input.Slice(sel_vec, count); - filtered_input.SetCardinality(count); - - radix_table.Sink(context, filtered_input, sink_input, empty_chunk, empty_filter); - } else { - radix_table.Sink(context, chunk, sink_input, empty_chunk, empty_filter); - } - } -} - -void PhysicalHashAggregate::SinkDistinct(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - for (idx_t i = 0; i < groupings.size(); i++) { - SinkDistinctGrouping(context, chunk, input, i); - } -} - -SinkResultType PhysicalHashAggregate::Sink(ExecutionContext &context, DataChunk &chunk, - OperatorSinkInput &input) const { - auto &local_state = input.local_state.Cast(); - auto &global_state = input.global_state.Cast(); - - if (distinct_collection_info) { - SinkDistinct(context, chunk, input); - } - - if (CanSkipRegularSink()) { - return SinkResultType::NEED_MORE_INPUT; - } - - DataChunk &aggregate_input_chunk = local_state.aggregate_input_chunk; - auto &aggregates = grouped_aggregate_data.aggregates; - idx_t aggregate_input_idx = 0; - - // Populate the aggregate child vectors - for (auto &aggregate : aggregates) { - auto &aggr = aggregate->Cast(); - for (auto &child_expr : aggr.children) { - D_ASSERT(child_expr->type == ExpressionType::BOUND_REF); - auto &bound_ref_expr = child_expr->Cast(); - D_ASSERT(bound_ref_expr.index < chunk.data.size()); - aggregate_input_chunk.data[aggregate_input_idx++].Reference(chunk.data[bound_ref_expr.index]); - } - } - // Populate the filter vectors - for (auto &aggregate : aggregates) { - auto &aggr = aggregate->Cast(); - if (aggr.filter) { - auto it = filter_indexes.find(aggr.filter.get()); - D_ASSERT(it != filter_indexes.end()); - D_ASSERT(it->second < chunk.data.size()); - aggregate_input_chunk.data[aggregate_input_idx++].Reference(chunk.data[it->second]); - } - } - - aggregate_input_chunk.SetCardinality(chunk.size()); - aggregate_input_chunk.Verify(); - - // For every grouping set there is one radix_table - for (idx_t i = 0; i < groupings.size(); i++) { - auto &grouping_local_state = global_state.grouping_states[i]; - auto &grouping_global_state = local_state.grouping_states[i]; - InterruptState interrupt_state; - OperatorSinkInput sink_input {*grouping_local_state.table_state, *grouping_global_state.table_state, - interrupt_state}; - - auto &grouping = groupings[i]; - auto &table = grouping.table_data; - table.Sink(context, chunk, sink_input, aggregate_input_chunk, non_distinct_filter); - } - - return SinkResultType::NEED_MORE_INPUT; -} - -//===--------------------------------------------------------------------===// -// Combine -//===--------------------------------------------------------------------===// -void PhysicalHashAggregate::CombineDistinct(ExecutionContext &context, OperatorSinkCombineInput &input) const { - - auto &global_sink = input.global_state.Cast(); - auto &sink = input.local_state.Cast(); - - if (!distinct_collection_info) { - return; - } - for (idx_t i = 0; i < groupings.size(); i++) { - auto &grouping_gstate = global_sink.grouping_states[i]; - auto &grouping_lstate = sink.grouping_states[i]; - - auto &distinct_data = groupings[i].distinct_data; - auto &distinct_state = grouping_gstate.distinct_state; - - const auto table_count = distinct_data->radix_tables.size(); - for (idx_t table_idx = 0; table_idx < table_count; table_idx++) { - if (!distinct_data->radix_tables[table_idx]) { - continue; - } - auto &radix_table = *distinct_data->radix_tables[table_idx]; - auto &radix_global_sink = *distinct_state->radix_states[table_idx]; - auto &radix_local_sink = *grouping_lstate.distinct_states[table_idx]; - - radix_table.Combine(context, radix_global_sink, radix_local_sink); - } - } -} - -SinkCombineResultType PhysicalHashAggregate::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &llstate = input.local_state.Cast(); - - OperatorSinkCombineInput combine_distinct_input {gstate, llstate, input.interrupt_state}; - CombineDistinct(context, combine_distinct_input); - - if (CanSkipRegularSink()) { - return SinkCombineResultType::FINISHED; - } - for (idx_t i = 0; i < groupings.size(); i++) { - auto &grouping_gstate = gstate.grouping_states[i]; - auto &grouping_lstate = llstate.grouping_states[i]; - - auto &grouping = groupings[i]; - auto &table = grouping.table_data; - table.Combine(context, *grouping_gstate.table_state, *grouping_lstate.table_state); - } - - return SinkCombineResultType::FINISHED; -} - -//===--------------------------------------------------------------------===// -// Finalize -//===--------------------------------------------------------------------===// -class HashAggregateFinalizeEvent : public BasePipelineEvent { -public: - //! "Regular" Finalize Event that is scheduled after combining the thread-local distinct HTs - HashAggregateFinalizeEvent(ClientContext &context, Pipeline *pipeline_p, const PhysicalHashAggregate &op_p, - HashAggregateGlobalSinkState &gstate_p) - : BasePipelineEvent(*pipeline_p), context(context), op(op_p), gstate(gstate_p) { - } - -public: - void Schedule() override; - -private: - ClientContext &context; - - const PhysicalHashAggregate &op; - HashAggregateGlobalSinkState &gstate; -}; - -class HashAggregateFinalizeTask : public ExecutorTask { -public: - HashAggregateFinalizeTask(ClientContext &context, Pipeline &pipeline, shared_ptr event_p, - const PhysicalHashAggregate &op, HashAggregateGlobalSinkState &state_p) - : ExecutorTask(pipeline.executor), context(context), pipeline(pipeline), event(std::move(event_p)), op(op), - gstate(state_p) { - } - -public: - TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override; - -private: - ClientContext &context; - Pipeline &pipeline; - shared_ptr event; - - const PhysicalHashAggregate &op; - HashAggregateGlobalSinkState &gstate; -}; - -void HashAggregateFinalizeEvent::Schedule() { - vector> tasks; - tasks.push_back(make_uniq(context, *pipeline, shared_from_this(), op, gstate)); - D_ASSERT(!tasks.empty()); - SetTasks(std::move(tasks)); -} - -TaskExecutionResult HashAggregateFinalizeTask::ExecuteTask(TaskExecutionMode mode) { - op.FinalizeInternal(pipeline, *event, context, gstate, false); - D_ASSERT(!gstate.finished); - gstate.finished = true; - event->FinishTask(); - return TaskExecutionResult::TASK_FINISHED; -} - -class HashAggregateDistinctFinalizeEvent : public BasePipelineEvent { -public: - //! Distinct Finalize Event that is scheduled if we have distinct aggregates - HashAggregateDistinctFinalizeEvent(ClientContext &context, Pipeline &pipeline_p, const PhysicalHashAggregate &op_p, - HashAggregateGlobalSinkState &gstate_p) - : BasePipelineEvent(pipeline_p), context(context), op(op_p), gstate(gstate_p) { - } - -public: - void Schedule() override; - void FinishEvent() override; - -private: - void CreateGlobalSources(); - -private: - ClientContext &context; - - const PhysicalHashAggregate &op; - HashAggregateGlobalSinkState &gstate; - -public: - //! The GlobalSourceStates for all the radix tables of the distinct aggregates - vector>> global_source_states; -}; - -class HashAggregateDistinctFinalizeTask : public ExecutorTask { -public: - HashAggregateDistinctFinalizeTask(Pipeline &pipeline, shared_ptr event_p, const PhysicalHashAggregate &op, - HashAggregateGlobalSinkState &state_p) - : ExecutorTask(pipeline.executor), pipeline(pipeline), event(std::move(event_p)), op(op), gstate(state_p) { - } - -public: - TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override; - -private: - void AggregateDistinctGrouping(const idx_t grouping_idx); - -private: - Pipeline &pipeline; - shared_ptr event; - - const PhysicalHashAggregate &op; - HashAggregateGlobalSinkState &gstate; -}; - -void HashAggregateDistinctFinalizeEvent::Schedule() { - CreateGlobalSources(); - - const idx_t n_threads = TaskScheduler::GetScheduler(context).NumberOfThreads(); - vector> tasks; - for (idx_t i = 0; i < n_threads; i++) { - tasks.push_back(make_uniq(*pipeline, shared_from_this(), op, gstate)); - } - SetTasks(std::move(tasks)); -} - -void HashAggregateDistinctFinalizeEvent::CreateGlobalSources() { - auto &aggregates = op.grouped_aggregate_data.aggregates; - global_source_states.reserve(op.groupings.size()); - for (idx_t grouping_idx = 0; grouping_idx < op.groupings.size(); grouping_idx++) { - auto &grouping = op.groupings[grouping_idx]; - auto &distinct_data = *grouping.distinct_data; - - vector> aggregate_sources; - aggregate_sources.reserve(aggregates.size()); - for (idx_t agg_idx = 0; agg_idx < aggregates.size(); agg_idx++) { - auto &aggregate = aggregates[agg_idx]; - auto &aggr = aggregate->Cast(); - - if (!aggr.IsDistinct()) { - aggregate_sources.push_back(nullptr); - continue; - } - D_ASSERT(distinct_data.info.table_map.count(agg_idx)); - - auto table_idx = distinct_data.info.table_map.at(agg_idx); - auto &radix_table_p = distinct_data.radix_tables[table_idx]; - aggregate_sources.push_back(radix_table_p->GetGlobalSourceState(context)); - } - global_source_states.push_back(std::move(aggregate_sources)); - } -} - -void HashAggregateDistinctFinalizeEvent::FinishEvent() { - // Now that everything is added to the main ht, we can actually finalize - auto new_event = make_shared(context, pipeline.get(), op, gstate); - this->InsertEvent(std::move(new_event)); -} - -TaskExecutionResult HashAggregateDistinctFinalizeTask::ExecuteTask(TaskExecutionMode mode) { - for (idx_t grouping_idx = 0; grouping_idx < op.groupings.size(); grouping_idx++) { - AggregateDistinctGrouping(grouping_idx); - } - event->FinishTask(); - return TaskExecutionResult::TASK_FINISHED; -} - -void HashAggregateDistinctFinalizeTask::AggregateDistinctGrouping(const idx_t grouping_idx) { - D_ASSERT(op.distinct_collection_info); - auto &info = *op.distinct_collection_info; - - auto &grouping_data = op.groupings[grouping_idx]; - auto &grouping_state = gstate.grouping_states[grouping_idx]; - D_ASSERT(grouping_state.distinct_state); - auto &distinct_state = *grouping_state.distinct_state; - auto &distinct_data = *grouping_data.distinct_data; - - auto &aggregates = info.aggregates; - - // Thread-local contexts - ThreadContext thread_context(executor.context); - ExecutionContext execution_context(executor.context, thread_context, &pipeline); - - // Sink state to sink into global HTs - InterruptState interrupt_state; - auto &global_sink_state = *grouping_state.table_state; - auto local_sink_state = grouping_data.table_data.GetLocalSinkState(execution_context); - OperatorSinkInput sink_input {global_sink_state, *local_sink_state, interrupt_state}; - - // Create a chunk that mimics the 'input' chunk in Sink, for storing the group vectors - DataChunk group_chunk; - if (!op.input_group_types.empty()) { - group_chunk.Initialize(executor.context, op.input_group_types); - } - - auto &groups = op.grouped_aggregate_data.groups; - const idx_t group_by_size = groups.size(); - - DataChunk aggregate_input_chunk; - if (!gstate.payload_types.empty()) { - aggregate_input_chunk.Initialize(executor.context, gstate.payload_types); - } - - auto &finalize_event = event->Cast(); - - idx_t payload_idx; - idx_t next_payload_idx = 0; - for (idx_t agg_idx = 0; agg_idx < op.grouped_aggregate_data.aggregates.size(); agg_idx++) { - auto &aggregate = aggregates[agg_idx]->Cast(); - - // Forward the payload idx - payload_idx = next_payload_idx; - next_payload_idx = payload_idx + aggregate.children.size(); - - // If aggregate is not distinct, skip it - if (!distinct_data.IsDistinct(agg_idx)) { - continue; - } - - D_ASSERT(distinct_data.info.table_map.count(agg_idx)); - const auto &table_idx = distinct_data.info.table_map.at(agg_idx); - auto &radix_table = distinct_data.radix_tables[table_idx]; - - auto &sink = *distinct_state.radix_states[table_idx]; - auto local_source = radix_table->GetLocalSourceState(execution_context); - OperatorSourceInput source_input {*finalize_event.global_source_states[grouping_idx][agg_idx], *local_source, - interrupt_state}; - - // Create a duplicate of the output_chunk, because of multi-threading we cant alter the original - DataChunk output_chunk; - output_chunk.Initialize(executor.context, distinct_state.distinct_output_chunks[table_idx]->GetTypes()); - - // Fetch all the data from the aggregate ht, and Sink it into the main ht - while (true) { - output_chunk.Reset(); - group_chunk.Reset(); - aggregate_input_chunk.Reset(); - - auto res = radix_table->GetData(execution_context, output_chunk, sink, source_input); - if (res == SourceResultType::FINISHED) { - D_ASSERT(output_chunk.size() == 0); - break; - } else if (res == SourceResultType::BLOCKED) { - throw InternalException( - "Unexpected interrupt from radix table GetData in HashAggregateDistinctFinalizeTask"); - } - - auto &grouped_aggregate_data = *distinct_data.grouped_aggregate_data[table_idx]; - for (idx_t group_idx = 0; group_idx < group_by_size; group_idx++) { - auto &group = grouped_aggregate_data.groups[group_idx]; - auto &bound_ref_expr = group->Cast(); - group_chunk.data[bound_ref_expr.index].Reference(output_chunk.data[group_idx]); - } - group_chunk.SetCardinality(output_chunk); - - for (idx_t child_idx = 0; child_idx < grouped_aggregate_data.groups.size() - group_by_size; child_idx++) { - aggregate_input_chunk.data[payload_idx + child_idx].Reference( - output_chunk.data[group_by_size + child_idx]); - } - aggregate_input_chunk.SetCardinality(output_chunk); - - // Sink it into the main ht - grouping_data.table_data.Sink(execution_context, group_chunk, sink_input, aggregate_input_chunk, {agg_idx}); - } - } - grouping_data.table_data.Combine(execution_context, global_sink_state, *local_sink_state); -} - -SinkFinalizeType PhysicalHashAggregate::FinalizeDistinct(Pipeline &pipeline, Event &event, ClientContext &context, - GlobalSinkState &gstate_p) const { - auto &gstate = gstate_p.Cast(); - D_ASSERT(distinct_collection_info); - - for (idx_t i = 0; i < groupings.size(); i++) { - auto &grouping = groupings[i]; - auto &distinct_data = *grouping.distinct_data; - auto &distinct_state = *gstate.grouping_states[i].distinct_state; - - for (idx_t table_idx = 0; table_idx < distinct_data.radix_tables.size(); table_idx++) { - if (!distinct_data.radix_tables[table_idx]) { - continue; - } - auto &radix_table = distinct_data.radix_tables[table_idx]; - auto &radix_state = *distinct_state.radix_states[table_idx]; - radix_table->Finalize(context, radix_state); - } - } - auto new_event = make_shared(context, pipeline, *this, gstate); - event.InsertEvent(std::move(new_event)); - return SinkFinalizeType::READY; -} - -SinkFinalizeType PhysicalHashAggregate::FinalizeInternal(Pipeline &pipeline, Event &event, ClientContext &context, - GlobalSinkState &gstate_p, bool check_distinct) const { - auto &gstate = gstate_p.Cast(); - - if (check_distinct && distinct_collection_info) { - // There are distinct aggregates - // If these are partitioned those need to be combined first - // Then we Finalize again, skipping this step - return FinalizeDistinct(pipeline, event, context, gstate_p); - } - - for (idx_t i = 0; i < groupings.size(); i++) { - auto &grouping = groupings[i]; - auto &grouping_gstate = gstate.grouping_states[i]; - grouping.table_data.Finalize(context, *grouping_gstate.table_state); - } - return SinkFinalizeType::READY; -} - -SinkFinalizeType PhysicalHashAggregate::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - return FinalizeInternal(pipeline, event, context, input.global_state, true); -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -class HashAggregateGlobalSourceState : public GlobalSourceState { -public: - HashAggregateGlobalSourceState(ClientContext &context, const PhysicalHashAggregate &op) : op(op), state_index(0) { - for (auto &grouping : op.groupings) { - auto &rt = grouping.table_data; - radix_states.push_back(rt.GetGlobalSourceState(context)); - } - } - - const PhysicalHashAggregate &op; - mutex lock; - atomic state_index; - - vector> radix_states; - -public: - idx_t MaxThreads() override { - // If there are no tables, we only need one thread. - if (op.groupings.empty()) { - return 1; - } - - auto &ht_state = op.sink_state->Cast(); - idx_t partitions = 0; - for (size_t sidx = 0; sidx < op.groupings.size(); ++sidx) { - auto &grouping = op.groupings[sidx]; - auto &grouping_gstate = ht_state.grouping_states[sidx]; - partitions += grouping.table_data.NumberOfPartitions(*grouping_gstate.table_state); - } - return MaxValue(1, partitions); - } -}; - -unique_ptr PhysicalHashAggregate::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(context, *this); -} - -class HashAggregateLocalSourceState : public LocalSourceState { -public: - explicit HashAggregateLocalSourceState(ExecutionContext &context, const PhysicalHashAggregate &op) { - for (auto &grouping : op.groupings) { - auto &rt = grouping.table_data; - radix_states.push_back(rt.GetLocalSourceState(context)); - } - } - - vector> radix_states; -}; - -unique_ptr PhysicalHashAggregate::GetLocalSourceState(ExecutionContext &context, - GlobalSourceState &gstate) const { - return make_uniq(context, *this); -} - -SourceResultType PhysicalHashAggregate::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &sink_gstate = sink_state->Cast(); - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - while (true) { - idx_t radix_idx = gstate.state_index; - if (radix_idx >= groupings.size()) { - break; - } - auto &grouping = groupings[radix_idx]; - auto &radix_table = grouping.table_data; - auto &grouping_gstate = sink_gstate.grouping_states[radix_idx]; - - InterruptState interrupt_state; - OperatorSourceInput source_input {*gstate.radix_states[radix_idx], *lstate.radix_states[radix_idx], - interrupt_state}; - auto res = radix_table.GetData(context, chunk, *grouping_gstate.table_state, source_input); - if (chunk.size() != 0) { - return SourceResultType::HAVE_MORE_OUTPUT; - } else if (res == SourceResultType::BLOCKED) { - throw InternalException("Unexpectedly Blocked from radix_table"); - } - - // move to the next table - lock_guard l(gstate.lock); - radix_idx++; - if (radix_idx > gstate.state_index) { - // we have not yet worked on the table - // move the global index forwards - gstate.state_index = radix_idx; - } - } - - return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; -} - -string PhysicalHashAggregate::ParamsToString() const { - string result; - auto &groups = grouped_aggregate_data.groups; - auto &aggregates = grouped_aggregate_data.aggregates; - for (idx_t i = 0; i < groups.size(); i++) { - if (i > 0) { - result += "\n"; - } - result += groups[i]->GetName(); - } - for (idx_t i = 0; i < aggregates.size(); i++) { - auto &aggregate = aggregates[i]->Cast(); - if (i > 0 || !groups.empty()) { - result += "\n"; - } - result += aggregates[i]->GetName(); - if (aggregate.filter) { - result += " Filter: " + aggregate.filter->GetName(); - } - } - return result; -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -PhysicalPerfectHashAggregate::PhysicalPerfectHashAggregate(ClientContext &context, vector types_p, - vector> aggregates_p, - vector> groups_p, - const vector> &group_stats, - vector required_bits_p, idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::PERFECT_HASH_GROUP_BY, std::move(types_p), estimated_cardinality), - groups(std::move(groups_p)), aggregates(std::move(aggregates_p)), required_bits(std::move(required_bits_p)) { - D_ASSERT(groups.size() == group_stats.size()); - group_minima.reserve(group_stats.size()); - for (auto &stats : group_stats) { - D_ASSERT(stats); - auto &nstats = *stats; - D_ASSERT(NumericStats::HasMin(nstats)); - group_minima.push_back(NumericStats::Min(nstats)); - } - for (auto &expr : groups) { - group_types.push_back(expr->return_type); - } - - vector bindings; - vector payload_types_filters; - for (auto &expr : aggregates) { - D_ASSERT(expr->expression_class == ExpressionClass::BOUND_AGGREGATE); - D_ASSERT(expr->IsAggregate()); - auto &aggr = expr->Cast(); - bindings.push_back(&aggr); - - D_ASSERT(!aggr.IsDistinct()); - D_ASSERT(aggr.function.combine); - for (auto &child : aggr.children) { - payload_types.push_back(child->return_type); - } - if (aggr.filter) { - payload_types_filters.push_back(aggr.filter->return_type); - } - } - for (const auto &pay_filters : payload_types_filters) { - payload_types.push_back(pay_filters); - } - aggregate_objects = AggregateObject::CreateAggregateObjects(bindings); - - // filter_indexes must be pre-built, not lazily instantiated in parallel... - idx_t aggregate_input_idx = 0; - for (auto &aggregate : aggregates) { - auto &aggr = aggregate->Cast(); - aggregate_input_idx += aggr.children.size(); - } - for (auto &aggregate : aggregates) { - auto &aggr = aggregate->Cast(); - if (aggr.filter) { - auto &bound_ref_expr = aggr.filter->Cast(); - auto it = filter_indexes.find(aggr.filter.get()); - if (it == filter_indexes.end()) { - filter_indexes[aggr.filter.get()] = bound_ref_expr.index; - bound_ref_expr.index = aggregate_input_idx++; - } else { - ++aggregate_input_idx; - } - } - } -} - -unique_ptr PhysicalPerfectHashAggregate::CreateHT(Allocator &allocator, - ClientContext &context) const { - return make_uniq(context, allocator, group_types, payload_types, aggregate_objects, - group_minima, required_bits); -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class PerfectHashAggregateGlobalState : public GlobalSinkState { -public: - PerfectHashAggregateGlobalState(const PhysicalPerfectHashAggregate &op, ClientContext &context) - : ht(op.CreateHT(Allocator::Get(context), context)) { - } - - //! The lock for updating the global aggregate state - mutex lock; - //! The global aggregate hash table - unique_ptr ht; -}; - -class PerfectHashAggregateLocalState : public LocalSinkState { -public: - PerfectHashAggregateLocalState(const PhysicalPerfectHashAggregate &op, ExecutionContext &context) - : ht(op.CreateHT(Allocator::Get(context.client), context.client)) { - group_chunk.InitializeEmpty(op.group_types); - if (!op.payload_types.empty()) { - aggregate_input_chunk.InitializeEmpty(op.payload_types); - } - } - - //! The local aggregate hash table - unique_ptr ht; - DataChunk group_chunk; - DataChunk aggregate_input_chunk; -}; - -unique_ptr PhysicalPerfectHashAggregate::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(*this, context); -} - -unique_ptr PhysicalPerfectHashAggregate::GetLocalSinkState(ExecutionContext &context) const { - return make_uniq(*this, context); -} - -SinkResultType PhysicalPerfectHashAggregate::Sink(ExecutionContext &context, DataChunk &chunk, - OperatorSinkInput &input) const { - auto &lstate = input.local_state.Cast(); - DataChunk &group_chunk = lstate.group_chunk; - DataChunk &aggregate_input_chunk = lstate.aggregate_input_chunk; - - for (idx_t group_idx = 0; group_idx < groups.size(); group_idx++) { - auto &group = groups[group_idx]; - D_ASSERT(group->type == ExpressionType::BOUND_REF); - auto &bound_ref_expr = group->Cast(); - group_chunk.data[group_idx].Reference(chunk.data[bound_ref_expr.index]); - } - idx_t aggregate_input_idx = 0; - for (auto &aggregate : aggregates) { - auto &aggr = aggregate->Cast(); - for (auto &child_expr : aggr.children) { - D_ASSERT(child_expr->type == ExpressionType::BOUND_REF); - auto &bound_ref_expr = child_expr->Cast(); - aggregate_input_chunk.data[aggregate_input_idx++].Reference(chunk.data[bound_ref_expr.index]); - } - } - for (auto &aggregate : aggregates) { - auto &aggr = aggregate->Cast(); - if (aggr.filter) { - auto it = filter_indexes.find(aggr.filter.get()); - D_ASSERT(it != filter_indexes.end()); - aggregate_input_chunk.data[aggregate_input_idx++].Reference(chunk.data[it->second]); - } - } - - group_chunk.SetCardinality(chunk.size()); - - aggregate_input_chunk.SetCardinality(chunk.size()); - - group_chunk.Verify(); - aggregate_input_chunk.Verify(); - D_ASSERT(aggregate_input_chunk.ColumnCount() == 0 || group_chunk.size() == aggregate_input_chunk.size()); - - lstate.ht->AddChunk(group_chunk, aggregate_input_chunk); - return SinkResultType::NEED_MORE_INPUT; -} - -//===--------------------------------------------------------------------===// -// Combine -//===--------------------------------------------------------------------===// -SinkCombineResultType PhysicalPerfectHashAggregate::Combine(ExecutionContext &context, - OperatorSinkCombineInput &input) const { - auto &lstate = input.local_state.Cast(); - auto &gstate = input.global_state.Cast(); - - lock_guard l(gstate.lock); - gstate.ht->Combine(*lstate.ht); - - return SinkCombineResultType::FINISHED; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -class PerfectHashAggregateState : public GlobalSourceState { -public: - PerfectHashAggregateState() : ht_scan_position(0) { - } - - //! The current position to scan the HT for output tuples - idx_t ht_scan_position; -}; - -unique_ptr PhysicalPerfectHashAggregate::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(); -} - -SourceResultType PhysicalPerfectHashAggregate::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &state = input.global_state.Cast(); - auto &gstate = sink_state->Cast(); - - gstate.ht->Scan(state.ht_scan_position, chunk); - - if (chunk.size() > 0) { - return SourceResultType::HAVE_MORE_OUTPUT; - } else { - return SourceResultType::FINISHED; - } -} - -string PhysicalPerfectHashAggregate::ParamsToString() const { - string result; - for (idx_t i = 0; i < groups.size(); i++) { - if (i > 0) { - result += "\n"; - } - result += groups[i]->GetName(); - } - for (idx_t i = 0; i < aggregates.size(); i++) { - if (i > 0 || !groups.empty()) { - result += "\n"; - } - result += aggregates[i]->GetName(); - auto &aggregate = aggregates[i]->Cast(); - if (aggregate.filter) { - result += " Filter: " + aggregate.filter->GetName(); - } - } - return result; -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -PhysicalStreamingWindow::PhysicalStreamingWindow(vector types, vector> select_list, - idx_t estimated_cardinality, PhysicalOperatorType type) - : PhysicalOperator(type, std::move(types), estimated_cardinality), select_list(std::move(select_list)) { -} - -class StreamingWindowGlobalState : public GlobalOperatorState { -public: - StreamingWindowGlobalState() : row_number(1) { - } - - //! The next row number. - std::atomic row_number; -}; - -class StreamingWindowState : public OperatorState { -public: - using StateBuffer = vector; - - StreamingWindowState() - : initialized(false), allocator(Allocator::DefaultAllocator()), - statev(LogicalType::POINTER, data_ptr_cast(&state_ptr)) { - } - - ~StreamingWindowState() override { - for (size_t i = 0; i < aggregate_dtors.size(); ++i) { - auto dtor = aggregate_dtors[i]; - if (dtor) { - AggregateInputData aggr_input_data(aggregate_bind_data[i], allocator); - state_ptr = aggregate_states[i].data(); - dtor(statev, aggr_input_data, 1); - } - } - } - - void Initialize(ClientContext &context, DataChunk &input, const vector> &expressions) { - const_vectors.resize(expressions.size()); - aggregate_states.resize(expressions.size()); - aggregate_bind_data.resize(expressions.size(), nullptr); - aggregate_dtors.resize(expressions.size(), nullptr); - - for (idx_t expr_idx = 0; expr_idx < expressions.size(); expr_idx++) { - auto &expr = *expressions[expr_idx]; - auto &wexpr = expr.Cast(); - switch (expr.GetExpressionType()) { - case ExpressionType::WINDOW_AGGREGATE: { - auto &aggregate = *wexpr.aggregate; - auto &state = aggregate_states[expr_idx]; - aggregate_bind_data[expr_idx] = wexpr.bind_info.get(); - aggregate_dtors[expr_idx] = aggregate.destructor; - state.resize(aggregate.state_size()); - aggregate.initialize(state.data()); - break; - } - case ExpressionType::WINDOW_FIRST_VALUE: { - // Just execute the expression once - ExpressionExecutor executor(context); - executor.AddExpression(*wexpr.children[0]); - DataChunk result; - result.Initialize(Allocator::Get(context), {wexpr.children[0]->return_type}); - executor.Execute(input, result); - - const_vectors[expr_idx] = make_uniq(result.GetValue(0, 0)); - break; - } - case ExpressionType::WINDOW_PERCENT_RANK: { - const_vectors[expr_idx] = make_uniq(Value((double)0)); - break; - } - case ExpressionType::WINDOW_RANK: - case ExpressionType::WINDOW_RANK_DENSE: { - const_vectors[expr_idx] = make_uniq(Value((int64_t)1)); - break; - } - default: - break; - } - } - initialized = true; - } - -public: - bool initialized; - vector> const_vectors; - ArenaAllocator allocator; - - // Aggregation - vector aggregate_states; - vector aggregate_bind_data; - vector aggregate_dtors; - data_ptr_t state_ptr; - Vector statev; -}; - -unique_ptr PhysicalStreamingWindow::GetGlobalOperatorState(ClientContext &context) const { - return make_uniq(); -} - -unique_ptr PhysicalStreamingWindow::GetOperatorState(ExecutionContext &context) const { - return make_uniq(); -} - -OperatorResultType PhysicalStreamingWindow::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - GlobalOperatorState &gstate_p, OperatorState &state_p) const { - auto &gstate = gstate_p.Cast(); - auto &state = state_p.Cast(); - state.allocator.Reset(); - - if (!state.initialized) { - state.Initialize(context.client, input, select_list); - } - // Put payload columns in place - for (idx_t col_idx = 0; col_idx < input.data.size(); col_idx++) { - chunk.data[col_idx].Reference(input.data[col_idx]); - } - // Compute window function - const idx_t count = input.size(); - for (idx_t expr_idx = 0; expr_idx < select_list.size(); expr_idx++) { - idx_t col_idx = input.data.size() + expr_idx; - auto &expr = *select_list[expr_idx]; - auto &result = chunk.data[col_idx]; - switch (expr.GetExpressionType()) { - case ExpressionType::WINDOW_AGGREGATE: { - // Establish the aggregation environment - auto &wexpr = expr.Cast(); - auto &aggregate = *wexpr.aggregate; - auto &statev = state.statev; - state.state_ptr = state.aggregate_states[expr_idx].data(); - AggregateInputData aggr_input_data(wexpr.bind_info.get(), state.allocator); - - // Check for COUNT(*) - if (wexpr.children.empty()) { - D_ASSERT(GetTypeIdSize(result.GetType().InternalType()) == sizeof(int64_t)); - auto data = FlatVector::GetData(result); - int64_t start_row = gstate.row_number; - for (idx_t i = 0; i < input.size(); ++i) { - data[i] = start_row + i; - } - break; - } - - // Compute the arguments - auto &allocator = Allocator::Get(context.client); - ExpressionExecutor executor(context.client); - vector payload_types; - for (auto &child : wexpr.children) { - payload_types.push_back(child->return_type); - executor.AddExpression(*child); - } - - DataChunk payload; - payload.Initialize(allocator, payload_types); - executor.Execute(input, payload); - - // Iterate through them using a single SV - payload.Flatten(); - DataChunk row; - row.Initialize(allocator, payload_types); - sel_t s = 0; - SelectionVector sel(&s); - row.Slice(sel, 1); - for (size_t col_idx = 0; col_idx < payload.ColumnCount(); ++col_idx) { - DictionaryVector::Child(row.data[col_idx]).Reference(payload.data[col_idx]); - } - - // Update the state and finalize it one row at a time. - for (idx_t i = 0; i < input.size(); ++i) { - sel.set_index(0, i); - aggregate.update(row.data.data(), aggr_input_data, row.ColumnCount(), statev, 1); - aggregate.finalize(statev, aggr_input_data, result, 1, i); - } - break; - } - case ExpressionType::WINDOW_FIRST_VALUE: - case ExpressionType::WINDOW_PERCENT_RANK: - case ExpressionType::WINDOW_RANK: - case ExpressionType::WINDOW_RANK_DENSE: { - // Reference constant vector - chunk.data[col_idx].Reference(*state.const_vectors[expr_idx]); - break; - } - case ExpressionType::WINDOW_ROW_NUMBER: { - // Set row numbers - int64_t start_row = gstate.row_number; - auto rdata = FlatVector::GetData(chunk.data[col_idx]); - for (idx_t i = 0; i < count; i++) { - rdata[i] = start_row + i; - } - break; - } - default: - throw NotImplementedException("%s for StreamingWindow", ExpressionTypeToString(expr.GetExpressionType())); - } - } - gstate.row_number += count; - chunk.SetCardinality(count); - return OperatorResultType::NEED_MORE_INPUT; -} - -string PhysicalStreamingWindow::ParamsToString() const { - string result; - for (idx_t i = 0; i < select_list.size(); i++) { - if (i > 0) { - result += "\n"; - } - result += select_list[i]->GetName(); - } - return result; -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - -#include - -namespace duckdb { - -PhysicalUngroupedAggregate::PhysicalUngroupedAggregate(vector types, - vector> expressions, - idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::UNGROUPED_AGGREGATE, std::move(types), estimated_cardinality), - aggregates(std::move(expressions)) { - - distinct_collection_info = DistinctAggregateCollectionInfo::Create(aggregates); - if (!distinct_collection_info) { - return; - } - distinct_data = make_uniq(*distinct_collection_info); -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -struct AggregateState { - explicit AggregateState(const vector> &aggregate_expressions) { - counts = make_uniq_array>(aggregate_expressions.size()); - for (idx_t i = 0; i < aggregate_expressions.size(); i++) { - auto &aggregate = aggregate_expressions[i]; - D_ASSERT(aggregate->GetExpressionClass() == ExpressionClass::BOUND_AGGREGATE); - auto &aggr = aggregate->Cast(); - auto state = make_unsafe_uniq_array(aggr.function.state_size()); - aggr.function.initialize(state.get()); - aggregates.push_back(std::move(state)); - bind_data.push_back(aggr.bind_info.get()); - destructors.push_back(aggr.function.destructor); -#ifdef DEBUG - counts[i] = 0; -#endif - } - } - ~AggregateState() { - D_ASSERT(destructors.size() == aggregates.size()); - for (idx_t i = 0; i < destructors.size(); i++) { - if (!destructors[i]) { - continue; - } - Vector state_vector(Value::POINTER(CastPointerToValue(aggregates[i].get()))); - state_vector.SetVectorType(VectorType::FLAT_VECTOR); - - ArenaAllocator allocator(Allocator::DefaultAllocator()); - AggregateInputData aggr_input_data(bind_data[i], allocator); - destructors[i](state_vector, aggr_input_data, 1); - } - } - - void Move(AggregateState &other) { - other.aggregates = std::move(aggregates); - other.destructors = std::move(destructors); - } - - //! The aggregate values - vector> aggregates; - //! The bind data - vector bind_data; - //! The destructors - vector destructors; - //! Counts (used for verification) - unique_array> counts; -}; - -class UngroupedAggregateGlobalSinkState : public GlobalSinkState { -public: - UngroupedAggregateGlobalSinkState(const PhysicalUngroupedAggregate &op, ClientContext &client) - : state(op.aggregates), finished(false), allocator(BufferAllocator::Get(client)) { - if (op.distinct_data) { - distinct_state = make_uniq(*op.distinct_data, client); - } - } - - //! The lock for updating the global aggregate state - mutex lock; - //! The global aggregate state - AggregateState state; - //! Whether or not the aggregate is finished - bool finished; - //! The data related to the distinct aggregates (if there are any) - unique_ptr distinct_state; - //! Global arena allocator - ArenaAllocator allocator; -}; - -class UngroupedAggregateLocalSinkState : public LocalSinkState { -public: - UngroupedAggregateLocalSinkState(const PhysicalUngroupedAggregate &op, const vector &child_types, - GlobalSinkState &gstate_p, ExecutionContext &context) - : allocator(BufferAllocator::Get(context.client)), state(op.aggregates), child_executor(context.client), - aggregate_input_chunk(), filter_set() { - auto &gstate = gstate_p.Cast(); - - auto &allocator = BufferAllocator::Get(context.client); - InitializeDistinctAggregates(op, gstate, context); - - vector payload_types; - vector aggregate_objects; - for (auto &aggregate : op.aggregates) { - D_ASSERT(aggregate->GetExpressionClass() == ExpressionClass::BOUND_AGGREGATE); - auto &aggr = aggregate->Cast(); - // initialize the payload chunk - for (auto &child : aggr.children) { - payload_types.push_back(child->return_type); - child_executor.AddExpression(*child); - } - aggregate_objects.emplace_back(&aggr); - } - if (!payload_types.empty()) { // for select count(*) from t; there is no payload at all - aggregate_input_chunk.Initialize(allocator, payload_types); - } - filter_set.Initialize(context.client, aggregate_objects, child_types); - } - - //! Local arena allocator - ArenaAllocator allocator; - //! The local aggregate state - AggregateState state; - //! The executor - ExpressionExecutor child_executor; - //! The payload chunk, containing all the Vectors for the aggregates - DataChunk aggregate_input_chunk; - //! Aggregate filter data set - AggregateFilterDataSet filter_set; - //! The local sink states of the distinct aggregates hash tables - vector> radix_states; - -public: - void Reset() { - aggregate_input_chunk.Reset(); - } - void InitializeDistinctAggregates(const PhysicalUngroupedAggregate &op, - const UngroupedAggregateGlobalSinkState &gstate, ExecutionContext &context) { - - if (!op.distinct_data) { - return; - } - auto &data = *op.distinct_data; - auto &state = *gstate.distinct_state; - D_ASSERT(!data.radix_tables.empty()); - - const idx_t aggregate_count = state.radix_states.size(); - radix_states.resize(aggregate_count); - - auto &distinct_info = *op.distinct_collection_info; - - for (auto &idx : distinct_info.indices) { - idx_t table_idx = distinct_info.table_map[idx]; - if (data.radix_tables[table_idx] == nullptr) { - // This aggregate has identical input as another aggregate, so no table is created for it - continue; - } - auto &radix_table = *data.radix_tables[table_idx]; - radix_states[table_idx] = radix_table.GetLocalSinkState(context); - } - } -}; - -bool PhysicalUngroupedAggregate::SinkOrderDependent() const { - for (auto &expr : aggregates) { - auto &aggr = expr->Cast(); - if (aggr.function.order_dependent == AggregateOrderDependent::ORDER_DEPENDENT) { - return true; - } - } - return false; -} - -unique_ptr PhysicalUngroupedAggregate::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(*this, context); -} - -unique_ptr PhysicalUngroupedAggregate::GetLocalSinkState(ExecutionContext &context) const { - D_ASSERT(sink_state); - auto &gstate = *sink_state; - return make_uniq(*this, children[0]->GetTypes(), gstate, context); -} - -void PhysicalUngroupedAggregate::SinkDistinct(ExecutionContext &context, DataChunk &chunk, - OperatorSinkInput &input) const { - auto &sink = input.local_state.Cast(); - auto &global_sink = input.global_state.Cast(); - D_ASSERT(distinct_data); - auto &distinct_state = *global_sink.distinct_state; - auto &distinct_info = *distinct_collection_info; - auto &distinct_indices = distinct_info.Indices(); - - DataChunk empty_chunk; - - auto &distinct_filter = distinct_info.Indices(); - - for (auto &idx : distinct_indices) { - auto &aggregate = aggregates[idx]->Cast(); - - idx_t table_idx = distinct_info.table_map[idx]; - if (!distinct_data->radix_tables[table_idx]) { - // This distinct aggregate shares its data with another - continue; - } - D_ASSERT(distinct_data->radix_tables[table_idx]); - auto &radix_table = *distinct_data->radix_tables[table_idx]; - auto &radix_global_sink = *distinct_state.radix_states[table_idx]; - auto &radix_local_sink = *sink.radix_states[table_idx]; - OperatorSinkInput sink_input {radix_global_sink, radix_local_sink, input.interrupt_state}; - - if (aggregate.filter) { - // The hashtable can apply a filter, but only on the payload - // And in our case, we need to filter the groups (the distinct aggr children) - - // Apply the filter before inserting into the hashtable - auto &filtered_data = sink.filter_set.GetFilterData(idx); - idx_t count = filtered_data.ApplyFilter(chunk); - filtered_data.filtered_payload.SetCardinality(count); - - radix_table.Sink(context, filtered_data.filtered_payload, sink_input, empty_chunk, distinct_filter); - } else { - radix_table.Sink(context, chunk, sink_input, empty_chunk, distinct_filter); - } - } -} - -SinkResultType PhysicalUngroupedAggregate::Sink(ExecutionContext &context, DataChunk &chunk, - OperatorSinkInput &input) const { - auto &sink = input.local_state.Cast(); - - // perform the aggregation inside the local state - sink.Reset(); - - if (distinct_data) { - SinkDistinct(context, chunk, input); - } - - DataChunk &payload_chunk = sink.aggregate_input_chunk; - - idx_t payload_idx = 0; - idx_t next_payload_idx = 0; - - for (idx_t aggr_idx = 0; aggr_idx < aggregates.size(); aggr_idx++) { - auto &aggregate = aggregates[aggr_idx]->Cast(); - - payload_idx = next_payload_idx; - next_payload_idx = payload_idx + aggregate.children.size(); - - if (aggregate.IsDistinct()) { - continue; - } - - idx_t payload_cnt = 0; - // resolve the filter (if any) - if (aggregate.filter) { - auto &filtered_data = sink.filter_set.GetFilterData(aggr_idx); - auto count = filtered_data.ApplyFilter(chunk); - - sink.child_executor.SetChunk(filtered_data.filtered_payload); - payload_chunk.SetCardinality(count); - } else { - sink.child_executor.SetChunk(chunk); - payload_chunk.SetCardinality(chunk); - } - -#ifdef DEBUG - sink.state.counts[aggr_idx] += payload_chunk.size(); -#endif - - // resolve the child expressions of the aggregate (if any) - for (idx_t i = 0; i < aggregate.children.size(); ++i) { - sink.child_executor.ExecuteExpression(payload_idx + payload_cnt, - payload_chunk.data[payload_idx + payload_cnt]); - payload_cnt++; - } - - auto start_of_input = payload_cnt == 0 ? nullptr : &payload_chunk.data[payload_idx]; - AggregateInputData aggr_input_data(aggregate.bind_info.get(), sink.allocator); - aggregate.function.simple_update(start_of_input, aggr_input_data, payload_cnt, - sink.state.aggregates[aggr_idx].get(), payload_chunk.size()); - } - return SinkResultType::NEED_MORE_INPUT; -} - -//===--------------------------------------------------------------------===// -// Combine -//===--------------------------------------------------------------------===// -void PhysicalUngroupedAggregate::CombineDistinct(ExecutionContext &context, OperatorSinkCombineInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - - if (!distinct_data) { - return; - } - auto &distinct_state = gstate.distinct_state; - auto table_count = distinct_data->radix_tables.size(); - for (idx_t table_idx = 0; table_idx < table_count; table_idx++) { - D_ASSERT(distinct_data->radix_tables[table_idx]); - auto &radix_table = *distinct_data->radix_tables[table_idx]; - auto &radix_global_sink = *distinct_state->radix_states[table_idx]; - auto &radix_local_sink = *lstate.radix_states[table_idx]; - - radix_table.Combine(context, radix_global_sink, radix_local_sink); - } -} - -SinkCombineResultType PhysicalUngroupedAggregate::Combine(ExecutionContext &context, - OperatorSinkCombineInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - D_ASSERT(!gstate.finished); - - // finalize: combine the local state into the global state - // all aggregates are combinable: we might be doing a parallel aggregate - // use the combine method to combine the partial aggregates - OperatorSinkCombineInput distinct_input {gstate, lstate, input.interrupt_state}; - CombineDistinct(context, distinct_input); - - lock_guard glock(gstate.lock); - for (idx_t aggr_idx = 0; aggr_idx < aggregates.size(); aggr_idx++) { - auto &aggregate = aggregates[aggr_idx]->Cast(); - - if (aggregate.IsDistinct()) { - continue; - } - - Vector source_state(Value::POINTER(CastPointerToValue(lstate.state.aggregates[aggr_idx].get()))); - Vector dest_state(Value::POINTER(CastPointerToValue(gstate.state.aggregates[aggr_idx].get()))); - - AggregateInputData aggr_input_data(aggregate.bind_info.get(), gstate.allocator); - aggregate.function.combine(source_state, dest_state, aggr_input_data, 1); -#ifdef DEBUG - gstate.state.counts[aggr_idx] += lstate.state.counts[aggr_idx]; -#endif - } - lstate.allocator.Destroy(); - - auto &client_profiler = QueryProfiler::Get(context.client); - context.thread.profiler.Flush(*this, lstate.child_executor, "child_executor", 0); - client_profiler.Flush(context.thread.profiler); - - return SinkCombineResultType::FINISHED; -} - -//===--------------------------------------------------------------------===// -// Finalize -//===--------------------------------------------------------------------===// -class UngroupedDistinctAggregateFinalizeEvent : public BasePipelineEvent { -public: - UngroupedDistinctAggregateFinalizeEvent(ClientContext &context, const PhysicalUngroupedAggregate &op_p, - UngroupedAggregateGlobalSinkState &gstate_p, Pipeline &pipeline_p) - : BasePipelineEvent(pipeline_p), context(context), op(op_p), gstate(gstate_p), tasks_scheduled(0), - tasks_done(0) { - } - -public: - void Schedule() override; - -private: - ClientContext &context; - - const PhysicalUngroupedAggregate &op; - UngroupedAggregateGlobalSinkState &gstate; - -public: - mutex lock; - idx_t tasks_scheduled; - idx_t tasks_done; - - vector> global_source_states; -}; - -class UngroupedDistinctAggregateFinalizeTask : public ExecutorTask { -public: - UngroupedDistinctAggregateFinalizeTask(Executor &executor, shared_ptr event_p, - const PhysicalUngroupedAggregate &op, - UngroupedAggregateGlobalSinkState &state_p) - : ExecutorTask(executor), event(std::move(event_p)), op(op), gstate(state_p), - allocator(BufferAllocator::Get(executor.context)) { - } - - TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override; - -private: - void AggregateDistinct(); - -private: - shared_ptr event; - - const PhysicalUngroupedAggregate &op; - UngroupedAggregateGlobalSinkState &gstate; - - ArenaAllocator allocator; -}; - -void UngroupedDistinctAggregateFinalizeEvent::Schedule() { - D_ASSERT(gstate.distinct_state); - auto &aggregates = op.aggregates; - auto &distinct_data = *op.distinct_data; - - idx_t payload_idx = 0; - idx_t next_payload_idx = 0; - for (idx_t agg_idx = 0; agg_idx < aggregates.size(); agg_idx++) { - auto &aggregate = aggregates[agg_idx]->Cast(); - - // Forward the payload idx - payload_idx = next_payload_idx; - next_payload_idx = payload_idx + aggregate.children.size(); - - // If aggregate is not distinct, skip it - if (!distinct_data.IsDistinct(agg_idx)) { - global_source_states.push_back(nullptr); - continue; - } - D_ASSERT(distinct_data.info.table_map.count(agg_idx)); - - // Create global state for scanning - auto table_idx = distinct_data.info.table_map.at(agg_idx); - auto &radix_table_p = *distinct_data.radix_tables[table_idx]; - global_source_states.push_back(radix_table_p.GetGlobalSourceState(context)); - } - - const idx_t n_threads = TaskScheduler::GetScheduler(context).NumberOfThreads(); - vector> tasks; - for (idx_t i = 0; i < n_threads; i++) { - tasks.push_back( - make_uniq(pipeline->executor, shared_from_this(), op, gstate)); - tasks_scheduled++; - } - SetTasks(std::move(tasks)); -} - -TaskExecutionResult UngroupedDistinctAggregateFinalizeTask::ExecuteTask(TaskExecutionMode mode) { - AggregateDistinct(); - event->FinishTask(); - return TaskExecutionResult::TASK_FINISHED; -} - -void UngroupedDistinctAggregateFinalizeTask::AggregateDistinct() { - D_ASSERT(gstate.distinct_state); - auto &distinct_state = *gstate.distinct_state; - auto &distinct_data = *op.distinct_data; - - // Create thread-local copy of aggregate state - auto &aggregates = op.aggregates; - AggregateState state(aggregates); - - // Thread-local contexts - ThreadContext thread_context(executor.context); - ExecutionContext execution_context(executor.context, thread_context, nullptr); - - auto &finalize_event = event->Cast(); - - // Now loop through the distinct aggregates, scanning the distinct HTs - idx_t payload_idx = 0; - idx_t next_payload_idx = 0; - for (idx_t agg_idx = 0; agg_idx < aggregates.size(); agg_idx++) { - auto &aggregate = aggregates[agg_idx]->Cast(); - - // Forward the payload idx - payload_idx = next_payload_idx; - next_payload_idx = payload_idx + aggregate.children.size(); - - // If aggregate is not distinct, skip it - if (!distinct_data.IsDistinct(agg_idx)) { - continue; - } - - const auto table_idx = distinct_data.info.table_map.at(agg_idx); - auto &radix_table = *distinct_data.radix_tables[table_idx]; - auto lstate = radix_table.GetLocalSourceState(execution_context); - - auto &sink = *distinct_state.radix_states[table_idx]; - InterruptState interrupt_state; - OperatorSourceInput source_input {*finalize_event.global_source_states[agg_idx], *lstate, interrupt_state}; - - DataChunk output_chunk; - output_chunk.Initialize(executor.context, distinct_state.distinct_output_chunks[table_idx]->GetTypes()); - - DataChunk payload_chunk; - payload_chunk.InitializeEmpty(distinct_data.grouped_aggregate_data[table_idx]->group_types); - payload_chunk.SetCardinality(0); - - AggregateInputData aggr_input_data(aggregate.bind_info.get(), allocator); - while (true) { - output_chunk.Reset(); - - auto res = radix_table.GetData(execution_context, output_chunk, sink, source_input); - if (res == SourceResultType::FINISHED) { - D_ASSERT(output_chunk.size() == 0); - break; - } else if (res == SourceResultType::BLOCKED) { - throw InternalException( - "Unexpected interrupt from radix table GetData in UngroupedDistinctAggregateFinalizeTask"); - } - - // We dont need to resolve the filter, we already did this in Sink - idx_t payload_cnt = aggregate.children.size(); - for (idx_t i = 0; i < payload_cnt; i++) { - payload_chunk.data[i].Reference(output_chunk.data[i]); - } - payload_chunk.SetCardinality(output_chunk); - -#ifdef DEBUG - gstate.state.counts[agg_idx] += payload_chunk.size(); -#endif - - // Update the aggregate state - auto start_of_input = payload_cnt ? &payload_chunk.data[0] : nullptr; - aggregate.function.simple_update(start_of_input, aggr_input_data, payload_cnt, - state.aggregates[agg_idx].get(), payload_chunk.size()); - } - } - - // After scanning the distinct HTs, we can combine the thread-local agg states with the thread-global - lock_guard guard(finalize_event.lock); - payload_idx = 0; - next_payload_idx = 0; - for (idx_t agg_idx = 0; agg_idx < aggregates.size(); agg_idx++) { - if (!distinct_data.IsDistinct(agg_idx)) { - continue; - } - - auto &aggregate = aggregates[agg_idx]->Cast(); - AggregateInputData aggr_input_data(aggregate.bind_info.get(), allocator); - - Vector state_vec(Value::POINTER(CastPointerToValue(state.aggregates[agg_idx].get()))); - Vector combined_vec(Value::POINTER(CastPointerToValue(gstate.state.aggregates[agg_idx].get()))); - aggregate.function.combine(state_vec, combined_vec, aggr_input_data, 1); - } - - D_ASSERT(!gstate.finished); - if (++finalize_event.tasks_done == finalize_event.tasks_scheduled) { - gstate.finished = true; - } -} - -SinkFinalizeType PhysicalUngroupedAggregate::FinalizeDistinct(Pipeline &pipeline, Event &event, ClientContext &context, - GlobalSinkState &gstate_p) const { - auto &gstate = gstate_p.Cast(); - D_ASSERT(distinct_data); - auto &distinct_state = *gstate.distinct_state; - - for (idx_t table_idx = 0; table_idx < distinct_data->radix_tables.size(); table_idx++) { - auto &radix_table_p = distinct_data->radix_tables[table_idx]; - auto &radix_state = *distinct_state.radix_states[table_idx]; - radix_table_p->Finalize(context, radix_state); - } - auto new_event = make_shared(context, *this, gstate, pipeline); - event.InsertEvent(std::move(new_event)); - return SinkFinalizeType::READY; -} - -SinkFinalizeType PhysicalUngroupedAggregate::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - auto &gstate = input.global_state.Cast(); - - if (distinct_data) { - return FinalizeDistinct(pipeline, event, context, input.global_state); - } - - D_ASSERT(!gstate.finished); - gstate.finished = true; - return SinkFinalizeType::READY; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -void VerifyNullHandling(DataChunk &chunk, AggregateState &state, const vector> &aggregates) { -#ifdef DEBUG - for (idx_t aggr_idx = 0; aggr_idx < aggregates.size(); aggr_idx++) { - auto &aggr = aggregates[aggr_idx]->Cast(); - if (state.counts[aggr_idx] == 0 && aggr.function.null_handling == FunctionNullHandling::DEFAULT_NULL_HANDLING) { - // Default is when 0 values go in, NULL comes out - UnifiedVectorFormat vdata; - chunk.data[aggr_idx].ToUnifiedFormat(1, vdata); - D_ASSERT(!vdata.validity.RowIsValid(vdata.sel->get_index(0))); - } - } -#endif -} - -SourceResultType PhysicalUngroupedAggregate::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &gstate = sink_state->Cast(); - D_ASSERT(gstate.finished); - - // initialize the result chunk with the aggregate values - chunk.SetCardinality(1); - for (idx_t aggr_idx = 0; aggr_idx < aggregates.size(); aggr_idx++) { - auto &aggregate = aggregates[aggr_idx]->Cast(); - - Vector state_vector(Value::POINTER(CastPointerToValue(gstate.state.aggregates[aggr_idx].get()))); - AggregateInputData aggr_input_data(aggregate.bind_info.get(), gstate.allocator); - aggregate.function.finalize(state_vector, aggr_input_data, chunk.data[aggr_idx], 1, 0); - } - VerifyNullHandling(chunk, gstate.state, aggregates); - - return SourceResultType::FINISHED; -} - -string PhysicalUngroupedAggregate::ParamsToString() const { - string result; - for (idx_t i = 0; i < aggregates.size(); i++) { - auto &aggregate = aggregates[i]->Cast(); - if (i > 0) { - result += "\n"; - } - result += aggregates[i]->GetName(); - if (aggregate.filter) { - result += " Filter: " + aggregate.filter->GetName(); - } - } - return result; -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - - - - - - - -#include -#include -#include - -namespace duckdb { - -// Global sink state -class WindowGlobalSinkState : public GlobalSinkState { -public: - WindowGlobalSinkState(const PhysicalWindow &op, ClientContext &context) - : op(op), mode(DBConfig::GetConfig(context).options.window_mode) { - - D_ASSERT(op.select_list[0]->GetExpressionClass() == ExpressionClass::BOUND_WINDOW); - auto &wexpr = op.select_list[0]->Cast(); - - global_partition = - make_uniq(context, wexpr.partitions, wexpr.orders, op.children[0]->types, - wexpr.partitions_stats, op.estimated_cardinality); - } - - const PhysicalWindow &op; - unique_ptr global_partition; - WindowAggregationMode mode; -}; - -// Per-thread sink state -class WindowLocalSinkState : public LocalSinkState { -public: - WindowLocalSinkState(ClientContext &context, const WindowGlobalSinkState &gstate) - : local_partition(context, *gstate.global_partition) { - } - - void Sink(DataChunk &input_chunk) { - local_partition.Sink(input_chunk); - } - - void Combine() { - local_partition.Combine(); - } - - PartitionLocalSinkState local_partition; -}; - -// this implements a sorted window functions variant -PhysicalWindow::PhysicalWindow(vector types, vector> select_list_p, - idx_t estimated_cardinality, PhysicalOperatorType type) - : PhysicalOperator(type, std::move(types), estimated_cardinality), select_list(std::move(select_list_p)) { - is_order_dependent = false; - for (auto &expr : select_list) { - D_ASSERT(expr->expression_class == ExpressionClass::BOUND_WINDOW); - auto &bound_window = expr->Cast(); - if (bound_window.partitions.empty() && bound_window.orders.empty()) { - is_order_dependent = true; - } - } -} - -static unique_ptr WindowExecutorFactory(BoundWindowExpression &wexpr, ClientContext &context, - const ValidityMask &partition_mask, - const ValidityMask &order_mask, const idx_t payload_count, - WindowAggregationMode mode) { - switch (wexpr.type) { - case ExpressionType::WINDOW_AGGREGATE: - return make_uniq(wexpr, context, payload_count, partition_mask, order_mask, mode); - case ExpressionType::WINDOW_ROW_NUMBER: - return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); - case ExpressionType::WINDOW_RANK_DENSE: - return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); - case ExpressionType::WINDOW_RANK: - return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); - case ExpressionType::WINDOW_PERCENT_RANK: - return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); - case ExpressionType::WINDOW_CUME_DIST: - return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); - case ExpressionType::WINDOW_NTILE: - return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); - case ExpressionType::WINDOW_LEAD: - case ExpressionType::WINDOW_LAG: - return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); - case ExpressionType::WINDOW_FIRST_VALUE: - return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); - case ExpressionType::WINDOW_LAST_VALUE: - return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); - case ExpressionType::WINDOW_NTH_VALUE: - return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); - break; - default: - throw InternalException("Window aggregate type %s", ExpressionTypeToString(wexpr.type)); - } -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -SinkResultType PhysicalWindow::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - auto &lstate = input.local_state.Cast(); - - lstate.Sink(chunk); - - return SinkResultType::NEED_MORE_INPUT; -} - -SinkCombineResultType PhysicalWindow::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { - auto &lstate = input.local_state.Cast(); - lstate.Combine(); - - return SinkCombineResultType::FINISHED; -} - -unique_ptr PhysicalWindow::GetLocalSinkState(ExecutionContext &context) const { - auto &gstate = sink_state->Cast(); - return make_uniq(context.client, gstate); -} - -unique_ptr PhysicalWindow::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(*this, context); -} - -SinkFinalizeType PhysicalWindow::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - auto &state = input.global_state.Cast(); - - // Did we get any data? - if (!state.global_partition->count) { - return SinkFinalizeType::NO_OUTPUT_POSSIBLE; - } - - // Do we have any sorting to schedule? - if (state.global_partition->rows) { - D_ASSERT(!state.global_partition->grouping_data); - return state.global_partition->rows->count ? SinkFinalizeType::READY : SinkFinalizeType::NO_OUTPUT_POSSIBLE; - } - - // Find the first group to sort - if (!state.global_partition->HasMergeTasks()) { - // Empty input! - return SinkFinalizeType::NO_OUTPUT_POSSIBLE; - } - - // Schedule all the sorts for maximum thread utilisation - auto new_event = make_shared(*state.global_partition, pipeline); - event.InsertEvent(std::move(new_event)); - - return SinkFinalizeType::READY; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -class WindowPartitionSourceState; - -class WindowGlobalSourceState : public GlobalSourceState { -public: - using HashGroupSourcePtr = unique_ptr; - using ScannerPtr = unique_ptr; - using Task = std::pair; - - WindowGlobalSourceState(ClientContext &context_p, WindowGlobalSinkState &gsink_p); - - //! Get the next task - Task NextTask(idx_t hash_bin); - - //! Context for executing computations - ClientContext &context; - //! All the sunk data - WindowGlobalSinkState &gsink; - //! The next group to build. - atomic next_build; - //! The built groups - vector built; - //! Serialise access to the built hash groups - mutable mutex built_lock; - //! The number of unfinished tasks - atomic tasks_remaining; - -public: - idx_t MaxThreads() override { - return tasks_remaining; - } - -private: - Task CreateTask(idx_t hash_bin); - Task StealWork(); -}; - -WindowGlobalSourceState::WindowGlobalSourceState(ClientContext &context_p, WindowGlobalSinkState &gsink_p) - : context(context_p), gsink(gsink_p), next_build(0), tasks_remaining(0) { - auto &hash_groups = gsink.global_partition->hash_groups; - - auto &gpart = gsink.global_partition; - if (hash_groups.empty()) { - // OVER() - built.resize(1); - if (gpart->rows) { - tasks_remaining += gpart->rows->blocks.size(); - } - } else { - built.resize(hash_groups.size()); - idx_t batch_base = 0; - for (auto &hash_group : hash_groups) { - if (!hash_group) { - continue; - } - auto &global_sort_state = *hash_group->global_sort; - if (global_sort_state.sorted_blocks.empty()) { - continue; - } - - D_ASSERT(global_sort_state.sorted_blocks.size() == 1); - auto &sb = *global_sort_state.sorted_blocks[0]; - auto &sd = *sb.payload_data; - tasks_remaining += sd.data_blocks.size(); - - hash_group->batch_base = batch_base; - batch_base += sd.data_blocks.size(); - } - } -} - -// Per-bin evaluation state (build and evaluate) -class WindowPartitionSourceState { -public: - using HashGroupPtr = unique_ptr; - using ExecutorPtr = unique_ptr; - using Executors = vector; - - WindowPartitionSourceState(ClientContext &context, WindowGlobalSourceState &gsource) - : context(context), op(gsource.gsink.op), gsource(gsource), read_block_idx(0), unscanned(0) { - layout.Initialize(gsource.gsink.global_partition->payload_types); - } - - unique_ptr GetScanner() const; - void MaterializeSortedData(); - void BuildPartition(WindowGlobalSinkState &gstate, const idx_t hash_bin); - - ClientContext &context; - const PhysicalWindow &op; - WindowGlobalSourceState &gsource; - - HashGroupPtr hash_group; - //! The generated input chunks - unique_ptr rows; - unique_ptr heap; - RowLayout layout; - //! The partition boundary mask - vector partition_bits; - ValidityMask partition_mask; - //! The order boundary mask - vector order_bits; - ValidityMask order_mask; - //! External paging - bool external; - //! The current execution functions - Executors executors; - - //! The bin number - idx_t hash_bin; - - //! The next block to read. - mutable atomic read_block_idx; - //! The number of remaining unscanned blocks. - atomic unscanned; -}; - -void WindowPartitionSourceState::MaterializeSortedData() { - auto &global_sort_state = *hash_group->global_sort; - if (global_sort_state.sorted_blocks.empty()) { - return; - } - - // scan the sorted row data - D_ASSERT(global_sort_state.sorted_blocks.size() == 1); - auto &sb = *global_sort_state.sorted_blocks[0]; - - // Free up some memory before allocating more - sb.radix_sorting_data.clear(); - sb.blob_sorting_data = nullptr; - - // Move the sorting row blocks into our RDCs - auto &buffer_manager = global_sort_state.buffer_manager; - auto &sd = *sb.payload_data; - - // Data blocks are required - D_ASSERT(!sd.data_blocks.empty()); - auto &block = sd.data_blocks[0]; - rows = make_uniq(buffer_manager, block->capacity, block->entry_size); - rows->blocks = std::move(sd.data_blocks); - rows->count = std::accumulate(rows->blocks.begin(), rows->blocks.end(), idx_t(0), - [&](idx_t c, const unique_ptr &b) { return c + b->count; }); - - // Heap blocks are optional, but we want both for iteration. - if (!sd.heap_blocks.empty()) { - auto &block = sd.heap_blocks[0]; - heap = make_uniq(buffer_manager, block->capacity, block->entry_size); - heap->blocks = std::move(sd.heap_blocks); - hash_group.reset(); - } else { - heap = make_uniq(buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1, true); - } - heap->count = std::accumulate(heap->blocks.begin(), heap->blocks.end(), idx_t(0), - [&](idx_t c, const unique_ptr &b) { return c + b->count; }); -} - -unique_ptr WindowPartitionSourceState::GetScanner() const { - auto &gsink = *gsource.gsink.global_partition; - if ((gsink.rows && !hash_bin) || hash_bin < gsink.hash_groups.size()) { - const auto block_idx = read_block_idx++; - if (block_idx >= rows->blocks.size()) { - return nullptr; - } - // Second pass can flush - --gsource.tasks_remaining; - return make_uniq(*rows, *heap, layout, external, block_idx, true); - } - return nullptr; -} - -void WindowPartitionSourceState::BuildPartition(WindowGlobalSinkState &gstate, const idx_t hash_bin_p) { - // Get rid of any stale data - hash_bin = hash_bin_p; - - // There are three types of partitions: - // 1. No partition (no sorting) - // 2. One partition (sorting, but no hashing) - // 3. Multiple partitions (sorting and hashing) - - // How big is the partition? - auto &gpart = *gsource.gsink.global_partition; - idx_t count = 0; - if (hash_bin < gpart.hash_groups.size() && gpart.hash_groups[hash_bin]) { - count = gpart.hash_groups[hash_bin]->count; - } else if (gpart.rows && !hash_bin) { - count = gpart.count; - } else { - return; - } - - // Initialise masks to false - const auto bit_count = ValidityMask::ValidityMaskSize(count); - partition_bits.clear(); - partition_bits.resize(bit_count, 0); - partition_mask.Initialize(partition_bits.data()); - - order_bits.clear(); - order_bits.resize(bit_count, 0); - order_mask.Initialize(order_bits.data()); - - // Scan the sorted data into new Collections - external = gpart.external; - if (gpart.rows && !hash_bin) { - // Simple mask - partition_mask.SetValidUnsafe(0); - order_mask.SetValidUnsafe(0); - // No partition - align the heap blocks with the row blocks - rows = gpart.rows->CloneEmpty(gpart.rows->keep_pinned); - heap = gpart.strings->CloneEmpty(gpart.strings->keep_pinned); - RowDataCollectionScanner::AlignHeapBlocks(*rows, *heap, *gpart.rows, *gpart.strings, layout); - external = true; - } else if (hash_bin < gpart.hash_groups.size()) { - // Overwrite the collections with the sorted data - D_ASSERT(gpart.hash_groups[hash_bin].get()); - hash_group = std::move(gpart.hash_groups[hash_bin]); - hash_group->ComputeMasks(partition_mask, order_mask); - external = hash_group->global_sort->external; - MaterializeSortedData(); - } else { - return; - } - - // Create the executors for each function - executors.clear(); - for (idx_t expr_idx = 0; expr_idx < op.select_list.size(); ++expr_idx) { - D_ASSERT(op.select_list[expr_idx]->GetExpressionClass() == ExpressionClass::BOUND_WINDOW); - auto &wexpr = op.select_list[expr_idx]->Cast(); - auto wexec = WindowExecutorFactory(wexpr, context, partition_mask, order_mask, count, gstate.mode); - executors.emplace_back(std::move(wexec)); - } - - // First pass over the input without flushing - DataChunk input_chunk; - input_chunk.Initialize(gpart.allocator, gpart.payload_types); - auto scanner = make_uniq(*rows, *heap, layout, external, false); - idx_t input_idx = 0; - while (true) { - input_chunk.Reset(); - scanner->Scan(input_chunk); - if (input_chunk.size() == 0) { - break; - } - - // TODO: Parallelization opportunity - for (auto &wexec : executors) { - wexec->Sink(input_chunk, input_idx, scanner->Count()); - } - input_idx += input_chunk.size(); - } - - // TODO: Parallelization opportunity - for (auto &wexec : executors) { - wexec->Finalize(); - } - - // External scanning assumes all blocks are swizzled. - scanner->ReSwizzle(); - - // Start the block countdown - unscanned = rows->blocks.size(); -} - -// Per-thread scan state -class WindowLocalSourceState : public LocalSourceState { -public: - using ReadStatePtr = unique_ptr; - using ReadStates = vector; - - explicit WindowLocalSourceState(WindowGlobalSourceState &gsource); - void UpdateBatchIndex(); - bool NextPartition(); - void Scan(DataChunk &chunk); - - //! The shared source state - WindowGlobalSourceState &gsource; - //! The current bin being processed - idx_t hash_bin; - //! The current batch index (for output reordering) - idx_t batch_index; - //! The current source being processed - optional_ptr partition_source; - //! The read cursor - unique_ptr scanner; - //! Buffer for the inputs - DataChunk input_chunk; - //! Executor read states. - ReadStates read_states; - //! Buffer for window results - DataChunk output_chunk; -}; - -WindowLocalSourceState::WindowLocalSourceState(WindowGlobalSourceState &gsource) - : gsource(gsource), hash_bin(gsource.built.size()), batch_index(0) { - auto &gsink = *gsource.gsink.global_partition; - auto &op = gsource.gsink.op; - - input_chunk.Initialize(gsink.allocator, gsink.payload_types); - - vector output_types; - for (idx_t expr_idx = 0; expr_idx < op.select_list.size(); ++expr_idx) { - D_ASSERT(op.select_list[expr_idx]->GetExpressionClass() == ExpressionClass::BOUND_WINDOW); - auto &wexpr = op.select_list[expr_idx]->Cast(); - output_types.emplace_back(wexpr.return_type); - } - output_chunk.Initialize(Allocator::Get(gsource.context), output_types); -} - -WindowGlobalSourceState::Task WindowGlobalSourceState::CreateTask(idx_t hash_bin) { - // Build outside the lock so no one tries to steal before we are done. - auto partition_source = make_uniq(context, *this); - partition_source->BuildPartition(gsink, hash_bin); - Task result(partition_source.get(), partition_source->GetScanner()); - - // Is there any data to scan? - if (result.second) { - lock_guard built_guard(built_lock); - built[hash_bin] = std::move(partition_source); - - return result; - } - - return Task(); -} - -WindowGlobalSourceState::Task WindowGlobalSourceState::StealWork() { - for (idx_t hash_bin = 0; hash_bin < built.size(); ++hash_bin) { - lock_guard built_guard(built_lock); - auto &partition_source = built[hash_bin]; - if (!partition_source) { - continue; - } - - Task result(partition_source.get(), partition_source->GetScanner()); - - // Is there any data to scan? - if (result.second) { - return result; - } - } - - // Nothing to steal - return Task(); -} - -WindowGlobalSourceState::Task WindowGlobalSourceState::NextTask(idx_t hash_bin) { - auto &hash_groups = gsink.global_partition->hash_groups; - const auto bin_count = built.size(); - - // Flush unneeded data - if (hash_bin < bin_count) { - // Lock and delete when all blocks have been scanned - // We do this here instead of in NextScan so the WindowLocalSourceState - // has a chance to delete its state objects first, - // which may reference the partition_source - - // Delete data outside the lock in case it is slow - HashGroupSourcePtr killed; - lock_guard built_guard(built_lock); - auto &partition_source = built[hash_bin]; - if (partition_source && !partition_source->unscanned) { - killed = std::move(partition_source); - } - } - - hash_bin = next_build++; - if (hash_bin < bin_count) { - // Find a non-empty hash group. - for (; hash_bin < hash_groups.size(); hash_bin = next_build++) { - if (hash_groups[hash_bin] && hash_groups[hash_bin]->count) { - auto result = CreateTask(hash_bin); - if (result.second) { - return result; - } - } - } - - // OVER() doesn't have a hash_group - if (hash_groups.empty()) { - auto result = CreateTask(hash_bin); - if (result.second) { - return result; - } - } - } - - // Work stealing - while (!context.interrupted && tasks_remaining) { - auto result = StealWork(); - if (result.second) { - return result; - } - - // If there is nothing to steal but there are unfinished partitions, - // yield until any pending builds are done. - TaskScheduler::YieldThread(); - } - - return Task(); -} - -void WindowLocalSourceState::UpdateBatchIndex() { - D_ASSERT(partition_source); - D_ASSERT(scanner.get()); - - batch_index = partition_source->hash_group ? partition_source->hash_group->batch_base : 0; - batch_index += scanner->BlockIndex(); -} - -bool WindowLocalSourceState::NextPartition() { - // Release old states before the source - scanner.reset(); - read_states.clear(); - - // Get a partition_source that is not finished - while (!scanner) { - auto task = gsource.NextTask(hash_bin); - if (!task.first) { - return false; - } - partition_source = task.first; - scanner = std::move(task.second); - hash_bin = partition_source->hash_bin; - UpdateBatchIndex(); - } - - for (auto &wexec : partition_source->executors) { - read_states.emplace_back(wexec->GetExecutorState()); - } - - return true; -} - -void WindowLocalSourceState::Scan(DataChunk &result) { - D_ASSERT(scanner); - if (!scanner->Remaining()) { - lock_guard built_guard(gsource.built_lock); - --partition_source->unscanned; - scanner = partition_source->GetScanner(); - - if (!scanner) { - partition_source = nullptr; - read_states.clear(); - return; - } - - UpdateBatchIndex(); - } - - const auto position = scanner->Scanned(); - input_chunk.Reset(); - scanner->Scan(input_chunk); - - auto &executors = partition_source->executors; - output_chunk.Reset(); - for (idx_t expr_idx = 0; expr_idx < executors.size(); ++expr_idx) { - auto &executor = *executors[expr_idx]; - auto &lstate = *read_states[expr_idx]; - auto &result = output_chunk.data[expr_idx]; - executor.Evaluate(position, input_chunk, result, lstate); - } - output_chunk.SetCardinality(input_chunk); - output_chunk.Verify(); - - idx_t out_idx = 0; - result.SetCardinality(input_chunk); - for (idx_t col_idx = 0; col_idx < input_chunk.ColumnCount(); col_idx++) { - result.data[out_idx++].Reference(input_chunk.data[col_idx]); - } - for (idx_t col_idx = 0; col_idx < output_chunk.ColumnCount(); col_idx++) { - result.data[out_idx++].Reference(output_chunk.data[col_idx]); - } - result.Verify(); -} - -unique_ptr PhysicalWindow::GetLocalSourceState(ExecutionContext &context, - GlobalSourceState &gsource_p) const { - auto &gsource = gsource_p.Cast(); - return make_uniq(gsource); -} - -unique_ptr PhysicalWindow::GetGlobalSourceState(ClientContext &context) const { - auto &gsink = sink_state->Cast(); - return make_uniq(context, gsink); -} - -bool PhysicalWindow::SupportsBatchIndex() const { - // We can only preserve order for single partitioning - // or work stealing causes out of order batch numbers - auto &wexpr = select_list[0]->Cast(); - return wexpr.partitions.empty() && !wexpr.orders.empty(); -} - -OrderPreservationType PhysicalWindow::SourceOrder() const { - return SupportsBatchIndex() ? OrderPreservationType::FIXED_ORDER : OrderPreservationType::NO_ORDER; -} - -idx_t PhysicalWindow::GetBatchIndex(ExecutionContext &context, DataChunk &chunk, GlobalSourceState &gstate_p, - LocalSourceState &lstate_p) const { - auto &lstate = lstate_p.Cast(); - return lstate.batch_index; -} - -SourceResultType PhysicalWindow::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &lsource = input.local_state.Cast(); - while (chunk.size() == 0) { - // Move to the next bin if we are done. - while (!lsource.scanner) { - if (!lsource.NextPartition()) { - return chunk.size() > 0 ? SourceResultType::HAVE_MORE_OUTPUT : SourceResultType::FINISHED; - } - } - - lsource.Scan(chunk); - } - - return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; -} - -string PhysicalWindow::ParamsToString() const { - string result; - for (idx_t i = 0; i < select_list.size(); i++) { - if (i > 0) { - result += "\n"; - } - result += select_list[i]->GetName(); - } - return result; -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - - - - - - -#include -#include -#include -#include - -namespace duckdb { - -string BaseCSVReader::GetLineNumberStr(idx_t line_error, bool is_line_estimated, idx_t buffer_idx) { - // If an error happens during auto-detect it is an estimated line - string estimated = (is_line_estimated ? string(" (estimated)") : string("")); - return to_string(GetLineError(line_error, buffer_idx)) + estimated; -} - -BaseCSVReader::BaseCSVReader(ClientContext &context_p, CSVReaderOptions options_p, - const vector &requested_types) - : context(context_p), fs(FileSystem::GetFileSystem(context)), allocator(BufferAllocator::Get(context)), - options(std::move(options_p)) { -} - -BaseCSVReader::~BaseCSVReader() { -} - -unique_ptr BaseCSVReader::OpenCSV(ClientContext &context, const CSVReaderOptions &options_p) { - return CSVFileHandle::OpenFile(FileSystem::GetFileSystem(context), BufferAllocator::Get(context), - options_p.file_path, options_p.compression); -} - -void BaseCSVReader::InitParseChunk(idx_t num_cols) { - // adapt not null info - if (options.force_not_null.size() != num_cols) { - options.force_not_null.resize(num_cols, false); - } - if (num_cols == parse_chunk.ColumnCount()) { - parse_chunk.Reset(); - } else { - parse_chunk.Destroy(); - - // initialize the parse_chunk with a set of VARCHAR types - vector varchar_types(num_cols, LogicalType::VARCHAR); - parse_chunk.Initialize(allocator, varchar_types); - } -} - -void BaseCSVReader::InitializeProjection() { - for (idx_t i = 0; i < GetTypes().size(); i++) { - reader_data.column_ids.push_back(i); - reader_data.column_mapping.push_back(i); - } -} - -template -static bool TemplatedTryCastDateVector(map &options, Vector &input_vector, - Vector &result_vector, idx_t count, string &error_message, idx_t &line_error) { - D_ASSERT(input_vector.GetType().id() == LogicalTypeId::VARCHAR); - bool all_converted = true; - idx_t cur_line = 0; - UnaryExecutor::Execute(input_vector, result_vector, count, [&](string_t input) { - T result; - if (!OP::Operation(options, input, result, error_message)) { - line_error = cur_line; - all_converted = false; - } - cur_line++; - return result; - }); - return all_converted; -} - -struct TryCastDateOperator { - static bool Operation(map &options, string_t input, date_t &result, - string &error_message) { - return options[LogicalTypeId::DATE].TryParseDate(input, result, error_message); - } -}; - -struct TryCastTimestampOperator { - static bool Operation(map &options, string_t input, timestamp_t &result, - string &error_message) { - return options[LogicalTypeId::TIMESTAMP].TryParseTimestamp(input, result, error_message); - } -}; - -bool BaseCSVReader::TryCastDateVector(map &options, Vector &input_vector, - Vector &result_vector, idx_t count, string &error_message, idx_t &line_error) { - return TemplatedTryCastDateVector(options, input_vector, result_vector, count, - error_message, line_error); -} - -bool BaseCSVReader::TryCastTimestampVector(map &options, Vector &input_vector, - Vector &result_vector, idx_t count, string &error_message) { - idx_t line_error; - return TemplatedTryCastDateVector(options, input_vector, result_vector, - count, error_message, line_error); -} - -void BaseCSVReader::VerifyLineLength(idx_t line_size, idx_t buffer_idx) { - if (line_size > options.maximum_line_size) { - throw InvalidInputException( - "Error in file \"%s\" on line %s: Maximum line size of %llu bytes exceeded!", options.file_path, - GetLineNumberStr(parse_chunk.size(), linenr_estimated, buffer_idx).c_str(), options.maximum_line_size); - } -} - -template -bool TemplatedTryCastFloatingVector(CSVReaderOptions &options, Vector &input_vector, Vector &result_vector, idx_t count, - string &error_message, idx_t &line_error) { - D_ASSERT(input_vector.GetType().id() == LogicalTypeId::VARCHAR); - bool all_converted = true; - idx_t row = 0; - UnaryExecutor::Execute(input_vector, result_vector, count, [&](string_t input) { - T result; - if (!OP::Operation(input, result, &error_message)) { - line_error = row; - all_converted = false; - } else { - row++; - } - return result; - }); - return all_converted; -} - -template -bool TemplatedTryCastDecimalVector(CSVReaderOptions &options, Vector &input_vector, Vector &result_vector, idx_t count, - string &error_message, uint8_t width, uint8_t scale) { - D_ASSERT(input_vector.GetType().id() == LogicalTypeId::VARCHAR); - bool all_converted = true; - UnaryExecutor::Execute(input_vector, result_vector, count, [&](string_t input) { - T result; - if (!OP::Operation(input, result, &error_message, width, scale)) { - all_converted = false; - } - return result; - }); - return all_converted; -} - -void BaseCSVReader::AddValue(string_t str_val, idx_t &column, vector &escape_positions, bool has_quotes, - idx_t buffer_idx) { - auto length = str_val.GetSize(); - if (length == 0 && column == 0) { - row_empty = true; - } else { - row_empty = false; - } - if (!return_types.empty() && column == return_types.size() && length == 0) { - // skip a single trailing delimiter in last column - return; - } - if (column >= return_types.size()) { - if (options.ignore_errors) { - error_column_overflow = true; - return; - } else { - throw InvalidInputException( - "Error in file \"%s\", on line %s: expected %lld values per row, but got more. (%s)", options.file_path, - GetLineNumberStr(linenr, linenr_estimated, buffer_idx).c_str(), return_types.size(), - options.ToString()); - } - } - - // insert the line number into the chunk - idx_t row_entry = parse_chunk.size(); - - // test against null string, but only if the value was not quoted - if ((!(has_quotes && !options.allow_quoted_nulls) || return_types[column].id() != LogicalTypeId::VARCHAR) && - !options.force_not_null[column] && Equals::Operation(str_val, string_t(options.null_str))) { - FlatVector::SetNull(parse_chunk.data[column], row_entry, true); - } else { - auto &v = parse_chunk.data[column]; - auto parse_data = FlatVector::GetData(v); - if (!escape_positions.empty()) { - // remove escape characters (if any) - string old_val = str_val.GetString(); - string new_val = ""; - idx_t prev_pos = 0; - for (idx_t i = 0; i < escape_positions.size(); i++) { - idx_t next_pos = escape_positions[i]; - new_val += old_val.substr(prev_pos, next_pos - prev_pos); - prev_pos = ++next_pos; - } - new_val += old_val.substr(prev_pos, old_val.size() - prev_pos); - escape_positions.clear(); - parse_data[row_entry] = StringVector::AddStringOrBlob(v, string_t(new_val)); - } else { - parse_data[row_entry] = str_val; - } - } - - // move to the next column - column++; -} - -bool BaseCSVReader::AddRow(DataChunk &insert_chunk, idx_t &column, string &error_message, idx_t buffer_idx) { - linenr++; - - if (row_empty) { - row_empty = false; - if (return_types.size() != 1) { - if (mode == ParserMode::PARSING) { - FlatVector::SetNull(parse_chunk.data[0], parse_chunk.size(), false); - } - column = 0; - return false; - } - } - - // Error forwarded by 'ignore_errors' - originally encountered in 'AddValue' - if (error_column_overflow) { - D_ASSERT(options.ignore_errors); - error_column_overflow = false; - column = 0; - return false; - } - - if (column < return_types.size()) { - if (options.null_padding) { - for (; column < return_types.size(); column++) { - FlatVector::SetNull(parse_chunk.data[column], parse_chunk.size(), true); - } - } else if (options.ignore_errors) { - column = 0; - return false; - } else { - if (mode == ParserMode::SNIFFING_DATATYPES) { - error_message = "Error when adding line"; - return false; - } else { - throw InvalidInputException( - "Error in file \"%s\" on line %s: expected %lld values per row, but got %d.\nParser options:\n%s", - options.file_path, GetLineNumberStr(linenr, linenr_estimated, buffer_idx).c_str(), - return_types.size(), column, options.ToString()); - } - } - } - - parse_chunk.SetCardinality(parse_chunk.size() + 1); - - if (mode == ParserMode::PARSING_HEADER) { - return true; - } - - if (mode == ParserMode::SNIFFING_DATATYPES) { - return true; - } - - if (mode == ParserMode::PARSING && parse_chunk.size() == STANDARD_VECTOR_SIZE) { - Flush(insert_chunk, buffer_idx); - return true; - } - - column = 0; - return false; -} - -void BaseCSVReader::VerifyUTF8(idx_t col_idx, idx_t row_idx, DataChunk &chunk, int64_t offset) { - D_ASSERT(col_idx < chunk.data.size()); - D_ASSERT(row_idx < chunk.size()); - auto &v = chunk.data[col_idx]; - if (FlatVector::IsNull(v, row_idx)) { - return; - } - - auto parse_data = FlatVector::GetData(chunk.data[col_idx]); - auto s = parse_data[row_idx]; - auto utf_type = Utf8Proc::Analyze(s.GetData(), s.GetSize()); - if (utf_type == UnicodeType::INVALID) { - string col_name = to_string(col_idx); - if (col_idx < names.size()) { - col_name = "\"" + names[col_idx] + "\""; - } - int64_t error_line = linenr - (chunk.size() - row_idx) + 1 + offset; - D_ASSERT(error_line >= 0); - throw InvalidInputException("Error in file \"%s\" at line %llu in column \"%s\": " - "%s. Parser options:\n%s", - options.file_path, error_line, col_name, - ErrorManager::InvalidUnicodeError(s.GetString(), "CSV file"), options.ToString()); - } -} - -void BaseCSVReader::VerifyUTF8(idx_t col_idx) { - D_ASSERT(col_idx < parse_chunk.data.size()); - for (idx_t i = 0; i < parse_chunk.size(); i++) { - VerifyUTF8(col_idx, i, parse_chunk); - } -} - -bool TryCastDecimalVectorCommaSeparated(CSVReaderOptions &options, Vector &input_vector, Vector &result_vector, - idx_t count, string &error_message, const LogicalType &result_type) { - auto width = DecimalType::GetWidth(result_type); - auto scale = DecimalType::GetScale(result_type); - switch (result_type.InternalType()) { - case PhysicalType::INT16: - return TemplatedTryCastDecimalVector( - options, input_vector, result_vector, count, error_message, width, scale); - case PhysicalType::INT32: - return TemplatedTryCastDecimalVector( - options, input_vector, result_vector, count, error_message, width, scale); - case PhysicalType::INT64: - return TemplatedTryCastDecimalVector( - options, input_vector, result_vector, count, error_message, width, scale); - case PhysicalType::INT128: - return TemplatedTryCastDecimalVector( - options, input_vector, result_vector, count, error_message, width, scale); - default: - throw InternalException("Unimplemented physical type for decimal"); - } -} - -bool TryCastFloatingVectorCommaSeparated(CSVReaderOptions &options, Vector &input_vector, Vector &result_vector, - idx_t count, string &error_message, const LogicalType &result_type, - idx_t &line_error) { - switch (result_type.InternalType()) { - case PhysicalType::DOUBLE: - return TemplatedTryCastFloatingVector( - options, input_vector, result_vector, count, error_message, line_error); - case PhysicalType::FLOAT: - return TemplatedTryCastFloatingVector( - options, input_vector, result_vector, count, error_message, line_error); - default: - throw InternalException("Unimplemented physical type for floating"); - } -} - -// Location of erroneous value in the current parse chunk -struct ErrorLocation { - idx_t row_idx; - idx_t col_idx; - idx_t row_line; - - ErrorLocation(idx_t row_idx, idx_t col_idx, idx_t row_line) - : row_idx(row_idx), col_idx(col_idx), row_line(row_line) { - } -}; - -bool BaseCSVReader::Flush(DataChunk &insert_chunk, idx_t buffer_idx, bool try_add_line) { - if (parse_chunk.size() == 0) { - return true; - } - - bool conversion_error_ignored = false; - - // convert the columns in the parsed chunk to the types of the table - insert_chunk.SetCardinality(parse_chunk); - if (reader_data.column_ids.empty() && !reader_data.empty_columns) { - throw InternalException("BaseCSVReader::Flush called on a CSV reader that was not correctly initialized. Call " - "MultiFileReader::InitializeReader or InitializeProjection"); - } - D_ASSERT(reader_data.column_ids.size() == reader_data.column_mapping.size()); - for (idx_t c = 0; c < reader_data.column_ids.size(); c++) { - auto col_idx = reader_data.column_ids[c]; - auto result_idx = reader_data.column_mapping[c]; - auto &parse_vector = parse_chunk.data[col_idx]; - auto &result_vector = insert_chunk.data[result_idx]; - auto &type = result_vector.GetType(); - if (type.id() == LogicalTypeId::VARCHAR) { - // target type is varchar: no need to convert - // just test that all strings are valid utf-8 strings - VerifyUTF8(col_idx); - // reinterpret rather than reference so we can deal with user-defined types - result_vector.Reinterpret(parse_vector); - } else { - string error_message; - bool success; - idx_t line_error = 0; - bool target_type_not_varchar = false; - if (options.dialect_options.has_format[LogicalTypeId::DATE] && type.id() == LogicalTypeId::DATE) { - // use the date format to cast the chunk - success = TryCastDateVector(options.dialect_options.date_format, parse_vector, result_vector, - parse_chunk.size(), error_message, line_error); - } else if (options.dialect_options.has_format[LogicalTypeId::TIMESTAMP] && - type.id() == LogicalTypeId::TIMESTAMP) { - // use the date format to cast the chunk - success = TryCastTimestampVector(options.dialect_options.date_format, parse_vector, result_vector, - parse_chunk.size(), error_message); - } else if (options.decimal_separator != "." && - (type.id() == LogicalTypeId::FLOAT || type.id() == LogicalTypeId::DOUBLE)) { - success = TryCastFloatingVectorCommaSeparated(options, parse_vector, result_vector, parse_chunk.size(), - error_message, type, line_error); - } else if (options.decimal_separator != "." && type.id() == LogicalTypeId::DECIMAL) { - success = TryCastDecimalVectorCommaSeparated(options, parse_vector, result_vector, parse_chunk.size(), - error_message, type); - } else { - // target type is not varchar: perform a cast - target_type_not_varchar = true; - success = - VectorOperations::TryCast(context, parse_vector, result_vector, parse_chunk.size(), &error_message); - } - if (success) { - continue; - } - if (try_add_line) { - return false; - } - - string col_name = to_string(col_idx); - if (col_idx < names.size()) { - col_name = "\"" + names[col_idx] + "\""; - } - - // figure out the exact line number - if (target_type_not_varchar) { - UnifiedVectorFormat inserted_column_data; - result_vector.ToUnifiedFormat(parse_chunk.size(), inserted_column_data); - for (; line_error < parse_chunk.size(); line_error++) { - if (!inserted_column_data.validity.RowIsValid(line_error) && - !FlatVector::IsNull(parse_vector, line_error)) { - break; - } - } - } - - // The line_error must be summed with linenr (All lines emmited from this batch) - // But subtracted from the parse_chunk - D_ASSERT(line_error + linenr >= parse_chunk.size()); - line_error += linenr; - line_error -= parse_chunk.size(); - - auto error_line = GetLineError(line_error, buffer_idx); - - if (options.ignore_errors) { - conversion_error_ignored = true; - - } else if (options.auto_detect) { - throw InvalidInputException("%s in column %s, at line %llu.\n\nParser " - "options:\n%s.\n\nConsider either increasing the sample size " - "(SAMPLE_SIZE=X [X rows] or SAMPLE_SIZE=-1 [all rows]), " - "or skipping column conversion (ALL_VARCHAR=1)", - error_message, col_name, error_line, options.ToString()); - } else { - throw InvalidInputException("%s at line %llu in column %s. Parser options:\n%s ", error_message, - error_line, col_name, options.ToString()); - } - } - } - if (conversion_error_ignored) { - D_ASSERT(options.ignore_errors); - - SelectionVector succesful_rows(parse_chunk.size()); - idx_t sel_size = 0; - - // Keep track of failed cells - vector failed_cells; - - for (idx_t row_idx = 0; row_idx < parse_chunk.size(); row_idx++) { - - auto global_row_idx = row_idx + linenr - parse_chunk.size(); - auto row_line = GetLineError(global_row_idx, buffer_idx, false); - - bool row_failed = false; - for (idx_t c = 0; c < reader_data.column_ids.size(); c++) { - auto col_idx = reader_data.column_ids[c]; - auto result_idx = reader_data.column_mapping[c]; - - auto &parse_vector = parse_chunk.data[col_idx]; - auto &result_vector = insert_chunk.data[result_idx]; - - bool was_already_null = FlatVector::IsNull(parse_vector, row_idx); - if (!was_already_null && FlatVector::IsNull(result_vector, row_idx)) { - Increment(buffer_idx); - auto bla = GetLineError(global_row_idx, buffer_idx, false); - row_idx += bla; - row_idx -= bla; - row_failed = true; - failed_cells.emplace_back(row_idx, col_idx, row_line); - } - } - if (!row_failed) { - succesful_rows.set_index(sel_size++, row_idx); - } - } - - // Now do a second pass to produce the reject table entries - if (!failed_cells.empty() && !options.rejects_table_name.empty()) { - auto limit = options.rejects_limit; - - auto rejects = CSVRejectsTable::GetOrCreate(context, options.rejects_table_name); - lock_guard lock(rejects->write_lock); - - // short circuit if we already have too many rejects - if (limit == 0 || rejects->count < limit) { - auto &table = rejects->GetTable(context); - InternalAppender appender(context, table); - auto file_name = GetFileName(); - - for (auto &cell : failed_cells) { - if (limit != 0 && rejects->count >= limit) { - break; - } - rejects->count++; - - auto row_idx = cell.row_idx; - auto col_idx = cell.col_idx; - auto row_line = cell.row_line; - - auto col_name = to_string(col_idx); - if (col_idx < names.size()) { - col_name = "\"" + names[col_idx] + "\""; - } - - auto &parse_vector = parse_chunk.data[col_idx]; - auto parsed_str = FlatVector::GetData(parse_vector)[row_idx]; - auto &type = insert_chunk.data[col_idx].GetType(); - auto row_error_msg = StringUtil::Format("Could not convert string '%s' to '%s'", - parsed_str.GetString(), type.ToString()); - - // Add the row to the rejects table - appender.BeginRow(); - appender.Append(string_t(file_name)); - appender.Append(row_line); - appender.Append(col_idx); - appender.Append(string_t(col_name)); - appender.Append(parsed_str); - - if (!options.rejects_recovery_columns.empty()) { - child_list_t recovery_key; - for (auto &key_idx : options.rejects_recovery_column_ids) { - // Figure out if the recovery key is valid. - // If not, error out for real. - auto &component_vector = parse_chunk.data[key_idx]; - if (FlatVector::IsNull(component_vector, row_idx)) { - throw InvalidInputException("%s at line %llu in column %s. Parser options:\n%s ", - "Could not parse recovery column", row_line, col_name, - options.ToString()); - } - auto component = Value(FlatVector::GetData(component_vector)[row_idx]); - recovery_key.emplace_back(names[key_idx], component); - } - appender.Append(Value::STRUCT(recovery_key)); - } - - appender.Append(string_t(row_error_msg)); - appender.EndRow(); - } - appender.Close(); - } - } - - // Now slice the insert chunk to only include the succesful rows - insert_chunk.Slice(succesful_rows, sel_size); - } - parse_chunk.Reset(); - return true; -} - -void BaseCSVReader::SetNewLineDelimiter(bool carry, bool carry_followed_by_nl) { - if (options.dialect_options.new_line == NewLineIdentifier::NOT_SET) { - if (options.dialect_options.new_line == NewLineIdentifier::MIX) { - return; - } - NewLineIdentifier this_line_identifier; - if (carry) { - if (carry_followed_by_nl) { - this_line_identifier = NewLineIdentifier::CARRY_ON; - } else { - this_line_identifier = NewLineIdentifier::SINGLE; - } - } else { - this_line_identifier = NewLineIdentifier::SINGLE; - } - if (options.dialect_options.new_line == NewLineIdentifier::NOT_SET) { - options.dialect_options.new_line = this_line_identifier; - return; - } - if (options.dialect_options.new_line != this_line_identifier) { - options.dialect_options.new_line = NewLineIdentifier::MIX; - return; - } - options.dialect_options.new_line = this_line_identifier; - } -} -} // namespace duckdb - - - - - - - - - - - - - - - - - - - - - -#include -#include -#include -#include - -namespace duckdb { - -BufferedCSVReader::BufferedCSVReader(ClientContext &context, CSVReaderOptions options_p, - const vector &requested_types) - : BaseCSVReader(context, std::move(options_p), requested_types), buffer_size(0), position(0), start(0) { - file_handle = OpenCSV(context, options); - Initialize(requested_types); -} - -BufferedCSVReader::BufferedCSVReader(ClientContext &context, string filename, CSVReaderOptions options_p, - const vector &requested_types) - : BaseCSVReader(context, std::move(options_p), requested_types), buffer_size(0), position(0), start(0) { - options.file_path = std::move(filename); - file_handle = OpenCSV(context, options); - Initialize(requested_types); -} - -void BufferedCSVReader::Initialize(const vector &requested_types) { - if (options.auto_detect && options.file_options.union_by_name) { - // This is required for the sniffer to work on Union By Name - D_ASSERT(options.file_path == file_handle->GetFilePath()); - auto bm_file_handle = BaseCSVReader::OpenCSV(context, options); - auto csv_buffer_manager = make_shared(context, std::move(bm_file_handle), options); - CSVSniffer sniffer(options, csv_buffer_manager, state_machine_cache); - auto sniffer_result = sniffer.SniffCSV(); - return_types = sniffer_result.return_types; - names = sniffer_result.names; - if (return_types.empty()) { - throw InvalidInputException("Failed to detect column types from CSV: is the file a valid CSV file?"); - } - } else { - return_types = requested_types; - ResetBuffer(); - } - SkipRowsAndReadHeader(options.dialect_options.skip_rows, options.dialect_options.header); - InitParseChunk(return_types.size()); -} - -void BufferedCSVReader::ResetBuffer() { - buffer.reset(); - buffer_size = 0; - position = 0; - start = 0; - cached_buffers.clear(); -} - -void BufferedCSVReader::SkipRowsAndReadHeader(idx_t skip_rows, bool skip_header) { - for (idx_t i = 0; i < skip_rows; i++) { - // ignore skip rows - string read_line = file_handle->ReadLine(); - linenr++; - } - - if (skip_header) { - // ignore the first line as a header line - InitParseChunk(return_types.size()); - ParseCSV(ParserMode::PARSING_HEADER); - } -} - -string BufferedCSVReader::ColumnTypesError(case_insensitive_map_t sql_types_per_column, - const vector &names) { - for (idx_t i = 0; i < names.size(); i++) { - auto it = sql_types_per_column.find(names[i]); - if (it != sql_types_per_column.end()) { - sql_types_per_column.erase(names[i]); - continue; - } - } - if (sql_types_per_column.empty()) { - return string(); - } - string exception = "COLUMN_TYPES error: Columns with names: "; - for (auto &col : sql_types_per_column) { - exception += "\"" + col.first + "\","; - } - exception.pop_back(); - exception += " do not exist in the CSV File"; - return exception; -} - -void BufferedCSVReader::SkipEmptyLines() { - if (parse_chunk.data.size() == 1) { - // Empty lines are null data. - return; - } - for (; position < buffer_size; position++) { - if (!StringUtil::CharacterIsNewline(buffer[position])) { - return; - } - } -} - -void UpdateMaxLineLength(ClientContext &context, idx_t line_length) { - if (!context.client_data->debug_set_max_line_length) { - return; - } - if (line_length < context.client_data->debug_max_line_length) { - return; - } - context.client_data->debug_max_line_length = line_length; -} - -bool BufferedCSVReader::ReadBuffer(idx_t &start, idx_t &line_start) { - if (start > buffer_size) { - return false; - } - auto old_buffer = std::move(buffer); - - // the remaining part of the last buffer - idx_t remaining = buffer_size - start; - - idx_t buffer_read_size = INITIAL_BUFFER_SIZE_LARGE; - - while (remaining > buffer_read_size) { - buffer_read_size *= 2; - } - - // Check line length - if (remaining > options.maximum_line_size) { - throw InvalidInputException("Maximum line size of %llu bytes exceeded on line %s!", options.maximum_line_size, - GetLineNumberStr(linenr, linenr_estimated)); - } - - buffer = make_unsafe_uniq_array(buffer_read_size + remaining + 1); - buffer_size = remaining + buffer_read_size; - if (remaining > 0) { - // remaining from last buffer: copy it here - memcpy(buffer.get(), old_buffer.get() + start, remaining); - } - idx_t read_count = file_handle->Read(buffer.get() + remaining, buffer_read_size); - - bytes_in_chunk += read_count; - buffer_size = remaining + read_count; - buffer[buffer_size] = '\0'; - if (old_buffer) { - cached_buffers.push_back(std::move(old_buffer)); - } - start = 0; - position = remaining; - if (!bom_checked) { - bom_checked = true; - if (read_count >= 3 && buffer[0] == '\xEF' && buffer[1] == '\xBB' && buffer[2] == '\xBF') { - start += 3; - position += 3; - } - } - line_start = start; - - return read_count > 0; -} - -void BufferedCSVReader::ParseCSV(DataChunk &insert_chunk) { - string error_message; - if (!TryParseCSV(ParserMode::PARSING, insert_chunk, error_message)) { - throw InvalidInputException(error_message); - } -} - -void BufferedCSVReader::ParseCSV(ParserMode mode) { - DataChunk dummy_chunk; - string error_message; - if (!TryParseCSV(mode, dummy_chunk, error_message)) { - throw InvalidInputException(error_message); - } -} - -bool BufferedCSVReader::TryParseCSV(ParserMode parser_mode, DataChunk &insert_chunk, string &error_message) { - mode = parser_mode; - // used for parsing algorithm - bool finished_chunk = false; - idx_t column = 0; - idx_t offset = 0; - bool has_quotes = false; - vector escape_positions; - - idx_t line_start = position; - idx_t line_size = 0; - // read values into the buffer (if any) - if (position >= buffer_size) { - if (!ReadBuffer(start, line_start)) { - return true; - } - } - - // start parsing the first value - goto value_start; -value_start: - offset = 0; - /* state: value_start */ - // this state parses the first character of a value - if (buffer[position] == options.dialect_options.state_machine_options.quote) { - // quote: actual value starts in the next position - // move to in_quotes state - start = position + 1; - line_size++; - goto in_quotes; - } else { - // no quote, move to normal parsing state - start = position; - goto normal; - } -normal: - /* state: normal parsing state */ - // this state parses the remainder of a non-quoted value until we reach a delimiter or newline - do { - for (; position < buffer_size; position++) { - line_size++; - if (buffer[position] == options.dialect_options.state_machine_options.delimiter) { - // delimiter: end the value and add it to the chunk - goto add_value; - } else if (StringUtil::CharacterIsNewline(buffer[position])) { - // newline: add row - goto add_row; - } - } - } while (ReadBuffer(start, line_start)); - // file ends during normal scan: go to end state - goto final_state; -add_value: - AddValue(string_t(buffer.get() + start, position - start - offset), column, escape_positions, has_quotes); - // increase position by 1 and move start to the new position - offset = 0; - has_quotes = false; - start = ++position; - line_size++; - if (position >= buffer_size && !ReadBuffer(start, line_start)) { - // file ends right after delimiter, go to final state - goto final_state; - } - goto value_start; -add_row : { - // check type of newline (\r or \n) - bool carriage_return = buffer[position] == '\r'; - AddValue(string_t(buffer.get() + start, position - start - offset), column, escape_positions, has_quotes); - if (!error_message.empty()) { - return false; - } - VerifyLineLength(position - line_start); - - finished_chunk = AddRow(insert_chunk, column, error_message); - UpdateMaxLineLength(context, position - line_start); - if (!error_message.empty()) { - return false; - } - // increase position by 1 and move start to the new position - offset = 0; - has_quotes = false; - position++; - line_size = 0; - start = position; - line_start = position; - if (position >= buffer_size && !ReadBuffer(start, line_start)) { - // file ends right after delimiter, go to final state - goto final_state; - } - if (carriage_return) { - // \r newline, go to special state that parses an optional \n afterwards - goto carriage_return; - } else { - SetNewLineDelimiter(); - SkipEmptyLines(); - - start = position; - line_start = position; - if (position >= buffer_size && !ReadBuffer(start, line_start)) { - // file ends right after delimiter, go to final state - goto final_state; - } - // \n newline, move to value start - if (finished_chunk) { - return true; - } - goto value_start; - } -} -in_quotes: - /* state: in_quotes */ - // this state parses the remainder of a quoted value - has_quotes = true; - position++; - line_size++; - do { - for (; position < buffer_size; position++) { - line_size++; - if (buffer[position] == options.dialect_options.state_machine_options.quote) { - // quote: move to unquoted state - goto unquote; - } else if (buffer[position] == options.dialect_options.state_machine_options.escape) { - // escape: store the escaped position and move to handle_escape state - escape_positions.push_back(position - start); - goto handle_escape; - } - } - } while (ReadBuffer(start, line_start)); - // still in quoted state at the end of the file, error: - throw InvalidInputException("Error in file \"%s\" on line %s: unterminated quotes. (%s)", options.file_path, - GetLineNumberStr(linenr, linenr_estimated).c_str(), options.ToString()); -unquote: - /* state: unquote */ - // this state handles the state directly after we unquote - // in this state we expect either another quote (entering the quoted state again, and escaping the quote) - // or a delimiter/newline, ending the current value and moving on to the next value - position++; - line_size++; - if (position >= buffer_size && !ReadBuffer(start, line_start)) { - // file ends right after unquote, go to final state - offset = 1; - goto final_state; - } - if (buffer[position] == options.dialect_options.state_machine_options.quote && - (options.dialect_options.state_machine_options.escape == '\0' || - options.dialect_options.state_machine_options.escape == options.dialect_options.state_machine_options.quote)) { - // escaped quote, return to quoted state and store escape position - escape_positions.push_back(position - start); - goto in_quotes; - } else if (buffer[position] == options.dialect_options.state_machine_options.delimiter) { - // delimiter, add value - offset = 1; - goto add_value; - } else if (StringUtil::CharacterIsNewline(buffer[position])) { - offset = 1; - goto add_row; - } else { - error_message = StringUtil::Format( - "Error in file \"%s\" on line %s: quote should be followed by end of value, end of " - "row or another quote. (%s)", - options.file_path, GetLineNumberStr(linenr, linenr_estimated).c_str(), options.ToString()); - return false; - } -handle_escape: - /* state: handle_escape */ - // escape should be followed by a quote or another escape character - position++; - line_size++; - if (position >= buffer_size && !ReadBuffer(start, line_start)) { - error_message = StringUtil::Format( - "Error in file \"%s\" on line %s: neither QUOTE nor ESCAPE is proceeded by ESCAPE. (%s)", options.file_path, - GetLineNumberStr(linenr, linenr_estimated).c_str(), options.ToString()); - return false; - } - if (buffer[position] != options.dialect_options.state_machine_options.quote && - buffer[position] != options.dialect_options.state_machine_options.escape) { - error_message = StringUtil::Format( - "Error in file \"%s\" on line %s: neither QUOTE nor ESCAPE is proceeded by ESCAPE. (%s)", options.file_path, - GetLineNumberStr(linenr, linenr_estimated).c_str(), options.ToString()); - return false; - } - // escape was followed by quote or escape, go back to quoted state - goto in_quotes; -carriage_return: - /* state: carriage_return */ - // this stage optionally skips a newline (\n) character, which allows \r\n to be interpreted as a single line - if (buffer[position] == '\n') { - SetNewLineDelimiter(true, true); - // newline after carriage return: skip - // increase position by 1 and move start to the new position - start = ++position; - line_size++; - - if (position >= buffer_size && !ReadBuffer(start, line_start)) { - // file ends right after delimiter, go to final state - goto final_state; - } - } else { - SetNewLineDelimiter(true, false); - } - if (finished_chunk) { - return true; - } - SkipEmptyLines(); - start = position; - line_start = position; - if (position >= buffer_size && !ReadBuffer(start, line_start)) { - // file ends right after delimiter, go to final state - goto final_state; - } - - goto value_start; -final_state: - if (finished_chunk) { - return true; - } - - if (column > 0 || position > start) { - // remaining values to be added to the chunk - AddValue(string_t(buffer.get() + start, position - start - offset), column, escape_positions, has_quotes); - VerifyLineLength(position - line_start); - - finished_chunk = AddRow(insert_chunk, column, error_message); - SkipEmptyLines(); - UpdateMaxLineLength(context, line_size); - if (!error_message.empty()) { - return false; - } - } - - // final stage, only reached after parsing the file is finished - // flush the parsed chunk and finalize parsing - if (mode == ParserMode::PARSING) { - Flush(insert_chunk); - } - - end_of_file_reached = true; - return true; -} - -} // namespace duckdb - - - -namespace duckdb { - -CSVBuffer::CSVBuffer(ClientContext &context, idx_t buffer_size_p, CSVFileHandle &file_handle, - idx_t &global_csv_current_position, idx_t file_number_p) - : context(context), first_buffer(true), file_number(file_number_p), can_seek(file_handle.CanSeek()) { - AllocateBuffer(buffer_size_p); - auto buffer = Ptr(); - actual_buffer_size = file_handle.Read(buffer, buffer_size_p); - while (actual_buffer_size < buffer_size_p && !file_handle.FinishedReading()) { - // We keep reading until this block is full - actual_buffer_size += file_handle.Read(&buffer[actual_buffer_size], buffer_size_p - actual_buffer_size); - } - global_csv_start = global_csv_current_position; - // BOM check (https://en.wikipedia.org/wiki/Byte_order_mark) - if (actual_buffer_size >= 3 && buffer[0] == '\xEF' && buffer[1] == '\xBB' && buffer[2] == '\xBF') { - start_position += 3; - } - last_buffer = file_handle.FinishedReading(); -} - -CSVBuffer::CSVBuffer(CSVFileHandle &file_handle, ClientContext &context, idx_t buffer_size, - idx_t global_csv_current_position, idx_t file_number_p) - : context(context), global_csv_start(global_csv_current_position), file_number(file_number_p), - can_seek(file_handle.CanSeek()) { - AllocateBuffer(buffer_size); - auto buffer = handle.Ptr(); - actual_buffer_size = file_handle.Read(handle.Ptr(), buffer_size); - while (actual_buffer_size < buffer_size && !file_handle.FinishedReading()) { - // We keep reading until this block is full - actual_buffer_size += file_handle.Read(&buffer[actual_buffer_size], buffer_size - actual_buffer_size); - } - last_buffer = file_handle.FinishedReading(); -} - -shared_ptr CSVBuffer::Next(CSVFileHandle &file_handle, idx_t buffer_size, idx_t file_number_p) { - auto next_csv_buffer = - make_shared(file_handle, context, buffer_size, global_csv_start + actual_buffer_size, file_number_p); - if (next_csv_buffer->GetBufferSize() == 0) { - // We are done reading - return nullptr; - } - return next_csv_buffer; -} - -void CSVBuffer::AllocateBuffer(idx_t buffer_size) { - auto &buffer_manager = BufferManager::GetBufferManager(context); - bool can_destroy = can_seek; - handle = buffer_manager.Allocate(MaxValue(Storage::BLOCK_SIZE, buffer_size), can_destroy, &block); -} - -idx_t CSVBuffer::GetBufferSize() { - return actual_buffer_size; -} - -void CSVBuffer::Reload(CSVFileHandle &file_handle) { - AllocateBuffer(actual_buffer_size); - file_handle.Seek(global_csv_start); - file_handle.Read(handle.Ptr(), actual_buffer_size); -} - -unique_ptr CSVBuffer::Pin(CSVFileHandle &file_handle) { - auto &buffer_manager = BufferManager::GetBufferManager(context); - if (can_seek && block->IsUnloaded()) { - // We have to reload it from disk - block = nullptr; - Reload(file_handle); - } - return make_uniq(buffer_manager.Pin(block), actual_buffer_size, first_buffer, last_buffer, - global_csv_start, start_position, file_number); -} - -void CSVBuffer::Unpin() { - if (handle.IsValid()) { - handle.Destroy(); - } -} - -idx_t CSVBuffer::GetStart() { - return start_position; -} - -bool CSVBuffer::IsCSVFileLastBuffer() { - return last_buffer; -} - -} // namespace duckdb - - -namespace duckdb { - -CSVBufferManager::CSVBufferManager(ClientContext &context_p, unique_ptr file_handle_p, - const CSVReaderOptions &options, idx_t file_idx_p) - : file_handle(std::move(file_handle_p)), context(context_p), file_idx(file_idx_p), - buffer_size(CSVBuffer::CSV_BUFFER_SIZE) { - if (options.skip_rows_set) { - // Skip rows if they are set - skip_rows = options.dialect_options.skip_rows; - } - auto file_size = file_handle->FileSize(); - if (file_size > 0 && file_size < buffer_size) { - buffer_size = CSVBuffer::CSV_MINIMUM_BUFFER_SIZE; - } - if (options.buffer_size < buffer_size) { - buffer_size = options.buffer_size; - } - for (idx_t i = 0; i < skip_rows; i++) { - file_handle->ReadLine(); - } - Initialize(); -} - -void CSVBufferManager::UnpinBuffer(idx_t cache_idx) { - if (cache_idx < cached_buffers.size()) { - cached_buffers[cache_idx]->Unpin(); - } -} - -void CSVBufferManager::Initialize() { - if (cached_buffers.empty()) { - cached_buffers.emplace_back( - make_shared(context, buffer_size, *file_handle, global_csv_pos, file_idx)); - last_buffer = cached_buffers.front(); - } - start_pos = last_buffer->GetStart(); -} - -idx_t CSVBufferManager::GetStartPos() { - return start_pos; -} -bool CSVBufferManager::ReadNextAndCacheIt() { - D_ASSERT(last_buffer); - if (!last_buffer->IsCSVFileLastBuffer()) { - auto maybe_last_buffer = last_buffer->Next(*file_handle, buffer_size, file_idx); - if (!maybe_last_buffer) { - last_buffer->last_buffer = true; - return false; - } - last_buffer = std::move(maybe_last_buffer); - cached_buffers.emplace_back(last_buffer); - return true; - } - return false; -} - -unique_ptr CSVBufferManager::GetBuffer(const idx_t pos) { - while (pos >= cached_buffers.size()) { - if (done) { - return nullptr; - } - if (!ReadNextAndCacheIt()) { - done = true; - } - } - if (pos != 0) { - cached_buffers[pos - 1]->Unpin(); - } - return cached_buffers[pos]->Pin(*file_handle); -} - -bool CSVBufferIterator::Finished() { - return !cur_buffer_handle; -} - -void CSVBufferIterator::Reset() { - if (cur_buffer_handle) { - cur_buffer_handle.reset(); - } - if (cur_buffer_idx > 0) { - buffer_manager->UnpinBuffer(cur_buffer_idx - 1); - } - cur_buffer_idx = 0; - buffer_manager->Initialize(); - cur_pos = buffer_manager->GetStartPos(); -} - -} // namespace duckdb - - -namespace duckdb { - -CSVFileHandle::CSVFileHandle(FileSystem &fs, Allocator &allocator, unique_ptr file_handle_p, - const string &path_p, FileCompressionType compression) - : file_handle(std::move(file_handle_p)), path(path_p) { - can_seek = file_handle->CanSeek(); - on_disk_file = file_handle->OnDiskFile(); - file_size = file_handle->GetFileSize(); -} - -unique_ptr CSVFileHandle::OpenFileHandle(FileSystem &fs, Allocator &allocator, const string &path, - FileCompressionType compression) { - auto file_handle = fs.OpenFile(path, FileFlags::FILE_FLAGS_READ, FileLockType::NO_LOCK, compression); - if (file_handle->CanSeek()) { - file_handle->Reset(); - } - return file_handle; -} - -unique_ptr CSVFileHandle::OpenFile(FileSystem &fs, Allocator &allocator, const string &path, - FileCompressionType compression) { - auto file_handle = CSVFileHandle::OpenFileHandle(fs, allocator, path, compression); - return make_uniq(fs, allocator, std::move(file_handle), path, compression); -} - -bool CSVFileHandle::CanSeek() { - return can_seek; -} - -void CSVFileHandle::Seek(idx_t position) { - if (!can_seek) { - throw InternalException("Cannot seek in this file"); - } - file_handle->Seek(position); -} - -bool CSVFileHandle::OnDiskFile() { - return on_disk_file; -} - -idx_t CSVFileHandle::FileSize() { - return file_size; -} - -bool CSVFileHandle::FinishedReading() { - return finished; -} - -idx_t CSVFileHandle::Read(void *buffer, idx_t nr_bytes) { - requested_bytes += nr_bytes; - // if this is a plain file source OR we can seek we are not caching anything - auto bytes_read = file_handle->Read(buffer, nr_bytes); - if (!finished) { - finished = bytes_read == 0; - } - return bytes_read; -} - -string CSVFileHandle::ReadLine() { - bool carriage_return = false; - string result; - char buffer[1]; - while (true) { - idx_t bytes_read = Read(buffer, 1); - if (bytes_read == 0) { - return result; - } - if (carriage_return) { - if (buffer[0] != '\n') { - if (!file_handle->CanSeek()) { - throw BinderException( - "Carriage return newlines not supported when reading CSV files in which we cannot seek"); - } - file_handle->Seek(file_handle->SeekPosition() - 1); - return result; - } - } - if (buffer[0] == '\n') { - return result; - } - if (buffer[0] != '\r') { - result += buffer[0]; - } else { - carriage_return = true; - } - } -} - -string CSVFileHandle::GetFilePath() { - return path; -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -static bool ParseBoolean(const Value &value, const string &loption); - -static bool ParseBoolean(const vector &set, const string &loption) { - if (set.empty()) { - // no option specified: default to true - return true; - } - if (set.size() > 1) { - throw BinderException("\"%s\" expects a single argument as a boolean value (e.g. TRUE or 1)", loption); - } - return ParseBoolean(set[0], loption); -} - -static bool ParseBoolean(const Value &value, const string &loption) { - - if (value.type().id() == LogicalTypeId::LIST) { - auto &children = ListValue::GetChildren(value); - return ParseBoolean(children, loption); - } - if (value.type() == LogicalType::FLOAT || value.type() == LogicalType::DOUBLE || - value.type().id() == LogicalTypeId::DECIMAL) { - throw BinderException("\"%s\" expects a boolean value (e.g. TRUE or 1)", loption); - } - return BooleanValue::Get(value.DefaultCastAs(LogicalType::BOOLEAN)); -} - -static string ParseString(const Value &value, const string &loption) { - if (value.IsNull()) { - return string(); - } - if (value.type().id() == LogicalTypeId::LIST) { - auto &children = ListValue::GetChildren(value); - if (children.size() != 1) { - throw BinderException("\"%s\" expects a single argument as a string value", loption); - } - return ParseString(children[0], loption); - } - if (value.type().id() != LogicalTypeId::VARCHAR) { - throw BinderException("\"%s\" expects a string argument!", loption); - } - return value.GetValue(); -} - -static int64_t ParseInteger(const Value &value, const string &loption) { - if (value.type().id() == LogicalTypeId::LIST) { - auto &children = ListValue::GetChildren(value); - if (children.size() != 1) { - // no option specified or multiple options specified - throw BinderException("\"%s\" expects a single argument as an integer value", loption); - } - return ParseInteger(children[0], loption); - } - return value.GetValue(); -} - -bool CSVReaderOptions::GetHeader() const { - return this->dialect_options.header; -} - -void CSVReaderOptions::SetHeader(bool input) { - this->dialect_options.header = input; - this->has_header = true; -} - -void CSVReaderOptions::SetCompression(const string &compression_p) { - this->compression = FileCompressionTypeFromString(compression_p); -} - -string CSVReaderOptions::GetEscape() const { - return std::string(1, this->dialect_options.state_machine_options.escape); -} - -void CSVReaderOptions::SetEscape(const string &input) { - auto escape_str = input; - if (escape_str.size() > 1) { - throw InvalidInputException("The escape option cannot exceed a size of 1 byte."); - } - if (escape_str.empty()) { - escape_str = string("\0", 1); - } - this->dialect_options.state_machine_options.escape = escape_str[0]; - this->has_escape = true; -} - -int64_t CSVReaderOptions::GetSkipRows() const { - return this->dialect_options.skip_rows; -} - -void CSVReaderOptions::SetSkipRows(int64_t skip_rows) { - dialect_options.skip_rows = skip_rows; - skip_rows_set = true; -} - -string CSVReaderOptions::GetDelimiter() const { - return std::string(1, this->dialect_options.state_machine_options.delimiter); -} - -void CSVReaderOptions::SetDelimiter(const string &input) { - auto delim_str = StringUtil::Replace(input, "\\t", "\t"); - if (delim_str.size() > 1) { - throw InvalidInputException("The delimiter option cannot exceed a size of 1 byte."); - } - this->has_delimiter = true; - if (input.empty()) { - delim_str = string("\0", 1); - } - this->dialect_options.state_machine_options.delimiter = delim_str[0]; -} - -string CSVReaderOptions::GetQuote() const { - return std::string(1, this->dialect_options.state_machine_options.quote); -} - -void CSVReaderOptions::SetQuote(const string "e_p) { - auto quote_str = quote_p; - if (quote_str.size() > 1) { - throw InvalidInputException("The quote option cannot exceed a size of 1 byte."); - } - if (quote_str.empty()) { - quote_str = string("\0", 1); - } - this->dialect_options.state_machine_options.quote = quote_str[0]; - this->has_quote = true; -} - -NewLineIdentifier CSVReaderOptions::GetNewline() const { - return dialect_options.new_line; -} - -void CSVReaderOptions::SetNewline(const string &input) { - if (input == "\\n" || input == "\\r") { - dialect_options.new_line = NewLineIdentifier::SINGLE; - } else if (input == "\\r\\n") { - dialect_options.new_line = NewLineIdentifier::CARRY_ON; - } else { - throw InvalidInputException("This is not accepted as a newline: " + input); - } - has_newline = true; -} - -void CSVReaderOptions::SetDateFormat(LogicalTypeId type, const string &format, bool read_format) { - string error; - if (read_format) { - error = StrTimeFormat::ParseFormatSpecifier(format, dialect_options.date_format[type]); - dialect_options.date_format[type].format_specifier = format; - } else { - error = StrTimeFormat::ParseFormatSpecifier(format, write_date_format[type]); - } - if (!error.empty()) { - throw InvalidInputException("Could not parse DATEFORMAT: %s", error.c_str()); - } - dialect_options.has_format[type] = true; -} - -void CSVReaderOptions::SetReadOption(const string &loption, const Value &value, vector &expected_names) { - if (SetBaseOption(loption, value)) { - return; - } - if (loption == "auto_detect") { - auto_detect = ParseBoolean(value, loption); - } else if (loption == "sample_size") { - int64_t sample_size_option = ParseInteger(value, loption); - if (sample_size_option < 1 && sample_size_option != -1) { - throw BinderException("Unsupported parameter for SAMPLE_SIZE: cannot be smaller than 1"); - } - if (sample_size_option == -1) { - // If -1, we basically read the whole thing - sample_size_chunks = NumericLimits().Maximum(); - } else { - sample_size_chunks = sample_size_option / STANDARD_VECTOR_SIZE; - if (sample_size_option % STANDARD_VECTOR_SIZE != 0) { - sample_size_chunks++; - } - } - - } else if (loption == "skip") { - SetSkipRows(ParseInteger(value, loption)); - } else if (loption == "max_line_size" || loption == "maximum_line_size") { - maximum_line_size = ParseInteger(value, loption); - } else if (loption == "force_not_null") { - force_not_null = ParseColumnList(value, expected_names, loption); - } else if (loption == "date_format" || loption == "dateformat") { - string format = ParseString(value, loption); - SetDateFormat(LogicalTypeId::DATE, format, true); - } else if (loption == "timestamp_format" || loption == "timestampformat") { - string format = ParseString(value, loption); - SetDateFormat(LogicalTypeId::TIMESTAMP, format, true); - } else if (loption == "ignore_errors") { - ignore_errors = ParseBoolean(value, loption); - } else if (loption == "buffer_size") { - buffer_size = ParseInteger(value, loption); - if (buffer_size == 0) { - throw InvalidInputException("Buffer Size option must be higher than 0"); - } - } else if (loption == "decimal_separator") { - decimal_separator = ParseString(value, loption); - if (decimal_separator != "." && decimal_separator != ",") { - throw BinderException("Unsupported parameter for DECIMAL_SEPARATOR: should be '.' or ','"); - } - } else if (loption == "null_padding") { - null_padding = ParseBoolean(value, loption); - } else if (loption == "allow_quoted_nulls") { - allow_quoted_nulls = ParseBoolean(value, loption); - } else if (loption == "parallel") { - parallel_mode = ParseBoolean(value, loption) ? ParallelMode::PARALLEL : ParallelMode::SINGLE_THREADED; - } else if (loption == "rejects_table") { - // skip, handled in SetRejectsOptions - auto table_name = ParseString(value, loption); - if (table_name.empty()) { - throw BinderException("REJECTS_TABLE option cannot be empty"); - } - rejects_table_name = table_name; - } else if (loption == "rejects_recovery_columns") { - // Get the list of columns to use as a recovery key - auto &children = ListValue::GetChildren(value); - for (auto &child : children) { - auto col_name = child.GetValue(); - rejects_recovery_columns.push_back(col_name); - } - } else if (loption == "rejects_limit") { - int64_t limit = ParseInteger(value, loption); - if (limit < 0) { - throw BinderException("Unsupported parameter for REJECTS_LIMIT: cannot be negative"); - } - rejects_limit = limit; - } else { - throw BinderException("Unrecognized option for CSV reader \"%s\"", loption); - } -} - -void CSVReaderOptions::SetWriteOption(const string &loption, const Value &value) { - if (loption == "new_line") { - // Steal this from SetBaseOption so we can write different newlines (e.g., format JSON ARRAY) - write_newline = ParseString(value, loption); - return; - } - - if (SetBaseOption(loption, value)) { - return; - } - - if (loption == "force_quote") { - force_quote = ParseColumnList(value, name_list, loption); - } else if (loption == "date_format" || loption == "dateformat") { - string format = ParseString(value, loption); - SetDateFormat(LogicalTypeId::DATE, format, false); - } else if (loption == "timestamp_format" || loption == "timestampformat") { - string format = ParseString(value, loption); - if (StringUtil::Lower(format) == "iso") { - format = "%Y-%m-%dT%H:%M:%S.%fZ"; - } - SetDateFormat(LogicalTypeId::TIMESTAMP, format, false); - SetDateFormat(LogicalTypeId::TIMESTAMP_TZ, format, false); - } else if (loption == "prefix") { - prefix = ParseString(value, loption); - } else if (loption == "suffix") { - suffix = ParseString(value, loption); - } else { - throw BinderException("Unrecognized option CSV writer \"%s\"", loption); - } -} - -bool CSVReaderOptions::SetBaseOption(const string &loption, const Value &value) { - // Make sure this function was only called after the option was turned into lowercase - D_ASSERT(!std::any_of(loption.begin(), loption.end(), ::isupper)); - - if (StringUtil::StartsWith(loption, "delim") || StringUtil::StartsWith(loption, "sep")) { - SetDelimiter(ParseString(value, loption)); - } else if (loption == "quote") { - SetQuote(ParseString(value, loption)); - } else if (loption == "new_line") { - SetNewline(ParseString(value, loption)); - } else if (loption == "escape") { - SetEscape(ParseString(value, loption)); - } else if (loption == "header") { - SetHeader(ParseBoolean(value, loption)); - } else if (loption == "null" || loption == "nullstr") { - null_str = ParseString(value, loption); - } else if (loption == "encoding") { - auto encoding = StringUtil::Lower(ParseString(value, loption)); - if (encoding != "utf8" && encoding != "utf-8") { - throw BinderException("Copy is only supported for UTF-8 encoded files, ENCODING 'UTF-8'"); - } - } else if (loption == "compression") { - SetCompression(ParseString(value, loption)); - } else { - // unrecognized option in base CSV - return false; - } - return true; -} - -string CSVReaderOptions::ToString() const { - return " file=" + file_path + "\n delimiter='" + dialect_options.state_machine_options.delimiter + - (has_delimiter ? "'" : (auto_detect ? "' (auto detected)" : "' (default)")) + "\n quote='" + - dialect_options.state_machine_options.quote + - (has_quote ? "'" : (auto_detect ? "' (auto detected)" : "' (default)")) + "\n escape='" + - dialect_options.state_machine_options.escape + - (has_escape ? "'" : (auto_detect ? "' (auto detected)" : "' (default)")) + - "\n header=" + std::to_string(dialect_options.header) + - (has_header ? "" : (auto_detect ? " (auto detected)" : "' (default)")) + - "\n sample_size=" + std::to_string(sample_size_chunks * STANDARD_VECTOR_SIZE) + - "\n ignore_errors=" + std::to_string(ignore_errors) + "\n all_varchar=" + std::to_string(all_varchar); -} - -static Value StringVectorToValue(const vector &vec) { - vector content; - content.reserve(vec.size()); - for (auto &item : vec) { - content.push_back(Value(item)); - } - return Value::LIST(std::move(content)); -} - -static uint8_t GetCandidateSpecificity(const LogicalType &candidate_type) { - //! Const ht with accepted auto_types and their weights in specificity - const duckdb::unordered_map auto_type_candidates_specificity { - {(uint8_t)LogicalTypeId::VARCHAR, 0}, {(uint8_t)LogicalTypeId::TIMESTAMP, 1}, - {(uint8_t)LogicalTypeId::DATE, 2}, {(uint8_t)LogicalTypeId::TIME, 3}, - {(uint8_t)LogicalTypeId::DOUBLE, 4}, {(uint8_t)LogicalTypeId::FLOAT, 5}, - {(uint8_t)LogicalTypeId::BIGINT, 6}, {(uint8_t)LogicalTypeId::INTEGER, 7}, - {(uint8_t)LogicalTypeId::SMALLINT, 8}, {(uint8_t)LogicalTypeId::TINYINT, 9}, - {(uint8_t)LogicalTypeId::BOOLEAN, 10}, {(uint8_t)LogicalTypeId::SQLNULL, 11}}; - - auto id = (uint8_t)candidate_type.id(); - auto it = auto_type_candidates_specificity.find(id); - if (it == auto_type_candidates_specificity.end()) { - throw BinderException("Auto Type Candidate of type %s is not accepted as a valid input", - EnumUtil::ToString(candidate_type.id())); - } - return it->second; -} - -void CSVReaderOptions::FromNamedParameters(named_parameter_map_t &in, ClientContext &context, - vector &return_types, vector &names) { - for (auto &kv : in) { - if (MultiFileReader::ParseOption(kv.first, kv.second, file_options, context)) { - continue; - } - auto loption = StringUtil::Lower(kv.first); - if (loption == "columns") { - explicitly_set_columns = true; - auto &child_type = kv.second.type(); - if (child_type.id() != LogicalTypeId::STRUCT) { - throw BinderException("read_csv columns requires a struct as input"); - } - auto &struct_children = StructValue::GetChildren(kv.second); - D_ASSERT(StructType::GetChildCount(child_type) == struct_children.size()); - for (idx_t i = 0; i < struct_children.size(); i++) { - auto &name = StructType::GetChildName(child_type, i); - auto &val = struct_children[i]; - names.push_back(name); - if (val.type().id() != LogicalTypeId::VARCHAR) { - throw BinderException("read_csv requires a type specification as string"); - } - return_types.emplace_back(TransformStringToLogicalType(StringValue::Get(val), context)); - } - if (names.empty()) { - throw BinderException("read_csv requires at least a single column as input!"); - } - } else if (loption == "auto_type_candidates") { - auto_type_candidates.clear(); - map candidate_types; - // We always have the extremes of Null and Varchar, so we can default to varchar if the - // sniffer is not able to confidently detect that column type - candidate_types[GetCandidateSpecificity(LogicalType::VARCHAR)] = LogicalType::VARCHAR; - candidate_types[GetCandidateSpecificity(LogicalType::SQLNULL)] = LogicalType::SQLNULL; - - auto &child_type = kv.second.type(); - if (child_type.id() != LogicalTypeId::LIST) { - throw BinderException("read_csv auto_types requires a list as input"); - } - auto &list_children = ListValue::GetChildren(kv.second); - if (list_children.empty()) { - throw BinderException("auto_type_candidates requires at least one type"); - } - for (auto &child : list_children) { - if (child.type().id() != LogicalTypeId::VARCHAR) { - throw BinderException("auto_type_candidates requires a type specification as string"); - } - auto candidate_type = TransformStringToLogicalType(StringValue::Get(child), context); - candidate_types[GetCandidateSpecificity(candidate_type)] = candidate_type; - } - for (auto &candidate_type : candidate_types) { - auto_type_candidates.emplace_back(candidate_type.second); - } - } else if (loption == "column_names" || loption == "names") { - if (!name_list.empty()) { - throw BinderException("read_csv_auto column_names/names can only be supplied once"); - } - if (kv.second.IsNull()) { - throw BinderException("read_csv_auto %s cannot be NULL", kv.first); - } - auto &children = ListValue::GetChildren(kv.second); - for (auto &child : children) { - name_list.push_back(StringValue::Get(child)); - } - } else if (loption == "column_types" || loption == "types" || loption == "dtypes") { - auto &child_type = kv.second.type(); - if (child_type.id() != LogicalTypeId::STRUCT && child_type.id() != LogicalTypeId::LIST) { - throw BinderException("read_csv_auto %s requires a struct or list as input", kv.first); - } - if (!sql_type_list.empty()) { - throw BinderException("read_csv_auto column_types/types/dtypes can only be supplied once"); - } - vector sql_type_names; - if (child_type.id() == LogicalTypeId::STRUCT) { - auto &struct_children = StructValue::GetChildren(kv.second); - D_ASSERT(StructType::GetChildCount(child_type) == struct_children.size()); - for (idx_t i = 0; i < struct_children.size(); i++) { - auto &name = StructType::GetChildName(child_type, i); - auto &val = struct_children[i]; - if (val.type().id() != LogicalTypeId::VARCHAR) { - throw BinderException("read_csv_auto %s requires a type specification as string", kv.first); - } - sql_type_names.push_back(StringValue::Get(val)); - sql_types_per_column[name] = i; - } - } else { - auto &list_child = ListType::GetChildType(child_type); - if (list_child.id() != LogicalTypeId::VARCHAR) { - throw BinderException("read_csv_auto %s requires a list of types (varchar) as input", kv.first); - } - auto &children = ListValue::GetChildren(kv.second); - for (auto &child : children) { - sql_type_names.push_back(StringValue::Get(child)); - } - } - sql_type_list.reserve(sql_type_names.size()); - for (auto &sql_type : sql_type_names) { - auto def_type = TransformStringToLogicalType(sql_type, context); - if (def_type.id() == LogicalTypeId::USER) { - throw BinderException("Unrecognized type \"%s\" for read_csv_auto %s definition", sql_type, - kv.first); - } - sql_type_list.push_back(std::move(def_type)); - } - } else if (loption == "all_varchar") { - all_varchar = BooleanValue::Get(kv.second); - } else if (loption == "normalize_names") { - normalize_names = BooleanValue::Get(kv.second); - } else { - SetReadOption(loption, kv.second, names); - } - } -} - -//! This function is used to remember options set by the sniffer, for use in ReadCSVRelation -void CSVReaderOptions::ToNamedParameters(named_parameter_map_t &named_params) { - if (has_delimiter) { - named_params["delim"] = Value(GetDelimiter()); - } - if (has_newline) { - named_params["newline"] = Value(EnumUtil::ToString(GetNewline())); - } - if (has_quote) { - named_params["quote"] = Value(GetQuote()); - } - if (has_escape) { - named_params["escape"] = Value(GetEscape()); - } - if (has_header) { - named_params["header"] = Value(GetHeader()); - } - named_params["max_line_size"] = Value::BIGINT(maximum_line_size); - if (skip_rows_set) { - named_params["skip"] = Value::BIGINT(GetSkipRows()); - } - named_params["null_padding"] = Value::BOOLEAN(null_padding); - if (!date_format.at(LogicalType::DATE).format_specifier.empty()) { - named_params["dateformat"] = Value(date_format.at(LogicalType::DATE).format_specifier); - } - if (!date_format.at(LogicalType::TIMESTAMP).format_specifier.empty()) { - named_params["timestampformat"] = Value(date_format.at(LogicalType::TIMESTAMP).format_specifier); - } - - named_params["normalize_names"] = Value::BOOLEAN(normalize_names); - if (!name_list.empty() && !named_params.count("column_names") && !named_params.count("names")) { - named_params["column_names"] = StringVectorToValue(name_list); - } - named_params["all_varchar"] = Value::BOOLEAN(all_varchar); - named_params["maximum_line_size"] = Value::BIGINT(maximum_line_size); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -CSVStateMachine::CSVStateMachine(CSVReaderOptions &options_p, const CSVStateMachineOptions &state_machine_options, - shared_ptr buffer_manager_p, - CSVStateMachineCache &csv_state_machine_cache_p) - : csv_state_machine_cache(csv_state_machine_cache_p), options(options_p), - csv_buffer_iterator(std::move(buffer_manager_p)), - transition_array(csv_state_machine_cache.Get(state_machine_options)) { - dialect_options.state_machine_options = state_machine_options; - dialect_options.has_format = options.dialect_options.has_format; - dialect_options.date_format = options.dialect_options.date_format; - dialect_options.skip_rows = options.dialect_options.skip_rows; -} - -void CSVStateMachine::Reset() { - csv_buffer_iterator.Reset(); -} - -void CSVStateMachine::VerifyUTF8() { - auto utf_type = Utf8Proc::Analyze(value.c_str(), value.size()); - if (utf_type == UnicodeType::INVALID) { - int64_t error_line = cur_rows; - throw InvalidInputException("Error in file \"%s\" at line %llu: " - "%s. Parser options:\n%s", - options.file_path, error_line, ErrorManager::InvalidUnicodeError(value, "CSV file"), - options.ToString()); - } -} -} // namespace duckdb - - - -namespace duckdb { - -void InitializeTransitionArray(unsigned char *transition_array, const uint8_t state) { - for (uint32_t i = 0; i < NUM_TRANSITIONS; i++) { - transition_array[i] = state; - } -} - -void CSVStateMachineCache::Insert(const CSVStateMachineOptions &state_machine_options) { - D_ASSERT(state_machine_cache.find(state_machine_options) == state_machine_cache.end()); - // Initialize transition array with default values to the Standard option - auto &transition_array = state_machine_cache[state_machine_options]; - const uint8_t standard_state = static_cast(CSVState::STANDARD); - const uint8_t field_separator_state = static_cast(CSVState::DELIMITER); - const uint8_t record_separator_state = static_cast(CSVState::RECORD_SEPARATOR); - const uint8_t carriage_return_state = static_cast(CSVState::CARRIAGE_RETURN); - const uint8_t quoted_state = static_cast(CSVState::QUOTED); - const uint8_t unquoted_state = static_cast(CSVState::UNQUOTED); - const uint8_t escape_state = static_cast(CSVState::ESCAPE); - const uint8_t empty_line_state = static_cast(CSVState::EMPTY_LINE); - const uint8_t invalid_state = static_cast(CSVState::INVALID); - - for (uint32_t i = 0; i < NUM_STATES; i++) { - switch (i) { - case quoted_state: - InitializeTransitionArray(transition_array[i], quoted_state); - break; - case unquoted_state: - case invalid_state: - case escape_state: - InitializeTransitionArray(transition_array[i], invalid_state); - break; - default: - InitializeTransitionArray(transition_array[i], standard_state); - break; - } - } - - // Now set values depending on configuration - // 1) Standard State - transition_array[standard_state][static_cast(state_machine_options.delimiter)] = field_separator_state; - transition_array[standard_state][static_cast('\n')] = record_separator_state; - transition_array[standard_state][static_cast('\r')] = carriage_return_state; - transition_array[standard_state][static_cast(state_machine_options.quote)] = quoted_state; - // 2) Field Separator State - transition_array[field_separator_state][static_cast(state_machine_options.delimiter)] = - field_separator_state; - transition_array[field_separator_state][static_cast('\n')] = record_separator_state; - transition_array[field_separator_state][static_cast('\r')] = carriage_return_state; - transition_array[field_separator_state][static_cast(state_machine_options.quote)] = quoted_state; - // 3) Record Separator State - transition_array[record_separator_state][static_cast(state_machine_options.delimiter)] = - field_separator_state; - transition_array[record_separator_state][static_cast('\n')] = empty_line_state; - transition_array[record_separator_state][static_cast('\r')] = empty_line_state; - transition_array[record_separator_state][static_cast(state_machine_options.quote)] = quoted_state; - // 4) Carriage Return State - transition_array[carriage_return_state][static_cast('\n')] = record_separator_state; - transition_array[carriage_return_state][static_cast('\r')] = empty_line_state; - transition_array[carriage_return_state][static_cast(state_machine_options.escape)] = escape_state; - // 5) Quoted State - transition_array[quoted_state][static_cast(state_machine_options.quote)] = unquoted_state; - if (state_machine_options.quote != state_machine_options.escape) { - transition_array[quoted_state][static_cast(state_machine_options.escape)] = escape_state; - } - // 6) Unquoted State - transition_array[unquoted_state][static_cast('\n')] = record_separator_state; - transition_array[unquoted_state][static_cast('\r')] = carriage_return_state; - transition_array[unquoted_state][static_cast(state_machine_options.delimiter)] = field_separator_state; - if (state_machine_options.quote == state_machine_options.escape) { - transition_array[unquoted_state][static_cast(state_machine_options.escape)] = quoted_state; - } - // 7) Escaped State - transition_array[escape_state][static_cast(state_machine_options.quote)] = quoted_state; - transition_array[escape_state][static_cast(state_machine_options.escape)] = quoted_state; - // 8) Empty Line State - transition_array[empty_line_state][static_cast('\r')] = empty_line_state; - transition_array[empty_line_state][static_cast('\n')] = empty_line_state; -} - -CSVStateMachineCache::CSVStateMachineCache() { - for (auto quoterule : default_quote_rule) { - const auto "e_candidates = default_quote[static_cast(quoterule)]; - for (const auto "e : quote_candidates) { - for (const auto &delimiter : default_delimiter) { - const auto &escape_candidates = default_escape[static_cast(quoterule)]; - for (const auto &escape : escape_candidates) { - Insert({delimiter, quote, escape}); - } - } - } - } -} - -const state_machine_t &CSVStateMachineCache::Get(const CSVStateMachineOptions &state_machine_options) { - //! Custom State Machine, we need to create it and cache it first - if (state_machine_cache.find(state_machine_options) == state_machine_cache.end()) { - Insert(state_machine_options); - } - const auto &transition_array = state_machine_cache[state_machine_options]; - return transition_array; -} -} // namespace duckdb - - - - - - - - - - - - - - - - - - - -#include -#include -#include -#include - -namespace duckdb { - -ParallelCSVReader::ParallelCSVReader(ClientContext &context, CSVReaderOptions options_p, - unique_ptr buffer_p, idx_t first_pos_first_buffer_p, - const vector &requested_types, idx_t file_idx_p) - : BaseCSVReader(context, std::move(options_p), requested_types), file_idx(file_idx_p), - first_pos_first_buffer(first_pos_first_buffer_p) { - Initialize(requested_types); - SetBufferRead(std::move(buffer_p)); -} - -void ParallelCSVReader::Initialize(const vector &requested_types) { - return_types = requested_types; - InitParseChunk(return_types.size()); -} - -bool ParallelCSVReader::NewLineDelimiter(bool carry, bool carry_followed_by_nl, bool first_char) { - // Set the delimiter if not set yet. - SetNewLineDelimiter(carry, carry_followed_by_nl); - D_ASSERT(options.dialect_options.new_line == NewLineIdentifier::SINGLE || - options.dialect_options.new_line == NewLineIdentifier::CARRY_ON); - if (options.dialect_options.new_line == NewLineIdentifier::SINGLE) { - return (!carry) || (carry && !carry_followed_by_nl); - } - return (carry && carry_followed_by_nl) || (!carry && first_char); -} - -void ParallelCSVReader::SkipEmptyLines() { - idx_t new_pos_buffer = position_buffer; - if (parse_chunk.data.size() == 1) { - // Empty lines are null data. - return; - } - for (; new_pos_buffer < end_buffer; new_pos_buffer++) { - if (StringUtil::CharacterIsNewline((*buffer)[new_pos_buffer])) { - bool carrier_return = (*buffer)[new_pos_buffer] == '\r'; - new_pos_buffer++; - if (carrier_return && new_pos_buffer < buffer_size && (*buffer)[new_pos_buffer] == '\n') { - position_buffer++; - } - if (new_pos_buffer > end_buffer) { - return; - } - position_buffer = new_pos_buffer; - } else if ((*buffer)[new_pos_buffer] != ' ') { - return; - } - } -} - -bool ParallelCSVReader::SetPosition() { - if (buffer->buffer->is_first_buffer && start_buffer == position_buffer && start_buffer == first_pos_first_buffer) { - start_buffer = buffer->buffer->start_position; - position_buffer = start_buffer; - verification_positions.beginning_of_first_line = position_buffer; - verification_positions.end_of_last_line = position_buffer; - // First buffer doesn't need any setting - - if (options.dialect_options.header) { - for (; position_buffer < end_buffer; position_buffer++) { - if (StringUtil::CharacterIsNewline((*buffer)[position_buffer])) { - bool carrier_return = (*buffer)[position_buffer] == '\r'; - position_buffer++; - if (carrier_return && position_buffer < buffer_size && (*buffer)[position_buffer] == '\n') { - position_buffer++; - } - if (position_buffer > end_buffer) { - VerifyLineLength(position_buffer, buffer->batch_index); - return false; - } - SkipEmptyLines(); - if (verification_positions.beginning_of_first_line == 0) { - verification_positions.beginning_of_first_line = position_buffer; - } - VerifyLineLength(position_buffer, buffer->batch_index); - verification_positions.end_of_last_line = position_buffer; - return true; - } - } - VerifyLineLength(position_buffer, buffer->batch_index); - return false; - } - SkipEmptyLines(); - if (verification_positions.beginning_of_first_line == 0) { - verification_positions.beginning_of_first_line = position_buffer; - } - - verification_positions.end_of_last_line = position_buffer; - return true; - } - - // We have to move position up to next new line - idx_t end_buffer_real = end_buffer; - // Check if we already start in a valid line - string error_message; - bool successfully_read_first_line = false; - while (!successfully_read_first_line) { - DataChunk first_line_chunk; - first_line_chunk.Initialize(allocator, return_types); - // Ensure that parse_chunk has no gunk when trying to figure new line - parse_chunk.Reset(); - for (; position_buffer < end_buffer; position_buffer++) { - if (StringUtil::CharacterIsNewline((*buffer)[position_buffer])) { - bool carriage_return = (*buffer)[position_buffer] == '\r'; - bool carriage_return_followed = false; - position_buffer++; - if (position_buffer < end_buffer) { - if (carriage_return && (*buffer)[position_buffer] == '\n') { - carriage_return_followed = true; - position_buffer++; - } - } - if (NewLineDelimiter(carriage_return, carriage_return_followed, position_buffer - 1 == start_buffer)) { - break; - } - } - } - SkipEmptyLines(); - - if (position_buffer > buffer_size) { - break; - } - - auto pos_check = position_buffer == 0 ? position_buffer : position_buffer - 1; - if (position_buffer >= end_buffer && !StringUtil::CharacterIsNewline((*buffer)[pos_check])) { - break; - } - - if (position_buffer > end_buffer && options.dialect_options.new_line == NewLineIdentifier::CARRY_ON && - (*buffer)[pos_check] == '\n') { - break; - } - idx_t position_set = position_buffer; - start_buffer = position_buffer; - // We check if we can add this line - // disable the projection pushdown while reading the first line - // otherwise the first line parsing can be influenced by which columns we are reading - auto column_ids = std::move(reader_data.column_ids); - auto column_mapping = std::move(reader_data.column_mapping); - InitializeProjection(); - try { - successfully_read_first_line = TryParseSimpleCSV(first_line_chunk, error_message, true); - } catch (...) { - successfully_read_first_line = false; - } - // restore the projection pushdown - reader_data.column_ids = std::move(column_ids); - reader_data.column_mapping = std::move(column_mapping); - end_buffer = end_buffer_real; - start_buffer = position_set; - if (position_buffer >= end_buffer) { - if (successfully_read_first_line) { - position_buffer = position_set; - } - break; - } - position_buffer = position_set; - } - if (verification_positions.beginning_of_first_line == 0) { - verification_positions.beginning_of_first_line = position_buffer; - } - // Ensure that parse_chunk has no gunk when trying to figure new line - parse_chunk.Reset(); - - verification_positions.end_of_last_line = position_buffer; - finished = false; - return successfully_read_first_line; -} - -void ParallelCSVReader::SetBufferRead(unique_ptr buffer_read_p) { - if (!buffer_read_p->buffer) { - throw InternalException("ParallelCSVReader::SetBufferRead - CSVBufferRead does not have a buffer to read"); - } - position_buffer = buffer_read_p->buffer_start; - start_buffer = buffer_read_p->buffer_start; - end_buffer = buffer_read_p->buffer_end; - if (buffer_read_p->next_buffer) { - buffer_size = buffer_read_p->buffer->actual_size + buffer_read_p->next_buffer->actual_size; - } else { - buffer_size = buffer_read_p->buffer->actual_size; - } - buffer = std::move(buffer_read_p); - - reached_remainder_state = false; - verification_positions.beginning_of_first_line = 0; - verification_positions.end_of_last_line = 0; - finished = false; - D_ASSERT(end_buffer <= buffer_size); -} - -VerificationPositions ParallelCSVReader::GetVerificationPositions() { - verification_positions.beginning_of_first_line += buffer->buffer->csv_global_start; - verification_positions.end_of_last_line += buffer->buffer->csv_global_start; - return verification_positions; -} - -// If BufferRemainder returns false, it means we are done scanning this buffer and should go to the end_state -bool ParallelCSVReader::BufferRemainder() { - if (position_buffer >= end_buffer && !reached_remainder_state) { - // First time we finish the buffer piece we should scan here, we set the variables - // to allow this piece to be scanned up to the end of the buffer or the next new line - reached_remainder_state = true; - // end_buffer is allowed to go to buffer size to finish its last line - end_buffer = buffer_size; - } - if (position_buffer >= end_buffer) { - // buffer ends, return false - return false; - } - // we can still scan stuff, return true - return true; -} - -bool AllNewLine(string_t value, idx_t column_amount) { - auto value_str = value.GetString(); - if (value_str.empty() && column_amount == 1) { - // This is a one column (empty) - return false; - } - for (idx_t i = 0; i < value.GetSize(); i++) { - if (!StringUtil::CharacterIsNewline(value_str[i])) { - return false; - } - } - return true; -} - -bool ParallelCSVReader::TryParseSimpleCSV(DataChunk &insert_chunk, string &error_message, bool try_add_line) { - // If line is not set, we have to figure it out, we assume whatever is in the first line - if (options.dialect_options.new_line == NewLineIdentifier::NOT_SET) { - idx_t cur_pos = position_buffer; - // we can start in the middle of a new line, so move a bit forward. - while (cur_pos < end_buffer) { - if (StringUtil::CharacterIsNewline((*buffer)[cur_pos])) { - cur_pos++; - } else { - break; - } - } - for (; cur_pos < end_buffer; cur_pos++) { - if (StringUtil::CharacterIsNewline((*buffer)[cur_pos])) { - bool carriage_return = (*buffer)[cur_pos] == '\r'; - bool carriage_return_followed = false; - cur_pos++; - if (cur_pos < end_buffer) { - if (carriage_return && (*buffer)[cur_pos] == '\n') { - carriage_return_followed = true; - cur_pos++; - } - } - SetNewLineDelimiter(carriage_return, carriage_return_followed); - break; - } - } - } - // used for parsing algorithm - if (start_buffer == buffer_size) { - // Nothing to read - finished = true; - return true; - } - D_ASSERT(end_buffer <= buffer_size); - bool finished_chunk = false; - idx_t column = 0; - idx_t offset = 0; - bool has_quotes = false; - - vector escape_positions; - if ((start_buffer == buffer->buffer_start || start_buffer == buffer->buffer_end) && !try_add_line) { - // First time reading this buffer piece - if (!SetPosition()) { - finished = true; - return true; - } - } - if (position_buffer == buffer_size) { - // Nothing to read - finished = true; - return true; - } - // Keep track of line size - idx_t line_start = position_buffer; - // start parsing the first value - goto value_start; - -value_start : { - /* state: value_start */ - if (!BufferRemainder()) { - goto final_state; - } - offset = 0; - - // this state parses the first character of a value - if ((*buffer)[position_buffer] == options.dialect_options.state_machine_options.quote) { - // quote: actual value starts in the next position - // move to in_quotes state - start_buffer = position_buffer + 1; - goto in_quotes; - } else { - // no quote, move to normal parsing state - start_buffer = position_buffer; - goto normal; - } -}; - -normal : { - /* state: normal parsing state */ - // this state parses the remainder of a non-quoted value until we reach a delimiter or newline - for (; position_buffer < end_buffer; position_buffer++) { - auto c = (*buffer)[position_buffer]; - if (c == options.dialect_options.state_machine_options.delimiter) { - // Check if previous character is a quote, if yes, this means we are in a non-initialized quoted value - // This only matters for when trying to figure out where csv lines start - if (position_buffer > 0 && try_add_line) { - if ((*buffer)[position_buffer - 1] == options.dialect_options.state_machine_options.quote) { - return false; - } - } - // delimiter: end the value and add it to the chunk - goto add_value; - } else if (StringUtil::CharacterIsNewline(c)) { - // Check if previous character is a quote, if yes, this means we are in a non-initialized quoted value - // This only matters for when trying to figure out where csv lines start - if (position_buffer > 0 && try_add_line) { - if ((*buffer)[position_buffer - 1] == options.dialect_options.state_machine_options.quote) { - return false; - } - } - // newline: add row - if (column > 0 || try_add_line || parse_chunk.data.size() == 1) { - goto add_row; - } - if (column == 0 && position_buffer == start_buffer) { - start_buffer++; - } - } - } - if (!BufferRemainder()) { - goto final_state; - } else { - goto normal; - } -}; - -add_value : { - /* state: Add value to string vector */ - AddValue(buffer->GetValue(start_buffer, position_buffer, offset), column, escape_positions, has_quotes, - buffer->local_batch_index); - // increase position by 1 and move start to the new position - offset = 0; - has_quotes = false; - start_buffer = ++position_buffer; - if (!BufferRemainder()) { - goto final_state; - } - goto value_start; -}; - -add_row : { - /* state: Add Row to Parse chunk */ - // check type of newline (\r or \n) - bool carriage_return = (*buffer)[position_buffer] == '\r'; - - AddValue(buffer->GetValue(start_buffer, position_buffer, offset), column, escape_positions, has_quotes, - buffer->local_batch_index); - if (try_add_line) { - bool success = column == insert_chunk.ColumnCount(); - if (success) { - idx_t cur_linenr = linenr; - AddRow(insert_chunk, column, error_message, buffer->local_batch_index); - success = Flush(insert_chunk, buffer->local_batch_index, true); - linenr = cur_linenr; - } - reached_remainder_state = false; - parse_chunk.Reset(); - return success; - } else { - VerifyLineLength(position_buffer - line_start, buffer->batch_index); - line_start = position_buffer; - finished_chunk = AddRow(insert_chunk, column, error_message, buffer->local_batch_index); - } - // increase position by 1 and move start to the new position - offset = 0; - has_quotes = false; - position_buffer++; - start_buffer = position_buffer; - verification_positions.end_of_last_line = position_buffer; - if (carriage_return) { - // \r newline, go to special state that parses an optional \n afterwards - // optionally skips a newline (\n) character, which allows \r\n to be interpreted as a single line - if (!BufferRemainder()) { - goto final_state; - } - if ((*buffer)[position_buffer] == '\n') { - if (options.dialect_options.new_line == NewLineIdentifier::SINGLE) { - error_message = "Wrong NewLine Identifier. Expecting \\r\\n"; - return false; - } - // newline after carriage return: skip - // increase position by 1 and move start to the new position - start_buffer = ++position_buffer; - - SkipEmptyLines(); - verification_positions.end_of_last_line = position_buffer; - start_buffer = position_buffer; - if (reached_remainder_state) { - goto final_state; - } - } else { - if (options.dialect_options.new_line == NewLineIdentifier::CARRY_ON) { - error_message = "Wrong NewLine Identifier. Expecting \\r or \\n"; - return false; - } - } - if (!BufferRemainder()) { - goto final_state; - } - if (reached_remainder_state || finished_chunk) { - goto final_state; - } - goto value_start; - } else { - if (options.dialect_options.new_line == NewLineIdentifier::CARRY_ON) { - error_message = "Wrong NewLine Identifier. Expecting \\r or \\n"; - return false; - } - if (reached_remainder_state) { - goto final_state; - } - if (!BufferRemainder()) { - goto final_state; - } - SkipEmptyLines(); - if (position_buffer - verification_positions.end_of_last_line > options.buffer_size) { - error_message = "Line does not fit in one buffer. Increase the buffer size."; - return false; - } - verification_positions.end_of_last_line = position_buffer; - start_buffer = position_buffer; - // \n newline, move to value start - if (finished_chunk) { - goto final_state; - } - goto value_start; - } -} -in_quotes: - /* state: in_quotes this state parses the remainder of a quoted value*/ - has_quotes = true; - position_buffer++; - for (; position_buffer < end_buffer; position_buffer++) { - auto c = (*buffer)[position_buffer]; - if (c == options.dialect_options.state_machine_options.quote) { - // quote: move to unquoted state - goto unquote; - } else if (c == options.dialect_options.state_machine_options.escape) { - // escape: store the escaped position and move to handle_escape state - escape_positions.push_back(position_buffer - start_buffer); - goto handle_escape; - } - } - if (!BufferRemainder()) { - if (buffer->buffer->is_last_buffer) { - if (try_add_line) { - return false; - } - // still in quoted state at the end of the file or at the end of a buffer when running multithreaded, error: - throw InvalidInputException("Error in file \"%s\" on line %s: unterminated quotes. (%s)", options.file_path, - GetLineNumberStr(linenr, linenr_estimated, buffer->local_batch_index).c_str(), - options.ToString()); - } else { - goto final_state; - } - } else { - position_buffer--; - goto in_quotes; - } - -unquote : { - /* state: unquote: this state handles the state directly after we unquote*/ - // - // in this state we expect either another quote (entering the quoted state again, and escaping the quote) - // or a delimiter/newline, ending the current value and moving on to the next value - position_buffer++; - if (!BufferRemainder()) { - offset = 1; - goto final_state; - } - auto c = (*buffer)[position_buffer]; - if (c == options.dialect_options.state_machine_options.quote && - (options.dialect_options.state_machine_options.escape == '\0' || - options.dialect_options.state_machine_options.escape == options.dialect_options.state_machine_options.quote)) { - // escaped quote, return to quoted state and store escape position - escape_positions.push_back(position_buffer - start_buffer); - goto in_quotes; - } else if (c == options.dialect_options.state_machine_options.delimiter) { - // delimiter, add value - offset = 1; - goto add_value; - } else if (StringUtil::CharacterIsNewline(c)) { - offset = 1; - // FIXME: should this be an assertion? - D_ASSERT(try_add_line || (!try_add_line && column == parse_chunk.ColumnCount() - 1)); - goto add_row; - } else if (position_buffer >= end_buffer) { - // reached end of buffer - offset = 1; - goto final_state; - } else { - error_message = StringUtil::Format( - "Error in file \"%s\" on line %s: quote should be followed by end of value, end of " - "row or another quote. (%s). ", - options.file_path, GetLineNumberStr(linenr, linenr_estimated, buffer->local_batch_index).c_str(), - options.ToString()); - return false; - } -} -handle_escape : { - /* state: handle_escape */ - // escape should be followed by a quote or another escape character - position_buffer++; - if (!BufferRemainder()) { - goto final_state; - } - if (position_buffer >= buffer_size && buffer->buffer->is_last_buffer) { - error_message = StringUtil::Format( - "Error in file \"%s\" on line %s: neither QUOTE nor ESCAPE is proceeded by ESCAPE. (%s)", options.file_path, - GetLineNumberStr(linenr, linenr_estimated, buffer->local_batch_index).c_str(), options.ToString()); - return false; - } - if ((*buffer)[position_buffer] != options.dialect_options.state_machine_options.quote && - (*buffer)[position_buffer] != options.dialect_options.state_machine_options.escape) { - error_message = StringUtil::Format( - "Error in file \"%s\" on line %s: neither QUOTE nor ESCAPE is proceeded by ESCAPE. (%s)", options.file_path, - GetLineNumberStr(linenr, linenr_estimated, buffer->local_batch_index).c_str(), options.ToString()); - return false; - } - // escape was followed by quote or escape, go back to quoted state - goto in_quotes; -} -final_state : { - /* state: final_stage reached after we finished reading the end_buffer of the csv buffer */ - // reset end buffer - end_buffer = buffer->buffer_end; - if (position_buffer == end_buffer) { - reached_remainder_state = false; - } - if (finished_chunk) { - if (position_buffer >= end_buffer) { - if (position_buffer == end_buffer && StringUtil::CharacterIsNewline((*buffer)[position_buffer - 1]) && - position_buffer < buffer_size) { - // last position is a new line, we still have to go through one more line of this buffer - finished = false; - } else { - finished = true; - } - } - buffer->lines_read += insert_chunk.size(); - return true; - } - // If this is the last buffer, we have to read the last value - if (buffer->buffer->is_last_buffer || !buffer->next_buffer || - (buffer->next_buffer && buffer->next_buffer->is_last_buffer)) { - if (column > 0 || start_buffer != position_buffer || try_add_line || - (insert_chunk.data.size() == 1 && start_buffer != position_buffer)) { - // remaining values to be added to the chunk - auto str_value = buffer->GetValue(start_buffer, position_buffer, offset); - if (!AllNewLine(str_value, insert_chunk.data.size()) || offset == 0) { - AddValue(str_value, column, escape_positions, has_quotes, buffer->local_batch_index); - if (try_add_line) { - bool success = column == return_types.size(); - if (success) { - auto cur_linenr = linenr; - AddRow(insert_chunk, column, error_message, buffer->local_batch_index); - success = Flush(insert_chunk, buffer->local_batch_index); - linenr = cur_linenr; - } - parse_chunk.Reset(); - reached_remainder_state = false; - return success; - } else { - VerifyLineLength(position_buffer - line_start, buffer->batch_index); - line_start = position_buffer; - AddRow(insert_chunk, column, error_message, buffer->local_batch_index); - if (position_buffer - verification_positions.end_of_last_line > options.buffer_size) { - error_message = "Line does not fit in one buffer. Increase the buffer size."; - return false; - } - verification_positions.end_of_last_line = position_buffer; - } - } - } - } - // flush the parsed chunk and finalize parsing - if (mode == ParserMode::PARSING) { - Flush(insert_chunk, buffer->local_batch_index); - buffer->lines_read += insert_chunk.size(); - } - if (position_buffer - verification_positions.end_of_last_line > options.buffer_size) { - error_message = "Line does not fit in one buffer. Increase the buffer size."; - return false; - } - end_buffer = buffer_size; - SkipEmptyLines(); - end_buffer = buffer->buffer_end; - verification_positions.end_of_last_line = position_buffer; - if (position_buffer >= end_buffer) { - if (position_buffer >= end_buffer) { - if (position_buffer == end_buffer && StringUtil::CharacterIsNewline((*buffer)[position_buffer - 1]) && - position_buffer < buffer_size) { - // last position is a new line, we still have to go through one more line of this buffer - finished = false; - } else { - finished = true; - } - } - } - return true; -}; -} - -void ParallelCSVReader::ParseCSV(DataChunk &insert_chunk) { - string error_message; - if (!TryParseCSV(ParserMode::PARSING, insert_chunk, error_message)) { - throw InvalidInputException(error_message); - } -} - -idx_t ParallelCSVReader::GetLineError(idx_t line_error, idx_t buffer_idx, bool stop_at_first) { - while (true) { - if (buffer->line_info->CanItGetLine(file_idx, buffer_idx)) { - auto cur_start = verification_positions.beginning_of_first_line + buffer->buffer->csv_global_start; - return buffer->line_info->GetLine(buffer_idx, line_error, file_idx, cur_start, false, stop_at_first); - } - } -} - -void ParallelCSVReader::Increment(idx_t buffer_idx) { - return buffer->line_info->Increment(file_idx, buffer_idx); -} - -bool ParallelCSVReader::TryParseCSV(ParserMode mode) { - DataChunk dummy_chunk; - string error_message; - return TryParseCSV(mode, dummy_chunk, error_message); -} - -void ParallelCSVReader::ParseCSV(ParserMode mode) { - DataChunk dummy_chunk; - string error_message; - if (!TryParseCSV(mode, dummy_chunk, error_message)) { - throw InvalidInputException(error_message); - } -} - -bool ParallelCSVReader::TryParseCSV(ParserMode parser_mode, DataChunk &insert_chunk, string &error_message) { - mode = parser_mode; - return TryParseSimpleCSV(insert_chunk, error_message); -} - -} // namespace duckdb - -#endif diff --git a/lib/duckdb-5.cpp b/lib/duckdb-5.cpp deleted file mode 100644 index 35a2effe..00000000 --- a/lib/duckdb-5.cpp +++ /dev/null @@ -1,20490 +0,0 @@ -#ifdef GODUCKDB_FROM_SOURCE -// See https://raw.githubusercontent.com/duckdb/duckdb/main/LICENSE for licensing information - -#include "duckdb.hpp" -#include "duckdb-internal.hpp" -#ifndef DUCKDB_AMALGAMATION -#error header mismatch -#endif - - -namespace duckdb { - -CSVSniffer::CSVSniffer(CSVReaderOptions &options_p, shared_ptr buffer_manager_p, - CSVStateMachineCache &state_machine_cache_p, bool explicit_set_columns_p) - : state_machine_cache(state_machine_cache_p), options(options_p), buffer_manager(std::move(buffer_manager_p)), - explicit_set_columns(explicit_set_columns_p) { - - // Check if any type is BLOB - for (auto &type : options.sql_type_list) { - if (type.id() == LogicalTypeId::BLOB) { - throw InvalidInputException( - "CSV auto-detect for blobs not supported: there may be invalid UTF-8 in the file"); - } - } - - // Initialize Format Candidates - for (const auto &format_template : format_template_candidates) { - auto &logical_type = format_template.first; - best_format_candidates[logical_type].clear(); - } -} - -SnifferResult CSVSniffer::SniffCSV() { - // 1. Dialect Detection - DetectDialect(); - if (explicit_set_columns) { - if (!candidates.empty()) { - options.dialect_options.state_machine_options = candidates[0]->dialect_options.state_machine_options; - options.dialect_options.new_line = candidates[0]->dialect_options.new_line; - } - // We do not need to run type and header detection as these were defined by the user - return SnifferResult(detected_types, names); - } - // 2. Type Detection - DetectTypes(); - // 3. Header Detection - DetectHeader(); - D_ASSERT(best_sql_types_candidates_per_column_idx.size() == names.size()); - // 4. Type Replacement - ReplaceTypes(); - // 5. Type Refinement - RefineTypes(); - // We are done, construct and return the result. - - // Set the CSV Options in the reference - options.dialect_options = best_candidate->dialect_options; - options.has_header = best_candidate->dialect_options.header; - options.skip_rows_set = options.dialect_options.skip_rows > 0; - if (options.has_header) { - options.dialect_options.true_start = best_start_with_header; - } else { - options.dialect_options.true_start = best_start_without_header; - } - - // Return the types and names - return SnifferResult(detected_types, names); -} - -} // namespace duckdb - - - -namespace duckdb { - -struct SniffDialect { - inline static void Initialize(CSVStateMachine &machine) { - machine.state = CSVState::STANDARD; - machine.previous_state = CSVState::STANDARD; - machine.pre_previous_state = CSVState::STANDARD; - machine.cur_rows = 0; - machine.column_count = 1; - } - - inline static bool Process(CSVStateMachine &machine, vector &sniffed_column_counts, char current_char, - idx_t current_pos) { - - D_ASSERT(sniffed_column_counts.size() == STANDARD_VECTOR_SIZE); - - if (machine.state == CSVState::INVALID) { - sniffed_column_counts.clear(); - return true; - } - machine.pre_previous_state = machine.previous_state; - machine.previous_state = machine.state; - - machine.state = static_cast( - machine.transition_array[static_cast(machine.state)][static_cast(current_char)]); - - bool carriage_return = machine.previous_state == CSVState::CARRIAGE_RETURN; - machine.column_count += machine.previous_state == CSVState::DELIMITER; - sniffed_column_counts[machine.cur_rows] = machine.column_count; - machine.cur_rows += - machine.previous_state == CSVState::RECORD_SEPARATOR && machine.state != CSVState::EMPTY_LINE; - machine.column_count -= (machine.column_count - 1) * (machine.previous_state == CSVState::RECORD_SEPARATOR); - - // It means our carriage return is actually a record separator - machine.cur_rows += machine.state != CSVState::RECORD_SEPARATOR && carriage_return; - machine.column_count -= - (machine.column_count - 1) * (machine.state != CSVState::RECORD_SEPARATOR && carriage_return); - - // Identify what is our line separator - machine.carry_on_separator = - (machine.state == CSVState::RECORD_SEPARATOR && carriage_return) || machine.carry_on_separator; - machine.single_record_separator = ((machine.state != CSVState::RECORD_SEPARATOR && carriage_return) || - (machine.state == CSVState::RECORD_SEPARATOR && !carriage_return)) || - machine.single_record_separator; - if (machine.cur_rows >= STANDARD_VECTOR_SIZE) { - // We sniffed enough rows - return true; - } - return false; - } - inline static void Finalize(CSVStateMachine &machine, vector &sniffed_column_counts) { - if (machine.state == CSVState::INVALID) { - return; - } - if (machine.cur_rows < STANDARD_VECTOR_SIZE && machine.state == CSVState::DELIMITER) { - sniffed_column_counts[machine.cur_rows] = ++machine.column_count; - } - if (machine.cur_rows < STANDARD_VECTOR_SIZE && machine.state != CSVState::EMPTY_LINE) { - sniffed_column_counts[machine.cur_rows++] = machine.column_count; - } - NewLineIdentifier suggested_newline; - if (machine.carry_on_separator) { - if (machine.single_record_separator) { - suggested_newline = NewLineIdentifier::MIX; - } else { - suggested_newline = NewLineIdentifier::CARRY_ON; - } - } else { - suggested_newline = NewLineIdentifier::SINGLE; - } - if (machine.options.dialect_options.new_line == NewLineIdentifier::NOT_SET) { - machine.dialect_options.new_line = suggested_newline; - } else { - if (machine.options.dialect_options.new_line != suggested_newline) { - // Invalidate this whole detection - machine.cur_rows = 0; - } - } - sniffed_column_counts.erase(sniffed_column_counts.begin() + machine.cur_rows, sniffed_column_counts.end()); - } -}; - -void CSVSniffer::GenerateCandidateDetectionSearchSpace(vector &delim_candidates, - vector "erule_candidates, - unordered_map> "e_candidates_map, - unordered_map> &escape_candidates_map) { - if (options.has_delimiter) { - // user provided a delimiter: use that delimiter - delim_candidates = {options.dialect_options.state_machine_options.delimiter}; - } else { - // no delimiter provided: try standard/common delimiters - delim_candidates = {',', '|', ';', '\t'}; - } - if (options.has_quote) { - // user provided quote: use that quote rule - quote_candidates_map[(uint8_t)QuoteRule::QUOTES_RFC] = {options.dialect_options.state_machine_options.quote}; - quote_candidates_map[(uint8_t)QuoteRule::QUOTES_OTHER] = {options.dialect_options.state_machine_options.quote}; - quote_candidates_map[(uint8_t)QuoteRule::NO_QUOTES] = {options.dialect_options.state_machine_options.quote}; - } else { - // no quote rule provided: use standard/common quotes - quote_candidates_map[(uint8_t)QuoteRule::QUOTES_RFC] = {'\"'}; - quote_candidates_map[(uint8_t)QuoteRule::QUOTES_OTHER] = {'\"', '\''}; - quote_candidates_map[(uint8_t)QuoteRule::NO_QUOTES] = {'\0'}; - } - if (options.has_escape) { - // user provided escape: use that escape rule - if (options.dialect_options.state_machine_options.escape == '\0') { - quoterule_candidates = {QuoteRule::QUOTES_RFC}; - } else { - quoterule_candidates = {QuoteRule::QUOTES_OTHER}; - } - escape_candidates_map[(uint8_t)quoterule_candidates[0]] = { - options.dialect_options.state_machine_options.escape}; - } else { - // no escape provided: try standard/common escapes - quoterule_candidates = {QuoteRule::QUOTES_RFC, QuoteRule::QUOTES_OTHER, QuoteRule::NO_QUOTES}; - } -} - -void CSVSniffer::GenerateStateMachineSearchSpace(vector> &csv_state_machines, - const vector &delimiter_candidates, - const vector "erule_candidates, - const unordered_map> "e_candidates_map, - const unordered_map> &escape_candidates_map) { - // Generate state machines for all option combinations - for (const auto quoterule : quoterule_candidates) { - const auto "e_candidates = quote_candidates_map.at((uint8_t)quoterule); - for (const auto "e : quote_candidates) { - for (const auto &delimiter : delimiter_candidates) { - const auto &escape_candidates = escape_candidates_map.at((uint8_t)quoterule); - for (const auto &escape : escape_candidates) { - D_ASSERT(buffer_manager); - CSVStateMachineOptions state_machine_options(delimiter, quote, escape); - csv_state_machines.emplace_back(make_uniq(options, state_machine_options, - buffer_manager, state_machine_cache)); - } - } - } - } -} - -void CSVSniffer::AnalyzeDialectCandidate(unique_ptr state_machine, idx_t &rows_read, - idx_t &best_consistent_rows, idx_t &prev_padding_count) { - // The sniffed_column_counts variable keeps track of the number of columns found for each row - vector sniffed_column_counts(STANDARD_VECTOR_SIZE); - - state_machine->csv_buffer_iterator.Process(*state_machine, sniffed_column_counts); - idx_t start_row = options.dialect_options.skip_rows; - idx_t consistent_rows = 0; - idx_t num_cols = sniffed_column_counts.empty() ? 0 : sniffed_column_counts[0]; - idx_t padding_count = 0; - bool allow_padding = options.null_padding; - if (sniffed_column_counts.size() > rows_read) { - rows_read = sniffed_column_counts.size(); - } - for (idx_t row = 0; row < sniffed_column_counts.size(); row++) { - if (sniffed_column_counts[row] == num_cols) { - consistent_rows++; - } else if (num_cols < sniffed_column_counts[row] && !options.skip_rows_set) { - // all rows up to this point will need padding - padding_count = 0; - // we use the maximum amount of num_cols that we find - num_cols = sniffed_column_counts[row]; - start_row = row + options.dialect_options.skip_rows; - consistent_rows = 1; - - } else if (num_cols >= sniffed_column_counts[row]) { - // we are missing some columns, we can parse this as long as we add padding - padding_count++; - } - } - - // Calculate the total number of consistent rows after adding padding. - consistent_rows += padding_count; - - // Whether there are more values (rows) available that are consistent, exceeding the current best. - bool more_values = (consistent_rows > best_consistent_rows && num_cols >= max_columns_found); - - // If additional padding is required when compared to the previous padding count. - bool require_more_padding = padding_count > prev_padding_count; - - // If less padding is now required when compared to the previous padding count. - bool require_less_padding = padding_count < prev_padding_count; - - // If there was only a single column before, and the new number of columns exceeds that. - bool single_column_before = max_columns_found < 2 && num_cols > max_columns_found; - - // If the number of rows is consistent with the calculated value after accounting for skipped rows and the - // start row. - bool rows_consistent = - start_row + consistent_rows - options.dialect_options.skip_rows == sniffed_column_counts.size(); - - // If there are more than one consistent row. - bool more_than_one_row = (consistent_rows > 1); - - // If there are more than one column. - bool more_than_one_column = (num_cols > 1); - - // If the start position is valid. - bool start_good = !candidates.empty() && (start_row <= candidates.front()->start_row); - - // If padding happened but it is not allowed. - bool invalid_padding = !allow_padding && padding_count > 0; - - // If rows are consistent and no invalid padding happens, this is the best suitable candidate if one of the - // following is valid: - // - There's a single column before. - // - There are more values and no additional padding is required. - // - There's more than one column and less padding is required. - if (rows_consistent && - (single_column_before || (more_values && !require_more_padding) || - (more_than_one_column && require_less_padding)) && - !invalid_padding) { - best_consistent_rows = consistent_rows; - max_columns_found = num_cols; - prev_padding_count = padding_count; - state_machine->start_row = start_row; - candidates.clear(); - state_machine->dialect_options.num_cols = num_cols; - candidates.emplace_back(std::move(state_machine)); - return; - } - // If there's more than one row and column, the start is good, rows are consistent, - // no additional padding is required, and there is no invalid padding, and there is not yet a candidate - // with the same quote, we add this state_machine as a suitable candidate. - if (more_than_one_row && more_than_one_column && start_good && rows_consistent && !require_more_padding && - !invalid_padding) { - bool same_quote_is_candidate = false; - for (auto &candidate : candidates) { - if (state_machine->dialect_options.state_machine_options.quote == - candidate->dialect_options.state_machine_options.quote) { - same_quote_is_candidate = true; - } - } - if (!same_quote_is_candidate) { - state_machine->start_row = start_row; - state_machine->dialect_options.num_cols = num_cols; - candidates.emplace_back(std::move(state_machine)); - } - } -} - -bool CSVSniffer::RefineCandidateNextChunk(CSVStateMachine &candidate) { - vector sniffed_column_counts(STANDARD_VECTOR_SIZE); - candidate.csv_buffer_iterator.Process(candidate, sniffed_column_counts); - bool allow_padding = options.null_padding; - - for (idx_t row = 0; row < sniffed_column_counts.size(); row++) { - if (max_columns_found != sniffed_column_counts[row] && !allow_padding) { - return false; - } - } - return true; -} - -void CSVSniffer::RefineCandidates() { - // It's very frequent that more than one dialect can parse a csv file, hence here we run one state machine - // fully on the whole sample dataset, when/if it fails we go to the next one. - if (candidates.empty()) { - // No candidates to refine - return; - } - if (candidates.size() == 1 || candidates[0]->csv_buffer_iterator.Finished()) { - // Only one candidate nothing to refine or all candidates already checked - return; - } - for (auto &cur_candidate : candidates) { - for (idx_t i = 1; i <= options.sample_size_chunks; i++) { - bool finished_file = cur_candidate->csv_buffer_iterator.Finished(); - if (finished_file || i == options.sample_size_chunks) { - // we finished the file or our chunk sample successfully: stop - auto successful_candidate = std::move(cur_candidate); - candidates.clear(); - candidates.emplace_back(std::move(successful_candidate)); - return; - } - cur_candidate->cur_rows = 0; - cur_candidate->column_count = 1; - if (!RefineCandidateNextChunk(*cur_candidate)) { - // This candidate failed, move to the next one - break; - } - } - } - candidates.clear(); - return; -} - -// Dialect Detection consists of five steps: -// 1. Generate a search space of all possible dialects -// 2. Generate a state machine for each dialect -// 3. Analyze the first chunk of the file and find the best dialect candidates -// 4. Analyze the remaining chunks of the file and find the best dialect candidate -void CSVSniffer::DetectDialect() { - // Variables for Dialect Detection - // Candidates for the delimiter - vector delim_candidates; - // Quote-Rule Candidates - vector quoterule_candidates; - // Candidates for the quote option - unordered_map> quote_candidates_map; - // Candidates for the escape option - unordered_map> escape_candidates_map; - escape_candidates_map[(uint8_t)QuoteRule::QUOTES_RFC] = {'\0', '\"', '\''}; - escape_candidates_map[(uint8_t)QuoteRule::QUOTES_OTHER] = {'\\'}; - escape_candidates_map[(uint8_t)QuoteRule::NO_QUOTES] = {'\0'}; - // Number of rows read - idx_t rows_read = 0; - // Best Number of consistent rows (i.e., presenting all columns) - idx_t best_consistent_rows = 0; - // If padding was necessary (i.e., rows are missing some columns, how many) - idx_t prev_padding_count = 0; - // Vector of CSV State Machines - vector> csv_state_machines; - - // Step 1: Generate search space - GenerateCandidateDetectionSearchSpace(delim_candidates, quoterule_candidates, quote_candidates_map, - escape_candidates_map); - // Step 2: Generate state machines - GenerateStateMachineSearchSpace(csv_state_machines, delim_candidates, quoterule_candidates, quote_candidates_map, - escape_candidates_map); - // Step 3: Analyze all candidates on the first chunk - for (auto &state_machine : csv_state_machines) { - state_machine->Reset(); - AnalyzeDialectCandidate(std::move(state_machine), rows_read, best_consistent_rows, prev_padding_count); - } - // Step 4: Loop over candidates and find if they can still produce good results for the remaining chunks - RefineCandidates(); - // if no dialect candidate was found, we throw an exception - if (candidates.empty()) { - throw InvalidInputException( - "Error in file \"%s\": CSV options could not be auto-detected. Consider setting parser options manually.", - options.file_path); - } -} -} // namespace duckdb - - - - -namespace duckdb { - -// Helper function to generate column names -static string GenerateColumnName(const idx_t total_cols, const idx_t col_number, const string &prefix = "column") { - int max_digits = NumericHelper::UnsignedLength(total_cols - 1); - int digits = NumericHelper::UnsignedLength(col_number); - string leading_zeros = string(max_digits - digits, '0'); - string value = to_string(col_number); - return string(prefix + leading_zeros + value); -} - -// Helper function for UTF-8 aware space trimming -static string TrimWhitespace(const string &col_name) { - utf8proc_int32_t codepoint; - auto str = reinterpret_cast(col_name.c_str()); - idx_t size = col_name.size(); - // Find the first character that is not left trimmed - idx_t begin = 0; - while (begin < size) { - auto bytes = utf8proc_iterate(str + begin, size - begin, &codepoint); - D_ASSERT(bytes > 0); - if (utf8proc_category(codepoint) != UTF8PROC_CATEGORY_ZS) { - break; - } - begin += bytes; - } - - // Find the last character that is not right trimmed - idx_t end; - end = begin; - for (auto next = begin; next < col_name.size();) { - auto bytes = utf8proc_iterate(str + next, size - next, &codepoint); - D_ASSERT(bytes > 0); - next += bytes; - if (utf8proc_category(codepoint) != UTF8PROC_CATEGORY_ZS) { - end = next; - } - } - - // return the trimmed string - return col_name.substr(begin, end - begin); -} - -static string NormalizeColumnName(const string &col_name) { - // normalize UTF8 characters to NFKD - auto nfkd = utf8proc_NFKD(reinterpret_cast(col_name.c_str()), col_name.size()); - const string col_name_nfkd = string(const_char_ptr_cast(nfkd), strlen(const_char_ptr_cast(nfkd))); - free(nfkd); - - // only keep ASCII characters 0-9 a-z A-Z and replace spaces with regular whitespace - string col_name_ascii = ""; - for (idx_t i = 0; i < col_name_nfkd.size(); i++) { - if (col_name_nfkd[i] == '_' || (col_name_nfkd[i] >= '0' && col_name_nfkd[i] <= '9') || - (col_name_nfkd[i] >= 'A' && col_name_nfkd[i] <= 'Z') || - (col_name_nfkd[i] >= 'a' && col_name_nfkd[i] <= 'z')) { - col_name_ascii += col_name_nfkd[i]; - } else if (StringUtil::CharacterIsSpace(col_name_nfkd[i])) { - col_name_ascii += " "; - } - } - - // trim whitespace and replace remaining whitespace by _ - string col_name_trimmed = TrimWhitespace(col_name_ascii); - string col_name_cleaned = ""; - bool in_whitespace = false; - for (idx_t i = 0; i < col_name_trimmed.size(); i++) { - if (col_name_trimmed[i] == ' ') { - if (!in_whitespace) { - col_name_cleaned += "_"; - in_whitespace = true; - } - } else { - col_name_cleaned += col_name_trimmed[i]; - in_whitespace = false; - } - } - - // don't leave string empty; if not empty, make lowercase - if (col_name_cleaned.empty()) { - col_name_cleaned = "_"; - } else { - col_name_cleaned = StringUtil::Lower(col_name_cleaned); - } - - // prepend _ if name starts with a digit or is a reserved keyword - if (KeywordHelper::IsKeyword(col_name_cleaned) || (col_name_cleaned[0] >= '0' && col_name_cleaned[0] <= '9')) { - col_name_cleaned = "_" + col_name_cleaned; - } - return col_name_cleaned; -} -void CSVSniffer::DetectHeader() { - // information for header detection - bool first_row_consistent = true; - // check if header row is all null and/or consistent with detected column data types - bool first_row_nulls = true; - // This case will fail in dialect detection, so we assert here just for sanity - D_ASSERT(best_candidate->options.null_padding || - best_sql_types_candidates_per_column_idx.size() == best_header_row.size()); - for (idx_t col = 0; col < best_header_row.size(); col++) { - auto dummy_val = best_header_row[col]; - if (!dummy_val.IsNull()) { - first_row_nulls = false; - } - - // try cast to sql_type of column - const auto &sql_type = best_sql_types_candidates_per_column_idx[col].back(); - if (!TryCastValue(*best_candidate, dummy_val, sql_type)) { - first_row_consistent = false; - } - } - bool has_header; - if (!best_candidate->options.has_header) { - has_header = !first_row_consistent || first_row_nulls; - } else { - has_header = best_candidate->options.dialect_options.header; - } - // update parser info, and read, generate & set col_names based on previous findings - if (has_header) { - best_candidate->dialect_options.header = true; - case_insensitive_map_t name_collision_count; - - // get header names from CSV - for (idx_t col = 0; col < best_header_row.size(); col++) { - const auto &val = best_header_row[col]; - string col_name = val.ToString(); - - // generate name if field is empty - if (col_name.empty() || val.IsNull()) { - col_name = GenerateColumnName(best_candidate->dialect_options.num_cols, col); - } - - // normalize names or at least trim whitespace - if (best_candidate->options.normalize_names) { - col_name = NormalizeColumnName(col_name); - } else { - col_name = TrimWhitespace(col_name); - } - - // avoid duplicate header names - while (name_collision_count.find(col_name) != name_collision_count.end()) { - name_collision_count[col_name] += 1; - col_name = col_name + "_" + to_string(name_collision_count[col_name]); - } - names.push_back(col_name); - name_collision_count[col_name] = 0; - } - if (best_header_row.size() < best_candidate->dialect_options.num_cols && options.null_padding) { - for (idx_t col = best_header_row.size(); col < best_candidate->dialect_options.num_cols; col++) { - names.push_back(GenerateColumnName(best_candidate->dialect_options.num_cols, col)); - } - } else if (best_header_row.size() < best_candidate->dialect_options.num_cols) { - throw InternalException("Detected header has number of columns inferior to dialect detection"); - } - - } else { - best_candidate->dialect_options.header = false; - for (idx_t col = 0; col < best_candidate->dialect_options.num_cols; col++) { - names.push_back(GenerateColumnName(best_candidate->dialect_options.num_cols, col)); - } - } - - // If the user provided names, we must replace our header with the user provided names - for (idx_t i = 0; i < MinValue(names.size(), best_candidate->options.name_list.size()); i++) { - names[i] = best_candidate->options.name_list[i]; - } -} -} // namespace duckdb - - - - - -namespace duckdb { -struct TryCastFloatingOperator { - template - static bool Operation(string_t input) { - T result; - string error_message; - return OP::Operation(input, result, &error_message); - } -}; - -struct TupleSniffing { - idx_t line_number; - idx_t position; - bool set = false; - vector values; -}; - -static bool StartsWithNumericDate(string &separator, const string &value) { - auto begin = value.c_str(); - auto end = begin + value.size(); - - // StrpTimeFormat::Parse will skip whitespace, so we can too - auto field1 = std::find_if_not(begin, end, StringUtil::CharacterIsSpace); - if (field1 == end) { - return false; - } - - // first numeric field must start immediately - if (!StringUtil::CharacterIsDigit(*field1)) { - return false; - } - auto literal1 = std::find_if_not(field1, end, StringUtil::CharacterIsDigit); - if (literal1 == end) { - return false; - } - - // second numeric field must exist - auto field2 = std::find_if(literal1, end, StringUtil::CharacterIsDigit); - if (field2 == end) { - return false; - } - auto literal2 = std::find_if_not(field2, end, StringUtil::CharacterIsDigit); - if (literal2 == end) { - return false; - } - - // third numeric field must exist - auto field3 = std::find_if(literal2, end, StringUtil::CharacterIsDigit); - if (field3 == end) { - return false; - } - - // second literal must match first - if (((field3 - literal2) != (field2 - literal1)) || strncmp(literal1, literal2, (field2 - literal1)) != 0) { - return false; - } - - // copy the literal as the separator, escaping percent signs - separator.clear(); - while (literal1 < field2) { - const auto literal_char = *literal1++; - if (literal_char == '%') { - separator.push_back(literal_char); - } - separator.push_back(literal_char); - } - - return true; -} - -string GenerateDateFormat(const string &separator, const char *format_template) { - string format_specifier = format_template; - auto amount_of_dashes = std::count(format_specifier.begin(), format_specifier.end(), '-'); - // All our date formats must have at least one - - D_ASSERT(amount_of_dashes); - string result; - result.reserve(format_specifier.size() - amount_of_dashes + (amount_of_dashes * separator.size())); - for (auto &character : format_specifier) { - if (character == '-') { - result += separator; - } else { - result += character; - } - } - return result; -} - -bool CSVSniffer::TryCastValue(CSVStateMachine &candidate, const Value &value, const LogicalType &sql_type) { - if (value.IsNull()) { - return true; - } - if (candidate.dialect_options.has_format.find(LogicalTypeId::DATE)->second && - sql_type.id() == LogicalTypeId::DATE) { - date_t result; - string error_message; - return candidate.dialect_options.date_format.find(LogicalTypeId::DATE) - ->second.TryParseDate(string_t(StringValue::Get(value)), result, error_message); - } - if (candidate.dialect_options.has_format.find(LogicalTypeId::TIMESTAMP)->second && - sql_type.id() == LogicalTypeId::TIMESTAMP) { - timestamp_t result; - string error_message; - return candidate.dialect_options.date_format.find(LogicalTypeId::TIMESTAMP) - ->second.TryParseTimestamp(string_t(StringValue::Get(value)), result, error_message); - } - if (candidate.options.decimal_separator != "." && (sql_type.id() == LogicalTypeId::DOUBLE)) { - return TryCastFloatingOperator::Operation(StringValue::Get(value)); - } - Value new_value; - string error_message; - return value.TryCastAs(buffer_manager->context, sql_type, new_value, &error_message, true); -} - -void CSVSniffer::SetDateFormat(CSVStateMachine &candidate, const string &format_specifier, - const LogicalTypeId &sql_type) { - candidate.dialect_options.has_format[sql_type] = true; - auto &date_format = candidate.dialect_options.date_format[sql_type]; - date_format.format_specifier = format_specifier; - StrTimeFormat::ParseFormatSpecifier(date_format.format_specifier, date_format); -} - -struct SniffValue { - inline static void Initialize(CSVStateMachine &machine) { - machine.state = CSVState::STANDARD; - machine.previous_state = CSVState::STANDARD; - machine.pre_previous_state = CSVState::STANDARD; - machine.cur_rows = 0; - machine.value = ""; - machine.rows_read = 0; - } - - inline static bool Process(CSVStateMachine &machine, vector &sniffed_values, char current_char, - idx_t current_pos) { - - if ((machine.dialect_options.new_line == NewLineIdentifier::SINGLE && - (current_char == '\r' || current_char == '\n')) || - (machine.dialect_options.new_line == NewLineIdentifier::CARRY_ON && current_char == '\n')) { - machine.rows_read++; - } - - if ((machine.previous_state == CSVState::RECORD_SEPARATOR && machine.state != CSVState::EMPTY_LINE) || - (machine.state != CSVState::RECORD_SEPARATOR && machine.previous_state == CSVState::CARRIAGE_RETURN)) { - sniffed_values[machine.cur_rows].position = machine.line_start_pos; - sniffed_values[machine.cur_rows].set = true; - machine.line_start_pos = current_pos; - } - machine.pre_previous_state = machine.previous_state; - machine.previous_state = machine.state; - machine.state = static_cast( - machine.transition_array[static_cast(machine.state)][static_cast(current_char)]); - - bool carriage_return = machine.previous_state == CSVState::CARRIAGE_RETURN; - if (machine.previous_state == CSVState::DELIMITER || - (machine.previous_state == CSVState::RECORD_SEPARATOR && machine.state != CSVState::EMPTY_LINE) || - (machine.state != CSVState::RECORD_SEPARATOR && carriage_return)) { - // Started a new value - // Check if it's UTF-8 - machine.VerifyUTF8(); - if (machine.value.empty() || machine.value == machine.options.null_str) { - // We set empty == null value - sniffed_values[machine.cur_rows].values.push_back(Value(LogicalType::VARCHAR)); - } else { - sniffed_values[machine.cur_rows].values.push_back(Value(machine.value)); - } - sniffed_values[machine.cur_rows].line_number = machine.rows_read; - - machine.value = ""; - } - if (machine.state == CSVState::STANDARD || - (machine.state == CSVState::QUOTED && machine.previous_state == CSVState::QUOTED)) { - machine.value += current_char; - } - machine.cur_rows += - machine.previous_state == CSVState::RECORD_SEPARATOR && machine.state != CSVState::EMPTY_LINE; - // It means our carriage return is actually a record separator - machine.cur_rows += machine.state != CSVState::RECORD_SEPARATOR && carriage_return; - if (machine.cur_rows >= sniffed_values.size()) { - // We sniffed enough rows - return true; - } - return false; - } - - inline static void Finalize(CSVStateMachine &machine, vector &sniffed_values) { - if (machine.cur_rows < sniffed_values.size() && machine.state == CSVState::DELIMITER) { - // Started a new empty value - sniffed_values[machine.cur_rows].values.push_back(Value(machine.value)); - } - if (machine.cur_rows < sniffed_values.size() && machine.state != CSVState::EMPTY_LINE) { - machine.VerifyUTF8(); - sniffed_values[machine.cur_rows].line_number = machine.rows_read; - if (!sniffed_values[machine.cur_rows].set) { - sniffed_values[machine.cur_rows].position = machine.line_start_pos; - sniffed_values[machine.cur_rows].set = true; - } - - sniffed_values[machine.cur_rows++].values.push_back(Value(machine.value)); - } - sniffed_values.erase(sniffed_values.end() - (sniffed_values.size() - machine.cur_rows), sniffed_values.end()); - } -}; - -void CSVSniffer::DetectDateAndTimeStampFormats(CSVStateMachine &candidate, - map &has_format_candidates, - map> &format_candidates, - const LogicalType &sql_type, const string &separator, Value &dummy_val) { - // generate date format candidates the first time through - auto &type_format_candidates = format_candidates[sql_type.id()]; - const auto had_format_candidates = has_format_candidates[sql_type.id()]; - if (!has_format_candidates[sql_type.id()]) { - has_format_candidates[sql_type.id()] = true; - // order by preference - auto entry = format_template_candidates.find(sql_type.id()); - if (entry != format_template_candidates.end()) { - const auto &format_template_list = entry->second; - for (const auto &t : format_template_list) { - const auto format_string = GenerateDateFormat(separator, t); - // don't parse ISO 8601 - if (format_string.find("%Y-%m-%d") == string::npos) { - type_format_candidates.emplace_back(format_string); - } - } - } - // initialise the first candidate - candidate.dialect_options.has_format[sql_type.id()] = true; - // all formats are constructed to be valid - SetDateFormat(candidate, type_format_candidates.back(), sql_type.id()); - } - // check all formats and keep the first one that works - StrpTimeFormat::ParseResult result; - auto save_format_candidates = type_format_candidates; - while (!type_format_candidates.empty()) { - // avoid using exceptions for flow control... - auto ¤t_format = candidate.dialect_options.date_format[sql_type.id()]; - if (current_format.Parse(StringValue::Get(dummy_val), result)) { - break; - } - // doesn't work - move to the next one - type_format_candidates.pop_back(); - candidate.dialect_options.has_format[sql_type.id()] = (!type_format_candidates.empty()); - if (!type_format_candidates.empty()) { - SetDateFormat(candidate, type_format_candidates.back(), sql_type.id()); - } - } - // if none match, then this is not a value of type sql_type, - if (type_format_candidates.empty()) { - // so restore the candidates that did work. - // or throw them out if they were generated by this value. - if (had_format_candidates) { - type_format_candidates.swap(save_format_candidates); - if (!type_format_candidates.empty()) { - SetDateFormat(candidate, type_format_candidates.back(), sql_type.id()); - } - } else { - has_format_candidates[sql_type.id()] = false; - } - } -} - -void CSVSniffer::DetectTypes() { - idx_t min_varchar_cols = max_columns_found + 1; - vector return_types; - // check which info candidate leads to minimum amount of non-varchar columns... - for (auto &candidate : candidates) { - unordered_map> info_sql_types_candidates; - for (idx_t i = 0; i < candidate->dialect_options.num_cols; i++) { - info_sql_types_candidates[i] = candidate->options.auto_type_candidates; - } - map has_format_candidates; - map> format_candidates; - for (const auto &t : format_template_candidates) { - has_format_candidates[t.first] = false; - format_candidates[t.first].clear(); - } - D_ASSERT(candidate->dialect_options.num_cols > 0); - - // Set all return_types to VARCHAR so we can do datatype detection based on VARCHAR values - return_types.clear(); - return_types.assign(candidate->dialect_options.num_cols, LogicalType::VARCHAR); - - // Reset candidate for parsing - candidate->Reset(); - - // Parse chunk and read csv with info candidate - vector tuples(STANDARD_VECTOR_SIZE); - candidate->csv_buffer_iterator.Process(*candidate, tuples); - // Potentially Skip empty rows (I find this dirty, but it is what the original code does) - // The true line where parsing starts in reference to the csv file - idx_t true_line_start = 0; - idx_t true_pos = 0; - // The start point of the tuples - idx_t tuple_true_start = 0; - while (tuple_true_start < tuples.size()) { - if (tuples[tuple_true_start].values.empty() || - (tuples[tuple_true_start].values.size() == 1 && tuples[tuple_true_start].values[0].IsNull())) { - true_line_start = tuples[tuple_true_start].line_number; - true_pos = tuples[tuple_true_start].position; - tuple_true_start++; - } else { - break; - } - } - - // Potentially Skip Notes (I also find this dirty, but it is what the original code does) - while (tuple_true_start < tuples.size()) { - if (tuples[tuple_true_start].values.size() < max_columns_found && !options.null_padding) { - true_line_start = tuples[tuple_true_start].line_number; - true_pos = tuples[tuple_true_start].position; - tuple_true_start++; - } else { - break; - } - } - if (tuple_true_start < tuples.size()) { - true_pos = tuples[tuple_true_start].position; - } - if (tuple_true_start > 0) { - tuples.erase(tuples.begin(), tuples.begin() + tuple_true_start); - } - - idx_t row_idx = 0; - if (tuples.size() > 1 && (!options.has_header || (options.has_header && options.dialect_options.header))) { - // This means we have more than one row, hence we can use the first row to detect if we have a header - row_idx = 1; - } - if (!tuples.empty()) { - best_start_without_header = tuples[0].position - true_pos; - } - - // First line where we start our type detection - const idx_t start_idx_detection = row_idx; - for (; row_idx < tuples.size(); row_idx++) { - for (idx_t col = 0; col < tuples[row_idx].values.size(); col++) { - auto &col_type_candidates = info_sql_types_candidates[col]; - // col_type_candidates can't be empty since anything in a CSV file should at least be a string - // and we validate utf-8 compatibility when creating the type - D_ASSERT(!col_type_candidates.empty()); - auto cur_top_candidate = col_type_candidates.back(); - auto dummy_val = tuples[row_idx].values[col]; - // try cast from string to sql_type - while (col_type_candidates.size() > 1) { - const auto &sql_type = col_type_candidates.back(); - // try formatting for date types if the user did not specify one and it starts with numeric values. - string separator; - bool has_format_is_set = false; - auto format_iterator = candidate->dialect_options.has_format.find(sql_type.id()); - if (format_iterator != candidate->dialect_options.has_format.end()) { - has_format_is_set = format_iterator->second; - } - if (has_format_candidates.count(sql_type.id()) && - (!has_format_is_set || format_candidates[sql_type.id()].size() > 1) && !dummy_val.IsNull() && - StartsWithNumericDate(separator, StringValue::Get(dummy_val))) { - DetectDateAndTimeStampFormats(*candidate, has_format_candidates, format_candidates, sql_type, - separator, dummy_val); - } - // try cast from string to sql_type - if (TryCastValue(*candidate, dummy_val, sql_type)) { - break; - } else { - if (row_idx != start_idx_detection && cur_top_candidate == LogicalType::BOOLEAN) { - // If we thought this was a boolean value (i.e., T,F, True, False) and it is not, we - // immediately pop to varchar. - while (col_type_candidates.back() != LogicalType::VARCHAR) { - col_type_candidates.pop_back(); - } - break; - } - col_type_candidates.pop_back(); - } - } - } - } - - idx_t varchar_cols = 0; - - for (idx_t col = 0; col < info_sql_types_candidates.size(); col++) { - auto &col_type_candidates = info_sql_types_candidates[col]; - // check number of varchar columns - const auto &col_type = col_type_candidates.back(); - if (col_type == LogicalType::VARCHAR) { - varchar_cols++; - } - } - - // it's good if the dialect creates more non-varchar columns, but only if we sacrifice < 30% of best_num_cols. - if (varchar_cols < min_varchar_cols && info_sql_types_candidates.size() > (max_columns_found * 0.7)) { - // we have a new best_options candidate - if (true_line_start > 0) { - // Add empty rows to skip_rows - candidate->dialect_options.skip_rows += true_line_start; - } - best_candidate = std::move(candidate); - min_varchar_cols = varchar_cols; - best_sql_types_candidates_per_column_idx = info_sql_types_candidates; - best_format_candidates = format_candidates; - best_header_row = tuples[0].values; - best_start_with_header = tuples[0].position - true_pos; - } - } - // Assert that it's all good at this point. - D_ASSERT(best_candidate && !best_format_candidates.empty() && !best_header_row.empty()); - - for (const auto &best : best_format_candidates) { - if (!best.second.empty()) { - SetDateFormat(*best_candidate, best.second.back(), best.first); - } - } -} - -} // namespace duckdb - - -namespace duckdb { -struct Parse { - inline static void Initialize(CSVStateMachine &machine) { - machine.state = CSVState::STANDARD; - machine.previous_state = CSVState::STANDARD; - machine.pre_previous_state = CSVState::STANDARD; - - machine.cur_rows = 0; - machine.column_count = 0; - machine.value = ""; - } - - inline static bool Process(CSVStateMachine &machine, DataChunk &parse_chunk, char current_char, idx_t current_pos) { - - machine.pre_previous_state = machine.previous_state; - machine.previous_state = machine.state; - machine.state = static_cast( - machine.transition_array[static_cast(machine.state)][static_cast(current_char)]); - - bool carriage_return = machine.previous_state == CSVState::CARRIAGE_RETURN; - if (machine.previous_state == CSVState::DELIMITER || - (machine.previous_state == CSVState::RECORD_SEPARATOR && machine.state != CSVState::EMPTY_LINE) || - (machine.state != CSVState::RECORD_SEPARATOR && carriage_return)) { - // Started a new value - // Check if it's UTF-8 (Or not?) - machine.VerifyUTF8(); - auto &v = parse_chunk.data[machine.column_count++]; - auto parse_data = FlatVector::GetData(v); - auto &validity_mask = FlatVector::Validity(v); - if (machine.value.empty()) { - validity_mask.SetInvalid(machine.cur_rows); - } else { - parse_data[machine.cur_rows] = StringVector::AddStringOrBlob(v, string_t(machine.value)); - } - machine.value = ""; - } - if (((machine.previous_state == CSVState::RECORD_SEPARATOR && machine.state != CSVState::EMPTY_LINE) || - (machine.state != CSVState::RECORD_SEPARATOR && carriage_return)) && - machine.options.null_padding && machine.column_count < parse_chunk.ColumnCount()) { - // It's a new row, check if we need to pad stuff - while (machine.column_count < parse_chunk.ColumnCount()) { - auto &v = parse_chunk.data[machine.column_count++]; - auto &validity_mask = FlatVector::Validity(v); - validity_mask.SetInvalid(machine.cur_rows); - } - } - if (machine.state == CSVState::STANDARD || - (machine.state == CSVState::QUOTED && machine.previous_state == CSVState::QUOTED)) { - machine.value += current_char; - } - machine.cur_rows += - machine.previous_state == CSVState::RECORD_SEPARATOR && machine.state != CSVState::EMPTY_LINE; - machine.column_count -= machine.column_count * (machine.previous_state == CSVState::RECORD_SEPARATOR); - - // It means our carriage return is actually a record separator - machine.cur_rows += machine.state != CSVState::RECORD_SEPARATOR && carriage_return; - machine.column_count -= machine.column_count * (machine.state != CSVState::RECORD_SEPARATOR && carriage_return); - - if (machine.cur_rows >= STANDARD_VECTOR_SIZE) { - // We sniffed enough rows - return true; - } - return false; - } - - inline static void Finalize(CSVStateMachine &machine, DataChunk &parse_chunk) { - if (machine.cur_rows < STANDARD_VECTOR_SIZE && machine.state != CSVState::EMPTY_LINE) { - machine.VerifyUTF8(); - auto &v = parse_chunk.data[machine.column_count++]; - auto parse_data = FlatVector::GetData(v); - if (machine.value.empty()) { - auto &validity_mask = FlatVector::Validity(v); - validity_mask.SetInvalid(machine.cur_rows); - } else { - parse_data[machine.cur_rows] = StringVector::AddStringOrBlob(v, string_t(machine.value)); - } - while (machine.column_count < parse_chunk.ColumnCount()) { - auto &v_pad = parse_chunk.data[machine.column_count++]; - auto &validity_mask = FlatVector::Validity(v_pad); - validity_mask.SetInvalid(machine.cur_rows); - } - machine.cur_rows++; - } - parse_chunk.SetCardinality(machine.cur_rows); - } -}; - -bool CSVSniffer::TryCastVector(Vector &parse_chunk_col, idx_t size, const LogicalType &sql_type) { - // try vector-cast from string to sql_type - Vector dummy_result(sql_type); - if (best_candidate->dialect_options.has_format[LogicalTypeId::DATE] && sql_type == LogicalTypeId::DATE) { - // use the date format to cast the chunk - string error_message; - idx_t line_error; - return BaseCSVReader::TryCastDateVector(best_candidate->dialect_options.date_format, parse_chunk_col, - dummy_result, size, error_message, line_error); - } - if (best_candidate->dialect_options.has_format[LogicalTypeId::TIMESTAMP] && sql_type == LogicalTypeId::TIMESTAMP) { - // use the timestamp format to cast the chunk - string error_message; - return BaseCSVReader::TryCastTimestampVector(best_candidate->dialect_options.date_format, parse_chunk_col, - dummy_result, size, error_message); - } - // target type is not varchar: perform a cast - string error_message; - return VectorOperations::DefaultTryCast(parse_chunk_col, dummy_result, size, &error_message, true); -} - -void CSVSniffer::RefineTypes() { - // if data types were provided, exit here if number of columns does not match - detected_types.assign(best_candidate->dialect_options.num_cols, LogicalType::VARCHAR); - if (best_candidate->options.all_varchar) { - // return all types varchar - return; - } - DataChunk parse_chunk; - parse_chunk.Initialize(BufferAllocator::Get(buffer_manager->context), detected_types, STANDARD_VECTOR_SIZE); - for (idx_t i = 1; i < best_candidate->options.sample_size_chunks; i++) { - bool finished_file = best_candidate->csv_buffer_iterator.Finished(); - if (finished_file) { - // we finished the file: stop - // set sql types - detected_types.clear(); - for (idx_t column_idx = 0; column_idx < best_sql_types_candidates_per_column_idx.size(); column_idx++) { - LogicalType d_type = best_sql_types_candidates_per_column_idx[column_idx].back(); - if (best_sql_types_candidates_per_column_idx[column_idx].size() == - best_candidate->options.auto_type_candidates.size()) { - d_type = LogicalType::VARCHAR; - } - detected_types.push_back(d_type); - } - return; - } - best_candidate->csv_buffer_iterator.Process(*best_candidate, parse_chunk); - for (idx_t col = 0; col < parse_chunk.ColumnCount(); col++) { - vector &col_type_candidates = best_sql_types_candidates_per_column_idx[col]; - bool is_bool_type = col_type_candidates.back() == LogicalType::BOOLEAN; - while (col_type_candidates.size() > 1) { - const auto &sql_type = col_type_candidates.back(); - // narrow down the date formats - if (best_format_candidates.count(sql_type.id())) { - auto &best_type_format_candidates = best_format_candidates[sql_type.id()]; - auto save_format_candidates = best_type_format_candidates; - while (!best_type_format_candidates.empty()) { - if (TryCastVector(parse_chunk.data[col], parse_chunk.size(), sql_type)) { - break; - } - // doesn't work - move to the next one - best_type_format_candidates.pop_back(); - best_candidate->dialect_options.has_format[sql_type.id()] = - (!best_type_format_candidates.empty()); - if (!best_type_format_candidates.empty()) { - SetDateFormat(*best_candidate, best_type_format_candidates.back(), sql_type.id()); - } - } - // if none match, then this is not a column of type sql_type, - if (best_type_format_candidates.empty()) { - // so restore the candidates that did work. - best_type_format_candidates.swap(save_format_candidates); - if (!best_type_format_candidates.empty()) { - SetDateFormat(*best_candidate, best_type_format_candidates.back(), sql_type.id()); - } - } - } - if (TryCastVector(parse_chunk.data[col], parse_chunk.size(), sql_type)) { - break; - } else { - if (col_type_candidates.back() == LogicalType::BOOLEAN && is_bool_type) { - // If we thought this was a boolean value (i.e., T,F, True, False) and it is not, we - // immediately pop to varchar. - while (col_type_candidates.back() != LogicalType::VARCHAR) { - col_type_candidates.pop_back(); - } - break; - } - col_type_candidates.pop_back(); - } - } - } - // reset parse chunk for the next iteration - parse_chunk.Reset(); - } - detected_types.clear(); - // set sql types - for (idx_t column_idx = 0; column_idx < best_sql_types_candidates_per_column_idx.size(); column_idx++) { - LogicalType d_type = best_sql_types_candidates_per_column_idx[column_idx].back(); - if (best_sql_types_candidates_per_column_idx[column_idx].size() == - best_candidate->options.auto_type_candidates.size()) { - d_type = LogicalType::VARCHAR; - } - detected_types.push_back(d_type); - } -} -} // namespace duckdb - - - -namespace duckdb { -void CSVSniffer::ReplaceTypes() { - if (best_candidate->options.sql_type_list.empty()) { - return; - } - // user-defined types were supplied for certain columns - // override the types - if (!best_candidate->options.sql_types_per_column.empty()) { - // types supplied as name -> value map - idx_t found = 0; - for (idx_t i = 0; i < names.size(); i++) { - auto it = best_candidate->options.sql_types_per_column.find(names[i]); - if (it != best_candidate->options.sql_types_per_column.end()) { - best_sql_types_candidates_per_column_idx[i] = {best_candidate->options.sql_type_list[it->second]}; - found++; - } - } - if (!best_candidate->options.file_options.union_by_name && - found < best_candidate->options.sql_types_per_column.size()) { - string error_msg = BufferedCSVReader::ColumnTypesError(options.sql_types_per_column, names); - if (!error_msg.empty()) { - throw BinderException(error_msg); - } - } - return; - } - // types supplied as list - if (names.size() < best_candidate->options.sql_type_list.size()) { - throw BinderException("read_csv: %d types were provided, but CSV file only has %d columns", - best_candidate->options.sql_type_list.size(), names.size()); - } - for (idx_t i = 0; i < best_candidate->options.sql_type_list.size(); i++) { - best_sql_types_candidates_per_column_idx[i] = {best_candidate->options.sql_type_list[i]}; - } -} -} // namespace duckdb - - - - -namespace duckdb { - -PhysicalFilter::PhysicalFilter(vector types, vector> select_list, - idx_t estimated_cardinality) - : CachingPhysicalOperator(PhysicalOperatorType::FILTER, std::move(types), estimated_cardinality) { - D_ASSERT(select_list.size() > 0); - if (select_list.size() > 1) { - // create a big AND out of the expressions - auto conjunction = make_uniq(ExpressionType::CONJUNCTION_AND); - for (auto &expr : select_list) { - conjunction->children.push_back(std::move(expr)); - } - expression = std::move(conjunction); - } else { - expression = std::move(select_list[0]); - } -} - -class FilterState : public CachingOperatorState { -public: - explicit FilterState(ExecutionContext &context, Expression &expr) - : executor(context.client, expr), sel(STANDARD_VECTOR_SIZE) { - } - - ExpressionExecutor executor; - SelectionVector sel; - -public: - void Finalize(const PhysicalOperator &op, ExecutionContext &context) override { - context.thread.profiler.Flush(op, executor, "filter", 0); - } -}; - -unique_ptr PhysicalFilter::GetOperatorState(ExecutionContext &context) const { - return make_uniq(context, *expression); -} - -OperatorResultType PhysicalFilter::ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - GlobalOperatorState &gstate, OperatorState &state_p) const { - auto &state = state_p.Cast(); - idx_t result_count = state.executor.SelectExpression(input, state.sel); - if (result_count == input.size()) { - // nothing was filtered: skip adding any selection vectors - chunk.Reference(input); - } else { - chunk.Slice(input, state.sel, result_count); - } - return OperatorResultType::NEED_MORE_INPUT; -} - -string PhysicalFilter::ParamsToString() const { - auto result = expression->GetName(); - result += "\n[INFOSEPARATOR]\n"; - result += StringUtil::Format("EC: %llu", estimated_cardinality); - return result; -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -PhysicalBatchCollector::PhysicalBatchCollector(PreparedStatementData &data) : PhysicalResultCollector(data) { -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class BatchCollectorGlobalState : public GlobalSinkState { -public: - BatchCollectorGlobalState(ClientContext &context, const PhysicalBatchCollector &op) : data(context, op.types) { - } - - mutex glock; - BatchedDataCollection data; - unique_ptr result; -}; - -class BatchCollectorLocalState : public LocalSinkState { -public: - BatchCollectorLocalState(ClientContext &context, const PhysicalBatchCollector &op) : data(context, op.types) { - } - - BatchedDataCollection data; -}; - -SinkResultType PhysicalBatchCollector::Sink(ExecutionContext &context, DataChunk &chunk, - OperatorSinkInput &input) const { - auto &state = input.local_state.Cast(); - state.data.Append(chunk, state.partition_info.batch_index.GetIndex()); - return SinkResultType::NEED_MORE_INPUT; -} - -SinkCombineResultType PhysicalBatchCollector::Combine(ExecutionContext &context, - OperatorSinkCombineInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &state = input.local_state.Cast(); - - lock_guard lock(gstate.glock); - gstate.data.Merge(state.data); - - return SinkCombineResultType::FINISHED; -} - -SinkFinalizeType PhysicalBatchCollector::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - auto &gstate = input.global_state.Cast(); - auto collection = gstate.data.FetchCollection(); - D_ASSERT(collection); - auto result = make_uniq(statement_type, properties, names, std::move(collection), - context.GetClientProperties()); - gstate.result = std::move(result); - return SinkFinalizeType::READY; -} - -unique_ptr PhysicalBatchCollector::GetLocalSinkState(ExecutionContext &context) const { - return make_uniq(context.client, *this); -} - -unique_ptr PhysicalBatchCollector::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(context, *this); -} - -unique_ptr PhysicalBatchCollector::GetResult(GlobalSinkState &state) { - auto &gstate = state.Cast(); - D_ASSERT(gstate.result); - return std::move(gstate.result); -} - -} // namespace duckdb - - - - -namespace duckdb { - -PhysicalExecute::PhysicalExecute(PhysicalOperator &plan) - : PhysicalOperator(PhysicalOperatorType::EXECUTE, plan.types, -1), plan(plan) { -} - -vector> PhysicalExecute::GetChildren() const { - return {plan}; -} - -void PhysicalExecute::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { - // EXECUTE statement: build pipeline on child - meta_pipeline.Build(plan); -} - -} // namespace duckdb - - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class ExplainAnalyzeStateGlobalState : public GlobalSinkState { -public: - string analyzed_plan; -}; - -SinkResultType PhysicalExplainAnalyze::Sink(ExecutionContext &context, DataChunk &chunk, - OperatorSinkInput &input) const { - return SinkResultType::NEED_MORE_INPUT; -} - -SinkFinalizeType PhysicalExplainAnalyze::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &profiler = QueryProfiler::Get(context); - gstate.analyzed_plan = profiler.ToString(); - return SinkFinalizeType::READY; -} - -unique_ptr PhysicalExplainAnalyze::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(); -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -SourceResultType PhysicalExplainAnalyze::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &gstate = sink_state->Cast(); - - chunk.SetValue(0, 0, Value("analyzed_plan")); - chunk.SetValue(1, 0, Value(gstate.analyzed_plan)); - chunk.SetCardinality(1); - - return SourceResultType::FINISHED; -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -PhysicalLimit::PhysicalLimit(vector types, idx_t limit, idx_t offset, - unique_ptr limit_expression, unique_ptr offset_expression, - idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::LIMIT, std::move(types), estimated_cardinality), limit_value(limit), - offset_value(offset), limit_expression(std::move(limit_expression)), - offset_expression(std::move(offset_expression)) { -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class LimitGlobalState : public GlobalSinkState { -public: - explicit LimitGlobalState(ClientContext &context, const PhysicalLimit &op) : data(context, op.types, true) { - limit = 0; - offset = 0; - } - - mutex glock; - idx_t limit; - idx_t offset; - BatchedDataCollection data; -}; - -class LimitLocalState : public LocalSinkState { -public: - explicit LimitLocalState(ClientContext &context, const PhysicalLimit &op) - : current_offset(0), data(context, op.types, true) { - this->limit = op.limit_expression ? DConstants::INVALID_INDEX : op.limit_value; - this->offset = op.offset_expression ? DConstants::INVALID_INDEX : op.offset_value; - } - - idx_t current_offset; - idx_t limit; - idx_t offset; - BatchedDataCollection data; -}; - -unique_ptr PhysicalLimit::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(context, *this); -} - -unique_ptr PhysicalLimit::GetLocalSinkState(ExecutionContext &context) const { - return make_uniq(context.client, *this); -} - -bool PhysicalLimit::ComputeOffset(ExecutionContext &context, DataChunk &input, idx_t &limit, idx_t &offset, - idx_t current_offset, idx_t &max_element, Expression *limit_expression, - Expression *offset_expression) { - if (limit != DConstants::INVALID_INDEX && offset != DConstants::INVALID_INDEX) { - max_element = limit + offset; - if ((limit == 0 || current_offset >= max_element) && !(limit_expression || offset_expression)) { - return false; - } - } - - // get the next chunk from the child - if (limit == DConstants::INVALID_INDEX) { - limit = 1ULL << 62ULL; - Value val = GetDelimiter(context, input, limit_expression); - if (!val.IsNull()) { - limit = val.GetValue(); - } - if (limit > 1ULL << 62ULL) { - throw BinderException("Max value %lld for LIMIT/OFFSET is %lld", limit, 1ULL << 62ULL); - } - } - if (offset == DConstants::INVALID_INDEX) { - offset = 0; - Value val = GetDelimiter(context, input, offset_expression); - if (!val.IsNull()) { - offset = val.GetValue(); - } - if (offset > 1ULL << 62ULL) { - throw BinderException("Max value %lld for LIMIT/OFFSET is %lld", offset, 1ULL << 62ULL); - } - } - max_element = limit + offset; - if (limit == 0 || current_offset >= max_element) { - return false; - } - return true; -} - -SinkResultType PhysicalLimit::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - - D_ASSERT(chunk.size() > 0); - auto &state = input.local_state.Cast(); - auto &limit = state.limit; - auto &offset = state.offset; - - idx_t max_element; - if (!ComputeOffset(context, chunk, limit, offset, state.current_offset, max_element, limit_expression.get(), - offset_expression.get())) { - return SinkResultType::FINISHED; - } - auto max_cardinality = max_element - state.current_offset; - if (max_cardinality < chunk.size()) { - chunk.SetCardinality(max_cardinality); - } - state.data.Append(chunk, state.partition_info.batch_index.GetIndex()); - state.current_offset += chunk.size(); - if (state.current_offset == max_element) { - return SinkResultType::FINISHED; - } - return SinkResultType::NEED_MORE_INPUT; -} - -SinkCombineResultType PhysicalLimit::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &state = input.local_state.Cast(); - - lock_guard lock(gstate.glock); - gstate.limit = state.limit; - gstate.offset = state.offset; - gstate.data.Merge(state.data); - - return SinkCombineResultType::FINISHED; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -class LimitSourceState : public GlobalSourceState { -public: - LimitSourceState() { - initialized = false; - current_offset = 0; - } - - bool initialized; - idx_t current_offset; - BatchedChunkScanState scan_state; -}; - -unique_ptr PhysicalLimit::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(); -} - -SourceResultType PhysicalLimit::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { - auto &gstate = sink_state->Cast(); - auto &state = input.global_state.Cast(); - while (state.current_offset < gstate.limit + gstate.offset) { - if (!state.initialized) { - gstate.data.InitializeScan(state.scan_state); - state.initialized = true; - } - gstate.data.Scan(state.scan_state, chunk); - if (chunk.size() == 0) { - return SourceResultType::FINISHED; - } - if (HandleOffset(chunk, state.current_offset, gstate.offset, gstate.limit)) { - break; - } - } - - return chunk.size() > 0 ? SourceResultType::HAVE_MORE_OUTPUT : SourceResultType::FINISHED; -} - -bool PhysicalLimit::HandleOffset(DataChunk &input, idx_t ¤t_offset, idx_t offset, idx_t limit) { - idx_t max_element = limit + offset; - if (limit == DConstants::INVALID_INDEX) { - max_element = DConstants::INVALID_INDEX; - } - idx_t input_size = input.size(); - if (current_offset < offset) { - // we are not yet at the offset point - if (current_offset + input.size() > offset) { - // however we will reach it in this chunk - // we have to copy part of the chunk with an offset - idx_t start_position = offset - current_offset; - auto chunk_count = MinValue(limit, input.size() - start_position); - SelectionVector sel(STANDARD_VECTOR_SIZE); - for (idx_t i = 0; i < chunk_count; i++) { - sel.set_index(i, start_position + i); - } - // set up a slice of the input chunks - input.Slice(input, sel, chunk_count); - } else { - current_offset += input_size; - return false; - } - } else { - // have to copy either the entire chunk or part of it - idx_t chunk_count; - if (current_offset + input.size() >= max_element) { - // have to limit the count of the chunk - chunk_count = max_element - current_offset; - } else { - // we copy the entire chunk - chunk_count = input.size(); - } - // instead of copying we just change the pointer in the current chunk - input.Reference(input); - input.SetCardinality(chunk_count); - } - - current_offset += input_size; - return true; -} - -Value PhysicalLimit::GetDelimiter(ExecutionContext &context, DataChunk &input, Expression *expr) { - DataChunk limit_chunk; - vector types {expr->return_type}; - auto &allocator = Allocator::Get(context.client); - limit_chunk.Initialize(allocator, types); - ExpressionExecutor limit_executor(context.client, expr); - auto input_size = input.size(); - input.SetCardinality(1); - limit_executor.Execute(input, limit_chunk); - input.SetCardinality(input_size); - auto limit_value = limit_chunk.GetValue(0, 0); - return limit_value; -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class LimitPercentGlobalState : public GlobalSinkState { -public: - explicit LimitPercentGlobalState(ClientContext &context, const PhysicalLimitPercent &op) - : current_offset(0), data(context, op.GetTypes()) { - if (!op.limit_expression) { - this->limit_percent = op.limit_percent; - is_limit_percent_delimited = true; - } else { - this->limit_percent = 100.0; - } - - if (!op.offset_expression) { - this->offset = op.offset_value; - is_offset_delimited = true; - } else { - this->offset = 0; - } - } - - idx_t current_offset; - double limit_percent; - idx_t offset; - ColumnDataCollection data; - - bool is_limit_percent_delimited = false; - bool is_offset_delimited = false; -}; - -unique_ptr PhysicalLimitPercent::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(context, *this); -} - -SinkResultType PhysicalLimitPercent::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - D_ASSERT(chunk.size() > 0); - auto &state = input.global_state.Cast(); - auto &limit_percent = state.limit_percent; - auto &offset = state.offset; - - // get the next chunk from the child - if (!state.is_limit_percent_delimited) { - Value val = PhysicalLimit::GetDelimiter(context, chunk, limit_expression.get()); - if (!val.IsNull()) { - limit_percent = val.GetValue(); - } - if (limit_percent < 0.0) { - throw BinderException("Percentage value(%f) can't be negative", limit_percent); - } - state.is_limit_percent_delimited = true; - } - if (!state.is_offset_delimited) { - Value val = PhysicalLimit::GetDelimiter(context, chunk, offset_expression.get()); - if (!val.IsNull()) { - offset = val.GetValue(); - } - if (offset > 1ULL << 62ULL) { - throw BinderException("Max value %lld for LIMIT/OFFSET is %lld", offset, 1ULL << 62ULL); - } - state.is_offset_delimited = true; - } - - if (!PhysicalLimit::HandleOffset(chunk, state.current_offset, offset, DConstants::INVALID_INDEX)) { - return SinkResultType::NEED_MORE_INPUT; - } - - state.data.Append(chunk); - return SinkResultType::NEED_MORE_INPUT; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -class LimitPercentOperatorState : public GlobalSourceState { -public: - explicit LimitPercentOperatorState(const PhysicalLimitPercent &op) - : limit(DConstants::INVALID_INDEX), current_offset(0) { - D_ASSERT(op.sink_state); - auto &gstate = op.sink_state->Cast(); - gstate.data.InitializeScan(scan_state); - } - - ColumnDataScanState scan_state; - idx_t limit; - idx_t current_offset; -}; - -unique_ptr PhysicalLimitPercent::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(*this); -} - -SourceResultType PhysicalLimitPercent::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &gstate = sink_state->Cast(); - auto &state = input.global_state.Cast(); - auto &percent_limit = gstate.limit_percent; - auto &offset = gstate.offset; - auto &limit = state.limit; - auto ¤t_offset = state.current_offset; - - if (gstate.is_limit_percent_delimited && limit == DConstants::INVALID_INDEX) { - idx_t count = gstate.data.Count(); - if (count > 0) { - count += offset; - } - if (Value::IsNan(percent_limit) || percent_limit < 0 || percent_limit > 100) { - throw OutOfRangeException("Limit percent out of range, should be between 0% and 100%"); - } - double limit_dbl = percent_limit / 100 * count; - if (limit_dbl > count) { - limit = count; - } else { - limit = idx_t(limit_dbl); - } - if (limit == 0) { - return SourceResultType::FINISHED; - } - } - - if (current_offset >= limit) { - return SourceResultType::FINISHED; - } - if (!gstate.data.Scan(state.scan_state, chunk)) { - return SourceResultType::FINISHED; - } - - PhysicalLimit::HandleOffset(chunk, current_offset, 0, limit); - - return SourceResultType::HAVE_MORE_OUTPUT; -} - -} // namespace duckdb - - - -namespace duckdb { - -SourceResultType PhysicalLoad::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { - if (info->load_type == LoadType::INSTALL || info->load_type == LoadType::FORCE_INSTALL) { - ExtensionHelper::InstallExtension(context.client, info->filename, info->load_type == LoadType::FORCE_INSTALL, - info->repository); - } else { - ExtensionHelper::LoadExternalExtension(context.client, info->filename); - } - - return SourceResultType::FINISHED; -} - -} // namespace duckdb - - - - - -namespace duckdb { - -PhysicalMaterializedCollector::PhysicalMaterializedCollector(PreparedStatementData &data, bool parallel) - : PhysicalResultCollector(data), parallel(parallel) { -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class MaterializedCollectorGlobalState : public GlobalSinkState { -public: - mutex glock; - unique_ptr collection; - shared_ptr context; -}; - -class MaterializedCollectorLocalState : public LocalSinkState { -public: - unique_ptr collection; - ColumnDataAppendState append_state; -}; - -SinkResultType PhysicalMaterializedCollector::Sink(ExecutionContext &context, DataChunk &chunk, - OperatorSinkInput &input) const { - auto &lstate = input.local_state.Cast(); - lstate.collection->Append(lstate.append_state, chunk); - return SinkResultType::NEED_MORE_INPUT; -} - -SinkCombineResultType PhysicalMaterializedCollector::Combine(ExecutionContext &context, - OperatorSinkCombineInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - if (lstate.collection->Count() == 0) { - return SinkCombineResultType::FINISHED; - } - - lock_guard l(gstate.glock); - if (!gstate.collection) { - gstate.collection = std::move(lstate.collection); - } else { - gstate.collection->Combine(*lstate.collection); - } - - return SinkCombineResultType::FINISHED; -} - -unique_ptr PhysicalMaterializedCollector::GetGlobalSinkState(ClientContext &context) const { - auto state = make_uniq(); - state->context = context.shared_from_this(); - return std::move(state); -} - -unique_ptr PhysicalMaterializedCollector::GetLocalSinkState(ExecutionContext &context) const { - auto state = make_uniq(); - state->collection = make_uniq(Allocator::DefaultAllocator(), types); - state->collection->InitializeAppend(state->append_state); - return std::move(state); -} - -unique_ptr PhysicalMaterializedCollector::GetResult(GlobalSinkState &state) { - auto &gstate = state.Cast(); - if (!gstate.collection) { - gstate.collection = make_uniq(Allocator::DefaultAllocator(), types); - } - auto result = make_uniq(statement_type, properties, names, std::move(gstate.collection), - gstate.context->GetClientProperties()); - return std::move(result); -} - -bool PhysicalMaterializedCollector::ParallelSink() const { - return parallel; -} - -bool PhysicalMaterializedCollector::SinkOrderDependent() const { - return true; -} - -} // namespace duckdb - - -namespace duckdb { - -SourceResultType PhysicalPragma::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &client = context.client; - FunctionParameters parameters {info.parameters, info.named_parameters}; - function.function(client, parameters); - - return SourceResultType::FINISHED; -} - -} // namespace duckdb - - - -namespace duckdb { - -SourceResultType PhysicalPrepare::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &client = context.client; - - // store the prepared statement in the context - ClientData::Get(client).prepared_statements[name] = prepared; - - return SourceResultType::FINISHED; -} - -} // namespace duckdb - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class SampleGlobalSinkState : public GlobalSinkState { -public: - explicit SampleGlobalSinkState(Allocator &allocator, SampleOptions &options) { - if (options.is_percentage) { - auto percentage = options.sample_size.GetValue(); - if (percentage == 0) { - return; - } - sample = make_uniq(allocator, percentage, options.seed); - } else { - auto size = options.sample_size.GetValue(); - if (size == 0) { - return; - } - sample = make_uniq(allocator, size, options.seed); - } - } - - //! The lock for updating the global aggregate state - mutex lock; - //! The reservoir sample - unique_ptr sample; -}; - -unique_ptr PhysicalReservoirSample::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(Allocator::Get(context), *options); -} - -SinkResultType PhysicalReservoirSample::Sink(ExecutionContext &context, DataChunk &chunk, - OperatorSinkInput &input) const { - auto &gstate = input.global_state.Cast(); - if (!gstate.sample) { - return SinkResultType::FINISHED; - } - // we implement reservoir sampling without replacement and exponential jumps here - // the algorithm is adopted from the paper Weighted random sampling with a reservoir by Pavlos S. Efraimidis et al. - // note that the original algorithm is about weighted sampling; this is a simplified approach for uniform sampling - lock_guard glock(gstate.lock); - gstate.sample->AddToReservoir(chunk); - return SinkResultType::NEED_MORE_INPUT; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -SourceResultType PhysicalReservoirSample::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &sink = this->sink_state->Cast(); - if (!sink.sample) { - return SourceResultType::FINISHED; - } - auto sample_chunk = sink.sample->GetChunk(); - if (!sample_chunk) { - return SourceResultType::FINISHED; - } - chunk.Move(*sample_chunk); - - return SourceResultType::HAVE_MORE_OUTPUT; -} - -string PhysicalReservoirSample::ParamsToString() const { - return options->sample_size.ToString() + (options->is_percentage ? "%" : " rows"); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -void PhysicalReset::ResetExtensionVariable(ExecutionContext &context, DBConfig &config, - ExtensionOption &extension_option) const { - if (extension_option.set_function) { - extension_option.set_function(context.client, scope, extension_option.default_value); - } - if (scope == SetScope::GLOBAL) { - config.ResetOption(name); - } else { - auto &client_config = ClientConfig::GetConfig(context.client); - client_config.set_variables[name] = extension_option.default_value; - } -} - -SourceResultType PhysicalReset::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { - auto &config = DBConfig::GetConfig(context.client); - if (config.options.lock_configuration) { - throw InvalidInputException("Cannot reset configuration option \"%s\" - the configuration has been locked", - name); - } - auto option = DBConfig::GetOptionByName(name); - if (!option) { - // check if this is an extra extension variable - auto entry = config.extension_parameters.find(name); - if (entry == config.extension_parameters.end()) { - Catalog::AutoloadExtensionByConfigName(context.client, name); - entry = config.extension_parameters.find(name); - D_ASSERT(entry != config.extension_parameters.end()); - } - ResetExtensionVariable(context, config, entry->second); - return SourceResultType::FINISHED; - } - - // Transform scope - SetScope variable_scope = scope; - if (variable_scope == SetScope::AUTOMATIC) { - if (option->set_local) { - variable_scope = SetScope::SESSION; - } else { - D_ASSERT(option->set_global); - variable_scope = SetScope::GLOBAL; - } - } - - switch (variable_scope) { - case SetScope::GLOBAL: { - if (!option->set_global) { - throw CatalogException("option \"%s\" cannot be reset globally", name); - } - auto &db = DatabaseInstance::GetDatabase(context.client); - config.ResetOption(&db, *option); - break; - } - case SetScope::SESSION: - if (!option->reset_local) { - throw CatalogException("option \"%s\" cannot be reset locally", name); - } - option->reset_local(context.client); - break; - default: - throw InternalException("Unsupported SetScope for variable"); - } - - return SourceResultType::FINISHED; -} - -} // namespace duckdb - - - - - - - - - - -namespace duckdb { - -PhysicalResultCollector::PhysicalResultCollector(PreparedStatementData &data) - : PhysicalOperator(PhysicalOperatorType::RESULT_COLLECTOR, {LogicalType::BOOLEAN}, 0), - statement_type(data.statement_type), properties(data.properties), plan(*data.plan), names(data.names) { - this->types = data.types; -} - -unique_ptr PhysicalResultCollector::GetResultCollector(ClientContext &context, - PreparedStatementData &data) { - if (!PhysicalPlanGenerator::PreserveInsertionOrder(context, *data.plan)) { - // the plan is not order preserving, so we just use the parallel materialized collector - return make_uniq_base(data, true); - } else if (!PhysicalPlanGenerator::UseBatchIndex(context, *data.plan)) { - // the plan is order preserving, but we cannot use the batch index: use a single-threaded result collector - return make_uniq_base(data, false); - } else { - // we care about maintaining insertion order and the sources all support batch indexes - // use a batch collector - return make_uniq_base(data); - } -} - -vector> PhysicalResultCollector::GetChildren() const { - return {plan}; -} - -void PhysicalResultCollector::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { - // operator is a sink, build a pipeline - sink_state.reset(); - - D_ASSERT(children.empty()); - - // single operator: the operator becomes the data source of the current pipeline - auto &state = meta_pipeline.GetState(); - state.SetPipelineSource(current, *this); - - // we create a new pipeline starting from the child - auto &child_meta_pipeline = meta_pipeline.CreateChildMetaPipeline(current, *this); - child_meta_pipeline.Build(plan); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -void PhysicalSet::SetExtensionVariable(ClientContext &context, ExtensionOption &extension_option, const string &name, - SetScope scope, const Value &value) { - auto &config = DBConfig::GetConfig(context); - auto &target_type = extension_option.type; - Value target_value = value.CastAs(context, target_type); - if (extension_option.set_function) { - extension_option.set_function(context, scope, target_value); - } - if (scope == SetScope::GLOBAL) { - config.SetOption(name, std::move(target_value)); - } else { - auto &client_config = ClientConfig::GetConfig(context); - client_config.set_variables[name] = std::move(target_value); - } -} - -SourceResultType PhysicalSet::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { - auto &config = DBConfig::GetConfig(context.client); - if (config.options.lock_configuration) { - throw InvalidInputException("Cannot change configuration option \"%s\" - the configuration has been locked", - name); - } - auto option = DBConfig::GetOptionByName(name); - if (!option) { - // check if this is an extra extension variable - auto entry = config.extension_parameters.find(name); - if (entry == config.extension_parameters.end()) { - Catalog::AutoloadExtensionByConfigName(context.client, name); - entry = config.extension_parameters.find(name); - D_ASSERT(entry != config.extension_parameters.end()); - } - SetExtensionVariable(context.client, entry->second, name, scope, value); - return SourceResultType::FINISHED; - } - SetScope variable_scope = scope; - if (variable_scope == SetScope::AUTOMATIC) { - if (option->set_local) { - variable_scope = SetScope::SESSION; - } else { - D_ASSERT(option->set_global); - variable_scope = SetScope::GLOBAL; - } - } - - Value input_val = value.CastAs(context.client, option->parameter_type); - switch (variable_scope) { - case SetScope::GLOBAL: { - if (!option->set_global) { - throw CatalogException("option \"%s\" cannot be set globally", name); - } - auto &db = DatabaseInstance::GetDatabase(context.client); - auto &config = DBConfig::GetConfig(context.client); - config.SetOption(&db, *option, input_val); - break; - } - case SetScope::SESSION: - if (!option->set_local) { - throw CatalogException("option \"%s\" cannot be set locally", name); - } - option->set_local(context.client, input_val); - break; - default: - throw InternalException("Unsupported SetScope for variable"); - } - - return SourceResultType::FINISHED; -} - -} // namespace duckdb - - - -namespace duckdb { - -PhysicalStreamingLimit::PhysicalStreamingLimit(vector types, idx_t limit, idx_t offset, - unique_ptr limit_expression, - unique_ptr offset_expression, idx_t estimated_cardinality, - bool parallel) - : PhysicalOperator(PhysicalOperatorType::STREAMING_LIMIT, std::move(types), estimated_cardinality), - limit_value(limit), offset_value(offset), limit_expression(std::move(limit_expression)), - offset_expression(std::move(offset_expression)), parallel(parallel) { -} - -//===--------------------------------------------------------------------===// -// Operator -//===--------------------------------------------------------------------===// -class StreamingLimitOperatorState : public OperatorState { -public: - explicit StreamingLimitOperatorState(const PhysicalStreamingLimit &op) { - this->limit = op.limit_expression ? DConstants::INVALID_INDEX : op.limit_value; - this->offset = op.offset_expression ? DConstants::INVALID_INDEX : op.offset_value; - } - - idx_t limit; - idx_t offset; -}; - -class StreamingLimitGlobalState : public GlobalOperatorState { -public: - StreamingLimitGlobalState() : current_offset(0) { - } - - std::atomic current_offset; -}; - -unique_ptr PhysicalStreamingLimit::GetOperatorState(ExecutionContext &context) const { - return make_uniq(*this); -} - -unique_ptr PhysicalStreamingLimit::GetGlobalOperatorState(ClientContext &context) const { - return make_uniq(); -} - -OperatorResultType PhysicalStreamingLimit::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - GlobalOperatorState &gstate_p, OperatorState &state_p) const { - auto &gstate = gstate_p.Cast(); - auto &state = state_p.Cast(); - auto &limit = state.limit; - auto &offset = state.offset; - idx_t current_offset = gstate.current_offset.fetch_add(input.size()); - idx_t max_element; - if (!PhysicalLimit::ComputeOffset(context, input, limit, offset, current_offset, max_element, - limit_expression.get(), offset_expression.get())) { - return OperatorResultType::FINISHED; - } - if (PhysicalLimit::HandleOffset(input, current_offset, offset, limit)) { - chunk.Reference(input); - } - return OperatorResultType::NEED_MORE_INPUT; -} - -OrderPreservationType PhysicalStreamingLimit::OperatorOrder() const { - return OrderPreservationType::FIXED_ORDER; -} - -bool PhysicalStreamingLimit::ParallelOperator() const { - return parallel; -} - -} // namespace duckdb - - - - - -namespace duckdb { - -PhysicalStreamingSample::PhysicalStreamingSample(vector types, SampleMethod method, double percentage, - int64_t seed, idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::STREAMING_SAMPLE, std::move(types), estimated_cardinality), method(method), - percentage(percentage / 100), seed(seed) { -} - -//===--------------------------------------------------------------------===// -// Operator -//===--------------------------------------------------------------------===// -class StreamingSampleOperatorState : public OperatorState { -public: - explicit StreamingSampleOperatorState(int64_t seed) : random(seed) { - } - - RandomEngine random; -}; - -void PhysicalStreamingSample::SystemSample(DataChunk &input, DataChunk &result, OperatorState &state_p) const { - // system sampling: we throw one dice per chunk - auto &state = state_p.Cast(); - double rand = state.random.NextRandom(); - if (rand <= percentage) { - // rand is smaller than sample_size: output chunk - result.Reference(input); - } -} - -void PhysicalStreamingSample::BernoulliSample(DataChunk &input, DataChunk &result, OperatorState &state_p) const { - // bernoulli sampling: we throw one dice per tuple - // then slice the result chunk - auto &state = state_p.Cast(); - idx_t result_count = 0; - SelectionVector sel(STANDARD_VECTOR_SIZE); - for (idx_t i = 0; i < input.size(); i++) { - double rand = state.random.NextRandom(); - if (rand <= percentage) { - sel.set_index(result_count++, i); - } - } - if (result_count > 0) { - result.Slice(input, sel, result_count); - } -} - -unique_ptr PhysicalStreamingSample::GetOperatorState(ExecutionContext &context) const { - return make_uniq(seed); -} - -OperatorResultType PhysicalStreamingSample::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - GlobalOperatorState &gstate, OperatorState &state) const { - switch (method) { - case SampleMethod::BERNOULLI_SAMPLE: - BernoulliSample(input, chunk, state); - break; - case SampleMethod::SYSTEM_SAMPLE: - SystemSample(input, chunk, state); - break; - default: - throw InternalException("Unsupported sample method for streaming sample"); - } - return OperatorResultType::NEED_MORE_INPUT; -} - -string PhysicalStreamingSample::ParamsToString() const { - return EnumUtil::ToString(method) + ": " + to_string(100 * percentage) + "%"; -} - -} // namespace duckdb - - - - -namespace duckdb { - -SourceResultType PhysicalTransaction::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &client = context.client; - - auto type = info->type; - if (type == TransactionType::COMMIT && ValidChecker::IsInvalidated(client.ActiveTransaction())) { - // transaction is invalidated - turn COMMIT into ROLLBACK - type = TransactionType::ROLLBACK; - } - switch (type) { - case TransactionType::BEGIN_TRANSACTION: { - if (client.transaction.IsAutoCommit()) { - // start the active transaction - // if autocommit is active, we have already called - // BeginTransaction by setting autocommit to false we - // prevent it from being closed after this query, hence - // preserving the transaction context for the next query - client.transaction.SetAutoCommit(false); - } else { - throw TransactionException("cannot start a transaction within a transaction"); - } - break; - } - case TransactionType::COMMIT: { - if (client.transaction.IsAutoCommit()) { - throw TransactionException("cannot commit - no transaction is active"); - } else { - // explicitly commit the current transaction - client.transaction.Commit(); - } - break; - } - case TransactionType::ROLLBACK: { - if (client.transaction.IsAutoCommit()) { - throw TransactionException("cannot rollback - no transaction is active"); - } else { - // explicitly rollback the current transaction - client.transaction.Rollback(); - } - break; - } - default: - throw NotImplementedException("Unrecognized transaction type!"); - } - - return SourceResultType::FINISHED; -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -PhysicalVacuum::PhysicalVacuum(unique_ptr info_p, idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::VACUUM, {LogicalType::BOOLEAN}, estimated_cardinality), - info(std::move(info_p)) { -} - -class VacuumLocalSinkState : public LocalSinkState { -public: - explicit VacuumLocalSinkState(VacuumInfo &info) { - for (idx_t col_idx = 0; col_idx < info.columns.size(); col_idx++) { - column_distinct_stats.push_back(make_uniq()); - } - }; - - vector> column_distinct_stats; -}; - -unique_ptr PhysicalVacuum::GetLocalSinkState(ExecutionContext &context) const { - return make_uniq(*info); -} - -class VacuumGlobalSinkState : public GlobalSinkState { -public: - explicit VacuumGlobalSinkState(VacuumInfo &info) { - for (idx_t col_idx = 0; col_idx < info.columns.size(); col_idx++) { - column_distinct_stats.push_back(make_uniq()); - } - }; - - mutex stats_lock; - vector> column_distinct_stats; -}; - -unique_ptr PhysicalVacuum::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(*info); -} - -SinkResultType PhysicalVacuum::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - auto &lstate = input.local_state.Cast(); - D_ASSERT(lstate.column_distinct_stats.size() == info->column_id_map.size()); - - for (idx_t col_idx = 0; col_idx < chunk.data.size(); col_idx++) { - if (!DistinctStatistics::TypeIsSupported(chunk.data[col_idx].GetType())) { - continue; - } - lstate.column_distinct_stats[col_idx]->Update(chunk.data[col_idx], chunk.size(), false); - } - - return SinkResultType::NEED_MORE_INPUT; -} - -SinkCombineResultType PhysicalVacuum::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - - lock_guard lock(gstate.stats_lock); - D_ASSERT(gstate.column_distinct_stats.size() == lstate.column_distinct_stats.size()); - for (idx_t col_idx = 0; col_idx < gstate.column_distinct_stats.size(); col_idx++) { - gstate.column_distinct_stats[col_idx]->Merge(*lstate.column_distinct_stats[col_idx]); - } - - return SinkCombineResultType::FINISHED; -} - -SinkFinalizeType PhysicalVacuum::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - auto &sink = input.global_state.Cast(); - - auto table = info->table; - for (idx_t col_idx = 0; col_idx < sink.column_distinct_stats.size(); col_idx++) { - table->GetStorage().SetDistinct(info->column_id_map.at(col_idx), - std::move(sink.column_distinct_stats[col_idx])); - } - - return SinkFinalizeType::READY; -} - -SourceResultType PhysicalVacuum::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - // NOP - return SourceResultType::FINISHED; -} - -} // namespace duckdb - - -namespace duckdb { - -OuterJoinMarker::OuterJoinMarker(bool enabled_p) : enabled(enabled_p), count(0) { -} - -void OuterJoinMarker::Initialize(idx_t count_p) { - if (!enabled) { - return; - } - this->count = count_p; - found_match = make_unsafe_uniq_array(count); - Reset(); -} - -void OuterJoinMarker::Reset() { - if (!enabled) { - return; - } - memset(found_match.get(), 0, sizeof(bool) * count); -} - -void OuterJoinMarker::SetMatch(idx_t position) { - if (!enabled) { - return; - } - D_ASSERT(position < count); - found_match[position] = true; -} - -void OuterJoinMarker::SetMatches(const SelectionVector &sel, idx_t count, idx_t base_idx) { - if (!enabled) { - return; - } - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto pos = base_idx + idx; - D_ASSERT(pos < this->count); - found_match[pos] = true; - } -} - -void OuterJoinMarker::ConstructLeftJoinResult(DataChunk &left, DataChunk &result) { - if (!enabled) { - return; - } - D_ASSERT(count == STANDARD_VECTOR_SIZE); - SelectionVector remaining_sel(STANDARD_VECTOR_SIZE); - idx_t remaining_count = 0; - for (idx_t i = 0; i < left.size(); i++) { - if (!found_match[i]) { - remaining_sel.set_index(remaining_count++, i); - } - } - if (remaining_count > 0) { - result.Slice(left, remaining_sel, remaining_count); - for (idx_t idx = left.ColumnCount(); idx < result.ColumnCount(); idx++) { - result.data[idx].SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result.data[idx], true); - } - } -} - -idx_t OuterJoinMarker::MaxThreads() const { - return count / (STANDARD_VECTOR_SIZE * 10ULL); -} - -void OuterJoinMarker::InitializeScan(ColumnDataCollection &data, OuterJoinGlobalScanState &gstate) { - gstate.data = &data; - data.InitializeScan(gstate.global_scan); -} - -void OuterJoinMarker::InitializeScan(OuterJoinGlobalScanState &gstate, OuterJoinLocalScanState &lstate) { - D_ASSERT(gstate.data); - lstate.match_sel.Initialize(STANDARD_VECTOR_SIZE); - gstate.data->InitializeScanChunk(lstate.scan_chunk); -} - -void OuterJoinMarker::Scan(OuterJoinGlobalScanState &gstate, OuterJoinLocalScanState &lstate, DataChunk &result) { - D_ASSERT(gstate.data); - // fill in NULL values for the LHS - while (gstate.data->Scan(gstate.global_scan, lstate.local_scan, lstate.scan_chunk)) { - idx_t result_count = 0; - // figure out which tuples didn't find a match in the RHS - for (idx_t i = 0; i < lstate.scan_chunk.size(); i++) { - if (!found_match[lstate.local_scan.current_row_index + i]) { - lstate.match_sel.set_index(result_count++, i); - } - } - if (result_count > 0) { - // if there were any tuples that didn't find a match, output them - idx_t left_column_count = result.ColumnCount() - lstate.scan_chunk.ColumnCount(); - for (idx_t i = 0; i < left_column_count; i++) { - result.data[i].SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result.data[i], true); - } - for (idx_t col_idx = left_column_count; col_idx < result.ColumnCount(); col_idx++) { - result.data[col_idx].Slice(lstate.scan_chunk.data[col_idx - left_column_count], lstate.match_sel, - result_count); - } - result.SetCardinality(result_count); - return; - } - } -} - -} // namespace duckdb - - - - - -namespace duckdb { - -PerfectHashJoinExecutor::PerfectHashJoinExecutor(const PhysicalHashJoin &join_p, JoinHashTable &ht_p, - PerfectHashJoinStats perfect_join_stats) - : join(join_p), ht(ht_p), perfect_join_statistics(std::move(perfect_join_stats)) { -} - -bool PerfectHashJoinExecutor::CanDoPerfectHashJoin() { - return perfect_join_statistics.is_build_small; -} - -//===--------------------------------------------------------------------===// -// Build -//===--------------------------------------------------------------------===// -bool PerfectHashJoinExecutor::BuildPerfectHashTable(LogicalType &key_type) { - // First, allocate memory for each build column - auto build_size = perfect_join_statistics.build_range + 1; - for (const auto &type : ht.build_types) { - perfect_hash_table.emplace_back(type, build_size); - } - - // and for duplicate_checking - bitmap_build_idx = make_unsafe_uniq_array(build_size); - memset(bitmap_build_idx.get(), 0, sizeof(bool) * build_size); // set false - - // Now fill columns with build data - - return FullScanHashTable(key_type); -} - -bool PerfectHashJoinExecutor::FullScanHashTable(LogicalType &key_type) { - auto &data_collection = ht.GetDataCollection(); - - // TODO: In a parallel finalize: One should exclusively lock and each thread should do one part of the code below. - Vector tuples_addresses(LogicalType::POINTER, ht.Count()); // allocate space for all the tuples - - idx_t key_count = 0; - if (data_collection.ChunkCount() > 0) { - JoinHTScanState join_ht_state(data_collection, 0, data_collection.ChunkCount(), - TupleDataPinProperties::KEEP_EVERYTHING_PINNED); - - // Go through all the blocks and fill the keys addresses - key_count = ht.FillWithHTOffsets(join_ht_state, tuples_addresses); - } - - // Scan the build keys in the hash table - Vector build_vector(key_type, key_count); - RowOperations::FullScanColumn(ht.layout, tuples_addresses, build_vector, key_count, 0); - - // Now fill the selection vector using the build keys and create a sequential vector - // TODO: add check for fast pass when probe is part of build domain - SelectionVector sel_build(key_count + 1); - SelectionVector sel_tuples(key_count + 1); - bool success = FillSelectionVectorSwitchBuild(build_vector, sel_build, sel_tuples, key_count); - - // early out - if (!success) { - return false; - } - if (unique_keys == perfect_join_statistics.build_range + 1 && !ht.has_null) { - perfect_join_statistics.is_build_dense = true; - } - key_count = unique_keys; // do not consider keys out of the range - - // Full scan the remaining build columns and fill the perfect hash table - const auto build_size = perfect_join_statistics.build_range + 1; - for (idx_t i = 0; i < ht.build_types.size(); i++) { - auto &vector = perfect_hash_table[i]; - D_ASSERT(vector.GetType() == ht.build_types[i]); - if (build_size > STANDARD_VECTOR_SIZE) { - auto &col_mask = FlatVector::Validity(vector); - col_mask.Initialize(build_size); - } - - const auto col_no = ht.condition_types.size() + i; - data_collection.Gather(tuples_addresses, sel_tuples, key_count, col_no, vector, sel_build); - } - - return true; -} - -bool PerfectHashJoinExecutor::FillSelectionVectorSwitchBuild(Vector &source, SelectionVector &sel_vec, - SelectionVector &seq_sel_vec, idx_t count) { - switch (source.GetType().InternalType()) { - case PhysicalType::INT8: - return TemplatedFillSelectionVectorBuild(source, sel_vec, seq_sel_vec, count); - case PhysicalType::INT16: - return TemplatedFillSelectionVectorBuild(source, sel_vec, seq_sel_vec, count); - case PhysicalType::INT32: - return TemplatedFillSelectionVectorBuild(source, sel_vec, seq_sel_vec, count); - case PhysicalType::INT64: - return TemplatedFillSelectionVectorBuild(source, sel_vec, seq_sel_vec, count); - case PhysicalType::UINT8: - return TemplatedFillSelectionVectorBuild(source, sel_vec, seq_sel_vec, count); - case PhysicalType::UINT16: - return TemplatedFillSelectionVectorBuild(source, sel_vec, seq_sel_vec, count); - case PhysicalType::UINT32: - return TemplatedFillSelectionVectorBuild(source, sel_vec, seq_sel_vec, count); - case PhysicalType::UINT64: - return TemplatedFillSelectionVectorBuild(source, sel_vec, seq_sel_vec, count); - default: - throw NotImplementedException("Type not supported for perfect hash join"); - } -} - -template -bool PerfectHashJoinExecutor::TemplatedFillSelectionVectorBuild(Vector &source, SelectionVector &sel_vec, - SelectionVector &seq_sel_vec, idx_t count) { - if (perfect_join_statistics.build_min.IsNull() || perfect_join_statistics.build_max.IsNull()) { - return false; - } - auto min_value = perfect_join_statistics.build_min.GetValueUnsafe(); - auto max_value = perfect_join_statistics.build_max.GetValueUnsafe(); - UnifiedVectorFormat vector_data; - source.ToUnifiedFormat(count, vector_data); - auto data = reinterpret_cast(vector_data.data); - // generate the selection vector - for (idx_t i = 0, sel_idx = 0; i < count; ++i) { - auto data_idx = vector_data.sel->get_index(i); - auto input_value = data[data_idx]; - // add index to selection vector if value in the range - if (min_value <= input_value && input_value <= max_value) { - auto idx = (idx_t)(input_value - min_value); // subtract min value to get the idx position - sel_vec.set_index(sel_idx, idx); - if (bitmap_build_idx[idx]) { - return false; - } else { - bitmap_build_idx[idx] = true; - unique_keys++; - } - seq_sel_vec.set_index(sel_idx++, i); - } - } - return true; -} - -//===--------------------------------------------------------------------===// -// Probe -//===--------------------------------------------------------------------===// -class PerfectHashJoinState : public OperatorState { -public: - PerfectHashJoinState(ClientContext &context, const PhysicalHashJoin &join) : probe_executor(context) { - join_keys.Initialize(Allocator::Get(context), join.condition_types); - for (auto &cond : join.conditions) { - probe_executor.AddExpression(*cond.left); - } - build_sel_vec.Initialize(STANDARD_VECTOR_SIZE); - probe_sel_vec.Initialize(STANDARD_VECTOR_SIZE); - seq_sel_vec.Initialize(STANDARD_VECTOR_SIZE); - } - - DataChunk join_keys; - ExpressionExecutor probe_executor; - SelectionVector build_sel_vec; - SelectionVector probe_sel_vec; - SelectionVector seq_sel_vec; -}; - -unique_ptr PerfectHashJoinExecutor::GetOperatorState(ExecutionContext &context) { - auto state = make_uniq(context.client, join); - return std::move(state); -} - -OperatorResultType PerfectHashJoinExecutor::ProbePerfectHashTable(ExecutionContext &context, DataChunk &input, - DataChunk &result, OperatorState &state_p) { - auto &state = state_p.Cast(); - // keeps track of how many probe keys have a match - idx_t probe_sel_count = 0; - - // fetch the join keys from the chunk - state.join_keys.Reset(); - state.probe_executor.Execute(input, state.join_keys); - // select the keys that are in the min-max range - auto &keys_vec = state.join_keys.data[0]; - auto keys_count = state.join_keys.size(); - // todo: add check for fast pass when probe is part of build domain - FillSelectionVectorSwitchProbe(keys_vec, state.build_sel_vec, state.probe_sel_vec, keys_count, probe_sel_count); - - // If build is dense and probe is in build's domain, just reference probe - if (perfect_join_statistics.is_build_dense && keys_count == probe_sel_count) { - result.Reference(input); - } else { - // otherwise, filter it out the values that do not match - result.Slice(input, state.probe_sel_vec, probe_sel_count, 0); - } - // on the build side, we need to fetch the data and build dictionary vectors with the sel_vec - for (idx_t i = 0; i < ht.build_types.size(); i++) { - auto &result_vector = result.data[input.ColumnCount() + i]; - D_ASSERT(result_vector.GetType() == ht.build_types[i]); - auto &build_vec = perfect_hash_table[i]; - result_vector.Reference(build_vec); - result_vector.Slice(state.build_sel_vec, probe_sel_count); - } - return OperatorResultType::NEED_MORE_INPUT; -} - -void PerfectHashJoinExecutor::FillSelectionVectorSwitchProbe(Vector &source, SelectionVector &build_sel_vec, - SelectionVector &probe_sel_vec, idx_t count, - idx_t &probe_sel_count) { - switch (source.GetType().InternalType()) { - case PhysicalType::INT8: - TemplatedFillSelectionVectorProbe(source, build_sel_vec, probe_sel_vec, count, probe_sel_count); - break; - case PhysicalType::INT16: - TemplatedFillSelectionVectorProbe(source, build_sel_vec, probe_sel_vec, count, probe_sel_count); - break; - case PhysicalType::INT32: - TemplatedFillSelectionVectorProbe(source, build_sel_vec, probe_sel_vec, count, probe_sel_count); - break; - case PhysicalType::INT64: - TemplatedFillSelectionVectorProbe(source, build_sel_vec, probe_sel_vec, count, probe_sel_count); - break; - case PhysicalType::UINT8: - TemplatedFillSelectionVectorProbe(source, build_sel_vec, probe_sel_vec, count, probe_sel_count); - break; - case PhysicalType::UINT16: - TemplatedFillSelectionVectorProbe(source, build_sel_vec, probe_sel_vec, count, probe_sel_count); - break; - case PhysicalType::UINT32: - TemplatedFillSelectionVectorProbe(source, build_sel_vec, probe_sel_vec, count, probe_sel_count); - break; - case PhysicalType::UINT64: - TemplatedFillSelectionVectorProbe(source, build_sel_vec, probe_sel_vec, count, probe_sel_count); - break; - default: - throw NotImplementedException("Type not supported"); - } -} - -template -void PerfectHashJoinExecutor::TemplatedFillSelectionVectorProbe(Vector &source, SelectionVector &build_sel_vec, - SelectionVector &probe_sel_vec, idx_t count, - idx_t &probe_sel_count) { - auto min_value = perfect_join_statistics.build_min.GetValueUnsafe(); - auto max_value = perfect_join_statistics.build_max.GetValueUnsafe(); - - UnifiedVectorFormat vector_data; - source.ToUnifiedFormat(count, vector_data); - auto data = reinterpret_cast(vector_data.data); - auto validity_mask = &vector_data.validity; - // build selection vector for non-dense build - if (validity_mask->AllValid()) { - for (idx_t i = 0, sel_idx = 0; i < count; ++i) { - // retrieve value from vector - auto data_idx = vector_data.sel->get_index(i); - auto input_value = data[data_idx]; - // add index to selection vector if value in the range - if (min_value <= input_value && input_value <= max_value) { - auto idx = (idx_t)(input_value - min_value); // subtract min value to get the idx position - // check for matches in the build - if (bitmap_build_idx[idx]) { - build_sel_vec.set_index(sel_idx, idx); - probe_sel_vec.set_index(sel_idx++, i); - probe_sel_count++; - } - } - } - } else { - for (idx_t i = 0, sel_idx = 0; i < count; ++i) { - // retrieve value from vector - auto data_idx = vector_data.sel->get_index(i); - if (!validity_mask->RowIsValid(data_idx)) { - continue; - } - auto input_value = data[data_idx]; - // add index to selection vector if value in the range - if (min_value <= input_value && input_value <= max_value) { - auto idx = (idx_t)(input_value - min_value); // subtract min value to get the idx position - // check for matches in the build - if (bitmap_build_idx[idx]) { - build_sel_vec.set_index(sel_idx, idx); - probe_sel_vec.set_index(sel_idx++, i); - probe_sel_count++; - } - } - } - } -} - -} // namespace duckdb - - - - - - - - - - - - - - - -#include - -namespace duckdb { - -PhysicalAsOfJoin::PhysicalAsOfJoin(LogicalComparisonJoin &op, unique_ptr left, - unique_ptr right) - : PhysicalComparisonJoin(op, PhysicalOperatorType::ASOF_JOIN, std::move(op.conditions), op.join_type, - op.estimated_cardinality), - comparison_type(ExpressionType::INVALID) { - - // Convert the conditions partitions and sorts - for (auto &cond : conditions) { - D_ASSERT(cond.left->return_type == cond.right->return_type); - join_key_types.push_back(cond.left->return_type); - - auto left = cond.left->Copy(); - auto right = cond.right->Copy(); - switch (cond.comparison) { - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - case ExpressionType::COMPARE_GREATERTHAN: - null_sensitive.emplace_back(lhs_orders.size()); - lhs_orders.emplace_back(OrderType::ASCENDING, OrderByNullType::NULLS_LAST, std::move(left)); - rhs_orders.emplace_back(OrderType::ASCENDING, OrderByNullType::NULLS_LAST, std::move(right)); - comparison_type = cond.comparison; - break; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - case ExpressionType::COMPARE_LESSTHAN: - // Always put NULLS LAST so they can be ignored. - null_sensitive.emplace_back(lhs_orders.size()); - lhs_orders.emplace_back(OrderType::DESCENDING, OrderByNullType::NULLS_LAST, std::move(left)); - rhs_orders.emplace_back(OrderType::DESCENDING, OrderByNullType::NULLS_LAST, std::move(right)); - comparison_type = cond.comparison; - break; - case ExpressionType::COMPARE_EQUAL: - null_sensitive.emplace_back(lhs_orders.size()); - // Fall through - case ExpressionType::COMPARE_NOT_DISTINCT_FROM: - lhs_partitions.emplace_back(std::move(left)); - rhs_partitions.emplace_back(std::move(right)); - break; - default: - throw NotImplementedException("Unsupported join condition for ASOF join"); - } - } - D_ASSERT(!lhs_orders.empty()); - D_ASSERT(!rhs_orders.empty()); - - children.push_back(std::move(left)); - children.push_back(std::move(right)); - - // Fill out the right projection map. - right_projection_map = op.right_projection_map; - if (right_projection_map.empty()) { - const auto right_count = children[1]->types.size(); - right_projection_map.reserve(right_count); - for (column_t i = 0; i < right_count; ++i) { - right_projection_map.emplace_back(i); - } - } -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class AsOfGlobalSinkState : public GlobalSinkState { -public: - AsOfGlobalSinkState(ClientContext &context, const PhysicalAsOfJoin &op) - : rhs_sink(context, op.rhs_partitions, op.rhs_orders, op.children[1]->types, {}, op.estimated_cardinality), - is_outer(IsRightOuterJoin(op.join_type)), has_null(false) { - } - - idx_t Count() const { - return rhs_sink.count; - } - - PartitionLocalSinkState *RegisterBuffer(ClientContext &context) { - lock_guard guard(lock); - lhs_buffers.emplace_back(make_uniq(context, *lhs_sink)); - return lhs_buffers.back().get(); - } - - PartitionGlobalSinkState rhs_sink; - - // One per partition - const bool is_outer; - vector right_outers; - bool has_null; - - // Left side buffering - unique_ptr lhs_sink; - - mutex lock; - vector> lhs_buffers; -}; - -class AsOfLocalSinkState : public LocalSinkState { -public: - explicit AsOfLocalSinkState(ClientContext &context, PartitionGlobalSinkState &gstate_p) - : local_partition(context, gstate_p) { - } - - void Sink(DataChunk &input_chunk) { - local_partition.Sink(input_chunk); - } - - void Combine() { - local_partition.Combine(); - } - - PartitionLocalSinkState local_partition; -}; - -unique_ptr PhysicalAsOfJoin::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(context, *this); -} - -unique_ptr PhysicalAsOfJoin::GetLocalSinkState(ExecutionContext &context) const { - // We only sink the RHS - auto &gsink = sink_state->Cast(); - return make_uniq(context.client, gsink.rhs_sink); -} - -SinkResultType PhysicalAsOfJoin::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - auto &lstate = input.local_state.Cast(); - - lstate.Sink(chunk); - - return SinkResultType::NEED_MORE_INPUT; -} - -SinkCombineResultType PhysicalAsOfJoin::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { - auto &lstate = input.local_state.Cast(); - lstate.Combine(); - return SinkCombineResultType::FINISHED; -} - -//===--------------------------------------------------------------------===// -// Finalize -//===--------------------------------------------------------------------===// -SinkFinalizeType PhysicalAsOfJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - auto &gstate = input.global_state.Cast(); - - // The data is all in so we can initialise the left partitioning. - const vector> partitions_stats; - gstate.lhs_sink = make_uniq(context, lhs_partitions, lhs_orders, children[0]->types, - partitions_stats, 0); - gstate.lhs_sink->SyncPartitioning(gstate.rhs_sink); - - // Find the first group to sort - if (!gstate.rhs_sink.HasMergeTasks() && EmptyResultIfRHSIsEmpty()) { - // Empty input! - return SinkFinalizeType::NO_OUTPUT_POSSIBLE; - } - - // Schedule all the sorts for maximum thread utilisation - auto new_event = make_shared(gstate.rhs_sink, pipeline); - event.InsertEvent(std::move(new_event)); - - return SinkFinalizeType::READY; -} - -//===--------------------------------------------------------------------===// -// Operator -//===--------------------------------------------------------------------===// -class AsOfGlobalState : public GlobalOperatorState { -public: - explicit AsOfGlobalState(AsOfGlobalSinkState &gsink) { - // for FULL/RIGHT OUTER JOIN, initialize right_outers to false for every tuple - auto &rhs_partition = gsink.rhs_sink; - auto &right_outers = gsink.right_outers; - right_outers.reserve(rhs_partition.hash_groups.size()); - for (const auto &hash_group : rhs_partition.hash_groups) { - right_outers.emplace_back(OuterJoinMarker(gsink.is_outer)); - right_outers.back().Initialize(hash_group->count); - } - } -}; - -unique_ptr PhysicalAsOfJoin::GetGlobalOperatorState(ClientContext &context) const { - auto &gsink = sink_state->Cast(); - return make_uniq(gsink); -} - -class AsOfLocalState : public CachingOperatorState { -public: - AsOfLocalState(ClientContext &context, const PhysicalAsOfJoin &op) - : context(context), allocator(Allocator::Get(context)), op(op), lhs_executor(context), - left_outer(IsLeftOuterJoin(op.join_type)), fetch_next_left(true) { - lhs_keys.Initialize(allocator, op.join_key_types); - for (const auto &cond : op.conditions) { - lhs_executor.AddExpression(*cond.left); - } - - lhs_payload.Initialize(allocator, op.children[0]->types); - lhs_sel.Initialize(); - left_outer.Initialize(STANDARD_VECTOR_SIZE); - - auto &gsink = op.sink_state->Cast(); - lhs_partition_sink = gsink.RegisterBuffer(context); - } - - bool Sink(DataChunk &input); - OperatorResultType ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk); - - ClientContext &context; - Allocator &allocator; - const PhysicalAsOfJoin &op; - - ExpressionExecutor lhs_executor; - DataChunk lhs_keys; - ValidityMask lhs_valid_mask; - SelectionVector lhs_sel; - DataChunk lhs_payload; - - OuterJoinMarker left_outer; - bool fetch_next_left; - - optional_ptr lhs_partition_sink; -}; - -bool AsOfLocalState::Sink(DataChunk &input) { - // Compute the join keys - lhs_keys.Reset(); - lhs_executor.Execute(input, lhs_keys); - - // Combine the NULLs - const auto count = input.size(); - lhs_valid_mask.Reset(); - for (auto col_idx : op.null_sensitive) { - auto &col = lhs_keys.data[col_idx]; - UnifiedVectorFormat unified; - col.ToUnifiedFormat(count, unified); - lhs_valid_mask.Combine(unified.validity, count); - } - - // Convert the mask to a selection vector - // and mark all the rows that cannot match for early return. - idx_t lhs_valid = 0; - const auto entry_count = lhs_valid_mask.EntryCount(count); - idx_t base_idx = 0; - left_outer.Reset(); - for (idx_t entry_idx = 0; entry_idx < entry_count;) { - const auto validity_entry = lhs_valid_mask.GetValidityEntry(entry_idx++); - const auto next = MinValue(base_idx + ValidityMask::BITS_PER_VALUE, count); - if (ValidityMask::AllValid(validity_entry)) { - for (; base_idx < next; ++base_idx) { - lhs_sel.set_index(lhs_valid++, base_idx); - left_outer.SetMatch(base_idx); - } - } else if (ValidityMask::NoneValid(validity_entry)) { - base_idx = next; - } else { - const auto start = base_idx; - for (; base_idx < next; ++base_idx) { - if (ValidityMask::RowIsValid(validity_entry, base_idx - start)) { - lhs_sel.set_index(lhs_valid++, base_idx); - left_outer.SetMatch(base_idx); - } - } - } - } - - // Slice the keys to the ones we can match - lhs_payload.Reset(); - if (lhs_valid == count) { - lhs_payload.Reference(input); - lhs_payload.SetCardinality(input); - } else { - lhs_payload.Slice(input, lhs_sel, lhs_valid); - lhs_payload.SetCardinality(lhs_valid); - - // Flush the ones that can't match - fetch_next_left = false; - } - - lhs_partition_sink->Sink(lhs_payload); - - return false; -} - -OperatorResultType AsOfLocalState::ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk) { - input.Verify(); - Sink(input); - - // If there were any unmatchable rows, return them now so we can forget about them. - if (!fetch_next_left) { - fetch_next_left = true; - left_outer.ConstructLeftJoinResult(input, chunk); - left_outer.Reset(); - } - - // Just keep asking for data and buffering it - return OperatorResultType::NEED_MORE_INPUT; -} - -OperatorResultType PhysicalAsOfJoin::ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - GlobalOperatorState &gstate, OperatorState &lstate_p) const { - auto &gsink = sink_state->Cast(); - auto &lstate = lstate_p.Cast(); - - if (gsink.rhs_sink.count == 0) { - // empty RHS - if (!EmptyResultIfRHSIsEmpty()) { - ConstructEmptyJoinResult(join_type, gsink.has_null, input, chunk); - return OperatorResultType::NEED_MORE_INPUT; - } else { - return OperatorResultType::FINISHED; - } - } - - return lstate.ExecuteInternal(context, input, chunk); -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -class AsOfProbeBuffer { -public: - using Orders = vector; - - static bool IsExternal(ClientContext &context) { - return ClientConfig::GetConfig(context).force_external; - } - - AsOfProbeBuffer(ClientContext &context, const PhysicalAsOfJoin &op); - -public: - void ResolveJoin(bool *found_matches, idx_t *matches = nullptr); - bool Scanning() const { - return lhs_scanner.get(); - } - void BeginLeftScan(hash_t scan_bin); - bool NextLeft(); - void EndScan(); - - // resolve joins that output max N elements (SEMI, ANTI, MARK) - void ResolveSimpleJoin(ExecutionContext &context, DataChunk &chunk); - // resolve joins that can potentially output N*M elements (INNER, LEFT, FULL) - void ResolveComplexJoin(ExecutionContext &context, DataChunk &chunk); - // Chunk may be empty - void GetData(ExecutionContext &context, DataChunk &chunk); - bool HasMoreData() const { - return !fetch_next_left || (lhs_scanner && lhs_scanner->Remaining()); - } - - ClientContext &context; - Allocator &allocator; - const PhysicalAsOfJoin &op; - BufferManager &buffer_manager; - const bool force_external; - const idx_t memory_per_thread; - Orders lhs_orders; - - // LHS scanning - SelectionVector lhs_sel; - optional_ptr left_hash; - OuterJoinMarker left_outer; - unique_ptr left_itr; - unique_ptr lhs_scanner; - DataChunk lhs_payload; - - // RHS scanning - optional_ptr right_hash; - optional_ptr right_outer; - unique_ptr right_itr; - unique_ptr rhs_scanner; - DataChunk rhs_payload; - - idx_t lhs_match_count; - bool fetch_next_left; -}; - -AsOfProbeBuffer::AsOfProbeBuffer(ClientContext &context, const PhysicalAsOfJoin &op) - : context(context), allocator(Allocator::Get(context)), op(op), - buffer_manager(BufferManager::GetBufferManager(context)), force_external(IsExternal(context)), - memory_per_thread(op.GetMaxThreadMemory(context)), left_outer(IsLeftOuterJoin(op.join_type)), - fetch_next_left(true) { - vector> partition_stats; - Orders partitions; // Not used. - PartitionGlobalSinkState::GenerateOrderings(partitions, lhs_orders, op.lhs_partitions, op.lhs_orders, - partition_stats); - - // We sort the row numbers of the incoming block, not the rows - lhs_payload.Initialize(allocator, op.children[0]->types); - rhs_payload.Initialize(allocator, op.children[1]->types); - - lhs_sel.Initialize(); - left_outer.Initialize(STANDARD_VECTOR_SIZE); -} - -void AsOfProbeBuffer::BeginLeftScan(hash_t scan_bin) { - auto &gsink = op.sink_state->Cast(); - auto &lhs_sink = *gsink.lhs_sink; - const auto left_group = lhs_sink.bin_groups[scan_bin]; - if (left_group >= lhs_sink.bin_groups.size()) { - return; - } - - auto iterator_comp = ExpressionType::INVALID; - switch (op.comparison_type) { - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - iterator_comp = ExpressionType::COMPARE_LESSTHANOREQUALTO; - break; - case ExpressionType::COMPARE_GREATERTHAN: - iterator_comp = ExpressionType::COMPARE_LESSTHAN; - break; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - iterator_comp = ExpressionType::COMPARE_GREATERTHANOREQUALTO; - break; - case ExpressionType::COMPARE_LESSTHAN: - iterator_comp = ExpressionType::COMPARE_GREATERTHAN; - break; - default: - throw NotImplementedException("Unsupported comparison type for ASOF join"); - } - - left_hash = lhs_sink.hash_groups[left_group].get(); - auto &left_sort = *(left_hash->global_sort); - if (left_sort.sorted_blocks.empty()) { - return; - } - lhs_scanner = make_uniq(left_sort, false); - left_itr = make_uniq(left_sort, iterator_comp); - - // We are only probing the corresponding right side bin, which may be empty - // If they are empty, we leave the iterator as null so we can emit left matches - auto &rhs_sink = gsink.rhs_sink; - const auto right_group = rhs_sink.bin_groups[scan_bin]; - if (right_group < rhs_sink.bin_groups.size()) { - right_hash = rhs_sink.hash_groups[right_group].get(); - right_outer = gsink.right_outers.data() + right_group; - auto &right_sort = *(right_hash->global_sort); - right_itr = make_uniq(right_sort, iterator_comp); - rhs_scanner = make_uniq(right_sort, false); - } -} - -bool AsOfProbeBuffer::NextLeft() { - if (!HasMoreData()) { - return false; - } - - // Scan the next sorted chunk - lhs_payload.Reset(); - left_itr->SetIndex(lhs_scanner->Scanned()); - lhs_scanner->Scan(lhs_payload); - - return true; -} - -void AsOfProbeBuffer::EndScan() { - right_hash = nullptr; - right_itr.reset(); - rhs_scanner.reset(); - right_outer = nullptr; - - left_hash = nullptr; - left_itr.reset(); - lhs_scanner.reset(); -} - -void AsOfProbeBuffer::ResolveJoin(bool *found_match, idx_t *matches) { - // If there was no right partition, there are no matches - lhs_match_count = 0; - left_outer.Reset(); - if (!right_itr) { - return; - } - - const auto count = lhs_payload.size(); - const auto left_base = left_itr->GetIndex(); - // Searching for right <= left - for (idx_t i = 0; i < count; ++i) { - left_itr->SetIndex(left_base + i); - - // If right > left, then there is no match - if (!right_itr->Compare(*left_itr)) { - continue; - } - - // Exponential search forward for a non-matching value using radix iterators - // (We use exponential search to avoid thrashing the block manager on large probes) - idx_t bound = 1; - idx_t begin = right_itr->GetIndex(); - right_itr->SetIndex(begin + bound); - while (right_itr->GetIndex() < right_hash->count) { - if (right_itr->Compare(*left_itr)) { - // If right <= left, jump ahead - bound *= 2; - right_itr->SetIndex(begin + bound); - } else { - break; - } - } - - // Binary search for the first non-matching value using radix iterators - // The previous value (which we know exists) is the match - auto first = begin + bound / 2; - auto last = MinValue(begin + bound, right_hash->count); - while (first < last) { - const auto mid = first + (last - first) / 2; - right_itr->SetIndex(mid); - if (right_itr->Compare(*left_itr)) { - // If right <= left, new lower bound - first = mid + 1; - } else { - last = mid; - } - } - right_itr->SetIndex(--first); - - // Check partitions for strict equality - if (right_hash->ComparePartitions(*left_itr, *right_itr)) { - continue; - } - - // Emit match data - right_outer->SetMatch(first); - left_outer.SetMatch(i); - if (found_match) { - found_match[i] = true; - } - if (matches) { - matches[i] = first; - } - lhs_sel.set_index(lhs_match_count++, i); - } -} - -unique_ptr PhysicalAsOfJoin::GetOperatorState(ExecutionContext &context) const { - return make_uniq(context.client, *this); -} - -void AsOfProbeBuffer::ResolveSimpleJoin(ExecutionContext &context, DataChunk &chunk) { - // perform the actual join - bool found_match[STANDARD_VECTOR_SIZE] = {false}; - ResolveJoin(found_match); - - // now construct the result based on the join result - switch (op.join_type) { - case JoinType::SEMI: - PhysicalJoin::ConstructSemiJoinResult(lhs_payload, chunk, found_match); - break; - case JoinType::ANTI: - PhysicalJoin::ConstructAntiJoinResult(lhs_payload, chunk, found_match); - break; - default: - throw NotImplementedException("Unimplemented join type for AsOf join"); - } -} - -void AsOfProbeBuffer::ResolveComplexJoin(ExecutionContext &context, DataChunk &chunk) { - // perform the actual join - idx_t matches[STANDARD_VECTOR_SIZE]; - ResolveJoin(nullptr, matches); - - for (idx_t i = 0; i < lhs_match_count; ++i) { - const auto idx = lhs_sel[i]; - const auto match_pos = matches[idx]; - // Skip to the range containing the match - while (match_pos >= rhs_scanner->Scanned()) { - rhs_payload.Reset(); - rhs_scanner->Scan(rhs_payload); - } - // Append the individual values - // TODO: Batch the copies - const auto source_offset = match_pos - (rhs_scanner->Scanned() - rhs_payload.size()); - for (column_t col_idx = 0; col_idx < op.right_projection_map.size(); ++col_idx) { - const auto rhs_idx = op.right_projection_map[col_idx]; - auto &source = rhs_payload.data[rhs_idx]; - auto &target = chunk.data[lhs_payload.ColumnCount() + col_idx]; - VectorOperations::Copy(source, target, source_offset + 1, source_offset, i); - } - } - - // Slice the left payload into the result - for (column_t i = 0; i < lhs_payload.ColumnCount(); ++i) { - chunk.data[i].Slice(lhs_payload.data[i], lhs_sel, lhs_match_count); - } - chunk.SetCardinality(lhs_match_count); - - // If we are doing a left join, come back for the NULLs - fetch_next_left = !left_outer.Enabled(); -} - -void AsOfProbeBuffer::GetData(ExecutionContext &context, DataChunk &chunk) { - // Handle dangling left join results from current chunk - if (!fetch_next_left) { - fetch_next_left = true; - if (left_outer.Enabled()) { - // left join: before we move to the next chunk, see if we need to output any vectors that didn't - // have a match found - left_outer.ConstructLeftJoinResult(lhs_payload, chunk); - left_outer.Reset(); - } - return; - } - - // Stop if there is no more data - if (!NextLeft()) { - return; - } - - switch (op.join_type) { - case JoinType::SEMI: - case JoinType::ANTI: - case JoinType::MARK: - // simple joins can have max STANDARD_VECTOR_SIZE matches per chunk - ResolveSimpleJoin(context, chunk); - break; - case JoinType::LEFT: - case JoinType::INNER: - case JoinType::RIGHT: - case JoinType::OUTER: - ResolveComplexJoin(context, chunk); - break; - default: - throw NotImplementedException("Unimplemented type for as-of join!"); - } -} - -class AsOfGlobalSourceState : public GlobalSourceState { -public: - explicit AsOfGlobalSourceState(AsOfGlobalSinkState &gsink_p) - : gsink(gsink_p), next_combine(0), combined(0), merged(0), mergers(0), next_left(0), flushed(0), next_right(0) { - } - - PartitionGlobalMergeStates &GetMergeStates() { - lock_guard guard(lock); - if (!merge_states) { - merge_states = make_uniq(*gsink.lhs_sink); - } - return *merge_states; - } - - AsOfGlobalSinkState &gsink; - //! The next buffer to combine - atomic next_combine; - //! The number of combined buffers - atomic combined; - //! The number of combined buffers - atomic merged; - //! The number of combined buffers - atomic mergers; - //! The next buffer to flush - atomic next_left; - //! The number of flushed buffers - atomic flushed; - //! The right outer output read position. - atomic next_right; - //! The merge handler - mutex lock; - unique_ptr merge_states; - -public: - idx_t MaxThreads() override { - return gsink.lhs_buffers.size(); - } -}; - -unique_ptr PhysicalAsOfJoin::GetGlobalSourceState(ClientContext &context) const { - auto &gsink = sink_state->Cast(); - return make_uniq(gsink); -} - -class AsOfLocalSourceState : public LocalSourceState { -public: - using HashGroupPtr = unique_ptr; - - AsOfLocalSourceState(AsOfGlobalSourceState &gsource, const PhysicalAsOfJoin &op, ClientContext &client_p); - - // Return true if we were not interrupted (another thread died) - bool CombineLeftPartitions(); - bool MergeLeftPartitions(); - - idx_t BeginRightScan(const idx_t hash_bin); - - AsOfGlobalSourceState &gsource; - ClientContext &client; - - //! The left side partition being probed - AsOfProbeBuffer probe_buffer; - - //! The read partition - idx_t hash_bin; - HashGroupPtr hash_group; - //! The read cursor - unique_ptr scanner; - //! Pointer to the matches - const bool *found_match = {}; -}; - -AsOfLocalSourceState::AsOfLocalSourceState(AsOfGlobalSourceState &gsource, const PhysicalAsOfJoin &op, - ClientContext &client_p) - : gsource(gsource), client(client_p), probe_buffer(gsource.gsink.lhs_sink->context, op) { - gsource.mergers++; -} - -bool AsOfLocalSourceState::CombineLeftPartitions() { - const auto buffer_count = gsource.gsink.lhs_buffers.size(); - while (gsource.combined < buffer_count && !client.interrupted) { - const auto next_combine = gsource.next_combine++; - if (next_combine < buffer_count) { - gsource.gsink.lhs_buffers[next_combine]->Combine(); - ++gsource.combined; - } else { - TaskScheduler::GetScheduler(client).YieldThread(); - } - } - - return !client.interrupted; -} - -bool AsOfLocalSourceState::MergeLeftPartitions() { - PartitionGlobalMergeStates::Callback local_callback; - PartitionLocalMergeState local_merge(*gsource.gsink.lhs_sink); - gsource.GetMergeStates().ExecuteTask(local_merge, local_callback); - gsource.merged++; - while (gsource.merged < gsource.mergers && !client.interrupted) { - TaskScheduler::GetScheduler(client).YieldThread(); - } - return !client.interrupted; -} - -idx_t AsOfLocalSourceState::BeginRightScan(const idx_t hash_bin_p) { - hash_bin = hash_bin_p; - - hash_group = std::move(gsource.gsink.rhs_sink.hash_groups[hash_bin]); - if (hash_group->global_sort->sorted_blocks.empty()) { - return 0; - } - scanner = make_uniq(*hash_group->global_sort); - found_match = gsource.gsink.right_outers[hash_bin].GetMatches(); - - return scanner->Remaining(); -} - -unique_ptr PhysicalAsOfJoin::GetLocalSourceState(ExecutionContext &context, - GlobalSourceState &gstate) const { - auto &gsource = gstate.Cast(); - return make_uniq(gsource, *this, context.client); -} - -SourceResultType PhysicalAsOfJoin::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &gsource = input.global_state.Cast(); - auto &lsource = input.local_state.Cast(); - auto &rhs_sink = gsource.gsink.rhs_sink; - auto &client = context.client; - - // Step 1: Combine the partitions - if (!lsource.CombineLeftPartitions()) { - return SourceResultType::FINISHED; - } - - // Step 2: Sort on all threads - if (!lsource.MergeLeftPartitions()) { - return SourceResultType::FINISHED; - } - - // Step 3: Join the partitions - auto &lhs_sink = *gsource.gsink.lhs_sink; - const auto left_bins = lhs_sink.grouping_data ? lhs_sink.grouping_data->GetPartitions().size() : 1; - while (gsource.flushed < left_bins) { - // Make sure we have something to flush - if (!lsource.probe_buffer.Scanning()) { - const auto left_bin = gsource.next_left++; - if (left_bin < left_bins) { - // More to flush - lsource.probe_buffer.BeginLeftScan(left_bin); - } else if (!IsRightOuterJoin(join_type) || client.interrupted) { - return SourceResultType::FINISHED; - } else { - // Wait for all threads to finish - // TODO: How to implement a spin wait correctly? - // Returning BLOCKED seems to hang the system. - TaskScheduler::GetScheduler(client).YieldThread(); - continue; - } - } - - lsource.probe_buffer.GetData(context, chunk); - if (chunk.size()) { - return SourceResultType::HAVE_MORE_OUTPUT; - } else if (lsource.probe_buffer.HasMoreData()) { - // Join the next partition - continue; - } else { - lsource.probe_buffer.EndScan(); - gsource.flushed++; - } - } - - // Step 4: Emit right join matches - if (!IsRightOuterJoin(join_type)) { - return SourceResultType::FINISHED; - } - - auto &hash_groups = rhs_sink.hash_groups; - const auto right_groups = hash_groups.size(); - - DataChunk rhs_chunk; - rhs_chunk.Initialize(Allocator::Get(context.client), rhs_sink.payload_types); - SelectionVector rsel(STANDARD_VECTOR_SIZE); - - while (chunk.size() == 0) { - // Move to the next bin if we are done. - while (!lsource.scanner || !lsource.scanner->Remaining()) { - lsource.scanner.reset(); - lsource.hash_group.reset(); - auto hash_bin = gsource.next_right++; - if (hash_bin >= right_groups) { - return SourceResultType::FINISHED; - } - - for (; hash_bin < hash_groups.size(); hash_bin = gsource.next_right++) { - if (hash_groups[hash_bin]) { - break; - } - } - lsource.BeginRightScan(hash_bin); - } - const auto rhs_position = lsource.scanner->Scanned(); - lsource.scanner->Scan(rhs_chunk); - - const auto count = rhs_chunk.size(); - if (count == 0) { - return SourceResultType::FINISHED; - } - - // figure out which tuples didn't find a match in the RHS - auto found_match = lsource.found_match; - idx_t result_count = 0; - for (idx_t i = 0; i < count; i++) { - if (!found_match[rhs_position + i]) { - rsel.set_index(result_count++, i); - } - } - - if (result_count > 0) { - // if there were any tuples that didn't find a match, output them - const idx_t left_column_count = children[0]->types.size(); - for (idx_t col_idx = 0; col_idx < left_column_count; ++col_idx) { - chunk.data[col_idx].SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(chunk.data[col_idx], true); - } - for (idx_t col_idx = 0; col_idx < right_projection_map.size(); ++col_idx) { - const auto rhs_idx = right_projection_map[col_idx]; - chunk.data[left_column_count + col_idx].Slice(rhs_chunk.data[rhs_idx], rsel, result_count); - } - chunk.SetCardinality(result_count); - break; - } - } - - return chunk.size() > 0 ? SourceResultType::HAVE_MORE_OUTPUT : SourceResultType::FINISHED; -} - -} // namespace duckdb - - - - - - - - - - -namespace duckdb { - -PhysicalBlockwiseNLJoin::PhysicalBlockwiseNLJoin(LogicalOperator &op, unique_ptr left, - unique_ptr right, unique_ptr condition, - JoinType join_type, idx_t estimated_cardinality) - : PhysicalJoin(op, PhysicalOperatorType::BLOCKWISE_NL_JOIN, join_type, estimated_cardinality), - condition(std::move(condition)) { - children.push_back(std::move(left)); - children.push_back(std::move(right)); - // MARK and SINGLE joins not handled - D_ASSERT(join_type != JoinType::MARK); - D_ASSERT(join_type != JoinType::SINGLE); -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class BlockwiseNLJoinLocalState : public LocalSinkState { -public: - BlockwiseNLJoinLocalState() { - } -}; - -class BlockwiseNLJoinGlobalState : public GlobalSinkState { -public: - explicit BlockwiseNLJoinGlobalState(ClientContext &context, const PhysicalBlockwiseNLJoin &op) - : right_chunks(context, op.children[1]->GetTypes()), right_outer(IsRightOuterJoin(op.join_type)) { - } - - mutex lock; - ColumnDataCollection right_chunks; - OuterJoinMarker right_outer; -}; - -unique_ptr PhysicalBlockwiseNLJoin::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(context, *this); -} - -unique_ptr PhysicalBlockwiseNLJoin::GetLocalSinkState(ExecutionContext &context) const { - return make_uniq(); -} - -SinkResultType PhysicalBlockwiseNLJoin::Sink(ExecutionContext &context, DataChunk &chunk, - OperatorSinkInput &input) const { - auto &gstate = input.global_state.Cast(); - lock_guard nl_lock(gstate.lock); - gstate.right_chunks.Append(chunk); - return SinkResultType::NEED_MORE_INPUT; -} - -//===--------------------------------------------------------------------===// -// Finalize -//===--------------------------------------------------------------------===// -SinkFinalizeType PhysicalBlockwiseNLJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - auto &gstate = input.global_state.Cast(); - gstate.right_outer.Initialize(gstate.right_chunks.Count()); - - if (gstate.right_chunks.Count() == 0 && EmptyResultIfRHSIsEmpty()) { - return SinkFinalizeType::NO_OUTPUT_POSSIBLE; - } - return SinkFinalizeType::READY; -} - -//===--------------------------------------------------------------------===// -// Operator -//===--------------------------------------------------------------------===// -class BlockwiseNLJoinState : public CachingOperatorState { -public: - explicit BlockwiseNLJoinState(ExecutionContext &context, ColumnDataCollection &rhs, - const PhysicalBlockwiseNLJoin &op) - : cross_product(rhs), left_outer(IsLeftOuterJoin(op.join_type)), match_sel(STANDARD_VECTOR_SIZE), - executor(context.client, *op.condition) { - left_outer.Initialize(STANDARD_VECTOR_SIZE); - } - - CrossProductExecutor cross_product; - OuterJoinMarker left_outer; - SelectionVector match_sel; - ExpressionExecutor executor; - DataChunk intermediate_chunk; -}; - -unique_ptr PhysicalBlockwiseNLJoin::GetOperatorState(ExecutionContext &context) const { - auto &gstate = sink_state->Cast(); - auto result = make_uniq(context, gstate.right_chunks, *this); - if (join_type == JoinType::SEMI || join_type == JoinType::ANTI) { - vector intermediate_types; - for (auto &type : children[0]->types) { - intermediate_types.emplace_back(type); - } - for (auto &type : children[1]->types) { - intermediate_types.emplace_back(type); - } - result->intermediate_chunk.Initialize(Allocator::DefaultAllocator(), intermediate_types); - } - return std::move(result); -} - -OperatorResultType PhysicalBlockwiseNLJoin::ExecuteInternal(ExecutionContext &context, DataChunk &input, - DataChunk &chunk, GlobalOperatorState &gstate_p, - OperatorState &state_p) const { - D_ASSERT(input.size() > 0); - auto &state = state_p.Cast(); - auto &gstate = sink_state->Cast(); - - if (gstate.right_chunks.Count() == 0) { - // empty RHS - if (!EmptyResultIfRHSIsEmpty()) { - PhysicalComparisonJoin::ConstructEmptyJoinResult(join_type, false, input, chunk); - return OperatorResultType::NEED_MORE_INPUT; - } else { - return OperatorResultType::FINISHED; - } - } - - DataChunk *intermediate_chunk = &chunk; - if (join_type == JoinType::SEMI || join_type == JoinType::ANTI) { - intermediate_chunk = &state.intermediate_chunk; - intermediate_chunk->Reset(); - } - - // now perform the actual join - // we perform a cross product, then execute the expression directly on the cross product result - idx_t result_count = 0; - bool found_match[STANDARD_VECTOR_SIZE] = {false}; - - do { - auto result = state.cross_product.Execute(input, *intermediate_chunk); - if (result == OperatorResultType::NEED_MORE_INPUT) { - // exhausted input, have to pull new LHS chunk - if (state.left_outer.Enabled()) { - // left join: before we move to the next chunk, see if we need to output any vectors that didn't - // have a match found - state.left_outer.ConstructLeftJoinResult(input, *intermediate_chunk); - state.left_outer.Reset(); - } - - if (join_type == JoinType::SEMI) { - PhysicalJoin::ConstructSemiJoinResult(input, chunk, found_match); - } - if (join_type == JoinType::ANTI) { - PhysicalJoin::ConstructAntiJoinResult(input, chunk, found_match); - } - - return OperatorResultType::NEED_MORE_INPUT; - } - - // now perform the computation - result_count = state.executor.SelectExpression(*intermediate_chunk, state.match_sel); - - // handle anti and semi joins with different logic - if (result_count > 0) { - // found a match! - // handle anti semi join conditions first - if (join_type == JoinType::ANTI || join_type == JoinType::SEMI) { - if (state.cross_product.ScanLHS()) { - found_match[state.cross_product.PositionInChunk()] = true; - } else { - for (idx_t i = 0; i < result_count; i++) { - found_match[state.match_sel.get_index(i)] = true; - } - } - intermediate_chunk->Reset(); - // trick the loop to continue as semi and anti joins will never produce more output than - // the LHS cardinality - result_count = 0; - } else { - // check if the cross product is scanning the LHS or the RHS in its entirety - if (!state.cross_product.ScanLHS()) { - // set the match flags in the LHS - state.left_outer.SetMatches(state.match_sel, result_count); - // set the match flag in the RHS - gstate.right_outer.SetMatch(state.cross_product.ScanPosition() + - state.cross_product.PositionInChunk()); - } else { - // set the match flag in the LHS - state.left_outer.SetMatch(state.cross_product.PositionInChunk()); - // set the match flags in the RHS - gstate.right_outer.SetMatches(state.match_sel, result_count, state.cross_product.ScanPosition()); - } - intermediate_chunk->Slice(state.match_sel, result_count); - } - } else { - // no result: reset the chunk - intermediate_chunk->Reset(); - } - } while (result_count == 0); - - return OperatorResultType::HAVE_MORE_OUTPUT; -} - -string PhysicalBlockwiseNLJoin::ParamsToString() const { - string extra_info = EnumUtil::ToString(join_type) + "\n"; - extra_info += condition->GetName(); - return extra_info; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -class BlockwiseNLJoinGlobalScanState : public GlobalSourceState { -public: - explicit BlockwiseNLJoinGlobalScanState(const PhysicalBlockwiseNLJoin &op) : op(op) { - D_ASSERT(op.sink_state); - auto &sink = op.sink_state->Cast(); - sink.right_outer.InitializeScan(sink.right_chunks, scan_state); - } - - const PhysicalBlockwiseNLJoin &op; - OuterJoinGlobalScanState scan_state; - -public: - idx_t MaxThreads() override { - auto &sink = op.sink_state->Cast(); - return sink.right_outer.MaxThreads(); - } -}; - -class BlockwiseNLJoinLocalScanState : public LocalSourceState { -public: - explicit BlockwiseNLJoinLocalScanState(const PhysicalBlockwiseNLJoin &op, BlockwiseNLJoinGlobalScanState &gstate) { - D_ASSERT(op.sink_state); - auto &sink = op.sink_state->Cast(); - sink.right_outer.InitializeScan(gstate.scan_state, scan_state); - } - - OuterJoinLocalScanState scan_state; -}; - -unique_ptr PhysicalBlockwiseNLJoin::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(*this); -} - -unique_ptr PhysicalBlockwiseNLJoin::GetLocalSourceState(ExecutionContext &context, - GlobalSourceState &gstate) const { - return make_uniq(*this, gstate.Cast()); -} - -SourceResultType PhysicalBlockwiseNLJoin::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - D_ASSERT(IsRightOuterJoin(join_type)); - // check if we need to scan any unmatched tuples from the RHS for the full/right outer join - auto &sink = sink_state->Cast(); - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - - // if the LHS is exhausted in a FULL/RIGHT OUTER JOIN, we scan chunks we still need to output - sink.right_outer.Scan(gstate.scan_state, lstate.scan_state, chunk); - - return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; -} - -} // namespace duckdb - - - - -namespace duckdb { - -PhysicalComparisonJoin::PhysicalComparisonJoin(LogicalOperator &op, PhysicalOperatorType type, - vector conditions_p, JoinType join_type, - idx_t estimated_cardinality) - : PhysicalJoin(op, type, join_type, estimated_cardinality) { - conditions.resize(conditions_p.size()); - // we reorder conditions so the ones with COMPARE_EQUAL occur first - idx_t equal_position = 0; - idx_t other_position = conditions_p.size() - 1; - for (idx_t i = 0; i < conditions_p.size(); i++) { - if (conditions_p[i].comparison == ExpressionType::COMPARE_EQUAL || - conditions_p[i].comparison == ExpressionType::COMPARE_NOT_DISTINCT_FROM) { - // COMPARE_EQUAL and COMPARE_NOT_DISTINCT_FROM, move to the start - conditions[equal_position++] = std::move(conditions_p[i]); - } else { - // other expression, move to the end - conditions[other_position--] = std::move(conditions_p[i]); - } - } -} - -string PhysicalComparisonJoin::ParamsToString() const { - string extra_info = EnumUtil::ToString(join_type) + "\n"; - for (auto &it : conditions) { - string op = ExpressionTypeToOperator(it.comparison); - extra_info += it.left->GetName() + " " + op + " " + it.right->GetName() + "\n"; - } - extra_info += "\n[INFOSEPARATOR]\n"; - extra_info += StringUtil::Format("EC: %llu\n", estimated_cardinality); - return extra_info; -} - -void PhysicalComparisonJoin::ConstructEmptyJoinResult(JoinType join_type, bool has_null, DataChunk &input, - DataChunk &result) { - // empty hash table, special case - if (join_type == JoinType::ANTI) { - // anti join with empty hash table, NOP join - // return the input - D_ASSERT(input.ColumnCount() == result.ColumnCount()); - result.Reference(input); - } else if (join_type == JoinType::MARK) { - // MARK join with empty hash table - D_ASSERT(join_type == JoinType::MARK); - D_ASSERT(result.ColumnCount() == input.ColumnCount() + 1); - auto &result_vector = result.data.back(); - D_ASSERT(result_vector.GetType() == LogicalType::BOOLEAN); - // for every data vector, we just reference the child chunk - result.SetCardinality(input); - for (idx_t i = 0; i < input.ColumnCount(); i++) { - result.data[i].Reference(input.data[i]); - } - // for the MARK vector: - // if the HT has no NULL values (i.e. empty result set), return a vector that has false for every input - // entry if the HT has NULL values (i.e. result set had values, but all were NULL), return a vector that - // has NULL for every input entry - if (!has_null) { - auto bool_result = FlatVector::GetData(result_vector); - for (idx_t i = 0; i < result.size(); i++) { - bool_result[i] = false; - } - } else { - FlatVector::Validity(result_vector).SetAllInvalid(result.size()); - } - } else if (join_type == JoinType::LEFT || join_type == JoinType::OUTER || join_type == JoinType::SINGLE) { - // LEFT/FULL OUTER/SINGLE join and build side is empty - // for the LHS we reference the data - result.SetCardinality(input.size()); - for (idx_t i = 0; i < input.ColumnCount(); i++) { - result.data[i].Reference(input.data[i]); - } - // for the RHS - for (idx_t k = input.ColumnCount(); k < result.ColumnCount(); k++) { - result.data[k].SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result.data[k], true); - } - } -} -} // namespace duckdb - - - - - - -namespace duckdb { - -PhysicalCrossProduct::PhysicalCrossProduct(vector types, unique_ptr left, - unique_ptr right, idx_t estimated_cardinality) - : CachingPhysicalOperator(PhysicalOperatorType::CROSS_PRODUCT, std::move(types), estimated_cardinality) { - children.push_back(std::move(left)); - children.push_back(std::move(right)); -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class CrossProductGlobalState : public GlobalSinkState { -public: - explicit CrossProductGlobalState(ClientContext &context, const PhysicalCrossProduct &op) - : rhs_materialized(context, op.children[1]->GetTypes()) { - rhs_materialized.InitializeAppend(append_state); - } - - ColumnDataCollection rhs_materialized; - ColumnDataAppendState append_state; - mutex rhs_lock; -}; - -unique_ptr PhysicalCrossProduct::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(context, *this); -} - -SinkResultType PhysicalCrossProduct::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - auto &sink = input.global_state.Cast(); - lock_guard client_guard(sink.rhs_lock); - sink.rhs_materialized.Append(sink.append_state, chunk); - return SinkResultType::NEED_MORE_INPUT; -} - -//===--------------------------------------------------------------------===// -// Operator -//===--------------------------------------------------------------------===// -CrossProductExecutor::CrossProductExecutor(ColumnDataCollection &rhs) - : rhs(rhs), position_in_chunk(0), initialized(false), finished(false) { - rhs.InitializeScanChunk(scan_chunk); -} - -void CrossProductExecutor::Reset(DataChunk &input, DataChunk &output) { - initialized = true; - finished = false; - scan_input_chunk = false; - rhs.InitializeScan(scan_state); - position_in_chunk = 0; - scan_chunk.Reset(); -} - -bool CrossProductExecutor::NextValue(DataChunk &input, DataChunk &output) { - if (!initialized) { - // not initialized yet: initialize the scan - Reset(input, output); - } - position_in_chunk++; - idx_t chunk_size = scan_input_chunk ? input.size() : scan_chunk.size(); - if (position_in_chunk < chunk_size) { - return true; - } - // fetch the next chunk - rhs.Scan(scan_state, scan_chunk); - position_in_chunk = 0; - if (scan_chunk.size() == 0) { - return false; - } - // the way the cross product works is that we keep one chunk constantly referenced - // while iterating over the other chunk one value at a time - // the second one is the chunk we are "scanning" - - // for the engine, it is better if we emit larger chunks - // hence the chunk that we keep constantly referenced should be the larger of the two - scan_input_chunk = input.size() < scan_chunk.size(); - return true; -} - -OperatorResultType CrossProductExecutor::Execute(DataChunk &input, DataChunk &output) { - if (rhs.Count() == 0) { - // no RHS: empty result - return OperatorResultType::FINISHED; - } - if (!NextValue(input, output)) { - // ran out of entries on the RHS - // reset the RHS and move to the next chunk on the LHS - initialized = false; - return OperatorResultType::NEED_MORE_INPUT; - } - - // set up the constant chunk - auto &constant_chunk = scan_input_chunk ? scan_chunk : input; - auto col_count = constant_chunk.ColumnCount(); - auto col_offset = scan_input_chunk ? input.ColumnCount() : 0; - output.SetCardinality(constant_chunk.size()); - for (idx_t i = 0; i < col_count; i++) { - output.data[col_offset + i].Reference(constant_chunk.data[i]); - } - - // for the chunk that we are scanning, scan a single value from that chunk - auto &scan = scan_input_chunk ? input : scan_chunk; - col_count = scan.ColumnCount(); - col_offset = scan_input_chunk ? 0 : input.ColumnCount(); - for (idx_t i = 0; i < col_count; i++) { - ConstantVector::Reference(output.data[col_offset + i], scan.data[i], position_in_chunk, scan.size()); - } - return OperatorResultType::HAVE_MORE_OUTPUT; -} - -class CrossProductOperatorState : public CachingOperatorState { -public: - explicit CrossProductOperatorState(ColumnDataCollection &rhs) : executor(rhs) { - } - - CrossProductExecutor executor; -}; - -unique_ptr PhysicalCrossProduct::GetOperatorState(ExecutionContext &context) const { - auto &sink = sink_state->Cast(); - return make_uniq(sink.rhs_materialized); -} - -OperatorResultType PhysicalCrossProduct::ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - GlobalOperatorState &gstate, OperatorState &state_p) const { - auto &state = state_p.Cast(); - return state.executor.Execute(input, chunk); -} - -//===--------------------------------------------------------------------===// -// Pipeline Construction -//===--------------------------------------------------------------------===// -void PhysicalCrossProduct::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { - PhysicalJoin::BuildJoinPipelines(current, meta_pipeline, *this); -} - -vector> PhysicalCrossProduct::GetSources() const { - return children[0]->GetSources(); -} - -} // namespace duckdb - - - - - - - - - - - -namespace duckdb { - -PhysicalDelimJoin::PhysicalDelimJoin(vector types, unique_ptr original_join, - vector> delim_scans, idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::DELIM_JOIN, std::move(types), estimated_cardinality), - join(std::move(original_join)), delim_scans(std::move(delim_scans)) { - D_ASSERT(join->children.size() == 2); - // now for the original join - // we take its left child, this is the side that we will duplicate eliminate - children.push_back(std::move(join->children[0])); - - // we replace it with a PhysicalColumnDataScan, that scans the ColumnDataCollection that we keep cached - // the actual chunk collection to scan will be created in the DelimJoinGlobalState - auto cached_chunk_scan = make_uniq( - children[0]->GetTypes(), PhysicalOperatorType::COLUMN_DATA_SCAN, estimated_cardinality); - join->children[0] = std::move(cached_chunk_scan); -} - -vector> PhysicalDelimJoin::GetChildren() const { - vector> result; - for (auto &child : children) { - result.push_back(*child); - } - result.push_back(*join); - result.push_back(*distinct); - return result; -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class DelimJoinGlobalState : public GlobalSinkState { -public: - explicit DelimJoinGlobalState(ClientContext &context, const PhysicalDelimJoin &delim_join) - : lhs_data(context, delim_join.children[0]->GetTypes()) { - D_ASSERT(delim_join.delim_scans.size() > 0); - // set up the delim join chunk to scan in the original join - auto &cached_chunk_scan = delim_join.join->children[0]->Cast(); - cached_chunk_scan.collection = &lhs_data; - } - - ColumnDataCollection lhs_data; - mutex lhs_lock; - - void Merge(ColumnDataCollection &input) { - lock_guard guard(lhs_lock); - lhs_data.Combine(input); - } -}; - -class DelimJoinLocalState : public LocalSinkState { -public: - explicit DelimJoinLocalState(ClientContext &context, const PhysicalDelimJoin &delim_join) - : lhs_data(context, delim_join.children[0]->GetTypes()) { - lhs_data.InitializeAppend(append_state); - } - - unique_ptr distinct_state; - ColumnDataCollection lhs_data; - ColumnDataAppendState append_state; - - void Append(DataChunk &input) { - lhs_data.Append(input); - } -}; - -unique_ptr PhysicalDelimJoin::GetGlobalSinkState(ClientContext &context) const { - auto state = make_uniq(context, *this); - distinct->sink_state = distinct->GetGlobalSinkState(context); - if (delim_scans.size() > 1) { - PhysicalHashAggregate::SetMultiScan(*distinct->sink_state); - } - return std::move(state); -} - -unique_ptr PhysicalDelimJoin::GetLocalSinkState(ExecutionContext &context) const { - auto state = make_uniq(context.client, *this); - state->distinct_state = distinct->GetLocalSinkState(context); - return std::move(state); -} - -SinkResultType PhysicalDelimJoin::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - auto &lstate = input.local_state.Cast(); - lstate.lhs_data.Append(lstate.append_state, chunk); - OperatorSinkInput distinct_sink_input {*distinct->sink_state, *lstate.distinct_state, input.interrupt_state}; - distinct->Sink(context, chunk, distinct_sink_input); - return SinkResultType::NEED_MORE_INPUT; -} - -SinkCombineResultType PhysicalDelimJoin::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { - auto &lstate = input.local_state.Cast(); - auto &gstate = input.global_state.Cast(); - gstate.Merge(lstate.lhs_data); - - OperatorSinkCombineInput distinct_combine_input {*distinct->sink_state, *lstate.distinct_state, - input.interrupt_state}; - distinct->Combine(context, distinct_combine_input); - - return SinkCombineResultType::FINISHED; -} - -SinkFinalizeType PhysicalDelimJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &client, - OperatorSinkFinalizeInput &input) const { - // finalize the distinct HT - D_ASSERT(distinct); - - OperatorSinkFinalizeInput finalize_input {*distinct->sink_state, input.interrupt_state}; - distinct->Finalize(pipeline, event, client, finalize_input); - return SinkFinalizeType::READY; -} - -string PhysicalDelimJoin::ParamsToString() const { - return join->ParamsToString(); -} - -//===--------------------------------------------------------------------===// -// Pipeline Construction -//===--------------------------------------------------------------------===// -void PhysicalDelimJoin::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { - op_state.reset(); - sink_state.reset(); - - auto &child_meta_pipeline = meta_pipeline.CreateChildMetaPipeline(current, *this); - child_meta_pipeline.Build(*children[0]); - - if (type == PhysicalOperatorType::DELIM_JOIN) { - // recurse into the actual join - // any pipelines in there depend on the main pipeline - // any scan of the duplicate eliminated data on the RHS depends on this pipeline - // we add an entry to the mapping of (PhysicalOperator*) -> (Pipeline*) - auto &state = meta_pipeline.GetState(); - for (auto &delim_scan : delim_scans) { - state.delim_join_dependencies.insert( - make_pair(delim_scan, reference(*child_meta_pipeline.GetBasePipeline()))); - } - join->BuildPipelines(current, meta_pipeline); - } -} - -} // namespace duckdb - - - - - - - - - - - - - - - - -namespace duckdb { - -PhysicalHashJoin::PhysicalHashJoin(LogicalOperator &op, unique_ptr left, - unique_ptr right, vector cond, JoinType join_type, - const vector &left_projection_map, - const vector &right_projection_map_p, vector delim_types, - idx_t estimated_cardinality, PerfectHashJoinStats perfect_join_stats) - : PhysicalComparisonJoin(op, PhysicalOperatorType::HASH_JOIN, std::move(cond), join_type, estimated_cardinality), - right_projection_map(right_projection_map_p), delim_types(std::move(delim_types)), - perfect_join_statistics(std::move(perfect_join_stats)) { - - children.push_back(std::move(left)); - children.push_back(std::move(right)); - - D_ASSERT(left_projection_map.empty()); - for (auto &condition : conditions) { - condition_types.push_back(condition.left->return_type); - } - - // for ANTI, SEMI and MARK join, we only need to store the keys, so for these the build types are empty - if (join_type != JoinType::ANTI && join_type != JoinType::SEMI && join_type != JoinType::MARK) { - build_types = LogicalOperator::MapTypes(children[1]->GetTypes(), right_projection_map); - } -} - -PhysicalHashJoin::PhysicalHashJoin(LogicalOperator &op, unique_ptr left, - unique_ptr right, vector cond, JoinType join_type, - idx_t estimated_cardinality, PerfectHashJoinStats perfect_join_state) - : PhysicalHashJoin(op, std::move(left), std::move(right), std::move(cond), join_type, {}, {}, {}, - estimated_cardinality, std::move(perfect_join_state)) { -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class HashJoinGlobalSinkState : public GlobalSinkState { -public: - HashJoinGlobalSinkState(const PhysicalHashJoin &op, ClientContext &context_p) - : context(context_p), finalized(false), scanned_data(false) { - hash_table = op.InitializeHashTable(context); - - // for perfect hash join - perfect_join_executor = make_uniq(op, *hash_table, op.perfect_join_statistics); - // for external hash join - external = ClientConfig::GetConfig(context).force_external; - // Set probe types - const auto &payload_types = op.children[0]->types; - probe_types.insert(probe_types.end(), op.condition_types.begin(), op.condition_types.end()); - probe_types.insert(probe_types.end(), payload_types.begin(), payload_types.end()); - probe_types.emplace_back(LogicalType::HASH); - } - - void ScheduleFinalize(Pipeline &pipeline, Event &event); - void InitializeProbeSpill(); - -public: - ClientContext &context; - //! Global HT used by the join - unique_ptr hash_table; - //! The perfect hash join executor (if any) - unique_ptr perfect_join_executor; - //! Whether or not the hash table has been finalized - bool finalized = false; - - //! Whether we are doing an external join - bool external; - - //! Hash tables built by each thread - mutex lock; - vector> local_hash_tables; - - //! Excess probe data gathered during Sink - vector probe_types; - unique_ptr probe_spill; - - //! Whether or not we have started scanning data using GetData - atomic scanned_data; -}; - -class HashJoinLocalSinkState : public LocalSinkState { -public: - HashJoinLocalSinkState(const PhysicalHashJoin &op, ClientContext &context) : build_executor(context) { - auto &allocator = BufferAllocator::Get(context); - if (!op.right_projection_map.empty()) { - build_chunk.Initialize(allocator, op.build_types); - } - for (auto &cond : op.conditions) { - build_executor.AddExpression(*cond.right); - } - join_keys.Initialize(allocator, op.condition_types); - - hash_table = op.InitializeHashTable(context); - - hash_table->GetSinkCollection().InitializeAppendState(append_state); - } - -public: - PartitionedTupleDataAppendState append_state; - - DataChunk build_chunk; - DataChunk join_keys; - ExpressionExecutor build_executor; - - //! Thread-local HT - unique_ptr hash_table; -}; - -unique_ptr PhysicalHashJoin::InitializeHashTable(ClientContext &context) const { - auto result = - make_uniq(BufferManager::GetBufferManager(context), conditions, build_types, join_type); - result->max_ht_size = double(0.6) * BufferManager::GetBufferManager(context).GetMaxMemory(); - if (!delim_types.empty() && join_type == JoinType::MARK) { - // correlated MARK join - if (delim_types.size() + 1 == conditions.size()) { - // the correlated MARK join has one more condition than the amount of correlated columns - // this is the case in a correlated ANY() expression - // in this case we need to keep track of additional entries, namely: - // - (1) the total amount of elements per group - // - (2) the amount of non-null elements per group - // we need these to correctly deal with the cases of either: - // - (1) the group being empty [in which case the result is always false, even if the comparison is NULL] - // - (2) the group containing a NULL value [in which case FALSE becomes NULL] - auto &info = result->correlated_mark_join_info; - - vector payload_types; - vector correlated_aggregates; - unique_ptr aggr; - - // jury-rigging the GroupedAggregateHashTable - // we need a count_star and a count to get counts with and without NULLs - - FunctionBinder function_binder(context); - aggr = function_binder.BindAggregateFunction(CountStarFun::GetFunction(), {}, nullptr, - AggregateType::NON_DISTINCT); - correlated_aggregates.push_back(&*aggr); - payload_types.push_back(aggr->return_type); - info.correlated_aggregates.push_back(std::move(aggr)); - - auto count_fun = CountFun::GetFunction(); - vector> children; - // this is a dummy but we need it to make the hash table understand whats going on - children.push_back(make_uniq_base(count_fun.return_type, 0)); - aggr = function_binder.BindAggregateFunction(count_fun, std::move(children), nullptr, - AggregateType::NON_DISTINCT); - correlated_aggregates.push_back(&*aggr); - payload_types.push_back(aggr->return_type); - info.correlated_aggregates.push_back(std::move(aggr)); - - auto &allocator = BufferAllocator::Get(context); - info.correlated_counts = make_uniq(context, allocator, delim_types, - payload_types, correlated_aggregates); - info.correlated_types = delim_types; - info.group_chunk.Initialize(allocator, delim_types); - info.result_chunk.Initialize(allocator, payload_types); - } - } - return result; -} - -unique_ptr PhysicalHashJoin::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(*this, context); -} - -unique_ptr PhysicalHashJoin::GetLocalSinkState(ExecutionContext &context) const { - return make_uniq(*this, context.client); -} - -SinkResultType PhysicalHashJoin::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - auto &lstate = input.local_state.Cast(); - - // resolve the join keys for the right chunk - lstate.join_keys.Reset(); - lstate.build_executor.Execute(chunk, lstate.join_keys); - - // build the HT - auto &ht = *lstate.hash_table; - if (!right_projection_map.empty()) { - // there is a projection map: fill the build chunk with the projected columns - lstate.build_chunk.Reset(); - lstate.build_chunk.SetCardinality(chunk); - for (idx_t i = 0; i < right_projection_map.size(); i++) { - lstate.build_chunk.data[i].Reference(chunk.data[right_projection_map[i]]); - } - ht.Build(lstate.append_state, lstate.join_keys, lstate.build_chunk); - } else if (!build_types.empty()) { - // there is not a projected map: place the entire right chunk in the HT - ht.Build(lstate.append_state, lstate.join_keys, chunk); - } else { - // there are only keys: place an empty chunk in the payload - lstate.build_chunk.SetCardinality(chunk.size()); - ht.Build(lstate.append_state, lstate.join_keys, lstate.build_chunk); - } - - return SinkResultType::NEED_MORE_INPUT; -} - -SinkCombineResultType PhysicalHashJoin::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - if (lstate.hash_table) { - lstate.hash_table->GetSinkCollection().FlushAppendState(lstate.append_state); - lock_guard local_ht_lock(gstate.lock); - gstate.local_hash_tables.push_back(std::move(lstate.hash_table)); - } - auto &client_profiler = QueryProfiler::Get(context.client); - context.thread.profiler.Flush(*this, lstate.build_executor, "build_executor", 1); - client_profiler.Flush(context.thread.profiler); - - return SinkCombineResultType::FINISHED; -} - -//===--------------------------------------------------------------------===// -// Finalize -//===--------------------------------------------------------------------===// -class HashJoinFinalizeTask : public ExecutorTask { -public: - HashJoinFinalizeTask(shared_ptr event_p, ClientContext &context, HashJoinGlobalSinkState &sink_p, - idx_t chunk_idx_from_p, idx_t chunk_idx_to_p, bool parallel_p) - : ExecutorTask(context), event(std::move(event_p)), sink(sink_p), chunk_idx_from(chunk_idx_from_p), - chunk_idx_to(chunk_idx_to_p), parallel(parallel_p) { - } - - TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { - sink.hash_table->Finalize(chunk_idx_from, chunk_idx_to, parallel); - event->FinishTask(); - return TaskExecutionResult::TASK_FINISHED; - } - -private: - shared_ptr event; - HashJoinGlobalSinkState &sink; - idx_t chunk_idx_from; - idx_t chunk_idx_to; - bool parallel; -}; - -class HashJoinFinalizeEvent : public BasePipelineEvent { -public: - HashJoinFinalizeEvent(Pipeline &pipeline_p, HashJoinGlobalSinkState &sink) - : BasePipelineEvent(pipeline_p), sink(sink) { - } - - HashJoinGlobalSinkState &sink; - -public: - void Schedule() override { - auto &context = pipeline->GetClientContext(); - - vector> finalize_tasks; - auto &ht = *sink.hash_table; - const auto chunk_count = ht.GetDataCollection().ChunkCount(); - const idx_t num_threads = TaskScheduler::GetScheduler(context).NumberOfThreads(); - if (num_threads == 1 || (ht.Count() < PARALLEL_CONSTRUCT_THRESHOLD && !context.config.verify_parallelism)) { - // Single-threaded finalize - finalize_tasks.push_back( - make_uniq(shared_from_this(), context, sink, 0, chunk_count, false)); - } else { - // Parallel finalize - auto chunks_per_thread = MaxValue((chunk_count + num_threads - 1) / num_threads, 1); - - idx_t chunk_idx = 0; - for (idx_t thread_idx = 0; thread_idx < num_threads; thread_idx++) { - auto chunk_idx_from = chunk_idx; - auto chunk_idx_to = MinValue(chunk_idx_from + chunks_per_thread, chunk_count); - finalize_tasks.push_back(make_uniq(shared_from_this(), context, sink, - chunk_idx_from, chunk_idx_to, true)); - chunk_idx = chunk_idx_to; - if (chunk_idx == chunk_count) { - break; - } - } - } - SetTasks(std::move(finalize_tasks)); - } - - void FinishEvent() override { - sink.hash_table->GetDataCollection().VerifyEverythingPinned(); - sink.hash_table->finalized = true; - } - - static constexpr const idx_t PARALLEL_CONSTRUCT_THRESHOLD = 1048576; -}; - -void HashJoinGlobalSinkState::ScheduleFinalize(Pipeline &pipeline, Event &event) { - if (hash_table->Count() == 0) { - hash_table->finalized = true; - return; - } - hash_table->InitializePointerTable(); - auto new_event = make_shared(pipeline, *this); - event.InsertEvent(std::move(new_event)); -} - -void HashJoinGlobalSinkState::InitializeProbeSpill() { - lock_guard guard(lock); - if (!probe_spill) { - probe_spill = make_uniq(*hash_table, context, probe_types); - } -} - -class HashJoinRepartitionTask : public ExecutorTask { -public: - HashJoinRepartitionTask(shared_ptr event_p, ClientContext &context, JoinHashTable &global_ht, - JoinHashTable &local_ht) - : ExecutorTask(context), event(std::move(event_p)), global_ht(global_ht), local_ht(local_ht) { - } - - TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { - local_ht.Partition(global_ht); - event->FinishTask(); - return TaskExecutionResult::TASK_FINISHED; - } - -private: - shared_ptr event; - - JoinHashTable &global_ht; - JoinHashTable &local_ht; -}; - -class HashJoinPartitionEvent : public BasePipelineEvent { -public: - HashJoinPartitionEvent(Pipeline &pipeline_p, HashJoinGlobalSinkState &sink, - vector> &local_hts) - : BasePipelineEvent(pipeline_p), sink(sink), local_hts(local_hts) { - } - - HashJoinGlobalSinkState &sink; - vector> &local_hts; - -public: - void Schedule() override { - auto &context = pipeline->GetClientContext(); - vector> partition_tasks; - partition_tasks.reserve(local_hts.size()); - for (auto &local_ht : local_hts) { - partition_tasks.push_back( - make_uniq(shared_from_this(), context, *sink.hash_table, *local_ht)); - } - SetTasks(std::move(partition_tasks)); - } - - void FinishEvent() override { - local_hts.clear(); - sink.hash_table->PrepareExternalFinalize(); - sink.ScheduleFinalize(*pipeline, *this); - } -}; - -SinkFinalizeType PhysicalHashJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - auto &sink = input.global_state.Cast(); - auto &ht = *sink.hash_table; - - sink.external = ht.RequiresExternalJoin(context.config, sink.local_hash_tables); - if (sink.external) { - sink.perfect_join_executor.reset(); - if (ht.RequiresPartitioning(context.config, sink.local_hash_tables)) { - auto new_event = make_shared(pipeline, sink, sink.local_hash_tables); - event.InsertEvent(std::move(new_event)); - } else { - for (auto &local_ht : sink.local_hash_tables) { - ht.Merge(*local_ht); - } - sink.local_hash_tables.clear(); - sink.hash_table->PrepareExternalFinalize(); - sink.ScheduleFinalize(pipeline, event); - } - sink.finalized = true; - return SinkFinalizeType::READY; - } else { - for (auto &local_ht : sink.local_hash_tables) { - ht.Merge(*local_ht); - } - sink.local_hash_tables.clear(); - ht.Unpartition(); - } - - // check for possible perfect hash table - auto use_perfect_hash = sink.perfect_join_executor->CanDoPerfectHashJoin(); - if (use_perfect_hash) { - D_ASSERT(ht.equality_types.size() == 1); - auto key_type = ht.equality_types[0]; - use_perfect_hash = sink.perfect_join_executor->BuildPerfectHashTable(key_type); - } - // In case of a large build side or duplicates, use regular hash join - if (!use_perfect_hash) { - sink.perfect_join_executor.reset(); - sink.ScheduleFinalize(pipeline, event); - } - sink.finalized = true; - if (ht.Count() == 0 && EmptyResultIfRHSIsEmpty()) { - return SinkFinalizeType::NO_OUTPUT_POSSIBLE; - } - return SinkFinalizeType::READY; -} - -//===--------------------------------------------------------------------===// -// Operator -//===--------------------------------------------------------------------===// -class HashJoinOperatorState : public CachingOperatorState { -public: - explicit HashJoinOperatorState(ClientContext &context) : probe_executor(context), initialized(false) { - } - - DataChunk join_keys; - TupleDataChunkState join_key_state; - - ExpressionExecutor probe_executor; - unique_ptr scan_structure; - unique_ptr perfect_hash_join_state; - - bool initialized; - JoinHashTable::ProbeSpillLocalAppendState spill_state; - //! Chunk to sink data into for external join - DataChunk spill_chunk; - -public: - void Finalize(const PhysicalOperator &op, ExecutionContext &context) override { - context.thread.profiler.Flush(op, probe_executor, "probe_executor", 0); - } -}; - -unique_ptr PhysicalHashJoin::GetOperatorState(ExecutionContext &context) const { - auto &allocator = BufferAllocator::Get(context.client); - auto &sink = sink_state->Cast(); - auto state = make_uniq(context.client); - if (sink.perfect_join_executor) { - state->perfect_hash_join_state = sink.perfect_join_executor->GetOperatorState(context); - } else { - state->join_keys.Initialize(allocator, condition_types); - for (auto &cond : conditions) { - state->probe_executor.AddExpression(*cond.left); - } - TupleDataCollection::InitializeChunkState(state->join_key_state, condition_types); - } - if (sink.external) { - state->spill_chunk.Initialize(allocator, sink.probe_types); - sink.InitializeProbeSpill(); - } - - return std::move(state); -} - -OperatorResultType PhysicalHashJoin::ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - GlobalOperatorState &gstate, OperatorState &state_p) const { - auto &state = state_p.Cast(); - auto &sink = sink_state->Cast(); - D_ASSERT(sink.finalized); - D_ASSERT(!sink.scanned_data); - - // some initialization for external hash join - if (sink.external && !state.initialized) { - if (!sink.probe_spill) { - sink.InitializeProbeSpill(); - } - state.spill_state = sink.probe_spill->RegisterThread(); - state.initialized = true; - } - - if (sink.hash_table->Count() == 0 && EmptyResultIfRHSIsEmpty()) { - return OperatorResultType::FINISHED; - } - - if (sink.perfect_join_executor) { - D_ASSERT(!sink.external); - return sink.perfect_join_executor->ProbePerfectHashTable(context, input, chunk, *state.perfect_hash_join_state); - } - - if (state.scan_structure) { - // still have elements remaining (i.e. we got >STANDARD_VECTOR_SIZE elements in the previous probe) - state.scan_structure->Next(state.join_keys, input, chunk); - if (chunk.size() > 0) { - return OperatorResultType::HAVE_MORE_OUTPUT; - } - state.scan_structure = nullptr; - return OperatorResultType::NEED_MORE_INPUT; - } - - // probe the HT - if (sink.hash_table->Count() == 0) { - ConstructEmptyJoinResult(sink.hash_table->join_type, sink.hash_table->has_null, input, chunk); - return OperatorResultType::NEED_MORE_INPUT; - } - - // resolve the join keys for the left chunk - state.join_keys.Reset(); - state.probe_executor.Execute(input, state.join_keys); - - // perform the actual probe - if (sink.external) { - state.scan_structure = sink.hash_table->ProbeAndSpill(state.join_keys, state.join_key_state, input, - *sink.probe_spill, state.spill_state, state.spill_chunk); - } else { - state.scan_structure = sink.hash_table->Probe(state.join_keys, state.join_key_state); - } - state.scan_structure->Next(state.join_keys, input, chunk); - return OperatorResultType::HAVE_MORE_OUTPUT; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -enum class HashJoinSourceStage : uint8_t { INIT, BUILD, PROBE, SCAN_HT, DONE }; - -class HashJoinLocalSourceState; - -class HashJoinGlobalSourceState : public GlobalSourceState { -public: - HashJoinGlobalSourceState(const PhysicalHashJoin &op, ClientContext &context); - - //! Initialize this source state using the info in the sink - void Initialize(HashJoinGlobalSinkState &sink); - //! Try to prepare the next stage - void TryPrepareNextStage(HashJoinGlobalSinkState &sink); - //! Prepare the next build/probe/scan_ht stage for external hash join (must hold lock) - void PrepareBuild(HashJoinGlobalSinkState &sink); - void PrepareProbe(HashJoinGlobalSinkState &sink); - void PrepareScanHT(HashJoinGlobalSinkState &sink); - //! Assigns a task to a local source state - bool AssignTask(HashJoinGlobalSinkState &sink, HashJoinLocalSourceState &lstate); - - idx_t MaxThreads() override { - D_ASSERT(op.sink_state); - auto &gstate = op.sink_state->Cast(); - - idx_t count; - if (gstate.probe_spill) { - count = probe_count; - } else if (IsRightOuterJoin(op.join_type)) { - count = gstate.hash_table->Count(); - } else { - return 0; - } - return count / ((idx_t)STANDARD_VECTOR_SIZE * parallel_scan_chunk_count); - } - -public: - const PhysicalHashJoin &op; - - //! For synchronizing the external hash join - atomic global_stage; - mutex lock; - - //! For HT build synchronization - idx_t build_chunk_idx; - idx_t build_chunk_count; - idx_t build_chunk_done; - idx_t build_chunks_per_thread; - - //! For probe synchronization - idx_t probe_chunk_count; - idx_t probe_chunk_done; - - //! To determine the number of threads - idx_t probe_count; - idx_t parallel_scan_chunk_count; - - //! For full/outer synchronization - idx_t full_outer_chunk_idx; - idx_t full_outer_chunk_count; - idx_t full_outer_chunk_done; - idx_t full_outer_chunks_per_thread; -}; - -class HashJoinLocalSourceState : public LocalSourceState { -public: - HashJoinLocalSourceState(const PhysicalHashJoin &op, Allocator &allocator); - - //! Do the work this thread has been assigned - void ExecuteTask(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate, DataChunk &chunk); - //! Whether this thread has finished the work it has been assigned - bool TaskFinished(); - //! Build, probe and scan for external hash join - void ExternalBuild(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate); - void ExternalProbe(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate, DataChunk &chunk); - void ExternalScanHT(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate, DataChunk &chunk); - -public: - //! The stage that this thread was assigned work for - HashJoinSourceStage local_stage; - //! Vector with pointers here so we don't have to re-initialize - Vector addresses; - - //! Chunks assigned to this thread for building the pointer table - idx_t build_chunk_idx_from; - idx_t build_chunk_idx_to; - - //! Local scan state for probe spill - ColumnDataConsumerScanState probe_local_scan; - //! Chunks for holding the scanned probe collection - DataChunk probe_chunk; - DataChunk join_keys; - DataChunk payload; - TupleDataChunkState join_key_state; - //! Column indices to easily reference the join keys/payload columns in probe_chunk - vector join_key_indices; - vector payload_indices; - //! Scan structure for the external probe - unique_ptr scan_structure; - bool empty_ht_probe_in_progress; - - //! Chunks assigned to this thread for a full/outer scan - idx_t full_outer_chunk_idx_from; - idx_t full_outer_chunk_idx_to; - unique_ptr full_outer_scan_state; -}; - -unique_ptr PhysicalHashJoin::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(*this, context); -} - -unique_ptr PhysicalHashJoin::GetLocalSourceState(ExecutionContext &context, - GlobalSourceState &gstate) const { - return make_uniq(*this, BufferAllocator::Get(context.client)); -} - -HashJoinGlobalSourceState::HashJoinGlobalSourceState(const PhysicalHashJoin &op, ClientContext &context) - : op(op), global_stage(HashJoinSourceStage::INIT), build_chunk_count(0), build_chunk_done(0), probe_chunk_count(0), - probe_chunk_done(0), probe_count(op.children[0]->estimated_cardinality), - parallel_scan_chunk_count(context.config.verify_parallelism ? 1 : 120) { -} - -void HashJoinGlobalSourceState::Initialize(HashJoinGlobalSinkState &sink) { - lock_guard init_lock(lock); - if (global_stage != HashJoinSourceStage::INIT) { - // Another thread initialized - return; - } - - // Finalize the probe spill - if (sink.probe_spill) { - sink.probe_spill->Finalize(); - } - - global_stage = HashJoinSourceStage::PROBE; - TryPrepareNextStage(sink); -} - -void HashJoinGlobalSourceState::TryPrepareNextStage(HashJoinGlobalSinkState &sink) { - switch (global_stage.load()) { - case HashJoinSourceStage::BUILD: - if (build_chunk_done == build_chunk_count) { - sink.hash_table->GetDataCollection().VerifyEverythingPinned(); - sink.hash_table->finalized = true; - PrepareProbe(sink); - } - break; - case HashJoinSourceStage::PROBE: - if (probe_chunk_done == probe_chunk_count) { - if (IsRightOuterJoin(op.join_type)) { - PrepareScanHT(sink); - } else { - PrepareBuild(sink); - } - } - break; - case HashJoinSourceStage::SCAN_HT: - if (full_outer_chunk_done == full_outer_chunk_count) { - PrepareBuild(sink); - } - break; - default: - break; - } -} - -void HashJoinGlobalSourceState::PrepareBuild(HashJoinGlobalSinkState &sink) { - D_ASSERT(global_stage != HashJoinSourceStage::BUILD); - auto &ht = *sink.hash_table; - - // Try to put the next partitions in the block collection of the HT - if (!sink.external || !ht.PrepareExternalFinalize()) { - global_stage = HashJoinSourceStage::DONE; - return; - } - - auto &data_collection = ht.GetDataCollection(); - if (data_collection.Count() == 0 && op.EmptyResultIfRHSIsEmpty()) { - PrepareBuild(sink); - return; - } - - build_chunk_idx = 0; - build_chunk_count = data_collection.ChunkCount(); - build_chunk_done = 0; - - auto num_threads = TaskScheduler::GetScheduler(sink.context).NumberOfThreads(); - build_chunks_per_thread = MaxValue((build_chunk_count + num_threads - 1) / num_threads, 1); - - ht.InitializePointerTable(); - - global_stage = HashJoinSourceStage::BUILD; -} - -void HashJoinGlobalSourceState::PrepareProbe(HashJoinGlobalSinkState &sink) { - sink.probe_spill->PrepareNextProbe(); - const auto &consumer = *sink.probe_spill->consumer; - - probe_chunk_count = consumer.Count() == 0 ? 0 : consumer.ChunkCount(); - probe_chunk_done = 0; - - global_stage = HashJoinSourceStage::PROBE; - if (probe_chunk_count == 0) { - TryPrepareNextStage(sink); - return; - } -} - -void HashJoinGlobalSourceState::PrepareScanHT(HashJoinGlobalSinkState &sink) { - D_ASSERT(global_stage != HashJoinSourceStage::SCAN_HT); - auto &ht = *sink.hash_table; - - auto &data_collection = ht.GetDataCollection(); - full_outer_chunk_idx = 0; - full_outer_chunk_count = data_collection.ChunkCount(); - full_outer_chunk_done = 0; - - auto num_threads = TaskScheduler::GetScheduler(sink.context).NumberOfThreads(); - full_outer_chunks_per_thread = MaxValue((full_outer_chunk_count + num_threads - 1) / num_threads, 1); - - global_stage = HashJoinSourceStage::SCAN_HT; -} - -bool HashJoinGlobalSourceState::AssignTask(HashJoinGlobalSinkState &sink, HashJoinLocalSourceState &lstate) { - D_ASSERT(lstate.TaskFinished()); - - lock_guard guard(lock); - switch (global_stage.load()) { - case HashJoinSourceStage::BUILD: - if (build_chunk_idx != build_chunk_count) { - lstate.local_stage = global_stage; - lstate.build_chunk_idx_from = build_chunk_idx; - build_chunk_idx = MinValue(build_chunk_count, build_chunk_idx + build_chunks_per_thread); - lstate.build_chunk_idx_to = build_chunk_idx; - return true; - } - break; - case HashJoinSourceStage::PROBE: - if (sink.probe_spill->consumer && sink.probe_spill->consumer->AssignChunk(lstate.probe_local_scan)) { - lstate.local_stage = global_stage; - lstate.empty_ht_probe_in_progress = false; - return true; - } - break; - case HashJoinSourceStage::SCAN_HT: - if (full_outer_chunk_idx != full_outer_chunk_count) { - lstate.local_stage = global_stage; - lstate.full_outer_chunk_idx_from = full_outer_chunk_idx; - full_outer_chunk_idx = - MinValue(full_outer_chunk_count, full_outer_chunk_idx + full_outer_chunks_per_thread); - lstate.full_outer_chunk_idx_to = full_outer_chunk_idx; - return true; - } - break; - case HashJoinSourceStage::DONE: - break; - default: - throw InternalException("Unexpected HashJoinSourceStage in AssignTask!"); - } - return false; -} - -HashJoinLocalSourceState::HashJoinLocalSourceState(const PhysicalHashJoin &op, Allocator &allocator) - : local_stage(HashJoinSourceStage::INIT), addresses(LogicalType::POINTER) { - auto &chunk_state = probe_local_scan.current_chunk_state; - chunk_state.properties = ColumnDataScanProperties::ALLOW_ZERO_COPY; - - auto &sink = op.sink_state->Cast(); - probe_chunk.Initialize(allocator, sink.probe_types); - join_keys.Initialize(allocator, op.condition_types); - payload.Initialize(allocator, op.children[0]->types); - TupleDataCollection::InitializeChunkState(join_key_state, op.condition_types); - - // Store the indices of the columns to reference them easily - idx_t col_idx = 0; - for (; col_idx < op.condition_types.size(); col_idx++) { - join_key_indices.push_back(col_idx); - } - for (; col_idx < sink.probe_types.size() - 1; col_idx++) { - payload_indices.push_back(col_idx); - } -} - -void HashJoinLocalSourceState::ExecuteTask(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate, - DataChunk &chunk) { - switch (local_stage) { - case HashJoinSourceStage::BUILD: - ExternalBuild(sink, gstate); - break; - case HashJoinSourceStage::PROBE: - ExternalProbe(sink, gstate, chunk); - break; - case HashJoinSourceStage::SCAN_HT: - ExternalScanHT(sink, gstate, chunk); - break; - default: - throw InternalException("Unexpected HashJoinSourceStage in ExecuteTask!"); - } -} - -bool HashJoinLocalSourceState::TaskFinished() { - switch (local_stage) { - case HashJoinSourceStage::INIT: - case HashJoinSourceStage::BUILD: - return true; - case HashJoinSourceStage::PROBE: - return scan_structure == nullptr && !empty_ht_probe_in_progress; - case HashJoinSourceStage::SCAN_HT: - return full_outer_scan_state == nullptr; - default: - throw InternalException("Unexpected HashJoinSourceStage in TaskFinished!"); - } -} - -void HashJoinLocalSourceState::ExternalBuild(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate) { - D_ASSERT(local_stage == HashJoinSourceStage::BUILD); - - auto &ht = *sink.hash_table; - ht.Finalize(build_chunk_idx_from, build_chunk_idx_to, true); - - lock_guard guard(gstate.lock); - gstate.build_chunk_done += build_chunk_idx_to - build_chunk_idx_from; -} - -void HashJoinLocalSourceState::ExternalProbe(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate, - DataChunk &chunk) { - D_ASSERT(local_stage == HashJoinSourceStage::PROBE && sink.hash_table->finalized); - - if (scan_structure) { - // Still have elements remaining (i.e. we got >STANDARD_VECTOR_SIZE elements in the previous probe) - scan_structure->Next(join_keys, payload, chunk); - if (chunk.size() != 0) { - return; - } - } - - if (scan_structure || empty_ht_probe_in_progress) { - // Previous probe is done - scan_structure = nullptr; - empty_ht_probe_in_progress = false; - sink.probe_spill->consumer->FinishChunk(probe_local_scan); - lock_guard lock(gstate.lock); - gstate.probe_chunk_done++; - return; - } - - // Scan input chunk for next probe - sink.probe_spill->consumer->ScanChunk(probe_local_scan, probe_chunk); - - // Get the probe chunk columns/hashes - join_keys.ReferenceColumns(probe_chunk, join_key_indices); - payload.ReferenceColumns(probe_chunk, payload_indices); - auto precomputed_hashes = &probe_chunk.data.back(); - - if (sink.hash_table->Count() == 0 && !gstate.op.EmptyResultIfRHSIsEmpty()) { - gstate.op.ConstructEmptyJoinResult(sink.hash_table->join_type, sink.hash_table->has_null, payload, chunk); - empty_ht_probe_in_progress = true; - return; - } - - // Perform the probe - scan_structure = sink.hash_table->Probe(join_keys, join_key_state, precomputed_hashes); - scan_structure->Next(join_keys, payload, chunk); -} - -void HashJoinLocalSourceState::ExternalScanHT(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate, - DataChunk &chunk) { - D_ASSERT(local_stage == HashJoinSourceStage::SCAN_HT); - - if (!full_outer_scan_state) { - full_outer_scan_state = make_uniq(sink.hash_table->GetDataCollection(), - full_outer_chunk_idx_from, full_outer_chunk_idx_to); - } - sink.hash_table->ScanFullOuter(*full_outer_scan_state, addresses, chunk); - - if (chunk.size() == 0) { - full_outer_scan_state = nullptr; - lock_guard guard(gstate.lock); - gstate.full_outer_chunk_done += full_outer_chunk_idx_to - full_outer_chunk_idx_from; - } -} - -SourceResultType PhysicalHashJoin::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &sink = sink_state->Cast(); - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - sink.scanned_data = true; - - if (!sink.external && !IsRightOuterJoin(join_type)) { - return SourceResultType::FINISHED; - } - - if (gstate.global_stage == HashJoinSourceStage::INIT) { - gstate.Initialize(sink); - } - - // Any call to GetData must produce tuples, otherwise the pipeline executor thinks that we're done - // Therefore, we loop until we've produced tuples, or until the operator is actually done - while (gstate.global_stage != HashJoinSourceStage::DONE && chunk.size() == 0) { - if (!lstate.TaskFinished() || gstate.AssignTask(sink, lstate)) { - lstate.ExecuteTask(sink, gstate, chunk); - } else { - lock_guard guard(gstate.lock); - gstate.TryPrepareNextStage(sink); - } - } - - return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; -} - -} // namespace duckdb - - - - - - - - - - - - - - -#include - -namespace duckdb { - -PhysicalIEJoin::PhysicalIEJoin(LogicalComparisonJoin &op, unique_ptr left, - unique_ptr right, vector cond, JoinType join_type, - idx_t estimated_cardinality) - : PhysicalRangeJoin(op, PhysicalOperatorType::IE_JOIN, std::move(left), std::move(right), std::move(cond), - join_type, estimated_cardinality) { - - // 1. let L1 (resp. L2) be the array of column X (resp. Y) - D_ASSERT(conditions.size() >= 2); - lhs_orders.resize(2); - rhs_orders.resize(2); - for (idx_t i = 0; i < 2; ++i) { - auto &cond = conditions[i]; - D_ASSERT(cond.left->return_type == cond.right->return_type); - join_key_types.push_back(cond.left->return_type); - - // Convert the conditions to sort orders - auto left = cond.left->Copy(); - auto right = cond.right->Copy(); - auto sense = OrderType::INVALID; - - // 2. if (op1 ∈ {>, ≥}) sort L1 in descending order - // 3. else if (op1 ∈ {<, ≤}) sort L1 in ascending order - // 4. if (op2 ∈ {>, ≥}) sort L2 in ascending order - // 5. else if (op2 ∈ {<, ≤}) sort L2 in descending order - switch (cond.comparison) { - case ExpressionType::COMPARE_GREATERTHAN: - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - sense = i ? OrderType::ASCENDING : OrderType::DESCENDING; - break; - case ExpressionType::COMPARE_LESSTHAN: - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - sense = i ? OrderType::DESCENDING : OrderType::ASCENDING; - break; - default: - throw NotImplementedException("Unimplemented join type for IEJoin"); - } - lhs_orders[i].emplace_back(BoundOrderByNode(sense, OrderByNullType::NULLS_LAST, std::move(left))); - rhs_orders[i].emplace_back(BoundOrderByNode(sense, OrderByNullType::NULLS_LAST, std::move(right))); - } - - for (idx_t i = 2; i < conditions.size(); ++i) { - auto &cond = conditions[i]; - D_ASSERT(cond.left->return_type == cond.right->return_type); - join_key_types.push_back(cond.left->return_type); - } -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class IEJoinLocalState : public LocalSinkState { -public: - using LocalSortedTable = PhysicalRangeJoin::LocalSortedTable; - - IEJoinLocalState(ClientContext &context, const PhysicalRangeJoin &op, const idx_t child) - : table(context, op, child) { - } - - //! The local sort state - LocalSortedTable table; -}; - -class IEJoinGlobalState : public GlobalSinkState { -public: - using GlobalSortedTable = PhysicalRangeJoin::GlobalSortedTable; - -public: - IEJoinGlobalState(ClientContext &context, const PhysicalIEJoin &op) : child(0) { - tables.resize(2); - RowLayout lhs_layout; - lhs_layout.Initialize(op.children[0]->types); - vector lhs_order; - lhs_order.emplace_back(op.lhs_orders[0][0].Copy()); - tables[0] = make_uniq(context, lhs_order, lhs_layout); - - RowLayout rhs_layout; - rhs_layout.Initialize(op.children[1]->types); - vector rhs_order; - rhs_order.emplace_back(op.rhs_orders[0][0].Copy()); - tables[1] = make_uniq(context, rhs_order, rhs_layout); - } - - IEJoinGlobalState(IEJoinGlobalState &prev) - : GlobalSinkState(prev), tables(std::move(prev.tables)), child(prev.child + 1) { - } - - void Sink(DataChunk &input, IEJoinLocalState &lstate) { - auto &table = *tables[child]; - auto &global_sort_state = table.global_sort_state; - auto &local_sort_state = lstate.table.local_sort_state; - - // Sink the data into the local sort state - lstate.table.Sink(input, global_sort_state); - - // When sorting data reaches a certain size, we sort it - if (local_sort_state.SizeInBytes() >= table.memory_per_thread) { - local_sort_state.Sort(global_sort_state, true); - } - } - - vector> tables; - size_t child; -}; - -unique_ptr PhysicalIEJoin::GetGlobalSinkState(ClientContext &context) const { - D_ASSERT(!sink_state); - return make_uniq(context, *this); -} - -unique_ptr PhysicalIEJoin::GetLocalSinkState(ExecutionContext &context) const { - idx_t sink_child = 0; - if (sink_state) { - const auto &ie_sink = sink_state->Cast(); - sink_child = ie_sink.child; - } - return make_uniq(context.client, *this, sink_child); -} - -SinkResultType PhysicalIEJoin::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - - gstate.Sink(chunk, lstate); - - return SinkResultType::NEED_MORE_INPUT; -} - -SinkCombineResultType PhysicalIEJoin::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - gstate.tables[gstate.child]->Combine(lstate.table); - auto &client_profiler = QueryProfiler::Get(context.client); - - context.thread.profiler.Flush(*this, lstate.table.executor, gstate.child ? "rhs_executor" : "lhs_executor", 1); - client_profiler.Flush(context.thread.profiler); - - return SinkCombineResultType::FINISHED; -} - -//===--------------------------------------------------------------------===// -// Finalize -//===--------------------------------------------------------------------===// -SinkFinalizeType PhysicalIEJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &table = *gstate.tables[gstate.child]; - auto &global_sort_state = table.global_sort_state; - - if ((gstate.child == 1 && IsRightOuterJoin(join_type)) || (gstate.child == 0 && IsLeftOuterJoin(join_type))) { - // for FULL/LEFT/RIGHT OUTER JOIN, initialize found_match to false for every tuple - table.IntializeMatches(); - } - if (gstate.child == 1 && global_sort_state.sorted_blocks.empty() && EmptyResultIfRHSIsEmpty()) { - // Empty input! - return SinkFinalizeType::NO_OUTPUT_POSSIBLE; - } - - // Sort the current input child - table.Finalize(pipeline, event); - - // Move to the next input child - ++gstate.child; - - return SinkFinalizeType::READY; -} - -//===--------------------------------------------------------------------===// -// Operator -//===--------------------------------------------------------------------===// -OperatorResultType PhysicalIEJoin::ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - GlobalOperatorState &gstate, OperatorState &state) const { - return OperatorResultType::FINISHED; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -struct IEJoinUnion { - using SortedTable = PhysicalRangeJoin::GlobalSortedTable; - - static idx_t AppendKey(SortedTable &table, ExpressionExecutor &executor, SortedTable &marked, int64_t increment, - int64_t base, const idx_t block_idx); - - static void Sort(SortedTable &table) { - auto &global_sort_state = table.global_sort_state; - global_sort_state.PrepareMergePhase(); - while (global_sort_state.sorted_blocks.size() > 1) { - global_sort_state.InitializeMergeRound(); - MergeSorter merge_sorter(global_sort_state, global_sort_state.buffer_manager); - merge_sorter.PerformInMergeRound(); - global_sort_state.CompleteMergeRound(true); - } - } - - template - static vector ExtractColumn(SortedTable &table, idx_t col_idx) { - vector result; - result.reserve(table.count); - - auto &gstate = table.global_sort_state; - auto &blocks = *gstate.sorted_blocks[0]->payload_data; - PayloadScanner scanner(blocks, gstate, false); - - DataChunk payload; - payload.Initialize(Allocator::DefaultAllocator(), gstate.payload_layout.GetTypes()); - for (;;) { - scanner.Scan(payload); - const auto count = payload.size(); - if (!count) { - break; - } - - const auto data_ptr = FlatVector::GetData(payload.data[col_idx]); - result.insert(result.end(), data_ptr, data_ptr + count); - } - - return result; - } - - IEJoinUnion(ClientContext &context, const PhysicalIEJoin &op, SortedTable &t1, const idx_t b1, SortedTable &t2, - const idx_t b2); - - idx_t SearchL1(idx_t pos); - bool NextRow(); - - //! Inverted loop - idx_t JoinComplexBlocks(SelectionVector &lsel, SelectionVector &rsel); - - //! L1 - unique_ptr l1; - //! L2 - unique_ptr l2; - - //! Li - vector li; - //! P - vector p; - - //! B - vector bit_array; - ValidityMask bit_mask; - - //! Bloom Filter - static constexpr idx_t BLOOM_CHUNK_BITS = 1024; - idx_t bloom_count; - vector bloom_array; - ValidityMask bloom_filter; - - //! Iteration state - idx_t n; - idx_t i; - idx_t j; - unique_ptr op1; - unique_ptr off1; - unique_ptr op2; - unique_ptr off2; - int64_t lrid; -}; - -idx_t IEJoinUnion::AppendKey(SortedTable &table, ExpressionExecutor &executor, SortedTable &marked, int64_t increment, - int64_t base, const idx_t block_idx) { - LocalSortState local_sort_state; - local_sort_state.Initialize(marked.global_sort_state, marked.global_sort_state.buffer_manager); - - // Reading - const auto valid = table.count - table.has_null; - auto &gstate = table.global_sort_state; - PayloadScanner scanner(gstate, block_idx); - auto table_idx = block_idx * gstate.block_capacity; - - DataChunk scanned; - scanned.Initialize(Allocator::DefaultAllocator(), scanner.GetPayloadTypes()); - - // Writing - auto types = local_sort_state.sort_layout->logical_types; - const idx_t payload_idx = types.size(); - - const auto &payload_types = local_sort_state.payload_layout->GetTypes(); - types.insert(types.end(), payload_types.begin(), payload_types.end()); - const idx_t rid_idx = types.size() - 1; - - DataChunk keys; - DataChunk payload; - keys.Initialize(Allocator::DefaultAllocator(), types); - - idx_t inserted = 0; - for (auto rid = base; table_idx < valid;) { - scanner.Scan(scanned); - - // NULLs are at the end, so stop when we reach them - auto scan_count = scanned.size(); - if (table_idx + scan_count > valid) { - scan_count = valid - table_idx; - scanned.SetCardinality(scan_count); - } - if (scan_count == 0) { - break; - } - table_idx += scan_count; - - // Compute the input columns from the payload - keys.Reset(); - keys.Split(payload, rid_idx); - executor.Execute(scanned, keys); - - // Mark the rid column - payload.data[0].Sequence(rid, increment, scan_count); - payload.SetCardinality(scan_count); - keys.Fuse(payload); - rid += increment * scan_count; - - // Sort on the sort columns (which will no longer be needed) - keys.Split(payload, payload_idx); - local_sort_state.SinkChunk(keys, payload); - inserted += scan_count; - keys.Fuse(payload); - - // Flush when we have enough data - if (local_sort_state.SizeInBytes() >= marked.memory_per_thread) { - local_sort_state.Sort(marked.global_sort_state, true); - } - } - marked.global_sort_state.AddLocalState(local_sort_state); - marked.count += inserted; - - return inserted; -} - -IEJoinUnion::IEJoinUnion(ClientContext &context, const PhysicalIEJoin &op, SortedTable &t1, const idx_t b1, - SortedTable &t2, const idx_t b2) - : n(0), i(0) { - // input : query Q with 2 join predicates t1.X op1 t2.X' and t1.Y op2 t2.Y', tables T, T' of sizes m and n resp. - // output: a list of tuple pairs (ti , tj) - // Note that T/T' are already sorted on X/X' and contain the payload data - // We only join the two block numbers and use the sizes of the blocks as the counts - - // 0. Filter out tables with no overlap - if (!t1.BlockSize(b1) || !t2.BlockSize(b2)) { - return; - } - - const auto &cmp1 = op.conditions[0].comparison; - SBIterator bounds1(t1.global_sort_state, cmp1); - SBIterator bounds2(t2.global_sort_state, cmp1); - - // t1.X[0] op1 t2.X'[-1] - bounds1.SetIndex(bounds1.block_capacity * b1); - bounds2.SetIndex(bounds2.block_capacity * b2 + t2.BlockSize(b2) - 1); - if (!bounds1.Compare(bounds2)) { - return; - } - - // 1. let L1 (resp. L2) be the array of column X (resp. Y ) - const auto &order1 = op.lhs_orders[0][0]; - const auto &order2 = op.lhs_orders[1][0]; - - // 2. if (op1 ∈ {>, ≥}) sort L1 in descending order - // 3. else if (op1 ∈ {<, ≤}) sort L1 in ascending order - - // For the union algorithm, we make a unified table with the keys and the rids as the payload: - // X/X', Y/Y', R/R'/Li - // The first position is the sort key. - vector types; - types.emplace_back(order2.expression->return_type); - types.emplace_back(LogicalType::BIGINT); - RowLayout payload_layout; - payload_layout.Initialize(types); - - // Sort on the first expression - auto ref = make_uniq(order1.expression->return_type, 0); - vector orders; - orders.emplace_back(order1.type, order1.null_order, std::move(ref)); - - l1 = make_uniq(context, orders, payload_layout); - - // LHS has positive rids - ExpressionExecutor l_executor(context); - l_executor.AddExpression(*order1.expression); - l_executor.AddExpression(*order2.expression); - AppendKey(t1, l_executor, *l1, 1, 1, b1); - - // RHS has negative rids - ExpressionExecutor r_executor(context); - r_executor.AddExpression(*op.rhs_orders[0][0].expression); - r_executor.AddExpression(*op.rhs_orders[1][0].expression); - AppendKey(t2, r_executor, *l1, -1, -1, b2); - - if (l1->global_sort_state.sorted_blocks.empty()) { - return; - } - - Sort(*l1); - - op1 = make_uniq(l1->global_sort_state, cmp1); - off1 = make_uniq(l1->global_sort_state, cmp1); - - // We don't actually need the L1 column, just its sort key, which is in the sort blocks - li = ExtractColumn(*l1, types.size() - 1); - - // 4. if (op2 ∈ {>, ≥}) sort L2 in ascending order - // 5. else if (op2 ∈ {<, ≤}) sort L2 in descending order - - // We sort on Y/Y' to obtain the sort keys and the permutation array. - // For this we just need a two-column table of Y, P - types.clear(); - types.emplace_back(LogicalType::BIGINT); - payload_layout.Initialize(types); - - // Sort on the first expression - orders.clear(); - ref = make_uniq(order2.expression->return_type, 0); - orders.emplace_back(order2.type, order2.null_order, std::move(ref)); - - ExpressionExecutor executor(context); - executor.AddExpression(*orders[0].expression); - - l2 = make_uniq(context, orders, payload_layout); - for (idx_t base = 0, block_idx = 0; block_idx < l1->BlockCount(); ++block_idx) { - base += AppendKey(*l1, executor, *l2, 1, base, block_idx); - } - - Sort(*l2); - - // We don't actually need the L2 column, just its sort key, which is in the sort blocks - - // 6. compute the permutation array P of L2 w.r.t. L1 - p = ExtractColumn(*l2, types.size() - 1); - - // 7. initialize bit-array B (|B| = n), and set all bits to 0 - n = l2->count.load(); - bit_array.resize(ValidityMask::EntryCount(n), 0); - bit_mask.Initialize(bit_array.data()); - - // Bloom filter - bloom_count = (n + (BLOOM_CHUNK_BITS - 1)) / BLOOM_CHUNK_BITS; - bloom_array.resize(ValidityMask::EntryCount(bloom_count), 0); - bloom_filter.Initialize(bloom_array.data()); - - // 11. for(i←1 to n) do - const auto &cmp2 = op.conditions[1].comparison; - op2 = make_uniq(l2->global_sort_state, cmp2); - off2 = make_uniq(l2->global_sort_state, cmp2); - i = 0; - j = 0; - (void)NextRow(); -} - -idx_t IEJoinUnion::SearchL1(idx_t pos) { - // Perform an exponential search in the appropriate direction - op1->SetIndex(pos); - - idx_t step = 1; - auto hi = pos; - auto lo = pos; - if (!op1->cmp) { - // Scan left for loose inequality - lo -= MinValue(step, lo); - step *= 2; - off1->SetIndex(lo); - while (lo > 0 && op1->Compare(*off1)) { - hi = lo; - lo -= MinValue(step, lo); - step *= 2; - off1->SetIndex(lo); - } - } else { - // Scan right for strict inequality - hi += MinValue(step, n - hi); - step *= 2; - off1->SetIndex(hi); - while (hi < n && !op1->Compare(*off1)) { - lo = hi; - hi += MinValue(step, n - hi); - step *= 2; - off1->SetIndex(hi); - } - } - - // Binary search the target area - while (lo < hi) { - const auto mid = lo + (hi - lo) / 2; - off1->SetIndex(mid); - if (op1->Compare(*off1)) { - hi = mid; - } else { - lo = mid + 1; - } - } - - off1->SetIndex(lo); - - return lo; -} - -bool IEJoinUnion::NextRow() { - for (; i < n; ++i) { - // 12. pos ← P[i] - auto pos = p[i]; - lrid = li[pos]; - if (lrid < 0) { - continue; - } - - // 16. B[pos] ← 1 - op2->SetIndex(i); - for (; off2->GetIndex() < n; ++(*off2)) { - if (!off2->Compare(*op2)) { - break; - } - const auto p2 = p[off2->GetIndex()]; - if (li[p2] < 0) { - // Only mark rhs matches. - bit_mask.SetValid(p2); - bloom_filter.SetValid(p2 / BLOOM_CHUNK_BITS); - } - } - - // 9. if (op1 ∈ {≤,≥} and op2 ∈ {≤,≥}) eqOff = 0 - // 10. else eqOff = 1 - // No, because there could be more than one equal value. - // Find the leftmost off1 where L1[pos] op1 L1[off1..n] - // These are the rows that satisfy the op1 condition - // and that is where we should start scanning B from - j = SearchL1(pos); - - return true; - } - return false; -} - -static idx_t NextValid(const ValidityMask &bits, idx_t j, const idx_t n) { - if (j >= n) { - return n; - } - - // We can do a first approximation by checking entries one at a time - // which gives 64:1. - idx_t entry_idx, idx_in_entry; - bits.GetEntryIndex(j, entry_idx, idx_in_entry); - auto entry = bits.GetValidityEntry(entry_idx++); - - // Trim the bits before the start position - entry &= (ValidityMask::ValidityBuffer::MAX_ENTRY << idx_in_entry); - - // Check the non-ragged entries - for (const auto entry_count = bits.EntryCount(n); entry_idx < entry_count; ++entry_idx) { - if (entry) { - for (; idx_in_entry < bits.BITS_PER_VALUE; ++idx_in_entry, ++j) { - if (bits.RowIsValid(entry, idx_in_entry)) { - return j; - } - } - } else { - j += bits.BITS_PER_VALUE - idx_in_entry; - } - - entry = bits.GetValidityEntry(entry_idx); - idx_in_entry = 0; - } - - // Check the final entry - for (; j < n; ++idx_in_entry, ++j) { - if (bits.RowIsValid(entry, idx_in_entry)) { - return j; - } - } - - return j; -} - -idx_t IEJoinUnion::JoinComplexBlocks(SelectionVector &lsel, SelectionVector &rsel) { - // 8. initialize join result as an empty list for tuple pairs - idx_t result_count = 0; - - // 11. for(i←1 to n) do - while (i < n) { - // 13. for (j ← pos+eqOff to n) do - for (;;) { - // 14. if B[j] = 1 then - - // Use the Bloom filter to find candidate blocks - while (j < n) { - auto bloom_begin = NextValid(bloom_filter, j / BLOOM_CHUNK_BITS, bloom_count) * BLOOM_CHUNK_BITS; - auto bloom_end = MinValue(n, bloom_begin + BLOOM_CHUNK_BITS); - - j = MaxValue(j, bloom_begin); - j = NextValid(bit_mask, j, bloom_end); - if (j < bloom_end) { - break; - } - } - - if (j >= n) { - break; - } - - // Filter out tuples with the same sign (they come from the same table) - const auto rrid = li[j]; - ++j; - - // 15. add tuples w.r.t. (L1[j], L1[i]) to join result - if (lrid > 0 && rrid < 0) { - lsel.set_index(result_count, sel_t(+lrid - 1)); - rsel.set_index(result_count, sel_t(-rrid - 1)); - ++result_count; - if (result_count == STANDARD_VECTOR_SIZE) { - // out of space! - return result_count; - } - } - } - ++i; - - if (!NextRow()) { - break; - } - } - - return result_count; -} - -class IEJoinLocalSourceState : public LocalSourceState { -public: - explicit IEJoinLocalSourceState(ClientContext &context, const PhysicalIEJoin &op) - : op(op), true_sel(STANDARD_VECTOR_SIZE), left_executor(context), right_executor(context), - left_matches(nullptr), right_matches(nullptr) { - auto &allocator = Allocator::Get(context); - unprojected.Initialize(allocator, op.unprojected_types); - - if (op.conditions.size() < 3) { - return; - } - - vector left_types; - vector right_types; - for (idx_t i = 2; i < op.conditions.size(); ++i) { - const auto &cond = op.conditions[i]; - - left_types.push_back(cond.left->return_type); - left_executor.AddExpression(*cond.left); - - right_types.push_back(cond.left->return_type); - right_executor.AddExpression(*cond.right); - } - - left_keys.Initialize(allocator, left_types); - right_keys.Initialize(allocator, right_types); - } - - idx_t SelectOuterRows(bool *matches) { - idx_t count = 0; - for (; outer_idx < outer_count; ++outer_idx) { - if (!matches[outer_idx]) { - true_sel.set_index(count++, outer_idx); - if (count >= STANDARD_VECTOR_SIZE) { - outer_idx++; - break; - } - } - } - - return count; - } - - const PhysicalIEJoin &op; - - // Joining - unique_ptr joiner; - - idx_t left_base; - idx_t left_block_index; - - idx_t right_base; - idx_t right_block_index; - - // Trailing predicates - SelectionVector true_sel; - - ExpressionExecutor left_executor; - DataChunk left_keys; - - ExpressionExecutor right_executor; - DataChunk right_keys; - - DataChunk unprojected; - - // Outer joins - idx_t outer_idx; - idx_t outer_count; - bool *left_matches; - bool *right_matches; -}; - -void PhysicalIEJoin::ResolveComplexJoin(ExecutionContext &context, DataChunk &result, LocalSourceState &state_p) const { - auto &state = state_p.Cast(); - auto &ie_sink = sink_state->Cast(); - auto &left_table = *ie_sink.tables[0]; - auto &right_table = *ie_sink.tables[1]; - - const auto left_cols = children[0]->GetTypes().size(); - auto &chunk = state.unprojected; - do { - SelectionVector lsel(STANDARD_VECTOR_SIZE); - SelectionVector rsel(STANDARD_VECTOR_SIZE); - auto result_count = state.joiner->JoinComplexBlocks(lsel, rsel); - if (result_count == 0) { - // exhausted this pair - return; - } - - // found matches: extract them - - chunk.Reset(); - SliceSortedPayload(chunk, left_table.global_sort_state, state.left_block_index, lsel, result_count, 0); - SliceSortedPayload(chunk, right_table.global_sort_state, state.right_block_index, rsel, result_count, - left_cols); - chunk.SetCardinality(result_count); - - auto sel = FlatVector::IncrementalSelectionVector(); - if (conditions.size() > 2) { - // If there are more expressions to compute, - // split the result chunk into the left and right halves - // so we can compute the values for comparison. - const auto tail_cols = conditions.size() - 2; - - DataChunk right_chunk; - chunk.Split(right_chunk, left_cols); - state.left_executor.SetChunk(chunk); - state.right_executor.SetChunk(right_chunk); - - auto tail_count = result_count; - auto true_sel = &state.true_sel; - for (size_t cmp_idx = 0; cmp_idx < tail_cols; ++cmp_idx) { - auto &left = state.left_keys.data[cmp_idx]; - state.left_executor.ExecuteExpression(cmp_idx, left); - - auto &right = state.right_keys.data[cmp_idx]; - state.right_executor.ExecuteExpression(cmp_idx, right); - - if (tail_count < result_count) { - left.Slice(*sel, tail_count); - right.Slice(*sel, tail_count); - } - tail_count = SelectJoinTail(conditions[cmp_idx + 2].comparison, left, right, sel, tail_count, true_sel); - sel = true_sel; - } - chunk.Fuse(right_chunk); - - if (tail_count < result_count) { - result_count = tail_count; - chunk.Slice(*sel, result_count); - } - } - - // We need all of the data to compute other predicates, - // but we only return what is in the projection map - ProjectResult(chunk, result); - - // found matches: mark the found matches if required - if (left_table.found_match) { - for (idx_t i = 0; i < result_count; i++) { - left_table.found_match[state.left_base + lsel[sel->get_index(i)]] = true; - } - } - if (right_table.found_match) { - for (idx_t i = 0; i < result_count; i++) { - right_table.found_match[state.right_base + rsel[sel->get_index(i)]] = true; - } - } - result.Verify(); - } while (result.size() == 0); -} - -class IEJoinGlobalSourceState : public GlobalSourceState { -public: - explicit IEJoinGlobalSourceState(const PhysicalIEJoin &op) - : op(op), initialized(false), next_pair(0), completed(0), left_outers(0), next_left(0), right_outers(0), - next_right(0) { - } - - void Initialize(IEJoinGlobalState &sink_state) { - lock_guard initializing(lock); - if (initialized) { - return; - } - - // Compute the starting row for reach block - // (In theory these are all the same size, but you never know...) - auto &left_table = *sink_state.tables[0]; - const auto left_blocks = left_table.BlockCount(); - idx_t left_base = 0; - - for (size_t lhs = 0; lhs < left_blocks; ++lhs) { - left_bases.emplace_back(left_base); - left_base += left_table.BlockSize(lhs); - } - - auto &right_table = *sink_state.tables[1]; - const auto right_blocks = right_table.BlockCount(); - idx_t right_base = 0; - for (size_t rhs = 0; rhs < right_blocks; ++rhs) { - right_bases.emplace_back(right_base); - right_base += right_table.BlockSize(rhs); - } - - // Outer join block counts - if (left_table.found_match) { - left_outers = left_blocks; - } - - if (right_table.found_match) { - right_outers = right_blocks; - } - - // Ready for action - initialized = true; - } - -public: - idx_t MaxThreads() override { - // We can't leverage any more threads than block pairs. - const auto &sink_state = (op.sink_state->Cast()); - return sink_state.tables[0]->BlockCount() * sink_state.tables[1]->BlockCount(); - } - - void GetNextPair(ClientContext &client, IEJoinGlobalState &gstate, IEJoinLocalSourceState &lstate) { - auto &left_table = *gstate.tables[0]; - auto &right_table = *gstate.tables[1]; - - const auto left_blocks = left_table.BlockCount(); - const auto right_blocks = right_table.BlockCount(); - const auto pair_count = left_blocks * right_blocks; - - // Regular block - const auto i = next_pair++; - if (i < pair_count) { - const auto b1 = i / right_blocks; - const auto b2 = i % right_blocks; - - lstate.left_block_index = b1; - lstate.left_base = left_bases[b1]; - - lstate.right_block_index = b2; - lstate.right_base = right_bases[b2]; - - lstate.joiner = make_uniq(client, op, left_table, b1, right_table, b2); - return; - } - - // Outer joins - if (!left_outers && !right_outers) { - return; - } - - // Spin wait for regular blocks to finish(!) - while (completed < pair_count) { - std::this_thread::yield(); - } - - // Left outer blocks - const auto l = next_left++; - if (l < left_outers) { - lstate.joiner = nullptr; - lstate.left_block_index = l; - lstate.left_base = left_bases[l]; - - lstate.left_matches = left_table.found_match.get() + lstate.left_base; - lstate.outer_idx = 0; - lstate.outer_count = left_table.BlockSize(l); - return; - } else { - lstate.left_matches = nullptr; - } - - // Right outer block - const auto r = next_right++; - if (r < right_outers) { - lstate.joiner = nullptr; - lstate.right_block_index = r; - lstate.right_base = right_bases[r]; - - lstate.right_matches = right_table.found_match.get() + lstate.right_base; - lstate.outer_idx = 0; - lstate.outer_count = right_table.BlockSize(r); - return; - } else { - lstate.right_matches = nullptr; - } - } - - void PairCompleted(ClientContext &client, IEJoinGlobalState &gstate, IEJoinLocalSourceState &lstate) { - lstate.joiner.reset(); - ++completed; - GetNextPair(client, gstate, lstate); - } - - const PhysicalIEJoin &op; - - mutex lock; - bool initialized; - - // Join queue state - std::atomic next_pair; - std::atomic completed; - - // Block base row number - vector left_bases; - vector right_bases; - - // Outer joins - idx_t left_outers; - std::atomic next_left; - - idx_t right_outers; - std::atomic next_right; -}; - -unique_ptr PhysicalIEJoin::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(*this); -} - -unique_ptr PhysicalIEJoin::GetLocalSourceState(ExecutionContext &context, - GlobalSourceState &gstate) const { - return make_uniq(context.client, *this); -} - -SourceResultType PhysicalIEJoin::GetData(ExecutionContext &context, DataChunk &result, - OperatorSourceInput &input) const { - auto &ie_sink = sink_state->Cast(); - auto &ie_gstate = input.global_state.Cast(); - auto &ie_lstate = input.local_state.Cast(); - - ie_gstate.Initialize(ie_sink); - - if (!ie_lstate.joiner && !ie_lstate.left_matches && !ie_lstate.right_matches) { - ie_gstate.GetNextPair(context.client, ie_sink, ie_lstate); - } - - // Process INNER results - while (ie_lstate.joiner) { - ResolveComplexJoin(context, result, ie_lstate); - - if (result.size()) { - return SourceResultType::HAVE_MORE_OUTPUT; - } - - ie_gstate.PairCompleted(context.client, ie_sink, ie_lstate); - } - - // Process LEFT OUTER results - const auto left_cols = children[0]->GetTypes().size(); - while (ie_lstate.left_matches) { - const idx_t count = ie_lstate.SelectOuterRows(ie_lstate.left_matches); - if (!count) { - ie_gstate.GetNextPair(context.client, ie_sink, ie_lstate); - continue; - } - auto &chunk = ie_lstate.unprojected; - chunk.Reset(); - SliceSortedPayload(chunk, ie_sink.tables[0]->global_sort_state, ie_lstate.left_block_index, ie_lstate.true_sel, - count); - - // Fill in NULLs to the right - for (auto col_idx = left_cols; col_idx < chunk.ColumnCount(); ++col_idx) { - chunk.data[col_idx].SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(chunk.data[col_idx], true); - } - - ProjectResult(chunk, result); - result.SetCardinality(count); - result.Verify(); - - return result.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; - } - - // Process RIGHT OUTER results - while (ie_lstate.right_matches) { - const idx_t count = ie_lstate.SelectOuterRows(ie_lstate.right_matches); - if (!count) { - ie_gstate.GetNextPair(context.client, ie_sink, ie_lstate); - continue; - } - - auto &chunk = ie_lstate.unprojected; - chunk.Reset(); - SliceSortedPayload(chunk, ie_sink.tables[1]->global_sort_state, ie_lstate.right_block_index, ie_lstate.true_sel, - count, left_cols); - - // Fill in NULLs to the left - for (idx_t col_idx = 0; col_idx < left_cols; ++col_idx) { - chunk.data[col_idx].SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(chunk.data[col_idx], true); - } - - ProjectResult(chunk, result); - result.SetCardinality(count); - result.Verify(); - - break; - } - - return result.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; -} - -//===--------------------------------------------------------------------===// -// Pipeline Construction -//===--------------------------------------------------------------------===// -void PhysicalIEJoin::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { - D_ASSERT(children.size() == 2); - if (meta_pipeline.HasRecursiveCTE()) { - throw NotImplementedException("IEJoins are not supported in recursive CTEs yet"); - } - - // becomes a source after both children fully sink their data - meta_pipeline.GetState().SetPipelineSource(current, *this); - - // Create one child meta pipeline that will hold the LHS and RHS pipelines - auto &child_meta_pipeline = meta_pipeline.CreateChildMetaPipeline(current, *this); - - // Build out LHS - auto lhs_pipeline = child_meta_pipeline.GetBasePipeline(); - children[0]->BuildPipelines(*lhs_pipeline, child_meta_pipeline); - - // Build out RHS - auto rhs_pipeline = child_meta_pipeline.CreatePipeline(); - children[1]->BuildPipelines(*rhs_pipeline, child_meta_pipeline); - - // Despite having the same sink, RHS and everything created after it need their own (same) PipelineFinishEvent - child_meta_pipeline.AddFinishEvent(rhs_pipeline); -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - -namespace duckdb { - -class IndexJoinOperatorState : public CachingOperatorState { -public: - IndexJoinOperatorState(ClientContext &context, const PhysicalIndexJoin &op) - : probe_executor(context), arena_allocator(BufferAllocator::Get(context)), keys(STANDARD_VECTOR_SIZE) { - auto &allocator = Allocator::Get(context); - rhs_rows.resize(STANDARD_VECTOR_SIZE); - result_sizes.resize(STANDARD_VECTOR_SIZE); - - join_keys.Initialize(allocator, op.condition_types); - for (auto &cond : op.conditions) { - probe_executor.AddExpression(*cond.left); - } - if (!op.fetch_types.empty()) { - rhs_chunk.Initialize(allocator, op.fetch_types); - } - rhs_sel.Initialize(STANDARD_VECTOR_SIZE); - } - - bool first_fetch = true; - idx_t lhs_idx = 0; - idx_t rhs_idx = 0; - idx_t result_size = 0; - vector result_sizes; - DataChunk join_keys; - DataChunk rhs_chunk; - SelectionVector rhs_sel; - - //! Vector of rows that mush be fetched for every LHS key - vector> rhs_rows; - ExpressionExecutor probe_executor; - - ArenaAllocator arena_allocator; - vector keys; - unique_ptr fetch_state; - -public: - void Finalize(const PhysicalOperator &op, ExecutionContext &context) override { - context.thread.profiler.Flush(op, probe_executor, "probe_executor", 0); - } -}; - -PhysicalIndexJoin::PhysicalIndexJoin(LogicalOperator &op, unique_ptr left, - unique_ptr right, vector cond, JoinType join_type, - const vector &left_projection_map_p, vector right_projection_map_p, - vector column_ids_p, Index &index_p, bool lhs_first, - idx_t estimated_cardinality) - : CachingPhysicalOperator(PhysicalOperatorType::INDEX_JOIN, std::move(op.types), estimated_cardinality), - left_projection_map(left_projection_map_p), right_projection_map(std::move(right_projection_map_p)), - index(index_p), conditions(std::move(cond)), join_type(join_type), lhs_first(lhs_first) { - D_ASSERT(right->type == PhysicalOperatorType::TABLE_SCAN); - auto &tbl_scan = right->Cast(); - column_ids = std::move(column_ids_p); - children.push_back(std::move(left)); - children.push_back(std::move(right)); - for (auto &condition : conditions) { - condition_types.push_back(condition.left->return_type); - } - //! Only add to fetch_ids columns that are not indexed - for (auto &index_id : index.column_ids) { - index_ids.insert(index_id); - } - - for (idx_t i = 0; i < column_ids.size(); i++) { - auto column_id = column_ids[i]; - auto it = index_ids.find(column_id); - if (it == index_ids.end()) { - fetch_ids.push_back(column_id); - if (column_id == COLUMN_IDENTIFIER_ROW_ID) { - fetch_types.emplace_back(LogicalType::ROW_TYPE); - } else { - fetch_types.push_back(tbl_scan.returned_types[column_id]); - } - } - } - if (right_projection_map.empty()) { - for (column_t i = 0; i < column_ids.size(); i++) { - right_projection_map.push_back(i); - } - } - if (left_projection_map.empty()) { - for (column_t i = 0; i < children[0]->types.size(); i++) { - left_projection_map.push_back(i); - } - } -} - -unique_ptr PhysicalIndexJoin::GetOperatorState(ExecutionContext &context) const { - return make_uniq(context.client, *this); -} - -void PhysicalIndexJoin::Output(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - OperatorState &state_p) const { - auto &phy_tbl_scan = children[1]->Cast(); - auto &bind_tbl = phy_tbl_scan.bind_data->Cast(); - auto &transaction = DuckTransaction::Get(context.client, bind_tbl.table.catalog); - auto &state = state_p.Cast(); - - auto &tbl = bind_tbl.table.GetStorage(); - idx_t output_sel_idx = 0; - vector fetch_rows; - - while (output_sel_idx < STANDARD_VECTOR_SIZE && state.lhs_idx < input.size()) { - if (state.rhs_idx < state.result_sizes[state.lhs_idx]) { - state.rhs_sel.set_index(output_sel_idx++, state.lhs_idx); - if (!fetch_types.empty()) { - //! We need to collect the rows we want to fetch - fetch_rows.push_back(state.rhs_rows[state.lhs_idx][state.rhs_idx]); - } - state.rhs_idx++; - } else { - //! We are done with the matches from this LHS Key - state.rhs_idx = 0; - state.lhs_idx++; - } - } - //! Now we fetch the RHS data - if (!fetch_types.empty()) { - if (fetch_rows.empty()) { - return; - } - state.rhs_chunk.Reset(); - state.fetch_state = make_uniq(); - Vector row_ids(LogicalType::ROW_TYPE, data_ptr_cast(&fetch_rows[0])); - tbl.Fetch(transaction, state.rhs_chunk, fetch_ids, row_ids, output_sel_idx, *state.fetch_state); - } - - //! Now we actually produce our result chunk - idx_t left_offset = lhs_first ? 0 : right_projection_map.size(); - idx_t right_offset = lhs_first ? left_projection_map.size() : 0; - idx_t rhs_column_idx = 0; - for (idx_t i = 0; i < right_projection_map.size(); i++) { - auto it = index_ids.find(column_ids[right_projection_map[i]]); - if (it == index_ids.end()) { - chunk.data[right_offset + i].Reference(state.rhs_chunk.data[rhs_column_idx++]); - } else { - chunk.data[right_offset + i].Slice(state.join_keys.data[0], state.rhs_sel, output_sel_idx); - } - } - for (idx_t i = 0; i < left_projection_map.size(); i++) { - chunk.data[left_offset + i].Slice(input.data[left_projection_map[i]], state.rhs_sel, output_sel_idx); - } - - state.result_size = output_sel_idx; - chunk.SetCardinality(state.result_size); -} - -void PhysicalIndexJoin::GetRHSMatches(ExecutionContext &context, DataChunk &input, OperatorState &state_p) const { - - auto &state = state_p.Cast(); - auto &art = index.Cast(); - - // generate the keys for this chunk - state.arena_allocator.Reset(); - ART::GenerateKeys(state.arena_allocator, state.join_keys, state.keys); - - for (idx_t i = 0; i < input.size(); i++) { - state.rhs_rows[i].clear(); - if (!state.keys[i].Empty()) { - if (fetch_types.empty()) { - IndexLock lock; - index.InitializeLock(lock); - art.SearchEqualJoinNoFetch(state.keys[i], state.result_sizes[i]); - } else { - IndexLock lock; - index.InitializeLock(lock); - art.SearchEqual(state.keys[i], (idx_t)-1, state.rhs_rows[i]); - state.result_sizes[i] = state.rhs_rows[i].size(); - } - } else { - //! This is null so no matches - state.result_sizes[i] = 0; - } - } - for (idx_t i = input.size(); i < STANDARD_VECTOR_SIZE; i++) { - //! No LHS chunk value so result size is empty - state.result_sizes[i] = 0; - } -} - -OperatorResultType PhysicalIndexJoin::ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - GlobalOperatorState &gstate, OperatorState &state_p) const { - auto &state = state_p.Cast(); - - state.result_size = 0; - if (state.first_fetch) { - state.probe_executor.Execute(input, state.join_keys); - - //! Fill Matches for the current LHS chunk - GetRHSMatches(context, input, state_p); - state.first_fetch = false; - } - //! Check if we need to get a new LHS chunk - if (state.lhs_idx >= input.size()) { - state.lhs_idx = 0; - state.rhs_idx = 0; - state.first_fetch = true; - // reset the LHS chunk to reset the validity masks - state.join_keys.Reset(); - return OperatorResultType::NEED_MORE_INPUT; - } - //! Output vectors - if (state.lhs_idx < input.size()) { - Output(context, input, chunk, state_p); - } - return OperatorResultType::HAVE_MORE_OUTPUT; -} - -//===--------------------------------------------------------------------===// -// Pipeline Construction -//===--------------------------------------------------------------------===// -void PhysicalIndexJoin::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { - // index join: we only continue into the LHS - // the right side is probed by the index join - // so we don't need to do anything in the pipeline with this child - meta_pipeline.GetState().AddPipelineOperator(current, *this); - children[0]->BuildPipelines(current, meta_pipeline); -} - -vector> PhysicalIndexJoin::GetSources() const { - return children[0]->GetSources(); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -PhysicalJoin::PhysicalJoin(LogicalOperator &op, PhysicalOperatorType type, JoinType join_type, - idx_t estimated_cardinality) - : CachingPhysicalOperator(type, op.types, estimated_cardinality), join_type(join_type) { -} - -bool PhysicalJoin::EmptyResultIfRHSIsEmpty() const { - // empty RHS with INNER, RIGHT or SEMI join means empty result set - switch (join_type) { - case JoinType::INNER: - case JoinType::RIGHT: - case JoinType::SEMI: - return true; - default: - return false; - } -} - -//===--------------------------------------------------------------------===// -// Pipeline Construction -//===--------------------------------------------------------------------===// -void PhysicalJoin::BuildJoinPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline, PhysicalOperator &op) { - op.op_state.reset(); - op.sink_state.reset(); - - // 'current' is the probe pipeline: add this operator - auto &state = meta_pipeline.GetState(); - state.AddPipelineOperator(current, op); - - // save the last added pipeline to set up dependencies later (in case we need to add a child pipeline) - vector> pipelines_so_far; - meta_pipeline.GetPipelines(pipelines_so_far, false); - auto last_pipeline = pipelines_so_far.back().get(); - - // on the RHS (build side), we construct a child MetaPipeline with this operator as its sink - auto &child_meta_pipeline = meta_pipeline.CreateChildMetaPipeline(current, op); - child_meta_pipeline.Build(*op.children[1]); - - // continue building the current pipeline on the LHS (probe side) - op.children[0]->BuildPipelines(current, meta_pipeline); - - switch (op.type) { - case PhysicalOperatorType::POSITIONAL_JOIN: - // Positional joins are always outer - meta_pipeline.CreateChildPipeline(current, op, last_pipeline); - return; - case PhysicalOperatorType::CROSS_PRODUCT: - return; - default: - break; - } - - // Join can become a source operator if it's RIGHT/OUTER, or if the hash join goes out-of-core - bool add_child_pipeline = false; - auto &join_op = op.Cast(); - if (join_op.IsSource()) { - add_child_pipeline = true; - } - - if (add_child_pipeline) { - meta_pipeline.CreateChildPipeline(current, op, last_pipeline); - } -} - -void PhysicalJoin::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { - PhysicalJoin::BuildJoinPipelines(current, meta_pipeline, *this); -} - -vector> PhysicalJoin::GetSources() const { - auto result = children[0]->GetSources(); - if (IsSource()) { - result.push_back(*this); - } - return result; -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -PhysicalNestedLoopJoin::PhysicalNestedLoopJoin(LogicalOperator &op, unique_ptr left, - unique_ptr right, vector cond, - JoinType join_type, idx_t estimated_cardinality) - : PhysicalComparisonJoin(op, PhysicalOperatorType::NESTED_LOOP_JOIN, std::move(cond), join_type, - estimated_cardinality) { - children.push_back(std::move(left)); - children.push_back(std::move(right)); -} - -bool PhysicalJoin::HasNullValues(DataChunk &chunk) { - for (idx_t col_idx = 0; col_idx < chunk.ColumnCount(); col_idx++) { - UnifiedVectorFormat vdata; - chunk.data[col_idx].ToUnifiedFormat(chunk.size(), vdata); - - if (vdata.validity.AllValid()) { - continue; - } - for (idx_t i = 0; i < chunk.size(); i++) { - auto idx = vdata.sel->get_index(i); - if (!vdata.validity.RowIsValid(idx)) { - return true; - } - } - } - return false; -} - -template -static void ConstructSemiOrAntiJoinResult(DataChunk &left, DataChunk &result, bool found_match[]) { - D_ASSERT(left.ColumnCount() == result.ColumnCount()); - // create the selection vector from the matches that were found - idx_t result_count = 0; - SelectionVector sel(STANDARD_VECTOR_SIZE); - for (idx_t i = 0; i < left.size(); i++) { - if (found_match[i] == MATCH) { - sel.set_index(result_count++, i); - } - } - // construct the final result - if (result_count > 0) { - // we only return the columns on the left side - // project them using the result selection vector - // reference the columns of the left side from the result - result.Slice(left, sel, result_count); - } else { - result.SetCardinality(0); - } -} - -void PhysicalJoin::ConstructSemiJoinResult(DataChunk &left, DataChunk &result, bool found_match[]) { - ConstructSemiOrAntiJoinResult(left, result, found_match); -} - -void PhysicalJoin::ConstructAntiJoinResult(DataChunk &left, DataChunk &result, bool found_match[]) { - ConstructSemiOrAntiJoinResult(left, result, found_match); -} - -void PhysicalJoin::ConstructMarkJoinResult(DataChunk &join_keys, DataChunk &left, DataChunk &result, bool found_match[], - bool has_null) { - // for the initial set of columns we just reference the left side - result.SetCardinality(left); - for (idx_t i = 0; i < left.ColumnCount(); i++) { - result.data[i].Reference(left.data[i]); - } - auto &mark_vector = result.data.back(); - mark_vector.SetVectorType(VectorType::FLAT_VECTOR); - // first we set the NULL values from the join keys - // if there is any NULL in the keys, the result is NULL - auto bool_result = FlatVector::GetData(mark_vector); - auto &mask = FlatVector::Validity(mark_vector); - for (idx_t col_idx = 0; col_idx < join_keys.ColumnCount(); col_idx++) { - UnifiedVectorFormat jdata; - join_keys.data[col_idx].ToUnifiedFormat(join_keys.size(), jdata); - if (!jdata.validity.AllValid()) { - for (idx_t i = 0; i < join_keys.size(); i++) { - auto jidx = jdata.sel->get_index(i); - mask.Set(i, jdata.validity.RowIsValid(jidx)); - } - } - } - // now set the remaining entries to either true or false based on whether a match was found - if (found_match) { - for (idx_t i = 0; i < left.size(); i++) { - bool_result[i] = found_match[i]; - } - } else { - memset(bool_result, 0, sizeof(bool) * left.size()); - } - // if the right side contains NULL values, the result of any FALSE becomes NULL - if (has_null) { - for (idx_t i = 0; i < left.size(); i++) { - if (!bool_result[i]) { - mask.SetInvalid(i); - } - } - } -} - -bool PhysicalNestedLoopJoin::IsSupported(const vector &conditions, JoinType join_type) { - if (join_type == JoinType::MARK) { - return true; - } - for (auto &cond : conditions) { - if (cond.left->return_type.InternalType() == PhysicalType::STRUCT || - cond.left->return_type.InternalType() == PhysicalType::LIST) { - return false; - } - } - return true; -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class NestedLoopJoinLocalState : public LocalSinkState { -public: - explicit NestedLoopJoinLocalState(ClientContext &context, const vector &conditions) - : rhs_executor(context) { - vector condition_types; - for (auto &cond : conditions) { - rhs_executor.AddExpression(*cond.right); - condition_types.push_back(cond.right->return_type); - } - right_condition.Initialize(Allocator::Get(context), condition_types); - } - - //! The chunk holding the right condition - DataChunk right_condition; - //! The executor of the RHS condition - ExpressionExecutor rhs_executor; -}; - -class NestedLoopJoinGlobalState : public GlobalSinkState { -public: - explicit NestedLoopJoinGlobalState(ClientContext &context, const PhysicalNestedLoopJoin &op) - : right_payload_data(context, op.children[1]->types), right_condition_data(context, op.GetJoinTypes()), - has_null(false), right_outer(IsRightOuterJoin(op.join_type)) { - } - - mutex nj_lock; - //! Materialized data of the RHS - ColumnDataCollection right_payload_data; - //! Materialized join condition of the RHS - ColumnDataCollection right_condition_data; - //! Whether or not the RHS of the nested loop join has NULL values - atomic has_null; - //! A bool indicating for each tuple in the RHS if they found a match (only used in FULL OUTER JOIN) - OuterJoinMarker right_outer; -}; - -vector PhysicalNestedLoopJoin::GetJoinTypes() const { - vector result; - for (auto &op : conditions) { - result.push_back(op.right->return_type); - } - return result; -} - -SinkResultType PhysicalNestedLoopJoin::Sink(ExecutionContext &context, DataChunk &chunk, - OperatorSinkInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &nlj_state = input.local_state.Cast(); - - // resolve the join expression of the right side - nlj_state.right_condition.Reset(); - nlj_state.rhs_executor.Execute(chunk, nlj_state.right_condition); - - // if we have not seen any NULL values yet, and we are performing a MARK join, check if there are NULL values in - // this chunk - if (join_type == JoinType::MARK && !gstate.has_null) { - if (HasNullValues(nlj_state.right_condition)) { - gstate.has_null = true; - } - } - - // append the payload data and the conditions - lock_guard nj_guard(gstate.nj_lock); - gstate.right_payload_data.Append(chunk); - gstate.right_condition_data.Append(nlj_state.right_condition); - return SinkResultType::NEED_MORE_INPUT; -} - -SinkCombineResultType PhysicalNestedLoopJoin::Combine(ExecutionContext &context, - OperatorSinkCombineInput &input) const { - auto &state = input.local_state.Cast(); - auto &client_profiler = QueryProfiler::Get(context.client); - - context.thread.profiler.Flush(*this, state.rhs_executor, "rhs_executor", 1); - client_profiler.Flush(context.thread.profiler); - - return SinkCombineResultType::FINISHED; -} - -SinkFinalizeType PhysicalNestedLoopJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - auto &gstate = input.global_state.Cast(); - gstate.right_outer.Initialize(gstate.right_payload_data.Count()); - if (gstate.right_payload_data.Count() == 0 && EmptyResultIfRHSIsEmpty()) { - return SinkFinalizeType::NO_OUTPUT_POSSIBLE; - } - return SinkFinalizeType::READY; -} - -unique_ptr PhysicalNestedLoopJoin::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(context, *this); -} - -unique_ptr PhysicalNestedLoopJoin::GetLocalSinkState(ExecutionContext &context) const { - return make_uniq(context.client, conditions); -} - -//===--------------------------------------------------------------------===// -// Operator -//===--------------------------------------------------------------------===// -class PhysicalNestedLoopJoinState : public CachingOperatorState { -public: - PhysicalNestedLoopJoinState(ClientContext &context, const PhysicalNestedLoopJoin &op, - const vector &conditions) - : fetch_next_left(true), fetch_next_right(false), lhs_executor(context), left_tuple(0), right_tuple(0), - left_outer(IsLeftOuterJoin(op.join_type)) { - vector condition_types; - for (auto &cond : conditions) { - lhs_executor.AddExpression(*cond.left); - condition_types.push_back(cond.left->return_type); - } - auto &allocator = Allocator::Get(context); - left_condition.Initialize(allocator, condition_types); - right_condition.Initialize(allocator, condition_types); - right_payload.Initialize(allocator, op.children[1]->GetTypes()); - left_outer.Initialize(STANDARD_VECTOR_SIZE); - } - - bool fetch_next_left; - bool fetch_next_right; - DataChunk left_condition; - //! The executor of the LHS condition - ExpressionExecutor lhs_executor; - - ColumnDataScanState condition_scan_state; - ColumnDataScanState payload_scan_state; - DataChunk right_condition; - DataChunk right_payload; - - idx_t left_tuple; - idx_t right_tuple; - - OuterJoinMarker left_outer; - -public: - void Finalize(const PhysicalOperator &op, ExecutionContext &context) override { - context.thread.profiler.Flush(op, lhs_executor, "lhs_executor", 0); - } -}; - -unique_ptr PhysicalNestedLoopJoin::GetOperatorState(ExecutionContext &context) const { - return make_uniq(context.client, *this, conditions); -} - -OperatorResultType PhysicalNestedLoopJoin::ExecuteInternal(ExecutionContext &context, DataChunk &input, - DataChunk &chunk, GlobalOperatorState &gstate_p, - OperatorState &state_p) const { - auto &gstate = sink_state->Cast(); - - if (gstate.right_payload_data.Count() == 0) { - // empty RHS - if (!EmptyResultIfRHSIsEmpty()) { - ConstructEmptyJoinResult(join_type, gstate.has_null, input, chunk); - return OperatorResultType::NEED_MORE_INPUT; - } else { - return OperatorResultType::FINISHED; - } - } - - switch (join_type) { - case JoinType::SEMI: - case JoinType::ANTI: - case JoinType::MARK: - // simple joins can have max STANDARD_VECTOR_SIZE matches per chunk - ResolveSimpleJoin(context, input, chunk, state_p); - return OperatorResultType::NEED_MORE_INPUT; - case JoinType::LEFT: - case JoinType::INNER: - case JoinType::OUTER: - case JoinType::RIGHT: - return ResolveComplexJoin(context, input, chunk, state_p); - default: - throw NotImplementedException("Unimplemented type for nested loop join!"); - } -} - -void PhysicalNestedLoopJoin::ResolveSimpleJoin(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - OperatorState &state_p) const { - auto &state = state_p.Cast(); - auto &gstate = sink_state->Cast(); - - // resolve the left join condition for the current chunk - state.left_condition.Reset(); - state.lhs_executor.Execute(input, state.left_condition); - - bool found_match[STANDARD_VECTOR_SIZE] = {false}; - NestedLoopJoinMark::Perform(state.left_condition, gstate.right_condition_data, found_match, conditions); - switch (join_type) { - case JoinType::MARK: - // now construct the mark join result from the found matches - PhysicalJoin::ConstructMarkJoinResult(state.left_condition, input, chunk, found_match, gstate.has_null); - break; - case JoinType::SEMI: - // construct the semi join result from the found matches - PhysicalJoin::ConstructSemiJoinResult(input, chunk, found_match); - break; - case JoinType::ANTI: - // construct the anti join result from the found matches - PhysicalJoin::ConstructAntiJoinResult(input, chunk, found_match); - break; - default: - throw NotImplementedException("Unimplemented type for simple nested loop join!"); - } -} - -OperatorResultType PhysicalNestedLoopJoin::ResolveComplexJoin(ExecutionContext &context, DataChunk &input, - DataChunk &chunk, OperatorState &state_p) const { - auto &state = state_p.Cast(); - auto &gstate = sink_state->Cast(); - - idx_t match_count; - do { - if (state.fetch_next_right) { - // we exhausted the chunk on the right: move to the next chunk on the right - state.left_tuple = 0; - state.right_tuple = 0; - state.fetch_next_right = false; - // check if we exhausted all chunks on the RHS - if (gstate.right_condition_data.Scan(state.condition_scan_state, state.right_condition)) { - if (!gstate.right_payload_data.Scan(state.payload_scan_state, state.right_payload)) { - throw InternalException("Nested loop join: payload and conditions are unaligned!?"); - } - if (state.right_condition.size() != state.right_payload.size()) { - throw InternalException("Nested loop join: payload and conditions are unaligned!?"); - } - } else { - // we exhausted all chunks on the right: move to the next chunk on the left - state.fetch_next_left = true; - if (state.left_outer.Enabled()) { - // left join: before we move to the next chunk, see if we need to output any vectors that didn't - // have a match found - state.left_outer.ConstructLeftJoinResult(input, chunk); - state.left_outer.Reset(); - } - return OperatorResultType::NEED_MORE_INPUT; - } - } - if (state.fetch_next_left) { - // resolve the left join condition for the current chunk - state.left_condition.Reset(); - state.lhs_executor.Execute(input, state.left_condition); - - state.left_tuple = 0; - state.right_tuple = 0; - gstate.right_condition_data.InitializeScan(state.condition_scan_state); - gstate.right_condition_data.Scan(state.condition_scan_state, state.right_condition); - - gstate.right_payload_data.InitializeScan(state.payload_scan_state); - gstate.right_payload_data.Scan(state.payload_scan_state, state.right_payload); - state.fetch_next_left = false; - } - // now we have a left and a right chunk that we can join together - // note that we only get here in the case of a LEFT, INNER or FULL join - auto &left_chunk = input; - auto &right_condition = state.right_condition; - auto &right_payload = state.right_payload; - - // sanity check - left_chunk.Verify(); - right_condition.Verify(); - right_payload.Verify(); - - // now perform the join - SelectionVector lvector(STANDARD_VECTOR_SIZE), rvector(STANDARD_VECTOR_SIZE); - match_count = NestedLoopJoinInner::Perform(state.left_tuple, state.right_tuple, state.left_condition, - right_condition, lvector, rvector, conditions); - // we have finished resolving the join conditions - if (match_count > 0) { - // we have matching tuples! - // construct the result - state.left_outer.SetMatches(lvector, match_count); - gstate.right_outer.SetMatches(rvector, match_count, state.condition_scan_state.current_row_index); - - chunk.Slice(input, lvector, match_count); - chunk.Slice(right_payload, rvector, match_count, input.ColumnCount()); - } - - // check if we exhausted the RHS, if we did we need to move to the next right chunk in the next iteration - if (state.right_tuple >= right_condition.size()) { - state.fetch_next_right = true; - } - } while (match_count == 0); - return OperatorResultType::HAVE_MORE_OUTPUT; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -class NestedLoopJoinGlobalScanState : public GlobalSourceState { -public: - explicit NestedLoopJoinGlobalScanState(const PhysicalNestedLoopJoin &op) : op(op) { - D_ASSERT(op.sink_state); - auto &sink = op.sink_state->Cast(); - sink.right_outer.InitializeScan(sink.right_payload_data, scan_state); - } - - const PhysicalNestedLoopJoin &op; - OuterJoinGlobalScanState scan_state; - -public: - idx_t MaxThreads() override { - auto &sink = op.sink_state->Cast(); - return sink.right_outer.MaxThreads(); - } -}; - -class NestedLoopJoinLocalScanState : public LocalSourceState { -public: - explicit NestedLoopJoinLocalScanState(const PhysicalNestedLoopJoin &op, NestedLoopJoinGlobalScanState &gstate) { - D_ASSERT(op.sink_state); - auto &sink = op.sink_state->Cast(); - sink.right_outer.InitializeScan(gstate.scan_state, scan_state); - } - - OuterJoinLocalScanState scan_state; -}; - -unique_ptr PhysicalNestedLoopJoin::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(*this); -} - -unique_ptr PhysicalNestedLoopJoin::GetLocalSourceState(ExecutionContext &context, - GlobalSourceState &gstate) const { - return make_uniq(*this, gstate.Cast()); -} - -SourceResultType PhysicalNestedLoopJoin::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - D_ASSERT(IsRightOuterJoin(join_type)); - // check if we need to scan any unmatched tuples from the RHS for the full/right outer join - auto &sink = sink_state->Cast(); - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - - // if the LHS is exhausted in a FULL/RIGHT OUTER JOIN, we scan chunks we still need to output - sink.right_outer.Scan(gstate.scan_state, lstate.scan_state, chunk); - - return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; -} - -} // namespace duckdb - - - - - - - - - - - - - - -namespace duckdb { - -PhysicalPiecewiseMergeJoin::PhysicalPiecewiseMergeJoin(LogicalComparisonJoin &op, unique_ptr left, - unique_ptr right, vector cond, - JoinType join_type, idx_t estimated_cardinality) - : PhysicalRangeJoin(op, PhysicalOperatorType::PIECEWISE_MERGE_JOIN, std::move(left), std::move(right), - std::move(cond), join_type, estimated_cardinality) { - - for (auto &cond : conditions) { - D_ASSERT(cond.left->return_type == cond.right->return_type); - join_key_types.push_back(cond.left->return_type); - - // Convert the conditions to sort orders - auto left = cond.left->Copy(); - auto right = cond.right->Copy(); - switch (cond.comparison) { - case ExpressionType::COMPARE_LESSTHAN: - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - lhs_orders.emplace_back(OrderType::ASCENDING, OrderByNullType::NULLS_LAST, std::move(left)); - rhs_orders.emplace_back(OrderType::ASCENDING, OrderByNullType::NULLS_LAST, std::move(right)); - break; - case ExpressionType::COMPARE_GREATERTHAN: - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - lhs_orders.emplace_back(OrderType::DESCENDING, OrderByNullType::NULLS_LAST, std::move(left)); - rhs_orders.emplace_back(OrderType::DESCENDING, OrderByNullType::NULLS_LAST, std::move(right)); - break; - case ExpressionType::COMPARE_NOTEQUAL: - case ExpressionType::COMPARE_DISTINCT_FROM: - // Allowed in multi-predicate joins, but can't be first/sort. - D_ASSERT(!lhs_orders.empty()); - lhs_orders.emplace_back(OrderType::INVALID, OrderByNullType::NULLS_LAST, std::move(left)); - rhs_orders.emplace_back(OrderType::INVALID, OrderByNullType::NULLS_LAST, std::move(right)); - break; - - default: - // COMPARE EQUAL not supported with merge join - throw NotImplementedException("Unimplemented join type for merge join"); - } - } -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class MergeJoinLocalState : public LocalSinkState { -public: - explicit MergeJoinLocalState(ClientContext &context, const PhysicalRangeJoin &op, const idx_t child) - : table(context, op, child) { - } - - //! The local sort state - PhysicalRangeJoin::LocalSortedTable table; -}; - -class MergeJoinGlobalState : public GlobalSinkState { -public: - using GlobalSortedTable = PhysicalRangeJoin::GlobalSortedTable; - -public: - MergeJoinGlobalState(ClientContext &context, const PhysicalPiecewiseMergeJoin &op) { - RowLayout rhs_layout; - rhs_layout.Initialize(op.children[1]->types); - vector rhs_order; - rhs_order.emplace_back(op.rhs_orders[0].Copy()); - table = make_uniq(context, rhs_order, rhs_layout); - } - - inline idx_t Count() const { - return table->count; - } - - void Sink(DataChunk &input, MergeJoinLocalState &lstate) { - auto &global_sort_state = table->global_sort_state; - auto &local_sort_state = lstate.table.local_sort_state; - - // Sink the data into the local sort state - lstate.table.Sink(input, global_sort_state); - - // When sorting data reaches a certain size, we sort it - if (local_sort_state.SizeInBytes() >= table->memory_per_thread) { - local_sort_state.Sort(global_sort_state, true); - } - } - - unique_ptr table; -}; - -unique_ptr PhysicalPiecewiseMergeJoin::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(context, *this); -} - -unique_ptr PhysicalPiecewiseMergeJoin::GetLocalSinkState(ExecutionContext &context) const { - // We only sink the RHS - return make_uniq(context.client, *this, 1); -} - -SinkResultType PhysicalPiecewiseMergeJoin::Sink(ExecutionContext &context, DataChunk &chunk, - OperatorSinkInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - - gstate.Sink(chunk, lstate); - - return SinkResultType::NEED_MORE_INPUT; -} - -SinkCombineResultType PhysicalPiecewiseMergeJoin::Combine(ExecutionContext &context, - OperatorSinkCombineInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - gstate.table->Combine(lstate.table); - auto &client_profiler = QueryProfiler::Get(context.client); - - context.thread.profiler.Flush(*this, lstate.table.executor, "rhs_executor", 1); - client_profiler.Flush(context.thread.profiler); - - return SinkCombineResultType::FINISHED; -} - -//===--------------------------------------------------------------------===// -// Finalize -//===--------------------------------------------------------------------===// -SinkFinalizeType PhysicalPiecewiseMergeJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &global_sort_state = gstate.table->global_sort_state; - - if (IsRightOuterJoin(join_type)) { - // for FULL/RIGHT OUTER JOIN, initialize found_match to false for every tuple - gstate.table->IntializeMatches(); - } - if (global_sort_state.sorted_blocks.empty() && EmptyResultIfRHSIsEmpty()) { - // Empty input! - return SinkFinalizeType::NO_OUTPUT_POSSIBLE; - } - - // Sort the current input child - gstate.table->Finalize(pipeline, event); - - return SinkFinalizeType::READY; -} - -//===--------------------------------------------------------------------===// -// Operator -//===--------------------------------------------------------------------===// -class PiecewiseMergeJoinState : public CachingOperatorState { -public: - using LocalSortedTable = PhysicalRangeJoin::LocalSortedTable; - - PiecewiseMergeJoinState(ClientContext &context, const PhysicalPiecewiseMergeJoin &op, bool force_external) - : context(context), allocator(Allocator::Get(context)), op(op), - buffer_manager(BufferManager::GetBufferManager(context)), force_external(force_external), - left_outer(IsLeftOuterJoin(op.join_type)), left_position(0), first_fetch(true), finished(true), - right_position(0), right_chunk_index(0), rhs_executor(context) { - vector condition_types; - for (auto &order : op.lhs_orders) { - condition_types.push_back(order.expression->return_type); - } - left_outer.Initialize(STANDARD_VECTOR_SIZE); - lhs_layout.Initialize(op.children[0]->types); - lhs_payload.Initialize(allocator, op.children[0]->types); - - lhs_order.emplace_back(op.lhs_orders[0].Copy()); - - // Set up shared data for multiple predicates - sel.Initialize(STANDARD_VECTOR_SIZE); - condition_types.clear(); - for (auto &order : op.rhs_orders) { - rhs_executor.AddExpression(*order.expression); - condition_types.push_back(order.expression->return_type); - } - rhs_keys.Initialize(allocator, condition_types); - } - - ClientContext &context; - Allocator &allocator; - const PhysicalPiecewiseMergeJoin &op; - BufferManager &buffer_manager; - bool force_external; - - // Block sorting - DataChunk lhs_payload; - OuterJoinMarker left_outer; - vector lhs_order; - RowLayout lhs_layout; - unique_ptr lhs_local_table; - unique_ptr lhs_global_state; - unique_ptr scanner; - - // Simple scans - idx_t left_position; - - // Complex scans - bool first_fetch; - bool finished; - idx_t right_position; - idx_t right_chunk_index; - idx_t right_base; - idx_t prev_left_index; - - // Secondary predicate shared data - SelectionVector sel; - DataChunk rhs_keys; - DataChunk rhs_input; - ExpressionExecutor rhs_executor; - vector payload_heap_handles; - -public: - void ResolveJoinKeys(DataChunk &input) { - // sort by join key - lhs_global_state = make_uniq(buffer_manager, lhs_order, lhs_layout); - lhs_local_table = make_uniq(context, op, 0); - lhs_local_table->Sink(input, *lhs_global_state); - - // Set external (can be forced with the PRAGMA) - lhs_global_state->external = force_external; - lhs_global_state->AddLocalState(lhs_local_table->local_sort_state); - lhs_global_state->PrepareMergePhase(); - while (lhs_global_state->sorted_blocks.size() > 1) { - MergeSorter merge_sorter(*lhs_global_state, buffer_manager); - merge_sorter.PerformInMergeRound(); - lhs_global_state->CompleteMergeRound(); - } - - // Scan the sorted payload - D_ASSERT(lhs_global_state->sorted_blocks.size() == 1); - - scanner = make_uniq(*lhs_global_state->sorted_blocks[0]->payload_data, *lhs_global_state); - lhs_payload.Reset(); - scanner->Scan(lhs_payload); - - // Recompute the sorted keys from the sorted input - lhs_local_table->keys.Reset(); - lhs_local_table->executor.Execute(lhs_payload, lhs_local_table->keys); - } - - void Finalize(const PhysicalOperator &op, ExecutionContext &context) override { - if (lhs_local_table) { - context.thread.profiler.Flush(op, lhs_local_table->executor, "lhs_executor", 0); - } - } -}; - -unique_ptr PhysicalPiecewiseMergeJoin::GetOperatorState(ExecutionContext &context) const { - auto &config = ClientConfig::GetConfig(context.client); - return make_uniq(context.client, *this, config.force_external); -} - -static inline idx_t SortedBlockNotNull(const idx_t base, const idx_t count, const idx_t not_null) { - return MinValue(base + count, MaxValue(base, not_null)) - base; -} - -static int MergeJoinComparisonValue(ExpressionType comparison) { - switch (comparison) { - case ExpressionType::COMPARE_LESSTHAN: - case ExpressionType::COMPARE_GREATERTHAN: - return -1; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - return 0; - default: - throw InternalException("Unimplemented comparison type for merge join!"); - } -} - -struct BlockMergeInfo { - GlobalSortState &state; - //! The block being scanned - const idx_t block_idx; - //! The number of not-NULL values in the block (they are at the end) - const idx_t not_null; - //! The current offset in the block - idx_t &entry_idx; - SelectionVector result; - - BlockMergeInfo(GlobalSortState &state, idx_t block_idx, idx_t &entry_idx, idx_t not_null) - : state(state), block_idx(block_idx), not_null(not_null), entry_idx(entry_idx), result(STANDARD_VECTOR_SIZE) { - } -}; - -static void MergeJoinPinSortingBlock(SBScanState &scan, const idx_t block_idx) { - scan.SetIndices(block_idx, 0); - scan.PinRadix(block_idx); - - auto &sd = *scan.sb->blob_sorting_data; - if (block_idx < sd.data_blocks.size()) { - scan.PinData(sd); - } -} - -static data_ptr_t MergeJoinRadixPtr(SBScanState &scan, const idx_t entry_idx) { - scan.entry_idx = entry_idx; - return scan.RadixPtr(); -} - -static idx_t MergeJoinSimpleBlocks(PiecewiseMergeJoinState &lstate, MergeJoinGlobalState &rstate, bool *found_match, - const ExpressionType comparison) { - const auto cmp = MergeJoinComparisonValue(comparison); - - // The sort parameters should all be the same - auto &lsort = *lstate.lhs_global_state; - auto &rsort = rstate.table->global_sort_state; - D_ASSERT(lsort.sort_layout.all_constant == rsort.sort_layout.all_constant); - const auto all_constant = lsort.sort_layout.all_constant; - D_ASSERT(lsort.external == rsort.external); - const auto external = lsort.external; - - // There should only be one sorted block if they have been sorted - D_ASSERT(lsort.sorted_blocks.size() == 1); - SBScanState lread(lsort.buffer_manager, lsort); - lread.sb = lsort.sorted_blocks[0].get(); - - const idx_t l_block_idx = 0; - idx_t l_entry_idx = 0; - const auto lhs_not_null = lstate.lhs_local_table->count - lstate.lhs_local_table->has_null; - MergeJoinPinSortingBlock(lread, l_block_idx); - auto l_ptr = MergeJoinRadixPtr(lread, l_entry_idx); - - D_ASSERT(rsort.sorted_blocks.size() == 1); - SBScanState rread(rsort.buffer_manager, rsort); - rread.sb = rsort.sorted_blocks[0].get(); - - const auto cmp_size = lsort.sort_layout.comparison_size; - const auto entry_size = lsort.sort_layout.entry_size; - - idx_t right_base = 0; - for (idx_t r_block_idx = 0; r_block_idx < rread.sb->radix_sorting_data.size(); r_block_idx++) { - // we only care about the BIGGEST value in each of the RHS data blocks - // because we want to figure out if the LHS values are less than [or equal] to ANY value - // get the biggest value from the RHS chunk - MergeJoinPinSortingBlock(rread, r_block_idx); - - auto &rblock = *rread.sb->radix_sorting_data[r_block_idx]; - const auto r_not_null = - SortedBlockNotNull(right_base, rblock.count, rstate.table->count - rstate.table->has_null); - if (r_not_null == 0) { - break; - } - const auto r_entry_idx = r_not_null - 1; - right_base += rblock.count; - - auto r_ptr = MergeJoinRadixPtr(rread, r_entry_idx); - - // now we start from the current lpos value and check if we found a new value that is [<= OR <] the max RHS - // value - while (true) { - int comp_res; - if (all_constant) { - comp_res = FastMemcmp(l_ptr, r_ptr, cmp_size); - } else { - lread.entry_idx = l_entry_idx; - rread.entry_idx = r_entry_idx; - comp_res = Comparators::CompareTuple(lread, rread, l_ptr, r_ptr, lsort.sort_layout, external); - } - - if (comp_res <= cmp) { - // found a match for lpos, set it in the found_match vector - found_match[l_entry_idx] = true; - l_entry_idx++; - l_ptr += entry_size; - if (l_entry_idx >= lhs_not_null) { - // early out: we exhausted the entire LHS and they all match - return 0; - } - } else { - // we found no match: any subsequent value from the LHS we scan now will be bigger and thus also not - // match move to the next RHS chunk - break; - } - } - } - return 0; -} - -void PhysicalPiecewiseMergeJoin::ResolveSimpleJoin(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - OperatorState &state_p) const { - auto &state = state_p.Cast(); - auto &gstate = sink_state->Cast(); - - state.ResolveJoinKeys(input); - auto &lhs_table = *state.lhs_local_table; - - // perform the actual join - bool found_match[STANDARD_VECTOR_SIZE]; - memset(found_match, 0, sizeof(found_match)); - MergeJoinSimpleBlocks(state, gstate, found_match, conditions[0].comparison); - - // use the sorted payload - const auto lhs_not_null = lhs_table.count - lhs_table.has_null; - auto &payload = state.lhs_payload; - - // now construct the result based on the join result - switch (join_type) { - case JoinType::MARK: { - // The only part of the join keys that is actually used is the validity mask. - // Since the payload is sorted, we can just set the tail end of the validity masks to invalid. - for (auto &key : lhs_table.keys.data) { - key.Flatten(lhs_table.keys.size()); - auto &mask = FlatVector::Validity(key); - if (mask.AllValid()) { - continue; - } - mask.SetAllValid(lhs_not_null); - for (idx_t i = lhs_not_null; i < lhs_table.count; ++i) { - mask.SetInvalid(i); - } - } - // So we make a set of keys that have the validity mask set for the - PhysicalJoin::ConstructMarkJoinResult(lhs_table.keys, payload, chunk, found_match, gstate.table->has_null); - break; - } - case JoinType::SEMI: - PhysicalJoin::ConstructSemiJoinResult(payload, chunk, found_match); - break; - case JoinType::ANTI: - PhysicalJoin::ConstructAntiJoinResult(payload, chunk, found_match); - break; - default: - throw NotImplementedException("Unimplemented join type for merge join"); - } -} - -static idx_t MergeJoinComplexBlocks(BlockMergeInfo &l, BlockMergeInfo &r, const ExpressionType comparison, - idx_t &prev_left_index) { - const auto cmp = MergeJoinComparisonValue(comparison); - - // The sort parameters should all be the same - D_ASSERT(l.state.sort_layout.all_constant == r.state.sort_layout.all_constant); - const auto all_constant = r.state.sort_layout.all_constant; - D_ASSERT(l.state.external == r.state.external); - const auto external = l.state.external; - - // There should only be one sorted block if they have been sorted - D_ASSERT(l.state.sorted_blocks.size() == 1); - SBScanState lread(l.state.buffer_manager, l.state); - lread.sb = l.state.sorted_blocks[0].get(); - D_ASSERT(lread.sb->radix_sorting_data.size() == 1); - MergeJoinPinSortingBlock(lread, l.block_idx); - auto l_start = MergeJoinRadixPtr(lread, 0); - auto l_ptr = MergeJoinRadixPtr(lread, l.entry_idx); - - D_ASSERT(r.state.sorted_blocks.size() == 1); - SBScanState rread(r.state.buffer_manager, r.state); - rread.sb = r.state.sorted_blocks[0].get(); - - if (r.entry_idx >= r.not_null) { - return 0; - } - - MergeJoinPinSortingBlock(rread, r.block_idx); - auto r_ptr = MergeJoinRadixPtr(rread, r.entry_idx); - - const auto cmp_size = l.state.sort_layout.comparison_size; - const auto entry_size = l.state.sort_layout.entry_size; - - idx_t result_count = 0; - while (true) { - if (l.entry_idx < prev_left_index) { - // left side smaller: found match - l.result.set_index(result_count, sel_t(l.entry_idx)); - r.result.set_index(result_count, sel_t(r.entry_idx)); - result_count++; - // move left side forward - l.entry_idx++; - l_ptr += entry_size; - if (result_count == STANDARD_VECTOR_SIZE) { - // out of space! - break; - } - continue; - } - if (l.entry_idx < l.not_null) { - int comp_res; - if (all_constant) { - comp_res = FastMemcmp(l_ptr, r_ptr, cmp_size); - } else { - lread.entry_idx = l.entry_idx; - rread.entry_idx = r.entry_idx; - comp_res = Comparators::CompareTuple(lread, rread, l_ptr, r_ptr, l.state.sort_layout, external); - } - if (comp_res <= cmp) { - // left side smaller: found match - l.result.set_index(result_count, sel_t(l.entry_idx)); - r.result.set_index(result_count, sel_t(r.entry_idx)); - result_count++; - // move left side forward - l.entry_idx++; - l_ptr += entry_size; - if (result_count == STANDARD_VECTOR_SIZE) { - // out of space! - break; - } - continue; - } - } - - prev_left_index = l.entry_idx; - // right side smaller or equal, or left side exhausted: move - // right pointer forward reset left side to start - r.entry_idx++; - if (r.entry_idx >= r.not_null) { - break; - } - r_ptr += entry_size; - - l_ptr = l_start; - l.entry_idx = 0; - } - - return result_count; -} - -OperatorResultType PhysicalPiecewiseMergeJoin::ResolveComplexJoin(ExecutionContext &context, DataChunk &input, - DataChunk &chunk, OperatorState &state_p) const { - auto &state = state_p.Cast(); - auto &gstate = sink_state->Cast(); - auto &rsorted = *gstate.table->global_sort_state.sorted_blocks[0]; - const auto left_cols = input.ColumnCount(); - const auto tail_cols = conditions.size() - 1; - - state.payload_heap_handles.clear(); - do { - if (state.first_fetch) { - state.ResolveJoinKeys(input); - - state.right_chunk_index = 0; - state.right_base = 0; - state.left_position = 0; - state.prev_left_index = 0; - state.right_position = 0; - state.first_fetch = false; - state.finished = false; - } - if (state.finished) { - if (state.left_outer.Enabled()) { - // left join: before we move to the next chunk, see if we need to output any vectors that didn't - // have a match found - state.left_outer.ConstructLeftJoinResult(state.lhs_payload, chunk); - state.left_outer.Reset(); - } - state.first_fetch = true; - state.finished = false; - return OperatorResultType::NEED_MORE_INPUT; - } - - auto &lhs_table = *state.lhs_local_table; - const auto lhs_not_null = lhs_table.count - lhs_table.has_null; - BlockMergeInfo left_info(*state.lhs_global_state, 0, state.left_position, lhs_not_null); - - const auto &rblock = *rsorted.radix_sorting_data[state.right_chunk_index]; - const auto rhs_not_null = - SortedBlockNotNull(state.right_base, rblock.count, gstate.table->count - gstate.table->has_null); - BlockMergeInfo right_info(gstate.table->global_sort_state, state.right_chunk_index, state.right_position, - rhs_not_null); - - idx_t result_count = - MergeJoinComplexBlocks(left_info, right_info, conditions[0].comparison, state.prev_left_index); - if (result_count == 0) { - // exhausted this chunk on the right side - // move to the next right chunk - state.left_position = 0; - state.right_position = 0; - state.right_base += rsorted.radix_sorting_data[state.right_chunk_index]->count; - state.right_chunk_index++; - if (state.right_chunk_index >= rsorted.radix_sorting_data.size()) { - state.finished = true; - } - } else { - // found matches: extract them - chunk.Reset(); - for (idx_t c = 0; c < state.lhs_payload.ColumnCount(); ++c) { - chunk.data[c].Slice(state.lhs_payload.data[c], left_info.result, result_count); - } - state.payload_heap_handles.push_back(SliceSortedPayload(chunk, right_info.state, right_info.block_idx, - right_info.result, result_count, left_cols)); - chunk.SetCardinality(result_count); - - auto sel = FlatVector::IncrementalSelectionVector(); - if (tail_cols) { - // If there are more expressions to compute, - // split the result chunk into the left and right halves - // so we can compute the values for comparison. - chunk.Split(state.rhs_input, left_cols); - state.rhs_executor.SetChunk(state.rhs_input); - state.rhs_keys.Reset(); - - auto tail_count = result_count; - for (size_t cmp_idx = 1; cmp_idx < conditions.size(); ++cmp_idx) { - Vector left(lhs_table.keys.data[cmp_idx]); - left.Slice(left_info.result, result_count); - - auto &right = state.rhs_keys.data[cmp_idx]; - state.rhs_executor.ExecuteExpression(cmp_idx, right); - - if (tail_count < result_count) { - left.Slice(*sel, tail_count); - right.Slice(*sel, tail_count); - } - tail_count = - SelectJoinTail(conditions[cmp_idx].comparison, left, right, sel, tail_count, &state.sel); - sel = &state.sel; - } - chunk.Fuse(state.rhs_input); - - if (tail_count < result_count) { - result_count = tail_count; - chunk.Slice(*sel, result_count); - } - } - - // found matches: mark the found matches if required - if (state.left_outer.Enabled()) { - for (idx_t i = 0; i < result_count; i++) { - state.left_outer.SetMatch(left_info.result[sel->get_index(i)]); - } - } - if (gstate.table->found_match) { - // Absolute position of the block + start position inside that block - for (idx_t i = 0; i < result_count; i++) { - gstate.table->found_match[state.right_base + right_info.result[sel->get_index(i)]] = true; - } - } - chunk.SetCardinality(result_count); - chunk.Verify(); - } - } while (chunk.size() == 0); - return OperatorResultType::HAVE_MORE_OUTPUT; -} - -OperatorResultType PhysicalPiecewiseMergeJoin::ExecuteInternal(ExecutionContext &context, DataChunk &input, - DataChunk &chunk, GlobalOperatorState &gstate_p, - OperatorState &state) const { - auto &gstate = sink_state->Cast(); - - if (gstate.Count() == 0) { - // empty RHS - if (!EmptyResultIfRHSIsEmpty()) { - ConstructEmptyJoinResult(join_type, gstate.table->has_null, input, chunk); - return OperatorResultType::NEED_MORE_INPUT; - } else { - return OperatorResultType::FINISHED; - } - } - - input.Verify(); - switch (join_type) { - case JoinType::SEMI: - case JoinType::ANTI: - case JoinType::MARK: - // simple joins can have max STANDARD_VECTOR_SIZE matches per chunk - ResolveSimpleJoin(context, input, chunk, state); - return OperatorResultType::NEED_MORE_INPUT; - case JoinType::LEFT: - case JoinType::INNER: - case JoinType::RIGHT: - case JoinType::OUTER: - return ResolveComplexJoin(context, input, chunk, state); - default: - throw NotImplementedException("Unimplemented type for piecewise merge loop join!"); - } -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -class PiecewiseJoinScanState : public GlobalSourceState { -public: - explicit PiecewiseJoinScanState(const PhysicalPiecewiseMergeJoin &op) : op(op), right_outer_position(0) { - } - - mutex lock; - const PhysicalPiecewiseMergeJoin &op; - unique_ptr scanner; - idx_t right_outer_position; - -public: - idx_t MaxThreads() override { - auto &sink = op.sink_state->Cast(); - return sink.Count() / (STANDARD_VECTOR_SIZE * idx_t(10)); - } -}; - -unique_ptr PhysicalPiecewiseMergeJoin::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(*this); -} - -SourceResultType PhysicalPiecewiseMergeJoin::GetData(ExecutionContext &context, DataChunk &result, - OperatorSourceInput &input) const { - D_ASSERT(IsRightOuterJoin(join_type)); - // check if we need to scan any unmatched tuples from the RHS for the full/right outer join - auto &sink = sink_state->Cast(); - auto &state = input.global_state.Cast(); - - lock_guard l(state.lock); - if (!state.scanner) { - // Initialize scanner (if not yet initialized) - auto &sort_state = sink.table->global_sort_state; - if (sort_state.sorted_blocks.empty()) { - return SourceResultType::FINISHED; - } - state.scanner = make_uniq(*sort_state.sorted_blocks[0]->payload_data, sort_state); - } - - // if the LHS is exhausted in a FULL/RIGHT OUTER JOIN, we scan the found_match for any chunks we - // still need to output - const auto found_match = sink.table->found_match.get(); - - DataChunk rhs_chunk; - rhs_chunk.Initialize(Allocator::Get(context.client), sink.table->global_sort_state.payload_layout.GetTypes()); - SelectionVector rsel(STANDARD_VECTOR_SIZE); - for (;;) { - // Read the next sorted chunk - state.scanner->Scan(rhs_chunk); - - const auto count = rhs_chunk.size(); - if (count == 0) { - return result.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; - } - - idx_t result_count = 0; - // figure out which tuples didn't find a match in the RHS - for (idx_t i = 0; i < count; i++) { - if (!found_match[state.right_outer_position + i]) { - rsel.set_index(result_count++, i); - } - } - state.right_outer_position += count; - - if (result_count > 0) { - // if there were any tuples that didn't find a match, output them - const idx_t left_column_count = children[0]->types.size(); - for (idx_t col_idx = 0; col_idx < left_column_count; ++col_idx) { - result.data[col_idx].SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result.data[col_idx], true); - } - const idx_t right_column_count = children[1]->types.size(); - ; - for (idx_t col_idx = 0; col_idx < right_column_count; ++col_idx) { - result.data[left_column_count + col_idx].Slice(rhs_chunk.data[col_idx], rsel, result_count); - } - result.SetCardinality(result_count); - break; - } - } - - return result.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; -} - -} // namespace duckdb - - - - - -namespace duckdb { - -PhysicalPositionalJoin::PhysicalPositionalJoin(vector types, unique_ptr left, - unique_ptr right, idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::POSITIONAL_JOIN, std::move(types), estimated_cardinality) { - children.push_back(std::move(left)); - children.push_back(std::move(right)); -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class PositionalJoinGlobalState : public GlobalSinkState { -public: - explicit PositionalJoinGlobalState(ClientContext &context, const PhysicalPositionalJoin &op) - : rhs(context, op.children[1]->GetTypes()), initialized(false), source_offset(0), exhausted(false) { - rhs.InitializeAppend(append_state); - } - - ColumnDataCollection rhs; - ColumnDataAppendState append_state; - mutex rhs_lock; - - bool initialized; - ColumnDataScanState scan_state; - DataChunk source; - idx_t source_offset; - bool exhausted; - - void InitializeScan(); - idx_t Refill(); - idx_t CopyData(DataChunk &output, const idx_t count, const idx_t col_offset); - void Execute(DataChunk &input, DataChunk &output); - void GetData(DataChunk &output); -}; - -unique_ptr PhysicalPositionalJoin::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(context, *this); -} - -SinkResultType PhysicalPositionalJoin::Sink(ExecutionContext &context, DataChunk &chunk, - OperatorSinkInput &input) const { - auto &sink = input.global_state.Cast(); - lock_guard client_guard(sink.rhs_lock); - sink.rhs.Append(sink.append_state, chunk); - return SinkResultType::NEED_MORE_INPUT; -} - -//===--------------------------------------------------------------------===// -// Operator -//===--------------------------------------------------------------------===// -void PositionalJoinGlobalState::InitializeScan() { - if (!initialized) { - // not initialized yet: initialize the scan - initialized = true; - rhs.InitializeScanChunk(source); - rhs.InitializeScan(scan_state); - } -} - -idx_t PositionalJoinGlobalState::Refill() { - if (source_offset >= source.size()) { - if (!exhausted) { - source.Reset(); - rhs.Scan(scan_state, source); - } - source_offset = 0; - } - - const auto available = source.size() - source_offset; - if (!available) { - if (!exhausted) { - source.Reset(); - for (idx_t i = 0; i < source.ColumnCount(); ++i) { - auto &vec = source.data[i]; - vec.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(vec, true); - } - exhausted = true; - } - } - - return available; -} - -idx_t PositionalJoinGlobalState::CopyData(DataChunk &output, const idx_t count, const idx_t col_offset) { - if (!source_offset && (source.size() >= count || exhausted)) { - // Fast track: aligned and has enough data - for (idx_t i = 0; i < source.ColumnCount(); ++i) { - output.data[col_offset + i].Reference(source.data[i]); - } - source_offset += count; - } else { - // Copy data - for (idx_t target_offset = 0; target_offset < count;) { - const auto needed = count - target_offset; - const auto available = exhausted ? needed : (source.size() - source_offset); - const auto copy_size = MinValue(needed, available); - const auto source_count = source_offset + copy_size; - for (idx_t i = 0; i < source.ColumnCount(); ++i) { - VectorOperations::Copy(source.data[i], output.data[col_offset + i], source_count, source_offset, - target_offset); - } - target_offset += copy_size; - source_offset += copy_size; - Refill(); - } - } - - return source.ColumnCount(); -} - -void PositionalJoinGlobalState::Execute(DataChunk &input, DataChunk &output) { - lock_guard client_guard(rhs_lock); - - // Reference the input and assume it will be full - const auto col_offset = input.ColumnCount(); - for (idx_t i = 0; i < col_offset; ++i) { - output.data[i].Reference(input.data[i]); - } - - // Copy or reference the RHS columns - const auto count = input.size(); - InitializeScan(); - Refill(); - CopyData(output, count, col_offset); - - output.SetCardinality(count); -} - -OperatorResultType PhysicalPositionalJoin::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - GlobalOperatorState &gstate, OperatorState &state_p) const { - auto &sink = sink_state->Cast(); - sink.Execute(input, chunk); - return OperatorResultType::NEED_MORE_INPUT; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -void PositionalJoinGlobalState::GetData(DataChunk &output) { - lock_guard client_guard(rhs_lock); - - InitializeScan(); - Refill(); - - // LHS exhausted - if (exhausted) { - // RHS exhausted too, so we are done - output.SetCardinality(0); - return; - } - - // LHS is all NULL - const auto col_offset = output.ColumnCount() - source.ColumnCount(); - for (idx_t i = 0; i < col_offset; ++i) { - auto &vec = output.data[i]; - vec.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(vec, true); - } - - // RHS still has data, so copy it - const auto count = MinValue(STANDARD_VECTOR_SIZE, source.size() - source_offset); - CopyData(output, count, col_offset); - output.SetCardinality(count); -} - -SourceResultType PhysicalPositionalJoin::GetData(ExecutionContext &context, DataChunk &result, - OperatorSourceInput &input) const { - auto &sink = sink_state->Cast(); - sink.GetData(result); - - return result.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; -} - -//===--------------------------------------------------------------------===// -// Pipeline Construction -//===--------------------------------------------------------------------===// -void PhysicalPositionalJoin::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { - PhysicalJoin::BuildJoinPipelines(current, meta_pipeline, *this); -} - -vector> PhysicalPositionalJoin::GetSources() const { - auto result = children[0]->GetSources(); - if (IsSource()) { - result.push_back(*this); - } - return result; -} - -} // namespace duckdb - - - - - - - - - - - - - -#include - -namespace duckdb { - -PhysicalRangeJoin::LocalSortedTable::LocalSortedTable(ClientContext &context, const PhysicalRangeJoin &op, - const idx_t child) - : op(op), executor(context), has_null(0), count(0) { - // Initialize order clause expression executor and key DataChunk - vector types; - for (const auto &cond : op.conditions) { - const auto &expr = child ? cond.right : cond.left; - executor.AddExpression(*expr); - - types.push_back(expr->return_type); - } - auto &allocator = Allocator::Get(context); - keys.Initialize(allocator, types); -} - -void PhysicalRangeJoin::LocalSortedTable::Sink(DataChunk &input, GlobalSortState &global_sort_state) { - // Initialize local state (if necessary) - if (!local_sort_state.initialized) { - local_sort_state.Initialize(global_sort_state, global_sort_state.buffer_manager); - } - - // Obtain sorting columns - keys.Reset(); - executor.Execute(input, keys); - - // Count the NULLs so we can exclude them later - has_null += MergeNulls(op.conditions); - count += keys.size(); - - // Only sort the primary key - DataChunk join_head; - join_head.data.emplace_back(keys.data[0]); - join_head.SetCardinality(keys.size()); - - // Sink the data into the local sort state - local_sort_state.SinkChunk(join_head, input); -} - -PhysicalRangeJoin::GlobalSortedTable::GlobalSortedTable(ClientContext &context, const vector &orders, - RowLayout &payload_layout) - : global_sort_state(BufferManager::GetBufferManager(context), orders, payload_layout), has_null(0), count(0), - memory_per_thread(0) { - D_ASSERT(orders.size() == 1); - - // Set external (can be forced with the PRAGMA) - auto &config = ClientConfig::GetConfig(context); - global_sort_state.external = config.force_external; - memory_per_thread = PhysicalRangeJoin::GetMaxThreadMemory(context); -} - -void PhysicalRangeJoin::GlobalSortedTable::Combine(LocalSortedTable <able) { - global_sort_state.AddLocalState(ltable.local_sort_state); - has_null += ltable.has_null; - count += ltable.count; -} - -void PhysicalRangeJoin::GlobalSortedTable::IntializeMatches() { - found_match = make_unsafe_uniq_array(Count()); - memset(found_match.get(), 0, sizeof(bool) * Count()); -} - -void PhysicalRangeJoin::GlobalSortedTable::Print() { - global_sort_state.Print(); -} - -class RangeJoinMergeTask : public ExecutorTask { -public: - using GlobalSortedTable = PhysicalRangeJoin::GlobalSortedTable; - -public: - RangeJoinMergeTask(shared_ptr event_p, ClientContext &context, GlobalSortedTable &table) - : ExecutorTask(context), event(std::move(event_p)), context(context), table(table) { - } - - TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { - // Initialize iejoin sorted and iterate until done - auto &global_sort_state = table.global_sort_state; - MergeSorter merge_sorter(global_sort_state, BufferManager::GetBufferManager(context)); - merge_sorter.PerformInMergeRound(); - event->FinishTask(); - - return TaskExecutionResult::TASK_FINISHED; - } - -private: - shared_ptr event; - ClientContext &context; - GlobalSortedTable &table; -}; - -class RangeJoinMergeEvent : public BasePipelineEvent { -public: - using GlobalSortedTable = PhysicalRangeJoin::GlobalSortedTable; - -public: - RangeJoinMergeEvent(GlobalSortedTable &table_p, Pipeline &pipeline_p) - : BasePipelineEvent(pipeline_p), table(table_p) { - } - - GlobalSortedTable &table; - -public: - void Schedule() override { - auto &context = pipeline->GetClientContext(); - - // Schedule tasks equal to the number of threads, which will each merge multiple partitions - auto &ts = TaskScheduler::GetScheduler(context); - idx_t num_threads = ts.NumberOfThreads(); - - vector> iejoin_tasks; - for (idx_t tnum = 0; tnum < num_threads; tnum++) { - iejoin_tasks.push_back(make_uniq(shared_from_this(), context, table)); - } - SetTasks(std::move(iejoin_tasks)); - } - - void FinishEvent() override { - auto &global_sort_state = table.global_sort_state; - - global_sort_state.CompleteMergeRound(true); - if (global_sort_state.sorted_blocks.size() > 1) { - // Multiple blocks remaining: Schedule the next round - table.ScheduleMergeTasks(*pipeline, *this); - } - } -}; - -void PhysicalRangeJoin::GlobalSortedTable::ScheduleMergeTasks(Pipeline &pipeline, Event &event) { - // Initialize global sort state for a round of merging - global_sort_state.InitializeMergeRound(); - auto new_event = make_shared(*this, pipeline); - event.InsertEvent(std::move(new_event)); -} - -void PhysicalRangeJoin::GlobalSortedTable::Finalize(Pipeline &pipeline, Event &event) { - // Prepare for merge sort phase - global_sort_state.PrepareMergePhase(); - - // Start the merge phase or finish if a merge is not necessary - if (global_sort_state.sorted_blocks.size() > 1) { - ScheduleMergeTasks(pipeline, event); - } -} - -PhysicalRangeJoin::PhysicalRangeJoin(LogicalComparisonJoin &op, PhysicalOperatorType type, - unique_ptr left, unique_ptr right, - vector cond, JoinType join_type, idx_t estimated_cardinality) - : PhysicalComparisonJoin(op, type, std::move(cond), join_type, estimated_cardinality) { - // Reorder the conditions so that ranges are at the front. - // TODO: use stats to improve the choice? - // TODO: Prefer fixed length types? - if (conditions.size() > 1) { - vector conditions_p(conditions.size()); - std::swap(conditions_p, conditions); - idx_t range_position = 0; - idx_t other_position = conditions_p.size(); - for (idx_t i = 0; i < conditions_p.size(); ++i) { - switch (conditions_p[i].comparison) { - case ExpressionType::COMPARE_LESSTHAN: - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - case ExpressionType::COMPARE_GREATERTHAN: - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - conditions[range_position++] = std::move(conditions_p[i]); - break; - default: - conditions[--other_position] = std::move(conditions_p[i]); - break; - } - } - } - - children.push_back(std::move(left)); - children.push_back(std::move(right)); - - // Fill out the left projection map. - left_projection_map = op.left_projection_map; - if (left_projection_map.empty()) { - const auto left_count = children[0]->types.size(); - left_projection_map.reserve(left_count); - for (column_t i = 0; i < left_count; ++i) { - left_projection_map.emplace_back(i); - } - } - // Fill out the right projection map. - right_projection_map = op.right_projection_map; - if (right_projection_map.empty()) { - const auto right_count = children[1]->types.size(); - right_projection_map.reserve(right_count); - for (column_t i = 0; i < right_count; ++i) { - right_projection_map.emplace_back(i); - } - } - - // Construct the unprojected type layout from the children's types - unprojected_types = children[0]->GetTypes(); - auto &types = children[1]->GetTypes(); - unprojected_types.insert(unprojected_types.end(), types.begin(), types.end()); -} - -idx_t PhysicalRangeJoin::LocalSortedTable::MergeNulls(const vector &conditions) { - // Merge the validity masks of the comparison keys into the primary - // Return the number of NULLs in the resulting chunk - D_ASSERT(keys.ColumnCount() > 0); - const auto count = keys.size(); - - size_t all_constant = 0; - for (auto &v : keys.data) { - if (v.GetVectorType() == VectorType::CONSTANT_VECTOR) { - ++all_constant; - } - } - - auto &primary = keys.data[0]; - if (all_constant == keys.data.size()) { - // Either all NULL or no NULLs - for (auto &v : keys.data) { - if (ConstantVector::IsNull(v)) { - ConstantVector::SetNull(primary, true); - return count; - } - } - return 0; - } else if (keys.ColumnCount() > 1) { - // Flatten the primary, as it will need to merge arbitrary validity masks - primary.Flatten(count); - auto &pvalidity = FlatVector::Validity(primary); - D_ASSERT(keys.ColumnCount() == conditions.size()); - for (size_t c = 1; c < keys.data.size(); ++c) { - // Skip comparisons that accept NULLs - if (conditions[c].comparison == ExpressionType::COMPARE_DISTINCT_FROM) { - continue; - } - // ToUnifiedFormat the rest, as the sort code will do this anyway. - auto &v = keys.data[c]; - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(count, vdata); - auto &vvalidity = vdata.validity; - if (vvalidity.AllValid()) { - continue; - } - pvalidity.EnsureWritable(); - switch (v.GetVectorType()) { - case VectorType::FLAT_VECTOR: { - // Merge entire entries - auto pmask = pvalidity.GetData(); - const auto entry_count = pvalidity.EntryCount(count); - for (idx_t entry_idx = 0; entry_idx < entry_count; ++entry_idx) { - pmask[entry_idx] &= vvalidity.GetValidityEntry(entry_idx); - } - break; - } - case VectorType::CONSTANT_VECTOR: - // All or nothing - if (ConstantVector::IsNull(v)) { - pvalidity.SetAllInvalid(count); - return count; - } - break; - default: - // One by one - for (idx_t i = 0; i < count; ++i) { - const auto idx = vdata.sel->get_index(i); - if (!vvalidity.RowIsValidUnsafe(idx)) { - pvalidity.SetInvalidUnsafe(i); - } - } - break; - } - } - return count - pvalidity.CountValid(count); - } else { - return count - VectorOperations::CountNotNull(primary, count); - } -} - -void PhysicalRangeJoin::ProjectResult(DataChunk &chunk, DataChunk &result) const { - const auto left_projected = left_projection_map.size(); - for (idx_t i = 0; i < left_projected; ++i) { - result.data[i].Reference(chunk.data[left_projection_map[i]]); - } - const auto left_width = children[0]->types.size(); - for (idx_t i = 0; i < right_projection_map.size(); ++i) { - result.data[left_projected + i].Reference(chunk.data[left_width + right_projection_map[i]]); - } - result.SetCardinality(chunk); -} - -BufferHandle PhysicalRangeJoin::SliceSortedPayload(DataChunk &payload, GlobalSortState &state, const idx_t block_idx, - const SelectionVector &result, const idx_t result_count, - const idx_t left_cols) { - // There should only be one sorted block if they have been sorted - D_ASSERT(state.sorted_blocks.size() == 1); - SBScanState read_state(state.buffer_manager, state); - read_state.sb = state.sorted_blocks[0].get(); - auto &sorted_data = *read_state.sb->payload_data; - - read_state.SetIndices(block_idx, 0); - read_state.PinData(sorted_data); - const auto data_ptr = read_state.DataPtr(sorted_data); - data_ptr_t heap_ptr = nullptr; - - // Set up a batch of pointers to scan data from - Vector addresses(LogicalType::POINTER, result_count); - auto data_pointers = FlatVector::GetData(addresses); - - // Set up the data pointers for the values that are actually referenced - const idx_t &row_width = sorted_data.layout.GetRowWidth(); - - auto prev_idx = result.get_index(0); - SelectionVector gsel(result_count); - idx_t addr_count = 0; - gsel.set_index(0, addr_count); - data_pointers[addr_count] = data_ptr + prev_idx * row_width; - for (idx_t i = 1; i < result_count; ++i) { - const auto row_idx = result.get_index(i); - if (row_idx != prev_idx) { - data_pointers[++addr_count] = data_ptr + row_idx * row_width; - prev_idx = row_idx; - } - gsel.set_index(i, addr_count); - } - ++addr_count; - - // Unswizzle the offsets back to pointers (if needed) - if (!sorted_data.layout.AllConstant() && state.external) { - heap_ptr = read_state.payload_heap_handle.Ptr(); - } - - // Deserialize the payload data - auto sel = FlatVector::IncrementalSelectionVector(); - for (idx_t col_no = 0; col_no < sorted_data.layout.ColumnCount(); col_no++) { - auto &col = payload.data[left_cols + col_no]; - RowOperations::Gather(addresses, *sel, col, *sel, addr_count, sorted_data.layout, col_no, 0, heap_ptr); - col.Slice(gsel, result_count); - } - - return std::move(read_state.payload_heap_handle); -} - -idx_t PhysicalRangeJoin::SelectJoinTail(const ExpressionType &condition, Vector &left, Vector &right, - const SelectionVector *sel, idx_t count, SelectionVector *true_sel) { - switch (condition) { - case ExpressionType::COMPARE_NOTEQUAL: - return VectorOperations::NotEquals(left, right, sel, count, true_sel, nullptr); - case ExpressionType::COMPARE_LESSTHAN: - return VectorOperations::LessThan(left, right, sel, count, true_sel, nullptr); - case ExpressionType::COMPARE_GREATERTHAN: - return VectorOperations::GreaterThan(left, right, sel, count, true_sel, nullptr); - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - return VectorOperations::LessThanEquals(left, right, sel, count, true_sel, nullptr); - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - return VectorOperations::GreaterThanEquals(left, right, sel, count, true_sel, nullptr); - case ExpressionType::COMPARE_DISTINCT_FROM: - return VectorOperations::DistinctFrom(left, right, sel, count, true_sel, nullptr); - case ExpressionType::COMPARE_NOT_DISTINCT_FROM: - return VectorOperations::NotDistinctFrom(left, right, sel, count, true_sel, nullptr); - case ExpressionType::COMPARE_EQUAL: - return VectorOperations::Equals(left, right, sel, count, true_sel, nullptr); - default: - throw InternalException("Unsupported comparison type for PhysicalRangeJoin"); - } - - return count; -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -PhysicalOrder::PhysicalOrder(vector types, vector orders, vector projections, - idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::ORDER_BY, std::move(types), estimated_cardinality), - orders(std::move(orders)), projections(std::move(projections)) { -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class OrderGlobalSinkState : public GlobalSinkState { -public: - OrderGlobalSinkState(BufferManager &buffer_manager, const PhysicalOrder &order, RowLayout &payload_layout) - : global_sort_state(buffer_manager, order.orders, payload_layout) { - } - - //! Global sort state - GlobalSortState global_sort_state; - //! Memory usage per thread - idx_t memory_per_thread; -}; - -class OrderLocalSinkState : public LocalSinkState { -public: - OrderLocalSinkState(ClientContext &context, const PhysicalOrder &op) : key_executor(context) { - // Initialize order clause expression executor and DataChunk - vector key_types; - for (auto &order : op.orders) { - key_types.push_back(order.expression->return_type); - key_executor.AddExpression(*order.expression); - } - auto &allocator = Allocator::Get(context); - keys.Initialize(allocator, key_types); - payload.Initialize(allocator, op.types); - } - -public: - //! The local sort state - LocalSortState local_sort_state; - //! Key expression executor, and chunk to hold the vectors - ExpressionExecutor key_executor; - DataChunk keys; - //! Payload chunk to hold the vectors - DataChunk payload; -}; - -unique_ptr PhysicalOrder::GetGlobalSinkState(ClientContext &context) const { - // Get the payload layout from the return types - RowLayout payload_layout; - payload_layout.Initialize(types); - auto state = make_uniq(BufferManager::GetBufferManager(context), *this, payload_layout); - // Set external (can be force with the PRAGMA) - state->global_sort_state.external = ClientConfig::GetConfig(context).force_external; - state->memory_per_thread = GetMaxThreadMemory(context); - return std::move(state); -} - -unique_ptr PhysicalOrder::GetLocalSinkState(ExecutionContext &context) const { - return make_uniq(context.client, *this); -} - -SinkResultType PhysicalOrder::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - - auto &global_sort_state = gstate.global_sort_state; - auto &local_sort_state = lstate.local_sort_state; - - // Initialize local state (if necessary) - if (!local_sort_state.initialized) { - local_sort_state.Initialize(global_sort_state, BufferManager::GetBufferManager(context.client)); - } - - // Obtain sorting columns - auto &keys = lstate.keys; - keys.Reset(); - lstate.key_executor.Execute(chunk, keys); - - auto &payload = lstate.payload; - payload.ReferenceColumns(chunk, projections); - - // Sink the data into the local sort state - keys.Verify(); - chunk.Verify(); - local_sort_state.SinkChunk(keys, payload); - - // When sorting data reaches a certain size, we sort it - if (local_sort_state.SizeInBytes() >= gstate.memory_per_thread) { - local_sort_state.Sort(global_sort_state, true); - } - return SinkResultType::NEED_MORE_INPUT; -} - -SinkCombineResultType PhysicalOrder::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - gstate.global_sort_state.AddLocalState(lstate.local_sort_state); - - return SinkCombineResultType::FINISHED; -} - -class PhysicalOrderMergeTask : public ExecutorTask { -public: - PhysicalOrderMergeTask(shared_ptr event_p, ClientContext &context, OrderGlobalSinkState &state) - : ExecutorTask(context), event(std::move(event_p)), context(context), state(state) { - } - - TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { - // Initialize merge sorted and iterate until done - auto &global_sort_state = state.global_sort_state; - MergeSorter merge_sorter(global_sort_state, BufferManager::GetBufferManager(context)); - merge_sorter.PerformInMergeRound(); - event->FinishTask(); - return TaskExecutionResult::TASK_FINISHED; - } - -private: - shared_ptr event; - ClientContext &context; - OrderGlobalSinkState &state; -}; - -class OrderMergeEvent : public BasePipelineEvent { -public: - OrderMergeEvent(OrderGlobalSinkState &gstate_p, Pipeline &pipeline_p) - : BasePipelineEvent(pipeline_p), gstate(gstate_p) { - } - - OrderGlobalSinkState &gstate; - -public: - void Schedule() override { - auto &context = pipeline->GetClientContext(); - - // Schedule tasks equal to the number of threads, which will each merge multiple partitions - auto &ts = TaskScheduler::GetScheduler(context); - idx_t num_threads = ts.NumberOfThreads(); - - vector> merge_tasks; - for (idx_t tnum = 0; tnum < num_threads; tnum++) { - merge_tasks.push_back(make_uniq(shared_from_this(), context, gstate)); - } - SetTasks(std::move(merge_tasks)); - } - - void FinishEvent() override { - auto &global_sort_state = gstate.global_sort_state; - - global_sort_state.CompleteMergeRound(); - if (global_sort_state.sorted_blocks.size() > 1) { - // Multiple blocks remaining: Schedule the next round - PhysicalOrder::ScheduleMergeTasks(*pipeline, *this, gstate); - } - } -}; - -SinkFinalizeType PhysicalOrder::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - auto &state = input.global_state.Cast(); - auto &global_sort_state = state.global_sort_state; - - if (global_sort_state.sorted_blocks.empty()) { - // Empty input! - return SinkFinalizeType::NO_OUTPUT_POSSIBLE; - } - - // Prepare for merge sort phase - global_sort_state.PrepareMergePhase(); - - // Start the merge phase or finish if a merge is not necessary - if (global_sort_state.sorted_blocks.size() > 1) { - PhysicalOrder::ScheduleMergeTasks(pipeline, event, state); - } - return SinkFinalizeType::READY; -} - -void PhysicalOrder::ScheduleMergeTasks(Pipeline &pipeline, Event &event, OrderGlobalSinkState &state) { - // Initialize global sort state for a round of merging - state.global_sort_state.InitializeMergeRound(); - auto new_event = make_shared(state, pipeline); - event.InsertEvent(std::move(new_event)); -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -class PhysicalOrderGlobalSourceState : public GlobalSourceState { -public: - explicit PhysicalOrderGlobalSourceState(OrderGlobalSinkState &sink) : next_batch_index(0) { - auto &global_sort_state = sink.global_sort_state; - if (global_sort_state.sorted_blocks.empty()) { - total_batches = 0; - } else { - D_ASSERT(global_sort_state.sorted_blocks.size() == 1); - total_batches = global_sort_state.sorted_blocks[0]->payload_data->data_blocks.size(); - } - } - - idx_t MaxThreads() override { - return total_batches; - } - -public: - atomic next_batch_index; - idx_t total_batches; -}; - -unique_ptr PhysicalOrder::GetGlobalSourceState(ClientContext &context) const { - auto &sink = this->sink_state->Cast(); - return make_uniq(sink); -} - -class PhysicalOrderLocalSourceState : public LocalSourceState { -public: - explicit PhysicalOrderLocalSourceState(PhysicalOrderGlobalSourceState &gstate) - : batch_index(gstate.next_batch_index++) { - } - -public: - idx_t batch_index; - unique_ptr scanner; -}; - -unique_ptr PhysicalOrder::GetLocalSourceState(ExecutionContext &context, - GlobalSourceState &gstate_p) const { - auto &gstate = gstate_p.Cast(); - return make_uniq(gstate); -} - -SourceResultType PhysicalOrder::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - - if (lstate.scanner && lstate.scanner->Remaining() == 0) { - lstate.batch_index = gstate.next_batch_index++; - lstate.scanner = nullptr; - } - - if (lstate.batch_index >= gstate.total_batches) { - return SourceResultType::FINISHED; - } - - if (!lstate.scanner) { - auto &sink = this->sink_state->Cast(); - auto &global_sort_state = sink.global_sort_state; - lstate.scanner = make_uniq(global_sort_state, lstate.batch_index, true); - } - - lstate.scanner->Scan(chunk); - - return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; -} - -idx_t PhysicalOrder::GetBatchIndex(ExecutionContext &context, DataChunk &chunk, GlobalSourceState &gstate_p, - LocalSourceState &lstate_p) const { - auto &lstate = lstate_p.Cast(); - return lstate.batch_index; -} - -string PhysicalOrder::ParamsToString() const { - string result = "ORDERS:\n"; - for (idx_t i = 0; i < orders.size(); i++) { - if (i > 0) { - result += "\n"; - } - result += orders[i].expression->ToString() + " "; - result += orders[i].type == OrderType::DESCENDING ? "DESC" : "ASC"; - } - return result; -} - -} // namespace duckdb - - - - - - - - - - -namespace duckdb { - -PhysicalTopN::PhysicalTopN(vector types, vector orders, idx_t limit, idx_t offset, - idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::TOP_N, std::move(types), estimated_cardinality), orders(std::move(orders)), - limit(limit), offset(offset) { -} - -//===--------------------------------------------------------------------===// -// Heaps -//===--------------------------------------------------------------------===// -class TopNHeap; - -struct TopNScanState { - unique_ptr scanner; - idx_t pos; - bool exclude_offset; -}; - -class TopNSortState { -public: - explicit TopNSortState(TopNHeap &heap); - - TopNHeap &heap; - unique_ptr local_state; - unique_ptr global_state; - idx_t count; - bool is_sorted; - -public: - void Initialize(); - void Append(DataChunk &sort_chunk, DataChunk &payload); - - void Sink(DataChunk &input); - void Finalize(); - - void Move(TopNSortState &other); - - void InitializeScan(TopNScanState &state, bool exclude_offset); - void Scan(TopNScanState &state, DataChunk &chunk); -}; - -class TopNHeap { -public: - TopNHeap(ClientContext &context, const vector &payload_types, const vector &orders, - idx_t limit, idx_t offset); - TopNHeap(ExecutionContext &context, const vector &payload_types, - const vector &orders, idx_t limit, idx_t offset); - TopNHeap(ClientContext &context, Allocator &allocator, const vector &payload_types, - const vector &orders, idx_t limit, idx_t offset); - - Allocator &allocator; - BufferManager &buffer_manager; - const vector &payload_types; - const vector &orders; - idx_t limit; - idx_t offset; - TopNSortState sort_state; - ExpressionExecutor executor; - DataChunk sort_chunk; - DataChunk compare_chunk; - DataChunk payload_chunk; - //! A set of boundary values that determine either the minimum or the maximum value we have to consider for our - //! top-n - DataChunk boundary_values; - //! Whether or not the boundary_values has been set. The boundary_values are only set after a reduce step - bool has_boundary_values; - - SelectionVector final_sel; - SelectionVector true_sel; - SelectionVector false_sel; - SelectionVector new_remaining_sel; - -public: - void Sink(DataChunk &input); - void Combine(TopNHeap &other); - void Reduce(); - void Finalize(); - - void ExtractBoundaryValues(DataChunk ¤t_chunk, DataChunk &prev_chunk); - - void InitializeScan(TopNScanState &state, bool exclude_offset); - void Scan(TopNScanState &state, DataChunk &chunk); - - bool CheckBoundaryValues(DataChunk &sort_chunk, DataChunk &payload); -}; - -//===--------------------------------------------------------------------===// -// TopNSortState -//===--------------------------------------------------------------------===// -TopNSortState::TopNSortState(TopNHeap &heap) : heap(heap), count(0), is_sorted(false) { -} - -void TopNSortState::Initialize() { - RowLayout layout; - layout.Initialize(heap.payload_types); - auto &buffer_manager = heap.buffer_manager; - global_state = make_uniq(buffer_manager, heap.orders, layout); - local_state = make_uniq(); - local_state->Initialize(*global_state, buffer_manager); -} - -void TopNSortState::Append(DataChunk &sort_chunk, DataChunk &payload) { - D_ASSERT(!is_sorted); - if (heap.has_boundary_values) { - if (!heap.CheckBoundaryValues(sort_chunk, payload)) { - return; - } - } - - local_state->SinkChunk(sort_chunk, payload); - count += payload.size(); -} - -void TopNSortState::Sink(DataChunk &input) { - // compute the ordering values for the new chunk - heap.sort_chunk.Reset(); - heap.executor.Execute(input, heap.sort_chunk); - - // append the new chunk to what we have already - Append(heap.sort_chunk, input); -} - -void TopNSortState::Move(TopNSortState &other) { - local_state = std::move(other.local_state); - global_state = std::move(other.global_state); - count = other.count; - is_sorted = other.is_sorted; -} - -void TopNSortState::Finalize() { - D_ASSERT(!is_sorted); - global_state->AddLocalState(*local_state); - - global_state->PrepareMergePhase(); - while (global_state->sorted_blocks.size() > 1) { - MergeSorter merge_sorter(*global_state, heap.buffer_manager); - merge_sorter.PerformInMergeRound(); - global_state->CompleteMergeRound(); - } - is_sorted = true; -} - -void TopNSortState::InitializeScan(TopNScanState &state, bool exclude_offset) { - D_ASSERT(is_sorted); - if (global_state->sorted_blocks.empty()) { - state.scanner = nullptr; - } else { - D_ASSERT(global_state->sorted_blocks.size() == 1); - state.scanner = make_uniq(*global_state->sorted_blocks[0]->payload_data, *global_state); - } - state.pos = 0; - state.exclude_offset = exclude_offset && heap.offset > 0; -} - -void TopNSortState::Scan(TopNScanState &state, DataChunk &chunk) { - if (!state.scanner) { - return; - } - auto offset = heap.offset; - auto limit = heap.limit; - D_ASSERT(is_sorted); - while (chunk.size() == 0) { - state.scanner->Scan(chunk); - if (chunk.size() == 0) { - break; - } - idx_t start = state.pos; - idx_t end = state.pos + chunk.size(); - state.pos = end; - - idx_t chunk_start = 0; - idx_t chunk_end = chunk.size(); - if (state.exclude_offset) { - // we need to exclude all tuples before the OFFSET - // check if we should include anything - if (end <= offset) { - // end is smaller than offset: include nothing! - chunk.Reset(); - continue; - } else if (start < offset) { - // we need to slice - chunk_start = offset - start; - } - } - // check if we need to truncate at the offset + limit mark - if (start >= offset + limit) { - // we are finished - chunk_end = 0; - } else if (end > offset + limit) { - // the end extends past the offset + limit - // truncate the current chunk - chunk_end = offset + limit - start; - } - D_ASSERT(chunk_end - chunk_start <= STANDARD_VECTOR_SIZE); - if (chunk_end == chunk_start) { - chunk.Reset(); - break; - } else if (chunk_start > 0) { - SelectionVector sel(STANDARD_VECTOR_SIZE); - for (idx_t i = chunk_start; i < chunk_end; i++) { - sel.set_index(i - chunk_start, i); - } - chunk.Slice(sel, chunk_end - chunk_start); - } else if (chunk_end != chunk.size()) { - chunk.SetCardinality(chunk_end); - } - } -} - -//===--------------------------------------------------------------------===// -// TopNHeap -//===--------------------------------------------------------------------===// -TopNHeap::TopNHeap(ClientContext &context, Allocator &allocator, const vector &payload_types_p, - const vector &orders_p, idx_t limit, idx_t offset) - : allocator(allocator), buffer_manager(BufferManager::GetBufferManager(context)), payload_types(payload_types_p), - orders(orders_p), limit(limit), offset(offset), sort_state(*this), executor(context), has_boundary_values(false), - final_sel(STANDARD_VECTOR_SIZE), true_sel(STANDARD_VECTOR_SIZE), false_sel(STANDARD_VECTOR_SIZE), - new_remaining_sel(STANDARD_VECTOR_SIZE) { - // initialize the executor and the sort_chunk - vector sort_types; - for (auto &order : orders) { - auto &expr = order.expression; - sort_types.push_back(expr->return_type); - executor.AddExpression(*expr); - } - payload_chunk.Initialize(allocator, payload_types); - sort_chunk.Initialize(allocator, sort_types); - compare_chunk.Initialize(allocator, sort_types); - boundary_values.Initialize(allocator, sort_types); - sort_state.Initialize(); -} - -TopNHeap::TopNHeap(ClientContext &context, const vector &payload_types, - const vector &orders, idx_t limit, idx_t offset) - : TopNHeap(context, BufferAllocator::Get(context), payload_types, orders, limit, offset) { -} - -TopNHeap::TopNHeap(ExecutionContext &context, const vector &payload_types, - const vector &orders, idx_t limit, idx_t offset) - : TopNHeap(context.client, Allocator::Get(context.client), payload_types, orders, limit, offset) { -} - -void TopNHeap::Sink(DataChunk &input) { - sort_state.Sink(input); -} - -void TopNHeap::Combine(TopNHeap &other) { - other.Finalize(); - - TopNScanState state; - other.InitializeScan(state, false); - while (true) { - payload_chunk.Reset(); - other.Scan(state, payload_chunk); - if (payload_chunk.size() == 0) { - break; - } - Sink(payload_chunk); - } - Reduce(); -} - -void TopNHeap::Finalize() { - sort_state.Finalize(); -} - -void TopNHeap::Reduce() { - idx_t min_sort_threshold = MaxValue(STANDARD_VECTOR_SIZE * 5ULL, 2ULL * (limit + offset)); - if (sort_state.count < min_sort_threshold) { - // only reduce when we pass two times the limit + offset, or 5 vectors (whichever comes first) - return; - } - sort_state.Finalize(); - TopNSortState new_state(*this); - new_state.Initialize(); - - TopNScanState state; - sort_state.InitializeScan(state, false); - - DataChunk new_chunk; - new_chunk.Initialize(allocator, payload_types); - - DataChunk *current_chunk = &new_chunk; - DataChunk *prev_chunk = &payload_chunk; - has_boundary_values = false; - while (true) { - current_chunk->Reset(); - Scan(state, *current_chunk); - if (current_chunk->size() == 0) { - ExtractBoundaryValues(*current_chunk, *prev_chunk); - break; - } - new_state.Sink(*current_chunk); - std::swap(current_chunk, prev_chunk); - } - - sort_state.Move(new_state); -} - -void TopNHeap::ExtractBoundaryValues(DataChunk ¤t_chunk, DataChunk &prev_chunk) { - // extract the last entry of the prev_chunk and set as minimum value - D_ASSERT(prev_chunk.size() > 0); - for (idx_t col_idx = 0; col_idx < current_chunk.ColumnCount(); col_idx++) { - ConstantVector::Reference(current_chunk.data[col_idx], prev_chunk.data[col_idx], prev_chunk.size() - 1, - prev_chunk.size()); - } - current_chunk.SetCardinality(1); - sort_chunk.Reset(); - executor.Execute(¤t_chunk, sort_chunk); - - boundary_values.Reset(); - boundary_values.Append(sort_chunk); - boundary_values.SetCardinality(1); - for (idx_t i = 0; i < boundary_values.ColumnCount(); i++) { - boundary_values.data[i].SetVectorType(VectorType::CONSTANT_VECTOR); - } - has_boundary_values = true; -} - -bool TopNHeap::CheckBoundaryValues(DataChunk &sort_chunk, DataChunk &payload) { - // we have boundary values - // from these boundary values, determine which values we should insert (if any) - idx_t final_count = 0; - - SelectionVector remaining_sel(nullptr); - idx_t remaining_count = sort_chunk.size(); - for (idx_t i = 0; i < orders.size(); i++) { - if (remaining_sel.data()) { - compare_chunk.data[i].Slice(sort_chunk.data[i], remaining_sel, remaining_count); - } else { - compare_chunk.data[i].Reference(sort_chunk.data[i]); - } - bool is_last = i + 1 == orders.size(); - idx_t true_count; - if (orders[i].null_order == OrderByNullType::NULLS_LAST) { - if (orders[i].type == OrderType::ASCENDING) { - true_count = VectorOperations::DistinctLessThan(compare_chunk.data[i], boundary_values.data[i], - &remaining_sel, remaining_count, &true_sel, &false_sel); - } else { - true_count = VectorOperations::DistinctGreaterThanNullsFirst(compare_chunk.data[i], - boundary_values.data[i], &remaining_sel, - remaining_count, &true_sel, &false_sel); - } - } else { - D_ASSERT(orders[i].null_order == OrderByNullType::NULLS_FIRST); - if (orders[i].type == OrderType::ASCENDING) { - true_count = VectorOperations::DistinctLessThanNullsFirst(compare_chunk.data[i], - boundary_values.data[i], &remaining_sel, - remaining_count, &true_sel, &false_sel); - } else { - true_count = - VectorOperations::DistinctGreaterThan(compare_chunk.data[i], boundary_values.data[i], - &remaining_sel, remaining_count, &true_sel, &false_sel); - } - } - - if (true_count > 0) { - memcpy(final_sel.data() + final_count, true_sel.data(), true_count * sizeof(sel_t)); - final_count += true_count; - } - idx_t false_count = remaining_count - true_count; - if (false_count > 0) { - // check what we should continue to check - compare_chunk.data[i].Slice(sort_chunk.data[i], false_sel, false_count); - remaining_count = VectorOperations::NotDistinctFrom(compare_chunk.data[i], boundary_values.data[i], - &false_sel, false_count, &new_remaining_sel, nullptr); - if (is_last) { - memcpy(final_sel.data() + final_count, new_remaining_sel.data(), remaining_count * sizeof(sel_t)); - final_count += remaining_count; - } else { - remaining_sel.Initialize(new_remaining_sel); - } - } else { - break; - } - } - if (final_count == 0) { - return false; - } - if (final_count < sort_chunk.size()) { - sort_chunk.Slice(final_sel, final_count); - payload.Slice(final_sel, final_count); - } - return true; -} - -void TopNHeap::InitializeScan(TopNScanState &state, bool exclude_offset) { - sort_state.InitializeScan(state, exclude_offset); -} - -void TopNHeap::Scan(TopNScanState &state, DataChunk &chunk) { - sort_state.Scan(state, chunk); -} - -class TopNGlobalState : public GlobalSinkState { -public: - TopNGlobalState(ClientContext &context, const vector &payload_types, - const vector &orders, idx_t limit, idx_t offset) - : heap(context, payload_types, orders, limit, offset) { - } - - mutex lock; - TopNHeap heap; -}; - -class TopNLocalState : public LocalSinkState { -public: - TopNLocalState(ExecutionContext &context, const vector &payload_types, - const vector &orders, idx_t limit, idx_t offset) - : heap(context, payload_types, orders, limit, offset) { - } - - TopNHeap heap; -}; - -unique_ptr PhysicalTopN::GetLocalSinkState(ExecutionContext &context) const { - return make_uniq(context, types, orders, limit, offset); -} - -unique_ptr PhysicalTopN::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(context, types, orders, limit, offset); -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -SinkResultType PhysicalTopN::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - // append to the local sink state - auto &sink = input.local_state.Cast(); - sink.heap.Sink(chunk); - sink.heap.Reduce(); - return SinkResultType::NEED_MORE_INPUT; -} - -//===--------------------------------------------------------------------===// -// Combine -//===--------------------------------------------------------------------===// -SinkCombineResultType PhysicalTopN::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - - // scan the local top N and append it to the global heap - lock_guard glock(gstate.lock); - gstate.heap.Combine(lstate.heap); - - return SinkCombineResultType::FINISHED; -} - -//===--------------------------------------------------------------------===// -// Finalize -//===--------------------------------------------------------------------===// -SinkFinalizeType PhysicalTopN::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - auto &gstate = input.global_state.Cast(); - // global finalize: compute the final top N - gstate.heap.Finalize(); - return SinkFinalizeType::READY; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -class TopNOperatorState : public GlobalSourceState { -public: - TopNScanState state; - bool initialized = false; -}; - -unique_ptr PhysicalTopN::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(); -} - -SourceResultType PhysicalTopN::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { - if (limit == 0) { - return SourceResultType::FINISHED; - } - auto &state = input.global_state.Cast(); - auto &gstate = sink_state->Cast(); - - if (!state.initialized) { - gstate.heap.InitializeScan(state.state, true); - state.initialized = true; - } - gstate.heap.Scan(state.state, chunk); - - return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; -} - -string PhysicalTopN::ParamsToString() const { - string result; - result += "Top " + to_string(limit); - if (offset > 0) { - result += "\n"; - result += "Offset " + to_string(offset); - } - result += "\n[INFOSEPARATOR]"; - for (idx_t i = 0; i < orders.size(); i++) { - result += "\n"; - result += orders[i].expression->ToString() + " "; - result += orders[i].type == OrderType::DESCENDING ? "DESC" : "ASC"; - } - return result; -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -TableCatalogEntry &CSVRejectsTable::GetTable(ClientContext &context) { - auto &temp_catalog = Catalog::GetCatalog(context, TEMP_CATALOG); - auto &table_entry = temp_catalog.GetEntry(context, TEMP_CATALOG, DEFAULT_SCHEMA, name); - return table_entry; -} - -shared_ptr CSVRejectsTable::GetOrCreate(ClientContext &context, const string &name) { - auto key = "CSV_REJECTS_TABLE_CACHE_ENTRY_" + StringUtil::Upper(name); - auto &cache = ObjectCache::GetObjectCache(context); - return cache.GetOrCreate(key, name); -} - -void CSVRejectsTable::InitializeTable(ClientContext &context, const ReadCSVData &data) { - // (Re)Create the temporary rejects table - auto &catalog = Catalog::GetCatalog(context, TEMP_CATALOG); - auto info = make_uniq(TEMP_CATALOG, DEFAULT_SCHEMA, name); - info->temporary = true; - info->on_conflict = OnCreateConflict::ERROR_ON_CONFLICT; - info->columns.AddColumn(ColumnDefinition("file", LogicalType::VARCHAR)); - info->columns.AddColumn(ColumnDefinition("line", LogicalType::BIGINT)); - info->columns.AddColumn(ColumnDefinition("column", LogicalType::BIGINT)); - info->columns.AddColumn(ColumnDefinition("column_name", LogicalType::VARCHAR)); - info->columns.AddColumn(ColumnDefinition("parsed_value", LogicalType::VARCHAR)); - - if (!data.options.rejects_recovery_columns.empty()) { - child_list_t recovery_key_components; - for (auto &col_name : data.options.rejects_recovery_columns) { - recovery_key_components.emplace_back(col_name, LogicalType::VARCHAR); - } - info->columns.AddColumn(ColumnDefinition("recovery_columns", LogicalType::STRUCT(recovery_key_components))); - } - - info->columns.AddColumn(ColumnDefinition("error", LogicalType::VARCHAR)); - - catalog.CreateTable(context, std::move(info)); - - count = 0; -} - -} // namespace duckdb - - - - - - - - -#include - -namespace duckdb { - -PhysicalBatchCopyToFile::PhysicalBatchCopyToFile(vector types, CopyFunction function_p, - unique_ptr bind_data_p, idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::BATCH_COPY_TO_FILE, std::move(types), estimated_cardinality), - function(std::move(function_p)), bind_data(std::move(bind_data_p)) { - if (!function.flush_batch || !function.prepare_batch) { - throw InternalException( - "PhysicalBatchCopyToFile created for copy function that does not have prepare_batch/flush_batch defined"); - } -} - -//===--------------------------------------------------------------------===// -// States -//===--------------------------------------------------------------------===// -class BatchCopyToGlobalState : public GlobalSinkState { -public: - explicit BatchCopyToGlobalState(unique_ptr global_state) - : rows_copied(0), global_state(std::move(global_state)), any_flushing(false) { - } - - mutex lock; - //! The total number of rows copied to the file - atomic rows_copied; - //! Global copy state - unique_ptr global_state; - //! The prepared batch data by batch index - ready to flush - map> batch_data; - //! Lock for flushing to disk - mutex flush_lock; - //! Whether or not any threads are flushing (only one thread can flush at a time) - atomic any_flushing; - - void AddBatchData(idx_t batch_index, unique_ptr new_batch) { - // move the batch data to the set of prepared batch data - lock_guard l(lock); - auto entry = batch_data.insert(make_pair(batch_index, std::move(new_batch))); - if (!entry.second) { - throw InternalException("Duplicate batch index %llu encountered in PhysicalBatchCopyToFile", batch_index); - } - } -}; - -class BatchCopyToLocalState : public LocalSinkState { -public: - explicit BatchCopyToLocalState(unique_ptr local_state_p) - : local_state(std::move(local_state_p)), rows_copied(0) { - } - - //! Local copy state - unique_ptr local_state; - //! The current collection we are appending to - unique_ptr collection; - //! The append state of the collection - ColumnDataAppendState append_state; - //! How many rows have been copied in total - idx_t rows_copied; - //! The current batch index - optional_idx batch_index; - - void InitializeCollection(ClientContext &context, const PhysicalOperator &op) { - collection = make_uniq(BufferAllocator::Get(context), op.children[0]->types); - collection->InitializeAppend(append_state); - } -}; - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -SinkResultType PhysicalBatchCopyToFile::Sink(ExecutionContext &context, DataChunk &chunk, - OperatorSinkInput &input) const { - auto &state = input.local_state.Cast(); - if (!state.collection) { - state.InitializeCollection(context.client, *this); - state.batch_index = state.partition_info.batch_index.GetIndex(); - } - state.rows_copied += chunk.size(); - state.collection->Append(state.append_state, chunk); - return SinkResultType::NEED_MORE_INPUT; -} - -SinkCombineResultType PhysicalBatchCopyToFile::Combine(ExecutionContext &context, - OperatorSinkCombineInput &input) const { - auto &state = input.local_state.Cast(); - auto &gstate = input.global_state.Cast(); - gstate.rows_copied += state.rows_copied; - - return SinkCombineResultType::FINISHED; -} - -//===--------------------------------------------------------------------===// -// Finalize -//===--------------------------------------------------------------------===// -SinkFinalizeType PhysicalBatchCopyToFile::FinalFlush(ClientContext &context, GlobalSinkState &gstate_p) const { - auto &gstate = gstate_p.Cast(); - idx_t min_batch_index = idx_t(NumericLimits::Maximum()); - FlushBatchData(context, gstate_p, min_batch_index); - if (function.copy_to_finalize) { - function.copy_to_finalize(context, *bind_data, *gstate.global_state); - - if (use_tmp_file) { - PhysicalCopyToFile::MoveTmpFile(context, file_path); - } - } - return SinkFinalizeType::READY; -} - -SinkFinalizeType PhysicalBatchCopyToFile::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - FinalFlush(context, input.global_state); - return SinkFinalizeType::READY; -} - -//===--------------------------------------------------------------------===// -// Batch Data Handling -//===--------------------------------------------------------------------===// -void PhysicalBatchCopyToFile::PrepareBatchData(ClientContext &context, GlobalSinkState &gstate_p, idx_t batch_index, - unique_ptr collection) const { - auto &gstate = gstate_p.Cast(); - - // prepare the batch - auto batch_data = function.prepare_batch(context, *bind_data, *gstate.global_state, std::move(collection)); - gstate.AddBatchData(batch_index, std::move(batch_data)); -} - -void PhysicalBatchCopyToFile::FlushBatchData(ClientContext &context, GlobalSinkState &gstate_p, idx_t min_index) const { - auto &gstate = gstate_p.Cast(); - - // flush batch data to disk (if there are any to flush) - // grab the flush lock - we can only call flush_batch with this lock - // otherwise the data might end up in the wrong order - { - lock_guard l(gstate.flush_lock); - if (gstate.any_flushing) { - return; - } - gstate.any_flushing = true; - } - ActiveFlushGuard active_flush(gstate.any_flushing); - while (true) { - unique_ptr batch_data; - { - // fetch the next batch to flush (if any) - lock_guard l(gstate.lock); - if (gstate.batch_data.empty()) { - // no batch data left to flush - break; - } - auto entry = gstate.batch_data.begin(); - if (entry->first >= min_index) { - // this data is past the min_index - we cannot write it yet - break; - } - if (!entry->second) { - // this batch is in process of being prepared but is not ready yet - break; - } - batch_data = std::move(entry->second); - gstate.batch_data.erase(entry); - } - function.flush_batch(context, *bind_data, *gstate.global_state, *batch_data); - } -} - -//===--------------------------------------------------------------------===// -// Next Batch -//===--------------------------------------------------------------------===// -void PhysicalBatchCopyToFile::NextBatch(ExecutionContext &context, GlobalSinkState &gstate_p, - LocalSinkState &lstate) const { - auto &state = lstate.Cast(); - if (state.collection && state.collection->Count() > 0) { - // we finished processing this batch - // start flushing data - auto min_batch_index = lstate.partition_info.min_batch_index.GetIndex(); - PrepareBatchData(context.client, gstate_p, state.batch_index.GetIndex(), std::move(state.collection)); - FlushBatchData(context.client, gstate_p, min_batch_index); - } - state.batch_index = lstate.partition_info.batch_index.GetIndex(); - - state.InitializeCollection(context.client, *this); -} - -unique_ptr PhysicalBatchCopyToFile::GetLocalSinkState(ExecutionContext &context) const { - return make_uniq(function.copy_to_initialize_local(context, *bind_data)); -} - -unique_ptr PhysicalBatchCopyToFile::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(function.copy_to_initialize_global(context, *bind_data, file_path)); -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -SourceResultType PhysicalBatchCopyToFile::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &g = sink_state->Cast(); - - chunk.SetCardinality(1); - chunk.SetValue(0, 0, Value::BIGINT(g.rows_copied)); - return SourceResultType::FINISHED; -} - -} // namespace duckdb - - - - - - - - - - - - -namespace duckdb { - -PhysicalBatchInsert::PhysicalBatchInsert(vector types, TableCatalogEntry &table, - physical_index_vector_t column_index_map, - vector> bound_defaults, idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::BATCH_INSERT, std::move(types), estimated_cardinality), - column_index_map(std::move(column_index_map)), insert_table(&table), insert_types(table.GetTypes()), - bound_defaults(std::move(bound_defaults)) { -} - -PhysicalBatchInsert::PhysicalBatchInsert(LogicalOperator &op, SchemaCatalogEntry &schema, - unique_ptr info_p, idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::BATCH_CREATE_TABLE_AS, op.types, estimated_cardinality), - insert_table(nullptr), schema(&schema), info(std::move(info_p)) { - PhysicalInsert::GetInsertInfo(*info, insert_types, bound_defaults); -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// - -class CollectionMerger { -public: - explicit CollectionMerger(ClientContext &context) : context(context) { - } - - ClientContext &context; - vector> current_collections; - -public: - void AddCollection(unique_ptr collection) { - current_collections.push_back(std::move(collection)); - } - - bool Empty() { - return current_collections.empty(); - } - - unique_ptr Flush(OptimisticDataWriter &writer) { - if (Empty()) { - return nullptr; - } - unique_ptr new_collection = std::move(current_collections[0]); - if (current_collections.size() > 1) { - // we have gathered multiple collections: create one big collection and merge that - auto &types = new_collection->GetTypes(); - TableAppendState append_state; - new_collection->InitializeAppend(append_state); - - DataChunk scan_chunk; - scan_chunk.Initialize(context, types); - - vector column_ids; - for (idx_t i = 0; i < types.size(); i++) { - column_ids.push_back(i); - } - for (auto &collection : current_collections) { - if (!collection) { - continue; - } - TableScanState scan_state; - scan_state.Initialize(column_ids); - collection->InitializeScan(scan_state.local_state, column_ids, nullptr); - - while (true) { - scan_chunk.Reset(); - scan_state.local_state.ScanCommitted(scan_chunk, TableScanType::TABLE_SCAN_COMMITTED_ROWS); - if (scan_chunk.size() == 0) { - break; - } - auto new_row_group = new_collection->Append(scan_chunk, append_state); - if (new_row_group) { - writer.WriteNewRowGroup(*new_collection); - } - } - } - new_collection->FinalizeAppend(TransactionData(0, 0), append_state); - writer.WriteLastRowGroup(*new_collection); - } - current_collections.clear(); - return new_collection; - } -}; - -enum class RowGroupBatchType : uint8_t { FLUSHED, NOT_FLUSHED }; -struct RowGroupBatchEntry { - RowGroupBatchEntry(idx_t batch_idx, unique_ptr collection_p, RowGroupBatchType type) - : batch_idx(batch_idx), total_rows(collection_p->GetTotalRows()), collection(std::move(collection_p)), - type(type) { - } - - idx_t batch_idx; - idx_t total_rows; - unique_ptr collection; - RowGroupBatchType type; -}; - -class BatchInsertGlobalState : public GlobalSinkState { -public: - static constexpr const idx_t BATCH_FLUSH_THRESHOLD = LocalStorage::MERGE_THRESHOLD * 3; - -public: - explicit BatchInsertGlobalState(DuckTableEntry &table) : table(table), insert_count(0) { - } - - mutex lock; - DuckTableEntry &table; - idx_t insert_count; - vector collections; - idx_t next_start = 0; - bool optimistically_written = false; - - void FindMergeCollections(idx_t min_batch_index, optional_idx &merged_batch_index, - vector> &result) { - bool merge = false; - idx_t start_index = next_start; - idx_t current_idx; - idx_t total_count = 0; - for (current_idx = start_index; current_idx < collections.size(); current_idx++) { - auto &entry = collections[current_idx]; - if (entry.batch_idx >= min_batch_index) { - // this entry is AFTER the min_batch_index - // we might still find new entries! - break; - } - if (entry.type == RowGroupBatchType::FLUSHED) { - // already flushed: cannot flush anything here - if (total_count > 0) { - merge = true; - break; - } - start_index = current_idx + 1; - if (start_index > next_start) { - // avoid checking this segment again in the future - next_start = start_index; - } - total_count = 0; - continue; - } - // not flushed - add to set of indexes to flush - total_count += entry.total_rows; - if (total_count >= BATCH_FLUSH_THRESHOLD) { - merge = true; - break; - } - } - if (merge && total_count > 0) { - D_ASSERT(current_idx > start_index); - merged_batch_index = collections[start_index].batch_idx; - for (idx_t idx = start_index; idx < current_idx; idx++) { - auto &entry = collections[idx]; - if (!entry.collection || entry.type == RowGroupBatchType::FLUSHED) { - throw InternalException("Adding a row group collection that should not be flushed"); - } - result.push_back(std::move(entry.collection)); - entry.total_rows = total_count; - entry.type = RowGroupBatchType::FLUSHED; - } - if (start_index + 1 < current_idx) { - // erase all entries except the first one - collections.erase(collections.begin() + start_index + 1, collections.begin() + current_idx); - } - } - } - - unique_ptr MergeCollections(ClientContext &context, - vector> merge_collections, - OptimisticDataWriter &writer) { - D_ASSERT(!merge_collections.empty()); - CollectionMerger merger(context); - for (auto &collection : merge_collections) { - merger.AddCollection(std::move(collection)); - } - optimistically_written = true; - return merger.Flush(writer); - } - - void AddCollection(ClientContext &context, idx_t batch_index, idx_t min_batch_index, - unique_ptr current_collection, - optional_ptr writer = nullptr, - optional_ptr written_to_disk = nullptr) { - if (batch_index < min_batch_index) { - throw InternalException( - "Batch index of the added collection (%llu) is smaller than the min batch index (%llu)", batch_index, - min_batch_index); - } - auto new_count = current_collection->GetTotalRows(); - auto batch_type = - new_count < Storage::ROW_GROUP_SIZE ? RowGroupBatchType::NOT_FLUSHED : RowGroupBatchType::FLUSHED; - if (batch_type == RowGroupBatchType::FLUSHED && writer) { - writer->WriteLastRowGroup(*current_collection); - } - optional_idx merged_batch_index; - vector> merge_collections; - { - lock_guard l(lock); - insert_count += new_count; - - // add the collection to the batch index - RowGroupBatchEntry new_entry(batch_index, std::move(current_collection), batch_type); - - auto it = std::lower_bound( - collections.begin(), collections.end(), new_entry, - [&](const RowGroupBatchEntry &a, const RowGroupBatchEntry &b) { return a.batch_idx < b.batch_idx; }); - if (it != collections.end() && it->batch_idx == new_entry.batch_idx) { - throw InternalException( - "PhysicalBatchInsert::AddCollection error: batch index %d is present in multiple " - "collections. This occurs when " - "batch indexes are not uniquely distributed over threads", - batch_index); - } - collections.insert(it, std::move(new_entry)); - if (writer) { - FindMergeCollections(min_batch_index, merged_batch_index, merge_collections); - } - } - if (!merge_collections.empty()) { - // merge together the collections - D_ASSERT(writer); - auto final_collection = MergeCollections(context, std::move(merge_collections), *writer); - if (written_to_disk) { - *written_to_disk = true; - } - // add the merged-together collection to the set of batch indexes - { - lock_guard l(lock); - RowGroupBatchEntry new_entry(merged_batch_index.GetIndex(), std::move(final_collection), - RowGroupBatchType::FLUSHED); - auto it = std::lower_bound(collections.begin(), collections.end(), new_entry, - [&](const RowGroupBatchEntry &a, const RowGroupBatchEntry &b) { - return a.batch_idx < b.batch_idx; - }); - if (it->batch_idx != merged_batch_index.GetIndex()) { - throw InternalException("Merged batch index was no longer present in collection"); - } - it->collection = std::move(new_entry.collection); - } - } - } -}; - -class BatchInsertLocalState : public LocalSinkState { -public: - BatchInsertLocalState(ClientContext &context, const vector &types, - const vector> &bound_defaults) - : default_executor(context, bound_defaults), written_to_disk(false) { - insert_chunk.Initialize(Allocator::Get(context), types); - } - - DataChunk insert_chunk; - ExpressionExecutor default_executor; - idx_t current_index; - TableAppendState current_append_state; - unique_ptr current_collection; - optional_ptr writer; - bool written_to_disk; - - void CreateNewCollection(DuckTableEntry &table, const vector &insert_types) { - auto &table_info = table.GetStorage().info; - auto &block_manager = TableIOManager::Get(table.GetStorage()).GetBlockManagerForRowData(); - current_collection = make_uniq(table_info, block_manager, insert_types, MAX_ROW_ID); - current_collection->InitializeEmpty(); - current_collection->InitializeAppend(current_append_state); - written_to_disk = false; - } -}; - -unique_ptr PhysicalBatchInsert::GetGlobalSinkState(ClientContext &context) const { - optional_ptr table; - if (info) { - // CREATE TABLE AS - D_ASSERT(!insert_table); - auto &catalog = schema->catalog; - auto created_table = catalog.CreateTable(catalog.GetCatalogTransaction(context), *schema.get_mutable(), *info); - table = &created_table->Cast(); - } else { - D_ASSERT(insert_table); - D_ASSERT(insert_table->IsDuckTable()); - table = insert_table.get_mutable(); - } - auto result = make_uniq(table->Cast()); - return std::move(result); -} - -unique_ptr PhysicalBatchInsert::GetLocalSinkState(ExecutionContext &context) const { - return make_uniq(context.client, insert_types, bound_defaults); -} - -void PhysicalBatchInsert::NextBatch(ExecutionContext &context, GlobalSinkState &state, LocalSinkState &lstate_p) const { - auto &gstate = state.Cast(); - auto &lstate = lstate_p.Cast(); - - auto &table = gstate.table; - auto batch_index = lstate.partition_info.batch_index.GetIndex(); - if (lstate.current_collection) { - if (lstate.current_index == batch_index) { - throw InternalException("NextBatch called with the same batch index?"); - } - // batch index has changed: move the old collection to the global state and create a new collection - TransactionData tdata(0, 0); - lstate.current_collection->FinalizeAppend(tdata, lstate.current_append_state); - gstate.AddCollection(context.client, lstate.current_index, lstate.partition_info.min_batch_index.GetIndex(), - std::move(lstate.current_collection), lstate.writer, &lstate.written_to_disk); - lstate.CreateNewCollection(table, insert_types); - } - lstate.current_index = batch_index; -} - -SinkResultType PhysicalBatchInsert::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - - auto &table = gstate.table; - PhysicalInsert::ResolveDefaults(table, chunk, column_index_map, lstate.default_executor, lstate.insert_chunk); - - auto batch_index = lstate.partition_info.batch_index.GetIndex(); - if (!lstate.current_collection) { - lock_guard l(gstate.lock); - // no collection yet: create a new one - lstate.CreateNewCollection(table, insert_types); - lstate.writer = &table.GetStorage().CreateOptimisticWriter(context.client); - } - - if (lstate.current_index != batch_index) { - throw InternalException("Current batch differs from batch - but NextBatch was not called!?"); - } - - table.GetStorage().VerifyAppendConstraints(table, context.client, lstate.insert_chunk); - - auto new_row_group = lstate.current_collection->Append(lstate.insert_chunk, lstate.current_append_state); - if (new_row_group) { - // we have already written to disk - flush the next row group as well - lstate.writer->WriteNewRowGroup(*lstate.current_collection); - lstate.written_to_disk = true; - } - return SinkResultType::NEED_MORE_INPUT; -} - -SinkCombineResultType PhysicalBatchInsert::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - auto &client_profiler = QueryProfiler::Get(context.client); - context.thread.profiler.Flush(*this, lstate.default_executor, "default_executor", 1); - client_profiler.Flush(context.thread.profiler); - - if (!lstate.current_collection) { - return SinkCombineResultType::FINISHED; - } - - if (lstate.current_collection->GetTotalRows() > 0) { - TransactionData tdata(0, 0); - lstate.current_collection->FinalizeAppend(tdata, lstate.current_append_state); - gstate.AddCollection(context.client, lstate.current_index, lstate.partition_info.min_batch_index.GetIndex(), - std::move(lstate.current_collection)); - } - { - lock_guard l(gstate.lock); - gstate.table.GetStorage().FinalizeOptimisticWriter(context.client, *lstate.writer); - } - - return SinkCombineResultType::FINISHED; -} - -SinkFinalizeType PhysicalBatchInsert::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - auto &gstate = input.global_state.Cast(); - - if (gstate.optimistically_written || gstate.insert_count >= LocalStorage::MERGE_THRESHOLD) { - // we have written data to disk optimistically or are inserting a large amount of data - // perform a final pass over all of the row groups and merge them together - vector> mergers; - unique_ptr current_merger; - - auto &storage = gstate.table.GetStorage(); - for (auto &entry : gstate.collections) { - if (entry.type == RowGroupBatchType::NOT_FLUSHED) { - // this collection has not been flushed: add it to the merge set - if (!current_merger) { - current_merger = make_uniq(context); - } - current_merger->AddCollection(std::move(entry.collection)); - } else { - // this collection has been flushed: it does not need to be merged - // create a separate collection merger only for this entry - if (current_merger) { - // we have small collections remaining: flush them - mergers.push_back(std::move(current_merger)); - current_merger.reset(); - } - auto larger_merger = make_uniq(context); - larger_merger->AddCollection(std::move(entry.collection)); - mergers.push_back(std::move(larger_merger)); - } - } - if (current_merger) { - mergers.push_back(std::move(current_merger)); - } - - // now that we have created all of the mergers, perform the actual merging - vector> final_collections; - final_collections.reserve(mergers.size()); - auto &writer = storage.CreateOptimisticWriter(context); - for (auto &merger : mergers) { - final_collections.push_back(merger->Flush(writer)); - } - storage.FinalizeOptimisticWriter(context, writer); - - // finally, merge the row groups into the local storage - for (auto &collection : final_collections) { - storage.LocalMerge(context, *collection); - } - } else { - // we are writing a small amount of data to disk - // append directly to transaction local storage - auto &table = gstate.table; - auto &storage = table.GetStorage(); - LocalAppendState append_state; - storage.InitializeLocalAppend(append_state, context); - auto &transaction = DuckTransaction::Get(context, table.catalog); - for (auto &entry : gstate.collections) { - entry.collection->Scan(transaction, [&](DataChunk &insert_chunk) { - storage.LocalAppend(append_state, table, context, insert_chunk); - return true; - }); - } - storage.FinalizeLocalAppend(append_state); - } - return SinkFinalizeType::READY; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// - -SourceResultType PhysicalBatchInsert::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &insert_gstate = sink_state->Cast(); - - chunk.SetCardinality(1); - chunk.SetValue(0, 0, Value::BIGINT(insert_gstate.insert_count)); - - return SourceResultType::FINISHED; -} - -} // namespace duckdb - - - - - - - - -#include - -namespace duckdb { - -class CopyToFunctionGlobalState : public GlobalSinkState { -public: - explicit CopyToFunctionGlobalState(unique_ptr global_state) - : rows_copied(0), last_file_offset(0), global_state(std::move(global_state)) { - } - mutex lock; - idx_t rows_copied; - idx_t last_file_offset; - unique_ptr global_state; - - //! shared state for HivePartitionedColumnData - shared_ptr partition_state; -}; - -class CopyToFunctionLocalState : public LocalSinkState { -public: - explicit CopyToFunctionLocalState(unique_ptr local_state) - : local_state(std::move(local_state)), writer_offset(0) { - } - unique_ptr global_state; - unique_ptr local_state; - - //! Buffers the tuples in partitions before writing - unique_ptr part_buffer; - unique_ptr part_buffer_append_state; - - idx_t writer_offset; -}; - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// - -void PhysicalCopyToFile::MoveTmpFile(ClientContext &context, const string &tmp_file_path) { - auto &fs = FileSystem::GetFileSystem(context); - auto file_path = tmp_file_path.substr(0, tmp_file_path.length() - 4); - if (fs.FileExists(file_path)) { - fs.RemoveFile(file_path); - } - fs.MoveFile(tmp_file_path, file_path); -} - -PhysicalCopyToFile::PhysicalCopyToFile(vector types, CopyFunction function_p, - unique_ptr bind_data, idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::COPY_TO_FILE, std::move(types), estimated_cardinality), - function(std::move(function_p)), bind_data(std::move(bind_data)), parallel(false) { -} - -SinkResultType PhysicalCopyToFile::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - auto &g = input.global_state.Cast(); - auto &l = input.local_state.Cast(); - - if (partition_output) { - l.part_buffer->Append(*l.part_buffer_append_state, chunk); - return SinkResultType::NEED_MORE_INPUT; - } - - { - lock_guard glock(g.lock); - g.rows_copied += chunk.size(); - } - function.copy_to_sink(context, *bind_data, per_thread_output ? *l.global_state : *g.global_state, *l.local_state, - chunk); - return SinkResultType::NEED_MORE_INPUT; -} - -static void CreateDir(const string &dir_path, FileSystem &fs) { - if (!fs.DirectoryExists(dir_path)) { - fs.CreateDirectory(dir_path); - } -} - -static string CreateDirRecursive(const vector &cols, const vector &names, const vector &values, - string path, FileSystem &fs) { - CreateDir(path, fs); - - for (idx_t i = 0; i < cols.size(); i++) { - const auto &partition_col_name = names[cols[i]]; - const auto &partition_value = values[i]; - string p_dir = partition_col_name + "=" + partition_value.ToString(); - path = fs.JoinPath(path, p_dir); - CreateDir(path, fs); - } - - return path; -} - -SinkCombineResultType PhysicalCopyToFile::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { - auto &g = input.global_state.Cast(); - auto &l = input.local_state.Cast(); - - if (partition_output) { - auto &fs = FileSystem::GetFileSystem(context.client); - l.part_buffer->FlushAppendState(*l.part_buffer_append_state); - auto &partitions = l.part_buffer->GetPartitions(); - auto partition_key_map = l.part_buffer->GetReverseMap(); - - string trimmed_path = file_path; - StringUtil::RTrim(trimmed_path, fs.PathSeparator(trimmed_path)); - - for (idx_t i = 0; i < partitions.size(); i++) { - string hive_path = - CreateDirRecursive(partition_columns, names, partition_key_map[i]->values, trimmed_path, fs); - string full_path(filename_pattern.CreateFilename(fs, hive_path, function.extension, l.writer_offset)); - if (fs.FileExists(full_path) && !overwrite_or_ignore) { - throw IOException("failed to create " + full_path + - ", file exists! Enable OVERWRITE_OR_IGNORE option to force writing"); - } - // Create a writer for the current file - auto fun_data_global = function.copy_to_initialize_global(context.client, *bind_data, full_path); - auto fun_data_local = function.copy_to_initialize_local(context, *bind_data); - - for (auto &chunk : partitions[i]->Chunks()) { - function.copy_to_sink(context, *bind_data, *fun_data_global, *fun_data_local, chunk); - } - - function.copy_to_combine(context, *bind_data, *fun_data_global, *fun_data_local); - function.copy_to_finalize(context.client, *bind_data, *fun_data_global); - } - - return SinkCombineResultType::FINISHED; - } - - if (function.copy_to_combine) { - function.copy_to_combine(context, *bind_data, per_thread_output ? *l.global_state : *g.global_state, - *l.local_state); - - if (per_thread_output) { - function.copy_to_finalize(context.client, *bind_data, *l.global_state); - } - } - - return SinkCombineResultType::FINISHED; -} - -SinkFinalizeType PhysicalCopyToFile::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - auto &gstate = input.global_state.Cast(); - if (per_thread_output || partition_output) { - // already happened in combine - return SinkFinalizeType::READY; - } - if (function.copy_to_finalize) { - function.copy_to_finalize(context, *bind_data, *gstate.global_state); - - if (use_tmp_file) { - D_ASSERT(!per_thread_output); // FIXME - D_ASSERT(!partition_output); // FIXME - MoveTmpFile(context, file_path); - } - } - return SinkFinalizeType::READY; -} - -unique_ptr PhysicalCopyToFile::GetLocalSinkState(ExecutionContext &context) const { - if (partition_output) { - auto state = make_uniq(nullptr); - { - auto &g = sink_state->Cast(); - lock_guard glock(g.lock); - state->writer_offset = g.last_file_offset++; - - state->part_buffer = make_uniq(context.client, expected_types, partition_columns, - g.partition_state); - state->part_buffer_append_state = make_uniq(); - state->part_buffer->InitializeAppendState(*state->part_buffer_append_state); - } - return std::move(state); - } - auto res = make_uniq(function.copy_to_initialize_local(context, *bind_data)); - if (per_thread_output) { - idx_t this_file_offset; - { - auto &g = sink_state->Cast(); - lock_guard glock(g.lock); - this_file_offset = g.last_file_offset++; - } - auto &fs = FileSystem::GetFileSystem(context.client); - string output_path(filename_pattern.CreateFilename(fs, file_path, function.extension, this_file_offset)); - if (fs.FileExists(output_path) && !overwrite_or_ignore) { - throw IOException("%s exists! Enable OVERWRITE_OR_IGNORE option to force writing", output_path); - } - res->global_state = function.copy_to_initialize_global(context.client, *bind_data, output_path); - } - return std::move(res); -} - -unique_ptr PhysicalCopyToFile::GetGlobalSinkState(ClientContext &context) const { - - if (partition_output || per_thread_output) { - auto &fs = FileSystem::GetFileSystem(context); - - if (fs.FileExists(file_path) && !overwrite_or_ignore) { - throw IOException("%s exists! Enable OVERWRITE_OR_IGNORE option to force writing", file_path); - } - if (!fs.DirectoryExists(file_path)) { - fs.CreateDirectory(file_path); - } else if (!overwrite_or_ignore) { - idx_t n_files = 0; - fs.ListFiles(file_path, [&n_files](const string &path, bool) { n_files++; }); - if (n_files > 0) { - throw IOException("Directory %s is not empty! Enable OVERWRITE_OR_IGNORE option to force writing", - file_path); - } - } - - auto state = make_uniq(nullptr); - - if (partition_output) { - state->partition_state = make_shared(); - } - - return std::move(state); - } - - return make_uniq(function.copy_to_initialize_global(context, *bind_data, file_path)); -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// - -SourceResultType PhysicalCopyToFile::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &g = sink_state->Cast(); - - chunk.SetCardinality(1); - chunk.SetValue(0, 0, Value::BIGINT(g.rows_copied)); - - return SourceResultType::FINISHED; -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class DeleteGlobalState : public GlobalSinkState { -public: - explicit DeleteGlobalState(ClientContext &context, const vector &return_types) - : deleted_count(0), return_collection(context, return_types) { - } - - mutex delete_lock; - idx_t deleted_count; - ColumnDataCollection return_collection; -}; - -class DeleteLocalState : public LocalSinkState { -public: - DeleteLocalState(Allocator &allocator, const vector &table_types) { - delete_chunk.Initialize(allocator, table_types); - } - DataChunk delete_chunk; -}; - -SinkResultType PhysicalDelete::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &ustate = input.local_state.Cast(); - - // get rows and - auto &transaction = DuckTransaction::Get(context.client, table.db); - auto &row_identifiers = chunk.data[row_id_index]; - - vector column_ids; - for (idx_t i = 0; i < table.column_definitions.size(); i++) { - column_ids.emplace_back(i); - }; - auto cfs = ColumnFetchState(); - - lock_guard delete_guard(gstate.delete_lock); - if (return_chunk) { - row_identifiers.Flatten(chunk.size()); - table.Fetch(transaction, ustate.delete_chunk, column_ids, row_identifiers, chunk.size(), cfs); - gstate.return_collection.Append(ustate.delete_chunk); - } - gstate.deleted_count += table.Delete(tableref, context.client, row_identifiers, chunk.size()); - - return SinkResultType::NEED_MORE_INPUT; -} - -unique_ptr PhysicalDelete::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(context, GetTypes()); -} - -unique_ptr PhysicalDelete::GetLocalSinkState(ExecutionContext &context) const { - return make_uniq(Allocator::Get(context.client), table.GetTypes()); -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -class DeleteSourceState : public GlobalSourceState { -public: - explicit DeleteSourceState(const PhysicalDelete &op) { - if (op.return_chunk) { - D_ASSERT(op.sink_state); - auto &g = op.sink_state->Cast(); - g.return_collection.InitializeScan(scan_state); - } - } - - ColumnDataScanState scan_state; -}; - -unique_ptr PhysicalDelete::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(*this); -} - -SourceResultType PhysicalDelete::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &state = input.global_state.Cast(); - auto &g = sink_state->Cast(); - if (!return_chunk) { - chunk.SetCardinality(1); - chunk.SetValue(0, 0, Value::BIGINT(g.deleted_count)); - return SourceResultType::FINISHED; - } - - g.return_collection.Scan(state.scan_state, chunk); - - return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; -} - -} // namespace duckdb - - - - - - - - - - - - -#include -#include - -namespace duckdb { - -using std::stringstream; - -static void WriteCatalogEntries(stringstream &ss, vector> &entries) { - for (auto &entry : entries) { - if (entry.get().internal) { - continue; - } - ss << entry.get().ToSQL() << std::endl; - } - ss << std::endl; -} - -static void WriteStringStreamToFile(FileSystem &fs, stringstream &ss, const string &path) { - auto ss_string = ss.str(); - auto handle = fs.OpenFile(path, FileFlags::FILE_FLAGS_WRITE | FileFlags::FILE_FLAGS_FILE_CREATE_NEW, - FileLockType::WRITE_LOCK); - fs.Write(*handle, (void *)ss_string.c_str(), ss_string.size()); - handle.reset(); -} - -static void WriteValueAsSQL(stringstream &ss, Value &val) { - if (val.type().IsNumeric()) { - ss << val.ToString(); - } else { - ss << "'" << val.ToString() << "'"; - } -} - -static void WriteCopyStatement(FileSystem &fs, stringstream &ss, CopyInfo &info, ExportedTableData &exported_table, - CopyFunction const &function) { - ss << "COPY "; - - if (exported_table.schema_name != DEFAULT_SCHEMA) { - ss << KeywordHelper::WriteOptionallyQuoted(exported_table.schema_name) << "."; - } - - ss << StringUtil::Format("%s FROM %s (", SQLIdentifier(exported_table.table_name), - SQLString(exported_table.file_path)); - - // write the copy options - ss << "FORMAT '" << info.format << "'"; - if (info.format == "csv") { - // insert default csv options, if not specified - if (info.options.find("header") == info.options.end()) { - info.options["header"].push_back(Value::INTEGER(1)); - } - if (info.options.find("delimiter") == info.options.end() && info.options.find("sep") == info.options.end() && - info.options.find("delim") == info.options.end()) { - info.options["delimiter"].push_back(Value(",")); - } - if (info.options.find("quote") == info.options.end()) { - info.options["quote"].push_back(Value("\"")); - } - } - for (auto ©_option : info.options) { - if (copy_option.first == "force_quote") { - continue; - } - ss << ", " << copy_option.first << " "; - if (copy_option.second.size() == 1) { - WriteValueAsSQL(ss, copy_option.second[0]); - } else { - // FIXME handle multiple options - throw NotImplementedException("FIXME: serialize list of options"); - } - } - ss << ");" << std::endl; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -class ExportSourceState : public GlobalSourceState { -public: - ExportSourceState() : finished(false) { - } - - bool finished; -}; - -unique_ptr PhysicalExport::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(); -} - -SourceResultType PhysicalExport::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &state = input.global_state.Cast(); - if (state.finished) { - return SourceResultType::FINISHED; - } - - auto &ccontext = context.client; - auto &fs = FileSystem::GetFileSystem(ccontext); - - // gather all catalog types to export - vector> schemas; - vector> custom_types; - vector> sequences; - vector> tables; - vector> views; - vector> indexes; - vector> macros; - - auto schema_list = Catalog::GetSchemas(ccontext, info->catalog); - for (auto &schema_p : schema_list) { - auto &schema = schema_p.get(); - if (!schema.internal) { - schemas.push_back(schema); - } - schema.Scan(context.client, CatalogType::TABLE_ENTRY, [&](CatalogEntry &entry) { - if (entry.internal) { - return; - } - if (entry.type != CatalogType::TABLE_ENTRY) { - views.push_back(entry); - } - }); - schema.Scan(context.client, CatalogType::SEQUENCE_ENTRY, - [&](CatalogEntry &entry) { sequences.push_back(entry); }); - schema.Scan(context.client, CatalogType::TYPE_ENTRY, - [&](CatalogEntry &entry) { custom_types.push_back(entry); }); - schema.Scan(context.client, CatalogType::INDEX_ENTRY, [&](CatalogEntry &entry) { indexes.push_back(entry); }); - schema.Scan(context.client, CatalogType::MACRO_ENTRY, [&](CatalogEntry &entry) { - if (!entry.internal && entry.type == CatalogType::MACRO_ENTRY) { - macros.push_back(entry); - } - }); - schema.Scan(context.client, CatalogType::TABLE_MACRO_ENTRY, [&](CatalogEntry &entry) { - if (!entry.internal && entry.type == CatalogType::TABLE_MACRO_ENTRY) { - macros.push_back(entry); - } - }); - } - - // consider the order of tables because of foreign key constraint - for (idx_t i = 0; i < exported_tables.data.size(); i++) { - tables.push_back(exported_tables.data[i].entry); - } - - // order macro's by timestamp so nested macro's are imported nicely - sort(macros.begin(), macros.end(), [](const reference &lhs, const reference &rhs) { - return lhs.get().oid < rhs.get().oid; - }); - - // write the schema.sql file - // export order is SCHEMA -> SEQUENCE -> TABLE -> VIEW -> INDEX - - stringstream ss; - WriteCatalogEntries(ss, schemas); - WriteCatalogEntries(ss, custom_types); - WriteCatalogEntries(ss, sequences); - WriteCatalogEntries(ss, tables); - WriteCatalogEntries(ss, views); - WriteCatalogEntries(ss, indexes); - WriteCatalogEntries(ss, macros); - - WriteStringStreamToFile(fs, ss, fs.JoinPath(info->file_path, "schema.sql")); - - // write the load.sql file - // for every table, we write COPY INTO statement with the specified options - stringstream load_ss; - for (idx_t i = 0; i < exported_tables.data.size(); i++) { - auto exported_table_info = exported_tables.data[i].table_data; - WriteCopyStatement(fs, load_ss, *info, exported_table_info, function); - } - WriteStringStreamToFile(fs, load_ss, fs.JoinPath(info->file_path, "load.sql")); - state.finished = true; - - return SourceResultType::FINISHED; -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -SinkResultType PhysicalExport::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - // nop - return SinkResultType::NEED_MORE_INPUT; -} - -//===--------------------------------------------------------------------===// -// Pipeline Construction -//===--------------------------------------------------------------------===// -void PhysicalExport::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { - // EXPORT has an optional child - // we only need to schedule child pipelines if there is a child - auto &state = meta_pipeline.GetState(); - state.SetPipelineSource(current, *this); - if (children.empty()) { - return; - } - PhysicalOperator::BuildPipelines(current, meta_pipeline); -} - -vector> PhysicalExport::GetSources() const { - return {*this}; -} - -} // namespace duckdb - - - - - - - - - -#include - -namespace duckdb { - -PhysicalFixedBatchCopy::PhysicalFixedBatchCopy(vector types, CopyFunction function_p, - unique_ptr bind_data_p, idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::BATCH_COPY_TO_FILE, std::move(types), estimated_cardinality), - function(std::move(function_p)), bind_data(std::move(bind_data_p)) { - if (!function.flush_batch || !function.prepare_batch || !function.desired_batch_size) { - throw InternalException("PhysicalFixedBatchCopy created for copy function that does not have " - "prepare_batch/flush_batch/desired_batch_size defined"); - } -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class BatchCopyTask { -public: - virtual ~BatchCopyTask() { - } - - virtual void Execute(const PhysicalFixedBatchCopy &op, ClientContext &context, GlobalSinkState &gstate_p) = 0; -}; - -//===--------------------------------------------------------------------===// -// States -//===--------------------------------------------------------------------===// -class FixedBatchCopyGlobalState : public GlobalSinkState { -public: - explicit FixedBatchCopyGlobalState(unique_ptr global_state) - : rows_copied(0), global_state(std::move(global_state)), batch_size(0), scheduled_batch_index(0), - flushed_batch_index(0), any_flushing(false), any_finished(false) { - } - - mutex lock; - mutex flush_lock; - //! The total number of rows copied to the file - atomic rows_copied; - //! Global copy state - unique_ptr global_state; - //! The desired batch size (if any) - idx_t batch_size; - //! Unpartitioned batches - only used in case batch_size is required - map> raw_batches; - //! The prepared batch data by batch index - ready to flush - map> batch_data; - //! The index of the latest batch index that has been scheduled - atomic scheduled_batch_index; - //! The index of the latest batch index that has been flushed - atomic flushed_batch_index; - //! Whether or not any thread is flushing - atomic any_flushing; - //! Whether or not any threads are finished - atomic any_finished; - - void AddTask(unique_ptr task) { - lock_guard l(task_lock); - task_queue.push(std::move(task)); - } - - unique_ptr GetTask() { - lock_guard l(task_lock); - if (task_queue.empty()) { - return nullptr; - } - auto entry = std::move(task_queue.front()); - task_queue.pop(); - return entry; - } - - idx_t TaskCount() { - lock_guard l(task_lock); - return task_queue.size(); - } - - void AddBatchData(idx_t batch_index, unique_ptr new_batch) { - // move the batch data to the set of prepared batch data - lock_guard l(lock); - auto entry = batch_data.insert(make_pair(batch_index, std::move(new_batch))); - if (!entry.second) { - throw InternalException("Duplicate batch index %llu encountered in PhysicalFixedBatchCopy", batch_index); - } - } - -private: - mutex task_lock; - //! The task queue for the batch copy to file - queue> task_queue; -}; - -class FixedBatchCopyLocalState : public LocalSinkState { -public: - explicit FixedBatchCopyLocalState(unique_ptr local_state_p) - : local_state(std::move(local_state_p)), rows_copied(0) { - } - - //! Local copy state - unique_ptr local_state; - //! The current collection we are appending to - unique_ptr collection; - //! The append state of the collection - ColumnDataAppendState append_state; - //! How many rows have been copied in total - idx_t rows_copied; - //! The current batch index - optional_idx batch_index; - - void InitializeCollection(ClientContext &context, const PhysicalOperator &op) { - collection = make_uniq(BufferAllocator::Get(context), op.children[0]->types); - collection->InitializeAppend(append_state); - } -}; - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -SinkResultType PhysicalFixedBatchCopy::Sink(ExecutionContext &context, DataChunk &chunk, - OperatorSinkInput &input) const { - auto &state = input.local_state.Cast(); - if (!state.collection) { - state.InitializeCollection(context.client, *this); - state.batch_index = state.partition_info.batch_index.GetIndex(); - } - state.rows_copied += chunk.size(); - state.collection->Append(state.append_state, chunk); - return SinkResultType::NEED_MORE_INPUT; -} - -SinkCombineResultType PhysicalFixedBatchCopy::Combine(ExecutionContext &context, - OperatorSinkCombineInput &input) const { - auto &state = input.local_state.Cast(); - auto &gstate = input.global_state.Cast(); - gstate.rows_copied += state.rows_copied; - if (!gstate.any_finished) { - // signal that this thread is finished processing batches and that we should move on to Finalize - lock_guard l(gstate.lock); - gstate.any_finished = true; - } - ExecuteTasks(context.client, gstate); - - return SinkCombineResultType::FINISHED; -} - -//===--------------------------------------------------------------------===// -// ProcessRemainingBatchesEvent -//===--------------------------------------------------------------------===// -class ProcessRemainingBatchesTask : public ExecutorTask { -public: - ProcessRemainingBatchesTask(Executor &executor, shared_ptr event_p, FixedBatchCopyGlobalState &state_p, - ClientContext &context, const PhysicalFixedBatchCopy &op) - : ExecutorTask(executor), event(std::move(event_p)), op(op), gstate(state_p), context(context) { - } - - TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { - while (op.ExecuteTask(context, gstate)) { - op.FlushBatchData(context, gstate, 0); - } - event->FinishTask(); - return TaskExecutionResult::TASK_FINISHED; - } - -private: - shared_ptr event; - const PhysicalFixedBatchCopy &op; - FixedBatchCopyGlobalState &gstate; - ClientContext &context; -}; - -class ProcessRemainingBatchesEvent : public BasePipelineEvent { -public: - ProcessRemainingBatchesEvent(const PhysicalFixedBatchCopy &op_p, FixedBatchCopyGlobalState &gstate_p, - Pipeline &pipeline_p, ClientContext &context) - : BasePipelineEvent(pipeline_p), op(op_p), gstate(gstate_p), context(context) { - } - const PhysicalFixedBatchCopy &op; - FixedBatchCopyGlobalState &gstate; - ClientContext &context; - -public: - void Schedule() override { - vector> tasks; - for (idx_t i = 0; i < idx_t(TaskScheduler::GetScheduler(context).NumberOfThreads()); i++) { - auto process_task = - make_uniq(pipeline->executor, shared_from_this(), gstate, context, op); - tasks.push_back(std::move(process_task)); - } - D_ASSERT(!tasks.empty()); - SetTasks(std::move(tasks)); - } - - void FinishEvent() override { - //! Now that all batches are processed we finish flushing the file to disk - op.FinalFlush(context, gstate); - } -}; -//===--------------------------------------------------------------------===// -// Finalize -//===--------------------------------------------------------------------===// -SinkFinalizeType PhysicalFixedBatchCopy::FinalFlush(ClientContext &context, GlobalSinkState &gstate_p) const { - auto &gstate = gstate_p.Cast(); - if (gstate.TaskCount() != 0) { - throw InternalException("Unexecuted tasks are remaining in PhysicalFixedBatchCopy::FinalFlush!?"); - } - idx_t min_batch_index = idx_t(NumericLimits::Maximum()); - FlushBatchData(context, gstate_p, min_batch_index); - if (gstate.scheduled_batch_index != gstate.flushed_batch_index) { - throw InternalException("Not all batches were flushed to disk - incomplete file?"); - } - if (function.copy_to_finalize) { - function.copy_to_finalize(context, *bind_data, *gstate.global_state); - - if (use_tmp_file) { - PhysicalCopyToFile::MoveTmpFile(context, file_path); - } - } - return SinkFinalizeType::READY; -} - -SinkFinalizeType PhysicalFixedBatchCopy::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - auto &gstate = input.global_state.Cast(); - idx_t min_batch_index = idx_t(NumericLimits::Maximum()); - // repartition any remaining batches - RepartitionBatches(context, input.global_state, min_batch_index, true); - // check if we have multiple tasks to execute - if (gstate.TaskCount() <= 1) { - // we don't - just execute the remaining task and finish flushing to disk - ExecuteTasks(context, input.global_state); - FinalFlush(context, input.global_state); - return SinkFinalizeType::READY; - } - // we have multiple tasks remaining - launch an event to execute the tasks in parallel - auto new_event = make_shared(*this, gstate, pipeline, context); - event.InsertEvent(std::move(new_event)); - return SinkFinalizeType::READY; -} - -//===--------------------------------------------------------------------===// -// Tasks -//===--------------------------------------------------------------------===// -class RepartitionedFlushTask : public BatchCopyTask { -public: - RepartitionedFlushTask() { - } - - void Execute(const PhysicalFixedBatchCopy &op, ClientContext &context, GlobalSinkState &gstate_p) override { - op.FlushBatchData(context, gstate_p, 0); - } -}; - -class PrepareBatchTask : public BatchCopyTask { -public: - PrepareBatchTask(idx_t batch_index, unique_ptr collection_p) - : batch_index(batch_index), collection(std::move(collection_p)) { - } - - idx_t batch_index; - unique_ptr collection; - - void Execute(const PhysicalFixedBatchCopy &op, ClientContext &context, GlobalSinkState &gstate_p) override { - auto &gstate = gstate_p.Cast(); - auto batch_data = - op.function.prepare_batch(context, *op.bind_data, *gstate.global_state, std::move(collection)); - gstate.AddBatchData(batch_index, std::move(batch_data)); - if (batch_index == gstate.flushed_batch_index) { - gstate.AddTask(make_uniq()); - } - } -}; - -//===--------------------------------------------------------------------===// -// Batch Data Handling -//===--------------------------------------------------------------------===// -void PhysicalFixedBatchCopy::AddRawBatchData(ClientContext &context, GlobalSinkState &gstate_p, idx_t batch_index, - unique_ptr collection) const { - auto &gstate = gstate_p.Cast(); - - // add the batch index to the set of raw batches - lock_guard l(gstate.lock); - auto entry = gstate.raw_batches.insert(make_pair(batch_index, std::move(collection))); - if (!entry.second) { - throw InternalException("Duplicate batch index %llu encountered in PhysicalFixedBatchCopy", batch_index); - } -} - -static bool CorrectSizeForBatch(idx_t collection_size, idx_t desired_size) { - return idx_t(AbsValue(int64_t(collection_size) - int64_t(desired_size))) < STANDARD_VECTOR_SIZE; -} - -void PhysicalFixedBatchCopy::RepartitionBatches(ClientContext &context, GlobalSinkState &gstate_p, idx_t min_index, - bool final) const { - auto &gstate = gstate_p.Cast(); - - // repartition batches until the min index is reached - lock_guard l(gstate.lock); - if (gstate.raw_batches.empty()) { - return; - } - if (!final) { - if (gstate.any_finished) { - // we only repartition in ::NextBatch if all threads are still busy processing batches - // otherwise we might end up repartitioning a lot of data with only a few threads remaining - // which causes erratic performance - return; - } - // if this is not the final flush we first check if we have enough data to merge past the batch threshold - idx_t candidate_rows = 0; - for (auto entry = gstate.raw_batches.begin(); entry != gstate.raw_batches.end(); entry++) { - if (entry->first >= min_index) { - // we have exceeded the minimum batch - break; - } - candidate_rows += entry->second->Count(); - } - if (candidate_rows < gstate.batch_size) { - // not enough rows - cancel! - return; - } - } - // gather all collections we can repartition - idx_t max_batch_index = 0; - vector> collections; - for (auto entry = gstate.raw_batches.begin(); entry != gstate.raw_batches.end();) { - if (entry->first >= min_index) { - break; - } - max_batch_index = entry->first; - collections.push_back(std::move(entry->second)); - entry = gstate.raw_batches.erase(entry); - } - unique_ptr current_collection; - ColumnDataAppendState append_state; - // now perform the actual repartitioning - for (auto &collection : collections) { - if (!current_collection) { - if (CorrectSizeForBatch(collection->Count(), gstate.batch_size)) { - // the collection is ~approximately equal to the batch size (off by at most one vector) - // use it directly - gstate.AddTask(make_uniq(gstate.scheduled_batch_index++, std::move(collection))); - collection.reset(); - } else if (collection->Count() < gstate.batch_size) { - // the collection is smaller than the batch size - use it as a starting point - current_collection = std::move(collection); - collection.reset(); - } else { - // the collection is too large for a batch - we need to repartition - // create an empty collection - current_collection = make_uniq(BufferAllocator::Get(context), children[0]->types); - } - if (current_collection) { - current_collection->InitializeAppend(append_state); - } - } - if (!collection) { - // we have consumed the collection already - no need to append - continue; - } - // iterate the collection while appending - for (auto &chunk : collection->Chunks()) { - // append the chunk to the collection - current_collection->Append(append_state, chunk); - if (current_collection->Count() < gstate.batch_size) { - // the collection is still under the batch size - continue - continue; - } - // the collection is full - move it to the result and create a new one - gstate.AddTask(make_uniq(gstate.scheduled_batch_index++, std::move(current_collection))); - current_collection = make_uniq(BufferAllocator::Get(context), children[0]->types); - current_collection->InitializeAppend(append_state); - } - } - if (current_collection && current_collection->Count() > 0) { - // if there are any remaining batches that are not filled up to the batch size - // AND this is not the final collection - // re-add it to the set of raw (to-be-merged) batches - if (final || CorrectSizeForBatch(current_collection->Count(), gstate.batch_size)) { - gstate.AddTask(make_uniq(gstate.scheduled_batch_index++, std::move(current_collection))); - } else { - gstate.raw_batches[max_batch_index] = std::move(current_collection); - } - } -} - -void PhysicalFixedBatchCopy::FlushBatchData(ClientContext &context, GlobalSinkState &gstate_p, idx_t min_index) const { - auto &gstate = gstate_p.Cast(); - - // flush batch data to disk (if there are any to flush) - // grab the flush lock - we can only call flush_batch with this lock - // otherwise the data might end up in the wrong order - { - lock_guard l(gstate.flush_lock); - if (gstate.any_flushing) { - return; - } - gstate.any_flushing = true; - } - ActiveFlushGuard active_flush(gstate.any_flushing); - while (true) { - unique_ptr batch_data; - { - lock_guard l(gstate.lock); - if (gstate.batch_data.empty()) { - // no batch data left to flush - break; - } - auto entry = gstate.batch_data.begin(); - if (entry->first != gstate.flushed_batch_index) { - // this entry is not yet ready to be flushed - break; - } - if (entry->first < gstate.flushed_batch_index) { - throw InternalException("Batch index was out of order!?"); - } - batch_data = std::move(entry->second); - gstate.batch_data.erase(entry); - } - function.flush_batch(context, *bind_data, *gstate.global_state, *batch_data); - gstate.flushed_batch_index++; - } -} - -//===--------------------------------------------------------------------===// -// Tasks -//===--------------------------------------------------------------------===// -bool PhysicalFixedBatchCopy::ExecuteTask(ClientContext &context, GlobalSinkState &gstate_p) const { - auto &gstate = gstate_p.Cast(); - auto task = gstate.GetTask(); - if (!task) { - return false; - } - task->Execute(*this, context, gstate_p); - return true; -} - -void PhysicalFixedBatchCopy::ExecuteTasks(ClientContext &context, GlobalSinkState &gstate_p) const { - while (ExecuteTask(context, gstate_p)) { - } -} - -//===--------------------------------------------------------------------===// -// Next Batch -//===--------------------------------------------------------------------===// -void PhysicalFixedBatchCopy::NextBatch(ExecutionContext &context, GlobalSinkState &gstate_p, - LocalSinkState &lstate) const { - auto &state = lstate.Cast(); - if (state.collection && state.collection->Count() > 0) { - // we finished processing this batch - // start flushing data - auto min_batch_index = lstate.partition_info.min_batch_index.GetIndex(); - // push the raw batch data into the set of unprocessed batches - AddRawBatchData(context.client, gstate_p, state.batch_index.GetIndex(), std::move(state.collection)); - // attempt to repartition to our desired batch size - RepartitionBatches(context.client, gstate_p, min_batch_index); - // execute a single batch task - ExecuteTask(context.client, gstate_p); - FlushBatchData(context.client, gstate_p, min_batch_index); - } - state.batch_index = lstate.partition_info.batch_index.GetIndex(); - - state.InitializeCollection(context.client, *this); -} - -unique_ptr PhysicalFixedBatchCopy::GetLocalSinkState(ExecutionContext &context) const { - return make_uniq(function.copy_to_initialize_local(context, *bind_data)); -} - -unique_ptr PhysicalFixedBatchCopy::GetGlobalSinkState(ClientContext &context) const { - auto result = - make_uniq(function.copy_to_initialize_global(context, *bind_data, file_path)); - result->batch_size = function.desired_batch_size(context, *bind_data); - return std::move(result); -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -SourceResultType PhysicalFixedBatchCopy::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &g = sink_state->Cast(); - - chunk.SetCardinality(1); - chunk.SetValue(0, 0, Value::BIGINT(g.rows_copied)); - return SourceResultType::FINISHED; -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - - - -namespace duckdb { - -PhysicalInsert::PhysicalInsert(vector types_p, TableCatalogEntry &table, - physical_index_vector_t column_index_map, - vector> bound_defaults, - vector> set_expressions, vector set_columns, - vector set_types, idx_t estimated_cardinality, bool return_chunk, - bool parallel, OnConflictAction action_type, - unique_ptr on_conflict_condition_p, - unique_ptr do_update_condition_p, unordered_set conflict_target_p, - vector columns_to_fetch_p) - : PhysicalOperator(PhysicalOperatorType::INSERT, std::move(types_p), estimated_cardinality), - column_index_map(std::move(column_index_map)), insert_table(&table), insert_types(table.GetTypes()), - bound_defaults(std::move(bound_defaults)), return_chunk(return_chunk), parallel(parallel), - action_type(action_type), set_expressions(std::move(set_expressions)), set_columns(std::move(set_columns)), - set_types(std::move(set_types)), on_conflict_condition(std::move(on_conflict_condition_p)), - do_update_condition(std::move(do_update_condition_p)), conflict_target(std::move(conflict_target_p)), - columns_to_fetch(std::move(columns_to_fetch_p)) { - - if (action_type == OnConflictAction::THROW) { - return; - } - - D_ASSERT(this->set_expressions.size() == this->set_columns.size()); - - // One or more columns are referenced from the existing table, - // we use the 'insert_types' to figure out which types these columns have - types_to_fetch = vector(columns_to_fetch.size(), LogicalType::SQLNULL); - for (idx_t i = 0; i < columns_to_fetch.size(); i++) { - auto &id = columns_to_fetch[i]; - D_ASSERT(id < insert_types.size()); - types_to_fetch[i] = insert_types[id]; - } -} - -PhysicalInsert::PhysicalInsert(LogicalOperator &op, SchemaCatalogEntry &schema, unique_ptr info_p, - idx_t estimated_cardinality, bool parallel) - : PhysicalOperator(PhysicalOperatorType::CREATE_TABLE_AS, op.types, estimated_cardinality), insert_table(nullptr), - return_chunk(false), schema(&schema), info(std::move(info_p)), parallel(parallel), - action_type(OnConflictAction::THROW) { - GetInsertInfo(*info, insert_types, bound_defaults); -} - -void PhysicalInsert::GetInsertInfo(const BoundCreateTableInfo &info, vector &insert_types, - vector> &bound_defaults) { - auto &create_info = info.base->Cast(); - for (auto &col : create_info.columns.Physical()) { - insert_types.push_back(col.GetType()); - bound_defaults.push_back(make_uniq(Value(col.GetType()))); - } -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class InsertGlobalState : public GlobalSinkState { -public: - explicit InsertGlobalState(ClientContext &context, const vector &return_types, DuckTableEntry &table) - : table(table), insert_count(0), initialized(false), return_collection(context, return_types) { - } - - mutex lock; - DuckTableEntry &table; - idx_t insert_count; - bool initialized; - LocalAppendState append_state; - ColumnDataCollection return_collection; -}; - -class InsertLocalState : public LocalSinkState { -public: - InsertLocalState(ClientContext &context, const vector &types, - const vector> &bound_defaults) - : default_executor(context, bound_defaults) { - insert_chunk.Initialize(Allocator::Get(context), types); - } - - DataChunk insert_chunk; - ExpressionExecutor default_executor; - TableAppendState local_append_state; - unique_ptr local_collection; - optional_ptr writer; - // Rows that have been updated by a DO UPDATE conflict - unordered_set updated_global_rows; - // Rows in the transaction-local storage that have been updated by a DO UPDATE conflict - unordered_set updated_local_rows; - idx_t update_count = 0; -}; - -unique_ptr PhysicalInsert::GetGlobalSinkState(ClientContext &context) const { - optional_ptr table; - if (info) { - // CREATE TABLE AS - D_ASSERT(!insert_table); - auto &catalog = schema->catalog; - table = &catalog.CreateTable(catalog.GetCatalogTransaction(context), *schema.get_mutable(), *info) - ->Cast(); - } else { - D_ASSERT(insert_table); - D_ASSERT(insert_table->IsDuckTable()); - table = insert_table.get_mutable(); - } - auto result = make_uniq(context, GetTypes(), table->Cast()); - return std::move(result); -} - -unique_ptr PhysicalInsert::GetLocalSinkState(ExecutionContext &context) const { - return make_uniq(context.client, insert_types, bound_defaults); -} - -void PhysicalInsert::ResolveDefaults(const TableCatalogEntry &table, DataChunk &chunk, - const physical_index_vector_t &column_index_map, - ExpressionExecutor &default_executor, DataChunk &result) { - chunk.Flatten(); - default_executor.SetChunk(chunk); - - result.Reset(); - result.SetCardinality(chunk); - - if (!column_index_map.empty()) { - // columns specified by the user, use column_index_map - for (auto &col : table.GetColumns().Physical()) { - auto storage_idx = col.StorageOid(); - auto mapped_index = column_index_map[col.Physical()]; - if (mapped_index == DConstants::INVALID_INDEX) { - // insert default value - default_executor.ExecuteExpression(storage_idx, result.data[storage_idx]); - } else { - // get value from child chunk - D_ASSERT((idx_t)mapped_index < chunk.ColumnCount()); - D_ASSERT(result.data[storage_idx].GetType() == chunk.data[mapped_index].GetType()); - result.data[storage_idx].Reference(chunk.data[mapped_index]); - } - } - } else { - // no columns specified, just append directly - for (idx_t i = 0; i < result.ColumnCount(); i++) { - D_ASSERT(result.data[i].GetType() == chunk.data[i].GetType()); - result.data[i].Reference(chunk.data[i]); - } - } -} - -bool AllConflictsMeetCondition(DataChunk &result) { - auto data = FlatVector::GetData(result.data[0]); - for (idx_t i = 0; i < result.size(); i++) { - if (!data[i]) { - return false; - } - } - return true; -} - -void CheckOnConflictCondition(ExecutionContext &context, DataChunk &conflicts, const unique_ptr &condition, - DataChunk &result) { - ExpressionExecutor executor(context.client, *condition); - result.Initialize(context.client, {LogicalType::BOOLEAN}); - executor.Execute(conflicts, result); - result.SetCardinality(conflicts.size()); -} - -static void CombineExistingAndInsertTuples(DataChunk &result, DataChunk &scan_chunk, DataChunk &input_chunk, - ClientContext &client, const PhysicalInsert &op) { - auto &types_to_fetch = op.types_to_fetch; - auto &insert_types = op.insert_types; - - if (types_to_fetch.empty()) { - // We have not scanned the initial table, so we can just duplicate the initial chunk - result.Initialize(client, input_chunk.GetTypes()); - result.Reference(input_chunk); - result.SetCardinality(input_chunk); - return; - } - vector combined_types; - combined_types.reserve(insert_types.size() + types_to_fetch.size()); - combined_types.insert(combined_types.end(), insert_types.begin(), insert_types.end()); - combined_types.insert(combined_types.end(), types_to_fetch.begin(), types_to_fetch.end()); - - result.Initialize(client, combined_types); - result.Reset(); - // Add the VALUES list - for (idx_t i = 0; i < insert_types.size(); i++) { - idx_t col_idx = i; - auto &other_col = input_chunk.data[i]; - auto &this_col = result.data[col_idx]; - D_ASSERT(other_col.GetType() == this_col.GetType()); - this_col.Reference(other_col); - } - // Add the columns from the original conflicting tuples - for (idx_t i = 0; i < types_to_fetch.size(); i++) { - idx_t col_idx = i + insert_types.size(); - auto &other_col = scan_chunk.data[i]; - auto &this_col = result.data[col_idx]; - D_ASSERT(other_col.GetType() == this_col.GetType()); - this_col.Reference(other_col); - } - // This is guaranteed by the requirement of a conflict target to have a condition or set expressions - // Only when we have any sort of condition or SET expression that references the existing table is this possible - // to not be true. - // We can have a SET expression without a conflict target ONLY if there is only 1 Index on the table - // In which case this also can't cause a discrepancy between existing tuple count and insert tuple count - D_ASSERT(input_chunk.size() == scan_chunk.size()); - result.SetCardinality(input_chunk.size()); -} - -static void CreateUpdateChunk(ExecutionContext &context, DataChunk &chunk, TableCatalogEntry &table, Vector &row_ids, - DataChunk &update_chunk, const PhysicalInsert &op) { - - auto &do_update_condition = op.do_update_condition; - auto &set_types = op.set_types; - auto &set_expressions = op.set_expressions; - // Check the optional condition for the DO UPDATE clause, to filter which rows will be updated - if (do_update_condition) { - DataChunk do_update_filter_result; - do_update_filter_result.Initialize(context.client, {LogicalType::BOOLEAN}); - ExpressionExecutor where_executor(context.client, *do_update_condition); - where_executor.Execute(chunk, do_update_filter_result); - do_update_filter_result.SetCardinality(chunk.size()); - - ManagedSelection selection(chunk.size()); - - auto where_data = FlatVector::GetData(do_update_filter_result.data[0]); - for (idx_t i = 0; i < chunk.size(); i++) { - if (where_data[i]) { - selection.Append(i); - } - } - if (selection.Count() != selection.Size()) { - // Not all conflicts met the condition, need to filter out the ones that don't - chunk.Slice(selection.Selection(), selection.Count()); - chunk.SetCardinality(selection.Count()); - // Also apply this Slice to the to-update row_ids - row_ids.Slice(selection.Selection(), selection.Count()); - } - } - - // Execute the SET expressions - update_chunk.Initialize(context.client, set_types); - ExpressionExecutor executor(context.client, set_expressions); - executor.Execute(chunk, update_chunk); - update_chunk.SetCardinality(chunk); -} - -template -static idx_t PerformOnConflictAction(ExecutionContext &context, DataChunk &chunk, TableCatalogEntry &table, - Vector &row_ids, const PhysicalInsert &op) { - - if (op.action_type == OnConflictAction::NOTHING) { - return 0; - } - auto &set_columns = op.set_columns; - - DataChunk update_chunk; - CreateUpdateChunk(context, chunk, table, row_ids, update_chunk, op); - - auto &data_table = table.GetStorage(); - // Perform the update, using the results of the SET expressions - if (GLOBAL) { - data_table.Update(table, context.client, row_ids, set_columns, update_chunk); - } else { - auto &local_storage = LocalStorage::Get(context.client, data_table.db); - // Perform the update, using the results of the SET expressions - local_storage.Update(data_table, row_ids, set_columns, update_chunk); - } - return update_chunk.size(); -} - -// TODO: should we use a hash table to keep track of this instead? -template -static void RegisterUpdatedRows(InsertLocalState &lstate, const Vector &row_ids, idx_t count) { - // Insert all rows, if any of the rows has already been updated before, we throw an error - auto data = FlatVector::GetData(row_ids); - - // The rowids in the transaction-local ART aren't final yet so we have to separately keep track of the two sets of - // rowids - unordered_set &updated_rows = GLOBAL ? lstate.updated_global_rows : lstate.updated_local_rows; - for (idx_t i = 0; i < count; i++) { - auto result = updated_rows.insert(data[i]); - if (result.second == false) { - throw InvalidInputException( - "ON CONFLICT DO UPDATE can not update the same row twice in the same command, Ensure that no rows " - "proposed for insertion within the same command have duplicate constrained values"); - } - } -} - -template -static idx_t HandleInsertConflicts(TableCatalogEntry &table, ExecutionContext &context, InsertLocalState &lstate, - DataTable &data_table, const PhysicalInsert &op) { - auto &types_to_fetch = op.types_to_fetch; - auto &on_conflict_condition = op.on_conflict_condition; - auto &conflict_target = op.conflict_target; - auto &columns_to_fetch = op.columns_to_fetch; - - auto &local_storage = LocalStorage::Get(context.client, data_table.db); - - // We either want to do nothing, or perform an update when conflicts arise - ConflictInfo conflict_info(conflict_target); - ConflictManager conflict_manager(VerifyExistenceType::APPEND, lstate.insert_chunk.size(), &conflict_info); - if (GLOBAL) { - data_table.VerifyAppendConstraints(table, context.client, lstate.insert_chunk, &conflict_manager); - } else { - DataTable::VerifyUniqueIndexes(local_storage.GetIndexes(data_table), context.client, lstate.insert_chunk, - &conflict_manager); - } - conflict_manager.Finalize(); - if (conflict_manager.ConflictCount() == 0) { - // No conflicts found, 0 updates performed - return 0; - } - auto &conflicts = conflict_manager.Conflicts(); - auto &row_ids = conflict_manager.RowIds(); - - DataChunk conflict_chunk; // contains only the conflicting values - DataChunk scan_chunk; // contains the original values, that caused the conflict - DataChunk combined_chunk; // contains conflict_chunk + scan_chunk (wide) - - // Filter out everything but the conflicting rows - conflict_chunk.Initialize(context.client, lstate.insert_chunk.GetTypes()); - conflict_chunk.Reference(lstate.insert_chunk); - conflict_chunk.Slice(conflicts.Selection(), conflicts.Count()); - conflict_chunk.SetCardinality(conflicts.Count()); - - // Holds the pins for the fetched rows - unique_ptr fetch_state; - if (!types_to_fetch.empty()) { - D_ASSERT(scan_chunk.size() == 0); - // When these values are required for the conditions or the SET expressions, - // then we scan the existing table for the conflicting tuples, using the rowids - scan_chunk.Initialize(context.client, types_to_fetch); - fetch_state = make_uniq(); - if (GLOBAL) { - auto &transaction = DuckTransaction::Get(context.client, table.catalog); - data_table.Fetch(transaction, scan_chunk, columns_to_fetch, row_ids, conflicts.Count(), *fetch_state); - } else { - local_storage.FetchChunk(data_table, row_ids, conflicts.Count(), columns_to_fetch, scan_chunk, - *fetch_state); - } - } - - // Splice the Input chunk and the fetched chunk together - CombineExistingAndInsertTuples(combined_chunk, scan_chunk, conflict_chunk, context.client, op); - - if (on_conflict_condition) { - DataChunk conflict_condition_result; - CheckOnConflictCondition(context, combined_chunk, on_conflict_condition, conflict_condition_result); - bool conditions_met = AllConflictsMeetCondition(conflict_condition_result); - if (!conditions_met) { - // Filter out the tuples that did pass the filter, then run the verify again - ManagedSelection sel(combined_chunk.size()); - auto data = FlatVector::GetData(conflict_condition_result.data[0]); - for (idx_t i = 0; i < combined_chunk.size(); i++) { - if (!data[i]) { - // Only populate the selection vector with the tuples that did not meet the condition - sel.Append(i); - } - } - combined_chunk.Slice(sel.Selection(), sel.Count()); - row_ids.Slice(sel.Selection(), sel.Count()); - if (GLOBAL) { - data_table.VerifyAppendConstraints(table, context.client, combined_chunk, nullptr); - } else { - DataTable::VerifyUniqueIndexes(local_storage.GetIndexes(data_table), context.client, - lstate.insert_chunk, nullptr); - } - throw InternalException("The previous operation was expected to throw but didn't"); - } - } - - RegisterUpdatedRows(lstate, row_ids, combined_chunk.size()); - - idx_t updated_tuples = PerformOnConflictAction(context, combined_chunk, table, row_ids, op); - - // Remove the conflicting tuples from the insert chunk - SelectionVector sel_vec(lstate.insert_chunk.size()); - idx_t new_size = - SelectionVector::Inverted(conflicts.Selection(), sel_vec, conflicts.Count(), lstate.insert_chunk.size()); - lstate.insert_chunk.Slice(sel_vec, new_size); - lstate.insert_chunk.SetCardinality(new_size); - return updated_tuples; -} - -idx_t PhysicalInsert::OnConflictHandling(TableCatalogEntry &table, ExecutionContext &context, - InsertLocalState &lstate) const { - auto &data_table = table.GetStorage(); - if (action_type == OnConflictAction::THROW) { - data_table.VerifyAppendConstraints(table, context.client, lstate.insert_chunk, nullptr); - return 0; - } - // Check whether any conflicts arise, and if they all meet the conflict_target + condition - // If that's not the case - We throw the first error - idx_t updated_tuples = 0; - updated_tuples += HandleInsertConflicts(table, context, lstate, data_table, *this); - // Also check the transaction-local storage+ART so we can detect conflicts within this transaction - updated_tuples += HandleInsertConflicts(table, context, lstate, data_table, *this); - - return updated_tuples; -} - -SinkResultType PhysicalInsert::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - - auto &table = gstate.table; - auto &storage = table.GetStorage(); - PhysicalInsert::ResolveDefaults(table, chunk, column_index_map, lstate.default_executor, lstate.insert_chunk); - - if (!parallel) { - if (!gstate.initialized) { - storage.InitializeLocalAppend(gstate.append_state, context.client); - gstate.initialized = true; - } - - idx_t updated_tuples = OnConflictHandling(table, context, lstate); - gstate.insert_count += lstate.insert_chunk.size(); - gstate.insert_count += updated_tuples; - storage.LocalAppend(gstate.append_state, table, context.client, lstate.insert_chunk, true); - - if (return_chunk) { - gstate.return_collection.Append(lstate.insert_chunk); - } - } else { - D_ASSERT(!return_chunk); - // parallel append - if (!lstate.local_collection) { - lock_guard l(gstate.lock); - auto &table_info = storage.info; - auto &block_manager = TableIOManager::Get(storage).GetBlockManagerForRowData(); - lstate.local_collection = - make_uniq(table_info, block_manager, insert_types, MAX_ROW_ID); - lstate.local_collection->InitializeEmpty(); - lstate.local_collection->InitializeAppend(lstate.local_append_state); - lstate.writer = &gstate.table.GetStorage().CreateOptimisticWriter(context.client); - } - OnConflictHandling(table, context, lstate); - - auto new_row_group = lstate.local_collection->Append(lstate.insert_chunk, lstate.local_append_state); - if (new_row_group) { - lstate.writer->WriteNewRowGroup(*lstate.local_collection); - } - } - - return SinkResultType::NEED_MORE_INPUT; -} - -SinkCombineResultType PhysicalInsert::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - auto &client_profiler = QueryProfiler::Get(context.client); - context.thread.profiler.Flush(*this, lstate.default_executor, "default_executor", 1); - client_profiler.Flush(context.thread.profiler); - - if (!parallel || !lstate.local_collection) { - return SinkCombineResultType::FINISHED; - } - - // parallel append: finalize the append - TransactionData tdata(0, 0); - lstate.local_collection->FinalizeAppend(tdata, lstate.local_append_state); - - auto append_count = lstate.local_collection->GetTotalRows(); - - lock_guard lock(gstate.lock); - gstate.insert_count += append_count; - if (append_count < Storage::ROW_GROUP_SIZE) { - // we have few rows - append to the local storage directly - auto &table = gstate.table; - auto &storage = table.GetStorage(); - storage.InitializeLocalAppend(gstate.append_state, context.client); - auto &transaction = DuckTransaction::Get(context.client, table.catalog); - lstate.local_collection->Scan(transaction, [&](DataChunk &insert_chunk) { - storage.LocalAppend(gstate.append_state, table, context.client, insert_chunk); - return true; - }); - storage.FinalizeLocalAppend(gstate.append_state); - } else { - // we have written rows to disk optimistically - merge directly into the transaction-local storage - gstate.table.GetStorage().FinalizeOptimisticWriter(context.client, *lstate.writer); - gstate.table.GetStorage().LocalMerge(context.client, *lstate.local_collection); - } - - return SinkCombineResultType::FINISHED; -} - -SinkFinalizeType PhysicalInsert::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - auto &gstate = input.global_state.Cast(); - if (!parallel && gstate.initialized) { - auto &table = gstate.table; - auto &storage = table.GetStorage(); - storage.FinalizeLocalAppend(gstate.append_state); - } - return SinkFinalizeType::READY; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -class InsertSourceState : public GlobalSourceState { -public: - explicit InsertSourceState(const PhysicalInsert &op) { - if (op.return_chunk) { - D_ASSERT(op.sink_state); - auto &g = op.sink_state->Cast(); - g.return_collection.InitializeScan(scan_state); - } - } - - ColumnDataScanState scan_state; -}; - -unique_ptr PhysicalInsert::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(*this); -} - -SourceResultType PhysicalInsert::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &state = input.global_state.Cast(); - auto &insert_gstate = sink_state->Cast(); - if (!return_chunk) { - chunk.SetCardinality(1); - chunk.SetValue(0, 0, Value::BIGINT(insert_gstate.insert_count)); - return SourceResultType::FINISHED; - } - - insert_gstate.return_collection.Scan(state.scan_state, chunk); - return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; -} - -} // namespace duckdb - - - - - - - - - - - -namespace duckdb { - -PhysicalUpdate::PhysicalUpdate(vector types, TableCatalogEntry &tableref, DataTable &table, - vector columns, vector> expressions, - vector> bound_defaults, idx_t estimated_cardinality, - bool return_chunk) - : PhysicalOperator(PhysicalOperatorType::UPDATE, std::move(types), estimated_cardinality), tableref(tableref), - table(table), columns(std::move(columns)), expressions(std::move(expressions)), - bound_defaults(std::move(bound_defaults)), return_chunk(return_chunk) { -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class UpdateGlobalState : public GlobalSinkState { -public: - explicit UpdateGlobalState(ClientContext &context, const vector &return_types) - : updated_count(0), return_collection(context, return_types) { - } - - mutex lock; - idx_t updated_count; - unordered_set updated_columns; - ColumnDataCollection return_collection; -}; - -class UpdateLocalState : public LocalSinkState { -public: - UpdateLocalState(ClientContext &context, const vector> &expressions, - const vector &table_types, const vector> &bound_defaults) - : default_executor(context, bound_defaults) { - // initialize the update chunk - auto &allocator = Allocator::Get(context); - vector update_types; - update_types.reserve(expressions.size()); - for (auto &expr : expressions) { - update_types.push_back(expr->return_type); - } - update_chunk.Initialize(allocator, update_types); - // initialize the mock chunk - mock_chunk.Initialize(allocator, table_types); - } - - DataChunk update_chunk; - DataChunk mock_chunk; - ExpressionExecutor default_executor; -}; - -SinkResultType PhysicalUpdate::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - - DataChunk &update_chunk = lstate.update_chunk; - DataChunk &mock_chunk = lstate.mock_chunk; - - chunk.Flatten(); - lstate.default_executor.SetChunk(chunk); - - // update data in the base table - // the row ids are given to us as the last column of the child chunk - auto &row_ids = chunk.data[chunk.ColumnCount() - 1]; - update_chunk.Reset(); - update_chunk.SetCardinality(chunk); - - for (idx_t i = 0; i < expressions.size(); i++) { - if (expressions[i]->type == ExpressionType::VALUE_DEFAULT) { - // default expression, set to the default value of the column - lstate.default_executor.ExecuteExpression(columns[i].index, update_chunk.data[i]); - } else { - D_ASSERT(expressions[i]->type == ExpressionType::BOUND_REF); - // index into child chunk - auto &binding = expressions[i]->Cast(); - update_chunk.data[i].Reference(chunk.data[binding.index]); - } - } - - lock_guard glock(gstate.lock); - if (update_is_del_and_insert) { - // index update or update on complex type, perform a delete and an append instead - - // figure out which rows have not yet been deleted in this update - // this is required since we might see the same row_id multiple times - // in the case of an UPDATE query that e.g. has joins - auto row_id_data = FlatVector::GetData(row_ids); - SelectionVector sel(STANDARD_VECTOR_SIZE); - idx_t update_count = 0; - for (idx_t i = 0; i < update_chunk.size(); i++) { - auto row_id = row_id_data[i]; - if (gstate.updated_columns.find(row_id) == gstate.updated_columns.end()) { - gstate.updated_columns.insert(row_id); - sel.set_index(update_count++, i); - } - } - if (update_count != update_chunk.size()) { - // we need to slice here - update_chunk.Slice(sel, update_count); - } - table.Delete(tableref, context.client, row_ids, update_chunk.size()); - // for the append we need to arrange the columns in a specific manner (namely the "standard table order") - mock_chunk.SetCardinality(update_chunk); - for (idx_t i = 0; i < columns.size(); i++) { - mock_chunk.data[columns[i].index].Reference(update_chunk.data[i]); - } - table.LocalAppend(tableref, context.client, mock_chunk); - } else { - if (return_chunk) { - mock_chunk.SetCardinality(update_chunk); - for (idx_t i = 0; i < columns.size(); i++) { - mock_chunk.data[columns[i].index].Reference(update_chunk.data[i]); - } - } - table.Update(tableref, context.client, row_ids, columns, update_chunk); - } - - if (return_chunk) { - gstate.return_collection.Append(mock_chunk); - } - - gstate.updated_count += chunk.size(); - - return SinkResultType::NEED_MORE_INPUT; -} - -unique_ptr PhysicalUpdate::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(context, GetTypes()); -} - -unique_ptr PhysicalUpdate::GetLocalSinkState(ExecutionContext &context) const { - return make_uniq(context.client, expressions, table.GetTypes(), bound_defaults); -} - -SinkCombineResultType PhysicalUpdate::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { - auto &state = input.local_state.Cast(); - auto &client_profiler = QueryProfiler::Get(context.client); - context.thread.profiler.Flush(*this, state.default_executor, "default_executor", 1); - client_profiler.Flush(context.thread.profiler); - - return SinkCombineResultType::FINISHED; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -class UpdateSourceState : public GlobalSourceState { -public: - explicit UpdateSourceState(const PhysicalUpdate &op) { - if (op.return_chunk) { - D_ASSERT(op.sink_state); - auto &g = op.sink_state->Cast(); - g.return_collection.InitializeScan(scan_state); - } - } - - ColumnDataScanState scan_state; -}; - -unique_ptr PhysicalUpdate::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(*this); -} - -SourceResultType PhysicalUpdate::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &state = input.global_state.Cast(); - auto &g = sink_state->Cast(); - if (!return_chunk) { - chunk.SetCardinality(1); - chunk.SetValue(0, 0, Value::BIGINT(g.updated_count)); - return SourceResultType::FINISHED; - } - - g.return_collection.Scan(state.scan_state, chunk); - - return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; -} - -} // namespace duckdb - - - -namespace duckdb { - -PhysicalPivot::PhysicalPivot(vector types_p, unique_ptr child, - BoundPivotInfo bound_pivot_p) - : PhysicalOperator(PhysicalOperatorType::PIVOT, std::move(types_p), child->estimated_cardinality), - bound_pivot(std::move(bound_pivot_p)) { - children.push_back(std::move(child)); - for (idx_t p = 0; p < bound_pivot.pivot_values.size(); p++) { - auto entry = pivot_map.find(bound_pivot.pivot_values[p]); - if (entry != pivot_map.end()) { - continue; - } - pivot_map[bound_pivot.pivot_values[p]] = bound_pivot.group_count + p; - } - // extract the empty aggregate expressions - ArenaAllocator allocator(Allocator::DefaultAllocator()); - for (auto &aggr_expr : bound_pivot.aggregates) { - auto &aggr = aggr_expr->Cast(); - // for each aggregate, initialize an empty aggregate state and finalize it immediately - auto state = make_unsafe_uniq_array(aggr.function.state_size()); - aggr.function.initialize(state.get()); - Vector state_vector(Value::POINTER(CastPointerToValue(state.get()))); - Vector result_vector(aggr_expr->return_type); - AggregateInputData aggr_input_data(aggr.bind_info.get(), allocator); - aggr.function.finalize(state_vector, aggr_input_data, result_vector, 1, 0); - empty_aggregates.push_back(result_vector.GetValue(0)); - } -} - -OperatorResultType PhysicalPivot::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - GlobalOperatorState &gstate, OperatorState &state) const { - // copy the groups as-is - for (idx_t i = 0; i < bound_pivot.group_count; i++) { - chunk.data[i].Reference(input.data[i]); - } - auto pivot_column_lists = FlatVector::GetData(input.data.back()); - auto &pivot_column_values = ListVector::GetEntry(input.data.back()); - auto pivot_columns = FlatVector::GetData(pivot_column_values); - - // initialize all aggregate columns with the empty aggregate value - // if there are multiple aggregates the columns are in order of [AGGR1][AGGR2][AGGR1][AGGR2] - // so we need to alternate the empty_aggregate that we use - idx_t aggregate = 0; - for (idx_t c = bound_pivot.group_count; c < chunk.ColumnCount(); c++) { - chunk.data[c].Reference(empty_aggregates[aggregate]); - chunk.data[c].Flatten(input.size()); - aggregate++; - if (aggregate >= empty_aggregates.size()) { - aggregate = 0; - } - } - - // move the pivots to the given columns - for (idx_t r = 0; r < input.size(); r++) { - auto list = pivot_column_lists[r]; - for (idx_t l = 0; l < list.length; l++) { - // figure out the column value number of this list - auto &column_name = pivot_columns[list.offset + l]; - auto entry = pivot_map.find(column_name); - if (entry == pivot_map.end()) { - // column entry not found in map - that means this element is explicitly excluded from the pivot list - continue; - } - auto column_idx = entry->second; - for (idx_t aggr = 0; aggr < empty_aggregates.size(); aggr++) { - auto pivot_value_lists = FlatVector::GetData(input.data[bound_pivot.group_count + aggr]); - auto &pivot_value_child = ListVector::GetEntry(input.data[bound_pivot.group_count + aggr]); - if (list.offset != pivot_value_lists[r].offset || list.length != pivot_value_lists[r].length) { - throw InternalException("Pivot - unaligned lists between values and columns!?"); - } - chunk.data[column_idx + aggr].SetValue(r, pivot_value_child.GetValue(list.offset + l)); - } - } - } - chunk.SetCardinality(input.size()); - return OperatorResultType::NEED_MORE_INPUT; -} - -} // namespace duckdb - - - - - -namespace duckdb { - -class ProjectionState : public OperatorState { -public: - explicit ProjectionState(ExecutionContext &context, const vector> &expressions) - : executor(context.client, expressions) { - } - - ExpressionExecutor executor; - -public: - void Finalize(const PhysicalOperator &op, ExecutionContext &context) override { - context.thread.profiler.Flush(op, executor, "projection", 0); - } -}; - -PhysicalProjection::PhysicalProjection(vector types, vector> select_list, - idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::PROJECTION, std::move(types), estimated_cardinality), - select_list(std::move(select_list)) { -} - -OperatorResultType PhysicalProjection::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - GlobalOperatorState &gstate, OperatorState &state_p) const { - auto &state = state_p.Cast(); - state.executor.Execute(input, chunk); - return OperatorResultType::NEED_MORE_INPUT; -} - -unique_ptr PhysicalProjection::GetOperatorState(ExecutionContext &context) const { - return make_uniq(context, select_list); -} - -unique_ptr -PhysicalProjection::CreateJoinProjection(vector proj_types, const vector &lhs_types, - const vector &rhs_types, const vector &left_projection_map, - const vector &right_projection_map, const idx_t estimated_cardinality) { - - vector> proj_selects; - proj_selects.reserve(proj_types.size()); - - if (left_projection_map.empty()) { - for (storage_t i = 0; i < lhs_types.size(); ++i) { - proj_selects.emplace_back(make_uniq(lhs_types[i], i)); - } - } else { - for (auto i : left_projection_map) { - proj_selects.emplace_back(make_uniq(lhs_types[i], i)); - } - } - const auto left_cols = lhs_types.size(); - - if (right_projection_map.empty()) { - for (storage_t i = 0; i < rhs_types.size(); ++i) { - proj_selects.emplace_back(make_uniq(rhs_types[i], left_cols + i)); - } - - } else { - for (auto i : right_projection_map) { - proj_selects.emplace_back(make_uniq(rhs_types[i], left_cols + i)); - } - } - - return make_uniq(std::move(proj_types), std::move(proj_selects), estimated_cardinality); -} - -string PhysicalProjection::ParamsToString() const { - string extra_info; - for (auto &expr : select_list) { - extra_info += expr->GetName() + "\n"; - } - return extra_info; -} - -} // namespace duckdb - - -namespace duckdb { - -class TableInOutLocalState : public OperatorState { -public: - TableInOutLocalState() : row_index(0), new_row(true) { - } - - unique_ptr local_state; - idx_t row_index; - bool new_row; - DataChunk input_chunk; -}; - -class TableInOutGlobalState : public GlobalOperatorState { -public: - TableInOutGlobalState() { - } - - unique_ptr global_state; -}; - -PhysicalTableInOutFunction::PhysicalTableInOutFunction(vector types, TableFunction function_p, - unique_ptr bind_data_p, - vector column_ids_p, idx_t estimated_cardinality, - vector project_input_p) - : PhysicalOperator(PhysicalOperatorType::INOUT_FUNCTION, std::move(types), estimated_cardinality), - function(std::move(function_p)), bind_data(std::move(bind_data_p)), column_ids(std::move(column_ids_p)), - projected_input(std::move(project_input_p)) { -} - -unique_ptr PhysicalTableInOutFunction::GetOperatorState(ExecutionContext &context) const { - auto &gstate = op_state->Cast(); - auto result = make_uniq(); - if (function.init_local) { - TableFunctionInitInput input(bind_data.get(), column_ids, vector(), nullptr); - result->local_state = function.init_local(context, input, gstate.global_state.get()); - } - if (!projected_input.empty()) { - result->input_chunk.Initialize(context.client, children[0]->types); - } - return std::move(result); -} - -unique_ptr PhysicalTableInOutFunction::GetGlobalOperatorState(ClientContext &context) const { - auto result = make_uniq(); - if (function.init_global) { - TableFunctionInitInput input(bind_data.get(), column_ids, vector(), nullptr); - result->global_state = function.init_global(context, input); - } - return std::move(result); -} - -OperatorResultType PhysicalTableInOutFunction::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - GlobalOperatorState &gstate_p, OperatorState &state_p) const { - auto &gstate = gstate_p.Cast(); - auto &state = state_p.Cast(); - TableFunctionInput data(bind_data.get(), state.local_state.get(), gstate.global_state.get()); - if (projected_input.empty()) { - // straightforward case - no need to project input - return function.in_out_function(context, data, input, chunk); - } - // when project_input is set we execute the input function row-by-row - if (state.new_row) { - if (state.row_index >= input.size()) { - // finished processing this chunk - state.new_row = true; - state.row_index = 0; - return OperatorResultType::NEED_MORE_INPUT; - } - // we are processing a new row: fetch the data for the current row - state.input_chunk.Reset(); - D_ASSERT(input.ColumnCount() == state.input_chunk.ColumnCount()); - // set up the input data to the table in-out function - for (idx_t col_idx = 0; col_idx < input.ColumnCount(); col_idx++) { - ConstantVector::Reference(state.input_chunk.data[col_idx], input.data[col_idx], state.row_index, 1); - } - state.input_chunk.SetCardinality(1); - state.row_index++; - state.new_row = false; - } - // set up the output data in "chunk" - D_ASSERT(chunk.ColumnCount() > projected_input.size()); - D_ASSERT(state.row_index > 0); - idx_t base_idx = chunk.ColumnCount() - projected_input.size(); - for (idx_t project_idx = 0; project_idx < projected_input.size(); project_idx++) { - auto source_idx = projected_input[project_idx]; - auto target_idx = base_idx + project_idx; - ConstantVector::Reference(chunk.data[target_idx], input.data[source_idx], state.row_index - 1, 1); - } - auto result = function.in_out_function(context, data, state.input_chunk, chunk); - if (result == OperatorResultType::FINISHED) { - return result; - } - if (result == OperatorResultType::NEED_MORE_INPUT) { - // we finished processing this row: move to the next row - state.new_row = true; - } - return OperatorResultType::HAVE_MORE_OUTPUT; -} - -OperatorFinalizeResultType PhysicalTableInOutFunction::FinalExecute(ExecutionContext &context, DataChunk &chunk, - GlobalOperatorState &gstate_p, - OperatorState &state_p) const { - auto &gstate = gstate_p.Cast(); - auto &state = state_p.Cast(); - if (!projected_input.empty()) { - throw InternalException("FinalExecute not supported for project_input"); - } - TableFunctionInput data(bind_data.get(), state.local_state.get(), gstate.global_state.get()); - return function.in_out_function_final(context, data, chunk); -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -class UnnestOperatorState : public OperatorState { -public: - UnnestOperatorState(ClientContext &context, const vector> &select_list) - : current_row(0), list_position(0), longest_list_length(DConstants::INVALID_INDEX), first_fetch(true), - executor(context) { - - // for each UNNEST in the select_list, we add the child expression to the expression executor - // and set the return type in the list_data chunk, which will contain the evaluated expression results - vector list_data_types; - for (auto &exp : select_list) { - D_ASSERT(exp->type == ExpressionType::BOUND_UNNEST); - auto &bue = exp->Cast(); - list_data_types.push_back(bue.child->return_type); - executor.AddExpression(*bue.child.get()); - } - - auto &allocator = Allocator::Get(context); - list_data.Initialize(allocator, list_data_types); - - list_vector_data.resize(list_data.ColumnCount()); - list_child_data.resize(list_data.ColumnCount()); - } - - idx_t current_row; - idx_t list_position; - idx_t longest_list_length; - bool first_fetch; - - ExpressionExecutor executor; - DataChunk list_data; - vector list_vector_data; - vector list_child_data; - -public: - //! Reset the fields of the unnest operator state - void Reset(); - //! Set the longest list's length for the current row - void SetLongestListLength(); -}; - -void UnnestOperatorState::Reset() { - current_row = 0; - list_position = 0; - longest_list_length = DConstants::INVALID_INDEX; - first_fetch = true; -} - -void UnnestOperatorState::SetLongestListLength() { - longest_list_length = 0; - for (idx_t col_idx = 0; col_idx < list_data.ColumnCount(); col_idx++) { - - auto &vector_data = list_vector_data[col_idx]; - auto current_idx = vector_data.sel->get_index(current_row); - - if (vector_data.validity.RowIsValid(current_idx)) { - - // check if this list is longer - auto list_data_entries = UnifiedVectorFormat::GetData(vector_data); - auto list_entry = list_data_entries[current_idx]; - if (list_entry.length > longest_list_length) { - longest_list_length = list_entry.length; - } - } - } -} - -PhysicalUnnest::PhysicalUnnest(vector types, vector> select_list, - idx_t estimated_cardinality, PhysicalOperatorType type) - : PhysicalOperator(type, std::move(types), estimated_cardinality), select_list(std::move(select_list)) { - D_ASSERT(!this->select_list.empty()); -} - -static void UnnestNull(idx_t start, idx_t end, Vector &result) { - - D_ASSERT(result.GetVectorType() == VectorType::FLAT_VECTOR); - auto &validity = FlatVector::Validity(result); - for (idx_t i = start; i < end; i++) { - validity.SetInvalid(i); - } - if (result.GetType().InternalType() == PhysicalType::STRUCT) { - auto &struct_children = StructVector::GetEntries(result); - for (auto &child : struct_children) { - UnnestNull(start, end, *child); - } - } -} - -template -static void TemplatedUnnest(UnifiedVectorFormat &vector_data, idx_t start, idx_t end, Vector &result) { - - auto source_data = UnifiedVectorFormat::GetData(vector_data); - auto &source_mask = vector_data.validity; - - D_ASSERT(result.GetVectorType() == VectorType::FLAT_VECTOR); - auto result_data = FlatVector::GetData(result); - auto &result_mask = FlatVector::Validity(result); - - for (idx_t i = start; i < end; i++) { - auto source_idx = vector_data.sel->get_index(i); - auto target_idx = i - start; - if (source_mask.RowIsValid(source_idx)) { - result_data[target_idx] = source_data[source_idx]; - result_mask.SetValid(target_idx); - } else { - result_mask.SetInvalid(target_idx); - } - } -} - -static void UnnestValidity(UnifiedVectorFormat &vector_data, idx_t start, idx_t end, Vector &result) { - - auto &source_mask = vector_data.validity; - D_ASSERT(result.GetVectorType() == VectorType::FLAT_VECTOR); - auto &result_mask = FlatVector::Validity(result); - - for (idx_t i = start; i < end; i++) { - auto source_idx = vector_data.sel->get_index(i); - auto target_idx = i - start; - result_mask.Set(target_idx, source_mask.RowIsValid(source_idx)); - } -} - -static void UnnestVector(UnifiedVectorFormat &child_vector_data, Vector &child_vector, idx_t list_size, idx_t start, - idx_t end, Vector &result) { - - D_ASSERT(child_vector.GetType() == result.GetType()); - switch (result.GetType().InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedUnnest(child_vector_data, start, end, result); - break; - case PhysicalType::INT16: - TemplatedUnnest(child_vector_data, start, end, result); - break; - case PhysicalType::INT32: - TemplatedUnnest(child_vector_data, start, end, result); - break; - case PhysicalType::INT64: - TemplatedUnnest(child_vector_data, start, end, result); - break; - case PhysicalType::INT128: - TemplatedUnnest(child_vector_data, start, end, result); - break; - case PhysicalType::UINT8: - TemplatedUnnest(child_vector_data, start, end, result); - break; - case PhysicalType::UINT16: - TemplatedUnnest(child_vector_data, start, end, result); - break; - case PhysicalType::UINT32: - TemplatedUnnest(child_vector_data, start, end, result); - break; - case PhysicalType::UINT64: - TemplatedUnnest(child_vector_data, start, end, result); - break; - case PhysicalType::FLOAT: - TemplatedUnnest(child_vector_data, start, end, result); - break; - case PhysicalType::DOUBLE: - TemplatedUnnest(child_vector_data, start, end, result); - break; - case PhysicalType::INTERVAL: - TemplatedUnnest(child_vector_data, start, end, result); - break; - case PhysicalType::VARCHAR: - TemplatedUnnest(child_vector_data, start, end, result); - break; - case PhysicalType::LIST: { - // the child vector of result now references the child vector source - // FIXME: only reference relevant children (start - end) instead of all - auto &target = ListVector::GetEntry(result); - target.Reference(ListVector::GetEntry(child_vector)); - ListVector::SetListSize(result, ListVector::GetListSize(child_vector)); - // unnest - TemplatedUnnest(child_vector_data, start, end, result); - break; - } - case PhysicalType::STRUCT: { - auto &child_vector_entries = StructVector::GetEntries(child_vector); - auto &result_entries = StructVector::GetEntries(result); - - // set the validity mask for the 'outer' struct vector before unnesting its children - UnnestValidity(child_vector_data, start, end, result); - - for (idx_t i = 0; i < child_vector_entries.size(); i++) { - UnifiedVectorFormat child_vector_entries_data; - child_vector_entries[i]->ToUnifiedFormat(list_size, child_vector_entries_data); - UnnestVector(child_vector_entries_data, *child_vector_entries[i], list_size, start, end, - *result_entries[i]); - } - break; - } - default: - throw InternalException("Unimplemented type for UNNEST."); - } -} - -static void PrepareInput(UnnestOperatorState &state, DataChunk &input, - const vector> &select_list) { - - state.list_data.Reset(); - // execute the expressions inside each UNNEST in the select_list to get the list data - // execution results (lists) are kept in state.list_data chunk - state.executor.Execute(input, state.list_data); - - // verify incoming lists - state.list_data.Verify(); - D_ASSERT(input.size() == state.list_data.size()); - D_ASSERT(state.list_data.ColumnCount() == select_list.size()); - D_ASSERT(state.list_vector_data.size() == state.list_data.ColumnCount()); - D_ASSERT(state.list_child_data.size() == state.list_data.ColumnCount()); - - // get the UnifiedVectorFormat of each list_data vector (LIST vectors for the different UNNESTs) - // both for the vector itself and its child vector - for (idx_t col_idx = 0; col_idx < state.list_data.ColumnCount(); col_idx++) { - - auto &list_vector = state.list_data.data[col_idx]; - list_vector.ToUnifiedFormat(state.list_data.size(), state.list_vector_data[col_idx]); - - if (list_vector.GetType() == LogicalType::SQLNULL) { - // UNNEST(NULL): SQLNULL vectors don't have child vectors, but we need to point to the child vector of - // each vector, so we just get the UnifiedVectorFormat of the vector itself - auto &child_vector = list_vector; - child_vector.ToUnifiedFormat(0, state.list_child_data[col_idx]); - } else { - auto list_size = ListVector::GetListSize(list_vector); - auto &child_vector = ListVector::GetEntry(list_vector); - child_vector.ToUnifiedFormat(list_size, state.list_child_data[col_idx]); - } - } - - state.first_fetch = false; -} - -unique_ptr PhysicalUnnest::GetOperatorState(ExecutionContext &context) const { - return PhysicalUnnest::GetState(context, select_list); -} - -unique_ptr PhysicalUnnest::GetState(ExecutionContext &context, - const vector> &select_list) { - return make_uniq(context.client, select_list); -} - -OperatorResultType PhysicalUnnest::ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - OperatorState &state_p, - const vector> &select_list, - bool include_input) { - - auto &state = state_p.Cast(); - - do { - // reset validities, if previous loop iteration contained UNNEST(NULL) - if (include_input) { - chunk.Reset(); - } - - // prepare the input data by executing any expressions and getting the - // UnifiedVectorFormat of each LIST vector (list_vector_data) and its child vector (list_child_data) - if (state.first_fetch) { - PrepareInput(state, input, select_list); - } - - // finished with all rows of this input chunk, reset - if (state.current_row >= input.size()) { - state.Reset(); - return OperatorResultType::NEED_MORE_INPUT; - } - - // each UNNEST in the select_list contains a list (or NULL) for this row, find the longest list - // because this length determines how many times we need to repeat for the current row - if (state.longest_list_length == DConstants::INVALID_INDEX) { - state.SetLongestListLength(); - } - D_ASSERT(state.longest_list_length != DConstants::INVALID_INDEX); - - // we emit chunks of either STANDARD_VECTOR_SIZE or smaller - auto this_chunk_len = MinValue(STANDARD_VECTOR_SIZE, state.longest_list_length - state.list_position); - chunk.SetCardinality(this_chunk_len); - - // if we include other projection input columns, e.g. SELECT 1, UNNEST([1, 2]);, then - // we need to add them as a constant vector to the resulting chunk - // FIXME: emit multiple unnested rows. Currently, we never emit a chunk containing multiple unnested input rows, - // so setting a constant vector for the value at state.current_row is fine - idx_t col_offset = 0; - if (include_input) { - for (idx_t col_idx = 0; col_idx < input.ColumnCount(); col_idx++) { - ConstantVector::Reference(chunk.data[col_idx], input.data[col_idx], state.current_row, input.size()); - } - col_offset = input.ColumnCount(); - } - - // unnest the lists - for (idx_t col_idx = 0; col_idx < state.list_data.ColumnCount(); col_idx++) { - - auto &result_vector = chunk.data[col_idx + col_offset]; - - if (state.list_data.data[col_idx].GetType() == LogicalType::SQLNULL) { - // UNNEST(NULL) - chunk.SetCardinality(0); - break; - } - - auto &vector_data = state.list_vector_data[col_idx]; - auto current_idx = vector_data.sel->get_index(state.current_row); - - if (!vector_data.validity.RowIsValid(current_idx)) { - UnnestNull(0, this_chunk_len, result_vector); - continue; - } - - auto list_data = UnifiedVectorFormat::GetData(vector_data); - auto list_entry = list_data[current_idx]; - - idx_t list_count = 0; - if (state.list_position < list_entry.length) { - // there are still list_count elements to unnest - list_count = MinValue(this_chunk_len, list_entry.length - state.list_position); - - auto &list_vector = state.list_data.data[col_idx]; - auto &child_vector = ListVector::GetEntry(list_vector); - auto list_size = ListVector::GetListSize(list_vector); - auto &child_vector_data = state.list_child_data[col_idx]; - - auto base_offset = list_entry.offset + state.list_position; - UnnestVector(child_vector_data, child_vector, list_size, base_offset, base_offset + list_count, - result_vector); - } - - // fill the rest with NULLs - if (list_count != this_chunk_len) { - UnnestNull(list_count, this_chunk_len, result_vector); - } - } - - chunk.Verify(); - - state.list_position += this_chunk_len; - if (state.list_position == state.longest_list_length) { - state.current_row++; - state.longest_list_length = DConstants::INVALID_INDEX; - state.list_position = 0; - } - - // we only emit one unnested row (that contains data) at a time - } while (chunk.size() == 0); - return OperatorResultType::HAVE_MORE_OUTPUT; -} - -OperatorResultType PhysicalUnnest::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - GlobalOperatorState &, OperatorState &state) const { - return ExecuteInternal(context, input, chunk, state, select_list); -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -PhysicalColumnDataScan::PhysicalColumnDataScan(vector types, PhysicalOperatorType op_type, - idx_t estimated_cardinality, - unique_ptr owned_collection_p) - : PhysicalOperator(op_type, std::move(types), estimated_cardinality), collection(owned_collection_p.get()), - owned_collection(std::move(owned_collection_p)) { -} - -class PhysicalColumnDataScanState : public GlobalSourceState { -public: - explicit PhysicalColumnDataScanState() : initialized(false) { - } - - //! The current position in the scan - ColumnDataScanState scan_state; - bool initialized; -}; - -unique_ptr PhysicalColumnDataScan::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(); -} - -SourceResultType PhysicalColumnDataScan::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &state = input.global_state.Cast(); - if (collection->Count() == 0) { - return SourceResultType::FINISHED; - } - if (!state.initialized) { - collection->InitializeScan(state.scan_state); - state.initialized = true; - } - collection->Scan(state.scan_state, chunk); - - return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; -} - -//===--------------------------------------------------------------------===// -// Pipeline Construction -//===--------------------------------------------------------------------===// -void PhysicalColumnDataScan::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { - // check if there is any additional action we need to do depending on the type - auto &state = meta_pipeline.GetState(); - switch (type) { - case PhysicalOperatorType::DELIM_SCAN: { - auto entry = state.delim_join_dependencies.find(*this); - D_ASSERT(entry != state.delim_join_dependencies.end()); - // this chunk scan introduces a dependency to the current pipeline - // namely a dependency on the duplicate elimination pipeline to finish - auto delim_dependency = entry->second.get().shared_from_this(); - auto delim_sink = state.GetPipelineSink(*delim_dependency); - D_ASSERT(delim_sink); - D_ASSERT(delim_sink->type == PhysicalOperatorType::DELIM_JOIN); - auto &delim_join = delim_sink->Cast(); - current.AddDependency(delim_dependency); - state.SetPipelineSource(current, delim_join.distinct->Cast()); - return; - } - case PhysicalOperatorType::CTE_SCAN: { - break; - } - case PhysicalOperatorType::RECURSIVE_CTE_SCAN: - if (!meta_pipeline.HasRecursiveCTE()) { - throw InternalException("Recursive CTE scan found without recursive CTE node"); - } - break; - default: - break; - } - D_ASSERT(children.empty()); - state.SetPipelineSource(current, *this); -} - -string PhysicalColumnDataScan::ParamsToString() const { - string result = ""; - switch (type) { - case PhysicalOperatorType::CTE_SCAN: - case PhysicalOperatorType::RECURSIVE_CTE_SCAN: { - result += "\n[INFOSEPARATOR]\n"; - result += StringUtil::Format("idx: %llu", cte_index); - break; - } - default: - break; - } - - return result; -} - -} // namespace duckdb - - -namespace duckdb { - -SourceResultType PhysicalDummyScan::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - // return a single row on the first call to the dummy scan - chunk.SetCardinality(1); - - return SourceResultType::FINISHED; -} - -} // namespace duckdb - - -namespace duckdb { - -SourceResultType PhysicalEmptyResult::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - return SourceResultType::FINISHED; -} - -} // namespace duckdb - - - - -namespace duckdb { - -class ExpressionScanState : public OperatorState { -public: - explicit ExpressionScanState(Allocator &allocator, const PhysicalExpressionScan &op) : expression_index(0) { - temp_chunk.Initialize(allocator, op.GetTypes()); - } - - //! The current position in the scan - idx_t expression_index; - //! Temporary chunk for evaluating expressions - DataChunk temp_chunk; -}; - -unique_ptr PhysicalExpressionScan::GetOperatorState(ExecutionContext &context) const { - return make_uniq(Allocator::Get(context.client), *this); -} - -OperatorResultType PhysicalExpressionScan::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - GlobalOperatorState &gstate, OperatorState &state_p) const { - auto &state = state_p.Cast(); - - for (; chunk.size() + input.size() <= STANDARD_VECTOR_SIZE && state.expression_index < expressions.size(); - state.expression_index++) { - state.temp_chunk.Reset(); - EvaluateExpression(context.client, state.expression_index, &input, state.temp_chunk); - chunk.Append(state.temp_chunk); - } - if (state.expression_index < expressions.size()) { - return OperatorResultType::HAVE_MORE_OUTPUT; - } else { - state.expression_index = 0; - return OperatorResultType::NEED_MORE_INPUT; - } -} - -void PhysicalExpressionScan::EvaluateExpression(ClientContext &context, idx_t expression_idx, DataChunk *child_chunk, - DataChunk &result) const { - ExpressionExecutor executor(context, expressions[expression_idx]); - if (child_chunk) { - child_chunk->Verify(); - executor.Execute(*child_chunk, result); - } else { - executor.Execute(result); - } -} - -bool PhysicalExpressionScan::IsFoldable() const { - for (auto &expr_list : expressions) { - for (auto &expr : expr_list) { - if (!expr->IsFoldable()) { - return false; - } - } - } - return true; -} - -} // namespace duckdb - - - - - - - - -#include - -namespace duckdb { - -PhysicalPositionalScan::PhysicalPositionalScan(vector types, unique_ptr left, - unique_ptr right) - : PhysicalOperator(PhysicalOperatorType::POSITIONAL_SCAN, std::move(types), - MaxValue(left->estimated_cardinality, right->estimated_cardinality)) { - - // Manage the children ourselves - if (left->type == PhysicalOperatorType::TABLE_SCAN) { - child_tables.emplace_back(std::move(left)); - } else if (left->type == PhysicalOperatorType::POSITIONAL_SCAN) { - auto &left_scan = left->Cast(); - child_tables = std::move(left_scan.child_tables); - } else { - throw InternalException("Invalid left input for PhysicalPositionalScan"); - } - - if (right->type == PhysicalOperatorType::TABLE_SCAN) { - child_tables.emplace_back(std::move(right)); - } else if (right->type == PhysicalOperatorType::POSITIONAL_SCAN) { - auto &right_scan = right->Cast(); - auto &right_tables = right_scan.child_tables; - child_tables.reserve(child_tables.size() + right_tables.size()); - std::move(right_tables.begin(), right_tables.end(), std::back_inserter(child_tables)); - } else { - throw InternalException("Invalid right input for PhysicalPositionalScan"); - } -} - -class PositionalScanGlobalSourceState : public GlobalSourceState { -public: - PositionalScanGlobalSourceState(ClientContext &context, const PhysicalPositionalScan &op) { - for (const auto &table : op.child_tables) { - global_states.emplace_back(table->GetGlobalSourceState(context)); - } - } - - vector> global_states; - - idx_t MaxThreads() override { - return 1; - } -}; - -class PositionalTableScanner { -public: - PositionalTableScanner(ExecutionContext &context, PhysicalOperator &table_p, GlobalSourceState &gstate_p) - : table(table_p), global_state(gstate_p), source_offset(0), exhausted(false) { - local_state = table.GetLocalSourceState(context, gstate_p); - source.Initialize(Allocator::Get(context.client), table.types); - } - - idx_t Refill(ExecutionContext &context) { - if (source_offset >= source.size()) { - if (!exhausted) { - source.Reset(); - - InterruptState interrupt_state; - OperatorSourceInput source_input {global_state, *local_state, interrupt_state}; - auto source_result = table.GetData(context, source, source_input); - if (source_result == SourceResultType::BLOCKED) { - throw NotImplementedException( - "Unexpected interrupt from table Source in PositionalTableScanner refill"); - } - } - source_offset = 0; - } - - const auto available = source.size() - source_offset; - if (!available) { - if (!exhausted) { - source.Reset(); - for (idx_t i = 0; i < source.ColumnCount(); ++i) { - auto &vec = source.data[i]; - vec.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(vec, true); - } - exhausted = true; - } - } - - return available; - } - - idx_t CopyData(ExecutionContext &context, DataChunk &output, const idx_t count, const idx_t col_offset) { - if (!source_offset && (source.size() >= count || exhausted)) { - // Fast track: aligned and has enough data - for (idx_t i = 0; i < source.ColumnCount(); ++i) { - output.data[col_offset + i].Reference(source.data[i]); - } - source_offset += count; - } else { - // Copy data - for (idx_t target_offset = 0; target_offset < count;) { - const auto needed = count - target_offset; - const auto available = exhausted ? needed : (source.size() - source_offset); - const auto copy_size = MinValue(needed, available); - const auto source_count = source_offset + copy_size; - for (idx_t i = 0; i < source.ColumnCount(); ++i) { - VectorOperations::Copy(source.data[i], output.data[col_offset + i], source_count, source_offset, - target_offset); - } - target_offset += copy_size; - source_offset += copy_size; - Refill(context); - } - } - - return source.ColumnCount(); - } - - double GetProgress(ClientContext &context) { - return table.GetProgress(context, global_state); - } - - PhysicalOperator &table; - GlobalSourceState &global_state; - unique_ptr local_state; - DataChunk source; - idx_t source_offset; - bool exhausted; -}; - -class PositionalScanLocalSourceState : public LocalSourceState { -public: - PositionalScanLocalSourceState(ExecutionContext &context, PositionalScanGlobalSourceState &gstate, - const PhysicalPositionalScan &op) { - for (size_t i = 0; i < op.child_tables.size(); ++i) { - auto &child = *op.child_tables[i]; - auto &global_state = *gstate.global_states[i]; - scanners.emplace_back(make_uniq(context, child, global_state)); - } - } - - vector> scanners; -}; - -unique_ptr PhysicalPositionalScan::GetLocalSourceState(ExecutionContext &context, - GlobalSourceState &gstate) const { - return make_uniq(context, gstate.Cast(), *this); -} - -unique_ptr PhysicalPositionalScan::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(context, *this); -} - -SourceResultType PhysicalPositionalScan::GetData(ExecutionContext &context, DataChunk &output, - OperatorSourceInput &input) const { - auto &lstate = input.local_state.Cast(); - - // Find the longest source block - idx_t count = 0; - for (auto &scanner : lstate.scanners) { - count = MaxValue(count, scanner->Refill(context)); - } - - // All done? - if (!count) { - return SourceResultType::FINISHED; - } - - // Copy or reference the source columns - idx_t col_offset = 0; - for (auto &scanner : lstate.scanners) { - col_offset += scanner->CopyData(context, output, count, col_offset); - } - - output.SetCardinality(count); - return SourceResultType::HAVE_MORE_OUTPUT; -} - -double PhysicalPositionalScan::GetProgress(ClientContext &context, GlobalSourceState &gstate_p) const { - auto &gstate = gstate_p.Cast(); - - double result = child_tables[0]->GetProgress(context, *gstate.global_states[0]); - for (size_t t = 1; t < child_tables.size(); ++t) { - result = MinValue(result, child_tables[t]->GetProgress(context, *gstate.global_states[t])); - } - - return result; -} - -bool PhysicalPositionalScan::Equals(const PhysicalOperator &other_p) const { - if (type != other_p.type) { - return false; - } - - auto &other = other_p.Cast(); - if (child_tables.size() != other.child_tables.size()) { - return false; - } - for (size_t i = 0; i < child_tables.size(); ++i) { - if (!child_tables[i]->Equals(*other.child_tables[i])) { - return false; - } - } - - return true; -} - -} // namespace duckdb - - - - - - - -#include - -namespace duckdb { - -PhysicalTableScan::PhysicalTableScan(vector types, TableFunction function_p, - unique_ptr bind_data_p, vector returned_types_p, - vector column_ids_p, vector projection_ids_p, - vector names_p, unique_ptr table_filters_p, - idx_t estimated_cardinality, ExtraOperatorInfo extra_info) - : PhysicalOperator(PhysicalOperatorType::TABLE_SCAN, std::move(types), estimated_cardinality), - function(std::move(function_p)), bind_data(std::move(bind_data_p)), returned_types(std::move(returned_types_p)), - column_ids(std::move(column_ids_p)), projection_ids(std::move(projection_ids_p)), names(std::move(names_p)), - table_filters(std::move(table_filters_p)), extra_info(extra_info) { -} - -class TableScanGlobalSourceState : public GlobalSourceState { -public: - TableScanGlobalSourceState(ClientContext &context, const PhysicalTableScan &op) { - if (op.function.init_global) { - TableFunctionInitInput input(op.bind_data.get(), op.column_ids, op.projection_ids, op.table_filters.get()); - global_state = op.function.init_global(context, input); - if (global_state) { - max_threads = global_state->MaxThreads(); - } - } else { - max_threads = 1; - } - } - - idx_t max_threads = 0; - unique_ptr global_state; - - idx_t MaxThreads() override { - return max_threads; - } -}; - -class TableScanLocalSourceState : public LocalSourceState { -public: - TableScanLocalSourceState(ExecutionContext &context, TableScanGlobalSourceState &gstate, - const PhysicalTableScan &op) { - if (op.function.init_local) { - TableFunctionInitInput input(op.bind_data.get(), op.column_ids, op.projection_ids, op.table_filters.get()); - local_state = op.function.init_local(context, input, gstate.global_state.get()); - } - } - - unique_ptr local_state; -}; - -unique_ptr PhysicalTableScan::GetLocalSourceState(ExecutionContext &context, - GlobalSourceState &gstate) const { - return make_uniq(context, gstate.Cast(), *this); -} - -unique_ptr PhysicalTableScan::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(context, *this); -} - -SourceResultType PhysicalTableScan::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - D_ASSERT(!column_ids.empty()); - auto &gstate = input.global_state.Cast(); - auto &state = input.local_state.Cast(); - - TableFunctionInput data(bind_data.get(), state.local_state.get(), gstate.global_state.get()); - function.function(context.client, data, chunk); - - return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; -} - -double PhysicalTableScan::GetProgress(ClientContext &context, GlobalSourceState &gstate_p) const { - auto &gstate = gstate_p.Cast(); - if (function.table_scan_progress) { - return function.table_scan_progress(context, bind_data.get(), gstate.global_state.get()); - } - // if table_scan_progress is not implemented we don't support this function yet in the progress bar - return -1; -} - -idx_t PhysicalTableScan::GetBatchIndex(ExecutionContext &context, DataChunk &chunk, GlobalSourceState &gstate_p, - LocalSourceState &lstate) const { - D_ASSERT(SupportsBatchIndex()); - D_ASSERT(function.get_batch_index); - auto &gstate = gstate_p.Cast(); - auto &state = lstate.Cast(); - return function.get_batch_index(context.client, bind_data.get(), state.local_state.get(), - gstate.global_state.get()); -} - -string PhysicalTableScan::GetName() const { - return StringUtil::Upper(function.name + " " + function.extra_info); -} - -string PhysicalTableScan::ParamsToString() const { - string result; - if (function.to_string) { - result = function.to_string(bind_data.get()); - result += "\n[INFOSEPARATOR]\n"; - } - if (function.projection_pushdown) { - if (function.filter_prune) { - for (idx_t i = 0; i < projection_ids.size(); i++) { - const auto &column_id = column_ids[projection_ids[i]]; - if (column_id < names.size()) { - if (i > 0) { - result += "\n"; - } - result += names[column_id]; - } - } - } else { - for (idx_t i = 0; i < column_ids.size(); i++) { - const auto &column_id = column_ids[i]; - if (column_id < names.size()) { - if (i > 0) { - result += "\n"; - } - result += names[column_id]; - } - } - } - } - if (function.filter_pushdown && table_filters) { - result += "\n[INFOSEPARATOR]\n"; - result += "Filters: "; - for (auto &f : table_filters->filters) { - auto &column_index = f.first; - auto &filter = f.second; - if (column_index < names.size()) { - result += filter->ToString(names[column_ids[column_index]]); - result += "\n"; - } - } - } - if (!extra_info.file_filters.empty()) { - result += "\n[INFOSEPARATOR]\n"; - result += "File Filters: " + extra_info.file_filters; - } - result += "\n[INFOSEPARATOR]\n"; - result += StringUtil::Format("EC: %llu", estimated_cardinality); - return result; -} - -bool PhysicalTableScan::Equals(const PhysicalOperator &other_p) const { - if (type != other_p.type) { - return false; - } - auto &other = other_p.Cast(); - if (function.function != other.function.function) { - return false; - } - if (column_ids != other.column_ids) { - return false; - } - if (!FunctionData::Equals(bind_data.get(), other.bind_data.get())) { - return false; - } - return true; -} - -} // namespace duckdb - - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -SourceResultType PhysicalAlter::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { - auto &catalog = Catalog::GetCatalog(context.client, info->catalog); - catalog.Alter(context.client, *info); - - return SourceResultType::FINISHED; -} - -} // namespace duckdb - - - - - - - - - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -SourceResultType PhysicalAttach::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - // parse the options - auto &config = DBConfig::GetConfig(context.client); - AccessMode access_mode = config.options.access_mode; - string type; - string unrecognized_option; - for (auto &entry : info->options) { - if (entry.first == "readonly" || entry.first == "read_only") { - auto read_only = BooleanValue::Get(entry.second.DefaultCastAs(LogicalType::BOOLEAN)); - if (read_only) { - access_mode = AccessMode::READ_ONLY; - } else { - access_mode = AccessMode::READ_WRITE; - } - } else if (entry.first == "readwrite" || entry.first == "read_write") { - auto read_only = !BooleanValue::Get(entry.second.DefaultCastAs(LogicalType::BOOLEAN)); - if (read_only) { - access_mode = AccessMode::READ_ONLY; - } else { - access_mode = AccessMode::READ_WRITE; - } - } else if (entry.first == "type") { - type = StringValue::Get(entry.second.DefaultCastAs(LogicalType::VARCHAR)); - } else if (unrecognized_option.empty()) { - unrecognized_option = entry.first; - } - } - auto &db = DatabaseInstance::GetDatabase(context.client); - if (type.empty()) { - // try to extract type from path - auto path_and_type = DBPathAndType::Parse(info->path, config); - type = path_and_type.type; - info->path = path_and_type.path; - } - - if (type.empty() && !unrecognized_option.empty()) { - throw BinderException("Unrecognized option for attach \"%s\"", unrecognized_option); - } - - // if we are loading a database type from an extension - check if that extension is loaded - if (!type.empty()) { - if (!Catalog::TryAutoLoad(context.client, type)) { - // FIXME: Here it might be preferrable to use an AutoLoadOrThrow kind of function - // so that either there will be success or a message to throw, and load will be - // attempted only once respecting the autoloading options - ExtensionHelper::LoadExternalExtension(context.client, type); - } - } - - // attach the database - auto &name = info->name; - const auto &path = info->path; - - if (name.empty()) { - auto &fs = FileSystem::GetFileSystem(context.client); - name = AttachedDatabase::ExtractDatabaseName(path, fs); - } - auto &db_manager = DatabaseManager::Get(context.client); - auto existing_db = db_manager.GetDatabaseFromPath(context.client, path); - if (existing_db) { - throw BinderException("Database \"%s\" is already attached with alias \"%s\"", path, existing_db->GetName()); - } - auto new_db = db.CreateAttachedDatabase(*info, type, access_mode); - new_db->Initialize(); - - db_manager.AddDatabase(context.client, std::move(new_db)); - - return SourceResultType::FINISHED; -} - -} // namespace duckdb - - - - - - - - - - - - - - -namespace duckdb { - -PhysicalCreateARTIndex::PhysicalCreateARTIndex(LogicalOperator &op, TableCatalogEntry &table_p, - const vector &column_ids, unique_ptr info, - vector> unbound_expressions, - idx_t estimated_cardinality, const bool sorted) - : PhysicalOperator(PhysicalOperatorType::CREATE_INDEX, op.types, estimated_cardinality), - table(table_p.Cast()), info(std::move(info)), unbound_expressions(std::move(unbound_expressions)), - sorted(sorted) { - // convert virtual column ids to storage column ids - for (auto &column_id : column_ids) { - storage_ids.push_back(table.GetColumns().LogicalToPhysical(LogicalIndex(column_id)).index); - } -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// - -class CreateARTIndexGlobalSinkState : public GlobalSinkState { -public: - //! Global index to be added to the table - unique_ptr global_index; -}; - -class CreateARTIndexLocalSinkState : public LocalSinkState { -public: - explicit CreateARTIndexLocalSinkState(ClientContext &context) : arena_allocator(Allocator::Get(context)) {}; - - unique_ptr local_index; - ArenaAllocator arena_allocator; - vector keys; - DataChunk key_chunk; - vector key_column_ids; -}; - -unique_ptr PhysicalCreateARTIndex::GetGlobalSinkState(ClientContext &context) const { - auto state = make_uniq(); - - // create the global index - auto &storage = table.GetStorage(); - state->global_index = make_uniq(storage_ids, TableIOManager::Get(storage), unbound_expressions, - info->constraint_type, storage.db); - - return (std::move(state)); -} - -unique_ptr PhysicalCreateARTIndex::GetLocalSinkState(ExecutionContext &context) const { - auto state = make_uniq(context.client); - - // create the local index - - auto &storage = table.GetStorage(); - state->local_index = make_uniq(storage_ids, TableIOManager::Get(storage), unbound_expressions, - info->constraint_type, storage.db); - - state->keys = vector(STANDARD_VECTOR_SIZE); - state->key_chunk.Initialize(Allocator::Get(context.client), state->local_index->logical_types); - - for (idx_t i = 0; i < state->key_chunk.ColumnCount(); i++) { - state->key_column_ids.push_back(i); - } - return std::move(state); -} - -SinkResultType PhysicalCreateARTIndex::SinkUnsorted(Vector &row_identifiers, OperatorSinkInput &input) const { - - auto &l_state = input.local_state.Cast(); - auto count = l_state.key_chunk.size(); - - // get the corresponding row IDs - row_identifiers.Flatten(count); - auto row_ids = FlatVector::GetData(row_identifiers); - - // insert the row IDs - auto &art = l_state.local_index->Cast(); - for (idx_t i = 0; i < count; i++) { - if (!art.Insert(art.tree, l_state.keys[i], 0, row_ids[i])) { - throw ConstraintException("Data contains duplicates on indexed column(s)"); - } - } - - return SinkResultType::NEED_MORE_INPUT; -} - -SinkResultType PhysicalCreateARTIndex::SinkSorted(Vector &row_identifiers, OperatorSinkInput &input) const { - - auto &l_state = input.local_state.Cast(); - auto &storage = table.GetStorage(); - auto &l_index = l_state.local_index; - - // create an ART from the chunk - auto art = make_uniq(l_index->column_ids, l_index->table_io_manager, l_index->unbound_expressions, - l_index->constraint_type, storage.db, l_index->Cast().allocators); - if (!art->ConstructFromSorted(l_state.key_chunk.size(), l_state.keys, row_identifiers)) { - throw ConstraintException("Data contains duplicates on indexed column(s)"); - } - - // merge into the local ART - if (!l_index->MergeIndexes(*art)) { - throw ConstraintException("Data contains duplicates on indexed column(s)"); - } - - return SinkResultType::NEED_MORE_INPUT; -} - -SinkResultType PhysicalCreateARTIndex::Sink(ExecutionContext &context, DataChunk &chunk, - OperatorSinkInput &input) const { - - D_ASSERT(chunk.ColumnCount() >= 2); - - // generate the keys for the given input - auto &l_state = input.local_state.Cast(); - l_state.key_chunk.ReferenceColumns(chunk, l_state.key_column_ids); - l_state.arena_allocator.Reset(); - ART::GenerateKeys(l_state.arena_allocator, l_state.key_chunk, l_state.keys); - - // insert the keys and their corresponding row IDs - auto &row_identifiers = chunk.data[chunk.ColumnCount() - 1]; - if (sorted) { - return SinkSorted(row_identifiers, input); - } - return SinkUnsorted(row_identifiers, input); -} - -SinkCombineResultType PhysicalCreateARTIndex::Combine(ExecutionContext &context, - OperatorSinkCombineInput &input) const { - - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - - // merge the local index into the global index - if (!gstate.global_index->MergeIndexes(*lstate.local_index)) { - throw ConstraintException("Data contains duplicates on indexed column(s)"); - } - - return SinkCombineResultType::FINISHED; -} - -SinkFinalizeType PhysicalCreateARTIndex::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - - // here, we set the resulting global index as the newly created index of the table - auto &state = input.global_state.Cast(); - - // vacuum excess memory and verify - state.global_index->Vacuum(); - D_ASSERT(!state.global_index->VerifyAndToString(true).empty()); - - auto &storage = table.GetStorage(); - if (!storage.IsRoot()) { - throw TransactionException("Transaction conflict: cannot add an index to a table that has been altered!"); - } - - auto &schema = table.schema; - auto index_entry = schema.CreateIndex(context, *info, table).get(); - if (!index_entry) { - D_ASSERT(info->on_conflict == OnCreateConflict::IGNORE_ON_CONFLICT); - // index already exists, but error ignored because of IF NOT EXISTS - return SinkFinalizeType::READY; - } - auto &index = index_entry->Cast(); - - index.index = state.global_index.get(); - index.info = storage.info; - for (auto &parsed_expr : info->parsed_expressions) { - index.parsed_expressions.push_back(parsed_expr->Copy()); - } - - // add index to storage - storage.info->indexes.AddIndex(std::move(state.global_index)); - return SinkFinalizeType::READY; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// - -SourceResultType PhysicalCreateARTIndex::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - return SourceResultType::FINISHED; -} - -} // namespace duckdb - - - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -SourceResultType PhysicalCreateFunction::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &catalog = Catalog::GetCatalog(context.client, info->catalog); - catalog.CreateFunction(context.client, *info); - - return SourceResultType::FINISHED; -} - -} // namespace duckdb - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -SourceResultType PhysicalCreateSchema::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &catalog = Catalog::GetCatalog(context.client, info->catalog); - if (catalog.IsSystemCatalog()) { - throw BinderException("Cannot create schema in system catalog"); - } - catalog.CreateSchema(context.client, *info); - - return SourceResultType::FINISHED; -} - -} // namespace duckdb - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -SourceResultType PhysicalCreateSequence::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &catalog = Catalog::GetCatalog(context.client, info->catalog); - catalog.CreateSequence(context.client, *info); - - return SourceResultType::FINISHED; -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -PhysicalCreateTable::PhysicalCreateTable(LogicalOperator &op, SchemaCatalogEntry &schema, - unique_ptr info, idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::CREATE_TABLE, op.types, estimated_cardinality), schema(schema), - info(std::move(info)) { -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -SourceResultType PhysicalCreateTable::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &catalog = schema.catalog; - catalog.CreateTable(catalog.GetCatalogTransaction(context.client), schema, *info); - - return SourceResultType::FINISHED; -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -PhysicalCreateType::PhysicalCreateType(unique_ptr info_p, idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::CREATE_TYPE, {LogicalType::BIGINT}, estimated_cardinality), - info(std::move(info_p)) { -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class CreateTypeGlobalState : public GlobalSinkState { -public: - explicit CreateTypeGlobalState(ClientContext &context) : result(LogicalType::VARCHAR) { - } - Vector result; - idx_t size = 0; - idx_t capacity = STANDARD_VECTOR_SIZE; - string_set_t found_strings; -}; - -unique_ptr PhysicalCreateType::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(context); -} - -SinkResultType PhysicalCreateType::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - auto &gstate = input.global_state.Cast(); - idx_t total_row_count = gstate.size + chunk.size(); - if (total_row_count > NumericLimits::Maximum()) { - throw InvalidInputException("Attempted to create ENUM of size %llu, which exceeds the maximum size of %llu", - total_row_count, NumericLimits::Maximum()); - } - UnifiedVectorFormat sdata; - chunk.data[0].ToUnifiedFormat(chunk.size(), sdata); - - if (total_row_count > gstate.capacity) { - // We must resize our result vector - gstate.result.Resize(gstate.capacity, gstate.capacity * 2); - gstate.capacity *= 2; - } - - auto src_ptr = UnifiedVectorFormat::GetData(sdata); - auto result_ptr = FlatVector::GetData(gstate.result); - // Input vector has NULL value, we just throw an exception - for (idx_t i = 0; i < chunk.size(); i++) { - idx_t idx = sdata.sel->get_index(i); - if (!sdata.validity.RowIsValid(idx)) { - throw InvalidInputException("Attempted to create ENUM type with NULL value!"); - } - auto str = src_ptr[idx]; - auto entry = gstate.found_strings.find(src_ptr[idx]); - if (entry != gstate.found_strings.end()) { - // entry was already found - skip - continue; - } - auto owned_string = StringVector::AddStringOrBlob(gstate.result, str.GetData(), str.GetSize()); - gstate.found_strings.insert(owned_string); - result_ptr[gstate.size++] = owned_string; - } - return SinkResultType::NEED_MORE_INPUT; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -SourceResultType PhysicalCreateType::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - if (IsSink()) { - D_ASSERT(info->type == LogicalType::INVALID); - auto &g_sink_state = sink_state->Cast(); - info->type = LogicalType::ENUM(g_sink_state.result, g_sink_state.size); - } - - auto &catalog = Catalog::GetCatalog(context.client, info->catalog); - catalog.CreateType(context.client, *info); - return SourceResultType::FINISHED; -} - -} // namespace duckdb - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -SourceResultType PhysicalCreateView::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &catalog = Catalog::GetCatalog(context.client, info->catalog); - catalog.CreateView(context.client, *info); - - return SourceResultType::FINISHED; -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -SourceResultType PhysicalDetach::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &db_manager = DatabaseManager::Get(context.client); - db_manager.DetachDatabase(context.client, info->name, info->if_not_found); - - return SourceResultType::FINISHED; -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -SourceResultType PhysicalDrop::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { - switch (info->type) { - case CatalogType::PREPARED_STATEMENT: { - // DEALLOCATE silently ignores errors - auto &statements = ClientData::Get(context.client).prepared_statements; - if (statements.find(info->name) != statements.end()) { - statements.erase(info->name); - } - break; - } - case CatalogType::SCHEMA_ENTRY: { - auto &catalog = Catalog::GetCatalog(context.client, info->catalog); - catalog.DropEntry(context.client, *info); - auto qualified_name = QualifiedName::Parse(info->name); - - // Check if the dropped schema was set as the current schema - auto &client_data = ClientData::Get(context.client); - auto &default_entry = client_data.catalog_search_path->GetDefault(); - auto ¤t_catalog = default_entry.catalog; - auto ¤t_schema = default_entry.schema; - D_ASSERT(info->name != DEFAULT_SCHEMA); - - if (info->catalog == current_catalog && current_schema == info->name) { - // Reset the schema to default - SchemaSetting::SetLocal(context.client, DEFAULT_SCHEMA); - } - break; - } - default: { - auto &catalog = Catalog::GetCatalog(context.client, info->catalog); - catalog.DropEntry(context.client, *info); - break; - } - } - - return SourceResultType::FINISHED; -} - -} // namespace duckdb - - - - - - - - - - - - -namespace duckdb { - -PhysicalCTE::PhysicalCTE(string ctename, idx_t table_index, vector types, unique_ptr top, - unique_ptr bottom, idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::CTE, std::move(types), estimated_cardinality), table_index(table_index), - ctename(std::move(ctename)) { - children.push_back(std::move(top)); - children.push_back(std::move(bottom)); -} - -PhysicalCTE::~PhysicalCTE() { -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class CTEState : public GlobalSinkState { -public: - explicit CTEState(ClientContext &context, const PhysicalCTE &op) - : intermediate_table(context, op.children[1]->GetTypes()) { - } - ColumnDataCollection intermediate_table; - ColumnDataScanState scan_state; - bool initialized = false; - bool finished_scan = false; -}; - -unique_ptr PhysicalCTE::GetGlobalSinkState(ClientContext &context) const { - working_table->Reset(); - return make_uniq(context, *this); -} - -SinkResultType PhysicalCTE::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - auto &gstate = input.global_state.Cast(); - if (!gstate.finished_scan) { - working_table->Append(chunk); - } else { - gstate.intermediate_table.Append(chunk); - } - return SinkResultType::NEED_MORE_INPUT; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -SourceResultType PhysicalCTE::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { - auto &gstate = sink_state->Cast(); - if (!gstate.initialized) { - gstate.intermediate_table.InitializeScan(gstate.scan_state); - gstate.finished_scan = false; - gstate.initialized = true; - } - if (!gstate.finished_scan) { - gstate.finished_scan = true; - ExecuteRecursivePipelines(context); - } - - gstate.intermediate_table.Scan(gstate.scan_state, chunk); - - return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; -} - -void PhysicalCTE::ExecuteRecursivePipelines(ExecutionContext &context) const { - if (!recursive_meta_pipeline) { - throw InternalException("Missing meta pipeline for recursive CTE"); - } - - // get and reset pipelines - vector> pipelines; - recursive_meta_pipeline->GetPipelines(pipelines, true); - for (auto &pipeline : pipelines) { - auto sink = pipeline->GetSink(); - if (sink.get() != this) { - sink->sink_state.reset(); - } - for (auto &op_ref : pipeline->GetOperators()) { - auto &op = op_ref.get(); - op.op_state.reset(); - } - pipeline->ClearSource(); - } - - // get the MetaPipelines in the recursive_meta_pipeline and reschedule them - vector> meta_pipelines; - recursive_meta_pipeline->GetMetaPipelines(meta_pipelines, true, false); - auto &executor = recursive_meta_pipeline->GetExecutor(); - vector> events; - executor.ReschedulePipelines(meta_pipelines, events); - - while (true) { - executor.WorkOnTasks(); - if (executor.HasError()) { - executor.ThrowException(); - } - bool finished = true; - for (auto &event : events) { - if (!event->IsFinished()) { - finished = false; - break; - } - } - if (finished) { - // all pipelines finished: done! - break; - } - } -} - -//===--------------------------------------------------------------------===// -// Pipeline Construction -//===--------------------------------------------------------------------===// -void PhysicalCTE::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { - D_ASSERT(children.size() == 2); - op_state.reset(); - sink_state.reset(); - recursive_meta_pipeline.reset(); - - auto &state = meta_pipeline.GetState(); - state.SetPipelineSource(current, *this); - - auto &executor = meta_pipeline.GetExecutor(); - executor.AddMaterializedCTE(*this); - - auto &child_meta_pipeline = meta_pipeline.CreateChildMetaPipeline(current, *this); - child_meta_pipeline.Build(*children[0]); - - // the RHS is the recursive pipeline - recursive_meta_pipeline = make_shared(executor, state, this); - if (meta_pipeline.HasRecursiveCTE()) { - recursive_meta_pipeline->SetRecursiveCTE(); - } - recursive_meta_pipeline->Build(*children[1]); -} - -vector> PhysicalCTE::GetSources() const { - return {*this}; -} - -string PhysicalCTE::ParamsToString() const { - string result = ""; - result += "\n[INFOSEPARATOR]\n"; - result += ctename; - result += "\n[INFOSEPARATOR]\n"; - result += StringUtil::Format("idx: %llu", table_index); - return result; -} - -} // namespace duckdb - - - - - - - - - - - - -namespace duckdb { - -PhysicalRecursiveCTE::PhysicalRecursiveCTE(string ctename, idx_t table_index, vector types, bool union_all, - unique_ptr top, unique_ptr bottom, - idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::RECURSIVE_CTE, std::move(types), estimated_cardinality), - ctename(std::move(ctename)), table_index(table_index), union_all(union_all) { - children.push_back(std::move(top)); - children.push_back(std::move(bottom)); -} - -PhysicalRecursiveCTE::~PhysicalRecursiveCTE() { -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class RecursiveCTEState : public GlobalSinkState { -public: - explicit RecursiveCTEState(ClientContext &context, const PhysicalRecursiveCTE &op) - : intermediate_table(context, op.GetTypes()), new_groups(STANDARD_VECTOR_SIZE) { - ht = make_uniq(context, BufferAllocator::Get(context), op.types, - vector(), vector()); - } - - unique_ptr ht; - - bool intermediate_empty = true; - ColumnDataCollection intermediate_table; - ColumnDataScanState scan_state; - bool initialized = false; - bool finished_scan = false; - SelectionVector new_groups; -}; - -unique_ptr PhysicalRecursiveCTE::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(context, *this); -} - -idx_t PhysicalRecursiveCTE::ProbeHT(DataChunk &chunk, RecursiveCTEState &state) const { - Vector dummy_addresses(LogicalType::POINTER); - - // Use the HT to eliminate duplicate rows - idx_t new_group_count = state.ht->FindOrCreateGroups(chunk, dummy_addresses, state.new_groups); - - // we only return entries we have not seen before (i.e. new groups) - chunk.Slice(state.new_groups, new_group_count); - - return new_group_count; -} - -SinkResultType PhysicalRecursiveCTE::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - auto &gstate = input.global_state.Cast(); - if (!union_all) { - idx_t match_count = ProbeHT(chunk, gstate); - if (match_count > 0) { - gstate.intermediate_table.Append(chunk); - } - } else { - gstate.intermediate_table.Append(chunk); - } - return SinkResultType::NEED_MORE_INPUT; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -SourceResultType PhysicalRecursiveCTE::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &gstate = sink_state->Cast(); - if (!gstate.initialized) { - gstate.intermediate_table.InitializeScan(gstate.scan_state); - gstate.finished_scan = false; - gstate.initialized = true; - } - while (chunk.size() == 0) { - if (!gstate.finished_scan) { - // scan any chunks we have collected so far - gstate.intermediate_table.Scan(gstate.scan_state, chunk); - if (chunk.size() == 0) { - gstate.finished_scan = true; - } else { - break; - } - } else { - // we have run out of chunks - // now we need to recurse - // we set up the working table as the data we gathered in this iteration of the recursion - working_table->Reset(); - working_table->Combine(gstate.intermediate_table); - // and we clear the intermediate table - gstate.finished_scan = false; - gstate.intermediate_table.Reset(); - // now we need to re-execute all of the pipelines that depend on the recursion - ExecuteRecursivePipelines(context); - - // check if we obtained any results - // if not, we are done - if (gstate.intermediate_table.Count() == 0) { - gstate.finished_scan = true; - break; - } - // set up the scan again - gstate.intermediate_table.InitializeScan(gstate.scan_state); - } - } - - return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; -} - -void PhysicalRecursiveCTE::ExecuteRecursivePipelines(ExecutionContext &context) const { - if (!recursive_meta_pipeline) { - throw InternalException("Missing meta pipeline for recursive CTE"); - } - D_ASSERT(recursive_meta_pipeline->HasRecursiveCTE()); - - // get and reset pipelines - vector> pipelines; - recursive_meta_pipeline->GetPipelines(pipelines, true); - for (auto &pipeline : pipelines) { - auto sink = pipeline->GetSink(); - if (sink.get() != this) { - sink->sink_state.reset(); - } - for (auto &op_ref : pipeline->GetOperators()) { - auto &op = op_ref.get(); - op.op_state.reset(); - } - pipeline->ClearSource(); - } - - // get the MetaPipelines in the recursive_meta_pipeline and reschedule them - vector> meta_pipelines; - recursive_meta_pipeline->GetMetaPipelines(meta_pipelines, true, false); - auto &executor = recursive_meta_pipeline->GetExecutor(); - vector> events; - executor.ReschedulePipelines(meta_pipelines, events); - - while (true) { - executor.WorkOnTasks(); - if (executor.HasError()) { - executor.ThrowException(); - } - bool finished = true; - for (auto &event : events) { - if (!event->IsFinished()) { - finished = false; - break; - } - } - if (finished) { - // all pipelines finished: done! - break; - } - } -} - -//===--------------------------------------------------------------------===// -// Pipeline Construction -//===--------------------------------------------------------------------===// -void PhysicalRecursiveCTE::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { - op_state.reset(); - sink_state.reset(); - recursive_meta_pipeline.reset(); - - auto &state = meta_pipeline.GetState(); - state.SetPipelineSource(current, *this); - - auto &executor = meta_pipeline.GetExecutor(); - executor.AddRecursiveCTE(*this); - - // the LHS of the recursive CTE is our initial state - auto &initial_state_pipeline = meta_pipeline.CreateChildMetaPipeline(current, *this); - initial_state_pipeline.Build(*children[0]); - - // the RHS is the recursive pipeline - recursive_meta_pipeline = make_shared(executor, state, this); - recursive_meta_pipeline->SetRecursiveCTE(); - recursive_meta_pipeline->Build(*children[1]); -} - -vector> PhysicalRecursiveCTE::GetSources() const { - return {*this}; -} - -string PhysicalRecursiveCTE::ParamsToString() const { - string result = ""; - result += "\n[INFOSEPARATOR]\n"; - result += ctename; - result += "\n[INFOSEPARATOR]\n"; - result += StringUtil::Format("idx: %llu", table_index); - return result; -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -PhysicalUnion::PhysicalUnion(vector types, unique_ptr top, - unique_ptr bottom, idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::UNION, std::move(types), estimated_cardinality) { - children.push_back(std::move(top)); - children.push_back(std::move(bottom)); -} - -//===--------------------------------------------------------------------===// -// Pipeline Construction -//===--------------------------------------------------------------------===// -void PhysicalUnion::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { - op_state.reset(); - sink_state.reset(); - - // order matters if any of the downstream operators are order dependent, - // or if the sink preserves order, but does not support batch indices to do so - auto sink = meta_pipeline.GetSink(); - bool order_matters = false; - if (current.IsOrderDependent()) { - order_matters = true; - } - if (sink) { - if (sink->SinkOrderDependent() || sink->RequiresBatchIndex()) { - order_matters = true; - } - if (!sink->ParallelSink()) { - order_matters = true; - } - } - - // create a union pipeline that is identical to 'current' - auto union_pipeline = meta_pipeline.CreateUnionPipeline(current, order_matters); - - // continue with the current pipeline - children[0]->BuildPipelines(current, meta_pipeline); - - if (order_matters) { - // order matters, so 'union_pipeline' must come after all pipelines created by building out 'current' - meta_pipeline.AddDependenciesFrom(union_pipeline, union_pipeline, false); - } - - // build the union pipeline - children[1]->BuildPipelines(*union_pipeline, meta_pipeline); - - // Assign proper batch index to the union pipeline - // This needs to happen after the pipelines have been built because unions can be nested - meta_pipeline.AssignNextBatchIndex(union_pipeline); -} - -vector> PhysicalUnion::GetSources() const { - vector> result; - for (auto &child : children) { - auto child_sources = child->GetSources(); - result.insert(result.end(), child_sources.begin(), child_sources.end()); - } - return result; -} - -} // namespace duckdb - - - - - -namespace duckdb { - -PerfectAggregateHashTable::PerfectAggregateHashTable(ClientContext &context, Allocator &allocator, - const vector &group_types_p, - vector payload_types_p, - vector aggregate_objects_p, - vector group_minima_p, vector required_bits_p) - : BaseAggregateHashTable(context, allocator, aggregate_objects_p, std::move(payload_types_p)), - addresses(LogicalType::POINTER), required_bits(std::move(required_bits_p)), total_required_bits(0), - group_minima(std::move(group_minima_p)), sel(STANDARD_VECTOR_SIZE), - aggregate_allocator(make_uniq(allocator)) { - for (auto &group_bits : required_bits) { - total_required_bits += group_bits; - } - // the total amount of groups we allocate space for is 2^required_bits - total_groups = (uint64_t)1 << total_required_bits; - // we don't need to store the groups in a perfect hash table, since the group keys can be deduced by their location - grouping_columns = group_types_p.size(); - layout.Initialize(std::move(aggregate_objects_p)); - tuple_size = layout.GetRowWidth(); - - // allocate and null initialize the data - owned_data = make_unsafe_uniq_array(tuple_size * total_groups); - data = owned_data.get(); - - // set up the empty payloads for every tuple, and initialize the "occupied" flag to false - group_is_set = make_unsafe_uniq_array(total_groups); - memset(group_is_set.get(), 0, total_groups * sizeof(bool)); - - // initialize the hash table for each entry - auto address_data = FlatVector::GetData(addresses); - idx_t init_count = 0; - for (idx_t i = 0; i < total_groups; i++) { - address_data[init_count] = uintptr_t(data) + (tuple_size * i); - init_count++; - if (init_count == STANDARD_VECTOR_SIZE) { - RowOperations::InitializeStates(layout, addresses, *FlatVector::IncrementalSelectionVector(), init_count); - init_count = 0; - } - } - RowOperations::InitializeStates(layout, addresses, *FlatVector::IncrementalSelectionVector(), init_count); -} - -PerfectAggregateHashTable::~PerfectAggregateHashTable() { - Destroy(); -} - -template -static void ComputeGroupLocationTemplated(UnifiedVectorFormat &group_data, Value &min, uintptr_t *address_data, - idx_t current_shift, idx_t count) { - auto data = UnifiedVectorFormat::GetData(group_data); - auto min_val = min.GetValueUnsafe(); - if (!group_data.validity.AllValid()) { - for (idx_t i = 0; i < count; i++) { - auto index = group_data.sel->get_index(i); - // check if the value is NULL - // NULL groups are considered as "0" in the hash table - // that is to say, they have no effect on the position of the element (because 0 << shift is 0) - // we only need to handle non-null values here - if (group_data.validity.RowIsValid(index)) { - D_ASSERT(data[index] >= min_val); - uintptr_t adjusted_value = (data[index] - min_val) + 1; - address_data[i] += adjusted_value << current_shift; - } - } - } else { - // no null values: we can directly compute the addresses - for (idx_t i = 0; i < count; i++) { - auto index = group_data.sel->get_index(i); - uintptr_t adjusted_value = (data[index] - min_val) + 1; - address_data[i] += adjusted_value << current_shift; - } - } -} - -static void ComputeGroupLocation(Vector &group, Value &min, uintptr_t *address_data, idx_t current_shift, idx_t count) { - UnifiedVectorFormat vdata; - group.ToUnifiedFormat(count, vdata); - - switch (group.GetType().InternalType()) { - case PhysicalType::INT8: - ComputeGroupLocationTemplated(vdata, min, address_data, current_shift, count); - break; - case PhysicalType::INT16: - ComputeGroupLocationTemplated(vdata, min, address_data, current_shift, count); - break; - case PhysicalType::INT32: - ComputeGroupLocationTemplated(vdata, min, address_data, current_shift, count); - break; - case PhysicalType::INT64: - ComputeGroupLocationTemplated(vdata, min, address_data, current_shift, count); - break; - case PhysicalType::UINT8: - ComputeGroupLocationTemplated(vdata, min, address_data, current_shift, count); - break; - case PhysicalType::UINT16: - ComputeGroupLocationTemplated(vdata, min, address_data, current_shift, count); - break; - case PhysicalType::UINT32: - ComputeGroupLocationTemplated(vdata, min, address_data, current_shift, count); - break; - case PhysicalType::UINT64: - ComputeGroupLocationTemplated(vdata, min, address_data, current_shift, count); - break; - default: - throw InternalException("Unsupported group type for perfect aggregate hash table"); - } -} - -void PerfectAggregateHashTable::AddChunk(DataChunk &groups, DataChunk &payload) { - // first we need to find the location in the HT of each of the groups - auto address_data = FlatVector::GetData(addresses); - // zero-initialize the address data - memset(address_data, 0, groups.size() * sizeof(uintptr_t)); - D_ASSERT(groups.ColumnCount() == group_minima.size()); - - // then compute the actual group location by iterating over each of the groups - idx_t current_shift = total_required_bits; - for (idx_t i = 0; i < groups.ColumnCount(); i++) { - current_shift -= required_bits[i]; - ComputeGroupLocation(groups.data[i], group_minima[i], address_data, current_shift, groups.size()); - } - // now we have the HT entry number for every tuple - // compute the actual pointer to the data by adding it to the base HT pointer and multiplying by the tuple size - for (idx_t i = 0; i < groups.size(); i++) { - const auto group = address_data[i]; - D_ASSERT(group < total_groups); - group_is_set[group] = true; - address_data[i] = uintptr_t(data) + group * tuple_size; - } - - // after finding the group location we update the aggregates - idx_t payload_idx = 0; - auto &aggregates = layout.GetAggregates(); - RowOperationsState row_state(*aggregate_allocator); - for (idx_t aggr_idx = 0; aggr_idx < aggregates.size(); aggr_idx++) { - auto &aggregate = aggregates[aggr_idx]; - auto input_count = (idx_t)aggregate.child_count; - if (aggregate.filter) { - RowOperations::UpdateFilteredStates(row_state, filter_set.GetFilterData(aggr_idx), aggregate, addresses, - payload, payload_idx); - } else { - RowOperations::UpdateStates(row_state, aggregate, addresses, payload, payload_idx, payload.size()); - } - // move to the next aggregate - payload_idx += input_count; - VectorOperations::AddInPlace(addresses, aggregate.payload_size, payload.size()); - } -} - -void PerfectAggregateHashTable::Combine(PerfectAggregateHashTable &other) { - D_ASSERT(total_groups == other.total_groups); - D_ASSERT(tuple_size == other.tuple_size); - - Vector source_addresses(LogicalType::POINTER); - Vector target_addresses(LogicalType::POINTER); - auto source_addresses_ptr = FlatVector::GetData(source_addresses); - auto target_addresses_ptr = FlatVector::GetData(target_addresses); - - // iterate over all entries of both hash tables and call combine for all entries that can be combined - data_ptr_t source_ptr = other.data; - data_ptr_t target_ptr = data; - idx_t combine_count = 0; - RowOperationsState row_state(*aggregate_allocator); - for (idx_t i = 0; i < total_groups; i++) { - auto has_entry_source = other.group_is_set[i]; - // we only have any work to do if the source has an entry for this group - if (has_entry_source) { - group_is_set[i] = true; - source_addresses_ptr[combine_count] = source_ptr; - target_addresses_ptr[combine_count] = target_ptr; - combine_count++; - if (combine_count == STANDARD_VECTOR_SIZE) { - RowOperations::CombineStates(row_state, layout, source_addresses, target_addresses, combine_count); - combine_count = 0; - } - } - source_ptr += tuple_size; - target_ptr += tuple_size; - } - RowOperations::CombineStates(row_state, layout, source_addresses, target_addresses, combine_count); - - // FIXME: after moving the arena allocator, we currently have to ensure that the pointer is not nullptr, because the - // FIXME: Destroy()-function of the hash table expects an allocator in some cases (e.g., for sorted aggregates) - stored_allocators.push_back(std::move(other.aggregate_allocator)); - other.aggregate_allocator = make_uniq(allocator); -} - -template -static void ReconstructGroupVectorTemplated(uint32_t group_values[], Value &min, idx_t mask, idx_t shift, - idx_t entry_count, Vector &result) { - auto data = FlatVector::GetData(result); - auto &validity_mask = FlatVector::Validity(result); - auto min_data = min.GetValueUnsafe(); - for (idx_t i = 0; i < entry_count; i++) { - // extract the value of this group from the total group index - auto group_index = (group_values[i] >> shift) & mask; - if (group_index == 0) { - // if it is 0, the value is NULL - validity_mask.SetInvalid(i); - } else { - // otherwise we add the value (minus 1) to the min value - data[i] = min_data + group_index - 1; - } - } -} - -static void ReconstructGroupVector(uint32_t group_values[], Value &min, idx_t required_bits, idx_t shift, - idx_t entry_count, Vector &result) { - // construct the mask for this entry - idx_t mask = ((uint64_t)1 << required_bits) - 1; - switch (result.GetType().InternalType()) { - case PhysicalType::INT8: - ReconstructGroupVectorTemplated(group_values, min, mask, shift, entry_count, result); - break; - case PhysicalType::INT16: - ReconstructGroupVectorTemplated(group_values, min, mask, shift, entry_count, result); - break; - case PhysicalType::INT32: - ReconstructGroupVectorTemplated(group_values, min, mask, shift, entry_count, result); - break; - case PhysicalType::INT64: - ReconstructGroupVectorTemplated(group_values, min, mask, shift, entry_count, result); - break; - case PhysicalType::UINT8: - ReconstructGroupVectorTemplated(group_values, min, mask, shift, entry_count, result); - break; - case PhysicalType::UINT16: - ReconstructGroupVectorTemplated(group_values, min, mask, shift, entry_count, result); - break; - case PhysicalType::UINT32: - ReconstructGroupVectorTemplated(group_values, min, mask, shift, entry_count, result); - break; - case PhysicalType::UINT64: - ReconstructGroupVectorTemplated(group_values, min, mask, shift, entry_count, result); - break; - default: - throw InternalException("Invalid type for perfect aggregate HT group"); - } -} - -void PerfectAggregateHashTable::Scan(idx_t &scan_position, DataChunk &result) { - auto data_pointers = FlatVector::GetData(addresses); - uint32_t group_values[STANDARD_VECTOR_SIZE]; - - // iterate over the HT until we either have exhausted the entire HT, or - idx_t entry_count = 0; - for (; scan_position < total_groups; scan_position++) { - if (group_is_set[scan_position]) { - // this group is set: add it to the set of groups to extract - data_pointers[entry_count] = data + tuple_size * scan_position; - group_values[entry_count] = scan_position; - entry_count++; - if (entry_count == STANDARD_VECTOR_SIZE) { - scan_position++; - break; - } - } - } - if (entry_count == 0) { - // no entries found - return; - } - // first reconstruct the groups from the group index - idx_t shift = total_required_bits; - for (idx_t i = 0; i < grouping_columns; i++) { - shift -= required_bits[i]; - ReconstructGroupVector(group_values, group_minima[i], required_bits[i], shift, entry_count, result.data[i]); - } - // then construct the payloads - result.SetCardinality(entry_count); - RowOperationsState row_state(*aggregate_allocator); - RowOperations::FinalizeStates(row_state, layout, addresses, result, grouping_columns); -} - -void PerfectAggregateHashTable::Destroy() { - // check if there is any destructor to call - bool has_destructor = false; - for (auto &aggr : layout.GetAggregates()) { - if (aggr.function.destructor) { - has_destructor = true; - } - } - if (!has_destructor) { - return; - } - // there are aggregates with destructors: loop over the hash table - // and call the destructor method for each of the aggregates - auto data_pointers = FlatVector::GetData(addresses); - idx_t count = 0; - - // iterate over all initialised slots of the hash table - RowOperationsState row_state(*aggregate_allocator); - data_ptr_t payload_ptr = data; - for (idx_t i = 0; i < total_groups; i++) { - if (group_is_set[i]) { - data_pointers[count++] = payload_ptr; - if (count == STANDARD_VECTOR_SIZE) { - RowOperations::DestroyStates(row_state, layout, addresses, count); - count = 0; - } - } - payload_ptr += tuple_size; - } - RowOperations::DestroyStates(row_state, layout, addresses, count); -} - -} // namespace duckdb - - - - - - - - - - - - - -namespace duckdb { - -string PhysicalOperator::GetName() const { - return PhysicalOperatorToString(type); -} - -string PhysicalOperator::ToString() const { - TreeRenderer renderer; - return renderer.ToString(*this); -} - -// LCOV_EXCL_START -void PhysicalOperator::Print() const { - Printer::Print(ToString()); -} -// LCOV_EXCL_STOP - -vector> PhysicalOperator::GetChildren() const { - vector> result; - for (auto &child : children) { - result.push_back(*child); - } - return result; -} - -//===--------------------------------------------------------------------===// -// Operator -//===--------------------------------------------------------------------===// -// LCOV_EXCL_START -unique_ptr PhysicalOperator::GetOperatorState(ExecutionContext &context) const { - return make_uniq(); -} - -unique_ptr PhysicalOperator::GetGlobalOperatorState(ClientContext &context) const { - return make_uniq(); -} - -OperatorResultType PhysicalOperator::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - GlobalOperatorState &gstate, OperatorState &state) const { - throw InternalException("Calling Execute on a node that is not an operator!"); -} - -OperatorFinalizeResultType PhysicalOperator::FinalExecute(ExecutionContext &context, DataChunk &chunk, - GlobalOperatorState &gstate, OperatorState &state) const { - throw InternalException("Calling FinalExecute on a node that is not an operator!"); -} -// LCOV_EXCL_STOP - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -unique_ptr PhysicalOperator::GetLocalSourceState(ExecutionContext &context, - GlobalSourceState &gstate) const { - return make_uniq(); -} - -unique_ptr PhysicalOperator::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(); -} - -// LCOV_EXCL_START -SourceResultType PhysicalOperator::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - throw InternalException("Calling GetData on a node that is not a source!"); -} - -idx_t PhysicalOperator::GetBatchIndex(ExecutionContext &context, DataChunk &chunk, GlobalSourceState &gstate, - LocalSourceState &lstate) const { - throw InternalException("Calling GetBatchIndex on a node that does not support it"); -} - -double PhysicalOperator::GetProgress(ClientContext &context, GlobalSourceState &gstate) const { - return -1; -} -// LCOV_EXCL_STOP - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -// LCOV_EXCL_START -SinkResultType PhysicalOperator::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - throw InternalException("Calling Sink on a node that is not a sink!"); -} - -// LCOV_EXCL_STOP - -SinkCombineResultType PhysicalOperator::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { - return SinkCombineResultType::FINISHED; -} - -SinkFinalizeType PhysicalOperator::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - return SinkFinalizeType::READY; -} - -void PhysicalOperator::NextBatch(ExecutionContext &context, GlobalSinkState &state, LocalSinkState &lstate_p) const { -} - -unique_ptr PhysicalOperator::GetLocalSinkState(ExecutionContext &context) const { - return make_uniq(); -} - -unique_ptr PhysicalOperator::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(); -} - -idx_t PhysicalOperator::GetMaxThreadMemory(ClientContext &context) { - // Memory usage per thread should scale with max mem / num threads - // We take 1/4th of this, to be conservative - idx_t max_memory = BufferManager::GetBufferManager(context).GetMaxMemory(); - idx_t num_threads = TaskScheduler::GetScheduler(context).NumberOfThreads(); - return (max_memory / num_threads) / 4; -} - -bool PhysicalOperator::OperatorCachingAllowed(ExecutionContext &context) { - if (!context.client.config.enable_caching_operators) { - return false; - } else if (!context.pipeline) { - return false; - } else if (!context.pipeline->GetSink()) { - return false; - } else if (context.pipeline->GetSink()->RequiresBatchIndex()) { - return false; - } else if (context.pipeline->IsOrderDependent()) { - return false; - } - - return true; -} - -//===--------------------------------------------------------------------===// -// Pipeline Construction -//===--------------------------------------------------------------------===// -void PhysicalOperator::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { - op_state.reset(); - - auto &state = meta_pipeline.GetState(); - if (IsSink()) { - // operator is a sink, build a pipeline - sink_state.reset(); - D_ASSERT(children.size() == 1); - - // single operator: the operator becomes the data source of the current pipeline - state.SetPipelineSource(current, *this); - - // we create a new pipeline starting from the child - auto &child_meta_pipeline = meta_pipeline.CreateChildMetaPipeline(current, *this); - child_meta_pipeline.Build(*children[0]); - } else { - // operator is not a sink! recurse in children - if (children.empty()) { - // source - state.SetPipelineSource(current, *this); - } else { - if (children.size() != 1) { - throw InternalException("Operator not supported in BuildPipelines"); - } - state.AddPipelineOperator(current, *this); - children[0]->BuildPipelines(current, meta_pipeline); - } - } -} - -vector> PhysicalOperator::GetSources() const { - vector> result; - if (IsSink()) { - D_ASSERT(children.size() == 1); - result.push_back(*this); - return result; - } else { - if (children.empty()) { - // source - result.push_back(*this); - return result; - } else { - if (children.size() != 1) { - throw InternalException("Operator not supported in GetSource"); - } - return children[0]->GetSources(); - } - } -} - -bool PhysicalOperator::AllSourcesSupportBatchIndex() const { - auto sources = GetSources(); - for (auto &source : sources) { - if (!source.get().SupportsBatchIndex()) { - return false; - } - } - return true; -} - -void PhysicalOperator::Verify() { -#ifdef DEBUG - auto sources = GetSources(); - D_ASSERT(!sources.empty()); - for (auto &child : children) { - child->Verify(); - } -#endif -} - -bool CachingPhysicalOperator::CanCacheType(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::LIST: - case LogicalTypeId::MAP: - return false; - case LogicalTypeId::STRUCT: { - auto &entries = StructType::GetChildTypes(type); - for (auto &entry : entries) { - if (!CanCacheType(entry.second)) { - return false; - } - } - return true; - } - default: - return true; - } -} - -CachingPhysicalOperator::CachingPhysicalOperator(PhysicalOperatorType type, vector types_p, - idx_t estimated_cardinality) - : PhysicalOperator(type, std::move(types_p), estimated_cardinality) { - - caching_supported = true; - for (auto &col_type : types) { - if (!CanCacheType(col_type)) { - caching_supported = false; - break; - } - } -} - -OperatorResultType CachingPhysicalOperator::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - GlobalOperatorState &gstate, OperatorState &state_p) const { - auto &state = state_p.Cast(); - - // Execute child operator - auto child_result = ExecuteInternal(context, input, chunk, gstate, state); - -#if STANDARD_VECTOR_SIZE >= 128 - if (!state.initialized) { - state.initialized = true; - state.can_cache_chunk = caching_supported && PhysicalOperator::OperatorCachingAllowed(context); - } - if (!state.can_cache_chunk) { - return child_result; - } - // TODO chunk size of 0 should not result in a cache being created! - if (chunk.size() < CACHE_THRESHOLD) { - // we have filtered out a significant amount of tuples - // add this chunk to the cache and continue - - if (!state.cached_chunk) { - state.cached_chunk = make_uniq(); - state.cached_chunk->Initialize(Allocator::Get(context.client), chunk.GetTypes()); - } - - state.cached_chunk->Append(chunk); - - if (state.cached_chunk->size() >= (STANDARD_VECTOR_SIZE - CACHE_THRESHOLD) || - child_result == OperatorResultType::FINISHED) { - // chunk cache full: return it - chunk.Move(*state.cached_chunk); - state.cached_chunk->Initialize(Allocator::Get(context.client), chunk.GetTypes()); - return child_result; - } else { - // chunk cache not full return empty result - chunk.Reset(); - } - } -#endif - - return child_result; -} - -OperatorFinalizeResultType CachingPhysicalOperator::FinalExecute(ExecutionContext &context, DataChunk &chunk, - GlobalOperatorState &gstate, - OperatorState &state_p) const { - auto &state = state_p.Cast(); - if (state.cached_chunk) { - chunk.Move(*state.cached_chunk); - state.cached_chunk.reset(); - } else { - chunk.SetCardinality(0); - } - return OperatorFinalizeResultType::FINISHED; -} - -} // namespace duckdb - - - - - - - - - - - - - - -namespace duckdb { - -static uint32_t RequiredBitsForValue(uint32_t n) { - idx_t required_bits = 0; - while (n > 0) { - n >>= 1; - required_bits++; - } - return required_bits; -} - -template -hugeint_t GetRangeHugeint(const BaseStatistics &nstats) { - return Hugeint::Convert(NumericStats::GetMax(nstats)) - Hugeint::Convert(NumericStats::GetMin(nstats)); -} - -static bool CanUsePerfectHashAggregate(ClientContext &context, LogicalAggregate &op, vector &bits_per_group) { - if (op.grouping_sets.size() > 1 || !op.grouping_functions.empty()) { - return false; - } - idx_t perfect_hash_bits = 0; - if (op.group_stats.empty()) { - op.group_stats.resize(op.groups.size()); - } - for (idx_t group_idx = 0; group_idx < op.groups.size(); group_idx++) { - auto &group = op.groups[group_idx]; - auto &stats = op.group_stats[group_idx]; - - switch (group->return_type.InternalType()) { - case PhysicalType::INT8: - case PhysicalType::INT16: - case PhysicalType::INT32: - case PhysicalType::INT64: - case PhysicalType::UINT8: - case PhysicalType::UINT16: - case PhysicalType::UINT32: - case PhysicalType::UINT64: - break; - default: - // we only support simple integer types for perfect hashing - return false; - } - // check if the group has stats available - auto &group_type = group->return_type; - if (!stats) { - // no stats, but we might still be able to use perfect hashing if the type is small enough - // for small types we can just set the stats to [type_min, type_max] - switch (group_type.InternalType()) { - case PhysicalType::INT8: - case PhysicalType::INT16: - case PhysicalType::UINT8: - case PhysicalType::UINT16: - break; - default: - // type is too large and there are no stats: skip perfect hashing - return false; - } - // construct stats with the min and max value of the type - stats = NumericStats::CreateUnknown(group_type).ToUnique(); - NumericStats::SetMin(*stats, Value::MinimumValue(group_type)); - NumericStats::SetMax(*stats, Value::MaximumValue(group_type)); - } - auto &nstats = *stats; - - if (!NumericStats::HasMinMax(nstats)) { - return false; - } - - if (NumericStats::Max(*stats) < NumericStats::Min(*stats)) { - // May result in underflow - return false; - } - - // we have a min and a max value for the stats: use that to figure out how many bits we have - // we add two here, one for the NULL value, and one to make the computation one-indexed - // (e.g. if min and max are the same, we still need one entry in total) - hugeint_t range_h; - switch (group_type.InternalType()) { - case PhysicalType::INT8: - range_h = GetRangeHugeint(nstats); - break; - case PhysicalType::INT16: - range_h = GetRangeHugeint(nstats); - break; - case PhysicalType::INT32: - range_h = GetRangeHugeint(nstats); - break; - case PhysicalType::INT64: - range_h = GetRangeHugeint(nstats); - break; - case PhysicalType::UINT8: - range_h = GetRangeHugeint(nstats); - break; - case PhysicalType::UINT16: - range_h = GetRangeHugeint(nstats); - break; - case PhysicalType::UINT32: - range_h = GetRangeHugeint(nstats); - break; - case PhysicalType::UINT64: - range_h = GetRangeHugeint(nstats); - break; - default: - throw InternalException("Unsupported type for perfect hash (should be caught before)"); - } - - uint64_t range; - if (!Hugeint::TryCast(range_h, range)) { - return false; - } - - // bail out on any range bigger than 2^32 - if (range >= NumericLimits::Maximum()) { - return false; - } - - range += 2; - // figure out how many bits we need - idx_t required_bits = RequiredBitsForValue(range); - bits_per_group.push_back(required_bits); - perfect_hash_bits += required_bits; - // check if we have exceeded the bits for the hash - if (perfect_hash_bits > ClientConfig::GetConfig(context).perfect_ht_threshold) { - // too many bits for perfect hash - return false; - } - } - for (auto &expression : op.expressions) { - auto &aggregate = expression->Cast(); - if (aggregate.IsDistinct() || !aggregate.function.combine) { - // distinct aggregates are not supported in perfect hash aggregates - return false; - } - } - return true; -} - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalAggregate &op) { - unique_ptr groupby; - D_ASSERT(op.children.size() == 1); - - auto plan = CreatePlan(*op.children[0]); - - plan = ExtractAggregateExpressions(std::move(plan), op.expressions, op.groups); - - if (op.groups.empty() && op.grouping_sets.size() <= 1) { - // no groups, check if we can use a simple aggregation - // special case: aggregate entire columns together - bool use_simple_aggregation = true; - for (auto &expression : op.expressions) { - auto &aggregate = expression->Cast(); - if (!aggregate.function.simple_update) { - // unsupported aggregate for simple aggregation: use hash aggregation - use_simple_aggregation = false; - break; - } - } - if (use_simple_aggregation) { - groupby = make_uniq_base(op.types, std::move(op.expressions), - op.estimated_cardinality); - } else { - groupby = make_uniq_base( - context, op.types, std::move(op.expressions), op.estimated_cardinality); - } - } else { - // groups! create a GROUP BY aggregator - // use a perfect hash aggregate if possible - vector required_bits; - if (CanUsePerfectHashAggregate(context, op, required_bits)) { - groupby = make_uniq_base( - context, op.types, std::move(op.expressions), std::move(op.groups), std::move(op.group_stats), - std::move(required_bits), op.estimated_cardinality); - } else { - groupby = make_uniq_base( - context, op.types, std::move(op.expressions), std::move(op.groups), std::move(op.grouping_sets), - std::move(op.grouping_functions), op.estimated_cardinality); - } - } - groupby->children.push_back(std::move(plan)); - return groupby; -} - -unique_ptr -PhysicalPlanGenerator::ExtractAggregateExpressions(unique_ptr child, - vector> &aggregates, - vector> &groups) { - vector> expressions; - vector types; - - // bind sorted aggregates - for (auto &aggr : aggregates) { - auto &bound_aggr = aggr->Cast(); - if (bound_aggr.order_bys) { - // sorted aggregate! - FunctionBinder::BindSortedAggregate(context, bound_aggr, groups); - } - } - for (auto &group : groups) { - auto ref = make_uniq(group->return_type, expressions.size()); - types.push_back(group->return_type); - expressions.push_back(std::move(group)); - group = std::move(ref); - } - for (auto &aggr : aggregates) { - auto &bound_aggr = aggr->Cast(); - for (auto &child : bound_aggr.children) { - auto ref = make_uniq(child->return_type, expressions.size()); - types.push_back(child->return_type); - expressions.push_back(std::move(child)); - child = std::move(ref); - } - if (bound_aggr.filter) { - auto &filter = bound_aggr.filter; - auto ref = make_uniq(filter->return_type, expressions.size()); - types.push_back(filter->return_type); - expressions.push_back(std::move(filter)); - bound_aggr.filter = std::move(ref); - } - } - if (expressions.empty()) { - return child; - } - auto projection = - make_uniq(std::move(types), std::move(expressions), child->estimated_cardinality); - projection->children.push_back(std::move(child)); - return std::move(projection); -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalAnyJoin &op) { - // first visit the child nodes - D_ASSERT(op.children.size() == 2); - D_ASSERT(op.condition); - - auto left = CreatePlan(*op.children[0]); - auto right = CreatePlan(*op.children[1]); - - // create the blockwise NL join - return make_uniq(op, std::move(left), std::move(right), std::move(op.condition), - op.join_type, op.estimated_cardinality); -} - -} // namespace duckdb - - - - - - - - - - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::PlanAsOfJoin(LogicalComparisonJoin &op) { - // now visit the children - D_ASSERT(op.children.size() == 2); - idx_t lhs_cardinality = op.children[0]->EstimateCardinality(context); - idx_t rhs_cardinality = op.children[1]->EstimateCardinality(context); - auto left = CreatePlan(*op.children[0]); - auto right = CreatePlan(*op.children[1]); - D_ASSERT(left && right); - - // Validate - vector equi_indexes; - auto asof_idx = op.conditions.size(); - for (size_t c = 0; c < op.conditions.size(); ++c) { - auto &cond = op.conditions[c]; - switch (cond.comparison) { - case ExpressionType::COMPARE_EQUAL: - case ExpressionType::COMPARE_NOT_DISTINCT_FROM: - equi_indexes.emplace_back(c); - break; - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - case ExpressionType::COMPARE_GREATERTHAN: - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - case ExpressionType::COMPARE_LESSTHAN: - D_ASSERT(asof_idx == op.conditions.size()); - asof_idx = c; - break; - default: - throw InternalException("Invalid ASOF JOIN comparison"); - } - } - D_ASSERT(asof_idx < op.conditions.size()); - - if (!ClientConfig::GetConfig(context).force_asof_iejoin) { - return make_uniq(op, std::move(left), std::move(right)); - } - - // Strip extra column from rhs projections - auto &right_projection_map = op.right_projection_map; - if (right_projection_map.empty()) { - const auto right_count = right->types.size(); - right_projection_map.reserve(right_count); - for (column_t i = 0; i < right_count; ++i) { - right_projection_map.emplace_back(i); - } - } - - // Debug implementation: IEJoin of Window - // LEAD(asof_column, 1, infinity) OVER (PARTITION BY equi_column... ORDER BY asof_column) AS asof_end - auto &asof_comp = op.conditions[asof_idx]; - auto &asof_column = asof_comp.right; - auto asof_type = asof_column->return_type; - auto asof_end = make_uniq(ExpressionType::WINDOW_LEAD, asof_type, nullptr, nullptr); - asof_end->children.emplace_back(asof_column->Copy()); - // TODO: If infinities are not supported for a type, fake them by looking at LHS statistics? - asof_end->offset_expr = make_uniq(Value::BIGINT(1)); - for (auto equi_idx : equi_indexes) { - asof_end->partitions.emplace_back(op.conditions[equi_idx].right->Copy()); - } - switch (asof_comp.comparison) { - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - case ExpressionType::COMPARE_GREATERTHAN: - asof_end->orders.emplace_back(OrderType::ASCENDING, OrderByNullType::NULLS_FIRST, asof_column->Copy()); - asof_end->default_expr = make_uniq(Value::Infinity(asof_type)); - break; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - case ExpressionType::COMPARE_LESSTHAN: - asof_end->orders.emplace_back(OrderType::DESCENDING, OrderByNullType::NULLS_FIRST, asof_column->Copy()); - asof_end->default_expr = make_uniq(Value::NegativeInfinity(asof_type)); - break; - default: - throw InternalException("Invalid ASOF JOIN ordering for WINDOW"); - } - - asof_end->start = WindowBoundary::UNBOUNDED_PRECEDING; - asof_end->end = WindowBoundary::CURRENT_ROW_ROWS; - - vector> window_select; - window_select.emplace_back(std::move(asof_end)); - - auto &window_types = op.children[1]->types; - window_types.emplace_back(asof_type); - - auto window = make_uniq(window_types, std::move(window_select), rhs_cardinality); - window->children.emplace_back(std::move(right)); - - // IEJoin(left, window, conditions || asof_comp ~op asof_end) - JoinCondition asof_upper; - asof_upper.left = asof_comp.left->Copy(); - asof_upper.right = make_uniq(asof_type, window_types.size() - 1); - switch (asof_comp.comparison) { - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - asof_upper.comparison = ExpressionType::COMPARE_LESSTHAN; - break; - case ExpressionType::COMPARE_GREATERTHAN: - asof_upper.comparison = ExpressionType::COMPARE_LESSTHANOREQUALTO; - break; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - asof_upper.comparison = ExpressionType::COMPARE_GREATERTHAN; - break; - case ExpressionType::COMPARE_LESSTHAN: - asof_upper.comparison = ExpressionType::COMPARE_GREATERTHANOREQUALTO; - break; - default: - throw InternalException("Invalid ASOF JOIN comparison for IEJoin"); - } - - op.conditions.emplace_back(std::move(asof_upper)); - - return make_uniq(op, std::move(left), std::move(window), std::move(op.conditions), op.join_type, - lhs_cardinality); -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalColumnDataGet &op) { - D_ASSERT(op.children.size() == 0); - D_ASSERT(op.collection); - - // create a PhysicalChunkScan pointing towards the owned collection - auto chunk_scan = make_uniq(op.types, PhysicalOperatorType::COLUMN_DATA_SCAN, - op.estimated_cardinality, std::move(op.collection)); - return std::move(chunk_scan); -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - - -namespace duckdb { - -static bool CanPlanIndexJoin(ClientContext &context, TableScanBindData &bind_data, PhysicalTableScan &scan) { - auto &table = bind_data.table; - auto &transaction = DuckTransaction::Get(context, table.catalog); - auto &local_storage = LocalStorage::Get(transaction); - if (local_storage.Find(table.GetStorage())) { - // transaction local appends: skip index join - return false; - } - if (scan.table_filters && !scan.table_filters->filters.empty()) { - // table scan filters - return false; - } - return true; -} - -bool ExtractNumericValue(Value val, int64_t &result) { - if (!val.type().IsIntegral()) { - switch (val.type().InternalType()) { - case PhysicalType::INT16: - result = val.GetValueUnsafe(); - break; - case PhysicalType::INT32: - result = val.GetValueUnsafe(); - break; - case PhysicalType::INT64: - result = val.GetValueUnsafe(); - break; - default: - return false; - } - } else { - if (!val.DefaultTryCastAs(LogicalType::BIGINT)) { - return false; - } - result = val.GetValue(); - } - return true; -} - -void CheckForPerfectJoinOpt(LogicalComparisonJoin &op, PerfectHashJoinStats &join_state) { - // we only do this optimization for inner joins - if (op.join_type != JoinType::INNER) { - return; - } - // with one condition - if (op.conditions.size() != 1) { - return; - } - // with propagated statistics - if (op.join_stats.empty()) { - return; - } - for (auto &type : op.children[1]->types) { - switch (type.InternalType()) { - case PhysicalType::STRUCT: - case PhysicalType::LIST: - return; - default: - break; - } - } - // with equality condition and null values not equal - for (auto &&condition : op.conditions) { - if (condition.comparison != ExpressionType::COMPARE_EQUAL) { - return; - } - } - // with integral internal types - for (auto &&join_stat : op.join_stats) { - if (!TypeIsInteger(join_stat->GetType().InternalType()) || - join_stat->GetType().InternalType() == PhysicalType::INT128) { - // perfect join not possible for non-integral types or hugeint - return; - } - } - - // and when the build range is smaller than the threshold - auto &stats_build = *op.join_stats[0].get(); // lhs stats - if (!NumericStats::HasMinMax(stats_build)) { - return; - } - int64_t min_value, max_value; - if (!ExtractNumericValue(NumericStats::Min(stats_build), min_value) || - !ExtractNumericValue(NumericStats::Max(stats_build), max_value)) { - return; - } - int64_t build_range; - if (!TrySubtractOperator::Operation(max_value, min_value, build_range)) { - return; - } - - // Fill join_stats for invisible join - auto &stats_probe = *op.join_stats[1].get(); // rhs stats - if (!NumericStats::HasMinMax(stats_probe)) { - return; - } - - // The max size our build must have to run the perfect HJ - const idx_t MAX_BUILD_SIZE = 1000000; - join_state.probe_min = NumericStats::Min(stats_probe); - join_state.probe_max = NumericStats::Max(stats_probe); - join_state.build_min = NumericStats::Min(stats_build); - join_state.build_max = NumericStats::Max(stats_build); - join_state.estimated_cardinality = op.estimated_cardinality; - join_state.build_range = build_range; - if (join_state.build_range > MAX_BUILD_SIZE) { - return; - } - if (NumericStats::Min(stats_build) <= NumericStats::Min(stats_probe) && - NumericStats::Max(stats_probe) <= NumericStats::Max(stats_build)) { - join_state.is_probe_in_domain = true; - } - join_state.is_build_small = true; - return; -} - -static optional_ptr CanUseIndexJoin(TableScanBindData &tbl, Expression &expr) { - optional_ptr result; - tbl.table.GetStorage().info->indexes.Scan([&](Index &index) { - if (index.unbound_expressions.size() != 1) { - return false; - } - if (expr.alias == index.unbound_expressions[0]->alias) { - result = &index; - return true; - } - return false; - }); - return result; -} - -optional_ptr CheckIndexJoin(ClientContext &context, LogicalComparisonJoin &op, PhysicalOperator &plan, - Expression &condition) { - if (op.type == LogicalOperatorType::LOGICAL_DELIM_JOIN) { - return nullptr; - } - // check if one of the tables has an index on column - if (op.join_type != JoinType::INNER) { - return nullptr; - } - if (op.conditions.size() != 1) { - return nullptr; - } - // check if the child is (1) a table scan, and (2) has an index on the join condition - if (plan.type != PhysicalOperatorType::TABLE_SCAN) { - return nullptr; - } - auto &tbl_scan = plan.Cast(); - auto tbl_data = dynamic_cast(tbl_scan.bind_data.get()); - if (!tbl_data) { - return nullptr; - } - optional_ptr result; - if (CanPlanIndexJoin(context, *tbl_data, tbl_scan)) { - result = CanUseIndexJoin(*tbl_data, condition); - } - return result; -} - -static bool PlanIndexJoin(ClientContext &context, LogicalComparisonJoin &op, unique_ptr &plan, - unique_ptr &left, unique_ptr &right, - optional_ptr index, bool swap_condition = false) { - if (!index) { - return false; - } - - // index joins are disabled if enable_optimizer is false - if (!ClientConfig::GetConfig(context).enable_optimizer) { - return false; - } - - // index joins are disabled on default - auto force_index_join = ClientConfig::GetConfig(context).force_index_join; - if (!ClientConfig::GetConfig(context).enable_index_join && !force_index_join) { - return false; - } - - // check if the cardinality difference justifies an index join - auto index_join_is_applicable = left->estimated_cardinality < 0.01 * right->estimated_cardinality; - if (!index_join_is_applicable && !force_index_join) { - return false; - } - - // plan the index join - if (swap_condition) { - swap(op.conditions[0].left, op.conditions[0].right); - swap(op.left_projection_map, op.right_projection_map); - } - D_ASSERT(right->type == PhysicalOperatorType::TABLE_SCAN); - auto &tbl_scan = right->Cast(); - - plan = make_uniq(op, std::move(left), std::move(right), std::move(op.conditions), op.join_type, - op.left_projection_map, op.right_projection_map, tbl_scan.column_ids, *index, - !swap_condition, op.estimated_cardinality); - return true; -} - -static bool PlanIndexJoin(ClientContext &context, LogicalComparisonJoin &op, unique_ptr &plan, - unique_ptr &left, unique_ptr &right) { - if (op.conditions.empty()) { - return false; - } - // check if we can plan an index join on the RHS - auto right_index = CheckIndexJoin(context, op, *right, *op.conditions[0].right); - if (PlanIndexJoin(context, op, plan, left, right, right_index)) { - return true; - } - // else check if we can plan an index join on the left side - auto left_index = CheckIndexJoin(context, op, *left, *op.conditions[0].left); - if (PlanIndexJoin(context, op, plan, right, left, left_index, true)) { - return true; - } - return false; -} - -static void RewriteJoinCondition(Expression &expr, idx_t offset) { - if (expr.type == ExpressionType::BOUND_REF) { - auto &ref = expr.Cast(); - ref.index += offset; - } - ExpressionIterator::EnumerateChildren(expr, [&](Expression &child) { RewriteJoinCondition(child, offset); }); -} - -unique_ptr PhysicalPlanGenerator::PlanComparisonJoin(LogicalComparisonJoin &op) { - // now visit the children - D_ASSERT(op.children.size() == 2); - idx_t lhs_cardinality = op.children[0]->EstimateCardinality(context); - idx_t rhs_cardinality = op.children[1]->EstimateCardinality(context); - auto left = CreatePlan(*op.children[0]); - auto right = CreatePlan(*op.children[1]); - left->estimated_cardinality = lhs_cardinality; - right->estimated_cardinality = rhs_cardinality; - D_ASSERT(left && right); - - if (op.conditions.empty()) { - // no conditions: insert a cross product - return make_uniq(op.types, std::move(left), std::move(right), op.estimated_cardinality); - } - - bool has_equality = false; - size_t has_range = 0; - for (size_t c = 0; c < op.conditions.size(); ++c) { - auto &cond = op.conditions[c]; - switch (cond.comparison) { - case ExpressionType::COMPARE_EQUAL: - case ExpressionType::COMPARE_NOT_DISTINCT_FROM: - has_equality = true; - break; - case ExpressionType::COMPARE_LESSTHAN: - case ExpressionType::COMPARE_GREATERTHAN: - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - ++has_range; - break; - case ExpressionType::COMPARE_NOTEQUAL: - case ExpressionType::COMPARE_DISTINCT_FROM: - break; - default: - throw NotImplementedException("Unimplemented comparison join"); - } - } - - bool can_merge = has_range > 0; - bool can_iejoin = has_range >= 2 && recursive_cte_tables.empty(); - switch (op.join_type) { - case JoinType::SEMI: - case JoinType::ANTI: - case JoinType::MARK: - can_merge = can_merge && op.conditions.size() == 1; - can_iejoin = false; - break; - default: - break; - } - - // TODO: Extend PWMJ to handle all comparisons and projection maps - const auto prefer_range_joins = (ClientConfig::GetConfig(context).prefer_range_joins && can_iejoin); - - unique_ptr plan; - if (has_equality && !prefer_range_joins) { - // check if we can use an index join - if (PlanIndexJoin(context, op, plan, left, right)) { - return plan; - } - // Equality join with small number of keys : possible perfect join optimization - PerfectHashJoinStats perfect_join_stats; - CheckForPerfectJoinOpt(op, perfect_join_stats); - plan = make_uniq(op, std::move(left), std::move(right), std::move(op.conditions), - op.join_type, op.left_projection_map, op.right_projection_map, - std::move(op.mark_types), op.estimated_cardinality, perfect_join_stats); - - } else { - static constexpr const idx_t NESTED_LOOP_JOIN_THRESHOLD = 5; - if (left->estimated_cardinality <= NESTED_LOOP_JOIN_THRESHOLD || - right->estimated_cardinality <= NESTED_LOOP_JOIN_THRESHOLD) { - can_iejoin = false; - can_merge = false; - } - if (can_iejoin) { - plan = make_uniq(op, std::move(left), std::move(right), std::move(op.conditions), - op.join_type, op.estimated_cardinality); - } else if (can_merge) { - // range join: use piecewise merge join - plan = - make_uniq(op, std::move(left), std::move(right), std::move(op.conditions), - op.join_type, op.estimated_cardinality); - } else if (PhysicalNestedLoopJoin::IsSupported(op.conditions, op.join_type)) { - // inequality join: use nested loop - plan = make_uniq(op, std::move(left), std::move(right), std::move(op.conditions), - op.join_type, op.estimated_cardinality); - } else { - for (auto &cond : op.conditions) { - RewriteJoinCondition(*cond.right, left->types.size()); - } - auto condition = JoinCondition::CreateExpression(std::move(op.conditions)); - plan = make_uniq(op, std::move(left), std::move(right), std::move(condition), - op.join_type, op.estimated_cardinality); - } - } - return plan; -} - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalComparisonJoin &op) { - switch (op.type) { - case LogicalOperatorType::LOGICAL_ASOF_JOIN: - return PlanAsOfJoin(op); - case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: - return PlanComparisonJoin(op); - case LogicalOperatorType::LOGICAL_DELIM_JOIN: - return PlanDelimJoin(op); - default: - throw InternalException("Unrecognized operator type for LogicalComparisonJoin"); - } -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalCopyToFile &op) { - auto plan = CreatePlan(*op.children[0]); - bool preserve_insertion_order = PhysicalPlanGenerator::PreserveInsertionOrder(context, *plan); - bool supports_batch_index = PhysicalPlanGenerator::UseBatchIndex(context, *plan); - auto &fs = FileSystem::GetFileSystem(context); - op.file_path = fs.ExpandPath(op.file_path); - if (op.use_tmp_file) { - op.file_path += ".tmp"; - } - if (op.per_thread_output || op.partition_output || !op.partition_columns.empty() || op.overwrite_or_ignore) { - // hive-partitioning/per-thread output does not care about insertion order, and does not support batch indexes - preserve_insertion_order = false; - supports_batch_index = false; - } - auto mode = CopyFunctionExecutionMode::REGULAR_COPY_TO_FILE; - if (op.function.execution_mode) { - mode = op.function.execution_mode(preserve_insertion_order, supports_batch_index); - } - if (mode == CopyFunctionExecutionMode::BATCH_COPY_TO_FILE) { - if (!supports_batch_index) { - throw InternalException("BATCH_COPY_TO_FILE can only be used if batch indexes are supported"); - } - // batched copy to file - if (op.function.desired_batch_size) { - auto copy = make_uniq(op.types, op.function, std::move(op.bind_data), - op.estimated_cardinality); - copy->file_path = op.file_path; - copy->use_tmp_file = op.use_tmp_file; - copy->children.push_back(std::move(plan)); - return std::move(copy); - } else { - auto copy = make_uniq(op.types, op.function, std::move(op.bind_data), - op.estimated_cardinality); - copy->file_path = op.file_path; - copy->use_tmp_file = op.use_tmp_file; - copy->children.push_back(std::move(plan)); - return std::move(copy); - } - } - // COPY from select statement to file - auto copy = make_uniq(op.types, op.function, std::move(op.bind_data), op.estimated_cardinality); - copy->file_path = op.file_path; - copy->use_tmp_file = op.use_tmp_file; - copy->overwrite_or_ignore = op.overwrite_or_ignore; - copy->filename_pattern = op.filename_pattern; - copy->per_thread_output = op.per_thread_output; - copy->partition_output = op.partition_output; - copy->partition_columns = op.partition_columns; - copy->names = op.names; - copy->expected_types = op.expected_types; - copy->parallel = mode == CopyFunctionExecutionMode::PARALLEL_COPY_TO_FILE; - - copy->children.push_back(std::move(plan)); - return std::move(copy); -} - -} // namespace duckdb - - - - - - - - - - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalCreate &op) { - switch (op.type) { - case LogicalOperatorType::LOGICAL_CREATE_SEQUENCE: - return make_uniq(unique_ptr_cast(std::move(op.info)), - op.estimated_cardinality); - case LogicalOperatorType::LOGICAL_CREATE_VIEW: - return make_uniq(unique_ptr_cast(std::move(op.info)), - op.estimated_cardinality); - case LogicalOperatorType::LOGICAL_CREATE_SCHEMA: - return make_uniq(unique_ptr_cast(std::move(op.info)), - op.estimated_cardinality); - case LogicalOperatorType::LOGICAL_CREATE_MACRO: - return make_uniq(unique_ptr_cast(std::move(op.info)), - op.estimated_cardinality); - case LogicalOperatorType::LOGICAL_CREATE_TYPE: { - unique_ptr create = make_uniq( - unique_ptr_cast(std::move(op.info)), op.estimated_cardinality); - if (!op.children.empty()) { - D_ASSERT(op.children.size() == 1); - auto plan = CreatePlan(*op.children[0]); - create->children.push_back(std::move(plan)); - } - return create; - } - default: - throw NotImplementedException("Unimplemented type for logical simple create"); - } -} - -} // namespace duckdb - - - - - - - - - - - - - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalCreateIndex &op) { - // generate a physical plan for the parallel index creation which consists of the following operators - // table scan - projection (for expression execution) - filter (NOT NULL) - order - create index - D_ASSERT(op.children.size() == 1); - auto table_scan = CreatePlan(*op.children[0]); - - // validate that all expressions contain valid scalar functions - // e.g. get_current_timestamp(), random(), and sequence values are not allowed as index keys - // because they make deletions and lookups unfeasible - for (idx_t i = 0; i < op.unbound_expressions.size(); i++) { - auto &expr = op.unbound_expressions[i]; - if (expr->HasSideEffects()) { - throw BinderException("Index keys cannot contain expressions with side " - "effects."); - } - } - - // If we get here without the plan and the index type is not ART, we throw an exception - // because we don't support any other index type yet. However an operator extension could have - // replaced this part of the plan with a different index creation operator. - if (op.info->index_type != IndexType::ART) { - throw BinderException("Index type not supported"); - } - - // table scan operator for index key columns and row IDs - dependencies.AddDependency(op.table); - - D_ASSERT(op.info->scan_types.size() - 1 <= op.info->names.size()); - D_ASSERT(op.info->scan_types.size() - 1 <= op.info->column_ids.size()); - - // projection to execute expressions on the key columns - - vector new_column_types; - vector> select_list; - for (idx_t i = 0; i < op.expressions.size(); i++) { - new_column_types.push_back(op.expressions[i]->return_type); - select_list.push_back(std::move(op.expressions[i])); - } - new_column_types.emplace_back(LogicalType::ROW_TYPE); - select_list.push_back(make_uniq(LogicalType::ROW_TYPE, op.info->scan_types.size() - 1)); - - auto projection = make_uniq(new_column_types, std::move(select_list), op.estimated_cardinality); - projection->children.push_back(std::move(table_scan)); - - // filter operator for IS_NOT_NULL on each key column - - vector filter_types; - vector> filter_select_list; - - for (idx_t i = 0; i < new_column_types.size() - 1; i++) { - filter_types.push_back(new_column_types[i]); - auto is_not_null_expr = - make_uniq(ExpressionType::OPERATOR_IS_NOT_NULL, LogicalType::BOOLEAN); - auto bound_ref = make_uniq(new_column_types[i], i); - is_not_null_expr->children.push_back(std::move(bound_ref)); - filter_select_list.push_back(std::move(is_not_null_expr)); - } - - auto null_filter = - make_uniq(std::move(filter_types), std::move(filter_select_list), op.estimated_cardinality); - null_filter->types.emplace_back(LogicalType::ROW_TYPE); - null_filter->children.push_back(std::move(projection)); - - // determine if we sort the data prior to index creation - // we don't sort, if either VARCHAR or compound key - auto perform_sorting = true; - if (op.unbound_expressions.size() > 1) { - perform_sorting = false; - } else if (op.unbound_expressions[0]->return_type.InternalType() == PhysicalType::VARCHAR) { - perform_sorting = false; - } - - // actual physical create index operator - - auto physical_create_index = - make_uniq(op, op.table, op.info->column_ids, std::move(op.info), - std::move(op.unbound_expressions), op.estimated_cardinality, perform_sorting); - - if (perform_sorting) { - - // optional order operator - vector orders; - vector projections; - for (idx_t i = 0; i < new_column_types.size() - 1; i++) { - auto col_expr = make_uniq_base(new_column_types[i], i); - orders.emplace_back(OrderType::ASCENDING, OrderByNullType::NULLS_FIRST, std::move(col_expr)); - projections.emplace_back(i); - } - projections.emplace_back(new_column_types.size() - 1); - - auto physical_order = make_uniq(new_column_types, std::move(orders), std::move(projections), - op.estimated_cardinality); - physical_order->children.push_back(std::move(null_filter)); - - physical_create_index->children.push_back(std::move(physical_order)); - } else { - - // no ordering - physical_create_index->children.push_back(std::move(null_filter)); - } - - return std::move(physical_create_index); -} - -} // namespace duckdb - - - - - - - - - - - - - - -namespace duckdb { - -unique_ptr DuckCatalog::PlanCreateTableAs(ClientContext &context, LogicalCreateTable &op, - unique_ptr plan) { - bool parallel_streaming_insert = !PhysicalPlanGenerator::PreserveInsertionOrder(context, *plan); - bool use_batch_index = PhysicalPlanGenerator::UseBatchIndex(context, *plan); - auto num_threads = TaskScheduler::GetScheduler(context).NumberOfThreads(); - unique_ptr create; - if (!parallel_streaming_insert && use_batch_index) { - create = make_uniq(op, op.schema, std::move(op.info), op.estimated_cardinality); - - } else { - create = make_uniq(op, op.schema, std::move(op.info), op.estimated_cardinality, - parallel_streaming_insert && num_threads > 1); - } - - D_ASSERT(op.children.size() == 1); - create->children.push_back(std::move(plan)); - return create; -} - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalCreateTable &op) { - const auto &create_info = op.info->base->Cast(); - auto &catalog = op.info->schema.catalog; - auto existing_entry = catalog.GetEntry(context, create_info.schema, create_info.table, - OnEntryNotFound::RETURN_NULL); - bool replace = op.info->Base().on_conflict == OnCreateConflict::REPLACE_ON_CONFLICT; - if ((!existing_entry || replace) && !op.children.empty()) { - auto plan = CreatePlan(*op.children[0]); - return op.schema.catalog.PlanCreateTableAs(context, op, std::move(plan)); - } else { - return make_uniq(op, op.schema, std::move(op.info), op.estimated_cardinality); - } -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalCrossProduct &op) { - D_ASSERT(op.children.size() == 2); - - auto left = CreatePlan(*op.children[0]); - auto right = CreatePlan(*op.children[1]); - return make_uniq(op.types, std::move(left), std::move(right), op.estimated_cardinality); -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalMaterializedCTE &op) { - D_ASSERT(op.children.size() == 2); - - // Create the working_table that the PhysicalCTE will use for evaluation. - auto working_table = std::make_shared(context, op.children[0]->types); - - // Add the ColumnDataCollection to the context of this PhysicalPlanGenerator - recursive_cte_tables[op.table_index] = working_table; - - // Create the plan for the left side. This is the materialization. - auto left = CreatePlan(*op.children[0]); - // Initialize an empty vector to collect the scan operators. - materialized_ctes.insert(op.table_index); - auto right = CreatePlan(*op.children[1]); - - auto cte = make_uniq(op.ctename, op.table_index, op.children[1]->types, std::move(left), - std::move(right), op.estimated_cardinality); - cte->working_table = working_table; - - return std::move(cte); -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -unique_ptr DuckCatalog::PlanDelete(ClientContext &context, LogicalDelete &op, - unique_ptr plan) { - // get the index of the row_id column - auto &bound_ref = op.expressions[0]->Cast(); - - auto del = make_uniq(op.types, op.table, op.table.GetStorage(), bound_ref.index, - op.estimated_cardinality, op.return_chunk); - del->children.push_back(std::move(plan)); - return std::move(del); -} - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalDelete &op) { - D_ASSERT(op.children.size() == 1); - D_ASSERT(op.expressions.size() == 1); - D_ASSERT(op.expressions[0]->type == ExpressionType::BOUND_REF); - - auto plan = CreatePlan(*op.children[0]); - - dependencies.AddDependency(op.table); - return op.table.catalog.PlanDelete(context, op, std::move(plan)); -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalDelimGet &op) { - D_ASSERT(op.children.empty()); - - // create a PhysicalChunkScan without an owned_collection, the collection will be added later - auto chunk_scan = - make_uniq(op.types, PhysicalOperatorType::DELIM_SCAN, op.estimated_cardinality); - return std::move(chunk_scan); -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -static void GatherDelimScans(const PhysicalOperator &op, vector> &delim_scans) { - if (op.type == PhysicalOperatorType::DELIM_SCAN) { - delim_scans.push_back(op); - } - for (auto &child : op.children) { - GatherDelimScans(*child, delim_scans); - } -} - -unique_ptr PhysicalPlanGenerator::PlanDelimJoin(LogicalComparisonJoin &op) { - // first create the underlying join - auto plan = PlanComparisonJoin(op); - // this should create a join, not a cross product - D_ASSERT(plan && plan->type != PhysicalOperatorType::CROSS_PRODUCT); - // duplicate eliminated join - // first gather the scans on the duplicate eliminated data set from the RHS - vector> delim_scans; - GatherDelimScans(*plan->children[1], delim_scans); - if (delim_scans.empty()) { - // no duplicate eliminated scans in the RHS! - // in this case we don't need to create a delim join - // just push the normal join - return plan; - } - vector delim_types; - vector> distinct_groups, distinct_expressions; - for (auto &delim_expr : op.duplicate_eliminated_columns) { - D_ASSERT(delim_expr->type == ExpressionType::BOUND_REF); - auto &bound_ref = delim_expr->Cast(); - delim_types.push_back(bound_ref.return_type); - distinct_groups.push_back(make_uniq(bound_ref.return_type, bound_ref.index)); - } - // now create the duplicate eliminated join - auto delim_join = make_uniq(op.types, std::move(plan), delim_scans, op.estimated_cardinality); - // we still have to create the DISTINCT clause that is used to generate the duplicate eliminated chunk - delim_join->distinct = make_uniq(context, delim_types, std::move(distinct_expressions), - std::move(distinct_groups), op.estimated_cardinality); - return std::move(delim_join); -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalDistinct &op) { - D_ASSERT(op.children.size() == 1); - auto child = CreatePlan(*op.children[0]); - auto &distinct_targets = op.distinct_targets; - D_ASSERT(child); - D_ASSERT(!distinct_targets.empty()); - - auto &types = child->GetTypes(); - vector> groups, aggregates, projections; - idx_t group_count = distinct_targets.size(); - unordered_map group_by_references; - vector aggregate_types; - // creates one group per distinct_target - for (idx_t i = 0; i < distinct_targets.size(); i++) { - auto &target = distinct_targets[i]; - if (target->type == ExpressionType::BOUND_REF) { - auto &bound_ref = target->Cast(); - group_by_references[bound_ref.index] = i; - } - aggregate_types.push_back(target->return_type); - groups.push_back(std::move(target)); - } - bool requires_projection = false; - if (types.size() != group_count) { - requires_projection = true; - } - // we need to create one aggregate per column in the select_list - for (idx_t i = 0; i < types.size(); ++i) { - auto logical_type = types[i]; - // check if we can directly refer to a group, or if we need to push an aggregate with FIRST - auto entry = group_by_references.find(i); - if (entry != group_by_references.end()) { - auto group_index = entry->second; - // entry is found: can directly refer to a group - projections.push_back(make_uniq(logical_type, group_index)); - if (group_index != i) { - // we require a projection only if this group element is out of order - requires_projection = true; - } - } else { - if (op.distinct_type == DistinctType::DISTINCT && op.order_by) { - throw InternalException("Entry that is not a group, but not a DISTINCT ON aggregate"); - } - // entry is not one of the groups: need to push a FIRST aggregate - auto bound = make_uniq(logical_type, i); - vector> first_children; - first_children.push_back(std::move(bound)); - - FunctionBinder function_binder(context); - auto first_aggregate = function_binder.BindAggregateFunction( - FirstFun::GetFunction(logical_type), std::move(first_children), nullptr, AggregateType::NON_DISTINCT); - first_aggregate->order_bys = op.order_by ? op.order_by->Copy() : nullptr; - // add the projection - projections.push_back(make_uniq(logical_type, group_count + aggregates.size())); - // push it to the list of aggregates - aggregate_types.push_back(logical_type); - aggregates.push_back(std::move(first_aggregate)); - requires_projection = true; - } - } - - child = ExtractAggregateExpressions(std::move(child), aggregates, groups); - - // we add a physical hash aggregation in the plan to select the distinct groups - auto groupby = make_uniq(context, aggregate_types, std::move(aggregates), std::move(groups), - child->estimated_cardinality); - groupby->children.push_back(std::move(child)); - if (!requires_projection) { - return std::move(groupby); - } - - // we add a physical projection on top of the aggregation to project all members in the select list - auto aggr_projection = make_uniq(types, std::move(projections), groupby->estimated_cardinality); - aggr_projection->children.push_back(std::move(groupby)); - return std::move(aggr_projection); -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalDummyScan &op) { - D_ASSERT(op.children.size() == 0); - return make_uniq(op.types, op.estimated_cardinality); -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalEmptyResult &op) { - D_ASSERT(op.children.size() == 0); - return make_uniq(op.types, op.estimated_cardinality); -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalExecute &op) { - if (!op.prepared->plan) { - D_ASSERT(op.children.size() == 1); - auto owned_plan = CreatePlan(*op.children[0]); - auto execute = make_uniq(*owned_plan); - execute->owned_plan = std::move(owned_plan); - execute->prepared = std::move(op.prepared); - return std::move(execute); - } else { - D_ASSERT(op.children.size() == 0); - return make_uniq(*op.prepared->plan); - } -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalExplain &op) { - D_ASSERT(op.children.size() == 1); - auto logical_plan_opt = op.children[0]->ToString(); - auto plan = CreatePlan(*op.children[0]); - if (op.explain_type == ExplainType::EXPLAIN_ANALYZE) { - auto result = make_uniq(op.types); - result->children.push_back(std::move(plan)); - return std::move(result); - } - - op.physical_plan = plan->ToString(); - // the output of the explain - vector keys, values; - switch (ClientConfig::GetConfig(context).explain_output_type) { - case ExplainOutputType::OPTIMIZED_ONLY: - keys = {"logical_opt"}; - values = {logical_plan_opt}; - break; - case ExplainOutputType::PHYSICAL_ONLY: - keys = {"physical_plan"}; - values = {op.physical_plan}; - break; - default: - keys = {"logical_plan", "logical_opt", "physical_plan"}; - values = {op.logical_plan_unopt, logical_plan_opt, op.physical_plan}; - } - - // create a ColumnDataCollection from the output - auto &allocator = Allocator::Get(context); - vector plan_types {LogicalType::VARCHAR, LogicalType::VARCHAR}; - auto collection = - make_uniq(context, plan_types, ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR); - - DataChunk chunk; - chunk.Initialize(allocator, op.types); - for (idx_t i = 0; i < keys.size(); i++) { - chunk.SetValue(0, chunk.size(), Value(keys[i])); - chunk.SetValue(1, chunk.size(), Value(values[i])); - chunk.SetCardinality(chunk.size() + 1); - if (chunk.size() == STANDARD_VECTOR_SIZE) { - collection->Append(chunk); - chunk.Reset(); - } - } - collection->Append(chunk); - - // create a chunk scan to output the result - auto chunk_scan = make_uniq(op.types, PhysicalOperatorType::COLUMN_DATA_SCAN, - op.estimated_cardinality, std::move(collection)); - return std::move(chunk_scan); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalExport &op) { - auto &config = DBConfig::GetConfig(context); - if (!config.options.enable_external_access) { - throw PermissionException("Export is disabled through configuration"); - } - auto export_node = make_uniq(op.types, op.function, std::move(op.copy_info), - op.estimated_cardinality, op.exported_tables); - // plan the underlying copy statements, if any - if (!op.children.empty()) { - auto plan = CreatePlan(*op.children[0]); - export_node->children.push_back(std::move(plan)); - } - return std::move(export_node); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalExpressionGet &op) { - D_ASSERT(op.children.size() == 1); - auto plan = CreatePlan(*op.children[0]); - - auto expr_scan = make_uniq(op.types, std::move(op.expressions), op.estimated_cardinality); - expr_scan->children.push_back(std::move(plan)); - if (!expr_scan->IsFoldable()) { - return std::move(expr_scan); - } - auto &allocator = Allocator::Get(context); - // simple expression scan (i.e. no subqueries to evaluate and no prepared statement parameters) - // we can evaluate all the expressions right now and turn this into a chunk collection scan - auto chunk_scan = make_uniq(op.types, PhysicalOperatorType::COLUMN_DATA_SCAN, - expr_scan->expressions.size()); - chunk_scan->owned_collection = make_uniq(context, op.types); - chunk_scan->collection = chunk_scan->owned_collection.get(); - - DataChunk chunk; - chunk.Initialize(allocator, op.types); - - ColumnDataAppendState append_state; - chunk_scan->owned_collection->InitializeAppend(append_state); - for (idx_t expression_idx = 0; expression_idx < expr_scan->expressions.size(); expression_idx++) { - chunk.Reset(); - expr_scan->EvaluateExpression(context, expression_idx, nullptr, chunk); - chunk_scan->owned_collection->Append(append_state, chunk); - } - return std::move(chunk_scan); -} - -} // namespace duckdb - - - - - - - - - - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalFilter &op) { - D_ASSERT(op.children.size() == 1); - unique_ptr plan = CreatePlan(*op.children[0]); - if (!op.expressions.empty()) { - D_ASSERT(plan->types.size() > 0); - // create a filter if there is anything to filter - auto filter = make_uniq(plan->types, std::move(op.expressions), op.estimated_cardinality); - filter->children.push_back(std::move(plan)); - plan = std::move(filter); - } - if (!op.projection_map.empty()) { - // there is a projection map, generate a physical projection - vector> select_list; - for (idx_t i = 0; i < op.projection_map.size(); i++) { - select_list.push_back(make_uniq(op.types[i], op.projection_map[i])); - } - auto proj = make_uniq(op.types, std::move(select_list), op.estimated_cardinality); - proj->children.push_back(std::move(plan)); - plan = std::move(proj); - } - return plan; -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -unique_ptr CreateTableFilterSet(TableFilterSet &table_filters, vector &column_ids) { - // create the table filter map - auto table_filter_set = make_uniq(); - for (auto &table_filter : table_filters.filters) { - // find the relative column index from the absolute column index into the table - idx_t column_index = DConstants::INVALID_INDEX; - for (idx_t i = 0; i < column_ids.size(); i++) { - if (table_filter.first == column_ids[i]) { - column_index = i; - break; - } - } - if (column_index == DConstants::INVALID_INDEX) { - throw InternalException("Could not find column index for table filter"); - } - table_filter_set->filters[column_index] = std::move(table_filter.second); - } - return table_filter_set; -} - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalGet &op) { - if (!op.children.empty()) { - // this is for table producing functions that consume subquery results - D_ASSERT(op.children.size() == 1); - auto node = make_uniq(op.types, op.function, std::move(op.bind_data), op.column_ids, - op.estimated_cardinality, std::move(op.projected_input)); - node->children.push_back(CreatePlan(std::move(op.children[0]))); - return std::move(node); - } - if (!op.projected_input.empty()) { - throw InternalException("LogicalGet::project_input can only be set for table-in-out functions"); - } - - unique_ptr table_filters; - if (!op.table_filters.filters.empty()) { - table_filters = CreateTableFilterSet(op.table_filters, op.column_ids); - } - - if (op.function.dependency) { - op.function.dependency(dependencies, op.bind_data.get()); - } - // create the table scan node - if (!op.function.projection_pushdown) { - // function does not support projection pushdown - auto node = make_uniq(op.returned_types, op.function, std::move(op.bind_data), - op.returned_types, op.column_ids, vector(), op.names, - std::move(table_filters), op.estimated_cardinality, op.extra_info); - // first check if an additional projection is necessary - if (op.column_ids.size() == op.returned_types.size()) { - bool projection_necessary = false; - for (idx_t i = 0; i < op.column_ids.size(); i++) { - if (op.column_ids[i] != i) { - projection_necessary = true; - break; - } - } - if (!projection_necessary) { - // a projection is not necessary if all columns have been requested in-order - // in that case we just return the node - - return std::move(node); - } - } - // push a projection on top that does the projection - vector types; - vector> expressions; - for (auto &column_id : op.column_ids) { - if (column_id == COLUMN_IDENTIFIER_ROW_ID) { - types.emplace_back(LogicalType::BIGINT); - expressions.push_back(make_uniq(Value::BIGINT(0))); - } else { - auto type = op.returned_types[column_id]; - types.push_back(type); - expressions.push_back(make_uniq(type, column_id)); - } - } - - auto projection = - make_uniq(std::move(types), std::move(expressions), op.estimated_cardinality); - projection->children.push_back(std::move(node)); - return std::move(projection); - } else { - return make_uniq(op.types, op.function, std::move(op.bind_data), op.returned_types, - op.column_ids, op.projection_ids, op.names, std::move(table_filters), - op.estimated_cardinality, op.extra_info); - } -} - -} // namespace duckdb - - - - - - - - - - -namespace duckdb { - -static OrderPreservationType OrderPreservationRecursive(PhysicalOperator &op) { - if (op.IsSource()) { - return op.SourceOrder(); - } - for (auto &child : op.children) { - auto child_preservation = OrderPreservationRecursive(*child); - if (child_preservation != OrderPreservationType::INSERTION_ORDER) { - return child_preservation; - } - } - return OrderPreservationType::INSERTION_ORDER; -} - -bool PhysicalPlanGenerator::PreserveInsertionOrder(ClientContext &context, PhysicalOperator &plan) { - auto &config = DBConfig::GetConfig(context); - - auto preservation_type = OrderPreservationRecursive(plan); - if (preservation_type == OrderPreservationType::FIXED_ORDER) { - // always need to maintain preservation order - return true; - } - if (preservation_type == OrderPreservationType::NO_ORDER) { - // never need to preserve order - return false; - } - // preserve insertion order - check flags - if (!config.options.preserve_insertion_order) { - // preserving insertion order is disabled by config - return false; - } - return true; -} - -bool PhysicalPlanGenerator::PreserveInsertionOrder(PhysicalOperator &plan) { - return PreserveInsertionOrder(context, plan); -} - -bool PhysicalPlanGenerator::UseBatchIndex(ClientContext &context, PhysicalOperator &plan) { - // TODO: always preserve order if query contains ORDER BY - auto &scheduler = TaskScheduler::GetScheduler(context); - if (scheduler.NumberOfThreads() == 1) { - // batch index usage only makes sense if we are using multiple threads - return false; - } - if (!plan.AllSourcesSupportBatchIndex()) { - // batch index is not supported - return false; - } - return true; -} - -bool PhysicalPlanGenerator::UseBatchIndex(PhysicalOperator &plan) { - return UseBatchIndex(context, plan); -} - -unique_ptr DuckCatalog::PlanInsert(ClientContext &context, LogicalInsert &op, - unique_ptr plan) { - bool parallel_streaming_insert = !PhysicalPlanGenerator::PreserveInsertionOrder(context, *plan); - bool use_batch_index = PhysicalPlanGenerator::UseBatchIndex(context, *plan); - auto num_threads = TaskScheduler::GetScheduler(context).NumberOfThreads(); - if (op.return_chunk) { - // not supported for RETURNING (yet?) - parallel_streaming_insert = false; - use_batch_index = false; - } - if (op.action_type != OnConflictAction::THROW) { - // We don't support ON CONFLICT clause in batch insertion operation currently - use_batch_index = false; - } - if (op.action_type == OnConflictAction::UPDATE) { - // When we potentially need to perform updates, we have to check that row is not updated twice - // that currently needs to be done for every chunk, which would add a huge bottleneck to parallelized insertion - parallel_streaming_insert = false; - } - unique_ptr insert; - if (use_batch_index && !parallel_streaming_insert) { - insert = make_uniq(op.types, op.table, op.column_index_map, std::move(op.bound_defaults), - op.estimated_cardinality); - } else { - insert = make_uniq( - op.types, op.table, op.column_index_map, std::move(op.bound_defaults), std::move(op.expressions), - std::move(op.set_columns), std::move(op.set_types), op.estimated_cardinality, op.return_chunk, - parallel_streaming_insert && num_threads > 1, op.action_type, std::move(op.on_conflict_condition), - std::move(op.do_update_condition), std::move(op.on_conflict_filter), std::move(op.columns_to_fetch)); - } - D_ASSERT(plan); - insert->children.push_back(std::move(plan)); - return insert; -} - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalInsert &op) { - unique_ptr plan; - if (!op.children.empty()) { - D_ASSERT(op.children.size() == 1); - plan = CreatePlan(*op.children[0]); - } - dependencies.AddDependency(op.table); - return op.table.catalog.PlanInsert(context, op, std::move(plan)); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalLimit &op) { - D_ASSERT(op.children.size() == 1); - - auto plan = CreatePlan(*op.children[0]); - - unique_ptr limit; - if (!PreserveInsertionOrder(*plan)) { - // use parallel streaming limit if insertion order is not important - limit = make_uniq(op.types, (idx_t)op.limit_val, op.offset_val, std::move(op.limit), - std::move(op.offset), op.estimated_cardinality, true); - } else { - // maintaining insertion order is important - if (UseBatchIndex(*plan)) { - // source supports batch index: use parallel batch limit - limit = make_uniq(op.types, (idx_t)op.limit_val, op.offset_val, std::move(op.limit), - std::move(op.offset), op.estimated_cardinality); - } else { - // source does not support batch index: use a non-parallel streaming limit - limit = make_uniq(op.types, (idx_t)op.limit_val, op.offset_val, std::move(op.limit), - std::move(op.offset), op.estimated_cardinality, false); - } - } - - limit->children.push_back(std::move(plan)); - return limit; -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalLimitPercent &op) { - D_ASSERT(op.children.size() == 1); - - auto plan = CreatePlan(*op.children[0]); - - auto limit = make_uniq(op.types, op.limit_percent, op.offset_val, std::move(op.limit), - std::move(op.offset), op.estimated_cardinality); - limit->children.push_back(std::move(plan)); - return std::move(limit); -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalOrder &op) { - D_ASSERT(op.children.size() == 1); - - auto plan = CreatePlan(*op.children[0]); - if (!op.orders.empty()) { - vector projections; - if (op.projections.empty()) { - for (idx_t i = 0; i < plan->types.size(); i++) { - projections.push_back(i); - } - } else { - projections = std::move(op.projections); - } - auto order = - make_uniq(op.types, std::move(op.orders), std::move(projections), op.estimated_cardinality); - order->children.push_back(std::move(plan)); - plan = std::move(order); - } - return plan; -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalPivot &op) { - D_ASSERT(op.children.size() == 1); - auto child_plan = CreatePlan(*op.children[0]); - auto pivot = make_uniq(std::move(op.types), std::move(child_plan), std::move(op.bound_pivot)); - return std::move(pivot); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalPositionalJoin &op) { - D_ASSERT(op.children.size() == 2); - - auto left = CreatePlan(*op.children[0]); - auto right = CreatePlan(*op.children[1]); - switch (left->type) { - case PhysicalOperatorType::TABLE_SCAN: - case PhysicalOperatorType::POSITIONAL_SCAN: - switch (right->type) { - case PhysicalOperatorType::TABLE_SCAN: - case PhysicalOperatorType::POSITIONAL_SCAN: - return make_uniq(op.types, std::move(left), std::move(right)); - default: - break; - } - default: - break; - } - - return make_uniq(op.types, std::move(left), std::move(right), op.estimated_cardinality); -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalPragma &op) { - return make_uniq(op.function, op.info, op.estimated_cardinality); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalPrepare &op) { - D_ASSERT(op.children.size() <= 1); - - // generate physical plan - if (!op.children.empty()) { - auto plan = CreatePlan(*op.children[0]); - op.prepared->types = plan->types; - op.prepared->plan = std::move(plan); - } - - return make_uniq(op.name, std::move(op.prepared), op.estimated_cardinality); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalProjection &op) { - D_ASSERT(op.children.size() == 1); - auto plan = CreatePlan(*op.children[0]); - -#ifdef DEBUG - for (auto &expr : op.expressions) { - D_ASSERT(!expr->IsWindow()); - D_ASSERT(!expr->IsAggregate()); - } -#endif - if (plan->types.size() == op.types.size()) { - // check if this projection can be omitted entirely - // this happens if a projection simply emits the columns in the same order - // e.g. PROJECTION(#0, #1, #2, #3, ...) - bool omit_projection = true; - for (idx_t i = 0; i < op.types.size(); i++) { - if (op.expressions[i]->type == ExpressionType::BOUND_REF) { - auto &bound_ref = op.expressions[i]->Cast(); - if (bound_ref.index == i) { - continue; - } - } - omit_projection = false; - break; - } - if (omit_projection) { - // the projection only directly projects the child' columns: omit it entirely - return plan; - } - } - - auto projection = make_uniq(op.types, std::move(op.expressions), op.estimated_cardinality); - projection->children.push_back(std::move(plan)); - return std::move(projection); -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalRecursiveCTE &op) { - D_ASSERT(op.children.size() == 2); - - // Create the working_table that the PhysicalRecursiveCTE will use for evaluation. - auto working_table = std::make_shared(context, op.types); - - // Add the ColumnDataCollection to the context of this PhysicalPlanGenerator - recursive_cte_tables[op.table_index] = working_table; - - auto left = CreatePlan(*op.children[0]); - auto right = CreatePlan(*op.children[1]); - - auto cte = make_uniq(op.ctename, op.table_index, op.types, op.union_all, std::move(left), - std::move(right), op.estimated_cardinality); - cte->working_table = working_table; - - return std::move(cte); -} - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalCTERef &op) { - D_ASSERT(op.children.empty()); - - // Check if this LogicalCTERef is supposed to scan a materialized CTE. - if (op.materialized_cte == CTEMaterialize::CTE_MATERIALIZE_ALWAYS) { - // Lookup if there is a materialized CTE for the cte_index. - auto materialized_cte = materialized_ctes.find(op.cte_index); - - // If this check fails, this is a reference to a materialized recursive CTE. - if (materialized_cte != materialized_ctes.end()) { - auto chunk_scan = make_uniq(op.types, PhysicalOperatorType::CTE_SCAN, - op.estimated_cardinality, op.cte_index); - - auto cte = recursive_cte_tables.find(op.cte_index); - if (cte == recursive_cte_tables.end()) { - throw InvalidInputException("Referenced materialized CTE does not exist."); - } - chunk_scan->collection = cte->second.get(); - - return std::move(chunk_scan); - } - } - - // CreatePlan of a LogicalRecursiveCTE must have happened before. - auto cte = recursive_cte_tables.find(op.cte_index); - if (cte == recursive_cte_tables.end()) { - throw InvalidInputException("Referenced recursive CTE does not exist."); - } - - auto chunk_scan = make_uniq( - cte->second.get()->Types(), PhysicalOperatorType::RECURSIVE_CTE_SCAN, op.estimated_cardinality, op.cte_index); - - chunk_scan->collection = cte->second.get(); - return std::move(chunk_scan); -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalReset &op) { - return make_uniq(op.name, op.scope, op.estimated_cardinality); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalSample &op) { - D_ASSERT(op.children.size() == 1); - - auto plan = CreatePlan(*op.children[0]); - - unique_ptr sample; - switch (op.sample_options->method) { - case SampleMethod::RESERVOIR_SAMPLE: - sample = make_uniq(op.types, std::move(op.sample_options), op.estimated_cardinality); - break; - case SampleMethod::SYSTEM_SAMPLE: - case SampleMethod::BERNOULLI_SAMPLE: - if (!op.sample_options->is_percentage) { - throw ParserException("Sample method %s cannot be used with a discrete sample count, either switch to " - "reservoir sampling or use a sample_size", - EnumUtil::ToString(op.sample_options->method)); - } - sample = make_uniq(op.types, op.sample_options->method, - op.sample_options->sample_size.GetValue(), - op.sample_options->seed, op.estimated_cardinality); - break; - default: - throw InternalException("Unimplemented sample method"); - } - sample->children.push_back(std::move(plan)); - return sample; -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalSet &op) { - return make_uniq(op.name, op.value, op.scope, op.estimated_cardinality); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalSetOperation &op) { - D_ASSERT(op.children.size() == 2); - - auto left = CreatePlan(*op.children[0]); - auto right = CreatePlan(*op.children[1]); - - if (left->GetTypes() != right->GetTypes()) { - throw InvalidInputException("Type mismatch for SET OPERATION"); - } - - switch (op.type) { - case LogicalOperatorType::LOGICAL_UNION: - // UNION - return make_uniq(op.types, std::move(left), std::move(right), op.estimated_cardinality); - default: { - // EXCEPT/INTERSECT - D_ASSERT(op.type == LogicalOperatorType::LOGICAL_EXCEPT || op.type == LogicalOperatorType::LOGICAL_INTERSECT); - auto &types = left->GetTypes(); - vector conditions; - // create equality condition for all columns - for (idx_t i = 0; i < types.size(); i++) { - JoinCondition cond; - cond.left = make_uniq(types[i], i); - cond.right = make_uniq(types[i], i); - cond.comparison = ExpressionType::COMPARE_NOT_DISTINCT_FROM; - conditions.push_back(std::move(cond)); - } - // EXCEPT is ANTI join - // INTERSECT is SEMI join - PerfectHashJoinStats join_stats; // used in inner joins only - JoinType join_type = op.type == LogicalOperatorType::LOGICAL_EXCEPT ? JoinType::ANTI : JoinType::SEMI; - return make_uniq(op, std::move(left), std::move(right), std::move(conditions), join_type, - op.estimated_cardinality, join_stats); - } - } -} - -} // namespace duckdb - - - - - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalShow &op) { - DataChunk output; - output.Initialize(Allocator::Get(context), op.types); - - auto collection = make_uniq(context, op.types); - ColumnDataAppendState append_state; - collection->InitializeAppend(append_state); - for (idx_t column_idx = 0; column_idx < op.types_select.size(); column_idx++) { - auto type = op.types_select[column_idx]; - auto &name = op.aliases[column_idx]; - - // "name", TypeId::VARCHAR - output.SetValue(0, output.size(), Value(name)); - // "type", TypeId::VARCHAR - output.SetValue(1, output.size(), Value(type.ToString())); - // "null", TypeId::VARCHAR - output.SetValue(2, output.size(), Value("YES")); - // "pk", TypeId::BOOL - output.SetValue(3, output.size(), Value()); - // "dflt_value", TypeId::VARCHAR - output.SetValue(4, output.size(), Value()); - // "extra", TypeId::VARCHAR - output.SetValue(5, output.size(), Value()); - - output.SetCardinality(output.size() + 1); - if (output.size() == STANDARD_VECTOR_SIZE) { - collection->Append(append_state, output); - output.Reset(); - } - } - - collection->Append(append_state, output); - - // create a chunk scan to output the result - auto chunk_scan = make_uniq(op.types, PhysicalOperatorType::COLUMN_DATA_SCAN, - op.estimated_cardinality, std::move(collection)); - return std::move(chunk_scan); -} - -} // namespace duckdb - - - - - - - - - - - - - - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalSimple &op) { - switch (op.type) { - case LogicalOperatorType::LOGICAL_ALTER: - return make_uniq(unique_ptr_cast(std::move(op.info)), - op.estimated_cardinality); - case LogicalOperatorType::LOGICAL_DROP: - return make_uniq(unique_ptr_cast(std::move(op.info)), - op.estimated_cardinality); - case LogicalOperatorType::LOGICAL_TRANSACTION: - return make_uniq(unique_ptr_cast(std::move(op.info)), - op.estimated_cardinality); - case LogicalOperatorType::LOGICAL_VACUUM: { - auto result = make_uniq(unique_ptr_cast(std::move(op.info)), - op.estimated_cardinality); - if (!op.children.empty()) { - auto child = CreatePlan(*op.children[0]); - result->children.push_back(std::move(child)); - } - return std::move(result); - } - case LogicalOperatorType::LOGICAL_LOAD: - return make_uniq(unique_ptr_cast(std::move(op.info)), - op.estimated_cardinality); - case LogicalOperatorType::LOGICAL_ATTACH: - return make_uniq(unique_ptr_cast(std::move(op.info)), - op.estimated_cardinality); - case LogicalOperatorType::LOGICAL_DETACH: - return make_uniq(unique_ptr_cast(std::move(op.info)), - op.estimated_cardinality); - default: - throw NotImplementedException("Unimplemented type for logical simple operator"); - } -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalTopN &op) { - D_ASSERT(op.children.size() == 1); - - auto plan = CreatePlan(*op.children[0]); - - auto top_n = - make_uniq(op.types, std::move(op.orders), (idx_t)op.limit, op.offset, op.estimated_cardinality); - top_n->children.push_back(std::move(plan)); - return std::move(top_n); -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalUnnest &op) { - D_ASSERT(op.children.size() == 1); - auto plan = CreatePlan(*op.children[0]); - auto unnest = make_uniq(op.types, std::move(op.expressions), op.estimated_cardinality); - unnest->children.push_back(std::move(plan)); - return std::move(unnest); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -unique_ptr DuckCatalog::PlanUpdate(ClientContext &context, LogicalUpdate &op, - unique_ptr plan) { - auto update = - make_uniq(op.types, op.table, op.table.GetStorage(), op.columns, std::move(op.expressions), - std::move(op.bound_defaults), op.estimated_cardinality, op.return_chunk); - - update->update_is_del_and_insert = op.update_is_del_and_insert; - update->children.push_back(std::move(plan)); - return std::move(update); -} - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalUpdate &op) { - D_ASSERT(op.children.size() == 1); - - auto plan = CreatePlan(*op.children[0]); - - dependencies.AddDependency(op.table); - return op.table.catalog.PlanUpdate(context, op, std::move(plan)); -} - -} // namespace duckdb - - - - - - - - -#include - -namespace duckdb { - -static bool IsStreamingWindow(unique_ptr &expr) { - auto &wexpr = expr->Cast(); - if (!wexpr.partitions.empty() || !wexpr.orders.empty() || wexpr.ignore_nulls) { - return false; - } - switch (wexpr.type) { - // TODO: add more expression types here? - case ExpressionType::WINDOW_AGGREGATE: - // We can stream aggregates if they are "running totals" and don't use filters - return wexpr.start == WindowBoundary::UNBOUNDED_PRECEDING && wexpr.end == WindowBoundary::CURRENT_ROW_ROWS && - !wexpr.filter_expr; - case ExpressionType::WINDOW_FIRST_VALUE: - case ExpressionType::WINDOW_PERCENT_RANK: - case ExpressionType::WINDOW_RANK: - case ExpressionType::WINDOW_RANK_DENSE: - case ExpressionType::WINDOW_ROW_NUMBER: - return true; - default: - return false; - } -} - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalWindow &op) { - D_ASSERT(op.children.size() == 1); - - auto plan = CreatePlan(*op.children[0]); -#ifdef DEBUG - for (auto &expr : op.expressions) { - D_ASSERT(expr->IsWindow()); - } -#endif - - op.estimated_cardinality = op.EstimateCardinality(context); - - // Slice types - auto types = op.types; - const auto output_idx = types.size() - op.expressions.size(); - types.resize(output_idx); - - // Identify streaming windows - vector blocking_windows; - vector streaming_windows; - for (idx_t expr_idx = 0; expr_idx < op.expressions.size(); expr_idx++) { - if (IsStreamingWindow(op.expressions[expr_idx])) { - streaming_windows.push_back(expr_idx); - } else { - blocking_windows.push_back(expr_idx); - } - } - - // Process the window functions by sharing the partition/order definitions - vector evaluation_order; - while (!blocking_windows.empty() || !streaming_windows.empty()) { - const bool process_streaming = blocking_windows.empty(); - auto &remaining = process_streaming ? streaming_windows : blocking_windows; - - // Find all functions that share the partitioning of the first remaining expression - const auto over_idx = remaining[0]; - auto &over_expr = op.expressions[over_idx]->Cast(); - - vector matching; - vector unprocessed; - for (const auto &expr_idx : remaining) { - D_ASSERT(op.expressions[expr_idx]->GetExpressionClass() == ExpressionClass::BOUND_WINDOW); - auto &wexpr = op.expressions[expr_idx]->Cast(); - if (over_expr.KeysAreCompatible(wexpr)) { - matching.emplace_back(expr_idx); - } else { - unprocessed.emplace_back(expr_idx); - } - } - remaining.swap(unprocessed); - - // Extract the matching expressions - vector> select_list; - for (const auto &expr_idx : matching) { - select_list.emplace_back(std::move(op.expressions[expr_idx])); - types.emplace_back(op.types[output_idx + expr_idx]); - } - - // Chain the new window operator on top of the plan - unique_ptr window; - if (process_streaming) { - window = make_uniq(types, std::move(select_list), op.estimated_cardinality); - } else { - window = make_uniq(types, std::move(select_list), op.estimated_cardinality); - } - window->children.push_back(std::move(plan)); - plan = std::move(window); - - // Remember the projection order if we changed it - if (!streaming_windows.empty() || !blocking_windows.empty() || !evaluation_order.empty()) { - evaluation_order.insert(evaluation_order.end(), matching.begin(), matching.end()); - } - } - - // Put everything back into place if it moved - if (!evaluation_order.empty()) { - vector> select_list(op.types.size()); - // The inputs don't move - for (idx_t i = 0; i < output_idx; ++i) { - select_list[i] = make_uniq(op.types[i], i); - } - // The outputs have been rearranged - for (idx_t i = 0; i < evaluation_order.size(); ++i) { - const auto expr_idx = evaluation_order[i] + output_idx; - select_list[expr_idx] = make_uniq(op.types[expr_idx], i + output_idx); - } - auto proj = make_uniq(op.types, std::move(select_list), op.estimated_cardinality); - proj->children.push_back(std::move(plan)); - plan = std::move(proj); - } - - return plan; -} - -} // namespace duckdb - - - - - - - - - - - - -namespace duckdb { - -class DependencyExtractor : public LogicalOperatorVisitor { -public: - explicit DependencyExtractor(DependencyList &dependencies) : dependencies(dependencies) { - } - -protected: - unique_ptr VisitReplace(BoundFunctionExpression &expr, unique_ptr *expr_ptr) override { - // extract dependencies from the bound function expression - if (expr.function.dependency) { - expr.function.dependency(expr, dependencies); - } - return nullptr; - } - -private: - DependencyList &dependencies; -}; - -PhysicalPlanGenerator::PhysicalPlanGenerator(ClientContext &context) : context(context) { -} - -PhysicalPlanGenerator::~PhysicalPlanGenerator() { -} - -unique_ptr PhysicalPlanGenerator::CreatePlan(unique_ptr op) { - auto &profiler = QueryProfiler::Get(context); - - // first resolve column references - profiler.StartPhase("column_binding"); - ColumnBindingResolver resolver; - resolver.VisitOperator(*op); - profiler.EndPhase(); - - // now resolve types of all the operators - profiler.StartPhase("resolve_types"); - op->ResolveOperatorTypes(); - profiler.EndPhase(); - - // extract dependencies from the logical plan - DependencyExtractor extractor(dependencies); - extractor.VisitOperator(*op); - - // then create the main physical plan - profiler.StartPhase("create_plan"); - auto plan = CreatePlan(*op); - profiler.EndPhase(); - - plan->Verify(); - return plan; -} - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalOperator &op) { - op.estimated_cardinality = op.EstimateCardinality(context); - unique_ptr plan = nullptr; - - switch (op.type) { - case LogicalOperatorType::LOGICAL_GET: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_PROJECTION: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_EMPTY_RESULT: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_FILTER: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_WINDOW: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_UNNEST: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_LIMIT: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_LIMIT_PERCENT: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_SAMPLE: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_ORDER_BY: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_TOP_N: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_COPY_TO_FILE: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_DUMMY_SCAN: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_ANY_JOIN: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_ASOF_JOIN: - case LogicalOperatorType::LOGICAL_DELIM_JOIN: - case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_CROSS_PRODUCT: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_POSITIONAL_JOIN: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_UNION: - case LogicalOperatorType::LOGICAL_EXCEPT: - case LogicalOperatorType::LOGICAL_INTERSECT: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_INSERT: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_DELETE: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_CHUNK_GET: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_DELIM_GET: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_EXPRESSION_GET: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_UPDATE: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_CREATE_TABLE: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_CREATE_INDEX: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_EXPLAIN: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_SHOW: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_DISTINCT: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_PREPARE: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_EXECUTE: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_CREATE_VIEW: - case LogicalOperatorType::LOGICAL_CREATE_SEQUENCE: - case LogicalOperatorType::LOGICAL_CREATE_SCHEMA: - case LogicalOperatorType::LOGICAL_CREATE_MACRO: - case LogicalOperatorType::LOGICAL_CREATE_TYPE: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_PRAGMA: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_TRANSACTION: - case LogicalOperatorType::LOGICAL_ALTER: - case LogicalOperatorType::LOGICAL_DROP: - case LogicalOperatorType::LOGICAL_VACUUM: - case LogicalOperatorType::LOGICAL_LOAD: - case LogicalOperatorType::LOGICAL_ATTACH: - case LogicalOperatorType::LOGICAL_DETACH: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_RECURSIVE_CTE: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_MATERIALIZED_CTE: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_CTE_REF: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_EXPORT: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_SET: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_RESET: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_PIVOT: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_EXTENSION_OPERATOR: - plan = op.Cast().CreatePlan(context, *this); - - if (!plan) { - throw InternalException("Missing PhysicalOperator for Extension Operator"); - } - break; - case LogicalOperatorType::LOGICAL_JOIN: - case LogicalOperatorType::LOGICAL_DEPENDENT_JOIN: - case LogicalOperatorType::LOGICAL_INVALID: { - throw NotImplementedException("Unimplemented logical operator type!"); - } - } - if (!plan) { - throw InternalException("Physical plan generator - no plan generated"); - } - - plan->estimated_cardinality = op.estimated_cardinality; - - return plan; -} - -} // namespace duckdb - - - - - - - - - - - - - -namespace duckdb { - -RadixPartitionedHashTable::RadixPartitionedHashTable(GroupingSet &grouping_set_p, const GroupedAggregateData &op_p) - : grouping_set(grouping_set_p), op(op_p) { - auto groups_count = op.GroupCount(); - for (idx_t i = 0; i < groups_count; i++) { - if (grouping_set.find(i) == grouping_set.end()) { - null_groups.push_back(i); - } - } - if (grouping_set.empty()) { - // Fake a single group with a constant value for aggregation without groups - group_types.emplace_back(LogicalType::TINYINT); - } - for (auto &entry : grouping_set) { - D_ASSERT(entry < op.group_types.size()); - group_types.push_back(op.group_types[entry]); - } - SetGroupingValues(); - - auto group_types_copy = group_types; - group_types_copy.emplace_back(LogicalType::HASH); - layout.Initialize(std::move(group_types_copy), AggregateObject::CreateAggregateObjects(op.bindings)); -} - -void RadixPartitionedHashTable::SetGroupingValues() { - // Compute the GROUPING values: - // For each parameter to the GROUPING clause, we check if the hash table groups on this particular group - // If it does, we return 0, otherwise we return 1 - // We then use bitshifts to combine these values - auto &grouping_functions = op.GetGroupingFunctions(); - for (auto &grouping : grouping_functions) { - int64_t grouping_value = 0; - D_ASSERT(grouping.size() < sizeof(int64_t) * 8); - for (idx_t i = 0; i < grouping.size(); i++) { - if (grouping_set.find(grouping[i]) == grouping_set.end()) { - // We don't group on this value! - grouping_value += (int64_t)1 << (grouping.size() - (i + 1)); - } - } - grouping_values.push_back(Value::BIGINT(grouping_value)); - } -} - -const TupleDataLayout &RadixPartitionedHashTable::GetLayout() const { - return layout; -} - -unique_ptr RadixPartitionedHashTable::CreateHT(ClientContext &context, const idx_t capacity, - const idx_t radix_bits) const { - return make_uniq(context, BufferAllocator::Get(context), group_types, op.payload_types, - op.bindings, capacity, radix_bits); -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -struct AggregatePartition { - explicit AggregatePartition(unique_ptr data_p) : data(std::move(data_p)), finalized(false) { - } - unique_ptr data; - atomic finalized; -}; - -class RadixHTGlobalSinkState; - -struct RadixHTConfig { -public: - explicit RadixHTConfig(ClientContext &context, RadixHTGlobalSinkState &sink); - - void SetRadixBits(idx_t radix_bits_p); - bool SetRadixBitsToExternal(); - idx_t GetRadixBits() const; - -private: - void SetRadixBitsInternal(const idx_t radix_bits_p, bool external); - static idx_t InitialSinkRadixBits(ClientContext &context); - static idx_t MaximumSinkRadixBits(ClientContext &context); - static idx_t ExternalRadixBits(const idx_t &maximum_sink_radix_bits_p); - static idx_t SinkCapacity(ClientContext &context); - -private: - //! Assume (1 << 15) = 32KB L1 cache per core, divided by two because hyperthreading - static constexpr const idx_t L1_CACHE_SIZE = 32768 / 2; - //! Assume (1 << 20) = 1MB L2 cache per core, divided by two because hyperthreading - static constexpr const idx_t L2_CACHE_SIZE = 1048576 / 2; - //! Assume (1 << 20) + (1 << 19) = 1.5MB L3 cache per core (shared), divided by two because hyperthreading - static constexpr const idx_t L3_CACHE_SIZE = 1572864 / 2; - - //! Sink radix bits to initialize with - static constexpr const idx_t MAXIMUM_INITIAL_SINK_RADIX_BITS = 3; - //! Maximum Sink radix bits (independent of threads) - static constexpr const idx_t MAXIMUM_FINAL_SINK_RADIX_BITS = 7; - //! By how many radix bits to increment if we go external - static constexpr const idx_t EXTERNAL_RADIX_BITS_INCREMENT = 3; - - //! The global sink state - RadixHTGlobalSinkState &sink; - //! Current thread-global sink radix bits - atomic sink_radix_bits; - //! Maximum Sink radix bits (set based on number of threads) - const idx_t maximum_sink_radix_bits; - //! Radix bits if we go external - const idx_t external_radix_bits; - -public: - //! Capacity of HTs during the Sink - const idx_t sink_capacity; - - //! If we fill this many blocks per partition, we trigger a repartition - static constexpr const double BLOCK_FILL_FACTOR = 1.8; - //! By how many bits to repartition if a repartition is triggered - static constexpr const idx_t REPARTITION_RADIX_BITS = 2; -}; - -class RadixHTGlobalSinkState : public GlobalSinkState { -public: - RadixHTGlobalSinkState(ClientContext &context, const RadixPartitionedHashTable &radix_ht); - - //! Destroys aggregate states (if multi-scan) - ~RadixHTGlobalSinkState() override; - void Destroy(); - -public: - //! The radix HT - const RadixPartitionedHashTable &radix_ht; - //! Config for partitioning - RadixHTConfig config; - - //! Whether we've called Finalize - bool finalized; - //! Whether we are doing an external aggregation - atomic external; - //! Threads that have called Sink - atomic active_threads; - //! If any thread has called combine - atomic any_combined; - - //! Lock for uncombined_data/stored_allocators - mutex lock; - //! Uncombined partitioned data that will be put into the AggregatePartitions - unique_ptr uncombined_data; - //! Allocators used during the Sink/Finalize - vector> stored_allocators; - - //! Partitions that are finalized during GetData - vector> partitions; - - //! For synchronizing finalize tasks - atomic finalize_idx; - - //! Pin properties when scanning - TupleDataPinProperties scan_pin_properties; - //! Total count before combining - idx_t count_before_combining; -}; - -RadixHTGlobalSinkState::RadixHTGlobalSinkState(ClientContext &context, const RadixPartitionedHashTable &radix_ht_p) - : radix_ht(radix_ht_p), config(context, *this), finalized(false), external(false), active_threads(0), - any_combined(false), finalize_idx(0), scan_pin_properties(TupleDataPinProperties::DESTROY_AFTER_DONE), - count_before_combining(0) { -} - -RadixHTGlobalSinkState::~RadixHTGlobalSinkState() { - Destroy(); -} - -// LCOV_EXCL_START -void RadixHTGlobalSinkState::Destroy() { - if (scan_pin_properties == TupleDataPinProperties::DESTROY_AFTER_DONE || count_before_combining == 0 || - partitions.empty()) { - // Already destroyed / empty - return; - } - - TupleDataLayout layout = partitions[0]->data->GetLayout().Copy(); - if (!layout.HasDestructor()) { - return; // No destructors, exit - } - - // There are aggregates with destructors: Call the destructor for each of the aggregates - RowOperationsState row_state(*stored_allocators.back()); - for (auto &partition : partitions) { - auto &data_collection = *partition->data; - if (data_collection.Count() == 0) { - continue; - } - TupleDataChunkIterator iterator(data_collection, TupleDataPinProperties::DESTROY_AFTER_DONE, false); - auto &row_locations = iterator.GetChunkState().row_locations; - do { - RowOperations::DestroyStates(row_state, layout, row_locations, iterator.GetCurrentChunkCount()); - } while (iterator.Next()); - data_collection.Reset(); - } -} -// LCOV_EXCL_STOP - -RadixHTConfig::RadixHTConfig(ClientContext &context, RadixHTGlobalSinkState &sink_p) - : sink(sink_p), sink_radix_bits(InitialSinkRadixBits(context)), - maximum_sink_radix_bits(MaximumSinkRadixBits(context)), - external_radix_bits(ExternalRadixBits(maximum_sink_radix_bits)), sink_capacity(SinkCapacity(context)) { -} - -void RadixHTConfig::SetRadixBits(idx_t radix_bits_p) { - SetRadixBitsInternal(MinValue(radix_bits_p, maximum_sink_radix_bits), false); -} - -bool RadixHTConfig::SetRadixBitsToExternal() { - SetRadixBitsInternal(external_radix_bits, true); - return sink.external; -} - -idx_t RadixHTConfig::GetRadixBits() const { - return sink_radix_bits; -} - -void RadixHTConfig::SetRadixBitsInternal(const idx_t radix_bits_p, bool external) { - if (sink_radix_bits >= radix_bits_p || sink.any_combined) { - return; - } - - lock_guard guard(sink.lock); - if (sink_radix_bits >= radix_bits_p || sink.any_combined) { - return; - } - - if (external) { - sink.external = true; - } - sink_radix_bits = radix_bits_p; - return; -} - -idx_t RadixHTConfig::InitialSinkRadixBits(ClientContext &context) { - const idx_t active_threads = TaskScheduler::GetScheduler(context).NumberOfThreads(); - return MinValue(RadixPartitioning::RadixBits(NextPowerOfTwo(active_threads)), MAXIMUM_INITIAL_SINK_RADIX_BITS); -} - -idx_t RadixHTConfig::MaximumSinkRadixBits(ClientContext &context) { - const idx_t active_threads = TaskScheduler::GetScheduler(context).NumberOfThreads(); - return MinValue(RadixPartitioning::RadixBits(NextPowerOfTwo(active_threads)), MAXIMUM_FINAL_SINK_RADIX_BITS); -} - -idx_t RadixHTConfig::ExternalRadixBits(const idx_t &maximum_sink_radix_bits_p) { - return MinValue(maximum_sink_radix_bits_p + EXTERNAL_RADIX_BITS_INCREMENT, MAXIMUM_FINAL_SINK_RADIX_BITS); -} - -idx_t RadixHTConfig::SinkCapacity(ClientContext &context) { - // Get active and maximum number of threads - const idx_t active_threads = TaskScheduler::GetScheduler(context).NumberOfThreads(); - const auto max_threads = DBConfig::GetSystemMaxThreads(FileSystem::GetFileSystem(context)); - - // Compute cache size per active thread (assuming cache is shared) - const auto total_shared_cache_size = max_threads * L3_CACHE_SIZE; - const auto cache_per_active_thread = L1_CACHE_SIZE + L2_CACHE_SIZE + total_shared_cache_size / active_threads; - - // Divide cache per active thread by entry size, round up to next power of two, to get capacity - const auto size_per_entry = sizeof(aggr_ht_entry_t) * GroupedAggregateHashTable::LOAD_FACTOR; - const auto capacity = NextPowerOfTwo(cache_per_active_thread / size_per_entry); - - // Capacity must be at least the minimum capacity - return MaxValue(capacity, GroupedAggregateHashTable::InitialCapacity()); -} - -class RadixHTLocalSinkState : public LocalSinkState { -public: - RadixHTLocalSinkState(ClientContext &context, const RadixPartitionedHashTable &radix_ht); - -public: - //! Thread-local HT that is re-used after abandoning - unique_ptr ht; - //! Chunk with group columns - DataChunk group_chunk; - - //! Data that is abandoned ends up here (only if we're doing external aggregation) - unique_ptr abandoned_data; -}; - -RadixHTLocalSinkState::RadixHTLocalSinkState(ClientContext &, const RadixPartitionedHashTable &radix_ht) { - // If there are no groups we create a fake group so everything has the same group - group_chunk.InitializeEmpty(radix_ht.group_types); - if (radix_ht.grouping_set.empty()) { - group_chunk.data[0].Reference(Value::TINYINT(42)); - } -} - -unique_ptr RadixPartitionedHashTable::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(context, *this); -} - -unique_ptr RadixPartitionedHashTable::GetLocalSinkState(ExecutionContext &context) const { - return make_uniq(context.client, *this); -} - -void RadixPartitionedHashTable::PopulateGroupChunk(DataChunk &group_chunk, DataChunk &input_chunk) const { - idx_t chunk_index = 0; - // Populate the group_chunk - for (auto &group_idx : grouping_set) { - // Retrieve the expression containing the index in the input chunk - auto &group = op.groups[group_idx]; - D_ASSERT(group->type == ExpressionType::BOUND_REF); - auto &bound_ref_expr = group->Cast(); - // Reference from input_chunk[group.index] -> group_chunk[chunk_index] - group_chunk.data[chunk_index++].Reference(input_chunk.data[bound_ref_expr.index]); - } - group_chunk.SetCardinality(input_chunk.size()); - group_chunk.Verify(); -} - -bool MaybeRepartition(ClientContext &context, RadixHTGlobalSinkState &gstate, RadixHTLocalSinkState &lstate) { - auto &config = gstate.config; - auto &ht = *lstate.ht; - auto &partitioned_data = ht.GetPartitionedData(); - - // Check if we're approaching the memory limit - const idx_t n_threads = TaskScheduler::GetScheduler(context).NumberOfThreads(); - const idx_t limit = BufferManager::GetBufferManager(context).GetMaxMemory(); - const idx_t thread_limit = 0.6 * limit / n_threads; - if (ht.GetPartitionedData()->SizeInBytes() > thread_limit || context.config.force_external) { - if (gstate.config.SetRadixBitsToExternal()) { - // We're approaching the memory limit, unpin the data - if (!lstate.abandoned_data) { - lstate.abandoned_data = make_uniq( - BufferManager::GetBufferManager(context), gstate.radix_ht.GetLayout(), config.GetRadixBits(), - gstate.radix_ht.GetLayout().ColumnCount() - 1); - } - - ht.UnpinData(); - partitioned_data->Repartition(*lstate.abandoned_data); - ht.SetRadixBits(gstate.config.GetRadixBits()); - ht.InitializePartitionedData(); - return true; - } - } - - const auto partition_count = partitioned_data->PartitionCount(); - const auto current_radix_bits = RadixPartitioning::RadixBits(partition_count); - D_ASSERT(current_radix_bits <= config.GetRadixBits()); - - const auto row_size_per_partition = - partitioned_data->Count() * partitioned_data->GetLayout().GetRowWidth() / partition_count; - if (row_size_per_partition > config.BLOCK_FILL_FACTOR * Storage::BLOCK_SIZE) { - // We crossed our block filling threshold, try to increment radix bits - config.SetRadixBits(current_radix_bits + config.REPARTITION_RADIX_BITS); - } - - const auto global_radix_bits = config.GetRadixBits(); - if (current_radix_bits == global_radix_bits) { - return false; // We're already on the right number of radix bits - } - - // We're out-of-sync with the global radix bits, repartition - ht.UnpinData(); - auto old_partitioned_data = std::move(partitioned_data); - ht.SetRadixBits(global_radix_bits); - ht.InitializePartitionedData(); - old_partitioned_data->Repartition(*ht.GetPartitionedData()); - return true; -} - -void RadixPartitionedHashTable::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input, - DataChunk &payload_input, const unsafe_vector &filter) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - if (!lstate.ht) { - lstate.ht = CreateHT(context.client, gstate.config.sink_capacity, gstate.config.GetRadixBits()); - gstate.active_threads++; - } - - auto &group_chunk = lstate.group_chunk; - PopulateGroupChunk(group_chunk, chunk); - - auto &ht = *lstate.ht; - ht.AddChunk(group_chunk, payload_input, filter); - - if (ht.Count() + STANDARD_VECTOR_SIZE < ht.ResizeThreshold()) { - return; // We can fit another chunk - } - - if (gstate.active_threads > 2) { - // 'Reset' the HT without taking its data, we can just keep appending to the same collection - // This only works because we never resize the HT - ht.ClearPointerTable(); - ht.ResetCount(); - // We don't do this when running with 1 or 2 threads, it only makes sense when there's many threads - } - - // Check if we need to repartition - auto repartitioned = MaybeRepartition(context.client, gstate, lstate); - - if (repartitioned && ht.Count() != 0) { - // We repartitioned, but we didn't clear the pointer table / reset the count because we're on 1 or 2 threads - ht.ClearPointerTable(); - ht.ResetCount(); - } - - // TODO: combine early and often -} - -void RadixPartitionedHashTable::Combine(ExecutionContext &context, GlobalSinkState &gstate_p, - LocalSinkState &lstate_p) const { - auto &gstate = gstate_p.Cast(); - auto &lstate = lstate_p.Cast(); - if (!lstate.ht) { - return; - } - - // Set any_combined, then check one last time whether we need to repartition - gstate.any_combined = true; - MaybeRepartition(context.client, gstate, lstate); - - auto &ht = *lstate.ht; - ht.UnpinData(); - - if (lstate.abandoned_data) { - D_ASSERT(gstate.external); - D_ASSERT(lstate.abandoned_data->PartitionCount() == lstate.ht->GetPartitionedData()->PartitionCount()); - D_ASSERT(lstate.abandoned_data->PartitionCount() == - RadixPartitioning::NumberOfPartitions(gstate.config.GetRadixBits())); - lstate.abandoned_data->Combine(*lstate.ht->GetPartitionedData()); - } else { - lstate.abandoned_data = std::move(ht.GetPartitionedData()); - } - - lock_guard guard(gstate.lock); - if (gstate.uncombined_data) { - gstate.uncombined_data->Combine(*lstate.abandoned_data); - } else { - gstate.uncombined_data = std::move(lstate.abandoned_data); - } - gstate.stored_allocators.emplace_back(ht.GetAggregateAllocator()); -} - -void RadixPartitionedHashTable::Finalize(ClientContext &, GlobalSinkState &gstate_p) const { - auto &gstate = gstate_p.Cast(); - - if (gstate.uncombined_data) { - auto &uncombined_data = *gstate.uncombined_data; - gstate.count_before_combining = uncombined_data.Count(); - - // If true there is no need to combine, it was all done by a single thread in a single HT - const auto single_ht = !gstate.external && gstate.active_threads == 1; - - auto &uncombined_partition_data = uncombined_data.GetPartitions(); - const auto n_partitions = uncombined_partition_data.size(); - gstate.partitions.reserve(n_partitions); - for (idx_t i = 0; i < n_partitions; i++) { - gstate.partitions.emplace_back(make_uniq(std::move(uncombined_partition_data[i]))); - if (single_ht) { - gstate.finalize_idx++; - gstate.partitions.back()->finalized = true; - } - } - } else { - gstate.count_before_combining = 0; - } - - gstate.finalized = true; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -idx_t RadixPartitionedHashTable::NumberOfPartitions(GlobalSinkState &sink_p) const { - auto &sink = sink_p.Cast(); - return sink.partitions.size(); -} - -void RadixPartitionedHashTable::SetMultiScan(GlobalSinkState &sink_p) { - auto &sink = sink_p.Cast(); - sink.scan_pin_properties = TupleDataPinProperties::UNPIN_AFTER_DONE; -} - -enum class RadixHTSourceTaskType : uint8_t { NO_TASK, FINALIZE, SCAN }; - -class RadixHTLocalSourceState; - -class RadixHTGlobalSourceState : public GlobalSourceState { -public: - RadixHTGlobalSourceState(ClientContext &context, const RadixPartitionedHashTable &radix_ht); - - //! Assigns a task to a local source state - bool AssignTask(RadixHTGlobalSinkState &sink, RadixHTLocalSourceState &lstate); - -public: - //! The client context - ClientContext &context; - //! For synchronizing the source phase - atomic finished; - - //! Column ids for scanning - vector column_ids; - - //! For synchronizing scan tasks - atomic scan_idx; - atomic scan_done; -}; - -enum class RadixHTScanStatus : uint8_t { INIT, IN_PROGRESS, DONE }; - -class RadixHTLocalSourceState : public LocalSourceState { -public: - explicit RadixHTLocalSourceState(ExecutionContext &context, const RadixPartitionedHashTable &radix_ht); - -public: - //! Do the work this thread has been assigned - void ExecuteTask(RadixHTGlobalSinkState &sink, RadixHTGlobalSourceState &gstate, DataChunk &chunk); - //! Whether this thread has finished the work it has been assigned - bool TaskFinished(); - -private: - //! Execute the finalize or scan task - void Finalize(RadixHTGlobalSinkState &sink, RadixHTGlobalSourceState &gstate); - void Scan(RadixHTGlobalSinkState &sink, RadixHTGlobalSourceState &gstate, DataChunk &chunk); - -public: - //! Current task and index - RadixHTSourceTaskType task; - idx_t task_idx; - - //! Thread-local HT that is re-used to Finalize - unique_ptr ht; - //! Current status of a Scan - RadixHTScanStatus scan_status; - -private: - //! Allocator and layout for finalizing state - TupleDataLayout layout; - ArenaAllocator aggregate_allocator; - - //! State and chunk for scanning - TupleDataScanState scan_state; - DataChunk scan_chunk; -}; - -unique_ptr RadixPartitionedHashTable::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(context, *this); -} - -unique_ptr RadixPartitionedHashTable::GetLocalSourceState(ExecutionContext &context) const { - return make_uniq(context, *this); -} - -RadixHTGlobalSourceState::RadixHTGlobalSourceState(ClientContext &context_p, const RadixPartitionedHashTable &radix_ht) - : context(context_p), finished(false), scan_idx(0), scan_done(0) { - for (column_t column_id = 0; column_id < radix_ht.group_types.size(); column_id++) { - column_ids.push_back(column_id); - } -} - -bool RadixHTGlobalSourceState::AssignTask(RadixHTGlobalSinkState &sink, RadixHTLocalSourceState &lstate) { - D_ASSERT(lstate.scan_status != RadixHTScanStatus::IN_PROGRESS); - - const auto n_partitions = sink.partitions.size(); - if (finished) { - return false; - } - // We first try to assign a Scan task, then a Finalize task if that didn't work, without using any locks - - // We need an atomic compare-and-swap to assign a Scan task, because we need to only increment - // the 'scan_idx' atomic if the 'finalize' of that partition is true, i.e., ready to be scanned - bool scan_assigned = true; - do { - lstate.task_idx = scan_idx.load(); - if (lstate.task_idx >= n_partitions || !sink.partitions[lstate.task_idx]->finalized) { - scan_assigned = false; - break; - } - } while (!std::atomic_compare_exchange_weak(&scan_idx, &lstate.task_idx, lstate.task_idx + 1)); - - if (scan_assigned) { - // We successfully assigned a Scan task - D_ASSERT(lstate.task_idx < n_partitions && sink.partitions[lstate.task_idx]->finalized); - lstate.task = RadixHTSourceTaskType::SCAN; - lstate.scan_status = RadixHTScanStatus::INIT; - return true; - } - - // We didn't assign a Scan task - if (sink.finalize_idx >= n_partitions) { - return false; // No finalize tasks left - } - - // We can just increment the atomic here, much simpler than assigning the scan task - lstate.task_idx = sink.finalize_idx++; - if (lstate.task_idx < n_partitions) { - // We successfully assigned a Finalize task - lstate.task = RadixHTSourceTaskType::FINALIZE; - return true; - } - - // We didn't manage to assign a Finalize task - return false; -} - -RadixHTLocalSourceState::RadixHTLocalSourceState(ExecutionContext &context, const RadixPartitionedHashTable &radix_ht) - : task(RadixHTSourceTaskType::NO_TASK), scan_status(RadixHTScanStatus::DONE), layout(radix_ht.GetLayout().Copy()), - aggregate_allocator(BufferAllocator::Get(context.client)) { - auto &allocator = BufferAllocator::Get(context.client); - auto scan_chunk_types = radix_ht.group_types; - for (auto &aggr_type : radix_ht.op.aggregate_return_types) { - scan_chunk_types.push_back(aggr_type); - } - scan_chunk.Initialize(allocator, scan_chunk_types); -} - -void RadixHTLocalSourceState::ExecuteTask(RadixHTGlobalSinkState &sink, RadixHTGlobalSourceState &gstate, - DataChunk &chunk) { - switch (task) { - case RadixHTSourceTaskType::FINALIZE: - Finalize(sink, gstate); - break; - case RadixHTSourceTaskType::SCAN: - Scan(sink, gstate, chunk); - break; - default: - throw InternalException("Unexpected RadixHTSourceTaskType in ExecuteTask!"); - } -} - -void RadixHTLocalSourceState::Finalize(RadixHTGlobalSinkState &sink, RadixHTGlobalSourceState &gstate) { - D_ASSERT(task == RadixHTSourceTaskType::FINALIZE); - D_ASSERT(scan_status != RadixHTScanStatus::IN_PROGRESS); - - auto &partition = *sink.partitions[task_idx]; - if (partition.data->Count() == 0) { - partition.finalized = true; - return; - } - - if (!ht) { - // Create a HT with sufficient capacity - const auto capacity = GroupedAggregateHashTable::GetCapacityForCount(partition.data->Count()); - ht = sink.radix_ht.CreateHT(gstate.context, capacity, 0); - } else { - // We may want to resize here to the size of this partition, but for now we just assume uniform partition sizes - ht->InitializePartitionedData(); - ht->ClearPointerTable(); - ht->ResetCount(); - } - - // Now combine the uncombined data using this thread's HT - ht->Combine(*partition.data); - ht->UnpinData(); - - // Move the combined data back to the partition - partition.data = - make_uniq(BufferManager::GetBufferManager(gstate.context), sink.radix_ht.GetLayout()); - partition.data->Combine(*ht->GetPartitionedData()->GetPartitions()[0]); - - // Mark partition as ready to scan - partition.finalized = true; - - // Make sure this thread's aggregate allocator does not get lost - lock_guard guard(sink.lock); - sink.stored_allocators.emplace_back(ht->GetAggregateAllocator()); -} - -void RadixHTLocalSourceState::Scan(RadixHTGlobalSinkState &sink, RadixHTGlobalSourceState &gstate, DataChunk &chunk) { - D_ASSERT(task == RadixHTSourceTaskType::SCAN); - D_ASSERT(scan_status != RadixHTScanStatus::DONE); - - auto &partition = *sink.partitions[task_idx]; - D_ASSERT(partition.finalized); - auto &data_collection = *partition.data; - - if (data_collection.Count() == 0) { - scan_status = RadixHTScanStatus::DONE; - if (++gstate.scan_done == sink.partitions.size()) { - gstate.finished = true; - } - return; - } - - if (scan_status == RadixHTScanStatus::INIT) { - data_collection.InitializeScan(scan_state, gstate.column_ids, sink.scan_pin_properties); - scan_status = RadixHTScanStatus::IN_PROGRESS; - } - - if (!data_collection.Scan(scan_state, scan_chunk)) { - scan_status = RadixHTScanStatus::DONE; - if (sink.scan_pin_properties == TupleDataPinProperties::DESTROY_AFTER_DONE) { - data_collection.Reset(); - } - return; - } - - if (data_collection.ScanComplete(scan_state)) { - if (++gstate.scan_done == sink.partitions.size()) { - gstate.finished = true; - } - } - - RowOperationsState row_state(aggregate_allocator); - const auto group_cols = layout.ColumnCount() - 1; - RowOperations::FinalizeStates(row_state, layout, scan_state.chunk_state.row_locations, scan_chunk, group_cols); - - if (sink.scan_pin_properties == TupleDataPinProperties::DESTROY_AFTER_DONE && layout.HasDestructor()) { - RowOperations::DestroyStates(row_state, layout, scan_state.chunk_state.row_locations, scan_chunk.size()); - } - - auto &radix_ht = sink.radix_ht; - idx_t chunk_index = 0; - for (auto &entry : radix_ht.grouping_set) { - chunk.data[entry].Reference(scan_chunk.data[chunk_index++]); - } - for (auto null_group : radix_ht.null_groups) { - chunk.data[null_group].SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(chunk.data[null_group], true); - } - D_ASSERT(radix_ht.grouping_set.size() + radix_ht.null_groups.size() == radix_ht.op.GroupCount()); - for (idx_t col_idx = 0; col_idx < radix_ht.op.aggregates.size(); col_idx++) { - chunk.data[radix_ht.op.GroupCount() + col_idx].Reference( - scan_chunk.data[radix_ht.group_types.size() + col_idx]); - } - D_ASSERT(radix_ht.op.grouping_functions.size() == radix_ht.grouping_values.size()); - for (idx_t i = 0; i < radix_ht.op.grouping_functions.size(); i++) { - chunk.data[radix_ht.op.GroupCount() + radix_ht.op.aggregates.size() + i].Reference(radix_ht.grouping_values[i]); - } - chunk.SetCardinality(scan_chunk); - D_ASSERT(chunk.size() != 0); -} - -bool RadixHTLocalSourceState::TaskFinished() { - switch (task) { - case RadixHTSourceTaskType::FINALIZE: - return true; - case RadixHTSourceTaskType::SCAN: - return scan_status == RadixHTScanStatus::DONE; - default: - D_ASSERT(task == RadixHTSourceTaskType::NO_TASK); - return true; - } -} - -SourceResultType RadixPartitionedHashTable::GetData(ExecutionContext &context, DataChunk &chunk, - GlobalSinkState &sink_p, OperatorSourceInput &input) const { - auto &sink = sink_p.Cast(); - D_ASSERT(sink.finalized); - - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - D_ASSERT(sink.scan_pin_properties == TupleDataPinProperties::UNPIN_AFTER_DONE || - sink.scan_pin_properties == TupleDataPinProperties::DESTROY_AFTER_DONE); - - if (gstate.finished) { - return SourceResultType::FINISHED; - } - - if (sink.count_before_combining == 0) { - if (grouping_set.empty()) { - // Special case hack to sort out aggregating from empty intermediates for aggregations without groups - D_ASSERT(chunk.ColumnCount() == null_groups.size() + op.aggregates.size() + op.grouping_functions.size()); - // For each column in the aggregates, set to initial state - chunk.SetCardinality(1); - for (auto null_group : null_groups) { - chunk.data[null_group].SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(chunk.data[null_group], true); - } - ArenaAllocator allocator(BufferAllocator::Get(context.client)); - for (idx_t i = 0; i < op.aggregates.size(); i++) { - D_ASSERT(op.aggregates[i]->GetExpressionClass() == ExpressionClass::BOUND_AGGREGATE); - auto &aggr = op.aggregates[i]->Cast(); - auto aggr_state = make_unsafe_uniq_array(aggr.function.state_size()); - aggr.function.initialize(aggr_state.get()); - - AggregateInputData aggr_input_data(aggr.bind_info.get(), allocator); - Vector state_vector(Value::POINTER(CastPointerToValue(aggr_state.get()))); - aggr.function.finalize(state_vector, aggr_input_data, chunk.data[null_groups.size() + i], 1, 0); - if (aggr.function.destructor) { - aggr.function.destructor(state_vector, aggr_input_data, 1); - } - } - // Place the grouping values (all the groups of the grouping_set condensed into a single value) - // Behind the null groups + aggregates - for (idx_t i = 0; i < op.grouping_functions.size(); i++) { - chunk.data[null_groups.size() + op.aggregates.size() + i].Reference(grouping_values[i]); - } - } - gstate.finished = true; - return SourceResultType::FINISHED; - } - - while (!gstate.finished && chunk.size() == 0) { - if (!lstate.TaskFinished() || gstate.AssignTask(sink, lstate)) { - lstate.ExecuteTask(sink, gstate, chunk); - } - } - - if (chunk.size() != 0) { - return SourceResultType::HAVE_MORE_OUTPUT; - } else { - return SourceResultType::FINISHED; - } -} - -} // namespace duckdb - - - -namespace duckdb { - -ReservoirSample::ReservoirSample(Allocator &allocator, idx_t sample_count, int64_t seed) - : BlockingSample(seed), sample_count(sample_count), reservoir(allocator) { -} - -void ReservoirSample::AddToReservoir(DataChunk &input) { - if (sample_count == 0) { - return; - } - // Input: A population V of n weighted items - // Output: A reservoir R with a size m - // 1: The first m items of V are inserted into R - // first we need to check if the reservoir already has "m" elements - if (reservoir.Count() < sample_count) { - if (FillReservoir(input) == 0) { - // entire chunk was consumed by reservoir - return; - } - } - // find the position of next_index relative to current_count - idx_t remaining = input.size(); - idx_t base_offset = 0; - while (true) { - idx_t offset = base_reservoir_sample.next_index - base_reservoir_sample.current_count; - if (offset >= remaining) { - // not in this chunk! increment current count and go to the next chunk - base_reservoir_sample.current_count += remaining; - return; - } - // in this chunk! replace the element - ReplaceElement(input, base_offset + offset); - // shift the chunk forward - remaining -= offset; - base_offset += offset; - } -} - -unique_ptr ReservoirSample::GetChunk() { - return reservoir.Fetch(); -} - -void ReservoirSample::ReplaceElement(DataChunk &input, idx_t index_in_chunk) { - // replace the entry in the reservoir - // 8. The item in R with the minimum key is replaced by item vi - for (idx_t col_idx = 0; col_idx < input.ColumnCount(); col_idx++) { - reservoir.SetValue(col_idx, base_reservoir_sample.min_entry, input.GetValue(col_idx, index_in_chunk)); - } - base_reservoir_sample.ReplaceElement(); -} - -idx_t ReservoirSample::FillReservoir(DataChunk &input) { - idx_t chunk_count = input.size(); - input.Flatten(); - - // we have not: append to the reservoir - idx_t required_count; - if (reservoir.Count() + chunk_count >= sample_count) { - // have to limit the count of the chunk - required_count = sample_count - reservoir.Count(); - } else { - // we copy the entire chunk - required_count = chunk_count; - } - // instead of copying we just change the pointer in the current chunk - input.SetCardinality(required_count); - reservoir.Append(input); - - base_reservoir_sample.InitializeReservoir(reservoir.Count(), sample_count); - - // check if there are still elements remaining - // this happens if we are on a boundary - // for example, input.size() is 1024, but our sample size is 10 - if (required_count == chunk_count) { - // we are done here - return 0; - } - // we still need to process a part of the chunk - // create a selection vector of the remaining elements - SelectionVector sel(STANDARD_VECTOR_SIZE); - for (idx_t i = required_count; i < chunk_count; i++) { - sel.set_index(i - required_count, i); - } - // slice the input vector and continue - input.Slice(sel, chunk_count - required_count); - return input.size(); -} - -ReservoirSamplePercentage::ReservoirSamplePercentage(Allocator &allocator, double percentage, int64_t seed) - : BlockingSample(seed), allocator(allocator), sample_percentage(percentage / 100.0), current_count(0), - is_finalized(false) { - reservoir_sample_size = idx_t(sample_percentage * RESERVOIR_THRESHOLD); - current_sample = make_uniq(allocator, reservoir_sample_size, random.NextRandomInteger()); -} - -void ReservoirSamplePercentage::AddToReservoir(DataChunk &input) { - if (current_count + input.size() > RESERVOIR_THRESHOLD) { - // we don't have enough space in our current reservoir - // first check what we still need to append to the current sample - idx_t append_to_current_sample_count = RESERVOIR_THRESHOLD - current_count; - idx_t append_to_next_sample = input.size() - append_to_current_sample_count; - if (append_to_current_sample_count > 0) { - // we have elements remaining, first add them to the current sample - if (append_to_next_sample > 0) { - // we need to also add to the next sample - DataChunk new_chunk; - new_chunk.InitializeEmpty(input.GetTypes()); - new_chunk.Slice(input, *FlatVector::IncrementalSelectionVector(), append_to_current_sample_count); - new_chunk.Flatten(); - current_sample->AddToReservoir(new_chunk); - } else { - input.Flatten(); - input.SetCardinality(append_to_current_sample_count); - current_sample->AddToReservoir(input); - } - } - if (append_to_next_sample > 0) { - // slice the input for the remainder - SelectionVector sel(append_to_next_sample); - for (idx_t i = 0; i < append_to_next_sample; i++) { - sel.set_index(i, append_to_current_sample_count + i); - } - input.Slice(sel, append_to_next_sample); - } - // now our first sample is filled: append it to the set of finished samples - finished_samples.push_back(std::move(current_sample)); - - // allocate a new sample, and potentially add the remainder of the current input to that sample - current_sample = make_uniq(allocator, reservoir_sample_size, random.NextRandomInteger()); - if (append_to_next_sample > 0) { - current_sample->AddToReservoir(input); - } - current_count = append_to_next_sample; - } else { - // we can just append to the current sample - current_count += input.size(); - current_sample->AddToReservoir(input); - } -} - -unique_ptr ReservoirSamplePercentage::GetChunk() { - if (!is_finalized) { - Finalize(); - } - while (!finished_samples.empty()) { - auto &front = finished_samples.front(); - auto chunk = front->GetChunk(); - if (chunk && chunk->size() > 0) { - return chunk; - } - // move to the next sample - finished_samples.erase(finished_samples.begin()); - } - return nullptr; -} - -void ReservoirSamplePercentage::Finalize() { - // need to finalize the current sample, if any - if (current_count > 0) { - // create a new sample - auto new_sample_size = idx_t(round(sample_percentage * current_count)); - auto new_sample = make_uniq(allocator, new_sample_size, random.NextRandomInteger()); - while (true) { - auto chunk = current_sample->GetChunk(); - if (!chunk || chunk->size() == 0) { - break; - } - new_sample->AddToReservoir(*chunk); - } - finished_samples.push_back(std::move(new_sample)); - } - is_finalized = true; -} - -BaseReservoirSampling::BaseReservoirSampling(int64_t seed) : random(seed) { - next_index = 0; - min_threshold = 0; - min_entry = 0; - current_count = 0; -} - -BaseReservoirSampling::BaseReservoirSampling() : BaseReservoirSampling(-1) { -} - -void BaseReservoirSampling::InitializeReservoir(idx_t cur_size, idx_t sample_size) { - //! 1: The first m items of V are inserted into R - //! first we need to check if the reservoir already has "m" elements - if (cur_size == sample_size) { - //! 2. For each item vi ∈ R: Calculate a key ki = random(0, 1) - //! we then define the threshold to enter the reservoir T_w as the minimum key of R - //! we use a priority queue to extract the minimum key in O(1) time - for (idx_t i = 0; i < sample_size; i++) { - double k_i = random.NextRandom(); - reservoir_weights.emplace(-k_i, i); - } - SetNextEntry(); - } -} - -void BaseReservoirSampling::SetNextEntry() { - //! 4. Let r = random(0, 1) and Xw = log(r) / log(T_w) - auto &min_key = reservoir_weights.top(); - double t_w = -min_key.first; - double r = random.NextRandom(); - double x_w = log(r) / log(t_w); - //! 5. From the current item vc skip items until item vi , such that: - //! 6. wc +wc+1 +···+wi−1 < Xw <= wc +wc+1 +···+wi−1 +wi - //! since all our weights are 1 (uniform sampling), we can just determine the amount of elements to skip - min_threshold = t_w; - min_entry = min_key.second; - next_index = MaxValue(1, idx_t(round(x_w))); - current_count = 0; -} - -void BaseReservoirSampling::ReplaceElement() { - //! replace the entry in the reservoir - //! pop the minimum entry - reservoir_weights.pop(); - //! now update the reservoir - //! 8. Let tw = Tw i , r2 = random(tw,1) and vi’s key: ki = (r2)1/wi - //! 9. The new threshold Tw is the new minimum key of R - //! we generate a random number between (min_threshold, 1) - double r2 = random.NextRandom(min_threshold, 1); - //! now we insert the new weight into the reservoir - reservoir_weights.emplace(-r2, min_entry); - //! we update the min entry with the new min entry in the reservoir - SetNextEntry(); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -static idx_t FindNextStart(const ValidityMask &mask, idx_t l, const idx_t r, idx_t &n) { - if (mask.AllValid()) { - auto start = MinValue(l + n - 1, r); - n -= MinValue(n, r - l); - return start; - } - - while (l < r) { - // If l is aligned with the start of a block, and the block is blank, then skip forward one block. - idx_t entry_idx; - idx_t shift; - mask.GetEntryIndex(l, entry_idx, shift); - - const auto block = mask.GetValidityEntry(entry_idx); - if (mask.NoneValid(block) && !shift) { - l += ValidityMask::BITS_PER_VALUE; - continue; - } - - // Loop over the block - for (; shift < ValidityMask::BITS_PER_VALUE && l < r; ++shift, ++l) { - if (mask.RowIsValid(block, shift) && --n == 0) { - return MinValue(l, r); - } - } - } - - // Didn't find a start so return the end of the range - return r; -} - -static idx_t FindPrevStart(const ValidityMask &mask, const idx_t l, idx_t r, idx_t &n) { - if (mask.AllValid()) { - auto start = (r <= l + n) ? l : r - n; - n -= r - start; - return start; - } - - while (l < r) { - // If r is aligned with the start of a block, and the previous block is blank, - // then skip backwards one block. - idx_t entry_idx; - idx_t shift; - mask.GetEntryIndex(r - 1, entry_idx, shift); - - const auto block = mask.GetValidityEntry(entry_idx); - if (mask.NoneValid(block) && (shift + 1 == ValidityMask::BITS_PER_VALUE)) { - // r is nonzero (> l) and word aligned, so this will not underflow. - r -= ValidityMask::BITS_PER_VALUE; - continue; - } - - // Loop backwards over the block - // shift is probing r-1 >= l >= 0 - for (++shift; shift-- > 0; --r) { - if (mask.RowIsValid(block, shift) && --n == 0) { - return MaxValue(l, r - 1); - } - } - } - - // Didn't find a start so return the start of the range - return l; -} - -template -static T GetCell(const DataChunk &chunk, idx_t column, idx_t index) { - D_ASSERT(chunk.ColumnCount() > column); - auto &source = chunk.data[column]; - const auto data = FlatVector::GetData(source); - return data[index]; -} - -static bool CellIsNull(const DataChunk &chunk, idx_t column, idx_t index) { - D_ASSERT(chunk.ColumnCount() > column); - auto &source = chunk.data[column]; - return FlatVector::IsNull(source, index); -} - -static void CopyCell(const DataChunk &chunk, idx_t column, idx_t index, Vector &target, idx_t target_offset) { - D_ASSERT(chunk.ColumnCount() > column); - auto &source = chunk.data[column]; - VectorOperations::Copy(source, target, index + 1, index, target_offset); -} - -//===--------------------------------------------------------------------===// -// WindowColumnIterator -//===--------------------------------------------------------------------===// -template -struct WindowColumnIterator { - using iterator = WindowColumnIterator; - using iterator_category = std::random_access_iterator_tag; - using difference_type = std::ptrdiff_t; - using value_type = T; - using reference = T; - using pointer = idx_t; - - explicit WindowColumnIterator(const WindowInputColumn &coll_p, pointer pos_p = 0) : coll(&coll_p), pos(pos_p) { - } - - // Forward iterator - inline reference operator*() const { - return coll->GetCell(pos); - } - inline explicit operator pointer() const { - return pos; - } - - inline iterator &operator++() { - ++pos; - return *this; - } - inline iterator operator++(int) { - auto result = *this; - ++(*this); - return result; - } - - // Bidirectional iterator - inline iterator &operator--() { - --pos; - return *this; - } - inline iterator operator--(int) { - auto result = *this; - --(*this); - return result; - } - - // Random Access - inline iterator &operator+=(difference_type n) { - pos += n; - return *this; - } - inline iterator &operator-=(difference_type n) { - pos -= n; - return *this; - } - - inline reference operator[](difference_type m) const { - return coll->GetCell(pos + m); - } - - friend inline iterator &operator+(const iterator &a, difference_type n) { - return iterator(a.coll, a.pos + n); - } - - friend inline iterator &operator-(const iterator &a, difference_type n) { - return iterator(a.coll, a.pos - n); - } - - friend inline iterator &operator+(difference_type n, const iterator &a) { - return a + n; - } - friend inline difference_type operator-(const iterator &a, const iterator &b) { - return difference_type(a.pos - b.pos); - } - - friend inline bool operator==(const iterator &a, const iterator &b) { - return a.pos == b.pos; - } - friend inline bool operator!=(const iterator &a, const iterator &b) { - return a.pos != b.pos; - } - friend inline bool operator<(const iterator &a, const iterator &b) { - return a.pos < b.pos; - } - friend inline bool operator<=(const iterator &a, const iterator &b) { - return a.pos <= b.pos; - } - friend inline bool operator>(const iterator &a, const iterator &b) { - return a.pos > b.pos; - } - friend inline bool operator>=(const iterator &a, const iterator &b) { - return a.pos >= b.pos; - } - -private: - optional_ptr coll; - pointer pos; -}; - -template -struct OperationCompare : public std::function { - inline bool operator()(const T &lhs, const T &val) const { - return OP::template Operation(lhs, val); - } -}; - -template -static idx_t FindTypedRangeBound(const WindowInputColumn &over, const idx_t order_begin, const idx_t order_end, - WindowInputExpression &boundary, const idx_t chunk_idx, const FrameBounds &prev) { - D_ASSERT(!boundary.CellIsNull(chunk_idx)); - const auto val = boundary.GetCell(chunk_idx); - - OperationCompare comp; - WindowColumnIterator begin(over, order_begin); - WindowColumnIterator end(over, order_end); - - if (order_begin < prev.start && prev.start < order_end) { - const auto first = over.GetCell(prev.start); - if (!comp(val, first)) { - // prev.first <= val, so we can start further forward - begin += (prev.start - order_begin); - } - } - if (order_begin <= prev.end && prev.end < order_end) { - const auto second = over.GetCell(prev.end); - if (!comp(second, val)) { - // val <= prev.second, so we can end further back - // (prev.second is the largest peer) - end -= (order_end - prev.end - 1); - } - } - - if (FROM) { - return idx_t(std::lower_bound(begin, end, val, comp)); - } else { - return idx_t(std::upper_bound(begin, end, val, comp)); - } -} - -template -static idx_t FindRangeBound(const WindowInputColumn &over, const idx_t order_begin, const idx_t order_end, - WindowInputExpression &boundary, const idx_t chunk_idx, const FrameBounds &prev) { - D_ASSERT(boundary.chunk.ColumnCount() == 1); - D_ASSERT(boundary.chunk.data[0].GetType().InternalType() == over.input_expr.ptype); - - switch (over.input_expr.ptype) { - case PhysicalType::INT8: - return FindTypedRangeBound(over, order_begin, order_end, boundary, chunk_idx, prev); - case PhysicalType::INT16: - return FindTypedRangeBound(over, order_begin, order_end, boundary, chunk_idx, prev); - case PhysicalType::INT32: - return FindTypedRangeBound(over, order_begin, order_end, boundary, chunk_idx, prev); - case PhysicalType::INT64: - return FindTypedRangeBound(over, order_begin, order_end, boundary, chunk_idx, prev); - case PhysicalType::UINT8: - return FindTypedRangeBound(over, order_begin, order_end, boundary, chunk_idx, prev); - case PhysicalType::UINT16: - return FindTypedRangeBound(over, order_begin, order_end, boundary, chunk_idx, prev); - case PhysicalType::UINT32: - return FindTypedRangeBound(over, order_begin, order_end, boundary, chunk_idx, prev); - case PhysicalType::UINT64: - return FindTypedRangeBound(over, order_begin, order_end, boundary, chunk_idx, prev); - case PhysicalType::INT128: - return FindTypedRangeBound(over, order_begin, order_end, boundary, chunk_idx, prev); - case PhysicalType::FLOAT: - return FindTypedRangeBound(over, order_begin, order_end, boundary, chunk_idx, prev); - case PhysicalType::DOUBLE: - return FindTypedRangeBound(over, order_begin, order_end, boundary, chunk_idx, prev); - case PhysicalType::INTERVAL: - return FindTypedRangeBound(over, order_begin, order_end, boundary, chunk_idx, prev); - default: - throw InternalException("Unsupported column type for RANGE"); - } -} - -template -static idx_t FindOrderedRangeBound(const WindowInputColumn &over, const OrderType range_sense, const idx_t order_begin, - const idx_t order_end, WindowInputExpression &boundary, const idx_t chunk_idx, - const FrameBounds &prev) { - switch (range_sense) { - case OrderType::ASCENDING: - return FindRangeBound(over, order_begin, order_end, boundary, chunk_idx, prev); - case OrderType::DESCENDING: - return FindRangeBound(over, order_begin, order_end, boundary, chunk_idx, prev); - default: - throw InternalException("Unsupported ORDER BY sense for RANGE"); - } -} - -struct WindowBoundariesState { - static inline bool IsScalar(const unique_ptr &expr) { - return expr ? expr->IsScalar() : true; - } - - static inline bool BoundaryNeedsPeer(const WindowBoundary &boundary) { - switch (boundary) { - case WindowBoundary::CURRENT_ROW_RANGE: - case WindowBoundary::EXPR_PRECEDING_RANGE: - case WindowBoundary::EXPR_FOLLOWING_RANGE: - return true; - default: - return false; - } - } - - WindowBoundariesState(BoundWindowExpression &wexpr, const idx_t input_size); - - void Update(const idx_t row_idx, const WindowInputColumn &range_collection, const idx_t chunk_idx, - WindowInputExpression &boundary_start, WindowInputExpression &boundary_end, - const ValidityMask &partition_mask, const ValidityMask &order_mask); - - void Bounds(DataChunk &bounds, idx_t row_idx, const WindowInputColumn &range, const idx_t count, - WindowInputExpression &boundary_start, WindowInputExpression &boundary_end, - const ValidityMask &partition_mask, const ValidityMask &order_mask); - - // Cached lookups - const ExpressionType type; - const idx_t input_size; - const WindowBoundary start_boundary; - const WindowBoundary end_boundary; - const size_t partition_count; - const size_t order_count; - const OrderType range_sense; - const bool has_preceding_range; - const bool has_following_range; - const bool needs_peer; - - idx_t next_pos = 0; - idx_t partition_start = 0; - idx_t partition_end = 0; - idx_t peer_start = 0; - idx_t peer_end = 0; - idx_t valid_start = 0; - idx_t valid_end = 0; - int64_t window_start = -1; - int64_t window_end = -1; - FrameBounds prev; -}; - -//===--------------------------------------------------------------------===// -// WindowBoundariesState -//===--------------------------------------------------------------------===// -void WindowBoundariesState::Update(const idx_t row_idx, const WindowInputColumn &range_collection, - const idx_t chunk_idx, WindowInputExpression &boundary_start, - WindowInputExpression &boundary_end, const ValidityMask &partition_mask, - const ValidityMask &order_mask) { - - if (partition_count + order_count > 0) { - - // determine partition and peer group boundaries to ultimately figure out window size - const auto is_same_partition = !partition_mask.RowIsValidUnsafe(row_idx); - const auto is_peer = !order_mask.RowIsValidUnsafe(row_idx); - const auto is_jump = (next_pos != row_idx); - - // when the partition changes, recompute the boundaries - if (!is_same_partition || is_jump) { - if (is_jump) { - idx_t n = 1; - partition_start = FindPrevStart(partition_mask, 0, row_idx + 1, n); - n = 1; - peer_start = FindPrevStart(order_mask, 0, row_idx + 1, n); - } else { - partition_start = row_idx; - peer_start = row_idx; - } - - // find end of partition - partition_end = input_size; - if (partition_count) { - idx_t n = 1; - partition_end = FindNextStart(partition_mask, partition_start + 1, input_size, n); - } - - // Find valid ordering values for the new partition - // so we can exclude NULLs from RANGE expression computations - valid_start = partition_start; - valid_end = partition_end; - - if ((valid_start < valid_end) && has_preceding_range) { - // Exclude any leading NULLs - if (range_collection.CellIsNull(valid_start)) { - idx_t n = 1; - valid_start = FindNextStart(order_mask, valid_start + 1, valid_end, n); - } - } - - if ((valid_start < valid_end) && has_following_range) { - // Exclude any trailing NULLs - if (range_collection.CellIsNull(valid_end - 1)) { - idx_t n = 1; - valid_end = FindPrevStart(order_mask, valid_start, valid_end, n); - } - - // Reset range hints - prev.start = valid_start; - prev.end = valid_end; - } - } else if (!is_peer) { - peer_start = row_idx; - } - - if (needs_peer) { - peer_end = partition_end; - if (order_count) { - idx_t n = 1; - peer_end = FindNextStart(order_mask, peer_start + 1, partition_end, n); - } - } - - } else { - // OVER() - partition_end = input_size; - peer_end = partition_end; - } - next_pos = row_idx + 1; - - // determine window boundaries depending on the type of expression - window_start = -1; - window_end = -1; - - switch (start_boundary) { - case WindowBoundary::UNBOUNDED_PRECEDING: - window_start = partition_start; - break; - case WindowBoundary::CURRENT_ROW_ROWS: - window_start = row_idx; - break; - case WindowBoundary::CURRENT_ROW_RANGE: - window_start = peer_start; - break; - case WindowBoundary::EXPR_PRECEDING_ROWS: { - if (!TrySubtractOperator::Operation(int64_t(row_idx), boundary_start.GetCell(chunk_idx), - window_start)) { - throw OutOfRangeException("Overflow computing ROWS PRECEDING start"); - } - break; - } - case WindowBoundary::EXPR_FOLLOWING_ROWS: { - if (!TryAddOperator::Operation(int64_t(row_idx), boundary_start.GetCell(chunk_idx), window_start)) { - throw OutOfRangeException("Overflow computing ROWS FOLLOWING start"); - } - break; - } - case WindowBoundary::EXPR_PRECEDING_RANGE: { - if (boundary_start.CellIsNull(chunk_idx)) { - window_start = peer_start; - } else { - prev.start = FindOrderedRangeBound(range_collection, range_sense, valid_start, row_idx, - boundary_start, chunk_idx, prev); - window_start = prev.start; - } - break; - } - case WindowBoundary::EXPR_FOLLOWING_RANGE: { - if (boundary_start.CellIsNull(chunk_idx)) { - window_start = peer_start; - } else { - prev.start = FindOrderedRangeBound(range_collection, range_sense, row_idx, valid_end, boundary_start, - chunk_idx, prev); - window_start = prev.start; - } - break; - } - default: - throw InternalException("Unsupported window start boundary"); - } - - switch (end_boundary) { - case WindowBoundary::CURRENT_ROW_ROWS: - window_end = row_idx + 1; - break; - case WindowBoundary::CURRENT_ROW_RANGE: - window_end = peer_end; - break; - case WindowBoundary::UNBOUNDED_FOLLOWING: - window_end = partition_end; - break; - case WindowBoundary::EXPR_PRECEDING_ROWS: - if (!TrySubtractOperator::Operation(int64_t(row_idx + 1), boundary_end.GetCell(chunk_idx), - window_end)) { - throw OutOfRangeException("Overflow computing ROWS PRECEDING end"); - } - break; - case WindowBoundary::EXPR_FOLLOWING_ROWS: - if (!TryAddOperator::Operation(int64_t(row_idx + 1), boundary_end.GetCell(chunk_idx), window_end)) { - throw OutOfRangeException("Overflow computing ROWS FOLLOWING end"); - } - break; - case WindowBoundary::EXPR_PRECEDING_RANGE: { - if (boundary_end.CellIsNull(chunk_idx)) { - window_end = peer_end; - } else { - prev.end = FindOrderedRangeBound(range_collection, range_sense, valid_start, row_idx, boundary_end, - chunk_idx, prev); - window_end = prev.end; - } - break; - } - case WindowBoundary::EXPR_FOLLOWING_RANGE: { - if (boundary_end.CellIsNull(chunk_idx)) { - window_end = peer_end; - } else { - prev.end = FindOrderedRangeBound(range_collection, range_sense, row_idx, valid_end, boundary_end, - chunk_idx, prev); - window_end = prev.end; - } - break; - } - default: - throw InternalException("Unsupported window end boundary"); - } - - // clamp windows to partitions if they should exceed - if (window_start < (int64_t)partition_start) { - window_start = partition_start; - } - if (window_start > (int64_t)partition_end) { - window_start = partition_end; - } - if (window_end < (int64_t)partition_start) { - window_end = partition_start; - } - if (window_end > (int64_t)partition_end) { - window_end = partition_end; - } - - if (window_start < 0 || window_end < 0) { - throw InternalException("Failed to compute window boundaries"); - } -} - -static bool HasPrecedingRange(BoundWindowExpression &wexpr) { - return (wexpr.start == WindowBoundary::EXPR_PRECEDING_RANGE || wexpr.end == WindowBoundary::EXPR_PRECEDING_RANGE); -} - -static bool HasFollowingRange(BoundWindowExpression &wexpr) { - return (wexpr.start == WindowBoundary::EXPR_FOLLOWING_RANGE || wexpr.end == WindowBoundary::EXPR_FOLLOWING_RANGE); -} - -WindowBoundariesState::WindowBoundariesState(BoundWindowExpression &wexpr, const idx_t input_size) - : type(wexpr.type), input_size(input_size), start_boundary(wexpr.start), end_boundary(wexpr.end), - partition_count(wexpr.partitions.size()), order_count(wexpr.orders.size()), - range_sense(wexpr.orders.empty() ? OrderType::INVALID : wexpr.orders[0].type), - has_preceding_range(HasPrecedingRange(wexpr)), has_following_range(HasFollowingRange(wexpr)), - needs_peer(BoundaryNeedsPeer(wexpr.end) || wexpr.type == ExpressionType::WINDOW_CUME_DIST) { -} - -void WindowBoundariesState::Bounds(DataChunk &bounds, idx_t row_idx, const WindowInputColumn &range, const idx_t count, - WindowInputExpression &boundary_start, WindowInputExpression &boundary_end, - const ValidityMask &partition_mask, const ValidityMask &order_mask) { - bounds.Reset(); - D_ASSERT(bounds.ColumnCount() == 6); - auto partition_begin_data = FlatVector::GetData(bounds.data[PARTITION_BEGIN]); - auto partition_end_data = FlatVector::GetData(bounds.data[PARTITION_END]); - auto peer_begin_data = FlatVector::GetData(bounds.data[PEER_BEGIN]); - auto peer_end_data = FlatVector::GetData(bounds.data[PEER_END]); - auto window_begin_data = FlatVector::GetData(bounds.data[WINDOW_BEGIN]); - auto window_end_data = FlatVector::GetData(bounds.data[WINDOW_END]); - for (idx_t chunk_idx = 0; chunk_idx < count; ++chunk_idx, ++row_idx) { - Update(row_idx, range, chunk_idx, boundary_start, boundary_end, partition_mask, order_mask); - *partition_begin_data++ = partition_start; - *partition_end_data++ = partition_end; - if (needs_peer) { - *peer_begin_data++ = peer_start; - *peer_end_data++ = peer_end; - } - *window_begin_data++ = window_start; - *window_end_data++ = window_end; - } - bounds.SetCardinality(count); -} - -//===--------------------------------------------------------------------===// -// WindowExecutorBoundsState -//===--------------------------------------------------------------------===// -class WindowExecutorBoundsState : public WindowExecutorState { -public: - WindowExecutorBoundsState(BoundWindowExpression &wexpr, ClientContext &context, const idx_t count, - const ValidityMask &partition_mask_p, const ValidityMask &order_mask_p); - ~WindowExecutorBoundsState() override { - } - - virtual void UpdateBounds(idx_t row_idx, DataChunk &input_chunk, const WindowInputColumn &range); - - // Frame management - const ValidityMask &partition_mask; - const ValidityMask &order_mask; - DataChunk bounds; - WindowBoundariesState state; - - // evaluate boundaries if present. Parser has checked boundary types. - WindowInputExpression boundary_start; - WindowInputExpression boundary_end; -}; - -WindowExecutorBoundsState::WindowExecutorBoundsState(BoundWindowExpression &wexpr, ClientContext &context, - const idx_t payload_count, const ValidityMask &partition_mask_p, - const ValidityMask &order_mask_p) - : partition_mask(partition_mask_p), order_mask(order_mask_p), state(wexpr, payload_count), - boundary_start(wexpr.start_expr.get(), context), boundary_end(wexpr.end_expr.get(), context) { - vector bounds_types(6, LogicalType(LogicalTypeId::UBIGINT)); - bounds.Initialize(Allocator::Get(context), bounds_types); -} - -void WindowExecutorBoundsState::UpdateBounds(idx_t row_idx, DataChunk &input_chunk, const WindowInputColumn &range) { - // Evaluate the row-level arguments - boundary_start.Execute(input_chunk); - boundary_end.Execute(input_chunk); - - const auto count = input_chunk.size(); - bounds.Reset(); - state.Bounds(bounds, row_idx, range, count, boundary_start, boundary_end, partition_mask, order_mask); -} - -//===--------------------------------------------------------------------===// -// WindowExecutor -//===--------------------------------------------------------------------===// -static void PrepareInputExpressions(vector> &exprs, ExpressionExecutor &executor, - DataChunk &chunk) { - if (exprs.empty()) { - return; - } - - vector types; - for (idx_t expr_idx = 0; expr_idx < exprs.size(); ++expr_idx) { - types.push_back(exprs[expr_idx]->return_type); - executor.AddExpression(*exprs[expr_idx]); - } - - if (!types.empty()) { - auto &allocator = executor.GetAllocator(); - chunk.Initialize(allocator, types); - } -} - -WindowExecutor::WindowExecutor(BoundWindowExpression &wexpr, ClientContext &context, const idx_t payload_count, - const ValidityMask &partition_mask, const ValidityMask &order_mask) - : wexpr(wexpr), context(context), payload_count(payload_count), partition_mask(partition_mask), - order_mask(order_mask), payload_collection(), payload_executor(context), - range((HasPrecedingRange(wexpr) || HasFollowingRange(wexpr)) ? wexpr.orders[0].expression.get() : nullptr, - context, payload_count) { - // TODO: child may be a scalar, don't need to materialize the whole collection then - - // evaluate inner expressions of window functions, could be more complex - PrepareInputExpressions(wexpr.children, payload_executor, payload_chunk); - - auto types = payload_chunk.GetTypes(); - if (!types.empty()) { - payload_collection.Initialize(Allocator::Get(context), types); - } -} - -unique_ptr WindowExecutor::GetExecutorState() const { - return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); -} - -//===--------------------------------------------------------------------===// -// WindowAggregateExecutor -//===--------------------------------------------------------------------===// -bool WindowAggregateExecutor::IsConstantAggregate() { - if (!wexpr.aggregate) { - return false; - } - - // COUNT(*) is already handled efficiently by segment trees. - if (wexpr.children.empty()) { - return false; - } - - /* - The default framing option is RANGE UNBOUNDED PRECEDING, which - is the same as RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT - ROW; it sets the frame to be all rows from the partition start - up through the current row's last peer (a row that the window's - ORDER BY clause considers equivalent to the current row; all - rows are peers if there is no ORDER BY). In general, UNBOUNDED - PRECEDING means that the frame starts with the first row of the - partition, and similarly UNBOUNDED FOLLOWING means that the - frame ends with the last row of the partition, regardless of - RANGE, ROWS or GROUPS mode. In ROWS mode, CURRENT ROW means that - the frame starts or ends with the current row; but in RANGE or - GROUPS mode it means that the frame starts or ends with the - current row's first or last peer in the ORDER BY ordering. The - offset PRECEDING and offset FOLLOWING options vary in meaning - depending on the frame mode. - */ - switch (wexpr.start) { - case WindowBoundary::UNBOUNDED_PRECEDING: - break; - case WindowBoundary::CURRENT_ROW_RANGE: - if (!wexpr.orders.empty()) { - return false; - } - break; - default: - return false; - } - - switch (wexpr.end) { - case WindowBoundary::UNBOUNDED_FOLLOWING: - break; - case WindowBoundary::CURRENT_ROW_RANGE: - if (!wexpr.orders.empty()) { - return false; - } - break; - default: - return false; - } - - return true; -} - -bool WindowAggregateExecutor::IsCustomAggregate() { - if (!wexpr.aggregate) { - return false; - } - - if (!AggregateObject(wexpr).function.window) { - return false; - } - - return (mode < WindowAggregationMode::COMBINE); -} - -void WindowExecutor::Evaluate(idx_t row_idx, DataChunk &input_chunk, Vector &result, - WindowExecutorState &lstate) const { - auto &lbstate = lstate.Cast(); - lbstate.UpdateBounds(row_idx, input_chunk, range); - - const auto count = input_chunk.size(); - EvaluateInternal(lstate, result, count, row_idx); - - result.Verify(count); -} - -WindowAggregateExecutor::WindowAggregateExecutor(BoundWindowExpression &wexpr, ClientContext &context, - const idx_t count, const ValidityMask &partition_mask, - const ValidityMask &order_mask, WindowAggregationMode mode) - : WindowExecutor(wexpr, context, count, partition_mask, order_mask), mode(mode), filter_executor(context) { - // TODO we could evaluate those expressions in parallel - - // Check for constant aggregate - if (IsConstantAggregate()) { - aggregator = - make_uniq(AggregateObject(wexpr), wexpr.return_type, partition_mask, count); - } else if (IsCustomAggregate()) { - aggregator = make_uniq(AggregateObject(wexpr), wexpr.return_type, count); - } else if (wexpr.aggregate) { - // build a segment tree for frame-adhering aggregates - // see http://www.vldb.org/pvldb/vol8/p1058-leis.pdf - aggregator = make_uniq(AggregateObject(wexpr), wexpr.return_type, count, mode); - } - - // evaluate the FILTER clause and stuff it into a large mask for compactness and reuse - if (wexpr.filter_expr) { - filter_executor.AddExpression(*wexpr.filter_expr); - filter_sel.Initialize(STANDARD_VECTOR_SIZE); - } -} - -void WindowAggregateExecutor::Sink(DataChunk &input_chunk, const idx_t input_idx, const idx_t total_count) { - idx_t filtered = 0; - SelectionVector *filtering = nullptr; - if (wexpr.filter_expr) { - filtering = &filter_sel; - filtered = filter_executor.SelectExpression(input_chunk, filter_sel); - } - - if (!wexpr.children.empty()) { - payload_chunk.Reset(); - payload_executor.Execute(input_chunk, payload_chunk); - payload_chunk.Verify(); - } else if (aggregator) { - // Zero-argument aggregate (e.g., COUNT(*) - payload_chunk.SetCardinality(input_chunk); - } - - D_ASSERT(aggregator); - aggregator->Sink(payload_chunk, filtering, filtered); - - WindowExecutor::Sink(input_chunk, input_idx, total_count); -} - -void WindowAggregateExecutor::Finalize() { - D_ASSERT(aggregator); - aggregator->Finalize(); -} - -class WindowAggregateState : public WindowExecutorBoundsState { -public: - WindowAggregateState(BoundWindowExpression &wexpr, ClientContext &context, const idx_t payload_count, - const ValidityMask &partition_mask, const ValidityMask &order_mask, - const WindowAggregator &aggregator) - : WindowExecutorBoundsState(wexpr, context, payload_count, partition_mask, order_mask), - aggregator_state(aggregator.GetLocalState()) { - } - -public: - unique_ptr aggregator_state; - - void NextRank(idx_t partition_begin, idx_t peer_begin, idx_t row_idx); -}; - -unique_ptr WindowAggregateExecutor::GetExecutorState() const { - return make_uniq(wexpr, context, payload_count, partition_mask, order_mask, *aggregator); -} - -void WindowAggregateExecutor::EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, - idx_t row_idx) const { - auto &lastate = lstate.Cast(); - D_ASSERT(aggregator); - auto window_begin = FlatVector::GetData(lastate.bounds.data[WINDOW_BEGIN]); - auto window_end = FlatVector::GetData(lastate.bounds.data[WINDOW_END]); - aggregator->Evaluate(*lastate.aggregator_state, window_begin, window_end, result, count); -} - -//===--------------------------------------------------------------------===// -// WindowRowNumberExecutor -//===--------------------------------------------------------------------===// -WindowRowNumberExecutor::WindowRowNumberExecutor(BoundWindowExpression &wexpr, ClientContext &context, - const idx_t payload_count, const ValidityMask &partition_mask, - const ValidityMask &order_mask) - : WindowExecutor(wexpr, context, payload_count, partition_mask, order_mask) { -} - -void WindowRowNumberExecutor::EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, - idx_t row_idx) const { - auto &lbstate = lstate.Cast(); - auto partition_begin = FlatVector::GetData(lbstate.bounds.data[PARTITION_BEGIN]); - auto rdata = FlatVector::GetData(result); - for (idx_t i = 0; i < count; ++i, ++row_idx) { - rdata[i] = row_idx - partition_begin[i] + 1; - } -} - -//===--------------------------------------------------------------------===// -// WindowPeerState -//===--------------------------------------------------------------------===// -class WindowPeerState : public WindowExecutorBoundsState { -public: - WindowPeerState(BoundWindowExpression &wexpr, ClientContext &context, const idx_t payload_count, - const ValidityMask &partition_mask, const ValidityMask &order_mask) - : WindowExecutorBoundsState(wexpr, context, payload_count, partition_mask, order_mask) { - } - -public: - uint64_t dense_rank = 1; - uint64_t rank_equal = 0; - uint64_t rank = 1; - - void NextRank(idx_t partition_begin, idx_t peer_begin, idx_t row_idx); -}; - -void WindowPeerState::NextRank(idx_t partition_begin, idx_t peer_begin, idx_t row_idx) { - if (partition_begin == row_idx) { - dense_rank = 1; - rank = 1; - rank_equal = 0; - } else if (peer_begin == row_idx) { - dense_rank++; - rank += rank_equal; - rank_equal = 0; - } - rank_equal++; -} - -WindowRankExecutor::WindowRankExecutor(BoundWindowExpression &wexpr, ClientContext &context, const idx_t payload_count, - const ValidityMask &partition_mask, const ValidityMask &order_mask) - : WindowExecutor(wexpr, context, payload_count, partition_mask, order_mask) { -} - -unique_ptr WindowRankExecutor::GetExecutorState() const { - return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); -} - -void WindowRankExecutor::EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, - idx_t row_idx) const { - auto &lpeer = lstate.Cast(); - auto partition_begin = FlatVector::GetData(lpeer.bounds.data[PARTITION_BEGIN]); - auto peer_begin = FlatVector::GetData(lpeer.bounds.data[PEER_BEGIN]); - auto rdata = FlatVector::GetData(result); - - // Reset to "previous" row - lpeer.rank = (peer_begin[0] - partition_begin[0]) + 1; - lpeer.rank_equal = (row_idx - peer_begin[0]); - - for (idx_t i = 0; i < count; ++i, ++row_idx) { - lpeer.NextRank(partition_begin[i], peer_begin[i], row_idx); - rdata[i] = lpeer.rank; - } -} - -WindowDenseRankExecutor::WindowDenseRankExecutor(BoundWindowExpression &wexpr, ClientContext &context, - const idx_t payload_count, const ValidityMask &partition_mask, - const ValidityMask &order_mask) - : WindowExecutor(wexpr, context, payload_count, partition_mask, order_mask) { -} - -unique_ptr WindowDenseRankExecutor::GetExecutorState() const { - return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); -} - -void WindowDenseRankExecutor::EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, - idx_t row_idx) const { - auto &lpeer = lstate.Cast(); - auto partition_begin = FlatVector::GetData(lpeer.bounds.data[PARTITION_BEGIN]); - auto peer_begin = FlatVector::GetData(lpeer.bounds.data[PEER_BEGIN]); - auto rdata = FlatVector::GetData(result); - - // Reset to "previous" row - lpeer.rank = (peer_begin[0] - partition_begin[0]) + 1; - lpeer.rank_equal = (row_idx - peer_begin[0]); - - // The previous dense rank is the number of order mask bits in [partition_begin, row_idx) - lpeer.dense_rank = 0; - - auto order_begin = partition_begin[0]; - idx_t begin_idx; - idx_t begin_offset; - order_mask.GetEntryIndex(order_begin, begin_idx, begin_offset); - - auto order_end = row_idx; - idx_t end_idx; - idx_t end_offset; - order_mask.GetEntryIndex(order_end, end_idx, end_offset); - - // If they are in the same entry, just loop - if (begin_idx == end_idx) { - const auto entry = order_mask.GetValidityEntry(begin_idx); - for (; begin_offset < end_offset; ++begin_offset) { - lpeer.dense_rank += order_mask.RowIsValid(entry, begin_offset); - } - } else { - // Count the ragged bits at the start of the partition - if (begin_offset) { - const auto entry = order_mask.GetValidityEntry(begin_idx); - for (; begin_offset < order_mask.BITS_PER_VALUE; ++begin_offset) { - lpeer.dense_rank += order_mask.RowIsValid(entry, begin_offset); - ++order_begin; - } - ++begin_idx; - } - - // Count the the aligned bits. - ValidityMask tail_mask(order_mask.GetData() + begin_idx); - lpeer.dense_rank += tail_mask.CountValid(order_end - order_begin); - } - - for (idx_t i = 0; i < count; ++i, ++row_idx) { - lpeer.NextRank(partition_begin[i], peer_begin[i], row_idx); - rdata[i] = lpeer.dense_rank; - } -} - -WindowPercentRankExecutor::WindowPercentRankExecutor(BoundWindowExpression &wexpr, ClientContext &context, - const idx_t payload_count, const ValidityMask &partition_mask, - const ValidityMask &order_mask) - : WindowExecutor(wexpr, context, payload_count, partition_mask, order_mask) { -} - -unique_ptr WindowPercentRankExecutor::GetExecutorState() const { - return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); -} - -void WindowPercentRankExecutor::EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, - idx_t row_idx) const { - auto &lpeer = lstate.Cast(); - auto partition_begin = FlatVector::GetData(lpeer.bounds.data[PARTITION_BEGIN]); - auto partition_end = FlatVector::GetData(lpeer.bounds.data[PARTITION_END]); - auto peer_begin = FlatVector::GetData(lpeer.bounds.data[PEER_BEGIN]); - auto rdata = FlatVector::GetData(result); - - // Reset to "previous" row - lpeer.rank = (peer_begin[0] - partition_begin[0]) + 1; - lpeer.rank_equal = (row_idx - peer_begin[0]); - - for (idx_t i = 0; i < count; ++i, ++row_idx) { - lpeer.NextRank(partition_begin[i], peer_begin[i], row_idx); - int64_t denom = partition_end[i] - partition_begin[i] - 1; - double percent_rank = denom > 0 ? ((double)lpeer.rank - 1) / denom : 0; - rdata[i] = percent_rank; - } -} - -//===--------------------------------------------------------------------===// -// WindowCumeDistExecutor -//===--------------------------------------------------------------------===// -WindowCumeDistExecutor::WindowCumeDistExecutor(BoundWindowExpression &wexpr, ClientContext &context, - const idx_t payload_count, const ValidityMask &partition_mask, - const ValidityMask &order_mask) - : WindowExecutor(wexpr, context, payload_count, partition_mask, order_mask) { -} - -void WindowCumeDistExecutor::EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, - idx_t row_idx) const { - auto &lbstate = lstate.Cast(); - auto partition_begin = FlatVector::GetData(lbstate.bounds.data[PARTITION_BEGIN]); - auto partition_end = FlatVector::GetData(lbstate.bounds.data[PARTITION_END]); - auto peer_end = FlatVector::GetData(lbstate.bounds.data[PEER_END]); - auto rdata = FlatVector::GetData(result); - for (idx_t i = 0; i < count; ++i, ++row_idx) { - int64_t denom = partition_end[i] - partition_begin[i]; - double cume_dist = denom > 0 ? ((double)(peer_end[i] - partition_begin[i])) / denom : 0; - rdata[i] = cume_dist; - } -} - -//===--------------------------------------------------------------------===// -// WindowValueExecutor -//===--------------------------------------------------------------------===// -WindowValueExecutor::WindowValueExecutor(BoundWindowExpression &wexpr, ClientContext &context, - const idx_t payload_count, const ValidityMask &partition_mask, - const ValidityMask &order_mask) - : WindowExecutor(wexpr, context, payload_count, partition_mask, order_mask) { -} - -WindowNtileExecutor::WindowNtileExecutor(BoundWindowExpression &wexpr, ClientContext &context, - const idx_t payload_count, const ValidityMask &partition_mask, - const ValidityMask &order_mask) - : WindowValueExecutor(wexpr, context, payload_count, partition_mask, order_mask) { -} - -void WindowValueExecutor::Sink(DataChunk &input_chunk, const idx_t input_idx, const idx_t total_count) { - // Single pass over the input to produce the global data. - // Vectorisation for the win... - - // Set up a validity mask for IGNORE NULLS - bool check_nulls = false; - if (wexpr.ignore_nulls) { - switch (wexpr.type) { - case ExpressionType::WINDOW_LEAD: - case ExpressionType::WINDOW_LAG: - case ExpressionType::WINDOW_FIRST_VALUE: - case ExpressionType::WINDOW_LAST_VALUE: - case ExpressionType::WINDOW_NTH_VALUE: - check_nulls = true; - break; - default: - break; - } - } - - if (!wexpr.children.empty()) { - payload_chunk.Reset(); - payload_executor.Execute(input_chunk, payload_chunk); - payload_chunk.Verify(); - payload_collection.Append(payload_chunk, true); - - // process payload chunks while they are still piping hot - if (check_nulls) { - const auto count = input_chunk.size(); - - UnifiedVectorFormat vdata; - payload_chunk.data[0].ToUnifiedFormat(count, vdata); - if (!vdata.validity.AllValid()) { - // Lazily materialise the contents when we find the first NULL - if (ignore_nulls.AllValid()) { - ignore_nulls.Initialize(total_count); - } - // Write to the current position - if (input_idx % ValidityMask::BITS_PER_VALUE == 0) { - // If we are at the edge of an output entry, just copy the entries - auto dst = ignore_nulls.GetData() + ignore_nulls.EntryCount(input_idx); - auto src = vdata.validity.GetData(); - for (auto entry_count = vdata.validity.EntryCount(count); entry_count-- > 0;) { - *dst++ = *src++; - } - } else { - // If not, we have ragged data and need to copy one bit at a time. - for (idx_t i = 0; i < count; ++i) { - ignore_nulls.Set(input_idx + i, vdata.validity.RowIsValid(i)); - } - } - } - } - } - - WindowExecutor::Sink(input_chunk, input_idx, total_count); -} - -void WindowNtileExecutor::EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, - idx_t row_idx) const { - D_ASSERT(payload_collection.ColumnCount() == 1); - auto &lbstate = lstate.Cast(); - auto partition_begin = FlatVector::GetData(lbstate.bounds.data[PARTITION_BEGIN]); - auto partition_end = FlatVector::GetData(lbstate.bounds.data[PARTITION_END]); - auto rdata = FlatVector::GetData(result); - for (idx_t i = 0; i < count; ++i, ++row_idx) { - if (CellIsNull(payload_collection, 0, row_idx)) { - FlatVector::SetNull(result, i, true); - } else { - auto n_param = GetCell(payload_collection, 0, row_idx); - if (n_param < 1) { - throw InvalidInputException("Argument for ntile must be greater than zero"); - } - // With thanks from SQLite's ntileValueFunc() - int64_t n_total = partition_end[i] - partition_begin[i]; - if (n_param > n_total) { - // more groups allowed than we have values - // map every entry to a unique group - n_param = n_total; - } - int64_t n_size = (n_total / n_param); - // find the row idx within the group - D_ASSERT(row_idx >= partition_begin[i]); - int64_t adjusted_row_idx = row_idx - partition_begin[i]; - // now compute the ntile - int64_t n_large = n_total - n_param * n_size; - int64_t i_small = n_large * (n_size + 1); - int64_t result_ntile; - - D_ASSERT((n_large * (n_size + 1) + (n_param - n_large) * n_size) == n_total); - - if (adjusted_row_idx < i_small) { - result_ntile = 1 + adjusted_row_idx / (n_size + 1); - } else { - result_ntile = 1 + n_large + (adjusted_row_idx - i_small) / n_size; - } - // result has to be between [1, NTILE] - D_ASSERT(result_ntile >= 1 && result_ntile <= n_param); - rdata[i] = result_ntile; - } - } -} - -//===--------------------------------------------------------------------===// -// WindowLeadLagState -//===--------------------------------------------------------------------===// -class WindowLeadLagState : public WindowExecutorBoundsState { -public: - WindowLeadLagState(BoundWindowExpression &wexpr, ClientContext &context, const idx_t payload_count, - const ValidityMask &partition_mask, const ValidityMask &order_mask) - : WindowExecutorBoundsState(wexpr, context, payload_count, partition_mask, order_mask), - leadlag_offset(wexpr.offset_expr.get(), context), leadlag_default(wexpr.default_expr.get(), context) { - } - - void UpdateBounds(idx_t row_idx, DataChunk &input_chunk, const WindowInputColumn &range) override; - -public: - // LEAD/LAG Evaluation - WindowInputExpression leadlag_offset; - WindowInputExpression leadlag_default; -}; - -void WindowLeadLagState::UpdateBounds(idx_t row_idx, DataChunk &input_chunk, const WindowInputColumn &range) { - // Evaluate the row-level arguments - leadlag_offset.Execute(input_chunk); - leadlag_default.Execute(input_chunk); - - WindowExecutorBoundsState::UpdateBounds(row_idx, input_chunk, range); -} - -WindowLeadLagExecutor::WindowLeadLagExecutor(BoundWindowExpression &wexpr, ClientContext &context, - const idx_t payload_count, const ValidityMask &partition_mask, - const ValidityMask &order_mask) - : WindowValueExecutor(wexpr, context, payload_count, partition_mask, order_mask) { -} - -unique_ptr WindowLeadLagExecutor::GetExecutorState() const { - return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); -} - -void WindowLeadLagExecutor::EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, - idx_t row_idx) const { - auto &llstate = lstate.Cast(); - - auto partition_begin = FlatVector::GetData(llstate.bounds.data[PARTITION_BEGIN]); - auto partition_end = FlatVector::GetData(llstate.bounds.data[PARTITION_END]); - for (idx_t i = 0; i < count; ++i, ++row_idx) { - int64_t offset = 1; - if (wexpr.offset_expr) { - offset = llstate.leadlag_offset.GetCell(i); - } - int64_t val_idx = (int64_t)row_idx; - if (wexpr.type == ExpressionType::WINDOW_LEAD) { - val_idx = AddOperatorOverflowCheck::Operation(val_idx, offset); - } else { - val_idx = SubtractOperatorOverflowCheck::Operation(val_idx, offset); - } - - idx_t delta = 0; - if (val_idx < (int64_t)row_idx) { - // Count backwards - delta = idx_t(row_idx - val_idx); - val_idx = FindPrevStart(ignore_nulls, partition_begin[i], row_idx, delta); - } else if (val_idx > (int64_t)row_idx) { - delta = idx_t(val_idx - row_idx); - val_idx = FindNextStart(ignore_nulls, row_idx + 1, partition_end[i], delta); - } - // else offset is zero, so don't move. - - if (!delta) { - CopyCell(payload_collection, 0, val_idx, result, i); - } else if (wexpr.default_expr) { - llstate.leadlag_default.CopyCell(result, i); - } else { - FlatVector::SetNull(result, i, true); - } - } -} - -WindowFirstValueExecutor::WindowFirstValueExecutor(BoundWindowExpression &wexpr, ClientContext &context, - const idx_t payload_count, const ValidityMask &partition_mask, - const ValidityMask &order_mask) - : WindowValueExecutor(wexpr, context, payload_count, partition_mask, order_mask) { -} - -void WindowFirstValueExecutor::EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, - idx_t row_idx) const { - auto &lbstate = lstate.Cast(); - auto window_begin = FlatVector::GetData(lbstate.bounds.data[WINDOW_BEGIN]); - auto window_end = FlatVector::GetData(lbstate.bounds.data[WINDOW_END]); - for (idx_t i = 0; i < count; ++i, ++row_idx) { - if (window_begin[i] >= window_end[i]) { - FlatVector::SetNull(result, i, true); - continue; - } - // Same as NTH_VALUE(..., 1) - idx_t n = 1; - const auto first_idx = FindNextStart(ignore_nulls, window_begin[i], window_end[i], n); - if (!n) { - CopyCell(payload_collection, 0, first_idx, result, i); - } else { - FlatVector::SetNull(result, i, true); - } - } -} - -WindowLastValueExecutor::WindowLastValueExecutor(BoundWindowExpression &wexpr, ClientContext &context, - const idx_t payload_count, const ValidityMask &partition_mask, - const ValidityMask &order_mask) - : WindowValueExecutor(wexpr, context, payload_count, partition_mask, order_mask) { -} - -void WindowLastValueExecutor::EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, - idx_t row_idx) const { - auto &lbstate = lstate.Cast(); - auto window_begin = FlatVector::GetData(lbstate.bounds.data[WINDOW_BEGIN]); - auto window_end = FlatVector::GetData(lbstate.bounds.data[WINDOW_END]); - for (idx_t i = 0; i < count; ++i, ++row_idx) { - if (window_begin[i] >= window_end[i]) { - FlatVector::SetNull(result, i, true); - continue; - } - idx_t n = 1; - const auto last_idx = FindPrevStart(ignore_nulls, window_begin[i], window_end[i], n); - if (!n) { - CopyCell(payload_collection, 0, last_idx, result, i); - } else { - FlatVector::SetNull(result, i, true); - } - } -} - -WindowNthValueExecutor::WindowNthValueExecutor(BoundWindowExpression &wexpr, ClientContext &context, - const idx_t payload_count, const ValidityMask &partition_mask, - const ValidityMask &order_mask) - : WindowValueExecutor(wexpr, context, payload_count, partition_mask, order_mask) { -} - -void WindowNthValueExecutor::EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, - idx_t row_idx) const { - D_ASSERT(payload_collection.ColumnCount() == 2); - - auto &lbstate = lstate.Cast(); - auto window_begin = FlatVector::GetData(lbstate.bounds.data[WINDOW_BEGIN]); - auto window_end = FlatVector::GetData(lbstate.bounds.data[WINDOW_END]); - for (idx_t i = 0; i < count; ++i, ++row_idx) { - if (window_begin[i] >= window_end[i]) { - FlatVector::SetNull(result, i, true); - continue; - } - // Returns value evaluated at the row that is the n'th row of the window frame (counting from 1); - // returns NULL if there is no such row. - if (CellIsNull(payload_collection, 1, row_idx)) { - FlatVector::SetNull(result, i, true); - } else { - auto n_param = GetCell(payload_collection, 1, row_idx); - if (n_param < 1) { - FlatVector::SetNull(result, i, true); - } else { - auto n = idx_t(n_param); - const auto nth_index = FindNextStart(ignore_nulls, window_begin[i], window_end[i], n); - if (!n) { - CopyCell(payload_collection, 0, nth_index, result, i); - } else { - FlatVector::SetNull(result, i, true); - } - } - } - } -} - -} // namespace duckdb - - - - - - -#include - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// WindowAggregator -//===--------------------------------------------------------------------===// -WindowAggregatorState::WindowAggregatorState() : allocator(Allocator::DefaultAllocator()) { -} - -WindowAggregator::WindowAggregator(AggregateObject aggr, const LogicalType &result_type_p, idx_t partition_count_p) - : aggr(std::move(aggr)), result_type(result_type_p), partition_count(partition_count_p), - state_size(aggr.function.state_size()), filter_pos(0) { -} - -WindowAggregator::~WindowAggregator() { -} - -void WindowAggregator::Sink(DataChunk &payload_chunk, SelectionVector *filter_sel, idx_t filtered) { - if (!inputs.ColumnCount() && payload_chunk.ColumnCount()) { - inputs.Initialize(Allocator::DefaultAllocator(), payload_chunk.GetTypes()); - } - if (inputs.ColumnCount()) { - inputs.Append(payload_chunk, true); - } - if (filter_sel) { - // Lazy instantiation - if (!filter_mask.IsMaskSet()) { - // Start with all invalid and set the ones that pass - filter_bits.resize(ValidityMask::ValidityMaskSize(partition_count), 0); - filter_mask.Initialize(filter_bits.data()); - } - for (idx_t f = 0; f < filtered; ++f) { - filter_mask.SetValid(filter_pos + filter_sel->get_index(f)); - } - filter_pos += payload_chunk.size(); - } -} - -void WindowAggregator::Finalize() { -} - -//===--------------------------------------------------------------------===// -// WindowConstantAggregate -//===--------------------------------------------------------------------===// -WindowConstantAggregator::WindowConstantAggregator(AggregateObject aggr, const LogicalType &result_type, - const ValidityMask &partition_mask, const idx_t count) - : WindowAggregator(std::move(aggr), result_type, count), partition(0), row(0), state(state_size), - statep(Value::POINTER(CastPointerToValue(state.data()))), - statef(Value::POINTER(CastPointerToValue(state.data()))) { - - statef.SetVectorType(VectorType::FLAT_VECTOR); // Prevent conversion of results to constants - - // Locate the partition boundaries - if (partition_mask.AllValid()) { - partition_offsets.emplace_back(0); - } else { - idx_t entry_idx; - idx_t shift; - for (idx_t start = 0; start < count;) { - partition_mask.GetEntryIndex(start, entry_idx, shift); - - // If start is aligned with the start of a block, - // and the block is blank, then skip forward one block. - const auto block = partition_mask.GetValidityEntry(entry_idx); - if (partition_mask.NoneValid(block) && !shift) { - start += ValidityMask::BITS_PER_VALUE; - continue; - } - - // Loop over the block - for (; shift < ValidityMask::BITS_PER_VALUE && start < count; ++shift, ++start) { - if (partition_mask.RowIsValid(block, shift)) { - partition_offsets.emplace_back(start); - } - } - } - } - - // Initialise the vector for caching the results - results = make_uniq(result_type, partition_offsets.size()); - partition_offsets.emplace_back(count); - - // Create an aggregate state for intermediate aggregates - gstate = make_uniq(); - - // Start the first aggregate - AggregateInit(); -} - -void WindowConstantAggregator::AggregateInit() { - aggr.function.initialize(state.data()); -} - -void WindowConstantAggregator::AggegateFinal(Vector &result, idx_t rid) { - AggregateInputData aggr_input_data(aggr.GetFunctionData(), gstate->allocator); - aggr.function.finalize(statef, aggr_input_data, result, 1, rid); - - if (aggr.function.destructor) { - aggr.function.destructor(statef, aggr_input_data, 1); - } -} - -void WindowConstantAggregator::Sink(DataChunk &payload_chunk, SelectionVector *filter_sel, idx_t filtered) { - const auto chunk_begin = row; - const auto chunk_end = chunk_begin + payload_chunk.size(); - - if (!inputs.ColumnCount() && payload_chunk.ColumnCount()) { - inputs.Initialize(Allocator::DefaultAllocator(), payload_chunk.GetTypes()); - } - - AggregateInputData aggr_input_data(aggr.GetFunctionData(), gstate->allocator); - idx_t begin = 0; - idx_t filter_idx = 0; - auto partition_end = partition_offsets[partition + 1]; - while (row < chunk_end) { - if (row == partition_end) { - AggegateFinal(*results, partition++); - AggregateInit(); - partition_end = partition_offsets[partition + 1]; - } - partition_end = MinValue(partition_end, chunk_end); - auto end = partition_end - chunk_begin; - - inputs.Reset(); - if (filter_sel) { - // Slice to any filtered rows in [begin, end) - SelectionVector sel; - - // Find the first value in [begin, end) - for (; filter_idx < filtered; ++filter_idx) { - auto idx = filter_sel->get_index(filter_idx); - if (idx >= begin) { - break; - } - } - - // Find the first value in [end, filtered) - sel.Initialize(filter_sel->data() + filter_idx); - idx_t nsel = 0; - for (; filter_idx < filtered; ++filter_idx, ++nsel) { - auto idx = filter_sel->get_index(filter_idx); - if (idx >= end) { - break; - } - } - - if (nsel != inputs.size()) { - inputs.Slice(payload_chunk, sel, nsel); - } - } else { - // Slice to [begin, end) - if (begin) { - for (idx_t c = 0; c < payload_chunk.ColumnCount(); ++c) { - inputs.data[c].Slice(payload_chunk.data[c], begin, end); - } - } else { - inputs.Reference(payload_chunk); - } - inputs.SetCardinality(end - begin); - } - - // Aggregate the filtered rows into a single state - const auto count = inputs.size(); - if (aggr.function.simple_update) { - aggr.function.simple_update(inputs.data.data(), aggr_input_data, inputs.ColumnCount(), state.data(), count); - } else { - aggr.function.update(inputs.data.data(), aggr_input_data, inputs.ColumnCount(), statep, count); - } - - // Skip filtered rows too! - row += end - begin; - begin = end; - } -} - -void WindowConstantAggregator::Finalize() { - AggegateFinal(*results, partition++); -} - -class WindowConstantAggregatorState : public WindowAggregatorState { -public: - WindowConstantAggregatorState() : partition(0) { - matches.Initialize(); - } - ~WindowConstantAggregatorState() override { - } - -public: - //! The current result partition being read - idx_t partition; - //! Shared SV for evaluation - SelectionVector matches; -}; - -unique_ptr WindowConstantAggregator::GetLocalState() const { - return make_uniq(); -} - -void WindowConstantAggregator::Evaluate(WindowAggregatorState &lstate, const idx_t *begins, const idx_t *ends, - Vector &target, idx_t count) const { - // Chunk up the constants and copy them one at a time - auto &lcstate = lstate.Cast(); - idx_t matched = 0; - idx_t target_offset = 0; - for (idx_t i = 0; i < count; ++i) { - const auto begin = begins[i]; - // Find the partition containing [begin, end) - while (partition_offsets[lcstate.partition + 1] <= begin) { - // Flush the previous partition's data - if (matched) { - VectorOperations::Copy(*results, target, lcstate.matches, matched, 0, target_offset); - target_offset += matched; - matched = 0; - } - ++lcstate.partition; - } - - lcstate.matches.set_index(matched++, lcstate.partition); - } - - // Flush the last partition - if (matched) { - VectorOperations::Copy(*results, target, lcstate.matches, matched, 0, target_offset); - } -} - -//===--------------------------------------------------------------------===// -// WindowCustomAggregator -//===--------------------------------------------------------------------===// -WindowCustomAggregator::WindowCustomAggregator(AggregateObject aggr, const LogicalType &result_type, idx_t count) - : WindowAggregator(std::move(aggr), result_type, count) { -} - -WindowCustomAggregator::~WindowCustomAggregator() { -} - -class WindowCustomAggregatorState : public WindowAggregatorState { -public: - explicit WindowCustomAggregatorState(const AggregateObject &aggr, DataChunk &inputs); - ~WindowCustomAggregatorState() override; - -public: - //! The aggregate function - const AggregateObject &aggr; - //! The aggregate function - DataChunk &inputs; - //! Data pointer that contains a single state, shared by all the custom evaluators - vector state; - //! Reused result state container for the window functions - Vector statef; - //! The frame boundaries, used for the window functions - FrameBounds frame; -}; - -WindowCustomAggregatorState::WindowCustomAggregatorState(const AggregateObject &aggr, DataChunk &inputs) - : aggr(aggr), inputs(inputs), state(aggr.function.state_size()), - statef(Value::POINTER(CastPointerToValue(state.data()))), frame(0, 0) { - // if we have a frame-by-frame method, share the single state - aggr.function.initialize(state.data()); -} - -WindowCustomAggregatorState::~WindowCustomAggregatorState() { - if (aggr.function.destructor) { - AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); - aggr.function.destructor(statef, aggr_input_data, 1); - } -} - -unique_ptr WindowCustomAggregator::GetLocalState() const { - return make_uniq(aggr, const_cast(inputs)); -} - -void WindowCustomAggregator::Evaluate(WindowAggregatorState &lstate, const idx_t *begins, const idx_t *ends, - Vector &result, idx_t count) const { - // TODO: window should take a const Vector* - auto &lcstate = lstate.Cast(); - auto &frame = lcstate.frame; - auto params = lcstate.inputs.data.data(); - auto &rmask = FlatVector::Validity(result); - for (idx_t i = 0; i < count; ++i) { - const auto begin = begins[i]; - const auto end = ends[i]; - if (begin >= end) { - rmask.SetInvalid(i); - continue; - } - - // Frame boundaries - auto prev = frame; - frame = FrameBounds(begin, end); - - // Extract the range - AggregateInputData aggr_input_data(aggr.GetFunctionData(), lstate.allocator); - aggr.function.window(params, filter_mask, aggr_input_data, inputs.ColumnCount(), lcstate.state.data(), frame, - prev, result, i, 0); - } -} - -//===--------------------------------------------------------------------===// -// WindowSegmentTree -//===--------------------------------------------------------------------===// -WindowSegmentTree::WindowSegmentTree(AggregateObject aggr, const LogicalType &result_type, idx_t count, - WindowAggregationMode mode_p) - : WindowAggregator(std::move(aggr), result_type, count), internal_nodes(0), mode(mode_p) { -} - -void WindowSegmentTree::Finalize() { - gstate = GetLocalState(); - if (inputs.ColumnCount() > 0) { - if (aggr.function.combine && UseCombineAPI()) { - ConstructTree(); - } - } -} - -WindowSegmentTree::~WindowSegmentTree() { - if (!aggr.function.destructor) { - // nothing to destroy - return; - } - AggregateInputData aggr_input_data(aggr.GetFunctionData(), gstate->allocator); - // call the destructor for all the intermediate states - data_ptr_t address_data[STANDARD_VECTOR_SIZE]; - Vector addresses(LogicalType::POINTER, data_ptr_cast(address_data)); - idx_t count = 0; - for (idx_t i = 0; i < internal_nodes; i++) { - address_data[count++] = data_ptr_t(levels_flat_native.get() + i * state_size); - if (count == STANDARD_VECTOR_SIZE) { - aggr.function.destructor(addresses, aggr_input_data, count); - count = 0; - } - } - if (count > 0) { - aggr.function.destructor(addresses, aggr_input_data, count); - } -} - -class WindowSegmentTreeState : public WindowAggregatorState { -public: - WindowSegmentTreeState(const AggregateObject &aggr, DataChunk &inputs, const ValidityMask &filter_mask); - ~WindowSegmentTreeState() override; - - void FlushStates(bool combining); - void ExtractFrame(idx_t begin, idx_t end, data_ptr_t current_state); - void WindowSegmentValue(const WindowSegmentTree &tree, idx_t l_idx, idx_t begin, idx_t end, - data_ptr_t current_state); - void Finalize(Vector &result, idx_t count); - -public: - //! The aggregate function - const AggregateObject &aggr; - //! The aggregate function - DataChunk &inputs; - //! The filtered rows in inputs - const ValidityMask &filter_mask; - //! The size of a single aggregate state - const idx_t state_size; - //! Data pointer that contains a single state, used for intermediate window segment aggregation - vector state; - //! Input data chunk, used for leaf segment aggregation - DataChunk leaves; - //! The filtered rows in inputs. - SelectionVector filter_sel; - //! A vector of pointers to "state", used for intermediate window segment aggregation - Vector statep; - //! Reused state pointers for combining segment tree levels - Vector statel; - //! Reused result state container for the window functions - Vector statef; - //! Count of buffered values - idx_t flush_count; -}; - -WindowSegmentTreeState::WindowSegmentTreeState(const AggregateObject &aggr, DataChunk &inputs, - const ValidityMask &filter_mask) - : aggr(aggr), inputs(inputs), filter_mask(filter_mask), state_size(aggr.function.state_size()), - state(state_size * STANDARD_VECTOR_SIZE), statep(LogicalType::POINTER), statel(LogicalType::POINTER), - statef(LogicalType::POINTER), flush_count(0) { - if (inputs.ColumnCount() > 0) { - leaves.Initialize(Allocator::DefaultAllocator(), inputs.GetTypes()); - filter_sel.Initialize(); - } - - // Build the finalise vector that just points to the result states - data_ptr_t state_ptr = state.data(); - D_ASSERT(statef.GetVectorType() == VectorType::FLAT_VECTOR); - statef.SetVectorType(VectorType::CONSTANT_VECTOR); - statef.Flatten(STANDARD_VECTOR_SIZE); - auto fdata = FlatVector::GetData(statef); - for (idx_t i = 0; i < STANDARD_VECTOR_SIZE; ++i) { - fdata[i] = state_ptr; - state_ptr += state_size; - } -} - -WindowSegmentTreeState::~WindowSegmentTreeState() { -} - -unique_ptr WindowSegmentTree::GetLocalState() const { - return make_uniq(aggr, const_cast(inputs), filter_mask); -} - -void WindowSegmentTreeState::FlushStates(bool combining) { - if (!flush_count) { - return; - } - - AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); - if (combining) { - statel.Verify(flush_count); - aggr.function.combine(statel, statep, aggr_input_data, flush_count); - } else { - leaves.Reference(inputs); - leaves.Slice(filter_sel, flush_count); - aggr.function.update(&leaves.data[0], aggr_input_data, leaves.ColumnCount(), statep, flush_count); - } - - flush_count = 0; -} - -void WindowSegmentTreeState::ExtractFrame(idx_t begin, idx_t end, data_ptr_t state_ptr) { - const auto count = end - begin; - - // If we are not filtering, - // just update the shared dictionary selection to the range - // Otherwise set it to the input rows that pass the filter - auto states = FlatVector::GetData(statep); - if (filter_mask.AllValid()) { - for (idx_t i = 0; i < count; ++i) { - states[flush_count] = state_ptr; - filter_sel.set_index(flush_count++, begin + i); - if (flush_count >= STANDARD_VECTOR_SIZE) { - FlushStates(false); - } - } - } else { - for (idx_t i = begin; i < end; ++i) { - if (filter_mask.RowIsValid(i)) { - states[flush_count] = state_ptr; - filter_sel.set_index(flush_count++, i); - if (flush_count >= STANDARD_VECTOR_SIZE) { - FlushStates(false); - } - } - } - } -} - -void WindowSegmentTreeState::WindowSegmentValue(const WindowSegmentTree &tree, idx_t l_idx, idx_t begin, idx_t end, - data_ptr_t state_ptr) { - D_ASSERT(begin <= end); - if (begin == end || inputs.ColumnCount() == 0) { - return; - } - - const auto count = end - begin; - if (l_idx == 0) { - ExtractFrame(begin, end, state_ptr); - } else { - // find out where the states begin - auto begin_ptr = tree.levels_flat_native.get() + state_size * (begin + tree.levels_flat_start[l_idx - 1]); - // set up a vector of pointers that point towards the set of states - auto ldata = FlatVector::GetData(statel); - auto pdata = FlatVector::GetData(statep); - for (idx_t i = 0; i < count; i++) { - pdata[flush_count] = state_ptr; - ldata[flush_count++] = begin_ptr; - begin_ptr += state_size; - if (flush_count >= STANDARD_VECTOR_SIZE) { - FlushStates(true); - } - } - } -} -void WindowSegmentTreeState::Finalize(Vector &result, idx_t count) { - // Finalise the result aggregates - AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); - aggr.function.finalize(statef, aggr_input_data, result, count, 0); - - // Destruct the result aggregates - if (aggr.function.destructor) { - aggr.function.destructor(statef, aggr_input_data, count); - } -} - -void WindowSegmentTree::ConstructTree() { - D_ASSERT(inputs.ColumnCount() > 0); - - // Use a temporary scan state to build the tree - auto >state = gstate->Cast(); - - // compute space required to store internal nodes of segment tree - internal_nodes = 0; - idx_t level_nodes = inputs.size(); - do { - level_nodes = (level_nodes + (TREE_FANOUT - 1)) / TREE_FANOUT; - internal_nodes += level_nodes; - } while (level_nodes > 1); - levels_flat_native = make_unsafe_uniq_array(internal_nodes * state_size); - levels_flat_start.push_back(0); - - idx_t levels_flat_offset = 0; - idx_t level_current = 0; - // level 0 is data itself - idx_t level_size; - // iterate over the levels of the segment tree - while ((level_size = - (level_current == 0 ? inputs.size() : levels_flat_offset - levels_flat_start[level_current - 1])) > 1) { - for (idx_t pos = 0; pos < level_size; pos += TREE_FANOUT) { - // compute the aggregate for this entry in the segment tree - data_ptr_t state_ptr = levels_flat_native.get() + (levels_flat_offset * state_size); - aggr.function.initialize(state_ptr); - gtstate.WindowSegmentValue(*this, level_current, pos, MinValue(level_size, pos + TREE_FANOUT), state_ptr); - gtstate.FlushStates(level_current > 0); - - levels_flat_offset++; - } - - levels_flat_start.push_back(levels_flat_offset); - level_current++; - } - - // Corner case: single element in the window - if (levels_flat_offset == 0) { - aggr.function.initialize(levels_flat_native.get()); - } -} - -void WindowSegmentTree::Evaluate(WindowAggregatorState &lstate, const idx_t *begins, const idx_t *ends, Vector &result, - idx_t count) const { - auto <state = lstate.Cast(); - const auto cant_combine = (!aggr.function.combine || !UseCombineAPI()); - auto fdata = FlatVector::GetData(ltstate.statef); - - // First pass: aggregate the segment tree nodes - // Share adjacent identical states - // We do this first because we want to share only tree aggregations - idx_t prev_begin = 1; - idx_t prev_end = 0; - auto ldata = FlatVector::GetData(ltstate.statel); - auto pdata = FlatVector::GetData(ltstate.statep); - data_ptr_t prev_state = nullptr; - for (idx_t rid = 0; rid < count; ++rid) { - auto state_ptr = fdata[rid]; - aggr.function.initialize(state_ptr); - - if (cant_combine) { - // Make sure we initialise all states - continue; - } - - auto begin = begins[rid]; - auto end = ends[rid]; - if (begin >= end) { - continue; - } - - // Skip level 0 - idx_t l_idx = 0; - for (; l_idx < levels_flat_start.size() + 1; l_idx++) { - idx_t parent_begin = begin / TREE_FANOUT; - idx_t parent_end = end / TREE_FANOUT; - if (prev_state && l_idx == 1 && begin == prev_begin && end == prev_end) { - // Just combine the previous top level result - ldata[ltstate.flush_count] = prev_state; - pdata[ltstate.flush_count] = state_ptr; - if (++ltstate.flush_count >= STANDARD_VECTOR_SIZE) { - ltstate.FlushStates(true); - } - break; - } - - if (l_idx == 1) { - prev_state = state_ptr; - prev_begin = begin; - prev_end = end; - } - - if (parent_begin == parent_end) { - if (l_idx) { - ltstate.WindowSegmentValue(*this, l_idx, begin, end, state_ptr); - } - break; - } - idx_t group_begin = parent_begin * TREE_FANOUT; - if (begin != group_begin) { - if (l_idx) { - ltstate.WindowSegmentValue(*this, l_idx, begin, group_begin + TREE_FANOUT, state_ptr); - } - parent_begin++; - } - idx_t group_end = parent_end * TREE_FANOUT; - if (end != group_end) { - if (l_idx) { - ltstate.WindowSegmentValue(*this, l_idx, group_end, end, state_ptr); - } - } - begin = parent_begin; - end = parent_end; - } - } - ltstate.FlushStates(true); - - // Second pass: aggregate the ragged leaves - // (or everything if we can't combine) - for (idx_t rid = 0; rid < count; ++rid) { - auto state_ptr = fdata[rid]; - - const auto begin = begins[rid]; - const auto end = ends[rid]; - if (begin >= end) { - continue; - } - - // Aggregate everything at once if we can't combine states - idx_t parent_begin = begin / TREE_FANOUT; - idx_t parent_end = end / TREE_FANOUT; - if (parent_begin == parent_end || cant_combine) { - ltstate.WindowSegmentValue(*this, 0, begin, end, state_ptr); - continue; - } - - idx_t group_begin = parent_begin * TREE_FANOUT; - if (begin != group_begin) { - ltstate.WindowSegmentValue(*this, 0, begin, group_begin + TREE_FANOUT, state_ptr); - parent_begin++; - } - idx_t group_end = parent_end * TREE_FANOUT; - if (end != group_end) { - ltstate.WindowSegmentValue(*this, 0, group_end, end, state_ptr); - } - } - ltstate.FlushStates(false); - - ltstate.Finalize(result, count); - - // Set the validity mask on the invalid rows - auto &rmask = FlatVector::Validity(result); - for (idx_t rid = 0; rid < count; ++rid) { - const auto begin = begins[rid]; - const auto end = ends[rid]; - - if (begin >= end) { - rmask.SetInvalid(rid); - } - } -} - -} // namespace duckdb - - - - - -namespace duckdb { - -struct BaseCountFunction { - template - static void Initialize(STATE &state) { - state = 0; - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - target += source; - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - target = state; - } -}; - -struct CountStarFunction : public BaseCountFunction { - template - static void Operation(STATE &state, AggregateInputData &, idx_t idx) { - state += 1; - } - - template - static void ConstantOperation(STATE &state, AggregateInputData &, idx_t count) { - state += count; - } - - template - static void Window(Vector inputs[], const ValidityMask &filter_mask, AggregateInputData &aggr_input_data, - idx_t input_count, data_ptr_t state, const FrameBounds &frame, const FrameBounds &prev, - Vector &result, idx_t rid, idx_t bias) { - D_ASSERT(input_count == 0); - auto data = FlatVector::GetData(result); - const auto begin = frame.start; - const auto end = frame.end; - // Slice to any filtered rows - if (!filter_mask.AllValid()) { - RESULT_TYPE filtered = 0; - for (auto i = begin; i < end; ++i) { - filtered += filter_mask.RowIsValid(i); - } - data[rid] = filtered; - } else { - data[rid] = end - begin; - } - } -}; - -struct CountFunction : public BaseCountFunction { - using STATE = int64_t; - - static void Operation(STATE &state) { - state += 1; - } - - static void ConstantOperation(STATE &state, idx_t count) { - state += count; - } - - static bool IgnoreNull() { - return true; - } - - static inline void CountFlatLoop(STATE **__restrict states, ValidityMask &mask, idx_t count) { - if (!mask.AllValid()) { - idx_t base_idx = 0; - auto entry_count = ValidityMask::EntryCount(count); - for (idx_t entry_idx = 0; entry_idx < entry_count; entry_idx++) { - auto validity_entry = mask.GetValidityEntry(entry_idx); - idx_t next = MinValue(base_idx + ValidityMask::BITS_PER_VALUE, count); - if (ValidityMask::AllValid(validity_entry)) { - // all valid: perform operation - for (; base_idx < next; base_idx++) { - CountFunction::Operation(*states[base_idx]); - } - } else if (ValidityMask::NoneValid(validity_entry)) { - // nothing valid: skip all - base_idx = next; - continue; - } else { - // partially valid: need to check individual elements for validity - idx_t start = base_idx; - for (; base_idx < next; base_idx++) { - if (ValidityMask::RowIsValid(validity_entry, base_idx - start)) { - CountFunction::Operation(*states[base_idx]); - } - } - } - } - } else { - for (idx_t i = 0; i < count; i++) { - CountFunction::Operation(*states[i]); - } - } - } - - static inline void CountScatterLoop(STATE **__restrict states, const SelectionVector &isel, - const SelectionVector &ssel, ValidityMask &mask, idx_t count) { - if (!mask.AllValid()) { - // potential NULL values - for (idx_t i = 0; i < count; i++) { - auto idx = isel.get_index(i); - auto sidx = ssel.get_index(i); - if (mask.RowIsValid(idx)) { - CountFunction::Operation(*states[sidx]); - } - } - } else { - // quick path: no NULL values - for (idx_t i = 0; i < count; i++) { - auto sidx = ssel.get_index(i); - CountFunction::Operation(*states[sidx]); - } - } - } - - static void CountScatter(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, Vector &states, - idx_t count) { - auto &input = inputs[0]; - if (input.GetVectorType() == VectorType::FLAT_VECTOR && states.GetVectorType() == VectorType::FLAT_VECTOR) { - auto sdata = FlatVector::GetData(states); - CountFlatLoop(sdata, FlatVector::Validity(input), count); - } else { - UnifiedVectorFormat idata, sdata; - input.ToUnifiedFormat(count, idata); - states.ToUnifiedFormat(count, sdata); - CountScatterLoop(reinterpret_cast(sdata.data), *idata.sel, *sdata.sel, idata.validity, count); - } - } - - static inline void CountFlatUpdateLoop(STATE &result, ValidityMask &mask, idx_t count) { - idx_t base_idx = 0; - auto entry_count = ValidityMask::EntryCount(count); - for (idx_t entry_idx = 0; entry_idx < entry_count; entry_idx++) { - auto validity_entry = mask.GetValidityEntry(entry_idx); - idx_t next = MinValue(base_idx + ValidityMask::BITS_PER_VALUE, count); - if (ValidityMask::AllValid(validity_entry)) { - // all valid - result += next - base_idx; - base_idx = next; - } else if (ValidityMask::NoneValid(validity_entry)) { - // nothing valid: skip all - base_idx = next; - continue; - } else { - // partially valid: need to check individual elements for validity - idx_t start = base_idx; - for (; base_idx < next; base_idx++) { - if (ValidityMask::RowIsValid(validity_entry, base_idx - start)) { - result++; - } - } - } - } - } - - static inline void CountUpdateLoop(STATE &result, ValidityMask &mask, idx_t count, - const SelectionVector &sel_vector) { - if (mask.AllValid()) { - // no NULL values - result += count; - return; - } - for (idx_t i = 0; i < count; i++) { - auto idx = sel_vector.get_index(i); - if (mask.RowIsValid(idx)) { - result++; - } - } - } - - static void CountUpdate(Vector inputs[], AggregateInputData &, idx_t input_count, data_ptr_t state_p, idx_t count) { - auto &input = inputs[0]; - auto &result = *reinterpret_cast(state_p); - switch (input.GetVectorType()) { - case VectorType::CONSTANT_VECTOR: { - if (!ConstantVector::IsNull(input)) { - // if the constant is not null increment the state - result += count; - } - break; - } - case VectorType::FLAT_VECTOR: { - CountFlatUpdateLoop(result, FlatVector::Validity(input), count); - break; - } - case VectorType::SEQUENCE_VECTOR: { - // sequence vectors cannot have NULL values - result += count; - break; - } - default: { - UnifiedVectorFormat idata; - input.ToUnifiedFormat(count, idata); - CountUpdateLoop(result, idata.validity, count, *idata.sel); - break; - } - } - } -}; - -AggregateFunction CountFun::GetFunction() { - AggregateFunction fun({LogicalType(LogicalTypeId::ANY)}, LogicalType::BIGINT, AggregateFunction::StateSize, - AggregateFunction::StateInitialize, CountFunction::CountScatter, - AggregateFunction::StateCombine, - AggregateFunction::StateFinalize, - FunctionNullHandling::SPECIAL_HANDLING, CountFunction::CountUpdate); - fun.name = "count"; - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return fun; -} - -AggregateFunction CountStarFun::GetFunction() { - auto fun = AggregateFunction::NullaryAggregate(LogicalType::BIGINT); - fun.name = "count_star"; - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - fun.window = CountStarFunction::Window; - return fun; -} - -unique_ptr CountPropagateStats(ClientContext &context, BoundAggregateExpression &expr, - AggregateStatisticsInput &input) { - if (!expr.IsDistinct() && !input.child_stats[0].CanHaveNull()) { - // count on a column without null values: use count star - expr.function = CountStarFun::GetFunction(); - expr.function.name = "count_star"; - expr.children.clear(); - } - return nullptr; -} - -void CountFun::RegisterFunction(BuiltinFunctions &set) { - AggregateFunction count_function = CountFun::GetFunction(); - count_function.statistics = CountPropagateStats; - AggregateFunctionSet count("count"); - count.AddFunction(count_function); - // the count function can also be called without arguments - count_function.arguments.clear(); - count_function.statistics = nullptr; - count_function.window = CountStarFunction::Window; - count.AddFunction(count_function); - set.AddFunction(count); -} - -void CountStarFun::RegisterFunction(BuiltinFunctions &set) { - AggregateFunctionSet count("count_star"); - count.AddFunction(CountStarFun::GetFunction()); - set.AddFunction(count); -} - -} // namespace duckdb - -#endif diff --git a/lib/duckdb-6.cpp b/lib/duckdb-6.cpp deleted file mode 100644 index ec0bccd0..00000000 --- a/lib/duckdb-6.cpp +++ /dev/null @@ -1,20127 +0,0 @@ -#ifdef GODUCKDB_FROM_SOURCE -// See https://raw.githubusercontent.com/duckdb/duckdb/main/LICENSE for licensing information - -#include "duckdb.hpp" -#include "duckdb-internal.hpp" -#ifndef DUCKDB_AMALGAMATION -#error header mismatch -#endif - - - - - -namespace duckdb { - -template -struct FirstState { - T value; - bool is_set; - bool is_null; -}; - -struct FirstFunctionBase { - template - static void Initialize(STATE &state) { - state.is_set = false; - state.is_null = false; - } - - static bool IgnoreNull() { - return false; - } -}; - -template -struct FirstFunction : public FirstFunctionBase { - template - static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { - if (LAST || !state.is_set) { - if (!unary_input.RowIsValid()) { - if (!SKIP_NULLS) { - state.is_set = true; - } - state.is_null = true; - } else { - state.is_set = true; - state.is_null = false; - state.value = input; - } - } - } - - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, - idx_t count) { - Operation(state, input, unary_input); - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - if (!target.is_set) { - target = source; - } - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (!state.is_set || state.is_null) { - finalize_data.ReturnNull(); - } else { - target = state.value; - } - } -}; - -template -struct FirstFunctionString : public FirstFunctionBase { - template - static void SetValue(STATE &state, AggregateInputData &input_data, string_t value, bool is_null) { - if (LAST && state.is_set) { - Destroy(state, input_data); - } - if (is_null) { - if (!SKIP_NULLS) { - state.is_set = true; - state.is_null = true; - } - } else { - state.is_set = true; - state.is_null = false; - if (value.IsInlined()) { - state.value = value; - } else { - // non-inlined string, need to allocate space for it - auto len = value.GetSize(); - auto ptr = new char[len]; - memcpy(ptr, value.GetData(), len); - - state.value = string_t(ptr, len); - } - } - } - - template - static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { - if (LAST || !state.is_set) { - SetValue(state, unary_input.input, input, !unary_input.RowIsValid()); - } - } - - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, - idx_t count) { - Operation(state, input, unary_input); - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &input_data) { - if (source.is_set && (LAST || !target.is_set)) { - SetValue(target, input_data, source.value, source.is_null); - } - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (!state.is_set || state.is_null) { - finalize_data.ReturnNull(); - } else { - target = StringVector::AddStringOrBlob(finalize_data.result, state.value); - } - } - - template - static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { - if (state.is_set && !state.is_null && !state.value.IsInlined()) { - delete[] state.value.GetData(); - } - } -}; - -struct FirstStateVector { - Vector *value; -}; - -template -struct FirstVectorFunction { - template - static void Initialize(STATE &state) { - state.value = nullptr; - } - - template - static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { - if (state.value) { - delete state.value; - } - } - static bool IgnoreNull() { - return SKIP_NULLS; - } - - template - static void SetValue(STATE &state, Vector &input, const idx_t idx) { - if (!state.value) { - state.value = new Vector(input.GetType()); - state.value->SetVectorType(VectorType::CONSTANT_VECTOR); - } - sel_t selv = idx; - SelectionVector sel(&selv); - VectorOperations::Copy(input, *state.value, sel, 1, 0, 0); - } - - static void Update(Vector inputs[], AggregateInputData &, idx_t input_count, Vector &state_vector, idx_t count) { - auto &input = inputs[0]; - UnifiedVectorFormat idata; - input.ToUnifiedFormat(count, idata); - - UnifiedVectorFormat sdata; - state_vector.ToUnifiedFormat(count, sdata); - - auto states = UnifiedVectorFormat::GetData(sdata); - for (idx_t i = 0; i < count; i++) { - const auto idx = idata.sel->get_index(i); - if (SKIP_NULLS && !idata.validity.RowIsValid(idx)) { - continue; - } - auto &state = *states[sdata.sel->get_index(i)]; - if (LAST || !state.value) { - SetValue(state, input, i); - } - } - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - if (source.value && (LAST || !target.value)) { - SetValue(target, *source.value, 0); - } - } - - template - static void Finalize(STATE &state, AggregateFinalizeData &finalize_data) { - if (!state.value) { - finalize_data.ReturnNull(); - } else { - VectorOperations::Copy(*state.value, finalize_data.result, 1, 0, finalize_data.result_idx); - } - } - - static unique_ptr Bind(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - function.arguments[0] = arguments[0]->return_type; - function.return_type = arguments[0]->return_type; - return nullptr; - } -}; - -template -static AggregateFunction GetFirstAggregateTemplated(LogicalType type) { - return AggregateFunction::UnaryAggregate, T, T, FirstFunction>(type, type); -} - -template -static AggregateFunction GetFirstFunction(const LogicalType &type); - -template -AggregateFunction GetDecimalFirstFunction(const LogicalType &type) { - D_ASSERT(type.id() == LogicalTypeId::DECIMAL); - switch (type.InternalType()) { - case PhysicalType::INT16: - return GetFirstFunction(LogicalType::SMALLINT); - case PhysicalType::INT32: - return GetFirstFunction(LogicalType::INTEGER); - case PhysicalType::INT64: - return GetFirstFunction(LogicalType::BIGINT); - default: - return GetFirstFunction(LogicalType::HUGEINT); - } -} - -template -static AggregateFunction GetFirstFunction(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::BOOLEAN: - return GetFirstAggregateTemplated(type); - case LogicalTypeId::TINYINT: - return GetFirstAggregateTemplated(type); - case LogicalTypeId::SMALLINT: - return GetFirstAggregateTemplated(type); - case LogicalTypeId::INTEGER: - case LogicalTypeId::DATE: - return GetFirstAggregateTemplated(type); - case LogicalTypeId::BIGINT: - case LogicalTypeId::TIME: - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIME_TZ: - case LogicalTypeId::TIMESTAMP_TZ: - return GetFirstAggregateTemplated(type); - case LogicalTypeId::UTINYINT: - return GetFirstAggregateTemplated(type); - case LogicalTypeId::USMALLINT: - return GetFirstAggregateTemplated(type); - case LogicalTypeId::UINTEGER: - return GetFirstAggregateTemplated(type); - case LogicalTypeId::UBIGINT: - return GetFirstAggregateTemplated(type); - case LogicalTypeId::HUGEINT: - return GetFirstAggregateTemplated(type); - case LogicalTypeId::FLOAT: - return GetFirstAggregateTemplated(type); - case LogicalTypeId::DOUBLE: - return GetFirstAggregateTemplated(type); - case LogicalTypeId::INTERVAL: - return GetFirstAggregateTemplated(type); - case LogicalTypeId::VARCHAR: - case LogicalTypeId::BLOB: - return AggregateFunction::UnaryAggregateDestructor, string_t, string_t, - FirstFunctionString>(type, type); - case LogicalTypeId::DECIMAL: { - type.Verify(); - AggregateFunction function = GetDecimalFirstFunction(type); - function.arguments[0] = type; - function.return_type = type; - // TODO set_key here? - return function; - } - default: { - using OP = FirstVectorFunction; - return AggregateFunction({type}, type, AggregateFunction::StateSize, - AggregateFunction::StateInitialize, OP::Update, - AggregateFunction::StateCombine, - AggregateFunction::StateVoidFinalize, nullptr, OP::Bind, - AggregateFunction::StateDestroy, nullptr, nullptr); - } - } -} - -AggregateFunction FirstFun::GetFunction(const LogicalType &type) { - auto fun = GetFirstFunction(type); - fun.name = "first"; - return fun; -} - -template -unique_ptr BindDecimalFirst(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - auto decimal_type = arguments[0]->return_type; - auto name = std::move(function.name); - function = GetFirstFunction(decimal_type); - function.name = std::move(name); - function.return_type = decimal_type; - return nullptr; -} - -template -static AggregateFunction GetFirstOperator(const LogicalType &type) { - if (type.id() == LogicalTypeId::DECIMAL) { - throw InternalException("FIXME: this shouldn't happen..."); - } - return GetFirstFunction(type); -} - -template -unique_ptr BindFirst(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - auto input_type = arguments[0]->return_type; - auto name = std::move(function.name); - function = GetFirstOperator(input_type); - function.name = std::move(name); - if (function.bind) { - return function.bind(context, function, arguments); - } else { - return nullptr; - } -} - -template -static void AddFirstOperator(AggregateFunctionSet &set) { - set.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr, - nullptr, nullptr, nullptr, BindDecimalFirst)); - set.AddFunction(AggregateFunction({LogicalType::ANY}, LogicalType::ANY, nullptr, nullptr, nullptr, nullptr, nullptr, - nullptr, BindFirst)); -} - -void FirstFun::RegisterFunction(BuiltinFunctions &set) { - AggregateFunctionSet first("first"); - AggregateFunctionSet last("last"); - AggregateFunctionSet any_value("any_value"); - - AddFirstOperator(first); - AddFirstOperator(last); - AddFirstOperator(any_value); - - set.AddFunction(first); - first.name = "arbitrary"; - set.AddFunction(first); - - set.AddFunction(last); - - set.AddFunction(any_value); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -void BuiltinFunctions::RegisterDistributiveAggregates() { - Register(); - Register(); - Register(); -} - -} // namespace duckdb - - - - - - - - - - -namespace duckdb { - -struct SortedAggregateBindData : public FunctionData { - SortedAggregateBindData(ClientContext &context, BoundAggregateExpression &expr) - : buffer_manager(BufferManager::GetBufferManager(context)), function(expr.function), - bind_info(std::move(expr.bind_info)), threshold(ClientConfig::GetConfig(context).ordered_aggregate_threshold), - external(ClientConfig::GetConfig(context).force_external) { - auto &children = expr.children; - arg_types.reserve(children.size()); - for (const auto &child : children) { - arg_types.emplace_back(child->return_type); - } - auto &order_bys = *expr.order_bys; - sort_types.reserve(order_bys.orders.size()); - for (auto &order : order_bys.orders) { - orders.emplace_back(order.Copy()); - sort_types.emplace_back(order.expression->return_type); - } - sorted_on_args = (children.size() == order_bys.orders.size()); - for (size_t i = 0; sorted_on_args && i < children.size(); ++i) { - sorted_on_args = children[i]->Equals(*order_bys.orders[i].expression); - } - } - - SortedAggregateBindData(const SortedAggregateBindData &other) - : buffer_manager(other.buffer_manager), function(other.function), arg_types(other.arg_types), - sort_types(other.sort_types), sorted_on_args(other.sorted_on_args), threshold(other.threshold), - external(other.external) { - if (other.bind_info) { - bind_info = other.bind_info->Copy(); - } - for (auto &order : other.orders) { - orders.emplace_back(order.Copy()); - } - } - - unique_ptr Copy() const override { - return make_uniq(*this); - } - - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - if (bind_info && other.bind_info) { - if (!bind_info->Equals(*other.bind_info)) { - return false; - } - } else if (bind_info || other.bind_info) { - return false; - } - if (function != other.function) { - return false; - } - if (orders.size() != other.orders.size()) { - return false; - } - for (size_t i = 0; i < orders.size(); ++i) { - if (!orders[i].Equals(other.orders[i])) { - return false; - } - } - return true; - } - - BufferManager &buffer_manager; - AggregateFunction function; - vector arg_types; - unique_ptr bind_info; - - vector orders; - vector sort_types; - bool sorted_on_args; - - //! The sort flush threshold - const idx_t threshold; - const bool external; -}; - -struct SortedAggregateState { - //! Default buffer size, optimised for small group to avoid blowing out memory. - static const idx_t BUFFER_CAPACITY = 16; - - SortedAggregateState() : count(0), nsel(0), offset(0) { - } - - static inline void InitializeBuffer(DataChunk &chunk, const vector &types) { - if (!chunk.ColumnCount() && !types.empty()) { - chunk.Initialize(Allocator::DefaultAllocator(), types, BUFFER_CAPACITY); - } - } - - //! Make sure the buffer is large enough for slicing - static inline void ResetBuffer(DataChunk &chunk, const vector &types) { - chunk.Reset(); - chunk.Destroy(); - chunk.Initialize(Allocator::DefaultAllocator(), types); - } - - void Flush(const SortedAggregateBindData &order_bind) { - if (ordering) { - return; - } - - ordering = make_uniq(order_bind.buffer_manager, order_bind.sort_types); - InitializeBuffer(sort_buffer, order_bind.sort_types); - ordering->Append(sort_buffer); - ResetBuffer(sort_buffer, order_bind.sort_types); - - if (!order_bind.sorted_on_args) { - arguments = make_uniq(order_bind.buffer_manager, order_bind.arg_types); - InitializeBuffer(arg_buffer, order_bind.arg_types); - arguments->Append(arg_buffer); - ResetBuffer(arg_buffer, order_bind.arg_types); - } - } - - void Update(const SortedAggregateBindData &order_bind, DataChunk &sort_chunk, DataChunk &arg_chunk) { - count += sort_chunk.size(); - - // Lazy instantiation of the buffer chunks - InitializeBuffer(sort_buffer, order_bind.sort_types); - if (!order_bind.sorted_on_args) { - InitializeBuffer(arg_buffer, order_bind.arg_types); - } - - if (sort_chunk.size() + sort_buffer.size() > STANDARD_VECTOR_SIZE) { - Flush(order_bind); - } - if (arguments) { - ordering->Append(sort_chunk); - arguments->Append(arg_chunk); - } else if (ordering) { - ordering->Append(sort_chunk); - } else if (order_bind.sorted_on_args) { - sort_buffer.Append(sort_chunk, true); - } else { - sort_buffer.Append(sort_chunk, true); - arg_buffer.Append(arg_chunk, true); - } - } - - void UpdateSlice(const SortedAggregateBindData &order_bind, DataChunk &sort_inputs, DataChunk &arg_inputs) { - count += nsel; - - // Lazy instantiation of the buffer chunks - InitializeBuffer(sort_buffer, order_bind.sort_types); - if (!order_bind.sorted_on_args) { - InitializeBuffer(arg_buffer, order_bind.arg_types); - } - - if (nsel + sort_buffer.size() > STANDARD_VECTOR_SIZE) { - Flush(order_bind); - } - if (arguments) { - sort_buffer.Reset(); - sort_buffer.Slice(sort_inputs, sel, nsel); - ordering->Append(sort_buffer); - - arg_buffer.Reset(); - arg_buffer.Slice(arg_inputs, sel, nsel); - arguments->Append(arg_buffer); - } else if (ordering) { - sort_buffer.Reset(); - sort_buffer.Slice(sort_inputs, sel, nsel); - ordering->Append(sort_buffer); - } else if (order_bind.sorted_on_args) { - sort_buffer.Append(sort_inputs, true, &sel, nsel); - } else { - sort_buffer.Append(sort_inputs, true, &sel, nsel); - arg_buffer.Append(arg_inputs, true, &sel, nsel); - } - - nsel = 0; - offset = 0; - } - - void Combine(SortedAggregateBindData &order_bind, SortedAggregateState &other) { - if (other.arguments) { - // Force CDC if the other has it - Flush(order_bind); - ordering->Combine(*other.ordering); - arguments->Combine(*other.arguments); - count += other.count; - } else if (other.ordering) { - // Force CDC if the other has it - Flush(order_bind); - ordering->Combine(*other.ordering); - count += other.count; - } else if (other.sort_buffer.size()) { - Update(order_bind, other.sort_buffer, other.arg_buffer); - } - } - - void PrefixSortBuffer(DataChunk &prefixed) { - for (column_t col_idx = 0; col_idx < sort_buffer.ColumnCount(); ++col_idx) { - prefixed.data[col_idx + 1].Reference(sort_buffer.data[col_idx]); - } - prefixed.SetCardinality(sort_buffer); - } - - void Finalize(const SortedAggregateBindData &order_bind, DataChunk &prefixed, LocalSortState &local_sort) { - if (arguments) { - ColumnDataScanState sort_state; - ordering->InitializeScan(sort_state); - ColumnDataScanState arg_state; - arguments->InitializeScan(arg_state); - for (sort_buffer.Reset(); ordering->Scan(sort_state, sort_buffer); sort_buffer.Reset()) { - PrefixSortBuffer(prefixed); - arg_buffer.Reset(); - arguments->Scan(arg_state, arg_buffer); - local_sort.SinkChunk(prefixed, arg_buffer); - } - ordering->Reset(); - arguments->Reset(); - } else if (ordering) { - ColumnDataScanState sort_state; - ordering->InitializeScan(sort_state); - for (sort_buffer.Reset(); ordering->Scan(sort_state, sort_buffer); sort_buffer.Reset()) { - PrefixSortBuffer(prefixed); - local_sort.SinkChunk(prefixed, sort_buffer); - } - ordering->Reset(); - } else if (order_bind.sorted_on_args) { - PrefixSortBuffer(prefixed); - local_sort.SinkChunk(prefixed, sort_buffer); - } else { - PrefixSortBuffer(prefixed); - local_sort.SinkChunk(prefixed, arg_buffer); - } - } - - idx_t count; - unique_ptr arguments; - unique_ptr ordering; - - DataChunk sort_buffer; - DataChunk arg_buffer; - - // Selection for scattering - SelectionVector sel; - idx_t nsel; - idx_t offset; -}; - -struct SortedAggregateFunction { - template - static void Initialize(STATE &state) { - new (&state) STATE(); - } - - template - static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { - state.~STATE(); - } - - static void ProjectInputs(Vector inputs[], const SortedAggregateBindData &order_bind, idx_t input_count, - idx_t count, DataChunk &arg_chunk, DataChunk &sort_chunk) { - idx_t col = 0; - - if (!order_bind.sorted_on_args) { - arg_chunk.InitializeEmpty(order_bind.arg_types); - for (auto &dst : arg_chunk.data) { - dst.Reference(inputs[col++]); - } - arg_chunk.SetCardinality(count); - } - - sort_chunk.InitializeEmpty(order_bind.sort_types); - for (auto &dst : sort_chunk.data) { - dst.Reference(inputs[col++]); - } - sort_chunk.SetCardinality(count); - } - - static void SimpleUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, data_ptr_t state, - idx_t count) { - const auto order_bind = aggr_input_data.bind_data->Cast(); - DataChunk arg_chunk; - DataChunk sort_chunk; - ProjectInputs(inputs, order_bind, input_count, count, arg_chunk, sort_chunk); - - const auto order_state = reinterpret_cast(state); - order_state->Update(order_bind, sort_chunk, arg_chunk); - } - - static void ScatterUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, Vector &states, - idx_t count) { - if (!count) { - return; - } - - // Append the arguments to the two sub-collections - const auto &order_bind = aggr_input_data.bind_data->Cast(); - DataChunk arg_inputs; - DataChunk sort_inputs; - ProjectInputs(inputs, order_bind, input_count, count, arg_inputs, sort_inputs); - - // We have to scatter the chunks one at a time - // so build a selection vector for each one. - UnifiedVectorFormat svdata; - states.ToUnifiedFormat(count, svdata); - - // Size the selection vector for each state. - auto sdata = UnifiedVectorFormat::GetDataNoConst(svdata); - for (idx_t i = 0; i < count; ++i) { - auto sidx = svdata.sel->get_index(i); - auto order_state = sdata[sidx]; - order_state->nsel++; - } - - // Build the selection vector for each state. - vector sel_data(count); - idx_t start = 0; - for (idx_t i = 0; i < count; ++i) { - auto sidx = svdata.sel->get_index(i); - auto order_state = sdata[sidx]; - if (!order_state->offset) { - // First one - order_state->offset = start; - order_state->sel.Initialize(sel_data.data() + order_state->offset); - start += order_state->nsel; - } - sel_data[order_state->offset++] = sidx; - } - - // Append nonempty slices to the arguments - for (idx_t i = 0; i < count; ++i) { - auto sidx = svdata.sel->get_index(i); - auto order_state = sdata[sidx]; - if (!order_state->nsel) { - continue; - } - - order_state->UpdateSlice(order_bind, sort_inputs, arg_inputs); - } - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input_data) { - auto &order_bind = aggr_input_data.bind_data->Cast(); - auto &other = const_cast(source); - target.Combine(order_bind, other); - } - - static void Window(Vector inputs[], const ValidityMask &filter_mask, AggregateInputData &aggr_input_data, - idx_t input_count, data_ptr_t state, const FrameBounds &frame, const FrameBounds &prev, - Vector &result, idx_t rid, idx_t bias) { - throw InternalException("Sorted aggregates should not be generated for window clauses"); - } - - static void Finalize(Vector &states, AggregateInputData &aggr_input_data, Vector &result, idx_t count, - const idx_t offset) { - auto &order_bind = aggr_input_data.bind_data->Cast(); - auto &buffer_manager = order_bind.buffer_manager; - RowLayout payload_layout; - payload_layout.Initialize(order_bind.arg_types); - DataChunk chunk; - chunk.Initialize(Allocator::DefaultAllocator(), order_bind.arg_types); - DataChunk sliced; - sliced.Initialize(Allocator::DefaultAllocator(), order_bind.arg_types); - - // Reusable inner state - vector agg_state(order_bind.function.state_size()); - Vector agg_state_vec(Value::POINTER(CastPointerToValue(agg_state.data()))); - - // State variables - auto bind_info = order_bind.bind_info.get(); - ArenaAllocator allocator(Allocator::DefaultAllocator()); - AggregateInputData aggr_bind_info(bind_info, allocator); - - // Inner aggregate APIs - auto initialize = order_bind.function.initialize; - auto destructor = order_bind.function.destructor; - auto simple_update = order_bind.function.simple_update; - auto update = order_bind.function.update; - auto finalize = order_bind.function.finalize; - - auto sdata = FlatVector::GetData(states); - - vector state_unprocessed(count, 0); - for (idx_t i = 0; i < count; ++i) { - state_unprocessed[i] = sdata[i]->count; - } - - // Sort the input payloads on (state_idx ASC, orders) - vector orders; - orders.emplace_back(BoundOrderByNode(OrderType::ASCENDING, OrderByNullType::NULLS_FIRST, - make_uniq(Value::USMALLINT(0)))); - for (const auto &order : order_bind.orders) { - orders.emplace_back(order.Copy()); - } - - auto global_sort = make_uniq(buffer_manager, orders, payload_layout); - global_sort->external = order_bind.external; - auto local_sort = make_uniq(); - local_sort->Initialize(*global_sort, global_sort->buffer_manager); - - DataChunk prefixed; - prefixed.Initialize(Allocator::DefaultAllocator(), global_sort->sort_layout.logical_types); - - // Go through the states accumulating values to sort until we hit the sort threshold - idx_t unsorted_count = 0; - idx_t sorted = 0; - for (idx_t finalized = 0; finalized < count;) { - if (unsorted_count < order_bind.threshold) { - auto state = sdata[finalized]; - prefixed.Reset(); - prefixed.data[0].Reference(Value::USMALLINT(finalized)); - state->Finalize(order_bind, prefixed, *local_sort); - unsorted_count += state_unprocessed[finalized]; - - // Go to the next aggregate unless this is the last one - if (++finalized < count) { - continue; - } - } - - // If they were all empty (filtering) flush them - // (This can only happen on the last range) - if (!unsorted_count) { - break; - } - - // Sort all the data - global_sort->AddLocalState(*local_sort); - global_sort->PrepareMergePhase(); - while (global_sort->sorted_blocks.size() > 1) { - global_sort->InitializeMergeRound(); - MergeSorter merge_sorter(*global_sort, global_sort->buffer_manager); - merge_sorter.PerformInMergeRound(); - global_sort->CompleteMergeRound(false); - } - - auto scanner = make_uniq(*global_sort); - initialize(agg_state.data()); - while (scanner->Remaining()) { - chunk.Reset(); - scanner->Scan(chunk); - idx_t consumed = 0; - - // Distribute the scanned chunk to the aggregates - while (consumed < chunk.size()) { - // Find the next aggregate that needs data - for (; !state_unprocessed[sorted]; ++sorted) { - // Finalize a single value at the next offset - agg_state_vec.SetVectorType(states.GetVectorType()); - finalize(agg_state_vec, aggr_bind_info, result, 1, sorted + offset); - if (destructor) { - destructor(agg_state_vec, aggr_bind_info, 1); - } - - initialize(agg_state.data()); - } - const auto input_count = MinValue(state_unprocessed[sorted], chunk.size() - consumed); - for (column_t col_idx = 0; col_idx < chunk.ColumnCount(); ++col_idx) { - sliced.data[col_idx].Slice(chunk.data[col_idx], consumed, consumed + input_count); - } - sliced.SetCardinality(input_count); - - // These are all simple updates, so use it if available - if (simple_update) { - simple_update(sliced.data.data(), aggr_bind_info, sliced.data.size(), agg_state.data(), - sliced.size()); - } else { - // We are only updating a constant state - agg_state_vec.SetVectorType(VectorType::CONSTANT_VECTOR); - update(sliced.data.data(), aggr_bind_info, sliced.data.size(), agg_state_vec, sliced.size()); - } - - consumed += input_count; - state_unprocessed[sorted] -= input_count; - } - } - - // Finalize the last state for this sort - agg_state_vec.SetVectorType(states.GetVectorType()); - finalize(agg_state_vec, aggr_bind_info, result, 1, sorted + offset); - if (destructor) { - destructor(agg_state_vec, aggr_bind_info, 1); - } - ++sorted; - - // Stop if we are done - if (finalized >= count) { - break; - } - - // Create a new sort - scanner.reset(); - global_sort = make_uniq(buffer_manager, orders, payload_layout); - global_sort->external = order_bind.external; - local_sort = make_uniq(); - local_sort->Initialize(*global_sort, global_sort->buffer_manager); - unsorted_count = 0; - } - - for (; sorted < count; ++sorted) { - initialize(agg_state.data()); - - // Finalize a single value at the next offset - agg_state_vec.SetVectorType(states.GetVectorType()); - finalize(agg_state_vec, aggr_bind_info, result, 1, sorted + offset); - - if (destructor) { - destructor(agg_state_vec, aggr_bind_info, 1); - } - } - - result.Verify(count); - } -}; - -void FunctionBinder::BindSortedAggregate(ClientContext &context, BoundAggregateExpression &expr, - const vector> &groups) { - if (!expr.order_bys || expr.order_bys->orders.empty() || expr.children.empty()) { - // not a sorted aggregate: return - return; - } - if (context.config.enable_optimizer) { - // for each ORDER BY - check if it is actually necessary - // expressions that are in the groups do not need to be ORDERED BY - // `ORDER BY` on a group has no effect, because for each aggregate, the group is unique - // similarly, we only need to ORDER BY each aggregate once - expression_set_t seen_expressions; - for (auto &target : groups) { - seen_expressions.insert(*target); - } - vector new_order_nodes; - for (auto &order_node : expr.order_bys->orders) { - if (seen_expressions.find(*order_node.expression) != seen_expressions.end()) { - // we do not need to order by this node - continue; - } - seen_expressions.insert(*order_node.expression); - new_order_nodes.push_back(std::move(order_node)); - } - if (new_order_nodes.empty()) { - expr.order_bys.reset(); - return; - } - expr.order_bys->orders = std::move(new_order_nodes); - } - auto &bound_function = expr.function; - auto &children = expr.children; - auto &order_bys = *expr.order_bys; - auto sorted_bind = make_uniq(context, expr); - - if (!sorted_bind->sorted_on_args) { - // The arguments are the children plus the sort columns. - for (auto &order : order_bys.orders) { - children.emplace_back(std::move(order.expression)); - } - } - - vector arguments; - arguments.reserve(children.size()); - for (const auto &child : children) { - arguments.emplace_back(child->return_type); - } - - // Replace the aggregate with the wrapper - AggregateFunction ordered_aggregate( - bound_function.name, arguments, bound_function.return_type, AggregateFunction::StateSize, - AggregateFunction::StateInitialize, - SortedAggregateFunction::ScatterUpdate, - AggregateFunction::StateCombine, - SortedAggregateFunction::Finalize, bound_function.null_handling, SortedAggregateFunction::SimpleUpdate, nullptr, - AggregateFunction::StateDestroy, nullptr, - SortedAggregateFunction::Window); - - expr.function = std::move(ordered_aggregate); - expr.bind_info = std::move(sorted_bind); - expr.order_bys.reset(); -} - -} // namespace duckdb - - - - - - - - - - -namespace duckdb { - -BuiltinFunctions::BuiltinFunctions(CatalogTransaction transaction, Catalog &catalog) - : transaction(transaction), catalog(catalog) { -} - -BuiltinFunctions::~BuiltinFunctions() { -} - -void BuiltinFunctions::AddCollation(string name, ScalarFunction function, bool combinable, - bool not_required_for_equality) { - CreateCollationInfo info(std::move(name), std::move(function), combinable, not_required_for_equality); - info.internal = true; - catalog.CreateCollation(transaction, info); -} - -void BuiltinFunctions::AddFunction(AggregateFunctionSet set) { - CreateAggregateFunctionInfo info(std::move(set)); - info.internal = true; - catalog.CreateFunction(transaction, info); -} - -void BuiltinFunctions::AddFunction(AggregateFunction function) { - CreateAggregateFunctionInfo info(std::move(function)); - info.internal = true; - catalog.CreateFunction(transaction, info); -} - -void BuiltinFunctions::AddFunction(PragmaFunction function) { - CreatePragmaFunctionInfo info(std::move(function)); - info.internal = true; - catalog.CreatePragmaFunction(transaction, info); -} - -void BuiltinFunctions::AddFunction(const string &name, PragmaFunctionSet functions) { - CreatePragmaFunctionInfo info(name, std::move(functions)); - info.internal = true; - catalog.CreatePragmaFunction(transaction, info); -} - -void BuiltinFunctions::AddFunction(ScalarFunction function) { - CreateScalarFunctionInfo info(std::move(function)); - info.internal = true; - catalog.CreateFunction(transaction, info); -} - -void BuiltinFunctions::AddFunction(const vector &names, ScalarFunction function) { // NOLINT: false positive - for (auto &name : names) { - function.name = name; - AddFunction(function); - } -} - -void BuiltinFunctions::AddFunction(ScalarFunctionSet set) { - CreateScalarFunctionInfo info(std::move(set)); - info.internal = true; - catalog.CreateFunction(transaction, info); -} - -void BuiltinFunctions::AddFunction(TableFunction function) { - CreateTableFunctionInfo info(std::move(function)); - info.internal = true; - catalog.CreateTableFunction(transaction, info); -} - -void BuiltinFunctions::AddFunction(TableFunctionSet set) { - CreateTableFunctionInfo info(std::move(set)); - info.internal = true; - catalog.CreateTableFunction(transaction, info); -} - -void BuiltinFunctions::AddFunction(CopyFunction function) { - CreateCopyFunctionInfo info(std::move(function)); - info.internal = true; - catalog.CreateCopyFunction(transaction, info); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -BoundCastInfo DefaultCasts::BitCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target) { - // now switch on the result type - switch (target.id()) { - // Numerics - case LogicalTypeId::BOOLEAN: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::TINYINT: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::SMALLINT: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::INTEGER: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::BIGINT: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::UTINYINT: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::USMALLINT: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::UINTEGER: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::UBIGINT: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::HUGEINT: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::FLOAT: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::DOUBLE: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - - case LogicalTypeId::BLOB: - return BoundCastInfo(&VectorCastHelpers::StringCast); - - case LogicalTypeId::VARCHAR: - return BoundCastInfo(&VectorCastHelpers::StringCast); - - default: - return DefaultCasts::TryVectorNullCast; - } -} - -} // namespace duckdb - - - -namespace duckdb { - -BoundCastInfo DefaultCasts::BlobCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target) { - // now switch on the result type - switch (target.id()) { - case LogicalTypeId::VARCHAR: - // blob to varchar - return BoundCastInfo(&VectorCastHelpers::StringCast); - case LogicalTypeId::AGGREGATE_STATE: - return DefaultCasts::ReinterpretCast; - case LogicalTypeId::BIT: - return BoundCastInfo(&VectorCastHelpers::StringCast); - - default: - return DefaultCasts::TryVectorNullCast; - } -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -BindCastInput::BindCastInput(CastFunctionSet &function_set, optional_ptr info, - optional_ptr context) - : function_set(function_set), info(info), context(context) { -} - -BoundCastInfo BindCastInput::GetCastFunction(const LogicalType &source, const LogicalType &target) { - GetCastFunctionInput input(context); - return function_set.GetCastFunction(source, target, input); -} - -BindCastFunction::BindCastFunction(bind_cast_function_t function_p, unique_ptr info_p) - : function(function_p), info(std::move(info_p)) { -} - -CastFunctionSet::CastFunctionSet() : map_info(nullptr) { - bind_functions.emplace_back(DefaultCasts::GetDefaultCastFunction); -} - -CastFunctionSet &CastFunctionSet::Get(ClientContext &context) { - return DBConfig::GetConfig(context).GetCastFunctions(); -} - -CastFunctionSet &CastFunctionSet::Get(DatabaseInstance &db) { - return DBConfig::GetConfig(db).GetCastFunctions(); -} - -BoundCastInfo CastFunctionSet::GetCastFunction(const LogicalType &source, const LogicalType &target, - GetCastFunctionInput &get_input) { - if (source == target) { - return DefaultCasts::NopCast; - } - // the first function is the default - // we iterate the set of bind functions backwards - for (idx_t i = bind_functions.size(); i > 0; i--) { - auto &bind_function = bind_functions[i - 1]; - BindCastInput input(*this, bind_function.info.get(), get_input.context); - auto result = bind_function.function(input, source, target); - if (result.function) { - // found a cast function! return it - return result; - } - } - // no cast found: return the default null cast - return DefaultCasts::TryVectorNullCast; -} - -struct MapCastNode { - MapCastNode(BoundCastInfo info, int64_t implicit_cast_cost) - : cast_info(std::move(info)), bind_function(nullptr), implicit_cast_cost(implicit_cast_cost) { - } - MapCastNode(bind_cast_function_t func, int64_t implicit_cast_cost) - : cast_info(nullptr), bind_function(func), implicit_cast_cost(implicit_cast_cost) { - } - - BoundCastInfo cast_info; - bind_cast_function_t bind_function; - int64_t implicit_cast_cost; -}; - -template -static auto RelaxedTypeMatch(type_map_t &map, const LogicalType &type) -> decltype(map.find(type)) { - D_ASSERT(map.find(type) == map.end()); // we shouldn't be here - switch (type.id()) { - case LogicalTypeId::LIST: - return map.find(LogicalType::LIST(LogicalType::ANY)); - case LogicalTypeId::STRUCT: - return map.find(LogicalType::STRUCT({{"any", LogicalType::ANY}})); - case LogicalTypeId::MAP: - for (auto it = map.begin(); it != map.end(); it++) { - const auto &entry_type = it->first; - if (entry_type.id() != LogicalTypeId::MAP) { - continue; - } - auto &entry_key_type = MapType::KeyType(entry_type); - auto &entry_val_type = MapType::ValueType(entry_type); - if ((entry_key_type == LogicalType::ANY || entry_key_type == MapType::KeyType(type)) && - (entry_val_type == LogicalType::ANY || entry_val_type == MapType::ValueType(type))) { - return it; - } - } - return map.end(); - case LogicalTypeId::UNION: - return map.find(LogicalType::UNION({{"any", LogicalType::ANY}})); - default: - return map.find(LogicalType::ANY); - } -} - -struct MapCastInfo : public BindCastInfo { -public: - const optional_ptr GetEntry(const LogicalType &source, const LogicalType &target) { - auto source_type_id_entry = casts.find(source.id()); - if (source_type_id_entry == casts.end()) { - source_type_id_entry = casts.find(LogicalTypeId::ANY); - if (source_type_id_entry == casts.end()) { - return nullptr; - } - } - - auto &source_type_entries = source_type_id_entry->second; - auto source_type_entry = source_type_entries.find(source); - if (source_type_entry == source_type_entries.end()) { - source_type_entry = RelaxedTypeMatch(source_type_entries, source); - if (source_type_entry == source_type_entries.end()) { - return nullptr; - } - } - - auto &target_type_id_entries = source_type_entry->second; - auto target_type_id_entry = target_type_id_entries.find(target.id()); - if (target_type_id_entry == target_type_id_entries.end()) { - target_type_id_entry = target_type_id_entries.find(LogicalTypeId::ANY); - if (target_type_id_entry == target_type_id_entries.end()) { - return nullptr; - } - } - - auto &target_type_entries = target_type_id_entry->second; - auto target_type_entry = target_type_entries.find(target); - if (target_type_entry == target_type_entries.end()) { - target_type_entry = RelaxedTypeMatch(target_type_entries, target); - if (target_type_entry == target_type_entries.end()) { - return nullptr; - } - } - - return &target_type_entry->second; - } - - void AddEntry(const LogicalType &source, const LogicalType &target, MapCastNode node) { - casts[source.id()][source][target.id()].insert(make_pair(target, std::move(node))); - } - -private: - type_id_map_t>>> casts; -}; - -int64_t CastFunctionSet::ImplicitCastCost(const LogicalType &source, const LogicalType &target) { - // check if a cast has been registered - if (map_info) { - auto entry = map_info->GetEntry(source, target); - if (entry) { - return entry->implicit_cast_cost; - } - } - // if not, fallback to the default implicit cast rules - return CastRules::ImplicitCast(source, target); -} - -BoundCastInfo MapCastFunction(BindCastInput &input, const LogicalType &source, const LogicalType &target) { - D_ASSERT(input.info); - auto &map_info = input.info->Cast(); - auto entry = map_info.GetEntry(source, target); - if (entry) { - if (entry->bind_function) { - return entry->bind_function(input, source, target); - } - return entry->cast_info.Copy(); - } - return nullptr; -} - -void CastFunctionSet::RegisterCastFunction(const LogicalType &source, const LogicalType &target, BoundCastInfo function, - int64_t implicit_cast_cost) { - RegisterCastFunction(source, target, MapCastNode(std::move(function), implicit_cast_cost)); -} - -void CastFunctionSet::RegisterCastFunction(const LogicalType &source, const LogicalType &target, - bind_cast_function_t bind_function, int64_t implicit_cast_cost) { - RegisterCastFunction(source, target, MapCastNode(bind_function, implicit_cast_cost)); -} - -void CastFunctionSet::RegisterCastFunction(const LogicalType &source, const LogicalType &target, MapCastNode node) { - if (!map_info) { - // create the cast map and the cast map function - auto info = make_uniq(); - map_info = info.get(); - bind_functions.emplace_back(MapCastFunction, std::move(info)); - } - map_info->AddEntry(source, target, std::move(node)); -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -template -static bool FromDecimalCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - auto &source_type = source.GetType(); - auto width = DecimalType::GetWidth(source_type); - auto scale = DecimalType::GetScale(source_type); - switch (source_type.InternalType()) { - case PhysicalType::INT16: - return VectorCastHelpers::TemplatedDecimalCast( - source, result, count, parameters.error_message, width, scale); - case PhysicalType::INT32: - return VectorCastHelpers::TemplatedDecimalCast( - source, result, count, parameters.error_message, width, scale); - case PhysicalType::INT64: - return VectorCastHelpers::TemplatedDecimalCast( - source, result, count, parameters.error_message, width, scale); - case PhysicalType::INT128: - return VectorCastHelpers::TemplatedDecimalCast( - source, result, count, parameters.error_message, width, scale); - default: - throw InternalException("Unimplemented internal type for decimal"); - } -} - -template -struct DecimalScaleInput { - DecimalScaleInput(Vector &result_p, FACTOR_TYPE factor_p) : result(result_p), factor(factor_p) { - } - DecimalScaleInput(Vector &result_p, LIMIT_TYPE limit_p, FACTOR_TYPE factor_p, string *error_message_p, - uint8_t source_width_p, uint8_t source_scale_p) - : result(result_p), limit(limit_p), factor(factor_p), error_message(error_message_p), - source_width(source_width_p), source_scale(source_scale_p) { - } - - Vector &result; - LIMIT_TYPE limit; - FACTOR_TYPE factor; - bool all_converted = true; - string *error_message; - uint8_t source_width; - uint8_t source_scale; -}; - -struct DecimalScaleUpOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, ValidityMask &mask, idx_t idx, void *dataptr) { - auto data = (DecimalScaleInput *)dataptr; - return Cast::Operation(input) * data->factor; - } -}; - -struct DecimalScaleUpCheckOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, ValidityMask &mask, idx_t idx, void *dataptr) { - auto data = (DecimalScaleInput *)dataptr; - if (input >= data->limit || input <= -data->limit) { - auto error = StringUtil::Format("Casting value \"%s\" to type %s failed: value is out of range!", - Decimal::ToString(input, data->source_width, data->source_scale), - data->result.GetType().ToString()); - return HandleVectorCastError::Operation(std::move(error), mask, idx, data->error_message, - data->all_converted); - } - return Cast::Operation(input) * data->factor; - } -}; - -template -bool TemplatedDecimalScaleUp(Vector &source, Vector &result, idx_t count, string *error_message) { - auto source_scale = DecimalType::GetScale(source.GetType()); - auto source_width = DecimalType::GetWidth(source.GetType()); - auto result_scale = DecimalType::GetScale(result.GetType()); - auto result_width = DecimalType::GetWidth(result.GetType()); - D_ASSERT(result_scale >= source_scale); - idx_t scale_difference = result_scale - source_scale; - DEST multiply_factor = POWERS_DEST::POWERS_OF_TEN[scale_difference]; - idx_t target_width = result_width - scale_difference; - if (source_width < target_width) { - DecimalScaleInput input(result, multiply_factor); - // type will always fit: no need to check limit - UnaryExecutor::GenericExecute(source, result, count, &input); - return true; - } else { - // type might not fit: check limit - auto limit = POWERS_SOURCE::POWERS_OF_TEN[target_width]; - DecimalScaleInput input(result, limit, multiply_factor, error_message, source_width, - source_scale); - UnaryExecutor::GenericExecute(source, result, count, &input, - error_message); - return input.all_converted; - } -} - -struct DecimalScaleDownOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, ValidityMask &mask, idx_t idx, void *dataptr) { - auto data = (DecimalScaleInput *)dataptr; - return Cast::Operation(input / data->factor); - } -}; - -struct DecimalScaleDownCheckOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, ValidityMask &mask, idx_t idx, void *dataptr) { - auto data = (DecimalScaleInput *)dataptr; - if (input >= data->limit || input <= -data->limit) { - auto error = StringUtil::Format("Casting value \"%s\" to type %s failed: value is out of range!", - Decimal::ToString(input, data->source_width, data->source_scale), - data->result.GetType().ToString()); - return HandleVectorCastError::Operation(std::move(error), mask, idx, data->error_message, - data->all_converted); - } - return Cast::Operation(input / data->factor); - } -}; - -template -bool TemplatedDecimalScaleDown(Vector &source, Vector &result, idx_t count, string *error_message) { - auto source_scale = DecimalType::GetScale(source.GetType()); - auto source_width = DecimalType::GetWidth(source.GetType()); - auto result_scale = DecimalType::GetScale(result.GetType()); - auto result_width = DecimalType::GetWidth(result.GetType()); - D_ASSERT(result_scale < source_scale); - idx_t scale_difference = source_scale - result_scale; - idx_t target_width = result_width + scale_difference; - SOURCE divide_factor = POWERS_SOURCE::POWERS_OF_TEN[scale_difference]; - if (source_width < target_width) { - DecimalScaleInput input(result, divide_factor); - // type will always fit: no need to check limit - UnaryExecutor::GenericExecute(source, result, count, &input); - return true; - } else { - // type might not fit: check limit - auto limit = POWERS_SOURCE::POWERS_OF_TEN[target_width]; - DecimalScaleInput input(result, limit, divide_factor, error_message, source_width, source_scale); - UnaryExecutor::GenericExecute(source, result, count, &input, - error_message); - return input.all_converted; - } -} - -template -static bool DecimalDecimalCastSwitch(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - auto source_scale = DecimalType::GetScale(source.GetType()); - auto result_scale = DecimalType::GetScale(result.GetType()); - source.GetType().Verify(); - result.GetType().Verify(); - - // we need to either multiply or divide by the difference in scales - if (result_scale >= source_scale) { - // multiply - switch (result.GetType().InternalType()) { - case PhysicalType::INT16: - return TemplatedDecimalScaleUp(source, result, count, - parameters.error_message); - case PhysicalType::INT32: - return TemplatedDecimalScaleUp(source, result, count, - parameters.error_message); - case PhysicalType::INT64: - return TemplatedDecimalScaleUp(source, result, count, - parameters.error_message); - case PhysicalType::INT128: - return TemplatedDecimalScaleUp(source, result, count, - parameters.error_message); - default: - throw NotImplementedException("Unimplemented internal type for decimal"); - } - } else { - // divide - switch (result.GetType().InternalType()) { - case PhysicalType::INT16: - return TemplatedDecimalScaleDown(source, result, count, - parameters.error_message); - case PhysicalType::INT32: - return TemplatedDecimalScaleDown(source, result, count, - parameters.error_message); - case PhysicalType::INT64: - return TemplatedDecimalScaleDown(source, result, count, - parameters.error_message); - case PhysicalType::INT128: - return TemplatedDecimalScaleDown(source, result, count, - parameters.error_message); - default: - throw NotImplementedException("Unimplemented internal type for decimal"); - } - } -} - -struct DecimalCastInput { - DecimalCastInput(Vector &result_p, uint8_t width_p, uint8_t scale_p) - : result(result_p), width(width_p), scale(scale_p) { - } - - Vector &result; - uint8_t width; - uint8_t scale; -}; - -struct StringCastFromDecimalOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, ValidityMask &mask, idx_t idx, void *dataptr) { - auto data = reinterpret_cast(dataptr); - return StringCastFromDecimal::Operation(input, data->width, data->scale, data->result); - } -}; - -template -static bool DecimalToStringCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - auto &source_type = source.GetType(); - auto width = DecimalType::GetWidth(source_type); - auto scale = DecimalType::GetScale(source_type); - DecimalCastInput input(result, width, scale); - - UnaryExecutor::GenericExecute(source, result, count, (void *)&input); - return true; -} - -BoundCastInfo DefaultCasts::DecimalCastSwitch(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - // now switch on the result type - switch (target.id()) { - case LogicalTypeId::BOOLEAN: - return FromDecimalCast; - case LogicalTypeId::TINYINT: - return FromDecimalCast; - case LogicalTypeId::SMALLINT: - return FromDecimalCast; - case LogicalTypeId::INTEGER: - return FromDecimalCast; - case LogicalTypeId::BIGINT: - return FromDecimalCast; - case LogicalTypeId::UTINYINT: - return FromDecimalCast; - case LogicalTypeId::USMALLINT: - return FromDecimalCast; - case LogicalTypeId::UINTEGER: - return FromDecimalCast; - case LogicalTypeId::UBIGINT: - return FromDecimalCast; - case LogicalTypeId::HUGEINT: - return FromDecimalCast; - case LogicalTypeId::DECIMAL: { - // decimal to decimal cast - // first we need to figure out the source and target internal types - switch (source.InternalType()) { - case PhysicalType::INT16: - return DecimalDecimalCastSwitch; - case PhysicalType::INT32: - return DecimalDecimalCastSwitch; - case PhysicalType::INT64: - return DecimalDecimalCastSwitch; - case PhysicalType::INT128: - return DecimalDecimalCastSwitch; - default: - throw NotImplementedException("Unimplemented internal type for decimal in decimal_decimal cast"); - } - } - case LogicalTypeId::FLOAT: - return FromDecimalCast; - case LogicalTypeId::DOUBLE: - return FromDecimalCast; - case LogicalTypeId::VARCHAR: { - switch (source.InternalType()) { - case PhysicalType::INT16: - return DecimalToStringCast; - case PhysicalType::INT32: - return DecimalToStringCast; - case PhysicalType::INT64: - return DecimalToStringCast; - case PhysicalType::INT128: - return DecimalToStringCast; - default: - throw InternalException("Unimplemented internal decimal type"); - } - } - default: - return DefaultCasts::TryVectorNullCast; - } -} - -} // namespace duckdb - - - - - - - - - - - - -namespace duckdb { - -BindCastInfo::~BindCastInfo() { -} - -BoundCastData::~BoundCastData() { -} - -BoundCastInfo::BoundCastInfo(cast_function_t function_p, unique_ptr cast_data_p, - init_cast_local_state_t init_local_state_p) - : function(function_p), init_local_state(init_local_state_p), cast_data(std::move(cast_data_p)) { -} - -BoundCastInfo BoundCastInfo::Copy() const { - return BoundCastInfo(function, cast_data ? cast_data->Copy() : nullptr, init_local_state); -} - -bool DefaultCasts::NopCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - result.Reference(source); - return true; -} - -static string UnimplementedCastMessage(const LogicalType &source_type, const LogicalType &target_type) { - return StringUtil::Format("Unimplemented type for cast (%s -> %s)", source_type.ToString(), target_type.ToString()); -} - -// NULL cast only works if all values in source are NULL, otherwise an unimplemented cast exception is thrown -bool DefaultCasts::TryVectorNullCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - bool success = true; - if (VectorOperations::HasNotNull(source, count)) { - HandleCastError::AssignError(UnimplementedCastMessage(source.GetType(), result.GetType()), - parameters.error_message); - success = false; - } - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - return success; -} - -bool DefaultCasts::ReinterpretCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - result.Reinterpret(source); - return true; -} - -static bool AggregateStateToBlobCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - if (result.GetType().id() != LogicalTypeId::BLOB) { - throw TypeMismatchException(source.GetType(), result.GetType(), - "Cannot cast AGGREGATE_STATE to anything but BLOB"); - } - result.Reinterpret(source); - return true; -} - -static bool NullTypeCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - // cast a NULL to another type, just copy the properties and change the type - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - return true; -} - -BoundCastInfo DefaultCasts::GetDefaultCastFunction(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - D_ASSERT(source != target); - - // first check if were casting to a union - if (source.id() != LogicalTypeId::UNION && source.id() != LogicalTypeId::SQLNULL && - target.id() == LogicalTypeId::UNION) { - return ImplicitToUnionCast(input, source, target); - } - - // else, switch on source type - switch (source.id()) { - case LogicalTypeId::BOOLEAN: - case LogicalTypeId::TINYINT: - case LogicalTypeId::SMALLINT: - case LogicalTypeId::INTEGER: - case LogicalTypeId::BIGINT: - case LogicalTypeId::UTINYINT: - case LogicalTypeId::USMALLINT: - case LogicalTypeId::UINTEGER: - case LogicalTypeId::UBIGINT: - case LogicalTypeId::HUGEINT: - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - return NumericCastSwitch(input, source, target); - case LogicalTypeId::POINTER: - return PointerCastSwitch(input, source, target); - case LogicalTypeId::UUID: - return UUIDCastSwitch(input, source, target); - case LogicalTypeId::DECIMAL: - return DecimalCastSwitch(input, source, target); - case LogicalTypeId::DATE: - return DateCastSwitch(input, source, target); - case LogicalTypeId::TIME: - return TimeCastSwitch(input, source, target); - case LogicalTypeId::TIME_TZ: - return TimeTzCastSwitch(input, source, target); - case LogicalTypeId::TIMESTAMP: - return TimestampCastSwitch(input, source, target); - case LogicalTypeId::TIMESTAMP_TZ: - return TimestampTzCastSwitch(input, source, target); - case LogicalTypeId::TIMESTAMP_NS: - return TimestampNsCastSwitch(input, source, target); - case LogicalTypeId::TIMESTAMP_MS: - return TimestampMsCastSwitch(input, source, target); - case LogicalTypeId::TIMESTAMP_SEC: - return TimestampSecCastSwitch(input, source, target); - case LogicalTypeId::INTERVAL: - return IntervalCastSwitch(input, source, target); - case LogicalTypeId::VARCHAR: - return StringCastSwitch(input, source, target); - case LogicalTypeId::BLOB: - return BlobCastSwitch(input, source, target); - case LogicalTypeId::BIT: - return BitCastSwitch(input, source, target); - case LogicalTypeId::SQLNULL: - return NullTypeCast; - case LogicalTypeId::MAP: - return MapCastSwitch(input, source, target); - case LogicalTypeId::STRUCT: - return StructCastSwitch(input, source, target); - case LogicalTypeId::LIST: - return ListCastSwitch(input, source, target); - case LogicalTypeId::UNION: - return UnionCastSwitch(input, source, target); - case LogicalTypeId::ENUM: - return EnumCastSwitch(input, source, target); - case LogicalTypeId::AGGREGATE_STATE: - return AggregateStateToBlobCast; - default: - return nullptr; - } -} - -} // namespace duckdb - - - - -namespace duckdb { - -template -bool EnumEnumCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - bool all_converted = true; - result.SetVectorType(VectorType::FLAT_VECTOR); - - auto &str_vec = EnumType::GetValuesInsertOrder(source.GetType()); - auto str_vec_ptr = FlatVector::GetData(str_vec); - - auto res_enum_type = result.GetType(); - - UnifiedVectorFormat vdata; - source.ToUnifiedFormat(count, vdata); - - auto source_data = UnifiedVectorFormat::GetData(vdata); - auto source_sel = vdata.sel; - auto source_mask = vdata.validity; - - auto result_data = FlatVector::GetData(result); - auto &result_mask = FlatVector::Validity(result); - - for (idx_t i = 0; i < count; i++) { - auto src_idx = source_sel->get_index(i); - if (!source_mask.RowIsValid(src_idx)) { - result_mask.SetInvalid(i); - continue; - } - auto key = EnumType::GetPos(res_enum_type, str_vec_ptr[source_data[src_idx]]); - if (key == -1) { - // key doesn't exist on result enum - if (!parameters.error_message) { - result_data[i] = HandleVectorCastError::Operation( - CastExceptionText(source_data[src_idx]), result_mask, i, - parameters.error_message, all_converted); - } else { - result_mask.SetInvalid(i); - } - continue; - } - result_data[i] = key; - } - return all_converted; -} - -template -BoundCastInfo EnumEnumCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target) { - switch (target.InternalType()) { - case PhysicalType::UINT8: - return EnumEnumCast; - case PhysicalType::UINT16: - return EnumEnumCast; - case PhysicalType::UINT32: - return EnumEnumCast; - default: - throw InternalException("ENUM can only have unsigned integers (except UINT64) as physical types"); - } -} - -template -static bool EnumToVarcharCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - auto &enum_dictionary = EnumType::GetValuesInsertOrder(source.GetType()); - auto dictionary_data = FlatVector::GetData(enum_dictionary); - auto result_data = FlatVector::GetData(result); - auto &result_mask = FlatVector::Validity(result); - - UnifiedVectorFormat vdata; - source.ToUnifiedFormat(count, vdata); - - auto source_data = UnifiedVectorFormat::GetData(vdata); - for (idx_t i = 0; i < count; i++) { - auto source_idx = vdata.sel->get_index(i); - if (!vdata.validity.RowIsValid(source_idx)) { - result_mask.SetInvalid(i); - continue; - } - auto enum_idx = source_data[source_idx]; - result_data[i] = dictionary_data[enum_idx]; - } - if (source.GetVectorType() == VectorType::CONSTANT_VECTOR) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } else { - result.SetVectorType(VectorType::FLAT_VECTOR); - } - return true; -} - -struct EnumBoundCastData : public BoundCastData { - EnumBoundCastData(BoundCastInfo to_varchar_cast, BoundCastInfo from_varchar_cast) - : to_varchar_cast(std::move(to_varchar_cast)), from_varchar_cast(std::move(from_varchar_cast)) { - } - - BoundCastInfo to_varchar_cast; - BoundCastInfo from_varchar_cast; - -public: - unique_ptr Copy() const override { - return make_uniq(to_varchar_cast.Copy(), from_varchar_cast.Copy()); - } -}; - -unique_ptr BindEnumCast(BindCastInput &input, const LogicalType &source, const LogicalType &target) { - auto to_varchar_cast = input.GetCastFunction(source, LogicalType::VARCHAR); - auto from_varchar_cast = input.GetCastFunction(LogicalType::VARCHAR, target); - return make_uniq(std::move(to_varchar_cast), std::move(from_varchar_cast)); -} - -struct EnumCastLocalState : public FunctionLocalState { -public: - unique_ptr to_varchar_local; - unique_ptr from_varchar_local; -}; - -static unique_ptr InitEnumCastLocalState(CastLocalStateParameters ¶meters) { - auto &cast_data = parameters.cast_data->Cast(); - auto result = make_uniq(); - - if (cast_data.from_varchar_cast.init_local_state) { - CastLocalStateParameters from_varchar_params(parameters, cast_data.from_varchar_cast.cast_data); - result->from_varchar_local = cast_data.from_varchar_cast.init_local_state(from_varchar_params); - } - if (cast_data.to_varchar_cast.init_local_state) { - CastLocalStateParameters from_varchar_params(parameters, cast_data.to_varchar_cast.cast_data); - result->from_varchar_local = cast_data.to_varchar_cast.init_local_state(from_varchar_params); - } - return std::move(result); -} - -static bool EnumToAnyCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - auto &cast_data = parameters.cast_data->Cast(); - auto &lstate = parameters.local_state->Cast(); - - Vector varchar_cast(LogicalType::VARCHAR, count); - - // cast to varchar - CastParameters to_varchar_params(parameters, cast_data.to_varchar_cast.cast_data, lstate.to_varchar_local); - cast_data.to_varchar_cast.function(source, varchar_cast, count, to_varchar_params); - - // cast from varchar to the target - CastParameters from_varchar_params(parameters, cast_data.from_varchar_cast.cast_data, lstate.from_varchar_local); - cast_data.from_varchar_cast.function(varchar_cast, result, count, from_varchar_params); - return true; -} - -BoundCastInfo DefaultCasts::EnumCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target) { - auto enum_physical_type = source.InternalType(); - switch (target.id()) { - case LogicalTypeId::ENUM: { - // This means they are both ENUMs, but of different types. - switch (enum_physical_type) { - case PhysicalType::UINT8: - return EnumEnumCastSwitch(input, source, target); - case PhysicalType::UINT16: - return EnumEnumCastSwitch(input, source, target); - case PhysicalType::UINT32: - return EnumEnumCastSwitch(input, source, target); - default: - throw InternalException("ENUM can only have unsigned integers (except UINT64) as physical types"); - } - } - case LogicalTypeId::VARCHAR: - switch (enum_physical_type) { - case PhysicalType::UINT8: - return EnumToVarcharCast; - case PhysicalType::UINT16: - return EnumToVarcharCast; - case PhysicalType::UINT32: - return EnumToVarcharCast; - default: - throw InternalException("ENUM can only have unsigned integers (except UINT64) as physical types"); - } - default: { - return BoundCastInfo(EnumToAnyCast, BindEnumCast(input, source, target), InitEnumCastLocalState); - } - } -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr ListBoundCastData::BindListToListCast(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - vector child_cast_info; - auto &source_child_type = ListType::GetChildType(source); - auto &result_child_type = ListType::GetChildType(target); - auto child_cast = input.GetCastFunction(source_child_type, result_child_type); - return make_uniq(std::move(child_cast)); -} - -unique_ptr ListBoundCastData::InitListLocalState(CastLocalStateParameters ¶meters) { - auto &cast_data = parameters.cast_data->Cast(); - if (!cast_data.child_cast_info.init_local_state) { - return nullptr; - } - CastLocalStateParameters child_parameters(parameters, cast_data.child_cast_info.cast_data); - return cast_data.child_cast_info.init_local_state(child_parameters); -} - -bool ListCast::ListToListCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - auto &cast_data = parameters.cast_data->Cast(); - - // only handle constant and flat vectors here for now - if (source.GetVectorType() == VectorType::CONSTANT_VECTOR) { - result.SetVectorType(source.GetVectorType()); - ConstantVector::SetNull(result, ConstantVector::IsNull(source)); - - auto ldata = ConstantVector::GetData(source); - auto tdata = ConstantVector::GetData(result); - *tdata = *ldata; - } else { - source.Flatten(count); - result.SetVectorType(VectorType::FLAT_VECTOR); - FlatVector::SetValidity(result, FlatVector::Validity(source)); - - auto ldata = FlatVector::GetData(source); - auto tdata = FlatVector::GetData(result); - for (idx_t i = 0; i < count; i++) { - tdata[i] = ldata[i]; - } - } - auto &source_cc = ListVector::GetEntry(source); - auto source_size = ListVector::GetListSize(source); - - ListVector::Reserve(result, source_size); - auto &append_vector = ListVector::GetEntry(result); - - CastParameters child_parameters(parameters, cast_data.child_cast_info.cast_data, parameters.local_state); - bool all_succeeded = cast_data.child_cast_info.function(source_cc, append_vector, source_size, child_parameters); - ListVector::SetListSize(result, source_size); - D_ASSERT(ListVector::GetListSize(result) == source_size); - return all_succeeded; -} - -static bool ListToVarcharCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - auto constant = source.GetVectorType() == VectorType::CONSTANT_VECTOR; - // first cast the child vector to varchar - Vector varchar_list(LogicalType::LIST(LogicalType::VARCHAR), count); - ListCast::ListToListCast(source, varchar_list, count, parameters); - - // now construct the actual varchar vector - varchar_list.Flatten(count); - auto &child = ListVector::GetEntry(varchar_list); - auto list_data = FlatVector::GetData(varchar_list); - auto &validity = FlatVector::Validity(varchar_list); - - child.Flatten(count); - auto child_data = FlatVector::GetData(child); - auto &child_validity = FlatVector::Validity(child); - - auto result_data = FlatVector::GetData(result); - static constexpr const idx_t SEP_LENGTH = 2; - static constexpr const idx_t NULL_LENGTH = 4; - for (idx_t i = 0; i < count; i++) { - if (!validity.RowIsValid(i)) { - FlatVector::SetNull(result, i, true); - continue; - } - auto list = list_data[i]; - // figure out how long the result needs to be - idx_t list_length = 2; // "[" and "]" - for (idx_t list_idx = 0; list_idx < list.length; list_idx++) { - auto idx = list.offset + list_idx; - if (list_idx > 0) { - list_length += SEP_LENGTH; // ", " - } - // string length, or "NULL" - list_length += child_validity.RowIsValid(idx) ? child_data[idx].GetSize() : NULL_LENGTH; - } - result_data[i] = StringVector::EmptyString(result, list_length); - auto dataptr = result_data[i].GetDataWriteable(); - auto offset = 0; - dataptr[offset++] = '['; - for (idx_t list_idx = 0; list_idx < list.length; list_idx++) { - auto idx = list.offset + list_idx; - if (list_idx > 0) { - memcpy(dataptr + offset, ", ", SEP_LENGTH); - offset += SEP_LENGTH; - } - if (child_validity.RowIsValid(idx)) { - auto len = child_data[idx].GetSize(); - memcpy(dataptr + offset, child_data[idx].GetData(), len); - offset += len; - } else { - memcpy(dataptr + offset, "NULL", NULL_LENGTH); - offset += NULL_LENGTH; - } - } - dataptr[offset] = ']'; - result_data[i].Finalize(); - } - - if (constant) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } - return true; -} - -BoundCastInfo DefaultCasts::ListCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target) { - switch (target.id()) { - case LogicalTypeId::LIST: - return BoundCastInfo(ListCast::ListToListCast, ListBoundCastData::BindListToListCast(input, source, target), - ListBoundCastData::InitListLocalState); - case LogicalTypeId::VARCHAR: - return BoundCastInfo( - ListToVarcharCast, - ListBoundCastData::BindListToListCast(input, source, LogicalType::LIST(LogicalType::VARCHAR)), - ListBoundCastData::InitListLocalState); - default: - return DefaultCasts::TryVectorNullCast; - } -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr MapBoundCastData::BindMapToMapCast(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - vector child_cast_info; - auto source_key = MapType::KeyType(source); - auto target_key = MapType::KeyType(target); - auto source_val = MapType::ValueType(source); - auto target_val = MapType::ValueType(target); - auto key_cast = input.GetCastFunction(source_key, target_key); - auto value_cast = input.GetCastFunction(source_val, target_val); - return make_uniq(std::move(key_cast), std::move(value_cast)); -} - -static bool MapToVarcharCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - auto constant = source.GetVectorType() == VectorType::CONSTANT_VECTOR; - auto varchar_type = LogicalType::MAP(LogicalType::VARCHAR, LogicalType::VARCHAR); - Vector varchar_map(varchar_type, count); - - // since map's physical type is a list, the ListCast can be utilized - ListCast::ListToListCast(source, varchar_map, count, parameters); - - varchar_map.Flatten(count); - auto &validity = FlatVector::Validity(varchar_map); - auto &key_str = MapVector::GetKeys(varchar_map); - auto &val_str = MapVector::GetValues(varchar_map); - - key_str.Flatten(ListVector::GetListSize(source)); - val_str.Flatten(ListVector::GetListSize(source)); - - auto list_data = ListVector::GetData(varchar_map); - auto key_data = FlatVector::GetData(key_str); - auto val_data = FlatVector::GetData(val_str); - auto &key_validity = FlatVector::Validity(key_str); - auto &val_validity = FlatVector::Validity(val_str); - auto &struct_validity = FlatVector::Validity(ListVector::GetEntry(varchar_map)); - - auto result_data = FlatVector::GetData(result); - for (idx_t i = 0; i < count; i++) { - if (!validity.RowIsValid(i)) { - FlatVector::SetNull(result, i, true); - continue; - } - auto list = list_data[i]; - string ret = "{"; - for (idx_t list_idx = 0; list_idx < list.length; list_idx++) { - if (list_idx > 0) { - ret += ", "; - } - auto idx = list.offset + list_idx; - - if (!struct_validity.RowIsValid(idx)) { - ret += "NULL"; - continue; - } - if (!key_validity.RowIsValid(idx)) { - // throw InternalException("Error in map: key validity invalid?!"); - ret += "invalid"; - continue; - } - ret += key_data[idx].GetString(); - ret += "="; - ret += val_validity.RowIsValid(idx) ? val_data[idx].GetString() : "NULL"; - } - ret += "}"; - result_data[i] = StringVector::AddString(result, ret); - } - - if (constant) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } - return true; -} - -BoundCastInfo DefaultCasts::MapCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target) { - switch (target.id()) { - case LogicalTypeId::MAP: - return BoundCastInfo(ListCast::ListToListCast, ListBoundCastData::BindListToListCast(input, source, target), - ListBoundCastData::InitListLocalState); - case LogicalTypeId::VARCHAR: { - auto varchar_type = LogicalType::MAP(LogicalType::VARCHAR, LogicalType::VARCHAR); - return BoundCastInfo(MapToVarcharCast, ListBoundCastData::BindListToListCast(input, source, varchar_type), - ListBoundCastData::InitListLocalState); - } - default: - return TryVectorNullCast; - } -} - -} // namespace duckdb - - - - - -namespace duckdb { - -template -static BoundCastInfo InternalNumericCastSwitch(const LogicalType &source, const LogicalType &target) { - // now switch on the result type - switch (target.id()) { - case LogicalTypeId::BOOLEAN: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::TINYINT: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::SMALLINT: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::INTEGER: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::BIGINT: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::UTINYINT: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::USMALLINT: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::UINTEGER: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::UBIGINT: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::HUGEINT: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::FLOAT: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::DOUBLE: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::DECIMAL: - return BoundCastInfo(&VectorCastHelpers::ToDecimalCast); - case LogicalTypeId::VARCHAR: - return BoundCastInfo(&VectorCastHelpers::StringCast); - case LogicalTypeId::BIT: - return BoundCastInfo(&VectorCastHelpers::StringCast); - default: - return DefaultCasts::TryVectorNullCast; - } -} - -BoundCastInfo DefaultCasts::NumericCastSwitch(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - switch (source.id()) { - case LogicalTypeId::BOOLEAN: - return InternalNumericCastSwitch(source, target); - case LogicalTypeId::TINYINT: - return InternalNumericCastSwitch(source, target); - case LogicalTypeId::SMALLINT: - return InternalNumericCastSwitch(source, target); - case LogicalTypeId::INTEGER: - return InternalNumericCastSwitch(source, target); - case LogicalTypeId::BIGINT: - return InternalNumericCastSwitch(source, target); - case LogicalTypeId::UTINYINT: - return InternalNumericCastSwitch(source, target); - case LogicalTypeId::USMALLINT: - return InternalNumericCastSwitch(source, target); - case LogicalTypeId::UINTEGER: - return InternalNumericCastSwitch(source, target); - case LogicalTypeId::UBIGINT: - return InternalNumericCastSwitch(source, target); - case LogicalTypeId::HUGEINT: - return InternalNumericCastSwitch(source, target); - case LogicalTypeId::FLOAT: - return InternalNumericCastSwitch(source, target); - case LogicalTypeId::DOUBLE: - return InternalNumericCastSwitch(source, target); - default: - throw InternalException("NumericCastSwitch called with non-numeric argument"); - } -} - -} // namespace duckdb - - - -namespace duckdb { - -BoundCastInfo DefaultCasts::PointerCastSwitch(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - // now switch on the result type - switch (target.id()) { - case LogicalTypeId::VARCHAR: - // pointer to varchar - return BoundCastInfo(&VectorCastHelpers::StringCast); - default: - return nullptr; - } -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -template -bool StringEnumCastLoop(const string_t *source_data, ValidityMask &source_mask, const LogicalType &source_type, - T *result_data, ValidityMask &result_mask, const LogicalType &result_type, idx_t count, - string *error_message, const SelectionVector *sel) { - bool all_converted = true; - for (idx_t i = 0; i < count; i++) { - idx_t source_idx = i; - if (sel) { - source_idx = sel->get_index(i); - } - if (source_mask.RowIsValid(source_idx)) { - auto pos = EnumType::GetPos(result_type, source_data[source_idx]); - if (pos == -1) { - result_data[i] = - HandleVectorCastError::Operation(CastExceptionText(source_data[source_idx]), - result_mask, i, error_message, all_converted); - } else { - result_data[i] = pos; - } - } else { - result_mask.SetInvalid(i); - } - } - return all_converted; -} - -template -bool StringEnumCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - D_ASSERT(source.GetType().id() == LogicalTypeId::VARCHAR); - switch (source.GetVectorType()) { - case VectorType::CONSTANT_VECTOR: { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - - auto source_data = ConstantVector::GetData(source); - auto source_mask = ConstantVector::Validity(source); - auto result_data = ConstantVector::GetData(result); - auto &result_mask = ConstantVector::Validity(result); - - return StringEnumCastLoop(source_data, source_mask, source.GetType(), result_data, result_mask, - result.GetType(), 1, parameters.error_message, nullptr); - } - default: { - UnifiedVectorFormat vdata; - source.ToUnifiedFormat(count, vdata); - - result.SetVectorType(VectorType::FLAT_VECTOR); - - auto source_data = UnifiedVectorFormat::GetData(vdata); - auto source_sel = vdata.sel; - auto source_mask = vdata.validity; - auto result_data = FlatVector::GetData(result); - auto &result_mask = FlatVector::Validity(result); - - return StringEnumCastLoop(source_data, source_mask, source.GetType(), result_data, result_mask, - result.GetType(), count, parameters.error_message, source_sel); - } - } -} - -static BoundCastInfo VectorStringCastNumericSwitch(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - // now switch on the result type - switch (target.id()) { - case LogicalTypeId::ENUM: { - switch (target.InternalType()) { - case PhysicalType::UINT8: - return StringEnumCast; - case PhysicalType::UINT16: - return StringEnumCast; - case PhysicalType::UINT32: - return StringEnumCast; - default: - throw InternalException("ENUM can only have unsigned integers (except UINT64) as physical types"); - } - } - case LogicalTypeId::BOOLEAN: - return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop); - case LogicalTypeId::TINYINT: - return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop); - case LogicalTypeId::SMALLINT: - return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop); - case LogicalTypeId::INTEGER: - return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop); - case LogicalTypeId::BIGINT: - return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop); - case LogicalTypeId::UTINYINT: - return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop); - case LogicalTypeId::USMALLINT: - return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop); - case LogicalTypeId::UINTEGER: - return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop); - case LogicalTypeId::UBIGINT: - return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop); - case LogicalTypeId::HUGEINT: - return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop); - case LogicalTypeId::FLOAT: - return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop); - case LogicalTypeId::DOUBLE: - return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop); - case LogicalTypeId::INTERVAL: - return BoundCastInfo(&VectorCastHelpers::TryCastErrorLoop); - case LogicalTypeId::DECIMAL: - return BoundCastInfo(&VectorCastHelpers::ToDecimalCast); - default: - return DefaultCasts::TryVectorNullCast; - } -} - -//===--------------------------------------------------------------------===// -// string -> list casting -//===--------------------------------------------------------------------===// -bool VectorStringToList::StringToNestedTypeCastLoop(const string_t *source_data, ValidityMask &source_mask, - Vector &result, ValidityMask &result_mask, idx_t count, - CastParameters ¶meters, const SelectionVector *sel) { - idx_t total_list_size = 0; - for (idx_t i = 0; i < count; i++) { - idx_t idx = i; - if (sel) { - idx = sel->get_index(i); - } - if (!source_mask.RowIsValid(idx)) { - continue; - } - total_list_size += VectorStringToList::CountPartsList(source_data[idx]); - } - - Vector varchar_vector(LogicalType::VARCHAR, total_list_size); - - ListVector::Reserve(result, total_list_size); - ListVector::SetListSize(result, total_list_size); - - auto list_data = ListVector::GetData(result); - auto child_data = FlatVector::GetData(varchar_vector); - - bool all_converted = true; - idx_t total = 0; - for (idx_t i = 0; i < count; i++) { - idx_t idx = i; - if (sel) { - idx = sel->get_index(i); - } - if (!source_mask.RowIsValid(idx)) { - result_mask.SetInvalid(i); - continue; - } - - list_data[i].offset = total; - if (!VectorStringToList::SplitStringList(source_data[idx], child_data, total, varchar_vector)) { - string text = "Type VARCHAR with value '" + source_data[idx].GetString() + - "' can't be cast to the destination type LIST"; - HandleVectorCastError::Operation(text, result_mask, idx, parameters.error_message, all_converted); - } - list_data[i].length = total - list_data[i].offset; // length is the amount of parts coming from this string - } - D_ASSERT(total_list_size == total); - - auto &result_child = ListVector::GetEntry(result); - auto &cast_data = parameters.cast_data->Cast(); - CastParameters child_parameters(parameters, cast_data.child_cast_info.cast_data, parameters.local_state); - return cast_data.child_cast_info.function(varchar_vector, result_child, total_list_size, child_parameters) && - all_converted; -} - -static LogicalType InitVarcharStructType(const LogicalType &target) { - child_list_t child_types; - for (auto &child : StructType::GetChildTypes(target)) { - child_types.push_back(make_pair(child.first, LogicalType::VARCHAR)); - } - - return LogicalType::STRUCT(child_types); -} - -//===--------------------------------------------------------------------===// -// string -> struct casting -//===--------------------------------------------------------------------===// -bool VectorStringToStruct::StringToNestedTypeCastLoop(const string_t *source_data, ValidityMask &source_mask, - Vector &result, ValidityMask &result_mask, idx_t count, - CastParameters ¶meters, const SelectionVector *sel) { - auto varchar_struct_type = InitVarcharStructType(result.GetType()); - Vector varchar_vector(varchar_struct_type, count); - auto &child_vectors = StructVector::GetEntries(varchar_vector); - auto &result_children = StructVector::GetEntries(result); - - string_map_t child_names; - vector child_masks; - for (idx_t child_idx = 0; child_idx < result_children.size(); child_idx++) { - child_names.insert({StructType::GetChildName(result.GetType(), child_idx), child_idx}); - child_masks.emplace_back(&FlatVector::Validity(*child_vectors[child_idx])); - child_masks[child_idx]->SetAllInvalid(count); - } - - bool all_converted = true; - for (idx_t i = 0; i < count; i++) { - idx_t idx = i; - if (sel) { - idx = sel->get_index(i); - } - if (!source_mask.RowIsValid(idx)) { - result_mask.SetInvalid(i); - continue; - } - if (!VectorStringToStruct::SplitStruct(source_data[idx], child_vectors, i, child_names, child_masks)) { - string text = "Type VARCHAR with value '" + source_data[idx].GetString() + - "' can't be cast to the destination type STRUCT"; - for (auto &child_mask : child_masks) { - child_mask->SetInvalid(idx); // some values may have already been found and set valid - } - HandleVectorCastError::Operation(text, result_mask, idx, parameters.error_message, all_converted); - } - } - - auto &cast_data = parameters.cast_data->Cast(); - auto &lstate = parameters.local_state->Cast(); - D_ASSERT(cast_data.child_cast_info.size() == result_children.size()); - - for (idx_t child_idx = 0; child_idx < result_children.size(); child_idx++) { - auto &child_varchar_vector = *child_vectors[child_idx]; - auto &result_child_vector = *result_children[child_idx]; - auto &child_cast_info = cast_data.child_cast_info[child_idx]; - CastParameters child_parameters(parameters, child_cast_info.cast_data, lstate.local_states[child_idx]); - if (!child_cast_info.function(child_varchar_vector, result_child_vector, count, child_parameters)) { - all_converted = false; - } - } - return all_converted; -} - -//===--------------------------------------------------------------------===// -// string -> map casting -//===--------------------------------------------------------------------===// -unique_ptr InitMapCastLocalState(CastLocalStateParameters ¶meters) { - auto &cast_data = parameters.cast_data->Cast(); - auto result = make_uniq(); - - if (cast_data.key_cast.init_local_state) { - CastLocalStateParameters child_params(parameters, cast_data.key_cast.cast_data); - result->key_state = cast_data.key_cast.init_local_state(child_params); - } - if (cast_data.value_cast.init_local_state) { - CastLocalStateParameters child_params(parameters, cast_data.value_cast.cast_data); - result->value_state = cast_data.value_cast.init_local_state(child_params); - } - return std::move(result); -} - -bool VectorStringToMap::StringToNestedTypeCastLoop(const string_t *source_data, ValidityMask &source_mask, - Vector &result, ValidityMask &result_mask, idx_t count, - CastParameters ¶meters, const SelectionVector *sel) { - idx_t total_elements = 0; - for (idx_t i = 0; i < count; i++) { - idx_t idx = i; - if (sel) { - idx = sel->get_index(i); - } - if (!source_mask.RowIsValid(idx)) { - continue; - } - total_elements += (VectorStringToMap::CountPartsMap(source_data[idx]) + 1) / 2; - } - - Vector varchar_key_vector(LogicalType::VARCHAR, total_elements); - Vector varchar_val_vector(LogicalType::VARCHAR, total_elements); - auto child_key_data = FlatVector::GetData(varchar_key_vector); - auto child_val_data = FlatVector::GetData(varchar_val_vector); - - ListVector::Reserve(result, total_elements); - ListVector::SetListSize(result, total_elements); - auto list_data = ListVector::GetData(result); - - bool all_converted = true; - idx_t total = 0; - for (idx_t i = 0; i < count; i++) { - idx_t idx = i; - if (sel) { - idx = sel->get_index(i); - } - if (!source_mask.RowIsValid(idx)) { - result_mask.SetInvalid(idx); - continue; - } - - list_data[i].offset = total; - if (!VectorStringToMap::SplitStringMap(source_data[idx], child_key_data, child_val_data, total, - varchar_key_vector, varchar_val_vector)) { - string text = "Type VARCHAR with value '" + source_data[idx].GetString() + - "' can't be cast to the destination type MAP"; - FlatVector::SetNull(result, idx, true); - HandleVectorCastError::Operation(text, result_mask, idx, parameters.error_message, all_converted); - } - list_data[i].length = total - list_data[i].offset; - } - D_ASSERT(total_elements == total); - - auto &result_key_child = MapVector::GetKeys(result); - auto &result_val_child = MapVector::GetValues(result); - auto &cast_data = parameters.cast_data->Cast(); - auto &lstate = parameters.local_state->Cast(); - - CastParameters key_params(parameters, cast_data.key_cast.cast_data, lstate.key_state); - if (!cast_data.key_cast.function(varchar_key_vector, result_key_child, total_elements, key_params)) { - all_converted = false; - } - CastParameters val_params(parameters, cast_data.value_cast.cast_data, lstate.value_state); - if (!cast_data.value_cast.function(varchar_val_vector, result_val_child, total_elements, val_params)) { - all_converted = false; - } - - auto &key_validity = FlatVector::Validity(result_key_child); - if (!all_converted) { - for (idx_t row_idx = 0; row_idx < count; row_idx++) { - if (!result_mask.RowIsValid(row_idx)) { - continue; - } - auto list = list_data[row_idx]; - for (idx_t list_idx = 0; list_idx < list.length; list_idx++) { - auto idx = list.offset + list_idx; - if (!key_validity.RowIsValid(idx)) { - result_mask.SetInvalid(row_idx); - } - } - } - } - MapVector::MapConversionVerify(result, count); - return all_converted; -} - -template -bool StringToNestedTypeCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - D_ASSERT(source.GetType().id() == LogicalTypeId::VARCHAR); - - switch (source.GetVectorType()) { - case VectorType::CONSTANT_VECTOR: { - auto source_data = ConstantVector::GetData(source); - auto &source_mask = ConstantVector::Validity(source); - auto &result_mask = FlatVector::Validity(result); - auto ret = T::StringToNestedTypeCastLoop(source_data, source_mask, result, result_mask, 1, parameters, nullptr); - result.SetVectorType(VectorType::CONSTANT_VECTOR); - return ret; - } - default: { - UnifiedVectorFormat unified_source; - - source.ToUnifiedFormat(count, unified_source); - auto source_sel = unified_source.sel; - auto source_data = UnifiedVectorFormat::GetData(unified_source); - auto &source_mask = unified_source.validity; - auto &result_mask = FlatVector::Validity(result); - - return T::StringToNestedTypeCastLoop(source_data, source_mask, result, result_mask, count, parameters, - source_sel); - } - } -} - -BoundCastInfo DefaultCasts::StringCastSwitch(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - switch (target.id()) { - case LogicalTypeId::DATE: - return BoundCastInfo(&VectorCastHelpers::TryCastErrorLoop); - case LogicalTypeId::TIME: - return BoundCastInfo(&VectorCastHelpers::TryCastErrorLoop); - case LogicalTypeId::TIME_TZ: - return BoundCastInfo(&VectorCastHelpers::TryCastErrorLoop); - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_TZ: - return BoundCastInfo(&VectorCastHelpers::TryCastErrorLoop); - case LogicalTypeId::TIMESTAMP_NS: - return BoundCastInfo( - &VectorCastHelpers::TryCastStrictLoop); - case LogicalTypeId::TIMESTAMP_SEC: - return BoundCastInfo( - &VectorCastHelpers::TryCastStrictLoop); - case LogicalTypeId::TIMESTAMP_MS: - return BoundCastInfo( - &VectorCastHelpers::TryCastStrictLoop); - case LogicalTypeId::BLOB: - return BoundCastInfo(&VectorCastHelpers::TryCastStringLoop); - case LogicalTypeId::BIT: - return BoundCastInfo(&VectorCastHelpers::TryCastStringLoop); - case LogicalTypeId::UUID: - return BoundCastInfo(&VectorCastHelpers::TryCastStringLoop); - case LogicalTypeId::SQLNULL: - return &DefaultCasts::TryVectorNullCast; - case LogicalTypeId::VARCHAR: - return &DefaultCasts::ReinterpretCast; - case LogicalTypeId::LIST: - // the second argument allows for a secondary casting function to be passed in the CastParameters - return BoundCastInfo( - &StringToNestedTypeCast, - ListBoundCastData::BindListToListCast(input, LogicalType::LIST(LogicalType::VARCHAR), target), - ListBoundCastData::InitListLocalState); - case LogicalTypeId::STRUCT: - return BoundCastInfo(&StringToNestedTypeCast, - StructBoundCastData::BindStructToStructCast(input, InitVarcharStructType(target), target), - StructBoundCastData::InitStructCastLocalState); - case LogicalTypeId::MAP: - return BoundCastInfo(&StringToNestedTypeCast, - MapBoundCastData::BindMapToMapCast( - input, LogicalType::MAP(LogicalType::VARCHAR, LogicalType::VARCHAR), target), - InitMapCastLocalState); - default: - return VectorStringCastNumericSwitch(input, source, target); - } -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr StructBoundCastData::BindStructToStructCast(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - vector child_cast_info; - auto &source_child_types = StructType::GetChildTypes(source); - auto &result_child_types = StructType::GetChildTypes(target); - - auto target_is_unnamed = StructType::IsUnnamed(target); - auto source_is_unnamed = StructType::IsUnnamed(source); - - if (source_child_types.size() != result_child_types.size()) { - throw TypeMismatchException(source, target, "Cannot cast STRUCTs of different size"); - } - for (idx_t i = 0; i < source_child_types.size(); i++) { - if (!target_is_unnamed && !source_is_unnamed && - !StringUtil::CIEquals(source_child_types[i].first, result_child_types[i].first)) { - throw TypeMismatchException(source, target, "Cannot cast STRUCTs with different names"); - } - auto child_cast = input.GetCastFunction(source_child_types[i].second, result_child_types[i].second); - child_cast_info.push_back(std::move(child_cast)); - } - return make_uniq(std::move(child_cast_info), target); -} - -unique_ptr StructBoundCastData::InitStructCastLocalState(CastLocalStateParameters ¶meters) { - auto &cast_data = parameters.cast_data->Cast(); - auto result = make_uniq(); - - for (auto &entry : cast_data.child_cast_info) { - unique_ptr child_state; - if (entry.init_local_state) { - CastLocalStateParameters child_params(parameters, entry.cast_data); - child_state = entry.init_local_state(child_params); - } - result->local_states.push_back(std::move(child_state)); - } - return std::move(result); -} - -static bool StructToStructCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - auto &cast_data = parameters.cast_data->Cast(); - auto &lstate = parameters.local_state->Cast(); - auto &source_child_types = StructType::GetChildTypes(source.GetType()); - auto &source_children = StructVector::GetEntries(source); - D_ASSERT(source_children.size() == StructType::GetChildTypes(result.GetType()).size()); - - auto &result_children = StructVector::GetEntries(result); - bool all_converted = true; - for (idx_t c_idx = 0; c_idx < source_child_types.size(); c_idx++) { - auto &result_child_vector = *result_children[c_idx]; - auto &source_child_vector = *source_children[c_idx]; - CastParameters child_parameters(parameters, cast_data.child_cast_info[c_idx].cast_data, - lstate.local_states[c_idx]); - if (!cast_data.child_cast_info[c_idx].function(source_child_vector, result_child_vector, count, - child_parameters)) { - all_converted = false; - } - } - if (source.GetVectorType() == VectorType::CONSTANT_VECTOR) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, ConstantVector::IsNull(source)); - } else { - source.Flatten(count); - FlatVector::Validity(result) = FlatVector::Validity(source); - } - return all_converted; -} - -static bool StructToVarcharCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - auto constant = source.GetVectorType() == VectorType::CONSTANT_VECTOR; - // first cast all child elements to varchar - auto &cast_data = parameters.cast_data->Cast(); - Vector varchar_struct(cast_data.target, count); - StructToStructCast(source, varchar_struct, count, parameters); - - // now construct the actual varchar vector - varchar_struct.Flatten(count); - auto &child_types = StructType::GetChildTypes(source.GetType()); - auto &children = StructVector::GetEntries(varchar_struct); - auto &validity = FlatVector::Validity(varchar_struct); - auto result_data = FlatVector::GetData(result); - static constexpr const idx_t SEP_LENGTH = 2; - static constexpr const idx_t NAME_SEP_LENGTH = 4; - static constexpr const idx_t NULL_LENGTH = 4; - for (idx_t i = 0; i < count; i++) { - if (!validity.RowIsValid(i)) { - FlatVector::SetNull(result, i, true); - continue; - } - idx_t string_length = 2; // {} - for (idx_t c = 0; c < children.size(); c++) { - if (c > 0) { - string_length += SEP_LENGTH; - } - children[c]->Flatten(count); - auto &child_validity = FlatVector::Validity(*children[c]); - auto data = FlatVector::GetData(*children[c]); - auto &name = child_types[c].first; - string_length += name.size() + NAME_SEP_LENGTH; // "'{name}': " - string_length += child_validity.RowIsValid(i) ? data[i].GetSize() : NULL_LENGTH; - } - result_data[i] = StringVector::EmptyString(result, string_length); - auto dataptr = result_data[i].GetDataWriteable(); - idx_t offset = 0; - dataptr[offset++] = '{'; - for (idx_t c = 0; c < children.size(); c++) { - if (c > 0) { - memcpy(dataptr + offset, ", ", SEP_LENGTH); - offset += SEP_LENGTH; - } - auto &child_validity = FlatVector::Validity(*children[c]); - auto data = FlatVector::GetData(*children[c]); - auto &name = child_types[c].first; - // "'{name}': " - dataptr[offset++] = '\''; - memcpy(dataptr + offset, name.c_str(), name.size()); - offset += name.size(); - dataptr[offset++] = '\''; - dataptr[offset++] = ':'; - dataptr[offset++] = ' '; - // value - if (child_validity.RowIsValid(i)) { - auto len = data[i].GetSize(); - memcpy(dataptr + offset, data[i].GetData(), len); - offset += len; - } else { - memcpy(dataptr + offset, "NULL", NULL_LENGTH); - offset += NULL_LENGTH; - } - } - dataptr[offset++] = '}'; - result_data[i].Finalize(); - } - - if (constant) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } - return true; -} - -BoundCastInfo DefaultCasts::StructCastSwitch(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - switch (target.id()) { - case LogicalTypeId::STRUCT: - return BoundCastInfo(StructToStructCast, StructBoundCastData::BindStructToStructCast(input, source, target), - StructBoundCastData::InitStructCastLocalState); - case LogicalTypeId::VARCHAR: { - // bind a cast in which we convert all child entries to VARCHAR entries - auto &struct_children = StructType::GetChildTypes(source); - child_list_t varchar_children; - for (auto &child_entry : struct_children) { - varchar_children.push_back(make_pair(child_entry.first, LogicalType::VARCHAR)); - } - auto varchar_type = LogicalType::STRUCT(varchar_children); - return BoundCastInfo(StructToVarcharCast, - StructBoundCastData::BindStructToStructCast(input, source, varchar_type), - StructBoundCastData::InitStructCastLocalState); - } - default: - return TryVectorNullCast; - } -} - -} // namespace duckdb - - - -namespace duckdb { - -BoundCastInfo DefaultCasts::DateCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target) { - // now switch on the result type - switch (target.id()) { - case LogicalTypeId::VARCHAR: - // date to varchar - return BoundCastInfo(&VectorCastHelpers::StringCast); - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_TZ: - // date to timestamp - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::TIMESTAMP_NS: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::TIMESTAMP_SEC: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::TIMESTAMP_MS: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - default: - return TryVectorNullCast; - } -} - -BoundCastInfo DefaultCasts::TimeCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target) { - // now switch on the result type - switch (target.id()) { - case LogicalTypeId::VARCHAR: - // time to varchar - return BoundCastInfo(&VectorCastHelpers::StringCast); - case LogicalTypeId::TIME_TZ: - // time to time with time zone - return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); - default: - return TryVectorNullCast; - } -} - -BoundCastInfo DefaultCasts::TimeTzCastSwitch(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - // now switch on the result type - switch (target.id()) { - case LogicalTypeId::VARCHAR: - // time with time zone to varchar - return BoundCastInfo(&VectorCastHelpers::StringCast); - case LogicalTypeId::TIME: - // time with time zone to time - return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); - default: - return TryVectorNullCast; - } -} - -BoundCastInfo DefaultCasts::TimestampCastSwitch(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - // now switch on the result type - switch (target.id()) { - case LogicalTypeId::VARCHAR: - // timestamp to varchar - return BoundCastInfo(&VectorCastHelpers::StringCast); - case LogicalTypeId::DATE: - // timestamp to date - return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); - case LogicalTypeId::TIME: - // timestamp to time - return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); - case LogicalTypeId::TIME_TZ: - // timestamp to time_tz - return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); - case LogicalTypeId::TIMESTAMP_TZ: - // timestamp (us) to timestamp with time zone - return ReinterpretCast; - case LogicalTypeId::TIMESTAMP_NS: - // timestamp (us) to timestamp (ns) - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); - case LogicalTypeId::TIMESTAMP_MS: - // timestamp (us) to timestamp (ms) - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); - case LogicalTypeId::TIMESTAMP_SEC: - // timestamp (us) to timestamp (s) - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); - default: - return TryVectorNullCast; - } -} - -BoundCastInfo DefaultCasts::TimestampTzCastSwitch(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - // now switch on the result type - switch (target.id()) { - case LogicalTypeId::VARCHAR: - // timestamp with time zone to varchar - return BoundCastInfo(&VectorCastHelpers::StringCast); - case LogicalTypeId::TIME_TZ: - // timestamp with time zone to time with time zone. - return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); - case LogicalTypeId::TIMESTAMP: - // timestamp with time zone to timestamp (us) - return ReinterpretCast; - default: - return TryVectorNullCast; - } -} - -BoundCastInfo DefaultCasts::TimestampNsCastSwitch(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - // now switch on the result type - switch (target.id()) { - case LogicalTypeId::VARCHAR: - // timestamp (ns) to varchar - return BoundCastInfo(&VectorCastHelpers::StringCast); - case LogicalTypeId::TIMESTAMP: - // timestamp (ns) to timestamp (us) - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); - default: - return TryVectorNullCast; - } -} - -BoundCastInfo DefaultCasts::TimestampMsCastSwitch(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - // now switch on the result type - switch (target.id()) { - case LogicalTypeId::VARCHAR: - // timestamp (ms) to varchar - return BoundCastInfo(&VectorCastHelpers::StringCast); - case LogicalTypeId::TIMESTAMP: - // timestamp (ms) to timestamp (us) - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); - case LogicalTypeId::TIMESTAMP_NS: - // timestamp (ms) to timestamp (ns) - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); - default: - return TryVectorNullCast; - } -} - -BoundCastInfo DefaultCasts::TimestampSecCastSwitch(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - // now switch on the result type - switch (target.id()) { - case LogicalTypeId::VARCHAR: - // timestamp (sec) to varchar - return BoundCastInfo(&VectorCastHelpers::StringCast); - case LogicalTypeId::TIMESTAMP_MS: - // timestamp (s) to timestamp (ms) - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); - case LogicalTypeId::TIMESTAMP: - // timestamp (s) to timestamp (us) - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); - case LogicalTypeId::TIMESTAMP_NS: - // timestamp (s) to timestamp (ns) - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); - default: - return TryVectorNullCast; - } -} -BoundCastInfo DefaultCasts::IntervalCastSwitch(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - // now switch on the result type - switch (target.id()) { - case LogicalTypeId::VARCHAR: - // time to varchar - return BoundCastInfo(&VectorCastHelpers::StringCast); - default: - return TryVectorNullCast; - } -} - -} // namespace duckdb - - -namespace duckdb { - -bool StructToUnionCast::AllowImplicitCastFromStruct(const LogicalType &source, const LogicalType &target) { - if (source.id() != LogicalTypeId::STRUCT) { - return false; - } - auto target_fields = StructType::GetChildTypes(target); - auto fields = StructType::GetChildTypes(source); - if (target_fields.size() != fields.size()) { - // Struct should have the same amount of fields as the union - return false; - } - for (idx_t i = 0; i < target_fields.size(); i++) { - auto &target_field = target_fields[i].second; - auto &target_field_name = target_fields[i].first; - auto &field = fields[i].second; - auto &field_name = fields[i].first; - if (i == 0) { - // For the tag field we don't accept a type substitute as varchar - if (target_field != field) { - return false; - } - continue; - } - if (!StringUtil::CIEquals(target_field_name, field_name)) { - return false; - } - if (target_field != field && field != LogicalType::VARCHAR) { - // We allow the field to be VARCHAR, since unsupported types get cast to VARCHAR by EXPORT DATABASE (format - // PARQUET) i.e UNION(a BIT) becomes STRUCT(a VARCHAR) - return false; - } - } - return true; -} - -// Physical Cast execution - -bool StructToUnionCast::Cast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - auto &cast_data = parameters.cast_data->Cast(); - auto &lstate = parameters.local_state->Cast(); - - D_ASSERT(source.GetType().id() == LogicalTypeId::STRUCT); - D_ASSERT(result.GetType().id() == LogicalTypeId::UNION); - D_ASSERT(cast_data.target.id() == LogicalTypeId::UNION); - - auto &source_children = StructVector::GetEntries(source); - auto &target_children = StructVector::GetEntries(result); - - for (idx_t i = 0; i < source_children.size(); i++) { - auto &result_child_vector = *target_children[i]; - auto &source_child_vector = *source_children[i]; - CastParameters child_parameters(parameters, cast_data.child_cast_info[i].cast_data, lstate.local_states[i]); - auto converted = - cast_data.child_cast_info[i].function(source_child_vector, result_child_vector, count, child_parameters); - (void)converted; - D_ASSERT(converted); - } - - auto check_tags = UnionVector::CheckUnionValidity(result, count); - switch (check_tags) { - case UnionInvalidReason::TAG_OUT_OF_RANGE: - throw ConversionException("One or more of the tags do not point to a valid union member"); - case UnionInvalidReason::VALIDITY_OVERLAP: - throw ConversionException("One or more rows in the produced UNION have validity set for more than 1 member"); - case UnionInvalidReason::TAG_MISMATCH: - throw ConversionException( - "One or more rows in the produced UNION have tags that don't point to the valid member"); - case UnionInvalidReason::VALID: - break; - default: - throw InternalException("Struct to union cast failed for unknown reason"); - } - - if (source.GetVectorType() == VectorType::CONSTANT_VECTOR) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, ConstantVector::IsNull(source)); - } else { - source.Flatten(count); - FlatVector::Validity(result) = FlatVector::Validity(source); - } - result.Verify(count); - return true; -} - -// Bind cast - -unique_ptr StructToUnionCast::BindData(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - vector child_cast_info; - D_ASSERT(source.id() == LogicalTypeId::STRUCT); - D_ASSERT(target.id() == LogicalTypeId::UNION); - - auto result_child_count = StructType::GetChildCount(target); - D_ASSERT(result_child_count == StructType::GetChildCount(source)); - - for (idx_t i = 0; i < result_child_count; i++) { - auto &source_child = StructType::GetChildType(source, i); - auto &target_child = StructType::GetChildType(target, i); - - auto child_cast = input.GetCastFunction(source_child, target_child); - child_cast_info.push_back(std::move(child_cast)); - } - return make_uniq(std::move(child_cast_info), target); -} - -BoundCastInfo StructToUnionCast::Bind(BindCastInput &input, const LogicalType &source, const LogicalType &target) { - auto cast_data = StructToUnionCast::BindData(input, source, target); - return BoundCastInfo(&StructToUnionCast::Cast, std::move(cast_data), StructBoundCastData::InitStructCastLocalState); -} - -} // namespace duckdb - - - - -#include // for std::sort - -namespace duckdb { - -//-------------------------------------------------------------------------------------------------- -// ??? -> UNION -//-------------------------------------------------------------------------------------------------- -// if the source can be implicitly cast to a member of the target union, the cast is valid - -unique_ptr BindToUnionCast(BindCastInput &input, const LogicalType &source, const LogicalType &target) { - D_ASSERT(target.id() == LogicalTypeId::UNION); - - vector candidates; - - for (idx_t member_idx = 0; member_idx < UnionType::GetMemberCount(target); member_idx++) { - auto member_type = UnionType::GetMemberType(target, member_idx); - auto member_name = UnionType::GetMemberName(target, member_idx); - auto member_cast_cost = input.function_set.ImplicitCastCost(source, member_type); - if (member_cast_cost != -1) { - auto member_cast_info = input.GetCastFunction(source, member_type); - candidates.emplace_back(member_idx, member_name, member_type, member_cast_cost, - std::move(member_cast_info)); - } - }; - - // no possible casts found! - if (candidates.empty()) { - auto message = StringUtil::Format( - "Type %s can't be cast as %s. %s can't be implicitly cast to any of the union member types: ", - source.ToString(), target.ToString(), source.ToString()); - - auto member_count = UnionType::GetMemberCount(target); - for (idx_t member_idx = 0; member_idx < member_count; member_idx++) { - auto member_type = UnionType::GetMemberType(target, member_idx); - message += member_type.ToString(); - if (member_idx < member_count - 1) { - message += ", "; - } - } - throw CastException(message); - } - - // sort the candidate casts by cost - std::sort(candidates.begin(), candidates.end(), UnionBoundCastData::SortByCostAscending); - - // select the lowest possible cost cast - auto &selected_cast = candidates[0]; - auto selected_cost = candidates[0].cost; - - // check if the cast is ambiguous (2 or more casts have the same cost) - if (candidates.size() > 1 && candidates[1].cost == selected_cost) { - - // collect all the ambiguous types - auto message = StringUtil::Format( - "Type %s can't be cast as %s. The cast is ambiguous, multiple possible members in target: ", source, - target); - for (size_t i = 0; i < candidates.size(); i++) { - if (candidates[i].cost == selected_cost) { - message += StringUtil::Format("'%s (%s)'", candidates[i].name, candidates[i].type.ToString()); - if (i < candidates.size() - 1) { - message += ", "; - } - } - } - message += ". Disambiguate the target type by using the 'union_value( := )' function to promote the " - "source value to a single member union before casting."; - throw CastException(message); - } - - // otherwise, return the selected cast - return make_uniq(std::move(selected_cast)); -} - -unique_ptr InitToUnionLocalState(CastLocalStateParameters ¶meters) { - auto &cast_data = parameters.cast_data->Cast(); - if (!cast_data.member_cast_info.init_local_state) { - return nullptr; - } - CastLocalStateParameters child_parameters(parameters, cast_data.member_cast_info.cast_data); - return cast_data.member_cast_info.init_local_state(child_parameters); -} - -static bool ToUnionCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - D_ASSERT(result.GetType().id() == LogicalTypeId::UNION); - auto &cast_data = parameters.cast_data->Cast(); - auto &selected_member_vector = UnionVector::GetMember(result, cast_data.tag); - - CastParameters child_parameters(parameters, cast_data.member_cast_info.cast_data, parameters.local_state); - if (!cast_data.member_cast_info.function(source, selected_member_vector, count, child_parameters)) { - return false; - } - - // cast succeeded, create union vector - UnionVector::SetToMember(result, cast_data.tag, selected_member_vector, count, true); - - result.Verify(count); - - return true; -} - -BoundCastInfo DefaultCasts::ImplicitToUnionCast(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - - D_ASSERT(target.id() == LogicalTypeId::UNION); - if (StructToUnionCast::AllowImplicitCastFromStruct(source, target)) { - return StructToUnionCast::Bind(input, source, target); - } - auto cast_data = BindToUnionCast(input, source, target); - return BoundCastInfo(&ToUnionCast, std::move(cast_data), InitToUnionLocalState); -} - -//-------------------------------------------------------------------------------------------------- -// UNION -> UNION -//-------------------------------------------------------------------------------------------------- -// if the source member tags is a subset of the target member tags, and all the source members can be -// implicitly cast to the corresponding target members, the cast is valid. -// -// VALID: UNION(A, B) -> UNION(A, B, C) -// VALID: UNION(A, B) -> UNION(A, C) if B can be implicitly cast to C -// -// INVALID: UNION(A, B, C) -> UNION(A, B) -// INVALID: UNION(A, B) -> UNION(A, C) if B can't be implicitly cast to C -// INVALID: UNION(A, B, D) -> UNION(A, B, C) - -struct UnionUnionBoundCastData : public BoundCastData { - - // mapping from source member index to target member index - // these are always the same size as the source member count - // (since all source members must be present in the target) - vector tag_map; - vector member_casts; - - LogicalType target_type; - - UnionUnionBoundCastData(vector tag_map, vector member_casts, LogicalType target_type) - : tag_map(std::move(tag_map)), member_casts(std::move(member_casts)), target_type(std::move(target_type)) { - } - -public: - unique_ptr Copy() const override { - vector member_casts_copy; - for (auto &member_cast : member_casts) { - member_casts_copy.push_back(member_cast.Copy()); - } - return make_uniq(tag_map, std::move(member_casts_copy), target_type); - } -}; - -unique_ptr BindUnionToUnionCast(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - D_ASSERT(source.id() == LogicalTypeId::UNION); - D_ASSERT(target.id() == LogicalTypeId::UNION); - - auto source_member_count = UnionType::GetMemberCount(source); - - auto tag_map = vector(source_member_count); - auto member_casts = vector(); - - for (idx_t source_idx = 0; source_idx < source_member_count; source_idx++) { - auto &source_member_type = UnionType::GetMemberType(source, source_idx); - auto &source_member_name = UnionType::GetMemberName(source, source_idx); - - bool found = false; - for (idx_t target_idx = 0; target_idx < UnionType::GetMemberCount(target); target_idx++) { - auto &target_member_name = UnionType::GetMemberName(target, target_idx); - - // found a matching member - if (source_member_name == target_member_name) { - auto &target_member_type = UnionType::GetMemberType(target, target_idx); - tag_map[source_idx] = target_idx; - member_casts.push_back(input.GetCastFunction(source_member_type, target_member_type)); - found = true; - break; - } - } - if (!found) { - // no matching member tag found in the target set - auto message = - StringUtil::Format("Type %s can't be cast as %s. The member '%s' is not present in target union", - source.ToString(), target.ToString(), source_member_name); - throw CastException(message); - } - } - - return make_uniq(tag_map, std::move(member_casts), target); -} - -unique_ptr InitUnionToUnionLocalState(CastLocalStateParameters ¶meters) { - auto &cast_data = parameters.cast_data->Cast(); - auto result = make_uniq(); - - for (auto &entry : cast_data.member_casts) { - unique_ptr child_state; - if (entry.init_local_state) { - CastLocalStateParameters child_params(parameters, entry.cast_data); - child_state = entry.init_local_state(child_params); - } - result->local_states.push_back(std::move(child_state)); - } - return std::move(result); -} - -static bool UnionToUnionCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - auto &cast_data = parameters.cast_data->Cast(); - auto &lstate = parameters.local_state->Cast(); - - auto source_member_count = UnionType::GetMemberCount(source.GetType()); - auto target_member_count = UnionType::GetMemberCount(result.GetType()); - - auto target_member_is_mapped = vector(target_member_count); - - // Perform the casts from source to target members - for (idx_t member_idx = 0; member_idx < source_member_count; member_idx++) { - auto target_member_idx = cast_data.tag_map[member_idx]; - - auto &source_member_vector = UnionVector::GetMember(source, member_idx); - auto &target_member_vector = UnionVector::GetMember(result, target_member_idx); - auto &member_cast = cast_data.member_casts[member_idx]; - - CastParameters child_parameters(parameters, member_cast.cast_data, lstate.local_states[member_idx]); - if (!member_cast.function(source_member_vector, target_member_vector, count, child_parameters)) { - return false; - } - - target_member_is_mapped[target_member_idx] = true; - } - - // All member casts succeeded! - - // Set the unmapped target members to constant NULL. - // If we cast UNION(A, B) -> UNION(A, B, C) we need to invalidate C so that - // the invariants of the result union hold. (only member columns "selected" - // by the rowwise corresponding tag in the tag vector should be valid) - for (idx_t target_member_idx = 0; target_member_idx < target_member_count; target_member_idx++) { - if (!target_member_is_mapped[target_member_idx]) { - auto &target_member_vector = UnionVector::GetMember(result, target_member_idx); - target_member_vector.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(target_member_vector, true); - } - } - - // Update the tags in the result vector - auto &source_tag_vector = UnionVector::GetTags(source); - auto &result_tag_vector = UnionVector::GetTags(result); - - if (source.GetVectorType() == VectorType::CONSTANT_VECTOR) { - // Constant vector case optimization - result.SetVectorType(VectorType::CONSTANT_VECTOR); - if (ConstantVector::IsNull(source)) { - ConstantVector::SetNull(result, true); - } else { - // map the tag - auto source_tag = ConstantVector::GetData(source_tag_vector)[0]; - auto mapped_tag = cast_data.tag_map[source_tag]; - ConstantVector::GetData(result_tag_vector)[0] = mapped_tag; - } - } else { - // Otherwise, use the unified vector format to access the source vector. - - // Ensure that all the result members are flat vectors - // This is not always the case, e.g. when a member is cast using the default TryNullCast function - // the resulting member vector will be a constant null vector. - for (idx_t target_member_idx = 0; target_member_idx < target_member_count; target_member_idx++) { - UnionVector::GetMember(result, target_member_idx).Flatten(count); - } - - // We assume that a union tag vector validity matches the union vector validity. - UnifiedVectorFormat source_tag_format; - source_tag_vector.ToUnifiedFormat(count, source_tag_format); - - for (idx_t row_idx = 0; row_idx < count; row_idx++) { - auto source_row_idx = source_tag_format.sel->get_index(row_idx); - if (source_tag_format.validity.RowIsValid(source_row_idx)) { - // map the tag - auto source_tag = (UnifiedVectorFormat::GetData(source_tag_format))[source_row_idx]; - auto target_tag = cast_data.tag_map[source_tag]; - FlatVector::GetData(result_tag_vector)[row_idx] = target_tag; - } else { - - // Issue: The members of the result is not always flatvectors - // In the case of TryNullCast, the result member is constant. - FlatVector::SetNull(result, row_idx, true); - } - } - } - - result.Verify(count); - - return true; -} - -static bool UnionToVarcharCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - auto constant = source.GetVectorType() == VectorType::CONSTANT_VECTOR; - // first cast all union members to varchar - auto &cast_data = parameters.cast_data->Cast(); - Vector varchar_union(cast_data.target_type, count); - - UnionToUnionCast(source, varchar_union, count, parameters); - - // now construct the actual varchar vector - varchar_union.Flatten(count); - auto &tag_vector = UnionVector::GetTags(source); - auto tag_vector_type = tag_vector.GetVectorType(); - if (tag_vector_type != VectorType::CONSTANT_VECTOR && tag_vector_type != VectorType::FLAT_VECTOR) { - tag_vector.Flatten(count); - } - - auto tags = FlatVector::GetData(tag_vector); - - auto &validity = FlatVector::Validity(varchar_union); - auto result_data = FlatVector::GetData(result); - - for (idx_t i = 0; i < count; i++) { - if (!validity.RowIsValid(i)) { - FlatVector::SetNull(result, i, true); - continue; - } - - auto &member = UnionVector::GetMember(varchar_union, tags[i]); - UnifiedVectorFormat member_vdata; - member.ToUnifiedFormat(count, member_vdata); - - auto mapped_idx = member_vdata.sel->get_index(i); - auto member_valid = member_vdata.validity.RowIsValid(mapped_idx); - if (member_valid) { - auto member_str = (UnifiedVectorFormat::GetData(member_vdata))[mapped_idx]; - result_data[i] = StringVector::AddString(result, member_str); - } else { - result_data[i] = StringVector::AddString(result, "NULL"); - } - } - - if (constant) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } - - result.Verify(count); - return true; -} - -BoundCastInfo DefaultCasts::UnionCastSwitch(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - D_ASSERT(source.id() == LogicalTypeId::UNION); - switch (target.id()) { - case LogicalTypeId::VARCHAR: { - // bind a cast in which we convert all members to VARCHAR first - child_list_t varchar_members; - for (idx_t member_idx = 0; member_idx < UnionType::GetMemberCount(source); member_idx++) { - varchar_members.push_back(make_pair(UnionType::GetMemberName(source, member_idx), LogicalType::VARCHAR)); - } - auto varchar_type = LogicalType::UNION(std::move(varchar_members)); - return BoundCastInfo(UnionToVarcharCast, BindUnionToUnionCast(input, source, varchar_type), - InitUnionToUnionLocalState); - } - case LogicalTypeId::UNION: - return BoundCastInfo(UnionToUnionCast, BindUnionToUnionCast(input, source, target), InitUnionToUnionLocalState); - default: - return TryVectorNullCast; - } -} - -} // namespace duckdb - - - - -namespace duckdb { - -BoundCastInfo DefaultCasts::UUIDCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target) { - // now switch on the result type - switch (target.id()) { - case LogicalTypeId::VARCHAR: - // uuid to varchar - return BoundCastInfo(&VectorCastHelpers::StringCast); - default: - return TryVectorNullCast; - } -} - -} // namespace duckdb - - -namespace duckdb { - -// ------- Helper functions for splitting string nested types ------- -static bool IsNull(const char *buf, idx_t start_pos, Vector &child, idx_t row_idx) { - if (buf[start_pos] == 'N' && buf[start_pos + 1] == 'U' && buf[start_pos + 2] == 'L' && buf[start_pos + 3] == 'L') { - FlatVector::SetNull(child, row_idx, true); - return true; - } - return false; -} - -inline static void SkipWhitespace(const char *buf, idx_t &pos, idx_t len) { - while (pos < len && StringUtil::CharacterIsSpace(buf[pos])) { - pos++; - } -} - -static bool SkipToCloseQuotes(idx_t &pos, const char *buf, idx_t &len) { - char quote = buf[pos]; - pos++; - bool escaped = false; - - while (pos < len) { - if (buf[pos] == '\\') { - escaped = !escaped; - } else { - if (buf[pos] == quote && !escaped) { - return true; - } - escaped = false; - } - pos++; - } - return false; -} - -static bool SkipToClose(idx_t &idx, const char *buf, idx_t &len, idx_t &lvl, char close_bracket) { - idx++; - - while (idx < len) { - if (buf[idx] == '"' || buf[idx] == '\'') { - if (!SkipToCloseQuotes(idx, buf, len)) { - return false; - } - } else if (buf[idx] == '{') { - if (!SkipToClose(idx, buf, len, lvl, '}')) { - return false; - } - } else if (buf[idx] == '[') { - if (!SkipToClose(idx, buf, len, lvl, ']')) { - return false; - } - lvl++; - } else if (buf[idx] == close_bracket) { - if (close_bracket == ']') { - lvl--; - } - return true; - } - idx++; - } - return false; -} - -static idx_t StringTrim(const char *buf, idx_t &start_pos, idx_t pos) { - idx_t trailing_whitespace = 0; - while (StringUtil::CharacterIsSpace(buf[pos - trailing_whitespace - 1])) { - trailing_whitespace++; - } - if ((buf[start_pos] == '"' && buf[pos - trailing_whitespace - 1] == '"') || - (buf[start_pos] == '\'' && buf[pos - trailing_whitespace - 1] == '\'')) { - start_pos++; - trailing_whitespace++; - } - return (pos - trailing_whitespace); -} - -struct CountPartOperation { - idx_t count = 0; - - bool HandleKey(const char *buf, idx_t start_pos, idx_t pos) { - count++; - return true; - } - void HandleValue(const char *buf, idx_t start_pos, idx_t pos) { - count++; - } -}; - -// ------- LIST SPLIT ------- -struct SplitStringListOperation { - SplitStringListOperation(string_t *child_data, idx_t &child_start, Vector &child) - : child_data(child_data), child_start(child_start), child(child) { - } - - string_t *child_data; - idx_t &child_start; - Vector &child; - - void HandleValue(const char *buf, idx_t start_pos, idx_t pos) { - if ((pos - start_pos) == 4 && IsNull(buf, start_pos, child, child_start)) { - child_start++; - return; - } - if (start_pos > pos) { - pos = start_pos; - } - child_data[child_start] = StringVector::AddString(child, buf + start_pos, pos - start_pos); - child_start++; - } -}; - -template -static bool SplitStringListInternal(const string_t &input, OP &state) { - const char *buf = input.GetData(); - idx_t len = input.GetSize(); - idx_t lvl = 1; - idx_t pos = 0; - bool seen_value = false; - - SkipWhitespace(buf, pos, len); - if (pos == len || buf[pos] != '[') { - return false; - } - - SkipWhitespace(buf, ++pos, len); - idx_t start_pos = pos; - while (pos < len) { - if (buf[pos] == '[') { - if (!SkipToClose(pos, buf, len, ++lvl, ']')) { - return false; - } - } else if ((buf[pos] == '"' || buf[pos] == '\'') && pos == start_pos) { - SkipToCloseQuotes(pos, buf, len); - } else if (buf[pos] == '{') { - idx_t struct_lvl = 0; - SkipToClose(pos, buf, len, struct_lvl, '}'); - } else if (buf[pos] == ',' || buf[pos] == ']') { - idx_t trailing_whitespace = 0; - while (StringUtil::CharacterIsSpace(buf[pos - trailing_whitespace - 1])) { - trailing_whitespace++; - } - if (buf[pos] != ']' || start_pos != pos || seen_value) { - state.HandleValue(buf, start_pos, pos - trailing_whitespace); - seen_value = true; - } - if (buf[pos] == ']') { - lvl--; - break; - } - SkipWhitespace(buf, ++pos, len); - start_pos = pos; - continue; - } - pos++; - } - SkipWhitespace(buf, ++pos, len); - return (pos == len && lvl == 0); -} - -bool VectorStringToList::SplitStringList(const string_t &input, string_t *child_data, idx_t &child_start, - Vector &child) { - SplitStringListOperation state(child_data, child_start, child); - return SplitStringListInternal(input, state); -} - -idx_t VectorStringToList::CountPartsList(const string_t &input) { - CountPartOperation state; - SplitStringListInternal(input, state); - return state.count; -} - -// ------- MAP SPLIT ------- -struct SplitStringMapOperation { - SplitStringMapOperation(string_t *child_key_data, string_t *child_val_data, idx_t &child_start, Vector &varchar_key, - Vector &varchar_val) - : child_key_data(child_key_data), child_val_data(child_val_data), child_start(child_start), - varchar_key(varchar_key), varchar_val(varchar_val) { - } - - string_t *child_key_data; - string_t *child_val_data; - idx_t &child_start; - Vector &varchar_key; - Vector &varchar_val; - - bool HandleKey(const char *buf, idx_t start_pos, idx_t pos) { - if ((pos - start_pos) == 4 && IsNull(buf, start_pos, varchar_key, child_start)) { - FlatVector::SetNull(varchar_val, child_start, true); - child_start++; - return false; - } - child_key_data[child_start] = StringVector::AddString(varchar_key, buf + start_pos, pos - start_pos); - return true; - } - - void HandleValue(const char *buf, idx_t start_pos, idx_t pos) { - if ((pos - start_pos) == 4 && IsNull(buf, start_pos, varchar_val, child_start)) { - child_start++; - return; - } - child_val_data[child_start] = StringVector::AddString(varchar_val, buf + start_pos, pos - start_pos); - child_start++; - } -}; - -template -static bool FindKeyOrValueMap(const char *buf, idx_t len, idx_t &pos, OP &state, bool key) { - auto start_pos = pos; - idx_t lvl = 0; - while (pos < len) { - if (buf[pos] == '"' || buf[pos] == '\'') { - SkipToCloseQuotes(pos, buf, len); - } else if (buf[pos] == '{') { - SkipToClose(pos, buf, len, lvl, '}'); - } else if (buf[pos] == '[') { - SkipToClose(pos, buf, len, lvl, ']'); - } else if (key && buf[pos] == '=') { - idx_t end_pos = StringTrim(buf, start_pos, pos); - return state.HandleKey(buf, start_pos, end_pos); // put string in KEY_child_vector - } else if (!key && (buf[pos] == ',' || buf[pos] == '}')) { - idx_t end_pos = StringTrim(buf, start_pos, pos); - state.HandleValue(buf, start_pos, end_pos); // put string in VALUE_child_vector - return true; - } - pos++; - } - return false; -} - -template -static bool SplitStringMapInternal(const string_t &input, OP &state) { - const char *buf = input.GetData(); - idx_t len = input.GetSize(); - idx_t pos = 0; - - SkipWhitespace(buf, pos, len); - if (pos == len || buf[pos] != '{') { - return false; - } - SkipWhitespace(buf, ++pos, len); - if (pos == len) { - return false; - } - if (buf[pos] == '}') { - SkipWhitespace(buf, ++pos, len); - return (pos == len); - } - while (pos < len) { - if (!FindKeyOrValueMap(buf, len, pos, state, true)) { - return false; - } - SkipWhitespace(buf, ++pos, len); - if (!FindKeyOrValueMap(buf, len, pos, state, false)) { - return false; - } - SkipWhitespace(buf, ++pos, len); - } - return true; -} - -bool VectorStringToMap::SplitStringMap(const string_t &input, string_t *child_key_data, string_t *child_val_data, - idx_t &child_start, Vector &varchar_key, Vector &varchar_val) { - SplitStringMapOperation state(child_key_data, child_val_data, child_start, varchar_key, varchar_val); - return SplitStringMapInternal(input, state); -} - -idx_t VectorStringToMap::CountPartsMap(const string_t &input) { - CountPartOperation state; - SplitStringMapInternal(input, state); - return state.count; -} - -// ------- STRUCT SPLIT ------- -static bool FindKeyStruct(const char *buf, idx_t len, idx_t &pos) { - while (pos < len) { - if (buf[pos] == ':') { - return true; - } - pos++; - } - return false; -} - -static bool FindValueStruct(const char *buf, idx_t len, idx_t &pos, Vector &varchar_child, idx_t &row_idx, - ValidityMask *child_mask) { - auto start_pos = pos; - idx_t lvl = 0; - while (pos < len) { - if (buf[pos] == '"' || buf[pos] == '\'') { - SkipToCloseQuotes(pos, buf, len); - } else if (buf[pos] == '{') { - SkipToClose(pos, buf, len, lvl, '}'); - } else if (buf[pos] == '[') { - SkipToClose(pos, buf, len, lvl, ']'); - } else if (buf[pos] == ',' || buf[pos] == '}') { - idx_t end_pos = StringTrim(buf, start_pos, pos); - if ((end_pos - start_pos) == 4 && IsNull(buf, start_pos, varchar_child, row_idx)) { - return true; - } - FlatVector::GetData(varchar_child)[row_idx] = - StringVector::AddString(varchar_child, buf + start_pos, end_pos - start_pos); - child_mask->SetValid(row_idx); // any child not set to valid will remain invalid - return true; - } - pos++; - } - return false; -} - -bool VectorStringToStruct::SplitStruct(const string_t &input, vector> &varchar_vectors, - idx_t &row_idx, string_map_t &child_names, - vector &child_masks) { - const char *buf = input.GetData(); - idx_t len = input.GetSize(); - idx_t pos = 0; - idx_t child_idx; - - SkipWhitespace(buf, pos, len); - if (pos == len || buf[pos] != '{') { - return false; - } - SkipWhitespace(buf, ++pos, len); - if (buf[pos] == '}') { - pos++; - } else { - while (pos < len) { - auto key_start = pos; - if (!FindKeyStruct(buf, len, pos)) { - return false; - } - auto key_end = StringTrim(buf, key_start, pos); - string_t found_key(buf + key_start, key_end - key_start); - - auto it = child_names.find(found_key); - if (it == child_names.end()) { - return false; // false key - } - child_idx = it->second; - SkipWhitespace(buf, ++pos, len); - if (!FindValueStruct(buf, len, pos, *varchar_vectors[child_idx], row_idx, child_masks[child_idx])) { - return false; - } - SkipWhitespace(buf, ++pos, len); - } - } - SkipWhitespace(buf, pos, len); - return (pos == len); -} - -} // namespace duckdb - - -namespace duckdb { - -//! The target type determines the preferred implicit casts -static int64_t TargetTypeCost(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::INTEGER: - return 103; - case LogicalTypeId::BIGINT: - return 101; - case LogicalTypeId::DOUBLE: - return 102; - case LogicalTypeId::HUGEINT: - return 120; - case LogicalTypeId::TIMESTAMP: - return 120; - case LogicalTypeId::VARCHAR: - return 149; - case LogicalTypeId::DECIMAL: - return 104; - case LogicalTypeId::STRUCT: - case LogicalTypeId::MAP: - case LogicalTypeId::LIST: - case LogicalTypeId::UNION: - return 160; - default: - return 110; - } -} - -static int64_t ImplicitCastTinyint(const LogicalType &to) { - switch (to.id()) { - case LogicalTypeId::SMALLINT: - case LogicalTypeId::INTEGER: - case LogicalTypeId::BIGINT: - case LogicalTypeId::HUGEINT: - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - case LogicalTypeId::DECIMAL: - return TargetTypeCost(to); - default: - return -1; - } -} - -static int64_t ImplicitCastSmallint(const LogicalType &to) { - switch (to.id()) { - case LogicalTypeId::INTEGER: - case LogicalTypeId::BIGINT: - case LogicalTypeId::HUGEINT: - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - case LogicalTypeId::DECIMAL: - return TargetTypeCost(to); - default: - return -1; - } -} - -static int64_t ImplicitCastInteger(const LogicalType &to) { - switch (to.id()) { - case LogicalTypeId::BIGINT: - case LogicalTypeId::HUGEINT: - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - case LogicalTypeId::DECIMAL: - return TargetTypeCost(to); - default: - return -1; - } -} - -static int64_t ImplicitCastBigint(const LogicalType &to) { - switch (to.id()) { - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - case LogicalTypeId::HUGEINT: - case LogicalTypeId::DECIMAL: - return TargetTypeCost(to); - default: - return -1; - } -} - -static int64_t ImplicitCastUTinyint(const LogicalType &to) { - switch (to.id()) { - case LogicalTypeId::USMALLINT: - case LogicalTypeId::UINTEGER: - case LogicalTypeId::UBIGINT: - case LogicalTypeId::SMALLINT: - case LogicalTypeId::INTEGER: - case LogicalTypeId::BIGINT: - case LogicalTypeId::HUGEINT: - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - case LogicalTypeId::DECIMAL: - return TargetTypeCost(to); - default: - return -1; - } -} - -static int64_t ImplicitCastUSmallint(const LogicalType &to) { - switch (to.id()) { - case LogicalTypeId::UINTEGER: - case LogicalTypeId::UBIGINT: - case LogicalTypeId::INTEGER: - case LogicalTypeId::BIGINT: - case LogicalTypeId::HUGEINT: - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - case LogicalTypeId::DECIMAL: - return TargetTypeCost(to); - default: - return -1; - } -} - -static int64_t ImplicitCastUInteger(const LogicalType &to) { - switch (to.id()) { - - case LogicalTypeId::UBIGINT: - case LogicalTypeId::BIGINT: - case LogicalTypeId::HUGEINT: - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - case LogicalTypeId::DECIMAL: - return TargetTypeCost(to); - default: - return -1; - } -} - -static int64_t ImplicitCastUBigint(const LogicalType &to) { - switch (to.id()) { - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - case LogicalTypeId::HUGEINT: - case LogicalTypeId::DECIMAL: - return TargetTypeCost(to); - default: - return -1; - } -} - -static int64_t ImplicitCastFloat(const LogicalType &to) { - switch (to.id()) { - case LogicalTypeId::DOUBLE: - return TargetTypeCost(to); - default: - return -1; - } -} - -static int64_t ImplicitCastDouble(const LogicalType &to) { - switch (to.id()) { - default: - return -1; - } -} - -static int64_t ImplicitCastDecimal(const LogicalType &to) { - switch (to.id()) { - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - return TargetTypeCost(to); - default: - return -1; - } -} - -static int64_t ImplicitCastHugeint(const LogicalType &to) { - switch (to.id()) { - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - case LogicalTypeId::DECIMAL: - return TargetTypeCost(to); - default: - return -1; - } -} - -static int64_t ImplicitCastDate(const LogicalType &to) { - switch (to.id()) { - case LogicalTypeId::TIMESTAMP: - return TargetTypeCost(to); - default: - return -1; - } -} - -int64_t CastRules::ImplicitCast(const LogicalType &from, const LogicalType &to) { - if (from.id() == LogicalTypeId::SQLNULL) { - // NULL expression can be cast to anything - return TargetTypeCost(to); - } - if (from.id() == LogicalTypeId::UNKNOWN) { - // parameter expression can be cast to anything for no cost - return 0; - } - if (to.id() == LogicalTypeId::ANY) { - // anything can be cast to ANY type for (almost no) cost - return 1; - } - if (from.GetAlias() != to.GetAlias()) { - // if aliases are different, an implicit cast is not possible - return -1; - } - if (from.id() == LogicalTypeId::LIST && to.id() == LogicalTypeId::LIST) { - // Lists can be cast if their child types can be cast - auto child_cost = ImplicitCast(ListType::GetChildType(from), ListType::GetChildType(to)); - if (child_cost >= 100) { - // subtract one from the cost because we prefer LIST[X] -> LIST[VARCHAR] over LIST[X] -> VARCHAR - child_cost--; - } - return child_cost; - } - if (from.id() == to.id()) { - // arguments match: do nothing - return 0; - } - if (from.id() == LogicalTypeId::BLOB && to.id() == LogicalTypeId::VARCHAR) { - // Implicit cast not allowed from BLOB to VARCHAR - return -1; - } - if (to.id() == LogicalTypeId::VARCHAR) { - // everything can be cast to VARCHAR, but this cast has a high cost - return TargetTypeCost(to); - } - - if (from.id() == LogicalTypeId::UNION && to.id() == LogicalTypeId::UNION) { - // Unions can be cast if the source tags are a subset of the target tags - // in which case the most expensive cost is used - int cost = -1; - for (idx_t from_member_idx = 0; from_member_idx < UnionType::GetMemberCount(from); from_member_idx++) { - auto &from_member_name = UnionType::GetMemberName(from, from_member_idx); - - bool found = false; - for (idx_t to_member_idx = 0; to_member_idx < UnionType::GetMemberCount(to); to_member_idx++) { - auto &to_member_name = UnionType::GetMemberName(to, to_member_idx); - - if (from_member_name == to_member_name) { - auto &from_member_type = UnionType::GetMemberType(from, from_member_idx); - auto &to_member_type = UnionType::GetMemberType(to, to_member_idx); - - int child_cost = ImplicitCast(from_member_type, to_member_type); - if (child_cost > cost) { - cost = child_cost; - } - found = true; - break; - } - } - if (!found) { - return -1; - } - } - return cost; - } - - if (to.id() == LogicalTypeId::UNION) { - // check that the union type is fully resolved. - if (to.AuxInfo() == nullptr) { - return -1; - } - // every type can be implicitly be cast to a union if the source type is a member of the union - for (idx_t i = 0; i < UnionType::GetMemberCount(to); i++) { - auto member = UnionType::GetMemberType(to, i); - if (from == member) { - return 0; - } - } - } - - if ((from.id() == LogicalTypeId::TIMESTAMP_SEC || from.id() == LogicalTypeId::TIMESTAMP_MS || - from.id() == LogicalTypeId::TIMESTAMP_NS) && - to.id() == LogicalTypeId::TIMESTAMP) { - //! Any timestamp type can be converted to the default (us) type at low cost - return 101; - } - if ((to.id() == LogicalTypeId::TIMESTAMP_SEC || to.id() == LogicalTypeId::TIMESTAMP_MS || - to.id() == LogicalTypeId::TIMESTAMP_NS) && - from.id() == LogicalTypeId::TIMESTAMP) { - //! Any timestamp type can be converted to the default (us) type at low cost - return 100; - } - switch (from.id()) { - case LogicalTypeId::TINYINT: - return ImplicitCastTinyint(to); - case LogicalTypeId::SMALLINT: - return ImplicitCastSmallint(to); - case LogicalTypeId::INTEGER: - return ImplicitCastInteger(to); - case LogicalTypeId::BIGINT: - return ImplicitCastBigint(to); - case LogicalTypeId::UTINYINT: - return ImplicitCastUTinyint(to); - case LogicalTypeId::USMALLINT: - return ImplicitCastUSmallint(to); - case LogicalTypeId::UINTEGER: - return ImplicitCastUInteger(to); - case LogicalTypeId::UBIGINT: - return ImplicitCastUBigint(to); - case LogicalTypeId::HUGEINT: - return ImplicitCastHugeint(to); - case LogicalTypeId::FLOAT: - return ImplicitCastFloat(to); - case LogicalTypeId::DOUBLE: - return ImplicitCastDouble(to); - case LogicalTypeId::DATE: - return ImplicitCastDate(to); - case LogicalTypeId::DECIMAL: - return ImplicitCastDecimal(to); - default: - return -1; - } -} - -} // namespace duckdb - - - - - -namespace duckdb { - -typedef CompressionFunction (*get_compression_function_t)(PhysicalType type); -typedef bool (*compression_supports_type_t)(PhysicalType type); - -struct DefaultCompressionMethod { - CompressionType type; - get_compression_function_t get_function; - compression_supports_type_t supports_type; -}; - -static DefaultCompressionMethod internal_compression_methods[] = { - {CompressionType::COMPRESSION_CONSTANT, ConstantFun::GetFunction, ConstantFun::TypeIsSupported}, - {CompressionType::COMPRESSION_UNCOMPRESSED, UncompressedFun::GetFunction, UncompressedFun::TypeIsSupported}, - {CompressionType::COMPRESSION_RLE, RLEFun::GetFunction, RLEFun::TypeIsSupported}, - {CompressionType::COMPRESSION_BITPACKING, BitpackingFun::GetFunction, BitpackingFun::TypeIsSupported}, - {CompressionType::COMPRESSION_DICTIONARY, DictionaryCompressionFun::GetFunction, - DictionaryCompressionFun::TypeIsSupported}, - {CompressionType::COMPRESSION_CHIMP, ChimpCompressionFun::GetFunction, ChimpCompressionFun::TypeIsSupported}, - {CompressionType::COMPRESSION_PATAS, PatasCompressionFun::GetFunction, PatasCompressionFun::TypeIsSupported}, - {CompressionType::COMPRESSION_FSST, FSSTFun::GetFunction, FSSTFun::TypeIsSupported}, - {CompressionType::COMPRESSION_AUTO, nullptr, nullptr}}; - -static optional_ptr FindCompressionFunction(CompressionFunctionSet &set, CompressionType type, - PhysicalType data_type) { - auto &functions = set.functions; - auto comp_entry = functions.find(type); - if (comp_entry != functions.end()) { - auto &type_functions = comp_entry->second; - auto type_entry = type_functions.find(data_type); - if (type_entry != type_functions.end()) { - return &type_entry->second; - } - } - return nullptr; -} - -static optional_ptr LoadCompressionFunction(CompressionFunctionSet &set, CompressionType type, - PhysicalType data_type) { - for (idx_t index = 0; internal_compression_methods[index].get_function; index++) { - const auto &method = internal_compression_methods[index]; - if (method.type == type) { - // found the correct compression type - if (!method.supports_type(data_type)) { - // but it does not support this data type: bail out - return nullptr; - } - // the type is supported: create the function and insert it into the set - auto function = method.get_function(data_type); - set.functions[type].insert(make_pair(data_type, function)); - return FindCompressionFunction(set, type, data_type); - } - } - throw InternalException("Unsupported compression function type"); -} - -static void TryLoadCompression(DBConfig &config, vector> &result, CompressionType type, - PhysicalType data_type) { - auto function = config.GetCompressionFunction(type, data_type); - if (!function) { - return; - } - result.push_back(*function); -} - -vector> DBConfig::GetCompressionFunctions(PhysicalType data_type) { - vector> result; - TryLoadCompression(*this, result, CompressionType::COMPRESSION_UNCOMPRESSED, data_type); - TryLoadCompression(*this, result, CompressionType::COMPRESSION_RLE, data_type); - TryLoadCompression(*this, result, CompressionType::COMPRESSION_BITPACKING, data_type); - TryLoadCompression(*this, result, CompressionType::COMPRESSION_DICTIONARY, data_type); - TryLoadCompression(*this, result, CompressionType::COMPRESSION_CHIMP, data_type); - TryLoadCompression(*this, result, CompressionType::COMPRESSION_PATAS, data_type); - TryLoadCompression(*this, result, CompressionType::COMPRESSION_FSST, data_type); - return result; -} - -optional_ptr DBConfig::GetCompressionFunction(CompressionType type, PhysicalType data_type) { - lock_guard l(compression_functions->lock); - // check if the function is already loaded - auto function = FindCompressionFunction(*compression_functions, type, data_type); - if (function) { - return function; - } - // else load the function - return LoadCompressionFunction(*compression_functions, type, data_type); -} - -} // namespace duckdb - - - - - - - - - - -namespace duckdb { - -FunctionData::~FunctionData() { -} - -bool FunctionData::Equals(const FunctionData *left, const FunctionData *right) { - if (left == right) { - return true; - } - if (!left || !right) { - return false; - } - return left->Equals(*right); -} - -TableFunctionData::~TableFunctionData() { -} - -unique_ptr TableFunctionData::Copy() const { - throw InternalException("Copy not supported for TableFunctionData"); -} - -bool TableFunctionData::Equals(const FunctionData &other) const { - return false; -} - -Function::Function(string name_p) : name(std::move(name_p)) { -} -Function::~Function() { -} - -SimpleFunction::SimpleFunction(string name_p, vector arguments_p, LogicalType varargs_p) - : Function(std::move(name_p)), arguments(std::move(arguments_p)), varargs(std::move(varargs_p)) { -} - -SimpleFunction::~SimpleFunction() { -} - -string SimpleFunction::ToString() const { - return Function::CallToString(name, arguments); -} - -bool SimpleFunction::HasVarArgs() const { - return varargs.id() != LogicalTypeId::INVALID; -} - -SimpleNamedParameterFunction::SimpleNamedParameterFunction(string name_p, vector arguments_p, - LogicalType varargs_p) - : SimpleFunction(std::move(name_p), std::move(arguments_p), std::move(varargs_p)) { -} - -SimpleNamedParameterFunction::~SimpleNamedParameterFunction() { -} - -string SimpleNamedParameterFunction::ToString() const { - return Function::CallToString(name, arguments, named_parameters); -} - -bool SimpleNamedParameterFunction::HasNamedParameters() const { - return !named_parameters.empty(); -} - -BaseScalarFunction::BaseScalarFunction(string name_p, vector arguments_p, LogicalType return_type_p, - FunctionSideEffects side_effects, LogicalType varargs_p, - FunctionNullHandling null_handling) - : SimpleFunction(std::move(name_p), std::move(arguments_p), std::move(varargs_p)), - return_type(std::move(return_type_p)), side_effects(side_effects), null_handling(null_handling) { -} - -BaseScalarFunction::~BaseScalarFunction() { -} - -string BaseScalarFunction::ToString() const { - return Function::CallToString(name, arguments, return_type); -} - -// add your initializer for new functions here -void BuiltinFunctions::Initialize() { - RegisterTableScanFunctions(); - RegisterSQLiteFunctions(); - RegisterReadFunctions(); - RegisterTableFunctions(); - RegisterArrowFunctions(); - - RegisterDistributiveAggregates(); - - RegisterCompressedMaterializationFunctions(); - - RegisterGenericFunctions(); - RegisterOperators(); - RegisterSequenceFunctions(); - RegisterStringFunctions(); - RegisterNestedFunctions(); - - RegisterPragmaFunctions(); - - // initialize collations - AddCollation("nocase", LowerFun::GetFunction(), true); - AddCollation("noaccent", StripAccentsFun::GetFunction()); - AddCollation("nfc", NFCNormalizeFun::GetFunction()); -} - -hash_t BaseScalarFunction::Hash() const { - hash_t hash = return_type.Hash(); - for (auto &arg : arguments) { - hash = duckdb::CombineHash(hash, arg.Hash()); - } - return hash; -} - -string Function::CallToString(const string &name, const vector &arguments) { - string result = name + "("; - result += StringUtil::Join(arguments, arguments.size(), ", ", - [](const LogicalType &argument) { return argument.ToString(); }); - return result + ")"; -} - -string Function::CallToString(const string &name, const vector &arguments, - const LogicalType &return_type) { - string result = CallToString(name, arguments); - result += " -> " + return_type.ToString(); - return result; -} - -string Function::CallToString(const string &name, const vector &arguments, - const named_parameter_type_map_t &named_parameters) { - vector input_arguments; - input_arguments.reserve(arguments.size() + named_parameters.size()); - for (auto &arg : arguments) { - input_arguments.push_back(arg.ToString()); - } - for (auto &kv : named_parameters) { - input_arguments.push_back(StringUtil::Format("%s : %s", kv.first, kv.second.ToString())); - } - return StringUtil::Format("%s(%s)", name, StringUtil::Join(input_arguments, ", ")); -} - -void Function::EraseArgument(SimpleFunction &bound_function, vector> &arguments, - idx_t argument_index) { - if (bound_function.original_arguments.empty()) { - bound_function.original_arguments = bound_function.arguments; - } - D_ASSERT(arguments.size() == bound_function.arguments.size()); - D_ASSERT(argument_index < arguments.size()); - arguments.erase(arguments.begin() + argument_index); - bound_function.arguments.erase(bound_function.arguments.begin() + argument_index); -} - -} // namespace duckdb - - - - - - - - - - - - - - -namespace duckdb { - -FunctionBinder::FunctionBinder(ClientContext &context) : context(context) { -} - -int64_t FunctionBinder::BindVarArgsFunctionCost(const SimpleFunction &func, const vector &arguments) { - if (arguments.size() < func.arguments.size()) { - // not enough arguments to fulfill the non-vararg part of the function - return -1; - } - int64_t cost = 0; - for (idx_t i = 0; i < arguments.size(); i++) { - LogicalType arg_type = i < func.arguments.size() ? func.arguments[i] : func.varargs; - if (arguments[i] == arg_type) { - // arguments match: do nothing - continue; - } - int64_t cast_cost = CastFunctionSet::Get(context).ImplicitCastCost(arguments[i], arg_type); - if (cast_cost >= 0) { - // we can implicitly cast, add the cost to the total cost - cost += cast_cost; - } else { - // we can't implicitly cast: throw an error - return -1; - } - } - return cost; -} - -int64_t FunctionBinder::BindFunctionCost(const SimpleFunction &func, const vector &arguments) { - if (func.HasVarArgs()) { - // special case varargs function - return BindVarArgsFunctionCost(func, arguments); - } - if (func.arguments.size() != arguments.size()) { - // invalid argument count: check the next function - return -1; - } - int64_t cost = 0; - for (idx_t i = 0; i < arguments.size(); i++) { - int64_t cast_cost = CastFunctionSet::Get(context).ImplicitCastCost(arguments[i], func.arguments[i]); - if (cast_cost >= 0) { - // we can implicitly cast, add the cost to the total cost - cost += cast_cost; - } else { - // we can't implicitly cast: throw an error - return -1; - } - } - return cost; -} - -template -vector FunctionBinder::BindFunctionsFromArguments(const string &name, FunctionSet &functions, - const vector &arguments, string &error) { - idx_t best_function = DConstants::INVALID_INDEX; - int64_t lowest_cost = NumericLimits::Maximum(); - vector candidate_functions; - for (idx_t f_idx = 0; f_idx < functions.functions.size(); f_idx++) { - auto &func = functions.functions[f_idx]; - // check the arguments of the function - int64_t cost = BindFunctionCost(func, arguments); - if (cost < 0) { - // auto casting was not possible - continue; - } - if (cost == lowest_cost) { - candidate_functions.push_back(f_idx); - continue; - } - if (cost > lowest_cost) { - continue; - } - candidate_functions.clear(); - lowest_cost = cost; - best_function = f_idx; - } - if (best_function == DConstants::INVALID_INDEX) { - // no matching function was found, throw an error - string call_str = Function::CallToString(name, arguments); - string candidate_str = ""; - for (auto &f : functions.functions) { - candidate_str += "\t" + f.ToString() + "\n"; - } - error = StringUtil::Format("No function matches the given name and argument types '%s'. You might need to add " - "explicit type casts.\n\tCandidate functions:\n%s", - call_str, candidate_str); - return candidate_functions; - } - candidate_functions.push_back(best_function); - return candidate_functions; -} - -template -idx_t FunctionBinder::MultipleCandidateException(const string &name, FunctionSet &functions, - vector &candidate_functions, - const vector &arguments, string &error) { - D_ASSERT(functions.functions.size() > 1); - // there are multiple possible function definitions - // throw an exception explaining which overloads are there - string call_str = Function::CallToString(name, arguments); - string candidate_str = ""; - for (auto &conf : candidate_functions) { - T f = functions.GetFunctionByOffset(conf); - candidate_str += "\t" + f.ToString() + "\n"; - } - error = StringUtil::Format("Could not choose a best candidate function for the function call \"%s\". In order to " - "select one, please add explicit type casts.\n\tCandidate functions:\n%s", - call_str, candidate_str); - return DConstants::INVALID_INDEX; -} - -template -idx_t FunctionBinder::BindFunctionFromArguments(const string &name, FunctionSet &functions, - const vector &arguments, string &error) { - auto candidate_functions = BindFunctionsFromArguments(name, functions, arguments, error); - if (candidate_functions.empty()) { - // no candidates - return DConstants::INVALID_INDEX; - } - if (candidate_functions.size() > 1) { - // multiple candidates, check if there are any unknown arguments - bool has_parameters = false; - for (auto &arg_type : arguments) { - if (arg_type.id() == LogicalTypeId::UNKNOWN) { - //! there are! we could not resolve parameters in this case - throw ParameterNotResolvedException(); - } - } - if (!has_parameters) { - return MultipleCandidateException(name, functions, candidate_functions, arguments, error); - } - } - return candidate_functions[0]; -} - -idx_t FunctionBinder::BindFunction(const string &name, ScalarFunctionSet &functions, - const vector &arguments, string &error) { - return BindFunctionFromArguments(name, functions, arguments, error); -} - -idx_t FunctionBinder::BindFunction(const string &name, AggregateFunctionSet &functions, - const vector &arguments, string &error) { - return BindFunctionFromArguments(name, functions, arguments, error); -} - -idx_t FunctionBinder::BindFunction(const string &name, TableFunctionSet &functions, - const vector &arguments, string &error) { - return BindFunctionFromArguments(name, functions, arguments, error); -} - -idx_t FunctionBinder::BindFunction(const string &name, PragmaFunctionSet &functions, PragmaInfo &info, string &error) { - vector types; - for (auto &value : info.parameters) { - types.push_back(value.type()); - } - idx_t entry = BindFunctionFromArguments(name, functions, types, error); - if (entry == DConstants::INVALID_INDEX) { - throw BinderException(error); - } - auto candidate_function = functions.GetFunctionByOffset(entry); - // cast the input parameters - for (idx_t i = 0; i < info.parameters.size(); i++) { - auto target_type = - i < candidate_function.arguments.size() ? candidate_function.arguments[i] : candidate_function.varargs; - info.parameters[i] = info.parameters[i].CastAs(context, target_type); - } - return entry; -} - -vector FunctionBinder::GetLogicalTypesFromExpressions(vector> &arguments) { - vector types; - types.reserve(arguments.size()); - for (auto &argument : arguments) { - types.push_back(argument->return_type); - } - return types; -} - -idx_t FunctionBinder::BindFunction(const string &name, ScalarFunctionSet &functions, - vector> &arguments, string &error) { - auto types = GetLogicalTypesFromExpressions(arguments); - return BindFunction(name, functions, types, error); -} - -idx_t FunctionBinder::BindFunction(const string &name, AggregateFunctionSet &functions, - vector> &arguments, string &error) { - auto types = GetLogicalTypesFromExpressions(arguments); - return BindFunction(name, functions, types, error); -} - -idx_t FunctionBinder::BindFunction(const string &name, TableFunctionSet &functions, - vector> &arguments, string &error) { - auto types = GetLogicalTypesFromExpressions(arguments); - return BindFunction(name, functions, types, error); -} - -enum class LogicalTypeComparisonResult { IDENTICAL_TYPE, TARGET_IS_ANY, DIFFERENT_TYPES }; - -LogicalTypeComparisonResult RequiresCast(const LogicalType &source_type, const LogicalType &target_type) { - if (target_type.id() == LogicalTypeId::ANY) { - return LogicalTypeComparisonResult::TARGET_IS_ANY; - } - if (source_type == target_type) { - return LogicalTypeComparisonResult::IDENTICAL_TYPE; - } - if (source_type.id() == LogicalTypeId::LIST && target_type.id() == LogicalTypeId::LIST) { - return RequiresCast(ListType::GetChildType(source_type), ListType::GetChildType(target_type)); - } - return LogicalTypeComparisonResult::DIFFERENT_TYPES; -} - -void FunctionBinder::CastToFunctionArguments(SimpleFunction &function, vector> &children) { - for (idx_t i = 0; i < children.size(); i++) { - auto target_type = i < function.arguments.size() ? function.arguments[i] : function.varargs; - target_type.Verify(); - // don't cast lambda children, they get removed anyways - if (children[i]->return_type.id() == LogicalTypeId::LAMBDA) { - continue; - } - // check if the type of child matches the type of function argument - // if not we need to add a cast - auto cast_result = RequiresCast(children[i]->return_type, target_type); - // except for one special case: if the function accepts ANY argument - // in that case we don't add a cast - if (cast_result == LogicalTypeComparisonResult::DIFFERENT_TYPES) { - children[i] = BoundCastExpression::AddCastToType(context, std::move(children[i]), target_type); - } - } -} - -unique_ptr FunctionBinder::BindScalarFunction(const string &schema, const string &name, - vector> children, string &error, - bool is_operator, Binder *binder) { - // bind the function - auto &function = - Catalog::GetSystemCatalog(context).GetEntry(context, CatalogType::SCALAR_FUNCTION_ENTRY, schema, name); - D_ASSERT(function.type == CatalogType::SCALAR_FUNCTION_ENTRY); - return BindScalarFunction(function.Cast(), std::move(children), error, is_operator, - binder); -} - -unique_ptr FunctionBinder::BindScalarFunction(ScalarFunctionCatalogEntry &func, - vector> children, string &error, - bool is_operator, Binder *binder) { - // bind the function - idx_t best_function = BindFunction(func.name, func.functions, children, error); - if (best_function == DConstants::INVALID_INDEX) { - return nullptr; - } - - // found a matching function! - auto bound_function = func.functions.GetFunctionByOffset(best_function); - - if (bound_function.null_handling == FunctionNullHandling::DEFAULT_NULL_HANDLING) { - for (auto &child : children) { - if (child->return_type == LogicalTypeId::SQLNULL) { - return make_uniq(Value(LogicalType::SQLNULL)); - } - if (!child->IsFoldable()) { - continue; - } - Value result; - if (!ExpressionExecutor::TryEvaluateScalar(context, *child, result)) { - continue; - } - if (result.IsNull()) { - return make_uniq(Value(LogicalType::SQLNULL)); - } - } - } - return BindScalarFunction(bound_function, std::move(children), is_operator); -} - -unique_ptr FunctionBinder::BindScalarFunction(ScalarFunction bound_function, - vector> children, - bool is_operator) { - unique_ptr bind_info; - if (bound_function.bind) { - bind_info = bound_function.bind(context, bound_function, children); - } - // check if we need to add casts to the children - CastToFunctionArguments(bound_function, children); - - // now create the function - auto return_type = bound_function.return_type; - return make_uniq(std::move(return_type), std::move(bound_function), std::move(children), - std::move(bind_info), is_operator); -} - -unique_ptr FunctionBinder::BindAggregateFunction(AggregateFunction bound_function, - vector> children, - unique_ptr filter, - AggregateType aggr_type) { - unique_ptr bind_info; - if (bound_function.bind) { - bind_info = bound_function.bind(context, bound_function, children); - // we may have lost some arguments in the bind - children.resize(MinValue(bound_function.arguments.size(), children.size())); - } - - // check if we need to add casts to the children - CastToFunctionArguments(bound_function, children); - - return make_uniq(std::move(bound_function), std::move(children), std::move(filter), - std::move(bind_info), aggr_type); -} - -} // namespace duckdb - - - -namespace duckdb { - -ScalarFunctionSet::ScalarFunctionSet() : FunctionSet("") { -} - -ScalarFunctionSet::ScalarFunctionSet(string name) : FunctionSet(std::move(name)) { -} - -ScalarFunctionSet::ScalarFunctionSet(ScalarFunction fun) : FunctionSet(std::move(fun.name)) { - functions.push_back(std::move(fun)); -} - -ScalarFunction ScalarFunctionSet::GetFunctionByArguments(ClientContext &context, const vector &arguments) { - string error; - FunctionBinder binder(context); - idx_t index = binder.BindFunction(name, *this, arguments, error); - if (index == DConstants::INVALID_INDEX) { - throw InternalException("Failed to find function %s(%s)\n%s", name, StringUtil::ToString(arguments, ","), - error); - } - return GetFunctionByOffset(index); -} - -AggregateFunctionSet::AggregateFunctionSet() : FunctionSet("") { -} - -AggregateFunctionSet::AggregateFunctionSet(string name) : FunctionSet(std::move(name)) { -} - -AggregateFunctionSet::AggregateFunctionSet(AggregateFunction fun) : FunctionSet(std::move(fun.name)) { - functions.push_back(std::move(fun)); -} - -AggregateFunction AggregateFunctionSet::GetFunctionByArguments(ClientContext &context, - const vector &arguments) { - string error; - FunctionBinder binder(context); - idx_t index = binder.BindFunction(name, *this, arguments, error); - if (index == DConstants::INVALID_INDEX) { - // check if the arguments are a prefix of any of the arguments - // this is used for functions such as quantile or string_agg that delete part of their arguments during bind - // FIXME: we should come up with a better solution here - for (auto &func : functions) { - if (arguments.size() >= func.arguments.size()) { - continue; - } - bool is_prefix = true; - for (idx_t k = 0; k < arguments.size(); k++) { - if (arguments[k] != func.arguments[k]) { - is_prefix = false; - break; - } - } - if (is_prefix) { - return func; - } - } - throw InternalException("Failed to find function %s(%s)\n%s", name, StringUtil::ToString(arguments, ","), - error); - } - return GetFunctionByOffset(index); -} - -TableFunctionSet::TableFunctionSet(string name) : FunctionSet(std::move(name)) { -} - -TableFunctionSet::TableFunctionSet(TableFunction fun) : FunctionSet(std::move(fun.name)) { - functions.push_back(std::move(fun)); -} - -TableFunction TableFunctionSet::GetFunctionByArguments(ClientContext &context, const vector &arguments) { - string error; - FunctionBinder binder(context); - idx_t index = binder.BindFunction(name, *this, arguments, error); - if (index == DConstants::INVALID_INDEX) { - throw InternalException("Failed to find function %s(%s)\n%s", name, StringUtil::ToString(arguments, ","), - error); - } - return GetFunctionByOffset(index); -} - -PragmaFunctionSet::PragmaFunctionSet(string name) : FunctionSet(std::move(name)) { -} - -PragmaFunctionSet::PragmaFunctionSet(PragmaFunction fun) : FunctionSet(std::move(fun.name)) { - functions.push_back(std::move(fun)); -} - -} // namespace duckdb - - - - - - - - - - - - -namespace duckdb { - -// MacroFunction::MacroFunction(unique_ptr expression) : expression(std::move(expression)) {} - -MacroFunction::MacroFunction(MacroType type) : type(type) { -} - -string MacroFunction::ValidateArguments(MacroFunction ¯o_def, const string &name, FunctionExpression &function_expr, - vector> &positionals, - unordered_map> &defaults) { - - // separate positional and default arguments - for (auto &arg : function_expr.children) { - if (!arg->alias.empty()) { - // default argument - if (!macro_def.default_parameters.count(arg->alias)) { - return StringUtil::Format("Macro %s does not have default parameter %s!", name, arg->alias); - } else if (defaults.count(arg->alias)) { - return StringUtil::Format("Duplicate default parameters %s!", arg->alias); - } - defaults[arg->alias] = std::move(arg); - } else if (!defaults.empty()) { - return "Positional parameters cannot come after parameters with a default value!"; - } else { - // positional argument - positionals.push_back(std::move(arg)); - } - } - - // validate if the right number of arguments was supplied - string error; - auto ¶meters = macro_def.parameters; - if (parameters.size() != positionals.size()) { - error = StringUtil::Format( - "Macro function '%s(%s)' requires ", name, - StringUtil::Join(parameters, parameters.size(), ", ", [](const unique_ptr &p) { - return (p->Cast()).column_names[0]; - })); - error += parameters.size() == 1 ? "a single positional argument" - : StringUtil::Format("%i positional arguments", parameters.size()); - error += ", but "; - error += positionals.size() == 1 ? "a single positional argument was" - : StringUtil::Format("%i positional arguments were", positionals.size()); - error += " provided."; - return error; - } - - // Add the default values for parameters that have defaults, that were not explicitly assigned to - for (auto it = macro_def.default_parameters.begin(); it != macro_def.default_parameters.end(); it++) { - auto ¶meter_name = it->first; - auto ¶meter_default = it->second; - if (!defaults.count(parameter_name)) { - // This parameter was not set yet, set it with the default value - defaults[parameter_name] = parameter_default->Copy(); - } - } - - return error; -} - -void MacroFunction::CopyProperties(MacroFunction &other) const { - other.type = type; - for (auto ¶m : parameters) { - other.parameters.push_back(param->Copy()); - } - for (auto &kv : default_parameters) { - other.default_parameters[kv.first] = kv.second->Copy(); - } -} - -string MacroFunction::ToSQL(const string &schema, const string &name) const { - vector param_strings; - for (auto ¶m : parameters) { - param_strings.push_back(param->ToString()); - } - for (auto &named_param : default_parameters) { - param_strings.push_back(StringUtil::Format("%s := %s", named_param.first, named_param.second->ToString())); - } - - return StringUtil::Format("CREATE MACRO %s.%s(%s) AS ", schema, name, StringUtil::Join(param_strings, ", ")); -} - -} // namespace duckdb - - - - - - - - - - - - - -#include - -namespace duckdb { - -static void PragmaEnableProfilingStatement(ClientContext &context, const FunctionParameters ¶meters) { - auto &config = ClientConfig::GetConfig(context); - config.enable_profiler = true; - config.emit_profiler_output = true; -} - -void RegisterEnableProfiling(BuiltinFunctions &set) { - PragmaFunctionSet functions(""); - functions.AddFunction(PragmaFunction::PragmaStatement(string(), PragmaEnableProfilingStatement)); - - set.AddFunction("enable_profile", functions); - set.AddFunction("enable_profiling", functions); -} - -static void PragmaDisableProfiling(ClientContext &context, const FunctionParameters ¶meters) { - auto &config = ClientConfig::GetConfig(context); - config.enable_profiler = false; -} - -static void PragmaEnableProgressBar(ClientContext &context, const FunctionParameters ¶meters) { - ClientConfig::GetConfig(context).enable_progress_bar = true; -} - -static void PragmaDisableProgressBar(ClientContext &context, const FunctionParameters ¶meters) { - ClientConfig::GetConfig(context).enable_progress_bar = false; -} - -static void PragmaEnablePrintProgressBar(ClientContext &context, const FunctionParameters ¶meters) { - ClientConfig::GetConfig(context).print_progress_bar = true; -} - -static void PragmaDisablePrintProgressBar(ClientContext &context, const FunctionParameters ¶meters) { - ClientConfig::GetConfig(context).print_progress_bar = false; -} - -static void PragmaEnableVerification(ClientContext &context, const FunctionParameters ¶meters) { - ClientConfig::GetConfig(context).query_verification_enabled = true; - ClientConfig::GetConfig(context).verify_serializer = true; -} - -static void PragmaDisableVerification(ClientContext &context, const FunctionParameters ¶meters) { - ClientConfig::GetConfig(context).query_verification_enabled = false; - ClientConfig::GetConfig(context).verify_serializer = false; -} - -static void PragmaVerifySerializer(ClientContext &context, const FunctionParameters ¶meters) { - ClientConfig::GetConfig(context).verify_serializer = true; -} - -static void PragmaDisableVerifySerializer(ClientContext &context, const FunctionParameters ¶meters) { - ClientConfig::GetConfig(context).verify_serializer = false; -} - -static void PragmaEnableExternalVerification(ClientContext &context, const FunctionParameters ¶meters) { - ClientConfig::GetConfig(context).verify_external = true; -} - -static void PragmaDisableExternalVerification(ClientContext &context, const FunctionParameters ¶meters) { - ClientConfig::GetConfig(context).verify_external = false; -} - -static void PragmaEnableForceParallelism(ClientContext &context, const FunctionParameters ¶meters) { - ClientConfig::GetConfig(context).verify_parallelism = true; -} - -static void PragmaEnableIndexJoin(ClientContext &context, const FunctionParameters ¶meters) { - ClientConfig::GetConfig(context).enable_index_join = true; -} - -static void PragmaEnableForceIndexJoin(ClientContext &context, const FunctionParameters ¶meters) { - ClientConfig::GetConfig(context).force_index_join = true; -} - -static void PragmaForceCheckpoint(ClientContext &context, const FunctionParameters ¶meters) { - DBConfig::GetConfig(context).options.force_checkpoint = true; -} - -static void PragmaDisableForceParallelism(ClientContext &context, const FunctionParameters ¶meters) { - ClientConfig::GetConfig(context).verify_parallelism = false; -} - -static void PragmaEnableObjectCache(ClientContext &context, const FunctionParameters ¶meters) { - DBConfig::GetConfig(context).options.object_cache_enable = true; -} - -static void PragmaDisableObjectCache(ClientContext &context, const FunctionParameters ¶meters) { - DBConfig::GetConfig(context).options.object_cache_enable = false; -} - -static void PragmaEnableCheckpointOnShutdown(ClientContext &context, const FunctionParameters ¶meters) { - DBConfig::GetConfig(context).options.checkpoint_on_shutdown = true; -} - -static void PragmaDisableCheckpointOnShutdown(ClientContext &context, const FunctionParameters ¶meters) { - DBConfig::GetConfig(context).options.checkpoint_on_shutdown = false; -} - -static void PragmaEnableOptimizer(ClientContext &context, const FunctionParameters ¶meters) { - ClientConfig::GetConfig(context).enable_optimizer = true; -} - -static void PragmaDisableOptimizer(ClientContext &context, const FunctionParameters ¶meters) { - ClientConfig::GetConfig(context).enable_optimizer = false; -} - -void PragmaFunctions::RegisterFunction(BuiltinFunctions &set) { - RegisterEnableProfiling(set); - - set.AddFunction(PragmaFunction::PragmaStatement("disable_profile", PragmaDisableProfiling)); - set.AddFunction(PragmaFunction::PragmaStatement("disable_profiling", PragmaDisableProfiling)); - - set.AddFunction(PragmaFunction::PragmaStatement("enable_verification", PragmaEnableVerification)); - set.AddFunction(PragmaFunction::PragmaStatement("disable_verification", PragmaDisableVerification)); - - set.AddFunction(PragmaFunction::PragmaStatement("verify_external", PragmaEnableExternalVerification)); - set.AddFunction(PragmaFunction::PragmaStatement("disable_verify_external", PragmaDisableExternalVerification)); - - set.AddFunction(PragmaFunction::PragmaStatement("verify_serializer", PragmaVerifySerializer)); - set.AddFunction(PragmaFunction::PragmaStatement("disable_verify_serializer", PragmaDisableVerifySerializer)); - - set.AddFunction(PragmaFunction::PragmaStatement("verify_parallelism", PragmaEnableForceParallelism)); - set.AddFunction(PragmaFunction::PragmaStatement("disable_verify_parallelism", PragmaDisableForceParallelism)); - - set.AddFunction(PragmaFunction::PragmaStatement("enable_object_cache", PragmaEnableObjectCache)); - set.AddFunction(PragmaFunction::PragmaStatement("disable_object_cache", PragmaDisableObjectCache)); - - set.AddFunction(PragmaFunction::PragmaStatement("enable_optimizer", PragmaEnableOptimizer)); - set.AddFunction(PragmaFunction::PragmaStatement("disable_optimizer", PragmaDisableOptimizer)); - - set.AddFunction(PragmaFunction::PragmaStatement("enable_index_join", PragmaEnableIndexJoin)); - set.AddFunction(PragmaFunction::PragmaStatement("force_index_join", PragmaEnableForceIndexJoin)); - set.AddFunction(PragmaFunction::PragmaStatement("force_checkpoint", PragmaForceCheckpoint)); - - set.AddFunction(PragmaFunction::PragmaStatement("enable_progress_bar", PragmaEnableProgressBar)); - set.AddFunction(PragmaFunction::PragmaStatement("disable_progress_bar", PragmaDisableProgressBar)); - - set.AddFunction(PragmaFunction::PragmaStatement("enable_print_progress_bar", PragmaEnablePrintProgressBar)); - set.AddFunction(PragmaFunction::PragmaStatement("disable_print_progress_bar", PragmaDisablePrintProgressBar)); - - set.AddFunction(PragmaFunction::PragmaStatement("enable_checkpoint_on_shutdown", PragmaEnableCheckpointOnShutdown)); - set.AddFunction( - PragmaFunction::PragmaStatement("disable_checkpoint_on_shutdown", PragmaDisableCheckpointOnShutdown)); -} - -} // namespace duckdb - - - - - - - - - - - - - -namespace duckdb { - -string PragmaTableInfo(ClientContext &context, const FunctionParameters ¶meters) { - return StringUtil::Format("SELECT * FROM pragma_table_info('%s');", parameters.values[0].ToString()); -} - -string PragmaShowTables(ClientContext &context, const FunctionParameters ¶meters) { - // clang-format off - return R"EOF( - with "tables" as - ( - SELECT table_name as "name" - FROM duckdb_tables - where in_search_path(database_name, schema_name) - ), "views" as - ( - SELECT view_name as "name" - FROM duckdb_views - where in_search_path(database_name, schema_name) - ), db_objects as - ( - SELECT "name" FROM "tables" - UNION ALL - SELECT "name" FROM "views" - ) - SELECT "name" - FROM db_objects - ORDER BY "name";)EOF"; - // clang-format on -} - -string PragmaShowTablesExpanded(ClientContext &context, const FunctionParameters ¶meters) { - return R"( - SELECT - t.database_name AS database, - t.schema_name AS schema, - t.table_name AS name, - LIST(c.column_name order by c.column_index) AS column_names, - LIST(c.data_type order by c.column_index) AS column_types, - FIRST(t.temporary) AS temporary, - FROM duckdb_tables t - JOIN duckdb_columns c - USING (table_oid) - GROUP BY database, schema, name - - UNION ALL - - SELECT - v.database_name AS database, - v.schema_name AS schema, - v.view_name AS name, - LIST(c.column_name order by c.column_index) AS column_names, - LIST(c.data_type order by c.column_index) AS column_types, - FIRST(v.temporary) AS temporary, - FROM duckdb_views v - JOIN duckdb_columns c - ON (v.view_oid=c.table_oid) - GROUP BY database, schema, name - - ORDER BY database, schema, name - )"; -} - -string PragmaShowDatabases(ClientContext &context, const FunctionParameters ¶meters) { - return "SELECT database_name FROM duckdb_databases() WHERE NOT internal ORDER BY database_name;"; -} - -string PragmaAllProfiling(ClientContext &context, const FunctionParameters ¶meters) { - return "SELECT * FROM pragma_last_profiling_output() JOIN pragma_detailed_profiling_output() ON " - "(pragma_last_profiling_output.operator_id);"; -} - -string PragmaDatabaseList(ClientContext &context, const FunctionParameters ¶meters) { - return "SELECT * FROM pragma_database_list;"; -} - -string PragmaCollations(ClientContext &context, const FunctionParameters ¶meters) { - return "SELECT * FROM pragma_collations() ORDER BY 1;"; -} - -string PragmaFunctionsQuery(ClientContext &context, const FunctionParameters ¶meters) { - return "SELECT function_name AS name, upper(function_type) AS type, parameter_types AS parameters, varargs, " - "return_type, has_side_effects AS side_effects" - " FROM duckdb_functions()" - " WHERE function_type IN ('scalar', 'aggregate')" - " ORDER BY 1;"; -} - -string PragmaShow(ClientContext &context, const FunctionParameters ¶meters) { - // PRAGMA table_info but with some aliases - auto table = QualifiedName::Parse(parameters.values[0].ToString()); - - // clang-format off - string sql = R"( - SELECT - name AS "column_name", - type as "column_type", - CASE WHEN "notnull" THEN 'NO' ELSE 'YES' END AS "null", - (SELECT - MIN(CASE - WHEN constraint_type='PRIMARY KEY' THEN 'PRI' - WHEN constraint_type='UNIQUE' THEN 'UNI' - ELSE NULL END) - FROM duckdb_constraints() c - WHERE c.table_oid=cols.table_oid - AND list_contains(constraint_column_names, cols.column_name)) AS "key", - dflt_value AS "default", - NULL AS "extra" - FROM pragma_table_info('%func_param_table%') - LEFT JOIN duckdb_columns cols - ON cols.column_name = pragma_table_info.name - AND cols.table_name='%table_name%' - AND cols.schema_name='%table_schema%' - AND cols.database_name = '%table_database%' - ORDER BY column_index;)"; - // clang-format on - - sql = StringUtil::Replace(sql, "%func_param_table%", parameters.values[0].ToString()); - sql = StringUtil::Replace(sql, "%table_name%", table.name); - sql = StringUtil::Replace(sql, "%table_schema%", table.schema.empty() ? DEFAULT_SCHEMA : table.schema); - sql = StringUtil::Replace(sql, "%table_database%", - table.catalog.empty() ? DatabaseManager::GetDefaultDatabase(context) : table.catalog); - return sql; -} - -string PragmaVersion(ClientContext &context, const FunctionParameters ¶meters) { - return "SELECT * FROM pragma_version();"; -} - -string PragmaPlatform(ClientContext &context, const FunctionParameters ¶meters) { - return "SELECT * FROM pragma_platform();"; -} - -string PragmaImportDatabase(ClientContext &context, const FunctionParameters ¶meters) { - auto &config = DBConfig::GetConfig(context); - if (!config.options.enable_external_access) { - throw PermissionException("Import is disabled through configuration"); - } - auto &fs = FileSystem::GetFileSystem(context); - - string final_query; - // read the "shema.sql" and "load.sql" files - vector files = {"schema.sql", "load.sql"}; - for (auto &file : files) { - auto file_path = fs.JoinPath(parameters.values[0].ToString(), file); - auto handle = fs.OpenFile(file_path, FileFlags::FILE_FLAGS_READ, FileSystem::DEFAULT_LOCK, - FileSystem::DEFAULT_COMPRESSION); - auto fsize = fs.GetFileSize(*handle); - auto buffer = make_unsafe_uniq_array(fsize); - fs.Read(*handle, buffer.get(), fsize); - auto query = string(buffer.get(), fsize); - // Replace the placeholder with the path provided to IMPORT - if (file == "load.sql") { - Parser parser; - parser.ParseQuery(query); - auto copy_statements = std::move(parser.statements); - query.clear(); - for (auto &statement_p : copy_statements) { - D_ASSERT(statement_p->type == StatementType::COPY_STATEMENT); - auto &statement = statement_p->Cast(); - auto &info = *statement.info; - auto file_name = fs.ExtractName(info.file_path); - info.file_path = fs.JoinPath(parameters.values[0].ToString(), file_name); - query += statement.ToString() + ";"; - } - } - final_query += query; - } - return final_query; -} - -string PragmaDatabaseSize(ClientContext &context, const FunctionParameters ¶meters) { - return "SELECT * FROM pragma_database_size();"; -} - -string PragmaStorageInfo(ClientContext &context, const FunctionParameters ¶meters) { - return StringUtil::Format("SELECT * FROM pragma_storage_info('%s');", parameters.values[0].ToString()); -} - -string PragmaMetadataInfo(ClientContext &context, const FunctionParameters ¶meters) { - return "SELECT * FROM pragma_metadata_info();"; -} - -void PragmaQueries::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(PragmaFunction::PragmaCall("table_info", PragmaTableInfo, {LogicalType::VARCHAR})); - set.AddFunction(PragmaFunction::PragmaCall("storage_info", PragmaStorageInfo, {LogicalType::VARCHAR})); - set.AddFunction(PragmaFunction::PragmaCall("metadata_info", PragmaMetadataInfo, {})); - set.AddFunction(PragmaFunction::PragmaStatement("show_tables", PragmaShowTables)); - set.AddFunction(PragmaFunction::PragmaStatement("show_tables_expanded", PragmaShowTablesExpanded)); - set.AddFunction(PragmaFunction::PragmaStatement("show_databases", PragmaShowDatabases)); - set.AddFunction(PragmaFunction::PragmaStatement("database_list", PragmaDatabaseList)); - set.AddFunction(PragmaFunction::PragmaStatement("collations", PragmaCollations)); - set.AddFunction(PragmaFunction::PragmaCall("show", PragmaShow, {LogicalType::VARCHAR})); - set.AddFunction(PragmaFunction::PragmaStatement("version", PragmaVersion)); - set.AddFunction(PragmaFunction::PragmaStatement("platform", PragmaPlatform)); - set.AddFunction(PragmaFunction::PragmaStatement("database_size", PragmaDatabaseSize)); - set.AddFunction(PragmaFunction::PragmaStatement("functions", PragmaFunctionsQuery)); - set.AddFunction(PragmaFunction::PragmaCall("import_database", PragmaImportDatabase, {LogicalType::VARCHAR})); - set.AddFunction(PragmaFunction::PragmaStatement("all_profiling_output", PragmaAllProfiling)); -} - -} // namespace duckdb - - - -namespace duckdb { - -PragmaFunction::PragmaFunction(string name, PragmaType pragma_type, pragma_query_t query, pragma_function_t function, - vector arguments, LogicalType varargs) - : SimpleNamedParameterFunction(std::move(name), std::move(arguments), std::move(varargs)), type(pragma_type), - query(query), function(function) { -} - -PragmaFunction PragmaFunction::PragmaCall(const string &name, pragma_query_t query, vector arguments, - LogicalType varargs) { - return PragmaFunction(name, PragmaType::PRAGMA_CALL, query, nullptr, std::move(arguments), std::move(varargs)); -} - -PragmaFunction PragmaFunction::PragmaCall(const string &name, pragma_function_t function, vector arguments, - LogicalType varargs) { - return PragmaFunction(name, PragmaType::PRAGMA_CALL, nullptr, function, std::move(arguments), std::move(varargs)); -} - -PragmaFunction PragmaFunction::PragmaStatement(const string &name, pragma_query_t query) { - vector types; - return PragmaFunction(name, PragmaType::PRAGMA_STATEMENT, query, nullptr, std::move(types), LogicalType::INVALID); -} - -PragmaFunction PragmaFunction::PragmaStatement(const string &name, pragma_function_t function) { - vector types; - return PragmaFunction(name, PragmaType::PRAGMA_STATEMENT, nullptr, function, std::move(types), - LogicalType::INVALID); -} - -string PragmaFunction::ToString() const { - switch (type) { - case PragmaType::PRAGMA_STATEMENT: - return StringUtil::Format("PRAGMA %s", name); - case PragmaType::PRAGMA_CALL: { - return StringUtil::Format("PRAGMA %s", SimpleNamedParameterFunction::ToString()); - } - default: - return "UNKNOWN"; - } -} - -} // namespace duckdb - - - - - -namespace duckdb { - -static string IntegralCompressFunctionName(const LogicalType &result_type) { - return StringUtil::Format("__internal_compress_integral_%s", - StringUtil::Lower(LogicalTypeIdToString(result_type.id()))); -} - -template -struct TemplatedIntegralCompress { - static inline RESULT_TYPE Operation(const INPUT_TYPE &input, const INPUT_TYPE &min_val) { - D_ASSERT(min_val <= input); - return input - min_val; - } -}; - -template -struct TemplatedIntegralCompress { - static inline RESULT_TYPE Operation(const hugeint_t &input, const hugeint_t &min_val) { - D_ASSERT(min_val <= input); - return (input - min_val).lower; - } -}; - -template -static void IntegralCompressFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 2); - D_ASSERT(args.data[1].GetVectorType() == VectorType::CONSTANT_VECTOR); - const auto min_val = ConstantVector::GetData(args.data[1])[0]; - UnaryExecutor::Execute(args.data[0], result, args.size(), [&](const INPUT_TYPE &input) { - return TemplatedIntegralCompress::Operation(input, min_val); - }); -} - -template -static scalar_function_t GetIntegralCompressFunction(const LogicalType &input_type, const LogicalType &result_type) { - return IntegralCompressFunction; -} - -template -static scalar_function_t GetIntegralCompressFunctionResultSwitch(const LogicalType &input_type, - const LogicalType &result_type) { - switch (result_type.id()) { - case LogicalTypeId::UTINYINT: - return GetIntegralCompressFunction(input_type, result_type); - case LogicalTypeId::USMALLINT: - return GetIntegralCompressFunction(input_type, result_type); - case LogicalTypeId::UINTEGER: - return GetIntegralCompressFunction(input_type, result_type); - case LogicalTypeId::UBIGINT: - return GetIntegralCompressFunction(input_type, result_type); - default: - throw InternalException("Unexpected result type in GetIntegralCompressFunctionResultSwitch"); - } -} - -static scalar_function_t GetIntegralCompressFunctionInputSwitch(const LogicalType &input_type, - const LogicalType &result_type) { - switch (input_type.id()) { - case LogicalTypeId::SMALLINT: - return GetIntegralCompressFunctionResultSwitch(input_type, result_type); - case LogicalTypeId::INTEGER: - return GetIntegralCompressFunctionResultSwitch(input_type, result_type); - case LogicalTypeId::BIGINT: - return GetIntegralCompressFunctionResultSwitch(input_type, result_type); - case LogicalTypeId::HUGEINT: - return GetIntegralCompressFunctionResultSwitch(input_type, result_type); - case LogicalTypeId::USMALLINT: - return GetIntegralCompressFunctionResultSwitch(input_type, result_type); - case LogicalTypeId::UINTEGER: - return GetIntegralCompressFunctionResultSwitch(input_type, result_type); - case LogicalTypeId::UBIGINT: - return GetIntegralCompressFunctionResultSwitch(input_type, result_type); - default: - throw InternalException("Unexpected input type in GetIntegralCompressFunctionInputSwitch"); - } -} - -static string IntegralDecompressFunctionName(const LogicalType &result_type) { - return StringUtil::Format("__internal_decompress_integral_%s", - StringUtil::Lower(LogicalTypeIdToString(result_type.id()))); -} - -template -static inline RESULT_TYPE TemplatedIntegralDecompress(const INPUT_TYPE &input, const RESULT_TYPE &min_val) { - return min_val + input; -} - -template -static void IntegralDecompressFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 2); - D_ASSERT(args.data[1].GetVectorType() == VectorType::CONSTANT_VECTOR); - D_ASSERT(args.data[1].GetType() == result.GetType()); - const auto min_val = ConstantVector::GetData(args.data[1])[0]; - UnaryExecutor::Execute(args.data[0], result, args.size(), [&](const INPUT_TYPE &input) { - return TemplatedIntegralDecompress(input, min_val); - }); -} - -template -static scalar_function_t GetIntegralDecompressFunction(const LogicalType &input_type, const LogicalType &result_type) { - return IntegralDecompressFunction; -} - -template -static scalar_function_t GetIntegralDecompressFunctionResultSwitch(const LogicalType &input_type, - const LogicalType &result_type) { - switch (result_type.id()) { - case LogicalTypeId::SMALLINT: - return GetIntegralDecompressFunction(input_type, result_type); - case LogicalTypeId::INTEGER: - return GetIntegralDecompressFunction(input_type, result_type); - case LogicalTypeId::BIGINT: - return GetIntegralDecompressFunction(input_type, result_type); - case LogicalTypeId::HUGEINT: - return GetIntegralDecompressFunction(input_type, result_type); - case LogicalTypeId::USMALLINT: - return GetIntegralDecompressFunction(input_type, result_type); - case LogicalTypeId::UINTEGER: - return GetIntegralDecompressFunction(input_type, result_type); - case LogicalTypeId::UBIGINT: - return GetIntegralDecompressFunction(input_type, result_type); - default: - throw InternalException("Unexpected input type in GetIntegralDecompressFunctionSetSwitch"); - } -} - -static scalar_function_t GetIntegralDecompressFunctionInputSwitch(const LogicalType &input_type, - const LogicalType &result_type) { - switch (input_type.id()) { - case LogicalTypeId::UTINYINT: - return GetIntegralDecompressFunctionResultSwitch(input_type, result_type); - case LogicalTypeId::USMALLINT: - return GetIntegralDecompressFunctionResultSwitch(input_type, result_type); - case LogicalTypeId::UINTEGER: - return GetIntegralDecompressFunctionResultSwitch(input_type, result_type); - case LogicalTypeId::UBIGINT: - return GetIntegralDecompressFunctionResultSwitch(input_type, result_type); - default: - throw InternalException("Unexpected result type in GetIntegralDecompressFunctionInputSwitch"); - } -} - -static void CMIntegralSerialize(Serializer &serializer, const optional_ptr bind_data, - const ScalarFunction &function) { - serializer.WriteProperty(100, "arguments", function.arguments); - serializer.WriteProperty(101, "return_type", function.return_type); -} - -template -unique_ptr CMIntegralDeserialize(Deserializer &deserializer, ScalarFunction &function) { - function.arguments = deserializer.ReadProperty>(100, "arguments"); - auto return_type = deserializer.ReadProperty(101, "return_type"); - function.function = GET_FUNCTION(function.arguments[0], return_type); - return nullptr; -} - -ScalarFunction CMIntegralCompressFun::GetFunction(const LogicalType &input_type, const LogicalType &result_type) { - ScalarFunction result(IntegralCompressFunctionName(result_type), {input_type, input_type}, result_type, - GetIntegralCompressFunctionInputSwitch(input_type, result_type), - CompressedMaterializationFunctions::Bind); - result.serialize = CMIntegralSerialize; - result.deserialize = CMIntegralDeserialize; - return result; -} - -static ScalarFunctionSet GetIntegralCompressFunctionSet(const LogicalType &result_type) { - ScalarFunctionSet set(IntegralCompressFunctionName(result_type)); - for (const auto &input_type : LogicalType::Integral()) { - if (GetTypeIdSize(result_type.InternalType()) < GetTypeIdSize(input_type.InternalType())) { - set.AddFunction(CMIntegralCompressFun::GetFunction(input_type, result_type)); - } - } - return set; -} - -void CMIntegralCompressFun::RegisterFunction(BuiltinFunctions &set) { - for (const auto &result_type : CompressedMaterializationFunctions::IntegralTypes()) { - set.AddFunction(GetIntegralCompressFunctionSet(result_type)); - } -} - -ScalarFunction CMIntegralDecompressFun::GetFunction(const LogicalType &input_type, const LogicalType &result_type) { - ScalarFunction result(IntegralDecompressFunctionName(result_type), {input_type, result_type}, result_type, - GetIntegralDecompressFunctionInputSwitch(input_type, result_type), - CompressedMaterializationFunctions::Bind); - result.serialize = CMIntegralSerialize; - result.deserialize = CMIntegralDeserialize; - return result; -} - -static ScalarFunctionSet GetIntegralDecompressFunctionSet(const LogicalType &result_type) { - ScalarFunctionSet set(IntegralDecompressFunctionName(result_type)); - for (const auto &input_type : CompressedMaterializationFunctions::IntegralTypes()) { - if (GetTypeIdSize(result_type.InternalType()) > GetTypeIdSize(input_type.InternalType())) { - set.AddFunction(CMIntegralDecompressFun::GetFunction(input_type, result_type)); - } - } - return set; -} - -void CMIntegralDecompressFun::RegisterFunction(BuiltinFunctions &set) { - for (const auto &result_type : LogicalType::Integral()) { - if (GetTypeIdSize(result_type.InternalType()) > 1) { - set.AddFunction(GetIntegralDecompressFunctionSet(result_type)); - } - } -} - -} // namespace duckdb - - - - - -namespace duckdb { - -static string StringCompressFunctionName(const LogicalType &result_type) { - return StringUtil::Format("__internal_compress_string_%s", - StringUtil::Lower(LogicalTypeIdToString(result_type.id()))); -} - -template -static inline void TemplatedReverseMemCpy(const data_ptr_t __restrict &dest, const const_data_ptr_t __restrict &src) { - for (idx_t i = 0; i < LENGTH; i++) { - dest[i] = src[LENGTH - 1 - i]; - } -} - -static inline void ReverseMemCpy(const data_ptr_t __restrict &dest, const const_data_ptr_t __restrict &src, - const idx_t &length) { - for (idx_t i = 0; i < length; i++) { - dest[i] = src[length - 1 - i]; - } -} - -template -static inline RESULT_TYPE StringCompressInternal(const string_t &input) { - RESULT_TYPE result; - const auto result_ptr = data_ptr_cast(&result); - if (sizeof(RESULT_TYPE) <= string_t::INLINE_LENGTH) { - TemplatedReverseMemCpy(result_ptr, const_data_ptr_cast(input.GetPrefix())); - } else if (input.IsInlined()) { - static constexpr auto REMAINDER = sizeof(RESULT_TYPE) - string_t::INLINE_LENGTH; - TemplatedReverseMemCpy(result_ptr + REMAINDER, const_data_ptr_cast(input.GetPrefix())); - memset(result_ptr, '\0', REMAINDER); - } else { - const auto remainder = sizeof(RESULT_TYPE) - input.GetSize(); - ReverseMemCpy(result_ptr + remainder, data_ptr_cast(input.GetPointer()), input.GetSize()); - memset(result_ptr, '\0', remainder); - } - result_ptr[0] = input.GetSize(); - return result; -} - -template -static inline RESULT_TYPE StringCompress(const string_t &input) { - D_ASSERT(input.GetSize() < sizeof(RESULT_TYPE)); - return StringCompressInternal(input); -} - -template -static inline RESULT_TYPE MiniStringCompress(const string_t &input) { - if (sizeof(RESULT_TYPE) <= string_t::INLINE_LENGTH) { - return input.GetSize() + *const_data_ptr_cast(input.GetPrefix()); - } else if (input.GetSize() == 0) { - return 0; - } else { - return input.GetSize() + *const_data_ptr_cast(input.GetPointer()); - } -} - -template <> -inline uint8_t StringCompress(const string_t &input) { - D_ASSERT(input.GetSize() <= sizeof(uint8_t)); - return MiniStringCompress(input); -} - -template -static void StringCompressFunction(DataChunk &args, ExpressionState &state, Vector &result) { - UnaryExecutor::Execute(args.data[0], result, args.size(), StringCompress); -} - -template -static scalar_function_t GetStringCompressFunction(const LogicalType &result_type) { - return StringCompressFunction; -} - -static scalar_function_t GetStringCompressFunctionSwitch(const LogicalType &result_type) { - switch (result_type.id()) { - case LogicalTypeId::UTINYINT: - return GetStringCompressFunction(result_type); - case LogicalTypeId::USMALLINT: - return GetStringCompressFunction(result_type); - case LogicalTypeId::UINTEGER: - return GetStringCompressFunction(result_type); - case LogicalTypeId::UBIGINT: - return GetStringCompressFunction(result_type); - case LogicalTypeId::HUGEINT: - return GetStringCompressFunction(result_type); - default: - throw InternalException("Unexpected type in GetStringCompressFunctionSwitch"); - } -} - -static string StringDecompressFunctionName() { - return "__internal_decompress_string"; -} - -struct StringDecompressLocalState : public FunctionLocalState { -public: - explicit StringDecompressLocalState(ClientContext &context) : allocator(Allocator::Get(context)) { - } - - static unique_ptr Init(ExpressionState &state, const BoundFunctionExpression &expr, - FunctionData *bind_data) { - return make_uniq(state.GetContext()); - } - -public: - ArenaAllocator allocator; -}; - -template -static inline string_t StringDecompress(const INPUT_TYPE &input, ArenaAllocator &allocator) { - const auto input_ptr = const_data_ptr_cast(&input); - string_t result(input_ptr[0]); - if (sizeof(INPUT_TYPE) <= string_t::INLINE_LENGTH) { - const auto result_ptr = data_ptr_cast(result.GetPrefixWriteable()); - TemplatedReverseMemCpy(result_ptr, input_ptr); - memset(result_ptr + sizeof(INPUT_TYPE) - 1, '\0', string_t::INLINE_LENGTH - sizeof(INPUT_TYPE) + 1); - } else if (result.GetSize() <= string_t::INLINE_LENGTH) { - static constexpr auto REMAINDER = sizeof(INPUT_TYPE) - string_t::INLINE_LENGTH; - const auto result_ptr = data_ptr_cast(result.GetPrefixWriteable()); - TemplatedReverseMemCpy(result_ptr, input_ptr + REMAINDER); - } else { - result.SetPointer(char_ptr_cast(allocator.Allocate(sizeof(INPUT_TYPE)))); - TemplatedReverseMemCpy(data_ptr_cast(result.GetPointer()), input_ptr); - memcpy(result.GetPrefixWriteable(), result.GetPointer(), string_t::PREFIX_LENGTH); - } - return result; -} - -template -static inline string_t MiniStringDecompress(const INPUT_TYPE &input, ArenaAllocator &allocator) { - if (input == 0) { - string_t result(uint32_t(0)); - memset(result.GetPrefixWriteable(), '\0', string_t::INLINE_BYTES); - return result; - } - - string_t result(1); - if (sizeof(INPUT_TYPE) <= string_t::INLINE_LENGTH) { - memset(result.GetPrefixWriteable(), '\0', string_t::INLINE_BYTES); - *data_ptr_cast(result.GetPrefixWriteable()) = input - 1; - } else { - result.SetPointer(char_ptr_cast(allocator.Allocate(1))); - *data_ptr_cast(result.GetPointer()) = input - 1; - memset(result.GetPrefixWriteable(), '\0', string_t::PREFIX_LENGTH); - *result.GetPrefixWriteable() = *result.GetPointer(); - } - return result; -} - -template <> -inline string_t StringDecompress(const uint8_t &input, ArenaAllocator &allocator) { - return MiniStringDecompress(input, allocator); -} - -template -static void StringDecompressFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &allocator = ExecuteFunctionState::GetFunctionState(state)->Cast().allocator; - allocator.Reset(); - UnaryExecutor::Execute(args.data[0], result, args.size(), [&](const INPUT_TYPE &input) { - return StringDecompress(input, allocator); - }); -} - -template -static scalar_function_t GetStringDecompressFunction(const LogicalType &input_type) { - return StringDecompressFunction; -} - -static scalar_function_t GetStringDecompressFunctionSwitch(const LogicalType &input_type) { - switch (input_type.id()) { - case LogicalTypeId::UTINYINT: - return GetStringDecompressFunction(input_type); - case LogicalTypeId::USMALLINT: - return GetStringDecompressFunction(input_type); - case LogicalTypeId::UINTEGER: - return GetStringDecompressFunction(input_type); - case LogicalTypeId::UBIGINT: - return GetStringDecompressFunction(input_type); - case LogicalTypeId::HUGEINT: - return GetStringDecompressFunction(input_type); - default: - throw InternalException("Unexpected type in GetStringDecompressFunctionSwitch"); - } -} - -static void CMStringCompressSerialize(Serializer &serializer, const optional_ptr bind_data, - const ScalarFunction &function) { - serializer.WriteProperty(100, "arguments", function.arguments); - serializer.WriteProperty(101, "return_type", function.return_type); -} - -unique_ptr CMStringCompressDeserialize(Deserializer &deserializer, ScalarFunction &function) { - function.arguments = deserializer.ReadProperty>(100, "arguments"); - auto return_type = deserializer.ReadProperty(101, "return_type"); - function.function = GetStringCompressFunctionSwitch(return_type); - return nullptr; -} - -ScalarFunction CMStringCompressFun::GetFunction(const LogicalType &result_type) { - ScalarFunction result(StringCompressFunctionName(result_type), {LogicalType::VARCHAR}, result_type, - GetStringCompressFunctionSwitch(result_type), CompressedMaterializationFunctions::Bind); - result.serialize = CMStringCompressSerialize; - result.deserialize = CMStringCompressDeserialize; - return result; -} - -void CMStringCompressFun::RegisterFunction(BuiltinFunctions &set) { - for (const auto &result_type : CompressedMaterializationFunctions::StringTypes()) { - set.AddFunction(CMStringCompressFun::GetFunction(result_type)); - } -} - -static void CMStringDecompressSerialize(Serializer &serializer, const optional_ptr bind_data, - const ScalarFunction &function) { - serializer.WriteProperty(100, "arguments", function.arguments); -} - -unique_ptr CMStringDecompressDeserialize(Deserializer &deserializer, ScalarFunction &function) { - function.arguments = deserializer.ReadProperty>(100, "arguments"); - function.function = GetStringDecompressFunctionSwitch(function.arguments[0]); - return nullptr; -} - -ScalarFunction CMStringDecompressFun::GetFunction(const LogicalType &input_type) { - ScalarFunction result(StringDecompressFunctionName(), {input_type}, LogicalType::VARCHAR, - GetStringDecompressFunctionSwitch(input_type), CompressedMaterializationFunctions::Bind, - nullptr, nullptr, StringDecompressLocalState::Init); - result.serialize = CMStringDecompressSerialize; - result.deserialize = CMStringDecompressDeserialize; - return result; -} - -static ScalarFunctionSet GetStringDecompressFunctionSet() { - ScalarFunctionSet set(StringDecompressFunctionName()); - for (const auto &input_type : CompressedMaterializationFunctions::StringTypes()) { - set.AddFunction(CMStringDecompressFun::GetFunction(input_type)); - } - return set; -} - -void CMStringDecompressFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(GetStringDecompressFunctionSet()); -} - -} // namespace duckdb - - -namespace duckdb { - -const vector CompressedMaterializationFunctions::IntegralTypes() { - return {LogicalType::UTINYINT, LogicalType::USMALLINT, LogicalType::UINTEGER, LogicalType::UBIGINT}; -} - -const vector CompressedMaterializationFunctions::StringTypes() { - return {LogicalType::UTINYINT, LogicalType::USMALLINT, LogicalType::UINTEGER, LogicalType::UBIGINT, - LogicalType::HUGEINT}; -} - -// LCOV_EXCL_START -unique_ptr CompressedMaterializationFunctions::Bind(ClientContext &context, - ScalarFunction &bound_function, - vector> &arguments) { - throw BinderException("Compressed materialization functions are for internal use only!"); -} -// LCOV_EXCL_STOP - -void BuiltinFunctions::RegisterCompressedMaterializationFunctions() { - Register(); - Register(); - Register(); - Register(); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -struct ConstantOrNullBindData : public FunctionData { - explicit ConstantOrNullBindData(Value val) : value(std::move(val)) { - } - - Value value; - -public: - unique_ptr Copy() const override { - return make_uniq(value); - } - - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return value == other.value; - } -}; - -static void ConstantOrNullFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - result.Reference(info.value); - for (idx_t idx = 1; idx < args.ColumnCount(); idx++) { - switch (args.data[idx].GetVectorType()) { - case VectorType::FLAT_VECTOR: { - auto &input_mask = FlatVector::Validity(args.data[idx]); - if (!input_mask.AllValid()) { - // there are null values: need to merge them into the result - result.Flatten(args.size()); - auto &result_mask = FlatVector::Validity(result); - result_mask.Combine(input_mask, args.size()); - } - break; - } - case VectorType::CONSTANT_VECTOR: { - if (ConstantVector::IsNull(args.data[idx])) { - // input is constant null, return constant null - result.Reference(info.value); - ConstantVector::SetNull(result, true); - return; - } - break; - } - default: { - UnifiedVectorFormat vdata; - args.data[idx].ToUnifiedFormat(args.size(), vdata); - if (!vdata.validity.AllValid()) { - result.Flatten(args.size()); - auto &result_mask = FlatVector::Validity(result); - for (idx_t i = 0; i < args.size(); i++) { - if (!vdata.validity.RowIsValid(vdata.sel->get_index(i))) { - result_mask.SetInvalid(i); - } - } - } - break; - } - } - } -} - -ScalarFunction ConstantOrNull::GetFunction(const LogicalType &return_type) { - return ScalarFunction("constant_or_null", {return_type, LogicalType::ANY}, return_type, ConstantOrNullFunction); -} - -unique_ptr ConstantOrNull::Bind(Value value) { - return make_uniq(std::move(value)); -} - -bool ConstantOrNull::IsConstantOrNull(BoundFunctionExpression &expr, const Value &val) { - if (expr.function.name != "constant_or_null") { - return false; - } - D_ASSERT(expr.bind_info); - auto &bind_data = expr.bind_info->Cast(); - D_ASSERT(bind_data.value.type() == val.type()); - return bind_data.value == val; -} - -unique_ptr ConstantOrNullBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - if (arguments[0]->HasParameter()) { - throw ParameterNotResolvedException(); - } - if (!arguments[0]->IsFoldable()) { - throw BinderException("ConstantOrNull requires a constant input"); - } - D_ASSERT(arguments.size() >= 2); - auto value = ExpressionExecutor::EvaluateScalar(context, *arguments[0]); - bound_function.return_type = arguments[0]->return_type; - return make_uniq(std::move(value)); -} - -void ConstantOrNull::RegisterFunction(BuiltinFunctions &set) { - auto fun = ConstantOrNull::GetFunction(LogicalType::ANY); - fun.bind = ConstantOrNullBind; - fun.varargs = LogicalType::ANY; - set.AddFunction(fun); -} - -} // namespace duckdb - - -namespace duckdb { - -void BuiltinFunctions::RegisterGenericFunctions() { - Register(); - Register(); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -static void ListContainsFunction(DataChunk &args, ExpressionState &state, Vector &result) { - (void)state; - return ListContainsOrPosition(args, result); -} - -static void ListPositionFunction(DataChunk &args, ExpressionState &state, Vector &result) { - (void)state; - return ListContainsOrPosition(args, result); -} - -template -static unique_ptr ListContainsOrPositionBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - D_ASSERT(bound_function.arguments.size() == 2); - - const auto &list = arguments[0]->return_type; // change to list - const auto &value = arguments[1]->return_type; - if (list.id() == LogicalTypeId::UNKNOWN) { - bound_function.return_type = RETURN_TYPE; - if (value.id() != LogicalTypeId::UNKNOWN) { - // only list is a parameter, cast it to a list of value type - bound_function.arguments[0] = LogicalType::LIST(value); - bound_function.arguments[1] = value; - } - } else if (value.id() == LogicalTypeId::UNKNOWN) { - // only value is a parameter: we expect the child type of list - auto const &child_type = ListType::GetChildType(list); - bound_function.arguments[0] = list; - bound_function.arguments[1] = child_type; - bound_function.return_type = RETURN_TYPE; - } else { - auto const &child_type = ListType::GetChildType(list); - auto max_child_type = LogicalType::MaxLogicalType(child_type, value); - auto list_type = LogicalType::LIST(max_child_type); - - bound_function.arguments[0] = list_type; - bound_function.arguments[1] = value == max_child_type ? value : max_child_type; - - // list_contains and list_position only differ in their return type - bound_function.return_type = RETURN_TYPE; - } - return make_uniq(bound_function.return_type); -} - -static unique_ptr ListContainsBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - return ListContainsOrPositionBind(context, bound_function, arguments); -} - -static unique_ptr ListPositionBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - return ListContainsOrPositionBind(context, bound_function, arguments); -} - -ScalarFunction ListContainsFun::GetFunction() { - return ScalarFunction({LogicalType::LIST(LogicalType::ANY), LogicalType::ANY}, // argument list - LogicalType::BOOLEAN, // return type - ListContainsFunction, ListContainsBind, nullptr); -} - -ScalarFunction ListPositionFun::GetFunction() { - return ScalarFunction({LogicalType::LIST(LogicalType::ANY), LogicalType::ANY}, // argument list - LogicalType::INTEGER, // return type - ListPositionFunction, ListPositionBind, nullptr); -} - -void ListContainsFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction({"list_contains", "array_contains", "list_has", "array_has"}, GetFunction()); -} - -void ListPositionFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction({"list_position", "list_indexof", "array_position", "array_indexof"}, GetFunction()); -} -} // namespace duckdb - - - - - - -namespace duckdb { - -static void ListConcatFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 2); - auto count = args.size(); - - Vector &lhs = args.data[0]; - Vector &rhs = args.data[1]; - if (lhs.GetType().id() == LogicalTypeId::SQLNULL) { - result.Reference(rhs); - return; - } - if (rhs.GetType().id() == LogicalTypeId::SQLNULL) { - result.Reference(lhs); - return; - } - - UnifiedVectorFormat lhs_data; - UnifiedVectorFormat rhs_data; - lhs.ToUnifiedFormat(count, lhs_data); - rhs.ToUnifiedFormat(count, rhs_data); - auto lhs_entries = UnifiedVectorFormat::GetData(lhs_data); - auto rhs_entries = UnifiedVectorFormat::GetData(rhs_data); - - auto lhs_list_size = ListVector::GetListSize(lhs); - auto rhs_list_size = ListVector::GetListSize(rhs); - auto &lhs_child = ListVector::GetEntry(lhs); - auto &rhs_child = ListVector::GetEntry(rhs); - UnifiedVectorFormat lhs_child_data; - UnifiedVectorFormat rhs_child_data; - lhs_child.ToUnifiedFormat(lhs_list_size, lhs_child_data); - rhs_child.ToUnifiedFormat(rhs_list_size, rhs_child_data); - - result.SetVectorType(VectorType::FLAT_VECTOR); - auto result_entries = FlatVector::GetData(result); - auto &result_validity = FlatVector::Validity(result); - - idx_t offset = 0; - for (idx_t i = 0; i < count; i++) { - auto lhs_list_index = lhs_data.sel->get_index(i); - auto rhs_list_index = rhs_data.sel->get_index(i); - if (!lhs_data.validity.RowIsValid(lhs_list_index) && !rhs_data.validity.RowIsValid(rhs_list_index)) { - result_validity.SetInvalid(i); - continue; - } - result_entries[i].offset = offset; - result_entries[i].length = 0; - if (lhs_data.validity.RowIsValid(lhs_list_index)) { - const auto &lhs_entry = lhs_entries[lhs_list_index]; - result_entries[i].length += lhs_entry.length; - ListVector::Append(result, lhs_child, *lhs_child_data.sel, lhs_entry.offset + lhs_entry.length, - lhs_entry.offset); - } - if (rhs_data.validity.RowIsValid(rhs_list_index)) { - const auto &rhs_entry = rhs_entries[rhs_list_index]; - result_entries[i].length += rhs_entry.length; - ListVector::Append(result, rhs_child, *rhs_child_data.sel, rhs_entry.offset + rhs_entry.length, - rhs_entry.offset); - } - offset += result_entries[i].length; - } - D_ASSERT(ListVector::GetListSize(result) == offset); - - if (lhs.GetVectorType() == VectorType::CONSTANT_VECTOR && rhs.GetVectorType() == VectorType::CONSTANT_VECTOR) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } -} - -static unique_ptr ListConcatBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - D_ASSERT(bound_function.arguments.size() == 2); - - auto &lhs = arguments[0]->return_type; - auto &rhs = arguments[1]->return_type; - if (lhs.id() == LogicalTypeId::UNKNOWN || rhs.id() == LogicalTypeId::UNKNOWN) { - throw ParameterNotResolvedException(); - } else if (lhs.id() == LogicalTypeId::SQLNULL || rhs.id() == LogicalTypeId::SQLNULL) { - // we mimic postgres behaviour: list_concat(NULL, my_list) = my_list - auto return_type = rhs.id() == LogicalTypeId::SQLNULL ? lhs : rhs; - bound_function.arguments[0] = return_type; - bound_function.arguments[1] = return_type; - bound_function.return_type = return_type; - } else { - D_ASSERT(lhs.id() == LogicalTypeId::LIST); - D_ASSERT(rhs.id() == LogicalTypeId::LIST); - - // Resolve list type - LogicalType child_type = LogicalType::SQLNULL; - for (const auto &argument : arguments) { - child_type = LogicalType::MaxLogicalType(child_type, ListType::GetChildType(argument->return_type)); - } - auto list_type = LogicalType::LIST(child_type); - - bound_function.arguments[0] = list_type; - bound_function.arguments[1] = list_type; - bound_function.return_type = list_type; - } - return make_uniq(bound_function.return_type); -} - -static unique_ptr ListConcatStats(ClientContext &context, FunctionStatisticsInput &input) { - auto &child_stats = input.child_stats; - D_ASSERT(child_stats.size() == 2); - - auto &left_stats = child_stats[0]; - auto &right_stats = child_stats[1]; - - auto stats = left_stats.ToUnique(); - stats->Merge(right_stats); - - return stats; -} - -ScalarFunction ListConcatFun::GetFunction() { - // the arguments and return types are actually set in the binder function - auto fun = ScalarFunction({LogicalType::LIST(LogicalType::ANY), LogicalType::LIST(LogicalType::ANY)}, - LogicalType::LIST(LogicalType::ANY), ListConcatFunction, ListConcatBind, nullptr, - ListConcatStats); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return fun; -} - -void ListConcatFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction({"list_concat", "list_cat", "array_concat", "array_cat"}, GetFunction()); -} - -} // namespace duckdb - - - - - - - - - - - -namespace duckdb { - -template -void ListExtractTemplate(idx_t count, UnifiedVectorFormat &list_data, UnifiedVectorFormat &offsets_data, - Vector &child_vector, idx_t list_size, Vector &result) { - UnifiedVectorFormat child_format; - child_vector.ToUnifiedFormat(list_size, child_format); - - T *result_data; - - result.SetVectorType(VectorType::FLAT_VECTOR); - if (!VALIDITY_ONLY) { - result_data = FlatVector::GetData(result); - } - auto &result_mask = FlatVector::Validity(result); - - // heap-ref once - if (HEAP_REF) { - StringVector::AddHeapReference(result, child_vector); - } - - // this is lifted from ExecuteGenericLoop because we can't push the list child data into this otherwise - // should have gone with GetValue perhaps - auto child_data = UnifiedVectorFormat::GetData(child_format); - for (idx_t i = 0; i < count; i++) { - auto list_index = list_data.sel->get_index(i); - auto offsets_index = offsets_data.sel->get_index(i); - if (!list_data.validity.RowIsValid(list_index)) { - result_mask.SetInvalid(i); - continue; - } - if (!offsets_data.validity.RowIsValid(offsets_index)) { - result_mask.SetInvalid(i); - continue; - } - auto list_entry = (UnifiedVectorFormat::GetData(list_data))[list_index]; - auto offsets_entry = (UnifiedVectorFormat::GetData(offsets_data))[offsets_index]; - - // 1-based indexing - if (offsets_entry == 0) { - result_mask.SetInvalid(i); - continue; - } - offsets_entry = (offsets_entry > 0) ? offsets_entry - 1 : offsets_entry; - - idx_t child_offset; - if (offsets_entry < 0) { - if (offsets_entry < -int64_t(list_entry.length)) { - result_mask.SetInvalid(i); - continue; - } - child_offset = list_entry.offset + list_entry.length + offsets_entry; - } else { - if ((idx_t)offsets_entry >= list_entry.length) { - result_mask.SetInvalid(i); - continue; - } - child_offset = list_entry.offset + offsets_entry; - } - auto child_index = child_format.sel->get_index(child_offset); - if (child_format.validity.RowIsValid(child_index)) { - if (!VALIDITY_ONLY) { - result_data[i] = child_data[child_index]; - } - } else { - result_mask.SetInvalid(i); - } - } - if (count == 1) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } -} -static void ExecuteListExtractInternal(const idx_t count, UnifiedVectorFormat &list, UnifiedVectorFormat &offsets, - Vector &child_vector, idx_t list_size, Vector &result) { - D_ASSERT(child_vector.GetType() == result.GetType()); - switch (result.GetType().InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - ListExtractTemplate(count, list, offsets, child_vector, list_size, result); - break; - case PhysicalType::INT16: - ListExtractTemplate(count, list, offsets, child_vector, list_size, result); - break; - case PhysicalType::INT32: - ListExtractTemplate(count, list, offsets, child_vector, list_size, result); - break; - case PhysicalType::INT64: - ListExtractTemplate(count, list, offsets, child_vector, list_size, result); - break; - case PhysicalType::INT128: - ListExtractTemplate(count, list, offsets, child_vector, list_size, result); - break; - case PhysicalType::UINT8: - ListExtractTemplate(count, list, offsets, child_vector, list_size, result); - break; - case PhysicalType::UINT16: - ListExtractTemplate(count, list, offsets, child_vector, list_size, result); - break; - case PhysicalType::UINT32: - ListExtractTemplate(count, list, offsets, child_vector, list_size, result); - break; - case PhysicalType::UINT64: - ListExtractTemplate(count, list, offsets, child_vector, list_size, result); - break; - case PhysicalType::FLOAT: - ListExtractTemplate(count, list, offsets, child_vector, list_size, result); - break; - case PhysicalType::DOUBLE: - ListExtractTemplate(count, list, offsets, child_vector, list_size, result); - break; - case PhysicalType::VARCHAR: - ListExtractTemplate(count, list, offsets, child_vector, list_size, result); - break; - case PhysicalType::INTERVAL: - ListExtractTemplate(count, list, offsets, child_vector, list_size, result); - break; - case PhysicalType::STRUCT: { - auto &entries = StructVector::GetEntries(child_vector); - auto &result_entries = StructVector::GetEntries(result); - D_ASSERT(entries.size() == result_entries.size()); - // extract the child entries of the struct - for (idx_t i = 0; i < entries.size(); i++) { - ExecuteListExtractInternal(count, list, offsets, *entries[i], list_size, *result_entries[i]); - } - // extract the validity mask - ListExtractTemplate(count, list, offsets, child_vector, list_size, result); - break; - } - case PhysicalType::LIST: { - // nested list: we have to reference the child - auto &child_child_list = ListVector::GetEntry(child_vector); - - ListVector::GetEntry(result).Reference(child_child_list); - ListVector::SetListSize(result, ListVector::GetListSize(child_vector)); - ListExtractTemplate(count, list, offsets, child_vector, list_size, result); - break; - } - default: - throw NotImplementedException("Unimplemented type for LIST_EXTRACT"); - } -} - -static void ExecuteListExtract(Vector &result, Vector &list, Vector &offsets, const idx_t count) { - D_ASSERT(list.GetType().id() == LogicalTypeId::LIST); - UnifiedVectorFormat list_data; - UnifiedVectorFormat offsets_data; - - list.ToUnifiedFormat(count, list_data); - offsets.ToUnifiedFormat(count, offsets_data); - ExecuteListExtractInternal(count, list_data, offsets_data, ListVector::GetEntry(list), - ListVector::GetListSize(list), result); - result.Verify(count); -} - -static void ExecuteStringExtract(Vector &result, Vector &input_vector, Vector &subscript_vector, const idx_t count) { - BinaryExecutor::Execute( - input_vector, subscript_vector, result, count, [&](string_t input_string, int64_t subscript) { - return SubstringFun::SubstringUnicode(result, input_string, subscript, 1); - }); -} - -static void ListExtractFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 2); - auto count = args.size(); - - result.SetVectorType(VectorType::CONSTANT_VECTOR); - for (idx_t i = 0; i < args.ColumnCount(); i++) { - if (args.data[i].GetVectorType() != VectorType::CONSTANT_VECTOR) { - result.SetVectorType(VectorType::FLAT_VECTOR); - } - } - - Vector &base = args.data[0]; - Vector &subscript = args.data[1]; - - switch (base.GetType().id()) { - case LogicalTypeId::LIST: - ExecuteListExtract(result, base, subscript, count); - break; - case LogicalTypeId::VARCHAR: - ExecuteStringExtract(result, base, subscript, count); - break; - case LogicalTypeId::SQLNULL: - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - break; - default: - throw NotImplementedException("Specifier type not implemented"); - } -} - -static unique_ptr ListExtractBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - D_ASSERT(bound_function.arguments.size() == 2); - D_ASSERT(LogicalTypeId::LIST == arguments[0]->return_type.id()); - // list extract returns the child type of the list as return type - bound_function.return_type = ListType::GetChildType(arguments[0]->return_type); - return make_uniq(bound_function.return_type); -} - -static unique_ptr ListExtractStats(ClientContext &context, FunctionStatisticsInput &input) { - auto &child_stats = input.child_stats; - auto &list_child_stats = ListStats::GetChildStats(child_stats[0]); - auto child_copy = list_child_stats.Copy(); - // list_extract always pushes a NULL, since if the offset is out of range for a list it inserts a null - child_copy.Set(StatsInfo::CAN_HAVE_NULL_VALUES); - return child_copy.ToUnique(); -} - -void ListExtractFun::RegisterFunction(BuiltinFunctions &set) { - // the arguments and return types are actually set in the binder function - ScalarFunction lfun({LogicalType::LIST(LogicalType::ANY), LogicalType::BIGINT}, LogicalType::ANY, - ListExtractFunction, ListExtractBind, nullptr, ListExtractStats); - - ScalarFunction sfun({LogicalType::VARCHAR, LogicalType::BIGINT}, LogicalType::VARCHAR, ListExtractFunction); - - ScalarFunctionSet list_extract("list_extract"); - list_extract.AddFunction(lfun); - list_extract.AddFunction(sfun); - set.AddFunction(list_extract); - - ScalarFunctionSet list_element("list_element"); - list_element.AddFunction(lfun); - list_element.AddFunction(sfun); - set.AddFunction(list_element); - - ScalarFunctionSet array_extract("array_extract"); - array_extract.AddFunction(lfun); - array_extract.AddFunction(sfun); - array_extract.AddFunction(StructExtractFun::GetFunction()); - set.AddFunction(array_extract); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -void ListResizeFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.data[1].GetType().id() == LogicalTypeId::UBIGINT); - if (result.GetType().id() == LogicalTypeId::SQLNULL) { - FlatVector::SetNull(result, 0, true); - return; - } - D_ASSERT(result.GetType().id() == LogicalTypeId::LIST); - auto count = args.size(); - - result.SetVectorType(VectorType::FLAT_VECTOR); - - auto &lists = args.data[0]; - auto &child = ListVector::GetEntry(args.data[0]); - auto &new_sizes = args.data[1]; - - UnifiedVectorFormat list_data; - lists.ToUnifiedFormat(count, list_data); - auto list_entries = UnifiedVectorFormat::GetData(list_data); - - UnifiedVectorFormat new_size_data; - new_sizes.ToUnifiedFormat(count, new_size_data); - auto new_size_entries = UnifiedVectorFormat::GetData(new_size_data); - - UnifiedVectorFormat child_data; - child.ToUnifiedFormat(count, child_data); - - // Find the new size of the result child vector - idx_t new_child_size = 0; - for (idx_t i = 0; i < count; i++) { - auto index = new_size_data.sel->get_index(i); - if (new_size_data.validity.RowIsValid(index)) { - new_child_size += new_size_entries[index]; - } - } - - // Create the default vector if it exists - UnifiedVectorFormat default_data; - optional_ptr default_vector; - if (args.ColumnCount() == 3) { - default_vector = &args.data[2]; - default_vector->ToUnifiedFormat(count, default_data); - default_vector->SetVectorType(VectorType::CONSTANT_VECTOR); - } - - ListVector::Reserve(result, new_child_size); - ListVector::SetListSize(result, new_child_size); - - auto result_entries = FlatVector::GetData(result); - auto &result_child = ListVector::GetEntry(result); - - // for each lists in the args - idx_t result_child_offset = 0; - for (idx_t args_index = 0; args_index < count; args_index++) { - auto l_index = list_data.sel->get_index(args_index); - auto new_index = new_size_data.sel->get_index(args_index); - - // set null if lists is null - if (!list_data.validity.RowIsValid(l_index)) { - FlatVector::SetNull(result, args_index, true); - continue; - } - - idx_t new_size_entry = 0; - if (new_size_data.validity.RowIsValid(new_index)) { - new_size_entry = new_size_entries[new_index]; - } - - // find the smallest size between lists and new_sizes - auto values_to_copy = MinValue(list_entries[l_index].length, new_size_entry); - - // set the result entry - result_entries[args_index].offset = result_child_offset; - result_entries[args_index].length = new_size_entry; - - // copy the values from the child vector - VectorOperations::Copy(child, result_child, list_entries[l_index].offset + values_to_copy, - list_entries[l_index].offset, result_child_offset); - result_child_offset += values_to_copy; - - // set default value if it exists - idx_t def_index = 0; - if (args.ColumnCount() == 3) { - def_index = default_data.sel->get_index(args_index); - } - - // if the new size is larger than the old size, fill in the default value - if (values_to_copy < new_size_entry) { - if (default_vector && default_data.validity.RowIsValid(def_index)) { - VectorOperations::Copy(*default_vector, result_child, new_size_entry - values_to_copy, def_index, - result_child_offset); - result_child_offset += new_size_entry - values_to_copy; - } else { - for (idx_t j = values_to_copy; j < new_size_entry; j++) { - FlatVector::SetNull(result_child, result_child_offset, true); - result_child_offset++; - } - } - } - } - - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } -} - -static unique_ptr ListResizeBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - D_ASSERT(bound_function.arguments.size() == 2 || arguments.size() == 3); - bound_function.arguments[1] = LogicalType::UBIGINT; - - // first argument is constant NULL - if (arguments[0]->return_type == LogicalType::SQLNULL) { - bound_function.arguments[0] = LogicalType::SQLNULL; - bound_function.return_type = LogicalType::SQLNULL; - return make_uniq(bound_function.return_type); - } - - // prepared statements - if (arguments[0]->return_type == LogicalType::UNKNOWN) { - bound_function.return_type = arguments[0]->return_type; - return make_uniq(bound_function.return_type); - } - - // default type does not match list type - if (bound_function.arguments.size() == 3 && - ListType::GetChildType(arguments[0]->return_type) != arguments[2]->return_type && - arguments[2]->return_type != LogicalTypeId::SQLNULL) { - bound_function.arguments[2] = ListType::GetChildType(arguments[0]->return_type); - } - - bound_function.return_type = arguments[0]->return_type; - return make_uniq(bound_function.return_type); -} - -void ListResizeFun::RegisterFunction(BuiltinFunctions &set) { - ScalarFunction sfun({LogicalType::LIST(LogicalTypeId::ANY), LogicalTypeId::ANY}, - LogicalType::LIST(LogicalTypeId::ANY), ListResizeFunction, ListResizeBind); - sfun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - - ScalarFunction dfun({LogicalType::LIST(LogicalTypeId::ANY), LogicalTypeId::ANY, LogicalTypeId::ANY}, - LogicalType::LIST(LogicalTypeId::ANY), ListResizeFunction, ListResizeBind); - dfun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - - ScalarFunctionSet list_resize("list_resize"); - list_resize.AddFunction(sfun); - list_resize.AddFunction(dfun); - set.AddFunction(list_resize); - - ScalarFunctionSet array_resize("array_resize"); - array_resize.AddFunction(sfun); - array_resize.AddFunction(dfun); - set.AddFunction(array_resize); -} - -} // namespace duckdb - - -namespace duckdb { - -void BuiltinFunctions::RegisterNestedFunctions() { - Register(); - Register(); - Register(); - Register(); - Register(); - Register(); -} - -} // namespace duckdb - - - - - - - - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// + [add] -//===--------------------------------------------------------------------===// -template <> -float AddOperator::Operation(float left, float right) { - auto result = left + right; - return result; -} - -template <> -double AddOperator::Operation(double left, double right) { - auto result = left + right; - return result; -} - -template <> -interval_t AddOperator::Operation(interval_t left, interval_t right) { - left.months = AddOperatorOverflowCheck::Operation(left.months, right.months); - left.days = AddOperatorOverflowCheck::Operation(left.days, right.days); - left.micros = AddOperatorOverflowCheck::Operation(left.micros, right.micros); - return left; -} - -template <> -date_t AddOperator::Operation(date_t left, int32_t right) { - if (!Value::IsFinite(left)) { - return left; - } - int32_t days; - if (!TryAddOperator::Operation(left.days, right, days)) { - throw OutOfRangeException("Date out of range"); - } - date_t result(days); - if (!Value::IsFinite(result)) { - throw OutOfRangeException("Date out of range"); - } - return result; -} - -template <> -date_t AddOperator::Operation(int32_t left, date_t right) { - return AddOperator::Operation(right, left); -} - -template <> -timestamp_t AddOperator::Operation(date_t left, dtime_t right) { - if (left == date_t::infinity()) { - return timestamp_t::infinity(); - } else if (left == date_t::ninfinity()) { - return timestamp_t::ninfinity(); - } - timestamp_t result; - if (!Timestamp::TryFromDatetime(left, right, result)) { - throw OutOfRangeException("Timestamp out of range"); - } - return result; -} - -template <> -timestamp_t AddOperator::Operation(dtime_t left, date_t right) { - return AddOperator::Operation(right, left); -} - -template <> -date_t AddOperator::Operation(date_t left, interval_t right) { - return Interval::Add(left, right); -} - -template <> -date_t AddOperator::Operation(interval_t left, date_t right) { - return AddOperator::Operation(right, left); -} - -template <> -timestamp_t AddOperator::Operation(timestamp_t left, interval_t right) { - return Interval::Add(left, right); -} - -template <> -timestamp_t AddOperator::Operation(interval_t left, timestamp_t right) { - return AddOperator::Operation(right, left); -} - -//===--------------------------------------------------------------------===// -// + [add] with overflow check -//===--------------------------------------------------------------------===// -struct OverflowCheckedAddition { - template - static inline bool Operation(SRCTYPE left, SRCTYPE right, SRCTYPE &result) { - UTYPE uresult = AddOperator::Operation(UTYPE(left), UTYPE(right)); - if (uresult < NumericLimits::Minimum() || uresult > NumericLimits::Maximum()) { - return false; - } - result = SRCTYPE(uresult); - return true; - } -}; - -template <> -bool TryAddOperator::Operation(uint8_t left, uint8_t right, uint8_t &result) { - return OverflowCheckedAddition::Operation(left, right, result); -} -template <> -bool TryAddOperator::Operation(uint16_t left, uint16_t right, uint16_t &result) { - return OverflowCheckedAddition::Operation(left, right, result); -} -template <> -bool TryAddOperator::Operation(uint32_t left, uint32_t right, uint32_t &result) { - return OverflowCheckedAddition::Operation(left, right, result); -} - -template <> -bool TryAddOperator::Operation(uint64_t left, uint64_t right, uint64_t &result) { - if (NumericLimits::Maximum() - left < right) { - return false; - } - return OverflowCheckedAddition::Operation(left, right, result); -} - -template <> -bool TryAddOperator::Operation(int8_t left, int8_t right, int8_t &result) { - return OverflowCheckedAddition::Operation(left, right, result); -} - -template <> -bool TryAddOperator::Operation(int16_t left, int16_t right, int16_t &result) { - return OverflowCheckedAddition::Operation(left, right, result); -} - -template <> -bool TryAddOperator::Operation(int32_t left, int32_t right, int32_t &result) { - return OverflowCheckedAddition::Operation(left, right, result); -} - -template <> -bool TryAddOperator::Operation(int64_t left, int64_t right, int64_t &result) { -#if (__GNUC__ >= 5) || defined(__clang__) - if (__builtin_add_overflow(left, right, &result)) { - return false; - } -#else - // https://blog.regehr.org/archives/1139 - result = int64_t((uint64_t)left + (uint64_t)right); - if ((left < 0 && right < 0 && result >= 0) || (left >= 0 && right >= 0 && result < 0)) { - return false; - } -#endif - return true; -} - -template <> -bool TryAddOperator::Operation(hugeint_t left, hugeint_t right, hugeint_t &result) { - if (!Hugeint::AddInPlace(left, right)) { - return false; - } - result = left; - return true; -} - -//===--------------------------------------------------------------------===// -// add decimal with overflow check -//===--------------------------------------------------------------------===// -template -bool TryDecimalAddTemplated(T left, T right, T &result) { - if (right < 0) { - if (min - right > left) { - return false; - } - } else { - if (max - right < left) { - return false; - } - } - result = left + right; - return true; -} - -template <> -bool TryDecimalAdd::Operation(int16_t left, int16_t right, int16_t &result) { - return TryDecimalAddTemplated(left, right, result); -} - -template <> -bool TryDecimalAdd::Operation(int32_t left, int32_t right, int32_t &result) { - return TryDecimalAddTemplated(left, right, result); -} - -template <> -bool TryDecimalAdd::Operation(int64_t left, int64_t right, int64_t &result) { - return TryDecimalAddTemplated(left, right, result); -} - -template <> -bool TryDecimalAdd::Operation(hugeint_t left, hugeint_t right, hugeint_t &result) { - result = left + right; - if (result <= -Hugeint::POWERS_OF_TEN[38] || result >= Hugeint::POWERS_OF_TEN[38]) { - return false; - } - return true; -} - -template <> -hugeint_t DecimalAddOverflowCheck::Operation(hugeint_t left, hugeint_t right) { - hugeint_t result; - if (!TryDecimalAdd::Operation(left, right, result)) { - throw OutOfRangeException("Overflow in addition of DECIMAL(38) (%s + %s);", left.ToString(), right.ToString()); - } - return result; -} - -//===--------------------------------------------------------------------===// -// add time operator -//===--------------------------------------------------------------------===// -template <> -dtime_t AddTimeOperator::Operation(dtime_t left, interval_t right) { - date_t date(0); - return Interval::Add(left, right, date); -} - -template <> -dtime_t AddTimeOperator::Operation(interval_t left, dtime_t right) { - return AddTimeOperator::Operation(right, left); -} - -} // namespace duckdb - - - - - - - - - - - - - - - - -#include - -namespace duckdb { - -template -static scalar_function_t GetScalarIntegerFunction(PhysicalType type) { - scalar_function_t function; - switch (type) { - case PhysicalType::INT8: - function = &ScalarFunction::BinaryFunction; - break; - case PhysicalType::INT16: - function = &ScalarFunction::BinaryFunction; - break; - case PhysicalType::INT32: - function = &ScalarFunction::BinaryFunction; - break; - case PhysicalType::INT64: - function = &ScalarFunction::BinaryFunction; - break; - case PhysicalType::INT128: - function = &ScalarFunction::BinaryFunction; - break; - case PhysicalType::UINT8: - function = &ScalarFunction::BinaryFunction; - break; - case PhysicalType::UINT16: - function = &ScalarFunction::BinaryFunction; - break; - case PhysicalType::UINT32: - function = &ScalarFunction::BinaryFunction; - break; - case PhysicalType::UINT64: - function = &ScalarFunction::BinaryFunction; - break; - default: - throw NotImplementedException("Unimplemented type for GetScalarBinaryFunction"); - } - return function; -} - -template -static scalar_function_t GetScalarBinaryFunction(PhysicalType type) { - scalar_function_t function; - switch (type) { - case PhysicalType::INT128: - function = &ScalarFunction::BinaryFunction; - break; - case PhysicalType::FLOAT: - function = &ScalarFunction::BinaryFunction; - break; - case PhysicalType::DOUBLE: - function = &ScalarFunction::BinaryFunction; - break; - default: - function = GetScalarIntegerFunction(type); - break; - } - return function; -} - -//===--------------------------------------------------------------------===// -// + [add] -//===--------------------------------------------------------------------===// -struct AddPropagateStatistics { - template - static bool Operation(LogicalType type, BaseStatistics &lstats, BaseStatistics &rstats, Value &new_min, - Value &new_max) { - T min, max; - // new min is min+min - if (!OP::Operation(NumericStats::GetMin(lstats), NumericStats::GetMin(rstats), min)) { - return true; - } - // new max is max+max - if (!OP::Operation(NumericStats::GetMax(lstats), NumericStats::GetMax(rstats), max)) { - return true; - } - new_min = Value::Numeric(type, min); - new_max = Value::Numeric(type, max); - return false; - } -}; - -struct SubtractPropagateStatistics { - template - static bool Operation(LogicalType type, BaseStatistics &lstats, BaseStatistics &rstats, Value &new_min, - Value &new_max) { - T min, max; - if (!OP::Operation(NumericStats::GetMin(lstats), NumericStats::GetMax(rstats), min)) { - return true; - } - if (!OP::Operation(NumericStats::GetMax(lstats), NumericStats::GetMin(rstats), max)) { - return true; - } - new_min = Value::Numeric(type, min); - new_max = Value::Numeric(type, max); - return false; - } -}; - -struct DecimalArithmeticBindData : public FunctionData { - DecimalArithmeticBindData() : check_overflow(true) { - } - - unique_ptr Copy() const override { - auto res = make_uniq(); - res->check_overflow = check_overflow; - return std::move(res); - } - - bool Equals(const FunctionData &other_p) const override { - auto other = other_p.Cast(); - return other.check_overflow == check_overflow; - } - - bool check_overflow; -}; - -template -static unique_ptr PropagateNumericStats(ClientContext &context, FunctionStatisticsInput &input) { - auto &child_stats = input.child_stats; - auto &expr = input.expr; - D_ASSERT(child_stats.size() == 2); - // can only propagate stats if the children have stats - auto &lstats = child_stats[0]; - auto &rstats = child_stats[1]; - Value new_min, new_max; - bool potential_overflow = true; - if (NumericStats::HasMinMax(lstats) && NumericStats::HasMinMax(rstats)) { - switch (expr.return_type.InternalType()) { - case PhysicalType::INT8: - potential_overflow = - PROPAGATE::template Operation(expr.return_type, lstats, rstats, new_min, new_max); - break; - case PhysicalType::INT16: - potential_overflow = - PROPAGATE::template Operation(expr.return_type, lstats, rstats, new_min, new_max); - break; - case PhysicalType::INT32: - potential_overflow = - PROPAGATE::template Operation(expr.return_type, lstats, rstats, new_min, new_max); - break; - case PhysicalType::INT64: - potential_overflow = - PROPAGATE::template Operation(expr.return_type, lstats, rstats, new_min, new_max); - break; - default: - return nullptr; - } - } - if (potential_overflow) { - new_min = Value(expr.return_type); - new_max = Value(expr.return_type); - } else { - // no potential overflow: replace with non-overflowing operator - if (input.bind_data) { - auto &bind_data = input.bind_data->Cast(); - bind_data.check_overflow = false; - } - expr.function.function = GetScalarIntegerFunction(expr.return_type.InternalType()); - } - auto result = NumericStats::CreateEmpty(expr.return_type); - NumericStats::SetMin(result, new_min); - NumericStats::SetMax(result, new_max); - result.CombineValidity(lstats, rstats); - return result.ToUnique(); -} - -template -unique_ptr BindDecimalAddSubtract(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - - auto bind_data = make_uniq(); - - // get the max width and scale of the input arguments - uint8_t max_width = 0, max_scale = 0, max_width_over_scale = 0; - for (idx_t i = 0; i < arguments.size(); i++) { - if (arguments[i]->return_type.id() == LogicalTypeId::UNKNOWN) { - continue; - } - uint8_t width, scale; - auto can_convert = arguments[i]->return_type.GetDecimalProperties(width, scale); - if (!can_convert) { - throw InternalException("Could not convert type %s to a decimal.", arguments[i]->return_type.ToString()); - } - max_width = MaxValue(width, max_width); - max_scale = MaxValue(scale, max_scale); - max_width_over_scale = MaxValue(width - scale, max_width_over_scale); - } - D_ASSERT(max_width > 0); - // for addition/subtraction, we add 1 to the width to ensure we don't overflow - auto required_width = MaxValue(max_scale + max_width_over_scale, max_width) + 1; - if (required_width > Decimal::MAX_WIDTH_INT64 && max_width <= Decimal::MAX_WIDTH_INT64) { - // we don't automatically promote past the hugeint boundary to avoid the large hugeint performance penalty - bind_data->check_overflow = true; - required_width = Decimal::MAX_WIDTH_INT64; - } - if (required_width > Decimal::MAX_WIDTH_DECIMAL) { - // target width does not fit in decimal at all: truncate the scale and perform overflow detection - bind_data->check_overflow = true; - required_width = Decimal::MAX_WIDTH_DECIMAL; - } - // arithmetic between two decimal arguments: check the types of the input arguments - LogicalType result_type = LogicalType::DECIMAL(required_width, max_scale); - // we cast all input types to the specified type - for (idx_t i = 0; i < arguments.size(); i++) { - // first check if the cast is necessary - // if the argument has a matching scale and internal type as the output type, no casting is necessary - auto &argument_type = arguments[i]->return_type; - uint8_t width, scale; - argument_type.GetDecimalProperties(width, scale); - if (scale == DecimalType::GetScale(result_type) && argument_type.InternalType() == result_type.InternalType()) { - bound_function.arguments[i] = argument_type; - } else { - bound_function.arguments[i] = result_type; - } - } - bound_function.return_type = result_type; - // now select the physical function to execute - if (bind_data->check_overflow) { - bound_function.function = GetScalarBinaryFunction(result_type.InternalType()); - } else { - bound_function.function = GetScalarBinaryFunction(result_type.InternalType()); - } - if (result_type.InternalType() != PhysicalType::INT128) { - if (IS_SUBTRACT) { - bound_function.statistics = - PropagateNumericStats; - } else { - bound_function.statistics = PropagateNumericStats; - } - } - return std::move(bind_data); -} - -static void SerializeDecimalArithmetic(Serializer &serializer, const optional_ptr bind_data_p, - const ScalarFunction &function) { - auto &bind_data = bind_data_p->Cast(); - serializer.WriteProperty(100, "check_overflow", bind_data.check_overflow); - serializer.WriteProperty(101, "return_type", function.return_type); - serializer.WriteProperty(102, "arguments", function.arguments); -} - -// TODO this is partially duplicated from the bind -template -unique_ptr DeserializeDecimalArithmetic(Deserializer &deserializer, ScalarFunction &bound_function) { - - // // re-change the function pointers - auto check_overflow = deserializer.ReadProperty(100, "check_overflow"); - auto return_type = deserializer.ReadProperty(101, "return_type"); - auto arguments = deserializer.ReadProperty>(102, "arguments"); - if (check_overflow) { - bound_function.function = GetScalarBinaryFunction(return_type.InternalType()); - } else { - bound_function.function = GetScalarBinaryFunction(return_type.InternalType()); - } - bound_function.statistics = nullptr; // TODO we likely dont want to do stats prop again - bound_function.return_type = return_type; - bound_function.arguments = arguments; - - auto bind_data = make_uniq(); - bind_data->check_overflow = check_overflow; - return std::move(bind_data); -} - -unique_ptr NopDecimalBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - bound_function.return_type = arguments[0]->return_type; - bound_function.arguments[0] = arguments[0]->return_type; - return nullptr; -} - -ScalarFunction AddFun::GetFunction(const LogicalType &type) { - D_ASSERT(type.IsNumeric()); - if (type.id() == LogicalTypeId::DECIMAL) { - return ScalarFunction("+", {type}, type, ScalarFunction::NopFunction, NopDecimalBind); - } else { - return ScalarFunction("+", {type}, type, ScalarFunction::NopFunction); - } -} - -ScalarFunction AddFun::GetFunction(const LogicalType &left_type, const LogicalType &right_type) { - if (left_type.IsNumeric() && left_type.id() == right_type.id()) { - if (left_type.id() == LogicalTypeId::DECIMAL) { - auto function = ScalarFunction("+", {left_type, right_type}, left_type, nullptr, - BindDecimalAddSubtract); - function.serialize = SerializeDecimalArithmetic; - function.deserialize = DeserializeDecimalArithmetic; - return function; - } else if (left_type.IsIntegral()) { - return ScalarFunction("+", {left_type, right_type}, left_type, - GetScalarIntegerFunction(left_type.InternalType()), nullptr, - nullptr, PropagateNumericStats); - } else { - return ScalarFunction("+", {left_type, right_type}, left_type, - GetScalarBinaryFunction(left_type.InternalType())); - } - } - - switch (left_type.id()) { - case LogicalTypeId::DATE: - if (right_type.id() == LogicalTypeId::INTEGER) { - return ScalarFunction("+", {left_type, right_type}, LogicalType::DATE, - ScalarFunction::BinaryFunction); - } else if (right_type.id() == LogicalTypeId::INTERVAL) { - return ScalarFunction("+", {left_type, right_type}, LogicalType::DATE, - ScalarFunction::BinaryFunction); - } else if (right_type.id() == LogicalTypeId::TIME) { - return ScalarFunction("+", {left_type, right_type}, LogicalType::TIMESTAMP, - ScalarFunction::BinaryFunction); - } - break; - case LogicalTypeId::INTEGER: - if (right_type.id() == LogicalTypeId::DATE) { - return ScalarFunction("+", {left_type, right_type}, right_type, - ScalarFunction::BinaryFunction); - } - break; - case LogicalTypeId::INTERVAL: - if (right_type.id() == LogicalTypeId::INTERVAL) { - return ScalarFunction("+", {left_type, right_type}, LogicalType::INTERVAL, - ScalarFunction::BinaryFunction); - } else if (right_type.id() == LogicalTypeId::DATE) { - return ScalarFunction("+", {left_type, right_type}, LogicalType::DATE, - ScalarFunction::BinaryFunction); - } else if (right_type.id() == LogicalTypeId::TIME) { - return ScalarFunction("+", {left_type, right_type}, LogicalType::TIME, - ScalarFunction::BinaryFunction); - } else if (right_type.id() == LogicalTypeId::TIMESTAMP) { - return ScalarFunction("+", {left_type, right_type}, LogicalType::TIMESTAMP, - ScalarFunction::BinaryFunction); - } - break; - case LogicalTypeId::TIME: - if (right_type.id() == LogicalTypeId::INTERVAL) { - return ScalarFunction("+", {left_type, right_type}, LogicalType::TIME, - ScalarFunction::BinaryFunction); - } else if (right_type.id() == LogicalTypeId::DATE) { - return ScalarFunction("+", {left_type, right_type}, LogicalType::TIMESTAMP, - ScalarFunction::BinaryFunction); - } - break; - case LogicalTypeId::TIMESTAMP: - if (right_type.id() == LogicalTypeId::INTERVAL) { - return ScalarFunction("+", {left_type, right_type}, LogicalType::TIMESTAMP, - ScalarFunction::BinaryFunction); - } - break; - default: - break; - } - // LCOV_EXCL_START - throw NotImplementedException("AddFun for types %s, %s", EnumUtil::ToString(left_type.id()), - EnumUtil::ToString(right_type.id())); - // LCOV_EXCL_STOP -} - -void AddFun::RegisterFunction(BuiltinFunctions &set) { - ScalarFunctionSet functions("+"); - for (auto &type : LogicalType::Numeric()) { - // unary add function is a nop, but only exists for numeric types - functions.AddFunction(GetFunction(type)); - // binary add function adds two numbers together - functions.AddFunction(GetFunction(type, type)); - } - // we can add integers to dates - functions.AddFunction(GetFunction(LogicalType::DATE, LogicalType::INTEGER)); - functions.AddFunction(GetFunction(LogicalType::INTEGER, LogicalType::DATE)); - // we can add intervals together - functions.AddFunction(GetFunction(LogicalType::INTERVAL, LogicalType::INTERVAL)); - // we can add intervals to dates/times/timestamps - functions.AddFunction(GetFunction(LogicalType::DATE, LogicalType::INTERVAL)); - functions.AddFunction(GetFunction(LogicalType::INTERVAL, LogicalType::DATE)); - - functions.AddFunction(GetFunction(LogicalType::TIME, LogicalType::INTERVAL)); - functions.AddFunction(GetFunction(LogicalType::INTERVAL, LogicalType::TIME)); - - functions.AddFunction(GetFunction(LogicalType::TIMESTAMP, LogicalType::INTERVAL)); - functions.AddFunction(GetFunction(LogicalType::INTERVAL, LogicalType::TIMESTAMP)); - - // we can add times to dates - functions.AddFunction(GetFunction(LogicalType::TIME, LogicalType::DATE)); - functions.AddFunction(GetFunction(LogicalType::DATE, LogicalType::TIME)); - - // we can add lists together - functions.AddFunction(ListConcatFun::GetFunction()); - - set.AddFunction(functions); - - functions.name = "add"; - set.AddFunction(functions); -} - -//===--------------------------------------------------------------------===// -// - [subtract] -//===--------------------------------------------------------------------===// -struct NegateOperator { - template - static bool CanNegate(T input) { - using Limits = std::numeric_limits; - return !(Limits::is_integer && Limits::is_signed && Limits::lowest() == input); - } - - template - static inline TR Operation(TA input) { - auto cast = (TR)input; - if (!CanNegate(cast)) { - throw OutOfRangeException("Overflow in negation of integer!"); - } - return -cast; - } -}; - -template <> -bool NegateOperator::CanNegate(float input) { - return true; -} - -template <> -bool NegateOperator::CanNegate(double input) { - return true; -} - -template <> -interval_t NegateOperator::Operation(interval_t input) { - interval_t result; - result.months = NegateOperator::Operation(input.months); - result.days = NegateOperator::Operation(input.days); - result.micros = NegateOperator::Operation(input.micros); - return result; -} - -struct DecimalNegateBindData : public FunctionData { - DecimalNegateBindData() : bound_type(LogicalTypeId::INVALID) { - } - - unique_ptr Copy() const override { - auto res = make_uniq(); - res->bound_type = bound_type; - return std::move(res); - } - - bool Equals(const FunctionData &other_p) const override { - auto other = other_p.Cast(); - return other.bound_type == bound_type; - } - - LogicalTypeId bound_type; -}; - -unique_ptr DecimalNegateBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - - auto bind_data = make_uniq(); - - auto &decimal_type = arguments[0]->return_type; - auto width = DecimalType::GetWidth(decimal_type); - if (width <= Decimal::MAX_WIDTH_INT16) { - bound_function.function = ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::SMALLINT); - } else if (width <= Decimal::MAX_WIDTH_INT32) { - bound_function.function = ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::INTEGER); - } else if (width <= Decimal::MAX_WIDTH_INT64) { - bound_function.function = ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::BIGINT); - } else { - D_ASSERT(width <= Decimal::MAX_WIDTH_INT128); - bound_function.function = ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::HUGEINT); - } - decimal_type.Verify(); - bound_function.arguments[0] = decimal_type; - bound_function.return_type = decimal_type; - return nullptr; -} - -struct NegatePropagateStatistics { - template - static bool Operation(LogicalType type, BaseStatistics &istats, Value &new_min, Value &new_max) { - auto max_value = NumericStats::GetMax(istats); - auto min_value = NumericStats::GetMin(istats); - if (!NegateOperator::CanNegate(min_value) || !NegateOperator::CanNegate(max_value)) { - return true; - } - // new min is -max - new_min = Value::Numeric(type, NegateOperator::Operation(max_value)); - // new max is -min - new_max = Value::Numeric(type, NegateOperator::Operation(min_value)); - return false; - } -}; - -static unique_ptr NegateBindStatistics(ClientContext &context, FunctionStatisticsInput &input) { - auto &child_stats = input.child_stats; - auto &expr = input.expr; - D_ASSERT(child_stats.size() == 1); - // can only propagate stats if the children have stats - auto &istats = child_stats[0]; - Value new_min, new_max; - bool potential_overflow = true; - if (NumericStats::HasMinMax(istats)) { - switch (expr.return_type.InternalType()) { - case PhysicalType::INT8: - potential_overflow = - NegatePropagateStatistics::Operation(expr.return_type, istats, new_min, new_max); - break; - case PhysicalType::INT16: - potential_overflow = - NegatePropagateStatistics::Operation(expr.return_type, istats, new_min, new_max); - break; - case PhysicalType::INT32: - potential_overflow = - NegatePropagateStatistics::Operation(expr.return_type, istats, new_min, new_max); - break; - case PhysicalType::INT64: - potential_overflow = - NegatePropagateStatistics::Operation(expr.return_type, istats, new_min, new_max); - break; - default: - return nullptr; - } - } - if (potential_overflow) { - new_min = Value(expr.return_type); - new_max = Value(expr.return_type); - } - auto stats = NumericStats::CreateEmpty(expr.return_type); - NumericStats::SetMin(stats, new_min); - NumericStats::SetMax(stats, new_max); - stats.CopyValidity(istats); - return stats.ToUnique(); -} - -ScalarFunction SubtractFun::GetFunction(const LogicalType &type) { - if (type.id() == LogicalTypeId::INTERVAL) { - return ScalarFunction("-", {type}, type, ScalarFunction::UnaryFunction); - } else if (type.id() == LogicalTypeId::DECIMAL) { - return ScalarFunction("-", {type}, type, nullptr, DecimalNegateBind, nullptr, NegateBindStatistics); - } else { - D_ASSERT(type.IsNumeric()); - return ScalarFunction("-", {type}, type, ScalarFunction::GetScalarUnaryFunction(type), nullptr, - nullptr, NegateBindStatistics); - } -} - -ScalarFunction SubtractFun::GetFunction(const LogicalType &left_type, const LogicalType &right_type) { - if (left_type.IsNumeric() && left_type.id() == right_type.id()) { - if (left_type.id() == LogicalTypeId::DECIMAL) { - auto function = - ScalarFunction("-", {left_type, right_type}, left_type, nullptr, - BindDecimalAddSubtract); - function.serialize = SerializeDecimalArithmetic; - function.deserialize = DeserializeDecimalArithmetic; - return function; - } else if (left_type.IsIntegral()) { - return ScalarFunction( - "-", {left_type, right_type}, left_type, - GetScalarIntegerFunction(left_type.InternalType()), nullptr, nullptr, - PropagateNumericStats); - - } else { - return ScalarFunction("-", {left_type, right_type}, left_type, - GetScalarBinaryFunction(left_type.InternalType())); - } - } - - switch (left_type.id()) { - case LogicalTypeId::DATE: - if (right_type.id() == LogicalTypeId::DATE) { - return ScalarFunction("-", {left_type, right_type}, LogicalType::BIGINT, - ScalarFunction::BinaryFunction); - - } else if (right_type.id() == LogicalTypeId::INTEGER) { - return ScalarFunction("-", {left_type, right_type}, LogicalType::DATE, - ScalarFunction::BinaryFunction); - } else if (right_type.id() == LogicalTypeId::INTERVAL) { - return ScalarFunction("-", {left_type, right_type}, LogicalType::DATE, - ScalarFunction::BinaryFunction); - } - break; - case LogicalTypeId::TIMESTAMP: - if (right_type.id() == LogicalTypeId::TIMESTAMP) { - return ScalarFunction( - "-", {left_type, right_type}, LogicalType::INTERVAL, - ScalarFunction::BinaryFunction); - } else if (right_type.id() == LogicalTypeId::INTERVAL) { - return ScalarFunction( - "-", {left_type, right_type}, LogicalType::TIMESTAMP, - ScalarFunction::BinaryFunction); - } - break; - case LogicalTypeId::INTERVAL: - if (right_type.id() == LogicalTypeId::INTERVAL) { - return ScalarFunction("-", {left_type, right_type}, LogicalType::INTERVAL, - ScalarFunction::BinaryFunction); - } - break; - case LogicalTypeId::TIME: - if (right_type.id() == LogicalTypeId::INTERVAL) { - return ScalarFunction("-", {left_type, right_type}, LogicalType::TIME, - ScalarFunction::BinaryFunction); - } - break; - default: - break; - } - // LCOV_EXCL_START - throw NotImplementedException("SubtractFun for types %s, %s", EnumUtil::ToString(left_type.id()), - EnumUtil::ToString(right_type.id())); - // LCOV_EXCL_STOP -} - -void SubtractFun::RegisterFunction(BuiltinFunctions &set) { - ScalarFunctionSet functions("-"); - for (auto &type : LogicalType::Numeric()) { - // unary subtract function, negates the input (i.e. multiplies by -1) - functions.AddFunction(GetFunction(type)); - // binary subtract function "a - b", subtracts b from a - functions.AddFunction(GetFunction(type, type)); - } - // we can subtract dates from each other - functions.AddFunction(GetFunction(LogicalType::DATE, LogicalType::DATE)); - // we can subtract integers from dates - functions.AddFunction(GetFunction(LogicalType::DATE, LogicalType::INTEGER)); - // we can subtract timestamps from each other - functions.AddFunction(GetFunction(LogicalType::TIMESTAMP, LogicalType::TIMESTAMP)); - // we can subtract intervals from each other - functions.AddFunction(GetFunction(LogicalType::INTERVAL, LogicalType::INTERVAL)); - // we can subtract intervals from dates/times/timestamps, but not the other way around - functions.AddFunction(GetFunction(LogicalType::DATE, LogicalType::INTERVAL)); - functions.AddFunction(GetFunction(LogicalType::TIME, LogicalType::INTERVAL)); - functions.AddFunction(GetFunction(LogicalType::TIMESTAMP, LogicalType::INTERVAL)); - // we can negate intervals - functions.AddFunction(GetFunction(LogicalType::INTERVAL)); - set.AddFunction(functions); - - functions.name = "subtract"; - set.AddFunction(functions); -} - -//===--------------------------------------------------------------------===// -// * [multiply] -//===--------------------------------------------------------------------===// -struct MultiplyPropagateStatistics { - template - static bool Operation(LogicalType type, BaseStatistics &lstats, BaseStatistics &rstats, Value &new_min, - Value &new_max) { - // statistics propagation on the multiplication is slightly less straightforward because of negative numbers - // the new min/max depend on the signs of the input types - // if both are positive the result is [lmin * rmin][lmax * rmax] - // if lmin/lmax are negative the result is [lmin * rmax][lmax * rmin] - // etc - // rather than doing all this switcheroo we just multiply all combinations of lmin/lmax with rmin/rmax - // and check what the minimum/maximum value is - T lvals[] {NumericStats::GetMin(lstats), NumericStats::GetMax(lstats)}; - T rvals[] {NumericStats::GetMin(rstats), NumericStats::GetMax(rstats)}; - T min = NumericLimits::Maximum(); - T max = NumericLimits::Minimum(); - // multiplications - for (idx_t l = 0; l < 2; l++) { - for (idx_t r = 0; r < 2; r++) { - T result; - if (!OP::Operation(lvals[l], rvals[r], result)) { - // potential overflow - return true; - } - if (result < min) { - min = result; - } - if (result > max) { - max = result; - } - } - } - new_min = Value::Numeric(type, min); - new_max = Value::Numeric(type, max); - return false; - } -}; - -unique_ptr BindDecimalMultiply(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - - auto bind_data = make_uniq(); - - uint8_t result_width = 0, result_scale = 0; - uint8_t max_width = 0; - for (idx_t i = 0; i < arguments.size(); i++) { - if (arguments[i]->return_type.id() == LogicalTypeId::UNKNOWN) { - continue; - } - uint8_t width, scale; - auto can_convert = arguments[i]->return_type.GetDecimalProperties(width, scale); - if (!can_convert) { - throw InternalException("Could not convert type %s to a decimal?", arguments[i]->return_type.ToString()); - } - if (width > max_width) { - max_width = width; - } - result_width += width; - result_scale += scale; - } - D_ASSERT(max_width > 0); - if (result_scale > Decimal::MAX_WIDTH_DECIMAL) { - throw OutOfRangeException( - "Needed scale %d to accurately represent the multiplication result, but this is out of range of the " - "DECIMAL type. Max scale is %d; could not perform an accurate multiplication. Either add a cast to DOUBLE, " - "or add an explicit cast to a decimal with a lower scale.", - result_scale, Decimal::MAX_WIDTH_DECIMAL); - } - if (result_width > Decimal::MAX_WIDTH_INT64 && max_width <= Decimal::MAX_WIDTH_INT64 && - result_scale < Decimal::MAX_WIDTH_INT64) { - bind_data->check_overflow = true; - result_width = Decimal::MAX_WIDTH_INT64; - } - if (result_width > Decimal::MAX_WIDTH_DECIMAL) { - bind_data->check_overflow = true; - result_width = Decimal::MAX_WIDTH_DECIMAL; - } - LogicalType result_type = LogicalType::DECIMAL(result_width, result_scale); - // since our scale is the summation of our input scales, we do not need to cast to the result scale - // however, we might need to cast to the correct internal type - for (idx_t i = 0; i < arguments.size(); i++) { - auto &argument_type = arguments[i]->return_type; - if (argument_type.InternalType() == result_type.InternalType()) { - bound_function.arguments[i] = argument_type; - } else { - uint8_t width, scale; - if (!argument_type.GetDecimalProperties(width, scale)) { - scale = 0; - } - - bound_function.arguments[i] = LogicalType::DECIMAL(result_width, scale); - } - } - result_type.Verify(); - bound_function.return_type = result_type; - // now select the physical function to execute - if (bind_data->check_overflow) { - bound_function.function = GetScalarBinaryFunction(result_type.InternalType()); - } else { - bound_function.function = GetScalarBinaryFunction(result_type.InternalType()); - } - if (result_type.InternalType() != PhysicalType::INT128) { - bound_function.statistics = - PropagateNumericStats; - } - return std::move(bind_data); -} - -void MultiplyFun::RegisterFunction(BuiltinFunctions &set) { - ScalarFunctionSet functions("*"); - for (auto &type : LogicalType::Numeric()) { - if (type.id() == LogicalTypeId::DECIMAL) { - ScalarFunction function({type, type}, type, nullptr, BindDecimalMultiply); - function.serialize = SerializeDecimalArithmetic; - function.deserialize = DeserializeDecimalArithmetic; - functions.AddFunction(function); - } else if (TypeIsIntegral(type.InternalType())) { - functions.AddFunction(ScalarFunction( - {type, type}, type, GetScalarIntegerFunction(type.InternalType()), - nullptr, nullptr, - PropagateNumericStats)); - } else { - functions.AddFunction( - ScalarFunction({type, type}, type, GetScalarBinaryFunction(type.InternalType()))); - } - } - functions.AddFunction( - ScalarFunction({LogicalType::INTERVAL, LogicalType::BIGINT}, LogicalType::INTERVAL, - ScalarFunction::BinaryFunction)); - functions.AddFunction( - ScalarFunction({LogicalType::BIGINT, LogicalType::INTERVAL}, LogicalType::INTERVAL, - ScalarFunction::BinaryFunction)); - set.AddFunction(functions); - - functions.name = "multiply"; - set.AddFunction(functions); -} - -//===--------------------------------------------------------------------===// -// / [divide] -//===--------------------------------------------------------------------===// -template <> -float DivideOperator::Operation(float left, float right) { - auto result = left / right; - return result; -} - -template <> -double DivideOperator::Operation(double left, double right) { - auto result = left / right; - return result; -} - -template <> -hugeint_t DivideOperator::Operation(hugeint_t left, hugeint_t right) { - if (right.lower == 0 && right.upper == 0) { - throw InternalException("Hugeint division by zero!"); - } - return left / right; -} - -template <> -interval_t DivideOperator::Operation(interval_t left, int64_t right) { - left.days /= right; - left.months /= right; - left.micros /= right; - return left; -} - -struct BinaryNumericDivideWrapper { - template - static inline RESULT_TYPE Operation(FUNC fun, LEFT_TYPE left, RIGHT_TYPE right, ValidityMask &mask, idx_t idx) { - if (left == NumericLimits::Minimum() && right == -1) { - throw OutOfRangeException("Overflow in division of %d / %d", left, right); - } else if (right == 0) { - mask.SetInvalid(idx); - return left; - } else { - return OP::template Operation(left, right); - } - } - - static bool AddsNulls() { - return true; - } -}; - -struct BinaryZeroIsNullWrapper { - template - static inline RESULT_TYPE Operation(FUNC fun, LEFT_TYPE left, RIGHT_TYPE right, ValidityMask &mask, idx_t idx) { - if (right == 0) { - mask.SetInvalid(idx); - return left; - } else { - return OP::template Operation(left, right); - } - } - - static bool AddsNulls() { - return true; - } -}; - -struct BinaryZeroIsNullHugeintWrapper { - template - static inline RESULT_TYPE Operation(FUNC fun, LEFT_TYPE left, RIGHT_TYPE right, ValidityMask &mask, idx_t idx) { - if (right.upper == 0 && right.lower == 0) { - mask.SetInvalid(idx); - return left; - } else { - return OP::template Operation(left, right); - } - } - - static bool AddsNulls() { - return true; - } -}; - -template -static void BinaryScalarFunctionIgnoreZero(DataChunk &input, ExpressionState &state, Vector &result) { - BinaryExecutor::Execute(input.data[0], input.data[1], result, input.size()); -} - -template -static scalar_function_t GetBinaryFunctionIgnoreZero(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::TINYINT: - return BinaryScalarFunctionIgnoreZero; - case LogicalTypeId::SMALLINT: - return BinaryScalarFunctionIgnoreZero; - case LogicalTypeId::INTEGER: - return BinaryScalarFunctionIgnoreZero; - case LogicalTypeId::BIGINT: - return BinaryScalarFunctionIgnoreZero; - case LogicalTypeId::UTINYINT: - return BinaryScalarFunctionIgnoreZero; - case LogicalTypeId::USMALLINT: - return BinaryScalarFunctionIgnoreZero; - case LogicalTypeId::UINTEGER: - return BinaryScalarFunctionIgnoreZero; - case LogicalTypeId::UBIGINT: - return BinaryScalarFunctionIgnoreZero; - case LogicalTypeId::HUGEINT: - return BinaryScalarFunctionIgnoreZero; - case LogicalTypeId::FLOAT: - return BinaryScalarFunctionIgnoreZero; - case LogicalTypeId::DOUBLE: - return BinaryScalarFunctionIgnoreZero; - default: - throw NotImplementedException("Unimplemented type for GetScalarUnaryFunction"); - } -} - -void DivideFun::RegisterFunction(BuiltinFunctions &set) { - ScalarFunctionSet fp_divide("/"); - fp_divide.AddFunction(ScalarFunction({LogicalType::FLOAT, LogicalType::FLOAT}, LogicalType::FLOAT, - GetBinaryFunctionIgnoreZero(LogicalType::FLOAT))); - fp_divide.AddFunction(ScalarFunction({LogicalType::DOUBLE, LogicalType::DOUBLE}, LogicalType::DOUBLE, - GetBinaryFunctionIgnoreZero(LogicalType::DOUBLE))); - fp_divide.AddFunction( - ScalarFunction({LogicalType::INTERVAL, LogicalType::BIGINT}, LogicalType::INTERVAL, - BinaryScalarFunctionIgnoreZero)); - set.AddFunction(fp_divide); - - ScalarFunctionSet full_divide("//"); - for (auto &type : LogicalType::Numeric()) { - if (type.id() == LogicalTypeId::DECIMAL) { - continue; - } else { - full_divide.AddFunction( - ScalarFunction({type, type}, type, GetBinaryFunctionIgnoreZero(type))); - } - } - set.AddFunction(full_divide); - - full_divide.name = "divide"; - set.AddFunction(full_divide); -} - -//===--------------------------------------------------------------------===// -// % [modulo] -//===--------------------------------------------------------------------===// -template <> -float ModuloOperator::Operation(float left, float right) { - D_ASSERT(right != 0); - auto result = std::fmod(left, right); - return result; -} - -template <> -double ModuloOperator::Operation(double left, double right) { - D_ASSERT(right != 0); - auto result = std::fmod(left, right); - return result; -} - -template <> -hugeint_t ModuloOperator::Operation(hugeint_t left, hugeint_t right) { - if (right.lower == 0 && right.upper == 0) { - throw InternalException("Hugeint division by zero!"); - } - return left % right; -} - -void ModFun::RegisterFunction(BuiltinFunctions &set) { - ScalarFunctionSet functions("%"); - for (auto &type : LogicalType::Numeric()) { - if (type.id() == LogicalTypeId::DECIMAL) { - continue; - } else { - functions.AddFunction( - ScalarFunction({type, type}, type, GetBinaryFunctionIgnoreZero(type))); - } - } - set.AddFunction(functions); - functions.name = "mod"; - set.AddFunction(functions); -} - -} // namespace duckdb - - - - - - - -#include -#include - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// * [multiply] -//===--------------------------------------------------------------------===// -template <> -float MultiplyOperator::Operation(float left, float right) { - auto result = left * right; - return result; -} - -template <> -double MultiplyOperator::Operation(double left, double right) { - auto result = left * right; - return result; -} - -template <> -interval_t MultiplyOperator::Operation(interval_t left, int64_t right) { - left.months = MultiplyOperatorOverflowCheck::Operation(left.months, right); - left.days = MultiplyOperatorOverflowCheck::Operation(left.days, right); - left.micros = MultiplyOperatorOverflowCheck::Operation(left.micros, right); - return left; -} - -template <> -interval_t MultiplyOperator::Operation(int64_t left, interval_t right) { - return MultiplyOperator::Operation(right, left); -} - -//===--------------------------------------------------------------------===// -// * [multiply] with overflow check -//===--------------------------------------------------------------------===// -struct OverflowCheckedMultiply { - template - static inline bool Operation(SRCTYPE left, SRCTYPE right, SRCTYPE &result) { - UTYPE uresult = MultiplyOperator::Operation(UTYPE(left), UTYPE(right)); - if (uresult < NumericLimits::Minimum() || uresult > NumericLimits::Maximum()) { - return false; - } - result = SRCTYPE(uresult); - return true; - } -}; - -template <> -bool TryMultiplyOperator::Operation(uint8_t left, uint8_t right, uint8_t &result) { - return OverflowCheckedMultiply::Operation(left, right, result); -} -template <> -bool TryMultiplyOperator::Operation(uint16_t left, uint16_t right, uint16_t &result) { - return OverflowCheckedMultiply::Operation(left, right, result); -} -template <> -bool TryMultiplyOperator::Operation(uint32_t left, uint32_t right, uint32_t &result) { - return OverflowCheckedMultiply::Operation(left, right, result); -} -template <> -bool TryMultiplyOperator::Operation(uint64_t left, uint64_t right, uint64_t &result) { - if (left > right) { - std::swap(left, right); - } - if (left > NumericLimits::Maximum()) { - return false; - } - uint32_t c = right >> 32; - uint32_t d = NumericLimits::Maximum() & right; - uint64_t r = left * c; - uint64_t s = left * d; - if (r > NumericLimits::Maximum()) { - return false; - } - r <<= 32; - if (NumericLimits::Maximum() - s < r) { - return false; - } - return OverflowCheckedMultiply::Operation(left, right, result); -} - -template <> -bool TryMultiplyOperator::Operation(int8_t left, int8_t right, int8_t &result) { - return OverflowCheckedMultiply::Operation(left, right, result); -} - -template <> -bool TryMultiplyOperator::Operation(int16_t left, int16_t right, int16_t &result) { - return OverflowCheckedMultiply::Operation(left, right, result); -} - -template <> -bool TryMultiplyOperator::Operation(int32_t left, int32_t right, int32_t &result) { - return OverflowCheckedMultiply::Operation(left, right, result); -} - -template <> -bool TryMultiplyOperator::Operation(int64_t left, int64_t right, int64_t &result) { -#if (__GNUC__ >= 5) || defined(__clang__) - if (__builtin_mul_overflow(left, right, &result)) { - return false; - } -#else - if (left == std::numeric_limits::min()) { - if (right == 0) { - result = 0; - return true; - } - if (right == 1) { - result = left; - return true; - } - return false; - } - if (right == std::numeric_limits::min()) { - if (left == 0) { - result = 0; - return true; - } - if (left == 1) { - result = right; - return true; - } - return false; - } - uint64_t left_non_negative = uint64_t(std::abs(left)); - uint64_t right_non_negative = uint64_t(std::abs(right)); - // split values into 2 32-bit parts - uint64_t left_high_bits = left_non_negative >> 32; - uint64_t left_low_bits = left_non_negative & 0xffffffff; - uint64_t right_high_bits = right_non_negative >> 32; - uint64_t right_low_bits = right_non_negative & 0xffffffff; - - // check the high bits of both - // the high bits define the overflow - if (left_high_bits == 0) { - if (right_high_bits != 0) { - // only the right has high bits set - // multiply the high bits of right with the low bits of left - // multiply the low bits, and carry any overflow to the high bits - // then check for any overflow - auto low_low = left_low_bits * right_low_bits; - auto low_high = left_low_bits * right_high_bits; - auto high_bits = low_high + (low_low >> 32); - if (high_bits & 0xffffff80000000) { - // there is! abort - return false; - } - } - } else if (right_high_bits == 0) { - // only the left has high bits set - // multiply the high bits of left with the low bits of right - // multiply the low bits, and carry any overflow to the high bits - // then check for any overflow - auto low_low = left_low_bits * right_low_bits; - auto high_low = left_high_bits * right_low_bits; - auto high_bits = high_low + (low_low >> 32); - if (high_bits & 0xffffff80000000) { - // there is! abort - return false; - } - } else { - // both left and right have high bits set: guaranteed overflow - // abort! - return false; - } - // now we know that there is no overflow, we can just perform the multiplication - result = left * right; -#endif - return true; -} - -template <> -bool TryMultiplyOperator::Operation(hugeint_t left, hugeint_t right, hugeint_t &result) { - return Hugeint::TryMultiply(left, right, result); -} - -//===--------------------------------------------------------------------===// -// multiply decimal with overflow check -//===--------------------------------------------------------------------===// -template -bool TryDecimalMultiplyTemplated(T left, T right, T &result) { - if (!TryMultiplyOperator::Operation(left, right, result) || result < min || result > max) { - return false; - } - return true; -} - -template <> -bool TryDecimalMultiply::Operation(int16_t left, int16_t right, int16_t &result) { - return TryDecimalMultiplyTemplated(left, right, result); -} - -template <> -bool TryDecimalMultiply::Operation(int32_t left, int32_t right, int32_t &result) { - return TryDecimalMultiplyTemplated(left, right, result); -} - -template <> -bool TryDecimalMultiply::Operation(int64_t left, int64_t right, int64_t &result) { - return TryDecimalMultiplyTemplated(left, right, result); -} - -template <> -bool TryDecimalMultiply::Operation(hugeint_t left, hugeint_t right, hugeint_t &result) { - result = left * right; - if (result <= -Hugeint::POWERS_OF_TEN[38] || result >= Hugeint::POWERS_OF_TEN[38]) { - return false; - } - return true; -} - -template <> -hugeint_t DecimalMultiplyOverflowCheck::Operation(hugeint_t left, hugeint_t right) { - hugeint_t result; - if (!TryDecimalMultiply::Operation(left, right, result)) { - throw OutOfRangeException("Overflow in multiplication of DECIMAL(38) (%s * %s). You might want to add an " - "explicit cast to a decimal with a smaller scale.", - left.ToString(), right.ToString()); - } - return result; -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// - [subtract] -//===--------------------------------------------------------------------===// -template <> -float SubtractOperator::Operation(float left, float right) { - auto result = left - right; - return result; -} - -template <> -double SubtractOperator::Operation(double left, double right) { - auto result = left - right; - return result; -} - -template <> -int64_t SubtractOperator::Operation(date_t left, date_t right) { - return int64_t(left.days) - int64_t(right.days); -} - -template <> -date_t SubtractOperator::Operation(date_t left, int32_t right) { - if (!Date::IsFinite(left)) { - return left; - } - int32_t days; - if (!TrySubtractOperator::Operation(left.days, right, days)) { - throw OutOfRangeException("Date out of range"); - } - - date_t result(days); - if (!Date::IsFinite(result)) { - throw OutOfRangeException("Date out of range"); - } - return result; -} - -template <> -interval_t SubtractOperator::Operation(interval_t left, interval_t right) { - interval_t result; - result.months = left.months - right.months; - result.days = left.days - right.days; - result.micros = left.micros - right.micros; - return result; -} - -template <> -date_t SubtractOperator::Operation(date_t left, interval_t right) { - return AddOperator::Operation(left, Interval::Invert(right)); -} - -template <> -timestamp_t SubtractOperator::Operation(timestamp_t left, interval_t right) { - return AddOperator::Operation(left, Interval::Invert(right)); -} - -template <> -interval_t SubtractOperator::Operation(timestamp_t left, timestamp_t right) { - return Interval::GetDifference(left, right); -} - -//===--------------------------------------------------------------------===// -// - [subtract] with overflow check -//===--------------------------------------------------------------------===// -struct OverflowCheckedSubtract { - template - static inline bool Operation(SRCTYPE left, SRCTYPE right, SRCTYPE &result) { - UTYPE uresult = SubtractOperator::Operation(UTYPE(left), UTYPE(right)); - if (uresult < NumericLimits::Minimum() || uresult > NumericLimits::Maximum()) { - return false; - } - result = SRCTYPE(uresult); - return true; - } -}; - -template <> -bool TrySubtractOperator::Operation(uint8_t left, uint8_t right, uint8_t &result) { - if (right > left) { - return false; - } - return OverflowCheckedSubtract::Operation(left, right, result); -} - -template <> -bool TrySubtractOperator::Operation(uint16_t left, uint16_t right, uint16_t &result) { - if (right > left) { - return false; - } - return OverflowCheckedSubtract::Operation(left, right, result); -} - -template <> -bool TrySubtractOperator::Operation(uint32_t left, uint32_t right, uint32_t &result) { - if (right > left) { - return false; - } - return OverflowCheckedSubtract::Operation(left, right, result); -} - -template <> -bool TrySubtractOperator::Operation(uint64_t left, uint64_t right, uint64_t &result) { - if (right > left) { - return false; - } - return OverflowCheckedSubtract::Operation(left, right, result); -} - -template <> -bool TrySubtractOperator::Operation(int8_t left, int8_t right, int8_t &result) { - return OverflowCheckedSubtract::Operation(left, right, result); -} - -template <> -bool TrySubtractOperator::Operation(int16_t left, int16_t right, int16_t &result) { - return OverflowCheckedSubtract::Operation(left, right, result); -} - -template <> -bool TrySubtractOperator::Operation(int32_t left, int32_t right, int32_t &result) { - return OverflowCheckedSubtract::Operation(left, right, result); -} - -template <> -bool TrySubtractOperator::Operation(int64_t left, int64_t right, int64_t &result) { -#if (__GNUC__ >= 5) || defined(__clang__) - if (__builtin_sub_overflow(left, right, &result)) { - return false; - } -#else - if (right < 0) { - if (NumericLimits::Maximum() + right < left) { - return false; - } - } else { - if (NumericLimits::Minimum() + right > left) { - return false; - } - } - result = left - right; -#endif - return true; -} - -template <> -bool TrySubtractOperator::Operation(hugeint_t left, hugeint_t right, hugeint_t &result) { - result = left; - return Hugeint::SubtractInPlace(result, right); -} - -//===--------------------------------------------------------------------===// -// subtract decimal with overflow check -//===--------------------------------------------------------------------===// -template -bool TryDecimalSubtractTemplated(T left, T right, T &result) { - if (right < 0) { - if (max + right < left) { - return false; - } - } else { - if (min + right > left) { - return false; - } - } - result = left - right; - return true; -} - -template <> -bool TryDecimalSubtract::Operation(int16_t left, int16_t right, int16_t &result) { - return TryDecimalSubtractTemplated(left, right, result); -} - -template <> -bool TryDecimalSubtract::Operation(int32_t left, int32_t right, int32_t &result) { - return TryDecimalSubtractTemplated(left, right, result); -} - -template <> -bool TryDecimalSubtract::Operation(int64_t left, int64_t right, int64_t &result) { - return TryDecimalSubtractTemplated(left, right, result); -} - -template <> -bool TryDecimalSubtract::Operation(hugeint_t left, hugeint_t right, hugeint_t &result) { - result = left - right; - if (result <= -Hugeint::POWERS_OF_TEN[38] || result >= Hugeint::POWERS_OF_TEN[38]) { - return false; - } - return true; -} - -template <> -hugeint_t DecimalSubtractOverflowCheck::Operation(hugeint_t left, hugeint_t right) { - hugeint_t result; - if (!TryDecimalSubtract::Operation(left, right, result)) { - throw OutOfRangeException("Overflow in subtract of DECIMAL(38) (%s - %s);", left.ToString(), right.ToString()); - } - return result; -} - -//===--------------------------------------------------------------------===// -// subtract time operator -//===--------------------------------------------------------------------===// -template <> -dtime_t SubtractTimeOperator::Operation(dtime_t left, interval_t right) { - right.micros = -right.micros; - return AddTimeOperator::Operation(left, right); -} - -} // namespace duckdb - - - -namespace duckdb { - -void BuiltinFunctions::RegisterOperators() { - Register(); - Register(); - Register(); - Register(); - Register(); -} - -} // namespace duckdb - - -namespace duckdb { - -void BuiltinFunctions::RegisterPragmaFunctions() { - Register(); - Register(); -} - -} // namespace duckdb - - - - - - - - - - - - - - -namespace duckdb { - -struct NextvalBindData : public FunctionData { - explicit NextvalBindData(optional_ptr sequence) : sequence(sequence) { - } - - //! The sequence to use for the nextval computation; only if the sequence is a constant - optional_ptr sequence; - - unique_ptr Copy() const override { - return make_uniq(sequence); - } - - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return sequence == other.sequence; - } -}; - -struct CurrentSequenceValueOperator { - static int64_t Operation(DuckTransaction &transaction, SequenceCatalogEntry &seq) { - lock_guard seqlock(seq.lock); - int64_t result; - if (seq.usage_count == 0u) { - throw SequenceException("currval: sequence is not yet defined in this session"); - } - result = seq.last_value; - return result; - } -}; - -struct NextSequenceValueOperator { - static int64_t Operation(DuckTransaction &transaction, SequenceCatalogEntry &seq) { - lock_guard seqlock(seq.lock); - int64_t result; - result = seq.counter; - bool overflow = !TryAddOperator::Operation(seq.counter, seq.increment, seq.counter); - if (seq.cycle) { - if (overflow) { - seq.counter = seq.increment < 0 ? seq.max_value : seq.min_value; - } else if (seq.counter < seq.min_value) { - seq.counter = seq.max_value; - } else if (seq.counter > seq.max_value) { - seq.counter = seq.min_value; - } - } else { - if (result < seq.min_value || (overflow && seq.increment < 0)) { - throw SequenceException("nextval: reached minimum value of sequence \"%s\" (%lld)", seq.name, - seq.min_value); - } - if (result > seq.max_value || overflow) { - throw SequenceException("nextval: reached maximum value of sequence \"%s\" (%lld)", seq.name, - seq.max_value); - } - } - seq.last_value = result; - seq.usage_count++; - if (!seq.temporary) { - transaction.sequence_usage[&seq] = SequenceValue(seq.usage_count, seq.counter); - } - return result; - } -}; - -SequenceCatalogEntry &BindSequence(ClientContext &context, const string &name) { - auto qname = QualifiedName::Parse(name); - // fetch the sequence from the catalog - Binder::BindSchemaOrCatalog(context, qname.catalog, qname.schema); - return Catalog::GetEntry(context, qname.catalog, qname.schema, qname.name); -} - -template -static void NextValFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - auto &input = args.data[0]; - - auto &context = state.GetContext(); - if (info.sequence) { - auto &sequence = *info.sequence; - auto &transaction = DuckTransaction::Get(context, sequence.catalog); - // sequence to use is hard coded - // increment the sequence - result.SetVectorType(VectorType::FLAT_VECTOR); - auto result_data = FlatVector::GetData(result); - for (idx_t i = 0; i < args.size(); i++) { - // get the next value from the sequence - result_data[i] = OP::Operation(transaction, sequence); - } - } else { - // sequence to use comes from the input - UnaryExecutor::Execute(input, result, args.size(), [&](string_t value) { - // fetch the sequence from the catalog - auto &sequence = BindSequence(context, value.GetString()); - // finally get the next value from the sequence - auto &transaction = DuckTransaction::Get(context, sequence.catalog); - return OP::Operation(transaction, sequence); - }); - } -} - -static unique_ptr NextValBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - optional_ptr sequence; - if (arguments[0]->IsFoldable()) { - // parameter to nextval function is a foldable constant - // evaluate the constant and perform the catalog lookup already - auto seqname = ExpressionExecutor::EvaluateScalar(context, *arguments[0]); - if (!seqname.IsNull()) { - sequence = &BindSequence(context, seqname.ToString()); - } - } - return make_uniq(sequence); -} - -static void NextValDependency(BoundFunctionExpression &expr, DependencyList &dependencies) { - auto &info = expr.bind_info->Cast(); - if (info.sequence) { - dependencies.AddDependency(*info.sequence); - } -} - -void NextvalFun::RegisterFunction(BuiltinFunctions &set) { - ScalarFunction next_val("nextval", {LogicalType::VARCHAR}, LogicalType::BIGINT, - NextValFunction, NextValBind, NextValDependency); - next_val.side_effects = FunctionSideEffects::HAS_SIDE_EFFECTS; - set.AddFunction(next_val); -} - -void CurrvalFun::RegisterFunction(BuiltinFunctions &set) { - ScalarFunction curr_val("currval", {LogicalType::VARCHAR}, LogicalType::BIGINT, - NextValFunction, NextValBind, NextValDependency); - curr_val.side_effects = FunctionSideEffects::HAS_SIDE_EFFECTS; - set.AddFunction(curr_val); -} - -} // namespace duckdb - - -namespace duckdb { - -void BuiltinFunctions::RegisterSequenceFunctions() { - Register(); - Register(); -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -idx_t StrfTimepecifierSize(StrTimeSpecifier specifier) { - switch (specifier) { - case StrTimeSpecifier::ABBREVIATED_WEEKDAY_NAME: - case StrTimeSpecifier::ABBREVIATED_MONTH_NAME: - return 3; - case StrTimeSpecifier::WEEKDAY_DECIMAL: - return 1; - case StrTimeSpecifier::DAY_OF_MONTH_PADDED: - case StrTimeSpecifier::MONTH_DECIMAL_PADDED: - case StrTimeSpecifier::YEAR_WITHOUT_CENTURY_PADDED: - case StrTimeSpecifier::HOUR_24_PADDED: - case StrTimeSpecifier::HOUR_12_PADDED: - case StrTimeSpecifier::MINUTE_PADDED: - case StrTimeSpecifier::SECOND_PADDED: - case StrTimeSpecifier::AM_PM: - case StrTimeSpecifier::WEEK_NUMBER_PADDED_SUN_FIRST: - case StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST: - return 2; - case StrTimeSpecifier::NANOSECOND_PADDED: - return 9; - case StrTimeSpecifier::MICROSECOND_PADDED: - return 6; - case StrTimeSpecifier::MILLISECOND_PADDED: - return 3; - case StrTimeSpecifier::DAY_OF_YEAR_PADDED: - return 3; - default: - return 0; - } -} - -void StrTimeFormat::AddLiteral(string literal) { - constant_size += literal.size(); - literals.push_back(std::move(literal)); -} - -void StrTimeFormat::AddFormatSpecifier(string preceding_literal, StrTimeSpecifier specifier) { - AddLiteral(std::move(preceding_literal)); - specifiers.push_back(specifier); -} - -void StrfTimeFormat::AddFormatSpecifier(string preceding_literal, StrTimeSpecifier specifier) { - is_date_specifier.push_back(IsDateSpecifier(specifier)); - idx_t specifier_size = StrfTimepecifierSize(specifier); - if (specifier_size == 0) { - // variable length specifier - var_length_specifiers.push_back(specifier); - } else { - // constant size specifier - constant_size += specifier_size; - } - StrTimeFormat::AddFormatSpecifier(std::move(preceding_literal), specifier); -} - -idx_t StrfTimeFormat::GetSpecifierLength(StrTimeSpecifier specifier, date_t date, dtime_t time, int32_t utc_offset, - const char *tz_name) { - switch (specifier) { - case StrTimeSpecifier::FULL_WEEKDAY_NAME: - return Date::DAY_NAMES[Date::ExtractISODayOfTheWeek(date) % 7].GetSize(); - case StrTimeSpecifier::FULL_MONTH_NAME: - return Date::MONTH_NAMES[Date::ExtractMonth(date) - 1].GetSize(); - case StrTimeSpecifier::YEAR_DECIMAL: { - auto year = Date::ExtractYear(date); - // Be consistent with WriteStandardSpecifier - if (0 <= year && year <= 9999) { - return 4; - } else { - return NumericHelper::SignedLength(year); - } - } - case StrTimeSpecifier::MONTH_DECIMAL: { - idx_t len = 1; - auto month = Date::ExtractMonth(date); - len += month >= 10; - return len; - } - case StrTimeSpecifier::UTC_OFFSET: - // ±HH or ±HH:MM - return (utc_offset % 60) ? 6 : 3; - case StrTimeSpecifier::TZ_NAME: - if (tz_name) { - return strlen(tz_name); - } - // empty for now - return 0; - case StrTimeSpecifier::HOUR_24_DECIMAL: - case StrTimeSpecifier::HOUR_12_DECIMAL: - case StrTimeSpecifier::MINUTE_DECIMAL: - case StrTimeSpecifier::SECOND_DECIMAL: { - // time specifiers - idx_t len = 1; - int32_t hour, min, sec, msec; - Time::Convert(time, hour, min, sec, msec); - switch (specifier) { - case StrTimeSpecifier::HOUR_24_DECIMAL: - len += hour >= 10; - break; - case StrTimeSpecifier::HOUR_12_DECIMAL: - hour = hour % 12; - if (hour == 0) { - hour = 12; - } - len += hour >= 10; - break; - case StrTimeSpecifier::MINUTE_DECIMAL: - len += min >= 10; - break; - case StrTimeSpecifier::SECOND_DECIMAL: - len += sec >= 10; - break; - default: - throw InternalException("Time specifier mismatch"); - } - return len; - } - case StrTimeSpecifier::DAY_OF_MONTH: - return NumericHelper::UnsignedLength(Date::ExtractDay(date)); - case StrTimeSpecifier::DAY_OF_YEAR_DECIMAL: - return NumericHelper::UnsignedLength(Date::ExtractDayOfTheYear(date)); - case StrTimeSpecifier::YEAR_WITHOUT_CENTURY: - return NumericHelper::UnsignedLength(AbsValue(Date::ExtractYear(date)) % 100); - default: - throw InternalException("Unimplemented specifier for GetSpecifierLength"); - } -} - -//! Returns the total length of the date formatted by this format specifier -idx_t StrfTimeFormat::GetLength(date_t date, dtime_t time, int32_t utc_offset, const char *tz_name) { - idx_t size = constant_size; - if (!var_length_specifiers.empty()) { - for (auto &specifier : var_length_specifiers) { - size += GetSpecifierLength(specifier, date, time, utc_offset, tz_name); - } - } - return size; -} - -char *StrfTimeFormat::WriteString(char *target, const string_t &str) { - idx_t size = str.GetSize(); - memcpy(target, str.GetData(), size); - return target + size; -} - -// write a value in the range of 0..99 unpadded (e.g. "1", "2", ... "98", "99") -char *StrfTimeFormat::Write2(char *target, uint8_t value) { - D_ASSERT(value < 100); - if (value >= 10) { - return WritePadded2(target, value); - } else { - *target = char(uint8_t('0') + value); - return target + 1; - } -} - -// write a value in the range of 0..99 padded to 2 digits -char *StrfTimeFormat::WritePadded2(char *target, uint32_t value) { - D_ASSERT(value < 100); - auto index = static_cast(value * 2); - *target++ = duckdb_fmt::internal::data::digits[index]; - *target++ = duckdb_fmt::internal::data::digits[index + 1]; - return target; -} - -// write a value in the range of 0..999 padded -char *StrfTimeFormat::WritePadded3(char *target, uint32_t value) { - D_ASSERT(value < 1000); - if (value >= 100) { - WritePadded2(target + 1, value % 100); - *target = char(uint8_t('0') + value / 100); - return target + 3; - } else { - *target = '0'; - target++; - return WritePadded2(target, value); - } -} - -// write a value in the range of 0..999999... padded to the given number of digits -char *StrfTimeFormat::WritePadded(char *target, uint32_t value, size_t padding) { - D_ASSERT(padding > 1); - if (padding % 2) { - int decimals = value % 1000; - WritePadded3(target + padding - 3, decimals); - value /= 1000; - padding -= 3; - } - for (size_t i = 0; i < padding / 2; i++) { - int decimals = value % 100; - WritePadded2(target + padding - 2 * (i + 1), decimals); - value /= 100; - } - return target + padding; -} - -bool StrfTimeFormat::IsDateSpecifier(StrTimeSpecifier specifier) { - switch (specifier) { - case StrTimeSpecifier::ABBREVIATED_WEEKDAY_NAME: - case StrTimeSpecifier::FULL_WEEKDAY_NAME: - case StrTimeSpecifier::DAY_OF_YEAR_PADDED: - case StrTimeSpecifier::DAY_OF_YEAR_DECIMAL: - case StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST: - case StrTimeSpecifier::WEEK_NUMBER_PADDED_SUN_FIRST: - case StrTimeSpecifier::WEEKDAY_DECIMAL: - return true; - default: - return false; - } -} - -char *StrfTimeFormat::WriteDateSpecifier(StrTimeSpecifier specifier, date_t date, char *target) { - switch (specifier) { - case StrTimeSpecifier::ABBREVIATED_WEEKDAY_NAME: { - auto dow = Date::ExtractISODayOfTheWeek(date); - target = WriteString(target, Date::DAY_NAMES_ABBREVIATED[dow % 7]); - break; - } - case StrTimeSpecifier::FULL_WEEKDAY_NAME: { - auto dow = Date::ExtractISODayOfTheWeek(date); - target = WriteString(target, Date::DAY_NAMES[dow % 7]); - break; - } - case StrTimeSpecifier::WEEKDAY_DECIMAL: { - auto dow = Date::ExtractISODayOfTheWeek(date); - *target = char('0' + uint8_t(dow % 7)); - target++; - break; - } - case StrTimeSpecifier::DAY_OF_YEAR_PADDED: { - int32_t doy = Date::ExtractDayOfTheYear(date); - target = WritePadded3(target, doy); - break; - } - case StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST: - target = WritePadded2(target, Date::ExtractWeekNumberRegular(date, true)); - break; - case StrTimeSpecifier::WEEK_NUMBER_PADDED_SUN_FIRST: - target = WritePadded2(target, Date::ExtractWeekNumberRegular(date, false)); - break; - case StrTimeSpecifier::DAY_OF_YEAR_DECIMAL: { - uint32_t doy = Date::ExtractDayOfTheYear(date); - target += NumericHelper::UnsignedLength(doy); - NumericHelper::FormatUnsigned(doy, target); - break; - } - default: - throw InternalException("Unimplemented date specifier for strftime"); - } - return target; -} - -char *StrfTimeFormat::WriteStandardSpecifier(StrTimeSpecifier specifier, int32_t data[], const char *tz_name, - size_t tz_len, char *target) { - // data contains [0] year, [1] month, [2] day, [3] hour, [4] minute, [5] second, [6] msec, [7] utc - switch (specifier) { - case StrTimeSpecifier::DAY_OF_MONTH_PADDED: - target = WritePadded2(target, data[2]); - break; - case StrTimeSpecifier::ABBREVIATED_MONTH_NAME: { - auto &month_name = Date::MONTH_NAMES_ABBREVIATED[data[1] - 1]; - return WriteString(target, month_name); - } - case StrTimeSpecifier::FULL_MONTH_NAME: { - auto &month_name = Date::MONTH_NAMES[data[1] - 1]; - return WriteString(target, month_name); - } - case StrTimeSpecifier::MONTH_DECIMAL_PADDED: - target = WritePadded2(target, data[1]); - break; - case StrTimeSpecifier::YEAR_WITHOUT_CENTURY_PADDED: - target = WritePadded2(target, AbsValue(data[0]) % 100); - break; - case StrTimeSpecifier::YEAR_DECIMAL: - if (data[0] >= 0 && data[0] <= 9999) { - target = WritePadded(target, data[0], 4); - } else { - int32_t year = data[0]; - if (data[0] < 0) { - *target = '-'; - year = -year; - target++; - } - auto len = NumericHelper::UnsignedLength(year); - NumericHelper::FormatUnsigned(year, target + len); - target += len; - } - break; - case StrTimeSpecifier::HOUR_24_PADDED: { - target = WritePadded2(target, data[3]); - break; - } - case StrTimeSpecifier::HOUR_12_PADDED: { - int hour = data[3] % 12; - if (hour == 0) { - hour = 12; - } - target = WritePadded2(target, hour); - break; - } - case StrTimeSpecifier::AM_PM: - *target++ = data[3] >= 12 ? 'P' : 'A'; - *target++ = 'M'; - break; - case StrTimeSpecifier::MINUTE_PADDED: { - target = WritePadded2(target, data[4]); - break; - } - case StrTimeSpecifier::SECOND_PADDED: - target = WritePadded2(target, data[5]); - break; - case StrTimeSpecifier::NANOSECOND_PADDED: - target = WritePadded(target, data[6] * Interval::NANOS_PER_MICRO, 9); - break; - case StrTimeSpecifier::MICROSECOND_PADDED: - target = WritePadded(target, data[6], 6); - break; - case StrTimeSpecifier::MILLISECOND_PADDED: - target = WritePadded3(target, data[6] / Interval::MICROS_PER_MSEC); - break; - case StrTimeSpecifier::UTC_OFFSET: { - *target++ = (data[7] < 0) ? '-' : '+'; - - auto offset = abs(data[7]); - auto offset_hours = offset / Interval::MINS_PER_HOUR; - auto offset_minutes = offset % Interval::MINS_PER_HOUR; - target = WritePadded2(target, offset_hours); - if (offset_minutes) { - *target++ = ':'; - target = WritePadded2(target, offset_minutes); - } - break; - } - case StrTimeSpecifier::TZ_NAME: - if (tz_name) { - memcpy(target, tz_name, tz_len); - target += strlen(tz_name); - } - break; - case StrTimeSpecifier::DAY_OF_MONTH: { - target = Write2(target, data[2] % 100); - break; - } - case StrTimeSpecifier::MONTH_DECIMAL: { - target = Write2(target, data[1]); - break; - } - case StrTimeSpecifier::YEAR_WITHOUT_CENTURY: { - target = Write2(target, AbsValue(data[0]) % 100); - break; - } - case StrTimeSpecifier::HOUR_24_DECIMAL: { - target = Write2(target, data[3]); - break; - } - case StrTimeSpecifier::HOUR_12_DECIMAL: { - int hour = data[3] % 12; - if (hour == 0) { - hour = 12; - } - target = Write2(target, hour); - break; - } - case StrTimeSpecifier::MINUTE_DECIMAL: { - target = Write2(target, data[4]); - break; - } - case StrTimeSpecifier::SECOND_DECIMAL: { - target = Write2(target, data[5]); - break; - } - default: - throw InternalException("Unimplemented specifier for WriteStandardSpecifier in strftime"); - } - return target; -} - -void StrfTimeFormat::FormatString(date_t date, int32_t data[8], const char *tz_name, char *target) { - D_ASSERT(specifiers.size() + 1 == literals.size()); - idx_t i; - for (i = 0; i < specifiers.size(); i++) { - // first copy the current literal - memcpy(target, literals[i].c_str(), literals[i].size()); - target += literals[i].size(); - // now copy the specifier - if (is_date_specifier[i]) { - target = WriteDateSpecifier(specifiers[i], date, target); - } else { - auto tz_len = tz_name ? strlen(tz_name) : 0; - target = WriteStandardSpecifier(specifiers[i], data, tz_name, tz_len, target); - } - } - // copy the final literal into the target - memcpy(target, literals[i].c_str(), literals[i].size()); -} - -void StrfTimeFormat::FormatString(date_t date, dtime_t time, char *target) { - int32_t data[8]; // year, month, day, hour, min, sec, µs, offset - Date::Convert(date, data[0], data[1], data[2]); - Time::Convert(time, data[3], data[4], data[5], data[6]); - data[7] = 0; - - FormatString(date, data, nullptr, target); -} - -string StrfTimeFormat::Format(timestamp_t timestamp, const string &format_str) { - StrfTimeFormat format; - format.ParseFormatSpecifier(format_str, format); - - auto date = Timestamp::GetDate(timestamp); - auto time = Timestamp::GetTime(timestamp); - - auto len = format.GetLength(date, time, 0, nullptr); - auto result = make_unsafe_uniq_array(len); - format.FormatString(date, time, result.get()); - return string(result.get(), len); -} - -string StrTimeFormat::ParseFormatSpecifier(const string &format_string, StrTimeFormat &format) { - if (format_string.empty()) { - return "Empty format string"; - } - format.format_specifier = format_string; - format.specifiers.clear(); - format.literals.clear(); - format.numeric_width.clear(); - format.constant_size = 0; - idx_t pos = 0; - string current_literal; - for (idx_t i = 0; i < format_string.size(); i++) { - if (format_string[i] == '%') { - if (i + 1 == format_string.size()) { - return "Trailing format character %"; - } - if (i > pos) { - // push the previous string to the current literal - current_literal += format_string.substr(pos, i - pos); - } - char format_char = format_string[++i]; - if (format_char == '%') { - // special case: %% - // set the pos for the next literal and continue - pos = i; - continue; - } - StrTimeSpecifier specifier; - if (format_char == '-' && i + 1 < format_string.size()) { - format_char = format_string[++i]; - switch (format_char) { - case 'd': - specifier = StrTimeSpecifier::DAY_OF_MONTH; - break; - case 'm': - specifier = StrTimeSpecifier::MONTH_DECIMAL; - break; - case 'y': - specifier = StrTimeSpecifier::YEAR_WITHOUT_CENTURY; - break; - case 'H': - specifier = StrTimeSpecifier::HOUR_24_DECIMAL; - break; - case 'I': - specifier = StrTimeSpecifier::HOUR_12_DECIMAL; - break; - case 'M': - specifier = StrTimeSpecifier::MINUTE_DECIMAL; - break; - case 'S': - specifier = StrTimeSpecifier::SECOND_DECIMAL; - break; - case 'j': - specifier = StrTimeSpecifier::DAY_OF_YEAR_DECIMAL; - break; - default: - return "Unrecognized format for strftime/strptime: %-" + string(1, format_char); - } - } else { - switch (format_char) { - case 'a': - specifier = StrTimeSpecifier::ABBREVIATED_WEEKDAY_NAME; - break; - case 'A': - specifier = StrTimeSpecifier::FULL_WEEKDAY_NAME; - break; - case 'w': - specifier = StrTimeSpecifier::WEEKDAY_DECIMAL; - break; - case 'd': - specifier = StrTimeSpecifier::DAY_OF_MONTH_PADDED; - break; - case 'h': - case 'b': - specifier = StrTimeSpecifier::ABBREVIATED_MONTH_NAME; - break; - case 'B': - specifier = StrTimeSpecifier::FULL_MONTH_NAME; - break; - case 'm': - specifier = StrTimeSpecifier::MONTH_DECIMAL_PADDED; - break; - case 'y': - specifier = StrTimeSpecifier::YEAR_WITHOUT_CENTURY_PADDED; - break; - case 'Y': - specifier = StrTimeSpecifier::YEAR_DECIMAL; - break; - case 'H': - specifier = StrTimeSpecifier::HOUR_24_PADDED; - break; - case 'I': - specifier = StrTimeSpecifier::HOUR_12_PADDED; - break; - case 'p': - specifier = StrTimeSpecifier::AM_PM; - break; - case 'M': - specifier = StrTimeSpecifier::MINUTE_PADDED; - break; - case 'S': - specifier = StrTimeSpecifier::SECOND_PADDED; - break; - case 'n': - specifier = StrTimeSpecifier::NANOSECOND_PADDED; - break; - case 'f': - specifier = StrTimeSpecifier::MICROSECOND_PADDED; - break; - case 'g': - specifier = StrTimeSpecifier::MILLISECOND_PADDED; - break; - case 'z': - specifier = StrTimeSpecifier::UTC_OFFSET; - break; - case 'Z': - specifier = StrTimeSpecifier::TZ_NAME; - break; - case 'j': - specifier = StrTimeSpecifier::DAY_OF_YEAR_PADDED; - break; - case 'U': - specifier = StrTimeSpecifier::WEEK_NUMBER_PADDED_SUN_FIRST; - break; - case 'W': - specifier = StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST; - break; - case 'c': - case 'x': - case 'X': - case 'T': { - string subformat; - if (format_char == 'c') { - // %c: Locale’s appropriate date and time representation. - // we push the ISO timestamp representation here - subformat = "%Y-%m-%d %H:%M:%S"; - } else if (format_char == 'x') { - // %x - Locale’s appropriate date representation. - // we push the ISO date format here - subformat = "%Y-%m-%d"; - } else if (format_char == 'X' || format_char == 'T') { - // %X - Locale’s appropriate time representation. - // we push the ISO time format here - subformat = "%H:%M:%S"; - } - // parse the subformat in a separate format specifier - StrfTimeFormat locale_format; - string error = StrTimeFormat::ParseFormatSpecifier(subformat, locale_format); - D_ASSERT(error.empty()); - // add the previous literal to the first literal of the subformat - locale_format.literals[0] = std::move(current_literal) + locale_format.literals[0]; - current_literal = ""; - // now push the subformat into the current format specifier - for (idx_t i = 0; i < locale_format.specifiers.size(); i++) { - format.AddFormatSpecifier(std::move(locale_format.literals[i]), locale_format.specifiers[i]); - } - pos = i + 1; - continue; - } - default: - return "Unrecognized format for strftime/strptime: %" + string(1, format_char); - } - } - format.AddFormatSpecifier(std::move(current_literal), specifier); - current_literal = ""; - pos = i + 1; - } - } - // add the final literal - if (pos < format_string.size()) { - current_literal += format_string.substr(pos, format_string.size() - pos); - } - format.AddLiteral(std::move(current_literal)); - return string(); -} - -void StrfTimeFormat::ConvertDateVector(Vector &input, Vector &result, idx_t count) { - D_ASSERT(input.GetType().id() == LogicalTypeId::DATE); - D_ASSERT(result.GetType().id() == LogicalTypeId::VARCHAR); - UnaryExecutor::ExecuteWithNulls(input, result, count, - [&](date_t input, ValidityMask &mask, idx_t idx) { - if (Date::IsFinite(input)) { - dtime_t time(0); - idx_t len = GetLength(input, time, 0, nullptr); - string_t target = StringVector::EmptyString(result, len); - FormatString(input, time, target.GetDataWriteable()); - target.Finalize(); - return target; - } else { - mask.SetInvalid(idx); - return string_t(); - } - }); -} - -void StrfTimeFormat::ConvertTimestampVector(Vector &input, Vector &result, idx_t count) { - D_ASSERT(input.GetType().id() == LogicalTypeId::TIMESTAMP || input.GetType().id() == LogicalTypeId::TIMESTAMP_TZ); - D_ASSERT(result.GetType().id() == LogicalTypeId::VARCHAR); - UnaryExecutor::ExecuteWithNulls( - input, result, count, [&](timestamp_t input, ValidityMask &mask, idx_t idx) { - if (Timestamp::IsFinite(input)) { - date_t date; - dtime_t time; - Timestamp::Convert(input, date, time); - idx_t len = GetLength(date, time, 0, nullptr); - string_t target = StringVector::EmptyString(result, len); - FormatString(date, time, target.GetDataWriteable()); - target.Finalize(); - return target; - } else { - mask.SetInvalid(idx); - return string_t(); - } - }); -} - -void StrpTimeFormat::AddFormatSpecifier(string preceding_literal, StrTimeSpecifier specifier) { - numeric_width.push_back(NumericSpecifierWidth(specifier)); - StrTimeFormat::AddFormatSpecifier(std::move(preceding_literal), specifier); -} - -int StrpTimeFormat::NumericSpecifierWidth(StrTimeSpecifier specifier) { - switch (specifier) { - case StrTimeSpecifier::WEEKDAY_DECIMAL: - return 1; - case StrTimeSpecifier::DAY_OF_MONTH_PADDED: - case StrTimeSpecifier::DAY_OF_MONTH: - case StrTimeSpecifier::MONTH_DECIMAL_PADDED: - case StrTimeSpecifier::MONTH_DECIMAL: - case StrTimeSpecifier::YEAR_WITHOUT_CENTURY_PADDED: - case StrTimeSpecifier::YEAR_WITHOUT_CENTURY: - case StrTimeSpecifier::HOUR_24_PADDED: - case StrTimeSpecifier::HOUR_24_DECIMAL: - case StrTimeSpecifier::HOUR_12_PADDED: - case StrTimeSpecifier::HOUR_12_DECIMAL: - case StrTimeSpecifier::MINUTE_PADDED: - case StrTimeSpecifier::MINUTE_DECIMAL: - case StrTimeSpecifier::SECOND_PADDED: - case StrTimeSpecifier::SECOND_DECIMAL: - case StrTimeSpecifier::WEEK_NUMBER_PADDED_SUN_FIRST: - case StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST: - return 2; - case StrTimeSpecifier::MILLISECOND_PADDED: - case StrTimeSpecifier::DAY_OF_YEAR_PADDED: - case StrTimeSpecifier::DAY_OF_YEAR_DECIMAL: - return 3; - case StrTimeSpecifier::YEAR_DECIMAL: - return 4; - case StrTimeSpecifier::MICROSECOND_PADDED: - return 6; - case StrTimeSpecifier::NANOSECOND_PADDED: - return 9; - default: - return -1; - } -} - -enum class TimeSpecifierAMOrPM : uint8_t { TIME_SPECIFIER_NONE = 0, TIME_SPECIFIER_AM = 1, TIME_SPECIFIER_PM = 2 }; - -int32_t StrpTimeFormat::TryParseCollection(const char *data, idx_t &pos, idx_t size, const string_t collection[], - idx_t collection_count) const { - for (idx_t c = 0; c < collection_count; c++) { - auto &entry = collection[c]; - auto entry_data = entry.GetData(); - auto entry_size = entry.GetSize(); - // check if this entry matches - if (pos + entry_size > size) { - // too big: can't match - continue; - } - // compare the characters - idx_t i; - for (i = 0; i < entry_size; i++) { - if (std::tolower(entry_data[i]) != std::tolower(data[pos + i])) { - break; - } - } - if (i == entry_size) { - // full match - pos += entry_size; - return c; - } - } - return -1; -} - -//! Parses a timestamp using the given specifier -bool StrpTimeFormat::Parse(string_t str, ParseResult &result) const { - auto &result_data = result.data; - auto &error_message = result.error_message; - auto &error_position = result.error_position; - - // initialize the result - result_data[0] = 1900; - result_data[1] = 1; - result_data[2] = 1; - result_data[3] = 0; - result_data[4] = 0; - result_data[5] = 0; - result_data[6] = 0; - result_data[7] = 0; - - auto data = str.GetData(); - idx_t size = str.GetSize(); - // skip leading spaces - while (StringUtil::CharacterIsSpace(*data)) { - data++; - size--; - } - idx_t pos = 0; - TimeSpecifierAMOrPM ampm = TimeSpecifierAMOrPM::TIME_SPECIFIER_NONE; - - // Year offset state (Year+W/j) - auto offset_specifier = StrTimeSpecifier::WEEKDAY_DECIMAL; - uint64_t weekno = 0; - uint64_t weekday = 0; - uint64_t yearday = 0; - - for (idx_t i = 0;; i++) { - D_ASSERT(i < literals.size()); - // first compare the literal - const auto &literal = literals[i]; - for (size_t l = 0; l < literal.size();) { - // Match runs of spaces to runs of spaces. - if (StringUtil::CharacterIsSpace(literal[l])) { - if (!StringUtil::CharacterIsSpace(data[pos])) { - error_message = "Space does not match, expected " + literals[i]; - error_position = pos; - return false; - } - for (++pos; pos < size && StringUtil::CharacterIsSpace(data[pos]); ++pos) { - continue; - } - for (++l; l < literal.size() && StringUtil::CharacterIsSpace(literal[l]); ++l) { - continue; - } - continue; - } - // literal does not match - if (data[pos++] != literal[l++]) { - error_message = "Literal does not match, expected " + literal; - error_position = pos; - return false; - } - } - if (i == specifiers.size()) { - break; - } - // now parse the specifier - if (numeric_width[i] > 0) { - // numeric specifier: parse a number - uint64_t number = 0; - size_t start_pos = pos; - size_t end_pos = start_pos + numeric_width[i]; - while (pos < size && pos < end_pos && StringUtil::CharacterIsDigit(data[pos])) { - number = number * 10 + data[pos] - '0'; - pos++; - } - if (pos == start_pos) { - // expected a number here - error_message = "Expected a number"; - error_position = start_pos; - return false; - } - switch (specifiers[i]) { - case StrTimeSpecifier::DAY_OF_MONTH_PADDED: - case StrTimeSpecifier::DAY_OF_MONTH: - if (number < 1 || number > 31) { - error_message = "Day out of range, expected a value between 1 and 31"; - error_position = start_pos; - return false; - } - // day of the month - result_data[2] = number; - offset_specifier = specifiers[i]; - break; - case StrTimeSpecifier::MONTH_DECIMAL_PADDED: - case StrTimeSpecifier::MONTH_DECIMAL: - if (number < 1 || number > 12) { - error_message = "Month out of range, expected a value between 1 and 12"; - error_position = start_pos; - return false; - } - // month number - result_data[1] = number; - offset_specifier = specifiers[i]; - break; - case StrTimeSpecifier::YEAR_WITHOUT_CENTURY_PADDED: - case StrTimeSpecifier::YEAR_WITHOUT_CENTURY: - // year without century.. - // Python uses 69 as a crossover point (i.e. >= 69 is 19.., < 69 is 20..) - if (number >= 100) { - // %y only supports numbers between [0..99] - error_message = "Year without century out of range, expected a value between 0 and 99"; - error_position = start_pos; - return false; - } - if (number >= 69) { - result_data[0] = int32_t(1900 + number); - } else { - result_data[0] = int32_t(2000 + number); - } - break; - case StrTimeSpecifier::YEAR_DECIMAL: - // year as full number - result_data[0] = number; - break; - case StrTimeSpecifier::HOUR_24_PADDED: - case StrTimeSpecifier::HOUR_24_DECIMAL: - if (number >= 24) { - error_message = "Hour out of range, expected a value between 0 and 23"; - error_position = start_pos; - return false; - } - // hour as full number - result_data[3] = number; - break; - case StrTimeSpecifier::HOUR_12_PADDED: - case StrTimeSpecifier::HOUR_12_DECIMAL: - if (number < 1 || number > 12) { - error_message = "Hour12 out of range, expected a value between 1 and 12"; - error_position = start_pos; - return false; - } - // 12-hour number: start off by just storing the number - result_data[3] = number; - break; - case StrTimeSpecifier::MINUTE_PADDED: - case StrTimeSpecifier::MINUTE_DECIMAL: - if (number >= 60) { - error_message = "Minutes out of range, expected a value between 0 and 59"; - error_position = start_pos; - return false; - } - // minutes - result_data[4] = number; - break; - case StrTimeSpecifier::SECOND_PADDED: - case StrTimeSpecifier::SECOND_DECIMAL: - if (number >= 60) { - error_message = "Seconds out of range, expected a value between 0 and 59"; - error_position = start_pos; - return false; - } - // seconds - result_data[5] = number; - break; - case StrTimeSpecifier::NANOSECOND_PADDED: - D_ASSERT(number < Interval::NANOS_PER_SEC); // enforced by the length of the number - // microseconds (rounded) - result_data[6] = (number + Interval::NANOS_PER_MICRO / 2) / Interval::NANOS_PER_MICRO; - break; - case StrTimeSpecifier::MICROSECOND_PADDED: - D_ASSERT(number < Interval::MICROS_PER_SEC); // enforced by the length of the number - // microseconds - result_data[6] = number; - break; - case StrTimeSpecifier::MILLISECOND_PADDED: - D_ASSERT(number < Interval::MSECS_PER_SEC); // enforced by the length of the number - // microseconds - result_data[6] = number * Interval::MICROS_PER_MSEC; - break; - case StrTimeSpecifier::WEEK_NUMBER_PADDED_SUN_FIRST: - case StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST: - // m/d overrides WU/w but does not conflict - switch (offset_specifier) { - case StrTimeSpecifier::DAY_OF_MONTH_PADDED: - case StrTimeSpecifier::DAY_OF_MONTH: - case StrTimeSpecifier::MONTH_DECIMAL_PADDED: - case StrTimeSpecifier::MONTH_DECIMAL: - // Just validate, don't use - break; - case StrTimeSpecifier::WEEKDAY_DECIMAL: - // First offset specifier - offset_specifier = specifiers[i]; - break; - default: - error_message = "Multiple year offsets specified"; - error_position = start_pos; - return false; - } - if (number > 53) { - error_message = "Week out of range, expected a value between 0 and 53"; - error_position = start_pos; - return false; - } - weekno = number; - break; - case StrTimeSpecifier::WEEKDAY_DECIMAL: - if (number > 6) { - error_message = "Weekday out of range, expected a value between 0 and 6"; - error_position = start_pos; - return false; - } - weekday = number; - break; - case StrTimeSpecifier::DAY_OF_YEAR_PADDED: - case StrTimeSpecifier::DAY_OF_YEAR_DECIMAL: - // m/d overrides j but does not conflict - switch (offset_specifier) { - case StrTimeSpecifier::DAY_OF_MONTH_PADDED: - case StrTimeSpecifier::DAY_OF_MONTH: - case StrTimeSpecifier::MONTH_DECIMAL_PADDED: - case StrTimeSpecifier::MONTH_DECIMAL: - // Just validate, don't use - break; - case StrTimeSpecifier::WEEKDAY_DECIMAL: - // First offset specifier - offset_specifier = specifiers[i]; - break; - default: - error_message = "Multiple year offsets specified"; - error_position = start_pos; - return false; - } - if (number < 1 || number > 366) { - error_message = "Year day out of range, expected a value between 1 and 366"; - error_position = start_pos; - return false; - } - yearday = number; - break; - default: - throw NotImplementedException("Unsupported specifier for strptime"); - } - } else { - switch (specifiers[i]) { - case StrTimeSpecifier::AM_PM: { - // parse the next 2 characters - if (pos + 2 > size) { - // no characters left to parse - error_message = "Expected AM/PM"; - error_position = pos; - return false; - } - char pa_char = char(std::tolower(data[pos])); - char m_char = char(std::tolower(data[pos + 1])); - if (m_char != 'm') { - error_message = "Expected AM/PM"; - error_position = pos; - return false; - } - if (pa_char == 'p') { - ampm = TimeSpecifierAMOrPM::TIME_SPECIFIER_PM; - } else if (pa_char == 'a') { - ampm = TimeSpecifierAMOrPM::TIME_SPECIFIER_AM; - } else { - error_message = "Expected AM/PM"; - error_position = pos; - return false; - } - pos += 2; - break; - } - // we parse weekday names, but we don't use them as information - case StrTimeSpecifier::ABBREVIATED_WEEKDAY_NAME: - if (TryParseCollection(data, pos, size, Date::DAY_NAMES_ABBREVIATED, 7) < 0) { - error_message = "Expected an abbreviated day name (Mon, Tue, Wed, Thu, Fri, Sat, Sun)"; - error_position = pos; - return false; - } - break; - case StrTimeSpecifier::FULL_WEEKDAY_NAME: - if (TryParseCollection(data, pos, size, Date::DAY_NAMES, 7) < 0) { - error_message = "Expected a full day name (Monday, Tuesday, etc...)"; - error_position = pos; - return false; - } - break; - case StrTimeSpecifier::ABBREVIATED_MONTH_NAME: { - int32_t month = TryParseCollection(data, pos, size, Date::MONTH_NAMES_ABBREVIATED, 12); - if (month < 0) { - error_message = "Expected an abbreviated month name (Jan, Feb, Mar, etc..)"; - error_position = pos; - return false; - } - result_data[1] = month + 1; - break; - } - case StrTimeSpecifier::FULL_MONTH_NAME: { - int32_t month = TryParseCollection(data, pos, size, Date::MONTH_NAMES, 12); - if (month < 0) { - error_message = "Expected a full month name (January, February, etc...)"; - error_position = pos; - return false; - } - result_data[1] = month + 1; - break; - } - case StrTimeSpecifier::UTC_OFFSET: { - int hour_offset, minute_offset; - if (!Timestamp::TryParseUTCOffset(data, pos, size, hour_offset, minute_offset)) { - error_message = "Expected +HH[MM] or -HH[MM]"; - error_position = pos; - return false; - } - result_data[7] = hour_offset * Interval::MINS_PER_HOUR + minute_offset; - break; - } - case StrTimeSpecifier::TZ_NAME: { - // skip leading spaces - while (pos < size && StringUtil::CharacterIsSpace(data[pos])) { - pos++; - } - const auto tz_begin = data + pos; - // stop when we encounter a non-tz character - while (pos < size && Timestamp::CharacterIsTimeZone(data[pos])) { - pos++; - } - const auto tz_end = data + pos; - // Can't fully validate without a list - caller's responsibility. - // But tz must not be empty. - if (tz_end == tz_begin) { - error_message = "Empty Time Zone name"; - error_position = tz_begin - data; - return false; - } - result.tz.assign(tz_begin, tz_end); - break; - } - default: - throw NotImplementedException("Unsupported specifier for strptime"); - } - } - } - // skip trailing spaces - while (pos < size && StringUtil::CharacterIsSpace(data[pos])) { - pos++; - } - if (pos != size) { - error_message = "Full specifier did not match: trailing characters"; - error_position = pos; - return false; - } - if (ampm != TimeSpecifierAMOrPM::TIME_SPECIFIER_NONE) { - if (result_data[3] > 12) { - error_message = - "Invalid hour: " + to_string(result_data[3]) + " AM/PM, expected an hour within the range [0..12]"; - return false; - } - // adjust the hours based on the AM or PM specifier - if (ampm == TimeSpecifierAMOrPM::TIME_SPECIFIER_AM) { - // AM: 12AM=0, 1AM=1, 2AM=2, ..., 11AM=11 - if (result_data[3] == 12) { - result_data[3] = 0; - } - } else { - // PM: 12PM=12, 1PM=13, 2PM=14, ..., 11PM=23 - if (result_data[3] != 12) { - result_data[3] += 12; - } - } - } - switch (offset_specifier) { - case StrTimeSpecifier::WEEK_NUMBER_PADDED_SUN_FIRST: - case StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST: { - // Adjust weekday to be 0-based for the week type - weekday = (weekday + 7 - int(offset_specifier == StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST)) % 7; - // Get the start of week 1, move back 7 days and then weekno * 7 + weekday gives the date - const auto jan1 = Date::FromDate(result_data[0], 1, 1); - auto yeardate = Date::GetMondayOfCurrentWeek(jan1); - yeardate -= int(offset_specifier == StrTimeSpecifier::WEEK_NUMBER_PADDED_SUN_FIRST); - // Is there a week 0? - yeardate -= 7 * int(yeardate >= jan1); - yeardate += weekno * 7 + weekday; - Date::Convert(yeardate, result_data[0], result_data[1], result_data[2]); - break; - } - case StrTimeSpecifier::DAY_OF_YEAR_PADDED: - case StrTimeSpecifier::DAY_OF_YEAR_DECIMAL: { - auto yeardate = Date::FromDate(result_data[0], 1, 1); - yeardate += yearday - 1; - Date::Convert(yeardate, result_data[0], result_data[1], result_data[2]); - break; - } - case StrTimeSpecifier::DAY_OF_MONTH_PADDED: - case StrTimeSpecifier::DAY_OF_MONTH: - case StrTimeSpecifier::MONTH_DECIMAL_PADDED: - case StrTimeSpecifier::MONTH_DECIMAL: - // m/d overrides UWw/j - break; - default: - D_ASSERT(offset_specifier == StrTimeSpecifier::WEEKDAY_DECIMAL); - break; - } - - return true; -} - -StrpTimeFormat::ParseResult StrpTimeFormat::Parse(const string &format_string, const string &text) { - StrpTimeFormat format; - format.format_specifier = format_string; - string error = StrTimeFormat::ParseFormatSpecifier(format_string, format); - if (!error.empty()) { - throw InvalidInputException("Failed to parse format specifier %s: %s", format_string, error); - } - StrpTimeFormat::ParseResult result; - if (!format.Parse(text, result)) { - throw InvalidInputException("Failed to parse string \"%s\" with format specifier \"%s\"", text, format_string); - } - return result; -} - -string StrpTimeFormat::FormatStrpTimeError(const string &input, idx_t position) { - if (position == DConstants::INVALID_INDEX) { - return string(); - } - return input + "\n" + string(position, ' ') + "^"; -} - -date_t StrpTimeFormat::ParseResult::ToDate() { - return Date::FromDate(data[0], data[1], data[2]); -} - -bool StrpTimeFormat::ParseResult::TryToDate(date_t &result) { - return Date::TryFromDate(data[0], data[1], data[2], result); -} - -timestamp_t StrpTimeFormat::ParseResult::ToTimestamp() { - date_t date = Date::FromDate(data[0], data[1], data[2]); - const auto hour_offset = data[7] / Interval::MINS_PER_HOUR; - const auto mins_offset = data[7] % Interval::MINS_PER_HOUR; - dtime_t time = Time::FromTime(data[3] - hour_offset, data[4] - mins_offset, data[5], data[6]); - return Timestamp::FromDatetime(date, time); -} - -bool StrpTimeFormat::ParseResult::TryToTimestamp(timestamp_t &result) { - date_t date; - if (!TryToDate(date)) { - return false; - } - const auto hour_offset = data[7] / Interval::MINS_PER_HOUR; - const auto mins_offset = data[7] % Interval::MINS_PER_HOUR; - dtime_t time = Time::FromTime(data[3] - hour_offset, data[4] - mins_offset, data[5], data[6]); - return Timestamp::TryFromDatetime(date, time, result); -} - -string StrpTimeFormat::ParseResult::FormatError(string_t input, const string &format_specifier) { - return StringUtil::Format("Could not parse string \"%s\" according to format specifier \"%s\"\n%s\nError: %s", - input.GetString(), format_specifier, - FormatStrpTimeError(input.GetString(), error_position), error_message); -} - -bool StrpTimeFormat::TryParseDate(string_t input, date_t &result, string &error_message) const { - ParseResult parse_result; - if (!Parse(input, parse_result)) { - error_message = parse_result.FormatError(input, format_specifier); - return false; - } - return parse_result.TryToDate(result); -} - -bool StrpTimeFormat::TryParseTimestamp(string_t input, timestamp_t &result, string &error_message) const { - ParseResult parse_result; - if (!Parse(input, parse_result)) { - error_message = parse_result.FormatError(input, format_specifier); - return false; - } - return parse_result.TryToTimestamp(result); -} - -date_t StrpTimeFormat::ParseDate(string_t input) { - ParseResult result; - if (!Parse(input, result)) { - throw InvalidInputException(result.FormatError(input, format_specifier)); - } - return result.ToDate(); -} - -timestamp_t StrpTimeFormat::ParseTimestamp(string_t input) { - ParseResult result; - if (!Parse(input, result)) { - throw InvalidInputException(result.FormatError(input, format_specifier)); - } - return result.ToTimestamp(); -} - -} // namespace duckdb - - - - - - - - - -#include - -namespace duckdb { - -uint8_t UpperFun::ascii_to_upper_map[] = { - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, - 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, - 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, - 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, - 88, 89, 90, 91, 92, 93, 94, 95, 96, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, - 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 123, 124, 125, 126, 127, 128, 129, 130, 131, - 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, - 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, - 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, - 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, - 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, - 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254}; -uint8_t LowerFun::ascii_to_lower_map[] = { - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, - 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, - 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 97, - 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, - 120, 121, 122, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, - 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, - 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, - 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, - 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, - 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, - 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, - 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254}; - -template -static string_t ASCIICaseConvert(Vector &result, const char *input_data, idx_t input_length) { - idx_t output_length = input_length; - auto result_str = StringVector::EmptyString(result, output_length); - auto result_data = result_str.GetDataWriteable(); - for (idx_t i = 0; i < input_length; i++) { - result_data[i] = IS_UPPER ? UpperFun::ascii_to_upper_map[uint8_t(input_data[i])] - : LowerFun::ascii_to_lower_map[uint8_t(input_data[i])]; - } - result_str.Finalize(); - return result_str; -} - -template -static idx_t GetResultLength(const char *input_data, idx_t input_length) { - idx_t output_length = 0; - for (idx_t i = 0; i < input_length;) { - if (input_data[i] & 0x80) { - // unicode - int sz = 0; - int codepoint = utf8proc_codepoint(input_data + i, sz); - int converted_codepoint = IS_UPPER ? utf8proc_toupper(codepoint) : utf8proc_tolower(codepoint); - int new_sz = utf8proc_codepoint_length(converted_codepoint); - D_ASSERT(new_sz >= 0); - output_length += new_sz; - i += sz; - } else { - // ascii - output_length++; - i++; - } - } - return output_length; -} - -template -static void CaseConvert(const char *input_data, idx_t input_length, char *result_data) { - for (idx_t i = 0; i < input_length;) { - if (input_data[i] & 0x80) { - // non-ascii character - int sz = 0, new_sz = 0; - int codepoint = utf8proc_codepoint(input_data + i, sz); - int converted_codepoint = IS_UPPER ? utf8proc_toupper(codepoint) : utf8proc_tolower(codepoint); - auto success = utf8proc_codepoint_to_utf8(converted_codepoint, new_sz, result_data); - D_ASSERT(success); - (void)success; - result_data += new_sz; - i += sz; - } else { - // ascii - *result_data = IS_UPPER ? UpperFun::ascii_to_upper_map[uint8_t(input_data[i])] - : LowerFun::ascii_to_lower_map[uint8_t(input_data[i])]; - result_data++; - i++; - } - } -} - -idx_t LowerFun::LowerLength(const char *input_data, idx_t input_length) { - return GetResultLength(input_data, input_length); -} - -void LowerFun::LowerCase(const char *input_data, idx_t input_length, char *result_data) { - CaseConvert(input_data, input_length, result_data); -} - -template -static string_t UnicodeCaseConvert(Vector &result, const char *input_data, idx_t input_length) { - // first figure out the output length - idx_t output_length = GetResultLength(input_data, input_length); - auto result_str = StringVector::EmptyString(result, output_length); - auto result_data = result_str.GetDataWriteable(); - - CaseConvert(input_data, input_length, result_data); - result_str.Finalize(); - return result_str; -} - -template -struct CaseConvertOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto input_data = input.GetData(); - auto input_length = input.GetSize(); - return UnicodeCaseConvert(result, input_data, input_length); - } -}; - -template -static void CaseConvertFunction(DataChunk &args, ExpressionState &state, Vector &result) { - UnaryExecutor::ExecuteString>(args.data[0], result, args.size()); -} - -template -struct CaseConvertOperatorASCII { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto input_data = input.GetData(); - auto input_length = input.GetSize(); - return ASCIICaseConvert(result, input_data, input_length); - } -}; - -template -static void CaseConvertFunctionASCII(DataChunk &args, ExpressionState &state, Vector &result) { - UnaryExecutor::ExecuteString>(args.data[0], result, - args.size()); -} - -template -static unique_ptr CaseConvertPropagateStats(ClientContext &context, FunctionStatisticsInput &input) { - auto &child_stats = input.child_stats; - auto &expr = input.expr; - D_ASSERT(child_stats.size() == 1); - // can only propagate stats if the children have stats - if (!StringStats::CanContainUnicode(child_stats[0])) { - expr.function.function = CaseConvertFunctionASCII; - } - return nullptr; -} - -ScalarFunction LowerFun::GetFunction() { - return ScalarFunction("lower", {LogicalType::VARCHAR}, LogicalType::VARCHAR, CaseConvertFunction, nullptr, - nullptr, CaseConvertPropagateStats); -} - -void LowerFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction({"lower", "lcase"}, LowerFun::GetFunction()); -} - -void UpperFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction({"upper", "ucase"}, - ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, CaseConvertFunction, nullptr, - nullptr, CaseConvertPropagateStats)); -} - -} // namespace duckdb - - - - - - - -#include - -namespace duckdb { - -static void ConcatFunction(DataChunk &args, ExpressionState &state, Vector &result) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - // iterate over the vectors to count how large the final string will be - idx_t constant_lengths = 0; - vector result_lengths(args.size(), 0); - for (idx_t col_idx = 0; col_idx < args.ColumnCount(); col_idx++) { - auto &input = args.data[col_idx]; - D_ASSERT(input.GetType().id() == LogicalTypeId::VARCHAR); - if (input.GetVectorType() == VectorType::CONSTANT_VECTOR) { - if (ConstantVector::IsNull(input)) { - // constant null, skip - continue; - } - auto input_data = ConstantVector::GetData(input); - constant_lengths += input_data->GetSize(); - } else { - // non-constant vector: set the result type to a flat vector - result.SetVectorType(VectorType::FLAT_VECTOR); - // now get the lengths of each of the input elements - UnifiedVectorFormat vdata; - input.ToUnifiedFormat(args.size(), vdata); - - auto input_data = UnifiedVectorFormat::GetData(vdata); - // now add the length of each vector to the result length - for (idx_t i = 0; i < args.size(); i++) { - auto idx = vdata.sel->get_index(i); - if (!vdata.validity.RowIsValid(idx)) { - continue; - } - result_lengths[i] += input_data[idx].GetSize(); - } - } - } - - // first we allocate the empty strings for each of the values - auto result_data = FlatVector::GetData(result); - for (idx_t i = 0; i < args.size(); i++) { - // allocate an empty string of the required size - idx_t str_length = constant_lengths + result_lengths[i]; - result_data[i] = StringVector::EmptyString(result, str_length); - // we reuse the result_lengths vector to store the currently appended size - result_lengths[i] = 0; - } - - // now that the empty space for the strings has been allocated, perform the concatenation - for (idx_t col_idx = 0; col_idx < args.ColumnCount(); col_idx++) { - auto &input = args.data[col_idx]; - - // loop over the vector and concat to all results - if (input.GetVectorType() == VectorType::CONSTANT_VECTOR) { - // constant vector - if (ConstantVector::IsNull(input)) { - // constant null, skip - continue; - } - // append the constant vector to each of the strings - auto input_data = ConstantVector::GetData(input); - auto input_ptr = input_data->GetData(); - auto input_len = input_data->GetSize(); - for (idx_t i = 0; i < args.size(); i++) { - memcpy(result_data[i].GetDataWriteable() + result_lengths[i], input_ptr, input_len); - result_lengths[i] += input_len; - } - } else { - // standard vector - UnifiedVectorFormat idata; - input.ToUnifiedFormat(args.size(), idata); - - auto input_data = UnifiedVectorFormat::GetData(idata); - for (idx_t i = 0; i < args.size(); i++) { - auto idx = idata.sel->get_index(i); - if (!idata.validity.RowIsValid(idx)) { - continue; - } - auto input_ptr = input_data[idx].GetData(); - auto input_len = input_data[idx].GetSize(); - memcpy(result_data[i].GetDataWriteable() + result_lengths[i], input_ptr, input_len); - result_lengths[i] += input_len; - } - } - } - for (idx_t i = 0; i < args.size(); i++) { - result_data[i].Finalize(); - } -} - -static void ConcatOperator(DataChunk &args, ExpressionState &state, Vector &result) { - BinaryExecutor::Execute( - args.data[0], args.data[1], result, args.size(), [&](string_t a, string_t b) { - auto a_data = a.GetData(); - auto b_data = b.GetData(); - auto a_length = a.GetSize(); - auto b_length = b.GetSize(); - - auto target_length = a_length + b_length; - auto target = StringVector::EmptyString(result, target_length); - auto target_data = target.GetDataWriteable(); - - memcpy(target_data, a_data, a_length); - memcpy(target_data + a_length, b_data, b_length); - target.Finalize(); - return target; - }); -} - -static void TemplatedConcatWS(DataChunk &args, const string_t *sep_data, const SelectionVector &sep_sel, - const SelectionVector &rsel, idx_t count, Vector &result) { - vector result_lengths(args.size(), 0); - vector has_results(args.size(), false); - auto orrified_data = make_unsafe_uniq_array(args.ColumnCount() - 1); - for (idx_t col_idx = 1; col_idx < args.ColumnCount(); col_idx++) { - args.data[col_idx].ToUnifiedFormat(args.size(), orrified_data[col_idx - 1]); - } - - // first figure out the lengths - for (idx_t col_idx = 1; col_idx < args.ColumnCount(); col_idx++) { - auto &idata = orrified_data[col_idx - 1]; - - auto input_data = UnifiedVectorFormat::GetData(idata); - for (idx_t i = 0; i < count; i++) { - auto ridx = rsel.get_index(i); - auto sep_idx = sep_sel.get_index(ridx); - auto idx = idata.sel->get_index(ridx); - if (!idata.validity.RowIsValid(idx)) { - continue; - } - if (has_results[ridx]) { - result_lengths[ridx] += sep_data[sep_idx].GetSize(); - } - result_lengths[ridx] += input_data[idx].GetSize(); - has_results[ridx] = true; - } - } - - // first we allocate the empty strings for each of the values - auto result_data = FlatVector::GetData(result); - for (idx_t i = 0; i < count; i++) { - auto ridx = rsel.get_index(i); - // allocate an empty string of the required size - result_data[ridx] = StringVector::EmptyString(result, result_lengths[ridx]); - // we reuse the result_lengths vector to store the currently appended size - result_lengths[ridx] = 0; - has_results[ridx] = false; - } - - // now that the empty space for the strings has been allocated, perform the concatenation - for (idx_t col_idx = 1; col_idx < args.ColumnCount(); col_idx++) { - auto &idata = orrified_data[col_idx - 1]; - auto input_data = UnifiedVectorFormat::GetData(idata); - for (idx_t i = 0; i < count; i++) { - auto ridx = rsel.get_index(i); - auto sep_idx = sep_sel.get_index(ridx); - auto idx = idata.sel->get_index(ridx); - if (!idata.validity.RowIsValid(idx)) { - continue; - } - if (has_results[ridx]) { - auto sep_size = sep_data[sep_idx].GetSize(); - auto sep_ptr = sep_data[sep_idx].GetData(); - memcpy(result_data[ridx].GetDataWriteable() + result_lengths[ridx], sep_ptr, sep_size); - result_lengths[ridx] += sep_size; - } - auto input_ptr = input_data[idx].GetData(); - auto input_len = input_data[idx].GetSize(); - memcpy(result_data[ridx].GetDataWriteable() + result_lengths[ridx], input_ptr, input_len); - result_lengths[ridx] += input_len; - has_results[ridx] = true; - } - } - for (idx_t i = 0; i < count; i++) { - auto ridx = rsel.get_index(i); - result_data[ridx].Finalize(); - } -} - -static void ConcatWSFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &separator = args.data[0]; - UnifiedVectorFormat vdata; - separator.ToUnifiedFormat(args.size(), vdata); - - result.SetVectorType(VectorType::CONSTANT_VECTOR); - for (idx_t col_idx = 0; col_idx < args.ColumnCount(); col_idx++) { - if (args.data[col_idx].GetVectorType() != VectorType::CONSTANT_VECTOR) { - result.SetVectorType(VectorType::FLAT_VECTOR); - break; - } - } - switch (separator.GetVectorType()) { - case VectorType::CONSTANT_VECTOR: { - if (ConstantVector::IsNull(separator)) { - // constant NULL as separator: return constant NULL vector - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - return; - } - // no null values - auto sel = FlatVector::IncrementalSelectionVector(); - TemplatedConcatWS(args, UnifiedVectorFormat::GetData(vdata), *vdata.sel, *sel, args.size(), result); - return; - } - default: { - // default case: loop over nullmask and create a non-null selection vector - idx_t not_null_count = 0; - SelectionVector not_null_vector(STANDARD_VECTOR_SIZE); - auto &result_mask = FlatVector::Validity(result); - for (idx_t i = 0; i < args.size(); i++) { - if (!vdata.validity.RowIsValid(vdata.sel->get_index(i))) { - result_mask.SetInvalid(i); - } else { - not_null_vector.set_index(not_null_count++, i); - } - } - TemplatedConcatWS(args, UnifiedVectorFormat::GetData(vdata), *vdata.sel, not_null_vector, - not_null_count, result); - return; - } - } -} - -void ConcatFun::RegisterFunction(BuiltinFunctions &set) { - // the concat operator and concat function have different behavior regarding NULLs - // this is strange but seems consistent with postgresql and mysql - // (sqlite does not support the concat function, only the concat operator) - - // the concat operator behaves as one would expect: any NULL value present results in a NULL - // i.e. NULL || 'hello' = NULL - // the concat function, however, treats NULL values as an empty string - // i.e. concat(NULL, 'hello') = 'hello' - // concat_ws functions similarly to the concat function, except the result is NULL if the separator is NULL - // if the separator is not NULL, however, NULL values are counted as empty string - // there is one separate rule: there are no separators added between NULL values - // so the NULL value and empty string are different! - // e.g.: - // concat_ws(',', NULL, NULL) = "" - // concat_ws(',', '', '') = "," - ScalarFunction concat = ScalarFunction("concat", {LogicalType::VARCHAR}, LogicalType::VARCHAR, ConcatFunction); - concat.varargs = LogicalType::VARCHAR; - concat.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - set.AddFunction(concat); - - ScalarFunctionSet concat_op("||"); - concat_op.AddFunction( - ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, ConcatOperator)); - concat_op.AddFunction(ScalarFunction({LogicalType::BLOB, LogicalType::BLOB}, LogicalType::BLOB, ConcatOperator)); - concat_op.AddFunction(ListConcatFun::GetFunction()); - for (auto &fun : concat_op.functions) { - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - } - set.AddFunction(concat_op); - - ScalarFunction concat_ws = ScalarFunction("concat_ws", {LogicalType::VARCHAR, LogicalType::VARCHAR}, - LogicalType::VARCHAR, ConcatWSFunction); - concat_ws.varargs = LogicalType::VARCHAR; - concat_ws.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - set.AddFunction(concat_ws); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -template -static idx_t ContainsUnaligned(const unsigned char *haystack, idx_t haystack_size, const unsigned char *needle, - idx_t base_offset) { - if (NEEDLE_SIZE > haystack_size) { - // needle is bigger than haystack: haystack cannot contain needle - return DConstants::INVALID_INDEX; - } - // contains for a small unaligned needle (3/5/6/7 bytes) - // we perform unsigned integer comparisons to check for equality of the entire needle in a single comparison - // this implementation is inspired by the memmem implementation of freebsd - - // first we set up the needle and the first NEEDLE_SIZE characters of the haystack as UNSIGNED integers - UNSIGNED needle_entry = 0; - UNSIGNED haystack_entry = 0; - const UNSIGNED start = (sizeof(UNSIGNED) * 8) - 8; - const UNSIGNED shift = (sizeof(UNSIGNED) - NEEDLE_SIZE) * 8; - for (int i = 0; i < NEEDLE_SIZE; i++) { - needle_entry |= UNSIGNED(needle[i]) << UNSIGNED(start - i * 8); - haystack_entry |= UNSIGNED(haystack[i]) << UNSIGNED(start - i * 8); - } - // now we perform the actual search - for (idx_t offset = NEEDLE_SIZE; offset < haystack_size; offset++) { - // for this position we first compare the haystack with the needle - if (haystack_entry == needle_entry) { - return base_offset + offset - NEEDLE_SIZE; - } - // now we adjust the haystack entry by - // (1) removing the left-most character (shift by 8) - // (2) adding the next character (bitwise or, with potential shift) - // this shift is only necessary if the needle size is not aligned with the unsigned integer size - // (e.g. needle size 3, unsigned integer size 4, we need to shift by 1) - haystack_entry = (haystack_entry << 8) | ((UNSIGNED(haystack[offset])) << shift); - } - if (haystack_entry == needle_entry) { - return base_offset + haystack_size - NEEDLE_SIZE; - } - return DConstants::INVALID_INDEX; -} - -template -static idx_t ContainsAligned(const unsigned char *haystack, idx_t haystack_size, const unsigned char *needle, - idx_t base_offset) { - if (sizeof(UNSIGNED) > haystack_size) { - // needle is bigger than haystack: haystack cannot contain needle - return DConstants::INVALID_INDEX; - } - // contains for a small needle aligned with unsigned integer (2/4/8) - // similar to ContainsUnaligned, but simpler because we only need to do a reinterpret cast - auto needle_entry = Load(needle); - for (idx_t offset = 0; offset <= haystack_size - sizeof(UNSIGNED); offset++) { - // for this position we first compare the haystack with the needle - auto haystack_entry = Load(haystack + offset); - if (needle_entry == haystack_entry) { - return base_offset + offset; - } - } - return DConstants::INVALID_INDEX; -} - -idx_t ContainsGeneric(const unsigned char *haystack, idx_t haystack_size, const unsigned char *needle, - idx_t needle_size, idx_t base_offset) { - if (needle_size > haystack_size) { - // needle is bigger than haystack: haystack cannot contain needle - return DConstants::INVALID_INDEX; - } - // this implementation is inspired by Raphael Javaux's faststrstr (https://github.com/RaphaelJ/fast_strstr) - // generic contains; note that we can't use strstr because we don't have null-terminated strings anymore - // we keep track of a shifting window sum of all characters with window size equal to needle_size - // this shifting sum is used to avoid calling into memcmp; - // we only need to call into memcmp when the window sum is equal to the needle sum - // when that happens, the characters are potentially the same and we call into memcmp to check if they are - uint32_t sums_diff = 0; - for (idx_t i = 0; i < needle_size; i++) { - sums_diff += haystack[i]; - sums_diff -= needle[i]; - } - idx_t offset = 0; - while (true) { - if (sums_diff == 0 && haystack[offset] == needle[0]) { - if (memcmp(haystack + offset, needle, needle_size) == 0) { - return base_offset + offset; - } - } - if (offset >= haystack_size - needle_size) { - return DConstants::INVALID_INDEX; - } - sums_diff -= haystack[offset]; - sums_diff += haystack[offset + needle_size]; - offset++; - } -} - -idx_t ContainsFun::Find(const unsigned char *haystack, idx_t haystack_size, const unsigned char *needle, - idx_t needle_size) { - D_ASSERT(needle_size > 0); - // start off by performing a memchr to find the first character of the - auto location = memchr(haystack, needle[0], haystack_size); - if (location == nullptr) { - return DConstants::INVALID_INDEX; - } - idx_t base_offset = const_uchar_ptr_cast(location) - haystack; - haystack_size -= base_offset; - haystack = const_uchar_ptr_cast(location); - // switch algorithm depending on needle size - switch (needle_size) { - case 1: - return base_offset; - case 2: - return ContainsAligned(haystack, haystack_size, needle, base_offset); - case 3: - return ContainsUnaligned(haystack, haystack_size, needle, base_offset); - case 4: - return ContainsAligned(haystack, haystack_size, needle, base_offset); - case 5: - return ContainsUnaligned(haystack, haystack_size, needle, base_offset); - case 6: - return ContainsUnaligned(haystack, haystack_size, needle, base_offset); - case 7: - return ContainsUnaligned(haystack, haystack_size, needle, base_offset); - case 8: - return ContainsAligned(haystack, haystack_size, needle, base_offset); - default: - return ContainsGeneric(haystack, haystack_size, needle, needle_size, base_offset); - } -} - -idx_t ContainsFun::Find(const string_t &haystack_s, const string_t &needle_s) { - auto haystack = const_uchar_ptr_cast(haystack_s.GetData()); - auto haystack_size = haystack_s.GetSize(); - auto needle = const_uchar_ptr_cast(needle_s.GetData()); - auto needle_size = needle_s.GetSize(); - if (needle_size == 0) { - // empty needle: always true - return 0; - } - return ContainsFun::Find(haystack, haystack_size, needle, needle_size); -} - -struct ContainsOperator { - template - static inline TR Operation(TA left, TB right) { - return ContainsFun::Find(left, right) != DConstants::INVALID_INDEX; - } -}; - -ScalarFunction ContainsFun::GetFunction() { - return ScalarFunction("contains", // name of the function - {LogicalType::VARCHAR, LogicalType::VARCHAR}, // argument list - LogicalType::BOOLEAN, // return type - ScalarFunction::BinaryFunction); -} - -void ContainsFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(GetFunction()); -} - -} // namespace duckdb - - - - - - - - - - -namespace duckdb { - -// length returns the number of unicode codepoints -struct StringLengthOperator { - template - static inline TR Operation(TA input) { - return LengthFun::Length(input); - } -}; - -struct GraphemeCountOperator { - template - static inline TR Operation(TA input) { - return LengthFun::GraphemeCount(input); - } -}; - -struct ArrayLengthOperator { - template - static inline TR Operation(TA input) { - return input.length; - } -}; - -struct ArrayLengthBinaryOperator { - template - static inline TR Operation(TA input, TB dimension) { - if (dimension != 1) { - throw NotImplementedException("array_length for dimensions other than 1 not implemented"); - } - return input.length; - } -}; - -// strlen returns the size in bytes -struct StrLenOperator { - template - static inline TR Operation(TA input) { - return input.GetSize(); - } -}; - -struct OctetLenOperator { - template - static inline TR Operation(TA input) { - return Bit::OctetLength(input); - } -}; - -// bitlen returns the size in bits -struct BitLenOperator { - template - static inline TR Operation(TA input) { - return 8 * input.GetSize(); - } -}; - -// bitstringlen returns the amount of bits in a bitstring -struct BitStringLenOperator { - template - static inline TR Operation(TA input) { - return Bit::BitLength(input); - } -}; - -static unique_ptr LengthPropagateStats(ClientContext &context, FunctionStatisticsInput &input) { - auto &child_stats = input.child_stats; - auto &expr = input.expr; - D_ASSERT(child_stats.size() == 1); - // can only propagate stats if the children have stats - if (!StringStats::CanContainUnicode(child_stats[0])) { - expr.function.function = ScalarFunction::UnaryFunction; - } - return nullptr; -} - -static unique_ptr ListLengthBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - if (arguments[0]->HasParameter()) { - throw ParameterNotResolvedException(); - } - bound_function.arguments[0] = arguments[0]->return_type; - return nullptr; -} - -void LengthFun::RegisterFunction(BuiltinFunctions &set) { - ScalarFunction array_length_unary = - ScalarFunction({LogicalType::LIST(LogicalType::ANY)}, LogicalType::BIGINT, - ScalarFunction::UnaryFunction, ListLengthBind); - ScalarFunctionSet length("length"); - length.AddFunction(ScalarFunction({LogicalType::VARCHAR}, LogicalType::BIGINT, - ScalarFunction::UnaryFunction, nullptr, - nullptr, LengthPropagateStats)); - length.AddFunction(ScalarFunction({LogicalType::BIT}, LogicalType::BIGINT, - ScalarFunction::UnaryFunction)); - length.AddFunction(array_length_unary); - set.AddFunction(length); - length.name = "len"; - set.AddFunction(length); - - ScalarFunctionSet length_grapheme("length_grapheme"); - length_grapheme.AddFunction(ScalarFunction({LogicalType::VARCHAR}, LogicalType::BIGINT, - ScalarFunction::UnaryFunction, - nullptr, nullptr, LengthPropagateStats)); - set.AddFunction(length_grapheme); - - ScalarFunctionSet array_length("array_length"); - array_length.AddFunction(array_length_unary); - array_length.AddFunction(ScalarFunction( - {LogicalType::LIST(LogicalType::ANY), LogicalType::BIGINT}, LogicalType::BIGINT, - ScalarFunction::BinaryFunction, ListLengthBind)); - set.AddFunction(array_length); - - set.AddFunction(ScalarFunction("strlen", {LogicalType::VARCHAR}, LogicalType::BIGINT, - ScalarFunction::UnaryFunction)); - ScalarFunctionSet bit_length("bit_length"); - bit_length.AddFunction(ScalarFunction({LogicalType::VARCHAR}, LogicalType::BIGINT, - ScalarFunction::UnaryFunction)); - bit_length.AddFunction(ScalarFunction({LogicalType::BIT}, LogicalType::BIGINT, - ScalarFunction::UnaryFunction)); - set.AddFunction(bit_length); - // length for BLOB type - ScalarFunctionSet octet_length("octet_length"); - octet_length.AddFunction(ScalarFunction({LogicalType::BLOB}, LogicalType::BIGINT, - ScalarFunction::UnaryFunction)); - octet_length.AddFunction(ScalarFunction({LogicalType::BIT}, LogicalType::BIGINT, - ScalarFunction::UnaryFunction)); - set.AddFunction(octet_length); -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -struct StandardCharacterReader { - static char Operation(const char *data, idx_t pos) { - return data[pos]; - } -}; - -struct ASCIILCaseReader { - static char Operation(const char *data, idx_t pos) { - return (char)LowerFun::ascii_to_lower_map[(uint8_t)data[pos]]; - } -}; - -template -bool TemplatedLikeOperator(const char *sdata, idx_t slen, const char *pdata, idx_t plen, char escape) { - idx_t pidx = 0; - idx_t sidx = 0; - for (; pidx < plen && sidx < slen; pidx++) { - char pchar = READER::Operation(pdata, pidx); - char schar = READER::Operation(sdata, sidx); - if (HAS_ESCAPE && pchar == escape) { - pidx++; - if (pidx == plen) { - throw SyntaxException("Like pattern must not end with escape character!"); - } - if (pdata[pidx] != schar) { - return false; - } - sidx++; - } else if (pchar == UNDERSCORE) { - sidx++; - } else if (pchar == PERCENTAGE) { - pidx++; - while (pidx < plen && pdata[pidx] == PERCENTAGE) { - pidx++; - } - if (pidx == plen) { - return true; /* tail is acceptable */ - } - for (; sidx < slen; sidx++) { - if (TemplatedLikeOperator( - sdata + sidx, slen - sidx, pdata + pidx, plen - pidx, escape)) { - return true; - } - } - return false; - } else if (pchar == schar) { - sidx++; - } else { - return false; - } - } - while (pidx < plen && pdata[pidx] == PERCENTAGE) { - pidx++; - } - return pidx == plen && sidx == slen; -} - -struct LikeSegment { - explicit LikeSegment(string pattern) : pattern(std::move(pattern)) { - } - - string pattern; -}; - -struct LikeMatcher : public FunctionData { - LikeMatcher(string like_pattern_p, vector segments, bool has_start_percentage, bool has_end_percentage) - : like_pattern(std::move(like_pattern_p)), segments(std::move(segments)), - has_start_percentage(has_start_percentage), has_end_percentage(has_end_percentage) { - } - - bool Match(string_t &str) { - auto str_data = const_uchar_ptr_cast(str.GetData()); - auto str_len = str.GetSize(); - idx_t segment_idx = 0; - idx_t end_idx = segments.size() - 1; - if (!has_start_percentage) { - // no start sample_size: match the first part of the string directly - auto &segment = segments[0]; - if (str_len < segment.pattern.size()) { - return false; - } - if (memcmp(str_data, segment.pattern.c_str(), segment.pattern.size()) != 0) { - return false; - } - str_data += segment.pattern.size(); - str_len -= segment.pattern.size(); - segment_idx++; - if (segments.size() == 1) { - // only one segment, and it matches - // we have a match if there is an end sample_size, OR if the memcmp was an exact match (remaining str is - // empty) - return has_end_percentage || str_len == 0; - } - } - // main match loop: for every segment in the middle, use Contains to find the needle in the haystack - for (; segment_idx < end_idx; segment_idx++) { - auto &segment = segments[segment_idx]; - // find the pattern of the current segment - idx_t next_offset = ContainsFun::Find(str_data, str_len, const_uchar_ptr_cast(segment.pattern.c_str()), - segment.pattern.size()); - if (next_offset == DConstants::INVALID_INDEX) { - // could not find this pattern in the string: no match - return false; - } - idx_t offset = next_offset + segment.pattern.size(); - str_data += offset; - str_len -= offset; - } - if (!has_end_percentage) { - end_idx--; - // no end sample_size: match the final segment now - auto &segment = segments.back(); - if (str_len < segment.pattern.size()) { - return false; - } - if (memcmp(str_data + str_len - segment.pattern.size(), segment.pattern.c_str(), segment.pattern.size()) != - 0) { - return false; - } - return true; - } else { - auto &segment = segments.back(); - // find the pattern of the current segment - idx_t next_offset = ContainsFun::Find(str_data, str_len, const_uchar_ptr_cast(segment.pattern.c_str()), - segment.pattern.size()); - return next_offset != DConstants::INVALID_INDEX; - } - } - - static unique_ptr CreateLikeMatcher(string like_pattern, char escape = '\0') { - vector segments; - idx_t last_non_pattern = 0; - bool has_start_percentage = false; - bool has_end_percentage = false; - for (idx_t i = 0; i < like_pattern.size(); i++) { - auto ch = like_pattern[i]; - if (ch == escape || ch == '%' || ch == '_') { - // special character, push a constant pattern - if (i > last_non_pattern) { - segments.emplace_back(like_pattern.substr(last_non_pattern, i - last_non_pattern)); - } - last_non_pattern = i + 1; - if (ch == escape || ch == '_') { - // escape or underscore: could not create efficient like matcher - // FIXME: we could handle escaped percentages here - return nullptr; - } else { - // sample_size - if (i == 0) { - has_start_percentage = true; - } - if (i + 1 == like_pattern.size()) { - has_end_percentage = true; - } - } - } - } - if (last_non_pattern < like_pattern.size()) { - segments.emplace_back(like_pattern.substr(last_non_pattern, like_pattern.size() - last_non_pattern)); - } - if (segments.empty()) { - return nullptr; - } - return make_uniq(std::move(like_pattern), std::move(segments), has_start_percentage, - has_end_percentage); - } - - unique_ptr Copy() const override { - return make_uniq(like_pattern, segments, has_start_percentage, has_end_percentage); - } - - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return like_pattern == other.like_pattern; - } - -private: - string like_pattern; - vector segments; - bool has_start_percentage; - bool has_end_percentage; -}; - -static unique_ptr LikeBindFunction(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - // pattern is the second argument. If its constant, we can already prepare the pattern and store it for later. - D_ASSERT(arguments.size() == 2 || arguments.size() == 3); - if (arguments[1]->IsFoldable()) { - Value pattern_str = ExpressionExecutor::EvaluateScalar(context, *arguments[1]); - return LikeMatcher::CreateLikeMatcher(pattern_str.ToString()); - } - return nullptr; -} - -bool LikeOperatorFunction(const char *s, idx_t slen, const char *pattern, idx_t plen, char escape) { - return TemplatedLikeOperator<'%', '_', true>(s, slen, pattern, plen, escape); -} - -bool LikeOperatorFunction(const char *s, idx_t slen, const char *pattern, idx_t plen) { - return TemplatedLikeOperator<'%', '_', false>(s, slen, pattern, plen, '\0'); -} - -bool LikeOperatorFunction(string_t &s, string_t &pat) { - return LikeOperatorFunction(s.GetData(), s.GetSize(), pat.GetData(), pat.GetSize()); -} - -bool LikeOperatorFunction(string_t &s, string_t &pat, char escape) { - return LikeOperatorFunction(s.GetData(), s.GetSize(), pat.GetData(), pat.GetSize(), escape); -} - -bool LikeFun::Glob(const char *string, idx_t slen, const char *pattern, idx_t plen, bool allow_question_mark) { - idx_t sidx = 0; - idx_t pidx = 0; -main_loop : { - // main matching loop - while (sidx < slen && pidx < plen) { - char s = string[sidx]; - char p = pattern[pidx]; - switch (p) { - case '*': { - // asterisk: match any set of characters - // skip any subsequent asterisks - pidx++; - while (pidx < plen && pattern[pidx] == '*') { - pidx++; - } - // if the asterisk is the last character, the pattern always matches - if (pidx == plen) { - return true; - } - // recursively match the remainder of the pattern - for (; sidx < slen; sidx++) { - if (LikeFun::Glob(string + sidx, slen - sidx, pattern + pidx, plen - pidx)) { - return true; - } - } - return false; - } - case '?': - // when enabled: matches anything but null - if (allow_question_mark) { - break; - } - DUCKDB_EXPLICIT_FALLTHROUGH; - case '[': - pidx++; - goto parse_bracket; - case '\\': - // escape character, next character needs to match literally - pidx++; - // check that we still have a character remaining - if (pidx == plen) { - return false; - } - p = pattern[pidx]; - if (s != p) { - return false; - } - break; - default: - // not a control character: characters need to match literally - if (s != p) { - return false; - } - break; - } - sidx++; - pidx++; - } - while (pidx < plen && pattern[pidx] == '*') { - pidx++; - } - // we are finished only if we have consumed the full pattern - return pidx == plen && sidx == slen; -} -parse_bracket : { - // inside a bracket - if (pidx == plen) { - return false; - } - // check the first character - // if it is an exclamation mark we need to invert our logic - char p = pattern[pidx]; - char s = string[sidx]; - bool invert = false; - if (p == '!') { - invert = true; - pidx++; - } - bool found_match = invert; - idx_t start_pos = pidx; - bool found_closing_bracket = false; - // now check the remainder of the pattern - while (pidx < plen) { - p = pattern[pidx]; - // if the first character is a closing bracket, we match it literally - // otherwise it indicates an end of bracket - if (p == ']' && pidx > start_pos) { - // end of bracket found: we are done - found_closing_bracket = true; - pidx++; - break; - } - // we either match a range (a-b) or a single character (a) - // check if the next character is a dash - if (pidx + 1 == plen) { - // no next character! - break; - } - bool matches; - if (pattern[pidx + 1] == '-') { - // range! find the next character in the range - if (pidx + 2 == plen) { - break; - } - char next_char = pattern[pidx + 2]; - // check if the current character is within the range - matches = s >= p && s <= next_char; - // shift the pattern forward past the range - pidx += 3; - } else { - // no range! perform a direct match - matches = p == s; - // shift the pattern forward past the character - pidx++; - } - if (found_match == invert && matches) { - // found a match! set the found_matches flag - // we keep on pattern matching after this until we reach the end bracket - // however, we don't need to update the found_match flag anymore - found_match = !invert; - } - } - if (!found_closing_bracket) { - // no end of bracket: invalid pattern - return false; - } - if (!found_match) { - // did not match the bracket: return false; - return false; - } - // finished the bracket matching: move forward - sidx++; - goto main_loop; -} -} - -static char GetEscapeChar(string_t escape) { - // Only one escape character should be allowed - if (escape.GetSize() > 1) { - throw SyntaxException("Invalid escape string. Escape string must be empty or one character."); - } - return escape.GetSize() == 0 ? '\0' : *escape.GetData(); -} - -struct LikeEscapeOperator { - template - static inline bool Operation(TA str, TB pattern, TC escape) { - char escape_char = GetEscapeChar(escape); - return LikeOperatorFunction(str.GetData(), str.GetSize(), pattern.GetData(), pattern.GetSize(), escape_char); - } -}; - -struct NotLikeEscapeOperator { - template - static inline bool Operation(TA str, TB pattern, TC escape) { - return !LikeEscapeOperator::Operation(str, pattern, escape); - } -}; - -struct LikeOperator { - template - static inline TR Operation(TA str, TB pattern) { - return LikeOperatorFunction(str, pattern); - } -}; - -bool ILikeOperatorFunction(string_t &str, string_t &pattern, char escape = '\0') { - auto str_data = str.GetData(); - auto str_size = str.GetSize(); - auto pat_data = pattern.GetData(); - auto pat_size = pattern.GetSize(); - - // lowercase both the str and the pattern - idx_t str_llength = LowerFun::LowerLength(str_data, str_size); - auto str_ldata = make_unsafe_uniq_array(str_llength); - LowerFun::LowerCase(str_data, str_size, str_ldata.get()); - - idx_t pat_llength = LowerFun::LowerLength(pat_data, pat_size); - auto pat_ldata = make_unsafe_uniq_array(pat_llength); - LowerFun::LowerCase(pat_data, pat_size, pat_ldata.get()); - string_t str_lcase(str_ldata.get(), str_llength); - string_t pat_lcase(pat_ldata.get(), pat_llength); - return LikeOperatorFunction(str_lcase, pat_lcase, escape); -} - -struct ILikeEscapeOperator { - template - static inline bool Operation(TA str, TB pattern, TC escape) { - char escape_char = GetEscapeChar(escape); - return ILikeOperatorFunction(str, pattern, escape_char); - } -}; - -struct NotILikeEscapeOperator { - template - static inline bool Operation(TA str, TB pattern, TC escape) { - return !ILikeEscapeOperator::Operation(str, pattern, escape); - } -}; - -struct ILikeOperator { - template - static inline TR Operation(TA str, TB pattern) { - return ILikeOperatorFunction(str, pattern); - } -}; - -struct NotLikeOperator { - template - static inline TR Operation(TA str, TB pattern) { - return !LikeOperatorFunction(str, pattern); - } -}; - -struct NotILikeOperator { - template - static inline TR Operation(TA str, TB pattern) { - return !ILikeOperator::Operation(str, pattern); - } -}; - -struct ILikeOperatorASCII { - template - static inline TR Operation(TA str, TB pattern) { - return TemplatedLikeOperator<'%', '_', false, ASCIILCaseReader>(str.GetData(), str.GetSize(), pattern.GetData(), - pattern.GetSize(), '\0'); - } -}; - -struct NotILikeOperatorASCII { - template - static inline TR Operation(TA str, TB pattern) { - return !ILikeOperatorASCII::Operation(str, pattern); - } -}; - -struct GlobOperator { - template - static inline TR Operation(TA str, TB pattern) { - return LikeFun::Glob(str.GetData(), str.GetSize(), pattern.GetData(), pattern.GetSize()); - } -}; - -// This can be moved to the scalar_function class -template -static void LikeEscapeFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &str = args.data[0]; - auto &pattern = args.data[1]; - auto &escape = args.data[2]; - - TernaryExecutor::Execute( - str, pattern, escape, result, args.size(), FUNC::template Operation); -} - -template -static unique_ptr ILikePropagateStats(ClientContext &context, FunctionStatisticsInput &input) { - auto &child_stats = input.child_stats; - auto &expr = input.expr; - D_ASSERT(child_stats.size() >= 1); - // can only propagate stats if the children have stats - if (!StringStats::CanContainUnicode(child_stats[0])) { - expr.function.function = ScalarFunction::BinaryFunction; - } - return nullptr; -} - -template -static void RegularLikeFunction(DataChunk &input, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - if (func_expr.bind_info) { - auto &matcher = func_expr.bind_info->Cast(); - // use fast like matcher - UnaryExecutor::Execute(input.data[0], result, input.size(), [&](string_t input) { - return INVERT ? !matcher.Match(input) : matcher.Match(input); - }); - } else { - // use generic like matcher - BinaryExecutor::ExecuteStandard(input.data[0], input.data[1], result, - input.size()); - } -} -void LikeFun::RegisterFunction(BuiltinFunctions &set) { - // like - set.AddFunction(GetLikeFunction()); - // not like - set.AddFunction(ScalarFunction("!~~", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, - RegularLikeFunction, LikeBindFunction)); - // glob - set.AddFunction(ScalarFunction("~~~", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, - ScalarFunction::BinaryFunction)); - // ilike - set.AddFunction(ScalarFunction("~~*", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, - ScalarFunction::BinaryFunction, nullptr, - nullptr, ILikePropagateStats)); - // not ilike - set.AddFunction(ScalarFunction("!~~*", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, - ScalarFunction::BinaryFunction, nullptr, - nullptr, ILikePropagateStats)); -} - -ScalarFunction LikeFun::GetLikeFunction() { - return ScalarFunction("~~", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, - RegularLikeFunction, LikeBindFunction); -} - -void LikeEscapeFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(GetLikeEscapeFun()); - set.AddFunction({"not_like_escape"}, - ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, - LogicalType::BOOLEAN, LikeEscapeFunction)); - - set.AddFunction({"ilike_escape"}, ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, - LogicalType::BOOLEAN, LikeEscapeFunction)); - set.AddFunction({"not_ilike_escape"}, - ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, - LogicalType::BOOLEAN, LikeEscapeFunction)); -} - -ScalarFunction LikeEscapeFun::GetLikeEscapeFun() { - return ScalarFunction("like_escape", {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, - LogicalType::BOOLEAN, LikeEscapeFunction); -} -} // namespace duckdb - - - - -namespace duckdb { - -struct NFCNormalizeOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto input_data = input.GetData(); - auto input_length = input.GetSize(); - if (StripAccentsFun::IsAscii(input_data, input_length)) { - return input; - } - auto normalized_str = Utf8Proc::Normalize(input_data, input_length); - D_ASSERT(normalized_str); - auto result_str = StringVector::AddString(result, normalized_str); - free(normalized_str); - return result_str; - } -}; - -static void NFCNormalizeFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 1); - - UnaryExecutor::ExecuteString(args.data[0], result, args.size()); - StringVector::AddHeapReference(result, args.data[0]); -} - -ScalarFunction NFCNormalizeFun::GetFunction() { - return ScalarFunction("nfc_normalize", {LogicalType::VARCHAR}, LogicalType::VARCHAR, NFCNormalizeFunction); -} - -void NFCNormalizeFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(NFCNormalizeFun::GetFunction()); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -static bool PrefixFunction(const string_t &str, const string_t &pattern); - -struct PrefixOperator { - template - static inline TR Operation(TA left, TB right) { - return PrefixFunction(left, right); - } -}; -static bool PrefixFunction(const string_t &str, const string_t &pattern) { - auto str_length = str.GetSize(); - auto patt_length = pattern.GetSize(); - if (patt_length > str_length) { - return false; - } - if (patt_length <= string_t::PREFIX_LENGTH) { - // short prefix - if (patt_length == 0) { - // length = 0, return true - return true; - } - - // prefix early out - const char *str_pref = str.GetPrefix(); - const char *patt_pref = pattern.GetPrefix(); - for (idx_t i = 0; i < patt_length; ++i) { - if (str_pref[i] != patt_pref[i]) { - return false; - } - } - return true; - } else { - // prefix early out - const char *str_pref = str.GetPrefix(); - const char *patt_pref = pattern.GetPrefix(); - for (idx_t i = 0; i < string_t::PREFIX_LENGTH; ++i) { - if (str_pref[i] != patt_pref[i]) { - // early out - return false; - } - } - // compare the rest of the prefix - const char *str_data = str.GetData(); - const char *patt_data = pattern.GetData(); - D_ASSERT(patt_length <= str_length); - for (idx_t i = string_t::PREFIX_LENGTH; i < patt_length; ++i) { - if (str_data[i] != patt_data[i]) { - return false; - } - } - return true; - } -} - -ScalarFunction PrefixFun::GetFunction() { - return ScalarFunction("prefix", // name of the function - {LogicalType::VARCHAR, LogicalType::VARCHAR}, // argument list - LogicalType::BOOLEAN, // return type - ScalarFunction::BinaryFunction); -} - -void PrefixFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(GetFunction()); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -using regexp_util::CreateStringPiece; -using regexp_util::Extract; -using regexp_util::ParseRegexOptions; -using regexp_util::TryParseConstantPattern; - -unique_ptr -RegexpExtractAll::InitLocalState(ExpressionState &state, const BoundFunctionExpression &expr, FunctionData *bind_data) { - auto &info = bind_data->Cast(); - if (info.constant_pattern) { - return make_uniq(info, true); - } - return nullptr; -} - -// Forwards startpos automatically -bool ExtractAll(duckdb_re2::StringPiece &input, duckdb_re2::RE2 &pattern, idx_t *startpos, - duckdb_re2::StringPiece *groups, int ngroups) { - - D_ASSERT(pattern.ok()); - D_ASSERT(pattern.NumberOfCapturingGroups() == ngroups); - - if (!pattern.Match(input, *startpos, input.size(), pattern.Anchored(), groups, ngroups + 1)) { - return false; - } - idx_t consumed = static_cast(groups[0].end() - (input.begin() + *startpos)); - if (!consumed) { - // Empty match found, have to manually forward the input - // to avoid an infinite loop - // FIXME: support unicode characters - consumed++; - while (*startpos + consumed < input.length() && !LengthFun::IsCharacter(input[*startpos + consumed])) { - consumed++; - } - } - *startpos += consumed; - return true; -} - -void ExtractSingleTuple(const string_t &string, duckdb_re2::RE2 &pattern, int32_t group, RegexStringPieceArgs &args, - Vector &result, idx_t row) { - auto input = CreateStringPiece(string); - - auto &child_vector = ListVector::GetEntry(result); - auto list_content = FlatVector::GetData(child_vector); - auto &child_validity = FlatVector::Validity(child_vector); - - auto current_list_size = ListVector::GetListSize(result); - auto current_list_capacity = ListVector::GetListCapacity(result); - - auto result_data = FlatVector::GetData(result); - auto &list_entry = result_data[row]; - list_entry.offset = current_list_size; - - if (group < 0) { - list_entry.length = 0; - return; - } - // If the requested group index is out of bounds - // we want to throw only if there is a match - bool throw_on_group_found = (idx_t)group > args.size; - - idx_t startpos = 0; - for (idx_t iteration = 0; ExtractAll(input, pattern, &startpos, args.group_buffer, args.size); iteration++) { - if (!iteration && throw_on_group_found) { - throw InvalidInputException("Pattern has %d groups. Cannot access group %d", args.size, group); - } - - // Make sure we have enough room for the new entries - if (current_list_size + 1 >= current_list_capacity) { - ListVector::Reserve(result, current_list_capacity * 2); - current_list_capacity = ListVector::GetListCapacity(result); - list_content = FlatVector::GetData(child_vector); - } - - // Write the captured groups into the list-child vector - auto &match_group = args.group_buffer[group]; - - idx_t child_idx = current_list_size; - if (match_group.empty()) { - // This group was not matched - list_content[child_idx] = string_t(string.GetData(), 0); - if (match_group.begin() == nullptr) { - // This group is optional - child_validity.SetInvalid(child_idx); - } - } else { - // Every group is a substring of the original, we can find out the offset using the pointer - // the 'match_group' address is guaranteed to be bigger than that of the source - D_ASSERT(const_char_ptr_cast(match_group.begin()) >= string.GetData()); - idx_t offset = match_group.begin() - string.GetData(); - list_content[child_idx] = string_t(string.GetData() + offset, match_group.size()); - } - current_list_size++; - if (startpos > input.size()) { - // Empty match found at the end of the string - break; - } - } - list_entry.length = current_list_size - list_entry.offset; - ListVector::SetListSize(result, current_list_size); -} - -int32_t GetGroupIndex(DataChunk &args, idx_t row, int32_t &result) { - if (args.ColumnCount() < 3) { - result = 0; - return true; - } - UnifiedVectorFormat format; - args.data[2].ToUnifiedFormat(args.size(), format); - idx_t index = format.sel->get_index(row); - if (!format.validity.RowIsValid(index)) { - return false; - } - result = UnifiedVectorFormat::GetData(format)[index]; - return true; -} - -duckdb_re2::RE2 &GetPattern(const RegexpBaseBindData &info, ExpressionState &state, - unique_ptr &pattern_p) { - if (info.constant_pattern) { - auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); - return lstate.constant_pattern; - } - D_ASSERT(pattern_p); - return *pattern_p; -} - -RegexStringPieceArgs &GetGroupsBuffer(const RegexpBaseBindData &info, ExpressionState &state, - unique_ptr &groups_p) { - if (info.constant_pattern) { - auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); - return lstate.group_buffer; - } - D_ASSERT(groups_p); - return *groups_p; -} - -void RegexpExtractAll::Execute(DataChunk &args, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - const auto &info = func_expr.bind_info->Cast(); - - auto &strings = args.data[0]; - auto &patterns = args.data[1]; - D_ASSERT(result.GetType().id() == LogicalTypeId::LIST); - auto &output_child = ListVector::GetEntry(result); - - UnifiedVectorFormat strings_data; - strings.ToUnifiedFormat(args.size(), strings_data); - - UnifiedVectorFormat pattern_data; - patterns.ToUnifiedFormat(args.size(), pattern_data); - - ListVector::Reserve(result, STANDARD_VECTOR_SIZE); - // Reference the 'strings' StringBuffer, because we won't need to allocate new data - // for the result, all returned strings are substrings of the originals - output_child.SetAuxiliary(strings.GetAuxiliary()); - - // Avoid doing extra work if all the inputs are constant - idx_t tuple_count = args.AllConstant() ? 1 : args.size(); - - unique_ptr non_const_args; - unique_ptr stored_re; - if (!info.constant_pattern) { - non_const_args = make_uniq(); - } else { - // Verify that the constant pattern is valid - auto &re = GetPattern(info, state, stored_re); - auto group_count_p = re.NumberOfCapturingGroups(); - if (group_count_p == -1) { - throw InvalidInputException("Pattern failed to parse, error: '%s'", re.error()); - } - } - - for (idx_t row = 0; row < tuple_count; row++) { - bool pattern_valid = true; - if (!info.constant_pattern) { - // Check if the pattern is NULL or not, - // and compile the pattern if it's not constant - auto pattern_idx = pattern_data.sel->get_index(row); - if (!pattern_data.validity.RowIsValid(pattern_idx)) { - pattern_valid = false; - } else { - auto &pattern_p = UnifiedVectorFormat::GetData(pattern_data)[pattern_idx]; - auto pattern_strpiece = CreateStringPiece(pattern_p); - stored_re = make_uniq(pattern_strpiece, info.options); - - // Increase the size of the args buffer if needed - auto group_count_p = stored_re->NumberOfCapturingGroups(); - if (group_count_p == -1) { - throw InvalidInputException("Pattern failed to parse, error: '%s'", stored_re->error()); - } - non_const_args->SetSize(group_count_p); - } - } - - auto string_idx = strings_data.sel->get_index(row); - int32_t group_index; - if (!pattern_valid || !strings_data.validity.RowIsValid(string_idx) || !GetGroupIndex(args, row, group_index)) { - // If something is NULL, the result is NULL - // FIXME: do we even need 'SPECIAL_HANDLING'? - auto result_data = FlatVector::GetData(result); - auto &result_validity = FlatVector::Validity(result); - result_data[row].length = 0; - result_data[row].offset = ListVector::GetListSize(result); - result_validity.SetInvalid(row); - continue; - } - - auto &re = GetPattern(info, state, stored_re); - auto &groups = GetGroupsBuffer(info, state, non_const_args); - auto &string = UnifiedVectorFormat::GetData(strings_data)[string_idx]; - ExtractSingleTuple(string, re, group_index, groups, result, row); - } - - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } -} - -unique_ptr RegexpExtractAll::Bind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - D_ASSERT(arguments.size() >= 2); - - duckdb_re2::RE2::Options options; - - string constant_string; - bool constant_pattern = TryParseConstantPattern(context, *arguments[1], constant_string); - - if (arguments.size() >= 4) { - ParseRegexOptions(context, *arguments[3], options); - } - return make_uniq(options, std::move(constant_string), constant_pattern, ""); -} - -} // namespace duckdb - - - -namespace duckdb { - -namespace regexp_util { - -bool TryParseConstantPattern(ClientContext &context, Expression &expr, string &constant_string) { - if (!expr.IsFoldable()) { - return false; - } - Value pattern_str = ExpressionExecutor::EvaluateScalar(context, expr); - if (!pattern_str.IsNull() && pattern_str.type().id() == LogicalTypeId::VARCHAR) { - constant_string = StringValue::Get(pattern_str); - return true; - } - return false; -} - -void ParseRegexOptions(const string &options, duckdb_re2::RE2::Options &result, bool *global_replace) { - for (idx_t i = 0; i < options.size(); i++) { - switch (options[i]) { - case 'c': - // case-sensitive matching - result.set_case_sensitive(true); - break; - case 'i': - // case-insensitive matching - result.set_case_sensitive(false); - break; - case 'l': - // literal matching - result.set_literal(true); - break; - case 'm': - case 'n': - case 'p': - // newline-sensitive matching - result.set_dot_nl(false); - break; - case 's': - // non-newline-sensitive matching - result.set_dot_nl(true); - break; - case 'g': - // global replace, only available for regexp_replace - if (global_replace) { - *global_replace = true; - } else { - throw InvalidInputException("Option 'g' (global replace) is only valid for regexp_replace"); - } - break; - case ' ': - case '\t': - case '\n': - // ignore whitespace - break; - default: - throw InvalidInputException("Unrecognized Regex option %c", options[i]); - } - } -} - -void ParseRegexOptions(ClientContext &context, Expression &expr, RE2::Options &target, bool *global_replace) { - if (expr.HasParameter()) { - throw ParameterNotResolvedException(); - } - if (!expr.IsFoldable()) { - throw InvalidInputException("Regex options field must be a constant"); - } - Value options_str = ExpressionExecutor::EvaluateScalar(context, expr); - if (options_str.IsNull()) { - throw InvalidInputException("Regex options field must not be NULL"); - } - if (options_str.type().id() != LogicalTypeId::VARCHAR) { - throw InvalidInputException("Regex options field must be a string"); - } - ParseRegexOptions(StringValue::Get(options_str), target, global_replace); -} - -} // namespace regexp_util - -} // namespace duckdb - - - - - - - - - - - - -namespace duckdb { - -using regexp_util::CreateStringPiece; -using regexp_util::Extract; -using regexp_util::ParseRegexOptions; -using regexp_util::TryParseConstantPattern; - -static bool RegexOptionsEquals(const duckdb_re2::RE2::Options &opt_a, const duckdb_re2::RE2::Options &opt_b) { - return opt_a.case_sensitive() == opt_b.case_sensitive(); -} - -RegexpBaseBindData::RegexpBaseBindData() : constant_pattern(false) { -} -RegexpBaseBindData::RegexpBaseBindData(duckdb_re2::RE2::Options options, string constant_string_p, - bool constant_pattern) - : options(options), constant_string(std::move(constant_string_p)), constant_pattern(constant_pattern) { -} - -RegexpBaseBindData::~RegexpBaseBindData() { -} - -bool RegexpBaseBindData::Equals(const FunctionData &other_p) const { - auto &other = other_p.Cast(); - return constant_pattern == other.constant_pattern && constant_string == other.constant_string && - RegexOptionsEquals(options, other.options); -} - -unique_ptr RegexInitLocalState(ExpressionState &state, const BoundFunctionExpression &expr, - FunctionData *bind_data) { - auto &info = bind_data->Cast(); - if (info.constant_pattern) { - return make_uniq(info); - } - return nullptr; -} - -//===--------------------------------------------------------------------===// -// Regexp Matches -//===--------------------------------------------------------------------===// -RegexpMatchesBindData::RegexpMatchesBindData(duckdb_re2::RE2::Options options, string constant_string_p, - bool constant_pattern) - : RegexpBaseBindData(options, std::move(constant_string_p), constant_pattern) { - if (constant_pattern) { - auto pattern = make_uniq(constant_string, options); - if (!pattern->ok()) { - throw Exception(pattern->error()); - } - - range_success = pattern->PossibleMatchRange(&range_min, &range_max, 1000); - } else { - range_success = false; - } -} - -RegexpMatchesBindData::RegexpMatchesBindData(duckdb_re2::RE2::Options options, string constant_string_p, - bool constant_pattern, string range_min_p, string range_max_p, - bool range_success) - : RegexpBaseBindData(options, std::move(constant_string_p), constant_pattern), range_min(std::move(range_min_p)), - range_max(std::move(range_max_p)), range_success(range_success) { -} - -unique_ptr RegexpMatchesBindData::Copy() const { - return make_uniq(options, constant_string, constant_pattern, range_min, range_max, - range_success); -} - -unique_ptr RegexpMatchesBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - // pattern is the second argument. If its constant, we can already prepare the pattern and store it for later. - D_ASSERT(arguments.size() == 2 || arguments.size() == 3); - RE2::Options options; - options.set_log_errors(false); - if (arguments.size() == 3) { - ParseRegexOptions(context, *arguments[2], options); - } - - string constant_string; - bool constant_pattern; - constant_pattern = TryParseConstantPattern(context, *arguments[1], constant_string); - return make_uniq(options, std::move(constant_string), constant_pattern); -} - -struct RegexPartialMatch { - static inline bool Operation(const duckdb_re2::StringPiece &input, duckdb_re2::RE2 &re) { - return duckdb_re2::RE2::PartialMatch(input, re); - } -}; - -struct RegexFullMatch { - static inline bool Operation(const duckdb_re2::StringPiece &input, duckdb_re2::RE2 &re) { - return duckdb_re2::RE2::FullMatch(input, re); - } -}; - -template -static void RegexpMatchesFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &strings = args.data[0]; - auto &patterns = args.data[1]; - - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - - if (info.constant_pattern) { - auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); - UnaryExecutor::Execute(strings, result, args.size(), [&](string_t input) { - return OP::Operation(CreateStringPiece(input), lstate.constant_pattern); - }); - } else { - BinaryExecutor::Execute(strings, patterns, result, args.size(), - [&](string_t input, string_t pattern) { - RE2 re(CreateStringPiece(pattern), info.options); - if (!re.ok()) { - throw Exception(re.error()); - } - return OP::Operation(CreateStringPiece(input), re); - }); - } -} - -//===--------------------------------------------------------------------===// -// Regexp Replace -//===--------------------------------------------------------------------===// -RegexpReplaceBindData::RegexpReplaceBindData() : global_replace(false) { -} - -RegexpReplaceBindData::RegexpReplaceBindData(duckdb_re2::RE2::Options options, string constant_string_p, - bool constant_pattern, bool global_replace) - : RegexpBaseBindData(options, std::move(constant_string_p), constant_pattern), global_replace(global_replace) { -} - -unique_ptr RegexpReplaceBindData::Copy() const { - auto copy = make_uniq(options, constant_string, constant_pattern, global_replace); - return std::move(copy); -} - -bool RegexpReplaceBindData::Equals(const FunctionData &other_p) const { - auto &other = other_p.Cast(); - return RegexpBaseBindData::Equals(other) && global_replace == other.global_replace; -} - -static unique_ptr RegexReplaceBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - auto data = make_uniq(); - - data->constant_pattern = TryParseConstantPattern(context, *arguments[1], data->constant_string); - if (arguments.size() == 4) { - ParseRegexOptions(context, *arguments[3], data->options, &data->global_replace); - } - data->options.set_log_errors(false); - return std::move(data); -} - -static void RegexReplaceFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - - auto &strings = args.data[0]; - auto &patterns = args.data[1]; - auto &replaces = args.data[2]; - - if (info.constant_pattern) { - auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); - BinaryExecutor::Execute( - strings, replaces, result, args.size(), [&](string_t input, string_t replace) { - std::string sstring = input.GetString(); - if (info.global_replace) { - RE2::GlobalReplace(&sstring, lstate.constant_pattern, CreateStringPiece(replace)); - } else { - RE2::Replace(&sstring, lstate.constant_pattern, CreateStringPiece(replace)); - } - return StringVector::AddString(result, sstring); - }); - } else { - TernaryExecutor::Execute( - strings, patterns, replaces, result, args.size(), [&](string_t input, string_t pattern, string_t replace) { - RE2 re(CreateStringPiece(pattern), info.options); - std::string sstring = input.GetString(); - if (info.global_replace) { - RE2::GlobalReplace(&sstring, re, CreateStringPiece(replace)); - } else { - RE2::Replace(&sstring, re, CreateStringPiece(replace)); - } - return StringVector::AddString(result, sstring); - }); - } -} - -//===--------------------------------------------------------------------===// -// Regexp Extract -//===--------------------------------------------------------------------===// -RegexpExtractBindData::RegexpExtractBindData() { -} - -RegexpExtractBindData::RegexpExtractBindData(duckdb_re2::RE2::Options options, string constant_string_p, - bool constant_pattern, string group_string_p) - : RegexpBaseBindData(options, std::move(constant_string_p), constant_pattern), - group_string(std::move(group_string_p)), rewrite(group_string) { -} - -unique_ptr RegexpExtractBindData::Copy() const { - return make_uniq(options, constant_string, constant_pattern, group_string); -} - -bool RegexpExtractBindData::Equals(const FunctionData &other_p) const { - auto &other = other_p.Cast(); - return RegexpBaseBindData::Equals(other) && group_string == other.group_string; -} - -static void RegexExtractFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - const auto &info = func_expr.bind_info->Cast(); - - auto &strings = args.data[0]; - auto &patterns = args.data[1]; - if (info.constant_pattern) { - auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); - UnaryExecutor::Execute(strings, result, args.size(), [&](string_t input) { - return Extract(input, result, lstate.constant_pattern, info.rewrite); - }); - } else { - BinaryExecutor::Execute(strings, patterns, result, args.size(), - [&](string_t input, string_t pattern) { - RE2 re(CreateStringPiece(pattern), info.options); - return Extract(input, result, re, info.rewrite); - }); - } -} - -//===--------------------------------------------------------------------===// -// Regexp Extract Struct -//===--------------------------------------------------------------------===// -static void RegexExtractStructFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); - - const auto count = args.size(); - auto &input = args.data[0]; - - auto &child_entries = StructVector::GetEntries(result); - const auto groupSize = child_entries.size(); - // Reference the 'input' StringBuffer, because we won't need to allocate new data - // for the result, all returned strings are substrings of the originals - for (auto &child_entry : child_entries) { - child_entry->SetAuxiliary(input.GetAuxiliary()); - } - - vector argv(groupSize); - vector groups(groupSize); - vector ws(groupSize); - for (size_t i = 0; i < groupSize; ++i) { - groups[i] = &argv[i]; - argv[i] = &ws[i]; - } - - if (input.GetVectorType() == VectorType::CONSTANT_VECTOR) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - - if (ConstantVector::IsNull(input)) { - ConstantVector::SetNull(result, true); - } else { - ConstantVector::SetNull(result, false); - auto idata = ConstantVector::GetData(input); - auto str = CreateStringPiece(idata[0]); - auto match = duckdb_re2::RE2::PartialMatchN(str, lstate.constant_pattern, groups.data(), groups.size()); - for (size_t col = 0; col < child_entries.size(); ++col) { - auto &child_entry = child_entries[col]; - ConstantVector::SetNull(*child_entry, false); - auto &extracted = ws[col]; - auto cdata = ConstantVector::GetData(*child_entry); - cdata[0] = string_t(extracted.data(), match ? extracted.size() : 0); - } - } - } else { - UnifiedVectorFormat iunified; - input.ToUnifiedFormat(count, iunified); - - const auto &ivalidity = iunified.validity; - auto idata = UnifiedVectorFormat::GetData(iunified); - - // Start with a valid flat vector - result.SetVectorType(VectorType::FLAT_VECTOR); - - // Start with valid children - for (size_t col = 0; col < child_entries.size(); ++col) { - auto &child_entry = child_entries[col]; - child_entry->SetVectorType(VectorType::FLAT_VECTOR); - } - - for (idx_t i = 0; i < count; ++i) { - const auto idx = iunified.sel->get_index(i); - if (ivalidity.RowIsValid(idx)) { - auto str = CreateStringPiece(idata[idx]); - auto match = duckdb_re2::RE2::PartialMatchN(str, lstate.constant_pattern, groups.data(), groups.size()); - for (size_t col = 0; col < child_entries.size(); ++col) { - auto &child_entry = child_entries[col]; - auto cdata = FlatVector::GetData(*child_entry); - auto &extracted = ws[col]; - cdata[i] = string_t(extracted.data(), match ? extracted.size() : 0); - } - } else { - FlatVector::SetNull(result, i, true); - } - } - } -} - -static unique_ptr RegexExtractBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - D_ASSERT(arguments.size() >= 2); - - duckdb_re2::RE2::Options options; - - string constant_string; - bool constant_pattern = TryParseConstantPattern(context, *arguments[1], constant_string); - - if (arguments.size() >= 4) { - ParseRegexOptions(context, *arguments[3], options); - } - - string group_string = "\\0"; - if (arguments.size() >= 3) { - if (arguments[2]->HasParameter()) { - throw ParameterNotResolvedException(); - } - if (!arguments[2]->IsFoldable()) { - throw InvalidInputException("Group specification field must be a constant!"); - } - Value group = ExpressionExecutor::EvaluateScalar(context, *arguments[2]); - if (group.IsNull()) { - group_string = ""; - } else if (group.type().id() == LogicalTypeId::LIST) { - if (!constant_pattern) { - throw BinderException("%s with LIST requires a constant pattern", bound_function.name); - } - auto &list_children = ListValue::GetChildren(group); - if (list_children.empty()) { - throw BinderException("%s requires non-empty lists of capture names", bound_function.name); - } - case_insensitive_set_t name_collision_set; - child_list_t struct_children; - for (const auto &child : list_children) { - if (child.IsNull()) { - throw BinderException("NULL group name in %s", bound_function.name); - } - const auto group_name = child.ToString(); - if (name_collision_set.find(group_name) != name_collision_set.end()) { - throw BinderException("Duplicate group name \"%s\" in %s", group_name, bound_function.name); - } - name_collision_set.insert(group_name); - struct_children.emplace_back(make_pair(group_name, LogicalType::VARCHAR)); - } - bound_function.return_type = LogicalType::STRUCT(struct_children); - - duckdb_re2::StringPiece constant_piece(constant_string.c_str(), constant_string.size()); - RE2 constant_pattern(constant_piece, options); - if (size_t(constant_pattern.NumberOfCapturingGroups()) < list_children.size()) { - throw BinderException("Not enough group names in %s", bound_function.name); - } - } else { - auto group_idx = group.GetValue(); - if (group_idx < 0 || group_idx > 9) { - throw InvalidInputException("Group index must be between 0 and 9!"); - } - group_string = "\\" + to_string(group_idx); - } - } - - return make_uniq(options, std::move(constant_string), constant_pattern, - std::move(group_string)); -} - -void RegexpFun::RegisterFunction(BuiltinFunctions &set) { - ScalarFunctionSet regexp_full_match("regexp_full_match"); - regexp_full_match.AddFunction(ScalarFunction( - {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, RegexpMatchesFunction, - RegexpMatchesBind, nullptr, nullptr, RegexInitLocalState, LogicalType::INVALID, - FunctionSideEffects::NO_SIDE_EFFECTS, FunctionNullHandling::SPECIAL_HANDLING)); - regexp_full_match.AddFunction(ScalarFunction( - {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, - RegexpMatchesFunction, RegexpMatchesBind, nullptr, nullptr, RegexInitLocalState, - LogicalType::INVALID, FunctionSideEffects::NO_SIDE_EFFECTS, FunctionNullHandling::SPECIAL_HANDLING)); - - ScalarFunctionSet regexp_partial_match("regexp_matches"); - regexp_partial_match.AddFunction(ScalarFunction( - {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, RegexpMatchesFunction, - RegexpMatchesBind, nullptr, nullptr, RegexInitLocalState, LogicalType::INVALID, - FunctionSideEffects::NO_SIDE_EFFECTS, FunctionNullHandling::SPECIAL_HANDLING)); - regexp_partial_match.AddFunction(ScalarFunction( - {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, - RegexpMatchesFunction, RegexpMatchesBind, nullptr, nullptr, RegexInitLocalState, - LogicalType::INVALID, FunctionSideEffects::NO_SIDE_EFFECTS, FunctionNullHandling::SPECIAL_HANDLING)); - - ScalarFunctionSet regexp_replace("regexp_replace"); - regexp_replace.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, - LogicalType::VARCHAR, RegexReplaceFunction, RegexReplaceBind, nullptr, - nullptr, RegexInitLocalState)); - regexp_replace.AddFunction(ScalarFunction( - {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, - RegexReplaceFunction, RegexReplaceBind, nullptr, nullptr, RegexInitLocalState)); - - ScalarFunctionSet regexp_extract("regexp_extract"); - regexp_extract.AddFunction( - ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, RegexExtractFunction, - RegexExtractBind, nullptr, nullptr, RegexInitLocalState, LogicalType::INVALID, - FunctionSideEffects::NO_SIDE_EFFECTS, FunctionNullHandling::SPECIAL_HANDLING)); - regexp_extract.AddFunction(ScalarFunction( - {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::INTEGER}, LogicalType::VARCHAR, RegexExtractFunction, - RegexExtractBind, nullptr, nullptr, RegexInitLocalState, LogicalType::INVALID, - FunctionSideEffects::NO_SIDE_EFFECTS, FunctionNullHandling::SPECIAL_HANDLING)); - regexp_extract.AddFunction(ScalarFunction( - {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::INTEGER, LogicalType::VARCHAR}, LogicalType::VARCHAR, - RegexExtractFunction, RegexExtractBind, nullptr, nullptr, RegexInitLocalState, LogicalType::INVALID, - FunctionSideEffects::NO_SIDE_EFFECTS, FunctionNullHandling::SPECIAL_HANDLING)); - // REGEXP_EXTRACT(, , [[, ]...]) - regexp_extract.AddFunction(ScalarFunction( - {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::LIST(LogicalType::VARCHAR)}, LogicalType::VARCHAR, - RegexExtractStructFunction, RegexExtractBind, nullptr, nullptr, RegexInitLocalState, LogicalType::INVALID, - FunctionSideEffects::NO_SIDE_EFFECTS, FunctionNullHandling::SPECIAL_HANDLING)); - // REGEXP_EXTRACT(, , [[, ]...], ) - regexp_extract.AddFunction(ScalarFunction( - {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::LIST(LogicalType::VARCHAR), LogicalType::VARCHAR}, - LogicalType::VARCHAR, RegexExtractStructFunction, RegexExtractBind, nullptr, nullptr, RegexInitLocalState, - LogicalType::INVALID, FunctionSideEffects::NO_SIDE_EFFECTS, FunctionNullHandling::SPECIAL_HANDLING)); - - ScalarFunctionSet regexp_extract_all("regexp_extract_all"); - regexp_extract_all.AddFunction(ScalarFunction( - {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::LIST(LogicalType::VARCHAR), - RegexpExtractAll::Execute, RegexpExtractAll::Bind, nullptr, nullptr, RegexpExtractAll::InitLocalState, - LogicalType::INVALID, FunctionSideEffects::NO_SIDE_EFFECTS, FunctionNullHandling::SPECIAL_HANDLING)); - regexp_extract_all.AddFunction(ScalarFunction( - {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::INTEGER}, LogicalType::LIST(LogicalType::VARCHAR), - RegexpExtractAll::Execute, RegexpExtractAll::Bind, nullptr, nullptr, RegexpExtractAll::InitLocalState, - LogicalType::INVALID, FunctionSideEffects::NO_SIDE_EFFECTS, FunctionNullHandling::SPECIAL_HANDLING)); - regexp_extract_all.AddFunction( - ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::INTEGER, LogicalType::VARCHAR}, - LogicalType::LIST(LogicalType::VARCHAR), RegexpExtractAll::Execute, RegexpExtractAll::Bind, - nullptr, nullptr, RegexpExtractAll::InitLocalState, LogicalType::INVALID, - FunctionSideEffects::NO_SIDE_EFFECTS, FunctionNullHandling::SPECIAL_HANDLING)); - - set.AddFunction(regexp_full_match); - set.AddFunction(regexp_partial_match); - set.AddFunction(regexp_replace); - set.AddFunction(regexp_extract); - set.AddFunction(regexp_extract_all); -} - -} // namespace duckdb - - - - -namespace duckdb { - -bool StripAccentsFun::IsAscii(const char *input, idx_t n) { - for (idx_t i = 0; i < n; i++) { - if (input[i] & 0x80) { - // non-ascii character - return false; - } - } - return true; -} - -struct StripAccentsOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - if (StripAccentsFun::IsAscii(input.GetData(), input.GetSize())) { - return input; - } - - // non-ascii, perform collation - auto stripped = utf8proc_remove_accents((const utf8proc_uint8_t *)input.GetData(), input.GetSize()); - auto result_str = StringVector::AddString(result, const_char_ptr_cast(stripped)); - free(stripped); - return result_str; - } -}; - -static void StripAccentsFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 1); - - UnaryExecutor::ExecuteString(args.data[0], result, args.size()); - StringVector::AddHeapReference(result, args.data[0]); -} - -ScalarFunction StripAccentsFun::GetFunction() { - return ScalarFunction("strip_accents", {LogicalType::VARCHAR}, LogicalType::VARCHAR, StripAccentsFunction); -} - -void StripAccentsFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(StripAccentsFun::GetFunction()); -} - -} // namespace duckdb - - - - - - - - - - - -namespace duckdb { - -static const int64_t SUPPORTED_UPPER_BOUND = NumericLimits::Maximum(); -static const int64_t SUPPORTED_LOWER_BOUND = -SUPPORTED_UPPER_BOUND - 1; - -static inline void AssertInSupportedRange(idx_t input_size, int64_t offset, int64_t length) { - - if (input_size > (uint64_t)SUPPORTED_UPPER_BOUND) { - throw OutOfRangeException("Substring input size is too large (> %d)", SUPPORTED_UPPER_BOUND); - } - if (offset < SUPPORTED_LOWER_BOUND) { - throw OutOfRangeException("Substring offset outside of supported range (< %d)", SUPPORTED_LOWER_BOUND); - } - if (offset > SUPPORTED_UPPER_BOUND) { - throw OutOfRangeException("Substring offset outside of supported range (> %d)", SUPPORTED_UPPER_BOUND); - } - if (length < SUPPORTED_LOWER_BOUND) { - throw OutOfRangeException("Substring length outside of supported range (< %d)", SUPPORTED_LOWER_BOUND); - } - if (length > SUPPORTED_UPPER_BOUND) { - throw OutOfRangeException("Substring length outside of supported range (> %d)", SUPPORTED_UPPER_BOUND); - } -} - -string_t SubstringEmptyString(Vector &result) { - auto result_string = StringVector::EmptyString(result, 0); - result_string.Finalize(); - return result_string; -} - -string_t SubstringSlice(Vector &result, const char *input_data, int64_t offset, int64_t length) { - auto result_string = StringVector::EmptyString(result, length); - auto result_data = result_string.GetDataWriteable(); - memcpy(result_data, input_data + offset, length); - result_string.Finalize(); - return result_string; -} - -// compute start and end characters from the given input size and offset/length -bool SubstringStartEnd(int64_t input_size, int64_t offset, int64_t length, int64_t &start, int64_t &end) { - if (length == 0) { - return false; - } - if (offset > 0) { - // positive offset: scan from start - start = MinValue(input_size, offset - 1); - } else if (offset < 0) { - // negative offset: scan from end (i.e. start = end + offset) - start = MaxValue(input_size + offset, 0); - } else { - // offset = 0: special case, we start 1 character BEHIND the first character - start = 0; - length--; - if (length <= 0) { - return false; - } - } - if (length > 0) { - // positive length: go forward (i.e. end = start + offset) - end = MinValue(input_size, start + length); - } else { - // negative length: go backwards (i.e. end = start, start = start + length) - end = start; - start = MaxValue(0, start + length); - } - if (start == end) { - return false; - } - D_ASSERT(start < end); - return true; -} - -string_t SubstringASCII(Vector &result, string_t input, int64_t offset, int64_t length) { - auto input_data = input.GetData(); - auto input_size = input.GetSize(); - - AssertInSupportedRange(input_size, offset, length); - - int64_t start, end; - if (!SubstringStartEnd(input_size, offset, length, start, end)) { - return SubstringEmptyString(result); - } - return SubstringSlice(result, input_data, start, end - start); -} - -string_t SubstringFun::SubstringUnicode(Vector &result, string_t input, int64_t offset, int64_t length) { - auto input_data = input.GetData(); - auto input_size = input.GetSize(); - - AssertInSupportedRange(input_size, offset, length); - - if (length == 0) { - return SubstringEmptyString(result); - } - // first figure out which direction we need to scan - idx_t start_pos; - idx_t end_pos; - if (offset < 0) { - start_pos = 0; - end_pos = DConstants::INVALID_INDEX; - - // negative offset: scan backwards - int64_t start, end; - - // we express start and end as unicode codepoints from the back - offset--; - if (length < 0) { - // negative length - start = -offset - length; - end = -offset; - } else { - // positive length - start = -offset; - end = -offset - length; - } - if (end <= 0) { - end_pos = input_size; - } - int64_t current_character = 0; - for (idx_t i = input_size; i > 0; i--) { - if (LengthFun::IsCharacter(input_data[i - 1])) { - current_character++; - if (current_character == start) { - start_pos = i; - break; - } else if (current_character == end) { - end_pos = i; - } - } - } - while (!LengthFun::IsCharacter(input_data[start_pos])) { - start_pos++; - } - while (end_pos < input_size && !LengthFun::IsCharacter(input_data[end_pos])) { - end_pos++; - } - - if (end_pos == DConstants::INVALID_INDEX) { - return SubstringEmptyString(result); - } - } else { - start_pos = DConstants::INVALID_INDEX; - end_pos = input_size; - - // positive offset: scan forwards - int64_t start, end; - - // we express start and end as unicode codepoints from the front - offset--; - if (length < 0) { - // negative length - start = MaxValue(0, offset + length); - end = offset; - } else { - // positive length - start = MaxValue(0, offset); - end = offset + length; - } - - int64_t current_character = 0; - for (idx_t i = 0; i < input_size; i++) { - if (LengthFun::IsCharacter(input_data[i])) { - if (current_character == start) { - start_pos = i; - } else if (current_character == end) { - end_pos = i; - break; - } - current_character++; - } - } - if (start_pos == DConstants::INVALID_INDEX || end == 0 || end <= start) { - return SubstringEmptyString(result); - } - } - D_ASSERT(end_pos >= start_pos); - // after we have found these, we can slice the substring - return SubstringSlice(result, input_data, start_pos, end_pos - start_pos); -} - -string_t SubstringFun::SubstringGrapheme(Vector &result, string_t input, int64_t offset, int64_t length) { - auto input_data = input.GetData(); - auto input_size = input.GetSize(); - - AssertInSupportedRange(input_size, offset, length); - - // we don't know yet if the substring is ascii, but we assume it is (for now) - // first get the start and end as if this was an ascii string - int64_t start, end; - if (!SubstringStartEnd(input_size, offset, length, start, end)) { - return SubstringEmptyString(result); - } - - // now check if all the characters between 0 and end are ascii characters - // note that we scan one further to check for a potential combining diacritics (e.g. i + diacritic is ï) - bool is_ascii = true; - idx_t ascii_end = MinValue(end + 1, input_size); - for (idx_t i = 0; i < ascii_end; i++) { - if (input_data[i] & 0x80) { - // found a non-ascii character: eek - is_ascii = false; - break; - } - } - if (is_ascii) { - // all characters are ascii, we can just slice the substring - return SubstringSlice(result, input_data, start, end - start); - } - // if the characters are not ascii, we need to scan grapheme clusters - // first figure out which direction we need to scan - // offset = 0 case is taken care of in SubstringStartEnd - if (offset < 0) { - // negative offset, this case is more difficult - // we first need to count the number of characters in the string - idx_t num_characters = 0; - utf8proc_grapheme_callback(input_data, input_size, [&](size_t start, size_t end) { - num_characters++; - return true; - }); - // now call substring start and end again, but with the number of unicode characters this time - SubstringStartEnd(num_characters, offset, length, start, end); - } - - // now scan the graphemes of the string to find the positions of the start and end characters - int64_t current_character = 0; - idx_t start_pos = DConstants::INVALID_INDEX, end_pos = input_size; - utf8proc_grapheme_callback(input_data, input_size, [&](size_t gstart, size_t gend) { - if (current_character == start) { - start_pos = gstart; - } else if (current_character == end) { - end_pos = gstart; - return false; - } - current_character++; - return true; - }); - if (start_pos == DConstants::INVALID_INDEX) { - return SubstringEmptyString(result); - } - // after we have found these, we can slice the substring - return SubstringSlice(result, input_data, start_pos, end_pos - start_pos); -} - -struct SubstringUnicodeOp { - static string_t Substring(Vector &result, string_t input, int64_t offset, int64_t length) { - return SubstringFun::SubstringUnicode(result, input, offset, length); - } -}; - -struct SubstringGraphemeOp { - static string_t Substring(Vector &result, string_t input, int64_t offset, int64_t length) { - return SubstringFun::SubstringGrapheme(result, input, offset, length); - } -}; - -template -static void SubstringFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &input_vector = args.data[0]; - auto &offset_vector = args.data[1]; - if (args.ColumnCount() == 3) { - auto &length_vector = args.data[2]; - - TernaryExecutor::Execute( - input_vector, offset_vector, length_vector, result, args.size(), - [&](string_t input_string, int64_t offset, int64_t length) { - return OP::Substring(result, input_string, offset, length); - }); - } else { - BinaryExecutor::Execute( - input_vector, offset_vector, result, args.size(), [&](string_t input_string, int64_t offset) { - return OP::Substring(result, input_string, offset, NumericLimits::Maximum()); - }); - } -} - -static void SubstringFunctionASCII(DataChunk &args, ExpressionState &state, Vector &result) { - auto &input_vector = args.data[0]; - auto &offset_vector = args.data[1]; - if (args.ColumnCount() == 3) { - auto &length_vector = args.data[2]; - - TernaryExecutor::Execute( - input_vector, offset_vector, length_vector, result, args.size(), - [&](string_t input_string, int64_t offset, int64_t length) { - return SubstringASCII(result, input_string, offset, length); - }); - } else { - BinaryExecutor::Execute( - input_vector, offset_vector, result, args.size(), [&](string_t input_string, int64_t offset) { - return SubstringASCII(result, input_string, offset, NumericLimits::Maximum()); - }); - } -} - -static unique_ptr SubstringPropagateStats(ClientContext &context, FunctionStatisticsInput &input) { - auto &child_stats = input.child_stats; - auto &expr = input.expr; - // can only propagate stats if the children have stats - // we only care about the stats of the first child (i.e. the string) - if (!StringStats::CanContainUnicode(child_stats[0])) { - expr.function.function = SubstringFunctionASCII; - } - return nullptr; -} - -void SubstringFun::RegisterFunction(BuiltinFunctions &set) { - ScalarFunctionSet substr("substring"); - substr.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::BIGINT, LogicalType::BIGINT}, - LogicalType::VARCHAR, SubstringFunction, nullptr, nullptr, - SubstringPropagateStats)); - substr.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::BIGINT}, LogicalType::VARCHAR, - SubstringFunction, nullptr, nullptr, - SubstringPropagateStats)); - set.AddFunction(substr); - substr.name = "substr"; - set.AddFunction(substr); - - ScalarFunctionSet substr_grapheme("substring_grapheme"); - substr_grapheme.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::BIGINT, LogicalType::BIGINT}, - LogicalType::VARCHAR, SubstringFunction, nullptr, - nullptr, SubstringPropagateStats)); - substr_grapheme.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::BIGINT}, LogicalType::VARCHAR, - SubstringFunction, nullptr, nullptr, - SubstringPropagateStats)); - set.AddFunction(substr_grapheme); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -static bool SuffixFunction(const string_t &str, const string_t &suffix); - -struct SuffixOperator { - template - static inline TR Operation(TA left, TB right) { - return SuffixFunction(left, right); - } -}; - -static bool SuffixFunction(const string_t &str, const string_t &suffix) { - auto suffix_size = suffix.GetSize(); - auto str_size = str.GetSize(); - if (suffix_size > str_size) { - return false; - } - - auto suffix_data = suffix.GetData(); - auto str_data = str.GetData(); - int32_t suf_idx = suffix_size - 1; - idx_t str_idx = str_size - 1; - for (; suf_idx >= 0; --suf_idx, --str_idx) { - if (suffix_data[suf_idx] != str_data[str_idx]) { - return false; - } - } - return true; -} - -ScalarFunction SuffixFun::GetFunction() { - return ScalarFunction("suffix", // name of the function - {LogicalType::VARCHAR, LogicalType::VARCHAR}, // argument list - LogicalType::BOOLEAN, // return type - ScalarFunction::BinaryFunction); -} - -void SuffixFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction({"suffix", "ends_with"}, GetFunction()); -} - -} // namespace duckdb - - -namespace duckdb { - -void BuiltinFunctions::RegisterStringFunctions() { - Register(); - Register(); - Register(); - Register(); - Register(); - Register(); - Register(); - Register(); - Register(); - Register(); - Register(); - Register(); - Register(); -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -struct StructExtractBindData : public FunctionData { - StructExtractBindData(string key, idx_t index, LogicalType type) - : key(std::move(key)), index(index), type(std::move(type)) { - } - - string key; - idx_t index; - LogicalType type; - -public: - unique_ptr Copy() const override { - return make_uniq(key, index, type); - } - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return key == other.key && index == other.index && type == other.type; - } -}; - -static void StructExtractFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - - // this should be guaranteed by the binder - auto &vec = args.data[0]; - - vec.Verify(args.size()); - auto &children = StructVector::GetEntries(vec); - D_ASSERT(info.index < children.size()); - auto &struct_child = children[info.index]; - result.Reference(*struct_child); - result.Verify(args.size()); -} - -static unique_ptr StructExtractBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - D_ASSERT(bound_function.arguments.size() == 2); - if (arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN) { - throw ParameterNotResolvedException(); - } - D_ASSERT(LogicalTypeId::STRUCT == arguments[0]->return_type.id()); - auto &struct_children = StructType::GetChildTypes(arguments[0]->return_type); - if (struct_children.empty()) { - throw InternalException("Can't extract something from an empty struct"); - } - bound_function.arguments[0] = arguments[0]->return_type; - - auto &key_child = arguments[1]; - if (key_child->HasParameter()) { - throw ParameterNotResolvedException(); - } - - if (key_child->return_type.id() != LogicalTypeId::VARCHAR || !key_child->IsFoldable()) { - throw BinderException("Key name for struct_extract needs to be a constant string"); - } - Value key_val = ExpressionExecutor::EvaluateScalar(context, *key_child); - D_ASSERT(key_val.type().id() == LogicalTypeId::VARCHAR); - auto &key_str = StringValue::Get(key_val); - if (key_val.IsNull() || key_str.empty()) { - throw BinderException("Key name for struct_extract needs to be neither NULL nor empty"); - } - string key = StringUtil::Lower(key_str); - - LogicalType return_type; - idx_t key_index = 0; - bool found_key = false; - - for (size_t i = 0; i < struct_children.size(); i++) { - auto &child = struct_children[i]; - if (StringUtil::Lower(child.first) == key) { - found_key = true; - key_index = i; - return_type = child.second; - break; - } - } - - if (!found_key) { - vector candidates; - candidates.reserve(struct_children.size()); - for (auto &struct_child : struct_children) { - candidates.push_back(struct_child.first); - } - auto closest_settings = StringUtil::TopNLevenshtein(candidates, key); - auto message = StringUtil::CandidatesMessage(closest_settings, "Candidate Entries"); - throw BinderException("Could not find key \"%s\" in struct\n%s", key, message); - } - - bound_function.return_type = return_type; - return make_uniq(std::move(key), key_index, std::move(return_type)); -} - -static unique_ptr PropagateStructExtractStats(ClientContext &context, FunctionStatisticsInput &input) { - auto &child_stats = input.child_stats; - auto &bind_data = input.bind_data; - - auto &info = bind_data->Cast(); - auto struct_child_stats = StructStats::GetChildStats(child_stats[0]); - return struct_child_stats[info.index].ToUnique(); -} - -ScalarFunction StructExtractFun::GetFunction() { - return ScalarFunction("struct_extract", {LogicalTypeId::STRUCT, LogicalType::VARCHAR}, LogicalType::ANY, - StructExtractFunction, StructExtractBind, nullptr, PropagateStructExtractStats); -} - -void StructExtractFun::RegisterFunction(BuiltinFunctions &set) { - // the arguments and return types are actually set in the binder function - auto fun = GetFunction(); - set.AddFunction(fun); -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -// aggregate state export -struct ExportAggregateBindData : public FunctionData { - AggregateFunction aggr; - idx_t state_size; - - explicit ExportAggregateBindData(AggregateFunction aggr_p, idx_t state_size_p) - : aggr(std::move(aggr_p)), state_size(state_size_p) { - } - - unique_ptr Copy() const override { - return make_uniq(aggr, state_size); - } - - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return aggr == other.aggr && state_size == other.state_size; - } - - static ExportAggregateBindData &GetFrom(ExpressionState &state) { - auto &func_expr = state.expr.Cast(); - return func_expr.bind_info->Cast(); - } -}; - -struct CombineState : public FunctionLocalState { - idx_t state_size; - - unsafe_unique_array state_buffer0, state_buffer1; - Vector state_vector0, state_vector1; - - ArenaAllocator allocator; - - explicit CombineState(idx_t state_size_p) - : state_size(state_size_p), state_buffer0(make_unsafe_uniq_array(state_size_p)), - state_buffer1(make_unsafe_uniq_array(state_size_p)), - state_vector0(Value::POINTER(CastPointerToValue(state_buffer0.get()))), - state_vector1(Value::POINTER(CastPointerToValue(state_buffer1.get()))), - allocator(Allocator::DefaultAllocator()) { - } -}; - -static unique_ptr InitCombineState(ExpressionState &state, const BoundFunctionExpression &expr, - FunctionData *bind_data_p) { - auto &bind_data = bind_data_p->Cast(); - return make_uniq(bind_data.state_size); -} - -struct FinalizeState : public FunctionLocalState { - idx_t state_size; - unsafe_unique_array state_buffer; - Vector addresses; - - ArenaAllocator allocator; - - explicit FinalizeState(idx_t state_size_p) - : state_size(state_size_p), - state_buffer(make_unsafe_uniq_array(STANDARD_VECTOR_SIZE * AlignValue(state_size_p))), - addresses(LogicalType::POINTER), allocator(Allocator::DefaultAllocator()) { - } -}; - -static unique_ptr InitFinalizeState(ExpressionState &state, const BoundFunctionExpression &expr, - FunctionData *bind_data_p) { - auto &bind_data = bind_data_p->Cast(); - return make_uniq(bind_data.state_size); -} - -static void AggregateStateFinalize(DataChunk &input, ExpressionState &state_p, Vector &result) { - auto &bind_data = ExportAggregateBindData::GetFrom(state_p); - auto &local_state = ExecuteFunctionState::GetFunctionState(state_p)->Cast(); - local_state.allocator.Reset(); - - D_ASSERT(bind_data.state_size == bind_data.aggr.state_size()); - D_ASSERT(input.data.size() == 1); - D_ASSERT(input.data[0].GetType().id() == LogicalTypeId::AGGREGATE_STATE); - auto aligned_state_size = AlignValue(bind_data.state_size); - - auto state_vec_ptr = FlatVector::GetData(local_state.addresses); - - UnifiedVectorFormat state_data; - input.data[0].ToUnifiedFormat(input.size(), state_data); - for (idx_t i = 0; i < input.size(); i++) { - auto state_idx = state_data.sel->get_index(i); - auto state_entry = UnifiedVectorFormat::GetData(state_data) + state_idx; - auto target_ptr = char_ptr_cast(local_state.state_buffer.get()) + aligned_state_size * i; - - if (state_data.validity.RowIsValid(state_idx)) { - D_ASSERT(state_entry->GetSize() == bind_data.state_size); - memcpy((void *)target_ptr, state_entry->GetData(), bind_data.state_size); - } else { - // create a dummy state because finalize does not understand NULLs in its input - // we put the NULL back in explicitly below - bind_data.aggr.initialize(data_ptr_cast(target_ptr)); - } - state_vec_ptr[i] = data_ptr_cast(target_ptr); - } - - AggregateInputData aggr_input_data(nullptr, local_state.allocator); - bind_data.aggr.finalize(local_state.addresses, aggr_input_data, result, input.size(), 0); - - for (idx_t i = 0; i < input.size(); i++) { - auto state_idx = state_data.sel->get_index(i); - if (!state_data.validity.RowIsValid(state_idx)) { - FlatVector::SetNull(result, i, true); - } - } -} - -static void AggregateStateCombine(DataChunk &input, ExpressionState &state_p, Vector &result) { - auto &bind_data = ExportAggregateBindData::GetFrom(state_p); - auto &local_state = ExecuteFunctionState::GetFunctionState(state_p)->Cast(); - local_state.allocator.Reset(); - - D_ASSERT(bind_data.state_size == bind_data.aggr.state_size()); - - D_ASSERT(input.data.size() == 2); - D_ASSERT(input.data[0].GetType().id() == LogicalTypeId::AGGREGATE_STATE); - D_ASSERT(input.data[0].GetType() == result.GetType()); - - if (input.data[0].GetType().InternalType() != input.data[1].GetType().InternalType()) { - throw IOException("Aggregate state combine type mismatch, expect %s, got %s", - input.data[0].GetType().ToString(), input.data[1].GetType().ToString()); - } - - UnifiedVectorFormat state0_data, state1_data; - input.data[0].ToUnifiedFormat(input.size(), state0_data); - input.data[1].ToUnifiedFormat(input.size(), state1_data); - - auto result_ptr = FlatVector::GetData(result); - - for (idx_t i = 0; i < input.size(); i++) { - auto state0_idx = state0_data.sel->get_index(i); - auto state1_idx = state1_data.sel->get_index(i); - - auto &state0 = UnifiedVectorFormat::GetData(state0_data)[state0_idx]; - auto &state1 = UnifiedVectorFormat::GetData(state1_data)[state1_idx]; - - // if both are NULL, we return NULL. If either of them is not, the result is that one - if (!state0_data.validity.RowIsValid(state0_idx) && !state1_data.validity.RowIsValid(state1_idx)) { - FlatVector::SetNull(result, i, true); - continue; - } - if (state0_data.validity.RowIsValid(state0_idx) && !state1_data.validity.RowIsValid(state1_idx)) { - result_ptr[i] = - StringVector::AddStringOrBlob(result, const_char_ptr_cast(state0.GetData()), bind_data.state_size); - continue; - } - if (!state0_data.validity.RowIsValid(state0_idx) && state1_data.validity.RowIsValid(state1_idx)) { - result_ptr[i] = - StringVector::AddStringOrBlob(result, const_char_ptr_cast(state1.GetData()), bind_data.state_size); - continue; - } - - // we actually have to combine - if (state0.GetSize() != bind_data.state_size || state1.GetSize() != bind_data.state_size) { - throw IOException("Aggregate state size mismatch, expect %llu, got %llu and %llu", bind_data.state_size, - state0.GetSize(), state1.GetSize()); - } - - memcpy(local_state.state_buffer0.get(), state0.GetData(), bind_data.state_size); - memcpy(local_state.state_buffer1.get(), state1.GetData(), bind_data.state_size); - - AggregateInputData aggr_input_data(nullptr, local_state.allocator); - bind_data.aggr.combine(local_state.state_vector0, local_state.state_vector1, aggr_input_data, 1); - - result_ptr[i] = StringVector::AddStringOrBlob(result, const_char_ptr_cast(local_state.state_buffer1.get()), - bind_data.state_size); - } -} - -static unique_ptr BindAggregateState(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - - // grab the aggregate type and bind the aggregate again - - // the aggregate name and types are in the logical type of the aggregate state, make sure its sane - auto &arg_return_type = arguments[0]->return_type; - for (auto &arg_type : bound_function.arguments) { - arg_type = arg_return_type; - } - - if (arg_return_type.id() != LogicalTypeId::AGGREGATE_STATE) { - throw BinderException("Can only FINALIZE aggregate state, not %s", arg_return_type.ToString()); - } - // combine - if (arguments.size() == 2 && arguments[0]->return_type != arguments[1]->return_type && - arguments[1]->return_type.id() != LogicalTypeId::BLOB) { - throw BinderException("Cannot COMBINE aggregate states from different functions, %s <> %s", - arguments[0]->return_type.ToString(), arguments[1]->return_type.ToString()); - } - - // following error states are only reachable when someone messes up creating the state_type - // which is impossible from SQL - - auto state_type = AggregateStateType::GetStateType(arg_return_type); - - // now we can look up the function in the catalog again and bind it - auto &func = Catalog::GetSystemCatalog(context).GetEntry(context, CatalogType::SCALAR_FUNCTION_ENTRY, - DEFAULT_SCHEMA, state_type.function_name); - if (func.type != CatalogType::AGGREGATE_FUNCTION_ENTRY) { - throw InternalException("Could not find aggregate %s", state_type.function_name); - } - auto &aggr = func.Cast(); - - string error; - - FunctionBinder function_binder(context); - idx_t best_function = - function_binder.BindFunction(aggr.name, aggr.functions, state_type.bound_argument_types, error); - if (best_function == DConstants::INVALID_INDEX) { - throw InternalException("Could not re-bind exported aggregate %s: %s", state_type.function_name, error); - } - auto bound_aggr = aggr.functions.GetFunctionByOffset(best_function); - if (bound_aggr.bind) { - // FIXME: this is really hacky - // but the aggregate state export needs a rework around how it handles more complex aggregates anyway - vector> args; - args.reserve(state_type.bound_argument_types.size()); - for (auto &arg_type : state_type.bound_argument_types) { - args.push_back(make_uniq(Value(arg_type))); - } - auto bind_info = bound_aggr.bind(context, bound_aggr, args); - if (bind_info) { - throw BinderException("Aggregate function with bind info not supported yet in aggregate state export"); - } - } - - if (bound_aggr.return_type != state_type.return_type || bound_aggr.arguments != state_type.bound_argument_types) { - throw InternalException("Type mismatch for exported aggregate %s", state_type.function_name); - } - - if (bound_function.name == "finalize") { - bound_function.return_type = bound_aggr.return_type; - } else { - D_ASSERT(bound_function.name == "combine"); - bound_function.return_type = arg_return_type; - } - - return make_uniq(bound_aggr, bound_aggr.state_size()); -} - -static void ExportAggregateFinalize(Vector &state, AggregateInputData &aggr_input_data, Vector &result, idx_t count, - idx_t offset) { - D_ASSERT(offset == 0); - auto &bind_data = aggr_input_data.bind_data->Cast(); - auto state_size = bind_data.aggregate->function.state_size(); - auto blob_ptr = FlatVector::GetData(result); - auto addresses_ptr = FlatVector::GetData(state); - for (idx_t row_idx = 0; row_idx < count; row_idx++) { - auto data_ptr = addresses_ptr[row_idx]; - blob_ptr[row_idx] = StringVector::AddStringOrBlob(result, const_char_ptr_cast(data_ptr), state_size); - } -} - -ExportAggregateFunctionBindData::ExportAggregateFunctionBindData(unique_ptr aggregate_p) { - D_ASSERT(aggregate_p->type == ExpressionType::BOUND_AGGREGATE); - aggregate = unique_ptr_cast(std::move(aggregate_p)); -} - -unique_ptr ExportAggregateFunctionBindData::Copy() const { - return make_uniq(aggregate->Copy()); -} - -bool ExportAggregateFunctionBindData::Equals(const FunctionData &other_p) const { - auto &other = other_p.Cast(); - return aggregate->Equals(*other.aggregate); -} - -static void ExportStateAggregateSerialize(Serializer &serializer, const optional_ptr bind_data_p, - const AggregateFunction &function) { - throw NotImplementedException("FIXME: export state serialize"); -} - -static unique_ptr ExportStateAggregateDeserialize(Deserializer &deserializer, - AggregateFunction &function) { - throw NotImplementedException("FIXME: export state deserialize"); -} - -static void ExportStateScalarSerialize(Serializer &serializer, const optional_ptr bind_data_p, - const ScalarFunction &function) { - throw NotImplementedException("FIXME: export state serialize"); -} - -static unique_ptr ExportStateScalarDeserialize(Deserializer &deserializer, ScalarFunction &function) { - throw NotImplementedException("FIXME: export state deserialize"); -} - -unique_ptr -ExportAggregateFunction::Bind(unique_ptr child_aggregate) { - auto &bound_function = child_aggregate->function; - if (!bound_function.combine) { - throw BinderException("Cannot use EXPORT_STATE for non-combinable function %s", bound_function.name); - } - if (bound_function.bind) { - throw BinderException("Cannot use EXPORT_STATE on aggregate functions with custom binders"); - } - if (bound_function.destructor) { - throw BinderException("Cannot use EXPORT_STATE on aggregate functions with custom destructors"); - } - // this should be required - D_ASSERT(bound_function.state_size); - D_ASSERT(bound_function.finalize); - - D_ASSERT(child_aggregate->function.return_type.id() != LogicalTypeId::INVALID); -#ifdef DEBUG - for (auto &arg_type : child_aggregate->function.arguments) { - D_ASSERT(arg_type.id() != LogicalTypeId::INVALID); - } -#endif - auto export_bind_data = make_uniq(child_aggregate->Copy()); - aggregate_state_t state_type(child_aggregate->function.name, child_aggregate->function.return_type, - child_aggregate->function.arguments); - auto return_type = LogicalType::AGGREGATE_STATE(std::move(state_type)); - - auto export_function = - AggregateFunction("aggregate_state_export_" + bound_function.name, bound_function.arguments, return_type, - bound_function.state_size, bound_function.initialize, bound_function.update, - bound_function.combine, ExportAggregateFinalize, bound_function.simple_update, - /* can't bind this again */ nullptr, /* no dynamic state yet */ nullptr, - /* can't propagate statistics */ nullptr, nullptr); - export_function.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - export_function.serialize = ExportStateAggregateSerialize; - export_function.deserialize = ExportStateAggregateDeserialize; - - return make_uniq(export_function, std::move(child_aggregate->children), - std::move(child_aggregate->filter), std::move(export_bind_data), - child_aggregate->aggr_type); -} - -ScalarFunction ExportAggregateFunction::GetFinalize() { - auto result = ScalarFunction("finalize", {LogicalTypeId::AGGREGATE_STATE}, LogicalTypeId::INVALID, - AggregateStateFinalize, BindAggregateState, nullptr, nullptr, InitFinalizeState); - result.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - result.serialize = ExportStateScalarSerialize; - result.deserialize = ExportStateScalarDeserialize; - return result; -} - -ScalarFunction ExportAggregateFunction::GetCombine() { - auto result = - ScalarFunction("combine", {LogicalTypeId::AGGREGATE_STATE, LogicalTypeId::ANY}, LogicalTypeId::AGGREGATE_STATE, - AggregateStateCombine, BindAggregateState, nullptr, nullptr, InitCombineState); - result.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - result.serialize = ExportStateScalarSerialize; - result.deserialize = ExportStateScalarDeserialize; - return result; -} - -void ExportAggregateFunction::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(ExportAggregateFunction::GetCombine()); - set.AddFunction(ExportAggregateFunction::GetFinalize()); -} - -} // namespace duckdb - - -namespace duckdb { - -FunctionLocalState::~FunctionLocalState() { -} - -ScalarFunction::ScalarFunction(string name, vector arguments, LogicalType return_type, - scalar_function_t function, bind_scalar_function_t bind, - dependency_function_t dependency, function_statistics_t statistics, - init_local_state_t init_local_state, LogicalType varargs, - FunctionSideEffects side_effects, FunctionNullHandling null_handling) - : BaseScalarFunction(std::move(name), std::move(arguments), std::move(return_type), side_effects, - std::move(varargs), null_handling), - function(std::move(function)), bind(bind), init_local_state(init_local_state), dependency(dependency), - statistics(statistics), serialize(nullptr), deserialize(nullptr) { -} - -ScalarFunction::ScalarFunction(vector arguments, LogicalType return_type, scalar_function_t function, - bind_scalar_function_t bind, dependency_function_t dependency, - function_statistics_t statistics, init_local_state_t init_local_state, - LogicalType varargs, FunctionSideEffects side_effects, - FunctionNullHandling null_handling) - : ScalarFunction(string(), std::move(arguments), std::move(return_type), std::move(function), bind, dependency, - statistics, init_local_state, std::move(varargs), side_effects, null_handling) { -} - -bool ScalarFunction::operator==(const ScalarFunction &rhs) const { - return name == rhs.name && arguments == rhs.arguments && return_type == rhs.return_type && varargs == rhs.varargs && - bind == rhs.bind && dependency == rhs.dependency && statistics == rhs.statistics; -} - -bool ScalarFunction::operator!=(const ScalarFunction &rhs) const { - return !(*this == rhs); -} - -bool ScalarFunction::Equal(const ScalarFunction &rhs) const { - // number of types - if (this->arguments.size() != rhs.arguments.size()) { - return false; - } - // argument types - for (idx_t i = 0; i < this->arguments.size(); ++i) { - if (this->arguments[i] != rhs.arguments[i]) { - return false; - } - } - // return type - if (this->return_type != rhs.return_type) { - return false; - } - // varargs - if (this->varargs != rhs.varargs) { - return false; - } - - return true; // they are equal -} - -void ScalarFunction::NopFunction(DataChunk &input, ExpressionState &state, Vector &result) { - D_ASSERT(input.ColumnCount() >= 1); - result.Reference(input.data[0]); -} - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/function/scalar_macro_function.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - -namespace duckdb { - -ScalarMacroFunction::ScalarMacroFunction(unique_ptr expression) - : MacroFunction(MacroType::SCALAR_MACRO), expression(std::move(expression)) { -} - -ScalarMacroFunction::ScalarMacroFunction(void) : MacroFunction(MacroType::SCALAR_MACRO) { -} - -unique_ptr ScalarMacroFunction::Copy() const { - auto result = make_uniq(); - result->expression = expression->Copy(); - CopyProperties(*result); - - return std::move(result); -} - -void RemoveQualificationRecursive(unique_ptr &expr) { - if (expr->GetExpressionType() == ExpressionType::COLUMN_REF) { - auto &col_ref = expr->Cast(); - auto &col_names = col_ref.column_names; - if (col_names.size() == 2 && col_names[0].find(DummyBinding::DUMMY_NAME) != string::npos) { - col_names.erase(col_names.begin()); - } - } else { - ParsedExpressionIterator::EnumerateChildren( - *expr, [](unique_ptr &child) { RemoveQualificationRecursive(child); }); - } -} - -string ScalarMacroFunction::ToSQL(const string &schema, const string &name) const { - // In case of nested macro's we need to fix it a bit - auto expression_copy = expression->Copy(); - RemoveQualificationRecursive(expression_copy); - return MacroFunction::ToSQL(schema, name) + StringUtil::Format("(%s);", expression_copy->ToString()); -} - -} // namespace duckdb - - - - -namespace duckdb { - -void ArrowTableType::AddColumn(idx_t index, unique_ptr type) { - D_ASSERT(arrow_convert_data.find(index) == arrow_convert_data.end()); - arrow_convert_data.emplace(std::make_pair(index, std::move(type))); -} - -const arrow_column_map_t &ArrowTableType::GetColumns() const { - return arrow_convert_data; -} - -void ArrowType::AddChild(unique_ptr child) { - children.emplace_back(std::move(child)); -} - -void ArrowType::AssignChildren(vector> children) { - D_ASSERT(this->children.empty()); - this->children = std::move(children); -} - -void ArrowType::SetDictionary(unique_ptr dictionary) { - D_ASSERT(!this->dictionary_type); - dictionary_type = std::move(dictionary); -} - -const ArrowType &ArrowType::GetDictionary() const { - D_ASSERT(dictionary_type); - return *dictionary_type; -} - -const LogicalType &ArrowType::GetDuckType() const { - return type; -} - -ArrowVariableSizeType ArrowType::GetSizeType() const { - return size_type; -} - -ArrowDateTimeType ArrowType::GetDateTimeType() const { - return date_time_precision; -} - -const ArrowType &ArrowType::operator[](idx_t index) const { - D_ASSERT(index < children.size()); - return *children[index]; -} - -idx_t ArrowType::FixedSize() const { - D_ASSERT(size_type == ArrowVariableSizeType::FIXED_SIZE); - return fixed_size; -} - -} // namespace duckdb - - - - - - - - - - - - - - -namespace duckdb { - -unique_ptr ArrowTableFunction::GetArrowLogicalType(ArrowSchema &schema) { - auto format = string(schema.format); - if (format == "n") { - return make_uniq(LogicalType::SQLNULL); - } else if (format == "b") { - return make_uniq(LogicalType::BOOLEAN); - } else if (format == "c") { - return make_uniq(LogicalType::TINYINT); - } else if (format == "s") { - return make_uniq(LogicalType::SMALLINT); - } else if (format == "i") { - return make_uniq(LogicalType::INTEGER); - } else if (format == "l") { - return make_uniq(LogicalType::BIGINT); - } else if (format == "C") { - return make_uniq(LogicalType::UTINYINT); - } else if (format == "S") { - return make_uniq(LogicalType::USMALLINT); - } else if (format == "I") { - return make_uniq(LogicalType::UINTEGER); - } else if (format == "L") { - return make_uniq(LogicalType::UBIGINT); - } else if (format == "f") { - return make_uniq(LogicalType::FLOAT); - } else if (format == "g") { - return make_uniq(LogicalType::DOUBLE); - } else if (format[0] == 'd') { //! this can be either decimal128 or decimal 256 (e.g., d:38,0) - std::string parameters = format.substr(format.find(':')); - uint8_t width = std::stoi(parameters.substr(1, parameters.find(','))); - uint8_t scale = std::stoi(parameters.substr(parameters.find(',') + 1)); - if (width > 38) { - throw NotImplementedException("Unsupported Internal Arrow Type for Decimal %s", format); - } - return make_uniq(LogicalType::DECIMAL(width, scale)); - } else if (format == "u") { - return make_uniq(LogicalType::VARCHAR, ArrowVariableSizeType::NORMAL); - } else if (format == "U") { - return make_uniq(LogicalType::VARCHAR, ArrowVariableSizeType::SUPER_SIZE); - } else if (format == "tsn:") { - return make_uniq(LogicalTypeId::TIMESTAMP_NS); - } else if (format == "tsu:") { - return make_uniq(LogicalTypeId::TIMESTAMP); - } else if (format == "tsm:") { - return make_uniq(LogicalTypeId::TIMESTAMP_MS); - } else if (format == "tss:") { - return make_uniq(LogicalTypeId::TIMESTAMP_SEC); - } else if (format == "tdD") { - return make_uniq(LogicalType::DATE, ArrowDateTimeType::DAYS); - } else if (format == "tdm") { - return make_uniq(LogicalType::DATE, ArrowDateTimeType::MILLISECONDS); - } else if (format == "tts") { - return make_uniq(LogicalType::TIME, ArrowDateTimeType::SECONDS); - } else if (format == "ttm") { - return make_uniq(LogicalType::TIME, ArrowDateTimeType::MILLISECONDS); - } else if (format == "ttu") { - return make_uniq(LogicalType::TIME, ArrowDateTimeType::MICROSECONDS); - } else if (format == "ttn") { - return make_uniq(LogicalType::TIME, ArrowDateTimeType::NANOSECONDS); - } else if (format == "tDs") { - return make_uniq(LogicalType::INTERVAL, ArrowDateTimeType::SECONDS); - } else if (format == "tDm") { - return make_uniq(LogicalType::INTERVAL, ArrowDateTimeType::MILLISECONDS); - } else if (format == "tDu") { - return make_uniq(LogicalType::INTERVAL, ArrowDateTimeType::MICROSECONDS); - } else if (format == "tDn") { - return make_uniq(LogicalType::INTERVAL, ArrowDateTimeType::NANOSECONDS); - } else if (format == "tiD") { - return make_uniq(LogicalType::INTERVAL, ArrowDateTimeType::DAYS); - } else if (format == "tiM") { - return make_uniq(LogicalType::INTERVAL, ArrowDateTimeType::MONTHS); - } else if (format == "tin") { - return make_uniq(LogicalType::INTERVAL, ArrowDateTimeType::MONTH_DAY_NANO); - } else if (format == "+l") { - auto child_type = GetArrowLogicalType(*schema.children[0]); - auto list_type = - make_uniq(LogicalType::LIST(child_type->GetDuckType()), ArrowVariableSizeType::NORMAL); - list_type->AddChild(std::move(child_type)); - return list_type; - } else if (format == "+L") { - auto child_type = GetArrowLogicalType(*schema.children[0]); - auto list_type = - make_uniq(LogicalType::LIST(child_type->GetDuckType()), ArrowVariableSizeType::SUPER_SIZE); - list_type->AddChild(std::move(child_type)); - return list_type; - } else if (format[0] == '+' && format[1] == 'w') { - std::string parameters = format.substr(format.find(':') + 1); - idx_t fixed_size = std::stoi(parameters); - auto child_type = GetArrowLogicalType(*schema.children[0]); - auto list_type = make_uniq(LogicalType::LIST(child_type->GetDuckType()), fixed_size); - list_type->AddChild(std::move(child_type)); - return list_type; - } else if (format == "+s") { - child_list_t child_types; - vector> children; - for (idx_t type_idx = 0; type_idx < (idx_t)schema.n_children; type_idx++) { - children.emplace_back(GetArrowLogicalType(*schema.children[type_idx])); - child_types.emplace_back(schema.children[type_idx]->name, children.back()->GetDuckType()); - } - auto struct_type = make_uniq(LogicalType::STRUCT(std::move(child_types))); - struct_type->AssignChildren(std::move(children)); - return struct_type; - } else if (format[0] == '+' && format[1] == 'u') { - if (format[2] != 's') { - throw NotImplementedException("Unsupported Internal Arrow Type: \"%c\" Union", format[2]); - } - D_ASSERT(format[3] == ':'); - - std::string prefix = "+us:"; - // TODO: what are these type ids actually for? - auto type_ids = StringUtil::Split(format.substr(prefix.size()), ','); - - child_list_t members; - vector> children; - for (idx_t type_idx = 0; type_idx < (idx_t)schema.n_children; type_idx++) { - auto type = schema.children[type_idx]; - - children.emplace_back(GetArrowLogicalType(*type)); - members.emplace_back(type->name, children.back()->GetDuckType()); - } - - auto union_type = make_uniq(LogicalType::UNION(members)); - union_type->AssignChildren(std::move(children)); - return union_type; - } else if (format == "+m") { - auto &arrow_struct_type = *schema.children[0]; - D_ASSERT(arrow_struct_type.n_children == 2); - auto key_type = GetArrowLogicalType(*arrow_struct_type.children[0]); - auto value_type = GetArrowLogicalType(*arrow_struct_type.children[1]); - auto map_type = make_uniq(LogicalType::MAP(key_type->GetDuckType(), value_type->GetDuckType()), - ArrowVariableSizeType::NORMAL); - child_list_t key_value; - key_value.emplace_back(std::make_pair("key", key_type->GetDuckType())); - key_value.emplace_back(std::make_pair("value", value_type->GetDuckType())); - - auto inner_struct = - make_uniq(LogicalType::STRUCT(std::move(key_value)), ArrowVariableSizeType::NORMAL); - vector> children; - children.reserve(2); - children.push_back(std::move(key_type)); - children.push_back(std::move(value_type)); - inner_struct->AssignChildren(std::move(children)); - map_type->AddChild(std::move(inner_struct)); - return map_type; - } else if (format == "z") { - return make_uniq(LogicalType::BLOB, ArrowVariableSizeType::NORMAL); - } else if (format == "Z") { - return make_uniq(LogicalType::BLOB, ArrowVariableSizeType::SUPER_SIZE); - } else if (format[0] == 'w') { - std::string parameters = format.substr(format.find(':') + 1); - idx_t fixed_size = std::stoi(parameters); - return make_uniq(LogicalType::BLOB, fixed_size); - } else if (format[0] == 't' && format[1] == 's') { - // Timestamp with Timezone - // TODO right now we just get the UTC value. We probably want to support this properly in the future - if (format[2] == 'n') { - return make_uniq(LogicalType::TIMESTAMP_TZ, ArrowDateTimeType::NANOSECONDS); - } else if (format[2] == 'u') { - return make_uniq(LogicalType::TIMESTAMP_TZ, ArrowDateTimeType::MICROSECONDS); - } else if (format[2] == 'm') { - return make_uniq(LogicalType::TIMESTAMP_TZ, ArrowDateTimeType::MILLISECONDS); - } else if (format[2] == 's') { - return make_uniq(LogicalType::TIMESTAMP_TZ, ArrowDateTimeType::SECONDS); - } else { - throw NotImplementedException(" Timestamptz precision of not accepted"); - } - } else { - throw NotImplementedException("Unsupported Internal Arrow Type %s", format); - } -} - -void ArrowTableFunction::RenameArrowColumns(vector &names) { - unordered_map name_map; - for (auto &column_name : names) { - // put it all lower_case - auto low_column_name = StringUtil::Lower(column_name); - if (name_map.find(low_column_name) == name_map.end()) { - // Name does not exist yet - name_map[low_column_name]++; - } else { - // Name already exists, we add _x where x is the repetition number - string new_column_name = column_name + "_" + std::to_string(name_map[low_column_name]); - auto new_column_name_low = StringUtil::Lower(new_column_name); - while (name_map.find(new_column_name_low) != name_map.end()) { - // This name is already here due to a previous definition - name_map[low_column_name]++; - new_column_name = column_name + "_" + std::to_string(name_map[low_column_name]); - new_column_name_low = StringUtil::Lower(new_column_name); - } - column_name = new_column_name; - name_map[new_column_name_low]++; - } - } -} - -void ArrowTableFunction::PopulateArrowTableType(ArrowTableType &arrow_table, ArrowSchemaWrapper &schema_p, - vector &names, vector &return_types) { - for (idx_t col_idx = 0; col_idx < (idx_t)schema_p.arrow_schema.n_children; col_idx++) { - auto &schema = *schema_p.arrow_schema.children[col_idx]; - if (!schema.release) { - throw InvalidInputException("arrow_scan: released schema passed"); - } - auto arrow_type = GetArrowLogicalType(schema); - if (schema.dictionary) { - auto logical_type = arrow_type->GetDuckType(); - auto dictionary = GetArrowLogicalType(*schema.dictionary); - return_types.emplace_back(dictionary->GetDuckType()); - // The dictionary might have different attributes (size type, datetime precision, etc..) - arrow_type->SetDictionary(std::move(dictionary)); - } else { - return_types.emplace_back(arrow_type->GetDuckType()); - } - arrow_table.AddColumn(col_idx, std::move(arrow_type)); - auto format = string(schema.format); - auto name = string(schema.name); - if (name.empty()) { - name = string("v") + to_string(col_idx); - } - names.push_back(name); - } -} - -unique_ptr ArrowTableFunction::ArrowScanBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - if (input.inputs[0].IsNull() || input.inputs[1].IsNull() || input.inputs[2].IsNull()) { - throw BinderException("arrow_scan: pointers cannot be null"); - } - - auto stream_factory_ptr = input.inputs[0].GetPointer(); - auto stream_factory_produce = (stream_factory_produce_t)input.inputs[1].GetPointer(); // NOLINT - auto stream_factory_get_schema = (stream_factory_get_schema_t)input.inputs[2].GetPointer(); // NOLINT - - auto res = make_uniq(stream_factory_produce, stream_factory_ptr); - - auto &data = *res; - stream_factory_get_schema(stream_factory_ptr, data.schema_root); - PopulateArrowTableType(res->arrow_table, data.schema_root, names, return_types); - RenameArrowColumns(names); - res->all_types = return_types; - return std::move(res); -} - -unique_ptr ProduceArrowScan(const ArrowScanFunctionData &function, - const vector &column_ids, TableFilterSet *filters) { - //! Generate Projection Pushdown Vector - ArrowStreamParameters parameters; - D_ASSERT(!column_ids.empty()); - for (idx_t idx = 0; idx < column_ids.size(); idx++) { - auto col_idx = column_ids[idx]; - if (col_idx != COLUMN_IDENTIFIER_ROW_ID) { - auto &schema = *function.schema_root.arrow_schema.children[col_idx]; - parameters.projected_columns.projection_map[idx] = schema.name; - parameters.projected_columns.columns.emplace_back(schema.name); - } - } - parameters.filters = filters; - return function.scanner_producer(function.stream_factory_ptr, parameters); -} - -idx_t ArrowTableFunction::ArrowScanMaxThreads(ClientContext &context, const FunctionData *bind_data_p) { - return context.db->NumberOfThreads(); -} - -bool ArrowTableFunction::ArrowScanParallelStateNext(ClientContext &context, const FunctionData *bind_data_p, - ArrowScanLocalState &state, ArrowScanGlobalState ¶llel_state) { - lock_guard parallel_lock(parallel_state.main_mutex); - if (parallel_state.done) { - return false; - } - state.chunk_offset = 0; - state.batch_index = ++parallel_state.batch_index; - - auto current_chunk = parallel_state.stream->GetNextChunk(); - while (current_chunk->arrow_array.length == 0 && current_chunk->arrow_array.release) { - current_chunk = parallel_state.stream->GetNextChunk(); - } - state.chunk = std::move(current_chunk); - //! have we run out of chunks? we are done - if (!state.chunk->arrow_array.release) { - parallel_state.done = true; - return false; - } - return true; -} - -unique_ptr ArrowTableFunction::ArrowScanInitGlobal(ClientContext &context, - TableFunctionInitInput &input) { - auto &bind_data = input.bind_data->Cast(); - auto result = make_uniq(); - result->stream = ProduceArrowScan(bind_data, input.column_ids, input.filters.get()); - result->max_threads = ArrowScanMaxThreads(context, input.bind_data.get()); - if (input.CanRemoveFilterColumns()) { - result->projection_ids = input.projection_ids; - for (const auto &col_idx : input.column_ids) { - if (col_idx == COLUMN_IDENTIFIER_ROW_ID) { - result->scanned_types.emplace_back(LogicalType::ROW_TYPE); - } else { - result->scanned_types.push_back(bind_data.all_types[col_idx]); - } - } - } - return std::move(result); -} - -unique_ptr -ArrowTableFunction::ArrowScanInitLocalInternal(ClientContext &context, TableFunctionInitInput &input, - GlobalTableFunctionState *global_state_p) { - auto &global_state = global_state_p->Cast(); - auto current_chunk = make_uniq(); - auto result = make_uniq(std::move(current_chunk)); - result->column_ids = input.column_ids; - result->filters = input.filters.get(); - if (input.CanRemoveFilterColumns()) { - auto &asgs = global_state_p->Cast(); - result->all_columns.Initialize(context, asgs.scanned_types); - } - if (!ArrowScanParallelStateNext(context, input.bind_data.get(), *result, global_state)) { - return nullptr; - } - return std::move(result); -} - -unique_ptr ArrowTableFunction::ArrowScanInitLocal(ExecutionContext &context, - TableFunctionInitInput &input, - GlobalTableFunctionState *global_state_p) { - return ArrowScanInitLocalInternal(context.client, input, global_state_p); -} - -void ArrowTableFunction::ArrowScanFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - if (!data_p.local_state) { - return; - } - auto &data = data_p.bind_data->CastNoConst(); // FIXME - auto &state = data_p.local_state->Cast(); - auto &global_state = data_p.global_state->Cast(); - - //! Out of tuples in this chunk - if (state.chunk_offset >= (idx_t)state.chunk->arrow_array.length) { - if (!ArrowScanParallelStateNext(context, data_p.bind_data.get(), state, global_state)) { - return; - } - } - int64_t output_size = MinValue(STANDARD_VECTOR_SIZE, state.chunk->arrow_array.length - state.chunk_offset); - data.lines_read += output_size; - if (global_state.CanRemoveFilterColumns()) { - state.all_columns.Reset(); - state.all_columns.SetCardinality(output_size); - ArrowToDuckDB(state, data.arrow_table.GetColumns(), state.all_columns, data.lines_read - output_size); - output.ReferenceColumns(state.all_columns, global_state.projection_ids); - } else { - output.SetCardinality(output_size); - ArrowToDuckDB(state, data.arrow_table.GetColumns(), output, data.lines_read - output_size); - } - - output.Verify(); - state.chunk_offset += output.size(); -} - -unique_ptr ArrowTableFunction::ArrowScanCardinality(ClientContext &context, const FunctionData *data) { - return make_uniq(); -} - -idx_t ArrowTableFunction::ArrowGetBatchIndex(ClientContext &context, const FunctionData *bind_data_p, - LocalTableFunctionState *local_state, - GlobalTableFunctionState *global_state) { - auto &state = local_state->Cast(); - return state.batch_index; -} - -void ArrowTableFunction::RegisterFunction(BuiltinFunctions &set) { - TableFunction arrow("arrow_scan", {LogicalType::POINTER, LogicalType::POINTER, LogicalType::POINTER}, - ArrowScanFunction, ArrowScanBind, ArrowScanInitGlobal, ArrowScanInitLocal); - arrow.cardinality = ArrowScanCardinality; - arrow.get_batch_index = ArrowGetBatchIndex; - arrow.projection_pushdown = true; - arrow.filter_pushdown = true; - arrow.filter_prune = true; - set.AddFunction(arrow); - - TableFunction arrow_dumb("arrow_scan_dumb", {LogicalType::POINTER, LogicalType::POINTER, LogicalType::POINTER}, - ArrowScanFunction, ArrowScanBind, ArrowScanInitGlobal, ArrowScanInitLocal); - arrow_dumb.cardinality = ArrowScanCardinality; - arrow_dumb.get_batch_index = ArrowGetBatchIndex; - arrow_dumb.projection_pushdown = false; - arrow_dumb.filter_pushdown = false; - arrow_dumb.filter_prune = false; - set.AddFunction(arrow_dumb); -} - -void BuiltinFunctions::RegisterArrowFunctions() { - ArrowTableFunction::RegisterFunction(*this); -} -} // namespace duckdb - - - - - - - -namespace duckdb { - -static void ShiftRight(unsigned char *ar, int size, int shift) { - int carry = 0; - while (shift--) { - for (int i = size - 1; i >= 0; --i) { - int next = (ar[i] & 1) ? 0x80 : 0; - ar[i] = carry | (ar[i] >> 1); - carry = next; - } - } -} - -template -T *ArrowBufferData(ArrowArray &array, idx_t buffer_idx) { - return (T *)array.buffers[buffer_idx]; // NOLINT -} - -static void GetValidityMask(ValidityMask &mask, ArrowArray &array, ArrowScanLocalState &scan_state, idx_t size, - int64_t nested_offset = -1, bool add_null = false) { - // In certains we don't need to or cannot copy arrow's validity mask to duckdb. - // - // The conditions where we do want to copy arrow's mask to duckdb are: - // 1. nulls exist - // 2. n_buffers > 0, meaning the array's arrow type is not `null` - // 3. the validity buffer (the first buffer) is not a nullptr - if (array.null_count != 0 && array.n_buffers > 0 && array.buffers[0]) { - auto bit_offset = scan_state.chunk_offset + array.offset; - if (nested_offset != -1) { - bit_offset = nested_offset; - } - mask.EnsureWritable(); -#if STANDARD_VECTOR_SIZE > 64 - auto n_bitmask_bytes = (size + 8 - 1) / 8; - if (bit_offset % 8 == 0) { - //! just memcpy nullmask - memcpy((void *)mask.GetData(), ArrowBufferData(array, 0) + bit_offset / 8, n_bitmask_bytes); - } else { - //! need to re-align nullmask - vector temp_nullmask(n_bitmask_bytes + 1); - memcpy(temp_nullmask.data(), ArrowBufferData(array, 0) + bit_offset / 8, n_bitmask_bytes + 1); - ShiftRight(temp_nullmask.data(), n_bitmask_bytes + 1, - bit_offset % 8); //! why this has to be a right shift is a mystery to me - memcpy((void *)mask.GetData(), data_ptr_cast(temp_nullmask.data()), n_bitmask_bytes); - } -#else - auto byte_offset = bit_offset / 8; - auto source_data = ArrowBufferData(array, 0); - bit_offset %= 8; - for (idx_t i = 0; i < size; i++) { - mask.Set(i, source_data[byte_offset] & (1 << bit_offset)); - bit_offset++; - if (bit_offset == 8) { - bit_offset = 0; - byte_offset++; - } - } -#endif - } - if (add_null) { - //! We are setting a validity mask of the data part of dictionary vector - //! For some reason, Nulls are allowed to be indexes, hence we need to set the last element here to be null - //! We might have to resize the mask - mask.Resize(size, size + 1); - mask.SetInvalid(size); - } -} - -static void SetValidityMask(Vector &vector, ArrowArray &array, ArrowScanLocalState &scan_state, idx_t size, - int64_t nested_offset, bool add_null = false) { - D_ASSERT(vector.GetVectorType() == VectorType::FLAT_VECTOR); - auto &mask = FlatVector::Validity(vector); - GetValidityMask(mask, array, scan_state, size, nested_offset, add_null); -} - -static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowScanLocalState &scan_state, idx_t size, - const ArrowType &arrow_type, int64_t nested_offset = -1, - ValidityMask *parent_mask = nullptr, uint64_t parent_offset = 0); - -static void ArrowToDuckDBList(Vector &vector, ArrowArray &array, ArrowScanLocalState &scan_state, idx_t size, - const ArrowType &arrow_type, int64_t nested_offset, ValidityMask *parent_mask) { - auto size_type = arrow_type.GetSizeType(); - idx_t list_size = 0; - SetValidityMask(vector, array, scan_state, size, nested_offset); - idx_t start_offset = 0; - idx_t cur_offset = 0; - if (size_type == ArrowVariableSizeType::FIXED_SIZE) { - auto fixed_size = arrow_type.FixedSize(); - //! Have to check validity mask before setting this up - idx_t offset = (scan_state.chunk_offset + array.offset) * fixed_size; - if (nested_offset != -1) { - offset = fixed_size * nested_offset; - } - start_offset = offset; - auto list_data = FlatVector::GetData(vector); - for (idx_t i = 0; i < size; i++) { - auto &le = list_data[i]; - le.offset = cur_offset; - le.length = fixed_size; - cur_offset += fixed_size; - } - list_size = start_offset + cur_offset; - } else if (size_type == ArrowVariableSizeType::NORMAL) { - auto offsets = ArrowBufferData(array, 1) + array.offset + scan_state.chunk_offset; - if (nested_offset != -1) { - offsets = ArrowBufferData(array, 1) + nested_offset; - } - start_offset = offsets[0]; - auto list_data = FlatVector::GetData(vector); - for (idx_t i = 0; i < size; i++) { - auto &le = list_data[i]; - le.offset = cur_offset; - le.length = offsets[i + 1] - offsets[i]; - cur_offset += le.length; - } - list_size = offsets[size]; - } else { - auto offsets = ArrowBufferData(array, 1) + array.offset + scan_state.chunk_offset; - if (nested_offset != -1) { - offsets = ArrowBufferData(array, 1) + nested_offset; - } - start_offset = offsets[0]; - auto list_data = FlatVector::GetData(vector); - for (idx_t i = 0; i < size; i++) { - auto &le = list_data[i]; - le.offset = cur_offset; - le.length = offsets[i + 1] - offsets[i]; - cur_offset += le.length; - } - list_size = offsets[size]; - } - list_size -= start_offset; - ListVector::Reserve(vector, list_size); - ListVector::SetListSize(vector, list_size); - auto &child_vector = ListVector::GetEntry(vector); - SetValidityMask(child_vector, *array.children[0], scan_state, list_size, start_offset); - auto &list_mask = FlatVector::Validity(vector); - if (parent_mask) { - //! Since this List is owned by a struct we must guarantee their validity map matches on Null - if (!parent_mask->AllValid()) { - for (idx_t i = 0; i < size; i++) { - if (!parent_mask->RowIsValid(i)) { - list_mask.SetInvalid(i); - } - } - } - } - if (list_size == 0 && start_offset == 0) { - ColumnArrowToDuckDB(child_vector, *array.children[0], scan_state, list_size, arrow_type[0], -1); - } else { - ColumnArrowToDuckDB(child_vector, *array.children[0], scan_state, list_size, arrow_type[0], start_offset); - } -} - -static void ArrowToDuckDBBlob(Vector &vector, ArrowArray &array, ArrowScanLocalState &scan_state, idx_t size, - const ArrowType &arrow_type, int64_t nested_offset) { - auto size_type = arrow_type.GetSizeType(); - SetValidityMask(vector, array, scan_state, size, nested_offset); - if (size_type == ArrowVariableSizeType::FIXED_SIZE) { - auto fixed_size = arrow_type.FixedSize(); - //! Have to check validity mask before setting this up - idx_t offset = (scan_state.chunk_offset + array.offset) * fixed_size; - if (nested_offset != -1) { - offset = fixed_size * nested_offset; - } - auto cdata = ArrowBufferData(array, 1); - for (idx_t row_idx = 0; row_idx < size; row_idx++) { - if (FlatVector::IsNull(vector, row_idx)) { - continue; - } - auto bptr = cdata + offset; - auto blob_len = fixed_size; - FlatVector::GetData(vector)[row_idx] = StringVector::AddStringOrBlob(vector, bptr, blob_len); - offset += blob_len; - } - } else if (size_type == ArrowVariableSizeType::NORMAL) { - auto offsets = ArrowBufferData(array, 1) + array.offset + scan_state.chunk_offset; - if (nested_offset != -1) { - offsets = ArrowBufferData(array, 1) + array.offset + nested_offset; - } - auto cdata = ArrowBufferData(array, 2); - for (idx_t row_idx = 0; row_idx < size; row_idx++) { - if (FlatVector::IsNull(vector, row_idx)) { - continue; - } - auto bptr = cdata + offsets[row_idx]; - auto blob_len = offsets[row_idx + 1] - offsets[row_idx]; - FlatVector::GetData(vector)[row_idx] = StringVector::AddStringOrBlob(vector, bptr, blob_len); - } - } else { - //! Check if last offset is higher than max uint32 - if (ArrowBufferData(array, 1)[array.length] > NumericLimits::Maximum()) { // LCOV_EXCL_START - throw ConversionException("DuckDB does not support Blobs over 4GB"); - } // LCOV_EXCL_STOP - auto offsets = ArrowBufferData(array, 1) + array.offset + scan_state.chunk_offset; - if (nested_offset != -1) { - offsets = ArrowBufferData(array, 1) + array.offset + nested_offset; - } - auto cdata = ArrowBufferData(array, 2); - for (idx_t row_idx = 0; row_idx < size; row_idx++) { - if (FlatVector::IsNull(vector, row_idx)) { - continue; - } - auto bptr = cdata + offsets[row_idx]; - auto blob_len = offsets[row_idx + 1] - offsets[row_idx]; - FlatVector::GetData(vector)[row_idx] = StringVector::AddStringOrBlob(vector, bptr, blob_len); - } - } -} - -static void ArrowToDuckDBMapVerify(Vector &vector, idx_t count) { - auto valid_check = MapVector::CheckMapValidity(vector, count); - switch (valid_check) { - case MapInvalidReason::VALID: - break; - case MapInvalidReason::DUPLICATE_KEY: { - throw InvalidInputException("Arrow map contains duplicate key, which isn't supported by DuckDB map type"); - } - case MapInvalidReason::NULL_KEY: { - throw InvalidInputException("Arrow map contains NULL as map key, which isn't supported by DuckDB map type"); - } - case MapInvalidReason::NULL_KEY_LIST: { - throw InvalidInputException("Arrow map contains NULL as key list, which isn't supported by DuckDB map type"); - } - default: { - throw InternalException("MapInvalidReason not implemented"); - } - } -} - -template -static void SetVectorString(Vector &vector, idx_t size, char *cdata, T *offsets) { - auto strings = FlatVector::GetData(vector); - for (idx_t row_idx = 0; row_idx < size; row_idx++) { - if (FlatVector::IsNull(vector, row_idx)) { - continue; - } - auto cptr = cdata + offsets[row_idx]; - auto str_len = offsets[row_idx + 1] - offsets[row_idx]; - if (str_len > NumericLimits::Maximum()) { // LCOV_EXCL_START - throw ConversionException("DuckDB does not support Strings over 4GB"); - } // LCOV_EXCL_STOP - strings[row_idx] = string_t(cptr, str_len); - } -} - -static void DirectConversion(Vector &vector, ArrowArray &array, ArrowScanLocalState &scan_state, int64_t nested_offset, - uint64_t parent_offset) { - auto internal_type = GetTypeIdSize(vector.GetType().InternalType()); - auto data_ptr = - ArrowBufferData(array, 1) + internal_type * (scan_state.chunk_offset + array.offset + parent_offset); - if (nested_offset != -1) { - data_ptr = ArrowBufferData(array, 1) + internal_type * (array.offset + nested_offset + parent_offset); - } - FlatVector::SetData(vector, data_ptr); -} - -template -static void TimeConversion(Vector &vector, ArrowArray &array, ArrowScanLocalState &scan_state, int64_t nested_offset, - idx_t size, int64_t conversion) { - auto tgt_ptr = FlatVector::GetData(vector); - auto &validity_mask = FlatVector::Validity(vector); - auto src_ptr = (T *)array.buffers[1] + scan_state.chunk_offset + array.offset; - if (nested_offset != -1) { - src_ptr = (T *)array.buffers[1] + nested_offset + array.offset; - } - for (idx_t row = 0; row < size; row++) { - if (!validity_mask.RowIsValid(row)) { - continue; - } - if (!TryMultiplyOperator::Operation((int64_t)src_ptr[row], conversion, tgt_ptr[row].micros)) { - throw ConversionException("Could not convert Time to Microsecond"); - } - } -} - -static void TimestampTZConversion(Vector &vector, ArrowArray &array, ArrowScanLocalState &scan_state, - int64_t nested_offset, idx_t size, int64_t conversion) { - auto tgt_ptr = FlatVector::GetData(vector); - auto &validity_mask = FlatVector::Validity(vector); - auto src_ptr = ArrowBufferData(array, 1) + scan_state.chunk_offset + array.offset; - if (nested_offset != -1) { - src_ptr = ArrowBufferData(array, 1) + nested_offset + array.offset; - } - for (idx_t row = 0; row < size; row++) { - if (!validity_mask.RowIsValid(row)) { - continue; - } - if (!TryMultiplyOperator::Operation(src_ptr[row], conversion, tgt_ptr[row].value)) { - throw ConversionException("Could not convert TimestampTZ to Microsecond"); - } - } -} - -static void IntervalConversionUs(Vector &vector, ArrowArray &array, ArrowScanLocalState &scan_state, - int64_t nested_offset, idx_t size, int64_t conversion) { - auto tgt_ptr = FlatVector::GetData(vector); - auto src_ptr = ArrowBufferData(array, 1) + scan_state.chunk_offset + array.offset; - if (nested_offset != -1) { - src_ptr = ArrowBufferData(array, 1) + nested_offset + array.offset; - } - for (idx_t row = 0; row < size; row++) { - tgt_ptr[row].days = 0; - tgt_ptr[row].months = 0; - if (!TryMultiplyOperator::Operation(src_ptr[row], conversion, tgt_ptr[row].micros)) { - throw ConversionException("Could not convert Interval to Microsecond"); - } - } -} - -static void IntervalConversionMonths(Vector &vector, ArrowArray &array, ArrowScanLocalState &scan_state, - int64_t nested_offset, idx_t size) { - auto tgt_ptr = FlatVector::GetData(vector); - auto src_ptr = ArrowBufferData(array, 1) + scan_state.chunk_offset + array.offset; - if (nested_offset != -1) { - src_ptr = ArrowBufferData(array, 1) + nested_offset + array.offset; - } - for (idx_t row = 0; row < size; row++) { - tgt_ptr[row].days = 0; - tgt_ptr[row].micros = 0; - tgt_ptr[row].months = src_ptr[row]; - } -} - -static void IntervalConversionMonthDayNanos(Vector &vector, ArrowArray &array, ArrowScanLocalState &scan_state, - int64_t nested_offset, idx_t size) { - auto tgt_ptr = FlatVector::GetData(vector); - auto src_ptr = ArrowBufferData(array, 1) + scan_state.chunk_offset + array.offset; - if (nested_offset != -1) { - src_ptr = ArrowBufferData(array, 1) + nested_offset + array.offset; - } - for (idx_t row = 0; row < size; row++) { - tgt_ptr[row].days = src_ptr[row].days; - tgt_ptr[row].micros = src_ptr[row].nanoseconds / Interval::NANOS_PER_MICRO; - tgt_ptr[row].months = src_ptr[row].months; - } -} - -static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowScanLocalState &scan_state, idx_t size, - const ArrowType &arrow_type, int64_t nested_offset, ValidityMask *parent_mask, - uint64_t parent_offset) { - switch (vector.GetType().id()) { - case LogicalTypeId::SQLNULL: - vector.Reference(Value()); - break; - case LogicalTypeId::BOOLEAN: { - //! Arrow bit-packs boolean values - //! Lets first figure out where we are in the source array - auto src_ptr = ArrowBufferData(array, 1) + (scan_state.chunk_offset + array.offset) / 8; - - if (nested_offset != -1) { - src_ptr = ArrowBufferData(array, 1) + (nested_offset + array.offset) / 8; - } - auto tgt_ptr = (uint8_t *)FlatVector::GetData(vector); - int src_pos = 0; - idx_t cur_bit = scan_state.chunk_offset % 8; - if (nested_offset != -1) { - cur_bit = nested_offset % 8; - } - for (idx_t row = 0; row < size; row++) { - if ((src_ptr[src_pos] & (1 << cur_bit)) == 0) { - tgt_ptr[row] = 0; - } else { - tgt_ptr[row] = 1; - } - cur_bit++; - if (cur_bit == 8) { - src_pos++; - cur_bit = 0; - } - } - break; - } - case LogicalTypeId::TINYINT: - case LogicalTypeId::SMALLINT: - case LogicalTypeId::INTEGER: - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - case LogicalTypeId::UTINYINT: - case LogicalTypeId::USMALLINT: - case LogicalTypeId::UINTEGER: - case LogicalTypeId::UBIGINT: - case LogicalTypeId::BIGINT: - case LogicalTypeId::HUGEINT: - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_SEC: - case LogicalTypeId::TIMESTAMP_MS: - case LogicalTypeId::TIMESTAMP_NS: { - DirectConversion(vector, array, scan_state, nested_offset, parent_offset); - break; - } - case LogicalTypeId::VARCHAR: { - auto size_type = arrow_type.GetSizeType(); - auto cdata = ArrowBufferData(array, 2); - if (size_type == ArrowVariableSizeType::SUPER_SIZE) { - auto offsets = ArrowBufferData(array, 1) + array.offset + scan_state.chunk_offset; - if (nested_offset != -1) { - offsets = ArrowBufferData(array, 1) + array.offset + nested_offset; - } - SetVectorString(vector, size, cdata, offsets); - } else { - auto offsets = ArrowBufferData(array, 1) + array.offset + scan_state.chunk_offset; - if (nested_offset != -1) { - offsets = ArrowBufferData(array, 1) + array.offset + nested_offset; - } - SetVectorString(vector, size, cdata, offsets); - } - break; - } - case LogicalTypeId::DATE: { - - auto precision = arrow_type.GetDateTimeType(); - switch (precision) { - case ArrowDateTimeType::DAYS: { - DirectConversion(vector, array, scan_state, nested_offset, parent_offset); - break; - } - case ArrowDateTimeType::MILLISECONDS: { - //! convert date from nanoseconds to days - auto src_ptr = ArrowBufferData(array, 1) + scan_state.chunk_offset + array.offset; - if (nested_offset != -1) { - src_ptr = ArrowBufferData(array, 1) + nested_offset + array.offset; - } - auto tgt_ptr = FlatVector::GetData(vector); - for (idx_t row = 0; row < size; row++) { - tgt_ptr[row] = date_t(int64_t(src_ptr[row]) / static_cast(1000 * 60 * 60 * 24)); - } - break; - } - default: - throw NotImplementedException("Unsupported precision for Date Type "); - } - break; - } - case LogicalTypeId::TIME: { - auto precision = arrow_type.GetDateTimeType(); - switch (precision) { - case ArrowDateTimeType::SECONDS: { - TimeConversion(vector, array, scan_state, nested_offset, size, 1000000); - break; - } - case ArrowDateTimeType::MILLISECONDS: { - TimeConversion(vector, array, scan_state, nested_offset, size, 1000); - break; - } - case ArrowDateTimeType::MICROSECONDS: { - TimeConversion(vector, array, scan_state, nested_offset, size, 1); - break; - } - case ArrowDateTimeType::NANOSECONDS: { - auto tgt_ptr = FlatVector::GetData(vector); - auto src_ptr = ArrowBufferData(array, 1) + scan_state.chunk_offset + array.offset; - if (nested_offset != -1) { - src_ptr = ArrowBufferData(array, 1) + nested_offset + array.offset; - } - for (idx_t row = 0; row < size; row++) { - tgt_ptr[row].micros = src_ptr[row] / 1000; - } - break; - } - default: - throw NotImplementedException("Unsupported precision for Time Type "); - } - break; - } - case LogicalTypeId::TIMESTAMP_TZ: { - auto precision = arrow_type.GetDateTimeType(); - switch (precision) { - case ArrowDateTimeType::SECONDS: { - TimestampTZConversion(vector, array, scan_state, nested_offset, size, 1000000); - break; - } - case ArrowDateTimeType::MILLISECONDS: { - TimestampTZConversion(vector, array, scan_state, nested_offset, size, 1000); - break; - } - case ArrowDateTimeType::MICROSECONDS: { - DirectConversion(vector, array, scan_state, nested_offset, parent_offset); - break; - } - case ArrowDateTimeType::NANOSECONDS: { - auto tgt_ptr = FlatVector::GetData(vector); - auto src_ptr = ArrowBufferData(array, 1) + scan_state.chunk_offset + array.offset; - if (nested_offset != -1) { - src_ptr = ArrowBufferData(array, 1) + nested_offset + array.offset; - } - for (idx_t row = 0; row < size; row++) { - tgt_ptr[row].value = src_ptr[row] / 1000; - } - break; - } - default: - throw NotImplementedException("Unsupported precision for TimestampTZ Type "); - } - break; - } - case LogicalTypeId::INTERVAL: { - auto precision = arrow_type.GetDateTimeType(); - switch (precision) { - case ArrowDateTimeType::SECONDS: { - IntervalConversionUs(vector, array, scan_state, nested_offset, size, 1000000); - break; - } - case ArrowDateTimeType::DAYS: - case ArrowDateTimeType::MILLISECONDS: { - IntervalConversionUs(vector, array, scan_state, nested_offset, size, 1000); - break; - } - case ArrowDateTimeType::MICROSECONDS: { - IntervalConversionUs(vector, array, scan_state, nested_offset, size, 1); - break; - } - case ArrowDateTimeType::NANOSECONDS: { - auto tgt_ptr = FlatVector::GetData(vector); - auto src_ptr = ArrowBufferData(array, 1) + scan_state.chunk_offset + array.offset; - if (nested_offset != -1) { - src_ptr = ArrowBufferData(array, 1) + nested_offset + array.offset; - } - for (idx_t row = 0; row < size; row++) { - tgt_ptr[row].micros = src_ptr[row] / 1000; - tgt_ptr[row].days = 0; - tgt_ptr[row].months = 0; - } - break; - } - case ArrowDateTimeType::MONTHS: { - IntervalConversionMonths(vector, array, scan_state, nested_offset, size); - break; - } - case ArrowDateTimeType::MONTH_DAY_NANO: { - IntervalConversionMonthDayNanos(vector, array, scan_state, nested_offset, size); - break; - } - default: - throw NotImplementedException("Unsupported precision for Interval/Duration Type "); - } - break; - } - case LogicalTypeId::DECIMAL: { - auto val_mask = FlatVector::Validity(vector); - //! We have to convert from INT128 - auto src_ptr = ArrowBufferData(array, 1) + scan_state.chunk_offset + array.offset; - if (nested_offset != -1) { - src_ptr = ArrowBufferData(array, 1) + nested_offset + array.offset; - } - switch (vector.GetType().InternalType()) { - case PhysicalType::INT16: { - auto tgt_ptr = FlatVector::GetData(vector); - for (idx_t row = 0; row < size; row++) { - if (val_mask.RowIsValid(row)) { - auto result = Hugeint::TryCast(src_ptr[row], tgt_ptr[row]); - D_ASSERT(result); - (void)result; - } - } - break; - } - case PhysicalType::INT32: { - auto tgt_ptr = FlatVector::GetData(vector); - for (idx_t row = 0; row < size; row++) { - if (val_mask.RowIsValid(row)) { - auto result = Hugeint::TryCast(src_ptr[row], tgt_ptr[row]); - D_ASSERT(result); - (void)result; - } - } - break; - } - case PhysicalType::INT64: { - auto tgt_ptr = FlatVector::GetData(vector); - for (idx_t row = 0; row < size; row++) { - if (val_mask.RowIsValid(row)) { - auto result = Hugeint::TryCast(src_ptr[row], tgt_ptr[row]); - D_ASSERT(result); - (void)result; - } - } - break; - } - case PhysicalType::INT128: { - FlatVector::SetData(vector, - ArrowBufferData(array, 1) + GetTypeIdSize(vector.GetType().InternalType()) * - (scan_state.chunk_offset + array.offset)); - break; - } - default: - throw NotImplementedException("Unsupported physical type for Decimal: %s", - TypeIdToString(vector.GetType().InternalType())); - } - break; - } - case LogicalTypeId::BLOB: { - ArrowToDuckDBBlob(vector, array, scan_state, size, arrow_type, nested_offset); - break; - } - case LogicalTypeId::LIST: { - ArrowToDuckDBList(vector, array, scan_state, size, arrow_type, nested_offset, parent_mask); - break; - } - case LogicalTypeId::MAP: { - ArrowToDuckDBList(vector, array, scan_state, size, arrow_type, nested_offset, parent_mask); - ArrowToDuckDBMapVerify(vector, size); - break; - } - case LogicalTypeId::STRUCT: { - //! Fill the children - auto &child_entries = StructVector::GetEntries(vector); - auto &struct_validity_mask = FlatVector::Validity(vector); - for (idx_t type_idx = 0; type_idx < static_cast(array.n_children); type_idx++) { - SetValidityMask(*child_entries[type_idx], *array.children[type_idx], scan_state, size, nested_offset); - if (!struct_validity_mask.AllValid()) { - auto &child_validity_mark = FlatVector::Validity(*child_entries[type_idx]); - for (idx_t i = 0; i < size; i++) { - if (!struct_validity_mask.RowIsValid(i)) { - child_validity_mark.SetInvalid(i); - } - } - } - ColumnArrowToDuckDB(*child_entries[type_idx], *array.children[type_idx], scan_state, size, - arrow_type[type_idx], nested_offset, &struct_validity_mask, array.offset); - } - break; - } - case LogicalTypeId::UNION: { - auto type_ids = ArrowBufferData(array, array.n_buffers == 1 ? 0 : 1); - D_ASSERT(type_ids); - auto members = UnionType::CopyMemberTypes(vector.GetType()); - - auto &validity_mask = FlatVector::Validity(vector); - - duckdb::vector children; - for (idx_t type_idx = 0; type_idx < static_cast(array.n_children); type_idx++) { - Vector child(members[type_idx].second); - auto arrow_array = array.children[type_idx]; - - SetValidityMask(child, *arrow_array, scan_state, size, nested_offset); - - ColumnArrowToDuckDB(child, *arrow_array, scan_state, size, arrow_type, nested_offset, &validity_mask); - - children.push_back(std::move(child)); - } - - for (idx_t row_idx = 0; row_idx < size; row_idx++) { - auto tag = type_ids[row_idx]; - - auto out_of_range = tag < 0 || tag >= array.n_children; - if (out_of_range) { - throw InvalidInputException("Arrow union tag out of range: %d", tag); - } - - const Value &value = children[tag].GetValue(row_idx); - vector.SetValue(row_idx, value.IsNull() ? Value() : Value::UNION(members, tag, value)); - } - - break; - } - default: - throw NotImplementedException("Unsupported type for arrow conversion: %s", vector.GetType().ToString()); - } -} - -template -static void SetSelectionVectorLoop(SelectionVector &sel, data_ptr_t indices_p, idx_t size) { - auto indices = reinterpret_cast(indices_p); - for (idx_t row = 0; row < size; row++) { - sel.set_index(row, indices[row]); - } -} - -template -static void SetSelectionVectorLoopWithChecks(SelectionVector &sel, data_ptr_t indices_p, idx_t size) { - - auto indices = reinterpret_cast(indices_p); - for (idx_t row = 0; row < size; row++) { - if (indices[row] > NumericLimits::Maximum()) { - throw ConversionException("DuckDB only supports indices that fit on an uint32"); - } - sel.set_index(row, indices[row]); - } -} - -template -static void SetMaskedSelectionVectorLoop(SelectionVector &sel, data_ptr_t indices_p, idx_t size, ValidityMask &mask, - idx_t last_element_pos) { - auto indices = reinterpret_cast(indices_p); - for (idx_t row = 0; row < size; row++) { - if (mask.RowIsValid(row)) { - sel.set_index(row, indices[row]); - } else { - //! Need to point out to last element - sel.set_index(row, last_element_pos); - } - } -} - -static void SetSelectionVector(SelectionVector &sel, data_ptr_t indices_p, LogicalType &logical_type, idx_t size, - ValidityMask *mask = nullptr, idx_t last_element_pos = 0) { - sel.Initialize(size); - - if (mask) { - switch (logical_type.id()) { - case LogicalTypeId::UTINYINT: - SetMaskedSelectionVectorLoop(sel, indices_p, size, *mask, last_element_pos); - break; - case LogicalTypeId::TINYINT: - SetMaskedSelectionVectorLoop(sel, indices_p, size, *mask, last_element_pos); - break; - case LogicalTypeId::USMALLINT: - SetMaskedSelectionVectorLoop(sel, indices_p, size, *mask, last_element_pos); - break; - case LogicalTypeId::SMALLINT: - SetMaskedSelectionVectorLoop(sel, indices_p, size, *mask, last_element_pos); - break; - case LogicalTypeId::UINTEGER: - if (last_element_pos > NumericLimits::Maximum()) { - //! Its guaranteed that our indices will point to the last element, so just throw an error - throw ConversionException("DuckDB only supports indices that fit on an uint32"); - } - SetMaskedSelectionVectorLoop(sel, indices_p, size, *mask, last_element_pos); - break; - case LogicalTypeId::INTEGER: - SetMaskedSelectionVectorLoop(sel, indices_p, size, *mask, last_element_pos); - break; - case LogicalTypeId::UBIGINT: - if (last_element_pos > NumericLimits::Maximum()) { - //! Its guaranteed that our indices will point to the last element, so just throw an error - throw ConversionException("DuckDB only supports indices that fit on an uint32"); - } - SetMaskedSelectionVectorLoop(sel, indices_p, size, *mask, last_element_pos); - break; - case LogicalTypeId::BIGINT: - if (last_element_pos > NumericLimits::Maximum()) { - //! Its guaranteed that our indices will point to the last element, so just throw an error - throw ConversionException("DuckDB only supports indices that fit on an uint32"); - } - SetMaskedSelectionVectorLoop(sel, indices_p, size, *mask, last_element_pos); - break; - - default: - throw NotImplementedException("(Arrow) Unsupported type for selection vectors %s", logical_type.ToString()); - } - - } else { - switch (logical_type.id()) { - case LogicalTypeId::UTINYINT: - SetSelectionVectorLoop(sel, indices_p, size); - break; - case LogicalTypeId::TINYINT: - SetSelectionVectorLoop(sel, indices_p, size); - break; - case LogicalTypeId::USMALLINT: - SetSelectionVectorLoop(sel, indices_p, size); - break; - case LogicalTypeId::SMALLINT: - SetSelectionVectorLoop(sel, indices_p, size); - break; - case LogicalTypeId::UINTEGER: - SetSelectionVectorLoop(sel, indices_p, size); - break; - case LogicalTypeId::INTEGER: - SetSelectionVectorLoop(sel, indices_p, size); - break; - case LogicalTypeId::UBIGINT: - if (last_element_pos > NumericLimits::Maximum()) { - //! We need to check if our indexes fit in a uint32_t - SetSelectionVectorLoopWithChecks(sel, indices_p, size); - } else { - SetSelectionVectorLoop(sel, indices_p, size); - } - break; - case LogicalTypeId::BIGINT: - if (last_element_pos > NumericLimits::Maximum()) { - //! We need to check if our indexes fit in a uint32_t - SetSelectionVectorLoopWithChecks(sel, indices_p, size); - } else { - SetSelectionVectorLoop(sel, indices_p, size); - } - break; - default: - throw ConversionException("(Arrow) Unsupported type for selection vectors %s", logical_type.ToString()); - } - } -} - -static void ColumnArrowToDuckDBDictionary(Vector &vector, ArrowArray &array, ArrowScanLocalState &scan_state, - idx_t size, const ArrowType &arrow_type, idx_t col_idx) { - SelectionVector sel; - auto &dict_vectors = scan_state.arrow_dictionary_vectors; - if (!dict_vectors.count(col_idx)) { - //! We need to set the dictionary data for this column - auto base_vector = make_uniq(vector.GetType(), array.dictionary->length); - SetValidityMask(*base_vector, *array.dictionary, scan_state, array.dictionary->length, 0, array.null_count > 0); - ColumnArrowToDuckDB(*base_vector, *array.dictionary, scan_state, array.dictionary->length, - arrow_type.GetDictionary()); - dict_vectors[col_idx] = std::move(base_vector); - } - auto dictionary_type = arrow_type.GetDuckType(); - //! Get Pointer to Indices of Dictionary - auto indices = ArrowBufferData(array, 1) + - GetTypeIdSize(dictionary_type.InternalType()) * (scan_state.chunk_offset + array.offset); - if (array.null_count > 0) { - ValidityMask indices_validity; - GetValidityMask(indices_validity, array, scan_state, size); - SetSelectionVector(sel, indices, dictionary_type, size, &indices_validity, array.dictionary->length); - } else { - SetSelectionVector(sel, indices, dictionary_type, size); - } - vector.Slice(*dict_vectors[col_idx], sel, size); -} - -void ArrowTableFunction::ArrowToDuckDB(ArrowScanLocalState &scan_state, const arrow_column_map_t &arrow_convert_data, - DataChunk &output, idx_t start, bool arrow_scan_is_projected) { - for (idx_t idx = 0; idx < output.ColumnCount(); idx++) { - auto col_idx = scan_state.column_ids[idx]; - - // If projection was not pushed down into the arrow scanner, but projection pushdown is enabled on the - // table function, we need to use original column ids here. - auto arrow_array_idx = arrow_scan_is_projected ? idx : col_idx; - - if (col_idx == COLUMN_IDENTIFIER_ROW_ID) { - // This column is skipped by the projection pushdown - continue; - } - - auto &array = *scan_state.chunk->arrow_array.children[arrow_array_idx]; - if (!array.release) { - throw InvalidInputException("arrow_scan: released array passed"); - } - if (array.length != scan_state.chunk->arrow_array.length) { - throw InvalidInputException("arrow_scan: array length mismatch"); - } - // Make sure this Vector keeps the Arrow chunk alive in case we can zero-copy the data - if (scan_state.arrow_owned_data.find(idx) == scan_state.arrow_owned_data.end()) { - auto arrow_data = make_shared(); - arrow_data->arrow_array = scan_state.chunk->arrow_array; - scan_state.chunk->arrow_array.release = nullptr; - scan_state.arrow_owned_data[idx] = arrow_data; - } - - output.data[idx].GetBuffer()->SetAuxiliaryData(make_uniq(scan_state.arrow_owned_data[idx])); - - D_ASSERT(arrow_convert_data.find(col_idx) != arrow_convert_data.end()); - auto &arrow_type = *arrow_convert_data.at(col_idx); - if (array.dictionary) { - ColumnArrowToDuckDBDictionary(output.data[idx], array, scan_state, output.size(), arrow_type, col_idx); - } else { - SetValidityMask(output.data[idx], array, scan_state, output.size(), -1); - ColumnArrowToDuckDB(output.data[idx], array, scan_state, output.size(), arrow_type); - } - } -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -struct CheckpointBindData : public FunctionData { - explicit CheckpointBindData(optional_ptr db) : db(db) { - } - - optional_ptr db; - -public: - unique_ptr Copy() const override { - return make_uniq(db); - } - - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return db == other.db; - } -}; - -static unique_ptr CheckpointBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - return_types.emplace_back(LogicalType::BOOLEAN); - names.emplace_back("Success"); - - optional_ptr db; - auto &db_manager = DatabaseManager::Get(context); - if (!input.inputs.empty()) { - if (input.inputs[0].IsNull()) { - throw BinderException("Database cannot be NULL"); - } - auto &db_name = StringValue::Get(input.inputs[0]); - db = db_manager.GetDatabase(context, db_name); - if (!db) { - throw BinderException("Database \"%s\" not found", db_name); - } - } else { - db = db_manager.GetDatabase(context, DatabaseManager::GetDefaultDatabase(context)); - } - return make_uniq(db); -} - -template -static void TemplatedCheckpointFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &bind_data = data_p.bind_data->Cast(); - auto &transaction_manager = TransactionManager::Get(*bind_data.db.get_mutable()); - transaction_manager.Checkpoint(context, FORCE); -} - -void CheckpointFunction::RegisterFunction(BuiltinFunctions &set) { - TableFunctionSet checkpoint("checkpoint"); - checkpoint.AddFunction(TableFunction({}, TemplatedCheckpointFunction, CheckpointBind)); - checkpoint.AddFunction(TableFunction({LogicalType::VARCHAR}, TemplatedCheckpointFunction, CheckpointBind)); - set.AddFunction(checkpoint); - - TableFunctionSet force_checkpoint("force_checkpoint"); - force_checkpoint.AddFunction(TableFunction({}, TemplatedCheckpointFunction, CheckpointBind)); - force_checkpoint.AddFunction( - TableFunction({LogicalType::VARCHAR}, TemplatedCheckpointFunction, CheckpointBind)); - set.AddFunction(force_checkpoint); -} - -} // namespace duckdb - - - - - - - - - - - - - - - -#include - -namespace duckdb { - -void AreOptionsEqual(char &str_1, char &str_2, const string &name_str_1, const string &name_str_2) { - if (str_1 == '\0' || str_2 == '\0') { - return; - } - if (str_1 == str_2) { - throw BinderException("%s must not appear in the %s specification and vice versa", name_str_1, name_str_2); - } -} - -void SubstringDetection(char &str_1, string &str_2, const string &name_str_1, const string &name_str_2) { - if (str_1 == '\0' || str_2.empty()) { - return; - } - if (str_2.find(str_1) != string::npos) { - throw BinderException("%s must not appear in the %s specification and vice versa", name_str_1, name_str_2); - } -} - -//===--------------------------------------------------------------------===// -// Bind -//===--------------------------------------------------------------------===// -void WriteQuoteOrEscape(WriteStream &writer, char quote_or_escape) { - if (quote_or_escape != '\0') { - writer.Write(quote_or_escape); - } -} - -void BaseCSVData::Finalize() { - // verify that the options are correct in the final pass - if (options.dialect_options.state_machine_options.escape == '\0') { - options.dialect_options.state_machine_options.escape = options.dialect_options.state_machine_options.quote; - } - // escape and delimiter must not be substrings of each other - if (options.has_delimiter && options.has_escape) { - AreOptionsEqual(options.dialect_options.state_machine_options.delimiter, - options.dialect_options.state_machine_options.escape, "DELIMITER", "ESCAPE"); - } - // delimiter and quote must not be substrings of each other - if (options.has_quote && options.has_delimiter) { - AreOptionsEqual(options.dialect_options.state_machine_options.quote, - options.dialect_options.state_machine_options.delimiter, "DELIMITER", "QUOTE"); - } - // escape and quote must not be substrings of each other (but can be the same) - if (options.dialect_options.state_machine_options.quote != options.dialect_options.state_machine_options.escape && - options.has_quote && options.has_escape) { - AreOptionsEqual(options.dialect_options.state_machine_options.quote, - options.dialect_options.state_machine_options.escape, "QUOTE", "ESCAPE"); - } - if (!options.null_str.empty()) { - // null string and delimiter must not be substrings of each other - if (options.has_delimiter) { - SubstringDetection(options.dialect_options.state_machine_options.delimiter, options.null_str, "DELIMITER", - "NULL"); - } - // quote/escape and nullstr must not be substrings of each other - if (options.has_quote) { - SubstringDetection(options.dialect_options.state_machine_options.quote, options.null_str, "QUOTE", "NULL"); - } - if (options.has_escape) { - SubstringDetection(options.dialect_options.state_machine_options.escape, options.null_str, "ESCAPE", - "NULL"); - } - } - - if (!options.prefix.empty() || !options.suffix.empty()) { - if (options.prefix.empty() || options.suffix.empty()) { - throw BinderException("COPY ... (FORMAT CSV) must have both PREFIX and SUFFIX, or none at all"); - } - if (options.dialect_options.header) { - throw BinderException("COPY ... (FORMAT CSV)'s HEADER cannot be combined with PREFIX/SUFFIX"); - } - } -} - -static unique_ptr WriteCSVBind(ClientContext &context, CopyInfo &info, vector &names, - vector &sql_types) { - auto bind_data = make_uniq(info.file_path, sql_types, names); - - // check all the options in the copy info - for (auto &option : info.options) { - auto loption = StringUtil::Lower(option.first); - auto &set = option.second; - bind_data->options.SetWriteOption(loption, ConvertVectorToValue(std::move(set))); - } - // verify the parsed options - if (bind_data->options.force_quote.empty()) { - // no FORCE_QUOTE specified: initialize to false - bind_data->options.force_quote.resize(names.size(), false); - } - bind_data->Finalize(); - - bind_data->requires_quotes = make_unsafe_uniq_array(256); - memset(bind_data->requires_quotes.get(), 0, sizeof(bool) * 256); - bind_data->requires_quotes['\n'] = true; - bind_data->requires_quotes['\r'] = true; - bind_data->requires_quotes[bind_data->options.dialect_options.state_machine_options.delimiter] = true; - bind_data->requires_quotes[bind_data->options.dialect_options.state_machine_options.quote] = true; - - if (!bind_data->options.write_newline.empty()) { - bind_data->newline = bind_data->options.write_newline; - } - return std::move(bind_data); -} - -static unique_ptr ReadCSVBind(ClientContext &context, CopyInfo &info, vector &expected_names, - vector &expected_types) { - auto bind_data = make_uniq(); - bind_data->csv_types = expected_types; - bind_data->csv_names = expected_names; - bind_data->return_types = expected_types; - bind_data->return_names = expected_names; - bind_data->files = MultiFileReader::GetFileList(context, Value(info.file_path), "CSV"); - - auto &options = bind_data->options; - - // check all the options in the copy info - for (auto &option : info.options) { - auto loption = StringUtil::Lower(option.first); - auto &set = option.second; - options.SetReadOption(loption, ConvertVectorToValue(set), expected_names); - } - // verify the parsed options - if (options.force_not_null.empty()) { - // no FORCE_QUOTE specified: initialize to false - options.force_not_null.resize(expected_types.size(), false); - } - - // Look for rejects table options last - named_parameter_map_t options_map; - for (auto &option : info.options) { - options_map[option.first] = ConvertVectorToValue(std::move(option.second)); - } - options.file_path = bind_data->files[0]; - options.name_list = expected_names; - options.sql_type_list = expected_types; - for (idx_t i = 0; i < expected_types.size(); i++) { - options.sql_types_per_column[expected_names[i]] = i; - } - - bind_data->FinalizeRead(context); - - if (options.auto_detect) { - // We must run the sniffer. - auto file_handle = BaseCSVReader::OpenCSV(context, options); - auto buffer_manager = make_shared(context, std::move(file_handle), options); - CSVSniffer sniffer(options, buffer_manager, bind_data->state_machine_cache); - auto sniffer_result = sniffer.SniffCSV(); - bind_data->csv_types = sniffer_result.return_types; - bind_data->csv_names = sniffer_result.names; - bind_data->return_types = sniffer_result.return_types; - bind_data->return_names = sniffer_result.names; - } - return std::move(bind_data); -} - -//===--------------------------------------------------------------------===// -// Helper writing functions -//===--------------------------------------------------------------------===// -static string AddEscapes(char &to_be_escaped, const char &escape, const string &val) { - idx_t i = 0; - string new_val = ""; - idx_t found = val.find(to_be_escaped); - - while (found != string::npos) { - while (i < found) { - new_val += val[i]; - i++; - } - if (escape != '\0') { - new_val += escape; - found = val.find(to_be_escaped, found + 1); - } - } - while (i < val.length()) { - new_val += val[i]; - i++; - } - return new_val; -} - -static bool RequiresQuotes(WriteCSVData &csv_data, const char *str, idx_t len) { - auto &options = csv_data.options; - // check if the string is equal to the null string - if (len == options.null_str.size() && memcmp(str, options.null_str.c_str(), len) == 0) { - return true; - } - auto str_data = reinterpret_cast(str); - for (idx_t i = 0; i < len; i++) { - if (csv_data.requires_quotes[str_data[i]]) { - // this byte requires quotes - write a quoted string - return true; - } - } - // no newline, quote or delimiter in the string - // no quoting or escaping necessary - return false; -} - -static void WriteQuotedString(WriteStream &writer, WriteCSVData &csv_data, const char *str, idx_t len, - bool force_quote) { - auto &options = csv_data.options; - if (!force_quote) { - // force quote is disabled: check if we need to add quotes anyway - force_quote = RequiresQuotes(csv_data, str, len); - } - if (force_quote) { - // quoting is enabled: we might need to escape things in the string - bool requires_escape = false; - // simple CSV - // do a single loop to check for a quote or escape value - for (idx_t i = 0; i < len; i++) { - if (str[i] == options.dialect_options.state_machine_options.quote || - str[i] == options.dialect_options.state_machine_options.escape) { - requires_escape = true; - break; - } - } - - if (!requires_escape) { - // fast path: no need to escape anything - WriteQuoteOrEscape(writer, options.dialect_options.state_machine_options.quote); - writer.WriteData(const_data_ptr_cast(str), len); - WriteQuoteOrEscape(writer, options.dialect_options.state_machine_options.quote); - return; - } - - // slow path: need to add escapes - string new_val(str, len); - new_val = AddEscapes(options.dialect_options.state_machine_options.escape, - options.dialect_options.state_machine_options.escape, new_val); - if (options.dialect_options.state_machine_options.escape != - options.dialect_options.state_machine_options.quote) { - // need to escape quotes separately - new_val = AddEscapes(options.dialect_options.state_machine_options.quote, - options.dialect_options.state_machine_options.escape, new_val); - } - WriteQuoteOrEscape(writer, options.dialect_options.state_machine_options.quote); - writer.WriteData(const_data_ptr_cast(new_val.c_str()), new_val.size()); - WriteQuoteOrEscape(writer, options.dialect_options.state_machine_options.quote); - } else { - writer.WriteData(const_data_ptr_cast(str), len); - } -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -struct LocalWriteCSVData : public LocalFunctionData { - //! The thread-local buffer to write data into - MemoryStream stream; - //! A chunk with VARCHAR columns to cast intermediates into - DataChunk cast_chunk; - //! If we've written any rows yet, allows us to prevent a trailing comma when writing JSON ARRAY - bool written_anything = false; -}; - -struct GlobalWriteCSVData : public GlobalFunctionData { - GlobalWriteCSVData(FileSystem &fs, const string &file_path, FileCompressionType compression) - : fs(fs), written_anything(false) { - handle = fs.OpenFile(file_path, FileFlags::FILE_FLAGS_WRITE | FileFlags::FILE_FLAGS_FILE_CREATE_NEW, - FileLockType::WRITE_LOCK, compression); - } - - //! Write generic data, e.g., CSV header - void WriteData(const_data_ptr_t data, idx_t size) { - lock_guard flock(lock); - handle->Write((void *)data, size); - } - - void WriteData(const char *data, idx_t size) { - WriteData(const_data_ptr_cast(data), size); - } - - //! Write rows - void WriteRows(const_data_ptr_t data, idx_t size, const string &newline) { - lock_guard flock(lock); - if (written_anything) { - handle->Write((void *)newline.c_str(), newline.length()); - } else { - written_anything = true; - } - handle->Write((void *)data, size); - } - - FileSystem &fs; - //! The mutex for writing to the physical file - mutex lock; - //! The file handle to write to - unique_ptr handle; - //! If we've written any rows yet, allows us to prevent a trailing comma when writing JSON ARRAY - bool written_anything; -}; - -static unique_ptr WriteCSVInitializeLocal(ExecutionContext &context, FunctionData &bind_data) { - auto &csv_data = bind_data.Cast(); - auto local_data = make_uniq(); - - // create the chunk with VARCHAR types - vector types; - types.resize(csv_data.options.name_list.size(), LogicalType::VARCHAR); - - local_data->cast_chunk.Initialize(Allocator::Get(context.client), types); - return std::move(local_data); -} - -static unique_ptr WriteCSVInitializeGlobal(ClientContext &context, FunctionData &bind_data, - const string &file_path) { - auto &csv_data = bind_data.Cast(); - auto &options = csv_data.options; - auto global_data = - make_uniq(FileSystem::GetFileSystem(context), file_path, options.compression); - - if (!options.prefix.empty()) { - global_data->WriteData(options.prefix.c_str(), options.prefix.size()); - } - - if (!(options.has_header && !options.dialect_options.header)) { - MemoryStream stream; - // write the header line to the file - for (idx_t i = 0; i < csv_data.options.name_list.size(); i++) { - if (i != 0) { - WriteQuoteOrEscape(stream, options.dialect_options.state_machine_options.delimiter); - } - WriteQuotedString(stream, csv_data, csv_data.options.name_list[i].c_str(), - csv_data.options.name_list[i].size(), false); - } - stream.WriteData(const_data_ptr_cast(csv_data.newline.c_str()), csv_data.newline.size()); - - global_data->WriteData(stream.GetData(), stream.GetPosition()); - } - - return std::move(global_data); -} - -static void WriteCSVChunkInternal(ClientContext &context, FunctionData &bind_data, DataChunk &cast_chunk, - MemoryStream &writer, DataChunk &input, bool &written_anything) { - auto &csv_data = bind_data.Cast(); - auto &options = csv_data.options; - - // first cast the columns of the chunk to varchar - cast_chunk.Reset(); - cast_chunk.SetCardinality(input); - for (idx_t col_idx = 0; col_idx < input.ColumnCount(); col_idx++) { - if (csv_data.sql_types[col_idx].id() == LogicalTypeId::VARCHAR) { - // VARCHAR, just reinterpret (cannot reference, because LogicalTypeId::VARCHAR is used by the JSON type too) - cast_chunk.data[col_idx].Reinterpret(input.data[col_idx]); - } else if (options.dialect_options.has_format[LogicalTypeId::DATE] && - csv_data.sql_types[col_idx].id() == LogicalTypeId::DATE) { - // use the date format to cast the chunk - csv_data.options.write_date_format[LogicalTypeId::DATE].ConvertDateVector( - input.data[col_idx], cast_chunk.data[col_idx], input.size()); - } else if (options.dialect_options.has_format[LogicalTypeId::TIMESTAMP] && - (csv_data.sql_types[col_idx].id() == LogicalTypeId::TIMESTAMP || - csv_data.sql_types[col_idx].id() == LogicalTypeId::TIMESTAMP_TZ)) { - // use the timestamp format to cast the chunk - csv_data.options.write_date_format[LogicalTypeId::TIMESTAMP].ConvertTimestampVector( - input.data[col_idx], cast_chunk.data[col_idx], input.size()); - } else { - // non varchar column, perform the cast - VectorOperations::Cast(context, input.data[col_idx], cast_chunk.data[col_idx], input.size()); - } - } - - cast_chunk.Flatten(); - // now loop over the vectors and output the values - for (idx_t row_idx = 0; row_idx < cast_chunk.size(); row_idx++) { - if (row_idx == 0 && !written_anything) { - written_anything = true; - } else { - writer.WriteData(const_data_ptr_cast(csv_data.newline.c_str()), csv_data.newline.size()); - } - // write values - for (idx_t col_idx = 0; col_idx < cast_chunk.ColumnCount(); col_idx++) { - if (col_idx != 0) { - WriteQuoteOrEscape(writer, options.dialect_options.state_machine_options.delimiter); - } - if (FlatVector::IsNull(cast_chunk.data[col_idx], row_idx)) { - // write null value - writer.WriteData(const_data_ptr_cast(options.null_str.c_str()), options.null_str.size()); - continue; - } - - // non-null value, fetch the string value from the cast chunk - auto str_data = FlatVector::GetData(cast_chunk.data[col_idx]); - // FIXME: we could gain some performance here by checking for certain types if they ever require quotes - // (e.g. integers only require quotes if the delimiter is a number, decimals only require quotes if the - // delimiter is a number or "." character) - WriteQuotedString(writer, csv_data, str_data[row_idx].GetData(), str_data[row_idx].GetSize(), - csv_data.options.force_quote[col_idx]); - } - } -} - -static void WriteCSVSink(ExecutionContext &context, FunctionData &bind_data, GlobalFunctionData &gstate, - LocalFunctionData &lstate, DataChunk &input) { - auto &csv_data = bind_data.Cast(); - auto &local_data = lstate.Cast(); - auto &global_state = gstate.Cast(); - - // write data into the local buffer - WriteCSVChunkInternal(context.client, bind_data, local_data.cast_chunk, local_data.stream, input, - local_data.written_anything); - - // check if we should flush what we have currently written - auto &writer = local_data.stream; - if (writer.GetPosition() >= csv_data.flush_size) { - global_state.WriteRows(writer.GetData(), writer.GetPosition(), csv_data.newline); - writer.Rewind(); - local_data.written_anything = false; - } -} - -//===--------------------------------------------------------------------===// -// Combine -//===--------------------------------------------------------------------===// -static void WriteCSVCombine(ExecutionContext &context, FunctionData &bind_data, GlobalFunctionData &gstate, - LocalFunctionData &lstate) { - auto &local_data = lstate.Cast(); - auto &global_state = gstate.Cast(); - auto &csv_data = bind_data.Cast(); - auto &writer = local_data.stream; - // flush the local writer - if (local_data.written_anything) { - global_state.WriteRows(writer.GetData(), writer.GetPosition(), csv_data.newline); - writer.Rewind(); - } -} - -//===--------------------------------------------------------------------===// -// Finalize -//===--------------------------------------------------------------------===// -void WriteCSVFinalize(ClientContext &context, FunctionData &bind_data, GlobalFunctionData &gstate) { - auto &global_state = gstate.Cast(); - auto &csv_data = bind_data.Cast(); - auto &options = csv_data.options; - - MemoryStream stream; - if (!options.suffix.empty()) { - stream.WriteData(const_data_ptr_cast(options.suffix.c_str()), options.suffix.size()); - } else if (global_state.written_anything) { - stream.WriteData(const_data_ptr_cast(csv_data.newline.c_str()), csv_data.newline.size()); - } - global_state.WriteData(stream.GetData(), stream.GetPosition()); - - global_state.handle->Close(); - global_state.handle.reset(); -} - -//===--------------------------------------------------------------------===// -// Execution Mode -//===--------------------------------------------------------------------===// -CopyFunctionExecutionMode WriteCSVExecutionMode(bool preserve_insertion_order, bool supports_batch_index) { - if (!preserve_insertion_order) { - return CopyFunctionExecutionMode::PARALLEL_COPY_TO_FILE; - } - if (supports_batch_index) { - return CopyFunctionExecutionMode::BATCH_COPY_TO_FILE; - } - return CopyFunctionExecutionMode::REGULAR_COPY_TO_FILE; -} -//===--------------------------------------------------------------------===// -// Prepare Batch -//===--------------------------------------------------------------------===// -struct WriteCSVBatchData : public PreparedBatchData { - //! The thread-local buffer to write data into - MemoryStream stream; -}; - -unique_ptr WriteCSVPrepareBatch(ClientContext &context, FunctionData &bind_data, - GlobalFunctionData &gstate, - unique_ptr collection) { - auto &csv_data = bind_data.Cast(); - - // create the cast chunk with VARCHAR types - vector types; - types.resize(csv_data.options.name_list.size(), LogicalType::VARCHAR); - DataChunk cast_chunk; - cast_chunk.Initialize(Allocator::Get(context), types); - - // write CSV chunks to the batch data - bool written_anything = false; - auto batch = make_uniq(); - for (auto &chunk : collection->Chunks()) { - WriteCSVChunkInternal(context, bind_data, cast_chunk, batch->stream, chunk, written_anything); - } - return std::move(batch); -} - -//===--------------------------------------------------------------------===// -// Flush Batch -//===--------------------------------------------------------------------===// -void WriteCSVFlushBatch(ClientContext &context, FunctionData &bind_data, GlobalFunctionData &gstate, - PreparedBatchData &batch) { - auto &csv_batch = batch.Cast(); - auto &global_state = gstate.Cast(); - auto &csv_data = bind_data.Cast(); - auto &writer = csv_batch.stream; - global_state.WriteRows(writer.GetData(), writer.GetPosition(), csv_data.newline); - writer.Rewind(); -} - -void CSVCopyFunction::RegisterFunction(BuiltinFunctions &set) { - CopyFunction info("csv"); - info.copy_to_bind = WriteCSVBind; - info.copy_to_initialize_local = WriteCSVInitializeLocal; - info.copy_to_initialize_global = WriteCSVInitializeGlobal; - info.copy_to_sink = WriteCSVSink; - info.copy_to_combine = WriteCSVCombine; - info.copy_to_finalize = WriteCSVFinalize; - info.execution_mode = WriteCSVExecutionMode; - info.prepare_batch = WriteCSVPrepareBatch; - info.flush_batch = WriteCSVFlushBatch; - - info.copy_from_bind = ReadCSVBind; - info.copy_from_function = ReadCSVTableFunction::GetFunction(); - - info.extension = "csv"; - - set.AddFunction(info); -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -struct GlobFunctionBindData : public TableFunctionData { - vector files; -}; - -static unique_ptr GlobFunctionBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - auto result = make_uniq(); - result->files = MultiFileReader::GetFileList(context, input.inputs[0], "Globbing", FileGlobOptions::ALLOW_EMPTY); - return_types.emplace_back(LogicalType::VARCHAR); - names.emplace_back("file"); - return std::move(result); -} - -struct GlobFunctionState : public GlobalTableFunctionState { - GlobFunctionState() : current_idx(0) { - } - - idx_t current_idx; -}; - -static unique_ptr GlobFunctionInit(ClientContext &context, TableFunctionInitInput &input) { - return make_uniq(); -} - -static void GlobFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &bind_data = data_p.bind_data->Cast(); - auto &state = data_p.global_state->Cast(); - - idx_t count = 0; - idx_t next_idx = MinValue(state.current_idx + STANDARD_VECTOR_SIZE, bind_data.files.size()); - for (; state.current_idx < next_idx; state.current_idx++) { - output.data[0].SetValue(count, bind_data.files[state.current_idx]); - count++; - } - output.SetCardinality(count); -} - -void GlobTableFunction::RegisterFunction(BuiltinFunctions &set) { - TableFunction glob_function("glob", {LogicalType::VARCHAR}, GlobFunction, GlobFunctionBind, GlobFunctionInit); - set.AddFunction(MultiFileReader::CreateFunctionSet(glob_function)); -} - -} // namespace duckdb - - - - - - - - - - -namespace duckdb { - -struct PragmaDetailedProfilingOutputOperatorData : public GlobalTableFunctionState { - explicit PragmaDetailedProfilingOutputOperatorData() : initialized(false) { - } - - ColumnDataScanState scan_state; - bool initialized; -}; - -struct PragmaDetailedProfilingOutputData : public TableFunctionData { - explicit PragmaDetailedProfilingOutputData(vector &types) : types(types) { - } - unique_ptr collection; - vector types; -}; - -static unique_ptr PragmaDetailedProfilingOutputBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, - vector &names) { - names.emplace_back("OPERATOR_ID"); - return_types.emplace_back(LogicalType::INTEGER); - - names.emplace_back("ANNOTATION"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("ID"); - return_types.emplace_back(LogicalType::INTEGER); - - names.emplace_back("NAME"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("TIME"); - return_types.emplace_back(LogicalType::DOUBLE); - - names.emplace_back("CYCLES_PER_TUPLE"); - return_types.emplace_back(LogicalType::DOUBLE); - - names.emplace_back("SAMPLE_SIZE"); - return_types.emplace_back(LogicalType::INTEGER); - - names.emplace_back("INPUT_SIZE"); - return_types.emplace_back(LogicalType::INTEGER); - - names.emplace_back("EXTRA_INFO"); - return_types.emplace_back(LogicalType::VARCHAR); - - return make_uniq(return_types); -} - -unique_ptr PragmaDetailedProfilingOutputInit(ClientContext &context, - TableFunctionInitInput &input) { - return make_uniq(); -} - -// Insert a row into the given datachunk -static void SetValue(DataChunk &output, int index, int op_id, string annotation, int id, string name, double time, - int sample_counter, int tuple_counter, string extra_info) { - output.SetValue(0, index, op_id); - output.SetValue(1, index, std::move(annotation)); - output.SetValue(2, index, id); - output.SetValue(3, index, std::move(name)); -#if defined(RDTSC) - output.SetValue(4, index, Value(nullptr)); - output.SetValue(5, index, time); -#else - output.SetValue(4, index, time); - output.SetValue(5, index, Value(nullptr)); - -#endif - output.SetValue(6, index, sample_counter); - output.SetValue(7, index, tuple_counter); - output.SetValue(8, index, std::move(extra_info)); -} - -static void ExtractFunctions(ColumnDataCollection &collection, ExpressionInfo &info, DataChunk &chunk, int op_id, - int &fun_id) { - if (info.hasfunction) { - D_ASSERT(info.sample_tuples_count != 0); - SetValue(chunk, chunk.size(), op_id, "Function", fun_id++, info.function_name, - int(info.function_time) / double(info.sample_tuples_count), info.sample_tuples_count, - info.tuples_count, ""); - - chunk.SetCardinality(chunk.size() + 1); - if (chunk.size() == STANDARD_VECTOR_SIZE) { - collection.Append(chunk); - chunk.Reset(); - } - } - if (info.children.empty()) { - return; - } - // extract the children of this node - for (auto &child : info.children) { - ExtractFunctions(collection, *child, chunk, op_id, fun_id); - } -} - -static void PragmaDetailedProfilingOutputFunction(ClientContext &context, TableFunctionInput &data_p, - DataChunk &output) { - auto &state = data_p.global_state->Cast(); - auto &data = data_p.bind_data->CastNoConst(); - - if (!state.initialized) { - // create a ColumnDataCollection - auto collection = make_uniq(context, data.types); - - // create a chunk - DataChunk chunk; - chunk.Initialize(context, data.types); - - // Initialize ids - int operator_counter = 1; - int function_counter = 1; - int expression_counter = 1; - auto &client_data = ClientData::Get(context); - if (client_data.query_profiler_history->GetPrevProfilers().empty()) { - return; - } - // For each Operator - auto &tree_map = client_data.query_profiler_history->GetPrevProfilers().back().second->GetTreeMap(); - for (auto op : tree_map) { - // For each Expression Executor - for (auto &expr_executor : op.second.get().info.executors_info) { - // For each Expression tree - if (!expr_executor) { - continue; - } - for (auto &expr_timer : expr_executor->roots) { - D_ASSERT(expr_timer->sample_tuples_count != 0); - SetValue(chunk, chunk.size(), operator_counter, "ExpressionRoot", expression_counter++, - // Sometimes, cycle counter is not accurate, too big or too small. return 0 for - // those cases - expr_timer->name, int(expr_timer->time) / double(expr_timer->sample_tuples_count), - expr_timer->sample_tuples_count, expr_timer->tuples_count, expr_timer->extra_info); - // Increment cardinality - chunk.SetCardinality(chunk.size() + 1); - // Check whether data chunk is full or not - if (chunk.size() == STANDARD_VECTOR_SIZE) { - collection->Append(chunk); - chunk.Reset(); - } - // Extract all functions inside the tree - ExtractFunctions(*collection, *expr_timer->root, chunk, operator_counter, function_counter); - } - } - operator_counter++; - } - collection->Append(chunk); - data.collection = std::move(collection); - data.collection->InitializeScan(state.scan_state); - state.initialized = true; - } - - data.collection->Scan(state.scan_state, output); -} - -void PragmaDetailedProfilingOutput::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(TableFunction("pragma_detailed_profiling_output", {}, PragmaDetailedProfilingOutputFunction, - PragmaDetailedProfilingOutputBind, PragmaDetailedProfilingOutputInit)); -} - -} // namespace duckdb - - - - - - - - - - -namespace duckdb { - -struct PragmaLastProfilingOutputOperatorData : public GlobalTableFunctionState { - PragmaLastProfilingOutputOperatorData() : initialized(false) { - } - - ColumnDataScanState scan_state; - bool initialized; -}; - -struct PragmaLastProfilingOutputData : public TableFunctionData { - explicit PragmaLastProfilingOutputData(vector &types) : types(types) { - } - unique_ptr collection; - vector types; -}; - -static unique_ptr PragmaLastProfilingOutputBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, - vector &names) { - names.emplace_back("OPERATOR_ID"); - return_types.emplace_back(LogicalType::INTEGER); - - names.emplace_back("NAME"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("TIME"); - return_types.emplace_back(LogicalType::DOUBLE); - - names.emplace_back("CARDINALITY"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("DESCRIPTION"); - return_types.emplace_back(LogicalType::VARCHAR); - - return make_uniq(return_types); -} - -static void SetValue(DataChunk &output, int index, int op_id, string name, double time, int64_t car, - string description) { - output.SetValue(0, index, op_id); - output.SetValue(1, index, std::move(name)); - output.SetValue(2, index, time); - output.SetValue(3, index, car); - output.SetValue(4, index, std::move(description)); -} - -unique_ptr PragmaLastProfilingOutputInit(ClientContext &context, - TableFunctionInitInput &input) { - return make_uniq(); -} - -static void PragmaLastProfilingOutputFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &state = data_p.global_state->Cast(); - auto &data = data_p.bind_data->CastNoConst(); - if (!state.initialized) { - // create a ColumnDataCollection - auto collection = make_uniq(context, data.types); - - DataChunk chunk; - chunk.Initialize(context, data.types); - int operator_counter = 1; - auto &client_data = ClientData::Get(context); - if (!client_data.query_profiler_history->GetPrevProfilers().empty()) { - auto &tree_map = client_data.query_profiler_history->GetPrevProfilers().back().second->GetTreeMap(); - for (auto op : tree_map) { - auto &tree_info = op.second.get(); - SetValue(chunk, chunk.size(), operator_counter++, tree_info.name, tree_info.info.time, - tree_info.info.elements, " "); - chunk.SetCardinality(chunk.size() + 1); - if (chunk.size() == STANDARD_VECTOR_SIZE) { - collection->Append(chunk); - chunk.Reset(); - } - } - } - collection->Append(chunk); - data.collection = std::move(collection); - data.collection->InitializeScan(state.scan_state); - state.initialized = true; - } - - data.collection->Scan(state.scan_state, output); -} - -void PragmaLastProfilingOutput::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(TableFunction("pragma_last_profiling_output", {}, PragmaLastProfilingOutputFunction, - PragmaLastProfilingOutputBind, PragmaLastProfilingOutputInit)); -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Range (integers) -//===--------------------------------------------------------------------===// -struct RangeFunctionBindData : public TableFunctionData { - hugeint_t start; - hugeint_t end; - hugeint_t increment; - -public: - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return other.start == start && other.end == end && other.increment == increment; - } -}; - -template -static void GenerateRangeParameters(const vector &inputs, RangeFunctionBindData &result) { - for (auto &input : inputs) { - if (input.IsNull()) { - result.start = GENERATE_SERIES ? 1 : 0; - result.end = 0; - result.increment = 1; - return; - } - } - if (inputs.size() < 2) { - // single argument: only the end is specified - result.start = 0; - result.end = inputs[0].GetValue(); - } else { - // two arguments: first two arguments are start and end - result.start = inputs[0].GetValue(); - result.end = inputs[1].GetValue(); - } - if (inputs.size() < 3) { - result.increment = 1; - } else { - result.increment = inputs[2].GetValue(); - } - if (result.increment == 0) { - throw BinderException("interval cannot be 0!"); - } - if (result.start > result.end && result.increment > 0) { - throw BinderException("start is bigger than end, but increment is positive: cannot generate infinite series"); - } else if (result.start < result.end && result.increment < 0) { - throw BinderException("start is smaller than end, but increment is negative: cannot generate infinite series"); - } -} - -template -static unique_ptr RangeFunctionBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - auto result = make_uniq(); - auto &inputs = input.inputs; - GenerateRangeParameters(inputs, *result); - - return_types.emplace_back(LogicalType::BIGINT); - if (GENERATE_SERIES) { - // generate_series has inclusive bounds on the RHS - if (result->increment < 0) { - result->end = result->end - 1; - } else { - result->end = result->end + 1; - } - names.emplace_back("generate_series"); - } else { - names.emplace_back("range"); - } - return std::move(result); -} - -struct RangeFunctionState : public GlobalTableFunctionState { - RangeFunctionState() : current_idx(0) { - } - - int64_t current_idx; -}; - -static unique_ptr RangeFunctionInit(ClientContext &context, TableFunctionInitInput &input) { - return make_uniq(); -} - -static void RangeFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &bind_data = data_p.bind_data->Cast(); - auto &state = data_p.global_state->Cast(); - - auto increment = bind_data.increment; - auto end = bind_data.end; - hugeint_t current_value = bind_data.start + increment * state.current_idx; - int64_t current_value_i64; - if (!Hugeint::TryCast(current_value, current_value_i64)) { - return; - } - int64_t offset = increment < 0 ? 1 : -1; - idx_t remaining = MinValue(Hugeint::Cast((end - current_value + (increment + offset)) / increment), - STANDARD_VECTOR_SIZE); - // set the result vector as a sequence vector - output.data[0].Sequence(current_value_i64, Hugeint::Cast(increment), remaining); - // increment the index pointer by the remaining count - state.current_idx += remaining; - output.SetCardinality(remaining); -} - -unique_ptr RangeCardinality(ClientContext &context, const FunctionData *bind_data_p) { - auto &bind_data = bind_data_p->Cast(); - idx_t cardinality = Hugeint::Cast((bind_data.end - bind_data.start) / bind_data.increment); - return make_uniq(cardinality, cardinality); -} - -//===--------------------------------------------------------------------===// -// Range (timestamp) -//===--------------------------------------------------------------------===// -struct RangeDateTimeBindData : public TableFunctionData { - timestamp_t start; - timestamp_t end; - interval_t increment; - bool inclusive_bound; - bool greater_than_check; - -public: - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return other.start == start && other.end == end && other.increment == increment && - other.inclusive_bound == inclusive_bound && other.greater_than_check == greater_than_check; - } - - bool Finished(timestamp_t current_value) const { - if (greater_than_check) { - if (inclusive_bound) { - return current_value > end; - } else { - return current_value >= end; - } - } else { - if (inclusive_bound) { - return current_value < end; - } else { - return current_value <= end; - } - } - } -}; - -template -static unique_ptr RangeDateTimeBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - auto result = make_uniq(); - auto &inputs = input.inputs; - D_ASSERT(inputs.size() == 3); - result->start = inputs[0].GetValue(); - result->end = inputs[1].GetValue(); - result->increment = inputs[2].GetValue(); - - // Infinities either cause errors or infinite loops, so just ban them - if (!Timestamp::IsFinite(result->start) || !Timestamp::IsFinite(result->end)) { - throw BinderException("RANGE with infinite bounds is not supported"); - } - - if (result->increment.months == 0 && result->increment.days == 0 && result->increment.micros == 0) { - throw BinderException("interval cannot be 0!"); - } - // all elements should point in the same direction - if (result->increment.months > 0 || result->increment.days > 0 || result->increment.micros > 0) { - if (result->increment.months < 0 || result->increment.days < 0 || result->increment.micros < 0) { - throw BinderException("RANGE with composite interval that has mixed signs is not supported"); - } - result->greater_than_check = true; - if (result->start > result->end) { - throw BinderException( - "start is bigger than end, but increment is positive: cannot generate infinite series"); - } - } else { - result->greater_than_check = false; - if (result->start < result->end) { - throw BinderException( - "start is smaller than end, but increment is negative: cannot generate infinite series"); - } - } - return_types.push_back(inputs[0].type()); - if (GENERATE_SERIES) { - // generate_series has inclusive bounds on the RHS - result->inclusive_bound = true; - names.emplace_back("generate_series"); - } else { - result->inclusive_bound = false; - names.emplace_back("range"); - } - return std::move(result); -} - -struct RangeDateTimeState : public GlobalTableFunctionState { - explicit RangeDateTimeState(timestamp_t start_p) : current_state(start_p) { - } - - timestamp_t current_state; - bool finished = false; -}; - -static unique_ptr RangeDateTimeInit(ClientContext &context, TableFunctionInitInput &input) { - auto &bind_data = input.bind_data->Cast(); - return make_uniq(bind_data.start); -} - -static void RangeDateTimeFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &bind_data = data_p.bind_data->Cast(); - auto &state = data_p.global_state->Cast(); - if (state.finished) { - return; - } - - idx_t size = 0; - auto data = FlatVector::GetData(output.data[0]); - while (true) { - data[size++] = state.current_state; - state.current_state = - AddOperator::Operation(state.current_state, bind_data.increment); - if (bind_data.Finished(state.current_state)) { - state.finished = true; - break; - } - if (size >= STANDARD_VECTOR_SIZE) { - break; - } - } - output.SetCardinality(size); -} - -void RangeTableFunction::RegisterFunction(BuiltinFunctions &set) { - TableFunctionSet range("range"); - - TableFunction range_function({LogicalType::BIGINT}, RangeFunction, RangeFunctionBind, RangeFunctionInit); - range_function.cardinality = RangeCardinality; - - // single argument range: (end) - implicit start = 0 and increment = 1 - range.AddFunction(range_function); - // two arguments range: (start, end) - implicit increment = 1 - range_function.arguments = {LogicalType::BIGINT, LogicalType::BIGINT}; - range.AddFunction(range_function); - // three arguments range: (start, end, increment) - range_function.arguments = {LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT}; - range.AddFunction(range_function); - range.AddFunction(TableFunction({LogicalType::TIMESTAMP, LogicalType::TIMESTAMP, LogicalType::INTERVAL}, - RangeDateTimeFunction, RangeDateTimeBind, RangeDateTimeInit)); - set.AddFunction(range); - // generate_series: similar to range, but inclusive instead of exclusive bounds on the RHS - TableFunctionSet generate_series("generate_series"); - range_function.bind = RangeFunctionBind; - range_function.arguments = {LogicalType::BIGINT}; - generate_series.AddFunction(range_function); - range_function.arguments = {LogicalType::BIGINT, LogicalType::BIGINT}; - generate_series.AddFunction(range_function); - range_function.arguments = {LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT}; - generate_series.AddFunction(range_function); - generate_series.AddFunction(TableFunction({LogicalType::TIMESTAMP, LogicalType::TIMESTAMP, LogicalType::INTERVAL}, - RangeDateTimeFunction, RangeDateTimeBind, RangeDateTimeInit)); - set.AddFunction(generate_series); -} - -void BuiltinFunctions::RegisterTableFunctions() { - CheckpointFunction::RegisterFunction(*this); - GlobTableFunction::RegisterFunction(*this); - RangeTableFunction::RegisterFunction(*this); - RepeatTableFunction::RegisterFunction(*this); - SummaryTableFunction::RegisterFunction(*this); - UnnestTableFunction::RegisterFunction(*this); - RepeatRowTableFunction::RegisterFunction(*this); -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - - - - -#include - -namespace duckdb { - -unique_ptr ReadCSV::OpenCSV(const string &file_path, FileCompressionType compression, - ClientContext &context) { - auto &fs = FileSystem::GetFileSystem(context); - auto &allocator = BufferAllocator::Get(context); - return CSVFileHandle::OpenFile(fs, allocator, file_path, compression); -} - -void ReadCSVData::FinalizeRead(ClientContext &context) { - BaseCSVData::Finalize(); - // Here we identify if we can run this CSV file on parallel or not. - bool not_supported_options = options.null_padding; - - auto number_of_threads = TaskScheduler::GetScheduler(context).NumberOfThreads(); - //! If we have many csv files, we run single-threaded on each file and parallelize on the number of files - bool many_csv_files = files.size() > 1 && int64_t(files.size() * 2) >= number_of_threads; - if (options.parallel_mode != ParallelMode::PARALLEL && many_csv_files) { - single_threaded = true; - } - if (options.parallel_mode == ParallelMode::SINGLE_THREADED || not_supported_options || - options.dialect_options.new_line == NewLineIdentifier::MIX) { - // not supported for parallel CSV reading - single_threaded = true; - } - - // Validate rejects_table options - if (!options.rejects_table_name.empty()) { - if (!options.ignore_errors) { - throw BinderException("REJECTS_TABLE option is only supported when IGNORE_ERRORS is set to true"); - } - if (options.file_options.union_by_name) { - throw BinderException("REJECTS_TABLE option is not supported when UNION_BY_NAME is set to true"); - } - } - - if (!options.rejects_recovery_columns.empty()) { - if (options.rejects_table_name.empty()) { - throw BinderException( - "REJECTS_RECOVERY_COLUMNS option is only supported when REJECTS_TABLE is set to a table name"); - } - for (auto &recovery_col : options.rejects_recovery_columns) { - bool found = false; - for (idx_t col_idx = 0; col_idx < return_names.size(); col_idx++) { - if (StringUtil::CIEquals(return_names[col_idx], recovery_col)) { - options.rejects_recovery_column_ids.push_back(col_idx); - found = true; - break; - } - } - if (!found) { - throw BinderException("Unsupported parameter for REJECTS_RECOVERY_COLUMNS: column \"%s\" not found", - recovery_col); - } - } - } - - if (options.rejects_limit != 0) { - if (options.rejects_table_name.empty()) { - throw BinderException("REJECTS_LIMIT option is only supported when REJECTS_TABLE is set to a table name"); - } - } -} - -static unique_ptr ReadCSVBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - - auto result = make_uniq(); - auto &options = result->options; - result->files = MultiFileReader::GetFileList(context, input.inputs[0], "CSV"); - - options.FromNamedParameters(input.named_parameters, context, return_types, names); - bool explicitly_set_columns = options.explicitly_set_columns; - - options.file_options.AutoDetectHivePartitioning(result->files, context); - - if (!options.auto_detect && return_types.empty()) { - throw BinderException("read_csv requires columns to be specified through the 'columns' option. Use " - "read_csv_auto or set read_csv(..., " - "AUTO_DETECT=TRUE) to automatically guess columns."); - } - if (options.auto_detect) { - options.file_path = result->files[0]; - // Initialize Buffer Manager and Sniffer - auto file_handle = BaseCSVReader::OpenCSV(context, options); - result->buffer_manager = make_shared(context, std::move(file_handle), options); - CSVSniffer sniffer(options, result->buffer_manager, result->state_machine_cache, explicitly_set_columns); - auto sniffer_result = sniffer.SniffCSV(); - if (names.empty()) { - names = sniffer_result.names; - return_types = sniffer_result.return_types; - } else { - if (explicitly_set_columns) { - // The user has influenced the names, can't assume they are valid anymore - if (return_types.size() != names.size()) { - throw BinderException("The amount of names specified (%d) and the observed amount of types (%d) in " - "the file don't match", - names.size(), return_types.size()); - } - } else { - D_ASSERT(return_types.size() == names.size()); - } - } - - } else { - D_ASSERT(return_types.size() == names.size()); - } - result->csv_types = return_types; - result->csv_names = names; - - if (options.file_options.union_by_name) { - result->reader_bind = - MultiFileReader::BindUnionReader(context, return_types, names, *result, options); - if (result->union_readers.size() > 1) { - result->column_info.emplace_back(result->csv_names, result->csv_types); - for (idx_t i = 1; i < result->union_readers.size(); i++) { - result->column_info.emplace_back(result->union_readers[i]->names, - result->union_readers[i]->return_types); - } - } - if (!options.sql_types_per_column.empty()) { - auto exception = BufferedCSVReader::ColumnTypesError(options.sql_types_per_column, names); - if (!exception.empty()) { - throw BinderException(exception); - } - } - } else { - result->reader_bind = MultiFileReader::BindOptions(options.file_options, result->files, return_types, names); - } - result->return_types = return_types; - result->return_names = names; - result->FinalizeRead(context); - - return std::move(result); -} - -static unique_ptr ReadCSVAutoBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - input.named_parameters["auto_detect"] = Value::BOOLEAN(true); - return ReadCSVBind(context, input, return_types, names); -} - -//===--------------------------------------------------------------------===// -// Parallel CSV Reader CSV Global State -//===--------------------------------------------------------------------===// - -struct ParallelCSVGlobalState : public GlobalTableFunctionState { -public: - ParallelCSVGlobalState(ClientContext &context, shared_ptr buffer_manager_p, - const CSVReaderOptions &options, idx_t system_threads_p, const vector &files_path_p, - bool force_parallelism_p, vector column_ids_p) - : buffer_manager(std::move(buffer_manager_p)), system_threads(system_threads_p), - force_parallelism(force_parallelism_p), column_ids(std::move(column_ids_p)), - line_info(main_mutex, batch_to_tuple_end, tuple_start, tuple_end) { - current_file_path = files_path_p[0]; - CSVFileHandle *file_handle_ptr; - - if (!buffer_manager || (options.skip_rows_set && options.dialect_options.skip_rows > 0) || - buffer_manager->file_handle->GetFilePath() != current_file_path) { - // If our buffers are too small, and we skip too many rows there is a chance things will go over-buffer - // for now don't reuse the buffer manager - buffer_manager.reset(); - file_handle = ReadCSV::OpenCSV(current_file_path, options.compression, context); - file_handle_ptr = file_handle.get(); - } else { - file_handle_ptr = buffer_manager->file_handle.get(); - } - - file_size = file_handle_ptr->FileSize(); - first_file_size = file_size; - on_disk_file = file_handle_ptr->OnDiskFile(); - bytes_read = 0; - running_threads = MaxThreads(); - - // Initialize all the book-keeping variables - auto file_count = files_path_p.size(); - line_info.current_batches.resize(file_count); - line_info.lines_read.resize(file_count); - line_info.lines_errored.resize(file_count); - tuple_start.resize(file_count); - tuple_end.resize(file_count); - tuple_end_to_batch.resize(file_count); - batch_to_tuple_end.resize(file_count); - - // Initialize the lines read - line_info.lines_read[0][0] = options.dialect_options.skip_rows; - if (options.has_header && options.dialect_options.header) { - line_info.lines_read[0][0]++; - } - first_position = options.dialect_options.true_start; - next_byte = options.dialect_options.true_start; - } - explicit ParallelCSVGlobalState(idx_t system_threads_p) - : system_threads(system_threads_p), line_info(main_mutex, batch_to_tuple_end, tuple_start, tuple_end) { - running_threads = MaxThreads(); - } - - ~ParallelCSVGlobalState() override { - } - - //! How many bytes were read up to this point - atomic bytes_read; - //! Size of current file - idx_t file_size; - -public: - idx_t MaxThreads() const override; - //! Updates the CSV reader with the next buffer to read. Returns false if no more buffers are available. - bool Next(ClientContext &context, const ReadCSVData &bind_data, unique_ptr &reader); - //! Verify if the CSV File was read correctly - void Verify(); - - void UpdateVerification(VerificationPositions positions, idx_t file_number, idx_t batch_idx); - - void UpdateLinesRead(CSVBufferRead &buffer_read, idx_t file_idx); - - void DecrementThread(); - - bool Finished(); - - double GetProgress(const ReadCSVData &bind_data) const { - idx_t total_files = bind_data.files.size(); - - // get the progress WITHIN the current file - double progress; - if (file_size == 0) { - progress = 1.0; - } else { - progress = double(bytes_read) / double(file_size); - } - // now get the total percentage of files read - double percentage = double(file_index - 1) / total_files; - percentage += (double(1) / double(total_files)) * progress; - return percentage * 100; - } - -private: - //! File Handle for current file - shared_ptr buffer_manager; - - //! The index of the next file to read (i.e. current file + 1) - idx_t file_index = 1; - string current_file_path; - - //! Mutex to lock when getting next batch of bytes (Parallel Only) - mutex main_mutex; - //! Byte set from for last thread - idx_t next_byte = 0; - //! Size of first file - idx_t first_file_size = 0; - //! Whether or not this is an on-disk file - bool on_disk_file = true; - //! Basically max number of threads in DuckDB - idx_t system_threads; - //! Current batch index - idx_t batch_index = 0; - idx_t local_batch_index = 0; - - //! Forces parallelism for small CSV Files, should only be used for testing. - bool force_parallelism = false; - //! First Position of First Buffer - idx_t first_position = 0; - //! Current File Number - idx_t max_tuple_end = 0; - //! The vector stores positions where threads ended the last line they read in the CSV File, and the set stores - //! Positions where they started reading the first line. - vector> tuple_end; - vector> tuple_start; - //! Tuple end to batch - vector> tuple_end_to_batch; - //! Batch to Tuple End - vector> batch_to_tuple_end; - idx_t running_threads = 0; - //! The column ids to read - vector column_ids; - //! Line Info used in error messages - LineInfo line_info; - //! Current Buffer index - idx_t cur_buffer_idx = 0; - //! Only used if we don't run auto_detection first - unique_ptr file_handle; -}; - -idx_t ParallelCSVGlobalState::MaxThreads() const { - if (force_parallelism || !on_disk_file) { - return system_threads; - } - idx_t one_mb = 1000000; // We initialize max one thread per Mb - idx_t threads_per_mb = first_file_size / one_mb + 1; - if (threads_per_mb < system_threads || threads_per_mb == 1) { - return threads_per_mb; - } - - return system_threads; -} - -void ParallelCSVGlobalState::DecrementThread() { - lock_guard parallel_lock(main_mutex); - D_ASSERT(running_threads > 0); - running_threads--; -} - -bool ParallelCSVGlobalState::Finished() { - lock_guard parallel_lock(main_mutex); - return running_threads == 0; -} - -void ParallelCSVGlobalState::Verify() { - // All threads are done, we run some magic sweet verification code - lock_guard parallel_lock(main_mutex); - if (running_threads == 0) { - D_ASSERT(tuple_end.size() == tuple_start.size()); - for (idx_t i = 0; i < tuple_start.size(); i++) { - auto ¤t_tuple_end = tuple_end[i]; - auto ¤t_tuple_start = tuple_start[i]; - // figure out max value of last_pos - if (current_tuple_end.empty()) { - return; - } - auto max_value = *max_element(std::begin(current_tuple_end), std::end(current_tuple_end)); - for (idx_t tpl_idx = 0; tpl_idx < current_tuple_end.size(); tpl_idx++) { - auto last_pos = current_tuple_end[tpl_idx]; - auto first_pos = current_tuple_start.find(last_pos); - if (first_pos == current_tuple_start.end()) { - // this might be necessary due to carriage returns outside buffer scopes. - first_pos = current_tuple_start.find(last_pos + 1); - } - if (first_pos == current_tuple_start.end() && last_pos != max_value) { - auto batch_idx = tuple_end_to_batch[i][last_pos]; - auto problematic_line = line_info.GetLine(batch_idx); - throw InvalidInputException( - "CSV File not supported for multithreading. This can be a problematic line in your CSV File or " - "that this CSV can't be read in Parallel. Please, inspect if the line %llu is correct. If so, " - "please run single-threaded CSV Reading by setting parallel=false in the read_csv call.", - problematic_line); - } - } - } - } -} - -void LineInfo::Verify(idx_t file_idx, idx_t batch_idx, idx_t cur_first_pos) { - auto &tuple_start_set = tuple_start[file_idx]; - auto &processed_batches = batch_to_tuple_end[file_idx]; - auto &tuple_end_vec = tuple_end[file_idx]; - bool has_error = false; - idx_t problematic_line; - if (batch_idx == 0 || tuple_start_set.empty()) { - return; - } - for (idx_t cur_batch = 0; cur_batch < batch_idx - 1; cur_batch++) { - auto cur_end = tuple_end_vec[processed_batches[cur_batch]]; - auto first_pos = tuple_start_set.find(cur_end); - if (first_pos == tuple_start_set.end()) { - has_error = true; - problematic_line = GetLine(cur_batch); - break; - } - } - if (!has_error) { - auto cur_end = tuple_end_vec[processed_batches[batch_idx - 1]]; - if (cur_end != cur_first_pos) { - has_error = true; - problematic_line = GetLine(batch_idx); - } - } - if (has_error) { - throw InvalidInputException( - "CSV File not supported for multithreading. This can be a problematic line in your CSV File or " - "that this CSV can't be read in Parallel. Please, inspect if the line %llu is correct. If so, " - "please run single-threaded CSV Reading by setting parallel=false in the read_csv call.", - problematic_line); - } -} -bool ParallelCSVGlobalState::Next(ClientContext &context, const ReadCSVData &bind_data, - unique_ptr &reader) { - lock_guard parallel_lock(main_mutex); - if (!buffer_manager && file_handle) { - buffer_manager = make_shared(context, std::move(file_handle), bind_data.options); - } - if (!buffer_manager) { - return false; - } - auto current_buffer = buffer_manager->GetBuffer(cur_buffer_idx); - auto next_buffer = buffer_manager->GetBuffer(cur_buffer_idx + 1); - - if (!current_buffer) { - // This means we are done with the current file, we need to go to the next one (if exists). - if (file_index < bind_data.files.size()) { - current_file_path = bind_data.files[file_index]; - file_handle = ReadCSV::OpenCSV(current_file_path, bind_data.options.compression, context); - buffer_manager = - make_shared(context, std::move(file_handle), bind_data.options, file_index); - cur_buffer_idx = 0; - first_position = 0; - local_batch_index = 0; - - line_info.lines_read[file_index++][local_batch_index] = (bind_data.options.has_header ? 1 : 0); - - current_buffer = buffer_manager->GetBuffer(cur_buffer_idx); - next_buffer = buffer_manager->GetBuffer(cur_buffer_idx + 1); - } else { - // We are done scanning. - reader.reset(); - return false; - } - } - // set up the current buffer - line_info.current_batches[file_index - 1].insert(local_batch_index); - idx_t bytes_per_local_state = current_buffer->actual_size / MaxThreads() + 1; - auto result = make_uniq( - buffer_manager->GetBuffer(cur_buffer_idx), buffer_manager->GetBuffer(cur_buffer_idx + 1), next_byte, - next_byte + bytes_per_local_state, batch_index++, local_batch_index++, &line_info); - // move the byte index of the CSV reader to the next buffer - next_byte += bytes_per_local_state; - if (next_byte >= current_buffer->actual_size) { - // We replace the current buffer with the next buffer - next_byte = 0; - bytes_read += current_buffer->actual_size; - current_buffer = std::move(next_buffer); - cur_buffer_idx++; - if (current_buffer) { - // Next buffer gets the next-next buffer - next_buffer = buffer_manager->GetBuffer(cur_buffer_idx + 1); - } - } - if (!reader || reader->options.file_path != current_file_path) { - // we either don't have a reader, or the reader was created for a different file - // we need to create a new reader and instantiate it - if (file_index > 0 && file_index <= bind_data.union_readers.size() && bind_data.union_readers[file_index - 1]) { - // we are doing UNION BY NAME - fetch the options from the union reader for this file - auto &union_reader = *bind_data.union_readers[file_index - 1]; - reader = make_uniq(context, union_reader.options, std::move(result), first_position, - union_reader.GetTypes(), file_index - 1); - reader->names = union_reader.GetNames(); - } else if (file_index <= bind_data.column_info.size()) { - // Serialized Union By name - reader = make_uniq(context, bind_data.options, std::move(result), first_position, - bind_data.column_info[file_index - 1].types, file_index - 1); - reader->names = bind_data.column_info[file_index - 1].names; - } else { - // regular file - use the standard options - if (!result) { - return false; - } - reader = make_uniq(context, bind_data.options, std::move(result), first_position, - bind_data.csv_types, file_index - 1); - reader->names = bind_data.csv_names; - } - reader->options.file_path = current_file_path; - MultiFileReader::InitializeReader(*reader, bind_data.options.file_options, bind_data.reader_bind, - bind_data.return_types, bind_data.return_names, column_ids, nullptr, - bind_data.files.front(), context); - } else { - // update the current reader - reader->SetBufferRead(std::move(result)); - } - - return true; -} -void ParallelCSVGlobalState::UpdateVerification(VerificationPositions positions, idx_t file_number_p, idx_t batch_idx) { - lock_guard parallel_lock(main_mutex); - if (positions.end_of_last_line > max_tuple_end) { - max_tuple_end = positions.end_of_last_line; - } - tuple_end_to_batch[file_number_p][positions.end_of_last_line] = batch_idx; - batch_to_tuple_end[file_number_p][batch_idx] = tuple_end[file_number_p].size(); - tuple_start[file_number_p].insert(positions.beginning_of_first_line); - tuple_end[file_number_p].push_back(positions.end_of_last_line); -} - -void ParallelCSVGlobalState::UpdateLinesRead(CSVBufferRead &buffer_read, idx_t file_idx) { - auto batch_idx = buffer_read.local_batch_index; - auto lines_read = buffer_read.lines_read; - lock_guard parallel_lock(main_mutex); - line_info.current_batches[file_idx].erase(batch_idx); - line_info.lines_read[file_idx][batch_idx] += lines_read; -} - -bool LineInfo::CanItGetLine(idx_t file_idx, idx_t batch_idx) { - lock_guard parallel_lock(main_mutex); - if (current_batches.empty() || done) { - return true; - } - if (file_idx >= current_batches.size() || current_batches[file_idx].empty()) { - return true; - } - auto min_value = *current_batches[file_idx].begin(); - if (min_value >= batch_idx) { - return true; - } - return false; -} - -void LineInfo::Increment(idx_t file_idx, idx_t batch_idx) { - auto parallel_lock = duckdb::make_uniq>(main_mutex); - lines_errored[file_idx][batch_idx]++; -} - -// Returns the 1-indexed line number -idx_t LineInfo::GetLine(idx_t batch_idx, idx_t line_error, idx_t file_idx, idx_t cur_start, bool verify, - bool stop_at_first) { - unique_ptr> parallel_lock; - if (!verify) { - parallel_lock = duckdb::make_uniq>(main_mutex); - } - idx_t line_count = 0; - - if (!stop_at_first) { - // Figure out the amount of lines read in the current file - for (idx_t cur_batch_idx = 0; cur_batch_idx <= batch_idx; cur_batch_idx++) { - if (cur_batch_idx < batch_idx) { - line_count += lines_errored[file_idx][cur_batch_idx]; - } - line_count += lines_read[file_idx][cur_batch_idx]; - } - return line_count + line_error + 1; - } - - // Otherwise, check if we already have an error on another thread - if (done) { - // line count is 0-indexed, but we want to return 1-indexed - return first_line + 1; - } - for (idx_t i = 0; i <= batch_idx; i++) { - if (lines_read[file_idx].find(i) == lines_read[file_idx].end() && i != batch_idx) { - throw InternalException("Missing batch index on Parallel CSV Reader GetLine"); - } - line_count += lines_read[file_idx][i]; - } - - // before we are done, if this is not a call in Verify() we must check Verify up to this batch - if (!verify) { - Verify(file_idx, batch_idx, cur_start); - } - done = true; - first_line = line_count + line_error; - // line count is 0-indexed, but we want to return 1-indexed - return first_line + 1; -} - -static unique_ptr ParallelCSVInitGlobal(ClientContext &context, - TableFunctionInitInput &input) { - auto &bind_data = input.bind_data->CastNoConst(); - if (bind_data.files.empty()) { - // This can happen when a filename based filter pushdown has eliminated all possible files for this scan. - return make_uniq(context.db->NumberOfThreads()); - } - bind_data.options.file_path = bind_data.files[0]; - auto buffer_manager = bind_data.buffer_manager; - return make_uniq(context, buffer_manager, bind_data.options, context.db->NumberOfThreads(), - bind_data.files, ClientConfig::GetConfig(context).verify_parallelism, - input.column_ids); -} - -//===--------------------------------------------------------------------===// -// Read CSV Local State -//===--------------------------------------------------------------------===// -struct ParallelCSVLocalState : public LocalTableFunctionState { -public: - explicit ParallelCSVLocalState(unique_ptr csv_reader_p) : csv_reader(std::move(csv_reader_p)) { - } - - //! The CSV reader - unique_ptr csv_reader; - CSVBufferRead previous_buffer; - bool done = false; -}; - -unique_ptr ParallelReadCSVInitLocal(ExecutionContext &context, TableFunctionInitInput &input, - GlobalTableFunctionState *global_state_p) { - auto &csv_data = input.bind_data->Cast(); - auto &global_state = global_state_p->Cast(); - unique_ptr csv_reader; - auto has_next = global_state.Next(context.client, csv_data, csv_reader); - if (!has_next) { - global_state.DecrementThread(); - csv_reader.reset(); - } - return make_uniq(std::move(csv_reader)); -} - -static void ParallelReadCSVFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &bind_data = data_p.bind_data->Cast(); - auto &csv_global_state = data_p.global_state->Cast(); - auto &csv_local_state = data_p.local_state->Cast(); - - if (!csv_local_state.csv_reader) { - // no csv_reader was set, this can happen when a filename-based filter has filtered out all possible files - return; - } - - do { - if (output.size() != 0) { - MultiFileReader::FinalizeChunk(bind_data.reader_bind, csv_local_state.csv_reader->reader_data, output); - break; - } - if (csv_local_state.csv_reader->finished) { - auto verification_updates = csv_local_state.csv_reader->GetVerificationPositions(); - csv_global_state.UpdateVerification(verification_updates, - csv_local_state.csv_reader->buffer->buffer->file_idx, - csv_local_state.csv_reader->buffer->local_batch_index); - csv_global_state.UpdateLinesRead(*csv_local_state.csv_reader->buffer, csv_local_state.csv_reader->file_idx); - auto has_next = csv_global_state.Next(context, bind_data, csv_local_state.csv_reader); - if (csv_local_state.csv_reader) { - csv_local_state.csv_reader->linenr = 0; - } - if (!has_next) { - csv_global_state.DecrementThread(); - break; - } - } - csv_local_state.csv_reader->ParseCSV(output); - - } while (true); - if (csv_global_state.Finished()) { - csv_global_state.Verify(); - } -} - -//===--------------------------------------------------------------------===// -// Single-Threaded CSV Reader -//===--------------------------------------------------------------------===// -struct SingleThreadedCSVState : public GlobalTableFunctionState { - explicit SingleThreadedCSVState(idx_t total_files) : total_files(total_files), next_file(0), progress_in_files(0) { - } - - mutex csv_lock; - unique_ptr initial_reader; - //! The total number of files to read from - idx_t total_files; - //! The index of the next file to read (i.e. current file + 1) - atomic next_file; - //! How far along we are in reading the current set of open files - //! This goes from [0...next_file] * 100 - atomic progress_in_files; - //! The set of SQL types - vector csv_types; - //! The set of SQL names to be read from the file - vector csv_names; - //! The column ids to read - vector column_ids; - - idx_t MaxThreads() const override { - return total_files; - } - - double GetProgress(const ReadCSVData &bind_data) const { - D_ASSERT(total_files == bind_data.files.size()); - D_ASSERT(progress_in_files <= total_files * 100); - return (double(progress_in_files) / double(total_files)); - } - - unique_ptr GetCSVReader(ClientContext &context, ReadCSVData &bind_data, idx_t &file_index, - idx_t &total_size) { - return GetCSVReaderInternal(context, bind_data, file_index, total_size); - } - -private: - unique_ptr GetCSVReaderInternal(ClientContext &context, ReadCSVData &bind_data, - idx_t &file_index, idx_t &total_size) { - CSVReaderOptions options; - { - lock_guard l(csv_lock); - if (initial_reader) { - total_size = initial_reader->file_handle ? initial_reader->file_handle->FileSize() : 0; - return std::move(initial_reader); - } - if (next_file >= total_files) { - return nullptr; - } - options = bind_data.options; - file_index = next_file; - next_file++; - } - // reuse csv_readers was created during binding - unique_ptr result; - if (file_index < bind_data.union_readers.size() && bind_data.union_readers[file_index]) { - result = std::move(bind_data.union_readers[file_index]); - } else { - auto union_by_name = options.file_options.union_by_name; - options.file_path = bind_data.files[file_index]; - result = make_uniq(context, std::move(options), csv_types); - if (!union_by_name) { - result->names = csv_names; - } - MultiFileReader::InitializeReader(*result, bind_data.options.file_options, bind_data.reader_bind, - bind_data.return_types, bind_data.return_names, column_ids, nullptr, - bind_data.files.front(), context); - } - total_size = result->file_handle->FileSize(); - return result; - } -}; - -struct SingleThreadedCSVLocalState : public LocalTableFunctionState { -public: - explicit SingleThreadedCSVLocalState() : bytes_read(0), total_size(0), current_progress(0), file_index(0) { - } - - //! The CSV reader - unique_ptr csv_reader; - //! The current amount of bytes read by this reader - idx_t bytes_read; - //! The total amount of bytes in the file - idx_t total_size; - //! The current progress from 0..100 - idx_t current_progress; - //! The file index of this reader - idx_t file_index; -}; - -static unique_ptr SingleThreadedCSVInit(ClientContext &context, - TableFunctionInitInput &input) { - auto &bind_data = input.bind_data->CastNoConst(); - auto result = make_uniq(bind_data.files.size()); - if (bind_data.files.empty()) { - // This can happen when a filename based filter pushdown has eliminated all possible files for this scan. - return std::move(result); - } else { - bind_data.options.file_path = bind_data.files[0]; - result->initial_reader = make_uniq(context, bind_data.options, bind_data.csv_types); - if (!bind_data.options.file_options.union_by_name) { - result->initial_reader->names = bind_data.csv_names; - } - if (bind_data.options.auto_detect) { - bind_data.options = result->initial_reader->options; - } - } - MultiFileReader::InitializeReader(*result->initial_reader, bind_data.options.file_options, bind_data.reader_bind, - bind_data.return_types, bind_data.return_names, input.column_ids, input.filters, - bind_data.files.front(), context); - for (auto &reader : bind_data.union_readers) { - if (!reader) { - continue; - } - MultiFileReader::InitializeReader(*reader, bind_data.options.file_options, bind_data.reader_bind, - bind_data.return_types, bind_data.return_names, input.column_ids, - input.filters, bind_data.files.front(), context); - } - result->column_ids = input.column_ids; - - if (!bind_data.options.file_options.union_by_name) { - // if we are reading multiple files - run auto-detect only on the first file - // UNLESS union by name is turned on - in that case we assume that different files have different schemas - // as such, we need to re-run the auto detection on each file - bind_data.options.auto_detect = false; - } - result->csv_types = bind_data.csv_types; - result->csv_names = bind_data.csv_names; - result->next_file = 1; - return std::move(result); -} - -unique_ptr SingleThreadedReadCSVInitLocal(ExecutionContext &context, - TableFunctionInitInput &input, - GlobalTableFunctionState *global_state_p) { - auto &bind_data = input.bind_data->CastNoConst(); - auto &data = global_state_p->Cast(); - auto result = make_uniq(); - result->csv_reader = data.GetCSVReader(context.client, bind_data, result->file_index, result->total_size); - return std::move(result); -} - -static void SingleThreadedCSVFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &bind_data = data_p.bind_data->CastNoConst(); - auto &data = data_p.global_state->Cast(); - auto &lstate = data_p.local_state->Cast(); - if (!lstate.csv_reader) { - // no csv_reader was set, this can happen when a filename-based filter has filtered out all possible files - return; - } - - do { - lstate.csv_reader->ParseCSV(output); - // update the number of bytes read - D_ASSERT(lstate.bytes_read <= lstate.csv_reader->bytes_in_chunk); - auto bytes_read = MinValue(lstate.total_size, lstate.csv_reader->bytes_in_chunk); - auto current_progress = lstate.total_size == 0 ? 100 : 100 * bytes_read / lstate.total_size; - if (current_progress > lstate.current_progress) { - if (current_progress > 100) { - throw InternalException("Progress should never exceed 100"); - } - data.progress_in_files += current_progress - lstate.current_progress; - lstate.current_progress = current_progress; - } - if (output.size() == 0) { - // exhausted this file, but we might have more files we can read - auto csv_reader = data.GetCSVReader(context, bind_data, lstate.file_index, lstate.total_size); - // add any left-over progress for this file to the progress bar - if (lstate.current_progress < 100) { - data.progress_in_files += 100 - lstate.current_progress; - } - // reset the current progress - lstate.current_progress = 0; - lstate.bytes_read = 0; - lstate.csv_reader = std::move(csv_reader); - if (!lstate.csv_reader) { - // no more files - we are done - return; - } - lstate.bytes_read = 0; - } else { - MultiFileReader::FinalizeChunk(bind_data.reader_bind, lstate.csv_reader->reader_data, output); - break; - } - } while (true); -} - -//===--------------------------------------------------------------------===// -// Read CSV Functions -//===--------------------------------------------------------------------===// -static unique_ptr ReadCSVInitGlobal(ClientContext &context, TableFunctionInitInput &input) { - auto &bind_data = input.bind_data->Cast(); - - // Create the temporary rejects table - auto rejects_table = bind_data.options.rejects_table_name; - if (!rejects_table.empty()) { - CSVRejectsTable::GetOrCreate(context, rejects_table)->InitializeTable(context, bind_data); - } - if (bind_data.single_threaded) { - return SingleThreadedCSVInit(context, input); - } else { - return ParallelCSVInitGlobal(context, input); - } -} - -unique_ptr ReadCSVInitLocal(ExecutionContext &context, TableFunctionInitInput &input, - GlobalTableFunctionState *global_state_p) { - auto &csv_data = input.bind_data->Cast(); - if (csv_data.single_threaded) { - return SingleThreadedReadCSVInitLocal(context, input, global_state_p); - } else { - return ParallelReadCSVInitLocal(context, input, global_state_p); - } -} - -static void ReadCSVFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &bind_data = data_p.bind_data->Cast(); - if (bind_data.single_threaded) { - SingleThreadedCSVFunction(context, data_p, output); - } else { - ParallelReadCSVFunction(context, data_p, output); - } -} - -static idx_t CSVReaderGetBatchIndex(ClientContext &context, const FunctionData *bind_data_p, - LocalTableFunctionState *local_state, GlobalTableFunctionState *global_state) { - auto &bind_data = bind_data_p->Cast(); - if (bind_data.single_threaded) { - auto &data = local_state->Cast(); - return data.file_index; - } - auto &data = local_state->Cast(); - return data.csv_reader->buffer->batch_index; -} - -static void ReadCSVAddNamedParameters(TableFunction &table_function) { - table_function.named_parameters["sep"] = LogicalType::VARCHAR; - table_function.named_parameters["delim"] = LogicalType::VARCHAR; - table_function.named_parameters["quote"] = LogicalType::VARCHAR; - table_function.named_parameters["new_line"] = LogicalType::VARCHAR; - table_function.named_parameters["escape"] = LogicalType::VARCHAR; - table_function.named_parameters["nullstr"] = LogicalType::VARCHAR; - table_function.named_parameters["columns"] = LogicalType::ANY; - table_function.named_parameters["auto_type_candidates"] = LogicalType::ANY; - table_function.named_parameters["header"] = LogicalType::BOOLEAN; - table_function.named_parameters["auto_detect"] = LogicalType::BOOLEAN; - table_function.named_parameters["sample_size"] = LogicalType::BIGINT; - table_function.named_parameters["all_varchar"] = LogicalType::BOOLEAN; - table_function.named_parameters["dateformat"] = LogicalType::VARCHAR; - table_function.named_parameters["timestampformat"] = LogicalType::VARCHAR; - table_function.named_parameters["normalize_names"] = LogicalType::BOOLEAN; - table_function.named_parameters["compression"] = LogicalType::VARCHAR; - table_function.named_parameters["skip"] = LogicalType::BIGINT; - table_function.named_parameters["max_line_size"] = LogicalType::VARCHAR; - table_function.named_parameters["maximum_line_size"] = LogicalType::VARCHAR; - table_function.named_parameters["ignore_errors"] = LogicalType::BOOLEAN; - table_function.named_parameters["rejects_table"] = LogicalType::VARCHAR; - table_function.named_parameters["rejects_limit"] = LogicalType::BIGINT; - table_function.named_parameters["rejects_recovery_columns"] = LogicalType::LIST(LogicalType::VARCHAR); - table_function.named_parameters["buffer_size"] = LogicalType::UBIGINT; - table_function.named_parameters["decimal_separator"] = LogicalType::VARCHAR; - table_function.named_parameters["parallel"] = LogicalType::BOOLEAN; - table_function.named_parameters["null_padding"] = LogicalType::BOOLEAN; - table_function.named_parameters["allow_quoted_nulls"] = LogicalType::BOOLEAN; - table_function.named_parameters["column_types"] = LogicalType::ANY; - table_function.named_parameters["dtypes"] = LogicalType::ANY; - table_function.named_parameters["types"] = LogicalType::ANY; - table_function.named_parameters["names"] = LogicalType::LIST(LogicalType::VARCHAR); - table_function.named_parameters["column_names"] = LogicalType::LIST(LogicalType::VARCHAR); - MultiFileReader::AddParameters(table_function); -} - -double CSVReaderProgress(ClientContext &context, const FunctionData *bind_data_p, - const GlobalTableFunctionState *global_state) { - auto &bind_data = bind_data_p->Cast(); - if (bind_data.single_threaded) { - auto &data = global_state->Cast(); - return data.GetProgress(bind_data); - } else { - auto &data = global_state->Cast(); - return data.GetProgress(bind_data); - } -} - -void CSVComplexFilterPushdown(ClientContext &context, LogicalGet &get, FunctionData *bind_data_p, - vector> &filters) { - auto &data = bind_data_p->Cast(); - auto reset_reader = - MultiFileReader::ComplexFilterPushdown(context, data.files, data.options.file_options, get, filters); - if (reset_reader) { - MultiFileReader::PruneReaders(data); - } -} - -unique_ptr CSVReaderCardinality(ClientContext &context, const FunctionData *bind_data_p) { - auto &bind_data = bind_data_p->Cast(); - idx_t per_file_cardinality = 0; - if (bind_data.buffer_manager && bind_data.buffer_manager->file_handle) { - auto estimated_row_width = (bind_data.csv_types.size() * 5); - per_file_cardinality = bind_data.buffer_manager->file_handle->FileSize() / estimated_row_width; - } else { - // determined through the scientific method as the average amount of rows in a CSV file - per_file_cardinality = 42; - } - return make_uniq(bind_data.files.size() * per_file_cardinality); -} - -static void CSVReaderSerialize(Serializer &serializer, const optional_ptr bind_data_p, - const TableFunction &function) { - auto &bind_data = bind_data_p->Cast(); - serializer.WriteProperty(100, "extra_info", function.extra_info); - serializer.WriteProperty(101, "csv_data", &bind_data); -} - -static unique_ptr CSVReaderDeserialize(Deserializer &deserializer, TableFunction &function) { - unique_ptr result; - deserializer.ReadProperty(100, "extra_info", function.extra_info); - deserializer.ReadProperty(101, "csv_data", result); - return std::move(result); -} - -TableFunction ReadCSVTableFunction::GetFunction() { - TableFunction read_csv("read_csv", {LogicalType::VARCHAR}, ReadCSVFunction, ReadCSVBind, ReadCSVInitGlobal, - ReadCSVInitLocal); - read_csv.table_scan_progress = CSVReaderProgress; - read_csv.pushdown_complex_filter = CSVComplexFilterPushdown; - read_csv.serialize = CSVReaderSerialize; - read_csv.deserialize = CSVReaderDeserialize; - read_csv.get_batch_index = CSVReaderGetBatchIndex; - read_csv.cardinality = CSVReaderCardinality; - read_csv.projection_pushdown = true; - ReadCSVAddNamedParameters(read_csv); - return read_csv; -} - -TableFunction ReadCSVTableFunction::GetAutoFunction() { - auto read_csv_auto = ReadCSVTableFunction::GetFunction(); - read_csv_auto.name = "read_csv_auto"; - read_csv_auto.bind = ReadCSVAutoBind; - return read_csv_auto; -} - -void ReadCSVTableFunction::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(MultiFileReader::CreateFunctionSet(ReadCSVTableFunction::GetFunction())); - set.AddFunction(MultiFileReader::CreateFunctionSet(ReadCSVTableFunction::GetAutoFunction())); -} - -unique_ptr ReadCSVReplacement(ClientContext &context, const string &table_name, ReplacementScanData *data) { - auto lower_name = StringUtil::Lower(table_name); - // remove any compression - if (StringUtil::EndsWith(lower_name, ".gz")) { - lower_name = lower_name.substr(0, lower_name.size() - 3); - } else if (StringUtil::EndsWith(lower_name, ".zst")) { - if (!Catalog::TryAutoLoad(context, "parquet")) { - throw MissingExtensionException("parquet extension is required for reading zst compressed file"); - } - lower_name = lower_name.substr(0, lower_name.size() - 4); - } - if (!StringUtil::EndsWith(lower_name, ".csv") && !StringUtil::Contains(lower_name, ".csv?") && - !StringUtil::EndsWith(lower_name, ".tsv") && !StringUtil::Contains(lower_name, ".tsv?")) { - return nullptr; - } - auto table_function = make_uniq(); - vector> children; - children.push_back(make_uniq(Value(table_name))); - table_function->function = make_uniq("read_csv_auto", std::move(children)); - - if (!FileSystem::HasGlob(table_name)) { - auto &fs = FileSystem::GetFileSystem(context); - table_function->alias = fs.ExtractBaseName(table_name); - } - - return std::move(table_function); -} - -void BuiltinFunctions::RegisterReadFunctions() { - CSVCopyFunction::RegisterFunction(*this); - ReadCSVTableFunction::RegisterFunction(*this); - auto &config = DBConfig::GetConfig(*transaction.db); - config.replacement_scans.emplace_back(ReadCSVReplacement); -} - -} // namespace duckdb - - - -namespace duckdb { - -struct RepeatFunctionData : public TableFunctionData { - RepeatFunctionData(Value value, idx_t target_count) : value(std::move(value)), target_count(target_count) { - } - - Value value; - idx_t target_count; -}; - -struct RepeatOperatorData : public GlobalTableFunctionState { - RepeatOperatorData() : current_count(0) { - } - idx_t current_count; -}; - -static unique_ptr RepeatBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - // the repeat function returns the type of the first argument - auto &inputs = input.inputs; - return_types.push_back(inputs[0].type()); - names.push_back(inputs[0].ToString()); - if (inputs[1].IsNull()) { - throw BinderException("Repeat second parameter cannot be NULL"); - } - return make_uniq(inputs[0], inputs[1].GetValue()); -} - -static unique_ptr RepeatInit(ClientContext &context, TableFunctionInitInput &input) { - return make_uniq(); -} - -static void RepeatFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &bind_data = data_p.bind_data->Cast(); - auto &state = data_p.global_state->Cast(); - - idx_t remaining = MinValue(bind_data.target_count - state.current_count, STANDARD_VECTOR_SIZE); - output.data[0].Reference(bind_data.value); - output.SetCardinality(remaining); - state.current_count += remaining; -} - -static unique_ptr RepeatCardinality(ClientContext &context, const FunctionData *bind_data_p) { - auto &bind_data = bind_data_p->Cast(); - return make_uniq(bind_data.target_count, bind_data.target_count); -} - -void RepeatTableFunction::RegisterFunction(BuiltinFunctions &set) { - TableFunction repeat("repeat", {LogicalType::ANY, LogicalType::BIGINT}, RepeatFunction, RepeatBind, RepeatInit); - repeat.cardinality = RepeatCardinality; - set.AddFunction(repeat); -} - -} // namespace duckdb - - - -namespace duckdb { - -struct RepeatRowFunctionData : public TableFunctionData { - RepeatRowFunctionData(vector values, idx_t target_count) - : values(std::move(values)), target_count(target_count) { - } - - const vector values; - idx_t target_count; -}; - -struct RepeatRowOperatorData : public GlobalTableFunctionState { - RepeatRowOperatorData() : current_count(0) { - } - idx_t current_count; -}; - -static unique_ptr RepeatRowBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - auto &inputs = input.inputs; - for (idx_t input_idx = 0; input_idx < inputs.size(); input_idx++) { - return_types.push_back(inputs[input_idx].type()); - names.push_back("column" + std::to_string(input_idx)); - } - auto entry = input.named_parameters.find("num_rows"); - if (entry == input.named_parameters.end()) { - throw BinderException("repeat_rows requires num_rows to be specified"); - } - if (inputs.empty()) { - throw BinderException("repeat_rows requires at least one column to be specified"); - } - return make_uniq(inputs, entry->second.GetValue()); -} - -static unique_ptr RepeatRowInit(ClientContext &context, TableFunctionInitInput &input) { - return make_uniq(); -} - -static void RepeatRowFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &bind_data = data_p.bind_data->Cast(); - auto &state = data_p.global_state->Cast(); - - idx_t remaining = MinValue(bind_data.target_count - state.current_count, STANDARD_VECTOR_SIZE); - for (idx_t val_idx = 0; val_idx < bind_data.values.size(); val_idx++) { - output.data[val_idx].Reference(bind_data.values[val_idx]); - } - output.SetCardinality(remaining); - state.current_count += remaining; -} - -static unique_ptr RepeatRowCardinality(ClientContext &context, const FunctionData *bind_data_p) { - auto &bind_data = bind_data_p->Cast(); - return make_uniq(bind_data.target_count, bind_data.target_count); -} - -void RepeatRowTableFunction::RegisterFunction(BuiltinFunctions &set) { - TableFunction repeat_row("repeat_row", {}, RepeatRowFunction, RepeatRowBind, RepeatRowInit); - repeat_row.varargs = LogicalType::ANY; - repeat_row.named_parameters["num_rows"] = LogicalType::BIGINT; - repeat_row.cardinality = RepeatRowCardinality; - set.AddFunction(repeat_row); -} - -} // namespace duckdb - - - - - -// this function makes not that much sense on its own but is a demo for table-parameter table-producing functions - -namespace duckdb { - -static unique_ptr SummaryFunctionBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - - return_types.emplace_back(LogicalType::VARCHAR); - names.emplace_back("summary"); - - for (idx_t i = 0; i < input.input_table_types.size(); i++) { - return_types.push_back(input.input_table_types[i]); - names.emplace_back(input.input_table_names[i]); - } - - return make_uniq(); -} - -static OperatorResultType SummaryFunction(ExecutionContext &context, TableFunctionInput &data_p, DataChunk &input, - DataChunk &output) { - output.SetCardinality(input.size()); - - for (idx_t row_idx = 0; row_idx < input.size(); row_idx++) { - string summary_val = "["; - - for (idx_t col_idx = 0; col_idx < input.ColumnCount(); col_idx++) { - summary_val += input.GetValue(col_idx, row_idx).ToString(); - if (col_idx < input.ColumnCount() - 1) { - summary_val += ", "; - } - } - summary_val += "]"; - output.SetValue(0, row_idx, Value(summary_val)); - } - for (idx_t col_idx = 0; col_idx < input.ColumnCount(); col_idx++) { - output.data[col_idx + 1].Reference(input.data[col_idx]); - } - return OperatorResultType::NEED_MORE_INPUT; -} - -void SummaryTableFunction::RegisterFunction(BuiltinFunctions &set) { - TableFunction summary_function("summary", {LogicalType::TABLE}, nullptr, SummaryFunctionBind); - summary_function.in_out_function = SummaryFunction; - set.AddFunction(summary_function); -} - -} // namespace duckdb - - - - - - - - - - -#include - -namespace duckdb { - -struct DuckDBColumnsData : public GlobalTableFunctionState { - DuckDBColumnsData() : offset(0), column_offset(0) { - } - - vector> entries; - idx_t offset; - idx_t column_offset; -}; - -static unique_ptr DuckDBColumnsBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("database_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("database_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("schema_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("schema_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("table_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("table_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("column_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("column_index"); - return_types.emplace_back(LogicalType::INTEGER); - - names.emplace_back("internal"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("column_default"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("is_nullable"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("data_type"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("data_type_id"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("character_maximum_length"); - return_types.emplace_back(LogicalType::INTEGER); - - names.emplace_back("numeric_precision"); - return_types.emplace_back(LogicalType::INTEGER); - - names.emplace_back("numeric_precision_radix"); - return_types.emplace_back(LogicalType::INTEGER); - - names.emplace_back("numeric_scale"); - return_types.emplace_back(LogicalType::INTEGER); - - return nullptr; -} - -unique_ptr DuckDBColumnsInit(ClientContext &context, TableFunctionInitInput &input) { - auto result = make_uniq(); - - // scan all the schemas for tables and views and collect them - auto schemas = Catalog::GetAllSchemas(context); - for (auto &schema : schemas) { - schema.get().Scan(context, CatalogType::TABLE_ENTRY, - [&](CatalogEntry &entry) { result->entries.push_back(entry); }); - } - return std::move(result); -} - -class ColumnHelper { -public: - static unique_ptr Create(CatalogEntry &entry); - - virtual ~ColumnHelper() { - } - - virtual StandardEntry &Entry() = 0; - virtual idx_t NumColumns() = 0; - virtual const string &ColumnName(idx_t col) = 0; - virtual const LogicalType &ColumnType(idx_t col) = 0; - virtual const Value ColumnDefault(idx_t col) = 0; - virtual bool IsNullable(idx_t col) = 0; - - void WriteColumns(idx_t index, idx_t start_col, idx_t end_col, DataChunk &output); -}; - -class TableColumnHelper : public ColumnHelper { -public: - explicit TableColumnHelper(TableCatalogEntry &entry) : entry(entry) { - for (auto &constraint : entry.GetConstraints()) { - if (constraint->type == ConstraintType::NOT_NULL) { - auto ¬_null = *reinterpret_cast(constraint.get()); - not_null_cols.insert(not_null.index.index); - } - } - } - - StandardEntry &Entry() override { - return entry; - } - idx_t NumColumns() override { - return entry.GetColumns().LogicalColumnCount(); - } - const string &ColumnName(idx_t col) override { - return entry.GetColumn(LogicalIndex(col)).Name(); - } - const LogicalType &ColumnType(idx_t col) override { - return entry.GetColumn(LogicalIndex(col)).Type(); - } - const Value ColumnDefault(idx_t col) override { - auto &column = entry.GetColumn(LogicalIndex(col)); - if (column.Generated()) { - return Value(column.GeneratedExpression().ToString()); - } else if (column.DefaultValue()) { - return Value(column.DefaultValue()->ToString()); - } - return Value(); - } - bool IsNullable(idx_t col) override { - return not_null_cols.find(col) == not_null_cols.end(); - } - -private: - TableCatalogEntry &entry; - std::set not_null_cols; -}; - -class ViewColumnHelper : public ColumnHelper { -public: - explicit ViewColumnHelper(ViewCatalogEntry &entry) : entry(entry) { - } - - StandardEntry &Entry() override { - return entry; - } - idx_t NumColumns() override { - return entry.types.size(); - } - const string &ColumnName(idx_t col) override { - return entry.aliases[col]; - } - const LogicalType &ColumnType(idx_t col) override { - return entry.types[col]; - } - const Value ColumnDefault(idx_t col) override { - return Value(); - } - bool IsNullable(idx_t col) override { - return true; - } - -private: - ViewCatalogEntry &entry; -}; - -unique_ptr ColumnHelper::Create(CatalogEntry &entry) { - switch (entry.type) { - case CatalogType::TABLE_ENTRY: - return make_uniq(entry.Cast()); - case CatalogType::VIEW_ENTRY: - return make_uniq(entry.Cast()); - default: - throw NotImplementedException("Unsupported catalog type for duckdb_columns"); - } -} - -void ColumnHelper::WriteColumns(idx_t start_index, idx_t start_col, idx_t end_col, DataChunk &output) { - for (idx_t i = start_col; i < end_col; i++) { - auto index = start_index + (i - start_col); - auto &entry = Entry(); - - idx_t col = 0; - // database_name, VARCHAR - output.SetValue(col++, index, entry.catalog.GetName()); - // database_oid, BIGINT - output.SetValue(col++, index, Value::BIGINT(entry.catalog.GetOid())); - // schema_name, VARCHAR - output.SetValue(col++, index, entry.schema.name); - // schema_oid, BIGINT - output.SetValue(col++, index, Value::BIGINT(entry.schema.oid)); - // table_name, VARCHAR - output.SetValue(col++, index, entry.name); - // table_oid, BIGINT - output.SetValue(col++, index, Value::BIGINT(entry.oid)); - // column_name, VARCHAR - output.SetValue(col++, index, Value(ColumnName(i))); - // column_index, INTEGER - output.SetValue(col++, index, Value::INTEGER(i + 1)); - // internal, BOOLEAN - output.SetValue(col++, index, Value::BOOLEAN(entry.internal)); - // column_default, VARCHAR - output.SetValue(col++, index, Value(ColumnDefault(i))); - // is_nullable, BOOLEAN - output.SetValue(col++, index, Value::BOOLEAN(IsNullable(i))); - // data_type, VARCHAR - const LogicalType &type = ColumnType(i); - output.SetValue(col++, index, Value(type.ToString())); - // data_type_id, BIGINT - output.SetValue(col++, index, Value::BIGINT(int(type.id()))); - if (type == LogicalType::VARCHAR) { - // FIXME: need check constraints in place to set this correctly - // character_maximum_length, INTEGER - output.SetValue(col++, index, Value()); - } else { - // "character_maximum_length", PhysicalType::INTEGER - output.SetValue(col++, index, Value()); - } - - Value numeric_precision, numeric_scale, numeric_precision_radix; - switch (type.id()) { - case LogicalTypeId::DECIMAL: - numeric_precision = Value::INTEGER(DecimalType::GetWidth(type)); - numeric_scale = Value::INTEGER(DecimalType::GetScale(type)); - numeric_precision_radix = Value::INTEGER(10); - break; - case LogicalTypeId::HUGEINT: - numeric_precision = Value::INTEGER(128); - numeric_scale = Value::INTEGER(0); - numeric_precision_radix = Value::INTEGER(2); - break; - case LogicalTypeId::BIGINT: - numeric_precision = Value::INTEGER(64); - numeric_scale = Value::INTEGER(0); - numeric_precision_radix = Value::INTEGER(2); - break; - case LogicalTypeId::INTEGER: - numeric_precision = Value::INTEGER(32); - numeric_scale = Value::INTEGER(0); - numeric_precision_radix = Value::INTEGER(2); - break; - case LogicalTypeId::SMALLINT: - numeric_precision = Value::INTEGER(16); - numeric_scale = Value::INTEGER(0); - numeric_precision_radix = Value::INTEGER(2); - break; - case LogicalTypeId::TINYINT: - numeric_precision = Value::INTEGER(8); - numeric_scale = Value::INTEGER(0); - numeric_precision_radix = Value::INTEGER(2); - break; - case LogicalTypeId::FLOAT: - numeric_precision = Value::INTEGER(24); - numeric_scale = Value::INTEGER(0); - numeric_precision_radix = Value::INTEGER(2); - break; - case LogicalTypeId::DOUBLE: - numeric_precision = Value::INTEGER(53); - numeric_scale = Value::INTEGER(0); - numeric_precision_radix = Value::INTEGER(2); - break; - default: - numeric_precision = Value(); - numeric_scale = Value(); - numeric_precision_radix = Value(); - break; - } - - // numeric_precision, INTEGER - output.SetValue(col++, index, numeric_precision); - // numeric_precision_radix, INTEGER - output.SetValue(col++, index, numeric_precision_radix); - // numeric_scale, INTEGER - output.SetValue(col++, index, numeric_scale); - } -} - -void DuckDBColumnsFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.offset >= data.entries.size()) { - // finished returning values - return; - } - - // We need to track the offset of the relation we're writing as well as the last column - // we wrote from that relation (if any); it's possible that we can fill up the output - // with a partial list of columns from a relation and will need to pick up processing the - // next chunk at the same spot. - idx_t next = data.offset; - idx_t column_offset = data.column_offset; - idx_t index = 0; - while (next < data.entries.size() && index < STANDARD_VECTOR_SIZE) { - auto column_helper = ColumnHelper::Create(data.entries[next].get()); - idx_t columns = column_helper->NumColumns(); - - // Check to see if we are going to exceed the maximum index for a DataChunk - if (index + (columns - column_offset) > STANDARD_VECTOR_SIZE) { - idx_t column_limit = column_offset + (STANDARD_VECTOR_SIZE - index); - output.SetCardinality(STANDARD_VECTOR_SIZE); - column_helper->WriteColumns(index, column_offset, column_limit, output); - - // Make the current column limit the column offset when we process the next chunk - column_offset = column_limit; - break; - } else { - // Otherwise, write all of the columns from the current relation and - // then move on to the next one. - output.SetCardinality(index + (columns - column_offset)); - column_helper->WriteColumns(index, column_offset, columns, output); - index += columns - column_offset; - next++; - column_offset = 0; - } - } - data.offset = next; - data.column_offset = column_offset; -} - -void DuckDBColumnsFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(TableFunction("duckdb_columns", {}, DuckDBColumnsFunction, DuckDBColumnsBind, DuckDBColumnsInit)); -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - -namespace duckdb { - -struct UniqueKeyInfo { - string schema; - string table; - vector columns; - - bool operator==(const UniqueKeyInfo &other) const { - return (schema == other.schema) && (table == other.table) && (columns == other.columns); - } -}; - -} // namespace duckdb - -namespace std { - -template <> -struct hash { - template - static size_t ComputeHash(const X &x) { - return hash()(x); - } - - size_t operator()(const duckdb::UniqueKeyInfo &j) const { - D_ASSERT(j.columns.size() > 0); - return ComputeHash(j.schema) + ComputeHash(j.table) + ComputeHash(j.columns[0].index); - } -}; - -} // namespace std - -namespace duckdb { - -struct DuckDBConstraintsData : public GlobalTableFunctionState { - DuckDBConstraintsData() : offset(0), constraint_offset(0), unique_constraint_offset(0) { - } - - vector> entries; - idx_t offset; - idx_t constraint_offset; - idx_t unique_constraint_offset; - unordered_map known_fk_unique_constraint_offsets; -}; - -static unique_ptr DuckDBConstraintsBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("database_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("database_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("schema_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("schema_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("table_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("table_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("constraint_index"); - return_types.emplace_back(LogicalType::BIGINT); - - // CHECK, PRIMARY KEY or UNIQUE - names.emplace_back("constraint_type"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("constraint_text"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("expression"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("constraint_column_indexes"); - return_types.push_back(LogicalType::LIST(LogicalType::BIGINT)); - - names.emplace_back("constraint_column_names"); - return_types.push_back(LogicalType::LIST(LogicalType::VARCHAR)); - - return nullptr; -} - -unique_ptr DuckDBConstraintsInit(ClientContext &context, TableFunctionInitInput &input) { - auto result = make_uniq(); - - // scan all the schemas for tables and collect them - auto schemas = Catalog::GetAllSchemas(context); - - for (auto &schema : schemas) { - vector> entries; - - schema.get().Scan(context, CatalogType::TABLE_ENTRY, [&](CatalogEntry &entry) { - if (entry.type == CatalogType::TABLE_ENTRY) { - entries.push_back(entry); - } - }); - - sort(entries.begin(), entries.end(), [&](CatalogEntry &x, CatalogEntry &y) { return (x.name < y.name); }); - - result->entries.insert(result->entries.end(), entries.begin(), entries.end()); - }; - - return std::move(result); -} - -void DuckDBConstraintsFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.offset >= data.entries.size()) { - // finished returning values - return; - } - // start returning values - // either fill up the chunk or return all the remaining columns - idx_t count = 0; - while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { - auto &entry = data.entries[data.offset].get(); - D_ASSERT(entry.type == CatalogType::TABLE_ENTRY); - - auto &table = entry.Cast(); - auto &constraints = table.GetConstraints(); - bool is_duck_table = table.IsDuckTable(); - for (; data.constraint_offset < constraints.size() && count < STANDARD_VECTOR_SIZE; data.constraint_offset++) { - auto &constraint = constraints[data.constraint_offset]; - // return values: - // constraint_type, VARCHAR - // Processing this first due to shortcut (early continue) - string constraint_type; - switch (constraint->type) { - case ConstraintType::CHECK: - constraint_type = "CHECK"; - break; - case ConstraintType::UNIQUE: { - auto &unique = constraint->Cast(); - constraint_type = unique.is_primary_key ? "PRIMARY KEY" : "UNIQUE"; - break; - } - case ConstraintType::NOT_NULL: - constraint_type = "NOT NULL"; - break; - case ConstraintType::FOREIGN_KEY: { - if (!is_duck_table) { - continue; - } - auto &bound_constraints = table.GetBoundConstraints(); - auto &bound_foreign_key = bound_constraints[data.constraint_offset]->Cast(); - if (bound_foreign_key.info.type == ForeignKeyType::FK_TYPE_PRIMARY_KEY_TABLE) { - // Those are already covered by PRIMARY KEY and UNIQUE entries - continue; - } - constraint_type = "FOREIGN KEY"; - break; - } - default: - throw NotImplementedException("Unimplemented constraint for duckdb_constraints"); - } - - idx_t col = 0; - // database_name, LogicalType::VARCHAR - output.SetValue(col++, count, Value(table.schema.catalog.GetName())); - // database_oid, LogicalType::BIGINT - output.SetValue(col++, count, Value::BIGINT(table.schema.catalog.GetOid())); - // schema_name, LogicalType::VARCHAR - output.SetValue(col++, count, Value(table.schema.name)); - // schema_oid, LogicalType::BIGINT - output.SetValue(col++, count, Value::BIGINT(table.schema.oid)); - // table_name, LogicalType::VARCHAR - output.SetValue(col++, count, Value(table.name)); - // table_oid, LogicalType::BIGINT - output.SetValue(col++, count, Value::BIGINT(table.oid)); - - // constraint_index, BIGINT - UniqueKeyInfo uk_info; - - if (is_duck_table) { - auto &bound_constraint = *table.GetBoundConstraints()[data.constraint_offset]; - switch (bound_constraint.type) { - case ConstraintType::UNIQUE: { - auto &bound_unique = bound_constraint.Cast(); - uk_info = {table.schema.name, table.name, bound_unique.keys}; - break; - } - case ConstraintType::FOREIGN_KEY: { - const auto &bound_foreign_key = bound_constraint.Cast(); - const auto &info = bound_foreign_key.info; - // find the other table - auto table_entry = Catalog::GetEntry( - context, table.catalog.GetName(), info.schema, info.table, OnEntryNotFound::RETURN_NULL); - if (!table_entry) { - throw InternalException("dukdb_constraints: entry %s.%s referenced in foreign key not found", - info.schema, info.table); - } - vector index; - for (auto &key : info.pk_keys) { - index.push_back(table_entry->GetColumns().PhysicalToLogical(key)); - } - uk_info = {table_entry->schema.name, table_entry->name, index}; - break; - } - default: - break; - } - } - - if (uk_info.columns.empty()) { - output.SetValue(col++, count, Value::BIGINT(data.unique_constraint_offset++)); - } else { - auto known_unique_constraint_offset = data.known_fk_unique_constraint_offsets.find(uk_info); - if (known_unique_constraint_offset == data.known_fk_unique_constraint_offsets.end()) { - data.known_fk_unique_constraint_offsets.insert(make_pair(uk_info, data.unique_constraint_offset)); - output.SetValue(col++, count, Value::BIGINT(data.unique_constraint_offset)); - data.unique_constraint_offset++; - } else { - output.SetValue(col++, count, Value::BIGINT(known_unique_constraint_offset->second)); - } - } - output.SetValue(col++, count, Value(constraint_type)); - - // constraint_text, VARCHAR - output.SetValue(col++, count, Value(constraint->ToString())); - - // expression, VARCHAR - Value expression_text; - if (constraint->type == ConstraintType::CHECK) { - auto &check = constraint->Cast(); - expression_text = Value(check.expression->ToString()); - } - output.SetValue(col++, count, expression_text); - - vector column_index_list; - if (is_duck_table) { - auto &bound_constraint = *table.GetBoundConstraints()[data.constraint_offset]; - switch (bound_constraint.type) { - case ConstraintType::CHECK: { - auto &bound_check = bound_constraint.Cast(); - for (auto &col_idx : bound_check.bound_columns) { - column_index_list.push_back(table.GetColumns().PhysicalToLogical(col_idx)); - } - break; - } - case ConstraintType::UNIQUE: { - auto &bound_unique = bound_constraint.Cast(); - for (auto &col_idx : bound_unique.keys) { - column_index_list.push_back(col_idx); - } - break; - } - case ConstraintType::NOT_NULL: { - auto &bound_not_null = bound_constraint.Cast(); - column_index_list.push_back(table.GetColumns().PhysicalToLogical(bound_not_null.index)); - break; - } - case ConstraintType::FOREIGN_KEY: { - auto &bound_foreign_key = bound_constraint.Cast(); - for (auto &col_idx : bound_foreign_key.info.fk_keys) { - column_index_list.push_back(table.GetColumns().PhysicalToLogical(col_idx)); - } - break; - } - default: - throw NotImplementedException("Unimplemented constraint for duckdb_constraints"); - } - } - - vector index_list; - vector column_name_list; - for (auto column_index : column_index_list) { - index_list.push_back(Value::BIGINT(column_index.index)); - column_name_list.emplace_back(table.GetColumn(column_index).Name()); - } - - // constraint_column_indexes, LIST - output.SetValue(col++, count, Value::LIST(LogicalType::BIGINT, std::move(index_list))); - - // constraint_column_names, LIST - output.SetValue(col++, count, Value::LIST(LogicalType::VARCHAR, std::move(column_name_list))); - - count++; - } - if (data.constraint_offset >= constraints.size()) { - data.constraint_offset = 0; - data.offset++; - } - } - output.SetCardinality(count); -} - -void DuckDBConstraintsFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(TableFunction("duckdb_constraints", {}, DuckDBConstraintsFunction, DuckDBConstraintsBind, - DuckDBConstraintsInit)); -} - -} // namespace duckdb - - - - -namespace duckdb { - -struct DuckDBDatabasesData : public GlobalTableFunctionState { - DuckDBDatabasesData() : offset(0) { - } - - vector> entries; - idx_t offset; -}; - -static unique_ptr DuckDBDatabasesBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("database_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("database_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("path"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("internal"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("type"); - return_types.emplace_back(LogicalType::VARCHAR); - - return nullptr; -} - -unique_ptr DuckDBDatabasesInit(ClientContext &context, TableFunctionInitInput &input) { - auto result = make_uniq(); - - // scan all the schemas for tables and collect them and collect them - auto &db_manager = DatabaseManager::Get(context); - result->entries = db_manager.GetDatabases(context); - return std::move(result); -} - -void DuckDBDatabasesFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.offset >= data.entries.size()) { - // finished returning values - return; - } - // start returning values - // either fill up the chunk or return all the remaining columns - idx_t count = 0; - while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { - auto &entry = data.entries[data.offset++]; - - auto &attached = entry.get().Cast(); - // return values: - - idx_t col = 0; - // database_name, VARCHAR - output.SetValue(col++, count, attached.GetName()); - // database_oid, BIGINT - output.SetValue(col++, count, Value::BIGINT(attached.oid)); - // path, VARCHAR - bool is_internal = attached.IsSystem() || attached.IsTemporary(); - Value db_path; - if (!is_internal) { - bool in_memory = attached.GetCatalog().InMemory(); - if (!in_memory) { - db_path = Value(attached.GetCatalog().GetDBPath()); - } - } - output.SetValue(col++, count, db_path); - // internal, BOOLEAN - output.SetValue(col++, count, Value::BOOLEAN(is_internal)); - // type, VARCHAR - output.SetValue(col++, count, Value(attached.GetCatalog().GetCatalogType())); - - count++; - } - output.SetCardinality(count); -} - -void DuckDBDatabasesFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction( - TableFunction("duckdb_databases", {}, DuckDBDatabasesFunction, DuckDBDatabasesBind, DuckDBDatabasesInit)); -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -struct DependencyInformation { - DependencyInformation(CatalogEntry &object, CatalogEntry &dependent, DependencyType type) - : object(object), dependent(dependent), type(type) { - } - - CatalogEntry &object; - CatalogEntry &dependent; - DependencyType type; -}; - -struct DuckDBDependenciesData : public GlobalTableFunctionState { - DuckDBDependenciesData() : offset(0) { - } - - vector entries; - idx_t offset; -}; - -static unique_ptr DuckDBDependenciesBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("classid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("objid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("objsubid"); - return_types.emplace_back(LogicalType::INTEGER); - - names.emplace_back("refclassid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("refobjid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("refobjsubid"); - return_types.emplace_back(LogicalType::INTEGER); - - names.emplace_back("deptype"); - return_types.emplace_back(LogicalType::VARCHAR); - - return nullptr; -} - -unique_ptr DuckDBDependenciesInit(ClientContext &context, TableFunctionInitInput &input) { - auto result = make_uniq(); - - // scan all the schemas and collect them - auto &catalog = Catalog::GetCatalog(context, INVALID_CATALOG); - if (catalog.IsDuckCatalog()) { - auto &duck_catalog = catalog.Cast(); - auto &dependency_manager = duck_catalog.GetDependencyManager(); - dependency_manager.Scan([&](CatalogEntry &obj, CatalogEntry &dependent, DependencyType type) { - result->entries.emplace_back(obj, dependent, type); - }); - } - - return std::move(result); -} - -void DuckDBDependenciesFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.offset >= data.entries.size()) { - // finished returning values - return; - } - // start returning values - // either fill up the chunk or return all the remaining columns - idx_t count = 0; - while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { - auto &entry = data.entries[data.offset]; - - // return values: - // classid, LogicalType::BIGINT - output.SetValue(0, count, Value::BIGINT(0)); - // objid, LogicalType::BIGINT - output.SetValue(1, count, Value::BIGINT(entry.object.oid)); - // objsubid, LogicalType::INTEGER - output.SetValue(2, count, Value::INTEGER(0)); - // refclassid, LogicalType::BIGINT - output.SetValue(3, count, Value::BIGINT(0)); - // refobjid, LogicalType::BIGINT - output.SetValue(4, count, Value::BIGINT(entry.dependent.oid)); - // refobjsubid, LogicalType::INTEGER - output.SetValue(5, count, Value::INTEGER(0)); - // deptype, LogicalType::VARCHAR - string dependency_type_str; - switch (entry.type) { - case DependencyType::DEPENDENCY_REGULAR: - dependency_type_str = "n"; - break; - case DependencyType::DEPENDENCY_AUTOMATIC: - dependency_type_str = "a"; - break; - default: - throw NotImplementedException("Unimplemented dependency type"); - } - output.SetValue(6, count, Value(dependency_type_str)); - - data.offset++; - count++; - } - output.SetCardinality(count); -} - -void DuckDBDependenciesFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(TableFunction("duckdb_dependencies", {}, DuckDBDependenciesFunction, DuckDBDependenciesBind, - DuckDBDependenciesInit)); -} - -} // namespace duckdb - - - - - - - - - - -namespace duckdb { - -struct ExtensionInformation { - string name; - bool loaded = false; - bool installed = false; - string file_path; - string description; - vector aliases; -}; - -struct DuckDBExtensionsData : public GlobalTableFunctionState { - DuckDBExtensionsData() : offset(0) { - } - - vector entries; - idx_t offset; -}; - -static unique_ptr DuckDBExtensionsBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("extension_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("loaded"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("installed"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("install_path"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("description"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("aliases"); - return_types.emplace_back(LogicalType::LIST(LogicalType::VARCHAR)); - - return nullptr; -} - -unique_ptr DuckDBExtensionsInit(ClientContext &context, TableFunctionInitInput &input) { - auto result = make_uniq(); - - auto &fs = FileSystem::GetFileSystem(context); - auto &db = DatabaseInstance::GetDatabase(context); - - map installed_extensions; - auto extension_count = ExtensionHelper::DefaultExtensionCount(); - auto alias_count = ExtensionHelper::ExtensionAliasCount(); - for (idx_t i = 0; i < extension_count; i++) { - auto extension = ExtensionHelper::GetDefaultExtension(i); - ExtensionInformation info; - info.name = extension.name; - info.installed = extension.statically_loaded; - info.loaded = false; - info.file_path = extension.statically_loaded ? "(BUILT-IN)" : string(); - info.description = extension.description; - for (idx_t k = 0; k < alias_count; k++) { - auto alias = ExtensionHelper::GetExtensionAlias(k); - if (info.name == alias.extension) { - info.aliases.emplace_back(alias.alias); - } - } - installed_extensions[info.name] = std::move(info); - } -#ifndef WASM_LOADABLE_EXTENSIONS - // scan the install directory for installed extensions - auto ext_directory = ExtensionHelper::ExtensionDirectory(context); - fs.ListFiles(ext_directory, [&](const string &path, bool is_directory) { - if (!StringUtil::EndsWith(path, ".duckdb_extension")) { - return; - } - ExtensionInformation info; - info.name = fs.ExtractBaseName(path); - info.loaded = false; - info.file_path = fs.JoinPath(ext_directory, path); - auto entry = installed_extensions.find(info.name); - if (entry == installed_extensions.end()) { - installed_extensions[info.name] = std::move(info); - } else { - if (!entry->second.loaded) { - entry->second.file_path = info.file_path; - } - entry->second.installed = true; - } - }); -#endif - // now check the list of currently loaded extensions - auto &loaded_extensions = db.LoadedExtensions(); - for (auto &ext_name : loaded_extensions) { - auto entry = installed_extensions.find(ext_name); - if (entry == installed_extensions.end()) { - ExtensionInformation info; - info.name = ext_name; - info.loaded = true; - installed_extensions[ext_name] = std::move(info); - } else { - entry->second.loaded = true; - } - } - - result->entries.reserve(installed_extensions.size()); - for (auto &kv : installed_extensions) { - result->entries.push_back(std::move(kv.second)); - } - return std::move(result); -} - -void DuckDBExtensionsFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.offset >= data.entries.size()) { - // finished returning values - return; - } - // start returning values - // either fill up the chunk or return all the remaining columns - idx_t count = 0; - while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { - auto &entry = data.entries[data.offset]; - - // return values: - // extension_name LogicalType::VARCHAR - output.SetValue(0, count, Value(entry.name)); - // loaded LogicalType::BOOLEAN - output.SetValue(1, count, Value::BOOLEAN(entry.loaded)); - // installed LogicalType::BOOLEAN - output.SetValue(2, count, !entry.installed && entry.loaded ? Value() : Value::BOOLEAN(entry.installed)); - // install_path LogicalType::VARCHAR - output.SetValue(3, count, Value(entry.file_path)); - // description LogicalType::VARCHAR - output.SetValue(4, count, Value(entry.description)); - // aliases LogicalType::LIST(LogicalType::VARCHAR) - output.SetValue(5, count, Value::LIST(LogicalType::VARCHAR, entry.aliases)); - - data.offset++; - count++; - } - output.SetCardinality(count); -} - -void DuckDBExtensionsFun::RegisterFunction(BuiltinFunctions &set) { - TableFunctionSet functions("duckdb_extensions"); - functions.AddFunction(TableFunction({}, DuckDBExtensionsFunction, DuckDBExtensionsBind, DuckDBExtensionsInit)); - set.AddFunction(functions); -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - -namespace duckdb { - -struct DuckDBFunctionsData : public GlobalTableFunctionState { - DuckDBFunctionsData() : offset(0), offset_in_entry(0) { - } - - vector> entries; - idx_t offset; - idx_t offset_in_entry; -}; - -static unique_ptr DuckDBFunctionsBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("database_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("schema_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("function_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("function_type"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("description"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("return_type"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("parameters"); - return_types.push_back(LogicalType::LIST(LogicalType::VARCHAR)); - - names.emplace_back("parameter_types"); - return_types.push_back(LogicalType::LIST(LogicalType::VARCHAR)); - - names.emplace_back("varargs"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("macro_definition"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("has_side_effects"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("internal"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("function_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("example"); - return_types.emplace_back(LogicalType::VARCHAR); - - return nullptr; -} - -static void ExtractFunctionsFromSchema(ClientContext &context, SchemaCatalogEntry &schema, - DuckDBFunctionsData &result) { - schema.Scan(context, CatalogType::SCALAR_FUNCTION_ENTRY, - [&](CatalogEntry &entry) { result.entries.push_back(entry); }); - schema.Scan(context, CatalogType::TABLE_FUNCTION_ENTRY, - [&](CatalogEntry &entry) { result.entries.push_back(entry); }); - schema.Scan(context, CatalogType::PRAGMA_FUNCTION_ENTRY, - [&](CatalogEntry &entry) { result.entries.push_back(entry); }); -} - -unique_ptr DuckDBFunctionsInit(ClientContext &context, TableFunctionInitInput &input) { - auto result = make_uniq(); - - // scan all the schemas for tables and collect them and collect them - auto schemas = Catalog::GetAllSchemas(context); - for (auto &schema : schemas) { - ExtractFunctionsFromSchema(context, schema.get(), *result); - }; - - std::sort(result->entries.begin(), result->entries.end(), - [&](reference a, reference b) { - return (int32_t)a.get().type < (int32_t)b.get().type; - }); - return std::move(result); -} - -struct ScalarFunctionExtractor { - static idx_t FunctionCount(ScalarFunctionCatalogEntry &entry) { - return entry.functions.Size(); - } - - static Value GetFunctionType() { - return Value("scalar"); - } - - static Value GetReturnType(ScalarFunctionCatalogEntry &entry, idx_t offset) { - return Value(entry.functions.GetFunctionByOffset(offset).return_type.ToString()); - } - - static vector GetParameters(ScalarFunctionCatalogEntry &entry, idx_t offset) { - vector results; - for (idx_t i = 0; i < entry.functions.GetFunctionByOffset(offset).arguments.size(); i++) { - results.emplace_back("col" + to_string(i)); - } - return results; - } - - static Value GetParameterTypes(ScalarFunctionCatalogEntry &entry, idx_t offset) { - vector results; - auto fun = entry.functions.GetFunctionByOffset(offset); - for (idx_t i = 0; i < fun.arguments.size(); i++) { - results.emplace_back(fun.arguments[i].ToString()); - } - return Value::LIST(LogicalType::VARCHAR, std::move(results)); - } - - static Value GetVarArgs(ScalarFunctionCatalogEntry &entry, idx_t offset) { - auto fun = entry.functions.GetFunctionByOffset(offset); - return !fun.HasVarArgs() ? Value() : Value(fun.varargs.ToString()); - } - - static Value GetMacroDefinition(ScalarFunctionCatalogEntry &entry, idx_t offset) { - return Value(); - } - - static Value HasSideEffects(ScalarFunctionCatalogEntry &entry, idx_t offset) { - return Value::BOOLEAN(entry.functions.GetFunctionByOffset(offset).side_effects == - FunctionSideEffects::HAS_SIDE_EFFECTS); - } -}; - -struct AggregateFunctionExtractor { - static idx_t FunctionCount(AggregateFunctionCatalogEntry &entry) { - return entry.functions.Size(); - } - - static Value GetFunctionType() { - return Value("aggregate"); - } - - static Value GetReturnType(AggregateFunctionCatalogEntry &entry, idx_t offset) { - return Value(entry.functions.GetFunctionByOffset(offset).return_type.ToString()); - } - - static vector GetParameters(AggregateFunctionCatalogEntry &entry, idx_t offset) { - vector results; - for (idx_t i = 0; i < entry.functions.GetFunctionByOffset(offset).arguments.size(); i++) { - results.emplace_back("col" + to_string(i)); - } - return results; - } - - static Value GetParameterTypes(AggregateFunctionCatalogEntry &entry, idx_t offset) { - vector results; - auto fun = entry.functions.GetFunctionByOffset(offset); - for (idx_t i = 0; i < fun.arguments.size(); i++) { - results.emplace_back(fun.arguments[i].ToString()); - } - return Value::LIST(LogicalType::VARCHAR, std::move(results)); - } - - static Value GetVarArgs(AggregateFunctionCatalogEntry &entry, idx_t offset) { - auto fun = entry.functions.GetFunctionByOffset(offset); - return !fun.HasVarArgs() ? Value() : Value(fun.varargs.ToString()); - } - - static Value GetMacroDefinition(AggregateFunctionCatalogEntry &entry, idx_t offset) { - return Value(); - } - - static Value HasSideEffects(AggregateFunctionCatalogEntry &entry, idx_t offset) { - return Value::BOOLEAN(entry.functions.GetFunctionByOffset(offset).side_effects == - FunctionSideEffects::HAS_SIDE_EFFECTS); - } -}; - -struct MacroExtractor { - static idx_t FunctionCount(ScalarMacroCatalogEntry &entry) { - return 1; - } - - static Value GetFunctionType() { - return Value("macro"); - } - - static Value GetReturnType(ScalarMacroCatalogEntry &entry, idx_t offset) { - return Value(); - } - - static vector GetParameters(ScalarMacroCatalogEntry &entry, idx_t offset) { - vector results; - for (auto ¶m : entry.function->parameters) { - D_ASSERT(param->type == ExpressionType::COLUMN_REF); - auto &colref = param->Cast(); - results.emplace_back(colref.GetColumnName()); - } - for (auto ¶m_entry : entry.function->default_parameters) { - results.emplace_back(param_entry.first); - } - return results; - } - - static Value GetParameterTypes(ScalarMacroCatalogEntry &entry, idx_t offset) { - vector results; - for (idx_t i = 0; i < entry.function->parameters.size(); i++) { - results.emplace_back(LogicalType::VARCHAR); - } - for (idx_t i = 0; i < entry.function->default_parameters.size(); i++) { - results.emplace_back(LogicalType::VARCHAR); - } - return Value::LIST(LogicalType::VARCHAR, std::move(results)); - } - - static Value GetVarArgs(ScalarMacroCatalogEntry &entry, idx_t offset) { - return Value(); - } - - static Value GetMacroDefinition(ScalarMacroCatalogEntry &entry, idx_t offset) { - D_ASSERT(entry.function->type == MacroType::SCALAR_MACRO); - auto &func = entry.function->Cast(); - return func.expression->ToString(); - } - - static Value HasSideEffects(ScalarMacroCatalogEntry &entry, idx_t offset) { - return Value(); - } -}; - -struct TableMacroExtractor { - static idx_t FunctionCount(TableMacroCatalogEntry &entry) { - return 1; - } - - static Value GetFunctionType() { - return Value("table_macro"); - } - - static Value GetReturnType(TableMacroCatalogEntry &entry, idx_t offset) { - return Value(); - } - - static vector GetParameters(TableMacroCatalogEntry &entry, idx_t offset) { - vector results; - for (auto ¶m : entry.function->parameters) { - D_ASSERT(param->type == ExpressionType::COLUMN_REF); - auto &colref = param->Cast(); - results.emplace_back(colref.GetColumnName()); - } - for (auto ¶m_entry : entry.function->default_parameters) { - results.emplace_back(param_entry.first); - } - return results; - } - - static Value GetParameterTypes(TableMacroCatalogEntry &entry, idx_t offset) { - vector results; - for (idx_t i = 0; i < entry.function->parameters.size(); i++) { - results.emplace_back(LogicalType::VARCHAR); - } - for (idx_t i = 0; i < entry.function->default_parameters.size(); i++) { - results.emplace_back(LogicalType::VARCHAR); - } - return Value::LIST(LogicalType::VARCHAR, std::move(results)); - } - - static Value GetVarArgs(TableMacroCatalogEntry &entry, idx_t offset) { - return Value(); - } - - static Value GetMacroDefinition(TableMacroCatalogEntry &entry, idx_t offset) { - if (entry.function->type == MacroType::SCALAR_MACRO) { - auto &func = entry.function->Cast(); - return func.expression->ToString(); - } - return Value(); - } - - static Value HasSideEffects(TableMacroCatalogEntry &entry, idx_t offset) { - return Value(); - } -}; - -struct TableFunctionExtractor { - static idx_t FunctionCount(TableFunctionCatalogEntry &entry) { - return entry.functions.Size(); - } - - static Value GetFunctionType() { - return Value("table"); - } - - static Value GetReturnType(TableFunctionCatalogEntry &entry, idx_t offset) { - return Value(); - } - - static vector GetParameters(TableFunctionCatalogEntry &entry, idx_t offset) { - vector results; - auto fun = entry.functions.GetFunctionByOffset(offset); - for (idx_t i = 0; i < fun.arguments.size(); i++) { - results.emplace_back("col" + to_string(i)); - } - for (auto ¶m : fun.named_parameters) { - results.emplace_back(param.first); - } - return results; - } - - static Value GetParameterTypes(TableFunctionCatalogEntry &entry, idx_t offset) { - vector results; - auto fun = entry.functions.GetFunctionByOffset(offset); - - for (idx_t i = 0; i < fun.arguments.size(); i++) { - results.emplace_back(fun.arguments[i].ToString()); - } - for (auto ¶m : fun.named_parameters) { - results.emplace_back(param.second.ToString()); - } - return Value::LIST(LogicalType::VARCHAR, std::move(results)); - } - - static Value GetVarArgs(TableFunctionCatalogEntry &entry, idx_t offset) { - auto fun = entry.functions.GetFunctionByOffset(offset); - return !fun.HasVarArgs() ? Value() : Value(fun.varargs.ToString()); - } - - static Value GetMacroDefinition(TableFunctionCatalogEntry &entry, idx_t offset) { - return Value(); - } - - static Value HasSideEffects(TableFunctionCatalogEntry &entry, idx_t offset) { - return Value(); - } -}; - -struct PragmaFunctionExtractor { - static idx_t FunctionCount(PragmaFunctionCatalogEntry &entry) { - return entry.functions.Size(); - } - - static Value GetFunctionType() { - return Value("pragma"); - } - - static Value GetReturnType(PragmaFunctionCatalogEntry &entry, idx_t offset) { - return Value(); - } - - static vector GetParameters(PragmaFunctionCatalogEntry &entry, idx_t offset) { - vector results; - auto fun = entry.functions.GetFunctionByOffset(offset); - - for (idx_t i = 0; i < fun.arguments.size(); i++) { - results.emplace_back("col" + to_string(i)); - } - for (auto ¶m : fun.named_parameters) { - results.emplace_back(param.first); - } - return results; - } - - static Value GetParameterTypes(PragmaFunctionCatalogEntry &entry, idx_t offset) { - vector results; - auto fun = entry.functions.GetFunctionByOffset(offset); - - for (idx_t i = 0; i < fun.arguments.size(); i++) { - results.emplace_back(fun.arguments[i].ToString()); - } - for (auto ¶m : fun.named_parameters) { - results.emplace_back(param.second.ToString()); - } - return Value::LIST(LogicalType::VARCHAR, std::move(results)); - } - - static Value GetVarArgs(PragmaFunctionCatalogEntry &entry, idx_t offset) { - auto fun = entry.functions.GetFunctionByOffset(offset); - return !fun.HasVarArgs() ? Value() : Value(fun.varargs.ToString()); - } - - static Value GetMacroDefinition(PragmaFunctionCatalogEntry &entry, idx_t offset) { - return Value(); - } - - static Value HasSideEffects(PragmaFunctionCatalogEntry &entry, idx_t offset) { - return Value(); - } -}; - -template -bool ExtractFunctionData(FunctionEntry &entry, idx_t function_idx, DataChunk &output, idx_t output_offset) { - auto &function = entry.Cast(); - idx_t col = 0; - - // database_name, LogicalType::VARCHAR - output.SetValue(col++, output_offset, Value(function.schema.catalog.GetName())); - - // schema_name, LogicalType::VARCHAR - output.SetValue(col++, output_offset, Value(function.schema.name)); - - // function_name, LogicalType::VARCHAR - output.SetValue(col++, output_offset, Value(function.name)); - - // function_type, LogicalType::VARCHAR - output.SetValue(col++, output_offset, Value(OP::GetFunctionType())); - - // function_description, LogicalType::VARCHAR - output.SetValue(col++, output_offset, entry.description.empty() ? Value() : entry.description); - - // return_type, LogicalType::VARCHAR - output.SetValue(col++, output_offset, OP::GetReturnType(function, function_idx)); - - // parameters, LogicalType::LIST(LogicalType::VARCHAR) - auto parameters = OP::GetParameters(function, function_idx); - for (idx_t param_idx = 0; param_idx < function.parameter_names.size() && param_idx < parameters.size(); - param_idx++) { - parameters[param_idx] = Value(function.parameter_names[param_idx]); - } - output.SetValue(col++, output_offset, Value::LIST(LogicalType::VARCHAR, std::move(parameters))); - - // parameter_types, LogicalType::LIST(LogicalType::VARCHAR) - output.SetValue(col++, output_offset, OP::GetParameterTypes(function, function_idx)); - - // varargs, LogicalType::VARCHAR - output.SetValue(col++, output_offset, OP::GetVarArgs(function, function_idx)); - - // macro_definition, LogicalType::VARCHAR - output.SetValue(col++, output_offset, OP::GetMacroDefinition(function, function_idx)); - - // has_side_effects, LogicalType::BOOLEAN - output.SetValue(col++, output_offset, OP::HasSideEffects(function, function_idx)); - - // internal, LogicalType::BOOLEAN - output.SetValue(col++, output_offset, Value::BOOLEAN(function.internal)); - - // function_oid, LogicalType::BIGINT - output.SetValue(col++, output_offset, Value::BIGINT(function.oid)); - - // example, LogicalType::VARCHAR - output.SetValue(col++, output_offset, entry.example.empty() ? Value() : entry.example); - - return function_idx + 1 == OP::FunctionCount(function); -} - -void DuckDBFunctionsFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.offset >= data.entries.size()) { - // finished returning values - return; - } - // start returning values - // either fill up the chunk or return all the remaining columns - idx_t count = 0; - while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { - auto &entry = data.entries[data.offset].get().Cast(); - bool finished; - - switch (entry.type) { - case CatalogType::SCALAR_FUNCTION_ENTRY: - finished = ExtractFunctionData( - entry, data.offset_in_entry, output, count); - break; - case CatalogType::AGGREGATE_FUNCTION_ENTRY: - finished = ExtractFunctionData( - entry, data.offset_in_entry, output, count); - break; - case CatalogType::TABLE_MACRO_ENTRY: - finished = ExtractFunctionData(entry, data.offset_in_entry, - output, count); - break; - - case CatalogType::MACRO_ENTRY: - finished = ExtractFunctionData(entry, data.offset_in_entry, output, - count); - break; - case CatalogType::TABLE_FUNCTION_ENTRY: - finished = ExtractFunctionData( - entry, data.offset_in_entry, output, count); - break; - case CatalogType::PRAGMA_FUNCTION_ENTRY: - finished = ExtractFunctionData( - entry, data.offset_in_entry, output, count); - break; - default: - throw InternalException("FIXME: unrecognized function type in duckdb_functions"); - } - if (finished) { - // finished with this function, move to the next function - data.offset++; - data.offset_in_entry = 0; - } else { - // more functions remain - data.offset_in_entry++; - } - count++; - } - output.SetCardinality(count); -} - -void DuckDBFunctionsFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction( - TableFunction("duckdb_functions", {}, DuckDBFunctionsFunction, DuckDBFunctionsBind, DuckDBFunctionsInit)); -} - -} // namespace duckdb - - - - - - - - - - - -namespace duckdb { - -struct DuckDBIndexesData : public GlobalTableFunctionState { - DuckDBIndexesData() : offset(0) { - } - - vector> entries; - idx_t offset; -}; - -static unique_ptr DuckDBIndexesBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("database_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("database_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("schema_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("schema_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("index_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("index_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("table_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("table_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("is_unique"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("is_primary"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("expressions"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("sql"); - return_types.emplace_back(LogicalType::VARCHAR); - - return nullptr; -} - -unique_ptr DuckDBIndexesInit(ClientContext &context, TableFunctionInitInput &input) { - auto result = make_uniq(); - - // scan all the schemas for tables and collect them and collect them - auto schemas = Catalog::GetAllSchemas(context); - for (auto &schema : schemas) { - schema.get().Scan(context, CatalogType::INDEX_ENTRY, - [&](CatalogEntry &entry) { result->entries.push_back(entry); }); - }; - return std::move(result); -} - -void DuckDBIndexesFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.offset >= data.entries.size()) { - // finished returning values - return; - } - // start returning values - // either fill up the chunk or return all the remaining columns - idx_t count = 0; - while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { - auto &entry = data.entries[data.offset++].get(); - - auto &index = entry.Cast(); - // return values: - - idx_t col = 0; - // database_name, VARCHAR - output.SetValue(col++, count, index.catalog.GetName()); - // database_oid, BIGINT - output.SetValue(col++, count, Value::BIGINT(index.catalog.GetOid())); - // schema_name, VARCHAR - output.SetValue(col++, count, Value(index.schema.name)); - // schema_oid, BIGINT - output.SetValue(col++, count, Value::BIGINT(index.schema.oid)); - // index_name, VARCHAR - output.SetValue(col++, count, Value(index.name)); - // index_oid, BIGINT - output.SetValue(col++, count, Value::BIGINT(index.oid)); - // find the table in the catalog - auto &table_entry = - index.schema.catalog.GetEntry(context, index.GetSchemaName(), index.GetTableName()); - // table_name, VARCHAR - output.SetValue(col++, count, Value(table_entry.name)); - // table_oid, BIGINT - output.SetValue(col++, count, Value::BIGINT(table_entry.oid)); - if (index.index) { - // is_unique, BOOLEAN - output.SetValue(col++, count, Value::BOOLEAN(index.index->IsUnique())); - // is_primary, BOOLEAN - output.SetValue(col++, count, Value::BOOLEAN(index.index->IsPrimary())); - } else { - output.SetValue(col++, count, Value()); - output.SetValue(col++, count, Value()); - } - // expressions, VARCHAR - output.SetValue(col++, count, Value()); - // sql, VARCHAR - auto sql = index.ToSQL(); - output.SetValue(col++, count, sql.empty() ? Value() : Value(std::move(sql))); - - count++; - } - output.SetCardinality(count); -} - -void DuckDBIndexesFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(TableFunction("duckdb_indexes", {}, DuckDBIndexesFunction, DuckDBIndexesBind, DuckDBIndexesInit)); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -struct DuckDBKeywordsData : public GlobalTableFunctionState { - DuckDBKeywordsData() : offset(0) { - } - - vector entries; - idx_t offset; -}; - -static unique_ptr DuckDBKeywordsBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("keyword_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("keyword_category"); - return_types.emplace_back(LogicalType::VARCHAR); - - return nullptr; -} - -unique_ptr DuckDBKeywordsInit(ClientContext &context, TableFunctionInitInput &input) { - auto result = make_uniq(); - result->entries = Parser::KeywordList(); - return std::move(result); -} - -void DuckDBKeywordsFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.offset >= data.entries.size()) { - // finished returning values - return; - } - // start returning values - // either fill up the chunk or return all the remaining columns - idx_t count = 0; - while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { - auto &entry = data.entries[data.offset++]; - - // keyword_name, VARCHAR - output.SetValue(0, count, Value(entry.name)); - // keyword_category, VARCHAR - string category_name; - switch (entry.category) { - case KeywordCategory::KEYWORD_RESERVED: - category_name = "reserved"; - break; - case KeywordCategory::KEYWORD_UNRESERVED: - category_name = "unreserved"; - break; - case KeywordCategory::KEYWORD_TYPE_FUNC: - category_name = "type_function"; - break; - case KeywordCategory::KEYWORD_COL_NAME: - category_name = "column_name"; - break; - default: - throw InternalException("Unrecognized keyword category"); - } - output.SetValue(1, count, Value(std::move(category_name))); - - count++; - } - output.SetCardinality(count); -} - -void DuckDBKeywordsFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction( - TableFunction("duckdb_keywords", {}, DuckDBKeywordsFunction, DuckDBKeywordsBind, DuckDBKeywordsInit)); -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -struct DuckDBSchemasData : public GlobalTableFunctionState { - DuckDBSchemasData() : offset(0) { - } - - vector> entries; - idx_t offset; -}; - -static unique_ptr DuckDBSchemasBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("database_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("database_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("schema_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("internal"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("sql"); - return_types.emplace_back(LogicalType::VARCHAR); - - return nullptr; -} - -unique_ptr DuckDBSchemasInit(ClientContext &context, TableFunctionInitInput &input) { - auto result = make_uniq(); - - // scan all the schemas and collect them - result->entries = Catalog::GetAllSchemas(context); - - return std::move(result); -} - -void DuckDBSchemasFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.offset >= data.entries.size()) { - // finished returning values - return; - } - // start returning values - // either fill up the chunk or return all the remaining columns - idx_t count = 0; - while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { - auto &entry = data.entries[data.offset].get(); - - // return values: - idx_t col = 0; - // "oid", PhysicalType::BIGINT - output.SetValue(col++, count, Value::BIGINT(entry.oid)); - // database_name, VARCHAR - output.SetValue(col++, count, entry.catalog.GetName()); - // database_oid, BIGINT - output.SetValue(col++, count, Value::BIGINT(entry.catalog.GetOid())); - // "schema_name", PhysicalType::VARCHAR - output.SetValue(col++, count, Value(entry.name)); - // "internal", PhysicalType::BOOLEAN - output.SetValue(col++, count, Value::BOOLEAN(entry.internal)); - // "sql", PhysicalType::VARCHAR - output.SetValue(col++, count, Value()); - - data.offset++; - count++; - } - output.SetCardinality(count); -} - -void DuckDBSchemasFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(TableFunction("duckdb_schemas", {}, DuckDBSchemasFunction, DuckDBSchemasBind, DuckDBSchemasInit)); -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -struct DuckDBSequencesData : public GlobalTableFunctionState { - DuckDBSequencesData() : offset(0) { - } - - vector> entries; - idx_t offset; -}; - -static unique_ptr DuckDBSequencesBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("database_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("database_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("schema_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("schema_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("sequence_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("sequence_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("temporary"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("start_value"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("min_value"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("max_value"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("increment_by"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("cycle"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("last_value"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("sql"); - return_types.emplace_back(LogicalType::VARCHAR); - - return nullptr; -} - -unique_ptr DuckDBSequencesInit(ClientContext &context, TableFunctionInitInput &input) { - auto result = make_uniq(); - - // scan all the schemas for tables and collect themand collect them - auto schemas = Catalog::GetAllSchemas(context); - for (auto &schema : schemas) { - schema.get().Scan(context, CatalogType::SEQUENCE_ENTRY, - [&](CatalogEntry &entry) { result->entries.push_back(entry.Cast()); }); - }; - return std::move(result); -} - -void DuckDBSequencesFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.offset >= data.entries.size()) { - // finished returning values - return; - } - // start returning values - // either fill up the chunk or return all the remaining columns - idx_t count = 0; - while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { - auto &seq = data.entries[data.offset++].get(); - - // return values: - idx_t col = 0; - // database_name, VARCHAR - output.SetValue(col++, count, seq.catalog.GetName()); - // database_oid, BIGINT - output.SetValue(col++, count, Value::BIGINT(seq.catalog.GetOid())); - // schema_name, VARCHAR - output.SetValue(col++, count, Value(seq.schema.name)); - // schema_oid, BIGINT - output.SetValue(col++, count, Value::BIGINT(seq.schema.oid)); - // sequence_name, VARCHAR - output.SetValue(col++, count, Value(seq.name)); - // sequence_oid, BIGINT - output.SetValue(col++, count, Value::BIGINT(seq.oid)); - // temporary, BOOLEAN - output.SetValue(col++, count, Value::BOOLEAN(seq.temporary)); - // start_value, BIGINT - output.SetValue(col++, count, Value::BIGINT(seq.start_value)); - // min_value, BIGINT - output.SetValue(col++, count, Value::BIGINT(seq.min_value)); - // max_value, BIGINT - output.SetValue(col++, count, Value::BIGINT(seq.max_value)); - // increment_by, BIGINT - output.SetValue(col++, count, Value::BIGINT(seq.increment)); - // cycle, BOOLEAN - output.SetValue(col++, count, Value::BOOLEAN(seq.cycle)); - // last_value, BIGINT - output.SetValue(col++, count, seq.usage_count == 0 ? Value() : Value::BOOLEAN(seq.last_value)); - // sql, LogicalType::VARCHAR - output.SetValue(col++, count, Value(seq.ToSQL())); - - count++; - } - output.SetCardinality(count); -} - -void DuckDBSequencesFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction( - TableFunction("duckdb_sequences", {}, DuckDBSequencesFunction, DuckDBSequencesBind, DuckDBSequencesInit)); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -struct DuckDBSettingValue { - string name; - string value; - string description; - string input_type; -}; - -struct DuckDBSettingsData : public GlobalTableFunctionState { - DuckDBSettingsData() : offset(0) { - } - - vector settings; - idx_t offset; -}; - -static unique_ptr DuckDBSettingsBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("value"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("description"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("input_type"); - return_types.emplace_back(LogicalType::VARCHAR); - - return nullptr; -} - -unique_ptr DuckDBSettingsInit(ClientContext &context, TableFunctionInitInput &input) { - auto result = make_uniq(); - - auto &config = DBConfig::GetConfig(context); - auto options_count = DBConfig::GetOptionCount(); - for (idx_t i = 0; i < options_count; i++) { - auto option = DBConfig::GetOptionByIndex(i); - D_ASSERT(option); - DuckDBSettingValue value; - value.name = option->name; - value.value = option->get_setting(context).ToString(); - value.description = option->description; - value.input_type = EnumUtil::ToString(option->parameter_type); - - result->settings.push_back(std::move(value)); - } - for (auto &ext_param : config.extension_parameters) { - Value setting_val; - string setting_str_val; - if (context.TryGetCurrentSetting(ext_param.first, setting_val)) { - setting_str_val = setting_val.ToString(); - } - DuckDBSettingValue value; - value.name = ext_param.first; - value.value = std::move(setting_str_val); - value.description = ext_param.second.description; - value.input_type = ext_param.second.type.ToString(); - - result->settings.push_back(std::move(value)); - } - return std::move(result); -} - -void DuckDBSettingsFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.offset >= data.settings.size()) { - // finished returning values - return; - } - // start returning values - // either fill up the chunk or return all the remaining columns - idx_t count = 0; - while (data.offset < data.settings.size() && count < STANDARD_VECTOR_SIZE) { - auto &entry = data.settings[data.offset++]; - - // return values: - // name, LogicalType::VARCHAR - output.SetValue(0, count, Value(entry.name)); - // value, LogicalType::VARCHAR - output.SetValue(1, count, Value(entry.value)); - // description, LogicalType::VARCHAR - output.SetValue(2, count, Value(entry.description)); - // input_type, LogicalType::VARCHAR - output.SetValue(3, count, Value(entry.input_type)); - count++; - } - output.SetCardinality(count); -} - -void DuckDBSettingsFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction( - TableFunction("duckdb_settings", {}, DuckDBSettingsFunction, DuckDBSettingsBind, DuckDBSettingsInit)); -} - -} // namespace duckdb - - - - - - - - - - - - - -namespace duckdb { - -struct DuckDBTablesData : public GlobalTableFunctionState { - DuckDBTablesData() : offset(0) { - } - - vector> entries; - idx_t offset; -}; - -static unique_ptr DuckDBTablesBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("database_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("database_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("schema_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("schema_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("table_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("table_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("internal"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("temporary"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("has_primary_key"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("estimated_size"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("column_count"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("index_count"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("check_constraint_count"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("sql"); - return_types.emplace_back(LogicalType::VARCHAR); - - return nullptr; -} - -unique_ptr DuckDBTablesInit(ClientContext &context, TableFunctionInitInput &input) { - auto result = make_uniq(); - - // scan all the schemas for tables and collect themand collect them - auto schemas = Catalog::GetAllSchemas(context); - for (auto &schema : schemas) { - schema.get().Scan(context, CatalogType::TABLE_ENTRY, - [&](CatalogEntry &entry) { result->entries.push_back(entry); }); - }; - return std::move(result); -} - -static bool TableHasPrimaryKey(TableCatalogEntry &table) { - for (auto &constraint : table.GetConstraints()) { - if (constraint->type == ConstraintType::UNIQUE) { - auto &unique = constraint->Cast(); - if (unique.is_primary_key) { - return true; - } - } - } - return false; -} - -static idx_t CheckConstraintCount(TableCatalogEntry &table) { - idx_t check_count = 0; - for (auto &constraint : table.GetConstraints()) { - if (constraint->type == ConstraintType::CHECK) { - check_count++; - } - } - return check_count; -} - -void DuckDBTablesFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.offset >= data.entries.size()) { - // finished returning values - return; - } - // start returning values - // either fill up the chunk or return all the remaining columns - idx_t count = 0; - while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { - auto &entry = data.entries[data.offset++].get(); - - if (entry.type != CatalogType::TABLE_ENTRY) { - continue; - } - auto &table = entry.Cast(); - auto storage_info = table.GetStorageInfo(context); - // return values: - idx_t col = 0; - // database_name, VARCHAR - output.SetValue(col++, count, table.catalog.GetName()); - // database_oid, BIGINT - output.SetValue(col++, count, Value::BIGINT(table.catalog.GetOid())); - // schema_name, LogicalType::VARCHAR - output.SetValue(col++, count, Value(table.schema.name)); - // schema_oid, LogicalType::BIGINT - output.SetValue(col++, count, Value::BIGINT(table.schema.oid)); - // table_name, LogicalType::VARCHAR - output.SetValue(col++, count, Value(table.name)); - // table_oid, LogicalType::BIGINT - output.SetValue(col++, count, Value::BIGINT(table.oid)); - // internal, LogicalType::BOOLEAN - output.SetValue(col++, count, Value::BOOLEAN(table.internal)); - // temporary, LogicalType::BOOLEAN - output.SetValue(col++, count, Value::BOOLEAN(table.temporary)); - // has_primary_key, LogicalType::BOOLEAN - output.SetValue(col++, count, Value::BOOLEAN(TableHasPrimaryKey(table))); - // estimated_size, LogicalType::BIGINT - Value card_val = - storage_info.cardinality == DConstants::INVALID_INDEX ? Value() : Value::BIGINT(storage_info.cardinality); - output.SetValue(col++, count, card_val); - // column_count, LogicalType::BIGINT - output.SetValue(col++, count, Value::BIGINT(table.GetColumns().LogicalColumnCount())); - // index_count, LogicalType::BIGINT - output.SetValue(col++, count, Value::BIGINT(storage_info.index_info.size())); - // check_constraint_count, LogicalType::BIGINT - output.SetValue(col++, count, Value::BIGINT(CheckConstraintCount(table))); - // sql, LogicalType::VARCHAR - output.SetValue(col++, count, Value(table.ToSQL())); - - count++; - } - output.SetCardinality(count); -} - -void DuckDBTablesFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(TableFunction("duckdb_tables", {}, DuckDBTablesFunction, DuckDBTablesBind, DuckDBTablesInit)); -} - -} // namespace duckdb - - - -namespace duckdb { - -struct DuckDBTemporaryFilesData : public GlobalTableFunctionState { - DuckDBTemporaryFilesData() : offset(0) { - } - - vector entries; - idx_t offset; -}; - -static unique_ptr DuckDBTemporaryFilesBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("path"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("size"); - return_types.emplace_back(LogicalType::BIGINT); - - return nullptr; -} - -unique_ptr DuckDBTemporaryFilesInit(ClientContext &context, TableFunctionInitInput &input) { - auto result = make_uniq(); - - result->entries = BufferManager::GetBufferManager(context).GetTemporaryFiles(); - return std::move(result); -} - -void DuckDBTemporaryFilesFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.offset >= data.entries.size()) { - // finished returning values - return; - } - // start returning values - // either fill up the chunk or return all the remaining columns - idx_t count = 0; - while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { - auto &entry = data.entries[data.offset++]; - // return values: - idx_t col = 0; - // database_name, VARCHAR - output.SetValue(col++, count, entry.path); - // database_oid, BIGINT - output.SetValue(col++, count, Value::BIGINT(entry.size)); - count++; - } - output.SetCardinality(count); -} - -void DuckDBTemporaryFilesFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(TableFunction("duckdb_temporary_files", {}, DuckDBTemporaryFilesFunction, DuckDBTemporaryFilesBind, - DuckDBTemporaryFilesInit)); -} - -} // namespace duckdb - - - - - - - - - - -namespace duckdb { - -struct DuckDBTypesData : public GlobalTableFunctionState { - DuckDBTypesData() : offset(0) { - } - - vector> entries; - idx_t offset; - unordered_set oids; -}; - -static unique_ptr DuckDBTypesBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("database_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("database_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("schema_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("schema_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("type_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("type_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("type_size"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("logical_type"); - return_types.emplace_back(LogicalType::VARCHAR); - - // NUMERIC, STRING, DATETIME, BOOLEAN, COMPOSITE, USER - names.emplace_back("type_category"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("internal"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("labels"); - return_types.emplace_back(LogicalType::LIST(LogicalType::VARCHAR)); - - return nullptr; -} - -unique_ptr DuckDBTypesInit(ClientContext &context, TableFunctionInitInput &input) { - auto result = make_uniq(); - auto schemas = Catalog::GetAllSchemas(context); - for (auto &schema : schemas) { - schema.get().Scan(context, CatalogType::TYPE_ENTRY, - [&](CatalogEntry &entry) { result->entries.push_back(entry.Cast()); }); - }; - return std::move(result); -} - -void DuckDBTypesFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.offset >= data.entries.size()) { - // finished returning values - return; - } - // start returning values - // either fill up the chunk or return all the remaining columns - idx_t count = 0; - while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { - auto &type_entry = data.entries[data.offset++].get(); - auto &type = type_entry.user_type; - - // return values: - idx_t col = 0; - // database_name, VARCHAR - output.SetValue(col++, count, type_entry.catalog.GetName()); - // database_oid, BIGINT - output.SetValue(col++, count, Value::BIGINT(type_entry.catalog.GetOid())); - // schema_name, LogicalType::VARCHAR - output.SetValue(col++, count, Value(type_entry.schema.name)); - // schema_oid, LogicalType::BIGINT - output.SetValue(col++, count, Value::BIGINT(type_entry.schema.oid)); - // type_oid, BIGINT - int64_t oid; - if (type_entry.internal) { - oid = int64_t(type.id()); - } else { - oid = type_entry.oid; - } - Value oid_val; - if (data.oids.find(oid) == data.oids.end()) { - data.oids.insert(oid); - oid_val = Value::BIGINT(oid); - } else { - oid_val = Value(); - } - output.SetValue(col++, count, oid_val); - // type_name, VARCHAR - output.SetValue(col++, count, Value(type_entry.name)); - // type_size, BIGINT - auto internal_type = type.InternalType(); - output.SetValue(col++, count, - internal_type == PhysicalType::INVALID ? Value() : Value::BIGINT(GetTypeIdSize(internal_type))); - // logical_type, VARCHAR - output.SetValue(col++, count, Value(EnumUtil::ToString(type.id()))); - // type_category, VARCHAR - string category; - switch (type.id()) { - case LogicalTypeId::TINYINT: - case LogicalTypeId::SMALLINT: - case LogicalTypeId::INTEGER: - case LogicalTypeId::BIGINT: - case LogicalTypeId::DECIMAL: - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - case LogicalTypeId::UTINYINT: - case LogicalTypeId::USMALLINT: - case LogicalTypeId::UINTEGER: - case LogicalTypeId::UBIGINT: - case LogicalTypeId::HUGEINT: - category = "NUMERIC"; - break; - case LogicalTypeId::DATE: - case LogicalTypeId::TIME: - case LogicalTypeId::TIMESTAMP_SEC: - case LogicalTypeId::TIMESTAMP_MS: - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_NS: - case LogicalTypeId::INTERVAL: - case LogicalTypeId::TIME_TZ: - case LogicalTypeId::TIMESTAMP_TZ: - category = "DATETIME"; - break; - case LogicalTypeId::CHAR: - case LogicalTypeId::VARCHAR: - category = "STRING"; - break; - case LogicalTypeId::BOOLEAN: - category = "BOOLEAN"; - break; - case LogicalTypeId::STRUCT: - case LogicalTypeId::LIST: - case LogicalTypeId::MAP: - case LogicalTypeId::UNION: - category = "COMPOSITE"; - break; - default: - break; - } - output.SetValue(col++, count, category.empty() ? Value() : Value(category)); - // internal, BOOLEAN - output.SetValue(col++, count, Value::BOOLEAN(type_entry.internal)); - // labels, VARCHAR[] - if (type.id() == LogicalTypeId::ENUM && type.AuxInfo()) { - auto data = FlatVector::GetData(EnumType::GetValuesInsertOrder(type)); - idx_t size = EnumType::GetSize(type); - - vector labels; - for (idx_t i = 0; i < size; i++) { - labels.emplace_back(data[i]); - } - - output.SetValue(col++, count, Value::LIST(labels)); - } else { - output.SetValue(col++, count, Value()); - } - - count++; - } - output.SetCardinality(count); -} - -void DuckDBTypesFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(TableFunction("duckdb_types", {}, DuckDBTypesFunction, DuckDBTypesBind, DuckDBTypesInit)); -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -struct DuckDBViewsData : public GlobalTableFunctionState { - DuckDBViewsData() : offset(0) { - } - - vector> entries; - idx_t offset; -}; - -static unique_ptr DuckDBViewsBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("database_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("database_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("schema_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("schema_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("view_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("view_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("internal"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("temporary"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("column_count"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("sql"); - return_types.emplace_back(LogicalType::VARCHAR); - - return nullptr; -} - -unique_ptr DuckDBViewsInit(ClientContext &context, TableFunctionInitInput &input) { - auto result = make_uniq(); - - // scan all the schemas for tables and collect them and collect them - auto schemas = Catalog::GetAllSchemas(context); - for (auto &schema : schemas) { - schema.get().Scan(context, CatalogType::VIEW_ENTRY, - [&](CatalogEntry &entry) { result->entries.push_back(entry); }); - }; - return std::move(result); -} - -void DuckDBViewsFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.offset >= data.entries.size()) { - // finished returning values - return; - } - // start returning values - // either fill up the chunk or return all the remaining columns - idx_t count = 0; - while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { - auto &entry = data.entries[data.offset++].get(); - - if (entry.type != CatalogType::VIEW_ENTRY) { - continue; - } - auto &view = entry.Cast(); - - // return values: - idx_t col = 0; - // database_name, VARCHAR - output.SetValue(col++, count, view.catalog.GetName()); - // database_oid, BIGINT - output.SetValue(col++, count, Value::BIGINT(view.catalog.GetOid())); - // schema_name, LogicalType::VARCHAR - output.SetValue(col++, count, Value(view.schema.name)); - // schema_oid, LogicalType::BIGINT - output.SetValue(col++, count, Value::BIGINT(view.schema.oid)); - // view_name, LogicalType::VARCHAR - output.SetValue(col++, count, Value(view.name)); - // view_oid, LogicalType::BIGINT - output.SetValue(col++, count, Value::BIGINT(view.oid)); - // internal, LogicalType::BOOLEAN - output.SetValue(col++, count, Value::BOOLEAN(view.internal)); - // temporary, LogicalType::BOOLEAN - output.SetValue(col++, count, Value::BOOLEAN(view.temporary)); - // column_count, LogicalType::BIGINT - output.SetValue(col++, count, Value::BIGINT(view.types.size())); - // sql, LogicalType::VARCHAR - output.SetValue(col++, count, Value(view.ToSQL())); - - count++; - } - output.SetCardinality(count); -} - -void DuckDBViewsFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(TableFunction("duckdb_views", {}, DuckDBViewsFunction, DuckDBViewsBind, DuckDBViewsInit)); -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -struct PragmaCollateData : public GlobalTableFunctionState { - PragmaCollateData() : offset(0) { - } - - vector entries; - idx_t offset; -}; - -static unique_ptr PragmaCollateBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("collname"); - return_types.emplace_back(LogicalType::VARCHAR); - - return nullptr; -} - -unique_ptr PragmaCollateInit(ClientContext &context, TableFunctionInitInput &input) { - auto result = make_uniq(); - - auto schemas = Catalog::GetAllSchemas(context); - for (auto schema : schemas) { - schema.get().Scan(context, CatalogType::COLLATION_ENTRY, - [&](CatalogEntry &entry) { result->entries.push_back(entry.name); }); - } - return std::move(result); -} - -static void PragmaCollateFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.offset >= data.entries.size()) { - // finished returning values - return; - } - idx_t next = MinValue(data.offset + STANDARD_VECTOR_SIZE, data.entries.size()); - output.SetCardinality(next - data.offset); - for (idx_t i = data.offset; i < next; i++) { - auto index = i - data.offset; - output.SetValue(0, index, Value(data.entries[i])); - } - - data.offset = next; -} - -void PragmaCollations::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction( - TableFunction("pragma_collations", {}, PragmaCollateFunction, PragmaCollateBind, PragmaCollateInit)); -} - -} // namespace duckdb - - - - - - - - - - -namespace duckdb { - -struct PragmaDatabaseSizeData : public GlobalTableFunctionState { - PragmaDatabaseSizeData() : index(0) { - } - - idx_t index; - vector> databases; - Value memory_usage; - Value memory_limit; -}; - -static unique_ptr PragmaDatabaseSizeBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("database_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("database_size"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("block_size"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("total_blocks"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("used_blocks"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("free_blocks"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("wal_size"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("memory_usage"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("memory_limit"); - return_types.emplace_back(LogicalType::VARCHAR); - - return nullptr; -} - -unique_ptr PragmaDatabaseSizeInit(ClientContext &context, TableFunctionInitInput &input) { - auto result = make_uniq(); - result->databases = DatabaseManager::Get(context).GetDatabases(context); - auto &buffer_manager = BufferManager::GetBufferManager(context); - result->memory_usage = Value(StringUtil::BytesToHumanReadableString(buffer_manager.GetUsedMemory())); - auto max_memory = buffer_manager.GetMaxMemory(); - result->memory_limit = - max_memory == (idx_t)-1 ? Value("Unlimited") : Value(StringUtil::BytesToHumanReadableString(max_memory)); - - return std::move(result); -} - -void PragmaDatabaseSizeFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - idx_t row = 0; - for (; data.index < data.databases.size() && row < STANDARD_VECTOR_SIZE; data.index++) { - auto &db = data.databases[data.index].get(); - if (db.IsSystem() || db.IsTemporary()) { - continue; - } - auto ds = db.GetCatalog().GetDatabaseSize(context); - idx_t col = 0; - output.data[col++].SetValue(row, Value(db.GetName())); - output.data[col++].SetValue(row, Value(StringUtil::BytesToHumanReadableString(ds.bytes))); - output.data[col++].SetValue(row, Value::BIGINT(ds.block_size)); - output.data[col++].SetValue(row, Value::BIGINT(ds.total_blocks)); - output.data[col++].SetValue(row, Value::BIGINT(ds.used_blocks)); - output.data[col++].SetValue(row, Value::BIGINT(ds.free_blocks)); - output.data[col++].SetValue( - row, ds.wal_size == idx_t(-1) ? Value() : Value(StringUtil::BytesToHumanReadableString(ds.wal_size))); - output.data[col++].SetValue(row, data.memory_usage); - output.data[col++].SetValue(row, data.memory_limit); - row++; - } - output.SetCardinality(row); -} - -void PragmaDatabaseSize::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(TableFunction("pragma_database_size", {}, PragmaDatabaseSizeFunction, PragmaDatabaseSizeBind, - PragmaDatabaseSizeInit)); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -struct PragmaMetadataFunctionData : public TableFunctionData { - explicit PragmaMetadataFunctionData() { - } - - vector metadata_info; -}; - -struct PragmaMetadataOperatorData : public GlobalTableFunctionState { - PragmaMetadataOperatorData() : offset(0) { - } - - idx_t offset; -}; - -static unique_ptr PragmaMetadataInfoBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("block_id"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("total_blocks"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("free_blocks"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("free_list"); - return_types.emplace_back(LogicalType::LIST(LogicalType::BIGINT)); - - string db_name = - input.inputs.empty() ? DatabaseManager::GetDefaultDatabase(context) : StringValue::Get(input.inputs[0]); - auto &catalog = Catalog::GetCatalog(context, db_name); - auto result = make_uniq(); - result->metadata_info = catalog.GetMetadataInfo(context); - return std::move(result); -} - -unique_ptr PragmaMetadataInfoInit(ClientContext &context, TableFunctionInitInput &input) { - return make_uniq(); -} - -static void PragmaMetadataInfoFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &bind_data = data_p.bind_data->Cast(); - auto &data = data_p.global_state->Cast(); - idx_t count = 0; - while (data.offset < bind_data.metadata_info.size() && count < STANDARD_VECTOR_SIZE) { - auto &entry = bind_data.metadata_info[data.offset++]; - - idx_t col_idx = 0; - // block_id - output.SetValue(col_idx++, count, Value::BIGINT(entry.block_id)); - // total_blocks - output.SetValue(col_idx++, count, Value::BIGINT(entry.total_blocks)); - // free_blocks - output.SetValue(col_idx++, count, Value::BIGINT(entry.free_list.size())); - // free_list - vector list_values; - for (auto &free_id : entry.free_list) { - list_values.push_back(Value::BIGINT(free_id)); - } - output.SetValue(col_idx++, count, Value::LIST(LogicalType::BIGINT, std::move(list_values))); - count++; - } - output.SetCardinality(count); -} - -void PragmaMetadataInfo::RegisterFunction(BuiltinFunctions &set) { - TableFunctionSet metadata_info("pragma_metadata_info"); - metadata_info.AddFunction( - TableFunction({}, PragmaMetadataInfoFunction, PragmaMetadataInfoBind, PragmaMetadataInfoInit)); - metadata_info.AddFunction(TableFunction({LogicalType::VARCHAR}, PragmaMetadataInfoFunction, PragmaMetadataInfoBind, - PragmaMetadataInfoInit)); - set.AddFunction(metadata_info); -} - -} // namespace duckdb - - - - - - - - - - - - - - - -#include - -namespace duckdb { - -struct PragmaStorageFunctionData : public TableFunctionData { - explicit PragmaStorageFunctionData(TableCatalogEntry &table_entry) : table_entry(table_entry) { - } - - TableCatalogEntry &table_entry; - vector column_segments_info; -}; - -struct PragmaStorageOperatorData : public GlobalTableFunctionState { - PragmaStorageOperatorData() : offset(0) { - } - - idx_t offset; -}; - -static unique_ptr PragmaStorageInfoBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("row_group_id"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("column_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("column_id"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("column_path"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("segment_id"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("segment_type"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("start"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("count"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("compression"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("stats"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("has_updates"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("persistent"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("block_id"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("block_offset"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("segment_info"); - return_types.emplace_back(LogicalType::VARCHAR); - - auto qname = QualifiedName::Parse(input.inputs[0].GetValue()); - - // look up the table name in the catalog - Binder::BindSchemaOrCatalog(context, qname.catalog, qname.schema); - auto &table_entry = Catalog::GetEntry(context, qname.catalog, qname.schema, qname.name); - auto result = make_uniq(table_entry); - result->column_segments_info = table_entry.GetColumnSegmentInfo(); - return std::move(result); -} - -unique_ptr PragmaStorageInfoInit(ClientContext &context, TableFunctionInitInput &input) { - return make_uniq(); -} - -static void PragmaStorageInfoFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &bind_data = data_p.bind_data->Cast(); - auto &data = data_p.global_state->Cast(); - idx_t count = 0; - auto &columns = bind_data.table_entry.GetColumns(); - while (data.offset < bind_data.column_segments_info.size() && count < STANDARD_VECTOR_SIZE) { - auto &entry = bind_data.column_segments_info[data.offset++]; - - idx_t col_idx = 0; - // row_group_id - output.SetValue(col_idx++, count, Value::BIGINT(entry.row_group_index)); - // column_name - auto &col = columns.GetColumn(PhysicalIndex(entry.column_id)); - output.SetValue(col_idx++, count, Value(col.Name())); - // column_id - output.SetValue(col_idx++, count, Value::BIGINT(entry.column_id)); - // column_path - output.SetValue(col_idx++, count, Value(entry.column_path)); - // segment_id - output.SetValue(col_idx++, count, Value::BIGINT(entry.segment_idx)); - // segment_type - output.SetValue(col_idx++, count, Value(entry.segment_type)); - // start - output.SetValue(col_idx++, count, Value::BIGINT(entry.segment_start)); - // count - output.SetValue(col_idx++, count, Value::BIGINT(entry.segment_count)); - // compression - output.SetValue(col_idx++, count, Value(entry.compression_type)); - // stats - output.SetValue(col_idx++, count, Value(entry.segment_stats)); - // has_updates - output.SetValue(col_idx++, count, Value::BOOLEAN(entry.has_updates)); - // persistent - output.SetValue(col_idx++, count, Value::BOOLEAN(entry.persistent)); - // block_id - // block_offset - if (entry.persistent) { - output.SetValue(col_idx++, count, Value::BIGINT(entry.block_id)); - output.SetValue(col_idx++, count, Value::BIGINT(entry.block_offset)); - } else { - output.SetValue(col_idx++, count, Value()); - output.SetValue(col_idx++, count, Value()); - } - // segment_info - output.SetValue(col_idx++, count, Value(entry.segment_info)); - count++; - } - output.SetCardinality(count); -} - -void PragmaStorageInfo::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(TableFunction("pragma_storage_info", {LogicalType::VARCHAR}, PragmaStorageInfoFunction, - PragmaStorageInfoBind, PragmaStorageInfoInit)); -} - -} // namespace duckdb - - - - - - - - - - - - - - -#include - -namespace duckdb { - -struct PragmaTableFunctionData : public TableFunctionData { - explicit PragmaTableFunctionData(CatalogEntry &entry_p) : entry(entry_p) { - } - - CatalogEntry &entry; -}; - -struct PragmaTableOperatorData : public GlobalTableFunctionState { - PragmaTableOperatorData() : offset(0) { - } - idx_t offset; -}; - -static unique_ptr PragmaTableInfoBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - - names.emplace_back("cid"); - return_types.emplace_back(LogicalType::INTEGER); - - names.emplace_back("name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("type"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("notnull"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("dflt_value"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("pk"); - return_types.emplace_back(LogicalType::BOOLEAN); - - auto qname = QualifiedName::Parse(input.inputs[0].GetValue()); - - // look up the table name in the catalog - Binder::BindSchemaOrCatalog(context, qname.catalog, qname.schema); - auto &entry = Catalog::GetEntry(context, CatalogType::TABLE_ENTRY, qname.catalog, qname.schema, qname.name); - return make_uniq(entry); -} - -unique_ptr PragmaTableInfoInit(ClientContext &context, TableFunctionInitInput &input) { - return make_uniq(); -} - -static void CheckConstraints(TableCatalogEntry &table, const ColumnDefinition &column, bool &out_not_null, - bool &out_pk) { - out_not_null = false; - out_pk = false; - // check all constraints - // FIXME: this is pretty inefficient, it probably doesn't matter - for (auto &constraint : table.GetConstraints()) { - switch (constraint->type) { - case ConstraintType::NOT_NULL: { - auto ¬_null = constraint->Cast(); - if (not_null.index == column.Logical()) { - out_not_null = true; - } - break; - } - case ConstraintType::UNIQUE: { - auto &unique = constraint->Cast(); - if (unique.is_primary_key) { - if (unique.index == column.Logical()) { - out_pk = true; - } - if (std::find(unique.columns.begin(), unique.columns.end(), column.GetName()) != unique.columns.end()) { - out_pk = true; - } - } - break; - } - default: - break; - } - } -} - -static void PragmaTableInfoTable(PragmaTableOperatorData &data, TableCatalogEntry &table, DataChunk &output) { - if (data.offset >= table.GetColumns().LogicalColumnCount()) { - // finished returning values - return; - } - // start returning values - // either fill up the chunk or return all the remaining columns - idx_t next = MinValue(data.offset + STANDARD_VECTOR_SIZE, table.GetColumns().LogicalColumnCount()); - output.SetCardinality(next - data.offset); - - for (idx_t i = data.offset; i < next; i++) { - bool not_null, pk; - auto index = i - data.offset; - auto &column = table.GetColumn(LogicalIndex(i)); - D_ASSERT(column.Oid() < (idx_t)NumericLimits::Maximum()); - CheckConstraints(table, column, not_null, pk); - - // return values: - // "cid", PhysicalType::INT32 - output.SetValue(0, index, Value::INTEGER((int32_t)column.Oid())); - // "name", PhysicalType::VARCHAR - output.SetValue(1, index, Value(column.Name())); - // "type", PhysicalType::VARCHAR - output.SetValue(2, index, Value(column.Type().ToString())); - // "notnull", PhysicalType::BOOL - output.SetValue(3, index, Value::BOOLEAN(not_null)); - // "dflt_value", PhysicalType::VARCHAR - Value def_value = column.DefaultValue() ? Value(column.DefaultValue()->ToString()) : Value(); - output.SetValue(4, index, def_value); - // "pk", PhysicalType::BOOL - output.SetValue(5, index, Value::BOOLEAN(pk)); - } - data.offset = next; -} - -static void PragmaTableInfoView(PragmaTableOperatorData &data, ViewCatalogEntry &view, DataChunk &output) { - if (data.offset >= view.types.size()) { - // finished returning values - return; - } - // start returning values - // either fill up the chunk or return all the remaining columns - idx_t next = MinValue(data.offset + STANDARD_VECTOR_SIZE, view.types.size()); - output.SetCardinality(next - data.offset); - - for (idx_t i = data.offset; i < next; i++) { - auto index = i - data.offset; - auto type = view.types[i]; - auto &name = view.aliases[i]; - // return values: - // "cid", PhysicalType::INT32 - - output.SetValue(0, index, Value::INTEGER((int32_t)i)); - // "name", PhysicalType::VARCHAR - output.SetValue(1, index, Value(name)); - // "type", PhysicalType::VARCHAR - output.SetValue(2, index, Value(type.ToString())); - // "notnull", PhysicalType::BOOL - output.SetValue(3, index, Value::BOOLEAN(false)); - // "dflt_value", PhysicalType::VARCHAR - output.SetValue(4, index, Value()); - // "pk", PhysicalType::BOOL - output.SetValue(5, index, Value::BOOLEAN(false)); - } - data.offset = next; -} - -static void PragmaTableInfoFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &bind_data = data_p.bind_data->Cast(); - auto &state = data_p.global_state->Cast(); - switch (bind_data.entry.type) { - case CatalogType::TABLE_ENTRY: - PragmaTableInfoTable(state, bind_data.entry.Cast(), output); - break; - case CatalogType::VIEW_ENTRY: - PragmaTableInfoView(state, bind_data.entry.Cast(), output); - break; - default: - throw NotImplementedException("Unimplemented catalog type for pragma_table_info"); - } -} - -void PragmaTableInfo::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(TableFunction("pragma_table_info", {LogicalType::VARCHAR}, PragmaTableInfoFunction, - PragmaTableInfoBind, PragmaTableInfoInit)); -} - -} // namespace duckdb - - - - - - -#include -#include - -namespace duckdb { - -struct TestAllTypesData : public GlobalTableFunctionState { - TestAllTypesData() : offset(0) { - } - - vector> entries; - idx_t offset; -}; - -vector TestAllTypesFun::GetTestTypes(bool use_large_enum) { - vector result; - // scalar types/numerics - result.emplace_back(LogicalType::BOOLEAN, "bool"); - result.emplace_back(LogicalType::TINYINT, "tinyint"); - result.emplace_back(LogicalType::SMALLINT, "smallint"); - result.emplace_back(LogicalType::INTEGER, "int"); - result.emplace_back(LogicalType::BIGINT, "bigint"); - result.emplace_back(LogicalType::HUGEINT, "hugeint"); - result.emplace_back(LogicalType::UTINYINT, "utinyint"); - result.emplace_back(LogicalType::USMALLINT, "usmallint"); - result.emplace_back(LogicalType::UINTEGER, "uint"); - result.emplace_back(LogicalType::UBIGINT, "ubigint"); - result.emplace_back(LogicalType::DATE, "date"); - result.emplace_back(LogicalType::TIME, "time"); - result.emplace_back(LogicalType::TIMESTAMP, "timestamp"); - result.emplace_back(LogicalType::TIMESTAMP_S, "timestamp_s"); - result.emplace_back(LogicalType::TIMESTAMP_MS, "timestamp_ms"); - result.emplace_back(LogicalType::TIMESTAMP_NS, "timestamp_ns"); - result.emplace_back(LogicalType::TIME_TZ, "time_tz"); - result.emplace_back(LogicalType::TIMESTAMP_TZ, "timestamp_tz"); - result.emplace_back(LogicalType::FLOAT, "float"); - result.emplace_back(LogicalType::DOUBLE, "double"); - result.emplace_back(LogicalType::DECIMAL(4, 1), "dec_4_1"); - result.emplace_back(LogicalType::DECIMAL(9, 4), "dec_9_4"); - result.emplace_back(LogicalType::DECIMAL(18, 6), "dec_18_6"); - result.emplace_back(LogicalType::DECIMAL(38, 10), "dec38_10"); - result.emplace_back(LogicalType::UUID, "uuid"); - - // interval - interval_t min_interval; - min_interval.months = 0; - min_interval.days = 0; - min_interval.micros = 0; - - interval_t max_interval; - max_interval.months = 999; - max_interval.days = 999; - max_interval.micros = 999999999; - result.emplace_back(LogicalType::INTERVAL, "interval", Value::INTERVAL(min_interval), - Value::INTERVAL(max_interval)); - // strings/blobs/bitstrings - result.emplace_back(LogicalType::VARCHAR, "varchar", Value("🦆🦆🦆🦆🦆🦆"), - Value(string("goo\x00se", 6))); - result.emplace_back(LogicalType::BLOB, "blob", Value::BLOB("thisisalongblob\\x00withnullbytes"), - Value::BLOB("\\x00\\x00\\x00a")); - result.emplace_back(LogicalType::BIT, "bit", Value::BIT("0010001001011100010101011010111"), Value::BIT("10101")); - - // enums - Vector small_enum(LogicalType::VARCHAR, 2); - auto small_enum_ptr = FlatVector::GetData(small_enum); - small_enum_ptr[0] = StringVector::AddStringOrBlob(small_enum, "DUCK_DUCK_ENUM"); - small_enum_ptr[1] = StringVector::AddStringOrBlob(small_enum, "GOOSE"); - result.emplace_back(LogicalType::ENUM(small_enum, 2), "small_enum"); - - Vector medium_enum(LogicalType::VARCHAR, 300); - auto medium_enum_ptr = FlatVector::GetData(medium_enum); - for (idx_t i = 0; i < 300; i++) { - medium_enum_ptr[i] = StringVector::AddStringOrBlob(medium_enum, string("enum_") + to_string(i)); - } - result.emplace_back(LogicalType::ENUM(medium_enum, 300), "medium_enum"); - - if (use_large_enum) { - // this is a big one... not sure if we should push this one here, but it's required for completeness - Vector large_enum(LogicalType::VARCHAR, 70000); - auto large_enum_ptr = FlatVector::GetData(large_enum); - for (idx_t i = 0; i < 70000; i++) { - large_enum_ptr[i] = StringVector::AddStringOrBlob(large_enum, string("enum_") + to_string(i)); - } - result.emplace_back(LogicalType::ENUM(large_enum, 70000), "large_enum"); - } else { - Vector large_enum(LogicalType::VARCHAR, 2); - auto large_enum_ptr = FlatVector::GetData(large_enum); - large_enum_ptr[0] = StringVector::AddStringOrBlob(large_enum, string("enum_") + to_string(0)); - large_enum_ptr[1] = StringVector::AddStringOrBlob(large_enum, string("enum_") + to_string(69999)); - result.emplace_back(LogicalType::ENUM(large_enum, 2), "large_enum"); - } - - // arrays - auto int_list_type = LogicalType::LIST(LogicalType::INTEGER); - auto empty_int_list = Value::EMPTYLIST(LogicalType::INTEGER); - auto int_list = Value::LIST({Value::INTEGER(42), Value::INTEGER(999), Value(LogicalType::INTEGER), - Value(LogicalType::INTEGER), Value::INTEGER(-42)}); - result.emplace_back(int_list_type, "int_array", empty_int_list, int_list); - - auto double_list_type = LogicalType::LIST(LogicalType::DOUBLE); - auto empty_double_list = Value::EMPTYLIST(LogicalType::DOUBLE); - auto double_list = Value::LIST( - {Value::DOUBLE(42), Value::DOUBLE(NAN), Value::DOUBLE(std::numeric_limits::infinity()), - Value::DOUBLE(-std::numeric_limits::infinity()), Value(LogicalType::DOUBLE), Value::DOUBLE(-42)}); - result.emplace_back(double_list_type, "double_array", empty_double_list, double_list); - - auto date_list_type = LogicalType::LIST(LogicalType::DATE); - auto empty_date_list = Value::EMPTYLIST(LogicalType::DATE); - auto date_list = - Value::LIST({Value::DATE(date_t()), Value::DATE(date_t::infinity()), Value::DATE(date_t::ninfinity()), - Value(LogicalType::DATE), Value::DATE(Date::FromString("2022-05-12"))}); - result.emplace_back(date_list_type, "date_array", empty_date_list, date_list); - - auto timestamp_list_type = LogicalType::LIST(LogicalType::TIMESTAMP); - auto empty_timestamp_list = Value::EMPTYLIST(LogicalType::TIMESTAMP); - auto timestamp_list = Value::LIST({Value::TIMESTAMP(timestamp_t()), Value::TIMESTAMP(timestamp_t::infinity()), - Value::TIMESTAMP(timestamp_t::ninfinity()), Value(LogicalType::TIMESTAMP), - Value::TIMESTAMP(Timestamp::FromString("2022-05-12 16:23:45"))}); - result.emplace_back(timestamp_list_type, "timestamp_array", empty_timestamp_list, timestamp_list); - - auto timestamptz_list_type = LogicalType::LIST(LogicalType::TIMESTAMP_TZ); - auto empty_timestamptz_list = Value::EMPTYLIST(LogicalType::TIMESTAMP_TZ); - auto timestamptz_list = Value::LIST({Value::TIMESTAMPTZ(timestamp_t()), Value::TIMESTAMPTZ(timestamp_t::infinity()), - Value::TIMESTAMPTZ(timestamp_t::ninfinity()), Value(LogicalType::TIMESTAMP_TZ), - Value::TIMESTAMPTZ(Timestamp::FromString("2022-05-12 16:23:45-07"))}); - result.emplace_back(timestamptz_list_type, "timestamptz_array", empty_timestamptz_list, timestamptz_list); - - auto varchar_list_type = LogicalType::LIST(LogicalType::VARCHAR); - auto empty_varchar_list = Value::EMPTYLIST(LogicalType::VARCHAR); - auto varchar_list = - Value::LIST({Value("🦆🦆🦆🦆🦆🦆"), Value("goose"), Value(LogicalType::VARCHAR), Value("")}); - result.emplace_back(varchar_list_type, "varchar_array", empty_varchar_list, varchar_list); - - // nested arrays - auto nested_list_type = LogicalType::LIST(int_list_type); - auto empty_nested_list = Value::EMPTYLIST(int_list_type); - auto nested_int_list = Value::LIST({empty_int_list, int_list, Value(int_list_type), empty_int_list, int_list}); - result.emplace_back(nested_list_type, "nested_int_array", empty_nested_list, nested_int_list); - - // structs - child_list_t struct_type_list; - struct_type_list.push_back(make_pair("a", LogicalType::INTEGER)); - struct_type_list.push_back(make_pair("b", LogicalType::VARCHAR)); - auto struct_type = LogicalType::STRUCT(struct_type_list); - - child_list_t min_struct_list; - min_struct_list.push_back(make_pair("a", Value(LogicalType::INTEGER))); - min_struct_list.push_back(make_pair("b", Value(LogicalType::VARCHAR))); - auto min_struct_val = Value::STRUCT(std::move(min_struct_list)); - - child_list_t max_struct_list; - max_struct_list.push_back(make_pair("a", Value::INTEGER(42))); - max_struct_list.push_back(make_pair("b", Value("🦆🦆🦆🦆🦆🦆"))); - auto max_struct_val = Value::STRUCT(std::move(max_struct_list)); - - result.emplace_back(struct_type, "struct", min_struct_val, max_struct_val); - - // structs with lists - child_list_t struct_list_type_list; - struct_list_type_list.push_back(make_pair("a", int_list_type)); - struct_list_type_list.push_back(make_pair("b", varchar_list_type)); - auto struct_list_type = LogicalType::STRUCT(struct_list_type_list); - - child_list_t min_struct_vl_list; - min_struct_vl_list.push_back(make_pair("a", Value(int_list_type))); - min_struct_vl_list.push_back(make_pair("b", Value(varchar_list_type))); - auto min_struct_val_list = Value::STRUCT(std::move(min_struct_vl_list)); - - child_list_t max_struct_vl_list; - max_struct_vl_list.push_back(make_pair("a", int_list)); - max_struct_vl_list.push_back(make_pair("b", varchar_list)); - auto max_struct_val_list = Value::STRUCT(std::move(max_struct_vl_list)); - - result.emplace_back(struct_list_type, "struct_of_arrays", std::move(min_struct_val_list), - std::move(max_struct_val_list)); - - // array of structs - auto array_of_structs_type = LogicalType::LIST(struct_type); - auto min_array_of_struct_val = Value::EMPTYLIST(struct_type); - auto max_array_of_struct_val = Value::LIST({min_struct_val, max_struct_val, Value(struct_type)}); - result.emplace_back(array_of_structs_type, "array_of_structs", std::move(min_array_of_struct_val), - std::move(max_array_of_struct_val)); - - // map - auto map_type = LogicalType::MAP(LogicalType::VARCHAR, LogicalType::VARCHAR); - auto min_map_value = Value::MAP(ListType::GetChildType(map_type), vector()); - - child_list_t map_struct1; - map_struct1.push_back(make_pair("key", Value("key1"))); - map_struct1.push_back(make_pair("value", Value("🦆🦆🦆🦆🦆🦆"))); - child_list_t map_struct2; - map_struct2.push_back(make_pair("key", Value("key2"))); - map_struct2.push_back(make_pair("key", Value("goose"))); - - vector map_values; - map_values.push_back(Value::STRUCT(map_struct1)); - map_values.push_back(Value::STRUCT(map_struct2)); - - auto max_map_value = Value::MAP(ListType::GetChildType(map_type), map_values); - result.emplace_back(map_type, "map", std::move(min_map_value), std::move(max_map_value)); - - // union - child_list_t members = {{"name", LogicalType::VARCHAR}, {"age", LogicalType::SMALLINT}}; - auto union_type = LogicalType::UNION(members); - const Value &min = Value::UNION(members, 0, Value("Frank")); - const Value &max = Value::UNION(members, 1, Value::SMALLINT(5)); - result.emplace_back(union_type, "union", min, max); - - return result; -} - -struct TestAllTypesBindData : public TableFunctionData { - vector test_types; -}; - -static unique_ptr TestAllTypesBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - auto result = make_uniq(); - bool use_large_enum = false; - auto entry = input.named_parameters.find("use_large_enum"); - if (entry != input.named_parameters.end()) { - use_large_enum = BooleanValue::Get(entry->second); - } - result->test_types = TestAllTypesFun::GetTestTypes(use_large_enum); - for (auto &test_type : result->test_types) { - return_types.push_back(test_type.type); - names.push_back(test_type.name); - } - return std::move(result); -} - -unique_ptr TestAllTypesInit(ClientContext &context, TableFunctionInitInput &input) { - auto &bind_data = input.bind_data->Cast(); - auto result = make_uniq(); - // 3 rows: min, max and NULL - result->entries.resize(3); - // initialize the values - for (auto &test_type : bind_data.test_types) { - result->entries[0].push_back(test_type.min_value); - result->entries[1].push_back(test_type.max_value); - result->entries[2].emplace_back(test_type.type); - } - return std::move(result); -} - -void TestAllTypesFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.offset >= data.entries.size()) { - // finished returning values - return; - } - // start returning values - // either fill up the chunk or return all the remaining columns - idx_t count = 0; - while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { - auto &vals = data.entries[data.offset++]; - for (idx_t col_idx = 0; col_idx < vals.size(); col_idx++) { - output.SetValue(col_idx, count, vals[col_idx]); - } - count++; - } - output.SetCardinality(count); -} - -void TestAllTypesFun::RegisterFunction(BuiltinFunctions &set) { - TableFunction test_all_types("test_all_types", {}, TestAllTypesFunction, TestAllTypesBind, TestAllTypesInit); - test_all_types.named_parameters["use_large_enum"] = LogicalType::BOOLEAN; - set.AddFunction(test_all_types); -} - -} // namespace duckdb - -#endif diff --git a/lib/duckdb-7.cpp b/lib/duckdb-7.cpp deleted file mode 100644 index 1c8058dc..00000000 --- a/lib/duckdb-7.cpp +++ /dev/null @@ -1,20600 +0,0 @@ -#ifdef GODUCKDB_FROM_SOURCE -// See https://raw.githubusercontent.com/duckdb/duckdb/main/LICENSE for licensing information - -#include "duckdb.hpp" -#include "duckdb-internal.hpp" -#ifndef DUCKDB_AMALGAMATION -#error header mismatch -#endif - - - - -namespace duckdb { - -// FLAT, CONSTANT, DICTIONARY, SEQUENCE -struct TestVectorBindData : public TableFunctionData { - vector types; - bool all_flat = false; -}; - -struct TestVectorTypesData : public GlobalTableFunctionState { - TestVectorTypesData() : offset(0) { - } - - vector> entries; - idx_t offset; -}; - -struct TestVectorInfo { - TestVectorInfo(const vector &types, const map &test_type_map, - vector> &entries) - : types(types), test_type_map(test_type_map), entries(entries) { - } - - const vector &types; - const map &test_type_map; - vector> &entries; -}; - -struct TestGeneratedValues { -public: - void AddColumn(vector values) { - if (!column_values.empty() && column_values[0].size() != values.size()) { - throw InternalException("Size mismatch when adding a column to TestGeneratedValues"); - } - column_values.push_back(std::move(values)); - } - - const Value &GetValue(idx_t row, idx_t column) const { - return column_values[column][row]; - } - - idx_t Rows() const { - return column_values.empty() ? 0 : column_values[0].size(); - } - - idx_t Columns() const { - return column_values.size(); - } - -private: - vector> column_values; -}; - -struct TestVectorFlat { - static constexpr const idx_t TEST_VECTOR_CARDINALITY = 3; - - static vector GenerateValues(TestVectorInfo &info, const LogicalType &type) { - vector result; - switch (type.InternalType()) { - case PhysicalType::STRUCT: { - vector> struct_children; - auto &child_types = StructType::GetChildTypes(type); - - struct_children.resize(TEST_VECTOR_CARDINALITY); - for (auto &child_type : child_types) { - auto child_values = GenerateValues(info, child_type.second); - - for (idx_t i = 0; i < child_values.size(); i++) { - struct_children[i].push_back(make_pair(child_type.first, std::move(child_values[i]))); - } - } - for (auto &struct_child : struct_children) { - result.push_back(Value::STRUCT(std::move(struct_child))); - } - break; - } - case PhysicalType::LIST: { - auto &child_type = ListType::GetChildType(type); - auto child_values = GenerateValues(info, child_type); - - result.push_back(Value::LIST(child_type, {child_values[0], child_values[1]})); - result.push_back(Value::LIST(child_type, {})); - result.push_back(Value::LIST(child_type, {child_values[2]})); - break; - } - default: { - auto entry = info.test_type_map.find(type.id()); - if (entry == info.test_type_map.end()) { - throw NotImplementedException("Unimplemented type for test_vector_types %s", type.ToString()); - } - result.push_back(entry->second.min_value); - result.push_back(entry->second.max_value); - result.emplace_back(type); - break; - } - } - return result; - } - - static TestGeneratedValues GenerateValues(TestVectorInfo &info) { - // generate the values for each column - TestGeneratedValues generated_values; - for (auto &type : info.types) { - generated_values.AddColumn(GenerateValues(info, type)); - } - return generated_values; - } - - static void Generate(TestVectorInfo &info) { - auto result_values = GenerateValues(info); - for (idx_t cur_row = 0; cur_row < result_values.Rows(); cur_row += STANDARD_VECTOR_SIZE) { - auto result = make_uniq(); - result->Initialize(Allocator::DefaultAllocator(), info.types); - auto cardinality = MinValue(STANDARD_VECTOR_SIZE, result_values.Rows() - cur_row); - for (idx_t c = 0; c < info.types.size(); c++) { - for (idx_t i = 0; i < cardinality; i++) { - result->data[c].SetValue(i, result_values.GetValue(cur_row + i, c)); - } - } - result->SetCardinality(cardinality); - info.entries.push_back(std::move(result)); - } - } -}; - -struct TestVectorConstant { - static void Generate(TestVectorInfo &info) { - auto values = TestVectorFlat::GenerateValues(info); - for (idx_t cur_row = 0; cur_row < TestVectorFlat::TEST_VECTOR_CARDINALITY; cur_row += STANDARD_VECTOR_SIZE) { - auto result = make_uniq(); - result->Initialize(Allocator::DefaultAllocator(), info.types); - auto cardinality = MinValue(STANDARD_VECTOR_SIZE, TestVectorFlat::TEST_VECTOR_CARDINALITY - cur_row); - for (idx_t c = 0; c < info.types.size(); c++) { - result->data[c].SetValue(0, values.GetValue(0, c)); - result->data[c].SetVectorType(VectorType::CONSTANT_VECTOR); - } - result->SetCardinality(cardinality); - - info.entries.push_back(std::move(result)); - } - } -}; - -struct TestVectorSequence { - static void GenerateVector(TestVectorInfo &info, const LogicalType &type, Vector &result) { - D_ASSERT(type == result.GetType()); - switch (type.id()) { - case LogicalTypeId::TINYINT: - case LogicalTypeId::SMALLINT: - case LogicalTypeId::INTEGER: - case LogicalTypeId::BIGINT: - case LogicalTypeId::UTINYINT: - case LogicalTypeId::USMALLINT: - case LogicalTypeId::UINTEGER: - case LogicalTypeId::UBIGINT: - result.Sequence(3, 2, 3); - return; - default: - break; - } - switch (type.InternalType()) { - case PhysicalType::STRUCT: { - auto &child_entries = StructVector::GetEntries(result); - for (auto &child_entry : child_entries) { - GenerateVector(info, child_entry->GetType(), *child_entry); - } - break; - } - case PhysicalType::LIST: { - auto data = FlatVector::GetData(result); - data[0].offset = 0; - data[0].length = 2; - data[1].offset = 2; - data[1].length = 0; - data[2].offset = 2; - data[2].length = 1; - - GenerateVector(info, ListType::GetChildType(type), ListVector::GetEntry(result)); - ListVector::SetListSize(result, 3); - break; - } - default: { - auto entry = info.test_type_map.find(type.id()); - if (entry == info.test_type_map.end()) { - throw NotImplementedException("Unimplemented type for test_vector_types %s", type.ToString()); - } - result.SetValue(0, entry->second.min_value); - result.SetValue(1, entry->second.max_value); - result.SetValue(2, Value(type)); - break; - } - } - } - - static void Generate(TestVectorInfo &info) { -#if STANDARD_VECTOR_SIZE > 2 - auto result = make_uniq(); - result->Initialize(Allocator::DefaultAllocator(), info.types); - - for (idx_t c = 0; c < info.types.size(); c++) { - GenerateVector(info, info.types[c], result->data[c]); - } - result->SetCardinality(3); - info.entries.push_back(std::move(result)); -#endif - } -}; - -struct TestVectorDictionary { - static void Generate(TestVectorInfo &info) { - idx_t current_chunk = info.entries.size(); - - unordered_set slice_entries {1, 2}; - - TestVectorFlat::Generate(info); - idx_t current_idx = 0; - for (idx_t i = current_chunk; i < info.entries.size(); i++) { - auto &chunk = *info.entries[i]; - SelectionVector sel(STANDARD_VECTOR_SIZE); - idx_t sel_idx = 0; - for (idx_t k = 0; k < chunk.size(); k++) { - if (slice_entries.count(current_idx + k) > 0) { - sel.set_index(sel_idx++, k); - } - } - chunk.Slice(sel, sel_idx); - current_idx += chunk.size(); - } - } -}; - -static unique_ptr TestVectorTypesBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - auto result = make_uniq(); - for (idx_t i = 0; i < input.inputs.size(); i++) { - string name = "test_vector"; - if (i > 0) { - name += to_string(i + 1); - } - auto &input_val = input.inputs[i]; - names.emplace_back(name); - return_types.push_back(input_val.type()); - result->types.push_back(input_val.type()); - } - for (auto &entry : input.named_parameters) { - if (entry.first == "all_flat") { - result->all_flat = BooleanValue::Get(entry.second); - } else { - throw InternalException("Unrecognized named parameter for test_vector_types"); - } - } - return std::move(result); -} - -unique_ptr TestVectorTypesInit(ClientContext &context, TableFunctionInitInput &input) { - auto &bind_data = input.bind_data->Cast(); - - auto result = make_uniq(); - - auto test_types = TestAllTypesFun::GetTestTypes(); - - map test_type_map; - for (auto &test_type : test_types) { - test_type_map.insert(make_pair(test_type.type.id(), std::move(test_type))); - } - - TestVectorInfo info(bind_data.types, test_type_map, result->entries); - TestVectorFlat::Generate(info); - TestVectorConstant::Generate(info); - TestVectorDictionary::Generate(info); - TestVectorSequence::Generate(info); - for (auto &entry : result->entries) { - entry->Verify(); - } - if (bind_data.all_flat) { - for (auto &entry : result->entries) { - entry->Flatten(); - entry->Verify(); - } - } - return std::move(result); -} - -void TestVectorTypesFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.offset >= data.entries.size()) { - // finished returning values - return; - } - output.Reference(*data.entries[data.offset]); - data.offset++; -} - -void TestVectorTypesFun::RegisterFunction(BuiltinFunctions &set) { - TableFunction test_vector_types("test_vector_types", {LogicalType::ANY}, TestVectorTypesFunction, - TestVectorTypesBind, TestVectorTypesInit); - test_vector_types.varargs = LogicalType::ANY; - test_vector_types.named_parameters["all_flat"] = LogicalType::BOOLEAN; - - set.AddFunction(std::move(test_vector_types)); -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -void BuiltinFunctions::RegisterSQLiteFunctions() { - PragmaVersion::RegisterFunction(*this); - PragmaPlatform::RegisterFunction(*this); - PragmaCollations::RegisterFunction(*this); - PragmaTableInfo::RegisterFunction(*this); - PragmaStorageInfo::RegisterFunction(*this); - PragmaMetadataInfo::RegisterFunction(*this); - PragmaDatabaseSize::RegisterFunction(*this); - PragmaLastProfilingOutput::RegisterFunction(*this); - PragmaDetailedProfilingOutput::RegisterFunction(*this); - - DuckDBColumnsFun::RegisterFunction(*this); - DuckDBConstraintsFun::RegisterFunction(*this); - DuckDBDatabasesFun::RegisterFunction(*this); - DuckDBFunctionsFun::RegisterFunction(*this); - DuckDBKeywordsFun::RegisterFunction(*this); - DuckDBIndexesFun::RegisterFunction(*this); - DuckDBSchemasFun::RegisterFunction(*this); - DuckDBDependenciesFun::RegisterFunction(*this); - DuckDBExtensionsFun::RegisterFunction(*this); - DuckDBSequencesFun::RegisterFunction(*this); - DuckDBSettingsFun::RegisterFunction(*this); - DuckDBTablesFun::RegisterFunction(*this); - DuckDBTemporaryFilesFun::RegisterFunction(*this); - DuckDBTypesFun::RegisterFunction(*this); - DuckDBViewsFun::RegisterFunction(*this); - TestAllTypesFun::RegisterFunction(*this); - TestVectorTypesFun::RegisterFunction(*this); -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Table Scan -//===--------------------------------------------------------------------===// -bool TableScanParallelStateNext(ClientContext &context, const FunctionData *bind_data_p, - LocalTableFunctionState *local_state, GlobalTableFunctionState *gstate); - -struct TableScanLocalState : public LocalTableFunctionState { - //! The current position in the scan - TableScanState scan_state; - //! The DataChunk containing all read columns (even filter columns that are immediately removed) - DataChunk all_columns; -}; - -static storage_t GetStorageIndex(TableCatalogEntry &table, column_t column_id) { - if (column_id == DConstants::INVALID_INDEX) { - return column_id; - } - auto &col = table.GetColumn(LogicalIndex(column_id)); - return col.StorageOid(); -} - -struct TableScanGlobalState : public GlobalTableFunctionState { - TableScanGlobalState(ClientContext &context, const FunctionData *bind_data_p) { - D_ASSERT(bind_data_p); - auto &bind_data = bind_data_p->Cast(); - max_threads = bind_data.table.GetStorage().MaxThreads(context); - } - - ParallelTableScanState state; - idx_t max_threads; - - vector projection_ids; - vector scanned_types; - - idx_t MaxThreads() const override { - return max_threads; - } - - bool CanRemoveFilterColumns() const { - return !projection_ids.empty(); - } -}; - -static unique_ptr TableScanInitLocal(ExecutionContext &context, TableFunctionInitInput &input, - GlobalTableFunctionState *gstate) { - auto result = make_uniq(); - auto &bind_data = input.bind_data->Cast(); - vector column_ids = input.column_ids; - for (auto &col : column_ids) { - auto storage_idx = GetStorageIndex(bind_data.table, col); - col = storage_idx; - } - result->scan_state.Initialize(std::move(column_ids), input.filters.get()); - TableScanParallelStateNext(context.client, input.bind_data.get(), result.get(), gstate); - if (input.CanRemoveFilterColumns()) { - auto &tsgs = gstate->Cast(); - result->all_columns.Initialize(context.client, tsgs.scanned_types); - } - return std::move(result); -} - -unique_ptr TableScanInitGlobal(ClientContext &context, TableFunctionInitInput &input) { - - D_ASSERT(input.bind_data); - auto &bind_data = input.bind_data->Cast(); - auto result = make_uniq(context, input.bind_data.get()); - bind_data.table.GetStorage().InitializeParallelScan(context, result->state); - if (input.CanRemoveFilterColumns()) { - result->projection_ids = input.projection_ids; - const auto &columns = bind_data.table.GetColumns(); - for (const auto &col_idx : input.column_ids) { - if (col_idx == COLUMN_IDENTIFIER_ROW_ID) { - result->scanned_types.emplace_back(LogicalType::ROW_TYPE); - } else { - result->scanned_types.push_back(columns.GetColumn(LogicalIndex(col_idx)).Type()); - } - } - } - return std::move(result); -} - -static unique_ptr TableScanStatistics(ClientContext &context, const FunctionData *bind_data_p, - column_t column_id) { - auto &bind_data = bind_data_p->Cast(); - auto &local_storage = LocalStorage::Get(context, bind_data.table.catalog); - if (local_storage.Find(bind_data.table.GetStorage())) { - // we don't emit any statistics for tables that have outstanding transaction-local data - return nullptr; - } - return bind_data.table.GetStatistics(context, column_id); -} - -static void TableScanFunc(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &bind_data = data_p.bind_data->Cast(); - auto &gstate = data_p.global_state->Cast(); - auto &state = data_p.local_state->Cast(); - auto &transaction = DuckTransaction::Get(context, bind_data.table.catalog); - auto &storage = bind_data.table.GetStorage(); - do { - if (bind_data.is_create_index) { - storage.CreateIndexScan(state.scan_state, output, - TableScanType::TABLE_SCAN_COMMITTED_ROWS_OMIT_PERMANENTLY_DELETED); - } else if (gstate.CanRemoveFilterColumns()) { - state.all_columns.Reset(); - storage.Scan(transaction, state.all_columns, state.scan_state); - output.ReferenceColumns(state.all_columns, gstate.projection_ids); - } else { - storage.Scan(transaction, output, state.scan_state); - } - if (output.size() > 0) { - return; - } - if (!TableScanParallelStateNext(context, data_p.bind_data.get(), data_p.local_state.get(), - data_p.global_state.get())) { - return; - } - } while (true); -} - -bool TableScanParallelStateNext(ClientContext &context, const FunctionData *bind_data_p, - LocalTableFunctionState *local_state, GlobalTableFunctionState *global_state) { - auto &bind_data = bind_data_p->Cast(); - auto ¶llel_state = global_state->Cast(); - auto &state = local_state->Cast(); - auto &storage = bind_data.table.GetStorage(); - - return storage.NextParallelScan(context, parallel_state.state, state.scan_state); -} - -double TableScanProgress(ClientContext &context, const FunctionData *bind_data_p, - const GlobalTableFunctionState *gstate_p) { - auto &bind_data = bind_data_p->Cast(); - auto &gstate = gstate_p->Cast(); - auto &storage = bind_data.table.GetStorage(); - idx_t total_rows = storage.GetTotalRows(); - if (total_rows == 0) { - //! Table is either empty or smaller than a vector size, so it is finished - return 100; - } - idx_t scanned_rows = gstate.state.scan_state.processed_rows; - scanned_rows += gstate.state.local_state.processed_rows; - auto percentage = 100 * (double(scanned_rows) / total_rows); - if (percentage > 100) { - //! In case the last chunk has less elements than STANDARD_VECTOR_SIZE, if our percentage is over 100 - //! It means we finished this table. - return 100; - } - return percentage; -} - -idx_t TableScanGetBatchIndex(ClientContext &context, const FunctionData *bind_data_p, - LocalTableFunctionState *local_state, GlobalTableFunctionState *gstate_p) { - auto &state = local_state->Cast(); - if (state.scan_state.table_state.row_group) { - return state.scan_state.table_state.batch_index; - } - if (state.scan_state.local_state.row_group) { - return state.scan_state.table_state.batch_index + state.scan_state.local_state.batch_index; - } - return 0; -} - -BindInfo TableScanGetBindInfo(const FunctionData *bind_data) { - return BindInfo(ScanType::TABLE); -} - -void TableScanDependency(DependencyList &entries, const FunctionData *bind_data_p) { - auto &bind_data = bind_data_p->Cast(); - entries.AddDependency(bind_data.table); -} - -unique_ptr TableScanCardinality(ClientContext &context, const FunctionData *bind_data_p) { - auto &bind_data = bind_data_p->Cast(); - auto &local_storage = LocalStorage::Get(context, bind_data.table.catalog); - auto &storage = bind_data.table.GetStorage(); - idx_t estimated_cardinality = storage.info->cardinality + local_storage.AddedRows(bind_data.table.GetStorage()); - return make_uniq(storage.info->cardinality, estimated_cardinality); -} - -//===--------------------------------------------------------------------===// -// Index Scan -//===--------------------------------------------------------------------===// -struct IndexScanGlobalState : public GlobalTableFunctionState { - explicit IndexScanGlobalState(data_ptr_t row_id_data) : row_ids(LogicalType::ROW_TYPE, row_id_data) { - } - - Vector row_ids; - ColumnFetchState fetch_state; - TableScanState local_storage_state; - vector column_ids; - bool finished; -}; - -static unique_ptr IndexScanInitGlobal(ClientContext &context, TableFunctionInitInput &input) { - auto &bind_data = input.bind_data->Cast(); - data_ptr_t row_id_data = nullptr; - if (!bind_data.result_ids.empty()) { - row_id_data = (data_ptr_t)&bind_data.result_ids[0]; // NOLINT - this is not pretty - } - auto result = make_uniq(row_id_data); - auto &local_storage = LocalStorage::Get(context, bind_data.table.catalog); - - result->column_ids.reserve(input.column_ids.size()); - for (auto &id : input.column_ids) { - result->column_ids.push_back(GetStorageIndex(bind_data.table, id)); - } - result->local_storage_state.Initialize(result->column_ids, input.filters.get()); - local_storage.InitializeScan(bind_data.table.GetStorage(), result->local_storage_state.local_state, input.filters); - - result->finished = false; - return std::move(result); -} - -static void IndexScanFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &bind_data = data_p.bind_data->Cast(); - auto &state = data_p.global_state->Cast(); - auto &transaction = DuckTransaction::Get(context, bind_data.table.catalog); - auto &local_storage = LocalStorage::Get(transaction); - - if (!state.finished) { - bind_data.table.GetStorage().Fetch(transaction, output, state.column_ids, state.row_ids, - bind_data.result_ids.size(), state.fetch_state); - state.finished = true; - } - if (output.size() == 0) { - local_storage.Scan(state.local_storage_state.local_state, state.column_ids, output); - } -} - -static void RewriteIndexExpression(Index &index, LogicalGet &get, Expression &expr, bool &rewrite_possible) { - if (expr.type == ExpressionType::BOUND_COLUMN_REF) { - auto &bound_colref = expr.Cast(); - // bound column ref: rewrite to fit in the current set of bound column ids - bound_colref.binding.table_index = get.table_index; - column_t referenced_column = index.column_ids[bound_colref.binding.column_index]; - // search for the referenced column in the set of column_ids - for (idx_t i = 0; i < get.column_ids.size(); i++) { - if (get.column_ids[i] == referenced_column) { - bound_colref.binding.column_index = i; - return; - } - } - // column id not found in bound columns in the LogicalGet: rewrite not possible - rewrite_possible = false; - } - ExpressionIterator::EnumerateChildren( - expr, [&](Expression &child) { RewriteIndexExpression(index, get, child, rewrite_possible); }); -} - -void TableScanPushdownComplexFilter(ClientContext &context, LogicalGet &get, FunctionData *bind_data_p, - vector> &filters) { - auto &bind_data = bind_data_p->Cast(); - auto &table = bind_data.table; - auto &storage = table.GetStorage(); - - auto &config = ClientConfig::GetConfig(context); - if (!config.enable_optimizer) { - // we only push index scans if the optimizer is enabled - return; - } - if (bind_data.is_index_scan) { - return; - } - if (!get.table_filters.filters.empty()) { - // if there were filters before we can't convert this to an index scan - return; - } - if (!get.projection_ids.empty()) { - // if columns were pruned by RemoveUnusedColumns we can't convert this to an index scan, - // because index scan does not support filter_prune (yet) - return; - } - if (filters.empty()) { - // no indexes or no filters: skip the pushdown - return; - } - // behold - storage.info->indexes.Scan([&](Index &index) { - // first rewrite the index expression so the ColumnBindings align with the column bindings of the current table - - if (index.unbound_expressions.size() > 1) { - // NOTE: index scans are not (yet) supported for compound index keys - return false; - } - - auto index_expression = index.unbound_expressions[0]->Copy(); - bool rewrite_possible = true; - RewriteIndexExpression(index, get, *index_expression, rewrite_possible); - if (!rewrite_possible) { - // could not rewrite! - return false; - } - - Value low_value, high_value, equal_value; - ExpressionType low_comparison_type = ExpressionType::INVALID, high_comparison_type = ExpressionType::INVALID; - // try to find a matching index for any of the filter expressions - for (auto &filter : filters) { - auto &expr = *filter; - - // create a matcher for a comparison with a constant - ComparisonExpressionMatcher matcher; - // match on a comparison type - matcher.expr_type = make_uniq(); - // match on a constant comparison with the indexed expression - matcher.matchers.push_back(make_uniq(*index_expression)); - matcher.matchers.push_back(make_uniq()); - - matcher.policy = SetMatcher::Policy::UNORDERED; - - vector> bindings; - if (matcher.Match(expr, bindings)) { - // range or equality comparison with constant value - // we can use our index here - // bindings[0] = the expression - // bindings[1] = the index expression - // bindings[2] = the constant - auto &comparison = bindings[0].get().Cast(); - auto constant_value = bindings[2].get().Cast().value; - auto comparison_type = comparison.type; - if (comparison.left->type == ExpressionType::VALUE_CONSTANT) { - // the expression is on the right side, we flip them around - comparison_type = FlipComparisonExpression(comparison_type); - } - if (comparison_type == ExpressionType::COMPARE_EQUAL) { - // equality value - // equality overrides any other bounds so we just break here - equal_value = constant_value; - break; - } else if (comparison_type == ExpressionType::COMPARE_GREATERTHANOREQUALTO || - comparison_type == ExpressionType::COMPARE_GREATERTHAN) { - // greater than means this is a lower bound - low_value = constant_value; - low_comparison_type = comparison_type; - } else { - // smaller than means this is an upper bound - high_value = constant_value; - high_comparison_type = comparison_type; - } - } else if (expr.type == ExpressionType::COMPARE_BETWEEN) { - // BETWEEN expression - auto &between = expr.Cast(); - if (!between.input->Equals(*index_expression)) { - // expression doesn't match the current index expression - continue; - } - if (between.lower->type != ExpressionType::VALUE_CONSTANT || - between.upper->type != ExpressionType::VALUE_CONSTANT) { - // not a constant comparison - continue; - } - low_value = (between.lower->Cast()).value; - low_comparison_type = between.lower_inclusive ? ExpressionType::COMPARE_GREATERTHANOREQUALTO - : ExpressionType::COMPARE_GREATERTHAN; - high_value = (between.upper->Cast()).value; - high_comparison_type = between.upper_inclusive ? ExpressionType::COMPARE_LESSTHANOREQUALTO - : ExpressionType::COMPARE_LESSTHAN; - break; - } - } - if (!equal_value.IsNull() || !low_value.IsNull() || !high_value.IsNull()) { - // we can scan this index using this predicate: try a scan - auto &transaction = Transaction::Get(context, bind_data.table.catalog); - unique_ptr index_state; - if (!equal_value.IsNull()) { - // equality predicate - index_state = - index.InitializeScanSinglePredicate(transaction, equal_value, ExpressionType::COMPARE_EQUAL); - } else if (!low_value.IsNull() && !high_value.IsNull()) { - // two-sided predicate - index_state = index.InitializeScanTwoPredicates(transaction, low_value, low_comparison_type, high_value, - high_comparison_type); - } else if (!low_value.IsNull()) { - // less than predicate - index_state = index.InitializeScanSinglePredicate(transaction, low_value, low_comparison_type); - } else { - D_ASSERT(!high_value.IsNull()); - index_state = index.InitializeScanSinglePredicate(transaction, high_value, high_comparison_type); - } - if (index.Scan(transaction, storage, *index_state, STANDARD_VECTOR_SIZE, bind_data.result_ids)) { - // use an index scan! - bind_data.is_index_scan = true; - get.function = TableScanFunction::GetIndexScanFunction(); - } else { - bind_data.result_ids.clear(); - } - return true; - } - return false; - }); -} - -string TableScanToString(const FunctionData *bind_data_p) { - auto &bind_data = bind_data_p->Cast(); - string result = bind_data.table.name; - return result; -} - -static void TableScanSerialize(Serializer &serializer, const optional_ptr bind_data_p, - const TableFunction &function) { - auto &bind_data = bind_data_p->Cast(); - serializer.WriteProperty(100, "catalog", bind_data.table.schema.catalog.GetName()); - serializer.WriteProperty(101, "schema", bind_data.table.schema.name); - serializer.WriteProperty(102, "table", bind_data.table.name); - serializer.WriteProperty(103, "is_index_scan", bind_data.is_index_scan); - serializer.WriteProperty(104, "is_create_index", bind_data.is_create_index); - serializer.WriteProperty(105, "result_ids", bind_data.result_ids); -} - -static unique_ptr TableScanDeserialize(Deserializer &deserializer, TableFunction &function) { - auto catalog = deserializer.ReadProperty(100, "catalog"); - auto schema = deserializer.ReadProperty(101, "schema"); - auto table = deserializer.ReadProperty(102, "table"); - auto &catalog_entry = - Catalog::GetEntry(deserializer.Get(), catalog, schema, table); - if (catalog_entry.type != CatalogType::TABLE_ENTRY) { - throw SerializationException("Cant find table for %s.%s", schema, table); - } - auto result = make_uniq(catalog_entry.Cast()); - deserializer.ReadProperty(103, "is_index_scan", result->is_index_scan); - deserializer.ReadProperty(104, "is_create_index", result->is_create_index); - deserializer.ReadProperty(105, "result_ids", result->result_ids); - return std::move(result); -} - -TableFunction TableScanFunction::GetIndexScanFunction() { - TableFunction scan_function("index_scan", {}, IndexScanFunction); - scan_function.init_local = nullptr; - scan_function.init_global = IndexScanInitGlobal; - scan_function.statistics = TableScanStatistics; - scan_function.dependency = TableScanDependency; - scan_function.cardinality = TableScanCardinality; - scan_function.pushdown_complex_filter = nullptr; - scan_function.to_string = TableScanToString; - scan_function.table_scan_progress = nullptr; - scan_function.get_batch_index = nullptr; - scan_function.projection_pushdown = true; - scan_function.filter_pushdown = false; - scan_function.serialize = TableScanSerialize; - scan_function.deserialize = TableScanDeserialize; - return scan_function; -} - -TableFunction TableScanFunction::GetFunction() { - TableFunction scan_function("seq_scan", {}, TableScanFunc); - scan_function.init_local = TableScanInitLocal; - scan_function.init_global = TableScanInitGlobal; - scan_function.statistics = TableScanStatistics; - scan_function.dependency = TableScanDependency; - scan_function.cardinality = TableScanCardinality; - scan_function.pushdown_complex_filter = TableScanPushdownComplexFilter; - scan_function.to_string = TableScanToString; - scan_function.table_scan_progress = TableScanProgress; - scan_function.get_batch_index = TableScanGetBatchIndex; - scan_function.get_batch_info = TableScanGetBindInfo; - scan_function.projection_pushdown = true; - scan_function.filter_pushdown = true; - scan_function.filter_prune = true; - scan_function.serialize = TableScanSerialize; - scan_function.deserialize = TableScanDeserialize; - return scan_function; -} - -optional_ptr TableScanFunction::GetTableEntry(const TableFunction &function, - const optional_ptr bind_data_p) { - if (function.function != TableScanFunc || !bind_data_p) { - return nullptr; - } - auto &bind_data = bind_data_p->Cast(); - return &bind_data.table; -} - -void TableScanFunction::RegisterFunction(BuiltinFunctions &set) { - TableFunctionSet table_scan_set("seq_scan"); - table_scan_set.AddFunction(GetFunction()); - set.AddFunction(std::move(table_scan_set)); - - set.AddFunction(GetIndexScanFunction()); -} - -void BuiltinFunctions::RegisterTableScanFunctions() { - TableScanFunction::RegisterFunction(*this); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -struct UnnestBindData : public FunctionData { - explicit UnnestBindData(LogicalType input_type_p) : input_type(std::move(input_type_p)) { - } - - LogicalType input_type; - -public: - unique_ptr Copy() const override { - return make_uniq(input_type); - } - - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return input_type == other.input_type; - } -}; - -struct UnnestGlobalState : public GlobalTableFunctionState { - UnnestGlobalState() { - } - - vector> select_list; - - idx_t MaxThreads() const override { - return GlobalTableFunctionState::MAX_THREADS; - } -}; - -struct UnnestLocalState : public LocalTableFunctionState { - UnnestLocalState() { - } - - unique_ptr operator_state; -}; - -static unique_ptr UnnestBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - if (input.input_table_types.size() != 1 || input.input_table_types[0].id() != LogicalTypeId::LIST) { - throw BinderException("UNNEST requires a single list as input"); - } - return_types.push_back(ListType::GetChildType(input.input_table_types[0])); - names.push_back(input.input_table_names[0]); - return make_uniq(input.input_table_types[0]); -} - -static unique_ptr UnnestLocalInit(ExecutionContext &context, TableFunctionInitInput &input, - GlobalTableFunctionState *global_state) { - auto &gstate = global_state->Cast(); - - auto result = make_uniq(); - result->operator_state = PhysicalUnnest::GetState(context, gstate.select_list); - return std::move(result); -} - -static unique_ptr UnnestInit(ClientContext &context, TableFunctionInitInput &input) { - auto &bind_data = input.bind_data->Cast(); - auto result = make_uniq(); - auto ref = make_uniq(bind_data.input_type, 0); - auto bound_unnest = make_uniq(ListType::GetChildType(bind_data.input_type)); - bound_unnest->child = std::move(ref); - result->select_list.push_back(std::move(bound_unnest)); - return std::move(result); -} - -static OperatorResultType UnnestFunction(ExecutionContext &context, TableFunctionInput &data_p, DataChunk &input, - DataChunk &output) { - auto &state = data_p.global_state->Cast(); - auto &lstate = data_p.local_state->Cast(); - return PhysicalUnnest::ExecuteInternal(context, input, output, *lstate.operator_state, state.select_list, false); -} - -void UnnestTableFunction::RegisterFunction(BuiltinFunctions &set) { - TableFunction unnest_function("unnest", {LogicalTypeId::TABLE}, nullptr, UnnestBind, UnnestInit, UnnestLocalInit); - unnest_function.in_out_function = UnnestFunction; - set.AddFunction(unnest_function); -} - -} // namespace duckdb - - - - -#include - -namespace duckdb { - -struct PragmaVersionData : public GlobalTableFunctionState { - PragmaVersionData() : finished(false) { - } - - bool finished; -}; - -static unique_ptr PragmaVersionBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("library_version"); - return_types.emplace_back(LogicalType::VARCHAR); - names.emplace_back("source_id"); - return_types.emplace_back(LogicalType::VARCHAR); - return nullptr; -} - -static unique_ptr PragmaVersionInit(ClientContext &context, TableFunctionInitInput &input) { - return make_uniq(); -} - -static void PragmaVersionFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.finished) { - // finished returning values - return; - } - output.SetCardinality(1); - output.SetValue(0, 0, DuckDB::LibraryVersion()); - output.SetValue(1, 0, DuckDB::SourceID()); - data.finished = true; -} - -void PragmaVersion::RegisterFunction(BuiltinFunctions &set) { - TableFunction pragma_version("pragma_version", {}, PragmaVersionFunction); - pragma_version.bind = PragmaVersionBind; - pragma_version.init_global = PragmaVersionInit; - set.AddFunction(pragma_version); -} - -idx_t DuckDB::StandardVectorSize() { - return STANDARD_VECTOR_SIZE; -} - -const char *DuckDB::SourceID() { - return DUCKDB_SOURCE_ID; -} - -const char *DuckDB::LibraryVersion() { - return DUCKDB_VERSION; -} - -string DuckDB::Platform() { -#if defined(DUCKDB_CUSTOM_PLATFORM) - return DUCKDB_QUOTE_DEFINE(DUCKDB_CUSTOM_PLATFORM); -#endif -#if defined(DUCKDB_WASM_VERSION) - // DuckDB-Wasm requires CUSTOM_PLATFORM to be defined - static_assert(0, "DUCKDB_WASM_VERSION should rely on CUSTOM_PLATFORM being provided"); -#endif - string os = "linux"; -#if INTPTR_MAX == INT64_MAX - string arch = "amd64"; -#elif INTPTR_MAX == INT32_MAX - string arch = "i686"; -#else -#error Unknown pointer size or missing size macros! -#endif - string postfix = ""; - -#ifdef _WIN32 - os = "windows"; -#elif defined(__APPLE__) - os = "osx"; -#endif -#if defined(__aarch64__) || defined(__ARM_ARCH_ISA_A64) - arch = "arm64"; -#endif - -#if !defined(_GLIBCXX_USE_CXX11_ABI) || _GLIBCXX_USE_CXX11_ABI == 0 - if (os == "linux") { - postfix = "_gcc4"; - } -#endif -#if defined(__ANDROID__) - postfix += "_android"; // using + because it may also be gcc4 -#endif -#ifdef __MINGW32__ - postfix = "_mingw"; -#endif -// this is used for the windows R builds which use a separate build environment -#ifdef DUCKDB_PLATFORM_RTOOLS - postfix = "_rtools"; -#endif - return os + "_" + arch + postfix; -} - -struct PragmaPlatformData : public GlobalTableFunctionState { - PragmaPlatformData() : finished(false) { - } - - bool finished; -}; - -static unique_ptr PragmaPlatformBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("platform"); - return_types.emplace_back(LogicalType::VARCHAR); - return nullptr; -} - -static unique_ptr PragmaPlatformInit(ClientContext &context, TableFunctionInitInput &input) { - return make_uniq(); -} - -static void PragmaPlatformFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.finished) { - // finished returning values - return; - } - output.SetCardinality(1); - output.SetValue(0, 0, DuckDB::Platform()); - data.finished = true; -} - -void PragmaPlatform::RegisterFunction(BuiltinFunctions &set) { - TableFunction pragma_platform("pragma_platform", {}, PragmaPlatformFunction); - pragma_platform.bind = PragmaPlatformBind; - pragma_platform.init_global = PragmaPlatformInit; - set.AddFunction(pragma_platform); -} - -} // namespace duckdb - - -namespace duckdb { - -GlobalTableFunctionState::~GlobalTableFunctionState() { -} - -LocalTableFunctionState::~LocalTableFunctionState() { -} - -TableFunctionInfo::~TableFunctionInfo() { -} - -TableFunction::TableFunction(string name, vector arguments, table_function_t function, - table_function_bind_t bind, table_function_init_global_t init_global, - table_function_init_local_t init_local) - : SimpleNamedParameterFunction(std::move(name), std::move(arguments)), bind(bind), bind_replace(nullptr), - init_global(init_global), init_local(init_local), function(function), in_out_function(nullptr), - in_out_function_final(nullptr), statistics(nullptr), dependency(nullptr), cardinality(nullptr), - pushdown_complex_filter(nullptr), to_string(nullptr), table_scan_progress(nullptr), get_batch_index(nullptr), - get_batch_info(nullptr), serialize(nullptr), deserialize(nullptr), projection_pushdown(false), - filter_pushdown(false), filter_prune(false) { -} - -TableFunction::TableFunction(const vector &arguments, table_function_t function, - table_function_bind_t bind, table_function_init_global_t init_global, - table_function_init_local_t init_local) - : TableFunction(string(), arguments, function, bind, init_global, init_local) { -} -TableFunction::TableFunction() - : SimpleNamedParameterFunction("", {}), bind(nullptr), bind_replace(nullptr), init_global(nullptr), - init_local(nullptr), function(nullptr), in_out_function(nullptr), statistics(nullptr), dependency(nullptr), - cardinality(nullptr), pushdown_complex_filter(nullptr), to_string(nullptr), table_scan_progress(nullptr), - get_batch_index(nullptr), get_batch_info(nullptr), serialize(nullptr), deserialize(nullptr), - projection_pushdown(false), filter_pushdown(false), filter_prune(false) { -} - -bool TableFunction::Equal(const TableFunction &rhs) const { - // number of types - if (this->arguments.size() != rhs.arguments.size()) { - return false; - } - // argument types - for (idx_t i = 0; i < this->arguments.size(); ++i) { - if (this->arguments[i] != rhs.arguments[i]) { - return false; - } - } - // varargs - if (this->varargs != rhs.varargs) { - return false; - } - - return true; // they are equal -} - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/function/table_macro_function.hpp -// -// -//===----------------------------------------------------------------------===// -//! The SelectStatement of the view - - - - - -namespace duckdb { - -TableMacroFunction::TableMacroFunction(unique_ptr query_node) - : MacroFunction(MacroType::TABLE_MACRO), query_node(std::move(query_node)) { -} - -TableMacroFunction::TableMacroFunction(void) : MacroFunction(MacroType::TABLE_MACRO) { -} - -unique_ptr TableMacroFunction::Copy() const { - auto result = make_uniq(); - result->query_node = query_node->Copy(); - this->CopyProperties(*result); - return std::move(result); -} - -string TableMacroFunction::ToSQL(const string &schema, const string &name) const { - return MacroFunction::ToSQL(schema, name) + StringUtil::Format("TABLE (%s);", query_node->ToString()); -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -void UDFWrapper::RegisterFunction(string name, vector args, LogicalType ret_type, - scalar_function_t udf_function, ClientContext &context, LogicalType varargs) { - - ScalarFunction scalar_function(std::move(name), std::move(args), std::move(ret_type), std::move(udf_function)); - scalar_function.varargs = std::move(varargs); - scalar_function.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - CreateScalarFunctionInfo info(scalar_function); - info.schema = DEFAULT_SCHEMA; - context.RegisterFunction(info); -} - -void UDFWrapper::RegisterAggrFunction(AggregateFunction aggr_function, ClientContext &context, LogicalType varargs) { - aggr_function.varargs = std::move(varargs); - CreateAggregateFunctionInfo info(std::move(aggr_function)); - context.RegisterFunction(info); -} - -} // namespace duckdb - - - - - - - - - - - - - - - -namespace duckdb { - -BaseAppender::BaseAppender(Allocator &allocator, AppenderType type_p) - : allocator(allocator), column(0), appender_type(type_p) { -} - -BaseAppender::BaseAppender(Allocator &allocator_p, vector types_p, AppenderType type_p) - : allocator(allocator_p), types(std::move(types_p)), collection(make_uniq(allocator, types)), - column(0), appender_type(type_p) { - InitializeChunk(); -} - -BaseAppender::~BaseAppender() { -} - -void BaseAppender::Destructor() { - if (Exception::UncaughtException()) { - return; - } - // flush any remaining chunks, but only if we are not cleaning up the appender as part of an exception stack unwind - // wrapped in a try/catch because Close() can throw if the table was dropped in the meantime - try { - Close(); - } catch (...) { - } -} - -InternalAppender::InternalAppender(ClientContext &context_p, TableCatalogEntry &table_p) - : BaseAppender(Allocator::DefaultAllocator(), table_p.GetTypes(), AppenderType::PHYSICAL), context(context_p), - table(table_p) { -} - -InternalAppender::~InternalAppender() { - Destructor(); -} - -Appender::Appender(Connection &con, const string &schema_name, const string &table_name) - : BaseAppender(Allocator::DefaultAllocator(), AppenderType::LOGICAL), context(con.context) { - description = con.TableInfo(schema_name, table_name); - if (!description) { - // table could not be found - throw CatalogException(StringUtil::Format("Table \"%s.%s\" could not be found", schema_name, table_name)); - } - for (auto &column : description->columns) { - types.push_back(column.Type()); - } - InitializeChunk(); - collection = make_uniq(allocator, types); -} - -Appender::Appender(Connection &con, const string &table_name) : Appender(con, DEFAULT_SCHEMA, table_name) { -} - -Appender::~Appender() { - Destructor(); -} - -void BaseAppender::InitializeChunk() { - chunk.Initialize(allocator, types); -} - -void BaseAppender::BeginRow() { -} - -void BaseAppender::EndRow() { - // check that all rows have been appended to - if (column != chunk.ColumnCount()) { - throw InvalidInputException("Call to EndRow before all rows have been appended to!"); - } - column = 0; - chunk.SetCardinality(chunk.size() + 1); - if (chunk.size() >= STANDARD_VECTOR_SIZE) { - FlushChunk(); - } -} - -template -void BaseAppender::AppendValueInternal(Vector &col, SRC input) { - FlatVector::GetData(col)[chunk.size()] = Cast::Operation(input); -} - -template -void BaseAppender::AppendDecimalValueInternal(Vector &col, SRC input) { - switch (appender_type) { - case AppenderType::LOGICAL: { - auto &type = col.GetType(); - D_ASSERT(type.id() == LogicalTypeId::DECIMAL); - auto width = DecimalType::GetWidth(type); - auto scale = DecimalType::GetScale(type); - TryCastToDecimal::Operation(input, FlatVector::GetData(col)[chunk.size()], nullptr, width, - scale); - return; - } - case AppenderType::PHYSICAL: { - AppendValueInternal(col, input); - return; - } - default: - throw InternalException("Type not implemented for AppenderType"); - } -} - -template -void BaseAppender::AppendValueInternal(T input) { - if (column >= types.size()) { - throw InvalidInputException("Too many appends for chunk!"); - } - auto &col = chunk.data[column]; - switch (col.GetType().id()) { - case LogicalTypeId::BOOLEAN: - AppendValueInternal(col, input); - break; - case LogicalTypeId::UTINYINT: - AppendValueInternal(col, input); - break; - case LogicalTypeId::TINYINT: - AppendValueInternal(col, input); - break; - case LogicalTypeId::USMALLINT: - AppendValueInternal(col, input); - break; - case LogicalTypeId::SMALLINT: - AppendValueInternal(col, input); - break; - case LogicalTypeId::UINTEGER: - AppendValueInternal(col, input); - break; - case LogicalTypeId::INTEGER: - AppendValueInternal(col, input); - break; - case LogicalTypeId::UBIGINT: - AppendValueInternal(col, input); - break; - case LogicalTypeId::BIGINT: - AppendValueInternal(col, input); - break; - case LogicalTypeId::HUGEINT: - AppendValueInternal(col, input); - break; - case LogicalTypeId::FLOAT: - AppendValueInternal(col, input); - break; - case LogicalTypeId::DOUBLE: - AppendValueInternal(col, input); - break; - case LogicalTypeId::DECIMAL: - switch (col.GetType().InternalType()) { - case PhysicalType::INT16: - AppendDecimalValueInternal(col, input); - break; - case PhysicalType::INT32: - AppendDecimalValueInternal(col, input); - break; - case PhysicalType::INT64: - AppendDecimalValueInternal(col, input); - break; - case PhysicalType::INT128: - AppendDecimalValueInternal(col, input); - break; - default: - throw InternalException("Internal type not recognized for Decimal"); - } - break; - case LogicalTypeId::DATE: - AppendValueInternal(col, input); - break; - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_TZ: - AppendValueInternal(col, input); - break; - case LogicalTypeId::TIME: - AppendValueInternal(col, input); - break; - case LogicalTypeId::TIME_TZ: - AppendValueInternal(col, input); - break; - case LogicalTypeId::INTERVAL: - AppendValueInternal(col, input); - break; - case LogicalTypeId::VARCHAR: - FlatVector::GetData(col)[chunk.size()] = StringCast::Operation(input, col); - break; - default: - AppendValue(Value::CreateValue(input)); - return; - } - column++; -} - -template <> -void BaseAppender::Append(bool value) { - AppendValueInternal(value); -} - -template <> -void BaseAppender::Append(int8_t value) { - AppendValueInternal(value); -} - -template <> -void BaseAppender::Append(int16_t value) { - AppendValueInternal(value); -} - -template <> -void BaseAppender::Append(int32_t value) { - AppendValueInternal(value); -} - -template <> -void BaseAppender::Append(int64_t value) { - AppendValueInternal(value); -} - -template <> -void BaseAppender::Append(hugeint_t value) { - AppendValueInternal(value); -} - -template <> -void BaseAppender::Append(uint8_t value) { - AppendValueInternal(value); -} - -template <> -void BaseAppender::Append(uint16_t value) { - AppendValueInternal(value); -} - -template <> -void BaseAppender::Append(uint32_t value) { - AppendValueInternal(value); -} - -template <> -void BaseAppender::Append(uint64_t value) { - AppendValueInternal(value); -} - -template <> -void BaseAppender::Append(const char *value) { - AppendValueInternal(string_t(value)); -} - -void BaseAppender::Append(const char *value, uint32_t length) { - AppendValueInternal(string_t(value, length)); -} - -template <> -void BaseAppender::Append(string_t value) { - AppendValueInternal(value); -} - -template <> -void BaseAppender::Append(float value) { - AppendValueInternal(value); -} - -template <> -void BaseAppender::Append(double value) { - AppendValueInternal(value); -} - -template <> -void BaseAppender::Append(date_t value) { - AppendValueInternal(value); -} - -template <> -void BaseAppender::Append(dtime_t value) { - AppendValueInternal(value); -} - -template <> -void BaseAppender::Append(timestamp_t value) { - AppendValueInternal(value); -} - -template <> -void BaseAppender::Append(interval_t value) { - AppendValueInternal(value); -} - -template <> -void BaseAppender::Append(Value value) { // NOLINT: template shtuff - if (column >= chunk.ColumnCount()) { - throw InvalidInputException("Too many appends for chunk!"); - } - AppendValue(value); -} - -template <> -void BaseAppender::Append(std::nullptr_t value) { - if (column >= chunk.ColumnCount()) { - throw InvalidInputException("Too many appends for chunk!"); - } - auto &col = chunk.data[column++]; - FlatVector::SetNull(col, chunk.size(), true); -} - -void BaseAppender::AppendValue(const Value &value) { - chunk.SetValue(column, chunk.size(), value); - column++; -} - -void BaseAppender::AppendDataChunk(DataChunk &chunk) { - if (chunk.GetTypes() != types) { - throw InvalidInputException("Type mismatch in Append DataChunk and the types required for appender"); - } - collection->Append(chunk); - if (collection->Count() >= FLUSH_COUNT) { - Flush(); - } -} - -void BaseAppender::FlushChunk() { - if (chunk.size() == 0) { - return; - } - collection->Append(chunk); - chunk.Reset(); - if (collection->Count() >= FLUSH_COUNT) { - Flush(); - } -} - -void BaseAppender::Flush() { - // check that all vectors have the same length before appending - if (column != 0) { - throw InvalidInputException("Failed to Flush appender: incomplete append to row!"); - } - - FlushChunk(); - if (collection->Count() == 0) { - return; - } - FlushInternal(*collection); - - collection->Reset(); - column = 0; -} - -void Appender::FlushInternal(ColumnDataCollection &collection) { - context->Append(*description, collection); -} - -void InternalAppender::FlushInternal(ColumnDataCollection &collection) { - table.GetStorage().LocalAppend(table, context, collection); -} - -void BaseAppender::Close() { - if (column == 0 || column == types.size()) { - Flush(); - } -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -AttachedDatabase::AttachedDatabase(DatabaseInstance &db, AttachedDatabaseType type) - : CatalogEntry(CatalogType::DATABASE_ENTRY, - type == AttachedDatabaseType::SYSTEM_DATABASE ? SYSTEM_CATALOG : TEMP_CATALOG, 0), - db(db), type(type) { - D_ASSERT(type == AttachedDatabaseType::TEMP_DATABASE || type == AttachedDatabaseType::SYSTEM_DATABASE); - if (type == AttachedDatabaseType::TEMP_DATABASE) { - storage = make_uniq(*this, ":memory:", false); - } - catalog = make_uniq(*this); - transaction_manager = make_uniq(*this); - internal = true; -} - -AttachedDatabase::AttachedDatabase(DatabaseInstance &db, Catalog &catalog_p, string name_p, string file_path_p, - AccessMode access_mode) - : CatalogEntry(CatalogType::DATABASE_ENTRY, catalog_p, std::move(name_p)), db(db), - type(access_mode == AccessMode::READ_ONLY ? AttachedDatabaseType::READ_ONLY_DATABASE - : AttachedDatabaseType::READ_WRITE_DATABASE), - parent_catalog(&catalog_p) { - storage = make_uniq(*this, std::move(file_path_p), access_mode == AccessMode::READ_ONLY); - catalog = make_uniq(*this); - transaction_manager = make_uniq(*this); - internal = true; -} - -AttachedDatabase::AttachedDatabase(DatabaseInstance &db, Catalog &catalog_p, StorageExtension &storage_extension, - string name_p, AttachInfo &info, AccessMode access_mode) - : CatalogEntry(CatalogType::DATABASE_ENTRY, catalog_p, std::move(name_p)), db(db), - type(access_mode == AccessMode::READ_ONLY ? AttachedDatabaseType::READ_ONLY_DATABASE - : AttachedDatabaseType::READ_WRITE_DATABASE), - parent_catalog(&catalog_p) { - catalog = storage_extension.attach(storage_extension.storage_info.get(), *this, name, info, access_mode); - if (!catalog) { - throw InternalException("AttachedDatabase - attach function did not return a catalog"); - } - transaction_manager = - storage_extension.create_transaction_manager(storage_extension.storage_info.get(), *this, *catalog); - if (!transaction_manager) { - throw InternalException( - "AttachedDatabase - create_transaction_manager function did not return a transaction manager"); - } - internal = true; -} - -AttachedDatabase::~AttachedDatabase() { - if (Exception::UncaughtException()) { - return; - } - if (!storage) { - return; - } - - // shutting down: attempt to checkpoint the database - // but only if we are not cleaning up as part of an exception unwind - try { - if (!storage->InMemory()) { - auto &config = DBConfig::GetConfig(db); - if (!config.options.checkpoint_on_shutdown) { - return; - } - storage->CreateCheckpoint(true); - } - } catch (...) { - } -} - -bool AttachedDatabase::IsSystem() const { - D_ASSERT(!storage || type != AttachedDatabaseType::SYSTEM_DATABASE); - return type == AttachedDatabaseType::SYSTEM_DATABASE; -} - -bool AttachedDatabase::IsTemporary() const { - return type == AttachedDatabaseType::TEMP_DATABASE; -} -bool AttachedDatabase::IsReadOnly() const { - return type == AttachedDatabaseType::READ_ONLY_DATABASE; -} - -string AttachedDatabase::ExtractDatabaseName(const string &dbpath, FileSystem &fs) { - if (dbpath.empty() || dbpath == ":memory:") { - return "memory"; - } - return fs.ExtractBaseName(dbpath); -} - -void AttachedDatabase::Initialize() { - if (IsSystem()) { - catalog->Initialize(true); - } else { - catalog->Initialize(false); - } - if (storage) { - storage->Initialize(); - } -} - -StorageManager &AttachedDatabase::GetStorageManager() { - if (!storage) { - throw InternalException("Internal system catalog does not have storage"); - } - return *storage; -} - -Catalog &AttachedDatabase::GetCatalog() { - return *catalog; -} - -TransactionManager &AttachedDatabase::GetTransactionManager() { - return *transaction_manager; -} - -Catalog &AttachedDatabase::ParentCatalog() { - return *parent_catalog; -} - -bool AttachedDatabase::IsInitialDatabase() const { - return is_initial_database; -} - -void AttachedDatabase::SetInitialDatabase() { - is_initial_database = true; -} - -} // namespace duckdb - - -using duckdb::Appender; -using duckdb::AppenderWrapper; -using duckdb::Connection; -using duckdb::date_t; -using duckdb::dtime_t; -using duckdb::hugeint_t; -using duckdb::interval_t; -using duckdb::string_t; -using duckdb::timestamp_t; - -duckdb_state duckdb_appender_create(duckdb_connection connection, const char *schema, const char *table, - duckdb_appender *out_appender) { - Connection *conn = reinterpret_cast(connection); - - if (!connection || !table || !out_appender) { - return DuckDBError; - } - if (schema == nullptr) { - schema = DEFAULT_SCHEMA; - } - auto wrapper = new AppenderWrapper(); - *out_appender = (duckdb_appender)wrapper; - try { - wrapper->appender = duckdb::make_uniq(*conn, schema, table); - } catch (std::exception &ex) { - wrapper->error = ex.what(); - return DuckDBError; - } catch (...) { // LCOV_EXCL_START - wrapper->error = "Unknown create appender error"; - return DuckDBError; - } // LCOV_EXCL_STOP - return DuckDBSuccess; -} - -duckdb_state duckdb_appender_destroy(duckdb_appender *appender) { - if (!appender || !*appender) { - return DuckDBError; - } - duckdb_appender_close(*appender); - auto wrapper = reinterpret_cast(*appender); - if (wrapper) { - delete wrapper; - } - *appender = nullptr; - return DuckDBSuccess; -} - -template -duckdb_state duckdb_appender_run_function(duckdb_appender appender, FUN &&function) { - if (!appender) { - return DuckDBError; - } - auto wrapper = reinterpret_cast(appender); - if (!wrapper->appender) { - return DuckDBError; - } - try { - function(*wrapper->appender); - } catch (std::exception &ex) { - wrapper->error = ex.what(); - return DuckDBError; - } catch (...) { // LCOV_EXCL_START - wrapper->error = "Unknown error"; - return DuckDBError; - } // LCOV_EXCL_STOP - return DuckDBSuccess; -} - -const char *duckdb_appender_error(duckdb_appender appender) { - if (!appender) { - return nullptr; - } - auto wrapper = reinterpret_cast(appender); - if (wrapper->error.empty()) { - return nullptr; - } - return wrapper->error.c_str(); -} - -duckdb_state duckdb_appender_begin_row(duckdb_appender appender) { - return DuckDBSuccess; -} - -duckdb_state duckdb_appender_end_row(duckdb_appender appender) { - return duckdb_appender_run_function(appender, [&](Appender &appender) { appender.EndRow(); }); -} - -template -duckdb_state duckdb_append_internal(duckdb_appender appender, T value) { - if (!appender) { - return DuckDBError; - } - auto *appender_instance = reinterpret_cast(appender); - try { - appender_instance->appender->Append(value); - } catch (std::exception &ex) { - appender_instance->error = ex.what(); - return DuckDBError; - } catch (...) { - return DuckDBError; - } - return DuckDBSuccess; -} - -duckdb_state duckdb_append_bool(duckdb_appender appender, bool value) { - return duckdb_append_internal(appender, value); -} - -duckdb_state duckdb_append_int8(duckdb_appender appender, int8_t value) { - return duckdb_append_internal(appender, value); -} - -duckdb_state duckdb_append_int16(duckdb_appender appender, int16_t value) { - return duckdb_append_internal(appender, value); -} - -duckdb_state duckdb_append_int32(duckdb_appender appender, int32_t value) { - return duckdb_append_internal(appender, value); -} - -duckdb_state duckdb_append_int64(duckdb_appender appender, int64_t value) { - return duckdb_append_internal(appender, value); -} - -duckdb_state duckdb_append_hugeint(duckdb_appender appender, duckdb_hugeint value) { - hugeint_t internal; - internal.lower = value.lower; - internal.upper = value.upper; - return duckdb_append_internal(appender, internal); -} - -duckdb_state duckdb_append_uint8(duckdb_appender appender, uint8_t value) { - return duckdb_append_internal(appender, value); -} - -duckdb_state duckdb_append_uint16(duckdb_appender appender, uint16_t value) { - return duckdb_append_internal(appender, value); -} - -duckdb_state duckdb_append_uint32(duckdb_appender appender, uint32_t value) { - return duckdb_append_internal(appender, value); -} - -duckdb_state duckdb_append_uint64(duckdb_appender appender, uint64_t value) { - return duckdb_append_internal(appender, value); -} - -duckdb_state duckdb_append_float(duckdb_appender appender, float value) { - return duckdb_append_internal(appender, value); -} - -duckdb_state duckdb_append_double(duckdb_appender appender, double value) { - return duckdb_append_internal(appender, value); -} - -duckdb_state duckdb_append_date(duckdb_appender appender, duckdb_date value) { - return duckdb_append_internal(appender, date_t(value.days)); -} - -duckdb_state duckdb_append_time(duckdb_appender appender, duckdb_time value) { - return duckdb_append_internal(appender, dtime_t(value.micros)); -} - -duckdb_state duckdb_append_timestamp(duckdb_appender appender, duckdb_timestamp value) { - return duckdb_append_internal(appender, timestamp_t(value.micros)); -} - -duckdb_state duckdb_append_interval(duckdb_appender appender, duckdb_interval value) { - interval_t interval; - interval.months = value.months; - interval.days = value.days; - interval.micros = value.micros; - return duckdb_append_internal(appender, interval); -} - -duckdb_state duckdb_append_null(duckdb_appender appender) { - return duckdb_append_internal(appender, nullptr); -} - -duckdb_state duckdb_append_varchar(duckdb_appender appender, const char *val) { - return duckdb_append_internal(appender, val); -} - -duckdb_state duckdb_append_varchar_length(duckdb_appender appender, const char *val, idx_t length) { - return duckdb_append_internal(appender, string_t(val, length)); -} -duckdb_state duckdb_append_blob(duckdb_appender appender, const void *data, idx_t length) { - auto value = duckdb::Value::BLOB((duckdb::const_data_ptr_t)data, length); - return duckdb_append_internal(appender, value); -} - -duckdb_state duckdb_appender_flush(duckdb_appender appender) { - return duckdb_appender_run_function(appender, [&](Appender &appender) { appender.Flush(); }); -} - -duckdb_state duckdb_appender_close(duckdb_appender appender) { - return duckdb_appender_run_function(appender, [&](Appender &appender) { appender.Close(); }); -} - -duckdb_state duckdb_append_data_chunk(duckdb_appender appender, duckdb_data_chunk chunk) { - if (!chunk) { - return DuckDBError; - } - auto data_chunk = (duckdb::DataChunk *)chunk; - return duckdb_appender_run_function(appender, [&](Appender &appender) { appender.AppendDataChunk(*data_chunk); }); -} - - - - - - -using duckdb::ArrowConverter; -using duckdb::ArrowResultWrapper; -using duckdb::Connection; -using duckdb::DataChunk; -using duckdb::LogicalType; -using duckdb::MaterializedQueryResult; -using duckdb::PreparedStatementWrapper; -using duckdb::QueryResult; -using duckdb::QueryResultType; - -duckdb_state duckdb_query_arrow(duckdb_connection connection, const char *query, duckdb_arrow *out_result) { - Connection *conn = (Connection *)connection; - auto wrapper = new ArrowResultWrapper(); - wrapper->result = conn->Query(query); - *out_result = (duckdb_arrow)wrapper; - return !wrapper->result->HasError() ? DuckDBSuccess : DuckDBError; -} - -duckdb_state duckdb_query_arrow_schema(duckdb_arrow result, duckdb_arrow_schema *out_schema) { - if (!out_schema) { - return DuckDBSuccess; - } - auto wrapper = reinterpret_cast(result); - ArrowConverter::ToArrowSchema((ArrowSchema *)*out_schema, wrapper->result->types, wrapper->result->names, - wrapper->options); - return DuckDBSuccess; -} - -duckdb_state duckdb_prepared_arrow_schema(duckdb_prepared_statement prepared, duckdb_arrow_schema *out_schema) { - if (!out_schema) { - return DuckDBSuccess; - } - auto wrapper = reinterpret_cast(prepared); - if (!wrapper || !wrapper->statement || !wrapper->statement->data) { - return DuckDBError; - } - auto properties = wrapper->statement->context->GetClientProperties(); - duckdb::vector prepared_types; - duckdb::vector prepared_names; - - auto count = wrapper->statement->data->properties.parameter_count; - for (idx_t i = 0; i < count; i++) { - // Every prepared parameter type is UNKNOWN, which we need to map to NULL according to the spec of - // 'AdbcStatementGetParameterSchema' - auto type = LogicalType::SQLNULL; - - // FIXME: we don't support named parameters yet, but when we do, this needs to be updated - auto name = std::to_string(i); - prepared_types.push_back(std::move(type)); - prepared_names.push_back(name); - } - - auto result_schema = (ArrowSchema *)*out_schema; - if (!result_schema) { - return DuckDBError; - } - - if (result_schema->release) { - // Need to release the existing schema before we overwrite it - result_schema->release(result_schema); - result_schema->release = nullptr; - } - - ArrowConverter::ToArrowSchema(result_schema, prepared_types, prepared_names, properties); - return DuckDBSuccess; -} - -duckdb_state duckdb_query_arrow_array(duckdb_arrow result, duckdb_arrow_array *out_array) { - if (!out_array) { - return DuckDBSuccess; - } - auto wrapper = reinterpret_cast(result); - auto success = wrapper->result->TryFetch(wrapper->current_chunk, wrapper->result->GetErrorObject()); - if (!success) { // LCOV_EXCL_START - return DuckDBError; - } // LCOV_EXCL_STOP - if (!wrapper->current_chunk || wrapper->current_chunk->size() == 0) { - return DuckDBSuccess; - } - ArrowConverter::ToArrowArray(*wrapper->current_chunk, reinterpret_cast(*out_array), wrapper->options); - return DuckDBSuccess; -} - -idx_t duckdb_arrow_row_count(duckdb_arrow result) { - auto wrapper = reinterpret_cast(result); - if (wrapper->result->HasError()) { - return 0; - } - return wrapper->result->RowCount(); -} - -idx_t duckdb_arrow_column_count(duckdb_arrow result) { - auto wrapper = reinterpret_cast(result); - return wrapper->result->ColumnCount(); -} - -idx_t duckdb_arrow_rows_changed(duckdb_arrow result) { - auto wrapper = reinterpret_cast(result); - if (wrapper->result->HasError()) { - return 0; - } - idx_t rows_changed = 0; - auto &collection = wrapper->result->Collection(); - idx_t row_count = collection.Count(); - if (row_count > 0 && wrapper->result->properties.return_type == duckdb::StatementReturnType::CHANGED_ROWS) { - auto rows = collection.GetRows(); - D_ASSERT(row_count == 1); - D_ASSERT(rows.size() == 1); - rows_changed = rows[0].GetValue(0).GetValue(); - } - return rows_changed; -} - -const char *duckdb_query_arrow_error(duckdb_arrow result) { - auto wrapper = reinterpret_cast(result); - return wrapper->result->GetError().c_str(); -} - -void duckdb_destroy_arrow(duckdb_arrow *result) { - if (*result) { - auto wrapper = reinterpret_cast(*result); - delete wrapper; - *result = nullptr; - } -} - -duckdb_state duckdb_execute_prepared_arrow(duckdb_prepared_statement prepared_statement, duckdb_arrow *out_result) { - auto wrapper = reinterpret_cast(prepared_statement); - if (!wrapper || !wrapper->statement || wrapper->statement->HasError() || !out_result) { - return DuckDBError; - } - auto arrow_wrapper = new ArrowResultWrapper(); - arrow_wrapper->options = wrapper->statement->context->GetClientProperties(); - - auto result = wrapper->statement->Execute(wrapper->values, false); - D_ASSERT(result->type == QueryResultType::MATERIALIZED_RESULT); - arrow_wrapper->result = duckdb::unique_ptr_cast(std::move(result)); - *out_result = reinterpret_cast(arrow_wrapper); - return !arrow_wrapper->result->HasError() ? DuckDBSuccess : DuckDBError; -} - -namespace arrow_array_stream_wrapper { -namespace { -struct PrivateData { - ArrowSchema *schema; - ArrowArray *array; - bool done = false; -}; - -// LCOV_EXCL_START -// This function is never called, but used to set ArrowSchema's release functions to a non-null NOOP. -void EmptySchemaRelease(ArrowSchema *) { -} -// LCOV_EXCL_STOP - -void EmptyArrayRelease(ArrowArray *) { -} - -void EmptyStreamRelease(ArrowArrayStream *) { -} - -void FactoryGetSchema(uintptr_t stream_factory_ptr, duckdb::ArrowSchemaWrapper &schema) { - auto stream = reinterpret_cast(stream_factory_ptr); - stream->get_schema(stream, &schema.arrow_schema); - - // Need to nullify the root schema's release function here, because streams don't allow us to set the release - // function. For the schema's children, we nullify the release functions in `duckdb_arrow_scan`, so we don't need to - // handle them again here. We set this to nullptr and not EmptySchemaRelease to prevent ArrowSchemaWrapper's - // destructor from destroying the schema (it's the caller's responsibility). - schema.arrow_schema.release = nullptr; -} - -int GetSchema(struct ArrowArrayStream *stream, struct ArrowSchema *out) { - auto private_data = static_cast((stream->private_data)); - if (private_data->schema == nullptr) { - return DuckDBError; - } - - *out = *private_data->schema; - out->release = EmptySchemaRelease; - return DuckDBSuccess; -} - -int GetNext(struct ArrowArrayStream *stream, struct ArrowArray *out) { - auto private_data = static_cast((stream->private_data)); - *out = *private_data->array; - if (private_data->done) { - out->release = nullptr; - } else { - out->release = EmptyArrayRelease; - } - - private_data->done = true; - return DuckDBSuccess; -} - -duckdb::unique_ptr FactoryGetNext(uintptr_t stream_factory_ptr, - duckdb::ArrowStreamParameters ¶meters) { - auto stream = reinterpret_cast(stream_factory_ptr); - auto ret = duckdb::make_uniq(); - ret->arrow_array_stream = *stream; - ret->arrow_array_stream.release = EmptyStreamRelease; - return ret; -} - -// LCOV_EXCL_START -// This function is never be called, because it's used to construct a stream wrapping around a caller-supplied -// ArrowArray. Thus, the stream itself cannot produce an error. -const char *GetLastError(struct ArrowArrayStream *stream) { - return nullptr; -} -// LCOV_EXCL_STOP - -void Release(struct ArrowArrayStream *stream) { - if (stream->private_data != nullptr) { - delete reinterpret_cast(stream->private_data); - } - - stream->private_data = nullptr; - stream->release = nullptr; -} - -duckdb_state Ingest(duckdb_connection connection, const char *table_name, struct ArrowArrayStream *input) { - try { - auto cconn = reinterpret_cast(connection); - cconn - ->TableFunction("arrow_scan", {duckdb::Value::POINTER((uintptr_t)input), - duckdb::Value::POINTER((uintptr_t)FactoryGetNext), - duckdb::Value::POINTER((uintptr_t)FactoryGetSchema)}) - ->CreateView(table_name, true, false); - } catch (...) { // LCOV_EXCL_START - // Tried covering this in tests, but it proved harder than expected. At the time of writing: - // - Passing any name to `CreateView` worked without throwing an exception - // - Passing a null Arrow array worked without throwing an exception - // - Passing an invalid schema (without any columns) led to an InternalException with SIGABRT, which is meant to - // be un-catchable. This case likely needs to be handled gracefully within `arrow_scan`. - // Ref: https://discord.com/channels/909674491309850675/921100573732909107/1115230468699336785 - return DuckDBError; - } // LCOV_EXCL_STOP - - return DuckDBSuccess; -} -} // namespace -} // namespace arrow_array_stream_wrapper - -duckdb_state duckdb_arrow_scan(duckdb_connection connection, const char *table_name, duckdb_arrow_stream arrow) { - auto stream = reinterpret_cast(arrow); - - // Backup release functions - we nullify children schema release functions because we don't want to release on - // behalf of the caller, downstream in our code. Note that Arrow releases target immediate children, but aren't - // recursive. So we only back up immediate children here and restore their functions. - ArrowSchema schema; - if (stream->get_schema(stream, &schema) == DuckDBError) { - return DuckDBError; - } - - typedef void (*release_fn_t)(ArrowSchema *); - std::vector release_fns(schema.n_children); - for (int64_t i = 0; i < schema.n_children; i++) { - auto child = schema.children[i]; - release_fns[i] = child->release; - child->release = arrow_array_stream_wrapper::EmptySchemaRelease; - } - - auto ret = arrow_array_stream_wrapper::Ingest(connection, table_name, stream); - - // Restore release functions. - for (int64_t i = 0; i < schema.n_children; i++) { - schema.children[i]->release = release_fns[i]; - } - - return ret; -} - -duckdb_state duckdb_arrow_array_scan(duckdb_connection connection, const char *table_name, - duckdb_arrow_schema arrow_schema, duckdb_arrow_array arrow_array, - duckdb_arrow_stream *out_stream) { - auto private_data = new arrow_array_stream_wrapper::PrivateData; - private_data->schema = reinterpret_cast(arrow_schema); - private_data->array = reinterpret_cast(arrow_array); - private_data->done = false; - - ArrowArrayStream *stream = new ArrowArrayStream; - *out_stream = reinterpret_cast(stream); - stream->get_schema = arrow_array_stream_wrapper::GetSchema; - stream->get_next = arrow_array_stream_wrapper::GetNext; - stream->get_last_error = arrow_array_stream_wrapper::GetLastError; - stream->release = arrow_array_stream_wrapper::Release; - stream->private_data = private_data; - - return duckdb_arrow_scan(connection, table_name, reinterpret_cast(stream)); -} - - - -namespace duckdb { - -//! DECIMAL -> VARCHAR -template <> -bool CastDecimalCInternal(duckdb_result *source, duckdb_string &result, idx_t col, idx_t row) { - auto result_data = (duckdb::DuckDBResultData *)source->internal_data; - auto &query_result = result_data->result; - auto &source_type = query_result->types[col]; - auto width = duckdb::DecimalType::GetWidth(source_type); - auto scale = duckdb::DecimalType::GetScale(source_type); - duckdb::Vector result_vec(duckdb::LogicalType::VARCHAR, false, false); - duckdb::string_t result_string; - void *source_address = UnsafeFetchPtr(source, col, row); - switch (source_type.InternalType()) { - case duckdb::PhysicalType::INT16: - result_string = duckdb::StringCastFromDecimal::Operation(UnsafeFetchFromPtr(source_address), - width, scale, result_vec); - break; - case duckdb::PhysicalType::INT32: - result_string = duckdb::StringCastFromDecimal::Operation(UnsafeFetchFromPtr(source_address), - width, scale, result_vec); - break; - case duckdb::PhysicalType::INT64: - result_string = duckdb::StringCastFromDecimal::Operation(UnsafeFetchFromPtr(source_address), - width, scale, result_vec); - break; - case duckdb::PhysicalType::INT128: - result_string = duckdb::StringCastFromDecimal::Operation( - UnsafeFetchFromPtr(source_address), width, scale, result_vec); - break; - default: - throw duckdb::InternalException("Unimplemented internal type for decimal"); - } - result.data = reinterpret_cast(duckdb_malloc(sizeof(char) * (result_string.GetSize() + 1))); - memcpy(result.data, result_string.GetData(), result_string.GetSize()); - result.data[result_string.GetSize()] = '\0'; - result.size = result_string.GetSize(); - return true; -} - -template -duckdb_hugeint FetchInternals(void *source_address) { - throw duckdb::NotImplementedException("FetchInternals not implemented for internal type"); -} - -template <> -duckdb_hugeint FetchInternals(void *source_address) { - duckdb_hugeint result; - int16_t intermediate_result; - - if (!TryCast::Operation(UnsafeFetchFromPtr(source_address), intermediate_result)) { - intermediate_result = FetchDefaultValue::Operation(); - } - hugeint_t hugeint_result = Hugeint::Cast(intermediate_result); - result.lower = hugeint_result.lower; - result.upper = hugeint_result.upper; - return result; -} -template <> -duckdb_hugeint FetchInternals(void *source_address) { - duckdb_hugeint result; - int32_t intermediate_result; - - if (!TryCast::Operation(UnsafeFetchFromPtr(source_address), intermediate_result)) { - intermediate_result = FetchDefaultValue::Operation(); - } - hugeint_t hugeint_result = Hugeint::Cast(intermediate_result); - result.lower = hugeint_result.lower; - result.upper = hugeint_result.upper; - return result; -} -template <> -duckdb_hugeint FetchInternals(void *source_address) { - duckdb_hugeint result; - int64_t intermediate_result; - - if (!TryCast::Operation(UnsafeFetchFromPtr(source_address), intermediate_result)) { - intermediate_result = FetchDefaultValue::Operation(); - } - hugeint_t hugeint_result = Hugeint::Cast(intermediate_result); - result.lower = hugeint_result.lower; - result.upper = hugeint_result.upper; - return result; -} -template <> -duckdb_hugeint FetchInternals(void *source_address) { - duckdb_hugeint result; - hugeint_t intermediate_result; - - if (!TryCast::Operation(UnsafeFetchFromPtr(source_address), intermediate_result)) { - intermediate_result = FetchDefaultValue::Operation(); - } - result.lower = intermediate_result.lower; - result.upper = intermediate_result.upper; - return result; -} - -//! DECIMAL -> DECIMAL (internal fetch) -template <> -bool CastDecimalCInternal(duckdb_result *source, duckdb_decimal &result, idx_t col, idx_t row) { - auto result_data = (duckdb::DuckDBResultData *)source->internal_data; - result_data->result->types[col].GetDecimalProperties(result.width, result.scale); - auto source_address = UnsafeFetchPtr(source, col, row); - - if (result.width > duckdb::Decimal::MAX_WIDTH_INT64) { - result.value = FetchInternals(source_address); - } else if (result.width > duckdb::Decimal::MAX_WIDTH_INT32) { - result.value = FetchInternals(source_address); - } else if (result.width > duckdb::Decimal::MAX_WIDTH_INT16) { - result.value = FetchInternals(source_address); - } else { - result.value = FetchInternals(source_address); - } - return true; -} - -} // namespace duckdb - - -namespace duckdb { - -template <> -duckdb_decimal FetchDefaultValue::Operation() { - duckdb_decimal result; - result.scale = 0; - result.width = 0; - result.value = {0, 0}; - return result; -} - -template <> -date_t FetchDefaultValue::Operation() { - date_t result; - result.days = 0; - return result; -} - -template <> -dtime_t FetchDefaultValue::Operation() { - dtime_t result; - result.micros = 0; - return result; -} - -template <> -timestamp_t FetchDefaultValue::Operation() { - timestamp_t result; - result.value = 0; - return result; -} - -template <> -interval_t FetchDefaultValue::Operation() { - interval_t result; - result.months = 0; - result.days = 0; - result.micros = 0; - return result; -} - -template <> -char *FetchDefaultValue::Operation() { - return nullptr; -} - -template <> -duckdb_string FetchDefaultValue::Operation() { - duckdb_string result; - result.data = nullptr; - result.size = 0; - return result; -} - -template <> -duckdb_blob FetchDefaultValue::Operation() { - duckdb_blob result; - result.data = nullptr; - result.size = 0; - return result; -} - -//===--------------------------------------------------------------------===// -// Blob Casts -//===--------------------------------------------------------------------===// - -template <> -bool FromCBlobCastWrapper::Operation(duckdb_blob input, duckdb_string &result) { - string_t input_str(const_char_ptr_cast(input.data), input.size); - return ToCStringCastWrapper::template Operation(input_str, result); -} - -} // namespace duckdb - -bool CanUseDeprecatedFetch(duckdb_result *result, idx_t col, idx_t row) { - if (!result) { - return false; - } - if (!duckdb::deprecated_materialize_result(result)) { - return false; - } - if (col >= result->__deprecated_column_count || row >= result->__deprecated_row_count) { - return false; - } - return true; -} - -bool CanFetchValue(duckdb_result *result, idx_t col, idx_t row) { - if (!CanUseDeprecatedFetch(result, col, row)) { - return false; - } - if (result->__deprecated_columns[col].__deprecated_nullmask[row]) { - return false; - } - return true; -} - - - - -using duckdb::DBConfig; -using duckdb::Value; - -// config -duckdb_state duckdb_create_config(duckdb_config *out_config) { - if (!out_config) { - return DuckDBError; - } - DBConfig *config; - try { - config = new DBConfig(); - } catch (...) { // LCOV_EXCL_START - return DuckDBError; - } // LCOV_EXCL_STOP - *out_config = reinterpret_cast(config); - return DuckDBSuccess; -} - -size_t duckdb_config_count() { - return DBConfig::GetOptionCount(); -} - -duckdb_state duckdb_get_config_flag(size_t index, const char **out_name, const char **out_description) { - auto option = DBConfig::GetOptionByIndex(index); - if (!option) { - return DuckDBError; - } - if (out_name) { - *out_name = option->name; - } - if (out_description) { - *out_description = option->description; - } - return DuckDBSuccess; -} - -duckdb_state duckdb_set_config(duckdb_config config, const char *name, const char *option) { - if (!config || !name || !option) { - return DuckDBError; - } - - try { - auto db_config = (DBConfig *)config; - db_config->SetOptionByName(name, Value(option)); - } catch (...) { - return DuckDBError; - } - return DuckDBSuccess; -} - -void duckdb_destroy_config(duckdb_config *config) { - if (!config) { - return; - } - if (*config) { - auto db_config = (DBConfig *)*config; - delete db_config; - *config = nullptr; - } -} - - - - -#include - -duckdb_data_chunk duckdb_create_data_chunk(duckdb_logical_type *ctypes, idx_t column_count) { - if (!ctypes) { - return nullptr; - } - duckdb::vector types; - for (idx_t i = 0; i < column_count; i++) { - auto ltype = reinterpret_cast(ctypes[i]); - types.push_back(*ltype); - } - - auto result = new duckdb::DataChunk(); - result->Initialize(duckdb::Allocator::DefaultAllocator(), types); - return reinterpret_cast(result); -} - -void duckdb_destroy_data_chunk(duckdb_data_chunk *chunk) { - if (chunk && *chunk) { - auto dchunk = reinterpret_cast(*chunk); - delete dchunk; - *chunk = nullptr; - } -} - -void duckdb_data_chunk_reset(duckdb_data_chunk chunk) { - if (!chunk) { - return; - } - auto dchunk = reinterpret_cast(chunk); - dchunk->Reset(); -} - -idx_t duckdb_data_chunk_get_column_count(duckdb_data_chunk chunk) { - if (!chunk) { - return 0; - } - auto dchunk = reinterpret_cast(chunk); - return dchunk->ColumnCount(); -} - -duckdb_vector duckdb_data_chunk_get_vector(duckdb_data_chunk chunk, idx_t col_idx) { - if (!chunk || col_idx >= duckdb_data_chunk_get_column_count(chunk)) { - return nullptr; - } - auto dchunk = reinterpret_cast(chunk); - return reinterpret_cast(&dchunk->data[col_idx]); -} - -idx_t duckdb_data_chunk_get_size(duckdb_data_chunk chunk) { - if (!chunk) { - return 0; - } - auto dchunk = reinterpret_cast(chunk); - return dchunk->size(); -} - -void duckdb_data_chunk_set_size(duckdb_data_chunk chunk, idx_t size) { - if (!chunk) { - return; - } - auto dchunk = reinterpret_cast(chunk); - dchunk->SetCardinality(size); -} - -duckdb_logical_type duckdb_vector_get_column_type(duckdb_vector vector) { - if (!vector) { - return nullptr; - } - auto v = reinterpret_cast(vector); - return reinterpret_cast(new duckdb::LogicalType(v->GetType())); -} - -void *duckdb_vector_get_data(duckdb_vector vector) { - if (!vector) { - return nullptr; - } - auto v = reinterpret_cast(vector); - return duckdb::FlatVector::GetData(*v); -} - -uint64_t *duckdb_vector_get_validity(duckdb_vector vector) { - if (!vector) { - return nullptr; - } - auto v = reinterpret_cast(vector); - return duckdb::FlatVector::Validity(*v).GetData(); -} - -void duckdb_vector_ensure_validity_writable(duckdb_vector vector) { - if (!vector) { - return; - } - auto v = reinterpret_cast(vector); - auto &validity = duckdb::FlatVector::Validity(*v); - validity.EnsureWritable(); -} - -void duckdb_vector_assign_string_element(duckdb_vector vector, idx_t index, const char *str) { - duckdb_vector_assign_string_element_len(vector, index, str, strlen(str)); -} - -void duckdb_vector_assign_string_element_len(duckdb_vector vector, idx_t index, const char *str, idx_t str_len) { - if (!vector) { - return; - } - auto v = reinterpret_cast(vector); - auto data = duckdb::FlatVector::GetData(*v); - data[index] = duckdb::StringVector::AddString(*v, str, str_len); -} - -duckdb_vector duckdb_list_vector_get_child(duckdb_vector vector) { - if (!vector) { - return nullptr; - } - auto v = reinterpret_cast(vector); - return reinterpret_cast(&duckdb::ListVector::GetEntry(*v)); -} - -idx_t duckdb_list_vector_get_size(duckdb_vector vector) { - if (!vector) { - return 0; - } - auto v = reinterpret_cast(vector); - return duckdb::ListVector::GetListSize(*v); -} - -duckdb_state duckdb_list_vector_set_size(duckdb_vector vector, idx_t size) { - if (!vector) { - return duckdb_state::DuckDBError; - } - auto v = reinterpret_cast(vector); - duckdb::ListVector::SetListSize(*v, size); - return duckdb_state::DuckDBSuccess; -} - -duckdb_state duckdb_list_vector_reserve(duckdb_vector vector, idx_t required_capacity) { - if (!vector) { - return duckdb_state::DuckDBError; - } - auto v = reinterpret_cast(vector); - duckdb::ListVector::Reserve(*v, required_capacity); - return duckdb_state::DuckDBSuccess; -} - -duckdb_vector duckdb_struct_vector_get_child(duckdb_vector vector, idx_t index) { - if (!vector) { - return nullptr; - } - auto v = reinterpret_cast(vector); - return reinterpret_cast(duckdb::StructVector::GetEntries(*v)[index].get()); -} - -bool duckdb_validity_row_is_valid(uint64_t *validity, idx_t row) { - if (!validity) { - return true; - } - idx_t entry_idx = row / 64; - idx_t idx_in_entry = row % 64; - return validity[entry_idx] & ((idx_t)1 << idx_in_entry); -} - -void duckdb_validity_set_row_validity(uint64_t *validity, idx_t row, bool valid) { - if (valid) { - duckdb_validity_set_row_valid(validity, row); - } else { - duckdb_validity_set_row_invalid(validity, row); - } -} - -void duckdb_validity_set_row_invalid(uint64_t *validity, idx_t row) { - if (!validity) { - return; - } - idx_t entry_idx = row / 64; - idx_t idx_in_entry = row % 64; - validity[entry_idx] &= ~((uint64_t)1 << idx_in_entry); -} - -void duckdb_validity_set_row_valid(uint64_t *validity, idx_t row) { - if (!validity) { - return; - } - idx_t entry_idx = row / 64; - idx_t idx_in_entry = row % 64; - validity[entry_idx] |= (uint64_t)1 << idx_in_entry; -} - - - - - -using duckdb::Date; -using duckdb::Time; -using duckdb::Timestamp; - -using duckdb::date_t; -using duckdb::dtime_t; -using duckdb::timestamp_t; - -duckdb_date_struct duckdb_from_date(duckdb_date date) { - int32_t year, month, day; - Date::Convert(date_t(date.days), year, month, day); - - duckdb_date_struct result; - result.year = year; - result.month = month; - result.day = day; - return result; -} - -duckdb_date duckdb_to_date(duckdb_date_struct date) { - duckdb_date result; - result.days = Date::FromDate(date.year, date.month, date.day).days; - return result; -} - -duckdb_time_struct duckdb_from_time(duckdb_time time) { - int32_t hour, minute, second, micros; - Time::Convert(dtime_t(time.micros), hour, minute, second, micros); - - duckdb_time_struct result; - result.hour = hour; - result.min = minute; - result.sec = second; - result.micros = micros; - return result; -} - -duckdb_time duckdb_to_time(duckdb_time_struct time) { - duckdb_time result; - result.micros = Time::FromTime(time.hour, time.min, time.sec, time.micros).micros; - return result; -} - -duckdb_timestamp_struct duckdb_from_timestamp(duckdb_timestamp ts) { - date_t date; - dtime_t time; - Timestamp::Convert(timestamp_t(ts.micros), date, time); - - duckdb_date ddate; - ddate.days = date.days; - - duckdb_time dtime; - dtime.micros = time.micros; - - duckdb_timestamp_struct result; - result.date = duckdb_from_date(ddate); - result.time = duckdb_from_time(dtime); - return result; -} - -duckdb_timestamp duckdb_to_timestamp(duckdb_timestamp_struct ts) { - date_t date = date_t(duckdb_to_date(ts.date).days); - dtime_t time = dtime_t(duckdb_to_time(ts.time).micros); - - duckdb_timestamp result; - result.micros = Timestamp::FromDatetime(date, time).value; - return result; -} - - -using duckdb::Connection; -using duckdb::DatabaseData; -using duckdb::DBConfig; -using duckdb::DuckDB; - -duckdb_state duckdb_open_ext(const char *path, duckdb_database *out, duckdb_config config, char **error) { - auto wrapper = new DatabaseData(); - try { - auto db_config = (DBConfig *)config; - wrapper->database = duckdb::make_uniq(path, db_config); - } catch (std::exception &ex) { - if (error) { - *error = strdup(ex.what()); - } - delete wrapper; - return DuckDBError; - } catch (...) { // LCOV_EXCL_START - if (error) { - *error = strdup("Unknown error"); - } - delete wrapper; - return DuckDBError; - } // LCOV_EXCL_STOP - *out = (duckdb_database)wrapper; - return DuckDBSuccess; -} - -duckdb_state duckdb_open(const char *path, duckdb_database *out) { - return duckdb_open_ext(path, out, nullptr, nullptr); -} - -void duckdb_close(duckdb_database *database) { - if (database && *database) { - auto wrapper = reinterpret_cast(*database); - delete wrapper; - *database = nullptr; - } -} - -duckdb_state duckdb_connect(duckdb_database database, duckdb_connection *out) { - if (!database || !out) { - return DuckDBError; - } - auto wrapper = reinterpret_cast(database); - Connection *connection; - try { - connection = new Connection(*wrapper->database); - } catch (...) { // LCOV_EXCL_START - return DuckDBError; - } // LCOV_EXCL_STOP - *out = (duckdb_connection)connection; - return DuckDBSuccess; -} - -void duckdb_interrupt(duckdb_connection connection) { - if (!connection) { - return; - } - Connection *conn = reinterpret_cast(connection); - conn->Interrupt(); -} - -double duckdb_query_progress(duckdb_connection connection) { - if (!connection) { - return -1; - } - Connection *conn = reinterpret_cast(connection); - return conn->context->GetProgress(); -} - -void duckdb_disconnect(duckdb_connection *connection) { - if (connection && *connection) { - Connection *conn = reinterpret_cast(*connection); - delete conn; - *connection = nullptr; - } -} - -duckdb_state duckdb_query(duckdb_connection connection, const char *query, duckdb_result *out) { - Connection *conn = reinterpret_cast(connection); - auto result = conn->Query(query); - return duckdb_translate_result(std::move(result), out); -} - -const char *duckdb_library_version() { - return DuckDB::LibraryVersion(); -} - - -void duckdb_destroy_value(duckdb_value *value) { - if (value && *value) { - auto val = reinterpret_cast(*value); - delete val; - *value = nullptr; - } -} - -duckdb_value duckdb_create_varchar_length(const char *text, idx_t length) { - return reinterpret_cast(new duckdb::Value(std::string(text, length))); -} - -duckdb_value duckdb_create_varchar(const char *text) { - return duckdb_create_varchar_length(text, strlen(text)); -} - -duckdb_value duckdb_create_int64(int64_t input) { - auto val = duckdb::Value::BIGINT(input); - return reinterpret_cast(new duckdb::Value(val)); -} - -char *duckdb_get_varchar(duckdb_value value) { - auto val = reinterpret_cast(value); - auto str_val = val->DefaultCastAs(duckdb::LogicalType::VARCHAR); - auto &str = duckdb::StringValue::Get(str_val); - - auto result = reinterpret_cast(malloc(sizeof(char) * (str.size() + 1))); - memcpy(result, str.c_str(), str.size()); - result[str.size()] = '\0'; - return result; -} - -int64_t duckdb_get_int64(duckdb_value value) { - auto val = reinterpret_cast(value); - if (!val->DefaultTryCastAs(duckdb::LogicalType::BIGINT)) { - return 0; - } - return duckdb::BigIntValue::Get(*val); -} - - -namespace duckdb { - -LogicalTypeId ConvertCTypeToCPP(duckdb_type c_type) { - switch (c_type) { - case DUCKDB_TYPE_BOOLEAN: - return LogicalTypeId::BOOLEAN; - case DUCKDB_TYPE_TINYINT: - return LogicalTypeId::TINYINT; - case DUCKDB_TYPE_SMALLINT: - return LogicalTypeId::SMALLINT; - case DUCKDB_TYPE_INTEGER: - return LogicalTypeId::INTEGER; - case DUCKDB_TYPE_BIGINT: - return LogicalTypeId::BIGINT; - case DUCKDB_TYPE_UTINYINT: - return LogicalTypeId::UTINYINT; - case DUCKDB_TYPE_USMALLINT: - return LogicalTypeId::USMALLINT; - case DUCKDB_TYPE_UINTEGER: - return LogicalTypeId::UINTEGER; - case DUCKDB_TYPE_UBIGINT: - return LogicalTypeId::UBIGINT; - case DUCKDB_TYPE_HUGEINT: - return LogicalTypeId::HUGEINT; - case DUCKDB_TYPE_FLOAT: - return LogicalTypeId::FLOAT; - case DUCKDB_TYPE_DOUBLE: - return LogicalTypeId::DOUBLE; - case DUCKDB_TYPE_TIMESTAMP: - return LogicalTypeId::TIMESTAMP; - case DUCKDB_TYPE_DATE: - return LogicalTypeId::DATE; - case DUCKDB_TYPE_TIME: - return LogicalTypeId::TIME; - case DUCKDB_TYPE_VARCHAR: - return LogicalTypeId::VARCHAR; - case DUCKDB_TYPE_BLOB: - return LogicalTypeId::BLOB; - case DUCKDB_TYPE_INTERVAL: - return LogicalTypeId::INTERVAL; - case DUCKDB_TYPE_TIMESTAMP_S: - return LogicalTypeId::TIMESTAMP_SEC; - case DUCKDB_TYPE_TIMESTAMP_MS: - return LogicalTypeId::TIMESTAMP_MS; - case DUCKDB_TYPE_TIMESTAMP_NS: - return LogicalTypeId::TIMESTAMP_NS; - case DUCKDB_TYPE_UUID: - return LogicalTypeId::UUID; - default: // LCOV_EXCL_START - D_ASSERT(0); - return LogicalTypeId::INVALID; - } // LCOV_EXCL_STOP -} - -duckdb_type ConvertCPPTypeToC(const LogicalType &sql_type) { - switch (sql_type.id()) { - case LogicalTypeId::BOOLEAN: - return DUCKDB_TYPE_BOOLEAN; - case LogicalTypeId::TINYINT: - return DUCKDB_TYPE_TINYINT; - case LogicalTypeId::SMALLINT: - return DUCKDB_TYPE_SMALLINT; - case LogicalTypeId::INTEGER: - return DUCKDB_TYPE_INTEGER; - case LogicalTypeId::BIGINT: - return DUCKDB_TYPE_BIGINT; - case LogicalTypeId::UTINYINT: - return DUCKDB_TYPE_UTINYINT; - case LogicalTypeId::USMALLINT: - return DUCKDB_TYPE_USMALLINT; - case LogicalTypeId::UINTEGER: - return DUCKDB_TYPE_UINTEGER; - case LogicalTypeId::UBIGINT: - return DUCKDB_TYPE_UBIGINT; - case LogicalTypeId::HUGEINT: - return DUCKDB_TYPE_HUGEINT; - case LogicalTypeId::FLOAT: - return DUCKDB_TYPE_FLOAT; - case LogicalTypeId::DOUBLE: - return DUCKDB_TYPE_DOUBLE; - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_TZ: - return DUCKDB_TYPE_TIMESTAMP; - case LogicalTypeId::TIMESTAMP_SEC: - return DUCKDB_TYPE_TIMESTAMP_S; - case LogicalTypeId::TIMESTAMP_MS: - return DUCKDB_TYPE_TIMESTAMP_MS; - case LogicalTypeId::TIMESTAMP_NS: - return DUCKDB_TYPE_TIMESTAMP_NS; - case LogicalTypeId::DATE: - return DUCKDB_TYPE_DATE; - case LogicalTypeId::TIME: - case LogicalTypeId::TIME_TZ: - return DUCKDB_TYPE_TIME; - case LogicalTypeId::VARCHAR: - return DUCKDB_TYPE_VARCHAR; - case LogicalTypeId::BLOB: - return DUCKDB_TYPE_BLOB; - case LogicalTypeId::BIT: - return DUCKDB_TYPE_BIT; - case LogicalTypeId::INTERVAL: - return DUCKDB_TYPE_INTERVAL; - case LogicalTypeId::DECIMAL: - return DUCKDB_TYPE_DECIMAL; - case LogicalTypeId::ENUM: - return DUCKDB_TYPE_ENUM; - case LogicalTypeId::LIST: - return DUCKDB_TYPE_LIST; - case LogicalTypeId::STRUCT: - return DUCKDB_TYPE_STRUCT; - case LogicalTypeId::MAP: - return DUCKDB_TYPE_MAP; - case LogicalTypeId::UNION: - return DUCKDB_TYPE_UNION; - case LogicalTypeId::UUID: - return DUCKDB_TYPE_UUID; - default: // LCOV_EXCL_START - D_ASSERT(0); - return DUCKDB_TYPE_INVALID; - } // LCOV_EXCL_STOP -} - -idx_t GetCTypeSize(duckdb_type type) { - switch (type) { - case DUCKDB_TYPE_BOOLEAN: - return sizeof(bool); - case DUCKDB_TYPE_TINYINT: - return sizeof(int8_t); - case DUCKDB_TYPE_SMALLINT: - return sizeof(int16_t); - case DUCKDB_TYPE_INTEGER: - return sizeof(int32_t); - case DUCKDB_TYPE_BIGINT: - return sizeof(int64_t); - case DUCKDB_TYPE_UTINYINT: - return sizeof(uint8_t); - case DUCKDB_TYPE_USMALLINT: - return sizeof(uint16_t); - case DUCKDB_TYPE_UINTEGER: - return sizeof(uint32_t); - case DUCKDB_TYPE_UBIGINT: - return sizeof(uint64_t); - case DUCKDB_TYPE_HUGEINT: - case DUCKDB_TYPE_UUID: - return sizeof(duckdb_hugeint); - case DUCKDB_TYPE_FLOAT: - return sizeof(float); - case DUCKDB_TYPE_DOUBLE: - return sizeof(double); - case DUCKDB_TYPE_DATE: - return sizeof(duckdb_date); - case DUCKDB_TYPE_TIME: - return sizeof(duckdb_time); - case DUCKDB_TYPE_TIMESTAMP: - case DUCKDB_TYPE_TIMESTAMP_S: - case DUCKDB_TYPE_TIMESTAMP_MS: - case DUCKDB_TYPE_TIMESTAMP_NS: - return sizeof(duckdb_timestamp); - case DUCKDB_TYPE_VARCHAR: - return sizeof(const char *); - case DUCKDB_TYPE_BLOB: - return sizeof(duckdb_blob); - case DUCKDB_TYPE_INTERVAL: - return sizeof(duckdb_interval); - case DUCKDB_TYPE_DECIMAL: - return sizeof(duckdb_hugeint); - default: // LCOV_EXCL_START - // unsupported type - D_ASSERT(0); - return sizeof(const char *); - } // LCOV_EXCL_STOP -} - -} // namespace duckdb - -void *duckdb_malloc(size_t size) { - return malloc(size); -} - -void duckdb_free(void *ptr) { - free(ptr); -} - -idx_t duckdb_vector_size() { - return STANDARD_VECTOR_SIZE; -} - -bool duckdb_string_is_inlined(duckdb_string_t string_p) { - static_assert(sizeof(duckdb_string_t) == sizeof(duckdb::string_t), - "duckdb_string_t should have the same memory layout as duckdb::string_t"); - auto &string = *(duckdb::string_t *)(&string_p); - return string.IsInlined(); -} - - - - - - - -using duckdb::Hugeint; -using duckdb::hugeint_t; -using duckdb::Value; - -double duckdb_hugeint_to_double(duckdb_hugeint val) { - hugeint_t internal; - internal.lower = val.lower; - internal.upper = val.upper; - return Hugeint::Cast(internal); -} - -static duckdb_decimal to_decimal_cast(double val, uint8_t width, uint8_t scale) { - if (width > duckdb::Decimal::MAX_WIDTH_INT64) { - return duckdb::TryCastToDecimalCInternal>(val, width, scale); - } - if (width > duckdb::Decimal::MAX_WIDTH_INT32) { - return duckdb::TryCastToDecimalCInternal>(val, width, scale); - } - if (width > duckdb::Decimal::MAX_WIDTH_INT16) { - return duckdb::TryCastToDecimalCInternal>(val, width, scale); - } - return duckdb::TryCastToDecimalCInternal>(val, width, scale); -} - -duckdb_decimal duckdb_double_to_decimal(double val, uint8_t width, uint8_t scale) { - if (scale > width || width > duckdb::Decimal::MAX_WIDTH_INT128) { - return duckdb::FetchDefaultValue::Operation(); - } - return to_decimal_cast(val, width, scale); -} - -duckdb_hugeint duckdb_double_to_hugeint(double val) { - hugeint_t internal_result; - if (!Value::DoubleIsFinite(val) || !Hugeint::TryConvert(val, internal_result)) { - internal_result.lower = 0; - internal_result.upper = 0; - } - - duckdb_hugeint result; - result.lower = internal_result.lower; - result.upper = internal_result.upper; - return result; -} - -double duckdb_decimal_to_double(duckdb_decimal val) { - double result; - hugeint_t value; - value.lower = val.value.lower; - value.upper = val.value.upper; - duckdb::TryCastFromDecimal::Operation(value, result, nullptr, val.width, val.scale); - return result; -} - - -static bool AssertLogicalTypeId(duckdb_logical_type type, duckdb::LogicalTypeId type_id) { - if (!type) { - return false; - } - auto <ype = *(reinterpret_cast(type)); - if (ltype.id() != type_id) { - return false; - } - return true; -} - -static bool AssertInternalType(duckdb_logical_type type, duckdb::PhysicalType physical_type) { - if (!type) { - return false; - } - auto <ype = *(reinterpret_cast(type)); - if (ltype.InternalType() != physical_type) { - return false; - } - return true; -} - -duckdb_logical_type duckdb_create_logical_type(duckdb_type type) { - return reinterpret_cast(new duckdb::LogicalType(duckdb::ConvertCTypeToCPP(type))); -} - -duckdb_logical_type duckdb_create_list_type(duckdb_logical_type type) { - if (!type) { - return nullptr; - } - duckdb::LogicalType *ltype = new duckdb::LogicalType; - *ltype = duckdb::LogicalType::LIST(*reinterpret_cast(type)); - return reinterpret_cast(ltype); -} - -duckdb_logical_type duckdb_create_union_type(duckdb_logical_type member_types_p, const char **member_names, - idx_t member_count) { - if (!member_types_p || !member_names) { - return nullptr; - } - duckdb::LogicalType *member_types = reinterpret_cast(member_types_p); - duckdb::LogicalType *mtype = new duckdb::LogicalType; - duckdb::child_list_t members; - - for (idx_t i = 0; i < member_count; i++) { - members.push_back(make_pair(member_names[i], member_types[i])); - } - *mtype = duckdb::LogicalType::UNION(members); - return reinterpret_cast(mtype); -} - -duckdb_logical_type duckdb_create_struct_type(duckdb_logical_type *member_types_p, const char **member_names, - idx_t member_count) { - if (!member_types_p || !member_names) { - return nullptr; - } - duckdb::LogicalType **member_types = (duckdb::LogicalType **)member_types_p; - for (idx_t i = 0; i < member_count; i++) { - if (!member_names[i] || !member_types[i]) { - return nullptr; - } - } - - duckdb::LogicalType *mtype = new duckdb::LogicalType; - duckdb::child_list_t members; - - for (idx_t i = 0; i < member_count; i++) { - members.push_back(make_pair(member_names[i], *member_types[i])); - } - *mtype = duckdb::LogicalType::STRUCT(members); - return reinterpret_cast(mtype); -} - -duckdb_logical_type duckdb_create_map_type(duckdb_logical_type key_type, duckdb_logical_type value_type) { - if (!key_type || !value_type) { - return nullptr; - } - duckdb::LogicalType *mtype = new duckdb::LogicalType; - *mtype = duckdb::LogicalType::MAP(*reinterpret_cast(key_type), - *reinterpret_cast(value_type)); - return reinterpret_cast(mtype); -} - -duckdb_logical_type duckdb_create_decimal_type(uint8_t width, uint8_t scale) { - return reinterpret_cast(new duckdb::LogicalType(duckdb::LogicalType::DECIMAL(width, scale))); -} - -duckdb_type duckdb_get_type_id(duckdb_logical_type type) { - if (!type) { - return DUCKDB_TYPE_INVALID; - } - auto ltype = reinterpret_cast(type); - return duckdb::ConvertCPPTypeToC(*ltype); -} - -void duckdb_destroy_logical_type(duckdb_logical_type *type) { - if (type && *type) { - auto ltype = reinterpret_cast(*type); - delete ltype; - *type = nullptr; - } -} - -uint8_t duckdb_decimal_width(duckdb_logical_type type) { - if (!AssertLogicalTypeId(type, duckdb::LogicalTypeId::DECIMAL)) { - return 0; - } - auto <ype = *(reinterpret_cast(type)); - return duckdb::DecimalType::GetWidth(ltype); -} - -uint8_t duckdb_decimal_scale(duckdb_logical_type type) { - if (!AssertLogicalTypeId(type, duckdb::LogicalTypeId::DECIMAL)) { - return 0; - } - auto <ype = *(reinterpret_cast(type)); - return duckdb::DecimalType::GetScale(ltype); -} - -duckdb_type duckdb_decimal_internal_type(duckdb_logical_type type) { - if (!AssertLogicalTypeId(type, duckdb::LogicalTypeId::DECIMAL)) { - return DUCKDB_TYPE_INVALID; - } - auto <ype = *(reinterpret_cast(type)); - switch (ltype.InternalType()) { - case duckdb::PhysicalType::INT16: - return DUCKDB_TYPE_SMALLINT; - case duckdb::PhysicalType::INT32: - return DUCKDB_TYPE_INTEGER; - case duckdb::PhysicalType::INT64: - return DUCKDB_TYPE_BIGINT; - case duckdb::PhysicalType::INT128: - return DUCKDB_TYPE_HUGEINT; - default: - return DUCKDB_TYPE_INVALID; - } -} - -duckdb_type duckdb_enum_internal_type(duckdb_logical_type type) { - if (!AssertLogicalTypeId(type, duckdb::LogicalTypeId::ENUM)) { - return DUCKDB_TYPE_INVALID; - } - auto <ype = *(reinterpret_cast(type)); - switch (ltype.InternalType()) { - case duckdb::PhysicalType::UINT8: - return DUCKDB_TYPE_UTINYINT; - case duckdb::PhysicalType::UINT16: - return DUCKDB_TYPE_USMALLINT; - case duckdb::PhysicalType::UINT32: - return DUCKDB_TYPE_UINTEGER; - default: - return DUCKDB_TYPE_INVALID; - } -} - -uint32_t duckdb_enum_dictionary_size(duckdb_logical_type type) { - if (!AssertLogicalTypeId(type, duckdb::LogicalTypeId::ENUM)) { - return 0; - } - auto <ype = *(reinterpret_cast(type)); - return duckdb::EnumType::GetSize(ltype); -} - -char *duckdb_enum_dictionary_value(duckdb_logical_type type, idx_t index) { - if (!AssertLogicalTypeId(type, duckdb::LogicalTypeId::ENUM)) { - return nullptr; - } - auto <ype = *(reinterpret_cast(type)); - auto &vector = duckdb::EnumType::GetValuesInsertOrder(ltype); - auto value = vector.GetValue(index); - return strdup(duckdb::StringValue::Get(value).c_str()); -} - -duckdb_logical_type duckdb_list_type_child_type(duckdb_logical_type type) { - if (!AssertLogicalTypeId(type, duckdb::LogicalTypeId::LIST) && - !AssertLogicalTypeId(type, duckdb::LogicalTypeId::MAP)) { - return nullptr; - } - auto <ype = *(reinterpret_cast(type)); - if (ltype.id() != duckdb::LogicalTypeId::LIST && ltype.id() != duckdb::LogicalTypeId::MAP) { - return nullptr; - } - return reinterpret_cast(new duckdb::LogicalType(duckdb::ListType::GetChildType(ltype))); -} - -duckdb_logical_type duckdb_map_type_key_type(duckdb_logical_type type) { - if (!AssertLogicalTypeId(type, duckdb::LogicalTypeId::MAP)) { - return nullptr; - } - auto &mtype = *(reinterpret_cast(type)); - if (mtype.id() != duckdb::LogicalTypeId::MAP) { - return nullptr; - } - return reinterpret_cast(new duckdb::LogicalType(duckdb::MapType::KeyType(mtype))); -} - -duckdb_logical_type duckdb_map_type_value_type(duckdb_logical_type type) { - if (!AssertLogicalTypeId(type, duckdb::LogicalTypeId::MAP)) { - return nullptr; - } - auto &mtype = *(reinterpret_cast(type)); - if (mtype.id() != duckdb::LogicalTypeId::MAP) { - return nullptr; - } - return reinterpret_cast(new duckdb::LogicalType(duckdb::MapType::ValueType(mtype))); -} - -idx_t duckdb_struct_type_child_count(duckdb_logical_type type) { - if (!AssertInternalType(type, duckdb::PhysicalType::STRUCT)) { - return 0; - } - auto <ype = *(reinterpret_cast(type)); - return duckdb::StructType::GetChildCount(ltype); -} - -idx_t duckdb_union_type_member_count(duckdb_logical_type type) { - if (!AssertLogicalTypeId(type, duckdb::LogicalTypeId::UNION)) { - return 0; - } - idx_t member_count = duckdb_struct_type_child_count(type); - if (member_count != 0) { - member_count--; - } - return member_count; -} - -char *duckdb_union_type_member_name(duckdb_logical_type type, idx_t index) { - if (!AssertInternalType(type, duckdb::PhysicalType::STRUCT)) { - return nullptr; - } - if (!AssertLogicalTypeId(type, duckdb::LogicalTypeId::UNION)) { - return nullptr; - } - auto <ype = *(reinterpret_cast(type)); - return strdup(duckdb::UnionType::GetMemberName(ltype, index).c_str()); -} - -duckdb_logical_type duckdb_union_type_member_type(duckdb_logical_type type, idx_t index) { - if (!AssertInternalType(type, duckdb::PhysicalType::STRUCT)) { - return nullptr; - } - if (!AssertLogicalTypeId(type, duckdb::LogicalTypeId::UNION)) { - return nullptr; - } - auto <ype = *(reinterpret_cast(type)); - return reinterpret_cast( - new duckdb::LogicalType(duckdb::UnionType::GetMemberType(ltype, index))); -} - -char *duckdb_struct_type_child_name(duckdb_logical_type type, idx_t index) { - if (!AssertInternalType(type, duckdb::PhysicalType::STRUCT)) { - return nullptr; - } - auto <ype = *(reinterpret_cast(type)); - return strdup(duckdb::StructType::GetChildName(ltype, index).c_str()); -} - -duckdb_logical_type duckdb_struct_type_child_type(duckdb_logical_type type, idx_t index) { - if (!AssertInternalType(type, duckdb::PhysicalType::STRUCT)) { - return nullptr; - } - auto <ype = *(reinterpret_cast(type)); - if (ltype.InternalType() != duckdb::PhysicalType::STRUCT) { - return nullptr; - } - return reinterpret_cast( - new duckdb::LogicalType(duckdb::StructType::GetChildType(ltype, index))); -} - - - - - - - -using duckdb::case_insensitive_map_t; -using duckdb::make_uniq; -using duckdb::optional_ptr; -using duckdb::PendingExecutionResult; -using duckdb::PendingQueryResult; -using duckdb::PendingStatementWrapper; -using duckdb::PreparedStatementWrapper; -using duckdb::Value; - -duckdb_state duckdb_pending_prepared_internal(duckdb_prepared_statement prepared_statement, - duckdb_pending_result *out_result, bool allow_streaming) { - if (!prepared_statement || !out_result) { - return DuckDBError; - } - auto wrapper = reinterpret_cast(prepared_statement); - auto result = new PendingStatementWrapper(); - result->allow_streaming = allow_streaming; - - try { - result->statement = wrapper->statement->PendingQuery(wrapper->values, allow_streaming); - } catch (const duckdb::Exception &ex) { - result->statement = make_uniq(duckdb::PreservedError(ex)); - } catch (std::exception &ex) { - result->statement = make_uniq(duckdb::PreservedError(ex)); - } - duckdb_state return_value = !result->statement->HasError() ? DuckDBSuccess : DuckDBError; - *out_result = reinterpret_cast(result); - - return return_value; -} - -duckdb_state duckdb_pending_prepared(duckdb_prepared_statement prepared_statement, duckdb_pending_result *out_result) { - return duckdb_pending_prepared_internal(prepared_statement, out_result, false); -} - -duckdb_state duckdb_pending_prepared_streaming(duckdb_prepared_statement prepared_statement, - duckdb_pending_result *out_result) { - return duckdb_pending_prepared_internal(prepared_statement, out_result, true); -} - -void duckdb_destroy_pending(duckdb_pending_result *pending_result) { - if (!pending_result || !*pending_result) { - return; - } - auto wrapper = reinterpret_cast(*pending_result); - if (wrapper->statement) { - wrapper->statement->Close(); - } - delete wrapper; - *pending_result = nullptr; -} - -const char *duckdb_pending_error(duckdb_pending_result pending_result) { - if (!pending_result) { - return nullptr; - } - auto wrapper = reinterpret_cast(pending_result); - if (!wrapper->statement) { - return nullptr; - } - return wrapper->statement->GetError().c_str(); -} - -duckdb_pending_state duckdb_pending_execute_task(duckdb_pending_result pending_result) { - if (!pending_result) { - return DUCKDB_PENDING_ERROR; - } - auto wrapper = reinterpret_cast(pending_result); - if (!wrapper->statement) { - return DUCKDB_PENDING_ERROR; - } - if (wrapper->statement->HasError()) { - return DUCKDB_PENDING_ERROR; - } - PendingExecutionResult return_value; - try { - return_value = wrapper->statement->ExecuteTask(); - } catch (const duckdb::Exception &ex) { - wrapper->statement->SetError(duckdb::PreservedError(ex)); - return DUCKDB_PENDING_ERROR; - } catch (std::exception &ex) { - wrapper->statement->SetError(duckdb::PreservedError(ex)); - return DUCKDB_PENDING_ERROR; - } - switch (return_value) { - case PendingExecutionResult::RESULT_READY: - return DUCKDB_PENDING_RESULT_READY; - case PendingExecutionResult::NO_TASKS_AVAILABLE: - return DUCKDB_PENDING_NO_TASKS_AVAILABLE; - case PendingExecutionResult::RESULT_NOT_READY: - return DUCKDB_PENDING_RESULT_NOT_READY; - default: - return DUCKDB_PENDING_ERROR; - } -} - -bool duckdb_pending_execution_is_finished(duckdb_pending_state pending_state) { - switch (pending_state) { - case DUCKDB_PENDING_RESULT_READY: - return PendingQueryResult::IsFinished(PendingExecutionResult::RESULT_READY); - case DUCKDB_PENDING_NO_TASKS_AVAILABLE: - return PendingQueryResult::IsFinished(PendingExecutionResult::NO_TASKS_AVAILABLE); - case DUCKDB_PENDING_RESULT_NOT_READY: - return PendingQueryResult::IsFinished(PendingExecutionResult::RESULT_NOT_READY); - case DUCKDB_PENDING_ERROR: - return PendingQueryResult::IsFinished(PendingExecutionResult::EXECUTION_ERROR); - default: - return PendingQueryResult::IsFinished(PendingExecutionResult::EXECUTION_ERROR); - } -} - -duckdb_state duckdb_execute_pending(duckdb_pending_result pending_result, duckdb_result *out_result) { - if (!pending_result || !out_result) { - return DuckDBError; - } - auto wrapper = reinterpret_cast(pending_result); - if (!wrapper->statement) { - return DuckDBError; - } - - duckdb::unique_ptr result; - result = wrapper->statement->Execute(); - wrapper->statement.reset(); - return duckdb_translate_result(std::move(result), out_result); -} - - - - - - - -using duckdb::case_insensitive_map_t; -using duckdb::Connection; -using duckdb::date_t; -using duckdb::dtime_t; -using duckdb::ExtractStatementsWrapper; -using duckdb::hugeint_t; -using duckdb::LogicalType; -using duckdb::MaterializedQueryResult; -using duckdb::optional_ptr; -using duckdb::PreparedStatementWrapper; -using duckdb::QueryResultType; -using duckdb::StringUtil; -using duckdb::timestamp_t; -using duckdb::Value; - -idx_t duckdb_extract_statements(duckdb_connection connection, const char *query, - duckdb_extracted_statements *out_extracted_statements) { - if (!connection || !query || !out_extracted_statements) { - return 0; - } - auto wrapper = new ExtractStatementsWrapper(); - Connection *conn = reinterpret_cast(connection); - try { - wrapper->statements = conn->ExtractStatements(query); - } catch (const duckdb::ParserException &e) { - wrapper->error = e.what(); - } - - *out_extracted_statements = (duckdb_extracted_statements)wrapper; - return wrapper->statements.size(); -} - -duckdb_state duckdb_prepare_extracted_statement(duckdb_connection connection, - duckdb_extracted_statements extracted_statements, idx_t index, - duckdb_prepared_statement *out_prepared_statement) { - Connection *conn = reinterpret_cast(connection); - auto source_wrapper = (ExtractStatementsWrapper *)extracted_statements; - - if (!connection || !out_prepared_statement || index >= source_wrapper->statements.size()) { - return DuckDBError; - } - auto wrapper = new PreparedStatementWrapper(); - wrapper->statement = conn->Prepare(std::move(source_wrapper->statements[index])); - - *out_prepared_statement = (duckdb_prepared_statement)wrapper; - return wrapper->statement->HasError() ? DuckDBError : DuckDBSuccess; -} - -const char *duckdb_extract_statements_error(duckdb_extracted_statements extracted_statements) { - auto wrapper = (ExtractStatementsWrapper *)extracted_statements; - if (!wrapper || wrapper->error.empty()) { - return nullptr; - } - return wrapper->error.c_str(); -} - -duckdb_state duckdb_prepare(duckdb_connection connection, const char *query, - duckdb_prepared_statement *out_prepared_statement) { - if (!connection || !query || !out_prepared_statement) { - return DuckDBError; - } - auto wrapper = new PreparedStatementWrapper(); - Connection *conn = reinterpret_cast(connection); - wrapper->statement = conn->Prepare(query); - *out_prepared_statement = (duckdb_prepared_statement)wrapper; - return !wrapper->statement->HasError() ? DuckDBSuccess : DuckDBError; -} - -const char *duckdb_prepare_error(duckdb_prepared_statement prepared_statement) { - auto wrapper = reinterpret_cast(prepared_statement); - if (!wrapper || !wrapper->statement || !wrapper->statement->HasError()) { - return nullptr; - } - return wrapper->statement->error.Message().c_str(); -} - -idx_t duckdb_nparams(duckdb_prepared_statement prepared_statement) { - auto wrapper = reinterpret_cast(prepared_statement); - if (!wrapper || !wrapper->statement || wrapper->statement->HasError()) { - return 0; - } - return wrapper->statement->n_param; -} - -static duckdb::string duckdb_parameter_name_internal(duckdb_prepared_statement prepared_statement, idx_t index) { - auto wrapper = (PreparedStatementWrapper *)prepared_statement; - if (!wrapper || !wrapper->statement || wrapper->statement->HasError()) { - return duckdb::string(); - } - if (index > wrapper->statement->n_param) { - return duckdb::string(); - } - for (auto &item : wrapper->statement->named_param_map) { - auto &identifier = item.first; - auto ¶m_idx = item.second; - if (param_idx == index) { - // Found the matching parameter - return identifier; - } - } - // No parameter was found with this index - return duckdb::string(); -} - -const char *duckdb_parameter_name(duckdb_prepared_statement prepared_statement, idx_t index) { - auto identifier = duckdb_parameter_name_internal(prepared_statement, index); - if (identifier == duckdb::string()) { - return NULL; - } - return strdup(identifier.c_str()); -} - -duckdb_type duckdb_param_type(duckdb_prepared_statement prepared_statement, idx_t param_idx) { - auto wrapper = reinterpret_cast(prepared_statement); - if (!wrapper || !wrapper->statement || wrapper->statement->HasError()) { - return DUCKDB_TYPE_INVALID; - } - LogicalType param_type; - auto identifier = std::to_string(param_idx); - if (wrapper->statement->data->TryGetType(identifier, param_type)) { - return ConvertCPPTypeToC(param_type); - } - // The value_map is gone after executing the prepared statement - // See if this is the case and we still have a value registered for it - auto it = wrapper->values.find(identifier); - if (it != wrapper->values.end()) { - return ConvertCPPTypeToC(it->second.type()); - } - return DUCKDB_TYPE_INVALID; -} - -duckdb_state duckdb_clear_bindings(duckdb_prepared_statement prepared_statement) { - auto wrapper = reinterpret_cast(prepared_statement); - if (!wrapper || !wrapper->statement || wrapper->statement->HasError()) { - return DuckDBError; - } - wrapper->values.clear(); - return DuckDBSuccess; -} - -duckdb_state duckdb_bind_value(duckdb_prepared_statement prepared_statement, idx_t param_idx, duckdb_value val) { - auto value = reinterpret_cast(val); - auto wrapper = reinterpret_cast(prepared_statement); - if (!wrapper || !wrapper->statement || wrapper->statement->HasError()) { - return DuckDBError; - } - if (param_idx <= 0 || param_idx > wrapper->statement->n_param) { - wrapper->statement->error = - duckdb::InvalidInputException("Can not bind to parameter number %d, statement only has %d parameter(s)", - param_idx, wrapper->statement->n_param); - return DuckDBError; - } - auto identifier = duckdb_parameter_name_internal(prepared_statement, param_idx); - wrapper->values[identifier] = *value; - return DuckDBSuccess; -} - -duckdb_state duckdb_bind_parameter_index(duckdb_prepared_statement prepared_statement, idx_t *param_idx_out, - const char *name_p) { - auto wrapper = (PreparedStatementWrapper *)prepared_statement; - if (!wrapper || !wrapper->statement || wrapper->statement->HasError()) { - return DuckDBError; - } - if (!name_p || !param_idx_out) { - return DuckDBError; - } - auto name = std::string(name_p); - for (auto &pair : wrapper->statement->named_param_map) { - if (duckdb::StringUtil::CIEquals(pair.first, name)) { - *param_idx_out = pair.second; - return DuckDBSuccess; - } - } - return DuckDBError; -} - -duckdb_state duckdb_bind_boolean(duckdb_prepared_statement prepared_statement, idx_t param_idx, bool val) { - auto value = Value::BOOLEAN(val); - return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); -} - -duckdb_state duckdb_bind_int8(duckdb_prepared_statement prepared_statement, idx_t param_idx, int8_t val) { - auto value = Value::TINYINT(val); - return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); -} - -duckdb_state duckdb_bind_int16(duckdb_prepared_statement prepared_statement, idx_t param_idx, int16_t val) { - auto value = Value::SMALLINT(val); - return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); -} - -duckdb_state duckdb_bind_int32(duckdb_prepared_statement prepared_statement, idx_t param_idx, int32_t val) { - auto value = Value::INTEGER(val); - return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); -} - -duckdb_state duckdb_bind_int64(duckdb_prepared_statement prepared_statement, idx_t param_idx, int64_t val) { - auto value = Value::BIGINT(val); - return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); -} - -static hugeint_t duckdb_internal_hugeint(duckdb_hugeint val) { - hugeint_t internal; - internal.lower = val.lower; - internal.upper = val.upper; - return internal; -} - -duckdb_state duckdb_bind_hugeint(duckdb_prepared_statement prepared_statement, idx_t param_idx, duckdb_hugeint val) { - auto value = Value::HUGEINT(duckdb_internal_hugeint(val)); - return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); -} - -duckdb_state duckdb_bind_uint8(duckdb_prepared_statement prepared_statement, idx_t param_idx, uint8_t val) { - auto value = Value::UTINYINT(val); - return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); -} - -duckdb_state duckdb_bind_uint16(duckdb_prepared_statement prepared_statement, idx_t param_idx, uint16_t val) { - auto value = Value::USMALLINT(val); - return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); -} - -duckdb_state duckdb_bind_uint32(duckdb_prepared_statement prepared_statement, idx_t param_idx, uint32_t val) { - auto value = Value::UINTEGER(val); - return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); -} - -duckdb_state duckdb_bind_uint64(duckdb_prepared_statement prepared_statement, idx_t param_idx, uint64_t val) { - auto value = Value::UBIGINT(val); - return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); -} - -duckdb_state duckdb_bind_float(duckdb_prepared_statement prepared_statement, idx_t param_idx, float val) { - auto value = Value::FLOAT(val); - return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); -} - -duckdb_state duckdb_bind_double(duckdb_prepared_statement prepared_statement, idx_t param_idx, double val) { - auto value = Value::DOUBLE(val); - return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); -} - -duckdb_state duckdb_bind_date(duckdb_prepared_statement prepared_statement, idx_t param_idx, duckdb_date val) { - auto value = Value::DATE(date_t(val.days)); - return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); -} - -duckdb_state duckdb_bind_time(duckdb_prepared_statement prepared_statement, idx_t param_idx, duckdb_time val) { - auto value = Value::TIME(dtime_t(val.micros)); - return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); -} - -duckdb_state duckdb_bind_timestamp(duckdb_prepared_statement prepared_statement, idx_t param_idx, - duckdb_timestamp val) { - auto value = Value::TIMESTAMP(timestamp_t(val.micros)); - return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); -} - -duckdb_state duckdb_bind_interval(duckdb_prepared_statement prepared_statement, idx_t param_idx, duckdb_interval val) { - auto value = Value::INTERVAL(val.months, val.days, val.micros); - return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); -} - -duckdb_state duckdb_bind_varchar(duckdb_prepared_statement prepared_statement, idx_t param_idx, const char *val) { - try { - auto value = Value(val); - return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); - } catch (...) { - return DuckDBError; - } -} - -duckdb_state duckdb_bind_varchar_length(duckdb_prepared_statement prepared_statement, idx_t param_idx, const char *val, - idx_t length) { - try { - auto value = Value(std::string(val, length)); - return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); - } catch (...) { - return DuckDBError; - } -} - -duckdb_state duckdb_bind_decimal(duckdb_prepared_statement prepared_statement, idx_t param_idx, duckdb_decimal val) { - auto hugeint_val = duckdb_internal_hugeint(val.value); - if (val.width > duckdb::Decimal::MAX_WIDTH_INT64) { - auto value = Value::DECIMAL(hugeint_val, val.width, val.scale); - return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); - } - auto value = hugeint_val.lower; - auto duck_val = Value::DECIMAL((int64_t)value, val.width, val.scale); - return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&duck_val); -} - -duckdb_state duckdb_bind_blob(duckdb_prepared_statement prepared_statement, idx_t param_idx, const void *data, - idx_t length) { - auto value = Value::BLOB(duckdb::const_data_ptr_cast(data), length); - return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); -} - -duckdb_state duckdb_bind_null(duckdb_prepared_statement prepared_statement, idx_t param_idx) { - auto value = Value(); - return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); -} - -duckdb_state duckdb_execute_prepared(duckdb_prepared_statement prepared_statement, duckdb_result *out_result) { - auto wrapper = reinterpret_cast(prepared_statement); - if (!wrapper || !wrapper->statement || wrapper->statement->HasError()) { - return DuckDBError; - } - - auto result = wrapper->statement->Execute(wrapper->values, false); - return duckdb_translate_result(std::move(result), out_result); -} - -template -void duckdb_destroy(void **wrapper) { - if (!wrapper) { - return; - } - - auto casted = (T *)*wrapper; - if (casted) { - delete casted; - } - *wrapper = nullptr; -} - -void duckdb_destroy_extracted(duckdb_extracted_statements *extracted_statements) { - duckdb_destroy(reinterpret_cast(extracted_statements)); -} - -void duckdb_destroy_prepare(duckdb_prepared_statement *prepared_statement) { - duckdb_destroy(reinterpret_cast(prepared_statement)); -} - - - - - - -namespace duckdb { - -struct CAPIReplacementScanData : public ReplacementScanData { - ~CAPIReplacementScanData() { - if (delete_callback) { - delete_callback(extra_data); - } - } - - duckdb_replacement_callback_t callback; - void *extra_data; - duckdb_delete_callback_t delete_callback; -}; - -struct CAPIReplacementScanInfo { - CAPIReplacementScanInfo(CAPIReplacementScanData *data) : data(data) { - } - - CAPIReplacementScanData *data; - string function_name; - vector parameters; - string error; -}; - -unique_ptr duckdb_capi_replacement_callback(ClientContext &context, const string &table_name, - ReplacementScanData *data) { - auto &scan_data = reinterpret_cast(*data); - - CAPIReplacementScanInfo info(&scan_data); - scan_data.callback((duckdb_replacement_scan_info)&info, table_name.c_str(), scan_data.extra_data); - if (!info.error.empty()) { - throw BinderException("Error in replacement scan: %s\n", info.error); - } - if (info.function_name.empty()) { - // no function provided: bail-out - return nullptr; - } - auto table_function = make_uniq(); - vector> children; - for (auto ¶m : info.parameters) { - children.push_back(make_uniq(std::move(param))); - } - table_function->function = make_uniq(info.function_name, std::move(children)); - return std::move(table_function); -} - -} // namespace duckdb - -void duckdb_add_replacement_scan(duckdb_database db, duckdb_replacement_callback_t replacement, void *extra_data, - duckdb_delete_callback_t delete_callback) { - if (!db || !replacement) { - return; - } - auto wrapper = reinterpret_cast(db); - auto scan_info = duckdb::make_uniq(); - scan_info->callback = replacement; - scan_info->extra_data = extra_data; - scan_info->delete_callback = delete_callback; - - auto &config = duckdb::DBConfig::GetConfig(*wrapper->database->instance); - config.replacement_scans.push_back( - duckdb::ReplacementScan(duckdb::duckdb_capi_replacement_callback, std::move(scan_info))); -} - -void duckdb_replacement_scan_set_function_name(duckdb_replacement_scan_info info_p, const char *function_name) { - if (!info_p || !function_name) { - return; - } - auto info = reinterpret_cast(info_p); - info->function_name = function_name; -} - -void duckdb_replacement_scan_add_parameter(duckdb_replacement_scan_info info_p, duckdb_value parameter) { - if (!info_p || !parameter) { - return; - } - auto info = reinterpret_cast(info_p); - auto val = reinterpret_cast(parameter); - info->parameters.push_back(*val); -} - -void duckdb_replacement_scan_set_error(duckdb_replacement_scan_info info_p, const char *error) { - if (!info_p || !error) { - return; - } - auto info = reinterpret_cast(info_p); - info->error = error; -} - - - - -namespace duckdb { - -struct CBaseConverter { - template - static void NullConvert(DST &target) { - } -}; -struct CStandardConverter : public CBaseConverter { - template - static DST Convert(SRC input) { - return input; - } -}; - -struct CStringConverter { - template - static DST Convert(SRC input) { - auto result = char_ptr_cast(duckdb_malloc(input.GetSize() + 1)); - assert(result); - memcpy((void *)result, input.GetData(), input.GetSize()); - auto write_arr = char_ptr_cast(result); - write_arr[input.GetSize()] = '\0'; - return result; - } - - template - static void NullConvert(DST &target) { - target = nullptr; - } -}; - -struct CBlobConverter { - template - static DST Convert(SRC input) { - duckdb_blob result; - result.data = char_ptr_cast(duckdb_malloc(input.GetSize())); - result.size = input.GetSize(); - assert(result.data); - memcpy(result.data, input.GetData(), input.GetSize()); - return result; - } - - template - static void NullConvert(DST &target) { - target.data = nullptr; - target.size = 0; - } -}; - -struct CTimestampMsConverter : public CBaseConverter { - template - static DST Convert(SRC input) { - return Timestamp::FromEpochMs(input.value); - } -}; - -struct CTimestampNsConverter : public CBaseConverter { - template - static DST Convert(SRC input) { - return Timestamp::FromEpochNanoSeconds(input.value); - } -}; - -struct CTimestampSecConverter : public CBaseConverter { - template - static DST Convert(SRC input) { - return Timestamp::FromEpochSeconds(input.value); - } -}; - -struct CHugeintConverter : public CBaseConverter { - template - static DST Convert(SRC input) { - duckdb_hugeint result; - result.lower = input.lower; - result.upper = input.upper; - return result; - } -}; - -struct CIntervalConverter : public CBaseConverter { - template - static DST Convert(SRC input) { - duckdb_interval result; - result.days = input.days; - result.months = input.months; - result.micros = input.micros; - return result; - } -}; - -template -struct CDecimalConverter : public CBaseConverter { - template - static DST Convert(SRC input) { - duckdb_hugeint result; - result.lower = input; - result.upper = 0; - return result; - } -}; - -template -void WriteData(duckdb_column *column, ColumnDataCollection &source, const vector &column_ids) { - idx_t row = 0; - auto target = (DST *)column->__deprecated_data; - for (auto &input : source.Chunks(column_ids)) { - auto source = FlatVector::GetData(input.data[0]); - auto &mask = FlatVector::Validity(input.data[0]); - - for (idx_t k = 0; k < input.size(); k++, row++) { - if (!mask.RowIsValid(k)) { - OP::template NullConvert(target[row]); - } else { - target[row] = OP::template Convert(source[k]); - } - } - } -} - -duckdb_state deprecated_duckdb_translate_column(MaterializedQueryResult &result, duckdb_column *column, idx_t col) { - D_ASSERT(!result.HasError()); - auto &collection = result.Collection(); - idx_t row_count = collection.Count(); - column->__deprecated_nullmask = (bool *)duckdb_malloc(sizeof(bool) * collection.Count()); - column->__deprecated_data = duckdb_malloc(GetCTypeSize(column->__deprecated_type) * row_count); - if (!column->__deprecated_nullmask || !column->__deprecated_data) { // LCOV_EXCL_START - // malloc failure - return DuckDBError; - } // LCOV_EXCL_STOP - - vector column_ids {col}; - // first convert the nullmask - { - idx_t row = 0; - for (auto &input : collection.Chunks(column_ids)) { - for (idx_t k = 0; k < input.size(); k++) { - column->__deprecated_nullmask[row++] = FlatVector::IsNull(input.data[0], k); - } - } - } - // then write the data - switch (result.types[col].id()) { - case LogicalTypeId::BOOLEAN: - WriteData(column, collection, column_ids); - break; - case LogicalTypeId::TINYINT: - WriteData(column, collection, column_ids); - break; - case LogicalTypeId::SMALLINT: - WriteData(column, collection, column_ids); - break; - case LogicalTypeId::INTEGER: - WriteData(column, collection, column_ids); - break; - case LogicalTypeId::BIGINT: - WriteData(column, collection, column_ids); - break; - case LogicalTypeId::UTINYINT: - WriteData(column, collection, column_ids); - break; - case LogicalTypeId::USMALLINT: - WriteData(column, collection, column_ids); - break; - case LogicalTypeId::UINTEGER: - WriteData(column, collection, column_ids); - break; - case LogicalTypeId::UBIGINT: - WriteData(column, collection, column_ids); - break; - case LogicalTypeId::FLOAT: - WriteData(column, collection, column_ids); - break; - case LogicalTypeId::DOUBLE: - WriteData(column, collection, column_ids); - break; - case LogicalTypeId::DATE: - WriteData(column, collection, column_ids); - break; - case LogicalTypeId::TIME: - WriteData(column, collection, column_ids); - break; - case LogicalTypeId::TIME_TZ: - WriteData(column, collection, column_ids); - break; - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_TZ: - WriteData(column, collection, column_ids); - break; - case LogicalTypeId::VARCHAR: { - WriteData(column, collection, column_ids); - break; - } - case LogicalTypeId::BLOB: { - WriteData(column, collection, column_ids); - break; - } - case LogicalTypeId::TIMESTAMP_NS: { - WriteData(column, collection, column_ids); - break; - } - case LogicalTypeId::TIMESTAMP_MS: { - WriteData(column, collection, column_ids); - break; - } - case LogicalTypeId::TIMESTAMP_SEC: { - WriteData(column, collection, column_ids); - break; - } - case LogicalTypeId::HUGEINT: { - WriteData(column, collection, column_ids); - break; - } - case LogicalTypeId::INTERVAL: { - WriteData(column, collection, column_ids); - break; - } - case LogicalTypeId::DECIMAL: { - // get data - switch (result.types[col].InternalType()) { - case PhysicalType::INT16: { - WriteData>(column, collection, column_ids); - break; - } - case PhysicalType::INT32: { - WriteData>(column, collection, column_ids); - break; - } - case PhysicalType::INT64: { - WriteData>(column, collection, column_ids); - break; - } - case PhysicalType::INT128: { - WriteData(column, collection, column_ids); - break; - } - default: - throw std::runtime_error("Unsupported physical type for Decimal" + - TypeIdToString(result.types[col].InternalType())); - } - break; - } - default: // LCOV_EXCL_START - return DuckDBError; - } // LCOV_EXCL_STOP - return DuckDBSuccess; -} - -duckdb_state duckdb_translate_result(unique_ptr result_p, duckdb_result *out) { - auto &result = *result_p; - D_ASSERT(result_p); - if (!out) { - // no result to write to, only return the status - return !result.HasError() ? DuckDBSuccess : DuckDBError; - } - - memset(out, 0, sizeof(duckdb_result)); - - // initialize the result_data object - auto result_data = new DuckDBResultData(); - result_data->result = std::move(result_p); - result_data->result_set_type = CAPIResultSetType::CAPI_RESULT_TYPE_NONE; - out->internal_data = result_data; - - if (result.HasError()) { - // write the error message - out->__deprecated_error_message = (char *)result.GetError().c_str(); // NOLINT - return DuckDBError; - } - // copy the data - // first write the meta data - out->__deprecated_column_count = result.ColumnCount(); - out->__deprecated_rows_changed = 0; - return DuckDBSuccess; -} - -bool deprecated_materialize_result(duckdb_result *result) { - if (!result) { - return false; - } - auto result_data = reinterpret_cast(result->internal_data); - if (result_data->result->HasError()) { - return false; - } - if (result_data->result_set_type == CAPIResultSetType::CAPI_RESULT_TYPE_DEPRECATED) { - // already materialized into deprecated result format - return true; - } - if (result_data->result_set_type == CAPIResultSetType::CAPI_RESULT_TYPE_MATERIALIZED) { - // already used as a new result set - return false; - } - if (result_data->result_set_type == CAPIResultSetType::CAPI_RESULT_TYPE_STREAMING) { - // already used as a streaming result - return false; - } - // materialize as deprecated result set - result_data->result_set_type = CAPIResultSetType::CAPI_RESULT_TYPE_DEPRECATED; - auto column_count = result_data->result->ColumnCount(); - result->__deprecated_columns = (duckdb_column *)duckdb_malloc(sizeof(duckdb_column) * column_count); - if (!result->__deprecated_columns) { // LCOV_EXCL_START - // malloc failure - return DuckDBError; - } // LCOV_EXCL_STOP - if (result_data->result->type == QueryResultType::STREAM_RESULT) { - // if we are dealing with a stream result, convert it to a materialized result first - auto &stream_result = (StreamQueryResult &)*result_data->result; - result_data->result = stream_result.Materialize(); - } - D_ASSERT(result_data->result->type == QueryResultType::MATERIALIZED_RESULT); - auto &materialized = reinterpret_cast(*result_data->result); - - // convert the result to a materialized result - // zero initialize the columns (so we can cleanly delete it in case a malloc fails) - memset(result->__deprecated_columns, 0, sizeof(duckdb_column) * column_count); - for (idx_t i = 0; i < column_count; i++) { - result->__deprecated_columns[i].__deprecated_type = ConvertCPPTypeToC(result_data->result->types[i]); - result->__deprecated_columns[i].__deprecated_name = (char *)result_data->result->names[i].c_str(); // NOLINT - } - result->__deprecated_row_count = materialized.RowCount(); - if (result->__deprecated_row_count > 0 && - materialized.properties.return_type == StatementReturnType::CHANGED_ROWS) { - // update total changes - auto row_changes = materialized.GetValue(0, 0); - if (!row_changes.IsNull() && row_changes.DefaultTryCastAs(LogicalType::BIGINT)) { - result->__deprecated_rows_changed = row_changes.GetValue(); - } - } - // now write the data - for (idx_t col = 0; col < column_count; col++) { - auto state = deprecated_duckdb_translate_column(materialized, &result->__deprecated_columns[col], col); - if (state != DuckDBSuccess) { - return false; - } - } - return true; -} - -} // namespace duckdb - -static void DuckdbDestroyColumn(duckdb_column column, idx_t count) { - if (column.__deprecated_data) { - if (column.__deprecated_type == DUCKDB_TYPE_VARCHAR) { - // varchar, delete individual strings - auto data = reinterpret_cast(column.__deprecated_data); - for (idx_t i = 0; i < count; i++) { - if (data[i]) { - duckdb_free(data[i]); - } - } - } else if (column.__deprecated_type == DUCKDB_TYPE_BLOB) { - // blob, delete individual blobs - auto data = reinterpret_cast(column.__deprecated_data); - for (idx_t i = 0; i < count; i++) { - if (data[i].data) { - duckdb_free((void *)data[i].data); - } - } - } - duckdb_free(column.__deprecated_data); - } - if (column.__deprecated_nullmask) { - duckdb_free(column.__deprecated_nullmask); - } -} - -void duckdb_destroy_result(duckdb_result *result) { - if (result->__deprecated_columns) { - for (idx_t i = 0; i < result->__deprecated_column_count; i++) { - DuckdbDestroyColumn(result->__deprecated_columns[i], result->__deprecated_row_count); - } - duckdb_free(result->__deprecated_columns); - } - if (result->internal_data) { - auto result_data = reinterpret_cast(result->internal_data); - delete result_data; - } - memset(result, 0, sizeof(duckdb_result)); -} - -const char *duckdb_column_name(duckdb_result *result, idx_t col) { - if (!result || col >= duckdb_column_count(result)) { - return nullptr; - } - auto &result_data = *(reinterpret_cast(result->internal_data)); - return result_data.result->names[col].c_str(); -} - -duckdb_type duckdb_column_type(duckdb_result *result, idx_t col) { - if (!result || col >= duckdb_column_count(result)) { - return DUCKDB_TYPE_INVALID; - } - auto &result_data = *(reinterpret_cast(result->internal_data)); - return duckdb::ConvertCPPTypeToC(result_data.result->types[col]); -} - -duckdb_logical_type duckdb_column_logical_type(duckdb_result *result, idx_t col) { - if (!result || col >= duckdb_column_count(result)) { - return nullptr; - } - auto &result_data = *(reinterpret_cast(result->internal_data)); - return reinterpret_cast(new duckdb::LogicalType(result_data.result->types[col])); -} - -idx_t duckdb_column_count(duckdb_result *result) { - if (!result) { - return 0; - } - auto &result_data = *(reinterpret_cast(result->internal_data)); - return result_data.result->ColumnCount(); -} - -idx_t duckdb_row_count(duckdb_result *result) { - if (!result) { - return 0; - } - auto &result_data = *(reinterpret_cast(result->internal_data)); - if (result_data.result->type == duckdb::QueryResultType::STREAM_RESULT) { - // We can't know the row count beforehand - return 0; - } - auto &materialized = reinterpret_cast(*result_data.result); - return materialized.RowCount(); -} - -idx_t duckdb_rows_changed(duckdb_result *result) { - if (!result) { - return 0; - } - if (!duckdb::deprecated_materialize_result(result)) { - return 0; - } - return result->__deprecated_rows_changed; -} - -void *duckdb_column_data(duckdb_result *result, idx_t col) { - if (!result || col >= result->__deprecated_column_count) { - return nullptr; - } - if (!duckdb::deprecated_materialize_result(result)) { - return nullptr; - } - return result->__deprecated_columns[col].__deprecated_data; -} - -bool *duckdb_nullmask_data(duckdb_result *result, idx_t col) { - if (!result || col >= result->__deprecated_column_count) { - return nullptr; - } - if (!duckdb::deprecated_materialize_result(result)) { - return nullptr; - } - return result->__deprecated_columns[col].__deprecated_nullmask; -} - -const char *duckdb_result_error(duckdb_result *result) { - if (!result) { - return nullptr; - } - auto &result_data = *(reinterpret_cast(result->internal_data)); - return !result_data.result->HasError() ? nullptr : result_data.result->GetError().c_str(); -} - -idx_t duckdb_result_chunk_count(duckdb_result result) { - if (!result.internal_data) { - return 0; - } - auto &result_data = *(reinterpret_cast(result.internal_data)); - if (result_data.result_set_type == duckdb::CAPIResultSetType::CAPI_RESULT_TYPE_DEPRECATED) { - return 0; - } - if (result_data.result->type != duckdb::QueryResultType::MATERIALIZED_RESULT) { - // Can't know beforehand how many chunks are returned. - return 0; - } - auto &materialized = reinterpret_cast(*result_data.result); - return materialized.Collection().ChunkCount(); -} - -duckdb_data_chunk duckdb_result_get_chunk(duckdb_result result, idx_t chunk_idx) { - if (!result.internal_data) { - return nullptr; - } - auto &result_data = *(reinterpret_cast(result.internal_data)); - if (result_data.result_set_type == duckdb::CAPIResultSetType::CAPI_RESULT_TYPE_DEPRECATED) { - return nullptr; - } - if (result_data.result->type != duckdb::QueryResultType::MATERIALIZED_RESULT) { - // This API is only supported for materialized query results - return nullptr; - } - result_data.result_set_type = duckdb::CAPIResultSetType::CAPI_RESULT_TYPE_MATERIALIZED; - auto &materialized = reinterpret_cast(*result_data.result); - auto &collection = materialized.Collection(); - if (chunk_idx >= collection.ChunkCount()) { - return nullptr; - } - auto chunk = duckdb::make_uniq(); - chunk->Initialize(duckdb::Allocator::DefaultAllocator(), collection.Types()); - collection.FetchChunk(chunk_idx, *chunk); - return reinterpret_cast(chunk.release()); -} - -bool duckdb_result_is_streaming(duckdb_result result) { - if (!result.internal_data) { - return false; - } - if (duckdb_result_error(&result) != nullptr) { - return false; - } - auto &result_data = *(reinterpret_cast(result.internal_data)); - return result_data.result->type == duckdb::QueryResultType::STREAM_RESULT; -} - - - - -duckdb_data_chunk duckdb_stream_fetch_chunk(duckdb_result result) { - if (!result.internal_data) { - return nullptr; - } - auto &result_data = *((duckdb::DuckDBResultData *)result.internal_data); - if (result_data.result_set_type == duckdb::CAPIResultSetType::CAPI_RESULT_TYPE_DEPRECATED) { - return nullptr; - } - if (result_data.result->type != duckdb::QueryResultType::STREAM_RESULT) { - // We can only fetch from a StreamQueryResult - return nullptr; - } - result_data.result_set_type = duckdb::CAPIResultSetType::CAPI_RESULT_TYPE_STREAMING; - auto &streaming = (duckdb::StreamQueryResult &)*result_data.result; - if (!streaming.IsOpen()) { - return nullptr; - } - // FetchRaw ? Do we care about flattening them? - auto chunk = streaming.Fetch(); - return reinterpret_cast(chunk.release()); -} - - - - - - - -namespace duckdb { - -struct CTableFunctionInfo : public TableFunctionInfo { - ~CTableFunctionInfo() { - if (extra_info && delete_callback) { - delete_callback(extra_info); - } - extra_info = nullptr; - delete_callback = nullptr; - } - - duckdb_table_function_bind_t bind = nullptr; - duckdb_table_function_init_t init = nullptr; - duckdb_table_function_init_t local_init = nullptr; - duckdb_table_function_t function = nullptr; - void *extra_info = nullptr; - duckdb_delete_callback_t delete_callback = nullptr; -}; - -struct CTableBindData : public TableFunctionData { - CTableBindData(CTableFunctionInfo &info) : info(info) { - } - ~CTableBindData() { - if (bind_data && delete_callback) { - delete_callback(bind_data); - } - bind_data = nullptr; - delete_callback = nullptr; - } - - CTableFunctionInfo &info; - void *bind_data = nullptr; - duckdb_delete_callback_t delete_callback = nullptr; - unique_ptr stats; -}; - -struct CTableInternalBindInfo { - CTableInternalBindInfo(ClientContext &context, TableFunctionBindInput &input, vector &return_types, - vector &names, CTableBindData &bind_data, CTableFunctionInfo &function_info) - : context(context), input(input), return_types(return_types), names(names), bind_data(bind_data), - function_info(function_info), success(true) { - } - - ClientContext &context; - TableFunctionBindInput &input; - vector &return_types; - vector &names; - CTableBindData &bind_data; - CTableFunctionInfo &function_info; - bool success; - string error; -}; - -struct CTableInitData { - ~CTableInitData() { - if (init_data && delete_callback) { - delete_callback(init_data); - } - init_data = nullptr; - delete_callback = nullptr; - } - - void *init_data = nullptr; - duckdb_delete_callback_t delete_callback = nullptr; - idx_t max_threads = 1; -}; - -struct CTableGlobalInitData : public GlobalTableFunctionState { - CTableInitData init_data; - - idx_t MaxThreads() const override { - return init_data.max_threads; - } -}; - -struct CTableLocalInitData : public LocalTableFunctionState { - CTableInitData init_data; -}; - -struct CTableInternalInitInfo { - CTableInternalInitInfo(const CTableBindData &bind_data, CTableInitData &init_data, - const vector &column_ids, optional_ptr filters) - : bind_data(bind_data), init_data(init_data), column_ids(column_ids), filters(filters), success(true) { - } - - const CTableBindData &bind_data; - CTableInitData &init_data; - const vector &column_ids; - optional_ptr filters; - bool success; - string error; -}; - -struct CTableInternalFunctionInfo { - CTableInternalFunctionInfo(const CTableBindData &bind_data, CTableInitData &init_data, CTableInitData &local_data) - : bind_data(bind_data), init_data(init_data), local_data(local_data), success(true) { - } - - const CTableBindData &bind_data; - CTableInitData &init_data; - CTableInitData &local_data; - bool success; - string error; -}; - -unique_ptr CTableFunctionBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - auto &info = input.info->Cast(); - D_ASSERT(info.bind && info.function && info.init); - auto result = make_uniq(info); - CTableInternalBindInfo bind_info(context, input, return_types, names, *result, info); - info.bind(&bind_info); - if (!bind_info.success) { - throw Exception(bind_info.error); - } - - return std::move(result); -} - -unique_ptr CTableFunctionInit(ClientContext &context, TableFunctionInitInput &data_p) { - auto &bind_data = data_p.bind_data->Cast(); - auto result = make_uniq(); - - CTableInternalInitInfo init_info(bind_data, result->init_data, data_p.column_ids, data_p.filters); - bind_data.info.init(&init_info); - if (!init_info.success) { - throw Exception(init_info.error); - } - return std::move(result); -} - -unique_ptr CTableFunctionLocalInit(ExecutionContext &context, TableFunctionInitInput &data_p, - GlobalTableFunctionState *gstate) { - auto &bind_data = data_p.bind_data->Cast(); - auto result = make_uniq(); - if (!bind_data.info.local_init) { - return std::move(result); - } - - CTableInternalInitInfo init_info(bind_data, result->init_data, data_p.column_ids, data_p.filters); - bind_data.info.local_init(&init_info); - if (!init_info.success) { - throw Exception(init_info.error); - } - return std::move(result); -} - -unique_ptr CTableFunctionCardinality(ClientContext &context, const FunctionData *bind_data_p) { - auto &bind_data = bind_data_p->Cast(); - if (!bind_data.stats) { - return nullptr; - } - return make_uniq(*bind_data.stats); -} - -void CTableFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &bind_data = data_p.bind_data->Cast(); - auto &global_data = (CTableGlobalInitData &)*data_p.global_state; - auto &local_data = (CTableLocalInitData &)*data_p.local_state; - CTableInternalFunctionInfo function_info(bind_data, global_data.init_data, local_data.init_data); - bind_data.info.function(&function_info, reinterpret_cast(&output)); - if (!function_info.success) { - throw Exception(function_info.error); - } -} - -} // namespace duckdb - -//===--------------------------------------------------------------------===// -// Table Function -//===--------------------------------------------------------------------===// -duckdb_table_function duckdb_create_table_function() { - auto function = new duckdb::TableFunction("", {}, duckdb::CTableFunction, duckdb::CTableFunctionBind, - duckdb::CTableFunctionInit, duckdb::CTableFunctionLocalInit); - function->function_info = duckdb::make_shared(); - function->cardinality = duckdb::CTableFunctionCardinality; - return function; -} - -void duckdb_destroy_table_function(duckdb_table_function *function) { - if (function && *function) { - auto tf = (duckdb::TableFunction *)*function; - delete tf; - *function = nullptr; - } -} - -void duckdb_table_function_set_name(duckdb_table_function function, const char *name) { - if (!function || !name) { - return; - } - auto tf = (duckdb::TableFunction *)function; - tf->name = name; -} - -void duckdb_table_function_add_parameter(duckdb_table_function function, duckdb_logical_type type) { - if (!function || !type) { - return; - } - auto tf = (duckdb::TableFunction *)function; - auto logical_type = (duckdb::LogicalType *)type; - tf->arguments.push_back(*logical_type); -} - -void duckdb_table_function_add_named_parameter(duckdb_table_function function, const char *name, - duckdb_logical_type type) { - if (!function || !type) { - return; - } - auto tf = (duckdb::TableFunction *)function; - auto logical_type = (duckdb::LogicalType *)type; - tf->named_parameters.insert({name, *logical_type}); -} - -void duckdb_table_function_set_extra_info(duckdb_table_function function, void *extra_info, - duckdb_delete_callback_t destroy) { - if (!function) { - return; - } - auto tf = (duckdb::TableFunction *)function; - auto info = (duckdb::CTableFunctionInfo *)tf->function_info.get(); - info->extra_info = extra_info; - info->delete_callback = destroy; -} - -void duckdb_table_function_set_bind(duckdb_table_function function, duckdb_table_function_bind_t bind) { - if (!function || !bind) { - return; - } - auto tf = (duckdb::TableFunction *)function; - auto info = (duckdb::CTableFunctionInfo *)tf->function_info.get(); - info->bind = bind; -} - -void duckdb_table_function_set_init(duckdb_table_function function, duckdb_table_function_init_t init) { - if (!function || !init) { - return; - } - auto tf = (duckdb::TableFunction *)function; - auto info = (duckdb::CTableFunctionInfo *)tf->function_info.get(); - info->init = init; -} - -void duckdb_table_function_set_local_init(duckdb_table_function function, duckdb_table_function_init_t init) { - if (!function || !init) { - return; - } - auto tf = (duckdb::TableFunction *)function; - auto info = (duckdb::CTableFunctionInfo *)tf->function_info.get(); - info->local_init = init; -} - -void duckdb_table_function_set_function(duckdb_table_function table_function, duckdb_table_function_t function) { - if (!table_function || !function) { - return; - } - auto tf = (duckdb::TableFunction *)table_function; - auto info = (duckdb::CTableFunctionInfo *)tf->function_info.get(); - info->function = function; -} - -void duckdb_table_function_supports_projection_pushdown(duckdb_table_function table_function, bool pushdown) { - if (!table_function) { - return; - } - auto tf = (duckdb::TableFunction *)table_function; - tf->projection_pushdown = pushdown; -} - -duckdb_state duckdb_register_table_function(duckdb_connection connection, duckdb_table_function function) { - if (!connection || !function) { - return DuckDBError; - } - auto con = (duckdb::Connection *)connection; - auto tf = (duckdb::TableFunction *)function; - auto info = (duckdb::CTableFunctionInfo *)tf->function_info.get(); - if (tf->name.empty() || !info->bind || !info->init || !info->function) { - return DuckDBError; - } - con->context->RunFunctionInTransaction([&]() { - auto &catalog = duckdb::Catalog::GetSystemCatalog(*con->context); - duckdb::CreateTableFunctionInfo tf_info(*tf); - - // create the function in the catalog - catalog.CreateTableFunction(*con->context, tf_info); - }); - return DuckDBSuccess; -} - -//===--------------------------------------------------------------------===// -// Bind Interface -//===--------------------------------------------------------------------===// -void *duckdb_bind_get_extra_info(duckdb_bind_info info) { - if (!info) { - return nullptr; - } - auto bind_info = (duckdb::CTableInternalBindInfo *)info; - return bind_info->function_info.extra_info; -} - -void duckdb_bind_add_result_column(duckdb_bind_info info, const char *name, duckdb_logical_type type) { - if (!info || !name || !type) { - return; - } - auto bind_info = (duckdb::CTableInternalBindInfo *)info; - bind_info->names.push_back(name); - bind_info->return_types.push_back(*(reinterpret_cast(type))); -} - -idx_t duckdb_bind_get_parameter_count(duckdb_bind_info info) { - if (!info) { - return 0; - } - auto bind_info = (duckdb::CTableInternalBindInfo *)info; - return bind_info->input.inputs.size(); -} - -duckdb_value duckdb_bind_get_parameter(duckdb_bind_info info, idx_t index) { - if (!info || index >= duckdb_bind_get_parameter_count(info)) { - return nullptr; - } - auto bind_info = (duckdb::CTableInternalBindInfo *)info; - return reinterpret_cast(new duckdb::Value(bind_info->input.inputs[index])); -} - -duckdb_value duckdb_bind_get_named_parameter(duckdb_bind_info info, const char *name) { - if (!info || !name) { - return nullptr; - } - auto bind_info = (duckdb::CTableInternalBindInfo *)info; - auto t = bind_info->input.named_parameters.find(name); - if (t == bind_info->input.named_parameters.end()) { - return nullptr; - } else { - return reinterpret_cast(new duckdb::Value(t->second)); - } -} - -void duckdb_bind_set_bind_data(duckdb_bind_info info, void *bind_data, duckdb_delete_callback_t destroy) { - if (!info) { - return; - } - auto bind_info = (duckdb::CTableInternalBindInfo *)info; - bind_info->bind_data.bind_data = bind_data; - bind_info->bind_data.delete_callback = destroy; -} - -void duckdb_bind_set_cardinality(duckdb_bind_info info, idx_t cardinality, bool is_exact) { - if (!info) { - return; - } - auto bind_info = (duckdb::CTableInternalBindInfo *)info; - if (is_exact) { - bind_info->bind_data.stats = duckdb::make_uniq(cardinality); - } else { - bind_info->bind_data.stats = duckdb::make_uniq(cardinality, cardinality); - } -} - -void duckdb_bind_set_error(duckdb_bind_info info, const char *error) { - if (!info || !error) { - return; - } - auto function_info = (duckdb::CTableInternalBindInfo *)info; - function_info->error = error; - function_info->success = false; -} - -//===--------------------------------------------------------------------===// -// Init Interface -//===--------------------------------------------------------------------===// -void *duckdb_init_get_extra_info(duckdb_init_info info) { - if (!info) { - return nullptr; - } - auto init_info = (duckdb::CTableInternalInitInfo *)info; - return init_info->bind_data.info.extra_info; -} - -void *duckdb_init_get_bind_data(duckdb_init_info info) { - if (!info) { - return nullptr; - } - auto init_info = (duckdb::CTableInternalInitInfo *)info; - return init_info->bind_data.bind_data; -} - -void duckdb_init_set_init_data(duckdb_init_info info, void *init_data, duckdb_delete_callback_t destroy) { - if (!info) { - return; - } - auto init_info = (duckdb::CTableInternalInitInfo *)info; - init_info->init_data.init_data = init_data; - init_info->init_data.delete_callback = destroy; -} - -void duckdb_init_set_error(duckdb_init_info info, const char *error) { - if (!info || !error) { - return; - } - auto function_info = (duckdb::CTableInternalInitInfo *)info; - function_info->error = error; - function_info->success = false; -} - -idx_t duckdb_init_get_column_count(duckdb_init_info info) { - if (!info) { - return 0; - } - auto function_info = (duckdb::CTableInternalInitInfo *)info; - return function_info->column_ids.size(); -} - -idx_t duckdb_init_get_column_index(duckdb_init_info info, idx_t column_index) { - if (!info) { - return 0; - } - auto function_info = (duckdb::CTableInternalInitInfo *)info; - if (column_index >= function_info->column_ids.size()) { - return 0; - } - return function_info->column_ids[column_index]; -} - -void duckdb_init_set_max_threads(duckdb_init_info info, idx_t max_threads) { - if (!info) { - return; - } - auto function_info = (duckdb::CTableInternalInitInfo *)info; - function_info->init_data.max_threads = max_threads; -} - -//===--------------------------------------------------------------------===// -// Function Interface -//===--------------------------------------------------------------------===// -void *duckdb_function_get_extra_info(duckdb_function_info info) { - if (!info) { - return nullptr; - } - auto function_info = (duckdb::CTableInternalFunctionInfo *)info; - return function_info->bind_data.info.extra_info; -} - -void *duckdb_function_get_bind_data(duckdb_function_info info) { - if (!info) { - return nullptr; - } - auto function_info = (duckdb::CTableInternalFunctionInfo *)info; - return function_info->bind_data.bind_data; -} - -void *duckdb_function_get_init_data(duckdb_function_info info) { - if (!info) { - return nullptr; - } - auto function_info = (duckdb::CTableInternalFunctionInfo *)info; - return function_info->init_data.init_data; -} - -void *duckdb_function_get_local_init_data(duckdb_function_info info) { - if (!info) { - return nullptr; - } - auto function_info = (duckdb::CTableInternalFunctionInfo *)info; - return function_info->local_data.init_data; -} - -void duckdb_function_set_error(duckdb_function_info info, const char *error) { - if (!info || !error) { - return; - } - auto function_info = (duckdb::CTableInternalFunctionInfo *)info; - function_info->error = error; - function_info->success = false; -} - - - -using duckdb::DatabaseData; - -struct CAPITaskState { - CAPITaskState(duckdb::DatabaseInstance &db) - : db(db), marker(duckdb::make_uniq>(true)), execute_count(0) { - } - - duckdb::DatabaseInstance &db; - duckdb::unique_ptr> marker; - duckdb::atomic execute_count; -}; - -void duckdb_execute_tasks(duckdb_database database, idx_t max_tasks) { - if (!database) { - return; - } - auto wrapper = (DatabaseData *)database; - auto &scheduler = duckdb::TaskScheduler::GetScheduler(*wrapper->database->instance); - scheduler.ExecuteTasks(max_tasks); -} - -duckdb_task_state duckdb_create_task_state(duckdb_database database) { - if (!database) { - return nullptr; - } - auto wrapper = (DatabaseData *)database; - auto state = new CAPITaskState(*wrapper->database->instance); - return state; -} - -void duckdb_execute_tasks_state(duckdb_task_state state_p) { - if (!state_p) { - return; - } - auto state = (CAPITaskState *)state_p; - auto &scheduler = duckdb::TaskScheduler::GetScheduler(state->db); - state->execute_count++; - scheduler.ExecuteForever(state->marker.get()); -} - -idx_t duckdb_execute_n_tasks_state(duckdb_task_state state_p, idx_t max_tasks) { - if (!state_p) { - return 0; - } - auto state = (CAPITaskState *)state_p; - auto &scheduler = duckdb::TaskScheduler::GetScheduler(state->db); - return scheduler.ExecuteTasks(state->marker.get(), max_tasks); -} - -void duckdb_finish_execution(duckdb_task_state state_p) { - if (!state_p) { - return; - } - auto state = (CAPITaskState *)state_p; - *state->marker = false; - if (state->execute_count > 0) { - // signal to the threads to wake up - auto &scheduler = duckdb::TaskScheduler::GetScheduler(state->db); - scheduler.Signal(state->execute_count); - } -} - -bool duckdb_task_state_is_finished(duckdb_task_state state_p) { - if (!state_p) { - return false; - } - auto state = (CAPITaskState *)state_p; - return !(*state->marker); -} - -void duckdb_destroy_task_state(duckdb_task_state state_p) { - if (!state_p) { - return; - } - auto state = (CAPITaskState *)state_p; - delete state; -} - -bool duckdb_execution_is_finished(duckdb_connection con) { - if (!con) { - return false; - } - duckdb::Connection *conn = (duckdb::Connection *)con; - return conn->context->ExecutionIsFinished(); -} - - - - - - - - -#include - -using duckdb::date_t; -using duckdb::dtime_t; -using duckdb::FetchDefaultValue; -using duckdb::GetInternalCValue; -using duckdb::hugeint_t; -using duckdb::interval_t; -using duckdb::StringCast; -using duckdb::timestamp_t; -using duckdb::ToCStringCastWrapper; -using duckdb::UnsafeFetch; - -bool duckdb_value_boolean(duckdb_result *result, idx_t col, idx_t row) { - return GetInternalCValue(result, col, row); -} - -int8_t duckdb_value_int8(duckdb_result *result, idx_t col, idx_t row) { - return GetInternalCValue(result, col, row); -} - -int16_t duckdb_value_int16(duckdb_result *result, idx_t col, idx_t row) { - return GetInternalCValue(result, col, row); -} - -int32_t duckdb_value_int32(duckdb_result *result, idx_t col, idx_t row) { - return GetInternalCValue(result, col, row); -} - -int64_t duckdb_value_int64(duckdb_result *result, idx_t col, idx_t row) { - return GetInternalCValue(result, col, row); -} - -static bool ResultIsDecimal(duckdb_result *result, idx_t col) { - if (!result) { - return false; - } - if (!result->internal_data) { - return false; - } - auto result_data = (duckdb::DuckDBResultData *)result->internal_data; - auto &query_result = result_data->result; - auto &source_type = query_result->types[col]; - return source_type.id() == duckdb::LogicalTypeId::DECIMAL; -} - -duckdb_decimal duckdb_value_decimal(duckdb_result *result, idx_t col, idx_t row) { - if (!CanFetchValue(result, col, row) || !ResultIsDecimal(result, col)) { - return FetchDefaultValue::Operation(); - } - - return GetInternalCValue(result, col, row); -} - -duckdb_hugeint duckdb_value_hugeint(duckdb_result *result, idx_t col, idx_t row) { - duckdb_hugeint result_value; - auto internal_value = GetInternalCValue(result, col, row); - result_value.lower = internal_value.lower; - result_value.upper = internal_value.upper; - return result_value; -} - -uint8_t duckdb_value_uint8(duckdb_result *result, idx_t col, idx_t row) { - return GetInternalCValue(result, col, row); -} - -uint16_t duckdb_value_uint16(duckdb_result *result, idx_t col, idx_t row) { - return GetInternalCValue(result, col, row); -} - -uint32_t duckdb_value_uint32(duckdb_result *result, idx_t col, idx_t row) { - return GetInternalCValue(result, col, row); -} - -uint64_t duckdb_value_uint64(duckdb_result *result, idx_t col, idx_t row) { - return GetInternalCValue(result, col, row); -} - -float duckdb_value_float(duckdb_result *result, idx_t col, idx_t row) { - return GetInternalCValue(result, col, row); -} - -double duckdb_value_double(duckdb_result *result, idx_t col, idx_t row) { - return GetInternalCValue(result, col, row); -} - -duckdb_date duckdb_value_date(duckdb_result *result, idx_t col, idx_t row) { - duckdb_date result_value; - result_value.days = GetInternalCValue(result, col, row).days; - return result_value; -} - -duckdb_time duckdb_value_time(duckdb_result *result, idx_t col, idx_t row) { - duckdb_time result_value; - result_value.micros = GetInternalCValue(result, col, row).micros; - return result_value; -} - -duckdb_timestamp duckdb_value_timestamp(duckdb_result *result, idx_t col, idx_t row) { - duckdb_timestamp result_value; - result_value.micros = GetInternalCValue(result, col, row).value; - return result_value; -} - -duckdb_interval duckdb_value_interval(duckdb_result *result, idx_t col, idx_t row) { - duckdb_interval result_value; - auto ival = GetInternalCValue(result, col, row); - result_value.months = ival.months; - result_value.days = ival.days; - result_value.micros = ival.micros; - return result_value; -} - -char *duckdb_value_varchar(duckdb_result *result, idx_t col, idx_t row) { - return duckdb_value_string(result, col, row).data; -} - -duckdb_string duckdb_value_string(duckdb_result *result, idx_t col, idx_t row) { - return GetInternalCValue>(result, col, row); -} - -char *duckdb_value_varchar_internal(duckdb_result *result, idx_t col, idx_t row) { - return duckdb_value_string_internal(result, col, row).data; -} - -duckdb_string duckdb_value_string_internal(duckdb_result *result, idx_t col, idx_t row) { - if (!CanFetchValue(result, col, row)) { - return FetchDefaultValue::Operation(); - } - if (duckdb_column_type(result, col) != DUCKDB_TYPE_VARCHAR) { - return FetchDefaultValue::Operation(); - } - // FIXME: this obviously does not work when there are null bytes in the string - // we need to remove the deprecated C result materialization to get that to work correctly - // since the deprecated C result materialization stores strings as null-terminated - duckdb_string res; - res.data = UnsafeFetch(result, col, row); - res.size = strlen(res.data); - return res; -} - -duckdb_blob duckdb_value_blob(duckdb_result *result, idx_t col, idx_t row) { - if (CanFetchValue(result, col, row) && result->__deprecated_columns[col].__deprecated_type == DUCKDB_TYPE_BLOB) { - auto internal_result = UnsafeFetch(result, col, row); - - duckdb_blob result_blob; - result_blob.data = malloc(internal_result.size); - result_blob.size = internal_result.size; - memcpy(result_blob.data, internal_result.data, internal_result.size); - return result_blob; - } - return FetchDefaultValue::Operation(); -} - -bool duckdb_value_is_null(duckdb_result *result, idx_t col, idx_t row) { - if (!CanUseDeprecatedFetch(result, col, row)) { - return false; - } - return result->__deprecated_columns[col].__deprecated_nullmask[row]; -} - - - - -namespace duckdb { - -QueryResultChunkScanState::QueryResultChunkScanState(QueryResult &result) : ChunkScanState(), result(result) { -} - -QueryResultChunkScanState::~QueryResultChunkScanState() { -} - -bool QueryResultChunkScanState::InternalLoad(PreservedError &error) { - D_ASSERT(!finished); - if (result.type == QueryResultType::STREAM_RESULT) { - auto &stream_result = result.Cast(); - if (!stream_result.IsOpen()) { - return true; - } - } - return result.TryFetch(current_chunk, error); -} - -bool QueryResultChunkScanState::HasError() const { - return result.HasError(); -} - -PreservedError &QueryResultChunkScanState::GetError() { - D_ASSERT(result.HasError()); - return result.GetErrorObject(); -} - -const vector &QueryResultChunkScanState::Types() const { - return result.types; -} - -const vector &QueryResultChunkScanState::Names() const { - return result.names; -} - -bool QueryResultChunkScanState::LoadNextChunk(PreservedError &error) { - if (finished) { - return !finished; - } - auto load_result = InternalLoad(error); - if (!load_result) { - finished = true; - } - offset = 0; - return !finished; -} - -} // namespace duckdb - - - -namespace duckdb { - -ChunkScanState::ChunkScanState() { -} - -ChunkScanState::~ChunkScanState() { -} - -idx_t ChunkScanState::CurrentOffset() const { - return offset; -} - -void ChunkScanState::IncreaseOffset(idx_t increment, bool unsafe) { - D_ASSERT(unsafe || increment <= RemainingInChunk()); - offset += increment; -} - -bool ChunkScanState::ChunkIsEmpty() const { - return !current_chunk || current_chunk->size() == 0; -} - -bool ChunkScanState::Finished() const { - return finished; -} - -bool ChunkScanState::ScanStarted() const { - return !ChunkIsEmpty(); -} - -DataChunk &ChunkScanState::CurrentChunk() { - // Scan must already be started - D_ASSERT(current_chunk); - return *current_chunk; -} - -idx_t ChunkScanState::RemainingInChunk() const { - if (ChunkIsEmpty()) { - return 0; - } - D_ASSERT(current_chunk); - D_ASSERT(offset <= current_chunk->size()); - return current_chunk->size() - offset; -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -namespace duckdb { - -struct ActiveQueryContext { - //! The query that is currently being executed - string query; - //! The currently open result - BaseQueryResult *open_result = nullptr; - //! Prepared statement data - shared_ptr prepared; - //! The query executor - unique_ptr executor; - //! The progress bar - unique_ptr progress_bar; -}; - -ClientContext::ClientContext(shared_ptr database) - : db(std::move(database)), interrupted(false), client_data(make_uniq(*this)), transaction(*this) { -} - -ClientContext::~ClientContext() { - if (Exception::UncaughtException()) { - return; - } - // destroy the client context and rollback if there is an active transaction - // but only if we are not destroying this client context as part of an exception stack unwind - Destroy(); -} - -unique_ptr ClientContext::LockContext() { - return make_uniq(context_lock); -} - -void ClientContext::Destroy() { - auto lock = LockContext(); - if (transaction.HasActiveTransaction()) { - transaction.ResetActiveQuery(); - if (!transaction.IsAutoCommit()) { - transaction.Rollback(); - } - } - CleanupInternal(*lock); -} - -unique_ptr ClientContext::Fetch(ClientContextLock &lock, StreamQueryResult &result) { - D_ASSERT(IsActiveResult(lock, &result)); - D_ASSERT(active_query->executor); - return FetchInternal(lock, *active_query->executor, result); -} - -unique_ptr ClientContext::FetchInternal(ClientContextLock &lock, Executor &executor, - BaseQueryResult &result) { - bool invalidate_query = true; - try { - // fetch the chunk and return it - auto chunk = executor.FetchChunk(); - if (!chunk || chunk->size() == 0) { - CleanupInternal(lock, &result); - } - return chunk; - } catch (StandardException &ex) { - // standard exceptions do not invalidate the current transaction - result.SetError(PreservedError(ex)); - invalidate_query = false; - } catch (FatalException &ex) { - // fatal exceptions invalidate the entire database - result.SetError(PreservedError(ex)); - auto &db = DatabaseInstance::GetDatabase(*this); - ValidChecker::Invalidate(db, ex.what()); - } catch (const Exception &ex) { - result.SetError(PreservedError(ex)); - } catch (std::exception &ex) { - result.SetError(PreservedError(ex)); - } catch (...) { // LCOV_EXCL_START - result.SetError(PreservedError("Unhandled exception in FetchInternal")); - } // LCOV_EXCL_STOP - CleanupInternal(lock, &result, invalidate_query); - return nullptr; -} - -void ClientContext::BeginTransactionInternal(ClientContextLock &lock, bool requires_valid_transaction) { - // check if we are on AutoCommit. In this case we should start a transaction - D_ASSERT(!active_query); - auto &db = DatabaseInstance::GetDatabase(*this); - if (ValidChecker::IsInvalidated(db)) { - throw FatalException(ErrorManager::FormatException(*this, ErrorType::INVALIDATED_DATABASE, - ValidChecker::InvalidatedMessage(db))); - } - if (requires_valid_transaction && transaction.HasActiveTransaction() && - ValidChecker::IsInvalidated(transaction.ActiveTransaction())) { - throw Exception(ErrorManager::FormatException(*this, ErrorType::INVALIDATED_TRANSACTION)); - } - active_query = make_uniq(); - if (transaction.IsAutoCommit()) { - transaction.BeginTransaction(); - } -} - -void ClientContext::BeginQueryInternal(ClientContextLock &lock, const string &query) { - BeginTransactionInternal(lock, false); - LogQueryInternal(lock, query); - active_query->query = query; - query_progress = -1; - transaction.SetActiveQuery(db->GetDatabaseManager().GetNewQueryNumber()); -} - -PreservedError ClientContext::EndQueryInternal(ClientContextLock &lock, bool success, bool invalidate_transaction) { - client_data->profiler->EndQuery(); - - if (client_data->http_state) { - client_data->http_state->Reset(); - } - - // Notify any registered state of query end - for (auto const &s : registered_state) { - s.second->QueryEnd(); - } - - D_ASSERT(active_query.get()); - active_query.reset(); - query_progress = -1; - PreservedError error; - try { - if (transaction.HasActiveTransaction()) { - // Move the query profiler into the history - auto &prev_profilers = client_data->query_profiler_history->GetPrevProfilers(); - prev_profilers.emplace_back(transaction.GetActiveQuery(), std::move(client_data->profiler)); - // Reinitialize the query profiler - client_data->profiler = make_shared(*this); - // Propagate settings of the saved query into the new profiler. - client_data->profiler->Propagate(*prev_profilers.back().second); - if (prev_profilers.size() >= client_data->query_profiler_history->GetPrevProfilersSize()) { - prev_profilers.pop_front(); - } - - transaction.ResetActiveQuery(); - if (transaction.IsAutoCommit()) { - if (success) { - transaction.Commit(); - } else { - transaction.Rollback(); - } - } else if (invalidate_transaction) { - D_ASSERT(!success); - ValidChecker::Invalidate(ActiveTransaction(), "Failed to commit"); - } - } - } catch (FatalException &ex) { - auto &db = DatabaseInstance::GetDatabase(*this); - ValidChecker::Invalidate(db, ex.what()); - error = PreservedError(ex); - } catch (const Exception &ex) { - error = PreservedError(ex); - } catch (std::exception &ex) { - error = PreservedError(ex); - } catch (...) { // LCOV_EXCL_START - error = PreservedError("Unhandled exception!"); - } // LCOV_EXCL_STOP - return error; -} - -void ClientContext::CleanupInternal(ClientContextLock &lock, BaseQueryResult *result, bool invalidate_transaction) { - client_data->http_state = make_shared(); - if (!active_query) { - // no query currently active - return; - } - if (active_query->executor) { - active_query->executor->CancelTasks(); - } - active_query->progress_bar.reset(); - - auto error = EndQueryInternal(lock, result ? !result->HasError() : false, invalidate_transaction); - if (result && !result->HasError()) { - // if an error occurred while committing report it in the result - result->SetError(error); - } - D_ASSERT(!active_query); -} - -Executor &ClientContext::GetExecutor() { - D_ASSERT(active_query); - D_ASSERT(active_query->executor); - return *active_query->executor; -} - -const string &ClientContext::GetCurrentQuery() { - D_ASSERT(active_query); - return active_query->query; -} - -unique_ptr ClientContext::FetchResultInternal(ClientContextLock &lock, PendingQueryResult &pending) { - D_ASSERT(active_query); - D_ASSERT(active_query->open_result == &pending); - D_ASSERT(active_query->prepared); - auto &executor = GetExecutor(); - auto &prepared = *active_query->prepared; - bool create_stream_result = prepared.properties.allow_stream_result && pending.allow_stream_result; - if (create_stream_result) { - D_ASSERT(!executor.HasResultCollector()); - active_query->progress_bar.reset(); - query_progress = -1; - - // successfully compiled SELECT clause, and it is the last statement - // return a StreamQueryResult so the client can call Fetch() on it and stream the result - auto stream_result = make_uniq(pending.statement_type, pending.properties, - shared_from_this(), pending.types, pending.names); - active_query->open_result = stream_result.get(); - return std::move(stream_result); - } - unique_ptr result; - if (executor.HasResultCollector()) { - // we have a result collector - fetch the result directly from the result collector - result = executor.GetResult(); - CleanupInternal(lock, result.get(), false); - } else { - // no result collector - create a materialized result by continuously fetching - auto result_collection = make_uniq(Allocator::DefaultAllocator(), pending.types); - D_ASSERT(!result_collection->Types().empty()); - auto materialized_result = - make_uniq(pending.statement_type, pending.properties, pending.names, - std::move(result_collection), GetClientProperties()); - - auto &collection = materialized_result->Collection(); - D_ASSERT(!collection.Types().empty()); - ColumnDataAppendState append_state; - collection.InitializeAppend(append_state); - while (true) { - auto chunk = FetchInternal(lock, GetExecutor(), *materialized_result); - if (!chunk || chunk->size() == 0) { - break; - } -#ifdef DEBUG - for (idx_t i = 0; i < chunk->ColumnCount(); i++) { - if (pending.types[i].id() == LogicalTypeId::VARCHAR) { - chunk->data[i].UTFVerify(chunk->size()); - } - } -#endif - collection.Append(append_state, *chunk); - } - result = std::move(materialized_result); - } - return result; -} - -static bool IsExplainAnalyze(SQLStatement *statement) { - if (!statement) { - return false; - } - if (statement->type != StatementType::EXPLAIN_STATEMENT) { - return false; - } - auto &explain = statement->Cast(); - return explain.explain_type == ExplainType::EXPLAIN_ANALYZE; -} - -shared_ptr -ClientContext::CreatePreparedStatement(ClientContextLock &lock, const string &query, unique_ptr statement, - optional_ptr> values) { - StatementType statement_type = statement->type; - auto result = make_shared(statement_type); - - auto &profiler = QueryProfiler::Get(*this); - profiler.StartQuery(query, IsExplainAnalyze(statement.get()), true); - profiler.StartPhase("planner"); - Planner planner(*this); - if (values) { - auto ¶meter_values = *values; - for (auto &value : parameter_values) { - planner.parameter_data.emplace(value.first, BoundParameterData(value.second)); - } - } - - client_data->http_state = make_shared(); - planner.CreatePlan(std::move(statement)); - D_ASSERT(planner.plan || !planner.properties.bound_all_parameters); - profiler.EndPhase(); - - auto plan = std::move(planner.plan); - // extract the result column names from the plan - result->properties = planner.properties; - result->names = planner.names; - result->types = planner.types; - result->value_map = std::move(planner.value_map); - result->catalog_version = MetaTransaction::Get(*this).catalog_version; - - if (!planner.properties.bound_all_parameters) { - return result; - } -#ifdef DEBUG - plan->Verify(*this); -#endif - if (config.enable_optimizer && plan->RequireOptimizer()) { - profiler.StartPhase("optimizer"); - Optimizer optimizer(*planner.binder, *this); - plan = optimizer.Optimize(std::move(plan)); - D_ASSERT(plan); - profiler.EndPhase(); - -#ifdef DEBUG - plan->Verify(*this); -#endif - } - - profiler.StartPhase("physical_planner"); - // now convert logical query plan into a physical query plan - PhysicalPlanGenerator physical_planner(*this); - auto physical_plan = physical_planner.CreatePlan(std::move(plan)); - profiler.EndPhase(); - -#ifdef DEBUG - D_ASSERT(!physical_plan->ToString().empty()); -#endif - result->plan = std::move(physical_plan); - return result; -} - -double ClientContext::GetProgress() { - return query_progress.load(); -} - -unique_ptr ClientContext::PendingPreparedStatement(ClientContextLock &lock, - shared_ptr statement_p, - const PendingQueryParameters ¶meters) { - D_ASSERT(active_query); - auto &statement = *statement_p; - if (ValidChecker::IsInvalidated(ActiveTransaction()) && statement.properties.requires_valid_transaction) { - throw Exception(ErrorManager::FormatException(*this, ErrorType::INVALIDATED_TRANSACTION)); - } - auto &transaction = MetaTransaction::Get(*this); - auto &manager = DatabaseManager::Get(*this); - for (auto &modified_database : statement.properties.modified_databases) { - auto entry = manager.GetDatabase(*this, modified_database); - if (!entry) { - throw InternalException("Database \"%s\" not found", modified_database); - } - if (entry->IsReadOnly()) { - throw Exception(StringUtil::Format( - "Cannot execute statement of type \"%s\" on database \"%s\" which is attached in read-only mode!", - StatementTypeToString(statement.statement_type), modified_database)); - } - transaction.ModifyDatabase(*entry); - } - - // bind the bound values before execution - case_insensitive_map_t owned_values; - if (parameters.parameters) { - auto ¶ms = *parameters.parameters; - for (auto &val : params) { - owned_values.emplace(val); - } - } - statement.Bind(std::move(owned_values)); - - active_query->executor = make_uniq(*this); - auto &executor = *active_query->executor; - if (config.enable_progress_bar) { - progress_bar_display_create_func_t display_create_func = nullptr; - if (config.print_progress_bar) { - // If a custom display is set, use that, otherwise just use the default - display_create_func = - config.display_create_func ? config.display_create_func : ProgressBar::DefaultProgressBarDisplay; - } - active_query->progress_bar = make_uniq(executor, config.wait_time, display_create_func); - active_query->progress_bar->Start(); - query_progress = 0; - } - auto stream_result = parameters.allow_stream_result && statement.properties.allow_stream_result; - if (!stream_result && statement.properties.return_type == StatementReturnType::QUERY_RESULT) { - unique_ptr collector; - auto &config = ClientConfig::GetConfig(*this); - auto get_method = - config.result_collector ? config.result_collector : PhysicalResultCollector::GetResultCollector; - collector = get_method(*this, statement); - D_ASSERT(collector->type == PhysicalOperatorType::RESULT_COLLECTOR); - executor.Initialize(std::move(collector)); - } else { - executor.Initialize(*statement.plan); - } - auto types = executor.GetTypes(); - D_ASSERT(types == statement.types); - D_ASSERT(!active_query->open_result); - - auto pending_result = - make_uniq(shared_from_this(), *statement_p, std::move(types), stream_result); - active_query->prepared = std::move(statement_p); - active_query->open_result = pending_result.get(); - return pending_result; -} - -PendingExecutionResult ClientContext::ExecuteTaskInternal(ClientContextLock &lock, PendingQueryResult &result) { - D_ASSERT(active_query); - D_ASSERT(active_query->open_result == &result); - try { - auto result = active_query->executor->ExecuteTask(); - if (active_query->progress_bar) { - active_query->progress_bar->Update(result == PendingExecutionResult::RESULT_READY); - query_progress = active_query->progress_bar->GetCurrentPercentage(); - } - return result; - } catch (FatalException &ex) { - // fatal exceptions invalidate the entire database - result.SetError(PreservedError(ex)); - auto &db = DatabaseInstance::GetDatabase(*this); - ValidChecker::Invalidate(db, ex.what()); - } catch (const Exception &ex) { - result.SetError(PreservedError(ex)); - } catch (std::exception &ex) { - result.SetError(PreservedError(ex)); - } catch (...) { // LCOV_EXCL_START - result.SetError(PreservedError("Unhandled exception in ExecuteTaskInternal")); - } // LCOV_EXCL_STOP - EndQueryInternal(lock, false, true); - return PendingExecutionResult::EXECUTION_ERROR; -} - -void ClientContext::InitialCleanup(ClientContextLock &lock) { - //! Cleanup any open results and reset the interrupted flag - CleanupInternal(lock); - interrupted = false; -} - -vector> ClientContext::ParseStatements(const string &query) { - auto lock = LockContext(); - return ParseStatementsInternal(*lock, query); -} - -vector> ClientContext::ParseStatementsInternal(ClientContextLock &lock, const string &query) { - Parser parser(GetParserOptions()); - parser.ParseQuery(query); - - PragmaHandler handler(*this); - handler.HandlePragmaStatements(lock, parser.statements); - - return std::move(parser.statements); -} - -void ClientContext::HandlePragmaStatements(vector> &statements) { - auto lock = LockContext(); - - PragmaHandler handler(*this); - handler.HandlePragmaStatements(*lock, statements); -} - -unique_ptr ClientContext::ExtractPlan(const string &query) { - auto lock = LockContext(); - - auto statements = ParseStatementsInternal(*lock, query); - if (statements.size() != 1) { - throw Exception("ExtractPlan can only prepare a single statement"); - } - - unique_ptr plan; - client_data->http_state = make_shared(); - RunFunctionInTransactionInternal(*lock, [&]() { - Planner planner(*this); - planner.CreatePlan(std::move(statements[0])); - D_ASSERT(planner.plan); - - plan = std::move(planner.plan); - - if (config.enable_optimizer) { - Optimizer optimizer(*planner.binder, *this); - plan = optimizer.Optimize(std::move(plan)); - } - - ColumnBindingResolver resolver; - resolver.Verify(*plan); - resolver.VisitOperator(*plan); - - plan->ResolveOperatorTypes(); - }); - return plan; -} - -unique_ptr ClientContext::PrepareInternal(ClientContextLock &lock, - unique_ptr statement) { - auto n_param = statement->n_param; - auto named_param_map = std::move(statement->named_param_map); - auto statement_query = statement->query; - shared_ptr prepared_data; - auto unbound_statement = statement->Copy(); - RunFunctionInTransactionInternal( - lock, [&]() { prepared_data = CreatePreparedStatement(lock, statement_query, std::move(statement)); }, false); - prepared_data->unbound_statement = std::move(unbound_statement); - return make_uniq(shared_from_this(), std::move(prepared_data), std::move(statement_query), - n_param, std::move(named_param_map)); -} - -unique_ptr ClientContext::Prepare(unique_ptr statement) { - auto lock = LockContext(); - // prepare the query - try { - InitialCleanup(*lock); - return PrepareInternal(*lock, std::move(statement)); - } catch (const Exception &ex) { - return make_uniq(PreservedError(ex)); - } catch (std::exception &ex) { - return make_uniq(PreservedError(ex)); - } -} - -unique_ptr ClientContext::Prepare(const string &query) { - auto lock = LockContext(); - // prepare the query - try { - InitialCleanup(*lock); - - // first parse the query - auto statements = ParseStatementsInternal(*lock, query); - if (statements.empty()) { - throw Exception("No statement to prepare!"); - } - if (statements.size() > 1) { - throw Exception("Cannot prepare multiple statements at once!"); - } - return PrepareInternal(*lock, std::move(statements[0])); - } catch (const Exception &ex) { - return make_uniq(PreservedError(ex)); - } catch (std::exception &ex) { - return make_uniq(PreservedError(ex)); - } -} - -unique_ptr ClientContext::PendingQueryPreparedInternal(ClientContextLock &lock, const string &query, - shared_ptr &prepared, - const PendingQueryParameters ¶meters) { - try { - InitialCleanup(lock); - } catch (const Exception &ex) { - return make_uniq(PreservedError(ex)); - } catch (std::exception &ex) { - return make_uniq(PreservedError(ex)); - } - return PendingStatementOrPreparedStatementInternal(lock, query, nullptr, prepared, parameters); -} - -unique_ptr ClientContext::PendingQuery(const string &query, - shared_ptr &prepared, - const PendingQueryParameters ¶meters) { - auto lock = LockContext(); - return PendingQueryPreparedInternal(*lock, query, prepared, parameters); -} - -unique_ptr ClientContext::Execute(const string &query, shared_ptr &prepared, - const PendingQueryParameters ¶meters) { - auto lock = LockContext(); - auto pending = PendingQueryPreparedInternal(*lock, query, prepared, parameters); - if (pending->HasError()) { - return make_uniq(pending->GetErrorObject()); - } - return pending->ExecuteInternal(*lock); -} - -unique_ptr ClientContext::Execute(const string &query, shared_ptr &prepared, - case_insensitive_map_t &values, bool allow_stream_result) { - PendingQueryParameters parameters; - parameters.parameters = &values; - parameters.allow_stream_result = allow_stream_result; - return Execute(query, prepared, parameters); -} - -unique_ptr ClientContext::PendingStatementInternal(ClientContextLock &lock, const string &query, - unique_ptr statement, - const PendingQueryParameters ¶meters) { - // prepare the query for execution - auto prepared = CreatePreparedStatement(lock, query, std::move(statement), parameters.parameters); - idx_t parameter_count = !parameters.parameters ? 0 : parameters.parameters->size(); - if (prepared->properties.parameter_count > 0 && parameter_count == 0) { - string error_message = StringUtil::Format("Expected %lld parameters, but none were supplied", - prepared->properties.parameter_count); - return make_uniq(PreservedError(error_message)); - } - if (!prepared->properties.bound_all_parameters) { - return make_uniq(PreservedError("Not all parameters were bound")); - } - // execute the prepared statement - return PendingPreparedStatement(lock, std::move(prepared), parameters); -} - -unique_ptr ClientContext::RunStatementInternal(ClientContextLock &lock, const string &query, - unique_ptr statement, - bool allow_stream_result, bool verify) { - PendingQueryParameters parameters; - parameters.allow_stream_result = allow_stream_result; - auto pending = PendingQueryInternal(lock, std::move(statement), parameters, verify); - if (pending->HasError()) { - return make_uniq(pending->GetErrorObject()); - } - return ExecutePendingQueryInternal(lock, *pending); -} - -bool ClientContext::IsActiveResult(ClientContextLock &lock, BaseQueryResult *result) { - if (!active_query) { - return false; - } - return active_query->open_result == result; -} - -unique_ptr ClientContext::PendingStatementOrPreparedStatementInternal( - ClientContextLock &lock, const string &query, unique_ptr statement, - shared_ptr &prepared, const PendingQueryParameters ¶meters) { - // check if we are on AutoCommit. In this case we should start a transaction. - if (statement && config.AnyVerification()) { - // query verification is enabled - // create a copy of the statement, and use the copy - // this way we verify that the copy correctly copies all properties - auto copied_statement = statement->Copy(); - switch (statement->type) { - case StatementType::SELECT_STATEMENT: { - // in case this is a select query, we verify the original statement - PreservedError error; - try { - error = VerifyQuery(lock, query, std::move(statement)); - } catch (const Exception &ex) { - error = PreservedError(ex); - } catch (std::exception &ex) { - error = PreservedError(ex); - } - if (error) { - // error in verifying query - return make_uniq(error); - } - statement = std::move(copied_statement); - break; - } -#ifndef DUCKDB_ALTERNATIVE_VERIFY - case StatementType::COPY_STATEMENT: - case StatementType::INSERT_STATEMENT: - case StatementType::DELETE_STATEMENT: - case StatementType::UPDATE_STATEMENT: { - Parser parser; - PreservedError error; - try { - parser.ParseQuery(statement->ToString()); - } catch (const Exception &ex) { - error = PreservedError(ex); - } catch (std::exception &ex) { - error = PreservedError(ex); - } - if (error) { - // error in verifying query - return make_uniq(error); - } - statement = std::move(parser.statements[0]); - break; - } -#endif - default: - statement = std::move(copied_statement); - break; - } - } - return PendingStatementOrPreparedStatement(lock, query, std::move(statement), prepared, parameters); -} - -unique_ptr ClientContext::PendingStatementOrPreparedStatement( - ClientContextLock &lock, const string &query, unique_ptr statement, - shared_ptr &prepared, const PendingQueryParameters ¶meters) { - unique_ptr result; - - try { - BeginQueryInternal(lock, query); - } catch (FatalException &ex) { - // fatal exceptions invalidate the entire database - auto &db = DatabaseInstance::GetDatabase(*this); - ValidChecker::Invalidate(db, ex.what()); - result = make_uniq(PreservedError(ex)); - return result; - } catch (const Exception &ex) { - return make_uniq(PreservedError(ex)); - } catch (std::exception &ex) { - return make_uniq(PreservedError(ex)); - } - // start the profiler - auto &profiler = QueryProfiler::Get(*this); - profiler.StartQuery(query, IsExplainAnalyze(statement ? statement.get() : prepared->unbound_statement.get())); - - bool invalidate_query = true; - try { - if (statement) { - result = PendingStatementInternal(lock, query, std::move(statement), parameters); - } else { - if (prepared->RequireRebind(*this, parameters.parameters)) { - // catalog was modified: rebind the statement before execution - auto new_prepared = - CreatePreparedStatement(lock, query, prepared->unbound_statement->Copy(), parameters.parameters); - D_ASSERT(new_prepared->properties.bound_all_parameters); - new_prepared->unbound_statement = std::move(prepared->unbound_statement); - prepared = std::move(new_prepared); - prepared->properties.bound_all_parameters = false; - } - result = PendingPreparedStatement(lock, prepared, parameters); - } - } catch (StandardException &ex) { - // standard exceptions do not invalidate the current transaction - result = make_uniq(PreservedError(ex)); - invalidate_query = false; - } catch (FatalException &ex) { - // fatal exceptions invalidate the entire database - if (!config.query_verification_enabled) { - auto &db = DatabaseInstance::GetDatabase(*this); - ValidChecker::Invalidate(db, ex.what()); - } - result = make_uniq(PreservedError(ex)); - } catch (const Exception &ex) { - // other types of exceptions do invalidate the current transaction - result = make_uniq(PreservedError(ex)); - } catch (std::exception &ex) { - // other types of exceptions do invalidate the current transaction - result = make_uniq(PreservedError(ex)); - } - if (result->HasError()) { - // query failed: abort now - EndQueryInternal(lock, false, invalidate_query); - return result; - } - D_ASSERT(active_query->open_result == result.get()); - return result; -} - -void ClientContext::LogQueryInternal(ClientContextLock &, const string &query) { - if (!client_data->log_query_writer) { -#ifdef DUCKDB_FORCE_QUERY_LOG - try { - string log_path(DUCKDB_FORCE_QUERY_LOG); - client_data->log_query_writer = - make_uniq(FileSystem::GetFileSystem(*this), log_path, - BufferedFileWriter::DEFAULT_OPEN_FLAGS, client_data->file_opener.get()); - } catch (...) { - return; - } -#else - return; -#endif - } - // log query path is set: log the query - client_data->log_query_writer->WriteData(const_data_ptr_cast(query.c_str()), query.size()); - client_data->log_query_writer->WriteData(const_data_ptr_cast("\n"), 1); - client_data->log_query_writer->Flush(); - client_data->log_query_writer->Sync(); -} - -unique_ptr ClientContext::Query(unique_ptr statement, bool allow_stream_result) { - auto pending_query = PendingQuery(std::move(statement), allow_stream_result); - if (pending_query->HasError()) { - return make_uniq(pending_query->GetErrorObject()); - } - return pending_query->Execute(); -} - -unique_ptr ClientContext::Query(const string &query, bool allow_stream_result) { - auto lock = LockContext(); - - PreservedError error; - vector> statements; - if (!ParseStatements(*lock, query, statements, error)) { - return make_uniq(std::move(error)); - } - if (statements.empty()) { - // no statements, return empty successful result - StatementProperties properties; - vector names; - auto collection = make_uniq(Allocator::DefaultAllocator()); - return make_uniq(StatementType::INVALID_STATEMENT, properties, std::move(names), - std::move(collection), GetClientProperties()); - } - - unique_ptr result; - QueryResult *last_result = nullptr; - bool last_had_result = false; - for (idx_t i = 0; i < statements.size(); i++) { - auto &statement = statements[i]; - bool is_last_statement = i + 1 == statements.size(); - PendingQueryParameters parameters; - parameters.allow_stream_result = allow_stream_result && is_last_statement; - auto pending_query = PendingQueryInternal(*lock, std::move(statement), parameters); - auto has_result = pending_query->properties.return_type == StatementReturnType::QUERY_RESULT; - unique_ptr current_result; - if (pending_query->HasError()) { - current_result = make_uniq(pending_query->GetErrorObject()); - } else { - current_result = ExecutePendingQueryInternal(*lock, *pending_query); - } - // now append the result to the list of results - if (!last_result || !last_had_result) { - // first result of the query - result = std::move(current_result); - last_result = result.get(); - last_had_result = has_result; - } else { - // later results; attach to the result chain - // but only if there is a result - if (!has_result) { - continue; - } - last_result->next = std::move(current_result); - last_result = last_result->next.get(); - } - } - return result; -} - -bool ClientContext::ParseStatements(ClientContextLock &lock, const string &query, - vector> &result, PreservedError &error) { - try { - InitialCleanup(lock); - // parse the query and transform it into a set of statements - result = ParseStatementsInternal(lock, query); - return true; - } catch (const Exception &ex) { - error = PreservedError(ex); - return false; - } catch (std::exception &ex) { - error = PreservedError(ex); - return false; - } -} - -unique_ptr ClientContext::PendingQuery(const string &query, bool allow_stream_result) { - auto lock = LockContext(); - - PreservedError error; - vector> statements; - if (!ParseStatements(*lock, query, statements, error)) { - return make_uniq(std::move(error)); - } - if (statements.size() != 1) { - return make_uniq(PreservedError("PendingQuery can only take a single statement")); - } - PendingQueryParameters parameters; - parameters.allow_stream_result = allow_stream_result; - return PendingQueryInternal(*lock, std::move(statements[0]), parameters); -} - -unique_ptr ClientContext::PendingQuery(unique_ptr statement, - bool allow_stream_result) { - auto lock = LockContext(); - PendingQueryParameters parameters; - parameters.allow_stream_result = allow_stream_result; - return PendingQueryInternal(*lock, std::move(statement), parameters); -} - -unique_ptr ClientContext::PendingQueryInternal(ClientContextLock &lock, - unique_ptr statement, - const PendingQueryParameters ¶meters, - bool verify) { - auto query = statement->query; - shared_ptr prepared; - if (verify) { - return PendingStatementOrPreparedStatementInternal(lock, query, std::move(statement), prepared, parameters); - } else { - return PendingStatementOrPreparedStatement(lock, query, std::move(statement), prepared, parameters); - } -} - -unique_ptr ClientContext::ExecutePendingQueryInternal(ClientContextLock &lock, PendingQueryResult &query) { - return query.ExecuteInternal(lock); -} - -void ClientContext::Interrupt() { - interrupted = true; -} - -void ClientContext::EnableProfiling() { - auto lock = LockContext(); - auto &config = ClientConfig::GetConfig(*this); - config.enable_profiler = true; - config.emit_profiler_output = true; -} - -void ClientContext::DisableProfiling() { - auto lock = LockContext(); - auto &config = ClientConfig::GetConfig(*this); - config.enable_profiler = false; -} - -void ClientContext::RegisterFunction(CreateFunctionInfo &info) { - RunFunctionInTransaction([&]() { - auto existing_function = Catalog::GetEntry(*this, INVALID_CATALOG, info.schema, - info.name, OnEntryNotFound::RETURN_NULL); - if (existing_function) { - auto &new_info = info.Cast(); - if (new_info.functions.MergeFunctionSet(existing_function->functions)) { - // function info was updated from catalog entry, rewrite is needed - info.on_conflict = OnCreateConflict::REPLACE_ON_CONFLICT; - } - } - // create function - auto &catalog = Catalog::GetSystemCatalog(*this); - catalog.CreateFunction(*this, info); - }); -} - -void ClientContext::RunFunctionInTransactionInternal(ClientContextLock &lock, const std::function &fun, - bool requires_valid_transaction) { - if (requires_valid_transaction && transaction.HasActiveTransaction() && - ValidChecker::IsInvalidated(ActiveTransaction())) { - throw TransactionException(ErrorManager::FormatException(*this, ErrorType::INVALIDATED_TRANSACTION)); - } - // check if we are on AutoCommit. In this case we should start a transaction - bool require_new_transaction = transaction.IsAutoCommit() && !transaction.HasActiveTransaction(); - if (require_new_transaction) { - D_ASSERT(!active_query); - transaction.BeginTransaction(); - } - try { - fun(); - } catch (StandardException &ex) { - if (require_new_transaction) { - transaction.Rollback(); - } - throw; - } catch (FatalException &ex) { - auto &db = DatabaseInstance::GetDatabase(*this); - ValidChecker::Invalidate(db, ex.what()); - throw; - } catch (std::exception &ex) { - if (require_new_transaction) { - transaction.Rollback(); - } else { - ValidChecker::Invalidate(ActiveTransaction(), ex.what()); - } - throw; - } - if (require_new_transaction) { - transaction.Commit(); - } -} - -void ClientContext::RunFunctionInTransaction(const std::function &fun, bool requires_valid_transaction) { - auto lock = LockContext(); - RunFunctionInTransactionInternal(*lock, fun, requires_valid_transaction); -} - -unique_ptr ClientContext::TableInfo(const string &schema_name, const string &table_name) { - unique_ptr result; - RunFunctionInTransaction([&]() { - // obtain the table info - auto table = Catalog::GetEntry(*this, INVALID_CATALOG, schema_name, table_name, - OnEntryNotFound::RETURN_NULL); - if (!table) { - return; - } - // write the table info to the result - result = make_uniq(); - result->schema = schema_name; - result->table = table_name; - for (auto &column : table->GetColumns().Logical()) { - result->columns.emplace_back(column.Name(), column.Type()); - } - }); - return result; -} - -void ClientContext::Append(TableDescription &description, ColumnDataCollection &collection) { - RunFunctionInTransaction([&]() { - auto &table_entry = - Catalog::GetEntry(*this, INVALID_CATALOG, description.schema, description.table); - // verify that the table columns and types match up - if (description.columns.size() != table_entry.GetColumns().PhysicalColumnCount()) { - throw Exception("Failed to append: table entry has different number of columns!"); - } - for (idx_t i = 0; i < description.columns.size(); i++) { - if (description.columns[i].Type() != table_entry.GetColumns().GetColumn(PhysicalIndex(i)).Type()) { - throw Exception("Failed to append: table entry has different number of columns!"); - } - } - table_entry.GetStorage().LocalAppend(table_entry, *this, collection); - }); -} - -void ClientContext::TryBindRelation(Relation &relation, vector &result_columns) { -#ifdef DEBUG - D_ASSERT(!relation.GetAlias().empty()); - D_ASSERT(!relation.ToString().empty()); -#endif - client_data->http_state = make_shared(); - RunFunctionInTransaction([&]() { - // bind the expressions - auto binder = Binder::CreateBinder(*this); - auto result = relation.Bind(*binder); - D_ASSERT(result.names.size() == result.types.size()); - - result_columns.reserve(result_columns.size() + result.names.size()); - for (idx_t i = 0; i < result.names.size(); i++) { - result_columns.emplace_back(result.names[i], result.types[i]); - } - }); -} - -unordered_set ClientContext::GetTableNames(const string &query) { - auto lock = LockContext(); - - auto statements = ParseStatementsInternal(*lock, query); - if (statements.size() != 1) { - throw InvalidInputException("Expected a single statement"); - } - - unordered_set result; - RunFunctionInTransactionInternal(*lock, [&]() { - // bind the expressions - auto binder = Binder::CreateBinder(*this); - binder->SetBindingMode(BindingMode::EXTRACT_NAMES); - binder->Bind(*statements[0]); - result = binder->GetTableNames(); - }); - return result; -} - -unique_ptr ClientContext::PendingQueryInternal(ClientContextLock &lock, - const shared_ptr &relation, - bool allow_stream_result) { - InitialCleanup(lock); - - string query; - if (config.query_verification_enabled) { - // run the ToString method of any relation we run, mostly to ensure it doesn't crash - relation->ToString(); - relation->GetAlias(); - if (relation->IsReadOnly()) { - // verify read only statements by running a select statement - auto select = make_uniq(); - select->node = relation->GetQueryNode(); - RunStatementInternal(lock, query, std::move(select), false); - } - } - - auto relation_stmt = make_uniq(relation); - PendingQueryParameters parameters; - parameters.allow_stream_result = allow_stream_result; - return PendingQueryInternal(lock, std::move(relation_stmt), parameters); -} - -unique_ptr ClientContext::PendingQuery(const shared_ptr &relation, - bool allow_stream_result) { - auto lock = LockContext(); - return PendingQueryInternal(*lock, relation, allow_stream_result); -} - -unique_ptr ClientContext::Execute(const shared_ptr &relation) { - auto lock = LockContext(); - auto &expected_columns = relation->Columns(); - auto pending = PendingQueryInternal(*lock, relation, false); - if (!pending->success) { - return make_uniq(pending->GetErrorObject()); - } - - unique_ptr result; - result = ExecutePendingQueryInternal(*lock, *pending); - if (result->HasError()) { - return result; - } - // verify that the result types and result names of the query match the expected result types/names - if (result->types.size() == expected_columns.size()) { - bool mismatch = false; - for (idx_t i = 0; i < result->types.size(); i++) { - if (result->types[i] != expected_columns[i].Type() || result->names[i] != expected_columns[i].Name()) { - mismatch = true; - break; - } - } - if (!mismatch) { - // all is as expected: return the result - return result; - } - } - // result mismatch - string err_str = "Result mismatch in query!\nExpected the following columns: ["; - for (idx_t i = 0; i < expected_columns.size(); i++) { - if (i > 0) { - err_str += ", "; - } - err_str += expected_columns[i].Name() + " " + expected_columns[i].Type().ToString(); - } - err_str += "]\nBut result contained the following: "; - for (idx_t i = 0; i < result->types.size(); i++) { - err_str += i == 0 ? "[" : ", "; - err_str += result->names[i] + " " + result->types[i].ToString(); - } - err_str += "]"; - return make_uniq(PreservedError(err_str)); -} - -bool ClientContext::TryGetCurrentSetting(const std::string &key, Value &result) { - // first check the built-in settings - auto &db_config = DBConfig::GetConfig(*this); - auto option = db_config.GetOptionByName(key); - if (option) { - result = option->get_setting(*this); - return true; - } - - // check the client session values - const auto &session_config_map = config.set_variables; - - auto session_value = session_config_map.find(key); - bool found_session_value = session_value != session_config_map.end(); - if (found_session_value) { - result = session_value->second; - return true; - } - // finally check the global session values - return db->TryGetCurrentSetting(key, result); -} - -ParserOptions ClientContext::GetParserOptions() const { - auto &client_config = ClientConfig::GetConfig(*this); - ParserOptions options; - options.preserve_identifier_case = client_config.preserve_identifier_case; - options.integer_division = client_config.integer_division; - options.max_expression_depth = client_config.max_expression_depth; - options.extensions = &DBConfig::GetConfig(*this).parser_extensions; - return options; -} - -ClientProperties ClientContext::GetClientProperties() const { - string timezone = "UTC"; - Value result; - // 1) Check Set Variable - auto &client_config = ClientConfig::GetConfig(*this); - auto tz_config = client_config.set_variables.find("timezone"); - if (tz_config == client_config.set_variables.end()) { - // 2) Check for Default Value - auto default_value = db->config.extension_parameters.find("timezone"); - if (default_value != db->config.extension_parameters.end()) { - timezone = default_value->second.default_value.GetValue(); - } - } else { - timezone = tz_config->second.GetValue(); - } - return {timezone, db->config.options.arrow_offset_size}; -} - -bool ClientContext::ExecutionIsFinished() { - if (!active_query || !active_query->executor) { - return false; - } - return active_query->executor->ExecutionIsFinished(); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -bool ClientContextFileOpener::TryGetCurrentSetting(const string &key, Value &result) { - return context.TryGetCurrentSetting(key, result); -} - -// LCOV_EXCL_START -bool ClientContextFileOpener::TryGetCurrentSetting(const string &key, Value &result, FileOpenerInfo &) { - return context.TryGetCurrentSetting(key, result); -} - -ClientContext *FileOpener::TryGetClientContext(FileOpener *opener) { - if (!opener) { - return nullptr; - } - return opener->TryGetClientContext(); -} - -bool FileOpener::TryGetCurrentSetting(FileOpener *opener, const string &key, Value &result) { - if (!opener) { - return false; - } - return opener->TryGetCurrentSetting(key, result); -} - -bool FileOpener::TryGetCurrentSetting(FileOpener *opener, const string &key, Value &result, FileOpenerInfo &info) { - if (!opener) { - return false; - } - return opener->TryGetCurrentSetting(key, result, info); -} - -bool FileOpener::TryGetCurrentSetting(const string &key, Value &result, FileOpenerInfo &info) { - return this->TryGetCurrentSetting(key, result); -} -// LCOV_EXCL_STOP -} // namespace duckdb - - - - - - - - - - - - - - - -namespace duckdb { - -class ClientFileSystem : public OpenerFileSystem { -public: - explicit ClientFileSystem(ClientContext &context_p) : context(context_p) { - } - - FileSystem &GetFileSystem() const override { - auto &config = DBConfig::GetConfig(context); - return *config.file_system; - } - optional_ptr GetOpener() const override { - return ClientData::Get(context).file_opener.get(); - } - -private: - ClientContext &context; -}; - -ClientData::ClientData(ClientContext &context) : catalog_search_path(make_uniq(context)) { - auto &db = DatabaseInstance::GetDatabase(context); - profiler = make_shared(context); - query_profiler_history = make_uniq(); - temporary_objects = make_shared(db, AttachedDatabaseType::TEMP_DATABASE); - temporary_objects->oid = DatabaseManager::Get(db).ModifyCatalog(); - random_engine = make_uniq(); - file_opener = make_uniq(context); - client_file_system = make_uniq(context); - temporary_objects->Initialize(); -} -ClientData::~ClientData() { -} - -ClientData &ClientData::Get(ClientContext &context) { - return *context.client_data; -} - -RandomEngine &RandomEngine::Get(ClientContext &context) { - return *ClientData::Get(context).random_engine; -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -static void ThrowIfExceptionIsInternal(StatementVerifier &verifier) { - if (!verifier.materialized_result) { - return; - } - auto &result = *verifier.materialized_result; - if (!result.HasError()) { - return; - } - auto &error = result.GetErrorObject(); - if (error.Type() == ExceptionType::INTERNAL) { - error.Throw(); - } -} - -PreservedError ClientContext::VerifyQuery(ClientContextLock &lock, const string &query, - unique_ptr statement) { - D_ASSERT(statement->type == StatementType::SELECT_STATEMENT); - // Aggressive query verification - - // The purpose of this function is to test correctness of otherwise hard to test features: - // Copy() of statements and expressions - // Serialize()/Deserialize() of expressions - // Hash() of expressions - // Equality() of statements and expressions - // ToString() of statements and expressions - // Correctness of plans both with and without optimizers - - const auto &stmt = *statement; - vector> statement_verifiers; - unique_ptr prepared_statement_verifier; - if (config.query_verification_enabled) { - statement_verifiers.emplace_back(StatementVerifier::Create(VerificationType::COPIED, stmt)); - statement_verifiers.emplace_back(StatementVerifier::Create(VerificationType::DESERIALIZED, stmt)); - statement_verifiers.emplace_back(StatementVerifier::Create(VerificationType::UNOPTIMIZED, stmt)); - prepared_statement_verifier = StatementVerifier::Create(VerificationType::PREPARED, stmt); -#ifdef DUCKDB_DEBUG_ASYNC_SINK_SOURCE - // This verification is quite slow, so we only run it for the async sink/source debug mode - statement_verifiers.emplace_back(StatementVerifier::Create(VerificationType::NO_OPERATOR_CACHING, stmt)); -#endif - } - if (config.verify_external) { - statement_verifiers.emplace_back(StatementVerifier::Create(VerificationType::EXTERNAL, stmt)); - } - - auto original = make_uniq(std::move(statement)); - for (auto &verifier : statement_verifiers) { - original->CheckExpressions(*verifier); - } - original->CheckExpressions(); - - // See below - auto statement_copy_for_explain = stmt.Copy(); - - // Save settings - bool optimizer_enabled = config.enable_optimizer; - bool profiling_is_enabled = config.enable_profiler; - bool force_external = config.force_external; - - // Disable profiling if it is enabled - if (profiling_is_enabled) { - config.enable_profiler = false; - } - - // Execute the original statement - bool any_failed = original->Run(*this, query, [&](const string &q, unique_ptr s) { - return RunStatementInternal(lock, q, std::move(s), false, false); - }); - if (!any_failed) { - statement_verifiers.emplace_back( - StatementVerifier::Create(VerificationType::PARSED, *statement_copy_for_explain)); - } - // Execute the verifiers - for (auto &verifier : statement_verifiers) { - bool failed = verifier->Run(*this, query, [&](const string &q, unique_ptr s) { - return RunStatementInternal(lock, q, std::move(s), false, false); - }); - any_failed = any_failed || failed; - } - - if (!any_failed && prepared_statement_verifier) { - // If none failed, we execute the prepared statement verifier - bool failed = prepared_statement_verifier->Run(*this, query, [&](const string &q, unique_ptr s) { - return RunStatementInternal(lock, q, std::move(s), false, false); - }); - if (!failed) { - // PreparedStatementVerifier fails if it runs into a ParameterNotAllowedException, which is OK - statement_verifiers.push_back(std::move(prepared_statement_verifier)); - } else { - // If it does fail, let's make sure it's not an internal exception - ThrowIfExceptionIsInternal(*prepared_statement_verifier); - } - } else { - if (ValidChecker::IsInvalidated(*db)) { - return original->materialized_result->GetErrorObject(); - } - } - - // Restore config setting - config.enable_optimizer = optimizer_enabled; - config.force_external = force_external; - - // Check explain, only if q does not already contain EXPLAIN - if (original->materialized_result->success) { - auto explain_q = "EXPLAIN " + query; - auto explain_stmt = make_uniq(std::move(statement_copy_for_explain)); - try { - RunStatementInternal(lock, explain_q, std::move(explain_stmt), false, false); - } catch (std::exception &ex) { // LCOV_EXCL_START - interrupted = false; - return PreservedError("EXPLAIN failed but query did not (" + string(ex.what()) + ")"); - } // LCOV_EXCL_STOP - -#ifdef DUCKDB_VERIFY_BOX_RENDERER - // this is pretty slow, so disabled by default - // test the box renderer on the result - // we mostly care that this does not crash - RandomEngine random; - BoxRendererConfig config; - // test with a random width - config.max_width = random.NextRandomInteger() % 500; - BoxRenderer renderer(config); - renderer.ToString(*this, original->materialized_result->names, original->materialized_result->Collection()); -#endif - } - - // Restore profiler setting - if (profiling_is_enabled) { - config.enable_profiler = true; - } - - // Now compare the results - // The results of all runs should be identical - for (auto &verifier : statement_verifiers) { - auto result = original->CompareResults(*verifier); - if (!result.empty()) { - return PreservedError(result); - } - } - - return PreservedError(); -} - -} // namespace duckdb - - - - - - - -#ifndef DUCKDB_NO_THREADS - -#endif - -#include -#include - -namespace duckdb { - -#ifdef DEBUG -bool DBConfigOptions::debug_print_bindings = false; -#endif - -#define DUCKDB_GLOBAL(_PARAM) \ - { \ - _PARAM::Name, _PARAM::Description, _PARAM::InputType, _PARAM::SetGlobal, nullptr, _PARAM::ResetGlobal, \ - nullptr, _PARAM::GetSetting \ - } -#define DUCKDB_GLOBAL_ALIAS(_ALIAS, _PARAM) \ - { \ - _ALIAS, _PARAM::Description, _PARAM::InputType, _PARAM::SetGlobal, nullptr, _PARAM::ResetGlobal, nullptr, \ - _PARAM::GetSetting \ - } - -#define DUCKDB_LOCAL(_PARAM) \ - { \ - _PARAM::Name, _PARAM::Description, _PARAM::InputType, nullptr, _PARAM::SetLocal, nullptr, _PARAM::ResetLocal, \ - _PARAM::GetSetting \ - } -#define DUCKDB_LOCAL_ALIAS(_ALIAS, _PARAM) \ - { \ - _ALIAS, _PARAM::Description, _PARAM::InputType, nullptr, _PARAM::SetLocal, nullptr, _PARAM::ResetLocal, \ - _PARAM::GetSetting \ - } - -#define DUCKDB_GLOBAL_LOCAL(_PARAM) \ - { \ - _PARAM::Name, _PARAM::Description, _PARAM::InputType, _PARAM::SetGlobal, _PARAM::SetLocal, \ - _PARAM::ResetGlobal, _PARAM::ResetLocal, _PARAM::GetSetting \ - } -#define DUCKDB_GLOBAL_LOCAL_ALIAS(_ALIAS, _PARAM) \ - { \ - _ALIAS, _PARAM::Description, _PARAM::InputType, _PARAM::SetGlobal, _PARAM::SetLocal, _PARAM::ResetGlobal, \ - _PARAM::ResetLocal, _PARAM::GetSetting \ - } -#define FINAL_SETTING \ - { nullptr, nullptr, LogicalTypeId::INVALID, nullptr, nullptr, nullptr, nullptr, nullptr } - -static ConfigurationOption internal_options[] = {DUCKDB_GLOBAL(AccessModeSetting), - DUCKDB_GLOBAL(CheckpointThresholdSetting), - DUCKDB_GLOBAL(DebugCheckpointAbort), - DUCKDB_LOCAL(DebugForceExternal), - DUCKDB_LOCAL(DebugForceNoCrossProduct), - DUCKDB_LOCAL(DebugAsOfIEJoin), - DUCKDB_LOCAL(PreferRangeJoins), - DUCKDB_GLOBAL(DebugWindowMode), - DUCKDB_GLOBAL_LOCAL(DefaultCollationSetting), - DUCKDB_GLOBAL(DefaultOrderSetting), - DUCKDB_GLOBAL(DefaultNullOrderSetting), - DUCKDB_GLOBAL(DisabledFileSystemsSetting), - DUCKDB_GLOBAL(DisabledOptimizersSetting), - DUCKDB_GLOBAL(EnableExternalAccessSetting), - DUCKDB_GLOBAL(EnableFSSTVectors), - DUCKDB_GLOBAL(AllowUnsignedExtensionsSetting), - DUCKDB_LOCAL(CustomExtensionRepository), - DUCKDB_LOCAL(AutoloadExtensionRepository), - DUCKDB_GLOBAL(AutoinstallKnownExtensions), - DUCKDB_GLOBAL(AutoloadKnownExtensions), - DUCKDB_GLOBAL(EnableObjectCacheSetting), - DUCKDB_GLOBAL(EnableHTTPMetadataCacheSetting), - DUCKDB_LOCAL(EnableProfilingSetting), - DUCKDB_LOCAL(EnableProgressBarSetting), - DUCKDB_LOCAL(EnableProgressBarPrintSetting), - DUCKDB_LOCAL(ExplainOutputSetting), - DUCKDB_GLOBAL(ExtensionDirectorySetting), - DUCKDB_GLOBAL(ExternalThreadsSetting), - DUCKDB_LOCAL(FileSearchPathSetting), - DUCKDB_GLOBAL(ForceCompressionSetting), - DUCKDB_GLOBAL(ForceBitpackingModeSetting), - DUCKDB_LOCAL(HomeDirectorySetting), - DUCKDB_LOCAL(LogQueryPathSetting), - DUCKDB_GLOBAL(LockConfigurationSetting), - DUCKDB_GLOBAL(ImmediateTransactionModeSetting), - DUCKDB_LOCAL(IntegerDivisionSetting), - DUCKDB_LOCAL(MaximumExpressionDepthSetting), - DUCKDB_GLOBAL(MaximumMemorySetting), - DUCKDB_GLOBAL_ALIAS("memory_limit", MaximumMemorySetting), - DUCKDB_GLOBAL_ALIAS("null_order", DefaultNullOrderSetting), - DUCKDB_LOCAL(OrderedAggregateThreshold), - DUCKDB_GLOBAL(PasswordSetting), - DUCKDB_LOCAL(PerfectHashThresholdSetting), - DUCKDB_LOCAL(PivotFilterThreshold), - DUCKDB_LOCAL(PivotLimitSetting), - DUCKDB_LOCAL(PreserveIdentifierCase), - DUCKDB_GLOBAL(PreserveInsertionOrder), - DUCKDB_LOCAL(ProfilerHistorySize), - DUCKDB_LOCAL(ProfileOutputSetting), - DUCKDB_LOCAL(ProfilingModeSetting), - DUCKDB_LOCAL_ALIAS("profiling_output", ProfileOutputSetting), - DUCKDB_LOCAL(ProgressBarTimeSetting), - DUCKDB_LOCAL(SchemaSetting), - DUCKDB_LOCAL(SearchPathSetting), - DUCKDB_GLOBAL(TempDirectorySetting), - DUCKDB_GLOBAL(ThreadsSetting), - DUCKDB_GLOBAL(UsernameSetting), - DUCKDB_GLOBAL(ExportLargeBufferArrow), - DUCKDB_GLOBAL_ALIAS("user", UsernameSetting), - DUCKDB_GLOBAL_ALIAS("wal_autocheckpoint", CheckpointThresholdSetting), - DUCKDB_GLOBAL_ALIAS("worker_threads", ThreadsSetting), - DUCKDB_GLOBAL(FlushAllocatorSetting), - FINAL_SETTING}; - -vector DBConfig::GetOptions() { - vector options; - for (idx_t index = 0; internal_options[index].name; index++) { - options.push_back(internal_options[index]); - } - return options; -} - -idx_t DBConfig::GetOptionCount() { - idx_t count = 0; - for (idx_t index = 0; internal_options[index].name; index++) { - count++; - } - return count; -} - -vector DBConfig::GetOptionNames() { - vector names; - for (idx_t i = 0, option_count = DBConfig::GetOptionCount(); i < option_count; i++) { - names.emplace_back(DBConfig::GetOptionByIndex(i)->name); - } - return names; -} - -ConfigurationOption *DBConfig::GetOptionByIndex(idx_t target_index) { - for (idx_t index = 0; internal_options[index].name; index++) { - if (index == target_index) { - return internal_options + index; - } - } - return nullptr; -} - -ConfigurationOption *DBConfig::GetOptionByName(const string &name) { - auto lname = StringUtil::Lower(name); - for (idx_t index = 0; internal_options[index].name; index++) { - D_ASSERT(StringUtil::Lower(internal_options[index].name) == string(internal_options[index].name)); - if (internal_options[index].name == lname) { - return internal_options + index; - } - } - return nullptr; -} - -void DBConfig::SetOption(const ConfigurationOption &option, const Value &value) { - SetOption(nullptr, option, value); -} - -void DBConfig::SetOptionByName(const string &name, const Value &value) { - auto option = DBConfig::GetOptionByName(name); - if (option) { - SetOption(*option, value); - } else { - options.unrecognized_options[name] = value; - } -} - -void DBConfig::SetOption(DatabaseInstance *db, const ConfigurationOption &option, const Value &value) { - lock_guard l(config_lock); - if (!option.set_global) { - throw InvalidInputException("Could not set option \"%s\" as a global option", option.name); - } - D_ASSERT(option.reset_global); - Value input = value.DefaultCastAs(option.parameter_type); - option.set_global(db, *this, input); -} - -void DBConfig::ResetOption(DatabaseInstance *db, const ConfigurationOption &option) { - lock_guard l(config_lock); - if (!option.reset_global) { - throw InternalException("Could not reset option \"%s\" as a global option", option.name); - } - D_ASSERT(option.set_global); - option.reset_global(db, *this); -} - -void DBConfig::SetOption(const string &name, Value value) { - lock_guard l(config_lock); - options.set_variables[name] = std::move(value); -} - -void DBConfig::ResetOption(const string &name) { - lock_guard l(config_lock); - auto extension_option = extension_parameters.find(name); - D_ASSERT(extension_option != extension_parameters.end()); - auto &default_value = extension_option->second.default_value; - if (!default_value.IsNull()) { - // Default is not NULL, override the setting - options.set_variables[name] = default_value; - } else { - // Otherwise just remove it from the 'set_variables' map - options.set_variables.erase(name); - } -} - -void DBConfig::AddExtensionOption(const string &name, string description, LogicalType parameter, - const Value &default_value, set_option_callback_t function) { - extension_parameters.insert( - make_pair(name, ExtensionOption(std::move(description), std::move(parameter), function, default_value))); - if (!default_value.IsNull()) { - // Default value is set, insert it into the 'set_variables' list - options.set_variables[name] = default_value; - } -} - -CastFunctionSet &DBConfig::GetCastFunctions() { - return *cast_functions; -} - -void DBConfig::SetDefaultMaxMemory() { - auto memory = FileSystem::GetAvailableMemory(); - if (memory != DConstants::INVALID_INDEX) { - options.maximum_memory = memory * 8 / 10; - } -} - -idx_t CGroupBandwidthQuota(idx_t physical_cores, FileSystem &fs) { - static constexpr const char *CPU_MAX = "/sys/fs/cgroup/cpu.max"; - static constexpr const char *CFS_QUOTA = "/sys/fs/cgroup/cpu/cpu.cfs_quota_us"; - static constexpr const char *CFS_PERIOD = "/sys/fs/cgroup/cpu/cpu.cfs_period_us"; - - int64_t quota, period; - char byte_buffer[1000]; - unique_ptr handle; - int64_t read_bytes; - - if (fs.FileExists(CPU_MAX)) { - // cgroup v2 - // https://www.kernel.org/doc/html/latest/admin-guide/cgroup-v2.html - handle = - fs.OpenFile(CPU_MAX, FileFlags::FILE_FLAGS_READ, FileSystem::DEFAULT_LOCK, FileSystem::DEFAULT_COMPRESSION); - read_bytes = fs.Read(*handle, (void *)byte_buffer, 999); - byte_buffer[read_bytes] = '\0'; - if (std::sscanf(byte_buffer, "%" SCNd64 " %" SCNd64 "", "a, &period) != 2) { - return physical_cores; - } - } else if (fs.FileExists(CFS_QUOTA) && fs.FileExists(CFS_PERIOD)) { - // cgroup v1 - // https://www.kernel.org/doc/html/latest/scheduler/sched-bwc.html#management - - // Read the quota, this indicates how many microseconds the CPU can be utilized by this cgroup per period - handle = fs.OpenFile(CFS_QUOTA, FileFlags::FILE_FLAGS_READ, FileSystem::DEFAULT_LOCK, - FileSystem::DEFAULT_COMPRESSION); - read_bytes = fs.Read(*handle, (void *)byte_buffer, 999); - byte_buffer[read_bytes] = '\0'; - if (std::sscanf(byte_buffer, "%" SCNd64 "", "a) != 1) { - return physical_cores; - } - - // Read the time period, a cgroup can utilize the CPU up to quota microseconds every period - handle = fs.OpenFile(CFS_PERIOD, FileFlags::FILE_FLAGS_READ, FileSystem::DEFAULT_LOCK, - FileSystem::DEFAULT_COMPRESSION); - read_bytes = fs.Read(*handle, (void *)byte_buffer, 999); - byte_buffer[read_bytes] = '\0'; - if (std::sscanf(byte_buffer, "%" SCNd64 "", &period) != 1) { - return physical_cores; - } - } else { - // No cgroup quota - return physical_cores; - } - if (quota > 0 && period > 0) { - return idx_t(std::ceil((double)quota / (double)period)); - } else { - return physical_cores; - } -} - -idx_t DBConfig::GetSystemMaxThreads(FileSystem &fs) { -#ifndef DUCKDB_NO_THREADS - idx_t physical_cores = std::thread::hardware_concurrency(); -#ifdef __linux__ - auto cores_available_per_period = CGroupBandwidthQuota(physical_cores, fs); - return MaxValue(cores_available_per_period, 1); -#else - return physical_cores; -#endif -#else - return 1; -#endif -} - -void DBConfig::SetDefaultMaxThreads() { -#ifndef DUCKDB_NO_THREADS - options.maximum_threads = GetSystemMaxThreads(*file_system); -#else - options.maximum_threads = 1; -#endif -} - -idx_t DBConfig::ParseMemoryLimit(const string &arg) { - if (arg[0] == '-' || arg == "null" || arg == "none") { - return DConstants::INVALID_INDEX; - } - // split based on the number/non-number - idx_t idx = 0; - while (StringUtil::CharacterIsSpace(arg[idx])) { - idx++; - } - idx_t num_start = idx; - while ((arg[idx] >= '0' && arg[idx] <= '9') || arg[idx] == '.' || arg[idx] == 'e' || arg[idx] == 'E' || - arg[idx] == '-') { - idx++; - } - if (idx == num_start) { - throw ParserException("Memory limit must have a number (e.g. SET memory_limit=1GB"); - } - string number = arg.substr(num_start, idx - num_start); - - // try to parse the number - double limit = Cast::Operation(string_t(number)); - - // now parse the memory limit unit (e.g. bytes, gb, etc) - while (StringUtil::CharacterIsSpace(arg[idx])) { - idx++; - } - idx_t start = idx; - while (idx < arg.size() && !StringUtil::CharacterIsSpace(arg[idx])) { - idx++; - } - if (limit < 0) { - // limit < 0, set limit to infinite - return (idx_t)-1; - } - string unit = StringUtil::Lower(arg.substr(start, idx - start)); - idx_t multiplier; - if (unit == "byte" || unit == "bytes" || unit == "b") { - multiplier = 1; - } else if (unit == "kilobyte" || unit == "kilobytes" || unit == "kb" || unit == "k") { - multiplier = 1000LL; - } else if (unit == "megabyte" || unit == "megabytes" || unit == "mb" || unit == "m") { - multiplier = 1000LL * 1000LL; - } else if (unit == "gigabyte" || unit == "gigabytes" || unit == "gb" || unit == "g") { - multiplier = 1000LL * 1000LL * 1000LL; - } else if (unit == "terabyte" || unit == "terabytes" || unit == "tb" || unit == "t") { - multiplier = 1000LL * 1000LL * 1000LL * 1000LL; - } else { - throw ParserException("Unknown unit for memory_limit: %s (expected: b, mb, gb or tb)", unit); - } - return (idx_t)multiplier * limit; -} - -// Right now we only really care about access mode when comparing DBConfigs -bool DBConfigOptions::operator==(const DBConfigOptions &other) const { - return other.access_mode == access_mode; -} - -bool DBConfig::operator==(const DBConfig &other) { - return other.options == options; -} - -bool DBConfig::operator!=(const DBConfig &other) { - return !(other.options == options); -} - -OrderType DBConfig::ResolveOrder(OrderType order_type) const { - if (order_type != OrderType::ORDER_DEFAULT) { - return order_type; - } - return options.default_order_type; -} - -OrderByNullType DBConfig::ResolveNullOrder(OrderType order_type, OrderByNullType null_type) const { - if (null_type != OrderByNullType::ORDER_DEFAULT) { - return null_type; - } - switch (options.default_null_order) { - case DefaultOrderByNullType::NULLS_FIRST: - return OrderByNullType::NULLS_FIRST; - case DefaultOrderByNullType::NULLS_LAST: - return OrderByNullType::NULLS_LAST; - case DefaultOrderByNullType::NULLS_FIRST_ON_ASC_LAST_ON_DESC: - return order_type == OrderType::ASCENDING ? OrderByNullType::NULLS_FIRST : OrderByNullType::NULLS_LAST; - case DefaultOrderByNullType::NULLS_LAST_ON_ASC_FIRST_ON_DESC: - return order_type == OrderType::ASCENDING ? OrderByNullType::NULLS_LAST : OrderByNullType::NULLS_FIRST; - default: - throw InternalException("Unknown null order setting"); - } -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - - -namespace duckdb { - -Connection::Connection(DatabaseInstance &database) : context(make_shared(database.shared_from_this())) { - ConnectionManager::Get(database).AddConnection(*context); -#ifdef DEBUG - EnableProfiling(); - context->config.emit_profiler_output = false; -#endif -} - -Connection::Connection(DuckDB &database) : Connection(*database.instance) { -} - -Connection::~Connection() { - ConnectionManager::Get(*context->db).RemoveConnection(*context); -} - -string Connection::GetProfilingInformation(ProfilerPrintFormat format) { - auto &profiler = QueryProfiler::Get(*context); - if (format == ProfilerPrintFormat::JSON) { - return profiler.ToJSON(); - } else { - return profiler.QueryTreeToString(); - } -} - -void Connection::Interrupt() { - context->Interrupt(); -} - -void Connection::EnableProfiling() { - context->EnableProfiling(); -} - -void Connection::DisableProfiling() { - context->DisableProfiling(); -} - -void Connection::EnableQueryVerification() { - ClientConfig::GetConfig(*context).query_verification_enabled = true; -} - -void Connection::DisableQueryVerification() { - ClientConfig::GetConfig(*context).query_verification_enabled = false; -} - -void Connection::ForceParallelism() { - ClientConfig::GetConfig(*context).verify_parallelism = true; -} - -unique_ptr Connection::SendQuery(const string &query) { - return context->Query(query, true); -} - -unique_ptr Connection::Query(const string &query) { - auto result = context->Query(query, false); - D_ASSERT(result->type == QueryResultType::MATERIALIZED_RESULT); - return unique_ptr_cast(std::move(result)); -} - -DUCKDB_API string Connection::GetSubstrait(const string &query) { - vector params; - params.emplace_back(query); - auto result = TableFunction("get_substrait", params)->Execute(); - auto protobuf = result->FetchRaw()->GetValue(0, 0); - return protobuf.GetValueUnsafe().GetString(); -} - -DUCKDB_API unique_ptr Connection::FromSubstrait(const string &proto) { - vector params; - params.emplace_back(Value::BLOB_RAW(proto)); - return TableFunction("from_substrait", params)->Execute(); -} - -DUCKDB_API string Connection::GetSubstraitJSON(const string &query) { - vector params; - params.emplace_back(query); - auto result = TableFunction("get_substrait_json", params)->Execute(); - auto protobuf = result->FetchRaw()->GetValue(0, 0); - return protobuf.GetValueUnsafe().GetString(); -} - -DUCKDB_API unique_ptr Connection::FromSubstraitJSON(const string &json) { - vector params; - params.emplace_back(json); - return TableFunction("from_substrait_json", params)->Execute(); -} - -unique_ptr Connection::Query(unique_ptr statement) { - auto result = context->Query(std::move(statement), false); - D_ASSERT(result->type == QueryResultType::MATERIALIZED_RESULT); - return unique_ptr_cast(std::move(result)); -} - -unique_ptr Connection::PendingQuery(const string &query, bool allow_stream_result) { - return context->PendingQuery(query, allow_stream_result); -} - -unique_ptr Connection::PendingQuery(unique_ptr statement, bool allow_stream_result) { - return context->PendingQuery(std::move(statement), allow_stream_result); -} - -unique_ptr Connection::Prepare(const string &query) { - return context->Prepare(query); -} - -unique_ptr Connection::Prepare(unique_ptr statement) { - return context->Prepare(std::move(statement)); -} - -unique_ptr Connection::QueryParamsRecursive(const string &query, vector &values) { - auto statement = Prepare(query); - if (statement->HasError()) { - return make_uniq(statement->error); - } - return statement->Execute(values, false); -} - -unique_ptr Connection::TableInfo(const string &table_name) { - return TableInfo(INVALID_SCHEMA, table_name); -} - -unique_ptr Connection::TableInfo(const string &schema_name, const string &table_name) { - return context->TableInfo(schema_name, table_name); -} - -vector> Connection::ExtractStatements(const string &query) { - return context->ParseStatements(query); -} - -unique_ptr Connection::ExtractPlan(const string &query) { - return context->ExtractPlan(query); -} - -void Connection::Append(TableDescription &description, DataChunk &chunk) { - if (chunk.size() == 0) { - return; - } - ColumnDataCollection collection(Allocator::Get(*context), chunk.GetTypes()); - collection.Append(chunk); - Append(description, collection); -} - -void Connection::Append(TableDescription &description, ColumnDataCollection &collection) { - context->Append(description, collection); -} - -shared_ptr Connection::Table(const string &table_name) { - return Table(DEFAULT_SCHEMA, table_name); -} - -shared_ptr Connection::Table(const string &schema_name, const string &table_name) { - auto table_info = TableInfo(schema_name, table_name); - if (!table_info) { - throw CatalogException("Table '%s' does not exist!", table_name); - } - return make_shared(context, std::move(table_info)); -} - -shared_ptr Connection::View(const string &tname) { - return View(DEFAULT_SCHEMA, tname); -} - -shared_ptr Connection::View(const string &schema_name, const string &table_name) { - return make_shared(context, schema_name, table_name); -} - -shared_ptr Connection::TableFunction(const string &fname) { - vector values; - named_parameter_map_t named_parameters; - return TableFunction(fname, values, named_parameters); -} - -shared_ptr Connection::TableFunction(const string &fname, const vector &values, - const named_parameter_map_t &named_parameters) { - return make_shared(context, fname, values, named_parameters); -} - -shared_ptr Connection::TableFunction(const string &fname, const vector &values) { - return make_shared(context, fname, values); -} - -shared_ptr Connection::Values(const vector> &values) { - vector column_names; - return Values(values, column_names); -} - -shared_ptr Connection::Values(const vector> &values, const vector &column_names, - const string &alias) { - return make_shared(context, values, column_names, alias); -} - -shared_ptr Connection::Values(const string &values) { - vector column_names; - return Values(values, column_names); -} - -shared_ptr Connection::Values(const string &values, const vector &column_names, const string &alias) { - return make_shared(context, values, column_names, alias); -} - -shared_ptr Connection::ReadCSV(const string &csv_file) { - named_parameter_map_t options; - return ReadCSV(csv_file, std::move(options)); -} - -shared_ptr Connection::ReadCSV(const string &csv_file, named_parameter_map_t &&options) { - return make_shared(context, csv_file, std::move(options)); -} - -shared_ptr Connection::ReadCSV(const string &csv_file, const vector &columns) { - // parse columns - vector column_list; - for (auto &column : columns) { - auto col_list = Parser::ParseColumnList(column, context->GetParserOptions()); - if (col_list.LogicalColumnCount() != 1) { - throw ParserException("Expected a single column definition"); - } - column_list.push_back(std::move(col_list.GetColumnMutable(LogicalIndex(0)))); - } - return make_shared(context, csv_file, std::move(column_list)); -} - -shared_ptr Connection::ReadParquet(const string &parquet_file, bool binary_as_string) { - vector params; - params.emplace_back(parquet_file); - named_parameter_map_t named_parameters({{"binary_as_string", Value::BOOLEAN(binary_as_string)}}); - return TableFunction("parquet_scan", params, named_parameters)->Alias(parquet_file); -} - -unordered_set Connection::GetTableNames(const string &query) { - return context->GetTableNames(query); -} - -shared_ptr Connection::RelationFromQuery(const string &query, const string &alias, const string &error) { - return RelationFromQuery(QueryRelation::ParseStatement(*context, query, error), alias); -} - -shared_ptr Connection::RelationFromQuery(unique_ptr select_stmt, const string &alias) { - return make_shared(context, std::move(select_stmt), alias); -} - -void Connection::BeginTransaction() { - auto result = Query("BEGIN TRANSACTION"); - if (result->HasError()) { - result->ThrowError(); - } -} - -void Connection::Commit() { - auto result = Query("COMMIT"); - if (result->HasError()) { - result->ThrowError(); - } -} - -void Connection::Rollback() { - auto result = Query("ROLLBACK"); - if (result->HasError()) { - result->ThrowError(); - } -} - -void Connection::SetAutoCommit(bool auto_commit) { - context->transaction.SetAutoCommit(auto_commit); -} - -bool Connection::IsAutoCommit() { - return context->transaction.IsAutoCommit(); -} -bool Connection::HasActiveTransaction() { - return context->transaction.HasActiveTransaction(); -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - - - - - - -#ifndef DUCKDB_NO_THREADS - -#endif - -namespace duckdb { - -DBConfig::DBConfig() { - compression_functions = make_uniq(); - cast_functions = make_uniq(); - error_manager = make_uniq(); -} - -DBConfig::DBConfig(std::unordered_map &config_dict, bool read_only) : DBConfig::DBConfig() { - if (read_only) { - options.access_mode = AccessMode::READ_ONLY; - } - for (auto &kv : config_dict) { - string key = kv.first; - string val = kv.second; - auto opt_val = Value(val); - DBConfig::SetOptionByName(key, opt_val); - } -} - -DBConfig::~DBConfig() { -} - -DatabaseInstance::DatabaseInstance() { -} - -DatabaseInstance::~DatabaseInstance() { -} - -BufferManager &BufferManager::GetBufferManager(DatabaseInstance &db) { - return db.GetBufferManager(); -} - -BufferManager &BufferManager::GetBufferManager(AttachedDatabase &db) { - return BufferManager::GetBufferManager(db.GetDatabase()); -} - -DatabaseInstance &DatabaseInstance::GetDatabase(ClientContext &context) { - return *context.db; -} - -DatabaseManager &DatabaseInstance::GetDatabaseManager() { - if (!db_manager) { - throw InternalException("Missing DB manager"); - } - return *db_manager; -} - -Catalog &Catalog::GetSystemCatalog(DatabaseInstance &db) { - return db.GetDatabaseManager().GetSystemCatalog(); -} - -Catalog &Catalog::GetCatalog(AttachedDatabase &db) { - return db.GetCatalog(); -} - -FileSystem &FileSystem::GetFileSystem(DatabaseInstance &db) { - return db.GetFileSystem(); -} - -FileSystem &FileSystem::Get(AttachedDatabase &db) { - return FileSystem::GetFileSystem(db.GetDatabase()); -} - -DBConfig &DBConfig::GetConfig(DatabaseInstance &db) { - return db.config; -} - -ClientConfig &ClientConfig::GetConfig(ClientContext &context) { - return context.config; -} - -DBConfig &DBConfig::Get(AttachedDatabase &db) { - return DBConfig::GetConfig(db.GetDatabase()); -} - -const DBConfig &DBConfig::GetConfig(const DatabaseInstance &db) { - return db.config; -} - -const ClientConfig &ClientConfig::GetConfig(const ClientContext &context) { - return context.config; -} - -TransactionManager &TransactionManager::Get(AttachedDatabase &db) { - return db.GetTransactionManager(); -} - -ConnectionManager &ConnectionManager::Get(DatabaseInstance &db) { - return db.GetConnectionManager(); -} - -ClientContext *ConnectionManager::GetConnection(DatabaseInstance *db) { - for (auto &conn : connections) { - if (conn.first->db.get() == db) { - return conn.first; - } - } - return nullptr; -} - -ConnectionManager &ConnectionManager::Get(ClientContext &context) { - return ConnectionManager::Get(DatabaseInstance::GetDatabase(context)); -} - -duckdb::unique_ptr DatabaseInstance::CreateAttachedDatabase(AttachInfo &info, const string &type, - AccessMode access_mode) { - duckdb::unique_ptr attached_database; - if (!type.empty()) { - // find the storage extension - auto extension_name = ExtensionHelper::ApplyExtensionAlias(type); - auto entry = config.storage_extensions.find(extension_name); - if (entry == config.storage_extensions.end()) { - throw BinderException("Unrecognized storage type \"%s\"", type); - } - - if (entry->second->attach != nullptr && entry->second->create_transaction_manager != nullptr) { - // use storage extension to create the initial database - attached_database = make_uniq(*this, Catalog::GetSystemCatalog(*this), *entry->second, - info.name, info, access_mode); - } else { - attached_database = - make_uniq(*this, Catalog::GetSystemCatalog(*this), info.name, info.path, access_mode); - } - } else { - // check if this is an in-memory database or not - attached_database = - make_uniq(*this, Catalog::GetSystemCatalog(*this), info.name, info.path, access_mode); - } - return attached_database; -} - -void DatabaseInstance::CreateMainDatabase() { - AttachInfo info; - info.name = AttachedDatabase::ExtractDatabaseName(config.options.database_path, GetFileSystem()); - info.path = config.options.database_path; - - auto attached_database = CreateAttachedDatabase(info, config.options.database_type, config.options.access_mode); - auto initial_database = attached_database.get(); - { - Connection con(*this); - con.BeginTransaction(); - db_manager->AddDatabase(*con.context, std::move(attached_database)); - con.Commit(); - } - - // initialize the database - initial_database->SetInitialDatabase(); - initial_database->Initialize(); -} - -void ThrowExtensionSetUnrecognizedOptions(const unordered_map &unrecognized_options) { - auto unrecognized_options_iter = unrecognized_options.begin(); - string unrecognized_option_keys = unrecognized_options_iter->first; - while (++unrecognized_options_iter != unrecognized_options.end()) { - unrecognized_option_keys = "," + unrecognized_options_iter->first; - } - throw InvalidInputException("Unrecognized configuration property \"%s\"", unrecognized_option_keys); -} - -void DatabaseInstance::Initialize(const char *database_path, DBConfig *user_config) { - DBConfig default_config; - DBConfig *config_ptr = &default_config; - if (user_config) { - config_ptr = user_config; - } - - if (config_ptr->options.temporary_directory.empty() && database_path) { - // no directory specified: use default temp path - config_ptr->options.temporary_directory = string(database_path) + ".tmp"; - - // special treatment for in-memory mode - if (strcmp(database_path, ":memory:") == 0) { - config_ptr->options.temporary_directory = ".tmp"; - } - } - - if (database_path) { - config_ptr->options.database_path = database_path; - } else { - config_ptr->options.database_path.clear(); - } - Configure(*config_ptr); - - if (user_config && !user_config->options.use_temporary_directory) { - // temporary directories explicitly disabled - config.options.temporary_directory = string(); - } - - db_manager = make_uniq(*this); - buffer_manager = make_uniq(*this, config.options.temporary_directory); - scheduler = make_uniq(*this); - object_cache = make_uniq(); - connection_manager = make_uniq(); - - // check if we are opening a standard DuckDB database or an extension database - if (config.options.database_type.empty()) { - auto path_and_type = DBPathAndType::Parse(config.options.database_path, config); - config.options.database_type = path_and_type.type; - config.options.database_path = path_and_type.path; - } - - // initialize the system catalog - db_manager->InitializeSystemCatalog(); - - if (!config.options.database_type.empty()) { - // if we are opening an extension database - load the extension - if (!config.file_system) { - throw InternalException("No file system!?"); - } - ExtensionHelper::LoadExternalExtension(*this, *config.file_system, config.options.database_type, nullptr); - } - - if (!config.options.unrecognized_options.empty()) { - ThrowExtensionSetUnrecognizedOptions(config.options.unrecognized_options); - } - - if (!db_manager->HasDefaultDatabase()) { - CreateMainDatabase(); - } - - // only increase thread count after storage init because we get races on catalog otherwise - scheduler->SetThreads(config.options.maximum_threads); -} - -DuckDB::DuckDB(const char *path, DBConfig *new_config) : instance(make_shared()) { - instance->Initialize(path, new_config); - if (instance->config.options.load_extensions) { - ExtensionHelper::LoadAllExtensions(*this); - } -} - -DuckDB::DuckDB(const string &path, DBConfig *config) : DuckDB(path.c_str(), config) { -} - -DuckDB::DuckDB(DatabaseInstance &instance_p) : instance(instance_p.shared_from_this()) { -} - -DuckDB::~DuckDB() { -} - -BufferManager &DatabaseInstance::GetBufferManager() { - return *buffer_manager; -} - -BufferPool &DatabaseInstance::GetBufferPool() { - return *config.buffer_pool; -} - -DatabaseManager &DatabaseManager::Get(DatabaseInstance &db) { - return db.GetDatabaseManager(); -} - -DatabaseManager &DatabaseManager::Get(ClientContext &db) { - return DatabaseManager::Get(*db.db); -} - -TaskScheduler &DatabaseInstance::GetScheduler() { - return *scheduler; -} - -ObjectCache &DatabaseInstance::GetObjectCache() { - return *object_cache; -} - -FileSystem &DatabaseInstance::GetFileSystem() { - return *config.file_system; -} - -ConnectionManager &DatabaseInstance::GetConnectionManager() { - return *connection_manager; -} - -FileSystem &DuckDB::GetFileSystem() { - return instance->GetFileSystem(); -} - -Allocator &Allocator::Get(ClientContext &context) { - return Allocator::Get(*context.db); -} - -Allocator &Allocator::Get(DatabaseInstance &db) { - return *db.config.allocator; -} - -Allocator &Allocator::Get(AttachedDatabase &db) { - return Allocator::Get(db.GetDatabase()); -} - -void DatabaseInstance::Configure(DBConfig &new_config) { - config.options = new_config.options; - if (config.options.access_mode == AccessMode::UNDEFINED) { - config.options.access_mode = AccessMode::READ_WRITE; - } - if (new_config.file_system) { - config.file_system = std::move(new_config.file_system); - } else { - config.file_system = make_uniq(); - } - if (config.options.maximum_memory == (idx_t)-1) { - config.SetDefaultMaxMemory(); - } - if (new_config.options.maximum_threads == (idx_t)-1) { - config.SetDefaultMaxThreads(); - } - config.allocator = std::move(new_config.allocator); - if (!config.allocator) { - config.allocator = make_uniq(); - } - config.replacement_scans = std::move(new_config.replacement_scans); - config.parser_extensions = std::move(new_config.parser_extensions); - config.error_manager = std::move(new_config.error_manager); - if (!config.error_manager) { - config.error_manager = make_uniq(); - } - if (!config.default_allocator) { - config.default_allocator = Allocator::DefaultAllocatorReference(); - } - if (new_config.buffer_pool) { - config.buffer_pool = std::move(new_config.buffer_pool); - } else { - config.buffer_pool = make_shared(config.options.maximum_memory); - } -} - -DBConfig &DBConfig::GetConfig(ClientContext &context) { - return context.db->config; -} - -const DBConfig &DBConfig::GetConfig(const ClientContext &context) { - return context.db->config; -} - -idx_t DatabaseInstance::NumberOfThreads() { - return scheduler->NumberOfThreads(); -} - -const unordered_set &DatabaseInstance::LoadedExtensions() { - return loaded_extensions; -} - -idx_t DuckDB::NumberOfThreads() { - return instance->NumberOfThreads(); -} - -bool DatabaseInstance::ExtensionIsLoaded(const std::string &name) { - auto extension_name = ExtensionHelper::GetExtensionName(name); - return loaded_extensions.find(extension_name) != loaded_extensions.end(); -} - -bool DuckDB::ExtensionIsLoaded(const std::string &name) { - return instance->ExtensionIsLoaded(name); -} - -void DatabaseInstance::SetExtensionLoaded(const std::string &name) { - auto extension_name = ExtensionHelper::GetExtensionName(name); - loaded_extensions.insert(extension_name); - - auto &callbacks = DBConfig::GetConfig(*this).extension_callbacks; - for (auto &callback : callbacks) { - callback->OnExtensionLoaded(*this, name); - } -} - -bool DatabaseInstance::TryGetCurrentSetting(const std::string &key, Value &result) { - // check the session values - auto &db_config = DBConfig::GetConfig(*this); - const auto &global_config_map = db_config.options.set_variables; - - auto global_value = global_config_map.find(key); - bool found_global_value = global_value != global_config_map.end(); - if (!found_global_value) { - return false; - } - result = global_value->second; - return true; -} - -ValidChecker &DatabaseInstance::GetValidChecker() { - return db_validity; -} - -ValidChecker &ValidChecker::Get(DatabaseInstance &db) { - return db.GetValidChecker(); -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -DatabaseManager::DatabaseManager(DatabaseInstance &db) : catalog_version(0), current_query_number(1) { - system = make_uniq(db); - databases = make_uniq(system->GetCatalog()); -} - -DatabaseManager::~DatabaseManager() { -} - -DatabaseManager &DatabaseManager::Get(AttachedDatabase &db) { - return DatabaseManager::Get(db.GetDatabase()); -} - -void DatabaseManager::InitializeSystemCatalog() { - system->Initialize(); -} - -optional_ptr DatabaseManager::GetDatabase(ClientContext &context, const string &name) { - if (StringUtil::Lower(name) == TEMP_CATALOG) { - return context.client_data->temporary_objects.get(); - } - return reinterpret_cast(databases->GetEntry(context, name).get()); -} - -void DatabaseManager::AddDatabase(ClientContext &context, unique_ptr db_instance) { - auto name = db_instance->GetName(); - db_instance->oid = ModifyCatalog(); - DependencyList dependencies; - if (default_database.empty()) { - default_database = name; - } - if (!databases->CreateEntry(context, name, std::move(db_instance), dependencies)) { - throw BinderException("Failed to attach database: database with name \"%s\" already exists", name); - } -} - -void DatabaseManager::DetachDatabase(ClientContext &context, const string &name, OnEntryNotFound if_not_found) { - if (GetDefaultDatabase(context) == name) { - throw BinderException("Cannot detach database \"%s\" because it is the default database. Select a different " - "database using `USE` to allow detaching this database", - name); - } - if (!databases->DropEntry(context, name, false, true)) { - if (if_not_found == OnEntryNotFound::THROW_EXCEPTION) { - throw BinderException("Failed to detach database with name \"%s\": database not found", name); - } - } -} - -optional_ptr DatabaseManager::GetDatabaseFromPath(ClientContext &context, const string &path) { - auto databases = GetDatabases(context); - for (auto &db_ref : databases) { - auto &db = db_ref.get(); - if (db.IsSystem()) { - continue; - } - auto &catalog = Catalog::GetCatalog(db); - if (catalog.InMemory()) { - continue; - } - auto db_path = catalog.GetDBPath(); - if (StringUtil::CIEquals(path, db_path)) { - return &db; - } - } - return nullptr; -} - -const string &DatabaseManager::GetDefaultDatabase(ClientContext &context) { - auto &config = ClientData::Get(context); - auto &default_entry = config.catalog_search_path->GetDefault(); - if (IsInvalidCatalog(default_entry.catalog)) { - auto &result = DatabaseManager::Get(context).default_database; - if (result.empty()) { - throw InternalException("Calling DatabaseManager::GetDefaultDatabase with no default database set"); - } - return result; - } - return default_entry.catalog; -} - -// LCOV_EXCL_START -void DatabaseManager::SetDefaultDatabase(ClientContext &context, const string &new_value) { - auto db_entry = GetDatabase(context, new_value); - - if (!db_entry) { - throw InternalException("Database \"%s\" not found", new_value); - } else if (db_entry->IsTemporary()) { - throw InternalException("Cannot set the default database to a temporary database"); - } else if (db_entry->IsSystem()) { - throw InternalException("Cannot set the default database to a system database"); - } - - default_database = new_value; -} -// LCOV_EXCL_STOP - -vector> DatabaseManager::GetDatabases(ClientContext &context) { - vector> result; - databases->Scan(context, [&](CatalogEntry &entry) { result.push_back(entry.Cast()); }); - result.push_back(*system); - result.push_back(*context.client_data->temporary_objects); - return result; -} - -Catalog &DatabaseManager::GetSystemCatalog() { - D_ASSERT(system); - return system->GetCatalog(); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -DBPathAndType DBPathAndType::Parse(const string &combined_path, const DBConfig &config) { - auto extension = ExtensionHelper::ExtractExtensionPrefixFromPath(combined_path); - if (!extension.empty()) { - // path is prefixed with an extension - remove it - auto path = StringUtil::Replace(combined_path, extension + ":", ""); - auto type = ExtensionHelper::ApplyExtensionAlias(extension); - return {path, type}; - } - // if there isn't - check the magic bytes of the file (if any) - auto file_type = MagicBytes::CheckMagicBytes(config.file_system.get(), combined_path); - if (file_type == DataFileType::SQLITE_FILE) { - return {combined_path, "sqlite"}; - } - return {combined_path, string()}; -} -} // namespace duckdb - - - -namespace duckdb { - -string GetDBAbsolutePath(const string &database_p, FileSystem &fs) { - auto database = FileSystem::ExpandPath(database_p, nullptr); - if (database.empty()) { - return ":memory:"; - } - if (database.rfind(":memory:", 0) == 0) { - // this is a memory db, just return it. - return database; - } - if (!ExtensionHelper::ExtractExtensionPrefixFromPath(database).empty()) { - // this database path is handled by a replacement open and is not a file path - return database; - } - if (fs.IsPathAbsolute(database)) { - return fs.NormalizeAbsolutePath(database); - } - return fs.NormalizeAbsolutePath(fs.JoinPath(FileSystem::GetWorkingDirectory(), database)); -} - -shared_ptr DBInstanceCache::GetInstanceInternal(const string &database, const DBConfig &config) { - shared_ptr db_instance; - - auto local_fs = FileSystem::CreateLocal(); - auto abs_database_path = GetDBAbsolutePath(database, *local_fs); - if (db_instances.find(abs_database_path) != db_instances.end()) { - db_instance = db_instances[abs_database_path].lock(); - if (db_instance) { - if (db_instance->instance->config != config) { - throw duckdb::ConnectionException( - "Can't open a connection to same database file with a different configuration " - "than existing connections"); - } - } else { - // clean-up - db_instances.erase(abs_database_path); - } - } - return db_instance; -} - -shared_ptr DBInstanceCache::GetInstance(const string &database, const DBConfig &config) { - lock_guard l(cache_lock); - return GetInstanceInternal(database, config); -} - -shared_ptr DBInstanceCache::CreateInstanceInternal(const string &database, DBConfig &config, - bool cache_instance) { - string abs_database_path; - if (config.file_system) { - abs_database_path = GetDBAbsolutePath(database, *config.file_system); - } else { - auto tmp_fs = FileSystem::CreateLocal(); - abs_database_path = GetDBAbsolutePath(database, *tmp_fs); - } - if (db_instances.find(abs_database_path) != db_instances.end()) { - throw duckdb::Exception(ExceptionType::CONNECTION, - "Instance with path: " + abs_database_path + " already exists."); - } - // Creates new instance - string instance_path = abs_database_path; - if (abs_database_path.rfind(":memory:", 0) == 0) { - instance_path = ":memory:"; - } - auto db_instance = make_shared(instance_path, &config); - if (cache_instance) { - db_instances[abs_database_path] = db_instance; - } - return db_instance; -} - -shared_ptr DBInstanceCache::CreateInstance(const string &database, DBConfig &config, bool cache_instance) { - lock_guard l(cache_lock); - return CreateInstanceInternal(database, config, cache_instance); -} - -shared_ptr DBInstanceCache::GetOrCreateInstance(const string &database, DBConfig &config_dict, - bool cache_instance) { - lock_guard l(cache_lock); - if (cache_instance) { - auto instance = GetInstanceInternal(database, config_dict); - if (instance) { - return instance; - } - } - return CreateInstanceInternal(database, config_dict, cache_instance); -} - -} // namespace duckdb - - - - -namespace duckdb { - -struct DefaultError { - ErrorType type; - const char *error; -}; - -static DefaultError internal_errors[] = { - {ErrorType::UNSIGNED_EXTENSION, - "Extension \"%s\" could not be loaded because its signature is either missing or invalid and unsigned extensions " - "are disabled by configuration (allow_unsigned_extensions)"}, - {ErrorType::INVALIDATED_TRANSACTION, "Current transaction is aborted (please ROLLBACK)"}, - {ErrorType::INVALIDATED_DATABASE, "Failed: database has been invalidated because of a previous fatal error. The " - "database must be restarted prior to being used again.\nOriginal error: \"%s\""}, - {ErrorType::INVALID, nullptr}}; - -string ErrorManager::FormatExceptionRecursive(ErrorType error_type, vector &values) { - if (error_type >= ErrorType::ERROR_COUNT) { - throw InternalException("Invalid error type passed to ErrorManager::FormatError"); - } - auto entry = custom_errors.find(error_type); - string error; - if (entry == custom_errors.end()) { - // error was not overwritten - error = internal_errors[int(error_type)].error; - } else { - // error was overwritten - error = entry->second; - } - return ExceptionFormatValue::Format(error, values); -} - -string ErrorManager::InvalidUnicodeError(const string &input, const string &context) { - UnicodeInvalidReason reason; - size_t pos; - auto unicode = Utf8Proc::Analyze(const_char_ptr_cast(input.c_str()), input.size(), &reason, &pos); - if (unicode != UnicodeType::INVALID) { - return "Invalid unicode error thrown but no invalid unicode detected in " + context; - } - string base_message; - switch (reason) { - case UnicodeInvalidReason::BYTE_MISMATCH: - base_message = "Invalid unicode (byte sequence mismatch)"; - break; - case UnicodeInvalidReason::INVALID_UNICODE: - base_message = "Invalid unicode"; - break; - default: - break; - } - return base_message + " detected in " + context; -} - -void ErrorManager::AddCustomError(ErrorType type, string new_error) { - custom_errors.insert(make_pair(type, std::move(new_error))); -} - -ErrorManager &ErrorManager::Get(ClientContext &context) { - return *DBConfig::GetConfig(context).error_manager; -} - -ErrorManager &ErrorManager::Get(DatabaseInstance &context) { - return *DBConfig::GetConfig(context).error_manager; -} - -} // namespace duckdb - - -namespace duckdb { - -static ExtensionAlias internal_aliases[] = {{"http", "httpfs"}, // httpfs - {"https", "httpfs"}, - {"md", "motherduck"}, // motherduck - {"s3", "httpfs"}, - {"postgres", "postgres_scanner"}, // postgres - {"sqlite", "sqlite_scanner"}, // sqlite - {"sqlite3", "sqlite_scanner"}, - {nullptr, nullptr}}; - -idx_t ExtensionHelper::ExtensionAliasCount() { - idx_t index; - for (index = 0; internal_aliases[index].alias != nullptr; index++) { - } - return index; -} - -ExtensionAlias ExtensionHelper::GetExtensionAlias(idx_t index) { - D_ASSERT(index < ExtensionAliasCount()); - return internal_aliases[index]; -} - -string ExtensionHelper::ApplyExtensionAlias(string extension_name) { - auto lname = StringUtil::Lower(extension_name); - for (idx_t index = 0; internal_aliases[index].alias; index++) { - if (lname == internal_aliases[index].alias) { - return internal_aliases[index].extension; - } - } - return extension_name; -} - -} // namespace duckdb - - - - - - - - -// Note that c++ preprocessor doesn't have a nice way to clean this up so we need to set the defines we use to false -// explicitly when they are undefined -#ifndef DUCKDB_EXTENSION_ICU_LINKED -#define DUCKDB_EXTENSION_ICU_LINKED false -#endif - -#ifndef DUCKDB_EXTENSION_EXCEL_LINKED -#define DUCKDB_EXTENSION_EXCEL_LINKED false -#endif - -#ifndef DUCKDB_EXTENSION_PARQUET_LINKED -#define DUCKDB_EXTENSION_PARQUET_LINKED false -#endif - -#ifndef DUCKDB_EXTENSION_TPCH_LINKED -#define DUCKDB_EXTENSION_TPCH_LINKED false -#endif - -#ifndef DUCKDB_EXTENSION_TPCDS_LINKED -#define DUCKDB_EXTENSION_TPCDS_LINKED false -#endif - -#ifndef DUCKDB_EXTENSION_FTS_LINKED -#define DUCKDB_EXTENSION_FTS_LINKED false -#endif - -#ifndef DUCKDB_EXTENSION_HTTPFS_LINKED -#define DUCKDB_EXTENSION_HTTPFS_LINKED false -#endif - -#ifndef DUCKDB_EXTENSION_JSON_LINKED -#define DUCKDB_EXTENSION_JSON_LINKED false -#endif - -#ifndef DUCKDB_EXTENSION_JEMALLOC_LINKED -#define DUCKDB_EXTENSION_JEMALLOC_LINKED false -#endif - -#ifndef DUCKDB_EXTENSION_AUTOCOMPLETE_LINKED -#define DUCKDB_EXTENSION_AUTOCOMPLETE_LINKED false -#endif - -// Load the generated header file containing our list of extension headers -#if defined(GENERATED_EXTENSION_HEADERS) && GENERATED_EXTENSION_HEADERS && !defined(DUCKDB_AMALGAMATION) - -#else -// TODO: rewrite package_build.py to allow also loading out-of-tree extensions in non-cmake builds, after that -// these can be removed -#if DUCKDB_EXTENSION_ICU_LINKED -#include "icu_extension.hpp" -#endif - -#if DUCKDB_EXTENSION_EXCEL_LINKED -#include "excel_extension.hpp" -#endif - -#if DUCKDB_EXTENSION_PARQUET_LINKED -#include "parquet_extension.hpp" -#endif - -#if DUCKDB_EXTENSION_TPCH_LINKED -#include "tpch_extension.hpp" -#endif - -#if DUCKDB_EXTENSION_TPCDS_LINKED -#include "tpcds_extension.hpp" -#endif - -#if DUCKDB_EXTENSION_FTS_LINKED -#include "fts_extension.hpp" -#endif - -#if DUCKDB_EXTENSION_HTTPFS_LINKED -#include "httpfs_extension.hpp" -#endif - -#if DUCKDB_EXTENSION_JSON_LINKED -#include "json_extension.hpp" -#endif - -#if DUCKDB_EXTENSION_JEMALLOC_LINKED -#include "jemalloc_extension.hpp" -#endif - -#if DUCKDB_EXTENSION_AUTOCOMPLETE_LINKED -#include "autocomplete_extension.hpp" -#endif -#endif - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Default Extensions -//===--------------------------------------------------------------------===// -static DefaultExtension internal_extensions[] = { - {"icu", "Adds support for time zones and collations using the ICU library", DUCKDB_EXTENSION_ICU_LINKED}, - {"excel", "Adds support for Excel-like format strings", DUCKDB_EXTENSION_EXCEL_LINKED}, - {"parquet", "Adds support for reading and writing parquet files", DUCKDB_EXTENSION_PARQUET_LINKED}, - {"tpch", "Adds TPC-H data generation and query support", DUCKDB_EXTENSION_TPCH_LINKED}, - {"tpcds", "Adds TPC-DS data generation and query support", DUCKDB_EXTENSION_TPCDS_LINKED}, - {"fts", "Adds support for Full-Text Search Indexes", DUCKDB_EXTENSION_FTS_LINKED}, - {"httpfs", "Adds support for reading and writing files over a HTTP(S) connection", DUCKDB_EXTENSION_HTTPFS_LINKED}, - {"json", "Adds support for JSON operations", DUCKDB_EXTENSION_JSON_LINKED}, - {"jemalloc", "Overwrites system allocator with JEMalloc", DUCKDB_EXTENSION_JEMALLOC_LINKED}, - {"autocomplete", "Adds support for autocomplete in the shell", DUCKDB_EXTENSION_AUTOCOMPLETE_LINKED}, - {"motherduck", "Enables motherduck integration with the system", false}, - {"sqlite_scanner", "Adds support for reading SQLite database files", false}, - {"postgres_scanner", "Adds support for reading from a Postgres database", false}, - {"inet", "Adds support for IP-related data types and functions", false}, - {"spatial", "Geospatial extension that adds support for working with spatial data and functions", false}, - {"substrait", "Adds support for the Substrait integration", false}, - {"aws", "Provides features that depend on the AWS SDK", false}, - {"arrow", "A zero-copy data integration between Apache Arrow and DuckDB", false}, - {"azure", "Adds a filesystem abstraction for Azure blob storage to DuckDB", false}, - {"iceberg", "Adds support for Apache Iceberg", false}, - {"visualizer", "Creates an HTML-based visualization of the query plan", false}, - {nullptr, nullptr, false}}; - -idx_t ExtensionHelper::DefaultExtensionCount() { - idx_t index; - for (index = 0; internal_extensions[index].name != nullptr; index++) { - } - return index; -} - -DefaultExtension ExtensionHelper::GetDefaultExtension(idx_t index) { - D_ASSERT(index < DefaultExtensionCount()); - return internal_extensions[index]; -} - -//===--------------------------------------------------------------------===// -// Allow Auto-Install Extensions -//===--------------------------------------------------------------------===// -static const char *auto_install[] = {"motherduck", "postgres_scanner", "sqlite_scanner", nullptr}; - -// TODO: unify with new autoload mechanism -bool ExtensionHelper::AllowAutoInstall(const string &extension) { - auto lcase = StringUtil::Lower(extension); - for (idx_t i = 0; auto_install[i]; i++) { - if (lcase == auto_install[i]) { - return true; - } - } - return false; -} - -bool ExtensionHelper::CanAutoloadExtension(const string &ext_name) { -#ifdef DUCKDB_DISABLE_EXTENSION_LOAD - return false; -#endif - - if (ext_name.empty()) { - return false; - } - for (const auto &ext : AUTOLOADABLE_EXTENSIONS) { - if (ext_name == ext) { - return true; - } - } - return false; -} - -string ExtensionHelper::AddExtensionInstallHintToErrorMsg(ClientContext &context, const string &base_error, - const string &extension_name) { - auto &dbconfig = DBConfig::GetConfig(context); - string install_hint; - - if (!ExtensionHelper::CanAutoloadExtension(extension_name)) { - install_hint = "Please try installing and loading the " + extension_name + " extension:\nINSTALL " + - extension_name + ";\nLOAD " + extension_name + ";\n\n"; - } else if (!dbconfig.options.autoload_known_extensions) { - install_hint = - "Please try installing and loading the " + extension_name + " extension by running:\nINSTALL " + - extension_name + ";\nLOAD " + extension_name + - ";\n\nAlternatively, consider enabling auto-install " - "and auto-load by running:\nSET autoinstall_known_extensions=1;\nSET autoload_known_extensions=1;"; - } else if (!dbconfig.options.autoinstall_known_extensions) { - install_hint = - "Please try installing the " + extension_name + " extension by running:\nINSTALL " + extension_name + - ";\n\nAlternatively, consider enabling autoinstall by running:\nSET autoinstall_known_extensions=1;"; - } - - if (!install_hint.empty()) { - return base_error + "\n\n" + install_hint; - } - - return base_error; -} - -bool ExtensionHelper::TryAutoLoadExtension(ClientContext &context, const string &extension_name) noexcept { - auto &dbconfig = DBConfig::GetConfig(context); - try { - if (dbconfig.options.autoinstall_known_extensions) { - ExtensionHelper::InstallExtension(context, extension_name, false, - context.config.autoinstall_extension_repo); - } - ExtensionHelper::LoadExternalExtension(context, extension_name); - return true; - } catch (...) { - return false; - } - return false; -} - -void ExtensionHelper::AutoLoadExtension(ClientContext &context, const string &extension_name) { - auto &dbconfig = DBConfig::GetConfig(context); - try { -#ifndef DUCKDB_WASM - if (dbconfig.options.autoinstall_known_extensions) { - ExtensionHelper::InstallExtension(context, extension_name, false, - context.config.autoinstall_extension_repo); - } -#endif - ExtensionHelper::LoadExternalExtension(context, extension_name); - } catch (Exception &e) { - throw AutoloadException(extension_name, e); - } -} - -//===--------------------------------------------------------------------===// -// Load Statically Compiled Extension -//===--------------------------------------------------------------------===// -void ExtensionHelper::LoadAllExtensions(DuckDB &db) { - // The in-tree extensions that we check. Non-cmake builds are currently limited to these for static linking - // TODO: rewrite package_build.py to allow also loading out-of-tree extensions in non-cmake builds, after that - // these can be removed - unordered_set extensions {"parquet", "icu", "tpch", "tpcds", "fts", "httpfs", "visualizer", - "json", "excel", "sqlsmith", "inet", "jemalloc", "autocomplete"}; - for (auto &ext : extensions) { - LoadExtensionInternal(db, ext, true); - } - -#if defined(GENERATED_EXTENSION_HEADERS) && GENERATED_EXTENSION_HEADERS - for (auto &ext : linked_extensions) { - LoadExtensionInternal(db, ext, true); - } -#endif -} - -ExtensionLoadResult ExtensionHelper::LoadExtension(DuckDB &db, const std::string &extension) { - return LoadExtensionInternal(db, extension, false); -} - -ExtensionLoadResult ExtensionHelper::LoadExtensionInternal(DuckDB &db, const std::string &extension, - bool initial_load) { -#ifdef DUCKDB_TEST_REMOTE_INSTALL - if (!initial_load && StringUtil::Contains(DUCKDB_TEST_REMOTE_INSTALL, extension)) { - Connection con(db); - auto result = con.Query("INSTALL " + extension); - if (result->HasError()) { - result->Print(); - return ExtensionLoadResult::EXTENSION_UNKNOWN; - } - result = con.Query("LOAD " + extension); - if (result->HasError()) { - result->Print(); - return ExtensionLoadResult::EXTENSION_UNKNOWN; - } - return ExtensionLoadResult::LOADED_EXTENSION; - } -#endif - -#ifdef DUCKDB_EXTENSIONS_TEST_WITH_LOADABLE - // Note: weird comma's are on purpose to do easy string contains on a list of extension names - if (!initial_load && StringUtil::Contains(DUCKDB_EXTENSIONS_TEST_WITH_LOADABLE, "," + extension + ",")) { - Connection con(db); - auto result = con.Query((string) "LOAD '" + DUCKDB_EXTENSIONS_BUILD_PATH + "/" + extension + "/" + extension + - ".duckdb_extension'"); - if (result->HasError()) { - result->Print(); - return ExtensionLoadResult::EXTENSION_UNKNOWN; - } - return ExtensionLoadResult::LOADED_EXTENSION; - } -#endif - - // This is the main extension loading mechanism that loads the extension that are statically linked. -#if defined(GENERATED_EXTENSION_HEADERS) && GENERATED_EXTENSION_HEADERS - if (TryLoadLinkedExtension(db, extension)) { - return ExtensionLoadResult::LOADED_EXTENSION; - } else { - return ExtensionLoadResult::NOT_LOADED; - } -#endif - - // This is the fallback to the "old" extension loading mechanism for non-cmake builds - // TODO: rewrite package_build.py to allow also loading out-of-tree extensions in non-cmake builds - if (extension == "parquet") { -#if DUCKDB_EXTENSION_PARQUET_LINKED - db.LoadExtension(); -#else - // parquet extension required but not build: skip this test - return ExtensionLoadResult::NOT_LOADED; -#endif - } else if (extension == "icu") { -#if DUCKDB_EXTENSION_ICU_LINKED - db.LoadExtension(); -#else - // icu extension required but not build: skip this test - return ExtensionLoadResult::NOT_LOADED; -#endif - } else if (extension == "tpch") { -#if DUCKDB_EXTENSION_TPCH_LINKED - db.LoadExtension(); -#else - // icu extension required but not build: skip this test - return ExtensionLoadResult::NOT_LOADED; -#endif - } else if (extension == "tpcds") { -#if DUCKDB_EXTENSION_TPCDS_LINKED - db.LoadExtension(); -#else - // icu extension required but not build: skip this test - return ExtensionLoadResult::NOT_LOADED; -#endif - } else if (extension == "fts") { -#if DUCKDB_EXTENSION_FTS_LINKED -// db.LoadExtension(); -#else - // fts extension required but not build: skip this test - return ExtensionLoadResult::NOT_LOADED; -#endif - } else if (extension == "httpfs") { -#if DUCKDB_EXTENSION_HTTPFS_LINKED - db.LoadExtension(); -#else - return ExtensionLoadResult::NOT_LOADED; -#endif - } else if (extension == "visualizer") { -#if DUCKDB_EXTENSION_VISUALIZER_LINKED - db.LoadExtension(); -#else - // visualizer extension required but not build: skip this test - return ExtensionLoadResult::NOT_LOADED; -#endif - } else if (extension == "json") { -#if DUCKDB_EXTENSION_JSON_LINKED - db.LoadExtension(); -#else - // json extension required but not build: skip this test - return ExtensionLoadResult::NOT_LOADED; -#endif - } else if (extension == "excel") { -#if DUCKDB_EXTENSION_EXCEL_LINKED - db.LoadExtension(); -#else - // excel extension required but not build: skip this test - return ExtensionLoadResult::NOT_LOADED; -#endif - } else if (extension == "sqlsmith") { -#if DUCKDB_EXTENSION_SQLSMITH_LINKED - db.LoadExtension(); -#else - // excel extension required but not build: skip this test - return ExtensionLoadResult::NOT_LOADED; -#endif - } else if (extension == "jemalloc") { -#if DUCKDB_EXTENSION_JEMALLOC_LINKED - db.LoadExtension(); -#else - // jemalloc extension required but not build: skip this test - return ExtensionLoadResult::NOT_LOADED; -#endif - } else if (extension == "autocomplete") { -#if DUCKDB_EXTENSION_AUTOCOMPLETE_LINKED - db.LoadExtension(); -#else - // autocomplete extension required but not build: skip this test - return ExtensionLoadResult::NOT_LOADED; -#endif - } else if (extension == "inet") { -#if DUCKDB_EXTENSION_INET_LINKED - db.LoadExtension(); -#else - // inet extension required but not build: skip this test - return ExtensionLoadResult::NOT_LOADED; -#endif - } - - return ExtensionLoadResult::LOADED_EXTENSION; -} - -static vector public_keys = { - R"( ------BEGIN PUBLIC KEY----- -MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA6aZuHUa1cLR9YDDYaEfi -UDbWY8m2t7b71S+k1ZkXfHqu+5drAxm+dIDzdOHOKZSIdwnJbT3sSqwFoG6PlXF3 -g3dsJjax5qESIhbVvf98nyipwNINxoyHCkcCIPkX17QP2xpnT7V59+CqcfDJXLqB -ymjqoFSlaH8dUCHybM4OXlWnAtVHW/nmw0khF8CetcWn4LxaTUHptByaBz8CasSs -gWpXgSfaHc3R9eArsYhtsVFGyL/DEWgkEHWolxY3Llenhgm/zOf3s7PsAMe7EJX4 -qlSgiXE6OVBXnqd85z4k20lCw/LAOe5hoTMmRWXIj74MudWe2U91J6GrrGEZa7zT -7QIDAQAB ------END PUBLIC KEY----- -)", - R"( ------BEGIN PUBLIC KEY----- -MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAq8Gg1S/LI6ApMAYsFc9m -PrkFIY+nc0LXSpxm77twU8D5M0Xkz/Av4f88DQmj1OE3164bEtR7sl7xDPZojFHj -YYyucJxEI97l5OU1d3Pc1BdKXL4+mnW5FlUGj218u8qD+G1hrkySXQkrUzIjPPNw -o6knF3G/xqQF+KI+tc7ajnTni8CAlnUSxfnstycqbVS86m238PLASVPK9/SmIRgO -XCEV+ZNMlerq8EwsW4cJPHH0oNVMcaG+QT4z79roW1rbJghn9ubAVdQU6VLUAikI -b8keUyY+D0XdY9DpDBeiorb1qPYt8BPLOAQrIUAw1CgpMM9KFp9TNvW47KcG4bcB -dQIDAQAB ------END PUBLIC KEY----- -)", - R"( ------BEGIN PUBLIC KEY----- -MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAyYATA9KOQ0Azf97QAPfY -Jc/WeZyE4E1qlRgKWKqNtYSXZqk5At0V7w2ntAWtYSpczFrVepCJ0oPMDpZTigEr -NgOgfo5LEhPx5XmtCf62xY/xL3kgtfz9Mm5TBkuQy4KwY4z1npGr4NYYDXtF7kkf -LQE+FnD8Yr4E0wHBib7ey7aeeKWmwqvUjzDqG+TzaqwzO/RCUsSctqSS0t1oo2hv -4q1ofanUXsV8MXk/ujtgxu7WkVvfiSpK1zRazgeZjcrQFO9qL/pla0vBUxa1U8He -GMLnL0oRfcMg7yKrbIMrvlEl2ZmiR9im44dXJWfY42quObwr1PuEkEoCMcMisSWl -jwIDAQAB ------END PUBLIC KEY----- -)", - R"( ------BEGIN PUBLIC KEY----- -MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA4RvbWx3zLblDHH/lGUF5 -Q512MT+v3YPriuibROMllv8WiCLAMeJ0QXbVaIzBOeHDeLx8yvoZZN+TENKxtT6u -IfMMneUzxHBqy0AQNfIsSsOnG5nqoeE/AwbS6VqCdH1aLfoCoPffacHYa0XvTcsi -aVlZfr+UzJS+ty8pRmFVi1UKSOADDdK8XfIovJl/zMP2TxYX2Y3fnjeLtl8Sqs2e -P+eHDoy7Wi4EPTyY7tNTCfxwKNHn1HQ5yrv5dgvMxFWIWXGz24yikFvtwLGHe8uJ -Wi+fBX+0PF0diZ6pIthZ149VU8qCqYAXjgpxZ0EZdrsiF6Ewz0cfg20SYApFcmW4 -pwIDAQAB ------END PUBLIC KEY----- -)", - R"( ------BEGIN PUBLIC KEY----- -MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAyhd5AfwrUohG3O4DE0K9 -O3FmgB7zE4aDnkL8UUfGCh5kdP8q7ewMjekY+c6LwWOmpdJpSwqhfV1q5ZU1l6rk -3hlt03LO3sgs28kcfOVH15hqfxts6Sg5KcRjxStE50ORmXGwXDcS9vqkJ60J1EHA -lcZqbCRSO73ZPLhdepfd0/C6tM0L7Ge6cAE62/MTmYNGv8fDzwQr/kYIJMdoS8Zp -thRpctFZJtPs3b0fffZA/TCLVKMvEVgTWs48751qKid7N/Lm/iEGx/tOf4o23Nec -Pz1IQaGLP+UOLVQbqQBHJWNOqigm7kWhDgs3N4YagWgxPEQ0WVLtFji/ZjlKZc7h -dwIDAQAB ------END PUBLIC KEY----- -)", - R"( ------BEGIN PUBLIC KEY----- -MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAnFDg3LhyV6BVE2Z3zQvN -6urrKvPhygTa5+wIPGwYTzJ8DfGALqlsX3VOXMvcJTca6SbuwwkoXHuSU5wQxfcs -bt4jTXD3NIoRwQPl+D9IbgIMuX0ACl27rJmr/f9zkY7qui4k1X82pQkxBe+/qJ4r -TBwVNONVx1fekTMnSCEhwg5yU3TNbkObu0qlQeJfuMWLDQbW/8v/qfr/Nz0JqHDN -yYKfKvFMlORxyJYiOyeOsbzNGEhkGQGOmKhRUhS35kD+oA0jqwPwMCM9O4kFg/L8 -iZbpBBX2By1K3msejWMRAewTOyPas6YMQOYq9BMmWQqzVtG5xcaSJwN/YnMpJyqb -sQIDAQAB ------END PUBLIC KEY----- -)", - R"( ------BEGIN PUBLIC KEY----- -MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA1z0RU8vGrfEkrscEoZKA -GiOcGh2EMcKwjQpl4nKuR9H4o/dg+CZregVSHg7MP2f8mhLZZyoFev49oWOV4Rmi -qs99UNxm7DyKW1fF1ovowsUW5lsDoKYLvpuzHo0s4laiV4AnIYP7tHGLdzsnK2Os -Cp5dSuMwKHPZ9N25hXxFB/dRrAdIiXHvbSqr4N29XzfQloQpL3bGHLKY6guFHluH -X5dJ9eirVakWWou7BR2rnD0k9vER6oRdVnJ6YKb5uhWEOQ3NmV961oyr+uiDTcep -qqtGHWuFhENixtiWGjFJJcACwqxEAW3bz9lyrfnPDsHSW/rlQVDIAkik+fOp+R7L -kQIDAQAB ------END PUBLIC KEY----- -)", - R"( ------BEGIN PUBLIC KEY----- -MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAxwO27e1vnbNcpiDg7Wwx -K/w5aEGukXotu3529ieq+O39H0+Bak4vIbzGhDUh3/ElmxaFMAs4PYrWe/hc2WFD -H4JCOoFIn4y9gQeE855DGGFgeIVd1BnSs5S+5wUEMxLNyHdHSmINN6FsoZ535iUg -KdYjRh1iZevezg7ln8o/O36uthu925ehFBXSy6jLJgQlwmq0KxZJE0OAZhuDBM60 -MtIunNa/e5y+Gw3GknFwtRLmn/nEckZx1nEtepYvvUa7UGy+8KuGuhOerCZTutbG -k8liCVgGenRve8unA2LrBbpL+AUf3CrZU/uAxxTqWmw6Z/S6TeW5ozeeyOCh8ii6 -TwIDAQAB ------END PUBLIC KEY----- -)", - R"( ------BEGIN PUBLIC KEY----- -MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAsGIFOfIQ4RI5qu4klOxf -ge6eXwBMAkuTXyhyIIJDtE8CurnwQvUXVlt+Kf0SfuIFW6MY5ErcWE/vMFbc81IR -9wByOAAV2CTyiLGZT63uE8pN6FSHd6yGYCLjXd3P3cnP3Qj5pBncpLuAUDfHG4wP -bs9jIADw3HysD+eCNja8p7ZC7CzWxTcO7HsEu9deAAU19YywdpagXvQ0pJ9zV5qU -jrHxBygl31t6TmmX+3d+azjGu9Hu36E+5wcSOOhuwAFXDejb40Ixv53ItJ3fZzzH -PF2nj9sQvQ8c5ptjyOvQCBRdqkEWXIVHClxqWb+o59pDIh1G0UGcmiDN7K9Gz5HA -ZQIDAQAB ------END PUBLIC KEY----- -)", - R"( ------BEGIN PUBLIC KEY----- -MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAt9uUnlW/CoUXT68yaZh9 -SeXHzGRCPNEI98Tara+dgYxDX1z7nfOh8o15liT0QsAzx34EewZOxcKCNiV/dZX5 -z4clCkD8uUbZut6IVx8Eu+7Qcd5jZthRc6hQrN9Ltv7ZQEh7KGXOHa53kT2K01ws -4jbVmd/7Nx7y0Yyqhja01pIu/CUaTkODfQxBXwriLdIzp7y/iJeF/TLqCwZWHKQx -QOZnsPEveB1F00Va9MeAtTlXFUJ/TQXquqTjeLj4HuIRtbyuNgWoc0JyF+mcafAl -bnrNEBIfxZhAT81aUCIAzRJp6AqfdeZxnZ/WwohtZQZLXAxFQPTWCcP+Z9M7OIQL -WwIDAQAB ------END PUBLIC KEY----- -)", - R"( ------BEGIN PUBLIC KEY----- -MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA56NhfACkeCyZM07l2wmd -iTp24E2tLLKU3iByKlIRWRAvXsOejRMJTHTNHWa3cQ7uLP++Tf2St7ksNsyPMNZy -9QRTLNCYr9rN9loLwdb2sMWxFBwwzCaAOTahGI7GJQy30UB7FEND0X/5U2rZvQij -Q6K+O4aa+K9M5qyOHNMmXywmTnAgWKNaNxQHPRtD2+dSj60T6zXdtIuCrPfcNGg5 -gj07qWGEXX83V/L7nSqCiIVYg/wqds1x52Yjk1nhXYNBTqlnhmOd8LynGxz/sXC7 -h2Q9XsHjXIChW4FHyLIOl6b4zPMBSxzCigYm3QZJWfAkZv5PBRtnq7vhYOLHzLQj -CwIDAQAB ------END PUBLIC KEY----- -)", - R"( ------BEGIN PUBLIC KEY----- -MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAmfPLe0IWGYC0MZC6YiM3 -QGfhT6zSKB0I2DW44nlBlWUcF+32jW2bFJtgE76qGGKFeU4kJBWYr99ufHoAodNg -M1Ehl/JfQ5KmbC1WIqnFTrgbmqJde79jeCvCpbFLuqnzidwO1PbXDbfRFQcgWaXT -mDVLNNVmLxA0GkCv+kydE2gtcOD9BDceg7F/56TDvclyI5QqAnjE2XIRMPZlXQP4 -oF2kgz4Cn7LxLHYmkU2sS9NYLzHoyUqFplWlxkQjA4eQ0neutV1Ydmc1IX8W7R38 -A7nFtaT8iI8w6Vkv7ijYN6xf5cVBPKZ3Dv7AdwPet86JD5mf5v+r7iwg5xl3r77Z -iwIDAQAB ------END PUBLIC KEY----- -)", - R"( ------BEGIN PUBLIC KEY----- -MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAoB1kWsX8YmCcFOD9ilBY -xK076HmUAN026uJ8JpmU9Hz+QT1FNXOsnj1h2G6U6btYVIdHUTHy/BvAumrDKqRz -qcEAzCuhxUjPjss54a/Zqu6nQcoIPHuG/Er39oZHIVkPR1WCvWj8wmyYv6T//dPH -unO6tW29sXXxS+J1Gah6vpbtJw1pI/liah1DZzb13KWPDI6ZzviTNnW4S05r6js/ -30He+Yud6aywrdaP/7G90qcrteEFcjFy4Xf+5vG960oKoGoDplwX5poay1oCP9tb -g8AC8VSRAGi3oviTeSWZcrLXS8AtJhGvF48cXQj2q+8YeVKVDpH6fPQxJ9Sh9aeU -awIDAQAB ------END PUBLIC KEY----- -)", - R"( ------BEGIN PUBLIC KEY----- -MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA4NTMAIYIlCMID00ufy/I -AZXc8pocDx9N1Q5x5/cL3aIpLmx02AKo9BvTJaJuHiTjlwYhPtlhIrHV4HUVTkOX -sISp8B8v9i2I1RIvCTAcvy3gcH6rdRWZ0cdTUiMEqnnxBX9zdzl8oMzZcyauv19D -BeqJvzflIT96b8g8K3mvgJHs9a1j9f0gN8FuTA0c52DouKnrh8UwH7mlrumYerJw -6goJGQuK1HEOt6bcQuvogkbgJWOoEYwjNrPwQvIcP4wyrgSnOHg1yXOFE84oVynJ -czQEOz9ke42I3h8wrnQxilEYBVo2uX8MenqTyfGnE32lPRt3Wv1iEVQls8Cxiuy2 -CQIDAQAB ------END PUBLIC KEY----- -)", - R"( ------BEGIN PUBLIC KEY----- -MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA3bUtfp66OtRyvIF/oucn -id8mo7gvbNEH04QMLO3Ok43dlWgWI3hekJAqOYc0mvoI5anqr98h8FI7aCYZm/bY -vpz0I1aXBaEPh3aWh8f/w9HME7ykBvmhMe3J+VFGWWL4eswfRl//GCtnSMBzDFhM -SaQOTvADWHkC0njeI5yXjf/lNm6fMACP1cnhuvCtnx7VP/DAtvUk9usDKG56MJnZ -UoVM3HHjbJeRwxCdlSWe12ilCdwMRKSDY92Hk38/zBLenH04C3HRQLjBGewACUmx -uvNInehZ4kSYFGa+7UxBxFtzJhlKzGR73qUjpWzZivCe1K0WfRVP5IWsKNCCESJ/ -nQIDAQAB ------END PUBLIC KEY----- -)", - R"( ------BEGIN PUBLIC KEY----- -MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAyV2dE/CRUAUE8ybq/DoS -Lc7QlYXh04K+McbhN724TbHahLTuDk5mR5TAunA8Nea4euRzknKdMFAz1eh9gyy3 -5x4UfXQW1fIZqNo6WNrGxYJgWAXU+pov+OvxsMQWzqS4jrTHDHbblCCLKp1akwJk -aFNyqgjAL373PcqXC+XAn8vHx4xHFoFP5lq4lLcJCOW5ee9v9El3w0USLwS+t1cF -RY3kuV6Njlr4zsRH9iM6/zaSuCALYWJ/JrPEurSJXzFZnWsvn6aQdeNeAn08+z0F -k2NwaauEo0xmLqzqTRGzjHqKKmeefN3/+M/FN2FrApDlxWQfhD2Y3USdAiN547Nj -1wIDAQAB ------END PUBLIC KEY----- -)", - R"( ------BEGIN PUBLIC KEY----- -MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAvm2+kTrEQWZXuxhWzBdl -PCbQGqbrukbeS6JKSlQLJDC8ayZIxFxatqg1Q8UPyv89MVRsHOGlG1OqFaOEtPjQ -Oo6j/moFwB4GPyJhJHOGpCKa4CLB5clhfDCLJw6ty7PcDU3T6yW4X4Qc5k4LRRWy -yzC8lVHfBdarN+1iEe0ALMOGoeiJjVn6i/AFxktRwgd8njqv/oWQyfjJZXkNMsb6 -7ZDxNVAUrp/WXpE4Kq694bB9xa/pWsqv7FjQJUgTnEzvbN+qXnVPtA7dHcOYYJ8Z -SbrJUfHrf8TS5B54AiopFpWG+hIbjqqdigqabBqFpmjiRDZgDy4zJJj52xJZMnrp -rwIDAQAB ------END PUBLIC KEY----- -)", - R"( ------BEGIN PUBLIC KEY----- -MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAwEAcVmY3589O02pLA22f -MlarLyJUgy0BeJDG5AUsi17ct8sHZzRiv9zKQVCBk1CtZY//jyqnrM7iCBLWsyby -TiTOtGYHHApaLnNjjtaHdQ6zplhbc3g2XLy+4ab8GNKG3zc8iXpsQM6r+JO5n9pm -V9vollz9dkFxS9l+1P17lZdIgCh9O3EIFJv5QCd5c9l2ezHAan2OhkWhiDtldnH/ -MfRXbz7X5sqlwWLa/jhPtvY45x7dZaCHGqNzbupQZs0vHnAVdDu3vAWDmT/3sXHG -vmGxswKA9tPU0prSvQWLz4LUCnGi/cC5R+fiu+fovFM/BwvaGtqBFIF/1oWVq7bZ -4wIDAQAB ------END PUBLIC KEY----- -)", - R"( ------BEGIN PUBLIC KEY----- -MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA25qGwNO1+qHygC8mjm8L -3I66mV/IzslgBDHC91mE8YcI5Fq0sdrtsbUhK3z89wIN/zOhbHX0NEiXm2GxUnsI -vb5tDZXAh7AbTnXTMVbxO/e/8sPLUiObGjDvjVzyzrxOeG87yK/oIiilwk9wTsIb -wMn2Grj4ht9gVKx3oGHYV7STNdWBlzSaJj4Ou7+5M1InjPDRFZG1K31D2d3IHByX -lmcRPZtPFTa5C1uVJw00fI4F4uEFlPclZQlR5yA0G9v+0uDgLcjIUB4eqwMthUWc -dHhlmrPp04LI19eksWHCtG30RzmUaxDiIC7J2Ut0zHDqUe7aXn8tOVI7dE9tTKQD -KQIDAQAB ------END PUBLIC KEY----- -)", - R"( ------BEGIN PUBLIC KEY----- -MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA7EC2bx7aRnf3TcRg5gmw -QOKNCUheCelK8hoXLMsKSJqmufyJ+IHUejpXGOpvyYRbACiJ5GiNcww20MVpTBU7 -YESWB2QSU2eEJJXMq84qsZSO8WGmAuKpUckI+hNHKQYJBEDOougV6/vVVEm5c5bc -SLWQo0+/ciQ21Zwz5SwimX8ep1YpqYirO04gcyGZzAfGboXRvdUwA+1bZvuUXdKC -4zsCw2QALlcVpzPwjB5mqA/3a+SPgdLAiLOwWXFDRMnQw44UjsnPJFoXgEZiUpZm -EMS5gLv50CzQqJXK9mNzPuYXNUIc4Pw4ssVWe0OfN3Od90gl5uFUwk/G9lWSYnBN -3wIDAQAB ------END PUBLIC KEY----- -)"}; - -const vector ExtensionHelper::GetPublicKeys() { - return public_keys; -} - -} // namespace duckdb - - - - - -#ifndef DISABLE_DUCKDB_REMOTE_INSTALL -#ifndef DUCKDB_DISABLE_EXTENSION_LOAD - -#endif -#endif - - -#include - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Install Extension -//===--------------------------------------------------------------------===// -const string ExtensionHelper::NormalizeVersionTag(const string &version_tag) { - if (version_tag.length() > 0 && version_tag[0] != 'v') { - return "v" + version_tag; - } - return version_tag; -} - -bool ExtensionHelper::IsRelease(const string &version_tag) { - return !StringUtil::Contains(version_tag, "-dev"); -} - -const string ExtensionHelper::GetVersionDirectoryName() { -#ifdef DUCKDB_WASM_VERSION - return DUCKDB_QUOTE_DEFINE(DUCKDB_WASM_VERSION); -#endif - if (IsRelease(DuckDB::LibraryVersion())) { - return NormalizeVersionTag(DuckDB::LibraryVersion()); - } else { - return DuckDB::SourceID(); - } -} - -const vector ExtensionHelper::PathComponents() { - return vector {".duckdb", "extensions", GetVersionDirectoryName(), DuckDB::Platform()}; -} - -string ExtensionHelper::ExtensionDirectory(DBConfig &config, FileSystem &fs) { -#ifdef WASM_LOADABLE_EXTENSIONS - throw PermissionException("ExtensionDirectory functionality is not supported in duckdb-wasm"); -#endif - string extension_directory; - if (!config.options.extension_directory.empty()) { // create the extension directory if not present - extension_directory = config.options.extension_directory; - // TODO this should probably live in the FileSystem - // convert random separators to platform-canonic - extension_directory = fs.ConvertSeparators(extension_directory); - // expand ~ in extension directory - extension_directory = fs.ExpandPath(extension_directory); - if (!fs.DirectoryExists(extension_directory)) { - auto sep = fs.PathSeparator(extension_directory); - auto splits = StringUtil::Split(extension_directory, sep); - D_ASSERT(!splits.empty()); - string extension_directory_prefix; - if (StringUtil::StartsWith(extension_directory, sep)) { - extension_directory_prefix = sep; // this is swallowed by Split otherwise - } - for (auto &split : splits) { - extension_directory_prefix = extension_directory_prefix + split + sep; - if (!fs.DirectoryExists(extension_directory_prefix)) { - fs.CreateDirectory(extension_directory_prefix); - } - } - } - } else { // otherwise default to home - string home_directory = fs.GetHomeDirectory(); - // exception if the home directory does not exist, don't create whatever we think is home - if (!fs.DirectoryExists(home_directory)) { - throw IOException("Can't find the home directory at '%s'\nSpecify a home directory using the SET " - "home_directory='/path/to/dir' option.", - home_directory); - } - extension_directory = home_directory; - } - D_ASSERT(fs.DirectoryExists(extension_directory)); - - auto path_components = PathComponents(); - for (auto &path_ele : path_components) { - extension_directory = fs.JoinPath(extension_directory, path_ele); - if (!fs.DirectoryExists(extension_directory)) { - fs.CreateDirectory(extension_directory); - } - } - return extension_directory; -} - -string ExtensionHelper::ExtensionDirectory(ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - auto &fs = FileSystem::GetFileSystem(context); - return ExtensionDirectory(config, fs); -} - -bool ExtensionHelper::CreateSuggestions(const string &extension_name, string &message) { - vector candidates; - for (idx_t ext_count = ExtensionHelper::DefaultExtensionCount(), i = 0; i < ext_count; i++) { - candidates.emplace_back(ExtensionHelper::GetDefaultExtension(i).name); - } - for (idx_t ext_count = ExtensionHelper::ExtensionAliasCount(), i = 0; i < ext_count; i++) { - candidates.emplace_back(ExtensionHelper::GetExtensionAlias(i).alias); - } - auto closest_extensions = StringUtil::TopNLevenshtein(candidates, extension_name); - message = StringUtil::CandidatesMessage(closest_extensions, "Candidate extensions"); - for (auto &closest : closest_extensions) { - if (closest == extension_name) { - message = "Extension \"" + extension_name + "\" is an existing extension.\n"; - return true; - } - } - return false; -} - -void ExtensionHelper::InstallExtension(DBConfig &config, FileSystem &fs, const string &extension, bool force_install, - const string &repository) { -#ifdef WASM_LOADABLE_EXTENSIONS - // Install is currently a no-op - return; -#endif - string local_path = ExtensionDirectory(config, fs); - InstallExtensionInternal(config, nullptr, fs, local_path, extension, force_install, repository); -} - -void ExtensionHelper::InstallExtension(ClientContext &context, const string &extension, bool force_install, - const string &repository) { -#ifdef WASM_LOADABLE_EXTENSIONS - // Install is currently a no-op - return; -#endif - auto &config = DBConfig::GetConfig(context); - auto &fs = FileSystem::GetFileSystem(context); - string local_path = ExtensionDirectory(context); - auto &client_config = ClientConfig::GetConfig(context); - InstallExtensionInternal(config, &client_config, fs, local_path, extension, force_install, repository); -} - -unsafe_unique_array ReadExtensionFileFromDisk(FileSystem &fs, const string &path, idx_t &file_size) { - auto source_file = fs.OpenFile(path, FileFlags::FILE_FLAGS_READ); - file_size = source_file->GetFileSize(); - auto in_buffer = make_unsafe_uniq_array(file_size); - source_file->Read(in_buffer.get(), file_size); - source_file->Close(); - return in_buffer; -} - -void WriteExtensionFileToDisk(FileSystem &fs, const string &path, void *data, idx_t data_size) { - auto target_file = fs.OpenFile(path, FileFlags::FILE_FLAGS_WRITE | FileFlags::FILE_FLAGS_APPEND | - FileFlags::FILE_FLAGS_FILE_CREATE_NEW); - target_file->Write(data, data_size); - target_file->Close(); - target_file.reset(); -} - -string ExtensionHelper::ExtensionUrlTemplate(optional_ptr client_config, const string &repository) { - string versioned_path = "/${REVISION}/${PLATFORM}/${NAME}.duckdb_extension"; -#ifdef WASM_LOADABLE_EXTENSIONS - string default_endpoint = "https://extensions.duckdb.org"; - versioned_path = "/duckdb-wasm" + versioned_path + ".wasm"; -#else - string default_endpoint = "http://extensions.duckdb.org"; - versioned_path = versioned_path + ".gz"; -#endif - string custom_endpoint = client_config ? client_config->custom_extension_repo : string(); - string endpoint; - if (!repository.empty()) { - endpoint = repository; - } else if (!custom_endpoint.empty()) { - endpoint = custom_endpoint; - } else { - endpoint = default_endpoint; - } - string url_template = endpoint + versioned_path; - return url_template; -} - -string ExtensionHelper::ExtensionFinalizeUrlTemplate(const string &url_template, const string &extension_name) { - auto url = StringUtil::Replace(url_template, "${REVISION}", GetVersionDirectoryName()); - url = StringUtil::Replace(url, "${PLATFORM}", DuckDB::Platform()); - url = StringUtil::Replace(url, "${NAME}", extension_name); - return url; -} - -void ExtensionHelper::InstallExtensionInternal(DBConfig &config, ClientConfig *client_config, FileSystem &fs, - const string &local_path, const string &extension, bool force_install, - const string &repository) { -#ifdef DUCKDB_DISABLE_EXTENSION_LOAD - throw PermissionException("Installing external extensions is disabled through a compile time flag"); -#else - if (!config.options.enable_external_access) { - throw PermissionException("Installing extensions is disabled through configuration"); - } - auto extension_name = ApplyExtensionAlias(fs.ExtractBaseName(extension)); - - string local_extension_path = fs.JoinPath(local_path, extension_name + ".duckdb_extension"); - if (fs.FileExists(local_extension_path) && !force_install) { - return; - } - - auto uuid = UUID::ToString(UUID::GenerateRandomUUID()); - string temp_path = local_extension_path + ".tmp-" + uuid; - if (fs.FileExists(temp_path)) { - fs.RemoveFile(temp_path); - } - auto is_http_url = StringUtil::Contains(extension, "http://"); - if (fs.FileExists(extension)) { - idx_t file_size; - auto in_buffer = ReadExtensionFileFromDisk(fs, extension, file_size); - WriteExtensionFileToDisk(fs, temp_path, in_buffer.get(), file_size); - - if (fs.FileExists(local_extension_path) && force_install) { - fs.RemoveFile(local_extension_path); - } - fs.MoveFile(temp_path, local_extension_path); - return; - } else if (StringUtil::Contains(extension, "/") && !is_http_url) { - throw IOException("Failed to read extension from \"%s\": no such file", extension); - } - -#ifdef DISABLE_DUCKDB_REMOTE_INSTALL - throw BinderException("Remote extension installation is disabled through configuration"); -#else - - string url_template = ExtensionUrlTemplate(client_config, repository); - - if (is_http_url) { - url_template = extension; - extension_name = ""; - } - - string url = ExtensionFinalizeUrlTemplate(url_template, extension_name); - - string no_http = StringUtil::Replace(url, "http://", ""); - - idx_t next = no_http.find('/', 0); - if (next == string::npos) { - throw IOException("No slash in URL template"); - } - - // Special case to install extension from a local file, useful for testing - if (!StringUtil::Contains(url_template, "http://")) { - string file = fs.ConvertSeparators(url); - if (!fs.FileExists(file)) { - // check for non-gzipped variant - file = file.substr(0, file.size() - 3); - if (!fs.FileExists(file)) { - throw IOException("Failed to copy local extension \"%s\" at PATH \"%s\"\n", extension_name, file); - } - } - auto read_handle = fs.OpenFile(file, FileFlags::FILE_FLAGS_READ); - auto test_data = std::unique_ptr {new unsigned char[read_handle->GetFileSize()]}; - read_handle->Read(test_data.get(), read_handle->GetFileSize()); - WriteExtensionFileToDisk(fs, temp_path, (void *)test_data.get(), read_handle->GetFileSize()); - - if (fs.FileExists(local_extension_path) && force_install) { - fs.RemoveFile(local_extension_path); - } - fs.MoveFile(temp_path, local_extension_path); - return; - } - - // Push the substring [last, next) on to splits - auto hostname_without_http = no_http.substr(0, next); - auto url_local_part = no_http.substr(next); - - auto url_base = "http://" + hostname_without_http; - duckdb_httplib::Client cli(url_base.c_str()); - - duckdb_httplib::Headers headers = {{"User-Agent", StringUtil::Format("DuckDB %s %s %s", DuckDB::LibraryVersion(), - DuckDB::SourceID(), DuckDB::Platform())}}; - - auto res = cli.Get(url_local_part.c_str(), headers); - - if (!res || res->status != 200) { - // create suggestions - string message; - auto exact_match = ExtensionHelper::CreateSuggestions(extension_name, message); - if (exact_match) { - message += "\nAre you using a development build? In this case, extensions might not (yet) be uploaded."; - } - if (res.error() == duckdb_httplib::Error::Success) { - throw HTTPException(res.value(), "Failed to download extension \"%s\" at URL \"%s%s\"\n%s", extension_name, - url_base, url_local_part, message); - } else { - throw IOException("Failed to download extension \"%s\" at URL \"%s%s\"\n%s (ERROR %s)", extension_name, - url_base, url_local_part, message, to_string(res.error())); - } - } - auto decompressed_body = GZipFileSystem::UncompressGZIPString(res->body); - - WriteExtensionFileToDisk(fs, temp_path, (void *)decompressed_body.data(), decompressed_body.size()); - - if (fs.FileExists(local_extension_path) && force_install) { - fs.RemoveFile(local_extension_path); - } - fs.MoveFile(temp_path, local_extension_path); -#endif -#endif -} - -} // namespace duckdb - - - - - - -#ifndef DUCKDB_NO_THREADS -#include -#endif // DUCKDB_NO_THREADS - -#ifdef WASM_LOADABLE_EXTENSIONS -#include -#endif - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Load External Extension -//===--------------------------------------------------------------------===// -#ifndef DUCKDB_DISABLE_EXTENSION_LOAD -typedef void (*ext_init_fun_t)(DatabaseInstance &); -typedef const char *(*ext_version_fun_t)(void); -typedef bool (*ext_is_storage_t)(void); - -template -static T LoadFunctionFromDLL(void *dll, const string &function_name, const string &filename) { - auto function = dlsym(dll, function_name.c_str()); - if (!function) { - throw IOException("File \"%s\" did not contain function \"%s\": %s", filename, function_name, GetDLError()); - } - return (T)function; -} - -static void ComputeSHA256String(const std::string &to_hash, std::string *res) { - // Invoke MbedTls function to actually compute sha256 - *res = duckdb_mbedtls::MbedTlsWrapper::ComputeSha256Hash(to_hash); -} - -static void ComputeSHA256FileSegment(FileHandle *handle, const idx_t start, const idx_t end, std::string *res) { - idx_t iter = start; - const idx_t segment_size = 1024 * 8; - - duckdb_mbedtls::MbedTlsWrapper::SHA256State state; - - std::string to_hash; - while (iter < end) { - idx_t len = std::min(end - iter, segment_size); - to_hash.resize(len); - handle->Read((void *)to_hash.data(), len, iter); - - state.AddString(to_hash); - - iter += segment_size; - } - - *res = state.Finalize(); -} -#endif - -bool ExtensionHelper::TryInitialLoad(DBConfig &config, FileSystem &fs, const string &extension, - ExtensionInitResult &result, string &error, - optional_ptr client_config) { -#ifdef DUCKDB_DISABLE_EXTENSION_LOAD - throw PermissionException("Loading external extensions is disabled through a compile time flag"); -#else - if (!config.options.enable_external_access) { - throw PermissionException("Loading external extensions is disabled through configuration"); - } - auto filename = fs.ConvertSeparators(extension); - - // shorthand case - if (!ExtensionHelper::IsFullPath(extension)) { - string extension_name = ApplyExtensionAlias(extension); -#ifdef WASM_LOADABLE_EXTENSIONS - string url_template = ExtensionUrlTemplate(client_config, ""); - string url = ExtensionFinalizeUrlTemplate(url_template, extension_name); - - char *str = (char *)EM_ASM_PTR( - { - var jsString = ((typeof runtime == 'object') && runtime && (typeof runtime.whereToLoad == 'function') && - runtime.whereToLoad) - ? runtime.whereToLoad(UTF8ToString($0)) - : (UTF8ToString($1)); - var lengthBytes = lengthBytesUTF8(jsString) + 1; - // 'jsString.length' would return the length of the string as UTF-16 - // units, but Emscripten C strings operate as UTF-8. - var stringOnWasmHeap = _malloc(lengthBytes); - stringToUTF8(jsString, stringOnWasmHeap, lengthBytes); - return stringOnWasmHeap; - }, - filename.c_str(), url.c_str()); - std::string address(str); - free(str); - - filename = address; -#else - - string local_path = - !config.options.extension_directory.empty() ? config.options.extension_directory : fs.GetHomeDirectory(); - - // convert random separators to platform-canonic - local_path = fs.ConvertSeparators(local_path); - // expand ~ in extension directory - local_path = fs.ExpandPath(local_path); - auto path_components = PathComponents(); - for (auto &path_ele : path_components) { - local_path = fs.JoinPath(local_path, path_ele); - } - filename = fs.JoinPath(local_path, extension_name + ".duckdb_extension"); -#endif - } - if (!fs.FileExists(filename)) { - string message; - bool exact_match = ExtensionHelper::CreateSuggestions(extension, message); - if (exact_match) { - message += "\nInstall it first using \"INSTALL " + extension + "\"."; - } - error = StringUtil::Format("Extension \"%s\" not found.\n%s", filename, message); - return false; - } - if (!config.options.allow_unsigned_extensions) { - auto handle = fs.OpenFile(filename, FileFlags::FILE_FLAGS_READ); - - // signature is the last 256 bytes of the file - - string signature; - signature.resize(256); - - auto signature_offset = handle->GetFileSize() - signature.size(); - - const idx_t maxLenChunks = 1024ULL * 1024ULL; - const idx_t numChunks = (signature_offset + maxLenChunks - 1) / maxLenChunks; - std::vector hash_chunks(numChunks); - std::vector splits(numChunks + 1); - - for (idx_t i = 0; i < numChunks; i++) { - splits[i] = maxLenChunks * i; - } - splits.back() = signature_offset; - -#ifndef DUCKDB_NO_THREADS - std::vector threads; - threads.reserve(numChunks); - for (idx_t i = 0; i < numChunks; i++) { - threads.emplace_back(ComputeSHA256FileSegment, handle.get(), splits[i], splits[i + 1], &hash_chunks[i]); - } - - for (auto &thread : threads) { - thread.join(); - } -#else - for (idx_t i = 0; i < numChunks; i++) { - ComputeSHA256FileSegment(handle.get(), splits[i], splits[i + 1], &hash_chunks[i]); - } -#endif // DUCKDB_NO_THREADS - - string hash_concatenation; - hash_concatenation.reserve(32 * numChunks); // 256 bits -> 32 bytes per chunk - - for (auto &hash_chunk : hash_chunks) { - hash_concatenation += hash_chunk; - } - - string two_level_hash; - ComputeSHA256String(hash_concatenation, &two_level_hash); - - // TODO maybe we should do a stream read / hash update here - handle->Read((void *)signature.data(), signature.size(), signature_offset); - - bool any_valid = false; - for (auto &key : ExtensionHelper::GetPublicKeys()) { - if (duckdb_mbedtls::MbedTlsWrapper::IsValidSha256Signature(key, signature, two_level_hash)) { - any_valid = true; - break; - } - } - if (!any_valid) { - throw IOException(config.error_manager->FormatException(ErrorType::UNSIGNED_EXTENSION, filename)); - } - } - auto basename = fs.ExtractBaseName(filename); - -#ifdef WASM_LOADABLE_EXTENSIONS - EM_ASM( - { - // Next few lines should argubly in separate JavaScript-land function call - // TODO: move them out / have them configurable - const xhr = new XMLHttpRequest(); - xhr.open("GET", UTF8ToString($0), false); - xhr.responseType = "arraybuffer"; - xhr.send(null); - var uInt8Array = xhr.response; - WebAssembly.validate(uInt8Array); - console.log('Loading extension ', UTF8ToString($1)); - - // Here we add the uInt8Array to Emscripten's filesystem, for it to be found by dlopen - FS.writeFile(UTF8ToString($1), new Uint8Array(uInt8Array)); - }, - filename.c_str(), basename.c_str()); - auto dopen_from = basename; -#else - auto dopen_from = filename; -#endif - - auto lib_hdl = dlopen(dopen_from.c_str(), RTLD_NOW | RTLD_LOCAL); - if (!lib_hdl) { - throw IOException("Extension \"%s\" could not be loaded: %s", filename, GetDLError()); - } - - ext_version_fun_t version_fun; - auto version_fun_name = basename + "_version"; - - version_fun = LoadFunctionFromDLL(lib_hdl, version_fun_name, filename); - - std::string engine_version = std::string(DuckDB::LibraryVersion()); - - auto version_fun_result = (*version_fun)(); - if (version_fun_result == nullptr) { - throw InvalidInputException("Extension \"%s\" returned a nullptr", filename); - } - std::string extension_version = std::string(version_fun_result); - - // Trim v's if necessary - std::string extension_version_trimmed = extension_version; - std::string engine_version_trimmed = engine_version; - if (extension_version.length() > 0 && extension_version[0] == 'v') { - extension_version_trimmed = extension_version.substr(1); - } - if (engine_version.length() > 0 && engine_version[0] == 'v') { - engine_version_trimmed = engine_version.substr(1); - } - - if (extension_version_trimmed != engine_version_trimmed) { - throw InvalidInputException("Extension \"%s\" version (%s) does not match DuckDB version (%s)", filename, - extension_version, engine_version); - } - - result.basename = basename; - result.filename = filename; - result.lib_hdl = lib_hdl; - return true; -#endif -} - -ExtensionInitResult ExtensionHelper::InitialLoad(DBConfig &config, FileSystem &fs, const string &extension, - optional_ptr client_config) { - string error; - ExtensionInitResult result; - if (!TryInitialLoad(config, fs, extension, result, error, client_config)) { - if (!ExtensionHelper::AllowAutoInstall(extension)) { - throw IOException(error); - } - // the extension load failed - try installing the extension - ExtensionHelper::InstallExtension(config, fs, extension, false); - // try loading again - if (!TryInitialLoad(config, fs, extension, result, error, client_config)) { - throw IOException(error); - } - } - return result; -} - -bool ExtensionHelper::IsFullPath(const string &extension) { - return StringUtil::Contains(extension, ".") || StringUtil::Contains(extension, "/") || - StringUtil::Contains(extension, "\\"); -} - -string ExtensionHelper::GetExtensionName(const string &original_name) { - auto extension = StringUtil::Lower(original_name); - if (!IsFullPath(extension)) { - return ExtensionHelper::ApplyExtensionAlias(extension); - } - auto splits = StringUtil::Split(StringUtil::Replace(extension, "\\", "/"), '/'); - if (splits.empty()) { - return ExtensionHelper::ApplyExtensionAlias(extension); - } - splits = StringUtil::Split(splits.back(), '.'); - if (splits.empty()) { - return ExtensionHelper::ApplyExtensionAlias(extension); - } - return ExtensionHelper::ApplyExtensionAlias(splits.front()); -} - -void ExtensionHelper::LoadExternalExtension(DatabaseInstance &db, FileSystem &fs, const string &extension, - optional_ptr client_config) { - if (db.ExtensionIsLoaded(extension)) { - return; - } -#ifdef DUCKDB_DISABLE_EXTENSION_LOAD - throw PermissionException("Loading external extensions is disabled through a compile time flag"); -#else - auto res = InitialLoad(DBConfig::GetConfig(db), fs, extension, client_config); - auto init_fun_name = res.basename + "_init"; - - ext_init_fun_t init_fun; - init_fun = LoadFunctionFromDLL(res.lib_hdl, init_fun_name, res.filename); - - try { - (*init_fun)(db); - } catch (std::exception &e) { - throw InvalidInputException("Initialization function \"%s\" from file \"%s\" threw an exception: \"%s\"", - init_fun_name, res.filename, e.what()); - } - - db.SetExtensionLoaded(extension); -#endif -} - -void ExtensionHelper::LoadExternalExtension(ClientContext &context, const string &extension) { - LoadExternalExtension(DatabaseInstance::GetDatabase(context), FileSystem::GetFileSystem(context), extension, - &ClientConfig::GetConfig(context)); -} - -string ExtensionHelper::ExtractExtensionPrefixFromPath(const string &path) { - auto first_colon = path.find(':'); - if (first_colon == string::npos || first_colon < 2) { // needs to be at least two characters because windows c: ... - return ""; - } - auto extension = path.substr(0, first_colon); - - if (path.substr(first_colon, 3) == "://") { - // these are not extensions - return ""; - } - - D_ASSERT(extension.size() > 1); - // needs to be alphanumeric - for (auto &ch : extension) { - if (!isalnum(ch) && ch != '_') { - return ""; - } - } - return extension; -} - -} // namespace duckdb - - - - - - - - - - - - - - - -namespace duckdb { - -void ExtensionUtil::RegisterFunction(DatabaseInstance &db, ScalarFunctionSet set) { - D_ASSERT(!set.name.empty()); - CreateScalarFunctionInfo info(std::move(set)); - auto &system_catalog = Catalog::GetSystemCatalog(db); - auto data = CatalogTransaction::GetSystemTransaction(db); - system_catalog.CreateFunction(data, info); -} - -void ExtensionUtil::RegisterFunction(DatabaseInstance &db, ScalarFunction function) { - D_ASSERT(!function.name.empty()); - ScalarFunctionSet set(function.name); - set.AddFunction(std::move(function)); - RegisterFunction(db, std::move(set)); -} - -void ExtensionUtil::RegisterFunction(DatabaseInstance &db, AggregateFunction function) { - D_ASSERT(!function.name.empty()); - AggregateFunctionSet set(function.name); - set.AddFunction(std::move(function)); - RegisterFunction(db, std::move(set)); -} - -void ExtensionUtil::RegisterFunction(DatabaseInstance &db, AggregateFunctionSet set) { - D_ASSERT(!set.name.empty()); - CreateAggregateFunctionInfo info(std::move(set)); - auto &system_catalog = Catalog::GetSystemCatalog(db); - auto data = CatalogTransaction::GetSystemTransaction(db); - system_catalog.CreateFunction(data, info); -} - -void ExtensionUtil::RegisterFunction(DatabaseInstance &db, TableFunction function) { - D_ASSERT(!function.name.empty()); - TableFunctionSet set(function.name); - set.AddFunction(std::move(function)); - RegisterFunction(db, std::move(set)); -} - -void ExtensionUtil::RegisterFunction(DatabaseInstance &db, TableFunctionSet function) { - D_ASSERT(!function.name.empty()); - CreateTableFunctionInfo info(std::move(function)); - auto &system_catalog = Catalog::GetSystemCatalog(db); - auto data = CatalogTransaction::GetSystemTransaction(db); - system_catalog.CreateFunction(data, info); -} - -void ExtensionUtil::RegisterFunction(DatabaseInstance &db, PragmaFunction function) { - D_ASSERT(!function.name.empty()); - PragmaFunctionSet set(function.name); - set.AddFunction(std::move(function)); - RegisterFunction(db, std::move(set)); -} - -void ExtensionUtil::RegisterFunction(DatabaseInstance &db, PragmaFunctionSet function) { - D_ASSERT(!function.name.empty()); - auto function_name = function.name; - CreatePragmaFunctionInfo info(std::move(function_name), std::move(function)); - auto &system_catalog = Catalog::GetSystemCatalog(db); - auto data = CatalogTransaction::GetSystemTransaction(db); - system_catalog.CreatePragmaFunction(data, info); -} - -void ExtensionUtil::RegisterFunction(DatabaseInstance &db, CopyFunction function) { - CreateCopyFunctionInfo info(std::move(function)); - auto &system_catalog = Catalog::GetSystemCatalog(db); - auto data = CatalogTransaction::GetSystemTransaction(db); - system_catalog.CreateCopyFunction(data, info); -} - -void ExtensionUtil::RegisterFunction(DatabaseInstance &db, CreateMacroInfo &info) { - auto &system_catalog = Catalog::GetSystemCatalog(db); - auto data = CatalogTransaction::GetSystemTransaction(db); - system_catalog.CreateFunction(data, info); -} - -void ExtensionUtil::RegisterCollation(DatabaseInstance &db, CreateCollationInfo &info) { - auto &system_catalog = Catalog::GetSystemCatalog(db); - auto data = CatalogTransaction::GetSystemTransaction(db); - info.on_conflict = OnCreateConflict::IGNORE_ON_CONFLICT; - system_catalog.CreateCollation(data, info); -} - -void ExtensionUtil::AddFunctionOverload(DatabaseInstance &db, ScalarFunction function) { - auto &scalar_function = ExtensionUtil::GetFunction(db, function.name); - scalar_function.functions.AddFunction(std::move(function)); -} - -void ExtensionUtil::AddFunctionOverload(DatabaseInstance &db, ScalarFunctionSet functions) { // NOLINT - D_ASSERT(!functions.name.empty()); - auto &scalar_function = ExtensionUtil::GetFunction(db, functions.name); - for (auto &function : functions.functions) { - function.name = functions.name; - scalar_function.functions.AddFunction(std::move(function)); - } -} - -void ExtensionUtil::AddFunctionOverload(DatabaseInstance &db, TableFunctionSet functions) { // NOLINT - auto &table_function = ExtensionUtil::GetTableFunction(db, functions.name); - for (auto &function : functions.functions) { - function.name = functions.name; - table_function.functions.AddFunction(std::move(function)); - } -} - -ScalarFunctionCatalogEntry &ExtensionUtil::GetFunction(DatabaseInstance &db, const string &name) { - D_ASSERT(!name.empty()); - auto &system_catalog = Catalog::GetSystemCatalog(db); - auto data = CatalogTransaction::GetSystemTransaction(db); - auto &schema = system_catalog.GetSchema(data, DEFAULT_SCHEMA); - auto catalog_entry = schema.GetEntry(data, CatalogType::SCALAR_FUNCTION_ENTRY, name); - if (!catalog_entry) { - throw InvalidInputException("Function with name \"%s\" not found in ExtensionUtil::GetFunction", name); - } - return catalog_entry->Cast(); -} - -TableFunctionCatalogEntry &ExtensionUtil::GetTableFunction(DatabaseInstance &db, const string &name) { - D_ASSERT(!name.empty()); - auto &system_catalog = Catalog::GetSystemCatalog(db); - auto data = CatalogTransaction::GetSystemTransaction(db); - auto &schema = system_catalog.GetSchema(data, DEFAULT_SCHEMA); - auto catalog_entry = schema.GetEntry(data, CatalogType::TABLE_FUNCTION_ENTRY, name); - if (!catalog_entry) { - throw InvalidInputException("Function with name \"%s\" not found in ExtensionUtil::GetTableFunction", name); - } - return catalog_entry->Cast(); -} - -void ExtensionUtil::RegisterType(DatabaseInstance &db, string type_name, LogicalType type) { - D_ASSERT(!type_name.empty()); - CreateTypeInfo info(std::move(type_name), std::move(type)); - info.temporary = true; - info.internal = true; - auto &system_catalog = Catalog::GetSystemCatalog(db); - auto data = CatalogTransaction::GetSystemTransaction(db); - system_catalog.CreateType(data, info); -} - -void ExtensionUtil::RegisterCastFunction(DatabaseInstance &db, const LogicalType &source, const LogicalType &target, - BoundCastInfo function, int64_t implicit_cast_cost) { - auto &config = DBConfig::GetConfig(db); - auto &casts = config.GetCastFunctions(); - casts.RegisterCastFunction(source, target, std::move(function), implicit_cast_cost); -} - -} // namespace duckdb - - -namespace duckdb { - -Extension::~Extension() { -} - -} // namespace duckdb - - - - - -namespace duckdb { - -MaterializedQueryResult::MaterializedQueryResult(StatementType statement_type, StatementProperties properties, - vector names_p, unique_ptr collection_p, - ClientProperties client_properties) - : QueryResult(QueryResultType::MATERIALIZED_RESULT, statement_type, std::move(properties), collection_p->Types(), - std::move(names_p), std::move(client_properties)), - collection(std::move(collection_p)), scan_initialized(false) { -} - -MaterializedQueryResult::MaterializedQueryResult(PreservedError error) - : QueryResult(QueryResultType::MATERIALIZED_RESULT, std::move(error)), scan_initialized(false) { -} - -string MaterializedQueryResult::ToString() { - string result; - if (success) { - result = HeaderToString(); - result += "[ Rows: " + to_string(collection->Count()) + "]\n"; - auto &coll = Collection(); - for (auto &row : coll.Rows()) { - for (idx_t col_idx = 0; col_idx < coll.ColumnCount(); col_idx++) { - if (col_idx > 0) { - result += "\t"; - } - auto val = row.GetValue(col_idx); - result += val.IsNull() ? "NULL" : StringUtil::Replace(val.ToString(), string("\0", 1), "\\0"); - } - result += "\n"; - } - result += "\n"; - } else { - result = GetError() + "\n"; - } - return result; -} - -string MaterializedQueryResult::ToBox(ClientContext &context, const BoxRendererConfig &config) { - if (!success) { - return GetError() + "\n"; - } - if (!collection) { - return "Internal error - result was successful but there was no collection"; - } - BoxRenderer renderer(config); - return renderer.ToString(context, names, Collection()); -} - -Value MaterializedQueryResult::GetValue(idx_t column, idx_t index) { - if (!row_collection) { - row_collection = make_uniq(collection->GetRows()); - } - return row_collection->GetValue(column, index); -} - -idx_t MaterializedQueryResult::RowCount() const { - return collection ? collection->Count() : 0; -} - -ColumnDataCollection &MaterializedQueryResult::Collection() { - if (HasError()) { - throw InvalidInputException("Attempting to get collection from an unsuccessful query result\n: Error %s", - GetError()); - } - if (!collection) { - throw InternalException("Missing collection from materialized query result"); - } - return *collection; -} - -unique_ptr MaterializedQueryResult::Fetch() { - return FetchRaw(); -} - -unique_ptr MaterializedQueryResult::FetchRaw() { - if (HasError()) { - throw InvalidInputException("Attempting to fetch from an unsuccessful query result\nError: %s", GetError()); - } - auto result = make_uniq(); - collection->InitializeScanChunk(*result); - if (!scan_initialized) { - // we disallow zero copy so the chunk is independently usable even after the result is destroyed - collection->InitializeScan(scan_state, ColumnDataScanProperties::DISALLOW_ZERO_COPY); - scan_initialized = true; - } - collection->Scan(scan_state, *result); - if (result->size() == 0) { - return nullptr; - } - return result; -} - -} // namespace duckdb - - - - -namespace duckdb { - -PendingQueryResult::PendingQueryResult(shared_ptr context_p, PreparedStatementData &statement, - vector types_p, bool allow_stream_result) - : BaseQueryResult(QueryResultType::PENDING_RESULT, statement.statement_type, statement.properties, - std::move(types_p), statement.names), - context(std::move(context_p)), allow_stream_result(allow_stream_result) { -} - -PendingQueryResult::PendingQueryResult(PreservedError error) - : BaseQueryResult(QueryResultType::PENDING_RESULT, std::move(error)) { -} - -PendingQueryResult::~PendingQueryResult() { -} - -unique_ptr PendingQueryResult::LockContext() { - if (!context) { - if (HasError()) { - throw InvalidInputException( - "Attempting to execute an unsuccessful or closed pending query result\nError: %s", GetError()); - } - throw InvalidInputException("Attempting to execute an unsuccessful or closed pending query result"); - } - return context->LockContext(); -} - -void PendingQueryResult::CheckExecutableInternal(ClientContextLock &lock) { - bool invalidated = HasError() || !context; - if (!invalidated) { - invalidated = !context->IsActiveResult(lock, this); - } - if (invalidated) { - if (HasError()) { - throw InvalidInputException( - "Attempting to execute an unsuccessful or closed pending query result\nError: %s", GetError()); - } - throw InvalidInputException("Attempting to execute an unsuccessful or closed pending query result"); - } -} - -PendingExecutionResult PendingQueryResult::ExecuteTask() { - auto lock = LockContext(); - return ExecuteTaskInternal(*lock); -} - -PendingExecutionResult PendingQueryResult::ExecuteTaskInternal(ClientContextLock &lock) { - CheckExecutableInternal(lock); - return context->ExecuteTaskInternal(lock, *this); -} - -unique_ptr PendingQueryResult::ExecuteInternal(ClientContextLock &lock) { - CheckExecutableInternal(lock); - // Busy wait while execution is not finished - while (!IsFinished(ExecuteTaskInternal(lock))) { - } - if (HasError()) { - return make_uniq(error); - } - auto result = context->FetchResultInternal(lock, *this); - Close(); - return result; -} - -unique_ptr PendingQueryResult::Execute() { - auto lock = LockContext(); - return ExecuteInternal(*lock); -} - -void PendingQueryResult::Close() { - context.reset(); -} - -bool PendingQueryResult::IsFinished(PendingExecutionResult result) { - if (result == PendingExecutionResult::RESULT_READY || result == PendingExecutionResult::EXECUTION_ERROR) { - return true; - } - return false; -} - -} // namespace duckdb - - - - - -namespace duckdb { - -PreparedStatement::PreparedStatement(shared_ptr context, shared_ptr data_p, - string query, idx_t n_param, case_insensitive_map_t named_param_map_p) - : context(std::move(context)), data(std::move(data_p)), query(std::move(query)), success(true), n_param(n_param), - named_param_map(std::move(named_param_map_p)) { - D_ASSERT(data || !success); -} - -PreparedStatement::PreparedStatement(PreservedError error) : context(nullptr), success(false), error(std::move(error)) { -} - -PreparedStatement::~PreparedStatement() { -} - -const string &PreparedStatement::GetError() { - D_ASSERT(HasError()); - return error.Message(); -} - -PreservedError &PreparedStatement::GetErrorObject() { - return error; -} - -bool PreparedStatement::HasError() const { - return !success; -} - -idx_t PreparedStatement::ColumnCount() { - D_ASSERT(data); - return data->types.size(); -} - -StatementType PreparedStatement::GetStatementType() { - D_ASSERT(data); - return data->statement_type; -} - -StatementProperties PreparedStatement::GetStatementProperties() { - D_ASSERT(data); - return data->properties; -} - -const vector &PreparedStatement::GetTypes() { - D_ASSERT(data); - return data->types; -} - -const vector &PreparedStatement::GetNames() { - D_ASSERT(data); - return data->names; -} - -case_insensitive_map_t PreparedStatement::GetExpectedParameterTypes() const { - D_ASSERT(data); - case_insensitive_map_t expected_types(data->value_map.size()); - for (auto &it : data->value_map) { - auto &identifier = it.first; - D_ASSERT(data->value_map.count(identifier)); - D_ASSERT(it.second); - expected_types[identifier] = it.second->GetValue().type(); - } - return expected_types; -} - -unique_ptr PreparedStatement::Execute(case_insensitive_map_t &named_values, - bool allow_stream_result) { - auto pending = PendingQuery(named_values, allow_stream_result); - if (pending->HasError()) { - return make_uniq(pending->GetErrorObject()); - } - return pending->Execute(); -} - -unique_ptr PreparedStatement::Execute(vector &values, bool allow_stream_result) { - auto pending = PendingQuery(values, allow_stream_result); - if (pending->HasError()) { - return make_uniq(pending->GetErrorObject()); - } - return pending->Execute(); -} - -unique_ptr PreparedStatement::PendingQuery(vector &values, bool allow_stream_result) { - case_insensitive_map_t named_values; - for (idx_t i = 0; i < values.size(); i++) { - auto &val = values[i]; - named_values[std::to_string(i + 1)] = val; - } - return PendingQuery(named_values, allow_stream_result); -} - -unique_ptr PreparedStatement::PendingQuery(case_insensitive_map_t &named_values, - bool allow_stream_result) { - if (!success) { - auto exception = InvalidInputException("Attempting to execute an unsuccessfully prepared statement!"); - return make_uniq(PreservedError(exception)); - } - PendingQueryParameters parameters; - parameters.parameters = &named_values; - - try { - VerifyParameters(named_values, named_param_map); - } catch (const Exception &ex) { - return make_uniq(PreservedError(ex)); - } - - D_ASSERT(data); - parameters.allow_stream_result = allow_stream_result && data->properties.allow_stream_result; - auto result = context->PendingQuery(query, data, parameters); - // The result should not contain any reference to the 'vector parameters.parameters' - return result; -} - -} // namespace duckdb - - - - -namespace duckdb { - -PreparedStatementData::PreparedStatementData(StatementType type) : statement_type(type) { -} - -PreparedStatementData::~PreparedStatementData() { -} - -void PreparedStatementData::CheckParameterCount(idx_t parameter_count) { - const auto required = properties.parameter_count; - if (parameter_count != required) { - throw BinderException("Parameter/argument count mismatch for prepared statement. Expected %llu, got %llu", - required, parameter_count); - } -} - -bool PreparedStatementData::RequireRebind(ClientContext &context, optional_ptr> values) { - idx_t count = values ? values->size() : 0; - CheckParameterCount(count); - if (!unbound_statement) { - // no unbound statement!? cannot rebind? - return false; - } - if (!properties.bound_all_parameters) { - // parameters not yet bound: query always requires a rebind - return true; - } - if (Catalog::GetSystemCatalog(context).GetCatalogVersion() != catalog_version) { - //! context is out of bounds - return true; - } - for (auto &it : value_map) { - auto &identifier = it.first; - auto lookup = values->find(identifier); - D_ASSERT(lookup != values->end()); - if (lookup->second.type() != it.second->return_type) { - return true; - } - } - return false; -} - -void PreparedStatementData::Bind(case_insensitive_map_t values) { - // set parameters - D_ASSERT(!unbound_statement || unbound_statement->n_param == properties.parameter_count); - CheckParameterCount(values.size()); - - // bind the required values - for (auto &it : value_map) { - const string &identifier = it.first; - auto lookup = values.find(identifier); - if (lookup == values.end()) { - throw BinderException("Could not find parameter with identifier %s", identifier); - } - D_ASSERT(it.second); - auto &value = lookup->second; - if (!value.DefaultTryCastAs(it.second->return_type)) { - throw BinderException( - "Type mismatch for binding parameter with identifier %s, expected type %s but got type %s", identifier, - it.second->return_type.ToString().c_str(), value.type().ToString().c_str()); - } - it.second->SetValue(value); - } -} - -bool PreparedStatementData::TryGetType(const string &identifier, LogicalType &result) { - auto it = value_map.find(identifier); - if (it == value_map.end()) { - return false; - } - if (it->second->return_type.id() != LogicalTypeId::INVALID) { - result = it->second->return_type; - } else { - result = it->second->GetValue().type(); - } - return true; -} - -LogicalType PreparedStatementData::GetType(const string &identifier) { - LogicalType result; - if (!TryGetType(identifier, result)) { - throw BinderException("Could not find parameter identified with: %s", identifier); - } - return result; -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - -#include -#include - -namespace duckdb { - -QueryProfiler::QueryProfiler(ClientContext &context_p) - : context(context_p), running(false), query_requires_profiling(false), is_explain_analyze(false) { -} - -bool QueryProfiler::IsEnabled() const { - return is_explain_analyze ? true : ClientConfig::GetConfig(context).enable_profiler; -} - -bool QueryProfiler::IsDetailedEnabled() const { - return is_explain_analyze ? false : ClientConfig::GetConfig(context).enable_detailed_profiling; -} - -ProfilerPrintFormat QueryProfiler::GetPrintFormat() const { - return ClientConfig::GetConfig(context).profiler_print_format; -} - -bool QueryProfiler::PrintOptimizerOutput() const { - return GetPrintFormat() == ProfilerPrintFormat::QUERY_TREE_OPTIMIZER || IsDetailedEnabled(); -} - -string QueryProfiler::GetSaveLocation() const { - return is_explain_analyze ? string() : ClientConfig::GetConfig(context).profiler_save_location; -} - -QueryProfiler &QueryProfiler::Get(ClientContext &context) { - return *ClientData::Get(context).profiler; -} - -void QueryProfiler::StartQuery(string query, bool is_explain_analyze, bool start_at_optimizer) { - if (is_explain_analyze) { - StartExplainAnalyze(); - } - if (!IsEnabled()) { - return; - } - if (start_at_optimizer && !PrintOptimizerOutput()) { - // This is the StartQuery call before the optimizer, but we don't have to print optimizer output - return; - } - if (running) { - // Called while already running: this should only happen when we print optimizer output - D_ASSERT(PrintOptimizerOutput()); - return; - } - this->running = true; - this->query = std::move(query); - tree_map.clear(); - root = nullptr; - phase_timings.clear(); - phase_stack.clear(); - - main_query.Start(); -} - -bool QueryProfiler::OperatorRequiresProfiling(PhysicalOperatorType op_type) { - switch (op_type) { - case PhysicalOperatorType::ORDER_BY: - case PhysicalOperatorType::RESERVOIR_SAMPLE: - case PhysicalOperatorType::STREAMING_SAMPLE: - case PhysicalOperatorType::LIMIT: - case PhysicalOperatorType::LIMIT_PERCENT: - case PhysicalOperatorType::STREAMING_LIMIT: - case PhysicalOperatorType::TOP_N: - case PhysicalOperatorType::WINDOW: - case PhysicalOperatorType::UNNEST: - case PhysicalOperatorType::UNGROUPED_AGGREGATE: - case PhysicalOperatorType::HASH_GROUP_BY: - case PhysicalOperatorType::FILTER: - case PhysicalOperatorType::PROJECTION: - case PhysicalOperatorType::COPY_TO_FILE: - case PhysicalOperatorType::TABLE_SCAN: - case PhysicalOperatorType::CHUNK_SCAN: - case PhysicalOperatorType::DELIM_SCAN: - case PhysicalOperatorType::EXPRESSION_SCAN: - case PhysicalOperatorType::BLOCKWISE_NL_JOIN: - case PhysicalOperatorType::NESTED_LOOP_JOIN: - case PhysicalOperatorType::HASH_JOIN: - case PhysicalOperatorType::CROSS_PRODUCT: - case PhysicalOperatorType::PIECEWISE_MERGE_JOIN: - case PhysicalOperatorType::IE_JOIN: - case PhysicalOperatorType::DELIM_JOIN: - case PhysicalOperatorType::UNION: - case PhysicalOperatorType::RECURSIVE_CTE: - case PhysicalOperatorType::EMPTY_RESULT: - return true; - default: - return false; - } -} - -void QueryProfiler::Finalize(TreeNode &node) { - for (auto &child : node.children) { - Finalize(*child); - if (node.type == PhysicalOperatorType::UNION) { - node.info.elements += child->info.elements; - } - } -} - -void QueryProfiler::StartExplainAnalyze() { - this->is_explain_analyze = true; -} - -void QueryProfiler::EndQuery() { - lock_guard guard(flush_lock); - if (!IsEnabled() || !running) { - return; - } - - main_query.End(); - if (root) { - Finalize(*root); - } - this->running = false; - // print or output the query profiling after termination - // EXPLAIN ANALYSE should not be outputted by the profiler - if (IsEnabled() && !is_explain_analyze) { - string query_info = ToString(); - auto save_location = GetSaveLocation(); - if (!ClientConfig::GetConfig(context).emit_profiler_output) { - // disable output - } else if (save_location.empty()) { - Printer::Print(query_info); - Printer::Print("\n"); - } else { - WriteToFile(save_location.c_str(), query_info); - } - } - this->is_explain_analyze = false; -} -string QueryProfiler::ToString() const { - const auto format = GetPrintFormat(); - switch (format) { - case ProfilerPrintFormat::QUERY_TREE: - case ProfilerPrintFormat::QUERY_TREE_OPTIMIZER: - return QueryTreeToString(); - case ProfilerPrintFormat::JSON: - return ToJSON(); - default: - throw InternalException("Unknown ProfilerPrintFormat \"%s\"", format); - } -} - -void QueryProfiler::StartPhase(string new_phase) { - if (!IsEnabled() || !running) { - return; - } - - if (!phase_stack.empty()) { - // there are active phases - phase_profiler.End(); - // add the timing to all phases prior to this one - string prefix = ""; - for (auto &phase : phase_stack) { - phase_timings[phase] += phase_profiler.Elapsed(); - prefix += phase + " > "; - } - // when there are previous phases, we prefix the current phase with those phases - new_phase = prefix + new_phase; - } - - // start a new phase - phase_stack.push_back(new_phase); - // restart the timer - phase_profiler.Start(); -} - -void QueryProfiler::EndPhase() { - if (!IsEnabled() || !running) { - return; - } - D_ASSERT(phase_stack.size() > 0); - - // end the timer - phase_profiler.End(); - // add the timing to all currently active phases - for (auto &phase : phase_stack) { - phase_timings[phase] += phase_profiler.Elapsed(); - } - // now remove the last added phase - phase_stack.pop_back(); - - if (!phase_stack.empty()) { - phase_profiler.Start(); - } -} - -void QueryProfiler::Initialize(const PhysicalOperator &root_op) { - if (!IsEnabled() || !running) { - return; - } - this->query_requires_profiling = false; - this->root = CreateTree(root_op); - if (!query_requires_profiling) { - // query does not require profiling: disable profiling for this query - this->running = false; - tree_map.clear(); - root = nullptr; - phase_timings.clear(); - phase_stack.clear(); - } -} - -OperatorProfiler::OperatorProfiler(bool enabled_p) : enabled(enabled_p), active_operator(nullptr) { -} - -void OperatorProfiler::StartOperator(optional_ptr phys_op) { - if (!enabled) { - return; - } - - if (active_operator) { - throw InternalException("OperatorProfiler: Attempting to call StartOperator while another operator is active"); - } - - active_operator = phys_op; - - // start timing for current element - op.Start(); -} - -void OperatorProfiler::EndOperator(optional_ptr chunk) { - if (!enabled) { - return; - } - - if (!active_operator) { - throw InternalException("OperatorProfiler: Attempting to call EndOperator while another operator is active"); - } - - // finish timing for the current element - op.End(); - - AddTiming(*active_operator, op.Elapsed(), chunk ? chunk->size() : 0); - active_operator = nullptr; -} - -void OperatorProfiler::AddTiming(const PhysicalOperator &op, double time, idx_t elements) { - if (!enabled) { - return; - } - if (!Value::DoubleIsFinite(time)) { - return; - } - auto entry = timings.find(op); - if (entry == timings.end()) { - // add new entry - timings[op] = OperatorInformation(time, elements); - } else { - // add to existing entry - entry->second.time += time; - entry->second.elements += elements; - } -} -void OperatorProfiler::Flush(const PhysicalOperator &phys_op, ExpressionExecutor &expression_executor, - const string &name, int id) { - auto entry = timings.find(phys_op); - if (entry == timings.end()) { - return; - } - auto &operator_timing = timings.find(phys_op)->second; - if (int(operator_timing.executors_info.size()) <= id) { - operator_timing.executors_info.resize(id + 1); - } - operator_timing.executors_info[id] = make_uniq(expression_executor, name, id); - operator_timing.name = phys_op.GetName(); -} - -void QueryProfiler::Flush(OperatorProfiler &profiler) { - lock_guard guard(flush_lock); - if (!IsEnabled() || !running) { - return; - } - for (auto &node : profiler.timings) { - auto &op = node.first.get(); - auto entry = tree_map.find(op); - D_ASSERT(entry != tree_map.end()); - auto &tree_node = entry->second.get(); - - tree_node.info.time += node.second.time; - tree_node.info.elements += node.second.elements; - if (!IsDetailedEnabled()) { - continue; - } - for (auto &info : node.second.executors_info) { - if (!info) { - continue; - } - auto info_id = info->id; - if (int32_t(tree_node.info.executors_info.size()) <= info_id) { - tree_node.info.executors_info.resize(info_id + 1); - } - tree_node.info.executors_info[info_id] = std::move(info); - } - } - profiler.timings.clear(); -} - -static string DrawPadded(const string &str, idx_t width) { - if (str.size() > width) { - return str.substr(0, width); - } else { - width -= str.size(); - int half_spaces = width / 2; - int extra_left_space = width % 2 != 0 ? 1 : 0; - return string(half_spaces + extra_left_space, ' ') + str + string(half_spaces, ' '); - } -} - -static string RenderTitleCase(string str) { - str = StringUtil::Lower(str); - str[0] = toupper(str[0]); - for (idx_t i = 0; i < str.size(); i++) { - if (str[i] == '_') { - str[i] = ' '; - if (i + 1 < str.size()) { - str[i + 1] = toupper(str[i + 1]); - } - } - } - return str; -} - -static string RenderTiming(double timing) { - string timing_s; - if (timing >= 1) { - timing_s = StringUtil::Format("%.2f", timing); - } else if (timing >= 0.1) { - timing_s = StringUtil::Format("%.3f", timing); - } else { - timing_s = StringUtil::Format("%.4f", timing); - } - return timing_s + "s"; -} - -string QueryProfiler::QueryTreeToString() const { - std::stringstream str; - QueryTreeToStream(str); - return str.str(); -} - -void QueryProfiler::QueryTreeToStream(std::ostream &ss) const { - if (!IsEnabled()) { - ss << "Query profiling is disabled. Call " - "Connection::EnableProfiling() to enable profiling!"; - return; - } - ss << "┌─────────────────────────────────────┐\n"; - ss << "│┌───────────────────────────────────┐│\n"; - ss << "││ Query Profiling Information ││\n"; - ss << "│└───────────────────────────────────┘│\n"; - ss << "└─────────────────────────────────────┘\n"; - ss << StringUtil::Replace(query, "\n", " ") + "\n"; - - // checking the tree to ensure the query is really empty - // the query string is empty when a logical plan is deserialized - if (query.empty() && !root) { - return; - } - - if (context.client_data->http_state && !context.client_data->http_state->IsEmpty()) { - string read = - "in: " + StringUtil::BytesToHumanReadableString(context.client_data->http_state->total_bytes_received); - string written = - "out: " + StringUtil::BytesToHumanReadableString(context.client_data->http_state->total_bytes_sent); - string head = "#HEAD: " + to_string(context.client_data->http_state->head_count); - string get = "#GET: " + to_string(context.client_data->http_state->get_count); - string put = "#PUT: " + to_string(context.client_data->http_state->put_count); - string post = "#POST: " + to_string(context.client_data->http_state->post_count); - - constexpr idx_t TOTAL_BOX_WIDTH = 39; - ss << "┌─────────────────────────────────────┐\n"; - ss << "│┌───────────────────────────────────┐│\n"; - ss << "││ HTTP Stats: ││\n"; - ss << "││ ││\n"; - ss << "││" + DrawPadded(read, TOTAL_BOX_WIDTH - 4) + "││\n"; - ss << "││" + DrawPadded(written, TOTAL_BOX_WIDTH - 4) + "││\n"; - ss << "││" + DrawPadded(head, TOTAL_BOX_WIDTH - 4) + "││\n"; - ss << "││" + DrawPadded(get, TOTAL_BOX_WIDTH - 4) + "││\n"; - ss << "││" + DrawPadded(put, TOTAL_BOX_WIDTH - 4) + "││\n"; - ss << "││" + DrawPadded(post, TOTAL_BOX_WIDTH - 4) + "││\n"; - ss << "│└───────────────────────────────────┘│\n"; - ss << "└─────────────────────────────────────┘\n"; - } - - constexpr idx_t TOTAL_BOX_WIDTH = 39; - ss << "┌─────────────────────────────────────┐\n"; - ss << "│┌───────────────────────────────────┐│\n"; - string total_time = "Total Time: " + RenderTiming(main_query.Elapsed()); - ss << "││" + DrawPadded(total_time, TOTAL_BOX_WIDTH - 4) + "││\n"; - ss << "│└───────────────────────────────────┘│\n"; - ss << "└─────────────────────────────────────┘\n"; - // print phase timings - if (PrintOptimizerOutput()) { - bool has_previous_phase = false; - for (const auto &entry : GetOrderedPhaseTimings()) { - if (!StringUtil::Contains(entry.first, " > ")) { - // primary phase! - if (has_previous_phase) { - ss << "│└───────────────────────────────────┘│\n"; - ss << "└─────────────────────────────────────┘\n"; - } - ss << "┌─────────────────────────────────────┐\n"; - ss << "│" + - DrawPadded(RenderTitleCase(entry.first) + ": " + RenderTiming(entry.second), - TOTAL_BOX_WIDTH - 2) + - "│\n"; - ss << "│┌───────────────────────────────────┐│\n"; - has_previous_phase = true; - } else { - string entry_name = StringUtil::Split(entry.first, " > ")[1]; - ss << "││" + - DrawPadded(RenderTitleCase(entry_name) + ": " + RenderTiming(entry.second), - TOTAL_BOX_WIDTH - 4) + - "││\n"; - } - } - if (has_previous_phase) { - ss << "│└───────────────────────────────────┘│\n"; - ss << "└─────────────────────────────────────┘\n"; - } - } - // render the main operator tree - if (root) { - Render(*root, ss); - } -} - -static string JSONSanitize(const string &text) { - string result; - result.reserve(text.size()); - for (idx_t i = 0; i < text.size(); i++) { - switch (text[i]) { - case '\b': - result += "\\b"; - break; - case '\f': - result += "\\f"; - break; - case '\n': - result += "\\n"; - break; - case '\r': - result += "\\r"; - break; - case '\t': - result += "\\t"; - break; - case '"': - result += "\\\""; - break; - case '\\': - result += "\\\\"; - break; - default: - result += text[i]; - break; - } - } - return result; -} - -// Print a row -static void PrintRow(std::ostream &ss, const string &annotation, int id, const string &name, double time, - int sample_counter, int tuple_counter, const string &extra_info, int depth) { - ss << string(depth * 3, ' ') << " {\n"; - ss << string(depth * 3, ' ') << " \"annotation\": \"" + JSONSanitize(annotation) + "\",\n"; - ss << string(depth * 3, ' ') << " \"id\": " + to_string(id) + ",\n"; - ss << string(depth * 3, ' ') << " \"name\": \"" + JSONSanitize(name) + "\",\n"; -#if defined(RDTSC) - ss << string(depth * 3, ' ') << " \"timing\": \"NULL\" ,\n"; - ss << string(depth * 3, ' ') << " \"cycles_per_tuple\": " + StringUtil::Format("%.4f", time) + ",\n"; -#else - ss << string(depth * 3, ' ') << " \"timing\":" + to_string(time) + ",\n"; - ss << string(depth * 3, ' ') << " \"cycles_per_tuple\": \"NULL\" ,\n"; -#endif - ss << string(depth * 3, ' ') << " \"sample_size\": " << to_string(sample_counter) + ",\n"; - ss << string(depth * 3, ' ') << " \"input_size\": " << to_string(tuple_counter) + ",\n"; - ss << string(depth * 3, ' ') << " \"extra_info\": \"" << JSONSanitize(extra_info) + "\"\n"; - ss << string(depth * 3, ' ') << " },\n"; -} - -static void ExtractFunctions(std::ostream &ss, ExpressionInfo &info, int &fun_id, int depth) { - if (info.hasfunction) { - double time = info.sample_tuples_count == 0 ? 0 : int(info.function_time) / double(info.sample_tuples_count); - PrintRow(ss, "Function", fun_id++, info.function_name, time, info.sample_tuples_count, info.tuples_count, "", - depth); - } - if (info.children.empty()) { - return; - } - // extract the children of this node - for (auto &child : info.children) { - ExtractFunctions(ss, *child, fun_id, depth); - } -} - -static void ToJSONRecursive(QueryProfiler::TreeNode &node, std::ostream &ss, int depth = 1) { - ss << string(depth * 3, ' ') << " {\n"; - ss << string(depth * 3, ' ') << " \"name\": \"" + JSONSanitize(node.name) + "\",\n"; - ss << string(depth * 3, ' ') << " \"timing\":" + to_string(node.info.time) + ",\n"; - ss << string(depth * 3, ' ') << " \"cardinality\":" + to_string(node.info.elements) + ",\n"; - ss << string(depth * 3, ' ') << " \"extra_info\": \"" + JSONSanitize(node.extra_info) + "\",\n"; - ss << string(depth * 3, ' ') << " \"timings\": ["; - int32_t function_counter = 1; - int32_t expression_counter = 1; - ss << "\n "; - for (auto &expr_executor : node.info.executors_info) { - // For each Expression tree - if (!expr_executor) { - continue; - } - for (auto &expr_timer : expr_executor->roots) { - double time = expr_timer->sample_tuples_count == 0 - ? 0 - : double(expr_timer->time) / double(expr_timer->sample_tuples_count); - PrintRow(ss, "ExpressionRoot", expression_counter++, expr_timer->name, time, - expr_timer->sample_tuples_count, expr_timer->tuples_count, expr_timer->extra_info, depth + 1); - // Extract all functions inside the tree - ExtractFunctions(ss, *expr_timer->root, function_counter, depth + 1); - } - } - ss.seekp(-2, ss.cur); - ss << "\n"; - ss << string(depth * 3, ' ') << " ],\n"; - ss << string(depth * 3, ' ') << " \"children\": [\n"; - if (node.children.empty()) { - ss << string(depth * 3, ' ') << " ]\n"; - } else { - for (idx_t i = 0; i < node.children.size(); i++) { - if (i > 0) { - ss << ",\n"; - } - ToJSONRecursive(*node.children[i], ss, depth + 1); - } - ss << string(depth * 3, ' ') << " ]\n"; - } - ss << string(depth * 3, ' ') << " }\n"; -} - -string QueryProfiler::ToJSON() const { - if (!IsEnabled()) { - return "{ \"result\": \"disabled\" }\n"; - } - if (query.empty() && !root) { - return "{ \"result\": \"empty\" }\n"; - } - if (!root) { - return "{ \"result\": \"error\" }\n"; - } - std::stringstream ss; - ss << "{\n"; - ss << " \"name\": \"Query\", \n"; - ss << " \"result\": " + to_string(main_query.Elapsed()) + ",\n"; - ss << " \"timing\": " + to_string(main_query.Elapsed()) + ",\n"; - ss << " \"cardinality\": " + to_string(root->info.elements) + ",\n"; - // JSON cannot have literal control characters in string literals - string extra_info = JSONSanitize(query); - ss << " \"extra-info\": \"" + extra_info + "\", \n"; - // print the phase timings - ss << " \"timings\": [\n"; - const auto &ordered_phase_timings = GetOrderedPhaseTimings(); - for (idx_t i = 0; i < ordered_phase_timings.size(); i++) { - if (i > 0) { - ss << ",\n"; - } - ss << " {\n"; - ss << " \"annotation\": \"" + ordered_phase_timings[i].first + "\", \n"; - ss << " \"timing\": " + to_string(ordered_phase_timings[i].second) + "\n"; - ss << " }"; - } - ss << "\n"; - ss << " ],\n"; - // recursively print the physical operator tree - ss << " \"children\": [\n"; - ToJSONRecursive(*root, ss); - ss << " ]\n"; - ss << "}"; - return ss.str(); -} - -void QueryProfiler::WriteToFile(const char *path, string &info) const { - ofstream out(path); - out << info; - out.close(); - // throw an IO exception if it fails to write the file - if (out.fail()) { - throw IOException(strerror(errno)); - } -} - -unique_ptr QueryProfiler::CreateTree(const PhysicalOperator &root, idx_t depth) { - if (OperatorRequiresProfiling(root.type)) { - this->query_requires_profiling = true; - } - auto node = make_uniq(); - node->type = root.type; - node->name = root.GetName(); - node->extra_info = root.ParamsToString(); - node->depth = depth; - tree_map.insert(make_pair(reference(root), reference(*node))); - auto children = root.GetChildren(); - for (auto &child : children) { - auto child_node = CreateTree(child.get(), depth + 1); - node->children.push_back(std::move(child_node)); - } - return node; -} - -void QueryProfiler::Render(const QueryProfiler::TreeNode &node, std::ostream &ss) const { - TreeRenderer renderer; - if (IsDetailedEnabled()) { - renderer.EnableDetailed(); - } else { - renderer.EnableStandard(); - } - renderer.Render(node, ss); -} - -void QueryProfiler::Print() { - Printer::Print(QueryTreeToString()); -} - -vector QueryProfiler::GetOrderedPhaseTimings() const { - vector result; - // first sort the phases alphabetically - vector phases; - for (auto &entry : phase_timings) { - phases.push_back(entry.first); - } - std::sort(phases.begin(), phases.end()); - for (const auto &phase : phases) { - auto entry = phase_timings.find(phase); - D_ASSERT(entry != phase_timings.end()); - result.emplace_back(entry->first, entry->second); - } - return result; -} -void QueryProfiler::Propagate(QueryProfiler &qp) { -} - -void ExpressionInfo::ExtractExpressionsRecursive(unique_ptr &state) { - if (state->child_states.empty()) { - return; - } - // extract the children of this node - for (auto &child : state->child_states) { - auto expr_info = make_uniq(); - if (child->expr.expression_class == ExpressionClass::BOUND_FUNCTION) { - expr_info->hasfunction = true; - expr_info->function_name = child->expr.Cast().function.ToString(); - expr_info->function_time = child->profiler.time; - expr_info->sample_tuples_count = child->profiler.sample_tuples_count; - expr_info->tuples_count = child->profiler.tuples_count; - } - expr_info->ExtractExpressionsRecursive(child); - children.push_back(std::move(expr_info)); - } - return; -} - -ExpressionExecutorInfo::ExpressionExecutorInfo(ExpressionExecutor &executor, const string &name, int id) : id(id) { - // Extract Expression Root Information from ExpressionExecutorStats - for (auto &state : executor.GetStates()) { - roots.push_back(make_uniq(*state, name)); - } -} - -ExpressionRootInfo::ExpressionRootInfo(ExpressionExecutorState &state, string name) - : current_count(state.profiler.current_count), sample_count(state.profiler.sample_count), - sample_tuples_count(state.profiler.sample_tuples_count), tuples_count(state.profiler.tuples_count), - name("expression"), time(state.profiler.time) { - // Use the name of expression-tree as extra-info - extra_info = std::move(name); - auto expression_info_p = make_uniq(); - // Maybe root has a function - if (state.root_state->expr.expression_class == ExpressionClass::BOUND_FUNCTION) { - expression_info_p->hasfunction = true; - expression_info_p->function_name = (state.root_state->expr.Cast()).function.name; - expression_info_p->function_time = state.root_state->profiler.time; - expression_info_p->sample_tuples_count = state.root_state->profiler.sample_tuples_count; - expression_info_p->tuples_count = state.root_state->profiler.tuples_count; - } - expression_info_p->ExtractExpressionsRecursive(state.root_state); - root = std::move(expression_info_p); -} -} // namespace duckdb - - - - - - -namespace duckdb { - -BaseQueryResult::BaseQueryResult(QueryResultType type, StatementType statement_type, StatementProperties properties_p, - vector types_p, vector names_p) - : type(type), statement_type(statement_type), properties(std::move(properties_p)), types(std::move(types_p)), - names(std::move(names_p)), success(true) { - D_ASSERT(types.size() == names.size()); -} - -BaseQueryResult::BaseQueryResult(QueryResultType type, PreservedError error) - : type(type), success(false), error(std::move(error)) { -} - -BaseQueryResult::~BaseQueryResult() { -} - -void BaseQueryResult::ThrowError(const string &prepended_message) const { - D_ASSERT(HasError()); - error.Throw(prepended_message); -} - -void BaseQueryResult::SetError(PreservedError error) { - success = !error; - this->error = std::move(error); -} - -bool BaseQueryResult::HasError() const { - D_ASSERT((bool)error == !success); - return !success; -} - -const ExceptionType &BaseQueryResult::GetErrorType() const { - return error.Type(); -} - -const std::string &BaseQueryResult::GetError() { - D_ASSERT(HasError()); - return error.Message(); -} - -PreservedError &BaseQueryResult::GetErrorObject() { - return error; -} - -idx_t BaseQueryResult::ColumnCount() { - return types.size(); -} - -QueryResult::QueryResult(QueryResultType type, StatementType statement_type, StatementProperties properties, - vector types_p, vector names_p, ClientProperties client_properties_p) - : BaseQueryResult(type, statement_type, std::move(properties), std::move(types_p), std::move(names_p)), - client_properties(std::move(client_properties_p)) { -} - -QueryResult::QueryResult(QueryResultType type, PreservedError error) - : BaseQueryResult(type, std::move(error)), client_properties("UTC", ArrowOffsetSize::REGULAR) { -} - -QueryResult::~QueryResult() { -} - -const string &QueryResult::ColumnName(idx_t index) const { - D_ASSERT(index < names.size()); - return names[index]; -} - -string QueryResult::ToBox(ClientContext &context, const BoxRendererConfig &config) { - return ToString(); -} - -unique_ptr QueryResult::Fetch() { - auto chunk = FetchRaw(); - if (!chunk) { - return nullptr; - } - chunk->Flatten(); - return chunk; -} - -bool QueryResult::Equals(QueryResult &other) { // LCOV_EXCL_START - // first compare the success state of the results - if (success != other.success) { - return false; - } - if (!success) { - return error == other.error; - } - // compare names - if (names != other.names) { - return false; - } - // compare types - if (types != other.types) { - return false; - } - // now compare the actual values - // fetch chunks - unique_ptr lchunk, rchunk; - idx_t lindex = 0, rindex = 0; - while (true) { - if (!lchunk || lindex == lchunk->size()) { - lchunk = Fetch(); - lindex = 0; - } - if (!rchunk || rindex == rchunk->size()) { - rchunk = other.Fetch(); - rindex = 0; - } - if (!lchunk && !rchunk) { - return true; - } - if (!lchunk || !rchunk) { - return false; - } - if (lchunk->size() == 0 && rchunk->size() == 0) { - return true; - } - D_ASSERT(lchunk->ColumnCount() == rchunk->ColumnCount()); - for (; lindex < lchunk->size() && rindex < rchunk->size(); lindex++, rindex++) { - for (idx_t col = 0; col < rchunk->ColumnCount(); col++) { - auto lvalue = lchunk->GetValue(col, lindex); - auto rvalue = rchunk->GetValue(col, rindex); - if (lvalue.IsNull() && rvalue.IsNull()) { - continue; - } - if (lvalue.IsNull() != rvalue.IsNull()) { - return false; - } - if (lvalue != rvalue) { - return false; - } - } - } - } -} // LCOV_EXCL_STOP - -void QueryResult::Print() { - Printer::Print(ToString()); -} - -string QueryResult::HeaderToString() { - string result; - for (auto &name : names) { - result += name + "\t"; - } - result += "\n"; - for (auto &type : types) { - result += type.ToString() + "\t"; - } - result += "\n"; - return result; -} - -} // namespace duckdb - - - - - -namespace duckdb { - -AggregateRelation::AggregateRelation(shared_ptr child_p, - vector> parsed_expressions) - : Relation(child_p->context, RelationType::AGGREGATE_RELATION), expressions(std::move(parsed_expressions)), - child(std::move(child_p)) { - // bind the expressions - context.GetContext()->TryBindRelation(*this, this->columns); -} - -AggregateRelation::AggregateRelation(shared_ptr child_p, - vector> parsed_expressions, GroupByNode groups_p) - : Relation(child_p->context, RelationType::AGGREGATE_RELATION), expressions(std::move(parsed_expressions)), - groups(std::move(groups_p)), child(std::move(child_p)) { - // bind the expressions - context.GetContext()->TryBindRelation(*this, this->columns); -} - -AggregateRelation::AggregateRelation(shared_ptr child_p, - vector> parsed_expressions, - vector> groups_p) - : Relation(child_p->context, RelationType::AGGREGATE_RELATION), expressions(std::move(parsed_expressions)), - child(std::move(child_p)) { - if (!groups_p.empty()) { - // explicit groups provided: use standard handling - GroupingSet grouping_set; - for (idx_t i = 0; i < groups_p.size(); i++) { - groups.group_expressions.push_back(std::move(groups_p[i])); - grouping_set.insert(i); - } - groups.grouping_sets.push_back(std::move(grouping_set)); - } - // bind the expressions - context.GetContext()->TryBindRelation(*this, this->columns); -} - -unique_ptr AggregateRelation::GetQueryNode() { - auto child_ptr = child.get(); - while (child_ptr->InheritsColumnBindings()) { - child_ptr = child_ptr->ChildRelation(); - } - unique_ptr result; - if (child_ptr->type == RelationType::JOIN_RELATION) { - // child node is a join: push projection into the child query node - result = child->GetQueryNode(); - } else { - // child node is not a join: create a new select node and push the child as a table reference - auto select = make_uniq(); - select->from_table = child->GetTableRef(); - result = std::move(select); - } - D_ASSERT(result->type == QueryNodeType::SELECT_NODE); - auto &select_node = result->Cast(); - if (!groups.group_expressions.empty()) { - select_node.aggregate_handling = AggregateHandling::STANDARD_HANDLING; - select_node.groups = groups.Copy(); - } else { - // no groups provided: automatically figure out groups (if any) - select_node.aggregate_handling = AggregateHandling::FORCE_AGGREGATES; - } - select_node.select_list.clear(); - for (auto &expr : expressions) { - select_node.select_list.push_back(expr->Copy()); - } - return result; -} - -string AggregateRelation::GetAlias() { - return child->GetAlias(); -} - -const vector &AggregateRelation::Columns() { - return columns; -} - -string AggregateRelation::ToString(idx_t depth) { - string str = RenderWhitespace(depth) + "Aggregate ["; - for (idx_t i = 0; i < expressions.size(); i++) { - if (i != 0) { - str += ", "; - } - str += expressions[i]->ToString(); - } - str += "]\n"; - return str + child->ToString(depth + 1); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -CreateTableRelation::CreateTableRelation(shared_ptr child_p, string schema_name, string table_name) - : Relation(child_p->context, RelationType::CREATE_TABLE_RELATION), child(std::move(child_p)), - schema_name(std::move(schema_name)), table_name(std::move(table_name)) { - context.GetContext()->TryBindRelation(*this, this->columns); -} - -BoundStatement CreateTableRelation::Bind(Binder &binder) { - auto select = make_uniq(); - select->node = child->GetQueryNode(); - - CreateStatement stmt; - auto info = make_uniq(); - info->schema = schema_name; - info->table = table_name; - info->query = std::move(select); - info->on_conflict = OnCreateConflict::ERROR_ON_CONFLICT; - stmt.info = std::move(info); - return binder.Bind(stmt.Cast()); -} - -const vector &CreateTableRelation::Columns() { - return columns; -} - -string CreateTableRelation::ToString(idx_t depth) { - string str = RenderWhitespace(depth) + "Create Table\n"; - return str + child->ToString(depth + 1); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -CreateViewRelation::CreateViewRelation(shared_ptr child_p, string view_name_p, bool replace_p, - bool temporary_p) - : Relation(child_p->context, RelationType::CREATE_VIEW_RELATION), child(std::move(child_p)), - view_name(std::move(view_name_p)), replace(replace_p), temporary(temporary_p) { - context.GetContext()->TryBindRelation(*this, this->columns); -} - -CreateViewRelation::CreateViewRelation(shared_ptr child_p, string schema_name_p, string view_name_p, - bool replace_p, bool temporary_p) - : Relation(child_p->context, RelationType::CREATE_VIEW_RELATION), child(std::move(child_p)), - schema_name(std::move(schema_name_p)), view_name(std::move(view_name_p)), replace(replace_p), - temporary(temporary_p) { - context.GetContext()->TryBindRelation(*this, this->columns); -} - -BoundStatement CreateViewRelation::Bind(Binder &binder) { - auto select = make_uniq(); - select->node = child->GetQueryNode(); - - CreateStatement stmt; - auto info = make_uniq(); - info->query = std::move(select); - info->view_name = view_name; - info->temporary = temporary; - info->schema = schema_name; - info->on_conflict = replace ? OnCreateConflict::REPLACE_ON_CONFLICT : OnCreateConflict::ERROR_ON_CONFLICT; - stmt.info = std::move(info); - return binder.Bind(stmt.Cast()); -} - -const vector &CreateViewRelation::Columns() { - return columns; -} - -string CreateViewRelation::ToString(idx_t depth) { - string str = RenderWhitespace(depth) + "Create View\n"; - return str + child->ToString(depth + 1); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -CrossProductRelation::CrossProductRelation(shared_ptr left_p, shared_ptr right_p, - JoinRefType ref_type) - : Relation(left_p->context, RelationType::CROSS_PRODUCT_RELATION), left(std::move(left_p)), - right(std::move(right_p)), ref_type(ref_type) { - if (left->context.GetContext() != right->context.GetContext()) { - throw Exception("Cannot combine LEFT and RIGHT relations of different connections!"); - } - context.GetContext()->TryBindRelation(*this, this->columns); -} - -unique_ptr CrossProductRelation::GetQueryNode() { - auto result = make_uniq(); - result->select_list.push_back(make_uniq()); - result->from_table = GetTableRef(); - return std::move(result); -} - -unique_ptr CrossProductRelation::GetTableRef() { - auto cross_product_ref = make_uniq(ref_type); - cross_product_ref->left = left->GetTableRef(); - cross_product_ref->right = right->GetTableRef(); - return std::move(cross_product_ref); -} - -const vector &CrossProductRelation::Columns() { - return this->columns; -} - -string CrossProductRelation::ToString(idx_t depth) { - string str = RenderWhitespace(depth); - str = "Cross Product"; - return str + "\n" + left->ToString(depth + 1) + right->ToString(depth + 1); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -DeleteRelation::DeleteRelation(ClientContextWrapper &context, unique_ptr condition_p, - string schema_name_p, string table_name_p) - : Relation(context, RelationType::DELETE_RELATION), condition(std::move(condition_p)), - schema_name(std::move(schema_name_p)), table_name(std::move(table_name_p)) { - context.GetContext()->TryBindRelation(*this, this->columns); -} - -BoundStatement DeleteRelation::Bind(Binder &binder) { - auto basetable = make_uniq(); - basetable->schema_name = schema_name; - basetable->table_name = table_name; - - DeleteStatement stmt; - stmt.condition = condition ? condition->Copy() : nullptr; - stmt.table = std::move(basetable); - return binder.Bind(stmt.Cast()); -} - -const vector &DeleteRelation::Columns() { - return columns; -} - -string DeleteRelation::ToString(idx_t depth) { - string str = RenderWhitespace(depth) + "DELETE FROM " + table_name; - if (condition) { - str += " WHERE " + condition->ToString(); - } - return str; -} - -} // namespace duckdb - - - - -namespace duckdb { - -DistinctRelation::DistinctRelation(shared_ptr child_p) - : Relation(child_p->context, RelationType::DISTINCT_RELATION), child(std::move(child_p)) { - D_ASSERT(child.get() != this); - vector dummy_columns; - context.GetContext()->TryBindRelation(*this, dummy_columns); -} - -unique_ptr DistinctRelation::GetQueryNode() { - auto child_node = child->GetQueryNode(); - child_node->AddDistinct(); - return child_node; -} - -string DistinctRelation::GetAlias() { - return child->GetAlias(); -} - -const vector &DistinctRelation::Columns() { - return child->Columns(); -} - -string DistinctRelation::ToString(idx_t depth) { - string str = RenderWhitespace(depth) + "Distinct\n"; - return str + child->ToString(depth + 1); - ; -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -ExplainRelation::ExplainRelation(shared_ptr child_p, ExplainType type) - : Relation(child_p->context, RelationType::EXPLAIN_RELATION), child(std::move(child_p)), type(type) { - context.GetContext()->TryBindRelation(*this, this->columns); -} - -BoundStatement ExplainRelation::Bind(Binder &binder) { - auto select = make_uniq(); - select->node = child->GetQueryNode(); - ExplainStatement explain(std::move(select), type); - return binder.Bind(explain.Cast()); -} - -const vector &ExplainRelation::Columns() { - return columns; -} - -string ExplainRelation::ToString(idx_t depth) { - string str = RenderWhitespace(depth) + "Explain\n"; - return str + child->ToString(depth + 1); -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -FilterRelation::FilterRelation(shared_ptr child_p, unique_ptr condition_p) - : Relation(child_p->context, RelationType::FILTER_RELATION), condition(std::move(condition_p)), - child(std::move(child_p)) { - D_ASSERT(child.get() != this); - vector dummy_columns; - context.GetContext()->TryBindRelation(*this, dummy_columns); -} - -unique_ptr FilterRelation::GetQueryNode() { - auto child_ptr = child.get(); - while (child_ptr->InheritsColumnBindings()) { - child_ptr = child_ptr->ChildRelation(); - } - if (child_ptr->type == RelationType::JOIN_RELATION) { - // child node is a join: push filter into WHERE clause of select node - auto child_node = child->GetQueryNode(); - D_ASSERT(child_node->type == QueryNodeType::SELECT_NODE); - auto &select_node = child_node->Cast(); - if (!select_node.where_clause) { - select_node.where_clause = condition->Copy(); - } else { - select_node.where_clause = make_uniq( - ExpressionType::CONJUNCTION_AND, std::move(select_node.where_clause), condition->Copy()); - } - return child_node; - } else { - auto result = make_uniq(); - result->select_list.push_back(make_uniq()); - result->from_table = child->GetTableRef(); - result->where_clause = condition->Copy(); - return std::move(result); - } -} - -string FilterRelation::GetAlias() { - return child->GetAlias(); -} - -const vector &FilterRelation::Columns() { - return child->Columns(); -} - -string FilterRelation::ToString(idx_t depth) { - string str = RenderWhitespace(depth) + "Filter [" + condition->ToString() + "]\n"; - return str + child->ToString(depth + 1); -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -InsertRelation::InsertRelation(shared_ptr child_p, string schema_name, string table_name) - : Relation(child_p->context, RelationType::INSERT_RELATION), child(std::move(child_p)), - schema_name(std::move(schema_name)), table_name(std::move(table_name)) { - context.GetContext()->TryBindRelation(*this, this->columns); -} - -BoundStatement InsertRelation::Bind(Binder &binder) { - InsertStatement stmt; - auto select = make_uniq(); - select->node = child->GetQueryNode(); - - stmt.schema = schema_name; - stmt.table = table_name; - stmt.select_statement = std::move(select); - return binder.Bind(stmt.Cast()); -} - -const vector &InsertRelation::Columns() { - return columns; -} - -string InsertRelation::ToString(idx_t depth) { - string str = RenderWhitespace(depth) + "Insert\n"; - return str + child->ToString(depth + 1); -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -JoinRelation::JoinRelation(shared_ptr left_p, shared_ptr right_p, - unique_ptr condition_p, JoinType type, JoinRefType join_ref_type) - : Relation(left_p->context, RelationType::JOIN_RELATION), left(std::move(left_p)), right(std::move(right_p)), - condition(std::move(condition_p)), join_type(type), join_ref_type(join_ref_type) { - if (left->context.GetContext() != right->context.GetContext()) { - throw Exception("Cannot combine LEFT and RIGHT relations of different connections!"); - } - context.GetContext()->TryBindRelation(*this, this->columns); -} - -JoinRelation::JoinRelation(shared_ptr left_p, shared_ptr right_p, vector using_columns_p, - JoinType type, JoinRefType join_ref_type) - : Relation(left_p->context, RelationType::JOIN_RELATION), left(std::move(left_p)), right(std::move(right_p)), - using_columns(std::move(using_columns_p)), join_type(type), join_ref_type(join_ref_type) { - if (left->context.GetContext() != right->context.GetContext()) { - throw Exception("Cannot combine LEFT and RIGHT relations of different connections!"); - } - context.GetContext()->TryBindRelation(*this, this->columns); -} - -unique_ptr JoinRelation::GetQueryNode() { - auto result = make_uniq(); - result->select_list.push_back(make_uniq()); - result->from_table = GetTableRef(); - return std::move(result); -} - -unique_ptr JoinRelation::GetTableRef() { - auto join_ref = make_uniq(join_ref_type); - join_ref->left = left->GetTableRef(); - join_ref->right = right->GetTableRef(); - if (condition) { - join_ref->condition = condition->Copy(); - } - join_ref->using_columns = using_columns; - join_ref->type = join_type; - return std::move(join_ref); -} - -const vector &JoinRelation::Columns() { - return this->columns; -} - -string JoinRelation::ToString(idx_t depth) { - string str = RenderWhitespace(depth); - str += "Join " + EnumUtil::ToString(join_ref_type) + " " + EnumUtil::ToString(join_type); - if (condition) { - str += " " + condition->GetName(); - } - - return str + "\n" + left->ToString(depth + 1) + "\n" + right->ToString(depth + 1); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -LimitRelation::LimitRelation(shared_ptr child_p, int64_t limit, int64_t offset) - : Relation(child_p->context, RelationType::PROJECTION_RELATION), limit(limit), offset(offset), - child(std::move(child_p)) { - D_ASSERT(child.get() != this); -} - -unique_ptr LimitRelation::GetQueryNode() { - auto child_node = child->GetQueryNode(); - auto limit_node = make_uniq(); - if (limit >= 0) { - limit_node->limit = make_uniq(Value::BIGINT(limit)); - } - if (offset > 0) { - limit_node->offset = make_uniq(Value::BIGINT(offset)); - } - - child_node->modifiers.push_back(std::move(limit_node)); - return child_node; -} - -string LimitRelation::GetAlias() { - return child->GetAlias(); -} - -const vector &LimitRelation::Columns() { - return child->Columns(); -} - -string LimitRelation::ToString(idx_t depth) { - string str = RenderWhitespace(depth) + "Limit " + to_string(limit); - if (offset > 0) { - str += " Offset " + to_string(offset); - } - str += "\n"; - return str + child->ToString(depth + 1); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -OrderRelation::OrderRelation(shared_ptr child_p, vector orders) - : Relation(child_p->context, RelationType::ORDER_RELATION), orders(std::move(orders)), child(std::move(child_p)) { - D_ASSERT(child.get() != this); - // bind the expressions - context.GetContext()->TryBindRelation(*this, this->columns); -} - -unique_ptr OrderRelation::GetQueryNode() { - auto select = make_uniq(); - select->from_table = child->GetTableRef(); - select->select_list.push_back(make_uniq()); - auto order_node = make_uniq(); - for (idx_t i = 0; i < orders.size(); i++) { - order_node->orders.emplace_back(orders[i].type, orders[i].null_order, orders[i].expression->Copy()); - } - select->modifiers.push_back(std::move(order_node)); - return std::move(select); -} - -string OrderRelation::GetAlias() { - return child->GetAlias(); -} - -const vector &OrderRelation::Columns() { - return columns; -} - -string OrderRelation::ToString(idx_t depth) { - string str = RenderWhitespace(depth) + "Order ["; - for (idx_t i = 0; i < orders.size(); i++) { - if (i != 0) { - str += ", "; - } - str += orders[i].expression->ToString() + (orders[i].type == OrderType::ASCENDING ? " ASC" : " DESC"); - } - str += "]\n"; - return str + child->ToString(depth + 1); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -ProjectionRelation::ProjectionRelation(shared_ptr child_p, - vector> parsed_expressions, vector aliases) - : Relation(child_p->context, RelationType::PROJECTION_RELATION), expressions(std::move(parsed_expressions)), - child(std::move(child_p)) { - if (!aliases.empty()) { - if (aliases.size() != expressions.size()) { - throw ParserException("Aliases list length must match expression list length!"); - } - for (idx_t i = 0; i < aliases.size(); i++) { - expressions[i]->alias = aliases[i]; - } - } - // bind the expressions - context.GetContext()->TryBindRelation(*this, this->columns); -} - -unique_ptr ProjectionRelation::GetQueryNode() { - auto child_ptr = child.get(); - while (child_ptr->InheritsColumnBindings()) { - child_ptr = child_ptr->ChildRelation(); - } - unique_ptr result; - if (child_ptr->type == RelationType::JOIN_RELATION) { - // child node is a join: push projection into the child query node - result = child->GetQueryNode(); - } else { - // child node is not a join: create a new select node and push the child as a table reference - auto select = make_uniq(); - select->from_table = child->GetTableRef(); - result = std::move(select); - } - D_ASSERT(result->type == QueryNodeType::SELECT_NODE); - auto &select_node = result->Cast(); - select_node.aggregate_handling = AggregateHandling::NO_AGGREGATES_ALLOWED; - select_node.select_list.clear(); - for (auto &expr : expressions) { - select_node.select_list.push_back(expr->Copy()); - } - return result; -} - -string ProjectionRelation::GetAlias() { - return child->GetAlias(); -} - -const vector &ProjectionRelation::Columns() { - return columns; -} - -string ProjectionRelation::ToString(idx_t depth) { - string str = RenderWhitespace(depth) + "Projection ["; - for (idx_t i = 0; i < expressions.size(); i++) { - if (i != 0) { - str += ", "; - } - str += expressions[i]->ToString() + " as " + expressions[i]->alias; - } - str += "]\n"; - return str + child->ToString(depth + 1); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -QueryRelation::QueryRelation(const std::shared_ptr &context, unique_ptr select_stmt_p, - string alias_p) - : Relation(context, RelationType::QUERY_RELATION), select_stmt(std::move(select_stmt_p)), - alias(std::move(alias_p)) { - context->TryBindRelation(*this, this->columns); -} - -QueryRelation::~QueryRelation() { -} - -unique_ptr QueryRelation::ParseStatement(ClientContext &context, const string &query, - const string &error) { - Parser parser(context.GetParserOptions()); - parser.ParseQuery(query); - if (parser.statements.size() != 1) { - throw ParserException(error); - } - if (parser.statements[0]->type != StatementType::SELECT_STATEMENT) { - throw ParserException(error); - } - return unique_ptr_cast(std::move(parser.statements[0])); -} - -unique_ptr QueryRelation::GetSelectStatement() { - return unique_ptr_cast(select_stmt->Copy()); -} - -unique_ptr QueryRelation::GetQueryNode() { - auto select = GetSelectStatement(); - return std::move(select->node); -} - -unique_ptr QueryRelation::GetTableRef() { - auto subquery_ref = make_uniq(GetSelectStatement(), GetAlias()); - return std::move(subquery_ref); -} - -string QueryRelation::GetAlias() { - return alias; -} - -const vector &QueryRelation::Columns() { - return columns; -} - -string QueryRelation::ToString(idx_t depth) { - return RenderWhitespace(depth) + "Subquery"; -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - -namespace duckdb { - -ReadCSVRelation::ReadCSVRelation(const shared_ptr &context, const string &csv_file, - vector columns_p, string alias_p) - : TableFunctionRelation(context, "read_csv", {Value(csv_file)}, nullptr, false), alias(std::move(alias_p)), - auto_detect(false) { - - if (alias.empty()) { - alias = StringUtil::Split(csv_file, ".")[0]; - } - - columns = std::move(columns_p); - - child_list_t column_names; - for (idx_t i = 0; i < columns.size(); i++) { - column_names.push_back(make_pair(columns[i].Name(), Value(columns[i].Type().ToString()))); - } - - AddNamedParameter("columns", Value::STRUCT(std::move(column_names))); -} - -ReadCSVRelation::ReadCSVRelation(const std::shared_ptr &context, const string &csv_file, - named_parameter_map_t &&options, string alias_p) - : TableFunctionRelation(context, "read_csv_auto", {Value(csv_file)}, nullptr, false), alias(std::move(alias_p)), - auto_detect(true) { - - if (alias.empty()) { - alias = StringUtil::Split(csv_file, ".")[0]; - } - - auto files = MultiFileReader::GetFileList(*context, csv_file, "CSV"); - D_ASSERT(!files.empty()); - - auto &file_name = files[0]; - options["auto_detect"] = Value::BOOLEAN(true); - CSVReaderOptions csv_options; - csv_options.file_path = file_name; - vector empty; - - vector unused_types; - vector unused_names; - csv_options.FromNamedParameters(options, *context, unused_types, unused_names); - // Run the auto-detect, populating the options with the detected settings - - auto bm_file_handle = BaseCSVReader::OpenCSV(*context, csv_options); - auto buffer_manager = make_shared(*context, std::move(bm_file_handle), csv_options); - CSVStateMachineCache state_machine_cache; - CSVSniffer sniffer(csv_options, buffer_manager, state_machine_cache); - auto sniffer_result = sniffer.SniffCSV(); - auto &types = sniffer_result.return_types; - auto &names = sniffer_result.names; - for (idx_t i = 0; i < types.size(); i++) { - columns.emplace_back(names[i], types[i]); - } - - //! Capture the options potentially set/altered by the auto detection phase - csv_options.ToNamedParameters(options); - - // No need to auto-detect again - options["auto_detect"] = Value::BOOLEAN(false); - SetNamedParameters(std::move(options)); -} - -string ReadCSVRelation::GetAlias() { - return alias; -} - -} // namespace duckdb - - -namespace duckdb { - -ReadJSONRelation::ReadJSONRelation(const shared_ptr &context, string json_file_p, - named_parameter_map_t options, bool auto_detect, string alias_p) - : TableFunctionRelation(context, auto_detect ? "read_json_auto" : "read_json", {Value(json_file_p)}, - std::move(options)), - json_file(std::move(json_file_p)), alias(std::move(alias_p)) { - - if (alias.empty()) { - alias = StringUtil::Split(json_file, ".")[0]; - } -} - -string ReadJSONRelation::GetAlias() { - return alias; -} - -} // namespace duckdb - - - - - -namespace duckdb { - -SetOpRelation::SetOpRelation(shared_ptr left_p, shared_ptr right_p, SetOperationType setop_type_p) - : Relation(left_p->context, RelationType::SET_OPERATION_RELATION), left(std::move(left_p)), - right(std::move(right_p)), setop_type(setop_type_p) { - if (left->context.GetContext() != right->context.GetContext()) { - throw Exception("Cannot combine LEFT and RIGHT relations of different connections!"); - } - context.GetContext()->TryBindRelation(*this, this->columns); -} - -unique_ptr SetOpRelation::GetQueryNode() { - auto result = make_uniq(); - if (setop_type == SetOperationType::EXCEPT || setop_type == SetOperationType::INTERSECT) { - result->modifiers.push_back(make_uniq()); - } - result->left = left->GetQueryNode(); - result->right = right->GetQueryNode(); - result->setop_type = setop_type; - return std::move(result); -} - -string SetOpRelation::GetAlias() { - return left->GetAlias(); -} - -const vector &SetOpRelation::Columns() { - return this->columns; -} - -string SetOpRelation::ToString(idx_t depth) { - string str = RenderWhitespace(depth); - switch (setop_type) { - case SetOperationType::UNION: - str += "Union"; - break; - case SetOperationType::EXCEPT: - str += "Except"; - break; - case SetOperationType::INTERSECT: - str += "Intersect"; - break; - default: - throw InternalException("Unknown setop type"); - } - return str + "\n" + left->ToString(depth + 1) + right->ToString(depth + 1); -} - -} // namespace duckdb - - - - -namespace duckdb { - -SubqueryRelation::SubqueryRelation(shared_ptr child_p, string alias_p) - : Relation(child_p->context, RelationType::SUBQUERY_RELATION), child(std::move(child_p)), - alias(std::move(alias_p)) { - D_ASSERT(child.get() != this); - vector dummy_columns; - context.GetContext()->TryBindRelation(*this, dummy_columns); -} - -unique_ptr SubqueryRelation::GetQueryNode() { - return child->GetQueryNode(); -} - -string SubqueryRelation::GetAlias() { - return alias; -} - -const vector &SubqueryRelation::Columns() { - return child->Columns(); -} - -string SubqueryRelation::ToString(idx_t depth) { - return child->ToString(depth); -} - -} // namespace duckdb - - - - - - - - - - - - - -namespace duckdb { - -void TableFunctionRelation::AddNamedParameter(const string &name, Value argument) { - named_parameters[name] = std::move(argument); -} - -void TableFunctionRelation::SetNamedParameters(named_parameter_map_t &&options) { - D_ASSERT(named_parameters.empty()); - named_parameters = std::move(options); -} - -TableFunctionRelation::TableFunctionRelation(const shared_ptr &context, string name_p, - vector parameters_p, named_parameter_map_t named_parameters, - shared_ptr input_relation_p, bool auto_init) - : Relation(context, RelationType::TABLE_FUNCTION_RELATION), name(std::move(name_p)), - parameters(std::move(parameters_p)), named_parameters(std::move(named_parameters)), - input_relation(std::move(input_relation_p)), auto_initialize(auto_init) { - InitializeColumns(); -} - -TableFunctionRelation::TableFunctionRelation(const shared_ptr &context, string name_p, - vector parameters_p, shared_ptr input_relation_p, - bool auto_init) - : Relation(context, RelationType::TABLE_FUNCTION_RELATION), name(std::move(name_p)), - parameters(std::move(parameters_p)), input_relation(std::move(input_relation_p)), auto_initialize(auto_init) { - InitializeColumns(); -} - -void TableFunctionRelation::InitializeColumns() { - if (!auto_initialize) { - return; - } - context.GetContext()->TryBindRelation(*this, this->columns); -} - -unique_ptr TableFunctionRelation::GetQueryNode() { - auto result = make_uniq(); - result->select_list.push_back(make_uniq()); - result->from_table = GetTableRef(); - return std::move(result); -} - -unique_ptr TableFunctionRelation::GetTableRef() { - vector> children; - if (input_relation) { // input relation becomes first parameter if present, always - auto subquery = make_uniq(); - subquery->subquery = make_uniq(); - subquery->subquery->node = input_relation->GetQueryNode(); - subquery->subquery_type = SubqueryType::SCALAR; - children.push_back(std::move(subquery)); - } - for (auto ¶meter : parameters) { - children.push_back(make_uniq(parameter)); - } - - for (auto ¶meter : named_parameters) { - // Hackity-hack some comparisons with column refs - // This is all but pretty, basically the named parameter is the column, the table is empty because that's what - // the function binder likes - auto column_ref = make_uniq(parameter.first); - auto constant_value = make_uniq(parameter.second); - auto comparison = make_uniq(ExpressionType::COMPARE_EQUAL, std::move(column_ref), - std::move(constant_value)); - children.push_back(std::move(comparison)); - } - - auto table_function = make_uniq(); - auto function = make_uniq(name, std::move(children)); - table_function->function = std::move(function); - return std::move(table_function); -} - -string TableFunctionRelation::GetAlias() { - return name; -} - -const vector &TableFunctionRelation::Columns() { - return columns; -} - -string TableFunctionRelation::ToString(idx_t depth) { - string function_call = name + "("; - for (idx_t i = 0; i < parameters.size(); i++) { - if (i > 0) { - function_call += ", "; - } - function_call += parameters[i].ToString(); - } - function_call += ")"; - return RenderWhitespace(depth) + function_call; -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -TableRelation::TableRelation(const std::shared_ptr &context, unique_ptr description) - : Relation(context, RelationType::TABLE_RELATION), description(std::move(description)) { -} - -unique_ptr TableRelation::GetQueryNode() { - auto result = make_uniq(); - result->select_list.push_back(make_uniq()); - result->from_table = GetTableRef(); - return std::move(result); -} - -unique_ptr TableRelation::GetTableRef() { - auto table_ref = make_uniq(); - table_ref->schema_name = description->schema; - table_ref->table_name = description->table; - return std::move(table_ref); -} - -string TableRelation::GetAlias() { - return description->table; -} - -const vector &TableRelation::Columns() { - return description->columns; -} - -string TableRelation::ToString(idx_t depth) { - return RenderWhitespace(depth) + "Scan Table [" + description->table + "]"; -} - -static unique_ptr ParseCondition(ClientContext &context, const string &condition) { - if (!condition.empty()) { - auto expression_list = Parser::ParseExpressionList(condition, context.GetParserOptions()); - if (expression_list.size() != 1) { - throw ParserException("Expected a single expression as filter condition"); - } - return std::move(expression_list[0]); - } else { - return nullptr; - } -} - -void TableRelation::Update(const string &update_list, const string &condition) { - vector update_columns; - vector> expressions; - auto cond = ParseCondition(*context.GetContext(), condition); - Parser::ParseUpdateList(update_list, update_columns, expressions, context.GetContext()->GetParserOptions()); - auto update = make_shared(context, std::move(cond), description->schema, description->table, - std::move(update_columns), std::move(expressions)); - update->Execute(); -} - -void TableRelation::Delete(const string &condition) { - auto cond = ParseCondition(*context.GetContext(), condition); - auto del = make_shared(context, std::move(cond), description->schema, description->table); - del->Execute(); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -UpdateRelation::UpdateRelation(ClientContextWrapper &context, unique_ptr condition_p, - string schema_name_p, string table_name_p, vector update_columns_p, - vector> expressions_p) - : Relation(context, RelationType::UPDATE_RELATION), condition(std::move(condition_p)), - schema_name(std::move(schema_name_p)), table_name(std::move(table_name_p)), - update_columns(std::move(update_columns_p)), expressions(std::move(expressions_p)) { - D_ASSERT(update_columns.size() == expressions.size()); - context.GetContext()->TryBindRelation(*this, this->columns); -} - -BoundStatement UpdateRelation::Bind(Binder &binder) { - auto basetable = make_uniq(); - basetable->schema_name = schema_name; - basetable->table_name = table_name; - - UpdateStatement stmt; - stmt.set_info = make_uniq(); - - stmt.set_info->condition = condition ? condition->Copy() : nullptr; - stmt.table = std::move(basetable); - stmt.set_info->columns = update_columns; - for (auto &expr : expressions) { - stmt.set_info->expressions.push_back(expr->Copy()); - } - return binder.Bind(stmt.Cast()); -} - -const vector &UpdateRelation::Columns() { - return columns; -} - -string UpdateRelation::ToString(idx_t depth) { - string str = RenderWhitespace(depth) + "UPDATE " + table_name + " SET\n"; - for (idx_t i = 0; i < expressions.size(); i++) { - str += update_columns[i] + " = " + expressions[i]->ToString() + "\n"; - } - if (condition) { - str += "WHERE " + condition->ToString() + "\n"; - } - return str; -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -ValueRelation::ValueRelation(const std::shared_ptr &context, const vector> &values, - vector names_p, string alias_p) - : Relation(context, RelationType::VALUE_LIST_RELATION), names(std::move(names_p)), alias(std::move(alias_p)) { - // create constant expressions for the values - for (idx_t row_idx = 0; row_idx < values.size(); row_idx++) { - auto &list = values[row_idx]; - vector> expressions; - for (idx_t col_idx = 0; col_idx < list.size(); col_idx++) { - expressions.push_back(make_uniq(list[col_idx])); - } - this->expressions.push_back(std::move(expressions)); - } - context->TryBindRelation(*this, this->columns); -} - -ValueRelation::ValueRelation(const std::shared_ptr &context, const string &values_list, - vector names_p, string alias_p) - : Relation(context, RelationType::VALUE_LIST_RELATION), names(std::move(names_p)), alias(std::move(alias_p)) { - this->expressions = Parser::ParseValuesList(values_list, context->GetParserOptions()); - context->TryBindRelation(*this, this->columns); -} - -unique_ptr ValueRelation::GetQueryNode() { - auto result = make_uniq(); - result->select_list.push_back(make_uniq()); - result->from_table = GetTableRef(); - return std::move(result); -} - -unique_ptr ValueRelation::GetTableRef() { - auto table_ref = make_uniq(); - // set the expected types/names - if (columns.empty()) { - // no columns yet: only set up names - for (idx_t i = 0; i < names.size(); i++) { - table_ref->expected_names.push_back(names[i]); - } - } else { - for (idx_t i = 0; i < columns.size(); i++) { - table_ref->expected_names.push_back(columns[i].Name()); - table_ref->expected_types.push_back(columns[i].Type()); - D_ASSERT(names.size() == 0 || columns[i].Name() == names[i]); - } - } - // copy the expressions - for (auto &expr_list : expressions) { - vector> copied_list; - copied_list.reserve(expr_list.size()); - for (auto &expr : expr_list) { - copied_list.push_back(expr->Copy()); - } - table_ref->values.push_back(std::move(copied_list)); - } - table_ref->alias = GetAlias(); - return std::move(table_ref); -} - -string ValueRelation::GetAlias() { - return alias; -} - -const vector &ValueRelation::Columns() { - return columns; -} - -string ValueRelation::ToString(idx_t depth) { - string str = RenderWhitespace(depth) + "Values "; - for (idx_t row_idx = 0; row_idx < expressions.size(); row_idx++) { - auto &list = expressions[row_idx]; - str += row_idx > 0 ? ", (" : "("; - for (idx_t col_idx = 0; col_idx < list.size(); col_idx++) { - str += col_idx > 0 ? ", " : ""; - str += list[col_idx]->ToString(); - } - str += ")"; - } - str += "\n"; - return str; -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -ViewRelation::ViewRelation(const std::shared_ptr &context, string schema_name_p, string view_name_p) - : Relation(context, RelationType::VIEW_RELATION), schema_name(std::move(schema_name_p)), - view_name(std::move(view_name_p)) { - context->TryBindRelation(*this, this->columns); -} - -unique_ptr ViewRelation::GetQueryNode() { - auto result = make_uniq(); - result->select_list.push_back(make_uniq()); - result->from_table = GetTableRef(); - return std::move(result); -} - -unique_ptr ViewRelation::GetTableRef() { - auto table_ref = make_uniq(); - table_ref->schema_name = schema_name; - table_ref->table_name = view_name; - return std::move(table_ref); -} - -string ViewRelation::GetAlias() { - return view_name; -} - -const vector &ViewRelation::Columns() { - return columns; -} - -string ViewRelation::ToString(idx_t depth) { - return RenderWhitespace(depth) + "View [" + view_name + "]"; -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -WriteCSVRelation::WriteCSVRelation(shared_ptr child_p, string csv_file_p, - case_insensitive_map_t> options_p) - : Relation(child_p->context, RelationType::WRITE_CSV_RELATION), child(std::move(child_p)), - csv_file(std::move(csv_file_p)), options(std::move(options_p)) { - context.GetContext()->TryBindRelation(*this, this->columns); -} - -BoundStatement WriteCSVRelation::Bind(Binder &binder) { - CopyStatement copy; - copy.select_statement = child->GetQueryNode(); - auto info = make_uniq(); - info->is_from = false; - info->file_path = csv_file; - info->format = "csv"; - info->options = options; - copy.info = std::move(info); - return binder.Bind(copy.Cast()); -} - -const vector &WriteCSVRelation::Columns() { - return columns; -} - -string WriteCSVRelation::ToString(idx_t depth) { - string str = RenderWhitespace(depth) + "Write To CSV [" + csv_file + "]\n"; - return str + child->ToString(depth + 1); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -WriteParquetRelation::WriteParquetRelation(shared_ptr child_p, string parquet_file_p, - case_insensitive_map_t> options_p) - : Relation(child_p->context, RelationType::WRITE_PARQUET_RELATION), child(std::move(child_p)), - parquet_file(std::move(parquet_file_p)), options(std::move(options_p)) { - context.GetContext()->TryBindRelation(*this, this->columns); -} - -BoundStatement WriteParquetRelation::Bind(Binder &binder) { - CopyStatement copy; - copy.select_statement = child->GetQueryNode(); - auto info = make_uniq(); - info->is_from = false; - info->file_path = parquet_file; - info->format = "parquet"; - info->options = options; - copy.info = std::move(info); - return binder.Bind(copy.Cast()); -} - -const vector &WriteParquetRelation::Columns() { - return columns; -} - -string WriteParquetRelation::ToString(idx_t depth) { - string str = RenderWhitespace(depth) + "Write To Parquet [" + parquet_file + "]\n"; - return str + child->ToString(depth + 1); -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -namespace duckdb { - -shared_ptr Relation::Project(const string &select_list) { - return Project(select_list, vector()); -} - -shared_ptr Relation::Project(const string &expression, const string &alias) { - return Project(expression, vector({alias})); -} - -shared_ptr Relation::Project(const string &select_list, const vector &aliases) { - auto expressions = Parser::ParseExpressionList(select_list, context.GetContext()->GetParserOptions()); - return make_shared(shared_from_this(), std::move(expressions), aliases); -} - -shared_ptr Relation::Project(const vector &expressions) { - vector aliases; - return Project(expressions, aliases); -} - -shared_ptr Relation::Project(vector> expressions, - const vector &aliases) { - return make_shared(shared_from_this(), std::move(expressions), aliases); -} - -static vector> StringListToExpressionList(ClientContext &context, - const vector &expressions) { - if (expressions.empty()) { - throw ParserException("Zero expressions provided"); - } - vector> result_list; - for (auto &expr : expressions) { - auto expression_list = Parser::ParseExpressionList(expr, context.GetParserOptions()); - if (expression_list.size() != 1) { - throw ParserException("Expected a single expression in the expression list"); - } - result_list.push_back(std::move(expression_list[0])); - } - return result_list; -} - -shared_ptr Relation::Project(const vector &expressions, const vector &aliases) { - auto result_list = StringListToExpressionList(*context.GetContext(), expressions); - return make_shared(shared_from_this(), std::move(result_list), aliases); -} - -shared_ptr Relation::Filter(const string &expression) { - auto expression_list = Parser::ParseExpressionList(expression, context.GetContext()->GetParserOptions()); - if (expression_list.size() != 1) { - throw ParserException("Expected a single expression as filter condition"); - } - return Filter(std::move(expression_list[0])); -} - -shared_ptr Relation::Filter(unique_ptr expression) { - return make_shared(shared_from_this(), std::move(expression)); -} - -shared_ptr Relation::Filter(const vector &expressions) { - // if there are multiple expressions, we AND them together - auto expression_list = StringListToExpressionList(*context.GetContext(), expressions); - D_ASSERT(!expression_list.empty()); - - auto expr = std::move(expression_list[0]); - for (idx_t i = 1; i < expression_list.size(); i++) { - expr = make_uniq(ExpressionType::CONJUNCTION_AND, std::move(expr), - std::move(expression_list[i])); - } - return make_shared(shared_from_this(), std::move(expr)); -} - -shared_ptr Relation::Limit(int64_t limit, int64_t offset) { - return make_shared(shared_from_this(), limit, offset); -} - -shared_ptr Relation::Order(const string &expression) { - auto order_list = Parser::ParseOrderList(expression, context.GetContext()->GetParserOptions()); - return Order(std::move(order_list)); -} - -shared_ptr Relation::Order(vector order_list) { - return make_shared(shared_from_this(), std::move(order_list)); -} - -shared_ptr Relation::Order(const vector &expressions) { - if (expressions.empty()) { - throw ParserException("Zero ORDER BY expressions provided"); - } - vector order_list; - for (auto &expression : expressions) { - auto inner_list = Parser::ParseOrderList(expression, context.GetContext()->GetParserOptions()); - if (inner_list.size() != 1) { - throw ParserException("Expected a single ORDER BY expression in the expression list"); - } - order_list.push_back(std::move(inner_list[0])); - } - return Order(std::move(order_list)); -} - -shared_ptr Relation::Join(const shared_ptr &other, const string &condition, JoinType type, - JoinRefType ref_type) { - auto expression_list = Parser::ParseExpressionList(condition, context.GetContext()->GetParserOptions()); - D_ASSERT(!expression_list.empty()); - return Join(other, std::move(expression_list), type, ref_type); -} - -shared_ptr Relation::Join(const shared_ptr &other, - vector> expression_list, JoinType type, - JoinRefType ref_type) { - if (expression_list.size() > 1 || expression_list[0]->type == ExpressionType::COLUMN_REF) { - // multiple columns or single column ref: the condition is a USING list - vector using_columns; - for (auto &expr : expression_list) { - if (expr->type != ExpressionType::COLUMN_REF) { - throw ParserException("Expected a single expression as join condition"); - } - auto &colref = expr->Cast(); - if (colref.IsQualified()) { - throw ParserException("Expected unqualified column for column in USING clause"); - } - using_columns.push_back(colref.column_names[0]); - } - return make_shared(shared_from_this(), other, std::move(using_columns), type, ref_type); - } else { - // single expression that is not a column reference: use the expression as a join condition - return make_shared(shared_from_this(), other, std::move(expression_list[0]), type, ref_type); - } -} - -shared_ptr Relation::CrossProduct(const shared_ptr &other, JoinRefType join_ref_type) { - return make_shared(shared_from_this(), other, join_ref_type); -} - -shared_ptr Relation::Union(const shared_ptr &other) { - return make_shared(shared_from_this(), other, SetOperationType::UNION); -} - -shared_ptr Relation::Except(const shared_ptr &other) { - return make_shared(shared_from_this(), other, SetOperationType::EXCEPT); -} - -shared_ptr Relation::Intersect(const shared_ptr &other) { - return make_shared(shared_from_this(), other, SetOperationType::INTERSECT); -} - -shared_ptr Relation::Distinct() { - return make_shared(shared_from_this()); -} - -shared_ptr Relation::Alias(const string &alias) { - return make_shared(shared_from_this(), alias); -} - -shared_ptr Relation::Aggregate(const string &aggregate_list) { - auto expression_list = Parser::ParseExpressionList(aggregate_list, context.GetContext()->GetParserOptions()); - return make_shared(shared_from_this(), std::move(expression_list)); -} - -shared_ptr Relation::Aggregate(const string &aggregate_list, const string &group_list) { - auto expression_list = Parser::ParseExpressionList(aggregate_list, context.GetContext()->GetParserOptions()); - auto groups = Parser::ParseGroupByList(group_list, context.GetContext()->GetParserOptions()); - return make_shared(shared_from_this(), std::move(expression_list), std::move(groups)); -} - -shared_ptr Relation::Aggregate(const vector &aggregates) { - auto aggregate_list = StringListToExpressionList(*context.GetContext(), aggregates); - return make_shared(shared_from_this(), std::move(aggregate_list)); -} - -shared_ptr Relation::Aggregate(const vector &aggregates, const vector &groups) { - auto aggregate_list = StringUtil::Join(aggregates, ", "); - auto group_list = StringUtil::Join(groups, ", "); - return this->Aggregate(aggregate_list, group_list); -} - -shared_ptr Relation::Aggregate(vector> expressions, const string &group_list) { - auto groups = Parser::ParseGroupByList(group_list, context.GetContext()->GetParserOptions()); - return make_shared(shared_from_this(), std::move(expressions), std::move(groups)); -} - -string Relation::GetAlias() { - return "relation"; -} - -unique_ptr Relation::GetTableRef() { - auto select = make_uniq(); - select->node = GetQueryNode(); - return make_uniq(std::move(select), GetAlias()); -} - -unique_ptr Relation::Execute() { - return context.GetContext()->Execute(shared_from_this()); -} - -unique_ptr Relation::ExecuteOrThrow() { - auto res = Execute(); - D_ASSERT(res); - if (res->HasError()) { - res->ThrowError(); - } - return res; -} - -BoundStatement Relation::Bind(Binder &binder) { - SelectStatement stmt; - stmt.node = GetQueryNode(); - return binder.Bind(stmt.Cast()); -} - -shared_ptr Relation::InsertRel(const string &schema_name, const string &table_name) { - return make_shared(shared_from_this(), schema_name, table_name); -} - -void Relation::Insert(const string &table_name) { - Insert(INVALID_SCHEMA, table_name); -} - -void Relation::Insert(const string &schema_name, const string &table_name) { - auto insert = InsertRel(schema_name, table_name); - auto res = insert->Execute(); - if (res->HasError()) { - const string prepended_message = "Failed to insert into table '" + table_name + "': "; - res->ThrowError(prepended_message); - } -} - -void Relation::Insert(const vector> &values) { - vector column_names; - auto rel = make_shared(context.GetContext(), values, std::move(column_names), "values"); - rel->Insert(GetAlias()); -} - -shared_ptr Relation::CreateRel(const string &schema_name, const string &table_name) { - return make_shared(shared_from_this(), schema_name, table_name); -} - -void Relation::Create(const string &table_name) { - Create(INVALID_SCHEMA, table_name); -} - -void Relation::Create(const string &schema_name, const string &table_name) { - auto create = CreateRel(schema_name, table_name); - auto res = create->Execute(); - if (res->HasError()) { - const string prepended_message = "Failed to create table '" + table_name + "': "; - res->ThrowError(prepended_message); - } -} - -shared_ptr Relation::WriteCSVRel(const string &csv_file, case_insensitive_map_t> options) { - return std::make_shared(shared_from_this(), csv_file, std::move(options)); -} - -void Relation::WriteCSV(const string &csv_file, case_insensitive_map_t> options) { - auto write_csv = WriteCSVRel(csv_file, std::move(options)); - auto res = write_csv->Execute(); - if (res->HasError()) { - const string prepended_message = "Failed to write '" + csv_file + "': "; - res->ThrowError(prepended_message); - } -} - -shared_ptr Relation::WriteParquetRel(const string &parquet_file, - case_insensitive_map_t> options) { - auto write_parquet = - std::make_shared(shared_from_this(), parquet_file, std::move(options)); - return std::move(write_parquet); -} - -void Relation::WriteParquet(const string &parquet_file, case_insensitive_map_t> options) { - auto write_parquet = WriteParquetRel(parquet_file, std::move(options)); - auto res = write_parquet->Execute(); - if (res->HasError()) { - const string prepended_message = "Failed to write '" + parquet_file + "': "; - res->ThrowError(prepended_message); - } -} - -shared_ptr Relation::CreateView(const string &name, bool replace, bool temporary) { - return CreateView(INVALID_SCHEMA, name, replace, temporary); -} - -shared_ptr Relation::CreateView(const string &schema_name, const string &name, bool replace, bool temporary) { - auto view = make_shared(shared_from_this(), schema_name, name, replace, temporary); - auto res = view->Execute(); - if (res->HasError()) { - const string prepended_message = "Failed to create view '" + name + "': "; - res->ThrowError(prepended_message); - } - return shared_from_this(); -} - -unique_ptr Relation::Query(const string &sql) { - return context.GetContext()->Query(sql, false); -} - -unique_ptr Relation::Query(const string &name, const string &sql) { - CreateView(name); - return Query(sql); -} - -unique_ptr Relation::Explain(ExplainType type) { - auto explain = make_shared(shared_from_this(), type); - return explain->Execute(); -} - -void Relation::Update(const string &update, const string &condition) { - throw Exception("UPDATE can only be used on base tables!"); -} - -void Relation::Delete(const string &condition) { - throw Exception("DELETE can only be used on base tables!"); -} - -shared_ptr Relation::TableFunction(const std::string &fname, const vector &values, - const named_parameter_map_t &named_parameters) { - return make_shared(context.GetContext(), fname, values, named_parameters, - shared_from_this()); -} - -shared_ptr Relation::TableFunction(const std::string &fname, const vector &values) { - return make_shared(context.GetContext(), fname, values, shared_from_this()); -} - -string Relation::ToString() { - string str; - str += "---------------------\n"; - str += "--- Relation Tree ---\n"; - str += "---------------------\n"; - str += ToString(0); - str += "\n\n"; - str += "---------------------\n"; - str += "-- Result Columns --\n"; - str += "---------------------\n"; - auto &cols = Columns(); - for (idx_t i = 0; i < cols.size(); i++) { - str += "- " + cols[i].Name() + " (" + cols[i].Type().ToString() + ")\n"; - } - return str; -} - -// LCOV_EXCL_START -unique_ptr Relation::GetQueryNode() { - throw InternalException("Cannot create a query node from this node type"); -} - -void Relation::Head(idx_t limit) { - auto limit_node = Limit(limit); - limit_node->Execute()->Print(); -} -// LCOV_EXCL_STOP - -void Relation::Print() { - Printer::Print(ToString()); -} - -string Relation::RenderWhitespace(idx_t depth) { - return string(depth * 2, ' '); -} - -vector> Relation::GetAllDependencies() { - vector> all_dependencies; - Relation *cur = this; - while (cur) { - if (cur->extra_dependencies) { - all_dependencies.push_back(cur->extra_dependencies); - } - cur = cur->ChildRelation(); - } - return all_dependencies; -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Access Mode -//===--------------------------------------------------------------------===// -void AccessModeSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - if (db) { - throw InvalidInputException("Cannot change access_mode setting while database is running - it must be set when " - "opening or attaching the database"); - } - auto parameter = StringUtil::Lower(input.ToString()); - if (parameter == "automatic") { - config.options.access_mode = AccessMode::AUTOMATIC; - } else if (parameter == "read_only") { - config.options.access_mode = AccessMode::READ_ONLY; - } else if (parameter == "read_write") { - config.options.access_mode = AccessMode::READ_WRITE; - } else { - throw InvalidInputException( - "Unrecognized parameter for option ACCESS_MODE \"%s\". Expected READ_ONLY or READ_WRITE.", parameter); - } -} - -void AccessModeSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.access_mode = DBConfig().options.access_mode; -} - -Value AccessModeSetting::GetSetting(ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - switch (config.options.access_mode) { - case AccessMode::AUTOMATIC: - return "automatic"; - case AccessMode::READ_ONLY: - return "read_only"; - case AccessMode::READ_WRITE: - return "read_write"; - default: - throw InternalException("Unknown access mode setting"); - } -} - -//===--------------------------------------------------------------------===// -// Checkpoint Threshold -//===--------------------------------------------------------------------===// -void CheckpointThresholdSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - idx_t new_limit = DBConfig::ParseMemoryLimit(input.ToString()); - config.options.checkpoint_wal_size = new_limit; -} - -void CheckpointThresholdSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.checkpoint_wal_size = DBConfig().options.checkpoint_wal_size; -} - -Value CheckpointThresholdSetting::GetSetting(ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - return Value(StringUtil::BytesToHumanReadableString(config.options.checkpoint_wal_size)); -} - -//===--------------------------------------------------------------------===// -// Debug Checkpoint Abort -//===--------------------------------------------------------------------===// -void DebugCheckpointAbort::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - auto checkpoint_abort = StringUtil::Lower(input.ToString()); - if (checkpoint_abort == "none") { - config.options.checkpoint_abort = CheckpointAbort::NO_ABORT; - } else if (checkpoint_abort == "before_truncate") { - config.options.checkpoint_abort = CheckpointAbort::DEBUG_ABORT_BEFORE_TRUNCATE; - } else if (checkpoint_abort == "before_header") { - config.options.checkpoint_abort = CheckpointAbort::DEBUG_ABORT_BEFORE_HEADER; - } else if (checkpoint_abort == "after_free_list_write") { - config.options.checkpoint_abort = CheckpointAbort::DEBUG_ABORT_AFTER_FREE_LIST_WRITE; - } else { - throw ParserException( - "Unrecognized option for PRAGMA debug_checkpoint_abort, expected none, before_truncate or before_header"); - } -} - -void DebugCheckpointAbort::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.checkpoint_abort = DBConfig().options.checkpoint_abort; -} - -Value DebugCheckpointAbort::GetSetting(ClientContext &context) { - auto &config = DBConfig::GetConfig(*context.db); - auto setting = config.options.checkpoint_abort; - switch (setting) { - case CheckpointAbort::NO_ABORT: - return "none"; - case CheckpointAbort::DEBUG_ABORT_BEFORE_TRUNCATE: - return "before_truncate"; - case CheckpointAbort::DEBUG_ABORT_BEFORE_HEADER: - return "before_header"; - case CheckpointAbort::DEBUG_ABORT_AFTER_FREE_LIST_WRITE: - return "after_free_list_write"; - default: - throw InternalException("Type not implemented for CheckpointAbort"); - } -} - -//===--------------------------------------------------------------------===// -// Debug Force External -//===--------------------------------------------------------------------===// -void DebugForceExternal::ResetLocal(ClientContext &context) { - ClientConfig::GetConfig(context).force_external = ClientConfig().force_external; -} - -void DebugForceExternal::SetLocal(ClientContext &context, const Value &input) { - ClientConfig::GetConfig(context).force_external = input.GetValue(); -} - -Value DebugForceExternal::GetSetting(ClientContext &context) { - return Value::BOOLEAN(ClientConfig::GetConfig(context).force_external); -} - -//===--------------------------------------------------------------------===// -// Debug Force NoCrossProduct -//===--------------------------------------------------------------------===// -void DebugForceNoCrossProduct::ResetLocal(ClientContext &context) { - ClientConfig::GetConfig(context).force_no_cross_product = ClientConfig().force_no_cross_product; -} - -void DebugForceNoCrossProduct::SetLocal(ClientContext &context, const Value &input) { - ClientConfig::GetConfig(context).force_no_cross_product = input.GetValue(); -} - -Value DebugForceNoCrossProduct::GetSetting(ClientContext &context) { - return Value::BOOLEAN(ClientConfig::GetConfig(context).force_no_cross_product); -} - -//===--------------------------------------------------------------------===// -// Ordered Aggregate Threshold -//===--------------------------------------------------------------------===// -void OrderedAggregateThreshold::ResetLocal(ClientContext &context) { - ClientConfig::GetConfig(context).ordered_aggregate_threshold = ClientConfig().ordered_aggregate_threshold; -} - -void OrderedAggregateThreshold::SetLocal(ClientContext &context, const Value &input) { - const auto param = input.GetValue(); - if (param <= 0) { - throw ParserException("Invalid option for PRAGMA ordered_aggregate_threshold, value must be positive"); - } - ClientConfig::GetConfig(context).ordered_aggregate_threshold = param; -} - -Value OrderedAggregateThreshold::GetSetting(ClientContext &context) { - return Value::UBIGINT(ClientConfig::GetConfig(context).ordered_aggregate_threshold); -} - -//===--------------------------------------------------------------------===// -// Debug Window Mode -//===--------------------------------------------------------------------===// -void DebugWindowMode::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - auto param = StringUtil::Lower(input.ToString()); - if (param == "window") { - config.options.window_mode = WindowAggregationMode::WINDOW; - } else if (param == "combine") { - config.options.window_mode = WindowAggregationMode::COMBINE; - } else if (param == "separate") { - config.options.window_mode = WindowAggregationMode::SEPARATE; - } else { - throw ParserException("Unrecognized option for PRAGMA debug_window_mode, expected window, combine or separate"); - } -} - -void DebugWindowMode::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.window_mode = DBConfig().options.window_mode; -} - -Value DebugWindowMode::GetSetting(ClientContext &context) { - return Value(); -} - -//===--------------------------------------------------------------------===// -// Debug AsOf Join -//===--------------------------------------------------------------------===// -void DebugAsOfIEJoin::ResetLocal(ClientContext &context) { - ClientConfig::GetConfig(context).force_asof_iejoin = ClientConfig().force_asof_iejoin; -} - -void DebugAsOfIEJoin::SetLocal(ClientContext &context, const Value &input) { - ClientConfig::GetConfig(context).force_asof_iejoin = input.GetValue(); -} - -Value DebugAsOfIEJoin::GetSetting(ClientContext &context) { - return Value::BOOLEAN(ClientConfig::GetConfig(context).force_asof_iejoin); -} - -//===--------------------------------------------------------------------===// -// Prefer Range Joins -//===--------------------------------------------------------------------===// -void PreferRangeJoins::ResetLocal(ClientContext &context) { - ClientConfig::GetConfig(context).prefer_range_joins = ClientConfig().prefer_range_joins; -} - -void PreferRangeJoins::SetLocal(ClientContext &context, const Value &input) { - ClientConfig::GetConfig(context).prefer_range_joins = input.GetValue(); -} - -Value PreferRangeJoins::GetSetting(ClientContext &context) { - return Value::BOOLEAN(ClientConfig::GetConfig(context).prefer_range_joins); -} - -//===--------------------------------------------------------------------===// -// Default Collation -//===--------------------------------------------------------------------===// -void DefaultCollationSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - auto parameter = StringUtil::Lower(input.ToString()); - config.options.collation = parameter; -} - -void DefaultCollationSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.collation = DBConfig().options.collation; -} - -void DefaultCollationSetting::ResetLocal(ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - config.options.collation = DBConfig().options.collation; -} - -void DefaultCollationSetting::SetLocal(ClientContext &context, const Value &input) { - auto parameter = input.ToString(); - // bind the collation to verify that it exists - ExpressionBinder::TestCollation(context, parameter); - auto &config = DBConfig::GetConfig(context); - config.options.collation = parameter; -} - -Value DefaultCollationSetting::GetSetting(ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - return Value(config.options.collation); -} - -//===--------------------------------------------------------------------===// -// Default Order -//===--------------------------------------------------------------------===// -void DefaultOrderSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - auto parameter = StringUtil::Lower(input.ToString()); - if (parameter == "ascending" || parameter == "asc") { - config.options.default_order_type = OrderType::ASCENDING; - } else if (parameter == "descending" || parameter == "desc") { - config.options.default_order_type = OrderType::DESCENDING; - } else { - throw InvalidInputException("Unrecognized parameter for option DEFAULT_ORDER \"%s\". Expected ASC or DESC.", - parameter); - } -} - -void DefaultOrderSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.default_order_type = DBConfig().options.default_order_type; -} - -Value DefaultOrderSetting::GetSetting(ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - switch (config.options.default_order_type) { - case OrderType::ASCENDING: - return "asc"; - case OrderType::DESCENDING: - return "desc"; - default: - throw InternalException("Unknown order type setting"); - } -} - -//===--------------------------------------------------------------------===// -// Default Null Order -//===--------------------------------------------------------------------===// -void DefaultNullOrderSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - auto parameter = StringUtil::Lower(input.ToString()); - - if (parameter == "nulls_first" || parameter == "nulls first" || parameter == "null first" || parameter == "first") { - config.options.default_null_order = DefaultOrderByNullType::NULLS_FIRST; - } else if (parameter == "nulls_last" || parameter == "nulls last" || parameter == "null last" || - parameter == "last") { - config.options.default_null_order = DefaultOrderByNullType::NULLS_LAST; - } else if (parameter == "nulls_first_on_asc_last_on_desc" || parameter == "sqlite" || parameter == "mysql") { - config.options.default_null_order = DefaultOrderByNullType::NULLS_FIRST_ON_ASC_LAST_ON_DESC; - } else if (parameter == "nulls_last_on_asc_first_on_desc" || parameter == "postgres") { - config.options.default_null_order = DefaultOrderByNullType::NULLS_LAST_ON_ASC_FIRST_ON_DESC; - } else { - throw ParserException("Unrecognized parameter for option NULL_ORDER \"%s\", expected either NULLS FIRST, NULLS " - "LAST, SQLite, MySQL or Postgres", - parameter); - } -} - -void DefaultNullOrderSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.default_null_order = DBConfig().options.default_null_order; -} - -Value DefaultNullOrderSetting::GetSetting(ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - switch (config.options.default_null_order) { - case DefaultOrderByNullType::NULLS_FIRST: - return "nulls_first"; - case DefaultOrderByNullType::NULLS_LAST: - return "nulls_last"; - case DefaultOrderByNullType::NULLS_FIRST_ON_ASC_LAST_ON_DESC: - return "nulls_first_on_asc_last_on_desc"; - case DefaultOrderByNullType::NULLS_LAST_ON_ASC_FIRST_ON_DESC: - return "nulls_last_on_asc_first_on_desc"; - default: - throw InternalException("Unknown null order setting"); - } -} - -//===--------------------------------------------------------------------===// -// Disabled File Systems -//===--------------------------------------------------------------------===// -void DisabledFileSystemsSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - if (!db) { - throw InternalException("disabled_filesystems can only be set in an active database"); - } - auto &fs = FileSystem::GetFileSystem(*db); - auto list = StringUtil::Split(input.ToString(), ","); - fs.SetDisabledFileSystems(list); -} - -void DisabledFileSystemsSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - if (!db) { - throw InternalException("disabled_filesystems can only be set in an active database"); - } - auto &fs = FileSystem::GetFileSystem(*db); - fs.SetDisabledFileSystems(vector()); -} - -Value DisabledFileSystemsSetting::GetSetting(ClientContext &context) { - return Value(""); -} - -//===--------------------------------------------------------------------===// -// Disabled Optimizer -//===--------------------------------------------------------------------===// -void DisabledOptimizersSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - auto list = StringUtil::Split(input.ToString(), ","); - set disabled_optimizers; - for (auto &entry : list) { - auto param = StringUtil::Lower(entry); - StringUtil::Trim(param); - if (param.empty()) { - continue; - } - disabled_optimizers.insert(OptimizerTypeFromString(param)); - } - config.options.disabled_optimizers = std::move(disabled_optimizers); -} - -void DisabledOptimizersSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.disabled_optimizers = DBConfig().options.disabled_optimizers; -} - -Value DisabledOptimizersSetting::GetSetting(ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - string result; - for (auto &optimizer : config.options.disabled_optimizers) { - if (!result.empty()) { - result += ","; - } - result += OptimizerTypeToString(optimizer); - } - return Value(result); -} - -//===--------------------------------------------------------------------===// -// Enable External Access -//===--------------------------------------------------------------------===// -void EnableExternalAccessSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - auto new_value = input.GetValue(); - if (db && new_value) { - throw InvalidInputException("Cannot change enable_external_access setting while database is running"); - } - config.options.enable_external_access = new_value; -} - -void EnableExternalAccessSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - if (db) { - throw InvalidInputException("Cannot change enable_external_access setting while database is running"); - } - config.options.enable_external_access = DBConfig().options.enable_external_access; -} - -Value EnableExternalAccessSetting::GetSetting(ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - return Value::BOOLEAN(config.options.enable_external_access); -} - -//===--------------------------------------------------------------------===// -// Enable FSST Vectors -//===--------------------------------------------------------------------===// -void EnableFSSTVectors::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - config.options.enable_fsst_vectors = input.GetValue(); -} - -void EnableFSSTVectors::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.enable_fsst_vectors = DBConfig().options.enable_fsst_vectors; -} - -Value EnableFSSTVectors::GetSetting(ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - return Value::BOOLEAN(config.options.enable_fsst_vectors); -} - -//===--------------------------------------------------------------------===// -// Allow Unsigned Extensions -//===--------------------------------------------------------------------===// -void AllowUnsignedExtensionsSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - auto new_value = input.GetValue(); - if (db && new_value) { - throw InvalidInputException("Cannot change allow_unsigned_extensions setting while database is running"); - } - config.options.allow_unsigned_extensions = new_value; -} - -void AllowUnsignedExtensionsSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - if (db) { - throw InvalidInputException("Cannot change allow_unsigned_extensions setting while database is running"); - } - config.options.allow_unsigned_extensions = DBConfig().options.allow_unsigned_extensions; -} - -Value AllowUnsignedExtensionsSetting::GetSetting(ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - return Value::BOOLEAN(config.options.allow_unsigned_extensions); -} - -//===--------------------------------------------------------------------===// -// Enable Object Cache -//===--------------------------------------------------------------------===// -void EnableObjectCacheSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - config.options.object_cache_enable = input.GetValue(); -} - -void EnableObjectCacheSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.object_cache_enable = DBConfig().options.object_cache_enable; -} - -Value EnableObjectCacheSetting::GetSetting(ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - return Value::BOOLEAN(config.options.object_cache_enable); -} - -//===--------------------------------------------------------------------===// -// Enable HTTP Metadata Cache -//===--------------------------------------------------------------------===// -void EnableHTTPMetadataCacheSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - config.options.http_metadata_cache_enable = input.GetValue(); -} - -void EnableHTTPMetadataCacheSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.http_metadata_cache_enable = DBConfig().options.http_metadata_cache_enable; -} - -Value EnableHTTPMetadataCacheSetting::GetSetting(ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - return Value::BOOLEAN(config.options.http_metadata_cache_enable); -} - -//===--------------------------------------------------------------------===// -// Enable Profiling -//===--------------------------------------------------------------------===// -void EnableProfilingSetting::ResetLocal(ClientContext &context) { - auto &config = ClientConfig::GetConfig(context); - config.profiler_print_format = ClientConfig().profiler_print_format; - config.enable_profiler = ClientConfig().enable_profiler; - config.emit_profiler_output = ClientConfig().emit_profiler_output; -} - -void EnableProfilingSetting::SetLocal(ClientContext &context, const Value &input) { - auto parameter = StringUtil::Lower(input.ToString()); - - auto &config = ClientConfig::GetConfig(context); - if (parameter == "json") { - config.profiler_print_format = ProfilerPrintFormat::JSON; - } else if (parameter == "query_tree") { - config.profiler_print_format = ProfilerPrintFormat::QUERY_TREE; - } else if (parameter == "query_tree_optimizer") { - config.profiler_print_format = ProfilerPrintFormat::QUERY_TREE_OPTIMIZER; - } else { - throw ParserException( - "Unrecognized print format %s, supported formats: [json, query_tree, query_tree_optimizer]", parameter); - } - config.enable_profiler = true; - config.emit_profiler_output = true; -} - -Value EnableProfilingSetting::GetSetting(ClientContext &context) { - auto &config = ClientConfig::GetConfig(context); - if (!config.enable_profiler) { - return Value(); - } - switch (config.profiler_print_format) { - case ProfilerPrintFormat::JSON: - return Value("json"); - case ProfilerPrintFormat::QUERY_TREE: - return Value("query_tree"); - case ProfilerPrintFormat::QUERY_TREE_OPTIMIZER: - return Value("query_tree_optimizer"); - default: - throw InternalException("Unsupported profiler print format"); - } -} - -//===--------------------------------------------------------------------===// -// Custom Extension Repository -//===--------------------------------------------------------------------===// -void CustomExtensionRepository::ResetLocal(ClientContext &context) { - ClientConfig::GetConfig(context).custom_extension_repo = ClientConfig().custom_extension_repo; -} - -void CustomExtensionRepository::SetLocal(ClientContext &context, const Value &input) { - ClientConfig::GetConfig(context).custom_extension_repo = StringUtil::Lower(input.ToString()); -} - -Value CustomExtensionRepository::GetSetting(ClientContext &context) { - return Value(ClientConfig::GetConfig(context).custom_extension_repo); -} - -//===--------------------------------------------------------------------===// -// Autoload Extension Repository -//===--------------------------------------------------------------------===// -void AutoloadExtensionRepository::ResetLocal(ClientContext &context) { - ClientConfig::GetConfig(context).autoinstall_extension_repo = ClientConfig().autoinstall_extension_repo; -} - -void AutoloadExtensionRepository::SetLocal(ClientContext &context, const Value &input) { - ClientConfig::GetConfig(context).autoinstall_extension_repo = StringUtil::Lower(input.ToString()); -} - -Value AutoloadExtensionRepository::GetSetting(ClientContext &context) { - return Value(ClientConfig::GetConfig(context).autoinstall_extension_repo); -} - -//===--------------------------------------------------------------------===// -// Autoinstall Known Extensions -//===--------------------------------------------------------------------===// -void AutoinstallKnownExtensions::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - config.options.autoinstall_known_extensions = input.GetValue(); -} - -void AutoinstallKnownExtensions::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.autoinstall_known_extensions = DBConfig().options.autoinstall_known_extensions; -} - -Value AutoinstallKnownExtensions::GetSetting(ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - return Value::BOOLEAN(config.options.autoinstall_known_extensions); -} -//===--------------------------------------------------------------------===// -// Autoload Known Extensions -//===--------------------------------------------------------------------===// -void AutoloadKnownExtensions::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - config.options.autoload_known_extensions = input.GetValue(); -} - -void AutoloadKnownExtensions::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.autoload_known_extensions = DBConfig().options.autoload_known_extensions; -} - -Value AutoloadKnownExtensions::GetSetting(ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - return Value::BOOLEAN(config.options.autoload_known_extensions); -} - -//===--------------------------------------------------------------------===// -// Enable Progress Bar -//===--------------------------------------------------------------------===// -void EnableProgressBarSetting::ResetLocal(ClientContext &context) { - auto &config = ClientConfig::GetConfig(context); - ProgressBar::SystemOverrideCheck(config); - config.enable_progress_bar = ClientConfig().enable_progress_bar; -} - -void EnableProgressBarSetting::SetLocal(ClientContext &context, const Value &input) { - auto &config = ClientConfig::GetConfig(context); - ProgressBar::SystemOverrideCheck(config); - config.enable_progress_bar = input.GetValue(); -} - -Value EnableProgressBarSetting::GetSetting(ClientContext &context) { - return Value::BOOLEAN(ClientConfig::GetConfig(context).enable_progress_bar); -} - -//===--------------------------------------------------------------------===// -// Enable Progress Bar Print -//===--------------------------------------------------------------------===// -void EnableProgressBarPrintSetting::SetLocal(ClientContext &context, const Value &input) { - auto &config = ClientConfig::GetConfig(context); - ProgressBar::SystemOverrideCheck(config); - config.print_progress_bar = input.GetValue(); -} - -void EnableProgressBarPrintSetting::ResetLocal(ClientContext &context) { - auto &config = ClientConfig::GetConfig(context); - ProgressBar::SystemOverrideCheck(config); - config.print_progress_bar = ClientConfig().print_progress_bar; -} - -Value EnableProgressBarPrintSetting::GetSetting(ClientContext &context) { - return Value::BOOLEAN(ClientConfig::GetConfig(context).print_progress_bar); -} - -//===--------------------------------------------------------------------===// -// Explain Output -//===--------------------------------------------------------------------===// -void ExplainOutputSetting::ResetLocal(ClientContext &context) { - ClientConfig::GetConfig(context).explain_output_type = ClientConfig().explain_output_type; -} - -void ExplainOutputSetting::SetLocal(ClientContext &context, const Value &input) { - auto parameter = StringUtil::Lower(input.ToString()); - if (parameter == "all") { - ClientConfig::GetConfig(context).explain_output_type = ExplainOutputType::ALL; - } else if (parameter == "optimized_only") { - ClientConfig::GetConfig(context).explain_output_type = ExplainOutputType::OPTIMIZED_ONLY; - } else if (parameter == "physical_only") { - ClientConfig::GetConfig(context).explain_output_type = ExplainOutputType::PHYSICAL_ONLY; - } else { - throw ParserException("Unrecognized output type \"%s\", expected either ALL, OPTIMIZED_ONLY or PHYSICAL_ONLY", - parameter); - } -} - -Value ExplainOutputSetting::GetSetting(ClientContext &context) { - switch (ClientConfig::GetConfig(context).explain_output_type) { - case ExplainOutputType::ALL: - return "all"; - case ExplainOutputType::OPTIMIZED_ONLY: - return "optimized_only"; - case ExplainOutputType::PHYSICAL_ONLY: - return "physical_only"; - default: - throw InternalException("Unrecognized explain output type"); - } -} - -//===--------------------------------------------------------------------===// -// Extension Directory Setting -//===--------------------------------------------------------------------===// -void ExtensionDirectorySetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - auto new_directory = input.ToString(); - config.options.extension_directory = input.ToString(); -} - -void ExtensionDirectorySetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.extension_directory = DBConfig().options.extension_directory; -} - -Value ExtensionDirectorySetting::GetSetting(ClientContext &context) { - return Value(DBConfig::GetConfig(context).options.extension_directory); -} - -//===--------------------------------------------------------------------===// -// External Threads Setting -//===--------------------------------------------------------------------===// -void ExternalThreadsSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - config.options.external_threads = input.GetValue(); -} - -void ExternalThreadsSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.external_threads = DBConfig().options.external_threads; -} - -Value ExternalThreadsSetting::GetSetting(ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - return Value::BIGINT(config.options.external_threads); -} - -//===--------------------------------------------------------------------===// -// File Search Path -//===--------------------------------------------------------------------===// -void FileSearchPathSetting::ResetLocal(ClientContext &context) { - auto &client_data = ClientData::Get(context); - client_data.file_search_path.clear(); -} - -void FileSearchPathSetting::SetLocal(ClientContext &context, const Value &input) { - auto parameter = input.ToString(); - auto &client_data = ClientData::Get(context); - client_data.file_search_path = parameter; -} - -Value FileSearchPathSetting::GetSetting(ClientContext &context) { - auto &client_data = ClientData::Get(context); - return Value(client_data.file_search_path); -} - -//===--------------------------------------------------------------------===// -// Force Compression -//===--------------------------------------------------------------------===// -void ForceCompressionSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - auto compression = StringUtil::Lower(input.ToString()); - if (compression == "none" || compression == "auto") { - config.options.force_compression = CompressionType::COMPRESSION_AUTO; - } else { - auto compression_type = CompressionTypeFromString(compression); - if (compression_type == CompressionType::COMPRESSION_AUTO) { - auto compression_types = StringUtil::Join(ListCompressionTypes(), ", "); - throw ParserException("Unrecognized option for PRAGMA force_compression, expected %s", compression_types); - } - config.options.force_compression = compression_type; - } -} - -void ForceCompressionSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.force_compression = DBConfig().options.force_compression; -} - -Value ForceCompressionSetting::GetSetting(ClientContext &context) { - auto &config = DBConfig::GetConfig(*context.db); - return CompressionTypeToString(config.options.force_compression); -} - -//===--------------------------------------------------------------------===// -// Force Bitpacking mode -//===--------------------------------------------------------------------===// -void ForceBitpackingModeSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - auto mode_str = StringUtil::Lower(input.ToString()); - auto mode = BitpackingModeFromString(mode_str); - if (mode == BitpackingMode::INVALID) { - throw ParserException("Unrecognized option for force_bitpacking_mode, expected none, constant, constant_delta, " - "delta_for, or for"); - } - config.options.force_bitpacking_mode = mode; -} - -void ForceBitpackingModeSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.force_bitpacking_mode = DBConfig().options.force_bitpacking_mode; -} - -Value ForceBitpackingModeSetting::GetSetting(ClientContext &context) { - return Value(BitpackingModeToString(context.db->config.options.force_bitpacking_mode)); -} - -//===--------------------------------------------------------------------===// -// Home Directory -//===--------------------------------------------------------------------===// -void HomeDirectorySetting::ResetLocal(ClientContext &context) { - ClientConfig::GetConfig(context).home_directory = ClientConfig().home_directory; -} - -void HomeDirectorySetting::SetLocal(ClientContext &context, const Value &input) { - auto &config = ClientConfig::GetConfig(context); - config.home_directory = input.IsNull() ? string() : input.ToString(); -} - -Value HomeDirectorySetting::GetSetting(ClientContext &context) { - auto &config = ClientConfig::GetConfig(context); - return Value(config.home_directory); -} - -//===--------------------------------------------------------------------===// -// Integer Division -//===--------------------------------------------------------------------===// -void IntegerDivisionSetting::ResetLocal(ClientContext &context) { - ClientConfig::GetConfig(context).integer_division = ClientConfig().integer_division; -} - -void IntegerDivisionSetting::SetLocal(ClientContext &context, const Value &input) { - auto &config = ClientConfig::GetConfig(context); - config.integer_division = input.GetValue(); -} - -Value IntegerDivisionSetting::GetSetting(ClientContext &context) { - auto &config = ClientConfig::GetConfig(context); - return Value(config.integer_division); -} - -//===--------------------------------------------------------------------===// -// Log Query Path -//===--------------------------------------------------------------------===// -void LogQueryPathSetting::ResetLocal(ClientContext &context) { - auto &client_data = ClientData::Get(context); - // TODO: verify that this does the right thing - client_data.log_query_writer = std::move(ClientData(context).log_query_writer); -} - -void LogQueryPathSetting::SetLocal(ClientContext &context, const Value &input) { - auto &client_data = ClientData::Get(context); - auto path = input.ToString(); - if (path.empty()) { - // empty path: clean up query writer - client_data.log_query_writer = nullptr; - } else { - client_data.log_query_writer = make_uniq(FileSystem::GetFileSystem(context), path, - BufferedFileWriter::DEFAULT_OPEN_FLAGS); - } -} - -Value LogQueryPathSetting::GetSetting(ClientContext &context) { - auto &client_data = ClientData::Get(context); - return client_data.log_query_writer ? Value(client_data.log_query_writer->path) : Value(); -} - -//===--------------------------------------------------------------------===// -// Lock Configuration -//===--------------------------------------------------------------------===// -void LockConfigurationSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - auto new_value = input.GetValue(); - config.options.lock_configuration = new_value; -} - -void LockConfigurationSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.lock_configuration = DBConfig().options.lock_configuration; -} - -Value LockConfigurationSetting::GetSetting(ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - return Value::BOOLEAN(config.options.lock_configuration); -} - -//===--------------------------------------------------------------------===// -// Immediate Transaction Mode -//===--------------------------------------------------------------------===// -void ImmediateTransactionModeSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - config.options.immediate_transaction_mode = BooleanValue::Get(input); -} - -void ImmediateTransactionModeSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.immediate_transaction_mode = DBConfig().options.immediate_transaction_mode; -} - -Value ImmediateTransactionModeSetting::GetSetting(ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - return Value::BOOLEAN(config.options.immediate_transaction_mode); -} - -//===--------------------------------------------------------------------===// -// Maximum Expression Depth -//===--------------------------------------------------------------------===// -void MaximumExpressionDepthSetting::ResetLocal(ClientContext &context) { - ClientConfig::GetConfig(context).max_expression_depth = ClientConfig().max_expression_depth; -} - -void MaximumExpressionDepthSetting::SetLocal(ClientContext &context, const Value &input) { - ClientConfig::GetConfig(context).max_expression_depth = input.GetValue(); -} - -Value MaximumExpressionDepthSetting::GetSetting(ClientContext &context) { - return Value::UBIGINT(ClientConfig::GetConfig(context).max_expression_depth); -} - -//===--------------------------------------------------------------------===// -// Maximum Memory -//===--------------------------------------------------------------------===// -void MaximumMemorySetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - config.options.maximum_memory = DBConfig::ParseMemoryLimit(input.ToString()); - if (db) { - BufferManager::GetBufferManager(*db).SetLimit(config.options.maximum_memory); - } -} - -void MaximumMemorySetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.SetDefaultMaxMemory(); -} - -Value MaximumMemorySetting::GetSetting(ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - return Value(StringUtil::BytesToHumanReadableString(config.options.maximum_memory)); -} - -//===--------------------------------------------------------------------===// -// Password Setting -//===--------------------------------------------------------------------===// -void PasswordSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - // nop -} - -void PasswordSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - // nop -} - -Value PasswordSetting::GetSetting(ClientContext &context) { - return Value(); -} - -//===--------------------------------------------------------------------===// -// Perfect Hash Threshold -//===--------------------------------------------------------------------===// -void PerfectHashThresholdSetting::ResetLocal(ClientContext &context) { - ClientConfig::GetConfig(context).perfect_ht_threshold = ClientConfig().perfect_ht_threshold; -} - -void PerfectHashThresholdSetting::SetLocal(ClientContext &context, const Value &input) { - auto bits = input.GetValue(); - if (bits < 0 || bits > 32) { - throw ParserException("Perfect HT threshold out of range: should be within range 0 - 32"); - } - ClientConfig::GetConfig(context).perfect_ht_threshold = bits; -} - -Value PerfectHashThresholdSetting::GetSetting(ClientContext &context) { - return Value::BIGINT(ClientConfig::GetConfig(context).perfect_ht_threshold); -} - -//===--------------------------------------------------------------------===// -// Pivot Filter Threshold -//===--------------------------------------------------------------------===// -void PivotFilterThreshold::ResetLocal(ClientContext &context) { - ClientConfig::GetConfig(context).pivot_filter_threshold = ClientConfig().pivot_filter_threshold; -} - -void PivotFilterThreshold::SetLocal(ClientContext &context, const Value &input) { - ClientConfig::GetConfig(context).pivot_filter_threshold = input.GetValue(); -} - -Value PivotFilterThreshold::GetSetting(ClientContext &context) { - return Value::BIGINT(ClientConfig::GetConfig(context).pivot_filter_threshold); -} - -//===--------------------------------------------------------------------===// -// Pivot Limit -//===--------------------------------------------------------------------===// -void PivotLimitSetting::ResetLocal(ClientContext &context) { - ClientConfig::GetConfig(context).pivot_limit = ClientConfig().pivot_limit; -} - -void PivotLimitSetting::SetLocal(ClientContext &context, const Value &input) { - ClientConfig::GetConfig(context).pivot_limit = input.GetValue(); -} - -Value PivotLimitSetting::GetSetting(ClientContext &context) { - return Value::BIGINT(ClientConfig::GetConfig(context).pivot_limit); -} - -//===--------------------------------------------------------------------===// -// PreserveIdentifierCase -//===--------------------------------------------------------------------===// -void PreserveIdentifierCase::ResetLocal(ClientContext &context) { - ClientConfig::GetConfig(context).preserve_identifier_case = ClientConfig().preserve_identifier_case; -} - -void PreserveIdentifierCase::SetLocal(ClientContext &context, const Value &input) { - ClientConfig::GetConfig(context).preserve_identifier_case = input.GetValue(); -} - -Value PreserveIdentifierCase::GetSetting(ClientContext &context) { - return Value::BOOLEAN(ClientConfig::GetConfig(context).preserve_identifier_case); -} - -//===--------------------------------------------------------------------===// -// PreserveInsertionOrder -//===--------------------------------------------------------------------===// -void PreserveInsertionOrder::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - config.options.preserve_insertion_order = input.GetValue(); -} - -void PreserveInsertionOrder::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.preserve_insertion_order = DBConfig().options.preserve_insertion_order; -} - -Value PreserveInsertionOrder::GetSetting(ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - return Value::BOOLEAN(config.options.preserve_insertion_order); -} - -//===--------------------------------------------------------------------===// -// ExportLargeBufferArrow -//===--------------------------------------------------------------------===// -void ExportLargeBufferArrow::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - auto export_large_buffers_arrow = input.GetValue(); - - config.options.arrow_offset_size = export_large_buffers_arrow ? ArrowOffsetSize::LARGE : ArrowOffsetSize::REGULAR; -} - -void ExportLargeBufferArrow::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.arrow_offset_size = DBConfig().options.arrow_offset_size; -} - -Value ExportLargeBufferArrow::GetSetting(ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - bool export_large_buffers_arrow = config.options.arrow_offset_size == ArrowOffsetSize::LARGE; - return Value::BOOLEAN(export_large_buffers_arrow); -} - -//===--------------------------------------------------------------------===// -// Profiler History Size -//===--------------------------------------------------------------------===// -void ProfilerHistorySize::ResetLocal(ClientContext &context) { - auto &client_data = ClientData::Get(context); - client_data.query_profiler_history->ResetProfilerHistorySize(); -} - -void ProfilerHistorySize::SetLocal(ClientContext &context, const Value &input) { - auto size = input.GetValue(); - if (size <= 0) { - throw ParserException("Size should be >= 0"); - } - auto &client_data = ClientData::Get(context); - client_data.query_profiler_history->SetProfilerHistorySize(size); -} - -Value ProfilerHistorySize::GetSetting(ClientContext &context) { - return Value(); -} - -//===--------------------------------------------------------------------===// -// Profile Output -//===--------------------------------------------------------------------===// -void ProfileOutputSetting::ResetLocal(ClientContext &context) { - ClientConfig::GetConfig(context).profiler_save_location = ClientConfig().profiler_save_location; -} - -void ProfileOutputSetting::SetLocal(ClientContext &context, const Value &input) { - auto &config = ClientConfig::GetConfig(context); - auto parameter = input.ToString(); - config.profiler_save_location = parameter; -} - -Value ProfileOutputSetting::GetSetting(ClientContext &context) { - auto &config = ClientConfig::GetConfig(context); - return Value(config.profiler_save_location); -} - -//===--------------------------------------------------------------------===// -// Profiling Mode -//===--------------------------------------------------------------------===// -void ProfilingModeSetting::ResetLocal(ClientContext &context) { - ClientConfig::GetConfig(context).enable_profiler = ClientConfig().enable_profiler; - ClientConfig::GetConfig(context).enable_detailed_profiling = ClientConfig().enable_detailed_profiling; - ClientConfig::GetConfig(context).emit_profiler_output = ClientConfig().emit_profiler_output; -} - -void ProfilingModeSetting::SetLocal(ClientContext &context, const Value &input) { - auto parameter = StringUtil::Lower(input.ToString()); - auto &config = ClientConfig::GetConfig(context); - if (parameter == "standard") { - config.enable_profiler = true; - config.enable_detailed_profiling = false; - config.emit_profiler_output = true; - } else if (parameter == "detailed") { - config.enable_profiler = true; - config.enable_detailed_profiling = true; - config.emit_profiler_output = true; - } else { - throw ParserException("Unrecognized profiling mode \"%s\", supported formats: [standard, detailed]", parameter); - } -} - -Value ProfilingModeSetting::GetSetting(ClientContext &context) { - auto &config = ClientConfig::GetConfig(context); - if (!config.enable_profiler) { - return Value(); - } - return Value(config.enable_detailed_profiling ? "detailed" : "standard"); -} - -//===--------------------------------------------------------------------===// -// Progress Bar Time -//===--------------------------------------------------------------------===// -void ProgressBarTimeSetting::ResetLocal(ClientContext &context) { - auto &config = ClientConfig::GetConfig(context); - ProgressBar::SystemOverrideCheck(config); - config.wait_time = ClientConfig().wait_time; - config.enable_progress_bar = ClientConfig().enable_progress_bar; -} - -void ProgressBarTimeSetting::SetLocal(ClientContext &context, const Value &input) { - auto &config = ClientConfig::GetConfig(context); - ProgressBar::SystemOverrideCheck(config); - config.wait_time = input.GetValue(); - config.enable_progress_bar = true; -} - -Value ProgressBarTimeSetting::GetSetting(ClientContext &context) { - return Value::BIGINT(ClientConfig::GetConfig(context).wait_time); -} - -//===--------------------------------------------------------------------===// -// Schema -//===--------------------------------------------------------------------===// -void SchemaSetting::ResetLocal(ClientContext &context) { - // FIXME: catalog_search_path is controlled by both SchemaSetting and SearchPathSetting - auto &client_data = ClientData::Get(context); - client_data.catalog_search_path->Reset(); -} - -void SchemaSetting::SetLocal(ClientContext &context, const Value &input) { - auto parameter = input.ToString(); - auto &client_data = ClientData::Get(context); - client_data.catalog_search_path->Set(CatalogSearchEntry::Parse(parameter), CatalogSetPathType::SET_SCHEMA); -} - -Value SchemaSetting::GetSetting(ClientContext &context) { - auto &client_data = ClientData::Get(context); - return client_data.catalog_search_path->GetDefault().schema; -} - -//===--------------------------------------------------------------------===// -// Search Path -//===--------------------------------------------------------------------===// -void SearchPathSetting::ResetLocal(ClientContext &context) { - // FIXME: catalog_search_path is controlled by both SchemaSetting and SearchPathSetting - auto &client_data = ClientData::Get(context); - client_data.catalog_search_path->Reset(); -} - -void SearchPathSetting::SetLocal(ClientContext &context, const Value &input) { - auto parameter = input.ToString(); - auto &client_data = ClientData::Get(context); - client_data.catalog_search_path->Set(CatalogSearchEntry::ParseList(parameter), CatalogSetPathType::SET_SCHEMAS); -} - -Value SearchPathSetting::GetSetting(ClientContext &context) { - auto &client_data = ClientData::Get(context); - auto &set_paths = client_data.catalog_search_path->GetSetPaths(); - return Value(CatalogSearchEntry::ListToString(set_paths)); -} - -//===--------------------------------------------------------------------===// -// Temp Directory -//===--------------------------------------------------------------------===// -void TempDirectorySetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - config.options.temporary_directory = input.ToString(); - config.options.use_temporary_directory = !config.options.temporary_directory.empty(); - if (db) { - auto &buffer_manager = BufferManager::GetBufferManager(*db); - buffer_manager.SetTemporaryDirectory(config.options.temporary_directory); - } -} - -void TempDirectorySetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.temporary_directory = DBConfig().options.temporary_directory; - config.options.use_temporary_directory = DBConfig().options.use_temporary_directory; - if (db) { - auto &buffer_manager = BufferManager::GetBufferManager(*db); - buffer_manager.SetTemporaryDirectory(config.options.temporary_directory); - } -} - -Value TempDirectorySetting::GetSetting(ClientContext &context) { - auto &buffer_manager = BufferManager::GetBufferManager(context); - return Value(buffer_manager.GetTemporaryDirectory()); -} - -//===--------------------------------------------------------------------===// -// Threads Setting -//===--------------------------------------------------------------------===// -void ThreadsSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - config.options.maximum_threads = input.GetValue(); - if (db) { - TaskScheduler::GetScheduler(*db).SetThreads(config.options.maximum_threads); - } -} - -void ThreadsSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.SetDefaultMaxThreads(); -} - -Value ThreadsSetting::GetSetting(ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - return Value::BIGINT(config.options.maximum_threads); -} - -//===--------------------------------------------------------------------===// -// Username Setting -//===--------------------------------------------------------------------===// -void UsernameSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - // nop -} - -void UsernameSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - // nop -} - -Value UsernameSetting::GetSetting(ClientContext &context) { - return Value(); -} - -//===--------------------------------------------------------------------===// -// Allocator Flush Threshold -//===--------------------------------------------------------------------===// -void FlushAllocatorSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - config.options.allocator_flush_threshold = DBConfig::ParseMemoryLimit(input.ToString()); - if (db) { - TaskScheduler::GetScheduler(*db).SetAllocatorFlushTreshold(config.options.allocator_flush_threshold); - } -} - -void FlushAllocatorSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.allocator_flush_threshold = DBConfig().options.allocator_flush_threshold; - if (db) { - TaskScheduler::GetScheduler(*db).SetAllocatorFlushTreshold(config.options.allocator_flush_threshold); - } -} - -Value FlushAllocatorSetting::GetSetting(ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - return Value(StringUtil::BytesToHumanReadableString(config.options.allocator_flush_threshold)); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -StreamQueryResult::StreamQueryResult(StatementType statement_type, StatementProperties properties, - shared_ptr context_p, vector types, - vector names) - : QueryResult(QueryResultType::STREAM_RESULT, statement_type, std::move(properties), std::move(types), - std::move(names), context_p->GetClientProperties()), - context(std::move(context_p)) { - D_ASSERT(context); -} - -StreamQueryResult::~StreamQueryResult() { -} - -string StreamQueryResult::ToString() { - string result; - if (success) { - result = HeaderToString(); - result += "[[STREAM RESULT]]"; - } else { - result = GetError() + "\n"; - } - return result; -} - -unique_ptr StreamQueryResult::LockContext() { - if (!context) { - string error_str = "Attempting to execute an unsuccessful or closed pending query result"; - if (HasError()) { - error_str += StringUtil::Format("\nError: %s", GetError()); - } - throw InvalidInputException(error_str); - } - return context->LockContext(); -} - -void StreamQueryResult::CheckExecutableInternal(ClientContextLock &lock) { - if (!IsOpenInternal(lock)) { - string error_str = "Attempting to execute an unsuccessful or closed pending query result"; - if (HasError()) { - error_str += StringUtil::Format("\nError: %s", GetError()); - } - throw InvalidInputException(error_str); - } -} - -unique_ptr StreamQueryResult::FetchRaw() { - unique_ptr chunk; - { - auto lock = LockContext(); - CheckExecutableInternal(*lock); - chunk = context->Fetch(*lock, *this); - } - if (!chunk || chunk->ColumnCount() == 0 || chunk->size() == 0) { - Close(); - return nullptr; - } - return chunk; -} - -unique_ptr StreamQueryResult::Materialize() { - if (HasError() || !context) { - return make_uniq(GetErrorObject()); - } - auto collection = make_uniq(Allocator::DefaultAllocator(), types); - - ColumnDataAppendState append_state; - collection->InitializeAppend(append_state); - while (true) { - auto chunk = Fetch(); - if (!chunk || chunk->size() == 0) { - break; - } - collection->Append(append_state, *chunk); - } - auto result = - make_uniq(statement_type, properties, names, std::move(collection), client_properties); - if (HasError()) { - return make_uniq(GetErrorObject()); - } - return result; -} - -bool StreamQueryResult::IsOpenInternal(ClientContextLock &lock) { - bool invalidated = !success || !context; - if (!invalidated) { - invalidated = !context->IsActiveResult(lock, this); - } - return !invalidated; -} - -bool StreamQueryResult::IsOpen() { - if (!success || !context) { - return false; - } - auto lock = LockContext(); - return IsOpenInternal(*lock); -} - -void StreamQueryResult::Close() { - context.reset(); -} - -} // namespace duckdb - - -namespace duckdb { - -ValidChecker::ValidChecker() : is_invalidated(false) { -} - -void ValidChecker::Invalidate(string error) { - lock_guard l(invalidate_lock); - this->is_invalidated = true; - this->invalidated_msg = std::move(error); -} - -bool ValidChecker::IsInvalidated() { - return this->is_invalidated; -} - -string ValidChecker::InvalidatedMessage() { - lock_guard l(invalidate_lock); - return invalidated_msg; -} -} // namespace duckdb - - - - -namespace duckdb { - -ReplacementBinding::ReplacementBinding(ColumnBinding old_binding, ColumnBinding new_binding) - : old_binding(old_binding), new_binding(new_binding), replace_type(false) { -} - -ReplacementBinding::ReplacementBinding(ColumnBinding old_binding, ColumnBinding new_binding, LogicalType new_type) - : old_binding(old_binding), new_binding(new_binding), replace_type(true), new_type(std::move(new_type)) { -} - -ColumnBindingReplacer::ColumnBindingReplacer() { -} - -void ColumnBindingReplacer::VisitOperator(LogicalOperator &op) { - if (stop_operator && stop_operator.get() == &op) { - return; - } - VisitOperatorChildren(op); - VisitOperatorExpressions(op); -} - -void ColumnBindingReplacer::VisitExpression(unique_ptr *expression) { - auto &expr = *expression; - if (expr->expression_class == ExpressionClass::BOUND_COLUMN_REF) { - auto &bound_column_ref = expr->Cast(); - for (const auto &replace_binding : replacement_bindings) { - if (bound_column_ref.binding == replace_binding.old_binding) { - bound_column_ref.binding = replace_binding.new_binding; - if (replace_binding.replace_type) { - bound_column_ref.return_type = replace_binding.new_type; - } - } - } - } - - VisitExpressionChildren(**expression); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -void ColumnLifetimeAnalyzer::ExtractUnusedColumnBindings(vector bindings, - column_binding_set_t &unused_bindings) { - for (idx_t i = 0; i < bindings.size(); i++) { - if (column_references.find(bindings[i]) == column_references.end()) { - unused_bindings.insert(bindings[i]); - } - } -} - -void ColumnLifetimeAnalyzer::GenerateProjectionMap(vector bindings, - column_binding_set_t &unused_bindings, - vector &projection_map) { - projection_map.clear(); - if (unused_bindings.empty()) { - return; - } - // now iterate over the result bindings of the child - for (idx_t i = 0; i < bindings.size(); i++) { - // if this binding does not belong to the unused bindings, add it to the projection map - if (unused_bindings.find(bindings[i]) == unused_bindings.end()) { - projection_map.push_back(i); - } - } - if (projection_map.size() == bindings.size()) { - projection_map.clear(); - } -} - -void ColumnLifetimeAnalyzer::StandardVisitOperator(LogicalOperator &op) { - LogicalOperatorVisitor::VisitOperatorExpressions(op); - if (op.type == LogicalOperatorType::LOGICAL_DELIM_JOIN) { - // visit the duplicate eliminated columns on the LHS, if any - auto &delim_join = op.Cast(); - for (auto &expr : delim_join.duplicate_eliminated_columns) { - VisitExpression(&expr); - } - } - LogicalOperatorVisitor::VisitOperatorChildren(op); -} - -void ColumnLifetimeAnalyzer::VisitOperator(LogicalOperator &op) { - switch (op.type) { - case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: { - // FIXME: groups that are not referenced can be removed from projection - // recurse into the children of the aggregate - ColumnLifetimeAnalyzer analyzer; - analyzer.VisitOperatorExpressions(op); - analyzer.VisitOperator(*op.children[0]); - return; - } - case LogicalOperatorType::LOGICAL_ASOF_JOIN: - case LogicalOperatorType::LOGICAL_DELIM_JOIN: - case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: { - if (everything_referenced) { - break; - } - auto &comp_join = op.Cast(); - if (comp_join.join_type == JoinType::MARK || comp_join.join_type == JoinType::SEMI || - comp_join.join_type == JoinType::ANTI) { - break; - } - // FIXME for now, we only push into the projection map for equality (hash) joins - // FIXME: add projection to LHS as well - bool has_equality = false; - for (auto &cond : comp_join.conditions) { - if (cond.comparison == ExpressionType::COMPARE_EQUAL) { - has_equality = true; - break; - } - } - if (!has_equality) { - break; - } - // visit current operator expressions so they are added to the referenced_columns - LogicalOperatorVisitor::VisitOperatorExpressions(op); - - column_binding_set_t unused_bindings; - auto old_op_bindings = op.GetColumnBindings(); - ExtractUnusedColumnBindings(op.children[1]->GetColumnBindings(), unused_bindings); - - // now recurse into the filter and its children - LogicalOperatorVisitor::VisitOperatorChildren(op); - - // then generate the projection map - GenerateProjectionMap(op.children[1]->GetColumnBindings(), unused_bindings, comp_join.right_projection_map); - return; - } - case LogicalOperatorType::LOGICAL_UNION: - case LogicalOperatorType::LOGICAL_EXCEPT: - case LogicalOperatorType::LOGICAL_INTERSECT: - // for set operations we don't remove anything, just recursively visit the children - // FIXME: for UNION we can remove unreferenced columns as long as everything_referenced is false (i.e. we - // encounter a UNION node that is not preceded by a DISTINCT) - for (auto &child : op.children) { - ColumnLifetimeAnalyzer analyzer(true); - analyzer.VisitOperator(*child); - } - return; - case LogicalOperatorType::LOGICAL_PROJECTION: { - // then recurse into the children of this projection - ColumnLifetimeAnalyzer analyzer; - analyzer.VisitOperatorExpressions(op); - analyzer.VisitOperator(*op.children[0]); - return; - } - case LogicalOperatorType::LOGICAL_DISTINCT: { - // distinct, all projected columns are used for the DISTINCT computation - // mark all columns as used and continue to the children - // FIXME: DISTINCT with expression list does not implicitly reference everything - everything_referenced = true; - break; - } - case LogicalOperatorType::LOGICAL_FILTER: { - auto &filter = op.Cast(); - if (everything_referenced) { - break; - } - // first visit operator expressions to populate referenced columns - LogicalOperatorVisitor::VisitOperatorExpressions(op); - // filter, figure out which columns are not needed after the filter - column_binding_set_t unused_bindings; - ExtractUnusedColumnBindings(op.children[0]->GetColumnBindings(), unused_bindings); - - // now recurse into the filter and its children - LogicalOperatorVisitor::VisitOperatorChildren(op); - - // then generate the projection map - GenerateProjectionMap(op.children[0]->GetColumnBindings(), unused_bindings, filter.projection_map); - return; - } - default: - break; - } - StandardVisitOperator(op); -} - -unique_ptr ColumnLifetimeAnalyzer::VisitReplace(BoundColumnRefExpression &expr, - unique_ptr *expr_ptr) { - column_references.insert(expr.binding); - return nullptr; -} - -unique_ptr ColumnLifetimeAnalyzer::VisitReplace(BoundReferenceExpression &expr, - unique_ptr *expr_ptr) { - // BoundReferenceExpression should not be used here yet, they only belong in the physical plan - throw InternalException("BoundReferenceExpression should not be used here yet!"); -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -void CommonAggregateOptimizer::VisitOperator(LogicalOperator &op) { - LogicalOperatorVisitor::VisitOperator(op); - switch (op.type) { - case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: - ExtractCommonAggregates(op.Cast()); - break; - default: - break; - } -} - -unique_ptr CommonAggregateOptimizer::VisitReplace(BoundColumnRefExpression &expr, - unique_ptr *expr_ptr) { - // check if this column ref points to an aggregate that was remapped; if it does we remap it - auto entry = aggregate_map.find(expr.binding); - if (entry != aggregate_map.end()) { - expr.binding = entry->second; - } - return nullptr; -} - -void CommonAggregateOptimizer::ExtractCommonAggregates(LogicalAggregate &aggr) { - expression_map_t aggregate_remap; - idx_t total_erased = 0; - for (idx_t i = 0; i < aggr.expressions.size(); i++) { - idx_t original_index = i + total_erased; - auto entry = aggregate_remap.find(*aggr.expressions[i]); - if (entry == aggregate_remap.end()) { - // aggregate does not exist yet: add it to the map - aggregate_remap[*aggr.expressions[i]] = i; - if (i != original_index) { - // this aggregate is not erased, however an agregate BEFORE it has been erased - // so we need to remap this aggregaet - ColumnBinding original_binding(aggr.aggregate_index, original_index); - ColumnBinding new_binding(aggr.aggregate_index, i); - aggregate_map[original_binding] = new_binding; - } - } else { - // aggregate already exists! we can remove this entry - total_erased++; - aggr.expressions.erase(aggr.expressions.begin() + i); - i--; - // we need to remap any references to this aggregate so they point to the other aggregate - ColumnBinding original_binding(aggr.aggregate_index, original_index); - ColumnBinding new_binding(aggr.aggregate_index, entry->second); - aggregate_map[original_binding] = new_binding; - } - } -} - -} // namespace duckdb - - - - - -namespace duckdb { - -void CompressedMaterialization::CompressAggregate(unique_ptr &op) { - auto &aggregate = op->Cast(); - auto &groups = aggregate.groups; - column_binding_set_t group_binding_set; - for (const auto &group : groups) { - if (group->type != ExpressionType::BOUND_COLUMN_REF) { - continue; - } - auto &colref = group->Cast(); - if (group_binding_set.find(colref.binding) != group_binding_set.end()) { - return; // Duplicate group - don't compress - } - group_binding_set.insert(colref.binding); - } - auto &group_stats = aggregate.group_stats; - - // No need to compress if there are no groups/stats - if (groups.empty() || group_stats.empty()) { - return; - } - D_ASSERT(groups.size() == group_stats.size()); - - // Find all bindings referenced by non-colref expressions in the groups - // These are excluded from compression by projection - // But we can try to compress the expression directly - column_binding_set_t referenced_bindings; - vector group_bindings(groups.size(), ColumnBinding()); - vector needs_decompression(groups.size(), false); - vector> stored_group_stats; - stored_group_stats.resize(groups.size()); - for (idx_t group_idx = 0; group_idx < groups.size(); group_idx++) { - auto &group_expr = *groups[group_idx]; - if (group_expr.GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { - auto &colref = group_expr.Cast(); - group_bindings[group_idx] = colref.binding; - continue; // Will be compressed generically - } - - // Mark the bindings referenced by the non-colref expression so they won't be modified - GetReferencedBindings(group_expr, referenced_bindings); - - // The non-colref expression won't be compressed generically, so try to compress it here - if (!group_stats[group_idx]) { - continue; // Can't compress without stats - } - - // Try to compress, if successful, replace the expression - auto compress_expr = GetCompressExpression(group_expr.Copy(), *group_stats[group_idx]); - if (compress_expr) { - needs_decompression[group_idx] = true; - stored_group_stats[group_idx] = std::move(group_stats[group_idx]); - groups[group_idx] = std::move(compress_expr->expression); - group_stats[group_idx] = std::move(compress_expr->stats); - } - } - - // Anything referenced in the aggregate functions is also excluded - for (idx_t expr_idx = 0; expr_idx < aggregate.expressions.size(); expr_idx++) { - const auto &expr = *aggregate.expressions[expr_idx]; - D_ASSERT(expr.type == ExpressionType::BOUND_AGGREGATE); - const auto &aggr_expr = expr.Cast(); - for (const auto &child : aggr_expr.children) { - GetReferencedBindings(*child, referenced_bindings); - } - if (aggr_expr.filter) { - GetReferencedBindings(*aggr_expr.filter, referenced_bindings); - } - if (aggr_expr.order_bys) { - for (const auto &order : aggr_expr.order_bys->orders) { - const auto &order_expr = *order.expression; - if (order_expr.type != ExpressionType::BOUND_COLUMN_REF) { - GetReferencedBindings(order_expr, referenced_bindings); - } - } - } - } - - // Create info for compression - CompressedMaterializationInfo info(*op, {0}, referenced_bindings); - - // Create binding mapping - const auto bindings_out = aggregate.GetColumnBindings(); - const auto &types = aggregate.types; - for (idx_t group_idx = 0; group_idx < groups.size(); group_idx++) { - // Aggregate changes bindings as it has a table idx - CMBindingInfo binding_info(bindings_out[group_idx], types[group_idx]); - binding_info.needs_decompression = needs_decompression[group_idx]; - if (needs_decompression[group_idx]) { - // Compressed non-generically - auto entry = info.binding_map.emplace(bindings_out[group_idx], std::move(binding_info)); - entry.first->second.stats = std::move(stored_group_stats[group_idx]); - } else if (group_bindings[group_idx] != ColumnBinding()) { - info.binding_map.emplace(group_bindings[group_idx], std::move(binding_info)); - } - } - - // Now try to compress - CreateProjections(op, info); - - // Update aggregate statistics - UpdateAggregateStats(op); -} - -void CompressedMaterialization::UpdateAggregateStats(unique_ptr &op) { - if (op->type != LogicalOperatorType::LOGICAL_PROJECTION) { - return; - } - - // Update aggregate group stats if compressed - auto &compressed_aggregate = op->children[0]->Cast(); - auto &groups = compressed_aggregate.groups; - auto &group_stats = compressed_aggregate.group_stats; - - for (idx_t group_idx = 0; group_idx < groups.size(); group_idx++) { - auto &group_expr = *groups[group_idx]; - if (group_expr.GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { - continue; - } - auto &colref = group_expr.Cast(); - if (!group_stats[group_idx]) { - continue; - } - if (colref.return_type == group_stats[group_idx]->GetType()) { - continue; - } - auto it = statistics_map.find(colref.binding); - if (it != statistics_map.end() && it->second) { - group_stats[group_idx] = it->second->ToUnique(); - } - } -} - -} // namespace duckdb - - - - -namespace duckdb { - -void CompressedMaterialization::CompressDistinct(unique_ptr &op) { - auto &distinct = op->Cast(); - auto &distinct_targets = distinct.distinct_targets; - - column_binding_set_t referenced_bindings; - for (auto &target : distinct_targets) { - if (target->type != ExpressionType::BOUND_COLUMN_REF) { // LCOV_EXCL_START - GetReferencedBindings(*target, referenced_bindings); - } // LCOV_EXCL_STOP - } - - if (distinct.order_by) { - for (auto &order : distinct.order_by->orders) { - if (order.expression->type != ExpressionType::BOUND_COLUMN_REF) { // LCOV_EXCL_START - GetReferencedBindings(*order.expression, referenced_bindings); - } // LCOV_EXCL_STOP - } - } - - // Create info for compression - CompressedMaterializationInfo info(*op, {0}, referenced_bindings); - - // Create binding mapping - const auto bindings = distinct.GetColumnBindings(); - const auto &types = distinct.types; - D_ASSERT(bindings.size() == types.size()); - for (idx_t col_idx = 0; col_idx < bindings.size(); col_idx++) { - // Distinct does not change bindings, input binding is output binding - info.binding_map.emplace(bindings[col_idx], CMBindingInfo(bindings[col_idx], types[col_idx])); - } - - // Now try to compress - CreateProjections(op, info); -} - -} // namespace duckdb - - - - -namespace duckdb { - -void CompressedMaterialization::CompressOrder(unique_ptr &op) { - auto &order = op->Cast(); - - // Find all bindings referenced by non-colref expressions in the order nodes - // These are excluded from compression by projection - // But we can try to compress the expression directly - column_binding_set_t referenced_bindings; - for (idx_t order_node_idx = 0; order_node_idx < order.orders.size(); order_node_idx++) { - auto &bound_order = order.orders[order_node_idx]; - auto &order_expression = *bound_order.expression; - if (order_expression.GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { - continue; // Will be compressed generically - } - - // Mark the bindings referenced by the non-colref expression so they won't be modified - GetReferencedBindings(order_expression, referenced_bindings); - } - - // Create info for compression - CompressedMaterializationInfo info(*op, {0}, referenced_bindings); - - // Create binding mapping - const auto bindings = order.GetColumnBindings(); - const auto &types = order.types; - D_ASSERT(bindings.size() == types.size()); - for (idx_t col_idx = 0; col_idx < bindings.size(); col_idx++) { - // Order does not change bindings, input binding is output binding - info.binding_map.emplace(bindings[col_idx], CMBindingInfo(bindings[col_idx], types[col_idx])); - } - - // Now try to compress - CreateProjections(op, info); - - // Update order statistics - UpdateOrderStats(op); -} - -void CompressedMaterialization::UpdateOrderStats(unique_ptr &op) { - if (op->type != LogicalOperatorType::LOGICAL_PROJECTION) { - return; - } - - // Update order stats if compressed - auto &compressed_order = op->children[0]->Cast(); - for (idx_t order_node_idx = 0; order_node_idx < compressed_order.orders.size(); order_node_idx++) { - auto &bound_order = compressed_order.orders[order_node_idx]; - auto &order_expression = *bound_order.expression; - if (order_expression.GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { - continue; - } - auto &colref = order_expression.Cast(); - auto it = statistics_map.find(colref.binding); - if (it != statistics_map.end() && it->second) { - bound_order.stats = it->second->ToUnique(); - } - } -} - -} // namespace duckdb - - - - - - - - - - - - - -namespace duckdb { - -CMChildInfo::CMChildInfo(LogicalOperator &op, const column_binding_set_t &referenced_bindings) - : bindings_before(op.GetColumnBindings()), types(op.types), can_compress(bindings_before.size(), true) { - for (const auto &binding : referenced_bindings) { - for (idx_t binding_idx = 0; binding_idx < bindings_before.size(); binding_idx++) { - if (binding == bindings_before[binding_idx]) { - can_compress[binding_idx] = false; - } - } - } -} - -CMBindingInfo::CMBindingInfo(ColumnBinding binding_p, const LogicalType &type_p) - : binding(binding_p), type(type_p), needs_decompression(false) { -} - -CompressedMaterializationInfo::CompressedMaterializationInfo(LogicalOperator &op, vector &&child_idxs_p, - const column_binding_set_t &referenced_bindings) - : child_idxs(child_idxs_p) { - child_info.reserve(child_idxs.size()); - for (const auto &child_idx : child_idxs) { - child_info.emplace_back(*op.children[child_idx], referenced_bindings); - } -} - -CompressExpression::CompressExpression(unique_ptr expression_p, unique_ptr stats_p) - : expression(std::move(expression_p)), stats(std::move(stats_p)) { -} - -CompressedMaterialization::CompressedMaterialization(ClientContext &context_p, Binder &binder_p, - statistics_map_t &&statistics_map_p) - : context(context_p), binder(binder_p), statistics_map(std::move(statistics_map_p)) { -} - -void CompressedMaterialization::GetReferencedBindings(const Expression &expression, - column_binding_set_t &referenced_bindings) { - if (expression.GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { - const auto &col_ref = expression.Cast(); - referenced_bindings.insert(col_ref.binding); - } else { - ExpressionIterator::EnumerateChildren( - expression, [&](const Expression &child) { GetReferencedBindings(child, referenced_bindings); }); - } -} - -void CompressedMaterialization::UpdateBindingInfo(CompressedMaterializationInfo &info, const ColumnBinding &binding, - bool needs_decompression) { - auto &binding_map = info.binding_map; - auto binding_it = binding_map.find(binding); - if (binding_it == binding_map.end()) { - return; - } - - auto &binding_info = binding_it->second; - binding_info.needs_decompression = needs_decompression; - auto stats_it = statistics_map.find(binding); - if (stats_it != statistics_map.end()) { - binding_info.stats = statistics_map[binding]->ToUnique(); - } -} - -void CompressedMaterialization::Compress(unique_ptr &op) { - root = op.get(); - root->ResolveOperatorTypes(); - - CompressInternal(op); -} - -void CompressedMaterialization::CompressInternal(unique_ptr &op) { - if (TopN::CanOptimize(*op)) { // Let's not mess with the TopN optimizer - CompressInternal(op->children[0]->children[0]); - return; - } - - for (auto &child : op->children) { - CompressInternal(child); - } - - switch (op->type) { - case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: - CompressAggregate(op); - break; - case LogicalOperatorType::LOGICAL_DISTINCT: - CompressDistinct(op); - break; - case LogicalOperatorType::LOGICAL_ORDER_BY: - CompressOrder(op); - break; - default: - return; - } -} - -void CompressedMaterialization::CreateProjections(unique_ptr &op, - CompressedMaterializationInfo &info) { - auto &materializing_op = *op; - - bool compressed_anything = false; - for (idx_t i = 0; i < info.child_idxs.size(); i++) { - auto &child_info = info.child_info[i]; - vector> compress_exprs; - if (TryCompressChild(info, child_info, compress_exprs)) { - // We can compress: Create a projection on top of the child operator - const auto child_idx = info.child_idxs[i]; - CreateCompressProjection(materializing_op.children[child_idx], std::move(compress_exprs), info, child_info); - compressed_anything = true; - } - } - - if (compressed_anything) { - CreateDecompressProjection(op, info); - } -} - -bool CompressedMaterialization::TryCompressChild(CompressedMaterializationInfo &info, const CMChildInfo &child_info, - vector> &compress_exprs) { - // Try to compress each of the column bindings of the child - bool compressed_anything = false; - for (idx_t child_i = 0; child_i < child_info.bindings_before.size(); child_i++) { - const auto child_binding = child_info.bindings_before[child_i]; - const auto &child_type = child_info.types[child_i]; - const auto &can_compress = child_info.can_compress[child_i]; - auto compress_expr = GetCompressExpression(child_binding, child_type, can_compress); - bool compressed = false; - if (compress_expr) { // We compressed, mark the outgoing binding in need of decompression - compress_exprs.emplace_back(std::move(compress_expr)); - compressed = true; - } else { // We did not compress, just push a colref - auto colref_expr = make_uniq(child_type, child_binding); - auto it = statistics_map.find(colref_expr->binding); - unique_ptr colref_stats = it != statistics_map.end() ? it->second->ToUnique() : nullptr; - compress_exprs.emplace_back(make_uniq(std::move(colref_expr), std::move(colref_stats))); - } - UpdateBindingInfo(info, child_binding, compressed); - compressed_anything = compressed_anything || compressed; - } - if (!compressed_anything) { - // If we compressed anything non-generically, we still need to decompress - for (const auto &entry : info.binding_map) { - compressed_anything = compressed_anything || entry.second.needs_decompression; - } - } - return compressed_anything; -} - -void CompressedMaterialization::CreateCompressProjection(unique_ptr &child_op, - vector> &&compress_exprs, - CompressedMaterializationInfo &info, CMChildInfo &child_info) { - // Replace child op with a projection - vector> projections; - projections.reserve(compress_exprs.size()); - for (auto &compress_expr : compress_exprs) { - projections.emplace_back(std::move(compress_expr->expression)); - } - const auto table_index = binder.GenerateTableIndex(); - auto compress_projection = make_uniq(table_index, std::move(projections)); - compression_table_indices.insert(table_index); - compress_projection->ResolveOperatorTypes(); - - compress_projection->children.emplace_back(std::move(child_op)); - child_op = std::move(compress_projection); - - // Get the new bindings and types - child_info.bindings_after = child_op->GetColumnBindings(); - const auto &new_types = child_op->types; - - // Initialize a ColumnBindingReplacer with the new bindings and types - ColumnBindingReplacer replacer; - auto &replacement_bindings = replacer.replacement_bindings; - for (idx_t col_idx = 0; col_idx < child_info.bindings_before.size(); col_idx++) { - const auto &old_binding = child_info.bindings_before[col_idx]; - const auto &new_binding = child_info.bindings_after[col_idx]; - const auto &new_type = new_types[col_idx]; - replacement_bindings.emplace_back(old_binding, new_binding, new_type); - - // Remove the old binding from the statistics map - statistics_map.erase(old_binding); - } - - // Make sure we skip the compress operator when replacing bindings - replacer.stop_operator = child_op.get(); - - // Make the plan consistent again - replacer.VisitOperator(*root); - - // Replace in/out exprs in the binding map too - auto &binding_map = info.binding_map; - for (auto &replacement_binding : replacement_bindings) { - auto it = binding_map.find(replacement_binding.old_binding); - if (it == binding_map.end()) { - continue; - } - auto &binding_info = it->second; - if (binding_info.binding == replacement_binding.old_binding) { - binding_info.binding = replacement_binding.new_binding; - } - - if (it->first == replacement_binding.old_binding) { - auto binding_info_local = std::move(binding_info); - binding_map.erase(it); - binding_map.emplace(replacement_binding.new_binding, std::move(binding_info_local)); - } - } - - // Add projection stats to statistics map - for (idx_t col_idx = 0; col_idx < child_info.bindings_after.size(); col_idx++) { - const auto &binding = child_info.bindings_after[col_idx]; - auto &stats = compress_exprs[col_idx]->stats; - statistics_map.emplace(binding, std::move(stats)); - } -} - -void CompressedMaterialization::CreateDecompressProjection(unique_ptr &op, - CompressedMaterializationInfo &info) { - const auto bindings = op->GetColumnBindings(); - op->ResolveOperatorTypes(); - const auto &types = op->types; - - // Create decompress expressions for everything we compressed - auto &binding_map = info.binding_map; - vector> decompress_exprs; - vector> statistics; - for (idx_t col_idx = 0; col_idx < bindings.size(); col_idx++) { - const auto &binding = bindings[col_idx]; - auto decompress_expr = make_uniq_base(types[col_idx], binding); - optional_ptr stats; - for (auto &entry : binding_map) { - auto &binding_info = entry.second; - if (binding_info.binding != binding) { - continue; - } - stats = binding_info.stats.get(); - if (binding_info.needs_decompression) { - decompress_expr = GetDecompressExpression(std::move(decompress_expr), binding_info.type, *stats); - } - } - statistics.push_back(stats); - decompress_exprs.emplace_back(std::move(decompress_expr)); - } - - // Replace op with a projection - const auto table_index = binder.GenerateTableIndex(); - auto decompress_projection = make_uniq(table_index, std::move(decompress_exprs)); - decompression_table_indices.insert(table_index); - - decompress_projection->children.emplace_back(std::move(op)); - op = std::move(decompress_projection); - - // Check if we're placing a projection on top of the root - if (op->children[0].get() == root.get()) { - root = op.get(); - return; - } - - // Get the new bindings and types - auto new_bindings = op->GetColumnBindings(); - op->ResolveOperatorTypes(); - auto &new_types = op->types; - - // Initialize a ColumnBindingReplacer with the new bindings and types - ColumnBindingReplacer replacer; - auto &replacement_bindings = replacer.replacement_bindings; - for (idx_t col_idx = 0; col_idx < bindings.size(); col_idx++) { - const auto &old_binding = bindings[col_idx]; - const auto &new_binding = new_bindings[col_idx]; - const auto &new_type = new_types[col_idx]; - replacement_bindings.emplace_back(old_binding, new_binding, new_type); - - if (statistics[col_idx]) { - statistics_map[new_binding] = statistics[col_idx]->ToUnique(); - } - } - - // Make sure we skip the decompress operator when replacing bindings - replacer.stop_operator = op.get(); - - // Make the plan consistent again - replacer.VisitOperator(*root); -} - -unique_ptr CompressedMaterialization::GetCompressExpression(const ColumnBinding &binding, - const LogicalType &type, - const bool &can_compress) { - auto it = statistics_map.find(binding); - if (can_compress && it != statistics_map.end() && it->second) { - auto input = make_uniq(type, binding); - const auto &stats = *it->second; - return GetCompressExpression(std::move(input), stats); - } - return nullptr; -} - -unique_ptr CompressedMaterialization::GetCompressExpression(unique_ptr input, - const BaseStatistics &stats) { - const auto &type = input->return_type; - if (type != stats.GetType()) { // LCOV_EXCL_START - return nullptr; - } // LCOV_EXCL_STOP - if (type.IsIntegral()) { - return GetIntegralCompress(std::move(input), stats); - } else if (type.id() == LogicalTypeId::VARCHAR) { - return GetStringCompress(std::move(input), stats); - } - return nullptr; -} - -static Value GetIntegralRangeValue(ClientContext &context, const LogicalType &type, const BaseStatistics &stats) { - auto min = NumericStats::Min(stats); - auto max = NumericStats::Max(stats); - - vector> arguments; - arguments.emplace_back(make_uniq(max)); - arguments.emplace_back(make_uniq(min)); - BoundFunctionExpression sub(type, SubtractFun::GetFunction(type, type), std::move(arguments), nullptr); - - Value result; - if (ExpressionExecutor::TryEvaluateScalar(context, sub, result)) { - return result; - } else { - // Couldn't evaluate: Return max hugeint as range so GetIntegralCompress will return nullptr - return Value::HUGEINT(NumericLimits::Maximum()); - } -} - -unique_ptr CompressedMaterialization::GetIntegralCompress(unique_ptr input, - const BaseStatistics &stats) { - const auto &type = input->return_type; - if (GetTypeIdSize(type.InternalType()) == 1 || !NumericStats::HasMinMax(stats)) { - return nullptr; - } - - // Get range and cast to UBIGINT (might fail for HUGEINT, in which case we just return) - Value range_value = GetIntegralRangeValue(context, type, stats); - if (!range_value.DefaultTryCastAs(LogicalType::UBIGINT)) { - return nullptr; - } - - // Get the smallest type that the range can fit into - const auto range = UBigIntValue::Get(range_value); - LogicalType cast_type; - if (range <= NumericLimits().Maximum()) { - cast_type = LogicalType::UTINYINT; - } else if (range <= NumericLimits().Maximum()) { - cast_type = LogicalType::USMALLINT; - } else if (range <= NumericLimits().Maximum()) { - cast_type = LogicalType::UINTEGER; - } else { - D_ASSERT(range <= NumericLimits().Maximum()); - cast_type = LogicalType::UBIGINT; - } - - // Check if type that fits the range is smaller than the input type - if (GetTypeIdSize(cast_type.InternalType()) == GetTypeIdSize(type.InternalType())) { - return nullptr; - } - D_ASSERT(GetTypeIdSize(cast_type.InternalType()) < GetTypeIdSize(type.InternalType())); - - // Compressing will yield a benefit - auto compress_function = CMIntegralCompressFun::GetFunction(type, cast_type); - vector> arguments; - arguments.emplace_back(std::move(input)); - arguments.emplace_back(make_uniq(NumericStats::Min(stats))); - auto compress_expr = - make_uniq(cast_type, compress_function, std::move(arguments), nullptr); - - auto compress_stats = BaseStatistics::CreateEmpty(cast_type); - compress_stats.CopyBase(stats); - NumericStats::SetMin(compress_stats, Value(0).DefaultCastAs(cast_type)); - NumericStats::SetMax(compress_stats, range_value.DefaultCastAs(cast_type)); - - return make_uniq(std::move(compress_expr), compress_stats.ToUnique()); -} - -unique_ptr CompressedMaterialization::GetStringCompress(unique_ptr input, - const BaseStatistics &stats) { - if (!StringStats::HasMaxStringLength(stats)) { - return nullptr; - } - - const auto max_string_length = StringStats::MaxStringLength(stats); - LogicalType cast_type = LogicalType::INVALID; - for (const auto &compressed_type : CompressedMaterializationFunctions::StringTypes()) { - if (max_string_length < GetTypeIdSize(compressed_type.InternalType())) { - cast_type = compressed_type; - break; - } - } - if (cast_type == LogicalType::INVALID) { - return nullptr; - } - - auto compress_stats = BaseStatistics::CreateEmpty(cast_type); - compress_stats.CopyBase(stats); - if (cast_type.id() == LogicalTypeId::USMALLINT) { - auto min_string = StringStats::Min(stats); - auto max_string = StringStats::Max(stats); - - uint8_t min_numeric = 0; - if (max_string_length != 0 && min_string.length() != 0) { - min_numeric = *reinterpret_cast(min_string.c_str()); - } - uint8_t max_numeric = 0; - if (max_string_length != 0 && max_string.length() != 0) { - max_numeric = *reinterpret_cast(max_string.c_str()); - } - - Value min_val = Value::USMALLINT(min_numeric); - Value max_val = Value::USMALLINT(max_numeric + 1); - if (max_numeric < NumericLimits::Maximum()) { - cast_type = LogicalType::UTINYINT; - compress_stats = BaseStatistics::CreateEmpty(cast_type); - compress_stats.CopyBase(stats); - min_val = Value::UTINYINT(min_numeric); - max_val = Value::UTINYINT(max_numeric + 1); - } - - NumericStats::SetMin(compress_stats, min_val); - NumericStats::SetMax(compress_stats, max_val); - } - - auto compress_function = CMStringCompressFun::GetFunction(cast_type); - vector> arguments; - arguments.emplace_back(std::move(input)); - auto compress_expr = - make_uniq(cast_type, compress_function, std::move(arguments), nullptr); - return make_uniq(std::move(compress_expr), compress_stats.ToUnique()); -} - -unique_ptr CompressedMaterialization::GetDecompressExpression(unique_ptr input, - const LogicalType &result_type, - const BaseStatistics &stats) { - const auto &type = result_type; - if (TypeIsIntegral(type.InternalType())) { - return GetIntegralDecompress(std::move(input), result_type, stats); - } else if (type.id() == LogicalTypeId::VARCHAR) { - return GetStringDecompress(std::move(input), stats); - } else { - throw InternalException("Type other than integral/string marked for decompression!"); - } -} - -unique_ptr CompressedMaterialization::GetIntegralDecompress(unique_ptr input, - const LogicalType &result_type, - const BaseStatistics &stats) { - D_ASSERT(NumericStats::HasMinMax(stats)); - auto decompress_function = CMIntegralDecompressFun::GetFunction(input->return_type, result_type); - vector> arguments; - arguments.emplace_back(std::move(input)); - arguments.emplace_back(make_uniq(NumericStats::Min(stats))); - return make_uniq(result_type, decompress_function, std::move(arguments), nullptr); -} - -unique_ptr CompressedMaterialization::GetStringDecompress(unique_ptr input, - const BaseStatistics &stats) { - D_ASSERT(StringStats::HasMaxStringLength(stats)); - auto decompress_function = CMStringDecompressFun::GetFunction(input->return_type); - vector> arguments; - arguments.emplace_back(std::move(input)); - return make_uniq(decompress_function.return_type, decompress_function, - std::move(arguments), nullptr); -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -//! The CSENode contains information about a common subexpression; how many times it occurs, and the column index in the -//! underlying projection -struct CSENode { - idx_t count; - idx_t column_index; - - CSENode() : count(1), column_index(DConstants::INVALID_INDEX) { - } -}; - -//! The CSEReplacementState -struct CSEReplacementState { - //! The projection index of the new projection - idx_t projection_index; - //! Map of expression -> CSENode - expression_map_t expression_count; - //! Map of column bindings to column indexes in the projection expression list - column_binding_map_t column_map; - //! The set of expressions of the resulting projection - vector> expressions; - //! Cached expressions that are kept around so the expression_map always contains valid expressions - vector> cached_expressions; -}; - -void CommonSubExpressionOptimizer::VisitOperator(LogicalOperator &op) { - switch (op.type) { - case LogicalOperatorType::LOGICAL_PROJECTION: - case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: - ExtractCommonSubExpresions(op); - break; - default: - break; - } - LogicalOperatorVisitor::VisitOperator(op); -} - -void CommonSubExpressionOptimizer::CountExpressions(Expression &expr, CSEReplacementState &state) { - // we only consider expressions with children for CSE elimination - switch (expr.expression_class) { - case ExpressionClass::BOUND_COLUMN_REF: - case ExpressionClass::BOUND_CONSTANT: - case ExpressionClass::BOUND_PARAMETER: - // skip conjunctions and case, since short-circuiting might be incorrectly disabled otherwise - case ExpressionClass::BOUND_CONJUNCTION: - case ExpressionClass::BOUND_CASE: - return; - default: - break; - } - if (expr.expression_class != ExpressionClass::BOUND_AGGREGATE && !expr.HasSideEffects()) { - // we can't move aggregates to a projection, so we only consider the children of the aggregate - auto node = state.expression_count.find(expr); - if (node == state.expression_count.end()) { - // first time we encounter this expression, insert this node with [count = 1] - state.expression_count[expr] = CSENode(); - } else { - // we encountered this expression before, increment the occurrence count - node->second.count++; - } - } - // recursively count the children - ExpressionIterator::EnumerateChildren(expr, [&](Expression &child) { CountExpressions(child, state); }); -} - -void CommonSubExpressionOptimizer::PerformCSEReplacement(unique_ptr &expr_ptr, CSEReplacementState &state) { - Expression &expr = *expr_ptr; - if (expr.expression_class == ExpressionClass::BOUND_COLUMN_REF) { - auto &bound_column_ref = expr.Cast(); - // bound column ref, check if this one has already been recorded in the expression list - auto column_entry = state.column_map.find(bound_column_ref.binding); - if (column_entry == state.column_map.end()) { - // not there yet: push the expression - idx_t new_column_index = state.expressions.size(); - state.column_map[bound_column_ref.binding] = new_column_index; - state.expressions.push_back(make_uniq( - bound_column_ref.alias, bound_column_ref.return_type, bound_column_ref.binding)); - bound_column_ref.binding = ColumnBinding(state.projection_index, new_column_index); - } else { - // else: just update the column binding! - bound_column_ref.binding = ColumnBinding(state.projection_index, column_entry->second); - } - return; - } - // check if this child is eligible for CSE elimination - bool can_cse = expr.expression_class != ExpressionClass::BOUND_CONJUNCTION && - expr.expression_class != ExpressionClass::BOUND_CASE; - if (can_cse && state.expression_count.find(expr) != state.expression_count.end()) { - auto &node = state.expression_count[expr]; - if (node.count > 1) { - // this expression occurs more than once! push it into the projection - // check if it has already been pushed into the projection - auto alias = expr.alias; - auto type = expr.return_type; - if (node.column_index == DConstants::INVALID_INDEX) { - // has not been pushed yet: push it - node.column_index = state.expressions.size(); - state.expressions.push_back(std::move(expr_ptr)); - } else { - state.cached_expressions.push_back(std::move(expr_ptr)); - } - // replace the original expression with a bound column ref - expr_ptr = make_uniq(alias, type, - ColumnBinding(state.projection_index, node.column_index)); - return; - } - } - // this expression only occurs once, we can't perform CSE elimination - // look into the children to see if we can replace them - ExpressionIterator::EnumerateChildren(expr, - [&](unique_ptr &child) { PerformCSEReplacement(child, state); }); -} - -void CommonSubExpressionOptimizer::ExtractCommonSubExpresions(LogicalOperator &op) { - D_ASSERT(op.children.size() == 1); - - // first we count for each expression with children how many types it occurs - CSEReplacementState state; - LogicalOperatorVisitor::EnumerateExpressions( - op, [&](unique_ptr *child) { CountExpressions(**child, state); }); - // check if there are any expressions to extract - bool perform_replacement = false; - for (auto &expr : state.expression_count) { - if (expr.second.count > 1) { - perform_replacement = true; - break; - } - } - if (!perform_replacement) { - // no CSEs to extract - return; - } - state.projection_index = binder.GenerateTableIndex(); - // we found common subexpressions to extract - // now we iterate over all the expressions and perform the actual CSE elimination - - LogicalOperatorVisitor::EnumerateExpressions( - op, [&](unique_ptr *child) { PerformCSEReplacement(*child, state); }); - D_ASSERT(state.expressions.size() > 0); - // create a projection node as the child of this node - auto projection = make_uniq(state.projection_index, std::move(state.expressions)); - projection->children.push_back(std::move(op.children[0])); - op.children[0] = std::move(projection); -} - -} // namespace duckdb - - - - - - - - - - - - - -namespace duckdb { - -struct DelimCandidate { -public: - explicit DelimCandidate(unique_ptr &op, LogicalComparisonJoin &delim_join) - : op(op), delim_join(delim_join), delim_get_count(0) { - } - -public: - unique_ptr &op; - LogicalComparisonJoin &delim_join; - vector>> joins; - idx_t delim_get_count; -}; - -static bool IsEqualityJoinCondition(const JoinCondition &cond) { - switch (cond.comparison) { - case ExpressionType::COMPARE_EQUAL: - case ExpressionType::COMPARE_NOT_DISTINCT_FROM: - return true; - default: - return false; - } -} - -unique_ptr Deliminator::Optimize(unique_ptr op) { - root = op; - - vector candidates; - FindCandidates(op, candidates); - - for (auto &candidate : candidates) { - auto &delim_join = candidate.delim_join; - - bool all_removed = true; - bool all_equality_conditions = true; - for (auto &join : candidate.joins) { - all_removed = - RemoveJoinWithDelimGet(delim_join, candidate.delim_get_count, join, all_equality_conditions) && - all_removed; - } - - // Change type if there are no more duplicate-eliminated columns - if (candidate.joins.size() == candidate.delim_get_count && all_removed) { - delim_join.type = LogicalOperatorType::LOGICAL_COMPARISON_JOIN; - delim_join.duplicate_eliminated_columns.clear(); - if (all_equality_conditions) { - for (auto &cond : delim_join.conditions) { - if (IsEqualityJoinCondition(cond)) { - cond.comparison = ExpressionType::COMPARE_NOT_DISTINCT_FROM; - } - } - } - } - } - - return op; -} - -void Deliminator::FindCandidates(unique_ptr &op, vector &candidates) { - // Search children before adding, so the deepest candidates get added first - for (auto &child : op->children) { - FindCandidates(child, candidates); - } - - if (op->type != LogicalOperatorType::LOGICAL_DELIM_JOIN) { - return; - } - - candidates.emplace_back(op, op->Cast()); - auto &candidate = candidates.back(); - - // DelimGets are in the RHS - FindJoinWithDelimGet(op->children[1], candidate); -} - -static bool OperatorIsDelimGet(LogicalOperator &op) { - if (op.type == LogicalOperatorType::LOGICAL_DELIM_GET) { - return true; - } - if (op.type == LogicalOperatorType::LOGICAL_FILTER && - op.children[0]->type == LogicalOperatorType::LOGICAL_DELIM_GET) { - return true; - } - return false; -} - -void Deliminator::FindJoinWithDelimGet(unique_ptr &op, DelimCandidate &candidate) { - if (op->type == LogicalOperatorType::LOGICAL_DELIM_JOIN) { - FindJoinWithDelimGet(op->children[0], candidate); - } else if (op->type == LogicalOperatorType::LOGICAL_DELIM_GET) { - candidate.delim_get_count++; - } else { - for (auto &child : op->children) { - FindJoinWithDelimGet(child, candidate); - } - } - - if (op->type == LogicalOperatorType::LOGICAL_COMPARISON_JOIN && - (OperatorIsDelimGet(*op->children[0]) || OperatorIsDelimGet(*op->children[1]))) { - candidate.joins.emplace_back(op); - } -} - -static bool ChildJoinTypeCanBeDeliminated(JoinType &join_type) { - switch (join_type) { - case JoinType::INNER: - case JoinType::SEMI: - return true; - default: - return false; - } -} - -bool Deliminator::RemoveJoinWithDelimGet(LogicalComparisonJoin &delim_join, const idx_t delim_get_count, - unique_ptr &join, bool &all_equality_conditions) { - auto &comparison_join = join->Cast(); - if (!ChildJoinTypeCanBeDeliminated(comparison_join.join_type)) { - return false; - } - - // Get the index (left or right) of the DelimGet side of the join - const idx_t delim_idx = OperatorIsDelimGet(*join->children[0]) ? 0 : 1; - - // Get the filter (if any) - optional_ptr filter; - vector> filter_expressions; - if (join->children[delim_idx]->type == LogicalOperatorType::LOGICAL_FILTER) { - filter = &join->children[delim_idx]->Cast(); - for (auto &expr : filter->expressions) { - filter_expressions.emplace_back(expr->Copy()); - } - } - - auto &delim_get = (filter ? filter->children[0] : join->children[delim_idx])->Cast(); - if (comparison_join.conditions.size() != delim_get.chunk_types.size()) { - return false; // Joining with DelimGet adds new information - } - - // Check if joining with the DelimGet is redundant, and collect relevant column information - ColumnBindingReplacer replacer; - auto &replacement_bindings = replacer.replacement_bindings; - for (auto &cond : comparison_join.conditions) { - all_equality_conditions = all_equality_conditions && IsEqualityJoinCondition(cond); - auto &delim_side = delim_idx == 0 ? *cond.left : *cond.right; - auto &other_side = delim_idx == 0 ? *cond.right : *cond.left; - if (delim_side.type != ExpressionType::BOUND_COLUMN_REF || - other_side.type != ExpressionType::BOUND_COLUMN_REF) { - return false; - } - auto &delim_colref = delim_side.Cast(); - auto &other_colref = other_side.Cast(); - replacement_bindings.emplace_back(delim_colref.binding, other_colref.binding); - - if (cond.comparison != ExpressionType::COMPARE_NOT_DISTINCT_FROM) { - auto is_not_null_expr = - make_uniq(ExpressionType::OPERATOR_IS_NOT_NULL, LogicalType::BOOLEAN); - is_not_null_expr->children.push_back(other_side.Copy()); - filter_expressions.push_back(std::move(is_not_null_expr)); - } - } - - if (!all_equality_conditions && - !RemoveInequalityJoinWithDelimGet(delim_join, delim_get_count, join, replacement_bindings)) { - return false; - } - - unique_ptr replacement_op = std::move(comparison_join.children[1 - delim_idx]); - if (!filter_expressions.empty()) { // Create filter if necessary - auto new_filter = make_uniq(); - new_filter->expressions = std::move(filter_expressions); - new_filter->children.emplace_back(std::move(replacement_op)); - replacement_op = std::move(new_filter); - } - - join = std::move(replacement_op); - - // TODO: Maybe go from delim join instead to save work - replacer.VisitOperator(*root); - return true; -} - -static bool InequalityDelimJoinCanBeEliminated(JoinType &join_type) { - return join_type == JoinType::ANTI || join_type == JoinType::MARK || join_type == JoinType::SEMI || - join_type == JoinType::SINGLE; -} - -bool FindAndReplaceBindings(vector &traced_bindings, const vector> &expressions, - const vector ¤t_bindings) { - for (auto &binding : traced_bindings) { - idx_t current_idx; - for (current_idx = 0; current_idx < expressions.size(); current_idx++) { - if (binding == current_bindings[current_idx]) { - break; - } - } - - if (current_idx == expressions.size() || expressions[current_idx]->type != ExpressionType::BOUND_COLUMN_REF) { - return false; // Didn't find / can't deal with non-colref - } - - auto &colref = expressions[current_idx]->Cast(); - binding = colref.binding; - } - return true; -} - -bool Deliminator::RemoveInequalityJoinWithDelimGet(LogicalComparisonJoin &delim_join, const idx_t delim_get_count, - unique_ptr &join, - const vector &replacement_bindings) { - auto &comparison_join = join->Cast(); - auto &delim_conditions = delim_join.conditions; - const auto &join_conditions = comparison_join.conditions; - if (delim_get_count != 1 || !InequalityDelimJoinCanBeEliminated(delim_join.join_type) || - delim_conditions.size() != join_conditions.size()) { - return false; - } - - // TODO: we cannot perform the optimization here because our pure inequality joins don't implement - // JoinType::SINGLE yet - if (delim_join.join_type == JoinType::SINGLE) { - bool has_one_equality = false; - for (auto &cond : join_conditions) { - has_one_equality = has_one_equality || IsEqualityJoinCondition(cond); - } - if (!has_one_equality) { - return false; - } - } - - // We only support colref's - vector traced_bindings; - for (const auto &cond : delim_conditions) { - if (cond.right->type != ExpressionType::BOUND_COLUMN_REF) { - return false; - } - auto &colref = cond.right->Cast(); - traced_bindings.emplace_back(colref.binding); - } - - // Now we trace down the bindings to the join (for now, we only trace it through a few operators) - reference current_op = *delim_join.children[1]; - while (¤t_op.get() != join.get()) { - if (current_op.get().children.size() != 1) { - return false; - } - - switch (current_op.get().type) { - case LogicalOperatorType::LOGICAL_PROJECTION: - FindAndReplaceBindings(traced_bindings, current_op.get().expressions, current_op.get().GetColumnBindings()); - break; - case LogicalOperatorType::LOGICAL_FILTER: - break; // Doesn't change bindings - default: - return false; - } - current_op = *current_op.get().children[0]; - } - - // Get the index (left or right) of the DelimGet side of the join - const idx_t delim_idx = OperatorIsDelimGet(*join->children[0]) ? 0 : 1; - - bool found_all = true; - for (idx_t cond_idx = 0; cond_idx < delim_conditions.size(); cond_idx++) { - auto &delim_condition = delim_conditions[cond_idx]; - const auto &traced_binding = traced_bindings[cond_idx]; - - bool found = false; - for (auto &join_condition : join_conditions) { - auto &delim_side = delim_idx == 0 ? *join_condition.left : *join_condition.right; - auto &colref = delim_side.Cast(); - if (colref.binding == traced_binding) { - delim_condition.comparison = FlipComparisonExpression(join_condition.comparison); - found = true; - break; - } - } - found_all = found_all && found; - } - - return found_all; -} - -} // namespace duckdb - - - -namespace duckdb { - -unique_ptr ExpressionHeuristics::Rewrite(unique_ptr op) { - VisitOperator(*op); - return op; -} - -void ExpressionHeuristics::VisitOperator(LogicalOperator &op) { - if (op.type == LogicalOperatorType::LOGICAL_FILTER) { - // reorder all filter expressions - if (op.expressions.size() > 1) { - ReorderExpressions(op.expressions); - } - } - - // traverse recursively through the operator tree - VisitOperatorChildren(op); - VisitOperatorExpressions(op); -} - -unique_ptr ExpressionHeuristics::VisitReplace(BoundConjunctionExpression &expr, - unique_ptr *expr_ptr) { - ReorderExpressions(expr.children); - return nullptr; -} - -void ExpressionHeuristics::ReorderExpressions(vector> &expressions) { - - struct ExpressionCosts { - unique_ptr expr; - idx_t cost; - - bool operator==(const ExpressionCosts &p) const { - return cost == p.cost; - } - bool operator<(const ExpressionCosts &p) const { - return cost < p.cost; - } - }; - - vector expression_costs; - expression_costs.reserve(expressions.size()); - // iterate expressions, get cost for each one - for (idx_t i = 0; i < expressions.size(); i++) { - idx_t cost = Cost(*expressions[i]); - expression_costs.push_back({std::move(expressions[i]), cost}); - } - - // sort by cost and put back in place - sort(expression_costs.begin(), expression_costs.end()); - for (idx_t i = 0; i < expression_costs.size(); i++) { - expressions[i] = std::move(expression_costs[i].expr); - } -} - -idx_t ExpressionHeuristics::ExpressionCost(BoundBetweenExpression &expr) { - return Cost(*expr.input) + Cost(*expr.lower) + Cost(*expr.upper) + 10; -} - -idx_t ExpressionHeuristics::ExpressionCost(BoundCaseExpression &expr) { - // CASE WHEN check THEN result_if_true ELSE result_if_false END - idx_t case_cost = 0; - for (auto &case_check : expr.case_checks) { - case_cost += Cost(*case_check.then_expr); - case_cost += Cost(*case_check.when_expr); - } - case_cost += Cost(*expr.else_expr); - return case_cost; -} - -idx_t ExpressionHeuristics::ExpressionCost(BoundCastExpression &expr) { - // OPERATOR_CAST - // determine cast cost by comparing cast_expr.source_type and cast_expr_target_type - idx_t cast_cost = 0; - if (expr.return_type != expr.source_type()) { - // if cast from or to varchar - // TODO: we might want to add more cases - if (expr.return_type.id() == LogicalTypeId::VARCHAR || expr.source_type().id() == LogicalTypeId::VARCHAR || - expr.return_type.id() == LogicalTypeId::BLOB || expr.source_type().id() == LogicalTypeId::BLOB) { - cast_cost = 200; - } else { - cast_cost = 5; - } - } - return Cost(*expr.child) + cast_cost; -} - -idx_t ExpressionHeuristics::ExpressionCost(BoundComparisonExpression &expr) { - // COMPARE_EQUAL, COMPARE_NOTEQUAL, COMPARE_GREATERTHAN, COMPARE_GREATERTHANOREQUALTO, COMPARE_LESSTHAN, - // COMPARE_LESSTHANOREQUALTO - return Cost(*expr.left) + 5 + Cost(*expr.right); -} - -idx_t ExpressionHeuristics::ExpressionCost(BoundConjunctionExpression &expr) { - // CONJUNCTION_AND, CONJUNCTION_OR - idx_t cost = 5; - for (auto &child : expr.children) { - cost += Cost(*child); - } - return cost; -} - -idx_t ExpressionHeuristics::ExpressionCost(BoundFunctionExpression &expr) { - idx_t cost_children = 0; - for (auto &child : expr.children) { - cost_children += Cost(*child); - } - - auto cost_function = function_costs.find(expr.function.name); - if (cost_function != function_costs.end()) { - return cost_children + cost_function->second; - } else { - return cost_children + 1000; - } -} - -idx_t ExpressionHeuristics::ExpressionCost(BoundOperatorExpression &expr, ExpressionType &expr_type) { - idx_t sum = 0; - for (auto &child : expr.children) { - sum += Cost(*child); - } - - // OPERATOR_IS_NULL, OPERATOR_IS_NOT_NULL - if (expr_type == ExpressionType::OPERATOR_IS_NULL || expr_type == ExpressionType::OPERATOR_IS_NOT_NULL) { - return sum + 5; - } else if (expr_type == ExpressionType::COMPARE_IN || expr_type == ExpressionType::COMPARE_NOT_IN) { - // COMPARE_IN, COMPARE_NOT_IN - return sum + (expr.children.size() - 1) * 100; - } else if (expr_type == ExpressionType::OPERATOR_NOT) { - // OPERATOR_NOT - return sum + 10; // TODO: evaluate via measured runtimes - } else { - return sum + 1000; - } -} - -idx_t ExpressionHeuristics::ExpressionCost(PhysicalType return_type, idx_t multiplier) { - // TODO: ajust values according to benchmark results - switch (return_type) { - case PhysicalType::VARCHAR: - return 5 * multiplier; - case PhysicalType::FLOAT: - case PhysicalType::DOUBLE: - return 2 * multiplier; - default: - return 1 * multiplier; - } -} - -idx_t ExpressionHeuristics::Cost(Expression &expr) { - switch (expr.expression_class) { - case ExpressionClass::BOUND_CASE: { - auto &case_expr = expr.Cast(); - return ExpressionCost(case_expr); - } - case ExpressionClass::BOUND_BETWEEN: { - auto &between_expr = expr.Cast(); - return ExpressionCost(between_expr); - } - case ExpressionClass::BOUND_CAST: { - auto &cast_expr = expr.Cast(); - return ExpressionCost(cast_expr); - } - case ExpressionClass::BOUND_COMPARISON: { - auto &comp_expr = expr.Cast(); - return ExpressionCost(comp_expr); - } - case ExpressionClass::BOUND_CONJUNCTION: { - auto &conj_expr = expr.Cast(); - return ExpressionCost(conj_expr); - } - case ExpressionClass::BOUND_FUNCTION: { - auto &func_expr = expr.Cast(); - return ExpressionCost(func_expr); - } - case ExpressionClass::BOUND_OPERATOR: { - auto &op_expr = expr.Cast(); - return ExpressionCost(op_expr, expr.type); - } - case ExpressionClass::BOUND_COLUMN_REF: { - auto &col_expr = expr.Cast(); - return ExpressionCost(col_expr.return_type.InternalType(), 8); - } - case ExpressionClass::BOUND_CONSTANT: { - auto &const_expr = expr.Cast(); - return ExpressionCost(const_expr.return_type.InternalType(), 1); - } - case ExpressionClass::BOUND_PARAMETER: { - auto &const_expr = expr.Cast(); - return ExpressionCost(const_expr.return_type.InternalType(), 1); - } - case ExpressionClass::BOUND_REF: { - auto &col_expr = expr.Cast(); - return ExpressionCost(col_expr.return_type.InternalType(), 8); - } - default: { - break; - } - } - - // return a very high value if nothing matches - return 1000; -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -unique_ptr ExpressionRewriter::ApplyRules(LogicalOperator &op, const vector> &rules, - unique_ptr expr, bool &changes_made, bool is_root) { - for (auto &rule : rules) { - vector> bindings; - if (rule.get().root->Match(*expr, bindings)) { - // the rule matches! try to apply it - bool rule_made_change = false; - auto result = rule.get().Apply(op, bindings, rule_made_change, is_root); - if (result) { - changes_made = true; - // the base node changed: the rule applied changes - // rerun on the new node - return ExpressionRewriter::ApplyRules(op, rules, std::move(result), changes_made); - } else if (rule_made_change) { - changes_made = true; - // the base node didn't change, but changes were made, rerun - return expr; - } - // else nothing changed, continue to the next rule - continue; - } - } - // no changes could be made to this node - // recursively run on the children of this node - ExpressionIterator::EnumerateChildren(*expr, [&](unique_ptr &child) { - child = ExpressionRewriter::ApplyRules(op, rules, std::move(child), changes_made); - }); - return expr; -} - -unique_ptr ExpressionRewriter::ConstantOrNull(unique_ptr child, Value value) { - vector> children; - children.push_back(make_uniq(value)); - children.push_back(std::move(child)); - return ConstantOrNull(std::move(children), std::move(value)); -} - -unique_ptr ExpressionRewriter::ConstantOrNull(vector> children, Value value) { - auto type = value.type(); - children.insert(children.begin(), make_uniq(value)); - return make_uniq(type, ConstantOrNull::GetFunction(type), std::move(children), - ConstantOrNull::Bind(std::move(value))); -} - -void ExpressionRewriter::VisitOperator(LogicalOperator &op) { - VisitOperatorChildren(op); - this->op = &op; - - to_apply_rules.clear(); - for (auto &rule : rules) { - if (rule->logical_root && !rule->logical_root->Match(op.type)) { - // this rule does not apply to this type of LogicalOperator - continue; - } - to_apply_rules.push_back(*rule); - } - if (to_apply_rules.empty()) { - // no rules to apply on this node - return; - } - - VisitOperatorExpressions(op); - - // if it is a LogicalFilter, we split up filter conjunctions again - if (op.type == LogicalOperatorType::LOGICAL_FILTER) { - auto &filter = op.Cast(); - filter.SplitPredicates(); - } -} - -void ExpressionRewriter::VisitExpression(unique_ptr *expression) { - bool changes_made; - do { - changes_made = false; - *expression = ExpressionRewriter::ApplyRules(*op, to_apply_rules, std::move(*expression), changes_made, true); - } while (changes_made); -} - -ClientContext &Rule::GetContext() const { - return rewriter.context; -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - -namespace duckdb { - -using ExpressionValueInformation = FilterCombiner::ExpressionValueInformation; - -ValueComparisonResult CompareValueInformation(ExpressionValueInformation &left, ExpressionValueInformation &right); - -FilterCombiner::FilterCombiner(ClientContext &context) : context(context) { -} - -FilterCombiner::FilterCombiner(Optimizer &optimizer) : FilterCombiner(optimizer.context) { -} - -Expression &FilterCombiner::GetNode(Expression &expr) { - auto entry = stored_expressions.find(expr); - if (entry != stored_expressions.end()) { - // expression already exists: return a reference to the stored expression - return *entry->second; - } - // expression does not exist yet: create a copy and store it - auto copy = expr.Copy(); - auto ©_ref = *copy; - D_ASSERT(stored_expressions.find(copy_ref) == stored_expressions.end()); - stored_expressions[copy_ref] = std::move(copy); - return copy_ref; -} - -idx_t FilterCombiner::GetEquivalenceSet(Expression &expr) { - D_ASSERT(stored_expressions.find(expr) != stored_expressions.end()); - D_ASSERT(stored_expressions.find(expr)->second.get() == &expr); - auto entry = equivalence_set_map.find(expr); - if (entry == equivalence_set_map.end()) { - idx_t index = set_index++; - equivalence_set_map[expr] = index; - equivalence_map[index].push_back(expr); - constant_values.insert(make_pair(index, vector())); - return index; - } else { - return entry->second; - } -} - -FilterResult FilterCombiner::AddConstantComparison(vector &info_list, - ExpressionValueInformation info) { - if (info.constant.IsNull()) { - return FilterResult::UNSATISFIABLE; - } - for (idx_t i = 0; i < info_list.size(); i++) { - auto comparison = CompareValueInformation(info_list[i], info); - switch (comparison) { - case ValueComparisonResult::PRUNE_LEFT: - // prune the entry from the info list - info_list.erase(info_list.begin() + i); - i--; - break; - case ValueComparisonResult::PRUNE_RIGHT: - // prune the current info - return FilterResult::SUCCESS; - case ValueComparisonResult::UNSATISFIABLE_CONDITION: - // combination of filters is unsatisfiable: prune the entire branch - return FilterResult::UNSATISFIABLE; - default: - // prune nothing, move to the next condition - break; - } - } - // finally add the entry to the list - info_list.push_back(info); - return FilterResult::SUCCESS; -} - -FilterResult FilterCombiner::AddFilter(unique_ptr expr) { - // LookUpConjunctions(expr.get()); - // try to push the filter into the combiner - auto result = AddFilter(*expr); - if (result == FilterResult::UNSUPPORTED) { - // unsupported filter, push into remaining filters - remaining_filters.push_back(std::move(expr)); - return FilterResult::SUCCESS; - } - return result; -} - -void FilterCombiner::GenerateFilters(const std::function filter)> &callback) { - // first loop over the remaining filters - for (auto &filter : remaining_filters) { - callback(std::move(filter)); - } - remaining_filters.clear(); - // now loop over the equivalence sets - for (auto &entry : equivalence_map) { - auto equivalence_set = entry.first; - auto &entries = entry.second; - auto &constant_list = constant_values.find(equivalence_set)->second; - // for each entry generate an equality expression comparing to each other - for (idx_t i = 0; i < entries.size(); i++) { - for (idx_t k = i + 1; k < entries.size(); k++) { - auto comparison = make_uniq( - ExpressionType::COMPARE_EQUAL, entries[i].get().Copy(), entries[k].get().Copy()); - callback(std::move(comparison)); - } - // for each entry also create a comparison with each constant - int lower_index = -1; - int upper_index = -1; - bool lower_inclusive = false; - bool upper_inclusive = false; - for (idx_t k = 0; k < constant_list.size(); k++) { - auto &info = constant_list[k]; - if (info.comparison_type == ExpressionType::COMPARE_GREATERTHAN || - info.comparison_type == ExpressionType::COMPARE_GREATERTHANOREQUALTO) { - lower_index = k; - lower_inclusive = info.comparison_type == ExpressionType::COMPARE_GREATERTHANOREQUALTO; - } else if (info.comparison_type == ExpressionType::COMPARE_LESSTHAN || - info.comparison_type == ExpressionType::COMPARE_LESSTHANOREQUALTO) { - upper_index = k; - upper_inclusive = info.comparison_type == ExpressionType::COMPARE_LESSTHANOREQUALTO; - } else { - auto constant = make_uniq(info.constant); - auto comparison = make_uniq( - info.comparison_type, entries[i].get().Copy(), std::move(constant)); - callback(std::move(comparison)); - } - } - if (lower_index >= 0 && upper_index >= 0) { - // found both lower and upper index, create a BETWEEN expression - auto lower_constant = make_uniq(constant_list[lower_index].constant); - auto upper_constant = make_uniq(constant_list[upper_index].constant); - auto between = - make_uniq(entries[i].get().Copy(), std::move(lower_constant), - std::move(upper_constant), lower_inclusive, upper_inclusive); - callback(std::move(between)); - } else if (lower_index >= 0) { - // only lower index found, create simple comparison expression - auto constant = make_uniq(constant_list[lower_index].constant); - auto comparison = make_uniq(constant_list[lower_index].comparison_type, - entries[i].get().Copy(), std::move(constant)); - callback(std::move(comparison)); - } else if (upper_index >= 0) { - // only upper index found, create simple comparison expression - auto constant = make_uniq(constant_list[upper_index].constant); - auto comparison = make_uniq(constant_list[upper_index].comparison_type, - entries[i].get().Copy(), std::move(constant)); - callback(std::move(comparison)); - } - } - } - stored_expressions.clear(); - equivalence_set_map.clear(); - constant_values.clear(); - equivalence_map.clear(); -} - -bool FilterCombiner::HasFilters() { - bool has_filters = false; - GenerateFilters([&](unique_ptr child) { has_filters = true; }); - return has_filters; -} - -// unordered_map> MergeAnd(unordered_map> &f_1, -// unordered_map> &f_2) { -// unordered_map> result; -// for (auto &f : f_1) { -// auto it = f_2.find(f.first); -// if (it == f_2.end()) { -// result[f.first] = f.second; -// } else { -// Value *min = nullptr, *max = nullptr; -// if (it->second.first && f.second.first) { -// if (*f.second.first > *it->second.first) { -// min = f.second.first; -// } else { -// min = it->second.first; -// } - -// } else if (it->second.first) { -// min = it->second.first; -// } else if (f.second.first) { -// min = f.second.first; -// } else { -// min = nullptr; -// } -// if (it->second.second && f.second.second) { -// if (*f.second.second < *it->second.second) { -// max = f.second.second; -// } else { -// max = it->second.second; -// } -// } else if (it->second.second) { -// max = it->second.second; -// } else if (f.second.second) { -// max = f.second.second; -// } else { -// max = nullptr; -// } -// result[f.first] = {min, max}; -// f_2.erase(f.first); -// } -// } -// for (auto &f : f_2) { -// result[f.first] = f.second; -// } -// return result; -// } - -// unordered_map> MergeOr(unordered_map> &f_1, -// unordered_map> &f_2) { -// unordered_map> result; -// for (auto &f : f_1) { -// auto it = f_2.find(f.first); -// if (it != f_2.end()) { -// Value *min = nullptr, *max = nullptr; -// if (it->second.first && f.second.first) { -// if (*f.second.first < *it->second.first) { -// min = f.second.first; -// } else { -// min = it->second.first; -// } -// } -// if (it->second.second && f.second.second) { -// if (*f.second.second > *it->second.second) { -// max = f.second.second; -// } else { -// max = it->second.second; -// } -// } -// result[f.first] = {min, max}; -// f_2.erase(f.first); -// } -// } -// return result; -// } - -// unordered_map> -// FilterCombiner::FindZonemapChecks(vector &column_ids, unordered_set ¬_constants, Expression *filter) -// { unordered_map> checks; switch (filter->type) { case -// ExpressionType::CONJUNCTION_OR: { -// //! For a filter to -// auto &or_exp = filter->Cast(); -// checks = FindZonemapChecks(column_ids, not_constants, or_exp.children[0].get()); -// for (size_t i = 1; i < or_exp.children.size(); ++i) { -// auto child_check = FindZonemapChecks(column_ids, not_constants, or_exp.children[i].get()); -// checks = MergeOr(checks, child_check); -// } -// return checks; -// } -// case ExpressionType::CONJUNCTION_AND: { -// auto &and_exp = filter->Cast(); -// checks = FindZonemapChecks(column_ids, not_constants, and_exp.children[0].get()); -// for (size_t i = 1; i < and_exp.children.size(); ++i) { -// auto child_check = FindZonemapChecks(column_ids, not_constants, and_exp.children[i].get()); -// checks = MergeAnd(checks, child_check); -// } -// return checks; -// } -// case ExpressionType::COMPARE_IN: { -// auto &comp_in_exp = filter->Cast(); -// if (comp_in_exp.children[0]->type == ExpressionType::BOUND_COLUMN_REF) { -// Value *min = nullptr, *max = nullptr; -// auto &column_ref = comp_in_exp.children[0]->Cast(); -// for (size_t i {1}; i < comp_in_exp.children.size(); i++) { -// if (comp_in_exp.children[i]->type != ExpressionType::VALUE_CONSTANT) { -// //! This indicates the column has a comparison that is not with a constant -// not_constants.insert(column_ids[column_ref.binding.column_index]); -// break; -// } else { -// auto &const_value_expr = comp_in_exp.children[i]->Cast(); -// if (const_value_expr.value.IsNull()) { -// return checks; -// } -// if (!min && !max) { -// min = &const_value_expr.value; -// max = min; -// } else { -// if (*min > const_value_expr.value) { -// min = &const_value_expr.value; -// } -// if (*max < const_value_expr.value) { -// max = &const_value_expr.value; -// } -// } -// } -// } -// checks[column_ids[column_ref.binding.column_index]] = {min, max}; -// } -// return checks; -// } -// case ExpressionType::COMPARE_EQUAL: { -// auto &comp_exp = filter->Cast(); -// if ((comp_exp.left->expression_class == ExpressionClass::BOUND_COLUMN_REF && -// comp_exp.right->expression_class == ExpressionClass::BOUND_CONSTANT)) { -// auto &column_ref = comp_exp.left->Cast(); -// auto &constant_value_expr = comp_exp.right->Cast(); -// checks[column_ids[column_ref.binding.column_index]] = {&constant_value_expr.value, -// &constant_value_expr.value}; -// } -// if ((comp_exp.left->expression_class == ExpressionClass::BOUND_CONSTANT && -// comp_exp.right->expression_class == ExpressionClass::BOUND_COLUMN_REF)) { -// auto &column_ref = comp_exp.right->Cast(); -// auto &constant_value_expr = comp_exp.left->Cast(); -// checks[column_ids[column_ref.binding.column_index]] = {&constant_value_expr.value, -// &constant_value_expr.value}; -// } -// return checks; -// } -// case ExpressionType::COMPARE_LESSTHAN: -// case ExpressionType::COMPARE_LESSTHANOREQUALTO: { -// auto &comp_exp = filter->Cast(); -// if ((comp_exp.left->expression_class == ExpressionClass::BOUND_COLUMN_REF && -// comp_exp.right->expression_class == ExpressionClass::BOUND_CONSTANT)) { -// auto &column_ref = comp_exp.left->Cast(); -// auto &constant_value_expr = comp_exp.right->Cast(); -// checks[column_ids[column_ref.binding.column_index]] = {nullptr, &constant_value_expr.value}; -// } -// if ((comp_exp.left->expression_class == ExpressionClass::BOUND_CONSTANT && -// comp_exp.right->expression_class == ExpressionClass::BOUND_COLUMN_REF)) { -// auto &column_ref = comp_exp.right->Cast(); -// auto &constant_value_expr = comp_exp.left->Cast(); -// checks[column_ids[column_ref.binding.column_index]] = {&constant_value_expr.value, nullptr}; -// } -// return checks; -// } -// case ExpressionType::COMPARE_GREATERTHANOREQUALTO: -// case ExpressionType::COMPARE_GREATERTHAN: { -// auto &comp_exp = filter->Cast(); -// if ((comp_exp.left->expression_class == ExpressionClass::BOUND_COLUMN_REF && -// comp_exp.right->expression_class == ExpressionClass::BOUND_CONSTANT)) { -// auto &column_ref = comp_exp.left->Cast(); -// auto &constant_value_expr = comp_exp.right->Cast(); -// checks[column_ids[column_ref.binding.column_index]] = {&constant_value_expr.value, nullptr}; -// } -// if ((comp_exp.left->expression_class == ExpressionClass::BOUND_CONSTANT && -// comp_exp.right->expression_class == ExpressionClass::BOUND_COLUMN_REF)) { -// auto &column_ref = comp_exp.right->Cast(); -// auto &constant_value_expr = comp_exp.left->Cast(); -// checks[column_ids[column_ref.binding.column_index]] = {nullptr, &constant_value_expr.value}; -// } -// return checks; -// } -// default: -// return checks; -// } -// } - -// vector FilterCombiner::GenerateZonemapChecks(vector &column_ids, -// vector &pushed_filters) { -// vector zonemap_checks; -// unordered_set not_constants; -// //! We go through the remaining filters and capture their min max -// if (remaining_filters.empty()) { -// return zonemap_checks; -// } - -// auto checks = FindZonemapChecks(column_ids, not_constants, remaining_filters[0].get()); -// for (size_t i = 1; i < remaining_filters.size(); ++i) { -// auto child_check = FindZonemapChecks(column_ids, not_constants, remaining_filters[i].get()); -// checks = MergeAnd(checks, child_check); -// } -// //! We construct the equivalent filters -// for (auto not_constant : not_constants) { -// checks.erase(not_constant); -// } -// for (const auto &pushed_filter : pushed_filters) { -// checks.erase(column_ids[pushed_filter.column_index]); -// } -// for (const auto &check : checks) { -// if (check.second.first) { -// zonemap_checks.emplace_back(check.second.first->Copy(), ExpressionType::COMPARE_GREATERTHANOREQUALTO, -// check.first); -// } -// if (check.second.second) { -// zonemap_checks.emplace_back(check.second.second->Copy(), ExpressionType::COMPARE_LESSTHANOREQUALTO, -// check.first); -// } -// } -// return zonemap_checks; -// } - -TableFilterSet FilterCombiner::GenerateTableScanFilters(vector &column_ids) { - TableFilterSet table_filters; - //! First, we figure the filters that have constant expressions that we can push down to the table scan - for (auto &constant_value : constant_values) { - if (!constant_value.second.empty()) { - auto filter_exp = equivalence_map.end(); - if ((constant_value.second[0].comparison_type == ExpressionType::COMPARE_EQUAL || - constant_value.second[0].comparison_type == ExpressionType::COMPARE_GREATERTHAN || - constant_value.second[0].comparison_type == ExpressionType::COMPARE_GREATERTHANOREQUALTO || - constant_value.second[0].comparison_type == ExpressionType::COMPARE_LESSTHAN || - constant_value.second[0].comparison_type == ExpressionType::COMPARE_LESSTHANOREQUALTO) && - (TypeIsNumeric(constant_value.second[0].constant.type().InternalType()) || - constant_value.second[0].constant.type().InternalType() == PhysicalType::VARCHAR || - constant_value.second[0].constant.type().InternalType() == PhysicalType::BOOL)) { - //! Here we check if these filters are column references - filter_exp = equivalence_map.find(constant_value.first); - if (filter_exp->second.size() == 1 && - filter_exp->second[0].get().type == ExpressionType::BOUND_COLUMN_REF) { - auto &filter_col_exp = filter_exp->second[0].get().Cast(); - auto column_index = column_ids[filter_col_exp.binding.column_index]; - if (column_index == COLUMN_IDENTIFIER_ROW_ID) { - break; - } - auto equivalence_set = filter_exp->first; - auto &entries = filter_exp->second; - auto &constant_list = constant_values.find(equivalence_set)->second; - // for each entry generate an equality expression comparing to each other - for (idx_t i = 0; i < entries.size(); i++) { - // for each entry also create a comparison with each constant - for (idx_t k = 0; k < constant_list.size(); k++) { - auto constant_filter = make_uniq(constant_value.second[k].comparison_type, - constant_value.second[k].constant); - table_filters.PushFilter(column_index, std::move(constant_filter)); - } - table_filters.PushFilter(column_index, make_uniq()); - } - equivalence_map.erase(filter_exp); - } - } - } - } - //! Here we look for LIKE or IN filters - for (idx_t rem_fil_idx = 0; rem_fil_idx < remaining_filters.size(); rem_fil_idx++) { - auto &remaining_filter = remaining_filters[rem_fil_idx]; - if (remaining_filter->expression_class == ExpressionClass::BOUND_FUNCTION) { - auto &func = remaining_filter->Cast(); - if (func.function.name == "prefix" && - func.children[0]->expression_class == ExpressionClass::BOUND_COLUMN_REF && - func.children[1]->type == ExpressionType::VALUE_CONSTANT) { - //! This is a like function. - auto &column_ref = func.children[0]->Cast(); - auto &constant_value_expr = func.children[1]->Cast(); - auto like_string = StringValue::Get(constant_value_expr.value); - if (like_string.empty()) { - continue; - } - auto column_index = column_ids[column_ref.binding.column_index]; - //! Here the like must be transformed to a BOUND COMPARISON geq le - auto lower_bound = - make_uniq(ExpressionType::COMPARE_GREATERTHANOREQUALTO, Value(like_string)); - like_string[like_string.size() - 1]++; - auto upper_bound = make_uniq(ExpressionType::COMPARE_LESSTHAN, Value(like_string)); - table_filters.PushFilter(column_index, std::move(lower_bound)); - table_filters.PushFilter(column_index, std::move(upper_bound)); - table_filters.PushFilter(column_index, make_uniq()); - } - if (func.function.name == "~~" && func.children[0]->expression_class == ExpressionClass::BOUND_COLUMN_REF && - func.children[1]->type == ExpressionType::VALUE_CONSTANT) { - //! This is a like function. - auto &column_ref = func.children[0]->Cast(); - auto &constant_value_expr = func.children[1]->Cast(); - auto &like_string = StringValue::Get(constant_value_expr.value); - if (like_string[0] == '%' || like_string[0] == '_') { - //! We have no prefix so nothing to pushdown - break; - } - string prefix; - bool equality = true; - for (char const &c : like_string) { - if (c == '%' || c == '_') { - equality = false; - break; - } - prefix += c; - } - auto column_index = column_ids[column_ref.binding.column_index]; - if (equality) { - //! Here the like can be transformed to an equality query - auto equal_filter = make_uniq(ExpressionType::COMPARE_EQUAL, Value(prefix)); - table_filters.PushFilter(column_index, std::move(equal_filter)); - table_filters.PushFilter(column_index, make_uniq()); - } else { - //! Here the like must be transformed to a BOUND COMPARISON geq le - auto lower_bound = - make_uniq(ExpressionType::COMPARE_GREATERTHANOREQUALTO, Value(prefix)); - prefix[prefix.size() - 1]++; - auto upper_bound = make_uniq(ExpressionType::COMPARE_LESSTHAN, Value(prefix)); - table_filters.PushFilter(column_index, std::move(lower_bound)); - table_filters.PushFilter(column_index, std::move(upper_bound)); - table_filters.PushFilter(column_index, make_uniq()); - } - } - } else if (remaining_filter->type == ExpressionType::COMPARE_IN) { - auto &func = remaining_filter->Cast(); - vector in_values; - D_ASSERT(func.children.size() > 1); - if (func.children[0]->expression_class != ExpressionClass::BOUND_COLUMN_REF) { - continue; - } - auto &column_ref = func.children[0]->Cast(); - auto column_index = column_ids[column_ref.binding.column_index]; - if (column_index == COLUMN_IDENTIFIER_ROW_ID) { - break; - } - //! check if all children are const expr - bool children_constant = true; - for (size_t i {1}; i < func.children.size(); i++) { - if (func.children[i]->type != ExpressionType::VALUE_CONSTANT) { - children_constant = false; - } - } - if (!children_constant) { - continue; - } - auto &fst_const_value_expr = func.children[1]->Cast(); - auto &type = fst_const_value_expr.value.type(); - - //! Check if values are consecutive, if yes transform them to >= <= (only for integers) - // e.g. if we have x IN (1, 2, 3, 4, 5) we transform this into x >= 1 AND x <= 5 - if (!type.IsIntegral()) { - continue; - } - - bool can_simplify_in_clause = true; - for (idx_t i = 1; i < func.children.size(); i++) { - auto &const_value_expr = func.children[i]->Cast(); - if (const_value_expr.value.IsNull()) { - can_simplify_in_clause = false; - break; - } - in_values.push_back(const_value_expr.value.GetValue()); - } - if (!can_simplify_in_clause || in_values.empty()) { - continue; - } - - sort(in_values.begin(), in_values.end()); - - for (idx_t in_val_idx = 1; in_val_idx < in_values.size(); in_val_idx++) { - if (in_values[in_val_idx] - in_values[in_val_idx - 1] > 1) { - can_simplify_in_clause = false; - break; - } - } - if (!can_simplify_in_clause) { - continue; - } - auto lower_bound = make_uniq(ExpressionType::COMPARE_GREATERTHANOREQUALTO, - Value::Numeric(type, in_values.front())); - auto upper_bound = make_uniq(ExpressionType::COMPARE_LESSTHANOREQUALTO, - Value::Numeric(type, in_values.back())); - table_filters.PushFilter(column_index, std::move(lower_bound)); - table_filters.PushFilter(column_index, std::move(upper_bound)); - table_filters.PushFilter(column_index, make_uniq()); - - remaining_filters.erase(remaining_filters.begin() + rem_fil_idx); - } - } - - // GenerateORFilters(table_filters, column_ids); - - return table_filters; -} - -static bool IsGreaterThan(ExpressionType type) { - return type == ExpressionType::COMPARE_GREATERTHAN || type == ExpressionType::COMPARE_GREATERTHANOREQUALTO; -} - -static bool IsLessThan(ExpressionType type) { - return type == ExpressionType::COMPARE_LESSTHAN || type == ExpressionType::COMPARE_LESSTHANOREQUALTO; -} - -FilterResult FilterCombiner::AddBoundComparisonFilter(Expression &expr) { - auto &comparison = expr.Cast(); - if (comparison.type != ExpressionType::COMPARE_LESSTHAN && - comparison.type != ExpressionType::COMPARE_LESSTHANOREQUALTO && - comparison.type != ExpressionType::COMPARE_GREATERTHAN && - comparison.type != ExpressionType::COMPARE_GREATERTHANOREQUALTO && - comparison.type != ExpressionType::COMPARE_EQUAL && comparison.type != ExpressionType::COMPARE_NOTEQUAL) { - // only support [>, >=, <, <=, ==, !=] expressions - return FilterResult::UNSUPPORTED; - } - // check if one of the sides is a scalar value - bool left_is_scalar = comparison.left->IsFoldable(); - bool right_is_scalar = comparison.right->IsFoldable(); - if (left_is_scalar || right_is_scalar) { - // comparison with scalar - auto &node = GetNode(left_is_scalar ? *comparison.right : *comparison.left); - idx_t equivalence_set = GetEquivalenceSet(node); - auto &scalar = left_is_scalar ? comparison.left : comparison.right; - Value constant_value; - if (!ExpressionExecutor::TryEvaluateScalar(context, *scalar, constant_value)) { - return FilterResult::UNSATISFIABLE; - } - if (constant_value.IsNull()) { - // comparisons with null are always null (i.e. will never result in rows) - return FilterResult::UNSATISFIABLE; - } - - // create the ExpressionValueInformation - ExpressionValueInformation info; - info.comparison_type = left_is_scalar ? FlipComparisonExpression(comparison.type) : comparison.type; - info.constant = constant_value; - - // get the current bucket of constant values - D_ASSERT(constant_values.find(equivalence_set) != constant_values.end()); - auto &info_list = constant_values.find(equivalence_set)->second; - D_ASSERT(node.return_type == info.constant.type()); - // check the existing constant comparisons to see if we can do any pruning - auto ret = AddConstantComparison(info_list, info); - - auto &non_scalar = left_is_scalar ? *comparison.right : *comparison.left; - auto transitive_filter = FindTransitiveFilter(non_scalar); - if (transitive_filter != nullptr) { - // try to add transitive filters - if (AddTransitiveFilters(transitive_filter->Cast()) == - FilterResult::UNSUPPORTED) { - // in case of unsuccessful re-add filter into remaining ones - remaining_filters.push_back(std::move(transitive_filter)); - } - } - return ret; - } else { - // comparison between two non-scalars - // only handle comparisons for now - if (expr.type != ExpressionType::COMPARE_EQUAL) { - if (IsGreaterThan(expr.type) || IsLessThan(expr.type)) { - return AddTransitiveFilters(comparison); - } - return FilterResult::UNSUPPORTED; - } - // get the LHS and RHS nodes - auto &left_node = GetNode(*comparison.left); - auto &right_node = GetNode(*comparison.right); - if (left_node.Equals(right_node)) { - return FilterResult::UNSUPPORTED; - } - // get the equivalence sets of the LHS and RHS - auto left_equivalence_set = GetEquivalenceSet(left_node); - auto right_equivalence_set = GetEquivalenceSet(right_node); - if (left_equivalence_set == right_equivalence_set) { - // this equality filter already exists, prune it - return FilterResult::SUCCESS; - } - // add the right bucket into the left bucket - D_ASSERT(equivalence_map.find(left_equivalence_set) != equivalence_map.end()); - D_ASSERT(equivalence_map.find(right_equivalence_set) != equivalence_map.end()); - - auto &left_bucket = equivalence_map.find(left_equivalence_set)->second; - auto &right_bucket = equivalence_map.find(right_equivalence_set)->second; - for (auto &right_expr : right_bucket) { - // rewrite the equivalence set mapping for this node - equivalence_set_map[right_expr] = left_equivalence_set; - // add the node to the left bucket - left_bucket.push_back(right_expr); - } - // now add all constant values from the right bucket to the left bucket - D_ASSERT(constant_values.find(left_equivalence_set) != constant_values.end()); - D_ASSERT(constant_values.find(right_equivalence_set) != constant_values.end()); - auto &left_constant_bucket = constant_values.find(left_equivalence_set)->second; - auto &right_constant_bucket = constant_values.find(right_equivalence_set)->second; - for (auto &right_constant : right_constant_bucket) { - if (AddConstantComparison(left_constant_bucket, right_constant) == FilterResult::UNSATISFIABLE) { - return FilterResult::UNSATISFIABLE; - } - } - } - return FilterResult::SUCCESS; -} - -FilterResult FilterCombiner::AddFilter(Expression &expr) { - if (expr.HasParameter()) { - return FilterResult::UNSUPPORTED; - } - if (expr.IsFoldable()) { - // scalar condition, evaluate it - Value result; - if (!ExpressionExecutor::TryEvaluateScalar(context, expr, result)) { - return FilterResult::UNSUPPORTED; - } - result = result.DefaultCastAs(LogicalType::BOOLEAN); - // check if the filter passes - if (result.IsNull() || !BooleanValue::Get(result)) { - // the filter does not pass the scalar test, create an empty result - return FilterResult::UNSATISFIABLE; - } else { - // the filter passes the scalar test, just remove the condition - return FilterResult::SUCCESS; - } - } - D_ASSERT(!expr.IsFoldable()); - if (expr.GetExpressionClass() == ExpressionClass::BOUND_BETWEEN) { - auto &comparison = expr.Cast(); - //! check if one of the sides is a scalar value - bool lower_is_scalar = comparison.lower->IsFoldable(); - bool upper_is_scalar = comparison.upper->IsFoldable(); - if (lower_is_scalar || upper_is_scalar) { - //! comparison with scalar - break apart - auto &node = GetNode(*comparison.input); - idx_t equivalence_set = GetEquivalenceSet(node); - auto result = FilterResult::UNSATISFIABLE; - - if (lower_is_scalar) { - auto scalar = comparison.lower.get(); - Value constant_value; - if (!ExpressionExecutor::TryEvaluateScalar(context, *scalar, constant_value)) { - return FilterResult::UNSUPPORTED; - } - - // create the ExpressionValueInformation - ExpressionValueInformation info; - if (comparison.lower_inclusive) { - info.comparison_type = ExpressionType::COMPARE_GREATERTHANOREQUALTO; - } else { - info.comparison_type = ExpressionType::COMPARE_GREATERTHAN; - } - info.constant = constant_value; - - // get the current bucket of constant values - D_ASSERT(constant_values.find(equivalence_set) != constant_values.end()); - auto &info_list = constant_values.find(equivalence_set)->second; - // check the existing constant comparisons to see if we can do any pruning - result = AddConstantComparison(info_list, info); - } else { - D_ASSERT(upper_is_scalar); - const auto type = comparison.upper_inclusive ? ExpressionType::COMPARE_LESSTHANOREQUALTO - : ExpressionType::COMPARE_LESSTHAN; - auto left = comparison.lower->Copy(); - auto right = comparison.input->Copy(); - auto lower_comp = make_uniq(type, std::move(left), std::move(right)); - result = AddBoundComparisonFilter(*lower_comp); - } - - // Stop if we failed - if (result != FilterResult::SUCCESS) { - return result; - } - - if (upper_is_scalar) { - auto scalar = comparison.upper.get(); - Value constant_value; - if (!ExpressionExecutor::TryEvaluateScalar(context, *scalar, constant_value)) { - return FilterResult::UNSUPPORTED; - } - - // create the ExpressionValueInformation - ExpressionValueInformation info; - if (comparison.upper_inclusive) { - info.comparison_type = ExpressionType::COMPARE_LESSTHANOREQUALTO; - } else { - info.comparison_type = ExpressionType::COMPARE_LESSTHAN; - } - info.constant = constant_value; - - // get the current bucket of constant values - D_ASSERT(constant_values.find(equivalence_set) != constant_values.end()); - // check the existing constant comparisons to see if we can do any pruning - result = AddConstantComparison(constant_values.find(equivalence_set)->second, info); - } else { - D_ASSERT(lower_is_scalar); - const auto type = comparison.upper_inclusive ? ExpressionType::COMPARE_LESSTHANOREQUALTO - : ExpressionType::COMPARE_LESSTHAN; - auto left = comparison.input->Copy(); - auto right = comparison.upper->Copy(); - auto upper_comp = make_uniq(type, std::move(left), std::move(right)); - result = AddBoundComparisonFilter(*upper_comp); - } - - return result; - } - } else if (expr.GetExpressionClass() == ExpressionClass::BOUND_COMPARISON) { - return AddBoundComparisonFilter(expr); - } - // only comparisons supported for now - return FilterResult::UNSUPPORTED; -} - -/* - * Create and add new transitive filters from a two non-scalar filter such as j > i, j >= i, j < i, and j <= i - * It's missing to create another method to add transitive filters from scalar filters, e.g, i > 10 - */ -FilterResult FilterCombiner::AddTransitiveFilters(BoundComparisonExpression &comparison) { - D_ASSERT(IsGreaterThan(comparison.type) || IsLessThan(comparison.type)); - // get the LHS and RHS nodes - auto &left_node = GetNode(*comparison.left); - reference right_node = GetNode(*comparison.right); - // In case with filters like CAST(i) = j and i = 5 we replace the COLUMN_REF i with the constant 5 - if (right_node.get().type == ExpressionType::OPERATOR_CAST) { - auto &bound_cast_expr = right_node.get().Cast(); - if (bound_cast_expr.child->type == ExpressionType::BOUND_COLUMN_REF) { - auto &col_ref = bound_cast_expr.child->Cast(); - for (auto &stored_exp : stored_expressions) { - if (stored_exp.first.get().type == ExpressionType::BOUND_COLUMN_REF) { - auto &st_col_ref = stored_exp.second->Cast(); - if (st_col_ref.binding == col_ref.binding && - bound_cast_expr.return_type == stored_exp.second->return_type) { - bound_cast_expr.child = stored_exp.second->Copy(); - right_node = GetNode(*bound_cast_expr.child); - break; - } - } - } - } - } - - if (left_node.Equals(right_node)) { - return FilterResult::UNSUPPORTED; - } - // get the equivalence sets of the LHS and RHS - idx_t left_equivalence_set = GetEquivalenceSet(left_node); - idx_t right_equivalence_set = GetEquivalenceSet(right_node); - if (left_equivalence_set == right_equivalence_set) { - // this equality filter already exists, prune it - return FilterResult::SUCCESS; - } - - vector &left_constants = constant_values.find(left_equivalence_set)->second; - vector &right_constants = constant_values.find(right_equivalence_set)->second; - bool is_successful = false; - bool is_inserted = false; - // read every constant filters already inserted for the right scalar variable - // and see if we can create new transitive filters, e.g., there is already a filter i > 10, - // suppose that we have now the j >= i, then we can infer a new filter j > 10 - for (const auto &right_constant : right_constants) { - ExpressionValueInformation info; - info.constant = right_constant.constant; - // there is already an equality filter, e.g., i = 10 - if (right_constant.comparison_type == ExpressionType::COMPARE_EQUAL) { - // create filter j [>, >=, <, <=] 10 - // suppose the new comparison is j >= i and we have already a filter i = 10, - // then we create a new filter j >= 10 - // and the filter j >= i can be pruned by not adding it into the remaining filters - info.comparison_type = comparison.type; - } else if ((comparison.type == ExpressionType::COMPARE_GREATERTHANOREQUALTO && - IsGreaterThan(right_constant.comparison_type)) || - (comparison.type == ExpressionType::COMPARE_LESSTHANOREQUALTO && - IsLessThan(right_constant.comparison_type))) { - // filters (j >= i AND i [>, >=] 10) OR (j <= i AND i [<, <=] 10) - // create filter j [>, >=] 10 and add the filter j [>=, <=] i into the remaining filters - info.comparison_type = right_constant.comparison_type; // create filter j [>, >=, <, <=] 10 - if (!is_inserted) { - // Add the filter j >= i in the remaing filters - auto filter = make_uniq(comparison.type, comparison.left->Copy(), - comparison.right->Copy()); - remaining_filters.push_back(std::move(filter)); - is_inserted = true; - } - } else if ((comparison.type == ExpressionType::COMPARE_GREATERTHAN && - IsGreaterThan(right_constant.comparison_type)) || - (comparison.type == ExpressionType::COMPARE_LESSTHAN && - IsLessThan(right_constant.comparison_type))) { - // filters (j > i AND i [>, >=] 10) OR j < i AND i [<, <=] 10 - // create filter j [>, <] 10 and add the filter j [>, <] i into the remaining filters - // the comparisons j > i and j < i are more restrictive - info.comparison_type = comparison.type; - if (!is_inserted) { - // Add the filter j [>, <] i - auto filter = make_uniq(comparison.type, comparison.left->Copy(), - comparison.right->Copy()); - remaining_filters.push_back(std::move(filter)); - is_inserted = true; - } - } else { - // we cannot add a new filter - continue; - } - // Add the new filer into the left set - if (AddConstantComparison(left_constants, info) == FilterResult::UNSATISFIABLE) { - return FilterResult::UNSATISFIABLE; - } - is_successful = true; - } - if (is_successful) { - // now check for remaining trasitive filters from the left column - auto transitive_filter = FindTransitiveFilter(*comparison.left); - if (transitive_filter != nullptr) { - // try to add transitive filters - if (AddTransitiveFilters(transitive_filter->Cast()) == - FilterResult::UNSUPPORTED) { - // in case of unsuccessful re-add filter into remaining ones - remaining_filters.push_back(std::move(transitive_filter)); - } - } - return FilterResult::SUCCESS; - } - - return FilterResult::UNSUPPORTED; -} - -/* - * Find a transitive filter already inserted into the remaining filters - * Check for a match between the right column of bound comparisons and the expression, - * then removes the bound comparison from the remaining filters and returns it - */ -unique_ptr FilterCombiner::FindTransitiveFilter(Expression &expr) { - // We only check for bound column ref - if (expr.type != ExpressionType::BOUND_COLUMN_REF) { - return nullptr; - } - for (idx_t i = 0; i < remaining_filters.size(); i++) { - if (remaining_filters[i]->GetExpressionClass() == ExpressionClass::BOUND_COMPARISON) { - auto &comparison = remaining_filters[i]->Cast(); - if (expr.Equals(*comparison.right) && comparison.type != ExpressionType::COMPARE_NOTEQUAL) { - auto filter = std::move(remaining_filters[i]); - remaining_filters.erase(remaining_filters.begin() + i); - return filter; - } - } - } - return nullptr; -} - -ValueComparisonResult InvertValueComparisonResult(ValueComparisonResult result) { - if (result == ValueComparisonResult::PRUNE_RIGHT) { - return ValueComparisonResult::PRUNE_LEFT; - } - if (result == ValueComparisonResult::PRUNE_LEFT) { - return ValueComparisonResult::PRUNE_RIGHT; - } - return result; -} - -ValueComparisonResult CompareValueInformation(ExpressionValueInformation &left, ExpressionValueInformation &right) { - if (left.comparison_type == ExpressionType::COMPARE_EQUAL) { - // left is COMPARE_EQUAL, we can either - // (1) prune the right side or - // (2) return UNSATISFIABLE - bool prune_right_side = false; - switch (right.comparison_type) { - case ExpressionType::COMPARE_LESSTHAN: - prune_right_side = left.constant < right.constant; - break; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - prune_right_side = left.constant <= right.constant; - break; - case ExpressionType::COMPARE_GREATERTHAN: - prune_right_side = left.constant > right.constant; - break; - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - prune_right_side = left.constant >= right.constant; - break; - case ExpressionType::COMPARE_NOTEQUAL: - prune_right_side = left.constant != right.constant; - break; - default: - D_ASSERT(right.comparison_type == ExpressionType::COMPARE_EQUAL); - prune_right_side = left.constant == right.constant; - break; - } - if (prune_right_side) { - return ValueComparisonResult::PRUNE_RIGHT; - } else { - return ValueComparisonResult::UNSATISFIABLE_CONDITION; - } - } else if (right.comparison_type == ExpressionType::COMPARE_EQUAL) { - // right is COMPARE_EQUAL - return InvertValueComparisonResult(CompareValueInformation(right, left)); - } else if (left.comparison_type == ExpressionType::COMPARE_NOTEQUAL) { - // left is COMPARE_NOTEQUAL, we can either - // (1) prune the left side or - // (2) not prune anything - bool prune_left_side = false; - switch (right.comparison_type) { - case ExpressionType::COMPARE_LESSTHAN: - prune_left_side = left.constant >= right.constant; - break; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - prune_left_side = left.constant > right.constant; - break; - case ExpressionType::COMPARE_GREATERTHAN: - prune_left_side = left.constant <= right.constant; - break; - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - prune_left_side = left.constant < right.constant; - break; - default: - D_ASSERT(right.comparison_type == ExpressionType::COMPARE_NOTEQUAL); - prune_left_side = left.constant == right.constant; - break; - } - if (prune_left_side) { - return ValueComparisonResult::PRUNE_LEFT; - } else { - return ValueComparisonResult::PRUNE_NOTHING; - } - } else if (right.comparison_type == ExpressionType::COMPARE_NOTEQUAL) { - return InvertValueComparisonResult(CompareValueInformation(right, left)); - } else if (IsGreaterThan(left.comparison_type) && IsGreaterThan(right.comparison_type)) { - // both comparisons are [>], we can either - // (1) prune the left side or - // (2) prune the right side - if (left.constant > right.constant) { - // left constant is more selective, prune right - return ValueComparisonResult::PRUNE_RIGHT; - } else if (left.constant < right.constant) { - // right constant is more selective, prune left - return ValueComparisonResult::PRUNE_LEFT; - } else { - // constants are equivalent - // however we can still have the scenario where one is [>=] and the other is [>] - // we want to prune the [>=] because [>] is more selective - // if left is [>=] we prune the left, else we prune the right - if (left.comparison_type == ExpressionType::COMPARE_GREATERTHANOREQUALTO) { - return ValueComparisonResult::PRUNE_LEFT; - } else { - return ValueComparisonResult::PRUNE_RIGHT; - } - } - } else if (IsLessThan(left.comparison_type) && IsLessThan(right.comparison_type)) { - // both comparisons are [<], we can either - // (1) prune the left side or - // (2) prune the right side - if (left.constant < right.constant) { - // left constant is more selective, prune right - return ValueComparisonResult::PRUNE_RIGHT; - } else if (left.constant > right.constant) { - // right constant is more selective, prune left - return ValueComparisonResult::PRUNE_LEFT; - } else { - // constants are equivalent - // however we can still have the scenario where one is [<=] and the other is [<] - // we want to prune the [<=] because [<] is more selective - // if left is [<=] we prune the left, else we prune the right - if (left.comparison_type == ExpressionType::COMPARE_LESSTHANOREQUALTO) { - return ValueComparisonResult::PRUNE_LEFT; - } else { - return ValueComparisonResult::PRUNE_RIGHT; - } - } - } else if (IsLessThan(left.comparison_type)) { - D_ASSERT(IsGreaterThan(right.comparison_type)); - // left is [<] and right is [>], in this case we can either - // (1) prune nothing or - // (2) return UNSATISFIABLE - // the SMALLER THAN constant has to be greater than the BIGGER THAN constant - if (left.constant >= right.constant) { - return ValueComparisonResult::PRUNE_NOTHING; - } else { - return ValueComparisonResult::UNSATISFIABLE_CONDITION; - } - } else { - // left is [>] and right is [<] or [!=] - D_ASSERT(IsLessThan(right.comparison_type) && IsGreaterThan(left.comparison_type)); - return InvertValueComparisonResult(CompareValueInformation(right, left)); - } -} -// -// void FilterCombiner::LookUpConjunctions(Expression *expr) { -// if (expr->GetExpressionType() == ExpressionType::CONJUNCTION_OR) { -// auto root_or_expr = (BoundConjunctionExpression *)expr; -// for (const auto &entry : map_col_conjunctions) { -// for (const auto &conjs_to_push : entry.second) { -// if (conjs_to_push->root_or->Equals(root_or_expr)) { -// return; -// } -// } -// } -// -// cur_root_or = root_or_expr; -// cur_conjunction = root_or_expr; -// cur_colref_to_push = nullptr; -// if (!BFSLookUpConjunctions(cur_root_or)) { -// if (cur_colref_to_push) { -// auto entry = map_col_conjunctions.find(cur_colref_to_push); -// auto &vec_conjs_to_push = entry->second; -// if (vec_conjs_to_push.size() == 1) { -// map_col_conjunctions.erase(entry); -// return; -// } -// vec_conjs_to_push.pop_back(); -// } -// } -// return; -// } -// -// // Verify if the expression has a column already pushed down by other OR expression -// VerifyOrsToPush(*expr); -//} -// -// bool FilterCombiner::BFSLookUpConjunctions(BoundConjunctionExpression *conjunction) { -// vector conjunctions_to_visit; -// -// for (auto &child : conjunction->children) { -// switch (child->GetExpressionClass()) { -// case ExpressionClass::BOUND_CONJUNCTION: { -// auto child_conjunction = (BoundConjunctionExpression *)child.get(); -// conjunctions_to_visit.emplace_back(child_conjunction); -// break; -// } -// case ExpressionClass::BOUND_COMPARISON: { -// if (!UpdateConjunctionFilter((BoundComparisonExpression *)child.get())) { -// return false; -// } -// break; -// } -// default: { -// return false; -// } -// } -// } -// -// for (auto child_conjunction : conjunctions_to_visit) { -// cur_conjunction = child_conjunction; -// // traverse child conjuction -// if (!BFSLookUpConjunctions(child_conjunction)) { -// return false; -// } -// } -// return true; -//} -// -// void FilterCombiner::VerifyOrsToPush(Expression &expr) { -// if (expr.type == ExpressionType::BOUND_COLUMN_REF) { -// auto colref = (BoundColumnRefExpression *)&expr; -// auto entry = map_col_conjunctions.find(colref); -// if (entry == map_col_conjunctions.end()) { -// return; -// } -// map_col_conjunctions.erase(entry); -// } -// ExpressionIterator::EnumerateChildren(expr, [&](Expression &child) { VerifyOrsToPush(child); }); -//} -// -// bool FilterCombiner::UpdateConjunctionFilter(BoundComparisonExpression *comparison_expr) { -// bool left_is_scalar = comparison_expr->left->IsFoldable(); -// bool right_is_scalar = comparison_expr->right->IsFoldable(); -// -// Expression *non_scalar_expr; -// if (left_is_scalar || right_is_scalar) { -// // only support comparison with scalar -// non_scalar_expr = left_is_scalar ? comparison_expr->right.get() : comparison_expr->left.get(); -// -// if (non_scalar_expr->GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { -// return UpdateFilterByColumn((BoundColumnRefExpression *)non_scalar_expr, comparison_expr); -// } -// } -// -// return false; -//} -// -// bool FilterCombiner::UpdateFilterByColumn(BoundColumnRefExpression *column_ref, -// BoundComparisonExpression *comparison_expr) { -// if (cur_colref_to_push == nullptr) { -// cur_colref_to_push = column_ref; -// -// auto or_conjunction = make_uniq(ExpressionType::CONJUNCTION_OR); -// or_conjunction->children.emplace_back(comparison_expr->Copy()); -// -// unique_ptr conjs_to_push = make_uniq(); -// conjs_to_push->conjunctions.emplace_back(std::move(or_conjunction)); -// conjs_to_push->root_or = cur_root_or; -// -// auto &&vec_col_conjs = map_col_conjunctions[column_ref]; -// vec_col_conjs.emplace_back(std::move(conjs_to_push)); -// vec_colref_insertion_order.emplace_back(column_ref); -// return true; -// } -// -// auto entry = map_col_conjunctions.find(cur_colref_to_push); -// D_ASSERT(entry != map_col_conjunctions.end()); -// auto &conjunctions_to_push = entry->second.back(); -// -// if (!cur_colref_to_push->Equals(column_ref)) { -// // check for multiple colunms in the same root OR node -// if (cur_root_or == cur_conjunction) { -// return false; -// } -// // found an AND using a different column, we should stop the look up -// if (cur_conjunction->GetExpressionType() == ExpressionType::CONJUNCTION_AND) { -// return false; -// } -// -// // found a different column, AND conditions cannot be preserved anymore -// conjunctions_to_push->preserve_and = false; -// return true; -// } -// -// auto &last_conjunction = conjunctions_to_push->conjunctions.back(); -// if (cur_conjunction->GetExpressionType() == last_conjunction->GetExpressionType()) { -// last_conjunction->children.emplace_back(comparison_expr->Copy()); -// } else { -// auto new_conjunction = make_uniq(cur_conjunction->GetExpressionType()); -// new_conjunction->children.emplace_back(comparison_expr->Copy()); -// conjunctions_to_push->conjunctions.emplace_back(std::move(new_conjunction)); -// } -// return true; -//} -// -// void FilterCombiner::GenerateORFilters(TableFilterSet &table_filter, vector &column_ids) { -// for (const auto colref : vec_colref_insertion_order) { -// auto column_index = column_ids[colref->binding.column_index]; -// if (column_index == COLUMN_IDENTIFIER_ROW_ID) { -// break; -// } -// -// for (const auto &conjunctions_to_push : map_col_conjunctions[colref]) { -// // root OR filter to push into the TableFilter -// auto root_or_filter = make_uniq(); -// // variable to hold the last conjuntion filter pointer -// // the next filter will be added into it, i.e., we create a chain of conjunction filters -// ConjunctionFilter *last_conj_filter = root_or_filter.get(); -// -// for (auto &conjunction : conjunctions_to_push->conjunctions) { -// if (conjunction->GetExpressionType() == ExpressionType::CONJUNCTION_AND && -// conjunctions_to_push->preserve_and) { -// GenerateConjunctionFilter(conjunction.get(), last_conj_filter); -// } else { -// GenerateConjunctionFilter(conjunction.get(), last_conj_filter); -// } -// } -// table_filter.PushFilter(column_index, std::move(root_or_filter)); -// } -// } -// map_col_conjunctions.clear(); -// vec_colref_insertion_order.clear(); -//} - -} // namespace duckdb - - - -namespace duckdb { - -unique_ptr FilterPullup::Rewrite(unique_ptr op) { - switch (op->type) { - case LogicalOperatorType::LOGICAL_FILTER: - return PullupFilter(std::move(op)); - case LogicalOperatorType::LOGICAL_PROJECTION: - return PullupProjection(std::move(op)); - case LogicalOperatorType::LOGICAL_CROSS_PRODUCT: - return PullupCrossProduct(std::move(op)); - case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: - case LogicalOperatorType::LOGICAL_ANY_JOIN: - case LogicalOperatorType::LOGICAL_DELIM_JOIN: - case LogicalOperatorType::LOGICAL_ASOF_JOIN: - return PullupJoin(std::move(op)); - case LogicalOperatorType::LOGICAL_INTERSECT: - case LogicalOperatorType::LOGICAL_EXCEPT: - return PullupSetOperation(std::move(op)); - case LogicalOperatorType::LOGICAL_DISTINCT: - case LogicalOperatorType::LOGICAL_ORDER_BY: { - // we can just pull directly through these operations without any rewriting - op->children[0] = Rewrite(std::move(op->children[0])); - return op; - } - default: - return FinishPullup(std::move(op)); - } -} - -unique_ptr FilterPullup::PullupJoin(unique_ptr op) { - D_ASSERT(op->type == LogicalOperatorType::LOGICAL_COMPARISON_JOIN || - op->type == LogicalOperatorType::LOGICAL_ASOF_JOIN || op->type == LogicalOperatorType::LOGICAL_ANY_JOIN || - op->type == LogicalOperatorType::LOGICAL_DELIM_JOIN); - auto &join = op->Cast(); - - switch (join.join_type) { - case JoinType::INNER: - return PullupInnerJoin(std::move(op)); - case JoinType::LEFT: - case JoinType::ANTI: - case JoinType::SEMI: { - return PullupFromLeft(std::move(op)); - } - default: - // unsupported join type: call children pull up - return FinishPullup(std::move(op)); - } -} - -unique_ptr FilterPullup::PullupInnerJoin(unique_ptr op) { - D_ASSERT(op->Cast().join_type == JoinType::INNER); - if (op->type == LogicalOperatorType::LOGICAL_DELIM_JOIN) { - return op; - } - return PullupBothSide(std::move(op)); -} - -unique_ptr FilterPullup::PullupCrossProduct(unique_ptr op) { - D_ASSERT(op->type == LogicalOperatorType::LOGICAL_CROSS_PRODUCT); - return PullupBothSide(std::move(op)); -} - -unique_ptr FilterPullup::GeneratePullupFilter(unique_ptr child, - vector> &expressions) { - unique_ptr filter = make_uniq(); - for (idx_t i = 0; i < expressions.size(); ++i) { - filter->expressions.push_back(std::move(expressions[i])); - } - expressions.clear(); - filter->children.push_back(std::move(child)); - return std::move(filter); -} - -unique_ptr FilterPullup::FinishPullup(unique_ptr op) { - // unhandled type, first perform filter pushdown in its children - for (idx_t i = 0; i < op->children.size(); i++) { - FilterPullup pullup; - op->children[i] = pullup.Rewrite(std::move(op->children[i])); - } - // now pull up any existing filters - if (filters_expr_pullup.empty()) { - // no filters to pull up - return op; - } - return GeneratePullupFilter(std::move(op), filters_expr_pullup); -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -using Filter = FilterPushdown::Filter; - -FilterPushdown::FilterPushdown(Optimizer &optimizer) : optimizer(optimizer), combiner(optimizer.context) { -} - -unique_ptr FilterPushdown::Rewrite(unique_ptr op) { - D_ASSERT(!combiner.HasFilters()); - switch (op->type) { - case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: - return PushdownAggregate(std::move(op)); - case LogicalOperatorType::LOGICAL_FILTER: - return PushdownFilter(std::move(op)); - case LogicalOperatorType::LOGICAL_CROSS_PRODUCT: - return PushdownCrossProduct(std::move(op)); - case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: - case LogicalOperatorType::LOGICAL_ANY_JOIN: - case LogicalOperatorType::LOGICAL_ASOF_JOIN: - case LogicalOperatorType::LOGICAL_DELIM_JOIN: - return PushdownJoin(std::move(op)); - case LogicalOperatorType::LOGICAL_PROJECTION: - return PushdownProjection(std::move(op)); - case LogicalOperatorType::LOGICAL_INTERSECT: - case LogicalOperatorType::LOGICAL_EXCEPT: - case LogicalOperatorType::LOGICAL_UNION: - return PushdownSetOperation(std::move(op)); - case LogicalOperatorType::LOGICAL_DISTINCT: - case LogicalOperatorType::LOGICAL_ORDER_BY: { - // we can just push directly through these operations without any rewriting - op->children[0] = Rewrite(std::move(op->children[0])); - return op; - } - case LogicalOperatorType::LOGICAL_GET: - return PushdownGet(std::move(op)); - case LogicalOperatorType::LOGICAL_LIMIT: - return PushdownLimit(std::move(op)); - default: - return FinishPushdown(std::move(op)); - } -} - -ClientContext &FilterPushdown::GetContext() { - return optimizer.GetContext(); -} - -unique_ptr FilterPushdown::PushdownJoin(unique_ptr op) { - D_ASSERT(op->type == LogicalOperatorType::LOGICAL_COMPARISON_JOIN || - op->type == LogicalOperatorType::LOGICAL_ASOF_JOIN || op->type == LogicalOperatorType::LOGICAL_ANY_JOIN || - op->type == LogicalOperatorType::LOGICAL_DELIM_JOIN); - auto &join = op->Cast(); - if (!join.left_projection_map.empty() || !join.right_projection_map.empty()) { - // cannot push down further otherwise the projection maps won't be preserved - return FinishPushdown(std::move(op)); - } - - unordered_set left_bindings, right_bindings; - LogicalJoin::GetTableReferences(*op->children[0], left_bindings); - LogicalJoin::GetTableReferences(*op->children[1], right_bindings); - - switch (join.join_type) { - case JoinType::INNER: - return PushdownInnerJoin(std::move(op), left_bindings, right_bindings); - case JoinType::LEFT: - return PushdownLeftJoin(std::move(op), left_bindings, right_bindings); - case JoinType::MARK: - return PushdownMarkJoin(std::move(op), left_bindings, right_bindings); - case JoinType::SINGLE: - return PushdownSingleJoin(std::move(op), left_bindings, right_bindings); - default: - // unsupported join type: stop pushing down - return FinishPushdown(std::move(op)); - } -} -void FilterPushdown::PushFilters() { - for (auto &f : filters) { - auto result = combiner.AddFilter(std::move(f->filter)); - D_ASSERT(result != FilterResult::UNSUPPORTED); - (void)result; - } - filters.clear(); -} - -FilterResult FilterPushdown::AddFilter(unique_ptr expr) { - PushFilters(); - // split up the filters by AND predicate - vector> expressions; - expressions.push_back(std::move(expr)); - LogicalFilter::SplitPredicates(expressions); - // push the filters into the combiner - for (auto &child_expr : expressions) { - if (combiner.AddFilter(std::move(child_expr)) == FilterResult::UNSATISFIABLE) { - return FilterResult::UNSATISFIABLE; - } - } - return FilterResult::SUCCESS; -} - -void FilterPushdown::GenerateFilters() { - if (!filters.empty()) { - D_ASSERT(!combiner.HasFilters()); - return; - } - combiner.GenerateFilters([&](unique_ptr filter) { - auto f = make_uniq(); - f->filter = std::move(filter); - f->ExtractBindings(); - filters.push_back(std::move(f)); - }); -} - -unique_ptr FilterPushdown::AddLogicalFilter(unique_ptr op, - vector> expressions) { - if (expressions.empty()) { - // No left expressions, so needn't to add an extra filter operator. - return op; - } - auto filter = make_uniq(); - filter->expressions = std::move(expressions); - filter->children.push_back(std::move(op)); - return std::move(filter); -} - -unique_ptr FilterPushdown::PushFinalFilters(unique_ptr op) { - vector> expressions; - for (auto &f : filters) { - expressions.push_back(std::move(f->filter)); - } - - return AddLogicalFilter(std::move(op), std::move(expressions)); -} - -unique_ptr FilterPushdown::FinishPushdown(unique_ptr op) { - // unhandled type, first perform filter pushdown in its children - for (auto &child : op->children) { - FilterPushdown pushdown(optimizer); - child = pushdown.Rewrite(std::move(child)); - } - // now push any existing filters - return PushFinalFilters(std::move(op)); -} - -void FilterPushdown::Filter::ExtractBindings() { - bindings.clear(); - LogicalJoin::GetExpressionBindings(*filter, bindings); -} - -} // namespace duckdb - - - - - - - - - - -namespace duckdb { - -unique_ptr InClauseRewriter::Rewrite(unique_ptr op) { - if (op->children.size() == 1) { - root = std::move(op->children[0]); - VisitOperatorExpressions(*op); - op->children[0] = std::move(root); - } - - for (auto &child : op->children) { - child = Rewrite(std::move(child)); - } - return op; -} - -unique_ptr InClauseRewriter::VisitReplace(BoundOperatorExpression &expr, unique_ptr *expr_ptr) { - if (expr.type != ExpressionType::COMPARE_IN && expr.type != ExpressionType::COMPARE_NOT_IN) { - return nullptr; - } - D_ASSERT(root); - auto in_type = expr.children[0]->return_type; - bool is_regular_in = expr.type == ExpressionType::COMPARE_IN; - bool all_scalar = true; - // IN clause with many children: try to generate a mark join that replaces this IN expression - // we can only do this if the expressions in the expression list are scalar - for (idx_t i = 1; i < expr.children.size(); i++) { - if (!expr.children[i]->IsFoldable()) { - // non-scalar expression - all_scalar = false; - } - } - if (expr.children.size() == 2) { - // only one child - // IN: turn into X = 1 - // NOT IN: turn into X <> 1 - return make_uniq(is_regular_in ? ExpressionType::COMPARE_EQUAL - : ExpressionType::COMPARE_NOTEQUAL, - std::move(expr.children[0]), std::move(expr.children[1])); - } - if (expr.children.size() < 6 || !all_scalar) { - // low amount of children or not all scalar - // IN: turn into (X = 1 OR X = 2 OR X = 3...) - // NOT IN: turn into (X <> 1 AND X <> 2 AND X <> 3 ...) - auto conjunction = make_uniq(is_regular_in ? ExpressionType::CONJUNCTION_OR - : ExpressionType::CONJUNCTION_AND); - for (idx_t i = 1; i < expr.children.size(); i++) { - conjunction->children.push_back(make_uniq( - is_regular_in ? ExpressionType::COMPARE_EQUAL : ExpressionType::COMPARE_NOTEQUAL, - expr.children[0]->Copy(), std::move(expr.children[i]))); - } - return std::move(conjunction); - } - // IN clause with many constant children - // generate a mark join that replaces this IN expression - // first generate a ColumnDataCollection from the set of expressions - vector types = {in_type}; - auto collection = make_uniq(context, types); - ColumnDataAppendState append_state; - collection->InitializeAppend(append_state); - - DataChunk chunk; - chunk.Initialize(context, types); - for (idx_t i = 1; i < expr.children.size(); i++) { - // resolve this expression to a constant - auto value = ExpressionExecutor::EvaluateScalar(context, *expr.children[i]); - idx_t index = chunk.size(); - chunk.SetCardinality(chunk.size() + 1); - chunk.SetValue(0, index, value); - if (chunk.size() == STANDARD_VECTOR_SIZE || i + 1 == expr.children.size()) { - // chunk full: append to chunk collection - collection->Append(append_state, chunk); - chunk.Reset(); - } - } - // now generate a ChunkGet that scans this collection - auto chunk_index = optimizer.binder.GenerateTableIndex(); - auto chunk_scan = make_uniq(chunk_index, types, std::move(collection)); - - // then we generate the MARK join with the chunk scan on the RHS - auto join = make_uniq(JoinType::MARK); - join->mark_index = chunk_index; - join->AddChild(std::move(root)); - join->AddChild(std::move(chunk_scan)); - // create the JOIN condition - JoinCondition cond; - cond.left = std::move(expr.children[0]); - - cond.right = make_uniq(in_type, ColumnBinding(chunk_index, 0)); - cond.comparison = ExpressionType::COMPARE_EQUAL; - join->conditions.push_back(std::move(cond)); - root = std::move(join); - - // we replace the original subquery with a BoundColumnRefExpression referring to the mark column - unique_ptr result = - make_uniq("IN (...)", LogicalType::BOOLEAN, ColumnBinding(chunk_index, 0)); - if (!is_regular_in) { - // NOT IN: invert - auto invert = make_uniq(ExpressionType::OPERATOR_NOT, LogicalType::BOOLEAN); - invert->children.push_back(std::move(result)); - result = std::move(invert); - } - return result; -} - -} // namespace duckdb - - - - - - - - - - - -namespace duckdb { - -// The filter was made on top of a logical sample or other projection, -// but no specific columns are referenced. See issue 4978 number 4. -bool CardinalityEstimator::EmptyFilter(FilterInfo &filter_info) { - if (!filter_info.left_set && !filter_info.right_set) { - return true; - } - return false; -} - -void CardinalityEstimator::AddRelationTdom(FilterInfo &filter_info) { - D_ASSERT(filter_info.set.count >= 1); - for (const RelationsToTDom &r2tdom : relations_to_tdoms) { - auto &i_set = r2tdom.equivalent_relations; - if (i_set.find(filter_info.left_binding) != i_set.end()) { - // found an equivalent filter - return; - } - } - - auto key = ColumnBinding(filter_info.left_binding.table_index, filter_info.left_binding.column_index); - RelationsToTDom new_r2tdom(column_binding_set_t({key})); - - relations_to_tdoms.emplace_back(new_r2tdom); -} - -bool CardinalityEstimator::SingleColumnFilter(FilterInfo &filter_info) { - if (filter_info.left_set && filter_info.right_set) { - // Both set - return false; - } - if (EmptyFilter(filter_info)) { - return false; - } - return true; -} - -vector CardinalityEstimator::DetermineMatchingEquivalentSets(FilterInfo *filter_info) { - vector matching_equivalent_sets; - auto equivalent_relation_index = 0; - - for (const RelationsToTDom &r2tdom : relations_to_tdoms) { - auto &i_set = r2tdom.equivalent_relations; - if (i_set.find(filter_info->left_binding) != i_set.end()) { - matching_equivalent_sets.push_back(equivalent_relation_index); - } else if (i_set.find(filter_info->right_binding) != i_set.end()) { - // don't add both left and right to the matching_equivalent_sets - // since both left and right get added to that index anyway. - matching_equivalent_sets.push_back(equivalent_relation_index); - } - equivalent_relation_index++; - } - return matching_equivalent_sets; -} - -void CardinalityEstimator::AddToEquivalenceSets(FilterInfo *filter_info, vector matching_equivalent_sets) { - D_ASSERT(matching_equivalent_sets.size() <= 2); - if (matching_equivalent_sets.size() > 1) { - // an equivalence relation is connecting two sets of equivalence relations - // so push all relations from the second set into the first. Later we will delete - // the second set. - for (ColumnBinding i : relations_to_tdoms.at(matching_equivalent_sets[1]).equivalent_relations) { - relations_to_tdoms.at(matching_equivalent_sets[0]).equivalent_relations.insert(i); - } - for (auto &column_name : relations_to_tdoms.at(matching_equivalent_sets[1]).column_names) { - relations_to_tdoms.at(matching_equivalent_sets[0]).column_names.push_back(column_name); - } - relations_to_tdoms.at(matching_equivalent_sets[1]).equivalent_relations.clear(); - relations_to_tdoms.at(matching_equivalent_sets[1]).column_names.clear(); - relations_to_tdoms.at(matching_equivalent_sets[0]).filters.push_back(filter_info); - // add all values of one set to the other, delete the empty one - } else if (matching_equivalent_sets.size() == 1) { - auto &tdom_i = relations_to_tdoms.at(matching_equivalent_sets.at(0)); - tdom_i.equivalent_relations.insert(filter_info->left_binding); - tdom_i.equivalent_relations.insert(filter_info->right_binding); - tdom_i.filters.push_back(filter_info); - } else if (matching_equivalent_sets.empty()) { - column_binding_set_t tmp; - tmp.insert(filter_info->left_binding); - tmp.insert(filter_info->right_binding); - relations_to_tdoms.emplace_back(tmp); - relations_to_tdoms.back().filters.push_back(filter_info); - } -} - -void CardinalityEstimator::InitEquivalentRelations(const vector> &filter_infos) { - // For each filter, we fill keep track of the index of the equivalent relation set - // the left and right relation needs to be added to. - for (auto &filter : filter_infos) { - if (SingleColumnFilter(*filter)) { - // Filter on one relation, (i.e string or range filter on a column). - // Grab the first relation and add it to the equivalence_relations - AddRelationTdom(*filter); - continue; - } else if (EmptyFilter(*filter)) { - continue; - } - D_ASSERT(filter->left_set->count >= 1); - D_ASSERT(filter->right_set->count >= 1); - - auto matching_equivalent_sets = DetermineMatchingEquivalentSets(filter.get()); - AddToEquivalenceSets(filter.get(), matching_equivalent_sets); - } - RemoveEmptyTotalDomains(); -} - -void CardinalityEstimator::RemoveEmptyTotalDomains() { - auto remove_start = std::remove_if(relations_to_tdoms.begin(), relations_to_tdoms.end(), - [](RelationsToTDom &r_2_tdom) { return r_2_tdom.equivalent_relations.empty(); }); - relations_to_tdoms.erase(remove_start, relations_to_tdoms.end()); -} - -void UpdateDenom(Subgraph2Denominator &relation_2_denom, RelationsToTDom &relation_to_tdom) { - relation_2_denom.denom *= relation_to_tdom.has_tdom_hll ? relation_to_tdom.tdom_hll : relation_to_tdom.tdom_no_hll; -} - -void FindSubgraphMatchAndMerge(Subgraph2Denominator &merge_to, idx_t find_me, - vector::iterator subgraph, - vector::iterator end) { - for (; subgraph != end; subgraph++) { - if (subgraph->relations.count(find_me) >= 1) { - for (auto &relation : subgraph->relations) { - merge_to.relations.insert(relation); - } - subgraph->relations.clear(); - merge_to.denom *= subgraph->denom; - return; - } - } -} - -template <> -double CardinalityEstimator::EstimateCardinalityWithSet(JoinRelationSet &new_set) { - - if (relation_set_2_cardinality.find(new_set.ToString()) != relation_set_2_cardinality.end()) { - return relation_set_2_cardinality[new_set.ToString()].cardinality_before_filters; - } - double numerator = 1; - unordered_set actual_set; - - for (idx_t i = 0; i < new_set.count; i++) { - auto &single_node_set = set_manager.GetJoinRelation(new_set.relations[i]); - auto card_helper = relation_set_2_cardinality[single_node_set.ToString()]; - numerator *= card_helper.cardinality_before_filters; - actual_set.insert(new_set.relations[i]); - } - - vector subgraphs; - bool done = false; - bool found_match = false; - - // Finding the denominator is tricky. You need to go through the tdoms in decreasing order - // Then loop through all filters in the equivalence set of the tdom to see if both the - // left and right relations are in the new set, if so you can use that filter. - // You must also make sure that the filters all relations in the given set, so we use subgraphs - // that should eventually merge into one connected graph that joins all the relations - // TODO: Implement a method to cache subgraphs so you don't have to build them up every - // time the cardinality of a new set is requested - - // relations_to_tdoms has already been sorted. - for (auto &relation_2_tdom : relations_to_tdoms) { - // loop through each filter in the tdom. - if (done) { - break; - } - for (auto &filter : relation_2_tdom.filters) { - if (actual_set.count(filter->left_binding.table_index) == 0 || - actual_set.count(filter->right_binding.table_index) == 0) { - continue; - } - // the join filter is on relations in the new set. - found_match = false; - vector::iterator it; - for (it = subgraphs.begin(); it != subgraphs.end(); it++) { - auto left_in = it->relations.count(filter->left_binding.table_index); - auto right_in = it->relations.count(filter->right_binding.table_index); - if (left_in && right_in) { - // if both left and right bindings are in the subgraph, continue. - // This means another filter is connecting relations already in the - // subgraph it, but it has a tdom that is less, and we don't care. - found_match = true; - continue; - } - if (!left_in && !right_in) { - // if both left and right bindings are *not* in the subgraph, continue - // without finding a match. This will trigger the process to add a new - // subgraph - continue; - } - idx_t find_table; - if (left_in) { - find_table = filter->right_binding.table_index; - } else { - D_ASSERT(right_in); - find_table = filter->left_binding.table_index; - } - auto next_subgraph = it + 1; - // iterate through other subgraphs and merge. - FindSubgraphMatchAndMerge(*it, find_table, next_subgraph, subgraphs.end()); - // Now insert the right binding and update denominator with the - // tdom of the filter - it->relations.insert(find_table); - UpdateDenom(*it, relation_2_tdom); - found_match = true; - break; - } - // means that the filter joins relations in the given set, but there is no - // connection to any subgraph in subgraphs. Add a new subgraph, and maybe later there will be - // a connection. - if (!found_match) { - subgraphs.emplace_back(); - auto &subgraph = subgraphs.back(); - subgraph.relations.insert(filter->left_binding.table_index); - subgraph.relations.insert(filter->right_binding.table_index); - UpdateDenom(subgraph, relation_2_tdom); - } - auto remove_start = std::remove_if(subgraphs.begin(), subgraphs.end(), - [](Subgraph2Denominator &s) { return s.relations.empty(); }); - subgraphs.erase(remove_start, subgraphs.end()); - - if (subgraphs.size() == 1 && subgraphs.at(0).relations.size() == new_set.count) { - // You have found enough filters to connect the relations. These are guaranteed - // to be the filters with the highest Tdoms. - done = true; - break; - } - } - } - double denom = 1; - // TODO: It's possible cross-products were added and are not present in the filters in the relation_2_tdom - // structures. When that's the case, multiply the denom structures that have no intersection - for (auto &match : subgraphs) { - denom *= match.denom; - } - // can happen if a table has cardinality 0, or a tdom is set to 0 - if (denom == 0) { - denom = 1; - } - auto result = numerator / denom; - auto new_entry = CardinalityHelper((double)result, 1); - relation_set_2_cardinality[new_set.ToString()] = new_entry; - return result; -} - -template <> -idx_t CardinalityEstimator::EstimateCardinalityWithSet(JoinRelationSet &new_set) { - auto cardinality_as_double = EstimateCardinalityWithSet(new_set); - auto max = NumericLimits::Maximum(); - if (cardinality_as_double > max) { - return max; - } - return (idx_t)cardinality_as_double; -} - -bool SortTdoms(const RelationsToTDom &a, const RelationsToTDom &b) { - if (a.has_tdom_hll && b.has_tdom_hll) { - return a.tdom_hll > b.tdom_hll; - } - if (a.has_tdom_hll) { - return a.tdom_hll > b.tdom_no_hll; - } - if (b.has_tdom_hll) { - return a.tdom_no_hll > b.tdom_hll; - } - return a.tdom_no_hll > b.tdom_no_hll; -} - -void CardinalityEstimator::InitCardinalityEstimatorProps(optional_ptr set, RelationStats &stats) { - // Get the join relation set - D_ASSERT(stats.stats_initialized); - auto relation_cardinality = stats.cardinality; - auto relation_filter = stats.filter_strength; - - auto card_helper = CardinalityHelper(relation_cardinality, relation_filter); - relation_set_2_cardinality[set->ToString()] = card_helper; - - UpdateTotalDomains(set, stats); - - // sort relations from greatest tdom to lowest tdom. - std::sort(relations_to_tdoms.begin(), relations_to_tdoms.end(), SortTdoms); -} - -void CardinalityEstimator::UpdateTotalDomains(optional_ptr set, RelationStats &stats) { - D_ASSERT(set->count == 1); - auto relation_id = set->relations[0]; - //! Initialize the distinct count for all columns used in joins with the current relation. - // D_ASSERT(stats.column_distinct_count.size() >= 1); - - for (idx_t i = 0; i < stats.column_distinct_count.size(); i++) { - //! for every column used in a filter in the relation, get the distinct count via HLL, or assume it to be - //! the cardinality - // Update the relation_to_tdom set with the estimated distinct count (or tdom) calculated above - auto key = ColumnBinding(relation_id, i); - for (auto &relation_to_tdom : relations_to_tdoms) { - column_binding_set_t i_set = relation_to_tdom.equivalent_relations; - if (i_set.find(key) == i_set.end()) { - continue; - } - auto distinct_count = stats.column_distinct_count.at(i); - if (distinct_count.from_hll && relation_to_tdom.has_tdom_hll) { - relation_to_tdom.tdom_hll = MaxValue(relation_to_tdom.tdom_hll, distinct_count.distinct_count); - } else if (distinct_count.from_hll && !relation_to_tdom.has_tdom_hll) { - relation_to_tdom.has_tdom_hll = true; - relation_to_tdom.tdom_hll = distinct_count.distinct_count; - } else { - relation_to_tdom.tdom_no_hll = MinValue(distinct_count.distinct_count, relation_to_tdom.tdom_no_hll); - } - break; - } - } -} - -// LCOV_EXCL_START - -void CardinalityEstimator::AddRelationNamesToTdoms(vector &stats) { -#ifdef DEBUG - for (auto &total_domain : relations_to_tdoms) { - for (auto &binding : total_domain.equivalent_relations) { - D_ASSERT(binding.table_index < stats.size()); - D_ASSERT(binding.column_index < stats.at(binding.table_index).column_names.size()); - string column_name = stats.at(binding.table_index).column_names.at(binding.column_index); - total_domain.column_names.push_back(column_name); - } - } -#endif -} - -void CardinalityEstimator::PrintRelationToTdomInfo() { - for (auto &total_domain : relations_to_tdoms) { - string domain = "Following columns have the same distinct count: "; - for (auto &column_name : total_domain.column_names) { - domain += column_name + ", "; - } - bool have_hll = total_domain.has_tdom_hll; - domain += "\n TOTAL DOMAIN = " + to_string(have_hll ? total_domain.tdom_hll : total_domain.tdom_no_hll); - Printer::Print(domain); - } -} - -// LCOV_EXCL_STOP - -} // namespace duckdb - - - -#include - -namespace duckdb { - -CostModel::CostModel(QueryGraphManager &query_graph_manager) - : query_graph_manager(query_graph_manager), cardinality_estimator() { -} - -double CostModel::ComputeCost(JoinNode &left, JoinNode &right) { - auto &combination = query_graph_manager.set_manager.Union(left.set, right.set); - auto join_card = cardinality_estimator.EstimateCardinalityWithSet(combination); - auto join_cost = join_card; - return join_cost + left.cost + right.cost; -} - -} // namespace duckdb - - - -namespace duckdb { - -template <> -double EstimatedProperties::GetCardinality() const { - return cardinality; -} - -template <> -idx_t EstimatedProperties::GetCardinality() const { - auto max_idx_t = NumericLimits::Maximum() - 10000; - return MinValue(cardinality, max_idx_t); -} - -template <> -double EstimatedProperties::GetCost() const { - return cost; -} - -template <> -idx_t EstimatedProperties::GetCost() const { - auto max_idx_t = NumericLimits::Maximum() - 10000; - return MinValue(cost, max_idx_t); -} - -void EstimatedProperties::SetCardinality(double new_card) { - cardinality = new_card; -} - -void EstimatedProperties::SetCost(double new_cost) { - cost = new_cost; -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -JoinNode::JoinNode(JoinRelationSet &set) : set(set) { -} - -JoinNode::JoinNode(JoinRelationSet &set, optional_ptr info, JoinNode &left, JoinNode &right, double cost) - : set(set), info(info), left(&left), right(&right), cost(cost) { -} - -unique_ptr EstimatedProperties::Copy() { - auto result = make_uniq(cardinality, cost); - return result; -} - -string JoinNode::ToString() { - string result = "-------------------------------\n"; - result += set.ToString() + "\n"; - result += "cost = " + to_string(cost) + "\n"; - result += "left = \n"; - if (left) { - result += left->ToString(); - } - result += "right = \n"; - if (right) { - result += right->ToString(); - } - return result; -} - -} // namespace duckdb - - - - - - - - - - -#include -#include - -namespace duckdb { - -static bool HasJoin(LogicalOperator *op) { - while (!op->children.empty()) { - if (op->children.size() == 1) { - op = op->children[0].get(); - } - if (op->children.size() == 2) { - return true; - } - } - return false; -} - -unique_ptr JoinOrderOptimizer::Optimize(unique_ptr plan, - optional_ptr stats) { - - // make sure query graph manager has not extracted a relation graph already - LogicalOperator *op = plan.get(); - - // extract the relations that go into the hyper graph. - // We optimize the children of any non-reorderable operations we come across. - bool reorderable = query_graph_manager.Build(*op); - - // get relation_stats here since the reconstruction process will move all of the relations. - auto relation_stats = query_graph_manager.relation_manager.GetRelationStats(); - unique_ptr new_logical_plan = nullptr; - - if (reorderable) { - // query graph now has filters and relations - auto cost_model = CostModel(query_graph_manager); - - // Initialize a plan enumerator. - auto plan_enumerator = - PlanEnumerator(query_graph_manager, cost_model, query_graph_manager.GetQueryGraphEdges()); - - // Initialize the leaf/single node plans - plan_enumerator.InitLeafPlans(); - - // Ask the plan enumerator to enumerate a number of join orders - auto final_plan = plan_enumerator.SolveJoinOrder(); - // TODO: add in the check that if no plan exists, you have to add a cross product. - - // now reconstruct a logical plan from the query graph plan - new_logical_plan = query_graph_manager.Reconstruct(std::move(plan), *final_plan); - } else { - new_logical_plan = std::move(plan); - if (relation_stats.size() == 1) { - new_logical_plan->estimated_cardinality = relation_stats.at(0).cardinality; - } - } - - // only perform left right optimizations when stats is null (means we have the top level optimize call) - // Don't check reorderability because non-reorderable joins will result in 1 relation, but we can - // still switch the children. - // TODO: put this in a different optimizer maybe? - if (stats == nullptr && HasJoin(new_logical_plan.get())) { - new_logical_plan = query_graph_manager.LeftRightOptimizations(std::move(new_logical_plan)); - } - - // Propagate up a stats object from the top of the new_logical_plan if stats exist. - if (stats) { - auto cardinality = new_logical_plan->EstimateCardinality(context); - auto bindings = new_logical_plan->GetColumnBindings(); - auto new_stats = RelationStatisticsHelper::CombineStatsOfReorderableOperator(bindings, relation_stats); - new_stats.cardinality = MaxValue(cardinality, new_stats.cardinality); - RelationStatisticsHelper::CopyRelationStats(*stats, new_stats); - } - - return new_logical_plan; -} - -} // namespace duckdb - - - - -#include - -namespace duckdb { - -using JoinRelationTreeNode = JoinRelationSetManager::JoinRelationTreeNode; - -// LCOV_EXCL_START -string JoinRelationSet::ToString() const { - string result = "["; - result += StringUtil::Join(relations, count, ", ", [](const idx_t &relation) { return to_string(relation); }); - result += "]"; - return result; -} -// LCOV_EXCL_STOP - -//! Returns true if sub is a subset of super -bool JoinRelationSet::IsSubset(JoinRelationSet &super, JoinRelationSet &sub) { - D_ASSERT(sub.count > 0); - if (sub.count > super.count) { - return false; - } - idx_t j = 0; - for (idx_t i = 0; i < super.count; i++) { - if (sub.relations[j] == super.relations[i]) { - j++; - if (j == sub.count) { - return true; - } - } - } - return false; -} - -JoinRelationSet &JoinRelationSetManager::GetJoinRelation(unsafe_unique_array relations, idx_t count) { - // now look it up in the tree - reference info(root); - for (idx_t i = 0; i < count; i++) { - auto entry = info.get().children.find(relations[i]); - if (entry == info.get().children.end()) { - // node not found, create it - auto insert_it = info.get().children.insert(make_pair(relations[i], make_uniq())); - entry = insert_it.first; - } - // move to the next node - info = *entry->second; - } - // now check if the JoinRelationSet has already been created - if (!info.get().relation) { - // if it hasn't we need to create it - info.get().relation = make_uniq(std::move(relations), count); - } - return *info.get().relation; -} - -//! Create or get a JoinRelationSet from a single node with the given index -JoinRelationSet &JoinRelationSetManager::GetJoinRelation(idx_t index) { - // create a sorted vector of the relations - auto relations = make_unsafe_uniq_array(1); - relations[0] = index; - idx_t count = 1; - return GetJoinRelation(std::move(relations), count); -} - -JoinRelationSet &JoinRelationSetManager::GetJoinRelation(const unordered_set &bindings) { - // create a sorted vector of the relations - unsafe_unique_array relations = bindings.empty() ? nullptr : make_unsafe_uniq_array(bindings.size()); - idx_t count = 0; - for (auto &entry : bindings) { - relations[count++] = entry; - } - std::sort(relations.get(), relations.get() + count); - return GetJoinRelation(std::move(relations), count); -} - -JoinRelationSet &JoinRelationSetManager::Union(JoinRelationSet &left, JoinRelationSet &right) { - auto relations = make_unsafe_uniq_array(left.count + right.count); - idx_t count = 0; - // move through the left and right relations, eliminating duplicates - idx_t i = 0, j = 0; - while (true) { - if (i == left.count) { - // exhausted left relation, add remaining of right relation - for (; j < right.count; j++) { - relations[count++] = right.relations[j]; - } - break; - } else if (j == right.count) { - // exhausted right relation, add remaining of left - for (; i < left.count; i++) { - relations[count++] = left.relations[i]; - } - break; - } else if (left.relations[i] < right.relations[j]) { - // left is smaller, progress left and add it to the set - relations[count++] = left.relations[i]; - i++; - } else { - D_ASSERT(left.relations[i] > right.relations[j]); - // right is smaller, progress right and add it to the set - relations[count++] = right.relations[j]; - j++; - } - } - return GetJoinRelation(std::move(relations), count); -} - -// JoinRelationSet *JoinRelationSetManager::Difference(JoinRelationSet *left, JoinRelationSet *right) { -// auto relations = unsafe_unique_array(new idx_t[left->count]); -// idx_t count = 0; -// // move through the left and right relations -// idx_t i = 0, j = 0; -// while (true) { -// if (i == left->count) { -// // exhausted left relation, we are done -// break; -// } else if (j == right->count) { -// // exhausted right relation, add remaining of left -// for (; i < left->count; i++) { -// relations[count++] = left->relations[i]; -// } -// break; -// } else if (left->relations[i] == right->relations[j]) { -// // equivalent, add nothing -// i++; -// j++; -// } else if (left->relations[i] < right->relations[j]) { -// // left is smaller, progress left and add it to the set -// relations[count++] = left->relations[i]; -// i++; -// } else { -// // right is smaller, progress right -// j++; -// } -// } -// return GetJoinRelation(std::move(relations), count); -// } - -} // namespace duckdb - - - - - -namespace duckdb { - -bool PlanEnumerator::NodeInFullPlan(JoinNode &node) { - return join_nodes_in_full_plan.find(node.set.ToString()) != join_nodes_in_full_plan.end(); -} - -void PlanEnumerator::UpdateJoinNodesInFullPlan(JoinNode &node) { - if (node.set.count == query_graph_manager.relation_manager.NumRelations()) { - join_nodes_in_full_plan.clear(); - } - if (node.set.count < query_graph_manager.relation_manager.NumRelations()) { - join_nodes_in_full_plan.insert(node.set.ToString()); - } - if (node.left) { - UpdateJoinNodesInFullPlan(*node.left); - } - if (node.right) { - UpdateJoinNodesInFullPlan(*node.right); - } -} - -static vector> AddSuperSets(const vector> ¤t, - const vector &all_neighbors) { - vector> ret; - - for (const auto &neighbor_set : current) { - auto max_val = std::max_element(neighbor_set.begin(), neighbor_set.end()); - for (const auto &neighbor : all_neighbors) { - if (*max_val >= neighbor) { - continue; - } - if (neighbor_set.count(neighbor) == 0) { - unordered_set new_set; - for (auto &n : neighbor_set) { - new_set.insert(n); - } - new_set.insert(neighbor); - ret.push_back(new_set); - } - } - } - - return ret; -} - -//! Update the exclusion set with all entries in the subgraph -static void UpdateExclusionSet(optional_ptr node, unordered_set &exclusion_set) { - for (idx_t i = 0; i < node->count; i++) { - exclusion_set.insert(node->relations[i]); - } -} - -// works by first creating all sets with cardinality 1 -// then iterates over each previously created group of subsets and will only add a neighbor if the neighbor -// is greater than all relations in the set. -static vector> GetAllNeighborSets(vector neighbors) { - vector> ret; - sort(neighbors.begin(), neighbors.end()); - vector> added; - for (auto &neighbor : neighbors) { - added.push_back(unordered_set({neighbor})); - ret.push_back(unordered_set({neighbor})); - } - do { - added = AddSuperSets(added, neighbors); - for (auto &d : added) { - ret.push_back(d); - } - } while (!added.empty()); -#if DEBUG - // drive by test to make sure we have an accurate amount of - // subsets, and that each neighbor is in a correct amount - // of those subsets. - D_ASSERT(ret.size() == pow(2, neighbors.size()) - 1); - for (auto &n : neighbors) { - idx_t count = 0; - for (auto &set : ret) { - if (set.count(n) >= 1) { - count += 1; - } - } - D_ASSERT(count == pow(2, neighbors.size() - 1)); - } -#endif - return ret; -} - -void PlanEnumerator::GenerateCrossProducts() { - // generate a set of cross products to combine the currently available plans into a full join plan - // we create edges between every relation with a high cost - for (idx_t i = 0; i < query_graph_manager.relation_manager.NumRelations(); i++) { - auto &left = query_graph_manager.set_manager.GetJoinRelation(i); - for (idx_t j = 0; j < query_graph_manager.relation_manager.NumRelations(); j++) { - if (i != j) { - auto &right = query_graph_manager.set_manager.GetJoinRelation(j); - query_graph_manager.CreateQueryGraphCrossProduct(left, right); - } - } - } - // Now that the query graph has new edges, we need to re-initialize our query graph. - // TODO: do we need to initialize our qyery graph again? - // query_graph = query_graph_manager.GetQueryGraph(); -} - -//! Create a new JoinTree node by joining together two previous JoinTree nodes -unique_ptr PlanEnumerator::CreateJoinTree(JoinRelationSet &set, - const vector> &possible_connections, - JoinNode &left, JoinNode &right) { - // for the hash join we want the right side (build side) to have the smallest cardinality - // also just a heuristic but for now... - // FIXME: we should probably actually benchmark that as well - // FIXME: should consider different join algorithms, should we pick a join algorithm here as well? (probably) - optional_ptr best_connection = nullptr; - - // cross products are techincally still connections, but the filter expression is a null_ptr - if (!possible_connections.empty()) { - best_connection = &possible_connections.back().get(); - } - - auto cost = cost_model.ComputeCost(left, right); - auto result = make_uniq(set, best_connection, left, right, cost); - result->cardinality = cost_model.cardinality_estimator.EstimateCardinalityWithSet(set); - return result; -} - -JoinNode &PlanEnumerator::EmitPair(JoinRelationSet &left, JoinRelationSet &right, - const vector> &info) { - // get the left and right join plans - auto left_plan = plans.find(left); - auto right_plan = plans.find(right); - if (left_plan == plans.end() || right_plan == plans.end()) { - throw InternalException("No left or right plan: internal error in join order optimizer"); - } - auto &new_set = query_graph_manager.set_manager.Union(left, right); - // create the join tree based on combining the two plans - auto new_plan = CreateJoinTree(new_set, info, *left_plan->second, *right_plan->second); - // check if this plan is the optimal plan we found for this set of relations - auto entry = plans.find(new_set); - auto new_cost = new_plan->cost; - double old_cost = NumericLimits::Maximum(); - if (entry != plans.end()) { - old_cost = entry->second->cost; - } - if (entry == plans.end() || new_cost < old_cost) { - // the new plan costs less than the old plan. Update our DP tree and cost tree - auto &result = *new_plan; - - if (full_plan_found && - join_nodes_in_full_plan.find(new_plan->set.ToString()) != join_nodes_in_full_plan.end()) { - must_update_full_plan = true; - } - if (new_set.count == query_graph_manager.relation_manager.NumRelations()) { - full_plan_found = true; - // If we find a full plan, we need to keep track of which nodes are in the full plan. - // It's possible the DP algorithm updates a node in the current full plan, then moves on - // to the SolveApproximately. SolveApproximately may find a full plan with a higher cost than - // what SolveExactly found. In this case, we revert to the SolveExactly plan, but it is - // possible to get use-after-free errors if the SolveApproximately algorithm updated some (but not all) - // nodes in the SolveExactly plan - // If we know a node in the full plan is updated, we can prevent ourselves from exiting the - // DP algorithm until the last plan updated is a full plan - UpdateJoinNodesInFullPlan(result); - if (must_update_full_plan) { - must_update_full_plan = false; - } - } - - D_ASSERT(new_plan); - plans[new_set] = std::move(new_plan); - return result; - } - return *entry->second; -} - -bool PlanEnumerator::TryEmitPair(JoinRelationSet &left, JoinRelationSet &right, - const vector> &info) { - pairs++; - // If a full plan is created, it's possible a node in the plan gets updated. When this happens, make sure you keep - // emitting pairs until you emit another final plan. Another final plan is guaranteed to be produced because of - // our symmetry guarantees. - if (pairs >= 10000 && !must_update_full_plan) { - // when the amount of pairs gets too large we exit the dynamic programming and resort to a greedy algorithm - // FIXME: simple heuristic currently - // at 10K pairs stop searching exactly and switch to heuristic - return false; - } - EmitPair(left, right, info); - return true; -} - -bool PlanEnumerator::EmitCSG(JoinRelationSet &node) { - if (node.count == query_graph_manager.relation_manager.NumRelations()) { - return true; - } - // create the exclusion set as everything inside the subgraph AND anything with members BELOW it - unordered_set exclusion_set; - for (idx_t i = 0; i < node.relations[0]; i++) { - exclusion_set.insert(i); - } - UpdateExclusionSet(&node, exclusion_set); - // find the neighbors given this exclusion set - auto neighbors = query_graph.GetNeighbors(node, exclusion_set); - if (neighbors.empty()) { - return true; - } - - //! Neighbors should be reversed when iterating over them. - std::sort(neighbors.begin(), neighbors.end(), std::greater_equal()); - for (idx_t i = 0; i < neighbors.size() - 1; i++) { - D_ASSERT(neighbors[i] > neighbors[i + 1]); - } - - // Dphyp paper missiing this. - // Because we are traversing in reverse order, we need to add neighbors whose number is smaller than the current - // node to exclusion_set - // This avoids duplicated enumeration - unordered_set new_exclusion_set = exclusion_set; - for (idx_t i = 0; i < neighbors.size(); ++i) { - D_ASSERT(new_exclusion_set.find(neighbors[i]) == new_exclusion_set.end()); - new_exclusion_set.insert(neighbors[i]); - } - - for (auto neighbor : neighbors) { - // since the GetNeighbors only returns the smallest element in a list, the entry might not be connected to - // (only!) this neighbor, hence we have to do a connectedness check before we can emit it - auto &neighbor_relation = query_graph_manager.set_manager.GetJoinRelation(neighbor); - auto connections = query_graph.GetConnections(node, neighbor_relation); - if (!connections.empty()) { - if (!TryEmitPair(node, neighbor_relation, connections)) { - return false; - } - } - - if (!EnumerateCmpRecursive(node, neighbor_relation, new_exclusion_set)) { - return false; - } - - new_exclusion_set.erase(neighbor); - } - return true; -} - -bool PlanEnumerator::EnumerateCmpRecursive(JoinRelationSet &left, JoinRelationSet &right, - unordered_set &exclusion_set) { - // get the neighbors of the second relation under the exclusion set - auto neighbors = query_graph.GetNeighbors(right, exclusion_set); - if (neighbors.empty()) { - return true; - } - - auto all_subset = GetAllNeighborSets(neighbors); - vector> union_sets; - union_sets.reserve(all_subset.size()); - for (const auto &rel_set : all_subset) { - auto &neighbor = query_graph_manager.set_manager.GetJoinRelation(rel_set); - // emit the combinations of this node and its neighbors - auto &combined_set = query_graph_manager.set_manager.Union(right, neighbor); - // If combined_set.count == right.count, This means we found a neighbor that has been present before - // This means we didn't set exclusion_set correctly. - D_ASSERT(combined_set.count > right.count); - if (plans.find(combined_set) != plans.end()) { - auto connections = query_graph.GetConnections(left, combined_set); - if (!connections.empty()) { - if (!TryEmitPair(left, combined_set, connections)) { - return false; - } - } - } - union_sets.push_back(combined_set); - } - - unordered_set new_exclusion_set = exclusion_set; - for (const auto &neighbor : neighbors) { - new_exclusion_set.insert(neighbor); - } - - // recursively enumerate the sets - for (idx_t i = 0; i < union_sets.size(); i++) { - // updated the set of excluded entries with this neighbor - if (!EnumerateCmpRecursive(left, union_sets[i], new_exclusion_set)) { - return false; - } - } - return true; -} - -bool PlanEnumerator::EnumerateCSGRecursive(JoinRelationSet &node, unordered_set &exclusion_set) { - // find neighbors of S under the exclusion set - auto neighbors = query_graph.GetNeighbors(node, exclusion_set); - if (neighbors.empty()) { - return true; - } - - auto all_subset = GetAllNeighborSets(neighbors); - vector> union_sets; - union_sets.reserve(all_subset.size()); - for (const auto &rel_set : all_subset) { - auto &neighbor = query_graph_manager.set_manager.GetJoinRelation(rel_set); - // emit the combinations of this node and its neighbors - auto &new_set = query_graph_manager.set_manager.Union(node, neighbor); - D_ASSERT(new_set.count > node.count); - if (plans.find(new_set) != plans.end()) { - if (!EmitCSG(new_set)) { - return false; - } - } - union_sets.push_back(new_set); - } - - unordered_set new_exclusion_set = exclusion_set; - for (const auto &neighbor : neighbors) { - new_exclusion_set.insert(neighbor); - } - - // recursively enumerate the sets - for (idx_t i = 0; i < union_sets.size(); i++) { - // updated the set of excluded entries with this neighbor - if (!EnumerateCSGRecursive(union_sets[i], new_exclusion_set)) { - return false; - } - } - return true; -} - -bool PlanEnumerator::SolveJoinOrderExactly() { - // now we perform the actual dynamic programming to compute the final result - // we enumerate over all the possible pairs in the neighborhood - for (idx_t i = query_graph_manager.relation_manager.NumRelations(); i > 0; i--) { - // for every node in the set, we consider it as the start node once - auto &start_node = query_graph_manager.set_manager.GetJoinRelation(i - 1); - // emit the start node - if (!EmitCSG(start_node)) { - return false; - } - // initialize the set of exclusion_set as all the nodes with a number below this - unordered_set exclusion_set; - for (idx_t j = 0; j < i; j++) { - exclusion_set.insert(j); - } - // then we recursively search for neighbors that do not belong to the banned entries - if (!EnumerateCSGRecursive(start_node, exclusion_set)) { - return false; - } - } - return true; -} - -void PlanEnumerator::UpdateDPTree(JoinNode &new_plan) { - if (!NodeInFullPlan(new_plan)) { - // if the new node is not in the full plan, feel free to return - // because you won't be updating the full plan. - return; - } - auto &new_set = new_plan.set; - // now update every plan that uses this plan - unordered_set exclusion_set; - for (idx_t i = 0; i < new_set.count; i++) { - exclusion_set.insert(new_set.relations[i]); - } - auto neighbors = query_graph.GetNeighbors(new_set, exclusion_set); - auto all_neighbors = GetAllNeighborSets(neighbors); - for (const auto &neighbor : all_neighbors) { - auto &neighbor_relation = query_graph_manager.set_manager.GetJoinRelation(neighbor); - auto &combined_set = query_graph_manager.set_manager.Union(new_set, neighbor_relation); - - auto combined_set_plan = plans.find(combined_set); - if (combined_set_plan == plans.end()) { - continue; - } - - double combined_set_plan_cost = combined_set_plan->second->cost; // combined_set_plan->second->GetCost(); - auto connections = query_graph.GetConnections(new_set, neighbor_relation); - // recurse and update up the tree if the combined set produces a plan with a lower cost - // only recurse on neighbor relations that have plans. - auto right_plan = plans.find(neighbor_relation); - if (right_plan == plans.end()) { - continue; - } - auto &updated_plan = EmitPair(new_set, neighbor_relation, connections); - // <= because the child node has already been replaced. You need to - // replace the parent node as well in this case - if (updated_plan.cost < combined_set_plan_cost) { - UpdateDPTree(updated_plan); - } - } -} - -void PlanEnumerator::SolveJoinOrderApproximately() { - // at this point, we exited the dynamic programming but did not compute the final join order because it took too - // long instead, we use a greedy heuristic to obtain a join ordering now we use Greedy Operator Ordering to - // construct the result tree first we start out with all the base relations (the to-be-joined relations) - vector> join_relations; // T in the paper - for (idx_t i = 0; i < query_graph_manager.relation_manager.NumRelations(); i++) { - join_relations.push_back(query_graph_manager.set_manager.GetJoinRelation(i)); - } - while (join_relations.size() > 1) { - // now in every step of the algorithm, we greedily pick the join between the to-be-joined relations that has the - // smallest cost. This is O(r^2) per step, and every step will reduce the total amount of relations to-be-joined - // by 1, so the total cost is O(r^3) in the amount of relations - idx_t best_left = 0, best_right = 0; - optional_ptr best_connection; - for (idx_t i = 0; i < join_relations.size(); i++) { - auto left = join_relations[i]; - for (idx_t j = i + 1; j < join_relations.size(); j++) { - auto right = join_relations[j]; - // check if we can connect these two relations - auto connection = query_graph.GetConnections(left, right); - if (!connection.empty()) { - // we can check the cost of this connection - auto &node = EmitPair(left, right, connection); - - // update the DP tree in case a plan created by the DP algorithm uses the node - // that was potentially just updated by EmitPair. You will get a use-after-free - // error if future plans rely on the old node that was just replaced. - // if node in FullPath, then updateDP tree. - UpdateDPTree(node); - - if (!best_connection || node.cost < best_connection->cost) { - // best pair found so far - best_connection = &node; - best_left = i; - best_right = j; - } - } - } - } - if (!best_connection) { - // could not find a connection, but we were not done with finding a completed plan - // we have to add a cross product; we add it between the two smallest relations - optional_ptr smallest_plans[2]; - idx_t smallest_index[2]; - D_ASSERT(join_relations.size() >= 2); - - // first just add the first two join relations. It doesn't matter the cost as the JOO - // will swap them on estimated cardinality anyway. - for (idx_t i = 0; i < 2; i++) { - auto current_plan = plans[join_relations[i]].get(); - smallest_plans[i] = current_plan; - smallest_index[i] = i; - } - - // if there are any other join relations that don't have connections - // add them if they have lower estimated cardinality. - for (idx_t i = 2; i < join_relations.size(); i++) { - // get the plan for this relation - auto current_plan = plans[join_relations[i].get()].get(); - // check if the cardinality is smaller than the smallest two found so far - for (idx_t j = 0; j < 2; j++) { - if (!smallest_plans[j] || smallest_plans[j]->cost > current_plan->cost) { - smallest_plans[j] = current_plan; - smallest_index[j] = i; - break; - } - } - } - if (!smallest_plans[0] || !smallest_plans[1]) { - throw InternalException("Internal error in join order optimizer"); - } - D_ASSERT(smallest_plans[0] && smallest_plans[1]); - D_ASSERT(smallest_index[0] != smallest_index[1]); - auto &left = smallest_plans[0]->set; - auto &right = smallest_plans[1]->set; - // create a cross product edge (i.e. edge with empty filter) between these two sets in the query graph - query_graph_manager.CreateQueryGraphCrossProduct(left, right); - // now emit the pair and continue with the algorithm - auto connections = query_graph.GetConnections(left, right); - D_ASSERT(!connections.empty()); - - best_connection = &EmitPair(left, right, connections); - best_left = smallest_index[0]; - best_right = smallest_index[1]; - - UpdateDPTree(*best_connection); - // the code below assumes best_right > best_left - if (best_left > best_right) { - std::swap(best_left, best_right); - } - } - // now update the to-be-checked pairs - // remove left and right, and add the combination - - // important to erase the biggest element first - // if we erase the smallest element first the index of the biggest element changes - D_ASSERT(best_right > best_left); - join_relations.erase(join_relations.begin() + best_right); - join_relations.erase(join_relations.begin() + best_left); - join_relations.push_back(best_connection->set); - } -} - -void PlanEnumerator::InitLeafPlans() { - // First we initialize each of the single-node plans with themselves and with their cardinalities these are the leaf - // nodes of the join tree NOTE: we can just use pointers to JoinRelationSet* here because the GetJoinRelation - // function ensures that a unique combination of relations will have a unique JoinRelationSet object. - // first initialize equivalent relations based on the filters - auto relation_stats = query_graph_manager.relation_manager.GetRelationStats(); - - cost_model.cardinality_estimator.InitEquivalentRelations(query_graph_manager.GetFilterBindings()); - cost_model.cardinality_estimator.AddRelationNamesToTdoms(relation_stats); - - // then update the total domains based on the cardinalities of each relation. - for (idx_t i = 0; i < relation_stats.size(); i++) { - auto stats = relation_stats.at(i); - auto &relation_set = query_graph_manager.set_manager.GetJoinRelation(i); - auto join_node = make_uniq(relation_set); - join_node->cost = 0; - join_node->cardinality = stats.cardinality; - plans[relation_set] = std::move(join_node); - cost_model.cardinality_estimator.InitCardinalityEstimatorProps(&relation_set, stats); - } -} - -// the plan enumeration is a straight implementation of the paper "Dynamic Programming Strikes Back" by Guido -// Moerkotte and Thomas Neumannn, see that paper for additional info/documentation bonus slides: -// https://db.in.tum.de/teaching/ws1415/queryopt/chapter3.pdf?lang=de -unique_ptr PlanEnumerator::SolveJoinOrder() { - bool force_no_cross_product = query_graph_manager.context.config.force_no_cross_product; - // first try to solve the join order exactly - if (!SolveJoinOrderExactly()) { - // otherwise, if that times out we resort to a greedy algorithm - SolveJoinOrderApproximately(); - } - - // now the optimal join path should have been found - // get it from the node - unordered_set bindings; - for (idx_t i = 0; i < query_graph_manager.relation_manager.NumRelations(); i++) { - bindings.insert(i); - } - auto &total_relation = query_graph_manager.set_manager.GetJoinRelation(bindings); - auto final_plan = plans.find(total_relation); - if (final_plan == plans.end()) { - // could not find the final plan - // this should only happen in case the sets are actually disjunct - // in this case we need to generate cross product to connect the disjoint sets - if (force_no_cross_product) { - throw InvalidInputException( - "Query requires a cross-product, but 'force_no_cross_product' PRAGMA is enabled"); - } - GenerateCrossProducts(); - //! solve the join order again, returning the final plan - return SolveJoinOrder(); - } - return std::move(final_plan->second); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -using QueryEdge = QueryGraphEdges::QueryEdge; - -// LCOV_EXCL_START -static string QueryEdgeToString(const QueryEdge *info, vector prefix) { - string result = ""; - string source = "["; - for (idx_t i = 0; i < prefix.size(); i++) { - source += to_string(prefix[i]) + (i < prefix.size() - 1 ? ", " : ""); - } - source += "]"; - for (auto &entry : info->neighbors) { - result += StringUtil::Format("%s -> %s\n", source.c_str(), entry->neighbor->ToString().c_str()); - } - for (auto &entry : info->children) { - vector new_prefix = prefix; - new_prefix.push_back(entry.first); - result += QueryEdgeToString(entry.second.get(), new_prefix); - } - return result; -} - -string QueryGraphEdges::ToString() const { - return QueryEdgeToString(&root, {}); -} - -void QueryGraphEdges::Print() { - Printer::Print(ToString()); -} -// LCOV_EXCL_STOP - -optional_ptr QueryGraphEdges::GetQueryEdge(JoinRelationSet &left) { - D_ASSERT(left.count > 0); - // find the EdgeInfo corresponding to the left set - optional_ptr info(&root); - for (idx_t i = 0; i < left.count; i++) { - auto entry = info.get()->children.find(left.relations[i]); - if (entry == info.get()->children.end()) { - // node not found, create it - auto insert_it = info.get()->children.insert(make_pair(left.relations[i], make_uniq())); - entry = insert_it.first; - } - // move to the next node - info = entry->second; - } - return info; -} - -void QueryGraphEdges::CreateEdge(JoinRelationSet &left, JoinRelationSet &right, optional_ptr filter_info) { - D_ASSERT(left.count > 0 && right.count > 0); - // find the EdgeInfo corresponding to the left set - auto info = GetQueryEdge(left); - // now insert the edge to the right relation, if it does not exist - for (idx_t i = 0; i < info->neighbors.size(); i++) { - if (info->neighbors[i]->neighbor == &right) { - if (filter_info) { - // neighbor already exists just add the filter, if we have any - info->neighbors[i]->filters.push_back(filter_info); - } - return; - } - } - // neighbor does not exist, create it - auto n = make_uniq(&right); - // if the edge represents a cross product, filter_info is null. The easiest way then to determine - // if an edge is for a cross product is if the filters are empty - if (info && filter_info) { - n->filters.push_back(filter_info); - } - info->neighbors.push_back(std::move(n)); -} - -void QueryGraphEdges::EnumerateNeighborsDFS(JoinRelationSet &node, reference info, idx_t index, - const std::function &callback) const { - - for (auto &neighbor : info.get().neighbors) { - if (callback(*neighbor)) { - return; - } - } - - for (idx_t node_index = index; node_index < node.count; ++node_index) { - auto iter = info.get().children.find(node.relations[node_index]); - if (iter != info.get().children.end()) { - reference new_info = *iter->second; - EnumerateNeighborsDFS(node, new_info, node_index + 1, callback); - } - } -} - -void QueryGraphEdges::EnumerateNeighbors(JoinRelationSet &node, - const std::function &callback) const { - for (idx_t j = 0; j < node.count; j++) { - auto iter = root.children.find(node.relations[j]); - if (iter != root.children.end()) { - reference new_info = *iter->second; - EnumerateNeighborsDFS(node, new_info, j + 1, callback); - } - } -} - -//! Returns true if a JoinRelationSet is banned by the list of exclusion_set, false otherwise -static bool JoinRelationSetIsExcluded(optional_ptr node, unordered_set &exclusion_set) { - return exclusion_set.find(node->relations[0]) != exclusion_set.end(); -} - -const vector QueryGraphEdges::GetNeighbors(JoinRelationSet &node, unordered_set &exclusion_set) const { - unordered_set result; - EnumerateNeighbors(node, [&](NeighborInfo &info) -> bool { - if (!JoinRelationSetIsExcluded(info.neighbor, exclusion_set)) { - // add the smallest node of the neighbor to the set - result.insert(info.neighbor->relations[0]); - } - return false; - }); - vector neighbors; - neighbors.insert(neighbors.end(), result.begin(), result.end()); - return neighbors; -} - -const vector> QueryGraphEdges::GetConnections(JoinRelationSet &node, - JoinRelationSet &other) const { - vector> connections; - EnumerateNeighbors(node, [&](NeighborInfo &info) -> bool { - if (JoinRelationSet::IsSubset(other, *info.neighbor)) { - connections.push_back(info); - } - return false; - }); - return connections; -} - -} // namespace duckdb - - - - - - - - - - - -namespace duckdb { - -//! Returns true if A and B are disjoint, false otherwise -template -static bool Disjoint(const unordered_set &a, const unordered_set &b) { - return std::all_of(a.begin(), a.end(), [&b](typename std::unordered_set::const_reference entry) { - return b.find(entry) == b.end(); - }); -} - -bool QueryGraphManager::Build(LogicalOperator &op) { - vector> filter_operators; - // have the relation manager extract the join relations and create a reference list of all the - // filter operators. - auto can_reorder = relation_manager.ExtractJoinRelations(op, filter_operators); - auto num_relations = relation_manager.NumRelations(); - if (num_relations <= 1 || !can_reorder) { - // nothing to optimize/reorder - return false; - } - // extract the edges of the hypergraph, creating a list of filters and their associated bindings. - filters_and_bindings = relation_manager.ExtractEdges(op, filter_operators, set_manager); - // Create the query_graph hyper edges - CreateHyperGraphEdges(); - return true; -} - -void QueryGraphManager::GetColumnBinding(Expression &expression, ColumnBinding &binding) { - if (expression.type == ExpressionType::BOUND_COLUMN_REF) { - // Here you have a filter on a single column in a table. Return a binding for the column - // being filtered on so the filter estimator knows what HLL count to pull - auto &colref = expression.Cast(); - D_ASSERT(colref.depth == 0); - D_ASSERT(colref.binding.table_index != DConstants::INVALID_INDEX); - // map the base table index to the relation index used by the JoinOrderOptimizer - D_ASSERT(relation_manager.relation_mapping.find(colref.binding.table_index) != - relation_manager.relation_mapping.end()); - binding = - ColumnBinding(relation_manager.relation_mapping[colref.binding.table_index], colref.binding.column_index); - } - // TODO: handle inequality filters with functions. - ExpressionIterator::EnumerateChildren(expression, [&](Expression &expr) { GetColumnBinding(expr, binding); }); -} - -const vector> &QueryGraphManager::GetFilterBindings() const { - return filters_and_bindings; -} - -static unique_ptr PushFilter(unique_ptr node, unique_ptr expr) { - // push an expression into a filter - // first check if we have any filter to push it into - if (node->type != LogicalOperatorType::LOGICAL_FILTER) { - // we don't, we need to create one - auto filter = make_uniq(); - filter->children.push_back(std::move(node)); - node = std::move(filter); - } - // push the filter into the LogicalFilter - D_ASSERT(node->type == LogicalOperatorType::LOGICAL_FILTER); - auto &filter = node->Cast(); - filter.expressions.push_back(std::move(expr)); - return node; -} - -void QueryGraphManager::CreateHyperGraphEdges() { - // create potential edges from the comparisons - for (auto &filter_info : filters_and_bindings) { - auto &filter = filter_info->filter; - // now check if it can be used as a join predicate - if (filter->GetExpressionClass() == ExpressionClass::BOUND_COMPARISON) { - auto &comparison = filter->Cast(); - // extract the bindings that are required for the left and right side of the comparison - unordered_set left_bindings, right_bindings; - relation_manager.ExtractBindings(*comparison.left, left_bindings); - relation_manager.ExtractBindings(*comparison.right, right_bindings); - GetColumnBinding(*comparison.left, filter_info->left_binding); - GetColumnBinding(*comparison.right, filter_info->right_binding); - if (!left_bindings.empty() && !right_bindings.empty()) { - // both the left and the right side have bindings - // first create the relation sets, if they do not exist - filter_info->left_set = &set_manager.GetJoinRelation(left_bindings); - filter_info->right_set = &set_manager.GetJoinRelation(right_bindings); - // we can only create a meaningful edge if the sets are not exactly the same - if (filter_info->left_set != filter_info->right_set) { - // check if the sets are disjoint - if (Disjoint(left_bindings, right_bindings)) { - // they are disjoint, we only need to create one set of edges in the join graph - query_graph.CreateEdge(*filter_info->left_set, *filter_info->right_set, filter_info); - query_graph.CreateEdge(*filter_info->right_set, *filter_info->left_set, filter_info); - } else { - continue; - } - continue; - } - } - } - } -} - -static unique_ptr ExtractJoinRelation(unique_ptr &rel) { - auto &children = rel->parent->children; - for (idx_t i = 0; i < children.size(); i++) { - if (children[i].get() == &rel->op) { - // found it! take ownership o/**/f it from the parent - auto result = std::move(children[i]); - children.erase(children.begin() + i); - return result; - } - } - throw Exception("Could not find relation in parent node (?)"); -} - -unique_ptr QueryGraphManager::Reconstruct(unique_ptr plan, JoinNode &node) { - return RewritePlan(std::move(plan), node); -} - -GenerateJoinRelation QueryGraphManager::GenerateJoins(vector> &extracted_relations, - JoinNode &node) { - optional_ptr left_node; - optional_ptr right_node; - optional_ptr result_relation; - unique_ptr result_operator; - if (node.left && node.right && node.info) { - // generate the left and right children - auto left = GenerateJoins(extracted_relations, *node.left); - auto right = GenerateJoins(extracted_relations, *node.right); - - if (node.info->filters.empty()) { - // no filters, create a cross product - result_operator = LogicalCrossProduct::Create(std::move(left.op), std::move(right.op)); - } else { - // we have filters, create a join node - auto join = make_uniq(JoinType::INNER); - // Here we optimize build side probe side. Our build side is the right side - // So the right plans should have lower cardinalities. - join->children.push_back(std::move(left.op)); - join->children.push_back(std::move(right.op)); - - // set the join conditions from the join node - for (auto &filter_ref : node.info->filters) { - auto f = filter_ref.get(); - // extract the filter from the operator it originally belonged to - D_ASSERT(filters_and_bindings[f->filter_index]->filter); - auto &filter_and_binding = filters_and_bindings.at(f->filter_index); - auto condition = std::move(filter_and_binding->filter); - // now create the actual join condition - D_ASSERT((JoinRelationSet::IsSubset(*left.set, *f->left_set) && - JoinRelationSet::IsSubset(*right.set, *f->right_set)) || - (JoinRelationSet::IsSubset(*left.set, *f->right_set) && - JoinRelationSet::IsSubset(*right.set, *f->left_set))); - JoinCondition cond; - D_ASSERT(condition->GetExpressionClass() == ExpressionClass::BOUND_COMPARISON); - auto &comparison = condition->Cast(); - - // we need to figure out which side is which by looking at the relations available to us - bool invert = !JoinRelationSet::IsSubset(*left.set, *f->left_set); - cond.left = !invert ? std::move(comparison.left) : std::move(comparison.right); - cond.right = !invert ? std::move(comparison.right) : std::move(comparison.left); - cond.comparison = condition->type; - - if (invert) { - // reverse comparison expression if we reverse the order of the children - cond.comparison = FlipComparisonExpression(cond.comparison); - } - join->conditions.push_back(std::move(cond)); - } - D_ASSERT(!join->conditions.empty()); - result_operator = std::move(join); - } - left_node = left.set; - right_node = right.set; - result_relation = &set_manager.Union(*left.set, *right.set); - } else { - // base node, get the entry from the list of extracted relations - D_ASSERT(node.set.count == 1); - D_ASSERT(extracted_relations[node.set.relations[0]]); - result_relation = &node.set; - result_operator = std::move(extracted_relations[node.set.relations[0]]); - } - // TODO: this is where estimated properties start coming into play. - // when creating the result operator, we should ask the cost model and cardinality estimator what - // the cost and cardinality are - // result_operator->estimated_props = node.estimated_props->Copy(); - result_operator->estimated_cardinality = node.cardinality; - result_operator->has_estimated_cardinality = true; - if (result_operator->type == LogicalOperatorType::LOGICAL_FILTER && - result_operator->children[0]->type == LogicalOperatorType::LOGICAL_GET) { - // FILTER on top of GET, add estimated properties to both - // auto &filter_props = *result_operator->estimated_props; - auto &child_operator = *result_operator->children[0]; - child_operator.estimated_cardinality = node.cardinality; - child_operator.has_estimated_cardinality = true; - } - // check if we should do a pushdown on this node - // basically, any remaining filter that is a subset of the current relation will no longer be used in joins - // hence we should push it here - for (auto &filter_info : filters_and_bindings) { - // check if the filter has already been extracted - auto &info = *filter_info; - if (filters_and_bindings[info.filter_index]->filter) { - // now check if the filter is a subset of the current relation - // note that infos with an empty relation set are a special case and we do not push them down - if (info.set.count > 0 && JoinRelationSet::IsSubset(*result_relation, info.set)) { - auto &filter_and_binding = filters_and_bindings[info.filter_index]; - auto filter = std::move(filter_and_binding->filter); - // if it is, we can push the filter - // we can push it either into a join or as a filter - // check if we are in a join or in a base table - if (!left_node || !info.left_set) { - // base table or non-comparison expression, push it as a filter - result_operator = PushFilter(std::move(result_operator), std::move(filter)); - continue; - } - // the node below us is a join or cross product and the expression is a comparison - // check if the nodes can be split up into left/right - bool found_subset = false; - bool invert = false; - if (JoinRelationSet::IsSubset(*left_node, *info.left_set) && - JoinRelationSet::IsSubset(*right_node, *info.right_set)) { - found_subset = true; - } else if (JoinRelationSet::IsSubset(*right_node, *info.left_set) && - JoinRelationSet::IsSubset(*left_node, *info.right_set)) { - invert = true; - found_subset = true; - } - if (!found_subset) { - // could not be split up into left/right - result_operator = PushFilter(std::move(result_operator), std::move(filter)); - continue; - } - // create the join condition - JoinCondition cond; - D_ASSERT(filter->GetExpressionClass() == ExpressionClass::BOUND_COMPARISON); - auto &comparison = filter->Cast(); - // we need to figure out which side is which by looking at the relations available to us - cond.left = !invert ? std::move(comparison.left) : std::move(comparison.right); - cond.right = !invert ? std::move(comparison.right) : std::move(comparison.left); - cond.comparison = comparison.type; - if (invert) { - // reverse comparison expression if we reverse the order of the children - cond.comparison = FlipComparisonExpression(comparison.type); - } - // now find the join to push it into - auto node = result_operator.get(); - if (node->type == LogicalOperatorType::LOGICAL_FILTER) { - node = node->children[0].get(); - } - if (node->type == LogicalOperatorType::LOGICAL_CROSS_PRODUCT) { - // turn into comparison join - auto comp_join = make_uniq(JoinType::INNER); - comp_join->children.push_back(std::move(node->children[0])); - comp_join->children.push_back(std::move(node->children[1])); - comp_join->conditions.push_back(std::move(cond)); - if (node == result_operator.get()) { - result_operator = std::move(comp_join); - } else { - D_ASSERT(result_operator->type == LogicalOperatorType::LOGICAL_FILTER); - result_operator->children[0] = std::move(comp_join); - } - } else { - D_ASSERT(node->type == LogicalOperatorType::LOGICAL_COMPARISON_JOIN || - node->type == LogicalOperatorType::LOGICAL_ASOF_JOIN); - auto &comp_join = node->Cast(); - comp_join.conditions.push_back(std::move(cond)); - } - } - } - } - auto result = GenerateJoinRelation(result_relation, std::move(result_operator)); - return result; -} - -const QueryGraphEdges &QueryGraphManager::GetQueryGraphEdges() const { - return query_graph; -} - -void QueryGraphManager::CreateQueryGraphCrossProduct(JoinRelationSet &left, JoinRelationSet &right) { - query_graph.CreateEdge(left, right, nullptr); - query_graph.CreateEdge(right, left, nullptr); -} - -unique_ptr QueryGraphManager::RewritePlan(unique_ptr plan, JoinNode &node) { - // now we have to rewrite the plan - bool root_is_join = plan->children.size() > 1; - - // first we will extract all relations from the main plan - vector> extracted_relations; - extracted_relations.reserve(relation_manager.NumRelations()); - for (auto &relation : relation_manager.GetRelations()) { - extracted_relations.push_back(ExtractJoinRelation(relation)); - } - - // now we generate the actual joins - auto join_tree = GenerateJoins(extracted_relations, node); - // perform the final pushdown of remaining filters - for (auto &filter : filters_and_bindings) { - // check if the filter has already been extracted - if (filter->filter) { - // if not we need to push it - join_tree.op = PushFilter(std::move(join_tree.op), std::move(filter->filter)); - } - } - - // find the first join in the relation to know where to place this node - if (root_is_join) { - // first node is the join, return it immediately - return std::move(join_tree.op); - } - D_ASSERT(plan->children.size() == 1); - // have to move up through the relations - auto op = plan.get(); - auto parent = plan.get(); - while (op->type != LogicalOperatorType::LOGICAL_CROSS_PRODUCT && - op->type != LogicalOperatorType::LOGICAL_COMPARISON_JOIN && - op->type != LogicalOperatorType::LOGICAL_ASOF_JOIN) { - D_ASSERT(op->children.size() == 1); - parent = op; - op = op->children[0].get(); - } - // have to replace at this node - parent->children[0] = std::move(join_tree.op); - return plan; -} - -bool QueryGraphManager::LeftCardLessThanRight(LogicalOperator &op) { - D_ASSERT(op.children.size() == 2); - if (op.children[0]->has_estimated_cardinality && op.children[1]->has_estimated_cardinality) { - return op.children[0]->estimated_cardinality < op.children[1]->estimated_cardinality; - } - return op.children[0]->EstimateCardinality(context) < op.children[1]->EstimateCardinality(context); -} - -unique_ptr QueryGraphManager::LeftRightOptimizations(unique_ptr input_op) { - auto op = input_op.get(); - // pass through single child operators - while (!op->children.empty()) { - if (op->children.size() == 2) { - switch (op->type) { - case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: { - auto &join = op->Cast(); - if (join.join_type == JoinType::INNER) { - if (LeftCardLessThanRight(*op)) { - std::swap(op->children[0], op->children[1]); - for (auto &cond : join.conditions) { - std::swap(cond.left, cond.right); - cond.comparison = FlipComparisonExpression(cond.comparison); - } - } - } else if (join.join_type == JoinType::LEFT && join.right_projection_map.empty()) { - auto lhs_cardinality = join.children[0]->EstimateCardinality(context); - auto rhs_cardinality = join.children[1]->EstimateCardinality(context); - if (rhs_cardinality > lhs_cardinality * 2) { - join.join_type = JoinType::RIGHT; - std::swap(join.children[0], join.children[1]); - for (auto &cond : join.conditions) { - std::swap(cond.left, cond.right); - cond.comparison = FlipComparisonExpression(cond.comparison); - } - } - } - break; - } - case LogicalOperatorType::LOGICAL_CROSS_PRODUCT: { - if (LeftCardLessThanRight(*op)) { - std::swap(op->children[0], op->children[1]); - } - break; - } - case LogicalOperatorType::LOGICAL_ANY_JOIN: { - auto &join = op->Cast(); - if (join.join_type == JoinType::LEFT && join.right_projection_map.empty()) { - auto lhs_cardinality = join.children[0]->EstimateCardinality(context); - auto rhs_cardinality = join.children[1]->EstimateCardinality(context); - if (rhs_cardinality > lhs_cardinality * 2) { - join.join_type = JoinType::RIGHT; - std::swap(join.children[0], join.children[1]); - } - } else if (join.join_type == JoinType::INNER && LeftCardLessThanRight(*op)) { - std::swap(join.children[0], join.children[1]); - } - break; - } - default: - break; - } - op->children[0] = LeftRightOptimizations(std::move(op->children[0])); - op->children[1] = LeftRightOptimizations(std::move(op->children[1])); - // break from while loop - break; - } - if (op->children.size() == 1) { - op = op->children[0].get(); - } - } - return input_op; -} - -} // namespace duckdb - - - - - - - - - - -#include - -namespace duckdb { - -const vector RelationManager::GetRelationStats() { - vector ret; - for (idx_t i = 0; i < relations.size(); i++) { - ret.push_back(relations[i]->stats); - } - return ret; -} - -vector> RelationManager::GetRelations() { - return std::move(relations); -} - -idx_t RelationManager::NumRelations() { - return relations.size(); -} - -void RelationManager::AddAggregateRelation(LogicalOperator &op, optional_ptr parent, - const RelationStats &stats) { - auto relation = make_uniq(op, parent, stats); - auto relation_id = relations.size(); - - auto table_indexes = op.GetTableIndex(); - for (auto &index : table_indexes) { - D_ASSERT(relation_mapping.find(index) == relation_mapping.end()); - relation_mapping[index] = relation_id; - } - relations.push_back(std::move(relation)); -} - -void RelationManager::AddRelation(LogicalOperator &op, optional_ptr parent, - const RelationStats &stats) { - - // if parent is null, then this is a root relation - // if parent is not null, it should have multiple children - D_ASSERT(!parent || parent->children.size() >= 2); - auto relation = make_uniq(op, parent, stats); - auto relation_id = relations.size(); - - auto table_indexes = op.GetTableIndex(); - if (table_indexes.empty()) { - // relation represents a non-reorderable relation, most likely a join relation - // Get the tables referenced in the non-reorderable relation and add them to the relation mapping - // This should all table references, even if there are nested non-reorderable joins. - unordered_set table_references; - LogicalJoin::GetTableReferences(op, table_references); - D_ASSERT(table_references.size() > 0); - for (auto &reference : table_references) { - D_ASSERT(relation_mapping.find(reference) == relation_mapping.end()); - relation_mapping[reference] = relation_id; - } - } else { - // Relations should never return more than 1 table index - D_ASSERT(table_indexes.size() == 1); - idx_t table_index = table_indexes.at(0); - D_ASSERT(relation_mapping.find(table_index) == relation_mapping.end()); - relation_mapping[table_index] = relation_id; - } - relations.push_back(std::move(relation)); -} - -static bool OperatorNeedsRelation(LogicalOperatorType op_type) { - switch (op_type) { - case LogicalOperatorType::LOGICAL_PROJECTION: - case LogicalOperatorType::LOGICAL_EXPRESSION_GET: - case LogicalOperatorType::LOGICAL_GET: - case LogicalOperatorType::LOGICAL_DELIM_GET: - case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: - case LogicalOperatorType::LOGICAL_WINDOW: - return true; - default: - return false; - } -} - -static bool OperatorIsNonReorderable(LogicalOperatorType op_type) { - switch (op_type) { - case LogicalOperatorType::LOGICAL_UNION: - case LogicalOperatorType::LOGICAL_EXCEPT: - case LogicalOperatorType::LOGICAL_INTERSECT: - case LogicalOperatorType::LOGICAL_DELIM_JOIN: - case LogicalOperatorType::LOGICAL_ANY_JOIN: - case LogicalOperatorType::LOGICAL_ASOF_JOIN: - return true; - default: - return false; - } -} - -static bool HasNonReorderableChild(LogicalOperator &op) { - LogicalOperator *tmp = &op; - while (tmp->children.size() == 1) { - if (OperatorNeedsRelation(tmp->type) || OperatorIsNonReorderable(tmp->type)) { - return true; - } - tmp = tmp->children[0].get(); - if (tmp->type == LogicalOperatorType::LOGICAL_COMPARISON_JOIN) { - auto &join = tmp->Cast(); - if (join.join_type != JoinType::INNER) { - return true; - } - } - } - return tmp->children.empty(); -} - -bool RelationManager::ExtractJoinRelations(LogicalOperator &input_op, - vector> &filter_operators, - optional_ptr parent) { - LogicalOperator *op = &input_op; - vector> datasource_filters; - // pass through single child operators - while (op->children.size() == 1 && !OperatorNeedsRelation(op->type)) { - if (op->type == LogicalOperatorType::LOGICAL_FILTER) { - if (HasNonReorderableChild(*op)) { - datasource_filters.push_back(*op); - } - filter_operators.push_back(*op); - } - if (op->type == LogicalOperatorType::LOGICAL_SHOW) { - return false; - } - op = op->children[0].get(); - } - bool non_reorderable_operation = false; - if (OperatorIsNonReorderable(op->type)) { - // set operation, optimize separately in children - non_reorderable_operation = true; - } - - if (op->type == LogicalOperatorType::LOGICAL_COMPARISON_JOIN) { - auto &join = op->Cast(); - if (join.join_type == JoinType::INNER) { - // extract join conditions from inner join - filter_operators.push_back(*op); - } else { - non_reorderable_operation = true; - } - } - if (non_reorderable_operation) { - // we encountered a non-reordable operation (setop or non-inner join) - // we do not reorder non-inner joins yet, however we do want to expand the potential join graph around them - // non-inner joins are also tricky because we can't freely make conditions through them - // e.g. suppose we have (left LEFT OUTER JOIN right WHERE right IS NOT NULL), the join can generate - // new NULL values in the right side, so pushing this condition through the join leads to incorrect results - // for this reason, we just start a new JoinOptimizer pass in each of the children of the join - // stats.cardinality will be initiated to highest cardinality of the children. - vector children_stats; - for (auto &child : op->children) { - auto stats = RelationStats(); - JoinOrderOptimizer optimizer(context); - child = optimizer.Optimize(std::move(child), &stats); - children_stats.push_back(stats); - } - - auto combined_stats = RelationStatisticsHelper::CombineStatsOfNonReorderableOperator(*op, children_stats); - if (!datasource_filters.empty()) { - combined_stats.cardinality = - (idx_t)MaxValue(combined_stats.cardinality * RelationStatisticsHelper::DEFAULT_SELECTIVITY, (double)1); - } - AddRelation(input_op, parent, combined_stats); - return true; - } - - switch (op->type) { - case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: { - // optimize children - RelationStats child_stats; - JoinOrderOptimizer optimizer(context); - op->children[0] = optimizer.Optimize(std::move(op->children[0]), &child_stats); - auto &aggr = op->Cast(); - auto operator_stats = RelationStatisticsHelper::ExtractAggregationStats(aggr, child_stats); - AddAggregateRelation(input_op, parent, operator_stats); - return true; - } - case LogicalOperatorType::LOGICAL_WINDOW: { - // optimize children - RelationStats child_stats; - JoinOrderOptimizer optimizer(context); - op->children[0] = optimizer.Optimize(std::move(op->children[0]), &child_stats); - auto &window = op->Cast(); - auto operator_stats = RelationStatisticsHelper::ExtractWindowStats(window, child_stats); - AddAggregateRelation(input_op, parent, operator_stats); - return true; - } - case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: - case LogicalOperatorType::LOGICAL_CROSS_PRODUCT: { - // Adding relations to the current join order optimizer - bool can_reorder_left = ExtractJoinRelations(*op->children[0], filter_operators, op); - bool can_reorder_right = ExtractJoinRelations(*op->children[1], filter_operators, op); - return can_reorder_left && can_reorder_right; - } - case LogicalOperatorType::LOGICAL_DUMMY_SCAN: { - auto &dummy_scan = op->Cast(); - auto stats = RelationStatisticsHelper::ExtractDummyScanStats(dummy_scan, context); - AddRelation(input_op, parent, stats); - return true; - } - case LogicalOperatorType::LOGICAL_EXPRESSION_GET: { - // base table scan, add to set of relations. - // create empty stats for dummy scan or logical expression get - auto &expression_get = op->Cast(); - auto stats = RelationStatisticsHelper::ExtractExpressionGetStats(expression_get, context); - AddRelation(input_op, parent, stats); - return true; - } - case LogicalOperatorType::LOGICAL_GET: { - // TODO: Get stats from a logical GET - auto &get = op->Cast(); - auto stats = RelationStatisticsHelper::ExtractGetStats(get, context); - // if there is another logical filter that could not be pushed down into the - // table scan, apply another selectivity. - if (!datasource_filters.empty()) { - stats.cardinality = - (idx_t)MaxValue(stats.cardinality * RelationStatisticsHelper::DEFAULT_SELECTIVITY, (double)1); - } - AddRelation(input_op, parent, stats); - return true; - } - case LogicalOperatorType::LOGICAL_DELIM_GET: { - auto &delim_get = op->Cast(); - auto stats = RelationStatisticsHelper::ExtractDelimGetStats(delim_get, context); - AddRelation(input_op, parent, stats); - return true; - } - case LogicalOperatorType::LOGICAL_PROJECTION: { - auto child_stats = RelationStats(); - // optimize the child and copy the stats - JoinOrderOptimizer optimizer(context); - op->children[0] = optimizer.Optimize(std::move(op->children[0]), &child_stats); - auto &proj = op->Cast(); - // Projection can create columns so we need to add them here - auto proj_stats = RelationStatisticsHelper::ExtractProjectionStats(proj, child_stats); - AddRelation(input_op, parent, proj_stats); - return true; - } - default: - return false; - } -} - -bool RelationManager::ExtractBindings(Expression &expression, unordered_set &bindings) { - if (expression.type == ExpressionType::BOUND_COLUMN_REF) { - auto &colref = expression.Cast(); - D_ASSERT(colref.depth == 0); - D_ASSERT(colref.binding.table_index != DConstants::INVALID_INDEX); - // map the base table index to the relation index used by the JoinOrderOptimizer - if (expression.alias == "SUBQUERY" && - relation_mapping.find(colref.binding.table_index) == relation_mapping.end()) { - // most likely a BoundSubqueryExpression that was created from an uncorrelated subquery - // Here we return true and don't fill the bindings, the expression can be reordered. - // A filter will be created using this expression, and pushed back on top of the parent - // operator during plan reconstruction - return true; - } - D_ASSERT(relation_mapping.find(colref.binding.table_index) != relation_mapping.end()); - bindings.insert(relation_mapping[colref.binding.table_index]); - } - if (expression.type == ExpressionType::BOUND_REF) { - // bound expression - bindings.clear(); - return false; - } - D_ASSERT(expression.type != ExpressionType::SUBQUERY); - bool can_reorder = true; - ExpressionIterator::EnumerateChildren(expression, [&](Expression &expr) { - if (!ExtractBindings(expr, bindings)) { - can_reorder = false; - return; - } - }); - return can_reorder; -} - -vector> RelationManager::ExtractEdges(LogicalOperator &op, - vector> &filter_operators, - JoinRelationSetManager &set_manager) { - // now that we know we are going to perform join ordering we actually extract the filters, eliminating duplicate - // filters in the process - vector> filters_and_bindings; - expression_set_t filter_set; - for (auto &filter_op : filter_operators) { - auto &f_op = filter_op.get(); - if (f_op.type == LogicalOperatorType::LOGICAL_COMPARISON_JOIN || - f_op.type == LogicalOperatorType::LOGICAL_ASOF_JOIN) { - auto &join = f_op.Cast(); - D_ASSERT(join.join_type == JoinType::INNER); - D_ASSERT(join.expressions.empty()); - for (auto &cond : join.conditions) { - auto comparison = - make_uniq(cond.comparison, std::move(cond.left), std::move(cond.right)); - if (filter_set.find(*comparison) == filter_set.end()) { - filter_set.insert(*comparison); - unordered_set bindings; - ExtractBindings(*comparison, bindings); - auto &set = set_manager.GetJoinRelation(bindings); - auto filter_info = make_uniq(std::move(comparison), set, filters_and_bindings.size()); - filters_and_bindings.push_back(std::move(filter_info)); - } - } - join.conditions.clear(); - } else { - for (auto &expression : f_op.expressions) { - if (filter_set.find(*expression) == filter_set.end()) { - filter_set.insert(*expression); - unordered_set bindings; - ExtractBindings(*expression, bindings); - auto &set = set_manager.GetJoinRelation(bindings); - auto filter_info = make_uniq(std::move(expression), set, filters_and_bindings.size()); - filters_and_bindings.push_back(std::move(filter_info)); - } - } - f_op.expressions.clear(); - } - } - - return filters_and_bindings; -} - -// LCOV_EXCL_START - -void RelationManager::PrintRelationStats() { -#ifdef DEBUG - string to_print; - for (idx_t i = 0; i < relations.size(); i++) { - auto &relation = relations.at(i); - auto &stats = relation->stats; - D_ASSERT(stats.column_names.size() == stats.column_distinct_count.size()); - for (idx_t i = 0; i < stats.column_names.size(); i++) { - to_print = stats.column_names.at(i) + " has estimated distinct count " + - to_string(stats.column_distinct_count.at(i).distinct_count); - Printer::Print(to_print); - } - to_print = stats.table_name + " has estimated cardinality " + to_string(stats.cardinality); - to_print += " and relation id " + to_string(i) + "\n"; - Printer::Print(to_print); - } -#endif -} - -// LCOV_EXCL_STOP - -} // namespace duckdb - - - - - - - - - - - -namespace duckdb { - -static ExpressionBinding GetChildColumnBinding(Expression &expr) { - auto ret = ExpressionBinding(); - switch (expr.expression_class) { - case ExpressionClass::BOUND_FUNCTION: { - // TODO: Other expression classes that can have 0 children? - auto &func = expr.Cast(); - // no children some sort of gen_random_uuid() or equivalent. - if (func.children.empty()) { - ret.found_expression = true; - ret.expression_is_constant = true; - return ret; - } - break; - } - case ExpressionClass::BOUND_COLUMN_REF: { - ret.found_expression = true; - auto &new_col_ref = expr.Cast(); - ret.child_binding = ColumnBinding(new_col_ref.binding.table_index, new_col_ref.binding.column_index); - return ret; - } - case ExpressionClass::BOUND_LAMBDA_REF: - case ExpressionClass::BOUND_CONSTANT: - case ExpressionClass::BOUND_DEFAULT: - case ExpressionClass::BOUND_PARAMETER: - case ExpressionClass::BOUND_REF: - ret.found_expression = true; - ret.expression_is_constant = true; - return ret; - default: - break; - } - ExpressionIterator::EnumerateChildren(expr, [&](unique_ptr &child) { - auto recursive_result = GetChildColumnBinding(*child); - if (recursive_result.found_expression) { - ret = recursive_result; - } - }); - // we didn't find a Bound Column Ref - return ret; -} - -RelationStats RelationStatisticsHelper::ExtractGetStats(LogicalGet &get, ClientContext &context) { - auto return_stats = RelationStats(); - - auto base_table_cardinality = get.EstimateCardinality(context); - auto cardinality_after_filters = base_table_cardinality; - unique_ptr column_statistics; - - auto table_thing = get.GetTable(); - auto name = string("some table"); - if (table_thing) { - name = table_thing->name; - return_stats.table_name = name; - } - - // if we can get the catalog table, then our column statistics will be accurate - // parquet readers etc. will still return statistics, but they initialize distinct column - // counts to 0. - // TODO: fix this, some file formats can encode distinct counts, we don't want to rely on - // getting a catalog table to know that we can use statistics. - bool have_catalog_table_statistics = false; - if (get.GetTable()) { - have_catalog_table_statistics = true; - } - - // first push back basic distinct counts for each column (if we have them). - for (idx_t i = 0; i < get.column_ids.size(); i++) { - bool have_distinct_count_stats = false; - if (get.function.statistics) { - column_statistics = get.function.statistics(context, get.bind_data.get(), get.column_ids[i]); - if (column_statistics && have_catalog_table_statistics) { - auto column_distinct_count = DistinctCount({column_statistics->GetDistinctCount(), true}); - return_stats.column_distinct_count.push_back(column_distinct_count); - return_stats.column_names.push_back(name + "." + get.names.at(get.column_ids.at(i))); - have_distinct_count_stats = true; - } - } - if (!have_distinct_count_stats) { - // currently treating the cardinality as the distinct count. - // the cardinality estimator will update these distinct counts based - // on the extra columns that are joined on. - auto column_distinct_count = DistinctCount({cardinality_after_filters, false}); - return_stats.column_distinct_count.push_back(column_distinct_count); - auto column_name = string("column"); - if (get.column_ids.at(i) < get.names.size()) { - column_name = get.names.at(get.column_ids.at(i)); - } - return_stats.column_names.push_back(get.GetName() + "." + column_name); - } - } - - if (!get.table_filters.filters.empty()) { - column_statistics = nullptr; - for (auto &it : get.table_filters.filters) { - if (get.bind_data && get.function.name.compare("seq_scan") == 0) { - auto &table_scan_bind_data = get.bind_data->Cast(); - column_statistics = get.function.statistics(context, &table_scan_bind_data, it.first); - } - - if (column_statistics && it.second->filter_type == TableFilterType::CONJUNCTION_AND) { - auto &filter = it.second->Cast(); - idx_t cardinality_with_and_filter = RelationStatisticsHelper::InspectConjunctionAND( - base_table_cardinality, it.first, filter, *column_statistics); - cardinality_after_filters = MinValue(cardinality_after_filters, cardinality_with_and_filter); - } - } - // if the above code didn't find an equality filter (i.e country_code = "[us]") - // and there are other table filters (i.e cost > 50), use default selectivity. - bool has_equality_filter = (cardinality_after_filters != base_table_cardinality); - if (!has_equality_filter && !get.table_filters.filters.empty()) { - cardinality_after_filters = - MaxValue(base_table_cardinality * RelationStatisticsHelper::DEFAULT_SELECTIVITY, 1); - } - if (base_table_cardinality == 0) { - cardinality_after_filters = 0; - } - } - return_stats.cardinality = cardinality_after_filters; - // update the estimated cardinality of the get as well. - // This is not updated during plan reconstruction. - get.estimated_cardinality = cardinality_after_filters; - get.has_estimated_cardinality = true; - D_ASSERT(base_table_cardinality >= cardinality_after_filters); - return_stats.stats_initialized = true; - return return_stats; -} - -RelationStats RelationStatisticsHelper::ExtractDelimGetStats(LogicalDelimGet &delim_get, ClientContext &context) { - RelationStats stats; - stats.table_name = delim_get.GetName(); - idx_t card = delim_get.EstimateCardinality(context); - stats.cardinality = card; - stats.stats_initialized = true; - for (auto &binding : delim_get.GetColumnBindings()) { - stats.column_distinct_count.push_back(DistinctCount({1, false})); - stats.column_names.push_back("column" + to_string(binding.column_index)); - } - return stats; -} - -RelationStats RelationStatisticsHelper::ExtractProjectionStats(LogicalProjection &proj, RelationStats &child_stats) { - auto proj_stats = RelationStats(); - proj_stats.cardinality = child_stats.cardinality; - proj_stats.table_name = proj.GetName(); - for (auto &expr : proj.expressions) { - proj_stats.column_names.push_back(expr->GetName()); - auto res = GetChildColumnBinding(*expr); - D_ASSERT(res.found_expression); - if (res.expression_is_constant) { - proj_stats.column_distinct_count.push_back(DistinctCount({1, true})); - } else { - auto column_index = res.child_binding.column_index; - if (column_index >= child_stats.column_distinct_count.size() && expr->ToString() == "count_star()") { - // only one value for a count star - proj_stats.column_distinct_count.push_back(DistinctCount({1, true})); - } else { - // TODO: add this back in - // D_ASSERT(column_index < stats.column_distinct_count.size()); - if (column_index < child_stats.column_distinct_count.size()) { - proj_stats.column_distinct_count.push_back(child_stats.column_distinct_count.at(column_index)); - } else { - proj_stats.column_distinct_count.push_back(DistinctCount({proj_stats.cardinality, false})); - } - } - } - } - proj_stats.stats_initialized = true; - return proj_stats; -} - -RelationStats RelationStatisticsHelper::ExtractDummyScanStats(LogicalDummyScan &dummy_scan, ClientContext &context) { - auto stats = RelationStats(); - idx_t card = dummy_scan.EstimateCardinality(context); - stats.cardinality = card; - for (idx_t i = 0; i < dummy_scan.GetColumnBindings().size(); i++) { - stats.column_distinct_count.push_back(DistinctCount({card, false})); - stats.column_names.push_back("dummy_scan_column"); - } - stats.stats_initialized = true; - stats.table_name = "dummy scan"; - return stats; -} - -void RelationStatisticsHelper::CopyRelationStats(RelationStats &to, const RelationStats &from) { - to.column_distinct_count = from.column_distinct_count; - to.column_names = from.column_names; - to.cardinality = from.cardinality; - to.table_name = from.table_name; - to.stats_initialized = from.stats_initialized; -} - -RelationStats RelationStatisticsHelper::CombineStatsOfReorderableOperator(vector &bindings, - vector relation_stats) { - RelationStats stats; - idx_t max_card = 0; - for (auto &child_stats : relation_stats) { - for (idx_t i = 0; i < child_stats.column_distinct_count.size(); i++) { - stats.column_distinct_count.push_back(child_stats.column_distinct_count.at(i)); - stats.column_names.push_back(child_stats.column_names.at(i)); - } - stats.table_name += "joined with " + child_stats.table_name; - max_card = MaxValue(max_card, child_stats.cardinality); - } - stats.stats_initialized = true; - stats.cardinality = max_card; - return stats; -} - -RelationStats RelationStatisticsHelper::CombineStatsOfNonReorderableOperator(LogicalOperator &op, - vector child_stats) { - D_ASSERT(child_stats.size() == 2); - RelationStats ret; - idx_t child_1_card = child_stats[0].stats_initialized ? child_stats[0].cardinality : 0; - idx_t child_2_card = child_stats[1].stats_initialized ? child_stats[1].cardinality : 0; - ret.cardinality = MaxValue(child_1_card, child_2_card); - ret.stats_initialized = true; - ret.filter_strength = 1; - ret.table_name = child_stats[0].table_name + " joined with " + child_stats[1].table_name; - for (auto &stats : child_stats) { - // MARK joins are nonreorderable. They won't return initialized stats - // continue in this case. - if (!stats.stats_initialized) { - continue; - } - for (auto &distinct_count : stats.column_distinct_count) { - ret.column_distinct_count.push_back(distinct_count); - } - for (auto &column_name : stats.column_names) { - ret.column_names.push_back(column_name); - } - } - return ret; -} - -RelationStats RelationStatisticsHelper::ExtractExpressionGetStats(LogicalExpressionGet &expression_get, - ClientContext &context) { - auto stats = RelationStats(); - idx_t card = expression_get.EstimateCardinality(context); - stats.cardinality = card; - for (idx_t i = 0; i < expression_get.GetColumnBindings().size(); i++) { - stats.column_distinct_count.push_back(DistinctCount({card, false})); - stats.column_names.push_back("expression_get_column"); - } - stats.stats_initialized = true; - stats.table_name = "expression_get"; - return stats; -} - -RelationStats RelationStatisticsHelper::ExtractWindowStats(LogicalWindow &window, RelationStats &child_stats) { - RelationStats stats; - stats.cardinality = child_stats.cardinality; - stats.column_distinct_count = child_stats.column_distinct_count; - stats.column_names = child_stats.column_names; - stats.stats_initialized = true; - auto num_child_columns = window.GetColumnBindings().size(); - - for (idx_t column_index = child_stats.column_distinct_count.size(); column_index < num_child_columns; - column_index++) { - stats.column_distinct_count.push_back(DistinctCount({child_stats.cardinality, false})); - stats.column_names.push_back("window"); - } - return stats; -} - -RelationStats RelationStatisticsHelper::ExtractAggregationStats(LogicalAggregate &aggr, RelationStats &child_stats) { - RelationStats stats; - // TODO: look at child distinct count to better estimate cardinality. - stats.cardinality = child_stats.cardinality; - stats.column_distinct_count = child_stats.column_distinct_count; - stats.column_names = child_stats.column_names; - stats.stats_initialized = true; - auto num_child_columns = aggr.GetColumnBindings().size(); - - for (idx_t column_index = child_stats.column_distinct_count.size(); column_index < num_child_columns; - column_index++) { - stats.column_distinct_count.push_back(DistinctCount({child_stats.cardinality, false})); - stats.column_names.push_back("aggregate"); - } - return stats; -} - -idx_t RelationStatisticsHelper::InspectConjunctionAND(idx_t cardinality, idx_t column_index, - ConjunctionAndFilter &filter, BaseStatistics &base_stats) { - auto cardinality_after_filters = cardinality; - for (auto &child_filter : filter.child_filters) { - if (child_filter->filter_type != TableFilterType::CONSTANT_COMPARISON) { - continue; - } - auto &comparison_filter = child_filter->Cast(); - if (comparison_filter.comparison_type != ExpressionType::COMPARE_EQUAL) { - continue; - } - auto column_count = base_stats.GetDistinctCount(); - auto filtered_card = cardinality; - // column_count = 0 when there is no column count (i.e parquet scans) - if (column_count > 0) { - // we want the ceil of cardinality/column_count. We also want to avoid compiler errors - filtered_card = (cardinality + column_count - 1) / column_count; - cardinality_after_filters = filtered_card; - } - } - return cardinality_after_filters; -} - -// TODO: Currently only simple AND filters are pushed into table scans. -// When OR filters are pushed this function can be added -// idx_t RelationStatisticsHelper::InspectConjunctionOR(idx_t cardinality, idx_t column_index, ConjunctionOrFilter -// &filter, -// BaseStatistics &base_stats) { -// auto has_equality_filter = false; -// auto cardinality_after_filters = cardinality; -// for (auto &child_filter : filter.child_filters) { -// if (child_filter->filter_type != TableFilterType::CONSTANT_COMPARISON) { -// continue; -// } -// auto &comparison_filter = child_filter->Cast(); -// if (comparison_filter.comparison_type == ExpressionType::COMPARE_EQUAL) { -// auto column_count = base_stats.GetDistinctCount(); -// auto increment = MaxValue(((cardinality + column_count - 1) / column_count), 1); -// if (has_equality_filter) { -// cardinality_after_filters += increment; -// } else { -// cardinality_after_filters = increment; -// } -// has_equality_filter = true; -// } -// if (child_filter->filter_type == TableFilterType::CONJUNCTION_AND) { -// auto &and_filter = child_filter->Cast(); -// cardinality_after_filters = RelationStatisticsHelper::InspectConjunctionAND( -// cardinality_after_filters, column_index, and_filter, base_stats); -// continue; -// } -// } -// D_ASSERT(cardinality_after_filters > 0); -// return cardinality_after_filters; -//} - -} // namespace duckdb - - - - -namespace duckdb { - -bool ExpressionMatcher::Match(Expression &expr, vector> &bindings) { - if (type && !type->Match(expr.return_type)) { - return false; - } - if (expr_type && !expr_type->Match(expr.type)) { - return false; - } - if (expr_class != ExpressionClass::INVALID && expr_class != expr.GetExpressionClass()) { - return false; - } - bindings.push_back(expr); - return true; -} - -bool ExpressionEqualityMatcher::Match(Expression &expr, vector> &bindings) { - if (!expr.Equals(expression)) { - return false; - } - bindings.push_back(expr); - return true; -} - -bool CaseExpressionMatcher::Match(Expression &expr_p, vector> &bindings) { - if (!ExpressionMatcher::Match(expr_p, bindings)) { - return false; - } - return true; -} - -bool ComparisonExpressionMatcher::Match(Expression &expr_p, vector> &bindings) { - if (!ExpressionMatcher::Match(expr_p, bindings)) { - return false; - } - auto &expr = expr_p.Cast(); - vector> expressions; - expressions.push_back(*expr.left); - expressions.push_back(*expr.right); - return SetMatcher::Match(matchers, expressions, bindings, policy); -} - -bool CastExpressionMatcher::Match(Expression &expr_p, vector> &bindings) { - if (!ExpressionMatcher::Match(expr_p, bindings)) { - return false; - } - if (!matcher) { - return true; - } - auto &expr = expr_p.Cast(); - return matcher->Match(*expr.child, bindings); -} - -bool InClauseExpressionMatcher::Match(Expression &expr_p, vector> &bindings) { - if (!ExpressionMatcher::Match(expr_p, bindings)) { - return false; - } - auto &expr = expr_p.Cast(); - if (expr.type != ExpressionType::COMPARE_IN || expr.type == ExpressionType::COMPARE_NOT_IN) { - return false; - } - return SetMatcher::Match(matchers, expr.children, bindings, policy); -} - -bool ConjunctionExpressionMatcher::Match(Expression &expr_p, vector> &bindings) { - if (!ExpressionMatcher::Match(expr_p, bindings)) { - return false; - } - auto &expr = expr_p.Cast(); - if (!SetMatcher::Match(matchers, expr.children, bindings, policy)) { - return false; - } - return true; -} - -bool FunctionExpressionMatcher::Match(Expression &expr_p, vector> &bindings) { - if (!ExpressionMatcher::Match(expr_p, bindings)) { - return false; - } - auto &expr = expr_p.Cast(); - if (!FunctionMatcher::Match(function, expr.function.name)) { - return false; - } - if (!SetMatcher::Match(matchers, expr.children, bindings, policy)) { - return false; - } - return true; -} - -bool FoldableConstantMatcher::Match(Expression &expr, vector> &bindings) { - // we match on ANY expression that is a scalar expression - if (!expr.IsFoldable()) { - return false; - } - bindings.push_back(expr); - return true; -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - - - - - - - - - - - -namespace duckdb { - -Optimizer::Optimizer(Binder &binder, ClientContext &context) : context(context), binder(binder), rewriter(context) { - rewriter.rules.push_back(make_uniq(rewriter)); - rewriter.rules.push_back(make_uniq(rewriter)); - rewriter.rules.push_back(make_uniq(rewriter)); - rewriter.rules.push_back(make_uniq(rewriter)); - rewriter.rules.push_back(make_uniq(rewriter)); - rewriter.rules.push_back(make_uniq(rewriter)); - rewriter.rules.push_back(make_uniq(rewriter)); - rewriter.rules.push_back(make_uniq(rewriter)); - rewriter.rules.push_back(make_uniq(rewriter)); - rewriter.rules.push_back(make_uniq(rewriter)); - rewriter.rules.push_back(make_uniq(rewriter)); - rewriter.rules.push_back(make_uniq(rewriter)); - rewriter.rules.push_back(make_uniq(rewriter)); - rewriter.rules.push_back(make_uniq(rewriter)); - rewriter.rules.push_back(make_uniq(rewriter)); - -#ifdef DEBUG - for (auto &rule : rewriter.rules) { - // root not defined in rule - D_ASSERT(rule->root); - } -#endif -} - -ClientContext &Optimizer::GetContext() { - return context; -} - -void Optimizer::RunOptimizer(OptimizerType type, const std::function &callback) { - auto &config = DBConfig::GetConfig(context); - if (config.options.disabled_optimizers.find(type) != config.options.disabled_optimizers.end()) { - // optimizer is marked as disabled: skip - return; - } - auto &profiler = QueryProfiler::Get(context); - profiler.StartPhase(OptimizerTypeToString(type)); - callback(); - profiler.EndPhase(); - if (plan) { - Verify(*plan); - } -} - -void Optimizer::Verify(LogicalOperator &op) { - ColumnBindingResolver::Verify(op); -} - -unique_ptr Optimizer::Optimize(unique_ptr plan_p) { - Verify(*plan_p); - - switch (plan_p->type) { - case LogicalOperatorType::LOGICAL_TRANSACTION: - return plan_p; // skip optimizing simple & often-occurring plans unaffected by rewrites - default: - break; - } - - this->plan = std::move(plan_p); - // first we perform expression rewrites using the ExpressionRewriter - // this does not change the logical plan structure, but only simplifies the expression trees - RunOptimizer(OptimizerType::EXPRESSION_REWRITER, [&]() { rewriter.VisitOperator(*plan); }); - - // perform filter pullup - RunOptimizer(OptimizerType::FILTER_PULLUP, [&]() { - FilterPullup filter_pullup; - plan = filter_pullup.Rewrite(std::move(plan)); - }); - - // perform filter pushdown - RunOptimizer(OptimizerType::FILTER_PUSHDOWN, [&]() { - FilterPushdown filter_pushdown(*this); - plan = filter_pushdown.Rewrite(std::move(plan)); - }); - - RunOptimizer(OptimizerType::REGEX_RANGE, [&]() { - RegexRangeFilter regex_opt; - plan = regex_opt.Rewrite(std::move(plan)); - }); - - RunOptimizer(OptimizerType::IN_CLAUSE, [&]() { - InClauseRewriter ic_rewriter(context, *this); - plan = ic_rewriter.Rewrite(std::move(plan)); - }); - - // removes any redundant DelimGets/DelimJoins - RunOptimizer(OptimizerType::DELIMINATOR, [&]() { - Deliminator deliminator; - plan = deliminator.Optimize(std::move(plan)); - }); - - // then we perform the join ordering optimization - // this also rewrites cross products + filters into joins and performs filter pushdowns - RunOptimizer(OptimizerType::JOIN_ORDER, [&]() { - JoinOrderOptimizer optimizer(context); - plan = optimizer.Optimize(std::move(plan)); - }); - - // rewrites UNNESTs in DelimJoins by moving them to the projection - RunOptimizer(OptimizerType::UNNEST_REWRITER, [&]() { - UnnestRewriter unnest_rewriter; - plan = unnest_rewriter.Optimize(std::move(plan)); - }); - - // removes unused columns - RunOptimizer(OptimizerType::UNUSED_COLUMNS, [&]() { - RemoveUnusedColumns unused(binder, context, true); - unused.VisitOperator(*plan); - }); - - // Remove duplicate groups from aggregates - RunOptimizer(OptimizerType::DUPLICATE_GROUPS, [&]() { - RemoveDuplicateGroups remove; - remove.VisitOperator(*plan); - }); - - // then we extract common subexpressions inside the different operators - RunOptimizer(OptimizerType::COMMON_SUBEXPRESSIONS, [&]() { - CommonSubExpressionOptimizer cse_optimizer(binder); - cse_optimizer.VisitOperator(*plan); - }); - - // creates projection maps so unused columns are projected out early - RunOptimizer(OptimizerType::COLUMN_LIFETIME, [&]() { - ColumnLifetimeAnalyzer column_lifetime(true); - column_lifetime.VisitOperator(*plan); - }); - - // perform statistics propagation - column_binding_map_t> statistics_map; - RunOptimizer(OptimizerType::STATISTICS_PROPAGATION, [&]() { - StatisticsPropagator propagator(*this); - propagator.PropagateStatistics(plan); - statistics_map = propagator.GetStatisticsMap(); - }); - - // remove duplicate aggregates - RunOptimizer(OptimizerType::COMMON_AGGREGATE, [&]() { - CommonAggregateOptimizer common_aggregate; - common_aggregate.VisitOperator(*plan); - }); - - // creates projection maps so unused columns are projected out early - RunOptimizer(OptimizerType::COLUMN_LIFETIME, [&]() { - ColumnLifetimeAnalyzer column_lifetime(true); - column_lifetime.VisitOperator(*plan); - }); - - // compress data based on statistics for materializing operators - RunOptimizer(OptimizerType::COMPRESSED_MATERIALIZATION, [&]() { - CompressedMaterialization compressed_materialization(context, binder, std::move(statistics_map)); - compressed_materialization.Compress(plan); - }); - - // transform ORDER BY + LIMIT to TopN - RunOptimizer(OptimizerType::TOP_N, [&]() { - TopN topn; - plan = topn.Optimize(std::move(plan)); - }); - - // apply simple expression heuristics to get an initial reordering - RunOptimizer(OptimizerType::REORDER_FILTER, [&]() { - ExpressionHeuristics expression_heuristics(*this); - plan = expression_heuristics.Rewrite(std::move(plan)); - }); - - for (auto &optimizer_extension : DBConfig::GetConfig(context).optimizer_extensions) { - RunOptimizer(OptimizerType::EXTENSION, [&]() { - optimizer_extension.optimize_function(context, optimizer_extension.optimizer_info.get(), plan); - }); - } - - Planner::VerifyPlan(context, plan); - - return std::move(plan); -} - -} // namespace duckdb - - -namespace duckdb { - -unique_ptr FilterPullup::PullupBothSide(unique_ptr op) { - FilterPullup left_pullup(true, can_add_column); - FilterPullup right_pullup(true, can_add_column); - op->children[0] = left_pullup.Rewrite(std::move(op->children[0])); - op->children[1] = right_pullup.Rewrite(std::move(op->children[1])); - D_ASSERT(left_pullup.can_add_column == can_add_column); - D_ASSERT(right_pullup.can_add_column == can_add_column); - - // merging filter expressions - for (idx_t i = 0; i < right_pullup.filters_expr_pullup.size(); ++i) { - left_pullup.filters_expr_pullup.push_back(std::move(right_pullup.filters_expr_pullup[i])); - } - - if (!left_pullup.filters_expr_pullup.empty()) { - return GeneratePullupFilter(std::move(op), left_pullup.filters_expr_pullup); - } - return op; -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -unique_ptr FilterPullup::PullupFilter(unique_ptr op) { - D_ASSERT(op->type == LogicalOperatorType::LOGICAL_FILTER); - - auto &filter = op->Cast(); - if (can_pullup && filter.projection_map.empty()) { - unique_ptr child = std::move(op->children[0]); - child = Rewrite(std::move(child)); - // moving filter's expressions - for (idx_t i = 0; i < op->expressions.size(); ++i) { - filters_expr_pullup.push_back(std::move(op->expressions[i])); - } - return child; - } - op->children[0] = Rewrite(std::move(op->children[0])); - return op; -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr FilterPullup::PullupFromLeft(unique_ptr op) { - D_ASSERT(op->type == LogicalOperatorType::LOGICAL_COMPARISON_JOIN || - op->type == LogicalOperatorType::LOGICAL_ASOF_JOIN || op->type == LogicalOperatorType::LOGICAL_ANY_JOIN || - op->type == LogicalOperatorType::LOGICAL_EXCEPT || op->type == LogicalOperatorType::LOGICAL_DELIM_JOIN); - - FilterPullup left_pullup(true, can_add_column); - FilterPullup right_pullup(false, can_add_column); - - op->children[0] = left_pullup.Rewrite(std::move(op->children[0])); - op->children[1] = right_pullup.Rewrite(std::move(op->children[1])); - - // check only for filters from the LHS - if (!left_pullup.filters_expr_pullup.empty() && right_pullup.filters_expr_pullup.empty()) { - return GeneratePullupFilter(std::move(op), left_pullup.filters_expr_pullup); - } - return op; -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -static void RevertFilterPullup(LogicalProjection &proj, vector> &expressions) { - unique_ptr filter = make_uniq(); - for (idx_t i = 0; i < expressions.size(); ++i) { - filter->expressions.push_back(std::move(expressions[i])); - } - expressions.clear(); - filter->children.push_back(std::move(proj.children[0])); - proj.children[0] = std::move(filter); -} - -static void ReplaceExpressionBinding(vector> &proj_expressions, Expression &expr, - idx_t proj_table_idx) { - if (expr.type == ExpressionType::BOUND_COLUMN_REF) { - bool found_proj_col = false; - BoundColumnRefExpression &colref = expr.Cast(); - // find the corresponding column index in the projection expressions - for (idx_t proj_idx = 0; proj_idx < proj_expressions.size(); proj_idx++) { - auto &proj_expr = *proj_expressions[proj_idx]; - if (proj_expr.type == ExpressionType::BOUND_COLUMN_REF) { - if (colref.Equals(proj_expr)) { - colref.binding.table_index = proj_table_idx; - colref.binding.column_index = proj_idx; - found_proj_col = true; - break; - } - } - } - if (!found_proj_col) { - // Project a new column - auto new_colref = colref.Copy(); - colref.binding.table_index = proj_table_idx; - colref.binding.column_index = proj_expressions.size(); - proj_expressions.push_back(std::move(new_colref)); - } - } - ExpressionIterator::EnumerateChildren( - expr, [&](Expression &child) { return ReplaceExpressionBinding(proj_expressions, child, proj_table_idx); }); -} - -void FilterPullup::ProjectSetOperation(LogicalProjection &proj) { - vector> copy_proj_expressions; - // copying the project expressions, it's useful whether we should revert the filter pullup - for (idx_t i = 0; i < proj.expressions.size(); ++i) { - copy_proj_expressions.push_back(proj.expressions[i]->Copy()); - } - - // Replace filter expression bindings, when need we add new columns into the copied projection expression - vector> changed_filter_expressions; - for (idx_t i = 0; i < filters_expr_pullup.size(); ++i) { - auto copy_filter_expr = filters_expr_pullup[i]->Copy(); - ReplaceExpressionBinding(copy_proj_expressions, (Expression &)*copy_filter_expr, proj.table_index); - changed_filter_expressions.push_back(std::move(copy_filter_expr)); - } - - /// Case new columns were added into the projection - // we must skip filter pullup because adding new columns to these operators will change the result - if (copy_proj_expressions.size() > proj.expressions.size()) { - RevertFilterPullup(proj, filters_expr_pullup); - return; - } - - // now we must replace the filter bindings - D_ASSERT(filters_expr_pullup.size() == changed_filter_expressions.size()); - for (idx_t i = 0; i < filters_expr_pullup.size(); ++i) { - filters_expr_pullup[i] = std::move(changed_filter_expressions[i]); - } -} - -unique_ptr FilterPullup::PullupProjection(unique_ptr op) { - D_ASSERT(op->type == LogicalOperatorType::LOGICAL_PROJECTION); - op->children[0] = Rewrite(std::move(op->children[0])); - if (!filters_expr_pullup.empty()) { - auto &proj = op->Cast(); - // INTERSECT, EXCEPT, and DISTINCT - if (!can_add_column) { - // special treatment for operators that cannot add columns, e.g., INTERSECT, EXCEPT, and DISTINCT - ProjectSetOperation(proj); - return op; - } - - for (idx_t i = 0; i < filters_expr_pullup.size(); ++i) { - auto &expr = (Expression &)*filters_expr_pullup[i]; - ReplaceExpressionBinding(proj.expressions, expr, proj.table_index); - } - } - return op; -} - -} // namespace duckdb - - - - - -namespace duckdb { - -static void ReplaceFilterTableIndex(Expression &expr, LogicalSetOperation &setop) { - if (expr.type == ExpressionType::BOUND_COLUMN_REF) { - auto &colref = expr.Cast(); - D_ASSERT(colref.depth == 0); - - colref.binding.table_index = setop.table_index; - return; - } - ExpressionIterator::EnumerateChildren(expr, [&](Expression &child) { ReplaceFilterTableIndex(child, setop); }); -} - -unique_ptr FilterPullup::PullupSetOperation(unique_ptr op) { - D_ASSERT(op->type == LogicalOperatorType::LOGICAL_INTERSECT || op->type == LogicalOperatorType::LOGICAL_EXCEPT); - can_add_column = false; - can_pullup = true; - if (op->type == LogicalOperatorType::LOGICAL_INTERSECT) { - op = PullupBothSide(std::move(op)); - } else { - // EXCEPT only pull ups from LHS - op = PullupFromLeft(std::move(op)); - } - if (op->type == LogicalOperatorType::LOGICAL_FILTER) { - auto &filter = op->Cast(); - auto &setop = filter.children[0]->Cast(); - for (idx_t i = 0; i < filter.expressions.size(); ++i) { - ReplaceFilterTableIndex(*filter.expressions[i], setop); - } - } - return op; -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -using Filter = FilterPushdown::Filter; - -static void ExtractFilterBindings(Expression &expr, vector &bindings) { - if (expr.type == ExpressionType::BOUND_COLUMN_REF) { - auto &colref = expr.Cast(); - bindings.push_back(colref.binding); - } - ExpressionIterator::EnumerateChildren(expr, [&](Expression &child) { ExtractFilterBindings(child, bindings); }); -} - -static unique_ptr ReplaceGroupBindings(LogicalAggregate &proj, unique_ptr expr) { - if (expr->type == ExpressionType::BOUND_COLUMN_REF) { - auto &colref = expr->Cast(); - D_ASSERT(colref.binding.table_index == proj.group_index); - D_ASSERT(colref.binding.column_index < proj.groups.size()); - D_ASSERT(colref.depth == 0); - // replace the binding with a copy to the expression at the referenced index - return proj.groups[colref.binding.column_index]->Copy(); - } - ExpressionIterator::EnumerateChildren( - *expr, [&](unique_ptr &child) { child = ReplaceGroupBindings(proj, std::move(child)); }); - return expr; -} - -unique_ptr FilterPushdown::PushdownAggregate(unique_ptr op) { - D_ASSERT(op->type == LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY); - auto &aggr = op->Cast(); - - // pushdown into AGGREGATE and GROUP BY - // we cannot push expressions that refer to the aggregate - FilterPushdown child_pushdown(optimizer); - for (idx_t i = 0; i < filters.size(); i++) { - auto &f = *filters[i]; - if (f.bindings.find(aggr.aggregate_index) != f.bindings.end()) { - // filter on aggregate: cannot pushdown - continue; - } - if (f.bindings.find(aggr.groupings_index) != f.bindings.end()) { - // filter on GROUPINGS function: cannot pushdown - continue; - } - // no aggregate! we are filtering on a group - // we can only push this down if the filter is in all grouping sets - vector bindings; - ExtractFilterBindings(*f.filter, bindings); - - bool can_pushdown_filter = true; - if (aggr.grouping_sets.empty()) { - // empty grouping set - we cannot pushdown the filter - can_pushdown_filter = false; - } - for (auto &grp : aggr.grouping_sets) { - // check for each of the grouping sets if they contain all groups - if (bindings.empty()) { - // we can never push down empty grouping sets - can_pushdown_filter = false; - break; - } - for (auto &binding : bindings) { - if (grp.find(binding.column_index) == grp.end()) { - can_pushdown_filter = false; - break; - } - } - if (!can_pushdown_filter) { - break; - } - } - if (!can_pushdown_filter) { - continue; - } - // no aggregate! we can push this down - // rewrite any group bindings within the filter - f.filter = ReplaceGroupBindings(aggr, std::move(f.filter)); - // add the filter to the child node - if (child_pushdown.AddFilter(std::move(f.filter)) == FilterResult::UNSATISFIABLE) { - // filter statically evaluates to false, strip tree - return make_uniq(std::move(op)); - } - // erase the filter from here - filters.erase(filters.begin() + i); - i--; - } - child_pushdown.GenerateFilters(); - - op->children[0] = child_pushdown.Rewrite(std::move(op->children[0])); - return FinishPushdown(std::move(op)); -} - -} // namespace duckdb - - - - -namespace duckdb { - -using Filter = FilterPushdown::Filter; - -unique_ptr FilterPushdown::PushdownCrossProduct(unique_ptr op) { - D_ASSERT(op->type == LogicalOperatorType::LOGICAL_CROSS_PRODUCT); - FilterPushdown left_pushdown(optimizer), right_pushdown(optimizer); - vector> join_expressions; - unordered_set left_bindings, right_bindings; - if (!filters.empty()) { - // check to see into which side we should push the filters - // first get the LHS and RHS bindings - LogicalJoin::GetTableReferences(*op->children[0], left_bindings); - LogicalJoin::GetTableReferences(*op->children[1], right_bindings); - // now check the set of filters - for (auto &f : filters) { - auto side = JoinSide::GetJoinSide(f->bindings, left_bindings, right_bindings); - if (side == JoinSide::LEFT) { - // bindings match left side: push into left - left_pushdown.filters.push_back(std::move(f)); - } else if (side == JoinSide::RIGHT) { - // bindings match right side: push into right - right_pushdown.filters.push_back(std::move(f)); - } else { - D_ASSERT(side == JoinSide::BOTH || side == JoinSide::NONE); - // bindings match both: turn into join condition - join_expressions.push_back(std::move(f->filter)); - } - } - } - - op->children[0] = left_pushdown.Rewrite(std::move(op->children[0])); - op->children[1] = right_pushdown.Rewrite(std::move(op->children[1])); - - if (!join_expressions.empty()) { - // join conditions found: turn into inner join - // extract join conditions - vector conditions; - vector> arbitrary_expressions; - auto join_type = JoinType::INNER; - LogicalComparisonJoin::ExtractJoinConditions(GetContext(), join_type, op->children[0], op->children[1], - left_bindings, right_bindings, join_expressions, conditions, - arbitrary_expressions); - // create the join from the join conditions - return LogicalComparisonJoin::CreateJoin(GetContext(), JoinType::INNER, JoinRefType::REGULAR, - std::move(op->children[0]), std::move(op->children[1]), - std::move(conditions), std::move(arbitrary_expressions)); - } else { - // no join conditions found: keep as cross product - return op; - } -} - -} // namespace duckdb - - - - -namespace duckdb { - -using Filter = FilterPushdown::Filter; - -unique_ptr FilterPushdown::PushdownFilter(unique_ptr op) { - D_ASSERT(op->type == LogicalOperatorType::LOGICAL_FILTER); - auto &filter = op->Cast(); - if (!filter.projection_map.empty()) { - return FinishPushdown(std::move(op)); - } - // filter: gather the filters and remove the filter from the set of operations - for (auto &expression : filter.expressions) { - if (AddFilter(std::move(expression)) == FilterResult::UNSATISFIABLE) { - // filter statically evaluates to false, strip tree - return make_uniq(std::move(op)); - } - } - GenerateFilters(); - return Rewrite(std::move(filter.children[0])); -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -unique_ptr FilterPushdown::PushdownGet(unique_ptr op) { - D_ASSERT(op->type == LogicalOperatorType::LOGICAL_GET); - auto &get = op->Cast(); - - if (get.function.pushdown_complex_filter || get.function.filter_pushdown) { - // this scan supports some form of filter push-down - // check if there are any parameters - // if there are, invalidate them to force a re-bind on execution - for (auto &filter : filters) { - if (filter->filter->HasParameter()) { - // there is a parameter in the filters! invalidate it - BoundParameterExpression::InvalidateRecursive(*filter->filter); - } - } - } - if (get.function.pushdown_complex_filter) { - // for the remaining filters, check if we can push any of them into the scan as well - vector> expressions; - expressions.reserve(filters.size()); - for (auto &filter : filters) { - expressions.push_back(std::move(filter->filter)); - } - filters.clear(); - - get.function.pushdown_complex_filter(optimizer.context, get, get.bind_data.get(), expressions); - - if (expressions.empty()) { - return op; - } - // re-generate the filters - for (auto &expr : expressions) { - auto f = make_uniq(); - f->filter = std::move(expr); - f->ExtractBindings(); - filters.push_back(std::move(f)); - } - } - - if (!get.table_filters.filters.empty() || !get.function.filter_pushdown) { - // the table function does not support filter pushdown: push a LogicalFilter on top - return FinishPushdown(std::move(op)); - } - PushFilters(); - - //! We generate the table filters that will be executed during the table scan - //! Right now this only executes simple AND filters - get.table_filters = combiner.GenerateTableScanFilters(get.column_ids); - - // //! For more complex filters if all filters to a column are constants we generate a min max boundary used to - // check - // //! the zonemaps. - // auto zonemap_checks = combiner.GenerateZonemapChecks(get.column_ids, get.table_filters); - - // for (auto &f : get.table_filters) { - // f.column_index = get.column_ids[f.column_index]; - // } - - // //! Use zonemap checks as table filters for pre-processing - // for (auto &zonemap_check : zonemap_checks) { - // if (zonemap_check.column_index != COLUMN_IDENTIFIER_ROW_ID) { - // get.table_filters.push_back(zonemap_check); - // } - // } - - GenerateFilters(); - - //! Now we try to pushdown the remaining filters to perform zonemap checking - return FinishPushdown(std::move(op)); -} - -} // namespace duckdb - -#endif diff --git a/lib/duckdb-8.cpp b/lib/duckdb-8.cpp deleted file mode 100644 index b1165c62..00000000 --- a/lib/duckdb-8.cpp +++ /dev/null @@ -1,22093 +0,0 @@ -#ifdef GODUCKDB_FROM_SOURCE -// See https://raw.githubusercontent.com/duckdb/duckdb/main/LICENSE for licensing information - -#include "duckdb.hpp" -#include "duckdb-internal.hpp" -#ifndef DUCKDB_AMALGAMATION -#error header mismatch -#endif - - - - - - -namespace duckdb { - -using Filter = FilterPushdown::Filter; - -unique_ptr FilterPushdown::PushdownInnerJoin(unique_ptr op, - unordered_set &left_bindings, - unordered_set &right_bindings) { - auto &join = op->Cast(); - D_ASSERT(join.join_type == JoinType::INNER); - if (op->type == LogicalOperatorType::LOGICAL_DELIM_JOIN) { - return FinishPushdown(std::move(op)); - } - // inner join: gather all the conditions of the inner join and add to the filter list - if (op->type == LogicalOperatorType::LOGICAL_ANY_JOIN) { - auto &any_join = join.Cast(); - // any join: only one filter to add - if (AddFilter(std::move(any_join.condition)) == FilterResult::UNSATISFIABLE) { - // filter statically evaluates to false, strip tree - return make_uniq(std::move(op)); - } - } else if (op->type == LogicalOperatorType::LOGICAL_ASOF_JOIN) { - // Don't mess with non-standard condition interpretations - return FinishPushdown(std::move(op)); - } else { - // comparison join - D_ASSERT(op->type == LogicalOperatorType::LOGICAL_COMPARISON_JOIN); - auto &comp_join = join.Cast(); - // turn the conditions into filters - for (auto &i : comp_join.conditions) { - auto condition = JoinCondition::CreateExpression(std::move(i)); - if (AddFilter(std::move(condition)) == FilterResult::UNSATISFIABLE) { - // filter statically evaluates to false, strip tree - return make_uniq(std::move(op)); - } - } - } - GenerateFilters(); - - // turn the inner join into a cross product - auto cross_product = make_uniq(std::move(op->children[0]), std::move(op->children[1])); - // then push down cross product - return PushdownCrossProduct(std::move(cross_product)); -} - -} // namespace duckdb - - - - - - - - - - -namespace duckdb { - -using Filter = FilterPushdown::Filter; - -static unique_ptr ReplaceColRefWithNull(unique_ptr expr, unordered_set &right_bindings) { - if (expr->type == ExpressionType::BOUND_COLUMN_REF) { - auto &bound_colref = expr->Cast(); - if (right_bindings.find(bound_colref.binding.table_index) != right_bindings.end()) { - // bound colref belongs to RHS - // replace it with a constant NULL - return make_uniq(Value(expr->return_type)); - } - return expr; - } - ExpressionIterator::EnumerateChildren( - *expr, [&](unique_ptr &child) { child = ReplaceColRefWithNull(std::move(child), right_bindings); }); - return expr; -} - -static bool FilterRemovesNull(ClientContext &context, ExpressionRewriter &rewriter, Expression *expr, - unordered_set &right_bindings) { - // make a copy of the expression - auto copy = expr->Copy(); - // replace all BoundColumnRef expressions frmo the RHS with NULL constants in the copied expression - copy = ReplaceColRefWithNull(std::move(copy), right_bindings); - - // attempt to flatten the expression by running the expression rewriter on it - auto filter = make_uniq(); - filter->expressions.push_back(std::move(copy)); - rewriter.VisitOperator(*filter); - - // check if all expressions are foldable - for (idx_t i = 0; i < filter->expressions.size(); i++) { - if (!filter->expressions[i]->IsFoldable()) { - return false; - } - // we flattened the result into a scalar, check if it is FALSE or NULL - auto val = - ExpressionExecutor::EvaluateScalar(context, *filter->expressions[i]).DefaultCastAs(LogicalType::BOOLEAN); - // if the result of the expression with all expressions replaced with NULL is "NULL" or "false" - // then any extra entries generated by the LEFT OUTER JOIN will be filtered out! - // hence the LEFT OUTER JOIN is equivalent to an inner join - if (val.IsNull() || !BooleanValue::Get(val)) { - return true; - } - } - return false; -} - -unique_ptr FilterPushdown::PushdownLeftJoin(unique_ptr op, - unordered_set &left_bindings, - unordered_set &right_bindings) { - auto &join = op->Cast(); - if (op->type == LogicalOperatorType::LOGICAL_DELIM_JOIN) { - return FinishPushdown(std::move(op)); - } - FilterPushdown left_pushdown(optimizer), right_pushdown(optimizer); - // for a comparison join we create a FilterCombiner that checks if we can push conditions on LHS join conditions - // into the RHS of the join - FilterCombiner filter_combiner(optimizer); - const auto isComparison = (op->type == LogicalOperatorType::LOGICAL_COMPARISON_JOIN || - op->type == LogicalOperatorType::LOGICAL_ASOF_JOIN); - if (isComparison) { - // add all comparison conditions - auto &comparison_join = op->Cast(); - for (auto &cond : comparison_join.conditions) { - filter_combiner.AddFilter( - make_uniq(cond.comparison, cond.left->Copy(), cond.right->Copy())); - } - } - // now check the set of filters - for (idx_t i = 0; i < filters.size(); i++) { - auto side = JoinSide::GetJoinSide(filters[i]->bindings, left_bindings, right_bindings); - if (side == JoinSide::LEFT) { - // bindings match left side - // we can push the filter into the left side - if (isComparison) { - // we MIGHT be able to push it down the RHS as well, but only if it is a comparison that matches the - // join predicates we use the FilterCombiner to figure this out add the expression to the FilterCombiner - filter_combiner.AddFilter(filters[i]->filter->Copy()); - } - left_pushdown.filters.push_back(std::move(filters[i])); - // erase the filter from the list of filters - filters.erase(filters.begin() + i); - i--; - } else { - // bindings match right side or both sides: we cannot directly push it into the right - // however, if the filter removes rows with null values from the RHS we can turn the left outer join - // in an inner join, and then push down as we would push down an inner join - if (FilterRemovesNull(optimizer.context, optimizer.rewriter, filters[i]->filter.get(), right_bindings)) { - // the filter removes NULL values, turn it into an inner join - join.join_type = JoinType::INNER; - // now we can do more pushdown - // move all filters we added to the left_pushdown back into the filter list - for (auto &left_filter : left_pushdown.filters) { - filters.push_back(std::move(left_filter)); - } - // now push down the inner join - return PushdownInnerJoin(std::move(op), left_bindings, right_bindings); - } - } - } - // finally we check the FilterCombiner to see if there are any predicates we can push into the RHS - // we only added (1) predicates that have JoinSide::BOTH from the conditions, and - // (2) predicates that have JoinSide::LEFT from the filters - // we check now if this combination generated any new filters that are only on JoinSide::RIGHT - // this happens if, e.g. a join condition is (i=a) and there is a filter (i=500), we can then push the filter - // (a=500) into the RHS - filter_combiner.GenerateFilters([&](unique_ptr filter) { - if (JoinSide::GetJoinSide(*filter, left_bindings, right_bindings) == JoinSide::RIGHT) { - right_pushdown.AddFilter(std::move(filter)); - } - }); - right_pushdown.GenerateFilters(); - op->children[0] = left_pushdown.Rewrite(std::move(op->children[0])); - op->children[1] = right_pushdown.Rewrite(std::move(op->children[1])); - return PushFinalFilters(std::move(op)); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -unique_ptr FilterPushdown::PushdownLimit(unique_ptr op) { - auto &limit = op->Cast(); - - if (!limit.limit && limit.limit_val == 0) { - return make_uniq(std::move(op)); - } - - return FinishPushdown(std::move(op)); -} - -} // namespace duckdb - - - - -namespace duckdb { - -using Filter = FilterPushdown::Filter; - -unique_ptr FilterPushdown::PushdownMarkJoin(unique_ptr op, - unordered_set &left_bindings, - unordered_set &right_bindings) { - auto &join = op->Cast(); - auto &comp_join = op->Cast(); - D_ASSERT(join.join_type == JoinType::MARK); - D_ASSERT(op->type == LogicalOperatorType::LOGICAL_COMPARISON_JOIN || - op->type == LogicalOperatorType::LOGICAL_DELIM_JOIN || op->type == LogicalOperatorType::LOGICAL_ASOF_JOIN); - - right_bindings.insert(comp_join.mark_index); - FilterPushdown left_pushdown(optimizer), right_pushdown(optimizer); -#ifdef DEBUG - bool simplified_mark_join = false; -#endif - // now check the set of filters - for (idx_t i = 0; i < filters.size(); i++) { - auto side = JoinSide::GetJoinSide(filters[i]->bindings, left_bindings, right_bindings); - if (side == JoinSide::LEFT) { - // bindings match left side: push into left - left_pushdown.filters.push_back(std::move(filters[i])); - // erase the filter from the list of filters - filters.erase(filters.begin() + i); - i--; - } else if (side == JoinSide::RIGHT) { -#ifdef DEBUG - D_ASSERT(!simplified_mark_join); -#endif - // this filter references the marker - // we can turn this into a SEMI join if the filter is on only the marker - if (filters[i]->filter->type == ExpressionType::BOUND_COLUMN_REF) { - // filter just references the marker: turn into semi join -#ifdef DEBUG - simplified_mark_join = true; -#endif - join.join_type = JoinType::SEMI; - filters.erase(filters.begin() + i); - i--; - continue; - } - // if the filter is on NOT(marker) AND the join conditions are all set to "null_values_are_equal" we can - // turn this into an ANTI join if all join conditions have null_values_are_equal=true, then the result of - // the MARK join is always TRUE or FALSE, and never NULL this happens in the case of a correlated EXISTS - // clause - if (filters[i]->filter->type == ExpressionType::OPERATOR_NOT) { - auto &op_expr = filters[i]->filter->Cast(); - if (op_expr.children[0]->type == ExpressionType::BOUND_COLUMN_REF) { - // the filter is NOT(marker), check the join conditions - bool all_null_values_are_equal = true; - for (auto &cond : comp_join.conditions) { - if (cond.comparison != ExpressionType::COMPARE_DISTINCT_FROM && - cond.comparison != ExpressionType::COMPARE_NOT_DISTINCT_FROM) { - all_null_values_are_equal = false; - break; - } - } - if (all_null_values_are_equal) { -#ifdef DEBUG - simplified_mark_join = true; -#endif - // all null values are equal, convert to ANTI join - join.join_type = JoinType::ANTI; - filters.erase(filters.begin() + i); - i--; - continue; - } - } - } - } - } - op->children[0] = left_pushdown.Rewrite(std::move(op->children[0])); - op->children[1] = right_pushdown.Rewrite(std::move(op->children[1])); - return PushFinalFilters(std::move(op)); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -static bool HasSideEffects(LogicalProjection &proj, const unique_ptr &expr) { - if (expr->type == ExpressionType::BOUND_COLUMN_REF) { - auto &colref = expr->Cast(); - D_ASSERT(colref.binding.table_index == proj.table_index); - D_ASSERT(colref.binding.column_index < proj.expressions.size()); - D_ASSERT(colref.depth == 0); - if (proj.expressions[colref.binding.column_index]->HasSideEffects()) { - return true; - } - return false; - } - bool has_side_effects = false; - ExpressionIterator::EnumerateChildren( - *expr, [&](unique_ptr &child) { has_side_effects |= HasSideEffects(proj, child); }); - return has_side_effects; -} - -static unique_ptr ReplaceProjectionBindings(LogicalProjection &proj, unique_ptr expr) { - if (expr->type == ExpressionType::BOUND_COLUMN_REF) { - auto &colref = expr->Cast(); - D_ASSERT(colref.binding.table_index == proj.table_index); - D_ASSERT(colref.binding.column_index < proj.expressions.size()); - D_ASSERT(colref.depth == 0); - // replace the binding with a copy to the expression at the referenced index - return proj.expressions[colref.binding.column_index]->Copy(); - } - ExpressionIterator::EnumerateChildren( - *expr, [&](unique_ptr &child) { child = ReplaceProjectionBindings(proj, std::move(child)); }); - return expr; -} - -unique_ptr FilterPushdown::PushdownProjection(unique_ptr op) { - D_ASSERT(op->type == LogicalOperatorType::LOGICAL_PROJECTION); - auto &proj = op->Cast(); - // push filter through logical projection - // all the BoundColumnRefExpressions in the filter should refer to the LogicalProjection - // we can rewrite them by replacing those references with the expression of the LogicalProjection node - FilterPushdown child_pushdown(optimizer); - // There are some expressions can not be pushed down. We should keep them - // and add an extra filter operator. - vector> remain_expressions; - for (auto &filter : filters) { - auto &f = *filter; - D_ASSERT(f.bindings.size() <= 1); - bool has_side_effects = HasSideEffects(proj, f.filter); - if (has_side_effects) { - // We can't push down related expressions if the column in the - // expression is generated by the functions which have side effects - remain_expressions.push_back(std::move(f.filter)); - } else { - // rewrite the bindings within this subquery - f.filter = ReplaceProjectionBindings(proj, std::move(f.filter)); - // add the filter to the child pushdown - if (child_pushdown.AddFilter(std::move(f.filter)) == FilterResult::UNSATISFIABLE) { - // filter statically evaluates to false, strip tree - return make_uniq(std::move(op)); - } - } - } - child_pushdown.GenerateFilters(); - // now push into children - op->children[0] = child_pushdown.Rewrite(std::move(op->children[0])); - if (op->children[0]->type == LogicalOperatorType::LOGICAL_EMPTY_RESULT) { - // child returns an empty result: generate an empty result here too - return make_uniq(std::move(op)); - } - return AddLogicalFilter(std::move(op), std::move(remain_expressions)); -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -using Filter = FilterPushdown::Filter; - -static void ReplaceSetOpBindings(vector &bindings, Filter &filter, Expression &expr, - LogicalSetOperation &setop) { - if (expr.type == ExpressionType::BOUND_COLUMN_REF) { - auto &colref = expr.Cast(); - D_ASSERT(colref.binding.table_index == setop.table_index); - D_ASSERT(colref.depth == 0); - - // rewrite the binding by looking into the bound_tables list of the subquery - colref.binding = bindings[colref.binding.column_index]; - filter.bindings.insert(colref.binding.table_index); - return; - } - ExpressionIterator::EnumerateChildren( - expr, [&](Expression &child) { ReplaceSetOpBindings(bindings, filter, child, setop); }); -} - -unique_ptr FilterPushdown::PushdownSetOperation(unique_ptr op) { - D_ASSERT(op->type == LogicalOperatorType::LOGICAL_UNION || op->type == LogicalOperatorType::LOGICAL_EXCEPT || - op->type == LogicalOperatorType::LOGICAL_INTERSECT); - auto &setop = op->Cast(); - - D_ASSERT(op->children.size() == 2); - auto left_bindings = op->children[0]->GetColumnBindings(); - auto right_bindings = op->children[1]->GetColumnBindings(); - if (left_bindings.size() != right_bindings.size()) { - throw InternalException("Filter pushdown - set operation LHS and RHS have incompatible counts"); - } - - // pushdown into set operation, we can duplicate the condition and pushdown the expressions into both sides - FilterPushdown left_pushdown(optimizer), right_pushdown(optimizer); - for (idx_t i = 0; i < filters.size(); i++) { - // first create a copy of the filter - auto right_filter = make_uniq(); - right_filter->filter = filters[i]->filter->Copy(); - - // in the original filter, rewrite references to the result of the union into references to the left_index - ReplaceSetOpBindings(left_bindings, *filters[i], *filters[i]->filter, setop); - // in the copied filter, rewrite references to the result of the union into references to the right_index - ReplaceSetOpBindings(right_bindings, *right_filter, *right_filter->filter, setop); - - // extract bindings again - filters[i]->ExtractBindings(); - right_filter->ExtractBindings(); - - // move the filters into the child pushdown nodes - left_pushdown.filters.push_back(std::move(filters[i])); - right_pushdown.filters.push_back(std::move(right_filter)); - } - - op->children[0] = left_pushdown.Rewrite(std::move(op->children[0])); - op->children[1] = right_pushdown.Rewrite(std::move(op->children[1])); - - bool left_empty = op->children[0]->type == LogicalOperatorType::LOGICAL_EMPTY_RESULT; - bool right_empty = op->children[1]->type == LogicalOperatorType::LOGICAL_EMPTY_RESULT; - if (left_empty && right_empty) { - // both empty: return empty result - return make_uniq(std::move(op)); - } - if (left_empty) { - // left child is empty result - switch (op->type) { - case LogicalOperatorType::LOGICAL_UNION: - if (op->children[1]->type == LogicalOperatorType::LOGICAL_PROJECTION) { - // union with empty left side: return right child - auto &projection = op->children[1]->Cast(); - projection.table_index = setop.table_index; - return std::move(op->children[1]); - } - break; - case LogicalOperatorType::LOGICAL_EXCEPT: - // except: if left child is empty, return empty result - case LogicalOperatorType::LOGICAL_INTERSECT: - // intersect: if any child is empty, return empty result itself - return make_uniq(std::move(op)); - default: - throw InternalException("Unsupported set operation"); - } - } else if (right_empty) { - // right child is empty result - switch (op->type) { - case LogicalOperatorType::LOGICAL_UNION: - case LogicalOperatorType::LOGICAL_EXCEPT: - if (op->children[0]->type == LogicalOperatorType::LOGICAL_PROJECTION) { - // union or except with empty right child: return left child - auto &projection = op->children[0]->Cast(); - projection.table_index = setop.table_index; - return std::move(op->children[0]); - } - break; - case LogicalOperatorType::LOGICAL_INTERSECT: - // intersect: if any child is empty, return empty result itself - return make_uniq(std::move(op)); - default: - throw InternalException("Unsupported set operation"); - } - } - return op; -} - -} // namespace duckdb - - - -namespace duckdb { - -using Filter = FilterPushdown::Filter; - -unique_ptr FilterPushdown::PushdownSingleJoin(unique_ptr op, - unordered_set &left_bindings, - unordered_set &right_bindings) { - D_ASSERT(op->Cast().join_type == JoinType::SINGLE); - FilterPushdown left_pushdown(optimizer), right_pushdown(optimizer); - // now check the set of filters - for (idx_t i = 0; i < filters.size(); i++) { - auto side = JoinSide::GetJoinSide(filters[i]->bindings, left_bindings, right_bindings); - if (side == JoinSide::LEFT) { - // bindings match left side: push into left - left_pushdown.filters.push_back(std::move(filters[i])); - // erase the filter from the list of filters - filters.erase(filters.begin() + i); - i--; - } - } - op->children[0] = left_pushdown.Rewrite(std::move(op->children[0])); - op->children[1] = right_pushdown.Rewrite(std::move(op->children[1])); - return PushFinalFilters(std::move(op)); -} - -} // namespace duckdb - - - - - - - - - - - - - - - -namespace duckdb { - -unique_ptr RegexRangeFilter::Rewrite(unique_ptr op) { - - for (idx_t child_idx = 0; child_idx < op->children.size(); child_idx++) { - op->children[child_idx] = Rewrite(std::move(op->children[child_idx])); - } - - if (op->type != LogicalOperatorType::LOGICAL_FILTER) { - return op; - } - - auto new_filter = make_uniq(); - - for (auto &expr : op->expressions) { - if (expr->type == ExpressionType::BOUND_FUNCTION) { - auto &func = expr->Cast(); - if (func.function.name != "regexp_full_match" || func.children.size() != 2) { - continue; - } - auto &info = func.bind_info->Cast(); - if (!info.range_success) { - continue; - } - auto filter_left = make_uniq( - ExpressionType::COMPARE_GREATERTHANOREQUALTO, func.children[0]->Copy(), - make_uniq(Value::BLOB_RAW(info.range_min))); - auto filter_right = make_uniq( - ExpressionType::COMPARE_LESSTHANOREQUALTO, func.children[0]->Copy(), - make_uniq(Value::BLOB_RAW(info.range_max))); - auto filter_expr = make_uniq(ExpressionType::CONJUNCTION_AND, - std::move(filter_left), std::move(filter_right)); - - new_filter->expressions.push_back(std::move(filter_expr)); - } - } - - if (!new_filter->expressions.empty()) { - new_filter->children = std::move(op->children); - op->children.clear(); - op->children.push_back(std::move(new_filter)); - } - - return op; -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -void RemoveDuplicateGroups::VisitOperator(LogicalOperator &op) { - switch (op.type) { - case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: - VisitAggregate(op.Cast()); - break; - default: - break; - } - LogicalOperatorVisitor::VisitOperatorExpressions(op); - LogicalOperatorVisitor::VisitOperatorChildren(op); -} - -void RemoveDuplicateGroups::VisitAggregate(LogicalAggregate &aggr) { - if (!aggr.grouping_functions.empty()) { - return; - } - - auto &groups = aggr.groups; - - column_binding_map_t duplicate_map; - vector> duplicates; - for (idx_t group_idx = 0; group_idx < groups.size(); group_idx++) { - const auto &group = groups[group_idx]; - if (group->type != ExpressionType::BOUND_COLUMN_REF) { - continue; - } - const auto &colref = group->Cast(); - const auto &binding = colref.binding; - const auto it = duplicate_map.find(binding); - if (it == duplicate_map.end()) { - duplicate_map.emplace(binding, group_idx); - } else { - duplicates.emplace_back(it->second, group_idx); - } - } - - if (duplicates.empty()) { - return; - } - - // Sort duplicates by max duplicate group idx, because we want to remove groups from the back - sort(duplicates.begin(), duplicates.end(), - [](const pair &lhs, const pair &rhs) { return lhs.second > rhs.second; }); - - // Now we want to remove the duplicates, but this alters the column bindings coming out of the aggregate, - // so we keep track of how they shift and do another round of column binding replacements - column_binding_map_t group_binding_map; - for (idx_t group_idx = 0; group_idx < groups.size(); group_idx++) { - group_binding_map.emplace(ColumnBinding(aggr.group_index, group_idx), - ColumnBinding(aggr.group_index, group_idx)); - } - - for (idx_t duplicate_idx = 0; duplicate_idx < duplicates.size(); duplicate_idx++) { - const auto &duplicate = duplicates[duplicate_idx]; - const auto &remaining_idx = duplicate.first; - const auto &removed_idx = duplicate.second; - - // Store expression and remove it from groups - stored_expressions.emplace_back(std::move(groups[removed_idx])); - groups.erase(groups.begin() + removed_idx); - - // This optimizer should run before statistics propagation, so this should be empty - // If it runs after, then group_stats should be updated too - D_ASSERT(aggr.group_stats.empty()); - - // Remove from grouping sets too - for (auto &grouping_set : aggr.grouping_sets) { - // Replace removed group with duplicate remaining group - if (grouping_set.erase(removed_idx) != 0) { - grouping_set.insert(remaining_idx); - } - - // Indices shifted: Reinsert groups in the set with group_idx - 1 - vector group_indices_to_reinsert; - for (auto &entry : grouping_set) { - if (entry > removed_idx) { - group_indices_to_reinsert.emplace_back(entry); - } - } - for (const auto group_idx : group_indices_to_reinsert) { - grouping_set.erase(group_idx); - } - for (const auto group_idx : group_indices_to_reinsert) { - grouping_set.insert(group_idx - 1); - } - } - - // Update mapping - auto it = group_binding_map.find(ColumnBinding(aggr.group_index, removed_idx)); - D_ASSERT(it != group_binding_map.end()); - it->second.column_index = remaining_idx; - - for (auto &map_entry : group_binding_map) { - auto &new_binding = map_entry.second; - if (new_binding.column_index > removed_idx) { - new_binding.column_index--; - } - } - } - - // Replace all references to the old group binding with the new group binding - for (const auto &map_entry : group_binding_map) { - auto it = column_references.find(map_entry.first); - if (it != column_references.end()) { - for (auto expr : it->second) { - expr.get().binding = map_entry.second; - } - } - } -} - -unique_ptr RemoveDuplicateGroups::VisitReplace(BoundColumnRefExpression &expr, - unique_ptr *expr_ptr) { - // add a column reference - column_references[expr.binding].push_back(expr); - return nullptr; -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - - - -namespace duckdb { - -void RemoveUnusedColumns::ReplaceBinding(ColumnBinding current_binding, ColumnBinding new_binding) { - auto colrefs = column_references.find(current_binding); - if (colrefs != column_references.end()) { - for (auto &colref : colrefs->second) { - D_ASSERT(colref->binding == current_binding); - colref->binding = new_binding; - } - } -} - -template -void RemoveUnusedColumns::ClearUnusedExpressions(vector &list, idx_t table_idx, bool replace) { - idx_t offset = 0; - for (idx_t col_idx = 0; col_idx < list.size(); col_idx++) { - auto current_binding = ColumnBinding(table_idx, col_idx + offset); - auto entry = column_references.find(current_binding); - if (entry == column_references.end()) { - // this entry is not referred to, erase it from the set of expressions - list.erase(list.begin() + col_idx); - offset++; - col_idx--; - } else if (offset > 0 && replace) { - // column is used but the ColumnBinding has changed because of removed columns - ReplaceBinding(current_binding, ColumnBinding(table_idx, col_idx)); - } - } -} - -void RemoveUnusedColumns::VisitOperator(LogicalOperator &op) { - switch (op.type) { - case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: { - // aggregate - if (!everything_referenced) { - // FIXME: groups that are not referenced need to stay -> but they don't need to be scanned and output! - auto &aggr = op.Cast(); - ClearUnusedExpressions(aggr.expressions, aggr.aggregate_index); - if (aggr.expressions.empty() && aggr.groups.empty()) { - // removed all expressions from the aggregate: push a COUNT(*) - auto count_star_fun = CountStarFun::GetFunction(); - FunctionBinder function_binder(context); - aggr.expressions.push_back( - function_binder.BindAggregateFunction(count_star_fun, {}, nullptr, AggregateType::NON_DISTINCT)); - } - } - - // then recurse into the children of the aggregate - RemoveUnusedColumns remove(binder, context); - remove.VisitOperatorExpressions(op); - remove.VisitOperator(*op.children[0]); - return; - } - case LogicalOperatorType::LOGICAL_ASOF_JOIN: - case LogicalOperatorType::LOGICAL_DELIM_JOIN: - case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: { - if (!everything_referenced) { - auto &comp_join = op.Cast(); - - if (comp_join.join_type != JoinType::INNER) { - break; - } - // for inner joins with equality predicates in the form of (X=Y) - // we can replace any references to the RHS (Y) to references to the LHS (X) - // this reduces the amount of columns we need to extract from the join hash table - for (auto &cond : comp_join.conditions) { - if (cond.comparison == ExpressionType::COMPARE_EQUAL) { - if (cond.left->expression_class == ExpressionClass::BOUND_COLUMN_REF && - cond.right->expression_class == ExpressionClass::BOUND_COLUMN_REF) { - // comparison join between two bound column refs - // we can replace any reference to the RHS (build-side) with a reference to the LHS (probe-side) - auto &lhs_col = cond.left->Cast(); - auto &rhs_col = cond.right->Cast(); - // if there are any columns that refer to the RHS, - auto colrefs = column_references.find(rhs_col.binding); - if (colrefs != column_references.end()) { - for (auto &entry : colrefs->second) { - entry->binding = lhs_col.binding; - column_references[lhs_col.binding].push_back(entry); - } - column_references.erase(rhs_col.binding); - } - } - } - } - } - break; - } - case LogicalOperatorType::LOGICAL_ANY_JOIN: - break; - case LogicalOperatorType::LOGICAL_UNION: - if (!everything_referenced) { - // for UNION we can remove unreferenced columns as long as everything_referenced is false (i.e. we - // encounter a UNION node that is not preceded by a DISTINCT) - // this happens when UNION ALL is used - auto &setop = op.Cast(); - vector entries; - for (idx_t i = 0; i < setop.column_count; i++) { - entries.push_back(i); - } - ClearUnusedExpressions(entries, setop.table_index); - if (entries.size() < setop.column_count) { - if (entries.empty()) { - // no columns referenced: this happens in the case of a COUNT(*) - // extract the first column - entries.push_back(0); - } - // columns were cleared - setop.column_count = entries.size(); - - for (idx_t child_idx = 0; child_idx < op.children.size(); child_idx++) { - RemoveUnusedColumns remove(binder, context, true); - auto &child = op.children[child_idx]; - - // we push a projection under this child that references the required columns of the union - child->ResolveOperatorTypes(); - auto bindings = child->GetColumnBindings(); - vector> expressions; - expressions.reserve(entries.size()); - for (auto &column_idx : entries) { - expressions.push_back( - make_uniq(child->types[column_idx], bindings[column_idx])); - } - auto new_projection = - make_uniq(binder.GenerateTableIndex(), std::move(expressions)); - new_projection->children.push_back(std::move(child)); - op.children[child_idx] = std::move(new_projection); - - remove.VisitOperator(*op.children[child_idx]); - } - return; - } - } - for (auto &child : op.children) { - RemoveUnusedColumns remove(binder, context, true); - remove.VisitOperator(*child); - } - return; - case LogicalOperatorType::LOGICAL_EXCEPT: - case LogicalOperatorType::LOGICAL_INTERSECT: - // for INTERSECT/EXCEPT operations we can't remove anything, just recursively visit the children - for (auto &child : op.children) { - RemoveUnusedColumns remove(binder, context, true); - remove.VisitOperator(*child); - } - return; - case LogicalOperatorType::LOGICAL_ORDER_BY: - if (!everything_referenced) { - auto &order = op.Cast(); - D_ASSERT(order.projections.empty()); // should not yet be set - const auto all_bindings = order.GetColumnBindings(); - - for (idx_t col_idx = 0; col_idx < all_bindings.size(); col_idx++) { - if (column_references.find(all_bindings[col_idx]) != column_references.end()) { - order.projections.push_back(col_idx); - } - } - } - for (auto &child : op.children) { - RemoveUnusedColumns remove(binder, context, true); - remove.VisitOperator(*child); - } - return; - case LogicalOperatorType::LOGICAL_PROJECTION: { - if (!everything_referenced) { - auto &proj = op.Cast(); - ClearUnusedExpressions(proj.expressions, proj.table_index); - - if (proj.expressions.empty()) { - // nothing references the projected expressions - // this happens in the case of e.g. EXISTS(SELECT * FROM ...) - // in this case we only need to project a single constant - proj.expressions.push_back(make_uniq(Value::INTEGER(42))); - } - } - // then recurse into the children of this projection - RemoveUnusedColumns remove(binder, context); - remove.VisitOperatorExpressions(op); - remove.VisitOperator(*op.children[0]); - return; - } - case LogicalOperatorType::LOGICAL_INSERT: - case LogicalOperatorType::LOGICAL_UPDATE: - case LogicalOperatorType::LOGICAL_DELETE: { - //! When RETURNING is used, a PROJECTION is the top level operator for INSERTS, UPDATES, and DELETES - //! We still need to project all values from these operators so the projection - //! on top of them can select from only the table values being inserted. - //! TODO: Push down the projections from the returning statement - //! TODO: Be careful because you might be adding expressions when a user returns * - RemoveUnusedColumns remove(binder, context, true); - remove.VisitOperatorExpressions(op); - remove.VisitOperator(*op.children[0]); - return; - } - case LogicalOperatorType::LOGICAL_GET: - LogicalOperatorVisitor::VisitOperatorExpressions(op); - if (!everything_referenced) { - auto &get = op.Cast(); - if (!get.function.projection_pushdown) { - return; - } - - // Create "selection vector" of all column ids - vector proj_sel; - for (idx_t col_idx = 0; col_idx < get.column_ids.size(); col_idx++) { - proj_sel.push_back(col_idx); - } - // Create a copy that we can use to match ids later - auto col_sel = proj_sel; - // Clear unused ids, exclude filter columns that are projected out immediately - ClearUnusedExpressions(proj_sel, get.table_index, false); - - // for every table filter, push a column binding into the column references map to prevent the column from - // being projected out - for (auto &filter : get.table_filters.filters) { - idx_t index = DConstants::INVALID_INDEX; - for (idx_t i = 0; i < get.column_ids.size(); i++) { - if (get.column_ids[i] == filter.first) { - index = i; - break; - } - } - if (index == DConstants::INVALID_INDEX) { - throw InternalException("Could not find column index for table filter"); - } - ColumnBinding filter_binding(get.table_index, index); - if (column_references.find(filter_binding) == column_references.end()) { - column_references.insert(make_pair(filter_binding, vector())); - } - } - - // Clear unused ids, include filter columns that are projected out immediately - ClearUnusedExpressions(col_sel, get.table_index); - - // Now set the column ids in the LogicalGet using the "selection vector" - vector column_ids; - column_ids.reserve(col_sel.size()); - for (auto col_sel_idx : col_sel) { - column_ids.push_back(get.column_ids[col_sel_idx]); - } - get.column_ids = std::move(column_ids); - - if (get.function.filter_prune) { - // Now set the projection cols by matching the "selection vector" that excludes filter columns - // with the "selection vector" that includes filter columns - idx_t col_idx = 0; - for (auto proj_sel_idx : proj_sel) { - for (; col_idx < col_sel.size(); col_idx++) { - if (proj_sel_idx == col_sel[col_idx]) { - get.projection_ids.push_back(col_idx); - break; - } - } - } - } - - if (get.column_ids.empty()) { - // this generally means we are only interested in whether or not anything exists in the table (e.g. - // EXISTS(SELECT * FROM tbl)) in this case, we just scan the row identifier column as it means we do not - // need to read any of the columns - get.column_ids.push_back(COLUMN_IDENTIFIER_ROW_ID); - } - } - return; - case LogicalOperatorType::LOGICAL_FILTER: { - auto &filter = op.Cast(); - if (!filter.projection_map.empty()) { - // if we have any entries in the filter projection map don't prune any columns - // FIXME: we can do something more clever here - everything_referenced = true; - } - break; - } - case LogicalOperatorType::LOGICAL_DISTINCT: { - // distinct, all projected columns are used for the DISTINCT computation - // mark all columns as used and continue to the children - // FIXME: DISTINCT with expression list does not implicitly reference everything - everything_referenced = true; - break; - } - case LogicalOperatorType::LOGICAL_RECURSIVE_CTE: { - everything_referenced = true; - break; - } - case LogicalOperatorType::LOGICAL_MATERIALIZED_CTE: { - everything_referenced = true; - break; - } - case LogicalOperatorType::LOGICAL_CTE_REF: { - everything_referenced = true; - break; - } - case LogicalOperatorType::LOGICAL_PIVOT: { - everything_referenced = true; - break; - } - default: - break; - } - LogicalOperatorVisitor::VisitOperatorExpressions(op); - LogicalOperatorVisitor::VisitOperatorChildren(op); -} - -unique_ptr RemoveUnusedColumns::VisitReplace(BoundColumnRefExpression &expr, - unique_ptr *expr_ptr) { - // add a column reference - column_references[expr.binding].push_back(&expr); - return nullptr; -} - -unique_ptr RemoveUnusedColumns::VisitReplace(BoundReferenceExpression &expr, - unique_ptr *expr_ptr) { - // BoundReferenceExpression should not be used here yet, they only belong in the physical plan - throw InternalException("BoundReferenceExpression should not be used here yet!"); -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -ArithmeticSimplificationRule::ArithmeticSimplificationRule(ExpressionRewriter &rewriter) : Rule(rewriter) { - // match on an OperatorExpression that has a ConstantExpression as child - auto op = make_uniq(); - op->matchers.push_back(make_uniq()); - op->matchers.push_back(make_uniq()); - op->policy = SetMatcher::Policy::SOME; - // we only match on simple arithmetic expressions (+, -, *, /) - op->function = make_uniq(unordered_set {"+", "-", "*", "//"}); - // and only with numeric results - op->type = make_uniq(); - op->matchers[0]->type = make_uniq(); - op->matchers[1]->type = make_uniq(); - root = std::move(op); -} - -unique_ptr ArithmeticSimplificationRule::Apply(LogicalOperator &op, vector> &bindings, - bool &changes_made, bool is_root) { - auto &root = bindings[0].get().Cast(); - auto &constant = bindings[1].get().Cast(); - int constant_child = root.children[0].get() == &constant ? 0 : 1; - D_ASSERT(root.children.size() == 2); - (void)root; - // any arithmetic operator involving NULL is always NULL - if (constant.value.IsNull()) { - return make_uniq(Value(root.return_type)); - } - auto &func_name = root.function.name; - if (func_name == "+") { - if (constant.value == 0) { - // addition with 0 - // we can remove the entire operator and replace it with the non-constant child - return std::move(root.children[1 - constant_child]); - } - } else if (func_name == "-") { - if (constant_child == 1 && constant.value == 0) { - // subtraction by 0 - // we can remove the entire operator and replace it with the non-constant child - return std::move(root.children[1 - constant_child]); - } - } else if (func_name == "*") { - if (constant.value == 1) { - // multiply with 1, replace with non-constant child - return std::move(root.children[1 - constant_child]); - } else if (constant.value == 0) { - // multiply by zero: replace with constant or null - return ExpressionRewriter::ConstantOrNull(std::move(root.children[1 - constant_child]), - Value::Numeric(root.return_type, 0)); - } - } else if (func_name == "//") { - if (constant_child == 1) { - if (constant.value == 1) { - // divide by 1, replace with non-constant child - return std::move(root.children[1 - constant_child]); - } else if (constant.value == 0) { - // divide by 0, replace with NULL - return make_uniq(Value(root.return_type)); - } - } - } else { - throw InternalException("Unrecognized function name in ArithmeticSimplificationRule"); - } - return nullptr; -} -} // namespace duckdb - - - - - -namespace duckdb { - -CaseSimplificationRule::CaseSimplificationRule(ExpressionRewriter &rewriter) : Rule(rewriter) { - // match on a CaseExpression that has a ConstantExpression as a check - auto op = make_uniq(); - root = std::move(op); -} - -unique_ptr CaseSimplificationRule::Apply(LogicalOperator &op, vector> &bindings, - bool &changes_made, bool is_root) { - auto &root = bindings[0].get().Cast(); - for (idx_t i = 0; i < root.case_checks.size(); i++) { - auto &case_check = root.case_checks[i]; - if (case_check.when_expr->IsFoldable()) { - // the WHEN check is a foldable expression - // use an ExpressionExecutor to execute the expression - auto constant_value = ExpressionExecutor::EvaluateScalar(GetContext(), *case_check.when_expr); - - // fold based on the constant condition - auto condition = constant_value.DefaultCastAs(LogicalType::BOOLEAN); - if (condition.IsNull() || !BooleanValue::Get(condition)) { - // the condition is always false: remove this case check - root.case_checks.erase(root.case_checks.begin() + i); - i--; - } else { - // the condition is always true - // move the THEN clause to the ELSE of the case - root.else_expr = std::move(case_check.then_expr); - // remove this case check and any case checks after this one - root.case_checks.erase(root.case_checks.begin() + i, root.case_checks.end()); - break; - } - } - } - if (root.case_checks.empty()) { - // no case checks left: return the ELSE expression - return std::move(root.else_expr); - } - return nullptr; -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -ComparisonSimplificationRule::ComparisonSimplificationRule(ExpressionRewriter &rewriter) : Rule(rewriter) { - // match on a ComparisonExpression that has a ConstantExpression as a check - auto op = make_uniq(); - op->matchers.push_back(make_uniq()); - op->policy = SetMatcher::Policy::SOME; - root = std::move(op); -} - -unique_ptr ComparisonSimplificationRule::Apply(LogicalOperator &op, vector> &bindings, - bool &changes_made, bool is_root) { - auto &expr = bindings[0].get().Cast(); - auto &constant_expr = bindings[1].get(); - bool column_ref_left = expr.left.get() != &constant_expr; - auto column_ref_expr = !column_ref_left ? expr.right.get() : expr.left.get(); - // the constant_expr is a scalar expression that we have to fold - // use an ExpressionExecutor to execute the expression - D_ASSERT(constant_expr.IsFoldable()); - Value constant_value; - if (!ExpressionExecutor::TryEvaluateScalar(GetContext(), constant_expr, constant_value)) { - return nullptr; - } - if (constant_value.IsNull() && !(expr.type == ExpressionType::COMPARE_NOT_DISTINCT_FROM || - expr.type == ExpressionType::COMPARE_DISTINCT_FROM)) { - // comparison with constant NULL, return NULL - return make_uniq(Value(LogicalType::BOOLEAN)); - } - if (column_ref_expr->expression_class == ExpressionClass::BOUND_CAST) { - //! Here we check if we can apply the expression on the constant side - //! We can do this if the cast itself is invertible and casting the constant is - //! invertible in practice. - auto &cast_expression = column_ref_expr->Cast(); - auto target_type = cast_expression.source_type(); - if (!BoundCastExpression::CastIsInvertible(target_type, cast_expression.return_type)) { - return nullptr; - } - - // Can we cast the constant at all? - string error_message; - Value cast_constant; - auto new_constant = constant_value.DefaultTryCastAs(target_type, cast_constant, &error_message, true); - if (!new_constant) { - return nullptr; - } - - // Is the constant cast invertible? - if (!cast_constant.IsNull() && - !BoundCastExpression::CastIsInvertible(cast_expression.return_type, target_type)) { - // Is it actually invertible? - Value uncast_constant; - if (!cast_constant.DefaultTryCastAs(constant_value.type(), uncast_constant, &error_message, true) || - uncast_constant != constant_value) { - return nullptr; - } - } - - //! We can cast, now we change our column_ref_expression from an operator cast to a column reference - auto child_expression = std::move(cast_expression.child); - auto new_constant_expr = make_uniq(cast_constant); - if (column_ref_left) { - expr.left = std::move(child_expression); - expr.right = std::move(new_constant_expr); - } else { - expr.left = std::move(new_constant_expr); - expr.right = std::move(child_expression); - } - } - return nullptr; -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -ConjunctionSimplificationRule::ConjunctionSimplificationRule(ExpressionRewriter &rewriter) : Rule(rewriter) { - // match on a ComparisonExpression that has a ConstantExpression as a check - auto op = make_uniq(); - op->matchers.push_back(make_uniq()); - op->policy = SetMatcher::Policy::SOME; - root = std::move(op); -} - -unique_ptr ConjunctionSimplificationRule::RemoveExpression(BoundConjunctionExpression &conj, - const Expression &expr) { - for (idx_t i = 0; i < conj.children.size(); i++) { - if (conj.children[i].get() == &expr) { - // erase the expression - conj.children.erase(conj.children.begin() + i); - break; - } - } - if (conj.children.size() == 1) { - // one expression remaining: simply return that expression and erase the conjunction - return std::move(conj.children[0]); - } - return nullptr; -} - -unique_ptr ConjunctionSimplificationRule::Apply(LogicalOperator &op, - vector> &bindings, bool &changes_made, - bool is_root) { - auto &conjunction = bindings[0].get().Cast(); - auto &constant_expr = bindings[1].get(); - // the constant_expr is a scalar expression that we have to fold - // use an ExpressionExecutor to execute the expression - D_ASSERT(constant_expr.IsFoldable()); - Value constant_value; - if (!ExpressionExecutor::TryEvaluateScalar(GetContext(), constant_expr, constant_value)) { - return nullptr; - } - constant_value = constant_value.DefaultCastAs(LogicalType::BOOLEAN); - if (constant_value.IsNull()) { - // we can't simplify conjunctions with a constant NULL - return nullptr; - } - if (conjunction.type == ExpressionType::CONJUNCTION_AND) { - if (!BooleanValue::Get(constant_value)) { - // FALSE in AND, result of expression is false - return make_uniq(Value::BOOLEAN(false)); - } else { - // TRUE in AND, remove the expression from the set - return RemoveExpression(conjunction, constant_expr); - } - } else { - D_ASSERT(conjunction.type == ExpressionType::CONJUNCTION_OR); - if (!BooleanValue::Get(constant_value)) { - // FALSE in OR, remove the expression from the set - return RemoveExpression(conjunction, constant_expr); - } else { - // TRUE in OR, result of expression is true - return make_uniq(Value::BOOLEAN(true)); - } - } -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -//! The ConstantFoldingExpressionMatcher matches on any scalar expression (i.e. Expression::IsFoldable is true) -class ConstantFoldingExpressionMatcher : public FoldableConstantMatcher { -public: - bool Match(Expression &expr, vector> &bindings) override { - // we also do not match on ConstantExpressions, because we cannot fold those any further - if (expr.type == ExpressionType::VALUE_CONSTANT) { - return false; - } - return FoldableConstantMatcher::Match(expr, bindings); - } -}; - -ConstantFoldingRule::ConstantFoldingRule(ExpressionRewriter &rewriter) : Rule(rewriter) { - auto op = make_uniq(); - root = std::move(op); -} - -unique_ptr ConstantFoldingRule::Apply(LogicalOperator &op, vector> &bindings, - bool &changes_made, bool is_root) { - auto &root = bindings[0].get(); - // the root is a scalar expression that we have to fold - D_ASSERT(root.IsFoldable() && root.type != ExpressionType::VALUE_CONSTANT); - - // use an ExpressionExecutor to execute the expression - Value result_value; - if (!ExpressionExecutor::TryEvaluateScalar(GetContext(), root, result_value)) { - return nullptr; - } - D_ASSERT(result_value.type().InternalType() == root.return_type.InternalType()); - // now get the value from the result vector and insert it back into the plan as a constant expression - return make_uniq(result_value); -} - -} // namespace duckdb - - - - - - - - - - - -namespace duckdb { - -DatePartSimplificationRule::DatePartSimplificationRule(ExpressionRewriter &rewriter) : Rule(rewriter) { - auto func = make_uniq(); - func->function = make_uniq("date_part"); - func->matchers.push_back(make_uniq()); - func->matchers.push_back(make_uniq()); - func->policy = SetMatcher::Policy::ORDERED; - root = std::move(func); -} - -unique_ptr DatePartSimplificationRule::Apply(LogicalOperator &op, vector> &bindings, - bool &changes_made, bool is_root) { - auto &date_part = bindings[0].get().Cast(); - auto &constant_expr = bindings[1].get().Cast(); - auto &constant = constant_expr.value; - - if (constant.IsNull()) { - // NULL specifier: return constant NULL - return make_uniq(Value(date_part.return_type)); - } - // otherwise check the specifier - auto specifier = GetDatePartSpecifier(StringValue::Get(constant)); - string new_function_name; - switch (specifier) { - case DatePartSpecifier::YEAR: - new_function_name = "year"; - break; - case DatePartSpecifier::MONTH: - new_function_name = "month"; - break; - case DatePartSpecifier::DAY: - new_function_name = "day"; - break; - case DatePartSpecifier::DECADE: - new_function_name = "decade"; - break; - case DatePartSpecifier::CENTURY: - new_function_name = "century"; - break; - case DatePartSpecifier::MILLENNIUM: - new_function_name = "millennium"; - break; - case DatePartSpecifier::QUARTER: - new_function_name = "quarter"; - break; - case DatePartSpecifier::WEEK: - new_function_name = "week"; - break; - case DatePartSpecifier::YEARWEEK: - new_function_name = "yearweek"; - break; - case DatePartSpecifier::DOW: - new_function_name = "dayofweek"; - break; - case DatePartSpecifier::ISODOW: - new_function_name = "isodow"; - break; - case DatePartSpecifier::DOY: - new_function_name = "dayofyear"; - break; - case DatePartSpecifier::MICROSECONDS: - new_function_name = "microsecond"; - break; - case DatePartSpecifier::MILLISECONDS: - new_function_name = "millisecond"; - break; - case DatePartSpecifier::SECOND: - new_function_name = "second"; - break; - case DatePartSpecifier::MINUTE: - new_function_name = "minute"; - break; - case DatePartSpecifier::HOUR: - new_function_name = "hour"; - break; - default: - return nullptr; - } - // found a replacement function: bind it - vector> children; - children.push_back(std::move(date_part.children[1])); - - string error; - FunctionBinder binder(rewriter.context); - auto function = binder.BindScalarFunction(DEFAULT_SCHEMA, new_function_name, std::move(children), error, false); - if (!function) { - throw BinderException(error); - } - return function; -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -DistributivityRule::DistributivityRule(ExpressionRewriter &rewriter) : Rule(rewriter) { - // we match on an OR expression within a LogicalFilter node - root = make_uniq(); - root->expr_type = make_uniq(ExpressionType::CONJUNCTION_OR); -} - -void DistributivityRule::AddExpressionSet(Expression &expr, expression_set_t &set) { - if (expr.type == ExpressionType::CONJUNCTION_AND) { - auto &and_expr = expr.Cast(); - for (auto &child : and_expr.children) { - set.insert(*child); - } - } else { - set.insert(expr); - } -} - -unique_ptr DistributivityRule::ExtractExpression(BoundConjunctionExpression &conj, idx_t idx, - Expression &expr) { - auto &child = conj.children[idx]; - unique_ptr result; - if (child->type == ExpressionType::CONJUNCTION_AND) { - // AND, remove expression from the list - auto &and_expr = child->Cast(); - for (idx_t i = 0; i < and_expr.children.size(); i++) { - if (and_expr.children[i]->Equals(expr)) { - result = std::move(and_expr.children[i]); - and_expr.children.erase(and_expr.children.begin() + i); - break; - } - } - if (and_expr.children.size() == 1) { - conj.children[idx] = std::move(and_expr.children[0]); - } - } else { - // not an AND node! remove the entire expression - // this happens in the case of e.g. (X AND B) OR X - D_ASSERT(child->Equals(expr)); - result = std::move(child); - conj.children[idx] = nullptr; - } - D_ASSERT(result); - return result; -} - -unique_ptr DistributivityRule::Apply(LogicalOperator &op, vector> &bindings, - bool &changes_made, bool is_root) { - auto &initial_or = bindings[0].get().Cast(); - - // we want to find expressions that occur in each of the children of the OR - // i.e. (X AND A) OR (X AND B) => X occurs in all branches - // first, for the initial child, we create an expression set of which expressions occur - // this is our initial candidate set (in the example: [X, A]) - expression_set_t candidate_set; - AddExpressionSet(*initial_or.children[0], candidate_set); - // now for each of the remaining children, we create a set again and intersect them - // in our example: the second set would be [X, B] - // the intersection would leave [X] - for (idx_t i = 1; i < initial_or.children.size(); i++) { - expression_set_t next_set; - AddExpressionSet(*initial_or.children[i], next_set); - expression_set_t intersect_result; - for (auto &expr : candidate_set) { - if (next_set.find(expr) != next_set.end()) { - intersect_result.insert(expr); - } - } - candidate_set = intersect_result; - } - if (candidate_set.empty()) { - // nothing found: abort - return nullptr; - } - // now for each of the remaining expressions in the candidate set we know that it is contained in all branches of - // the OR - auto new_root = make_uniq(ExpressionType::CONJUNCTION_AND); - for (auto &expr : candidate_set) { - D_ASSERT(initial_or.children.size() > 0); - - // extract the expression from the first child of the OR - auto result = ExtractExpression(initial_or, 0, expr.get()); - // now for the subsequent expressions, simply remove the expression - for (idx_t i = 1; i < initial_or.children.size(); i++) { - ExtractExpression(initial_or, i, *result); - } - // now we add the expression to the new root - new_root->children.push_back(std::move(result)); - } - - // check if we completely erased one of the children of the OR - // this happens if we have an OR in the form of "X OR (X AND A)" - // the left child will be completely empty, as it only contains common expressions - // in this case, any other children are not useful: - // X OR (X AND A) is the same as "X" - // since (1) only tuples that do not qualify "X" will not pass this predicate - // and (2) all tuples that qualify "X" will pass this predicate - for (idx_t i = 0; i < initial_or.children.size(); i++) { - if (!initial_or.children[i]) { - if (new_root->children.size() <= 1) { - return std::move(new_root->children[0]); - } else { - return std::move(new_root); - } - } - } - // finally we need to add the remaining expressions in the OR to the new root - if (initial_or.children.size() == 1) { - // one child: skip the OR entirely and only add the single child - new_root->children.push_back(std::move(initial_or.children[0])); - } else if (initial_or.children.size() > 1) { - // multiple children still remain: push them into a new OR and add that to the new root - auto new_or = make_uniq(ExpressionType::CONJUNCTION_OR); - for (auto &child : initial_or.children) { - new_or->children.push_back(std::move(child)); - } - new_root->children.push_back(std::move(new_or)); - } - // finally return the new root - if (new_root->children.size() == 1) { - return std::move(new_root->children[0]); - } - return std::move(new_root); -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -EmptyNeedleRemovalRule::EmptyNeedleRemovalRule(ExpressionRewriter &rewriter) : Rule(rewriter) { - // match on a FunctionExpression that has a foldable ConstantExpression - auto func = make_uniq(); - func->matchers.push_back(make_uniq()); - func->matchers.push_back(make_uniq()); - func->policy = SetMatcher::Policy::SOME; - - unordered_set functions = {"prefix", "contains", "suffix"}; - func->function = make_uniq(functions); - root = std::move(func); -} - -unique_ptr EmptyNeedleRemovalRule::Apply(LogicalOperator &op, vector> &bindings, - bool &changes_made, bool is_root) { - auto &root = bindings[0].get().Cast(); - D_ASSERT(root.children.size() == 2); - auto &prefix_expr = bindings[2].get(); - - // the constant_expr is a scalar expression that we have to fold - if (!prefix_expr.IsFoldable()) { - return nullptr; - } - D_ASSERT(root.return_type.id() == LogicalTypeId::BOOLEAN); - - auto prefix_value = ExpressionExecutor::EvaluateScalar(GetContext(), prefix_expr); - - if (prefix_value.IsNull()) { - return make_uniq(Value(LogicalType::BOOLEAN)); - } - - D_ASSERT(prefix_value.type() == prefix_expr.return_type); - auto &needle_string = StringValue::Get(prefix_value); - - // PREFIX('xyz', '') is TRUE - // PREFIX(NULL, '') is NULL - // so rewrite PREFIX(x, '') to TRUE_OR_NULL(x) - if (needle_string.empty()) { - return ExpressionRewriter::ConstantOrNull(std::move(root.children[0]), Value::BOOLEAN(true)); - } - return nullptr; -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -EnumComparisonRule::EnumComparisonRule(ExpressionRewriter &rewriter) : Rule(rewriter) { - // match on a ComparisonExpression that is an Equality and has a VARCHAR and ENUM as its children - auto op = make_uniq(); - // Enum requires expression to be root - op->expr_type = make_uniq(ExpressionType::COMPARE_EQUAL); - for (idx_t i = 0; i < 2; i++) { - auto child = make_uniq(); - child->type = make_uniq(LogicalTypeId::VARCHAR); - child->matcher = make_uniq(); - child->matcher->type = make_uniq(LogicalTypeId::ENUM); - op->matchers.push_back(std::move(child)); - } - root = std::move(op); -} - -bool AreMatchesPossible(LogicalType &left, LogicalType &right) { - LogicalType *small_enum, *big_enum; - if (EnumType::GetSize(left) < EnumType::GetSize(right)) { - small_enum = &left; - big_enum = &right; - } else { - small_enum = &right; - big_enum = &left; - } - auto &string_vec = EnumType::GetValuesInsertOrder(*small_enum); - auto string_vec_ptr = FlatVector::GetData(string_vec); - auto size = EnumType::GetSize(*small_enum); - for (idx_t i = 0; i < size; i++) { - auto key = string_vec_ptr[i].GetString(); - if (EnumType::GetPos(*big_enum, key) != -1) { - return true; - } - } - return false; -} -unique_ptr EnumComparisonRule::Apply(LogicalOperator &op, vector> &bindings, - bool &changes_made, bool is_root) { - - auto &root = bindings[0].get().Cast(); - auto &left_child = bindings[1].get().Cast(); - auto &right_child = bindings[3].get().Cast(); - - if (!AreMatchesPossible(left_child.child->return_type, right_child.child->return_type)) { - vector> children; - children.push_back(std::move(root.left)); - children.push_back(std::move(root.right)); - return ExpressionRewriter::ConstantOrNull(std::move(children), Value::BOOLEAN(false)); - } - - if (!is_root || op.type != LogicalOperatorType::LOGICAL_FILTER) { - return nullptr; - } - - auto cast_left_to_right = - BoundCastExpression::AddDefaultCastToType(std::move(left_child.child), right_child.child->return_type, true); - return make_uniq(root.type, std::move(cast_left_to_right), std::move(right_child.child)); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -EqualOrNullSimplification::EqualOrNullSimplification(ExpressionRewriter &rewriter) : Rule(rewriter) { - // match on OR conjunction - auto op = make_uniq(); - op->expr_type = make_uniq(ExpressionType::CONJUNCTION_OR); - op->policy = SetMatcher::Policy::SOME; - - // equi comparison on one side - auto equal_child = make_uniq(); - equal_child->expr_type = make_uniq(ExpressionType::COMPARE_EQUAL); - equal_child->policy = SetMatcher::Policy::SOME; - op->matchers.push_back(std::move(equal_child)); - - // AND conjuction on the other - auto and_child = make_uniq(); - and_child->expr_type = make_uniq(ExpressionType::CONJUNCTION_AND); - and_child->policy = SetMatcher::Policy::SOME; - - // IS NULL tests inside AND - auto isnull_child = make_uniq(); - isnull_child->expr_type = make_uniq(ExpressionType::OPERATOR_IS_NULL); - // I could try to use std::make_uniq for a copy, but it's available from C++14 only - auto isnull_child2 = make_uniq(); - isnull_child2->expr_type = make_uniq(ExpressionType::OPERATOR_IS_NULL); - and_child->matchers.push_back(std::move(isnull_child)); - and_child->matchers.push_back(std::move(isnull_child2)); - - op->matchers.push_back(std::move(and_child)); - root = std::move(op); -} - -// a=b OR (a IS NULL AND b IS NULL) to a IS NOT DISTINCT FROM b -static unique_ptr TryRewriteEqualOrIsNull(Expression &equal_expr, Expression &and_expr) { - if (equal_expr.type != ExpressionType::COMPARE_EQUAL || and_expr.type != ExpressionType::CONJUNCTION_AND) { - return nullptr; - } - - auto &equal_cast = equal_expr.Cast(); - auto &and_cast = and_expr.Cast(); - - if (and_cast.children.size() != 2) { - return nullptr; - } - - // Make sure on the AND conjuction the relevant conditions appear - auto &a_exp = *equal_cast.left; - auto &b_exp = *equal_cast.right; - bool a_is_null_found = false; - bool b_is_null_found = false; - - for (const auto &item : and_cast.children) { - auto &next_exp = *item; - - if (next_exp.type == ExpressionType::OPERATOR_IS_NULL) { - auto &next_exp_cast = next_exp.Cast(); - auto &child = *next_exp_cast.children[0]; - - // Test for equality on both 'a' and 'b' expressions - if (Expression::Equals(child, a_exp)) { - a_is_null_found = true; - } else if (Expression::Equals(child, b_exp)) { - b_is_null_found = true; - } else { - return nullptr; - } - } else { - return nullptr; - } - } - if (a_is_null_found && b_is_null_found) { - return make_uniq(ExpressionType::COMPARE_NOT_DISTINCT_FROM, - std::move(equal_cast.left), std::move(equal_cast.right)); - } - return nullptr; -} - -unique_ptr EqualOrNullSimplification::Apply(LogicalOperator &op, vector> &bindings, - bool &changes_made, bool is_root) { - const Expression &or_exp = bindings[0].get(); - - if (or_exp.type != ExpressionType::CONJUNCTION_OR) { - return nullptr; - } - - const auto &or_exp_cast = or_exp.Cast(); - - if (or_exp_cast.children.size() != 2) { - return nullptr; - } - - auto &left_exp = *or_exp_cast.children[0]; - auto &right_exp = *or_exp_cast.children[1]; - // Test for: a=b OR (a IS NULL AND b IS NULL) - auto first_try = TryRewriteEqualOrIsNull(left_exp, right_exp); - if (first_try) { - return first_try; - } - // Test for: (a IS NULL AND b IS NULL) OR a=b - return TryRewriteEqualOrIsNull(right_exp, left_exp); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -InClauseSimplificationRule::InClauseSimplificationRule(ExpressionRewriter &rewriter) : Rule(rewriter) { - // match on InClauseExpression that has a ConstantExpression as a check - auto op = make_uniq(); - op->policy = SetMatcher::Policy::SOME; - root = std::move(op); -} - -unique_ptr InClauseSimplificationRule::Apply(LogicalOperator &op, vector> &bindings, - bool &changes_made, bool is_root) { - auto &expr = bindings[0].get().Cast(); - if (expr.children[0]->expression_class != ExpressionClass::BOUND_CAST) { - return nullptr; - } - auto &cast_expression = expr.children[0]->Cast(); - if (cast_expression.child->expression_class != ExpressionClass::BOUND_COLUMN_REF) { - return nullptr; - } - //! Here we check if we can apply the expression on the constant side - auto target_type = cast_expression.source_type(); - if (!BoundCastExpression::CastIsInvertible(cast_expression.return_type, target_type)) { - return nullptr; - } - vector> cast_list; - //! First check if we can cast all children - for (size_t i = 1; i < expr.children.size(); i++) { - if (expr.children[i]->expression_class != ExpressionClass::BOUND_CONSTANT) { - return nullptr; - } - D_ASSERT(expr.children[i]->IsFoldable()); - auto constant_value = ExpressionExecutor::EvaluateScalar(GetContext(), *expr.children[i]); - auto new_constant = constant_value.DefaultTryCastAs(target_type); - if (!new_constant) { - return nullptr; - } else { - auto new_constant_expr = make_uniq(constant_value); - cast_list.push_back(std::move(new_constant_expr)); - } - } - //! We can cast, so we move the new constant - for (size_t i = 1; i < expr.children.size(); i++) { - expr.children[i] = std::move(cast_list[i - 1]); - - // expr->children[i] = std::move(new_constant_expr); - } - //! We can cast the full list, so we move the column - expr.children[0] = std::move(cast_expression.child); - return nullptr; -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -LikeOptimizationRule::LikeOptimizationRule(ExpressionRewriter &rewriter) : Rule(rewriter) { - // match on a FunctionExpression that has a foldable ConstantExpression - auto func = make_uniq(); - func->matchers.push_back(make_uniq()); - func->matchers.push_back(make_uniq()); - func->policy = SetMatcher::Policy::ORDERED; - // we match on LIKE ("~~") and NOT LIKE ("!~~") - func->function = make_uniq(unordered_set {"!~~", "~~"}); - root = std::move(func); -} - -static bool PatternIsConstant(const string &pattern) { - for (idx_t i = 0; i < pattern.size(); i++) { - if (pattern[i] == '%' || pattern[i] == '_') { - return false; - } - } - return true; -} - -static bool PatternIsPrefix(const string &pattern) { - idx_t i; - for (i = pattern.size(); i > 0; i--) { - if (pattern[i - 1] != '%') { - break; - } - } - if (i == pattern.size()) { - // no trailing % - // cannot be a prefix - return false; - } - // continue to look in the string - // if there is a % or _ in the string (besides at the very end) this is not a prefix match - for (; i > 0; i--) { - if (pattern[i - 1] == '%' || pattern[i - 1] == '_') { - return false; - } - } - return true; -} - -static bool PatternIsSuffix(const string &pattern) { - idx_t i; - for (i = 0; i < pattern.size(); i++) { - if (pattern[i] != '%') { - break; - } - } - if (i == 0) { - // no leading % - // cannot be a suffix - return false; - } - // continue to look in the string - // if there is a % or _ in the string (besides at the beginning) this is not a suffix match - for (; i < pattern.size(); i++) { - if (pattern[i] == '%' || pattern[i] == '_') { - return false; - } - } - return true; -} - -static bool PatternIsContains(const string &pattern) { - idx_t start; - idx_t end; - for (start = 0; start < pattern.size(); start++) { - if (pattern[start] != '%') { - break; - } - } - for (end = pattern.size(); end > 0; end--) { - if (pattern[end - 1] != '%') { - break; - } - } - if (start == 0 || end == pattern.size()) { - // contains requires both a leading AND a trailing % - return false; - } - // check if there are any other special characters in the string - // if there is a % or _ in the string (besides at the beginning/end) this is not a contains match - for (idx_t i = start; i < end; i++) { - if (pattern[i] == '%' || pattern[i] == '_') { - return false; - } - } - return true; -} - -unique_ptr LikeOptimizationRule::Apply(LogicalOperator &op, vector> &bindings, - bool &changes_made, bool is_root) { - auto &root = bindings[0].get().Cast(); - auto &constant_expr = bindings[2].get().Cast(); - D_ASSERT(root.children.size() == 2); - - if (constant_expr.value.IsNull()) { - return make_uniq(Value(root.return_type)); - } - - // the constant_expr is a scalar expression that we have to fold - if (!constant_expr.IsFoldable()) { - return nullptr; - } - - auto constant_value = ExpressionExecutor::EvaluateScalar(GetContext(), constant_expr); - D_ASSERT(constant_value.type() == constant_expr.return_type); - auto &patt_str = StringValue::Get(constant_value); - - bool is_not_like = root.function.name == "!~~"; - if (PatternIsConstant(patt_str)) { - // Pattern is constant - return make_uniq(is_not_like ? ExpressionType::COMPARE_NOTEQUAL - : ExpressionType::COMPARE_EQUAL, - std::move(root.children[0]), std::move(root.children[1])); - } else if (PatternIsPrefix(patt_str)) { - // Prefix LIKE pattern : [^%_]*[%]+, ignoring underscore - return ApplyRule(root, PrefixFun::GetFunction(), patt_str, is_not_like); - } else if (PatternIsSuffix(patt_str)) { - // Suffix LIKE pattern: [%]+[^%_]*, ignoring underscore - return ApplyRule(root, SuffixFun::GetFunction(), patt_str, is_not_like); - } else if (PatternIsContains(patt_str)) { - // Contains LIKE pattern: [%]+[^%_]*[%]+, ignoring underscore - return ApplyRule(root, ContainsFun::GetFunction(), patt_str, is_not_like); - } - return nullptr; -} - -unique_ptr LikeOptimizationRule::ApplyRule(BoundFunctionExpression &expr, ScalarFunction function, - string pattern, bool is_not_like) { - // replace LIKE by an optimized function - unique_ptr result; - auto new_function = - make_uniq(expr.return_type, std::move(function), std::move(expr.children), nullptr); - - // removing "%" from the pattern - pattern.erase(std::remove(pattern.begin(), pattern.end(), '%'), pattern.end()); - - new_function->children[1] = make_uniq(Value(std::move(pattern))); - - result = std::move(new_function); - if (is_not_like) { - auto negation = make_uniq(ExpressionType::OPERATOR_NOT, LogicalType::BOOLEAN); - negation->children.push_back(std::move(result)); - result = std::move(negation); - } - - return result; -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -MoveConstantsRule::MoveConstantsRule(ExpressionRewriter &rewriter) : Rule(rewriter) { - auto op = make_uniq(); - op->matchers.push_back(make_uniq()); - op->policy = SetMatcher::Policy::UNORDERED; - - auto arithmetic = make_uniq(); - // we handle multiplication, addition and subtraction because those are "easy" - // integer division makes the division case difficult - // e.g. [x / 2 = 3] means [x = 6 OR x = 7] because of truncation -> no clean rewrite rules - arithmetic->function = make_uniq(unordered_set {"+", "-", "*"}); - // we match only on integral numeric types - arithmetic->type = make_uniq(); - arithmetic->matchers.push_back(make_uniq()); - arithmetic->matchers.push_back(make_uniq()); - arithmetic->policy = SetMatcher::Policy::SOME; - op->matchers.push_back(std::move(arithmetic)); - root = std::move(op); -} - -unique_ptr MoveConstantsRule::Apply(LogicalOperator &op, vector> &bindings, - bool &changes_made, bool is_root) { - auto &comparison = bindings[0].get().Cast(); - auto &outer_constant = bindings[1].get().Cast(); - auto &arithmetic = bindings[2].get().Cast(); - auto &inner_constant = bindings[3].get().Cast(); - if (!TypeIsIntegral(arithmetic.return_type.InternalType())) { - return nullptr; - } - if (inner_constant.value.IsNull() || outer_constant.value.IsNull()) { - return make_uniq(Value(comparison.return_type)); - } - auto &constant_type = outer_constant.return_type; - hugeint_t outer_value = IntegralValue::Get(outer_constant.value); - hugeint_t inner_value = IntegralValue::Get(inner_constant.value); - - idx_t arithmetic_child_index = arithmetic.children[0].get() == &inner_constant ? 1 : 0; - auto &op_type = arithmetic.function.name; - if (op_type == "+") { - // [x + 1 COMP 10] OR [1 + x COMP 10] - // order does not matter in addition: - // simply change right side to 10-1 (outer_constant - inner_constant) - if (!Hugeint::SubtractInPlace(outer_value, inner_value)) { - return nullptr; - } - auto result_value = Value::HUGEINT(outer_value); - if (!result_value.DefaultTryCastAs(constant_type)) { - if (comparison.type != ExpressionType::COMPARE_EQUAL) { - return nullptr; - } - // if the cast is not possible then the comparison is not possible - // for example, if we have x + 5 = 3, where x is an unsigned number, we will get x = -2 - // since this is not possible we can remove the entire branch here - return ExpressionRewriter::ConstantOrNull(std::move(arithmetic.children[arithmetic_child_index]), - Value::BOOLEAN(false)); - } - outer_constant.value = std::move(result_value); - } else if (op_type == "-") { - // [x - 1 COMP 10] O R [1 - x COMP 10] - // order matters in subtraction: - if (arithmetic_child_index == 0) { - // [x - 1 COMP 10] - // change right side to 10+1 (outer_constant + inner_constant) - if (!Hugeint::AddInPlace(outer_value, inner_value)) { - return nullptr; - } - auto result_value = Value::HUGEINT(outer_value); - if (!result_value.DefaultTryCastAs(constant_type)) { - // if the cast is not possible then an equality comparison is not possible - if (comparison.type != ExpressionType::COMPARE_EQUAL) { - return nullptr; - } - return ExpressionRewriter::ConstantOrNull(std::move(arithmetic.children[arithmetic_child_index]), - Value::BOOLEAN(false)); - } - outer_constant.value = std::move(result_value); - } else { - // [1 - x COMP 10] - // change right side to 1-10=-9 - if (!Hugeint::SubtractInPlace(inner_value, outer_value)) { - return nullptr; - } - auto result_value = Value::HUGEINT(inner_value); - if (!result_value.DefaultTryCastAs(constant_type)) { - // if the cast is not possible then an equality comparison is not possible - if (comparison.type != ExpressionType::COMPARE_EQUAL) { - return nullptr; - } - return ExpressionRewriter::ConstantOrNull(std::move(arithmetic.children[arithmetic_child_index]), - Value::BOOLEAN(false)); - } - outer_constant.value = std::move(result_value); - // in this case, we should also flip the comparison - // e.g. if we have [4 - x < 2] then we should have [x > 2] - comparison.type = FlipComparisonExpression(comparison.type); - } - } else { - D_ASSERT(op_type == "*"); - // [x * 2 COMP 10] OR [2 * x COMP 10] - // order does not matter in multiplication: - // change right side to 10/2 (outer_constant / inner_constant) - // but ONLY if outer_constant is cleanly divisible by the inner_constant - if (inner_value == 0) { - // x * 0, the result is either 0 or NULL - // we let the arithmetic_simplification rule take care of simplifying this first - return nullptr; - } - if (outer_value % inner_value != 0) { - // not cleanly divisible - bool is_equality = comparison.type == ExpressionType::COMPARE_EQUAL; - bool is_inequality = comparison.type == ExpressionType::COMPARE_NOTEQUAL; - if (is_equality || is_inequality) { - // we know the values are not equal - // the result will be either FALSE or NULL (if COMPARE_EQUAL) - // or TRUE or NULL (if COMPARE_NOTEQUAL) - return ExpressionRewriter::ConstantOrNull(std::move(arithmetic.children[arithmetic_child_index]), - Value::BOOLEAN(is_inequality)); - } else { - // not cleanly divisible and we are doing > >= < <=, skip the simplification for now - return nullptr; - } - } - if (inner_value < 0) { - // multiply by negative value, need to flip expression - comparison.type = FlipComparisonExpression(comparison.type); - } - // else divide the RHS by the LHS - // we need to do a range check on the cast even though we do a division - // because e.g. -128 / -1 = 128, which is out of range - auto result_value = Value::HUGEINT(outer_value / inner_value); - if (!result_value.DefaultTryCastAs(constant_type)) { - return ExpressionRewriter::ConstantOrNull(std::move(arithmetic.children[arithmetic_child_index]), - Value::BOOLEAN(false)); - } - outer_constant.value = std::move(result_value); - } - // replace left side with x - // first extract x from the arithmetic expression - auto arithmetic_child = std::move(arithmetic.children[arithmetic_child_index]); - // then place in the comparison - if (comparison.left.get() == &outer_constant) { - comparison.right = std::move(arithmetic_child); - } else { - comparison.left = std::move(arithmetic_child); - } - changes_made = true; - return nullptr; -} - -} // namespace duckdb - - - - - -namespace duckdb { - -OrderedAggregateOptimizer::OrderedAggregateOptimizer(ExpressionRewriter &rewriter) : Rule(rewriter) { - // we match on an OR expression within a LogicalFilter node - root = make_uniq(); - root->expr_class = ExpressionClass::BOUND_AGGREGATE; -} - -unique_ptr OrderedAggregateOptimizer::Apply(LogicalOperator &op, vector> &bindings, - bool &changes_made, bool is_root) { - auto &aggr = bindings[0].get().Cast(); - if (!aggr.order_bys) { - // no ORDER BYs defined - return nullptr; - } - if (aggr.function.order_dependent == AggregateOrderDependent::NOT_ORDER_DEPENDENT) { - // not an order dependent aggregate but we have an ORDER BY clause - remove it - aggr.order_bys.reset(); - changes_made = true; - return nullptr; - } - return nullptr; -} - -} // namespace duckdb - - - - - - - - - - - -namespace duckdb { - -RegexOptimizationRule::RegexOptimizationRule(ExpressionRewriter &rewriter) : Rule(rewriter) { - auto func = make_uniq(); - func->function = make_uniq("regexp_matches"); - func->policy = SetMatcher::Policy::SOME_ORDERED; - func->matchers.push_back(make_uniq()); - func->matchers.push_back(make_uniq()); - - root = std::move(func); -} - -struct LikeString { - bool exists = true; - bool escaped = false; - string like_string = ""; -}; - -static void AddCharacter(char chr, LikeString &ret, bool contains) { - // if we are not converting into a contains, and the string has LIKE special characters - // then don't return a possible LIKE match - // same if the character is a control character - if (iscntrl(chr) || (!contains && (chr == '%' || chr == '_'))) { - ret.exists = false; - return; - } - auto run_as_str {chr}; - ret.like_string += run_as_str; -} - -static LikeString GetLikeStringEscaped(duckdb_re2::Regexp *regexp, bool contains = false) { - D_ASSERT(regexp->op() == duckdb_re2::kRegexpLiteralString || regexp->op() == duckdb_re2::kRegexpLiteral); - LikeString ret; - - if (regexp->parse_flags() & duckdb_re2::Regexp::FoldCase || - !(regexp->parse_flags() & duckdb_re2::Regexp::OneLine)) { - // parse flags can turn on and off within a regex match, return no optimization - // For now, we just don't optimize if these every turn on. - // TODO: logic to attempt the optimization, then if the parse flags change, then abort - ret.exists = false; - return ret; - } - - // case insensitivity may be on now, but it can also turn off. - if (regexp->op() == duckdb_re2::kRegexpLiteralString) { - auto nrunes = (idx_t)regexp->nrunes(); - auto runes = regexp->runes(); - for (idx_t i = 0; i < nrunes; i++) { - char chr = toascii(runes[i]); - AddCharacter(chr, ret, contains); - if (!ret.exists) { - return ret; - } - } - } else { - auto rune = regexp->rune(); - char chr = toascii(rune); - AddCharacter(chr, ret, contains); - } - D_ASSERT(ret.like_string.size() >= 1 || !ret.exists); - return ret; -} - -static LikeString LikeMatchFromRegex(duckdb_re2::RE2 &pattern) { - LikeString ret = LikeString(); - auto num_subs = pattern.Regexp()->nsub(); - auto subs = pattern.Regexp()->sub(); - auto cur_sub_index = 0; - while (cur_sub_index < num_subs) { - switch (subs[cur_sub_index]->op()) { - case duckdb_re2::kRegexpAnyChar: - if (cur_sub_index == 0) { - ret.like_string += "%"; - } - ret.like_string += "_"; - if (cur_sub_index + 1 == num_subs) { - ret.like_string += "%"; - } - break; - case duckdb_re2::kRegexpStar: - // .* is a Star operator is a anyChar operator as a child. - // any other child operator would represent a pattern LIKE cannot match. - if (subs[cur_sub_index]->nsub() == 1 && subs[cur_sub_index]->sub()[0]->op() == duckdb_re2::kRegexpAnyChar) { - ret.like_string += "%"; - break; - } - ret.exists = false; - return ret; - case duckdb_re2::kRegexpLiteralString: - case duckdb_re2::kRegexpLiteral: { - // if this is the only matching op, we should have directly called - // GetEscapedLikeString - D_ASSERT(!(cur_sub_index == 0 && cur_sub_index + 1 == num_subs)); - if (cur_sub_index == 0) { - ret.like_string += "%"; - } - // if the kRegexpLiteral or kRegexpLiteralString is the only op to match - // the string can directly be converted into a contains - LikeString escaped_like_string = GetLikeStringEscaped(subs[cur_sub_index], false); - if (!escaped_like_string.exists) { - return escaped_like_string; - } - ret.like_string += escaped_like_string.like_string; - ret.escaped = escaped_like_string.escaped; - if (cur_sub_index + 1 == num_subs) { - ret.like_string += "%"; - } - break; - } - case duckdb_re2::kRegexpEndText: - case duckdb_re2::kRegexpEmptyMatch: - case duckdb_re2::kRegexpBeginText: { - break; - } - default: - // some other regexp op that doesn't have an equivalent to a like string - // return false; - ret.exists = false; - return ret; - } - cur_sub_index += 1; - } - return ret; -} - -unique_ptr RegexOptimizationRule::Apply(LogicalOperator &op, vector> &bindings, - bool &changes_made, bool is_root) { - auto &root = bindings[0].get().Cast(); - auto &constant_expr = bindings[2].get().Cast(); - D_ASSERT(root.children.size() == 2 || root.children.size() == 3); - auto regexp_bind_data = root.bind_info.get()->Cast(); - - auto constant_value = ExpressionExecutor::EvaluateScalar(GetContext(), constant_expr); - D_ASSERT(constant_value.type() == constant_expr.return_type); - auto patt_str = StringValue::Get(constant_value); - - duckdb_re2::RE2::Options parsed_options = regexp_bind_data.options; - - if (constant_expr.value.IsNull()) { - return make_uniq(Value(root.return_type)); - } - - // the constant_expr is a scalar expression that we have to fold - if (!constant_expr.IsFoldable()) { - return nullptr; - }; - - duckdb_re2::RE2 pattern(patt_str, parsed_options); - if (!pattern.ok()) { - return nullptr; // this should fail somewhere else - } - - LikeString like_string; - // check for a like string. If we can convert it to a like string, the like string - // optimizer will further optimize suffix and prefix things. - if (pattern.Regexp()->op() == duckdb_re2::kRegexpLiteralString || - pattern.Regexp()->op() == duckdb_re2::kRegexpLiteral) { - // convert to contains. - LikeString escaped_like_string = GetLikeStringEscaped(pattern.Regexp(), true); - if (!escaped_like_string.exists) { - return nullptr; - } - auto parameter = make_uniq(Value(std::move(escaped_like_string.like_string))); - auto contains = make_uniq(root.return_type, ContainsFun::GetFunction(), - std::move(root.children), nullptr); - contains->children[1] = std::move(parameter); - - return std::move(contains); - } else if (pattern.Regexp()->op() == duckdb_re2::kRegexpConcat) { - like_string = LikeMatchFromRegex(pattern); - } else { - like_string.exists = false; - } - - if (!like_string.exists) { - return nullptr; - } - - // if regexp had options, remove them so the new Like Expression can be matched for other optimizers. - if (root.children.size() == 3) { - root.children.pop_back(); - D_ASSERT(root.children.size() == 2); - } - - auto like_expression = make_uniq(root.return_type, LikeFun::GetLikeFunction(), - std::move(root.children), nullptr); - auto parameter = make_uniq(Value(std::move(like_string.like_string))); - like_expression->children[1] = std::move(parameter); - return std::move(like_expression); -} - -} // namespace duckdb - - - -namespace duckdb { - -unique_ptr StatisticsPropagator::PropagateExpression(BoundAggregateExpression &aggr, - unique_ptr *expr_ptr) { - vector stats; - stats.reserve(aggr.children.size()); - for (auto &child : aggr.children) { - auto stat = PropagateExpression(child); - if (!stat) { - stats.push_back(BaseStatistics::CreateUnknown(child->return_type)); - } else { - stats.push_back(stat->Copy()); - } - } - if (!aggr.function.statistics) { - return nullptr; - } - AggregateStatisticsInput input(aggr.bind_info.get(), stats, node_stats.get()); - return aggr.function.statistics(context, aggr, input); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -unique_ptr StatisticsPropagator::PropagateExpression(BoundBetweenExpression &between, - unique_ptr *expr_ptr) { - // propagate in all the children - auto input_stats = PropagateExpression(between.input); - auto lower_stats = PropagateExpression(between.lower); - auto upper_stats = PropagateExpression(between.upper); - if (!input_stats) { - return nullptr; - } - auto lower_comparison = between.LowerComparisonType(); - auto upper_comparison = between.UpperComparisonType(); - // propagate the comparisons - auto lower_prune = FilterPropagateResult::NO_PRUNING_POSSIBLE; - auto upper_prune = FilterPropagateResult::NO_PRUNING_POSSIBLE; - if (lower_stats) { - lower_prune = PropagateComparison(*input_stats, *lower_stats, lower_comparison); - } - if (upper_stats) { - upper_prune = PropagateComparison(*input_stats, *upper_stats, upper_comparison); - } - if (lower_prune == FilterPropagateResult::FILTER_ALWAYS_TRUE && - upper_prune == FilterPropagateResult::FILTER_ALWAYS_TRUE) { - // both filters are always true: replace the between expression with a constant true - *expr_ptr = make_uniq(Value::BOOLEAN(true)); - } else if (lower_prune == FilterPropagateResult::FILTER_ALWAYS_FALSE || - upper_prune == FilterPropagateResult::FILTER_ALWAYS_FALSE) { - // either one of the filters is always false: replace the between expression with a constant false - *expr_ptr = make_uniq(Value::BOOLEAN(false)); - } else if (lower_prune == FilterPropagateResult::FILTER_FALSE_OR_NULL || - upper_prune == FilterPropagateResult::FILTER_FALSE_OR_NULL) { - // either one of the filters is false or null: replace with a constant or null (false) - vector> children; - children.push_back(std::move(between.input)); - children.push_back(std::move(between.lower)); - children.push_back(std::move(between.upper)); - *expr_ptr = ExpressionRewriter::ConstantOrNull(std::move(children), Value::BOOLEAN(false)); - } else if (lower_prune == FilterPropagateResult::FILTER_TRUE_OR_NULL && - upper_prune == FilterPropagateResult::FILTER_TRUE_OR_NULL) { - // both filters are true or null: replace with a true or null - vector> children; - children.push_back(std::move(between.input)); - children.push_back(std::move(between.lower)); - children.push_back(std::move(between.upper)); - *expr_ptr = ExpressionRewriter::ConstantOrNull(std::move(children), Value::BOOLEAN(true)); - } else if (lower_prune == FilterPropagateResult::FILTER_ALWAYS_TRUE) { - // lower filter is always true: replace with upper comparison - *expr_ptr = - make_uniq(upper_comparison, std::move(between.input), std::move(between.upper)); - } else if (upper_prune == FilterPropagateResult::FILTER_ALWAYS_TRUE) { - // upper filter is always true: replace with lower comparison - *expr_ptr = - make_uniq(lower_comparison, std::move(between.input), std::move(between.lower)); - } - return nullptr; -} - -} // namespace duckdb - - - -namespace duckdb { - -unique_ptr StatisticsPropagator::PropagateExpression(BoundCaseExpression &bound_case, - unique_ptr *expr_ptr) { - // propagate in all the children - auto result_stats = PropagateExpression(bound_case.else_expr); - for (auto &case_check : bound_case.case_checks) { - PropagateExpression(case_check.when_expr); - auto then_stats = PropagateExpression(case_check.then_expr); - if (!then_stats) { - result_stats.reset(); - } else if (result_stats) { - result_stats->Merge(*then_stats); - } - } - return result_stats; -} - -} // namespace duckdb - - - -namespace duckdb { - -static unique_ptr StatisticsOperationsNumericNumericCast(const BaseStatistics &input, - const LogicalType &target) { - if (!NumericStats::HasMinMax(input)) { - return nullptr; - } - Value min = NumericStats::Min(input); - Value max = NumericStats::Max(input); - if (!min.DefaultTryCastAs(target) || !max.DefaultTryCastAs(target)) { - // overflow in cast: bailout - return nullptr; - } - auto result = NumericStats::CreateEmpty(target); - result.CopyBase(input); - NumericStats::SetMin(result, min); - NumericStats::SetMax(result, max); - return result.ToUnique(); -} - -static unique_ptr StatisticsNumericCastSwitch(const BaseStatistics &input, const LogicalType &target) { - // Downcasting timestamps to times is not a truncation operation - switch (target.id()) { - case LogicalTypeId::TIME: - switch (input.GetType().id()) { - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_TZ: - return nullptr; - default: - break; - } - default: - break; - } - - switch (target.InternalType()) { - case PhysicalType::INT8: - case PhysicalType::INT16: - case PhysicalType::INT32: - case PhysicalType::INT64: - case PhysicalType::INT128: - case PhysicalType::FLOAT: - case PhysicalType::DOUBLE: - return StatisticsOperationsNumericNumericCast(input, target); - default: - return nullptr; - } -} - -unique_ptr StatisticsPropagator::PropagateExpression(BoundCastExpression &cast, - unique_ptr *expr_ptr) { - auto child_stats = PropagateExpression(cast.child); - if (!child_stats) { - return nullptr; - } - unique_ptr result_stats; - switch (cast.child->return_type.InternalType()) { - case PhysicalType::INT8: - case PhysicalType::INT16: - case PhysicalType::INT32: - case PhysicalType::INT64: - case PhysicalType::INT128: - case PhysicalType::FLOAT: - case PhysicalType::DOUBLE: - result_stats = StatisticsNumericCastSwitch(*child_stats, cast.return_type); - break; - default: - return nullptr; - } - if (cast.try_cast && result_stats) { - result_stats->Set(StatsInfo::CAN_HAVE_NULL_VALUES); - } - return result_stats; -} - -} // namespace duckdb - - - -namespace duckdb { - -unique_ptr StatisticsPropagator::PropagateExpression(BoundColumnRefExpression &colref, - unique_ptr *expr_ptr) { - auto stats = statistics_map.find(colref.binding); - if (stats == statistics_map.end()) { - return nullptr; - } - return stats->second->ToUnique(); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -FilterPropagateResult StatisticsPropagator::PropagateComparison(BaseStatistics &lstats, BaseStatistics &rstats, - ExpressionType comparison) { - // only handle numerics for now - switch (lstats.GetType().InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::UINT8: - case PhysicalType::UINT16: - case PhysicalType::UINT32: - case PhysicalType::UINT64: - case PhysicalType::INT8: - case PhysicalType::INT16: - case PhysicalType::INT32: - case PhysicalType::INT64: - case PhysicalType::INT128: - case PhysicalType::FLOAT: - case PhysicalType::DOUBLE: - break; - default: - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } - if (!NumericStats::HasMinMax(lstats) || !NumericStats::HasMinMax(rstats)) { - // no stats available: nothing to prune - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } - // the result of the propagation depend on whether or not either side has null values - // if there are null values present, we cannot say whether or not - bool has_null = lstats.CanHaveNull() || rstats.CanHaveNull(); - switch (comparison) { - case ExpressionType::COMPARE_EQUAL: - // l = r, if l.min > r.max or r.min > l.max equality is not possible - if (NumericStats::Min(lstats) > NumericStats::Max(rstats) || - NumericStats::Min(rstats) > NumericStats::Max(lstats)) { - return has_null ? FilterPropagateResult::FILTER_FALSE_OR_NULL : FilterPropagateResult::FILTER_ALWAYS_FALSE; - } else { - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } - case ExpressionType::COMPARE_GREATERTHAN: - // l > r - if (NumericStats::Min(lstats) > NumericStats::Max(rstats)) { - // if l.min > r.max, it is always true ONLY if neither side contains nulls - return has_null ? FilterPropagateResult::FILTER_TRUE_OR_NULL : FilterPropagateResult::FILTER_ALWAYS_TRUE; - } - // if r.min is bigger or equal to l.max, the filter is always false - if (NumericStats::Min(rstats) >= NumericStats::Max(lstats)) { - return has_null ? FilterPropagateResult::FILTER_FALSE_OR_NULL : FilterPropagateResult::FILTER_ALWAYS_FALSE; - } - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - // l >= r - if (NumericStats::Min(lstats) >= NumericStats::Max(rstats)) { - // if l.min >= r.max, it is always true ONLY if neither side contains nulls - return has_null ? FilterPropagateResult::FILTER_TRUE_OR_NULL : FilterPropagateResult::FILTER_ALWAYS_TRUE; - } - // if r.min > l.max, the filter is always false - if (NumericStats::Min(rstats) > NumericStats::Max(lstats)) { - return has_null ? FilterPropagateResult::FILTER_FALSE_OR_NULL : FilterPropagateResult::FILTER_ALWAYS_FALSE; - } - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - case ExpressionType::COMPARE_LESSTHAN: - // l < r - if (NumericStats::Max(lstats) < NumericStats::Min(rstats)) { - // if l.max < r.min, it is always true ONLY if neither side contains nulls - return has_null ? FilterPropagateResult::FILTER_TRUE_OR_NULL : FilterPropagateResult::FILTER_ALWAYS_TRUE; - } - // if l.min >= rstats.max, the filter is always false - if (NumericStats::Min(lstats) >= NumericStats::Max(rstats)) { - return has_null ? FilterPropagateResult::FILTER_FALSE_OR_NULL : FilterPropagateResult::FILTER_ALWAYS_FALSE; - } - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - // l <= r - if (NumericStats::Max(lstats) <= NumericStats::Min(rstats)) { - // if l.max <= r.min, it is always true ONLY if neither side contains nulls - return has_null ? FilterPropagateResult::FILTER_TRUE_OR_NULL : FilterPropagateResult::FILTER_ALWAYS_TRUE; - } - // if l.min > rstats.max, the filter is always false - if (NumericStats::Min(lstats) > NumericStats::Max(rstats)) { - return has_null ? FilterPropagateResult::FILTER_FALSE_OR_NULL : FilterPropagateResult::FILTER_ALWAYS_FALSE; - } - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - default: - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } -} - -unique_ptr StatisticsPropagator::PropagateExpression(BoundComparisonExpression &expr, - unique_ptr *expr_ptr) { - auto left_stats = PropagateExpression(expr.left); - auto right_stats = PropagateExpression(expr.right); - if (!left_stats || !right_stats) { - return nullptr; - } - // propagate the statistics of the comparison operator - auto propagate_result = PropagateComparison(*left_stats, *right_stats, expr.type); - switch (propagate_result) { - case FilterPropagateResult::FILTER_ALWAYS_TRUE: - *expr_ptr = make_uniq(Value::BOOLEAN(true)); - return PropagateExpression(*expr_ptr); - case FilterPropagateResult::FILTER_ALWAYS_FALSE: - *expr_ptr = make_uniq(Value::BOOLEAN(false)); - return PropagateExpression(*expr_ptr); - case FilterPropagateResult::FILTER_TRUE_OR_NULL: { - vector> children; - children.push_back(std::move(expr.left)); - children.push_back(std::move(expr.right)); - *expr_ptr = ExpressionRewriter::ConstantOrNull(std::move(children), Value::BOOLEAN(true)); - return nullptr; - } - case FilterPropagateResult::FILTER_FALSE_OR_NULL: { - vector> children; - children.push_back(std::move(expr.left)); - children.push_back(std::move(expr.right)); - *expr_ptr = ExpressionRewriter::ConstantOrNull(std::move(children), Value::BOOLEAN(false)); - return nullptr; - } - default: - // FIXME: we can propagate nulls here, i.e. this expression will have nulls only if left and right has nulls - return nullptr; - } -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -unique_ptr StatisticsPropagator::PropagateExpression(BoundConjunctionExpression &expr, - unique_ptr *expr_ptr) { - auto is_and = expr.type == ExpressionType::CONJUNCTION_AND; - for (idx_t expr_idx = 0; expr_idx < expr.children.size(); expr_idx++) { - auto &child = expr.children[expr_idx]; - auto stats = PropagateExpression(child); - if (!child->IsFoldable()) { - continue; - } - // we have a constant in a conjunction - // we (1) either prune the child - // or (2) replace the entire conjunction with a constant - auto constant = ExpressionExecutor::EvaluateScalar(context, *child); - if (constant.IsNull()) { - continue; - } - auto b = BooleanValue::Get(constant); - bool prune_child = false; - bool constant_value = true; - if (b) { - // true - if (is_and) { - // true in and: prune child - prune_child = true; - } else { - // true in OR: replace with TRUE - constant_value = true; - } - } else { - // false - if (is_and) { - // false in AND: replace with FALSE - constant_value = false; - } else { - // false in OR: prune child - prune_child = true; - } - } - if (prune_child) { - expr.children.erase(expr.children.begin() + expr_idx); - expr_idx--; - continue; - } - *expr_ptr = make_uniq(Value::BOOLEAN(constant_value)); - return PropagateExpression(*expr_ptr); - } - if (expr.children.empty()) { - // if there are no children left, replace the conjunction with TRUE (for AND) or FALSE (for OR) - *expr_ptr = make_uniq(Value::BOOLEAN(is_and)); - return PropagateExpression(*expr_ptr); - } else if (expr.children.size() == 1) { - // if there is one child left, replace the conjunction with that one child - *expr_ptr = std::move(expr.children[0]); - } - return nullptr; -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -unique_ptr StatisticsPropagator::StatisticsFromValue(const Value &input) { - return BaseStatistics::FromConstant(input).ToUnique(); -} - -unique_ptr StatisticsPropagator::PropagateExpression(BoundConstantExpression &constant, - unique_ptr *expr_ptr) { - return StatisticsFromValue(constant.value); -} - -} // namespace duckdb - - - -namespace duckdb { - -unique_ptr StatisticsPropagator::PropagateExpression(BoundFunctionExpression &func, - unique_ptr *expr_ptr) { - vector stats; - stats.reserve(func.children.size()); - for (idx_t i = 0; i < func.children.size(); i++) { - auto stat = PropagateExpression(func.children[i]); - if (!stat) { - stats.push_back(BaseStatistics::CreateUnknown(func.children[i]->return_type)); - } else { - stats.push_back(stat->Copy()); - } - } - if (!func.function.statistics) { - return nullptr; - } - FunctionStatisticsInput input(func, func.bind_info.get(), stats, expr_ptr); - return func.function.statistics(context, input); -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr StatisticsPropagator::PropagateExpression(BoundOperatorExpression &expr, - unique_ptr *expr_ptr) { - bool all_have_stats = true; - vector> child_stats; - child_stats.reserve(expr.children.size()); - for (auto &child : expr.children) { - auto stats = PropagateExpression(child); - if (!stats) { - all_have_stats = false; - } - child_stats.push_back(std::move(stats)); - } - if (!all_have_stats) { - return nullptr; - } - switch (expr.type) { - case ExpressionType::OPERATOR_COALESCE: - // COALESCE, merge stats of all children - for (idx_t i = 0; i < expr.children.size(); i++) { - D_ASSERT(child_stats[i]); - if (!child_stats[i]->CanHaveNoNull()) { - // this child is always NULL, we can remove it from the coalesce - // UNLESS there is only one node remaining - if (expr.children.size() > 1) { - expr.children.erase(expr.children.begin() + i); - child_stats.erase(child_stats.begin() + i); - i--; - } - } else if (!child_stats[i]->CanHaveNull()) { - // coalesce child cannot have NULL entries - // this is the last coalesce node that influences the result - // we can erase any children after this node - if (i + 1 < expr.children.size()) { - expr.children.erase(expr.children.begin() + i + 1, expr.children.end()); - child_stats.erase(child_stats.begin() + i + 1, child_stats.end()); - } - break; - } - } - D_ASSERT(!expr.children.empty()); - D_ASSERT(expr.children.size() == child_stats.size()); - if (expr.children.size() == 1) { - // coalesce of one entry: simply return that entry - *expr_ptr = std::move(expr.children[0]); - } else { - // coalesce of multiple entries - // merge the stats - for (idx_t i = 1; i < expr.children.size(); i++) { - child_stats[0]->Merge(*child_stats[i]); - } - } - return std::move(child_stats[0]); - case ExpressionType::OPERATOR_IS_NULL: - if (!child_stats[0]->CanHaveNull()) { - // child has no null values: x IS NULL will always be false - *expr_ptr = make_uniq(Value::BOOLEAN(false)); - return PropagateExpression(*expr_ptr); - } - if (!child_stats[0]->CanHaveNoNull()) { - // child has no valid values: x IS NULL will always be true - *expr_ptr = make_uniq(Value::BOOLEAN(true)); - return PropagateExpression(*expr_ptr); - } - return nullptr; - case ExpressionType::OPERATOR_IS_NOT_NULL: - if (!child_stats[0]->CanHaveNull()) { - // child has no null values: x IS NOT NULL will always be true - *expr_ptr = make_uniq(Value::BOOLEAN(true)); - return PropagateExpression(*expr_ptr); - } - if (!child_stats[0]->CanHaveNoNull()) { - // child has no valid values: x IS NOT NULL will always be false - *expr_ptr = make_uniq(Value::BOOLEAN(false)); - return PropagateExpression(*expr_ptr); - } - return nullptr; - default: - return nullptr; - } -} - -} // namespace duckdb - - - -namespace duckdb { - -unique_ptr StatisticsPropagator::PropagateStatistics(LogicalAggregate &aggr, - unique_ptr *node_ptr) { - // first propagate statistics in the child node - node_stats = PropagateStatistics(aggr.children[0]); - - // handle the groups: simply propagate statistics and assign the stats to the group binding - aggr.group_stats.resize(aggr.groups.size()); - for (idx_t group_idx = 0; group_idx < aggr.groups.size(); group_idx++) { - auto stats = PropagateExpression(aggr.groups[group_idx]); - aggr.group_stats[group_idx] = stats ? stats->ToUnique() : nullptr; - if (!stats) { - continue; - } - if (aggr.grouping_sets.size() > 1) { - // aggregates with multiple grouping sets can introduce NULL values to certain groups - // FIXME: actually figure out WHICH groups can have null values introduced - stats->Set(StatsInfo::CAN_HAVE_NULL_VALUES); - continue; - } - ColumnBinding group_binding(aggr.group_index, group_idx); - statistics_map[group_binding] = std::move(stats); - } - // propagate statistics in the aggregates - for (idx_t aggregate_idx = 0; aggregate_idx < aggr.expressions.size(); aggregate_idx++) { - auto stats = PropagateExpression(aggr.expressions[aggregate_idx]); - if (!stats) { - continue; - } - ColumnBinding aggregate_binding(aggr.aggregate_index, aggregate_idx); - statistics_map[aggregate_binding] = std::move(stats); - } - // the max cardinality of an aggregate is the max cardinality of the input (i.e. when every row is a unique group) - return std::move(node_stats); -} - -} // namespace duckdb - - - -namespace duckdb { - -unique_ptr StatisticsPropagator::PropagateStatistics(LogicalCrossProduct &cp, - unique_ptr *node_ptr) { - // first propagate statistics in the child node - auto left_stats = PropagateStatistics(cp.children[0]); - auto right_stats = PropagateStatistics(cp.children[1]); - if (!left_stats || !right_stats) { - return nullptr; - } - MultiplyCardinalities(left_stats, *right_stats); - return left_stats; -} - -} // namespace duckdb - - - - - - - - - - -namespace duckdb { - -static bool IsCompareDistinct(ExpressionType type) { - return type == ExpressionType::COMPARE_DISTINCT_FROM || type == ExpressionType::COMPARE_NOT_DISTINCT_FROM; -} - -bool StatisticsPropagator::ExpressionIsConstant(Expression &expr, const Value &val) { - if (expr.GetExpressionClass() != ExpressionClass::BOUND_CONSTANT) { - return false; - } - auto &bound_constant = expr.Cast(); - D_ASSERT(bound_constant.value.type() == val.type()); - return Value::NotDistinctFrom(bound_constant.value, val); -} - -bool StatisticsPropagator::ExpressionIsConstantOrNull(Expression &expr, const Value &val) { - if (expr.GetExpressionClass() != ExpressionClass::BOUND_FUNCTION) { - return false; - } - auto &bound_function = expr.Cast(); - return ConstantOrNull::IsConstantOrNull(bound_function, val); -} - -void StatisticsPropagator::SetStatisticsNotNull(ColumnBinding binding) { - auto entry = statistics_map.find(binding); - if (entry == statistics_map.end()) { - return; - } - entry->second->Set(StatsInfo::CANNOT_HAVE_NULL_VALUES); -} - -void StatisticsPropagator::UpdateFilterStatistics(BaseStatistics &stats, ExpressionType comparison_type, - const Value &constant) { - // regular comparisons removes all null values - if (!IsCompareDistinct(comparison_type)) { - stats.Set(StatsInfo::CANNOT_HAVE_NULL_VALUES); - } - if (!stats.GetType().IsNumeric()) { - // don't handle non-numeric columns here (yet) - return; - } - if (!NumericStats::HasMinMax(stats)) { - // no stats available: skip this - return; - } - switch (comparison_type) { - case ExpressionType::COMPARE_LESSTHAN: - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - // X < constant OR X <= constant - // max becomes the constant - NumericStats::SetMax(stats, constant); - break; - case ExpressionType::COMPARE_GREATERTHAN: - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - // X > constant OR X >= constant - // min becomes the constant - NumericStats::SetMin(stats, constant); - break; - case ExpressionType::COMPARE_EQUAL: - // X = constant - // both min and max become the constant - NumericStats::SetMin(stats, constant); - NumericStats::SetMax(stats, constant); - break; - default: - break; - } -} - -void StatisticsPropagator::UpdateFilterStatistics(BaseStatistics &lstats, BaseStatistics &rstats, - ExpressionType comparison_type) { - // regular comparisons removes all null values - if (!IsCompareDistinct(comparison_type)) { - lstats.Set(StatsInfo::CANNOT_HAVE_NULL_VALUES); - rstats.Set(StatsInfo::CANNOT_HAVE_NULL_VALUES); - } - D_ASSERT(lstats.GetType() == rstats.GetType()); - if (!lstats.GetType().IsNumeric()) { - // don't handle non-numeric columns here (yet) - return; - } - if (!NumericStats::HasMinMax(lstats) || !NumericStats::HasMinMax(rstats)) { - // no stats available: skip this - return; - } - switch (comparison_type) { - case ExpressionType::COMPARE_LESSTHAN: - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - // LEFT < RIGHT OR LEFT <= RIGHT - // we know that every value of left is smaller (or equal to) every value in right - // i.e. if we have left = [-50, 250] and right = [-100, 100] - - // we know that left.max is AT MOST equal to right.max - // because any value in left that is BIGGER than right.max will not pass the filter - if (NumericStats::Max(lstats) > NumericStats::Max(rstats)) { - NumericStats::SetMax(lstats, NumericStats::Max(rstats)); - } - - // we also know that right.min is AT MOST equal to left.min - // because any value in right that is SMALLER than left.min will not pass the filter - if (NumericStats::Min(rstats) < NumericStats::Min(lstats)) { - NumericStats::SetMin(rstats, NumericStats::Min(lstats)); - } - // so in our example, the bounds get updated as follows: - // left: [-50, 100], right: [-50, 100] - break; - case ExpressionType::COMPARE_GREATERTHAN: - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - // LEFT > RIGHT OR LEFT >= RIGHT - // we know that every value of left is bigger (or equal to) every value in right - // this is essentially the inverse of the less than (or equal to) scenario - if (NumericStats::Max(rstats) > NumericStats::Max(lstats)) { - NumericStats::SetMax(rstats, NumericStats::Max(lstats)); - } - if (NumericStats::Min(lstats) < NumericStats::Min(rstats)) { - NumericStats::SetMin(lstats, NumericStats::Min(rstats)); - } - break; - case ExpressionType::COMPARE_EQUAL: - case ExpressionType::COMPARE_NOT_DISTINCT_FROM: - // LEFT = RIGHT - // only the tightest bounds pass - // so if we have e.g. left = [-50, 250] and right = [-100, 100] - // the tighest bounds are [-50, 100] - // select the highest min - if (NumericStats::Min(lstats) > NumericStats::Min(rstats)) { - NumericStats::SetMin(rstats, NumericStats::Min(lstats)); - } else { - NumericStats::SetMin(lstats, NumericStats::Min(rstats)); - } - // select the lowest max - if (NumericStats::Max(lstats) < NumericStats::Max(rstats)) { - NumericStats::SetMax(rstats, NumericStats::Max(lstats)); - } else { - NumericStats::SetMax(lstats, NumericStats::Max(rstats)); - } - break; - default: - break; - } -} - -void StatisticsPropagator::UpdateFilterStatistics(Expression &left, Expression &right, ExpressionType comparison_type) { - // first check if either side is a bound column ref - // any column ref involved in a comparison will not be null after the comparison - bool compare_distinct = IsCompareDistinct(comparison_type); - if (!compare_distinct && left.type == ExpressionType::BOUND_COLUMN_REF) { - SetStatisticsNotNull((left.Cast()).binding); - } - if (!compare_distinct && right.type == ExpressionType::BOUND_COLUMN_REF) { - SetStatisticsNotNull((right.Cast()).binding); - } - // check if this is a comparison between a constant and a column ref - optional_ptr constant; - optional_ptr columnref; - if (left.type == ExpressionType::VALUE_CONSTANT && right.type == ExpressionType::BOUND_COLUMN_REF) { - constant = &left.Cast(); - columnref = &right.Cast(); - comparison_type = FlipComparisonExpression(comparison_type); - } else if (left.type == ExpressionType::BOUND_COLUMN_REF && right.type == ExpressionType::VALUE_CONSTANT) { - columnref = &left.Cast(); - constant = &right.Cast(); - } else if (left.type == ExpressionType::BOUND_COLUMN_REF && right.type == ExpressionType::BOUND_COLUMN_REF) { - // comparison between two column refs - auto &left_column_ref = left.Cast(); - auto &right_column_ref = right.Cast(); - auto lentry = statistics_map.find(left_column_ref.binding); - auto rentry = statistics_map.find(right_column_ref.binding); - if (lentry == statistics_map.end() || rentry == statistics_map.end()) { - return; - } - UpdateFilterStatistics(*lentry->second, *rentry->second, comparison_type); - } else { - // unsupported filter - return; - } - if (constant && columnref) { - // comparison between columnref - auto entry = statistics_map.find(columnref->binding); - if (entry == statistics_map.end()) { - return; - } - UpdateFilterStatistics(*entry->second, comparison_type, constant->value); - } -} - -void StatisticsPropagator::UpdateFilterStatistics(Expression &condition) { - // in filters, we check for constant comparisons with bound columns - // if we find a comparison in the form of e.g. "i=3", we can update our statistics for that column - switch (condition.GetExpressionClass()) { - case ExpressionClass::BOUND_BETWEEN: { - auto &between = condition.Cast(); - UpdateFilterStatistics(*between.input, *between.lower, between.LowerComparisonType()); - UpdateFilterStatistics(*between.input, *between.upper, between.UpperComparisonType()); - break; - } - case ExpressionClass::BOUND_COMPARISON: { - auto &comparison = condition.Cast(); - UpdateFilterStatistics(*comparison.left, *comparison.right, comparison.type); - break; - } - default: - break; - } -} - -unique_ptr StatisticsPropagator::PropagateStatistics(LogicalFilter &filter, - unique_ptr *node_ptr) { - // first propagate to the child - node_stats = PropagateStatistics(filter.children[0]); - if (filter.children[0]->type == LogicalOperatorType::LOGICAL_EMPTY_RESULT) { - ReplaceWithEmptyResult(*node_ptr); - return make_uniq(0, 0); - } - - // then propagate to each of the expressions - for (idx_t i = 0; i < filter.expressions.size(); i++) { - auto &condition = filter.expressions[i]; - PropagateExpression(condition); - - if (ExpressionIsConstant(*condition, Value::BOOLEAN(true))) { - // filter is always true; it is useless to execute it - // erase this condition - filter.expressions.erase(filter.expressions.begin() + i); - i--; - if (filter.expressions.empty()) { - // all conditions have been erased: remove the entire filter - *node_ptr = std::move(filter.children[0]); - break; - } - } else if (ExpressionIsConstant(*condition, Value::BOOLEAN(false)) || - ExpressionIsConstantOrNull(*condition, Value::BOOLEAN(false))) { - // filter is always false or null; this entire filter should be replaced by an empty result block - ReplaceWithEmptyResult(*node_ptr); - return make_uniq(0, 0); - } else { - // cannot prune this filter: propagate statistics from the filter - UpdateFilterStatistics(*condition); - } - } - // the max cardinality of a filter is the cardinality of the input (i.e. no tuples get filtered) - return std::move(node_stats); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -FilterPropagateResult StatisticsPropagator::PropagateTableFilter(BaseStatistics &stats, TableFilter &filter) { - return filter.CheckStatistics(stats); -} - -void StatisticsPropagator::UpdateFilterStatistics(BaseStatistics &input, TableFilter &filter) { - // FIXME: update stats... - switch (filter.filter_type) { - case TableFilterType::CONJUNCTION_AND: { - auto &conjunction_and = filter.Cast(); - for (auto &child_filter : conjunction_and.child_filters) { - UpdateFilterStatistics(input, *child_filter); - } - break; - } - case TableFilterType::CONSTANT_COMPARISON: { - auto &constant_filter = filter.Cast(); - UpdateFilterStatistics(input, constant_filter.comparison_type, constant_filter.constant); - break; - } - default: - break; - } -} - -unique_ptr StatisticsPropagator::PropagateStatistics(LogicalGet &get, - unique_ptr *node_ptr) { - if (get.function.cardinality) { - node_stats = get.function.cardinality(context, get.bind_data.get()); - } - if (!get.function.statistics) { - // no column statistics to get - return std::move(node_stats); - } - for (idx_t i = 0; i < get.column_ids.size(); i++) { - auto stats = get.function.statistics(context, get.bind_data.get(), get.column_ids[i]); - if (stats) { - ColumnBinding binding(get.table_index, i); - statistics_map.insert(make_pair(binding, std::move(stats))); - } - } - // push table filters into the statistics - vector column_indexes; - column_indexes.reserve(get.table_filters.filters.size()); - for (auto &kv : get.table_filters.filters) { - column_indexes.push_back(kv.first); - } - - for (auto &table_filter_column : column_indexes) { - idx_t column_index; - for (column_index = 0; column_index < get.column_ids.size(); column_index++) { - if (get.column_ids[column_index] == table_filter_column) { - break; - } - } - D_ASSERT(column_index < get.column_ids.size()); - D_ASSERT(get.column_ids[column_index] == table_filter_column); - - // find the stats - ColumnBinding stats_binding(get.table_index, column_index); - auto entry = statistics_map.find(stats_binding); - if (entry == statistics_map.end()) { - // no stats for this entry - continue; - } - auto &stats = *entry->second; - - // fetch the table filter - D_ASSERT(get.table_filters.filters.count(table_filter_column) > 0); - auto &filter = get.table_filters.filters[table_filter_column]; - auto propagate_result = PropagateTableFilter(stats, *filter); - switch (propagate_result) { - case FilterPropagateResult::FILTER_ALWAYS_TRUE: - // filter is always true; it is useless to execute it - // erase this condition - get.table_filters.filters.erase(table_filter_column); - break; - case FilterPropagateResult::FILTER_FALSE_OR_NULL: - case FilterPropagateResult::FILTER_ALWAYS_FALSE: - // filter is always false; this entire filter should be replaced by an empty result block - ReplaceWithEmptyResult(*node_ptr); - return make_uniq(0, 0); - default: - // general case: filter can be true or false, update this columns' statistics - UpdateFilterStatistics(stats, *filter); - break; - } - } - return std::move(node_stats); -} - -} // namespace duckdb - - - - - - - - - - - - - - -namespace duckdb { - -void StatisticsPropagator::PropagateStatistics(LogicalComparisonJoin &join, unique_ptr *node_ptr) { - for (idx_t i = 0; i < join.conditions.size(); i++) { - auto &condition = join.conditions[i]; - const auto stats_left = PropagateExpression(condition.left); - const auto stats_right = PropagateExpression(condition.right); - if (stats_left && stats_right) { - if ((condition.comparison == ExpressionType::COMPARE_DISTINCT_FROM || - condition.comparison == ExpressionType::COMPARE_NOT_DISTINCT_FROM) && - stats_left->CanHaveNull() && stats_right->CanHaveNull()) { - // null values are equal in this join, and both sides can have null values - // nothing to do here - continue; - } - auto prune_result = PropagateComparison(*stats_left, *stats_right, condition.comparison); - // Add stats to logical_join for perfect hash join - join.join_stats.push_back(stats_left->ToUnique()); - join.join_stats.push_back(stats_right->ToUnique()); - switch (prune_result) { - case FilterPropagateResult::FILTER_FALSE_OR_NULL: - case FilterPropagateResult::FILTER_ALWAYS_FALSE: - // filter is always false or null, none of the join conditions matter - switch (join.join_type) { - case JoinType::SEMI: - case JoinType::INNER: - // semi or inner join on false; entire node can be pruned - ReplaceWithEmptyResult(*node_ptr); - return; - case JoinType::ANTI: { - // when the right child has data, return the left child - // when the right child has no data, return an empty set - auto limit = make_uniq(1, 0, nullptr, nullptr); - limit->AddChild(std::move(join.children[1])); - auto cross_product = LogicalCrossProduct::Create(std::move(join.children[0]), std::move(limit)); - *node_ptr = std::move(cross_product); - return; - } - case JoinType::LEFT: - // anti/left outer join: replace right side with empty node - ReplaceWithEmptyResult(join.children[1]); - return; - case JoinType::RIGHT: - // right outer join: replace left side with empty node - ReplaceWithEmptyResult(join.children[0]); - return; - default: - // other join types: can't do much meaningful with this information - // full outer join requires both sides anyway; we can skip the execution of the actual join, but eh - // mark/single join requires knowing if the rhs has null values or not - break; - } - break; - case FilterPropagateResult::FILTER_ALWAYS_TRUE: - // filter is always true - if (join.conditions.size() > 1) { - // there are multiple conditions: erase this condition - join.conditions.erase(join.conditions.begin() + i); - // remove the corresponding statistics - join.join_stats.clear(); - i--; - continue; - } else { - // this is the only condition and it is always true: all conditions are true - switch (join.join_type) { - case JoinType::SEMI: { - // when the right child has data, return the left child - // when the right child has no data, return an empty set - auto limit = make_uniq(1, 0, nullptr, nullptr); - limit->AddChild(std::move(join.children[1])); - auto cross_product = LogicalCrossProduct::Create(std::move(join.children[0]), std::move(limit)); - *node_ptr = std::move(cross_product); - return; - } - case JoinType::INNER: { - // inner, replace with cross product - auto cross_product = - LogicalCrossProduct::Create(std::move(join.children[0]), std::move(join.children[1])); - *node_ptr = std::move(cross_product); - return; - } - case JoinType::ANTI: - // anti join on true: empty result - ReplaceWithEmptyResult(*node_ptr); - return; - default: - // we don't handle mark/single join here yet - break; - } - } - break; - default: - break; - } - } - // after we have propagated, we can update the statistics on both sides - // note that it is fine to do this now, even if the same column is used again later - // e.g. if we have i=j AND i=k, and the stats for j and k are disjoint, we know there are no results - // so if we have e.g. i: [0, 100], j: [0, 25], k: [75, 100] - // we can set i: [0, 25] after the first comparison, and statically determine that the second comparison is fals - - // note that we can't update statistics the same for all join types - // mark and single joins don't filter any tuples -> so there is no propagation possible - // anti joins have inverse statistics propagation - // (i.e. if we have an anti join on i: [0, 100] and j: [0, 25], the resulting stats are i:[25,100]) - // for now we don't handle anti joins - if (condition.comparison == ExpressionType::COMPARE_DISTINCT_FROM || - condition.comparison == ExpressionType::COMPARE_NOT_DISTINCT_FROM) { - // skip update when null values are equal (for now?) - continue; - } - switch (join.join_type) { - case JoinType::INNER: - case JoinType::SEMI: { - UpdateFilterStatistics(*condition.left, *condition.right, condition.comparison); - auto updated_stats_left = PropagateExpression(condition.left); - auto updated_stats_right = PropagateExpression(condition.right); - - // Try to push lhs stats down rhs and vice versa - if (!context.config.force_index_join && stats_left && stats_right && updated_stats_left && - updated_stats_right && condition.left->type == ExpressionType::BOUND_COLUMN_REF && - condition.right->type == ExpressionType::BOUND_COLUMN_REF) { - CreateFilterFromJoinStats(join.children[0], condition.left, *stats_left, *updated_stats_left); - CreateFilterFromJoinStats(join.children[1], condition.right, *stats_right, *updated_stats_right); - } - - // Update join_stats when is already part of the join - if (join.join_stats.size() == 2) { - join.join_stats[0] = std::move(updated_stats_left); - join.join_stats[1] = std::move(updated_stats_right); - } - break; - } - default: - break; - } - } -} - -void StatisticsPropagator::PropagateStatistics(LogicalAnyJoin &join, unique_ptr *node_ptr) { - // propagate the expression into the join condition - PropagateExpression(join.condition); -} - -void StatisticsPropagator::MultiplyCardinalities(unique_ptr &stats, NodeStatistics &new_stats) { - if (!stats->has_estimated_cardinality || !new_stats.has_estimated_cardinality || !stats->has_max_cardinality || - !new_stats.has_max_cardinality) { - stats = nullptr; - return; - } - stats->estimated_cardinality = MaxValue(stats->estimated_cardinality, new_stats.estimated_cardinality); - auto new_max = Hugeint::Multiply(stats->max_cardinality, new_stats.max_cardinality); - if (new_max < NumericLimits::Maximum()) { - int64_t result; - if (!Hugeint::TryCast(new_max, result)) { - throw InternalException("Overflow in cast in statistics propagation"); - } - D_ASSERT(result >= 0); - stats->max_cardinality = idx_t(result); - } else { - stats = nullptr; - } -} - -unique_ptr StatisticsPropagator::PropagateStatistics(LogicalJoin &join, - unique_ptr *node_ptr) { - // first propagate through the children of the join - node_stats = PropagateStatistics(join.children[0]); - for (idx_t child_idx = 1; child_idx < join.children.size(); child_idx++) { - auto child_stats = PropagateStatistics(join.children[child_idx]); - if (!child_stats) { - node_stats = nullptr; - } else if (node_stats) { - MultiplyCardinalities(node_stats, *child_stats); - } - } - - auto join_type = join.join_type; - // depending on the join type, we might need to alter the statistics - // LEFT, FULL, RIGHT OUTER and SINGLE joins can introduce null values - // this requires us to alter the statistics after this point in the query plan - bool adds_null_on_left = IsRightOuterJoin(join_type); - bool adds_null_on_right = IsLeftOuterJoin(join_type) || join_type == JoinType::SINGLE; - - vector left_bindings, right_bindings; - if (adds_null_on_left) { - left_bindings = join.children[0]->GetColumnBindings(); - } - if (adds_null_on_right) { - right_bindings = join.children[1]->GetColumnBindings(); - } - - // then propagate into the join conditions - switch (join.type) { - case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: - case LogicalOperatorType::LOGICAL_DELIM_JOIN: - case LogicalOperatorType::LOGICAL_ASOF_JOIN: - PropagateStatistics(join.Cast(), node_ptr); - break; - case LogicalOperatorType::LOGICAL_ANY_JOIN: - PropagateStatistics(join.Cast(), node_ptr); - break; - default: - break; - } - - if (adds_null_on_right) { - // left or full outer join: set IsNull() to true for all rhs statistics - for (auto &binding : right_bindings) { - auto stats = statistics_map.find(binding); - if (stats != statistics_map.end()) { - stats->second->Set(StatsInfo::CAN_HAVE_NULL_VALUES); - } - } - } - if (adds_null_on_left) { - // right or full outer join: set IsNull() to true for all lhs statistics - for (auto &binding : left_bindings) { - auto stats = statistics_map.find(binding); - if (stats != statistics_map.end()) { - stats->second->Set(StatsInfo::CAN_HAVE_NULL_VALUES); - } - } - } - return std::move(node_stats); -} - -static void MaxCardinalities(unique_ptr &stats, NodeStatistics &new_stats) { - if (!stats->has_estimated_cardinality || !new_stats.has_estimated_cardinality || !stats->has_max_cardinality || - !new_stats.has_max_cardinality) { - stats = nullptr; - return; - } - stats->estimated_cardinality = MaxValue(stats->estimated_cardinality, new_stats.estimated_cardinality); - stats->max_cardinality = MaxValue(stats->max_cardinality, new_stats.max_cardinality); -} - -unique_ptr StatisticsPropagator::PropagateStatistics(LogicalPositionalJoin &join, - unique_ptr *node_ptr) { - D_ASSERT(join.type == LogicalOperatorType::LOGICAL_POSITIONAL_JOIN); - - // first propagate through the children of the join - node_stats = PropagateStatistics(join.children[0]); - for (idx_t child_idx = 1; child_idx < join.children.size(); child_idx++) { - auto child_stats = PropagateStatistics(join.children[child_idx]); - if (!child_stats) { - node_stats = nullptr; - } else if (node_stats) { - if (!node_stats->has_estimated_cardinality || !child_stats->has_estimated_cardinality || - !node_stats->has_max_cardinality || !child_stats->has_max_cardinality) { - node_stats = nullptr; - } else { - MaxCardinalities(node_stats, *child_stats); - } - } - } - - // No conditions. - - // Positional Joins are always FULL OUTER - - // set IsNull() to true for all lhs statistics - auto left_bindings = join.children[0]->GetColumnBindings(); - for (auto &binding : left_bindings) { - auto stats = statistics_map.find(binding); - if (stats != statistics_map.end()) { - stats->second->Set(StatsInfo::CAN_HAVE_NULL_VALUES); - } - } - - // set IsNull() to true for all rhs statistics - auto right_bindings = join.children[1]->GetColumnBindings(); - for (auto &binding : right_bindings) { - auto stats = statistics_map.find(binding); - if (stats != statistics_map.end()) { - stats->second->Set(StatsInfo::CAN_HAVE_NULL_VALUES); - } - } - - return std::move(node_stats); -} - -void StatisticsPropagator::CreateFilterFromJoinStats(unique_ptr &child, unique_ptr &expr, - const BaseStatistics &stats_before, - const BaseStatistics &stats_after) { - // Only do this for integral colref's that have stats - if (expr->type != ExpressionType::BOUND_COLUMN_REF || !expr->return_type.IsIntegral() || - !NumericStats::HasMinMax(stats_before) || !NumericStats::HasMinMax(stats_after)) { - return; - } - - // Retrieve min/max - auto min_before = NumericStats::Min(stats_before); - auto max_before = NumericStats::Max(stats_before); - auto min_after = NumericStats::Min(stats_after); - auto max_after = NumericStats::Max(stats_after); - - vector> filter_exprs; - if (min_after > min_before) { - filter_exprs.emplace_back( - make_uniq(ExpressionType::COMPARE_GREATERTHANOREQUALTO, expr->Copy(), - make_uniq(std::move(min_after)))); - } - if (max_after < max_before) { - filter_exprs.emplace_back( - make_uniq(ExpressionType::COMPARE_LESSTHANOREQUALTO, expr->Copy(), - make_uniq(std::move(max_after)))); - } - - if (filter_exprs.empty()) { - return; - } - - auto filter = make_uniq(); - filter->children.emplace_back(std::move(child)); - child = std::move(filter); - - for (auto &filter_expr : filter_exprs) { - child->expressions.emplace_back(std::move(filter_expr)); - } - - FilterPushdown filter_pushdown(optimizer); - child = filter_pushdown.Rewrite(std::move(child)); - PropagateExpression(expr); -} - -} // namespace duckdb - - - -namespace duckdb { - -unique_ptr StatisticsPropagator::PropagateStatistics(LogicalLimit &limit, - unique_ptr *node_ptr) { - // propagate statistics in the child node - PropagateStatistics(limit.children[0]); - // return the node stats, with as expected cardinality the amount specified in the limit - return make_uniq(limit.limit_val, limit.limit_val); -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr StatisticsPropagator::PropagateStatistics(LogicalOrder &order, - unique_ptr *node_ptr) { - // first propagate to the child - node_stats = PropagateStatistics(order.children[0]); - - // then propagate to each of the order expressions - for (auto &bound_order : order.orders) { - bound_order.stats = PropagateExpression(bound_order.expression); - } - return std::move(node_stats); -} - -} // namespace duckdb - - - -namespace duckdb { - -unique_ptr StatisticsPropagator::PropagateStatistics(LogicalProjection &proj, - unique_ptr *node_ptr) { - // first propagate to the child - node_stats = PropagateStatistics(proj.children[0]); - if (proj.children[0]->type == LogicalOperatorType::LOGICAL_EMPTY_RESULT) { - ReplaceWithEmptyResult(*node_ptr); - return std::move(node_stats); - } - // then propagate to each of the expressions - for (idx_t i = 0; i < proj.expressions.size(); i++) { - auto stats = PropagateExpression(proj.expressions[i]); - if (stats) { - ColumnBinding binding(proj.table_index, i); - statistics_map.insert(make_pair(binding, std::move(stats))); - } - } - return std::move(node_stats); -} - -} // namespace duckdb - - - -namespace duckdb { - -void StatisticsPropagator::AddCardinalities(unique_ptr &stats, NodeStatistics &new_stats) { - if (!stats->has_estimated_cardinality || !new_stats.has_estimated_cardinality || !stats->has_max_cardinality || - !new_stats.has_max_cardinality) { - stats = nullptr; - return; - } - stats->estimated_cardinality += new_stats.estimated_cardinality; - auto new_max = Hugeint::Add(stats->max_cardinality, new_stats.max_cardinality); - if (new_max < NumericLimits::Maximum()) { - int64_t result; - if (!Hugeint::TryCast(new_max, result)) { - throw InternalException("Overflow in cast in statistics propagation"); - } - D_ASSERT(result >= 0); - stats->max_cardinality = idx_t(result); - } else { - stats = nullptr; - } -} - -unique_ptr StatisticsPropagator::PropagateStatistics(LogicalSetOperation &setop, - unique_ptr *node_ptr) { - // first propagate statistics in the child nodes - auto left_stats = PropagateStatistics(setop.children[0]); - auto right_stats = PropagateStatistics(setop.children[1]); - - // now fetch the column bindings on both sides - auto left_bindings = setop.children[0]->GetColumnBindings(); - auto right_bindings = setop.children[1]->GetColumnBindings(); - - D_ASSERT(left_bindings.size() == right_bindings.size()); - D_ASSERT(left_bindings.size() == setop.column_count); - for (idx_t i = 0; i < setop.column_count; i++) { - // for each column binding, we fetch the statistics from both the lhs and the rhs - auto left_entry = statistics_map.find(left_bindings[i]); - auto right_entry = statistics_map.find(right_bindings[i]); - if (left_entry == statistics_map.end() || right_entry == statistics_map.end()) { - // no statistics on one of the sides: can't propagate stats - continue; - } - unique_ptr new_stats; - switch (setop.type) { - case LogicalOperatorType::LOGICAL_UNION: - // union: merge the stats of the LHS and RHS together - new_stats = left_entry->second->ToUnique(); - new_stats->Merge(*right_entry->second); - break; - case LogicalOperatorType::LOGICAL_EXCEPT: - // except: use the stats of the LHS - new_stats = left_entry->second->ToUnique(); - break; - case LogicalOperatorType::LOGICAL_INTERSECT: - // intersect: intersect the two stats - // FIXME: for now we just use the stats of the LHS, as this is correct - // however, the stats can be further refined to the minimal subset of the LHS and RHS - new_stats = left_entry->second->ToUnique(); - break; - default: - throw InternalException("Unsupported setop type"); - } - ColumnBinding binding(setop.table_index, i); - statistics_map[binding] = std::move(new_stats); - } - if (!left_stats || !right_stats) { - return nullptr; - } - if (setop.type == LogicalOperatorType::LOGICAL_UNION) { - AddCardinalities(left_stats, *right_stats); - } - return left_stats; -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr StatisticsPropagator::PropagateStatistics(LogicalWindow &window, - unique_ptr *node_ptr) { - // first propagate to the child - node_stats = PropagateStatistics(window.children[0]); - - // then propagate to each of the order expressions - for (auto &window_expr : window.expressions) { - auto over_expr = reinterpret_cast(window_expr.get()); - for (auto &expr : over_expr->partitions) { - over_expr->partitions_stats.push_back(PropagateExpression(expr)); - } - for (auto &bound_order : over_expr->orders) { - bound_order.stats = PropagateExpression(bound_order.expression); - } - } - return std::move(node_stats); -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - - -namespace duckdb { - -StatisticsPropagator::StatisticsPropagator(Optimizer &optimizer_p) - : optimizer(optimizer_p), context(optimizer.context) { -} - -void StatisticsPropagator::ReplaceWithEmptyResult(unique_ptr &node) { - node = make_uniq(std::move(node)); -} - -unique_ptr StatisticsPropagator::PropagateChildren(LogicalOperator &node, - unique_ptr *node_ptr) { - for (idx_t child_idx = 0; child_idx < node.children.size(); child_idx++) { - PropagateStatistics(node.children[child_idx]); - } - return nullptr; -} - -unique_ptr StatisticsPropagator::PropagateStatistics(LogicalOperator &node, - unique_ptr *node_ptr) { - switch (node.type) { - case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: - return PropagateStatistics(node.Cast(), node_ptr); - case LogicalOperatorType::LOGICAL_CROSS_PRODUCT: - return PropagateStatistics(node.Cast(), node_ptr); - case LogicalOperatorType::LOGICAL_FILTER: - return PropagateStatistics(node.Cast(), node_ptr); - case LogicalOperatorType::LOGICAL_GET: - return PropagateStatistics(node.Cast(), node_ptr); - case LogicalOperatorType::LOGICAL_PROJECTION: - return PropagateStatistics(node.Cast(), node_ptr); - case LogicalOperatorType::LOGICAL_ANY_JOIN: - case LogicalOperatorType::LOGICAL_ASOF_JOIN: - case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: - case LogicalOperatorType::LOGICAL_JOIN: - case LogicalOperatorType::LOGICAL_DELIM_JOIN: - return PropagateStatistics(node.Cast(), node_ptr); - case LogicalOperatorType::LOGICAL_POSITIONAL_JOIN: - return PropagateStatistics(node.Cast(), node_ptr); - case LogicalOperatorType::LOGICAL_UNION: - case LogicalOperatorType::LOGICAL_EXCEPT: - case LogicalOperatorType::LOGICAL_INTERSECT: - return PropagateStatistics(node.Cast(), node_ptr); - case LogicalOperatorType::LOGICAL_ORDER_BY: - return PropagateStatistics(node.Cast(), node_ptr); - case LogicalOperatorType::LOGICAL_WINDOW: - return PropagateStatistics(node.Cast(), node_ptr); - default: - return PropagateChildren(node, node_ptr); - } -} - -unique_ptr StatisticsPropagator::PropagateStatistics(unique_ptr &node_ptr) { - return PropagateStatistics(*node_ptr, &node_ptr); -} - -unique_ptr StatisticsPropagator::PropagateExpression(Expression &expr, - unique_ptr *expr_ptr) { - switch (expr.GetExpressionClass()) { - case ExpressionClass::BOUND_AGGREGATE: - return PropagateExpression(expr.Cast(), expr_ptr); - case ExpressionClass::BOUND_BETWEEN: - return PropagateExpression(expr.Cast(), expr_ptr); - case ExpressionClass::BOUND_CASE: - return PropagateExpression(expr.Cast(), expr_ptr); - case ExpressionClass::BOUND_CONJUNCTION: - return PropagateExpression(expr.Cast(), expr_ptr); - case ExpressionClass::BOUND_FUNCTION: - return PropagateExpression(expr.Cast(), expr_ptr); - case ExpressionClass::BOUND_CAST: - return PropagateExpression(expr.Cast(), expr_ptr); - case ExpressionClass::BOUND_COMPARISON: - return PropagateExpression(expr.Cast(), expr_ptr); - case ExpressionClass::BOUND_CONSTANT: - return PropagateExpression(expr.Cast(), expr_ptr); - case ExpressionClass::BOUND_COLUMN_REF: - return PropagateExpression(expr.Cast(), expr_ptr); - case ExpressionClass::BOUND_OPERATOR: - return PropagateExpression(expr.Cast(), expr_ptr); - default: - break; - } - ExpressionIterator::EnumerateChildren(expr, [&](unique_ptr &child) { PropagateExpression(child); }); - return nullptr; -} - -unique_ptr StatisticsPropagator::PropagateExpression(unique_ptr &expr) { - auto stats = PropagateExpression(*expr, &expr); - if (ClientConfig::GetConfig(context).query_verification_enabled && stats) { - expr->verification_stats = stats->ToUnique(); - } - return stats; -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -bool TopN::CanOptimize(LogicalOperator &op) { - if (op.type == LogicalOperatorType::LOGICAL_LIMIT && - op.children[0]->type == LogicalOperatorType::LOGICAL_ORDER_BY) { - auto &limit = op.Cast(); - - // When there are some expressions in the limit operator, - // we shouldn't use this optimizations. Because of the expressions - // will be lost when it convert to TopN operator. - if (limit.limit || limit.offset) { - return false; - } - - // This optimization doesn't apply when OFFSET is present without LIMIT - // Or if offset is not constant - if (limit.limit_val != NumericLimits::Maximum() || limit.offset) { - return true; - } - } - return false; -} - -unique_ptr TopN::Optimize(unique_ptr op) { - if (CanOptimize(*op)) { - auto &limit = op->Cast(); - auto &order_by = (op->children[0])->Cast(); - - auto topn = make_uniq(std::move(order_by.orders), limit.limit_val, limit.offset_val); - topn->AddChild(std::move(order_by.children[0])); - op = std::move(topn); - } else { - for (auto &child : op->children) { - child = Optimize(std::move(child)); - } - } - return op; -} - -} // namespace duckdb - - - - - - - - - - - -namespace duckdb { - -void UnnestRewriterPlanUpdater::VisitOperator(LogicalOperator &op) { - VisitOperatorChildren(op); - VisitOperatorExpressions(op); -} - -void UnnestRewriterPlanUpdater::VisitExpression(unique_ptr *expression) { - auto &expr = *expression; - - if (expr->expression_class == ExpressionClass::BOUND_COLUMN_REF) { - auto &bound_column_ref = expr->Cast(); - for (idx_t i = 0; i < replace_bindings.size(); i++) { - if (bound_column_ref.binding == replace_bindings[i].old_binding) { - bound_column_ref.binding = replace_bindings[i].new_binding; - break; - } - } - } - - VisitExpressionChildren(**expression); -} - -unique_ptr UnnestRewriter::Optimize(unique_ptr op) { - - UnnestRewriterPlanUpdater updater; - vector *> candidates; - FindCandidates(&op, candidates); - - // rewrite the plan and update the bindings - for (auto &candidate : candidates) { - - // rearrange the logical operators - if (RewriteCandidate(candidate)) { - updater.overwritten_tbl_idx = overwritten_tbl_idx; - // update the bindings of the BOUND_UNNEST expression - UpdateBoundUnnestBindings(updater, candidate); - // update the sequence of LOGICAL_PROJECTION(s) - UpdateRHSBindings(&op, candidate, updater); - // reset - delim_columns.clear(); - lhs_bindings.clear(); - } - } - - return op; -} - -void UnnestRewriter::FindCandidates(unique_ptr *op_ptr, - vector *> &candidates) { - auto op = op_ptr->get(); - // search children before adding, so that we add candidates bottom-up - for (auto &child : op->children) { - FindCandidates(&child, candidates); - } - - // search for operator that has a LOGICAL_DELIM_JOIN as its child - if (op->children.size() != 1) { - return; - } - if (op->children[0]->type != LogicalOperatorType::LOGICAL_DELIM_JOIN) { - return; - } - - // found a delim join - auto &delim_join = op->children[0]->Cast(); - // only support INNER delim joins - if (delim_join.join_type != JoinType::INNER) { - return; - } - // INNER delim join must have exactly one condition - if (delim_join.conditions.size() != 1) { - return; - } - - // LHS child is a window - if (delim_join.children[0]->type != LogicalOperatorType::LOGICAL_WINDOW) { - return; - } - - // RHS child must be projection(s) followed by an UNNEST - auto curr_op = &delim_join.children[1]; - while (curr_op->get()->type == LogicalOperatorType::LOGICAL_PROJECTION) { - if (curr_op->get()->children.size() != 1) { - break; - } - curr_op = &curr_op->get()->children[0]; - } - - if (curr_op->get()->type == LogicalOperatorType::LOGICAL_UNNEST) { - candidates.push_back(op_ptr); - } -} - -bool UnnestRewriter::RewriteCandidate(unique_ptr *candidate) { - - auto &topmost_op = (LogicalOperator &)**candidate; - if (topmost_op.type != LogicalOperatorType::LOGICAL_PROJECTION && - topmost_op.type != LogicalOperatorType::LOGICAL_WINDOW && - topmost_op.type != LogicalOperatorType::LOGICAL_FILTER && - topmost_op.type != LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY && - topmost_op.type != LogicalOperatorType::LOGICAL_UNNEST) { - return false; - } - - // get the LOGICAL_DELIM_JOIN, which is a child of the candidate - D_ASSERT(topmost_op.children.size() == 1); - auto &delim_join = *(topmost_op.children[0]); - D_ASSERT(delim_join.type == LogicalOperatorType::LOGICAL_DELIM_JOIN); - GetDelimColumns(delim_join); - - // LHS of the LOGICAL_DELIM_JOIN is a LOGICAL_WINDOW that contains a LOGICAL_PROJECTION - // this lhs_proj later becomes the child of the UNNEST - auto &window = *delim_join.children[0]; - auto &lhs_op = window.children[0]; - GetLHSExpressions(*lhs_op); - - // find the LOGICAL_UNNEST - // and get the path down to the LOGICAL_UNNEST - vector *> path_to_unnest; - auto curr_op = &(delim_join.children[1]); - while (curr_op->get()->type == LogicalOperatorType::LOGICAL_PROJECTION) { - path_to_unnest.push_back(curr_op); - curr_op = &curr_op->get()->children[0]; - } - - // store the table index of the child of the LOGICAL_UNNEST - // then update the plan by making the lhs_proj the child of the LOGICAL_UNNEST - D_ASSERT(curr_op->get()->type == LogicalOperatorType::LOGICAL_UNNEST); - auto &unnest = curr_op->get()->Cast(); - D_ASSERT(unnest.children[0]->type == LogicalOperatorType::LOGICAL_DELIM_GET); - overwritten_tbl_idx = unnest.children[0]->Cast().table_index; - - D_ASSERT(!unnest.children.empty()); - auto &delim_get = unnest.children[0]->Cast(); - D_ASSERT(delim_get.chunk_types.size() > 1); - distinct_unnest_count = delim_get.chunk_types.size(); - unnest.children[0] = std::move(lhs_op); - - // replace the LOGICAL_DELIM_JOIN with its RHS child operator - topmost_op.children[0] = std::move(*path_to_unnest.front()); - return true; -} - -void UnnestRewriter::UpdateRHSBindings(unique_ptr *plan_ptr, unique_ptr *candidate, - UnnestRewriterPlanUpdater &updater) { - - auto &topmost_op = (LogicalOperator &)**candidate; - idx_t shift = lhs_bindings.size(); - - vector *> path_to_unnest; - auto curr_op = &(topmost_op.children[0]); - while (curr_op->get()->type == LogicalOperatorType::LOGICAL_PROJECTION) { - - path_to_unnest.push_back(curr_op); - D_ASSERT(curr_op->get()->type == LogicalOperatorType::LOGICAL_PROJECTION); - auto &proj = curr_op->get()->Cast(); - - // pop the unnest columns and the delim index - D_ASSERT(proj.expressions.size() > distinct_unnest_count); - for (idx_t i = 0; i < distinct_unnest_count; i++) { - proj.expressions.pop_back(); - } - - // store all shifted current bindings - idx_t tbl_idx = proj.table_index; - for (idx_t i = 0; i < proj.expressions.size(); i++) { - ReplaceBinding replace_binding(ColumnBinding(tbl_idx, i), ColumnBinding(tbl_idx, i + shift)); - updater.replace_bindings.push_back(replace_binding); - } - - curr_op = &curr_op->get()->children[0]; - } - - // update all bindings by shifting them - updater.VisitOperator(*plan_ptr->get()); - updater.replace_bindings.clear(); - - // update all bindings coming from the LHS to RHS bindings - D_ASSERT(topmost_op.children[0]->type == LogicalOperatorType::LOGICAL_PROJECTION); - auto &top_proj = topmost_op.children[0]->Cast(); - for (idx_t i = 0; i < lhs_bindings.size(); i++) { - ReplaceBinding replace_binding(lhs_bindings[i].binding, ColumnBinding(top_proj.table_index, i)); - updater.replace_bindings.push_back(replace_binding); - } - - // temporarily remove the BOUND_UNNESTs and the child of the LOGICAL_UNNEST from the plan - D_ASSERT(curr_op->get()->type == LogicalOperatorType::LOGICAL_UNNEST); - auto &unnest = curr_op->get()->Cast(); - vector> temp_bound_unnests; - for (auto &temp_bound_unnest : unnest.expressions) { - temp_bound_unnests.push_back(std::move(temp_bound_unnest)); - } - D_ASSERT(unnest.children.size() == 1); - auto temp_unnest_child = std::move(unnest.children[0]); - unnest.expressions.clear(); - unnest.children.clear(); - // update the bindings of the plan - updater.VisitOperator(*plan_ptr->get()); - updater.replace_bindings.clear(); - // add the children again - for (auto &temp_bound_unnest : temp_bound_unnests) { - unnest.expressions.push_back(std::move(temp_bound_unnest)); - } - unnest.children.push_back(std::move(temp_unnest_child)); - - // add the LHS expressions to each LOGICAL_PROJECTION - for (idx_t i = path_to_unnest.size(); i > 0; i--) { - - D_ASSERT(path_to_unnest[i - 1]->get()->type == LogicalOperatorType::LOGICAL_PROJECTION); - auto &proj = path_to_unnest[i - 1]->get()->Cast(); - - // temporarily store the existing expressions - vector> existing_expressions; - for (idx_t expr_idx = 0; expr_idx < proj.expressions.size(); expr_idx++) { - existing_expressions.push_back(std::move(proj.expressions[expr_idx])); - } - - proj.expressions.clear(); - - // add the new expressions - for (idx_t expr_idx = 0; expr_idx < lhs_bindings.size(); expr_idx++) { - auto new_expr = make_uniq( - lhs_bindings[expr_idx].alias, lhs_bindings[expr_idx].type, lhs_bindings[expr_idx].binding); - proj.expressions.push_back(std::move(new_expr)); - - // update the table index - lhs_bindings[expr_idx].binding.table_index = proj.table_index; - lhs_bindings[expr_idx].binding.column_index = expr_idx; - } - - // add the existing expressions again - for (idx_t expr_idx = 0; expr_idx < existing_expressions.size(); expr_idx++) { - proj.expressions.push_back(std::move(existing_expressions[expr_idx])); - } - } -} - -void UnnestRewriter::UpdateBoundUnnestBindings(UnnestRewriterPlanUpdater &updater, - unique_ptr *candidate) { - - auto &topmost_op = (LogicalOperator &)**candidate; - - // traverse LOGICAL_PROJECTION(s) - auto curr_op = &(topmost_op.children[0]); - while (curr_op->get()->type == LogicalOperatorType::LOGICAL_PROJECTION) { - curr_op = &curr_op->get()->children[0]; - } - - // found the LOGICAL_UNNEST - D_ASSERT(curr_op->get()->type == LogicalOperatorType::LOGICAL_UNNEST); - auto &unnest = curr_op->get()->Cast(); - - D_ASSERT(unnest.children.size() == 1); - auto unnest_cols = unnest.children[0]->GetColumnBindings(); - - for (idx_t i = 0; i < delim_columns.size(); i++) { - auto delim_binding = delim_columns[i]; - - auto unnest_it = unnest_cols.begin(); - while (unnest_it != unnest_cols.end()) { - auto unnest_binding = *unnest_it; - - if (delim_binding.table_index == unnest_binding.table_index) { - unnest_binding.table_index = overwritten_tbl_idx; - unnest_binding.column_index++; - updater.replace_bindings.emplace_back(unnest_binding, delim_binding); - unnest_cols.erase(unnest_it); - break; - } - unnest_it++; - } - } - - // update bindings - for (auto &unnest_expr : unnest.expressions) { - updater.VisitExpression(&unnest_expr); - } - updater.replace_bindings.clear(); -} - -void UnnestRewriter::GetDelimColumns(LogicalOperator &op) { - - D_ASSERT(op.type == LogicalOperatorType::LOGICAL_DELIM_JOIN); - auto &delim_join = op.Cast(); - for (idx_t i = 0; i < delim_join.duplicate_eliminated_columns.size(); i++) { - auto &expr = *delim_join.duplicate_eliminated_columns[i]; - D_ASSERT(expr.type == ExpressionType::BOUND_COLUMN_REF); - auto &bound_colref_expr = expr.Cast(); - delim_columns.push_back(bound_colref_expr.binding); - } -} - -void UnnestRewriter::GetLHSExpressions(LogicalOperator &op) { - - op.ResolveOperatorTypes(); - auto col_bindings = op.GetColumnBindings(); - D_ASSERT(op.types.size() == col_bindings.size()); - - bool set_alias = false; - // we can easily extract the alias for LOGICAL_PROJECTION(s) - if (op.type == LogicalOperatorType::LOGICAL_PROJECTION) { - auto &proj = op.Cast(); - if (proj.expressions.size() == op.types.size()) { - set_alias = true; - } - } - - for (idx_t i = 0; i < op.types.size(); i++) { - lhs_bindings.emplace_back(col_bindings[i], op.types[i]); - if (set_alias) { - auto &proj = op.Cast(); - lhs_bindings.back().alias = proj.expressions[i]->alias; - } - } -} - -} // namespace duckdb - - -namespace duckdb { - -BasePipelineEvent::BasePipelineEvent(shared_ptr pipeline_p) - : Event(pipeline_p->executor), pipeline(std::move(pipeline_p)) { -} - -BasePipelineEvent::BasePipelineEvent(Pipeline &pipeline_p) - : Event(pipeline_p.executor), pipeline(pipeline_p.shared_from_this()) { -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -Event::Event(Executor &executor_p) - : executor(executor_p), finished_tasks(0), total_tasks(0), finished_dependencies(0), total_dependencies(0), - finished(false) { -} - -void Event::CompleteDependency() { - idx_t current_finished = ++finished_dependencies; - D_ASSERT(current_finished <= total_dependencies); - if (current_finished == total_dependencies) { - // all dependencies have been completed: schedule the event - D_ASSERT(total_tasks == 0); - Schedule(); - if (total_tasks == 0) { - Finish(); - } - } -} - -void Event::Finish() { - D_ASSERT(!finished); - FinishEvent(); - finished = true; - // finished processing the pipeline, now we can schedule pipelines that depend on this pipeline - for (auto &parent_entry : parents) { - auto parent = parent_entry.lock(); - if (!parent) { // LCOV_EXCL_START - continue; - } // LCOV_EXCL_STOP - // mark a dependency as completed for each of the parents - parent->CompleteDependency(); - } - FinalizeFinish(); -} - -void Event::AddDependency(Event &event) { - total_dependencies++; - event.parents.push_back(weak_ptr(shared_from_this())); -#ifdef DEBUG - event.parents_raw.push_back(this); -#endif -} - -const vector &Event::GetParentsVerification() const { - D_ASSERT(parents.size() == parents_raw.size()); - return parents_raw; -} - -void Event::FinishTask() { - D_ASSERT(finished_tasks.load() < total_tasks.load()); - idx_t current_tasks = total_tasks; - idx_t current_finished = ++finished_tasks; - D_ASSERT(current_finished <= current_tasks); - if (current_finished == current_tasks) { - Finish(); - } -} - -void Event::InsertEvent(shared_ptr replacement_event) { - replacement_event->parents = std::move(parents); -#ifdef DEBUG - replacement_event->parents_raw = std::move(parents_raw); -#endif - replacement_event->AddDependency(*this); - executor.AddEvent(std::move(replacement_event)); -} - -void Event::SetTasks(vector> tasks) { - auto &ts = TaskScheduler::GetScheduler(executor.context); - D_ASSERT(total_tasks == 0); - D_ASSERT(!tasks.empty()); - this->total_tasks = tasks.size(); - for (auto &task : tasks) { - ts.ScheduleTask(executor.GetToken(), std::move(task)); - } -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - -#include - -namespace duckdb { - -Executor::Executor(ClientContext &context) : context(context) { -} - -Executor::~Executor() { -} - -Executor &Executor::Get(ClientContext &context) { - return context.GetExecutor(); -} - -void Executor::AddEvent(shared_ptr event) { - lock_guard elock(executor_lock); - if (cancelled) { - return; - } - events.push_back(std::move(event)); -} - -struct PipelineEventStack { - PipelineEventStack(Event &pipeline_initialize_event, Event &pipeline_event, Event &pipeline_finish_event, - Event &pipeline_complete_event) - : pipeline_initialize_event(pipeline_initialize_event), pipeline_event(pipeline_event), - pipeline_finish_event(pipeline_finish_event), pipeline_complete_event(pipeline_complete_event) { - } - - Event &pipeline_initialize_event; - Event &pipeline_event; - Event &pipeline_finish_event; - Event &pipeline_complete_event; -}; - -using event_map_t = reference_map_t; - -struct ScheduleEventData { - ScheduleEventData(const vector> &meta_pipelines, vector> &events, - bool initial_schedule) - : meta_pipelines(meta_pipelines), events(events), initial_schedule(initial_schedule) { - } - - const vector> &meta_pipelines; - vector> &events; - bool initial_schedule; - event_map_t event_map; -}; - -void Executor::SchedulePipeline(const shared_ptr &meta_pipeline, ScheduleEventData &event_data) { - D_ASSERT(meta_pipeline); - auto &events = event_data.events; - auto &event_map = event_data.event_map; - - // create events/stack for the base pipeline - auto base_pipeline = meta_pipeline->GetBasePipeline(); - auto base_initialize_event = make_shared(base_pipeline); - auto base_event = make_shared(base_pipeline); - auto base_finish_event = make_shared(base_pipeline); - auto base_complete_event = make_shared(base_pipeline->executor, event_data.initial_schedule); - PipelineEventStack base_stack(*base_initialize_event, *base_event, *base_finish_event, *base_complete_event); - events.push_back(std::move(base_initialize_event)); - events.push_back(std::move(base_event)); - events.push_back(std::move(base_finish_event)); - events.push_back(std::move(base_complete_event)); - - // dependencies: initialize -> event -> finish -> complete - base_stack.pipeline_event.AddDependency(base_stack.pipeline_initialize_event); - base_stack.pipeline_finish_event.AddDependency(base_stack.pipeline_event); - base_stack.pipeline_complete_event.AddDependency(base_stack.pipeline_finish_event); - - // create an event and stack for all pipelines in the MetaPipeline - vector> pipelines; - meta_pipeline->GetPipelines(pipelines, false); - for (idx_t i = 1; i < pipelines.size(); i++) { // loop starts at 1 because 0 is the base pipeline - auto &pipeline = pipelines[i]; - D_ASSERT(pipeline); - - // create events/stack for this pipeline - auto pipeline_event = make_shared(pipeline); - - auto finish_group = meta_pipeline->GetFinishGroup(pipeline.get()); - if (finish_group) { - // this pipeline is part of a finish group - const auto group_entry = event_map.find(*finish_group.get()); - D_ASSERT(group_entry != event_map.end()); - auto &group_stack = group_entry->second; - PipelineEventStack pipeline_stack(base_stack.pipeline_initialize_event, *pipeline_event, - group_stack.pipeline_finish_event, base_stack.pipeline_complete_event); - - // dependencies: base_finish -> pipeline_event -> group_finish - pipeline_stack.pipeline_event.AddDependency(base_stack.pipeline_finish_event); - group_stack.pipeline_finish_event.AddDependency(pipeline_stack.pipeline_event); - - // add pipeline stack to event map - event_map.insert(make_pair(reference(*pipeline), pipeline_stack)); - } else if (meta_pipeline->HasFinishEvent(pipeline.get())) { - // this pipeline has its own finish event (despite going into the same sink - Finalize twice!) - auto pipeline_finish_event = make_shared(pipeline); - PipelineEventStack pipeline_stack(base_stack.pipeline_initialize_event, *pipeline_event, - *pipeline_finish_event, base_stack.pipeline_complete_event); - events.push_back(std::move(pipeline_finish_event)); - - // dependencies: base_finish -> pipeline_event -> pipeline_finish -> base_complete - pipeline_stack.pipeline_event.AddDependency(base_stack.pipeline_finish_event); - pipeline_stack.pipeline_finish_event.AddDependency(pipeline_stack.pipeline_event); - base_stack.pipeline_complete_event.AddDependency(pipeline_stack.pipeline_finish_event); - - // add pipeline stack to event map - event_map.insert(make_pair(reference(*pipeline), pipeline_stack)); - - } else { - // no additional finish event - PipelineEventStack pipeline_stack(base_stack.pipeline_initialize_event, *pipeline_event, - base_stack.pipeline_finish_event, base_stack.pipeline_complete_event); - - // dependencies: base_initialize -> pipeline_event -> base_finish - pipeline_stack.pipeline_event.AddDependency(base_stack.pipeline_initialize_event); - base_stack.pipeline_finish_event.AddDependency(pipeline_stack.pipeline_event); - - // add pipeline stack to event map - event_map.insert(make_pair(reference(*pipeline), pipeline_stack)); - } - events.push_back(std::move(pipeline_event)); - } - - // add base stack to the event data too - event_map.insert(make_pair(reference(*base_pipeline), base_stack)); - - // set up the dependencies within this MetaPipeline - for (auto &pipeline : pipelines) { - auto source = pipeline->GetSource(); - if (source->type == PhysicalOperatorType::TABLE_SCAN) { - // we have to reset the source here (in the main thread), because some of our clients (looking at you, R) - // do not like it when threads other than the main thread call into R, for e.g., arrow scans - pipeline->ResetSource(true); - } - - auto dependencies = meta_pipeline->GetDependencies(pipeline.get()); - if (!dependencies) { - continue; - } - auto root_entry = event_map.find(*pipeline); - D_ASSERT(root_entry != event_map.end()); - auto &pipeline_stack = root_entry->second; - for (auto &dependency : *dependencies) { - auto event_entry = event_map.find(*dependency); - D_ASSERT(event_entry != event_map.end()); - auto &dependency_stack = event_entry->second; - pipeline_stack.pipeline_event.AddDependency(dependency_stack.pipeline_event); - } - } -} - -void Executor::ScheduleEventsInternal(ScheduleEventData &event_data) { - auto &events = event_data.events; - D_ASSERT(events.empty()); - - // create all the required pipeline events - for (auto &pipeline : event_data.meta_pipelines) { - SchedulePipeline(pipeline, event_data); - } - - // set up the dependencies across MetaPipelines - auto &event_map = event_data.event_map; - for (auto &entry : event_map) { - auto &pipeline = entry.first.get(); - for (auto &dependency : pipeline.dependencies) { - auto dep = dependency.lock(); - D_ASSERT(dep); - auto event_map_entry = event_map.find(*dep); - D_ASSERT(event_map_entry != event_map.end()); - auto &dep_entry = event_map_entry->second; - entry.second.pipeline_event.AddDependency(dep_entry.pipeline_complete_event); - } - } - - // verify that we have no cyclic dependencies - VerifyScheduledEvents(event_data); - - // schedule the pipelines that do not have dependencies - for (auto &event : events) { - if (!event->HasDependencies()) { - event->Schedule(); - } - } -} - -void Executor::ScheduleEvents(const vector> &meta_pipelines) { - ScheduleEventData event_data(meta_pipelines, events, true); - ScheduleEventsInternal(event_data); -} - -void Executor::VerifyScheduledEvents(const ScheduleEventData &event_data) { -#ifdef DEBUG - const idx_t count = event_data.events.size(); - vector vertices; - vertices.reserve(count); - for (const auto &event : event_data.events) { - vertices.push_back(event.get()); - } - vector visited(count, false); - vector recursion_stack(count, false); - for (idx_t i = 0; i < count; i++) { - VerifyScheduledEventsInternal(i, vertices, visited, recursion_stack); - } -#endif -} - -void Executor::VerifyScheduledEventsInternal(const idx_t vertex, const vector &vertices, vector &visited, - vector &recursion_stack) { - D_ASSERT(!recursion_stack[vertex]); // this vertex is in the recursion stack: circular dependency! - if (visited[vertex]) { - return; // early out: we already visited this vertex - } - - auto &parents = vertices[vertex]->GetParentsVerification(); - if (parents.empty()) { - return; // early out: outgoing edges - } - - // create a vector the indices of the adjacent events - vector adjacent; - const idx_t count = vertices.size(); - for (auto parent : parents) { - idx_t i; - for (i = 0; i < count; i++) { - if (vertices[i] == parent) { - adjacent.push_back(i); - break; - } - } - D_ASSERT(i != count); // dependency must be in there somewhere - } - - // mark vertex as visited and add to recursion stack - visited[vertex] = true; - recursion_stack[vertex] = true; - - // recurse into adjacent vertices - for (const auto &i : adjacent) { - VerifyScheduledEventsInternal(i, vertices, visited, recursion_stack); - } - - // remove vertex from recursion stack - recursion_stack[vertex] = false; -} - -void Executor::AddRecursiveCTE(PhysicalOperator &rec_cte) { - recursive_ctes.push_back(rec_cte); -} - -void Executor::AddMaterializedCTE(PhysicalOperator &mat_cte) { - materialized_ctes.push_back(mat_cte); -} - -void Executor::ReschedulePipelines(const vector> &pipelines_p, - vector> &events_p) { - ScheduleEventData event_data(pipelines_p, events_p, false); - ScheduleEventsInternal(event_data); -} - -bool Executor::NextExecutor() { - if (root_pipeline_idx >= root_pipelines.size()) { - return false; - } - root_pipelines[root_pipeline_idx]->Reset(); - root_executor = make_uniq(context, *root_pipelines[root_pipeline_idx]); - root_pipeline_idx++; - return true; -} - -void Executor::VerifyPipeline(Pipeline &pipeline) { - D_ASSERT(!pipeline.ToString().empty()); - auto operators = pipeline.GetOperators(); - for (auto &other_pipeline : pipelines) { - auto other_operators = other_pipeline->GetOperators(); - for (idx_t op_idx = 0; op_idx < operators.size(); op_idx++) { - for (idx_t other_idx = 0; other_idx < other_operators.size(); other_idx++) { - auto &left = operators[op_idx].get(); - auto &right = other_operators[other_idx].get(); - if (left.Equals(right)) { - D_ASSERT(right.Equals(left)); - } else { - D_ASSERT(!right.Equals(left)); - } - } - } - } -} - -void Executor::VerifyPipelines() { -#ifdef DEBUG - for (auto &pipeline : pipelines) { - VerifyPipeline(*pipeline); - } -#endif -} - -void Executor::Initialize(unique_ptr physical_plan) { - Reset(); - owned_plan = std::move(physical_plan); - InitializeInternal(*owned_plan); -} - -void Executor::Initialize(PhysicalOperator &plan) { - Reset(); - InitializeInternal(plan); -} - -void Executor::InitializeInternal(PhysicalOperator &plan) { - - auto &scheduler = TaskScheduler::GetScheduler(context); - { - lock_guard elock(executor_lock); - physical_plan = &plan; - - this->profiler = ClientData::Get(context).profiler; - profiler->Initialize(plan); - this->producer = scheduler.CreateProducer(); - - // build and ready the pipelines - PipelineBuildState state; - auto root_pipeline = make_shared(*this, state, nullptr); - root_pipeline->Build(*physical_plan); - root_pipeline->Ready(); - - // ready recursive cte pipelines too - for (auto &rec_cte_ref : recursive_ctes) { - auto &rec_cte = rec_cte_ref.get().Cast(); - rec_cte.recursive_meta_pipeline->Ready(); - } - - // ready materialized cte pipelines too - for (auto &mat_cte_ref : materialized_ctes) { - auto &mat_cte = mat_cte_ref.get().Cast(); - mat_cte.recursive_meta_pipeline->Ready(); - } - - // set root pipelines, i.e., all pipelines that end in the final sink - root_pipeline->GetPipelines(root_pipelines, false); - root_pipeline_idx = 0; - - // collect all meta-pipelines from the root pipeline - vector> to_schedule; - root_pipeline->GetMetaPipelines(to_schedule, true, true); - - // number of 'PipelineCompleteEvent's is equal to the number of meta pipelines, so we have to set it here - total_pipelines = to_schedule.size(); - - // collect all pipelines from the root pipelines (recursively) for the progress bar and verify them - root_pipeline->GetPipelines(pipelines, true); - - // finally, verify and schedule - VerifyPipelines(); - ScheduleEvents(to_schedule); - } -} - -void Executor::CancelTasks() { - task.reset(); - // we do this by creating weak pointers to all pipelines - // then clearing our references to the pipelines - // and waiting until all pipelines have been destroyed - vector> weak_references; - { - lock_guard elock(executor_lock); - weak_references.reserve(pipelines.size()); - cancelled = true; - for (auto &pipeline : pipelines) { - weak_references.push_back(weak_ptr(pipeline)); - } - for (auto &rec_cte_ref : recursive_ctes) { - auto &rec_cte = rec_cte_ref.get().Cast(); - rec_cte.recursive_meta_pipeline.reset(); - } - for (auto &mat_cte_ref : materialized_ctes) { - auto &mat_cte = mat_cte_ref.get().Cast(); - mat_cte.recursive_meta_pipeline.reset(); - } - pipelines.clear(); - root_pipelines.clear(); - to_be_rescheduled_tasks.clear(); - events.clear(); - } - WorkOnTasks(); - for (auto &weak_ref : weak_references) { - while (true) { - auto weak = weak_ref.lock(); - if (!weak) { - break; - } - } - } -} - -void Executor::WorkOnTasks() { - auto &scheduler = TaskScheduler::GetScheduler(context); - - shared_ptr task; - while (scheduler.GetTaskFromProducer(*producer, task)) { - auto res = task->Execute(TaskExecutionMode::PROCESS_ALL); - if (res == TaskExecutionResult::TASK_BLOCKED) { - task->Deschedule(); - } - task.reset(); - } -} - -void Executor::RescheduleTask(shared_ptr &task) { - // This function will spin lock until the task provided is added to the to_be_rescheduled_tasks - while (true) { - lock_guard l(executor_lock); - if (cancelled) { - return; - } - auto entry = to_be_rescheduled_tasks.find(task.get()); - if (entry != to_be_rescheduled_tasks.end()) { - auto &scheduler = TaskScheduler::GetScheduler(context); - to_be_rescheduled_tasks.erase(task.get()); - scheduler.ScheduleTask(GetToken(), task); - break; - } - } -} - -void Executor::AddToBeRescheduled(shared_ptr &task) { - lock_guard l(executor_lock); - if (cancelled) { - return; - } - if (to_be_rescheduled_tasks.find(task.get()) != to_be_rescheduled_tasks.end()) { - return; - } - to_be_rescheduled_tasks[task.get()] = std::move(task); -} - -bool Executor::ExecutionIsFinished() { - return completed_pipelines >= total_pipelines || HasError(); -} - -PendingExecutionResult Executor::ExecuteTask() { - // Only executor should return NO_TASKS_AVAILABLE - D_ASSERT(execution_result != PendingExecutionResult::NO_TASKS_AVAILABLE); - if (execution_result != PendingExecutionResult::RESULT_NOT_READY) { - return execution_result; - } - // check if there are any incomplete pipelines - auto &scheduler = TaskScheduler::GetScheduler(context); - while (completed_pipelines < total_pipelines) { - // there are! if we don't already have a task, fetch one - if (!task) { - scheduler.GetTaskFromProducer(*producer, task); - } - if (!task && !HasError()) { - // there are no tasks to be scheduled and there are tasks blocked - return PendingExecutionResult::NO_TASKS_AVAILABLE; - } - if (task) { - // if we have a task, partially process it - auto result = task->Execute(TaskExecutionMode::PROCESS_PARTIAL); - if (result == TaskExecutionResult::TASK_BLOCKED) { - task->Deschedule(); - task.reset(); - } else if (result == TaskExecutionResult::TASK_FINISHED) { - // if the task is finished, clean it up - task.reset(); - } - } - if (!HasError()) { - // we (partially) processed a task and no exceptions were thrown - // give back control to the caller - return PendingExecutionResult::RESULT_NOT_READY; - } - execution_result = PendingExecutionResult::EXECUTION_ERROR; - - // an exception has occurred executing one of the pipelines - // we need to cancel all tasks associated with this executor - CancelTasks(); - ThrowException(); - } - D_ASSERT(!task); - - lock_guard elock(executor_lock); - pipelines.clear(); - NextExecutor(); - if (HasError()) { // LCOV_EXCL_START - // an exception has occurred executing one of the pipelines - execution_result = PendingExecutionResult::EXECUTION_ERROR; - ThrowException(); - } // LCOV_EXCL_STOP - execution_result = PendingExecutionResult::RESULT_READY; - return execution_result; -} - -void Executor::Reset() { - lock_guard elock(executor_lock); - physical_plan = nullptr; - cancelled = false; - owned_plan.reset(); - root_executor.reset(); - root_pipelines.clear(); - root_pipeline_idx = 0; - completed_pipelines = 0; - total_pipelines = 0; - exceptions.clear(); - pipelines.clear(); - events.clear(); - to_be_rescheduled_tasks.clear(); - execution_result = PendingExecutionResult::RESULT_NOT_READY; -} - -shared_ptr Executor::CreateChildPipeline(Pipeline ¤t, PhysicalOperator &op) { - D_ASSERT(!current.operators.empty()); - D_ASSERT(op.IsSource()); - // found another operator that is a source, schedule a child pipeline - // 'op' is the source, and the sink is the same - auto child_pipeline = make_shared(*this); - child_pipeline->sink = current.sink; - child_pipeline->source = &op; - - // the child pipeline has the same operators up until 'op' - for (auto current_op : current.operators) { - if (¤t_op.get() == &op) { - break; - } - child_pipeline->operators.push_back(current_op); - } - - return child_pipeline; -} - -vector Executor::GetTypes() { - D_ASSERT(physical_plan); - return physical_plan->GetTypes(); -} - -void Executor::PushError(PreservedError exception) { - lock_guard elock(error_lock); - // interrupt execution of any other pipelines that belong to this executor - context.interrupted = true; - // push the exception onto the stack - exceptions.push_back(std::move(exception)); -} - -bool Executor::HasError() { - lock_guard elock(error_lock); - return !exceptions.empty(); -} - -void Executor::ThrowException() { - lock_guard elock(error_lock); - D_ASSERT(!exceptions.empty()); - auto &entry = exceptions[0]; - entry.Throw(); -} - -void Executor::Flush(ThreadContext &tcontext) { - profiler->Flush(tcontext.profiler); -} - -bool Executor::GetPipelinesProgress(double ¤t_progress) { // LCOV_EXCL_START - lock_guard elock(executor_lock); - - vector progress; - vector cardinality; - idx_t total_cardinality = 0; - for (auto &pipeline : pipelines) { - double child_percentage; - idx_t child_cardinality; - - if (!pipeline->GetProgress(child_percentage, child_cardinality)) { - return false; - } - progress.push_back(child_percentage); - cardinality.push_back(child_cardinality); - total_cardinality += child_cardinality; - } - current_progress = 0; - if (total_cardinality == 0) { - return true; - } - for (size_t i = 0; i < progress.size(); i++) { - current_progress += progress[i] * double(cardinality[i]) / double(total_cardinality); - } - return true; -} // LCOV_EXCL_STOP - -bool Executor::HasResultCollector() { - return physical_plan->type == PhysicalOperatorType::RESULT_COLLECTOR; -} - -unique_ptr Executor::GetResult() { - D_ASSERT(HasResultCollector()); - auto &result_collector = physical_plan->Cast(); - D_ASSERT(result_collector.sink_state); - return result_collector.GetResult(*result_collector.sink_state); -} - -unique_ptr Executor::FetchChunk() { - D_ASSERT(physical_plan); - - auto chunk = make_uniq(); - root_executor->InitializeChunk(*chunk); - while (true) { - root_executor->ExecutePull(*chunk); - if (chunk->size() == 0) { - root_executor->PullFinalize(); - if (NextExecutor()) { - continue; - } - break; - } else { - break; - } - } - return chunk; -} - -} // namespace duckdb - - - - -namespace duckdb { - -ExecutorTask::ExecutorTask(Executor &executor_p) : executor(executor_p) { -} - -ExecutorTask::ExecutorTask(ClientContext &context) : ExecutorTask(Executor::Get(context)) { -} - -ExecutorTask::~ExecutorTask() { -} - -void ExecutorTask::Deschedule() { - auto this_ptr = shared_from_this(); - executor.AddToBeRescheduled(this_ptr); -} - -void ExecutorTask::Reschedule() { - auto this_ptr = shared_from_this(); - executor.RescheduleTask(this_ptr); -} - -TaskExecutionResult ExecutorTask::Execute(TaskExecutionMode mode) { - try { - return ExecuteTask(mode); - } catch (Exception &ex) { - executor.PushError(PreservedError(ex)); - } catch (std::exception &ex) { - executor.PushError(PreservedError(ex)); - } catch (...) { // LCOV_EXCL_START - executor.PushError(PreservedError("Unknown exception in Finalize!")); - } // LCOV_EXCL_STOP - return TaskExecutionResult::TASK_ERROR; -} - -} // namespace duckdb - - - - - -#include - -namespace duckdb { - -InterruptState::InterruptState() : mode(InterruptMode::NO_INTERRUPTS) { -} -InterruptState::InterruptState(weak_ptr task) : mode(InterruptMode::TASK), current_task(std::move(task)) { -} -InterruptState::InterruptState(weak_ptr signal_state_p) - : mode(InterruptMode::BLOCKING), signal_state(std::move(signal_state_p)) { -} - -void InterruptState::Callback() const { - if (mode == InterruptMode::TASK) { - auto task = current_task.lock(); - - if (!task) { - return; - } - - task->Reschedule(); - } else if (mode == InterruptMode::BLOCKING) { - auto signal_state_l = signal_state.lock(); - - if (!signal_state_l) { - return; - } - - // Signal the caller, who is currently blocked - signal_state_l->Signal(); - } else { - throw InternalException("Callback made on InterruptState without valid interrupt mode specified"); - } -} - -void InterruptDoneSignalState::Signal() { - { - unique_lock lck {lock}; - done = true; - } - cv.notify_all(); -} - -void InterruptDoneSignalState::Await() { - std::unique_lock lck(lock); - cv.wait(lck, [&]() { return done; }); - - // Reset after signal received - done = false; -} - -} // namespace duckdb - - - - - -namespace duckdb { - -MetaPipeline::MetaPipeline(Executor &executor_p, PipelineBuildState &state_p, PhysicalOperator *sink_p) - : executor(executor_p), state(state_p), sink(sink_p), recursive_cte(false), next_batch_index(0) { - CreatePipeline(); -} - -Executor &MetaPipeline::GetExecutor() const { - return executor; -} - -PipelineBuildState &MetaPipeline::GetState() const { - return state; -} - -optional_ptr MetaPipeline::GetSink() const { - return sink; -} - -shared_ptr &MetaPipeline::GetBasePipeline() { - return pipelines[0]; -} - -void MetaPipeline::GetPipelines(vector> &result, bool recursive) { - result.insert(result.end(), pipelines.begin(), pipelines.end()); - if (recursive) { - for (auto &child : children) { - child->GetPipelines(result, true); - } - } -} - -void MetaPipeline::GetMetaPipelines(vector> &result, bool recursive, bool skip) { - if (!skip) { - result.push_back(shared_from_this()); - } - if (recursive) { - for (auto &child : children) { - child->GetMetaPipelines(result, true, false); - } - } -} - -const vector *MetaPipeline::GetDependencies(Pipeline *dependant) const { - auto it = dependencies.find(dependant); - if (it == dependencies.end()) { - return nullptr; - } else { - return &it->second; - } -} - -bool MetaPipeline::HasRecursiveCTE() const { - return recursive_cte; -} - -void MetaPipeline::SetRecursiveCTE() { - recursive_cte = true; -} - -void MetaPipeline::AssignNextBatchIndex(Pipeline *pipeline) { - pipeline->base_batch_index = next_batch_index++ * PipelineBuildState::BATCH_INCREMENT; -} - -void MetaPipeline::Build(PhysicalOperator &op) { - D_ASSERT(pipelines.size() == 1); - D_ASSERT(children.empty()); - op.BuildPipelines(*pipelines.back(), *this); -} - -void MetaPipeline::Ready() { - for (auto &pipeline : pipelines) { - pipeline->Ready(); - } - for (auto &child : children) { - child->Ready(); - } -} - -MetaPipeline &MetaPipeline::CreateChildMetaPipeline(Pipeline ¤t, PhysicalOperator &op) { - children.push_back(make_shared(executor, state, &op)); - auto child_meta_pipeline = children.back().get(); - // child MetaPipeline must finish completely before this MetaPipeline can start - current.AddDependency(child_meta_pipeline->GetBasePipeline()); - // child meta pipeline is part of the recursive CTE too - child_meta_pipeline->recursive_cte = recursive_cte; - return *child_meta_pipeline; -} - -Pipeline *MetaPipeline::CreatePipeline() { - pipelines.emplace_back(make_shared(executor)); - state.SetPipelineSink(*pipelines.back(), sink, next_batch_index++); - return pipelines.back().get(); -} - -void MetaPipeline::AddDependenciesFrom(Pipeline *dependant, Pipeline *start, bool including) { - // find 'start' - auto it = pipelines.begin(); - for (; it->get() != start; it++) { - } - - if (!including) { - it++; - } - - // collect pipelines that were created from then - vector created_pipelines; - for (; it != pipelines.end(); it++) { - if (it->get() == dependant) { - // cannot depend on itself - continue; - } - created_pipelines.push_back(it->get()); - } - - // add them to the dependencies - auto &deps = dependencies[dependant]; - deps.insert(deps.begin(), created_pipelines.begin(), created_pipelines.end()); -} - -void MetaPipeline::AddFinishEvent(Pipeline *pipeline) { - D_ASSERT(finish_pipelines.find(pipeline) == finish_pipelines.end()); - finish_pipelines.insert(pipeline); - - // add all pipelines that were added since 'pipeline' was added (including 'pipeline') to the finish group - auto it = pipelines.begin(); - for (; it->get() != pipeline; it++) { - } - it++; - for (; it != pipelines.end(); it++) { - finish_map.emplace(it->get(), pipeline); - } -} - -bool MetaPipeline::HasFinishEvent(Pipeline *pipeline) const { - return finish_pipelines.find(pipeline) != finish_pipelines.end(); -} - -optional_ptr MetaPipeline::GetFinishGroup(Pipeline *pipeline) const { - auto it = finish_map.find(pipeline); - return it == finish_map.end() ? nullptr : it->second; -} - -Pipeline *MetaPipeline::CreateUnionPipeline(Pipeline ¤t, bool order_matters) { - // create the union pipeline (batch index 0, should be set correctly afterwards) - auto union_pipeline = CreatePipeline(); - state.SetPipelineOperators(*union_pipeline, state.GetPipelineOperators(current)); - state.SetPipelineSink(*union_pipeline, sink, 0); - - // 'union_pipeline' inherits ALL dependencies of 'current' (within this MetaPipeline, and across MetaPipelines) - union_pipeline->dependencies = current.dependencies; - auto current_deps = GetDependencies(¤t); - if (current_deps) { - dependencies[union_pipeline] = *current_deps; - } - - if (order_matters) { - // if we need to preserve order, or if the sink is not parallel, we set a dependency - dependencies[union_pipeline].push_back(¤t); - } - - return union_pipeline; -} - -void MetaPipeline::CreateChildPipeline(Pipeline ¤t, PhysicalOperator &op, Pipeline *last_pipeline) { - // rule 2: 'current' must be fully built (down to the source) before creating the child pipeline - D_ASSERT(current.source); - - // create the child pipeline (same batch index) - pipelines.emplace_back(state.CreateChildPipeline(executor, current, op)); - auto child_pipeline = pipelines.back().get(); - child_pipeline->base_batch_index = current.base_batch_index; - - // child pipeline has a dependency (within this MetaPipeline on all pipelines that were scheduled - // between 'current' and now (including 'current') - set them up - dependencies[child_pipeline].push_back(¤t); - AddDependenciesFrom(child_pipeline, last_pipeline, false); - D_ASSERT(!GetDependencies(child_pipeline)->empty()); -} - -} // namespace duckdb - - - - - - - - - - - - - - - -namespace duckdb { - -class PipelineTask : public ExecutorTask { - static constexpr const idx_t PARTIAL_CHUNK_COUNT = 50; - -public: - explicit PipelineTask(Pipeline &pipeline_p, shared_ptr event_p) - : ExecutorTask(pipeline_p.executor), pipeline(pipeline_p), event(std::move(event_p)) { - } - - Pipeline &pipeline; - shared_ptr event; - unique_ptr pipeline_executor; - -public: - TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { - if (!pipeline_executor) { - pipeline_executor = make_uniq(pipeline.GetClientContext(), pipeline); - } - - pipeline_executor->SetTaskForInterrupts(shared_from_this()); - - if (mode == TaskExecutionMode::PROCESS_PARTIAL) { - auto res = pipeline_executor->Execute(PARTIAL_CHUNK_COUNT); - - switch (res) { - case PipelineExecuteResult::NOT_FINISHED: - return TaskExecutionResult::TASK_NOT_FINISHED; - case PipelineExecuteResult::INTERRUPTED: - return TaskExecutionResult::TASK_BLOCKED; - case PipelineExecuteResult::FINISHED: - break; - } - } else { - auto res = pipeline_executor->Execute(); - switch (res) { - case PipelineExecuteResult::NOT_FINISHED: - throw InternalException("Execute without limit should not return NOT_FINISHED"); - case PipelineExecuteResult::INTERRUPTED: - return TaskExecutionResult::TASK_BLOCKED; - case PipelineExecuteResult::FINISHED: - break; - } - } - - event->FinishTask(); - pipeline_executor.reset(); - return TaskExecutionResult::TASK_FINISHED; - } -}; - -Pipeline::Pipeline(Executor &executor_p) - : executor(executor_p), ready(false), initialized(false), source(nullptr), sink(nullptr) { -} - -ClientContext &Pipeline::GetClientContext() { - return executor.context; -} - -bool Pipeline::GetProgress(double ¤t_percentage, idx_t &source_cardinality) { - D_ASSERT(source); - source_cardinality = source->estimated_cardinality; - if (!initialized) { - current_percentage = 0; - return true; - } - auto &client = executor.context; - current_percentage = source->GetProgress(client, *source_state); - return current_percentage >= 0; -} - -void Pipeline::ScheduleSequentialTask(shared_ptr &event) { - vector> tasks; - tasks.push_back(make_uniq(*this, event)); - event->SetTasks(std::move(tasks)); -} - -bool Pipeline::ScheduleParallel(shared_ptr &event) { - // check if the sink, source and all intermediate operators support parallelism - if (!sink->ParallelSink()) { - return false; - } - if (!source->ParallelSource()) { - return false; - } - for (auto &op_ref : operators) { - auto &op = op_ref.get(); - if (!op.ParallelOperator()) { - return false; - } - } - if (sink->RequiresBatchIndex()) { - if (!source->SupportsBatchIndex()) { - throw InternalException( - "Attempting to schedule a pipeline where the sink requires batch index but source does not support it"); - } - } - idx_t max_threads = source_state->MaxThreads(); - return LaunchScanTasks(event, max_threads); -} - -bool Pipeline::IsOrderDependent() const { - auto &config = DBConfig::GetConfig(executor.context); - if (source) { - auto source_order = source->SourceOrder(); - if (source_order == OrderPreservationType::FIXED_ORDER) { - return true; - } - if (source_order == OrderPreservationType::NO_ORDER) { - return false; - } - } - for (auto &op_ref : operators) { - auto &op = op_ref.get(); - if (op.OperatorOrder() == OrderPreservationType::NO_ORDER) { - return false; - } - if (op.OperatorOrder() == OrderPreservationType::FIXED_ORDER) { - return true; - } - } - if (!config.options.preserve_insertion_order) { - return false; - } - if (sink && sink->SinkOrderDependent()) { - return true; - } - return false; -} - -void Pipeline::Schedule(shared_ptr &event) { - D_ASSERT(ready); - D_ASSERT(sink); - Reset(); - if (!ScheduleParallel(event)) { - // could not parallelize this pipeline: push a sequential task instead - ScheduleSequentialTask(event); - } -} - -bool Pipeline::LaunchScanTasks(shared_ptr &event, idx_t max_threads) { - // split the scan up into parts and schedule the parts - auto &scheduler = TaskScheduler::GetScheduler(executor.context); - idx_t active_threads = scheduler.NumberOfThreads(); - if (max_threads > active_threads) { - max_threads = active_threads; - } - if (max_threads <= 1) { - // too small to parallelize - return false; - } - - // launch a task for every thread - vector> tasks; - for (idx_t i = 0; i < max_threads; i++) { - tasks.push_back(make_uniq(*this, event)); - } - event->SetTasks(std::move(tasks)); - return true; -} - -void Pipeline::ResetSink() { - if (sink) { - if (!sink->IsSink()) { - throw InternalException("Sink of pipeline does not have IsSink set"); - } - lock_guard guard(sink->lock); - if (!sink->sink_state) { - sink->sink_state = sink->GetGlobalSinkState(GetClientContext()); - } - } -} - -void Pipeline::Reset() { - ResetSink(); - for (auto &op_ref : operators) { - auto &op = op_ref.get(); - lock_guard guard(op.lock); - if (!op.op_state) { - op.op_state = op.GetGlobalOperatorState(GetClientContext()); - } - } - ResetSource(false); - // we no longer reset source here because this function is no longer guaranteed to be called by the main thread - // source reset needs to be called by the main thread because resetting a source may call into clients like R - initialized = true; -} - -void Pipeline::ResetSource(bool force) { - if (source && !source->IsSource()) { - throw InternalException("Source of pipeline does not have IsSource set"); - } - if (force || !source_state) { - source_state = source->GetGlobalSourceState(GetClientContext()); - } -} - -void Pipeline::Ready() { - if (ready) { - return; - } - ready = true; - std::reverse(operators.begin(), operators.end()); -} - -void Pipeline::AddDependency(shared_ptr &pipeline) { - D_ASSERT(pipeline); - dependencies.push_back(weak_ptr(pipeline)); - pipeline->parents.push_back(weak_ptr(shared_from_this())); -} - -string Pipeline::ToString() const { - TreeRenderer renderer; - return renderer.ToString(*this); -} - -void Pipeline::Print() const { - Printer::Print(ToString()); -} - -void Pipeline::PrintDependencies() const { - for (auto &dep : dependencies) { - shared_ptr(dep)->Print(); - } -} - -vector> Pipeline::GetOperators() { - vector> result; - D_ASSERT(source); - result.push_back(*source); - for (auto &op : operators) { - result.push_back(op.get()); - } - if (sink) { - result.push_back(*sink); - } - return result; -} - -vector> Pipeline::GetOperators() const { - vector> result; - D_ASSERT(source); - result.push_back(*source); - for (auto &op : operators) { - result.push_back(op.get()); - } - if (sink) { - result.push_back(*sink); - } - return result; -} - -void Pipeline::ClearSource() { - source_state.reset(); - batch_indexes.clear(); -} - -idx_t Pipeline::RegisterNewBatchIndex() { - lock_guard l(batch_lock); - idx_t minimum = batch_indexes.empty() ? base_batch_index : *batch_indexes.begin(); - batch_indexes.insert(minimum); - return minimum; -} - -idx_t Pipeline::UpdateBatchIndex(idx_t old_index, idx_t new_index) { - lock_guard l(batch_lock); - if (new_index < *batch_indexes.begin()) { - throw InternalException("Processing batch index %llu, but previous min batch index was %llu", new_index, - *batch_indexes.begin()); - } - auto entry = batch_indexes.find(old_index); - if (entry == batch_indexes.end()) { - throw InternalException("Batch index %llu was not found in set of active batch indexes", old_index); - } - batch_indexes.erase(entry); - batch_indexes.insert(new_index); - return *batch_indexes.begin(); -} -//===--------------------------------------------------------------------===// -// Pipeline Build State -//===--------------------------------------------------------------------===// -void PipelineBuildState::SetPipelineSource(Pipeline &pipeline, PhysicalOperator &op) { - pipeline.source = &op; -} - -void PipelineBuildState::SetPipelineSink(Pipeline &pipeline, optional_ptr op, - idx_t sink_pipeline_count) { - pipeline.sink = op; - // set the base batch index of this pipeline based on how many other pipelines have this node as their sink - pipeline.base_batch_index = BATCH_INCREMENT * sink_pipeline_count; -} - -void PipelineBuildState::AddPipelineOperator(Pipeline &pipeline, PhysicalOperator &op) { - pipeline.operators.push_back(op); -} - -optional_ptr PipelineBuildState::GetPipelineSource(Pipeline &pipeline) { - return pipeline.source; -} - -optional_ptr PipelineBuildState::GetPipelineSink(Pipeline &pipeline) { - return pipeline.sink; -} - -void PipelineBuildState::SetPipelineOperators(Pipeline &pipeline, vector> operators) { - pipeline.operators = std::move(operators); -} - -shared_ptr PipelineBuildState::CreateChildPipeline(Executor &executor, Pipeline &pipeline, - PhysicalOperator &op) { - return executor.CreateChildPipeline(pipeline, op); -} - -vector> PipelineBuildState::GetPipelineOperators(Pipeline &pipeline) { - return pipeline.operators; -} - -} // namespace duckdb - - - -namespace duckdb { - -PipelineCompleteEvent::PipelineCompleteEvent(Executor &executor, bool complete_pipeline_p) - : Event(executor), complete_pipeline(complete_pipeline_p) { -} - -void PipelineCompleteEvent::Schedule() { -} - -void PipelineCompleteEvent::FinalizeFinish() { - if (complete_pipeline) { - executor.CompletePipeline(); - } -} - -} // namespace duckdb - - - -namespace duckdb { - -PipelineEvent::PipelineEvent(shared_ptr pipeline_p) : BasePipelineEvent(std::move(pipeline_p)) { -} - -void PipelineEvent::Schedule() { - auto event = shared_from_this(); - auto &executor = pipeline->executor; - try { - pipeline->Schedule(event); - D_ASSERT(total_tasks > 0); - } catch (Exception &ex) { - executor.PushError(PreservedError(ex)); - } catch (std::exception &ex) { - executor.PushError(PreservedError(ex)); - } catch (...) { // LCOV_EXCL_START - executor.PushError(PreservedError("Unknown exception in Finalize!")); - } // LCOV_EXCL_STOP -} - -void PipelineEvent::FinishEvent() { -} - -} // namespace duckdb - - - - -#ifdef DUCKDB_DEBUG_ASYNC_SINK_SOURCE -#include -#include -#endif - -namespace duckdb { - -PipelineExecutor::PipelineExecutor(ClientContext &context_p, Pipeline &pipeline_p) - : pipeline(pipeline_p), thread(context_p), context(context_p, thread, &pipeline_p) { - D_ASSERT(pipeline.source_state); - if (pipeline.sink) { - local_sink_state = pipeline.sink->GetLocalSinkState(context); - requires_batch_index = pipeline.sink->RequiresBatchIndex() && pipeline.source->SupportsBatchIndex(); - if (requires_batch_index) { - auto &partition_info = local_sink_state->partition_info; - D_ASSERT(!partition_info.batch_index.IsValid()); - // batch index is not set yet - initialize before fetching anything - partition_info.batch_index = pipeline.RegisterNewBatchIndex(); - partition_info.min_batch_index = partition_info.batch_index; - } - } - local_source_state = pipeline.source->GetLocalSourceState(context, *pipeline.source_state); - - intermediate_chunks.reserve(pipeline.operators.size()); - intermediate_states.reserve(pipeline.operators.size()); - for (idx_t i = 0; i < pipeline.operators.size(); i++) { - auto &prev_operator = i == 0 ? *pipeline.source : pipeline.operators[i - 1].get(); - auto ¤t_operator = pipeline.operators[i].get(); - - auto chunk = make_uniq(); - chunk->Initialize(Allocator::Get(context.client), prev_operator.GetTypes()); - intermediate_chunks.push_back(std::move(chunk)); - - auto op_state = current_operator.GetOperatorState(context); - intermediate_states.push_back(std::move(op_state)); - - if (current_operator.IsSink() && current_operator.sink_state->state == SinkFinalizeType::NO_OUTPUT_POSSIBLE) { - // one of the operators has already figured out no output is possible - // we can skip executing the pipeline - FinishProcessing(); - } - } - InitializeChunk(final_chunk); -} - -bool PipelineExecutor::TryFlushCachingOperators() { - if (!started_flushing) { - // Remainder of this method assumes any in process operators are from flushing - D_ASSERT(in_process_operators.empty()); - started_flushing = true; - flushing_idx = IsFinished() ? idx_t(finished_processing_idx) : 0; - } - - // Go over each operator and keep flushing them using `FinalExecute` until empty - while (flushing_idx < pipeline.operators.size()) { - if (!pipeline.operators[flushing_idx].get().RequiresFinalExecute()) { - flushing_idx++; - continue; - } - - // This slightly awkward way of increasing the flushing idx is to make the code re-entrant: We need to call this - // method again in the case of a Sink returning BLOCKED. - if (!should_flush_current_idx && in_process_operators.empty()) { - should_flush_current_idx = true; - flushing_idx++; - continue; - } - - auto &curr_chunk = - flushing_idx + 1 >= intermediate_chunks.size() ? final_chunk : *intermediate_chunks[flushing_idx + 1]; - auto ¤t_operator = pipeline.operators[flushing_idx].get(); - - OperatorFinalizeResultType finalize_result; - OperatorResultType push_result; - - if (in_process_operators.empty()) { - curr_chunk.Reset(); - StartOperator(current_operator); - finalize_result = current_operator.FinalExecute(context, curr_chunk, *current_operator.op_state, - *intermediate_states[flushing_idx]); - EndOperator(current_operator, &curr_chunk); - } else { - // Reset flag and reflush the last chunk we were flushing. - finalize_result = OperatorFinalizeResultType::HAVE_MORE_OUTPUT; - } - - push_result = ExecutePushInternal(curr_chunk, flushing_idx + 1); - - if (finalize_result == OperatorFinalizeResultType::HAVE_MORE_OUTPUT) { - should_flush_current_idx = true; - } else { - should_flush_current_idx = false; - } - - if (push_result == OperatorResultType::BLOCKED) { - remaining_sink_chunk = true; - return false; - } else if (push_result == OperatorResultType::FINISHED) { - break; - } - } - return true; -} - -PipelineExecuteResult PipelineExecutor::Execute(idx_t max_chunks) { - D_ASSERT(pipeline.sink); - auto &source_chunk = pipeline.operators.empty() ? final_chunk : *intermediate_chunks[0]; - for (idx_t i = 0; i < max_chunks; i++) { - if (context.client.interrupted) { - throw InterruptException(); - } - - OperatorResultType result; - if (exhausted_source && done_flushing && !remaining_sink_chunk && in_process_operators.empty()) { - break; - } else if (remaining_sink_chunk) { - // The pipeline was interrupted by the Sink. We should retry sinking the final chunk. - result = ExecutePushInternal(final_chunk); - remaining_sink_chunk = false; - } else if (!in_process_operators.empty() && !started_flushing) { - // The pipeline was interrupted by the Sink when pushing a source chunk through the pipeline. We need to - // re-push the same source chunk through the pipeline because there are in_process operators, meaning that - // the result for the pipeline - D_ASSERT(source_chunk.size() > 0); - result = ExecutePushInternal(source_chunk); - } else if (exhausted_source && !done_flushing) { - // The source was exhausted, try flushing all operators - auto flush_completed = TryFlushCachingOperators(); - if (flush_completed) { - done_flushing = true; - break; - } else { - return PipelineExecuteResult::INTERRUPTED; - } - } else if (!exhausted_source) { - // "Regular" path: fetch a chunk from the source and push it through the pipeline - source_chunk.Reset(); - SourceResultType source_result = FetchFromSource(source_chunk); - - if (source_result == SourceResultType::BLOCKED) { - return PipelineExecuteResult::INTERRUPTED; - } - - if (source_result == SourceResultType::FINISHED) { - exhausted_source = true; - if (source_chunk.size() == 0) { - continue; - } - } - result = ExecutePushInternal(source_chunk); - } else { - throw InternalException("Unexpected state reached in pipeline executor"); - } - - // SINK INTERRUPT - if (result == OperatorResultType::BLOCKED) { - remaining_sink_chunk = true; - return PipelineExecuteResult::INTERRUPTED; - } - - if (result == OperatorResultType::FINISHED) { - break; - } - } - - if ((!exhausted_source || !done_flushing) && !IsFinished()) { - return PipelineExecuteResult::NOT_FINISHED; - } - - return PushFinalize(); -} - -PipelineExecuteResult PipelineExecutor::Execute() { - return Execute(NumericLimits::Maximum()); -} - -OperatorResultType PipelineExecutor::ExecutePush(DataChunk &input) { // LCOV_EXCL_START - return ExecutePushInternal(input); -} // LCOV_EXCL_STOP - -void PipelineExecutor::FinishProcessing(int32_t operator_idx) { - finished_processing_idx = operator_idx < 0 ? NumericLimits::Maximum() : operator_idx; - in_process_operators = stack(); -} - -bool PipelineExecutor::IsFinished() { - return finished_processing_idx >= 0; -} - -OperatorResultType PipelineExecutor::ExecutePushInternal(DataChunk &input, idx_t initial_idx) { - D_ASSERT(pipeline.sink); - if (input.size() == 0) { // LCOV_EXCL_START - return OperatorResultType::NEED_MORE_INPUT; - } // LCOV_EXCL_STOP - - // this loop will continuously push the input chunk through the pipeline as long as: - // - the OperatorResultType for the Execute is HAVE_MORE_OUTPUT - // - the Sink doesn't block - while (true) { - OperatorResultType result; - // Note: if input is the final_chunk, we don't do any executing, the chunk just needs to be sinked - if (&input != &final_chunk) { - final_chunk.Reset(); - result = Execute(input, final_chunk, initial_idx); - if (result == OperatorResultType::FINISHED) { - return OperatorResultType::FINISHED; - } - } else { - result = OperatorResultType::NEED_MORE_INPUT; - } - auto &sink_chunk = final_chunk; - if (sink_chunk.size() > 0) { - StartOperator(*pipeline.sink); - D_ASSERT(pipeline.sink); - D_ASSERT(pipeline.sink->sink_state); - OperatorSinkInput sink_input {*pipeline.sink->sink_state, *local_sink_state, interrupt_state}; - - auto sink_result = Sink(sink_chunk, sink_input); - - EndOperator(*pipeline.sink, nullptr); - - if (sink_result == SinkResultType::BLOCKED) { - return OperatorResultType::BLOCKED; - } else if (sink_result == SinkResultType::FINISHED) { - FinishProcessing(); - return OperatorResultType::FINISHED; - } - } - if (result == OperatorResultType::NEED_MORE_INPUT) { - return OperatorResultType::NEED_MORE_INPUT; - } - } -} - -PipelineExecuteResult PipelineExecutor::PushFinalize() { - if (finalized) { - throw InternalException("Calling PushFinalize on a pipeline that has been finalized already"); - } - - D_ASSERT(local_sink_state); - - // Run the combine for the sink - OperatorSinkCombineInput combine_input {*pipeline.sink->sink_state, *local_sink_state, interrupt_state}; - -#ifdef DUCKDB_DEBUG_ASYNC_SINK_SOURCE - if (debug_blocked_combine_count < debug_blocked_target_count) { - debug_blocked_combine_count++; - - auto &callback_state = combine_input.interrupt_state; - std::thread rewake_thread([callback_state] { - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - callback_state.Callback(); - }); - rewake_thread.detach(); - - return PipelineExecuteResult::INTERRUPTED; - } -#endif - auto result = pipeline.sink->Combine(context, combine_input); - - if (result == SinkCombineResultType::BLOCKED) { - return PipelineExecuteResult::INTERRUPTED; - } - - finalized = true; - // flush all query profiler info - for (idx_t i = 0; i < intermediate_states.size(); i++) { - intermediate_states[i]->Finalize(pipeline.operators[i].get(), context); - } - pipeline.executor.Flush(thread); - local_sink_state.reset(); - - return PipelineExecuteResult::FINISHED; -} - -// TODO: Refactoring the StreamingQueryResult to use Push-based execution should eliminate the need for this code -void PipelineExecutor::ExecutePull(DataChunk &result) { - if (IsFinished()) { - return; - } - auto &executor = pipeline.executor; - try { - D_ASSERT(!pipeline.sink); - auto &source_chunk = pipeline.operators.empty() ? result : *intermediate_chunks[0]; - while (result.size() == 0 && !exhausted_source) { - if (in_process_operators.empty()) { - source_chunk.Reset(); - - auto done_signal = make_shared(); - interrupt_state = InterruptState(done_signal); - SourceResultType source_result; - - // Repeatedly try to fetch from the source until it doesn't block. Note that it may block multiple times - while (true) { - source_result = FetchFromSource(source_chunk); - - // No interrupt happened, all good. - if (source_result != SourceResultType::BLOCKED) { - break; - } - - // Busy wait for async callback from source operator - done_signal->Await(); - } - - if (source_result == SourceResultType::FINISHED) { - exhausted_source = true; - if (source_chunk.size() == 0) { - break; - } - } - } - if (!pipeline.operators.empty()) { - auto state = Execute(source_chunk, result); - if (state == OperatorResultType::FINISHED) { - break; - } - } - } - } catch (const Exception &ex) { // LCOV_EXCL_START - if (executor.HasError()) { - executor.ThrowException(); - } - throw; - } catch (std::exception &ex) { - if (executor.HasError()) { - executor.ThrowException(); - } - throw; - } catch (...) { - if (executor.HasError()) { - executor.ThrowException(); - } - throw; - } // LCOV_EXCL_STOP -} - -void PipelineExecutor::PullFinalize() { - if (finalized) { - throw InternalException("Calling PullFinalize on a pipeline that has been finalized already"); - } - finalized = true; - pipeline.executor.Flush(thread); -} - -void PipelineExecutor::GoToSource(idx_t ¤t_idx, idx_t initial_idx) { - // we go back to the first operator (the source) - current_idx = initial_idx; - if (!in_process_operators.empty()) { - // ... UNLESS there is an in process operator - // if there is an in-process operator, we start executing at the latest one - // for example, if we have a join operator that has tuples left, we first need to emit those tuples - current_idx = in_process_operators.top(); - in_process_operators.pop(); - } - D_ASSERT(current_idx >= initial_idx); -} - -OperatorResultType PipelineExecutor::Execute(DataChunk &input, DataChunk &result, idx_t initial_idx) { - if (input.size() == 0) { // LCOV_EXCL_START - return OperatorResultType::NEED_MORE_INPUT; - } // LCOV_EXCL_STOP - D_ASSERT(!pipeline.operators.empty()); - - idx_t current_idx; - GoToSource(current_idx, initial_idx); - if (current_idx == initial_idx) { - current_idx++; - } - if (current_idx > pipeline.operators.size()) { - result.Reference(input); - return OperatorResultType::NEED_MORE_INPUT; - } - while (true) { - if (context.client.interrupted) { - throw InterruptException(); - } - // now figure out where to put the chunk - // if current_idx is the last possible index (>= operators.size()) we write to the result - // otherwise we write to an intermediate chunk - auto current_intermediate = current_idx; - auto ¤t_chunk = - current_intermediate >= intermediate_chunks.size() ? result : *intermediate_chunks[current_intermediate]; - current_chunk.Reset(); - if (current_idx == initial_idx) { - // we went back to the source: we need more input - return OperatorResultType::NEED_MORE_INPUT; - } else { - auto &prev_chunk = - current_intermediate == initial_idx + 1 ? input : *intermediate_chunks[current_intermediate - 1]; - auto operator_idx = current_idx - 1; - auto ¤t_operator = pipeline.operators[operator_idx].get(); - - // if current_idx > source_idx, we pass the previous operators' output through the Execute of the current - // operator - StartOperator(current_operator); - auto result = current_operator.Execute(context, prev_chunk, current_chunk, *current_operator.op_state, - *intermediate_states[current_intermediate - 1]); - EndOperator(current_operator, ¤t_chunk); - if (result == OperatorResultType::HAVE_MORE_OUTPUT) { - // more data remains in this operator - // push in-process marker - in_process_operators.push(current_idx); - } else if (result == OperatorResultType::FINISHED) { - D_ASSERT(current_chunk.size() == 0); - FinishProcessing(current_idx); - return OperatorResultType::FINISHED; - } - current_chunk.Verify(); - } - - if (current_chunk.size() == 0) { - // no output from this operator! - if (current_idx == initial_idx) { - // if we got no output from the scan, we are done - break; - } else { - // if we got no output from an intermediate op - // we go back and try to pull data from the source again - GoToSource(current_idx, initial_idx); - continue; - } - } else { - // we got output! continue to the next operator - current_idx++; - if (current_idx > pipeline.operators.size()) { - // if we got output and are at the last operator, we are finished executing for this output chunk - // return the data and push it into the chunk - break; - } - } - } - return in_process_operators.empty() ? OperatorResultType::NEED_MORE_INPUT : OperatorResultType::HAVE_MORE_OUTPUT; -} - -void PipelineExecutor::SetTaskForInterrupts(weak_ptr current_task) { - interrupt_state = InterruptState(std::move(current_task)); -} - -SourceResultType PipelineExecutor::GetData(DataChunk &chunk, OperatorSourceInput &input) { - //! Testing feature to enable async source on every operator -#ifdef DUCKDB_DEBUG_ASYNC_SINK_SOURCE - if (debug_blocked_source_count < debug_blocked_target_count) { - debug_blocked_source_count++; - - auto &callback_state = input.interrupt_state; - std::thread rewake_thread([callback_state] { - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - callback_state.Callback(); - }); - rewake_thread.detach(); - - return SourceResultType::BLOCKED; - } -#endif - - return pipeline.source->GetData(context, chunk, input); -} - -SinkResultType PipelineExecutor::Sink(DataChunk &chunk, OperatorSinkInput &input) { - //! Testing feature to enable async sink on every operator -#ifdef DUCKDB_DEBUG_ASYNC_SINK_SOURCE - if (debug_blocked_sink_count < debug_blocked_target_count) { - debug_blocked_sink_count++; - - auto &callback_state = input.interrupt_state; - std::thread rewake_thread([callback_state] { - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - callback_state.Callback(); - }); - rewake_thread.detach(); - - return SinkResultType::BLOCKED; - } -#endif - return pipeline.sink->Sink(context, chunk, input); -} - -SourceResultType PipelineExecutor::FetchFromSource(DataChunk &result) { - StartOperator(*pipeline.source); - - OperatorSourceInput source_input = {*pipeline.source_state, *local_source_state, interrupt_state}; - auto res = GetData(result, source_input); - - // Ensures Sinks only return empty results when Blocking or Finished - D_ASSERT(res != SourceResultType::BLOCKED || result.size() == 0); - - if (requires_batch_index && res != SourceResultType::BLOCKED) { - idx_t next_batch_index; - if (result.size() == 0) { - next_batch_index = NumericLimits::Maximum(); - } else { - next_batch_index = - pipeline.source->GetBatchIndex(context, result, *pipeline.source_state, *local_source_state); - // we start with the base_batch_index as a valid starting value. Make sure that next batch is called below - next_batch_index += pipeline.base_batch_index + 1; - } - auto &partition_info = local_sink_state->partition_info; - if (next_batch_index != partition_info.batch_index.GetIndex()) { - // batch index has changed - update it - if (partition_info.batch_index.GetIndex() > next_batch_index) { - throw InternalException( - "Pipeline batch index - gotten lower batch index %llu (down from previous batch index of %llu)", - next_batch_index, partition_info.batch_index.GetIndex()); - } - auto current_batch = partition_info.batch_index.GetIndex(); - partition_info.batch_index = next_batch_index; - // call NextBatch before updating min_batch_index to provide the opportunity to flush the previous batch - pipeline.sink->NextBatch(context, *pipeline.sink->sink_state, *local_sink_state); - partition_info.min_batch_index = pipeline.UpdateBatchIndex(current_batch, next_batch_index); - } - } - - EndOperator(*pipeline.source, &result); - - return res; -} - -void PipelineExecutor::InitializeChunk(DataChunk &chunk) { - auto &last_op = pipeline.operators.empty() ? *pipeline.source : pipeline.operators.back().get(); - chunk.Initialize(Allocator::DefaultAllocator(), last_op.GetTypes()); -} - -void PipelineExecutor::StartOperator(PhysicalOperator &op) { - if (context.client.interrupted) { - throw InterruptException(); - } - context.thread.profiler.StartOperator(&op); -} - -void PipelineExecutor::EndOperator(PhysicalOperator &op, optional_ptr chunk) { - context.thread.profiler.EndOperator(chunk); - - if (chunk) { - chunk->Verify(); - } -} - -} // namespace duckdb - - - - -namespace duckdb { - -//! The PipelineFinishTask calls Finalize on the sink. Note that this is a single-threaded operation, but is executed -//! in a task to allow the Finalize call to block (e.g. for async I/O) -class PipelineFinishTask : public ExecutorTask { -public: - explicit PipelineFinishTask(Pipeline &pipeline_p, shared_ptr event_p) - : ExecutorTask(pipeline_p.executor), pipeline(pipeline_p), event(std::move(event_p)) { - } - - Pipeline &pipeline; - shared_ptr event; - -public: - TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { - auto sink = pipeline.GetSink(); - InterruptState interrupt_state(shared_from_this()); - OperatorSinkFinalizeInput finalize_input {*sink->sink_state, interrupt_state}; - -#ifdef DUCKDB_DEBUG_ASYNC_SINK_SOURCE - if (debug_blocked_count < debug_blocked_target_count) { - debug_blocked_count++; - - auto &callback_state = interrupt_state; - std::thread rewake_thread([callback_state] { - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - callback_state.Callback(); - }); - rewake_thread.detach(); - - return TaskExecutionResult::TASK_BLOCKED; - } -#endif - auto sink_state = sink->Finalize(pipeline, *event, executor.context, finalize_input); - - if (sink_state == SinkFinalizeType::BLOCKED) { - return TaskExecutionResult::TASK_BLOCKED; - } - - sink->sink_state->state = sink_state; - event->FinishTask(); - return TaskExecutionResult::TASK_FINISHED; - } - -private: -#ifdef DUCKDB_DEBUG_ASYNC_SINK_SOURCE - //! Debugging state: number of times blocked - int debug_blocked_count = 0; - //! Number of times the Finalize will block before actually returning data - int debug_blocked_target_count = 10; -#endif -}; - -PipelineFinishEvent::PipelineFinishEvent(shared_ptr pipeline_p) : BasePipelineEvent(std::move(pipeline_p)) { -} - -void PipelineFinishEvent::Schedule() { - vector> tasks; - tasks.push_back(make_uniq(*pipeline, shared_from_this())); - SetTasks(std::move(tasks)); -} - -void PipelineFinishEvent::FinishEvent() { -} - -} // namespace duckdb - - - - -namespace duckdb { - -PipelineInitializeEvent::PipelineInitializeEvent(shared_ptr pipeline_p) - : BasePipelineEvent(std::move(pipeline_p)) { -} - -class PipelineInitializeTask : public ExecutorTask { -public: - explicit PipelineInitializeTask(Pipeline &pipeline_p, shared_ptr event_p) - : ExecutorTask(pipeline_p.executor), pipeline(pipeline_p), event(std::move(event_p)) { - } - - Pipeline &pipeline; - shared_ptr event; - -public: - TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { - pipeline.ResetSink(); - event->FinishTask(); - return TaskExecutionResult::TASK_FINISHED; - } -}; - -void PipelineInitializeEvent::Schedule() { - // needs to spawn a task to get the chain of tasks for the query plan going - vector> tasks; - tasks.push_back(make_uniq(*pipeline, shared_from_this())); - SetTasks(std::move(tasks)); -} - -void PipelineInitializeEvent::FinishEvent() { -} - -} // namespace duckdb - - - - - - - -#ifndef DUCKDB_NO_THREADS - - - -#include -#else -#include -#endif - -namespace duckdb { - -struct SchedulerThread { -#ifndef DUCKDB_NO_THREADS - explicit SchedulerThread(unique_ptr thread_p) : internal_thread(std::move(thread_p)) { - } - - unique_ptr internal_thread; -#endif -}; - -#ifndef DUCKDB_NO_THREADS -typedef duckdb_moodycamel::ConcurrentQueue> concurrent_queue_t; -typedef duckdb_moodycamel::LightweightSemaphore lightweight_semaphore_t; - -struct ConcurrentQueue { - concurrent_queue_t q; - lightweight_semaphore_t semaphore; - - void Enqueue(ProducerToken &token, shared_ptr task); - bool DequeueFromProducer(ProducerToken &token, shared_ptr &task); -}; - -struct QueueProducerToken { - explicit QueueProducerToken(ConcurrentQueue &queue) : queue_token(queue.q) { - } - - duckdb_moodycamel::ProducerToken queue_token; -}; - -void ConcurrentQueue::Enqueue(ProducerToken &token, shared_ptr task) { - lock_guard producer_lock(token.producer_lock); - if (q.enqueue(token.token->queue_token, std::move(task))) { - semaphore.signal(); - } else { - throw InternalException("Could not schedule task!"); - } -} - -bool ConcurrentQueue::DequeueFromProducer(ProducerToken &token, shared_ptr &task) { - lock_guard producer_lock(token.producer_lock); - return q.try_dequeue_from_producer(token.token->queue_token, task); -} - -#else -struct ConcurrentQueue { - std::queue> q; - mutex qlock; - - void Enqueue(ProducerToken &token, shared_ptr task); - bool DequeueFromProducer(ProducerToken &token, shared_ptr &task); -}; - -void ConcurrentQueue::Enqueue(ProducerToken &token, shared_ptr task) { - lock_guard lock(qlock); - q.push(std::move(task)); -} - -bool ConcurrentQueue::DequeueFromProducer(ProducerToken &token, shared_ptr &task) { - lock_guard lock(qlock); - if (q.empty()) { - return false; - } - task = std::move(q.front()); - q.pop(); - return true; -} - -struct QueueProducerToken { - QueueProducerToken(ConcurrentQueue &queue) { - } -}; -#endif - -ProducerToken::ProducerToken(TaskScheduler &scheduler, unique_ptr token) - : scheduler(scheduler), token(std::move(token)) { -} - -ProducerToken::~ProducerToken() { -} - -TaskScheduler::TaskScheduler(DatabaseInstance &db) - : db(db), queue(make_uniq()), - allocator_flush_threshold(db.config.options.allocator_flush_threshold) { -} - -TaskScheduler::~TaskScheduler() { -#ifndef DUCKDB_NO_THREADS - SetThreadsInternal(1); -#endif -} - -TaskScheduler &TaskScheduler::GetScheduler(ClientContext &context) { - return TaskScheduler::GetScheduler(DatabaseInstance::GetDatabase(context)); -} - -TaskScheduler &TaskScheduler::GetScheduler(DatabaseInstance &db) { - return db.GetScheduler(); -} - -unique_ptr TaskScheduler::CreateProducer() { - auto token = make_uniq(*queue); - return make_uniq(*this, std::move(token)); -} - -void TaskScheduler::ScheduleTask(ProducerToken &token, shared_ptr task) { - // Enqueue a task for the given producer token and signal any sleeping threads - queue->Enqueue(token, std::move(task)); -} - -bool TaskScheduler::GetTaskFromProducer(ProducerToken &token, shared_ptr &task) { - return queue->DequeueFromProducer(token, task); -} - -void TaskScheduler::ExecuteForever(atomic *marker) { -#ifndef DUCKDB_NO_THREADS - shared_ptr task; - // loop until the marker is set to false - while (*marker) { - // wait for a signal with a timeout - queue->semaphore.wait(); - if (queue->q.try_dequeue(task)) { - auto execute_result = task->Execute(TaskExecutionMode::PROCESS_ALL); - - switch (execute_result) { - case TaskExecutionResult::TASK_FINISHED: - case TaskExecutionResult::TASK_ERROR: - task.reset(); - break; - case TaskExecutionResult::TASK_NOT_FINISHED: - throw InternalException("Task should not return TASK_NOT_FINISHED in PROCESS_ALL mode"); - case TaskExecutionResult::TASK_BLOCKED: - task->Deschedule(); - task.reset(); - break; - } - - // Flushes the outstanding allocator's outstanding allocations - Allocator::ThreadFlush(allocator_flush_threshold); - } - } -#else - throw NotImplementedException("DuckDB was compiled without threads! Background thread loop is not allowed."); -#endif -} - -idx_t TaskScheduler::ExecuteTasks(atomic *marker, idx_t max_tasks) { -#ifndef DUCKDB_NO_THREADS - idx_t completed_tasks = 0; - // loop until the marker is set to false - while (*marker && completed_tasks < max_tasks) { - shared_ptr task; - if (!queue->q.try_dequeue(task)) { - return completed_tasks; - } - auto execute_result = task->Execute(TaskExecutionMode::PROCESS_ALL); - - switch (execute_result) { - case TaskExecutionResult::TASK_FINISHED: - case TaskExecutionResult::TASK_ERROR: - task.reset(); - completed_tasks++; - break; - case TaskExecutionResult::TASK_NOT_FINISHED: - throw InternalException("Task should not return TASK_NOT_FINISHED in PROCESS_ALL mode"); - case TaskExecutionResult::TASK_BLOCKED: - task->Deschedule(); - task.reset(); - break; - } - } - return completed_tasks; -#else - throw NotImplementedException("DuckDB was compiled without threads! Background thread loop is not allowed."); -#endif -} - -void TaskScheduler::ExecuteTasks(idx_t max_tasks) { -#ifndef DUCKDB_NO_THREADS - shared_ptr task; - for (idx_t i = 0; i < max_tasks; i++) { - queue->semaphore.wait(TASK_TIMEOUT_USECS); - if (!queue->q.try_dequeue(task)) { - return; - } - try { - auto execute_result = task->Execute(TaskExecutionMode::PROCESS_ALL); - switch (execute_result) { - case TaskExecutionResult::TASK_FINISHED: - case TaskExecutionResult::TASK_ERROR: - task.reset(); - break; - case TaskExecutionResult::TASK_NOT_FINISHED: - throw InternalException("Task should not return TASK_NOT_FINISHED in PROCESS_ALL mode"); - case TaskExecutionResult::TASK_BLOCKED: - task->Deschedule(); - task.reset(); - break; - } - } catch (...) { - return; - } - } -#else - throw NotImplementedException("DuckDB was compiled without threads! Background thread loop is not allowed."); -#endif -} - -#ifndef DUCKDB_NO_THREADS -static void ThreadExecuteTasks(TaskScheduler *scheduler, atomic *marker) { - scheduler->ExecuteForever(marker); -} -#endif - -int32_t TaskScheduler::NumberOfThreads() { - lock_guard t(thread_lock); - auto &config = DBConfig::GetConfig(db); - return threads.size() + config.options.external_threads + 1; -} - -void TaskScheduler::SetThreads(int32_t n) { -#ifndef DUCKDB_NO_THREADS - lock_guard t(thread_lock); - if (n < 1) { - throw SyntaxException("Must have at least 1 thread!"); - } - SetThreadsInternal(n); -#else - if (n != 1) { - throw NotImplementedException("DuckDB was compiled without threads! Setting threads > 1 is not allowed."); - } -#endif -} - -void TaskScheduler::SetAllocatorFlushTreshold(idx_t threshold) { -} - -void TaskScheduler::Signal(idx_t n) { -#ifndef DUCKDB_NO_THREADS - queue->semaphore.signal(n); -#endif -} - -void TaskScheduler::YieldThread() { -#ifndef DUCKDB_NO_THREADS - std::this_thread::yield(); -#endif -} - -void TaskScheduler::SetThreadsInternal(int32_t n) { -#ifndef DUCKDB_NO_THREADS - if (threads.size() == idx_t(n - 1)) { - return; - } - idx_t new_thread_count = n - 1; - if (threads.size() > new_thread_count) { - // we are reducing the number of threads: clear all threads first - for (idx_t i = 0; i < threads.size(); i++) { - *markers[i] = false; - } - Signal(threads.size()); - // now join the threads to ensure they are fully stopped before erasing them - for (idx_t i = 0; i < threads.size(); i++) { - threads[i]->internal_thread->join(); - } - // erase the threads/markers - threads.clear(); - markers.clear(); - } - if (threads.size() < new_thread_count) { - // we are increasing the number of threads: launch them and run tasks on them - idx_t create_new_threads = new_thread_count - threads.size(); - for (idx_t i = 0; i < create_new_threads; i++) { - // launch a thread and assign it a cancellation marker - auto marker = unique_ptr>(new atomic(true)); - auto worker_thread = make_uniq(ThreadExecuteTasks, this, marker.get()); - auto thread_wrapper = make_uniq(std::move(worker_thread)); - - threads.push_back(std::move(thread_wrapper)); - markers.push_back(std::move(marker)); - } - } -#endif -} - -} // namespace duckdb - - - - -namespace duckdb { - -ThreadContext::ThreadContext(ClientContext &context) : profiler(QueryProfiler::Get(context).IsEnabled()) { -} - -} // namespace duckdb - - - - - -namespace duckdb { - -void BaseExpression::Print() const { - Printer::Print(ToString()); -} - -string BaseExpression::GetName() const { -#ifdef DEBUG - if (DBConfigOptions::debug_print_bindings) { - return ToString(); - } -#endif - return !alias.empty() ? alias : ToString(); -} - -bool BaseExpression::Equals(const BaseExpression &other) const { - if (expression_class != other.expression_class || type != other.type) { - return false; - } - return true; -} - -void BaseExpression::Verify() const { -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -ColumnDefinition::ColumnDefinition(string name_p, LogicalType type_p) - : name(std::move(name_p)), type(std::move(type_p)) { -} - -ColumnDefinition::ColumnDefinition(string name_p, LogicalType type_p, unique_ptr expression, - TableColumnType category) - : name(std::move(name_p)), type(std::move(type_p)), category(category), expression(std::move(expression)) { -} - -ColumnDefinition ColumnDefinition::Copy() const { - ColumnDefinition copy(name, type); - copy.oid = oid; - copy.storage_oid = storage_oid; - copy.expression = expression ? expression->Copy() : nullptr; - copy.compression_type = compression_type; - copy.category = category; - return copy; -} - -const unique_ptr &ColumnDefinition::DefaultValue() const { - if (Generated()) { - throw InternalException("Calling DefaultValue() on a generated column"); - } - return expression; -} - -void ColumnDefinition::SetDefaultValue(unique_ptr default_value) { - if (Generated()) { - throw InternalException("Calling SetDefaultValue() on a generated column"); - } - this->expression = std::move(default_value); -} - -const LogicalType &ColumnDefinition::Type() const { - return type; -} - -LogicalType &ColumnDefinition::TypeMutable() { - return type; -} - -void ColumnDefinition::SetType(const LogicalType &type) { - this->type = type; -} - -const string &ColumnDefinition::Name() const { - return name; -} - -void ColumnDefinition::SetName(const string &name) { - this->name = name; -} - -const duckdb::CompressionType &ColumnDefinition::CompressionType() const { - return compression_type; -} - -void ColumnDefinition::SetCompressionType(duckdb::CompressionType compression_type) { - this->compression_type = compression_type; -} - -const storage_t &ColumnDefinition::StorageOid() const { - return storage_oid; -} - -LogicalIndex ColumnDefinition::Logical() const { - return LogicalIndex(oid); -} - -PhysicalIndex ColumnDefinition::Physical() const { - return PhysicalIndex(storage_oid); -} - -void ColumnDefinition::SetStorageOid(storage_t storage_oid) { - this->storage_oid = storage_oid; -} - -const column_t &ColumnDefinition::Oid() const { - return oid; -} - -void ColumnDefinition::SetOid(column_t oid) { - this->oid = oid; -} - -const TableColumnType &ColumnDefinition::Category() const { - return category; -} - -bool ColumnDefinition::Generated() const { - return category == TableColumnType::GENERATED; -} - -//===--------------------------------------------------------------------===// -// Generated Columns (VIRTUAL) -//===--------------------------------------------------------------------===// - -static void VerifyColumnRefs(ParsedExpression &expr) { - if (expr.type == ExpressionType::COLUMN_REF) { - auto &column_ref = expr.Cast(); - if (column_ref.IsQualified()) { - throw ParserException( - "Qualified (tbl.name) column references are not allowed inside of generated column expressions"); - } - } - ParsedExpressionIterator::EnumerateChildren( - expr, [&](const ParsedExpression &child) { VerifyColumnRefs((ParsedExpression &)child); }); -} - -static void InnerGetListOfDependencies(ParsedExpression &expr, vector &dependencies) { - if (expr.type == ExpressionType::COLUMN_REF) { - auto columnref = expr.Cast(); - auto &name = columnref.GetColumnName(); - dependencies.push_back(name); - } - ParsedExpressionIterator::EnumerateChildren(expr, [&](const ParsedExpression &child) { - if (expr.type == ExpressionType::LAMBDA) { - throw NotImplementedException("Lambda functions are currently not supported in generated columns."); - } - InnerGetListOfDependencies((ParsedExpression &)child, dependencies); - }); -} - -void ColumnDefinition::GetListOfDependencies(vector &dependencies) const { - D_ASSERT(Generated()); - InnerGetListOfDependencies(*expression, dependencies); -} - -string ColumnDefinition::GetName() const { - return name; -} - -LogicalType ColumnDefinition::GetType() const { - return type; -} - -void ColumnDefinition::SetGeneratedExpression(unique_ptr new_expr) { - category = TableColumnType::GENERATED; - - if (new_expr->HasSubquery()) { - throw ParserException("Expression of generated column \"%s\" contains a subquery, which isn't allowed", name); - } - - VerifyColumnRefs(*new_expr); - if (type.id() == LogicalTypeId::ANY) { - expression = std::move(new_expr); - return; - } - // Always wrap the expression in a cast, that way we can always update the cast when we change the type - // Except if the type is LogicalType::ANY (no type specified) - expression = make_uniq_base(type, std::move(new_expr)); -} - -void ColumnDefinition::ChangeGeneratedExpressionType(const LogicalType &type) { - D_ASSERT(Generated()); - // First time the type is set, add a cast around the expression - D_ASSERT(this->type.id() == LogicalTypeId::ANY); - expression = make_uniq_base(type, std::move(expression)); - // Every generated expression should be wrapped in a cast on creation - // D_ASSERT(generated_expression->type == ExpressionType::OPERATOR_CAST); - // auto &cast_expr = generated_expression->Cast(); - // auto base_expr = std::move(cast_expr.child); - // generated_expression = make_uniq_base(type, std::move(base_expr)); -} - -const ParsedExpression &ColumnDefinition::GeneratedExpression() const { - D_ASSERT(Generated()); - return *expression; -} - -ParsedExpression &ColumnDefinition::GeneratedExpressionMutable() { - D_ASSERT(Generated()); - return *expression; -} - -} // namespace duckdb - - - - -namespace duckdb { - -ColumnList::ColumnList(bool allow_duplicate_names) : allow_duplicate_names(allow_duplicate_names) { -} - -ColumnList::ColumnList(vector columns, bool allow_duplicate_names) - : allow_duplicate_names(allow_duplicate_names) { - for (auto &col : columns) { - AddColumn(std::move(col)); - } -} - -void ColumnList::AddColumn(ColumnDefinition column) { - auto oid = columns.size(); - if (!column.Generated()) { - column.SetStorageOid(physical_columns.size()); - physical_columns.push_back(oid); - } else { - column.SetStorageOid(DConstants::INVALID_INDEX); - } - column.SetOid(columns.size()); - AddToNameMap(column); - columns.push_back(std::move(column)); -} - -void ColumnList::Finalize() { - // add the "rowid" alias, if there is no rowid column specified in the table - if (name_map.find("rowid") == name_map.end()) { - name_map["rowid"] = COLUMN_IDENTIFIER_ROW_ID; - } -} - -void ColumnList::AddToNameMap(ColumnDefinition &col) { - if (allow_duplicate_names) { - idx_t index = 1; - string base_name = col.Name(); - while (name_map.find(col.Name()) != name_map.end()) { - col.SetName(base_name + ":" + to_string(index++)); - } - } else { - if (name_map.find(col.Name()) != name_map.end()) { - throw CatalogException("Column with name %s already exists!", col.Name()); - } - } - name_map[col.Name()] = col.Oid(); -} - -ColumnDefinition &ColumnList::GetColumnMutable(LogicalIndex logical) { - if (logical.index >= columns.size()) { - throw InternalException("Logical column index %lld out of range", logical.index); - } - return columns[logical.index]; -} - -ColumnDefinition &ColumnList::GetColumnMutable(PhysicalIndex physical) { - if (physical.index >= physical_columns.size()) { - throw InternalException("Physical column index %lld out of range", physical.index); - } - auto logical_index = physical_columns[physical.index]; - D_ASSERT(logical_index < columns.size()); - return columns[logical_index]; -} - -ColumnDefinition &ColumnList::GetColumnMutable(const string &name) { - auto entry = name_map.find(name); - if (entry == name_map.end()) { - throw InternalException("Column with name \"%s\" does not exist", name); - } - auto logical_index = entry->second; - D_ASSERT(logical_index < columns.size()); - return columns[logical_index]; -} - -const ColumnDefinition &ColumnList::GetColumn(LogicalIndex logical) const { - if (logical.index >= columns.size()) { - throw InternalException("Logical column index %lld out of range", logical.index); - } - return columns[logical.index]; -} - -const ColumnDefinition &ColumnList::GetColumn(PhysicalIndex physical) const { - if (physical.index >= physical_columns.size()) { - throw InternalException("Physical column index %lld out of range", physical.index); - } - auto logical_index = physical_columns[physical.index]; - D_ASSERT(logical_index < columns.size()); - return columns[logical_index]; -} - -const ColumnDefinition &ColumnList::GetColumn(const string &name) const { - auto entry = name_map.find(name); - if (entry == name_map.end()) { - throw InternalException("Column with name \"%s\" does not exist", name); - } - auto logical_index = entry->second; - D_ASSERT(logical_index < columns.size()); - return columns[logical_index]; -} - -vector ColumnList::GetColumnNames() const { - vector names; - names.reserve(columns.size()); - for (auto &column : columns) { - names.push_back(column.Name()); - } - return names; -} - -vector ColumnList::GetColumnTypes() const { - vector types; - types.reserve(columns.size()); - for (auto &column : columns) { - types.push_back(column.Type()); - } - return types; -} - -bool ColumnList::ColumnExists(const string &name) const { - auto entry = name_map.find(name); - return entry != name_map.end(); -} - -PhysicalIndex ColumnList::LogicalToPhysical(LogicalIndex logical) const { - auto &column = GetColumn(logical); - if (column.Generated()) { - throw InternalException("Column at position %d is not a physical column", logical.index); - } - return column.Physical(); -} - -LogicalIndex ColumnList::PhysicalToLogical(PhysicalIndex index) const { - auto &column = GetColumn(index); - return column.Logical(); -} - -LogicalIndex ColumnList::GetColumnIndex(string &column_name) const { - auto entry = name_map.find(column_name); - if (entry == name_map.end()) { - return LogicalIndex(DConstants::INVALID_INDEX); - } - if (entry->second == COLUMN_IDENTIFIER_ROW_ID) { - column_name = "rowid"; - return LogicalIndex(COLUMN_IDENTIFIER_ROW_ID); - } - column_name = columns[entry->second].Name(); - return LogicalIndex(entry->second); -} - -ColumnList ColumnList::Copy() const { - ColumnList result(allow_duplicate_names); - for (auto &col : columns) { - result.AddColumn(col.Copy()); - } - return result; -} - -ColumnList::ColumnListIterator ColumnList::Logical() const { - return ColumnListIterator(*this, false); -} - -ColumnList::ColumnListIterator ColumnList::Physical() const { - return ColumnListIterator(*this, true); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -Constraint::Constraint(ConstraintType type) : type(type) { -} - -Constraint::~Constraint() { -} - -void Constraint::Print() const { - Printer::Print(ToString()); -} - -} // namespace duckdb - - -namespace duckdb { - -CheckConstraint::CheckConstraint(unique_ptr expression) - : Constraint(ConstraintType::CHECK), expression(std::move(expression)) { -} - -string CheckConstraint::ToString() const { - return "CHECK(" + expression->ToString() + ")"; -} - -unique_ptr CheckConstraint::Copy() const { - return make_uniq(expression->Copy()); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -ForeignKeyConstraint::ForeignKeyConstraint() : Constraint(ConstraintType::FOREIGN_KEY) { -} - -ForeignKeyConstraint::ForeignKeyConstraint(vector pk_columns, vector fk_columns, ForeignKeyInfo info) - : Constraint(ConstraintType::FOREIGN_KEY), pk_columns(std::move(pk_columns)), fk_columns(std::move(fk_columns)), - info(std::move(info)) { -} - -string ForeignKeyConstraint::ToString() const { - if (info.type == ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE) { - string base = "FOREIGN KEY ("; - - for (idx_t i = 0; i < fk_columns.size(); i++) { - if (i > 0) { - base += ", "; - } - base += KeywordHelper::WriteOptionallyQuoted(fk_columns[i]); - } - base += ") REFERENCES "; - if (!info.schema.empty()) { - base += info.schema; - base += "."; - } - base += info.table; - base += "("; - - for (idx_t i = 0; i < pk_columns.size(); i++) { - if (i > 0) { - base += ", "; - } - base += KeywordHelper::WriteOptionallyQuoted(pk_columns[i]); - } - base += ")"; - - return base; - } - - return ""; -} - -unique_ptr ForeignKeyConstraint::Copy() const { - return make_uniq(pk_columns, fk_columns, info); -} - -} // namespace duckdb - - -namespace duckdb { - -NotNullConstraint::NotNullConstraint(LogicalIndex index) : Constraint(ConstraintType::NOT_NULL), index(index) { -} - -NotNullConstraint::~NotNullConstraint() { -} - -string NotNullConstraint::ToString() const { - return "NOT NULL"; -} - -unique_ptr NotNullConstraint::Copy() const { - return make_uniq(index); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -UniqueConstraint::UniqueConstraint() : Constraint(ConstraintType::UNIQUE), index(DConstants::INVALID_INDEX) { -} - -UniqueConstraint::UniqueConstraint(LogicalIndex index, bool is_primary_key) - : Constraint(ConstraintType::UNIQUE), index(index), is_primary_key(is_primary_key) { -} -UniqueConstraint::UniqueConstraint(vector columns, bool is_primary_key) - : Constraint(ConstraintType::UNIQUE), index(DConstants::INVALID_INDEX), columns(std::move(columns)), - is_primary_key(is_primary_key) { -} - -string UniqueConstraint::ToString() const { - string base = is_primary_key ? "PRIMARY KEY(" : "UNIQUE("; - for (idx_t i = 0; i < columns.size(); i++) { - if (i > 0) { - base += ", "; - } - base += KeywordHelper::WriteOptionallyQuoted(columns[i]); - } - return base + ")"; -} - -unique_ptr UniqueConstraint::Copy() const { - if (index.index == DConstants::INVALID_INDEX) { - return make_uniq(columns, is_primary_key); - } else { - auto result = make_uniq(index, is_primary_key); - result->columns = columns; - return std::move(result); - } -} - -} // namespace duckdb - - - - -namespace duckdb { - -BetweenExpression::BetweenExpression(unique_ptr input_p, unique_ptr lower_p, - unique_ptr upper_p) - : ParsedExpression(ExpressionType::COMPARE_BETWEEN, ExpressionClass::BETWEEN), input(std::move(input_p)), - lower(std::move(lower_p)), upper(std::move(upper_p)) { -} - -BetweenExpression::BetweenExpression() : BetweenExpression(nullptr, nullptr, nullptr) { -} - -string BetweenExpression::ToString() const { - return ToString(*this); -} - -bool BetweenExpression::Equal(const BetweenExpression &a, const BetweenExpression &b) { - if (!a.input->Equals(*b.input)) { - return false; - } - if (!a.lower->Equals(*b.lower)) { - return false; - } - if (!a.upper->Equals(*b.upper)) { - return false; - } - return true; -} - -unique_ptr BetweenExpression::Copy() const { - auto copy = make_uniq(input->Copy(), lower->Copy(), upper->Copy()); - copy->CopyProperties(*this); - return std::move(copy); -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -CaseExpression::CaseExpression() : ParsedExpression(ExpressionType::CASE_EXPR, ExpressionClass::CASE) { -} - -string CaseExpression::ToString() const { - return ToString(*this); -} - -bool CaseExpression::Equal(const CaseExpression &a, const CaseExpression &b) { - if (a.case_checks.size() != b.case_checks.size()) { - return false; - } - for (idx_t i = 0; i < a.case_checks.size(); i++) { - if (!a.case_checks[i].when_expr->Equals(*b.case_checks[i].when_expr)) { - return false; - } - if (!a.case_checks[i].then_expr->Equals(*b.case_checks[i].then_expr)) { - return false; - } - } - if (!a.else_expr->Equals(*b.else_expr)) { - return false; - } - return true; -} - -unique_ptr CaseExpression::Copy() const { - auto copy = make_uniq(); - copy->CopyProperties(*this); - for (auto &check : case_checks) { - CaseCheck new_check; - new_check.when_expr = check.when_expr->Copy(); - new_check.then_expr = check.then_expr->Copy(); - copy->case_checks.push_back(std::move(new_check)); - } - copy->else_expr = else_expr->Copy(); - return std::move(copy); -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -CastExpression::CastExpression(LogicalType target, unique_ptr child, bool try_cast_p) - : ParsedExpression(ExpressionType::OPERATOR_CAST, ExpressionClass::CAST), cast_type(std::move(target)), - try_cast(try_cast_p) { - D_ASSERT(child); - this->child = std::move(child); -} - -CastExpression::CastExpression() : ParsedExpression(ExpressionType::OPERATOR_CAST, ExpressionClass::CAST) { -} - -string CastExpression::ToString() const { - return ToString(*this); -} - -bool CastExpression::Equal(const CastExpression &a, const CastExpression &b) { - if (!a.child->Equals(*b.child)) { - return false; - } - if (a.cast_type != b.cast_type) { - return false; - } - if (a.try_cast != b.try_cast) { - return false; - } - return true; -} - -unique_ptr CastExpression::Copy() const { - auto copy = make_uniq(cast_type, child->Copy(), try_cast); - copy->CopyProperties(*this); - return std::move(copy); -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -CollateExpression::CollateExpression(string collation_p, unique_ptr child) - : ParsedExpression(ExpressionType::COLLATE, ExpressionClass::COLLATE), collation(std::move(collation_p)) { - D_ASSERT(child); - this->child = std::move(child); -} - -CollateExpression::CollateExpression() : ParsedExpression(ExpressionType::COLLATE, ExpressionClass::COLLATE) { -} - -string CollateExpression::ToString() const { - return StringUtil::Format("%s COLLATE %s", child->ToString(), SQLIdentifier(collation)); -} - -bool CollateExpression::Equal(const CollateExpression &a, const CollateExpression &b) { - if (!a.child->Equals(*b.child)) { - return false; - } - if (a.collation != b.collation) { - return false; - } - return true; -} - -unique_ptr CollateExpression::Copy() const { - auto copy = make_uniq(collation, child->Copy()); - copy->CopyProperties(*this); - return std::move(copy); -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -ColumnRefExpression::ColumnRefExpression() : ParsedExpression(ExpressionType::COLUMN_REF, ExpressionClass::COLUMN_REF) { -} - -ColumnRefExpression::ColumnRefExpression(string column_name, string table_name) - : ColumnRefExpression(table_name.empty() ? vector {std::move(column_name)} - : vector {std::move(table_name), std::move(column_name)}) { -} - -ColumnRefExpression::ColumnRefExpression(string column_name) - : ColumnRefExpression(vector {std::move(column_name)}) { -} - -ColumnRefExpression::ColumnRefExpression(vector column_names_p) - : ParsedExpression(ExpressionType::COLUMN_REF, ExpressionClass::COLUMN_REF), - column_names(std::move(column_names_p)) { -#ifdef DEBUG - for (auto &col_name : column_names) { - D_ASSERT(!col_name.empty()); - } -#endif -} - -bool ColumnRefExpression::IsQualified() const { - return column_names.size() > 1; -} - -const string &ColumnRefExpression::GetColumnName() const { - D_ASSERT(column_names.size() <= 4); - return column_names.back(); -} - -const string &ColumnRefExpression::GetTableName() const { - D_ASSERT(column_names.size() >= 2 && column_names.size() <= 4); - if (column_names.size() == 4) { - return column_names[2]; - } - if (column_names.size() == 3) { - return column_names[1]; - } - return column_names[0]; -} - -string ColumnRefExpression::GetName() const { - return !alias.empty() ? alias : column_names.back(); -} - -string ColumnRefExpression::ToString() const { - string result; - for (idx_t i = 0; i < column_names.size(); i++) { - if (i > 0) { - result += "."; - } - result += KeywordHelper::WriteOptionallyQuoted(column_names[i]); - } - return result; -} - -bool ColumnRefExpression::Equal(const ColumnRefExpression &a, const ColumnRefExpression &b) { - if (a.column_names.size() != b.column_names.size()) { - return false; - } - for (idx_t i = 0; i < a.column_names.size(); i++) { - if (!StringUtil::CIEquals(a.column_names[i], b.column_names[i])) { - return false; - } - } - return true; -} - -hash_t ColumnRefExpression::Hash() const { - hash_t result = ParsedExpression::Hash(); - for (auto &column_name : column_names) { - result = CombineHash(result, StringUtil::CIHash(column_name)); - } - return result; -} - -unique_ptr ColumnRefExpression::Copy() const { - auto copy = make_uniq(column_names); - copy->CopyProperties(*this); - return std::move(copy); -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -ComparisonExpression::ComparisonExpression(ExpressionType type) : ParsedExpression(type, ExpressionClass::COMPARISON) { -} - -ComparisonExpression::ComparisonExpression(ExpressionType type, unique_ptr left, - unique_ptr right) - : ParsedExpression(type, ExpressionClass::COMPARISON), left(std::move(left)), right(std::move(right)) { -} - -string ComparisonExpression::ToString() const { - return ToString(*this); -} - -bool ComparisonExpression::Equal(const ComparisonExpression &a, const ComparisonExpression &b) { - if (!a.left->Equals(*b.left)) { - return false; - } - if (!a.right->Equals(*b.right)) { - return false; - } - return true; -} - -unique_ptr ComparisonExpression::Copy() const { - auto copy = make_uniq(type, left->Copy(), right->Copy()); - copy->CopyProperties(*this); - return std::move(copy); -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -ConjunctionExpression::ConjunctionExpression(ExpressionType type) - : ParsedExpression(type, ExpressionClass::CONJUNCTION) { -} - -ConjunctionExpression::ConjunctionExpression(ExpressionType type, vector> children) - : ParsedExpression(type, ExpressionClass::CONJUNCTION) { - for (auto &child : children) { - AddExpression(std::move(child)); - } -} - -ConjunctionExpression::ConjunctionExpression(ExpressionType type, unique_ptr left, - unique_ptr right) - : ParsedExpression(type, ExpressionClass::CONJUNCTION) { - AddExpression(std::move(left)); - AddExpression(std::move(right)); -} - -void ConjunctionExpression::AddExpression(unique_ptr expr) { - if (expr->type == type) { - // expr is a conjunction of the same type: merge the expression lists together - auto &other = expr->Cast(); - for (auto &child : other.children) { - children.push_back(std::move(child)); - } - } else { - children.push_back(std::move(expr)); - } -} - -string ConjunctionExpression::ToString() const { - return ToString(*this); -} - -bool ConjunctionExpression::Equal(const ConjunctionExpression &a, const ConjunctionExpression &b) { - return ExpressionUtil::SetEquals(a.children, b.children); -} - -unique_ptr ConjunctionExpression::Copy() const { - vector> copy_children; - copy_children.reserve(children.size()); - for (auto &expr : children) { - copy_children.push_back(expr->Copy()); - } - - auto copy = make_uniq(type, std::move(copy_children)); - copy->CopyProperties(*this); - return std::move(copy); -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -ConstantExpression::ConstantExpression() : ParsedExpression(ExpressionType::VALUE_CONSTANT, ExpressionClass::CONSTANT) { -} - -ConstantExpression::ConstantExpression(Value val) - : ParsedExpression(ExpressionType::VALUE_CONSTANT, ExpressionClass::CONSTANT), value(std::move(val)) { -} - -string ConstantExpression::ToString() const { - return value.ToSQLString(); -} - -bool ConstantExpression::Equal(const ConstantExpression &a, const ConstantExpression &b) { - return a.value.type() == b.value.type() && !ValueOperations::DistinctFrom(a.value, b.value); -} - -hash_t ConstantExpression::Hash() const { - return value.Hash(); -} - -unique_ptr ConstantExpression::Copy() const { - auto copy = make_uniq(value); - copy->CopyProperties(*this); - return std::move(copy); -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -DefaultExpression::DefaultExpression() : ParsedExpression(ExpressionType::VALUE_DEFAULT, ExpressionClass::DEFAULT) { -} - -string DefaultExpression::ToString() const { - return "DEFAULT"; -} - -unique_ptr DefaultExpression::Copy() const { - auto copy = make_uniq(); - copy->CopyProperties(*this); - return std::move(copy); -} - -} // namespace duckdb - - -#include - - - - - - - -namespace duckdb { - -FunctionExpression::FunctionExpression() : ParsedExpression(ExpressionType::FUNCTION, ExpressionClass::FUNCTION) { -} - -FunctionExpression::FunctionExpression(string catalog, string schema, const string &function_name, - vector> children_p, - unique_ptr filter, unique_ptr order_bys_p, - bool distinct, bool is_operator, bool export_state_p) - : ParsedExpression(ExpressionType::FUNCTION, ExpressionClass::FUNCTION), catalog(std::move(catalog)), - schema(std::move(schema)), function_name(StringUtil::Lower(function_name)), is_operator(is_operator), - children(std::move(children_p)), distinct(distinct), filter(std::move(filter)), order_bys(std::move(order_bys_p)), - export_state(export_state_p) { - D_ASSERT(!function_name.empty()); - if (!order_bys) { - order_bys = make_uniq(); - } -} - -FunctionExpression::FunctionExpression(const string &function_name, vector> children_p, - unique_ptr filter, unique_ptr order_bys, - bool distinct, bool is_operator, bool export_state_p) - : FunctionExpression(INVALID_CATALOG, INVALID_SCHEMA, function_name, std::move(children_p), std::move(filter), - std::move(order_bys), distinct, is_operator, export_state_p) { -} - -string FunctionExpression::ToString() const { - return ToString(*this, schema, function_name, is_operator, distinct, - filter.get(), order_bys.get(), export_state, true); -} - -bool FunctionExpression::Equal(const FunctionExpression &a, const FunctionExpression &b) { - if (a.catalog != b.catalog || a.schema != b.schema || a.function_name != b.function_name || - b.distinct != a.distinct) { - return false; - } - if (b.children.size() != a.children.size()) { - return false; - } - for (idx_t i = 0; i < a.children.size(); i++) { - if (!a.children[i]->Equals(*b.children[i])) { - return false; - } - } - if (!ParsedExpression::Equals(a.filter, b.filter)) { - return false; - } - if (!OrderModifier::Equals(a.order_bys, b.order_bys)) { - return false; - } - if (a.export_state != b.export_state) { - return false; - } - return true; -} - -hash_t FunctionExpression::Hash() const { - hash_t result = ParsedExpression::Hash(); - result = CombineHash(result, duckdb::Hash(schema.c_str())); - result = CombineHash(result, duckdb::Hash(function_name.c_str())); - result = CombineHash(result, duckdb::Hash(distinct)); - result = CombineHash(result, duckdb::Hash(export_state)); - return result; -} - -unique_ptr FunctionExpression::Copy() const { - vector> copy_children; - unique_ptr filter_copy; - copy_children.reserve(children.size()); - for (auto &child : children) { - copy_children.push_back(child->Copy()); - } - if (filter) { - filter_copy = filter->Copy(); - } - auto order_copy = order_bys ? unique_ptr_cast(order_bys->Copy()) : nullptr; - auto copy = - make_uniq(catalog, schema, function_name, std::move(copy_children), std::move(filter_copy), - std::move(order_copy), distinct, is_operator, export_state); - copy->CopyProperties(*this); - return std::move(copy); -} - -void FunctionExpression::Verify() const { - D_ASSERT(!function_name.empty()); -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -LambdaExpression::LambdaExpression() : ParsedExpression(ExpressionType::LAMBDA, ExpressionClass::LAMBDA) { -} - -LambdaExpression::LambdaExpression(unique_ptr lhs, unique_ptr expr) - : ParsedExpression(ExpressionType::LAMBDA, ExpressionClass::LAMBDA), lhs(std::move(lhs)), expr(std::move(expr)) { -} - -string LambdaExpression::ToString() const { - return "(" + lhs->ToString() + " -> " + expr->ToString() + ")"; -} - -bool LambdaExpression::Equal(const LambdaExpression &a, const LambdaExpression &b) { - return a.lhs->Equals(*b.lhs) && a.expr->Equals(*b.expr); -} - -hash_t LambdaExpression::Hash() const { - hash_t result = lhs->Hash(); - ParsedExpression::Hash(); - result = CombineHash(result, expr->Hash()); - return result; -} - -unique_ptr LambdaExpression::Copy() const { - auto copy = make_uniq(lhs->Copy(), expr->Copy()); - copy->CopyProperties(*this); - return std::move(copy); -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -OperatorExpression::OperatorExpression(ExpressionType type, unique_ptr left, - unique_ptr right) - : ParsedExpression(type, ExpressionClass::OPERATOR) { - if (left) { - children.push_back(std::move(left)); - } - if (right) { - children.push_back(std::move(right)); - } -} - -OperatorExpression::OperatorExpression(ExpressionType type, vector> children) - : ParsedExpression(type, ExpressionClass::OPERATOR), children(std::move(children)) { -} - -string OperatorExpression::ToString() const { - return ToString(*this); -} - -bool OperatorExpression::Equal(const OperatorExpression &a, const OperatorExpression &b) { - if (a.children.size() != b.children.size()) { - return false; - } - for (idx_t i = 0; i < a.children.size(); i++) { - if (!a.children[i]->Equals(*b.children[i])) { - return false; - } - } - return true; -} - -unique_ptr OperatorExpression::Copy() const { - auto copy = make_uniq(type); - copy->CopyProperties(*this); - for (auto &it : children) { - copy->children.push_back(it->Copy()); - } - return std::move(copy); -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -ParameterExpression::ParameterExpression() - : ParsedExpression(ExpressionType::VALUE_PARAMETER, ExpressionClass::PARAMETER) { -} - -string ParameterExpression::ToString() const { - return "$" + identifier; -} - -unique_ptr ParameterExpression::Copy() const { - auto copy = make_uniq(); - copy->identifier = identifier; - copy->CopyProperties(*this); - return std::move(copy); -} - -bool ParameterExpression::Equal(const ParameterExpression &a, const ParameterExpression &b) { - return StringUtil::CIEquals(a.identifier, b.identifier); -} - -hash_t ParameterExpression::Hash() const { - hash_t result = ParsedExpression::Hash(); - return CombineHash(duckdb::Hash(identifier.c_str(), identifier.size()), result); -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -PositionalReferenceExpression::PositionalReferenceExpression() - : ParsedExpression(ExpressionType::POSITIONAL_REFERENCE, ExpressionClass::POSITIONAL_REFERENCE) { -} - -PositionalReferenceExpression::PositionalReferenceExpression(idx_t index) - : ParsedExpression(ExpressionType::POSITIONAL_REFERENCE, ExpressionClass::POSITIONAL_REFERENCE), index(index) { -} - -string PositionalReferenceExpression::ToString() const { - return "#" + to_string(index); -} - -bool PositionalReferenceExpression::Equal(const PositionalReferenceExpression &a, - const PositionalReferenceExpression &b) { - return a.index == b.index; -} - -unique_ptr PositionalReferenceExpression::Copy() const { - auto copy = make_uniq(index); - copy->CopyProperties(*this); - return std::move(copy); -} - -hash_t PositionalReferenceExpression::Hash() const { - hash_t result = ParsedExpression::Hash(); - return CombineHash(duckdb::Hash(index), result); -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -StarExpression::StarExpression(string relation_name_p) - : ParsedExpression(ExpressionType::STAR, ExpressionClass::STAR), relation_name(std::move(relation_name_p)) { -} - -string StarExpression::ToString() const { - if (expr) { - D_ASSERT(columns); - return "COLUMNS(" + expr->ToString() + ")"; - } - string result; - if (columns) { - result += "COLUMNS("; - } - result += relation_name.empty() ? "*" : relation_name + ".*"; - if (!exclude_list.empty()) { - result += " EXCLUDE ("; - bool first_entry = true; - for (auto &entry : exclude_list) { - if (!first_entry) { - result += ", "; - } - result += entry; - first_entry = false; - } - result += ")"; - } - if (!replace_list.empty()) { - result += " REPLACE ("; - bool first_entry = true; - for (auto &entry : replace_list) { - if (!first_entry) { - result += ", "; - } - result += entry.second->ToString(); - result += " AS "; - result += entry.first; - first_entry = false; - } - result += ")"; - } - if (columns) { - result += ")"; - } - return result; -} - -bool StarExpression::Equal(const StarExpression &a, const StarExpression &b) { - if (a.relation_name != b.relation_name || a.exclude_list != b.exclude_list) { - return false; - } - if (a.columns != b.columns) { - return false; - } - if (a.replace_list.size() != b.replace_list.size()) { - return false; - } - for (auto &entry : a.replace_list) { - auto other_entry = b.replace_list.find(entry.first); - if (other_entry == b.replace_list.end()) { - return false; - } - if (!entry.second->Equals(*other_entry->second)) { - return false; - } - } - if (!ParsedExpression::Equals(a.expr, b.expr)) { - return false; - } - return true; -} - -unique_ptr StarExpression::Copy() const { - auto copy = make_uniq(relation_name); - copy->exclude_list = exclude_list; - for (auto &entry : replace_list) { - copy->replace_list[entry.first] = entry.second->Copy(); - } - copy->columns = columns; - copy->expr = expr ? expr->Copy() : nullptr; - copy->CopyProperties(*this); - return std::move(copy); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -SubqueryExpression::SubqueryExpression() - : ParsedExpression(ExpressionType::SUBQUERY, ExpressionClass::SUBQUERY), subquery_type(SubqueryType::INVALID), - comparison_type(ExpressionType::INVALID) { -} - -string SubqueryExpression::ToString() const { - switch (subquery_type) { - case SubqueryType::ANY: - return "(" + child->ToString() + " " + ExpressionTypeToOperator(comparison_type) + " ANY(" + - subquery->ToString() + "))"; - case SubqueryType::EXISTS: - return "EXISTS(" + subquery->ToString() + ")"; - case SubqueryType::NOT_EXISTS: - return "NOT EXISTS(" + subquery->ToString() + ")"; - case SubqueryType::SCALAR: - return "(" + subquery->ToString() + ")"; - default: - throw InternalException("Unrecognized type for subquery"); - } -} - -bool SubqueryExpression::Equal(const SubqueryExpression &a, const SubqueryExpression &b) { - if (!a.subquery || !b.subquery) { - return false; - } - if (!ParsedExpression::Equals(a.child, b.child)) { - return false; - } - return a.comparison_type == b.comparison_type && a.subquery_type == b.subquery_type && - a.subquery->Equals(*b.subquery); -} - -unique_ptr SubqueryExpression::Copy() const { - auto copy = make_uniq(); - copy->CopyProperties(*this); - copy->subquery = unique_ptr_cast(subquery->Copy()); - copy->subquery_type = subquery_type; - copy->child = child ? child->Copy() : nullptr; - copy->comparison_type = comparison_type; - return std::move(copy); -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -WindowExpression::WindowExpression(ExpressionType type) : ParsedExpression(type, ExpressionClass::WINDOW) { -} - -WindowExpression::WindowExpression(ExpressionType type, string catalog_name, string schema, const string &function_name) - : ParsedExpression(type, ExpressionClass::WINDOW), catalog(std::move(catalog_name)), schema(std::move(schema)), - function_name(StringUtil::Lower(function_name)), ignore_nulls(false) { - switch (type) { - case ExpressionType::WINDOW_AGGREGATE: - case ExpressionType::WINDOW_ROW_NUMBER: - case ExpressionType::WINDOW_FIRST_VALUE: - case ExpressionType::WINDOW_LAST_VALUE: - case ExpressionType::WINDOW_NTH_VALUE: - case ExpressionType::WINDOW_RANK: - case ExpressionType::WINDOW_RANK_DENSE: - case ExpressionType::WINDOW_PERCENT_RANK: - case ExpressionType::WINDOW_CUME_DIST: - case ExpressionType::WINDOW_LEAD: - case ExpressionType::WINDOW_LAG: - case ExpressionType::WINDOW_NTILE: - break; - default: - throw NotImplementedException("Window aggregate type %s not supported", ExpressionTypeToString(type).c_str()); - } -} - -ExpressionType WindowExpression::WindowToExpressionType(string &fun_name) { - if (fun_name == "rank") { - return ExpressionType::WINDOW_RANK; - } else if (fun_name == "rank_dense" || fun_name == "dense_rank") { - return ExpressionType::WINDOW_RANK_DENSE; - } else if (fun_name == "percent_rank") { - return ExpressionType::WINDOW_PERCENT_RANK; - } else if (fun_name == "row_number") { - return ExpressionType::WINDOW_ROW_NUMBER; - } else if (fun_name == "first_value" || fun_name == "first") { - return ExpressionType::WINDOW_FIRST_VALUE; - } else if (fun_name == "last_value" || fun_name == "last") { - return ExpressionType::WINDOW_LAST_VALUE; - } else if (fun_name == "nth_value") { - return ExpressionType::WINDOW_NTH_VALUE; - } else if (fun_name == "cume_dist") { - return ExpressionType::WINDOW_CUME_DIST; - } else if (fun_name == "lead") { - return ExpressionType::WINDOW_LEAD; - } else if (fun_name == "lag") { - return ExpressionType::WINDOW_LAG; - } else if (fun_name == "ntile") { - return ExpressionType::WINDOW_NTILE; - } - return ExpressionType::WINDOW_AGGREGATE; -} - -string WindowExpression::ToString() const { - return ToString(*this, schema, function_name); -} - -bool WindowExpression::Equal(const WindowExpression &a, const WindowExpression &b) { - // check if the child expressions are equivalent - if (a.ignore_nulls != b.ignore_nulls) { - return false; - } - if (!ParsedExpression::ListEquals(a.children, b.children)) { - return false; - } - if (a.start != b.start || a.end != b.end) { - return false; - } - // check if the framing expressions are equivalentbind_ - if (!ParsedExpression::Equals(a.start_expr, b.start_expr) || !ParsedExpression::Equals(a.end_expr, b.end_expr) || - !ParsedExpression::Equals(a.offset_expr, b.offset_expr) || - !ParsedExpression::Equals(a.default_expr, b.default_expr)) { - return false; - } - - // check if the partitions are equivalent - if (!ParsedExpression::ListEquals(a.partitions, b.partitions)) { - return false; - } - // check if the orderings are equivalent - if (a.orders.size() != b.orders.size()) { - return false; - } - for (idx_t i = 0; i < a.orders.size(); i++) { - if (a.orders[i].type != b.orders[i].type) { - return false; - } - if (!a.orders[i].expression->Equals(*b.orders[i].expression)) { - return false; - } - } - // check if the filter clauses are equivalent - if (!ParsedExpression::Equals(a.filter_expr, b.filter_expr)) { - return false; - } - - return true; -} - -unique_ptr WindowExpression::Copy() const { - auto new_window = make_uniq(type, catalog, schema, function_name); - new_window->CopyProperties(*this); - - for (auto &child : children) { - new_window->children.push_back(child->Copy()); - } - - for (auto &e : partitions) { - new_window->partitions.push_back(e->Copy()); - } - - for (auto &o : orders) { - new_window->orders.emplace_back(o.type, o.null_order, o.expression->Copy()); - } - - new_window->filter_expr = filter_expr ? filter_expr->Copy() : nullptr; - - new_window->start = start; - new_window->end = end; - new_window->start_expr = start_expr ? start_expr->Copy() : nullptr; - new_window->end_expr = end_expr ? end_expr->Copy() : nullptr; - new_window->offset_expr = offset_expr ? offset_expr->Copy() : nullptr; - new_window->default_expr = default_expr ? default_expr->Copy() : nullptr; - new_window->ignore_nulls = ignore_nulls; - - return std::move(new_window); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -template -bool ExpressionUtil::ExpressionListEquals(const vector> &a, const vector> &b) { - if (a.size() != b.size()) { - return false; - } - for (idx_t i = 0; i < a.size(); i++) { - if (!(*a[i] == *b[i])) { - return false; - } - } - return true; -} - -template -bool ExpressionUtil::ExpressionSetEquals(const vector> &a, const vector> &b) { - if (a.size() != b.size()) { - return false; - } - // we create a map of expression -> count for the left side - // we keep the count because the same expression can occur multiple times (e.g. "1 AND 1" is legal) - // in this case we track the following value: map["Constant(1)"] = 2 - EXPRESSION_MAP map; - for (idx_t i = 0; i < a.size(); i++) { - map[*a[i]]++; - } - // now on the right side we reduce the counts again - // if the conjunctions are identical, all the counts will be 0 after the - for (auto &expr : b) { - auto entry = map.find(*expr); - // first we check if we can find the expression in the map at all - if (entry == map.end()) { - return false; - } - // if we found it we check the count; if the count is already 0 we return false - // this happens if e.g. the left side contains "1 AND X", and the right side contains "1 AND 1" - // "1" is contained in the map, however, the right side contains the expression twice - // hence we know the children are not identical in this case because the LHS and RHS have a different count for - // the Constant(1) expression - if (entry->second == 0) { - return false; - } - entry->second--; - } - return true; -} - -bool ExpressionUtil::ListEquals(const vector> &a, - const vector> &b) { - return ExpressionListEquals(a, b); -} - -bool ExpressionUtil::ListEquals(const vector> &a, const vector> &b) { - return ExpressionListEquals(a, b); -} - -bool ExpressionUtil::SetEquals(const vector> &a, - const vector> &b) { - return ExpressionSetEquals>(a, b); -} - -bool ExpressionUtil::SetEquals(const vector> &a, const vector> &b) { - return ExpressionSetEquals>(a, b); -} - -} // namespace duckdb - - - - -namespace duckdb { - -bool KeywordHelper::IsKeyword(const string &text) { - return Parser::IsKeyword(text); -} - -bool KeywordHelper::RequiresQuotes(const string &text, bool allow_caps) { - for (size_t i = 0; i < text.size(); i++) { - if (i > 0 && (text[i] >= '0' && text[i] <= '9')) { - continue; - } - if (text[i] >= 'a' && text[i] <= 'z') { - continue; - } - if (allow_caps) { - if (text[i] >= 'A' && text[i] <= 'Z') { - continue; - } - } - if (text[i] == '_') { - continue; - } - return true; - } - return IsKeyword(text); -} - -string KeywordHelper::EscapeQuotes(const string &text, char quote) { - return StringUtil::Replace(text, string(1, quote), string(2, quote)); -} - -string KeywordHelper::WriteQuoted(const string &text, char quote) { - // 1. Escapes all occurences of 'quote' by doubling them (escape in SQL) - // 2. Adds quotes around the string - return string(1, quote) + EscapeQuotes(text, quote) + string(1, quote); -} - -string KeywordHelper::WriteOptionallyQuoted(const string &text, char quote, bool allow_caps) { - if (!RequiresQuotes(text, allow_caps)) { - return text; - } - return WriteQuoted(text, quote); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -AlterInfo::AlterInfo(AlterType type, string catalog_p, string schema_p, string name_p, OnEntryNotFound if_not_found) - : ParseInfo(TYPE), type(type), if_not_found(if_not_found), catalog(std::move(catalog_p)), - schema(std::move(schema_p)), name(std::move(name_p)), allow_internal(false) { -} - -AlterInfo::AlterInfo(AlterType type) : ParseInfo(TYPE), type(type) { -} - -AlterInfo::~AlterInfo() { -} - -AlterEntryData AlterInfo::GetAlterEntryData() const { - AlterEntryData data; - data.catalog = catalog; - data.schema = schema; - data.name = name; - data.if_not_found = if_not_found; - return data; -} - -} // namespace duckdb - - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// AlterScalarFunctionInfo -//===--------------------------------------------------------------------===// -AlterScalarFunctionInfo::AlterScalarFunctionInfo(AlterScalarFunctionType type, AlterEntryData data) - : AlterInfo(AlterType::ALTER_SCALAR_FUNCTION, std::move(data.catalog), std::move(data.schema), std::move(data.name), - data.if_not_found), - alter_scalar_function_type(type) { -} -AlterScalarFunctionInfo::~AlterScalarFunctionInfo() { -} - -CatalogType AlterScalarFunctionInfo::GetCatalogType() const { - return CatalogType::SCALAR_FUNCTION_ENTRY; -} - -//===--------------------------------------------------------------------===// -// AddScalarFunctionOverloadInfo -//===--------------------------------------------------------------------===// -AddScalarFunctionOverloadInfo::AddScalarFunctionOverloadInfo(AlterEntryData data, ScalarFunctionSet new_overloads_p) - : AlterScalarFunctionInfo(AlterScalarFunctionType::ADD_FUNCTION_OVERLOADS, std::move(data)), - new_overloads(std::move(new_overloads_p)) { - this->allow_internal = true; -} - -AddScalarFunctionOverloadInfo::~AddScalarFunctionOverloadInfo() { -} - -unique_ptr AddScalarFunctionOverloadInfo::Copy() const { - return make_uniq_base(GetAlterEntryData(), new_overloads); -} - -} // namespace duckdb - - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// AlterTableFunctionInfo -//===--------------------------------------------------------------------===// -AlterTableFunctionInfo::AlterTableFunctionInfo(AlterTableFunctionType type, AlterEntryData data) - : AlterInfo(AlterType::ALTER_TABLE_FUNCTION, std::move(data.catalog), std::move(data.schema), std::move(data.name), - data.if_not_found), - alter_table_function_type(type) { -} -AlterTableFunctionInfo::~AlterTableFunctionInfo() { -} - -CatalogType AlterTableFunctionInfo::GetCatalogType() const { - return CatalogType::TABLE_FUNCTION_ENTRY; -} - -//===--------------------------------------------------------------------===// -// AddTableFunctionOverloadInfo -//===--------------------------------------------------------------------===// -AddTableFunctionOverloadInfo::AddTableFunctionOverloadInfo(AlterEntryData data, TableFunctionSet new_overloads_p) - : AlterTableFunctionInfo(AlterTableFunctionType::ADD_FUNCTION_OVERLOADS, std::move(data)), - new_overloads(std::move(new_overloads_p)) { - this->allow_internal = true; -} - -AddTableFunctionOverloadInfo::~AddTableFunctionOverloadInfo() { -} - -unique_ptr AddTableFunctionOverloadInfo::Copy() const { - return make_uniq_base(GetAlterEntryData(), new_overloads); -} - -} // namespace duckdb - - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// ChangeOwnershipInfo -//===--------------------------------------------------------------------===// -ChangeOwnershipInfo::ChangeOwnershipInfo(CatalogType entry_catalog_type, string entry_catalog_p, string entry_schema_p, - string entry_name_p, string owner_schema_p, string owner_name_p, - OnEntryNotFound if_not_found) - : AlterInfo(AlterType::CHANGE_OWNERSHIP, std::move(entry_catalog_p), std::move(entry_schema_p), - std::move(entry_name_p), if_not_found), - entry_catalog_type(entry_catalog_type), owner_schema(std::move(owner_schema_p)), - owner_name(std::move(owner_name_p)) { -} - -CatalogType ChangeOwnershipInfo::GetCatalogType() const { - return entry_catalog_type; -} - -unique_ptr ChangeOwnershipInfo::Copy() const { - return make_uniq_base(entry_catalog_type, catalog, schema, name, owner_schema, - owner_name, if_not_found); -} - -//===--------------------------------------------------------------------===// -// AlterTableInfo -//===--------------------------------------------------------------------===// -AlterTableInfo::AlterTableInfo(AlterTableType type) : AlterInfo(AlterType::ALTER_TABLE), alter_table_type(type) { -} - -AlterTableInfo::AlterTableInfo(AlterTableType type, AlterEntryData data) - : AlterInfo(AlterType::ALTER_TABLE, std::move(data.catalog), std::move(data.schema), std::move(data.name), - data.if_not_found), - alter_table_type(type) { -} -AlterTableInfo::~AlterTableInfo() { -} - -CatalogType AlterTableInfo::GetCatalogType() const { - return CatalogType::TABLE_ENTRY; -} -//===--------------------------------------------------------------------===// -// RenameColumnInfo -//===--------------------------------------------------------------------===// -RenameColumnInfo::RenameColumnInfo(AlterEntryData data, string old_name_p, string new_name_p) - : AlterTableInfo(AlterTableType::RENAME_COLUMN, std::move(data)), old_name(std::move(old_name_p)), - new_name(std::move(new_name_p)) { -} - -RenameColumnInfo::RenameColumnInfo() : AlterTableInfo(AlterTableType::RENAME_COLUMN) { -} - -RenameColumnInfo::~RenameColumnInfo() { -} - -unique_ptr RenameColumnInfo::Copy() const { - return make_uniq_base(GetAlterEntryData(), old_name, new_name); -} - -//===--------------------------------------------------------------------===// -// RenameTableInfo -//===--------------------------------------------------------------------===// -RenameTableInfo::RenameTableInfo() : AlterTableInfo(AlterTableType::RENAME_TABLE) { -} - -RenameTableInfo::RenameTableInfo(AlterEntryData data, string new_name_p) - : AlterTableInfo(AlterTableType::RENAME_TABLE, std::move(data)), new_table_name(std::move(new_name_p)) { -} - -RenameTableInfo::~RenameTableInfo() { -} - -unique_ptr RenameTableInfo::Copy() const { - return make_uniq_base(GetAlterEntryData(), new_table_name); -} - -//===--------------------------------------------------------------------===// -// AddColumnInfo -//===--------------------------------------------------------------------===// -AddColumnInfo::AddColumnInfo(ColumnDefinition new_column_p) - : AlterTableInfo(AlterTableType::ADD_COLUMN), new_column(std::move(new_column_p)) { -} - -AddColumnInfo::AddColumnInfo(AlterEntryData data, ColumnDefinition new_column, bool if_column_not_exists) - : AlterTableInfo(AlterTableType::ADD_COLUMN, std::move(data)), new_column(std::move(new_column)), - if_column_not_exists(if_column_not_exists) { -} - -AddColumnInfo::~AddColumnInfo() { -} - -unique_ptr AddColumnInfo::Copy() const { - return make_uniq_base(GetAlterEntryData(), new_column.Copy(), if_column_not_exists); -} - -//===--------------------------------------------------------------------===// -// RemoveColumnInfo -//===--------------------------------------------------------------------===// -RemoveColumnInfo::RemoveColumnInfo() : AlterTableInfo(AlterTableType::REMOVE_COLUMN) { -} - -RemoveColumnInfo::RemoveColumnInfo(AlterEntryData data, string removed_column, bool if_column_exists, bool cascade) - : AlterTableInfo(AlterTableType::REMOVE_COLUMN, std::move(data)), removed_column(std::move(removed_column)), - if_column_exists(if_column_exists), cascade(cascade) { -} -RemoveColumnInfo::~RemoveColumnInfo() { -} - -unique_ptr RemoveColumnInfo::Copy() const { - return make_uniq_base(GetAlterEntryData(), removed_column, if_column_exists, cascade); -} - -//===--------------------------------------------------------------------===// -// ChangeColumnTypeInfo -//===--------------------------------------------------------------------===// -ChangeColumnTypeInfo::ChangeColumnTypeInfo() : AlterTableInfo(AlterTableType::ALTER_COLUMN_TYPE) { -} - -ChangeColumnTypeInfo::ChangeColumnTypeInfo(AlterEntryData data, string column_name, LogicalType target_type, - unique_ptr expression) - : AlterTableInfo(AlterTableType::ALTER_COLUMN_TYPE, std::move(data)), column_name(std::move(column_name)), - target_type(std::move(target_type)), expression(std::move(expression)) { -} -ChangeColumnTypeInfo::~ChangeColumnTypeInfo() { -} - -unique_ptr ChangeColumnTypeInfo::Copy() const { - return make_uniq_base(GetAlterEntryData(), column_name, target_type, - expression->Copy()); -} - -//===--------------------------------------------------------------------===// -// SetDefaultInfo -//===--------------------------------------------------------------------===// -SetDefaultInfo::SetDefaultInfo() : AlterTableInfo(AlterTableType::SET_DEFAULT) { -} - -SetDefaultInfo::SetDefaultInfo(AlterEntryData data, string column_name_p, unique_ptr new_default) - : AlterTableInfo(AlterTableType::SET_DEFAULT, std::move(data)), column_name(std::move(column_name_p)), - expression(std::move(new_default)) { -} -SetDefaultInfo::~SetDefaultInfo() { -} - -unique_ptr SetDefaultInfo::Copy() const { - return make_uniq_base(GetAlterEntryData(), column_name, - expression ? expression->Copy() : nullptr); -} - -//===--------------------------------------------------------------------===// -// SetNotNullInfo -//===--------------------------------------------------------------------===// -SetNotNullInfo::SetNotNullInfo() : AlterTableInfo(AlterTableType::SET_NOT_NULL) { -} - -SetNotNullInfo::SetNotNullInfo(AlterEntryData data, string column_name_p) - : AlterTableInfo(AlterTableType::SET_NOT_NULL, std::move(data)), column_name(std::move(column_name_p)) { -} -SetNotNullInfo::~SetNotNullInfo() { -} - -unique_ptr SetNotNullInfo::Copy() const { - return make_uniq_base(GetAlterEntryData(), column_name); -} - -//===--------------------------------------------------------------------===// -// DropNotNullInfo -//===--------------------------------------------------------------------===// -DropNotNullInfo::DropNotNullInfo() : AlterTableInfo(AlterTableType::DROP_NOT_NULL) { -} - -DropNotNullInfo::DropNotNullInfo(AlterEntryData data, string column_name_p) - : AlterTableInfo(AlterTableType::DROP_NOT_NULL, std::move(data)), column_name(std::move(column_name_p)) { -} -DropNotNullInfo::~DropNotNullInfo() { -} - -unique_ptr DropNotNullInfo::Copy() const { - return make_uniq_base(GetAlterEntryData(), column_name); -} - -//===--------------------------------------------------------------------===// -// AlterForeignKeyInfo -//===--------------------------------------------------------------------===// -AlterForeignKeyInfo::AlterForeignKeyInfo() : AlterTableInfo(AlterTableType::FOREIGN_KEY_CONSTRAINT) { -} - -AlterForeignKeyInfo::AlterForeignKeyInfo(AlterEntryData data, string fk_table, vector pk_columns, - vector fk_columns, vector pk_keys, - vector fk_keys, AlterForeignKeyType type_p) - : AlterTableInfo(AlterTableType::FOREIGN_KEY_CONSTRAINT, std::move(data)), fk_table(std::move(fk_table)), - pk_columns(std::move(pk_columns)), fk_columns(std::move(fk_columns)), pk_keys(std::move(pk_keys)), - fk_keys(std::move(fk_keys)), type(type_p) { -} -AlterForeignKeyInfo::~AlterForeignKeyInfo() { -} - -unique_ptr AlterForeignKeyInfo::Copy() const { - return make_uniq_base(GetAlterEntryData(), fk_table, pk_columns, fk_columns, - pk_keys, fk_keys, type); -} - -//===--------------------------------------------------------------------===// -// Alter View -//===--------------------------------------------------------------------===// -AlterViewInfo::AlterViewInfo(AlterViewType type) : AlterInfo(AlterType::ALTER_VIEW), alter_view_type(type) { -} - -AlterViewInfo::AlterViewInfo(AlterViewType type, AlterEntryData data) - : AlterInfo(AlterType::ALTER_VIEW, std::move(data.catalog), std::move(data.schema), std::move(data.name), - data.if_not_found), - alter_view_type(type) { -} -AlterViewInfo::~AlterViewInfo() { -} - -CatalogType AlterViewInfo::GetCatalogType() const { - return CatalogType::VIEW_ENTRY; -} - -//===--------------------------------------------------------------------===// -// RenameViewInfo -//===--------------------------------------------------------------------===// -RenameViewInfo::RenameViewInfo() : AlterViewInfo(AlterViewType::RENAME_VIEW) { -} -RenameViewInfo::RenameViewInfo(AlterEntryData data, string new_name_p) - : AlterViewInfo(AlterViewType::RENAME_VIEW, std::move(data)), new_view_name(std::move(new_name_p)) { -} -RenameViewInfo::~RenameViewInfo() { -} - -unique_ptr RenameViewInfo::Copy() const { - return make_uniq_base(GetAlterEntryData(), new_view_name); -} - -} // namespace duckdb - - -namespace duckdb { - -unique_ptr AttachInfo::Copy() const { - auto result = make_uniq(); - result->name = name; - result->path = path; - result->options = options; - return result; -} - -} // namespace duckdb - - -namespace duckdb { - -CreateAggregateFunctionInfo::CreateAggregateFunctionInfo(AggregateFunction function) - : CreateFunctionInfo(CatalogType::AGGREGATE_FUNCTION_ENTRY), functions(function.name) { - name = function.name; - functions.AddFunction(std::move(function)); - internal = true; -} - -CreateAggregateFunctionInfo::CreateAggregateFunctionInfo(AggregateFunctionSet set) - : CreateFunctionInfo(CatalogType::AGGREGATE_FUNCTION_ENTRY), functions(std::move(set)) { - name = functions.name; - for (auto &func : functions.functions) { - func.name = functions.name; - } - internal = true; -} - -unique_ptr CreateAggregateFunctionInfo::Copy() const { - auto result = make_uniq(functions); - CopyProperties(*result); - return std::move(result); -} - -} // namespace duckdb - - -namespace duckdb { - -CreateCollationInfo::CreateCollationInfo(string name_p, ScalarFunction function_p, bool combinable_p, - bool not_required_for_equality_p) - : CreateInfo(CatalogType::COLLATION_ENTRY), function(std::move(function_p)), combinable(combinable_p), - not_required_for_equality(not_required_for_equality_p) { - this->name = std::move(name_p); - internal = true; -} - -unique_ptr CreateCollationInfo::Copy() const { - auto result = make_uniq(name, function, combinable, not_required_for_equality); - CopyProperties(*result); - return std::move(result); -} - -} // namespace duckdb - - -namespace duckdb { - -CreateCopyFunctionInfo::CreateCopyFunctionInfo(CopyFunction function_p) - : CreateInfo(CatalogType::COPY_FUNCTION_ENTRY), function(std::move(function_p)) { - this->name = function.name; - internal = true; -} - -unique_ptr CreateCopyFunctionInfo::Copy() const { - auto result = make_uniq(function); - CopyProperties(*result); - return std::move(result); -} - -} // namespace duckdb - - -namespace duckdb { - -unique_ptr CreateIndexInfo::Copy() const { - auto result = make_uniq(); - CopyProperties(*result); - - result->index_type = index_type; - result->index_name = index_name; - result->constraint_type = constraint_type; - result->table = table; - for (auto &expr : expressions) { - result->expressions.push_back(expr->Copy()); - } - for (auto &expr : parsed_expressions) { - result->parsed_expressions.push_back(expr->Copy()); - } - - result->scan_types = scan_types; - result->names = names; - result->column_ids = column_ids; - result->options = options; - return std::move(result); -} - -} // namespace duckdb - - - - - - - - - - - -namespace duckdb { - -void CreateInfo::CopyProperties(CreateInfo &other) const { - other.type = type; - other.catalog = catalog; - other.schema = schema; - other.on_conflict = on_conflict; - other.temporary = temporary; - other.internal = internal; - other.sql = sql; -} - -unique_ptr CreateInfo::GetAlterInfo() const { - throw NotImplementedException("GetAlterInfo not implemented for this type"); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -CreateMacroInfo::CreateMacroInfo(CatalogType type) : CreateFunctionInfo(type, INVALID_SCHEMA) { -} - -unique_ptr CreateMacroInfo::Copy() const { - auto result = make_uniq(type); - result->function = function->Copy(); - result->name = name; - CopyProperties(*result); - return std::move(result); -} - -} // namespace duckdb - - -namespace duckdb { - -CreatePragmaFunctionInfo::CreatePragmaFunctionInfo(PragmaFunction function) - : CreateFunctionInfo(CatalogType::PRAGMA_FUNCTION_ENTRY), functions(function.name) { - name = function.name; - functions.AddFunction(std::move(function)); - internal = true; -} -CreatePragmaFunctionInfo::CreatePragmaFunctionInfo(string name, PragmaFunctionSet functions_p) - : CreateFunctionInfo(CatalogType::PRAGMA_FUNCTION_ENTRY), functions(std::move(functions_p)) { - this->name = std::move(name); - internal = true; -} - -unique_ptr CreatePragmaFunctionInfo::Copy() const { - auto result = make_uniq(functions.name, functions); - CopyProperties(*result); - return std::move(result); -} - -} // namespace duckdb - - - -namespace duckdb { - -CreateScalarFunctionInfo::CreateScalarFunctionInfo(ScalarFunction function) - : CreateFunctionInfo(CatalogType::SCALAR_FUNCTION_ENTRY), functions(function.name) { - name = function.name; - functions.AddFunction(std::move(function)); - internal = true; -} -CreateScalarFunctionInfo::CreateScalarFunctionInfo(ScalarFunctionSet set) - : CreateFunctionInfo(CatalogType::SCALAR_FUNCTION_ENTRY), functions(std::move(set)) { - name = functions.name; - for (auto &func : functions.functions) { - func.name = functions.name; - } - internal = true; -} - -unique_ptr CreateScalarFunctionInfo::Copy() const { - ScalarFunctionSet set(name); - set.functions = functions.functions; - auto result = make_uniq(std::move(set)); - CopyProperties(*result); - return std::move(result); -} - -unique_ptr CreateScalarFunctionInfo::GetAlterInfo() const { - return make_uniq_base( - AlterEntryData(catalog, schema, name, OnEntryNotFound::RETURN_NULL), functions); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -CreateSequenceInfo::CreateSequenceInfo() - : CreateInfo(CatalogType::SEQUENCE_ENTRY, INVALID_SCHEMA), name(string()), usage_count(0), increment(1), - min_value(1), max_value(NumericLimits::Maximum()), start_value(1), cycle(false) { -} - -unique_ptr CreateSequenceInfo::Copy() const { - auto result = make_uniq(); - CopyProperties(*result); - result->name = name; - result->schema = schema; - result->usage_count = usage_count; - result->increment = increment; - result->min_value = min_value; - result->max_value = max_value; - result->start_value = start_value; - result->cycle = cycle; - return std::move(result); -} - -} // namespace duckdb - - - -namespace duckdb { - -CreateTableFunctionInfo::CreateTableFunctionInfo(TableFunction function) - : CreateFunctionInfo(CatalogType::TABLE_FUNCTION_ENTRY), functions(function.name) { - name = function.name; - functions.AddFunction(std::move(function)); - internal = true; -} -CreateTableFunctionInfo::CreateTableFunctionInfo(TableFunctionSet set) - : CreateFunctionInfo(CatalogType::TABLE_FUNCTION_ENTRY), functions(std::move(set)) { - name = functions.name; - for (auto &func : functions.functions) { - func.name = functions.name; - } - internal = true; -} - -unique_ptr CreateTableFunctionInfo::Copy() const { - TableFunctionSet set(name); - set.functions = functions.functions; - auto result = make_uniq(std::move(set)); - CopyProperties(*result); - return std::move(result); -} - -unique_ptr CreateTableFunctionInfo::GetAlterInfo() const { - return make_uniq_base( - AlterEntryData(catalog, schema, name, OnEntryNotFound::RETURN_NULL), functions); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -CreateTableInfo::CreateTableInfo() : CreateInfo(CatalogType::TABLE_ENTRY, INVALID_SCHEMA) { -} - -CreateTableInfo::CreateTableInfo(string catalog_p, string schema_p, string name_p) - : CreateInfo(CatalogType::TABLE_ENTRY, std::move(schema_p), std::move(catalog_p)), table(std::move(name_p)) { -} - -CreateTableInfo::CreateTableInfo(SchemaCatalogEntry &schema, string name_p) - : CreateTableInfo(schema.catalog.GetName(), schema.name, std::move(name_p)) { -} - -unique_ptr CreateTableInfo::Copy() const { - auto result = make_uniq(catalog, schema, table); - CopyProperties(*result); - result->columns = columns.Copy(); - for (auto &constraint : constraints) { - result->constraints.push_back(constraint->Copy()); - } - if (query) { - result->query = unique_ptr_cast(query->Copy()); - } - return std::move(result); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -CreateTypeInfo::CreateTypeInfo() : CreateInfo(CatalogType::TYPE_ENTRY) { -} -CreateTypeInfo::CreateTypeInfo(string name_p, LogicalType type_p) - : CreateInfo(CatalogType::TYPE_ENTRY), name(std::move(name_p)), type(std::move(type_p)) { -} - -unique_ptr CreateTypeInfo::Copy() const { - auto result = make_uniq(); - CopyProperties(*result); - result->name = name; - result->type = type; - if (query) { - result->query = query->Copy(); - } - return std::move(result); -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -CreateViewInfo::CreateViewInfo() : CreateInfo(CatalogType::VIEW_ENTRY, INVALID_SCHEMA) { -} -CreateViewInfo::CreateViewInfo(string catalog_p, string schema_p, string view_name_p) - : CreateInfo(CatalogType::VIEW_ENTRY, std::move(schema_p), std::move(catalog_p)), - view_name(std::move(view_name_p)) { -} - -CreateViewInfo::CreateViewInfo(SchemaCatalogEntry &schema, string view_name) - : CreateViewInfo(schema.catalog.GetName(), schema.name, std::move(view_name)) { -} - -unique_ptr CreateViewInfo::Copy() const { - auto result = make_uniq(catalog, schema, view_name); - CopyProperties(*result); - result->aliases = aliases; - result->types = types; - result->query = unique_ptr_cast(query->Copy()); - return std::move(result); -} - -unique_ptr CreateViewInfo::FromSelect(ClientContext &context, unique_ptr info) { - D_ASSERT(info); - D_ASSERT(!info->view_name.empty()); - D_ASSERT(!info->sql.empty()); - D_ASSERT(!info->query); - - Parser parser; - parser.ParseQuery(info->sql); - if (parser.statements.size() != 1 || parser.statements[0]->type != StatementType::SELECT_STATEMENT) { - throw BinderException( - "Failed to create view from SQL string - \"%s\" - statement did not contain a single SELECT statement", - info->sql); - } - D_ASSERT(parser.statements.size() == 1 && parser.statements[0]->type == StatementType::SELECT_STATEMENT); - info->query = unique_ptr_cast(std::move(parser.statements[0])); - - auto binder = Binder::CreateBinder(context); - binder->BindCreateViewInfo(*info); - - return info; -} - -unique_ptr CreateViewInfo::FromCreateView(ClientContext &context, const string &sql) { - D_ASSERT(!sql.empty()); - - // parse the SQL statement - Parser parser; - parser.ParseQuery(sql); - - if (parser.statements.size() != 1 || parser.statements[0]->type != StatementType::CREATE_STATEMENT) { - throw BinderException( - "Failed to create view from SQL string - \"%s\" - statement did not contain a single CREATE VIEW statement", - sql); - } - auto &create_statement = parser.statements[0]->Cast(); - if (create_statement.info->type != CatalogType::VIEW_ENTRY) { - throw BinderException( - "Failed to create view from SQL string - \"%s\" - view did not contain a CREATE VIEW statement", sql); - } - - auto result = unique_ptr_cast(std::move(create_statement.info)); - - auto binder = Binder::CreateBinder(context); - binder->BindCreateViewInfo(*result); - - return result; -} - -} // namespace duckdb - - -namespace duckdb { - -DetachInfo::DetachInfo() : ParseInfo(TYPE) { -} - -unique_ptr DetachInfo::Copy() const { - auto result = make_uniq(); - result->name = name; - result->if_not_found = if_not_found; - return result; -} - -} // namespace duckdb - - -namespace duckdb { - -DropInfo::DropInfo() : ParseInfo(TYPE), catalog(INVALID_CATALOG), schema(INVALID_SCHEMA), cascade(false) { -} - -unique_ptr DropInfo::Copy() const { - auto result = make_uniq(); - result->type = type; - result->catalog = catalog; - result->schema = schema; - result->name = name; - result->if_not_found = if_not_found; - result->cascade = cascade; - result->allow_drop_internal = allow_drop_internal; - return result; -} - -} // namespace duckdb - - - - -namespace duckdb { - -// **DEPRECATED**: Use EnumUtil directly instead. -string SampleMethodToString(SampleMethod method) { - return EnumUtil::ToString(method); -} - -unique_ptr SampleOptions::Copy() { - auto result = make_uniq(); - result->sample_size = sample_size; - result->is_percentage = is_percentage; - result->method = method; - result->seed = seed; - return result; -} - -bool SampleOptions::Equals(SampleOptions *a, SampleOptions *b) { - if (a == b) { - return true; - } - if (!a || !b) { - return false; - } - if (a->sample_size != b->sample_size || a->is_percentage != b->is_percentage || a->method != b->method || - a->seed != b->seed) { - return false; - } - return true; -} - -} // namespace duckdb - - -namespace duckdb { - -TransactionInfo::TransactionInfo() : ParseInfo(TYPE) { -} - -TransactionInfo::TransactionInfo(TransactionType type) : ParseInfo(TYPE), type(type) { -} - -} // namespace duckdb - - -namespace duckdb { - -VacuumInfo::VacuumInfo(VacuumOptions options) : ParseInfo(TYPE), options(options), has_table(false) { -} - -unique_ptr VacuumInfo::Copy() { - auto result = make_uniq(options); - result->has_table = has_table; - if (has_table) { - result->ref = ref->Copy(); - } - return result; -} - -} // namespace duckdb - - - - - - - - - - -namespace duckdb { - -bool ParsedExpression::IsAggregate() const { - bool is_aggregate = false; - ParsedExpressionIterator::EnumerateChildren( - *this, [&](const ParsedExpression &child) { is_aggregate |= child.IsAggregate(); }); - return is_aggregate; -} - -bool ParsedExpression::IsWindow() const { - bool is_window = false; - ParsedExpressionIterator::EnumerateChildren(*this, - [&](const ParsedExpression &child) { is_window |= child.IsWindow(); }); - return is_window; -} - -bool ParsedExpression::IsScalar() const { - bool is_scalar = true; - ParsedExpressionIterator::EnumerateChildren(*this, [&](const ParsedExpression &child) { - if (!child.IsScalar()) { - is_scalar = false; - } - }); - return is_scalar; -} - -bool ParsedExpression::HasParameter() const { - bool has_parameter = false; - ParsedExpressionIterator::EnumerateChildren( - *this, [&](const ParsedExpression &child) { has_parameter |= child.HasParameter(); }); - return has_parameter; -} - -bool ParsedExpression::HasSubquery() const { - bool has_subquery = false; - ParsedExpressionIterator::EnumerateChildren( - *this, [&](const ParsedExpression &child) { has_subquery |= child.HasSubquery(); }); - return has_subquery; -} - -bool ParsedExpression::Equals(const BaseExpression &other) const { - if (!BaseExpression::Equals(other)) { - return false; - } - switch (expression_class) { - case ExpressionClass::BETWEEN: - return BetweenExpression::Equal(Cast(), other.Cast()); - case ExpressionClass::CASE: - return CaseExpression::Equal(Cast(), other.Cast()); - case ExpressionClass::CAST: - return CastExpression::Equal(Cast(), other.Cast()); - case ExpressionClass::COLLATE: - return CollateExpression::Equal(Cast(), other.Cast()); - case ExpressionClass::COLUMN_REF: - return ColumnRefExpression::Equal(Cast(), other.Cast()); - case ExpressionClass::COMPARISON: - return ComparisonExpression::Equal(Cast(), other.Cast()); - case ExpressionClass::CONJUNCTION: - return ConjunctionExpression::Equal(Cast(), other.Cast()); - case ExpressionClass::CONSTANT: - return ConstantExpression::Equal(Cast(), other.Cast()); - case ExpressionClass::DEFAULT: - return true; - case ExpressionClass::FUNCTION: - return FunctionExpression::Equal(Cast(), other.Cast()); - case ExpressionClass::LAMBDA: - return LambdaExpression::Equal(Cast(), other.Cast()); - case ExpressionClass::OPERATOR: - return OperatorExpression::Equal(Cast(), other.Cast()); - case ExpressionClass::PARAMETER: - return ParameterExpression::Equal(Cast(), other.Cast()); - case ExpressionClass::POSITIONAL_REFERENCE: - return PositionalReferenceExpression::Equal(Cast(), - other.Cast()); - case ExpressionClass::STAR: - return StarExpression::Equal(Cast(), other.Cast()); - case ExpressionClass::SUBQUERY: - return SubqueryExpression::Equal(Cast(), other.Cast()); - case ExpressionClass::WINDOW: - return WindowExpression::Equal(Cast(), other.Cast()); - default: - throw SerializationException("Unsupported type for expression comparison!"); - } -} - -hash_t ParsedExpression::Hash() const { - hash_t hash = duckdb::Hash((uint32_t)type); - ParsedExpressionIterator::EnumerateChildren( - *this, [&](const ParsedExpression &child) { hash = CombineHash(child.Hash(), hash); }); - return hash; -} - -bool ParsedExpression::Equals(const unique_ptr &left, const unique_ptr &right) { - if (left.get() == right.get()) { - return true; - } - if (!left || !right) { - return false; - } - return left->Equals(*right); -} - -bool ParsedExpression::ListEquals(const vector> &left, - const vector> &right) { - return ExpressionUtil::ListEquals(left, right); -} - -} // namespace duckdb - - - - - - - - - - -namespace duckdb { - -void ParsedExpressionIterator::EnumerateChildren(const ParsedExpression &expression, - const std::function &callback) { - EnumerateChildren((ParsedExpression &)expression, [&](unique_ptr &child) { - D_ASSERT(child); - callback(*child); - }); -} - -void ParsedExpressionIterator::EnumerateChildren(ParsedExpression &expr, - const std::function &callback) { - EnumerateChildren(expr, [&](unique_ptr &child) { - D_ASSERT(child); - callback(*child); - }); -} - -void ParsedExpressionIterator::EnumerateChildren( - ParsedExpression &expr, const std::function &child)> &callback) { - switch (expr.expression_class) { - case ExpressionClass::BETWEEN: { - auto &cast_expr = expr.Cast(); - callback(cast_expr.input); - callback(cast_expr.lower); - callback(cast_expr.upper); - break; - } - case ExpressionClass::CASE: { - auto &case_expr = expr.Cast(); - for (auto &check : case_expr.case_checks) { - callback(check.when_expr); - callback(check.then_expr); - } - callback(case_expr.else_expr); - break; - } - case ExpressionClass::CAST: { - auto &cast_expr = expr.Cast(); - callback(cast_expr.child); - break; - } - case ExpressionClass::COLLATE: { - auto &cast_expr = expr.Cast(); - callback(cast_expr.child); - break; - } - case ExpressionClass::COMPARISON: { - auto &comp_expr = expr.Cast(); - callback(comp_expr.left); - callback(comp_expr.right); - break; - } - case ExpressionClass::CONJUNCTION: { - auto &conj_expr = expr.Cast(); - for (auto &child : conj_expr.children) { - callback(child); - } - break; - } - - case ExpressionClass::FUNCTION: { - auto &func_expr = expr.Cast(); - for (auto &child : func_expr.children) { - callback(child); - } - if (func_expr.filter) { - callback(func_expr.filter); - } - if (func_expr.order_bys) { - for (auto &order : func_expr.order_bys->orders) { - callback(order.expression); - } - } - break; - } - case ExpressionClass::LAMBDA: { - auto &lambda_expr = expr.Cast(); - callback(lambda_expr.lhs); - callback(lambda_expr.expr); - break; - } - case ExpressionClass::OPERATOR: { - auto &op_expr = expr.Cast(); - for (auto &child : op_expr.children) { - callback(child); - } - break; - } - case ExpressionClass::STAR: { - auto &star_expr = expr.Cast(); - if (star_expr.expr) { - callback(star_expr.expr); - } - break; - } - case ExpressionClass::SUBQUERY: { - auto &subquery_expr = expr.Cast(); - if (subquery_expr.child) { - callback(subquery_expr.child); - } - break; - } - case ExpressionClass::WINDOW: { - auto &window_expr = expr.Cast(); - for (auto &partition : window_expr.partitions) { - callback(partition); - } - for (auto &order : window_expr.orders) { - callback(order.expression); - } - for (auto &child : window_expr.children) { - callback(child); - } - if (window_expr.filter_expr) { - callback(window_expr.filter_expr); - } - if (window_expr.start_expr) { - callback(window_expr.start_expr); - } - if (window_expr.end_expr) { - callback(window_expr.end_expr); - } - if (window_expr.offset_expr) { - callback(window_expr.offset_expr); - } - if (window_expr.default_expr) { - callback(window_expr.default_expr); - } - break; - } - case ExpressionClass::BOUND_EXPRESSION: - case ExpressionClass::COLUMN_REF: - case ExpressionClass::CONSTANT: - case ExpressionClass::DEFAULT: - case ExpressionClass::PARAMETER: - case ExpressionClass::POSITIONAL_REFERENCE: - // these node types have no children - break; - default: - // called on non ParsedExpression type! - throw NotImplementedException("Unimplemented expression class"); - } -} - -void ParsedExpressionIterator::EnumerateQueryNodeModifiers( - QueryNode &node, const std::function &child)> &callback) { - - for (auto &modifier : node.modifiers) { - switch (modifier->type) { - case ResultModifierType::LIMIT_MODIFIER: { - auto &limit_modifier = modifier->Cast(); - if (limit_modifier.limit) { - callback(limit_modifier.limit); - } - if (limit_modifier.offset) { - callback(limit_modifier.offset); - } - } break; - - case ResultModifierType::LIMIT_PERCENT_MODIFIER: { - auto &limit_modifier = modifier->Cast(); - if (limit_modifier.limit) { - callback(limit_modifier.limit); - } - if (limit_modifier.offset) { - callback(limit_modifier.offset); - } - } break; - - case ResultModifierType::ORDER_MODIFIER: { - auto &order_modifier = modifier->Cast(); - for (auto &order : order_modifier.orders) { - callback(order.expression); - } - } break; - - case ResultModifierType::DISTINCT_MODIFIER: { - auto &distinct_modifier = modifier->Cast(); - for (auto &target : distinct_modifier.distinct_on_targets) { - callback(target); - } - } break; - - // do nothing - default: - break; - } - } -} - -void ParsedExpressionIterator::EnumerateTableRefChildren( - TableRef &ref, const std::function &child)> &callback) { - switch (ref.type) { - case TableReferenceType::EXPRESSION_LIST: { - auto &el_ref = ref.Cast(); - for (idx_t i = 0; i < el_ref.values.size(); i++) { - for (idx_t j = 0; j < el_ref.values[i].size(); j++) { - callback(el_ref.values[i][j]); - } - } - break; - } - case TableReferenceType::JOIN: { - auto &j_ref = ref.Cast(); - EnumerateTableRefChildren(*j_ref.left, callback); - EnumerateTableRefChildren(*j_ref.right, callback); - if (j_ref.condition) { - callback(j_ref.condition); - } - break; - } - case TableReferenceType::PIVOT: { - auto &p_ref = ref.Cast(); - EnumerateTableRefChildren(*p_ref.source, callback); - for (auto &aggr : p_ref.aggregates) { - callback(aggr); - } - break; - } - case TableReferenceType::SUBQUERY: { - auto &sq_ref = ref.Cast(); - EnumerateQueryNodeChildren(*sq_ref.subquery->node, callback); - break; - } - case TableReferenceType::TABLE_FUNCTION: { - auto &tf_ref = ref.Cast(); - callback(tf_ref.function); - break; - } - case TableReferenceType::BASE_TABLE: - case TableReferenceType::EMPTY: - // these TableRefs do not need to be unfolded - break; - case TableReferenceType::INVALID: - case TableReferenceType::CTE: - throw NotImplementedException("TableRef type not implemented for traversal"); - } -} - -void ParsedExpressionIterator::EnumerateQueryNodeChildren( - QueryNode &node, const std::function &child)> &callback) { - switch (node.type) { - case QueryNodeType::RECURSIVE_CTE_NODE: { - auto &rcte_node = node.Cast(); - EnumerateQueryNodeChildren(*rcte_node.left, callback); - EnumerateQueryNodeChildren(*rcte_node.right, callback); - break; - } - case QueryNodeType::CTE_NODE: { - auto &cte_node = node.Cast(); - EnumerateQueryNodeChildren(*cte_node.query, callback); - EnumerateQueryNodeChildren(*cte_node.child, callback); - break; - } - case QueryNodeType::SELECT_NODE: { - auto &sel_node = node.Cast(); - for (idx_t i = 0; i < sel_node.select_list.size(); i++) { - callback(sel_node.select_list[i]); - } - for (idx_t i = 0; i < sel_node.groups.group_expressions.size(); i++) { - callback(sel_node.groups.group_expressions[i]); - } - if (sel_node.where_clause) { - callback(sel_node.where_clause); - } - if (sel_node.having) { - callback(sel_node.having); - } - if (sel_node.qualify) { - callback(sel_node.qualify); - } - - EnumerateTableRefChildren(*sel_node.from_table.get(), callback); - break; - } - case QueryNodeType::SET_OPERATION_NODE: { - auto &setop_node = node.Cast(); - EnumerateQueryNodeChildren(*setop_node.left, callback); - EnumerateQueryNodeChildren(*setop_node.right, callback); - break; - } - default: - throw NotImplementedException("QueryNode type not implemented for traversal"); - } - - if (!node.modifiers.empty()) { - EnumerateQueryNodeModifiers(node, callback); - } - - for (auto &kv : node.cte_map.map) { - EnumerateQueryNodeChildren(*kv.second->query->node, callback); - } -} - -} // namespace duckdb - - - - - - - - - - - - - - - - -namespace duckdb { - -Parser::Parser(ParserOptions options_p) : options(options_p) { -} - -struct UnicodeSpace { - UnicodeSpace(idx_t pos, idx_t bytes) : pos(pos), bytes(bytes) { - } - - idx_t pos; - idx_t bytes; -}; - -static bool ReplaceUnicodeSpaces(const string &query, string &new_query, vector &unicode_spaces) { - if (unicode_spaces.empty()) { - // no unicode spaces found - return false; - } - idx_t prev = 0; - for (auto &usp : unicode_spaces) { - new_query += query.substr(prev, usp.pos - prev); - new_query += " "; - prev = usp.pos + usp.bytes; - } - new_query += query.substr(prev, query.size() - prev); - return true; -} - -// This function strips unicode space characters from the query and replaces them with regular spaces -// It returns true if any unicode space characters were found and stripped -// See here for a list of unicode space characters - https://jkorpela.fi/chars/spaces.html -bool Parser::StripUnicodeSpaces(const string &query_str, string &new_query) { - const idx_t NBSP_LEN = 2; - const idx_t USP_LEN = 3; - idx_t pos = 0; - unsigned char quote; - vector unicode_spaces; - auto query = const_uchar_ptr_cast(query_str.c_str()); - auto qsize = query_str.size(); - -regular: - for (; pos + 2 < qsize; pos++) { - if (query[pos] == 0xC2) { - if (query[pos + 1] == 0xA0) { - // U+00A0 - C2A0 - unicode_spaces.emplace_back(pos, NBSP_LEN); - } - } - if (query[pos] == 0xE2) { - if (query[pos + 1] == 0x80) { - if (query[pos + 2] >= 0x80 && query[pos + 2] <= 0x8B) { - // U+2000 to U+200B - // E28080 - E2808B - unicode_spaces.emplace_back(pos, USP_LEN); - } else if (query[pos + 2] == 0xAF) { - // U+202F - E280AF - unicode_spaces.emplace_back(pos, USP_LEN); - } - } else if (query[pos + 1] == 0x81) { - if (query[pos + 2] == 0x9F) { - // U+205F - E2819f - unicode_spaces.emplace_back(pos, USP_LEN); - } else if (query[pos + 2] == 0xA0) { - // U+2060 - E281A0 - unicode_spaces.emplace_back(pos, USP_LEN); - } - } - } else if (query[pos] == 0xE3) { - if (query[pos + 1] == 0x80 && query[pos + 2] == 0x80) { - // U+3000 - E38080 - unicode_spaces.emplace_back(pos, USP_LEN); - } - } else if (query[pos] == 0xEF) { - if (query[pos + 1] == 0xBB && query[pos + 2] == 0xBF) { - // U+FEFF - EFBBBF - unicode_spaces.emplace_back(pos, USP_LEN); - } - } else if (query[pos] == '"' || query[pos] == '\'') { - quote = query[pos]; - pos++; - goto in_quotes; - } else if (query[pos] == '-' && query[pos + 1] == '-') { - goto in_comment; - } - } - goto end; -in_quotes: - for (; pos + 1 < qsize; pos++) { - if (query[pos] == quote) { - if (query[pos + 1] == quote) { - // escaped quote - pos++; - continue; - } - pos++; - goto regular; - } - } - goto end; -in_comment: - for (; pos < qsize; pos++) { - if (query[pos] == '\n' || query[pos] == '\r') { - goto regular; - } - } - goto end; -end: - return ReplaceUnicodeSpaces(query_str, new_query, unicode_spaces); -} - -vector SplitQueryStringIntoStatements(const string &query) { - // Break sql string down into sql statements using the tokenizer - vector query_statements; - auto tokens = Parser::Tokenize(query); - auto next_statement_start = 0; - for (idx_t i = 1; i < tokens.size(); ++i) { - auto &t_prev = tokens[i - 1]; - auto &t = tokens[i]; - if (t_prev.type == SimplifiedTokenType::SIMPLIFIED_TOKEN_OPERATOR) { - // LCOV_EXCL_START - for (idx_t c = t_prev.start; c <= t.start; ++c) { - if (query.c_str()[c] == ';') { - query_statements.emplace_back(query.substr(next_statement_start, t.start - next_statement_start)); - next_statement_start = tokens[i].start; - } - } - // LCOV_EXCL_STOP - } - } - query_statements.emplace_back(query.substr(next_statement_start, query.size() - next_statement_start)); - return query_statements; -} - -void Parser::ParseQuery(const string &query) { - Transformer transformer(options); - string parser_error; - { - // check if there are any unicode spaces in the string - string new_query; - if (StripUnicodeSpaces(query, new_query)) { - // there are - strip the unicode spaces and re-run the query - ParseQuery(new_query); - return; - } - } - { - PostgresParser::SetPreserveIdentifierCase(options.preserve_identifier_case); - bool parsing_succeed = false; - // Creating a new scope to prevent multiple PostgresParser destructors being called - // which led to some memory issues - { - PostgresParser parser; - parser.Parse(query); - if (parser.success) { - if (!parser.parse_tree) { - // empty statement - return; - } - - // if it succeeded, we transform the Postgres parse tree into a list of - // SQLStatements - transformer.TransformParseTree(parser.parse_tree, statements); - parsing_succeed = true; - } else { - parser_error = QueryErrorContext::Format(query, parser.error_message, parser.error_location - 1); - } - } - // If DuckDB fails to parse the entire sql string, break the string down into individual statements - // using ';' as the delimiter so that parser extensions can parse the statement - if (parsing_succeed) { - // no-op - // return here would require refactoring into another function. o.w. will just no-op in order to run wrap up - // code at the end of this function - } else if (!options.extensions || options.extensions->empty()) { - throw ParserException(parser_error); - } else { - // split sql string into statements and re-parse using extension - auto query_statements = SplitQueryStringIntoStatements(query); - auto stmt_loc = 0; - for (auto const &query_statement : query_statements) { - string another_parser_error; - // Creating a new scope to allow extensions to use PostgresParser, which is not reentrant - { - PostgresParser another_parser; - another_parser.Parse(query_statement); - // LCOV_EXCL_START - // first see if DuckDB can parse this individual query statement - if (another_parser.success) { - if (!another_parser.parse_tree) { - // empty statement - continue; - } - transformer.TransformParseTree(another_parser.parse_tree, statements); - // important to set in the case of a mixture of DDB and parser ext statements - statements.back()->stmt_length = query_statement.size() - 1; - statements.back()->stmt_location = stmt_loc; - stmt_loc += query_statement.size(); - continue; - } else { - another_parser_error = QueryErrorContext::Format(query, another_parser.error_message, - another_parser.error_location - 1); - } - } // LCOV_EXCL_STOP - // LCOV_EXCL_START - // let extensions parse the statement which DuckDB failed to parse - bool parsed_single_statement = false; - for (auto &ext : *options.extensions) { - D_ASSERT(!parsed_single_statement); - D_ASSERT(ext.parse_function); - auto result = ext.parse_function(ext.parser_info.get(), query_statement); - if (result.type == ParserExtensionResultType::PARSE_SUCCESSFUL) { - auto statement = make_uniq(ext, std::move(result.parse_data)); - statement->stmt_length = query_statement.size() - 1; - statement->stmt_location = stmt_loc; - stmt_loc += query_statement.size(); - statements.push_back(std::move(statement)); - parsed_single_statement = true; - break; - } else if (result.type == ParserExtensionResultType::DISPLAY_EXTENSION_ERROR) { - throw ParserException(result.error); - } else { - // We move to the next one! - } - } - if (!parsed_single_statement) { - throw ParserException(parser_error); - } // LCOV_EXCL_STOP - } - } - } - if (!statements.empty()) { - auto &last_statement = statements.back(); - last_statement->stmt_length = query.size() - last_statement->stmt_location; - for (auto &statement : statements) { - statement->query = query; - if (statement->type == StatementType::CREATE_STATEMENT) { - auto &create = statement->Cast(); - create.info->sql = query.substr(statement->stmt_location, statement->stmt_length); - } - } - } -} - -vector Parser::Tokenize(const string &query) { - auto pg_tokens = PostgresParser::Tokenize(query); - vector result; - result.reserve(pg_tokens.size()); - for (auto &pg_token : pg_tokens) { - SimplifiedToken token; - switch (pg_token.type) { - case duckdb_libpgquery::PGSimplifiedTokenType::PG_SIMPLIFIED_TOKEN_IDENTIFIER: - token.type = SimplifiedTokenType::SIMPLIFIED_TOKEN_IDENTIFIER; - break; - case duckdb_libpgquery::PGSimplifiedTokenType::PG_SIMPLIFIED_TOKEN_NUMERIC_CONSTANT: - token.type = SimplifiedTokenType::SIMPLIFIED_TOKEN_NUMERIC_CONSTANT; - break; - case duckdb_libpgquery::PGSimplifiedTokenType::PG_SIMPLIFIED_TOKEN_STRING_CONSTANT: - token.type = SimplifiedTokenType::SIMPLIFIED_TOKEN_STRING_CONSTANT; - break; - case duckdb_libpgquery::PGSimplifiedTokenType::PG_SIMPLIFIED_TOKEN_OPERATOR: - token.type = SimplifiedTokenType::SIMPLIFIED_TOKEN_OPERATOR; - break; - case duckdb_libpgquery::PGSimplifiedTokenType::PG_SIMPLIFIED_TOKEN_KEYWORD: - token.type = SimplifiedTokenType::SIMPLIFIED_TOKEN_KEYWORD; - break; - // comments are not supported by our tokenizer right now - case duckdb_libpgquery::PGSimplifiedTokenType::PG_SIMPLIFIED_TOKEN_COMMENT: // LCOV_EXCL_START - token.type = SimplifiedTokenType::SIMPLIFIED_TOKEN_COMMENT; - break; - default: - throw InternalException("Unrecognized token category"); - } // LCOV_EXCL_STOP - token.start = pg_token.start; - result.push_back(token); - } - return result; -} - -bool Parser::IsKeyword(const string &text) { - return PostgresParser::IsKeyword(text); -} - -vector Parser::KeywordList() { - auto keywords = PostgresParser::KeywordList(); - vector result; - for (auto &kw : keywords) { - ParserKeyword res; - res.name = kw.text; - switch (kw.category) { - case duckdb_libpgquery::PGKeywordCategory::PG_KEYWORD_RESERVED: - res.category = KeywordCategory::KEYWORD_RESERVED; - break; - case duckdb_libpgquery::PGKeywordCategory::PG_KEYWORD_UNRESERVED: - res.category = KeywordCategory::KEYWORD_UNRESERVED; - break; - case duckdb_libpgquery::PGKeywordCategory::PG_KEYWORD_TYPE_FUNC: - res.category = KeywordCategory::KEYWORD_TYPE_FUNC; - break; - case duckdb_libpgquery::PGKeywordCategory::PG_KEYWORD_COL_NAME: - res.category = KeywordCategory::KEYWORD_COL_NAME; - break; - default: - throw InternalException("Unrecognized keyword category"); - } - result.push_back(res); - } - return result; -} - -vector> Parser::ParseExpressionList(const string &select_list, ParserOptions options) { - // construct a mock query prefixed with SELECT - string mock_query = "SELECT " + select_list; - // parse the query - Parser parser(options); - parser.ParseQuery(mock_query); - // check the statements - if (parser.statements.size() != 1 || parser.statements[0]->type != StatementType::SELECT_STATEMENT) { - throw ParserException("Expected a single SELECT statement"); - } - auto &select = parser.statements[0]->Cast(); - if (select.node->type != QueryNodeType::SELECT_NODE) { - throw ParserException("Expected a single SELECT node"); - } - auto &select_node = select.node->Cast(); - return std::move(select_node.select_list); -} - -GroupByNode Parser::ParseGroupByList(const string &group_by, ParserOptions options) { - // construct a mock SELECT query with our group_by expressions - string mock_query = StringUtil::Format("SELECT 42 GROUP BY %s", group_by); - // parse the query - Parser parser(options); - parser.ParseQuery(mock_query); - // check the result - if (parser.statements.size() != 1 || parser.statements[0]->type != StatementType::SELECT_STATEMENT) { - throw ParserException("Expected a single SELECT statement"); - } - auto &select = parser.statements[0]->Cast(); - D_ASSERT(select.node->type == QueryNodeType::SELECT_NODE); - auto &select_node = select.node->Cast(); - return std::move(select_node.groups); -} - -vector Parser::ParseOrderList(const string &select_list, ParserOptions options) { - // construct a mock query - string mock_query = "SELECT * FROM tbl ORDER BY " + select_list; - // parse the query - Parser parser(options); - parser.ParseQuery(mock_query); - // check the statements - if (parser.statements.size() != 1 || parser.statements[0]->type != StatementType::SELECT_STATEMENT) { - throw ParserException("Expected a single SELECT statement"); - } - auto &select = parser.statements[0]->Cast(); - D_ASSERT(select.node->type == QueryNodeType::SELECT_NODE); - auto &select_node = select.node->Cast(); - if (select_node.modifiers.empty() || select_node.modifiers[0]->type != ResultModifierType::ORDER_MODIFIER || - select_node.modifiers.size() != 1) { - throw ParserException("Expected a single ORDER clause"); - } - auto &order = select_node.modifiers[0]->Cast(); - return std::move(order.orders); -} - -void Parser::ParseUpdateList(const string &update_list, vector &update_columns, - vector> &expressions, ParserOptions options) { - // construct a mock query - string mock_query = "UPDATE tbl SET " + update_list; - // parse the query - Parser parser(options); - parser.ParseQuery(mock_query); - // check the statements - if (parser.statements.size() != 1 || parser.statements[0]->type != StatementType::UPDATE_STATEMENT) { - throw ParserException("Expected a single UPDATE statement"); - } - auto &update = parser.statements[0]->Cast(); - update_columns = std::move(update.set_info->columns); - expressions = std::move(update.set_info->expressions); -} - -vector>> Parser::ParseValuesList(const string &value_list, ParserOptions options) { - // construct a mock query - string mock_query = "VALUES " + value_list; - // parse the query - Parser parser(options); - parser.ParseQuery(mock_query); - // check the statements - if (parser.statements.size() != 1 || parser.statements[0]->type != StatementType::SELECT_STATEMENT) { - throw ParserException("Expected a single SELECT statement"); - } - auto &select = parser.statements[0]->Cast(); - if (select.node->type != QueryNodeType::SELECT_NODE) { - throw ParserException("Expected a single SELECT node"); - } - auto &select_node = select.node->Cast(); - if (!select_node.from_table || select_node.from_table->type != TableReferenceType::EXPRESSION_LIST) { - throw ParserException("Expected a single VALUES statement"); - } - auto &values_list = select_node.from_table->Cast(); - return std::move(values_list.values); -} - -ColumnList Parser::ParseColumnList(const string &column_list, ParserOptions options) { - string mock_query = "CREATE TABLE blabla (" + column_list + ")"; - Parser parser(options); - parser.ParseQuery(mock_query); - if (parser.statements.size() != 1 || parser.statements[0]->type != StatementType::CREATE_STATEMENT) { - throw ParserException("Expected a single CREATE statement"); - } - auto &create = parser.statements[0]->Cast(); - if (create.info->type != CatalogType::TABLE_ENTRY) { - throw InternalException("Expected a single CREATE TABLE statement"); - } - auto &info = create.info->Cast(); - return std::move(info.columns); -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -string QueryErrorContext::Format(const string &query, const string &error_message, int error_loc) { - if (error_loc < 0 || size_t(error_loc) >= query.size()) { - // no location in query provided - return error_message; - } - idx_t error_location = idx_t(error_loc); - // count the line numbers until the error location - // and set the start position as the first character of that line - idx_t start_pos = 0; - idx_t line_number = 1; - for (idx_t i = 0; i < error_location; i++) { - if (StringUtil::CharacterIsNewline(query[i])) { - line_number++; - start_pos = i + 1; - } - } - // now find either the next newline token after the query, or find the end of string - // this is the initial end position - idx_t end_pos = query.size(); - for (idx_t i = error_location; i < query.size(); i++) { - if (StringUtil::CharacterIsNewline(query[i])) { - end_pos = i; - break; - } - } - // now start scanning from the start pos - // we want to figure out the start and end pos of what we are going to render - // we want to render at most 80 characters in total, with the error_location located in the middle - const char *buf = query.c_str() + start_pos; - idx_t len = end_pos - start_pos; - vector render_widths; - vector positions; - if (Utf8Proc::IsValid(buf, len)) { - // for unicode awareness, we traverse the graphemes of the current line and keep track of their render widths - // and of their position in the string - for (idx_t cpos = 0; cpos < len;) { - auto char_render_width = Utf8Proc::RenderWidth(buf, len, cpos); - positions.push_back(cpos); - render_widths.push_back(char_render_width); - cpos = Utf8Proc::NextGraphemeCluster(buf, len, cpos); - } - } else { // LCOV_EXCL_START - // invalid utf-8, we can't do much at this point - // we just assume every character is a character, and every character has a render width of 1 - for (idx_t cpos = 0; cpos < len; cpos++) { - positions.push_back(cpos); - render_widths.push_back(1); - } - } // LCOV_EXCL_STOP - // now we want to find the (unicode aware) start and end position - idx_t epos = 0; - // start by finding the error location inside the array - for (idx_t i = 0; i < positions.size(); i++) { - if (positions[i] >= (error_location - start_pos)) { - epos = i; - break; - } - } - bool truncate_beginning = false; - bool truncate_end = false; - idx_t spos = 0; - // now we iterate backwards from the error location - // we show max 40 render width before the error location - idx_t current_render_width = 0; - for (idx_t i = epos; i > 0; i--) { - current_render_width += render_widths[i]; - if (current_render_width >= 40) { - truncate_beginning = true; - start_pos = positions[i]; - spos = i; - break; - } - } - // now do the same, but going forward - current_render_width = 0; - for (idx_t i = epos; i < positions.size(); i++) { - current_render_width += render_widths[i]; - if (current_render_width >= 40) { - truncate_end = true; - end_pos = positions[i]; - break; - } - } - string line_indicator = "LINE " + to_string(line_number) + ": "; - string begin_trunc = truncate_beginning ? "..." : ""; - string end_trunc = truncate_end ? "..." : ""; - - // get the render width of the error indicator (i.e. how many spaces we need to insert before the ^) - idx_t error_render_width = 0; - for (idx_t i = spos; i < epos; i++) { - error_render_width += render_widths[i]; - } - error_render_width += line_indicator.size() + begin_trunc.size(); - - // now first print the error message plus the current line (or a subset of the line) - string result = error_message; - result += "\n" + line_indicator + begin_trunc + query.substr(start_pos, end_pos - start_pos) + end_trunc; - // print an arrow pointing at the error location - result += "\n" + string(error_render_width, ' ') + "^"; - return result; -} - -string QueryErrorContext::FormatErrorRecursive(const string &msg, vector &values) { - string error_message = values.empty() ? msg : ExceptionFormatValue::Format(msg, values); - if (!statement || query_location >= statement->query.size()) { - // no statement provided or query location out of range - return error_message; - } - return Format(statement->query, error_message, query_location); -} - -} // namespace duckdb - - - - -namespace duckdb { - -string CTENode::ToString() const { - string result; - result += child->ToString(); - return result; -} - -bool CTENode::Equals(const QueryNode *other_p) const { - if (!QueryNode::Equals(other_p)) { - return false; - } - if (this == other_p) { - return true; - } - auto &other = other_p->Cast(); - - if (!query->Equals(other.query.get())) { - return false; - } - if (!child->Equals(other.child.get())) { - return false; - } - return true; -} - -unique_ptr CTENode::Copy() const { - auto result = make_uniq(); - result->ctename = ctename; - result->query = query->Copy(); - result->child = child->Copy(); - result->aliases = aliases; - this->CopyProperties(*result); - return std::move(result); -} - -} // namespace duckdb - - - - -namespace duckdb { - -string RecursiveCTENode::ToString() const { - string result; - result += "(" + left->ToString() + ")"; - result += " UNION "; - if (union_all) { - result += " ALL "; - } - result += "(" + right->ToString() + ")"; - return result; -} - -bool RecursiveCTENode::Equals(const QueryNode *other_p) const { - if (!QueryNode::Equals(other_p)) { - return false; - } - if (this == other_p) { - return true; - } - auto &other = other_p->Cast(); - - if (other.union_all != union_all) { - return false; - } - if (!left->Equals(other.left.get())) { - return false; - } - if (!right->Equals(other.right.get())) { - return false; - } - return true; -} - -unique_ptr RecursiveCTENode::Copy() const { - auto result = make_uniq(); - result->ctename = ctename; - result->union_all = union_all; - result->left = left->Copy(); - result->right = right->Copy(); - result->aliases = aliases; - this->CopyProperties(*result); - return std::move(result); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -SelectNode::SelectNode() - : QueryNode(QueryNodeType::SELECT_NODE), aggregate_handling(AggregateHandling::STANDARD_HANDLING) { -} - -string SelectNode::ToString() const { - string result; - result = cte_map.ToString(); - result += "SELECT "; - - // search for a distinct modifier - for (idx_t modifier_idx = 0; modifier_idx < modifiers.size(); modifier_idx++) { - if (modifiers[modifier_idx]->type == ResultModifierType::DISTINCT_MODIFIER) { - auto &distinct_modifier = modifiers[modifier_idx]->Cast(); - result += "DISTINCT "; - if (!distinct_modifier.distinct_on_targets.empty()) { - result += "ON ("; - for (idx_t k = 0; k < distinct_modifier.distinct_on_targets.size(); k++) { - if (k > 0) { - result += ", "; - } - result += distinct_modifier.distinct_on_targets[k]->ToString(); - } - result += ") "; - } - } - } - for (idx_t i = 0; i < select_list.size(); i++) { - if (i > 0) { - result += ", "; - } - result += select_list[i]->ToString(); - if (!select_list[i]->alias.empty()) { - result += StringUtil::Format(" AS %s", SQLIdentifier(select_list[i]->alias)); - } - } - if (from_table && from_table->type != TableReferenceType::EMPTY) { - result += " FROM " + from_table->ToString(); - } - if (where_clause) { - result += " WHERE " + where_clause->ToString(); - } - if (!groups.grouping_sets.empty()) { - result += " GROUP BY "; - // if we are dealing with multiple grouping sets, we have to add a few additional brackets - bool grouping_sets = groups.grouping_sets.size() > 1; - if (grouping_sets) { - result += "GROUPING SETS ("; - } - for (idx_t i = 0; i < groups.grouping_sets.size(); i++) { - auto &grouping_set = groups.grouping_sets[i]; - if (i > 0) { - result += ","; - } - if (grouping_set.empty()) { - result += "()"; - continue; - } - if (grouping_sets) { - result += "("; - } - bool first = true; - for (auto &grp : grouping_set) { - if (!first) { - result += ", "; - } - result += groups.group_expressions[grp]->ToString(); - first = false; - } - if (grouping_sets) { - result += ")"; - } - } - if (grouping_sets) { - result += ")"; - } - } else if (aggregate_handling == AggregateHandling::FORCE_AGGREGATES) { - result += " GROUP BY ALL"; - } - if (having) { - result += " HAVING " + having->ToString(); - } - if (qualify) { - result += " QUALIFY " + qualify->ToString(); - } - if (sample) { - result += " USING SAMPLE "; - result += sample->sample_size.ToString(); - if (sample->is_percentage) { - result += "%"; - } - result += " (" + EnumUtil::ToString(sample->method); - if (sample->seed >= 0) { - result += ", " + std::to_string(sample->seed); - } - result += ")"; - } - return result + ResultModifiersToString(); -} - -bool SelectNode::Equals(const QueryNode *other_p) const { - if (!QueryNode::Equals(other_p)) { - return false; - } - if (this == other_p) { - return true; - } - auto &other = other_p->Cast(); - - // SELECT - if (!ExpressionUtil::ListEquals(select_list, other.select_list)) { - return false; - } - // FROM - if (!TableRef::Equals(from_table, other.from_table)) { - return false; - } - // WHERE - if (!ParsedExpression::Equals(where_clause, other.where_clause)) { - return false; - } - // GROUP BY - if (!ParsedExpression::ListEquals(groups.group_expressions, other.groups.group_expressions)) { - return false; - } - if (groups.grouping_sets != other.groups.grouping_sets) { - return false; - } - if (!SampleOptions::Equals(sample.get(), other.sample.get())) { - return false; - } - // HAVING - if (!ParsedExpression::Equals(having, other.having)) { - return false; - } - // QUALIFY - if (!ParsedExpression::Equals(qualify, other.qualify)) { - return false; - } - return true; -} - -unique_ptr SelectNode::Copy() const { - auto result = make_uniq(); - for (auto &child : select_list) { - result->select_list.push_back(child->Copy()); - } - result->from_table = from_table ? from_table->Copy() : nullptr; - result->where_clause = where_clause ? where_clause->Copy() : nullptr; - // groups - for (auto &group : groups.group_expressions) { - result->groups.group_expressions.push_back(group->Copy()); - } - result->groups.grouping_sets = groups.grouping_sets; - result->aggregate_handling = aggregate_handling; - result->having = having ? having->Copy() : nullptr; - result->qualify = qualify ? qualify->Copy() : nullptr; - result->sample = sample ? sample->Copy() : nullptr; - this->CopyProperties(*result); - return std::move(result); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -string SetOperationNode::ToString() const { - string result; - result = cte_map.ToString(); - result += "(" + left->ToString() + ") "; - bool is_distinct = false; - for (idx_t modifier_idx = 0; modifier_idx < modifiers.size(); modifier_idx++) { - if (modifiers[modifier_idx]->type == ResultModifierType::DISTINCT_MODIFIER) { - is_distinct = true; - break; - } - } - - switch (setop_type) { - case SetOperationType::UNION: - result += is_distinct ? "UNION" : "UNION ALL"; - break; - case SetOperationType::UNION_BY_NAME: - result += is_distinct ? "UNION BY NAME" : "UNION ALL BY NAME"; - break; - case SetOperationType::EXCEPT: - D_ASSERT(is_distinct); - result += "EXCEPT"; - break; - case SetOperationType::INTERSECT: - D_ASSERT(is_distinct); - result += "INTERSECT"; - break; - default: - throw InternalException("Unsupported set operation type"); - } - result += " (" + right->ToString() + ")"; - return result + ResultModifiersToString(); -} - -bool SetOperationNode::Equals(const QueryNode *other_p) const { - if (!QueryNode::Equals(other_p)) { - return false; - } - if (this == other_p) { - return true; - } - auto &other = other_p->Cast(); - if (setop_type != other.setop_type) { - return false; - } - if (!left->Equals(other.left.get())) { - return false; - } - if (!right->Equals(other.right.get())) { - return false; - } - return true; -} - -unique_ptr SetOperationNode::Copy() const { - auto result = make_uniq(); - result->setop_type = setop_type; - result->left = left->Copy(); - result->right = right->Copy(); - this->CopyProperties(*result); - return std::move(result); -} - -} // namespace duckdb - - - - - - - - - - -namespace duckdb { - -CommonTableExpressionMap::CommonTableExpressionMap() { -} - -CommonTableExpressionMap CommonTableExpressionMap::Copy() const { - CommonTableExpressionMap res; - for (auto &kv : this->map) { - auto kv_info = make_uniq(); - for (auto &al : kv.second->aliases) { - kv_info->aliases.push_back(al); - } - kv_info->query = unique_ptr_cast(kv.second->query->Copy()); - kv_info->materialized = kv.second->materialized; - res.map[kv.first] = std::move(kv_info); - } - return res; -} - -string CommonTableExpressionMap::ToString() const { - if (map.empty()) { - return string(); - } - // check if there are any recursive CTEs - bool has_recursive = false; - for (auto &kv : map) { - if (kv.second->query->node->type == QueryNodeType::RECURSIVE_CTE_NODE) { - has_recursive = true; - break; - } - } - string result = "WITH "; - if (has_recursive) { - result += "RECURSIVE "; - } - bool first_cte = true; - for (auto &kv : map) { - if (!first_cte) { - result += ", "; - } - auto &cte = *kv.second; - result += KeywordHelper::WriteOptionallyQuoted(kv.first); - if (!cte.aliases.empty()) { - result += " ("; - for (idx_t k = 0; k < cte.aliases.size(); k++) { - if (k > 0) { - result += ", "; - } - result += KeywordHelper::WriteOptionallyQuoted(cte.aliases[k]); - } - result += ")"; - } - if (kv.second->materialized == CTEMaterialize::CTE_MATERIALIZE_ALWAYS) { - result += " AS MATERIALIZED ("; - } else if (kv.second->materialized == CTEMaterialize::CTE_MATERIALIZE_NEVER) { - result += " AS NOT MATERIALIZED ("; - } else { - result += " AS ("; - } - result += cte.query->ToString(); - result += ")"; - first_cte = false; - } - return result; -} - -string QueryNode::ResultModifiersToString() const { - string result; - for (idx_t modifier_idx = 0; modifier_idx < modifiers.size(); modifier_idx++) { - auto &modifier = *modifiers[modifier_idx]; - if (modifier.type == ResultModifierType::ORDER_MODIFIER) { - auto &order_modifier = modifier.Cast(); - result += " ORDER BY "; - for (idx_t k = 0; k < order_modifier.orders.size(); k++) { - if (k > 0) { - result += ", "; - } - result += order_modifier.orders[k].ToString(); - } - } else if (modifier.type == ResultModifierType::LIMIT_MODIFIER) { - auto &limit_modifier = modifier.Cast(); - if (limit_modifier.limit) { - result += " LIMIT " + limit_modifier.limit->ToString(); - } - if (limit_modifier.offset) { - result += " OFFSET " + limit_modifier.offset->ToString(); - } - } else if (modifier.type == ResultModifierType::LIMIT_PERCENT_MODIFIER) { - auto &limit_p_modifier = modifier.Cast(); - if (limit_p_modifier.limit) { - result += " LIMIT (" + limit_p_modifier.limit->ToString() + ") %"; - } - if (limit_p_modifier.offset) { - result += " OFFSET " + limit_p_modifier.offset->ToString(); - } - } - } - return result; -} - -bool QueryNode::Equals(const QueryNode *other) const { - if (!other) { - return false; - } - if (this == other) { - return true; - } - if (other->type != this->type) { - return false; - } - - if (modifiers.size() != other->modifiers.size()) { - return false; - } - for (idx_t i = 0; i < modifiers.size(); i++) { - if (!modifiers[i]->Equals(*other->modifiers[i])) { - return false; - } - } - // WITH clauses (CTEs) - if (cte_map.map.size() != other->cte_map.map.size()) { - return false; - } - for (auto &entry : cte_map.map) { - auto other_entry = other->cte_map.map.find(entry.first); - if (other_entry == other->cte_map.map.end()) { - return false; - } - if (entry.second->aliases != other_entry->second->aliases) { - return false; - } - if (!entry.second->query->Equals(*other_entry->second->query)) { - return false; - } - } - return other->type == type; -} - -void QueryNode::CopyProperties(QueryNode &other) const { - for (auto &modifier : modifiers) { - other.modifiers.push_back(modifier->Copy()); - } - for (auto &kv : cte_map.map) { - auto kv_info = make_uniq(); - for (auto &al : kv.second->aliases) { - kv_info->aliases.push_back(al); - } - kv_info->query = unique_ptr_cast(kv.second->query->Copy()); - kv_info->materialized = kv.second->materialized; - other.cte_map.map[kv.first] = std::move(kv_info); - } -} - -void QueryNode::AddDistinct() { - // check if we already have a DISTINCT modifier - for (idx_t modifier_idx = modifiers.size(); modifier_idx > 0; modifier_idx--) { - auto &modifier = *modifiers[modifier_idx - 1]; - if (modifier.type == ResultModifierType::DISTINCT_MODIFIER) { - auto &distinct_modifier = modifier.Cast(); - if (distinct_modifier.distinct_on_targets.empty()) { - // we have a DISTINCT without an ON clause - this distinct does not need to be added - return; - } - } else if (modifier.type == ResultModifierType::LIMIT_MODIFIER || - modifier.type == ResultModifierType::LIMIT_PERCENT_MODIFIER) { - // we encountered a LIMIT or LIMIT PERCENT - these change the result of DISTINCT, so we do need to push a - // DISTINCT relation - break; - } - } - modifiers.push_back(make_uniq()); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -bool ResultModifier::Equals(const ResultModifier &other) const { - return type == other.type; -} - -bool LimitModifier::Equals(const ResultModifier &other_p) const { - if (!ResultModifier::Equals(other_p)) { - return false; - } - auto &other = other_p.Cast(); - if (!ParsedExpression::Equals(limit, other.limit)) { - return false; - } - if (!ParsedExpression::Equals(offset, other.offset)) { - return false; - } - return true; -} - -unique_ptr LimitModifier::Copy() const { - auto copy = make_uniq(); - if (limit) { - copy->limit = limit->Copy(); - } - if (offset) { - copy->offset = offset->Copy(); - } - return std::move(copy); -} - -bool DistinctModifier::Equals(const ResultModifier &other_p) const { - if (!ResultModifier::Equals(other_p)) { - return false; - } - auto &other = other_p.Cast(); - if (!ExpressionUtil::ListEquals(distinct_on_targets, other.distinct_on_targets)) { - return false; - } - return true; -} - -unique_ptr DistinctModifier::Copy() const { - auto copy = make_uniq(); - for (auto &expr : distinct_on_targets) { - copy->distinct_on_targets.push_back(expr->Copy()); - } - return std::move(copy); -} - -bool OrderModifier::Equals(const ResultModifier &other_p) const { - if (!ResultModifier::Equals(other_p)) { - return false; - } - auto &other = other_p.Cast(); - if (orders.size() != other.orders.size()) { - return false; - } - for (idx_t i = 0; i < orders.size(); i++) { - if (orders[i].type != other.orders[i].type) { - return false; - } - if (!BaseExpression::Equals(*orders[i].expression, *other.orders[i].expression)) { - return false; - } - } - return true; -} - -bool OrderModifier::Equals(const unique_ptr &left, const unique_ptr &right) { - if (left.get() == right.get()) { - return true; - } - if (!left || !right) { - return false; - } - return left->Equals(*right); -} - -unique_ptr OrderModifier::Copy() const { - auto copy = make_uniq(); - for (auto &order : orders) { - copy->orders.emplace_back(order.type, order.null_order, order.expression->Copy()); - } - return std::move(copy); -} - -string OrderByNode::ToString() const { - auto str = expression->ToString(); - switch (type) { - case OrderType::ASCENDING: - str += " ASC"; - break; - case OrderType::DESCENDING: - str += " DESC"; - break; - default: - break; - } - - switch (null_order) { - case OrderByNullType::NULLS_FIRST: - str += " NULLS FIRST"; - break; - case OrderByNullType::NULLS_LAST: - str += " NULLS LAST"; - break; - default: - break; - } - return str; -} - -bool LimitPercentModifier::Equals(const ResultModifier &other_p) const { - if (!ResultModifier::Equals(other_p)) { - return false; - } - auto &other = other_p.Cast(); - if (!ParsedExpression::Equals(limit, other.limit)) { - return false; - } - if (!ParsedExpression::Equals(offset, other.offset)) { - return false; - } - return true; -} - -unique_ptr LimitPercentModifier::Copy() const { - auto copy = make_uniq(); - if (limit) { - copy->limit = limit->Copy(); - } - if (offset) { - copy->offset = offset->Copy(); - } - return std::move(copy); -} - -} // namespace duckdb - - -namespace duckdb { - -AlterStatement::AlterStatement() : SQLStatement(StatementType::ALTER_STATEMENT) { -} - -AlterStatement::AlterStatement(const AlterStatement &other) : SQLStatement(other), info(other.info->Copy()) { -} - -unique_ptr AlterStatement::Copy() const { - return unique_ptr(new AlterStatement(*this)); -} - -} // namespace duckdb - - -namespace duckdb { - -AttachStatement::AttachStatement() : SQLStatement(StatementType::ATTACH_STATEMENT) { -} - -AttachStatement::AttachStatement(const AttachStatement &other) : SQLStatement(other), info(other.info->Copy()) { -} - -unique_ptr AttachStatement::Copy() const { - return unique_ptr(new AttachStatement(*this)); -} - -} // namespace duckdb - - -namespace duckdb { - -CallStatement::CallStatement() : SQLStatement(StatementType::CALL_STATEMENT) { -} - -CallStatement::CallStatement(const CallStatement &other) : SQLStatement(other), function(other.function->Copy()) { -} - -unique_ptr CallStatement::Copy() const { - return unique_ptr(new CallStatement(*this)); -} - -} // namespace duckdb - - -namespace duckdb { - -CopyStatement::CopyStatement() : SQLStatement(StatementType::COPY_STATEMENT), info(make_uniq()) { -} - -CopyStatement::CopyStatement(const CopyStatement &other) : SQLStatement(other), info(other.info->Copy()) { - if (other.select_statement) { - select_statement = other.select_statement->Copy(); - } -} - -string CopyStatement::CopyOptionsToString(const string &format, - const case_insensitive_map_t> &options) const { - if (format.empty() && options.empty()) { - return string(); - } - string result; - - result += " ("; - if (!format.empty()) { - result += " FORMAT "; - result += format; - } - for (auto it = options.begin(); it != options.end(); it++) { - if (!format.empty() || it != options.begin()) { - result += ", "; - } - auto &name = it->first; - auto &values = it->second; - - result += name + " "; - if (values.empty()) { - // Options like HEADER don't need an explicit value - // just providing the name already sets it to true - } else if (values.size() == 1) { - result += values[0].ToSQLString(); - } else { - result += "( "; - for (idx_t i = 0; i < values.size(); i++) { - if (i) { - result += ", "; - } - result += values[i].ToSQLString(); - } - result += " )"; - } - } - result += " )"; - return result; -} - -// COPY table-name (c1, c2, ..) -string TablePart(const CopyInfo &info) { - string result; - - if (!info.catalog.empty()) { - result += KeywordHelper::WriteOptionallyQuoted(info.catalog) + "."; - } - if (!info.schema.empty()) { - result += KeywordHelper::WriteOptionallyQuoted(info.schema) + "."; - } - D_ASSERT(!info.table.empty()); - result += KeywordHelper::WriteOptionallyQuoted(info.table); - - // (c1, c2, ..) - if (!info.select_list.empty()) { - result += " ("; - for (idx_t i = 0; i < info.select_list.size(); i++) { - if (i > 0) { - result += ", "; - } - result += KeywordHelper::WriteOptionallyQuoted(info.select_list[i]); - } - result += " )"; - } - return result; -} - -string CopyStatement::ToString() const { - string result; - - result += "COPY "; - if (info->is_from) { - D_ASSERT(!select_statement); - result += TablePart(*info); - result += " FROM"; - result += StringUtil::Format(" %s", SQLString(info->file_path)); - result += CopyOptionsToString(info->format, info->options); - } else { - if (select_statement) { - // COPY (select-node) TO ... - result += "(" + select_statement->ToString() + ")"; - } else { - result += TablePart(*info); - } - result += " TO "; - result += StringUtil::Format("%s", SQLString(info->file_path)); - result += CopyOptionsToString(info->format, info->options); - } - return result; -} - -unique_ptr CopyStatement::Copy() const { - return unique_ptr(new CopyStatement(*this)); -} - -} // namespace duckdb - - -namespace duckdb { - -CreateStatement::CreateStatement() : SQLStatement(StatementType::CREATE_STATEMENT) { -} - -CreateStatement::CreateStatement(const CreateStatement &other) : SQLStatement(other), info(other.info->Copy()) { -} - -unique_ptr CreateStatement::Copy() const { - return unique_ptr(new CreateStatement(*this)); -} - -} // namespace duckdb - - - -namespace duckdb { - -DeleteStatement::DeleteStatement() : SQLStatement(StatementType::DELETE_STATEMENT) { -} - -DeleteStatement::DeleteStatement(const DeleteStatement &other) : SQLStatement(other), table(other.table->Copy()) { - if (other.condition) { - condition = other.condition->Copy(); - } - for (const auto &using_clause : other.using_clauses) { - using_clauses.push_back(using_clause->Copy()); - } - for (auto &expr : other.returning_list) { - returning_list.emplace_back(expr->Copy()); - } - cte_map = other.cte_map.Copy(); -} - -string DeleteStatement::ToString() const { - string result; - result = cte_map.ToString(); - result += "DELETE FROM "; - result += table->ToString(); - if (!using_clauses.empty()) { - result += " USING "; - for (idx_t i = 0; i < using_clauses.size(); i++) { - if (i > 0) { - result += ", "; - } - result += using_clauses[i]->ToString(); - } - } - if (condition) { - result += " WHERE " + condition->ToString(); - } - - if (!returning_list.empty()) { - result += " RETURNING "; - for (idx_t i = 0; i < returning_list.size(); i++) { - if (i > 0) { - result += ", "; - } - result += returning_list[i]->ToString(); - } - } - return result; -} - -unique_ptr DeleteStatement::Copy() const { - return unique_ptr(new DeleteStatement(*this)); -} - -} // namespace duckdb - - -namespace duckdb { - -DetachStatement::DetachStatement() : SQLStatement(StatementType::DETACH_STATEMENT) { -} - -DetachStatement::DetachStatement(const DetachStatement &other) : SQLStatement(other), info(other.info->Copy()) { -} - -unique_ptr DetachStatement::Copy() const { - return unique_ptr(new DetachStatement(*this)); -} - -} // namespace duckdb - - -namespace duckdb { - -DropStatement::DropStatement() : SQLStatement(StatementType::DROP_STATEMENT), info(make_uniq()) { -} - -DropStatement::DropStatement(const DropStatement &other) : SQLStatement(other), info(other.info->Copy()) { -} - -unique_ptr DropStatement::Copy() const { - return unique_ptr(new DropStatement(*this)); -} - -} // namespace duckdb - - -namespace duckdb { - -ExecuteStatement::ExecuteStatement() : SQLStatement(StatementType::EXECUTE_STATEMENT) { -} - -ExecuteStatement::ExecuteStatement(const ExecuteStatement &other) : SQLStatement(other), name(other.name) { - for (const auto &item : other.named_values) { - named_values.emplace(std::make_pair(item.first, item.second->Copy())); - } -} - -unique_ptr ExecuteStatement::Copy() const { - return unique_ptr(new ExecuteStatement(*this)); -} - -} // namespace duckdb - - -namespace duckdb { - -ExplainStatement::ExplainStatement(unique_ptr stmt, ExplainType explain_type) - : SQLStatement(StatementType::EXPLAIN_STATEMENT), stmt(std::move(stmt)), explain_type(explain_type) { -} - -ExplainStatement::ExplainStatement(const ExplainStatement &other) - : SQLStatement(other), stmt(other.stmt->Copy()), explain_type(other.explain_type) { -} - -unique_ptr ExplainStatement::Copy() const { - return unique_ptr(new ExplainStatement(*this)); -} - -} // namespace duckdb - - -namespace duckdb { - -ExportStatement::ExportStatement(unique_ptr info) - : SQLStatement(StatementType::EXPORT_STATEMENT), info(std::move(info)) { -} - -ExportStatement::ExportStatement(const ExportStatement &other) : SQLStatement(other), info(other.info->Copy()) { -} - -unique_ptr ExportStatement::Copy() const { - return unique_ptr(new ExportStatement(*this)); -} - -} // namespace duckdb - - -namespace duckdb { - -ExtensionStatement::ExtensionStatement(ParserExtension extension_p, unique_ptr parse_data_p) - : SQLStatement(StatementType::EXTENSION_STATEMENT), extension(std::move(extension_p)), - parse_data(std::move(parse_data_p)) { -} - -unique_ptr ExtensionStatement::Copy() const { - return make_uniq(extension, parse_data->Copy()); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -OnConflictInfo::OnConflictInfo() : action_type(OnConflictAction::THROW) { -} - -OnConflictInfo::OnConflictInfo(const OnConflictInfo &other) - : action_type(other.action_type), indexed_columns(other.indexed_columns) { - if (other.set_info) { - set_info = other.set_info->Copy(); - } - if (other.condition) { - condition = other.condition->Copy(); - } -} - -unique_ptr OnConflictInfo::Copy() const { - return unique_ptr(new OnConflictInfo(*this)); -} - -InsertStatement::InsertStatement() - : SQLStatement(StatementType::INSERT_STATEMENT), schema(DEFAULT_SCHEMA), catalog(INVALID_CATALOG) { -} - -InsertStatement::InsertStatement(const InsertStatement &other) - : SQLStatement(other), select_statement(unique_ptr_cast( - other.select_statement ? other.select_statement->Copy() : nullptr)), - columns(other.columns), table(other.table), schema(other.schema), catalog(other.catalog), - default_values(other.default_values), column_order(other.column_order) { - cte_map = other.cte_map.Copy(); - for (auto &expr : other.returning_list) { - returning_list.emplace_back(expr->Copy()); - } - if (other.table_ref) { - table_ref = other.table_ref->Copy(); - } - if (other.on_conflict_info) { - on_conflict_info = other.on_conflict_info->Copy(); - } -} - -string InsertStatement::OnConflictActionToString(OnConflictAction action) { - switch (action) { - case OnConflictAction::NOTHING: - return "DO NOTHING"; - case OnConflictAction::REPLACE: - case OnConflictAction::UPDATE: - return "DO UPDATE"; - case OnConflictAction::THROW: - // Explicitly left empty, for ToString purposes - return ""; - default: { - throw NotImplementedException("type not implemented for OnConflictActionType"); - } - } -} - -string InsertStatement::ToString() const { - bool or_replace_shorthand_set = false; - string result; - - result = cte_map.ToString(); - result += "INSERT"; - if (on_conflict_info && on_conflict_info->action_type == OnConflictAction::REPLACE) { - or_replace_shorthand_set = true; - result += " OR REPLACE"; - } - result += " INTO "; - if (!catalog.empty()) { - result += KeywordHelper::WriteOptionallyQuoted(catalog) + "."; - } - if (!schema.empty()) { - result += KeywordHelper::WriteOptionallyQuoted(schema) + "."; - } - result += KeywordHelper::WriteOptionallyQuoted(table); - // Write the (optional) alias of the insert target - if (table_ref && !table_ref->alias.empty()) { - result += StringUtil::Format(" AS %s", KeywordHelper::WriteOptionallyQuoted(table_ref->alias)); - } - if (column_order == InsertColumnOrder::INSERT_BY_NAME) { - result += " BY NAME"; - } - if (!columns.empty()) { - result += " ("; - for (idx_t i = 0; i < columns.size(); i++) { - if (i > 0) { - result += ", "; - } - result += KeywordHelper::WriteOptionallyQuoted(columns[i]); - } - result += " )"; - } - result += " "; - auto values_list = GetValuesList(); - if (values_list) { - D_ASSERT(!default_values); - values_list->alias = string(); - result += values_list->ToString(); - } else if (select_statement) { - D_ASSERT(!default_values); - result += select_statement->ToString(); - } else { - D_ASSERT(default_values); - result += "DEFAULT VALUES"; - } - if (!or_replace_shorthand_set && on_conflict_info) { - auto &conflict_info = *on_conflict_info; - result += " ON CONFLICT "; - // (optional) conflict target - if (!conflict_info.indexed_columns.empty()) { - result += "("; - auto &columns = conflict_info.indexed_columns; - for (auto it = columns.begin(); it != columns.end();) { - result += StringUtil::Lower(*it); - if (++it != columns.end()) { - result += ", "; - } - } - result += " )"; - } - - // (optional) where clause - if (conflict_info.condition) { - result += " WHERE " + conflict_info.condition->ToString(); - } - result += " " + OnConflictActionToString(conflict_info.action_type); - if (conflict_info.set_info) { - D_ASSERT(conflict_info.action_type == OnConflictAction::UPDATE); - result += " SET "; - auto &set_info = *conflict_info.set_info; - D_ASSERT(set_info.columns.size() == set_info.expressions.size()); - // SET = - for (idx_t i = 0; i < set_info.columns.size(); i++) { - auto &column = set_info.columns[i]; - auto &expr = set_info.expressions[i]; - if (i) { - result += ", "; - } - result += StringUtil::Lower(column) + " = " + expr->ToString(); - } - // (optional) where clause - if (set_info.condition) { - result += " WHERE " + set_info.condition->ToString(); - } - } - } - if (!returning_list.empty()) { - result += " RETURNING "; - for (idx_t i = 0; i < returning_list.size(); i++) { - if (i > 0) { - result += ", "; - } - result += returning_list[i]->ToString(); - } - } - return result; -} - -unique_ptr InsertStatement::Copy() const { - return unique_ptr(new InsertStatement(*this)); -} - -optional_ptr InsertStatement::GetValuesList() const { - if (!select_statement) { - return nullptr; - } - if (select_statement->node->type != QueryNodeType::SELECT_NODE) { - return nullptr; - } - auto &node = select_statement->node->Cast(); - if (node.where_clause || node.qualify || node.having) { - return nullptr; - } - if (!node.cte_map.map.empty()) { - return nullptr; - } - if (!node.groups.grouping_sets.empty()) { - return nullptr; - } - if (node.aggregate_handling != AggregateHandling::STANDARD_HANDLING) { - return nullptr; - } - if (node.select_list.size() != 1 || node.select_list[0]->type != ExpressionType::STAR) { - return nullptr; - } - if (!node.from_table || node.from_table->type != TableReferenceType::EXPRESSION_LIST) { - return nullptr; - } - return &node.from_table->Cast(); -} - -} // namespace duckdb - - -namespace duckdb { - -LoadStatement::LoadStatement() : SQLStatement(StatementType::LOAD_STATEMENT) { -} - -LoadStatement::LoadStatement(const LoadStatement &other) : SQLStatement(other), info(other.info->Copy()) { -} - -unique_ptr LoadStatement::Copy() const { - return unique_ptr(new LoadStatement(*this)); -} - -} // namespace duckdb - - -namespace duckdb { - -MultiStatement::MultiStatement() : SQLStatement(StatementType::MULTI_STATEMENT) { -} - -MultiStatement::MultiStatement(const MultiStatement &other) : SQLStatement(other) { - for (auto &stmt : other.statements) { - statements.push_back(stmt->Copy()); - } -} - -unique_ptr MultiStatement::Copy() const { - return unique_ptr(new MultiStatement(*this)); -} - -} // namespace duckdb - - -namespace duckdb { - -PragmaStatement::PragmaStatement() : SQLStatement(StatementType::PRAGMA_STATEMENT), info(make_uniq()) { -} - -PragmaStatement::PragmaStatement(const PragmaStatement &other) : SQLStatement(other), info(other.info->Copy()) { -} - -unique_ptr PragmaStatement::Copy() const { - return unique_ptr(new PragmaStatement(*this)); -} - -} // namespace duckdb - - -namespace duckdb { - -PrepareStatement::PrepareStatement() : SQLStatement(StatementType::PREPARE_STATEMENT), statement(nullptr), name("") { -} - -PrepareStatement::PrepareStatement(const PrepareStatement &other) - : SQLStatement(other), statement(other.statement->Copy()), name(other.name) { -} - -unique_ptr PrepareStatement::Copy() const { - return unique_ptr(new PrepareStatement(*this)); -} - -} // namespace duckdb - - -namespace duckdb { - -RelationStatement::RelationStatement(shared_ptr relation) - : SQLStatement(StatementType::RELATION_STATEMENT), relation(std::move(relation)) { -} - -unique_ptr RelationStatement::Copy() const { - return unique_ptr(new RelationStatement(*this)); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -SelectStatement::SelectStatement(const SelectStatement &other) : SQLStatement(other), node(other.node->Copy()) { -} - -unique_ptr SelectStatement::Copy() const { - return unique_ptr(new SelectStatement(*this)); -} - -bool SelectStatement::Equals(const SQLStatement &other_p) const { - if (type != other_p.type) { - return false; - } - auto &other = other_p.Cast(); - return node->Equals(other.node.get()); -} - -string SelectStatement::ToString() const { - return node->ToString(); -} - -} // namespace duckdb - - -namespace duckdb { - -SetStatement::SetStatement(std::string name_p, SetScope scope_p, SetType type_p) - : SQLStatement(StatementType::SET_STATEMENT), name(std::move(name_p)), scope(scope_p), set_type(type_p) { -} - -unique_ptr SetStatement::Copy() const { - return unique_ptr(new SetStatement(*this)); -} - -// Set Variable - -SetVariableStatement::SetVariableStatement(std::string name_p, Value value_p, SetScope scope_p) - : SetStatement(std::move(name_p), scope_p, SetType::SET), value(std::move(value_p)) { -} - -unique_ptr SetVariableStatement::Copy() const { - return unique_ptr(new SetVariableStatement(*this)); -} - -// Reset Variable - -ResetVariableStatement::ResetVariableStatement(std::string name_p, SetScope scope_p) - : SetStatement(std::move(name_p), scope_p, SetType::RESET) { -} - -} // namespace duckdb - - -namespace duckdb { - -ShowStatement::ShowStatement() : SQLStatement(StatementType::SHOW_STATEMENT), info(make_uniq()) { -} - -ShowStatement::ShowStatement(const ShowStatement &other) : SQLStatement(other), info(other.info->Copy()) { -} - -unique_ptr ShowStatement::Copy() const { - return unique_ptr(new ShowStatement(*this)); -} - -} // namespace duckdb - - -namespace duckdb { - -TransactionStatement::TransactionStatement(TransactionType type) - : SQLStatement(StatementType::TRANSACTION_STATEMENT), info(make_uniq(type)) { -} - -TransactionStatement::TransactionStatement(const TransactionStatement &other) - : SQLStatement(other), info(make_uniq(other.info->type)) { -} - -unique_ptr TransactionStatement::Copy() const { - return unique_ptr(new TransactionStatement(*this)); -} - -} // namespace duckdb - - - -namespace duckdb { - -UpdateSetInfo::UpdateSetInfo() { -} - -UpdateSetInfo::UpdateSetInfo(const UpdateSetInfo &other) : columns(other.columns) { - if (other.condition) { - condition = other.condition->Copy(); - } - for (auto &expr : other.expressions) { - expressions.emplace_back(expr->Copy()); - } -} - -unique_ptr UpdateSetInfo::Copy() const { - return unique_ptr(new UpdateSetInfo(*this)); -} - -UpdateStatement::UpdateStatement() : SQLStatement(StatementType::UPDATE_STATEMENT) { -} - -UpdateStatement::UpdateStatement(const UpdateStatement &other) - : SQLStatement(other), table(other.table->Copy()), set_info(other.set_info->Copy()) { - if (other.from_table) { - from_table = other.from_table->Copy(); - } - for (auto &expr : other.returning_list) { - returning_list.emplace_back(expr->Copy()); - } - cte_map = other.cte_map.Copy(); -} - -string UpdateStatement::ToString() const { - D_ASSERT(set_info); - auto &condition = set_info->condition; - auto &columns = set_info->columns; - auto &expressions = set_info->expressions; - - string result; - result = cte_map.ToString(); - result += "UPDATE "; - result += table->ToString(); - result += " SET "; - D_ASSERT(columns.size() == expressions.size()); - for (idx_t i = 0; i < columns.size(); i++) { - if (i > 0) { - result += ", "; - } - result += KeywordHelper::WriteOptionallyQuoted(columns[i]); - result += " = "; - result += expressions[i]->ToString(); - } - if (from_table) { - result += " FROM " + from_table->ToString(); - } - if (condition) { - result += " WHERE " + condition->ToString(); - } - if (!returning_list.empty()) { - result += " RETURNING "; - for (idx_t i = 0; i < returning_list.size(); i++) { - if (i > 0) { - result += ", "; - } - result += returning_list[i]->ToString(); - } - } - return result; -} - -unique_ptr UpdateStatement::Copy() const { - return unique_ptr(new UpdateStatement(*this)); -} - -} // namespace duckdb - - -namespace duckdb { - -VacuumStatement::VacuumStatement(const VacuumOptions &options) - : SQLStatement(StatementType::VACUUM_STATEMENT), info(make_uniq(options)) { -} - -VacuumStatement::VacuumStatement(const VacuumStatement &other) : SQLStatement(other), info(other.info->Copy()) { -} - -unique_ptr VacuumStatement::Copy() const { - return unique_ptr(new VacuumStatement(*this)); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -string BaseTableRef::ToString() const { - string result; - result += catalog_name.empty() ? "" : (KeywordHelper::WriteOptionallyQuoted(catalog_name) + "."); - result += schema_name.empty() ? "" : (KeywordHelper::WriteOptionallyQuoted(schema_name) + "."); - result += KeywordHelper::WriteOptionallyQuoted(table_name); - return BaseToString(result, column_name_alias); -} - -bool BaseTableRef::Equals(const TableRef &other_p) const { - if (!TableRef::Equals(other_p)) { - return false; - } - auto &other = other_p.Cast(); - return other.catalog_name == catalog_name && other.schema_name == schema_name && other.table_name == table_name && - column_name_alias == other.column_name_alias; -} - -unique_ptr BaseTableRef::Copy() { - auto copy = make_uniq(); - - copy->catalog_name = catalog_name; - copy->schema_name = schema_name; - copy->table_name = table_name; - copy->column_name_alias = column_name_alias; - CopyProperties(*copy); - - return std::move(copy); -} -} // namespace duckdb - - -namespace duckdb { - -string EmptyTableRef::ToString() const { - return ""; -} - -bool EmptyTableRef::Equals(const TableRef &other) const { - return TableRef::Equals(other); -} - -unique_ptr EmptyTableRef::Copy() { - return make_uniq(); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -string ExpressionListRef::ToString() const { - D_ASSERT(!values.empty()); - string result = "(VALUES "; - for (idx_t row_idx = 0; row_idx < values.size(); row_idx++) { - if (row_idx > 0) { - result += ", "; - } - auto &row = values[row_idx]; - result += "("; - for (idx_t col_idx = 0; col_idx < row.size(); col_idx++) { - if (col_idx > 0) { - result += ", "; - } - result += row[col_idx]->ToString(); - } - result += ")"; - } - result += ")"; - return BaseToString(result, expected_names); -} - -bool ExpressionListRef::Equals(const TableRef &other_p) const { - if (!TableRef::Equals(other_p)) { - return false; - } - auto &other = other_p.Cast(); - if (values.size() != other.values.size()) { - return false; - } - for (idx_t i = 0; i < values.size(); i++) { - if (values[i].size() != other.values[i].size()) { - return false; - } - for (idx_t j = 0; j < values[i].size(); j++) { - if (!values[i][j]->Equals(*other.values[i][j])) { - return false; - } - } - } - return true; -} - -unique_ptr ExpressionListRef::Copy() { - // value list - auto result = make_uniq(); - for (auto &val_list : values) { - vector> new_val_list; - new_val_list.reserve(val_list.size()); - for (auto &val : val_list) { - new_val_list.push_back(val->Copy()); - } - result->values.push_back(std::move(new_val_list)); - } - result->expected_names = expected_names; - result->expected_types = expected_types; - CopyProperties(*result); - return std::move(result); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -string JoinRef::ToString() const { - string result; - result = left->ToString() + " "; - switch (ref_type) { - case JoinRefType::REGULAR: - result += EnumUtil::ToString(type) + " JOIN "; - break; - case JoinRefType::NATURAL: - result += "NATURAL "; - result += EnumUtil::ToString(type) + " JOIN "; - break; - case JoinRefType::ASOF: - result += "ASOF "; - result += EnumUtil::ToString(type) + " JOIN "; - break; - case JoinRefType::CROSS: - result += ", "; - break; - case JoinRefType::POSITIONAL: - result += "POSITIONAL JOIN "; - break; - case JoinRefType::DEPENDENT: - result += "DEPENDENT JOIN "; - break; - } - result += right->ToString(); - if (condition) { - D_ASSERT(using_columns.empty()); - result += " ON ("; - result += condition->ToString(); - result += ")"; - } else if (!using_columns.empty()) { - result += " USING ("; - for (idx_t i = 0; i < using_columns.size(); i++) { - if (i > 0) { - result += ", "; - } - result += using_columns[i]; - } - result += ")"; - } - return result; -} - -bool JoinRef::Equals(const TableRef &other_p) const { - if (!TableRef::Equals(other_p)) { - return false; - } - auto &other = other_p.Cast(); - if (using_columns.size() != other.using_columns.size()) { - return false; - } - for (idx_t i = 0; i < using_columns.size(); i++) { - if (using_columns[i] != other.using_columns[i]) { - return false; - } - } - return left->Equals(*other.left) && right->Equals(*other.right) && - ParsedExpression::Equals(condition, other.condition) && type == other.type; -} - -unique_ptr JoinRef::Copy() { - auto copy = make_uniq(ref_type); - copy->left = left->Copy(); - copy->right = right->Copy(); - if (condition) { - copy->condition = condition->Copy(); - } - copy->type = type; - copy->ref_type = ref_type; - copy->alias = alias; - copy->using_columns = using_columns; - return std::move(copy); -} - -} // namespace duckdb - - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// PivotColumn -//===--------------------------------------------------------------------===// -string PivotColumn::ToString() const { - string result; - if (!unpivot_names.empty()) { - D_ASSERT(pivot_expressions.empty()); - // unpivot - if (unpivot_names.size() == 1) { - result += KeywordHelper::WriteOptionallyQuoted(unpivot_names[0]); - } else { - result += "("; - for (idx_t n = 0; n < unpivot_names.size(); n++) { - if (n > 0) { - result += ", "; - } - result += KeywordHelper::WriteOptionallyQuoted(unpivot_names[n]); - } - result += ")"; - } - } else if (!pivot_expressions.empty()) { - // pivot - result += "("; - for (idx_t n = 0; n < pivot_expressions.size(); n++) { - if (n > 0) { - result += ", "; - } - result += pivot_expressions[n]->ToString(); - } - result += ")"; - } - result += " IN "; - if (pivot_enum.empty()) { - result += "("; - for (idx_t e = 0; e < entries.size(); e++) { - auto &entry = entries[e]; - if (e > 0) { - result += ", "; - } - if (entry.star_expr) { - D_ASSERT(entry.values.empty()); - result += entry.star_expr->ToString(); - } else if (entry.values.size() == 1) { - result += entry.values[0].ToSQLString(); - } else { - result += "("; - for (idx_t v = 0; v < entry.values.size(); v++) { - if (v > 0) { - result += ", "; - } - result += entry.values[v].ToSQLString(); - } - result += ")"; - } - if (!entry.alias.empty()) { - result += " AS " + KeywordHelper::WriteOptionallyQuoted(entry.alias); - } - } - result += ")"; - } else { - result += KeywordHelper::WriteOptionallyQuoted(pivot_enum); - } - return result; -} - -bool PivotColumnEntry::Equals(const PivotColumnEntry &other) const { - if (alias != other.alias) { - return false; - } - if (values.size() != other.values.size()) { - return false; - } - for (idx_t i = 0; i < values.size(); i++) { - if (!Value::NotDistinctFrom(values[i], other.values[i])) { - return false; - } - } - return true; -} - -bool PivotColumn::Equals(const PivotColumn &other) const { - if (!ExpressionUtil::ListEquals(pivot_expressions, other.pivot_expressions)) { - return false; - } - if (other.unpivot_names != unpivot_names) { - return false; - } - if (other.pivot_enum != pivot_enum) { - return false; - } - if (other.entries.size() != entries.size()) { - return false; - } - for (idx_t i = 0; i < entries.size(); i++) { - if (!entries[i].Equals(other.entries[i])) { - return false; - } - } - return true; -} - -PivotColumn PivotColumn::Copy() const { - PivotColumn result; - for (auto &expr : pivot_expressions) { - result.pivot_expressions.push_back(expr->Copy()); - } - result.unpivot_names = unpivot_names; - for (auto &entry : entries) { - result.entries.push_back(entry.Copy()); - } - result.pivot_enum = pivot_enum; - return result; -} - -//===--------------------------------------------------------------------===// -// PivotColumnEntry -//===--------------------------------------------------------------------===// -PivotColumnEntry PivotColumnEntry::Copy() const { - PivotColumnEntry result; - result.values = values; - result.star_expr = star_expr ? star_expr->Copy() : nullptr; - result.alias = alias; - return result; -} - -//===--------------------------------------------------------------------===// -// PivotRef -//===--------------------------------------------------------------------===// -string PivotRef::ToString() const { - string result; - result = source->ToString(); - if (!aggregates.empty()) { - // pivot - result += " PIVOT ("; - for (idx_t aggr_idx = 0; aggr_idx < aggregates.size(); aggr_idx++) { - if (aggr_idx > 0) { - result += ", "; - } - result += aggregates[aggr_idx]->ToString(); - if (!aggregates[aggr_idx]->alias.empty()) { - result += " AS " + KeywordHelper::WriteOptionallyQuoted(aggregates[aggr_idx]->alias); - } - } - } else { - // unpivot - result += " UNPIVOT "; - if (include_nulls) { - result += "INCLUDE NULLS "; - } - result += "("; - if (unpivot_names.size() == 1) { - result += KeywordHelper::WriteOptionallyQuoted(unpivot_names[0]); - } else { - result += "("; - for (idx_t n = 0; n < unpivot_names.size(); n++) { - if (n > 0) { - result += ", "; - } - result += KeywordHelper::WriteOptionallyQuoted(unpivot_names[n]); - } - result += ")"; - } - } - result += " FOR"; - for (auto &pivot : pivots) { - result += " "; - result += pivot.ToString(); - } - if (!groups.empty()) { - result += " GROUP BY "; - for (idx_t i = 0; i < groups.size(); i++) { - if (i > 0) { - result += ", "; - } - result += groups[i]; - } - } - result += ")"; - if (!alias.empty()) { - result += " AS " + KeywordHelper::WriteOptionallyQuoted(alias); - if (!column_name_alias.empty()) { - result += "("; - for (idx_t i = 0; i < column_name_alias.size(); i++) { - if (i > 0) { - result += ", "; - } - result += KeywordHelper::WriteOptionallyQuoted(column_name_alias[i]); - } - result += ")"; - } - } - return result; -} - -bool PivotRef::Equals(const TableRef &other_p) const { - if (!TableRef::Equals(other_p)) { - return false; - } - auto &other = other_p.Cast(); - if (!source->Equals(*other.source)) { - return false; - } - if (!ParsedExpression::ListEquals(aggregates, other.aggregates)) { - return false; - } - if (pivots.size() != other.pivots.size()) { - return false; - } - for (idx_t i = 0; i < pivots.size(); i++) { - if (!pivots[i].Equals(other.pivots[i])) { - return false; - } - } - if (unpivot_names != other.unpivot_names) { - return false; - } - if (alias != other.alias) { - return false; - } - if (groups != other.groups) { - return false; - } - if (include_nulls != other.include_nulls) { - return false; - } - return true; -} - -unique_ptr PivotRef::Copy() { - auto copy = make_uniq(); - copy->source = source->Copy(); - for (auto &aggr : aggregates) { - copy->aggregates.push_back(aggr->Copy()); - } - copy->unpivot_names = unpivot_names; - for (auto &entry : pivots) { - copy->pivots.push_back(entry.Copy()); - } - copy->groups = groups; - copy->column_name_alias = column_name_alias; - copy->include_nulls = include_nulls; - copy->alias = alias; - return std::move(copy); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -string SubqueryRef::ToString() const { - string result = "(" + subquery->ToString() + ")"; - return BaseToString(result, column_name_alias); -} - -SubqueryRef::SubqueryRef() : TableRef(TableReferenceType::SUBQUERY) { -} - -SubqueryRef::SubqueryRef(unique_ptr subquery_p, string alias_p) - : TableRef(TableReferenceType::SUBQUERY), subquery(std::move(subquery_p)) { - this->alias = std::move(alias_p); -} - -bool SubqueryRef::Equals(const TableRef &other_p) const { - if (!TableRef::Equals(other_p)) { - return false; - } - auto &other = other_p.Cast(); - return subquery->Equals(*other.subquery); -} - -unique_ptr SubqueryRef::Copy() { - auto copy = make_uniq(unique_ptr_cast(subquery->Copy()), alias); - copy->column_name_alias = column_name_alias; - CopyProperties(*copy); - return std::move(copy); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -TableFunctionRef::TableFunctionRef() : TableRef(TableReferenceType::TABLE_FUNCTION) { -} - -string TableFunctionRef::ToString() const { - return BaseToString(function->ToString(), column_name_alias); -} - -bool TableFunctionRef::Equals(const TableRef &other_p) const { - if (!TableRef::Equals(other_p)) { - return false; - } - auto &other = other_p.Cast(); - return function->Equals(*other.function); -} - -unique_ptr TableFunctionRef::Copy() { - auto copy = make_uniq(); - - copy->function = function->Copy(); - copy->column_name_alias = column_name_alias; - CopyProperties(*copy); - - return std::move(copy); -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -string TableRef::BaseToString(string result) const { - vector column_name_alias; - return BaseToString(std::move(result), column_name_alias); -} - -string TableRef::BaseToString(string result, const vector &column_name_alias) const { - if (!alias.empty()) { - result += StringUtil::Format(" AS %s", SQLIdentifier(alias)); - } - if (!column_name_alias.empty()) { - D_ASSERT(!alias.empty()); - result += "("; - for (idx_t i = 0; i < column_name_alias.size(); i++) { - if (i > 0) { - result += ", "; - } - result += KeywordHelper::WriteOptionallyQuoted(column_name_alias[i]); - } - result += ")"; - } - if (sample) { - result += " TABLESAMPLE " + EnumUtil::ToString(sample->method); - result += "(" + sample->sample_size.ToString() + " " + string(sample->is_percentage ? "PERCENT" : "ROWS") + ")"; - if (sample->seed >= 0) { - result += "REPEATABLE (" + to_string(sample->seed) + ")"; - } - } - - return result; -} - -bool TableRef::Equals(const TableRef &other) const { - return type == other.type && alias == other.alias && SampleOptions::Equals(sample.get(), other.sample.get()); -} - -void TableRef::CopyProperties(TableRef &target) const { - D_ASSERT(type == target.type); - target.alias = alias; - target.query_location = query_location; - target.sample = sample ? sample->Copy() : nullptr; -} - -void TableRef::Print() { - Printer::Print(ToString()); -} - -bool TableRef::Equals(const unique_ptr &left, const unique_ptr &right) { - if (left.get() == right.get()) { - return true; - } - if (!left || !right) { - return false; - } - return left->Equals(*right); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -static void ParseSchemaTableNameFK(duckdb_libpgquery::PGRangeVar *input, ForeignKeyInfo &fk_info) { - if (input->catalogname) { - throw ParserException("FOREIGN KEY constraints cannot be defined cross-database"); - } - if (input->schemaname) { - fk_info.schema = input->schemaname; - } else { - fk_info.schema = ""; - }; - fk_info.table = input->relname; -} - -static bool ForeignKeyActionSupported(char action) { - switch (action) { - case PG_FKCONSTR_ACTION_NOACTION: - case PG_FKCONSTR_ACTION_RESTRICT: - return true; - case PG_FKCONSTR_ACTION_CASCADE: - case PG_FKCONSTR_ACTION_SETDEFAULT: - case PG_FKCONSTR_ACTION_SETNULL: - return false; - default: - D_ASSERT(false); - } - return false; -} - -static unique_ptr -TransformForeignKeyConstraint(duckdb_libpgquery::PGConstraint *constraint, - optional_ptr override_fk_column = nullptr) { - D_ASSERT(constraint); - if (!ForeignKeyActionSupported(constraint->fk_upd_action) || - !ForeignKeyActionSupported(constraint->fk_del_action)) { - throw ParserException("FOREIGN KEY constraints cannot use CASCADE, SET NULL or SET DEFAULT"); - } - ForeignKeyInfo fk_info; - fk_info.type = ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE; - ParseSchemaTableNameFK(constraint->pktable, fk_info); - vector pk_columns, fk_columns; - if (override_fk_column) { - D_ASSERT(!constraint->fk_attrs); - fk_columns.emplace_back(*override_fk_column); - } else if (constraint->fk_attrs) { - for (auto kc = constraint->fk_attrs->head; kc; kc = kc->next) { - fk_columns.emplace_back(reinterpret_cast(kc->data.ptr_value)->val.str); - } - } - if (constraint->pk_attrs) { - for (auto kc = constraint->pk_attrs->head; kc; kc = kc->next) { - pk_columns.emplace_back(reinterpret_cast(kc->data.ptr_value)->val.str); - } - } - if (!pk_columns.empty() && pk_columns.size() != fk_columns.size()) { - throw ParserException("The number of referencing and referenced columns for foreign keys must be the same"); - } - if (fk_columns.empty()) { - throw ParserException("The set of referencing and referenced columns for foreign keys must be not empty"); - } - return make_uniq(pk_columns, fk_columns, std::move(fk_info)); -} - -unique_ptr Transformer::TransformConstraint(duckdb_libpgquery::PGListCell *cell) { - auto constraint = reinterpret_cast(cell->data.ptr_value); - D_ASSERT(constraint); - switch (constraint->contype) { - case duckdb_libpgquery::PG_CONSTR_UNIQUE: - case duckdb_libpgquery::PG_CONSTR_PRIMARY: { - bool is_primary_key = constraint->contype == duckdb_libpgquery::PG_CONSTR_PRIMARY; - vector columns; - for (auto kc = constraint->keys->head; kc; kc = kc->next) { - columns.emplace_back(reinterpret_cast(kc->data.ptr_value)->val.str); - } - return make_uniq(columns, is_primary_key); - } - case duckdb_libpgquery::PG_CONSTR_CHECK: { - auto expression = TransformExpression(constraint->raw_expr); - if (expression->HasSubquery()) { - throw ParserException("subqueries prohibited in CHECK constraints"); - } - return make_uniq(TransformExpression(constraint->raw_expr)); - } - case duckdb_libpgquery::PG_CONSTR_FOREIGN: - return TransformForeignKeyConstraint(constraint); - - default: - throw NotImplementedException("Constraint type not handled yet!"); - } -} - -unique_ptr Transformer::TransformConstraint(duckdb_libpgquery::PGListCell *cell, ColumnDefinition &column, - idx_t index) { - auto constraint = reinterpret_cast(cell->data.ptr_value); - D_ASSERT(constraint); - switch (constraint->contype) { - case duckdb_libpgquery::PG_CONSTR_NOTNULL: - return make_uniq(LogicalIndex(index)); - case duckdb_libpgquery::PG_CONSTR_CHECK: - return TransformConstraint(cell); - case duckdb_libpgquery::PG_CONSTR_PRIMARY: - return make_uniq(LogicalIndex(index), true); - case duckdb_libpgquery::PG_CONSTR_UNIQUE: - return make_uniq(LogicalIndex(index), false); - case duckdb_libpgquery::PG_CONSTR_NULL: - return nullptr; - case duckdb_libpgquery::PG_CONSTR_GENERATED_VIRTUAL: { - if (column.DefaultValue()) { - throw InvalidInputException("DEFAULT constraint on GENERATED column \"%s\" is not allowed", column.Name()); - } - column.SetGeneratedExpression(TransformExpression(constraint->raw_expr)); - return nullptr; - } - case duckdb_libpgquery::PG_CONSTR_GENERATED_STORED: - throw InvalidInputException("Can not create a STORED generated column!"); - case duckdb_libpgquery::PG_CONSTR_DEFAULT: - column.SetDefaultValue(TransformExpression(constraint->raw_expr)); - return nullptr; - case duckdb_libpgquery::PG_CONSTR_COMPRESSION: - column.SetCompressionType(CompressionTypeFromString(constraint->compression_name)); - if (column.CompressionType() == CompressionType::COMPRESSION_AUTO) { - throw ParserException("Unrecognized option for column compression, expected none, uncompressed, rle, " - "dictionary, pfor, bitpacking or fsst"); - } - return nullptr; - case duckdb_libpgquery::PG_CONSTR_FOREIGN: - return TransformForeignKeyConstraint(constraint, &column.Name()); - default: - throw NotImplementedException("Constraint not implemented!"); - } -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -unique_ptr Transformer::TransformArrayAccess(duckdb_libpgquery::PGAIndirection &indirection_node) { - // transform the source expression - unique_ptr result; - result = TransformExpression(indirection_node.arg); - - // now go over the indices - // note that a single indirection node can contain multiple indices - // this happens for e.g. more complex accesses (e.g. (foo).field1[42]) - idx_t list_size = 0; - for (auto node = indirection_node.indirection->head; node != nullptr; node = node->next) { - auto target = reinterpret_cast(node->data.ptr_value); - D_ASSERT(target); - - switch (target->type) { - case duckdb_libpgquery::T_PGAIndices: { - // index access (either slice or extract) - auto index = PGPointerCast(target); - vector> children; - children.push_back(std::move(result)); - if (index->is_slice) { - // slice - // if either the lower or upper bound is not specified, we use an empty const list so that we can - // handle it in the execution - unique_ptr lower = - index->lidx ? TransformExpression(index->lidx) - : make_uniq(Value::LIST(LogicalType::INTEGER, vector())); - children.push_back(std::move(lower)); - unique_ptr upper = - index->uidx ? TransformExpression(index->uidx) - : make_uniq(Value::LIST(LogicalType::INTEGER, vector())); - children.push_back(std::move(upper)); - if (index->step) { - children.push_back(TransformExpression(index->step)); - } - result = make_uniq(ExpressionType::ARRAY_SLICE, std::move(children)); - } else { - // array access - D_ASSERT(!index->lidx); - D_ASSERT(index->uidx); - children.push_back(TransformExpression(index->uidx)); - result = make_uniq(ExpressionType::ARRAY_EXTRACT, std::move(children)); - } - break; - } - case duckdb_libpgquery::T_PGString: { - auto val = PGPointerCast(target); - vector> children; - children.push_back(std::move(result)); - children.push_back(TransformValue(*val)); - result = make_uniq(ExpressionType::STRUCT_EXTRACT, std::move(children)); - break; - } - case duckdb_libpgquery::T_PGFuncCall: { - auto func = PGPointerCast(target); - auto function = TransformFuncCall(*func); - if (function->type != ExpressionType::FUNCTION) { - throw ParserException("%s.%s() call must be a function", result->ToString(), function->ToString()); - } - auto &f = function->Cast(); - f.children.insert(f.children.begin(), std::move(result)); - result = std::move(function); - break; - } - default: - throw NotImplementedException("Unimplemented subscript type"); - } - list_size++; - StackCheck(list_size); - } - return result; -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr Transformer::TransformBoolExpr(duckdb_libpgquery::PGBoolExpr &root) { - unique_ptr result; - for (auto node = root.args->head; node != nullptr; node = node->next) { - auto next = TransformExpression(PGPointerCast(node->data.ptr_value)); - - switch (root.boolop) { - case duckdb_libpgquery::PG_AND_EXPR: { - if (!result) { - result = std::move(next); - } else { - result = make_uniq(ExpressionType::CONJUNCTION_AND, std::move(result), - std::move(next)); - } - break; - } - case duckdb_libpgquery::PG_OR_EXPR: { - if (!result) { - result = std::move(next); - } else { - result = make_uniq(ExpressionType::CONJUNCTION_OR, std::move(result), - std::move(next)); - } - break; - } - case duckdb_libpgquery::PG_NOT_EXPR: { - if (next->type == ExpressionType::COMPARE_IN) { - // convert COMPARE_IN to COMPARE_NOT_IN - next->type = ExpressionType::COMPARE_NOT_IN; - result = std::move(next); - } else if (next->type >= ExpressionType::COMPARE_EQUAL && - next->type <= ExpressionType::COMPARE_GREATERTHANOREQUALTO) { - // NOT on a comparison: we can negate the comparison - // e.g. NOT(x > y) is equivalent to x <= y - next->type = NegateComparisonExpression(next->type); - result = std::move(next); - } else { - result = make_uniq(ExpressionType::OPERATOR_NOT, std::move(next)); - } - break; - } - } - } - return result; -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -unique_ptr Transformer::TransformBooleanTest(duckdb_libpgquery::PGBooleanTest &node) { - auto argument = TransformExpression(PGPointerCast(node.arg)); - - auto expr_true = make_uniq(Value::BOOLEAN(true)); - auto expr_false = make_uniq(Value::BOOLEAN(false)); - // we cast the argument to bool to remove ambiguity wrt function binding on the comparision - auto cast_argument = make_uniq(LogicalType::BOOLEAN, argument->Copy()); - - switch (node.booltesttype) { - case duckdb_libpgquery::PGBoolTestType::PG_IS_TRUE: - return make_uniq(ExpressionType::COMPARE_NOT_DISTINCT_FROM, std::move(cast_argument), - std::move(expr_true)); - case duckdb_libpgquery::PGBoolTestType::IS_NOT_TRUE: - return make_uniq(ExpressionType::COMPARE_DISTINCT_FROM, std::move(cast_argument), - std::move(expr_true)); - case duckdb_libpgquery::PGBoolTestType::IS_FALSE: - return make_uniq(ExpressionType::COMPARE_NOT_DISTINCT_FROM, std::move(cast_argument), - std::move(expr_false)); - case duckdb_libpgquery::PGBoolTestType::IS_NOT_FALSE: - return make_uniq(ExpressionType::COMPARE_DISTINCT_FROM, std::move(cast_argument), - std::move(expr_false)); - case duckdb_libpgquery::PGBoolTestType::IS_UNKNOWN: // IS NULL - return make_uniq(ExpressionType::OPERATOR_IS_NULL, std::move(argument)); - case duckdb_libpgquery::PGBoolTestType::IS_NOT_UNKNOWN: // IS NOT NULL - return make_uniq(ExpressionType::OPERATOR_IS_NOT_NULL, std::move(argument)); - default: - throw NotImplementedException("Unknown boolean test type %d", node.booltesttype); - } -} - -} // namespace duckdb - - - - - -namespace duckdb { - -unique_ptr Transformer::TransformCase(duckdb_libpgquery::PGCaseExpr &root) { - auto case_node = make_uniq(); - auto root_arg = TransformExpression(PGPointerCast(root.arg)); - for (auto cell = root.args->head; cell != nullptr; cell = cell->next) { - CaseCheck case_check; - - auto w = PGPointerCast(cell->data.ptr_value); - auto test_raw = TransformExpression(PGPointerCast(w->expr)); - unique_ptr test; - if (root_arg) { - case_check.when_expr = - make_uniq(ExpressionType::COMPARE_EQUAL, root_arg->Copy(), std::move(test_raw)); - } else { - case_check.when_expr = std::move(test_raw); - } - case_check.then_expr = TransformExpression(PGPointerCast(w->result)); - case_node->case_checks.push_back(std::move(case_check)); - } - - if (root.defresult) { - case_node->else_expr = TransformExpression(PGPointerCast(root.defresult)); - } else { - case_node->else_expr = make_uniq(Value(LogicalType::SQLNULL)); - } - return std::move(case_node); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -unique_ptr Transformer::TransformTypeCast(duckdb_libpgquery::PGTypeCast &root) { - // get the type to cast to - auto type_name = root.typeName; - LogicalType target_type = TransformTypeName(*type_name); - - // check for a constant BLOB value, then return ConstantExpression with BLOB - if (!root.tryCast && target_type == LogicalType::BLOB && root.arg->type == duckdb_libpgquery::T_PGAConst) { - auto c = PGPointerCast(root.arg); - if (c->val.type == duckdb_libpgquery::T_PGString) { - return make_uniq(Value::BLOB(string(c->val.val.str))); - } - } - // transform the expression node - auto expression = TransformExpression(root.arg); - bool try_cast = root.tryCast; - - // now create a cast operation - return make_uniq(target_type, std::move(expression), try_cast); -} - -} // namespace duckdb - - - -namespace duckdb { - -// COALESCE(a,b,c) returns the first argument that is NOT NULL, so -// rewrite into CASE(a IS NOT NULL, a, CASE(b IS NOT NULL, b, c)) -unique_ptr Transformer::TransformCoalesce(duckdb_libpgquery::PGAExpr &root) { - auto coalesce_args = PGPointerCast(root.lexpr); - D_ASSERT(coalesce_args->length > 0); // parser ensures this already - - auto coalesce_op = make_uniq(ExpressionType::OPERATOR_COALESCE); - for (auto cell = coalesce_args->head; cell; cell = cell->next) { - // get the value of the COALESCE - auto value_expr = TransformExpression(PGPointerCast(cell->data.ptr_value)); - coalesce_op->children.push_back(std::move(value_expr)); - } - return std::move(coalesce_op); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -unique_ptr Transformer::TransformStarExpression(duckdb_libpgquery::PGAStar &star) { - auto result = make_uniq(star.relation ? star.relation : string()); - if (star.except_list) { - for (auto head = star.except_list->head; head; head = head->next) { - auto value = PGPointerCast(head->data.ptr_value); - D_ASSERT(value->type == duckdb_libpgquery::T_PGString); - string exclude_entry = value->val.str; - if (result->exclude_list.find(exclude_entry) != result->exclude_list.end()) { - throw ParserException("Duplicate entry \"%s\" in EXCLUDE list", exclude_entry); - } - result->exclude_list.insert(std::move(exclude_entry)); - } - } - if (star.replace_list) { - for (auto head = star.replace_list->head; head; head = head->next) { - auto list = PGPointerCast(head->data.ptr_value); - D_ASSERT(list->length == 2); - auto replace_expression = - TransformExpression(PGPointerCast(list->head->data.ptr_value)); - auto value = PGPointerCast(list->tail->data.ptr_value); - D_ASSERT(value->type == duckdb_libpgquery::T_PGString); - string exclude_entry = value->val.str; - if (result->replace_list.find(exclude_entry) != result->replace_list.end()) { - throw ParserException("Duplicate entry \"%s\" in REPLACE list", exclude_entry); - } - if (result->exclude_list.find(exclude_entry) != result->exclude_list.end()) { - throw ParserException("Column \"%s\" cannot occur in both EXCEPT and REPLACE list", exclude_entry); - } - result->replace_list.insert(make_pair(std::move(exclude_entry), std::move(replace_expression))); - } - } - if (star.expr) { - D_ASSERT(star.columns); - D_ASSERT(result->relation_name.empty()); - D_ASSERT(result->exclude_list.empty()); - D_ASSERT(result->replace_list.empty()); - result->expr = TransformExpression(star.expr); - if (result->expr->type == ExpressionType::STAR) { - auto &child_star = result->expr->Cast(); - result->exclude_list = std::move(child_star.exclude_list); - result->replace_list = std::move(child_star.replace_list); - result->expr.reset(); - } else if (result->expr->type == ExpressionType::LAMBDA) { - vector> children; - children.push_back(make_uniq()); - children.push_back(std::move(result->expr)); - auto list_filter = make_uniq("list_filter", std::move(children)); - result->expr = std::move(list_filter); - } - } - result->columns = star.columns; - result->query_location = star.location; - return std::move(result); -} - -unique_ptr Transformer::TransformColumnRef(duckdb_libpgquery::PGColumnRef &root) { - auto fields = root.fields; - auto head_node = PGPointerCast(fields->head->data.ptr_value); - switch (head_node->type) { - case duckdb_libpgquery::T_PGString: { - if (fields->length < 1) { - throw InternalException("Unexpected field length"); - } - vector column_names; - for (auto node = fields->head; node; node = node->next) { - column_names.emplace_back(PGPointerCast(node->data.ptr_value)->val.str); - } - auto colref = make_uniq(std::move(column_names)); - colref->query_location = root.location; - return std::move(colref); - } - case duckdb_libpgquery::T_PGAStar: { - return TransformStarExpression(PGCast(*head_node)); - } - default: - throw NotImplementedException("ColumnRef not implemented!"); - } -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -unique_ptr Transformer::TransformValue(duckdb_libpgquery::PGValue val) { - switch (val.type) { - case duckdb_libpgquery::T_PGInteger: - D_ASSERT(val.val.ival <= NumericLimits::Maximum()); - return make_uniq(Value::INTEGER((int32_t)val.val.ival)); - case duckdb_libpgquery::T_PGBitString: // FIXME: this should actually convert to BLOB - case duckdb_libpgquery::T_PGString: - return make_uniq(Value(string(val.val.str))); - case duckdb_libpgquery::T_PGFloat: { - string_t str_val(val.val.str); - bool try_cast_as_integer = true; - bool try_cast_as_decimal = true; - int decimal_position = -1; - for (idx_t i = 0; i < str_val.GetSize(); i++) { - if (val.val.str[i] == '.') { - // decimal point: cast as either decimal or double - try_cast_as_integer = false; - decimal_position = i; - } - if (val.val.str[i] == 'e' || val.val.str[i] == 'E') { - // found exponent, cast as double - try_cast_as_integer = false; - try_cast_as_decimal = false; - } - } - if (try_cast_as_integer) { - int64_t bigint_value; - // try to cast as bigint first - if (TryCast::Operation(str_val, bigint_value)) { - // successfully cast to bigint: bigint value - return make_uniq(Value::BIGINT(bigint_value)); - } - hugeint_t hugeint_value; - // if that is not successful; try to cast as hugeint - if (TryCast::Operation(str_val, hugeint_value)) { - // successfully cast to bigint: bigint value - return make_uniq(Value::HUGEINT(hugeint_value)); - } - } - idx_t decimal_offset = val.val.str[0] == '-' ? 3 : 2; - if (try_cast_as_decimal && decimal_position >= 0 && - str_val.GetSize() < Decimal::MAX_WIDTH_DECIMAL + decimal_offset) { - // figure out the width/scale based on the decimal position - auto width = uint8_t(str_val.GetSize() - 1); - auto scale = uint8_t(width - decimal_position); - if (val.val.str[0] == '-') { - width--; - } - if (width <= Decimal::MAX_WIDTH_DECIMAL) { - // we can cast the value as a decimal - Value val = Value(str_val); - val = val.DefaultCastAs(LogicalType::DECIMAL(width, scale)); - return make_uniq(std::move(val)); - } - } - // if there is a decimal or the value is too big to cast as either hugeint or bigint - double dbl_value = Cast::Operation(str_val); - return make_uniq(Value::DOUBLE(dbl_value)); - } - case duckdb_libpgquery::T_PGNull: - return make_uniq(Value(LogicalType::SQLNULL)); - default: - throw NotImplementedException("Value not implemented!"); - } -} - -unique_ptr Transformer::TransformConstant(duckdb_libpgquery::PGAConst &c) { - return TransformValue(c.val); -} - -bool Transformer::ConstructConstantFromExpression(const ParsedExpression &expr, Value &value) { - // We have to construct it like this because we don't have the ClientContext for binding/executing the expr here - switch (expr.type) { - case ExpressionType::FUNCTION: { - auto &function = expr.Cast(); - if (function.function_name == "struct_pack") { - unordered_set unique_names; - child_list_t values; - values.reserve(function.children.size()); - for (const auto &child : function.children) { - if (!unique_names.insert(child->alias).second) { - throw BinderException("Duplicate struct entry name \"%s\"", child->alias); - } - Value child_value; - if (!ConstructConstantFromExpression(*child, child_value)) { - return false; - } - values.emplace_back(child->alias, std::move(child_value)); - } - value = Value::STRUCT(std::move(values)); - return true; - } else { - return false; - } - } - case ExpressionType::VALUE_CONSTANT: { - auto &constant = expr.Cast(); - value = constant.value; - return true; - } - case ExpressionType::OPERATOR_CAST: { - auto &cast = expr.Cast(); - Value dummy_value; - if (!ConstructConstantFromExpression(*cast.child, dummy_value)) { - return false; - } - - string error_message; - if (!dummy_value.DefaultTryCastAs(cast.cast_type, value, &error_message)) { - throw ConversionException("Unable to cast %s to %s", dummy_value.ToString(), - EnumUtil::ToString(cast.cast_type.id())); - } - return true; - } - default: - return false; - } -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr Transformer::TransformResTarget(duckdb_libpgquery::PGResTarget &root) { - auto expr = TransformExpression(root.val); - if (!expr) { - return nullptr; - } - if (root.name) { - expr->alias = string(root.name); - } - return expr; -} - -unique_ptr Transformer::TransformNamedArg(duckdb_libpgquery::PGNamedArgExpr &root) { - - auto expr = TransformExpression(PGPointerCast(root.arg)); - if (root.name) { - expr->alias = string(root.name); - } - return expr; -} - -unique_ptr Transformer::TransformExpression(duckdb_libpgquery::PGNode &node) { - - auto stack_checker = StackCheck(); - - switch (node.type) { - case duckdb_libpgquery::T_PGColumnRef: - return TransformColumnRef(PGCast(node)); - case duckdb_libpgquery::T_PGAConst: - return TransformConstant(PGCast(node)); - case duckdb_libpgquery::T_PGAExpr: - return TransformAExpr(PGCast(node)); - case duckdb_libpgquery::T_PGFuncCall: - return TransformFuncCall(PGCast(node)); - case duckdb_libpgquery::T_PGBoolExpr: - return TransformBoolExpr(PGCast(node)); - case duckdb_libpgquery::T_PGTypeCast: - return TransformTypeCast(PGCast(node)); - case duckdb_libpgquery::T_PGCaseExpr: - return TransformCase(PGCast(node)); - case duckdb_libpgquery::T_PGSubLink: - return TransformSubquery(PGCast(node)); - case duckdb_libpgquery::T_PGCoalesceExpr: - return TransformCoalesce(PGCast(node)); - case duckdb_libpgquery::T_PGNullTest: - return TransformNullTest(PGCast(node)); - case duckdb_libpgquery::T_PGResTarget: - return TransformResTarget(PGCast(node)); - case duckdb_libpgquery::T_PGParamRef: - return TransformParamRef(PGCast(node)); - case duckdb_libpgquery::T_PGNamedArgExpr: - return TransformNamedArg(PGCast(node)); - case duckdb_libpgquery::T_PGSQLValueFunction: - return TransformSQLValueFunction(PGCast(node)); - case duckdb_libpgquery::T_PGSetToDefault: - return make_uniq(); - case duckdb_libpgquery::T_PGCollateClause: - return TransformCollateExpr(PGCast(node)); - case duckdb_libpgquery::T_PGIntervalConstant: - return TransformInterval(PGCast(node)); - case duckdb_libpgquery::T_PGLambdaFunction: - return TransformLambda(PGCast(node)); - case duckdb_libpgquery::T_PGAIndirection: - return TransformArrayAccess(PGCast(node)); - case duckdb_libpgquery::T_PGPositionalReference: - return TransformPositionalReference(PGCast(node)); - case duckdb_libpgquery::T_PGGroupingFunc: - return TransformGroupingFunction(PGCast(node)); - case duckdb_libpgquery::T_PGAStar: - return TransformStarExpression(PGCast(node)); - case duckdb_libpgquery::T_PGBooleanTest: - return TransformBooleanTest(PGCast(node)); - case duckdb_libpgquery::T_PGMultiAssignRef: - return TransformMultiAssignRef(PGCast(node)); - - default: - throw NotImplementedException("Expression type %s (%d)", NodetypeToString(node.type), (int)node.type); - } -} - -unique_ptr Transformer::TransformExpression(optional_ptr node) { - if (!node) { - return nullptr; - } - return TransformExpression(*node); -} - -void Transformer::TransformExpressionList(duckdb_libpgquery::PGList &list, - vector> &result) { - for (auto node = list.head; node != nullptr; node = node->next) { - auto target = PGPointerCast(node->data.ptr_value); - - auto expr = TransformExpression(*target); - result.push_back(std::move(expr)); - } -} - -} // namespace duckdb - - - - - - - - - - - - - -namespace duckdb { - -void Transformer::TransformWindowDef(duckdb_libpgquery::PGWindowDef &window_spec, WindowExpression &expr, - const char *window_name) { - // next: partitioning/ordering expressions - if (window_spec.partitionClause) { - if (window_name && !expr.partitions.empty()) { - throw ParserException("Cannot override PARTITION BY clause of window \"%s\"", window_name); - } - TransformExpressionList(*window_spec.partitionClause, expr.partitions); - } - if (window_spec.orderClause) { - if (window_name && !expr.orders.empty()) { - throw ParserException("Cannot override ORDER BY clause of window \"%s\"", window_name); - } - TransformOrderBy(window_spec.orderClause, expr.orders); - } -} - -void Transformer::TransformWindowFrame(duckdb_libpgquery::PGWindowDef &window_spec, WindowExpression &expr) { - // finally: specifics of bounds - expr.start_expr = TransformExpression(window_spec.startOffset); - expr.end_expr = TransformExpression(window_spec.endOffset); - - if ((window_spec.frameOptions & FRAMEOPTION_END_UNBOUNDED_PRECEDING) || - (window_spec.frameOptions & FRAMEOPTION_START_UNBOUNDED_FOLLOWING)) { - throw InternalException( - "Window frames starting with unbounded following or ending in unbounded preceding make no sense"); - } - - const bool rangeMode = (window_spec.frameOptions & FRAMEOPTION_RANGE) != 0; - if (window_spec.frameOptions & FRAMEOPTION_START_UNBOUNDED_PRECEDING) { - expr.start = WindowBoundary::UNBOUNDED_PRECEDING; - } else if (window_spec.frameOptions & FRAMEOPTION_START_VALUE_PRECEDING) { - expr.start = rangeMode ? WindowBoundary::EXPR_PRECEDING_RANGE : WindowBoundary::EXPR_PRECEDING_ROWS; - } else if (window_spec.frameOptions & FRAMEOPTION_START_VALUE_FOLLOWING) { - expr.start = rangeMode ? WindowBoundary::EXPR_FOLLOWING_RANGE : WindowBoundary::EXPR_FOLLOWING_ROWS; - } else if (window_spec.frameOptions & FRAMEOPTION_START_CURRENT_ROW) { - expr.start = rangeMode ? WindowBoundary::CURRENT_ROW_RANGE : WindowBoundary::CURRENT_ROW_ROWS; - } - - if (window_spec.frameOptions & FRAMEOPTION_END_UNBOUNDED_FOLLOWING) { - expr.end = WindowBoundary::UNBOUNDED_FOLLOWING; - } else if (window_spec.frameOptions & FRAMEOPTION_END_VALUE_PRECEDING) { - expr.end = rangeMode ? WindowBoundary::EXPR_PRECEDING_RANGE : WindowBoundary::EXPR_PRECEDING_ROWS; - } else if (window_spec.frameOptions & FRAMEOPTION_END_VALUE_FOLLOWING) { - expr.end = rangeMode ? WindowBoundary::EXPR_FOLLOWING_RANGE : WindowBoundary::EXPR_FOLLOWING_ROWS; - } else if (window_spec.frameOptions & FRAMEOPTION_END_CURRENT_ROW) { - expr.end = rangeMode ? WindowBoundary::CURRENT_ROW_RANGE : WindowBoundary::CURRENT_ROW_ROWS; - } - - D_ASSERT(expr.start != WindowBoundary::INVALID && expr.end != WindowBoundary::INVALID); - if (((window_spec.frameOptions & (FRAMEOPTION_START_VALUE_PRECEDING | FRAMEOPTION_START_VALUE_FOLLOWING)) && - !expr.start_expr) || - ((window_spec.frameOptions & (FRAMEOPTION_END_VALUE_PRECEDING | FRAMEOPTION_END_VALUE_FOLLOWING)) && - !expr.end_expr)) { - throw InternalException("Failed to transform window boundary expression"); - } -} - -bool Transformer::ExpressionIsEmptyStar(ParsedExpression &expr) { - if (expr.expression_class != ExpressionClass::STAR) { - return false; - } - auto &star = expr.Cast(); - if (!star.columns && star.exclude_list.empty() && star.replace_list.empty()) { - return true; - } - return false; -} - -bool Transformer::InWindowDefinition() { - if (in_window_definition) { - return true; - } - if (parent) { - return parent->InWindowDefinition(); - } - return false; -} - -unique_ptr Transformer::TransformFuncCall(duckdb_libpgquery::PGFuncCall &root) { - auto name = root.funcname; - string catalog, schema, function_name; - if (name->length == 3) { - // catalog + schema + name - catalog = PGPointerCast(name->head->data.ptr_value)->val.str; - schema = PGPointerCast(name->head->next->data.ptr_value)->val.str; - function_name = PGPointerCast(name->head->next->next->data.ptr_value)->val.str; - } else if (name->length == 2) { - // schema + name - catalog = INVALID_CATALOG; - schema = PGPointerCast(name->head->data.ptr_value)->val.str; - function_name = PGPointerCast(name->head->next->data.ptr_value)->val.str; - } else if (name->length == 1) { - // unqualified name - catalog = INVALID_CATALOG; - schema = INVALID_SCHEMA; - function_name = PGPointerCast(name->head->data.ptr_value)->val.str; - } else { - throw ParserException("TransformFuncCall - Expected 1, 2 or 3 qualifications"); - } - - // transform children - vector> children; - if (root.args) { - TransformExpressionList(*root.args, children); - } - if (children.size() == 1 && ExpressionIsEmptyStar(*children[0]) && !root.agg_distinct && !root.agg_order) { - // COUNT(*) gets translated into COUNT() - children.clear(); - } - - auto lowercase_name = StringUtil::Lower(function_name); - if (root.over) { - if (InWindowDefinition()) { - throw ParserException("window functions are not allowed in window definitions"); - } - - const auto win_fun_type = WindowExpression::WindowToExpressionType(lowercase_name); - if (win_fun_type == ExpressionType::INVALID) { - throw InternalException("Unknown/unsupported window function"); - } - - if (root.agg_distinct) { - throw ParserException("DISTINCT is not implemented for window functions!"); - } - - if (root.agg_order) { - throw ParserException("ORDER BY is not implemented for window functions!"); - } - - if (win_fun_type != ExpressionType::WINDOW_AGGREGATE && root.agg_filter) { - throw ParserException("FILTER is not implemented for non-aggregate window functions!"); - } - if (root.export_state) { - throw ParserException("EXPORT_STATE is not supported for window functions!"); - } - - if (win_fun_type == ExpressionType::WINDOW_AGGREGATE && root.agg_ignore_nulls) { - throw ParserException("IGNORE NULLS is not supported for windowed aggregates"); - } - - auto expr = make_uniq(win_fun_type, std::move(catalog), std::move(schema), lowercase_name); - expr->ignore_nulls = root.agg_ignore_nulls; - - if (root.agg_filter) { - auto filter_expr = TransformExpression(root.agg_filter); - expr->filter_expr = std::move(filter_expr); - } - - if (win_fun_type == ExpressionType::WINDOW_AGGREGATE) { - expr->children = std::move(children); - } else { - if (!children.empty()) { - expr->children.push_back(std::move(children[0])); - } - if (win_fun_type == ExpressionType::WINDOW_LEAD || win_fun_type == ExpressionType::WINDOW_LAG) { - if (children.size() > 1) { - expr->offset_expr = std::move(children[1]); - } - if (children.size() > 2) { - expr->default_expr = std::move(children[2]); - } - if (children.size() > 3) { - throw ParserException("Incorrect number of parameters for function %s", lowercase_name); - } - } else if (win_fun_type == ExpressionType::WINDOW_NTH_VALUE) { - if (children.size() > 1) { - expr->children.push_back(std::move(children[1])); - } - if (children.size() > 2) { - throw ParserException("Incorrect number of parameters for function %s", lowercase_name); - } - } else { - if (children.size() > 1) { - throw ParserException("Incorrect number of parameters for function %s", lowercase_name); - } - } - } - auto window_spec = PGPointerCast(root.over); - if (window_spec->name) { - auto it = window_clauses.find(StringUtil::Lower(string(window_spec->name))); - if (it == window_clauses.end()) { - throw ParserException("window \"%s\" does not exist", window_spec->name); - } - window_spec = it->second; - D_ASSERT(window_spec); - } - auto window_ref = window_spec; - auto window_name = window_ref->refname; - if (window_ref->refname) { - auto it = window_clauses.find(StringUtil::Lower(string(window_spec->refname))); - if (it == window_clauses.end()) { - throw ParserException("window \"%s\" does not exist", window_spec->refname); - } - window_ref = it->second; - D_ASSERT(window_ref); - if (window_ref->startOffset || window_ref->endOffset || window_ref->frameOptions != FRAMEOPTION_DEFAULTS) { - throw ParserException("cannot copy window \"%s\" because it has a frame clause", window_spec->refname); - } - } - in_window_definition = true; - TransformWindowDef(*window_ref, *expr); - if (window_ref != window_spec) { - TransformWindowDef(*window_spec, *expr, window_name); - } - TransformWindowFrame(*window_spec, *expr); - in_window_definition = false; - expr->query_location = root.location; - return std::move(expr); - } - - if (root.agg_ignore_nulls) { - throw ParserException("IGNORE NULLS is not supported for non-window functions"); - } - - unique_ptr filter_expr; - if (root.agg_filter) { - filter_expr = TransformExpression(root.agg_filter); - } - - auto order_bys = make_uniq(); - TransformOrderBy(root.agg_order, order_bys->orders); - - // Ordered aggregates can be either WITHIN GROUP or after the function arguments - if (root.agg_within_group) { - // https://www.postgresql.org/docs/current/functions-aggregate.html#FUNCTIONS-ORDEREDSET-TABLE - // Since we implement "ordered aggregates" without sorting, - // we map all the ones we support to the corresponding aggregate function. - if (order_bys->orders.size() != 1) { - throw ParserException("Cannot use multiple ORDER BY clauses with WITHIN GROUP"); - } - if (lowercase_name == "percentile_cont") { - if (children.size() != 1) { - throw ParserException("Wrong number of arguments for PERCENTILE_CONT"); - } - lowercase_name = "quantile_cont"; - } else if (lowercase_name == "percentile_disc") { - if (children.size() != 1) { - throw ParserException("Wrong number of arguments for PERCENTILE_DISC"); - } - lowercase_name = "quantile_disc"; - } else if (lowercase_name == "mode") { - if (!children.empty()) { - throw ParserException("Wrong number of arguments for MODE"); - } - lowercase_name = "mode"; - } else { - throw ParserException("Unknown ordered aggregate \"%s\".", function_name); - } - } - - // star gets eaten in the parser - if (lowercase_name == "count" && children.empty()) { - lowercase_name = "count_star"; - } - - if (lowercase_name == "if") { - if (children.size() != 3) { - throw ParserException("Wrong number of arguments to IF."); - } - auto expr = make_uniq(); - CaseCheck check; - check.when_expr = std::move(children[0]); - check.then_expr = std::move(children[1]); - expr->case_checks.push_back(std::move(check)); - expr->else_expr = std::move(children[2]); - return std::move(expr); - } else if (lowercase_name == "construct_array") { - auto construct_array = make_uniq(ExpressionType::ARRAY_CONSTRUCTOR); - construct_array->children = std::move(children); - return std::move(construct_array); - } else if (lowercase_name == "ifnull") { - if (children.size() != 2) { - throw ParserException("Wrong number of arguments to IFNULL."); - } - - // Two-argument COALESCE - auto coalesce_op = make_uniq(ExpressionType::OPERATOR_COALESCE); - coalesce_op->children.push_back(std::move(children[0])); - coalesce_op->children.push_back(std::move(children[1])); - return std::move(coalesce_op); - } else if (lowercase_name == "list" && order_bys->orders.size() == 1) { - // list(expr ORDER BY expr ) => list_sort(list(expr), , ) - if (children.size() != 1) { - throw ParserException("Wrong number of arguments to LIST."); - } - auto arg_expr = children[0].get(); - auto &order_by = order_bys->orders[0]; - if (arg_expr->Equals(*order_by.expression)) { - auto sense = make_uniq(EnumUtil::ToChars(order_by.type)); - auto nulls = make_uniq(EnumUtil::ToChars(order_by.null_order)); - order_bys = nullptr; - auto unordered = make_uniq(catalog, schema, lowercase_name.c_str(), std::move(children), - std::move(filter_expr), std::move(order_bys), - root.agg_distinct, false, root.export_state); - lowercase_name = "list_sort"; - order_bys.reset(); // NOLINT - filter_expr.reset(); // NOLINT - children.clear(); // NOLINT - root.agg_distinct = false; - children.emplace_back(std::move(unordered)); - children.emplace_back(std::move(sense)); - children.emplace_back(std::move(nulls)); - } - } - - auto function = make_uniq(std::move(catalog), std::move(schema), lowercase_name.c_str(), - std::move(children), std::move(filter_expr), std::move(order_bys), - root.agg_distinct, false, root.export_state); - function->query_location = root.location; - - return std::move(function); -} - -unique_ptr Transformer::TransformSQLValueFunction(duckdb_libpgquery::PGSQLValueFunction &node) { - throw InternalException("SQL value functions should not be emitted by the parser"); -} - -} // namespace duckdb - - - -namespace duckdb { - -unique_ptr Transformer::TransformGroupingFunction(duckdb_libpgquery::PGGroupingFunc &grouping) { - auto op = make_uniq(ExpressionType::GROUPING_FUNCTION); - for (auto node = grouping.args->head; node; node = node->next) { - auto n = PGPointerCast(node->data.ptr_value); - op->children.push_back(TransformExpression(n)); - } - op->query_location = grouping.location; - return std::move(op); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -unique_ptr Transformer::TransformInterval(duckdb_libpgquery::PGIntervalConstant &node) { - // handle post-fix notation of INTERVAL - - // three scenarios - // interval (expr) year - // interval 'string' year - // interval int year - unique_ptr expr; - switch (node.val_type) { - case duckdb_libpgquery::T_PGAExpr: - expr = TransformExpression(node.eval); - break; - case duckdb_libpgquery::T_PGString: - expr = make_uniq(Value(node.sval)); - break; - case duckdb_libpgquery::T_PGInteger: - expr = make_uniq(Value(node.ival)); - break; - default: - throw InternalException("Unsupported interval transformation"); - } - - if (!node.typmods) { - return make_uniq(LogicalType::INTERVAL, std::move(expr)); - } - - int32_t mask = PGPointerCast(node.typmods->head->data.ptr_value)->val.val.ival; - // these seemingly random constants are from datetime.hpp - // they are copied here to avoid having to include this header - // the bitshift is from the function INTERVAL_MASK in the parser - constexpr int32_t MONTH_MASK = 1 << 1; - constexpr int32_t YEAR_MASK = 1 << 2; - constexpr int32_t DAY_MASK = 1 << 3; - constexpr int32_t HOUR_MASK = 1 << 10; - constexpr int32_t MINUTE_MASK = 1 << 11; - constexpr int32_t SECOND_MASK = 1 << 12; - constexpr int32_t MILLISECOND_MASK = 1 << 13; - constexpr int32_t MICROSECOND_MASK = 1 << 14; - - // we need to check certain combinations - // because certain interval masks (e.g. INTERVAL '10' HOURS TO DAYS) set multiple bits - // for now we don't support all of the combined ones - // (we might add support if someone complains about it) - - string fname; - LogicalType target_type; - if (mask & YEAR_MASK && mask & MONTH_MASK) { - // DAY TO HOUR - throw ParserException("YEAR TO MONTH is not supported"); - } else if (mask & DAY_MASK && mask & HOUR_MASK) { - // DAY TO HOUR - throw ParserException("DAY TO HOUR is not supported"); - } else if (mask & DAY_MASK && mask & MINUTE_MASK) { - // DAY TO MINUTE - throw ParserException("DAY TO MINUTE is not supported"); - } else if (mask & DAY_MASK && mask & SECOND_MASK) { - // DAY TO SECOND - throw ParserException("DAY TO SECOND is not supported"); - } else if (mask & HOUR_MASK && mask & MINUTE_MASK) { - // DAY TO SECOND - throw ParserException("HOUR TO MINUTE is not supported"); - } else if (mask & HOUR_MASK && mask & SECOND_MASK) { - // DAY TO SECOND - throw ParserException("HOUR TO SECOND is not supported"); - } else if (mask & MINUTE_MASK && mask & SECOND_MASK) { - // DAY TO SECOND - throw ParserException("MINUTE TO SECOND is not supported"); - } else if (mask & YEAR_MASK) { - // YEAR - fname = "to_years"; - target_type = LogicalType::INTEGER; - } else if (mask & MONTH_MASK) { - // MONTH - fname = "to_months"; - target_type = LogicalType::INTEGER; - } else if (mask & DAY_MASK) { - // DAY - fname = "to_days"; - target_type = LogicalType::INTEGER; - } else if (mask & HOUR_MASK) { - // HOUR - fname = "to_hours"; - target_type = LogicalType::BIGINT; - } else if (mask & MINUTE_MASK) { - // MINUTE - fname = "to_minutes"; - target_type = LogicalType::BIGINT; - } else if (mask & SECOND_MASK) { - // SECOND - fname = "to_seconds"; - target_type = LogicalType::BIGINT; - } else if (mask & MILLISECOND_MASK) { - // MILLISECOND - fname = "to_milliseconds"; - target_type = LogicalType::BIGINT; - } else if (mask & MICROSECOND_MASK) { - // SECOND - fname = "to_microseconds"; - target_type = LogicalType::BIGINT; - } else { - throw InternalException("Unsupported interval post-fix"); - } - // first push a cast to the target type - expr = make_uniq(target_type, std::move(expr)); - // now push the operation - vector> children; - children.push_back(std::move(expr)); - return make_uniq(fname, std::move(children)); -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr Transformer::TransformNullTest(duckdb_libpgquery::PGNullTest &root) { - auto arg = TransformExpression(PGPointerCast(root.arg)); - if (root.argisrow) { - throw NotImplementedException("IS NULL argisrow"); - } - ExpressionType expr_type = (root.nulltesttype == duckdb_libpgquery::PG_IS_NULL) - ? ExpressionType::OPERATOR_IS_NULL - : ExpressionType::OPERATOR_IS_NOT_NULL; - - return unique_ptr(new OperatorExpression(expr_type, std::move(arg))); -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr Transformer::TransformLambda(duckdb_libpgquery::PGLambdaFunction &node) { - D_ASSERT(node.lhs); - D_ASSERT(node.rhs); - - auto lhs = TransformExpression(node.lhs); - auto rhs = TransformExpression(node.rhs); - D_ASSERT(lhs); - D_ASSERT(rhs); - return make_uniq(std::move(lhs), std::move(rhs)); -} - -} // namespace duckdb - - - -namespace duckdb { - -unique_ptr Transformer::TransformMultiAssignRef(duckdb_libpgquery::PGMultiAssignRef &root) { - // Multi assignment for the ROW function - if (root.source->type == duckdb_libpgquery::T_PGFuncCall) { - auto func = PGCast(*root.source); - - // Explicitly only allow ROW function - char const *function_name = - PGPointerCast(func.funcname->tail->data.ptr_value)->val.str; - if (function_name == nullptr || strlen(function_name) != 3 || strncmp(function_name, "row", 3) != 0) { - return TransformExpression(root.source); - } - - // Too many columns (ie. (x, y) = (1, 2, 3) ) - if (root.ncolumns < func.args->length) { - throw ParserException( - "Could not perform multiple assignment, target only expects %d values, %d were provided", root.ncolumns, - func.args->length); - } - - // Get the expression corresponding with the current column - idx_t idx = 1; - auto list = func.args->head; - while (list && idx < static_cast(root.colno)) { - list = list->next; - ++idx; - } - - // Not enough columns (ie. (x, y, z) = (1, 2) ) - if (!list) { - throw ParserException( - "Could not perform multiple assignment, target expects %d values, only %d were provided", root.ncolumns, - func.args->length); - } - return TransformExpression(reinterpret_cast(list->data.ptr_value)); - } - return TransformExpression(root.source); -} - -} // namespace duckdb - - - - - - - - - - - - - - - -namespace duckdb { - -unique_ptr Transformer::TransformUnaryOperator(const string &op, unique_ptr child) { - vector> children; - children.push_back(std::move(child)); - - // built-in operator function - auto result = make_uniq(op, std::move(children)); - result->is_operator = true; - return std::move(result); -} - -unique_ptr Transformer::TransformBinaryOperator(string op, unique_ptr left, - unique_ptr right) { - vector> children; - children.push_back(std::move(left)); - children.push_back(std::move(right)); - - if (options.integer_division && op == "/") { - op = "//"; - } - if (op == "~" || op == "!~") { - // rewrite 'asdf' SIMILAR TO '.*sd.*' into regexp_full_match('asdf', '.*sd.*') - bool invert_similar = op == "!~"; - - auto result = make_uniq("regexp_full_match", std::move(children)); - if (invert_similar) { - return make_uniq(ExpressionType::OPERATOR_NOT, std::move(result)); - } else { - return std::move(result); - } - } else { - auto target_type = OperatorToExpressionType(op); - if (target_type != ExpressionType::INVALID) { - // built-in comparison operator - return make_uniq(target_type, std::move(children[0]), std::move(children[1])); - } - // not a special operator: convert to a function expression - auto result = make_uniq(std::move(op), std::move(children)); - result->is_operator = true; - return std::move(result); - } -} - -unique_ptr Transformer::TransformAExprInternal(duckdb_libpgquery::PGAExpr &root) { - auto name = string(PGPointerCast(root.name->head->data.ptr_value)->val.str); - - switch (root.kind) { - case duckdb_libpgquery::PG_AEXPR_OP_ALL: - case duckdb_libpgquery::PG_AEXPR_OP_ANY: { - // left=ANY(right) - // we turn this into left=ANY((SELECT UNNEST(right))) - auto left_expr = TransformExpression(root.lexpr); - auto right_expr = TransformExpression(root.rexpr); - - auto subquery_expr = make_uniq(); - auto select_statement = make_uniq(); - auto select_node = make_uniq(); - vector> children; - children.push_back(std::move(right_expr)); - - select_node->select_list.push_back(make_uniq("UNNEST", std::move(children))); - select_node->from_table = make_uniq(); - select_statement->node = std::move(select_node); - subquery_expr->subquery = std::move(select_statement); - subquery_expr->subquery_type = SubqueryType::ANY; - subquery_expr->child = std::move(left_expr); - subquery_expr->comparison_type = OperatorToExpressionType(name); - subquery_expr->query_location = root.location; - if (subquery_expr->comparison_type == ExpressionType::INVALID) { - throw ParserException("Unsupported comparison \"%s\" for ANY/ALL subquery", name); - } - - if (root.kind == duckdb_libpgquery::PG_AEXPR_OP_ALL) { - // ALL sublink is equivalent to NOT(ANY) with inverted comparison - // e.g. [= ALL()] is equivalent to [NOT(<> ANY())] - // first invert the comparison type - subquery_expr->comparison_type = NegateComparisonExpression(subquery_expr->comparison_type); - return make_uniq(ExpressionType::OPERATOR_NOT, std::move(subquery_expr)); - } - return std::move(subquery_expr); - } - case duckdb_libpgquery::PG_AEXPR_IN: { - auto left_expr = TransformExpression(root.lexpr); - ExpressionType operator_type; - // this looks very odd, but seems to be the way to find out its NOT IN - if (name == "<>") { - // NOT IN - operator_type = ExpressionType::COMPARE_NOT_IN; - } else { - // IN - operator_type = ExpressionType::COMPARE_IN; - } - auto result = make_uniq(operator_type, std::move(left_expr)); - result->query_location = root.location; - TransformExpressionList(*PGPointerCast(root.rexpr), result->children); - return std::move(result); - } - // rewrite NULLIF(a, b) into CASE WHEN a=b THEN NULL ELSE a END - case duckdb_libpgquery::PG_AEXPR_NULLIF: { - vector> children; - children.push_back(TransformExpression(root.lexpr)); - children.push_back(TransformExpression(root.rexpr)); - return make_uniq("nullif", std::move(children)); - } - // rewrite (NOT) X BETWEEN A AND B into (NOT) AND(GREATERTHANOREQUALTO(X, - // A), LESSTHANOREQUALTO(X, B)) - case duckdb_libpgquery::PG_AEXPR_BETWEEN: - case duckdb_libpgquery::PG_AEXPR_NOT_BETWEEN: { - auto between_args = PGPointerCast(root.rexpr); - if (between_args->length != 2 || !between_args->head->data.ptr_value || !between_args->tail->data.ptr_value) { - throw InternalException("(NOT) BETWEEN needs two args"); - } - - auto input = TransformExpression(root.lexpr); - auto between_left = - TransformExpression(PGPointerCast(between_args->head->data.ptr_value)); - auto between_right = - TransformExpression(PGPointerCast(between_args->tail->data.ptr_value)); - - auto compare_between = - make_uniq(std::move(input), std::move(between_left), std::move(between_right)); - if (root.kind == duckdb_libpgquery::PG_AEXPR_BETWEEN) { - return std::move(compare_between); - } else { - return make_uniq(ExpressionType::OPERATOR_NOT, std::move(compare_between)); - } - } - // rewrite SIMILAR TO into regexp_full_match('asdf', '.*sd.*') - case duckdb_libpgquery::PG_AEXPR_SIMILAR: { - auto left_expr = TransformExpression(root.lexpr); - auto right_expr = TransformExpression(root.rexpr); - - vector> children; - children.push_back(std::move(left_expr)); - - auto &similar_func = right_expr->Cast(); - D_ASSERT(similar_func.function_name == "similar_escape"); - D_ASSERT(similar_func.children.size() == 2); - if (similar_func.children[1]->type != ExpressionType::VALUE_CONSTANT) { - throw NotImplementedException("Custom escape in SIMILAR TO"); - } - auto &constant = similar_func.children[1]->Cast(); - if (!constant.value.IsNull()) { - throw NotImplementedException("Custom escape in SIMILAR TO"); - } - // take the child of the similar_func - children.push_back(std::move(similar_func.children[0])); - - // this looks very odd, but seems to be the way to find out its NOT IN - bool invert_similar = false; - if (name == "!~") { - // NOT SIMILAR TO - invert_similar = true; - } - const auto regex_function = "regexp_full_match"; - auto result = make_uniq(regex_function, std::move(children)); - - if (invert_similar) { - return make_uniq(ExpressionType::OPERATOR_NOT, std::move(result)); - } else { - return std::move(result); - } - } - case duckdb_libpgquery::PG_AEXPR_NOT_DISTINCT: { - auto left_expr = TransformExpression(root.lexpr); - auto right_expr = TransformExpression(root.rexpr); - return make_uniq(ExpressionType::COMPARE_NOT_DISTINCT_FROM, std::move(left_expr), - std::move(right_expr)); - } - case duckdb_libpgquery::PG_AEXPR_DISTINCT: { - auto left_expr = TransformExpression(root.lexpr); - auto right_expr = TransformExpression(root.rexpr); - return make_uniq(ExpressionType::COMPARE_DISTINCT_FROM, std::move(left_expr), - std::move(right_expr)); - } - - default: - break; - } - auto left_expr = TransformExpression(root.lexpr); - auto right_expr = TransformExpression(root.rexpr); - - if (!left_expr) { - // prefix operator - return TransformUnaryOperator(name, std::move(right_expr)); - } else if (!right_expr) { - // postfix operator, only ! is currently supported - return TransformUnaryOperator(name + "__postfix", std::move(left_expr)); - } else { - return TransformBinaryOperator(std::move(name), std::move(left_expr), std::move(right_expr)); - } -} - -unique_ptr Transformer::TransformAExpr(duckdb_libpgquery::PGAExpr &root) { - auto result = TransformAExprInternal(root); - if (result) { - result->query_location = root.location; - } - return result; -} - -} // namespace duckdb - - - - -namespace duckdb { - -namespace { - -struct PreparedParam { - PreparedParamType type; - string identifier; -}; - -} // namespace - -static PreparedParam GetParameterIdentifier(duckdb_libpgquery::PGParamRef &node) { - PreparedParam param; - if (node.name) { - param.type = PreparedParamType::NAMED; - param.identifier = node.name; - return param; - } - if (node.number < 0) { - throw ParserException("Parameter numbers cannot be negative"); - } - param.identifier = StringUtil::Format("%d", node.number); - param.type = node.number == 0 ? PreparedParamType::AUTO_INCREMENT : PreparedParamType::POSITIONAL; - return param; -} - -unique_ptr Transformer::TransformParamRef(duckdb_libpgquery::PGParamRef &node) { - auto expr = make_uniq(); - - auto param = GetParameterIdentifier(node); - idx_t known_param_index = DConstants::INVALID_INDEX; - // This is a named parameter, try to find an entry for it - GetParam(param.identifier, known_param_index, param.type); - - if (known_param_index == DConstants::INVALID_INDEX) { - // We have not seen this parameter before - if (node.number != 0) { - // Preserve the parameter number - known_param_index = node.number; - } else { - known_param_index = ParamCount() + 1; - if (!node.name) { - param.identifier = StringUtil::Format("%d", known_param_index); - } - } - - if (!named_param_map.count(param.identifier)) { - // Add it to the named parameter map so we can find it next time it's referenced - SetParam(param.identifier, known_param_index, param.type); - } - } - - expr->identifier = param.identifier; - idx_t new_param_count = MaxValue(ParamCount(), known_param_index); - SetParamCount(new_param_count); - return std::move(expr); -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr Transformer::TransformPositionalReference(duckdb_libpgquery::PGPositionalReference &node) { - if (node.position <= 0) { - throw ParserException("Positional reference node needs to be >= 1"); - } - auto result = make_uniq(node.position); - result->query_location = node.location; - return std::move(result); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -unique_ptr Transformer::TransformSubquery(duckdb_libpgquery::PGSubLink &root) { - auto subquery_expr = make_uniq(); - - subquery_expr->subquery = TransformSelect(root.subselect); - D_ASSERT(subquery_expr->subquery); - D_ASSERT(subquery_expr->subquery->node->GetSelectList().size() > 0); - - switch (root.subLinkType) { - case duckdb_libpgquery::PG_EXISTS_SUBLINK: { - subquery_expr->subquery_type = SubqueryType::EXISTS; - break; - } - case duckdb_libpgquery::PG_ANY_SUBLINK: - case duckdb_libpgquery::PG_ALL_SUBLINK: { - // comparison with ANY() or ALL() - subquery_expr->subquery_type = SubqueryType::ANY; - subquery_expr->child = TransformExpression(root.testexpr); - // get the operator name - if (!root.operName) { - // simple IN - subquery_expr->comparison_type = ExpressionType::COMPARE_EQUAL; - } else { - auto operator_name = - string((PGPointerCast(root.operName->head->data.ptr_value))->val.str); - subquery_expr->comparison_type = OperatorToExpressionType(operator_name); - } - if (subquery_expr->comparison_type != ExpressionType::COMPARE_EQUAL && - subquery_expr->comparison_type != ExpressionType::COMPARE_NOTEQUAL && - subquery_expr->comparison_type != ExpressionType::COMPARE_GREATERTHAN && - subquery_expr->comparison_type != ExpressionType::COMPARE_GREATERTHANOREQUALTO && - subquery_expr->comparison_type != ExpressionType::COMPARE_LESSTHAN && - subquery_expr->comparison_type != ExpressionType::COMPARE_LESSTHANOREQUALTO) { - throw ParserException("ANY and ALL operators require one of =,<>,>,<,>=,<= comparisons!"); - } - if (root.subLinkType == duckdb_libpgquery::PG_ALL_SUBLINK) { - // ALL sublink is equivalent to NOT(ANY) with inverted comparison - // e.g. [= ALL()] is equivalent to [NOT(<> ANY())] - // first invert the comparison type - subquery_expr->comparison_type = NegateComparisonExpression(subquery_expr->comparison_type); - return make_uniq(ExpressionType::OPERATOR_NOT, std::move(subquery_expr)); - } - break; - } - case duckdb_libpgquery::PG_EXPR_SUBLINK: { - // return a single scalar value from the subquery - // no child expression to compare to - subquery_expr->subquery_type = SubqueryType::SCALAR; - break; - } - case duckdb_libpgquery::PG_ARRAY_SUBLINK: { - auto subquery_table_alias = "__subquery"; - auto subquery_column_alias = "__arr_element"; - - // ARRAY expression - // wrap subquery into "SELECT CASE WHEN ARRAY_AGG(i) IS NULL THEN [] ELSE ARRAY_AGG(i) END FROM (...) tbl(i)" - auto select_node = make_uniq(); - - // ARRAY_AGG(i) - vector> children; - children.push_back( - make_uniq_base(subquery_column_alias, subquery_table_alias)); - auto aggr = make_uniq("array_agg", std::move(children)); - // ARRAY_AGG(i) IS NULL - auto agg_is_null = make_uniq(ExpressionType::OPERATOR_IS_NULL, aggr->Copy()); - // empty list - vector> list_children; - auto empty_list = make_uniq("list_value", std::move(list_children)); - // CASE - auto case_expr = make_uniq(); - CaseCheck check; - check.when_expr = std::move(agg_is_null); - check.then_expr = std::move(empty_list); - case_expr->case_checks.push_back(std::move(check)); - case_expr->else_expr = std::move(aggr); - - select_node->select_list.push_back(std::move(case_expr)); - - // FROM (...) tbl(i) - auto child_subquery = make_uniq(std::move(subquery_expr->subquery), subquery_table_alias); - child_subquery->column_name_alias.emplace_back(subquery_column_alias); - select_node->from_table = std::move(child_subquery); - - auto new_subquery = make_uniq(); - new_subquery->node = std::move(select_node); - subquery_expr->subquery = std::move(new_subquery); - - subquery_expr->subquery_type = SubqueryType::SCALAR; - break; - } - default: - throw NotImplementedException("Subquery of type %d not implemented\n", (int)root.subLinkType); - } - subquery_expr->query_location = root.location; - return std::move(subquery_expr); -} - -} // namespace duckdb - - -namespace duckdb { - -std::string Transformer::NodetypeToString(duckdb_libpgquery::PGNodeTag type) { // LCOV_EXCL_START - switch (type) { - case duckdb_libpgquery::T_PGInvalid: - return "T_Invalid"; - case duckdb_libpgquery::T_PGIndexInfo: - return "T_IndexInfo"; - case duckdb_libpgquery::T_PGExprContext: - return "T_ExprContext"; - case duckdb_libpgquery::T_PGProjectionInfo: - return "T_ProjectionInfo"; - case duckdb_libpgquery::T_PGJunkFilter: - return "T_JunkFilter"; - case duckdb_libpgquery::T_PGResultRelInfo: - return "T_ResultRelInfo"; - case duckdb_libpgquery::T_PGEState: - return "T_EState"; - case duckdb_libpgquery::T_PGTupleTableSlot: - return "T_TupleTableSlot"; - case duckdb_libpgquery::T_PGPlan: - return "T_Plan"; - case duckdb_libpgquery::T_PGResult: - return "T_Result"; - case duckdb_libpgquery::T_PGProjectSet: - return "T_ProjectSet"; - case duckdb_libpgquery::T_PGModifyTable: - return "T_ModifyTable"; - case duckdb_libpgquery::T_PGAppend: - return "T_Append"; - case duckdb_libpgquery::T_PGMergeAppend: - return "T_MergeAppend"; - case duckdb_libpgquery::T_PGRecursiveUnion: - return "T_RecursiveUnion"; - case duckdb_libpgquery::T_PGBitmapAnd: - return "T_BitmapAnd"; - case duckdb_libpgquery::T_PGBitmapOr: - return "T_BitmapOr"; - case duckdb_libpgquery::T_PGScan: - return "T_Scan"; - case duckdb_libpgquery::T_PGSeqScan: - return "T_SeqScan"; - case duckdb_libpgquery::T_PGSampleScan: - return "T_SampleScan"; - case duckdb_libpgquery::T_PGIndexScan: - return "T_IndexScan"; - case duckdb_libpgquery::T_PGIndexOnlyScan: - return "T_IndexOnlyScan"; - case duckdb_libpgquery::T_PGBitmapIndexScan: - return "T_BitmapIndexScan"; - case duckdb_libpgquery::T_PGBitmapHeapScan: - return "T_BitmapHeapScan"; - case duckdb_libpgquery::T_PGTidScan: - return "T_TidScan"; - case duckdb_libpgquery::T_PGSubqueryScan: - return "T_SubqueryScan"; - case duckdb_libpgquery::T_PGFunctionScan: - return "T_FunctionScan"; - case duckdb_libpgquery::T_PGValuesScan: - return "T_ValuesScan"; - case duckdb_libpgquery::T_PGTableFuncScan: - return "T_TableFuncScan"; - case duckdb_libpgquery::T_PGCteScan: - return "T_CteScan"; - case duckdb_libpgquery::T_PGNamedTuplestoreScan: - return "T_NamedTuplestoreScan"; - case duckdb_libpgquery::T_PGWorkTableScan: - return "T_WorkTableScan"; - case duckdb_libpgquery::T_PGForeignScan: - return "T_ForeignScan"; - case duckdb_libpgquery::T_PGCustomScan: - return "T_CustomScan"; - case duckdb_libpgquery::T_PGJoin: - return "T_Join"; - case duckdb_libpgquery::T_PGNestLoop: - return "T_NestLoop"; - case duckdb_libpgquery::T_PGMergeJoin: - return "T_MergeJoin"; - case duckdb_libpgquery::T_PGHashJoin: - return "T_HashJoin"; - case duckdb_libpgquery::T_PGMaterial: - return "T_Material"; - case duckdb_libpgquery::T_PGSort: - return "T_Sort"; - case duckdb_libpgquery::T_PGGroup: - return "T_Group"; - case duckdb_libpgquery::T_PGAgg: - return "T_Agg"; - case duckdb_libpgquery::T_PGWindowAgg: - return "T_WindowAgg"; - case duckdb_libpgquery::T_PGUnique: - return "T_Unique"; - case duckdb_libpgquery::T_PGGather: - return "T_Gather"; - case duckdb_libpgquery::T_PGGatherMerge: - return "T_GatherMerge"; - case duckdb_libpgquery::T_PGHash: - return "T_Hash"; - case duckdb_libpgquery::T_PGSetOp: - return "T_SetOp"; - case duckdb_libpgquery::T_PGLockRows: - return "T_LockRows"; - case duckdb_libpgquery::T_PGLimit: - return "T_Limit"; - case duckdb_libpgquery::T_PGNestLoopParam: - return "T_NestLoopParam"; - case duckdb_libpgquery::T_PGPlanRowMark: - return "T_PlanRowMark"; - case duckdb_libpgquery::T_PGPlanInvalItem: - return "T_PlanInvalItem"; - case duckdb_libpgquery::T_PGPlanState: - return "T_PlanState"; - case duckdb_libpgquery::T_PGResultState: - return "T_ResultState"; - case duckdb_libpgquery::T_PGProjectSetState: - return "T_ProjectSetState"; - case duckdb_libpgquery::T_PGModifyTableState: - return "T_ModifyTableState"; - case duckdb_libpgquery::T_PGAppendState: - return "T_AppendState"; - case duckdb_libpgquery::T_PGMergeAppendState: - return "T_MergeAppendState"; - case duckdb_libpgquery::T_PGRecursiveUnionState: - return "T_RecursiveUnionState"; - case duckdb_libpgquery::T_PGBitmapAndState: - return "T_BitmapAndState"; - case duckdb_libpgquery::T_PGBitmapOrState: - return "T_BitmapOrState"; - case duckdb_libpgquery::T_PGScanState: - return "T_ScanState"; - case duckdb_libpgquery::T_PGSeqScanState: - return "T_SeqScanState"; - case duckdb_libpgquery::T_PGSampleScanState: - return "T_SampleScanState"; - case duckdb_libpgquery::T_PGIndexScanState: - return "T_IndexScanState"; - case duckdb_libpgquery::T_PGIndexOnlyScanState: - return "T_IndexOnlyScanState"; - case duckdb_libpgquery::T_PGBitmapIndexScanState: - return "T_BitmapIndexScanState"; - case duckdb_libpgquery::T_PGBitmapHeapScanState: - return "T_BitmapHeapScanState"; - case duckdb_libpgquery::T_PGTidScanState: - return "T_TidScanState"; - case duckdb_libpgquery::T_PGSubqueryScanState: - return "T_SubqueryScanState"; - case duckdb_libpgquery::T_PGFunctionScanState: - return "T_FunctionScanState"; - case duckdb_libpgquery::T_PGTableFuncScanState: - return "T_TableFuncScanState"; - case duckdb_libpgquery::T_PGValuesScanState: - return "T_ValuesScanState"; - case duckdb_libpgquery::T_PGCteScanState: - return "T_CteScanState"; - case duckdb_libpgquery::T_PGNamedTuplestoreScanState: - return "T_NamedTuplestoreScanState"; - case duckdb_libpgquery::T_PGWorkTableScanState: - return "T_WorkTableScanState"; - case duckdb_libpgquery::T_PGForeignScanState: - return "T_ForeignScanState"; - case duckdb_libpgquery::T_PGCustomScanState: - return "T_CustomScanState"; - case duckdb_libpgquery::T_PGJoinState: - return "T_JoinState"; - case duckdb_libpgquery::T_PGNestLoopState: - return "T_NestLoopState"; - case duckdb_libpgquery::T_PGMergeJoinState: - return "T_MergeJoinState"; - case duckdb_libpgquery::T_PGHashJoinState: - return "T_HashJoinState"; - case duckdb_libpgquery::T_PGMaterialState: - return "T_MaterialState"; - case duckdb_libpgquery::T_PGSortState: - return "T_SortState"; - case duckdb_libpgquery::T_PGGroupState: - return "T_GroupState"; - case duckdb_libpgquery::T_PGAggState: - return "T_AggState"; - case duckdb_libpgquery::T_PGWindowAggState: - return "T_WindowAggState"; - case duckdb_libpgquery::T_PGUniqueState: - return "T_UniqueState"; - case duckdb_libpgquery::T_PGGatherState: - return "T_GatherState"; - case duckdb_libpgquery::T_PGGatherMergeState: - return "T_GatherMergeState"; - case duckdb_libpgquery::T_PGHashState: - return "T_HashState"; - case duckdb_libpgquery::T_PGSetOpState: - return "T_SetOpState"; - case duckdb_libpgquery::T_PGLockRowsState: - return "T_LockRowsState"; - case duckdb_libpgquery::T_PGLimitState: - return "T_LimitState"; - case duckdb_libpgquery::T_PGAlias: - return "T_Alias"; - case duckdb_libpgquery::T_PGRangeVar: - return "T_RangeVar"; - case duckdb_libpgquery::T_PGTableFunc: - return "T_TableFunc"; - case duckdb_libpgquery::T_PGExpr: - return "T_Expr"; - case duckdb_libpgquery::T_PGVar: - return "T_Var"; - case duckdb_libpgquery::T_PGConst: - return "T_Const"; - case duckdb_libpgquery::T_PGParam: - return "T_Param"; - case duckdb_libpgquery::T_PGAggref: - return "T_Aggref"; - case duckdb_libpgquery::T_PGGroupingFunc: - return "T_GroupingFunc"; - case duckdb_libpgquery::T_PGWindowFunc: - return "T_WindowFunc"; - case duckdb_libpgquery::T_PGArrayRef: - return "T_ArrayRef"; - case duckdb_libpgquery::T_PGFuncExpr: - return "T_FuncExpr"; - case duckdb_libpgquery::T_PGNamedArgExpr: - return "T_NamedArgExpr"; - case duckdb_libpgquery::T_PGOpExpr: - return "T_OpExpr"; - case duckdb_libpgquery::T_PGDistinctExpr: - return "T_DistinctExpr"; - case duckdb_libpgquery::T_PGNullIfExpr: - return "T_NullIfExpr"; - case duckdb_libpgquery::T_PGScalarArrayOpExpr: - return "T_ScalarArrayOpExpr"; - case duckdb_libpgquery::T_PGBoolExpr: - return "T_BoolExpr"; - case duckdb_libpgquery::T_PGSubLink: - return "T_SubLink"; - case duckdb_libpgquery::T_PGSubPlan: - return "T_SubPlan"; - case duckdb_libpgquery::T_PGAlternativeSubPlan: - return "T_AlternativeSubPlan"; - case duckdb_libpgquery::T_PGFieldSelect: - return "T_FieldSelect"; - case duckdb_libpgquery::T_PGFieldStore: - return "T_FieldStore"; - case duckdb_libpgquery::T_PGRelabelType: - return "T_RelabelType"; - case duckdb_libpgquery::T_PGCoerceViaIO: - return "T_CoerceViaIO"; - case duckdb_libpgquery::T_PGArrayCoerceExpr: - return "T_ArrayCoerceExpr"; - case duckdb_libpgquery::T_PGConvertRowtypeExpr: - return "T_ConvertRowtypeExpr"; - case duckdb_libpgquery::T_PGCollateExpr: - return "T_CollateExpr"; - case duckdb_libpgquery::T_PGCaseExpr: - return "T_CaseExpr"; - case duckdb_libpgquery::T_PGCaseWhen: - return "T_CaseWhen"; - case duckdb_libpgquery::T_PGCaseTestExpr: - return "T_CaseTestExpr"; - case duckdb_libpgquery::T_PGArrayExpr: - return "T_ArrayExpr"; - case duckdb_libpgquery::T_PGRowExpr: - return "T_RowExpr"; - case duckdb_libpgquery::T_PGRowCompareExpr: - return "T_RowCompareExpr"; - case duckdb_libpgquery::T_PGCoalesceExpr: - return "T_CoalesceExpr"; - case duckdb_libpgquery::T_PGMinMaxExpr: - return "T_MinMaxExpr"; - case duckdb_libpgquery::T_PGSQLValueFunction: - return "T_SQLValueFunction"; - case duckdb_libpgquery::T_PGXmlExpr: - return "T_XmlExpr"; - case duckdb_libpgquery::T_PGNullTest: - return "T_NullTest"; - case duckdb_libpgquery::T_PGBooleanTest: - return "T_BooleanTest"; - case duckdb_libpgquery::T_PGCoerceToDomain: - return "T_CoerceToDomain"; - case duckdb_libpgquery::T_PGCoerceToDomainValue: - return "T_CoerceToDomainValue"; - case duckdb_libpgquery::T_PGSetToDefault: - return "T_SetToDefault"; - case duckdb_libpgquery::T_PGCurrentOfExpr: - return "T_CurrentOfExpr"; - case duckdb_libpgquery::T_PGNextValueExpr: - return "T_NextValueExpr"; - case duckdb_libpgquery::T_PGInferenceElem: - return "T_InferenceElem"; - case duckdb_libpgquery::T_PGTargetEntry: - return "T_TargetEntry"; - case duckdb_libpgquery::T_PGRangeTblRef: - return "T_RangeTblRef"; - case duckdb_libpgquery::T_PGJoinExpr: - return "T_JoinExpr"; - case duckdb_libpgquery::T_PGFromExpr: - return "T_FromExpr"; - case duckdb_libpgquery::T_PGOnConflictExpr: - return "T_OnConflictExpr"; - case duckdb_libpgquery::T_PGIntoClause: - return "T_IntoClause"; - case duckdb_libpgquery::T_PGExprState: - return "T_ExprState"; - case duckdb_libpgquery::T_PGAggrefExprState: - return "T_AggrefExprState"; - case duckdb_libpgquery::T_PGWindowFuncExprState: - return "T_WindowFuncExprState"; - case duckdb_libpgquery::T_PGSetExprState: - return "T_SetExprState"; - case duckdb_libpgquery::T_PGSubPlanState: - return "T_SubPlanState"; - case duckdb_libpgquery::T_PGAlternativeSubPlanState: - return "T_AlternativeSubPlanState"; - case duckdb_libpgquery::T_PGDomainConstraintState: - return "T_DomainConstraintState"; - case duckdb_libpgquery::T_PGPlannerInfo: - return "T_PlannerInfo"; - case duckdb_libpgquery::T_PGPlannerGlobal: - return "T_PlannerGlobal"; - case duckdb_libpgquery::T_PGRelOptInfo: - return "T_RelOptInfo"; - case duckdb_libpgquery::T_PGIndexOptInfo: - return "T_IndexOptInfo"; - case duckdb_libpgquery::T_PGForeignKeyOptInfo: - return "T_ForeignKeyOptInfo"; - case duckdb_libpgquery::T_PGParamPathInfo: - return "T_ParamPathInfo"; - case duckdb_libpgquery::T_PGPath: - return "T_Path"; - case duckdb_libpgquery::T_PGIndexPath: - return "T_IndexPath"; - case duckdb_libpgquery::T_PGBitmapHeapPath: - return "T_BitmapHeapPath"; - case duckdb_libpgquery::T_PGBitmapAndPath: - return "T_BitmapAndPath"; - case duckdb_libpgquery::T_PGBitmapOrPath: - return "T_BitmapOrPath"; - case duckdb_libpgquery::T_PGTidPath: - return "T_TidPath"; - case duckdb_libpgquery::T_PGSubqueryScanPath: - return "T_SubqueryScanPath"; - case duckdb_libpgquery::T_PGForeignPath: - return "T_ForeignPath"; - case duckdb_libpgquery::T_PGCustomPath: - return "T_CustomPath"; - case duckdb_libpgquery::T_PGNestPath: - return "T_NestPath"; - case duckdb_libpgquery::T_PGMergePath: - return "T_MergePath"; - case duckdb_libpgquery::T_PGHashPath: - return "T_HashPath"; - case duckdb_libpgquery::T_PGAppendPath: - return "T_AppendPath"; - case duckdb_libpgquery::T_PGMergeAppendPath: - return "T_MergeAppendPath"; - case duckdb_libpgquery::T_PGResultPath: - return "T_ResultPath"; - case duckdb_libpgquery::T_PGMaterialPath: - return "T_MaterialPath"; - case duckdb_libpgquery::T_PGUniquePath: - return "T_UniquePath"; - case duckdb_libpgquery::T_PGGatherPath: - return "T_GatherPath"; - case duckdb_libpgquery::T_PGGatherMergePath: - return "T_GatherMergePath"; - case duckdb_libpgquery::T_PGProjectionPath: - return "T_ProjectionPath"; - case duckdb_libpgquery::T_PGProjectSetPath: - return "T_ProjectSetPath"; - case duckdb_libpgquery::T_PGSortPath: - return "T_SortPath"; - case duckdb_libpgquery::T_PGGroupPath: - return "T_GroupPath"; - case duckdb_libpgquery::T_PGUpperUniquePath: - return "T_UpperUniquePath"; - case duckdb_libpgquery::T_PGAggPath: - return "T_AggPath"; - case duckdb_libpgquery::T_PGGroupingSetsPath: - return "T_GroupingSetsPath"; - case duckdb_libpgquery::T_PGMinMaxAggPath: - return "T_MinMaxAggPath"; - case duckdb_libpgquery::T_PGWindowAggPath: - return "T_WindowAggPath"; - case duckdb_libpgquery::T_PGSetOpPath: - return "T_SetOpPath"; - case duckdb_libpgquery::T_PGRecursiveUnionPath: - return "T_RecursiveUnionPath"; - case duckdb_libpgquery::T_PGLockRowsPath: - return "T_LockRowsPath"; - case duckdb_libpgquery::T_PGModifyTablePath: - return "T_ModifyTablePath"; - case duckdb_libpgquery::T_PGLimitPath: - return "T_LimitPath"; - case duckdb_libpgquery::T_PGEquivalenceClass: - return "T_EquivalenceClass"; - case duckdb_libpgquery::T_PGEquivalenceMember: - return "T_EquivalenceMember"; - case duckdb_libpgquery::T_PGPathKey: - return "T_PathKey"; - case duckdb_libpgquery::T_PGPathTarget: - return "T_PathTarget"; - case duckdb_libpgquery::T_PGRestrictInfo: - return "T_RestrictInfo"; - case duckdb_libpgquery::T_PGPlaceHolderVar: - return "T_PlaceHolderVar"; - case duckdb_libpgquery::T_PGSpecialJoinInfo: - return "T_SpecialJoinInfo"; - case duckdb_libpgquery::T_PGAppendRelInfo: - return "T_AppendRelInfo"; - case duckdb_libpgquery::T_PGPartitionedChildRelInfo: - return "T_PartitionedChildRelInfo"; - case duckdb_libpgquery::T_PGPlaceHolderInfo: - return "T_PlaceHolderInfo"; - case duckdb_libpgquery::T_PGMinMaxAggInfo: - return "T_MinMaxAggInfo"; - case duckdb_libpgquery::T_PGPlannerParamItem: - return "T_PlannerParamItem"; - case duckdb_libpgquery::T_PGRollupData: - return "T_RollupData"; - case duckdb_libpgquery::T_PGGroupingSetData: - return "T_GroupingSetData"; - case duckdb_libpgquery::T_PGStatisticExtInfo: - return "T_StatisticExtInfo"; - case duckdb_libpgquery::T_PGMemoryContext: - return "T_MemoryContext"; - case duckdb_libpgquery::T_PGAllocSetContext: - return "T_AllocSetContext"; - case duckdb_libpgquery::T_PGSlabContext: - return "T_SlabContext"; - case duckdb_libpgquery::T_PGValue: - return "T_Value"; - case duckdb_libpgquery::T_PGInteger: - return "T_Integer"; - case duckdb_libpgquery::T_PGFloat: - return "T_Float"; - case duckdb_libpgquery::T_PGString: - return "T_String"; - case duckdb_libpgquery::T_PGBitString: - return "T_BitString"; - case duckdb_libpgquery::T_PGNull: - return "T_Null"; - case duckdb_libpgquery::T_PGList: - return "T_List"; - case duckdb_libpgquery::T_PGIntList: - return "T_IntList"; - case duckdb_libpgquery::T_PGOidList: - return "T_OidList"; - case duckdb_libpgquery::T_PGExtensibleNode: - return "T_ExtensibleNode"; - case duckdb_libpgquery::T_PGRawStmt: - return "T_RawStmt"; - case duckdb_libpgquery::T_PGQuery: - return "T_Query"; - case duckdb_libpgquery::T_PGPlannedStmt: - return "T_PlannedStmt"; - case duckdb_libpgquery::T_PGInsertStmt: - return "T_InsertStmt"; - case duckdb_libpgquery::T_PGDeleteStmt: - return "T_DeleteStmt"; - case duckdb_libpgquery::T_PGUpdateStmt: - return "T_UpdateStmt"; - case duckdb_libpgquery::T_PGSelectStmt: - return "T_SelectStmt"; - case duckdb_libpgquery::T_PGAlterTableStmt: - return "T_AlterTableStmt"; - case duckdb_libpgquery::T_PGAlterTableCmd: - return "T_AlterTableCmd"; - case duckdb_libpgquery::T_PGAlterDomainStmt: - return "T_AlterDomainStmt"; - case duckdb_libpgquery::T_PGSetOperationStmt: - return "T_SetOperationStmt"; - case duckdb_libpgquery::T_PGGrantStmt: - return "T_GrantStmt"; - case duckdb_libpgquery::T_PGGrantRoleStmt: - return "T_GrantRoleStmt"; - case duckdb_libpgquery::T_PGAlterDefaultPrivilegesStmt: - return "T_AlterDefaultPrivilegesStmt"; - case duckdb_libpgquery::T_PGClosePortalStmt: - return "T_ClosePortalStmt"; - case duckdb_libpgquery::T_PGClusterStmt: - return "T_ClusterStmt"; - case duckdb_libpgquery::T_PGCopyStmt: - return "T_CopyStmt"; - case duckdb_libpgquery::T_PGCreateStmt: - return "T_CreateStmt"; - case duckdb_libpgquery::T_PGDefineStmt: - return "T_DefineStmt"; - case duckdb_libpgquery::T_PGDropStmt: - return "T_DropStmt"; - case duckdb_libpgquery::T_PGTruncateStmt: - return "T_TruncateStmt"; - case duckdb_libpgquery::T_PGCommentStmt: - return "T_CommentStmt"; - case duckdb_libpgquery::T_PGFetchStmt: - return "T_FetchStmt"; - case duckdb_libpgquery::T_PGIndexStmt: - return "T_IndexStmt"; - case duckdb_libpgquery::T_PGCreateFunctionStmt: - return "T_CreateFunctionStmt"; - case duckdb_libpgquery::T_PGAlterFunctionStmt: - return "T_AlterFunctionStmt"; - case duckdb_libpgquery::T_PGDoStmt: - return "T_DoStmt"; - case duckdb_libpgquery::T_PGRenameStmt: - return "T_RenameStmt"; - case duckdb_libpgquery::T_PGRuleStmt: - return "T_RuleStmt"; - case duckdb_libpgquery::T_PGNotifyStmt: - return "T_NotifyStmt"; - case duckdb_libpgquery::T_PGListenStmt: - return "T_ListenStmt"; - case duckdb_libpgquery::T_PGUnlistenStmt: - return "T_UnlistenStmt"; - case duckdb_libpgquery::T_PGTransactionStmt: - return "T_TransactionStmt"; - case duckdb_libpgquery::T_PGViewStmt: - return "T_ViewStmt"; - case duckdb_libpgquery::T_PGLoadStmt: - return "T_LoadStmt"; - case duckdb_libpgquery::T_PGCreateDomainStmt: - return "T_CreateDomainStmt"; - case duckdb_libpgquery::T_PGCreatedbStmt: - return "T_CreatedbStmt"; - case duckdb_libpgquery::T_PGDropdbStmt: - return "T_DropdbStmt"; - case duckdb_libpgquery::T_PGVacuumStmt: - return "T_VacuumStmt"; - case duckdb_libpgquery::T_PGExplainStmt: - return "T_ExplainStmt"; - case duckdb_libpgquery::T_PGCreateTableAsStmt: - return "T_CreateTableAsStmt"; - case duckdb_libpgquery::T_PGCreateSeqStmt: - return "T_CreateSeqStmt"; - case duckdb_libpgquery::T_PGAlterSeqStmt: - return "T_AlterSeqStmt"; - case duckdb_libpgquery::T_PGVariableSetStmt: - return "T_VariableSetStmt"; - case duckdb_libpgquery::T_PGVariableShowStmt: - return "T_VariableShowStmt"; - case duckdb_libpgquery::T_PGVariableShowSelectStmt: - return "T_VariableShowSelectStmt"; - case duckdb_libpgquery::T_PGDiscardStmt: - return "T_DiscardStmt"; - case duckdb_libpgquery::T_PGCreateTrigStmt: - return "T_CreateTrigStmt"; - case duckdb_libpgquery::T_PGCreatePLangStmt: - return "T_CreatePLangStmt"; - case duckdb_libpgquery::T_PGCreateRoleStmt: - return "T_CreateRoleStmt"; - case duckdb_libpgquery::T_PGAlterRoleStmt: - return "T_AlterRoleStmt"; - case duckdb_libpgquery::T_PGDropRoleStmt: - return "T_DropRoleStmt"; - case duckdb_libpgquery::T_PGLockStmt: - return "T_LockStmt"; - case duckdb_libpgquery::T_PGConstraintsSetStmt: - return "T_ConstraintsSetStmt"; - case duckdb_libpgquery::T_PGReindexStmt: - return "T_ReindexStmt"; - case duckdb_libpgquery::T_PGCheckPointStmt: - return "T_CheckPointStmt"; - case duckdb_libpgquery::T_PGCreateSchemaStmt: - return "T_CreateSchemaStmt"; - case duckdb_libpgquery::T_PGAlterDatabaseStmt: - return "T_AlterDatabaseStmt"; - case duckdb_libpgquery::T_PGAlterDatabaseSetStmt: - return "T_AlterDatabaseSetStmt"; - case duckdb_libpgquery::T_PGAlterRoleSetStmt: - return "T_AlterRoleSetStmt"; - case duckdb_libpgquery::T_PGCreateConversionStmt: - return "T_CreateConversionStmt"; - case duckdb_libpgquery::T_PGCreateCastStmt: - return "T_CreateCastStmt"; - case duckdb_libpgquery::T_PGCreateOpClassStmt: - return "T_CreateOpClassStmt"; - case duckdb_libpgquery::T_PGCreateOpFamilyStmt: - return "T_CreateOpFamilyStmt"; - case duckdb_libpgquery::T_PGAlterOpFamilyStmt: - return "T_AlterOpFamilyStmt"; - case duckdb_libpgquery::T_PGPrepareStmt: - return "T_PrepareStmt"; - case duckdb_libpgquery::T_PGExecuteStmt: - return "T_ExecuteStmt"; - case duckdb_libpgquery::T_PGCallStmt: - return "T_CallStmt"; - case duckdb_libpgquery::T_PGDeallocateStmt: - return "T_DeallocateStmt"; - case duckdb_libpgquery::T_PGDeclareCursorStmt: - return "T_DeclareCursorStmt"; - case duckdb_libpgquery::T_PGCreateTableSpaceStmt: - return "T_CreateTableSpaceStmt"; - case duckdb_libpgquery::T_PGDropTableSpaceStmt: - return "T_DropTableSpaceStmt"; - case duckdb_libpgquery::T_PGAlterObjectDependsStmt: - return "T_AlterObjectDependsStmt"; - case duckdb_libpgquery::T_PGAlterObjectSchemaStmt: - return "T_AlterObjectSchemaStmt"; - case duckdb_libpgquery::T_PGAlterOwnerStmt: - return "T_AlterOwnerStmt"; - case duckdb_libpgquery::T_PGAlterOperatorStmt: - return "T_AlterOperatorStmt"; - case duckdb_libpgquery::T_PGDropOwnedStmt: - return "T_DropOwnedStmt"; - case duckdb_libpgquery::T_PGReassignOwnedStmt: - return "T_ReassignOwnedStmt"; - case duckdb_libpgquery::T_PGCompositeTypeStmt: - return "T_CompositeTypeStmt"; - case duckdb_libpgquery::T_PGCreateTypeStmt: - return "T_CreateTypeStmt"; - case duckdb_libpgquery::T_PGCreateRangeStmt: - return "T_CreateRangeStmt"; - case duckdb_libpgquery::T_PGAlterEnumStmt: - return "T_AlterEnumStmt"; - case duckdb_libpgquery::T_PGAlterTSDictionaryStmt: - return "T_AlterTSDictionaryStmt"; - case duckdb_libpgquery::T_PGAlterTSConfigurationStmt: - return "T_AlterTSConfigurationStmt"; - case duckdb_libpgquery::T_PGCreateFdwStmt: - return "T_CreateFdwStmt"; - case duckdb_libpgquery::T_PGAlterFdwStmt: - return "T_AlterFdwStmt"; - case duckdb_libpgquery::T_PGCreateForeignServerStmt: - return "T_CreateForeignServerStmt"; - case duckdb_libpgquery::T_PGAlterForeignServerStmt: - return "T_AlterForeignServerStmt"; - case duckdb_libpgquery::T_PGCreateUserMappingStmt: - return "T_CreateUserMappingStmt"; - case duckdb_libpgquery::T_PGAlterUserMappingStmt: - return "T_AlterUserMappingStmt"; - case duckdb_libpgquery::T_PGDropUserMappingStmt: - return "T_DropUserMappingStmt"; - case duckdb_libpgquery::T_PGAlterTableSpaceOptionsStmt: - return "T_AlterTableSpaceOptionsStmt"; - case duckdb_libpgquery::T_PGAlterTableMoveAllStmt: - return "T_AlterTableMoveAllStmt"; - case duckdb_libpgquery::T_PGSecLabelStmt: - return "T_SecLabelStmt"; - case duckdb_libpgquery::T_PGCreateForeignTableStmt: - return "T_CreateForeignTableStmt"; - case duckdb_libpgquery::T_PGImportForeignSchemaStmt: - return "T_ImportForeignSchemaStmt"; - case duckdb_libpgquery::T_PGCreateExtensionStmt: - return "T_CreateExtensionStmt"; - case duckdb_libpgquery::T_PGAlterExtensionStmt: - return "T_AlterExtensionStmt"; - case duckdb_libpgquery::T_PGAlterExtensionContentsStmt: - return "T_AlterExtensionContentsStmt"; - case duckdb_libpgquery::T_PGCreateEventTrigStmt: - return "T_CreateEventTrigStmt"; - case duckdb_libpgquery::T_PGAlterEventTrigStmt: - return "T_AlterEventTrigStmt"; - case duckdb_libpgquery::T_PGRefreshMatViewStmt: - return "T_RefreshMatViewStmt"; - case duckdb_libpgquery::T_PGReplicaIdentityStmt: - return "T_ReplicaIdentityStmt"; - case duckdb_libpgquery::T_PGAlterSystemStmt: - return "T_AlterSystemStmt"; - case duckdb_libpgquery::T_PGCreatePolicyStmt: - return "T_CreatePolicyStmt"; - case duckdb_libpgquery::T_PGAlterPolicyStmt: - return "T_AlterPolicyStmt"; - case duckdb_libpgquery::T_PGCreateTransformStmt: - return "T_CreateTransformStmt"; - case duckdb_libpgquery::T_PGCreateAmStmt: - return "T_CreateAmStmt"; - case duckdb_libpgquery::T_PGCreatePublicationStmt: - return "T_CreatePublicationStmt"; - case duckdb_libpgquery::T_PGAlterPublicationStmt: - return "T_AlterPublicationStmt"; - case duckdb_libpgquery::T_PGCreateSubscriptionStmt: - return "T_CreateSubscriptionStmt"; - case duckdb_libpgquery::T_PGAlterSubscriptionStmt: - return "T_AlterSubscriptionStmt"; - case duckdb_libpgquery::T_PGDropSubscriptionStmt: - return "T_DropSubscriptionStmt"; - case duckdb_libpgquery::T_PGCreateStatsStmt: - return "T_CreateStatsStmt"; - case duckdb_libpgquery::T_PGAlterCollationStmt: - return "T_AlterCollationStmt"; - case duckdb_libpgquery::T_PGAExpr: - return "TAExpr"; - case duckdb_libpgquery::T_PGColumnRef: - return "T_ColumnRef"; - case duckdb_libpgquery::T_PGParamRef: - return "T_ParamRef"; - case duckdb_libpgquery::T_PGAConst: - return "TAConst"; - case duckdb_libpgquery::T_PGFuncCall: - return "T_FuncCall"; - case duckdb_libpgquery::T_PGAStar: - return "TAStar"; - case duckdb_libpgquery::T_PGAIndices: - return "TAIndices"; - case duckdb_libpgquery::T_PGAIndirection: - return "TAIndirection"; - case duckdb_libpgquery::T_PGAArrayExpr: - return "TAArrayExpr"; - case duckdb_libpgquery::T_PGResTarget: - return "T_ResTarget"; - case duckdb_libpgquery::T_PGMultiAssignRef: - return "T_MultiAssignRef"; - case duckdb_libpgquery::T_PGTypeCast: - return "T_TypeCast"; - case duckdb_libpgquery::T_PGCollateClause: - return "T_CollateClause"; - case duckdb_libpgquery::T_PGSortBy: - return "T_SortBy"; - case duckdb_libpgquery::T_PGWindowDef: - return "T_WindowDef"; - case duckdb_libpgquery::T_PGRangeSubselect: - return "T_RangeSubselect"; - case duckdb_libpgquery::T_PGRangeFunction: - return "T_RangeFunction"; - case duckdb_libpgquery::T_PGRangeTableSample: - return "T_RangeTableSample"; - case duckdb_libpgquery::T_PGRangeTableFunc: - return "T_RangeTableFunc"; - case duckdb_libpgquery::T_PGRangeTableFuncCol: - return "T_RangeTableFuncCol"; - case duckdb_libpgquery::T_PGTypeName: - return "T_TypeName"; - case duckdb_libpgquery::T_PGColumnDef: - return "T_ColumnDef"; - case duckdb_libpgquery::T_PGIndexElem: - return "T_IndexElem"; - case duckdb_libpgquery::T_PGConstraint: - return "T_Constraint"; - case duckdb_libpgquery::T_PGDefElem: - return "T_DefElem"; - case duckdb_libpgquery::T_PGRangeTblEntry: - return "T_RangeTblEntry"; - case duckdb_libpgquery::T_PGRangeTblFunction: - return "T_RangeTblFunction"; - case duckdb_libpgquery::T_PGTableSampleClause: - return "T_TableSampleClause"; - case duckdb_libpgquery::T_PGWithCheckOption: - return "T_WithCheckOption"; - case duckdb_libpgquery::T_PGSortGroupClause: - return "T_SortGroupClause"; - case duckdb_libpgquery::T_PGGroupingSet: - return "T_GroupingSet"; - case duckdb_libpgquery::T_PGWindowClause: - return "T_WindowClause"; - case duckdb_libpgquery::T_PGObjectWithArgs: - return "T_ObjectWithArgs"; - case duckdb_libpgquery::T_PGAccessPriv: - return "T_AccessPriv"; - case duckdb_libpgquery::T_PGCreateOpClassItem: - return "T_CreateOpClassItem"; - case duckdb_libpgquery::T_PGTableLikeClause: - return "T_TableLikeClause"; - case duckdb_libpgquery::T_PGFunctionParameter: - return "T_FunctionParameter"; - case duckdb_libpgquery::T_PGLockingClause: - return "T_LockingClause"; - case duckdb_libpgquery::T_PGRowMarkClause: - return "T_RowMarkClause"; - case duckdb_libpgquery::T_PGXmlSerialize: - return "T_XmlSerialize"; - case duckdb_libpgquery::T_PGWithClause: - return "T_WithClause"; - case duckdb_libpgquery::T_PGInferClause: - return "T_InferClause"; - case duckdb_libpgquery::T_PGOnConflictClause: - return "T_OnConflictClause"; - case duckdb_libpgquery::T_PGCommonTableExpr: - return "T_CommonTableExpr"; - case duckdb_libpgquery::T_PGRoleSpec: - return "T_RoleSpec"; - case duckdb_libpgquery::T_PGTriggerTransition: - return "T_TriggerTransition"; - case duckdb_libpgquery::T_PGPartitionElem: - return "T_PartitionElem"; - case duckdb_libpgquery::T_PGPartitionSpec: - return "T_PartitionSpec"; - case duckdb_libpgquery::T_PGPartitionBoundSpec: - return "T_PartitionBoundSpec"; - case duckdb_libpgquery::T_PGPartitionRangeDatum: - return "T_PartitionRangeDatum"; - case duckdb_libpgquery::T_PGPartitionCmd: - return "T_PartitionCmd"; - case duckdb_libpgquery::T_PGIdentifySystemCmd: - return "T_IdentifySystemCmd"; - case duckdb_libpgquery::T_PGBaseBackupCmd: - return "T_BaseBackupCmd"; - case duckdb_libpgquery::T_PGCreateReplicationSlotCmd: - return "T_CreateReplicationSlotCmd"; - case duckdb_libpgquery::T_PGDropReplicationSlotCmd: - return "T_DropReplicationSlotCmd"; - case duckdb_libpgquery::T_PGStartReplicationCmd: - return "T_StartReplicationCmd"; - case duckdb_libpgquery::T_PGTimeLineHistoryCmd: - return "T_TimeLineHistoryCmd"; - case duckdb_libpgquery::T_PGSQLCmd: - return "T_SQLCmd"; - case duckdb_libpgquery::T_PGTriggerData: - return "T_TriggerData"; - case duckdb_libpgquery::T_PGEventTriggerData: - return "T_EventTriggerData"; - case duckdb_libpgquery::T_PGReturnSetInfo: - return "T_ReturnSetInfo"; - case duckdb_libpgquery::T_PGWindowObjectData: - return "T_WindowObjectData"; - case duckdb_libpgquery::T_PGTIDBitmap: - return "T_TIDBitmap"; - case duckdb_libpgquery::T_PGInlineCodeBlock: - return "T_InlineCodeBlock"; - case duckdb_libpgquery::T_PGFdwRoutine: - return "T_FdwRoutine"; - case duckdb_libpgquery::T_PGIndexAmRoutine: - return "T_IndexAmRoutine"; - case duckdb_libpgquery::T_PGTsmRoutine: - return "T_TsmRoutine"; - case duckdb_libpgquery::T_PGForeignKeyCacheInfo: - return "T_ForeignKeyCacheInfo"; - case duckdb_libpgquery::T_PGAttachStmt: - return "T_PGAttachStmt"; - case duckdb_libpgquery::T_PGUseStmt: - return "T_PGUseStmt"; - default: - return "(UNKNOWN)"; - } -} // LCOV_EXCL_STOP - -} // namespace duckdb - - -namespace duckdb { - -vector Transformer::TransformStringList(duckdb_libpgquery::PGList *list) { - vector result; - if (!list) { - return result; - } - for (auto node = list->head; node != nullptr; node = node->next) { - result.emplace_back(reinterpret_cast(node->data.ptr_value)->val.str); - } - return result; -} - -string Transformer::TransformAlias(duckdb_libpgquery::PGAlias *root, vector &column_name_alias) { - if (!root) { - return ""; - } - column_name_alias = TransformStringList(root->colnames); - return root->aliasname; -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -unique_ptr CommonTableExpressionInfo::Copy() { - auto result = make_uniq(); - result->aliases = aliases; - result->query = unique_ptr_cast(query->Copy()); - result->materialized = materialized; - return result; -} - -void Transformer::ExtractCTEsRecursive(CommonTableExpressionMap &cte_map) { - for (auto &cte_entry : stored_cte_map) { - for (auto &entry : cte_entry->map) { - auto found_entry = cte_map.map.find(entry.first); - if (found_entry != cte_map.map.end()) { - // entry already present - use top-most entry - continue; - } - cte_map.map[entry.first] = entry.second->Copy(); - } - } - if (parent) { - parent->ExtractCTEsRecursive(cte_map); - } -} - -void Transformer::TransformCTE(duckdb_libpgquery::PGWithClause &de_with_clause, CommonTableExpressionMap &cte_map, - vector> &materialized_ctes) { - // TODO: might need to update in case of future lawsuit - stored_cte_map.push_back(&cte_map); - - D_ASSERT(de_with_clause.ctes); - for (auto cte_ele = de_with_clause.ctes->head; cte_ele != nullptr; cte_ele = cte_ele->next) { - auto info = make_uniq(); - - auto &cte = *PGPointerCast(cte_ele->data.ptr_value); - if (cte.aliascolnames) { - for (auto node = cte.aliascolnames->head; node != nullptr; node = node->next) { - info->aliases.emplace_back( - reinterpret_cast(node->data.ptr_value)->val.str); - } - } - // lets throw some errors on unsupported features early - if (cte.ctecolnames) { - throw NotImplementedException("Column name setting not supported in CTEs"); - } - if (cte.ctecoltypes) { - throw NotImplementedException("Column type setting not supported in CTEs"); - } - if (cte.ctecoltypmods) { - throw NotImplementedException("Column type modification not supported in CTEs"); - } - if (cte.ctecolcollations) { - throw NotImplementedException("CTE collations not supported"); - } - // we need a query - if (!cte.ctequery || cte.ctequery->type != duckdb_libpgquery::T_PGSelectStmt) { - throw NotImplementedException("A CTE needs a SELECT"); - } - - // CTE transformation can either result in inlining for non recursive CTEs, or in recursive CTE bindings - // otherwise. - if (cte.cterecursive || de_with_clause.recursive) { - info->query = TransformRecursiveCTE(cte, *info); - } else { - Transformer cte_transformer(*this); - info->query = - cte_transformer.TransformSelect(*PGPointerCast(cte.ctequery)); - } - D_ASSERT(info->query); - auto cte_name = string(cte.ctename); - - auto it = cte_map.map.find(cte_name); - if (it != cte_map.map.end()) { - // can't have two CTEs with same name - throw ParserException("Duplicate CTE name \"%s\"", cte_name); - } - -#ifdef DUCKDB_ALTERNATIVE_VERIFY - if (cte.ctematerialized == duckdb_libpgquery::PGCTEMaterializeDefault) { -#else - if (cte.ctematerialized == duckdb_libpgquery::PGCTEMaterializeAlways) { -#endif - auto materialize = make_uniq(); - materialize->query = info->query->node->Copy(); - materialize->ctename = cte_name; - materialize->aliases = info->aliases; - materialized_ctes.push_back(std::move(materialize)); - - info->materialized = CTEMaterialize::CTE_MATERIALIZE_ALWAYS; - } - - cte_map.map[cte_name] = std::move(info); - } -} - -unique_ptr Transformer::TransformRecursiveCTE(duckdb_libpgquery::PGCommonTableExpr &cte, - CommonTableExpressionInfo &info) { - auto &stmt = *PGPointerCast(cte.ctequery); - - unique_ptr select; - switch (stmt.op) { - case duckdb_libpgquery::PG_SETOP_UNION: - case duckdb_libpgquery::PG_SETOP_EXCEPT: - case duckdb_libpgquery::PG_SETOP_INTERSECT: { - select = make_uniq(); - select->node = make_uniq_base(); - auto &result = select->node->Cast(); - result.ctename = string(cte.ctename); - result.union_all = stmt.all; - result.left = TransformSelectNode(*PGPointerCast(stmt.larg)); - result.right = TransformSelectNode(*PGPointerCast(stmt.rarg)); - result.aliases = info.aliases; - if (stmt.op != duckdb_libpgquery::PG_SETOP_UNION) { - throw ParserException("Unsupported setop type for recursive CTE: only UNION or UNION ALL are supported"); - } - break; - } - default: - // This CTE is not recursive. Fallback to regular query transformation. - return TransformSelect(*PGPointerCast(cte.ctequery)); - } - - if (stmt.limitCount || stmt.limitOffset) { - throw ParserException("LIMIT or OFFSET in a recursive query is not allowed"); - } - if (stmt.sortClause) { - throw ParserException("ORDER BY in a recursive query is not allowed"); - } - return select; -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -static void CheckGroupingSetMax(idx_t count) { - static constexpr const idx_t MAX_GROUPING_SETS = 65535; - if (count > MAX_GROUPING_SETS) { - throw ParserException("Maximum grouping set count of %d exceeded", MAX_GROUPING_SETS); - } -} - -static void CheckGroupingSetCubes(idx_t current_count, idx_t cube_count) { - idx_t combinations = 1; - for (idx_t i = 0; i < cube_count; i++) { - combinations *= 2; - CheckGroupingSetMax(current_count + combinations); - } -} - -struct GroupingExpressionMap { - parsed_expression_map_t map; -}; - -static GroupingSet VectorToGroupingSet(vector &indexes) { - GroupingSet result; - for (idx_t i = 0; i < indexes.size(); i++) { - result.insert(indexes[i]); - } - return result; -} - -static void MergeGroupingSet(GroupingSet &result, GroupingSet &other) { - CheckGroupingSetMax(result.size() + other.size()); - result.insert(other.begin(), other.end()); -} - -void Transformer::AddGroupByExpression(unique_ptr expression, GroupingExpressionMap &map, - GroupByNode &result, vector &result_set) { - if (expression->type == ExpressionType::FUNCTION) { - auto &func = expression->Cast(); - if (func.function_name == "row") { - for (auto &child : func.children) { - AddGroupByExpression(std::move(child), map, result, result_set); - } - return; - } - } - auto entry = map.map.find(*expression); - idx_t result_idx; - if (entry == map.map.end()) { - result_idx = result.group_expressions.size(); - map.map[*expression] = result_idx; - result.group_expressions.push_back(std::move(expression)); - } else { - result_idx = entry->second; - } - result_set.push_back(result_idx); -} - -static void AddCubeSets(const GroupingSet ¤t_set, vector &result_set, - vector &result_sets, idx_t start_idx = 0) { - CheckGroupingSetMax(result_sets.size()); - result_sets.push_back(current_set); - for (idx_t k = start_idx; k < result_set.size(); k++) { - auto child_set = current_set; - MergeGroupingSet(child_set, result_set[k]); - AddCubeSets(child_set, result_set, result_sets, k + 1); - } -} - -void Transformer::TransformGroupByExpression(duckdb_libpgquery::PGNode &n, GroupingExpressionMap &map, - GroupByNode &result, vector &indexes) { - auto expression = TransformExpression(n); - AddGroupByExpression(std::move(expression), map, result, indexes); -} - -// If one GROUPING SETS clause is nested inside another, -// the effect is the same as if all the elements of the inner clause had been written directly in the outer clause. -void Transformer::TransformGroupByNode(duckdb_libpgquery::PGNode &n, GroupingExpressionMap &map, SelectNode &result, - vector &result_sets) { - if (n.type == duckdb_libpgquery::T_PGGroupingSet) { - auto &grouping_set = PGCast(n); - switch (grouping_set.kind) { - case duckdb_libpgquery::GROUPING_SET_EMPTY: - result_sets.emplace_back(); - break; - case duckdb_libpgquery::GROUPING_SET_ALL: { - result.aggregate_handling = AggregateHandling::FORCE_AGGREGATES; - break; - } - case duckdb_libpgquery::GROUPING_SET_SETS: { - for (auto node = grouping_set.content->head; node; node = node->next) { - auto pg_node = PGPointerCast(node->data.ptr_value); - TransformGroupByNode(*pg_node, map, result, result_sets); - } - break; - } - case duckdb_libpgquery::GROUPING_SET_ROLLUP: { - vector rollup_sets; - for (auto node = grouping_set.content->head; node; node = node->next) { - auto pg_node = PGPointerCast(node->data.ptr_value); - vector rollup_set; - TransformGroupByExpression(*pg_node, map, result.groups, rollup_set); - rollup_sets.push_back(VectorToGroupingSet(rollup_set)); - } - // generate the subsets of the rollup set and add them to the grouping sets - GroupingSet current_set; - result_sets.push_back(current_set); - for (idx_t i = 0; i < rollup_sets.size(); i++) { - MergeGroupingSet(current_set, rollup_sets[i]); - result_sets.push_back(current_set); - } - break; - } - case duckdb_libpgquery::GROUPING_SET_CUBE: { - vector cube_sets; - for (auto node = grouping_set.content->head; node; node = node->next) { - auto pg_node = PGPointerCast(node->data.ptr_value); - vector cube_set; - TransformGroupByExpression(*pg_node, map, result.groups, cube_set); - cube_sets.push_back(VectorToGroupingSet(cube_set)); - } - // generate the subsets of the rollup set and add them to the grouping sets - CheckGroupingSetCubes(result_sets.size(), cube_sets.size()); - - GroupingSet current_set; - AddCubeSets(current_set, cube_sets, result_sets, 0); - break; - } - default: - throw InternalException("Unsupported GROUPING SET type %d", grouping_set.kind); - } - } else { - vector indexes; - TransformGroupByExpression(n, map, result.groups, indexes); - result_sets.push_back(VectorToGroupingSet(indexes)); - } -} - -// If multiple grouping items are specified in a single GROUP BY clause, -// then the final list of grouping sets is the cross product of the individual items. -bool Transformer::TransformGroupBy(optional_ptr group, SelectNode &select_node) { - if (!group) { - return false; - } - auto &result = select_node.groups; - GroupingExpressionMap map; - for (auto node = group->head; node != nullptr; node = node->next) { - auto n = PGPointerCast(node->data.ptr_value); - vector result_sets; - TransformGroupByNode(*n, map, select_node, result_sets); - CheckGroupingSetMax(result_sets.size()); - if (result.grouping_sets.empty()) { - // no grouping sets yet: use the current set of grouping sets - result.grouping_sets = std::move(result_sets); - } else { - // compute the cross product - vector new_sets; - idx_t grouping_set_count = result.grouping_sets.size() * result_sets.size(); - CheckGroupingSetMax(grouping_set_count); - new_sets.reserve(grouping_set_count); - for (idx_t current_idx = 0; current_idx < result.grouping_sets.size(); current_idx++) { - auto ¤t_set = result.grouping_sets[current_idx]; - for (idx_t new_idx = 0; new_idx < result_sets.size(); new_idx++) { - auto &new_set = result_sets[new_idx]; - GroupingSet set; - set.insert(current_set.begin(), current_set.end()); - set.insert(new_set.begin(), new_set.end()); - new_sets.push_back(std::move(set)); - } - } - result.grouping_sets = std::move(new_sets); - } - } - if (result.group_expressions.size() == 1 && result.grouping_sets.size() == 1 && - ExpressionIsEmptyStar(*result.group_expressions[0])) { - // GROUP BY * - result.group_expressions.clear(); - result.grouping_sets.clear(); - select_node.aggregate_handling = AggregateHandling::FORCE_AGGREGATES; - } - return true; -} - -} // namespace duckdb - - - - - -namespace duckdb { - -bool Transformer::TransformOrderBy(duckdb_libpgquery::PGList *order, vector &result) { - if (!order) { - return false; - } - - for (auto node = order->head; node != nullptr; node = node->next) { - auto temp = reinterpret_cast(node->data.ptr_value); - if (temp->type == duckdb_libpgquery::T_PGSortBy) { - OrderType type; - OrderByNullType null_order; - auto sort = reinterpret_cast(temp); - auto target = sort->node; - if (sort->sortby_dir == duckdb_libpgquery::PG_SORTBY_DEFAULT) { - type = OrderType::ORDER_DEFAULT; - } else if (sort->sortby_dir == duckdb_libpgquery::PG_SORTBY_ASC) { - type = OrderType::ASCENDING; - } else if (sort->sortby_dir == duckdb_libpgquery::PG_SORTBY_DESC) { - type = OrderType::DESCENDING; - } else { - throw NotImplementedException("Unimplemented order by type"); - } - if (sort->sortby_nulls == duckdb_libpgquery::PG_SORTBY_NULLS_DEFAULT) { - null_order = OrderByNullType::ORDER_DEFAULT; - } else if (sort->sortby_nulls == duckdb_libpgquery::PG_SORTBY_NULLS_FIRST) { - null_order = OrderByNullType::NULLS_FIRST; - } else if (sort->sortby_nulls == duckdb_libpgquery::PG_SORTBY_NULLS_LAST) { - null_order = OrderByNullType::NULLS_LAST; - } else { - throw NotImplementedException("Unimplemented order by type"); - } - auto order_expression = TransformExpression(target); - result.emplace_back(type, null_order, std::move(order_expression)); - } else { - throw NotImplementedException("ORDER BY list member type %d\n", temp->type); - } - } - return true; -} - -} // namespace duckdb - - - - - -namespace duckdb { - -static SampleMethod GetSampleMethod(const string &method) { - auto lmethod = StringUtil::Lower(method); - if (lmethod == "system") { - return SampleMethod::SYSTEM_SAMPLE; - } else if (lmethod == "bernoulli") { - return SampleMethod::BERNOULLI_SAMPLE; - } else if (lmethod == "reservoir") { - return SampleMethod::RESERVOIR_SAMPLE; - } else { - throw ParserException("Unrecognized sampling method %s, expected system, bernoulli or reservoir", method); - } -} - -unique_ptr Transformer::TransformSampleOptions(optional_ptr options) { - if (!options) { - return nullptr; - } - auto result = make_uniq(); - auto &sample_options = PGCast(*options); - auto &sample_size = *PGPointerCast(sample_options.sample_size); - auto sample_value = TransformValue(sample_size.sample_size)->value; - result->is_percentage = sample_size.is_percentage; - if (sample_size.is_percentage) { - // sample size is given in sample_size: use system sampling - auto percentage = sample_value.GetValue(); - if (percentage < 0 || percentage > 100) { - throw ParserException("Sample sample_size %llf out of range, must be between 0 and 100", percentage); - } - result->sample_size = Value::DOUBLE(percentage); - result->method = SampleMethod::SYSTEM_SAMPLE; - } else { - // sample size is given in rows: use reservoir sampling - auto rows = sample_value.GetValue(); - if (rows < 0) { - throw ParserException("Sample rows %lld out of range, must be bigger than or equal to 0", rows); - } - result->sample_size = Value::BIGINT(rows); - result->method = SampleMethod::RESERVOIR_SAMPLE; - } - if (sample_options.method) { - result->method = GetSampleMethod(sample_options.method); - } - if (sample_options.has_seed) { - result->seed = sample_options.seed; - } - return result; -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -LogicalType Transformer::TransformTypeName(duckdb_libpgquery::PGTypeName &type_name) { - if (type_name.type != duckdb_libpgquery::T_PGTypeName) { - throw ParserException("Expected a type"); - } - auto stack_checker = StackCheck(); - - auto name = PGPointerCast(type_name.names->tail->data.ptr_value)->val.str; - // transform it to the SQL type - LogicalTypeId base_type = TransformStringToLogicalTypeId(name); - - LogicalType result_type; - if (base_type == LogicalTypeId::LIST) { - throw ParserException("LIST is not valid as a stand-alone type"); - } else if (base_type == LogicalTypeId::ENUM) { - if (!type_name.typmods || type_name.typmods->length == 0) { - throw ParserException("Enum needs a set of entries"); - } - Vector enum_vector(LogicalType::VARCHAR, type_name.typmods->length); - auto string_data = FlatVector::GetData(enum_vector); - idx_t pos = 0; - for (auto node = type_name.typmods->head; node; node = node->next) { - auto constant_value = PGPointerCast(node->data.ptr_value); - if (constant_value->type != duckdb_libpgquery::T_PGAConst || - constant_value->val.type != duckdb_libpgquery::T_PGString) { - throw ParserException("Enum type requires a set of strings as type modifiers"); - } - string_data[pos++] = StringVector::AddString(enum_vector, constant_value->val.val.str); - } - return LogicalType::ENUM(enum_vector, type_name.typmods->length); - } else if (base_type == LogicalTypeId::STRUCT) { - if (!type_name.typmods || type_name.typmods->length == 0) { - throw ParserException("Struct needs a name and entries"); - } - child_list_t children; - case_insensitive_set_t name_collision_set; - - for (auto node = type_name.typmods->head; node; node = node->next) { - auto &type_val = *PGPointerCast(node->data.ptr_value); - if (type_val.length != 2) { - throw ParserException("Struct entry needs an entry name and a type name"); - } - - auto entry_name_node = PGPointerCast(type_val.head->data.ptr_value); - D_ASSERT(entry_name_node->type == duckdb_libpgquery::T_PGString); - auto entry_type_node = PGPointerCast(type_val.tail->data.ptr_value); - D_ASSERT(entry_type_node->type == duckdb_libpgquery::T_PGTypeName); - - auto entry_name = string(entry_name_node->val.str); - D_ASSERT(!entry_name.empty()); - - if (name_collision_set.find(entry_name) != name_collision_set.end()) { - throw ParserException("Duplicate struct entry name \"%s\"", entry_name); - } - name_collision_set.insert(entry_name); - auto entry_type = TransformTypeName(*entry_type_node); - - children.push_back(make_pair(entry_name, entry_type)); - } - D_ASSERT(!children.empty()); - result_type = LogicalType::STRUCT(children); - - } else if (base_type == LogicalTypeId::MAP) { - if (!type_name.typmods || type_name.typmods->length != 2) { - throw ParserException("Map type needs exactly two entries, key and value type"); - } - auto key_type = - TransformTypeName(*PGPointerCast(type_name.typmods->head->data.ptr_value)); - auto value_type = - TransformTypeName(*PGPointerCast(type_name.typmods->tail->data.ptr_value)); - - result_type = LogicalType::MAP(std::move(key_type), std::move(value_type)); - } else if (base_type == LogicalTypeId::UNION) { - if (!type_name.typmods || type_name.typmods->length == 0) { - throw ParserException("Union type needs at least one member"); - } - if (type_name.typmods->length > (int)UnionType::MAX_UNION_MEMBERS) { - throw ParserException("Union types can have at most %d members", UnionType::MAX_UNION_MEMBERS); - } - - child_list_t children; - case_insensitive_set_t name_collision_set; - - for (auto node = type_name.typmods->head; node; node = node->next) { - auto &type_val = *PGPointerCast(node->data.ptr_value); - if (type_val.length != 2) { - throw ParserException("Union type member needs a tag name and a type name"); - } - - auto entry_name_node = PGPointerCast(type_val.head->data.ptr_value); - D_ASSERT(entry_name_node->type == duckdb_libpgquery::T_PGString); - auto entry_type_node = PGPointerCast(type_val.tail->data.ptr_value); - D_ASSERT(entry_type_node->type == duckdb_libpgquery::T_PGTypeName); - - auto entry_name = string(entry_name_node->val.str); - D_ASSERT(!entry_name.empty()); - - if (name_collision_set.find(entry_name) != name_collision_set.end()) { - throw ParserException("Duplicate union type tag name \"%s\"", entry_name); - } - - name_collision_set.insert(entry_name); - - auto entry_type = TransformTypeName(*entry_type_node); - children.push_back(make_pair(entry_name, entry_type)); - } - D_ASSERT(!children.empty()); - result_type = LogicalType::UNION(std::move(children)); - } else { - int64_t width, scale; - if (base_type == LogicalTypeId::DECIMAL) { - // default decimal width/scale - width = 18; - scale = 3; - } else { - width = 0; - scale = 0; - } - // check any modifiers - int modifier_idx = 0; - if (type_name.typmods) { - for (auto node = type_name.typmods->head; node; node = node->next) { - auto &const_val = *PGPointerCast(node->data.ptr_value); - if (const_val.type != duckdb_libpgquery::T_PGAConst || - const_val.val.type != duckdb_libpgquery::T_PGInteger) { - throw ParserException("Expected an integer constant as type modifier"); - } - if (const_val.val.val.ival < 0) { - throw ParserException("Negative modifier not supported"); - } - if (modifier_idx == 0) { - width = const_val.val.val.ival; - if (base_type == LogicalTypeId::BIT && const_val.location != -1) { - width = 0; - } - } else if (modifier_idx == 1) { - scale = const_val.val.val.ival; - } else { - throw ParserException("A maximum of two modifiers is supported"); - } - modifier_idx++; - } - } - switch (base_type) { - case LogicalTypeId::VARCHAR: - if (modifier_idx > 1) { - throw ParserException("VARCHAR only supports a single modifier"); - } - // FIXME: create CHECK constraint based on varchar width - width = 0; - result_type = LogicalType::VARCHAR; - break; - case LogicalTypeId::DECIMAL: - if (modifier_idx == 1) { - // only width is provided: set scale to 0 - scale = 0; - } - if (width <= 0 || width > Decimal::MAX_WIDTH_DECIMAL) { - throw ParserException("Width must be between 1 and %d!", (int)Decimal::MAX_WIDTH_DECIMAL); - } - if (scale > width) { - throw ParserException("Scale cannot be bigger than width"); - } - result_type = LogicalType::DECIMAL(width, scale); - break; - case LogicalTypeId::INTERVAL: - if (modifier_idx > 1) { - throw ParserException("INTERVAL only supports a single modifier"); - } - width = 0; - result_type = LogicalType::INTERVAL; - break; - case LogicalTypeId::USER: { - string user_type_name {name}; - result_type = LogicalType::USER(user_type_name); - break; - } - case LogicalTypeId::BIT: { - if (!width && type_name.typmods) { - throw ParserException("Type %s does not support any modifiers!", LogicalType(base_type).ToString()); - } - result_type = LogicalType(base_type); - break; - } - case LogicalTypeId::TIMESTAMP: - if (modifier_idx == 0) { - result_type = LogicalType::TIMESTAMP; - } else { - if (modifier_idx > 1) { - throw ParserException("TIMESTAMP only supports a single modifier"); - } - if (width > 10) { - throw ParserException("TIMESTAMP only supports until nano-second precision (9)"); - } - if (width == 0) { - result_type = LogicalType::TIMESTAMP_S; - } else if (width <= 3) { - result_type = LogicalType::TIMESTAMP_MS; - } else if (width <= 6) { - result_type = LogicalType::TIMESTAMP; - } else { - result_type = LogicalType::TIMESTAMP_NS; - } - } - break; - default: - if (modifier_idx > 0) { - throw ParserException("Type %s does not support any modifiers!", LogicalType(base_type).ToString()); - } - result_type = LogicalType(base_type); - break; - } - } - if (type_name.arrayBounds) { - // array bounds: turn the type into a list - idx_t extra_stack = 0; - for (auto cell = type_name.arrayBounds->head; cell != nullptr; cell = cell->next) { - result_type = LogicalType::LIST(result_type); - StackCheck(extra_stack++); - } - } - return result_type; -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -unique_ptr Transformer::TransformAlterSequence(duckdb_libpgquery::PGAlterSeqStmt &stmt) { - auto result = make_uniq(); - - auto qname = TransformQualifiedName(*stmt.sequence); - auto sequence_catalog = qname.catalog; - auto sequence_schema = qname.schema; - auto sequence_name = qname.name; - - if (!stmt.options) { - throw InternalException("Expected an argument for ALTER SEQUENCE."); - } - - unordered_set used; - duckdb_libpgquery::PGListCell *cell; - for_each_cell(cell, stmt.options->head) { - auto def_elem = PGPointerCast(cell->data.ptr_value); - string opt_name = string(def_elem->defname); - - if (opt_name == "owned_by") { - if (used.find(SequenceInfo::SEQ_OWN) != used.end()) { - throw ParserException("Owned by value should be passed as most once"); - } - used.insert(SequenceInfo::SEQ_OWN); - - auto val = PGPointerCast(def_elem->arg); - if (!val) { - throw InternalException("Expected an argument for option %s", opt_name); - } - D_ASSERT(val); - if (val->type != duckdb_libpgquery::T_PGList) { - throw InternalException("Expected a string argument for option %s", opt_name); - } - auto opt_values = vector(); - - for (auto c = val->head; c != nullptr; c = lnext(c)) { - auto target = PGPointerCast(c->data.ptr_value); - opt_values.emplace_back(target->name); - } - D_ASSERT(!opt_values.empty()); - string owner_schema = INVALID_SCHEMA; - string owner_name; - if (opt_values.size() == 2) { - owner_schema = opt_values[0]; - owner_name = opt_values[1]; - } else if (opt_values.size() == 1) { - owner_schema = DEFAULT_SCHEMA; - owner_name = opt_values[0]; - } else { - throw InternalException("Wrong argument for %s. Expected either . or ", opt_name); - } - auto info = make_uniq(CatalogType::SEQUENCE_ENTRY, sequence_catalog, sequence_schema, - sequence_name, owner_schema, owner_name, - TransformOnEntryNotFound(stmt.missing_ok)); - result->info = std::move(info); - } else { - throw NotImplementedException("ALTER SEQUENCE option not supported yet!"); - } - } - result->info->if_not_found = TransformOnEntryNotFound(stmt.missing_ok); - return result; -} -} // namespace duckdb - - - - - - -namespace duckdb { - -OnEntryNotFound Transformer::TransformOnEntryNotFound(bool missing_ok) { - return missing_ok ? OnEntryNotFound::RETURN_NULL : OnEntryNotFound::THROW_EXCEPTION; -} - -unique_ptr Transformer::TransformAlter(duckdb_libpgquery::PGAlterTableStmt &stmt) { - D_ASSERT(stmt.relation); - - if (stmt.cmds->length != 1) { - throw ParserException("Only one ALTER command per statement is supported"); - } - - auto result = make_uniq(); - auto qname = TransformQualifiedName(*stmt.relation); - - // first we check the type of ALTER - for (auto c = stmt.cmds->head; c != nullptr; c = c->next) { - auto command = reinterpret_cast(lfirst(c)); - AlterEntryData data(qname.catalog, qname.schema, qname.name, TransformOnEntryNotFound(stmt.missing_ok)); - // TODO: Include more options for command->subtype - switch (command->subtype) { - case duckdb_libpgquery::PG_AT_AddColumn: { - auto cdef = PGPointerCast(command->def); - - if (stmt.relkind != duckdb_libpgquery::PG_OBJECT_TABLE) { - throw ParserException("Adding columns is only supported for tables"); - } - if (cdef->category == duckdb_libpgquery::COL_GENERATED) { - throw ParserException("Adding generated columns after table creation is not supported yet"); - } - auto centry = TransformColumnDefinition(*cdef); - - if (cdef->constraints) { - for (auto constr = cdef->constraints->head; constr != nullptr; constr = constr->next) { - auto constraint = TransformConstraint(constr, centry, 0); - if (!constraint) { - continue; - } - throw ParserException("Adding columns with constraints not yet supported"); - } - } - result->info = make_uniq(std::move(data), std::move(centry), command->missing_ok); - break; - } - case duckdb_libpgquery::PG_AT_DropColumn: { - bool cascade = command->behavior == duckdb_libpgquery::PG_DROP_CASCADE; - - if (stmt.relkind != duckdb_libpgquery::PG_OBJECT_TABLE) { - throw ParserException("Dropping columns is only supported for tables"); - } - result->info = make_uniq(std::move(data), command->name, command->missing_ok, cascade); - break; - } - case duckdb_libpgquery::PG_AT_ColumnDefault: { - auto expr = TransformExpression(command->def); - - if (stmt.relkind != duckdb_libpgquery::PG_OBJECT_TABLE) { - throw ParserException("Alter column's default is only supported for tables"); - } - result->info = make_uniq(std::move(data), command->name, std::move(expr)); - break; - } - case duckdb_libpgquery::PG_AT_AlterColumnType: { - auto cdef = PGPointerCast(command->def); - auto column_definition = TransformColumnDefinition(*cdef); - unique_ptr expr; - - if (stmt.relkind != duckdb_libpgquery::PG_OBJECT_TABLE) { - throw ParserException("Alter column's type is only supported for tables"); - } - if (cdef->raw_default) { - expr = TransformExpression(cdef->raw_default); - } else { - auto colref = make_uniq(command->name); - expr = make_uniq(column_definition.Type(), std::move(colref)); - } - result->info = make_uniq(std::move(data), command->name, column_definition.Type(), - std::move(expr)); - break; - } - case duckdb_libpgquery::PG_AT_SetNotNull: { - result->info = make_uniq(std::move(data), command->name); - break; - } - case duckdb_libpgquery::PG_AT_DropNotNull: { - result->info = make_uniq(std::move(data), command->name); - break; - } - case duckdb_libpgquery::PG_AT_DropConstraint: - default: - throw NotImplementedException("ALTER TABLE option not supported yet!"); - } - } - - return result; -} - -} // namespace duckdb - - - - - -namespace duckdb { - -unique_ptr Transformer::TransformAttach(duckdb_libpgquery::PGAttachStmt &stmt) { - auto result = make_uniq(); - auto info = make_uniq(); - info->name = stmt.name ? stmt.name : string(); - info->path = stmt.path; - - if (stmt.options) { - duckdb_libpgquery::PGListCell *cell; - for_each_cell(cell, stmt.options->head) { - auto def_elem = PGPointerCast(cell->data.ptr_value); - Value val; - if (def_elem->arg) { - val = TransformValue(*PGPointerCast(def_elem->arg))->value; - } else { - val = Value::BOOLEAN(true); - } - info->options[StringUtil::Lower(def_elem->defname)] = std::move(val); - } - } - result->info = std::move(info); - return result; -} - -} // namespace duckdb - - - -namespace duckdb { - -unique_ptr Transformer::TransformCall(duckdb_libpgquery::PGCallStmt &stmt) { - auto result = make_uniq(); - result->function = TransformFuncCall(*PGPointerCast(stmt.func)); - return result; -} - -} // namespace duckdb - - - - - -namespace duckdb { - -unique_ptr Transformer::TransformCheckpoint(duckdb_libpgquery::PGCheckPointStmt &stmt) { - vector> children; - // transform into "CALL checkpoint()" or "CALL force_checkpoint()" - auto checkpoint_name = stmt.force ? "force_checkpoint" : "checkpoint"; - auto result = make_uniq(); - auto function = make_uniq(checkpoint_name, std::move(children)); - if (stmt.name) { - function->children.push_back(make_uniq(Value(stmt.name))); - } - result->function = std::move(function); - return std::move(result); -} - -} // namespace duckdb - - - - - - - - - -#include - -namespace duckdb { - -void Transformer::TransformCopyOptions(CopyInfo &info, optional_ptr options) { - if (!options) { - return; - } - - // iterate over each option - duckdb_libpgquery::PGListCell *cell; - for_each_cell(cell, options->head) { - auto def_elem = PGPointerCast(cell->data.ptr_value); - if (StringUtil::Lower(def_elem->defname) == "format") { - // format specifier: interpret this option - auto format_val = PGPointerCast(def_elem->arg); - if (!format_val || format_val->type != duckdb_libpgquery::T_PGString) { - throw ParserException("Unsupported parameter type for FORMAT: expected e.g. FORMAT 'csv', 'parquet'"); - } - info.format = StringUtil::Lower(format_val->val.str); - continue; - } - // otherwise - if (info.options.find(def_elem->defname) != info.options.end()) { - throw ParserException("Unexpected duplicate option \"%s\"", def_elem->defname); - } - if (!def_elem->arg) { - info.options[def_elem->defname] = vector(); - continue; - } - switch (def_elem->arg->type) { - case duckdb_libpgquery::T_PGList: { - auto column_list = PGPointerCast(def_elem->arg); - for (auto c = column_list->head; c != nullptr; c = lnext(c)) { - auto target = PGPointerCast(c->data.ptr_value); - info.options[def_elem->defname].push_back(Value(target->name)); - } - break; - } - case duckdb_libpgquery::T_PGAStar: - info.options[def_elem->defname].push_back(Value("*")); - break; - case duckdb_libpgquery::T_PGFuncCall: { - auto func_call = PGPointerCast(def_elem->arg); - auto func_expr = TransformFuncCall(*func_call); - - Value value; - if (!Transformer::ConstructConstantFromExpression(*func_expr, value)) { - throw ParserException("Unsupported expression in COPY options: %s", func_expr->ToString()); - } - info.options[def_elem->defname].push_back(std::move(value)); - break; - } - default: { - auto val = PGPointerCast(def_elem->arg); - info.options[def_elem->defname].push_back(TransformValue(*val)->value); - break; - } - } - } -} - -unique_ptr Transformer::TransformCopy(duckdb_libpgquery::PGCopyStmt &stmt) { - auto result = make_uniq(); - auto &info = *result->info; - - // get file_path and is_from - info.is_from = stmt.is_from; - if (!stmt.filename) { - // stdin/stdout - info.file_path = info.is_from ? "/dev/stdin" : "/dev/stdout"; - } else { - // copy to a file - info.file_path = stmt.filename; - } - if (StringUtil::EndsWith(info.file_path, ".parquet")) { - info.format = "parquet"; - } else if (StringUtil::EndsWith(info.file_path, ".json") || StringUtil::EndsWith(info.file_path, ".ndjson")) { - info.format = "json"; - } else { - info.format = "csv"; - } - - // get select_list - if (stmt.attlist) { - for (auto n = stmt.attlist->head; n != nullptr; n = n->next) { - auto target = PGPointerCast(n->data.ptr_value); - if (target->name) { - info.select_list.emplace_back(target->name); - } - } - } - - if (stmt.relation) { - auto ref = TransformRangeVar(*stmt.relation); - auto &table = ref->Cast(); - info.table = table.table_name; - info.schema = table.schema_name; - info.catalog = table.catalog_name; - } else { - result->select_statement = TransformSelectNode(*PGPointerCast(stmt.query)); - } - - // handle the different options of the COPY statement - TransformCopyOptions(info, stmt.options); - - return result; -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -unique_ptr Transformer::TransformCreateFunction(duckdb_libpgquery::PGCreateFunctionStmt &stmt) { - D_ASSERT(stmt.type == duckdb_libpgquery::T_PGCreateFunctionStmt); - D_ASSERT(stmt.function || stmt.query); - - auto result = make_uniq(); - auto qname = TransformQualifiedName(*stmt.name); - - unique_ptr macro_func; - - // function can be null here - if (stmt.function) { - auto expression = TransformExpression(stmt.function); - macro_func = make_uniq(std::move(expression)); - } else if (stmt.query) { - auto query_node = - TransformSelect(*PGPointerCast(stmt.query), true)->node->Copy(); - macro_func = make_uniq(std::move(query_node)); - } - PivotEntryCheck("macro"); - - auto info = make_uniq(stmt.function ? CatalogType::MACRO_ENTRY : CatalogType::TABLE_MACRO_ENTRY); - info->catalog = qname.catalog; - info->schema = qname.schema; - info->name = qname.name; - - // temporary macro - switch (stmt.name->relpersistence) { - case duckdb_libpgquery::PG_RELPERSISTENCE_TEMP: - info->temporary = true; - break; - case duckdb_libpgquery::PG_RELPERSISTENCE_UNLOGGED: - throw ParserException("Unlogged flag not supported for macros: '%s'", qname.name); - break; - case duckdb_libpgquery::RELPERSISTENCE_PERMANENT: - info->temporary = false; - break; - } - - // what to do on conflict - info->on_conflict = TransformOnConflict(stmt.onconflict); - - if (stmt.params) { - vector> parameters; - TransformExpressionList(*stmt.params, parameters); - for (auto ¶m : parameters) { - if (param->type == ExpressionType::VALUE_CONSTANT) { - // parameters with default value (must have an alias) - if (param->alias.empty()) { - throw ParserException("Invalid parameter: '%s'", param->ToString()); - } - if (macro_func->default_parameters.find(param->alias) != macro_func->default_parameters.end()) { - throw ParserException("Duplicate default parameter: '%s'", param->alias); - } - macro_func->default_parameters[param->alias] = std::move(param); - } else if (param->GetExpressionClass() == ExpressionClass::COLUMN_REF) { - // positional parameters - if (!macro_func->default_parameters.empty()) { - throw ParserException("Positional parameters cannot come after parameters with a default value!"); - } - macro_func->parameters.push_back(std::move(param)); - } else { - throw ParserException("Invalid parameter: '%s'", param->ToString()); - } - } - } - - info->function = std::move(macro_func); - result->info = std::move(info); - - return result; -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -vector> Transformer::TransformIndexParameters(duckdb_libpgquery::PGList &list, - const string &relation_name) { - vector> expressions; - for (auto cell = list.head; cell != nullptr; cell = cell->next) { - auto index_element = PGPointerCast(cell->data.ptr_value); - if (index_element->collation) { - throw NotImplementedException("Index with collation not supported yet!"); - } - if (index_element->opclass) { - throw NotImplementedException("Index with opclass not supported yet!"); - } - - if (index_element->name) { - // create a column reference expression - expressions.push_back(make_uniq(index_element->name, relation_name)); - } else { - // parse the index expression - D_ASSERT(index_element->expr); - expressions.push_back(TransformExpression(index_element->expr)); - } - } - return expressions; -} - -unique_ptr Transformer::TransformCreateIndex(duckdb_libpgquery::PGIndexStmt &stmt) { - auto result = make_uniq(); - auto info = make_uniq(); - if (stmt.unique) { - info->constraint_type = IndexConstraintType::UNIQUE; - } else { - info->constraint_type = IndexConstraintType::NONE; - } - - info->on_conflict = TransformOnConflict(stmt.onconflict); - - info->expressions = TransformIndexParameters(*stmt.indexParams, stmt.relation->relname); - - auto index_type_name = StringUtil::Upper(string(stmt.accessMethod)); - - if (index_type_name == "ART") { - info->index_type = IndexType::ART; - } else { - info->index_type = IndexType::EXTENSION; - } - - info->index_type_name = index_type_name; - - if (stmt.relation->schemaname) { - info->schema = stmt.relation->schemaname; - } - if (stmt.relation->catalogname) { - info->catalog = stmt.relation->catalogname; - } - info->table = stmt.relation->relname; - if (stmt.idxname) { - info->index_name = stmt.idxname; - } else { - throw NotImplementedException("Index without a name not supported yet!"); - } - - // Parse the options list - if (stmt.options) { - duckdb_libpgquery::PGListCell *cell; - for_each_cell(cell, stmt.options->head) { - auto def_elem = PGPointerCast(cell->data.ptr_value); - Value val; - if (def_elem->arg) { - val = TransformValue(*PGPointerCast(def_elem->arg))->value; - } else { - val = Value::BOOLEAN(true); - } - info->options[StringUtil::Lower(def_elem->defname)] = std::move(val); - } - } - - for (auto &expr : info->expressions) { - info->parsed_expressions.emplace_back(expr->Copy()); - } - result->info = std::move(info); - return result; -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr Transformer::TransformCreateSchema(duckdb_libpgquery::PGCreateSchemaStmt &stmt) { - auto result = make_uniq(); - auto info = make_uniq(); - - D_ASSERT(stmt.schemaname); - info->catalog = stmt.catalogname ? stmt.catalogname : INVALID_CATALOG; - info->schema = stmt.schemaname; - info->on_conflict = TransformOnConflict(stmt.onconflict); - - if (stmt.schemaElts) { - // schema elements - for (auto cell = stmt.schemaElts->head; cell != nullptr; cell = cell->next) { - auto node = PGPointerCast(cell->data.ptr_value); - switch (node->type) { - case duckdb_libpgquery::T_PGCreateStmt: - case duckdb_libpgquery::T_PGViewStmt: - default: - throw NotImplementedException("Schema element not supported yet!"); - } - } - } - result->info = std::move(info); - return result; -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -unique_ptr Transformer::TransformCreateSequence(duckdb_libpgquery::PGCreateSeqStmt &stmt) { - auto result = make_uniq(); - auto info = make_uniq(); - - auto qname = TransformQualifiedName(*stmt.sequence); - info->catalog = qname.catalog; - info->schema = qname.schema; - info->name = qname.name; - - if (stmt.options) { - unordered_set used; - duckdb_libpgquery::PGListCell *cell = nullptr; - for_each_cell(cell, stmt.options->head) { - auto def_elem = PGPointerCast(cell->data.ptr_value); - string opt_name = string(def_elem->defname); - auto val = PGPointerCast(def_elem->arg); - bool nodef = def_elem->defaction == duckdb_libpgquery::PG_DEFELEM_UNSPEC && !val; // e.g. NO MINVALUE - int64_t opt_value = 0; - - if (val) { - if (val->type == duckdb_libpgquery::T_PGInteger) { - opt_value = val->val.ival; - } else if (val->type == duckdb_libpgquery::T_PGFloat) { - if (!TryCast::Operation(string_t(val->val.str), opt_value, true)) { - throw ParserException("Expected an integer argument for option %s", opt_name); - } - } else { - throw ParserException("Expected an integer argument for option %s", opt_name); - } - } - if (opt_name == "increment") { - if (used.find(SequenceInfo::SEQ_INC) != used.end()) { - throw ParserException("Increment value should be passed as most once"); - } - used.insert(SequenceInfo::SEQ_INC); - if (nodef) { - continue; - } - - info->increment = opt_value; - if (info->increment == 0) { - throw ParserException("Increment must not be zero"); - } - if (info->increment < 0) { - info->start_value = info->max_value = -1; - info->min_value = NumericLimits::Minimum(); - } else { - info->start_value = info->min_value = 1; - info->max_value = NumericLimits::Maximum(); - } - } else if (opt_name == "minvalue") { - if (used.find(SequenceInfo::SEQ_MIN) != used.end()) { - throw ParserException("Minvalue should be passed as most once"); - } - used.insert(SequenceInfo::SEQ_MIN); - if (nodef) { - continue; - } - - info->min_value = opt_value; - if (info->increment > 0) { - info->start_value = info->min_value; - } - } else if (opt_name == "maxvalue") { - if (used.find(SequenceInfo::SEQ_MAX) != used.end()) { - throw ParserException("Maxvalue should be passed as most once"); - } - used.insert(SequenceInfo::SEQ_MAX); - if (nodef) { - continue; - } - - info->max_value = opt_value; - if (info->increment < 0) { - info->start_value = info->max_value; - } - } else if (opt_name == "start") { - if (used.find(SequenceInfo::SEQ_START) != used.end()) { - throw ParserException("Start value should be passed as most once"); - } - used.insert(SequenceInfo::SEQ_START); - if (nodef) { - continue; - } - - info->start_value = opt_value; - } else if (opt_name == "cycle") { - if (used.find(SequenceInfo::SEQ_CYCLE) != used.end()) { - throw ParserException("Cycle value should be passed as most once"); - } - used.insert(SequenceInfo::SEQ_CYCLE); - if (nodef) { - continue; - } - - info->cycle = opt_value > 0; - } else { - throw ParserException("Unrecognized option \"%s\" for CREATE SEQUENCE", opt_name); - } - } - } - info->temporary = !stmt.sequence->relpersistence; - info->on_conflict = TransformOnConflict(stmt.onconflict); - if (info->max_value <= info->min_value) { - throw ParserException("MINVALUE (%lld) must be less than MAXVALUE (%lld)", info->min_value, info->max_value); - } - if (info->start_value < info->min_value) { - throw ParserException("START value (%lld) cannot be less than MINVALUE (%lld)", info->start_value, - info->min_value); - } - if (info->start_value > info->max_value) { - throw ParserException("START value (%lld) cannot be greater than MAXVALUE (%lld)", info->start_value, - info->max_value); - } - result->info = std::move(info); - return result; -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -string Transformer::TransformCollation(optional_ptr collate) { - if (!collate) { - return string(); - } - string collation; - for (auto c = collate->collname->head; c != nullptr; c = lnext(c)) { - auto pgvalue = PGPointerCast(c->data.ptr_value); - if (pgvalue->type != duckdb_libpgquery::T_PGString) { - throw ParserException("Expected a string as collation type!"); - } - auto collation_argument = string(pgvalue->val.str); - if (collation.empty()) { - collation = collation_argument; - } else { - collation += "." + collation_argument; - } - } - return collation; -} - -OnCreateConflict Transformer::TransformOnConflict(duckdb_libpgquery::PGOnCreateConflict conflict) { - switch (conflict) { - case duckdb_libpgquery::PG_ERROR_ON_CONFLICT: - return OnCreateConflict::ERROR_ON_CONFLICT; - case duckdb_libpgquery::PG_IGNORE_ON_CONFLICT: - return OnCreateConflict::IGNORE_ON_CONFLICT; - case duckdb_libpgquery::PG_REPLACE_ON_CONFLICT: - return OnCreateConflict::REPLACE_ON_CONFLICT; - default: - throw InternalException("Unrecognized OnConflict type"); - } -} - -unique_ptr Transformer::TransformCollateExpr(duckdb_libpgquery::PGCollateClause &collate) { - auto child = TransformExpression(collate.arg); - auto collation = TransformCollation(&collate); - return make_uniq(collation, std::move(child)); -} - -ColumnDefinition Transformer::TransformColumnDefinition(duckdb_libpgquery::PGColumnDef &cdef) { - string colname; - if (cdef.colname) { - colname = cdef.colname; - } - bool optional_type = cdef.category == duckdb_libpgquery::COL_GENERATED; - LogicalType target_type = (optional_type && !cdef.typeName) ? LogicalType::ANY : TransformTypeName(*cdef.typeName); - if (cdef.collClause) { - if (cdef.category == duckdb_libpgquery::COL_GENERATED) { - throw ParserException("Collations are not supported on generated columns"); - } - if (target_type.id() != LogicalTypeId::VARCHAR) { - throw ParserException("Only VARCHAR columns can have collations!"); - } - target_type = LogicalType::VARCHAR_COLLATION(TransformCollation(cdef.collClause)); - } - - return ColumnDefinition(colname, target_type); -} - -unique_ptr Transformer::TransformCreateTable(duckdb_libpgquery::PGCreateStmt &stmt) { - auto result = make_uniq(); - auto info = make_uniq(); - - if (stmt.inhRelations) { - throw NotImplementedException("inherited relations not implemented"); - } - D_ASSERT(stmt.relation); - - info->catalog = INVALID_CATALOG; - auto qname = TransformQualifiedName(*stmt.relation); - info->catalog = qname.catalog; - info->schema = qname.schema; - info->table = qname.name; - info->on_conflict = TransformOnConflict(stmt.onconflict); - info->temporary = - stmt.relation->relpersistence == duckdb_libpgquery::PGPostgresRelPersistence::PG_RELPERSISTENCE_TEMP; - - if (info->temporary && stmt.oncommit != duckdb_libpgquery::PGOnCommitAction::PG_ONCOMMIT_PRESERVE_ROWS && - stmt.oncommit != duckdb_libpgquery::PGOnCommitAction::PG_ONCOMMIT_NOOP) { - throw NotImplementedException("Only ON COMMIT PRESERVE ROWS is supported"); - } - if (!stmt.tableElts) { - throw ParserException("Table must have at least one column!"); - } - - idx_t column_count = 0; - for (auto c = stmt.tableElts->head; c != nullptr; c = lnext(c)) { - auto node = PGPointerCast(c->data.ptr_value); - switch (node->type) { - case duckdb_libpgquery::T_PGColumnDef: { - auto cdef = PGPointerCast(c->data.ptr_value); - auto centry = TransformColumnDefinition(*cdef); - if (cdef->constraints) { - for (auto constr = cdef->constraints->head; constr != nullptr; constr = constr->next) { - auto constraint = TransformConstraint(constr, centry, info->columns.LogicalColumnCount()); - if (constraint) { - info->constraints.push_back(std::move(constraint)); - } - } - } - info->columns.AddColumn(std::move(centry)); - column_count++; - break; - } - case duckdb_libpgquery::T_PGConstraint: { - info->constraints.push_back(TransformConstraint(c)); - break; - } - default: - throw NotImplementedException("ColumnDef type not handled yet"); - } - } - - if (!column_count) { - throw ParserException("Table must have at least one column!"); - } - - result->info = std::move(info); - return result; -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr Transformer::TransformCreateTableAs(duckdb_libpgquery::PGCreateTableAsStmt &stmt) { - if (stmt.relkind == duckdb_libpgquery::PG_OBJECT_MATVIEW) { - throw NotImplementedException("Materialized view not implemented"); - } - if (stmt.is_select_into || stmt.into->colNames || stmt.into->options) { - throw NotImplementedException("Unimplemented features for CREATE TABLE as"); - } - auto qname = TransformQualifiedName(*stmt.into->rel); - if (stmt.query->type != duckdb_libpgquery::T_PGSelectStmt) { - throw ParserException("CREATE TABLE AS requires a SELECT clause"); - } - auto query = TransformSelect(stmt.query, false); - - auto result = make_uniq(); - auto info = make_uniq(); - info->catalog = qname.catalog; - info->schema = qname.schema; - info->table = qname.name; - info->on_conflict = TransformOnConflict(stmt.onconflict); - info->temporary = - stmt.into->rel->relpersistence == duckdb_libpgquery::PGPostgresRelPersistence::PG_RELPERSISTENCE_TEMP; - info->query = std::move(query); - result->info = std::move(info); - return result; -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -Vector Transformer::PGListToVector(optional_ptr column_list, idx_t &size) { - if (!column_list) { - Vector result(LogicalType::VARCHAR); - return result; - } - // First we discover the size of this list - for (auto c = column_list->head; c != nullptr; c = lnext(c)) { - size++; - } - - Vector result(LogicalType::VARCHAR, size); - auto result_ptr = FlatVector::GetData(result); - - size = 0; - for (auto c = column_list->head; c != nullptr; c = lnext(c)) { - auto &type_val = *PGPointerCast(c->data.ptr_value); - auto &entry_value_node = type_val.val; - if (entry_value_node.type != duckdb_libpgquery::T_PGString) { - throw ParserException("Expected a string constant as value"); - } - - auto entry_value = string(entry_value_node.val.str); - D_ASSERT(!entry_value.empty()); - result_ptr[size++] = StringVector::AddStringOrBlob(result, entry_value); - } - return result; -} - -unique_ptr Transformer::TransformCreateType(duckdb_libpgquery::PGCreateTypeStmt &stmt) { - auto result = make_uniq(); - auto info = make_uniq(); - - auto qualified_name = TransformQualifiedName(*stmt.typeName); - info->catalog = qualified_name.catalog; - info->schema = qualified_name.schema; - info->name = qualified_name.name; - - switch (stmt.kind) { - case duckdb_libpgquery::PG_NEWTYPE_ENUM: { - info->internal = false; - if (stmt.query) { - // CREATE TYPE mood AS ENUM (SELECT ...) - D_ASSERT(stmt.vals == nullptr); - auto query = TransformSelect(stmt.query, false); - info->query = std::move(query); - info->type = LogicalType::INVALID; - } else { - D_ASSERT(stmt.query == nullptr); - idx_t size = 0; - auto ordered_array = PGListToVector(stmt.vals, size); - info->type = LogicalType::ENUM(ordered_array, size); - } - } break; - - case duckdb_libpgquery::PG_NEWTYPE_ALIAS: { - LogicalType target_type = TransformTypeName(*stmt.ofType); - info->type = target_type; - } break; - - default: - throw InternalException("Unknown kind of new type"); - } - result->info = std::move(info); - return result; -} -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr Transformer::TransformCreateView(duckdb_libpgquery::PGViewStmt &stmt) { - D_ASSERT(stmt.type == duckdb_libpgquery::T_PGViewStmt); - D_ASSERT(stmt.view); - - auto result = make_uniq(); - auto info = make_uniq(); - - auto qname = TransformQualifiedName(*stmt.view); - info->catalog = qname.catalog; - info->schema = qname.schema; - info->view_name = qname.name; - info->temporary = !stmt.view->relpersistence; - if (info->temporary && IsInvalidCatalog(info->catalog)) { - info->catalog = TEMP_CATALOG; - } - info->on_conflict = TransformOnConflict(stmt.onconflict); - - info->query = TransformSelect(*PGPointerCast(stmt.query), false); - - PivotEntryCheck("view"); - - if (stmt.aliases && stmt.aliases->length > 0) { - for (auto c = stmt.aliases->head; c != nullptr; c = lnext(c)) { - auto val = PGPointerCast(c->data.ptr_value); - switch (val->type) { - case duckdb_libpgquery::T_PGString: { - info->aliases.emplace_back(val->val.str); - break; - } - default: - throw NotImplementedException("View projection type"); - } - } - if (info->aliases.empty()) { - throw ParserException("Need at least one column name in CREATE VIEW projection list"); - } - } - - if (stmt.options && stmt.options->length > 0) { - throw NotImplementedException("VIEW options"); - } - - if (stmt.withCheckOption != duckdb_libpgquery::PGViewCheckOption::PG_NO_CHECK_OPTION) { - throw NotImplementedException("VIEW CHECK options"); - } - result->info = std::move(info); - return result; -} - -} // namespace duckdb - - - -namespace duckdb { - -unique_ptr Transformer::TransformDelete(duckdb_libpgquery::PGDeleteStmt &stmt) { - auto result = make_uniq(); - vector> materialized_ctes; - if (stmt.withClause) { - TransformCTE(*PGPointerCast(stmt.withClause), result->cte_map, - materialized_ctes); - if (!materialized_ctes.empty()) { - throw NotImplementedException("Materialized CTEs are not implemented for delete."); - } - } - - result->condition = TransformExpression(stmt.whereClause); - result->table = TransformRangeVar(*stmt.relation); - if (result->table->type != TableReferenceType::BASE_TABLE) { - throw Exception("Can only delete from base tables!"); - } - if (stmt.usingClause) { - for (auto n = stmt.usingClause->head; n != nullptr; n = n->next) { - auto target = PGPointerCast(n->data.ptr_value); - auto using_entry = TransformTableRefNode(*target); - result->using_clauses.push_back(std::move(using_entry)); - } - } - - if (stmt.returningList) { - TransformExpressionList(*stmt.returningList, result->returning_list); - } - return result; -} - -} // namespace duckdb - - - - - -namespace duckdb { - -unique_ptr Transformer::TransformDetach(duckdb_libpgquery::PGDetachStmt &stmt) { - auto result = make_uniq(); - auto info = make_uniq(); - info->name = stmt.db_name; - info->if_not_found = TransformOnEntryNotFound(stmt.missing_ok); - - result->info = std::move(info); - return result; -} - -} // namespace duckdb - - - -namespace duckdb { - -unique_ptr Transformer::TransformDrop(duckdb_libpgquery::PGDropStmt &stmt) { - auto result = make_uniq(); - auto &info = *result->info.get(); - if (stmt.objects->length != 1) { - throw NotImplementedException("Can only drop one object at a time"); - } - switch (stmt.removeType) { - case duckdb_libpgquery::PG_OBJECT_TABLE: - info.type = CatalogType::TABLE_ENTRY; - break; - case duckdb_libpgquery::PG_OBJECT_SCHEMA: - info.type = CatalogType::SCHEMA_ENTRY; - break; - case duckdb_libpgquery::PG_OBJECT_INDEX: - info.type = CatalogType::INDEX_ENTRY; - break; - case duckdb_libpgquery::PG_OBJECT_VIEW: - info.type = CatalogType::VIEW_ENTRY; - break; - case duckdb_libpgquery::PG_OBJECT_SEQUENCE: - info.type = CatalogType::SEQUENCE_ENTRY; - break; - case duckdb_libpgquery::PG_OBJECT_FUNCTION: - info.type = CatalogType::MACRO_ENTRY; - break; - case duckdb_libpgquery::PG_OBJECT_TABLE_MACRO: - info.type = CatalogType::TABLE_MACRO_ENTRY; - break; - case duckdb_libpgquery::PG_OBJECT_TYPE: - info.type = CatalogType::TYPE_ENTRY; - break; - default: - throw NotImplementedException("Cannot drop this type yet"); - } - - switch (stmt.removeType) { - case duckdb_libpgquery::PG_OBJECT_TYPE: { - auto view_list = PGPointerCast(stmt.objects); - auto target = PGPointerCast(view_list->head->data.ptr_value); - info.name = PGPointerCast(target->names->tail->data.ptr_value)->val.str; - break; - } - case duckdb_libpgquery::PG_OBJECT_SCHEMA: { - auto view_list = PGPointerCast(stmt.objects->head->data.ptr_value); - if (view_list->length == 2) { - info.catalog = PGPointerCast(view_list->head->data.ptr_value)->val.str; - info.name = PGPointerCast(view_list->head->next->data.ptr_value)->val.str; - } else if (view_list->length == 1) { - info.name = PGPointerCast(view_list->head->data.ptr_value)->val.str; - } else { - throw ParserException("Expected \"catalog.schema\" or \"schema\""); - } - break; - } - default: { - auto view_list = PGPointerCast(stmt.objects->head->data.ptr_value); - if (view_list->length == 3) { - info.catalog = PGPointerCast(view_list->head->data.ptr_value)->val.str; - info.schema = PGPointerCast(view_list->head->next->data.ptr_value)->val.str; - info.name = PGPointerCast(view_list->head->next->next->data.ptr_value)->val.str; - } else if (view_list->length == 2) { - info.schema = PGPointerCast(view_list->head->data.ptr_value)->val.str; - info.name = PGPointerCast(view_list->head->next->data.ptr_value)->val.str; - } else if (view_list->length == 1) { - info.name = PGPointerCast(view_list->head->data.ptr_value)->val.str; - } else { - throw ParserException("Expected \"catalog.schema.name\", \"schema.name\"or \"name\""); - } - break; - } - } - info.cascade = stmt.behavior == duckdb_libpgquery::PGDropBehavior::PG_DROP_CASCADE; - info.if_not_found = TransformOnEntryNotFound(stmt.missing_ok); - return std::move(result); -} - -} // namespace duckdb - - - -namespace duckdb { - -unique_ptr Transformer::TransformExplain(duckdb_libpgquery::PGExplainStmt &stmt) { - auto explain_type = ExplainType::EXPLAIN_STANDARD; - if (stmt.options) { - for (auto n = stmt.options->head; n; n = n->next) { - auto def_elem = PGPointerCast(n->data.ptr_value)->defname; - string elem(def_elem); - if (elem == "analyze") { - explain_type = ExplainType::EXPLAIN_ANALYZE; - } else { - throw NotImplementedException("Unimplemented explain type: %s", elem); - } - } - } - return make_uniq(TransformStatement(*stmt.query), explain_type); -} - -} // namespace duckdb - - - -namespace duckdb { - -unique_ptr Transformer::TransformExport(duckdb_libpgquery::PGExportStmt &stmt) { - auto info = make_uniq(); - info->file_path = stmt.filename; - info->format = "csv"; - info->is_from = false; - // handle export options - TransformCopyOptions(*info, stmt.options); - - auto result = make_uniq(std::move(info)); - if (stmt.database) { - result->database = stmt.database; - } - return result; -} - -} // namespace duckdb - - - -namespace duckdb { - -unique_ptr Transformer::TransformImport(duckdb_libpgquery::PGImportStmt &stmt) { - auto result = make_uniq(); - result->info->name = "import_database"; - result->info->parameters.emplace_back(stmt.filename); - return result; -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr Transformer::TransformValuesList(duckdb_libpgquery::PGList *list) { - auto result = make_uniq(); - for (auto value_list = list->head; value_list != nullptr; value_list = value_list->next) { - auto target = PGPointerCast(value_list->data.ptr_value); - - vector> insert_values; - TransformExpressionList(*target, insert_values); - if (!result->values.empty()) { - if (result->values[0].size() != insert_values.size()) { - throw ParserException("VALUES lists must all be the same length"); - } - } - result->values.push_back(std::move(insert_values)); - } - result->alias = "valueslist"; - return std::move(result); -} - -unique_ptr Transformer::TransformInsert(duckdb_libpgquery::PGInsertStmt &stmt) { - auto result = make_uniq(); - vector> materialized_ctes; - if (stmt.withClause) { - TransformCTE(*PGPointerCast(stmt.withClause), result->cte_map, - materialized_ctes); - if (!materialized_ctes.empty()) { - throw NotImplementedException("Materialized CTEs are not implemented for insert."); - } - } - - // first check if there are any columns specified - if (stmt.cols) { - for (auto c = stmt.cols->head; c != nullptr; c = lnext(c)) { - auto target = PGPointerCast(c->data.ptr_value); - result->columns.emplace_back(target->name); - } - } - - // Grab and transform the returning columns from the parser. - if (stmt.returningList) { - TransformExpressionList(*stmt.returningList, result->returning_list); - } - if (stmt.selectStmt) { - result->select_statement = TransformSelect(stmt.selectStmt, false); - } else { - result->default_values = true; - } - - auto qname = TransformQualifiedName(*stmt.relation); - result->table = qname.name; - result->schema = qname.schema; - - if (stmt.onConflictClause) { - if (stmt.onConflictAlias != duckdb_libpgquery::PG_ONCONFLICT_ALIAS_NONE) { - // OR REPLACE | OR IGNORE are shorthands for the ON CONFLICT clause - throw ParserException("You can not provide both OR REPLACE|IGNORE and an ON CONFLICT clause, please remove " - "the first if you want to have more granual control"); - } - result->on_conflict_info = TransformOnConflictClause(stmt.onConflictClause, result->schema); - result->table_ref = TransformRangeVar(*stmt.relation); - } - if (stmt.onConflictAlias != duckdb_libpgquery::PG_ONCONFLICT_ALIAS_NONE) { - D_ASSERT(!stmt.onConflictClause); - result->on_conflict_info = DummyOnConflictClause(stmt.onConflictAlias, result->schema); - result->table_ref = TransformRangeVar(*stmt.relation); - } - switch (stmt.insert_column_order) { - case duckdb_libpgquery::PG_INSERT_BY_POSITION: - result->column_order = InsertColumnOrder::INSERT_BY_POSITION; - break; - case duckdb_libpgquery::PG_INSERT_BY_NAME: - result->column_order = InsertColumnOrder::INSERT_BY_NAME; - break; - default: - throw InternalException("Unrecognized insert column order in TransformInsert"); - } - result->catalog = qname.catalog; - return result; -} - -} // namespace duckdb - - - -namespace duckdb { - -unique_ptr Transformer::TransformLoad(duckdb_libpgquery::PGLoadStmt &stmt) { - D_ASSERT(stmt.type == duckdb_libpgquery::T_PGLoadStmt); - - auto load_stmt = make_uniq(); - auto load_info = make_uniq(); - load_info->filename = std::string(stmt.filename); - load_info->repository = std::string(stmt.repository); - switch (stmt.load_type) { - case duckdb_libpgquery::PG_LOAD_TYPE_LOAD: - load_info->load_type = LoadType::LOAD; - break; - case duckdb_libpgquery::PG_LOAD_TYPE_INSTALL: - load_info->load_type = LoadType::INSTALL; - break; - case duckdb_libpgquery::PG_LOAD_TYPE_FORCE_INSTALL: - load_info->load_type = LoadType::FORCE_INSTALL; - break; - } - load_stmt->info = std::move(load_info); - return load_stmt; -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - -namespace duckdb { - -void Transformer::AddPivotEntry(string enum_name, unique_ptr base, unique_ptr column, - unique_ptr subquery) { - if (parent) { - parent->AddPivotEntry(std::move(enum_name), std::move(base), std::move(column), std::move(subquery)); - return; - } - auto result = make_uniq(); - result->enum_name = std::move(enum_name); - result->base = std::move(base); - result->column = std::move(column); - result->subquery = std::move(subquery); - - pivot_entries.push_back(std::move(result)); -} - -bool Transformer::HasPivotEntries() { - return !GetPivotEntries().empty(); -} - -idx_t Transformer::PivotEntryCount() { - return GetPivotEntries().size(); -} - -vector> &Transformer::GetPivotEntries() { - if (parent) { - return parent->GetPivotEntries(); - } - return pivot_entries; -} - -void Transformer::PivotEntryCheck(const string &type) { - auto &entries = GetPivotEntries(); - if (!entries.empty()) { - throw ParserException( - "PIVOT statements with pivot elements extracted from the data cannot be used in %ss.\nIn order to use " - "PIVOT in a %s the PIVOT values must be manually specified, e.g.:\nPIVOT ... ON %s IN (val1, val2, ...)", - type, type, entries[0]->column->ToString()); - } -} -unique_ptr Transformer::GenerateCreateEnumStmt(unique_ptr entry) { - auto result = make_uniq(); - auto info = make_uniq(); - - info->temporary = true; - info->internal = false; - info->catalog = INVALID_CATALOG; - info->schema = INVALID_SCHEMA; - info->name = std::move(entry->enum_name); - info->on_conflict = OnCreateConflict::REPLACE_ON_CONFLICT; - - // generate the query that will result in the enum creation - unique_ptr subselect; - if (!entry->subquery) { - auto select_node = std::move(entry->base); - auto columnref = entry->column->Copy(); - auto cast = make_uniq(LogicalType::VARCHAR, std::move(columnref)); - select_node->select_list.push_back(std::move(cast)); - - auto is_not_null = - make_uniq(ExpressionType::OPERATOR_IS_NOT_NULL, std::move(entry->column)); - select_node->where_clause = std::move(is_not_null); - - // order by the column - select_node->modifiers.push_back(make_uniq()); - auto modifier = make_uniq(); - modifier->orders.emplace_back(OrderType::ASCENDING, OrderByNullType::ORDER_DEFAULT, - make_uniq(Value::INTEGER(1))); - select_node->modifiers.push_back(std::move(modifier)); - subselect = std::move(select_node); - } else { - subselect = std::move(entry->subquery); - } - - auto select = make_uniq(); - select->node = std::move(subselect); - info->query = std::move(select); - info->type = LogicalType::INVALID; - - result->info = std::move(info); - return std::move(result); -} - -// unique_ptr GenerateDropEnumStmt(string enum_name) { -// auto result = make_uniq(); -// result->info->if_exists = true; -// result->info->schema = INVALID_SCHEMA; -// result->info->catalog = INVALID_CATALOG; -// result->info->name = std::move(enum_name); -// result->info->type = CatalogType::TYPE_ENTRY; -// return std::move(result); -//} - -unique_ptr Transformer::CreatePivotStatement(unique_ptr statement) { - auto result = make_uniq(); - for (auto &pivot : pivot_entries) { - result->statements.push_back(GenerateCreateEnumStmt(std::move(pivot))); - } - result->statements.push_back(std::move(statement)); - // FIXME: drop the types again!? - // for(auto &pivot : pivot_entries) { - // result->statements.push_back(GenerateDropEnumStmt(std::move(pivot->enum_name))); - // } - return std::move(result); -} - -unique_ptr Transformer::TransformPivotStatement(duckdb_libpgquery::PGSelectStmt &select) { - auto pivot = select.pivot; - auto source = TransformTableRefNode(*pivot->source); - - auto select_node = make_uniq(); - vector> materialized_ctes; - // handle the CTEs - if (select.withClause) { - TransformCTE(*PGPointerCast(select.withClause), select_node->cte_map, - materialized_ctes); - } - if (!pivot->columns) { - // no pivot columns - not actually a pivot - select_node->from_table = std::move(source); - if (pivot->groups) { - auto groups = TransformStringList(pivot->groups); - GroupingSet set; - for (idx_t gr = 0; gr < groups.size(); gr++) { - auto &group = groups[gr]; - auto colref = make_uniq(group); - select_node->select_list.push_back(colref->Copy()); - select_node->groups.group_expressions.push_back(std::move(colref)); - set.insert(gr); - } - select_node->groups.grouping_sets.push_back(std::move(set)); - } - if (pivot->aggrs) { - TransformExpressionList(*pivot->aggrs, select_node->select_list); - } - return std::move(select_node); - } - - // generate CREATE TYPE statements for each of the columns that do not have an IN list - auto columns = TransformPivotList(*pivot->columns); - auto pivot_idx = PivotEntryCount(); - for (idx_t c = 0; c < columns.size(); c++) { - auto &col = columns[c]; - if (!col.pivot_enum.empty() || !col.entries.empty()) { - continue; - } - if (col.pivot_expressions.size() != 1) { - throw InternalException("PIVOT statement with multiple names in pivot entry!?"); - } - auto enum_name = "__pivot_enum_" + std::to_string(pivot_idx) + "_" + std::to_string(c); - - auto new_select = make_uniq(); - ExtractCTEsRecursive(new_select->cte_map); - new_select->from_table = source->Copy(); - AddPivotEntry(enum_name, std::move(new_select), col.pivot_expressions[0]->Copy(), std::move(col.subquery)); - col.pivot_enum = enum_name; - } - - // generate the actual query, including the pivot - select_node->select_list.push_back(make_uniq()); - - auto pivot_ref = make_uniq(); - pivot_ref->source = std::move(source); - if (pivot->unpivots) { - pivot_ref->unpivot_names = TransformStringList(pivot->unpivots); - } else { - if (pivot->aggrs) { - TransformExpressionList(*pivot->aggrs, pivot_ref->aggregates); - } else { - // pivot but no aggregates specified - push a count star - vector> children; - auto function = make_uniq("count_star", std::move(children)); - pivot_ref->aggregates.push_back(std::move(function)); - } - } - if (pivot->groups) { - pivot_ref->groups = TransformStringList(pivot->groups); - } - pivot_ref->pivots = std::move(columns); - select_node->from_table = std::move(pivot_ref); - // transform order by/limit modifiers - TransformModifiers(select, *select_node); - - auto node = Transformer::TransformMaterializedCTE(std::move(select_node), materialized_ctes); - - return node; -} - -} // namespace duckdb - - - - - - - - - - -namespace duckdb { - -unique_ptr Transformer::TransformPragma(duckdb_libpgquery::PGPragmaStmt &stmt) { - auto result = make_uniq(); - auto &info = *result->info; - - info.name = stmt.name; - // parse the arguments, if any - if (stmt.args) { - for (auto cell = stmt.args->head; cell != nullptr; cell = cell->next) { - auto node = PGPointerCast(cell->data.ptr_value); - auto expr = TransformExpression(node); - - if (expr->type == ExpressionType::COMPARE_EQUAL) { - auto &comp = expr->Cast(); - if (comp.left->type != ExpressionType::COLUMN_REF) { - throw ParserException("Named parameter requires a column reference on the LHS"); - } - auto &columnref = comp.left->Cast(); - - Value rhs_value; - if (!Transformer::ConstructConstantFromExpression(*comp.right, rhs_value)) { - throw ParserException("Named parameter requires a constant on the RHS"); - } - - info.named_parameters[columnref.GetName()] = rhs_value; - } else if (node->type == duckdb_libpgquery::T_PGAConst) { - auto constant = TransformConstant(*PGPointerCast(node.get())); - info.parameters.push_back((constant->Cast()).value); - } else if (expr->type == ExpressionType::COLUMN_REF) { - auto &colref = expr->Cast(); - if (!colref.IsQualified()) { - info.parameters.emplace_back(colref.GetColumnName()); - } else { - info.parameters.emplace_back(expr->ToString()); - } - } else { - info.parameters.emplace_back(expr->ToString()); - } - } - } - // now parse the pragma type - switch (stmt.kind) { - case duckdb_libpgquery::PG_PRAGMA_TYPE_NOTHING: { - if (!info.parameters.empty() || !info.named_parameters.empty()) { - throw InternalException("PRAGMA statement that is not a call or assignment cannot contain parameters"); - } - break; - case duckdb_libpgquery::PG_PRAGMA_TYPE_ASSIGNMENT: - if (info.parameters.size() != 1) { - throw InternalException("PRAGMA statement with assignment should contain exactly one parameter"); - } - if (!info.named_parameters.empty()) { - throw InternalException("PRAGMA statement with assignment cannot have named parameters"); - } - // SQLite does not distinguish between: - // "PRAGMA table_info='integers'" - // "PRAGMA table_info('integers')" - // for compatibility, any pragmas that match the SQLite ones are parsed as calls - case_insensitive_set_t sqlite_compat_pragmas {"table_info"}; - if (sqlite_compat_pragmas.find(info.name) != sqlite_compat_pragmas.end()) { - break; - } - auto set_statement = make_uniq(info.name, info.parameters[0], SetScope::AUTOMATIC); - return std::move(set_statement); - } - case duckdb_libpgquery::PG_PRAGMA_TYPE_CALL: - break; - default: - throw InternalException("Unknown pragma type"); - } - - return std::move(result); -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -unique_ptr Transformer::TransformPrepare(duckdb_libpgquery::PGPrepareStmt &stmt) { - if (stmt.argtypes && stmt.argtypes->length > 0) { - throw NotImplementedException("Prepared statement argument types are not supported, use CAST"); - } - - auto result = make_uniq(); - result->name = string(stmt.name); - result->statement = TransformStatement(*stmt.query); - SetParamCount(0); - - return result; -} - -static string NotAcceptedExpressionException() { - return "Only scalar parameters, named parameters or NULL supported for EXECUTE"; -} - -unique_ptr Transformer::TransformExecute(duckdb_libpgquery::PGExecuteStmt &stmt) { - auto result = make_uniq(); - result->name = string(stmt.name); - - vector> intermediate_values; - if (stmt.params) { - TransformExpressionList(*stmt.params, intermediate_values); - } - - idx_t param_idx = 0; - for (idx_t i = 0; i < intermediate_values.size(); i++) { - auto &expr = intermediate_values[i]; - if (!expr->IsScalar()) { - throw InvalidInputException(NotAcceptedExpressionException()); - } - if (!expr->alias.empty() && param_idx != 0) { - // Found unnamed parameters mixed with named parameters - throw NotImplementedException("Mixing named parameters and positional parameters is not supported yet"); - } - auto param_name = expr->alias; - if (expr->alias.empty()) { - param_name = std::to_string(param_idx + 1); - if (param_idx != i) { - throw NotImplementedException("Mixing named parameters and positional parameters is not supported yet"); - } - param_idx++; - } - expr->alias.clear(); - result->named_values[param_name] = std::move(expr); - } - intermediate_values.clear(); - return result; -} - -unique_ptr Transformer::TransformDeallocate(duckdb_libpgquery::PGDeallocateStmt &stmt) { - if (!stmt.name) { - throw ParserException("DEALLOCATE requires a name"); - } - - auto result = make_uniq(); - result->info->type = CatalogType::PREPARED_STATEMENT; - result->info->name = string(stmt.name); - return result; -} - -} // namespace duckdb - - - -namespace duckdb { - -unique_ptr Transformer::TransformRename(duckdb_libpgquery::PGRenameStmt &stmt) { - if (!stmt.relation) { - throw NotImplementedException("Altering schemas is not yet supported"); - } - - unique_ptr info; - - AlterEntryData data; - data.if_not_found = TransformOnEntryNotFound(stmt.missing_ok); - data.catalog = stmt.relation->catalogname ? stmt.relation->catalogname : INVALID_CATALOG; - data.schema = stmt.relation->schemaname ? stmt.relation->schemaname : INVALID_SCHEMA; - if (stmt.relation->relname) { - data.name = stmt.relation->relname; - } - // first we check the type of ALTER - switch (stmt.renameType) { - case duckdb_libpgquery::PG_OBJECT_COLUMN: { - // change column name - - // get the old name and the new name - string old_name = stmt.subname; - string new_name = stmt.newname; - info = make_uniq(std::move(data), old_name, new_name); - break; - } - case duckdb_libpgquery::PG_OBJECT_TABLE: { - // change table name - string new_name = stmt.newname; - info = make_uniq(std::move(data), new_name); - break; - } - case duckdb_libpgquery::PG_OBJECT_VIEW: { - // change view name - string new_name = stmt.newname; - info = make_uniq(std::move(data), new_name); - break; - } - case duckdb_libpgquery::PG_OBJECT_DATABASE: - default: - throw NotImplementedException("Schema element not supported yet!"); - } - D_ASSERT(info); - - auto result = make_uniq(); - result->info = std::move(info); - return result; -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr Transformer::TransformSelectNode(duckdb_libpgquery::PGSelectStmt &select) { - if (select.pivot) { - return TransformPivotStatement(select); - } else { - return TransformSelectInternal(select); - } -} - -unique_ptr Transformer::TransformSelect(duckdb_libpgquery::PGSelectStmt &select, bool is_select) { - auto result = make_uniq(); - - // Both Insert/Create Table As uses this. - if (is_select) { - if (select.intoClause) { - throw ParserException("SELECT INTO not supported!"); - } - if (select.lockingClause) { - throw ParserException("SELECT locking clause is not supported!"); - } - } - - result->node = TransformSelectNode(select); - return result; -} - -unique_ptr Transformer::TransformSelect(optional_ptr node, bool is_select) { - return TransformSelect(PGCast(*node), is_select); -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -void Transformer::TransformModifiers(duckdb_libpgquery::PGSelectStmt &stmt, QueryNode &node) { - // transform the common properties - // both the set operations and the regular select can have an ORDER BY/LIMIT attached to them - vector orders; - TransformOrderBy(stmt.sortClause, orders); - if (!orders.empty()) { - auto order_modifier = make_uniq(); - order_modifier->orders = std::move(orders); - node.modifiers.push_back(std::move(order_modifier)); - } - if (stmt.limitCount || stmt.limitOffset) { - if (stmt.limitCount && stmt.limitCount->type == duckdb_libpgquery::T_PGLimitPercent) { - auto limit_percent_modifier = make_uniq(); - auto expr_node = PGPointerCast(stmt.limitCount)->limit_percent; - limit_percent_modifier->limit = TransformExpression(expr_node); - if (stmt.limitOffset) { - limit_percent_modifier->offset = TransformExpression(stmt.limitOffset); - } - node.modifiers.push_back(std::move(limit_percent_modifier)); - } else { - auto limit_modifier = make_uniq(); - if (stmt.limitCount) { - limit_modifier->limit = TransformExpression(stmt.limitCount); - } - if (stmt.limitOffset) { - limit_modifier->offset = TransformExpression(stmt.limitOffset); - } - node.modifiers.push_back(std::move(limit_modifier)); - } - } -} - -unique_ptr Transformer::TransformSelectInternal(duckdb_libpgquery::PGSelectStmt &stmt) { - D_ASSERT(stmt.type == duckdb_libpgquery::T_PGSelectStmt); - auto stack_checker = StackCheck(); - - unique_ptr node; - vector> materialized_ctes; - - switch (stmt.op) { - case duckdb_libpgquery::PG_SETOP_NONE: { - node = make_uniq(); - auto &result = node->Cast(); - if (stmt.withClause) { - TransformCTE(*PGPointerCast(stmt.withClause), node->cte_map, - materialized_ctes); - } - if (stmt.windowClause) { - for (auto window_ele = stmt.windowClause->head; window_ele != nullptr; window_ele = window_ele->next) { - auto window_def = PGPointerCast(window_ele->data.ptr_value); - D_ASSERT(window_def); - D_ASSERT(window_def->name); - string window_name(window_def->name); - auto it = window_clauses.find(window_name); - if (it != window_clauses.end()) { - throw ParserException("window \"%s\" is already defined", window_name); - } - window_clauses[window_name] = window_def.get(); - } - } - - // checks distinct clause - if (stmt.distinctClause != nullptr) { - auto modifier = make_uniq(); - // checks distinct on clause - auto target = PGPointerCast(stmt.distinctClause->head->data.ptr_value); - if (target) { - // add the columns defined in the ON clause to the select list - TransformExpressionList(*stmt.distinctClause, modifier->distinct_on_targets); - } - result.modifiers.push_back(std::move(modifier)); - } - - // do this early so the value lists also have a `FROM` - if (stmt.valuesLists) { - // VALUES list, create an ExpressionList - D_ASSERT(!stmt.fromClause); - result.from_table = TransformValuesList(stmt.valuesLists); - result.select_list.push_back(make_uniq()); - } else { - if (!stmt.targetList) { - throw ParserException("SELECT clause without selection list"); - } - // select list - TransformExpressionList(*stmt.targetList, result.select_list); - result.from_table = TransformFrom(stmt.fromClause); - } - - // where - result.where_clause = TransformExpression(stmt.whereClause); - // group by - TransformGroupBy(stmt.groupClause, result); - // having - result.having = TransformExpression(stmt.havingClause); - // qualify - result.qualify = TransformExpression(stmt.qualifyClause); - // sample - result.sample = TransformSampleOptions(stmt.sampleOptions); - break; - } - case duckdb_libpgquery::PG_SETOP_UNION: - case duckdb_libpgquery::PG_SETOP_EXCEPT: - case duckdb_libpgquery::PG_SETOP_INTERSECT: - case duckdb_libpgquery::PG_SETOP_UNION_BY_NAME: { - node = make_uniq(); - auto &result = node->Cast(); - if (stmt.withClause) { - TransformCTE(*PGPointerCast(stmt.withClause), node->cte_map, - materialized_ctes); - } - result.left = TransformSelectNode(*stmt.larg); - result.right = TransformSelectNode(*stmt.rarg); - if (!result.left || !result.right) { - throw Exception("Failed to transform setop children."); - } - - bool select_distinct = true; - switch (stmt.op) { - case duckdb_libpgquery::PG_SETOP_UNION: - select_distinct = !stmt.all; - result.setop_type = SetOperationType::UNION; - break; - case duckdb_libpgquery::PG_SETOP_EXCEPT: - result.setop_type = SetOperationType::EXCEPT; - break; - case duckdb_libpgquery::PG_SETOP_INTERSECT: - result.setop_type = SetOperationType::INTERSECT; - break; - case duckdb_libpgquery::PG_SETOP_UNION_BY_NAME: - select_distinct = !stmt.all; - result.setop_type = SetOperationType::UNION_BY_NAME; - break; - default: - throw Exception("Unexpected setop type"); - } - if (select_distinct) { - result.modifiers.push_back(make_uniq()); - } - if (stmt.sampleOptions) { - throw ParserException("SAMPLE clause is only allowed in regular SELECT statements"); - } - break; - } - default: - throw NotImplementedException("Statement type %d not implemented!", stmt.op); - } - - TransformModifiers(stmt, *node); - - // Handle materialized CTEs - node = Transformer::TransformMaterializedCTE(std::move(node), materialized_ctes); - - return node; -} - -} // namespace duckdb - - - - - -namespace duckdb { - -namespace { - -SetScope ToSetScope(duckdb_libpgquery::VariableSetScope pg_scope) { - switch (pg_scope) { - case duckdb_libpgquery::VariableSetScope::VAR_SET_SCOPE_LOCAL: - return SetScope::LOCAL; - case duckdb_libpgquery::VariableSetScope::VAR_SET_SCOPE_SESSION: - return SetScope::SESSION; - case duckdb_libpgquery::VariableSetScope::VAR_SET_SCOPE_GLOBAL: - return SetScope::GLOBAL; - case duckdb_libpgquery::VariableSetScope::VAR_SET_SCOPE_DEFAULT: - return SetScope::AUTOMATIC; - default: - throw InternalException("Unexpected pg_scope: %d", pg_scope); - } -} - -SetType ToSetType(duckdb_libpgquery::VariableSetKind pg_kind) { - switch (pg_kind) { - case duckdb_libpgquery::VariableSetKind::VAR_SET_VALUE: - return SetType::SET; - case duckdb_libpgquery::VariableSetKind::VAR_RESET: - return SetType::RESET; - default: - throw NotImplementedException("Can only SET or RESET a variable"); - } -} - -} // namespace - -unique_ptr Transformer::TransformSetVariable(duckdb_libpgquery::PGVariableSetStmt &stmt) { - D_ASSERT(stmt.kind == duckdb_libpgquery::VariableSetKind::VAR_SET_VALUE); - - if (stmt.scope == duckdb_libpgquery::VariableSetScope::VAR_SET_SCOPE_LOCAL) { - throw NotImplementedException("SET LOCAL is not implemented."); - } - - auto name = std::string(stmt.name); - D_ASSERT(!name.empty()); // parser protect us! - if (stmt.args->length != 1) { - throw ParserException("SET needs a single scalar value parameter"); - } - D_ASSERT(stmt.args->head && stmt.args->head->data.ptr_value); - auto const_val = PGPointerCast(stmt.args->head->data.ptr_value); - D_ASSERT(const_val->type == duckdb_libpgquery::T_PGAConst); - - auto value = TransformValue(const_val->val)->value; - return make_uniq(name, value, ToSetScope(stmt.scope)); -} - -unique_ptr Transformer::TransformResetVariable(duckdb_libpgquery::PGVariableSetStmt &stmt) { - D_ASSERT(stmt.kind == duckdb_libpgquery::VariableSetKind::VAR_RESET); - - if (stmt.scope == duckdb_libpgquery::VariableSetScope::VAR_SET_SCOPE_LOCAL) { - throw NotImplementedException("RESET LOCAL is not implemented."); - } - - auto name = std::string(stmt.name); - D_ASSERT(!name.empty()); // parser protect us! - - return make_uniq(name, ToSetScope(stmt.scope)); -} - -unique_ptr Transformer::TransformSet(duckdb_libpgquery::PGVariableSetStmt &stmt) { - D_ASSERT(stmt.type == duckdb_libpgquery::T_PGVariableSetStmt); - - SetType set_type = ToSetType(stmt.kind); - - switch (set_type) { - case SetType::SET: - return TransformSetVariable(stmt); - case SetType::RESET: - return TransformResetVariable(stmt); - default: - throw NotImplementedException("Type not implemented for SetType"); - } -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -static void TransformShowName(unique_ptr &result, const string &name) { - auto &info = *result->info; - auto lname = StringUtil::Lower(name); - - if (lname == "\"databases\"") { - info.name = "show_databases"; - } else if (lname == "\"tables\"") { - // show all tables - info.name = "show_tables"; - } else if (lname == "__show_tables_expanded") { - info.name = "show_tables_expanded"; - } else { - // show one specific table - info.name = "show"; - info.parameters.emplace_back(name); - } -} - -unique_ptr Transformer::TransformShow(duckdb_libpgquery::PGVariableShowStmt &stmt) { - // we transform SHOW x into PRAGMA SHOW('x') - if (stmt.is_summary) { - auto result = make_uniq(); - auto &info = *result->info; - info.is_summary = stmt.is_summary; - - auto select = make_uniq(); - select->select_list.push_back(make_uniq()); - auto basetable = make_uniq(); - auto qualified_name = QualifiedName::Parse(stmt.name); - basetable->schema_name = qualified_name.schema; - basetable->table_name = qualified_name.name; - select->from_table = std::move(basetable); - - info.query = std::move(select); - return std::move(result); - } - - auto result = make_uniq(); - - auto show_name = stmt.name; - TransformShowName(result, show_name); - return std::move(result); -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr Transformer::TransformShowSelect(duckdb_libpgquery::PGVariableShowSelectStmt &stmt) { - // we capture the select statement of SHOW - auto select_stmt = PGPointerCast(stmt.stmt); - - auto result = make_uniq(); - auto &info = *result->info; - info.is_summary = stmt.is_summary; - - info.query = TransformSelectNode(*select_stmt); - - return result; -} - -} // namespace duckdb - - - -namespace duckdb { - -unique_ptr Transformer::TransformTransaction(duckdb_libpgquery::PGTransactionStmt &stmt) { - switch (stmt.kind) { - case duckdb_libpgquery::PG_TRANS_STMT_BEGIN: - case duckdb_libpgquery::PG_TRANS_STMT_START: - return make_uniq(TransactionType::BEGIN_TRANSACTION); - case duckdb_libpgquery::PG_TRANS_STMT_COMMIT: - return make_uniq(TransactionType::COMMIT); - case duckdb_libpgquery::PG_TRANS_STMT_ROLLBACK: - return make_uniq(TransactionType::ROLLBACK); - default: - throw NotImplementedException("Transaction type %d not implemented yet", stmt.kind); - } -} - -} // namespace duckdb - - - -namespace duckdb { - -unique_ptr Transformer::TransformUpdateSetInfo(duckdb_libpgquery::PGList *target_list, - duckdb_libpgquery::PGNode *where_clause) { - auto result = make_uniq(); - - auto root = target_list; - for (auto cell = root->head; cell != nullptr; cell = cell->next) { - auto target = PGPointerCast(cell->data.ptr_value); - result->columns.emplace_back(target->name); - result->expressions.push_back(TransformExpression(target->val)); - } - result->condition = TransformExpression(where_clause); - return result; -} - -unique_ptr Transformer::TransformUpdate(duckdb_libpgquery::PGUpdateStmt &stmt) { - auto result = make_uniq(); - vector> materialized_ctes; - if (stmt.withClause) { - TransformCTE(*PGPointerCast(stmt.withClause), result->cte_map, - materialized_ctes); - if (!materialized_ctes.empty()) { - throw NotImplementedException("Materialized CTEs are not implemented for update."); - } - } - - result->table = TransformRangeVar(*stmt.relation); - if (stmt.fromClause) { - result->from_table = TransformFrom(stmt.fromClause); - } - - result->set_info = TransformUpdateSetInfo(stmt.targetList, stmt.whereClause); - - // Grab and transform the returning columns from the parser. - if (stmt.returningList) { - TransformExpressionList(*stmt.returningList, result->returning_list); - } - - return result; -} - -} // namespace duckdb - - - - - -namespace duckdb { - -OnConflictAction TransformOnConflictAction(duckdb_libpgquery::PGOnConflictClause *on_conflict) { - if (!on_conflict) { - return OnConflictAction::THROW; - } - switch (on_conflict->action) { - case duckdb_libpgquery::PG_ONCONFLICT_NONE: - return OnConflictAction::THROW; - case duckdb_libpgquery::PG_ONCONFLICT_NOTHING: - return OnConflictAction::NOTHING; - case duckdb_libpgquery::PG_ONCONFLICT_UPDATE: - return OnConflictAction::UPDATE; - default: - throw InternalException("Type not implemented for OnConflictAction"); - } -} - -vector Transformer::TransformConflictTarget(duckdb_libpgquery::PGList &list) { - vector columns; - for (auto cell = list.head; cell != nullptr; cell = cell->next) { - auto index_element = PGPointerCast(cell->data.ptr_value); - if (index_element->collation) { - throw NotImplementedException("Index with collation not supported yet!"); - } - if (index_element->opclass) { - throw NotImplementedException("Index with opclass not supported yet!"); - } - if (!index_element->name) { - throw NotImplementedException("Non-column index element not supported yet!"); - } - if (index_element->nulls_ordering) { - throw NotImplementedException("Index with null_ordering not supported yet!"); - } - if (index_element->ordering) { - throw NotImplementedException("Index with ordering not supported yet!"); - } - columns.emplace_back(index_element->name); - } - return columns; -} - -unique_ptr Transformer::DummyOnConflictClause(duckdb_libpgquery::PGOnConflictActionAlias type, - const string &relname) { - switch (type) { - case duckdb_libpgquery::PGOnConflictActionAlias::PG_ONCONFLICT_ALIAS_REPLACE: { - // This can not be fully resolved yet until the bind stage - auto result = make_uniq(); - result->action_type = OnConflictAction::REPLACE; - return result; - } - case duckdb_libpgquery::PGOnConflictActionAlias::PG_ONCONFLICT_ALIAS_IGNORE: { - // We can just fully replace this with DO NOTHING, and be done with it - auto result = make_uniq(); - result->action_type = OnConflictAction::NOTHING; - return result; - } - default: { - throw InternalException("Type not implemented for PGOnConflictActionAlias"); - } - } -} - -unique_ptr Transformer::TransformOnConflictClause(duckdb_libpgquery::PGOnConflictClause *node, - const string &relname) { - auto stmt = reinterpret_cast(node); - D_ASSERT(stmt); - - auto result = make_uniq(); - result->action_type = TransformOnConflictAction(stmt); - if (stmt->infer) { - // A filter for the ON CONFLICT ... is specified - if (stmt->infer->indexElems) { - // Columns are specified - result->indexed_columns = TransformConflictTarget(*stmt->infer->indexElems); - if (stmt->infer->whereClause) { - result->condition = TransformExpression(stmt->infer->whereClause); - } - } else { - throw NotImplementedException("ON CONSTRAINT conflict target is not supported yet"); - } - } - - if (result->action_type == OnConflictAction::UPDATE) { - result->set_info = TransformUpdateSetInfo(stmt->targetList, stmt->whereClause); - } - return result; -} - -} // namespace duckdb - - - -namespace duckdb { - -unique_ptr Transformer::TransformUse(duckdb_libpgquery::PGUseStmt &stmt) { - auto qualified_name = TransformQualifiedName(*stmt.name); - if (!IsInvalidCatalog(qualified_name.catalog)) { - throw ParserException("Expected \"USE database\" or \"USE database.schema\""); - } - string name; - if (IsInvalidSchema(qualified_name.schema)) { - name = qualified_name.name; - } else { - name = qualified_name.schema + "." + qualified_name.name; - } - return make_uniq("schema", std::move(name), SetScope::AUTOMATIC); -} - -} // namespace duckdb - - - -namespace duckdb { - -VacuumOptions ParseOptions(int options) { - VacuumOptions result; - if (options & duckdb_libpgquery::PGVacuumOption::PG_VACOPT_VACUUM) { - result.vacuum = true; - } - if (options & duckdb_libpgquery::PGVacuumOption::PG_VACOPT_ANALYZE) { - result.analyze = true; - } - if (options & duckdb_libpgquery::PGVacuumOption::PG_VACOPT_VERBOSE) { - throw NotImplementedException("Verbose vacuum option"); - } - if (options & duckdb_libpgquery::PGVacuumOption::PG_VACOPT_FREEZE) { - throw NotImplementedException("Freeze vacuum option"); - } - if (options & duckdb_libpgquery::PGVacuumOption::PG_VACOPT_FULL) { - throw NotImplementedException("Full vacuum option"); - } - if (options & duckdb_libpgquery::PGVacuumOption::PG_VACOPT_NOWAIT) { - throw NotImplementedException("No Wait vacuum option"); - } - if (options & duckdb_libpgquery::PGVacuumOption::PG_VACOPT_SKIPTOAST) { - throw NotImplementedException("Skip Toast vacuum option"); - } - if (options & duckdb_libpgquery::PGVacuumOption::PG_VACOPT_DISABLE_PAGE_SKIPPING) { - throw NotImplementedException("Disable Page Skipping vacuum option"); - } - return result; -} - -unique_ptr Transformer::TransformVacuum(duckdb_libpgquery::PGVacuumStmt &stmt) { - auto result = make_uniq(ParseOptions(stmt.options)); - - if (stmt.relation) { - result->info->ref = TransformRangeVar(*stmt.relation); - result->info->has_table = true; - } - - if (stmt.va_cols) { - D_ASSERT(result->info->has_table); - for (auto col_node = stmt.va_cols->head; col_node != nullptr; col_node = col_node->next) { - result->info->columns.emplace_back( - reinterpret_cast(col_node->data.ptr_value)->val.str); - } - } - return std::move(result); -} - -} // namespace duckdb - - - -namespace duckdb { - -unique_ptr Transformer::TransformRangeVar(duckdb_libpgquery::PGRangeVar &root) { - auto result = make_uniq(); - - result->alias = TransformAlias(root.alias, result->column_name_alias); - if (root.relname) { - result->table_name = root.relname; - } - if (root.catalogname) { - result->catalog_name = root.catalogname; - } - if (root.schemaname) { - result->schema_name = root.schemaname; - } - if (root.sample) { - result->sample = TransformSampleOptions(root.sample); - } - result->query_location = root.location; - return std::move(result); -} - -QualifiedName Transformer::TransformQualifiedName(duckdb_libpgquery::PGRangeVar &root) { - QualifiedName qname; - if (root.catalogname) { - qname.catalog = root.catalogname; - } else { - qname.catalog = INVALID_CATALOG; - } - if (root.schemaname) { - qname.schema = root.schemaname; - } else { - qname.schema = INVALID_SCHEMA; - } - if (root.relname) { - qname.name = root.relname; - } else { - qname.name = string(); - } - return qname; -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr Transformer::TransformFrom(optional_ptr root) { - if (!root) { - return make_uniq(); - } - - if (root->length > 1) { - // Cross Product - auto result = make_uniq(JoinRefType::CROSS); - JoinRef *cur_root = result.get(); - idx_t list_size = 0; - for (auto node = root->head; node != nullptr; node = node->next) { - auto n = PGPointerCast(node->data.ptr_value); - unique_ptr next = TransformTableRefNode(*n); - if (!cur_root->left) { - cur_root->left = std::move(next); - } else if (!cur_root->right) { - cur_root->right = std::move(next); - } else { - auto old_res = std::move(result); - result = make_uniq(JoinRefType::CROSS); - result->left = std::move(old_res); - result->right = std::move(next); - cur_root = result.get(); - } - list_size++; - StackCheck(list_size); - } - return std::move(result); - } - - auto n = PGPointerCast(root->head->data.ptr_value); - return TransformTableRefNode(*n); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -unique_ptr Transformer::TransformJoin(duckdb_libpgquery::PGJoinExpr &root) { - auto result = make_uniq(JoinRefType::REGULAR); - switch (root.jointype) { - case duckdb_libpgquery::PG_JOIN_INNER: { - result->type = JoinType::INNER; - break; - } - case duckdb_libpgquery::PG_JOIN_LEFT: { - result->type = JoinType::LEFT; - break; - } - case duckdb_libpgquery::PG_JOIN_FULL: { - result->type = JoinType::OUTER; - break; - } - case duckdb_libpgquery::PG_JOIN_RIGHT: { - result->type = JoinType::RIGHT; - break; - } - case duckdb_libpgquery::PG_JOIN_SEMI: { - result->type = JoinType::SEMI; - break; - } - case duckdb_libpgquery::PG_JOIN_ANTI: { - result->type = JoinType::ANTI; - break; - } - case duckdb_libpgquery::PG_JOIN_POSITION: { - result->ref_type = JoinRefType::POSITIONAL; - break; - } - default: { - throw NotImplementedException("Join type %d not supported\n", root.jointype); - } - } - - // Check the type of left arg and right arg before transform - result->left = TransformTableRefNode(*root.larg); - result->right = TransformTableRefNode(*root.rarg); - switch (root.joinreftype) { - case duckdb_libpgquery::PG_JOIN_NATURAL: - result->ref_type = JoinRefType::NATURAL; - break; - case duckdb_libpgquery::PG_JOIN_ASOF: - result->ref_type = JoinRefType::ASOF; - break; - default: - break; - } - result->query_location = root.location; - - if (root.usingClause && root.usingClause->length > 0) { - // usingClause is a list of strings - for (auto node = root.usingClause->head; node != nullptr; node = node->next) { - auto target = reinterpret_cast(node->data.ptr_value); - D_ASSERT(target->type == duckdb_libpgquery::T_PGString); - auto column_name = string(reinterpret_cast(target)->val.str); - result->using_columns.push_back(column_name); - } - return std::move(result); - } - - if (!root.quals && result->using_columns.empty() && result->ref_type == JoinRefType::REGULAR) { // CROSS PRODUCT - result->ref_type = JoinRefType::CROSS; - } - result->condition = TransformExpression(root.quals); - return std::move(result); -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -static void TransformPivotInList(unique_ptr &expr, PivotColumnEntry &entry, bool root_entry = true) { - if (expr->type == ExpressionType::COLUMN_REF) { - auto &colref = expr->Cast(); - if (colref.IsQualified()) { - throw ParserException("PIVOT IN list cannot contain qualified column references"); - } - entry.values.emplace_back(colref.GetColumnName()); - } else if (expr->type == ExpressionType::VALUE_CONSTANT) { - auto &constant_expr = expr->Cast(); - entry.values.push_back(std::move(constant_expr.value)); - } else if (root_entry && expr->type == ExpressionType::FUNCTION) { - auto &function = expr->Cast(); - if (function.function_name != "row") { - throw ParserException("PIVOT IN list must contain columns or lists of columns"); - } - for (auto &child : function.children) { - TransformPivotInList(child, entry, false); - } - } else if (root_entry && expr->type == ExpressionType::STAR) { - entry.star_expr = std::move(expr); - } else { - throw ParserException("PIVOT IN list must contain columns or lists of columns"); - } -} - -PivotColumn Transformer::TransformPivotColumn(duckdb_libpgquery::PGPivot &pivot) { - PivotColumn col; - if (pivot.pivot_columns) { - TransformExpressionList(*pivot.pivot_columns, col.pivot_expressions); - for (auto &expr : col.pivot_expressions) { - if (expr->IsScalar()) { - throw ParserException("Cannot pivot on constant value \"%s\"", expr->ToString()); - } - if (expr->HasSubquery()) { - throw ParserException("Cannot pivot on subquery \"%s\"", expr->ToString()); - } - } - } else if (pivot.unpivot_columns) { - col.unpivot_names = TransformStringList(pivot.unpivot_columns); - } else { - throw InternalException("Either pivot_columns or unpivot_columns must be defined"); - } - if (pivot.pivot_value) { - for (auto node = pivot.pivot_value->head; node != nullptr; node = node->next) { - auto n = PGPointerCast(node->data.ptr_value); - auto expr = TransformExpression(n); - PivotColumnEntry entry; - entry.alias = expr->alias; - TransformPivotInList(expr, entry); - col.entries.push_back(std::move(entry)); - } - } - if (pivot.subquery) { - col.subquery = TransformSelectNode(*PGPointerCast(pivot.subquery)); - } - if (pivot.pivot_enum) { - col.pivot_enum = pivot.pivot_enum; - } - return col; -} - -vector Transformer::TransformPivotList(duckdb_libpgquery::PGList &list) { - vector result; - for (auto node = list.head; node != nullptr; node = node->next) { - auto pivot = PGPointerCast(node->data.ptr_value); - result.push_back(TransformPivotColumn(*pivot)); - } - return result; -} - -unique_ptr Transformer::TransformPivot(duckdb_libpgquery::PGPivotExpr &root) { - auto result = make_uniq(); - result->source = TransformTableRefNode(*root.source); - if (root.aggrs) { - TransformExpressionList(*root.aggrs, result->aggregates); - } - if (root.unpivots) { - result->unpivot_names = TransformStringList(root.unpivots); - } - result->pivots = TransformPivotList(*root.pivots); - if (!result->unpivot_names.empty() && result->pivots.size() > 1) { - throw ParserException("UNPIVOT requires a single pivot element"); - } - if (root.groups) { - result->groups = TransformStringList(root.groups); - } - for (auto &pivot : result->pivots) { - idx_t expected_size; - bool is_pivot = result->unpivot_names.empty(); - if (!result->unpivot_names.empty()) { - // unpivot - if (pivot.unpivot_names.size() != 1) { - throw ParserException("UNPIVOT requires a single column name for the PIVOT IN clause"); - } - D_ASSERT(pivot.pivot_expressions.empty()); - expected_size = pivot.entries[0].values.size(); - } else { - // pivot - expected_size = pivot.pivot_expressions.size(); - D_ASSERT(pivot.unpivot_names.empty()); - } - for (auto &entry : pivot.entries) { - if (entry.star_expr && is_pivot) { - throw ParserException("PIVOT IN list must contain columns or lists of columns - star expressions are " - "only supported for UNPIVOT"); - } - if (entry.values.size() != expected_size) { - throw ParserException("PIVOT IN list - inconsistent amount of rows - expected %d but got %d", - expected_size, entry.values.size()); - } - } - } - result->include_nulls = root.include_nulls; - result->alias = TransformAlias(root.alias, result->column_name_alias); - return std::move(result); -} - -} // namespace duckdb - - - -namespace duckdb { - -unique_ptr Transformer::TransformRangeSubselect(duckdb_libpgquery::PGRangeSubselect &root) { - Transformer subquery_transformer(*this); - auto subquery = subquery_transformer.TransformSelect(root.subquery); - if (!subquery) { - return nullptr; - } - auto result = make_uniq(std::move(subquery)); - result->alias = TransformAlias(root.alias, result->column_name_alias); - if (root.sample) { - result->sample = TransformSampleOptions(root.sample); - } - return std::move(result); -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr Transformer::TransformRangeFunction(duckdb_libpgquery::PGRangeFunction &root) { - if (root.ordinality) { - throw NotImplementedException("WITH ORDINALITY not implemented"); - } - if (root.is_rowsfrom) { - throw NotImplementedException("ROWS FROM() not implemented"); - } - if (root.functions->length != 1) { - throw NotImplementedException("Need exactly one function"); - } - auto function_sublist = PGPointerCast(root.functions->head->data.ptr_value); - D_ASSERT(function_sublist->length == 2); - auto call_tree = PGPointerCast(function_sublist->head->data.ptr_value); - auto coldef = function_sublist->head->next->data.ptr_value; - - if (coldef) { - throw NotImplementedException("Explicit column definition not supported yet"); - } - // transform the function call - auto result = make_uniq(); - switch (call_tree->type) { - case duckdb_libpgquery::T_PGFuncCall: { - auto func_call = PGPointerCast(call_tree.get()); - result->function = TransformFuncCall(*func_call); - result->query_location = func_call->location; - break; - } - case duckdb_libpgquery::T_PGSQLValueFunction: - result->function = - TransformSQLValueFunction(*PGPointerCast(call_tree.get())); - break; - default: - throw ParserException("Not a function call or value function"); - } - result->alias = TransformAlias(root.alias, result->column_name_alias); - if (root.sample) { - result->sample = TransformSampleOptions(root.sample); - } - return std::move(result); -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr Transformer::TransformTableRefNode(duckdb_libpgquery::PGNode &n) { - auto stack_checker = StackCheck(); - - switch (n.type) { - case duckdb_libpgquery::T_PGRangeVar: - return TransformRangeVar(PGCast(n)); - case duckdb_libpgquery::T_PGJoinExpr: - return TransformJoin(PGCast(n)); - case duckdb_libpgquery::T_PGRangeSubselect: - return TransformRangeSubselect(PGCast(n)); - case duckdb_libpgquery::T_PGRangeFunction: - return TransformRangeFunction(PGCast(n)); - case duckdb_libpgquery::T_PGPivotExpr: - return TransformPivot(PGCast(n)); - default: - throw NotImplementedException("From Type %d not supported", n.type); - } -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -Transformer::Transformer(ParserOptions &options) - : parent(nullptr), options(options), stack_depth(DConstants::INVALID_INDEX) { -} - -Transformer::Transformer(Transformer &parent) - : parent(&parent), options(parent.options), stack_depth(DConstants::INVALID_INDEX) { -} - -Transformer::~Transformer() { -} - -void Transformer::Clear() { - SetParamCount(0); - pivot_entries.clear(); -} - -bool Transformer::TransformParseTree(duckdb_libpgquery::PGList *tree, vector> &statements) { - InitializeStackCheck(); - for (auto entry = tree->head; entry != nullptr; entry = entry->next) { - Clear(); - auto n = PGPointerCast(entry->data.ptr_value); - auto stmt = TransformStatement(*n); - D_ASSERT(stmt); - if (HasPivotEntries()) { - stmt = CreatePivotStatement(std::move(stmt)); - } - stmt->n_param = ParamCount(); - statements.push_back(std::move(stmt)); - } - return true; -} - -void Transformer::InitializeStackCheck() { - stack_depth = 0; -} - -StackChecker Transformer::StackCheck(idx_t extra_stack) { - auto &root = RootTransformer(); - D_ASSERT(root.stack_depth != DConstants::INVALID_INDEX); - if (root.stack_depth + extra_stack >= options.max_expression_depth) { - throw ParserException("Max expression depth limit of %lld exceeded. Use \"SET max_expression_depth TO x\" to " - "increase the maximum expression depth.", - options.max_expression_depth); - } - return StackChecker(root, extra_stack); -} - -unique_ptr Transformer::TransformStatement(duckdb_libpgquery::PGNode &stmt) { - auto result = TransformStatementInternal(stmt); - result->n_param = ParamCount(); - if (!named_param_map.empty()) { - // Avoid overriding a previous move with nothing - result->named_param_map = std::move(named_param_map); - } - return result; -} - -Transformer &Transformer::RootTransformer() { - reference node = *this; - while (node.get().parent) { - node = *node.get().parent; - } - return node.get(); -} - -const Transformer &Transformer::RootTransformer() const { - reference node = *this; - while (node.get().parent) { - node = *node.get().parent; - } - return node.get(); -} - -idx_t Transformer::ParamCount() const { - auto &root = RootTransformer(); - return root.prepared_statement_parameter_index; -} - -void Transformer::SetParamCount(idx_t new_count) { - auto &root = RootTransformer(); - root.prepared_statement_parameter_index = new_count; -} - -static void ParamTypeCheck(PreparedParamType last_type, PreparedParamType new_type) { - // Mixing positional/auto-increment and named parameters is not supported - if (last_type == PreparedParamType::INVALID) { - return; - } - if (last_type == PreparedParamType::NAMED) { - if (new_type != PreparedParamType::NAMED) { - throw NotImplementedException("Mixing named and positional parameters is not supported yet"); - } - } - if (last_type != PreparedParamType::NAMED) { - if (new_type == PreparedParamType::NAMED) { - throw NotImplementedException("Mixing named and positional parameters is not supported yet"); - } - } -} - -void Transformer::SetParam(const string &identifier, idx_t index, PreparedParamType type) { - auto &root = RootTransformer(); - ParamTypeCheck(root.last_param_type, type); - root.last_param_type = type; - D_ASSERT(!root.named_param_map.count(identifier)); - root.named_param_map[identifier] = index; -} - -bool Transformer::GetParam(const string &identifier, idx_t &index, PreparedParamType type) { - auto &root = RootTransformer(); - ParamTypeCheck(root.last_param_type, type); - auto entry = root.named_param_map.find(identifier); - if (entry == root.named_param_map.end()) { - return false; - } - index = entry->second; - return true; -} - -unique_ptr Transformer::TransformStatementInternal(duckdb_libpgquery::PGNode &stmt) { - switch (stmt.type) { - case duckdb_libpgquery::T_PGRawStmt: { - auto &raw_stmt = PGCast(stmt); - auto result = TransformStatement(*raw_stmt.stmt); - if (result) { - result->stmt_location = raw_stmt.stmt_location; - result->stmt_length = raw_stmt.stmt_len; - } - return result; - } - case duckdb_libpgquery::T_PGSelectStmt: - return TransformSelect(PGCast(stmt)); - case duckdb_libpgquery::T_PGCreateStmt: - return TransformCreateTable(PGCast(stmt)); - case duckdb_libpgquery::T_PGCreateSchemaStmt: - return TransformCreateSchema(PGCast(stmt)); - case duckdb_libpgquery::T_PGViewStmt: - return TransformCreateView(PGCast(stmt)); - case duckdb_libpgquery::T_PGCreateSeqStmt: - return TransformCreateSequence(PGCast(stmt)); - case duckdb_libpgquery::T_PGCreateFunctionStmt: - return TransformCreateFunction(PGCast(stmt)); - case duckdb_libpgquery::T_PGDropStmt: - return TransformDrop(PGCast(stmt)); - case duckdb_libpgquery::T_PGInsertStmt: - return TransformInsert(PGCast(stmt)); - case duckdb_libpgquery::T_PGCopyStmt: - return TransformCopy(PGCast(stmt)); - case duckdb_libpgquery::T_PGTransactionStmt: - return TransformTransaction(PGCast(stmt)); - case duckdb_libpgquery::T_PGDeleteStmt: - return TransformDelete(PGCast(stmt)); - case duckdb_libpgquery::T_PGUpdateStmt: - return TransformUpdate(PGCast(stmt)); - case duckdb_libpgquery::T_PGIndexStmt: - return TransformCreateIndex(PGCast(stmt)); - case duckdb_libpgquery::T_PGAlterTableStmt: - return TransformAlter(PGCast(stmt)); - case duckdb_libpgquery::T_PGRenameStmt: - return TransformRename(PGCast(stmt)); - case duckdb_libpgquery::T_PGPrepareStmt: - return TransformPrepare(PGCast(stmt)); - case duckdb_libpgquery::T_PGExecuteStmt: - return TransformExecute(PGCast(stmt)); - case duckdb_libpgquery::T_PGDeallocateStmt: - return TransformDeallocate(PGCast(stmt)); - case duckdb_libpgquery::T_PGCreateTableAsStmt: - return TransformCreateTableAs(PGCast(stmt)); - case duckdb_libpgquery::T_PGPragmaStmt: - return TransformPragma(PGCast(stmt)); - case duckdb_libpgquery::T_PGExportStmt: - return TransformExport(PGCast(stmt)); - case duckdb_libpgquery::T_PGImportStmt: - return TransformImport(PGCast(stmt)); - case duckdb_libpgquery::T_PGExplainStmt: - return TransformExplain(PGCast(stmt)); - case duckdb_libpgquery::T_PGVacuumStmt: - return TransformVacuum(PGCast(stmt)); - case duckdb_libpgquery::T_PGVariableShowStmt: - return TransformShow(PGCast(stmt)); - case duckdb_libpgquery::T_PGVariableShowSelectStmt: - return TransformShowSelect(PGCast(stmt)); - case duckdb_libpgquery::T_PGCallStmt: - return TransformCall(PGCast(stmt)); - case duckdb_libpgquery::T_PGVariableSetStmt: - return TransformSet(PGCast(stmt)); - case duckdb_libpgquery::T_PGCheckPointStmt: - return TransformCheckpoint(PGCast(stmt)); - case duckdb_libpgquery::T_PGLoadStmt: - return TransformLoad(PGCast(stmt)); - case duckdb_libpgquery::T_PGCreateTypeStmt: - return TransformCreateType(PGCast(stmt)); - case duckdb_libpgquery::T_PGAlterSeqStmt: - return TransformAlterSequence(PGCast(stmt)); - case duckdb_libpgquery::T_PGAttachStmt: - return TransformAttach(PGCast(stmt)); - case duckdb_libpgquery::T_PGDetachStmt: - return TransformDetach(PGCast(stmt)); - case duckdb_libpgquery::T_PGUseStmt: - return TransformUse(PGCast(stmt)); - default: - throw NotImplementedException(NodetypeToString(stmt.type)); - } -} - -unique_ptr Transformer::TransformMaterializedCTE(unique_ptr root, - vector> &materialized_ctes) { - while (!materialized_ctes.empty()) { - unique_ptr node_result; - node_result = std::move(materialized_ctes.back()); - node_result->cte_map = root->cte_map.Copy(); - node_result->child = std::move(root); - root = std::move(node_result); - materialized_ctes.pop_back(); - } - - return root; -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - -#include - -namespace duckdb { - -string BindContext::GetMatchingBinding(const string &column_name) { - string result; - for (auto &kv : bindings) { - auto binding = kv.second.get(); - auto is_using_binding = GetUsingBinding(column_name, kv.first); - if (is_using_binding) { - continue; - } - if (binding->HasMatchingBinding(column_name)) { - if (!result.empty() || is_using_binding) { - throw BinderException("Ambiguous reference to column name \"%s\" (use: \"%s.%s\" " - "or \"%s.%s\")", - column_name, result, column_name, kv.first, column_name); - } - result = kv.first; - } - } - return result; -} - -vector BindContext::GetSimilarBindings(const string &column_name) { - vector> scores; - for (auto &kv : bindings) { - auto binding = kv.second.get(); - for (auto &name : binding->names) { - idx_t distance = StringUtil::SimilarityScore(name, column_name); - scores.emplace_back(binding->alias + "." + name, distance); - } - } - return StringUtil::TopNStrings(scores); -} - -void BindContext::AddUsingBinding(const string &column_name, UsingColumnSet &set) { - using_columns[column_name].insert(set); -} - -void BindContext::AddUsingBindingSet(unique_ptr set) { - using_column_sets.push_back(std::move(set)); -} - -optional_ptr BindContext::GetUsingBinding(const string &column_name) { - auto entry = using_columns.find(column_name); - if (entry == using_columns.end()) { - return nullptr; - } - auto &using_bindings = entry->second; - if (using_bindings.size() > 1) { - string error = "Ambiguous column reference: column \"" + column_name + "\" can refer to either:\n"; - for (auto &using_set_ref : using_bindings) { - auto &using_set = using_set_ref.get(); - string result_bindings; - for (auto &binding : using_set.bindings) { - if (result_bindings.empty()) { - result_bindings = "["; - } else { - result_bindings += ", "; - } - result_bindings += binding; - result_bindings += "."; - result_bindings += GetActualColumnName(binding, column_name); - } - error += result_bindings + "]"; - } - throw BinderException(error); - } - for (auto &using_set : using_bindings) { - return &using_set.get(); - } - throw InternalException("Using binding found but no entries"); -} - -optional_ptr BindContext::GetUsingBinding(const string &column_name, const string &binding_name) { - if (binding_name.empty()) { - throw InternalException("GetUsingBinding: expected non-empty binding_name"); - } - auto entry = using_columns.find(column_name); - if (entry == using_columns.end()) { - return nullptr; - } - auto &using_bindings = entry->second; - for (auto &using_set_ref : using_bindings) { - auto &using_set = using_set_ref.get(); - auto &bindings = using_set.bindings; - if (bindings.find(binding_name) != bindings.end()) { - return &using_set; - } - } - return nullptr; -} - -void BindContext::RemoveUsingBinding(const string &column_name, UsingColumnSet &set) { - auto entry = using_columns.find(column_name); - if (entry == using_columns.end()) { - throw InternalException("Attempting to remove using binding that is not there"); - } - auto &bindings = entry->second; - if (bindings.find(set) != bindings.end()) { - bindings.erase(set); - } - if (bindings.empty()) { - using_columns.erase(column_name); - } -} - -void BindContext::TransferUsingBinding(BindContext ¤t_context, optional_ptr current_set, - UsingColumnSet &new_set, const string &binding, const string &using_column) { - AddUsingBinding(using_column, new_set); - if (current_set) { - current_context.RemoveUsingBinding(using_column, *current_set); - } -} - -string BindContext::GetActualColumnName(const string &binding_name, const string &column_name) { - string error; - auto binding = GetBinding(binding_name, error); - if (!binding) { - throw InternalException("No binding with name \"%s\"", binding_name); - } - column_t binding_index; - if (!binding->TryGetBindingIndex(column_name, binding_index)) { // LCOV_EXCL_START - throw InternalException("Binding with name \"%s\" does not have a column named \"%s\"", binding_name, - column_name); - } // LCOV_EXCL_STOP - return binding->names[binding_index]; -} - -unordered_set BindContext::GetMatchingBindings(const string &column_name) { - unordered_set result; - for (auto &kv : bindings) { - auto binding = kv.second.get(); - if (binding->HasMatchingBinding(column_name)) { - result.insert(kv.first); - } - } - return result; -} - -unique_ptr BindContext::ExpandGeneratedColumn(const string &table_name, const string &column_name) { - string error_message; - - auto binding = GetBinding(table_name, error_message); - D_ASSERT(binding); - auto &table_binding = binding->Cast(); - auto result = table_binding.ExpandGeneratedColumn(column_name); - result->alias = column_name; - return result; -} - -unique_ptr BindContext::CreateColumnReference(const string &table_name, const string &column_name) { - string schema_name; - return CreateColumnReference(schema_name, table_name, column_name); -} - -static bool ColumnIsGenerated(Binding &binding, column_t index) { - if (binding.binding_type != BindingType::TABLE) { - return false; - } - auto &table_binding = binding.Cast(); - auto catalog_entry = table_binding.GetStandardEntry(); - if (!catalog_entry) { - return false; - } - if (index == COLUMN_IDENTIFIER_ROW_ID) { - return false; - } - D_ASSERT(catalog_entry->type == CatalogType::TABLE_ENTRY); - auto &table_entry = catalog_entry->Cast(); - return table_entry.GetColumn(LogicalIndex(index)).Generated(); -} - -unique_ptr BindContext::CreateColumnReference(const string &catalog_name, const string &schema_name, - const string &table_name, const string &column_name) { - string error_message; - vector names; - if (!catalog_name.empty()) { - names.push_back(catalog_name); - } - if (!schema_name.empty()) { - names.push_back(schema_name); - } - names.push_back(table_name); - names.push_back(column_name); - - auto result = make_uniq(std::move(names)); - auto binding = GetBinding(table_name, error_message); - if (!binding) { - return std::move(result); - } - auto column_index = binding->GetBindingIndex(column_name); - if (ColumnIsGenerated(*binding, column_index)) { - return ExpandGeneratedColumn(table_name, column_name); - } else if (column_index < binding->names.size() && binding->names[column_index] != column_name) { - // because of case insensitivity in the binder we rename the column to the original name - // as it appears in the binding itself - result->alias = binding->names[column_index]; - } - return std::move(result); -} - -unique_ptr BindContext::CreateColumnReference(const string &schema_name, const string &table_name, - const string &column_name) { - string catalog_name; - return CreateColumnReference(catalog_name, schema_name, table_name, column_name); -} - -optional_ptr BindContext::GetCTEBinding(const string &ctename) { - auto match = cte_bindings.find(ctename); - if (match == cte_bindings.end()) { - return nullptr; - } - return match->second.get(); -} - -optional_ptr BindContext::GetBinding(const string &name, string &out_error) { - auto match = bindings.find(name); - if (match == bindings.end()) { - // alias not found in this BindContext - vector candidates; - for (auto &kv : bindings) { - candidates.push_back(kv.first); - } - string candidate_str = - StringUtil::CandidatesMessage(StringUtil::TopNLevenshtein(candidates, name), "Candidate tables"); - out_error = StringUtil::Format("Referenced table \"%s\" not found!%s", name, candidate_str); - return nullptr; - } - return match->second.get(); -} - -BindResult BindContext::BindColumn(ColumnRefExpression &colref, idx_t depth) { - if (!colref.IsQualified()) { - throw InternalException("Could not bind alias \"%s\"!", colref.GetColumnName()); - } - - string error; - auto binding = GetBinding(colref.GetTableName(), error); - if (!binding) { - return BindResult(error); - } - return binding->Bind(colref, depth); -} - -string BindContext::BindColumn(PositionalReferenceExpression &ref, string &table_name, string &column_name) { - idx_t total_columns = 0; - idx_t current_position = ref.index - 1; - for (auto &entry : bindings_list) { - auto &binding = entry.get(); - idx_t entry_column_count = binding.names.size(); - if (ref.index == 0) { - // this is a row id - table_name = binding.alias; - column_name = "rowid"; - return string(); - } - if (current_position < entry_column_count) { - table_name = binding.alias; - column_name = binding.names[current_position]; - return string(); - } else { - total_columns += entry_column_count; - current_position -= entry_column_count; - } - } - return StringUtil::Format("Positional reference %d out of range (total %d columns)", ref.index, total_columns); -} - -unique_ptr BindContext::PositionToColumn(PositionalReferenceExpression &ref) { - string table_name, column_name; - - string error = BindColumn(ref, table_name, column_name); - if (!error.empty()) { - throw BinderException(error); - } - return make_uniq(column_name, table_name); -} - -bool BindContext::CheckExclusionList(StarExpression &expr, const string &column_name, - vector> &new_select_list, - case_insensitive_set_t &excluded_columns) { - if (expr.exclude_list.find(column_name) != expr.exclude_list.end()) { - excluded_columns.insert(column_name); - return true; - } - auto entry = expr.replace_list.find(column_name); - if (entry != expr.replace_list.end()) { - auto new_entry = entry->second->Copy(); - new_entry->alias = entry->first; - excluded_columns.insert(entry->first); - new_select_list.push_back(std::move(new_entry)); - return true; - } - return false; -} - -void BindContext::GenerateAllColumnExpressions(StarExpression &expr, - vector> &new_select_list) { - if (bindings_list.empty()) { - throw BinderException("* expression without FROM clause!"); - } - case_insensitive_set_t excluded_columns; - if (expr.relation_name.empty()) { - // SELECT * case - // bind all expressions of each table in-order - reference_set_t handled_using_columns; - for (auto &entry : bindings_list) { - auto &binding = entry.get(); - for (auto &column_name : binding.names) { - if (CheckExclusionList(expr, column_name, new_select_list, excluded_columns)) { - continue; - } - // check if this column is a USING column - auto using_binding_ptr = GetUsingBinding(column_name, binding.alias); - if (using_binding_ptr) { - auto &using_binding = *using_binding_ptr; - // it is! - // check if we have already emitted the using column - if (handled_using_columns.find(using_binding) != handled_using_columns.end()) { - // we have! bail out - continue; - } - // we have not! output the using column - if (using_binding.primary_binding.empty()) { - // no primary binding: output a coalesce - auto coalesce = make_uniq(ExpressionType::OPERATOR_COALESCE); - for (auto &child_binding : using_binding.bindings) { - coalesce->children.push_back(make_uniq(column_name, child_binding)); - } - coalesce->alias = column_name; - new_select_list.push_back(std::move(coalesce)); - } else { - // primary binding: output the qualified column ref - new_select_list.push_back( - make_uniq(column_name, using_binding.primary_binding)); - } - handled_using_columns.insert(using_binding); - continue; - } - new_select_list.push_back(make_uniq(column_name, binding.alias)); - } - } - } else { - // SELECT tbl.* case - // SELECT struct.* case - string error; - auto binding = GetBinding(expr.relation_name, error); - bool is_struct_ref = false; - if (!binding) { - auto binding_name = GetMatchingBinding(expr.relation_name); - if (binding_name.empty()) { - throw BinderException(error); - } - binding = bindings[binding_name].get(); - is_struct_ref = true; - } - - if (is_struct_ref) { - auto col_idx = binding->GetBindingIndex(expr.relation_name); - auto col_type = binding->types[col_idx]; - if (col_type.id() != LogicalTypeId::STRUCT) { - throw BinderException(StringUtil::Format( - "Cannot extract field from expression \"%s\" because it is not a struct", expr.ToString())); - } - auto &struct_children = StructType::GetChildTypes(col_type); - vector column_names(3); - column_names[0] = binding->alias; - column_names[1] = expr.relation_name; - for (auto &child : struct_children) { - if (CheckExclusionList(expr, child.first, new_select_list, excluded_columns)) { - continue; - } - column_names[2] = child.first; - new_select_list.push_back(make_uniq(column_names)); - } - } else { - for (auto &column_name : binding->names) { - if (CheckExclusionList(expr, column_name, new_select_list, excluded_columns)) { - continue; - } - - new_select_list.push_back(make_uniq(column_name, binding->alias)); - } - } - } - for (auto &excluded : expr.exclude_list) { - if (excluded_columns.find(excluded) == excluded_columns.end()) { - throw BinderException("Column \"%s\" in EXCLUDE list not found in %s", excluded, - expr.relation_name.empty() ? "FROM clause" : expr.relation_name.c_str()); - } - } - for (auto &entry : expr.replace_list) { - if (excluded_columns.find(entry.first) == excluded_columns.end()) { - throw BinderException("Column \"%s\" in REPLACE list not found in %s", entry.first, - expr.relation_name.empty() ? "FROM clause" : expr.relation_name.c_str()); - } - } -} - -void BindContext::GetTypesAndNames(vector &result_names, vector &result_types) { - for (auto &binding_entry : bindings_list) { - auto &binding = binding_entry.get(); - D_ASSERT(binding.names.size() == binding.types.size()); - for (idx_t i = 0; i < binding.names.size(); i++) { - result_names.push_back(binding.names[i]); - result_types.push_back(binding.types[i]); - } - } -} - -void BindContext::AddBinding(const string &alias, unique_ptr binding) { - if (bindings.find(alias) != bindings.end()) { - throw BinderException("Duplicate alias \"%s\" in query!", alias); - } - bindings_list.push_back(*binding); - bindings[alias] = std::move(binding); -} - -void BindContext::AddBaseTable(idx_t index, const string &alias, const vector &names, - const vector &types, vector &bound_column_ids, - StandardEntry *entry, bool add_row_id) { - AddBinding(alias, make_uniq(alias, types, names, bound_column_ids, entry, index, add_row_id)); -} - -void BindContext::AddTableFunction(idx_t index, const string &alias, const vector &names, - const vector &types, vector &bound_column_ids, - StandardEntry *entry) { - AddBinding(alias, make_uniq(alias, types, names, bound_column_ids, entry, index)); -} - -static string AddColumnNameToBinding(const string &base_name, case_insensitive_set_t ¤t_names) { - idx_t index = 1; - string name = base_name; - while (current_names.find(name) != current_names.end()) { - name = base_name + ":" + std::to_string(index++); - } - current_names.insert(name); - return name; -} - -vector BindContext::AliasColumnNames(const string &table_name, const vector &names, - const vector &column_aliases) { - vector result; - if (column_aliases.size() > names.size()) { - throw BinderException("table \"%s\" has %lld columns available but %lld columns specified", table_name, - names.size(), column_aliases.size()); - } - case_insensitive_set_t current_names; - // use any provided column aliases first - for (idx_t i = 0; i < column_aliases.size(); i++) { - result.push_back(AddColumnNameToBinding(column_aliases[i], current_names)); - } - // if not enough aliases were provided, use the default names for remaining columns - for (idx_t i = column_aliases.size(); i < names.size(); i++) { - result.push_back(AddColumnNameToBinding(names[i], current_names)); - } - return result; -} - -void BindContext::AddSubquery(idx_t index, const string &alias, SubqueryRef &ref, BoundQueryNode &subquery) { - auto names = AliasColumnNames(alias, subquery.names, ref.column_name_alias); - AddGenericBinding(index, alias, names, subquery.types); -} - -void BindContext::AddEntryBinding(idx_t index, const string &alias, const vector &names, - const vector &types, StandardEntry &entry) { - AddBinding(alias, make_uniq(alias, types, names, index, entry)); -} - -void BindContext::AddView(idx_t index, const string &alias, SubqueryRef &ref, BoundQueryNode &subquery, - ViewCatalogEntry *view) { - auto names = AliasColumnNames(alias, subquery.names, ref.column_name_alias); - AddEntryBinding(index, alias, names, subquery.types, view->Cast()); -} - -void BindContext::AddSubquery(idx_t index, const string &alias, TableFunctionRef &ref, BoundQueryNode &subquery) { - auto names = AliasColumnNames(alias, subquery.names, ref.column_name_alias); - AddGenericBinding(index, alias, names, subquery.types); -} - -void BindContext::AddGenericBinding(idx_t index, const string &alias, const vector &names, - const vector &types) { - AddBinding(alias, make_uniq(BindingType::BASE, alias, types, names, index)); -} - -void BindContext::AddCTEBinding(idx_t index, const string &alias, const vector &names, - const vector &types) { - auto binding = make_shared(BindingType::BASE, alias, types, names, index); - - if (cte_bindings.find(alias) != cte_bindings.end()) { - throw BinderException("Duplicate alias \"%s\" in query!", alias); - } - cte_bindings[alias] = std::move(binding); - cte_references[alias] = std::make_shared(0); -} - -void BindContext::AddContext(BindContext other) { - for (auto &binding : other.bindings) { - if (bindings.find(binding.first) != bindings.end()) { - throw BinderException("Duplicate alias \"%s\" in query!", binding.first); - } - bindings[binding.first] = std::move(binding.second); - } - for (auto &binding : other.bindings_list) { - bindings_list.push_back(binding); - } - for (auto &entry : other.using_columns) { - for (auto &alias : entry.second) { -#ifdef DEBUG - for (auto &other_alias : using_columns[entry.first]) { - for (auto &col : alias.get().bindings) { - D_ASSERT(other_alias.get().bindings.find(col) == other_alias.get().bindings.end()); - } - } -#endif - using_columns[entry.first].insert(alias); - } - } -} - -void BindContext::RemoveContext(vector> &other_bindings_list) { - for (auto &other_binding : other_bindings_list) { - auto it = std::remove_if(bindings_list.begin(), bindings_list.end(), [other_binding](reference x) { - return x.get().alias == other_binding.get().alias; - }); - bindings_list.erase(it, bindings_list.end()); - } - - for (auto &other_binding : other_bindings_list) { - auto &alias = other_binding.get().alias; - if (bindings.find(alias) != bindings.end()) { - bindings.erase(alias); - } - } -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - -namespace duckdb { - -static Value NegatePercentileValue(const Value &v, const bool desc) { - if (v.IsNull()) { - return v; - } - - const auto frac = v.GetValue(); - if (frac < 0 || frac > 1) { - throw BinderException("PERCENTILEs can only take parameters in the range [0, 1]"); - } - - if (!desc) { - return v; - } - - const auto &type = v.type(); - switch (type.id()) { - case LogicalTypeId::DECIMAL: { - // Negate DECIMALs as DECIMAL. - const auto integral = IntegralValue::Get(v); - const auto width = DecimalType::GetWidth(type); - const auto scale = DecimalType::GetScale(type); - switch (type.InternalType()) { - case PhysicalType::INT16: - return Value::DECIMAL(Cast::Operation(-integral), width, scale); - case PhysicalType::INT32: - return Value::DECIMAL(Cast::Operation(-integral), width, scale); - case PhysicalType::INT64: - return Value::DECIMAL(Cast::Operation(-integral), width, scale); - case PhysicalType::INT128: - return Value::DECIMAL(-integral, width, scale); - default: - throw InternalException("Unknown DECIMAL type"); - } - } - default: - // Everything else can just be a DOUBLE - return Value::DOUBLE(-v.GetValue()); - } -} - -static void NegatePercentileFractions(ClientContext &context, unique_ptr &fractions, bool desc) { - D_ASSERT(fractions.get()); - D_ASSERT(fractions->expression_class == ExpressionClass::BOUND_EXPRESSION); - auto &bound = BoundExpression::GetExpression(*fractions); - - if (!bound->IsFoldable()) { - return; - } - - Value value = ExpressionExecutor::EvaluateScalar(context, *bound); - if (value.type().id() == LogicalTypeId::LIST) { - vector values; - for (const auto &element_val : ListValue::GetChildren(value)) { - values.push_back(NegatePercentileValue(element_val, desc)); - } - if (values.empty()) { - throw BinderException("Empty list in percentile not allowed"); - } - bound = make_uniq(Value::LIST(values)); - } else { - bound = make_uniq(NegatePercentileValue(value, desc)); - } -} - -BindResult BaseSelectBinder::BindAggregate(FunctionExpression &aggr, AggregateFunctionCatalogEntry &func, idx_t depth) { - // first bind the child of the aggregate expression (if any) - this->bound_aggregate = true; - unique_ptr bound_filter; - AggregateBinder aggregate_binder(binder, context); - string error; - - // Now we bind the filter (if any) - if (aggr.filter) { - aggregate_binder.BindChild(aggr.filter, 0, error); - } - - // Handle ordered-set aggregates by moving the single ORDER BY expression to the front of the children. - // https://www.postgresql.org/docs/current/functions-aggregate.html#FUNCTIONS-ORDEREDSET-TABLE - bool ordered_set_agg = false; - bool negate_fractions = false; - if (aggr.order_bys && aggr.order_bys->orders.size() == 1) { - const auto &func_name = aggr.function_name; - ordered_set_agg = (func_name == "quantile_cont" || func_name == "quantile_disc" || - (func_name == "mode" && aggr.children.empty())); - - if (ordered_set_agg) { - auto &config = DBConfig::GetConfig(context); - const auto &order = aggr.order_bys->orders[0]; - const auto sense = - (order.type == OrderType::ORDER_DEFAULT) ? config.options.default_order_type : order.type; - negate_fractions = (sense == OrderType::DESCENDING); - } - } - - for (auto &child : aggr.children) { - aggregate_binder.BindChild(child, 0, error); - // We have to negate the fractions for PERCENTILE_XXXX DESC - if (error.empty() && ordered_set_agg) { - NegatePercentileFractions(context, child, negate_fractions); - } - } - - // Bind the ORDER BYs, if any - if (aggr.order_bys && !aggr.order_bys->orders.empty()) { - for (auto &order : aggr.order_bys->orders) { - aggregate_binder.BindChild(order.expression, 0, error); - } - } - - if (!error.empty()) { - // failed to bind child - if (aggregate_binder.HasBoundColumns()) { - for (idx_t i = 0; i < aggr.children.size(); i++) { - // however, we bound columns! - // that means this aggregation belongs to this node - // check if we have to resolve any errors by binding with parent binders - bool success = aggregate_binder.BindCorrelatedColumns(aggr.children[i]); - // if there is still an error after this, we could not successfully bind the aggregate - if (!success) { - throw BinderException(error); - } - auto &bound_expr = BoundExpression::GetExpression(*aggr.children[i]); - ExtractCorrelatedExpressions(binder, *bound_expr); - } - if (aggr.filter) { - bool success = aggregate_binder.BindCorrelatedColumns(aggr.filter); - // if there is still an error after this, we could not successfully bind the aggregate - if (!success) { - throw BinderException(error); - } - auto &bound_expr = BoundExpression::GetExpression(*aggr.filter); - ExtractCorrelatedExpressions(binder, *bound_expr); - } - if (aggr.order_bys && !aggr.order_bys->orders.empty()) { - for (auto &order : aggr.order_bys->orders) { - bool success = aggregate_binder.BindCorrelatedColumns(order.expression); - if (!success) { - throw BinderException(error); - } - auto &bound_expr = BoundExpression::GetExpression(*order.expression); - ExtractCorrelatedExpressions(binder, *bound_expr); - } - } - } else { - // we didn't bind columns, try again in children - return BindResult(error); - } - } else if (depth > 0 && !aggregate_binder.HasBoundColumns()) { - return BindResult("Aggregate with only constant parameters has to be bound in the root subquery"); - } - - if (aggr.filter) { - auto &child = BoundExpression::GetExpression(*aggr.filter); - bound_filter = BoundCastExpression::AddCastToType(context, std::move(child), LogicalType::BOOLEAN); - } - - // all children bound successfully - // extract the children and types - vector types; - vector arguments; - vector> children; - - if (ordered_set_agg) { - const bool order_sensitive = (aggr.function_name == "mode"); - for (auto &order : aggr.order_bys->orders) { - auto &child = BoundExpression::GetExpression(*order.expression); - types.push_back(child->return_type); - arguments.push_back(child->return_type); - if (order_sensitive) { - children.push_back(child->Copy()); - } else { - children.push_back(std::move(child)); - } - } - if (!order_sensitive) { - aggr.order_bys->orders.clear(); - } - } - - for (idx_t i = 0; i < aggr.children.size(); i++) { - auto &child = BoundExpression::GetExpression(*aggr.children[i]); - types.push_back(child->return_type); - arguments.push_back(child->return_type); - children.push_back(std::move(child)); - } - - // bind the aggregate - FunctionBinder function_binder(context); - idx_t best_function = function_binder.BindFunction(func.name, func.functions, types, error); - if (best_function == DConstants::INVALID_INDEX) { - throw BinderException(binder.FormatError(aggr, error)); - } - // found a matching function! - auto bound_function = func.functions.GetFunctionByOffset(best_function); - - // Bind any sort columns, unless the aggregate is order-insensitive - unique_ptr order_bys; - if (!aggr.order_bys->orders.empty()) { - order_bys = make_uniq(); - auto &config = DBConfig::GetConfig(context); - for (auto &order : aggr.order_bys->orders) { - auto &order_expr = BoundExpression::GetExpression(*order.expression); - const auto sense = config.ResolveOrder(order.type); - const auto null_order = config.ResolveNullOrder(sense, order.null_order); - order_bys->orders.emplace_back(sense, null_order, std::move(order_expr)); - } - } - - auto aggregate = - function_binder.BindAggregateFunction(bound_function, std::move(children), std::move(bound_filter), - aggr.distinct ? AggregateType::DISTINCT : AggregateType::NON_DISTINCT); - if (aggr.export_state) { - aggregate = ExportAggregateFunction::Bind(std::move(aggregate)); - } - aggregate->order_bys = std::move(order_bys); - - // check for all the aggregates if this aggregate already exists - idx_t aggr_index; - auto entry = node.aggregate_map.find(*aggregate); - if (entry == node.aggregate_map.end()) { - // new aggregate: insert into aggregate list - aggr_index = node.aggregates.size(); - node.aggregate_map[*aggregate] = aggr_index; - node.aggregates.push_back(std::move(aggregate)); - } else { - // duplicate aggregate: simplify refer to this aggregate - aggr_index = entry->second; - } - - // now create a column reference referring to the aggregate - auto colref = make_uniq( - aggr.alias.empty() ? node.aggregates[aggr_index]->ToString() : aggr.alias, - node.aggregates[aggr_index]->return_type, ColumnBinding(node.aggregate_index, aggr_index), depth); - // move the aggregate expression into the set of bound aggregates - return BindResult(std::move(colref)); -} -} // namespace duckdb - - - - - - - - -namespace duckdb { - -BindResult ExpressionBinder::BindExpression(BetweenExpression &expr, idx_t depth) { - // first try to bind the children of the case expression - string error; - BindChild(expr.input, depth, error); - BindChild(expr.lower, depth, error); - BindChild(expr.upper, depth, error); - if (!error.empty()) { - return BindResult(error); - } - // the children have been successfully resolved - auto &input = BoundExpression::GetExpression(*expr.input); - auto &lower = BoundExpression::GetExpression(*expr.lower); - auto &upper = BoundExpression::GetExpression(*expr.upper); - - auto input_sql_type = input->return_type; - auto lower_sql_type = lower->return_type; - auto upper_sql_type = upper->return_type; - - // cast the input types to the same type - // now obtain the result type of the input types - auto input_type = BoundComparisonExpression::BindComparison(input_sql_type, lower_sql_type); - input_type = BoundComparisonExpression::BindComparison(input_type, upper_sql_type); - // add casts (if necessary) - input = BoundCastExpression::AddCastToType(context, std::move(input), input_type); - lower = BoundCastExpression::AddCastToType(context, std::move(lower), input_type); - upper = BoundCastExpression::AddCastToType(context, std::move(upper), input_type); - if (input_type.id() == LogicalTypeId::VARCHAR) { - // handle collation - auto collation = StringType::GetCollation(input_type); - input = PushCollation(context, std::move(input), collation, false); - lower = PushCollation(context, std::move(lower), collation, false); - upper = PushCollation(context, std::move(upper), collation, false); - } - if (!input->HasSideEffects() && !input->HasParameter() && !input->HasSubquery()) { - // the expression does not have side effects and can be copied: create two comparisons - // the reason we do this is that individual comparisons are easier to handle in optimizers - // if both comparisons remain they will be folded together again into a single BETWEEN in the optimizer - auto left_compare = make_uniq(ExpressionType::COMPARE_GREATERTHANOREQUALTO, - input->Copy(), std::move(lower)); - auto right_compare = make_uniq(ExpressionType::COMPARE_LESSTHANOREQUALTO, - std::move(input), std::move(upper)); - return BindResult(make_uniq(ExpressionType::CONJUNCTION_AND, - std::move(left_compare), std::move(right_compare))); - } else { - // expression has side effects: we cannot duplicate it - // create a bound_between directly - return BindResult( - make_uniq(std::move(input), std::move(lower), std::move(upper), true, true)); - } -} - -} // namespace duckdb - - - - - -namespace duckdb { - -BindResult ExpressionBinder::BindExpression(CaseExpression &expr, idx_t depth) { - // first try to bind the children of the case expression - string error; - for (auto &check : expr.case_checks) { - BindChild(check.when_expr, depth, error); - BindChild(check.then_expr, depth, error); - } - BindChild(expr.else_expr, depth, error); - if (!error.empty()) { - return BindResult(error); - } - // the children have been successfully resolved - // figure out the result type of the CASE expression - auto &else_expr = BoundExpression::GetExpression(*expr.else_expr); - auto return_type = else_expr->return_type; - for (auto &check : expr.case_checks) { - auto &then_expr = BoundExpression::GetExpression(*check.then_expr); - return_type = LogicalType::MaxLogicalType(return_type, then_expr->return_type); - } - - // bind all the individual components of the CASE statement - auto result = make_uniq(return_type); - for (idx_t i = 0; i < expr.case_checks.size(); i++) { - auto &check = expr.case_checks[i]; - auto &when_expr = BoundExpression::GetExpression(*check.when_expr); - auto &then_expr = BoundExpression::GetExpression(*check.then_expr); - BoundCaseCheck result_check; - result_check.when_expr = - BoundCastExpression::AddCastToType(context, std::move(when_expr), LogicalType::BOOLEAN); - result_check.then_expr = BoundCastExpression::AddCastToType(context, std::move(then_expr), return_type); - result->case_checks.push_back(std::move(result_check)); - } - result->else_expr = BoundCastExpression::AddCastToType(context, std::move(else_expr), return_type); - return BindResult(std::move(result)); -} -} // namespace duckdb - - - - - - -namespace duckdb { - -BindResult ExpressionBinder::BindExpression(CastExpression &expr, idx_t depth) { - // first try to bind the child of the cast expression - string error = Bind(expr.child, depth); - if (!error.empty()) { - return BindResult(error); - } - // FIXME: We can also implement 'hello'::schema.custom_type; and pass by the schema down here. - // Right now just considering its DEFAULT_SCHEMA always - Binder::BindLogicalType(context, expr.cast_type); - // the children have been successfully resolved - auto &child = BoundExpression::GetExpression(*expr.child); - if (expr.try_cast) { - if (child->return_type == expr.cast_type) { - // no cast required: type matches - return BindResult(std::move(child)); - } - child = BoundCastExpression::AddCastToType(context, std::move(child), expr.cast_type, true); - } else { - // otherwise add a cast to the target type - child = BoundCastExpression::AddCastToType(context, std::move(child), expr.cast_type); - } - return BindResult(std::move(child)); -} -} // namespace duckdb - - - - -namespace duckdb { - -BindResult ExpressionBinder::BindExpression(CollateExpression &expr, idx_t depth) { - // first try to bind the child of the cast expression - string error = Bind(expr.child, depth); - if (!error.empty()) { - return BindResult(error); - } - auto &child = BoundExpression::GetExpression(*expr.child); - if (child->HasParameter()) { - throw ParameterNotResolvedException(); - } - if (child->return_type.id() != LogicalTypeId::VARCHAR) { - throw BinderException("collations are only supported for type varchar"); - } - // Validate the collation, but don't use it - PushCollation(context, child->Copy(), expr.collation, false); - child->return_type = LogicalType::VARCHAR_COLLATION(expr.collation); - return BindResult(std::move(child)); -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - -namespace duckdb { - -string GetSQLValueFunctionName(const string &column_name) { - auto lcase = StringUtil::Lower(column_name); - if (lcase == "current_catalog") { - return "current_catalog"; - } else if (lcase == "current_date") { - return "current_date"; - } else if (lcase == "current_schema") { - return "current_schema"; - } else if (lcase == "current_role") { - return "current_role"; - } else if (lcase == "current_time") { - return "get_current_time"; - } else if (lcase == "current_timestamp") { - return "get_current_timestamp"; - } else if (lcase == "current_user") { - return "current_user"; - } else if (lcase == "localtime") { - return "current_localtime"; - } else if (lcase == "localtimestamp") { - return "current_localtimestamp"; - } else if (lcase == "session_user") { - return "session_user"; - } else if (lcase == "user") { - return "user"; - } - return string(); -} - -unique_ptr ExpressionBinder::GetSQLValueFunction(const string &column_name) { - auto value_function = GetSQLValueFunctionName(column_name); - if (value_function.empty()) { - return nullptr; - } - - vector> children; - return make_uniq(value_function, std::move(children)); -} - -unique_ptr ExpressionBinder::QualifyColumnName(const string &column_name, string &error_message) { - auto using_binding = binder.bind_context.GetUsingBinding(column_name); - if (using_binding) { - // we are referencing a USING column - // check if we can refer to one of the base columns directly - unique_ptr expression; - if (!using_binding->primary_binding.empty()) { - // we can! just assign the table name and re-bind - return binder.bind_context.CreateColumnReference(using_binding->primary_binding, column_name); - } else { - // // we cannot! we need to bind this as a coalesce between all the relevant columns - auto coalesce = make_uniq(ExpressionType::OPERATOR_COALESCE); - coalesce->children.reserve(using_binding->bindings.size()); - for (auto &entry : using_binding->bindings) { - coalesce->children.push_back(make_uniq(column_name, entry)); - } - return std::move(coalesce); - } - } - - // find a binding that contains this - string table_name = binder.bind_context.GetMatchingBinding(column_name); - - // throw an error if a macro conflicts with a column name - auto is_macro_column = false; - if (binder.macro_binding != nullptr && binder.macro_binding->HasMatchingBinding(column_name)) { - is_macro_column = true; - if (!table_name.empty()) { - throw BinderException("Conflicting column names for column " + column_name + "!"); - } - } - - if (lambda_bindings) { - for (idx_t i = 0; i < lambda_bindings->size(); i++) { - if ((*lambda_bindings)[i].HasMatchingBinding(column_name)) { - - // throw an error if a lambda conflicts with a column name or a macro - if (!table_name.empty() || is_macro_column) { - throw BinderException("Conflicting column names for column " + column_name + "!"); - } - - D_ASSERT(!(*lambda_bindings)[i].alias.empty()); - return make_uniq(column_name, (*lambda_bindings)[i].alias); - } - } - } - - if (is_macro_column) { - D_ASSERT(!binder.macro_binding->alias.empty()); - return make_uniq(column_name, binder.macro_binding->alias); - } - // see if it's a column - if (table_name.empty()) { - // column was not found - check if it is a SQL value function - auto value_function = GetSQLValueFunction(column_name); - if (value_function) { - return value_function; - } - // it's not, find candidates and error - auto similar_bindings = binder.bind_context.GetSimilarBindings(column_name); - string candidate_str = StringUtil::CandidatesMessage(similar_bindings, "Candidate bindings"); - error_message = - StringUtil::Format("Referenced column \"%s\" not found in FROM clause!%s", column_name, candidate_str); - return nullptr; - } - return binder.bind_context.CreateColumnReference(table_name, column_name); -} - -void ExpressionBinder::QualifyColumnNames(unique_ptr &expr) { - switch (expr->type) { - case ExpressionType::COLUMN_REF: { - auto &colref = expr->Cast(); - string error_message; - auto new_expr = QualifyColumnName(colref, error_message); - if (new_expr) { - if (!expr->alias.empty()) { - new_expr->alias = expr->alias; - } - new_expr->query_location = colref.query_location; - expr = std::move(new_expr); - } - break; - } - case ExpressionType::POSITIONAL_REFERENCE: { - auto &ref = expr->Cast(); - if (ref.alias.empty()) { - string table_name, column_name; - auto error = binder.bind_context.BindColumn(ref, table_name, column_name); - if (error.empty()) { - ref.alias = column_name; - } - } - break; - } - default: - break; - } - ParsedExpressionIterator::EnumerateChildren( - *expr, [&](unique_ptr &child) { QualifyColumnNames(child); }); -} - -void ExpressionBinder::QualifyColumnNames(Binder &binder, unique_ptr &expr) { - WhereBinder where_binder(binder, binder.context); - where_binder.QualifyColumnNames(expr); -} - -unique_ptr ExpressionBinder::CreateStructExtract(unique_ptr base, - string field_name) { - - // we need to transform the struct extract if it is inside a lambda expression - // because we cannot bind to an existing table, so we remove the dummy table also - if (lambda_bindings && base->type == ExpressionType::COLUMN_REF) { - auto &lambda_column_ref = base->Cast(); - D_ASSERT(!lambda_column_ref.column_names.empty()); - - if (lambda_column_ref.column_names[0].find(DummyBinding::DUMMY_NAME) != string::npos) { - D_ASSERT(lambda_column_ref.column_names.size() == 2); - auto lambda_param_name = lambda_column_ref.column_names.back(); - lambda_column_ref.column_names.clear(); - lambda_column_ref.column_names.push_back(lambda_param_name); - } - } - - vector> children; - children.push_back(std::move(base)); - children.push_back(make_uniq_base(Value(std::move(field_name)))); - auto extract_fun = make_uniq(ExpressionType::STRUCT_EXTRACT, std::move(children)); - return std::move(extract_fun); -} - -unique_ptr ExpressionBinder::CreateStructPack(ColumnRefExpression &colref) { - D_ASSERT(colref.column_names.size() <= 3); - string error_message; - auto &table_name = colref.column_names.back(); - auto binding = binder.bind_context.GetBinding(table_name, error_message); - if (!binding) { - return nullptr; - } - if (colref.column_names.size() >= 2) { - // "schema_name.table_name" - auto catalog_entry = binding->GetStandardEntry(); - if (!catalog_entry) { - return nullptr; - } - if (catalog_entry->name != table_name) { - return nullptr; - } - if (colref.column_names.size() == 2) { - auto &qualifier = colref.column_names[0]; - if (catalog_entry->catalog.GetName() != qualifier && catalog_entry->schema.name != qualifier) { - return nullptr; - } - } else if (colref.column_names.size() == 3) { - auto &catalog_name = colref.column_names[0]; - auto &schema_name = colref.column_names[1]; - if (catalog_entry->catalog.GetName() != catalog_name || catalog_entry->schema.name != schema_name) { - return nullptr; - } - } else { - throw InternalException("Expected 2 or 3 column names for CreateStructPack"); - } - } - // We found the table, now create the struct_pack expression - vector> child_expressions; - child_expressions.reserve(binding->names.size()); - for (const auto &column_name : binding->names) { - child_expressions.push_back(make_uniq(column_name, table_name)); - } - return make_uniq("struct_pack", std::move(child_expressions)); -} - -unique_ptr ExpressionBinder::QualifyColumnName(ColumnRefExpression &colref, string &error_message) { - idx_t column_parts = colref.column_names.size(); - // column names can have an arbitrary amount of dots - // here is how the resolution works: - if (column_parts == 1) { - // no dots (i.e. "part1") - // -> part1 refers to a column - // check if we can qualify the column name with the table name - auto qualified_colref = QualifyColumnName(colref.GetColumnName(), error_message); - if (qualified_colref) { - // we could: return it - return qualified_colref; - } - // we could not! Try creating an implicit struct_pack - return CreateStructPack(colref); - } else if (column_parts == 2) { - // one dot (i.e. "part1.part2") - // EITHER: - // -> part1 is a table, part2 is a column - // -> part1 is a column, part2 is a property of that column (i.e. struct_extract) - - // first check if part1 is a table, and part2 is a standard column - if (binder.HasMatchingBinding(colref.column_names[0], colref.column_names[1], error_message)) { - // it is! return the colref directly - return binder.bind_context.CreateColumnReference(colref.column_names[0], colref.column_names[1]); - } else { - // otherwise check if we can turn this into a struct extract - auto new_colref = make_uniq(colref.column_names[0]); - string other_error; - auto qualified_colref = QualifyColumnName(colref.column_names[0], other_error); - if (qualified_colref) { - // we could: create a struct extract - return CreateStructExtract(std::move(qualified_colref), colref.column_names[1]); - } - // we could not! Try creating an implicit struct_pack - return CreateStructPack(colref); - } - } else { - // two or more dots (i.e. "part1.part2.part3.part4...") - // -> part1 is a catalog, part2 is a schema, part3 is a table, part4 is a column name, part 5 and beyond are - // struct fields - // -> part1 is a catalog, part2 is a table, part3 is a column name, part4 and beyond are struct fields - // -> part1 is a schema, part2 is a table, part3 is a column name, part4 and beyond are struct fields - // -> part1 is a table, part2 is a column name, part3 and beyond are struct fields - // -> part1 is a column, part2 and beyond are struct fields - - // we always prefer the most top-level view - // i.e. in case of multiple resolution options, we resolve in order: - // -> 1. resolve "part1" as a catalog - // -> 2. resolve "part1" as a schema - // -> 3. resolve "part1" as a table - // -> 4. resolve "part1" as a column - - unique_ptr result_expr; - idx_t struct_extract_start; - // first check if part1 is a catalog - if (colref.column_names.size() > 3 && - binder.HasMatchingBinding(colref.column_names[0], colref.column_names[1], colref.column_names[2], - colref.column_names[3], error_message)) { - // part1 is a catalog - the column reference is "catalog.schema.table.column" - result_expr = binder.bind_context.CreateColumnReference(colref.column_names[0], colref.column_names[1], - colref.column_names[2], colref.column_names[3]); - struct_extract_start = 4; - } else if (binder.HasMatchingBinding(colref.column_names[0], INVALID_SCHEMA, colref.column_names[1], - colref.column_names[2], error_message)) { - // part1 is a catalog - the column reference is "catalog.table.column" - result_expr = binder.bind_context.CreateColumnReference(colref.column_names[0], INVALID_SCHEMA, - colref.column_names[1], colref.column_names[2]); - struct_extract_start = 3; - } else if (binder.HasMatchingBinding(colref.column_names[0], colref.column_names[1], colref.column_names[2], - error_message)) { - // part1 is a schema - the column reference is "schema.table.column" - // any additional fields are turned into struct_extract calls - result_expr = binder.bind_context.CreateColumnReference(colref.column_names[0], colref.column_names[1], - colref.column_names[2]); - struct_extract_start = 3; - } else if (binder.HasMatchingBinding(colref.column_names[0], colref.column_names[1], error_message)) { - // part1 is a table - // the column reference is "table.column" - // any additional fields are turned into struct_extract calls - result_expr = binder.bind_context.CreateColumnReference(colref.column_names[0], colref.column_names[1]); - struct_extract_start = 2; - } else { - // part1 could be a column - string col_error; - result_expr = QualifyColumnName(colref.column_names[0], col_error); - if (!result_expr) { - // it is not! Try creating an implicit struct_pack - return CreateStructPack(colref); - } - // it is! add the struct extract calls - struct_extract_start = 1; - } - for (idx_t i = struct_extract_start; i < colref.column_names.size(); i++) { - result_expr = CreateStructExtract(std::move(result_expr), colref.column_names[i]); - } - return result_expr; - } -} - -BindResult ExpressionBinder::BindExpression(ColumnRefExpression &colref_p, idx_t depth) { - if (binder.GetBindingMode() == BindingMode::EXTRACT_NAMES) { - return BindResult(make_uniq(Value(LogicalType::SQLNULL))); - } - string error_message; - auto expr = QualifyColumnName(colref_p, error_message); - if (!expr) { - return BindResult(binder.FormatError(colref_p, error_message)); - } - expr->query_location = colref_p.query_location; - - // a generated column returns a generated expression, a struct on a column returns a struct extract - if (expr->type != ExpressionType::COLUMN_REF) { - auto alias = expr->alias; - auto result = BindExpression(expr, depth); - if (result.expression) { - result.expression->alias = std::move(alias); - } - return result; - } - - auto &colref = expr->Cast(); - D_ASSERT(colref.IsQualified()); - auto &table_name = colref.GetTableName(); - - // individual column reference - // resolve to either a base table or a subquery expression - // if it was a macro parameter, let macro_binding bind it to the argument - // if it was a lambda parameter, let lambda_bindings bind it to the argument - - BindResult result; - - auto found_lambda_binding = false; - if (lambda_bindings) { - for (idx_t i = 0; i < lambda_bindings->size(); i++) { - if (table_name == (*lambda_bindings)[i].alias) { - result = (*lambda_bindings)[i].Bind(colref, i, depth); - found_lambda_binding = true; - break; - } - } - } - - if (!found_lambda_binding) { - if (binder.macro_binding && table_name == binder.macro_binding->alias) { - result = binder.macro_binding->Bind(colref, depth); - } else { - result = binder.bind_context.BindColumn(colref, depth); - } - } - - if (!result.HasError()) { - BoundColumnReferenceInfo ref; - ref.name = colref.column_names.back(); - ref.query_location = colref.query_location; - bound_columns.push_back(std::move(ref)); - } else { - result.error = binder.FormatError(colref_p, result.error); - } - return result; -} - -bool ExpressionBinder::QualifyColumnAlias(const ColumnRefExpression &colref) { - // Only BaseSelectBinder will have a valid col alias map, - // otherwise just return false - return false; -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - -namespace duckdb { - -unique_ptr ExpressionBinder::PushCollation(ClientContext &context, unique_ptr source, - const string &collation_p, bool equality_only) { - // replace default collation with system collation - string collation; - if (collation_p.empty()) { - collation = DBConfig::GetConfig(context).options.collation; - } else { - collation = collation_p; - } - collation = StringUtil::Lower(collation); - // bind the collation - if (collation.empty() || collation == "binary" || collation == "c" || collation == "posix") { - // binary collation: just skip - return source; - } - auto &catalog = Catalog::GetSystemCatalog(context); - auto splits = StringUtil::Split(StringUtil::Lower(collation), "."); - vector> entries; - for (auto &collation_argument : splits) { - auto &collation_entry = catalog.GetEntry(context, DEFAULT_SCHEMA, collation_argument); - if (collation_entry.combinable) { - entries.insert(entries.begin(), collation_entry); - } else { - if (!entries.empty() && !entries.back().get().combinable) { - throw BinderException("Cannot combine collation types \"%s\" and \"%s\"", entries.back().get().name, - collation_entry.name); - } - entries.push_back(collation_entry); - } - } - for (auto &entry : entries) { - auto &collation_entry = entry.get(); - if (equality_only && collation_entry.not_required_for_equality) { - continue; - } - vector> children; - children.push_back(std::move(source)); - - FunctionBinder function_binder(context); - auto function = function_binder.BindScalarFunction(collation_entry.function, std::move(children)); - source = std::move(function); - } - return source; -} - -void ExpressionBinder::TestCollation(ClientContext &context, const string &collation) { - PushCollation(context, make_uniq(Value("")), collation); -} - -LogicalType BoundComparisonExpression::BindComparison(LogicalType left_type, LogicalType right_type) { - auto result_type = LogicalType::MaxLogicalType(left_type, right_type); - switch (result_type.id()) { - case LogicalTypeId::DECIMAL: { - // result is a decimal: we need the maximum width and the maximum scale over width - vector argument_types = {left_type, right_type}; - uint8_t max_width = 0, max_scale = 0, max_width_over_scale = 0; - for (idx_t i = 0; i < argument_types.size(); i++) { - uint8_t width, scale; - auto can_convert = argument_types[i].GetDecimalProperties(width, scale); - if (!can_convert) { - return result_type; - } - max_width = MaxValue(width, max_width); - max_scale = MaxValue(scale, max_scale); - max_width_over_scale = MaxValue(width - scale, max_width_over_scale); - } - max_width = MaxValue(max_scale + max_width_over_scale, max_width); - if (max_width > Decimal::MAX_WIDTH_DECIMAL) { - // target width does not fit in decimal: truncate the scale (if possible) to try and make it fit - max_width = Decimal::MAX_WIDTH_DECIMAL; - } - return LogicalType::DECIMAL(max_width, max_scale); - } - case LogicalTypeId::VARCHAR: - // for comparison with strings, we prefer to bind to the numeric types - if (left_type.IsNumeric() || left_type.id() == LogicalTypeId::BOOLEAN) { - return left_type; - } else if (right_type.IsNumeric() || right_type.id() == LogicalTypeId::BOOLEAN) { - return right_type; - } else { - // else: check if collations are compatible - auto left_collation = StringType::GetCollation(left_type); - auto right_collation = StringType::GetCollation(right_type); - if (!left_collation.empty() && !right_collation.empty() && left_collation != right_collation) { - throw BinderException("Cannot combine types with different collation!"); - } - } - return result_type; - default: - return result_type; - } -} - -BindResult ExpressionBinder::BindExpression(ComparisonExpression &expr, idx_t depth) { - // first try to bind the children of the case expression - string error; - BindChild(expr.left, depth, error); - BindChild(expr.right, depth, error); - if (!error.empty()) { - return BindResult(error); - } - - // the children have been successfully resolved - auto &left = BoundExpression::GetExpression(*expr.left); - auto &right = BoundExpression::GetExpression(*expr.right); - auto left_sql_type = left->return_type; - auto right_sql_type = right->return_type; - // cast the input types to the same type - // now obtain the result type of the input types - auto input_type = BoundComparisonExpression::BindComparison(left_sql_type, right_sql_type); - // add casts (if necessary) - left = BoundCastExpression::AddCastToType(context, std::move(left), input_type, - input_type.id() == LogicalTypeId::ENUM); - right = BoundCastExpression::AddCastToType(context, std::move(right), input_type, - input_type.id() == LogicalTypeId::ENUM); - - if (input_type.id() == LogicalTypeId::VARCHAR) { - // handle collation - auto collation = StringType::GetCollation(input_type); - left = PushCollation(context, std::move(left), collation, expr.type == ExpressionType::COMPARE_EQUAL); - right = PushCollation(context, std::move(right), collation, expr.type == ExpressionType::COMPARE_EQUAL); - } - // now create the bound comparison expression - return BindResult(make_uniq(expr.type, std::move(left), std::move(right))); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -BindResult ExpressionBinder::BindExpression(ConjunctionExpression &expr, idx_t depth) { - // first try to bind the children of the case expression - string error; - for (idx_t i = 0; i < expr.children.size(); i++) { - BindChild(expr.children[i], depth, error); - } - if (!error.empty()) { - return BindResult(error); - } - // the children have been successfully resolved - // cast the input types to boolean (if necessary) - // and construct the bound conjunction expression - auto result = make_uniq(expr.type); - for (auto &child_expr : expr.children) { - auto &child = BoundExpression::GetExpression(*child_expr); - result->children.push_back(BoundCastExpression::AddCastToType(context, std::move(child), LogicalType::BOOLEAN)); - } - // now create the bound conjunction expression - return BindResult(std::move(result)); -} - -} // namespace duckdb - - - - -namespace duckdb { - -BindResult ExpressionBinder::BindExpression(ConstantExpression &expr, idx_t depth) { - return BindResult(make_uniq(expr.value)); -} - -} // namespace duckdb - - - - - - - - - - - - - - - - -namespace duckdb { - -BindResult ExpressionBinder::BindExpression(FunctionExpression &function, idx_t depth, - unique_ptr &expr_ptr) { - // lookup the function in the catalog - QueryErrorContext error_context(binder.root_statement, function.query_location); - auto func = Catalog::GetEntry(context, CatalogType::SCALAR_FUNCTION_ENTRY, function.catalog, function.schema, - function.function_name, OnEntryNotFound::RETURN_NULL, error_context); - if (!func) { - // function was not found - check if we this is a table function - auto table_func = - Catalog::GetEntry(context, CatalogType::TABLE_FUNCTION_ENTRY, function.catalog, function.schema, - function.function_name, OnEntryNotFound::RETURN_NULL, error_context); - if (table_func) { - throw BinderException(binder.FormatError( - function, - StringUtil::Format("Function \"%s\" is a table function but it was used as a scalar function. This " - "function has to be called in a FROM clause (similar to a table).", - function.function_name))); - } - // not a table function - check if the schema is set - if (!function.schema.empty()) { - // the schema is set - check if we can turn this the schema into a column ref - string error; - unique_ptr colref; - if (function.catalog.empty()) { - colref = make_uniq(function.schema); - } else { - colref = make_uniq(function.schema, function.catalog); - } - auto new_colref = QualifyColumnName(*colref, error); - bool is_col = error.empty() ? true : false; - bool is_col_alias = QualifyColumnAlias(*colref); - - if (is_col || is_col_alias) { - // we can! transform this into a function call on the column - // i.e. "x.lower()" becomes "lower(x)" - function.children.insert(function.children.begin(), std::move(colref)); - function.catalog = INVALID_CATALOG; - function.schema = INVALID_SCHEMA; - } - } - // rebind the function - func = Catalog::GetEntry(context, CatalogType::SCALAR_FUNCTION_ENTRY, function.catalog, function.schema, - function.function_name, OnEntryNotFound::THROW_EXCEPTION, error_context); - } - - if (func->type != CatalogType::AGGREGATE_FUNCTION_ENTRY && - (function.distinct || function.filter || !function.order_bys->orders.empty())) { - throw InvalidInputException("Function \"%s\" is a %s. \"DISTINCT\", \"FILTER\", and \"ORDER BY\" are only " - "applicable to aggregate functions.", - function.function_name, CatalogTypeToString(func->type)); - } - - switch (func->type) { - case CatalogType::SCALAR_FUNCTION_ENTRY: { - // scalar function - - // check for lambda parameters, ignore ->> operator (JSON extension) - bool try_bind_lambda = false; - if (function.function_name != "->>") { - for (auto &child : function.children) { - if (child->expression_class == ExpressionClass::LAMBDA) { - try_bind_lambda = true; - } - } - } - - if (try_bind_lambda) { - auto result = BindLambdaFunction(function, func->Cast(), depth); - if (!result.HasError()) { - // Lambda bind successful - return result; - } - } - - // other scalar function - return BindFunction(function, func->Cast(), depth); - } - case CatalogType::MACRO_ENTRY: - // macro function - return BindMacro(function, func->Cast(), depth, expr_ptr); - default: - // aggregate function - return BindAggregate(function, func->Cast(), depth); - } -} - -BindResult ExpressionBinder::BindFunction(FunctionExpression &function, ScalarFunctionCatalogEntry &func, idx_t depth) { - - // bind the children of the function expression - string error; - - // bind of each child - for (idx_t i = 0; i < function.children.size(); i++) { - BindChild(function.children[i], depth, error); - } - - if (!error.empty()) { - return BindResult(error); - } - if (binder.GetBindingMode() == BindingMode::EXTRACT_NAMES) { - return BindResult(make_uniq(Value(LogicalType::SQLNULL))); - } - - // all children bound successfully - // extract the children and types - vector> children; - for (idx_t i = 0; i < function.children.size(); i++) { - auto &child = BoundExpression::GetExpression(*function.children[i]); - children.push_back(std::move(child)); - } - - FunctionBinder function_binder(context); - unique_ptr result = - function_binder.BindScalarFunction(func, std::move(children), error, function.is_operator, &binder); - if (!result) { - throw BinderException(binder.FormatError(function, error)); - } - return BindResult(std::move(result)); -} - -BindResult ExpressionBinder::BindLambdaFunction(FunctionExpression &function, ScalarFunctionCatalogEntry &func, - idx_t depth) { - - // bind the children of the function expression - string error; - - if (function.children.size() != 2) { - return BindResult("Invalid function arguments!"); - } - D_ASSERT(function.children[1]->GetExpressionClass() == ExpressionClass::LAMBDA); - - // bind the list parameter - BindChild(function.children[0], depth, error); - if (!error.empty()) { - return BindResult(error); - } - - // get the logical type of the children of the list - auto &list_child = BoundExpression::GetExpression(*function.children[0]); - if (list_child->return_type.id() != LogicalTypeId::LIST && list_child->return_type.id() != LogicalTypeId::SQLNULL && - list_child->return_type.id() != LogicalTypeId::UNKNOWN) { - return BindResult(" Invalid LIST argument to " + function.function_name + "!"); - } - - LogicalType list_child_type = list_child->return_type.id(); - if (list_child->return_type.id() != LogicalTypeId::SQLNULL && - list_child->return_type.id() != LogicalTypeId::UNKNOWN) { - list_child_type = ListType::GetChildType(list_child->return_type); - } - - // bind the lambda parameter - auto &lambda_expr = function.children[1]->Cast(); - BindResult bind_lambda_result = BindExpression(lambda_expr, depth, true, list_child_type); - - if (bind_lambda_result.HasError()) { - error = bind_lambda_result.error; - } else { - // successfully bound: replace the node with a BoundExpression - auto alias = function.children[1]->alias; - bind_lambda_result.expression->alias = alias; - if (!alias.empty()) { - bind_lambda_result.expression->alias = alias; - } - function.children[1] = make_uniq(std::move(bind_lambda_result.expression)); - } - - if (!error.empty()) { - return BindResult(error); - } - if (binder.GetBindingMode() == BindingMode::EXTRACT_NAMES) { - return BindResult(make_uniq(Value(LogicalType::SQLNULL))); - } - - // all children bound successfully - // extract the children and types - vector> children; - for (idx_t i = 0; i < function.children.size(); i++) { - auto &child = BoundExpression::GetExpression(*function.children[i]); - children.push_back(std::move(child)); - } - - // capture the (lambda) columns - auto &bound_lambda_expr = children.back()->Cast(); - CaptureLambdaColumns(bound_lambda_expr.captures, list_child_type, bound_lambda_expr.lambda_expr); - - FunctionBinder function_binder(context); - unique_ptr result = - function_binder.BindScalarFunction(func, std::move(children), error, function.is_operator, &binder); - if (!result) { - throw BinderException(binder.FormatError(function, error)); - } - - auto &bound_function_expr = result->Cast(); - D_ASSERT(bound_function_expr.children.size() == 2); - - // remove the lambda expression from the children - auto lambda = std::move(bound_function_expr.children.back()); - bound_function_expr.children.pop_back(); - auto &bound_lambda = lambda->Cast(); - - // push back (in reverse order) any nested lambda parameters so that we can later use them in the lambda expression - // (rhs) - if (lambda_bindings) { - for (idx_t i = lambda_bindings->size(); i > 0; i--) { - - idx_t lambda_index = lambda_bindings->size() - i + 1; - auto &binding = (*lambda_bindings)[i - 1]; - - D_ASSERT(binding.names.size() == 1); - D_ASSERT(binding.types.size() == 1); - - auto bound_lambda_param = - make_uniq(binding.names[0], binding.types[0], lambda_index); - bound_function_expr.children.push_back(std::move(bound_lambda_param)); - } - } - - // push back the captures into the children vector and the correct return types into the bound_function arguments - for (auto &capture : bound_lambda.captures) { - bound_function_expr.children.push_back(std::move(capture)); - } - - return BindResult(std::move(result)); -} - -BindResult ExpressionBinder::BindAggregate(FunctionExpression &expr, AggregateFunctionCatalogEntry &function, - idx_t depth) { - return BindResult(binder.FormatError(expr, UnsupportedAggregateMessage())); -} - -BindResult ExpressionBinder::BindUnnest(FunctionExpression &expr, idx_t depth, bool root_expression) { - return BindResult(binder.FormatError(expr, UnsupportedUnnestMessage())); -} - -string ExpressionBinder::UnsupportedAggregateMessage() { - return "Aggregate functions are not supported here"; -} - -string ExpressionBinder::UnsupportedUnnestMessage() { - return "UNNEST not supported here"; -} - -} // namespace duckdb - - - - - - - - - - - - -namespace duckdb { - -BindResult ExpressionBinder::BindExpression(LambdaExpression &expr, idx_t depth, const bool is_lambda, - const LogicalType &list_child_type) { - - if (!is_lambda) { - // this is for binding JSON - auto lhs_expr = expr.lhs->Copy(); - OperatorExpression arrow_expr(ExpressionType::ARROW, std::move(lhs_expr), expr.expr->Copy()); - return BindExpression(arrow_expr, depth); - } - - // binding the lambda expression - D_ASSERT(expr.lhs); - if (expr.lhs->expression_class != ExpressionClass::FUNCTION && - expr.lhs->expression_class != ExpressionClass::COLUMN_REF) { - throw BinderException( - "Invalid parameter list! Parameters must be comma-separated column names, e.g. x or (x, y)."); - } - - // move the lambda parameters to the params vector - if (expr.lhs->expression_class == ExpressionClass::COLUMN_REF) { - expr.params.push_back(std::move(expr.lhs)); - } else { - auto &func_expr = expr.lhs->Cast(); - for (idx_t i = 0; i < func_expr.children.size(); i++) { - expr.params.push_back(std::move(func_expr.children[i])); - } - } - D_ASSERT(!expr.params.empty()); - - // create dummy columns for the lambda parameters (lhs) - vector column_types; - vector column_names; - vector params_strings; - - // positional parameters as column references - for (idx_t i = 0; i < expr.params.size(); i++) { - if (expr.params[i]->GetExpressionClass() != ExpressionClass::COLUMN_REF) { - throw BinderException("Parameter must be a column name."); - } - - auto column_ref = expr.params[i]->Cast(); - if (column_ref.IsQualified()) { - throw BinderException("Invalid parameter name '%s': must be unqualified", column_ref.ToString()); - } - - column_types.emplace_back(list_child_type); - column_names.push_back(column_ref.GetColumnName()); - params_strings.push_back(expr.params[i]->ToString()); - } - - // base table alias - auto params_alias = StringUtil::Join(params_strings, ", "); - if (params_strings.size() > 1) { - params_alias = "(" + params_alias + ")"; - } - - // create a lambda binding and push it to the lambda bindings vector - vector local_bindings; - if (!lambda_bindings) { - lambda_bindings = &local_bindings; - } - DummyBinding new_lambda_binding(column_types, column_names, params_alias); - lambda_bindings->push_back(new_lambda_binding); - - // bind the parameter expressions - for (idx_t i = 0; i < expr.params.size(); i++) { - auto result = BindExpression(expr.params[i], depth, false); - if (result.HasError()) { - throw InternalException("Error during lambda binding: %s", result.error); - } - } - - auto result = BindExpression(expr.expr, depth, false); - lambda_bindings->pop_back(); - - // successfully bound a subtree of nested lambdas, set this to nullptr in case other parts of the - // query also contain lambdas - if (lambda_bindings->empty()) { - lambda_bindings = nullptr; - } - - if (result.HasError()) { - throw BinderException(result.error); - } - - return BindResult(make_uniq(ExpressionType::LAMBDA, LogicalType::LAMBDA, - std::move(result.expression), params_strings.size())); -} - -void ExpressionBinder::TransformCapturedLambdaColumn(unique_ptr &original, - unique_ptr &replacement, - vector> &captures, - LogicalType &list_child_type) { - - // check if the original expression is a lambda parameter - if (original->expression_class == ExpressionClass::BOUND_LAMBDA_REF) { - - // determine if this is the lambda parameter - auto &bound_lambda_ref = original->Cast(); - auto alias = bound_lambda_ref.alias; - - if (lambda_bindings && bound_lambda_ref.lambda_index != lambda_bindings->size()) { - - D_ASSERT(bound_lambda_ref.lambda_index < lambda_bindings->size()); - auto &lambda_binding = (*lambda_bindings)[bound_lambda_ref.lambda_index]; - - D_ASSERT(lambda_binding.names.size() == 1); - D_ASSERT(lambda_binding.types.size() == 1); - // refers to a lambda parameter outside of the current lambda function - replacement = - make_uniq(lambda_binding.names[0], lambda_binding.types[0], - lambda_bindings->size() - bound_lambda_ref.lambda_index + 1); - - } else { - // refers to current lambda parameter - replacement = make_uniq(alias, list_child_type, 0); - } - - } else { - // always at least the current lambda parameter - idx_t index_offset = 1; - if (lambda_bindings) { - index_offset += lambda_bindings->size(); - } - - // this is not a lambda parameter, so we need to create a new argument for the arguments vector - replacement = make_uniq(original->alias, original->return_type, - captures.size() + index_offset + 1); - captures.push_back(std::move(original)); - } -} - -void ExpressionBinder::CaptureLambdaColumns(vector> &captures, LogicalType &list_child_type, - unique_ptr &expr) { - - if (expr->expression_class == ExpressionClass::BOUND_SUBQUERY) { - throw InvalidInputException("Subqueries are not supported in lambda expressions!"); - } - - // these expression classes do not have children, transform them - if (expr->expression_class == ExpressionClass::BOUND_CONSTANT || - expr->expression_class == ExpressionClass::BOUND_COLUMN_REF || - expr->expression_class == ExpressionClass::BOUND_PARAMETER || - expr->expression_class == ExpressionClass::BOUND_LAMBDA_REF) { - - // move the expr because we are going to replace it - auto original = std::move(expr); - unique_ptr replacement; - - TransformCapturedLambdaColumn(original, replacement, captures, list_child_type); - - // replace the expression - expr = std::move(replacement); - - } else { - // recursively enumerate the children of the expression - ExpressionIterator::EnumerateChildren( - *expr, [&](unique_ptr &child) { CaptureLambdaColumns(captures, list_child_type, child); }); - } - - expr->Verify(); -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -void ExpressionBinder::ReplaceMacroParametersRecursive(unique_ptr &expr) { - switch (expr->GetExpressionClass()) { - case ExpressionClass::COLUMN_REF: { - // if expr is a parameter, replace it with its argument - auto &colref = expr->Cast(); - bool bind_macro_parameter = false; - if (colref.IsQualified()) { - bind_macro_parameter = false; - if (colref.GetTableName().find(DummyBinding::DUMMY_NAME) != string::npos) { - bind_macro_parameter = true; - } - } else { - bind_macro_parameter = macro_binding->HasMatchingBinding(colref.GetColumnName()); - } - if (bind_macro_parameter) { - D_ASSERT(macro_binding->HasMatchingBinding(colref.GetColumnName())); - expr = macro_binding->ParamToArg(colref); - } - return; - } - case ExpressionClass::SUBQUERY: { - // replacing parameters within a subquery is slightly different - auto &sq = (expr->Cast()).subquery; - ParsedExpressionIterator::EnumerateQueryNodeChildren( - *sq->node, [&](unique_ptr &child) { ReplaceMacroParametersRecursive(child); }); - break; - } - default: // fall through - break; - } - // unfold child expressions - ParsedExpressionIterator::EnumerateChildren( - *expr, [&](unique_ptr &child) { ReplaceMacroParametersRecursive(child); }); -} - -BindResult ExpressionBinder::BindMacro(FunctionExpression &function, ScalarMacroCatalogEntry ¯o_func, idx_t depth, - unique_ptr &expr) { - // recast function so we can access the scalar member function->expression - auto ¯o_def = macro_func.function->Cast(); - - // validate the arguments and separate positional and default arguments - vector> positionals; - unordered_map> defaults; - - string error = - MacroFunction::ValidateArguments(*macro_func.function, macro_func.name, function, positionals, defaults); - if (!error.empty()) { - throw BinderException(binder.FormatError(*expr, error)); - } - - // create a MacroBinding to bind this macro's parameters to its arguments - vector types; - vector names; - // positional parameters - for (idx_t i = 0; i < macro_def.parameters.size(); i++) { - types.emplace_back(LogicalType::SQLNULL); - auto ¶m = macro_def.parameters[i]->Cast(); - names.push_back(param.GetColumnName()); - } - // default parameters - for (auto it = macro_def.default_parameters.begin(); it != macro_def.default_parameters.end(); it++) { - types.emplace_back(LogicalType::SQLNULL); - names.push_back(it->first); - // now push the defaults into the positionals - positionals.push_back(std::move(defaults[it->first])); - } - auto new_macro_binding = make_uniq(types, names, macro_func.name); - new_macro_binding->arguments = &positionals; - macro_binding = new_macro_binding.get(); - - // replace current expression with stored macro expression - expr = macro_def.expression->Copy(); - - // now replace the parameters - ReplaceMacroParametersRecursive(expr); - - // bind the unfolded macro - return BindExpression(expr, depth); -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -static LogicalType ResolveNotType(OperatorExpression &op, vector> &children) { - // NOT expression, cast child to BOOLEAN - D_ASSERT(children.size() == 1); - children[0] = BoundCastExpression::AddDefaultCastToType(std::move(children[0]), LogicalType::BOOLEAN); - return LogicalType(LogicalTypeId::BOOLEAN); -} - -static LogicalType ResolveInType(OperatorExpression &op, vector> &children) { - if (children.empty()) { - throw InternalException("IN requires at least a single child node"); - } - // get the maximum type from the children - LogicalType max_type = children[0]->return_type; - bool any_varchar = children[0]->return_type == LogicalType::VARCHAR; - bool any_enum = children[0]->return_type.id() == LogicalTypeId::ENUM; - for (idx_t i = 1; i < children.size(); i++) { - max_type = LogicalType::MaxLogicalType(max_type, children[i]->return_type); - if (children[i]->return_type == LogicalType::VARCHAR) { - any_varchar = true; - } - if (children[i]->return_type.id() == LogicalTypeId::ENUM) { - any_enum = true; - } - } - if (any_varchar && any_enum) { - // For the coalesce function, we must be sure we always upcast the parameters to VARCHAR, if there are at least - // one enum and one varchar - max_type = LogicalType::VARCHAR; - } - - // cast all children to the same type - for (idx_t i = 0; i < children.size(); i++) { - children[i] = BoundCastExpression::AddDefaultCastToType(std::move(children[i]), max_type); - } - // (NOT) IN always returns a boolean - return LogicalType::BOOLEAN; -} - -static LogicalType ResolveOperatorType(OperatorExpression &op, vector> &children) { - switch (op.type) { - case ExpressionType::OPERATOR_IS_NULL: - case ExpressionType::OPERATOR_IS_NOT_NULL: - // IS (NOT) NULL always returns a boolean, and does not cast its children - if (!children[0]->return_type.IsValid()) { - throw ParameterNotResolvedException(); - } - return LogicalType::BOOLEAN; - case ExpressionType::COMPARE_IN: - case ExpressionType::COMPARE_NOT_IN: - return ResolveInType(op, children); - case ExpressionType::OPERATOR_COALESCE: { - ResolveInType(op, children); - return children[0]->return_type; - } - case ExpressionType::OPERATOR_NOT: - return ResolveNotType(op, children); - default: - throw InternalException("Unrecognized expression type for ResolveOperatorType"); - } -} - -BindResult ExpressionBinder::BindGroupingFunction(OperatorExpression &op, idx_t depth) { - return BindResult("GROUPING function is not supported here"); -} - -BindResult ExpressionBinder::BindExpression(OperatorExpression &op, idx_t depth) { - if (op.type == ExpressionType::GROUPING_FUNCTION) { - return BindGroupingFunction(op, depth); - } - // bind the children of the operator expression - string error; - for (idx_t i = 0; i < op.children.size(); i++) { - BindChild(op.children[i], depth, error); - } - if (!error.empty()) { - return BindResult(error); - } - // all children bound successfully - string function_name; - switch (op.type) { - case ExpressionType::ARRAY_EXTRACT: { - D_ASSERT(op.children[0]->expression_class == ExpressionClass::BOUND_EXPRESSION); - auto &b_exp = BoundExpression::GetExpression(*op.children[0]); - if (b_exp->return_type.id() == LogicalTypeId::MAP) { - function_name = "map_extract"; - } else { - function_name = "array_extract"; - } - break; - } - case ExpressionType::ARRAY_SLICE: - function_name = "array_slice"; - break; - case ExpressionType::STRUCT_EXTRACT: { - D_ASSERT(op.children.size() == 2); - D_ASSERT(op.children[0]->expression_class == ExpressionClass::BOUND_EXPRESSION); - D_ASSERT(op.children[1]->expression_class == ExpressionClass::BOUND_EXPRESSION); - auto &extract_exp = BoundExpression::GetExpression(*op.children[0]); - auto &name_exp = BoundExpression::GetExpression(*op.children[1]); - auto extract_expr_type = extract_exp->return_type.id(); - if (extract_expr_type != LogicalTypeId::STRUCT && extract_expr_type != LogicalTypeId::UNION && - extract_expr_type != LogicalTypeId::SQLNULL) { - return BindResult(StringUtil::Format( - "Cannot extract field %s from expression \"%s\" because it is not a struct or a union", - name_exp->ToString(), extract_exp->ToString())); - } - function_name = extract_expr_type == LogicalTypeId::UNION ? "union_extract" : "struct_extract"; - break; - } - case ExpressionType::ARRAY_CONSTRUCTOR: - function_name = "list_value"; - break; - case ExpressionType::ARROW: - function_name = "json_extract"; - break; - default: - break; - } - if (!function_name.empty()) { - auto function = make_uniq_base(function_name, std::move(op.children)); - return BindExpression(function, depth, false); - } - - vector> children; - for (idx_t i = 0; i < op.children.size(); i++) { - D_ASSERT(op.children[i]->expression_class == ExpressionClass::BOUND_EXPRESSION); - children.push_back(std::move(BoundExpression::GetExpression(*op.children[i]))); - } - // now resolve the types - LogicalType result_type = ResolveOperatorType(op, children); - if (op.type == ExpressionType::OPERATOR_COALESCE) { - if (children.empty()) { - throw BinderException("COALESCE needs at least one child"); - } - if (children.size() == 1) { - return BindResult(std::move(children[0])); - } - } - - auto result = make_uniq(op.type, result_type); - for (auto &child : children) { - result->children.push_back(std::move(child)); - } - return BindResult(std::move(result)); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -BindResult ExpressionBinder::BindExpression(ParameterExpression &expr, idx_t depth) { - if (!binder.parameters) { - throw BinderException("Unexpected prepared parameter. This type of statement can't be prepared!"); - } - auto parameter_id = expr.identifier; - - D_ASSERT(binder.parameters); - // Check if a parameter value has already been supplied - auto ¶meter_data = binder.parameters->GetParameterData(); - auto param_data_it = parameter_data.find(parameter_id); - if (param_data_it != parameter_data.end()) { - // it has! emit a constant directly - auto &data = param_data_it->second; - auto constant = make_uniq(data.GetValue()); - constant->alias = expr.alias; - constant->return_type = binder.parameters->GetReturnType(parameter_id); - return BindResult(std::move(constant)); - } - - auto bound_parameter = binder.parameters->BindParameterExpression(expr); - return BindResult(std::move(bound_parameter)); -} - -} // namespace duckdb - - - - -namespace duckdb { - -BindResult ExpressionBinder::BindPositionalReference(unique_ptr &expr, idx_t depth, - bool root_expression) { - auto &ref = expr->Cast(); - if (depth != 0) { - throw InternalException("Positional reference expression could not be bound"); - } - // replace the positional reference with a column - auto column = binder.bind_context.PositionToColumn(ref); - expr = std::move(column); - return BindExpression(expr, depth, root_expression); -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -string GetColumnsStringValue(ParsedExpression &expr) { - if (expr.type == ExpressionType::COLUMN_REF) { - auto &colref = expr.Cast(); - return colref.GetColumnName(); - } else { - return expr.ToString(); - } -} - -bool Binder::FindStarExpression(unique_ptr &expr, StarExpression **star, bool is_root, - bool in_columns) { - bool has_star = false; - if (expr->GetExpressionClass() == ExpressionClass::STAR) { - auto ¤t_star = expr->Cast(); - if (!current_star.columns) { - if (is_root) { - *star = ¤t_star; - return true; - } - if (!in_columns) { - throw BinderException( - "STAR expression is only allowed as the root element of an expression. Use COLUMNS(*) instead."); - } - // star expression inside a COLUMNS - convert to a constant list - if (!current_star.replace_list.empty()) { - throw BinderException( - "STAR expression with REPLACE list is only allowed as the root element of COLUMNS"); - } - vector> star_list; - bind_context.GenerateAllColumnExpressions(current_star, star_list); - - vector values; - values.reserve(star_list.size()); - for (auto &expr : star_list) { - values.emplace_back(GetColumnsStringValue(*expr)); - } - D_ASSERT(!values.empty()); - - expr = make_uniq(Value::LIST(LogicalType::VARCHAR, values)); - return true; - } - if (in_columns) { - throw BinderException("COLUMNS expression is not allowed inside another COLUMNS expression"); - } - in_columns = true; - if (*star) { - // we can have multiple - if (!(*star)->Equals(current_star)) { - throw BinderException( - FormatError(*expr, "Multiple different STAR/COLUMNS in the same expression are not supported")); - } - return true; - } - *star = ¤t_star; - has_star = true; - } - ParsedExpressionIterator::EnumerateChildren(*expr, [&](unique_ptr &child_expr) { - if (FindStarExpression(child_expr, star, false, in_columns)) { - has_star = true; - } - }); - return has_star; -} - -void Binder::ReplaceStarExpression(unique_ptr &expr, unique_ptr &replacement) { - D_ASSERT(expr); - if (expr->GetExpressionClass() == ExpressionClass::STAR) { - D_ASSERT(replacement); - expr = replacement->Copy(); - return; - } - ParsedExpressionIterator::EnumerateChildren( - *expr, [&](unique_ptr &child_expr) { ReplaceStarExpression(child_expr, replacement); }); -} - -void Binder::ExpandStarExpression(unique_ptr expr, - vector> &new_select_list) { - StarExpression *star = nullptr; - if (!FindStarExpression(expr, &star, true, false)) { - // no star expression: add it as-is - D_ASSERT(!star); - new_select_list.push_back(std::move(expr)); - return; - } - D_ASSERT(star); - vector> star_list; - // we have star expressions! expand the list of star expressions - bind_context.GenerateAllColumnExpressions(*star, star_list); - - if (star->expr) { - // COLUMNS with an expression - // two options: - // VARCHAR parameter <- this is a regular expression - // LIST of VARCHAR parameters <- this is a set of columns - TableFunctionBinder binder(*this, context); - auto child = star->expr->Copy(); - auto result = binder.Bind(child); - if (!result->IsFoldable()) { - // cannot resolve parameters here - if (star->expr->HasParameter()) { - throw ParameterNotResolvedException(); - } else { - throw BinderException("Unsupported expression in COLUMNS"); - } - } - auto val = ExpressionExecutor::EvaluateScalar(context, *result); - if (val.type().id() == LogicalTypeId::VARCHAR) { - // regex - if (val.IsNull()) { - throw BinderException("COLUMNS does not support NULL as regex argument"); - } - auto ®ex_str = StringValue::Get(val); - duckdb_re2::RE2 regex(regex_str); - if (!regex.error().empty()) { - auto err = StringUtil::Format("Failed to compile regex \"%s\": %s", regex_str, regex.error()); - throw BinderException(FormatError(*star, err)); - } - vector> new_list; - for (idx_t i = 0; i < star_list.size(); i++) { - auto &colref = star_list[i]->Cast(); - if (!RE2::PartialMatch(colref.GetColumnName(), regex)) { - continue; - } - new_list.push_back(std::move(star_list[i])); - } - if (new_list.empty()) { - auto err = StringUtil::Format("No matching columns found that match regex \"%s\"", regex_str); - throw BinderException(FormatError(*star, err)); - } - star_list = std::move(new_list); - } else if (val.type().id() == LogicalTypeId::LIST && - ListType::GetChildType(val.type()).id() == LogicalTypeId::VARCHAR) { - // list of varchar columns - if (val.IsNull() || ListValue::GetChildren(val).empty()) { - auto err = - StringUtil::Format("Star expression \"%s\" resulted in an empty set of columns", star->ToString()); - throw BinderException(FormatError(*star, err)); - } - auto &children = ListValue::GetChildren(val); - vector> new_list; - // scan the list for all selected columns and construct a lookup table - case_insensitive_map_t selected_set; - for (auto &child : children) { - selected_set.insert(make_pair(StringValue::Get(child), false)); - } - // now check the list of all possible expressions and select which ones make it in - for (auto &expr : star_list) { - auto str = GetColumnsStringValue(*expr); - auto entry = selected_set.find(str); - if (entry != selected_set.end()) { - new_list.push_back(std::move(expr)); - entry->second = true; - } - } - // check if all expressions found a match - for (auto &entry : selected_set) { - if (!entry.second) { - throw BinderException("Column \"%s\" was selected but was not found in the FROM clause", - entry.first); - } - } - star_list = std::move(new_list); - } else { - throw BinderException(FormatError( - *star, "COLUMNS expects either a VARCHAR argument (regex) or a LIST of VARCHAR (list of columns)")); - } - } - - // now perform the replacement - for (idx_t i = 0; i < star_list.size(); i++) { - auto new_expr = expr->Copy(); - ReplaceStarExpression(new_expr, star_list[i]); - new_select_list.push_back(std::move(new_expr)); - } -} - -void Binder::ExpandStarExpressions(vector> &select_list, - vector> &new_select_list) { - for (auto &select_element : select_list) { - ExpandStarExpression(std::move(select_element), new_select_list); - } -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -class BoundSubqueryNode : public QueryNode { -public: - static constexpr const QueryNodeType TYPE = QueryNodeType::BOUND_SUBQUERY_NODE; - -public: - BoundSubqueryNode(shared_ptr subquery_binder, unique_ptr bound_node, - unique_ptr subquery) - : QueryNode(QueryNodeType::BOUND_SUBQUERY_NODE), subquery_binder(std::move(subquery_binder)), - bound_node(std::move(bound_node)), subquery(std::move(subquery)) { - } - - shared_ptr subquery_binder; - unique_ptr bound_node; - unique_ptr subquery; - - const vector> &GetSelectList() const override { - throw InternalException("Cannot get select list of bound subquery node"); - } - - string ToString() const override { - throw InternalException("Cannot ToString bound subquery node"); - } - unique_ptr Copy() const override { - throw InternalException("Cannot copy bound subquery node"); - } - - void Serialize(Serializer &serializer) const override { - throw InternalException("Cannot serialize bound subquery node"); - } -}; - -BindResult ExpressionBinder::BindExpression(SubqueryExpression &expr, idx_t depth) { - if (expr.subquery->node->type != QueryNodeType::BOUND_SUBQUERY_NODE) { - D_ASSERT(depth == 0); - // first bind the actual subquery in a new binder - auto subquery_binder = Binder::CreateBinder(context, &binder); - subquery_binder->can_contain_nulls = true; - auto bound_node = subquery_binder->BindNode(*expr.subquery->node); - // check the correlated columns of the subquery for correlated columns with depth > 1 - for (idx_t i = 0; i < subquery_binder->correlated_columns.size(); i++) { - CorrelatedColumnInfo corr = subquery_binder->correlated_columns[i]; - if (corr.depth > 1) { - // depth > 1, the column references the query ABOVE the current one - // add to the set of correlated columns for THIS query - corr.depth -= 1; - binder.AddCorrelatedColumn(corr); - } - } - if (expr.subquery_type != SubqueryType::EXISTS && bound_node->types.size() > 1) { - throw BinderException(binder.FormatError( - expr, StringUtil::Format("Subquery returns %zu columns - expected 1", bound_node->types.size()))); - } - auto prior_subquery = std::move(expr.subquery); - expr.subquery = make_uniq(); - expr.subquery->node = - make_uniq(std::move(subquery_binder), std::move(bound_node), std::move(prior_subquery)); - } - // now bind the child node of the subquery - if (expr.child) { - // first bind the children of the subquery, if any - string error = Bind(expr.child, depth); - if (!error.empty()) { - return BindResult(error); - } - } - // both binding the child and binding the subquery was successful - D_ASSERT(expr.subquery->node->type == QueryNodeType::BOUND_SUBQUERY_NODE); - auto &bound_subquery = expr.subquery->node->Cast(); - auto subquery_binder = std::move(bound_subquery.subquery_binder); - auto bound_node = std::move(bound_subquery.bound_node); - LogicalType return_type = - expr.subquery_type == SubqueryType::SCALAR ? bound_node->types[0] : LogicalType(LogicalTypeId::BOOLEAN); - if (return_type.id() == LogicalTypeId::UNKNOWN) { - return_type = LogicalType::SQLNULL; - } - - auto result = make_uniq(return_type); - if (expr.subquery_type == SubqueryType::ANY) { - // ANY comparison - // cast child and subquery child to equivalent types - D_ASSERT(bound_node->types.size() == 1); - auto &child = BoundExpression::GetExpression(*expr.child); - auto compare_type = LogicalType::MaxLogicalType(child->return_type, bound_node->types[0]); - child = BoundCastExpression::AddCastToType(context, std::move(child), compare_type); - result->child_type = bound_node->types[0]; - result->child_target = compare_type; - result->child = std::move(child); - } - result->binder = std::move(subquery_binder); - result->subquery = std::move(bound_node); - result->subquery_type = expr.subquery_type; - result->comparison_type = expr.comparison_type; - - return BindResult(std::move(result)); -} - -} // namespace duckdb - - - - - - - - - - - - - - - - -namespace duckdb { - -unique_ptr CreateBoundStructExtract(ClientContext &context, unique_ptr expr, string key) { - vector> arguments; - arguments.push_back(std::move(expr)); - arguments.push_back(make_uniq(Value(key))); - auto extract_function = StructExtractFun::GetFunction(); - auto bind_info = extract_function.bind(context, extract_function, arguments); - auto return_type = extract_function.return_type; - auto result = make_uniq(return_type, std::move(extract_function), std::move(arguments), - std::move(bind_info)); - result->alias = std::move(key); - return std::move(result); -} - -BindResult SelectBinder::BindUnnest(FunctionExpression &function, idx_t depth, bool root_expression) { - // bind the children of the function expression - if (depth > 0) { - return BindResult(binder.FormatError(function, "UNNEST() for correlated expressions is not supported yet")); - } - string error; - if (function.children.empty()) { - return BindResult(binder.FormatError(function, "UNNEST() requires a single argument")); - } - idx_t max_depth = 1; - if (function.children.size() != 1) { - bool has_parameter = false; - bool supported_argument = false; - for (idx_t i = 1; i < function.children.size(); i++) { - if (has_parameter) { - return BindResult(binder.FormatError(function, "UNNEST() only supports a single additional argument")); - } - if (function.children[i]->HasParameter()) { - throw ParameterNotAllowedException("Parameter not allowed in unnest parameter"); - } - if (!function.children[i]->IsScalar()) { - break; - } - auto alias = function.children[i]->alias; - BindChild(function.children[i], depth, error); - if (!error.empty()) { - return BindResult(error); - } - auto &const_child = BoundExpression::GetExpression(*function.children[i]); - auto value = ExpressionExecutor::EvaluateScalar(context, *const_child, true); - if (alias == "recursive") { - auto recursive = value.GetValue(); - if (recursive) { - max_depth = NumericLimits::Maximum(); - } - } else if (alias == "max_depth") { - max_depth = value.GetValue(); - if (max_depth == 0) { - throw BinderException("UNNEST cannot have a max depth of 0"); - } - } else if (!alias.empty()) { - throw BinderException("Unsupported parameter \"%s\" for unnest", alias); - } else { - break; - } - has_parameter = true; - supported_argument = true; - } - if (!supported_argument) { - return BindResult(binder.FormatError(function, "UNNEST - unsupported extra argument, unnest only supports " - "recursive := [true/false] or max_depth := #")); - } - } - unnest_level++; - BindChild(function.children[0], depth, error); - if (!error.empty()) { - // failed to bind - // try to bind correlated columns manually - if (!BindCorrelatedColumns(function.children[0])) { - return BindResult(error); - } - auto &bound_expr = BoundExpression::GetExpression(*function.children[0]); - ExtractCorrelatedExpressions(binder, *bound_expr); - } - auto &child = BoundExpression::GetExpression(*function.children[0]); - auto &child_type = child->return_type; - unnest_level--; - - if (unnest_level > 0) { - throw BinderException( - "Nested UNNEST calls are not supported - use UNNEST(x, recursive := true) to unnest multiple levels"); - } - - switch (child_type.id()) { - case LogicalTypeId::UNKNOWN: - throw ParameterNotResolvedException(); - case LogicalTypeId::LIST: - case LogicalTypeId::STRUCT: - case LogicalTypeId::SQLNULL: - break; - default: - return BindResult(binder.FormatError(function, "UNNEST() can only be applied to lists, structs and NULL")); - } - - idx_t list_unnests; - idx_t struct_unnests = 0; - - auto unnest_expr = std::move(child); - if (child_type.id() == LogicalTypeId::SQLNULL) { - list_unnests = 1; - } else { - // first do all of the list unnests - auto type = child_type; - list_unnests = 0; - while (type.id() == LogicalTypeId::LIST) { - type = ListType::GetChildType(type); - list_unnests++; - if (list_unnests >= max_depth) { - break; - } - } - // unnest structs all the way afterwards, if there are any - if (type.id() == LogicalTypeId::STRUCT) { - struct_unnests = max_depth - list_unnests; - } - } - if (struct_unnests > 0 && !root_expression) { - return BindResult(binder.FormatError( - function, "UNNEST() on a struct column can only be applied as the root element of a SELECT expression")); - } - // perform all of the list unnests first - auto return_type = child_type; - for (idx_t current_depth = 0; current_depth < list_unnests; current_depth++) { - if (return_type.id() == LogicalTypeId::LIST) { - return_type = ListType::GetChildType(return_type); - } - auto result = make_uniq(return_type); - result->child = std::move(unnest_expr); - auto alias = function.alias.empty() ? result->ToString() : function.alias; - - auto current_level = unnest_level + list_unnests - current_depth - 1; - auto entry = node.unnests.find(current_level); - idx_t unnest_table_index; - idx_t unnest_column_index; - if (entry == node.unnests.end()) { - BoundUnnestNode unnest_node; - unnest_node.index = binder.GenerateTableIndex(); - unnest_node.expressions.push_back(std::move(result)); - unnest_table_index = unnest_node.index; - unnest_column_index = 0; - node.unnests.insert(make_pair(current_level, std::move(unnest_node))); - } else { - unnest_table_index = entry->second.index; - unnest_column_index = entry->second.expressions.size(); - entry->second.expressions.push_back(std::move(result)); - } - // now create a column reference referring to the unnest - unnest_expr = make_uniq( - std::move(alias), return_type, ColumnBinding(unnest_table_index, unnest_column_index), depth); - } - // now perform struct unnests, if any - if (struct_unnests > 0) { - vector> struct_expressions; - struct_expressions.push_back(std::move(unnest_expr)); - - for (idx_t i = 0; i < struct_unnests; i++) { - vector> new_expressions; - // check if there are any structs left - bool has_structs = false; - for (auto &expr : struct_expressions) { - if (expr->return_type.id() == LogicalTypeId::STRUCT) { - // struct! push a struct_extract - auto &child_types = StructType::GetChildTypes(expr->return_type); - for (auto &entry : child_types) { - new_expressions.push_back(CreateBoundStructExtract(context, expr->Copy(), entry.first)); - } - has_structs = true; - } else { - // not a struct - push as-is - new_expressions.push_back(std::move(expr)); - } - } - struct_expressions = std::move(new_expressions); - if (!has_structs) { - break; - } - } - expanded_expressions = std::move(struct_expressions); - unnest_expr = make_uniq(Value(42)); - } - return BindResult(std::move(unnest_expr)); -} - -} // namespace duckdb - - - - - - - - - - - - - - - - -namespace duckdb { - -static LogicalType ResolveWindowExpressionType(ExpressionType window_type, const vector &child_types) { - - idx_t param_count; - switch (window_type) { - case ExpressionType::WINDOW_RANK: - case ExpressionType::WINDOW_RANK_DENSE: - case ExpressionType::WINDOW_ROW_NUMBER: - case ExpressionType::WINDOW_PERCENT_RANK: - case ExpressionType::WINDOW_CUME_DIST: - param_count = 0; - break; - case ExpressionType::WINDOW_NTILE: - case ExpressionType::WINDOW_FIRST_VALUE: - case ExpressionType::WINDOW_LAST_VALUE: - case ExpressionType::WINDOW_LEAD: - case ExpressionType::WINDOW_LAG: - param_count = 1; - break; - case ExpressionType::WINDOW_NTH_VALUE: - param_count = 2; - break; - default: - throw InternalException("Unrecognized window expression type " + ExpressionTypeToString(window_type)); - } - if (child_types.size() != param_count) { - throw BinderException("%s needs %d parameter%s, got %d", ExpressionTypeToString(window_type), param_count, - param_count == 1 ? "" : "s", child_types.size()); - } - switch (window_type) { - case ExpressionType::WINDOW_PERCENT_RANK: - case ExpressionType::WINDOW_CUME_DIST: - return LogicalType(LogicalTypeId::DOUBLE); - case ExpressionType::WINDOW_ROW_NUMBER: - case ExpressionType::WINDOW_RANK: - case ExpressionType::WINDOW_RANK_DENSE: - case ExpressionType::WINDOW_NTILE: - return LogicalType::BIGINT; - case ExpressionType::WINDOW_NTH_VALUE: - case ExpressionType::WINDOW_FIRST_VALUE: - case ExpressionType::WINDOW_LAST_VALUE: - case ExpressionType::WINDOW_LEAD: - case ExpressionType::WINDOW_LAG: - return child_types[0]; - default: - throw InternalException("Unrecognized window expression type " + ExpressionTypeToString(window_type)); - } -} - -static unique_ptr GetExpression(unique_ptr &expr) { - if (!expr) { - return nullptr; - } - D_ASSERT(expr.get()); - D_ASSERT(expr->expression_class == ExpressionClass::BOUND_EXPRESSION); - return std::move(BoundExpression::GetExpression(*expr)); -} - -static unique_ptr CastWindowExpression(unique_ptr &expr, const LogicalType &type) { - if (!expr) { - return nullptr; - } - D_ASSERT(expr.get()); - D_ASSERT(expr->expression_class == ExpressionClass::BOUND_EXPRESSION); - - auto &bound = BoundExpression::GetExpression(*expr); - bound = BoundCastExpression::AddDefaultCastToType(std::move(bound), type); - - return std::move(bound); -} - -static LogicalType BindRangeExpression(ClientContext &context, const string &name, unique_ptr &expr, - unique_ptr &order_expr) { - - vector> children; - - D_ASSERT(order_expr.get()); - D_ASSERT(order_expr->expression_class == ExpressionClass::BOUND_EXPRESSION); - auto &bound_order = BoundExpression::GetExpression(*order_expr); - children.emplace_back(bound_order->Copy()); - - D_ASSERT(expr.get()); - D_ASSERT(expr->expression_class == ExpressionClass::BOUND_EXPRESSION); - auto &bound = BoundExpression::GetExpression(*expr); - children.emplace_back(std::move(bound)); - - string error; - FunctionBinder function_binder(context); - auto function = function_binder.BindScalarFunction(DEFAULT_SCHEMA, name, std::move(children), error, true); - if (!function) { - throw BinderException(error); - } - bound = std::move(function); - return bound->return_type; -} - -BindResult BaseSelectBinder::BindWindow(WindowExpression &window, idx_t depth) { - auto name = window.GetName(); - - QueryErrorContext error_context(binder.GetRootStatement(), window.query_location); - if (inside_window) { - throw BinderException(error_context.FormatError("window function calls cannot be nested")); - } - if (depth > 0) { - throw BinderException(error_context.FormatError("correlated columns in window functions not supported")); - } - // If we have range expressions, then only one order by clause is allowed. - if ((window.start == WindowBoundary::EXPR_PRECEDING_RANGE || window.start == WindowBoundary::EXPR_FOLLOWING_RANGE || - window.end == WindowBoundary::EXPR_PRECEDING_RANGE || window.end == WindowBoundary::EXPR_FOLLOWING_RANGE) && - window.orders.size() != 1) { - throw BinderException(error_context.FormatError("RANGE frames must have only one ORDER BY expression")); - } - // bind inside the children of the window function - // we set the inside_window flag to true to prevent binding nested window functions - this->inside_window = true; - string error; - for (auto &child : window.children) { - BindChild(child, depth, error); - } - for (auto &child : window.partitions) { - BindChild(child, depth, error); - } - for (auto &order : window.orders) { - BindChild(order.expression, depth, error); - } - BindChild(window.filter_expr, depth, error); - BindChild(window.start_expr, depth, error); - BindChild(window.end_expr, depth, error); - BindChild(window.offset_expr, depth, error); - BindChild(window.default_expr, depth, error); - - this->inside_window = false; - if (!error.empty()) { - // failed to bind children of window function - return BindResult(error); - } - // successfully bound all children: create bound window function - vector types; - vector> children; - for (auto &child : window.children) { - D_ASSERT(child.get()); - D_ASSERT(child->expression_class == ExpressionClass::BOUND_EXPRESSION); - auto &bound = BoundExpression::GetExpression(*child); - // Add casts for positional arguments - const auto argno = children.size(); - switch (window.type) { - case ExpressionType::WINDOW_NTILE: - // ntile(bigint) - if (argno == 0) { - bound = BoundCastExpression::AddCastToType(context, std::move(bound), LogicalType::BIGINT); - } - break; - case ExpressionType::WINDOW_NTH_VALUE: - // nth_value(, index) - if (argno == 1) { - bound = BoundCastExpression::AddCastToType(context, std::move(bound), LogicalType::BIGINT); - } - default: - break; - } - types.push_back(bound->return_type); - children.push_back(std::move(bound)); - } - // Determine the function type. - LogicalType sql_type; - unique_ptr aggregate; - unique_ptr bind_info; - if (window.type == ExpressionType::WINDOW_AGGREGATE) { - // Look up the aggregate function in the catalog - auto &func = Catalog::GetEntry(context, window.catalog, window.schema, - window.function_name, error_context); - D_ASSERT(func.type == CatalogType::AGGREGATE_FUNCTION_ENTRY); - - // bind the aggregate - string error; - FunctionBinder function_binder(context); - auto best_function = function_binder.BindFunction(func.name, func.functions, types, error); - if (best_function == DConstants::INVALID_INDEX) { - throw BinderException(binder.FormatError(window, error)); - } - // found a matching function! bind it as an aggregate - auto bound_function = func.functions.GetFunctionByOffset(best_function); - auto bound_aggregate = function_binder.BindAggregateFunction(bound_function, std::move(children)); - // create the aggregate - aggregate = make_uniq(bound_aggregate->function); - bind_info = std::move(bound_aggregate->bind_info); - children = std::move(bound_aggregate->children); - sql_type = bound_aggregate->return_type; - } else { - // fetch the child of the non-aggregate window function (if any) - sql_type = ResolveWindowExpressionType(window.type, types); - } - auto result = make_uniq(window.type, sql_type, std::move(aggregate), std::move(bind_info)); - result->children = std::move(children); - for (auto &child : window.partitions) { - result->partitions.push_back(GetExpression(child)); - } - result->ignore_nulls = window.ignore_nulls; - - // Convert RANGE boundary expressions to ORDER +/- expressions. - // Note that PRECEEDING and FOLLOWING refer to the sequential order in the frame, - // not the natural ordering of the type. This means that the offset arithmetic must be reversed - // for ORDER BY DESC. - auto &config = DBConfig::GetConfig(context); - auto range_sense = OrderType::INVALID; - LogicalType start_type = LogicalType::BIGINT; - if (window.start == WindowBoundary::EXPR_PRECEDING_RANGE) { - D_ASSERT(window.orders.size() == 1); - range_sense = config.ResolveOrder(window.orders[0].type); - const auto name = (range_sense == OrderType::ASCENDING) ? "-" : "+"; - start_type = BindRangeExpression(context, name, window.start_expr, window.orders[0].expression); - } else if (window.start == WindowBoundary::EXPR_FOLLOWING_RANGE) { - D_ASSERT(window.orders.size() == 1); - range_sense = config.ResolveOrder(window.orders[0].type); - const auto name = (range_sense == OrderType::ASCENDING) ? "+" : "-"; - start_type = BindRangeExpression(context, name, window.start_expr, window.orders[0].expression); - } - - LogicalType end_type = LogicalType::BIGINT; - if (window.end == WindowBoundary::EXPR_PRECEDING_RANGE) { - D_ASSERT(window.orders.size() == 1); - range_sense = config.ResolveOrder(window.orders[0].type); - const auto name = (range_sense == OrderType::ASCENDING) ? "-" : "+"; - end_type = BindRangeExpression(context, name, window.end_expr, window.orders[0].expression); - } else if (window.end == WindowBoundary::EXPR_FOLLOWING_RANGE) { - D_ASSERT(window.orders.size() == 1); - range_sense = config.ResolveOrder(window.orders[0].type); - const auto name = (range_sense == OrderType::ASCENDING) ? "+" : "-"; - end_type = BindRangeExpression(context, name, window.end_expr, window.orders[0].expression); - } - - // Cast ORDER and boundary expressions to the same type - if (range_sense != OrderType::INVALID) { - D_ASSERT(window.orders.size() == 1); - - auto &order_expr = window.orders[0].expression; - D_ASSERT(order_expr.get()); - D_ASSERT(order_expr->expression_class == ExpressionClass::BOUND_EXPRESSION); - auto &bound_order = BoundExpression::GetExpression(*order_expr); - auto order_type = bound_order->return_type; - if (window.start_expr) { - order_type = LogicalType::MaxLogicalType(order_type, start_type); - } - if (window.end_expr) { - order_type = LogicalType::MaxLogicalType(order_type, end_type); - } - - // Cast all three to match - bound_order = BoundCastExpression::AddCastToType(context, std::move(bound_order), order_type); - start_type = end_type = order_type; - } - - for (auto &order : window.orders) { - auto type = config.ResolveOrder(order.type); - auto null_order = config.ResolveNullOrder(type, order.null_order); - auto expression = GetExpression(order.expression); - result->orders.emplace_back(type, null_order, std::move(expression)); - } - - result->filter_expr = CastWindowExpression(window.filter_expr, LogicalType::BOOLEAN); - - result->start_expr = CastWindowExpression(window.start_expr, start_type); - result->end_expr = CastWindowExpression(window.end_expr, end_type); - result->offset_expr = CastWindowExpression(window.offset_expr, LogicalType::BIGINT); - result->default_expr = CastWindowExpression(window.default_expr, result->return_type); - result->start = window.start; - result->end = window.end; - - // create a BoundColumnRef that references this entry - auto colref = make_uniq(std::move(name), result->return_type, - ColumnBinding(node.window_index, node.windows.size()), depth); - // move the WINDOW expression into the set of bound windows - node.windows.push_back(std::move(result)); - return BindResult(std::move(colref)); -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -unique_ptr Binder::BindNode(CTENode &statement) { - auto result = make_uniq(); - - // first recursively visit the materialized CTE operations - // the left side is visited first and is added to the BindContext of the right side - D_ASSERT(statement.query); - D_ASSERT(statement.child); - - result->ctename = statement.ctename; - result->setop_index = GenerateTableIndex(); - - result->query_binder = Binder::CreateBinder(context, this); - result->query = result->query_binder->BindNode(*statement.query); - - // the result types of the CTE are the types of the LHS - result->types = result->query->types; - // names are picked from the LHS, unless aliases are explicitly specified - result->names = result->query->names; - for (idx_t i = 0; i < statement.aliases.size() && i < result->names.size(); i++) { - result->names[i] = statement.aliases[i]; - } - - // This allows the right side to reference the CTE - bind_context.AddGenericBinding(result->setop_index, statement.ctename, result->names, result->types); - - result->child_binder = Binder::CreateBinder(context, this); - - // Move all modifiers to the child node. - for (auto &modifier : statement.modifiers) { - statement.child->modifiers.push_back(std::move(modifier)); - } - - statement.modifiers.clear(); - - // Add bindings of left side to temporary CTE bindings context - result->child_binder->bind_context.AddCTEBinding(result->setop_index, statement.ctename, result->names, - result->types); - result->child = result->child_binder->BindNode(*statement.child); - - // the result types of the CTE are the types of the LHS - result->types = result->child->types; - // names are picked from the LHS, unless aliases are explicitly specified - result->names = result->child->names; - for (idx_t i = 0; i < statement.aliases.size() && i < result->names.size(); i++) { - result->names[i] = statement.aliases[i]; - } - - MoveCorrelatedExpressions(*result->query_binder); - MoveCorrelatedExpressions(*result->child_binder); - - return std::move(result); -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -unique_ptr Binder::BindNode(RecursiveCTENode &statement) { - auto result = make_uniq(); - - // first recursively visit the recursive CTE operations - // the left side is visited first and is added to the BindContext of the right side - D_ASSERT(statement.left); - D_ASSERT(statement.right); - - result->ctename = statement.ctename; - result->union_all = statement.union_all; - result->setop_index = GenerateTableIndex(); - - result->left_binder = Binder::CreateBinder(context, this); - result->left = result->left_binder->BindNode(*statement.left); - - // the result types of the CTE are the types of the LHS - result->types = result->left->types; - // names are picked from the LHS, unless aliases are explicitly specified - result->names = result->left->names; - for (idx_t i = 0; i < statement.aliases.size() && i < result->names.size(); i++) { - result->names[i] = statement.aliases[i]; - } - - // This allows the right side to reference the CTE recursively - bind_context.AddGenericBinding(result->setop_index, statement.ctename, result->names, result->types); - - result->right_binder = Binder::CreateBinder(context, this); - - // Add bindings of left side to temporary CTE bindings context - result->right_binder->bind_context.AddCTEBinding(result->setop_index, statement.ctename, result->names, - result->types); - result->right = result->right_binder->BindNode(*statement.right); - - // move the correlated expressions from the child binders to this binder - MoveCorrelatedExpressions(*result->left_binder); - MoveCorrelatedExpressions(*result->right_binder); - - // now both sides have been bound we can resolve types - if (result->left->types.size() != result->right->types.size()) { - throw BinderException("Set operations can only apply to expressions with the " - "same number of result columns"); - } - - if (!statement.modifiers.empty()) { - throw NotImplementedException("FIXME: bind modifiers in recursive CTE"); - } - - return std::move(result); -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - - - - - - - - - - -namespace duckdb { - -unique_ptr Binder::BindOrderExpression(OrderBinder &order_binder, unique_ptr expr) { - // we treat the Distinct list as a order by - auto bound_expr = order_binder.Bind(std::move(expr)); - if (!bound_expr) { - // DISTINCT ON non-integer constant - // remove the expression from the DISTINCT ON list - return nullptr; - } - D_ASSERT(bound_expr->type == ExpressionType::BOUND_COLUMN_REF); - return bound_expr; -} - -unique_ptr Binder::BindDelimiter(ClientContext &context, OrderBinder &order_binder, - unique_ptr delimiter, const LogicalType &type, - Value &delimiter_value) { - auto new_binder = Binder::CreateBinder(context, this, true); - if (delimiter->HasSubquery()) { - if (!order_binder.HasExtraList()) { - throw BinderException("Subquery in LIMIT/OFFSET not supported in set operation"); - } - return order_binder.CreateExtraReference(std::move(delimiter)); - } - ExpressionBinder expr_binder(*new_binder, context); - expr_binder.target_type = type; - auto expr = expr_binder.Bind(delimiter); - if (expr->IsFoldable()) { - //! this is a constant - delimiter_value = ExpressionExecutor::EvaluateScalar(context, *expr).CastAs(context, type); - return nullptr; - } - if (!new_binder->correlated_columns.empty()) { - throw BinderException("Correlated columns not supported in LIMIT/OFFSET"); - } - // move any correlated columns to this binder - MoveCorrelatedExpressions(*new_binder); - return expr; -} - -duckdb::unique_ptr Binder::BindLimit(OrderBinder &order_binder, LimitModifier &limit_mod) { - auto result = make_uniq(); - if (limit_mod.limit) { - Value val; - result->limit = BindDelimiter(context, order_binder, std::move(limit_mod.limit), LogicalType::BIGINT, val); - if (!result->limit) { - result->limit_val = val.IsNull() ? NumericLimits::Maximum() : val.GetValue(); - if (result->limit_val < 0) { - throw BinderException("LIMIT cannot be negative"); - } - } - } - if (limit_mod.offset) { - Value val; - result->offset = BindDelimiter(context, order_binder, std::move(limit_mod.offset), LogicalType::BIGINT, val); - if (!result->offset) { - result->offset_val = val.IsNull() ? 0 : val.GetValue(); - if (result->offset_val < 0) { - throw BinderException("OFFSET cannot be negative"); - } - } - } - return std::move(result); -} - -unique_ptr Binder::BindLimitPercent(OrderBinder &order_binder, LimitPercentModifier &limit_mod) { - auto result = make_uniq(); - if (limit_mod.limit) { - Value val; - result->limit = BindDelimiter(context, order_binder, std::move(limit_mod.limit), LogicalType::DOUBLE, val); - if (!result->limit) { - result->limit_percent = val.IsNull() ? 100 : val.GetValue(); - if (result->limit_percent < 0.0) { - throw Exception("Limit percentage can't be negative value"); - } - } - } - if (limit_mod.offset) { - Value val; - result->offset = BindDelimiter(context, order_binder, std::move(limit_mod.offset), LogicalType::BIGINT, val); - if (!result->offset) { - result->offset_val = val.IsNull() ? 0 : val.GetValue(); - } - } - return std::move(result); -} - -void Binder::BindModifiers(OrderBinder &order_binder, QueryNode &statement, BoundQueryNode &result) { - for (auto &mod : statement.modifiers) { - unique_ptr bound_modifier; - switch (mod->type) { - case ResultModifierType::DISTINCT_MODIFIER: { - auto &distinct = mod->Cast(); - auto bound_distinct = make_uniq(); - bound_distinct->distinct_type = - distinct.distinct_on_targets.empty() ? DistinctType::DISTINCT : DistinctType::DISTINCT_ON; - if (distinct.distinct_on_targets.empty()) { - for (idx_t i = 0; i < result.names.size(); i++) { - distinct.distinct_on_targets.push_back(make_uniq(Value::INTEGER(1 + i))); - } - } - for (auto &distinct_on_target : distinct.distinct_on_targets) { - auto expr = BindOrderExpression(order_binder, std::move(distinct_on_target)); - if (!expr) { - continue; - } - bound_distinct->target_distincts.push_back(std::move(expr)); - } - bound_modifier = std::move(bound_distinct); - break; - } - case ResultModifierType::ORDER_MODIFIER: { - auto &order = mod->Cast(); - auto bound_order = make_uniq(); - auto &config = DBConfig::GetConfig(context); - D_ASSERT(!order.orders.empty()); - auto &order_binders = order_binder.GetBinders(); - if (order.orders.size() == 1 && order.orders[0].expression->type == ExpressionType::STAR) { - auto &star = order.orders[0].expression->Cast(); - if (star.exclude_list.empty() && star.replace_list.empty() && !star.expr) { - // ORDER BY ALL - // replace the order list with the all elements in the SELECT list - auto order_type = order.orders[0].type; - auto null_order = order.orders[0].null_order; - - vector new_orders; - for (idx_t i = 0; i < order_binder.MaxCount(); i++) { - new_orders.emplace_back(order_type, null_order, - make_uniq(Value::INTEGER(i + 1))); - } - order.orders = std::move(new_orders); - } - } - for (auto &order_node : order.orders) { - vector> order_list; - order_binders[0]->ExpandStarExpression(std::move(order_node.expression), order_list); - - auto type = config.ResolveOrder(order_node.type); - auto null_order = config.ResolveNullOrder(type, order_node.null_order); - for (auto &order_expr : order_list) { - auto bound_expr = BindOrderExpression(order_binder, std::move(order_expr)); - if (!bound_expr) { - continue; - } - bound_order->orders.emplace_back(type, null_order, std::move(bound_expr)); - } - } - if (!bound_order->orders.empty()) { - bound_modifier = std::move(bound_order); - } - break; - } - case ResultModifierType::LIMIT_MODIFIER: - bound_modifier = BindLimit(order_binder, mod->Cast()); - break; - case ResultModifierType::LIMIT_PERCENT_MODIFIER: - bound_modifier = BindLimitPercent(order_binder, mod->Cast()); - break; - default: - throw Exception("Unsupported result modifier"); - } - if (bound_modifier) { - result.modifiers.push_back(std::move(bound_modifier)); - } - } -} - -static void AssignReturnType(unique_ptr &expr, const vector &sql_types) { - if (!expr) { - return; - } - if (expr->type != ExpressionType::BOUND_COLUMN_REF) { - return; - } - auto &bound_colref = expr->Cast(); - bound_colref.return_type = sql_types[bound_colref.binding.column_index]; -} - -void Binder::BindModifierTypes(BoundQueryNode &result, const vector &sql_types, idx_t projection_index) { - for (auto &bound_mod : result.modifiers) { - switch (bound_mod->type) { - case ResultModifierType::DISTINCT_MODIFIER: { - auto &distinct = bound_mod->Cast(); - D_ASSERT(!distinct.target_distincts.empty()); - // set types of distinct targets - for (auto &expr : distinct.target_distincts) { - D_ASSERT(expr->type == ExpressionType::BOUND_COLUMN_REF); - auto &bound_colref = expr->Cast(); - if (bound_colref.binding.column_index == DConstants::INVALID_INDEX) { - throw BinderException("Ambiguous name in DISTINCT ON!"); - } - D_ASSERT(bound_colref.binding.column_index < sql_types.size()); - bound_colref.return_type = sql_types[bound_colref.binding.column_index]; - } - for (auto &target_distinct : distinct.target_distincts) { - auto &bound_colref = target_distinct->Cast(); - const auto &sql_type = sql_types[bound_colref.binding.column_index]; - if (sql_type.id() == LogicalTypeId::VARCHAR) { - target_distinct = ExpressionBinder::PushCollation(context, std::move(target_distinct), - StringType::GetCollation(sql_type), true); - } - } - break; - } - case ResultModifierType::LIMIT_MODIFIER: { - auto &limit = bound_mod->Cast(); - AssignReturnType(limit.limit, sql_types); - AssignReturnType(limit.offset, sql_types); - break; - } - case ResultModifierType::LIMIT_PERCENT_MODIFIER: { - auto &limit = bound_mod->Cast(); - AssignReturnType(limit.limit, sql_types); - AssignReturnType(limit.offset, sql_types); - break; - } - case ResultModifierType::ORDER_MODIFIER: { - auto &order = bound_mod->Cast(); - for (auto &order_node : order.orders) { - auto &expr = order_node.expression; - D_ASSERT(expr->type == ExpressionType::BOUND_COLUMN_REF); - auto &bound_colref = expr->Cast(); - if (bound_colref.binding.column_index == DConstants::INVALID_INDEX) { - throw BinderException("Ambiguous name in ORDER BY!"); - } - D_ASSERT(bound_colref.binding.column_index < sql_types.size()); - const auto &sql_type = sql_types[bound_colref.binding.column_index]; - bound_colref.return_type = sql_types[bound_colref.binding.column_index]; - if (sql_type.id() == LogicalTypeId::VARCHAR) { - order_node.expression = ExpressionBinder::PushCollation(context, std::move(order_node.expression), - StringType::GetCollation(sql_type)); - } - } - break; - } - default: - break; - } - } -} - -unique_ptr Binder::BindNode(SelectNode &statement) { - D_ASSERT(statement.from_table); - // first bind the FROM table statement - auto from = std::move(statement.from_table); - auto from_table = Bind(*from); - return BindSelectNode(statement, std::move(from_table)); -} - -void Binder::BindWhereStarExpression(unique_ptr &expr) { - // expand any expressions in the upper AND recursively - if (expr->type == ExpressionType::CONJUNCTION_AND) { - auto &conj = expr->Cast(); - for (auto &child : conj.children) { - BindWhereStarExpression(child); - } - return; - } - if (expr->type == ExpressionType::STAR) { - auto &star = expr->Cast(); - if (!star.columns) { - throw ParserException("STAR expression is not allowed in the WHERE clause. Use COLUMNS(*) instead."); - } - } - // expand the stars for this expression - vector> new_conditions; - ExpandStarExpression(std::move(expr), new_conditions); - if (new_conditions.empty()) { - throw ParserException("COLUMNS expansion resulted in empty set of columns"); - } - - // set up an AND conjunction between the expanded conditions - expr = std::move(new_conditions[0]); - for (idx_t i = 1; i < new_conditions.size(); i++) { - auto and_conj = make_uniq(ExpressionType::CONJUNCTION_AND, std::move(expr), - std::move(new_conditions[i])); - expr = std::move(and_conj); - } -} - -unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ptr from_table) { - D_ASSERT(from_table); - D_ASSERT(!statement.from_table); - auto result = make_uniq(); - result->projection_index = GenerateTableIndex(); - result->group_index = GenerateTableIndex(); - result->aggregate_index = GenerateTableIndex(); - result->groupings_index = GenerateTableIndex(); - result->window_index = GenerateTableIndex(); - result->prune_index = GenerateTableIndex(); - - result->from_table = std::move(from_table); - // bind the sample clause - if (statement.sample) { - result->sample_options = std::move(statement.sample); - } - - // visit the select list and expand any "*" statements - vector> new_select_list; - ExpandStarExpressions(statement.select_list, new_select_list); - - if (new_select_list.empty()) { - throw BinderException("SELECT list is empty after resolving * expressions!"); - } - statement.select_list = std::move(new_select_list); - - // create a mapping of (alias -> index) and a mapping of (Expression -> index) for the SELECT list - case_insensitive_map_t alias_map; - parsed_expression_map_t projection_map; - for (idx_t i = 0; i < statement.select_list.size(); i++) { - auto &expr = statement.select_list[i]; - result->names.push_back(expr->GetName()); - ExpressionBinder::QualifyColumnNames(*this, expr); - if (!expr->alias.empty()) { - alias_map[expr->alias] = i; - result->names[i] = expr->alias; - } - projection_map[*expr] = i; - result->original_expressions.push_back(expr->Copy()); - } - result->column_count = statement.select_list.size(); - - // first visit the WHERE clause - // the WHERE clause happens before the GROUP BY, PROJECTION or HAVING clauses - if (statement.where_clause) { - // bind any star expressions in the WHERE clause - BindWhereStarExpression(statement.where_clause); - - ColumnAliasBinder alias_binder(*result, alias_map); - WhereBinder where_binder(*this, context, &alias_binder); - unique_ptr condition = std::move(statement.where_clause); - result->where_clause = where_binder.Bind(condition); - } - - // now bind all the result modifiers; including DISTINCT and ORDER BY targets - OrderBinder order_binder({this}, result->projection_index, statement, alias_map, projection_map); - BindModifiers(order_binder, statement, *result); - - vector> unbound_groups; - BoundGroupInformation info; - auto &group_expressions = statement.groups.group_expressions; - if (!group_expressions.empty()) { - // the statement has a GROUP BY clause, bind it - unbound_groups.resize(group_expressions.size()); - GroupBinder group_binder(*this, context, statement, result->group_index, alias_map, info.alias_map); - for (idx_t i = 0; i < group_expressions.size(); i++) { - - // we keep a copy of the unbound expression; - // we keep the unbound copy around to check for group references in the SELECT and HAVING clause - // the reason we want the unbound copy is because we want to figure out whether an expression - // is a group reference BEFORE binding in the SELECT/HAVING binder - group_binder.unbound_expression = group_expressions[i]->Copy(); - group_binder.bind_index = i; - - // bind the groups - LogicalType group_type; - auto bound_expr = group_binder.Bind(group_expressions[i], &group_type); - D_ASSERT(bound_expr->return_type.id() != LogicalTypeId::INVALID); - - // find out whether the expression contains a subquery, it can't be copied if so - auto &bound_expr_ref = *bound_expr; - bool contains_subquery = bound_expr_ref.HasSubquery(); - - // push a potential collation, if necessary - auto collated_expr = ExpressionBinder::PushCollation(context, std::move(bound_expr), - StringType::GetCollation(group_type), true); - if (!contains_subquery && !collated_expr->Equals(bound_expr_ref)) { - // if there is a collation on a group x, we should group by the collated expr, - // but also push a first(x) aggregate in case x is selected (uncollated) - info.collated_groups[i] = result->aggregates.size(); - - auto first_fun = FirstFun::GetFunction(LogicalType::VARCHAR); - vector> first_children; - // FIXME: would be better to just refer to this expression, but for now we copy - first_children.push_back(bound_expr_ref.Copy()); - - FunctionBinder function_binder(context); - auto function = function_binder.BindAggregateFunction(first_fun, std::move(first_children)); - result->aggregates.push_back(std::move(function)); - } - result->groups.group_expressions.push_back(std::move(collated_expr)); - - // in the unbound expression we DO bind the table names of any ColumnRefs - // we do this to make sure that "table.a" and "a" are treated the same - // if we wouldn't do this then (SELECT test.a FROM test GROUP BY a) would not work because "test.a" <> "a" - // hence we convert "a" -> "test.a" in the unbound expression - unbound_groups[i] = std::move(group_binder.unbound_expression); - ExpressionBinder::QualifyColumnNames(*this, unbound_groups[i]); - info.map[*unbound_groups[i]] = i; - } - } - result->groups.grouping_sets = std::move(statement.groups.grouping_sets); - - // bind the HAVING clause, if any - if (statement.having) { - HavingBinder having_binder(*this, context, *result, info, alias_map, statement.aggregate_handling); - ExpressionBinder::QualifyColumnNames(*this, statement.having); - result->having = having_binder.Bind(statement.having); - } - - // bind the QUALIFY clause, if any - unique_ptr qualify_binder; - if (statement.qualify) { - if (statement.aggregate_handling == AggregateHandling::FORCE_AGGREGATES) { - throw BinderException("Combining QUALIFY with GROUP BY ALL is not supported yet"); - } - qualify_binder = make_uniq(*this, context, *result, info, alias_map); - ExpressionBinder::QualifyColumnNames(*this, statement.qualify); - result->qualify = qualify_binder->Bind(statement.qualify); - if (qualify_binder->HasBoundColumns() && qualify_binder->BoundAggregates()) { - throw BinderException("Cannot mix aggregates with non-aggregated columns!"); - } - } - - // after that, we bind to the SELECT list - SelectBinder select_binder(*this, context, *result, info, alias_map); - vector internal_sql_types; - vector group_by_all_indexes; - vector new_names; - for (idx_t i = 0; i < statement.select_list.size(); i++) { - bool is_window = statement.select_list[i]->IsWindow(); - idx_t unnest_count = result->unnests.size(); - LogicalType result_type; - auto expr = select_binder.Bind(statement.select_list[i], &result_type, true); - bool is_original_column = i < result->column_count; - bool can_group_by_all = - statement.aggregate_handling == AggregateHandling::FORCE_AGGREGATES && is_original_column; - if (select_binder.HasExpandedExpressions()) { - if (!is_original_column) { - throw InternalException("Only original columns can have expanded expressions"); - } - if (statement.aggregate_handling == AggregateHandling::FORCE_AGGREGATES) { - throw BinderException("UNNEST of struct cannot be combined with GROUP BY ALL"); - } - auto &struct_expressions = select_binder.ExpandedExpressions(); - D_ASSERT(!struct_expressions.empty()); - for (auto &struct_expr : struct_expressions) { - new_names.push_back(struct_expr->GetName()); - result->types.push_back(struct_expr->return_type); - result->select_list.push_back(std::move(struct_expr)); - } - struct_expressions.clear(); - continue; - } - if (can_group_by_all && select_binder.HasBoundColumns()) { - if (select_binder.BoundAggregates()) { - throw BinderException("Cannot mix aggregates with non-aggregated columns!"); - } - if (is_window) { - throw BinderException("Cannot group on a window clause"); - } - if (result->unnests.size() > unnest_count) { - throw BinderException("Cannot group on an UNNEST or UNLIST clause"); - } - // we are forcing aggregates, and the node has columns bound - // this entry becomes a group - group_by_all_indexes.push_back(i); - } - result->select_list.push_back(std::move(expr)); - if (is_original_column) { - new_names.push_back(std::move(result->names[i])); - result->types.push_back(result_type); - } - internal_sql_types.push_back(result_type); - if (can_group_by_all) { - select_binder.ResetBindings(); - } - } - // push the GROUP BY ALL expressions into the group set - for (auto &group_by_all_index : group_by_all_indexes) { - auto &expr = result->select_list[group_by_all_index]; - auto group_ref = make_uniq( - expr->return_type, ColumnBinding(result->group_index, result->groups.group_expressions.size())); - result->groups.group_expressions.push_back(std::move(expr)); - expr = std::move(group_ref); - } - result->column_count = new_names.size(); - result->names = std::move(new_names); - result->need_prune = result->select_list.size() > result->column_count; - - // in the normal select binder, we bind columns as if there is no aggregation - // i.e. in the query [SELECT i, SUM(i) FROM integers;] the "i" will be bound as a normal column - // since we have an aggregation, we need to either (1) throw an error, or (2) wrap the column in a FIRST() aggregate - // we choose the former one [CONTROVERSIAL: this is the PostgreSQL behavior] - if (!result->groups.group_expressions.empty() || !result->aggregates.empty() || statement.having || - !result->groups.grouping_sets.empty()) { - if (statement.aggregate_handling == AggregateHandling::NO_AGGREGATES_ALLOWED) { - throw BinderException("Aggregates cannot be present in a Project relation!"); - } else { - vector> to_check_binders; - to_check_binders.push_back(select_binder); - if (qualify_binder) { - to_check_binders.push_back(*qualify_binder); - } - for (auto &binder : to_check_binders) { - auto &sel_binder = binder.get(); - if (!sel_binder.HasBoundColumns()) { - continue; - } - auto &bound_columns = sel_binder.GetBoundColumns(); - string error; - error = "column \"%s\" must appear in the GROUP BY clause or must be part of an aggregate function."; - if (statement.aggregate_handling == AggregateHandling::FORCE_AGGREGATES) { - error += "\nGROUP BY ALL will only group entries in the SELECT list. Add it to the SELECT list or " - "GROUP BY this entry explicitly."; - } else { - error += - "\nEither add it to the GROUP BY list, or use \"ANY_VALUE(%s)\" if the exact value of \"%s\" " - "is not important."; - } - throw BinderException(FormatError(bound_columns[0].query_location, error, bound_columns[0].name, - bound_columns[0].name, bound_columns[0].name)); - } - } - } - - // QUALIFY clause requires at least one window function to be specified in at least one of the SELECT column list or - // the filter predicate of the QUALIFY clause - if (statement.qualify && result->windows.empty()) { - throw BinderException("at least one window function must appear in the SELECT column or QUALIFY clause"); - } - - // now that the SELECT list is bound, we set the types of DISTINCT/ORDER BY expressions - BindModifierTypes(*result, internal_sql_types, result->projection_index); - return std::move(result); -} - -} // namespace duckdb - -#endif diff --git a/lib/duckdb-9.cpp b/lib/duckdb-9.cpp deleted file mode 100644 index be56c894..00000000 --- a/lib/duckdb-9.cpp +++ /dev/null @@ -1,21153 +0,0 @@ -#ifdef GODUCKDB_FROM_SOURCE -// See https://raw.githubusercontent.com/duckdb/duckdb/main/LICENSE for licensing information - -#include "duckdb.hpp" -#include "duckdb-internal.hpp" -#ifndef DUCKDB_AMALGAMATION -#error header mismatch -#endif - - - - - - - - - - - - -namespace duckdb { - -static void GatherAliases(BoundQueryNode &node, case_insensitive_map_t &aliases, - parsed_expression_map_t &expressions, const vector &reorder_idx) { - if (node.type == QueryNodeType::SET_OPERATION_NODE) { - // setop, recurse - auto &setop = node.Cast(); - - // create new reorder index - if (setop.setop_type == SetOperationType::UNION_BY_NAME) { - vector new_left_reorder_idx(setop.left_reorder_idx.size()); - vector new_right_reorder_idx(setop.right_reorder_idx.size()); - for (idx_t i = 0; i < setop.left_reorder_idx.size(); ++i) { - new_left_reorder_idx[i] = reorder_idx[setop.left_reorder_idx[i]]; - } - - for (idx_t i = 0; i < setop.right_reorder_idx.size(); ++i) { - new_right_reorder_idx[i] = reorder_idx[setop.right_reorder_idx[i]]; - } - - // use new reorder index - GatherAliases(*setop.left, aliases, expressions, new_left_reorder_idx); - GatherAliases(*setop.right, aliases, expressions, new_right_reorder_idx); - return; - } - - GatherAliases(*setop.left, aliases, expressions, reorder_idx); - GatherAliases(*setop.right, aliases, expressions, reorder_idx); - } else { - // query node - D_ASSERT(node.type == QueryNodeType::SELECT_NODE); - auto &select = node.Cast(); - // fill the alias lists - for (idx_t i = 0; i < select.names.size(); i++) { - auto &name = select.names[i]; - auto &expr = select.original_expressions[i]; - // first check if the alias is already in there - auto entry = aliases.find(name); - - idx_t index = reorder_idx[i]; - - if (entry != aliases.end()) { - // the alias already exists - // check if there is a conflict - - if (entry->second != index) { - // there is a conflict - // we place "-1" in the aliases map at this location - // "-1" signifies that there is an ambiguous reference - aliases[name] = DConstants::INVALID_INDEX; - } - } else { - // the alias is not in there yet, just assign it - aliases[name] = index; - } - // now check if the node is already in the set of expressions - auto expr_entry = expressions.find(*expr); - if (expr_entry != expressions.end()) { - // the node is in there - // repeat the same as with the alias: if there is an ambiguity we insert "-1" - if (expr_entry->second != index) { - expressions[*expr] = DConstants::INVALID_INDEX; - } - } else { - // not in there yet, just place it in there - expressions[*expr] = index; - } - } - } -} - -static void BuildUnionByNameInfo(BoundSetOperationNode &result, bool can_contain_nulls) { - D_ASSERT(result.setop_type == SetOperationType::UNION_BY_NAME); - case_insensitive_map_t left_names_map; - case_insensitive_map_t right_names_map; - - BoundQueryNode *left_node = result.left.get(); - BoundQueryNode *right_node = result.right.get(); - - // Build a name_map to use to check if a name exists - // We throw a binder exception if two same name in the SELECT list - for (idx_t i = 0; i < left_node->names.size(); ++i) { - if (left_names_map.find(left_node->names[i]) != left_names_map.end()) { - throw BinderException("UNION(ALL) BY NAME operation doesn't support same name in SELECT list"); - } - left_names_map[left_node->names[i]] = i; - } - - for (idx_t i = 0; i < right_node->names.size(); ++i) { - if (right_names_map.find(right_node->names[i]) != right_names_map.end()) { - throw BinderException("UNION(ALL) BY NAME operation doesn't support same name in SELECT list"); - } - if (left_names_map.find(right_node->names[i]) == left_names_map.end()) { - result.names.push_back(right_node->names[i]); - } - right_names_map[right_node->names[i]] = i; - } - - idx_t new_size = result.names.size(); - bool need_reorder = false; - vector left_reorder_idx(left_node->names.size()); - vector right_reorder_idx(right_node->names.size()); - - // Construct return type and reorder_idxs - // reorder_idxs is used to gather correct alias_map - // and expression_map in GatherAlias(...) - for (idx_t i = 0; i < new_size; ++i) { - auto left_index = left_names_map.find(result.names[i]); - auto right_index = right_names_map.find(result.names[i]); - bool left_exist = left_index != left_names_map.end(); - bool right_exist = right_index != right_names_map.end(); - LogicalType result_type; - if (left_exist && right_exist) { - result_type = LogicalType::MaxLogicalType(left_node->types[left_index->second], - right_node->types[right_index->second]); - if (left_index->second != i || right_index->second != i) { - need_reorder = true; - } - left_reorder_idx[left_index->second] = i; - right_reorder_idx[right_index->second] = i; - } else if (left_exist) { - result_type = left_node->types[left_index->second]; - need_reorder = true; - left_reorder_idx[left_index->second] = i; - } else { - D_ASSERT(right_exist); - result_type = right_node->types[right_index->second]; - need_reorder = true; - right_reorder_idx[right_index->second] = i; - } - - if (!can_contain_nulls) { - if (ExpressionBinder::ContainsNullType(result_type)) { - result_type = ExpressionBinder::ExchangeNullType(result_type); - } - } - - result.types.push_back(result_type); - } - - result.left_reorder_idx = std::move(left_reorder_idx); - result.right_reorder_idx = std::move(right_reorder_idx); - - // If reorder is required, collect reorder expressions for push projection - // into the two child nodes of union node - if (need_reorder) { - for (idx_t i = 0; i < new_size; ++i) { - auto left_index = left_names_map.find(result.names[i]); - auto right_index = right_names_map.find(result.names[i]); - bool left_exist = left_index != left_names_map.end(); - bool right_exist = right_index != right_names_map.end(); - unique_ptr left_reorder_expr; - unique_ptr right_reorder_expr; - if (left_exist && right_exist) { - left_reorder_expr = make_uniq( - left_node->types[left_index->second], ColumnBinding(left_node->GetRootIndex(), left_index->second)); - right_reorder_expr = - make_uniq(right_node->types[right_index->second], - ColumnBinding(right_node->GetRootIndex(), right_index->second)); - } else if (left_exist) { - left_reorder_expr = make_uniq( - left_node->types[left_index->second], ColumnBinding(left_node->GetRootIndex(), left_index->second)); - // create null value here - right_reorder_expr = make_uniq(Value(result.types[i])); - } else { - D_ASSERT(right_exist); - left_reorder_expr = make_uniq(Value(result.types[i])); - right_reorder_expr = - make_uniq(right_node->types[right_index->second], - ColumnBinding(right_node->GetRootIndex(), right_index->second)); - } - result.left_reorder_exprs.push_back(std::move(left_reorder_expr)); - result.right_reorder_exprs.push_back(std::move(right_reorder_expr)); - } - } -} - -unique_ptr Binder::BindNode(SetOperationNode &statement) { - auto result = make_uniq(); - result->setop_type = statement.setop_type; - - // first recursively visit the set operations - // both the left and right sides have an independent BindContext and Binder - D_ASSERT(statement.left); - D_ASSERT(statement.right); - - result->setop_index = GenerateTableIndex(); - - result->left_binder = Binder::CreateBinder(context, this); - result->left_binder->can_contain_nulls = true; - result->left = result->left_binder->BindNode(*statement.left); - result->right_binder = Binder::CreateBinder(context, this); - result->right_binder->can_contain_nulls = true; - result->right = result->right_binder->BindNode(*statement.right); - - result->names = result->left->names; - - // move the correlated expressions from the child binders to this binder - MoveCorrelatedExpressions(*result->left_binder); - MoveCorrelatedExpressions(*result->right_binder); - - // now both sides have been bound we can resolve types - if (result->setop_type != SetOperationType::UNION_BY_NAME && - result->left->types.size() != result->right->types.size()) { - throw BinderException("Set operations can only apply to expressions with the " - "same number of result columns"); - } - - if (result->setop_type == SetOperationType::UNION_BY_NAME) { - BuildUnionByNameInfo(*result, can_contain_nulls); - - } else { - // figure out the types of the setop result by picking the max of both - for (idx_t i = 0; i < result->left->types.size(); i++) { - auto result_type = LogicalType::MaxLogicalType(result->left->types[i], result->right->types[i]); - if (!can_contain_nulls) { - if (ExpressionBinder::ContainsNullType(result_type)) { - result_type = ExpressionBinder::ExchangeNullType(result_type); - } - } - result->types.push_back(result_type); - } - } - - if (!statement.modifiers.empty()) { - // handle the ORDER BY/DISTINCT clauses - - // we recursively visit the children of this node to extract aliases and expressions that can be referenced - // in the ORDER BY - case_insensitive_map_t alias_map; - parsed_expression_map_t expression_map; - - if (result->setop_type == SetOperationType::UNION_BY_NAME) { - GatherAliases(*result->left, alias_map, expression_map, result->left_reorder_idx); - GatherAliases(*result->right, alias_map, expression_map, result->right_reorder_idx); - } else { - vector reorder_idx; - for (idx_t i = 0; i < result->names.size(); i++) { - reorder_idx.push_back(i); - } - GatherAliases(*result, alias_map, expression_map, reorder_idx); - } - // now we perform the actual resolution of the ORDER BY/DISTINCT expressions - OrderBinder order_binder({result->left_binder.get(), result->right_binder.get()}, result->setop_index, - alias_map, expression_map, result->names.size()); - BindModifiers(order_binder, statement, *result); - } - - // finally bind the types of the ORDER/DISTINCT clause expressions - BindModifierTypes(*result, result->types, result->setop_index); - return std::move(result); -} - -} // namespace duckdb - - - - - - - - - - - - - - - - -namespace duckdb { - -unique_ptr Binder::BindTableMacro(FunctionExpression &function, TableMacroCatalogEntry ¯o_func, - idx_t depth) { - - auto ¯o_def = macro_func.function->Cast(); - auto node = macro_def.query_node->Copy(); - - // auto ¯o_def = *macro_func->function; - - // validate the arguments and separate positional and default arguments - vector> positionals; - unordered_map> defaults; - string error = - MacroFunction::ValidateArguments(*macro_func.function, macro_func.name, function, positionals, defaults); - if (!error.empty()) { - // cannot use error below as binder rnot in scope - // return BindResult(binder. FormatError(*expr->get(), error)); - throw BinderException(FormatError(function, error)); - } - - // create a MacroBinding to bind this macro's parameters to its arguments - vector types; - vector names; - // positional parameters - for (idx_t i = 0; i < macro_def.parameters.size(); i++) { - types.emplace_back(LogicalType::SQLNULL); - auto ¶m = macro_def.parameters[i]->Cast(); - names.push_back(param.GetColumnName()); - } - // default parameters - for (auto it = macro_def.default_parameters.begin(); it != macro_def.default_parameters.end(); it++) { - types.emplace_back(LogicalType::SQLNULL); - names.push_back(it->first); - // now push the defaults into the positionals - positionals.push_back(std::move(defaults[it->first])); - } - auto new_macro_binding = make_uniq(types, names, macro_func.name); - new_macro_binding->arguments = &positionals; - - // We need an ExpressionBinder so that we can call ExpressionBinder::ReplaceMacroParametersRecursive() - auto eb = ExpressionBinder(*this, this->context); - - eb.macro_binding = new_macro_binding.get(); - - /* Does it all goes throu every expression in a selectstmt */ - ParsedExpressionIterator::EnumerateQueryNodeChildren( - *node, [&](unique_ptr &child) { eb.ReplaceMacroParametersRecursive(child); }); - - return node; -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundCTENode &node) { - // Generate the logical plan for the cte_query and child. - auto cte_query = CreatePlan(*node.query); - auto cte_child = CreatePlan(*node.child); - - auto root = make_uniq(node.ctename, node.setop_index, node.types.size(), - std::move(cte_query), std::move(cte_child)); - - // check if there are any unplanned subqueries left in either child - has_unplanned_dependent_joins = - node.child_binder->has_unplanned_dependent_joins || node.query_binder->has_unplanned_dependent_joins; - - return VisitQueryNode(node, std::move(root)); -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -unique_ptr Binder::VisitQueryNode(BoundQueryNode &node, unique_ptr root) { - D_ASSERT(root); - for (auto &mod : node.modifiers) { - switch (mod->type) { - case ResultModifierType::DISTINCT_MODIFIER: { - auto &bound = mod->Cast(); - auto distinct = make_uniq(std::move(bound.target_distincts), bound.distinct_type); - distinct->AddChild(std::move(root)); - root = std::move(distinct); - break; - } - case ResultModifierType::ORDER_MODIFIER: { - auto &bound = mod->Cast(); - if (root->type == LogicalOperatorType::LOGICAL_DISTINCT) { - auto &distinct = root->Cast(); - if (distinct.distinct_type == DistinctType::DISTINCT_ON) { - auto order_by = make_uniq(); - for (auto &order_node : bound.orders) { - order_by->orders.push_back(order_node.Copy()); - } - distinct.order_by = std::move(order_by); - } - } - auto order = make_uniq(std::move(bound.orders)); - order->AddChild(std::move(root)); - root = std::move(order); - break; - } - case ResultModifierType::LIMIT_MODIFIER: { - auto &bound = mod->Cast(); - auto limit = make_uniq(bound.limit_val, bound.offset_val, std::move(bound.limit), - std::move(bound.offset)); - limit->AddChild(std::move(root)); - root = std::move(limit); - break; - } - case ResultModifierType::LIMIT_PERCENT_MODIFIER: { - auto &bound = mod->Cast(); - auto limit = make_uniq(bound.limit_percent, bound.offset_val, std::move(bound.limit), - std::move(bound.offset)); - limit->AddChild(std::move(root)); - root = std::move(limit); - break; - } - default: - throw BinderException("Unimplemented modifier type!"); - } - } - return root; -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundRecursiveCTENode &node) { - // Generate the logical plan for the left and right sides of the set operation - node.left_binder->is_outside_flattened = is_outside_flattened; - node.right_binder->is_outside_flattened = is_outside_flattened; - - auto left_node = node.left_binder->CreatePlan(*node.left); - auto right_node = node.right_binder->CreatePlan(*node.right); - - // check if there are any unplanned subqueries left in either child - has_unplanned_dependent_joins = - node.left_binder->has_unplanned_dependent_joins || node.right_binder->has_unplanned_dependent_joins; - - // for both the left and right sides, cast them to the same types - left_node = CastLogicalOperatorToTypes(node.left->types, node.types, std::move(left_node)); - right_node = CastLogicalOperatorToTypes(node.right->types, node.types, std::move(right_node)); - - if (!node.right_binder->bind_context.cte_references[node.ctename] || - *node.right_binder->bind_context.cte_references[node.ctename] == 0) { - auto root = make_uniq(node.setop_index, node.types.size(), std::move(left_node), - std::move(right_node), LogicalOperatorType::LOGICAL_UNION); - return VisitQueryNode(node, std::move(root)); - } - auto root = make_uniq(node.ctename, node.setop_index, node.types.size(), node.union_all, - std::move(left_node), std::move(right_node)); - - return VisitQueryNode(node, std::move(root)); -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -unique_ptr Binder::PlanFilter(unique_ptr condition, unique_ptr root) { - PlanSubqueries(condition, root); - auto filter = make_uniq(std::move(condition)); - filter->AddChild(std::move(root)); - return std::move(filter); -} - -unique_ptr Binder::CreatePlan(BoundSelectNode &statement) { - unique_ptr root; - D_ASSERT(statement.from_table); - root = CreatePlan(*statement.from_table); - D_ASSERT(root); - - // plan the sample clause - if (statement.sample_options) { - root = make_uniq(std::move(statement.sample_options), std::move(root)); - } - - if (statement.where_clause) { - root = PlanFilter(std::move(statement.where_clause), std::move(root)); - } - - if (!statement.aggregates.empty() || !statement.groups.group_expressions.empty()) { - if (!statement.groups.group_expressions.empty()) { - // visit the groups - for (auto &group : statement.groups.group_expressions) { - PlanSubqueries(group, root); - } - } - // now visit all aggregate expressions - for (auto &expr : statement.aggregates) { - PlanSubqueries(expr, root); - } - // finally create the aggregate node with the group_index and aggregate_index as obtained from the binder - auto aggregate = make_uniq(statement.group_index, statement.aggregate_index, - std::move(statement.aggregates)); - aggregate->groups = std::move(statement.groups.group_expressions); - aggregate->groupings_index = statement.groupings_index; - aggregate->grouping_sets = std::move(statement.groups.grouping_sets); - aggregate->grouping_functions = std::move(statement.grouping_functions); - - aggregate->AddChild(std::move(root)); - root = std::move(aggregate); - } else if (!statement.groups.grouping_sets.empty()) { - // edge case: we have grouping sets but no groups or aggregates - // this can only happen if we have e.g. select 1 from tbl group by (); - // just output a dummy scan - root = make_uniq_base(statement.group_index); - } - - if (statement.having) { - PlanSubqueries(statement.having, root); - auto having = make_uniq(std::move(statement.having)); - - having->AddChild(std::move(root)); - root = std::move(having); - } - - if (!statement.windows.empty()) { - auto win = make_uniq(statement.window_index); - win->expressions = std::move(statement.windows); - // visit the window expressions - for (auto &expr : win->expressions) { - PlanSubqueries(expr, root); - } - D_ASSERT(!win->expressions.empty()); - win->AddChild(std::move(root)); - root = std::move(win); - } - - if (statement.qualify) { - PlanSubqueries(statement.qualify, root); - auto qualify = make_uniq(std::move(statement.qualify)); - - qualify->AddChild(std::move(root)); - root = std::move(qualify); - } - - for (idx_t i = statement.unnests.size(); i > 0; i--) { - auto unnest_level = i - 1; - auto entry = statement.unnests.find(unnest_level); - if (entry == statement.unnests.end()) { - throw InternalException("unnests specified at level %d but none were found", unnest_level); - } - auto &unnest_node = entry->second; - auto unnest = make_uniq(unnest_node.index); - unnest->expressions = std::move(unnest_node.expressions); - // visit the unnest expressions - for (auto &expr : unnest->expressions) { - PlanSubqueries(expr, root); - } - D_ASSERT(!unnest->expressions.empty()); - unnest->AddChild(std::move(root)); - root = std::move(unnest); - } - - for (auto &expr : statement.select_list) { - PlanSubqueries(expr, root); - } - - auto proj = make_uniq(statement.projection_index, std::move(statement.select_list)); - auto &projection = *proj; - proj->AddChild(std::move(root)); - root = std::move(proj); - - // finish the plan by handling the elements of the QueryNode - root = VisitQueryNode(statement, std::move(root)); - - // add a prune node if necessary - if (statement.need_prune) { - D_ASSERT(root); - vector> prune_expressions; - for (idx_t i = 0; i < statement.column_count; i++) { - prune_expressions.push_back(make_uniq( - projection.expressions[i]->return_type, ColumnBinding(statement.projection_index, i))); - } - auto prune = make_uniq(statement.prune_index, std::move(prune_expressions)); - prune->AddChild(std::move(root)); - root = std::move(prune); - } - return root; -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -// Optionally push a PROJECTION operator -unique_ptr Binder::CastLogicalOperatorToTypes(vector &source_types, - vector &target_types, - unique_ptr op) { - D_ASSERT(op); - // first check if we even need to cast - D_ASSERT(source_types.size() == target_types.size()); - if (source_types == target_types) { - // source and target types are equal: don't need to cast - return op; - } - // otherwise add casts - auto node = op.get(); - if (node->type == LogicalOperatorType::LOGICAL_PROJECTION) { - // "node" is a projection; we can just do the casts in there - D_ASSERT(node->expressions.size() == source_types.size()); - // add the casts to the selection list - for (idx_t i = 0; i < target_types.size(); i++) { - if (source_types[i] != target_types[i]) { - // differing types, have to add a cast - string alias = node->expressions[i]->alias; - node->expressions[i] = - BoundCastExpression::AddCastToType(context, std::move(node->expressions[i]), target_types[i]); - node->expressions[i]->alias = alias; - } - } - return op; - } else { - // found a non-projection operator - // push a new projection containing the casts - - // fetch the set of column bindings - auto setop_columns = op->GetColumnBindings(); - D_ASSERT(setop_columns.size() == source_types.size()); - - // now generate the expression list - vector> select_list; - for (idx_t i = 0; i < target_types.size(); i++) { - unique_ptr result = make_uniq(source_types[i], setop_columns[i]); - if (source_types[i] != target_types[i]) { - // add a cast only if the source and target types are not equivalent - result = BoundCastExpression::AddCastToType(context, std::move(result), target_types[i]); - } - select_list.push_back(std::move(result)); - } - auto projection = make_uniq(GenerateTableIndex(), std::move(select_list)); - projection->children.push_back(std::move(op)); - return std::move(projection); - } -} - -unique_ptr Binder::CreatePlan(BoundSetOperationNode &node) { - // Generate the logical plan for the left and right sides of the set operation - node.left_binder->is_outside_flattened = is_outside_flattened; - node.right_binder->is_outside_flattened = is_outside_flattened; - - auto left_node = node.left_binder->CreatePlan(*node.left); - auto right_node = node.right_binder->CreatePlan(*node.right); - - // Add a new projection to child node - D_ASSERT(node.left_reorder_exprs.size() == node.right_reorder_exprs.size()); - if (!node.left_reorder_exprs.empty()) { - D_ASSERT(node.setop_type == SetOperationType::UNION_BY_NAME); - vector left_types; - vector right_types; - // We are going to add a new projection operator, so collect the type - // of reorder exprs in order to call CastLogicalOperatorToTypes() - for (idx_t i = 0; i < node.left_reorder_exprs.size(); ++i) { - left_types.push_back(node.left_reorder_exprs[i]->return_type); - right_types.push_back(node.right_reorder_exprs[i]->return_type); - } - - auto left_projection = make_uniq(GenerateTableIndex(), std::move(node.left_reorder_exprs)); - left_projection->children.push_back(std::move(left_node)); - left_node = std::move(left_projection); - - auto right_projection = make_uniq(GenerateTableIndex(), std::move(node.right_reorder_exprs)); - right_projection->children.push_back(std::move(right_node)); - right_node = std::move(right_projection); - - left_node = CastLogicalOperatorToTypes(left_types, node.types, std::move(left_node)); - right_node = CastLogicalOperatorToTypes(right_types, node.types, std::move(right_node)); - } else { - left_node = CastLogicalOperatorToTypes(node.left->types, node.types, std::move(left_node)); - right_node = CastLogicalOperatorToTypes(node.right->types, node.types, std::move(right_node)); - } - - // check if there are any unplanned subqueries left in either child - has_unplanned_dependent_joins = - node.left_binder->has_unplanned_dependent_joins || node.right_binder->has_unplanned_dependent_joins; - - // create actual logical ops for setops - LogicalOperatorType logical_type; - switch (node.setop_type) { - case SetOperationType::UNION: - case SetOperationType::UNION_BY_NAME: - logical_type = LogicalOperatorType::LOGICAL_UNION; - break; - case SetOperationType::EXCEPT: - logical_type = LogicalOperatorType::LOGICAL_EXCEPT; - break; - default: - D_ASSERT(node.setop_type == SetOperationType::INTERSECT); - logical_type = LogicalOperatorType::LOGICAL_INTERSECT; - break; - } - - auto root = make_uniq(node.setop_index, node.types.size(), std::move(left_node), - std::move(right_node), logical_type); - - return VisitQueryNode(node, std::move(root)); -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - - - - -namespace duckdb { - -static unique_ptr PlanUncorrelatedSubquery(Binder &binder, BoundSubqueryExpression &expr, - unique_ptr &root, - unique_ptr plan) { - D_ASSERT(!expr.IsCorrelated()); - switch (expr.subquery_type) { - case SubqueryType::EXISTS: { - // uncorrelated EXISTS - // we only care about existence, hence we push a LIMIT 1 operator - auto limit = make_uniq(1, 0, nullptr, nullptr); - limit->AddChild(std::move(plan)); - plan = std::move(limit); - - // now we push a COUNT(*) aggregate onto the limit, this will be either 0 or 1 (EXISTS or NOT EXISTS) - auto count_star_fun = CountStarFun::GetFunction(); - - FunctionBinder function_binder(binder.context); - auto count_star = - function_binder.BindAggregateFunction(count_star_fun, {}, nullptr, AggregateType::NON_DISTINCT); - auto idx_type = count_star->return_type; - vector> aggregate_list; - aggregate_list.push_back(std::move(count_star)); - auto aggregate_index = binder.GenerateTableIndex(); - auto aggregate = - make_uniq(binder.GenerateTableIndex(), aggregate_index, std::move(aggregate_list)); - aggregate->AddChild(std::move(plan)); - plan = std::move(aggregate); - - // now we push a projection with a comparison to 1 - auto left_child = make_uniq(idx_type, ColumnBinding(aggregate_index, 0)); - auto right_child = make_uniq(Value::Numeric(idx_type, 1)); - auto comparison = make_uniq(ExpressionType::COMPARE_EQUAL, std::move(left_child), - std::move(right_child)); - - vector> projection_list; - projection_list.push_back(std::move(comparison)); - auto projection_index = binder.GenerateTableIndex(); - auto projection = make_uniq(projection_index, std::move(projection_list)); - projection->AddChild(std::move(plan)); - plan = std::move(projection); - - // we add it to the main query by adding a cross product - // FIXME: should use something else besides cross product as we always add only one scalar constant - root = LogicalCrossProduct::Create(std::move(root), std::move(plan)); - - // we replace the original subquery with a ColumnRefExpression referring to the result of the projection (either - // TRUE or FALSE) - return make_uniq(expr.GetName(), LogicalType::BOOLEAN, - ColumnBinding(projection_index, 0)); - } - case SubqueryType::SCALAR: { - // uncorrelated scalar, we want to return the first entry - // figure out the table index of the bound table of the entry which we want to return - auto bindings = plan->GetColumnBindings(); - D_ASSERT(bindings.size() == 1); - idx_t table_idx = bindings[0].table_index; - - // in the uncorrelated case we are only interested in the first result of the query - // hence we simply push a LIMIT 1 to get the first row of the subquery - auto limit = make_uniq(1, 0, nullptr, nullptr); - limit->AddChild(std::move(plan)); - plan = std::move(limit); - - // we push an aggregate that returns the FIRST element - vector> expressions; - auto bound = make_uniq(expr.return_type, ColumnBinding(table_idx, 0)); - vector> first_children; - first_children.push_back(std::move(bound)); - - FunctionBinder function_binder(binder.context); - auto first_agg = function_binder.BindAggregateFunction( - FirstFun::GetFunction(expr.return_type), std::move(first_children), nullptr, AggregateType::NON_DISTINCT); - - expressions.push_back(std::move(first_agg)); - auto aggr_index = binder.GenerateTableIndex(); - auto aggr = make_uniq(binder.GenerateTableIndex(), aggr_index, std::move(expressions)); - aggr->AddChild(std::move(plan)); - plan = std::move(aggr); - - // in the uncorrelated case, we add the value to the main query through a cross product - // FIXME: should use something else besides cross product as we always add only one scalar constant and cross - // product is not optimized for this. - D_ASSERT(root); - root = LogicalCrossProduct::Create(std::move(root), std::move(plan)); - - // we replace the original subquery with a BoundColumnRefExpression referring to the first result of the - // aggregation - return make_uniq(expr.GetName(), expr.return_type, ColumnBinding(aggr_index, 0)); - } - default: { - D_ASSERT(expr.subquery_type == SubqueryType::ANY); - // we generate a MARK join that results in either (TRUE, FALSE or NULL) - // subquery has NULL values -> result is (TRUE or NULL) - // subquery has no NULL values -> result is (TRUE, FALSE or NULL [if input is NULL]) - // fetch the column bindings - auto plan_columns = plan->GetColumnBindings(); - - // then we generate the MARK join with the subquery - idx_t mark_index = binder.GenerateTableIndex(); - auto join = make_uniq(JoinType::MARK); - join->mark_index = mark_index; - join->AddChild(std::move(root)); - join->AddChild(std::move(plan)); - // create the JOIN condition - JoinCondition cond; - cond.left = std::move(expr.child); - cond.right = BoundCastExpression::AddDefaultCastToType( - make_uniq(expr.child_type, plan_columns[0]), expr.child_target); - cond.comparison = expr.comparison_type; - join->conditions.push_back(std::move(cond)); - root = std::move(join); - - // we replace the original subquery with a BoundColumnRefExpression referring to the mark column - return make_uniq(expr.GetName(), expr.return_type, ColumnBinding(mark_index, 0)); - } - } -} - -static unique_ptr -CreateDuplicateEliminatedJoin(const vector &correlated_columns, JoinType join_type, - unique_ptr original_plan, bool perform_delim) { - auto delim_join = make_uniq(join_type, LogicalOperatorType::LOGICAL_DELIM_JOIN); - if (!perform_delim) { - // if we are not performing a delim join, we push a row_number() OVER() window operator on the LHS - // and perform all duplicate elimination on that row number instead - D_ASSERT(correlated_columns[0].type.id() == LogicalTypeId::BIGINT); - auto window = make_uniq(correlated_columns[0].binding.table_index); - auto row_number = - make_uniq(ExpressionType::WINDOW_ROW_NUMBER, LogicalType::BIGINT, nullptr, nullptr); - row_number->start = WindowBoundary::UNBOUNDED_PRECEDING; - row_number->end = WindowBoundary::CURRENT_ROW_ROWS; - row_number->alias = "delim_index"; - window->expressions.push_back(std::move(row_number)); - window->AddChild(std::move(original_plan)); - original_plan = std::move(window); - } - delim_join->AddChild(std::move(original_plan)); - for (idx_t i = 0; i < correlated_columns.size(); i++) { - auto &col = correlated_columns[i]; - delim_join->duplicate_eliminated_columns.push_back(make_uniq(col.type, col.binding)); - delim_join->mark_types.push_back(col.type); - } - return delim_join; -} - -static void CreateDelimJoinConditions(LogicalComparisonJoin &delim_join, - const vector &correlated_columns, - vector bindings, idx_t base_offset, bool perform_delim) { - auto col_count = perform_delim ? correlated_columns.size() : 1; - for (idx_t i = 0; i < col_count; i++) { - auto &col = correlated_columns[i]; - auto binding_idx = base_offset + i; - if (binding_idx >= bindings.size()) { - throw InternalException("Delim join - binding index out of range"); - } - JoinCondition cond; - cond.left = make_uniq(col.name, col.type, col.binding); - cond.right = make_uniq(col.name, col.type, bindings[binding_idx]); - cond.comparison = ExpressionType::COMPARE_NOT_DISTINCT_FROM; - delim_join.conditions.push_back(std::move(cond)); - } -} - -static bool PerformDelimOnType(const LogicalType &type) { - if (type.InternalType() == PhysicalType::LIST) { - return false; - } - if (type.InternalType() == PhysicalType::STRUCT) { - for (auto &entry : StructType::GetChildTypes(type)) { - if (!PerformDelimOnType(entry.second)) { - return false; - } - } - } - return true; -} - -static bool PerformDuplicateElimination(Binder &binder, vector &correlated_columns) { - if (!ClientConfig::GetConfig(binder.context).enable_optimizer) { - // if optimizations are disabled we always do a delim join - return true; - } - bool perform_delim = true; - for (auto &col : correlated_columns) { - if (!PerformDelimOnType(col.type)) { - perform_delim = false; - break; - } - } - if (perform_delim) { - return true; - } - auto binding = ColumnBinding(binder.GenerateTableIndex(), 0); - auto type = LogicalType::BIGINT; - auto name = "delim_index"; - CorrelatedColumnInfo info(binding, type, name, 0); - correlated_columns.insert(correlated_columns.begin(), std::move(info)); - return false; -} - -static unique_ptr PlanCorrelatedSubquery(Binder &binder, BoundSubqueryExpression &expr, - unique_ptr &root, - unique_ptr plan) { - auto &correlated_columns = expr.binder->correlated_columns; - // FIXME: there should be a way of disabling decorrelation for ANY queries as well, but not for now... - bool perform_delim = - expr.subquery_type == SubqueryType::ANY ? true : PerformDuplicateElimination(binder, correlated_columns); - D_ASSERT(expr.IsCorrelated()); - // correlated subquery - // for a more in-depth explanation of this code, read the paper "Unnesting Arbitrary Subqueries" - // we handle three types of correlated subqueries: Scalar, EXISTS and ANY - // all three cases are very similar with some minor changes (mainly the type of join performed at the end) - switch (expr.subquery_type) { - case SubqueryType::SCALAR: { - // correlated SCALAR query - // first push a DUPLICATE ELIMINATED join - // a duplicate eliminated join creates a duplicate eliminated copy of the LHS - // and pushes it into any DUPLICATE_ELIMINATED SCAN operators on the RHS - - // in the SCALAR case, we create a SINGLE join (because we are only interested in obtaining the value) - // NULL values are equal in this join because we join on the correlated columns ONLY - // and e.g. in the query: SELECT (SELECT 42 FROM integers WHERE i1.i IS NULL LIMIT 1) FROM integers i1; - // the input value NULL will generate the value 42, and we need to join NULL on the LHS with NULL on the RHS - // the left side is the original plan - // this is the side that will be duplicate eliminated and pushed into the RHS - auto delim_join = - CreateDuplicateEliminatedJoin(correlated_columns, JoinType::SINGLE, std::move(root), perform_delim); - - // the right side initially is a DEPENDENT join between the duplicate eliminated scan and the subquery - // HOWEVER: we do not explicitly create the dependent join - // instead, we eliminate the dependent join by pushing it down into the right side of the plan - FlattenDependentJoins flatten(binder, correlated_columns, perform_delim); - - // first we check which logical operators have correlated expressions in the first place - flatten.DetectCorrelatedExpressions(plan.get()); - // now we push the dependent join down - auto dependent_join = flatten.PushDownDependentJoin(std::move(plan)); - - // now the dependent join is fully eliminated - // we only need to create the join conditions between the LHS and the RHS - // fetch the set of columns - auto plan_columns = dependent_join->GetColumnBindings(); - - // now create the join conditions - CreateDelimJoinConditions(*delim_join, correlated_columns, plan_columns, flatten.delim_offset, perform_delim); - delim_join->AddChild(std::move(dependent_join)); - root = std::move(delim_join); - // finally push the BoundColumnRefExpression referring to the data element returned by the join - return make_uniq(expr.GetName(), expr.return_type, plan_columns[flatten.data_offset]); - } - case SubqueryType::EXISTS: { - // correlated EXISTS query - // this query is similar to the correlated SCALAR query, except we use a MARK join here - idx_t mark_index = binder.GenerateTableIndex(); - auto delim_join = - CreateDuplicateEliminatedJoin(correlated_columns, JoinType::MARK, std::move(root), perform_delim); - delim_join->mark_index = mark_index; - // RHS - FlattenDependentJoins flatten(binder, correlated_columns, perform_delim, true); - flatten.DetectCorrelatedExpressions(plan.get()); - auto dependent_join = flatten.PushDownDependentJoin(std::move(plan)); - - // fetch the set of columns - auto plan_columns = dependent_join->GetColumnBindings(); - - // now we create the join conditions between the dependent join and the original table - CreateDelimJoinConditions(*delim_join, correlated_columns, plan_columns, flatten.delim_offset, perform_delim); - delim_join->AddChild(std::move(dependent_join)); - root = std::move(delim_join); - // finally push the BoundColumnRefExpression referring to the marker - return make_uniq(expr.GetName(), expr.return_type, ColumnBinding(mark_index, 0)); - } - default: { - D_ASSERT(expr.subquery_type == SubqueryType::ANY); - // correlated ANY query - // this query is similar to the correlated SCALAR query - // however, in this case we push a correlated MARK join - // note that in this join null values are NOT equal for ALL columns, but ONLY for the correlated columns - // the correlated mark join handles this case by itself - // as the MARK join has one extra join condition (the original condition, of the ANY expression, e.g. - // [i=ANY(...)]) - idx_t mark_index = binder.GenerateTableIndex(); - auto delim_join = - CreateDuplicateEliminatedJoin(correlated_columns, JoinType::MARK, std::move(root), perform_delim); - delim_join->mark_index = mark_index; - // RHS - FlattenDependentJoins flatten(binder, correlated_columns, true, true); - flatten.DetectCorrelatedExpressions(plan.get()); - auto dependent_join = flatten.PushDownDependentJoin(std::move(plan)); - - // fetch the columns - auto plan_columns = dependent_join->GetColumnBindings(); - - // now we create the join conditions between the dependent join and the original table - CreateDelimJoinConditions(*delim_join, correlated_columns, plan_columns, flatten.delim_offset, perform_delim); - // add the actual condition based on the ANY/ALL predicate - JoinCondition compare_cond; - compare_cond.left = std::move(expr.child); - compare_cond.right = BoundCastExpression::AddDefaultCastToType( - make_uniq(expr.child_type, plan_columns[0]), expr.child_target); - compare_cond.comparison = expr.comparison_type; - delim_join->conditions.push_back(std::move(compare_cond)); - - delim_join->AddChild(std::move(dependent_join)); - root = std::move(delim_join); - // finally push the BoundColumnRefExpression referring to the marker - return make_uniq(expr.GetName(), expr.return_type, ColumnBinding(mark_index, 0)); - } - } -} - -void RecursiveDependentJoinPlanner::VisitOperator(LogicalOperator &op) { - if (!op.children.empty()) { - root = std::move(op.children[0]); - D_ASSERT(root); - if (root->type == LogicalOperatorType::LOGICAL_DEPENDENT_JOIN) { - // Found a dependent join, flatten it - auto &new_root = root->Cast(); - root = binder.PlanLateralJoin(std::move(new_root.children[0]), std::move(new_root.children[1]), - new_root.correlated_columns, new_root.join_type, - std::move(new_root.join_condition)); - } - VisitOperatorExpressions(op); - op.children[0] = std::move(root); - for (idx_t i = 0; i < op.children.size(); i++) { - D_ASSERT(op.children[i]); - VisitOperator(*op.children[i]); - } - } -} - -unique_ptr RecursiveDependentJoinPlanner::VisitReplace(BoundSubqueryExpression &expr, - unique_ptr *expr_ptr) { - return binder.PlanSubquery(expr, root); -} - -unique_ptr Binder::PlanSubquery(BoundSubqueryExpression &expr, unique_ptr &root) { - D_ASSERT(root); - // first we translate the QueryNode of the subquery into a logical plan - // note that we do not plan nested subqueries yet - auto sub_binder = Binder::CreateBinder(context, this); - sub_binder->is_outside_flattened = false; - auto subquery_root = sub_binder->CreatePlan(*expr.subquery); - D_ASSERT(subquery_root); - - // now we actually flatten the subquery - auto plan = std::move(subquery_root); - - unique_ptr result_expression; - if (!expr.IsCorrelated()) { - result_expression = PlanUncorrelatedSubquery(*this, expr, root, std::move(plan)); - } else { - result_expression = PlanCorrelatedSubquery(*this, expr, root, std::move(plan)); - } - // finally, we recursively plan the nested subqueries (if there are any) - if (sub_binder->has_unplanned_dependent_joins) { - RecursiveDependentJoinPlanner plan(*this); - plan.VisitOperator(*root); - } - return result_expression; -} - -void Binder::PlanSubqueries(unique_ptr &expr_ptr, unique_ptr &root) { - if (!expr_ptr) { - return; - } - auto &expr = *expr_ptr; - // first visit the children of the node, if any - ExpressionIterator::EnumerateChildren(expr, [&](unique_ptr &expr) { PlanSubqueries(expr, root); }); - - // check if this is a subquery node - if (expr.expression_class == ExpressionClass::BOUND_SUBQUERY) { - auto &subquery = expr.Cast(); - // subquery node! plan it - if (subquery.IsCorrelated() && !is_outside_flattened) { - // detected a nested correlated subquery - // we don't plan it yet here, we are currently planning a subquery - // nested subqueries will only be planned AFTER the current subquery has been flattened entirely - has_unplanned_dependent_joins = true; - return; - } - expr_ptr = PlanSubquery(subquery, root); - } -} - -unique_ptr Binder::PlanLateralJoin(unique_ptr left, unique_ptr right, - vector &correlated_columns, - JoinType join_type, unique_ptr condition) { - // scan the right operator for correlated columns - // correlated LATERAL JOIN - vector conditions; - vector> arbitrary_expressions; - if (condition) { - // extract join conditions, if there are any - LogicalComparisonJoin::ExtractJoinConditions(context, join_type, left, right, std::move(condition), conditions, - arbitrary_expressions); - } - - auto perform_delim = PerformDuplicateElimination(*this, correlated_columns); - auto delim_join = CreateDuplicateEliminatedJoin(correlated_columns, join_type, std::move(left), perform_delim); - - FlattenDependentJoins flatten(*this, correlated_columns, perform_delim); - - // first we check which logical operators have correlated expressions in the first place - flatten.DetectCorrelatedExpressions(right.get(), true); - // now we push the dependent join down - auto dependent_join = flatten.PushDownDependentJoin(std::move(right)); - - // now the dependent join is fully eliminated - // we only need to create the join conditions between the LHS and the RHS - // fetch the set of columns - auto plan_columns = dependent_join->GetColumnBindings(); - - // now create the join conditions - // start off with the conditions that were passed in (if any) - D_ASSERT(delim_join->conditions.empty()); - delim_join->conditions = std::move(conditions); - // then add the delim join conditions - CreateDelimJoinConditions(*delim_join, correlated_columns, plan_columns, flatten.delim_offset, perform_delim); - delim_join->AddChild(std::move(dependent_join)); - - // check if there are any arbitrary expressions left - if (!arbitrary_expressions.empty()) { - // we can only evaluate scalar arbitrary expressions for inner joins - if (join_type != JoinType::INNER) { - throw BinderException( - "Join condition for non-inner LATERAL JOIN must be a comparison between the left and right side"); - } - auto filter = make_uniq(); - filter->expressions = std::move(arbitrary_expressions); - filter->AddChild(std::move(delim_join)); - return std::move(filter); - } - return std::move(delim_join); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -BoundStatement Binder::Bind(AttachStatement &stmt) { - BoundStatement result; - result.types = {LogicalType::BOOLEAN}; - result.names = {"Success"}; - - result.plan = make_uniq(LogicalOperatorType::LOGICAL_ATTACH, std::move(stmt.info)); - properties.allow_stream_result = false; - properties.return_type = StatementReturnType::NOTHING; - return result; -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -BoundStatement Binder::Bind(CallStatement &stmt) { - BoundStatement result; - - TableFunctionRef ref; - ref.function = std::move(stmt.function); - - auto bound_func = Bind(ref); - auto &bound_table_func = bound_func->Cast(); - ; - auto &get = bound_table_func.get->Cast(); - D_ASSERT(get.returned_types.size() > 0); - for (idx_t i = 0; i < get.returned_types.size(); i++) { - get.column_ids.push_back(i); - } - - result.types = get.returned_types; - result.names = get.names; - result.plan = CreatePlan(*bound_func); - properties.return_type = StatementReturnType::QUERY_RESULT; - return result; -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - - - - - -#include - -namespace duckdb { - -vector GetUniqueNames(const vector &original_names) { - unordered_set name_set; - vector unique_names; - unique_names.reserve(original_names.size()); - - for (auto &name : original_names) { - auto insert_result = name_set.insert(name); - if (insert_result.second == false) { - // Could not be inserted, name already exists - idx_t index = 1; - string postfixed_name; - while (true) { - postfixed_name = StringUtil::Format("%s:%d", name, index); - auto res = name_set.insert(postfixed_name); - if (!res.second) { - index++; - continue; - } - break; - } - unique_names.push_back(postfixed_name); - } else { - unique_names.push_back(name); - } - } - return unique_names; -} - -BoundStatement Binder::BindCopyTo(CopyStatement &stmt) { - // COPY TO a file - auto &config = DBConfig::GetConfig(context); - if (!config.options.enable_external_access) { - throw PermissionException("COPY TO is disabled by configuration"); - } - BoundStatement result; - result.types = {LogicalType::BIGINT}; - result.names = {"Count"}; - - // lookup the format in the catalog - auto ©_function = - Catalog::GetEntry(context, INVALID_CATALOG, DEFAULT_SCHEMA, stmt.info->format); - if (copy_function.function.plan) { - // plan rewrite COPY TO - return copy_function.function.plan(*this, stmt); - } - - // bind the select statement - auto select_node = Bind(*stmt.select_statement); - - if (!copy_function.function.copy_to_bind) { - throw NotImplementedException("COPY TO is not supported for FORMAT \"%s\"", stmt.info->format); - } - bool use_tmp_file = true; - bool overwrite_or_ignore = false; - FilenamePattern filename_pattern; - bool user_set_use_tmp_file = false; - bool per_thread_output = false; - vector partition_cols; - - auto original_options = stmt.info->options; - stmt.info->options.clear(); - - for (auto &option : original_options) { - auto loption = StringUtil::Lower(option.first); - if (loption == "use_tmp_file") { - use_tmp_file = - option.second.empty() || option.second[0].CastAs(context, LogicalType::BOOLEAN).GetValue(); - user_set_use_tmp_file = true; - continue; - } - if (loption == "overwrite_or_ignore") { - overwrite_or_ignore = - option.second.empty() || option.second[0].CastAs(context, LogicalType::BOOLEAN).GetValue(); - continue; - } - if (loption == "filename_pattern") { - if (option.second.empty()) { - throw IOException("FILENAME_PATTERN cannot be empty"); - } - filename_pattern.SetFilenamePattern( - option.second[0].CastAs(context, LogicalType::VARCHAR).GetValue()); - continue; - } - - if (loption == "per_thread_output") { - per_thread_output = - option.second.empty() || option.second[0].CastAs(context, LogicalType::BOOLEAN).GetValue(); - continue; - } - if (loption == "partition_by") { - auto converted = ConvertVectorToValue(std::move(option.second)); - partition_cols = ParseColumnsOrdered(converted, select_node.names, loption); - continue; - } - stmt.info->options[option.first] = option.second; - } - if (user_set_use_tmp_file && per_thread_output) { - throw NotImplementedException("Can't combine USE_TMP_FILE and PER_THREAD_OUTPUT for COPY"); - } - if (user_set_use_tmp_file && !partition_cols.empty()) { - throw NotImplementedException("Can't combine USE_TMP_FILE and PARTITION_BY for COPY"); - } - if (per_thread_output && !partition_cols.empty()) { - throw NotImplementedException("Can't combine PER_THREAD_OUTPUT and PARTITION_BY for COPY"); - } - bool is_remote_file = config.file_system->IsRemoteFile(stmt.info->file_path); - if (is_remote_file) { - use_tmp_file = false; - } else { - bool is_file_and_exists = config.file_system->FileExists(stmt.info->file_path); - bool is_stdout = stmt.info->file_path == "/dev/stdout"; - if (!user_set_use_tmp_file) { - use_tmp_file = is_file_and_exists && !per_thread_output && partition_cols.empty() && !is_stdout; - } - } - - auto unique_column_names = GetUniqueNames(select_node.names); - - auto function_data = - copy_function.function.copy_to_bind(context, *stmt.info, unique_column_names, select_node.types); - // now create the copy information - auto copy = make_uniq(copy_function.function, std::move(function_data)); - copy->file_path = stmt.info->file_path; - copy->use_tmp_file = use_tmp_file; - copy->overwrite_or_ignore = overwrite_or_ignore; - copy->filename_pattern = filename_pattern; - copy->per_thread_output = per_thread_output; - copy->partition_output = !partition_cols.empty(); - copy->partition_columns = std::move(partition_cols); - - copy->names = unique_column_names; - copy->expected_types = select_node.types; - - copy->AddChild(std::move(select_node.plan)); - - result.plan = std::move(copy); - - return result; -} - -BoundStatement Binder::BindCopyFrom(CopyStatement &stmt) { - auto &config = DBConfig::GetConfig(context); - if (!config.options.enable_external_access) { - throw PermissionException("COPY FROM is disabled by configuration"); - } - BoundStatement result; - result.types = {LogicalType::BIGINT}; - result.names = {"Count"}; - - if (stmt.info->table.empty()) { - throw ParserException("COPY FROM requires a table name to be specified"); - } - // COPY FROM a file - // generate an insert statement for the the to-be-inserted table - InsertStatement insert; - insert.table = stmt.info->table; - insert.schema = stmt.info->schema; - insert.catalog = stmt.info->catalog; - insert.columns = stmt.info->select_list; - - // bind the insert statement to the base table - auto insert_statement = Bind(insert); - D_ASSERT(insert_statement.plan->type == LogicalOperatorType::LOGICAL_INSERT); - - auto &bound_insert = insert_statement.plan->Cast(); - - // lookup the format in the catalog - auto &catalog = Catalog::GetSystemCatalog(context); - auto ©_function = catalog.GetEntry(context, DEFAULT_SCHEMA, stmt.info->format); - if (!copy_function.function.copy_from_bind) { - throw NotImplementedException("COPY FROM is not supported for FORMAT \"%s\"", stmt.info->format); - } - // lookup the table to copy into - BindSchemaOrCatalog(stmt.info->catalog, stmt.info->schema); - auto &table = - Catalog::GetEntry(context, stmt.info->catalog, stmt.info->schema, stmt.info->table); - vector expected_names; - if (!bound_insert.column_index_map.empty()) { - expected_names.resize(bound_insert.expected_types.size()); - for (auto &col : table.GetColumns().Physical()) { - auto i = col.Physical(); - if (bound_insert.column_index_map[i] != DConstants::INVALID_INDEX) { - expected_names[bound_insert.column_index_map[i]] = col.Name(); - } - } - } else { - expected_names.reserve(bound_insert.expected_types.size()); - for (auto &col : table.GetColumns().Physical()) { - expected_names.push_back(col.Name()); - } - } - - auto function_data = - copy_function.function.copy_from_bind(context, *stmt.info, expected_names, bound_insert.expected_types); - auto get = make_uniq(GenerateTableIndex(), copy_function.function.copy_from_function, - std::move(function_data), bound_insert.expected_types, expected_names); - for (idx_t i = 0; i < bound_insert.expected_types.size(); i++) { - get->column_ids.push_back(i); - } - insert_statement.plan->children.push_back(std::move(get)); - result.plan = std::move(insert_statement.plan); - return result; -} - -BoundStatement Binder::Bind(CopyStatement &stmt) { - if (!stmt.info->is_from && !stmt.select_statement) { - // copy table into file without a query - // generate SELECT * FROM table; - auto ref = make_uniq(); - ref->catalog_name = stmt.info->catalog; - ref->schema_name = stmt.info->schema; - ref->table_name = stmt.info->table; - - auto statement = make_uniq(); - statement->from_table = std::move(ref); - if (!stmt.info->select_list.empty()) { - for (auto &name : stmt.info->select_list) { - statement->select_list.push_back(make_uniq(name)); - } - } else { - statement->select_list.push_back(make_uniq()); - } - stmt.select_statement = std::move(statement); - } - properties.allow_stream_result = false; - properties.return_type = StatementReturnType::CHANGED_ROWS; - if (stmt.info->is_from) { - return BindCopyFrom(stmt); - } else { - return BindCopyTo(stmt); - } -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -namespace duckdb { - -void Binder::BindSchemaOrCatalog(ClientContext &context, string &catalog, string &schema) { - if (catalog.empty() && !schema.empty()) { - // schema is specified - but catalog is not - // try searching for the catalog instead - auto &db_manager = DatabaseManager::Get(context); - auto database = db_manager.GetDatabase(context, schema); - if (database) { - // we have a database with this name - // check if there is a schema - auto schema_obj = Catalog::GetSchema(context, INVALID_CATALOG, schema, OnEntryNotFound::RETURN_NULL); - if (schema_obj) { - auto &attached = schema_obj->catalog.GetAttached(); - throw BinderException( - "Ambiguous reference to catalog or schema \"%s\" - use a fully qualified path like \"%s.%s\"", - schema, attached.GetName(), schema); - } - catalog = schema; - schema = string(); - } - } -} - -void Binder::BindSchemaOrCatalog(string &catalog, string &schema) { - BindSchemaOrCatalog(context, catalog, schema); -} - -SchemaCatalogEntry &Binder::BindSchema(CreateInfo &info) { - BindSchemaOrCatalog(info.catalog, info.schema); - if (IsInvalidCatalog(info.catalog) && info.temporary) { - info.catalog = TEMP_CATALOG; - } - auto &search_path = ClientData::Get(context).catalog_search_path; - if (IsInvalidCatalog(info.catalog) && IsInvalidSchema(info.schema)) { - auto &default_entry = search_path->GetDefault(); - info.catalog = default_entry.catalog; - info.schema = default_entry.schema; - } else if (IsInvalidSchema(info.schema)) { - info.schema = search_path->GetDefaultSchema(info.catalog); - } else if (IsInvalidCatalog(info.catalog)) { - info.catalog = search_path->GetDefaultCatalog(info.schema); - } - if (IsInvalidCatalog(info.catalog)) { - info.catalog = DatabaseManager::GetDefaultDatabase(context); - } - if (!info.temporary) { - // non-temporary create: not read only - if (info.catalog == TEMP_CATALOG) { - throw ParserException("Only TEMPORARY table names can use the \"%s\" catalog", TEMP_CATALOG); - } - } else { - if (info.catalog != TEMP_CATALOG) { - throw ParserException("TEMPORARY table names can *only* use the \"%s\" catalog", TEMP_CATALOG); - } - } - // fetch the schema in which we want to create the object - auto &schema_obj = Catalog::GetSchema(context, info.catalog, info.schema); - D_ASSERT(schema_obj.type == CatalogType::SCHEMA_ENTRY); - info.schema = schema_obj.name; - if (!info.temporary) { - properties.modified_databases.insert(schema_obj.catalog.GetName()); - } - return schema_obj; -} - -SchemaCatalogEntry &Binder::BindCreateSchema(CreateInfo &info) { - auto &schema = BindSchema(info); - if (schema.catalog.IsSystemCatalog()) { - throw BinderException("Cannot create entry in system catalog"); - } - return schema; -} - -void Binder::BindCreateViewInfo(CreateViewInfo &base) { - // bind the view as if it were a query so we can catch errors - // note that we bind the original, and replace the original with a copy - auto view_binder = Binder::CreateBinder(context); - view_binder->can_contain_nulls = true; - - auto copy = base.query->Copy(); - auto query_node = view_binder->Bind(*base.query); - base.query = unique_ptr_cast(std::move(copy)); - if (base.aliases.size() > query_node.names.size()) { - throw BinderException("More VIEW aliases than columns in query result"); - } - // fill up the aliases with the remaining names of the bound query - base.aliases.reserve(query_node.names.size()); - for (idx_t i = base.aliases.size(); i < query_node.names.size(); i++) { - base.aliases.push_back(query_node.names[i]); - } - base.types = query_node.types; -} - -SchemaCatalogEntry &Binder::BindCreateFunctionInfo(CreateInfo &info) { - auto &base = info.Cast(); - auto &scalar_function = base.function->Cast(); - - if (scalar_function.expression->HasParameter()) { - throw BinderException("Parameter expressions within macro's are not supported!"); - } - - // create macro binding in order to bind the function - vector dummy_types; - vector dummy_names; - // positional parameters - for (idx_t i = 0; i < base.function->parameters.size(); i++) { - auto param = base.function->parameters[i]->Cast(); - if (param.IsQualified()) { - throw BinderException("Invalid parameter name '%s': must be unqualified", param.ToString()); - } - dummy_types.emplace_back(LogicalType::SQLNULL); - dummy_names.push_back(param.GetColumnName()); - } - // default parameters - for (auto it = base.function->default_parameters.begin(); it != base.function->default_parameters.end(); it++) { - auto &val = it->second->Cast(); - dummy_types.push_back(val.value.type()); - dummy_names.push_back(it->first); - } - auto this_macro_binding = make_uniq(dummy_types, dummy_names, base.name); - macro_binding = this_macro_binding.get(); - ExpressionBinder::QualifyColumnNames(*this, scalar_function.expression); - - // create a copy of the expression because we do not want to alter the original - auto expression = scalar_function.expression->Copy(); - - // bind it to verify the function was defined correctly - string error; - auto sel_node = make_uniq(); - auto group_info = make_uniq(); - SelectBinder binder(*this, context, *sel_node, *group_info); - error = binder.Bind(expression, 0, false); - - if (!error.empty()) { - throw BinderException(error); - } - - return BindCreateSchema(info); -} - -void Binder::BindLogicalType(ClientContext &context, LogicalType &type, optional_ptr catalog, - const string &schema) { - if (type.id() == LogicalTypeId::LIST || type.id() == LogicalTypeId::MAP) { - auto child_type = ListType::GetChildType(type); - BindLogicalType(context, child_type, catalog, schema); - auto alias = type.GetAlias(); - if (type.id() == LogicalTypeId::LIST) { - type = LogicalType::LIST(child_type); - } else { - D_ASSERT(child_type.id() == LogicalTypeId::STRUCT); // map must be list of structs - type = LogicalType::MAP(child_type); - } - - type.SetAlias(alias); - } else if (type.id() == LogicalTypeId::STRUCT) { - auto child_types = StructType::GetChildTypes(type); - for (auto &child_type : child_types) { - BindLogicalType(context, child_type.second, catalog, schema); - } - // Generate new Struct Type - auto alias = type.GetAlias(); - type = LogicalType::STRUCT(child_types); - type.SetAlias(alias); - } else if (type.id() == LogicalTypeId::UNION) { - auto member_types = UnionType::CopyMemberTypes(type); - for (auto &member_type : member_types) { - BindLogicalType(context, member_type.second, catalog, schema); - } - // Generate new Union Type - auto alias = type.GetAlias(); - type = LogicalType::UNION(member_types); - type.SetAlias(alias); - } else if (type.id() == LogicalTypeId::USER) { - auto user_type_name = UserType::GetTypeName(type); - if (catalog) { - // The search order is: - // 1) In the same schema as the table - // 2) In the same catalog - // 3) System catalog - type = catalog->GetType(context, schema, user_type_name, OnEntryNotFound::RETURN_NULL); - - if (type.id() == LogicalTypeId::INVALID) { - type = catalog->GetType(context, INVALID_SCHEMA, user_type_name, OnEntryNotFound::RETURN_NULL); - } - - if (type.id() == LogicalTypeId::INVALID) { - type = Catalog::GetType(context, INVALID_CATALOG, schema, user_type_name); - } - } else { - type = Catalog::GetType(context, INVALID_CATALOG, schema, user_type_name); - } - BindLogicalType(context, type, catalog, schema); - } -} - -static void FindMatchingPrimaryKeyColumns(const ColumnList &columns, const vector> &constraints, - ForeignKeyConstraint &fk) { - // find the matching primary key constraint - bool found_constraint = false; - // if no columns are defined, we will automatically try to bind to the primary key - bool find_primary_key = fk.pk_columns.empty(); - for (auto &constr : constraints) { - if (constr->type != ConstraintType::UNIQUE) { - continue; - } - auto &unique = constr->Cast(); - if (find_primary_key && !unique.is_primary_key) { - continue; - } - found_constraint = true; - - vector pk_names; - if (unique.index.index != DConstants::INVALID_INDEX) { - pk_names.push_back(columns.GetColumn(LogicalIndex(unique.index)).Name()); - } else { - pk_names = unique.columns; - } - if (find_primary_key) { - // found matching primary key - if (pk_names.size() != fk.fk_columns.size()) { - auto pk_name_str = StringUtil::Join(pk_names, ","); - auto fk_name_str = StringUtil::Join(fk.fk_columns, ","); - throw BinderException( - "Failed to create foreign key: number of referencing (%s) and referenced columns (%s) differ", - fk_name_str, pk_name_str); - } - fk.pk_columns = pk_names; - return; - } - if (pk_names.size() != fk.fk_columns.size()) { - // the number of referencing and referenced columns for foreign keys must be the same - continue; - } - bool equals = true; - for (idx_t i = 0; i < fk.pk_columns.size(); i++) { - if (!StringUtil::CIEquals(fk.pk_columns[i], pk_names[i])) { - equals = false; - break; - } - } - if (!equals) { - continue; - } - // found match - return; - } - // no match found! examine why - if (!found_constraint) { - // no unique constraint or primary key - string search_term = find_primary_key ? "primary key" : "primary key or unique constraint"; - throw BinderException("Failed to create foreign key: there is no %s for referenced table \"%s\"", search_term, - fk.info.table); - } - // check if all the columns exist - for (auto &name : fk.pk_columns) { - bool found = columns.ColumnExists(name); - if (!found) { - throw BinderException( - "Failed to create foreign key: referenced table \"%s\" does not have a column named \"%s\"", - fk.info.table, name); - } - } - auto fk_names = StringUtil::Join(fk.pk_columns, ","); - throw BinderException("Failed to create foreign key: referenced table \"%s\" does not have a primary key or unique " - "constraint on the columns %s", - fk.info.table, fk_names); -} - -static void FindForeignKeyIndexes(const ColumnList &columns, const vector &names, - vector &indexes) { - D_ASSERT(indexes.empty()); - D_ASSERT(!names.empty()); - for (auto &name : names) { - if (!columns.ColumnExists(name)) { - throw BinderException("column \"%s\" named in key does not exist", name); - } - auto &column = columns.GetColumn(name); - if (column.Generated()) { - throw BinderException("Failed to create foreign key: referenced column \"%s\" is a generated column", - column.Name()); - } - indexes.push_back(column.Physical()); - } -} - -static void CheckForeignKeyTypes(const ColumnList &pk_columns, const ColumnList &fk_columns, ForeignKeyConstraint &fk) { - D_ASSERT(fk.info.pk_keys.size() == fk.info.fk_keys.size()); - for (idx_t c_idx = 0; c_idx < fk.info.pk_keys.size(); c_idx++) { - auto &pk_col = pk_columns.GetColumn(fk.info.pk_keys[c_idx]); - auto &fk_col = fk_columns.GetColumn(fk.info.fk_keys[c_idx]); - if (pk_col.Type() != fk_col.Type()) { - throw BinderException("Failed to create foreign key: incompatible types between column \"%s\" (\"%s\") and " - "column \"%s\" (\"%s\")", - pk_col.Name(), pk_col.Type().ToString(), fk_col.Name(), fk_col.Type().ToString()); - } - } -} - -void ExpressionContainsGeneratedColumn(const ParsedExpression &expr, const unordered_set &gcols, - bool &contains_gcol) { - if (contains_gcol) { - return; - } - if (expr.type == ExpressionType::COLUMN_REF) { - auto &column_ref = expr.Cast(); - auto &name = column_ref.GetColumnName(); - if (gcols.count(name)) { - contains_gcol = true; - return; - } - } - ParsedExpressionIterator::EnumerateChildren( - expr, [&](const ParsedExpression &child) { ExpressionContainsGeneratedColumn(child, gcols, contains_gcol); }); -} - -static bool AnyConstraintReferencesGeneratedColumn(CreateTableInfo &table_info) { - unordered_set generated_columns; - for (auto &col : table_info.columns.Logical()) { - if (!col.Generated()) { - continue; - } - generated_columns.insert(col.Name()); - } - if (generated_columns.empty()) { - return false; - } - - for (auto &constr : table_info.constraints) { - switch (constr->type) { - case ConstraintType::CHECK: { - auto &constraint = constr->Cast(); - auto &expr = constraint.expression; - bool contains_generated_column = false; - ExpressionContainsGeneratedColumn(*expr, generated_columns, contains_generated_column); - if (contains_generated_column) { - return true; - } - break; - } - case ConstraintType::NOT_NULL: { - auto &constraint = constr->Cast(); - if (table_info.columns.GetColumn(constraint.index).Generated()) { - return true; - } - break; - } - case ConstraintType::UNIQUE: { - auto &constraint = constr->Cast(); - auto index = constraint.index; - if (index.index == DConstants::INVALID_INDEX) { - for (auto &col : constraint.columns) { - if (generated_columns.count(col)) { - return true; - } - } - } else { - if (table_info.columns.GetColumn(index).Generated()) { - return true; - } - } - break; - } - case ConstraintType::FOREIGN_KEY: { - // If it contained a generated column, an exception would have been thrown inside AddDataTableIndex earlier - break; - } - default: { - throw NotImplementedException("ConstraintType not implemented"); - } - } - } - return false; -} - -unique_ptr DuckCatalog::BindCreateIndex(Binder &binder, CreateStatement &stmt, - TableCatalogEntry &table, unique_ptr plan) { - D_ASSERT(plan->type == LogicalOperatorType::LOGICAL_GET); - auto &base = stmt.info->Cast(); - - auto &get = plan->Cast(); - // bind the index expressions - IndexBinder index_binder(binder, binder.context); - vector> expressions; - expressions.reserve(base.expressions.size()); - for (auto &expr : base.expressions) { - expressions.push_back(index_binder.Bind(expr)); - } - - auto create_index_info = unique_ptr_cast(std::move(stmt.info)); - for (auto &column_id : get.column_ids) { - if (column_id == COLUMN_IDENTIFIER_ROW_ID) { - throw BinderException("Cannot create an index on the rowid!"); - } - create_index_info->scan_types.push_back(get.returned_types[column_id]); - } - create_index_info->scan_types.emplace_back(LogicalType::ROW_TYPE); - create_index_info->names = get.names; - create_index_info->column_ids = get.column_ids; - auto &bind_data = get.bind_data->Cast(); - bind_data.is_create_index = true; - get.column_ids.push_back(COLUMN_IDENTIFIER_ROW_ID); - - // the logical CREATE INDEX also needs all fields to scan the referenced table - auto result = make_uniq(std::move(create_index_info), std::move(expressions), table); - result->children.push_back(std::move(plan)); - return std::move(result); -} - -BoundStatement Binder::Bind(CreateStatement &stmt) { - BoundStatement result; - result.names = {"Count"}; - result.types = {LogicalType::BIGINT}; - - auto catalog_type = stmt.info->type; - switch (catalog_type) { - case CatalogType::SCHEMA_ENTRY: - result.plan = make_uniq(LogicalOperatorType::LOGICAL_CREATE_SCHEMA, std::move(stmt.info)); - break; - case CatalogType::VIEW_ENTRY: { - auto &base = stmt.info->Cast(); - // bind the schema - auto &schema = BindCreateSchema(*stmt.info); - BindCreateViewInfo(base); - result.plan = make_uniq(LogicalOperatorType::LOGICAL_CREATE_VIEW, std::move(stmt.info), &schema); - break; - } - case CatalogType::SEQUENCE_ENTRY: { - auto &schema = BindCreateSchema(*stmt.info); - result.plan = - make_uniq(LogicalOperatorType::LOGICAL_CREATE_SEQUENCE, std::move(stmt.info), &schema); - break; - } - case CatalogType::TABLE_MACRO_ENTRY: { - auto &schema = BindCreateSchema(*stmt.info); - result.plan = - make_uniq(LogicalOperatorType::LOGICAL_CREATE_MACRO, std::move(stmt.info), &schema); - break; - } - case CatalogType::MACRO_ENTRY: { - auto &schema = BindCreateFunctionInfo(*stmt.info); - result.plan = - make_uniq(LogicalOperatorType::LOGICAL_CREATE_MACRO, std::move(stmt.info), &schema); - break; - } - case CatalogType::INDEX_ENTRY: { - auto &base = stmt.info->Cast(); - - // visit the table reference - auto table_ref = make_uniq(); - table_ref->catalog_name = base.catalog; - table_ref->schema_name = base.schema; - table_ref->table_name = base.table; - - auto bound_table = Bind(*table_ref); - if (bound_table->type != TableReferenceType::BASE_TABLE) { - throw BinderException("Can only create an index over a base table!"); - } - auto &table_binding = bound_table->Cast(); - auto &table = table_binding.table; - if (table.temporary) { - stmt.info->temporary = true; - } - // create a plan over the bound table - auto plan = CreatePlan(*bound_table); - if (plan->type != LogicalOperatorType::LOGICAL_GET) { - throw BinderException("Cannot create index on a view!"); - } - result.plan = table.catalog.BindCreateIndex(*this, stmt, table, std::move(plan)); - break; - } - case CatalogType::TABLE_ENTRY: { - auto &create_info = stmt.info->Cast(); - // If there is a foreign key constraint, resolve primary key column's index from primary key column's name - reference_set_t fk_schemas; - for (idx_t i = 0; i < create_info.constraints.size(); i++) { - auto &cond = create_info.constraints[i]; - if (cond->type != ConstraintType::FOREIGN_KEY) { - continue; - } - auto &fk = cond->Cast(); - if (fk.info.type != ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE) { - continue; - } - D_ASSERT(fk.info.pk_keys.empty()); - D_ASSERT(fk.info.fk_keys.empty()); - FindForeignKeyIndexes(create_info.columns, fk.fk_columns, fk.info.fk_keys); - if (StringUtil::CIEquals(create_info.table, fk.info.table)) { - // self-referential foreign key constraint - fk.info.type = ForeignKeyType::FK_TYPE_SELF_REFERENCE_TABLE; - FindMatchingPrimaryKeyColumns(create_info.columns, create_info.constraints, fk); - FindForeignKeyIndexes(create_info.columns, fk.pk_columns, fk.info.pk_keys); - CheckForeignKeyTypes(create_info.columns, create_info.columns, fk); - } else { - // have to resolve referenced table - auto &pk_table_entry_ptr = - Catalog::GetEntry(context, INVALID_CATALOG, fk.info.schema, fk.info.table); - fk_schemas.insert(pk_table_entry_ptr.schema); - FindMatchingPrimaryKeyColumns(pk_table_entry_ptr.GetColumns(), pk_table_entry_ptr.GetConstraints(), fk); - FindForeignKeyIndexes(pk_table_entry_ptr.GetColumns(), fk.pk_columns, fk.info.pk_keys); - CheckForeignKeyTypes(pk_table_entry_ptr.GetColumns(), create_info.columns, fk); - auto &storage = pk_table_entry_ptr.GetStorage(); - auto index = storage.info->indexes.FindForeignKeyIndex(fk.info.pk_keys, - ForeignKeyType::FK_TYPE_PRIMARY_KEY_TABLE); - if (!index) { - auto fk_column_names = StringUtil::Join(fk.pk_columns, ","); - throw BinderException("Failed to create foreign key on %s(%s): no UNIQUE or PRIMARY KEY constraint " - "present on these columns", - pk_table_entry_ptr.name, fk_column_names); - } - } - D_ASSERT(fk.info.pk_keys.size() == fk.info.fk_keys.size()); - D_ASSERT(fk.info.pk_keys.size() == fk.pk_columns.size()); - D_ASSERT(fk.info.fk_keys.size() == fk.fk_columns.size()); - } - if (AnyConstraintReferencesGeneratedColumn(create_info)) { - throw BinderException("Constraints on generated columns are not supported yet"); - } - auto bound_info = BindCreateTableInfo(std::move(stmt.info)); - auto root = std::move(bound_info->query); - for (auto &fk_schema : fk_schemas) { - if (&fk_schema.get() != &bound_info->schema) { - throw BinderException("Creating foreign keys across different schemas or catalogs is not supported"); - } - } - - // create the logical operator - auto &schema = bound_info->schema; - auto create_table = make_uniq(schema, std::move(bound_info)); - if (root) { - // CREATE TABLE AS - properties.return_type = StatementReturnType::CHANGED_ROWS; - create_table->children.push_back(std::move(root)); - } - result.plan = std::move(create_table); - break; - } - case CatalogType::TYPE_ENTRY: { - auto &schema = BindCreateSchema(*stmt.info); - auto &create_type_info = stmt.info->Cast(); - result.plan = make_uniq(LogicalOperatorType::LOGICAL_CREATE_TYPE, std::move(stmt.info), &schema); - if (create_type_info.query) { - // CREATE TYPE mood AS ENUM (SELECT 'happy') - auto query_obj = Bind(*create_type_info.query); - auto query = std::move(query_obj.plan); - create_type_info.query.reset(); - - auto &sql_types = query_obj.types; - if (sql_types.size() != 1) { - // add cast expression? - throw BinderException("The query must return a single column"); - } - if (sql_types[0].id() != LogicalType::VARCHAR) { - // push a projection casting to varchar - vector> select_list; - auto ref = make_uniq(sql_types[0], query->GetColumnBindings()[0]); - auto cast_expr = BoundCastExpression::AddCastToType(context, std::move(ref), LogicalType::VARCHAR); - select_list.push_back(std::move(cast_expr)); - auto proj = make_uniq(GenerateTableIndex(), std::move(select_list)); - proj->AddChild(std::move(query)); - query = std::move(proj); - } - - result.plan->AddChild(std::move(query)); - } else if (create_type_info.type.id() == LogicalTypeId::USER) { - // two cases: - // 1: create a type with a non-existant type as source, catalog.GetType(...) will throw exception. - // 2: create a type alias with a custom type. - // eg. CREATE TYPE a AS INT; CREATE TYPE b AS a; - // We set b to be an alias for the underlying type of a - auto inner_type = Catalog::GetType(context, schema.catalog.GetName(), schema.name, - UserType::GetTypeName(create_type_info.type)); - inner_type.SetAlias(create_type_info.name); - create_type_info.type = inner_type; - } - break; - } - default: - throw Exception("Unrecognized type!"); - } - properties.return_type = StatementReturnType::NOTHING; - properties.allow_stream_result = false; - return result; -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - - - - - -#include - -namespace duckdb { - -static void CreateColumnDependencyManager(BoundCreateTableInfo &info) { - auto &base = info.base->Cast(); - for (auto &col : base.columns.Logical()) { - if (!col.Generated()) { - continue; - } - info.column_dependency_manager.AddGeneratedColumn(col, base.columns); - } -} - -static void BindCheckConstraint(Binder &binder, BoundCreateTableInfo &info, const unique_ptr &cond) { - auto &base = info.base->Cast(); - - auto bound_constraint = make_uniq(); - // check constraint: bind the expression - CheckBinder check_binder(binder, binder.context, base.table, base.columns, bound_constraint->bound_columns); - auto &check = cond->Cast(); - // create a copy of the unbound expression because the binding destroys the constraint - auto unbound_expression = check.expression->Copy(); - // now bind the constraint and create a new BoundCheckConstraint - bound_constraint->expression = check_binder.Bind(check.expression); - info.bound_constraints.push_back(std::move(bound_constraint)); - // move the unbound constraint back into the original check expression - check.expression = std::move(unbound_expression); -} - -static void BindConstraints(Binder &binder, BoundCreateTableInfo &info) { - auto &base = info.base->Cast(); - - bool has_primary_key = false; - logical_index_set_t not_null_columns; - vector primary_keys; - for (idx_t i = 0; i < base.constraints.size(); i++) { - auto &cond = base.constraints[i]; - switch (cond->type) { - case ConstraintType::CHECK: { - BindCheckConstraint(binder, info, cond); - break; - } - case ConstraintType::NOT_NULL: { - auto ¬_null = cond->Cast(); - auto &col = base.columns.GetColumn(LogicalIndex(not_null.index)); - info.bound_constraints.push_back(make_uniq(PhysicalIndex(col.StorageOid()))); - not_null_columns.insert(not_null.index); - break; - } - case ConstraintType::UNIQUE: { - auto &unique = cond->Cast(); - // have to resolve columns of the unique constraint - vector keys; - logical_index_set_t key_set; - if (unique.index.index != DConstants::INVALID_INDEX) { - D_ASSERT(unique.index.index < base.columns.LogicalColumnCount()); - // unique constraint is given by single index - unique.columns.push_back(base.columns.GetColumn(unique.index).Name()); - keys.push_back(unique.index); - key_set.insert(unique.index); - } else { - // unique constraint is given by list of names - // have to resolve names - D_ASSERT(!unique.columns.empty()); - for (auto &keyname : unique.columns) { - if (!base.columns.ColumnExists(keyname)) { - throw ParserException("column \"%s\" named in key does not exist", keyname); - } - auto &column = base.columns.GetColumn(keyname); - auto column_index = column.Logical(); - if (key_set.find(column_index) != key_set.end()) { - throw ParserException("column \"%s\" appears twice in " - "primary key constraint", - keyname); - } - keys.push_back(column_index); - key_set.insert(column_index); - } - } - - if (unique.is_primary_key) { - // we can only have one primary key per table - if (has_primary_key) { - throw ParserException("table \"%s\" has more than one primary key", base.table); - } - has_primary_key = true; - primary_keys = keys; - } - info.bound_constraints.push_back( - make_uniq(std::move(keys), std::move(key_set), unique.is_primary_key)); - break; - } - case ConstraintType::FOREIGN_KEY: { - auto &fk = cond->Cast(); - D_ASSERT((fk.info.type == ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE && !fk.info.pk_keys.empty()) || - (fk.info.type == ForeignKeyType::FK_TYPE_PRIMARY_KEY_TABLE && !fk.info.pk_keys.empty()) || - fk.info.type == ForeignKeyType::FK_TYPE_SELF_REFERENCE_TABLE); - physical_index_set_t fk_key_set, pk_key_set; - for (idx_t i = 0; i < fk.info.pk_keys.size(); i++) { - if (pk_key_set.find(fk.info.pk_keys[i]) != pk_key_set.end()) { - throw BinderException("Duplicate primary key referenced in FOREIGN KEY constraint"); - } - pk_key_set.insert(fk.info.pk_keys[i]); - } - for (idx_t i = 0; i < fk.info.fk_keys.size(); i++) { - if (fk_key_set.find(fk.info.fk_keys[i]) != fk_key_set.end()) { - throw BinderException("Duplicate key specified in FOREIGN KEY constraint"); - } - fk_key_set.insert(fk.info.fk_keys[i]); - } - info.bound_constraints.push_back( - make_uniq(fk.info, std::move(pk_key_set), std::move(fk_key_set))); - break; - } - default: - throw NotImplementedException("unrecognized constraint type in bind"); - } - } - if (has_primary_key) { - // if there is a primary key index, also create a NOT NULL constraint for each of the columns - for (auto &column_index : primary_keys) { - if (not_null_columns.count(column_index)) { - //! No need to create a NotNullConstraint, it's already present - continue; - } - auto physical_index = base.columns.LogicalToPhysical(column_index); - base.constraints.push_back(make_uniq(column_index)); - info.bound_constraints.push_back(make_uniq(physical_index)); - } - } -} - -void Binder::BindGeneratedColumns(BoundCreateTableInfo &info) { - auto &base = info.base->Cast(); - - vector names; - vector types; - - D_ASSERT(base.type == CatalogType::TABLE_ENTRY); - for (auto &col : base.columns.Logical()) { - names.push_back(col.Name()); - types.push_back(col.Type()); - } - auto table_index = GenerateTableIndex(); - - // Create a new binder because we dont need (or want) these bindings in this scope - auto binder = Binder::CreateBinder(context); - binder->bind_context.AddGenericBinding(table_index, base.table, names, types); - auto expr_binder = ExpressionBinder(*binder, context); - string ignore; - auto table_binding = binder->bind_context.GetBinding(base.table, ignore); - D_ASSERT(table_binding && ignore.empty()); - - auto bind_order = info.column_dependency_manager.GetBindOrder(base.columns); - logical_index_set_t bound_indices; - - while (!bind_order.empty()) { - auto i = bind_order.top(); - bind_order.pop(); - auto &col = base.columns.GetColumnMutable(i); - - //! Already bound this previously - //! This can not be optimized out of the GetBindOrder function - //! These occurrences happen because we need to make sure that ALL dependencies of a column are resolved before - //! it gets resolved - if (bound_indices.count(i)) { - continue; - } - D_ASSERT(col.Generated()); - auto expression = col.GeneratedExpression().Copy(); - - auto bound_expression = expr_binder.Bind(expression); - D_ASSERT(bound_expression); - D_ASSERT(!bound_expression->HasSubquery()); - if (col.Type().id() == LogicalTypeId::ANY) { - // Do this before changing the type, so we know it's the first time the type is set - col.ChangeGeneratedExpressionType(bound_expression->return_type); - col.SetType(bound_expression->return_type); - - // Update the type in the binding, for future expansions - string ignore; - table_binding->types[i.index] = col.Type(); - } - bound_indices.insert(i); - } -} - -void Binder::BindDefaultValues(const ColumnList &columns, vector> &bound_defaults) { - for (auto &column : columns.Physical()) { - unique_ptr bound_default; - if (column.DefaultValue()) { - // we bind a copy of the DEFAULT value because binding is destructive - // and we want to keep the original expression around for serialization - auto default_copy = column.DefaultValue()->Copy(); - ConstantBinder default_binder(*this, context, "DEFAULT value"); - default_binder.target_type = column.Type(); - bound_default = default_binder.Bind(default_copy); - } else { - // no default value specified: push a default value of constant null - bound_default = make_uniq(Value(column.Type())); - } - bound_defaults.push_back(std::move(bound_default)); - } -} - -static void ExtractExpressionDependencies(Expression &expr, DependencyList &dependencies) { - if (expr.type == ExpressionType::BOUND_FUNCTION) { - auto &function = expr.Cast(); - if (function.function.dependency) { - function.function.dependency(function, dependencies); - } - } - ExpressionIterator::EnumerateChildren( - expr, [&](Expression &child) { ExtractExpressionDependencies(child, dependencies); }); -} - -static void ExtractDependencies(BoundCreateTableInfo &info) { - for (auto &default_value : info.bound_defaults) { - if (default_value) { - ExtractExpressionDependencies(*default_value, info.dependencies); - } - } - for (auto &constraint : info.bound_constraints) { - if (constraint->type == ConstraintType::CHECK) { - auto &bound_check = constraint->Cast(); - ExtractExpressionDependencies(*bound_check.expression, info.dependencies); - } - } -} -unique_ptr Binder::BindCreateTableInfo(unique_ptr info, SchemaCatalogEntry &schema) { - auto &base = info->Cast(); - auto result = make_uniq(schema, std::move(info)); - if (base.query) { - // construct the result object - auto query_obj = Bind(*base.query); - base.query.reset(); - result->query = std::move(query_obj.plan); - - // construct the set of columns based on the names and types of the query - auto &names = query_obj.names; - auto &sql_types = query_obj.types; - D_ASSERT(names.size() == sql_types.size()); - base.columns.SetAllowDuplicates(true); - for (idx_t i = 0; i < names.size(); i++) { - base.columns.AddColumn(ColumnDefinition(names[i], sql_types[i])); - } - CreateColumnDependencyManager(*result); - // bind the generated column expressions - BindGeneratedColumns(*result); - } else { - CreateColumnDependencyManager(*result); - // bind the generated column expressions - BindGeneratedColumns(*result); - // bind any constraints - BindConstraints(*this, *result); - // bind the default values - BindDefaultValues(base.columns, result->bound_defaults); - } - // extract dependencies from any default values or CHECK constraints - ExtractDependencies(*result); - - if (base.columns.PhysicalColumnCount() == 0) { - throw BinderException("Creating a table without physical (non-generated) columns is not supported"); - } - // bind collations to detect any unsupported collation errors - for (idx_t i = 0; i < base.columns.PhysicalColumnCount(); i++) { - auto &column = base.columns.GetColumnMutable(PhysicalIndex(i)); - if (column.Type().id() == LogicalTypeId::VARCHAR) { - ExpressionBinder::TestCollation(context, StringType::GetCollation(column.Type())); - } - BindLogicalType(context, column.TypeMutable(), &result->schema.catalog); - } - result->dependencies.VerifyDependencies(schema.catalog, result->Base().table); - properties.allow_stream_result = false; - return result; -} - -unique_ptr Binder::BindCreateTableInfo(unique_ptr info) { - auto &base = info->Cast(); - auto &schema = BindCreateSchema(base); - return BindCreateTableInfo(std::move(info), schema); -} - -vector> Binder::BindCreateIndexExpressions(TableCatalogEntry &table, CreateIndexInfo &info) { - auto index_binder = IndexBinder(*this, this->context, &table, &info); - vector> expressions; - expressions.reserve(info.expressions.size()); - for (auto &expr : info.expressions) { - expressions.push_back(index_binder.Bind(expr)); - } - - return expressions; -} - -} // namespace duckdb - - - - - - - - - - - - -namespace duckdb { - -BoundStatement Binder::Bind(DeleteStatement &stmt) { - BoundStatement result; - - // visit the table reference - auto bound_table = Bind(*stmt.table); - if (bound_table->type != TableReferenceType::BASE_TABLE) { - throw BinderException("Can only delete from base table!"); - } - auto &table_binding = bound_table->Cast(); - auto &table = table_binding.table; - - auto root = CreatePlan(*bound_table); - auto &get = root->Cast(); - D_ASSERT(root->type == LogicalOperatorType::LOGICAL_GET); - - if (!table.temporary) { - // delete from persistent table: not read only! - properties.modified_databases.insert(table.catalog.GetName()); - } - - // Add CTEs as bindable - AddCTEMap(stmt.cte_map); - - // plan any tables from the various using clauses - if (!stmt.using_clauses.empty()) { - unique_ptr child_operator; - for (auto &using_clause : stmt.using_clauses) { - // bind the using clause - auto using_binder = Binder::CreateBinder(context, this); - auto bound_node = using_binder->Bind(*using_clause); - auto op = CreatePlan(*bound_node); - if (child_operator) { - // already bound a child: create a cross product to unify the two - child_operator = LogicalCrossProduct::Create(std::move(child_operator), std::move(op)); - } else { - child_operator = std::move(op); - } - bind_context.AddContext(std::move(using_binder->bind_context)); - } - if (child_operator) { - root = LogicalCrossProduct::Create(std::move(root), std::move(child_operator)); - } - } - - // project any additional columns required for the condition - unique_ptr condition; - if (stmt.condition) { - WhereBinder binder(*this, context); - condition = binder.Bind(stmt.condition); - - PlanSubqueries(condition, root); - auto filter = make_uniq(std::move(condition)); - filter->AddChild(std::move(root)); - root = std::move(filter); - } - // create the delete node - auto del = make_uniq(table, GenerateTableIndex()); - del->AddChild(std::move(root)); - - // set up the delete expression - del->expressions.push_back(make_uniq( - LogicalType::ROW_TYPE, ColumnBinding(get.table_index, get.column_ids.size()))); - get.column_ids.push_back(COLUMN_IDENTIFIER_ROW_ID); - - if (!stmt.returning_list.empty()) { - del->return_chunk = true; - - auto update_table_index = GenerateTableIndex(); - del->table_index = update_table_index; - - unique_ptr del_as_logicaloperator = std::move(del); - return BindReturning(std::move(stmt.returning_list), table, stmt.table->alias, update_table_index, - std::move(del_as_logicaloperator), std::move(result)); - } - result.plan = std::move(del); - result.names = {"Count"}; - result.types = {LogicalType::BIGINT}; - properties.allow_stream_result = false; - properties.return_type = StatementReturnType::CHANGED_ROWS; - - return result; -} - -} // namespace duckdb - - - - - -namespace duckdb { - -BoundStatement Binder::Bind(DetachStatement &stmt) { - BoundStatement result; - - result.plan = make_uniq(LogicalOperatorType::LOGICAL_DETACH, std::move(stmt.info)); - result.names = {"Success"}; - result.types = {LogicalType::BOOLEAN}; - properties.allow_stream_result = false; - properties.return_type = StatementReturnType::NOTHING; - return result; -} - -} // namespace duckdb - - - - - - - - - - - -namespace duckdb { - -BoundStatement Binder::Bind(DropStatement &stmt) { - BoundStatement result; - - switch (stmt.info->type) { - case CatalogType::PREPARED_STATEMENT: - // dropping prepared statements is always possible - // it also does not require a valid transaction - properties.requires_valid_transaction = false; - break; - case CatalogType::SCHEMA_ENTRY: { - // dropping a schema is never read-only because there are no temporary schemas - auto &catalog = Catalog::GetCatalog(context, stmt.info->catalog); - properties.modified_databases.insert(catalog.GetName()); - break; - } - case CatalogType::VIEW_ENTRY: - case CatalogType::SEQUENCE_ENTRY: - case CatalogType::MACRO_ENTRY: - case CatalogType::TABLE_MACRO_ENTRY: - case CatalogType::INDEX_ENTRY: - case CatalogType::TABLE_ENTRY: - case CatalogType::TYPE_ENTRY: { - BindSchemaOrCatalog(stmt.info->catalog, stmt.info->schema); - auto entry = Catalog::GetEntry(context, stmt.info->type, stmt.info->catalog, stmt.info->schema, stmt.info->name, - OnEntryNotFound::RETURN_NULL); - if (!entry) { - break; - } - if (entry->internal) { - throw CatalogException("Cannot drop internal catalog entry \"%s\"!", entry->name); - } - stmt.info->catalog = entry->ParentCatalog().GetName(); - if (!entry->temporary) { - // we can only drop temporary tables in read-only mode - properties.modified_databases.insert(stmt.info->catalog); - } - stmt.info->schema = entry->ParentSchema().name; - break; - } - default: - throw BinderException("Unknown catalog type for drop statement!"); - } - result.plan = make_uniq(LogicalOperatorType::LOGICAL_DROP, std::move(stmt.info)); - result.names = {"Success"}; - result.types = {LogicalType::BOOLEAN}; - properties.allow_stream_result = false; - properties.return_type = StatementReturnType::NOTHING; - return result; -} - -} // namespace duckdb - - - - - - - - - - -namespace duckdb { - -BoundStatement Binder::Bind(ExecuteStatement &stmt) { - auto parameter_count = stmt.n_param; - - // bind the prepared statement - auto &client_data = ClientData::Get(context); - - auto entry = client_data.prepared_statements.find(stmt.name); - if (entry == client_data.prepared_statements.end()) { - throw BinderException("Prepared statement \"%s\" does not exist", stmt.name); - } - - // check if we need to rebind the prepared statement - // this happens if the catalog changes, since in this case e.g. tables we relied on may have been deleted - auto prepared = entry->second; - auto &named_param_map = prepared->unbound_statement->named_param_map; - - PreparedStatement::VerifyParameters(stmt.named_values, named_param_map); - - auto &mapped_named_values = stmt.named_values; - // bind any supplied parameters - case_insensitive_map_t bind_values; - auto constant_binder = Binder::CreateBinder(context); - constant_binder->SetCanContainNulls(true); - for (auto &pair : mapped_named_values) { - ConstantBinder cbinder(*constant_binder, context, "EXECUTE statement"); - auto bound_expr = cbinder.Bind(pair.second); - - Value value = ExpressionExecutor::EvaluateScalar(context, *bound_expr, true); - bind_values[pair.first] = std::move(value); - } - unique_ptr rebound_plan; - - if (prepared->RequireRebind(context, &bind_values)) { - // catalog was modified or statement does not have clear types: rebind the statement before running the execute - Planner prepared_planner(context); - for (auto &pair : bind_values) { - prepared_planner.parameter_data.emplace(std::make_pair(pair.first, BoundParameterData(pair.second))); - } - prepared = prepared_planner.PrepareSQLStatement(entry->second->unbound_statement->Copy()); - rebound_plan = std::move(prepared_planner.plan); - D_ASSERT(prepared->properties.bound_all_parameters); - this->bound_tables = prepared_planner.binder->bound_tables; - } - // copy the properties of the prepared statement into the planner - this->properties = prepared->properties; - this->properties.parameter_count = parameter_count; - BoundStatement result; - result.names = prepared->names; - result.types = prepared->types; - - prepared->Bind(std::move(bind_values)); - if (rebound_plan) { - auto execute_plan = make_uniq(std::move(prepared)); - execute_plan->children.push_back(std::move(rebound_plan)); - result.plan = std::move(execute_plan); - } else { - result.plan = make_uniq(std::move(prepared)); - } - return result; -} - -} // namespace duckdb - - - - -namespace duckdb { - -BoundStatement Binder::Bind(ExplainStatement &stmt) { - BoundStatement result; - - // bind the underlying statement - auto plan = Bind(*stmt.stmt); - // get the unoptimized logical plan, and create the explain statement - auto logical_plan_unopt = plan.plan->ToString(); - auto explain = make_uniq(std::move(plan.plan), stmt.explain_type); - explain->logical_plan_unopt = logical_plan_unopt; - - result.plan = std::move(explain); - result.names = {"explain_key", "explain_value"}; - result.types = {LogicalType::VARCHAR, LogicalType::VARCHAR}; - properties.return_type = StatementReturnType::QUERY_RESULT; - return result; -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - -#include - -namespace duckdb { - -//! Sanitizes a string to have only low case chars and underscores -string SanitizeExportIdentifier(const string &str) { - // Copy the original string to result - string result(str); - - for (idx_t i = 0; i < str.length(); ++i) { - auto c = str[i]; - if (c >= 'a' && c <= 'z') { - // If it is lower case just continue - continue; - } - - if (c >= 'A' && c <= 'Z') { - // To lowercase - result[i] = tolower(c); - } else { - // Substitute to underscore - result[i] = '_'; - } - } - - return result; -} - -bool ReferencedTableIsOrdered(string &referenced_table, catalog_entry_vector_t &ordered) { - for (auto &entry : ordered) { - auto &table_entry = entry.get().Cast(); - if (StringUtil::CIEquals(table_entry.name, referenced_table)) { - // The referenced table is already ordered - return true; - } - } - return false; -} - -void ScanForeignKeyTable(catalog_entry_vector_t &ordered, catalog_entry_vector_t &unordered, - bool move_primary_keys_only) { - catalog_entry_vector_t remaining; - - for (auto &entry : unordered) { - auto &table_entry = entry.get().Cast(); - bool move_to_ordered = true; - auto &constraints = table_entry.GetConstraints(); - - for (auto &cond : constraints) { - if (cond->type != ConstraintType::FOREIGN_KEY) { - continue; - } - auto &fk = cond->Cast(); - if (fk.info.type != ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE) { - continue; - } - - if (move_primary_keys_only) { - // This table references a table, don't move it yet - move_to_ordered = false; - break; - } else if (!ReferencedTableIsOrdered(fk.info.table, ordered)) { - // The table that it references isn't ordered yet - move_to_ordered = false; - break; - } - } - - if (move_to_ordered) { - ordered.push_back(table_entry); - } else { - remaining.push_back(table_entry); - } - } - unordered = remaining; -} - -void ReorderTableEntries(catalog_entry_vector_t &tables) { - catalog_entry_vector_t ordered; - catalog_entry_vector_t unordered = tables; - // First only move the tables that don't have any dependencies - ScanForeignKeyTable(ordered, unordered, true); - while (!unordered.empty()) { - // Now we will start moving tables that have foreign key constraints - // if the tables they reference are already moved - ScanForeignKeyTable(ordered, unordered, false); - } - tables = ordered; -} - -string CreateFileName(const string &id_suffix, TableCatalogEntry &table, const string &extension) { - auto name = SanitizeExportIdentifier(table.name); - if (table.schema.name == DEFAULT_SCHEMA) { - return StringUtil::Format("%s%s.%s", name, id_suffix, extension); - } - auto schema = SanitizeExportIdentifier(table.schema.name); - return StringUtil::Format("%s_%s%s.%s", schema, name, id_suffix, extension); -} - -static bool IsSupported(CopyTypeSupport support_level) { - // For export purposes we don't want to lose information, so we only accept fully supported types - return support_level == CopyTypeSupport::SUPPORTED; -} - -static LogicalType AlterLogicalType(const LogicalType &original, copy_supports_type_t type_check) { - D_ASSERT(type_check); - auto id = original.id(); - switch (id) { - case LogicalTypeId::LIST: { - auto child = AlterLogicalType(ListType::GetChildType(original), type_check); - return LogicalType::LIST(child); - } - case LogicalTypeId::STRUCT: { - auto &original_children = StructType::GetChildTypes(original); - child_list_t new_children; - for (auto &child : original_children) { - auto &child_name = child.first; - auto &child_type = child.second; - - LogicalType new_type; - if (!IsSupported(type_check(child_type))) { - new_type = AlterLogicalType(child_type, type_check); - } else { - new_type = child_type; - } - new_children.push_back(std::make_pair(child_name, new_type)); - } - return LogicalType::STRUCT(std::move(new_children)); - } - case LogicalTypeId::UNION: { - auto member_count = UnionType::GetMemberCount(original); - child_list_t new_children; - for (idx_t i = 0; i < member_count; i++) { - auto &child_name = UnionType::GetMemberName(original, i); - auto &child_type = UnionType::GetMemberType(original, i); - - LogicalType new_type; - if (!IsSupported(type_check(child_type))) { - new_type = AlterLogicalType(child_type, type_check); - } else { - new_type = child_type; - } - - new_children.push_back(std::make_pair(child_name, new_type)); - } - return LogicalType::UNION(std::move(new_children)); - } - case LogicalTypeId::MAP: { - auto &key_type = MapType::KeyType(original); - auto &value_type = MapType::ValueType(original); - - LogicalType new_key_type; - LogicalType new_value_type; - if (!IsSupported(type_check(key_type))) { - new_key_type = AlterLogicalType(key_type, type_check); - } else { - new_key_type = key_type; - } - - if (!IsSupported(type_check(value_type))) { - new_value_type = AlterLogicalType(value_type, type_check); - } else { - new_value_type = value_type; - } - return LogicalType::MAP(new_key_type, new_value_type); - } - default: { - D_ASSERT(!IsSupported(type_check(original))); - return LogicalType::VARCHAR; - } - } -} - -static bool NeedsCast(LogicalType &type, copy_supports_type_t type_check) { - if (!type_check) { - return false; - } - if (IsSupported(type_check(type))) { - // The type is supported in it's entirety, no cast is required - return false; - } - // Change the type to something that is supported - type = AlterLogicalType(type, type_check); - return true; -} - -static unique_ptr CreateSelectStatement(CopyStatement &stmt, child_list_t &select_list, - copy_supports_type_t type_check) { - auto ref = make_uniq(); - ref->catalog_name = stmt.info->catalog; - ref->schema_name = stmt.info->schema; - ref->table_name = stmt.info->table; - - auto statement = make_uniq(); - statement->from_table = std::move(ref); - - vector> expressions; - for (auto &col : select_list) { - auto &name = col.first; - auto &type = col.second; - - auto expression = make_uniq_base(name); - if (NeedsCast(type, type_check)) { - // Add a cast to a type supported by the copy function - expression = make_uniq_base(type, std::move(expression)); - } - expressions.push_back(std::move(expression)); - } - - statement->select_list = std::move(expressions); - return std::move(statement); -} - -BoundStatement Binder::Bind(ExportStatement &stmt) { - // COPY TO a file - auto &config = DBConfig::GetConfig(context); - if (!config.options.enable_external_access) { - throw PermissionException("COPY TO is disabled through configuration"); - } - BoundStatement result; - result.types = {LogicalType::BOOLEAN}; - result.names = {"Success"}; - - // lookup the format in the catalog - auto ©_function = - Catalog::GetEntry(context, INVALID_CATALOG, DEFAULT_SCHEMA, stmt.info->format); - if (!copy_function.function.copy_to_bind && !copy_function.function.plan) { - throw NotImplementedException("COPY TO is not supported for FORMAT \"%s\"", stmt.info->format); - } - - // gather a list of all the tables - string catalog = stmt.database.empty() ? INVALID_CATALOG : stmt.database; - catalog_entry_vector_t tables; - auto schemas = Catalog::GetSchemas(context, catalog); - for (auto &schema : schemas) { - schema.get().Scan(context, CatalogType::TABLE_ENTRY, [&](CatalogEntry &entry) { - if (entry.type == CatalogType::TABLE_ENTRY) { - tables.push_back(entry.Cast()); - } - }); - } - - // reorder tables because of foreign key constraint - ReorderTableEntries(tables); - - // now generate the COPY statements for each of the tables - auto &fs = FileSystem::GetFileSystem(context); - unique_ptr child_operator; - - BoundExportData exported_tables; - - unordered_set table_name_index; - for (auto &t : tables) { - auto &table = t.get().Cast(); - auto info = make_uniq(); - // we copy the options supplied to the EXPORT - info->format = stmt.info->format; - info->options = stmt.info->options; - // set up the file name for the COPY TO - - idx_t id = 0; - while (true) { - string id_suffix = id == 0 ? string() : "_" + to_string(id); - auto name = CreateFileName(id_suffix, table, copy_function.function.extension); - auto directory = stmt.info->file_path; - auto full_path = fs.JoinPath(directory, name); - info->file_path = full_path; - auto insert_result = table_name_index.insert(info->file_path); - if (insert_result.second == true) { - // this name was not yet taken: take it - break; - } - id++; - } - info->is_from = false; - info->catalog = catalog; - info->schema = table.schema.name; - info->table = table.name; - - // We can not export generated columns - child_list_t select_list; - - for (auto &col : table.GetColumns().Physical()) { - select_list.push_back(std::make_pair(col.Name(), col.Type())); - } - - ExportedTableData exported_data; - exported_data.database_name = catalog; - exported_data.table_name = info->table; - exported_data.schema_name = info->schema; - - exported_data.file_path = info->file_path; - - ExportedTableInfo table_info(table, std::move(exported_data)); - exported_tables.data.push_back(table_info); - id++; - - // generate the copy statement and bind it - CopyStatement copy_stmt; - copy_stmt.info = std::move(info); - copy_stmt.select_statement = - CreateSelectStatement(copy_stmt, select_list, copy_function.function.supports_type); - - auto copy_binder = Binder::CreateBinder(context, this); - auto bound_statement = copy_binder->Bind(copy_stmt); - auto plan = std::move(bound_statement.plan); - - if (child_operator) { - // use UNION ALL to combine the individual copy statements into a single node - auto copy_union = make_uniq(GenerateTableIndex(), 1, std::move(child_operator), - std::move(plan), LogicalOperatorType::LOGICAL_UNION); - child_operator = std::move(copy_union); - } else { - child_operator = std::move(plan); - } - } - - // try to create the directory, if it doesn't exist yet - // a bit hacky to do it here, but we need to create the directory BEFORE the copy statements run - if (!fs.DirectoryExists(stmt.info->file_path)) { - fs.CreateDirectory(stmt.info->file_path); - } - - // create the export node - auto export_node = make_uniq(copy_function.function, std::move(stmt.info), exported_tables); - - if (child_operator) { - export_node->children.push_back(std::move(child_operator)); - } - - result.plan = std::move(export_node); - properties.allow_stream_result = false; - properties.return_type = StatementReturnType::NOTHING; - return result; -} - -} // namespace duckdb - - - - -namespace duckdb { - -BoundStatement Binder::Bind(ExtensionStatement &stmt) { - BoundStatement result; - - // perform the planning of the function - D_ASSERT(stmt.extension.plan_function); - auto parse_result = - stmt.extension.plan_function(stmt.extension.parser_info.get(), context, std::move(stmt.parse_data)); - - properties.modified_databases = parse_result.modified_databases; - properties.requires_valid_transaction = parse_result.requires_valid_transaction; - properties.return_type = parse_result.return_type; - - // create the plan as a scan of the given table function - result.plan = BindTableFunction(parse_result.function, std::move(parse_result.parameters)); - D_ASSERT(result.plan->type == LogicalOperatorType::LOGICAL_GET); - auto &get = result.plan->Cast(); - result.names = get.names; - result.types = get.returned_types; - get.column_ids.clear(); - for (idx_t i = 0; i < get.returned_types.size(); i++) { - get.column_ids.push_back(i); - } - return result; -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -namespace duckdb { - -static void CheckInsertColumnCountMismatch(int64_t expected_columns, int64_t result_columns, bool columns_provided, - const char *tname) { - if (result_columns != expected_columns) { - string msg = StringUtil::Format(!columns_provided ? "table %s has %lld columns but %lld values were supplied" - : "Column name/value mismatch for insert on %s: " - "expected %lld columns but %lld values were supplied", - tname, expected_columns, result_columns); - throw BinderException(msg); - } -} - -unique_ptr ExpandDefaultExpression(const ColumnDefinition &column) { - if (column.DefaultValue()) { - return column.DefaultValue()->Copy(); - } else { - return make_uniq(Value(column.Type())); - } -} - -void ReplaceDefaultExpression(unique_ptr &expr, const ColumnDefinition &column) { - D_ASSERT(expr->type == ExpressionType::VALUE_DEFAULT); - expr = ExpandDefaultExpression(column); -} - -void QualifyColumnReferences(unique_ptr &expr, const string &table_name) { - // To avoid ambiguity with 'excluded', we explicitly qualify all column references - if (expr->type == ExpressionType::COLUMN_REF) { - auto &column_ref = expr->Cast(); - if (column_ref.IsQualified()) { - return; - } - auto column_name = column_ref.GetColumnName(); - expr = make_uniq(column_name, table_name); - } - ParsedExpressionIterator::EnumerateChildren( - *expr, [&](unique_ptr &child) { QualifyColumnReferences(child, table_name); }); -} - -// Replace binding.table_index with 'dest' if it's 'source' -void ReplaceColumnBindings(Expression &expr, idx_t source, idx_t dest) { - if (expr.type == ExpressionType::BOUND_COLUMN_REF) { - auto &bound_columnref = expr.Cast(); - if (bound_columnref.binding.table_index == source) { - bound_columnref.binding.table_index = dest; - } - } - ExpressionIterator::EnumerateChildren( - expr, [&](unique_ptr &child) { ReplaceColumnBindings(*child, source, dest); }); -} - -void Binder::BindDoUpdateSetExpressions(const string &table_alias, LogicalInsert &insert, UpdateSetInfo &set_info, - TableCatalogEntry &table, TableStorageInfo &storage_info) { - D_ASSERT(insert.children.size() == 1); - D_ASSERT(insert.children[0]->type == LogicalOperatorType::LOGICAL_PROJECTION); - - vector logical_column_ids; - vector column_names; - D_ASSERT(set_info.columns.size() == set_info.expressions.size()); - - for (idx_t i = 0; i < set_info.columns.size(); i++) { - auto &colname = set_info.columns[i]; - auto &expr = set_info.expressions[i]; - if (!table.ColumnExists(colname)) { - throw BinderException("Referenced update column %s not found in table!", colname); - } - auto &column = table.GetColumn(colname); - if (column.Generated()) { - throw BinderException("Cant update column \"%s\" because it is a generated column!", column.Name()); - } - if (std::find(insert.set_columns.begin(), insert.set_columns.end(), column.Physical()) != - insert.set_columns.end()) { - throw BinderException("Multiple assignments to same column \"%s\"", colname); - } - insert.set_columns.push_back(column.Physical()); - logical_column_ids.push_back(column.Oid()); - insert.set_types.push_back(column.Type()); - column_names.push_back(colname); - if (expr->type == ExpressionType::VALUE_DEFAULT) { - expr = ExpandDefaultExpression(column); - } - UpdateBinder binder(*this, context); - binder.target_type = column.Type(); - - // Avoid ambiguity issues - QualifyColumnReferences(expr, table_alias); - - auto bound_expr = binder.Bind(expr); - D_ASSERT(bound_expr); - if (bound_expr->expression_class == ExpressionClass::BOUND_SUBQUERY) { - throw BinderException("Expression in the DO UPDATE SET clause can not be a subquery"); - } - - insert.expressions.push_back(std::move(bound_expr)); - } - - // Figure out which columns are indexed on - unordered_set indexed_columns; - for (auto &index : storage_info.index_info) { - for (auto &column_id : index.column_set) { - indexed_columns.insert(column_id); - } - } - - // Verify that none of the columns that are targeted with a SET expression are indexed on - for (idx_t i = 0; i < logical_column_ids.size(); i++) { - auto &column = logical_column_ids[i]; - if (indexed_columns.count(column)) { - throw BinderException("Can not assign to column '%s' because it has a UNIQUE/PRIMARY KEY constraint", - column_names[i]); - } - } -} - -unique_ptr CreateSetInfoForReplace(TableCatalogEntry &table, InsertStatement &insert, - TableStorageInfo &storage_info) { - auto set_info = make_uniq(); - - auto &columns = set_info->columns; - // Figure out which columns are indexed on - - unordered_set indexed_columns; - for (auto &index : storage_info.index_info) { - for (auto &column_id : index.column_set) { - indexed_columns.insert(column_id); - } - } - - auto &column_list = table.GetColumns(); - if (insert.columns.empty()) { - for (auto &column : column_list.Physical()) { - auto &name = column.Name(); - // FIXME: can these column names be aliased somehow? - if (indexed_columns.count(column.Oid())) { - continue; - } - columns.push_back(name); - } - } else { - // a list of columns was explicitly supplied, only update those - for (auto &name : insert.columns) { - auto &column = column_list.GetColumn(name); - if (indexed_columns.count(column.Oid())) { - continue; - } - columns.push_back(name); - } - } - - // Create 'excluded' qualified column references of these columns - for (auto &column : columns) { - set_info->expressions.push_back(make_uniq(column, "excluded")); - } - - return set_info; -} - -void Binder::BindOnConflictClause(LogicalInsert &insert, TableCatalogEntry &table, InsertStatement &stmt) { - if (!stmt.on_conflict_info) { - insert.action_type = OnConflictAction::THROW; - return; - } - D_ASSERT(stmt.table_ref->type == TableReferenceType::BASE_TABLE); - - // visit the table reference - auto bound_table = Bind(*stmt.table_ref); - if (bound_table->type != TableReferenceType::BASE_TABLE) { - throw BinderException("Can only update base table!"); - } - - auto &table_ref = stmt.table_ref->Cast(); - const string &table_alias = !table_ref.alias.empty() ? table_ref.alias : table_ref.table_name; - - auto &on_conflict = *stmt.on_conflict_info; - D_ASSERT(on_conflict.action_type != OnConflictAction::THROW); - insert.action_type = on_conflict.action_type; - - // obtain the table storage info - auto storage_info = table.GetStorageInfo(context); - - auto &columns = table.GetColumns(); - if (!on_conflict.indexed_columns.empty()) { - // Bind the ON CONFLICT () - - // create a mapping of (list index) -> (column index) - case_insensitive_map_t specified_columns; - for (idx_t i = 0; i < on_conflict.indexed_columns.size(); i++) { - specified_columns[on_conflict.indexed_columns[i]] = i; - auto column_index = table.GetColumnIndex(on_conflict.indexed_columns[i]); - if (column_index.index == COLUMN_IDENTIFIER_ROW_ID) { - throw BinderException("Cannot specify ROWID as ON CONFLICT target"); - } - auto &col = columns.GetColumn(column_index); - if (col.Generated()) { - throw BinderException("Cannot specify a generated column as ON CONFLICT target"); - } - } - for (auto &col : columns.Physical()) { - auto entry = specified_columns.find(col.Name()); - if (entry != specified_columns.end()) { - // column was specified, set to the index - insert.on_conflict_filter.insert(col.Oid()); - } - } - bool index_references_columns = false; - for (auto &index : storage_info.index_info) { - if (!index.is_unique) { - continue; - } - bool index_matches = insert.on_conflict_filter == index.column_set; - if (index_matches) { - index_references_columns = true; - break; - } - } - if (!index_references_columns) { - // Same as before, this is essentially a no-op, turning this into a DO THROW instead - // But since this makes no logical sense, it's probably better to throw an error - throw BinderException( - "The specified columns as conflict target are not referenced by a UNIQUE/PRIMARY KEY CONSTRAINT"); - } - } else { - // When omitting the conflict target, the ON CONFLICT applies to every UNIQUE/PRIMARY KEY on the table - - // We check if there are any constraints on the table, if there aren't we throw an error. - idx_t found_matching_indexes = 0; - for (auto &index : storage_info.index_info) { - if (!index.is_unique) { - continue; - } - // does this work with multi-column indexes? - auto &indexed_columns = index.column_set; - for (auto &column : table.GetColumns().Physical()) { - if (indexed_columns.count(column.Physical().index)) { - found_matching_indexes++; - } - } - } - if (!found_matching_indexes) { - throw BinderException( - "There are no UNIQUE/PRIMARY KEY Indexes that refer to this table, ON CONFLICT is a no-op"); - } - if (insert.action_type != OnConflictAction::NOTHING && found_matching_indexes != 1) { - // When no conflict target is provided, and the action type is UPDATE, - // we only allow the operation when only a single Index exists - throw BinderException("Conflict target has to be provided for a DO UPDATE operation when the table has " - "multiple UNIQUE/PRIMARY KEY constraints"); - } - } - - // add the 'excluded' dummy table binding - AddTableName("excluded"); - // add a bind context entry for it - auto excluded_index = GenerateTableIndex(); - insert.excluded_table_index = excluded_index; - auto table_column_names = columns.GetColumnNames(); - auto table_column_types = columns.GetColumnTypes(); - bind_context.AddGenericBinding(excluded_index, "excluded", table_column_names, table_column_types); - - if (on_conflict.condition) { - // Avoid ambiguity between binding and 'excluded' - QualifyColumnReferences(on_conflict.condition, table_alias); - // Bind the ON CONFLICT ... WHERE clause - WhereBinder where_binder(*this, context); - auto condition = where_binder.Bind(on_conflict.condition); - if (condition && condition->expression_class == ExpressionClass::BOUND_SUBQUERY) { - throw BinderException("conflict_target WHERE clause can not be a subquery"); - } - insert.on_conflict_condition = std::move(condition); - } - - auto bindings = insert.children[0]->GetColumnBindings(); - idx_t projection_index = DConstants::INVALID_INDEX; - vector> *insert_child_operators; - insert_child_operators = &insert.children; - while (projection_index == DConstants::INVALID_INDEX) { - if (insert_child_operators->empty()) { - // No further children to visit - break; - } - D_ASSERT(insert_child_operators->size() >= 1); - auto ¤t_child = (*insert_child_operators)[0]; - auto table_indices = current_child->GetTableIndex(); - if (table_indices.empty()) { - // This operator does not have a table index to refer to, we have to visit its children - insert_child_operators = ¤t_child->children; - continue; - } - projection_index = table_indices[0]; - } - if (projection_index == DConstants::INVALID_INDEX) { - throw InternalException("Could not locate a table_index from the children of the insert"); - } - - string unused; - auto original_binding = bind_context.GetBinding(table_alias, unused); - D_ASSERT(original_binding); - - auto table_index = original_binding->index; - - // Replace any column bindings to refer to the projection table_index, rather than the source table - if (insert.on_conflict_condition) { - ReplaceColumnBindings(*insert.on_conflict_condition, table_index, projection_index); - } - - if (insert.action_type == OnConflictAction::REPLACE) { - D_ASSERT(on_conflict.set_info == nullptr); - on_conflict.set_info = CreateSetInfoForReplace(table, stmt, storage_info); - insert.action_type = OnConflictAction::UPDATE; - } - if (on_conflict.set_info && on_conflict.set_info->columns.empty()) { - // if we are doing INSERT OR REPLACE on a table with no columns outside of the primary key column - // convert to INSERT OR IGNORE - insert.action_type = OnConflictAction::NOTHING; - } - if (insert.action_type == OnConflictAction::NOTHING) { - if (!insert.on_conflict_condition) { - return; - } - // Get the column_ids we need to fetch later on from the conflicting tuples - // of the original table, to execute the expressions - D_ASSERT(original_binding->binding_type == BindingType::TABLE); - auto &table_binding = original_binding->Cast(); - insert.columns_to_fetch = table_binding.GetBoundColumnIds(); - return; - } - - D_ASSERT(on_conflict.set_info); - auto &set_info = *on_conflict.set_info; - D_ASSERT(set_info.columns.size() == set_info.expressions.size()); - - if (set_info.condition) { - // Avoid ambiguity between binding and 'excluded' - QualifyColumnReferences(set_info.condition, table_alias); - // Bind the SET ... WHERE clause - WhereBinder where_binder(*this, context); - auto condition = where_binder.Bind(set_info.condition); - if (condition && condition->expression_class == ExpressionClass::BOUND_SUBQUERY) { - throw BinderException("conflict_target WHERE clause can not be a subquery"); - } - insert.do_update_condition = std::move(condition); - } - - BindDoUpdateSetExpressions(table_alias, insert, set_info, table, storage_info); - - // Get the column_ids we need to fetch later on from the conflicting tuples - // of the original table, to execute the expressions - D_ASSERT(original_binding->binding_type == BindingType::TABLE); - auto &table_binding = original_binding->Cast(); - insert.columns_to_fetch = table_binding.GetBoundColumnIds(); - - // Replace the column bindings to refer to the child operator - for (auto &expr : insert.expressions) { - // Change the non-excluded column references to refer to the projection index - ReplaceColumnBindings(*expr, table_index, projection_index); - } - // Do the same for the (optional) DO UPDATE condition - if (insert.do_update_condition) { - ReplaceColumnBindings(*insert.do_update_condition, table_index, projection_index); - } -} - -BoundStatement Binder::Bind(InsertStatement &stmt) { - BoundStatement result; - result.names = {"Count"}; - result.types = {LogicalType::BIGINT}; - - BindSchemaOrCatalog(stmt.catalog, stmt.schema); - auto &table = Catalog::GetEntry(context, stmt.catalog, stmt.schema, stmt.table); - if (!table.temporary) { - // inserting into a non-temporary table: alters underlying database - properties.modified_databases.insert(table.catalog.GetName()); - } - - auto insert = make_uniq(table, GenerateTableIndex()); - // Add CTEs as bindable - AddCTEMap(stmt.cte_map); - - auto values_list = stmt.GetValuesList(); - - // bind the root select node (if any) - BoundStatement root_select; - if (stmt.column_order == InsertColumnOrder::INSERT_BY_NAME) { - if (values_list) { - throw BinderException("INSERT BY NAME can only be used when inserting from a SELECT statement"); - } - if (!stmt.columns.empty()) { - throw BinderException("INSERT BY NAME cannot be combined with an explicit column list"); - } - D_ASSERT(stmt.select_statement); - // INSERT BY NAME - generate the columns from the names of the SELECT statement - auto select_binder = Binder::CreateBinder(context, this); - root_select = select_binder->Bind(*stmt.select_statement); - MoveCorrelatedExpressions(*select_binder); - - stmt.columns = root_select.names; - } - - vector named_column_map; - if (!stmt.columns.empty() || stmt.default_values) { - // insertion statement specifies column list - - // create a mapping of (list index) -> (column index) - case_insensitive_map_t column_name_map; - for (idx_t i = 0; i < stmt.columns.size(); i++) { - auto entry = column_name_map.insert(make_pair(stmt.columns[i], i)); - if (!entry.second) { - throw BinderException("Duplicate column name \"%s\" in INSERT", stmt.columns[i]); - } - column_name_map[stmt.columns[i]] = i; - auto column_index = table.GetColumnIndex(stmt.columns[i]); - if (column_index.index == COLUMN_IDENTIFIER_ROW_ID) { - throw BinderException("Cannot explicitly insert values into rowid column"); - } - auto &col = table.GetColumn(column_index); - if (col.Generated()) { - throw BinderException("Cannot insert into a generated column"); - } - insert->expected_types.push_back(col.Type()); - named_column_map.push_back(column_index); - } - for (auto &col : table.GetColumns().Physical()) { - auto entry = column_name_map.find(col.Name()); - if (entry == column_name_map.end()) { - // column not specified, set index to DConstants::INVALID_INDEX - insert->column_index_map.push_back(DConstants::INVALID_INDEX); - } else { - // column was specified, set to the index - insert->column_index_map.push_back(entry->second); - } - } - } else { - // insert by position and no columns specified - insertion into all columns of the table - // intentionally don't populate 'column_index_map' as an indication of this - for (auto &col : table.GetColumns().Physical()) { - named_column_map.push_back(col.Logical()); - insert->expected_types.push_back(col.Type()); - } - } - - // bind the default values - BindDefaultValues(table.GetColumns(), insert->bound_defaults); - if (!stmt.select_statement && !stmt.default_values) { - result.plan = std::move(insert); - return result; - } - // Exclude the generated columns from this amount - idx_t expected_columns = stmt.columns.empty() ? table.GetColumns().PhysicalColumnCount() : stmt.columns.size(); - - // special case: check if we are inserting from a VALUES statement - if (values_list) { - auto &expr_list = values_list->Cast(); - expr_list.expected_types.resize(expected_columns); - expr_list.expected_names.resize(expected_columns); - - D_ASSERT(expr_list.values.size() > 0); - CheckInsertColumnCountMismatch(expected_columns, expr_list.values[0].size(), !stmt.columns.empty(), - table.name.c_str()); - - // VALUES list! - for (idx_t col_idx = 0; col_idx < expected_columns; col_idx++) { - D_ASSERT(named_column_map.size() >= col_idx); - auto &table_col_idx = named_column_map[col_idx]; - - // set the expected types as the types for the INSERT statement - auto &column = table.GetColumn(table_col_idx); - expr_list.expected_types[col_idx] = column.Type(); - expr_list.expected_names[col_idx] = column.Name(); - - // now replace any DEFAULT values with the corresponding default expression - for (idx_t list_idx = 0; list_idx < expr_list.values.size(); list_idx++) { - if (expr_list.values[list_idx][col_idx]->type == ExpressionType::VALUE_DEFAULT) { - // DEFAULT value! replace the entry - ReplaceDefaultExpression(expr_list.values[list_idx][col_idx], column); - } - } - } - } - - // parse select statement and add to logical plan - unique_ptr root; - if (stmt.select_statement) { - if (stmt.column_order == InsertColumnOrder::INSERT_BY_POSITION) { - auto select_binder = Binder::CreateBinder(context, this); - root_select = select_binder->Bind(*stmt.select_statement); - MoveCorrelatedExpressions(*select_binder); - } - // inserting from a select - check if the column count matches - CheckInsertColumnCountMismatch(expected_columns, root_select.types.size(), !stmt.columns.empty(), - table.name.c_str()); - - root = CastLogicalOperatorToTypes(root_select.types, insert->expected_types, std::move(root_select.plan)); - } else { - root = make_uniq(GenerateTableIndex()); - } - insert->AddChild(std::move(root)); - - BindOnConflictClause(*insert, table, stmt); - - if (!stmt.returning_list.empty()) { - insert->return_chunk = true; - result.types.clear(); - result.names.clear(); - auto insert_table_index = GenerateTableIndex(); - insert->table_index = insert_table_index; - unique_ptr index_as_logicaloperator = std::move(insert); - - return BindReturning(std::move(stmt.returning_list), table, stmt.table_ref ? stmt.table_ref->alias : string(), - insert_table_index, std::move(index_as_logicaloperator), std::move(result)); - } - - D_ASSERT(result.types.size() == result.names.size()); - result.plan = std::move(insert); - properties.allow_stream_result = false; - properties.return_type = StatementReturnType::CHANGED_ROWS; - return result; -} - -} // namespace duckdb - - - -#include - -namespace duckdb { - -BoundStatement Binder::Bind(LoadStatement &stmt) { - BoundStatement result; - result.types = {LogicalType::BOOLEAN}; - result.names = {"Success"}; - - result.plan = make_uniq(LogicalOperatorType::LOGICAL_LOAD, std::move(stmt.info)); - properties.allow_stream_result = false; - properties.return_type = StatementReturnType::NOTHING; - return result; -} - -} // namespace duckdb - - -#include - -namespace duckdb { - -idx_t GetMaxTableIndex(LogicalOperator &op) { - idx_t result = 0; - for (auto &child : op.children) { - auto max_child_index = GetMaxTableIndex(*child); - result = MaxValue(result, max_child_index); - } - auto indexes = op.GetTableIndex(); - for (auto &index : indexes) { - result = MaxValue(result, index); - } - return result; -} - -BoundStatement Binder::Bind(LogicalPlanStatement &stmt) { - BoundStatement result; - result.types = stmt.plan->types; - for (idx_t i = 0; i < result.types.size(); i++) { - result.names.push_back(StringUtil::Format("col%d", i)); - } - result.plan = std::move(stmt.plan); - properties.allow_stream_result = true; - properties.return_type = StatementReturnType::QUERY_RESULT; // TODO could also be something else - - if (parent) { - throw InternalException("LogicalPlanStatement should be bound in root binder"); - } - bound_tables = GetMaxTableIndex(*result.plan) + 1; - return result; -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -BoundStatement Binder::Bind(PragmaStatement &stmt) { - // bind the pragma function - auto &entry = - Catalog::GetEntry(context, INVALID_CATALOG, DEFAULT_SCHEMA, stmt.info->name); - string error; - FunctionBinder function_binder(context); - idx_t bound_idx = function_binder.BindFunction(entry.name, entry.functions, *stmt.info, error); - if (bound_idx == DConstants::INVALID_INDEX) { - throw BinderException(FormatError(stmt.stmt_location, error)); - } - auto bound_function = entry.functions.GetFunctionByOffset(bound_idx); - if (!bound_function.function) { - throw BinderException("PRAGMA function does not have a function specified"); - } - - // bind and check named params - QueryErrorContext error_context(root_statement, stmt.stmt_location); - BindNamedParameters(bound_function.named_parameters, stmt.info->named_parameters, error_context, - bound_function.name); - - BoundStatement result; - result.names = {"Success"}; - result.types = {LogicalType::BOOLEAN}; - result.plan = make_uniq(bound_function, *stmt.info); - properties.return_type = StatementReturnType::QUERY_RESULT; - return result; -} - -} // namespace duckdb - - - - - -namespace duckdb { - -BoundStatement Binder::Bind(PrepareStatement &stmt) { - Planner prepared_planner(context); - auto prepared_data = prepared_planner.PrepareSQLStatement(std::move(stmt.statement)); - this->bound_tables = prepared_planner.binder->bound_tables; - - auto prepare = make_uniq(stmt.name, std::move(prepared_data), std::move(prepared_planner.plan)); - // we can always prepare, even if the transaction has been invalidated - // this is required because most clients ALWAYS invoke prepared statements - properties.requires_valid_transaction = false; - properties.allow_stream_result = false; - properties.bound_all_parameters = true; - properties.parameter_count = 0; - properties.return_type = StatementReturnType::NOTHING; - - BoundStatement result; - result.names = {"Success"}; - result.types = {LogicalType::BOOLEAN}; - result.plan = std::move(prepare); - return result; -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -BoundStatement Binder::Bind(RelationStatement &stmt) { - return stmt.relation->Bind(*this); -} - -} // namespace duckdb - - - - -namespace duckdb { - -BoundStatement Binder::Bind(SelectStatement &stmt) { - properties.allow_stream_result = true; - properties.return_type = StatementReturnType::QUERY_RESULT; - return Bind(*stmt.node); -} - -} // namespace duckdb - - - - -#include - -namespace duckdb { - -BoundStatement Binder::Bind(SetVariableStatement &stmt) { - BoundStatement result; - result.types = {LogicalType::BOOLEAN}; - result.names = {"Success"}; - - result.plan = make_uniq(stmt.name, stmt.value, stmt.scope); - properties.return_type = StatementReturnType::NOTHING; - return result; -} - -BoundStatement Binder::Bind(ResetVariableStatement &stmt) { - BoundStatement result; - result.types = {LogicalType::BOOLEAN}; - result.names = {"Success"}; - - result.plan = make_uniq(stmt.name, stmt.scope); - properties.return_type = StatementReturnType::NOTHING; - return result; -} - -BoundStatement Binder::Bind(SetStatement &stmt) { - switch (stmt.set_type) { - case SetType::SET: { - auto &set_stmt = stmt.Cast(); - return Bind(set_stmt); - } - case SetType::RESET: { - auto &set_stmt = stmt.Cast(); - return Bind(set_stmt); - } - default: - throw NotImplementedException("Type not implemented for SetType"); - } -} - -} // namespace duckdb - - - - -namespace duckdb { - -BoundStatement Binder::Bind(ShowStatement &stmt) { - BoundStatement result; - - if (stmt.info->is_summary) { - return BindSummarize(stmt); - } - auto plan = Bind(*stmt.info->query); - stmt.info->types = plan.types; - stmt.info->aliases = plan.names; - - auto show = make_uniq(std::move(plan.plan)); - show->types_select = plan.types; - show->aliases = plan.names; - - result.plan = std::move(show); - - result.names = {"column_name", "column_type", "null", "key", "default", "extra"}; - result.types = {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR, - LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}; - properties.return_type = StatementReturnType::QUERY_RESULT; - return result; -} - -} // namespace duckdb - - - - - - - - - -//! This file contains the binder definitions for statements that do not need to be bound at all and only require a -//! straightforward conversion - -namespace duckdb { - -BoundStatement Binder::Bind(AlterStatement &stmt) { - BoundStatement result; - result.names = {"Success"}; - result.types = {LogicalType::BOOLEAN}; - BindSchemaOrCatalog(stmt.info->catalog, stmt.info->schema); - auto entry = Catalog::GetEntry(context, stmt.info->GetCatalogType(), stmt.info->catalog, stmt.info->schema, - stmt.info->name, stmt.info->if_not_found); - if (entry) { - auto &catalog = entry->ParentCatalog(); - if (!entry->temporary) { - // we can only alter temporary tables/views in read-only mode - properties.modified_databases.insert(catalog.GetName()); - } - stmt.info->catalog = catalog.GetName(); - stmt.info->schema = entry->ParentSchema().name; - } - result.plan = make_uniq(LogicalOperatorType::LOGICAL_ALTER, std::move(stmt.info)); - properties.return_type = StatementReturnType::NOTHING; - return result; -} - -BoundStatement Binder::Bind(TransactionStatement &stmt) { - // transaction statements do not require a valid transaction - properties.requires_valid_transaction = stmt.info->type == TransactionType::BEGIN_TRANSACTION; - - BoundStatement result; - result.names = {"Success"}; - result.types = {LogicalType::BOOLEAN}; - result.plan = make_uniq(LogicalOperatorType::LOGICAL_TRANSACTION, std::move(stmt.info)); - properties.return_type = StatementReturnType::NOTHING; - return result; -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -static unique_ptr SummarizeWrapUnnest(vector> &children, - const string &alias) { - auto list_function = make_uniq("list_value", std::move(children)); - vector> unnest_children; - unnest_children.push_back(std::move(list_function)); - auto unnest_function = make_uniq("unnest", std::move(unnest_children)); - unnest_function->alias = alias; - return std::move(unnest_function); -} - -static unique_ptr SummarizeCreateAggregate(const string &aggregate, string column_name) { - vector> children; - children.push_back(make_uniq(std::move(column_name))); - auto aggregate_function = make_uniq(aggregate, std::move(children)); - auto cast_function = make_uniq(LogicalType::VARCHAR, std::move(aggregate_function)); - return std::move(cast_function); -} - -static unique_ptr SummarizeCreateAggregate(const string &aggregate, string column_name, - const Value &modifier) { - vector> children; - children.push_back(make_uniq(std::move(column_name))); - children.push_back(make_uniq(modifier)); - auto aggregate_function = make_uniq(aggregate, std::move(children)); - auto cast_function = make_uniq(LogicalType::VARCHAR, std::move(aggregate_function)); - return std::move(cast_function); -} - -static unique_ptr SummarizeCreateCountStar() { - vector> children; - auto aggregate_function = make_uniq("count_star", std::move(children)); - return std::move(aggregate_function); -} - -static unique_ptr SummarizeCreateBinaryFunction(const string &op, unique_ptr left, - unique_ptr right) { - vector> children; - children.push_back(std::move(left)); - children.push_back(std::move(right)); - auto binary_function = make_uniq(op, std::move(children)); - return std::move(binary_function); -} - -static unique_ptr SummarizeCreateNullPercentage(string column_name) { - auto count_star = make_uniq(LogicalType::DOUBLE, SummarizeCreateCountStar()); - auto count = - make_uniq(LogicalType::DOUBLE, SummarizeCreateAggregate("count", std::move(column_name))); - auto null_percentage = SummarizeCreateBinaryFunction("/", std::move(count), std::move(count_star)); - auto negate_x = - SummarizeCreateBinaryFunction("-", make_uniq(Value::DOUBLE(1)), std::move(null_percentage)); - auto percentage_x = - SummarizeCreateBinaryFunction("*", std::move(negate_x), make_uniq(Value::DOUBLE(100))); - auto round_x = SummarizeCreateBinaryFunction("round", std::move(percentage_x), - make_uniq(Value::INTEGER(2))); - auto concat_x = - SummarizeCreateBinaryFunction("concat", std::move(round_x), make_uniq(Value("%"))); - - return concat_x; -} - -BoundStatement Binder::BindSummarize(ShowStatement &stmt) { - auto query_copy = stmt.info->query->Copy(); - - // we bind the plan once in a child-node to figure out the column names and column types - auto child_binder = Binder::CreateBinder(context); - auto plan = child_binder->Bind(*stmt.info->query); - D_ASSERT(plan.types.size() == plan.names.size()); - vector> name_children; - vector> type_children; - vector> min_children; - vector> max_children; - vector> unique_children; - vector> avg_children; - vector> std_children; - vector> q25_children; - vector> q50_children; - vector> q75_children; - vector> count_children; - vector> null_percentage_children; - auto select = make_uniq(); - select->node = std::move(query_copy); - for (idx_t i = 0; i < plan.names.size(); i++) { - name_children.push_back(make_uniq(Value(plan.names[i]))); - type_children.push_back(make_uniq(Value(plan.types[i].ToString()))); - min_children.push_back(SummarizeCreateAggregate("min", plan.names[i])); - max_children.push_back(SummarizeCreateAggregate("max", plan.names[i])); - unique_children.push_back(SummarizeCreateAggregate("approx_count_distinct", plan.names[i])); - if (plan.types[i].IsNumeric()) { - avg_children.push_back(SummarizeCreateAggregate("avg", plan.names[i])); - std_children.push_back(SummarizeCreateAggregate("stddev", plan.names[i])); - q25_children.push_back(SummarizeCreateAggregate("approx_quantile", plan.names[i], Value::FLOAT(0.25))); - q50_children.push_back(SummarizeCreateAggregate("approx_quantile", plan.names[i], Value::FLOAT(0.50))); - q75_children.push_back(SummarizeCreateAggregate("approx_quantile", plan.names[i], Value::FLOAT(0.75))); - } else { - avg_children.push_back(make_uniq(Value())); - std_children.push_back(make_uniq(Value())); - q25_children.push_back(make_uniq(Value())); - q50_children.push_back(make_uniq(Value())); - q75_children.push_back(make_uniq(Value())); - } - count_children.push_back(SummarizeCreateCountStar()); - null_percentage_children.push_back(SummarizeCreateNullPercentage(plan.names[i])); - } - auto subquery_ref = make_uniq(std::move(select), "summarize_tbl"); - subquery_ref->column_name_alias = plan.names; - - auto select_node = make_uniq(); - select_node->select_list.push_back(SummarizeWrapUnnest(name_children, "column_name")); - select_node->select_list.push_back(SummarizeWrapUnnest(type_children, "column_type")); - select_node->select_list.push_back(SummarizeWrapUnnest(min_children, "min")); - select_node->select_list.push_back(SummarizeWrapUnnest(max_children, "max")); - select_node->select_list.push_back(SummarizeWrapUnnest(unique_children, "approx_unique")); - select_node->select_list.push_back(SummarizeWrapUnnest(avg_children, "avg")); - select_node->select_list.push_back(SummarizeWrapUnnest(std_children, "std")); - select_node->select_list.push_back(SummarizeWrapUnnest(q25_children, "q25")); - select_node->select_list.push_back(SummarizeWrapUnnest(q50_children, "q50")); - select_node->select_list.push_back(SummarizeWrapUnnest(q75_children, "q75")); - select_node->select_list.push_back(SummarizeWrapUnnest(count_children, "count")); - select_node->select_list.push_back(SummarizeWrapUnnest(null_percentage_children, "null_percentage")); - select_node->from_table = std::move(subquery_ref); - - properties.return_type = StatementReturnType::QUERY_RESULT; - return Bind(*select_node); -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - -#include - -namespace duckdb { - -// This creates a LogicalProjection and moves 'root' into it as a child -// unless there are no expressions to project, in which case it just returns 'root' -unique_ptr Binder::BindUpdateSet(LogicalOperator &op, unique_ptr root, - UpdateSetInfo &set_info, TableCatalogEntry &table, - vector &columns) { - auto proj_index = GenerateTableIndex(); - - vector> projection_expressions; - D_ASSERT(set_info.columns.size() == set_info.expressions.size()); - for (idx_t i = 0; i < set_info.columns.size(); i++) { - auto &colname = set_info.columns[i]; - auto &expr = set_info.expressions[i]; - if (!table.ColumnExists(colname)) { - throw BinderException("Referenced update column %s not found in table!", colname); - } - auto &column = table.GetColumn(colname); - if (column.Generated()) { - throw BinderException("Cant update column \"%s\" because it is a generated column!", column.Name()); - } - if (std::find(columns.begin(), columns.end(), column.Physical()) != columns.end()) { - throw BinderException("Multiple assignments to same column \"%s\"", colname); - } - columns.push_back(column.Physical()); - if (expr->type == ExpressionType::VALUE_DEFAULT) { - op.expressions.push_back(make_uniq(column.Type())); - } else { - UpdateBinder binder(*this, context); - binder.target_type = column.Type(); - auto bound_expr = binder.Bind(expr); - PlanSubqueries(bound_expr, root); - - op.expressions.push_back(make_uniq( - bound_expr->return_type, ColumnBinding(proj_index, projection_expressions.size()))); - projection_expressions.push_back(std::move(bound_expr)); - } - } - if (op.type != LogicalOperatorType::LOGICAL_UPDATE && projection_expressions.empty()) { - return root; - } - // now create the projection - auto proj = make_uniq(proj_index, std::move(projection_expressions)); - proj->AddChild(std::move(root)); - return unique_ptr_cast(std::move(proj)); -} - -BoundStatement Binder::Bind(UpdateStatement &stmt) { - BoundStatement result; - unique_ptr root; - - // visit the table reference - auto bound_table = Bind(*stmt.table); - if (bound_table->type != TableReferenceType::BASE_TABLE) { - throw BinderException("Can only update base table!"); - } - auto &table_binding = bound_table->Cast(); - auto &table = table_binding.table; - - // Add CTEs as bindable - AddCTEMap(stmt.cte_map); - - optional_ptr get; - if (stmt.from_table) { - auto from_binder = Binder::CreateBinder(context, this); - BoundJoinRef bound_crossproduct(JoinRefType::CROSS); - bound_crossproduct.left = std::move(bound_table); - bound_crossproduct.right = from_binder->Bind(*stmt.from_table); - root = CreatePlan(bound_crossproduct); - get = &root->children[0]->Cast(); - bind_context.AddContext(std::move(from_binder->bind_context)); - } else { - root = CreatePlan(*bound_table); - get = &root->Cast(); - } - - if (!table.temporary) { - // update of persistent table: not read only! - properties.modified_databases.insert(table.catalog.GetName()); - } - auto update = make_uniq(table); - - // set return_chunk boolean early because it needs uses update_is_del_and_insert logic - if (!stmt.returning_list.empty()) { - update->return_chunk = true; - } - // bind the default values - BindDefaultValues(table.GetColumns(), update->bound_defaults); - - // project any additional columns required for the condition/expressions - if (stmt.set_info->condition) { - WhereBinder binder(*this, context); - auto condition = binder.Bind(stmt.set_info->condition); - - PlanSubqueries(condition, root); - auto filter = make_uniq(std::move(condition)); - filter->AddChild(std::move(root)); - root = std::move(filter); - } - - D_ASSERT(stmt.set_info); - D_ASSERT(stmt.set_info->columns.size() == stmt.set_info->expressions.size()); - - auto proj_tmp = BindUpdateSet(*update, std::move(root), *stmt.set_info, table, update->columns); - D_ASSERT(proj_tmp->type == LogicalOperatorType::LOGICAL_PROJECTION); - auto proj = unique_ptr_cast(std::move(proj_tmp)); - - // bind any extra columns necessary for CHECK constraints or indexes - table.BindUpdateConstraints(*get, *proj, *update, context); - - // finally add the row id column to the projection list - proj->expressions.push_back(make_uniq( - LogicalType::ROW_TYPE, ColumnBinding(get->table_index, get->column_ids.size()))); - get->column_ids.push_back(COLUMN_IDENTIFIER_ROW_ID); - - // set the projection as child of the update node and finalize the result - update->AddChild(std::move(proj)); - - auto update_table_index = GenerateTableIndex(); - update->table_index = update_table_index; - if (!stmt.returning_list.empty()) { - unique_ptr update_as_logicaloperator = std::move(update); - - return BindReturning(std::move(stmt.returning_list), table, stmt.table->alias, update_table_index, - std::move(update_as_logicaloperator), std::move(result)); - } - - result.names = {"Count"}; - result.types = {LogicalType::BIGINT}; - result.plan = std::move(update); - properties.allow_stream_result = false; - properties.return_type = StatementReturnType::CHANGED_ROWS; - return result; -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -BoundStatement Binder::Bind(VacuumStatement &stmt) { - BoundStatement result; - - unique_ptr root; - - if (stmt.info->has_table) { - D_ASSERT(!stmt.info->table); - D_ASSERT(stmt.info->column_id_map.empty()); - auto bound_table = Bind(*stmt.info->ref); - if (bound_table->type != TableReferenceType::BASE_TABLE) { - throw InvalidInputException("Can only vacuum/analyze base tables!"); - } - auto ref = unique_ptr_cast(std::move(bound_table)); - auto &table = ref->table; - stmt.info->table = &table; - - auto &columns = stmt.info->columns; - vector> select_list; - if (columns.empty()) { - // Empty means ALL columns should be vacuumed/analyzed - auto &get = ref->get->Cast(); - columns.insert(columns.end(), get.names.begin(), get.names.end()); - } - - case_insensitive_set_t column_name_set; - vector non_generated_column_names; - for (auto &col_name : columns) { - if (column_name_set.count(col_name) > 0) { - throw BinderException("Vacuum the same column twice(same name in column name list)"); - } - column_name_set.insert(col_name); - if (!table.ColumnExists(col_name)) { - throw BinderException("Column with name \"%s\" does not exist", col_name); - } - auto &col = table.GetColumn(col_name); - // ignore generated column - if (col.Generated()) { - continue; - } - non_generated_column_names.push_back(col_name); - ColumnRefExpression colref(col_name, table.name); - auto result = bind_context.BindColumn(colref, 0); - if (result.HasError()) { - throw BinderException(result.error); - } - select_list.push_back(std::move(result.expression)); - } - stmt.info->columns = std::move(non_generated_column_names); - if (!select_list.empty()) { - auto table_scan = CreatePlan(*ref); - D_ASSERT(table_scan->type == LogicalOperatorType::LOGICAL_GET); - - auto &get = table_scan->Cast(); - - D_ASSERT(select_list.size() == get.column_ids.size()); - D_ASSERT(stmt.info->columns.size() == get.column_ids.size()); - for (idx_t i = 0; i < get.column_ids.size(); i++) { - stmt.info->column_id_map[i] = - table.GetColumns().LogicalToPhysical(LogicalIndex(get.column_ids[i])).index; - } - - auto projection = make_uniq(GenerateTableIndex(), std::move(select_list)); - projection->children.push_back(std::move(table_scan)); - - root = std::move(projection); - } else { - // eg. CREATE TABLE test (x AS (1)); - // ANALYZE test; - // Make it not a SINK so it doesn't have to do anything - stmt.info->has_table = false; - } - } - auto vacuum = make_uniq(LogicalOperatorType::LOGICAL_VACUUM, std::move(stmt.info)); - if (root) { - vacuum->children.push_back(std::move(root)); - } - - result.names = {"Success"}; - result.types = {LogicalType::BOOLEAN}; - result.plan = std::move(vacuum); - properties.return_type = StatementReturnType::NOTHING; - return result; -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - -namespace duckdb { - -static bool TryLoadExtensionForReplacementScan(ClientContext &context, const string &table_name) { - auto lower_name = StringUtil::Lower(table_name); - auto &dbconfig = DBConfig::GetConfig(context); - - if (!dbconfig.options.autoload_known_extensions) { - return false; - } - - for (const auto &entry : EXTENSION_FILE_POSTFIXES) { - if (StringUtil::EndsWith(lower_name, entry.name)) { - ExtensionHelper::AutoLoadExtension(context, entry.extension); - return true; - } - } - - for (const auto &entry : EXTENSION_FILE_CONTAINS) { - if (StringUtil::Contains(lower_name, entry.name)) { - ExtensionHelper::AutoLoadExtension(context, entry.extension); - return true; - } - } - - return false; -} - -unique_ptr Binder::BindWithReplacementScan(ClientContext &context, const string &table_name, - BaseTableRef &ref) { - auto &config = DBConfig::GetConfig(context); - if (context.config.use_replacement_scans) { - for (auto &scan : config.replacement_scans) { - auto replacement_function = scan.function(context, table_name, scan.data.get()); - if (replacement_function) { - if (!ref.alias.empty()) { - // user-provided alias overrides the default alias - replacement_function->alias = ref.alias; - } else if (replacement_function->alias.empty()) { - // if the replacement scan itself did not provide an alias we use the table name - replacement_function->alias = ref.table_name; - } - if (replacement_function->type == TableReferenceType::TABLE_FUNCTION) { - auto &table_function = replacement_function->Cast(); - table_function.column_name_alias = ref.column_name_alias; - } else if (replacement_function->type == TableReferenceType::SUBQUERY) { - auto &subquery = replacement_function->Cast(); - subquery.column_name_alias = ref.column_name_alias; - } else { - throw InternalException("Replacement scan should return either a table function or a subquery"); - } - return Bind(*replacement_function); - } - } - } - - return nullptr; -} - -unique_ptr Binder::Bind(BaseTableRef &ref) { - QueryErrorContext error_context(root_statement, ref.query_location); - // CTEs and views are also referred to using BaseTableRefs, hence need to distinguish here - // check if the table name refers to a CTE - - // CTE name should never be qualified (i.e. schema_name should be empty) - optional_ptr found_cte = nullptr; - if (ref.schema_name.empty()) { - found_cte = FindCTE(ref.table_name, ref.table_name == alias); - } - - if (found_cte) { - // Check if there is a CTE binding in the BindContext - auto &cte = *found_cte; - auto ctebinding = bind_context.GetCTEBinding(ref.table_name); - if (!ctebinding) { - if (CTEIsAlreadyBound(cte)) { - throw BinderException( - "Circular reference to CTE \"%s\", There are two possible solutions. \n1. use WITH RECURSIVE to " - "use recursive CTEs. \n2. If " - "you want to use the TABLE name \"%s\" the same as the CTE name, please explicitly add " - "\"SCHEMA\" before table name. You can try \"main.%s\" (main is the duckdb default schema)", - ref.table_name, ref.table_name, ref.table_name); - } - // Move CTE to subquery and bind recursively - SubqueryRef subquery(unique_ptr_cast(cte.query->Copy())); - subquery.alias = ref.alias.empty() ? ref.table_name : ref.alias; - subquery.column_name_alias = cte.aliases; - for (idx_t i = 0; i < ref.column_name_alias.size(); i++) { - if (i < subquery.column_name_alias.size()) { - subquery.column_name_alias[i] = ref.column_name_alias[i]; - } else { - subquery.column_name_alias.push_back(ref.column_name_alias[i]); - } - } - return Bind(subquery, found_cte); - } else { - // There is a CTE binding in the BindContext. - // This can only be the case if there is a recursive CTE, - // or a materialized CTE present. - auto index = GenerateTableIndex(); - auto materialized = cte.materialized; - if (materialized == CTEMaterialize::CTE_MATERIALIZE_DEFAULT) { -#ifdef DUCKDB_ALTERNATIVE_VERIFY - materialized = CTEMaterialize::CTE_MATERIALIZE_ALWAYS; -#else - materialized = CTEMaterialize::CTE_MATERIALIZE_NEVER; -#endif - } - auto result = make_uniq(index, ctebinding->index, materialized); - auto b = ctebinding; - auto alias = ref.alias.empty() ? ref.table_name : ref.alias; - auto names = BindContext::AliasColumnNames(alias, b->names, ref.column_name_alias); - - bind_context.AddGenericBinding(index, alias, names, b->types); - // Update references to CTE - auto cteref = bind_context.cte_references[ref.table_name]; - (*cteref)++; - - result->types = b->types; - result->bound_columns = std::move(names); - return std::move(result); - } - } - // not a CTE - // extract a table or view from the catalog - BindSchemaOrCatalog(ref.catalog_name, ref.schema_name); - auto table_or_view = Catalog::GetEntry(context, CatalogType::TABLE_ENTRY, ref.catalog_name, ref.schema_name, - ref.table_name, OnEntryNotFound::RETURN_NULL, error_context); - // we still didn't find the table - if (GetBindingMode() == BindingMode::EXTRACT_NAMES) { - if (!table_or_view || table_or_view->type == CatalogType::TABLE_ENTRY) { - // if we are in EXTRACT_NAMES, we create a dummy table ref - AddTableName(ref.table_name); - - // add a bind context entry - auto table_index = GenerateTableIndex(); - auto alias = ref.alias.empty() ? ref.table_name : ref.alias; - vector types {LogicalType::INTEGER}; - vector names {"__dummy_col" + to_string(table_index)}; - bind_context.AddGenericBinding(table_index, alias, names, types); - return make_uniq_base(table_index); - } - } - if (!table_or_view) { - string table_name = ref.catalog_name; - if (!ref.schema_name.empty()) { - table_name += (!table_name.empty() ? "." : "") + ref.schema_name; - } - table_name += (!table_name.empty() ? "." : "") + ref.table_name; - // table could not be found: try to bind a replacement scan - // Try replacement scan bind - auto replacement_scan_bind_result = BindWithReplacementScan(context, table_name, ref); - if (replacement_scan_bind_result) { - return replacement_scan_bind_result; - } - - // Try autoloading an extension, then retry the replacement scan bind - auto extension_loaded = TryLoadExtensionForReplacementScan(context, table_name); - if (extension_loaded) { - replacement_scan_bind_result = BindWithReplacementScan(context, table_name, ref); - if (replacement_scan_bind_result) { - return replacement_scan_bind_result; - } - } - - // could not find an alternative: bind again to get the error - table_or_view = Catalog::GetEntry(context, CatalogType::TABLE_ENTRY, ref.catalog_name, ref.schema_name, - ref.table_name, OnEntryNotFound::THROW_EXCEPTION, error_context); - } - switch (table_or_view->type) { - case CatalogType::TABLE_ENTRY: { - // base table: create the BoundBaseTableRef node - auto table_index = GenerateTableIndex(); - auto &table = table_or_view->Cast(); - - unique_ptr bind_data; - auto scan_function = table.GetScanFunction(context, bind_data); - auto alias = ref.alias.empty() ? ref.table_name : ref.alias; - // TODO: bundle the type and name vector in a struct (e.g PackedColumnMetadata) - vector table_types; - vector table_names; - vector table_categories; - - vector return_types; - vector return_names; - for (auto &col : table.GetColumns().Logical()) { - table_types.push_back(col.Type()); - table_names.push_back(col.Name()); - return_types.push_back(col.Type()); - return_names.push_back(col.Name()); - } - table_names = BindContext::AliasColumnNames(alias, table_names, ref.column_name_alias); - - auto logical_get = make_uniq(table_index, scan_function, std::move(bind_data), - std::move(return_types), std::move(return_names)); - bind_context.AddBaseTable(table_index, alias, table_names, table_types, logical_get->column_ids, - logical_get->GetTable().get()); - return make_uniq_base(table, std::move(logical_get)); - } - case CatalogType::VIEW_ENTRY: { - // the node is a view: get the query that the view represents - auto &view_catalog_entry = table_or_view->Cast(); - // We need to use a new binder for the view that doesn't reference any CTEs - // defined for this binder so there are no collisions between the CTEs defined - // for the view and for the current query - bool inherit_ctes = false; - auto view_binder = Binder::CreateBinder(context, this, inherit_ctes); - view_binder->can_contain_nulls = true; - SubqueryRef subquery(unique_ptr_cast(view_catalog_entry.query->Copy())); - subquery.alias = ref.alias.empty() ? ref.table_name : ref.alias; - subquery.column_name_alias = - BindContext::AliasColumnNames(subquery.alias, view_catalog_entry.aliases, ref.column_name_alias); - // bind the child subquery - view_binder->AddBoundView(view_catalog_entry); - auto bound_child = view_binder->Bind(subquery); - if (!view_binder->correlated_columns.empty()) { - throw BinderException("Contents of view were altered - view bound correlated columns"); - } - - D_ASSERT(bound_child->type == TableReferenceType::SUBQUERY); - // verify that the types and names match up with the expected types and names - auto &bound_subquery = bound_child->Cast(); - if (GetBindingMode() != BindingMode::EXTRACT_NAMES && - bound_subquery.subquery->types != view_catalog_entry.types) { - throw BinderException("Contents of view were altered: types don't match!"); - } - bind_context.AddView(bound_subquery.subquery->GetRootIndex(), subquery.alias, subquery, - *bound_subquery.subquery, &view_catalog_entry); - return bound_child; - } - default: - throw InternalException("Catalog entry type"); - } -} -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr Binder::Bind(EmptyTableRef &ref) { - return make_uniq(GenerateTableIndex()); -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -unique_ptr Binder::Bind(ExpressionListRef &expr) { - auto result = make_uniq(); - result->types = expr.expected_types; - result->names = expr.expected_names; - // bind value list - InsertBinder binder(*this, context); - binder.target_type = LogicalType(LogicalTypeId::INVALID); - for (idx_t list_idx = 0; list_idx < expr.values.size(); list_idx++) { - auto &expression_list = expr.values[list_idx]; - if (result->names.empty()) { - // no names provided, generate them - for (idx_t val_idx = 0; val_idx < expression_list.size(); val_idx++) { - result->names.push_back("col" + to_string(val_idx)); - } - } - - vector> list; - for (idx_t val_idx = 0; val_idx < expression_list.size(); val_idx++) { - if (!result->types.empty()) { - D_ASSERT(result->types.size() == expression_list.size()); - binder.target_type = result->types[val_idx]; - } - auto expr = binder.Bind(expression_list[val_idx]); - list.push_back(std::move(expr)); - } - result->values.push_back(std::move(list)); - } - if (result->types.empty() && !expr.values.empty()) { - // there are no types specified - // we have to figure out the result types - // for each column, we iterate over all of the expressions and select the max logical type - // we initialize all types to SQLNULL - result->types.resize(expr.values[0].size(), LogicalType::SQLNULL); - // now loop over the lists and select the max logical type - for (idx_t list_idx = 0; list_idx < result->values.size(); list_idx++) { - auto &list = result->values[list_idx]; - for (idx_t val_idx = 0; val_idx < list.size(); val_idx++) { - result->types[val_idx] = - LogicalType::MaxLogicalType(result->types[val_idx], list[val_idx]->return_type); - } - } - // finally do another loop over the expressions and add casts where required - for (idx_t list_idx = 0; list_idx < result->values.size(); list_idx++) { - auto &list = result->values[list_idx]; - for (idx_t val_idx = 0; val_idx < list.size(); val_idx++) { - list[val_idx] = - BoundCastExpression::AddCastToType(context, std::move(list[val_idx]), result->types[val_idx]); - } - } - } - result->bind_index = GenerateTableIndex(); - bind_context.AddGenericBinding(result->bind_index, expr.alias, result->names, result->types); - return std::move(result); -} - -} // namespace duckdb - - - - - - - - - - - - - - - -namespace duckdb { - -static unique_ptr BindColumn(Binder &binder, ClientContext &context, const string &alias, - const string &column_name) { - auto expr = make_uniq_base(column_name, alias); - ExpressionBinder expr_binder(binder, context); - auto result = expr_binder.Bind(expr); - return make_uniq(std::move(result)); -} - -static unique_ptr AddCondition(ClientContext &context, Binder &left_binder, Binder &right_binder, - const string &left_alias, const string &right_alias, - const string &column_name, ExpressionType type) { - ExpressionBinder expr_binder(left_binder, context); - auto left = BindColumn(left_binder, context, left_alias, column_name); - auto right = BindColumn(right_binder, context, right_alias, column_name); - return make_uniq(type, std::move(left), std::move(right)); -} - -bool Binder::TryFindBinding(const string &using_column, const string &join_side, string &result) { - // for each using column, get the matching binding - auto bindings = bind_context.GetMatchingBindings(using_column); - if (bindings.empty()) { - return false; - } - // find the join binding - for (auto &binding : bindings) { - if (!result.empty()) { - string error = "Column name \""; - error += using_column; - error += "\" is ambiguous: it exists more than once on "; - error += join_side; - error += " side of join.\nCandidates:"; - for (auto &binding : bindings) { - error += "\n\t"; - error += binding; - error += "."; - error += bind_context.GetActualColumnName(binding, using_column); - } - throw BinderException(error); - } else { - result = binding; - } - } - return true; -} - -string Binder::FindBinding(const string &using_column, const string &join_side) { - string result; - if (!TryFindBinding(using_column, join_side, result)) { - throw BinderException("Column \"%s\" does not exist on %s side of join!", using_column, join_side); - } - return result; -} - -static void AddUsingBindings(UsingColumnSet &set, optional_ptr input_set, const string &input_binding) { - if (input_set) { - for (auto &entry : input_set->bindings) { - set.bindings.insert(entry); - } - } else { - set.bindings.insert(input_binding); - } -} - -static void SetPrimaryBinding(UsingColumnSet &set, JoinType join_type, const string &left_binding, - const string &right_binding) { - switch (join_type) { - case JoinType::LEFT: - case JoinType::INNER: - case JoinType::SEMI: - case JoinType::ANTI: - set.primary_binding = left_binding; - break; - case JoinType::RIGHT: - set.primary_binding = right_binding; - break; - default: - break; - } -} - -string Binder::RetrieveUsingBinding(Binder ¤t_binder, optional_ptr current_set, - const string &using_column, const string &join_side) { - string binding; - if (!current_set) { - binding = current_binder.FindBinding(using_column, join_side); - } else { - binding = current_set->primary_binding; - } - return binding; -} - -static vector RemoveDuplicateUsingColumns(const vector &using_columns) { - vector result; - case_insensitive_set_t handled_columns; - for (auto &using_column : using_columns) { - if (handled_columns.find(using_column) == handled_columns.end()) { - handled_columns.insert(using_column); - result.push_back(using_column); - } - } - return result; -} - -unique_ptr Binder::Bind(JoinRef &ref) { - auto result = make_uniq(ref.ref_type); - result->left_binder = Binder::CreateBinder(context, this); - result->right_binder = Binder::CreateBinder(context, this); - auto &left_binder = *result->left_binder; - auto &right_binder = *result->right_binder; - - result->type = ref.type; - result->left = left_binder.Bind(*ref.left); - { - LateralBinder binder(left_binder, context); - result->right = right_binder.Bind(*ref.right); - bool is_lateral = false; - // Store the correlated columns in the right binder in bound ref for planning of LATERALs - // Ignore the correlated columns in the left binder, flattening handles those correlations - result->correlated_columns = right_binder.correlated_columns; - // Find correlations for the current join - for (auto &cor_col : result->correlated_columns) { - if (cor_col.depth == 1) { - // Depth 1 indicates columns binding from the left indicating a lateral join - is_lateral = true; - break; - } - } - result->lateral = is_lateral; - if (result->lateral) { - // lateral join: can only be an INNER or LEFT join - if (ref.type != JoinType::INNER && ref.type != JoinType::LEFT) { - throw BinderException("The combining JOIN type must be INNER or LEFT for a LATERAL reference"); - } - } - } - - vector> extra_conditions; - vector extra_using_columns; - switch (ref.ref_type) { - case JoinRefType::NATURAL: { - // natural join, figure out which column names are present in both sides of the join - // first bind the left hand side and get a list of all the tables and column names - case_insensitive_set_t lhs_columns; - auto &lhs_binding_list = left_binder.bind_context.GetBindingsList(); - for (auto &binding : lhs_binding_list) { - for (auto &column_name : binding.get().names) { - lhs_columns.insert(column_name); - } - } - // now bind the rhs - for (auto &column_name : lhs_columns) { - auto right_using_binding = right_binder.bind_context.GetUsingBinding(column_name); - - string right_binding; - // loop over the set of lhs columns, and figure out if there is a table in the rhs with the same name - if (!right_using_binding) { - if (!right_binder.TryFindBinding(column_name, "right", right_binding)) { - // no match found for this column on the rhs: skip - continue; - } - } - extra_using_columns.push_back(column_name); - } - if (extra_using_columns.empty()) { - // no matching bindings found in natural join: throw an exception - string error_msg = "No columns found to join on in NATURAL JOIN.\n"; - error_msg += "Use CROSS JOIN if you intended for this to be a cross-product."; - // gather all left/right candidates - string left_candidates, right_candidates; - auto &rhs_binding_list = right_binder.bind_context.GetBindingsList(); - for (auto &binding_ref : lhs_binding_list) { - auto &binding = binding_ref.get(); - for (auto &column_name : binding.names) { - if (!left_candidates.empty()) { - left_candidates += ", "; - } - left_candidates += binding.alias + "." + column_name; - } - } - for (auto &binding_ref : rhs_binding_list) { - auto &binding = binding_ref.get(); - for (auto &column_name : binding.names) { - if (!right_candidates.empty()) { - right_candidates += ", "; - } - right_candidates += binding.alias + "." + column_name; - } - } - error_msg += "\n Left candidates: " + left_candidates; - error_msg += "\n Right candidates: " + right_candidates; - throw BinderException(FormatError(ref, error_msg)); - } - break; - } - case JoinRefType::REGULAR: - case JoinRefType::ASOF: - if (!ref.using_columns.empty()) { - // USING columns - D_ASSERT(!result->condition); - extra_using_columns = ref.using_columns; - } - break; - - case JoinRefType::CROSS: - case JoinRefType::POSITIONAL: - case JoinRefType::DEPENDENT: - break; - } - extra_using_columns = RemoveDuplicateUsingColumns(extra_using_columns); - - if (!extra_using_columns.empty()) { - vector> left_using_bindings; - vector> right_using_bindings; - for (idx_t i = 0; i < extra_using_columns.size(); i++) { - auto &using_column = extra_using_columns[i]; - // we check if there is ALREADY a using column of the same name in the left and right set - // this can happen if we chain USING clauses - // e.g. x JOIN y USING (c) JOIN z USING (c) - auto left_using_binding = left_binder.bind_context.GetUsingBinding(using_column); - auto right_using_binding = right_binder.bind_context.GetUsingBinding(using_column); - if (!left_using_binding) { - left_binder.bind_context.GetMatchingBinding(using_column); - } - if (!right_using_binding) { - right_binder.bind_context.GetMatchingBinding(using_column); - } - left_using_bindings.push_back(left_using_binding); - right_using_bindings.push_back(right_using_binding); - } - - for (idx_t i = 0; i < extra_using_columns.size(); i++) { - auto &using_column = extra_using_columns[i]; - string left_binding; - string right_binding; - - auto set = make_uniq(); - auto &left_using_binding = left_using_bindings[i]; - auto &right_using_binding = right_using_bindings[i]; - left_binding = RetrieveUsingBinding(left_binder, left_using_binding, using_column, "left"); - right_binding = RetrieveUsingBinding(right_binder, right_using_binding, using_column, "right"); - - // Last column of ASOF JOIN ... USING is >= - const auto type = (ref.ref_type == JoinRefType::ASOF && i == extra_using_columns.size() - 1) - ? ExpressionType::COMPARE_GREATERTHANOREQUALTO - : ExpressionType::COMPARE_EQUAL; - - extra_conditions.push_back( - AddCondition(context, left_binder, right_binder, left_binding, right_binding, using_column, type)); - - AddUsingBindings(*set, left_using_binding, left_binding); - AddUsingBindings(*set, right_using_binding, right_binding); - SetPrimaryBinding(*set, ref.type, left_binding, right_binding); - bind_context.TransferUsingBinding(left_binder.bind_context, left_using_binding, *set, left_binding, - using_column); - bind_context.TransferUsingBinding(right_binder.bind_context, right_using_binding, *set, right_binding, - using_column); - AddUsingBindingSet(std::move(set)); - } - } - - auto right_bindings_list_copy = right_binder.bind_context.GetBindingsList(); - - bind_context.AddContext(std::move(left_binder.bind_context)); - bind_context.AddContext(std::move(right_binder.bind_context)); - - // Update the correlated columns for the parent binder - // For the left binder, depth >= 1 indicates correlations from the parent binder - for (const auto &col : left_binder.correlated_columns) { - if (col.depth >= 1) { - AddCorrelatedColumn(col); - } - } - // For the right binder, depth > 1 indicates correlations from the parent binder - // (depth = 1 indicates correlations from the left side of the join) - for (auto col : right_binder.correlated_columns) { - if (col.depth > 1) { - // Decrement the depth to account for the effect of the lateral binder - col.depth--; - AddCorrelatedColumn(col); - } - } - - for (auto &condition : extra_conditions) { - if (ref.condition) { - ref.condition = make_uniq(ExpressionType::CONJUNCTION_AND, std::move(ref.condition), - std::move(condition)); - } else { - ref.condition = std::move(condition); - } - } - if (ref.condition) { - WhereBinder binder(*this, context); - result->condition = binder.Bind(ref.condition); - } - - if (result->type == JoinType::SEMI || result->type == JoinType::ANTI) { - bind_context.RemoveContext(right_bindings_list_copy); - } - - return std::move(result); -} - -} // namespace duckdb - - -namespace duckdb { - -void Binder::BindNamedParameters(named_parameter_type_map_t &types, named_parameter_map_t &values, - QueryErrorContext &error_context, string &func_name) { - for (auto &kv : values) { - auto entry = types.find(kv.first); - if (entry == types.end()) { - // create a list of named parameters for the error - string named_params; - for (auto &kv : types) { - named_params += " "; - named_params += kv.first; - named_params += " "; - named_params += kv.second.ToString(); - named_params += "\n"; - } - string error_msg; - if (named_params.empty()) { - error_msg = "Function does not accept any named parameters."; - } else { - error_msg = "Candidates:\n" + named_params; - } - throw BinderException(error_context.FormatError("Invalid named parameter \"%s\" for function %s\n%s", - kv.first, func_name, error_msg)); - } - if (entry->second.id() != LogicalTypeId::ANY) { - kv.second = kv.second.DefaultCastAs(entry->second); - } - } -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - - - - - -namespace duckdb { - -static void ConstructPivots(PivotRef &ref, vector &pivot_values, idx_t pivot_idx = 0, - const PivotValueElement ¤t_value = PivotValueElement()) { - auto &pivot = ref.pivots[pivot_idx]; - bool last_pivot = pivot_idx + 1 == ref.pivots.size(); - for (auto &entry : pivot.entries) { - PivotValueElement new_value = current_value; - string name = entry.alias; - D_ASSERT(entry.values.size() == pivot.pivot_expressions.size()); - for (idx_t v = 0; v < entry.values.size(); v++) { - auto &value = entry.values[v]; - new_value.values.push_back(value); - if (entry.alias.empty()) { - if (name.empty()) { - name = value.ToString(); - } else { - name += "_" + value.ToString(); - } - } - } - if (!current_value.name.empty()) { - new_value.name = current_value.name + "_" + name; - } else { - new_value.name = std::move(name); - } - if (last_pivot) { - pivot_values.push_back(std::move(new_value)); - } else { - // need to recurse - ConstructPivots(ref, pivot_values, pivot_idx + 1, new_value); - } - } -} - -static void ExtractPivotExpressions(ParsedExpression &expr, case_insensitive_set_t &handled_columns) { - if (expr.type == ExpressionType::COLUMN_REF) { - auto &child_colref = expr.Cast(); - if (child_colref.IsQualified()) { - throw BinderException("PIVOT expression cannot contain qualified columns"); - } - handled_columns.insert(child_colref.GetColumnName()); - } - ParsedExpressionIterator::EnumerateChildren( - expr, [&](ParsedExpression &child) { ExtractPivotExpressions(child, handled_columns); }); -} - -static unique_ptr ConstructInitialGrouping(PivotRef &ref, vector> all_columns, - const case_insensitive_set_t &handled_columns) { - auto subquery = make_uniq(); - subquery->from_table = std::move(ref.source); - if (ref.groups.empty()) { - // if rows are not specified any columns that are not pivoted/aggregated on are added to the GROUP BY clause - for (auto &entry : all_columns) { - if (entry->type != ExpressionType::COLUMN_REF) { - throw InternalException("Unexpected child of pivot source - not a ColumnRef"); - } - auto &columnref = entry->Cast(); - if (handled_columns.find(columnref.GetColumnName()) == handled_columns.end()) { - // not handled - add to grouping set - subquery->groups.group_expressions.push_back( - make_uniq(Value::INTEGER(subquery->select_list.size() + 1))); - subquery->select_list.push_back(make_uniq(columnref.GetColumnName())); - } - } - } else { - // if rows are specified only the columns mentioned in rows are added as groups - for (auto &row : ref.groups) { - subquery->groups.group_expressions.push_back( - make_uniq(Value::INTEGER(subquery->select_list.size() + 1))); - subquery->select_list.push_back(make_uniq(row)); - } - } - return subquery; -} - -static unique_ptr PivotFilteredAggregate(PivotRef &ref, vector> all_columns, - const case_insensitive_set_t &handled_columns, - vector pivot_values) { - auto subquery = ConstructInitialGrouping(ref, std::move(all_columns), handled_columns); - - // push the filtered aggregates - for (auto &pivot_value : pivot_values) { - unique_ptr filter; - idx_t pivot_value_idx = 0; - for (auto &pivot_column : ref.pivots) { - for (auto &pivot_expr : pivot_column.pivot_expressions) { - auto column_ref = make_uniq(LogicalType::VARCHAR, pivot_expr->Copy()); - auto constant_value = make_uniq(pivot_value.values[pivot_value_idx++]); - auto comp_expr = make_uniq(ExpressionType::COMPARE_NOT_DISTINCT_FROM, - std::move(column_ref), std::move(constant_value)); - if (filter) { - filter = make_uniq(ExpressionType::CONJUNCTION_AND, std::move(filter), - std::move(comp_expr)); - } else { - filter = std::move(comp_expr); - } - } - } - for (auto &aggregate : ref.aggregates) { - auto copied_aggr = aggregate->Copy(); - auto &aggr = copied_aggr->Cast(); - aggr.filter = filter->Copy(); - auto &aggr_name = aggregate->alias; - auto name = pivot_value.name; - if (ref.aggregates.size() > 1 || !aggr_name.empty()) { - // if there are multiple aggregates specified we add the name of the aggregate as well - name += "_" + (aggr_name.empty() ? aggregate->GetName() : aggr_name); - } - aggr.alias = name; - subquery->select_list.push_back(std::move(copied_aggr)); - } - } - return subquery; -} - -struct PivotBindState { - vector internal_group_names; - vector group_names; - vector aggregate_names; - vector internal_aggregate_names; -}; - -static unique_ptr PivotInitialAggregate(PivotBindState &bind_state, PivotRef &ref, - vector> all_columns, - const case_insensitive_set_t &handled_columns) { - auto subquery_stage1 = ConstructInitialGrouping(ref, std::move(all_columns), handled_columns); - - idx_t group_count = 0; - for (auto &expr : subquery_stage1->select_list) { - bind_state.group_names.push_back(expr->GetName()); - if (expr->alias.empty()) { - expr->alias = "__internal_pivot_group" + std::to_string(++group_count); - } - bind_state.internal_group_names.push_back(expr->alias); - } - // group by all of the pivot values - idx_t pivot_count = 0; - for (auto &pivot_column : ref.pivots) { - for (auto &pivot_expr : pivot_column.pivot_expressions) { - if (pivot_expr->alias.empty()) { - pivot_expr->alias = "__internal_pivot_ref" + std::to_string(++pivot_count); - } - auto pivot_alias = pivot_expr->alias; - subquery_stage1->groups.group_expressions.push_back( - make_uniq(Value::INTEGER(subquery_stage1->select_list.size() + 1))); - subquery_stage1->select_list.push_back(std::move(pivot_expr)); - pivot_expr = make_uniq(std::move(pivot_alias)); - } - } - idx_t aggregate_count = 0; - // finally add the aggregates - for (auto &aggregate : ref.aggregates) { - auto aggregate_alias = "__internal_pivot_aggregate" + std::to_string(++aggregate_count); - bind_state.aggregate_names.push_back(aggregate->alias); - bind_state.internal_aggregate_names.push_back(aggregate_alias); - aggregate->alias = std::move(aggregate_alias); - subquery_stage1->select_list.push_back(std::move(aggregate)); - } - return subquery_stage1; -} - -unique_ptr ConstructPivotExpression(unique_ptr pivot_expr) { - auto cast = make_uniq(LogicalType::VARCHAR, std::move(pivot_expr)); - vector> coalesce_children; - coalesce_children.push_back(std::move(cast)); - coalesce_children.push_back(make_uniq(Value("NULL"))); - auto coalesce = make_uniq(ExpressionType::OPERATOR_COALESCE, std::move(coalesce_children)); - return std::move(coalesce); -} - -static unique_ptr PivotListAggregate(PivotBindState &bind_state, PivotRef &ref, - unique_ptr subquery_stage1) { - auto subquery_stage2 = make_uniq(); - // wrap the subquery of stage 1 - auto subquery_select = make_uniq(); - subquery_select->node = std::move(subquery_stage1); - auto subquery_ref = make_uniq(std::move(subquery_select)); - - // add all of the groups - for (idx_t gr = 0; gr < bind_state.internal_group_names.size(); gr++) { - subquery_stage2->groups.group_expressions.push_back( - make_uniq(Value::INTEGER(subquery_stage2->select_list.size() + 1))); - auto group_reference = make_uniq(bind_state.internal_group_names[gr]); - group_reference->alias = bind_state.internal_group_names[gr]; - subquery_stage2->select_list.push_back(std::move(group_reference)); - } - - // construct the list aggregates - for (idx_t aggr = 0; aggr < bind_state.internal_aggregate_names.size(); aggr++) { - auto colref = make_uniq(bind_state.internal_aggregate_names[aggr]); - vector> list_children; - list_children.push_back(std::move(colref)); - auto aggregate = make_uniq("list", std::move(list_children)); - aggregate->alias = bind_state.internal_aggregate_names[aggr]; - subquery_stage2->select_list.push_back(std::move(aggregate)); - } - // construct the pivot list - auto pivot_name = "__internal_pivot_name"; - unique_ptr expr; - for (auto &pivot : ref.pivots) { - for (auto &pivot_expr : pivot.pivot_expressions) { - // coalesce(pivot::VARCHAR, 'NULL') - auto coalesce = ConstructPivotExpression(std::move(pivot_expr)); - if (!expr) { - expr = std::move(coalesce); - } else { - // string concat - vector> concat_children; - concat_children.push_back(std::move(expr)); - concat_children.push_back(make_uniq(Value("_"))); - concat_children.push_back(std::move(coalesce)); - auto concat = make_uniq("concat", std::move(concat_children)); - expr = std::move(concat); - } - } - } - // list(coalesce) - vector> list_children; - list_children.push_back(std::move(expr)); - auto aggregate = make_uniq("list", std::move(list_children)); - - aggregate->alias = pivot_name; - subquery_stage2->select_list.push_back(std::move(aggregate)); - - subquery_stage2->from_table = std::move(subquery_ref); - return subquery_stage2; -} - -static unique_ptr PivotFinalOperator(PivotBindState &bind_state, PivotRef &ref, - unique_ptr subquery, - vector pivot_values) { - auto final_pivot_operator = make_uniq(); - // wrap the subquery of stage 1 - auto subquery_select = make_uniq(); - subquery_select->node = std::move(subquery); - auto subquery_ref = make_uniq(std::move(subquery_select)); - - auto bound_pivot = make_uniq(); - bound_pivot->bound_pivot_values = std::move(pivot_values); - bound_pivot->bound_group_names = std::move(bind_state.group_names); - bound_pivot->bound_aggregate_names = std::move(bind_state.aggregate_names); - bound_pivot->source = std::move(subquery_ref); - - final_pivot_operator->select_list.push_back(make_uniq()); - final_pivot_operator->from_table = std::move(bound_pivot); - return final_pivot_operator; -} - -void ExtractPivotAggregates(BoundTableRef &node, vector> &aggregates) { - if (node.type != TableReferenceType::SUBQUERY) { - throw InternalException("Pivot - Expected a subquery"); - } - auto &subq = node.Cast(); - if (subq.subquery->type != QueryNodeType::SELECT_NODE) { - throw InternalException("Pivot - Expected a select node"); - } - auto &select = subq.subquery->Cast(); - if (select.from_table->type != TableReferenceType::SUBQUERY) { - throw InternalException("Pivot - Expected another subquery"); - } - auto &subq2 = select.from_table->Cast(); - if (subq2.subquery->type != QueryNodeType::SELECT_NODE) { - throw InternalException("Pivot - Expected another select node"); - } - auto &select2 = subq2.subquery->Cast(); - for (auto &aggr : select2.aggregates) { - aggregates.push_back(aggr->Copy()); - } -} - -unique_ptr Binder::BindBoundPivot(PivotRef &ref) { - // bind the child table in a child binder - auto result = make_uniq(); - result->bind_index = GenerateTableIndex(); - result->child_binder = Binder::CreateBinder(context, this); - result->child = result->child_binder->Bind(*ref.source); - - auto &aggregates = result->bound_pivot.aggregates; - ExtractPivotAggregates(*result->child, aggregates); - if (aggregates.size() != ref.bound_aggregate_names.size()) { - throw InternalException("Pivot aggregate count mismatch (expected %llu, found %llu)", - ref.bound_aggregate_names.size(), aggregates.size()); - } - - vector child_names; - vector child_types; - result->child_binder->bind_context.GetTypesAndNames(child_names, child_types); - - vector names; - vector types; - // emit the groups - for (idx_t i = 0; i < ref.bound_group_names.size(); i++) { - names.push_back(ref.bound_group_names[i]); - types.push_back(child_types[i]); - } - // emit the pivot columns - for (auto &pivot_value : ref.bound_pivot_values) { - for (idx_t aggr_idx = 0; aggr_idx < ref.bound_aggregate_names.size(); aggr_idx++) { - auto &aggr = aggregates[aggr_idx]; - auto &aggr_name = ref.bound_aggregate_names[aggr_idx]; - auto name = pivot_value.name; - if (aggregates.size() > 1 || !aggr_name.empty()) { - // if there are multiple aggregates specified we add the name of the aggregate as well - name += "_" + (aggr_name.empty() ? aggr->GetName() : aggr_name); - } - string pivot_str; - for (auto &value : pivot_value.values) { - auto str = value.ToString(); - if (pivot_str.empty()) { - pivot_str = std::move(str); - } else { - pivot_str += "_" + str; - } - } - result->bound_pivot.pivot_values.push_back(std::move(pivot_str)); - names.push_back(std::move(name)); - types.push_back(aggr->return_type); - } - } - result->bound_pivot.group_count = ref.bound_group_names.size(); - result->bound_pivot.types = types; - auto subquery_alias = ref.alias.empty() ? "__unnamed_pivot" : ref.alias; - bind_context.AddGenericBinding(result->bind_index, subquery_alias, names, types); - MoveCorrelatedExpressions(*result->child_binder); - return std::move(result); -} - -unique_ptr Binder::BindPivot(PivotRef &ref, vector> all_columns) { - // keep track of the columns by which we pivot/aggregate - // any columns which are not pivoted/aggregated on are added to the GROUP BY clause - case_insensitive_set_t handled_columns; - // parse the aggregate, and extract the referenced columns from the aggregate - for (auto &aggr : ref.aggregates) { - if (aggr->type != ExpressionType::FUNCTION) { - throw BinderException(FormatError(*aggr, "Pivot expression must be an aggregate")); - } - if (aggr->HasSubquery()) { - throw BinderException(FormatError(*aggr, "Pivot expression cannot contain subqueries")); - } - if (aggr->IsWindow()) { - throw BinderException(FormatError(*aggr, "Pivot expression cannot contain window functions")); - } - // bind the function as an aggregate to ensure it is an aggregate and not a scalar function - auto &aggr_function = aggr->Cast(); - (void)Catalog::GetEntry(context, aggr_function.catalog, aggr_function.schema, - aggr_function.function_name); - ExtractPivotExpressions(*aggr, handled_columns); - } - - // first add all pivots to the set of handled columns, and check for duplicates - idx_t total_pivots = 1; - for (auto &pivot : ref.pivots) { - if (!pivot.pivot_enum.empty()) { - auto type = Catalog::GetType(context, INVALID_CATALOG, INVALID_SCHEMA, pivot.pivot_enum); - if (type.id() != LogicalTypeId::ENUM) { - throw BinderException( - FormatError(ref, StringUtil::Format("Pivot must reference an ENUM type: \"%s\" is of type \"%s\"", - pivot.pivot_enum, type.ToString()))); - } - auto enum_size = EnumType::GetSize(type); - for (idx_t i = 0; i < enum_size; i++) { - auto enum_value = EnumType::GetValue(Value::ENUM(i, type)); - PivotColumnEntry entry; - entry.values.emplace_back(enum_value); - entry.alias = std::move(enum_value); - pivot.entries.push_back(std::move(entry)); - } - } - total_pivots *= pivot.entries.size(); - // add the pivoted column to the columns that have been handled - for (auto &pivot_name : pivot.pivot_expressions) { - ExtractPivotExpressions(*pivot_name, handled_columns); - } - value_set_t pivots; - for (auto &entry : pivot.entries) { - D_ASSERT(!entry.star_expr); - Value val; - if (entry.values.size() == 1) { - val = entry.values[0]; - } else { - val = Value::LIST(LogicalType::VARCHAR, entry.values); - } - if (pivots.find(val) != pivots.end()) { - throw BinderException(FormatError( - ref, StringUtil::Format("The value \"%s\" was specified multiple times in the IN clause", - val.ToString()))); - } - if (entry.values.size() != pivot.pivot_expressions.size()) { - throw ParserException("PIVOT IN list - inconsistent amount of rows - expected %d but got %d", - pivot.pivot_expressions.size(), entry.values.size()); - } - pivots.insert(val); - } - } - auto &client_config = ClientConfig::GetConfig(context); - auto pivot_limit = client_config.pivot_limit; - if (total_pivots >= pivot_limit) { - throw BinderException("Pivot column limit of %llu exceeded. Use SET pivot_limit=X to increase the limit.", - client_config.pivot_limit); - } - - // construct the required pivot values recursively - vector pivot_values; - ConstructPivots(ref, pivot_values); - - unique_ptr pivot_node; - // pivots have three components - // - the pivots (i.e. future column names) - // - the groups (i.e. the future row names - // - the aggregates (i.e. the values of the pivot columns) - - // we have two ways of executing a pivot statement - // (1) the straightforward manner of filtered aggregates SUM(..) FILTER (pivot_value=X) - // (2) computing the aggregates once, then using LIST to group the aggregates together with the PIVOT operator - // -> filtered aggregates are faster when there are FEW pivot values - // -> LIST is faster when there are MANY pivot values - // we switch dynamically based on the number of pivots to compute - if (pivot_values.size() <= client_config.pivot_filter_threshold) { - // use a set of filtered aggregates - pivot_node = PivotFilteredAggregate(ref, std::move(all_columns), handled_columns, std::move(pivot_values)); - } else { - // executing a pivot statement happens in three stages - // 1) execute the query "SELECT {groups}, {pivots}, {aggregates} FROM {from_clause} GROUP BY {groups}, {pivots} - // this computes all values that are required in the final result, but not yet in the correct orientation - // 2) execute the query "SELECT {groups}, LIST({pivots}), LIST({aggregates}) FROM [Q1] GROUP BY {groups} - // this pushes all pivots and aggregates that belong to a specific group together in an aligned manner - // 3) push a PIVOT operator, that performs the actual pivoting of the values into the different columns - - PivotBindState bind_state; - // Pivot Stage 1 - // SELECT {groups}, {pivots}, {aggregates} FROM {from_clause} GROUP BY {groups}, {pivots} - auto subquery_stage1 = PivotInitialAggregate(bind_state, ref, std::move(all_columns), handled_columns); - - // Pivot stage 2 - // SELECT {groups}, LIST({pivots}), LIST({aggregates}) FROM [Q1] GROUP BY {groups} - auto subquery_stage2 = PivotListAggregate(bind_state, ref, std::move(subquery_stage1)); - - // Pivot stage 3 - // construct the final pivot operator - pivot_node = PivotFinalOperator(bind_state, ref, std::move(subquery_stage2), std::move(pivot_values)); - } - return pivot_node; -} - -unique_ptr Binder::BindUnpivot(Binder &child_binder, PivotRef &ref, - vector> all_columns, - unique_ptr &where_clause) { - D_ASSERT(ref.groups.empty()); - D_ASSERT(ref.pivots.size() == 1); - - unique_ptr expr; - auto select_node = make_uniq(); - select_node->from_table = std::move(ref.source); - - // handle the pivot - auto &unpivot = ref.pivots[0]; - - // handle star expressions in any entries - vector new_entries; - for (auto &entry : unpivot.entries) { - if (entry.star_expr) { - D_ASSERT(entry.values.empty()); - vector> star_columns; - child_binder.ExpandStarExpression(std::move(entry.star_expr), star_columns); - - for (auto &col : star_columns) { - if (col->type != ExpressionType::COLUMN_REF) { - throw InternalException("Unexpected child of unpivot star - not a ColumnRef"); - } - auto &columnref = col->Cast(); - PivotColumnEntry new_entry; - new_entry.values.emplace_back(columnref.GetColumnName()); - new_entry.alias = columnref.GetColumnName(); - new_entries.push_back(std::move(new_entry)); - } - } else { - new_entries.push_back(std::move(entry)); - } - } - unpivot.entries = std::move(new_entries); - - case_insensitive_set_t handled_columns; - case_insensitive_map_t name_map; - for (auto &entry : unpivot.entries) { - for (auto &value : entry.values) { - handled_columns.insert(value.ToString()); - } - } - - for (auto &col_expr : all_columns) { - if (col_expr->type != ExpressionType::COLUMN_REF) { - throw InternalException("Unexpected child of pivot source - not a ColumnRef"); - } - auto &columnref = col_expr->Cast(); - auto &column_name = columnref.GetColumnName(); - auto entry = handled_columns.find(column_name); - if (entry == handled_columns.end()) { - // not handled - add to the set of regularly selected columns - select_node->select_list.push_back(std::move(col_expr)); - } else { - name_map[column_name] = column_name; - handled_columns.erase(entry); - } - } - if (!handled_columns.empty()) { - for (auto &entry : handled_columns) { - throw BinderException("Column \"%s\" referenced in UNPIVOT but no matching entry was found in the table", - entry); - } - } - vector unpivot_names; - for (auto &entry : unpivot.entries) { - string generated_name; - for (auto &val : entry.values) { - auto name_entry = name_map.find(val.ToString()); - if (name_entry == name_map.end()) { - throw InternalException("Unpivot - could not find column name in name map"); - } - if (!generated_name.empty()) { - generated_name += "_"; - } - generated_name += name_entry->second; - } - unpivot_names.emplace_back(!entry.alias.empty() ? entry.alias : generated_name); - } - vector>> unpivot_expressions; - for (idx_t v_idx = 1; v_idx < unpivot.entries.size(); v_idx++) { - if (unpivot.entries[v_idx].values.size() != unpivot.entries[0].values.size()) { - throw BinderException( - "UNPIVOT value count mismatch - entry has %llu values, but expected all entries to have %llu values", - unpivot.entries[v_idx].values.size(), unpivot.entries[0].values.size()); - } - } - - for (idx_t v_idx = 0; v_idx < unpivot.entries[0].values.size(); v_idx++) { - vector> expressions; - expressions.reserve(unpivot.entries.size()); - for (auto &entry : unpivot.entries) { - expressions.push_back(make_uniq(entry.values[v_idx].ToString())); - } - unpivot_expressions.push_back(std::move(expressions)); - } - - // construct the UNNEST expression for the set of names (constant) - auto unpivot_list = Value::LIST(LogicalType::VARCHAR, std::move(unpivot_names)); - auto unpivot_name_expr = make_uniq(std::move(unpivot_list)); - vector> unnest_name_children; - unnest_name_children.push_back(std::move(unpivot_name_expr)); - auto unnest_name_expr = make_uniq("unnest", std::move(unnest_name_children)); - unnest_name_expr->alias = unpivot.unpivot_names[0]; - select_node->select_list.push_back(std::move(unnest_name_expr)); - - // construct the UNNEST expression for the set of unpivoted columns - if (ref.unpivot_names.size() != unpivot_expressions.size()) { - throw BinderException("UNPIVOT name count mismatch - got %d names but %d expressions", ref.unpivot_names.size(), - unpivot_expressions.size()); - } - for (idx_t i = 0; i < unpivot_expressions.size(); i++) { - auto list_expr = make_uniq("list_value", std::move(unpivot_expressions[i])); - vector> unnest_val_children; - unnest_val_children.push_back(std::move(list_expr)); - auto unnest_val_expr = make_uniq("unnest", std::move(unnest_val_children)); - auto unnest_name = i < ref.column_name_alias.size() ? ref.column_name_alias[i] : ref.unpivot_names[i]; - unnest_val_expr->alias = unnest_name; - select_node->select_list.push_back(std::move(unnest_val_expr)); - if (!ref.include_nulls) { - // if we are running with EXCLUDE NULLS we need to add an IS NOT NULL filter - auto colref = make_uniq(unnest_name); - auto filter = make_uniq(ExpressionType::OPERATOR_IS_NOT_NULL, std::move(colref)); - if (where_clause) { - where_clause = make_uniq(ExpressionType::CONJUNCTION_AND, - std::move(where_clause), std::move(filter)); - } else { - where_clause = std::move(filter); - } - } - } - return select_node; -} - -unique_ptr Binder::Bind(PivotRef &ref) { - if (!ref.source) { - throw InternalException("Pivot without a source!?"); - } - if (!ref.bound_pivot_values.empty() || !ref.bound_group_names.empty() || !ref.bound_aggregate_names.empty()) { - // bound pivot - return BindBoundPivot(ref); - } - - // bind the source of the pivot - // we need to do this to be able to expand star expressions - if (ref.source->type == TableReferenceType::SUBQUERY && ref.source->alias.empty()) { - ref.source->alias = "__internal_pivot_alias_" + to_string(GenerateTableIndex()); - } - auto copied_source = ref.source->Copy(); - auto star_binder = Binder::CreateBinder(context, this); - star_binder->Bind(*copied_source); - - // figure out the set of column names that are in the source of the pivot - vector> all_columns; - star_binder->ExpandStarExpression(make_uniq(), all_columns); - - unique_ptr select_node; - unique_ptr where_clause; - if (!ref.aggregates.empty()) { - select_node = BindPivot(ref, std::move(all_columns)); - } else { - select_node = BindUnpivot(*star_binder, ref, std::move(all_columns), where_clause); - } - // bind the generated select node - auto child_binder = Binder::CreateBinder(context, this); - auto bound_select_node = child_binder->BindNode(*select_node); - auto root_index = bound_select_node->GetRootIndex(); - BoundQueryNode *bound_select_ptr = bound_select_node.get(); - - unique_ptr result; - MoveCorrelatedExpressions(*child_binder); - result = make_uniq(std::move(child_binder), std::move(bound_select_node)); - auto subquery_alias = ref.alias.empty() ? "__unnamed_pivot" : ref.alias; - SubqueryRef subquery_ref(nullptr, subquery_alias); - subquery_ref.column_name_alias = std::move(ref.column_name_alias); - if (where_clause) { - // if a WHERE clause was provided - bind a subquery holding the WHERE clause - // we need to bind a new subquery here because the WHERE clause has to be applied AFTER the unnest - child_binder = Binder::CreateBinder(context, this); - child_binder->bind_context.AddSubquery(root_index, subquery_ref.alias, subquery_ref, *bound_select_ptr); - auto where_query = make_uniq(); - where_query->select_list.push_back(make_uniq()); - where_query->where_clause = std::move(where_clause); - bound_select_node = child_binder->BindSelectNode(*where_query, std::move(result)); - bound_select_ptr = bound_select_node.get(); - root_index = bound_select_node->GetRootIndex(); - result = make_uniq(std::move(child_binder), std::move(bound_select_node)); - } - bind_context.AddSubquery(root_index, subquery_ref.alias, subquery_ref, *bound_select_ptr); - return result; -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr Binder::Bind(SubqueryRef &ref, optional_ptr cte) { - auto binder = Binder::CreateBinder(context, this); - binder->can_contain_nulls = true; - if (cte) { - binder->bound_ctes.insert(*cte); - } - binder->alias = ref.alias.empty() ? "unnamed_subquery" : ref.alias; - auto subquery = binder->BindNode(*ref.subquery->node); - idx_t bind_index = subquery->GetRootIndex(); - string subquery_alias; - if (ref.alias.empty()) { - subquery_alias = "unnamed_subquery" + to_string(bind_index); - } else { - subquery_alias = ref.alias; - } - auto result = make_uniq(std::move(binder), std::move(subquery)); - bind_context.AddSubquery(bind_index, subquery_alias, ref, *result->subquery); - MoveCorrelatedExpressions(*result->binder); - return std::move(result); -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - - - - - - -namespace duckdb { - -static bool IsTableInTableOutFunction(TableFunctionCatalogEntry &table_function) { - auto fun = table_function.functions.GetFunctionByOffset(0); - return table_function.functions.Size() == 1 && fun.arguments.size() == 1 && - fun.arguments[0].id() == LogicalTypeId::TABLE; -} - -bool Binder::BindTableInTableOutFunction(vector> &expressions, - unique_ptr &subquery, string &error) { - auto binder = Binder::CreateBinder(this->context, this, true); - unique_ptr subquery_node; - if (expressions.size() == 1 && expressions[0]->type == ExpressionType::SUBQUERY) { - // general case: argument is a subquery, bind it as part of the node - auto &se = expressions[0]->Cast(); - subquery_node = std::move(se.subquery->node); - } else { - // special case: non-subquery parameter to table-in table-out function - // generate a subquery and bind that (i.e. UNNEST([1,2,3]) becomes UNNEST((SELECT [1,2,3])) - auto select_node = make_uniq(); - select_node->select_list = std::move(expressions); - select_node->from_table = make_uniq(); - subquery_node = std::move(select_node); - } - auto node = binder->BindNode(*subquery_node); - subquery = make_uniq(std::move(binder), std::move(node)); - MoveCorrelatedExpressions(*subquery->binder); - return true; -} - -bool Binder::BindTableFunctionParameters(TableFunctionCatalogEntry &table_function, - vector> &expressions, - vector &arguments, vector ¶meters, - named_parameter_map_t &named_parameters, - unique_ptr &subquery, string &error) { - if (IsTableInTableOutFunction(table_function)) { - // special case binding for table-in table-out function - arguments.emplace_back(LogicalTypeId::TABLE); - return BindTableInTableOutFunction(expressions, subquery, error); - } - bool seen_subquery = false; - for (auto &child : expressions) { - string parameter_name; - - // hack to make named parameters work - if (child->type == ExpressionType::COMPARE_EQUAL) { - // comparison, check if the LHS is a columnref - auto &comp = child->Cast(); - if (comp.left->type == ExpressionType::COLUMN_REF) { - auto &colref = comp.left->Cast(); - if (!colref.IsQualified()) { - parameter_name = colref.GetColumnName(); - child = std::move(comp.right); - } - } - } - if (child->type == ExpressionType::SUBQUERY) { - auto fun = table_function.functions.GetFunctionByOffset(0); - if (table_function.functions.Size() != 1 || fun.arguments.empty() || - fun.arguments[0].id() != LogicalTypeId::TABLE) { - throw BinderException( - "Only table-in-out functions can have subquery parameters - %s only accepts constant parameters", - fun.name); - } - // this separate subquery binding path is only used by python_map - // FIXME: this should be unified with `BindTableInTableOutFunction` above - if (seen_subquery) { - error = "Table function can have at most one subquery parameter "; - return false; - } - auto binder = Binder::CreateBinder(this->context, this, true); - auto &se = child->Cast(); - auto node = binder->BindNode(*se.subquery->node); - subquery = make_uniq(std::move(binder), std::move(node)); - seen_subquery = true; - arguments.emplace_back(LogicalTypeId::TABLE); - parameters.emplace_back( - Value(LogicalType::INVALID)); // this is a dummy value so the lengths of arguments and parameter match - continue; - } - - TableFunctionBinder binder(*this, context); - LogicalType sql_type; - auto expr = binder.Bind(child, &sql_type); - if (expr->HasParameter()) { - throw ParameterNotResolvedException(); - } - if (!expr->IsScalar()) { - // should have been eliminated before - throw InternalException("Table function requires a constant parameter"); - } - auto constant = ExpressionExecutor::EvaluateScalar(context, *expr, true); - if (parameter_name.empty()) { - // unnamed parameter - if (!named_parameters.empty()) { - error = "Unnamed parameters cannot come after named parameters"; - return false; - } - arguments.emplace_back(sql_type); - parameters.emplace_back(std::move(constant)); - } else { - named_parameters[parameter_name] = std::move(constant); - } - } - return true; -} - -unique_ptr -Binder::BindTableFunctionInternal(TableFunction &table_function, const string &function_name, vector parameters, - named_parameter_map_t named_parameters, vector input_table_types, - vector input_table_names, const vector &column_name_alias, - unique_ptr external_dependency) { - auto bind_index = GenerateTableIndex(); - // perform the binding - unique_ptr bind_data; - vector return_types; - vector return_names; - if (table_function.bind || table_function.bind_replace) { - TableFunctionBindInput bind_input(parameters, named_parameters, input_table_types, input_table_names, - table_function.function_info.get()); - if (table_function.bind_replace) { - auto new_plan = table_function.bind_replace(context, bind_input); - if (new_plan != nullptr) { - return CreatePlan(*Bind(*new_plan)); - } else if (!table_function.bind) { - throw BinderException("Failed to bind \"%s\": nullptr returned from bind_replace without bind function", - table_function.name); - } - } - bind_data = table_function.bind(context, bind_input, return_types, return_names); - if (table_function.name == "pandas_scan" || table_function.name == "arrow_scan") { - auto &arrow_bind = bind_data->Cast(); - arrow_bind.external_dependency = std::move(external_dependency); - } - if (table_function.name == "read_csv" || table_function.name == "read_csv_auto") { - auto &csv_bind = bind_data->Cast(); - if (csv_bind.single_threaded) { - table_function.extra_info = "(Single-Threaded)"; - } else { - table_function.extra_info = "(Multi-Threaded)"; - } - } - } else { - throw InvalidInputException("Cannot call function \"%s\" directly - it has no bind function", - table_function.name); - } - if (return_types.size() != return_names.size()) { - throw InternalException("Failed to bind \"%s\": return_types/names must have same size", table_function.name); - } - if (return_types.empty()) { - throw InternalException("Failed to bind \"%s\": Table function must return at least one column", - table_function.name); - } - // overwrite the names with any supplied aliases - for (idx_t i = 0; i < column_name_alias.size() && i < return_names.size(); i++) { - return_names[i] = column_name_alias[i]; - } - for (idx_t i = 0; i < return_names.size(); i++) { - if (return_names[i].empty()) { - return_names[i] = "C" + to_string(i); - } - } - - auto get = make_uniq(bind_index, table_function, std::move(bind_data), return_types, return_names); - get->parameters = parameters; - get->named_parameters = named_parameters; - get->input_table_types = input_table_types; - get->input_table_names = input_table_names; - if (table_function.in_out_function && !table_function.projection_pushdown) { - get->column_ids.reserve(return_types.size()); - for (idx_t i = 0; i < return_types.size(); i++) { - get->column_ids.push_back(i); - } - } - // now add the table function to the bind context so its columns can be bound - bind_context.AddTableFunction(bind_index, function_name, return_names, return_types, get->column_ids, - get->GetTable().get()); - return std::move(get); -} - -unique_ptr Binder::BindTableFunction(TableFunction &function, vector parameters) { - named_parameter_map_t named_parameters; - vector input_table_types; - vector input_table_names; - vector column_name_aliases; - return BindTableFunctionInternal(function, function.name, std::move(parameters), std::move(named_parameters), - std::move(input_table_types), std::move(input_table_names), column_name_aliases, - nullptr); -} - -unique_ptr Binder::Bind(TableFunctionRef &ref) { - QueryErrorContext error_context(root_statement, ref.query_location); - - D_ASSERT(ref.function->type == ExpressionType::FUNCTION); - auto &fexpr = ref.function->Cast(); - - // fetch the function from the catalog - auto &func_catalog = Catalog::GetEntry(context, CatalogType::TABLE_FUNCTION_ENTRY, fexpr.catalog, fexpr.schema, - fexpr.function_name, error_context); - - if (func_catalog.type == CatalogType::TABLE_MACRO_ENTRY) { - auto ¯o_func = func_catalog.Cast(); - auto query_node = BindTableMacro(fexpr, macro_func, 0); - D_ASSERT(query_node); - - auto binder = Binder::CreateBinder(context, this); - binder->can_contain_nulls = true; - - binder->alias = ref.alias.empty() ? "unnamed_query" : ref.alias; - auto query = binder->BindNode(*query_node); - - idx_t bind_index = query->GetRootIndex(); - // string alias; - string alias = (ref.alias.empty() ? "unnamed_query" + to_string(bind_index) : ref.alias); - - auto result = make_uniq(std::move(binder), std::move(query)); - // remember ref here is TableFunctionRef and NOT base class - bind_context.AddSubquery(bind_index, alias, ref, *result->subquery); - MoveCorrelatedExpressions(*result->binder); - return std::move(result); - } - D_ASSERT(func_catalog.type == CatalogType::TABLE_FUNCTION_ENTRY); - auto &function = func_catalog.Cast(); - - // evaluate the input parameters to the function - vector arguments; - vector parameters; - named_parameter_map_t named_parameters; - unique_ptr subquery; - string error; - if (!BindTableFunctionParameters(function, fexpr.children, arguments, parameters, named_parameters, subquery, - error)) { - throw BinderException(FormatError(ref, error)); - } - - // select the function based on the input parameters - FunctionBinder function_binder(context); - idx_t best_function_idx = function_binder.BindFunction(function.name, function.functions, arguments, error); - if (best_function_idx == DConstants::INVALID_INDEX) { - throw BinderException(FormatError(ref, error)); - } - auto table_function = function.functions.GetFunctionByOffset(best_function_idx); - - // now check the named parameters - BindNamedParameters(table_function.named_parameters, named_parameters, error_context, table_function.name); - - // cast the parameters to the type of the function - for (idx_t i = 0; i < arguments.size(); i++) { - auto target_type = i < table_function.arguments.size() ? table_function.arguments[i] : table_function.varargs; - - if (target_type != LogicalType::ANY && target_type != LogicalType::TABLE && - target_type != LogicalType::POINTER && target_type.id() != LogicalTypeId::LIST) { - parameters[i] = parameters[i].CastAs(context, target_type); - } - } - - vector input_table_types; - vector input_table_names; - - if (subquery) { - input_table_types = subquery->subquery->types; - input_table_names = subquery->subquery->names; - } - auto get = BindTableFunctionInternal(table_function, ref.alias.empty() ? fexpr.function_name : ref.alias, - std::move(parameters), std::move(named_parameters), - std::move(input_table_types), std::move(input_table_names), - ref.column_name_alias, std::move(ref.external_dependency)); - if (subquery) { - get->children.push_back(Binder::CreatePlan(*subquery)); - } - - return make_uniq_base(std::move(get)); -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundBaseTableRef &ref) { - return std::move(ref.get); -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundCTERef &ref) { - auto index = ref.bind_index; - - vector types; - types.reserve(ref.types.size()); - for (auto &type : ref.types) { - types.push_back(type); - } - - return make_uniq(index, ref.cte_index, types, ref.bound_columns, ref.materialized_cte); -} - -} // namespace duckdb - - - - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundEmptyTableRef &ref) { - return make_uniq(ref.bind_index); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundExpressionListRef &ref) { - auto root = make_uniq_base(GenerateTableIndex()); - // values list, first plan any subqueries in the list - for (auto &expr_list : ref.values) { - for (auto &expr : expr_list) { - PlanSubqueries(expr, root); - } - } - // now create a LogicalExpressionGet from the set of expressions - // fetch the types - vector types; - for (auto &expr : ref.values[0]) { - types.push_back(expr->return_type); - } - auto expr_get = make_uniq(ref.bind_index, types, std::move(ref.values)); - expr_get->AddChild(std::move(root)); - return std::move(expr_get); -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - - - -namespace duckdb { - -//! Create a JoinCondition from a comparison -static bool CreateJoinCondition(Expression &expr, const unordered_set &left_bindings, - const unordered_set &right_bindings, vector &conditions) { - // comparison - auto &comparison = expr.Cast(); - auto left_side = JoinSide::GetJoinSide(*comparison.left, left_bindings, right_bindings); - auto right_side = JoinSide::GetJoinSide(*comparison.right, left_bindings, right_bindings); - if (left_side != JoinSide::BOTH && right_side != JoinSide::BOTH) { - // join condition can be divided in a left/right side - JoinCondition condition; - condition.comparison = expr.type; - auto left = std::move(comparison.left); - auto right = std::move(comparison.right); - if (left_side == JoinSide::RIGHT) { - // left = right, right = left, flip the comparison symbol and reverse sides - swap(left, right); - condition.comparison = FlipComparisonExpression(expr.type); - } - condition.left = std::move(left); - condition.right = std::move(right); - conditions.push_back(std::move(condition)); - return true; - } - return false; -} - -void LogicalComparisonJoin::ExtractJoinConditions( - ClientContext &context, JoinType type, unique_ptr &left_child, - unique_ptr &right_child, const unordered_set &left_bindings, - const unordered_set &right_bindings, vector> &expressions, - vector &conditions, vector> &arbitrary_expressions) { - - for (auto &expr : expressions) { - auto total_side = JoinSide::GetJoinSide(*expr, left_bindings, right_bindings); - if (total_side != JoinSide::BOTH) { - // join condition does not reference both sides, add it as filter under the join - if (type == JoinType::LEFT && total_side == JoinSide::RIGHT) { - // filter is on RHS and the join is a LEFT OUTER join, we can push it in the right child - if (right_child->type != LogicalOperatorType::LOGICAL_FILTER) { - // not a filter yet, push a new empty filter - auto filter = make_uniq(); - filter->AddChild(std::move(right_child)); - right_child = std::move(filter); - } - // push the expression into the filter - auto &filter = right_child->Cast(); - filter.expressions.push_back(std::move(expr)); - continue; - } - // if the join is a LEFT JOIN and the join expression constantly evaluates to TRUE, - // then we do not add it to the arbitrary expressions - if (type == JoinType::LEFT && expr->IsFoldable()) { - Value result; - ExpressionExecutor::TryEvaluateScalar(context, *expr, result); - if (!result.IsNull() && result == Value(true)) { - continue; - } - } - } else if (expr->type == ExpressionType::COMPARE_EQUAL || expr->type == ExpressionType::COMPARE_NOTEQUAL || - expr->type == ExpressionType::COMPARE_BOUNDARY_START || - expr->type == ExpressionType::COMPARE_LESSTHAN || - expr->type == ExpressionType::COMPARE_GREATERTHAN || - expr->type == ExpressionType::COMPARE_LESSTHANOREQUALTO || - expr->type == ExpressionType::COMPARE_GREATERTHANOREQUALTO || - expr->type == ExpressionType::COMPARE_BOUNDARY_START || - expr->type == ExpressionType::COMPARE_NOT_DISTINCT_FROM || - expr->type == ExpressionType::COMPARE_DISTINCT_FROM) - - { - // comparison, check if we can create a comparison JoinCondition - if (CreateJoinCondition(*expr, left_bindings, right_bindings, conditions)) { - // successfully created the join condition - continue; - } - } - arbitrary_expressions.push_back(std::move(expr)); - } -} - -void LogicalComparisonJoin::ExtractJoinConditions(ClientContext &context, JoinType type, - unique_ptr &left_child, - unique_ptr &right_child, - vector> &expressions, - vector &conditions, - vector> &arbitrary_expressions) { - unordered_set left_bindings, right_bindings; - LogicalJoin::GetTableReferences(*left_child, left_bindings); - LogicalJoin::GetTableReferences(*right_child, right_bindings); - return ExtractJoinConditions(context, type, left_child, right_child, left_bindings, right_bindings, expressions, - conditions, arbitrary_expressions); -} - -void LogicalComparisonJoin::ExtractJoinConditions(ClientContext &context, JoinType type, - unique_ptr &left_child, - unique_ptr &right_child, - unique_ptr condition, vector &conditions, - vector> &arbitrary_expressions) { - // split the expressions by the AND clause - vector> expressions; - expressions.push_back(std::move(condition)); - LogicalFilter::SplitPredicates(expressions); - return ExtractJoinConditions(context, type, left_child, right_child, expressions, conditions, - arbitrary_expressions); -} - -unique_ptr LogicalComparisonJoin::CreateJoin(ClientContext &context, JoinType type, - JoinRefType reftype, - unique_ptr left_child, - unique_ptr right_child, - vector conditions, - vector> arbitrary_expressions) { - // Validate the conditions - bool need_to_consider_arbitrary_expressions = true; - switch (reftype) { - case JoinRefType::ASOF: { - need_to_consider_arbitrary_expressions = false; - auto asof_idx = conditions.size(); - for (size_t c = 0; c < conditions.size(); ++c) { - auto &cond = conditions[c]; - switch (cond.comparison) { - case ExpressionType::COMPARE_EQUAL: - case ExpressionType::COMPARE_NOT_DISTINCT_FROM: - break; - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - case ExpressionType::COMPARE_GREATERTHAN: - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - case ExpressionType::COMPARE_LESSTHAN: - if (asof_idx < conditions.size()) { - throw BinderException("Multiple ASOF JOIN inequalities"); - } - asof_idx = c; - break; - default: - throw BinderException("Invalid ASOF JOIN comparison"); - } - } - if (asof_idx == conditions.size()) { - throw BinderException("Missing ASOF JOIN inequality"); - } - break; - } - default: - break; - } - - if (type == JoinType::INNER && reftype == JoinRefType::REGULAR) { - // for inner joins we can push arbitrary expressions as a filter - // here we prefer to create a comparison join if possible - // that way we can use the much faster hash join to process the main join - // rather than doing a nested loop join to handle arbitrary expressions - - // for left and full outer joins we HAVE to process all join conditions - // because pushing a filter will lead to an incorrect result, as non-matching tuples cannot be filtered out - need_to_consider_arbitrary_expressions = false; - } - if ((need_to_consider_arbitrary_expressions && !arbitrary_expressions.empty()) || conditions.empty()) { - if (arbitrary_expressions.empty()) { - // all conditions were pushed down, add TRUE predicate - arbitrary_expressions.push_back(make_uniq(Value::BOOLEAN(true))); - } - for (auto &condition : conditions) { - arbitrary_expressions.push_back(JoinCondition::CreateExpression(std::move(condition))); - } - // if we get here we could not create any JoinConditions - // turn this into an arbitrary expression join - auto any_join = make_uniq(type); - // create the condition - any_join->children.push_back(std::move(left_child)); - any_join->children.push_back(std::move(right_child)); - // AND all the arbitrary expressions together - // do the same with any remaining conditions - any_join->condition = std::move(arbitrary_expressions[0]); - for (idx_t i = 1; i < arbitrary_expressions.size(); i++) { - any_join->condition = make_uniq( - ExpressionType::CONJUNCTION_AND, std::move(any_join->condition), std::move(arbitrary_expressions[i])); - } - return std::move(any_join); - } else { - // we successfully converted expressions into JoinConditions - // create a LogicalComparisonJoin - auto logical_type = LogicalOperatorType::LOGICAL_COMPARISON_JOIN; - if (reftype == JoinRefType::ASOF) { - logical_type = LogicalOperatorType::LOGICAL_ASOF_JOIN; - } - auto comp_join = make_uniq(type, logical_type); - comp_join->conditions = std::move(conditions); - comp_join->children.push_back(std::move(left_child)); - comp_join->children.push_back(std::move(right_child)); - if (!arbitrary_expressions.empty()) { - // we have some arbitrary expressions as well - // add them to a filter - auto filter = make_uniq(); - for (auto &expr : arbitrary_expressions) { - filter->expressions.push_back(std::move(expr)); - } - LogicalFilter::SplitPredicates(filter->expressions); - filter->children.push_back(std::move(comp_join)); - return std::move(filter); - } - return std::move(comp_join); - } -} - -static bool HasCorrelatedColumns(Expression &expression) { - if (expression.type == ExpressionType::BOUND_COLUMN_REF) { - auto &colref = expression.Cast(); - if (colref.depth > 0) { - return true; - } - } - bool has_correlated_columns = false; - ExpressionIterator::EnumerateChildren(expression, [&](Expression &child) { - if (HasCorrelatedColumns(child)) { - has_correlated_columns = true; - } - }); - return has_correlated_columns; -} - -unique_ptr LogicalComparisonJoin::CreateJoin(ClientContext &context, JoinType type, - JoinRefType reftype, - unique_ptr left_child, - unique_ptr right_child, - unique_ptr condition) { - vector conditions; - vector> arbitrary_expressions; - LogicalComparisonJoin::ExtractJoinConditions(context, type, left_child, right_child, std::move(condition), - conditions, arbitrary_expressions); - return LogicalComparisonJoin::CreateJoin(context, type, reftype, std::move(left_child), std::move(right_child), - std::move(conditions), std::move(arbitrary_expressions)); -} - -unique_ptr Binder::CreatePlan(BoundJoinRef &ref) { - auto old_is_outside_flattened = is_outside_flattened; - // Plan laterals from outermost to innermost - if (ref.lateral) { - // Set the flag to ensure that children do not flatten before the root - is_outside_flattened = false; - } - auto left = CreatePlan(*ref.left); - auto right = CreatePlan(*ref.right); - is_outside_flattened = old_is_outside_flattened; - - // For joins, depth of the bindings will be one higher on the right because of the lateral binder - // If the current join does not have correlations between left and right, then the right bindings - // have depth 1 too high and can be reduced by 1 throughout - if (!ref.lateral && !ref.correlated_columns.empty()) { - LateralBinder::ReduceExpressionDepth(*right, ref.correlated_columns); - } - - if (ref.type == JoinType::RIGHT && ref.ref_type != JoinRefType::ASOF && - ClientConfig::GetConfig(context).enable_optimizer) { - // we turn any right outer joins into left outer joins for optimization purposes - // they are the same but with sides flipped, so treating them the same simplifies life - ref.type = JoinType::LEFT; - std::swap(left, right); - } - if (ref.lateral) { - if (!is_outside_flattened) { - // If outer dependent joins is yet to be flattened, only plan the lateral - has_unplanned_dependent_joins = true; - return LogicalDependentJoin::Create(std::move(left), std::move(right), ref.correlated_columns, ref.type, - std::move(ref.condition)); - } else { - // All outer dependent joins have been planned and flattened, so plan and flatten lateral and recursively - // plan the children - auto new_plan = PlanLateralJoin(std::move(left), std::move(right), ref.correlated_columns, ref.type, - std::move(ref.condition)); - if (has_unplanned_dependent_joins) { - RecursiveDependentJoinPlanner plan(*this); - plan.VisitOperator(*new_plan); - } - return new_plan; - } - } - switch (ref.ref_type) { - case JoinRefType::CROSS: - return LogicalCrossProduct::Create(std::move(left), std::move(right)); - case JoinRefType::POSITIONAL: - return LogicalPositionalJoin::Create(std::move(left), std::move(right)); - default: - break; - } - if (ref.type == JoinType::INNER && (ref.condition->HasSubquery() || HasCorrelatedColumns(*ref.condition)) && - ref.ref_type == JoinRefType::REGULAR) { - // inner join, generate a cross product + filter - // this will be later turned into a proper join by the join order optimizer - auto root = LogicalCrossProduct::Create(std::move(left), std::move(right)); - - auto filter = make_uniq(std::move(ref.condition)); - // visit the expressions in the filter - for (auto &expression : filter->expressions) { - PlanSubqueries(expression, root); - } - filter->AddChild(std::move(root)); - return std::move(filter); - } - - // now create the join operator from the join condition - auto result = LogicalComparisonJoin::CreateJoin(context, ref.type, ref.ref_type, std::move(left), std::move(right), - std::move(ref.condition)); - - optional_ptr join; - if (result->type == LogicalOperatorType::LOGICAL_FILTER) { - join = result->children[0].get(); - } else { - join = result.get(); - } - for (auto &child : join->children) { - if (child->type == LogicalOperatorType::LOGICAL_FILTER) { - auto &filter = child->Cast(); - for (auto &expr : filter.expressions) { - PlanSubqueries(expr, filter.children[0]); - } - } - } - - // we visit the expressions depending on the type of join - switch (join->type) { - case LogicalOperatorType::LOGICAL_ASOF_JOIN: - case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: { - // comparison join - // in this join we visit the expressions on the LHS with the LHS as root node - // and the expressions on the RHS with the RHS as root node - auto &comp_join = join->Cast(); - for (idx_t i = 0; i < comp_join.conditions.size(); i++) { - PlanSubqueries(comp_join.conditions[i].left, comp_join.children[0]); - PlanSubqueries(comp_join.conditions[i].right, comp_join.children[1]); - } - break; - } - case LogicalOperatorType::LOGICAL_ANY_JOIN: { - auto &any_join = join->Cast(); - // for the any join we just visit the condition - if (any_join.condition->HasSubquery()) { - throw NotImplementedException("Cannot perform non-inner join on subquery!"); - } - break; - } - default: - break; - } - return result; -} - -} // namespace duckdb - - - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundPivotRef &ref) { - auto subquery = ref.child_binder->CreatePlan(*ref.child); - - auto result = make_uniq(ref.bind_index, std::move(subquery), std::move(ref.bound_pivot)); - return std::move(result); -} - -} // namespace duckdb - - - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundSubqueryRef &ref) { - // generate the logical plan for the subquery - // this happens separately from the current LogicalPlan generation - ref.binder->is_outside_flattened = is_outside_flattened; - auto subquery = ref.binder->CreatePlan(*ref.subquery); - if (ref.binder->has_unplanned_dependent_joins) { - has_unplanned_dependent_joins = true; - } - return subquery; -} - -} // namespace duckdb - - - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundTableFunction &ref) { - return std::move(ref.get); -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - - -#include - -namespace duckdb { - -Binder *Binder::GetRootBinder() { - Binder *root = this; - while (root->parent) { - root = root->parent.get(); - } - return root; -} - -idx_t Binder::GetBinderDepth() const { - const Binder *root = this; - idx_t depth = 1; - while (root->parent) { - depth++; - root = root->parent.get(); - } - return depth; -} - -shared_ptr Binder::CreateBinder(ClientContext &context, optional_ptr parent, bool inherit_ctes) { - auto depth = parent ? parent->GetBinderDepth() : 0; - if (depth > context.config.max_expression_depth) { - throw BinderException("Max expression depth limit of %lld exceeded. Use \"SET max_expression_depth TO x\" to " - "increase the maximum expression depth.", - context.config.max_expression_depth); - } - return make_shared(true, context, parent ? parent->shared_from_this() : nullptr, inherit_ctes); -} - -Binder::Binder(bool, ClientContext &context, shared_ptr parent_p, bool inherit_ctes_p) - : context(context), parent(std::move(parent_p)), bound_tables(0), inherit_ctes(inherit_ctes_p) { - if (parent) { - - // We have to inherit macro and lambda parameter bindings and from the parent binder, if there is a parent. - macro_binding = parent->macro_binding; - lambda_bindings = parent->lambda_bindings; - - if (inherit_ctes) { - // We have to inherit CTE bindings from the parent bind_context, if there is a parent. - bind_context.SetCTEBindings(parent->bind_context.GetCTEBindings()); - bind_context.cte_references = parent->bind_context.cte_references; - parameters = parent->parameters; - } - } -} - -BoundStatement Binder::Bind(SQLStatement &statement) { - root_statement = &statement; - switch (statement.type) { - case StatementType::SELECT_STATEMENT: - return Bind(statement.Cast()); - case StatementType::INSERT_STATEMENT: - return Bind(statement.Cast()); - case StatementType::COPY_STATEMENT: - return Bind(statement.Cast()); - case StatementType::DELETE_STATEMENT: - return Bind(statement.Cast()); - case StatementType::UPDATE_STATEMENT: - return Bind(statement.Cast()); - case StatementType::RELATION_STATEMENT: - return Bind(statement.Cast()); - case StatementType::CREATE_STATEMENT: - return Bind(statement.Cast()); - case StatementType::DROP_STATEMENT: - return Bind(statement.Cast()); - case StatementType::ALTER_STATEMENT: - return Bind(statement.Cast()); - case StatementType::TRANSACTION_STATEMENT: - return Bind(statement.Cast()); - case StatementType::PRAGMA_STATEMENT: - return Bind(statement.Cast()); - case StatementType::EXPLAIN_STATEMENT: - return Bind(statement.Cast()); - case StatementType::VACUUM_STATEMENT: - return Bind(statement.Cast()); - case StatementType::SHOW_STATEMENT: - return Bind(statement.Cast()); - case StatementType::CALL_STATEMENT: - return Bind(statement.Cast()); - case StatementType::EXPORT_STATEMENT: - return Bind(statement.Cast()); - case StatementType::SET_STATEMENT: - return Bind(statement.Cast()); - case StatementType::LOAD_STATEMENT: - return Bind(statement.Cast()); - case StatementType::EXTENSION_STATEMENT: - return Bind(statement.Cast()); - case StatementType::PREPARE_STATEMENT: - return Bind(statement.Cast()); - case StatementType::EXECUTE_STATEMENT: - return Bind(statement.Cast()); - case StatementType::LOGICAL_PLAN_STATEMENT: - return Bind(statement.Cast()); - case StatementType::ATTACH_STATEMENT: - return Bind(statement.Cast()); - case StatementType::DETACH_STATEMENT: - return Bind(statement.Cast()); - default: // LCOV_EXCL_START - throw NotImplementedException("Unimplemented statement type \"%s\" for Bind", - StatementTypeToString(statement.type)); - } // LCOV_EXCL_STOP -} - -void Binder::AddCTEMap(CommonTableExpressionMap &cte_map) { - for (auto &cte_it : cte_map.map) { - AddCTE(cte_it.first, *cte_it.second); - } -} - -unique_ptr Binder::BindNode(QueryNode &node) { - // first we visit the set of CTEs and add them to the bind context - AddCTEMap(node.cte_map); - // now we bind the node - unique_ptr result; - switch (node.type) { - case QueryNodeType::SELECT_NODE: - result = BindNode(node.Cast()); - break; - case QueryNodeType::RECURSIVE_CTE_NODE: - result = BindNode(node.Cast()); - break; - case QueryNodeType::CTE_NODE: - result = BindNode(node.Cast()); - break; - default: - D_ASSERT(node.type == QueryNodeType::SET_OPERATION_NODE); - result = BindNode(node.Cast()); - break; - } - return result; -} - -BoundStatement Binder::Bind(QueryNode &node) { - auto bound_node = BindNode(node); - - BoundStatement result; - result.names = bound_node->names; - result.types = bound_node->types; - - // and plan it - result.plan = CreatePlan(*bound_node); - return result; -} - -unique_ptr Binder::CreatePlan(BoundQueryNode &node) { - switch (node.type) { - case QueryNodeType::SELECT_NODE: - return CreatePlan(node.Cast()); - case QueryNodeType::SET_OPERATION_NODE: - return CreatePlan(node.Cast()); - case QueryNodeType::RECURSIVE_CTE_NODE: - return CreatePlan(node.Cast()); - case QueryNodeType::CTE_NODE: - return CreatePlan(node.Cast()); - default: - throw InternalException("Unsupported bound query node type"); - } -} - -unique_ptr Binder::Bind(TableRef &ref) { - unique_ptr result; - switch (ref.type) { - case TableReferenceType::BASE_TABLE: - result = Bind(ref.Cast()); - break; - case TableReferenceType::JOIN: - result = Bind(ref.Cast()); - break; - case TableReferenceType::SUBQUERY: - result = Bind(ref.Cast()); - break; - case TableReferenceType::EMPTY: - result = Bind(ref.Cast()); - break; - case TableReferenceType::TABLE_FUNCTION: - result = Bind(ref.Cast()); - break; - case TableReferenceType::EXPRESSION_LIST: - result = Bind(ref.Cast()); - break; - case TableReferenceType::PIVOT: - result = Bind(ref.Cast()); - break; - case TableReferenceType::CTE: - case TableReferenceType::INVALID: - default: - throw InternalException("Unknown table ref type"); - } - result->sample = std::move(ref.sample); - return result; -} - -unique_ptr Binder::CreatePlan(BoundTableRef &ref) { - unique_ptr root; - switch (ref.type) { - case TableReferenceType::BASE_TABLE: - root = CreatePlan(ref.Cast()); - break; - case TableReferenceType::SUBQUERY: - root = CreatePlan(ref.Cast()); - break; - case TableReferenceType::JOIN: - root = CreatePlan(ref.Cast()); - break; - case TableReferenceType::TABLE_FUNCTION: - root = CreatePlan(ref.Cast()); - break; - case TableReferenceType::EMPTY: - root = CreatePlan(ref.Cast()); - break; - case TableReferenceType::EXPRESSION_LIST: - root = CreatePlan(ref.Cast()); - break; - case TableReferenceType::CTE: - root = CreatePlan(ref.Cast()); - break; - case TableReferenceType::PIVOT: - root = CreatePlan(ref.Cast()); - break; - case TableReferenceType::INVALID: - default: - throw InternalException("Unsupported bound table ref type"); - } - // plan the sample clause - if (ref.sample) { - root = make_uniq(std::move(ref.sample), std::move(root)); - } - return root; -} - -void Binder::AddCTE(const string &name, CommonTableExpressionInfo &info) { - D_ASSERT(!name.empty()); - auto entry = CTE_bindings.find(name); - if (entry != CTE_bindings.end()) { - throw InternalException("Duplicate CTE \"%s\" in query!", name); - } - CTE_bindings.insert(make_pair(name, reference(info))); -} - -optional_ptr Binder::FindCTE(const string &name, bool skip) { - auto entry = CTE_bindings.find(name); - if (entry != CTE_bindings.end()) { - if (!skip || entry->second.get().query->node->type == QueryNodeType::RECURSIVE_CTE_NODE) { - return &entry->second.get(); - } - } - if (parent && inherit_ctes) { - return parent->FindCTE(name, name == alias); - } - return nullptr; -} - -bool Binder::CTEIsAlreadyBound(CommonTableExpressionInfo &cte) { - if (bound_ctes.find(cte) != bound_ctes.end()) { - return true; - } - if (parent && inherit_ctes) { - return parent->CTEIsAlreadyBound(cte); - } - return false; -} - -void Binder::AddBoundView(ViewCatalogEntry &view) { - // check if the view is already bound - auto current = this; - while (current) { - if (current->bound_views.find(view) != current->bound_views.end()) { - throw BinderException("infinite recursion detected: attempting to recursively bind view \"%s\"", view.name); - } - current = current->parent.get(); - } - bound_views.insert(view); -} - -idx_t Binder::GenerateTableIndex() { - auto root_binder = GetRootBinder(); - return root_binder->bound_tables++; -} - -void Binder::PushExpressionBinder(ExpressionBinder &binder) { - GetActiveBinders().push_back(binder); -} - -void Binder::PopExpressionBinder() { - D_ASSERT(HasActiveBinder()); - GetActiveBinders().pop_back(); -} - -void Binder::SetActiveBinder(ExpressionBinder &binder) { - D_ASSERT(HasActiveBinder()); - GetActiveBinders().back() = binder; -} - -ExpressionBinder &Binder::GetActiveBinder() { - return GetActiveBinders().back(); -} - -bool Binder::HasActiveBinder() { - return !GetActiveBinders().empty(); -} - -vector> &Binder::GetActiveBinders() { - auto root_binder = GetRootBinder(); - return root_binder->active_binders; -} - -void Binder::AddUsingBindingSet(unique_ptr set) { - auto root_binder = GetRootBinder(); - root_binder->bind_context.AddUsingBindingSet(std::move(set)); -} - -void Binder::MoveCorrelatedExpressions(Binder &other) { - MergeCorrelatedColumns(other.correlated_columns); - other.correlated_columns.clear(); -} - -void Binder::MergeCorrelatedColumns(vector &other) { - for (idx_t i = 0; i < other.size(); i++) { - AddCorrelatedColumn(other[i]); - } -} - -void Binder::AddCorrelatedColumn(const CorrelatedColumnInfo &info) { - // we only add correlated columns to the list if they are not already there - if (std::find(correlated_columns.begin(), correlated_columns.end(), info) == correlated_columns.end()) { - correlated_columns.push_back(info); - } -} - -bool Binder::HasMatchingBinding(const string &table_name, const string &column_name, string &error_message) { - string empty_schema; - return HasMatchingBinding(empty_schema, table_name, column_name, error_message); -} - -bool Binder::HasMatchingBinding(const string &schema_name, const string &table_name, const string &column_name, - string &error_message) { - string empty_catalog; - return HasMatchingBinding(empty_catalog, schema_name, table_name, column_name, error_message); -} - -bool Binder::HasMatchingBinding(const string &catalog_name, const string &schema_name, const string &table_name, - const string &column_name, string &error_message) { - optional_ptr binding; - D_ASSERT(!lambda_bindings); - if (macro_binding && table_name == macro_binding->alias) { - binding = optional_ptr(macro_binding.get()); - } else { - binding = bind_context.GetBinding(table_name, error_message); - } - - if (!binding) { - return false; - } - if (!catalog_name.empty() || !schema_name.empty()) { - auto catalog_entry = binding->GetStandardEntry(); - if (!catalog_entry) { - return false; - } - if (!catalog_name.empty() && catalog_entry->catalog.GetName() != catalog_name) { - return false; - } - if (!schema_name.empty() && catalog_entry->schema.name != schema_name) { - return false; - } - if (catalog_entry->name != table_name) { - return false; - } - } - bool binding_found; - binding_found = binding->HasMatchingBinding(column_name); - if (!binding_found) { - error_message = binding->ColumnNotFoundError(column_name); - } - return binding_found; -} - -void Binder::SetBindingMode(BindingMode mode) { - auto root_binder = GetRootBinder(); - // FIXME: this used to also set the 'mode' for the current binder, was that necessary? - root_binder->mode = mode; -} - -BindingMode Binder::GetBindingMode() { - auto root_binder = GetRootBinder(); - return root_binder->mode; -} - -void Binder::SetCanContainNulls(bool can_contain_nulls_p) { - can_contain_nulls = can_contain_nulls_p; -} - -void Binder::AddTableName(string table_name) { - auto root_binder = GetRootBinder(); - root_binder->table_names.insert(std::move(table_name)); -} - -const unordered_set &Binder::GetTableNames() { - auto root_binder = GetRootBinder(); - return root_binder->table_names; -} - -string Binder::FormatError(ParsedExpression &expr_context, const string &message) { - return FormatError(expr_context.query_location, message); -} - -string Binder::FormatError(TableRef &ref_context, const string &message) { - return FormatError(ref_context.query_location, message); -} - -string Binder::FormatErrorRecursive(idx_t query_location, const string &message, vector &values) { - QueryErrorContext context(root_statement, query_location); - return context.FormatErrorRecursive(message, values); -} - -// FIXME: this is extremely naive -void VerifyNotExcluded(ParsedExpression &expr) { - if (expr.type == ExpressionType::COLUMN_REF) { - auto &column_ref = expr.Cast(); - if (!column_ref.IsQualified()) { - return; - } - auto &table_name = column_ref.GetTableName(); - if (table_name == "excluded") { - throw NotImplementedException("'excluded' qualified columns are not supported in the RETURNING clause yet"); - } - return; - } - ParsedExpressionIterator::EnumerateChildren( - expr, [&](const ParsedExpression &child) { VerifyNotExcluded((ParsedExpression &)child); }); -} - -BoundStatement Binder::BindReturning(vector> returning_list, TableCatalogEntry &table, - const string &alias, idx_t update_table_index, - unique_ptr child_operator, BoundStatement result) { - - vector types; - vector names; - - auto binder = Binder::CreateBinder(context); - - vector bound_columns; - idx_t column_count = 0; - for (auto &col : table.GetColumns().Logical()) { - names.push_back(col.Name()); - types.push_back(col.Type()); - if (!col.Generated()) { - bound_columns.push_back(column_count); - } - column_count++; - } - - binder->bind_context.AddBaseTable(update_table_index, alias.empty() ? table.name : alias, names, types, - bound_columns, &table, false); - ReturningBinder returning_binder(*binder, context); - - vector> projection_expressions; - LogicalType result_type; - vector> new_returning_list; - binder->ExpandStarExpressions(returning_list, new_returning_list); - for (auto &returning_expr : new_returning_list) { - VerifyNotExcluded(*returning_expr); - auto expr = returning_binder.Bind(returning_expr, &result_type); - result.names.push_back(expr->GetName()); - result.types.push_back(result_type); - projection_expressions.push_back(std::move(expr)); - } - - auto projection = make_uniq(GenerateTableIndex(), std::move(projection_expressions)); - projection->AddChild(std::move(child_operator)); - D_ASSERT(result.types.size() == result.names.size()); - result.plan = std::move(projection); - // If an insert/delete/update statement returns data, there are sometimes issues with streaming results - // where the data modification doesn't take place until the streamed result is exhausted. Once a row is - // returned, it should be guaranteed that the row has been inserted. - // see https://github.com/duckdb/duckdb/issues/8310 - properties.allow_stream_result = false; - properties.return_type = StatementReturnType::QUERY_RESULT; - return result; -} - -} // namespace duckdb - - - - -namespace duckdb { - -BoundParameterMap::BoundParameterMap(case_insensitive_map_t ¶meter_data) - : parameter_data(parameter_data) { -} - -LogicalType BoundParameterMap::GetReturnType(const string &identifier) { - D_ASSERT(!identifier.empty()); - auto it = parameter_data.find(identifier); - if (it == parameter_data.end()) { - return LogicalTypeId::UNKNOWN; - } - return it->second.return_type; -} - -bound_parameter_map_t *BoundParameterMap::GetParametersPtr() { - return ¶meters; -} - -const bound_parameter_map_t &BoundParameterMap::GetParameters() { - return parameters; -} - -const case_insensitive_map_t &BoundParameterMap::GetParameterData() { - return parameter_data; -} - -shared_ptr BoundParameterMap::CreateOrGetData(const string &identifier) { - auto entry = parameters.find(identifier); - if (entry == parameters.end()) { - // no entry yet: create a new one - auto data = make_shared(); - data->return_type = GetReturnType(identifier); - - CreateNewParameter(identifier, data); - return data; - } - return entry->second; -} - -unique_ptr BoundParameterMap::BindParameterExpression(ParameterExpression &expr) { - auto &identifier = expr.identifier; - auto return_type = GetReturnType(identifier); - - D_ASSERT(!parameter_data.count(identifier)); - - // No value has been supplied yet, - // We return a shared pointer to an object that will get populated wtih a Value later - // When the BoundParameterExpression get executed, this will be used to get the corresponding value - auto param_data = CreateOrGetData(identifier); - auto bound_expr = make_uniq(identifier); - bound_expr->parameter_data = param_data; - bound_expr->return_type = return_type; - bound_expr->alias = expr.alias; - return bound_expr; -} - -void BoundParameterMap::CreateNewParameter(const string &id, const shared_ptr ¶m_data) { - D_ASSERT(!parameters.count(id)); - parameters.emplace(std::make_pair(id, param_data)); -} - -} // namespace duckdb - - -namespace duckdb { - -BoundResultModifier::BoundResultModifier(ResultModifierType type) : type(type) { -} - -BoundResultModifier::~BoundResultModifier() { -} - -BoundOrderByNode::BoundOrderByNode(OrderType type, OrderByNullType null_order, unique_ptr expression) - : type(type), null_order(null_order), expression(std::move(expression)) { -} -BoundOrderByNode::BoundOrderByNode(OrderType type, OrderByNullType null_order, unique_ptr expression, - unique_ptr stats) - : type(type), null_order(null_order), expression(std::move(expression)), stats(std::move(stats)) { -} - -BoundOrderByNode BoundOrderByNode::Copy() const { - if (stats) { - return BoundOrderByNode(type, null_order, expression->Copy(), stats->ToUnique()); - } else { - return BoundOrderByNode(type, null_order, expression->Copy()); - } -} - -bool BoundOrderByNode::Equals(const BoundOrderByNode &other) const { - if (type != other.type || null_order != other.null_order) { - return false; - } - if (!expression->Equals(*other.expression)) { - return false; - } - - return true; -} - -string BoundOrderByNode::ToString() const { - auto str = expression->ToString(); - switch (type) { - case OrderType::ASCENDING: - str += " ASC"; - break; - case OrderType::DESCENDING: - str += " DESC"; - break; - default: - break; - } - - switch (null_order) { - case OrderByNullType::NULLS_FIRST: - str += " NULLS FIRST"; - break; - case OrderByNullType::NULLS_LAST: - str += " NULLS LAST"; - break; - default: - break; - } - return str; -} - -unique_ptr BoundOrderModifier::Copy() const { - auto result = make_uniq(); - for (auto &order : orders) { - result->orders.push_back(order.Copy()); - } - return result; -} - -bool BoundOrderModifier::Equals(const BoundOrderModifier &left, const BoundOrderModifier &right) { - if (left.orders.size() != right.orders.size()) { - return false; - } - for (idx_t i = 0; i < left.orders.size(); i++) { - if (!left.orders[i].Equals(right.orders[i])) { - return false; - } - } - return true; -} - -bool BoundOrderModifier::Equals(const unique_ptr &left, - const unique_ptr &right) { - if (left.get() == right.get()) { - return true; - } - if (!left || !right) { - return false; - } - return BoundOrderModifier::Equals(*left, *right); -} - -BoundLimitModifier::BoundLimitModifier() : BoundResultModifier(ResultModifierType::LIMIT_MODIFIER) { -} - -BoundOrderModifier::BoundOrderModifier() : BoundResultModifier(ResultModifierType::ORDER_MODIFIER) { -} - -BoundDistinctModifier::BoundDistinctModifier() : BoundResultModifier(ResultModifierType::DISTINCT_MODIFIER) { -} - -BoundLimitPercentModifier::BoundLimitPercentModifier() - : BoundResultModifier(ResultModifierType::LIMIT_PERCENT_MODIFIER) { -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -BoundAggregateExpression::BoundAggregateExpression(AggregateFunction function, vector> children, - unique_ptr filter, unique_ptr bind_info, - AggregateType aggr_type) - : Expression(ExpressionType::BOUND_AGGREGATE, ExpressionClass::BOUND_AGGREGATE, function.return_type), - function(std::move(function)), children(std::move(children)), bind_info(std::move(bind_info)), - aggr_type(aggr_type), filter(std::move(filter)) { - D_ASSERT(!this->function.name.empty()); -} - -string BoundAggregateExpression::ToString() const { - return FunctionExpression::ToString( - *this, string(), function.name, false, IsDistinct(), filter.get(), order_bys.get()); -} - -hash_t BoundAggregateExpression::Hash() const { - hash_t result = Expression::Hash(); - result = CombineHash(result, function.Hash()); - result = CombineHash(result, duckdb::Hash(IsDistinct())); - return result; -} - -bool BoundAggregateExpression::Equals(const BaseExpression &other_p) const { - if (!Expression::Equals(other_p)) { - return false; - } - auto &other = other_p.Cast(); - if (other.aggr_type != aggr_type) { - return false; - } - if (other.function != function) { - return false; - } - if (children.size() != other.children.size()) { - return false; - } - if (!Expression::Equals(other.filter, filter)) { - return false; - } - for (idx_t i = 0; i < children.size(); i++) { - if (!Expression::Equals(*children[i], *other.children[i])) { - return false; - } - } - if (!FunctionData::Equals(bind_info.get(), other.bind_info.get())) { - return false; - } - if (!BoundOrderModifier::Equals(order_bys, other.order_bys)) { - return false; - } - return true; -} - -bool BoundAggregateExpression::PropagatesNullValues() const { - return function.null_handling == FunctionNullHandling::SPECIAL_HANDLING ? false - : Expression::PropagatesNullValues(); -} - -unique_ptr BoundAggregateExpression::Copy() { - vector> new_children; - new_children.reserve(children.size()); - for (auto &child : children) { - new_children.push_back(child->Copy()); - } - auto new_bind_info = bind_info ? bind_info->Copy() : nullptr; - auto new_filter = filter ? filter->Copy() : nullptr; - auto copy = make_uniq(function, std::move(new_children), std::move(new_filter), - std::move(new_bind_info), aggr_type); - copy->CopyProperties(*this); - copy->order_bys = order_bys ? order_bys->Copy() : nullptr; - return std::move(copy); -} - -void BoundAggregateExpression::Serialize(Serializer &serializer) const { - Expression::Serialize(serializer); - serializer.WriteProperty(200, "return_type", return_type); - serializer.WriteProperty(201, "children", children); - FunctionSerializer::Serialize(serializer, function, bind_info.get()); - serializer.WriteProperty(203, "aggregate_type", aggr_type); - serializer.WritePropertyWithDefault(204, "filter", filter, unique_ptr()); - serializer.WritePropertyWithDefault(205, "order_bys", order_bys, unique_ptr()); -} - -unique_ptr BoundAggregateExpression::Deserialize(Deserializer &deserializer) { - auto return_type = deserializer.ReadProperty(200, "return_type"); - auto children = deserializer.ReadProperty>>(201, "children"); - auto entry = FunctionSerializer::Deserialize( - deserializer, CatalogType::AGGREGATE_FUNCTION_ENTRY, children, std::move(return_type)); - auto aggregate_type = deserializer.ReadProperty(203, "aggregate_type"); - auto filter = deserializer.ReadPropertyWithDefault>(204, "filter", unique_ptr()); - auto result = make_uniq(std::move(entry.first), std::move(children), std::move(filter), - std::move(entry.second), aggregate_type); - deserializer.ReadPropertyWithDefault(205, "order_bys", result->order_bys, unique_ptr()); - return std::move(result); -} - -} // namespace duckdb - - - -namespace duckdb { - -BoundBetweenExpression::BoundBetweenExpression() - : Expression(ExpressionType::COMPARE_BETWEEN, ExpressionClass::BOUND_BETWEEN, LogicalType::BOOLEAN) { -} - -BoundBetweenExpression::BoundBetweenExpression(unique_ptr input, unique_ptr lower, - unique_ptr upper, bool lower_inclusive, bool upper_inclusive) - : Expression(ExpressionType::COMPARE_BETWEEN, ExpressionClass::BOUND_BETWEEN, LogicalType::BOOLEAN), - input(std::move(input)), lower(std::move(lower)), upper(std::move(upper)), lower_inclusive(lower_inclusive), - upper_inclusive(upper_inclusive) { -} - -string BoundBetweenExpression::ToString() const { - return BetweenExpression::ToString(*this); -} - -bool BoundBetweenExpression::Equals(const BaseExpression &other_p) const { - if (!Expression::Equals(other_p)) { - return false; - } - auto &other = other_p.Cast(); - if (!Expression::Equals(*input, *other.input)) { - return false; - } - if (!Expression::Equals(*lower, *other.lower)) { - return false; - } - if (!Expression::Equals(*upper, *other.upper)) { - return false; - } - return lower_inclusive == other.lower_inclusive && upper_inclusive == other.upper_inclusive; -} - -unique_ptr BoundBetweenExpression::Copy() { - auto copy = make_uniq(input->Copy(), lower->Copy(), upper->Copy(), lower_inclusive, - upper_inclusive); - copy->CopyProperties(*this); - return std::move(copy); -} - -} // namespace duckdb - - - -namespace duckdb { - -BoundCaseExpression::BoundCaseExpression(LogicalType type) - : Expression(ExpressionType::CASE_EXPR, ExpressionClass::BOUND_CASE, std::move(type)) { -} - -BoundCaseExpression::BoundCaseExpression(unique_ptr when_expr, unique_ptr then_expr, - unique_ptr else_expr_p) - : Expression(ExpressionType::CASE_EXPR, ExpressionClass::BOUND_CASE, then_expr->return_type), - else_expr(std::move(else_expr_p)) { - BoundCaseCheck check; - check.when_expr = std::move(when_expr); - check.then_expr = std::move(then_expr); - case_checks.push_back(std::move(check)); -} - -string BoundCaseExpression::ToString() const { - return CaseExpression::ToString(*this); -} - -bool BoundCaseExpression::Equals(const BaseExpression &other_p) const { - if (!Expression::Equals(other_p)) { - return false; - } - auto &other = other_p.Cast(); - if (case_checks.size() != other.case_checks.size()) { - return false; - } - for (idx_t i = 0; i < case_checks.size(); i++) { - if (!Expression::Equals(*case_checks[i].when_expr, *other.case_checks[i].when_expr)) { - return false; - } - if (!Expression::Equals(*case_checks[i].then_expr, *other.case_checks[i].then_expr)) { - return false; - } - } - if (!Expression::Equals(*else_expr, *other.else_expr)) { - return false; - } - return true; -} - -unique_ptr BoundCaseExpression::Copy() { - auto new_case = make_uniq(return_type); - for (auto &check : case_checks) { - BoundCaseCheck new_check; - new_check.when_expr = check.when_expr->Copy(); - new_check.then_expr = check.then_expr->Copy(); - new_case->case_checks.push_back(std::move(new_check)); - } - new_case->else_expr = else_expr->Copy(); - - new_case->CopyProperties(*this); - return std::move(new_case); -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -static BoundCastInfo BindCastFunction(ClientContext &context, const LogicalType &source, const LogicalType &target) { - auto &cast_functions = DBConfig::GetConfig(context).GetCastFunctions(); - GetCastFunctionInput input(context); - return cast_functions.GetCastFunction(source, target, input); -} - -BoundCastExpression::BoundCastExpression(unique_ptr child_p, LogicalType target_type_p, - BoundCastInfo bound_cast_p, bool try_cast_p) - : Expression(ExpressionType::OPERATOR_CAST, ExpressionClass::BOUND_CAST, std::move(target_type_p)), - child(std::move(child_p)), try_cast(try_cast_p), bound_cast(std::move(bound_cast_p)) { -} - -BoundCastExpression::BoundCastExpression(ClientContext &context, unique_ptr child_p, - LogicalType target_type_p) - : Expression(ExpressionType::OPERATOR_CAST, ExpressionClass::BOUND_CAST, std::move(target_type_p)), - child(std::move(child_p)), try_cast(false), - bound_cast(BindCastFunction(context, child->return_type, return_type)) { -} - -unique_ptr AddCastExpressionInternal(unique_ptr expr, const LogicalType &target_type, - BoundCastInfo bound_cast, bool try_cast) { - if (expr->return_type == target_type) { - return expr; - } - auto &expr_type = expr->return_type; - if (target_type.id() == LogicalTypeId::LIST && expr_type.id() == LogicalTypeId::LIST) { - auto &target_list = ListType::GetChildType(target_type); - auto &expr_list = ListType::GetChildType(expr_type); - if (target_list.id() == LogicalTypeId::ANY || expr_list == target_list) { - return expr; - } - } - return make_uniq(std::move(expr), target_type, std::move(bound_cast), try_cast); -} - -unique_ptr AddCastToTypeInternal(unique_ptr expr, const LogicalType &target_type, - CastFunctionSet &cast_functions, GetCastFunctionInput &get_input, - bool try_cast) { - D_ASSERT(expr); - if (expr->expression_class == ExpressionClass::BOUND_PARAMETER) { - auto ¶meter = expr->Cast(); - if (!target_type.IsValid()) { - // invalidate the parameter - parameter.parameter_data->return_type = LogicalType::INVALID; - parameter.return_type = target_type; - return expr; - } - if (parameter.parameter_data->return_type.id() == LogicalTypeId::INVALID) { - // we don't know the type of this parameter - parameter.return_type = target_type; - return expr; - } - if (parameter.parameter_data->return_type.id() == LogicalTypeId::UNKNOWN) { - // prepared statement parameter cast - but there is no type, convert the type - parameter.parameter_data->return_type = target_type; - parameter.return_type = target_type; - return expr; - } - // prepared statement parameter already has a type - if (parameter.parameter_data->return_type == target_type) { - // this type! we are done - parameter.return_type = parameter.parameter_data->return_type; - return expr; - } - // invalidate the type - parameter.parameter_data->return_type = LogicalType::INVALID; - parameter.return_type = target_type; - return expr; - } else if (expr->expression_class == ExpressionClass::BOUND_DEFAULT) { - D_ASSERT(target_type.IsValid()); - auto &def = expr->Cast(); - def.return_type = target_type; - } - if (!target_type.IsValid()) { - return expr; - } - - auto cast_function = cast_functions.GetCastFunction(expr->return_type, target_type, get_input); - return AddCastExpressionInternal(std::move(expr), target_type, std::move(cast_function), try_cast); -} - -unique_ptr BoundCastExpression::AddDefaultCastToType(unique_ptr expr, - const LogicalType &target_type, bool try_cast) { - CastFunctionSet default_set; - GetCastFunctionInput get_input; - return AddCastToTypeInternal(std::move(expr), target_type, default_set, get_input, try_cast); -} - -unique_ptr BoundCastExpression::AddCastToType(ClientContext &context, unique_ptr expr, - const LogicalType &target_type, bool try_cast) { - auto &cast_functions = DBConfig::GetConfig(context).GetCastFunctions(); - GetCastFunctionInput get_input(context); - return AddCastToTypeInternal(std::move(expr), target_type, cast_functions, get_input, try_cast); -} - -bool BoundCastExpression::CastIsInvertible(const LogicalType &source_type, const LogicalType &target_type) { - D_ASSERT(source_type.IsValid() && target_type.IsValid()); - if (source_type.id() == LogicalTypeId::BOOLEAN || target_type.id() == LogicalTypeId::BOOLEAN) { - return false; - } - if (source_type.id() == LogicalTypeId::FLOAT || target_type.id() == LogicalTypeId::FLOAT) { - return false; - } - if (source_type.id() == LogicalTypeId::DOUBLE || target_type.id() == LogicalTypeId::DOUBLE) { - return false; - } - if (source_type.id() == LogicalTypeId::DECIMAL || target_type.id() == LogicalTypeId::DECIMAL) { - uint8_t source_width, target_width; - uint8_t source_scale, target_scale; - // cast to or from decimal - // cast is only invertible if the cast is strictly widening - if (!source_type.GetDecimalProperties(source_width, source_scale)) { - return false; - } - if (!target_type.GetDecimalProperties(target_width, target_scale)) { - return false; - } - if (target_scale < source_scale) { - return false; - } - return true; - } - if (source_type.id() == LogicalTypeId::TIMESTAMP || source_type.id() == LogicalTypeId::TIMESTAMP_TZ) { - switch (target_type.id()) { - case LogicalTypeId::DATE: - case LogicalTypeId::TIME: - case LogicalTypeId::TIME_TZ: - return false; - default: - break; - } - } - if (source_type.id() == LogicalTypeId::VARCHAR) { - switch (target_type.id()) { - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_NS: - case LogicalTypeId::TIMESTAMP_MS: - case LogicalTypeId::TIMESTAMP_SEC: - case LogicalTypeId::TIMESTAMP_TZ: - return true; - default: - return false; - } - } - if (target_type.id() == LogicalTypeId::VARCHAR) { - switch (source_type.id()) { - case LogicalTypeId::DATE: - case LogicalTypeId::TIME: - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_NS: - case LogicalTypeId::TIMESTAMP_MS: - case LogicalTypeId::TIMESTAMP_SEC: - case LogicalTypeId::TIME_TZ: - case LogicalTypeId::TIMESTAMP_TZ: - return true; - default: - return false; - } - } - return true; -} - -string BoundCastExpression::ToString() const { - return (try_cast ? "TRY_CAST(" : "CAST(") + child->GetName() + " AS " + return_type.ToString() + ")"; -} - -bool BoundCastExpression::Equals(const BaseExpression &other_p) const { - if (!Expression::Equals(other_p)) { - return false; - } - auto &other = other_p.Cast(); - if (!Expression::Equals(*child, *other.child)) { - return false; - } - if (try_cast != other.try_cast) { - return false; - } - return true; -} - -unique_ptr BoundCastExpression::Copy() { - auto copy = make_uniq(child->Copy(), return_type, bound_cast.Copy(), try_cast); - copy->CopyProperties(*this); - return std::move(copy); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -BoundColumnRefExpression::BoundColumnRefExpression(string alias_p, LogicalType type, ColumnBinding binding, idx_t depth) - : Expression(ExpressionType::BOUND_COLUMN_REF, ExpressionClass::BOUND_COLUMN_REF, std::move(type)), - binding(binding), depth(depth) { - this->alias = std::move(alias_p); -} - -BoundColumnRefExpression::BoundColumnRefExpression(LogicalType type, ColumnBinding binding, idx_t depth) - : BoundColumnRefExpression(string(), std::move(type), binding, depth) { -} - -unique_ptr BoundColumnRefExpression::Copy() { - return make_uniq(alias, return_type, binding, depth); -} - -hash_t BoundColumnRefExpression::Hash() const { - auto result = Expression::Hash(); - result = CombineHash(result, duckdb::Hash(binding.column_index)); - result = CombineHash(result, duckdb::Hash(binding.table_index)); - return CombineHash(result, duckdb::Hash(depth)); -} - -bool BoundColumnRefExpression::Equals(const BaseExpression &other_p) const { - if (!Expression::Equals(other_p)) { - return false; - } - auto &other = other_p.Cast(); - return other.binding == binding && other.depth == depth; -} - -string BoundColumnRefExpression::GetName() const { -#ifdef DEBUG - if (DBConfigOptions::debug_print_bindings) { - return binding.ToString(); - } -#endif - return Expression::GetName(); -} - -string BoundColumnRefExpression::ToString() const { -#ifdef DEBUG - if (DBConfigOptions::debug_print_bindings) { - return binding.ToString(); - } -#endif - if (!alias.empty()) { - return alias; - } - return binding.ToString(); -} - -} // namespace duckdb - - - -namespace duckdb { - -BoundComparisonExpression::BoundComparisonExpression(ExpressionType type, unique_ptr left, - unique_ptr right) - : Expression(type, ExpressionClass::BOUND_COMPARISON, LogicalType::BOOLEAN), left(std::move(left)), - right(std::move(right)) { -} - -string BoundComparisonExpression::ToString() const { - return ComparisonExpression::ToString(*this); -} - -bool BoundComparisonExpression::Equals(const BaseExpression &other_p) const { - if (!Expression::Equals(other_p)) { - return false; - } - auto &other = other_p.Cast(); - if (!Expression::Equals(*left, *other.left)) { - return false; - } - if (!Expression::Equals(*right, *other.right)) { - return false; - } - return true; -} - -unique_ptr BoundComparisonExpression::Copy() { - auto copy = make_uniq(type, left->Copy(), right->Copy()); - copy->CopyProperties(*this); - return std::move(copy); -} - -} // namespace duckdb - - - - -namespace duckdb { - -BoundConjunctionExpression::BoundConjunctionExpression(ExpressionType type) - : Expression(type, ExpressionClass::BOUND_CONJUNCTION, LogicalType::BOOLEAN) { -} - -BoundConjunctionExpression::BoundConjunctionExpression(ExpressionType type, unique_ptr left, - unique_ptr right) - : BoundConjunctionExpression(type) { - children.push_back(std::move(left)); - children.push_back(std::move(right)); -} - -string BoundConjunctionExpression::ToString() const { - return ConjunctionExpression::ToString(*this); -} - -bool BoundConjunctionExpression::Equals(const BaseExpression &other_p) const { - if (!Expression::Equals(other_p)) { - return false; - } - auto &other = other_p.Cast(); - return ExpressionUtil::SetEquals(children, other.children); -} - -bool BoundConjunctionExpression::PropagatesNullValues() const { - return false; -} - -unique_ptr BoundConjunctionExpression::Copy() { - auto copy = make_uniq(type); - for (auto &expr : children) { - copy->children.push_back(expr->Copy()); - } - copy->CopyProperties(*this); - return std::move(copy); -} - -} // namespace duckdb - - - - -namespace duckdb { - -BoundConstantExpression::BoundConstantExpression(Value value_p) - : Expression(ExpressionType::VALUE_CONSTANT, ExpressionClass::BOUND_CONSTANT, value_p.type()), - value(std::move(value_p)) { -} - -string BoundConstantExpression::ToString() const { - return value.ToSQLString(); -} - -bool BoundConstantExpression::Equals(const BaseExpression &other_p) const { - if (!Expression::Equals(other_p)) { - return false; - } - auto &other = other_p.Cast(); - return value.type() == other.value.type() && !ValueOperations::DistinctFrom(value, other.value); -} - -hash_t BoundConstantExpression::Hash() const { - hash_t result = Expression::Hash(); - return CombineHash(value.Hash(), result); -} - -unique_ptr BoundConstantExpression::Copy() { - auto copy = make_uniq(value); - copy->CopyProperties(*this); - return std::move(copy); -} - -} // namespace duckdb - - -namespace duckdb { - -BoundExpression::BoundExpression(unique_ptr expr_p) - : ParsedExpression(ExpressionType::INVALID, ExpressionClass::BOUND_EXPRESSION), expr(std::move(expr_p)) { - this->alias = expr->alias; -} - -unique_ptr &BoundExpression::GetExpression(ParsedExpression &expr) { - auto &bound_expr = expr.Cast(); - if (!bound_expr.expr) { - throw InternalException("BoundExpression::GetExpression called on empty bound expression"); - } - return bound_expr.expr; -} - -string BoundExpression::ToString() const { - if (!expr) { - throw InternalException("ToString(): BoundExpression does not have a child"); - } - return expr->ToString(); -} - -bool BoundExpression::Equals(const BaseExpression &other) const { - return false; -} -hash_t BoundExpression::Hash() const { - return 0; -} - -unique_ptr BoundExpression::Copy() const { - throw SerializationException("Cannot copy or serialize bound expression"); -} - -void BoundExpression::Serialize(Serializer &serializer) const { - throw SerializationException("Cannot copy or serialize bound expression"); -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -BoundFunctionExpression::BoundFunctionExpression(LogicalType return_type, ScalarFunction bound_function, - vector> arguments, - unique_ptr bind_info, bool is_operator) - : Expression(ExpressionType::BOUND_FUNCTION, ExpressionClass::BOUND_FUNCTION, std::move(return_type)), - function(std::move(bound_function)), children(std::move(arguments)), bind_info(std::move(bind_info)), - is_operator(is_operator) { - D_ASSERT(!function.name.empty()); -} - -bool BoundFunctionExpression::HasSideEffects() const { - return function.side_effects == FunctionSideEffects::HAS_SIDE_EFFECTS ? true : Expression::HasSideEffects(); -} - -bool BoundFunctionExpression::IsFoldable() const { - // functions with side effects cannot be folded: they have to be executed once for every row - return function.side_effects == FunctionSideEffects::HAS_SIDE_EFFECTS ? false : Expression::IsFoldable(); -} - -string BoundFunctionExpression::ToString() const { - return FunctionExpression::ToString(*this, string(), function.name, - is_operator); -} -bool BoundFunctionExpression::PropagatesNullValues() const { - return function.null_handling == FunctionNullHandling::SPECIAL_HANDLING ? false - : Expression::PropagatesNullValues(); -} - -hash_t BoundFunctionExpression::Hash() const { - hash_t result = Expression::Hash(); - return CombineHash(result, function.Hash()); -} - -bool BoundFunctionExpression::Equals(const BaseExpression &other_p) const { - if (!Expression::Equals(other_p)) { - return false; - } - auto &other = other_p.Cast(); - if (other.function != function) { - return false; - } - if (!Expression::ListEquals(children, other.children)) { - return false; - } - if (!FunctionData::Equals(bind_info.get(), other.bind_info.get())) { - return false; - } - return true; -} - -unique_ptr BoundFunctionExpression::Copy() { - vector> new_children; - new_children.reserve(children.size()); - for (auto &child : children) { - new_children.push_back(child->Copy()); - } - unique_ptr new_bind_info = bind_info ? bind_info->Copy() : nullptr; - - auto copy = make_uniq(return_type, function, std::move(new_children), - std::move(new_bind_info), is_operator); - copy->CopyProperties(*this); - return std::move(copy); -} - -void BoundFunctionExpression::Verify() const { - D_ASSERT(!function.name.empty()); -} - -void BoundFunctionExpression::Serialize(Serializer &serializer) const { - Expression::Serialize(serializer); - serializer.WriteProperty(200, "return_type", return_type); - serializer.WriteProperty(201, "children", children); - FunctionSerializer::Serialize(serializer, function, bind_info.get()); - serializer.WriteProperty(202, "is_operator", is_operator); -} - -unique_ptr BoundFunctionExpression::Deserialize(Deserializer &deserializer) { - auto return_type = deserializer.ReadProperty(200, "return_type"); - auto children = deserializer.ReadProperty>>(201, "children"); - auto entry = FunctionSerializer::Deserialize( - deserializer, CatalogType::SCALAR_FUNCTION_ENTRY, children, return_type); - auto result = make_uniq(std::move(return_type), std::move(entry.first), - std::move(children), std::move(entry.second)); - deserializer.ReadProperty(202, "is_operator", result->is_operator); - return std::move(result); -} - -} // namespace duckdb - - - -namespace duckdb { - -BoundLambdaExpression::BoundLambdaExpression(ExpressionType type_p, LogicalType return_type_p, - unique_ptr lambda_expr_p, idx_t parameter_count_p) - : Expression(type_p, ExpressionClass::BOUND_LAMBDA, std::move(return_type_p)), - lambda_expr(std::move(lambda_expr_p)), parameter_count(parameter_count_p) { -} - -string BoundLambdaExpression::ToString() const { - return lambda_expr->ToString(); -} - -bool BoundLambdaExpression::Equals(const BaseExpression &other_p) const { - if (!Expression::Equals(other_p)) { - return false; - } - auto &other = other_p.Cast(); - if (!Expression::Equals(*lambda_expr, *other.lambda_expr)) { - return false; - } - if (!Expression::ListEquals(captures, other.captures)) { - return false; - } - if (parameter_count != other.parameter_count) { - return false; - } - return true; -} - -unique_ptr BoundLambdaExpression::Copy() { - auto copy = make_uniq(type, return_type, lambda_expr->Copy(), parameter_count); - for (auto &capture : captures) { - copy->captures.push_back(capture->Copy()); - } - return std::move(copy); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -BoundLambdaRefExpression::BoundLambdaRefExpression(string alias_p, LogicalType type, ColumnBinding binding, - idx_t lambda_index, idx_t depth) - : Expression(ExpressionType::BOUND_LAMBDA_REF, ExpressionClass::BOUND_LAMBDA_REF, std::move(type)), - binding(binding), lambda_index(lambda_index), depth(depth) { - this->alias = std::move(alias_p); -} - -BoundLambdaRefExpression::BoundLambdaRefExpression(LogicalType type, ColumnBinding binding, idx_t lambda_index, - idx_t depth) - : BoundLambdaRefExpression(string(), std::move(type), binding, lambda_index, depth) { -} - -unique_ptr BoundLambdaRefExpression::Copy() { - return make_uniq(alias, return_type, binding, lambda_index, depth); -} - -hash_t BoundLambdaRefExpression::Hash() const { - auto result = Expression::Hash(); - result = CombineHash(result, duckdb::Hash(lambda_index)); - result = CombineHash(result, duckdb::Hash(binding.column_index)); - result = CombineHash(result, duckdb::Hash(binding.table_index)); - return CombineHash(result, duckdb::Hash(depth)); -} - -bool BoundLambdaRefExpression::Equals(const BaseExpression &other_p) const { - if (!Expression::Equals(other_p)) { - return false; - } - auto &other = other_p.Cast(); - return other.binding == binding && other.lambda_index == lambda_index && other.depth == depth; -} - -string BoundLambdaRefExpression::ToString() const { - if (!alias.empty()) { - return alias; - } - return "#[" + to_string(binding.table_index) + "." + to_string(binding.column_index) + "." + - to_string(lambda_index) + "]"; -} - -} // namespace duckdb - - - - -namespace duckdb { - -BoundOperatorExpression::BoundOperatorExpression(ExpressionType type, LogicalType return_type) - : Expression(type, ExpressionClass::BOUND_OPERATOR, std::move(return_type)) { -} - -string BoundOperatorExpression::ToString() const { - return OperatorExpression::ToString(*this); -} - -bool BoundOperatorExpression::Equals(const BaseExpression &other_p) const { - if (!Expression::Equals(other_p)) { - return false; - } - auto &other = other_p.Cast(); - if (!Expression::ListEquals(children, other.children)) { - return false; - } - return true; -} - -unique_ptr BoundOperatorExpression::Copy() { - auto copy = make_uniq(type, return_type); - copy->CopyProperties(*this); - for (auto &child : children) { - copy->children.push_back(child->Copy()); - } - return std::move(copy); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -BoundParameterExpression::BoundParameterExpression(const string &identifier) - : Expression(ExpressionType::VALUE_PARAMETER, ExpressionClass::BOUND_PARAMETER, - LogicalType(LogicalTypeId::UNKNOWN)), - identifier(identifier) { -} - -BoundParameterExpression::BoundParameterExpression(bound_parameter_map_t &global_parameter_set, string identifier, - LogicalType return_type, - shared_ptr parameter_data) - : Expression(ExpressionType::VALUE_PARAMETER, ExpressionClass::BOUND_PARAMETER, std::move(return_type)), - identifier(std::move(identifier)) { - // check if we have already deserialized a parameter with this number - auto entry = global_parameter_set.find(this->identifier); - if (entry == global_parameter_set.end()) { - // we have not - store the entry we deserialized from this parameter expression - global_parameter_set[this->identifier] = parameter_data; - } else { - // we have! use the previously deserialized entry - parameter_data = entry->second; - } - this->parameter_data = std::move(parameter_data); -} - -void BoundParameterExpression::Invalidate(Expression &expr) { - if (expr.type != ExpressionType::VALUE_PARAMETER) { - throw InternalException("BoundParameterExpression::Invalidate requires a parameter as input"); - } - auto &bound_parameter = expr.Cast(); - bound_parameter.return_type = LogicalTypeId::SQLNULL; - bound_parameter.parameter_data->return_type = LogicalTypeId::INVALID; -} - -void BoundParameterExpression::InvalidateRecursive(Expression &expr) { - if (expr.type == ExpressionType::VALUE_PARAMETER) { - Invalidate(expr); - return; - } - ExpressionIterator::EnumerateChildren(expr, [&](Expression &child) { InvalidateRecursive(child); }); -} - -bool BoundParameterExpression::IsScalar() const { - return true; -} -bool BoundParameterExpression::HasParameter() const { - return true; -} -bool BoundParameterExpression::IsFoldable() const { - return false; -} - -string BoundParameterExpression::ToString() const { - return "$" + identifier; -} - -bool BoundParameterExpression::Equals(const BaseExpression &other_p) const { - if (!Expression::Equals(other_p)) { - return false; - } - auto &other = other_p.Cast(); - return StringUtil::CIEquals(identifier, other.identifier); -} - -hash_t BoundParameterExpression::Hash() const { - hash_t result = Expression::Hash(); - result = CombineHash(duckdb::Hash(identifier.c_str(), identifier.size()), result); - return result; -} - -unique_ptr BoundParameterExpression::Copy() { - auto result = make_uniq(identifier); - result->parameter_data = parameter_data; - result->return_type = return_type; - result->CopyProperties(*this); - return std::move(result); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -BoundReferenceExpression::BoundReferenceExpression(string alias, LogicalType type, idx_t index) - : Expression(ExpressionType::BOUND_REF, ExpressionClass::BOUND_REF, std::move(type)), index(index) { - this->alias = std::move(alias); -} -BoundReferenceExpression::BoundReferenceExpression(LogicalType type, idx_t index) - : BoundReferenceExpression(string(), std::move(type), index) { -} - -string BoundReferenceExpression::ToString() const { -#ifdef DEBUG - if (DBConfigOptions::debug_print_bindings) { - return "#" + to_string(index); - } -#endif - if (!alias.empty()) { - return alias; - } - return "#" + to_string(index); -} - -bool BoundReferenceExpression::Equals(const BaseExpression &other_p) const { - if (!Expression::Equals(other_p)) { - return false; - } - auto &other = other_p.Cast(); - return other.index == index; -} - -hash_t BoundReferenceExpression::Hash() const { - return CombineHash(Expression::Hash(), duckdb::Hash(index)); -} - -unique_ptr BoundReferenceExpression::Copy() { - return make_uniq(alias, return_type, index); -} - -} // namespace duckdb - - - - -namespace duckdb { - -BoundSubqueryExpression::BoundSubqueryExpression(LogicalType return_type) - : Expression(ExpressionType::SUBQUERY, ExpressionClass::BOUND_SUBQUERY, std::move(return_type)) { -} - -string BoundSubqueryExpression::ToString() const { - return "SUBQUERY"; -} - -bool BoundSubqueryExpression::Equals(const BaseExpression &other_p) const { - // equality between bound subqueries not implemented currently - return false; -} - -unique_ptr BoundSubqueryExpression::Copy() { - throw SerializationException("Cannot copy BoundSubqueryExpression"); -} - -bool BoundSubqueryExpression::PropagatesNullValues() const { - // TODO this can be optimized further by checking the actual subquery node - return false; -} - -} // namespace duckdb - - - - - -namespace duckdb { - -BoundUnnestExpression::BoundUnnestExpression(LogicalType return_type) - : Expression(ExpressionType::BOUND_UNNEST, ExpressionClass::BOUND_UNNEST, std::move(return_type)) { -} - -bool BoundUnnestExpression::IsFoldable() const { - return false; -} - -string BoundUnnestExpression::ToString() const { - return "UNNEST(" + child->ToString() + ")"; -} - -hash_t BoundUnnestExpression::Hash() const { - hash_t result = Expression::Hash(); - return CombineHash(result, duckdb::Hash("unnest")); -} - -bool BoundUnnestExpression::Equals(const BaseExpression &other_p) const { - if (!Expression::Equals(other_p)) { - return false; - } - auto &other = other_p.Cast(); - if (!Expression::Equals(*child, *other.child)) { - return false; - } - return true; -} - -unique_ptr BoundUnnestExpression::Copy() { - auto copy = make_uniq(return_type); - copy->child = child->Copy(); - return std::move(copy); -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -BoundWindowExpression::BoundWindowExpression(ExpressionType type, LogicalType return_type, - unique_ptr aggregate, - unique_ptr bind_info) - : Expression(type, ExpressionClass::BOUND_WINDOW, std::move(return_type)), aggregate(std::move(aggregate)), - bind_info(std::move(bind_info)), ignore_nulls(false) { -} - -string BoundWindowExpression::ToString() const { - string function_name = aggregate.get() ? aggregate->name : ExpressionTypeToString(type); - return WindowExpression::ToString(*this, string(), - function_name); -} - -bool BoundWindowExpression::Equals(const BaseExpression &other_p) const { - if (!Expression::Equals(other_p)) { - return false; - } - auto &other = other_p.Cast(); - - if (ignore_nulls != other.ignore_nulls) { - return false; - } - if (start != other.start || end != other.end) { - return false; - } - // check if the child expressions are equivalent - if (!Expression::ListEquals(children, other.children)) { - return false; - } - // check if the filter expressions are equivalent - if (!Expression::Equals(filter_expr, other.filter_expr)) { - return false; - } - - // check if the framing expressions are equivalent - if (!Expression::Equals(start_expr, other.start_expr) || !Expression::Equals(end_expr, other.end_expr) || - !Expression::Equals(offset_expr, other.offset_expr) || !Expression::Equals(default_expr, other.default_expr)) { - return false; - } - - return KeysAreCompatible(other); -} - -bool BoundWindowExpression::KeysAreCompatible(const BoundWindowExpression &other) const { - // check if the partitions are equivalent - if (!Expression::ListEquals(partitions, other.partitions)) { - return false; - } - // check if the orderings are equivalent - if (orders.size() != other.orders.size()) { - return false; - } - for (idx_t i = 0; i < orders.size(); i++) { - if (!orders[i].Equals(other.orders[i])) { - return false; - } - } - return true; -} - -unique_ptr BoundWindowExpression::Copy() { - auto new_window = make_uniq(type, return_type, nullptr, nullptr); - new_window->CopyProperties(*this); - - if (aggregate) { - new_window->aggregate = make_uniq(*aggregate); - } - if (bind_info) { - new_window->bind_info = bind_info->Copy(); - } - for (auto &child : children) { - new_window->children.push_back(child->Copy()); - } - for (auto &e : partitions) { - new_window->partitions.push_back(e->Copy()); - } - for (auto &ps : partitions_stats) { - if (ps) { - new_window->partitions_stats.push_back(ps->ToUnique()); - } else { - new_window->partitions_stats.push_back(nullptr); - } - } - for (auto &o : orders) { - new_window->orders.emplace_back(o.type, o.null_order, o.expression->Copy()); - } - - new_window->filter_expr = filter_expr ? filter_expr->Copy() : nullptr; - - new_window->start = start; - new_window->end = end; - new_window->start_expr = start_expr ? start_expr->Copy() : nullptr; - new_window->end_expr = end_expr ? end_expr->Copy() : nullptr; - new_window->offset_expr = offset_expr ? offset_expr->Copy() : nullptr; - new_window->default_expr = default_expr ? default_expr->Copy() : nullptr; - new_window->ignore_nulls = ignore_nulls; - - return std::move(new_window); -} - -void BoundWindowExpression::Serialize(Serializer &serializer) const { - Expression::Serialize(serializer); - serializer.WriteProperty(200, "return_type", return_type); - serializer.WriteProperty(201, "children", children); - if (type == ExpressionType::WINDOW_AGGREGATE) { - D_ASSERT(aggregate); - FunctionSerializer::Serialize(serializer, *aggregate, bind_info.get()); - } - serializer.WriteProperty(202, "partitions", partitions); - serializer.WriteProperty(203, "orders", orders); - serializer.WritePropertyWithDefault(204, "filters", filter_expr, unique_ptr()); - serializer.WriteProperty(205, "ignore_nulls", ignore_nulls); - serializer.WriteProperty(206, "start", start); - serializer.WriteProperty(207, "end", end); - serializer.WritePropertyWithDefault(208, "start_expr", start_expr, unique_ptr()); - serializer.WritePropertyWithDefault(209, "end_expr", end_expr, unique_ptr()); - serializer.WritePropertyWithDefault(210, "offset_expr", offset_expr, unique_ptr()); - serializer.WritePropertyWithDefault(211, "default_expr", default_expr, unique_ptr()); -} - -unique_ptr BoundWindowExpression::Deserialize(Deserializer &deserializer) { - auto expression_type = deserializer.Get(); - auto return_type = deserializer.ReadProperty(200, "return_type"); - auto children = deserializer.ReadProperty>>(201, "children"); - unique_ptr aggregate; - unique_ptr bind_info; - if (expression_type == ExpressionType::WINDOW_AGGREGATE) { - auto entry = FunctionSerializer::Deserialize( - deserializer, CatalogType::AGGREGATE_FUNCTION_ENTRY, children, return_type); - aggregate = make_uniq(std::move(entry.first)); - bind_info = std::move(entry.second); - } - auto result = - make_uniq(expression_type, return_type, std::move(aggregate), std::move(bind_info)); - result->children = std::move(children); - deserializer.ReadProperty(202, "partitions", result->partitions); - deserializer.ReadProperty(203, "orders", result->orders); - deserializer.ReadPropertyWithDefault(204, "filters", result->filter_expr, unique_ptr()); - deserializer.ReadProperty(205, "ignore_nulls", result->ignore_nulls); - deserializer.ReadProperty(206, "start", result->start); - deserializer.ReadProperty(207, "end", result->end); - deserializer.ReadPropertyWithDefault(208, "start_expr", result->start_expr, unique_ptr()); - deserializer.ReadPropertyWithDefault(209, "end_expr", result->end_expr, unique_ptr()); - deserializer.ReadPropertyWithDefault(210, "offset_expr", result->offset_expr, unique_ptr()); - deserializer.ReadPropertyWithDefault(211, "default_expr", result->default_expr, unique_ptr()); - return std::move(result); -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -Expression::Expression(ExpressionType type, ExpressionClass expression_class, LogicalType return_type) - : BaseExpression(type, expression_class), return_type(std::move(return_type)) { -} - -Expression::~Expression() { -} - -bool Expression::IsAggregate() const { - bool is_aggregate = false; - ExpressionIterator::EnumerateChildren(*this, [&](const Expression &child) { is_aggregate |= child.IsAggregate(); }); - return is_aggregate; -} - -bool Expression::IsWindow() const { - bool is_window = false; - ExpressionIterator::EnumerateChildren(*this, [&](const Expression &child) { is_window |= child.IsWindow(); }); - return is_window; -} - -bool Expression::IsScalar() const { - bool is_scalar = true; - ExpressionIterator::EnumerateChildren(*this, [&](const Expression &child) { - if (!child.IsScalar()) { - is_scalar = false; - } - }); - return is_scalar; -} - -bool Expression::HasSideEffects() const { - bool has_side_effects = false; - ExpressionIterator::EnumerateChildren(*this, [&](const Expression &child) { - if (child.HasSideEffects()) { - has_side_effects = true; - } - }); - return has_side_effects; -} - -bool Expression::PropagatesNullValues() const { - if (type == ExpressionType::OPERATOR_IS_NULL || type == ExpressionType::OPERATOR_IS_NOT_NULL || - type == ExpressionType::COMPARE_NOT_DISTINCT_FROM || type == ExpressionType::COMPARE_DISTINCT_FROM || - type == ExpressionType::CONJUNCTION_OR || type == ExpressionType::CONJUNCTION_AND || - type == ExpressionType::OPERATOR_COALESCE) { - return false; - } - bool propagate_null_values = true; - ExpressionIterator::EnumerateChildren(*this, [&](const Expression &child) { - if (!child.PropagatesNullValues()) { - propagate_null_values = false; - } - }); - return propagate_null_values; -} - -bool Expression::IsFoldable() const { - bool is_foldable = true; - ExpressionIterator::EnumerateChildren(*this, [&](const Expression &child) { - if (!child.IsFoldable()) { - is_foldable = false; - } - }); - return is_foldable; -} - -bool Expression::HasParameter() const { - bool has_parameter = false; - ExpressionIterator::EnumerateChildren(*this, - [&](const Expression &child) { has_parameter |= child.HasParameter(); }); - return has_parameter; -} - -bool Expression::HasSubquery() const { - bool has_subquery = false; - ExpressionIterator::EnumerateChildren(*this, [&](const Expression &child) { has_subquery |= child.HasSubquery(); }); - return has_subquery; -} - -hash_t Expression::Hash() const { - hash_t hash = duckdb::Hash((uint32_t)type); - hash = CombineHash(hash, return_type.Hash()); - ExpressionIterator::EnumerateChildren(*this, - [&](const Expression &child) { hash = CombineHash(child.Hash(), hash); }); - return hash; -} - -bool Expression::Equals(const unique_ptr &left, const unique_ptr &right) { - if (left.get() == right.get()) { - return true; - } - if (!left || !right) { - return false; - } - return left->Equals(*right); -} - -bool Expression::ListEquals(const vector> &left, const vector> &right) { - return ExpressionUtil::ListEquals(left, right); -} - -} // namespace duckdb - - - - -namespace duckdb { - -AggregateBinder::AggregateBinder(Binder &binder, ClientContext &context) : ExpressionBinder(binder, context, true) { -} - -BindResult AggregateBinder::BindExpression(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { - auto &expr = *expr_ptr; - switch (expr.expression_class) { - case ExpressionClass::WINDOW: - throw ParserException("aggregate function calls cannot contain window function calls"); - default: - return ExpressionBinder::BindExpression(expr_ptr, depth); - } -} - -string AggregateBinder::UnsupportedAggregateMessage() { - return "aggregate function calls cannot be nested"; -} -} // namespace duckdb - - - - - - -namespace duckdb { - -AlterBinder::AlterBinder(Binder &binder, ClientContext &context, TableCatalogEntry &table, - vector &bound_columns, LogicalType target_type) - : ExpressionBinder(binder, context), table(table), bound_columns(bound_columns) { - this->target_type = std::move(target_type); -} - -BindResult AlterBinder::BindExpression(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { - auto &expr = *expr_ptr; - switch (expr.GetExpressionClass()) { - case ExpressionClass::WINDOW: - return BindResult("window functions are not allowed in alter statement"); - case ExpressionClass::SUBQUERY: - return BindResult("cannot use subquery in alter statement"); - case ExpressionClass::COLUMN_REF: - return BindColumn(expr.Cast()); - default: - return ExpressionBinder::BindExpression(expr_ptr, depth); - } -} - -string AlterBinder::UnsupportedAggregateMessage() { - return "aggregate functions are not allowed in alter statement"; -} - -BindResult AlterBinder::BindColumn(ColumnRefExpression &colref) { - if (colref.column_names.size() > 1) { - return BindQualifiedColumnName(colref, table.name); - } - auto idx = table.GetColumnIndex(colref.column_names[0], true); - if (!idx.IsValid()) { - throw BinderException("Table does not contain column %s referenced in alter statement!", - colref.column_names[0]); - } - if (table.GetColumn(idx).Generated()) { - throw BinderException("Using generated columns in alter statement not supported"); - } - bound_columns.push_back(idx); - return BindResult(make_uniq(table.GetColumn(idx).Type(), bound_columns.size() - 1)); -} - -} // namespace duckdb - - - - - - - - - - - - - -namespace duckdb { - -BaseSelectBinder::BaseSelectBinder(Binder &binder, ClientContext &context, BoundSelectNode &node, - BoundGroupInformation &info, case_insensitive_map_t alias_map) - : ExpressionBinder(binder, context), inside_window(false), node(node), info(info), alias_map(std::move(alias_map)) { -} - -BaseSelectBinder::BaseSelectBinder(Binder &binder, ClientContext &context, BoundSelectNode &node, - BoundGroupInformation &info) - : BaseSelectBinder(binder, context, node, info, case_insensitive_map_t()) { -} - -BindResult BaseSelectBinder::BindExpression(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { - auto &expr = *expr_ptr; - // check if the expression binds to one of the groups - auto group_index = TryBindGroup(expr, depth); - if (group_index != DConstants::INVALID_INDEX) { - return BindGroup(expr, depth, group_index); - } - switch (expr.expression_class) { - case ExpressionClass::COLUMN_REF: - return BindColumnRef(expr_ptr, depth); - case ExpressionClass::DEFAULT: - return BindResult("SELECT clause cannot contain DEFAULT clause"); - case ExpressionClass::WINDOW: - return BindWindow(expr.Cast(), depth); - default: - return ExpressionBinder::BindExpression(expr_ptr, depth, root_expression); - } -} - -idx_t BaseSelectBinder::TryBindGroup(ParsedExpression &expr, idx_t depth) { - // first check the group alias map, if expr is a ColumnRefExpression - if (expr.type == ExpressionType::COLUMN_REF) { - auto &colref = expr.Cast(); - if (!colref.IsQualified()) { - auto alias_entry = info.alias_map.find(colref.column_names[0]); - if (alias_entry != info.alias_map.end()) { - // found entry! - return alias_entry->second; - } - } - } - // no alias reference found - // check the list of group columns for a match - auto entry = info.map.find(expr); - if (entry != info.map.end()) { - return entry->second; - } -#ifdef DEBUG - for (auto entry : info.map) { - D_ASSERT(!entry.first.get().Equals(expr)); - D_ASSERT(!expr.Equals(entry.first.get())); - } -#endif - return DConstants::INVALID_INDEX; -} - -BindResult BaseSelectBinder::BindColumnRef(unique_ptr &expr_ptr, idx_t depth) { - // first try to bind the column reference regularly - auto result = ExpressionBinder::BindExpression(expr_ptr, depth); - if (!result.HasError()) { - return result; - } - // binding failed - // check in the alias map - auto &colref = (expr_ptr.get())->Cast(); - if (!colref.IsQualified()) { - auto alias_entry = alias_map.find(colref.column_names[0]); - if (alias_entry != alias_map.end()) { - // found entry! - auto index = alias_entry->second; - if (index >= node.select_list.size()) { - throw BinderException("Column \"%s\" referenced that exists in the SELECT clause - but this column " - "cannot be referenced before it is defined", - colref.column_names[0]); - } - if (node.select_list[index]->HasSideEffects()) { - throw BinderException("Alias \"%s\" referenced in a SELECT clause - but the expression has side " - "effects. This is not yet supported.", - colref.column_names[0]); - } - if (node.select_list[index]->HasSubquery()) { - throw BinderException("Alias \"%s\" referenced in a SELECT clause - but the expression has a subquery." - " This is not yet supported.", - colref.column_names[0]); - } - auto result = BindResult(node.select_list[index]->Copy()); - if (result.expression->type == ExpressionType::BOUND_COLUMN_REF) { - auto &result_expr = result.expression->Cast(); - result_expr.depth = depth; - } - return result; - } - } - // entry was not found in the alias map: return the original error - return result; -} - -BindResult BaseSelectBinder::BindGroupingFunction(OperatorExpression &op, idx_t depth) { - if (op.children.empty()) { - throw InternalException("GROUPING requires at least one child"); - } - if (node.groups.group_expressions.empty()) { - return BindResult(binder.FormatError(op, "GROUPING statement cannot be used without groups")); - } - if (op.children.size() >= 64) { - return BindResult(binder.FormatError(op, "GROUPING statement cannot have more than 64 groups")); - } - vector group_indexes; - group_indexes.reserve(op.children.size()); - for (auto &child : op.children) { - ExpressionBinder::QualifyColumnNames(binder, child); - auto idx = TryBindGroup(*child, depth); - if (idx == DConstants::INVALID_INDEX) { - return BindResult(binder.FormatError( - op, StringUtil::Format("GROUPING child \"%s\" must be a grouping column", child->GetName()))); - } - group_indexes.push_back(idx); - } - auto col_idx = node.grouping_functions.size(); - node.grouping_functions.push_back(std::move(group_indexes)); - return BindResult(make_uniq(op.GetName(), LogicalType::BIGINT, - ColumnBinding(node.groupings_index, col_idx), depth)); -} - -BindResult BaseSelectBinder::BindGroup(ParsedExpression &expr, idx_t depth, idx_t group_index) { - auto it = info.collated_groups.find(group_index); - if (it != info.collated_groups.end()) { - // This is an implicitly collated group, so we need to refer to the first() aggregate - const auto &aggr_index = it->second; - return BindResult(make_uniq(expr.GetName(), node.aggregates[aggr_index]->return_type, - ColumnBinding(node.aggregate_index, aggr_index), depth)); - } else { - auto &group = node.groups.group_expressions[group_index]; - return BindResult(make_uniq(expr.GetName(), group->return_type, - ColumnBinding(node.group_index, group_index), depth)); - } -} - -bool BaseSelectBinder::QualifyColumnAlias(const ColumnRefExpression &colref) { - if (!colref.IsQualified()) { - return alias_map.find(colref.column_names[0]) != alias_map.end() ? true : false; - } - return false; -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -CheckBinder::CheckBinder(Binder &binder, ClientContext &context, string table_p, const ColumnList &columns, - physical_index_set_t &bound_columns) - : ExpressionBinder(binder, context), table(std::move(table_p)), columns(columns), bound_columns(bound_columns) { - target_type = LogicalType::INTEGER; -} - -BindResult CheckBinder::BindExpression(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { - auto &expr = *expr_ptr; - switch (expr.GetExpressionClass()) { - case ExpressionClass::WINDOW: - return BindResult("window functions are not allowed in check constraints"); - case ExpressionClass::SUBQUERY: - return BindResult("cannot use subquery in check constraint"); - case ExpressionClass::COLUMN_REF: - return BindCheckColumn(expr.Cast()); - default: - return ExpressionBinder::BindExpression(expr_ptr, depth); - } -} - -string CheckBinder::UnsupportedAggregateMessage() { - return "aggregate functions are not allowed in check constraints"; -} - -BindResult ExpressionBinder::BindQualifiedColumnName(ColumnRefExpression &colref, const string &table_name) { - idx_t struct_start = 0; - if (colref.column_names[0] == table_name) { - struct_start++; - } - auto result = make_uniq_base(colref.column_names.back()); - for (idx_t i = struct_start; i + 1 < colref.column_names.size(); i++) { - result = CreateStructExtract(std::move(result), colref.column_names[i]); - } - return BindExpression(result, 0); -} - -BindResult CheckBinder::BindCheckColumn(ColumnRefExpression &colref) { - - // if this is a lambda parameters, then we temporarily add a BoundLambdaRef, - // which we capture and remove later - if (lambda_bindings) { - for (idx_t i = 0; i < lambda_bindings->size(); i++) { - if (colref.GetColumnName() == (*lambda_bindings)[i].dummy_name) { - // FIXME: support lambdas in CHECK constraints - // FIXME: like so: return (*lambda_bindings)[i].Bind(colref, i, depth); - throw NotImplementedException("Lambda functions are currently not supported in CHECK constraints."); - } - } - } - - if (colref.column_names.size() > 1) { - return BindQualifiedColumnName(colref, table); - } - if (!columns.ColumnExists(colref.column_names[0])) { - throw BinderException("Table does not contain column %s referenced in check constraint!", - colref.column_names[0]); - } - auto &col = columns.GetColumn(colref.column_names[0]); - if (col.Generated()) { - auto bound_expression = col.GeneratedExpression().Copy(); - return BindExpression(bound_expression, 0, false); - } - bound_columns.insert(col.Physical()); - D_ASSERT(col.StorageOid() != DConstants::INVALID_INDEX); - return BindResult(make_uniq(col.Type(), col.StorageOid())); -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -ColumnAliasBinder::ColumnAliasBinder(BoundSelectNode &node, const case_insensitive_map_t &alias_map) - : node(node), alias_map(alias_map), visited_select_indexes() { -} - -BindResult ColumnAliasBinder::BindAlias(ExpressionBinder &enclosing_binder, ColumnRefExpression &expr, idx_t depth, - bool root_expression) { - if (expr.IsQualified()) { - return BindResult(StringUtil::Format("Alias %s cannot be qualified.", expr.ToString())); - } - - auto alias_entry = alias_map.find(expr.column_names[0]); - if (alias_entry == alias_map.end()) { - return BindResult(StringUtil::Format("Alias %s is not found.", expr.ToString())); - } - - if (visited_select_indexes.find(alias_entry->second) != visited_select_indexes.end()) { - return BindResult("Cannot resolve self-referential alias"); - } - - // found an alias: bind the alias expression - auto expression = node.original_expressions[alias_entry->second]->Copy(); - visited_select_indexes.insert(alias_entry->second); - - // since the alias has been found, pass a depth of 0. See Issue 4978 (#16) - // ColumnAliasBinders are only in Having, Qualify and Where Binders - auto result = enclosing_binder.BindExpression(expression, 0, root_expression); - visited_select_indexes.erase(alias_entry->second); - return result; -} - -} // namespace duckdb - - - -namespace duckdb { - -ConstantBinder::ConstantBinder(Binder &binder, ClientContext &context, string clause) - : ExpressionBinder(binder, context), clause(std::move(clause)) { -} - -BindResult ConstantBinder::BindExpression(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { - auto &expr = *expr_ptr; - switch (expr.GetExpressionClass()) { - case ExpressionClass::COLUMN_REF: { - auto &colref = expr.Cast(); - if (!colref.IsQualified()) { - auto value_function = GetSQLValueFunction(colref.GetColumnName()); - if (value_function) { - expr_ptr = std::move(value_function); - return BindExpression(expr_ptr, depth, root_expression); - } - } - return BindResult(clause + " cannot contain column names"); - } - case ExpressionClass::SUBQUERY: - throw BinderException(clause + " cannot contain subqueries"); - case ExpressionClass::DEFAULT: - return BindResult(clause + " cannot contain DEFAULT clause"); - case ExpressionClass::WINDOW: - return BindResult(clause + " cannot contain window functions!"); - default: - return ExpressionBinder::BindExpression(expr_ptr, depth); - } -} - -string ConstantBinder::UnsupportedAggregateMessage() { - return clause + " cannot contain aggregates!"; -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -GroupBinder::GroupBinder(Binder &binder, ClientContext &context, SelectNode &node, idx_t group_index, - case_insensitive_map_t &alias_map, case_insensitive_map_t &group_alias_map) - : ExpressionBinder(binder, context), node(node), alias_map(alias_map), group_alias_map(group_alias_map), - group_index(group_index) { -} - -BindResult GroupBinder::BindExpression(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { - auto &expr = *expr_ptr; - if (root_expression && depth == 0) { - switch (expr.expression_class) { - case ExpressionClass::COLUMN_REF: - return BindColumnRef(expr.Cast()); - case ExpressionClass::CONSTANT: - return BindConstant(expr.Cast()); - case ExpressionClass::PARAMETER: - throw ParameterNotAllowedException("Parameter not supported in GROUP BY clause"); - default: - break; - } - } - switch (expr.expression_class) { - case ExpressionClass::DEFAULT: - return BindResult("GROUP BY clause cannot contain DEFAULT clause"); - case ExpressionClass::WINDOW: - return BindResult("GROUP BY clause cannot contain window functions!"); - default: - return ExpressionBinder::BindExpression(expr_ptr, depth); - } -} - -string GroupBinder::UnsupportedAggregateMessage() { - return "GROUP BY clause cannot contain aggregates!"; -} - -BindResult GroupBinder::BindSelectRef(idx_t entry) { - if (used_aliases.find(entry) != used_aliases.end()) { - // the alias has already been bound to before! - // this happens if we group on the same alias twice - // e.g. GROUP BY k, k or GROUP BY 1, 1 - // in this case, we can just replace the grouping with a constant since the second grouping has no effect - // (the constant grouping will be optimized out later) - return BindResult(make_uniq(Value::INTEGER(42))); - } - if (entry >= node.select_list.size()) { - throw BinderException("GROUP BY term out of range - should be between 1 and %d", (int)node.select_list.size()); - } - // we replace the root expression, also replace the unbound expression - unbound_expression = node.select_list[entry]->Copy(); - // move the expression that this refers to here and bind it - auto select_entry = std::move(node.select_list[entry]); - auto binding = Bind(select_entry, nullptr, false); - // now replace the original expression in the select list with a reference to this group - group_alias_map[to_string(entry)] = bind_index; - node.select_list[entry] = make_uniq(to_string(entry)); - // insert into the set of used aliases - used_aliases.insert(entry); - return BindResult(std::move(binding)); -} - -BindResult GroupBinder::BindConstant(ConstantExpression &constant) { - // constant as root expression - if (!constant.value.type().IsIntegral()) { - // non-integral expression, we just leave the constant here. - return ExpressionBinder::BindExpression(constant, 0); - } - // INTEGER constant: we use the integer as an index into the select list (e.g. GROUP BY 1) - auto index = (idx_t)constant.value.GetValue(); - return BindSelectRef(index - 1); -} - -BindResult GroupBinder::BindColumnRef(ColumnRefExpression &colref) { - // columns in GROUP BY clauses: - // FIRST refer to the original tables, and - // THEN if no match is found refer to aliases in the SELECT list - // THEN if no match is found, refer to outer queries - - // first try to bind to the base columns (original tables) - auto result = ExpressionBinder::BindExpression(colref, 0); - if (result.HasError()) { - if (colref.IsQualified()) { - // explicit table name: not an alias reference - return result; - } - // failed to bind the column and the node is the root expression with depth = 0 - // check if refers to an alias in the select clause - auto alias_name = colref.column_names[0]; - auto entry = alias_map.find(alias_name); - if (entry == alias_map.end()) { - // no matching alias found - return result; - } - result = BindResult(BindSelectRef(entry->second)); - if (!result.HasError()) { - group_alias_map[alias_name] = bind_index; - } - } - return result; -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -HavingBinder::HavingBinder(Binder &binder, ClientContext &context, BoundSelectNode &node, BoundGroupInformation &info, - case_insensitive_map_t &alias_map, AggregateHandling aggregate_handling) - : BaseSelectBinder(binder, context, node, info), column_alias_binder(node, alias_map), - aggregate_handling(aggregate_handling) { - target_type = LogicalType(LogicalTypeId::BOOLEAN); -} - -BindResult HavingBinder::BindColumnRef(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { - auto &expr = expr_ptr->Cast(); - auto alias_result = column_alias_binder.BindAlias(*this, expr, depth, root_expression); - if (!alias_result.HasError()) { - if (depth > 0) { - throw BinderException("Having clause cannot reference alias in correlated subquery"); - } - return alias_result; - } - if (aggregate_handling == AggregateHandling::FORCE_AGGREGATES) { - if (depth > 0) { - throw BinderException("Having clause cannot reference column in correlated subquery and group by all"); - } - auto expr = duckdb::BaseSelectBinder::BindExpression(expr_ptr, depth); - if (expr.HasError()) { - return expr; - } - auto group_ref = make_uniq( - expr.expression->return_type, ColumnBinding(node.group_index, node.groups.group_expressions.size())); - node.groups.group_expressions.push_back(std::move(expr.expression)); - return BindResult(std::move(group_ref)); - } - return BindResult(StringUtil::Format( - "column %s must appear in the GROUP BY clause or be used in an aggregate function", expr.ToString())); -} - -BindResult HavingBinder::BindExpression(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { - auto &expr = *expr_ptr; - // check if the expression binds to one of the groups - auto group_index = TryBindGroup(expr, depth); - if (group_index != DConstants::INVALID_INDEX) { - return BindGroup(expr, depth, group_index); - } - switch (expr.expression_class) { - case ExpressionClass::WINDOW: - return BindResult("HAVING clause cannot contain window functions!"); - case ExpressionClass::COLUMN_REF: - return BindColumnRef(expr_ptr, depth, root_expression); - default: - return duckdb::BaseSelectBinder::BindExpression(expr_ptr, depth); - } -} - -} // namespace duckdb - - - - - - - -namespace duckdb { - -IndexBinder::IndexBinder(Binder &binder, ClientContext &context, optional_ptr table, - optional_ptr info) - : ExpressionBinder(binder, context), table(table), info(info) { -} - -BindResult IndexBinder::BindExpression(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { - auto &expr = *expr_ptr; - switch (expr.expression_class) { - case ExpressionClass::WINDOW: - return BindResult("window functions are not allowed in index expressions"); - case ExpressionClass::SUBQUERY: - return BindResult("cannot use subquery in index expressions"); - case ExpressionClass::COLUMN_REF: { - if (table) { - // WAL replay - // we assume that the parsed expressions have qualified column names - // and that the columns exist in the table - auto &col_ref = expr.Cast(); - auto col_idx = table->GetColumnIndex(col_ref.column_names.back()); - auto col_type = table->GetColumn(col_idx).GetType(); - - // find the col_idx in the index.column_ids - auto col_id_idx = DConstants::INVALID_INDEX; - for (idx_t i = 0; i < info->column_ids.size(); i++) { - if (col_idx.index == info->column_ids[i]) { - col_id_idx = i; - } - } - - if (col_id_idx == DConstants::INVALID_INDEX) { - throw InternalException("failed to replay CREATE INDEX statement - column id not found"); - } - return BindResult( - make_uniq(col_ref.GetColumnName(), col_type, ColumnBinding(0, col_id_idx))); - } - return ExpressionBinder::BindExpression(expr_ptr, depth); - } - default: - return ExpressionBinder::BindExpression(expr_ptr, depth); - } -} - -string IndexBinder::UnsupportedAggregateMessage() { - return "aggregate functions are not allowed in index expressions"; -} - -} // namespace duckdb - - - - -namespace duckdb { - -InsertBinder::InsertBinder(Binder &binder, ClientContext &context) : ExpressionBinder(binder, context) { -} - -BindResult InsertBinder::BindExpression(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { - auto &expr = *expr_ptr; - switch (expr.GetExpressionClass()) { - case ExpressionClass::DEFAULT: - return BindResult("DEFAULT is not allowed here!"); - case ExpressionClass::WINDOW: - return BindResult("INSERT statement cannot contain window functions!"); - default: - return ExpressionBinder::BindExpression(expr_ptr, depth); - } -} - -string InsertBinder::UnsupportedAggregateMessage() { - return "INSERT statement cannot contain aggregates!"; -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -LateralBinder::LateralBinder(Binder &binder, ClientContext &context) : ExpressionBinder(binder, context) { -} - -void LateralBinder::ExtractCorrelatedColumns(Expression &expr) { - if (expr.type == ExpressionType::BOUND_COLUMN_REF) { - auto &bound_colref = expr.Cast(); - if (bound_colref.depth > 0) { - // add the correlated column info - CorrelatedColumnInfo info(bound_colref); - if (std::find(correlated_columns.begin(), correlated_columns.end(), info) == correlated_columns.end()) { - correlated_columns.push_back(std::move(info)); - } - } - } - ExpressionIterator::EnumerateChildren(expr, [&](Expression &child) { ExtractCorrelatedColumns(child); }); -} - -BindResult LateralBinder::BindColumnRef(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { - if (depth == 0) { - throw InternalException("Lateral binder can only bind correlated columns"); - } - auto result = ExpressionBinder::BindExpression(expr_ptr, depth); - if (result.HasError()) { - return result; - } - ExtractCorrelatedColumns(*result.expression); - return result; -} - -BindResult LateralBinder::BindExpression(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { - auto &expr = *expr_ptr; - switch (expr.GetExpressionClass()) { - case ExpressionClass::DEFAULT: - return BindResult("LATERAL join cannot contain DEFAULT clause"); - case ExpressionClass::WINDOW: - return BindResult("LATERAL join cannot contain window functions!"); - case ExpressionClass::COLUMN_REF: - return BindColumnRef(expr_ptr, depth, root_expression); - default: - return ExpressionBinder::BindExpression(expr_ptr, depth); - } -} - -string LateralBinder::UnsupportedAggregateMessage() { - return "LATERAL join cannot contain aggregates!"; -} - -class ExpressionDepthReducer : public LogicalOperatorVisitor { -public: - explicit ExpressionDepthReducer(const vector &correlated) : correlated_columns(correlated) { - } - -protected: - void ReduceColumnRefDepth(BoundColumnRefExpression &expr) { - // don't need to reduce this - if (expr.depth == 0) { - return; - } - for (auto &correlated : correlated_columns) { - if (correlated.binding == expr.binding) { - D_ASSERT(expr.depth > 1); - expr.depth--; - break; - } - } - } - - unique_ptr VisitReplace(BoundColumnRefExpression &expr, unique_ptr *expr_ptr) override { - ReduceColumnRefDepth(expr); - return nullptr; - } - - void ReduceExpressionSubquery(BoundSubqueryExpression &expr) { - for (auto &s_correlated : expr.binder->correlated_columns) { - for (auto &correlated : correlated_columns) { - if (correlated == s_correlated) { - s_correlated.depth--; - break; - } - } - } - } - - void ReduceExpressionDepth(Expression &expr) { - if (expr.GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { - ReduceColumnRefDepth(expr.Cast()); - } - if (expr.GetExpressionClass() == ExpressionClass::BOUND_SUBQUERY) { - auto &subquery_ref = expr.Cast(); - ReduceExpressionSubquery(expr.Cast()); - // Recursively update the depth in the bindings of the children nodes - ExpressionIterator::EnumerateQueryNodeChildren( - *subquery_ref.subquery, [&](Expression &child_expr) { ReduceExpressionDepth(child_expr); }); - } - } - - unique_ptr VisitReplace(BoundSubqueryExpression &expr, unique_ptr *expr_ptr) override { - ReduceExpressionSubquery(expr); - ExpressionIterator::EnumerateQueryNodeChildren( - *expr.subquery, [&](Expression &child_expr) { ReduceExpressionDepth(child_expr); }); - return nullptr; - } - - const vector &correlated_columns; -}; - -void LateralBinder::ReduceExpressionDepth(LogicalOperator &op, const vector &correlated) { - ExpressionDepthReducer depth_reducer(correlated); - depth_reducer.VisitOperator(op); -} - -} // namespace duckdb - - - - - - - - - - - - -namespace duckdb { - -OrderBinder::OrderBinder(vector binders, idx_t projection_index, case_insensitive_map_t &alias_map, - parsed_expression_map_t &projection_map, idx_t max_count) - : binders(std::move(binders)), projection_index(projection_index), max_count(max_count), extra_list(nullptr), - alias_map(alias_map), projection_map(projection_map) { -} -OrderBinder::OrderBinder(vector binders, idx_t projection_index, SelectNode &node, - case_insensitive_map_t &alias_map, parsed_expression_map_t &projection_map) - : binders(std::move(binders)), projection_index(projection_index), alias_map(alias_map), - projection_map(projection_map) { - this->max_count = node.select_list.size(); - this->extra_list = &node.select_list; -} - -unique_ptr OrderBinder::CreateProjectionReference(ParsedExpression &expr, idx_t index) { - string alias; - if (extra_list && index < extra_list->size()) { - alias = extra_list->at(index)->ToString(); - } else { - if (!expr.alias.empty()) { - alias = expr.alias; - } - } - return make_uniq(std::move(alias), LogicalType::INVALID, - ColumnBinding(projection_index, index)); -} - -unique_ptr OrderBinder::CreateExtraReference(unique_ptr expr) { - if (!extra_list) { - throw InternalException("CreateExtraReference called without extra_list"); - } - projection_map[*expr] = extra_list->size(); - auto result = CreateProjectionReference(*expr, extra_list->size()); - extra_list->push_back(std::move(expr)); - return result; -} - -unique_ptr OrderBinder::BindConstant(ParsedExpression &expr, const Value &val) { - // ORDER BY a constant - if (!val.type().IsIntegral()) { - // non-integral expression, we just leave the constant here. - // ORDER BY has no effect - // CONTROVERSIAL: maybe we should throw an error - return nullptr; - } - // INTEGER constant: we use the integer as an index into the select list (e.g. ORDER BY 1) - auto index = (idx_t)val.GetValue(); - if (index < 1 || index > max_count) { - throw BinderException("ORDER term out of range - should be between 1 and %lld", (idx_t)max_count); - } - return CreateProjectionReference(expr, index - 1); -} - -unique_ptr OrderBinder::Bind(unique_ptr expr) { - // in the ORDER BY clause we do not bind children - // we bind ONLY to the select list - // if there is no matching entry in the SELECT list already, we add the expression to the SELECT list and refer the - // new expression the new entry will then be bound later during the binding of the SELECT list we also don't do type - // resolution here: this only happens after the SELECT list has been bound - switch (expr->expression_class) { - case ExpressionClass::CONSTANT: { - // ORDER BY constant - // is the ORDER BY expression a constant integer? (e.g. ORDER BY 1) - auto &constant = expr->Cast(); - return BindConstant(*expr, constant.value); - } - case ExpressionClass::COLUMN_REF: { - // COLUMN REF expression - // check if we can bind it to an alias in the select list - auto &colref = expr->Cast(); - // if there is an explicit table name we can't bind to an alias - if (colref.IsQualified()) { - break; - } - // check the alias list - auto entry = alias_map.find(colref.column_names[0]); - if (entry != alias_map.end()) { - // it does! point it to that entry - return CreateProjectionReference(*expr, entry->second); - } - break; - } - case ExpressionClass::POSITIONAL_REFERENCE: { - auto &posref = expr->Cast(); - if (posref.index < 1 || posref.index > max_count) { - throw BinderException("ORDER term out of range - should be between 1 and %lld", (idx_t)max_count); - } - return CreateProjectionReference(*expr, posref.index - 1); - } - case ExpressionClass::PARAMETER: { - throw ParameterNotAllowedException("Parameter not supported in ORDER BY clause"); - } - default: - break; - } - // general case - // first bind the table names of this entry - for (auto &binder : binders) { - ExpressionBinder::QualifyColumnNames(*binder, expr); - } - // first check if the ORDER BY clause already points to an entry in the projection list - auto entry = projection_map.find(*expr); - if (entry != projection_map.end()) { - if (entry->second == DConstants::INVALID_INDEX) { - throw BinderException("Ambiguous reference to column"); - } - // there is a matching entry in the projection list - // just point to that entry - return CreateProjectionReference(*expr, entry->second); - } - if (!extra_list) { - // no extra list specified: we cannot push an extra ORDER BY clause - throw BinderException("Could not ORDER BY column \"%s\": add the expression/function to every SELECT, or move " - "the UNION into a FROM clause.", - expr->ToString()); - } - // otherwise we need to push the ORDER BY entry into the select list - return CreateExtraReference(std::move(expr)); -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -QualifyBinder::QualifyBinder(Binder &binder, ClientContext &context, BoundSelectNode &node, BoundGroupInformation &info, - case_insensitive_map_t &alias_map) - : BaseSelectBinder(binder, context, node, info), column_alias_binder(node, alias_map) { - target_type = LogicalType(LogicalTypeId::BOOLEAN); -} - -BindResult QualifyBinder::BindColumnRef(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { - auto &expr = expr_ptr->Cast(); - auto result = duckdb::BaseSelectBinder::BindExpression(expr_ptr, depth); - if (!result.HasError()) { - return result; - } - - auto alias_result = column_alias_binder.BindAlias(*this, expr, depth, root_expression); - if (!alias_result.HasError()) { - return alias_result; - } - - return BindResult(StringUtil::Format("Referenced column %s not found in FROM clause and can't find in alias map.", - expr.ToString())); -} - -BindResult QualifyBinder::BindExpression(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { - auto &expr = *expr_ptr; - // check if the expression binds to one of the groups - auto group_index = TryBindGroup(expr, depth); - if (group_index != DConstants::INVALID_INDEX) { - return BindGroup(expr, depth, group_index); - } - switch (expr.expression_class) { - case ExpressionClass::WINDOW: - return BindWindow(expr.Cast(), depth); - case ExpressionClass::COLUMN_REF: - return BindColumnRef(expr_ptr, depth, root_expression); - default: - return duckdb::BaseSelectBinder::BindExpression(expr_ptr, depth); - } -} - -} // namespace duckdb - - -namespace duckdb { - -RelationBinder::RelationBinder(Binder &binder, ClientContext &context, string op) - : ExpressionBinder(binder, context), op(std::move(op)) { -} - -BindResult RelationBinder::BindExpression(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { - auto &expr = *expr_ptr; - switch (expr.expression_class) { - case ExpressionClass::AGGREGATE: - return BindResult("aggregate functions are not allowed in " + op); - case ExpressionClass::DEFAULT: - return BindResult(op + " cannot contain DEFAULT clause"); - case ExpressionClass::SUBQUERY: - return BindResult("subqueries are not allowed in " + op); - case ExpressionClass::WINDOW: - return BindResult("window functions are not allowed in " + op); - default: - return ExpressionBinder::BindExpression(expr_ptr, depth); - } -} - -string RelationBinder::UnsupportedAggregateMessage() { - return "aggregate functions are not allowed in " + op; -} - -} // namespace duckdb - - - - -namespace duckdb { - -ReturningBinder::ReturningBinder(Binder &binder, ClientContext &context) : ExpressionBinder(binder, context) { -} - -BindResult ReturningBinder::BindExpression(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { - auto &expr = *expr_ptr; - switch (expr.GetExpressionClass()) { - case ExpressionClass::SUBQUERY: - return BindResult("SUBQUERY is not supported in returning statements"); - case ExpressionClass::BOUND_SUBQUERY: - return BindResult("BOUND SUBQUERY is not supported in returning statements"); - case ExpressionClass::COLUMN_REF: - return ExpressionBinder::BindExpression(expr_ptr, depth); - default: - return ExpressionBinder::BindExpression(expr_ptr, depth); - } -} - -} // namespace duckdb - - -namespace duckdb { - -SelectBinder::SelectBinder(Binder &binder, ClientContext &context, BoundSelectNode &node, BoundGroupInformation &info, - case_insensitive_map_t alias_map) - : BaseSelectBinder(binder, context, node, info, std::move(alias_map)) { -} - -SelectBinder::SelectBinder(Binder &binder, ClientContext &context, BoundSelectNode &node, BoundGroupInformation &info) - : SelectBinder(binder, context, node, info, case_insensitive_map_t()) { -} - -} // namespace duckdb - - - - - -namespace duckdb { - -TableFunctionBinder::TableFunctionBinder(Binder &binder, ClientContext &context) : ExpressionBinder(binder, context) { -} - -BindResult TableFunctionBinder::BindColumnReference(ColumnRefExpression &expr, idx_t depth, bool root_expression) { - - // if this is a lambda parameters, then we temporarily add a BoundLambdaRef, - // which we capture and remove later - if (lambda_bindings) { - auto &colref = expr.Cast(); - for (idx_t i = 0; i < lambda_bindings->size(); i++) { - if (colref.GetColumnName() == (*lambda_bindings)[i].dummy_name) { - return (*lambda_bindings)[i].Bind(colref, i, depth); - } - } - } - auto value_function = ExpressionBinder::GetSQLValueFunction(expr.GetColumnName()); - if (value_function) { - return BindExpression(value_function, depth, root_expression); - } - - auto result_name = StringUtil::Join(expr.column_names, "."); - return BindResult(make_uniq(Value(result_name))); -} - -BindResult TableFunctionBinder::BindExpression(unique_ptr &expr_ptr, idx_t depth, - bool root_expression) { - auto &expr = *expr_ptr; - switch (expr.GetExpressionClass()) { - case ExpressionClass::COLUMN_REF: - return BindColumnReference(expr.Cast(), depth, root_expression); - case ExpressionClass::SUBQUERY: - throw BinderException("Table function cannot contain subqueries"); - case ExpressionClass::DEFAULT: - return BindResult("Table function cannot contain DEFAULT clause"); - case ExpressionClass::WINDOW: - return BindResult("Table function cannot contain window functions!"); - default: - return ExpressionBinder::BindExpression(expr_ptr, depth); - } -} - -string TableFunctionBinder::UnsupportedAggregateMessage() { - return "Table function cannot contain aggregates!"; -} - -} // namespace duckdb - - -namespace duckdb { - -UpdateBinder::UpdateBinder(Binder &binder, ClientContext &context) : ExpressionBinder(binder, context) { -} - -BindResult UpdateBinder::BindExpression(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { - auto &expr = *expr_ptr; - switch (expr.expression_class) { - case ExpressionClass::WINDOW: - return BindResult("window functions are not allowed in UPDATE"); - default: - return ExpressionBinder::BindExpression(expr_ptr, depth); - } -} - -string UpdateBinder::UnsupportedAggregateMessage() { - return "aggregate functions are not allowed in UPDATE"; -} - -} // namespace duckdb - - - - -namespace duckdb { - -WhereBinder::WhereBinder(Binder &binder, ClientContext &context, optional_ptr column_alias_binder) - : ExpressionBinder(binder, context), column_alias_binder(column_alias_binder) { - target_type = LogicalType(LogicalTypeId::BOOLEAN); -} - -BindResult WhereBinder::BindColumnRef(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { - auto &expr = expr_ptr->Cast(); - auto result = ExpressionBinder::BindExpression(expr_ptr, depth); - if (!result.HasError() || !column_alias_binder) { - return result; - } - - BindResult alias_result = column_alias_binder->BindAlias(*this, expr, depth, root_expression); - // This code path cannot be exercised at thispoint. #1547 might change that. - if (!alias_result.HasError()) { - return alias_result; - } - - return result; -} - -BindResult WhereBinder::BindExpression(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { - auto &expr = *expr_ptr; - switch (expr.GetExpressionClass()) { - case ExpressionClass::DEFAULT: - return BindResult("WHERE clause cannot contain DEFAULT clause"); - case ExpressionClass::WINDOW: - return BindResult("WHERE clause cannot contain window functions!"); - case ExpressionClass::COLUMN_REF: - return BindColumnRef(expr_ptr, depth, root_expression); - default: - return ExpressionBinder::BindExpression(expr_ptr, depth); - } -} - -string WhereBinder::UnsupportedAggregateMessage() { - return "WHERE clause cannot contain aggregates!"; -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -ExpressionBinder::ExpressionBinder(Binder &binder, ClientContext &context, bool replace_binder) - : binder(binder), context(context) { - InitializeStackCheck(); - if (replace_binder) { - stored_binder = &binder.GetActiveBinder(); - binder.SetActiveBinder(*this); - } else { - binder.PushExpressionBinder(*this); - } -} - -ExpressionBinder::~ExpressionBinder() { - if (binder.HasActiveBinder()) { - if (stored_binder) { - binder.SetActiveBinder(*stored_binder); - } else { - binder.PopExpressionBinder(); - } - } -} - -void ExpressionBinder::InitializeStackCheck() { - if (binder.HasActiveBinder()) { - stack_depth = binder.GetActiveBinder().stack_depth; - } else { - stack_depth = 0; - } -} - -StackChecker ExpressionBinder::StackCheck(const ParsedExpression &expr, idx_t extra_stack) { - D_ASSERT(stack_depth != DConstants::INVALID_INDEX); - if (stack_depth + extra_stack >= MAXIMUM_STACK_DEPTH) { - throw BinderException("Maximum recursion depth exceeded (Maximum: %llu) while binding \"%s\"", - MAXIMUM_STACK_DEPTH, expr.ToString()); - } - return StackChecker(*this, extra_stack); -} - -BindResult ExpressionBinder::BindExpression(unique_ptr &expr, idx_t depth, bool root_expression) { - auto stack_checker = StackCheck(*expr); - - auto &expr_ref = *expr; - switch (expr_ref.expression_class) { - case ExpressionClass::BETWEEN: - return BindExpression(expr_ref.Cast(), depth); - case ExpressionClass::CASE: - return BindExpression(expr_ref.Cast(), depth); - case ExpressionClass::CAST: - return BindExpression(expr_ref.Cast(), depth); - case ExpressionClass::COLLATE: - return BindExpression(expr_ref.Cast(), depth); - case ExpressionClass::COLUMN_REF: - return BindExpression(expr_ref.Cast(), depth); - case ExpressionClass::COMPARISON: - return BindExpression(expr_ref.Cast(), depth); - case ExpressionClass::CONJUNCTION: - return BindExpression(expr_ref.Cast(), depth); - case ExpressionClass::CONSTANT: - return BindExpression(expr_ref.Cast(), depth); - case ExpressionClass::FUNCTION: { - auto &function = expr_ref.Cast(); - if (function.function_name == "unnest" || function.function_name == "unlist") { - // special case, not in catalog - return BindUnnest(function, depth, root_expression); - } - // binding function expression has extra parameter needed for macro's - return BindExpression(function, depth, expr); - } - case ExpressionClass::LAMBDA: - return BindExpression(expr_ref.Cast(), depth, false, LogicalTypeId::INVALID); - case ExpressionClass::OPERATOR: - return BindExpression(expr_ref.Cast(), depth); - case ExpressionClass::SUBQUERY: - return BindExpression(expr_ref.Cast(), depth); - case ExpressionClass::PARAMETER: - return BindExpression(expr_ref.Cast(), depth); - case ExpressionClass::POSITIONAL_REFERENCE: { - return BindPositionalReference(expr, depth, root_expression); - } - case ExpressionClass::STAR: - return BindResult(binder.FormatError(expr_ref, "STAR expression is not supported here")); - default: - throw NotImplementedException("Unimplemented expression class"); - } -} - -bool ExpressionBinder::BindCorrelatedColumns(unique_ptr &expr) { - // try to bind in one of the outer queries, if the binding error occurred in a subquery - auto &active_binders = binder.GetActiveBinders(); - // make a copy of the set of binders, so we can restore it later - auto binders = active_binders; - - // we already failed with the current binder - active_binders.pop_back(); - idx_t depth = 1; - bool success = false; - - while (!active_binders.empty()) { - auto &next_binder = active_binders.back().get(); - ExpressionBinder::QualifyColumnNames(next_binder.binder, expr); - auto bind_result = next_binder.Bind(expr, depth); - if (bind_result.empty()) { - success = true; - break; - } - depth++; - active_binders.pop_back(); - } - active_binders = binders; - return success; -} - -void ExpressionBinder::BindChild(unique_ptr &expr, idx_t depth, string &error) { - if (expr) { - string bind_error = Bind(expr, depth); - if (error.empty()) { - error = bind_error; - } - } -} - -void ExpressionBinder::ExtractCorrelatedExpressions(Binder &binder, Expression &expr) { - if (expr.type == ExpressionType::BOUND_COLUMN_REF) { - auto &bound_colref = expr.Cast(); - if (bound_colref.depth > 0) { - binder.AddCorrelatedColumn(CorrelatedColumnInfo(bound_colref)); - } - } - ExpressionIterator::EnumerateChildren(expr, - [&](Expression &child) { ExtractCorrelatedExpressions(binder, child); }); -} - -bool ExpressionBinder::ContainsType(const LogicalType &type, LogicalTypeId target) { - if (type.id() == target) { - return true; - } - switch (type.id()) { - case LogicalTypeId::STRUCT: { - auto child_count = StructType::GetChildCount(type); - for (idx_t i = 0; i < child_count; i++) { - if (ContainsType(StructType::GetChildType(type, i), target)) { - return true; - } - } - return false; - } - case LogicalTypeId::UNION: { - auto member_count = UnionType::GetMemberCount(type); - for (idx_t i = 0; i < member_count; i++) { - if (ContainsType(UnionType::GetMemberType(type, i), target)) { - return true; - } - } - return false; - } - case LogicalTypeId::LIST: - case LogicalTypeId::MAP: - return ContainsType(ListType::GetChildType(type), target); - default: - return false; - } -} - -LogicalType ExpressionBinder::ExchangeType(const LogicalType &type, LogicalTypeId target, LogicalType new_type) { - if (type.id() == target) { - return new_type; - } - switch (type.id()) { - case LogicalTypeId::STRUCT: { - // we make a copy of the child types of the struct here - auto child_types = StructType::GetChildTypes(type); - for (auto &child_type : child_types) { - child_type.second = ExchangeType(child_type.second, target, new_type); - } - return LogicalType::STRUCT(child_types); - } - case LogicalTypeId::UNION: { - auto member_types = UnionType::CopyMemberTypes(type); - for (auto &member_type : member_types) { - member_type.second = ExchangeType(member_type.second, target, new_type); - } - return LogicalType::UNION(std::move(member_types)); - } - case LogicalTypeId::LIST: - return LogicalType::LIST(ExchangeType(ListType::GetChildType(type), target, new_type)); - case LogicalTypeId::MAP: - return LogicalType::MAP(ExchangeType(ListType::GetChildType(type), target, new_type)); - default: - return type; - } -} - -bool ExpressionBinder::ContainsNullType(const LogicalType &type) { - return ContainsType(type, LogicalTypeId::SQLNULL); -} - -LogicalType ExpressionBinder::ExchangeNullType(const LogicalType &type) { - return ExchangeType(type, LogicalTypeId::SQLNULL, LogicalType::INTEGER); -} - -unique_ptr ExpressionBinder::Bind(unique_ptr &expr, optional_ptr result_type, - bool root_expression) { - // bind the main expression - auto error_msg = Bind(expr, 0, root_expression); - if (!error_msg.empty()) { - // failed to bind: try to bind correlated columns in the expression (if any) - bool success = BindCorrelatedColumns(expr); - if (!success) { - throw BinderException(error_msg); - } - auto &bound_expr = expr->Cast(); - ExtractCorrelatedExpressions(binder, *bound_expr.expr); - } - auto &bound_expr = expr->Cast(); - unique_ptr result = std::move(bound_expr.expr); - if (target_type.id() != LogicalTypeId::INVALID) { - // the binder has a specific target type: add a cast to that type - result = BoundCastExpression::AddCastToType(context, std::move(result), target_type); - } else { - if (!binder.can_contain_nulls) { - // SQL NULL type is only used internally in the binder - // cast to INTEGER if we encounter it outside of the binder - if (ContainsNullType(result->return_type)) { - auto exchanged_type = ExchangeNullType(result->return_type); - result = BoundCastExpression::AddCastToType(context, std::move(result), exchanged_type); - } - } - if (result->return_type.id() == LogicalTypeId::UNKNOWN) { - throw ParameterNotResolvedException(); - } - } - if (result_type) { - *result_type = result->return_type; - } - return result; -} - -string ExpressionBinder::Bind(unique_ptr &expr, idx_t depth, bool root_expression) { - // bind the node, but only if it has not been bound yet - auto &expression = *expr; - auto alias = expression.alias; - if (expression.GetExpressionClass() == ExpressionClass::BOUND_EXPRESSION) { - // already bound, don't bind it again - return string(); - } - // bind the expression - BindResult result = BindExpression(expr, depth, root_expression); - if (result.HasError()) { - return result.error; - } - // successfully bound: replace the node with a BoundExpression - expr = make_uniq(std::move(result.expression)); - auto &be = expr->Cast(); - be.alias = alias; - if (!alias.empty()) { - be.expr->alias = alias; - } - return string(); -} - -} // namespace duckdb - - - - - - - - - - -namespace duckdb { - -void ExpressionIterator::EnumerateChildren(const Expression &expr, - const std::function &callback) { - EnumerateChildren((Expression &)expr, [&](unique_ptr &child) { callback(*child); }); -} - -void ExpressionIterator::EnumerateChildren(Expression &expr, const std::function &callback) { - EnumerateChildren(expr, [&](unique_ptr &child) { callback(*child); }); -} - -void ExpressionIterator::EnumerateChildren(Expression &expr, - const std::function &child)> &callback) { - switch (expr.expression_class) { - case ExpressionClass::BOUND_AGGREGATE: { - auto &aggr_expr = expr.Cast(); - for (auto &child : aggr_expr.children) { - callback(child); - } - if (aggr_expr.filter) { - callback(aggr_expr.filter); - } - if (aggr_expr.order_bys) { - for (auto &order : aggr_expr.order_bys->orders) { - callback(order.expression); - } - } - break; - } - case ExpressionClass::BOUND_BETWEEN: { - auto &between_expr = expr.Cast(); - callback(between_expr.input); - callback(between_expr.lower); - callback(between_expr.upper); - break; - } - case ExpressionClass::BOUND_CASE: { - auto &case_expr = expr.Cast(); - for (auto &case_check : case_expr.case_checks) { - callback(case_check.when_expr); - callback(case_check.then_expr); - } - callback(case_expr.else_expr); - break; - } - case ExpressionClass::BOUND_CAST: { - auto &cast_expr = expr.Cast(); - callback(cast_expr.child); - break; - } - case ExpressionClass::BOUND_COMPARISON: { - auto &comp_expr = expr.Cast(); - callback(comp_expr.left); - callback(comp_expr.right); - break; - } - case ExpressionClass::BOUND_CONJUNCTION: { - auto &conj_expr = expr.Cast(); - for (auto &child : conj_expr.children) { - callback(child); - } - break; - } - case ExpressionClass::BOUND_FUNCTION: { - auto &func_expr = expr.Cast(); - for (auto &child : func_expr.children) { - callback(child); - } - break; - } - case ExpressionClass::BOUND_OPERATOR: { - auto &op_expr = expr.Cast(); - for (auto &child : op_expr.children) { - callback(child); - } - break; - } - case ExpressionClass::BOUND_SUBQUERY: { - auto &subquery_expr = expr.Cast(); - if (subquery_expr.child) { - callback(subquery_expr.child); - } - break; - } - case ExpressionClass::BOUND_WINDOW: { - auto &window_expr = expr.Cast(); - for (auto &partition : window_expr.partitions) { - callback(partition); - } - for (auto &order : window_expr.orders) { - callback(order.expression); - } - for (auto &child : window_expr.children) { - callback(child); - } - if (window_expr.filter_expr) { - callback(window_expr.filter_expr); - } - if (window_expr.start_expr) { - callback(window_expr.start_expr); - } - if (window_expr.end_expr) { - callback(window_expr.end_expr); - } - if (window_expr.offset_expr) { - callback(window_expr.offset_expr); - } - if (window_expr.default_expr) { - callback(window_expr.default_expr); - } - break; - } - case ExpressionClass::BOUND_UNNEST: { - auto &unnest_expr = expr.Cast(); - callback(unnest_expr.child); - break; - } - case ExpressionClass::BOUND_COLUMN_REF: - case ExpressionClass::BOUND_LAMBDA_REF: - case ExpressionClass::BOUND_CONSTANT: - case ExpressionClass::BOUND_DEFAULT: - case ExpressionClass::BOUND_PARAMETER: - case ExpressionClass::BOUND_REF: - // these node types have no children - break; - default: - throw InternalException("ExpressionIterator used on unbound expression"); - } -} - -void ExpressionIterator::EnumerateExpression(unique_ptr &expr, - const std::function &callback) { - if (!expr) { - return; - } - callback(*expr); - ExpressionIterator::EnumerateChildren(*expr, - [&](unique_ptr &child) { EnumerateExpression(child, callback); }); -} - -void ExpressionIterator::EnumerateTableRefChildren(BoundTableRef &ref, - const std::function &callback) { - switch (ref.type) { - case TableReferenceType::EXPRESSION_LIST: { - auto &bound_expr_list = ref.Cast(); - for (auto &expr_list : bound_expr_list.values) { - for (auto &expr : expr_list) { - EnumerateExpression(expr, callback); - } - } - break; - } - case TableReferenceType::JOIN: { - auto &bound_join = ref.Cast(); - if (bound_join.condition) { - EnumerateExpression(bound_join.condition, callback); - } - EnumerateTableRefChildren(*bound_join.left, callback); - EnumerateTableRefChildren(*bound_join.right, callback); - break; - } - case TableReferenceType::SUBQUERY: { - auto &bound_subquery = ref.Cast(); - EnumerateQueryNodeChildren(*bound_subquery.subquery, callback); - break; - } - case TableReferenceType::TABLE_FUNCTION: - case TableReferenceType::EMPTY: - case TableReferenceType::BASE_TABLE: - case TableReferenceType::CTE: - break; - default: - throw NotImplementedException("Unimplemented table reference type in ExpressionIterator"); - } -} - -void ExpressionIterator::EnumerateQueryNodeChildren(BoundQueryNode &node, - const std::function &callback) { - switch (node.type) { - case QueryNodeType::SET_OPERATION_NODE: { - auto &bound_setop = node.Cast(); - EnumerateQueryNodeChildren(*bound_setop.left, callback); - EnumerateQueryNodeChildren(*bound_setop.right, callback); - break; - } - case QueryNodeType::RECURSIVE_CTE_NODE: { - auto &cte_node = node.Cast(); - EnumerateQueryNodeChildren(*cte_node.left, callback); - EnumerateQueryNodeChildren(*cte_node.right, callback); - break; - } - case QueryNodeType::CTE_NODE: { - auto &cte_node = node.Cast(); - EnumerateQueryNodeChildren(*cte_node.child, callback); - break; - } - case QueryNodeType::SELECT_NODE: { - auto &bound_select = node.Cast(); - for (auto &expr : bound_select.select_list) { - EnumerateExpression(expr, callback); - } - EnumerateExpression(bound_select.where_clause, callback); - for (auto &expr : bound_select.groups.group_expressions) { - EnumerateExpression(expr, callback); - } - EnumerateExpression(bound_select.having, callback); - for (auto &expr : bound_select.aggregates) { - EnumerateExpression(expr, callback); - } - for (auto &entry : bound_select.unnests) { - for (auto &expr : entry.second.expressions) { - EnumerateExpression(expr, callback); - } - } - for (auto &expr : bound_select.windows) { - EnumerateExpression(expr, callback); - } - if (bound_select.from_table) { - EnumerateTableRefChildren(*bound_select.from_table, callback); - } - break; - } - default: - throw NotImplementedException("Unimplemented query node in ExpressionIterator"); - } - for (idx_t i = 0; i < node.modifiers.size(); i++) { - switch (node.modifiers[i]->type) { - case ResultModifierType::DISTINCT_MODIFIER: - for (auto &expr : node.modifiers[i]->Cast().target_distincts) { - EnumerateExpression(expr, callback); - } - break; - case ResultModifierType::ORDER_MODIFIER: - for (auto &order : node.modifiers[i]->Cast().orders) { - EnumerateExpression(order.expression, callback); - } - break; - default: - break; - } - } -} - -} // namespace duckdb - - -namespace duckdb { - -ConjunctionOrFilter::ConjunctionOrFilter() : ConjunctionFilter(TableFilterType::CONJUNCTION_OR) { -} - -FilterPropagateResult ConjunctionOrFilter::CheckStatistics(BaseStatistics &stats) { - // the OR filter is true if ANY of the children is true - D_ASSERT(!child_filters.empty()); - for (auto &filter : child_filters) { - auto prune_result = filter->CheckStatistics(stats); - if (prune_result == FilterPropagateResult::NO_PRUNING_POSSIBLE) { - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } else if (prune_result == FilterPropagateResult::FILTER_ALWAYS_TRUE) { - return FilterPropagateResult::FILTER_ALWAYS_TRUE; - } - } - return FilterPropagateResult::FILTER_ALWAYS_FALSE; -} - -string ConjunctionOrFilter::ToString(const string &column_name) { - string result; - for (idx_t i = 0; i < child_filters.size(); i++) { - if (i > 0) { - result += " OR "; - } - result += child_filters[i]->ToString(column_name); - } - return result; -} - -bool ConjunctionOrFilter::Equals(const TableFilter &other_p) const { - if (!ConjunctionFilter::Equals(other_p)) { - return false; - } - auto &other = other_p.Cast(); - if (other.child_filters.size() != child_filters.size()) { - return false; - } - for (idx_t i = 0; i < other.child_filters.size(); i++) { - if (!child_filters[i]->Equals(*other.child_filters[i])) { - return false; - } - } - return true; -} - -ConjunctionAndFilter::ConjunctionAndFilter() : ConjunctionFilter(TableFilterType::CONJUNCTION_AND) { -} - -FilterPropagateResult ConjunctionAndFilter::CheckStatistics(BaseStatistics &stats) { - // the AND filter is true if ALL of the children is true - D_ASSERT(!child_filters.empty()); - auto result = FilterPropagateResult::FILTER_ALWAYS_TRUE; - for (auto &filter : child_filters) { - auto prune_result = filter->CheckStatistics(stats); - if (prune_result == FilterPropagateResult::FILTER_ALWAYS_FALSE) { - return FilterPropagateResult::FILTER_ALWAYS_FALSE; - } else if (prune_result != result) { - result = FilterPropagateResult::NO_PRUNING_POSSIBLE; - } - } - return result; -} - -string ConjunctionAndFilter::ToString(const string &column_name) { - string result; - for (idx_t i = 0; i < child_filters.size(); i++) { - if (i > 0) { - result += " AND "; - } - result += child_filters[i]->ToString(column_name); - } - return result; -} - -bool ConjunctionAndFilter::Equals(const TableFilter &other_p) const { - if (!ConjunctionFilter::Equals(other_p)) { - return false; - } - auto &other = other_p.Cast(); - if (other.child_filters.size() != child_filters.size()) { - return false; - } - for (idx_t i = 0; i < other.child_filters.size(); i++) { - if (!child_filters[i]->Equals(*other.child_filters[i])) { - return false; - } - } - return true; -} - -} // namespace duckdb - - - -namespace duckdb { - -ConstantFilter::ConstantFilter(ExpressionType comparison_type_p, Value constant_p) - : TableFilter(TableFilterType::CONSTANT_COMPARISON), comparison_type(comparison_type_p), - constant(std::move(constant_p)) { -} - -FilterPropagateResult ConstantFilter::CheckStatistics(BaseStatistics &stats) { - D_ASSERT(constant.type().id() == stats.GetType().id()); - switch (constant.type().InternalType()) { - case PhysicalType::UINT8: - case PhysicalType::UINT16: - case PhysicalType::UINT32: - case PhysicalType::UINT64: - case PhysicalType::INT8: - case PhysicalType::INT16: - case PhysicalType::INT32: - case PhysicalType::INT64: - case PhysicalType::INT128: - case PhysicalType::FLOAT: - case PhysicalType::DOUBLE: - return NumericStats::CheckZonemap(stats, comparison_type, constant); - case PhysicalType::VARCHAR: - return StringStats::CheckZonemap(stats, comparison_type, StringValue::Get(constant)); - default: - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } -} - -string ConstantFilter::ToString(const string &column_name) { - return column_name + ExpressionTypeToOperator(comparison_type) + constant.ToString(); -} - -bool ConstantFilter::Equals(const TableFilter &other_p) const { - if (!TableFilter::Equals(other_p)) { - return false; - } - auto &other = other_p.Cast(); - return other.comparison_type == comparison_type && other.constant == constant; -} - -} // namespace duckdb - - - -namespace duckdb { - -IsNullFilter::IsNullFilter() : TableFilter(TableFilterType::IS_NULL) { -} - -FilterPropagateResult IsNullFilter::CheckStatistics(BaseStatistics &stats) { - if (!stats.CanHaveNull()) { - // no null values are possible: always false - return FilterPropagateResult::FILTER_ALWAYS_FALSE; - } - if (!stats.CanHaveNoNull()) { - // no non-null values are possible: always true - return FilterPropagateResult::FILTER_ALWAYS_TRUE; - } - return FilterPropagateResult::NO_PRUNING_POSSIBLE; -} - -string IsNullFilter::ToString(const string &column_name) { - return column_name + "IS NULL"; -} - -IsNotNullFilter::IsNotNullFilter() : TableFilter(TableFilterType::IS_NOT_NULL) { -} - -FilterPropagateResult IsNotNullFilter::CheckStatistics(BaseStatistics &stats) { - if (!stats.CanHaveNoNull()) { - // no non-null values are possible: always false - return FilterPropagateResult::FILTER_ALWAYS_FALSE; - } - if (!stats.CanHaveNull()) { - // no null values are possible: always true - return FilterPropagateResult::FILTER_ALWAYS_TRUE; - } - return FilterPropagateResult::NO_PRUNING_POSSIBLE; -} - -string IsNotNullFilter::ToString(const string &column_name) { - return column_name + " IS NOT NULL"; -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -unique_ptr JoinCondition::CreateExpression(JoinCondition cond) { - auto bound_comparison = - make_uniq(cond.comparison, std::move(cond.left), std::move(cond.right)); - return std::move(bound_comparison); -} - -unique_ptr JoinCondition::CreateExpression(vector conditions) { - unique_ptr result; - for (auto &cond : conditions) { - auto expr = CreateExpression(std::move(cond)); - if (!result) { - result = std::move(expr); - } else { - auto conj = make_uniq(ExpressionType::CONJUNCTION_AND, std::move(expr), - std::move(result)); - result = std::move(conj); - } - } - return result; -} - -JoinSide JoinSide::CombineJoinSide(JoinSide left, JoinSide right) { - if (left == JoinSide::NONE) { - return right; - } - if (right == JoinSide::NONE) { - return left; - } - if (left != right) { - return JoinSide::BOTH; - } - return left; -} - -JoinSide JoinSide::GetJoinSide(idx_t table_binding, const unordered_set &left_bindings, - const unordered_set &right_bindings) { - if (left_bindings.find(table_binding) != left_bindings.end()) { - // column references table on left side - D_ASSERT(right_bindings.find(table_binding) == right_bindings.end()); - return JoinSide::LEFT; - } else { - // column references table on right side - D_ASSERT(right_bindings.find(table_binding) != right_bindings.end()); - return JoinSide::RIGHT; - } -} - -JoinSide JoinSide::GetJoinSide(Expression &expression, const unordered_set &left_bindings, - const unordered_set &right_bindings) { - if (expression.type == ExpressionType::BOUND_COLUMN_REF) { - auto &colref = expression.Cast(); - if (colref.depth > 0) { - throw Exception("Non-inner join on correlated columns not supported"); - } - return GetJoinSide(colref.binding.table_index, left_bindings, right_bindings); - } - D_ASSERT(expression.type != ExpressionType::BOUND_REF); - if (expression.type == ExpressionType::SUBQUERY) { - D_ASSERT(expression.GetExpressionClass() == ExpressionClass::BOUND_SUBQUERY); - auto &subquery = expression.Cast(); - JoinSide side = JoinSide::NONE; - if (subquery.child) { - side = GetJoinSide(*subquery.child, left_bindings, right_bindings); - } - // correlated subquery, check the side of each of correlated columns in the subquery - for (auto &corr : subquery.binder->correlated_columns) { - if (corr.depth > 1) { - // correlated column has depth > 1 - // it does not refer to any table in the current set of bindings - return JoinSide::BOTH; - } - auto correlated_side = GetJoinSide(corr.binding.table_index, left_bindings, right_bindings); - side = CombineJoinSide(side, correlated_side); - } - return side; - } - JoinSide join_side = JoinSide::NONE; - ExpressionIterator::EnumerateChildren(expression, [&](Expression &child) { - auto child_side = GetJoinSide(child, left_bindings, right_bindings); - join_side = CombineJoinSide(child_side, join_side); - }); - return join_side; -} - -JoinSide JoinSide::GetJoinSide(const unordered_set &bindings, const unordered_set &left_bindings, - const unordered_set &right_bindings) { - JoinSide side = JoinSide::NONE; - for (auto binding : bindings) { - side = CombineJoinSide(side, GetJoinSide(binding, left_bindings, right_bindings)); - } - return side; -} - -} // namespace duckdb - - - - - - - - - - - -namespace duckdb { - -LogicalOperator::LogicalOperator(LogicalOperatorType type) - : type(type), estimated_cardinality(0), has_estimated_cardinality(false) { -} - -LogicalOperator::LogicalOperator(LogicalOperatorType type, vector> expressions) - : type(type), expressions(std::move(expressions)), estimated_cardinality(0), has_estimated_cardinality(false) { -} - -LogicalOperator::~LogicalOperator() { -} - -vector LogicalOperator::GetColumnBindings() { - return {ColumnBinding(0, 0)}; -} - -string LogicalOperator::GetName() const { - return LogicalOperatorToString(type); -} - -string LogicalOperator::ParamsToString() const { - string result; - for (idx_t i = 0; i < expressions.size(); i++) { - if (i > 0) { - result += "\n"; - } - result += expressions[i]->GetName(); - } - return result; -} - -void LogicalOperator::ResolveOperatorTypes() { - - types.clear(); - // first resolve child types - for (auto &child : children) { - child->ResolveOperatorTypes(); - } - // now resolve the types for this operator - ResolveTypes(); - D_ASSERT(types.size() == GetColumnBindings().size()); -} - -vector LogicalOperator::GenerateColumnBindings(idx_t table_idx, idx_t column_count) { - vector result; - result.reserve(column_count); - for (idx_t i = 0; i < column_count; i++) { - result.emplace_back(table_idx, i); - } - return result; -} - -vector LogicalOperator::MapTypes(const vector &types, const vector &projection_map) { - if (projection_map.empty()) { - return types; - } else { - vector result_types; - result_types.reserve(projection_map.size()); - for (auto index : projection_map) { - result_types.push_back(types[index]); - } - return result_types; - } -} - -vector LogicalOperator::MapBindings(const vector &bindings, - const vector &projection_map) { - if (projection_map.empty()) { - return bindings; - } else { - vector result_bindings; - result_bindings.reserve(projection_map.size()); - for (auto index : projection_map) { - D_ASSERT(index < bindings.size()); - result_bindings.push_back(bindings[index]); - } - return result_bindings; - } -} - -string LogicalOperator::ToString() const { - TreeRenderer renderer; - return renderer.ToString(*this); -} - -void LogicalOperator::Verify(ClientContext &context) { -#ifdef DEBUG - // verify expressions - for (idx_t expr_idx = 0; expr_idx < expressions.size(); expr_idx++) { - auto str = expressions[expr_idx]->ToString(); - // verify that we can (correctly) copy this expression - auto copy = expressions[expr_idx]->Copy(); - auto original_hash = expressions[expr_idx]->Hash(); - auto copy_hash = copy->Hash(); - // copy should be identical to original - D_ASSERT(expressions[expr_idx]->ToString() == copy->ToString()); - D_ASSERT(original_hash == copy_hash); - D_ASSERT(Expression::Equals(expressions[expr_idx], copy)); - - for (idx_t other_idx = 0; other_idx < expr_idx; other_idx++) { - // comparison with other expressions - auto other_hash = expressions[other_idx]->Hash(); - bool expr_equal = Expression::Equals(expressions[expr_idx], expressions[other_idx]); - if (original_hash != other_hash) { - // if the hashes are not equal the expressions should not be equal either - D_ASSERT(!expr_equal); - } - } - D_ASSERT(!str.empty()); - - // verify that serialization + deserialization round-trips correctly - if (expressions[expr_idx]->HasParameter()) { - continue; - } - MemoryStream stream; - // We are serializing a query plan - try { - BinarySerializer::Serialize(*expressions[expr_idx], stream); - } catch (NotImplementedException &ex) { - // ignore for now (FIXME) - continue; - } - // Rewind the stream - stream.Rewind(); - - bound_parameter_map_t parameters; - auto deserialized_expression = BinaryDeserializer::Deserialize(stream, context, parameters); - - // FIXME: expressions might not be equal yet because of statistics propagation - continue; - D_ASSERT(Expression::Equals(expressions[expr_idx], deserialized_expression)); - D_ASSERT(expressions[expr_idx]->Hash() == deserialized_expression->Hash()); - } - D_ASSERT(!ToString().empty()); - for (auto &child : children) { - child->Verify(context); - } -#endif -} - -void LogicalOperator::AddChild(unique_ptr child) { - D_ASSERT(child); - children.push_back(std::move(child)); -} - -idx_t LogicalOperator::EstimateCardinality(ClientContext &context) { - // simple estimator, just take the max of the children - if (has_estimated_cardinality) { - return estimated_cardinality; - } - idx_t max_cardinality = 0; - for (auto &child : children) { - max_cardinality = MaxValue(child->EstimateCardinality(context), max_cardinality); - } - has_estimated_cardinality = true; - estimated_cardinality = max_cardinality; - return estimated_cardinality; -} - -void LogicalOperator::Print() { - Printer::Print(ToString()); -} - -vector LogicalOperator::GetTableIndex() const { - return vector {}; -} - -unique_ptr LogicalOperator::Copy(ClientContext &context) const { - MemoryStream stream; - BinarySerializer serializer(stream); - try { - serializer.Begin(); - this->Serialize(serializer); - serializer.End(); - } catch (NotImplementedException &ex) { - throw NotImplementedException("Logical Operator Copy requires the logical operator and all of its children to " - "be serializable: " + - std::string(ex.what())); - } - stream.Rewind(); - bound_parameter_map_t parameters; - auto op_copy = BinaryDeserializer::Deserialize(stream, context, parameters); - return op_copy; -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -void LogicalOperatorVisitor::VisitOperator(LogicalOperator &op) { - VisitOperatorChildren(op); - VisitOperatorExpressions(op); -} - -void LogicalOperatorVisitor::VisitOperatorChildren(LogicalOperator &op) { - for (auto &child : op.children) { - VisitOperator(*child); - } -} - -void LogicalOperatorVisitor::EnumerateExpressions(LogicalOperator &op, - const std::function *child)> &callback) { - - switch (op.type) { - case LogicalOperatorType::LOGICAL_EXPRESSION_GET: { - auto &get = op.Cast(); - for (auto &expr_list : get.expressions) { - for (auto &expr : expr_list) { - callback(&expr); - } - } - break; - } - case LogicalOperatorType::LOGICAL_ORDER_BY: { - auto &order = op.Cast(); - for (auto &node : order.orders) { - callback(&node.expression); - } - break; - } - case LogicalOperatorType::LOGICAL_TOP_N: { - auto &order = op.Cast(); - for (auto &node : order.orders) { - callback(&node.expression); - } - break; - } - case LogicalOperatorType::LOGICAL_DISTINCT: { - auto &distinct = op.Cast(); - for (auto &target : distinct.distinct_targets) { - callback(&target); - } - if (distinct.order_by) { - for (auto &order : distinct.order_by->orders) { - callback(&order.expression); - } - } - break; - } - case LogicalOperatorType::LOGICAL_INSERT: { - auto &insert = op.Cast(); - if (insert.on_conflict_condition) { - callback(&insert.on_conflict_condition); - } - if (insert.do_update_condition) { - callback(&insert.do_update_condition); - } - break; - } - case LogicalOperatorType::LOGICAL_ASOF_JOIN: - case LogicalOperatorType::LOGICAL_DELIM_JOIN: - case LogicalOperatorType::LOGICAL_DEPENDENT_JOIN: - case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: { - auto &join = op.Cast(); - for (auto &expr : join.duplicate_eliminated_columns) { - callback(&expr); - } - for (auto &cond : join.conditions) { - callback(&cond.left); - callback(&cond.right); - } - break; - } - case LogicalOperatorType::LOGICAL_ANY_JOIN: { - auto &join = op.Cast(); - callback(&join.condition); - break; - } - case LogicalOperatorType::LOGICAL_LIMIT: { - auto &limit = op.Cast(); - if (limit.limit) { - callback(&limit.limit); - } - if (limit.offset) { - callback(&limit.offset); - } - break; - } - case LogicalOperatorType::LOGICAL_LIMIT_PERCENT: { - auto &limit = op.Cast(); - if (limit.limit) { - callback(&limit.limit); - } - if (limit.offset) { - callback(&limit.offset); - } - break; - } - case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: { - auto &aggr = op.Cast(); - for (auto &group : aggr.groups) { - callback(&group); - } - break; - } - default: - break; - } - for (auto &expression : op.expressions) { - callback(&expression); - } -} - -void LogicalOperatorVisitor::VisitOperatorExpressions(LogicalOperator &op) { - LogicalOperatorVisitor::EnumerateExpressions(op, [&](unique_ptr *child) { VisitExpression(child); }); -} - -void LogicalOperatorVisitor::VisitExpression(unique_ptr *expression) { - auto &expr = **expression; - unique_ptr result; - switch (expr.GetExpressionClass()) { - case ExpressionClass::BOUND_AGGREGATE: - result = VisitReplace(expr.Cast(), expression); - break; - case ExpressionClass::BOUND_BETWEEN: - result = VisitReplace(expr.Cast(), expression); - break; - case ExpressionClass::BOUND_CASE: - result = VisitReplace(expr.Cast(), expression); - break; - case ExpressionClass::BOUND_CAST: - result = VisitReplace(expr.Cast(), expression); - break; - case ExpressionClass::BOUND_COLUMN_REF: - result = VisitReplace(expr.Cast(), expression); - break; - case ExpressionClass::BOUND_COMPARISON: - result = VisitReplace(expr.Cast(), expression); - break; - case ExpressionClass::BOUND_CONJUNCTION: - result = VisitReplace(expr.Cast(), expression); - break; - case ExpressionClass::BOUND_CONSTANT: - result = VisitReplace(expr.Cast(), expression); - break; - case ExpressionClass::BOUND_FUNCTION: - result = VisitReplace(expr.Cast(), expression); - break; - case ExpressionClass::BOUND_SUBQUERY: - result = VisitReplace(expr.Cast(), expression); - break; - case ExpressionClass::BOUND_OPERATOR: - result = VisitReplace(expr.Cast(), expression); - break; - case ExpressionClass::BOUND_PARAMETER: - result = VisitReplace(expr.Cast(), expression); - break; - case ExpressionClass::BOUND_REF: - result = VisitReplace(expr.Cast(), expression); - break; - case ExpressionClass::BOUND_DEFAULT: - result = VisitReplace(expr.Cast(), expression); - break; - case ExpressionClass::BOUND_WINDOW: - result = VisitReplace(expr.Cast(), expression); - break; - case ExpressionClass::BOUND_UNNEST: - result = VisitReplace(expr.Cast(), expression); - break; - default: - throw InternalException("Unrecognized expression type in logical operator visitor"); - } - if (result) { - *expression = std::move(result); - } else { - // visit the children of this node - VisitExpressionChildren(expr); - } -} - -void LogicalOperatorVisitor::VisitExpressionChildren(Expression &expr) { - ExpressionIterator::EnumerateChildren(expr, [&](unique_ptr &expr) { VisitExpression(&expr); }); -} - -// these are all default methods that can be overriden -// we don't care about coverage here -// LCOV_EXCL_START -unique_ptr LogicalOperatorVisitor::VisitReplace(BoundAggregateExpression &expr, - unique_ptr *expr_ptr) { - return nullptr; -} - -unique_ptr LogicalOperatorVisitor::VisitReplace(BoundBetweenExpression &expr, - unique_ptr *expr_ptr) { - return nullptr; -} - -unique_ptr LogicalOperatorVisitor::VisitReplace(BoundCaseExpression &expr, - unique_ptr *expr_ptr) { - return nullptr; -} - -unique_ptr LogicalOperatorVisitor::VisitReplace(BoundCastExpression &expr, - unique_ptr *expr_ptr) { - return nullptr; -} - -unique_ptr LogicalOperatorVisitor::VisitReplace(BoundColumnRefExpression &expr, - unique_ptr *expr_ptr) { - return nullptr; -} - -unique_ptr LogicalOperatorVisitor::VisitReplace(BoundComparisonExpression &expr, - unique_ptr *expr_ptr) { - return nullptr; -} - -unique_ptr LogicalOperatorVisitor::VisitReplace(BoundConjunctionExpression &expr, - unique_ptr *expr_ptr) { - return nullptr; -} - -unique_ptr LogicalOperatorVisitor::VisitReplace(BoundConstantExpression &expr, - unique_ptr *expr_ptr) { - return nullptr; -} - -unique_ptr LogicalOperatorVisitor::VisitReplace(BoundDefaultExpression &expr, - unique_ptr *expr_ptr) { - return nullptr; -} - -unique_ptr LogicalOperatorVisitor::VisitReplace(BoundFunctionExpression &expr, - unique_ptr *expr_ptr) { - return nullptr; -} - -unique_ptr LogicalOperatorVisitor::VisitReplace(BoundOperatorExpression &expr, - unique_ptr *expr_ptr) { - return nullptr; -} - -unique_ptr LogicalOperatorVisitor::VisitReplace(BoundParameterExpression &expr, - unique_ptr *expr_ptr) { - return nullptr; -} - -unique_ptr LogicalOperatorVisitor::VisitReplace(BoundReferenceExpression &expr, - unique_ptr *expr_ptr) { - return nullptr; -} - -unique_ptr LogicalOperatorVisitor::VisitReplace(BoundSubqueryExpression &expr, - unique_ptr *expr_ptr) { - return nullptr; -} - -unique_ptr LogicalOperatorVisitor::VisitReplace(BoundWindowExpression &expr, - unique_ptr *expr_ptr) { - return nullptr; -} - -unique_ptr LogicalOperatorVisitor::VisitReplace(BoundUnnestExpression &expr, - unique_ptr *expr_ptr) { - return nullptr; -} - -// LCOV_EXCL_STOP - -} // namespace duckdb - - - - - -namespace duckdb { - -LogicalAggregate::LogicalAggregate(idx_t group_index, idx_t aggregate_index, vector> select_list) - : LogicalOperator(LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY, std::move(select_list)), - group_index(group_index), aggregate_index(aggregate_index), groupings_index(DConstants::INVALID_INDEX) { -} - -void LogicalAggregate::ResolveTypes() { - D_ASSERT(groupings_index != DConstants::INVALID_INDEX || grouping_functions.empty()); - for (auto &expr : groups) { - types.push_back(expr->return_type); - } - // get the chunk types from the projection list - for (auto &expr : expressions) { - types.push_back(expr->return_type); - } - for (idx_t i = 0; i < grouping_functions.size(); i++) { - types.emplace_back(LogicalType::BIGINT); - } -} - -vector LogicalAggregate::GetColumnBindings() { - D_ASSERT(groupings_index != DConstants::INVALID_INDEX || grouping_functions.empty()); - vector result; - result.reserve(groups.size() + expressions.size() + grouping_functions.size()); - for (idx_t i = 0; i < groups.size(); i++) { - result.emplace_back(group_index, i); - } - for (idx_t i = 0; i < expressions.size(); i++) { - result.emplace_back(aggregate_index, i); - } - for (idx_t i = 0; i < grouping_functions.size(); i++) { - result.emplace_back(groupings_index, i); - } - return result; -} - -string LogicalAggregate::ParamsToString() const { - string result; - for (idx_t i = 0; i < groups.size(); i++) { - if (i > 0) { - result += "\n"; - } - result += groups[i]->GetName(); - } - for (idx_t i = 0; i < expressions.size(); i++) { - if (i > 0 || !groups.empty()) { - result += "\n"; - } - result += expressions[i]->GetName(); - } - return result; -} - -idx_t LogicalAggregate::EstimateCardinality(ClientContext &context) { - if (groups.empty()) { - // ungrouped aggregate - return 1; - } - return LogicalOperator::EstimateCardinality(context); -} - -vector LogicalAggregate::GetTableIndex() const { - vector result {group_index, aggregate_index}; - if (groupings_index != DConstants::INVALID_INDEX) { - result.push_back(groupings_index); - } - return result; -} - -string LogicalAggregate::GetName() const { -#ifdef DEBUG - if (DBConfigOptions::debug_print_bindings) { - return LogicalOperator::GetName() + - StringUtil::Format(" #%llu, #%llu, #%llu", group_index, aggregate_index, groupings_index); - } -#endif - return LogicalOperator::GetName(); -} - -} // namespace duckdb - - -namespace duckdb { - -LogicalAnyJoin::LogicalAnyJoin(JoinType type) : LogicalJoin(type, LogicalOperatorType::LOGICAL_ANY_JOIN) { -} - -string LogicalAnyJoin::ParamsToString() const { - return condition->ToString(); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -LogicalColumnDataGet::LogicalColumnDataGet(idx_t table_index, vector types, - unique_ptr collection) - : LogicalOperator(LogicalOperatorType::LOGICAL_CHUNK_GET), table_index(table_index), - collection(std::move(collection)) { - D_ASSERT(types.size() > 0); - chunk_types = std::move(types); -} - -vector LogicalColumnDataGet::GetColumnBindings() { - return GenerateColumnBindings(table_index, chunk_types.size()); -} - -vector LogicalColumnDataGet::GetTableIndex() const { - return vector {table_index}; -} - -string LogicalColumnDataGet::GetName() const { -#ifdef DEBUG - if (DBConfigOptions::debug_print_bindings) { - return LogicalOperator::GetName() + StringUtil::Format(" #%llu", table_index); - } -#endif - return LogicalOperator::GetName(); -} - -} // namespace duckdb - - - - -namespace duckdb { - -LogicalComparisonJoin::LogicalComparisonJoin(JoinType join_type, LogicalOperatorType logical_type) - : LogicalJoin(join_type, logical_type) { -} - -string LogicalComparisonJoin::ParamsToString() const { - string result = EnumUtil::ToChars(join_type); - for (auto &condition : conditions) { - result += "\n"; - auto expr = - make_uniq(condition.comparison, condition.left->Copy(), condition.right->Copy()); - result += expr->ToString(); - } - - return result; -} - -} // namespace duckdb - - - - - -namespace duckdb { - -void LogicalCopyToFile::Serialize(Serializer &serializer) const { - throw SerializationException("LogicalCopyToFile not implemented yet"); -} - -unique_ptr LogicalCopyToFile::Deserialize(Deserializer &deserializer) { - throw SerializationException("LogicalCopyToFile not implemented yet"); -} - -idx_t LogicalCopyToFile::EstimateCardinality(ClientContext &context) { - return 1; -} - -} // namespace duckdb - - -namespace duckdb { - -LogicalCreate::LogicalCreate(LogicalOperatorType type, unique_ptr info, - optional_ptr schema) - : LogicalOperator(type), schema(schema), info(std::move(info)) { -} - -LogicalCreate::LogicalCreate(LogicalOperatorType type, ClientContext &context, unique_ptr info_p) - : LogicalOperator(type), info(std::move(info_p)) { - this->schema = Catalog::GetSchema(context, info->catalog, info->schema, OnEntryNotFound::RETURN_NULL); -} - -idx_t LogicalCreate::EstimateCardinality(ClientContext &context) { - return 1; -} - -void LogicalCreate::ResolveTypes() { - types.emplace_back(LogicalType::BIGINT); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -LogicalCreateIndex::LogicalCreateIndex(unique_ptr info_p, vector> expressions_p, - TableCatalogEntry &table_p) - : LogicalOperator(LogicalOperatorType::LOGICAL_CREATE_INDEX), info(std::move(info_p)), table(table_p) { - - for (auto &expr : expressions_p) { - this->unbound_expressions.push_back(expr->Copy()); - } - this->expressions = std::move(expressions_p); - - if (info->column_ids.empty()) { - throw BinderException("CREATE INDEX does not refer to any columns in the base table!"); - } -} - -LogicalCreateIndex::LogicalCreateIndex(ClientContext &context, unique_ptr info_p, - vector> expressions_p) - : LogicalOperator(LogicalOperatorType::LOGICAL_CREATE_INDEX), - info(unique_ptr_cast(std::move(info_p))), table(BindTable(context, *info)) { - for (auto &expr : expressions_p) { - this->unbound_expressions.push_back(expr->Copy()); - } - this->expressions = std::move(expressions_p); -} - -void LogicalCreateIndex::ResolveTypes() { - types.emplace_back(LogicalType::BIGINT); -} - -TableCatalogEntry &LogicalCreateIndex::BindTable(ClientContext &context, CreateIndexInfo &info) { - auto &catalog = info.catalog; - auto &schema = info.schema; - auto &table_name = info.table; - return Catalog::GetEntry(context, catalog, schema, table_name); -} - -} // namespace duckdb - - -namespace duckdb { - -LogicalCreateTable::LogicalCreateTable(SchemaCatalogEntry &schema, unique_ptr info) - : LogicalOperator(LogicalOperatorType::LOGICAL_CREATE_TABLE), schema(schema), info(std::move(info)) { -} - -LogicalCreateTable::LogicalCreateTable(ClientContext &context, unique_ptr unbound_info) - : LogicalOperator(LogicalOperatorType::LOGICAL_CREATE_TABLE), - schema(Catalog::GetSchema(context, unbound_info->catalog, unbound_info->schema)) { - D_ASSERT(unbound_info->type == CatalogType::TABLE_ENTRY); - auto binder = Binder::CreateBinder(context); - info = binder->BindCreateTableInfo(unique_ptr_cast(std::move(unbound_info))); -} - -idx_t LogicalCreateTable::EstimateCardinality(ClientContext &context) { - return 1; -} - -void LogicalCreateTable::ResolveTypes() { - types.emplace_back(LogicalType::BIGINT); -} - -} // namespace duckdb - - -namespace duckdb { - -LogicalCrossProduct::LogicalCrossProduct(unique_ptr left, unique_ptr right) - : LogicalUnconditionalJoin(LogicalOperatorType::LOGICAL_CROSS_PRODUCT, std::move(left), std::move(right)) { -} - -unique_ptr LogicalCrossProduct::Create(unique_ptr left, - unique_ptr right) { - if (left->type == LogicalOperatorType::LOGICAL_DUMMY_SCAN) { - return right; - } - if (right->type == LogicalOperatorType::LOGICAL_DUMMY_SCAN) { - return left; - } - return make_uniq(std::move(left), std::move(right)); -} - -} // namespace duckdb - - - - -namespace duckdb { - -vector LogicalCTERef::GetTableIndex() const { - return vector {table_index}; -} - -string LogicalCTERef::GetName() const { -#ifdef DEBUG - if (DBConfigOptions::debug_print_bindings) { - return LogicalOperator::GetName() + StringUtil::Format(" #%llu", table_index); - } -#endif - return LogicalOperator::GetName(); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -LogicalDelete::LogicalDelete(TableCatalogEntry &table, idx_t table_index) - : LogicalOperator(LogicalOperatorType::LOGICAL_DELETE), table(table), table_index(table_index), - return_chunk(false) { -} - -LogicalDelete::LogicalDelete(ClientContext &context, const unique_ptr &table_info) - : LogicalOperator(LogicalOperatorType::LOGICAL_DELETE), - table(Catalog::GetEntry(context, table_info->catalog, table_info->schema, - dynamic_cast(*table_info).table)) { -} - -idx_t LogicalDelete::EstimateCardinality(ClientContext &context) { - return return_chunk ? LogicalOperator::EstimateCardinality(context) : 1; -} - -vector LogicalDelete::GetTableIndex() const { - return vector {table_index}; -} - -vector LogicalDelete::GetColumnBindings() { - if (return_chunk) { - return GenerateColumnBindings(table_index, table.GetTypes().size()); - } - return {ColumnBinding(0, 0)}; -} - -void LogicalDelete::ResolveTypes() { - if (return_chunk) { - types = table.GetTypes(); - } else { - types.emplace_back(LogicalType::BIGINT); - } -} - -string LogicalDelete::GetName() const { -#ifdef DEBUG - if (DBConfigOptions::debug_print_bindings) { - return LogicalOperator::GetName() + StringUtil::Format(" #%llu", table_index); - } -#endif - return LogicalOperator::GetName(); -} - -} // namespace duckdb - - - - -namespace duckdb { - -vector LogicalDelimGet::GetTableIndex() const { - return vector {table_index}; -} - -string LogicalDelimGet::GetName() const { -#ifdef DEBUG - if (DBConfigOptions::debug_print_bindings) { - return LogicalOperator::GetName() + StringUtil::Format(" #%llu", table_index); - } -#endif - return LogicalOperator::GetName(); -} - -} // namespace duckdb - - -namespace duckdb { - -LogicalDependentJoin::LogicalDependentJoin(unique_ptr left, unique_ptr right, - vector correlated_columns, JoinType type, - unique_ptr condition) - : LogicalComparisonJoin(type, LogicalOperatorType::LOGICAL_DEPENDENT_JOIN), join_condition(std::move(condition)), - correlated_columns(std::move(correlated_columns)) { - children.push_back(std::move(left)); - children.push_back(std::move(right)); -} - -unique_ptr LogicalDependentJoin::Create(unique_ptr left, - unique_ptr right, - vector correlated_columns, JoinType type, - unique_ptr condition) { - return make_uniq(std::move(left), std::move(right), std::move(correlated_columns), type, - std::move(condition)); -} - -} // namespace duckdb - - - -namespace duckdb { - -LogicalDistinct::LogicalDistinct(DistinctType distinct_type) - : LogicalOperator(LogicalOperatorType::LOGICAL_DISTINCT), distinct_type(distinct_type) { -} -LogicalDistinct::LogicalDistinct(vector> targets, DistinctType distinct_type) - : LogicalOperator(LogicalOperatorType::LOGICAL_DISTINCT), distinct_type(distinct_type), - distinct_targets(std::move(targets)) { -} - -string LogicalDistinct::ParamsToString() const { - string result = LogicalOperator::ParamsToString(); - if (!distinct_targets.empty()) { - result += StringUtil::Join(distinct_targets, distinct_targets.size(), "\n", - [](const unique_ptr &child) { return child->GetName(); }); - } - - return result; -} - -void LogicalDistinct::ResolveTypes() { - types = children[0]->types; -} - -} // namespace duckdb - - - - -namespace duckdb { - -vector LogicalDummyScan::GetTableIndex() const { - return vector {table_index}; -} - -string LogicalDummyScan::GetName() const { -#ifdef DEBUG - if (DBConfigOptions::debug_print_bindings) { - return LogicalOperator::GetName() + StringUtil::Format(" #%llu", table_index); - } -#endif - return LogicalOperator::GetName(); -} - -} // namespace duckdb - - -namespace duckdb { - -LogicalEmptyResult::LogicalEmptyResult(unique_ptr op) - : LogicalOperator(LogicalOperatorType::LOGICAL_EMPTY_RESULT) { - - this->bindings = op->GetColumnBindings(); - - op->ResolveOperatorTypes(); - this->return_types = op->types; -} - -LogicalEmptyResult::LogicalEmptyResult() : LogicalOperator(LogicalOperatorType::LOGICAL_EMPTY_RESULT) { -} - -} // namespace duckdb - - - - -namespace duckdb { - -vector LogicalExpressionGet::GetTableIndex() const { - return vector {table_index}; -} - -string LogicalExpressionGet::GetName() const { -#ifdef DEBUG - if (DBConfigOptions::debug_print_bindings) { - return LogicalOperator::GetName() + StringUtil::Format(" #%llu", table_index); - } -#endif - return LogicalOperator::GetName(); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -void LogicalExtensionOperator::ResolveColumnBindings(ColumnBindingResolver &res, vector &bindings) { - // general case - // first visit the children of this operator - for (auto &child : children) { - res.VisitOperator(*child); - } - // now visit the expressions of this operator to resolve any bound column references - for (auto &expression : expressions) { - res.VisitExpression(&expression); - } - // finally update the current set of bindings to the current set of column bindings - bindings = GetColumnBindings(); -} - -void LogicalExtensionOperator::Serialize(Serializer &serializer) const { - LogicalOperator::Serialize(serializer); - serializer.WriteProperty(200, "extension_name", GetExtensionName()); -} - -unique_ptr LogicalExtensionOperator::Deserialize(Deserializer &deserializer) { - auto &config = DBConfig::GetConfig(deserializer.Get()); - auto extension_name = deserializer.ReadProperty(200, "extension_name"); - for (auto &extension : config.operator_extensions) { - if (extension->GetName() == extension_name) { - return extension->Deserialize(deserializer); - } - } - throw SerializationException("No deserialization method exists for extension: " + extension_name); -} - -string LogicalExtensionOperator::GetExtensionName() const { - throw SerializationException("LogicalExtensionOperator::GetExtensionName not implemented which is required for " - "serializing extension operators"); -} - -} // namespace duckdb - - - -namespace duckdb { - -LogicalFilter::LogicalFilter(unique_ptr expression) : LogicalOperator(LogicalOperatorType::LOGICAL_FILTER) { - expressions.push_back(std::move(expression)); - SplitPredicates(expressions); -} - -LogicalFilter::LogicalFilter() : LogicalOperator(LogicalOperatorType::LOGICAL_FILTER) { -} - -void LogicalFilter::ResolveTypes() { - types = MapTypes(children[0]->types, projection_map); -} - -vector LogicalFilter::GetColumnBindings() { - return MapBindings(children[0]->GetColumnBindings(), projection_map); -} - -// Split the predicates separated by AND statements -// These are the predicates that are safe to push down because all of them MUST -// be true -bool LogicalFilter::SplitPredicates(vector> &expressions) { - bool found_conjunction = false; - for (idx_t i = 0; i < expressions.size(); i++) { - if (expressions[i]->type == ExpressionType::CONJUNCTION_AND) { - auto &conjunction = expressions[i]->Cast(); - found_conjunction = true; - // AND expression, append the other children - for (idx_t k = 1; k < conjunction.children.size(); k++) { - expressions.push_back(std::move(conjunction.children[k])); - } - // replace this expression with the first child of the conjunction - expressions[i] = std::move(conjunction.children[0]); - // we move back by one so the right child is checked again - // in case it is an AND expression as well - i--; - } - } - return found_conjunction; -} - -} // namespace duckdb - - - - - - - - - - - - -namespace duckdb { - -LogicalGet::LogicalGet() : LogicalOperator(LogicalOperatorType::LOGICAL_GET) { -} - -LogicalGet::LogicalGet(idx_t table_index, TableFunction function, unique_ptr bind_data, - vector returned_types, vector returned_names) - : LogicalOperator(LogicalOperatorType::LOGICAL_GET), table_index(table_index), function(std::move(function)), - bind_data(std::move(bind_data)), returned_types(std::move(returned_types)), names(std::move(returned_names)), - extra_info() { -} - -optional_ptr LogicalGet::GetTable() const { - return TableScanFunction::GetTableEntry(function, bind_data.get()); -} - -string LogicalGet::ParamsToString() const { - string result = ""; - for (auto &kv : table_filters.filters) { - auto &column_index = kv.first; - auto &filter = kv.second; - if (column_index < names.size()) { - result += filter->ToString(names[column_index]); - } - result += "\n"; - } - if (!extra_info.file_filters.empty()) { - result += "\n[INFOSEPARATOR]\n"; - result += "File Filters: " + extra_info.file_filters; - } - if (!function.to_string) { - return result; - } - return result + "\n" + function.to_string(bind_data.get()); -} - -vector LogicalGet::GetColumnBindings() { - if (column_ids.empty()) { - return {ColumnBinding(table_index, 0)}; - } - vector result; - if (projection_ids.empty()) { - for (idx_t col_idx = 0; col_idx < column_ids.size(); col_idx++) { - result.emplace_back(table_index, col_idx); - } - } else { - for (auto proj_id : projection_ids) { - result.emplace_back(table_index, proj_id); - } - } - if (!projected_input.empty()) { - if (children.size() != 1) { - throw InternalException("LogicalGet::project_input can only be set for table-in-out functions"); - } - auto child_bindings = children[0]->GetColumnBindings(); - for (auto entry : projected_input) { - D_ASSERT(entry < child_bindings.size()); - result.emplace_back(child_bindings[entry]); - } - } - return result; -} - -void LogicalGet::ResolveTypes() { - if (column_ids.empty()) { - column_ids.push_back(COLUMN_IDENTIFIER_ROW_ID); - } - - if (projection_ids.empty()) { - for (auto &index : column_ids) { - if (index == COLUMN_IDENTIFIER_ROW_ID) { - types.emplace_back(LogicalType::ROW_TYPE); - } else { - types.push_back(returned_types[index]); - } - } - } else { - for (auto &proj_index : projection_ids) { - auto &index = column_ids[proj_index]; - if (index == COLUMN_IDENTIFIER_ROW_ID) { - types.emplace_back(LogicalType::ROW_TYPE); - } else { - types.push_back(returned_types[index]); - } - } - } - if (!projected_input.empty()) { - if (children.size() != 1) { - throw InternalException("LogicalGet::project_input can only be set for table-in-out functions"); - } - for (auto entry : projected_input) { - D_ASSERT(entry < children[0]->types.size()); - types.push_back(children[0]->types[entry]); - } - } -} - -idx_t LogicalGet::EstimateCardinality(ClientContext &context) { - // join order optimizer does better cardinality estimation. - if (has_estimated_cardinality) { - return estimated_cardinality; - } - if (function.cardinality) { - auto node_stats = function.cardinality(context, bind_data.get()); - if (node_stats && node_stats->has_estimated_cardinality) { - return node_stats->estimated_cardinality; - } - } - return 1; -} - -void LogicalGet::Serialize(Serializer &serializer) const { - LogicalOperator::Serialize(serializer); - serializer.WriteProperty(200, "table_index", table_index); - serializer.WriteProperty(201, "returned_types", returned_types); - serializer.WriteProperty(202, "names", names); - serializer.WriteProperty(203, "column_ids", column_ids); - serializer.WriteProperty(204, "projection_ids", projection_ids); - serializer.WriteProperty(205, "table_filters", table_filters); - FunctionSerializer::Serialize(serializer, function, bind_data.get()); - if (!function.serialize) { - D_ASSERT(!function.serialize); - // no serialize method: serialize input values and named_parameters for rebinding purposes - serializer.WriteProperty(206, "parameters", parameters); - serializer.WriteProperty(207, "named_parameters", named_parameters); - serializer.WriteProperty(208, "input_table_types", input_table_types); - serializer.WriteProperty(209, "input_table_names", input_table_names); - } - serializer.WriteProperty(210, "projected_input", projected_input); -} - -unique_ptr LogicalGet::Deserialize(Deserializer &deserializer) { - auto result = unique_ptr(new LogicalGet()); - deserializer.ReadProperty(200, "table_index", result->table_index); - deserializer.ReadProperty(201, "returned_types", result->returned_types); - deserializer.ReadProperty(202, "names", result->names); - deserializer.ReadProperty(203, "column_ids", result->column_ids); - deserializer.ReadProperty(204, "projection_ids", result->projection_ids); - deserializer.ReadProperty(205, "table_filters", result->table_filters); - auto entry = FunctionSerializer::DeserializeBase( - deserializer, CatalogType::TABLE_FUNCTION_ENTRY); - result->function = entry.first; - auto &function = result->function; - auto has_serialize = entry.second; - - unique_ptr bind_data; - if (!has_serialize) { - deserializer.ReadProperty(206, "parameters", result->parameters); - deserializer.ReadProperty(207, "named_parameters", result->named_parameters); - deserializer.ReadProperty(208, "input_table_types", result->input_table_types); - deserializer.ReadProperty(209, "input_table_names", result->input_table_names); - TableFunctionBindInput input(result->parameters, result->named_parameters, result->input_table_types, - result->input_table_names, function.function_info.get()); - - vector bind_return_types; - vector bind_names; - if (!function.bind) { - throw InternalException("Table function \"%s\" has neither bind nor (de)serialize", function.name); - } - bind_data = function.bind(deserializer.Get(), input, bind_return_types, bind_names); - if (result->returned_types != bind_return_types) { - throw SerializationException( - "Table function deserialization failure - bind returned different return types than were serialized"); - } - // names can actually be different because of aliases - only the sizes cannot be different - if (result->names.size() != bind_names.size()) { - throw SerializationException( - "Table function deserialization failure - bind returned different returned names than were serialized"); - } - } else { - bind_data = FunctionSerializer::FunctionDeserialize(deserializer, function); - } - result->bind_data = std::move(bind_data); - deserializer.ReadProperty(210, "projected_input", result->projected_input); - return std::move(result); -} - -vector LogicalGet::GetTableIndex() const { - return vector {table_index}; -} - -string LogicalGet::GetName() const { -#ifdef DEBUG - if (DBConfigOptions::debug_print_bindings) { - return StringUtil::Upper(function.name) + StringUtil::Format(" #%llu", table_index); - } -#endif - return StringUtil::Upper(function.name); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -LogicalInsert::LogicalInsert(TableCatalogEntry &table, idx_t table_index) - : LogicalOperator(LogicalOperatorType::LOGICAL_INSERT), table(table), table_index(table_index), return_chunk(false), - action_type(OnConflictAction::THROW) { -} - -LogicalInsert::LogicalInsert(ClientContext &context, const unique_ptr table_info) - : LogicalOperator(LogicalOperatorType::LOGICAL_INSERT), - table(Catalog::GetEntry(context, table_info->catalog, table_info->schema, - dynamic_cast(*table_info).table)) { -} - -idx_t LogicalInsert::EstimateCardinality(ClientContext &context) { - return return_chunk ? LogicalOperator::EstimateCardinality(context) : 1; -} - -vector LogicalInsert::GetTableIndex() const { - return vector {table_index}; -} - -vector LogicalInsert::GetColumnBindings() { - if (return_chunk) { - return GenerateColumnBindings(table_index, table.GetTypes().size()); - } - return {ColumnBinding(0, 0)}; -} - -void LogicalInsert::ResolveTypes() { - if (return_chunk) { - types = table.GetTypes(); - } else { - types.emplace_back(LogicalType::BIGINT); - } -} - -string LogicalInsert::GetName() const { -#ifdef DEBUG - if (DBConfigOptions::debug_print_bindings) { - return LogicalOperator::GetName() + StringUtil::Format(" #%llu", table_index); - } -#endif - return LogicalOperator::GetName(); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -LogicalJoin::LogicalJoin(JoinType join_type, LogicalOperatorType logical_type) - : LogicalOperator(logical_type), join_type(join_type) { -} - -vector LogicalJoin::GetColumnBindings() { - auto left_bindings = MapBindings(children[0]->GetColumnBindings(), left_projection_map); - if (join_type == JoinType::SEMI || join_type == JoinType::ANTI) { - // for SEMI and ANTI join we only project the left hand side - return left_bindings; - } - if (join_type == JoinType::MARK) { - // for MARK join we project the left hand side plus the MARK column - left_bindings.emplace_back(mark_index, 0); - return left_bindings; - } - // for other join types we project both the LHS and the RHS - auto right_bindings = MapBindings(children[1]->GetColumnBindings(), right_projection_map); - left_bindings.insert(left_bindings.end(), right_bindings.begin(), right_bindings.end()); - return left_bindings; -} - -void LogicalJoin::ResolveTypes() { - types = MapTypes(children[0]->types, left_projection_map); - if (join_type == JoinType::SEMI || join_type == JoinType::ANTI) { - // for SEMI and ANTI join we only project the left hand side - return; - } - if (join_type == JoinType::MARK) { - // for MARK join we project the left hand side, plus a BOOLEAN column indicating the MARK - types.emplace_back(LogicalType::BOOLEAN); - return; - } - // for any other join we project both sides - auto right_types = MapTypes(children[1]->types, right_projection_map); - types.insert(types.end(), right_types.begin(), right_types.end()); -} - -void LogicalJoin::GetTableReferences(LogicalOperator &op, unordered_set &bindings) { - auto column_bindings = op.GetColumnBindings(); - for (auto binding : column_bindings) { - bindings.insert(binding.table_index); - } -} - -void LogicalJoin::GetExpressionBindings(Expression &expr, unordered_set &bindings) { - if (expr.type == ExpressionType::BOUND_COLUMN_REF) { - auto &colref = expr.Cast(); - D_ASSERT(colref.depth == 0); - bindings.insert(colref.binding.table_index); - } - ExpressionIterator::EnumerateChildren(expr, [&](Expression &child) { GetExpressionBindings(child, bindings); }); -} - -} // namespace duckdb - - -namespace duckdb { - -LogicalLimit::LogicalLimit(int64_t limit_val, int64_t offset_val, unique_ptr limit, - unique_ptr offset) - : LogicalOperator(LogicalOperatorType::LOGICAL_LIMIT), limit_val(limit_val), offset_val(offset_val), - limit(std::move(limit)), offset(std::move(offset)) { -} - -vector LogicalLimit::GetColumnBindings() { - return children[0]->GetColumnBindings(); -} - -idx_t LogicalLimit::EstimateCardinality(ClientContext &context) { - auto child_cardinality = children[0]->EstimateCardinality(context); - if (limit_val >= 0 && idx_t(limit_val) < child_cardinality) { - child_cardinality = limit_val; - } - return child_cardinality; -} - -void LogicalLimit::ResolveTypes() { - types = children[0]->types; -} - -} // namespace duckdb - -#include - -namespace duckdb { - -idx_t LogicalLimitPercent::EstimateCardinality(ClientContext &context) { - auto child_cardinality = LogicalOperator::EstimateCardinality(context); - if ((limit_percent < 0 || limit_percent > 100) || std::isnan(limit_percent)) { - return child_cardinality; - } - return idx_t(child_cardinality * (limit_percent / 100.0)); -} - -} // namespace duckdb - - -namespace duckdb { - -vector LogicalMaterializedCTE::GetTableIndex() const { - return vector {table_index}; -} - -} // namespace duckdb - - -namespace duckdb { - -LogicalOrder::LogicalOrder(vector orders) - : LogicalOperator(LogicalOperatorType::LOGICAL_ORDER_BY), orders(std::move(orders)) { -} - -vector LogicalOrder::GetColumnBindings() { - auto child_bindings = children[0]->GetColumnBindings(); - if (projections.empty()) { - return child_bindings; - } - - vector result; - for (auto &col_idx : projections) { - result.push_back(child_bindings[col_idx]); - } - return result; -} - -string LogicalOrder::ParamsToString() const { - string result = "ORDERS:\n"; - for (idx_t i = 0; i < orders.size(); i++) { - if (i > 0) { - result += "\n"; - } - result += orders[i].expression->GetName(); - } - return result; -} - -void LogicalOrder::ResolveTypes() { - const auto child_types = children[0]->types; - if (projections.empty()) { - types = child_types; - } else { - for (auto &col_idx : projections) { - types.push_back(child_types[col_idx]); - } - } -} - -} // namespace duckdb - - - - -namespace duckdb { - -LogicalPivot::LogicalPivot() : LogicalOperator(LogicalOperatorType::LOGICAL_PIVOT) { -} - -LogicalPivot::LogicalPivot(idx_t pivot_idx, unique_ptr plan, BoundPivotInfo info_p) - : LogicalOperator(LogicalOperatorType::LOGICAL_PIVOT), pivot_index(pivot_idx), bound_pivot(std::move(info_p)) { - D_ASSERT(plan); - children.push_back(std::move(plan)); -} - -vector LogicalPivot::GetColumnBindings() { - vector result; - for (idx_t i = 0; i < bound_pivot.types.size(); i++) { - result.emplace_back(pivot_index, i); - } - return result; -} - -vector LogicalPivot::GetTableIndex() const { - return vector {pivot_index}; -} - -void LogicalPivot::ResolveTypes() { - this->types = bound_pivot.types; -} - -string LogicalPivot::GetName() const { -#ifdef DEBUG - if (DBConfigOptions::debug_print_bindings) { - return LogicalOperator::GetName() + StringUtil::Format(" #%llu", pivot_index); - } -#endif - return LogicalOperator::GetName(); -} - -} // namespace duckdb - - -namespace duckdb { - -LogicalPositionalJoin::LogicalPositionalJoin(unique_ptr left, unique_ptr right) - : LogicalUnconditionalJoin(LogicalOperatorType::LOGICAL_POSITIONAL_JOIN, std::move(left), std::move(right)) { -} - -unique_ptr LogicalPositionalJoin::Create(unique_ptr left, - unique_ptr right) { - if (left->type == LogicalOperatorType::LOGICAL_DUMMY_SCAN) { - return right; - } - if (right->type == LogicalOperatorType::LOGICAL_DUMMY_SCAN) { - return left; - } - return make_uniq(std::move(left), std::move(right)); -} - -} // namespace duckdb - - -namespace duckdb { - -idx_t LogicalPragma::EstimateCardinality(ClientContext &context) { - return 1; -} - -} // namespace duckdb - - -namespace duckdb { - -idx_t LogicalPrepare::EstimateCardinality(ClientContext &context) { - return 1; -} - -} // namespace duckdb - - - - -namespace duckdb { - -LogicalProjection::LogicalProjection(idx_t table_index, vector> select_list) - : LogicalOperator(LogicalOperatorType::LOGICAL_PROJECTION, std::move(select_list)), table_index(table_index) { -} - -vector LogicalProjection::GetColumnBindings() { - return GenerateColumnBindings(table_index, expressions.size()); -} - -void LogicalProjection::ResolveTypes() { - for (auto &expr : expressions) { - types.push_back(expr->return_type); - } -} - -vector LogicalProjection::GetTableIndex() const { - return vector {table_index}; -} - -string LogicalProjection::GetName() const { -#ifdef DEBUG - if (DBConfigOptions::debug_print_bindings) { - return LogicalOperator::GetName() + StringUtil::Format(" #%llu", table_index); - } -#endif - return LogicalOperator::GetName(); -} - -} // namespace duckdb - - - - -namespace duckdb { - -vector LogicalRecursiveCTE::GetTableIndex() const { - return vector {table_index}; -} - -string LogicalRecursiveCTE::GetName() const { -#ifdef DEBUG - if (DBConfigOptions::debug_print_bindings) { - return LogicalOperator::GetName() + StringUtil::Format(" #%llu", table_index); - } -#endif - return LogicalOperator::GetName(); -} - -} // namespace duckdb - - -namespace duckdb { - -idx_t LogicalReset::EstimateCardinality(ClientContext &context) { - return 1; -} - -} // namespace duckdb - - -namespace duckdb { - -LogicalSample::LogicalSample() : LogicalOperator(LogicalOperatorType::LOGICAL_SAMPLE) { -} - -LogicalSample::LogicalSample(unique_ptr sample_options_p, unique_ptr child) - : LogicalOperator(LogicalOperatorType::LOGICAL_SAMPLE), sample_options(std::move(sample_options_p)) { - children.push_back(std::move(child)); -} - -vector LogicalSample::GetColumnBindings() { - return children[0]->GetColumnBindings(); -} - -idx_t LogicalSample::EstimateCardinality(ClientContext &context) { - auto child_cardinality = children[0]->EstimateCardinality(context); - if (sample_options->is_percentage) { - double sample_cardinality = - double(child_cardinality) * (sample_options->sample_size.GetValue() / 100.0); - if (sample_cardinality > double(child_cardinality)) { - return child_cardinality; - } - return idx_t(sample_cardinality); - } else { - auto sample_size = sample_options->sample_size.GetValue(); - if (sample_size < child_cardinality) { - return sample_size; - } - } - return child_cardinality; -} - -void LogicalSample::ResolveTypes() { - types = children[0]->types; -} - -} // namespace duckdb - - -namespace duckdb { - -idx_t LogicalSet::EstimateCardinality(ClientContext &context) { - return 1; -} - -} // namespace duckdb - - - - -namespace duckdb { - -vector LogicalSetOperation::GetTableIndex() const { - return vector {table_index}; -} - -string LogicalSetOperation::GetName() const { -#ifdef DEBUG - if (DBConfigOptions::debug_print_bindings) { - return LogicalOperator::GetName() + StringUtil::Format(" #%llu", table_index); - } -#endif - return LogicalOperator::GetName(); -} - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -idx_t LogicalSimple::EstimateCardinality(ClientContext &context) { - return 1; -} - -} // namespace duckdb - - -namespace duckdb { - -idx_t LogicalTopN::EstimateCardinality(ClientContext &context) { - auto child_cardinality = LogicalOperator::EstimateCardinality(context); - if (limit >= 0 && child_cardinality < idx_t(limit)) { - return limit; - } - return child_cardinality; -} - -} // namespace duckdb - - -namespace duckdb { - -LogicalUnconditionalJoin::LogicalUnconditionalJoin(LogicalOperatorType logical_type, unique_ptr left, - unique_ptr right) - : LogicalOperator(logical_type) { - D_ASSERT(left); - D_ASSERT(right); - children.push_back(std::move(left)); - children.push_back(std::move(right)); -} - -vector LogicalUnconditionalJoin::GetColumnBindings() { - auto left_bindings = children[0]->GetColumnBindings(); - auto right_bindings = children[1]->GetColumnBindings(); - left_bindings.insert(left_bindings.end(), right_bindings.begin(), right_bindings.end()); - return left_bindings; -} - -void LogicalUnconditionalJoin::ResolveTypes() { - types.insert(types.end(), children[0]->types.begin(), children[0]->types.end()); - types.insert(types.end(), children[1]->types.begin(), children[1]->types.end()); -} - -} // namespace duckdb - - - - -namespace duckdb { - -vector LogicalUnnest::GetColumnBindings() { - auto child_bindings = children[0]->GetColumnBindings(); - for (idx_t i = 0; i < expressions.size(); i++) { - child_bindings.emplace_back(unnest_index, i); - } - return child_bindings; -} - -void LogicalUnnest::ResolveTypes() { - types.insert(types.end(), children[0]->types.begin(), children[0]->types.end()); - for (auto &expr : expressions) { - types.push_back(expr->return_type); - } -} - -vector LogicalUnnest::GetTableIndex() const { - return vector {unnest_index}; -} - -string LogicalUnnest::GetName() const { -#ifdef DEBUG - if (DBConfigOptions::debug_print_bindings) { - return LogicalOperator::GetName() + StringUtil::Format(" #%llu", unnest_index); - } -#endif - return LogicalOperator::GetName(); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -LogicalUpdate::LogicalUpdate(TableCatalogEntry &table) - : LogicalOperator(LogicalOperatorType::LOGICAL_UPDATE), table(table), table_index(0), return_chunk(false) { -} - -LogicalUpdate::LogicalUpdate(ClientContext &context, const unique_ptr &table_info) - : LogicalOperator(LogicalOperatorType::LOGICAL_UPDATE), - table(Catalog::GetEntry(context, table_info->catalog, table_info->schema, - dynamic_cast(*table_info).table)) { -} - -idx_t LogicalUpdate::EstimateCardinality(ClientContext &context) { - return return_chunk ? LogicalOperator::EstimateCardinality(context) : 1; -} - -vector LogicalUpdate::GetColumnBindings() { - if (return_chunk) { - return GenerateColumnBindings(table_index, table.GetTypes().size()); - } - return {ColumnBinding(0, 0)}; -} - -void LogicalUpdate::ResolveTypes() { - if (return_chunk) { - types = table.GetTypes(); - } else { - types.emplace_back(LogicalType::BIGINT); - } -} - -string LogicalUpdate::GetName() const { -#ifdef DEBUG - if (DBConfigOptions::debug_print_bindings) { - return LogicalOperator::GetName() + StringUtil::Format(" #%llu", table_index); - } -#endif - return LogicalOperator::GetName(); -} - -} // namespace duckdb - - - - -namespace duckdb { - -vector LogicalWindow::GetColumnBindings() { - auto child_bindings = children[0]->GetColumnBindings(); - for (idx_t i = 0; i < expressions.size(); i++) { - child_bindings.emplace_back(window_index, i); - } - return child_bindings; -} - -void LogicalWindow::ResolveTypes() { - types.insert(types.end(), children[0]->types.begin(), children[0]->types.end()); - for (auto &expr : expressions) { - types.push_back(expr->return_type); - } -} - -vector LogicalWindow::GetTableIndex() const { - return vector {window_index}; -} - -string LogicalWindow::GetName() const { -#ifdef DEBUG - if (DBConfigOptions::debug_print_bindings) { - return LogicalOperator::GetName() + StringUtil::Format(" #%llu", window_index); - } -#endif - return LogicalOperator::GetName(); -} - -} // namespace duckdb - - - - - - - - - - - - - - - -namespace duckdb { - -Planner::Planner(ClientContext &context) : binder(Binder::CreateBinder(context)), context(context) { -} - -static void CheckTreeDepth(const LogicalOperator &op, idx_t max_depth, idx_t depth = 0) { - if (depth >= max_depth) { - throw ParserException("Maximum tree depth of %lld exceeded in logical planner", max_depth); - } - for (auto &child : op.children) { - CheckTreeDepth(*child, max_depth, depth + 1); - } -} - -void Planner::CreatePlan(SQLStatement &statement) { - auto &profiler = QueryProfiler::Get(context); - auto parameter_count = statement.n_param; - - BoundParameterMap bound_parameters(parameter_data); - - // first bind the tables and columns to the catalog - bool parameters_resolved = true; - try { - profiler.StartPhase("binder"); - binder->parameters = &bound_parameters; - auto bound_statement = binder->Bind(statement); - profiler.EndPhase(); - - this->names = bound_statement.names; - this->types = bound_statement.types; - this->plan = std::move(bound_statement.plan); - - auto max_tree_depth = ClientConfig::GetConfig(context).max_expression_depth; - CheckTreeDepth(*plan, max_tree_depth); - } catch (const ParameterNotResolvedException &ex) { - // parameter types could not be resolved - this->names = {"unknown"}; - this->types = {LogicalTypeId::UNKNOWN}; - this->plan = nullptr; - parameters_resolved = false; - } catch (const Exception &ex) { - auto &config = DBConfig::GetConfig(context); - - this->plan = nullptr; - for (auto &extension_op : config.operator_extensions) { - auto bound_statement = - extension_op->Bind(context, *this->binder, extension_op->operator_info.get(), statement); - if (bound_statement.plan != nullptr) { - this->names = bound_statement.names; - this->types = bound_statement.types; - this->plan = std::move(bound_statement.plan); - break; - } - } - - if (!this->plan) { - throw; - } - } catch (std::exception &ex) { - throw; - } - this->properties = binder->properties; - this->properties.parameter_count = parameter_count; - properties.bound_all_parameters = parameters_resolved; - - Planner::VerifyPlan(context, plan, bound_parameters.GetParametersPtr()); - - // set up a map of parameter number -> value entries - for (auto &kv : bound_parameters.GetParameters()) { - auto &identifier = kv.first; - auto ¶m = kv.second; - // check if the type of the parameter could be resolved - if (!param->return_type.IsValid()) { - properties.bound_all_parameters = false; - continue; - } - param->SetValue(Value(param->return_type)); - value_map[identifier] = param; - } -} - -shared_ptr Planner::PrepareSQLStatement(unique_ptr statement) { - auto copied_statement = statement->Copy(); - // create a plan of the underlying statement - CreatePlan(std::move(statement)); - // now create the logical prepare - auto prepared_data = make_shared(copied_statement->type); - prepared_data->unbound_statement = std::move(copied_statement); - prepared_data->names = names; - prepared_data->types = types; - prepared_data->value_map = std::move(value_map); - prepared_data->properties = properties; - prepared_data->catalog_version = MetaTransaction::Get(context).catalog_version; - return prepared_data; -} - -void Planner::CreatePlan(unique_ptr statement) { - D_ASSERT(statement); - switch (statement->type) { - case StatementType::SELECT_STATEMENT: - case StatementType::INSERT_STATEMENT: - case StatementType::COPY_STATEMENT: - case StatementType::DELETE_STATEMENT: - case StatementType::UPDATE_STATEMENT: - case StatementType::CREATE_STATEMENT: - case StatementType::DROP_STATEMENT: - case StatementType::ALTER_STATEMENT: - case StatementType::TRANSACTION_STATEMENT: - case StatementType::EXPLAIN_STATEMENT: - case StatementType::VACUUM_STATEMENT: - case StatementType::RELATION_STATEMENT: - case StatementType::CALL_STATEMENT: - case StatementType::EXPORT_STATEMENT: - case StatementType::PRAGMA_STATEMENT: - case StatementType::SHOW_STATEMENT: - case StatementType::SET_STATEMENT: - case StatementType::LOAD_STATEMENT: - case StatementType::EXTENSION_STATEMENT: - case StatementType::PREPARE_STATEMENT: - case StatementType::EXECUTE_STATEMENT: - case StatementType::LOGICAL_PLAN_STATEMENT: - case StatementType::ATTACH_STATEMENT: - case StatementType::DETACH_STATEMENT: - CreatePlan(*statement); - break; - default: - throw NotImplementedException("Cannot plan statement of type %s!", StatementTypeToString(statement->type)); - } -} - -static bool OperatorSupportsSerialization(LogicalOperator &op) { - for (auto &child : op.children) { - if (!OperatorSupportsSerialization(*child)) { - return false; - } - } - return op.SupportSerialization(); -} - -void Planner::VerifyPlan(ClientContext &context, unique_ptr &op, - optional_ptr map) { -#ifdef DUCKDB_ALTERNATIVE_VERIFY - // if alternate verification is enabled we run the original operator - return; -#endif - if (!op || !ClientConfig::GetConfig(context).verify_serializer) { - return; - } - //! SELECT only for now - if (!OperatorSupportsSerialization(*op)) { - return; - } - - // format (de)serialization of this operator - try { - MemoryStream stream; - BinarySerializer::Serialize(*op, stream, true); - stream.Rewind(); - bound_parameter_map_t parameters; - auto new_plan = BinaryDeserializer::Deserialize(stream, context, parameters); - - if (map) { - *map = std::move(parameters); - } - op = std::move(new_plan); - } catch (SerializationException &ex) { - // pass - } catch (NotImplementedException &ex) { - // pass - } -} - -} // namespace duckdb - - - - - - - - - - - - - - - - -namespace duckdb { - -PragmaHandler::PragmaHandler(ClientContext &context) : context(context) { -} - -void PragmaHandler::HandlePragmaStatementsInternal(vector> &statements) { - vector> new_statements; - for (idx_t i = 0; i < statements.size(); i++) { - if (statements[i]->type == StatementType::MULTI_STATEMENT) { - auto &multi_statement = statements[i]->Cast(); - for (auto &stmt : multi_statement.statements) { - statements.push_back(std::move(stmt)); - } - continue; - } - if (statements[i]->type == StatementType::PRAGMA_STATEMENT) { - // PRAGMA statement: check if we need to replace it by a new set of statements - PragmaHandler handler(context); - string new_query; - bool expanded = handler.HandlePragma(statements[i].get(), new_query); - if (expanded) { - // this PRAGMA statement gets replaced by a new query string - // push the new query string through the parser again and add it to the transformer - Parser parser(context.GetParserOptions()); - parser.ParseQuery(new_query); - // insert the new statements and remove the old statement - for (idx_t j = 0; j < parser.statements.size(); j++) { - new_statements.push_back(std::move(parser.statements[j])); - } - continue; - } - } - new_statements.push_back(std::move(statements[i])); - } - statements = std::move(new_statements); -} - -void PragmaHandler::HandlePragmaStatements(ClientContextLock &lock, vector> &statements) { - // first check if there are any pragma statements - bool found_pragma = false; - for (idx_t i = 0; i < statements.size(); i++) { - if (statements[i]->type == StatementType::PRAGMA_STATEMENT || - statements[i]->type == StatementType::MULTI_STATEMENT) { - found_pragma = true; - break; - } - } - if (!found_pragma) { - // no pragmas: skip this step - return; - } - context.RunFunctionInTransactionInternal(lock, [&]() { HandlePragmaStatementsInternal(statements); }); -} - -bool PragmaHandler::HandlePragma(SQLStatement *statement, string &resulting_query) { // PragmaInfo &info - auto info = *(statement->Cast()).info; - auto &entry = Catalog::GetEntry(context, INVALID_CATALOG, DEFAULT_SCHEMA, info.name); - string error; - - FunctionBinder function_binder(context); - idx_t bound_idx = function_binder.BindFunction(entry.name, entry.functions, info, error); - if (bound_idx == DConstants::INVALID_INDEX) { - throw BinderException(error); - } - auto bound_function = entry.functions.GetFunctionByOffset(bound_idx); - if (bound_function.query) { - QueryErrorContext error_context(statement, statement->stmt_location); - Binder::BindNamedParameters(bound_function.named_parameters, info.named_parameters, error_context, - bound_function.name); - FunctionParameters parameters {info.parameters, info.named_parameters}; - resulting_query = bound_function.query(context, parameters); - return true; - } - return false; -} - -} // namespace duckdb - - - - - - - - - - - - - - -namespace duckdb { - -FlattenDependentJoins::FlattenDependentJoins(Binder &binder, const vector &correlated, - bool perform_delim, bool any_join) - : binder(binder), delim_offset(DConstants::INVALID_INDEX), correlated_columns(correlated), - perform_delim(perform_delim), any_join(any_join) { - for (idx_t i = 0; i < correlated_columns.size(); i++) { - auto &col = correlated_columns[i]; - correlated_map[col.binding] = i; - delim_types.push_back(col.type); - } -} - -bool FlattenDependentJoins::DetectCorrelatedExpressions(LogicalOperator *op, bool lateral, idx_t lateral_depth) { - - bool is_lateral_join = false; - - D_ASSERT(op); - // check if this entry has correlated expressions - if (op->type == LogicalOperatorType::LOGICAL_DEPENDENT_JOIN) { - is_lateral_join = true; - } - HasCorrelatedExpressions visitor(correlated_columns, lateral, lateral_depth); - visitor.VisitOperator(*op); - bool has_correlation = visitor.has_correlated_expressions; - int child_idx = 0; - // now visit the children of this entry and check if they have correlated expressions - for (auto &child : op->children) { - auto new_lateral_depth = lateral_depth; - if (is_lateral_join && child_idx == 1) { - new_lateral_depth = lateral_depth + 1; - } - // we OR the property with its children such that has_correlation is true if either - // (1) this node has a correlated expression or - // (2) one of its children has a correlated expression - if (DetectCorrelatedExpressions(child.get(), lateral, new_lateral_depth)) { - has_correlation = true; - } - child_idx++; - } - // set the entry in the map - has_correlated_expressions[op] = has_correlation; - return has_correlation; -} - -unique_ptr FlattenDependentJoins::PushDownDependentJoin(unique_ptr plan) { - bool propagate_null_values = true; - auto result = PushDownDependentJoinInternal(std::move(plan), propagate_null_values, 0); - if (!replacement_map.empty()) { - // check if we have to replace any COUNT aggregates into "CASE WHEN X IS NULL THEN 0 ELSE COUNT END" - RewriteCountAggregates aggr(replacement_map); - aggr.VisitOperator(*result); - } - return result; -} - -bool SubqueryDependentFilter(Expression *expr) { - if (expr->expression_class == ExpressionClass::BOUND_CONJUNCTION && - expr->GetExpressionType() == ExpressionType::CONJUNCTION_AND) { - auto &bound_conjuction = expr->Cast(); - for (auto &child : bound_conjuction.children) { - if (SubqueryDependentFilter(child.get())) { - return true; - } - } - } - if (expr->expression_class == ExpressionClass::BOUND_SUBQUERY) { - return true; - } - return false; -} - -unique_ptr FlattenDependentJoins::PushDownDependentJoinInternal(unique_ptr plan, - bool &parent_propagate_null_values, - idx_t lateral_depth) { - // first check if the logical operator has correlated expressions - auto entry = has_correlated_expressions.find(plan.get()); - D_ASSERT(entry != has_correlated_expressions.end()); - if (!entry->second) { - // we reached a node without correlated expressions - // we can eliminate the dependent join now and create a simple cross product - // now create the duplicate eliminated scan for this node - auto left_columns = plan->GetColumnBindings().size(); - auto delim_index = binder.GenerateTableIndex(); - this->base_binding = ColumnBinding(delim_index, 0); - this->delim_offset = left_columns; - this->data_offset = 0; - auto delim_scan = make_uniq(delim_index, delim_types); - return LogicalCrossProduct::Create(std::move(plan), std::move(delim_scan)); - } - switch (plan->type) { - case LogicalOperatorType::LOGICAL_UNNEST: - case LogicalOperatorType::LOGICAL_FILTER: { - // filter - // first we flatten the dependent join in the child of the filter - for (auto &expr : plan->expressions) { - any_join |= SubqueryDependentFilter(expr.get()); - } - plan->children[0] = - PushDownDependentJoinInternal(std::move(plan->children[0]), parent_propagate_null_values, lateral_depth); - - // then we replace any correlated expressions with the corresponding entry in the correlated_map - RewriteCorrelatedExpressions rewriter(base_binding, correlated_map, lateral_depth); - rewriter.VisitOperator(*plan); - return plan; - } - case LogicalOperatorType::LOGICAL_PROJECTION: { - // projection - // first we flatten the dependent join in the child of the projection - for (auto &expr : plan->expressions) { - parent_propagate_null_values &= expr->PropagatesNullValues(); - } - plan->children[0] = - PushDownDependentJoinInternal(std::move(plan->children[0]), parent_propagate_null_values, lateral_depth); - - // then we replace any correlated expressions with the corresponding entry in the correlated_map - RewriteCorrelatedExpressions rewriter(base_binding, correlated_map, lateral_depth); - rewriter.VisitOperator(*plan); - // now we add all the columns of the delim_scan to the projection list - auto &proj = plan->Cast(); - for (idx_t i = 0; i < correlated_columns.size(); i++) { - auto &col = correlated_columns[i]; - auto colref = make_uniq( - col.name, col.type, ColumnBinding(base_binding.table_index, base_binding.column_index + i)); - plan->expressions.push_back(std::move(colref)); - } - - base_binding.table_index = proj.table_index; - this->delim_offset = base_binding.column_index = plan->expressions.size() - correlated_columns.size(); - this->data_offset = 0; - return plan; - } - case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: { - auto &aggr = plan->Cast(); - // aggregate and group by - // first we flatten the dependent join in the child of the projection - for (auto &expr : plan->expressions) { - parent_propagate_null_values &= expr->PropagatesNullValues(); - } - plan->children[0] = - PushDownDependentJoinInternal(std::move(plan->children[0]), parent_propagate_null_values, lateral_depth); - // then we replace any correlated expressions with the corresponding entry in the correlated_map - RewriteCorrelatedExpressions rewriter(base_binding, correlated_map, lateral_depth); - rewriter.VisitOperator(*plan); - // now we add all the columns of the delim_scan to the grouping operators AND the projection list - idx_t delim_table_index; - idx_t delim_column_offset; - idx_t delim_data_offset; - auto new_group_count = perform_delim ? correlated_columns.size() : 1; - for (idx_t i = 0; i < new_group_count; i++) { - auto &col = correlated_columns[i]; - auto colref = make_uniq( - col.name, col.type, ColumnBinding(base_binding.table_index, base_binding.column_index + i)); - for (auto &set : aggr.grouping_sets) { - set.insert(aggr.groups.size()); - } - aggr.groups.push_back(std::move(colref)); - } - if (!perform_delim) { - // if we are not performing the duplicate elimination, we have only added the row_id column to the grouping - // operators in this case, we push a FIRST aggregate for each of the remaining expressions - delim_table_index = aggr.aggregate_index; - delim_column_offset = aggr.expressions.size(); - delim_data_offset = aggr.groups.size(); - for (idx_t i = 0; i < correlated_columns.size(); i++) { - auto &col = correlated_columns[i]; - auto first_aggregate = FirstFun::GetFunction(col.type); - auto colref = make_uniq( - col.name, col.type, ColumnBinding(base_binding.table_index, base_binding.column_index + i)); - vector> aggr_children; - aggr_children.push_back(std::move(colref)); - auto first_fun = - make_uniq(std::move(first_aggregate), std::move(aggr_children), nullptr, - nullptr, AggregateType::NON_DISTINCT); - aggr.expressions.push_back(std::move(first_fun)); - } - } else { - delim_table_index = aggr.group_index; - delim_column_offset = aggr.groups.size() - correlated_columns.size(); - delim_data_offset = aggr.groups.size(); - } - if (aggr.groups.size() == new_group_count) { - // we have to perform a LEFT OUTER JOIN between the result of this aggregate and the delim scan - // FIXME: this does not always have to be a LEFT OUTER JOIN, depending on whether aggr.expressions return - // NULL or a value - unique_ptr join = make_uniq(JoinType::INNER); - for (auto &aggr_exp : aggr.expressions) { - auto &b_aggr_exp = aggr_exp->Cast(); - if (!b_aggr_exp.PropagatesNullValues() || any_join || !parent_propagate_null_values) { - join = make_uniq(JoinType::LEFT); - break; - } - } - auto left_index = binder.GenerateTableIndex(); - auto delim_scan = make_uniq(left_index, delim_types); - join->children.push_back(std::move(delim_scan)); - join->children.push_back(std::move(plan)); - for (idx_t i = 0; i < new_group_count; i++) { - auto &col = correlated_columns[i]; - JoinCondition cond; - cond.left = make_uniq(col.name, col.type, ColumnBinding(left_index, i)); - cond.right = make_uniq( - correlated_columns[i].type, ColumnBinding(delim_table_index, delim_column_offset + i)); - cond.comparison = ExpressionType::COMPARE_NOT_DISTINCT_FROM; - join->conditions.push_back(std::move(cond)); - } - // for any COUNT aggregate we replace references to the column with: CASE WHEN COUNT(*) IS NULL THEN 0 - // ELSE COUNT(*) END - for (idx_t i = 0; i < aggr.expressions.size(); i++) { - D_ASSERT(aggr.expressions[i]->GetExpressionClass() == ExpressionClass::BOUND_AGGREGATE); - auto &bound = aggr.expressions[i]->Cast(); - vector arguments; - if (bound.function == CountFun::GetFunction() || bound.function == CountStarFun::GetFunction()) { - // have to replace this ColumnBinding with the CASE expression - replacement_map[ColumnBinding(aggr.aggregate_index, i)] = i; - } - } - // now we update the delim_index - base_binding.table_index = left_index; - this->delim_offset = base_binding.column_index = 0; - this->data_offset = 0; - return std::move(join); - } else { - // update the delim_index - base_binding.table_index = delim_table_index; - this->delim_offset = base_binding.column_index = delim_column_offset; - this->data_offset = delim_data_offset; - return plan; - } - } - case LogicalOperatorType::LOGICAL_CROSS_PRODUCT: { - // cross product - // push into both sides of the plan - bool left_has_correlation = has_correlated_expressions.find(plan->children[0].get())->second; - bool right_has_correlation = has_correlated_expressions.find(plan->children[1].get())->second; - if (!right_has_correlation) { - // only left has correlation: push into left - plan->children[0] = PushDownDependentJoinInternal(std::move(plan->children[0]), - parent_propagate_null_values, lateral_depth); - return plan; - } - if (!left_has_correlation) { - // only right has correlation: push into right - plan->children[1] = PushDownDependentJoinInternal(std::move(plan->children[1]), - parent_propagate_null_values, lateral_depth); - return plan; - } - // both sides have correlation - // turn into an inner join - auto join = make_uniq(JoinType::INNER); - plan->children[0] = - PushDownDependentJoinInternal(std::move(plan->children[0]), parent_propagate_null_values, lateral_depth); - auto left_binding = this->base_binding; - plan->children[1] = - PushDownDependentJoinInternal(std::move(plan->children[1]), parent_propagate_null_values, lateral_depth); - // add the correlated columns to the join conditions - for (idx_t i = 0; i < correlated_columns.size(); i++) { - JoinCondition cond; - cond.left = make_uniq( - correlated_columns[i].type, ColumnBinding(left_binding.table_index, left_binding.column_index + i)); - cond.right = make_uniq( - correlated_columns[i].type, ColumnBinding(base_binding.table_index, base_binding.column_index + i)); - cond.comparison = ExpressionType::COMPARE_NOT_DISTINCT_FROM; - join->conditions.push_back(std::move(cond)); - } - join->children.push_back(std::move(plan->children[0])); - join->children.push_back(std::move(plan->children[1])); - return std::move(join); - } - case LogicalOperatorType::LOGICAL_DEPENDENT_JOIN: { - auto &dependent_join = plan->Cast(); - if (!((dependent_join.join_type == JoinType::INNER) || (dependent_join.join_type == JoinType::LEFT))) { - throw Exception("Dependent join can only be INNER or LEFT type"); - } - D_ASSERT(plan->children.size() == 2); - // Push all the bindings down to the left side so the right side knows where to refer DELIM_GET from - plan->children[0] = - PushDownDependentJoinInternal(std::move(plan->children[0]), parent_propagate_null_values, lateral_depth); - - // Normal rewriter like in other joins - RewriteCorrelatedExpressions rewriter(this->base_binding, correlated_map, lateral_depth); - rewriter.VisitOperator(*plan); - - // Recursive rewriter to visit right side of lateral join and update bindings from left - RewriteCorrelatedExpressions recursive_rewriter(this->base_binding, correlated_map, lateral_depth + 1, true); - recursive_rewriter.VisitOperator(*plan->children[1]); - - return plan; - } - case LogicalOperatorType::LOGICAL_ANY_JOIN: - case LogicalOperatorType::LOGICAL_ASOF_JOIN: - case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: { - auto &join = plan->Cast(); - D_ASSERT(plan->children.size() == 2); - // check the correlated expressions in the children of the join - bool left_has_correlation = has_correlated_expressions.find(plan->children[0].get())->second; - bool right_has_correlation = has_correlated_expressions.find(plan->children[1].get())->second; - - if (join.join_type == JoinType::INNER) { - // inner join - if (!right_has_correlation) { - // only left has correlation: push into left - plan->children[0] = PushDownDependentJoinInternal(std::move(plan->children[0]), - parent_propagate_null_values, lateral_depth); - // Remove the correlated columns coming from outside for current join node - return plan; - } - if (!left_has_correlation) { - // only right has correlation: push into right - plan->children[1] = PushDownDependentJoinInternal(std::move(plan->children[1]), - parent_propagate_null_values, lateral_depth); - // Remove the correlated columns coming from outside for current join node - return plan; - } - } else if (join.join_type == JoinType::LEFT) { - // left outer join - if (!right_has_correlation) { - // only left has correlation: push into left - plan->children[0] = PushDownDependentJoinInternal(std::move(plan->children[0]), - parent_propagate_null_values, lateral_depth); - // Remove the correlated columns coming from outside for current join node - return plan; - } - } else if (join.join_type == JoinType::RIGHT) { - // left outer join - if (!left_has_correlation) { - // only right has correlation: push into right - plan->children[1] = PushDownDependentJoinInternal(std::move(plan->children[1]), - parent_propagate_null_values, lateral_depth); - return plan; - } - } else if (join.join_type == JoinType::MARK) { - if (right_has_correlation) { - throw Exception("MARK join with correlation in RHS not supported"); - } - // push the child into the LHS - plan->children[0] = PushDownDependentJoinInternal(std::move(plan->children[0]), - parent_propagate_null_values, lateral_depth); - // rewrite expressions in the join conditions - RewriteCorrelatedExpressions rewriter(base_binding, correlated_map, lateral_depth); - rewriter.VisitOperator(*plan); - return plan; - } else { - throw Exception("Unsupported join type for flattening correlated subquery"); - } - // both sides have correlation - // push into both sides - plan->children[0] = - PushDownDependentJoinInternal(std::move(plan->children[0]), parent_propagate_null_values, lateral_depth); - auto left_binding = this->base_binding; - plan->children[1] = - PushDownDependentJoinInternal(std::move(plan->children[1]), parent_propagate_null_values, lateral_depth); - auto right_binding = this->base_binding; - // NOTE: for OUTER JOINS it matters what the BASE BINDING is after the join - // for the LEFT OUTER JOIN, we want the LEFT side to be the base binding after we push - // because the RIGHT binding might contain NULL values - if (join.join_type == JoinType::LEFT) { - this->base_binding = left_binding; - } else if (join.join_type == JoinType::RIGHT) { - this->base_binding = right_binding; - } - // add the correlated columns to the join conditions - for (idx_t i = 0; i < correlated_columns.size(); i++) { - auto left = make_uniq( - correlated_columns[i].type, ColumnBinding(left_binding.table_index, left_binding.column_index + i)); - auto right = make_uniq( - correlated_columns[i].type, ColumnBinding(right_binding.table_index, right_binding.column_index + i)); - - if (join.type == LogicalOperatorType::LOGICAL_COMPARISON_JOIN || - join.type == LogicalOperatorType::LOGICAL_ASOF_JOIN) { - JoinCondition cond; - cond.left = std::move(left); - cond.right = std::move(right); - cond.comparison = ExpressionType::COMPARE_NOT_DISTINCT_FROM; - - auto &comparison_join = join.Cast(); - comparison_join.conditions.push_back(std::move(cond)); - } else { - auto &any_join = join.Cast(); - auto comparison = make_uniq(ExpressionType::COMPARE_NOT_DISTINCT_FROM, - std::move(left), std::move(right)); - auto conjunction = make_uniq( - ExpressionType::CONJUNCTION_AND, std::move(comparison), std::move(any_join.condition)); - any_join.condition = std::move(conjunction); - } - } - // then we replace any correlated expressions with the corresponding entry in the correlated_map - RewriteCorrelatedExpressions rewriter(right_binding, correlated_map, lateral_depth); - rewriter.VisitOperator(*plan); - return plan; - } - case LogicalOperatorType::LOGICAL_LIMIT: { - auto &limit = plan->Cast(); - if (limit.limit || limit.offset) { - throw ParserException("Non-constant limit or offset not supported in correlated subquery"); - } - auto rownum_alias = "limit_rownum"; - unique_ptr child; - unique_ptr order_by; - - // check if the direct child of this LIMIT node is an ORDER BY node, if so, keep it separate - // this is done for an optimization to avoid having to compute the total order - if (plan->children[0]->type == LogicalOperatorType::LOGICAL_ORDER_BY) { - order_by = unique_ptr_cast(std::move(plan->children[0])); - child = PushDownDependentJoinInternal(std::move(order_by->children[0]), parent_propagate_null_values, - lateral_depth); - } else { - child = PushDownDependentJoinInternal(std::move(plan->children[0]), parent_propagate_null_values, - lateral_depth); - } - auto child_column_count = child->GetColumnBindings().size(); - // we push a row_number() OVER (PARTITION BY [correlated columns]) - auto window_index = binder.GenerateTableIndex(); - auto window = make_uniq(window_index); - auto row_number = - make_uniq(ExpressionType::WINDOW_ROW_NUMBER, LogicalType::BIGINT, nullptr, nullptr); - auto partition_count = perform_delim ? correlated_columns.size() : 1; - for (idx_t i = 0; i < partition_count; i++) { - auto &col = correlated_columns[i]; - auto colref = make_uniq( - col.name, col.type, ColumnBinding(base_binding.table_index, base_binding.column_index + i)); - row_number->partitions.push_back(std::move(colref)); - } - if (order_by) { - // optimization: if there is an ORDER BY node followed by a LIMIT - // rather than computing the entire order, we push the ORDER BY expressions into the row_num computation - // this way, the order only needs to be computed per partition - row_number->orders = std::move(order_by->orders); - } - row_number->start = WindowBoundary::UNBOUNDED_PRECEDING; - row_number->end = WindowBoundary::CURRENT_ROW_ROWS; - window->expressions.push_back(std::move(row_number)); - window->children.push_back(std::move(child)); - - // add a filter based on the row_number - // the filter we add is "row_number > offset AND row_number <= offset + limit" - auto filter = make_uniq(); - unique_ptr condition; - auto row_num_ref = - make_uniq(rownum_alias, LogicalType::BIGINT, ColumnBinding(window_index, 0)); - - int64_t upper_bound_limit = NumericLimits::Maximum(); - TryAddOperator::Operation(limit.offset_val, limit.limit_val, upper_bound_limit); - auto upper_bound = make_uniq(Value::BIGINT(upper_bound_limit)); - condition = make_uniq(ExpressionType::COMPARE_LESSTHANOREQUALTO, row_num_ref->Copy(), - std::move(upper_bound)); - // we only need to add "row_number >= offset + 1" if offset is bigger than 0 - if (limit.offset_val > 0) { - auto lower_bound = make_uniq(Value::BIGINT(limit.offset_val)); - auto lower_comp = make_uniq(ExpressionType::COMPARE_GREATERTHAN, - row_num_ref->Copy(), std::move(lower_bound)); - auto conj = make_uniq(ExpressionType::CONJUNCTION_AND, std::move(lower_comp), - std::move(condition)); - condition = std::move(conj); - } - filter->expressions.push_back(std::move(condition)); - filter->children.push_back(std::move(window)); - // we prune away the row_number after the filter clause using the projection map - for (idx_t i = 0; i < child_column_count; i++) { - filter->projection_map.push_back(i); - } - return std::move(filter); - } - case LogicalOperatorType::LOGICAL_LIMIT_PERCENT: { - // NOTE: limit percent could be supported in a manner similar to the LIMIT above - // but instead of filtering by an exact number of rows, the limit should be expressed as - // COUNT computed over the partition multiplied by the percentage - throw ParserException("Limit percent operator not supported in correlated subquery"); - } - case LogicalOperatorType::LOGICAL_WINDOW: { - auto &window = plan->Cast(); - // push into children - plan->children[0] = - PushDownDependentJoinInternal(std::move(plan->children[0]), parent_propagate_null_values, lateral_depth); - // add the correlated columns to the PARTITION BY clauses in the Window - for (auto &expr : window.expressions) { - D_ASSERT(expr->GetExpressionClass() == ExpressionClass::BOUND_WINDOW); - auto &w = expr->Cast(); - for (idx_t i = 0; i < correlated_columns.size(); i++) { - w.partitions.push_back(make_uniq( - correlated_columns[i].type, - ColumnBinding(base_binding.table_index, base_binding.column_index + i))); - } - } - return plan; - } - case LogicalOperatorType::LOGICAL_EXCEPT: - case LogicalOperatorType::LOGICAL_INTERSECT: - case LogicalOperatorType::LOGICAL_UNION: { - auto &setop = plan->Cast(); - // set operator, push into both children -#ifdef DEBUG - plan->children[0]->ResolveOperatorTypes(); - plan->children[1]->ResolveOperatorTypes(); - D_ASSERT(plan->children[0]->types == plan->children[1]->types); -#endif - plan->children[0] = PushDownDependentJoin(std::move(plan->children[0])); - plan->children[1] = PushDownDependentJoin(std::move(plan->children[1])); -#ifdef DEBUG - D_ASSERT(plan->children[0]->GetColumnBindings().size() == plan->children[1]->GetColumnBindings().size()); - plan->children[0]->ResolveOperatorTypes(); - plan->children[1]->ResolveOperatorTypes(); - D_ASSERT(plan->children[0]->types == plan->children[1]->types); -#endif - // we have to refer to the setop index now - base_binding.table_index = setop.table_index; - base_binding.column_index = setop.column_count; - setop.column_count += correlated_columns.size(); - return plan; - } - case LogicalOperatorType::LOGICAL_DISTINCT: { - auto &distinct = plan->Cast(); - // push down into child - distinct.children[0] = PushDownDependentJoin(std::move(distinct.children[0])); - // add all correlated columns to the distinct targets - for (idx_t i = 0; i < correlated_columns.size(); i++) { - distinct.distinct_targets.push_back(make_uniq( - correlated_columns[i].type, ColumnBinding(base_binding.table_index, base_binding.column_index + i))); - } - return plan; - } - case LogicalOperatorType::LOGICAL_EXPRESSION_GET: { - // expression get - // first we flatten the dependent join in the child - plan->children[0] = - PushDownDependentJoinInternal(std::move(plan->children[0]), parent_propagate_null_values, lateral_depth); - // then we replace any correlated expressions with the corresponding entry in the correlated_map - RewriteCorrelatedExpressions rewriter(base_binding, correlated_map, lateral_depth); - rewriter.VisitOperator(*plan); - // now we add all the correlated columns to each of the expressions of the expression scan - auto &expr_get = plan->Cast(); - for (idx_t i = 0; i < correlated_columns.size(); i++) { - for (auto &expr_list : expr_get.expressions) { - auto colref = make_uniq( - correlated_columns[i].type, ColumnBinding(base_binding.table_index, base_binding.column_index + i)); - expr_list.push_back(std::move(colref)); - } - expr_get.expr_types.push_back(correlated_columns[i].type); - } - - base_binding.table_index = expr_get.table_index; - this->delim_offset = base_binding.column_index = expr_get.expr_types.size() - correlated_columns.size(); - this->data_offset = 0; - return plan; - } - case LogicalOperatorType::LOGICAL_PIVOT: - throw BinderException("PIVOT is not supported in correlated subqueries yet"); - case LogicalOperatorType::LOGICAL_ORDER_BY: - plan->children[0] = PushDownDependentJoin(std::move(plan->children[0])); - return plan; - case LogicalOperatorType::LOGICAL_GET: { - auto &get = plan->Cast(); - if (get.children.size() != 1) { - throw InternalException("Flatten dependent joins - logical get encountered without children"); - } - plan->children[0] = PushDownDependentJoin(std::move(plan->children[0])); - for (idx_t i = 0; i < correlated_columns.size(); i++) { - get.projected_input.push_back(this->delim_offset + i); - } - this->delim_offset = get.returned_types.size(); - this->data_offset = 0; - return plan; - } - case LogicalOperatorType::LOGICAL_RECURSIVE_CTE: { - throw BinderException("Recursive CTEs not (yet) supported in correlated subquery"); - } - case LogicalOperatorType::LOGICAL_MATERIALIZED_CTE: { - throw BinderException("Materialized CTEs not (yet) supported in correlated subquery"); - } - case LogicalOperatorType::LOGICAL_DELIM_JOIN: { - throw BinderException("Nested lateral joins or lateral joins in correlated subqueries are not (yet) supported"); - } - case LogicalOperatorType::LOGICAL_SAMPLE: - throw BinderException("Sampling in correlated subqueries is not (yet) supported"); - default: - throw InternalException("Logical operator type \"%s\" for dependent join", LogicalOperatorToString(plan->type)); - } -} - -} // namespace duckdb - - - - - -#include - -namespace duckdb { - -HasCorrelatedExpressions::HasCorrelatedExpressions(const vector &correlated, bool lateral, - idx_t lateral_depth) - : has_correlated_expressions(false), lateral(lateral), correlated_columns(correlated), - lateral_depth(lateral_depth) { -} - -void HasCorrelatedExpressions::VisitOperator(LogicalOperator &op) { - VisitOperatorExpressions(op); -} - -unique_ptr HasCorrelatedExpressions::VisitReplace(BoundColumnRefExpression &expr, - unique_ptr *expr_ptr) { - // Indicates local correlations (all correlations within a child) for the root - if (expr.depth <= lateral_depth) { - return nullptr; - } - - // Should never happen - if (expr.depth > 1 + lateral_depth) { - if (lateral) { - throw BinderException("Invalid lateral depth encountered for an expression"); - } - throw InternalException("Expression with depth > 1 detected in non-lateral join"); - } - // Note: This is added, since we only want to set has_correlated_expressions to true when the - // BoundSubqueryExpression has the same bindings as one of the correlated_columns from the left hand side - // (correlated_columns is the correlated_columns from left hand side) - bool found_match = false; - for (idx_t i = 0; i < correlated_columns.size(); i++) { - if (correlated_columns[i].binding == expr.binding) { - found_match = true; - break; - } - } - // correlated column reference - D_ASSERT(expr.depth == lateral_depth + 1); - has_correlated_expressions = found_match; - return nullptr; -} - -unique_ptr HasCorrelatedExpressions::VisitReplace(BoundSubqueryExpression &expr, - unique_ptr *expr_ptr) { - if (!expr.IsCorrelated()) { - return nullptr; - } - // check if the subquery contains any of the correlated expressions that we are concerned about in this node - for (idx_t i = 0; i < correlated_columns.size(); i++) { - if (std::find(expr.binder->correlated_columns.begin(), expr.binder->correlated_columns.end(), - correlated_columns[i]) != expr.binder->correlated_columns.end()) { - has_correlated_expressions = true; - break; - } - } - return nullptr; -} - -} // namespace duckdb - - - - - - - - - - - - -namespace duckdb { - -RewriteCorrelatedExpressions::RewriteCorrelatedExpressions(ColumnBinding base_binding, - column_binding_map_t &correlated_map, - idx_t lateral_depth, bool recursive_rewrite) - : base_binding(base_binding), correlated_map(correlated_map), lateral_depth(lateral_depth), - recursive_rewrite(recursive_rewrite) { -} - -void RewriteCorrelatedExpressions::VisitOperator(LogicalOperator &op) { - if (recursive_rewrite) { - // Update column bindings from left child of lateral to right child - if (op.type == LogicalOperatorType::LOGICAL_DEPENDENT_JOIN) { - D_ASSERT(op.children.size() == 2); - VisitOperator(*op.children[0]); - lateral_depth++; - VisitOperator(*op.children[1]); - lateral_depth--; - } else { - VisitOperatorChildren(op); - } - } - // update the bindings in the correlated columns of the dependendent join - if (op.type == LogicalOperatorType::LOGICAL_DEPENDENT_JOIN) { - auto &plan = op.Cast(); - for (auto &corr : plan.correlated_columns) { - auto entry = correlated_map.find(corr.binding); - if (entry != correlated_map.end()) { - corr.binding = ColumnBinding(base_binding.table_index, base_binding.column_index + entry->second); - } - } - } - VisitOperatorExpressions(op); -} - -unique_ptr RewriteCorrelatedExpressions::VisitReplace(BoundColumnRefExpression &expr, - unique_ptr *expr_ptr) { - if (expr.depth <= lateral_depth) { - // Indicates local correlations not relevant for the current the rewrite - return nullptr; - } - // correlated column reference - // replace with the entry referring to the duplicate eliminated scan - // if this assertion occurs it generally means the bindings are inappropriate set in the binder or - // we either missed to account for lateral binder or over-counted for the lateral binder - D_ASSERT(expr.depth == 1 + lateral_depth); - auto entry = correlated_map.find(expr.binding); - D_ASSERT(entry != correlated_map.end()); - - expr.binding = ColumnBinding(base_binding.table_index, base_binding.column_index + entry->second); - if (recursive_rewrite) { - D_ASSERT(expr.depth > 1); - expr.depth--; - } else { - expr.depth = 0; - } - return nullptr; -} - -unique_ptr RewriteCorrelatedExpressions::VisitReplace(BoundSubqueryExpression &expr, - unique_ptr *expr_ptr) { - if (!expr.IsCorrelated()) { - return nullptr; - } - // subquery detected within this subquery - // recursively rewrite it using the RewriteCorrelatedRecursive class - RewriteCorrelatedRecursive rewrite(expr, base_binding, correlated_map); - rewrite.RewriteCorrelatedSubquery(expr); - return nullptr; -} - -RewriteCorrelatedExpressions::RewriteCorrelatedRecursive::RewriteCorrelatedRecursive( - BoundSubqueryExpression &parent, ColumnBinding base_binding, column_binding_map_t &correlated_map) - : parent(parent), base_binding(base_binding), correlated_map(correlated_map) { -} - -void RewriteCorrelatedExpressions::RewriteCorrelatedRecursive::RewriteJoinRefRecursive(BoundTableRef &ref) { - // recursively rewrite bindings in the correlated columns for the table ref and all the children - if (ref.type == TableReferenceType::JOIN) { - auto &bound_join = ref.Cast(); - for (auto &corr : bound_join.correlated_columns) { - auto entry = correlated_map.find(corr.binding); - if (entry != correlated_map.end()) { - corr.binding = ColumnBinding(base_binding.table_index, base_binding.column_index + entry->second); - } - } - RewriteJoinRefRecursive(*bound_join.left); - RewriteJoinRefRecursive(*bound_join.right); - } -} - -void RewriteCorrelatedExpressions::RewriteCorrelatedRecursive::RewriteCorrelatedSubquery( - BoundSubqueryExpression &expr) { - // rewrite the binding in the correlated list of the subquery) - for (auto &corr : expr.binder->correlated_columns) { - auto entry = correlated_map.find(corr.binding); - if (entry != correlated_map.end()) { - corr.binding = ColumnBinding(base_binding.table_index, base_binding.column_index + entry->second); - } - } - // TODO: Cleanup and find a better way to do this - auto &node = *expr.subquery; - if (node.type == QueryNodeType::SELECT_NODE) { - // Found an unplanned select node, need to update column bindings correlated columns in the from tables - auto &bound_select = node.Cast(); - if (bound_select.from_table) { - BoundTableRef &table_ref = *bound_select.from_table; - RewriteJoinRefRecursive(table_ref); - } - } - // now rewrite any correlated BoundColumnRef expressions inside the subquery - ExpressionIterator::EnumerateQueryNodeChildren(*expr.subquery, - [&](Expression &child) { RewriteCorrelatedExpressions(child); }); -} - -void RewriteCorrelatedExpressions::RewriteCorrelatedRecursive::RewriteCorrelatedExpressions(Expression &child) { - if (child.type == ExpressionType::BOUND_COLUMN_REF) { - // bound column reference - auto &bound_colref = child.Cast(); - if (bound_colref.depth == 0) { - // not a correlated column, ignore - return; - } - // correlated column - // check the correlated map - auto entry = correlated_map.find(bound_colref.binding); - if (entry != correlated_map.end()) { - // we found the column in the correlated map! - // update the binding and reduce the depth by 1 - bound_colref.binding = ColumnBinding(base_binding.table_index, base_binding.column_index + entry->second); - bound_colref.depth--; - } - } else if (child.type == ExpressionType::SUBQUERY) { - // we encountered another subquery: rewrite recursively - D_ASSERT(child.GetExpressionClass() == ExpressionClass::BOUND_SUBQUERY); - auto &bound_subquery = child.Cast(); - RewriteCorrelatedRecursive rewrite(bound_subquery, base_binding, correlated_map); - rewrite.RewriteCorrelatedSubquery(bound_subquery); - } -} - -RewriteCountAggregates::RewriteCountAggregates(column_binding_map_t &replacement_map) - : replacement_map(replacement_map) { -} - -unique_ptr RewriteCountAggregates::VisitReplace(BoundColumnRefExpression &expr, - unique_ptr *expr_ptr) { - auto entry = replacement_map.find(expr.binding); - if (entry != replacement_map.end()) { - // reference to a COUNT(*) aggregate - // replace this with CASE WHEN COUNT(*) IS NULL THEN 0 ELSE COUNT(*) END - auto is_null = make_uniq(ExpressionType::OPERATOR_IS_NULL, LogicalType::BOOLEAN); - is_null->children.push_back(expr.Copy()); - auto check = std::move(is_null); - auto result_if_true = make_uniq(Value::Numeric(expr.return_type, 0)); - auto result_if_false = std::move(*expr_ptr); - return make_uniq(std::move(check), std::move(result_if_true), std::move(result_if_false)); - } - return nullptr; -} - -} // namespace duckdb - - - - - - - - - - - - - -#include - -namespace duckdb { - -Binding::Binding(BindingType binding_type, const string &alias, vector coltypes, vector colnames, - idx_t index) - : binding_type(binding_type), alias(alias), index(index), types(std::move(coltypes)), names(std::move(colnames)) { - D_ASSERT(types.size() == names.size()); - for (idx_t i = 0; i < names.size(); i++) { - auto &name = names[i]; - D_ASSERT(!name.empty()); - if (name_map.find(name) != name_map.end()) { - throw BinderException("table \"%s\" has duplicate column name \"%s\"", alias, name); - } - name_map[name] = i; - } -} - -bool Binding::TryGetBindingIndex(const string &column_name, column_t &result) { - auto entry = name_map.find(column_name); - if (entry == name_map.end()) { - return false; - } - auto column_info = entry->second; - result = column_info; - return true; -} - -column_t Binding::GetBindingIndex(const string &column_name) { - column_t result; - if (!TryGetBindingIndex(column_name, result)) { - throw InternalException("Binding index for column \"%s\" not found", column_name); - } - return result; -} - -bool Binding::HasMatchingBinding(const string &column_name) { - column_t result; - return TryGetBindingIndex(column_name, result); -} - -string Binding::ColumnNotFoundError(const string &column_name) const { - return StringUtil::Format("Values list \"%s\" does not have a column named \"%s\"", alias, column_name); -} - -BindResult Binding::Bind(ColumnRefExpression &colref, idx_t depth) { - column_t column_index; - bool success = false; - success = TryGetBindingIndex(colref.GetColumnName(), column_index); - if (!success) { - return BindResult(ColumnNotFoundError(colref.GetColumnName())); - } - ColumnBinding binding; - binding.table_index = index; - binding.column_index = column_index; - LogicalType sql_type = types[column_index]; - if (colref.alias.empty()) { - colref.alias = names[column_index]; - } - return BindResult(make_uniq(colref.GetName(), sql_type, binding, depth)); -} - -optional_ptr Binding::GetStandardEntry() { - return nullptr; -} - -EntryBinding::EntryBinding(const string &alias, vector types_p, vector names_p, idx_t index, - StandardEntry &entry) - : Binding(BindingType::CATALOG_ENTRY, alias, std::move(types_p), std::move(names_p), index), entry(entry) { -} - -optional_ptr EntryBinding::GetStandardEntry() { - return &entry; -} - -TableBinding::TableBinding(const string &alias, vector types_p, vector names_p, - vector &bound_column_ids, optional_ptr entry, idx_t index, - bool add_row_id) - : Binding(BindingType::TABLE, alias, std::move(types_p), std::move(names_p), index), - bound_column_ids(bound_column_ids), entry(entry) { - if (add_row_id) { - if (name_map.find("rowid") == name_map.end()) { - name_map["rowid"] = COLUMN_IDENTIFIER_ROW_ID; - } - } -} - -static void ReplaceAliases(ParsedExpression &expr, const ColumnList &list, - const unordered_map &alias_map) { - if (expr.type == ExpressionType::COLUMN_REF) { - auto &colref = expr.Cast(); - D_ASSERT(!colref.IsQualified()); - auto &col_names = colref.column_names; - D_ASSERT(col_names.size() == 1); - auto idx_entry = list.GetColumnIndex(col_names[0]); - auto &alias = alias_map.at(idx_entry.index); - col_names = {alias}; - } - ParsedExpressionIterator::EnumerateChildren( - expr, [&](const ParsedExpression &child) { ReplaceAliases((ParsedExpression &)child, list, alias_map); }); -} - -static void BakeTableName(ParsedExpression &expr, const string &table_name) { - if (expr.type == ExpressionType::COLUMN_REF) { - auto &colref = expr.Cast(); - D_ASSERT(!colref.IsQualified()); - auto &col_names = colref.column_names; - col_names.insert(col_names.begin(), table_name); - } - ParsedExpressionIterator::EnumerateChildren( - expr, [&](const ParsedExpression &child) { BakeTableName((ParsedExpression &)child, table_name); }); -} - -unique_ptr TableBinding::ExpandGeneratedColumn(const string &column_name) { - auto catalog_entry = GetStandardEntry(); - D_ASSERT(catalog_entry); // Should only be called on a TableBinding - - D_ASSERT(catalog_entry->type == CatalogType::TABLE_ENTRY); - auto &table_entry = catalog_entry->Cast(); - - // Get the index of the generated column - auto column_index = GetBindingIndex(column_name); - D_ASSERT(table_entry.GetColumn(LogicalIndex(column_index)).Generated()); - // Get a copy of the generated column - auto expression = table_entry.GetColumn(LogicalIndex(column_index)).GeneratedExpression().Copy(); - unordered_map alias_map; - for (auto &entry : name_map) { - alias_map[entry.second] = entry.first; - } - ReplaceAliases(*expression, table_entry.GetColumns(), alias_map); - BakeTableName(*expression, alias); - return (expression); -} - -const vector &TableBinding::GetBoundColumnIds() const { -#ifdef DEBUG - unordered_set column_ids; - for (auto &id : bound_column_ids) { - auto result = column_ids.insert(id); - // assert that all entries in the bound_column_ids are unique - D_ASSERT(result.second); - auto it = std::find_if(name_map.begin(), name_map.end(), - [&](const std::pair &it) { return it.second == id; }); - // assert that every id appears in the name_map - D_ASSERT(it != name_map.end()); - // the order that they appear in is not guaranteed to be sequential - } -#endif - return bound_column_ids; -} - -ColumnBinding TableBinding::GetColumnBinding(column_t column_index) { - auto &column_ids = bound_column_ids; - ColumnBinding binding; - - // Locate the column_id that matches the 'column_index' - auto it = std::find_if(column_ids.begin(), column_ids.end(), - [&](const column_t &id) -> bool { return id == column_index; }); - // Get the index of it - binding.column_index = std::distance(column_ids.begin(), it); - // If it wasn't found, add it - if (it == column_ids.end()) { - column_ids.push_back(column_index); - } - - binding.table_index = index; - return binding; -} - -BindResult TableBinding::Bind(ColumnRefExpression &colref, idx_t depth) { - auto &column_name = colref.GetColumnName(); - column_t column_index; - bool success = false; - success = TryGetBindingIndex(column_name, column_index); - if (!success) { - return BindResult(ColumnNotFoundError(column_name)); - } - auto entry = GetStandardEntry(); - if (entry && column_index != COLUMN_IDENTIFIER_ROW_ID) { - D_ASSERT(entry->type == CatalogType::TABLE_ENTRY); - // Either there is no table, or the columns category has to be standard - auto &table_entry = entry->Cast(); - auto &column_entry = table_entry.GetColumn(LogicalIndex(column_index)); - (void)table_entry; - (void)column_entry; - D_ASSERT(column_entry.Category() == TableColumnType::STANDARD); - } - // fetch the type of the column - LogicalType col_type; - if (column_index == COLUMN_IDENTIFIER_ROW_ID) { - // row id: BIGINT type - col_type = LogicalType::BIGINT; - } else { - // normal column: fetch type from base column - col_type = types[column_index]; - if (colref.alias.empty()) { - colref.alias = names[column_index]; - } - } - ColumnBinding binding = GetColumnBinding(column_index); - return BindResult(make_uniq(colref.GetName(), col_type, binding, depth)); -} - -optional_ptr TableBinding::GetStandardEntry() { - return entry; -} - -string TableBinding::ColumnNotFoundError(const string &column_name) const { - return StringUtil::Format("Table \"%s\" does not have a column named \"%s\"", alias, column_name); -} - -DummyBinding::DummyBinding(vector types_p, vector names_p, string dummy_name_p) - : Binding(BindingType::DUMMY, DummyBinding::DUMMY_NAME + dummy_name_p, std::move(types_p), std::move(names_p), - DConstants::INVALID_INDEX), - dummy_name(std::move(dummy_name_p)) { -} - -BindResult DummyBinding::Bind(ColumnRefExpression &colref, idx_t depth) { - column_t column_index; - if (!TryGetBindingIndex(colref.GetColumnName(), column_index)) { - throw InternalException("Column %s not found in bindings", colref.GetColumnName()); - } - ColumnBinding binding(index, column_index); - - // we are binding a parameter to create the dummy binding, no arguments are supplied - return BindResult(make_uniq(colref.GetName(), types[column_index], binding, depth)); -} - -BindResult DummyBinding::Bind(ColumnRefExpression &colref, idx_t lambda_index, idx_t depth) { - column_t column_index; - if (!TryGetBindingIndex(colref.GetColumnName(), column_index)) { - throw InternalException("Column %s not found in bindings", colref.GetColumnName()); - } - ColumnBinding binding(index, column_index); - return BindResult( - make_uniq(colref.GetName(), types[column_index], binding, lambda_index, depth)); -} - -unique_ptr DummyBinding::ParamToArg(ColumnRefExpression &colref) { - column_t column_index; - if (!TryGetBindingIndex(colref.GetColumnName(), column_index)) { - throw InternalException("Column %s not found in macro", colref.GetColumnName()); - } - auto arg = (*arguments)[column_index]->Copy(); - arg->alias = colref.alias; - return arg; -} - -} // namespace duckdb - - - - - -namespace duckdb { - -void TableFilterSet::PushFilter(idx_t column_index, unique_ptr filter) { - auto entry = filters.find(column_index); - if (entry == filters.end()) { - // no filter yet: push the filter directly - filters[column_index] = std::move(filter); - } else { - // there is already a filter: AND it together - if (entry->second->filter_type == TableFilterType::CONJUNCTION_AND) { - auto &and_filter = entry->second->Cast(); - and_filter.child_filters.push_back(std::move(filter)); - } else { - auto and_filter = make_uniq(); - and_filter->child_filters.push_back(std::move(entry->second)); - and_filter->child_filters.push_back(std::move(filter)); - filters[column_index] = std::move(and_filter); - } - } -} - -} // namespace duckdb - - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Arena Chunk -//===--------------------------------------------------------------------===// -ArenaChunk::ArenaChunk(Allocator &allocator, idx_t size) : current_position(0), maximum_size(size), prev(nullptr) { - D_ASSERT(size > 0); - data = allocator.Allocate(size); -} -ArenaChunk::~ArenaChunk() { - if (next) { - auto current_next = std::move(next); - while (current_next) { - current_next = std::move(current_next->next); - } - } -} - -//===--------------------------------------------------------------------===// -// Allocator Wrapper -//===--------------------------------------------------------------------===// -struct ArenaAllocatorData : public PrivateAllocatorData { - explicit ArenaAllocatorData(ArenaAllocator &allocator) : allocator(allocator) { - } - - ArenaAllocator &allocator; -}; - -static data_ptr_t ArenaAllocatorAllocate(PrivateAllocatorData *private_data, idx_t size) { - auto &allocator_data = private_data->Cast(); - return allocator_data.allocator.Allocate(size); -} - -static void ArenaAllocatorFree(PrivateAllocatorData *, data_ptr_t, idx_t) { - // nop -} - -static data_ptr_t ArenaAllocateReallocate(PrivateAllocatorData *private_data, data_ptr_t pointer, idx_t old_size, - idx_t size) { - auto &allocator_data = private_data->Cast(); - return allocator_data.allocator.Reallocate(pointer, old_size, size); -} -//===--------------------------------------------------------------------===// -// Arena Allocator -//===--------------------------------------------------------------------===// -ArenaAllocator::ArenaAllocator(Allocator &allocator, idx_t initial_capacity) - : allocator(allocator), arena_allocator(ArenaAllocatorAllocate, ArenaAllocatorFree, ArenaAllocateReallocate, - make_uniq(*this)) { - head = nullptr; - tail = nullptr; - current_capacity = initial_capacity; -} - -ArenaAllocator::~ArenaAllocator() { -} - -data_ptr_t ArenaAllocator::Allocate(idx_t len) { - D_ASSERT(!head || head->current_position <= head->maximum_size); - if (!head || head->current_position + len > head->maximum_size) { - do { - current_capacity *= 2; - } while (current_capacity < len); - auto new_chunk = make_unsafe_uniq(allocator, current_capacity); - if (head) { - head->prev = new_chunk.get(); - new_chunk->next = std::move(head); - } else { - tail = new_chunk.get(); - } - head = std::move(new_chunk); - } - D_ASSERT(head->current_position + len <= head->maximum_size); - auto result = head->data.get() + head->current_position; - head->current_position += len; - return result; -} - -data_ptr_t ArenaAllocator::Reallocate(data_ptr_t pointer, idx_t old_size, idx_t size) { - D_ASSERT(head); - if (old_size == size) { - // nothing to do - return pointer; - } - - auto head_ptr = head->data.get() + head->current_position; - int64_t diff = size - old_size; - if (pointer == head_ptr && (size < old_size || head->current_position + diff <= head->maximum_size)) { - // passed pointer is the head pointer, and the diff fits on the current chunk - head->current_position += diff; - return pointer; - } else { - // allocate new memory - auto result = Allocate(size); - memcpy(result, pointer, old_size); - return result; - } -} - -data_ptr_t ArenaAllocator::AllocateAligned(idx_t size) { - return Allocate(AlignValue(size)); -} - -data_ptr_t ArenaAllocator::ReallocateAligned(data_ptr_t pointer, idx_t old_size, idx_t size) { - return Reallocate(pointer, old_size, AlignValue(size)); -} - -void ArenaAllocator::Reset() { - if (head) { - // destroy all chunks except the current one - if (head->next) { - auto current_next = std::move(head->next); - while (current_next) { - current_next = std::move(current_next->next); - } - } - tail = head.get(); - - // reset the head - head->current_position = 0; - head->prev = nullptr; - } -} - -void ArenaAllocator::Destroy() { - head = nullptr; - tail = nullptr; - current_capacity = ARENA_ALLOCATOR_INITIAL_CAPACITY; -} - -void ArenaAllocator::Move(ArenaAllocator &other) { - D_ASSERT(!other.head); - other.tail = tail; - other.head = std::move(head); - other.current_capacity = current_capacity; - Destroy(); -} - -ArenaChunk *ArenaAllocator::GetHead() { - return head.get(); -} - -ArenaChunk *ArenaAllocator::GetTail() { - return tail; -} - -bool ArenaAllocator::IsEmpty() const { - return head == nullptr; -} - -idx_t ArenaAllocator::SizeInBytes() const { - idx_t total_size = 0; - if (!IsEmpty()) { - auto current = head.get(); - while (current != nullptr) { - total_size += current->current_position; - current = current->next.get(); - } - } - return total_size; -} - -} // namespace duckdb - - - -namespace duckdb { - -Block::Block(Allocator &allocator, block_id_t id) - : FileBuffer(allocator, FileBufferType::BLOCK, Storage::BLOCK_SIZE), id(id) { -} - -Block::Block(Allocator &allocator, block_id_t id, uint32_t internal_size) - : FileBuffer(allocator, FileBufferType::BLOCK, internal_size), id(id) { - D_ASSERT((AllocSize() & (Storage::SECTOR_SIZE - 1)) == 0); -} - -Block::Block(FileBuffer &source, block_id_t id) : FileBuffer(source, FileBufferType::BLOCK), id(id) { - D_ASSERT((AllocSize() & (Storage::SECTOR_SIZE - 1)) == 0); -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -BlockHandle::BlockHandle(BlockManager &block_manager, block_id_t block_id_p) - : block_manager(block_manager), readers(0), block_id(block_id_p), buffer(nullptr), eviction_timestamp(0), - can_destroy(false), memory_charge(block_manager.buffer_manager.GetBufferPool()), unswizzled(nullptr) { - eviction_timestamp = 0; - state = BlockState::BLOCK_UNLOADED; - memory_usage = Storage::BLOCK_ALLOC_SIZE; -} - -BlockHandle::BlockHandle(BlockManager &block_manager, block_id_t block_id_p, unique_ptr buffer_p, - bool can_destroy_p, idx_t block_size, BufferPoolReservation &&reservation) - : block_manager(block_manager), readers(0), block_id(block_id_p), eviction_timestamp(0), can_destroy(can_destroy_p), - memory_charge(block_manager.buffer_manager.GetBufferPool()), unswizzled(nullptr) { - buffer = std::move(buffer_p); - state = BlockState::BLOCK_LOADED; - memory_usage = block_size; - memory_charge = std::move(reservation); -} - -BlockHandle::~BlockHandle() { // NOLINT: allow internal exceptions - // being destroyed, so any unswizzled pointers are just binary junk now. - unswizzled = nullptr; - auto &buffer_manager = block_manager.buffer_manager; - // no references remain to this block: erase - if (buffer && state == BlockState::BLOCK_LOADED) { - D_ASSERT(memory_charge.size > 0); - // the block is still loaded in memory: erase it - buffer.reset(); - memory_charge.Resize(0); - } else { - D_ASSERT(memory_charge.size == 0); - } - buffer_manager.GetBufferPool().PurgeQueue(); - block_manager.UnregisterBlock(block_id, can_destroy); -} - -unique_ptr AllocateBlock(BlockManager &block_manager, unique_ptr reusable_buffer, - block_id_t block_id) { - if (reusable_buffer) { - // re-usable buffer: re-use it - if (reusable_buffer->type == FileBufferType::BLOCK) { - // we can reuse the buffer entirely - auto &block = reinterpret_cast(*reusable_buffer); - block.id = block_id; - return unique_ptr_cast(std::move(reusable_buffer)); - } - auto block = block_manager.CreateBlock(block_id, reusable_buffer.get()); - reusable_buffer.reset(); - return block; - } else { - // no re-usable buffer: allocate a new block - return block_manager.CreateBlock(block_id, nullptr); - } -} - -BufferHandle BlockHandle::Load(shared_ptr &handle, unique_ptr reusable_buffer) { - if (handle->state == BlockState::BLOCK_LOADED) { - // already loaded - D_ASSERT(handle->buffer); - return BufferHandle(handle, handle->buffer.get()); - } - - auto &block_manager = handle->block_manager; - if (handle->block_id < MAXIMUM_BLOCK) { - auto block = AllocateBlock(block_manager, std::move(reusable_buffer), handle->block_id); - block_manager.Read(*block); - handle->buffer = std::move(block); - } else { - if (handle->can_destroy) { - return BufferHandle(); - } else { - handle->buffer = - block_manager.buffer_manager.ReadTemporaryBuffer(handle->block_id, std::move(reusable_buffer)); - } - } - handle->state = BlockState::BLOCK_LOADED; - return BufferHandle(handle, handle->buffer.get()); -} - -unique_ptr BlockHandle::UnloadAndTakeBlock() { - if (state == BlockState::BLOCK_UNLOADED) { - // already unloaded: nothing to do - return nullptr; - } - D_ASSERT(!unswizzled); - D_ASSERT(CanUnload()); - - if (block_id >= MAXIMUM_BLOCK && !can_destroy) { - // temporary block that cannot be destroyed: write to temporary file - block_manager.buffer_manager.WriteTemporaryBuffer(block_id, *buffer); - } - memory_charge.Resize(0); - state = BlockState::BLOCK_UNLOADED; - return std::move(buffer); -} - -void BlockHandle::Unload() { - auto block = UnloadAndTakeBlock(); - block.reset(); -} - -bool BlockHandle::CanUnload() { - if (state == BlockState::BLOCK_UNLOADED) { - // already unloaded - return false; - } - if (readers > 0) { - // there are active readers - return false; - } - if (block_id >= MAXIMUM_BLOCK && !can_destroy && !block_manager.buffer_manager.HasTemporaryDirectory()) { - // in order to unload this block we need to write it to a temporary buffer - // however, no temporary directory is specified! - // hence we cannot unload the block - return false; - } - return true; -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -BlockManager::BlockManager(BufferManager &buffer_manager) - : buffer_manager(buffer_manager), metadata_manager(make_uniq(*this, buffer_manager)) { -} - -shared_ptr BlockManager::RegisterBlock(block_id_t block_id) { - lock_guard lock(blocks_lock); - // check if the block already exists - auto entry = blocks.find(block_id); - if (entry != blocks.end()) { - // already exists: check if it hasn't expired yet - auto existing_ptr = entry->second.lock(); - if (existing_ptr) { - //! it hasn't! return it - return existing_ptr; - } - } - // create a new block pointer for this block - auto result = make_shared(*this, block_id); - // register the block pointer in the set of blocks as a weak pointer - blocks[block_id] = weak_ptr(result); - return result; -} - -shared_ptr BlockManager::ConvertToPersistent(block_id_t block_id, shared_ptr old_block) { - // pin the old block to ensure we have it loaded in memory - auto old_handle = buffer_manager.Pin(old_block); - D_ASSERT(old_block->state == BlockState::BLOCK_LOADED); - D_ASSERT(old_block->buffer); - - // Temp buffers can be larger than the storage block size. But persistent buffers - // cannot. - D_ASSERT(old_block->buffer->AllocSize() <= Storage::BLOCK_ALLOC_SIZE); - - // register a block with the new block id - auto new_block = RegisterBlock(block_id); - D_ASSERT(new_block->state == BlockState::BLOCK_UNLOADED); - D_ASSERT(new_block->readers == 0); - - // move the data from the old block into data for the new block - new_block->state = BlockState::BLOCK_LOADED; - new_block->buffer = ConvertBlock(block_id, *old_block->buffer); - new_block->memory_usage = old_block->memory_usage; - new_block->memory_charge = std::move(old_block->memory_charge); - - // clear the old buffer and unload it - old_block->buffer.reset(); - old_block->state = BlockState::BLOCK_UNLOADED; - old_block->memory_usage = 0; - old_handle.Destroy(); - old_block.reset(); - - // persist the new block to disk - Write(*new_block->buffer, block_id); - - buffer_manager.GetBufferPool().AddToEvictionQueue(new_block); - - return new_block; -} - -void BlockManager::UnregisterBlock(block_id_t block_id, bool can_destroy) { - if (block_id >= MAXIMUM_BLOCK) { - // in-memory buffer: buffer could have been offloaded to disk: remove the file - buffer_manager.DeleteTemporaryFile(block_id); - } else { - lock_guard lock(blocks_lock); - // on-disk block: erase from list of blocks in manager - blocks.erase(block_id); - } -} - -MetadataManager &BlockManager::GetMetadataManager() { - return *metadata_manager; -} - -void BlockManager::Truncate() { -} - -} // namespace duckdb - - - - -namespace duckdb { - -BufferHandle::BufferHandle() : handle(nullptr), node(nullptr) { -} - -BufferHandle::BufferHandle(shared_ptr handle_p, FileBuffer *node_p) - : handle(std::move(handle_p)), node(node_p) { -} - -BufferHandle::BufferHandle(BufferHandle &&other) noexcept { - std::swap(node, other.node); - std::swap(handle, other.handle); -} - -BufferHandle &BufferHandle::operator=(BufferHandle &&other) noexcept { - std::swap(node, other.node); - std::swap(handle, other.handle); - return *this; -} - -BufferHandle::~BufferHandle() { - Destroy(); -} - -bool BufferHandle::IsValid() const { - return node != nullptr; -} - -void BufferHandle::Destroy() { - if (!handle || !IsValid()) { - return; - } - handle->block_manager.buffer_manager.Unpin(handle); - handle.reset(); - node = nullptr; -} - -FileBuffer &BufferHandle::GetFileBuffer() { - D_ASSERT(node); - return *node; -} - -} // namespace duckdb - - - - -namespace duckdb { - -typedef duckdb_moodycamel::ConcurrentQueue eviction_queue_t; - -struct EvictionQueue { - eviction_queue_t q; -}; - -bool BufferEvictionNode::CanUnload(BlockHandle &handle_p) { - if (timestamp != handle_p.eviction_timestamp) { - // handle was used in between - return false; - } - return handle_p.CanUnload(); -} - -shared_ptr BufferEvictionNode::TryGetBlockHandle() { - auto handle_p = handle.lock(); - if (!handle_p) { - // BlockHandle has been destroyed - return nullptr; - } - if (!CanUnload(*handle_p)) { - // handle was used in between - return nullptr; - } - // this is the latest node in the queue with this handle - return handle_p; -} - -BufferPool::BufferPool(idx_t maximum_memory) - : current_memory(0), maximum_memory(maximum_memory), queue(make_uniq()), queue_insertions(0) { -} -BufferPool::~BufferPool() { -} - -void BufferPool::AddToEvictionQueue(shared_ptr &handle) { - constexpr int INSERT_INTERVAL = 1024; - - D_ASSERT(handle->readers == 0); - handle->eviction_timestamp++; - // After each 1024 insertions, run through the queue and purge. - if ((++queue_insertions % INSERT_INTERVAL) == 0) { - PurgeQueue(); - } - queue->q.enqueue(BufferEvictionNode(weak_ptr(handle), handle->eviction_timestamp)); -} - -void BufferPool::IncreaseUsedMemory(idx_t size) { - current_memory += size; -} - -idx_t BufferPool::GetUsedMemory() { - return current_memory; -} -idx_t BufferPool::GetMaxMemory() { - return maximum_memory; -} - -BufferPool::EvictionResult BufferPool::EvictBlocks(idx_t extra_memory, idx_t memory_limit, - unique_ptr *buffer) { - BufferEvictionNode node; - TempBufferPoolReservation r(*this, extra_memory); - while (current_memory > memory_limit) { - // get a block to unpin from the queue - if (!queue->q.try_dequeue(node)) { - // Failed to reserve. Adjust size of temp reservation to 0. - r.Resize(0); - return {false, std::move(r)}; - } - // get a reference to the underlying block pointer - auto handle = node.TryGetBlockHandle(); - if (!handle) { - continue; - } - // we might be able to free this block: grab the mutex and check if we can free it - lock_guard lock(handle->lock); - if (!node.CanUnload(*handle)) { - // something changed in the mean-time, bail out - continue; - } - // hooray, we can unload the block - if (buffer && handle->buffer->AllocSize() == extra_memory) { - // we can actually re-use the memory directly! - *buffer = handle->UnloadAndTakeBlock(); - return {true, std::move(r)}; - } else { - // release the memory and mark the block as unloaded - handle->Unload(); - } - } - return {true, std::move(r)}; -} - -void BufferPool::PurgeQueue() { - BufferEvictionNode node; - while (true) { - if (!queue->q.try_dequeue(node)) { - break; - } - auto handle = node.TryGetBlockHandle(); - if (!handle) { - continue; - } else { - queue->q.enqueue(std::move(node)); - break; - } - } -} - -void BufferPool::SetLimit(idx_t limit, const char *exception_postscript) { - lock_guard l_lock(limit_lock); - // try to evict until the limit is reached - if (!EvictBlocks(0, limit).success) { - throw OutOfMemoryException( - "Failed to change memory limit to %lld: could not free up enough memory for the new limit%s", limit, - exception_postscript); - } - idx_t old_limit = maximum_memory; - // set the global maximum memory to the new limit if successful - maximum_memory = limit; - // evict again - if (!EvictBlocks(0, limit).success) { - // failed: go back to old limit - maximum_memory = old_limit; - throw OutOfMemoryException( - "Failed to change memory limit to %lld: could not free up enough memory for the new limit%s", limit, - exception_postscript); - } -} - -} // namespace duckdb - - - -namespace duckdb { - -BufferPoolReservation::BufferPoolReservation(BufferPool &pool) : pool(pool) { -} - -BufferPoolReservation::BufferPoolReservation(BufferPoolReservation &&src) noexcept : pool(src.pool) { - size = src.size; - src.size = 0; -} - -BufferPoolReservation &BufferPoolReservation::operator=(BufferPoolReservation &&src) noexcept { - size = src.size; - src.size = 0; - return *this; -} - -BufferPoolReservation::~BufferPoolReservation() { - D_ASSERT(size == 0); -} - -void BufferPoolReservation::Resize(idx_t new_size) { - int64_t delta = (int64_t)new_size - size; - pool.IncreaseUsedMemory(delta); - size = new_size; -} - -void BufferPoolReservation::Merge(BufferPoolReservation &&src) { - size += src.size; - src.size = 0; -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -unique_ptr BufferManager::CreateStandardBufferManager(DatabaseInstance &db, DBConfig &config) { - return make_uniq(db, config.options.temporary_directory); -} - -shared_ptr BufferManager::RegisterSmallMemory(idx_t block_size) { - throw NotImplementedException("This type of BufferManager can not create 'small-memory' blocks"); -} - -Allocator &BufferManager::GetBufferAllocator() { - throw NotImplementedException("This type of BufferManager does not have an Allocator"); -} - -void BufferManager::ReserveMemory(idx_t size) { - throw NotImplementedException("This type of BufferManager can not reserve memory"); -} -void BufferManager::FreeReservedMemory(idx_t size) { - throw NotImplementedException("This type of BufferManager can not free reserved memory"); -} - -void BufferManager::SetLimit(idx_t limit) { - throw NotImplementedException("This type of BufferManager can not set a limit"); -} - -vector BufferManager::GetTemporaryFiles() { - throw InternalException("This type of BufferManager does not allow temporary files"); -} - -const string &BufferManager::GetTemporaryDirectory() { - throw InternalException("This type of BufferManager does not allow a temporary directory"); -} - -BufferPool &BufferManager::GetBufferPool() { - throw InternalException("This type of BufferManager does not have a buffer pool"); -} - -void BufferManager::SetTemporaryDirectory(const string &new_dir) { - throw NotImplementedException("This type of BufferManager can not set a temporary directory"); -} - -DatabaseInstance &BufferManager::GetDatabase() { - throw NotImplementedException("This type of BufferManager is not linked to a DatabaseInstance"); -} - -bool BufferManager::HasTemporaryDirectory() const { - return false; -} - -unique_ptr BufferManager::ConstructManagedBuffer(idx_t size, unique_ptr &&source, - FileBufferType type) { - throw NotImplementedException("This type of BufferManager can not construct managed buffers"); -} - -// Protected methods - -void BufferManager::AddToEvictionQueue(shared_ptr &handle) { - throw NotImplementedException("This type of BufferManager does not support 'AddToEvictionQueue"); -} - -void BufferManager::WriteTemporaryBuffer(block_id_t block_id, FileBuffer &buffer) { - throw NotImplementedException("This type of BufferManager does not support 'WriteTemporaryBuffer"); -} - -unique_ptr BufferManager::ReadTemporaryBuffer(block_id_t id, unique_ptr buffer) { - throw NotImplementedException("This type of BufferManager does not support 'ReadTemporaryBuffer"); -} - -void BufferManager::DeleteTemporaryFile(block_id_t id) { - throw NotImplementedException("This type of BufferManager does not support 'DeleteTemporaryFile"); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -CompressionType RowGroupWriter::GetColumnCompressionType(idx_t i) { - return table.GetColumn(LogicalIndex(i)).CompressionType(); -} - -void RowGroupWriter::RegisterPartialBlock(PartialBlockAllocation &&allocation) { - partial_block_manager.RegisterPartialBlock(std::move(allocation)); -} - -PartialBlockAllocation RowGroupWriter::GetBlockAllocation(uint32_t segment_size) { - return partial_block_manager.GetBlockAllocation(segment_size); -} - -void SingleFileRowGroupWriter::WriteColumnDataPointers(ColumnCheckpointState &column_checkpoint_state, - Serializer &serializer) { - const auto &data_pointers = column_checkpoint_state.data_pointers; - serializer.WriteProperty(100, "data_pointers", data_pointers); -} - -MetadataWriter &SingleFileRowGroupWriter::GetPayloadWriter() { - return table_data_writer; -} - -} // namespace duckdb - - - - - - - - - - -namespace duckdb { - -TableDataReader::TableDataReader(MetadataReader &reader, BoundCreateTableInfo &info) : reader(reader), info(info) { - info.data = make_uniq(info.Base().columns.LogicalColumnCount()); -} - -void TableDataReader::ReadTableData() { - auto &columns = info.Base().columns; - D_ASSERT(!columns.empty()); - - // We stored the table statistics as a unit in FinalizeTable. - BinaryDeserializer stats_deserializer(reader); - stats_deserializer.Begin(); - info.data->table_stats.Deserialize(stats_deserializer, columns); - stats_deserializer.End(); - - // Deserialize the row group pointers (lazily, just set the count and the pointer to them for now) - info.data->row_group_count = reader.Read(); - info.data->block_pointer = reader.GetMetaBlockPointer(); -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -TableDataWriter::TableDataWriter(TableCatalogEntry &table_p) : table(table_p.Cast()) { - D_ASSERT(table_p.IsDuckTable()); -} - -TableDataWriter::~TableDataWriter() { -} - -void TableDataWriter::WriteTableData(Serializer &metadata_serializer) { - // start scanning the table and append the data to the uncompressed segments - table.GetStorage().Checkpoint(*this, metadata_serializer); -} - -CompressionType TableDataWriter::GetColumnCompressionType(idx_t i) { - return table.GetColumn(LogicalIndex(i)).CompressionType(); -} - -void TableDataWriter::AddRowGroup(RowGroupPointer &&row_group_pointer, unique_ptr &&writer) { - row_group_pointers.push_back(std::move(row_group_pointer)); - writer.reset(); -} - -SingleFileTableDataWriter::SingleFileTableDataWriter(SingleFileCheckpointWriter &checkpoint_manager, - TableCatalogEntry &table, MetadataWriter &table_data_writer) - : TableDataWriter(table), checkpoint_manager(checkpoint_manager), table_data_writer(table_data_writer) { -} - -unique_ptr SingleFileTableDataWriter::GetRowGroupWriter(RowGroup &row_group) { - return make_uniq(table, checkpoint_manager.partial_block_manager, table_data_writer); -} - -void SingleFileTableDataWriter::FinalizeTable(TableStatistics &&global_stats, DataTableInfo *info, - Serializer &metadata_serializer) { - // store the current position in the metadata writer - // this is where the row groups for this table start - auto pointer = table_data_writer.GetMetaBlockPointer(); - - // Serialize statistics as a single unit - BinarySerializer stats_serializer(table_data_writer); - stats_serializer.Begin(); - global_stats.Serialize(stats_serializer); - stats_serializer.End(); - - // now start writing the row group pointers to disk - table_data_writer.Write(row_group_pointers.size()); - idx_t total_rows = 0; - for (auto &row_group_pointer : row_group_pointers) { - auto row_group_count = row_group_pointer.row_start + row_group_pointer.tuple_count; - if (row_group_count > total_rows) { - total_rows = row_group_count; - } - - // Each RowGroup is its own unit - BinarySerializer row_group_serializer(table_data_writer); - row_group_serializer.Begin(); - RowGroup::Serialize(row_group_pointer, row_group_serializer); - row_group_serializer.End(); - } - - auto index_pointers = info->indexes.SerializeIndexes(table_data_writer); - - // Now begin the metadata as a unit - // Pointer to the table itself goes to the metadata stream. - metadata_serializer.WriteProperty(101, "table_pointer", pointer); - metadata_serializer.WriteProperty(102, "total_rows", total_rows); - metadata_serializer.WriteProperty(103, "index_pointers", index_pointers); -} - -} // namespace duckdb - - - - -namespace duckdb { - -WriteOverflowStringsToDisk::WriteOverflowStringsToDisk(BlockManager &block_manager) - : block_manager(block_manager), block_id(INVALID_BLOCK), offset(0) { -} - -WriteOverflowStringsToDisk::~WriteOverflowStringsToDisk() { - // verify that the overflow writer has been flushed - D_ASSERT(Exception::UncaughtException() || offset == 0); -} - -shared_ptr UncompressedStringSegmentState::GetHandle(BlockManager &manager, block_id_t block_id) { - lock_guard lock(block_lock); - auto entry = handles.find(block_id); - if (entry != handles.end()) { - return entry->second; - } - auto result = manager.RegisterBlock(block_id); - handles.insert(make_pair(block_id, result)); - return result; -} - -void UncompressedStringSegmentState::RegisterBlock(BlockManager &manager, block_id_t block_id) { - lock_guard lock(block_lock); - auto entry = handles.find(block_id); - if (entry != handles.end()) { - throw InternalException("UncompressedStringSegmentState::RegisterBlock - block id %llu already exists", - block_id); - } - auto result = manager.RegisterBlock(block_id); - handles.insert(make_pair(block_id, std::move(result))); - on_disk_blocks.push_back(block_id); -} - -void WriteOverflowStringsToDisk::WriteString(UncompressedStringSegmentState &state, string_t string, - block_id_t &result_block, int32_t &result_offset) { - auto &buffer_manager = block_manager.buffer_manager; - if (!handle.IsValid()) { - handle = buffer_manager.Allocate(Storage::BLOCK_SIZE); - } - // first write the length of the string - if (block_id == INVALID_BLOCK || offset + 2 * sizeof(uint32_t) >= STRING_SPACE) { - AllocateNewBlock(state, block_manager.GetFreeBlockId()); - } - result_block = block_id; - result_offset = offset; - - // write the length field - auto data_ptr = handle.Ptr(); - auto string_length = string.GetSize(); - Store(string_length, data_ptr + offset); - offset += sizeof(uint32_t); - - // now write the remainder of the string - auto strptr = string.GetData(); - uint32_t remaining = string_length; - while (remaining > 0) { - uint32_t to_write = MinValue(remaining, STRING_SPACE - offset); - if (to_write > 0) { - memcpy(data_ptr + offset, strptr, to_write); - - remaining -= to_write; - offset += to_write; - strptr += to_write; - } - if (remaining > 0) { - D_ASSERT(offset == WriteOverflowStringsToDisk::STRING_SPACE); - // there is still remaining stuff to write - // now write the current block to disk and allocate a new block - AllocateNewBlock(state, block_manager.GetFreeBlockId()); - } - } -} - -void WriteOverflowStringsToDisk::Flush() { - if (block_id != INVALID_BLOCK && offset > 0) { - // zero-initialize the empty part of the overflow string buffer (if any) - if (offset < STRING_SPACE) { - memset(handle.Ptr() + offset, 0, STRING_SPACE - offset); - } - // write to disk - block_manager.Write(handle.GetFileBuffer(), block_id); - } - block_id = INVALID_BLOCK; - offset = 0; -} - -void WriteOverflowStringsToDisk::AllocateNewBlock(UncompressedStringSegmentState &state, block_id_t new_block_id) { - if (block_id != INVALID_BLOCK) { - // there is an old block, write it first - // write the new block id at the end of the previous block - Store(new_block_id, handle.Ptr() + WriteOverflowStringsToDisk::STRING_SPACE); - Flush(); - } - offset = 0; - block_id = new_block_id; - state.RegisterBlock(block_manager, new_block_id); -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -namespace duckdb { - -void ReorderTableEntries(catalog_entry_vector_t &tables); - -SingleFileCheckpointWriter::SingleFileCheckpointWriter(AttachedDatabase &db, BlockManager &block_manager) - : CheckpointWriter(db), partial_block_manager(block_manager, CheckpointType::FULL_CHECKPOINT) { -} - -BlockManager &SingleFileCheckpointWriter::GetBlockManager() { - auto &storage_manager = db.GetStorageManager().Cast(); - return *storage_manager.block_manager; -} - -MetadataWriter &SingleFileCheckpointWriter::GetMetadataWriter() { - return *metadata_writer; -} - -MetadataManager &SingleFileCheckpointWriter::GetMetadataManager() { - return GetBlockManager().GetMetadataManager(); -} - -unique_ptr SingleFileCheckpointWriter::GetTableDataWriter(TableCatalogEntry &table) { - return make_uniq(*this, table, *table_metadata_writer); -} - -static catalog_entry_vector_t GetCatalogEntries(vector> &schemas) { - catalog_entry_vector_t entries; - for (auto &schema_p : schemas) { - auto &schema = schema_p.get(); - entries.push_back(schema); - schema.Scan(CatalogType::TYPE_ENTRY, [&](CatalogEntry &entry) { - if (entry.internal) { - return; - } - entries.push_back(entry); - }); - - schema.Scan(CatalogType::SEQUENCE_ENTRY, [&](CatalogEntry &entry) { - if (entry.internal) { - return; - } - entries.push_back(entry); - }); - - catalog_entry_vector_t tables; - vector> views; - schema.Scan(CatalogType::TABLE_ENTRY, [&](CatalogEntry &entry) { - if (entry.internal) { - return; - } - if (entry.type == CatalogType::TABLE_ENTRY) { - tables.push_back(entry.Cast()); - } else if (entry.type == CatalogType::VIEW_ENTRY) { - views.push_back(entry.Cast()); - } else { - throw NotImplementedException("Catalog type for entries"); - } - }); - // Reorder tables because of foreign key constraint - ReorderTableEntries(tables); - for (auto &table : tables) { - entries.push_back(table.get()); - } - for (auto &view : views) { - entries.push_back(view.get()); - } - - schema.Scan(CatalogType::SCALAR_FUNCTION_ENTRY, [&](CatalogEntry &entry) { - if (entry.internal) { - return; - } - if (entry.type == CatalogType::MACRO_ENTRY) { - entries.push_back(entry); - } - }); - - schema.Scan(CatalogType::TABLE_FUNCTION_ENTRY, [&](CatalogEntry &entry) { - if (entry.internal) { - return; - } - if (entry.type == CatalogType::TABLE_MACRO_ENTRY) { - entries.push_back(entry); - } - }); - - schema.Scan(CatalogType::INDEX_ENTRY, [&](CatalogEntry &entry) { - D_ASSERT(!entry.internal); - entries.push_back(entry); - }); - } - return entries; -} - -void SingleFileCheckpointWriter::CreateCheckpoint() { - auto &config = DBConfig::Get(db); - auto &storage_manager = db.GetStorageManager().Cast(); - if (storage_manager.InMemory()) { - return; - } - // assert that the checkpoint manager hasn't been used before - D_ASSERT(!metadata_writer); - - auto &block_manager = GetBlockManager(); - auto &metadata_manager = GetMetadataManager(); - - //! Set up the writers for the checkpoints - metadata_writer = make_uniq(metadata_manager); - table_metadata_writer = make_uniq(metadata_manager); - - // get the id of the first meta block - auto meta_block = metadata_writer->GetMetaBlockPointer(); - - vector> schemas; - // we scan the set of committed schemas - auto &catalog = Catalog::GetCatalog(db).Cast(); - catalog.ScanSchemas([&](SchemaCatalogEntry &entry) { schemas.push_back(entry); }); - // write the actual data into the database - - // Create a serializer to write the checkpoint data - // The serialized format is roughly: - /* - { - schemas: [ - { - schema: , - custom_types: [ { type: }, ... ], - sequences: [ { sequence: }, ... ], - tables: [ { table: }, ... ], - views: [ { view: }, ... ], - macros: [ { macro: }, ... ], - table_macros: [ { table_macro: }, ... ], - indexes: [ { index: , root_offset }, ... ] - } - ] - } - */ - auto catalog_entries = GetCatalogEntries(schemas); - BinarySerializer serializer(*metadata_writer); - serializer.Begin(); - serializer.WriteList(100, "catalog_entries", catalog_entries.size(), [&](Serializer::List &list, idx_t i) { - auto &entry = catalog_entries[i]; - list.WriteObject([&](Serializer &obj) { WriteEntry(entry.get(), obj); }); - }); - serializer.End(); - - partial_block_manager.FlushPartialBlocks(); - metadata_writer->Flush(); - table_metadata_writer->Flush(); - - // write a checkpoint flag to the WAL - // this protects against the rare event that the database crashes AFTER writing the file, but BEFORE truncating the - // WAL we write an entry CHECKPOINT "meta_block_id" into the WAL upon loading, if we see there is an entry - // CHECKPOINT "meta_block_id", and the id MATCHES the head idin the file we know that the database was successfully - // checkpointed, so we know that we should avoid replaying the WAL to avoid duplicating data - auto wal = storage_manager.GetWriteAheadLog(); - wal->WriteCheckpoint(meta_block); - wal->Flush(); - - if (config.options.checkpoint_abort == CheckpointAbort::DEBUG_ABORT_BEFORE_HEADER) { - throw FatalException("Checkpoint aborted before header write because of PRAGMA checkpoint_abort flag"); - } - - // finally write the updated header - DatabaseHeader header; - header.meta_block = meta_block.block_pointer; - block_manager.WriteHeader(header); - - if (config.options.checkpoint_abort == CheckpointAbort::DEBUG_ABORT_BEFORE_TRUNCATE) { - throw FatalException("Checkpoint aborted before truncate because of PRAGMA checkpoint_abort flag"); - } - - // truncate the file - block_manager.Truncate(); - - // truncate the WAL - wal->Truncate(0); -} - -void CheckpointReader::LoadCheckpoint(ClientContext &context, MetadataReader &reader) { - BinaryDeserializer deserializer(reader); - deserializer.Begin(); - deserializer.ReadList(100, "catalog_entries", [&](Deserializer::List &list, idx_t i) { - return list.ReadObject([&](Deserializer &obj) { ReadEntry(context, obj); }); - }); - deserializer.End(); -} - -MetadataManager &SingleFileCheckpointReader::GetMetadataManager() { - return storage.block_manager->GetMetadataManager(); -} - -void SingleFileCheckpointReader::LoadFromStorage() { - auto &block_manager = *storage.block_manager; - auto &metadata_manager = GetMetadataManager(); - MetaBlockPointer meta_block(block_manager.GetMetaBlock(), 0); - if (!meta_block.IsValid()) { - // storage is empty - return; - } - - Connection con(storage.GetDatabase()); - con.BeginTransaction(); - // create the MetadataReader to read from the storage - MetadataReader reader(metadata_manager, meta_block); - // reader.SetContext(*con.context); - LoadCheckpoint(*con.context, reader); - con.Commit(); -} - -void CheckpointWriter::WriteEntry(CatalogEntry &entry, Serializer &serializer) { - serializer.WriteProperty(99, "catalog_type", entry.type); - - switch (entry.type) { - case CatalogType::SCHEMA_ENTRY: { - auto &schema = entry.Cast(); - WriteSchema(schema, serializer); - break; - } - case CatalogType::TYPE_ENTRY: { - auto &custom_type = entry.Cast(); - WriteType(custom_type, serializer); - break; - } - case CatalogType::SEQUENCE_ENTRY: { - auto &seq = entry.Cast(); - WriteSequence(seq, serializer); - break; - } - case CatalogType::TABLE_ENTRY: { - auto &table = entry.Cast(); - WriteTable(table, serializer); - break; - } - case CatalogType::VIEW_ENTRY: { - auto &view = entry.Cast(); - WriteView(view, serializer); - break; - } - case CatalogType::MACRO_ENTRY: { - auto ¯o = entry.Cast(); - WriteMacro(macro, serializer); - break; - } - case CatalogType::TABLE_MACRO_ENTRY: { - auto ¯o = entry.Cast(); - WriteTableMacro(macro, serializer); - break; - } - case CatalogType::INDEX_ENTRY: { - auto &index = entry.Cast(); - WriteIndex(index, serializer); - break; - } - default: - throw InternalException("Unrecognized catalog type in CheckpointWriter::WriteEntry"); - } -} - -//===--------------------------------------------------------------------===// -// Schema -//===--------------------------------------------------------------------===// -void CheckpointWriter::WriteSchema(SchemaCatalogEntry &schema, Serializer &serializer) { - // write the schema data - serializer.WriteProperty(100, "schema", &schema); -} - -void CheckpointReader::ReadEntry(ClientContext &context, Deserializer &deserializer) { - auto type = deserializer.ReadProperty(99, "type"); - - switch (type) { - case CatalogType::SCHEMA_ENTRY: { - ReadSchema(context, deserializer); - break; - } - case CatalogType::TYPE_ENTRY: { - ReadType(context, deserializer); - break; - } - case CatalogType::SEQUENCE_ENTRY: { - ReadSequence(context, deserializer); - break; - } - case CatalogType::TABLE_ENTRY: { - ReadTable(context, deserializer); - break; - } - case CatalogType::VIEW_ENTRY: { - ReadView(context, deserializer); - break; - } - case CatalogType::MACRO_ENTRY: { - ReadMacro(context, deserializer); - break; - } - case CatalogType::TABLE_MACRO_ENTRY: { - ReadTableMacro(context, deserializer); - break; - } - case CatalogType::INDEX_ENTRY: { - ReadIndex(context, deserializer); - break; - } - default: - throw InternalException("Unrecognized catalog type in CheckpointWriter::WriteEntry"); - } -} - -void CheckpointReader::ReadSchema(ClientContext &context, Deserializer &deserializer) { - // Read the schema and create it in the catalog - auto info = deserializer.ReadProperty>(100, "schema"); - auto &schema_info = info->Cast(); - - // we set create conflict to IGNORE_ON_CONFLICT, so that we can ignore a failure when recreating the main schema - schema_info.on_conflict = OnCreateConflict::IGNORE_ON_CONFLICT; - catalog.CreateSchema(context, schema_info); -} - -//===--------------------------------------------------------------------===// -// Views -//===--------------------------------------------------------------------===// -void CheckpointWriter::WriteView(ViewCatalogEntry &view, Serializer &serializer) { - serializer.WriteProperty(100, "view", &view); -} - -void CheckpointReader::ReadView(ClientContext &context, Deserializer &deserializer) { - auto info = deserializer.ReadProperty>(100, "view"); - auto &view_info = info->Cast(); - catalog.CreateView(context, view_info); -} - -//===--------------------------------------------------------------------===// -// Sequences -//===--------------------------------------------------------------------===// -void CheckpointWriter::WriteSequence(SequenceCatalogEntry &seq, Serializer &serializer) { - serializer.WriteProperty(100, "sequence", &seq); -} - -void CheckpointReader::ReadSequence(ClientContext &context, Deserializer &deserializer) { - auto info = deserializer.ReadProperty>(100, "sequence"); - auto &sequence_info = info->Cast(); - catalog.CreateSequence(context, sequence_info); -} - -//===--------------------------------------------------------------------===// -// Indexes -//===--------------------------------------------------------------------===// -void CheckpointWriter::WriteIndex(IndexCatalogEntry &index_catalog, Serializer &serializer) { - // The index data is written as part of WriteTableData. - // Here, we need only serialize the pointer to that data. - auto root_block_pointer = index_catalog.index->GetRootBlockPointer(); - serializer.WriteProperty(100, "index", &index_catalog); - serializer.WriteProperty(101, "root_block_pointer", root_block_pointer); -} - -void CheckpointReader::ReadIndex(ClientContext &context, Deserializer &deserializer) { - - // deserialize the index create info - auto create_info = deserializer.ReadProperty>(100, "index"); - auto &info = create_info->Cast(); - - // create the index in the catalog - auto &schema = catalog.GetSchema(context, create_info->schema); - auto &table = - catalog.GetEntry(context, CatalogType::TABLE_ENTRY, create_info->schema, info.table).Cast(); - - auto &index = schema.CreateIndex(context, info, table)->Cast(); - - index.info = table.GetStorage().info; - // insert the parsed expressions into the stored index so that we correctly (de)serialize it during consecutive - // checkpoints - for (auto &parsed_expr : info.parsed_expressions) { - index.parsed_expressions.push_back(parsed_expr->Copy()); - } - - // we deserialize the index lazily, i.e., we do not need to load any node information - // except the root block pointer - auto root_block_pointer = deserializer.ReadProperty(101, "root_block_pointer"); - - // obtain the parsed expressions of the ART from the index metadata - vector> parsed_expressions; - for (auto &parsed_expr : info.parsed_expressions) { - parsed_expressions.push_back(parsed_expr->Copy()); - } - D_ASSERT(!parsed_expressions.empty()); - - // add the table to the bind context to bind the parsed expressions - auto binder = Binder::CreateBinder(context); - vector column_types; - vector column_names; - for (auto &col : table.GetColumns().Logical()) { - column_types.push_back(col.Type()); - column_names.push_back(col.Name()); - } - - // create a binder to bind the parsed expressions - vector column_ids; - binder->bind_context.AddBaseTable(0, info.table, column_names, column_types, column_ids, &table); - IndexBinder idx_binder(*binder, context); - - // bind the parsed expressions to create unbound expressions - vector> unbound_expressions; - unbound_expressions.reserve(parsed_expressions.size()); - for (auto &expr : parsed_expressions) { - unbound_expressions.push_back(idx_binder.Bind(expr)); - } - - // create the index and add it to the storage - switch (info.index_type) { - case IndexType::ART: { - auto &storage = table.GetStorage(); - auto art = make_uniq(info.column_ids, TableIOManager::Get(storage), std::move(unbound_expressions), - info.constraint_type, storage.db, nullptr, root_block_pointer); - - index.index = art.get(); - storage.info->indexes.AddIndex(std::move(art)); - } break; - default: - throw InternalException("Unknown index type for ReadIndex"); - } -} - -//===--------------------------------------------------------------------===// -// Custom Types -//===--------------------------------------------------------------------===// -void CheckpointWriter::WriteType(TypeCatalogEntry &type, Serializer &serializer) { - serializer.WriteProperty(100, "type", &type); -} - -void CheckpointReader::ReadType(ClientContext &context, Deserializer &deserializer) { - auto info = deserializer.ReadProperty>(100, "type"); - auto &type_info = info->Cast(); - catalog.CreateType(context, type_info); -} - -//===--------------------------------------------------------------------===// -// Macro's -//===--------------------------------------------------------------------===// -void CheckpointWriter::WriteMacro(ScalarMacroCatalogEntry ¯o, Serializer &serializer) { - serializer.WriteProperty(100, "macro", ¯o); -} - -void CheckpointReader::ReadMacro(ClientContext &context, Deserializer &deserializer) { - auto info = deserializer.ReadProperty>(100, "macro"); - auto ¯o_info = info->Cast(); - catalog.CreateFunction(context, macro_info); -} - -void CheckpointWriter::WriteTableMacro(TableMacroCatalogEntry ¯o, Serializer &serializer) { - serializer.WriteProperty(100, "table_macro", ¯o); -} - -void CheckpointReader::ReadTableMacro(ClientContext &context, Deserializer &deserializer) { - auto info = deserializer.ReadProperty>(100, "table_macro"); - auto ¯o_info = info->Cast(); - catalog.CreateFunction(context, macro_info); -} - -//===--------------------------------------------------------------------===// -// Table Metadata -//===--------------------------------------------------------------------===// -void CheckpointWriter::WriteTable(TableCatalogEntry &table, Serializer &serializer) { - // Write the table meta data - serializer.WriteProperty(100, "table", &table); - - // Write the table data - if (auto writer = GetTableDataWriter(table)) { - writer->WriteTableData(serializer); - } -} - -void CheckpointReader::ReadTable(ClientContext &context, Deserializer &deserializer) { - // deserialize the table meta data - auto info = deserializer.ReadProperty>(100, "table"); - auto binder = Binder::CreateBinder(context); - auto &schema = catalog.GetSchema(context, info->schema); - auto bound_info = binder->BindCreateTableInfo(std::move(info), schema); - - // now read the actual table data and place it into the create table info - ReadTableData(context, deserializer, *bound_info); - - // finally create the table in the catalog - catalog.CreateTable(context, *bound_info); -} - -void CheckpointReader::ReadTableData(ClientContext &context, Deserializer &deserializer, - BoundCreateTableInfo &bound_info) { - - // This is written in "SingleFileTableDataWriter::FinalizeTable" - auto table_pointer = deserializer.ReadProperty(101, "table_pointer"); - auto total_rows = deserializer.ReadProperty(102, "total_rows"); - auto index_pointers = deserializer.ReadProperty>(103, "index_pointers"); - - // FIXME: icky downcast to get the underlying MetadataReader - auto &binary_deserializer = dynamic_cast(deserializer); - auto &reader = dynamic_cast(binary_deserializer.GetStream()); - - MetadataReader table_data_reader(reader.GetMetadataManager(), table_pointer); - TableDataReader data_reader(table_data_reader, bound_info); - data_reader.ReadTableData(); - - bound_info.data->total_rows = total_rows; - bound_info.indexes = index_pointers; -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - -#include - -namespace duckdb { - -static constexpr const idx_t BITPACKING_METADATA_GROUP_SIZE = STANDARD_VECTOR_SIZE > 512 ? STANDARD_VECTOR_SIZE : 2048; - -BitpackingMode BitpackingModeFromString(const string &str) { - auto mode = StringUtil::Lower(str); - if (mode == "auto" || mode == "none") { - return BitpackingMode::AUTO; - } else if (mode == "constant") { - return BitpackingMode::CONSTANT; - } else if (mode == "constant_delta") { - return BitpackingMode::CONSTANT_DELTA; - } else if (mode == "delta_for") { - return BitpackingMode::DELTA_FOR; - } else if (mode == "for") { - return BitpackingMode::FOR; - } else { - return BitpackingMode::INVALID; - } -} - -string BitpackingModeToString(const BitpackingMode &mode) { - switch (mode) { - case BitpackingMode::AUTO: - return "auto"; - case BitpackingMode::CONSTANT: - return "constant"; - case BitpackingMode::CONSTANT_DELTA: - return "constant_delta"; - case BitpackingMode::DELTA_FOR: - return "delta_for"; - case BitpackingMode::FOR: - return "for"; - default: - throw NotImplementedException("Unknown bitpacking mode: " + to_string((uint8_t)mode) + "\n"); - } -} - -typedef struct { - BitpackingMode mode; - uint32_t offset; -} bitpacking_metadata_t; - -typedef uint32_t bitpacking_metadata_encoded_t; - -static bitpacking_metadata_encoded_t EncodeMeta(bitpacking_metadata_t metadata) { - D_ASSERT(metadata.offset <= 16777215); // max uint24_t - bitpacking_metadata_encoded_t encoded_value = metadata.offset; - encoded_value |= (uint8_t)metadata.mode << 24; - return encoded_value; -} -static bitpacking_metadata_t DecodeMeta(bitpacking_metadata_encoded_t *metadata_encoded) { - bitpacking_metadata_t metadata; - metadata.mode = Load(data_ptr_cast(metadata_encoded) + 3); - metadata.offset = *metadata_encoded & 0x00FFFFFF; - return metadata; -} - -struct EmptyBitpackingWriter { - template - static void WriteConstant(T constant, idx_t count, void *data_ptr, bool all_invalid) { - } - template ::type> - static void WriteConstantDelta(T_S constant, T frame_of_reference, idx_t count, T *values, bool *validity, - void *data_ptr) { - } - template ::type> - static void WriteDeltaFor(T *values, bool *validity, bitpacking_width_t width, T frame_of_reference, - T_S delta_offset, T *original_values, idx_t count, void *data_ptr) { - } - template - static void WriteFor(T *values, bool *validity, bitpacking_width_t width, T frame_of_reference, idx_t count, - void *data_ptr) { - } -}; - -template ::type> -struct BitpackingState { -public: - BitpackingState() : compression_buffer_idx(0), total_size(0), data_ptr(nullptr) { - compression_buffer_internal[0] = T(0); - compression_buffer = &compression_buffer_internal[1]; - Reset(); - } - - // Extra val for delta encoding - T compression_buffer_internal[BITPACKING_METADATA_GROUP_SIZE + 1]; - T *compression_buffer; - T_S delta_buffer[BITPACKING_METADATA_GROUP_SIZE]; - bool compression_buffer_validity[BITPACKING_METADATA_GROUP_SIZE]; - idx_t compression_buffer_idx; - idx_t total_size; - - // Used to pass CompressionState ptr through the Bitpacking writer - void *data_ptr; - - // Stats on current compression buffer - T minimum; - T maximum; - T min_max_diff; - T_S minimum_delta; - T_S maximum_delta; - T_S min_max_delta_diff; - T_S delta_offset; - bool all_valid; - bool all_invalid; - - bool can_do_delta; - bool can_do_for; - - // Used to force a specific mode, useful in testing - BitpackingMode mode = BitpackingMode::AUTO; - -public: - void Reset() { - minimum = NumericLimits::Maximum(); - minimum_delta = NumericLimits::Maximum(); - maximum = NumericLimits::Minimum(); - maximum_delta = NumericLimits::Minimum(); - delta_offset = 0; - all_valid = true; - all_invalid = true; - can_do_delta = false; - can_do_for = false; - compression_buffer_idx = 0; - min_max_diff = 0; - min_max_delta_diff = 0; - } - - void CalculateFORStats() { - can_do_for = TrySubtractOperator::Operation(maximum, minimum, min_max_diff); - } - - void CalculateDeltaStats() { - // TODO: currently we dont support delta compression of values above NumericLimits::Maximum(), - // we could support this with some clever substract trickery? - if (maximum > static_cast(NumericLimits::Maximum())) { - return; - } - - // Don't delta encoding 1 value makes no sense - if (compression_buffer_idx < 2) { - return; - } - - // TODO: handle NULLS here? - // Currently we cannot handle nulls because we would need an additional step of patching for this. - // we could for example copy the last value on a null insert. This would help a bit, but not be optimal for - // large deltas since theres suddenly a zero then. Ideally we would insert a value that leads to a delta within - // the current domain of deltas however we dont know that domain here yet - if (!all_valid) { - return; - } - - // Note: since we dont allow any values over NumericLimits::Maximum(), all subtractions for unsigned types - // are guaranteed not to overflow - bool can_do_all = true; - if (NumericLimits::IsSigned()) { - T_S bogus; - can_do_all = TrySubtractOperator::Operation(static_cast(minimum), static_cast(maximum), bogus) && - TrySubtractOperator::Operation(static_cast(maximum), static_cast(minimum), bogus); - } - - // Calculate delta's - // compression_buffer pointer points one element ahead of the internal buffer making the use of signed index - // integer (-1) possible - D_ASSERT(compression_buffer_idx <= NumericLimits::Maximum()); - if (can_do_all) { - for (int64_t i = 0; i < static_cast(compression_buffer_idx); i++) { - delta_buffer[i] = static_cast(compression_buffer[i]) - static_cast(compression_buffer[i - 1]); - } - } else { - for (int64_t i = 0; i < static_cast(compression_buffer_idx); i++) { - auto success = - TrySubtractOperator::Operation(static_cast(compression_buffer[i]), - static_cast(compression_buffer[i - 1]), delta_buffer[i]); - if (!success) { - return; - } - } - } - - can_do_delta = true; - - for (idx_t i = 1; i < compression_buffer_idx; i++) { - maximum_delta = MaxValue(maximum_delta, delta_buffer[i]); - minimum_delta = MinValue(minimum_delta, delta_buffer[i]); - } - - // Since we can set the first value arbitrarily, we want to pick one from the current domain, note that - // we will store the original first value - this offset as the delta_offset to be able to decode this again. - delta_buffer[0] = minimum_delta; - - can_do_delta = can_do_delta && TrySubtractOperator::Operation(maximum_delta, minimum_delta, min_max_delta_diff); - can_do_delta = can_do_delta && TrySubtractOperator::Operation(static_cast(compression_buffer[0]), - minimum_delta, delta_offset); - } - - template - void SubtractFrameOfReference(T_INNER *buffer, T_INNER frame_of_reference) { - static_assert(IsIntegral::value, "Integral type required."); - for (idx_t i = 0; i < compression_buffer_idx; i++) { - buffer[i] -= static_cast::type>(frame_of_reference); - } - } - - template - bool Flush() { - if (compression_buffer_idx == 0) { - return true; - } - - if ((all_invalid || maximum == minimum) && (mode == BitpackingMode::AUTO || mode == BitpackingMode::CONSTANT)) { - OP::WriteConstant(maximum, compression_buffer_idx, data_ptr, all_invalid); - total_size += sizeof(T) + sizeof(bitpacking_metadata_encoded_t); - return true; - } - - CalculateFORStats(); - CalculateDeltaStats(); - - if (can_do_delta) { - if (maximum_delta == minimum_delta && mode != BitpackingMode::FOR && mode != BitpackingMode::DELTA_FOR) { - // FOR needs to be T (considering hugeint is bigger than idx_t) - T frame_of_reference = compression_buffer[0]; - - OP::WriteConstantDelta(maximum_delta, static_cast(frame_of_reference), compression_buffer_idx, - compression_buffer, compression_buffer_validity, data_ptr); - total_size += sizeof(T) + sizeof(T) + sizeof(bitpacking_metadata_encoded_t); - return true; - } - - // Check if delta has benefit - // bitwidth is calculated differently between signed and unsigned values, but considering we do not have - // an unsigned version of hugeint, we need to explicitly specify (through boolean) that we wish to calculate - // the unsigned minimum bit-width instead of relying on MakeUnsigned and IsSigned - auto delta_required_bitwidth = BitpackingPrimitives::MinimumBitWidth(min_max_delta_diff); - auto regular_required_bitwidth = BitpackingPrimitives::MinimumBitWidth(min_max_diff); - - if (delta_required_bitwidth < regular_required_bitwidth && mode != BitpackingMode::FOR) { - SubtractFrameOfReference(delta_buffer, minimum_delta); - - OP::WriteDeltaFor(reinterpret_cast(delta_buffer), compression_buffer_validity, - delta_required_bitwidth, static_cast(minimum_delta), delta_offset, - compression_buffer, compression_buffer_idx, data_ptr); - - total_size += BitpackingPrimitives::GetRequiredSize(compression_buffer_idx, delta_required_bitwidth); - total_size += sizeof(T); // FOR value - total_size += sizeof(T); // Delta offset value - total_size += AlignValue(sizeof(bitpacking_width_t)); // FOR value - - return true; - } - } - - if (can_do_for) { - auto width = BitpackingPrimitives::MinimumBitWidth(min_max_diff); - SubtractFrameOfReference(compression_buffer, minimum); - OP::WriteFor(compression_buffer, compression_buffer_validity, width, minimum, compression_buffer_idx, - data_ptr); - - total_size += BitpackingPrimitives::GetRequiredSize(compression_buffer_idx, width); - total_size += sizeof(T); // FOR value - total_size += AlignValue(sizeof(bitpacking_width_t)); - - return true; - } - - return false; - } - - template - bool Update(T value, bool is_valid) { - compression_buffer_validity[compression_buffer_idx] = is_valid; - all_valid = all_valid && is_valid; - all_invalid = all_invalid && !is_valid; - - if (is_valid) { - compression_buffer[compression_buffer_idx] = value; - minimum = MinValue(minimum, value); - maximum = MaxValue(maximum, value); - } - - compression_buffer_idx++; - - if (compression_buffer_idx == BITPACKING_METADATA_GROUP_SIZE) { - bool success = Flush(); - Reset(); - return success; - } - return true; - } -}; - -//===--------------------------------------------------------------------===// -// Analyze -//===--------------------------------------------------------------------===// -template -struct BitpackingAnalyzeState : public AnalyzeState { - BitpackingState state; -}; - -template -unique_ptr BitpackingInitAnalyze(ColumnData &col_data, PhysicalType type) { - auto &config = DBConfig::GetConfig(col_data.GetDatabase()); - - auto state = make_uniq>(); - state->state.mode = config.options.force_bitpacking_mode; - - return std::move(state); -} - -template -bool BitpackingAnalyze(AnalyzeState &state, Vector &input, idx_t count) { - auto &analyze_state = static_cast &>(state); - UnifiedVectorFormat vdata; - input.ToUnifiedFormat(count, vdata); - - auto data = UnifiedVectorFormat::GetData(vdata); - for (idx_t i = 0; i < count; i++) { - auto idx = vdata.sel->get_index(i); - if (!analyze_state.state.template Update(data[idx], vdata.validity.RowIsValid(idx))) { - return false; - } - } - return true; -} - -template -idx_t BitpackingFinalAnalyze(AnalyzeState &state) { - auto &bitpacking_state = static_cast &>(state); - auto flush_result = bitpacking_state.state.template Flush(); - if (!flush_result) { - return DConstants::INVALID_INDEX; - } - return bitpacking_state.state.total_size; -} - -//===--------------------------------------------------------------------===// -// Compress -//===--------------------------------------------------------------------===// -template ::type> -struct BitpackingCompressState : public CompressionState { -public: - explicit BitpackingCompressState(ColumnDataCheckpointer &checkpointer) - : checkpointer(checkpointer), - function(checkpointer.GetCompressionFunction(CompressionType::COMPRESSION_BITPACKING)) { - CreateEmptySegment(checkpointer.GetRowGroup().start); - - state.data_ptr = reinterpret_cast(this); - - auto &config = DBConfig::GetConfig(checkpointer.GetDatabase()); - state.mode = config.options.force_bitpacking_mode; - } - - ColumnDataCheckpointer &checkpointer; - CompressionFunction &function; - unique_ptr current_segment; - BufferHandle handle; - - // Ptr to next free spot in segment; - data_ptr_t data_ptr; - // Ptr to next free spot for storing bitwidths and frame-of-references (growing downwards). - data_ptr_t metadata_ptr; - - BitpackingState state; - -public: - struct BitpackingWriter { - static void WriteConstant(T constant, idx_t count, void *data_ptr, bool all_invalid) { - auto state = reinterpret_cast *>(data_ptr); - - ReserveSpace(state, sizeof(T)); - WriteMetaData(state, BitpackingMode::CONSTANT); - WriteData(state->data_ptr, constant); - - UpdateStats(state, count); - } - - static void WriteConstantDelta(T_S constant, T frame_of_reference, idx_t count, T *values, bool *validity, - void *data_ptr) { - auto state = reinterpret_cast *>(data_ptr); - - ReserveSpace(state, 2 * sizeof(T)); - WriteMetaData(state, BitpackingMode::CONSTANT_DELTA); - WriteData(state->data_ptr, frame_of_reference); - WriteData(state->data_ptr, constant); - - UpdateStats(state, count); - } - static void WriteDeltaFor(T *values, bool *validity, bitpacking_width_t width, T frame_of_reference, - T_S delta_offset, T *original_values, idx_t count, void *data_ptr) { - auto state = reinterpret_cast *>(data_ptr); - - auto bp_size = BitpackingPrimitives::GetRequiredSize(count, width); - ReserveSpace(state, bp_size + 3 * sizeof(T)); - - WriteMetaData(state, BitpackingMode::DELTA_FOR); - WriteData(state->data_ptr, frame_of_reference); - WriteData(state->data_ptr, static_cast(width)); - WriteData(state->data_ptr, delta_offset); - - BitpackingPrimitives::PackBuffer(state->data_ptr, values, count, width); - state->data_ptr += bp_size; - - UpdateStats(state, count); - } - - static void WriteFor(T *values, bool *validity, bitpacking_width_t width, T frame_of_reference, idx_t count, - void *data_ptr) { - auto state = reinterpret_cast *>(data_ptr); - - auto bp_size = BitpackingPrimitives::GetRequiredSize(count, width); - ReserveSpace(state, bp_size + 2 * sizeof(T)); - - WriteMetaData(state, BitpackingMode::FOR); - WriteData(state->data_ptr, frame_of_reference); - WriteData(state->data_ptr, (T)width); - - BitpackingPrimitives::PackBuffer(state->data_ptr, values, count, width); - state->data_ptr += bp_size; - - UpdateStats(state, count); - } - - template - static void WriteData(data_ptr_t &ptr, T_OUT val) { - *reinterpret_cast(ptr) = val; - ptr += sizeof(T_OUT); - } - - static void WriteMetaData(BitpackingCompressState *state, BitpackingMode mode) { - bitpacking_metadata_t metadata {mode, (uint32_t)(state->data_ptr - state->handle.Ptr())}; - state->metadata_ptr -= sizeof(bitpacking_metadata_encoded_t); - Store(EncodeMeta(metadata), state->metadata_ptr); - } - - static void ReserveSpace(BitpackingCompressState *state, idx_t data_bytes) { - idx_t meta_bytes = sizeof(bitpacking_metadata_encoded_t); - state->FlushAndCreateSegmentIfFull(data_bytes, meta_bytes); - D_ASSERT(state->CanStore(data_bytes, meta_bytes)); - } - - static void UpdateStats(BitpackingCompressState *state, idx_t count) { - state->current_segment->count += count; - - if (WRITE_STATISTICS && !state->state.all_invalid) { - NumericStats::Update(state->current_segment->stats.statistics, state->state.minimum); - NumericStats::Update(state->current_segment->stats.statistics, state->state.maximum); - } - } - }; - - bool CanStore(idx_t data_bytes, idx_t meta_bytes) { - auto required_data_bytes = AlignValue((data_ptr + data_bytes) - data_ptr); - auto required_meta_bytes = Storage::BLOCK_SIZE - (metadata_ptr - data_ptr) + meta_bytes; - - return required_data_bytes + required_meta_bytes <= - Storage::BLOCK_SIZE - BitpackingPrimitives::BITPACKING_HEADER_SIZE; - } - - void CreateEmptySegment(idx_t row_start) { - auto &db = checkpointer.GetDatabase(); - auto &type = checkpointer.GetType(); - auto compressed_segment = ColumnSegment::CreateTransientSegment(db, type, row_start); - compressed_segment->function = function; - current_segment = std::move(compressed_segment); - auto &buffer_manager = BufferManager::GetBufferManager(db); - handle = buffer_manager.Pin(current_segment->block); - - data_ptr = handle.Ptr() + BitpackingPrimitives::BITPACKING_HEADER_SIZE; - metadata_ptr = handle.Ptr() + Storage::BLOCK_SIZE; - } - - void Append(UnifiedVectorFormat &vdata, idx_t count) { - auto data = UnifiedVectorFormat::GetData(vdata); - - for (idx_t i = 0; i < count; i++) { - idx_t idx = vdata.sel->get_index(i); - state.template Update::BitpackingWriter>( - data[idx], vdata.validity.RowIsValid(idx)); - } - } - - void FlushAndCreateSegmentIfFull(idx_t required_data_bytes, idx_t required_meta_bytes) { - if (!CanStore(required_data_bytes, required_meta_bytes)) { - idx_t row_start = current_segment->start + current_segment->count; - FlushSegment(); - CreateEmptySegment(row_start); - } - } - - void FlushSegment() { - auto &state = checkpointer.GetCheckpointState(); - auto base_ptr = handle.Ptr(); - - // Compact the segment by moving the metadata next to the data. - idx_t metadata_offset = AlignValue(data_ptr - base_ptr); - idx_t metadata_size = base_ptr + Storage::BLOCK_SIZE - metadata_ptr; - idx_t total_segment_size = metadata_offset + metadata_size; - - // Asserting things are still sane here - if (!CanStore(0, 0)) { - throw InternalException("Error in bitpacking size calculation"); - } - - memmove(base_ptr + metadata_offset, metadata_ptr, metadata_size); - - // Store the offset of the metadata of the first group (which is at the highest address). - Store(metadata_offset + metadata_size, base_ptr); - handle.Destroy(); - - state.FlushSegment(std::move(current_segment), total_segment_size); - } - - void Finalize() { - state.template Flush::BitpackingWriter>(); - FlushSegment(); - current_segment.reset(); - } -}; - -template -unique_ptr BitpackingInitCompression(ColumnDataCheckpointer &checkpointer, - unique_ptr state) { - return make_uniq>(checkpointer); -} - -template -void BitpackingCompress(CompressionState &state_p, Vector &scan_vector, idx_t count) { - auto &state = static_cast &>(state_p); - UnifiedVectorFormat vdata; - scan_vector.ToUnifiedFormat(count, vdata); - state.Append(vdata, count); -} - -template -void BitpackingFinalizeCompress(CompressionState &state_p) { - auto &state = static_cast &>(state_p); - state.Finalize(); -} - -//===--------------------------------------------------------------------===// -// Scan -//===--------------------------------------------------------------------===// -template -static void ApplyFrameOfReference(T *dst, T frame_of_reference, idx_t size) { - if (!frame_of_reference) { - return; - } - for (idx_t i = 0; i < size; i++) { - dst[i] += frame_of_reference; - } -} - -// Based on https://github.com/lemire/FastPFor (Apache License 2.0) -template -static T DeltaDecode(T *data, T previous_value, const size_t size) { - D_ASSERT(size >= 1); - - data[0] += previous_value; - - const size_t UnrollQty = 4; - const size_t sz0 = (size / UnrollQty) * UnrollQty; // equal to 0, if size < UnrollQty - size_t i = 1; - if (sz0 >= UnrollQty) { - T a = data[0]; - for (; i < sz0 - UnrollQty; i += UnrollQty) { - a = data[i] += a; - a = data[i + 1] += a; - a = data[i + 2] += a; - a = data[i + 3] += a; - } - } - for (; i != size; ++i) { - data[i] += data[i - 1]; - } - - return data[size - 1]; -} - -template ::type> -struct BitpackingScanState : public SegmentScanState { -public: - explicit BitpackingScanState(ColumnSegment &segment) : current_segment(segment) { - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - handle = buffer_manager.Pin(segment.block); - auto dataptr = handle.Ptr(); - - // load offset to bitpacking widths pointer - auto bitpacking_metadata_offset = Load(dataptr + segment.GetBlockOffset()); - bitpacking_metadata_ptr = - dataptr + segment.GetBlockOffset() + bitpacking_metadata_offset - sizeof(bitpacking_metadata_encoded_t); - - // load the first group - LoadNextGroup(); - } - - BufferHandle handle; - ColumnSegment ¤t_segment; - - T decompression_buffer[BITPACKING_METADATA_GROUP_SIZE]; - - bitpacking_metadata_t current_group; - - bitpacking_width_t current_width; - T current_frame_of_reference; - T current_constant; - T current_delta_offset; - - idx_t current_group_offset = 0; - data_ptr_t current_group_ptr; - data_ptr_t bitpacking_metadata_ptr; - -public: - //! Loads the metadata for the current metadata group. This will set bitpacking_metadata_ptr to the next group. - //! this will also load any metadata that is at the start of a compressed buffer (e.g. the width, for, or constant - //! value) depending on the bitpacking mode for that group - void LoadNextGroup() { - D_ASSERT(bitpacking_metadata_ptr > handle.Ptr() && - bitpacking_metadata_ptr < handle.Ptr() + Storage::BLOCK_SIZE); - current_group_offset = 0; - current_group = DecodeMeta(reinterpret_cast(bitpacking_metadata_ptr)); - - bitpacking_metadata_ptr -= sizeof(bitpacking_metadata_encoded_t); - current_group_ptr = GetPtr(current_group); - - // Read first value - switch (current_group.mode) { - case BitpackingMode::CONSTANT: - current_constant = *reinterpret_cast(current_group_ptr); - current_group_ptr += sizeof(T); - break; - case BitpackingMode::FOR: - case BitpackingMode::CONSTANT_DELTA: - case BitpackingMode::DELTA_FOR: - current_frame_of_reference = *reinterpret_cast(current_group_ptr); - current_group_ptr += sizeof(T); - break; - default: - throw InternalException("Invalid bitpacking mode"); - } - - // Read second value - switch (current_group.mode) { - case BitpackingMode::CONSTANT_DELTA: - current_constant = *reinterpret_cast(current_group_ptr); - current_group_ptr += sizeof(T); - break; - case BitpackingMode::FOR: - case BitpackingMode::DELTA_FOR: - current_width = (bitpacking_width_t)(*reinterpret_cast(current_group_ptr)); - current_group_ptr += MaxValue(sizeof(T), sizeof(bitpacking_width_t)); - break; - case BitpackingMode::CONSTANT: - break; - default: - throw InternalException("Invalid bitpacking mode"); - } - - // Read third value - if (current_group.mode == BitpackingMode::DELTA_FOR) { - current_delta_offset = *reinterpret_cast(current_group_ptr); - current_group_ptr += sizeof(T); - } - } - - void Skip(ColumnSegment &segment, idx_t skip_count) { - bool skip_sign_extend = true; - - idx_t skipped = 0; - while (skipped < skip_count) { - // Exhausted this metadata group, move pointers to next group and load metadata for next group. - if (current_group_offset >= BITPACKING_METADATA_GROUP_SIZE) { - LoadNextGroup(); - } - - idx_t offset_in_compression_group = - current_group_offset % BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE; - - if (current_group.mode == BitpackingMode::CONSTANT) { - idx_t remaining = skip_count - skipped; - idx_t to_skip = MinValue(remaining, BITPACKING_METADATA_GROUP_SIZE - current_group_offset); - skipped += to_skip; - current_group_offset += to_skip; - continue; - } - if (current_group.mode == BitpackingMode::CONSTANT_DELTA) { - idx_t remaining = skip_count - skipped; - idx_t to_skip = MinValue(remaining, BITPACKING_METADATA_GROUP_SIZE - current_group_offset); - skipped += to_skip; - current_group_offset += to_skip; - continue; - } - D_ASSERT(current_group.mode == BitpackingMode::FOR || current_group.mode == BitpackingMode::DELTA_FOR); - - idx_t to_skip = - MinValue(skip_count - skipped, - BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE - offset_in_compression_group); - // Calculate start of compression algorithm group - if (current_group.mode == BitpackingMode::DELTA_FOR) { - data_ptr_t current_position_ptr = current_group_ptr + current_group_offset * current_width / 8; - data_ptr_t decompression_group_start_pointer = - current_position_ptr - offset_in_compression_group * current_width / 8; - - BitpackingPrimitives::UnPackBlock(data_ptr_cast(decompression_buffer), - decompression_group_start_pointer, current_width, - skip_sign_extend); - - T *decompression_ptr = decompression_buffer + offset_in_compression_group; - ApplyFrameOfReference(reinterpret_cast(decompression_ptr), - static_cast(current_frame_of_reference), to_skip); - DeltaDecode(reinterpret_cast(decompression_ptr), static_cast(current_delta_offset), - to_skip); - current_delta_offset = decompression_ptr[to_skip - 1]; - } - - skipped += to_skip; - current_group_offset += to_skip; - } - } - - data_ptr_t GetPtr(bitpacking_metadata_t group) { - return handle.Ptr() + current_segment.GetBlockOffset() + group.offset; - } -}; - -template -unique_ptr BitpackingInitScan(ColumnSegment &segment) { - auto result = make_uniq>(segment); - return std::move(result); -} - -//===--------------------------------------------------------------------===// -// Scan base data -//===--------------------------------------------------------------------===// -template ::type> -void BitpackingScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, - idx_t result_offset) { - auto &scan_state = static_cast &>(*state.scan_state); - - T *result_data = FlatVector::GetData(result); - result.SetVectorType(VectorType::FLAT_VECTOR); - - //! Because FOR offsets all our values to be 0 or above, we can always skip sign extension here - bool skip_sign_extend = true; - - idx_t scanned = 0; - while (scanned < scan_count) { - // Exhausted this metadata group, move pointers to next group and load metadata for next group. - if (scan_state.current_group_offset >= BITPACKING_METADATA_GROUP_SIZE) { - scan_state.LoadNextGroup(); - } - - idx_t offset_in_compression_group = - scan_state.current_group_offset % BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE; - - if (scan_state.current_group.mode == BitpackingMode::CONSTANT) { - idx_t remaining = scan_count - scanned; - idx_t to_scan = MinValue(remaining, BITPACKING_METADATA_GROUP_SIZE - scan_state.current_group_offset); - T *begin = result_data + result_offset + scanned; - T *end = begin + remaining; - std::fill(begin, end, scan_state.current_constant); - scanned += to_scan; - scan_state.current_group_offset += to_scan; - continue; - } - if (scan_state.current_group.mode == BitpackingMode::CONSTANT_DELTA) { - idx_t remaining = scan_count - scanned; - idx_t to_scan = MinValue(remaining, BITPACKING_METADATA_GROUP_SIZE - scan_state.current_group_offset); - T *target_ptr = result_data + result_offset + scanned; - - for (idx_t i = 0; i < to_scan; i++) { - target_ptr[i] = (static_cast(scan_state.current_group_offset + i) * scan_state.current_constant) + - scan_state.current_frame_of_reference; - } - - scanned += to_scan; - scan_state.current_group_offset += to_scan; - continue; - } - D_ASSERT(scan_state.current_group.mode == BitpackingMode::FOR || - scan_state.current_group.mode == BitpackingMode::DELTA_FOR); - - idx_t to_scan = MinValue(scan_count - scanned, BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE - - offset_in_compression_group); - // Calculate start of compression algorithm group - data_ptr_t current_position_ptr = - scan_state.current_group_ptr + scan_state.current_group_offset * scan_state.current_width / 8; - data_ptr_t decompression_group_start_pointer = - current_position_ptr - offset_in_compression_group * scan_state.current_width / 8; - - T *current_result_ptr = result_data + result_offset + scanned; - - if (to_scan == BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE && offset_in_compression_group == 0) { - // Decompress directly into result vector - BitpackingPrimitives::UnPackBlock(data_ptr_cast(current_result_ptr), decompression_group_start_pointer, - scan_state.current_width, skip_sign_extend); - } else { - // Decompress compression algorithm to buffer - BitpackingPrimitives::UnPackBlock(data_ptr_cast(scan_state.decompression_buffer), - decompression_group_start_pointer, scan_state.current_width, - skip_sign_extend); - - memcpy(current_result_ptr, scan_state.decompression_buffer + offset_in_compression_group, - to_scan * sizeof(T)); - } - - if (scan_state.current_group.mode == BitpackingMode::DELTA_FOR) { - ApplyFrameOfReference(reinterpret_cast(current_result_ptr), - static_cast(scan_state.current_frame_of_reference), to_scan); - DeltaDecode(reinterpret_cast(current_result_ptr), - static_cast(scan_state.current_delta_offset), to_scan); - scan_state.current_delta_offset = current_result_ptr[to_scan - 1]; - } else { - ApplyFrameOfReference(current_result_ptr, scan_state.current_frame_of_reference, to_scan); - } - - scanned += to_scan; - scan_state.current_group_offset += to_scan; - } -} - -template -void BitpackingScan(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result) { - BitpackingScanPartial(segment, state, scan_count, result, 0); -} - -//===--------------------------------------------------------------------===// -// Fetch -//===--------------------------------------------------------------------===// -template -void BitpackingFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t row_id, Vector &result, - idx_t result_idx) { - BitpackingScanState scan_state(segment); - scan_state.Skip(segment, row_id); - T *result_data = FlatVector::GetData(result); - T *current_result_ptr = result_data + result_idx; - - idx_t offset_in_compression_group = - scan_state.current_group_offset % BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE; - - data_ptr_t decompression_group_start_pointer = - scan_state.current_group_ptr + - (scan_state.current_group_offset - offset_in_compression_group) * scan_state.current_width / 8; - - //! Because FOR offsets all our values to be 0 or above, we can always skip sign extension here - bool skip_sign_extend = true; - - if (scan_state.current_group.mode == BitpackingMode::CONSTANT) { - *current_result_ptr = scan_state.current_constant; - return; - } - - if (scan_state.current_group.mode == BitpackingMode::CONSTANT_DELTA) { -#ifdef DEBUG - // overflow check - T result; - bool multiply = TryMultiplyOperator::Operation(static_cast(scan_state.current_group_offset), - scan_state.current_constant, result); - bool add = TryAddOperator::Operation(result, scan_state.current_frame_of_reference, result); - D_ASSERT(multiply && add); -#endif - *current_result_ptr = (static_cast(scan_state.current_group_offset) * scan_state.current_constant) + - scan_state.current_frame_of_reference; - return; - } - - D_ASSERT(scan_state.current_group.mode == BitpackingMode::FOR || - scan_state.current_group.mode == BitpackingMode::DELTA_FOR); - - BitpackingPrimitives::UnPackBlock(data_ptr_cast(scan_state.decompression_buffer), - decompression_group_start_pointer, scan_state.current_width, skip_sign_extend); - - *current_result_ptr = scan_state.decompression_buffer[offset_in_compression_group]; - *current_result_ptr += scan_state.current_frame_of_reference; - - if (scan_state.current_group.mode == BitpackingMode::DELTA_FOR) { - *current_result_ptr += scan_state.current_delta_offset; - } -} -template -void BitpackingSkip(ColumnSegment &segment, ColumnScanState &state, idx_t skip_count) { - auto &scan_state = static_cast &>(*state.scan_state); - scan_state.Skip(segment, skip_count); -} - -//===--------------------------------------------------------------------===// -// Get Function -//===--------------------------------------------------------------------===// -template -CompressionFunction GetBitpackingFunction(PhysicalType data_type) { - return CompressionFunction(CompressionType::COMPRESSION_BITPACKING, data_type, BitpackingInitAnalyze, - BitpackingAnalyze, BitpackingFinalAnalyze, - BitpackingInitCompression, BitpackingCompress, - BitpackingFinalizeCompress, BitpackingInitScan, - BitpackingScan, BitpackingScanPartial, BitpackingFetchRow, BitpackingSkip); -} - -CompressionFunction BitpackingFun::GetFunction(PhysicalType type) { - switch (type) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - return GetBitpackingFunction(type); - case PhysicalType::INT16: - return GetBitpackingFunction(type); - case PhysicalType::INT32: - return GetBitpackingFunction(type); - case PhysicalType::INT64: - return GetBitpackingFunction(type); - case PhysicalType::UINT8: - return GetBitpackingFunction(type); - case PhysicalType::UINT16: - return GetBitpackingFunction(type); - case PhysicalType::UINT32: - return GetBitpackingFunction(type); - case PhysicalType::UINT64: - return GetBitpackingFunction(type); - case PhysicalType::INT128: - return GetBitpackingFunction(type); - case PhysicalType::LIST: - return GetBitpackingFunction(type); - default: - throw InternalException("Unsupported type for Bitpacking"); - } -} - -bool BitpackingFun::TypeIsSupported(PhysicalType type) { - switch (type) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - case PhysicalType::INT16: - case PhysicalType::INT32: - case PhysicalType::INT64: - case PhysicalType::UINT8: - case PhysicalType::UINT16: - case PhysicalType::UINT32: - case PhysicalType::UINT64: - case PhysicalType::LIST: - case PhysicalType::INT128: - return true; - default: - return false; - } -} - -} // namespace duckdb - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Unpacking -//===--------------------------------------------------------------------===// - -static void UnpackSingle(const uint32_t *__restrict &in, hugeint_t *__restrict out, uint16_t delta, uint16_t shr) { - if (delta + shr < 32) { - *out = ((static_cast(in[0])) >> shr) % (hugeint_t(1) << delta); - } - - else if (delta + shr >= 32 && delta + shr < 64) { - *out = static_cast(in[0]) >> shr; - ++in; - - if (delta + shr > 32) { - const uint16_t NEXT_SHR = shr + delta - 32; - *out |= static_cast((*in) % (1U << NEXT_SHR)) << (32 - shr); - } - } - - else if (delta + shr >= 64 && delta + shr < 96) { - *out = static_cast(in[0]) >> shr; - *out |= static_cast(in[1]) << (32 - shr); - in += 2; - - if (delta + shr > 64) { - const uint16_t NEXT_SHR = delta + shr - 64; - *out |= static_cast((*in) % (1U << NEXT_SHR)) << (64 - shr); - } - } - - else if (delta + shr >= 96 && delta + shr < 128) { - *out = static_cast(in[0]) >> shr; - *out |= static_cast(in[1]) << (32 - shr); - *out |= static_cast(in[2]) << (64 - shr); - in += 3; - - if (delta + shr > 96) { - const uint16_t NEXT_SHR = delta + shr - 96; - *out |= static_cast((*in) % (1U << NEXT_SHR)) << (96 - shr); - } - } - - else if (delta + shr >= 128) { - *out = static_cast(in[0]) >> shr; - *out |= static_cast(in[1]) << (32 - shr); - *out |= static_cast(in[2]) << (64 - shr); - *out |= static_cast(in[3]) << (96 - shr); - in += 4; - - if (delta + shr > 128) { - const uint16_t NEXT_SHR = delta + shr - 128; - *out |= static_cast((*in) % (1U << NEXT_SHR)) << (128 - shr); - } - } -} - -static void UnpackLast(const uint32_t *__restrict &in, hugeint_t *__restrict out, uint16_t delta) { - const uint8_t LAST_IDX = 31; - const uint16_t SHIFT = (delta * 31) % 32; - out[LAST_IDX] = in[0] >> SHIFT; - if (delta > 32) { - out[LAST_IDX] |= static_cast(in[1]) << (32 - SHIFT); - } - if (delta > 64) { - out[LAST_IDX] |= static_cast(in[2]) << (64 - SHIFT); - } - if (delta > 96) { - out[LAST_IDX] |= static_cast(in[3]) << (96 - SHIFT); - } -} - -// Unpacks for specific deltas -static void UnpackDelta0(const uint32_t *__restrict in, hugeint_t *__restrict out) { - for (uint8_t i = 0; i < 32; ++i) { - out[i] = 0; - } -} - -static void UnpackDelta32(const uint32_t *__restrict in, hugeint_t *__restrict out) { - for (uint8_t k = 0; k < 32; ++k) { - out[k] = static_cast(in[k]); - } -} - -static void UnpackDelta64(const uint32_t *__restrict in, hugeint_t *__restrict out) { - for (uint8_t i = 0; i < 32; ++i) { - const uint8_t OFFSET = i * 2; - out[i] = in[OFFSET]; - out[i] |= static_cast(in[OFFSET + 1]) << 32; - } -} - -static void UnpackDelta96(const uint32_t *__restrict in, hugeint_t *__restrict out) { - for (uint8_t i = 0; i < 32; ++i) { - const uint8_t OFFSET = i * 3; - out[i] = in[OFFSET]; - out[i] |= static_cast(in[OFFSET + 1]) << 32; - out[i] |= static_cast(in[OFFSET + 2]) << 64; - } -} - -static void UnpackDelta128(const uint32_t *__restrict in, hugeint_t *__restrict out) { - for (uint8_t i = 0; i < 32; ++i) { - const uint8_t OFFSET = i * 4; - out[i] = in[OFFSET]; - out[i] |= static_cast(in[OFFSET + 1]) << 32; - out[i] |= static_cast(in[OFFSET + 2]) << 64; - out[i] |= static_cast(in[OFFSET + 3]) << 96; - } -} - -//===--------------------------------------------------------------------===// -// Packing -//===--------------------------------------------------------------------===// - -static void PackSingle(const hugeint_t in, uint32_t *__restrict &out, uint16_t delta, uint16_t shl, hugeint_t mask) { - if (delta + shl < 32) { - - if (shl == 0) { - out[0] = static_cast(in & mask); - } else { - out[0] |= static_cast((in & mask) << shl); - } - - } else if (delta + shl >= 32 && delta + shl < 64) { - - if (shl == 0) { - out[0] = static_cast(in & mask); - } else { - out[0] |= static_cast((in & mask) << shl); - } - ++out; - - if (delta + shl > 32) { - *out = static_cast((in & mask) >> (32 - shl)); - } - } - - else if (delta + shl >= 64 && delta + shl < 96) { - - if (shl == 0) { - out[0] = static_cast(in & mask); - } else { - out[0] |= static_cast(in << shl); - } - - out[1] = static_cast((in & mask) >> (32 - shl)); - out += 2; - - if (delta + shl > 64) { - *out = static_cast((in & mask) >> (64 - shl)); - } - } - - else if (delta + shl >= 96 && delta + shl < 128) { - if (shl == 0) { - out[0] = static_cast(in & mask); - } else { - out[0] |= static_cast(in << shl); - } - - out[1] = static_cast((in & mask) >> (32 - shl)); - out[2] = static_cast((in & mask) >> (64 - shl)); - out += 3; - - if (delta + shl > 96) { - *out = static_cast((in & mask) >> (96 - shl)); - } - } - - else if (delta + shl >= 128) { - // shl == 0 won't ever happen here considering a delta of 128 calls PackDelta128 - out[0] |= static_cast(in << shl); - out[1] = static_cast((in & mask) >> (32 - shl)); - out[2] = static_cast((in & mask) >> (64 - shl)); - out[3] = static_cast((in & mask) >> (96 - shl)); - out += 4; - - if (delta + shl > 128) { - *out = static_cast((in & mask) >> (128 - shl)); - } - } -} - -static void PackLast(const hugeint_t *__restrict in, uint32_t *__restrict out, uint16_t delta) { - const uint8_t LAST_IDX = 31; - const uint16_t SHIFT = (delta * 31) % 32; - out[0] |= static_cast(in[LAST_IDX] << SHIFT); - if (delta > 32) { - out[1] = static_cast(in[LAST_IDX] >> (32 - SHIFT)); - } - if (delta > 64) { - out[2] = static_cast(in[LAST_IDX] >> (64 - SHIFT)); - } - if (delta > 96) { - out[3] = static_cast(in[LAST_IDX] >> (96 - SHIFT)); - } -} - -// Packs for specific deltas -static void PackDelta32(const hugeint_t *__restrict in, uint32_t *__restrict out) { - for (uint8_t i = 0; i < 32; ++i) { - out[i] = static_cast(in[i]); - } -} - -static void PackDelta64(const hugeint_t *__restrict in, uint32_t *__restrict out) { - for (uint8_t i = 0; i < 32; ++i) { - const uint8_t OFFSET = 2 * i; - out[OFFSET] = static_cast(in[i]); - out[OFFSET + 1] = static_cast(in[i] >> 32); - } -} - -static void PackDelta96(const hugeint_t *__restrict in, uint32_t *__restrict out) { - for (uint8_t i = 0; i < 32; ++i) { - const uint8_t OFFSET = 3 * i; - out[OFFSET] = static_cast(in[i]); - out[OFFSET + 1] = static_cast(in[i] >> 32); - out[OFFSET + 2] = static_cast(in[i] >> 64); - } -} - -static void PackDelta128(const hugeint_t *__restrict in, uint32_t *__restrict out) { - for (uint8_t i = 0; i < 32; ++i) { - const uint8_t OFFSET = 4 * i; - out[OFFSET] = static_cast(in[i]); - out[OFFSET + 1] = static_cast(in[i] >> 32); - out[OFFSET + 2] = static_cast(in[i] >> 64); - out[OFFSET + 3] = static_cast(in[i] >> 96); - } -} - -//===--------------------------------------------------------------------===// -// HugeIntPacker -//===--------------------------------------------------------------------===// - -void HugeIntPacker::Pack(const hugeint_t *__restrict in, uint32_t *__restrict out, bitpacking_width_t width) { - D_ASSERT(width <= 128); - switch (width) { - case 0: - break; - case 32: - PackDelta32(in, out); - break; - case 64: - PackDelta64(in, out); - break; - case 96: - PackDelta96(in, out); - break; - case 128: - PackDelta128(in, out); - break; - default: - for (idx_t oindex = 0; oindex < BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE - 1; ++oindex) { - PackSingle(in[oindex], out, width, (width * oindex) % BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE, - (hugeint_t(1) << width) - 1); - } - PackLast(in, out, width); - } -} - -void HugeIntPacker::Unpack(const uint32_t *__restrict in, hugeint_t *__restrict out, bitpacking_width_t width) { - D_ASSERT(width <= 128); - switch (width) { - case 0: - UnpackDelta0(in, out); - break; - case 32: - UnpackDelta32(in, out); - break; - case 64: - UnpackDelta64(in, out); - break; - case 96: - UnpackDelta96(in, out); - break; - case 128: - UnpackDelta128(in, out); - break; - default: - for (idx_t oindex = 0; oindex < BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE - 1; ++oindex) { - UnpackSingle(in, out + oindex, width, - (width * oindex) % BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE); - } - UnpackLast(in, out, width); - } -} - -} // namespace duckdb - - -namespace duckdb { - -constexpr uint8_t BitReader::REMAINDER_MASKS[]; -constexpr uint8_t BitReader::MASKS[]; - -} // namespace duckdb - - - - - - - - - -namespace duckdb { - -template -CompressionFunction GetChimpFunction(PhysicalType data_type) { - return CompressionFunction(CompressionType::COMPRESSION_CHIMP, data_type, ChimpInitAnalyze, ChimpAnalyze, - ChimpFinalAnalyze, ChimpInitCompression, ChimpCompress, - ChimpFinalizeCompress, ChimpInitScan, ChimpScan, ChimpScanPartial, - ChimpFetchRow, ChimpSkip); -} - -CompressionFunction ChimpCompressionFun::GetFunction(PhysicalType type) { - switch (type) { - case PhysicalType::FLOAT: - return GetChimpFunction(type); - case PhysicalType::DOUBLE: - return GetChimpFunction(type); - default: - throw InternalException("Unsupported type for Chimp"); - } -} - -bool ChimpCompressionFun::TypeIsSupported(PhysicalType type) { - switch (type) { - case PhysicalType::FLOAT: - case PhysicalType::DOUBLE: - return true; - default: - return false; - } -} - -} // namespace duckdb - - -namespace duckdb { - -constexpr uint8_t ChimpConstants::Compression::LEADING_ROUND[]; -constexpr uint8_t ChimpConstants::Compression::LEADING_REPRESENTATION[]; - -constexpr uint8_t ChimpConstants::Decompression::LEADING_REPRESENTATION[]; - -} // namespace duckdb - - -namespace duckdb { - -constexpr uint8_t FlagBufferConstants::MASKS[]; -constexpr uint8_t FlagBufferConstants::SHIFTS[]; - -} // namespace duckdb - - -namespace duckdb { - -constexpr uint32_t LeadingZeroBufferConstants::MASKS[]; -constexpr uint8_t LeadingZeroBufferConstants::SHIFTS[]; - -} // namespace duckdb - - - - - - - - - - - - -namespace duckdb { - -// Abstract class for keeping compression state either for compression or size analysis -class DictionaryCompressionState : public CompressionState { -public: - bool UpdateState(Vector &scan_vector, idx_t count) { - UnifiedVectorFormat vdata; - scan_vector.ToUnifiedFormat(count, vdata); - auto data = UnifiedVectorFormat::GetData(vdata); - Verify(); - - for (idx_t i = 0; i < count; i++) { - auto idx = vdata.sel->get_index(i); - size_t string_size = 0; - bool new_string = false; - auto row_is_valid = vdata.validity.RowIsValid(idx); - - if (row_is_valid) { - string_size = data[idx].GetSize(); - if (string_size >= StringUncompressed::STRING_BLOCK_LIMIT) { - // Big strings not implemented for dictionary compression - return false; - } - new_string = !LookupString(data[idx]); - } - - bool fits = CalculateSpaceRequirements(new_string, string_size); - if (!fits) { - Flush(); - new_string = true; - - fits = CalculateSpaceRequirements(new_string, string_size); - if (!fits) { - throw InternalException("Dictionary compression could not write to new segment"); - } - } - - if (!row_is_valid) { - AddNull(); - } else if (new_string) { - AddNewString(data[idx]); - } else { - AddLastLookup(); - } - - Verify(); - } - - return true; - } - -protected: - // Should verify the State - virtual void Verify() = 0; - // Performs a lookup of str, storing the result internally - virtual bool LookupString(string_t str) = 0; - // Add the most recently looked up str to compression state - virtual void AddLastLookup() = 0; - // Add string to the state that is known to not be seen yet - virtual void AddNewString(string_t str) = 0; - // Add a null value to the compression state - virtual void AddNull() = 0; - // Needs to be called before adding a value. Will return false if a flush is required first. - virtual bool CalculateSpaceRequirements(bool new_string, size_t string_size) = 0; - // Flush the segment to disk if compressing or reset the counters if analyzing - virtual void Flush(bool final = false) = 0; -}; - -typedef struct { - uint32_t dict_size; - uint32_t dict_end; - uint32_t index_buffer_offset; - uint32_t index_buffer_count; - uint32_t bitpacking_width; -} dictionary_compression_header_t; - -struct DictionaryCompressionStorage { - static constexpr float MINIMUM_COMPRESSION_RATIO = 1.2; - static constexpr uint16_t DICTIONARY_HEADER_SIZE = sizeof(dictionary_compression_header_t); - static constexpr size_t COMPACTION_FLUSH_LIMIT = (size_t)Storage::BLOCK_SIZE / 5 * 4; - - static unique_ptr StringInitAnalyze(ColumnData &col_data, PhysicalType type); - static bool StringAnalyze(AnalyzeState &state_p, Vector &input, idx_t count); - static idx_t StringFinalAnalyze(AnalyzeState &state_p); - - static unique_ptr InitCompression(ColumnDataCheckpointer &checkpointer, - unique_ptr state); - static void Compress(CompressionState &state_p, Vector &scan_vector, idx_t count); - static void FinalizeCompress(CompressionState &state_p); - - static unique_ptr StringInitScan(ColumnSegment &segment); - template - static void StringScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, - idx_t result_offset); - static void StringScan(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result); - static void StringFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t row_id, Vector &result, - idx_t result_idx); - - static bool HasEnoughSpace(idx_t current_count, idx_t index_count, idx_t dict_size, - bitpacking_width_t packing_width); - static idx_t RequiredSpace(idx_t current_count, idx_t index_count, idx_t dict_size, - bitpacking_width_t packing_width); - - static StringDictionaryContainer GetDictionary(ColumnSegment &segment, BufferHandle &handle); - static void SetDictionary(ColumnSegment &segment, BufferHandle &handle, StringDictionaryContainer container); - static string_t FetchStringFromDict(ColumnSegment &segment, StringDictionaryContainer dict, data_ptr_t baseptr, - int32_t dict_offset, uint16_t string_len); - static uint16_t GetStringLength(uint32_t *index_buffer_ptr, sel_t index); -}; - -// Dictionary compression uses a combination of bitpacking and a dictionary to compress string segments. The data is -// stored across three buffers: the index buffer, the selection buffer and the dictionary. Firstly the Index buffer -// contains the offsets into the dictionary which are also used to determine the string lengths. Each value in the -// dictionary gets a single unique index in the index buffer. Secondly, the selection buffer maps the tuples to an index -// in the index buffer. The selection buffer is compressed with bitpacking. Finally, the dictionary contains simply all -// the unique strings without lenghts or null termination as we can deduce the lengths from the index buffer. The -// addition of the selection buffer is done for two reasons: firstly, to allow the scan to emit dictionary vectors by -// scanning the whole dictionary at once and then scanning the selection buffer for each emitted vector. Secondly, it -// allows for efficient bitpacking compression as the selection values should remain relatively small. -struct DictionaryCompressionCompressState : public DictionaryCompressionState { - explicit DictionaryCompressionCompressState(ColumnDataCheckpointer &checkpointer_p) - : checkpointer(checkpointer_p), - function(checkpointer.GetCompressionFunction(CompressionType::COMPRESSION_DICTIONARY)), - heap(BufferAllocator::Get(checkpointer.GetDatabase())) { - CreateEmptySegment(checkpointer.GetRowGroup().start); - } - - ColumnDataCheckpointer &checkpointer; - CompressionFunction &function; - - // State regarding current segment - unique_ptr current_segment; - BufferHandle current_handle; - StringDictionaryContainer current_dictionary; - data_ptr_t current_end_ptr; - - // Buffers and map for current segment - StringHeap heap; - string_map_t current_string_map; - vector index_buffer; - vector selection_buffer; - - bitpacking_width_t current_width = 0; - bitpacking_width_t next_width = 0; - - // Result of latest LookupString call - uint32_t latest_lookup_result; - -public: - void CreateEmptySegment(idx_t row_start) { - auto &db = checkpointer.GetDatabase(); - auto &type = checkpointer.GetType(); - auto compressed_segment = ColumnSegment::CreateTransientSegment(db, type, row_start); - current_segment = std::move(compressed_segment); - - current_segment->function = function; - - // Reset the buffers and string map - current_string_map.clear(); - index_buffer.clear(); - index_buffer.push_back(0); // Reserve index 0 for null strings - selection_buffer.clear(); - - current_width = 0; - next_width = 0; - - // Reset the pointers into the current segment - auto &buffer_manager = BufferManager::GetBufferManager(checkpointer.GetDatabase()); - current_handle = buffer_manager.Pin(current_segment->block); - current_dictionary = DictionaryCompressionStorage::GetDictionary(*current_segment, current_handle); - current_end_ptr = current_handle.Ptr() + current_dictionary.end; - } - - void Verify() override { - current_dictionary.Verify(); - D_ASSERT(current_segment->count == selection_buffer.size()); - D_ASSERT(DictionaryCompressionStorage::HasEnoughSpace(current_segment->count.load(), index_buffer.size(), - current_dictionary.size, current_width)); - D_ASSERT(current_dictionary.end == Storage::BLOCK_SIZE); - D_ASSERT(index_buffer.size() == current_string_map.size() + 1); // +1 is for null value - } - - bool LookupString(string_t str) override { - auto search = current_string_map.find(str); - auto has_result = search != current_string_map.end(); - - if (has_result) { - latest_lookup_result = search->second; - } - return has_result; - } - - void AddNewString(string_t str) override { - UncompressedStringStorage::UpdateStringStats(current_segment->stats, str); - - // Copy string to dict - current_dictionary.size += str.GetSize(); - auto dict_pos = current_end_ptr - current_dictionary.size; - memcpy(dict_pos, str.GetData(), str.GetSize()); - current_dictionary.Verify(); - D_ASSERT(current_dictionary.end == Storage::BLOCK_SIZE); - - // Update buffers and map - index_buffer.push_back(current_dictionary.size); - selection_buffer.push_back(index_buffer.size() - 1); - if (str.IsInlined()) { - current_string_map.insert({str, index_buffer.size() - 1}); - } else { - current_string_map.insert({heap.AddBlob(str), index_buffer.size() - 1}); - } - DictionaryCompressionStorage::SetDictionary(*current_segment, current_handle, current_dictionary); - - current_width = next_width; - current_segment->count++; - } - - void AddNull() override { - selection_buffer.push_back(0); - current_segment->count++; - } - - void AddLastLookup() override { - selection_buffer.push_back(latest_lookup_result); - current_segment->count++; - } - - bool CalculateSpaceRequirements(bool new_string, size_t string_size) override { - if (new_string) { - next_width = BitpackingPrimitives::MinimumBitWidth(index_buffer.size() - 1 + new_string); - return DictionaryCompressionStorage::HasEnoughSpace(current_segment->count.load() + 1, - index_buffer.size() + 1, - current_dictionary.size + string_size, next_width); - } else { - return DictionaryCompressionStorage::HasEnoughSpace(current_segment->count.load() + 1, index_buffer.size(), - current_dictionary.size, current_width); - } - } - - void Flush(bool final = false) override { - auto next_start = current_segment->start + current_segment->count; - - auto segment_size = Finalize(); - auto &state = checkpointer.GetCheckpointState(); - state.FlushSegment(std::move(current_segment), segment_size); - - if (!final) { - CreateEmptySegment(next_start); - } - } - - idx_t Finalize() { - auto &buffer_manager = BufferManager::GetBufferManager(checkpointer.GetDatabase()); - auto handle = buffer_manager.Pin(current_segment->block); - D_ASSERT(current_dictionary.end == Storage::BLOCK_SIZE); - - // calculate sizes - auto compressed_selection_buffer_size = - BitpackingPrimitives::GetRequiredSize(current_segment->count, current_width); - auto index_buffer_size = index_buffer.size() * sizeof(uint32_t); - auto total_size = DictionaryCompressionStorage::DICTIONARY_HEADER_SIZE + compressed_selection_buffer_size + - index_buffer_size + current_dictionary.size; - - // calculate ptr and offsets - auto base_ptr = handle.Ptr(); - auto header_ptr = reinterpret_cast(base_ptr); - auto compressed_selection_buffer_offset = DictionaryCompressionStorage::DICTIONARY_HEADER_SIZE; - auto index_buffer_offset = compressed_selection_buffer_offset + compressed_selection_buffer_size; - - // Write compressed selection buffer - BitpackingPrimitives::PackBuffer(base_ptr + compressed_selection_buffer_offset, - (sel_t *)(selection_buffer.data()), current_segment->count, - current_width); - - // Write the index buffer - memcpy(base_ptr + index_buffer_offset, index_buffer.data(), index_buffer_size); - - // Store sizes and offsets in segment header - Store(index_buffer_offset, data_ptr_cast(&header_ptr->index_buffer_offset)); - Store(index_buffer.size(), data_ptr_cast(&header_ptr->index_buffer_count)); - Store((uint32_t)current_width, data_ptr_cast(&header_ptr->bitpacking_width)); - - D_ASSERT(current_width == BitpackingPrimitives::MinimumBitWidth(index_buffer.size() - 1)); - D_ASSERT(DictionaryCompressionStorage::HasEnoughSpace(current_segment->count, index_buffer.size(), - current_dictionary.size, current_width)); - D_ASSERT((uint64_t)*max_element(std::begin(selection_buffer), std::end(selection_buffer)) == - index_buffer.size() - 1); - - if (total_size >= DictionaryCompressionStorage::COMPACTION_FLUSH_LIMIT) { - // the block is full enough, don't bother moving around the dictionary - return Storage::BLOCK_SIZE; - } - // the block has space left: figure out how much space we can save - auto move_amount = Storage::BLOCK_SIZE - total_size; - // move the dictionary so it lines up exactly with the offsets - auto new_dictionary_offset = index_buffer_offset + index_buffer_size; - memmove(base_ptr + new_dictionary_offset, base_ptr + current_dictionary.end - current_dictionary.size, - current_dictionary.size); - current_dictionary.end -= move_amount; - D_ASSERT(current_dictionary.end == total_size); - // write the new dictionary (with the updated "end") - DictionaryCompressionStorage::SetDictionary(*current_segment, handle, current_dictionary); - return total_size; - } -}; - -//===--------------------------------------------------------------------===// -// Analyze -//===--------------------------------------------------------------------===// -struct DictionaryAnalyzeState : public DictionaryCompressionState { - DictionaryAnalyzeState() - : segment_count(0), current_tuple_count(0), current_unique_count(0), current_dict_size(0), current_width(0), - next_width(0) { - } - - size_t segment_count; - idx_t current_tuple_count; - idx_t current_unique_count; - size_t current_dict_size; - StringHeap heap; - string_set_t current_set; - bitpacking_width_t current_width; - bitpacking_width_t next_width; - - bool LookupString(string_t str) override { - return current_set.count(str); - } - - void AddNewString(string_t str) override { - current_tuple_count++; - current_unique_count++; - current_dict_size += str.GetSize(); - if (str.IsInlined()) { - current_set.insert(str); - } else { - current_set.insert(heap.AddBlob(str)); - } - current_width = next_width; - } - - void AddLastLookup() override { - current_tuple_count++; - } - - void AddNull() override { - current_tuple_count++; - } - - bool CalculateSpaceRequirements(bool new_string, size_t string_size) override { - if (new_string) { - next_width = - BitpackingPrimitives::MinimumBitWidth(current_unique_count + 2); // 1 for null, one for new string - return DictionaryCompressionStorage::HasEnoughSpace(current_tuple_count + 1, current_unique_count + 1, - current_dict_size + string_size, next_width); - } else { - return DictionaryCompressionStorage::HasEnoughSpace(current_tuple_count + 1, current_unique_count, - current_dict_size, current_width); - } - } - - void Flush(bool final = false) override { - segment_count++; - current_tuple_count = 0; - current_unique_count = 0; - current_dict_size = 0; - current_set.clear(); - } - void Verify() override {}; -}; - -struct DictionaryCompressionAnalyzeState : public AnalyzeState { - DictionaryCompressionAnalyzeState() : analyze_state(make_uniq()) { - } - - unique_ptr analyze_state; -}; - -unique_ptr DictionaryCompressionStorage::StringInitAnalyze(ColumnData &col_data, PhysicalType type) { - return make_uniq(); -} - -bool DictionaryCompressionStorage::StringAnalyze(AnalyzeState &state_p, Vector &input, idx_t count) { - auto &state = state_p.Cast(); - return state.analyze_state->UpdateState(input, count); -} - -idx_t DictionaryCompressionStorage::StringFinalAnalyze(AnalyzeState &state_p) { - auto &analyze_state = state_p.Cast(); - auto &state = *analyze_state.analyze_state; - - auto width = BitpackingPrimitives::MinimumBitWidth(state.current_unique_count + 1); - auto req_space = - RequiredSpace(state.current_tuple_count, state.current_unique_count, state.current_dict_size, width); - - return MINIMUM_COMPRESSION_RATIO * (state.segment_count * Storage::BLOCK_SIZE + req_space); -} - -//===--------------------------------------------------------------------===// -// Compress -//===--------------------------------------------------------------------===// -unique_ptr DictionaryCompressionStorage::InitCompression(ColumnDataCheckpointer &checkpointer, - unique_ptr state) { - return make_uniq(checkpointer); -} - -void DictionaryCompressionStorage::Compress(CompressionState &state_p, Vector &scan_vector, idx_t count) { - auto &state = state_p.Cast(); - state.UpdateState(scan_vector, count); -} - -void DictionaryCompressionStorage::FinalizeCompress(CompressionState &state_p) { - auto &state = state_p.Cast(); - state.Flush(true); -} - -//===--------------------------------------------------------------------===// -// Scan -//===--------------------------------------------------------------------===// -struct CompressedStringScanState : public StringScanState { - BufferHandle handle; - buffer_ptr dictionary; - bitpacking_width_t current_width; - buffer_ptr sel_vec; - idx_t sel_vec_size = 0; -}; - -unique_ptr DictionaryCompressionStorage::StringInitScan(ColumnSegment &segment) { - auto state = make_uniq(); - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - state->handle = buffer_manager.Pin(segment.block); - - auto baseptr = state->handle.Ptr() + segment.GetBlockOffset(); - - // Load header values - auto dict = DictionaryCompressionStorage::GetDictionary(segment, state->handle); - auto header_ptr = reinterpret_cast(baseptr); - auto index_buffer_offset = Load(data_ptr_cast(&header_ptr->index_buffer_offset)); - auto index_buffer_count = Load(data_ptr_cast(&header_ptr->index_buffer_count)); - state->current_width = (bitpacking_width_t)(Load(data_ptr_cast(&header_ptr->bitpacking_width))); - - auto index_buffer_ptr = reinterpret_cast(baseptr + index_buffer_offset); - - state->dictionary = make_buffer(segment.type, index_buffer_count); - auto dict_child_data = FlatVector::GetData(*(state->dictionary)); - - for (uint32_t i = 0; i < index_buffer_count; i++) { - // NOTE: the passing of dict_child_vector, will not be used, its for big strings - uint16_t str_len = GetStringLength(index_buffer_ptr, i); - dict_child_data[i] = FetchStringFromDict(segment, dict, baseptr, index_buffer_ptr[i], str_len); - } - - return std::move(state); -} - -//===--------------------------------------------------------------------===// -// Scan base data -//===--------------------------------------------------------------------===// -template -void DictionaryCompressionStorage::StringScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, - Vector &result, idx_t result_offset) { - // clear any previously locked buffers and get the primary buffer handle - auto &scan_state = state.scan_state->Cast(); - auto start = segment.GetRelativeIndex(state.row_index); - - auto baseptr = scan_state.handle.Ptr() + segment.GetBlockOffset(); - auto dict = DictionaryCompressionStorage::GetDictionary(segment, scan_state.handle); - - auto header_ptr = reinterpret_cast(baseptr); - auto index_buffer_offset = Load(data_ptr_cast(&header_ptr->index_buffer_offset)); - auto index_buffer_ptr = reinterpret_cast(baseptr + index_buffer_offset); - - auto base_data = data_ptr_cast(baseptr + DICTIONARY_HEADER_SIZE); - auto result_data = FlatVector::GetData(result); - - if (!ALLOW_DICT_VECTORS || scan_count != STANDARD_VECTOR_SIZE || - start % BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE != 0) { - // Emit regular vector - - // Handling non-bitpacking-group-aligned start values; - idx_t start_offset = start % BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE; - - // We will scan in blocks of BITPACKING_ALGORITHM_GROUP_SIZE, so we may scan some extra values. - idx_t decompress_count = BitpackingPrimitives::RoundUpToAlgorithmGroupSize(scan_count + start_offset); - - // Create a decompression buffer of sufficient size if we don't already have one. - if (!scan_state.sel_vec || scan_state.sel_vec_size < decompress_count) { - scan_state.sel_vec_size = decompress_count; - scan_state.sel_vec = make_buffer(decompress_count); - } - - data_ptr_t src = &base_data[((start - start_offset) * scan_state.current_width) / 8]; - sel_t *sel_vec_ptr = scan_state.sel_vec->data(); - - BitpackingPrimitives::UnPackBuffer(data_ptr_cast(sel_vec_ptr), src, decompress_count, - scan_state.current_width); - - for (idx_t i = 0; i < scan_count; i++) { - // Lookup dict offset in index buffer - auto string_number = scan_state.sel_vec->get_index(i + start_offset); - auto dict_offset = index_buffer_ptr[string_number]; - uint16_t str_len = GetStringLength(index_buffer_ptr, string_number); - result_data[result_offset + i] = FetchStringFromDict(segment, dict, baseptr, dict_offset, str_len); - } - - } else { - D_ASSERT(start % BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE == 0); - D_ASSERT(scan_count == STANDARD_VECTOR_SIZE); - D_ASSERT(result_offset == 0); - - idx_t decompress_count = BitpackingPrimitives::RoundUpToAlgorithmGroupSize(scan_count); - - // Create a selection vector of sufficient size if we don't already have one. - if (!scan_state.sel_vec || scan_state.sel_vec_size < decompress_count) { - scan_state.sel_vec_size = decompress_count; - scan_state.sel_vec = make_buffer(decompress_count); - } - - // Scanning 1024 values, emitting a dict vector - data_ptr_t dst = data_ptr_cast(scan_state.sel_vec->data()); - data_ptr_t src = data_ptr_cast(&base_data[(start * scan_state.current_width) / 8]); - - BitpackingPrimitives::UnPackBuffer(dst, src, scan_count, scan_state.current_width); - - result.Slice(*(scan_state.dictionary), *scan_state.sel_vec, scan_count); - } -} - -void DictionaryCompressionStorage::StringScan(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, - Vector &result) { - StringScanPartial(segment, state, scan_count, result, 0); -} - -//===--------------------------------------------------------------------===// -// Fetch -//===--------------------------------------------------------------------===// -void DictionaryCompressionStorage::StringFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t row_id, - Vector &result, idx_t result_idx) { - // fetch a single row from the string segment - // first pin the main buffer if it is not already pinned - auto &handle = state.GetOrInsertHandle(segment); - - auto baseptr = handle.Ptr() + segment.GetBlockOffset(); - auto header_ptr = reinterpret_cast(baseptr); - auto dict = DictionaryCompressionStorage::GetDictionary(segment, handle); - auto index_buffer_offset = Load(data_ptr_cast(&header_ptr->index_buffer_offset)); - auto width = (bitpacking_width_t)Load(data_ptr_cast(&header_ptr->bitpacking_width)); - auto index_buffer_ptr = reinterpret_cast(baseptr + index_buffer_offset); - auto base_data = data_ptr_cast(baseptr + DICTIONARY_HEADER_SIZE); - auto result_data = FlatVector::GetData(result); - - // Handling non-bitpacking-group-aligned start values; - idx_t start_offset = row_id % BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE; - - // Decompress part of selection buffer we need for this value. - sel_t decompression_buffer[BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE]; - data_ptr_t src = data_ptr_cast(&base_data[((row_id - start_offset) * width) / 8]); - BitpackingPrimitives::UnPackBuffer(data_ptr_cast(decompression_buffer), src, - BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE, width); - - auto selection_value = decompression_buffer[start_offset]; - auto dict_offset = index_buffer_ptr[selection_value]; - uint16_t str_len = GetStringLength(index_buffer_ptr, selection_value); - - result_data[result_idx] = FetchStringFromDict(segment, dict, baseptr, dict_offset, str_len); -} - -//===--------------------------------------------------------------------===// -// Helper Functions -//===--------------------------------------------------------------------===// -bool DictionaryCompressionStorage::HasEnoughSpace(idx_t current_count, idx_t index_count, idx_t dict_size, - bitpacking_width_t packing_width) { - return RequiredSpace(current_count, index_count, dict_size, packing_width) <= Storage::BLOCK_SIZE; -} - -idx_t DictionaryCompressionStorage::RequiredSpace(idx_t current_count, idx_t index_count, idx_t dict_size, - bitpacking_width_t packing_width) { - idx_t base_space = DICTIONARY_HEADER_SIZE + dict_size; - idx_t string_number_space = BitpackingPrimitives::GetRequiredSize(current_count, packing_width); - idx_t index_space = index_count * sizeof(uint32_t); - - idx_t used_space = base_space + index_space + string_number_space; - - return used_space; -} - -StringDictionaryContainer DictionaryCompressionStorage::GetDictionary(ColumnSegment &segment, BufferHandle &handle) { - auto header_ptr = reinterpret_cast(handle.Ptr() + segment.GetBlockOffset()); - StringDictionaryContainer container; - container.size = Load(data_ptr_cast(&header_ptr->dict_size)); - container.end = Load(data_ptr_cast(&header_ptr->dict_end)); - return container; -} - -void DictionaryCompressionStorage::SetDictionary(ColumnSegment &segment, BufferHandle &handle, - StringDictionaryContainer container) { - auto header_ptr = reinterpret_cast(handle.Ptr() + segment.GetBlockOffset()); - Store(container.size, data_ptr_cast(&header_ptr->dict_size)); - Store(container.end, data_ptr_cast(&header_ptr->dict_end)); -} - -string_t DictionaryCompressionStorage::FetchStringFromDict(ColumnSegment &segment, StringDictionaryContainer dict, - data_ptr_t baseptr, int32_t dict_offset, - uint16_t string_len) { - D_ASSERT(dict_offset >= 0 && dict_offset <= Storage::BLOCK_SIZE); - - if (dict_offset == 0) { - return string_t(nullptr, 0); - } - // normal string: read string from this block - auto dict_end = baseptr + dict.end; - auto dict_pos = dict_end - dict_offset; - - auto str_ptr = char_ptr_cast(dict_pos); - return string_t(str_ptr, string_len); -} - -uint16_t DictionaryCompressionStorage::GetStringLength(uint32_t *index_buffer_ptr, sel_t index) { - if (index == 0) { - return 0; - } else { - return index_buffer_ptr[index] - index_buffer_ptr[index - 1]; - } -} - -//===--------------------------------------------------------------------===// -// Get Function -//===--------------------------------------------------------------------===// -CompressionFunction DictionaryCompressionFun::GetFunction(PhysicalType data_type) { - return CompressionFunction( - CompressionType::COMPRESSION_DICTIONARY, data_type, DictionaryCompressionStorage ::StringInitAnalyze, - DictionaryCompressionStorage::StringAnalyze, DictionaryCompressionStorage::StringFinalAnalyze, - DictionaryCompressionStorage::InitCompression, DictionaryCompressionStorage::Compress, - DictionaryCompressionStorage::FinalizeCompress, DictionaryCompressionStorage::StringInitScan, - DictionaryCompressionStorage::StringScan, DictionaryCompressionStorage::StringScanPartial, - DictionaryCompressionStorage::StringFetchRow, UncompressedFunctions::EmptySkip); -} - -bool DictionaryCompressionFun::TypeIsSupported(PhysicalType type) { - return type == PhysicalType::VARCHAR; -} -} // namespace duckdb - - - - - - - - - - - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Analyze -//===--------------------------------------------------------------------===// -struct FixedSizeAnalyzeState : public AnalyzeState { - FixedSizeAnalyzeState() : count(0) { - } - - idx_t count; -}; - -unique_ptr FixedSizeInitAnalyze(ColumnData &col_data, PhysicalType type) { - return make_uniq(); -} - -bool FixedSizeAnalyze(AnalyzeState &state_p, Vector &input, idx_t count) { - auto &state = state_p.Cast(); - state.count += count; - return true; -} - -template -idx_t FixedSizeFinalAnalyze(AnalyzeState &state_p) { - auto &state = state_p.template Cast(); - return sizeof(T) * state.count; -} - -//===--------------------------------------------------------------------===// -// Compress -//===--------------------------------------------------------------------===// -struct UncompressedCompressState : public CompressionState { - explicit UncompressedCompressState(ColumnDataCheckpointer &checkpointer); - - ColumnDataCheckpointer &checkpointer; - unique_ptr current_segment; - ColumnAppendState append_state; - - virtual void CreateEmptySegment(idx_t row_start); - void FlushSegment(idx_t segment_size); - void Finalize(idx_t segment_size); -}; - -UncompressedCompressState::UncompressedCompressState(ColumnDataCheckpointer &checkpointer) - : checkpointer(checkpointer) { - UncompressedCompressState::CreateEmptySegment(checkpointer.GetRowGroup().start); -} - -void UncompressedCompressState::CreateEmptySegment(idx_t row_start) { - auto &db = checkpointer.GetDatabase(); - auto &type = checkpointer.GetType(); - auto compressed_segment = ColumnSegment::CreateTransientSegment(db, type, row_start); - if (type.InternalType() == PhysicalType::VARCHAR) { - auto &state = compressed_segment->GetSegmentState()->Cast(); - state.overflow_writer = make_uniq(checkpointer.GetRowGroup().GetBlockManager()); - } - current_segment = std::move(compressed_segment); - current_segment->InitializeAppend(append_state); -} - -void UncompressedCompressState::FlushSegment(idx_t segment_size) { - auto &state = checkpointer.GetCheckpointState(); - if (current_segment->type.InternalType() == PhysicalType::VARCHAR) { - auto &segment_state = current_segment->GetSegmentState()->Cast(); - segment_state.overflow_writer->Flush(); - segment_state.overflow_writer.reset(); - } - state.FlushSegment(std::move(current_segment), segment_size); -} - -void UncompressedCompressState::Finalize(idx_t segment_size) { - FlushSegment(segment_size); - current_segment.reset(); -} - -unique_ptr UncompressedFunctions::InitCompression(ColumnDataCheckpointer &checkpointer, - unique_ptr state) { - return make_uniq(checkpointer); -} - -void UncompressedFunctions::Compress(CompressionState &state_p, Vector &data, idx_t count) { - auto &state = state_p.Cast(); - UnifiedVectorFormat vdata; - data.ToUnifiedFormat(count, vdata); - - idx_t offset = 0; - while (count > 0) { - idx_t appended = state.current_segment->Append(state.append_state, vdata, offset, count); - if (appended == count) { - // appended everything: finished - return; - } - auto next_start = state.current_segment->start + state.current_segment->count; - // the segment is full: flush it to disk - state.FlushSegment(state.current_segment->FinalizeAppend(state.append_state)); - - // now create a new segment and continue appending - state.CreateEmptySegment(next_start); - offset += appended; - count -= appended; - } -} - -void UncompressedFunctions::FinalizeCompress(CompressionState &state_p) { - auto &state = state_p.Cast(); - state.Finalize(state.current_segment->FinalizeAppend(state.append_state)); -} - -//===--------------------------------------------------------------------===// -// Scan -//===--------------------------------------------------------------------===// -struct FixedSizeScanState : public SegmentScanState { - BufferHandle handle; -}; - -unique_ptr FixedSizeInitScan(ColumnSegment &segment) { - auto result = make_uniq(); - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - result->handle = buffer_manager.Pin(segment.block); - return std::move(result); -} - -//===--------------------------------------------------------------------===// -// Scan base data -//===--------------------------------------------------------------------===// -template -void FixedSizeScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, - idx_t result_offset) { - auto &scan_state = state.scan_state->Cast(); - auto start = segment.GetRelativeIndex(state.row_index); - - auto data = scan_state.handle.Ptr() + segment.GetBlockOffset(); - auto source_data = data + start * sizeof(T); - - // copy the data from the base table - result.SetVectorType(VectorType::FLAT_VECTOR); - memcpy(FlatVector::GetData(result) + result_offset * sizeof(T), source_data, scan_count * sizeof(T)); -} - -template -void FixedSizeScan(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result) { - auto &scan_state = state.scan_state->template Cast(); - auto start = segment.GetRelativeIndex(state.row_index); - - auto data = scan_state.handle.Ptr() + segment.GetBlockOffset(); - auto source_data = data + start * sizeof(T); - - result.SetVectorType(VectorType::FLAT_VECTOR); - FlatVector::SetData(result, source_data); -} - -//===--------------------------------------------------------------------===// -// Fetch -//===--------------------------------------------------------------------===// -template -void FixedSizeFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t row_id, Vector &result, - idx_t result_idx) { - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - auto handle = buffer_manager.Pin(segment.block); - - // first fetch the data from the base table - auto data_ptr = handle.Ptr() + segment.GetBlockOffset() + row_id * sizeof(T); - - memcpy(FlatVector::GetData(result) + result_idx * sizeof(T), data_ptr, sizeof(T)); -} - -//===--------------------------------------------------------------------===// -// Append -//===--------------------------------------------------------------------===// -static unique_ptr FixedSizeInitAppend(ColumnSegment &segment) { - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - auto handle = buffer_manager.Pin(segment.block); - return make_uniq(std::move(handle)); -} - -struct StandardFixedSizeAppend { - template - static void Append(SegmentStatistics &stats, data_ptr_t target, idx_t target_offset, UnifiedVectorFormat &adata, - idx_t offset, idx_t count) { - auto sdata = UnifiedVectorFormat::GetData(adata); - auto tdata = reinterpret_cast(target); - if (!adata.validity.AllValid()) { - for (idx_t i = 0; i < count; i++) { - auto source_idx = adata.sel->get_index(offset + i); - auto target_idx = target_offset + i; - bool is_null = !adata.validity.RowIsValid(source_idx); - if (!is_null) { - NumericStats::Update(stats.statistics, sdata[source_idx]); - tdata[target_idx] = sdata[source_idx]; - } else { - // we insert a NullValue in the null gap for debuggability - // this value should never be used or read anywhere - tdata[target_idx] = NullValue(); - } - } - } else { - for (idx_t i = 0; i < count; i++) { - auto source_idx = adata.sel->get_index(offset + i); - auto target_idx = target_offset + i; - NumericStats::Update(stats.statistics, sdata[source_idx]); - tdata[target_idx] = sdata[source_idx]; - } - } - } -}; - -struct ListFixedSizeAppend { - template - static void Append(SegmentStatistics &stats, data_ptr_t target, idx_t target_offset, UnifiedVectorFormat &adata, - idx_t offset, idx_t count) { - auto sdata = UnifiedVectorFormat::GetData(adata); - auto tdata = reinterpret_cast(target); - for (idx_t i = 0; i < count; i++) { - auto source_idx = adata.sel->get_index(offset + i); - auto target_idx = target_offset + i; - tdata[target_idx] = sdata[source_idx]; - } - } -}; - -template -idx_t FixedSizeAppend(CompressionAppendState &append_state, ColumnSegment &segment, SegmentStatistics &stats, - UnifiedVectorFormat &data, idx_t offset, idx_t count) { - D_ASSERT(segment.GetBlockOffset() == 0); - - auto target_ptr = append_state.handle.Ptr(); - idx_t max_tuple_count = segment.SegmentSize() / sizeof(T); - idx_t copy_count = MinValue(count, max_tuple_count - segment.count); - - OP::template Append(stats, target_ptr, segment.count, data, offset, copy_count); - segment.count += copy_count; - return copy_count; -} - -template -idx_t FixedSizeFinalizeAppend(ColumnSegment &segment, SegmentStatistics &stats) { - return segment.count * sizeof(T); -} - -//===--------------------------------------------------------------------===// -// Get Function -//===--------------------------------------------------------------------===// -template -CompressionFunction FixedSizeGetFunction(PhysicalType data_type) { - return CompressionFunction(CompressionType::COMPRESSION_UNCOMPRESSED, data_type, FixedSizeInitAnalyze, - FixedSizeAnalyze, FixedSizeFinalAnalyze, UncompressedFunctions::InitCompression, - UncompressedFunctions::Compress, UncompressedFunctions::FinalizeCompress, - FixedSizeInitScan, FixedSizeScan, FixedSizeScanPartial, FixedSizeFetchRow, - UncompressedFunctions::EmptySkip, nullptr, FixedSizeInitAppend, - FixedSizeAppend, FixedSizeFinalizeAppend, nullptr); -} - -CompressionFunction FixedSizeUncompressed::GetFunction(PhysicalType data_type) { - switch (data_type) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - return FixedSizeGetFunction(data_type); - case PhysicalType::INT16: - return FixedSizeGetFunction(data_type); - case PhysicalType::INT32: - return FixedSizeGetFunction(data_type); - case PhysicalType::INT64: - return FixedSizeGetFunction(data_type); - case PhysicalType::UINT8: - return FixedSizeGetFunction(data_type); - case PhysicalType::UINT16: - return FixedSizeGetFunction(data_type); - case PhysicalType::UINT32: - return FixedSizeGetFunction(data_type); - case PhysicalType::UINT64: - return FixedSizeGetFunction(data_type); - case PhysicalType::INT128: - return FixedSizeGetFunction(data_type); - case PhysicalType::FLOAT: - return FixedSizeGetFunction(data_type); - case PhysicalType::DOUBLE: - return FixedSizeGetFunction(data_type); - case PhysicalType::INTERVAL: - return FixedSizeGetFunction(data_type); - case PhysicalType::LIST: - return FixedSizeGetFunction(data_type); - default: - throw InternalException("Unsupported type for FixedSizeUncompressed::GetFunction"); - } -} - -} // namespace duckdb - - - - - - - - - - - - -namespace duckdb { - -typedef struct { - uint32_t dict_size; - uint32_t dict_end; - uint32_t bitpacking_width; - uint32_t fsst_symbol_table_offset; -} fsst_compression_header_t; - -// Counts and offsets used during scanning/fetching -// | ColumnSegment to be scanned / fetched from | -// | untouched | bp align | unused d-values | to scan | bp align | untouched | -typedef struct BPDeltaDecodeOffsets { - idx_t delta_decode_start_row; // X - idx_t bitunpack_alignment_offset; // <---------> - idx_t bitunpack_start_row; // X - idx_t unused_delta_decoded_values; // <-----------------> - idx_t scan_offset; // <----------------------------> - idx_t total_delta_decode_count; // <--------------------------> - idx_t total_bitunpack_count; // <------------------------------------------------> -} bp_delta_offsets_t; - -struct FSSTStorage { - static constexpr size_t COMPACTION_FLUSH_LIMIT = (size_t)Storage::BLOCK_SIZE / 5 * 4; - static constexpr double MINIMUM_COMPRESSION_RATIO = 1.2; - static constexpr double ANALYSIS_SAMPLE_SIZE = 0.25; - - static unique_ptr StringInitAnalyze(ColumnData &col_data, PhysicalType type); - static bool StringAnalyze(AnalyzeState &state_p, Vector &input, idx_t count); - static idx_t StringFinalAnalyze(AnalyzeState &state_p); - - static unique_ptr InitCompression(ColumnDataCheckpointer &checkpointer, - unique_ptr analyze_state_p); - static void Compress(CompressionState &state_p, Vector &scan_vector, idx_t count); - static void FinalizeCompress(CompressionState &state_p); - - static unique_ptr StringInitScan(ColumnSegment &segment); - template - static void StringScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, - idx_t result_offset); - static void StringScan(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result); - static void StringFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t row_id, Vector &result, - idx_t result_idx); - - static void SetDictionary(ColumnSegment &segment, BufferHandle &handle, StringDictionaryContainer container); - static StringDictionaryContainer GetDictionary(ColumnSegment &segment, BufferHandle &handle); - - static char *FetchStringPointer(StringDictionaryContainer dict, data_ptr_t baseptr, int32_t dict_offset); - static bp_delta_offsets_t CalculateBpDeltaOffsets(int64_t last_known_row, idx_t start, idx_t scan_count); - static bool ParseFSSTSegmentHeader(data_ptr_t base_ptr, duckdb_fsst_decoder_t *decoder_out, - bitpacking_width_t *width_out); -}; - -//===--------------------------------------------------------------------===// -// Analyze -//===--------------------------------------------------------------------===// -struct FSSTAnalyzeState : public AnalyzeState { - FSSTAnalyzeState() : count(0), fsst_string_total_size(0), empty_strings(0) { - } - - ~FSSTAnalyzeState() override { - if (fsst_encoder) { - duckdb_fsst_destroy(fsst_encoder); - } - } - - duckdb_fsst_encoder_t *fsst_encoder = nullptr; - idx_t count; - - StringHeap fsst_string_heap; - vector fsst_strings; - size_t fsst_string_total_size; - - RandomEngine random_engine; - bool have_valid_row = false; - - idx_t empty_strings; -}; - -unique_ptr FSSTStorage::StringInitAnalyze(ColumnData &col_data, PhysicalType type) { - return make_uniq(); -} - -bool FSSTStorage::StringAnalyze(AnalyzeState &state_p, Vector &input, idx_t count) { - auto &state = state_p.Cast(); - UnifiedVectorFormat vdata; - input.ToUnifiedFormat(count, vdata); - - state.count += count; - auto data = UnifiedVectorFormat::GetData(vdata); - - // Note that we ignore the sampling in case we have not found any valid strings yet, this solves the issue of - // not having seen any valid strings here leading to an empty fsst symbol table. - bool sample_selected = !state.have_valid_row || state.random_engine.NextRandom() < ANALYSIS_SAMPLE_SIZE; - - for (idx_t i = 0; i < count; i++) { - auto idx = vdata.sel->get_index(i); - - if (!vdata.validity.RowIsValid(idx)) { - continue; - } - - // We need to check all strings for this, otherwise we run in to trouble during compression if we miss ones - auto string_size = data[idx].GetSize(); - if (string_size >= StringUncompressed::STRING_BLOCK_LIMIT) { - return false; - } - - if (!sample_selected) { - continue; - } - - if (string_size > 0) { - state.have_valid_row = true; - if (data[idx].IsInlined()) { - state.fsst_strings.push_back(data[idx]); - } else { - state.fsst_strings.emplace_back(state.fsst_string_heap.AddBlob(data[idx])); - } - state.fsst_string_total_size += string_size; - } else { - state.empty_strings++; - } - } - return true; -} - -idx_t FSSTStorage::StringFinalAnalyze(AnalyzeState &state_p) { - auto &state = state_p.Cast(); - - size_t compressed_dict_size = 0; - size_t max_compressed_string_length = 0; - - auto string_count = state.fsst_strings.size(); - - if (!string_count) { - return DConstants::INVALID_INDEX; - } - - size_t output_buffer_size = 7 + 2 * state.fsst_string_total_size; // size as specified in fsst.h - - vector fsst_string_sizes; - vector fsst_string_ptrs; - for (auto &str : state.fsst_strings) { - fsst_string_sizes.push_back(str.GetSize()); - fsst_string_ptrs.push_back((unsigned char *)str.GetData()); // NOLINT - } - - state.fsst_encoder = duckdb_fsst_create(string_count, &fsst_string_sizes[0], &fsst_string_ptrs[0], 0); - - // TODO: do we really need to encode to get a size estimate? - auto compressed_ptrs = vector(string_count, nullptr); - auto compressed_sizes = vector(string_count, 0); - unique_ptr compressed_buffer(new unsigned char[output_buffer_size]); - - auto res = - duckdb_fsst_compress(state.fsst_encoder, string_count, &fsst_string_sizes[0], &fsst_string_ptrs[0], - output_buffer_size, compressed_buffer.get(), &compressed_sizes[0], &compressed_ptrs[0]); - - if (string_count != res) { - throw std::runtime_error("FSST output buffer is too small unexpectedly"); - } - - // Sum and and Max compressed lengths - for (auto &size : compressed_sizes) { - compressed_dict_size += size; - max_compressed_string_length = MaxValue(max_compressed_string_length, size); - } - D_ASSERT(compressed_dict_size == (compressed_ptrs[res - 1] - compressed_ptrs[0]) + compressed_sizes[res - 1]); - - auto minimum_width = BitpackingPrimitives::MinimumBitWidth(max_compressed_string_length); - auto bitpacked_offsets_size = - BitpackingPrimitives::GetRequiredSize(string_count + state.empty_strings, minimum_width); - - auto estimated_base_size = (bitpacked_offsets_size + compressed_dict_size) * (1 / ANALYSIS_SAMPLE_SIZE); - auto num_blocks = estimated_base_size / (Storage::BLOCK_SIZE - sizeof(duckdb_fsst_decoder_t)); - auto symtable_size = num_blocks * sizeof(duckdb_fsst_decoder_t); - - auto estimated_size = estimated_base_size + symtable_size; - - return estimated_size * MINIMUM_COMPRESSION_RATIO; -} - -//===--------------------------------------------------------------------===// -// Compress -//===--------------------------------------------------------------------===// - -class FSSTCompressionState : public CompressionState { -public: - explicit FSSTCompressionState(ColumnDataCheckpointer &checkpointer) - : checkpointer(checkpointer), function(checkpointer.GetCompressionFunction(CompressionType::COMPRESSION_FSST)) { - CreateEmptySegment(checkpointer.GetRowGroup().start); - } - - ~FSSTCompressionState() override { - if (fsst_encoder) { - duckdb_fsst_destroy(fsst_encoder); - } - } - - void Reset() { - index_buffer.clear(); - current_width = 0; - max_compressed_string_length = 0; - last_fitting_size = 0; - - // Reset the pointers into the current segment - auto &buffer_manager = BufferManager::GetBufferManager(current_segment->db); - current_handle = buffer_manager.Pin(current_segment->block); - current_dictionary = FSSTStorage::GetDictionary(*current_segment, current_handle); - current_end_ptr = current_handle.Ptr() + current_dictionary.end; - } - - void CreateEmptySegment(idx_t row_start) { - auto &db = checkpointer.GetDatabase(); - auto &type = checkpointer.GetType(); - auto compressed_segment = ColumnSegment::CreateTransientSegment(db, type, row_start); - current_segment = std::move(compressed_segment); - current_segment->function = function; - Reset(); - } - - void UpdateState(string_t uncompressed_string, unsigned char *compressed_string, size_t compressed_string_len) { - if (!HasEnoughSpace(compressed_string_len)) { - Flush(); - if (!HasEnoughSpace(compressed_string_len)) { - throw InternalException("FSST string compression failed due to insufficient space in empty block"); - }; - } - - UncompressedStringStorage::UpdateStringStats(current_segment->stats, uncompressed_string); - - // Write string into dictionary - current_dictionary.size += compressed_string_len; - auto dict_pos = current_end_ptr - current_dictionary.size; - memcpy(dict_pos, compressed_string, compressed_string_len); - current_dictionary.Verify(); - - // We just push the string length to effectively delta encode the strings - index_buffer.push_back(compressed_string_len); - - max_compressed_string_length = MaxValue(max_compressed_string_length, compressed_string_len); - - current_width = BitpackingPrimitives::MinimumBitWidth(max_compressed_string_length); - current_segment->count++; - } - - void AddNull() { - if (!HasEnoughSpace(0)) { - Flush(); - if (!HasEnoughSpace(0)) { - throw InternalException("FSST string compression failed due to insufficient space in empty block"); - }; - } - index_buffer.push_back(0); - current_segment->count++; - } - - void AddEmptyString() { - AddNull(); - UncompressedStringStorage::UpdateStringStats(current_segment->stats, ""); - } - - size_t GetRequiredSize(size_t string_len) { - bitpacking_width_t required_minimum_width; - if (string_len > max_compressed_string_length) { - required_minimum_width = BitpackingPrimitives::MinimumBitWidth(string_len); - } else { - required_minimum_width = current_width; - } - - size_t current_dict_size = current_dictionary.size; - idx_t current_string_count = index_buffer.size(); - - size_t dict_offsets_size = - BitpackingPrimitives::GetRequiredSize(current_string_count + 1, required_minimum_width); - - // TODO switch to a symbol table per RowGroup, saves a bit of space - return sizeof(fsst_compression_header_t) + current_dict_size + dict_offsets_size + string_len + - fsst_serialized_symbol_table_size; - } - - // Checks if there is enough space, if there is, sets last_fitting_size - bool HasEnoughSpace(size_t string_len) { - auto required_size = GetRequiredSize(string_len); - - if (required_size <= Storage::BLOCK_SIZE) { - last_fitting_size = required_size; - return true; - } - return false; - } - - void Flush(bool final = false) { - auto next_start = current_segment->start + current_segment->count; - - auto segment_size = Finalize(); - auto &state = checkpointer.GetCheckpointState(); - state.FlushSegment(std::move(current_segment), segment_size); - - if (!final) { - CreateEmptySegment(next_start); - } - } - - idx_t Finalize() { - auto &buffer_manager = BufferManager::GetBufferManager(current_segment->db); - auto handle = buffer_manager.Pin(current_segment->block); - D_ASSERT(current_dictionary.end == Storage::BLOCK_SIZE); - - // calculate sizes - auto compressed_index_buffer_size = - BitpackingPrimitives::GetRequiredSize(current_segment->count, current_width); - auto total_size = sizeof(fsst_compression_header_t) + compressed_index_buffer_size + current_dictionary.size + - fsst_serialized_symbol_table_size; - - if (total_size != last_fitting_size) { - throw InternalException("FSST string compression failed due to incorrect size calculation"); - } - - // calculate ptr and offsets - auto base_ptr = handle.Ptr(); - auto header_ptr = reinterpret_cast(base_ptr); - auto compressed_index_buffer_offset = sizeof(fsst_compression_header_t); - auto symbol_table_offset = compressed_index_buffer_offset + compressed_index_buffer_size; - - D_ASSERT(current_segment->count == index_buffer.size()); - BitpackingPrimitives::PackBuffer(base_ptr + compressed_index_buffer_offset, - reinterpret_cast(index_buffer.data()), - current_segment->count, current_width); - - // Write the fsst symbol table or nothing - if (fsst_encoder != nullptr) { - memcpy(base_ptr + symbol_table_offset, &fsst_serialized_symbol_table[0], fsst_serialized_symbol_table_size); - } else { - memset(base_ptr + symbol_table_offset, 0, fsst_serialized_symbol_table_size); - } - - Store(symbol_table_offset, data_ptr_cast(&header_ptr->fsst_symbol_table_offset)); - Store((uint32_t)current_width, data_ptr_cast(&header_ptr->bitpacking_width)); - - if (total_size >= FSSTStorage::COMPACTION_FLUSH_LIMIT) { - // the block is full enough, don't bother moving around the dictionary - return Storage::BLOCK_SIZE; - } - // the block has space left: figure out how much space we can save - auto move_amount = Storage::BLOCK_SIZE - total_size; - // move the dictionary so it lines up exactly with the offsets - auto new_dictionary_offset = symbol_table_offset + fsst_serialized_symbol_table_size; - memmove(base_ptr + new_dictionary_offset, base_ptr + current_dictionary.end - current_dictionary.size, - current_dictionary.size); - current_dictionary.end -= move_amount; - D_ASSERT(current_dictionary.end == total_size); - // write the new dictionary (with the updated "end") - FSSTStorage::SetDictionary(*current_segment, handle, current_dictionary); - - return total_size; - } - - ColumnDataCheckpointer &checkpointer; - CompressionFunction &function; - - // State regarding current segment - unique_ptr current_segment; - BufferHandle current_handle; - StringDictionaryContainer current_dictionary; - data_ptr_t current_end_ptr; - - // Buffers and map for current segment - vector index_buffer; - - size_t max_compressed_string_length; - bitpacking_width_t current_width; - idx_t last_fitting_size; - - duckdb_fsst_encoder_t *fsst_encoder = nullptr; - unsigned char fsst_serialized_symbol_table[sizeof(duckdb_fsst_decoder_t)]; - size_t fsst_serialized_symbol_table_size = sizeof(duckdb_fsst_decoder_t); -}; - -unique_ptr FSSTStorage::InitCompression(ColumnDataCheckpointer &checkpointer, - unique_ptr analyze_state_p) { - auto analyze_state = static_cast(analyze_state_p.get()); - auto compression_state = make_uniq(checkpointer); - - if (analyze_state->fsst_encoder == nullptr) { - throw InternalException("No encoder found during FSST compression"); - } - - compression_state->fsst_encoder = analyze_state->fsst_encoder; - compression_state->fsst_serialized_symbol_table_size = - duckdb_fsst_export(compression_state->fsst_encoder, &compression_state->fsst_serialized_symbol_table[0]); - analyze_state->fsst_encoder = nullptr; - - return std::move(compression_state); -} - -void FSSTStorage::Compress(CompressionState &state_p, Vector &scan_vector, idx_t count) { - auto &state = state_p.Cast(); - - // Get vector data - UnifiedVectorFormat vdata; - scan_vector.ToUnifiedFormat(count, vdata); - auto data = UnifiedVectorFormat::GetData(vdata); - - // Collect pointers to strings to compress - vector sizes_in; - vector strings_in; - size_t total_size = 0; - idx_t total_count = 0; - for (idx_t i = 0; i < count; i++) { - auto idx = vdata.sel->get_index(i); - - // Note: we treat nulls and empty strings the same - if (!vdata.validity.RowIsValid(idx) || data[idx].GetSize() == 0) { - continue; - } - - total_count++; - total_size += data[idx].GetSize(); - sizes_in.push_back(data[idx].GetSize()); - strings_in.push_back((unsigned char *)data[idx].GetData()); // NOLINT - } - - // Only Nulls or empty strings in this vector, nothing to compress - if (total_count == 0) { - for (idx_t i = 0; i < count; i++) { - auto idx = vdata.sel->get_index(i); - if (!vdata.validity.RowIsValid(idx)) { - state.AddNull(); - } else if (data[idx].GetSize() == 0) { - state.AddEmptyString(); - } else { - throw FatalException("FSST: no encoder found even though there are values to encode"); - } - } - return; - } - - // Compress buffers - size_t compress_buffer_size = MaxValue(total_size * 2 + 7, 1); - vector strings_out(total_count, nullptr); - vector sizes_out(total_count, 0); - vector compress_buffer(compress_buffer_size, 0); - - auto res = duckdb_fsst_compress( - state.fsst_encoder, /* IN: encoder obtained from duckdb_fsst_create(). */ - total_count, /* IN: number of strings in batch to compress. */ - &sizes_in[0], /* IN: byte-lengths of the inputs */ - &strings_in[0], /* IN: input string start pointers. */ - compress_buffer_size, /* IN: byte-length of output buffer. */ - &compress_buffer[0], /* OUT: memory buffer to put the compressed strings in (one after the other). */ - &sizes_out[0], /* OUT: byte-lengths of the compressed strings. */ - &strings_out[0] /* OUT: output string start pointers. Will all point into [output,output+size). */ - ); - - if (res != total_count) { - throw FatalException("FSST compression failed to compress all strings"); - } - - // Push the compressed strings to the compression state one by one - idx_t compressed_idx = 0; - for (idx_t i = 0; i < count; i++) { - auto idx = vdata.sel->get_index(i); - if (!vdata.validity.RowIsValid(idx)) { - state.AddNull(); - } else if (data[idx].GetSize() == 0) { - state.AddEmptyString(); - } else { - state.UpdateState(data[idx], strings_out[compressed_idx], sizes_out[compressed_idx]); - compressed_idx++; - } - } -} - -void FSSTStorage::FinalizeCompress(CompressionState &state_p) { - auto &state = state_p.Cast(); - state.Flush(true); -} - -//===--------------------------------------------------------------------===// -// Scan -//===--------------------------------------------------------------------===// -struct FSSTScanState : public StringScanState { - FSSTScanState() { - ResetStoredDelta(); - } - - buffer_ptr duckdb_fsst_decoder; - bitpacking_width_t current_width; - - // To speed up delta decoding we store the last index - uint32_t last_known_index; - int64_t last_known_row; - - void StoreLastDelta(uint32_t value, int64_t row) { - last_known_index = value; - last_known_row = row; - } - void ResetStoredDelta() { - last_known_index = 0; - last_known_row = -1; - } -}; - -unique_ptr FSSTStorage::StringInitScan(ColumnSegment &segment) { - auto state = make_uniq(); - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - state->handle = buffer_manager.Pin(segment.block); - auto base_ptr = state->handle.Ptr() + segment.GetBlockOffset(); - - state->duckdb_fsst_decoder = make_buffer(); - auto retval = ParseFSSTSegmentHeader( - base_ptr, reinterpret_cast(state->duckdb_fsst_decoder.get()), &state->current_width); - if (!retval) { - state->duckdb_fsst_decoder = nullptr; - } - - return std::move(state); -} - -void DeltaDecodeIndices(uint32_t *buffer_in, uint32_t *buffer_out, idx_t decode_count, uint32_t last_known_value) { - buffer_out[0] = buffer_in[0]; - buffer_out[0] += last_known_value; - for (idx_t i = 1; i < decode_count; i++) { - buffer_out[i] = buffer_in[i] + buffer_out[i - 1]; - } -} - -void BitUnpackRange(data_ptr_t src_ptr, data_ptr_t dst_ptr, idx_t count, idx_t row, bitpacking_width_t width) { - auto bitunpack_src_ptr = &src_ptr[(row * width) / 8]; - BitpackingPrimitives::UnPackBuffer(dst_ptr, bitunpack_src_ptr, count, width); -} - -//===--------------------------------------------------------------------===// -// Scan base data -//===--------------------------------------------------------------------===// -template -void FSSTStorage::StringScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, - idx_t result_offset) { - - auto &scan_state = state.scan_state->Cast(); - auto start = segment.GetRelativeIndex(state.row_index); - - bool enable_fsst_vectors; - if (ALLOW_FSST_VECTORS) { - auto &config = DBConfig::GetConfig(segment.db); - enable_fsst_vectors = config.options.enable_fsst_vectors; - } else { - enable_fsst_vectors = false; - } - - auto baseptr = scan_state.handle.Ptr() + segment.GetBlockOffset(); - auto dict = GetDictionary(segment, scan_state.handle); - auto base_data = data_ptr_cast(baseptr + sizeof(fsst_compression_header_t)); - string_t *result_data; - - if (scan_count == 0) { - return; - } - - if (enable_fsst_vectors) { - D_ASSERT(result_offset == 0); - if (scan_state.duckdb_fsst_decoder) { - D_ASSERT(result_offset == 0 || result.GetVectorType() == VectorType::FSST_VECTOR); - result.SetVectorType(VectorType::FSST_VECTOR); - FSSTVector::RegisterDecoder(result, scan_state.duckdb_fsst_decoder); - result_data = FSSTVector::GetCompressedData(result); - } else { - D_ASSERT(result.GetVectorType() == VectorType::FLAT_VECTOR); - result_data = FlatVector::GetData(result); - } - } else { - D_ASSERT(result.GetVectorType() == VectorType::FLAT_VECTOR); - result_data = FlatVector::GetData(result); - } - - if (start == 0 || scan_state.last_known_row >= (int64_t)start) { - scan_state.ResetStoredDelta(); - } - - auto offsets = CalculateBpDeltaOffsets(scan_state.last_known_row, start, scan_count); - - auto bitunpack_buffer = unique_ptr(new uint32_t[offsets.total_bitunpack_count]); - BitUnpackRange(base_data, data_ptr_cast(bitunpack_buffer.get()), offsets.total_bitunpack_count, - offsets.bitunpack_start_row, scan_state.current_width); - auto delta_decode_buffer = unique_ptr(new uint32_t[offsets.total_delta_decode_count]); - DeltaDecodeIndices(bitunpack_buffer.get() + offsets.bitunpack_alignment_offset, delta_decode_buffer.get(), - offsets.total_delta_decode_count, scan_state.last_known_index); - - if (enable_fsst_vectors) { - // Lookup decompressed offsets in dict - for (idx_t i = 0; i < scan_count; i++) { - uint32_t string_length = bitunpack_buffer[i + offsets.scan_offset]; - result_data[i] = UncompressedStringStorage::FetchStringFromDict( - segment, dict, result, baseptr, delta_decode_buffer[i + offsets.unused_delta_decoded_values], - string_length); - FSSTVector::SetCount(result, scan_count); - } - } else { - // Just decompress - for (idx_t i = 0; i < scan_count; i++) { - uint32_t str_len = bitunpack_buffer[i + offsets.scan_offset]; - auto str_ptr = FSSTStorage::FetchStringPointer( - dict, baseptr, delta_decode_buffer[i + offsets.unused_delta_decoded_values]); - - if (str_len > 0) { - result_data[i + result_offset] = - FSSTPrimitives::DecompressValue(scan_state.duckdb_fsst_decoder.get(), result, str_ptr, str_len); - } else { - result_data[i + result_offset] = string_t(nullptr, 0); - } - } - } - - scan_state.StoreLastDelta(delta_decode_buffer[scan_count + offsets.unused_delta_decoded_values - 1], - start + scan_count - 1); -} - -void FSSTStorage::StringScan(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result) { - StringScanPartial(segment, state, scan_count, result, 0); -} - -//===--------------------------------------------------------------------===// -// Fetch -//===--------------------------------------------------------------------===// -void FSSTStorage::StringFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t row_id, Vector &result, - idx_t result_idx) { - - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - auto handle = buffer_manager.Pin(segment.block); - auto base_ptr = handle.Ptr() + segment.GetBlockOffset(); - auto base_data = data_ptr_cast(base_ptr + sizeof(fsst_compression_header_t)); - auto dict = GetDictionary(segment, handle); - - duckdb_fsst_decoder_t decoder; - bitpacking_width_t width; - auto have_symbol_table = ParseFSSTSegmentHeader(base_ptr, &decoder, &width); - - auto result_data = FlatVector::GetData(result); - - if (have_symbol_table) { - // We basically just do a scan of 1 which is kinda expensive as we need to repeatedly delta decode until we - // reach the row we want, we could consider a more clever caching trick if this is slow - auto offsets = CalculateBpDeltaOffsets(-1, row_id, 1); - - auto bitunpack_buffer = unique_ptr(new uint32_t[offsets.total_bitunpack_count]); - BitUnpackRange(base_data, data_ptr_cast(bitunpack_buffer.get()), offsets.total_bitunpack_count, - offsets.bitunpack_start_row, width); - auto delta_decode_buffer = unique_ptr(new uint32_t[offsets.total_delta_decode_count]); - DeltaDecodeIndices(bitunpack_buffer.get() + offsets.bitunpack_alignment_offset, delta_decode_buffer.get(), - offsets.total_delta_decode_count, 0); - - uint32_t string_length = bitunpack_buffer[offsets.scan_offset]; - - string_t compressed_string = UncompressedStringStorage::FetchStringFromDict( - segment, dict, result, base_ptr, delta_decode_buffer[offsets.unused_delta_decoded_values], string_length); - - result_data[result_idx] = FSSTPrimitives::DecompressValue((void *)&decoder, result, compressed_string.GetData(), - compressed_string.GetSize()); - } else { - // There's no fsst symtable, this only happens for empty strings or nulls, we can just emit an empty string - result_data[result_idx] = string_t(nullptr, 0); - } -} - -//===--------------------------------------------------------------------===// -// Get Function -//===--------------------------------------------------------------------===// -CompressionFunction FSSTFun::GetFunction(PhysicalType data_type) { - D_ASSERT(data_type == PhysicalType::VARCHAR); - return CompressionFunction( - CompressionType::COMPRESSION_FSST, data_type, FSSTStorage::StringInitAnalyze, FSSTStorage::StringAnalyze, - FSSTStorage::StringFinalAnalyze, FSSTStorage::InitCompression, FSSTStorage::Compress, - FSSTStorage::FinalizeCompress, FSSTStorage::StringInitScan, FSSTStorage::StringScan, - FSSTStorage::StringScanPartial, FSSTStorage::StringFetchRow, UncompressedFunctions::EmptySkip); -} - -bool FSSTFun::TypeIsSupported(PhysicalType type) { - return type == PhysicalType::VARCHAR; -} - -//===--------------------------------------------------------------------===// -// Helper Functions -//===--------------------------------------------------------------------===// -void FSSTStorage::SetDictionary(ColumnSegment &segment, BufferHandle &handle, StringDictionaryContainer container) { - auto header_ptr = reinterpret_cast(handle.Ptr() + segment.GetBlockOffset()); - Store(container.size, data_ptr_cast(&header_ptr->dict_size)); - Store(container.end, data_ptr_cast(&header_ptr->dict_end)); -} - -StringDictionaryContainer FSSTStorage::GetDictionary(ColumnSegment &segment, BufferHandle &handle) { - auto header_ptr = reinterpret_cast(handle.Ptr() + segment.GetBlockOffset()); - StringDictionaryContainer container; - container.size = Load(data_ptr_cast(&header_ptr->dict_size)); - container.end = Load(data_ptr_cast(&header_ptr->dict_end)); - return container; -} - -char *FSSTStorage::FetchStringPointer(StringDictionaryContainer dict, data_ptr_t baseptr, int32_t dict_offset) { - if (dict_offset == 0) { - return nullptr; - } - - auto dict_end = baseptr + dict.end; - auto dict_pos = dict_end - dict_offset; - return char_ptr_cast(dict_pos); -} - -// Returns false if no symbol table was found. This means all strings are either empty or null -bool FSSTStorage::ParseFSSTSegmentHeader(data_ptr_t base_ptr, duckdb_fsst_decoder_t *decoder_out, - bitpacking_width_t *width_out) { - auto header_ptr = reinterpret_cast(base_ptr); - auto fsst_symbol_table_offset = Load(data_ptr_cast(&header_ptr->fsst_symbol_table_offset)); - *width_out = (bitpacking_width_t)(Load(data_ptr_cast(&header_ptr->bitpacking_width))); - return duckdb_fsst_import(decoder_out, base_ptr + fsst_symbol_table_offset); -} - -// The calculation of offsets and counts while scanning or fetching is a bit tricky, for two reasons: -// - bitunpacking needs to be aligned to BITPACKING_ALGORITHM_GROUP_SIZE -// - delta decoding needs to decode from the last known value. -bp_delta_offsets_t FSSTStorage::CalculateBpDeltaOffsets(int64_t last_known_row, idx_t start, idx_t scan_count) { - D_ASSERT((idx_t)(last_known_row + 1) <= start); - bp_delta_offsets_t result; - - result.delta_decode_start_row = (idx_t)(last_known_row + 1); - result.bitunpack_alignment_offset = - result.delta_decode_start_row % BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE; - result.bitunpack_start_row = result.delta_decode_start_row - result.bitunpack_alignment_offset; - result.unused_delta_decoded_values = start - result.delta_decode_start_row; - result.scan_offset = result.bitunpack_alignment_offset + result.unused_delta_decoded_values; - result.total_delta_decode_count = scan_count + result.unused_delta_decoded_values; - result.total_bitunpack_count = - BitpackingPrimitives::RoundUpToAlgorithmGroupSize(scan_count + result.scan_offset); - - D_ASSERT(result.total_delta_decode_count + result.bitunpack_alignment_offset <= result.total_bitunpack_count); - return result; -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Scan -//===--------------------------------------------------------------------===// -unique_ptr ConstantInitScan(ColumnSegment &segment) { - return nullptr; -} - -//===--------------------------------------------------------------------===// -// Scan Partial -//===--------------------------------------------------------------------===// -void ConstantFillFunctionValidity(ColumnSegment &segment, Vector &result, idx_t start_idx, idx_t count) { - auto &stats = segment.stats.statistics; - if (stats.CanHaveNull()) { - auto &mask = FlatVector::Validity(result); - for (idx_t i = 0; i < count; i++) { - mask.SetInvalid(start_idx + i); - } - } -} - -template -void ConstantFillFunction(ColumnSegment &segment, Vector &result, idx_t start_idx, idx_t count) { - auto &nstats = segment.stats.statistics; - - auto data = FlatVector::GetData(result); - auto constant_value = NumericStats::GetMin(nstats); - for (idx_t i = 0; i < count; i++) { - data[start_idx + i] = constant_value; - } -} - -void ConstantScanPartialValidity(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, - idx_t result_offset) { - ConstantFillFunctionValidity(segment, result, result_offset, scan_count); -} - -template -void ConstantScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, - idx_t result_offset) { - ConstantFillFunction(segment, result, result_offset, scan_count); -} - -//===--------------------------------------------------------------------===// -// Scan base data -//===--------------------------------------------------------------------===// -void ConstantScanFunctionValidity(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result) { - auto &stats = segment.stats.statistics; - if (stats.CanHaveNull()) { - if (result.GetVectorType() == VectorType::CONSTANT_VECTOR) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - } else { - result.Flatten(scan_count); - ConstantFillFunctionValidity(segment, result, 0, scan_count); - } - } -} - -template -void ConstantScanFunction(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result) { - auto &nstats = segment.stats.statistics; - - auto data = FlatVector::GetData(result); - data[0] = NumericStats::GetMin(nstats); - result.SetVectorType(VectorType::CONSTANT_VECTOR); -} - -//===--------------------------------------------------------------------===// -// Fetch -//===--------------------------------------------------------------------===// -void ConstantFetchRowValidity(ColumnSegment &segment, ColumnFetchState &state, row_t row_id, Vector &result, - idx_t result_idx) { - ConstantFillFunctionValidity(segment, result, result_idx, 1); -} - -template -void ConstantFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) { - ConstantFillFunction(segment, result, result_idx, 1); -} - -//===--------------------------------------------------------------------===// -// Get Function -//===--------------------------------------------------------------------===// -CompressionFunction ConstantGetFunctionValidity(PhysicalType data_type) { - D_ASSERT(data_type == PhysicalType::BIT); - return CompressionFunction(CompressionType::COMPRESSION_CONSTANT, data_type, nullptr, nullptr, nullptr, nullptr, - nullptr, nullptr, ConstantInitScan, ConstantScanFunctionValidity, - ConstantScanPartialValidity, ConstantFetchRowValidity, UncompressedFunctions::EmptySkip); -} - -template -CompressionFunction ConstantGetFunction(PhysicalType data_type) { - return CompressionFunction(CompressionType::COMPRESSION_CONSTANT, data_type, nullptr, nullptr, nullptr, nullptr, - nullptr, nullptr, ConstantInitScan, ConstantScanFunction, ConstantScanPartial, - ConstantFetchRow, UncompressedFunctions::EmptySkip); -} - -CompressionFunction ConstantFun::GetFunction(PhysicalType data_type) { - switch (data_type) { - case PhysicalType::BIT: - return ConstantGetFunctionValidity(data_type); - case PhysicalType::BOOL: - case PhysicalType::INT8: - return ConstantGetFunction(data_type); - case PhysicalType::INT16: - return ConstantGetFunction(data_type); - case PhysicalType::INT32: - return ConstantGetFunction(data_type); - case PhysicalType::INT64: - return ConstantGetFunction(data_type); - case PhysicalType::UINT8: - return ConstantGetFunction(data_type); - case PhysicalType::UINT16: - return ConstantGetFunction(data_type); - case PhysicalType::UINT32: - return ConstantGetFunction(data_type); - case PhysicalType::UINT64: - return ConstantGetFunction(data_type); - case PhysicalType::INT128: - return ConstantGetFunction(data_type); - case PhysicalType::FLOAT: - return ConstantGetFunction(data_type); - case PhysicalType::DOUBLE: - return ConstantGetFunction(data_type); - default: - throw InternalException("Unsupported type for ConstantUncompressed::GetFunction"); - } -} - -bool ConstantFun::TypeIsSupported(PhysicalType type) { - switch (type) { - case PhysicalType::BIT: - case PhysicalType::BOOL: - case PhysicalType::INT8: - case PhysicalType::INT16: - case PhysicalType::INT32: - case PhysicalType::INT64: - case PhysicalType::UINT8: - case PhysicalType::UINT16: - case PhysicalType::UINT32: - case PhysicalType::UINT64: - case PhysicalType::INT128: - case PhysicalType::FLOAT: - case PhysicalType::DOUBLE: - return true; - default: - throw InternalException("Unsupported type for constant function"); - } -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - -#include - -namespace duckdb { - -template -CompressionFunction GetPatasFunction(PhysicalType data_type) { - throw NotImplementedException("GetPatasFunction not implemented for the given datatype"); -} - -template <> -CompressionFunction GetPatasFunction(PhysicalType data_type) { - return CompressionFunction(CompressionType::COMPRESSION_PATAS, data_type, PatasInitAnalyze, - PatasAnalyze, PatasFinalAnalyze, PatasInitCompression, - PatasCompress, PatasFinalizeCompress, PatasInitScan, - PatasScan, PatasScanPartial, PatasFetchRow, PatasSkip); -} - -template <> -CompressionFunction GetPatasFunction(PhysicalType data_type) { - return CompressionFunction(CompressionType::COMPRESSION_PATAS, data_type, PatasInitAnalyze, - PatasAnalyze, PatasFinalAnalyze, PatasInitCompression, - PatasCompress, PatasFinalizeCompress, PatasInitScan, - PatasScan, PatasScanPartial, PatasFetchRow, PatasSkip); -} - -CompressionFunction PatasCompressionFun::GetFunction(PhysicalType type) { - switch (type) { - case PhysicalType::FLOAT: - return GetPatasFunction(type); - case PhysicalType::DOUBLE: - return GetPatasFunction(type); - default: - throw InternalException("Unsupported type for Patas"); - } -} - -bool PatasCompressionFun::TypeIsSupported(PhysicalType type) { - switch (type) { - case PhysicalType::FLOAT: - case PhysicalType::DOUBLE: - return true; - default: - return false; - } -} - -} // namespace duckdb - - - - - - - - - -#include - -namespace duckdb { - -using rle_count_t = uint16_t; - -//===--------------------------------------------------------------------===// -// Analyze -//===--------------------------------------------------------------------===// -struct EmptyRLEWriter { - template - static void Operation(VALUE_TYPE value, rle_count_t count, void *dataptr, bool is_null) { - } -}; - -template -struct RLEState { - RLEState() : seen_count(0), last_value(NullValue()), last_seen_count(0), dataptr(nullptr) { - } - - idx_t seen_count; - T last_value; - rle_count_t last_seen_count; - void *dataptr; - bool all_null = true; - -public: - template - void Flush() { - OP::template Operation(last_value, last_seen_count, dataptr, all_null); - } - - template - void Update(const T *data, ValidityMask &validity, idx_t idx) { - if (validity.RowIsValid(idx)) { - if (all_null) { - // no value seen yet - // assign the current value, and increment the seen_count - // note that we increment last_seen_count rather than setting it to 1 - // this is intentional: this is the first VALID value we see - // but it might not be the first value in case of nulls! - last_value = data[idx]; - seen_count++; - last_seen_count++; - all_null = false; - } else if (last_value == data[idx]) { - // the last value is identical to this value: increment the last_seen_count - last_seen_count++; - } else { - // the values are different - // issue the callback on the last value - Flush(); - - // increment the seen_count and put the new value into the RLE slot - last_value = data[idx]; - seen_count++; - last_seen_count = 1; - } - } else { - // NULL value: we merely increment the last_seen_count - last_seen_count++; - } - if (last_seen_count == NumericLimits::Maximum()) { - // we have seen the same value so many times in a row we are at the limit of what fits in our count - // write away the value and move to the next value - Flush(); - last_seen_count = 0; - seen_count++; - } - } -}; - -template -struct RLEAnalyzeState : public AnalyzeState { - RLEAnalyzeState() { - } - - RLEState state; -}; - -template -unique_ptr RLEInitAnalyze(ColumnData &col_data, PhysicalType type) { - return make_uniq>(); -} - -template -bool RLEAnalyze(AnalyzeState &state, Vector &input, idx_t count) { - auto &rle_state = state.template Cast>(); - UnifiedVectorFormat vdata; - input.ToUnifiedFormat(count, vdata); - - auto data = UnifiedVectorFormat::GetData(vdata); - for (idx_t i = 0; i < count; i++) { - auto idx = vdata.sel->get_index(i); - rle_state.state.Update(data, vdata.validity, idx); - } - return true; -} - -template -idx_t RLEFinalAnalyze(AnalyzeState &state) { - auto &rle_state = state.template Cast>(); - return (sizeof(rle_count_t) + sizeof(T)) * rle_state.state.seen_count; -} - -//===--------------------------------------------------------------------===// -// Compress -//===--------------------------------------------------------------------===// -struct RLEConstants { - static constexpr const idx_t RLE_HEADER_SIZE = sizeof(uint64_t); -}; - -template -struct RLECompressState : public CompressionState { - struct RLEWriter { - template - static void Operation(VALUE_TYPE value, rle_count_t count, void *dataptr, bool is_null) { - auto state = reinterpret_cast *>(dataptr); - state->WriteValue(value, count, is_null); - } - }; - - static idx_t MaxRLECount() { - auto entry_size = sizeof(T) + sizeof(rle_count_t); - auto entry_count = (Storage::BLOCK_SIZE - RLEConstants::RLE_HEADER_SIZE) / entry_size; - auto max_vector_count = entry_count / STANDARD_VECTOR_SIZE; - return max_vector_count * STANDARD_VECTOR_SIZE; - } - - explicit RLECompressState(ColumnDataCheckpointer &checkpointer_p) - : checkpointer(checkpointer_p), - function(checkpointer.GetCompressionFunction(CompressionType::COMPRESSION_RLE)) { - CreateEmptySegment(checkpointer.GetRowGroup().start); - - state.dataptr = (void *)this; - max_rle_count = MaxRLECount(); - } - - void CreateEmptySegment(idx_t row_start) { - auto &db = checkpointer.GetDatabase(); - auto &type = checkpointer.GetType(); - auto column_segment = ColumnSegment::CreateTransientSegment(db, type, row_start); - column_segment->function = function; - current_segment = std::move(column_segment); - auto &buffer_manager = BufferManager::GetBufferManager(db); - handle = buffer_manager.Pin(current_segment->block); - } - - void Append(UnifiedVectorFormat &vdata, idx_t count) { - auto data = UnifiedVectorFormat::GetData(vdata); - for (idx_t i = 0; i < count; i++) { - auto idx = vdata.sel->get_index(i); - state.template Update::RLEWriter>(data, vdata.validity, idx); - } - } - - void WriteValue(T value, rle_count_t count, bool is_null) { - // write the RLE entry - auto handle_ptr = handle.Ptr() + RLEConstants::RLE_HEADER_SIZE; - auto data_pointer = reinterpret_cast(handle_ptr); - auto index_pointer = reinterpret_cast(handle_ptr + max_rle_count * sizeof(T)); - data_pointer[entry_count] = value; - index_pointer[entry_count] = count; - entry_count++; - - // update meta data - if (WRITE_STATISTICS && !is_null) { - NumericStats::Update(current_segment->stats.statistics, value); - } - current_segment->count += count; - - if (entry_count == max_rle_count) { - // we have finished writing this segment: flush it and create a new segment - auto row_start = current_segment->start + current_segment->count; - FlushSegment(); - CreateEmptySegment(row_start); - entry_count = 0; - } - } - - void FlushSegment() { - // flush the segment - // we compact the segment by moving the counts so they are directly next to the values - idx_t counts_size = sizeof(rle_count_t) * entry_count; - idx_t original_rle_offset = RLEConstants::RLE_HEADER_SIZE + max_rle_count * sizeof(T); - idx_t minimal_rle_offset = AlignValue(RLEConstants::RLE_HEADER_SIZE + sizeof(T) * entry_count); - idx_t total_segment_size = minimal_rle_offset + counts_size; - auto data_ptr = handle.Ptr(); - memmove(data_ptr + minimal_rle_offset, data_ptr + original_rle_offset, counts_size); - // store the final RLE offset within the segment - Store(minimal_rle_offset, data_ptr); - handle.Destroy(); - - auto &state = checkpointer.GetCheckpointState(); - state.FlushSegment(std::move(current_segment), total_segment_size); - } - - void Finalize() { - state.template Flush::RLEWriter>(); - - FlushSegment(); - current_segment.reset(); - } - - ColumnDataCheckpointer &checkpointer; - CompressionFunction &function; - unique_ptr current_segment; - BufferHandle handle; - - RLEState state; - idx_t entry_count = 0; - idx_t max_rle_count; -}; - -template -unique_ptr RLEInitCompression(ColumnDataCheckpointer &checkpointer, unique_ptr state) { - return make_uniq>(checkpointer); -} - -template -void RLECompress(CompressionState &state_p, Vector &scan_vector, idx_t count) { - auto &state = (RLECompressState &)state_p; - UnifiedVectorFormat vdata; - scan_vector.ToUnifiedFormat(count, vdata); - - state.Append(vdata, count); -} - -template -void RLEFinalizeCompress(CompressionState &state_p) { - auto &state = (RLECompressState &)state_p; - state.Finalize(); -} - -//===--------------------------------------------------------------------===// -// Scan -//===--------------------------------------------------------------------===// -template -struct RLEScanState : public SegmentScanState { - explicit RLEScanState(ColumnSegment &segment) { - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - handle = buffer_manager.Pin(segment.block); - entry_pos = 0; - position_in_entry = 0; - rle_count_offset = Load(handle.Ptr() + segment.GetBlockOffset()); - D_ASSERT(rle_count_offset <= Storage::BLOCK_SIZE); - } - - void Skip(ColumnSegment &segment, idx_t skip_count) { - auto data = handle.Ptr() + segment.GetBlockOffset(); - auto index_pointer = reinterpret_cast(data + rle_count_offset); - - for (idx_t i = 0; i < skip_count; i++) { - // assign the current value - position_in_entry++; - if (position_in_entry >= index_pointer[entry_pos]) { - // handled all entries in this RLE value - // move to the next entry - entry_pos++; - position_in_entry = 0; - } - } - } - - BufferHandle handle; - idx_t entry_pos; - idx_t position_in_entry; - uint32_t rle_count_offset; -}; - -template -unique_ptr RLEInitScan(ColumnSegment &segment) { - auto result = make_uniq>(segment); - return std::move(result); -} - -//===--------------------------------------------------------------------===// -// Scan base data -//===--------------------------------------------------------------------===// -template -void RLESkip(ColumnSegment &segment, ColumnScanState &state, idx_t skip_count) { - auto &scan_state = state.scan_state->Cast>(); - scan_state.Skip(segment, skip_count); -} - -template -static bool CanEmitConstantVector(idx_t position, idx_t run_length, idx_t scan_count) { - if (!ENTIRE_VECTOR) { - return false; - } - if (scan_count != STANDARD_VECTOR_SIZE) { - // Only when we can fill an entire Vector can we emit a ConstantVector, because subsequent scans require the - // input Vector to be flat - return false; - } - D_ASSERT(position < run_length); - auto remaining_in_run = run_length - position; - // The amount of values left in this run are equal or greater than the amount of values we need to scan - return remaining_in_run >= scan_count; -} - -template -inline static void ForwardToNextRun(RLEScanState &scan_state) { - // handled all entries in this RLE value - // move to the next entry - scan_state.entry_pos++; - scan_state.position_in_entry = 0; -} - -template -inline static bool ExhaustedRun(RLEScanState &scan_state, rle_count_t *index_pointer) { - return scan_state.position_in_entry >= index_pointer[scan_state.entry_pos]; -} - -template -static void RLEScanConstant(RLEScanState &scan_state, rle_count_t *index_pointer, T *data_pointer, idx_t scan_count, - Vector &result) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - auto result_data = ConstantVector::GetData(result); - result_data[0] = data_pointer[scan_state.entry_pos]; - scan_state.position_in_entry += scan_count; - if (ExhaustedRun(scan_state, index_pointer)) { - ForwardToNextRun(scan_state); - } - return; -} - -template -void RLEScanPartialInternal(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, - idx_t result_offset) { - auto &scan_state = state.scan_state->Cast>(); - - auto data = scan_state.handle.Ptr() + segment.GetBlockOffset(); - auto data_pointer = reinterpret_cast(data + RLEConstants::RLE_HEADER_SIZE); - auto index_pointer = reinterpret_cast(data + scan_state.rle_count_offset); - - // If we are scanning an entire Vector and it contains only a single run - if (CanEmitConstantVector(scan_state.position_in_entry, index_pointer[scan_state.entry_pos], - scan_count)) { - RLEScanConstant(scan_state, index_pointer, data_pointer, scan_count, result); - return; - } - - auto result_data = FlatVector::GetData(result); - result.SetVectorType(VectorType::FLAT_VECTOR); - for (idx_t i = 0; i < scan_count; i++) { - // assign the current value - result_data[result_offset + i] = data_pointer[scan_state.entry_pos]; - scan_state.position_in_entry++; - if (ExhaustedRun(scan_state, index_pointer)) { - ForwardToNextRun(scan_state); - } - } -} - -template -void RLEScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, - idx_t result_offset) { - return RLEScanPartialInternal(segment, state, scan_count, result, result_offset); -} - -template -void RLEScan(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result) { - RLEScanPartialInternal(segment, state, scan_count, result, 0); -} - -//===--------------------------------------------------------------------===// -// Fetch -//===--------------------------------------------------------------------===// -template -void RLEFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) { - RLEScanState scan_state(segment); - scan_state.Skip(segment, row_id); - - auto data = scan_state.handle.Ptr() + segment.GetBlockOffset(); - auto data_pointer = reinterpret_cast(data + RLEConstants::RLE_HEADER_SIZE); - auto result_data = FlatVector::GetData(result); - result_data[result_idx] = data_pointer[scan_state.entry_pos]; -} - -//===--------------------------------------------------------------------===// -// Get Function -//===--------------------------------------------------------------------===// -template -CompressionFunction GetRLEFunction(PhysicalType data_type) { - return CompressionFunction(CompressionType::COMPRESSION_RLE, data_type, RLEInitAnalyze, RLEAnalyze, - RLEFinalAnalyze, RLEInitCompression, - RLECompress, RLEFinalizeCompress, - RLEInitScan, RLEScan, RLEScanPartial, RLEFetchRow, RLESkip); -} - -CompressionFunction RLEFun::GetFunction(PhysicalType type) { - switch (type) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - return GetRLEFunction(type); - case PhysicalType::INT16: - return GetRLEFunction(type); - case PhysicalType::INT32: - return GetRLEFunction(type); - case PhysicalType::INT64: - return GetRLEFunction(type); - case PhysicalType::INT128: - return GetRLEFunction(type); - case PhysicalType::UINT8: - return GetRLEFunction(type); - case PhysicalType::UINT16: - return GetRLEFunction(type); - case PhysicalType::UINT32: - return GetRLEFunction(type); - case PhysicalType::UINT64: - return GetRLEFunction(type); - case PhysicalType::FLOAT: - return GetRLEFunction(type); - case PhysicalType::DOUBLE: - return GetRLEFunction(type); - case PhysicalType::LIST: - return GetRLEFunction(type); - default: - throw InternalException("Unsupported type for RLE"); - } -} - -bool RLEFun::TypeIsSupported(PhysicalType type) { - switch (type) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - case PhysicalType::INT16: - case PhysicalType::INT32: - case PhysicalType::INT64: - case PhysicalType::INT128: - case PhysicalType::UINT8: - case PhysicalType::UINT16: - case PhysicalType::UINT32: - case PhysicalType::UINT64: - case PhysicalType::FLOAT: - case PhysicalType::DOUBLE: - case PhysicalType::LIST: - return true; - default: - return false; - } -} - -} // namespace duckdb - - - - - - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Storage Class -//===--------------------------------------------------------------------===// -UncompressedStringSegmentState::~UncompressedStringSegmentState() { - while (head) { - // prevent deep recursion here - head = std::move(head->next); - } -} - -//===--------------------------------------------------------------------===// -// Analyze -//===--------------------------------------------------------------------===// -struct StringAnalyzeState : public AnalyzeState { - StringAnalyzeState() : count(0), total_string_size(0), overflow_strings(0) { - } - - idx_t count; - idx_t total_string_size; - idx_t overflow_strings; -}; - -unique_ptr UncompressedStringStorage::StringInitAnalyze(ColumnData &col_data, PhysicalType type) { - return make_uniq(); -} - -bool UncompressedStringStorage::StringAnalyze(AnalyzeState &state_p, Vector &input, idx_t count) { - auto &state = state_p.Cast(); - UnifiedVectorFormat vdata; - input.ToUnifiedFormat(count, vdata); - - state.count += count; - auto data = UnifiedVectorFormat::GetData(vdata); - for (idx_t i = 0; i < count; i++) { - auto idx = vdata.sel->get_index(i); - if (vdata.validity.RowIsValid(idx)) { - auto string_size = data[idx].GetSize(); - state.total_string_size += string_size; - if (string_size >= StringUncompressed::STRING_BLOCK_LIMIT) { - state.overflow_strings++; - } - } - } - return true; -} - -idx_t UncompressedStringStorage::StringFinalAnalyze(AnalyzeState &state_p) { - auto &state = state_p.Cast(); - return state.count * sizeof(int32_t) + state.total_string_size + state.overflow_strings * BIG_STRING_MARKER_SIZE; -} - -//===--------------------------------------------------------------------===// -// Scan -//===--------------------------------------------------------------------===// -unique_ptr UncompressedStringStorage::StringInitScan(ColumnSegment &segment) { - auto result = make_uniq(); - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - result->handle = buffer_manager.Pin(segment.block); - return std::move(result); -} - -//===--------------------------------------------------------------------===// -// Scan base data -//===--------------------------------------------------------------------===// -void UncompressedStringStorage::StringScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, - Vector &result, idx_t result_offset) { - // clear any previously locked buffers and get the primary buffer handle - auto &scan_state = state.scan_state->Cast(); - auto start = segment.GetRelativeIndex(state.row_index); - - auto baseptr = scan_state.handle.Ptr() + segment.GetBlockOffset(); - auto dict = GetDictionary(segment, scan_state.handle); - auto base_data = reinterpret_cast(baseptr + DICTIONARY_HEADER_SIZE); - auto result_data = FlatVector::GetData(result); - - int32_t previous_offset = start > 0 ? base_data[start - 1] : 0; - - for (idx_t i = 0; i < scan_count; i++) { - // std::abs used since offsets can be negative to indicate big strings - uint32_t string_length = std::abs(base_data[start + i]) - std::abs(previous_offset); - result_data[result_offset + i] = - FetchStringFromDict(segment, dict, result, baseptr, base_data[start + i], string_length); - previous_offset = base_data[start + i]; - } -} - -void UncompressedStringStorage::StringScan(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, - Vector &result) { - StringScanPartial(segment, state, scan_count, result, 0); -} - -//===--------------------------------------------------------------------===// -// Fetch -//===--------------------------------------------------------------------===// -BufferHandle &ColumnFetchState::GetOrInsertHandle(ColumnSegment &segment) { - auto primary_id = segment.block->BlockId(); - - auto entry = handles.find(primary_id); - if (entry == handles.end()) { - // not pinned yet: pin it - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - auto handle = buffer_manager.Pin(segment.block); - auto entry = handles.insert(make_pair(primary_id, std::move(handle))); - return entry.first->second; - } else { - // already pinned: use the pinned handle - return entry->second; - } -} - -void UncompressedStringStorage::StringFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t row_id, - Vector &result, idx_t result_idx) { - // fetch a single row from the string segment - // first pin the main buffer if it is not already pinned - auto &handle = state.GetOrInsertHandle(segment); - - auto baseptr = handle.Ptr() + segment.GetBlockOffset(); - auto dict = GetDictionary(segment, handle); - auto base_data = reinterpret_cast(baseptr + DICTIONARY_HEADER_SIZE); - auto result_data = FlatVector::GetData(result); - - auto dict_offset = base_data[row_id]; - uint32_t string_length; - if ((idx_t)row_id == 0) { - // edge case where this is the first string in the dict - string_length = std::abs(dict_offset); - } else { - string_length = std::abs(dict_offset) - std::abs(base_data[row_id - 1]); - } - result_data[result_idx] = FetchStringFromDict(segment, dict, result, baseptr, dict_offset, string_length); -} - -//===--------------------------------------------------------------------===// -// Append -//===--------------------------------------------------------------------===// -struct SerializedStringSegmentState : public ColumnSegmentState { - SerializedStringSegmentState() { - } - explicit SerializedStringSegmentState(vector blocks_p) : blocks(std::move(blocks_p)) { - } - - vector blocks; - - void Serialize(Serializer &serializer) const override { - serializer.WriteProperty(1, "overflow_blocks", blocks); - } -}; - -unique_ptr -UncompressedStringStorage::StringInitSegment(ColumnSegment &segment, block_id_t block_id, - optional_ptr segment_state) { - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - if (block_id == INVALID_BLOCK) { - auto handle = buffer_manager.Pin(segment.block); - StringDictionaryContainer dictionary; - dictionary.size = 0; - dictionary.end = segment.SegmentSize(); - SetDictionary(segment, handle, dictionary); - } - auto result = make_uniq(); - if (segment_state) { - auto &serialized_state = segment_state->Cast(); - result->on_disk_blocks = std::move(serialized_state.blocks); - } - return std::move(result); -} - -idx_t UncompressedStringStorage::FinalizeAppend(ColumnSegment &segment, SegmentStatistics &stats) { - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - auto handle = buffer_manager.Pin(segment.block); - auto dict = GetDictionary(segment, handle); - D_ASSERT(dict.end == segment.SegmentSize()); - // compute the total size required to store this segment - auto offset_size = DICTIONARY_HEADER_SIZE + segment.count * sizeof(int32_t); - auto total_size = offset_size + dict.size; - if (total_size >= COMPACTION_FLUSH_LIMIT) { - // the block is full enough, don't bother moving around the dictionary - return segment.SegmentSize(); - } - // the block has space left: figure out how much space we can save - auto move_amount = segment.SegmentSize() - total_size; - // move the dictionary so it lines up exactly with the offsets - auto dataptr = handle.Ptr(); - memmove(dataptr + offset_size, dataptr + dict.end - dict.size, dict.size); - dict.end -= move_amount; - D_ASSERT(dict.end == total_size); - // write the new dictionary (with the updated "end") - SetDictionary(segment, handle, dict); - return total_size; -} - -//===--------------------------------------------------------------------===// -// Serialization & Cleanup -//===--------------------------------------------------------------------===// -unique_ptr UncompressedStringStorage::SerializeState(ColumnSegment &segment) { - auto &state = segment.GetSegmentState()->Cast(); - if (state.on_disk_blocks.empty()) { - // no on-disk blocks - nothing to write - return nullptr; - } - return make_uniq(state.on_disk_blocks); -} - -unique_ptr UncompressedStringStorage::DeserializeState(Deserializer &deserializer) { - auto result = make_uniq(); - deserializer.ReadProperty(1, "overflow_blocks", result->blocks); - return std::move(result); -} - -void UncompressedStringStorage::CleanupState(ColumnSegment &segment) { - auto &state = segment.GetSegmentState()->Cast(); - auto &block_manager = segment.GetBlockManager(); - for (auto &block_id : state.on_disk_blocks) { - block_manager.MarkBlockAsModified(block_id); - } -} - -//===--------------------------------------------------------------------===// -// Get Function -//===--------------------------------------------------------------------===// -CompressionFunction StringUncompressed::GetFunction(PhysicalType data_type) { - D_ASSERT(data_type == PhysicalType::VARCHAR); - return CompressionFunction(CompressionType::COMPRESSION_UNCOMPRESSED, data_type, - UncompressedStringStorage::StringInitAnalyze, UncompressedStringStorage::StringAnalyze, - UncompressedStringStorage::StringFinalAnalyze, UncompressedFunctions::InitCompression, - UncompressedFunctions::Compress, UncompressedFunctions::FinalizeCompress, - UncompressedStringStorage::StringInitScan, UncompressedStringStorage::StringScan, - UncompressedStringStorage::StringScanPartial, UncompressedStringStorage::StringFetchRow, - UncompressedFunctions::EmptySkip, UncompressedStringStorage::StringInitSegment, - UncompressedStringStorage::StringInitAppend, UncompressedStringStorage::StringAppend, - UncompressedStringStorage::FinalizeAppend, nullptr, - UncompressedStringStorage::SerializeState, UncompressedStringStorage::DeserializeState, - UncompressedStringStorage::CleanupState); -} - -//===--------------------------------------------------------------------===// -// Helper Functions -//===--------------------------------------------------------------------===// -void UncompressedStringStorage::SetDictionary(ColumnSegment &segment, BufferHandle &handle, - StringDictionaryContainer container) { - auto startptr = handle.Ptr() + segment.GetBlockOffset(); - Store(container.size, startptr); - Store(container.end, startptr + sizeof(uint32_t)); -} - -StringDictionaryContainer UncompressedStringStorage::GetDictionary(ColumnSegment &segment, BufferHandle &handle) { - auto startptr = handle.Ptr() + segment.GetBlockOffset(); - StringDictionaryContainer container; - container.size = Load(startptr); - container.end = Load(startptr + sizeof(uint32_t)); - return container; -} - -idx_t UncompressedStringStorage::RemainingSpace(ColumnSegment &segment, BufferHandle &handle) { - auto dictionary = GetDictionary(segment, handle); - D_ASSERT(dictionary.end == segment.SegmentSize()); - idx_t used_space = dictionary.size + segment.count * sizeof(int32_t) + DICTIONARY_HEADER_SIZE; - D_ASSERT(segment.SegmentSize() >= used_space); - return segment.SegmentSize() - used_space; -} - -void UncompressedStringStorage::WriteString(ColumnSegment &segment, string_t string, block_id_t &result_block, - int32_t &result_offset) { - auto &state = segment.GetSegmentState()->Cast(); - if (state.overflow_writer) { - // overflow writer is set: write string there - state.overflow_writer->WriteString(state, string, result_block, result_offset); - } else { - // default overflow behavior: use in-memory buffer to store the overflow string - WriteStringMemory(segment, string, result_block, result_offset); - } -} - -void UncompressedStringStorage::WriteStringMemory(ColumnSegment &segment, string_t string, block_id_t &result_block, - int32_t &result_offset) { - uint32_t total_length = string.GetSize() + sizeof(uint32_t); - shared_ptr block; - BufferHandle handle; - - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - auto &state = segment.GetSegmentState()->Cast(); - // check if the string fits in the current block - if (!state.head || state.head->offset + total_length >= state.head->size) { - // string does not fit, allocate space for it - // create a new string block - idx_t alloc_size = MaxValue(total_length, Storage::BLOCK_SIZE); - auto new_block = make_uniq(); - new_block->offset = 0; - new_block->size = alloc_size; - // allocate an in-memory buffer for it - handle = buffer_manager.Allocate(alloc_size, false, &block); - state.overflow_blocks.insert(make_pair(block->BlockId(), reference(*new_block))); - new_block->block = std::move(block); - new_block->next = std::move(state.head); - state.head = std::move(new_block); - } else { - // string fits, copy it into the current block - handle = buffer_manager.Pin(state.head->block); - } - - result_block = state.head->block->BlockId(); - result_offset = state.head->offset; - - // copy the string and the length there - auto ptr = handle.Ptr() + state.head->offset; - Store(string.GetSize(), ptr); - ptr += sizeof(uint32_t); - memcpy(ptr, string.GetData(), string.GetSize()); - state.head->offset += total_length; -} - -string_t UncompressedStringStorage::ReadOverflowString(ColumnSegment &segment, Vector &result, block_id_t block, - int32_t offset) { - D_ASSERT(block != INVALID_BLOCK); - D_ASSERT(offset < Storage::BLOCK_SIZE); - - auto &block_manager = segment.GetBlockManager(); - auto &buffer_manager = block_manager.buffer_manager; - auto &state = segment.GetSegmentState()->Cast(); - if (block < MAXIMUM_BLOCK) { - // read the overflow string from disk - // pin the initial handle and read the length - auto block_handle = state.GetHandle(block_manager, block); - auto handle = buffer_manager.Pin(block_handle); - - // read header - uint32_t length = Load(handle.Ptr() + offset); - uint32_t remaining = length; - offset += sizeof(uint32_t); - - // allocate a buffer to store the string - auto alloc_size = MaxValue(Storage::BLOCK_SIZE, length); - // allocate a buffer to store the compressed string - // TODO: profile this to check if we need to reuse buffer - auto target_handle = buffer_manager.Allocate(alloc_size); - auto target_ptr = target_handle.Ptr(); - - // now append the string to the single buffer - while (remaining > 0) { - idx_t to_write = MinValue(remaining, Storage::BLOCK_SIZE - sizeof(block_id_t) - offset); - memcpy(target_ptr, handle.Ptr() + offset, to_write); - remaining -= to_write; - offset += to_write; - target_ptr += to_write; - if (remaining > 0) { - // read the next block - block_id_t next_block = Load(handle.Ptr() + offset); - block_handle = state.GetHandle(block_manager, next_block); - handle = buffer_manager.Pin(block_handle); - offset = 0; - } - } - - auto final_buffer = target_handle.Ptr(); - StringVector::AddHandle(result, std::move(target_handle)); - return ReadString(final_buffer, 0, length); - } else { - // read the overflow string from memory - // first pin the handle, if it is not pinned yet - auto entry = state.overflow_blocks.find(block); - D_ASSERT(entry != state.overflow_blocks.end()); - auto handle = buffer_manager.Pin(entry->second.get().block); - auto final_buffer = handle.Ptr(); - StringVector::AddHandle(result, std::move(handle)); - return ReadStringWithLength(final_buffer, offset); - } -} - -string_t UncompressedStringStorage::ReadString(data_ptr_t target, int32_t offset, uint32_t string_length) { - auto ptr = target + offset; - auto str_ptr = char_ptr_cast(ptr); - return string_t(str_ptr, string_length); -} - -string_t UncompressedStringStorage::ReadStringWithLength(data_ptr_t target, int32_t offset) { - auto ptr = target + offset; - auto str_length = Load(ptr); - auto str_ptr = char_ptr_cast(ptr + sizeof(uint32_t)); - return string_t(str_ptr, str_length); -} - -void UncompressedStringStorage::WriteStringMarker(data_ptr_t target, block_id_t block_id, int32_t offset) { - memcpy(target, &block_id, sizeof(block_id_t)); - target += sizeof(block_id_t); - memcpy(target, &offset, sizeof(int32_t)); -} - -void UncompressedStringStorage::ReadStringMarker(data_ptr_t target, block_id_t &block_id, int32_t &offset) { - memcpy(&block_id, target, sizeof(block_id_t)); - target += sizeof(block_id_t); - memcpy(&offset, target, sizeof(int32_t)); -} - -string_location_t UncompressedStringStorage::FetchStringLocation(StringDictionaryContainer dict, data_ptr_t baseptr, - int32_t dict_offset) { - D_ASSERT(dict_offset >= -1 * Storage::BLOCK_SIZE && dict_offset <= Storage::BLOCK_SIZE); - if (dict_offset < 0) { - string_location_t result; - ReadStringMarker(baseptr + dict.end - (-1 * dict_offset), result.block_id, result.offset); - return result; - } else { - return string_location_t(INVALID_BLOCK, dict_offset); - } -} - -string_t UncompressedStringStorage::FetchStringFromDict(ColumnSegment &segment, StringDictionaryContainer dict, - Vector &result, data_ptr_t baseptr, int32_t dict_offset, - uint32_t string_length) { - // fetch base data - D_ASSERT(dict_offset <= Storage::BLOCK_SIZE); - string_location_t location = FetchStringLocation(dict, baseptr, dict_offset); - return FetchString(segment, dict, result, baseptr, location, string_length); -} - -string_t UncompressedStringStorage::FetchString(ColumnSegment &segment, StringDictionaryContainer dict, Vector &result, - data_ptr_t baseptr, string_location_t location, - uint32_t string_length) { - if (location.block_id != INVALID_BLOCK) { - // big string marker: read from separate block - return ReadOverflowString(segment, result, location.block_id, location.offset); - } else { - if (location.offset == 0) { - return string_t(nullptr, 0); - } - // normal string: read string from this block - auto dict_end = baseptr + dict.end; - auto dict_pos = dict_end - location.offset; - - auto str_ptr = char_ptr_cast(dict_pos); - return string_t(str_ptr, string_length); - } -} - -} // namespace duckdb - - - -namespace duckdb { - -CompressionFunction UncompressedFun::GetFunction(PhysicalType type) { - switch (type) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - case PhysicalType::INT16: - case PhysicalType::INT32: - case PhysicalType::INT64: - case PhysicalType::INT128: - case PhysicalType::UINT8: - case PhysicalType::UINT16: - case PhysicalType::UINT32: - case PhysicalType::UINT64: - case PhysicalType::FLOAT: - case PhysicalType::DOUBLE: - case PhysicalType::LIST: - case PhysicalType::INTERVAL: - return FixedSizeUncompressed::GetFunction(type); - case PhysicalType::BIT: - return ValidityUncompressed::GetFunction(type); - case PhysicalType::VARCHAR: - return StringUncompressed::GetFunction(type); - default: - throw InternalException("Unsupported type for Uncompressed"); - } -} - -bool UncompressedFun::TypeIsSupported(PhysicalType type) { - return true; -} - -} // namespace duckdb - - - - - - - - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Mask constants -//===--------------------------------------------------------------------===// -// LOWER_MASKS contains masks with all the lower bits set until a specific value -// LOWER_MASKS[0] has the 0 lowest bits set, i.e.: -// 0b0000000000000000000000000000000000000000000000000000000000000000, -// LOWER_MASKS[10] has the 10 lowest bits set, i.e.: -// 0b0000000000000000000000000000000000000000000000000000000111111111, -// etc... -// 0b0000000000000000000000000000000000000001111111111111111111111111, -// ... -// 0b0000000000000000000001111111111111111111111111111111111111111111, -// until LOWER_MASKS[64], which has all bits set: -// 0b1111111111111111111111111111111111111111111111111111111111111111 -// generated with this python snippet: -// for i in range(65): -// print(hex(int((64 - i) * '0' + i * '1', 2)) + ",") -const validity_t ValidityUncompressed::LOWER_MASKS[] = {0x0, - 0x1, - 0x3, - 0x7, - 0xf, - 0x1f, - 0x3f, - 0x7f, - 0xff, - 0x1ff, - 0x3ff, - 0x7ff, - 0xfff, - 0x1fff, - 0x3fff, - 0x7fff, - 0xffff, - 0x1ffff, - 0x3ffff, - 0x7ffff, - 0xfffff, - 0x1fffff, - 0x3fffff, - 0x7fffff, - 0xffffff, - 0x1ffffff, - 0x3ffffff, - 0x7ffffff, - 0xfffffff, - 0x1fffffff, - 0x3fffffff, - 0x7fffffff, - 0xffffffff, - 0x1ffffffff, - 0x3ffffffff, - 0x7ffffffff, - 0xfffffffff, - 0x1fffffffff, - 0x3fffffffff, - 0x7fffffffff, - 0xffffffffff, - 0x1ffffffffff, - 0x3ffffffffff, - 0x7ffffffffff, - 0xfffffffffff, - 0x1fffffffffff, - 0x3fffffffffff, - 0x7fffffffffff, - 0xffffffffffff, - 0x1ffffffffffff, - 0x3ffffffffffff, - 0x7ffffffffffff, - 0xfffffffffffff, - 0x1fffffffffffff, - 0x3fffffffffffff, - 0x7fffffffffffff, - 0xffffffffffffff, - 0x1ffffffffffffff, - 0x3ffffffffffffff, - 0x7ffffffffffffff, - 0xfffffffffffffff, - 0x1fffffffffffffff, - 0x3fffffffffffffff, - 0x7fffffffffffffff, - 0xffffffffffffffff}; - -// UPPER_MASKS contains masks with all the highest bits set until a specific value -// UPPER_MASKS[0] has the 0 highest bits set, i.e.: -// 0b0000000000000000000000000000000000000000000000000000000000000000, -// UPPER_MASKS[10] has the 10 highest bits set, i.e.: -// 0b1111111111110000000000000000000000000000000000000000000000000000, -// etc... -// 0b1111111111111111111111110000000000000000000000000000000000000000, -// ... -// 0b1111111111111111111111111111111111111110000000000000000000000000, -// until UPPER_MASKS[64], which has all bits set: -// 0b1111111111111111111111111111111111111111111111111111111111111111 -// generated with this python snippet: -// for i in range(65): -// print(hex(int(i * '1' + (64 - i) * '0', 2)) + ",") -const validity_t ValidityUncompressed::UPPER_MASKS[] = {0x0, - 0x8000000000000000, - 0xc000000000000000, - 0xe000000000000000, - 0xf000000000000000, - 0xf800000000000000, - 0xfc00000000000000, - 0xfe00000000000000, - 0xff00000000000000, - 0xff80000000000000, - 0xffc0000000000000, - 0xffe0000000000000, - 0xfff0000000000000, - 0xfff8000000000000, - 0xfffc000000000000, - 0xfffe000000000000, - 0xffff000000000000, - 0xffff800000000000, - 0xffffc00000000000, - 0xffffe00000000000, - 0xfffff00000000000, - 0xfffff80000000000, - 0xfffffc0000000000, - 0xfffffe0000000000, - 0xffffff0000000000, - 0xffffff8000000000, - 0xffffffc000000000, - 0xffffffe000000000, - 0xfffffff000000000, - 0xfffffff800000000, - 0xfffffffc00000000, - 0xfffffffe00000000, - 0xffffffff00000000, - 0xffffffff80000000, - 0xffffffffc0000000, - 0xffffffffe0000000, - 0xfffffffff0000000, - 0xfffffffff8000000, - 0xfffffffffc000000, - 0xfffffffffe000000, - 0xffffffffff000000, - 0xffffffffff800000, - 0xffffffffffc00000, - 0xffffffffffe00000, - 0xfffffffffff00000, - 0xfffffffffff80000, - 0xfffffffffffc0000, - 0xfffffffffffe0000, - 0xffffffffffff0000, - 0xffffffffffff8000, - 0xffffffffffffc000, - 0xffffffffffffe000, - 0xfffffffffffff000, - 0xfffffffffffff800, - 0xfffffffffffffc00, - 0xfffffffffffffe00, - 0xffffffffffffff00, - 0xffffffffffffff80, - 0xffffffffffffffc0, - 0xffffffffffffffe0, - 0xfffffffffffffff0, - 0xfffffffffffffff8, - 0xfffffffffffffffc, - 0xfffffffffffffffe, - 0xffffffffffffffff}; - -//===--------------------------------------------------------------------===// -// Analyze -//===--------------------------------------------------------------------===// -struct ValidityAnalyzeState : public AnalyzeState { - ValidityAnalyzeState() : count(0) { - } - - idx_t count; -}; - -unique_ptr ValidityInitAnalyze(ColumnData &col_data, PhysicalType type) { - return make_uniq(); -} - -bool ValidityAnalyze(AnalyzeState &state_p, Vector &input, idx_t count) { - auto &state = state_p.Cast(); - state.count += count; - return true; -} - -idx_t ValidityFinalAnalyze(AnalyzeState &state_p) { - auto &state = state_p.Cast(); - return (state.count + 7) / 8; -} - -//===--------------------------------------------------------------------===// -// Scan -//===--------------------------------------------------------------------===// -struct ValidityScanState : public SegmentScanState { - BufferHandle handle; - block_id_t block_id; -}; - -unique_ptr ValidityInitScan(ColumnSegment &segment) { - auto result = make_uniq(); - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - result->handle = buffer_manager.Pin(segment.block); - result->block_id = segment.block->BlockId(); - return std::move(result); -} - -//===--------------------------------------------------------------------===// -// Scan base data -//===--------------------------------------------------------------------===// -void ValidityScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, - idx_t result_offset) { - auto start = segment.GetRelativeIndex(state.row_index); - - static_assert(sizeof(validity_t) == sizeof(uint64_t), "validity_t should be 64-bit"); - auto &scan_state = state.scan_state->Cast(); - - auto &result_mask = FlatVector::Validity(result); - auto buffer_ptr = scan_state.handle.Ptr() + segment.GetBlockOffset(); - D_ASSERT(scan_state.block_id == segment.block->BlockId()); - auto input_data = reinterpret_cast(buffer_ptr); - -#ifdef DEBUG - // this method relies on all the bits we are going to write to being set to valid - for (idx_t i = 0; i < scan_count; i++) { - D_ASSERT(result_mask.RowIsValid(result_offset + i)); - } -#endif -#if STANDARD_VECTOR_SIZE < 128 - // fallback for tiny vector sizes - // the bitwise ops we use below don't work if the vector size is too small - ValidityMask source_mask(input_data); - for (idx_t i = 0; i < scan_count; i++) { - if (!source_mask.RowIsValid(start + i)) { - if (result_mask.AllValid()) { - result_mask.Initialize(MaxValue(STANDARD_VECTOR_SIZE, result_offset + scan_count)); - } - result_mask.SetInvalid(result_offset + i); - } - } -#else - // the code below does what the fallback code above states, but using bitwise ops: - auto result_data = (validity_t *)result_mask.GetData(); - - // set up the initial positions - // we need to find the validity_entry to modify, together with the bit-index WITHIN the validity entry - idx_t result_entry = result_offset / ValidityMask::BITS_PER_VALUE; - idx_t result_idx = result_offset - result_entry * ValidityMask::BITS_PER_VALUE; - - // same for the input: find the validity_entry we are pulling from, together with the bit-index WITHIN that entry - idx_t input_entry = start / ValidityMask::BITS_PER_VALUE; - idx_t input_idx = start - input_entry * ValidityMask::BITS_PER_VALUE; - - // now start the bit games - idx_t pos = 0; - while (pos < scan_count) { - // these are the current validity entries we are dealing with - idx_t current_result_idx = result_entry; - idx_t offset; - validity_t input_mask = input_data[input_entry]; - - // construct the mask to AND together with the result - if (result_idx < input_idx) { - // we have to shift the input RIGHT if the result_idx is smaller than the input_idx - auto shift_amount = input_idx - result_idx; - D_ASSERT(shift_amount > 0 && shift_amount <= ValidityMask::BITS_PER_VALUE); - - input_mask = input_mask >> shift_amount; - - // now the upper "shift_amount" bits are set to 0 - // we need them to be set to 1 - // otherwise the subsequent bitwise & will modify values outside of the range of values we want to alter - input_mask |= ValidityUncompressed::UPPER_MASKS[shift_amount]; - - // after this, we move to the next input_entry - offset = ValidityMask::BITS_PER_VALUE - input_idx; - input_entry++; - input_idx = 0; - result_idx += offset; - } else if (result_idx > input_idx) { - // we have to shift the input LEFT if the result_idx is bigger than the input_idx - auto shift_amount = result_idx - input_idx; - D_ASSERT(shift_amount > 0 && shift_amount <= ValidityMask::BITS_PER_VALUE); - - // to avoid overflows, we set the upper "shift_amount" values to 0 first - input_mask = (input_mask & ~ValidityUncompressed::UPPER_MASKS[shift_amount]) << shift_amount; - - // now the lower "shift_amount" bits are set to 0 - // we need them to be set to 1 - // otherwise the subsequent bitwise & will modify values outside of the range of values we want to alter - input_mask |= ValidityUncompressed::LOWER_MASKS[shift_amount]; - - // after this, we move to the next result_entry - offset = ValidityMask::BITS_PER_VALUE - result_idx; - result_entry++; - result_idx = 0; - input_idx += offset; - } else { - // if the input_idx is equal to result_idx they are already aligned - // we just move to the next entry for both after this - offset = ValidityMask::BITS_PER_VALUE - result_idx; - input_entry++; - result_entry++; - result_idx = input_idx = 0; - } - // now we need to check if we should include the ENTIRE mask - // OR if we need to mask from the right side - pos += offset; - if (pos > scan_count) { - // we need to set any bits that are past the scan_count on the right-side to 1 - // this is required so we don't influence any bits that are not part of the scan - input_mask |= ValidityUncompressed::UPPER_MASKS[pos - scan_count]; - } - // now finally we can merge the input mask with the result mask - if (input_mask != ValidityMask::ValidityBuffer::MAX_ENTRY) { - if (!result_data) { - result_mask.Initialize(MaxValue(STANDARD_VECTOR_SIZE, result_offset + scan_count)); - result_data = (validity_t *)result_mask.GetData(); - } - result_data[current_result_idx] &= input_mask; - } - } -#endif - -#ifdef DEBUG - // verify that we actually accomplished the bitwise ops equivalent that we wanted to do - ValidityMask input_mask(input_data); - for (idx_t i = 0; i < scan_count; i++) { - D_ASSERT(result_mask.RowIsValid(result_offset + i) == input_mask.RowIsValid(start + i)); - } -#endif -} - -void ValidityScan(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result) { - result.Flatten(scan_count); - - auto start = segment.GetRelativeIndex(state.row_index); - if (start % ValidityMask::BITS_PER_VALUE == 0) { - auto &scan_state = state.scan_state->Cast(); - - // aligned scan: no need to do anything fancy - // note: this is only an optimization which avoids having to do messy bitshifting in the common case - // it is not required for correctness - auto &result_mask = FlatVector::Validity(result); - auto buffer_ptr = scan_state.handle.Ptr() + segment.GetBlockOffset(); - D_ASSERT(scan_state.block_id == segment.block->BlockId()); - auto input_data = reinterpret_cast(buffer_ptr); - auto result_data = result_mask.GetData(); - idx_t start_offset = start / ValidityMask::BITS_PER_VALUE; - idx_t entry_scan_count = (scan_count + ValidityMask::BITS_PER_VALUE - 1) / ValidityMask::BITS_PER_VALUE; - for (idx_t i = 0; i < entry_scan_count; i++) { - auto input_entry = input_data[start_offset + i]; - if (!result_data && input_entry == ValidityMask::ValidityBuffer::MAX_ENTRY) { - continue; - } - if (!result_data) { - result_mask.Initialize(MaxValue(STANDARD_VECTOR_SIZE, scan_count)); - result_data = result_mask.GetData(); - } - result_data[i] = input_entry; - } - } else { - // unaligned scan: fall back to scan_partial which does bitshift tricks - ValidityScanPartial(segment, state, scan_count, result, 0); - } -} - -//===--------------------------------------------------------------------===// -// Fetch -//===--------------------------------------------------------------------===// -void ValidityFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) { - D_ASSERT(row_id >= 0 && row_id < row_t(segment.count)); - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - auto handle = buffer_manager.Pin(segment.block); - auto dataptr = handle.Ptr() + segment.GetBlockOffset(); - ValidityMask mask(reinterpret_cast(dataptr)); - auto &result_mask = FlatVector::Validity(result); - if (!mask.RowIsValidUnsafe(row_id)) { - result_mask.SetInvalid(result_idx); - } -} - -//===--------------------------------------------------------------------===// -// Append -//===--------------------------------------------------------------------===// -static unique_ptr ValidityInitAppend(ColumnSegment &segment) { - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - auto handle = buffer_manager.Pin(segment.block); - return make_uniq(std::move(handle)); -} - -unique_ptr ValidityInitSegment(ColumnSegment &segment, block_id_t block_id, - optional_ptr segment_state) { - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - if (block_id == INVALID_BLOCK) { - auto handle = buffer_manager.Pin(segment.block); - memset(handle.Ptr(), 0xFF, segment.SegmentSize()); - } - return nullptr; -} - -idx_t ValidityAppend(CompressionAppendState &append_state, ColumnSegment &segment, SegmentStatistics &stats, - UnifiedVectorFormat &data, idx_t offset, idx_t vcount) { - D_ASSERT(segment.GetBlockOffset() == 0); - auto &validity_stats = stats.statistics; - - auto max_tuples = segment.SegmentSize() / ValidityMask::STANDARD_MASK_SIZE * STANDARD_VECTOR_SIZE; - idx_t append_count = MinValue(vcount, max_tuples - segment.count); - if (data.validity.AllValid()) { - // no null values: skip append - segment.count += append_count; - validity_stats.SetHasNoNull(); - return append_count; - } - - ValidityMask mask(reinterpret_cast(append_state.handle.Ptr())); - for (idx_t i = 0; i < append_count; i++) { - auto idx = data.sel->get_index(offset + i); - if (!data.validity.RowIsValidUnsafe(idx)) { - mask.SetInvalidUnsafe(segment.count + i); - validity_stats.SetHasNull(); - } else { - validity_stats.SetHasNoNull(); - } - } - segment.count += append_count; - return append_count; -} - -idx_t ValidityFinalizeAppend(ColumnSegment &segment, SegmentStatistics &stats) { - return ((segment.count + STANDARD_VECTOR_SIZE - 1) / STANDARD_VECTOR_SIZE) * ValidityMask::STANDARD_MASK_SIZE; -} - -void ValidityRevertAppend(ColumnSegment &segment, idx_t start_row) { - idx_t start_bit = start_row - segment.start; - - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - auto handle = buffer_manager.Pin(segment.block); - idx_t revert_start; - if (start_bit % 8 != 0) { - // handle sub-bit stuff (yay) - idx_t byte_pos = start_bit / 8; - idx_t bit_end = (byte_pos + 1) * 8; - ValidityMask mask(reinterpret_cast(handle.Ptr())); - for (idx_t i = start_bit; i < bit_end; i++) { - mask.SetValid(i); - } - revert_start = bit_end / 8; - } else { - revert_start = start_bit / 8; - } - // for the rest, we just memset - memset(handle.Ptr() + revert_start, 0xFF, segment.SegmentSize() - revert_start); -} - -//===--------------------------------------------------------------------===// -// Get Function -//===--------------------------------------------------------------------===// -CompressionFunction ValidityUncompressed::GetFunction(PhysicalType data_type) { - D_ASSERT(data_type == PhysicalType::BIT); - return CompressionFunction(CompressionType::COMPRESSION_UNCOMPRESSED, data_type, ValidityInitAnalyze, - ValidityAnalyze, ValidityFinalAnalyze, UncompressedFunctions::InitCompression, - UncompressedFunctions::Compress, UncompressedFunctions::FinalizeCompress, - ValidityInitScan, ValidityScan, ValidityScanPartial, ValidityFetchRow, - UncompressedFunctions::EmptySkip, ValidityInitSegment, ValidityInitAppend, - ValidityAppend, ValidityFinalizeAppend, ValidityRevertAppend); -} - -} // namespace duckdb - - - - - - -namespace duckdb { - -unique_ptr ColumnSegmentState::Deserialize(Deserializer &deserializer) { - auto compression_type = deserializer.Get(); - auto &db = deserializer.Get(); - auto &type = deserializer.Get(); - auto compression_function = DBConfig::GetConfig(db).GetCompressionFunction(compression_type, type.InternalType()); - if (!compression_function || !compression_function->deserialize_state) { - throw SerializationException("Deserializing a ColumnSegmentState but could not find deserialize method"); - } - return compression_function->deserialize_state(deserializer); -} - -} // namespace duckdb - - - - - - - - - - - - - - - - - - - - - - - - - - - - -namespace duckdb { - -DataTableInfo::DataTableInfo(AttachedDatabase &db, shared_ptr table_io_manager_p, string schema, - string table) - : db(db), table_io_manager(std::move(table_io_manager_p)), cardinality(0), schema(std::move(schema)), - table(std::move(table)) { -} - -bool DataTableInfo::IsTemporary() const { - return db.IsTemporary(); -} - -DataTable::DataTable(AttachedDatabase &db, shared_ptr table_io_manager_p, const string &schema, - const string &table, vector column_definitions_p, - unique_ptr data) - : info(make_shared(db, std::move(table_io_manager_p), schema, table)), - column_definitions(std::move(column_definitions_p)), db(db), is_root(true) { - // initialize the table with the existing data from disk, if any - auto types = GetTypes(); - this->row_groups = - make_shared(info, TableIOManager::Get(*this).GetBlockManagerForRowData(), types, 0); - if (data && data->row_group_count > 0) { - this->row_groups->Initialize(*data); - } else { - this->row_groups->InitializeEmpty(); - D_ASSERT(row_groups->GetTotalRows() == 0); - } - row_groups->Verify(); -} - -DataTable::DataTable(ClientContext &context, DataTable &parent, ColumnDefinition &new_column, Expression &default_value) - : info(parent.info), db(parent.db), is_root(true) { - // add the column definitions from this DataTable - for (auto &column_def : parent.column_definitions) { - column_definitions.emplace_back(column_def.Copy()); - } - column_definitions.emplace_back(new_column.Copy()); - // prevent any new tuples from being added to the parent - lock_guard parent_lock(parent.append_lock); - - this->row_groups = parent.row_groups->AddColumn(context, new_column, default_value); - - // also add this column to client local storage - auto &local_storage = LocalStorage::Get(context, db); - local_storage.AddColumn(parent, *this, new_column, default_value); - - // this table replaces the previous table, hence the parent is no longer the root DataTable - parent.is_root = false; -} - -DataTable::DataTable(ClientContext &context, DataTable &parent, idx_t removed_column) - : info(parent.info), db(parent.db), is_root(true) { - // prevent any new tuples from being added to the parent - lock_guard parent_lock(parent.append_lock); - - for (auto &column_def : parent.column_definitions) { - column_definitions.emplace_back(column_def.Copy()); - } - // first check if there are any indexes that exist that point to the removed column - info->indexes.Scan([&](Index &index) { - for (auto &column_id : index.column_ids) { - if (column_id == removed_column) { - throw CatalogException("Cannot drop this column: an index depends on it!"); - } else if (column_id > removed_column) { - throw CatalogException("Cannot drop this column: an index depends on a column after it!"); - } - } - return false; - }); - - // erase the column definitions from this DataTable - D_ASSERT(removed_column < column_definitions.size()); - column_definitions.erase(column_definitions.begin() + removed_column); - - storage_t storage_idx = 0; - for (idx_t i = 0; i < column_definitions.size(); i++) { - auto &col = column_definitions[i]; - col.SetOid(i); - if (col.Generated()) { - continue; - } - col.SetStorageOid(storage_idx++); - } - - // alter the row_groups and remove the column from each of them - this->row_groups = parent.row_groups->RemoveColumn(removed_column); - - // scan the original table, and fill the new column with the transformed value - auto &local_storage = LocalStorage::Get(context, db); - local_storage.DropColumn(parent, *this, removed_column); - - // this table replaces the previous table, hence the parent is no longer the root DataTable - parent.is_root = false; -} - -// Alter column to add new constraint -DataTable::DataTable(ClientContext &context, DataTable &parent, unique_ptr constraint) - : info(parent.info), db(parent.db), row_groups(parent.row_groups), is_root(true) { - - lock_guard parent_lock(parent.append_lock); - for (auto &column_def : parent.column_definitions) { - column_definitions.emplace_back(column_def.Copy()); - } - - // Verify the new constraint against current persistent/local data - VerifyNewConstraint(context, parent, constraint.get()); - - // Get the local data ownership from old dt - auto &local_storage = LocalStorage::Get(context, db); - local_storage.MoveStorage(parent, *this); - // this table replaces the previous table, hence the parent is no longer the root DataTable - parent.is_root = false; -} - -DataTable::DataTable(ClientContext &context, DataTable &parent, idx_t changed_idx, const LogicalType &target_type, - const vector &bound_columns, Expression &cast_expr) - : info(parent.info), db(parent.db), is_root(true) { - // prevent any tuples from being added to the parent - lock_guard lock(append_lock); - for (auto &column_def : parent.column_definitions) { - column_definitions.emplace_back(column_def.Copy()); - } - // first check if there are any indexes that exist that point to the changed column - info->indexes.Scan([&](Index &index) { - for (auto &column_id : index.column_ids) { - if (column_id == changed_idx) { - throw CatalogException("Cannot change the type of this column: an index depends on it!"); - } - } - return false; - }); - - // change the type in this DataTable - column_definitions[changed_idx].SetType(target_type); - - // set up the statistics for the table - // the column that had its type changed will have the new statistics computed during conversion - this->row_groups = parent.row_groups->AlterType(context, changed_idx, target_type, bound_columns, cast_expr); - - // scan the original table, and fill the new column with the transformed value - auto &local_storage = LocalStorage::Get(context, db); - local_storage.ChangeType(parent, *this, changed_idx, target_type, bound_columns, cast_expr); - - // this table replaces the previous table, hence the parent is no longer the root DataTable - parent.is_root = false; -} - -vector DataTable::GetTypes() { - vector types; - for (auto &it : column_definitions) { - types.push_back(it.Type()); - } - return types; -} - -TableIOManager &TableIOManager::Get(DataTable &table) { - return *table.info->table_io_manager; -} - -//===--------------------------------------------------------------------===// -// Scan -//===--------------------------------------------------------------------===// -void DataTable::InitializeScan(TableScanState &state, const vector &column_ids, - TableFilterSet *table_filters) { - state.Initialize(column_ids, table_filters); - row_groups->InitializeScan(state.table_state, column_ids, table_filters); -} - -void DataTable::InitializeScan(DuckTransaction &transaction, TableScanState &state, const vector &column_ids, - TableFilterSet *table_filters) { - InitializeScan(state, column_ids, table_filters); - auto &local_storage = LocalStorage::Get(transaction); - local_storage.InitializeScan(*this, state.local_state, table_filters); -} - -void DataTable::InitializeScanWithOffset(TableScanState &state, const vector &column_ids, idx_t start_row, - idx_t end_row) { - state.Initialize(column_ids); - row_groups->InitializeScanWithOffset(state.table_state, column_ids, start_row, end_row); -} - -idx_t DataTable::MaxThreads(ClientContext &context) { - idx_t parallel_scan_vector_count = Storage::ROW_GROUP_VECTOR_COUNT; - if (ClientConfig::GetConfig(context).verify_parallelism) { - parallel_scan_vector_count = 1; - } - idx_t parallel_scan_tuple_count = STANDARD_VECTOR_SIZE * parallel_scan_vector_count; - return GetTotalRows() / parallel_scan_tuple_count + 1; -} - -void DataTable::InitializeParallelScan(ClientContext &context, ParallelTableScanState &state) { - row_groups->InitializeParallelScan(state.scan_state); - - auto &local_storage = LocalStorage::Get(context, db); - local_storage.InitializeParallelScan(*this, state.local_state); -} - -bool DataTable::NextParallelScan(ClientContext &context, ParallelTableScanState &state, TableScanState &scan_state) { - if (row_groups->NextParallelScan(context, state.scan_state, scan_state.table_state)) { - return true; - } - scan_state.table_state.batch_index = state.scan_state.batch_index; - auto &local_storage = LocalStorage::Get(context, db); - if (local_storage.NextParallelScan(context, *this, state.local_state, scan_state.local_state)) { - return true; - } else { - // finished all scans: no more scans remaining - return false; - } -} - -void DataTable::Scan(DuckTransaction &transaction, DataChunk &result, TableScanState &state) { - // scan the persistent segments - if (state.table_state.Scan(transaction, result)) { - D_ASSERT(result.size() > 0); - return; - } - - // scan the transaction-local segments - auto &local_storage = LocalStorage::Get(transaction); - local_storage.Scan(state.local_state, state.GetColumnIds(), result); -} - -bool DataTable::CreateIndexScan(TableScanState &state, DataChunk &result, TableScanType type) { - return state.table_state.ScanCommitted(result, type); -} - -//===--------------------------------------------------------------------===// -// Fetch -//===--------------------------------------------------------------------===// -void DataTable::Fetch(DuckTransaction &transaction, DataChunk &result, const vector &column_ids, - const Vector &row_identifiers, idx_t fetch_count, ColumnFetchState &state) { - row_groups->Fetch(transaction, result, column_ids, row_identifiers, fetch_count, state); -} - -//===--------------------------------------------------------------------===// -// Append -//===--------------------------------------------------------------------===// -static void VerifyNotNullConstraint(TableCatalogEntry &table, Vector &vector, idx_t count, const string &col_name) { - if (!VectorOperations::HasNull(vector, count)) { - return; - } - - throw ConstraintException("NOT NULL constraint failed: %s.%s", table.name, col_name); -} - -// To avoid throwing an error at SELECT, instead this moves the error detection to INSERT -static void VerifyGeneratedExpressionSuccess(ClientContext &context, TableCatalogEntry &table, DataChunk &chunk, - Expression &expr, column_t index) { - auto &col = table.GetColumn(LogicalIndex(index)); - D_ASSERT(col.Generated()); - ExpressionExecutor executor(context, expr); - Vector result(col.Type()); - try { - executor.ExecuteExpression(chunk, result); - } catch (InternalException &ex) { - throw; - } catch (std::exception &ex) { - throw ConstraintException("Incorrect value for generated column \"%s %s AS (%s)\" : %s", col.Name(), - col.Type().ToString(), col.GeneratedExpression().ToString(), ex.what()); - } -} - -static void VerifyCheckConstraint(ClientContext &context, TableCatalogEntry &table, Expression &expr, - DataChunk &chunk) { - ExpressionExecutor executor(context, expr); - Vector result(LogicalType::INTEGER); - try { - executor.ExecuteExpression(chunk, result); - } catch (std::exception &ex) { - throw ConstraintException("CHECK constraint failed: %s (Error: %s)", table.name, ex.what()); - } catch (...) { // LCOV_EXCL_START - throw ConstraintException("CHECK constraint failed: %s (Unknown Error)", table.name); - } // LCOV_EXCL_STOP - UnifiedVectorFormat vdata; - result.ToUnifiedFormat(chunk.size(), vdata); - - auto dataptr = UnifiedVectorFormat::GetData(vdata); - for (idx_t i = 0; i < chunk.size(); i++) { - auto idx = vdata.sel->get_index(i); - if (vdata.validity.RowIsValid(idx) && dataptr[idx] == 0) { - throw ConstraintException("CHECK constraint failed: %s", table.name); - } - } -} - -bool DataTable::IsForeignKeyIndex(const vector &fk_keys, Index &index, ForeignKeyType fk_type) { - if (fk_type == ForeignKeyType::FK_TYPE_PRIMARY_KEY_TABLE ? !index.IsUnique() : !index.IsForeign()) { - return false; - } - if (fk_keys.size() != index.column_ids.size()) { - return false; - } - for (auto &fk_key : fk_keys) { - bool is_found = false; - for (auto &index_key : index.column_ids) { - if (fk_key.index == index_key) { - is_found = true; - break; - } - } - if (!is_found) { - return false; - } - } - return true; -} - -// Find the first index that is not null, and did not find a match -static idx_t FirstMissingMatch(const ManagedSelection &matches) { - idx_t match_idx = 0; - - for (idx_t i = 0; i < matches.Size(); i++) { - auto match = matches.IndexMapsToLocation(match_idx, i); - match_idx += match; - if (!match) { - // This index is missing in the matches vector - return i; - } - } - return DConstants::INVALID_INDEX; -} - -idx_t LocateErrorIndex(bool is_append, const ManagedSelection &matches) { - idx_t failed_index = DConstants::INVALID_INDEX; - if (!is_append) { - // We expected to find nothing, so the first error is the first match - failed_index = matches[0]; - } else { - // We expected to find matches for all of them, so the first missing match is the first error - return FirstMissingMatch(matches); - } - return failed_index; -} - -[[noreturn]] static void ThrowForeignKeyConstraintError(idx_t failed_index, bool is_append, Index &index, - DataChunk &input) { - auto verify_type = is_append ? VerifyExistenceType::APPEND_FK : VerifyExistenceType::DELETE_FK; - - D_ASSERT(failed_index != DConstants::INVALID_INDEX); - D_ASSERT(index.type == IndexType::ART); - auto &art_index = index.Cast(); - auto key_name = art_index.GenerateErrorKeyName(input, failed_index); - auto exception_msg = art_index.GenerateConstraintErrorMessage(verify_type, key_name); - throw ConstraintException(exception_msg); -} - -bool IsForeignKeyConstraintError(bool is_append, idx_t input_count, const ManagedSelection &matches) { - if (is_append) { - // We need to find a match for all of the values - return matches.Count() != input_count; - } else { - // We should not find any matches - return matches.Count() != 0; - } -} - -static bool IsAppend(VerifyExistenceType verify_type) { - return verify_type == VerifyExistenceType::APPEND_FK; -} - -void DataTable::VerifyForeignKeyConstraint(const BoundForeignKeyConstraint &bfk, ClientContext &context, - DataChunk &chunk, VerifyExistenceType verify_type) { - const vector *src_keys_ptr = &bfk.info.fk_keys; - const vector *dst_keys_ptr = &bfk.info.pk_keys; - - bool is_append = IsAppend(verify_type); - if (!is_append) { - src_keys_ptr = &bfk.info.pk_keys; - dst_keys_ptr = &bfk.info.fk_keys; - } - - auto &table_entry_ptr = - Catalog::GetEntry(context, INVALID_CATALOG, bfk.info.schema, bfk.info.table); - // make the data chunk to check - vector types; - for (auto &col : table_entry_ptr.GetColumns().Physical()) { - types.emplace_back(col.Type()); - } - DataChunk dst_chunk; - dst_chunk.InitializeEmpty(types); - for (idx_t i = 0; i < src_keys_ptr->size(); i++) { - dst_chunk.data[(*dst_keys_ptr)[i].index].Reference(chunk.data[(*src_keys_ptr)[i].index]); - } - dst_chunk.SetCardinality(chunk.size()); - auto &data_table = table_entry_ptr.GetStorage(); - - idx_t count = dst_chunk.size(); - if (count <= 0) { - return; - } - - // Set up a way to record conflicts, rather than directly throw on them - unordered_set empty_column_list; - ConflictInfo empty_conflict_info(empty_column_list, false); - ConflictManager regular_conflicts(verify_type, count, &empty_conflict_info); - ConflictManager transaction_conflicts(verify_type, count, &empty_conflict_info); - regular_conflicts.SetMode(ConflictManagerMode::SCAN); - transaction_conflicts.SetMode(ConflictManagerMode::SCAN); - - data_table.info->indexes.VerifyForeignKey(*dst_keys_ptr, dst_chunk, regular_conflicts); - regular_conflicts.Finalize(); - auto ®ular_matches = regular_conflicts.Conflicts(); - - // check if we can insert the chunk into the reference table's local storage - auto &local_storage = LocalStorage::Get(context, db); - bool error = IsForeignKeyConstraintError(is_append, count, regular_matches); - bool transaction_error = false; - bool transaction_check = local_storage.Find(data_table); - - if (transaction_check) { - auto &transact_index = local_storage.GetIndexes(data_table); - transact_index.VerifyForeignKey(*dst_keys_ptr, dst_chunk, transaction_conflicts); - transaction_conflicts.Finalize(); - auto &transaction_matches = transaction_conflicts.Conflicts(); - transaction_error = IsForeignKeyConstraintError(is_append, count, transaction_matches); - } - - if (!transaction_error && !error) { - // No error occurred; - return; - } - - // Some error occurred, and we likely want to throw - optional_ptr index; - optional_ptr transaction_index; - - auto fk_type = is_append ? ForeignKeyType::FK_TYPE_PRIMARY_KEY_TABLE : ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE; - // check whether or not the chunk can be inserted or deleted into the referenced table' storage - index = data_table.info->indexes.FindForeignKeyIndex(*dst_keys_ptr, fk_type); - if (transaction_check) { - auto &transact_index = local_storage.GetIndexes(data_table); - // check whether or not the chunk can be inserted or deleted into the referenced table' storage - transaction_index = transact_index.FindForeignKeyIndex(*dst_keys_ptr, fk_type); - } - - if (!transaction_check) { - // Only local state is checked, throw the error - D_ASSERT(error); - auto failed_index = LocateErrorIndex(is_append, regular_matches); - D_ASSERT(failed_index != DConstants::INVALID_INDEX); - ThrowForeignKeyConstraintError(failed_index, is_append, *index, dst_chunk); - } - if (transaction_error && error && is_append) { - // When we want to do an append, we only throw if the foreign key does not exist in both transaction and local - // storage - auto &transaction_matches = transaction_conflicts.Conflicts(); - idx_t failed_index = DConstants::INVALID_INDEX; - idx_t regular_idx = 0; - idx_t transaction_idx = 0; - for (idx_t i = 0; i < count; i++) { - bool in_regular = regular_matches.IndexMapsToLocation(regular_idx, i); - regular_idx += in_regular; - bool in_transaction = transaction_matches.IndexMapsToLocation(transaction_idx, i); - transaction_idx += in_transaction; - - if (!in_regular && !in_transaction) { - // We need to find a match for all of the input values - // The failed index is i, it does not show up in either regular or transaction storage - failed_index = i; - break; - } - } - if (failed_index == DConstants::INVALID_INDEX) { - // We don't throw, every value was present in either regular or transaction storage - return; - } - ThrowForeignKeyConstraintError(failed_index, true, *index, dst_chunk); - } - if (!is_append && transaction_check) { - auto &transaction_matches = transaction_conflicts.Conflicts(); - if (error) { - auto failed_index = LocateErrorIndex(false, regular_matches); - D_ASSERT(failed_index != DConstants::INVALID_INDEX); - ThrowForeignKeyConstraintError(failed_index, false, *index, dst_chunk); - } else { - D_ASSERT(transaction_error); - D_ASSERT(transaction_matches.Count() != DConstants::INVALID_INDEX); - auto failed_index = LocateErrorIndex(false, transaction_matches); - D_ASSERT(failed_index != DConstants::INVALID_INDEX); - ThrowForeignKeyConstraintError(failed_index, false, *transaction_index, dst_chunk); - } - } -} - -void DataTable::VerifyAppendForeignKeyConstraint(const BoundForeignKeyConstraint &bfk, ClientContext &context, - DataChunk &chunk) { - VerifyForeignKeyConstraint(bfk, context, chunk, VerifyExistenceType::APPEND_FK); -} - -void DataTable::VerifyDeleteForeignKeyConstraint(const BoundForeignKeyConstraint &bfk, ClientContext &context, - DataChunk &chunk) { - VerifyForeignKeyConstraint(bfk, context, chunk, VerifyExistenceType::DELETE_FK); -} - -void DataTable::VerifyNewConstraint(ClientContext &context, DataTable &parent, const BoundConstraint *constraint) { - if (constraint->type != ConstraintType::NOT_NULL) { - throw NotImplementedException("FIXME: ALTER COLUMN with such constraint is not supported yet"); - } - - parent.row_groups->VerifyNewConstraint(parent, *constraint); - auto &local_storage = LocalStorage::Get(context, db); - local_storage.VerifyNewConstraint(parent, *constraint); -} - -bool HasUniqueIndexes(TableIndexList &list) { - bool has_unique_index = false; - list.Scan([&](Index &index) { - if (index.IsUnique()) { - return has_unique_index = true; - return true; - } - return false; - }); - return has_unique_index; -} - -void DataTable::VerifyUniqueIndexes(TableIndexList &indexes, ClientContext &context, DataChunk &chunk, - ConflictManager *conflict_manager) { - //! check whether or not the chunk can be inserted into the indexes - if (!conflict_manager) { - // Only need to verify that no unique constraints are violated - indexes.Scan([&](Index &index) { - if (!index.IsUnique()) { - return false; - } - index.VerifyAppend(chunk); - return false; - }); - return; - } - - D_ASSERT(conflict_manager); - // The conflict manager is only provided when a ON CONFLICT clause was provided to the INSERT statement - - idx_t matching_indexes = 0; - auto &conflict_info = conflict_manager->GetConflictInfo(); - // First we figure out how many indexes match our conflict target - // So we can optimize accordingly - indexes.Scan([&](Index &index) { - matching_indexes += conflict_info.ConflictTargetMatches(index); - return false; - }); - conflict_manager->SetMode(ConflictManagerMode::SCAN); - conflict_manager->SetIndexCount(matching_indexes); - // First we verify only the indexes that match our conflict target - unordered_set checked_indexes; - indexes.Scan([&](Index &index) { - if (!index.IsUnique()) { - return false; - } - if (conflict_info.ConflictTargetMatches(index)) { - index.VerifyAppend(chunk, *conflict_manager); - checked_indexes.insert(&index); - } - return false; - }); - - conflict_manager->SetMode(ConflictManagerMode::THROW); - // Then we scan the other indexes, throwing if they cause conflicts on tuples that were not found during - // the scan - indexes.Scan([&](Index &index) { - if (!index.IsUnique()) { - return false; - } - if (checked_indexes.count(&index)) { - // Already checked this constraint - return false; - } - index.VerifyAppend(chunk, *conflict_manager); - return false; - }); -} - -void DataTable::VerifyAppendConstraints(TableCatalogEntry &table, ClientContext &context, DataChunk &chunk, - ConflictManager *conflict_manager) { - if (table.HasGeneratedColumns()) { - // Verify that the generated columns expression work with the inserted values - auto binder = Binder::CreateBinder(context); - physical_index_set_t bound_columns; - CheckBinder generated_check_binder(*binder, context, table.name, table.GetColumns(), bound_columns); - for (auto &col : table.GetColumns().Logical()) { - if (!col.Generated()) { - continue; - } - D_ASSERT(col.Type().id() != LogicalTypeId::ANY); - generated_check_binder.target_type = col.Type(); - auto to_be_bound_expression = col.GeneratedExpression().Copy(); - auto bound_expression = generated_check_binder.Bind(to_be_bound_expression); - VerifyGeneratedExpressionSuccess(context, table, chunk, *bound_expression, col.Oid()); - } - } - - if (HasUniqueIndexes(info->indexes)) { - VerifyUniqueIndexes(info->indexes, context, chunk, conflict_manager); - } - - auto &constraints = table.GetConstraints(); - auto &bound_constraints = table.GetBoundConstraints(); - for (idx_t i = 0; i < bound_constraints.size(); i++) { - auto &base_constraint = constraints[i]; - auto &constraint = bound_constraints[i]; - switch (base_constraint->type) { - case ConstraintType::NOT_NULL: { - auto &bound_not_null = *reinterpret_cast(constraint.get()); - auto ¬_null = *reinterpret_cast(base_constraint.get()); - auto &col = table.GetColumns().GetColumn(LogicalIndex(not_null.index)); - VerifyNotNullConstraint(table, chunk.data[bound_not_null.index.index], chunk.size(), col.Name()); - break; - } - case ConstraintType::CHECK: { - auto &check = *reinterpret_cast(constraint.get()); - VerifyCheckConstraint(context, table, *check.expression, chunk); - break; - } - case ConstraintType::UNIQUE: { - // These were handled earlier on - break; - } - case ConstraintType::FOREIGN_KEY: { - auto &bfk = *reinterpret_cast(constraint.get()); - if (bfk.info.type == ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE || - bfk.info.type == ForeignKeyType::FK_TYPE_SELF_REFERENCE_TABLE) { - VerifyAppendForeignKeyConstraint(bfk, context, chunk); - } - break; - } - default: - throw NotImplementedException("Constraint type not implemented!"); - } - } -} - -void DataTable::InitializeLocalAppend(LocalAppendState &state, ClientContext &context) { - if (!is_root) { - throw TransactionException("Transaction conflict: adding entries to a table that has been altered!"); - } - auto &local_storage = LocalStorage::Get(context, db); - local_storage.InitializeAppend(state, *this); -} - -void DataTable::LocalAppend(LocalAppendState &state, TableCatalogEntry &table, ClientContext &context, DataChunk &chunk, - bool unsafe) { - if (chunk.size() == 0) { - return; - } - D_ASSERT(chunk.ColumnCount() == table.GetColumns().PhysicalColumnCount()); - if (!is_root) { - throw TransactionException("Transaction conflict: adding entries to a table that has been altered!"); - } - - chunk.Verify(); - - // verify any constraints on the new chunk - if (!unsafe) { - VerifyAppendConstraints(table, context, chunk); - } - - // append to the transaction local data - LocalStorage::Append(state, chunk); -} - -void DataTable::FinalizeLocalAppend(LocalAppendState &state) { - LocalStorage::FinalizeAppend(state); -} - -OptimisticDataWriter &DataTable::CreateOptimisticWriter(ClientContext &context) { - auto &local_storage = LocalStorage::Get(context, db); - return local_storage.CreateOptimisticWriter(*this); -} - -void DataTable::FinalizeOptimisticWriter(ClientContext &context, OptimisticDataWriter &writer) { - auto &local_storage = LocalStorage::Get(context, db); - local_storage.FinalizeOptimisticWriter(*this, writer); -} - -void DataTable::LocalMerge(ClientContext &context, RowGroupCollection &collection) { - auto &local_storage = LocalStorage::Get(context, db); - local_storage.LocalMerge(*this, collection); -} - -void DataTable::LocalAppend(TableCatalogEntry &table, ClientContext &context, DataChunk &chunk) { - LocalAppendState append_state; - auto &storage = table.GetStorage(); - storage.InitializeLocalAppend(append_state, context); - storage.LocalAppend(append_state, table, context, chunk); - storage.FinalizeLocalAppend(append_state); -} - -void DataTable::LocalAppend(TableCatalogEntry &table, ClientContext &context, ColumnDataCollection &collection) { - LocalAppendState append_state; - auto &storage = table.GetStorage(); - storage.InitializeLocalAppend(append_state, context); - for (auto &chunk : collection.Chunks()) { - storage.LocalAppend(append_state, table, context, chunk); - } - storage.FinalizeLocalAppend(append_state); -} - -void DataTable::AppendLock(TableAppendState &state) { - state.append_lock = unique_lock(append_lock); - if (!is_root) { - throw TransactionException("Transaction conflict: adding entries to a table that has been altered!"); - } - state.row_start = row_groups->GetTotalRows(); - state.current_row = state.row_start; -} - -void DataTable::InitializeAppend(DuckTransaction &transaction, TableAppendState &state, idx_t append_count) { - // obtain the append lock for this table - if (!state.append_lock) { - throw InternalException("DataTable::AppendLock should be called before DataTable::InitializeAppend"); - } - row_groups->InitializeAppend(transaction, state, append_count); -} - -void DataTable::Append(DataChunk &chunk, TableAppendState &state) { - D_ASSERT(is_root); - row_groups->Append(chunk, state); -} - -void DataTable::ScanTableSegment(idx_t row_start, idx_t count, const std::function &function) { - if (count == 0) { - return; - } - idx_t end = row_start + count; - - vector column_ids; - vector types; - for (idx_t i = 0; i < this->column_definitions.size(); i++) { - auto &col = this->column_definitions[i]; - column_ids.push_back(i); - types.push_back(col.Type()); - } - DataChunk chunk; - chunk.Initialize(Allocator::Get(db), types); - - CreateIndexScanState state; - - InitializeScanWithOffset(state, column_ids, row_start, row_start + count); - auto row_start_aligned = state.table_state.row_group->start + state.table_state.vector_index * STANDARD_VECTOR_SIZE; - - idx_t current_row = row_start_aligned; - while (current_row < end) { - state.table_state.ScanCommitted(chunk, TableScanType::TABLE_SCAN_COMMITTED_ROWS); - if (chunk.size() == 0) { - break; - } - idx_t end_row = current_row + chunk.size(); - // start of chunk is current_row - // end of chunk is end_row - // figure out if we need to write the entire chunk or just part of it - idx_t chunk_start = MaxValue(current_row, row_start); - idx_t chunk_end = MinValue(end_row, end); - D_ASSERT(chunk_start < chunk_end); - idx_t chunk_count = chunk_end - chunk_start; - if (chunk_count != chunk.size()) { - D_ASSERT(chunk_count <= chunk.size()); - // need to slice the chunk before insert - idx_t start_in_chunk; - if (current_row >= row_start) { - start_in_chunk = 0; - } else { - start_in_chunk = row_start - current_row; - } - SelectionVector sel(start_in_chunk, chunk_count); - chunk.Slice(sel, chunk_count); - chunk.Verify(); - } - function(chunk); - chunk.Reset(); - current_row = end_row; - } -} - -void DataTable::MergeStorage(RowGroupCollection &data, TableIndexList &indexes) { - row_groups->MergeStorage(data); - row_groups->Verify(); -} - -void DataTable::WriteToLog(WriteAheadLog &log, idx_t row_start, idx_t count) { - if (log.skip_writing) { - return; - } - log.WriteSetTable(info->schema, info->table); - ScanTableSegment(row_start, count, [&](DataChunk &chunk) { log.WriteInsert(chunk); }); -} - -void DataTable::CommitAppend(transaction_t commit_id, idx_t row_start, idx_t count) { - lock_guard lock(append_lock); - row_groups->CommitAppend(commit_id, row_start, count); - info->cardinality += count; -} - -void DataTable::RevertAppendInternal(idx_t start_row) { - // adjust the cardinality - info->cardinality = start_row; - D_ASSERT(is_root); - // revert appends made to row_groups - row_groups->RevertAppendInternal(start_row); -} - -void DataTable::RevertAppend(idx_t start_row, idx_t count) { - lock_guard lock(append_lock); - - // revert any appends to indexes - if (!info->indexes.Empty()) { - idx_t current_row_base = start_row; - row_t row_data[STANDARD_VECTOR_SIZE]; - Vector row_identifiers(LogicalType::ROW_TYPE, data_ptr_cast(row_data)); - idx_t scan_count = MinValue(count, row_groups->GetTotalRows() - start_row); - ScanTableSegment(start_row, scan_count, [&](DataChunk &chunk) { - for (idx_t i = 0; i < chunk.size(); i++) { - row_data[i] = current_row_base + i; - } - info->indexes.Scan([&](Index &index) { - index.Delete(chunk, row_identifiers); - return false; - }); - current_row_base += chunk.size(); - }); - } - - // we need to vacuum the indexes to remove any buffers that are now empty - // due to reverting the appends - info->indexes.Scan([&](Index &index) { - index.Vacuum(); - return false; - }); - - // revert the data table append - RevertAppendInternal(start_row); -} - -//===--------------------------------------------------------------------===// -// Indexes -//===--------------------------------------------------------------------===// -PreservedError DataTable::AppendToIndexes(TableIndexList &indexes, DataChunk &chunk, row_t row_start) { - PreservedError error; - if (indexes.Empty()) { - return error; - } - // first generate the vector of row identifiers - Vector row_identifiers(LogicalType::ROW_TYPE); - VectorOperations::GenerateSequence(row_identifiers, chunk.size(), row_start, 1); - - vector already_appended; - bool append_failed = false; - // now append the entries to the indices - indexes.Scan([&](Index &index) { - try { - error = index.Append(chunk, row_identifiers); - } catch (Exception &ex) { - error = PreservedError(ex); - } catch (std::exception &ex) { - error = PreservedError(ex); - } - if (error) { - append_failed = true; - return true; - } - already_appended.push_back(&index); - return false; - }); - - if (append_failed) { - // constraint violation! - // remove any appended entries from previous indexes (if any) - for (auto *index : already_appended) { - index->Delete(chunk, row_identifiers); - } - } - return error; -} - -PreservedError DataTable::AppendToIndexes(DataChunk &chunk, row_t row_start) { - D_ASSERT(is_root); - return AppendToIndexes(info->indexes, chunk, row_start); -} - -void DataTable::RemoveFromIndexes(TableAppendState &state, DataChunk &chunk, row_t row_start) { - D_ASSERT(is_root); - if (info->indexes.Empty()) { - return; - } - // first generate the vector of row identifiers - Vector row_identifiers(LogicalType::ROW_TYPE); - VectorOperations::GenerateSequence(row_identifiers, chunk.size(), row_start, 1); - - // now remove the entries from the indices - RemoveFromIndexes(state, chunk, row_identifiers); -} - -void DataTable::RemoveFromIndexes(TableAppendState &state, DataChunk &chunk, Vector &row_identifiers) { - D_ASSERT(is_root); - info->indexes.Scan([&](Index &index) { - index.Delete(chunk, row_identifiers); - return false; - }); -} - -void DataTable::RemoveFromIndexes(Vector &row_identifiers, idx_t count) { - D_ASSERT(is_root); - row_groups->RemoveFromIndexes(info->indexes, row_identifiers, count); -} - -//===--------------------------------------------------------------------===// -// Delete -//===--------------------------------------------------------------------===// -static bool TableHasDeleteConstraints(TableCatalogEntry &table) { - auto &bound_constraints = table.GetBoundConstraints(); - for (auto &constraint : bound_constraints) { - switch (constraint->type) { - case ConstraintType::NOT_NULL: - case ConstraintType::CHECK: - case ConstraintType::UNIQUE: - break; - case ConstraintType::FOREIGN_KEY: { - auto &bfk = *reinterpret_cast(constraint.get()); - if (bfk.info.type == ForeignKeyType::FK_TYPE_PRIMARY_KEY_TABLE || - bfk.info.type == ForeignKeyType::FK_TYPE_SELF_REFERENCE_TABLE) { - return true; - } - break; - } - default: - throw NotImplementedException("Constraint type not implemented!"); - } - } - return false; -} - -void DataTable::VerifyDeleteConstraints(TableCatalogEntry &table, ClientContext &context, DataChunk &chunk) { - auto &bound_constraints = table.GetBoundConstraints(); - for (auto &constraint : bound_constraints) { - switch (constraint->type) { - case ConstraintType::NOT_NULL: - case ConstraintType::CHECK: - case ConstraintType::UNIQUE: - break; - case ConstraintType::FOREIGN_KEY: { - auto &bfk = *reinterpret_cast(constraint.get()); - if (bfk.info.type == ForeignKeyType::FK_TYPE_PRIMARY_KEY_TABLE || - bfk.info.type == ForeignKeyType::FK_TYPE_SELF_REFERENCE_TABLE) { - VerifyDeleteForeignKeyConstraint(bfk, context, chunk); - } - break; - } - default: - throw NotImplementedException("Constraint type not implemented!"); - } - } -} - -idx_t DataTable::Delete(TableCatalogEntry &table, ClientContext &context, Vector &row_identifiers, idx_t count) { - D_ASSERT(row_identifiers.GetType().InternalType() == ROW_TYPE); - if (count == 0) { - return 0; - } - - auto &transaction = DuckTransaction::Get(context, db); - auto &local_storage = LocalStorage::Get(transaction); - bool has_delete_constraints = TableHasDeleteConstraints(table); - - row_identifiers.Flatten(count); - auto ids = FlatVector::GetData(row_identifiers); - - DataChunk verify_chunk; - vector col_ids; - vector types; - ColumnFetchState fetch_state; - if (has_delete_constraints) { - // initialize the chunk if there are any constraints to verify - for (idx_t i = 0; i < column_definitions.size(); i++) { - col_ids.push_back(column_definitions[i].StorageOid()); - types.emplace_back(column_definitions[i].Type()); - } - verify_chunk.Initialize(Allocator::Get(context), types); - } - idx_t pos = 0; - idx_t delete_count = 0; - while (pos < count) { - idx_t start = pos; - bool is_transaction_delete = ids[pos] >= MAX_ROW_ID; - // figure out which batch of rows to delete now - for (pos++; pos < count; pos++) { - bool row_is_transaction_delete = ids[pos] >= MAX_ROW_ID; - if (row_is_transaction_delete != is_transaction_delete) { - break; - } - } - idx_t current_offset = start; - idx_t current_count = pos - start; - - Vector offset_ids(row_identifiers, current_offset, pos); - if (is_transaction_delete) { - // transaction-local delete - if (has_delete_constraints) { - // perform the constraint verification - local_storage.FetchChunk(*this, offset_ids, current_count, col_ids, verify_chunk, fetch_state); - VerifyDeleteConstraints(table, context, verify_chunk); - } - delete_count += local_storage.Delete(*this, offset_ids, current_count); - } else { - // regular table delete - if (has_delete_constraints) { - // perform the constraint verification - Fetch(transaction, verify_chunk, col_ids, offset_ids, current_count, fetch_state); - VerifyDeleteConstraints(table, context, verify_chunk); - } - delete_count += row_groups->Delete(transaction, *this, ids + current_offset, current_count); - } - } - return delete_count; -} - -//===--------------------------------------------------------------------===// -// Update -//===--------------------------------------------------------------------===// -static void CreateMockChunk(vector &types, const vector &column_ids, DataChunk &chunk, - DataChunk &mock_chunk) { - // construct a mock DataChunk - mock_chunk.InitializeEmpty(types); - for (column_t i = 0; i < column_ids.size(); i++) { - mock_chunk.data[column_ids[i].index].Reference(chunk.data[i]); - } - mock_chunk.SetCardinality(chunk.size()); -} - -static bool CreateMockChunk(TableCatalogEntry &table, const vector &column_ids, - physical_index_set_t &desired_column_ids, DataChunk &chunk, DataChunk &mock_chunk) { - idx_t found_columns = 0; - // check whether the desired columns are present in the UPDATE clause - for (column_t i = 0; i < column_ids.size(); i++) { - if (desired_column_ids.find(column_ids[i]) != desired_column_ids.end()) { - found_columns++; - } - } - if (found_columns == 0) { - // no columns were found: no need to check the constraint again - return false; - } - if (found_columns != desired_column_ids.size()) { - // not all columns in UPDATE clause are present! - // this should not be triggered at all as the binder should add these columns - throw InternalException("Not all columns required for the CHECK constraint are present in the UPDATED chunk!"); - } - // construct a mock DataChunk - auto types = table.GetTypes(); - CreateMockChunk(types, column_ids, chunk, mock_chunk); - return true; -} - -void DataTable::VerifyUpdateConstraints(ClientContext &context, TableCatalogEntry &table, DataChunk &chunk, - const vector &column_ids) { - auto &constraints = table.GetConstraints(); - auto &bound_constraints = table.GetBoundConstraints(); - for (idx_t constr_idx = 0; constr_idx < bound_constraints.size(); constr_idx++) { - auto &base_constraint = constraints[constr_idx]; - auto &constraint = bound_constraints[constr_idx]; - switch (constraint->type) { - case ConstraintType::NOT_NULL: { - auto &bound_not_null = *reinterpret_cast(constraint.get()); - auto ¬_null = *reinterpret_cast(base_constraint.get()); - // check if the constraint is in the list of column_ids - for (idx_t col_idx = 0; col_idx < column_ids.size(); col_idx++) { - if (column_ids[col_idx] == bound_not_null.index) { - // found the column id: check the data in - auto &col = table.GetColumn(LogicalIndex(not_null.index)); - VerifyNotNullConstraint(table, chunk.data[col_idx], chunk.size(), col.Name()); - break; - } - } - break; - } - case ConstraintType::CHECK: { - auto &check = *reinterpret_cast(constraint.get()); - - DataChunk mock_chunk; - if (CreateMockChunk(table, column_ids, check.bound_columns, chunk, mock_chunk)) { - VerifyCheckConstraint(context, table, *check.expression, mock_chunk); - } - break; - } - case ConstraintType::UNIQUE: - case ConstraintType::FOREIGN_KEY: - break; - default: - throw NotImplementedException("Constraint type not implemented!"); - } - } - // update should not be called for indexed columns! - // instead update should have been rewritten to delete + update on higher layer -#ifdef DEBUG - info->indexes.Scan([&](Index &index) { - D_ASSERT(!index.IndexIsUpdated(column_ids)); - return false; - }); - -#endif -} - -void DataTable::Update(TableCatalogEntry &table, ClientContext &context, Vector &row_ids, - const vector &column_ids, DataChunk &updates) { - D_ASSERT(row_ids.GetType().InternalType() == ROW_TYPE); - D_ASSERT(column_ids.size() == updates.ColumnCount()); - updates.Verify(); - - auto count = updates.size(); - if (count == 0) { - return; - } - - if (!is_root) { - throw TransactionException("Transaction conflict: cannot update a table that has been altered!"); - } - - // first verify that no constraints are violated - VerifyUpdateConstraints(context, table, updates, column_ids); - - // now perform the actual update - Vector max_row_id_vec(Value::BIGINT(MAX_ROW_ID)); - Vector row_ids_slice(LogicalType::BIGINT); - DataChunk updates_slice; - updates_slice.InitializeEmpty(updates.GetTypes()); - - SelectionVector sel_local_update(count), sel_global_update(count); - auto n_local_update = VectorOperations::GreaterThanEquals(row_ids, max_row_id_vec, nullptr, count, - &sel_local_update, &sel_global_update); - auto n_global_update = count - n_local_update; - - // row id > MAX_ROW_ID? transaction-local storage - if (n_local_update > 0) { - updates_slice.Slice(updates, sel_local_update, n_local_update); - updates_slice.Flatten(); - row_ids_slice.Slice(row_ids, sel_local_update, n_local_update); - row_ids_slice.Flatten(n_local_update); - - LocalStorage::Get(context, db).Update(*this, row_ids_slice, column_ids, updates_slice); - } - - // otherwise global storage - if (n_global_update > 0) { - updates_slice.Slice(updates, sel_global_update, n_global_update); - updates_slice.Flatten(); - row_ids_slice.Slice(row_ids, sel_global_update, n_global_update); - row_ids_slice.Flatten(n_global_update); - - row_groups->Update(DuckTransaction::Get(context, db), FlatVector::GetData(row_ids_slice), column_ids, - updates_slice); - } -} - -void DataTable::UpdateColumn(TableCatalogEntry &table, ClientContext &context, Vector &row_ids, - const vector &column_path, DataChunk &updates) { - D_ASSERT(row_ids.GetType().InternalType() == ROW_TYPE); - D_ASSERT(updates.ColumnCount() == 1); - updates.Verify(); - if (updates.size() == 0) { - return; - } - - if (!is_root) { - throw TransactionException("Transaction conflict: cannot update a table that has been altered!"); - } - - // now perform the actual update - auto &transaction = DuckTransaction::Get(context, db); - - updates.Flatten(); - row_ids.Flatten(updates.size()); - row_groups->UpdateColumn(transaction, row_ids, column_path, updates); -} - -//===--------------------------------------------------------------------===// -// Index Scan -//===--------------------------------------------------------------------===// -void DataTable::InitializeWALCreateIndexScan(CreateIndexScanState &state, const vector &column_ids) { - // we grab the append lock to make sure nothing is appended until AFTER we finish the index scan - state.append_lock = std::unique_lock(append_lock); - InitializeScan(state, column_ids); -} - -void DataTable::WALAddIndex(ClientContext &context, unique_ptr index, - const vector> &expressions) { - - // if the data table is empty - if (row_groups->IsEmpty()) { - info->indexes.AddIndex(std::move(index)); - return; - } - - auto &allocator = Allocator::Get(db); - - // intermediate holds scanned chunks of the underlying data to create the index - DataChunk intermediate; - vector intermediate_types; - vector column_ids; - for (auto &it : column_definitions) { - intermediate_types.push_back(it.Type()); - column_ids.push_back(it.Oid()); - } - column_ids.push_back(COLUMN_IDENTIFIER_ROW_ID); - intermediate_types.emplace_back(LogicalType::ROW_TYPE); - - intermediate.Initialize(allocator, intermediate_types); - - // holds the result of executing the index expression on the intermediate chunks - DataChunk result; - result.Initialize(allocator, index->logical_types); - - // initialize an index scan - CreateIndexScanState state; - InitializeWALCreateIndexScan(state, column_ids); - - if (!is_root) { - throw InternalException("Error during WAL replay. Cannot add an index to a table that has been altered."); - } - - // now start incrementally building the index - { - IndexLock lock; - index->InitializeLock(lock); - - while (true) { - intermediate.Reset(); - result.Reset(); - // scan a new chunk from the table to index - CreateIndexScan(state, intermediate, TableScanType::TABLE_SCAN_COMMITTED_ROWS_OMIT_PERMANENTLY_DELETED); - if (intermediate.size() == 0) { - // finished scanning for index creation - // release all locks - break; - } - // resolve the expressions for this chunk - index->ExecuteExpressions(intermediate, result); - - // insert into the index - auto error = index->Insert(lock, result, intermediate.data[intermediate.ColumnCount() - 1]); - if (error) { - throw InternalException("Error during WAL replay: %s", error.Message()); - } - } - } - - info->indexes.AddIndex(std::move(index)); -} - -//===--------------------------------------------------------------------===// -// Statistics -//===--------------------------------------------------------------------===// -unique_ptr DataTable::GetStatistics(ClientContext &context, column_t column_id) { - if (column_id == COLUMN_IDENTIFIER_ROW_ID) { - return nullptr; - } - return row_groups->CopyStats(column_id); -} - -void DataTable::SetDistinct(column_t column_id, unique_ptr distinct_stats) { - D_ASSERT(column_id != COLUMN_IDENTIFIER_ROW_ID); - row_groups->SetDistinct(column_id, std::move(distinct_stats)); -} - -//===--------------------------------------------------------------------===// -// Checkpoint -//===--------------------------------------------------------------------===// -void DataTable::Checkpoint(TableDataWriter &writer, Serializer &metadata_serializer) { - // checkpoint each individual row group - // FIXME: we might want to combine adjacent row groups in case they have had deletions... - TableStatistics global_stats; - row_groups->CopyStats(global_stats); - - row_groups->Checkpoint(writer, global_stats); - - // The rowgroup payload data has been written. Now write: - // column stats - // row-group pointers - // table pointer - // index data - writer.FinalizeTable(std::move(global_stats), info.get(), metadata_serializer); -} - -void DataTable::CommitDropColumn(idx_t index) { - row_groups->CommitDropColumn(index); -} - -idx_t DataTable::GetTotalRows() { - return row_groups->GetTotalRows(); -} - -void DataTable::CommitDropTable() { - // commit a drop of this table: mark all blocks as modified so they can be reclaimed later on - row_groups->CommitDropTable(); -} - -//===--------------------------------------------------------------------===// -// GetColumnSegmentInfo -//===--------------------------------------------------------------------===// -vector DataTable::GetColumnSegmentInfo() { - return row_groups->GetColumnSegmentInfo(); -} - -} // namespace duckdb - -#endif diff --git a/lib/duckdb-fastpforlib.cpp b/lib/duckdb-fastpforlib.cpp deleted file mode 100644 index 5f077ae2..00000000 --- a/lib/duckdb-fastpforlib.cpp +++ /dev/null @@ -1,1303 +0,0 @@ -#ifdef GODUCKDB_FROM_SOURCE -// See https://raw.githubusercontent.com/duckdb/duckdb/main/LICENSE for licensing information - -#include "duckdb.hpp" -#include "duckdb-internal.hpp" -#ifndef DUCKDB_AMALGAMATION -#error header mismatch -#endif - - -// LICENSE_CHANGE_BEGIN -// The following code up to LICENSE_CHANGE_END is subject to THIRD PARTY LICENSE #15 -// See the end of this file for a list - - - -#include -#include - -namespace duckdb_fastpforlib { -namespace internal { - -// Used for uint8_t, uint16_t and uint32_t -template -typename std::enable_if<(DELTA + SHR) < TYPE_SIZE>::type unpack_single_out(const TYPE *__restrict in, - TYPE *__restrict out) { - *out = ((*in) >> SHR) % (1 << DELTA); -} - -// Used for uint8_t, uint16_t and uint32_t -template -typename std::enable_if<(DELTA + SHR) >= TYPE_SIZE>::type unpack_single_out(const TYPE *__restrict &in, - TYPE *__restrict out) { - *out = (*in) >> SHR; - ++in; - - static const TYPE NEXT_SHR = SHR + DELTA - TYPE_SIZE; - *out |= ((*in) % (1U << NEXT_SHR)) << (TYPE_SIZE - SHR); -} - -template -typename std::enable_if<(DELTA + SHR) < 32>::type unpack_single_out(const uint32_t *__restrict in, - uint64_t *__restrict out) { - *out = ((static_cast(*in)) >> SHR) % (1ULL << DELTA); -} - -template -typename std::enable_if<(DELTA + SHR) >= 32 && (DELTA + SHR) < 64>::type -unpack_single_out(const uint32_t *__restrict &in, uint64_t *__restrict out) { - *out = static_cast(*in) >> SHR; - ++in; - if (DELTA + SHR > 32) { - static const uint8_t NEXT_SHR = SHR + DELTA - 32; - *out |= static_cast((*in) % (1U << NEXT_SHR)) << (32 - SHR); - } -} - -template -typename std::enable_if<(DELTA + SHR) >= 64>::type unpack_single_out(const uint32_t *__restrict &in, - uint64_t *__restrict out) { - *out = static_cast(*in) >> SHR; - ++in; - - *out |= static_cast(*in) << (32 - SHR); - ++in; - - if (DELTA + SHR > 64) { - static const uint8_t NEXT_SHR = DELTA + SHR - 64; - *out |= static_cast((*in) % (1U << NEXT_SHR)) << (64 - SHR); - } -} - -// Used for uint8_t, uint16_t and uint32_t -template - typename std::enable_if < DELTA + SHL::type pack_single_in(const TYPE in, TYPE *__restrict out) { - if (SHL == 0) { - *out = in & MASK; - } else { - *out |= (in & MASK) << SHL; - } -} - -// Used for uint8_t, uint16_t and uint32_t -template -typename std::enable_if= TYPE_SIZE>::type pack_single_in(const TYPE in, TYPE *__restrict &out) { - *out |= in << SHL; - ++out; - - if (DELTA + SHL > TYPE_SIZE) { - *out = (in & MASK) >> (TYPE_SIZE - SHL); - } -} - -template - typename std::enable_if < DELTA + SHL<32>::type pack_single_in64(const uint64_t in, uint32_t *__restrict out) { - if (SHL == 0) { - *out = static_cast(in & MASK); - } else { - *out |= (in & MASK) << SHL; - } -} -template - typename std::enable_if < DELTA + SHL >= 32 && - DELTA + SHL<64>::type pack_single_in64(const uint64_t in, uint32_t *__restrict &out) { - if (SHL == 0) { - *out = static_cast(in & MASK); - } else { - *out |= (in & MASK) << SHL; - } - - ++out; - - if (DELTA + SHL > 32) { - *out = static_cast((in & MASK) >> (32 - SHL)); - } -} -template -typename std::enable_if= 64>::type pack_single_in64(const uint64_t in, uint32_t *__restrict &out) { - *out |= in << SHL; - ++out; - - *out = static_cast((in & MASK) >> (32 - SHL)); - ++out; - - if (DELTA + SHL > 64) { - *out = (in & MASK) >> (64 - SHL); - } -} -template -struct Unroller8 { - static void Unpack(const uint8_t *__restrict &in, uint8_t *__restrict out) { - unpack_single_out(in, out + OINDEX); - - Unroller8::Unpack(in, out); - } - - static void Pack(const uint8_t *__restrict in, uint8_t *__restrict out) { - pack_single_in(in[OINDEX], out); - - Unroller8::Pack(in, out); - } - -};\ -template -struct Unroller8 { - enum { SHIFT = (DELTA * 7) % 8 }; - - static void Unpack(const uint8_t *__restrict in, uint8_t *__restrict out) { - out[7] = (*in) >> SHIFT; - } - - static void Pack(const uint8_t *__restrict in, uint8_t *__restrict out) { - *out |= (in[7] << SHIFT); - } -}; - -template -struct Unroller16 { - static void Unpack(const uint16_t *__restrict &in, uint16_t *__restrict out) { - unpack_single_out(in, out + OINDEX); - - Unroller16::Unpack(in, out); - } - - static void Pack(const uint16_t *__restrict in, uint16_t *__restrict out) { - pack_single_in(in[OINDEX], out); - - Unroller16::Pack(in, out); - } - -}; - -template -struct Unroller16 { - enum { SHIFT = (DELTA * 15) % 16 }; - - static void Unpack(const uint16_t *__restrict in, uint16_t *__restrict out) { - out[15] = (*in) >> SHIFT; - } - - static void Pack(const uint16_t *__restrict in, uint16_t *__restrict out) { - *out |= (in[15] << SHIFT); - } -}; - -template -struct Unroller { - static void Unpack(const uint32_t *__restrict &in, uint32_t *__restrict out) { - unpack_single_out(in, out + OINDEX); - - Unroller::Unpack(in, out); - } - - static void Unpack(const uint32_t *__restrict &in, uint64_t *__restrict out) { - unpack_single_out(in, out + OINDEX); - - Unroller::Unpack(in, out); - } - - static void Pack(const uint32_t *__restrict in, uint32_t *__restrict out) { - pack_single_in(in[OINDEX], out); - - Unroller::Pack(in, out); - } - - static void Pack(const uint64_t *__restrict in, uint32_t *__restrict out) { - pack_single_in64(in[OINDEX], out); - - Unroller::Pack(in, out); - } -}; - -template -struct Unroller { - enum { SHIFT = (DELTA * 31) % 32 }; - - static void Unpack(const uint32_t *__restrict in, uint32_t *__restrict out) { - out[31] = (*in) >> SHIFT; - } - - static void Unpack(const uint32_t *__restrict in, uint64_t *__restrict out) { - out[31] = (*in) >> SHIFT; - if (DELTA > 32) { - ++in; - out[31] |= static_cast(*in) << (32 - SHIFT); - } - } - - static void Pack(const uint32_t *__restrict in, uint32_t *__restrict out) { - *out |= (in[31] << SHIFT); - } - - static void Pack(const uint64_t *__restrict in, uint32_t *__restrict out) { - *out |= (in[31] << SHIFT); - if (DELTA > 32) { - ++out; - *out = static_cast(in[31] >> (32 - SHIFT)); - } - } -}; - -// Special cases -void __fastunpack0(const uint8_t *__restrict, uint8_t *__restrict out) { - for (uint8_t i = 0; i < 8; ++i) - *(out++) = 0; -} - -void __fastunpack0(const uint16_t *__restrict, uint16_t *__restrict out) { - for (uint16_t i = 0; i < 16; ++i) - *(out++) = 0; -} - -void __fastunpack0(const uint32_t *__restrict, uint32_t *__restrict out) { - for (uint32_t i = 0; i < 32; ++i) - *(out++) = 0; -} - -void __fastunpack0(const uint32_t *__restrict, uint64_t *__restrict out) { - for (uint32_t i = 0; i < 32; ++i) - *(out++) = 0; -} - -void __fastpack0(const uint8_t *__restrict, uint8_t *__restrict) { -} -void __fastpack0(const uint16_t *__restrict, uint16_t *__restrict) { -} -void __fastpack0(const uint32_t *__restrict, uint32_t *__restrict) { -} -void __fastpack0(const uint64_t *__restrict, uint32_t *__restrict) { -} - -// fastunpack for 8 bits -void __fastunpack1(const uint8_t *__restrict in, uint8_t *__restrict out) { - Unroller8<1>::Unpack(in, out); -} - -void __fastunpack2(const uint8_t *__restrict in, uint8_t *__restrict out) { - Unroller8<2>::Unpack(in, out); -} - -void __fastunpack3(const uint8_t *__restrict in, uint8_t *__restrict out) { - Unroller8<3>::Unpack(in, out); -} - -void __fastunpack4(const uint8_t *__restrict in, uint8_t *__restrict out) { - for (uint8_t outer = 0; outer < 4; ++outer) { - for (uint8_t inwordpointer = 0; inwordpointer < 8; inwordpointer += 4) - *(out++) = ((*in) >> inwordpointer) % (1U << 4); - ++in; - } -} - -void __fastunpack5(const uint8_t *__restrict in, uint8_t *__restrict out) { - Unroller8<5>::Unpack(in, out); -} - -void __fastunpack6(const uint8_t *__restrict in, uint8_t *__restrict out) { - Unroller8<6>::Unpack(in, out); -} - -void __fastunpack7(const uint8_t *__restrict in, uint8_t *__restrict out) { - Unroller8<7>::Unpack(in, out); -} - -void __fastunpack8(const uint8_t *__restrict in, uint8_t *__restrict out) { - for (int k = 0; k < 8; ++k) - out[k] = in[k]; -} - - -// fastunpack for 16 bits -void __fastunpack1(const uint16_t *__restrict in, uint16_t *__restrict out) { - Unroller16<1>::Unpack(in, out); -} - -void __fastunpack2(const uint16_t *__restrict in, uint16_t *__restrict out) { - Unroller16<2>::Unpack(in, out); -} - -void __fastunpack3(const uint16_t *__restrict in, uint16_t *__restrict out) { - Unroller16<3>::Unpack(in, out); -} - -void __fastunpack4(const uint16_t *__restrict in, uint16_t *__restrict out) { - for (uint16_t outer = 0; outer < 4; ++outer) { - for (uint16_t inwordpointer = 0; inwordpointer < 16; inwordpointer += 4) - *(out++) = ((*in) >> inwordpointer) % (1U << 4); - ++in; - } -} - -void __fastunpack5(const uint16_t *__restrict in, uint16_t *__restrict out) { - Unroller16<5>::Unpack(in, out); -} - -void __fastunpack6(const uint16_t *__restrict in, uint16_t *__restrict out) { - Unroller16<6>::Unpack(in, out); -} - -void __fastunpack7(const uint16_t *__restrict in, uint16_t *__restrict out) { - Unroller16<7>::Unpack(in, out); -} - -void __fastunpack8(const uint16_t *__restrict in, uint16_t *__restrict out) { - for (uint16_t outer = 0; outer < 8; ++outer) { - for (uint16_t inwordpointer = 0; inwordpointer < 16; inwordpointer += 8) - *(out++) = ((*in) >> inwordpointer) % (1U << 8); - ++in; - } -} - -void __fastunpack9(const uint16_t *__restrict in, uint16_t *__restrict out) { - Unroller16<9>::Unpack(in, out); -} - -void __fastunpack10(const uint16_t *__restrict in, uint16_t *__restrict out) { - Unroller16<10>::Unpack(in, out); -} - -void __fastunpack11(const uint16_t *__restrict in, uint16_t *__restrict out) { - Unroller16<11>::Unpack(in, out); -} - -void __fastunpack12(const uint16_t *__restrict in, uint16_t *__restrict out) { - Unroller16<12>::Unpack(in, out); -} - -void __fastunpack13(const uint16_t *__restrict in, uint16_t *__restrict out) { - Unroller16<13>::Unpack(in, out); -} - -void __fastunpack14(const uint16_t *__restrict in, uint16_t *__restrict out) { - Unroller16<14>::Unpack(in, out); -} - -void __fastunpack15(const uint16_t *__restrict in, uint16_t *__restrict out) { - Unroller16<15>::Unpack(in, out); -} - -void __fastunpack16(const uint16_t *__restrict in, uint16_t *__restrict out) { - for (int k = 0; k < 16; ++k) - out[k] = in[k]; -} - -// fastunpack for 32 bits -void __fastunpack1(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<1>::Unpack(in, out); -} - -void __fastunpack2(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<2>::Unpack(in, out); -} - -void __fastunpack3(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<3>::Unpack(in, out); -} - -void __fastunpack4(const uint32_t *__restrict in, uint32_t *__restrict out) { - for (uint32_t outer = 0; outer < 4; ++outer) { - for (uint32_t inwordpointer = 0; inwordpointer < 32; inwordpointer += 4) - *(out++) = ((*in) >> inwordpointer) % (1U << 4); - ++in; - } -} - -void __fastunpack5(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<5>::Unpack(in, out); -} - -void __fastunpack6(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<6>::Unpack(in, out); -} - -void __fastunpack7(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<7>::Unpack(in, out); -} - -void __fastunpack8(const uint32_t *__restrict in, uint32_t *__restrict out) { - for (uint32_t outer = 0; outer < 8; ++outer) { - for (uint32_t inwordpointer = 0; inwordpointer < 32; inwordpointer += 8) - *(out++) = ((*in) >> inwordpointer) % (1U << 8); - ++in; - } -} - -void __fastunpack9(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<9>::Unpack(in, out); -} - -void __fastunpack10(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<10>::Unpack(in, out); -} - -void __fastunpack11(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<11>::Unpack(in, out); -} - -void __fastunpack12(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<12>::Unpack(in, out); -} - -void __fastunpack13(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<13>::Unpack(in, out); -} - -void __fastunpack14(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<14>::Unpack(in, out); -} - -void __fastunpack15(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<15>::Unpack(in, out); -} - -void __fastunpack16(const uint32_t *__restrict in, uint32_t *__restrict out) { - for (uint32_t outer = 0; outer < 16; ++outer) { - for (uint32_t inwordpointer = 0; inwordpointer < 32; inwordpointer += 16) - *(out++) = ((*in) >> inwordpointer) % (1U << 16); - ++in; - } -} - -void __fastunpack17(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<17>::Unpack(in, out); -} - -void __fastunpack18(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<18>::Unpack(in, out); -} - -void __fastunpack19(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<19>::Unpack(in, out); -} - -void __fastunpack20(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<20>::Unpack(in, out); -} - -void __fastunpack21(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<21>::Unpack(in, out); -} - -void __fastunpack22(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<22>::Unpack(in, out); -} - -void __fastunpack23(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<23>::Unpack(in, out); -} - -void __fastunpack24(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<24>::Unpack(in, out); -} - -void __fastunpack25(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<25>::Unpack(in, out); -} - -void __fastunpack26(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<26>::Unpack(in, out); -} - -void __fastunpack27(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<27>::Unpack(in, out); -} - -void __fastunpack28(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<28>::Unpack(in, out); -} - -void __fastunpack29(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<29>::Unpack(in, out); -} - -void __fastunpack30(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<30>::Unpack(in, out); -} - -void __fastunpack31(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<31>::Unpack(in, out); -} - -void __fastunpack32(const uint32_t *__restrict in, uint32_t *__restrict out) { - for (int k = 0; k < 32; ++k) - out[k] = in[k]; -} - -// fastupack for 64 bits -void __fastunpack1(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<1>::Unpack(in, out); -} - -void __fastunpack2(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<2>::Unpack(in, out); -} - -void __fastunpack3(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<3>::Unpack(in, out); -} - -void __fastunpack4(const uint32_t *__restrict in, uint64_t *__restrict out) { - for (uint32_t outer = 0; outer < 4; ++outer) { - for (uint32_t inwordpointer = 0; inwordpointer < 32; inwordpointer += 4) - *(out++) = ((*in) >> inwordpointer) % (1U << 4); - ++in; - } -} - -void __fastunpack5(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<5>::Unpack(in, out); -} - -void __fastunpack6(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<6>::Unpack(in, out); -} - -void __fastunpack7(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<7>::Unpack(in, out); -} - -void __fastunpack8(const uint32_t *__restrict in, uint64_t *__restrict out) { - for (uint32_t outer = 0; outer < 8; ++outer) { - for (uint32_t inwordpointer = 0; inwordpointer < 32; inwordpointer += 8) { - *(out++) = ((*in) >> inwordpointer) % (1U << 8); - } - ++in; - } -} - -void __fastunpack9(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<9>::Unpack(in, out); -} - -void __fastunpack10(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<10>::Unpack(in, out); -} - -void __fastunpack11(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<11>::Unpack(in, out); -} - -void __fastunpack12(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<12>::Unpack(in, out); -} - -void __fastunpack13(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<13>::Unpack(in, out); -} - -void __fastunpack14(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<14>::Unpack(in, out); -} - -void __fastunpack15(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<15>::Unpack(in, out); -} - -void __fastunpack16(const uint32_t *__restrict in, uint64_t *__restrict out) { - for (uint32_t outer = 0; outer < 16; ++outer) { - for (uint32_t inwordpointer = 0; inwordpointer < 32; inwordpointer += 16) - *(out++) = ((*in) >> inwordpointer) % (1U << 16); - ++in; - } -} - -void __fastunpack17(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<17>::Unpack(in, out); -} - -void __fastunpack18(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<18>::Unpack(in, out); -} - -void __fastunpack19(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<19>::Unpack(in, out); -} - -void __fastunpack20(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<20>::Unpack(in, out); -} - -void __fastunpack21(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<21>::Unpack(in, out); -} - -void __fastunpack22(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<22>::Unpack(in, out); -} - -void __fastunpack23(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<23>::Unpack(in, out); -} - -void __fastunpack24(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<24>::Unpack(in, out); -} - -void __fastunpack25(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<25>::Unpack(in, out); -} - -void __fastunpack26(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<26>::Unpack(in, out); -} - -void __fastunpack27(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<27>::Unpack(in, out); -} - -void __fastunpack28(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<28>::Unpack(in, out); -} - -void __fastunpack29(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<29>::Unpack(in, out); -} - -void __fastunpack30(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<30>::Unpack(in, out); -} - -void __fastunpack31(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<31>::Unpack(in, out); -} - -void __fastunpack32(const uint32_t *__restrict in, uint64_t *__restrict out) { - for (int k = 0; k < 32; ++k) - out[k] = in[k]; -} - -void __fastunpack33(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<33>::Unpack(in, out); -} - -void __fastunpack34(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<34>::Unpack(in, out); -} - -void __fastunpack35(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<35>::Unpack(in, out); -} - -void __fastunpack36(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<36>::Unpack(in, out); -} - -void __fastunpack37(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<37>::Unpack(in, out); -} - -void __fastunpack38(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<38>::Unpack(in, out); -} - -void __fastunpack39(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<39>::Unpack(in, out); -} - -void __fastunpack40(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<40>::Unpack(in, out); -} - -void __fastunpack41(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<41>::Unpack(in, out); -} - -void __fastunpack42(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<42>::Unpack(in, out); -} - -void __fastunpack43(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<43>::Unpack(in, out); -} - -void __fastunpack44(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<44>::Unpack(in, out); -} - -void __fastunpack45(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<45>::Unpack(in, out); -} - -void __fastunpack46(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<46>::Unpack(in, out); -} - -void __fastunpack47(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<47>::Unpack(in, out); -} - -void __fastunpack48(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<48>::Unpack(in, out); -} - -void __fastunpack49(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<49>::Unpack(in, out); -} - -void __fastunpack50(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<50>::Unpack(in, out); -} - -void __fastunpack51(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<51>::Unpack(in, out); -} - -void __fastunpack52(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<52>::Unpack(in, out); -} - -void __fastunpack53(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<53>::Unpack(in, out); -} - -void __fastunpack54(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<54>::Unpack(in, out); -} - -void __fastunpack55(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<55>::Unpack(in, out); -} - -void __fastunpack56(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<56>::Unpack(in, out); -} - -void __fastunpack57(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<57>::Unpack(in, out); -} - -void __fastunpack58(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<58>::Unpack(in, out); -} - -void __fastunpack59(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<59>::Unpack(in, out); -} - -void __fastunpack60(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<60>::Unpack(in, out); -} - -void __fastunpack61(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<61>::Unpack(in, out); -} - -void __fastunpack62(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<62>::Unpack(in, out); -} - -void __fastunpack63(const uint32_t *__restrict in, uint64_t *__restrict out) { - Unroller<63>::Unpack(in, out); -} - -void __fastunpack64(const uint32_t *__restrict in, uint64_t *__restrict out) { - for (int k = 0; k < 32; ++k) { - out[k] = in[k * 2]; - out[k] |= static_cast(in[k * 2 + 1]) << 32; - } -} - -// fastpack for 8 bits - -void __fastpack1(const uint8_t *__restrict in, uint8_t *__restrict out) { - Unroller8<1>::Pack(in, out); -} - -void __fastpack2(const uint8_t *__restrict in, uint8_t *__restrict out) { - Unroller8<2>::Pack(in, out); -} - -void __fastpack3(const uint8_t *__restrict in, uint8_t *__restrict out) { - Unroller8<3>::Pack(in, out); -} - -void __fastpack4(const uint8_t *__restrict in, uint8_t *__restrict out) { - Unroller8<4>::Pack(in, out); -} - -void __fastpack5(const uint8_t *__restrict in, uint8_t *__restrict out) { - Unroller8<5>::Pack(in, out); -} - -void __fastpack6(const uint8_t *__restrict in, uint8_t *__restrict out) { - Unroller8<6>::Pack(in, out); -} - -void __fastpack7(const uint8_t *__restrict in, uint8_t *__restrict out) { - Unroller8<7>::Pack(in, out); -} - -void __fastpack8(const uint8_t *__restrict in, uint8_t *__restrict out) { - for (int k = 0; k < 8; ++k) - out[k] = in[k]; -} - -// fastpack for 16 bits - -void __fastpack1(const uint16_t *__restrict in, uint16_t *__restrict out) { - Unroller16<1>::Pack(in, out); -} - -void __fastpack2(const uint16_t *__restrict in, uint16_t *__restrict out) { - Unroller16<2>::Pack(in, out); -} - -void __fastpack3(const uint16_t *__restrict in, uint16_t *__restrict out) { - Unroller16<3>::Pack(in, out); -} - -void __fastpack4(const uint16_t *__restrict in, uint16_t *__restrict out) { - Unroller16<4>::Pack(in, out); -} - -void __fastpack5(const uint16_t *__restrict in, uint16_t *__restrict out) { - Unroller16<5>::Pack(in, out); -} - -void __fastpack6(const uint16_t *__restrict in, uint16_t *__restrict out) { - Unroller16<6>::Pack(in, out); -} - -void __fastpack7(const uint16_t *__restrict in, uint16_t *__restrict out) { - Unroller16<7>::Pack(in, out); -} - -void __fastpack8(const uint16_t *__restrict in, uint16_t *__restrict out) { - Unroller16<8>::Pack(in, out); -} - -void __fastpack9(const uint16_t *__restrict in, uint16_t *__restrict out) { - Unroller16<9>::Pack(in, out); -} - -void __fastpack10(const uint16_t *__restrict in, uint16_t *__restrict out) { - Unroller16<10>::Pack(in, out); -} - -void __fastpack11(const uint16_t *__restrict in, uint16_t *__restrict out) { - Unroller16<11>::Pack(in, out); -} - -void __fastpack12(const uint16_t *__restrict in, uint16_t *__restrict out) { - Unroller16<12>::Pack(in, out); -} - -void __fastpack13(const uint16_t *__restrict in, uint16_t *__restrict out) { - Unroller16<13>::Pack(in, out); -} - -void __fastpack14(const uint16_t *__restrict in, uint16_t *__restrict out) { - Unroller16<14>::Pack(in, out); -} - -void __fastpack15(const uint16_t *__restrict in, uint16_t *__restrict out) { - Unroller16<15>::Pack(in, out); -} - -void __fastpack16(const uint16_t *__restrict in, uint16_t *__restrict out) { - for (int k = 0; k < 16; ++k) - out[k] = in[k]; -} - - -// fastpack for 32 bits - -void __fastpack1(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<1>::Pack(in, out); -} - -void __fastpack2(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<2>::Pack(in, out); -} - -void __fastpack3(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<3>::Pack(in, out); -} - -void __fastpack4(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<4>::Pack(in, out); -} - -void __fastpack5(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<5>::Pack(in, out); -} - -void __fastpack6(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<6>::Pack(in, out); -} - -void __fastpack7(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<7>::Pack(in, out); -} - -void __fastpack8(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<8>::Pack(in, out); -} - -void __fastpack9(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<9>::Pack(in, out); -} - -void __fastpack10(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<10>::Pack(in, out); -} - -void __fastpack11(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<11>::Pack(in, out); -} - -void __fastpack12(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<12>::Pack(in, out); -} - -void __fastpack13(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<13>::Pack(in, out); -} - -void __fastpack14(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<14>::Pack(in, out); -} - -void __fastpack15(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<15>::Pack(in, out); -} - -void __fastpack16(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<16>::Pack(in, out); -} - -void __fastpack17(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<17>::Pack(in, out); -} - -void __fastpack18(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<18>::Pack(in, out); -} - -void __fastpack19(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<19>::Pack(in, out); -} - -void __fastpack20(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<20>::Pack(in, out); -} - -void __fastpack21(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<21>::Pack(in, out); -} - -void __fastpack22(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<22>::Pack(in, out); -} - -void __fastpack23(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<23>::Pack(in, out); -} - -void __fastpack24(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<24>::Pack(in, out); -} - -void __fastpack25(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<25>::Pack(in, out); -} - -void __fastpack26(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<26>::Pack(in, out); -} - -void __fastpack27(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<27>::Pack(in, out); -} - -void __fastpack28(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<28>::Pack(in, out); -} - -void __fastpack29(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<29>::Pack(in, out); -} - -void __fastpack30(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<30>::Pack(in, out); -} - -void __fastpack31(const uint32_t *__restrict in, uint32_t *__restrict out) { - Unroller<31>::Pack(in, out); -} - -void __fastpack32(const uint32_t *__restrict in, uint32_t *__restrict out) { - for (int k = 0; k < 32; ++k) - out[k] = in[k]; -} - -// fastpack for 64 bits - -void __fastpack1(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<1>::Pack(in, out); -} - -void __fastpack2(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<2>::Pack(in, out); -} - -void __fastpack3(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<3>::Pack(in, out); -} - -void __fastpack4(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<4>::Pack(in, out); -} - -void __fastpack5(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<5>::Pack(in, out); -} - -void __fastpack6(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<6>::Pack(in, out); -} - -void __fastpack7(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<7>::Pack(in, out); -} - -void __fastpack8(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<8>::Pack(in, out); -} - -void __fastpack9(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<9>::Pack(in, out); -} - -void __fastpack10(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<10>::Pack(in, out); -} - -void __fastpack11(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<11>::Pack(in, out); -} - -void __fastpack12(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<12>::Pack(in, out); -} - -void __fastpack13(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<13>::Pack(in, out); -} - -void __fastpack14(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<14>::Pack(in, out); -} - -void __fastpack15(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<15>::Pack(in, out); -} - -void __fastpack16(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<16>::Pack(in, out); -} - -void __fastpack17(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<17>::Pack(in, out); -} - -void __fastpack18(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<18>::Pack(in, out); -} - -void __fastpack19(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<19>::Pack(in, out); -} - -void __fastpack20(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<20>::Pack(in, out); -} - -void __fastpack21(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<21>::Pack(in, out); -} - -void __fastpack22(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<22>::Pack(in, out); -} - -void __fastpack23(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<23>::Pack(in, out); -} - -void __fastpack24(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<24>::Pack(in, out); -} - -void __fastpack25(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<25>::Pack(in, out); -} - -void __fastpack26(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<26>::Pack(in, out); -} - -void __fastpack27(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<27>::Pack(in, out); -} - -void __fastpack28(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<28>::Pack(in, out); -} - -void __fastpack29(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<29>::Pack(in, out); -} - -void __fastpack30(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<30>::Pack(in, out); -} - -void __fastpack31(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<31>::Pack(in, out); -} - -void __fastpack32(const uint64_t *__restrict in, uint32_t *__restrict out) { - for (int k = 0; k < 32; ++k) { - out[k] = static_cast(in[k]); - } -} - -void __fastpack33(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<33>::Pack(in, out); -} - -void __fastpack34(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<34>::Pack(in, out); -} - -void __fastpack35(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<35>::Pack(in, out); -} - -void __fastpack36(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<36>::Pack(in, out); -} - -void __fastpack37(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<37>::Pack(in, out); -} - -void __fastpack38(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<38>::Pack(in, out); -} - -void __fastpack39(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<39>::Pack(in, out); -} - -void __fastpack40(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<40>::Pack(in, out); -} - -void __fastpack41(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<41>::Pack(in, out); -} - -void __fastpack42(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<42>::Pack(in, out); -} - -void __fastpack43(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<43>::Pack(in, out); -} - -void __fastpack44(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<44>::Pack(in, out); -} - -void __fastpack45(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<45>::Pack(in, out); -} - -void __fastpack46(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<46>::Pack(in, out); -} - -void __fastpack47(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<47>::Pack(in, out); -} - -void __fastpack48(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<48>::Pack(in, out); -} - -void __fastpack49(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<49>::Pack(in, out); -} - -void __fastpack50(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<50>::Pack(in, out); -} - -void __fastpack51(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<51>::Pack(in, out); -} - -void __fastpack52(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<52>::Pack(in, out); -} - -void __fastpack53(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<53>::Pack(in, out); -} - -void __fastpack54(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<54>::Pack(in, out); -} - -void __fastpack55(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<55>::Pack(in, out); -} - -void __fastpack56(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<56>::Pack(in, out); -} - -void __fastpack57(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<57>::Pack(in, out); -} - -void __fastpack58(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<58>::Pack(in, out); -} - -void __fastpack59(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<59>::Pack(in, out); -} - -void __fastpack60(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<60>::Pack(in, out); -} - -void __fastpack61(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<61>::Pack(in, out); -} - -void __fastpack62(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<62>::Pack(in, out); -} - -void __fastpack63(const uint64_t *__restrict in, uint32_t *__restrict out) { - Unroller<63>::Pack(in, out); -} - -void __fastpack64(const uint64_t *__restrict in, uint32_t *__restrict out) { - for (int i = 0; i < 32; ++i) { - out[2 * i] = static_cast(in[i]); - out[2 * i + 1] = in[i] >> 32; - } -} -} // namespace internal -} // namespace duckdb_fastpforlib - - -// LICENSE_CHANGE_END - -#endif diff --git a/lib/duckdb-fmt.cpp b/lib/duckdb-fmt.cpp deleted file mode 100644 index fcf2639b..00000000 --- a/lib/duckdb-fmt.cpp +++ /dev/null @@ -1,1384 +0,0 @@ -#ifdef GODUCKDB_FROM_SOURCE -// See https://raw.githubusercontent.com/duckdb/duckdb/main/LICENSE for licensing information - -#include "duckdb.hpp" -#include "duckdb-internal.hpp" -#ifndef DUCKDB_AMALGAMATION -#error header mismatch -#endif - - -// LICENSE_CHANGE_BEGIN -// The following code up to LICENSE_CHANGE_END is subject to THIRD PARTY LICENSE #5 -// See the end of this file for a list - -// Formatting library for C++ -// -// Copyright (c) 2012 - 2016, Victor Zverovich -// All rights reserved. -// -// For the license information refer to format.h. - - - -// LICENSE_CHANGE_BEGIN -// The following code up to LICENSE_CHANGE_END is subject to THIRD PARTY LICENSE #5 -// See the end of this file for a list - -// Formatting library for C++ - implementation -// -// Copyright (c) 2012 - 2016, Victor Zverovich -// All rights reserved. -// -// For the license information refer to format.h. - -#ifndef FMT_FORMAT_INL_H_ -#define FMT_FORMAT_INL_H_ - - - -#include -#include -#include -#include -#include -#include // for std::memmove -#include -#if FMT_EXCEPTIONS -# define FMT_TRY try -# define FMT_CATCH(x) catch (x) -#else -# define FMT_TRY if (true) -# define FMT_CATCH(x) if (false) -#endif - -#ifdef _MSC_VER -# pragma warning(push) -# pragma warning(disable : 4702) // unreachable code -#endif - -// Dummy implementations of strerror_r and strerror_s called if corresponding -// system functions are not available. -inline duckdb_fmt::internal::null<> strerror_r(int, char*, ...) { return {}; } -inline duckdb_fmt::internal::null<> strerror_s(char*, std::size_t, ...) { return {}; } - -FMT_BEGIN_NAMESPACE -namespace internal { - -#ifndef _MSC_VER -# define FMT_SNPRINTF snprintf -#else // _MSC_VER -inline int fmt_snprintf(char* buffer, size_t size, const char* format, ...) { - va_list args; - va_start(args, format); - int result = vsnprintf_s(buffer, size, _TRUNCATE, format, args); - va_end(args); - return result; -} -# define FMT_SNPRINTF fmt_snprintf -#endif // _MSC_VER - -using format_func = void (*)(internal::buffer&, int, string_view); - -// A portable thread-safe version of strerror. -// Sets buffer to point to a string describing the error code. -// This can be either a pointer to a string stored in buffer, -// or a pointer to some static immutable string. -// Returns one of the following values: -// 0 - success -// ERANGE - buffer is not large enough to store the error message -// other - failure -// Buffer should be at least of size 1. -FMT_FUNC int safe_strerror(int error_code, char*& buffer, - std::size_t buffer_size) FMT_NOEXCEPT { - FMT_ASSERT(buffer != nullptr && buffer_size != 0, "invalid buffer"); - - class dispatcher { - private: - int error_code_; - char*& buffer_; - std::size_t buffer_size_; - - // A noop assignment operator to avoid bogus warnings. - void operator=(const dispatcher&) {} - - // Handle the result of XSI-compliant version of strerror_r. - int handle(int result) { - // glibc versions before 2.13 return result in errno. - return result == -1 ? errno : result; - } - - // Handle the result of GNU-specific version of strerror_r. - int handle(char* message) { - // If the buffer is full then the message is probably truncated. - if (message == buffer_ && strlen(buffer_) == buffer_size_ - 1) - return ERANGE; - buffer_ = message; - return 0; - } - - // Handle the case when strerror_r is not available. - int handle(internal::null<>) { - return fallback(strerror_s(buffer_, buffer_size_, error_code_)); - } - - // Fallback to strerror_s when strerror_r is not available. - int fallback(int result) { - // If the buffer is full then the message is probably truncated. - return result == 0 && strlen(buffer_) == buffer_size_ - 1 ? ERANGE - : result; - } - -#if !FMT_MSC_VER - // Fallback to strerror if strerror_r and strerror_s are not available. - int fallback(internal::null<>) { - errno = 0; - buffer_ = strerror(error_code_); - return errno; - } -#endif - - public: - dispatcher(int err_code, char*& buf, std::size_t buf_size) - : error_code_(err_code), buffer_(buf), buffer_size_(buf_size) {} - - int run() { return handle(strerror_r(error_code_, buffer_, buffer_size_)); } - }; - return dispatcher(error_code, buffer, buffer_size).run(); -} - -FMT_FUNC void format_error_code(internal::buffer& out, int error_code, - string_view message) FMT_NOEXCEPT { - // Report error code making sure that the output fits into - // inline_buffer_size to avoid dynamic memory allocation and potential - // bad_alloc. - out.resize(0); - static const char SEP[] = ": "; - static const char ERROR_STR[] = "error "; - // Subtract 2 to account for terminating null characters in SEP and ERROR_STR. - std::size_t error_code_size = sizeof(SEP) + sizeof(ERROR_STR) - 2; - auto abs_value = static_cast>(error_code); - if (internal::is_negative(error_code)) { - abs_value = 0 - abs_value; - ++error_code_size; - } - error_code_size += internal::to_unsigned(internal::count_digits(abs_value)); - internal::writer w(out); - if (message.size() <= inline_buffer_size - error_code_size) { - w.write(message); - w.write(SEP); - } - w.write(ERROR_STR); - w.write(error_code); - assert(out.size() <= inline_buffer_size); -} - -FMT_FUNC void report_error(format_func func, int error_code, - string_view message) FMT_NOEXCEPT { - memory_buffer full_message; - func(full_message, error_code, message); - /*// R does not allow us to have a reference to stderr even if we are not using it - // Don't use fwrite_fully because the latter may throw. - (void)std::fwrite(full_message.data(), full_message.size(), 1, stderr); - std::fputc('\n', stderr); - */ -} -} // namespace internal - -template -FMT_FUNC std::string internal::grouping_impl(locale_ref) { - return "\03"; -} -template -FMT_FUNC Char internal::thousands_sep_impl(locale_ref) { - return ','; -} -template -FMT_FUNC Char internal::decimal_point_impl(locale_ref) { - return '.'; -} - -namespace internal { - -template <> FMT_FUNC int count_digits<4>(internal::fallback_uintptr n) { - // fallback_uintptr is always stored in little endian. - int i = static_cast(sizeof(void*)) - 1; - while (i > 0 && n.value[i] == 0) --i; - auto char_digits = std::numeric_limits::digits / 4; - return i >= 0 ? i * char_digits + count_digits<4, unsigned>(n.value[i]) : 1; -} - -template -const char basic_data::digits[] = - "0001020304050607080910111213141516171819" - "2021222324252627282930313233343536373839" - "4041424344454647484950515253545556575859" - "6061626364656667686970717273747576777879" - "8081828384858687888990919293949596979899"; - -template -const char basic_data::hex_digits[] = "0123456789abcdef"; - -#define FMT_POWERS_OF_10(factor) \ - factor * 10, (factor)*100, (factor)*1000, (factor)*10000, (factor)*100000, \ - (factor)*1000000, (factor)*10000000, (factor)*100000000, \ - (factor)*1000000000 - -template -const uint64_t basic_data::powers_of_10_64[] = { - 1, FMT_POWERS_OF_10(1), FMT_POWERS_OF_10(1000000000ULL), - 10000000000000000000ULL}; - -template -const uint32_t basic_data::zero_or_powers_of_10_32[] = {0, - FMT_POWERS_OF_10(1)}; - -template -const uint64_t basic_data::zero_or_powers_of_10_64[] = { - 0, FMT_POWERS_OF_10(1), FMT_POWERS_OF_10(1000000000ULL), - 10000000000000000000ULL}; - -// Normalized 64-bit significands of pow(10, k), for k = -348, -340, ..., 340. -// These are generated by support/compute-powers.py. -template -const uint64_t basic_data::pow10_significands[] = { - 0xfa8fd5a0081c0288, 0xbaaee17fa23ebf76, 0x8b16fb203055ac76, - 0xcf42894a5dce35ea, 0x9a6bb0aa55653b2d, 0xe61acf033d1a45df, - 0xab70fe17c79ac6ca, 0xff77b1fcbebcdc4f, 0xbe5691ef416bd60c, - 0x8dd01fad907ffc3c, 0xd3515c2831559a83, 0x9d71ac8fada6c9b5, - 0xea9c227723ee8bcb, 0xaecc49914078536d, 0x823c12795db6ce57, - 0xc21094364dfb5637, 0x9096ea6f3848984f, 0xd77485cb25823ac7, - 0xa086cfcd97bf97f4, 0xef340a98172aace5, 0xb23867fb2a35b28e, - 0x84c8d4dfd2c63f3b, 0xc5dd44271ad3cdba, 0x936b9fcebb25c996, - 0xdbac6c247d62a584, 0xa3ab66580d5fdaf6, 0xf3e2f893dec3f126, - 0xb5b5ada8aaff80b8, 0x87625f056c7c4a8b, 0xc9bcff6034c13053, - 0x964e858c91ba2655, 0xdff9772470297ebd, 0xa6dfbd9fb8e5b88f, - 0xf8a95fcf88747d94, 0xb94470938fa89bcf, 0x8a08f0f8bf0f156b, - 0xcdb02555653131b6, 0x993fe2c6d07b7fac, 0xe45c10c42a2b3b06, - 0xaa242499697392d3, 0xfd87b5f28300ca0e, 0xbce5086492111aeb, - 0x8cbccc096f5088cc, 0xd1b71758e219652c, 0x9c40000000000000, - 0xe8d4a51000000000, 0xad78ebc5ac620000, 0x813f3978f8940984, - 0xc097ce7bc90715b3, 0x8f7e32ce7bea5c70, 0xd5d238a4abe98068, - 0x9f4f2726179a2245, 0xed63a231d4c4fb27, 0xb0de65388cc8ada8, - 0x83c7088e1aab65db, 0xc45d1df942711d9a, 0x924d692ca61be758, - 0xda01ee641a708dea, 0xa26da3999aef774a, 0xf209787bb47d6b85, - 0xb454e4a179dd1877, 0x865b86925b9bc5c2, 0xc83553c5c8965d3d, - 0x952ab45cfa97a0b3, 0xde469fbd99a05fe3, 0xa59bc234db398c25, - 0xf6c69a72a3989f5c, 0xb7dcbf5354e9bece, 0x88fcf317f22241e2, - 0xcc20ce9bd35c78a5, 0x98165af37b2153df, 0xe2a0b5dc971f303a, - 0xa8d9d1535ce3b396, 0xfb9b7cd9a4a7443c, 0xbb764c4ca7a44410, - 0x8bab8eefb6409c1a, 0xd01fef10a657842c, 0x9b10a4e5e9913129, - 0xe7109bfba19c0c9d, 0xac2820d9623bf429, 0x80444b5e7aa7cf85, - 0xbf21e44003acdd2d, 0x8e679c2f5e44ff8f, 0xd433179d9c8cb841, - 0x9e19db92b4e31ba9, 0xeb96bf6ebadf77d9, 0xaf87023b9bf0ee6b, -}; - -// Binary exponents of pow(10, k), for k = -348, -340, ..., 340, corresponding -// to significands above. -template -const int16_t basic_data::pow10_exponents[] = { - -1220, -1193, -1166, -1140, -1113, -1087, -1060, -1034, -1007, -980, -954, - -927, -901, -874, -847, -821, -794, -768, -741, -715, -688, -661, - -635, -608, -582, -555, -529, -502, -475, -449, -422, -396, -369, - -343, -316, -289, -263, -236, -210, -183, -157, -130, -103, -77, - -50, -24, 3, 30, 56, 83, 109, 136, 162, 189, 216, - 242, 269, 295, 322, 348, 375, 402, 428, 455, 481, 508, - 534, 561, 588, 614, 641, 667, 694, 720, 747, 774, 800, - 827, 853, 880, 907, 933, 960, 986, 1013, 1039, 1066}; - -template -const char basic_data::foreground_color[] = "\x1b[38;2;"; -template -const char basic_data::background_color[] = "\x1b[48;2;"; -template const char basic_data::reset_color[] = "\x1b[0m"; -template const wchar_t basic_data::wreset_color[] = L"\x1b[0m"; -template const char basic_data::signs[] = {0, '-', '+', ' '}; - -template struct bits { - static FMT_CONSTEXPR_DECL const int value = - static_cast(sizeof(T) * std::numeric_limits::digits); -}; - -class fp; -template fp normalize(fp value); - -// Lower (upper) boundary is a value half way between a floating-point value -// and its predecessor (successor). Boundaries have the same exponent as the -// value so only significands are stored. -struct boundaries { - uint64_t lower; - uint64_t upper; -}; - -// A handmade floating-point number f * pow(2, e). -class fp { - private: - using significand_type = uint64_t; - - // All sizes are in bits. - // Subtract 1 to account for an implicit most significant bit in the - // normalized form. - static FMT_CONSTEXPR_DECL const int double_significand_size = - std::numeric_limits::digits - 1; - static FMT_CONSTEXPR_DECL const uint64_t implicit_bit = - 1ULL << double_significand_size; - - public: - significand_type f; - int e; - - static FMT_CONSTEXPR_DECL const int significand_size = - bits::value; - - fp() : f(0), e(0) {} - fp(uint64_t f_val, int e_val) : f(f_val), e(e_val) {} - - // Constructs fp from an IEEE754 double. It is a template to prevent compile - // errors on platforms where double is not IEEE754. - template explicit fp(Double d) { assign(d); } - - // Normalizes the value converted from double and multiplied by (1 << SHIFT). - template friend fp normalize(fp value) { - // Handle subnormals. - const auto shifted_implicit_bit = fp::implicit_bit << SHIFT; - while ((value.f & shifted_implicit_bit) == 0) { - value.f <<= 1; - --value.e; - } - // Subtract 1 to account for hidden bit. - const auto offset = - fp::significand_size - fp::double_significand_size - SHIFT - 1; - value.f <<= offset; - value.e -= offset; - return value; - } - - // Assigns d to this and return true iff predecessor is closer than successor. - template - bool assign(Double d) { - // Assume double is in the format [sign][exponent][significand]. - using limits = std::numeric_limits; - const int exponent_size = - bits::value - double_significand_size - 1; // -1 for sign - const uint64_t significand_mask = implicit_bit - 1; - const uint64_t exponent_mask = (~0ULL >> 1) & ~significand_mask; - const int exponent_bias = (1 << exponent_size) - limits::max_exponent - 1; - auto u = bit_cast(d); - f = u & significand_mask; - auto biased_e = (u & exponent_mask) >> double_significand_size; - // Predecessor is closer if d is a normalized power of 2 (f == 0) other than - // the smallest normalized number (biased_e > 1). - bool is_predecessor_closer = f == 0 && biased_e > 1; - if (biased_e != 0) - f += implicit_bit; - else - biased_e = 1; // Subnormals use biased exponent 1 (min exponent). - e = static_cast(biased_e - exponent_bias - double_significand_size); - return is_predecessor_closer; - } - - template - bool assign(Double) { - *this = fp(); - return false; - } - - // Assigns d to this together with computing lower and upper boundaries, - // where a boundary is a value half way between the number and its predecessor - // (lower) or successor (upper). The upper boundary is normalized and lower - // has the same exponent but may be not normalized. - template boundaries assign_with_boundaries(Double d) { - bool is_lower_closer = assign(d); - fp lower = - is_lower_closer ? fp((f << 2) - 1, e - 2) : fp((f << 1) - 1, e - 1); - // 1 in normalize accounts for the exponent shift above. - fp upper = normalize<1>(fp((f << 1) + 1, e - 1)); - lower.f <<= lower.e - upper.e; - return boundaries{lower.f, upper.f}; - } - - template boundaries assign_float_with_boundaries(Double d) { - assign(d); - constexpr int min_normal_e = std::numeric_limits::min_exponent - - std::numeric_limits::digits; - significand_type half_ulp = 1 << (std::numeric_limits::digits - - std::numeric_limits::digits - 1); - if (min_normal_e > e) half_ulp <<= min_normal_e - e; - fp upper = normalize<0>(fp(f + half_ulp, e)); - fp lower = fp( - f - (half_ulp >> ((f == implicit_bit && e > min_normal_e) ? 1 : 0)), e); - lower.f <<= lower.e - upper.e; - return boundaries{lower.f, upper.f}; - } -}; - -inline bool operator==(fp x, fp y) { return x.f == y.f && x.e == y.e; } - -// Computes lhs * rhs / pow(2, 64) rounded to nearest with half-up tie breaking. -inline uint64_t multiply(uint64_t lhs, uint64_t rhs) { -#if FMT_USE_INT128 - auto product = static_cast<__uint128_t>(lhs) * rhs; - auto f = static_cast(product >> 64); - return (static_cast(product) & (1ULL << 63)) != 0 ? f + 1 : f; -#else - // Multiply 32-bit parts of significands. - uint64_t mask = (1ULL << 32) - 1; - uint64_t a = lhs >> 32, b = lhs & mask; - uint64_t c = rhs >> 32, d = rhs & mask; - uint64_t ac = a * c, bc = b * c, ad = a * d, bd = b * d; - // Compute mid 64-bit of result and round. - uint64_t mid = (bd >> 32) + (ad & mask) + (bc & mask) + (1U << 31); - return ac + (ad >> 32) + (bc >> 32) + (mid >> 32); -#endif -} - -inline fp operator*(fp x, fp y) { return {multiply(x.f, y.f), x.e + y.e + 64}; } - -// Returns a cached power of 10 `c_k = c_k.f * pow(2, c_k.e)` such that its -// (binary) exponent satisfies `min_exponent <= c_k.e <= min_exponent + 28`. -FMT_FUNC fp get_cached_power(int min_exponent, int& pow10_exponent) { - const uint64_t one_over_log2_10 = 0x4d104d42; // round(pow(2, 32) / log2(10)) - int index = static_cast( - static_cast( - (min_exponent + fp::significand_size - 1) * one_over_log2_10 + - ((uint64_t(1) << 32) - 1) // ceil - ) >> - 32 // arithmetic shift - ); - // Decimal exponent of the first (smallest) cached power of 10. - const int first_dec_exp = -348; - // Difference between 2 consecutive decimal exponents in cached powers of 10. - const int dec_exp_step = 8; - index = (index - first_dec_exp - 1) / dec_exp_step + 1; - pow10_exponent = first_dec_exp + index * dec_exp_step; - return {data::pow10_significands[index], data::pow10_exponents[index]}; -} - -// A simple accumulator to hold the sums of terms in bigint::square if uint128_t -// is not available. -struct accumulator { - uint64_t lower; - uint64_t upper; - - accumulator() : lower(0), upper(0) {} - explicit operator uint32_t() const { return static_cast(lower); } - - void operator+=(uint64_t n) { - lower += n; - if (lower < n) ++upper; - } - void operator>>=(int shift) { - assert(shift == 32); - (void)shift; - lower = (upper << 32) | (lower >> 32); - upper >>= 32; - } -}; - -class bigint { - private: - // A bigint is stored as an array of bigits (big digits), with bigit at index - // 0 being the least significant one. - using bigit = uint32_t; - using double_bigit = uint64_t; - enum { bigits_capacity = 32 }; - basic_memory_buffer bigits_; - int exp_; - - static FMT_CONSTEXPR_DECL const int bigit_bits = bits::value; - - friend struct formatter; - - void subtract_bigits(int index, bigit other, bigit& borrow) { - auto result = static_cast(bigits_[index]) - other - borrow; - bigits_[index] = static_cast(result); - borrow = static_cast(result >> (bigit_bits * 2 - 1)); - } - - void remove_leading_zeros() { - int num_bigits = static_cast(bigits_.size()) - 1; - while (num_bigits > 0 && bigits_[num_bigits] == 0) --num_bigits; - bigits_.resize(num_bigits + 1); - } - - // Computes *this -= other assuming aligned bigints and *this >= other. - void subtract_aligned(const bigint& other) { - FMT_ASSERT(other.exp_ >= exp_, "unaligned bigints"); - FMT_ASSERT(compare(*this, other) >= 0, ""); - bigit borrow = 0; - int i = other.exp_ - exp_; - for (int j = 0, n = static_cast(other.bigits_.size()); j != n; - ++i, ++j) { - subtract_bigits(i, other.bigits_[j], borrow); - } - while (borrow > 0) subtract_bigits(i, 0, borrow); - remove_leading_zeros(); - } - - void multiply(uint32_t value) { - const double_bigit wide_value = value; - bigit carry = 0; - for (size_t i = 0, n = bigits_.size(); i < n; ++i) { - double_bigit result = bigits_[i] * wide_value + carry; - bigits_[i] = static_cast(result); - carry = static_cast(result >> bigit_bits); - } - if (carry != 0) bigits_.push_back(carry); - } - - void multiply(uint64_t value) { - const bigit mask = ~bigit(0); - const double_bigit lower = value & mask; - const double_bigit upper = value >> bigit_bits; - double_bigit carry = 0; - for (size_t i = 0, n = bigits_.size(); i < n; ++i) { - double_bigit result = bigits_[i] * lower + (carry & mask); - carry = - bigits_[i] * upper + (result >> bigit_bits) + (carry >> bigit_bits); - bigits_[i] = static_cast(result); - } - while (carry != 0) { - bigits_.push_back(carry & mask); - carry >>= bigit_bits; - } - } - - public: - bigint() : exp_(0) {} - explicit bigint(uint64_t n) { assign(n); } - ~bigint() { assert(bigits_.capacity() <= bigits_capacity); } - - bigint(const bigint&) = delete; - void operator=(const bigint&) = delete; - - void assign(const bigint& other) { - bigits_.resize(other.bigits_.size()); - auto data = other.bigits_.data(); - std::copy(data, data + other.bigits_.size(), bigits_.data()); - exp_ = other.exp_; - } - - void assign(uint64_t n) { - int num_bigits = 0; - do { - bigits_[num_bigits++] = n & ~bigit(0); - n >>= bigit_bits; - } while (n != 0); - bigits_.resize(num_bigits); - exp_ = 0; - } - - int num_bigits() const { return static_cast(bigits_.size()) + exp_; } - - bigint& operator<<=(int shift) { - assert(shift >= 0); - exp_ += shift / bigit_bits; - shift %= bigit_bits; - if (shift == 0) return *this; - bigit carry = 0; - for (size_t i = 0, n = bigits_.size(); i < n; ++i) { - bigit c = bigits_[i] >> (bigit_bits - shift); - bigits_[i] = (bigits_[i] << shift) + carry; - carry = c; - } - if (carry != 0) bigits_.push_back(carry); - return *this; - } - - template bigint& operator*=(Int value) { - FMT_ASSERT(value > 0, ""); - multiply(uint32_or_64_or_128_t(value)); - return *this; - } - - friend int compare(const bigint& lhs, const bigint& rhs) { - int num_lhs_bigits = lhs.num_bigits(), num_rhs_bigits = rhs.num_bigits(); - if (num_lhs_bigits != num_rhs_bigits) - return num_lhs_bigits > num_rhs_bigits ? 1 : -1; - int i = static_cast(lhs.bigits_.size()) - 1; - int j = static_cast(rhs.bigits_.size()) - 1; - int end = i - j; - if (end < 0) end = 0; - for (; i >= end; --i, --j) { - bigit lhs_bigit = lhs.bigits_[i], rhs_bigit = rhs.bigits_[j]; - if (lhs_bigit != rhs_bigit) return lhs_bigit > rhs_bigit ? 1 : -1; - } - if (i != j) return i > j ? 1 : -1; - return 0; - } - - // Returns compare(lhs1 + lhs2, rhs). - friend int add_compare(const bigint& lhs1, const bigint& lhs2, - const bigint& rhs) { - int max_lhs_bigits = (std::max)(lhs1.num_bigits(), lhs2.num_bigits()); - int num_rhs_bigits = rhs.num_bigits(); - if (max_lhs_bigits + 1 < num_rhs_bigits) return -1; - if (max_lhs_bigits > num_rhs_bigits) return 1; - auto get_bigit = [](const bigint& n, int i) -> bigit { - return i >= n.exp_ && i < n.num_bigits() ? n.bigits_[i - n.exp_] : 0; - }; - double_bigit borrow = 0; - int min_exp = (std::min)((std::min)(lhs1.exp_, lhs2.exp_), rhs.exp_); - for (int i = num_rhs_bigits - 1; i >= min_exp; --i) { - double_bigit sum = - static_cast(get_bigit(lhs1, i)) + get_bigit(lhs2, i); - bigit rhs_bigit = get_bigit(rhs, i); - if (sum > rhs_bigit + borrow) return 1; - borrow = rhs_bigit + borrow - sum; - if (borrow > 1) return -1; - borrow <<= bigit_bits; - } - return borrow != 0 ? -1 : 0; - } - - // Assigns pow(10, exp) to this bigint. - void assign_pow10(int exp) { - assert(exp >= 0); - if (exp == 0) return assign(1); - // Find the top bit. - int bitmask = 1; - while (exp >= bitmask) bitmask <<= 1; - bitmask >>= 1; - // pow(10, exp) = pow(5, exp) * pow(2, exp). First compute pow(5, exp) by - // repeated squaring and multiplication. - assign(5); - bitmask >>= 1; - while (bitmask != 0) { - square(); - if ((exp & bitmask) != 0) *this *= 5; - bitmask >>= 1; - } - *this <<= exp; // Multiply by pow(2, exp) by shifting. - } - - void square() { - basic_memory_buffer n(std::move(bigits_)); - int num_bigits = static_cast(bigits_.size()); - int num_result_bigits = 2 * num_bigits; - bigits_.resize(num_result_bigits); - using accumulator_t = conditional_t; - auto sum = accumulator_t(); - for (int bigit_index = 0; bigit_index < num_bigits; ++bigit_index) { - // Compute bigit at position bigit_index of the result by adding - // cross-product terms n[i] * n[j] such that i + j == bigit_index. - for (int i = 0, j = bigit_index; j >= 0; ++i, --j) { - // Most terms are multiplied twice which can be optimized in the future. - sum += static_cast(n[i]) * n[j]; - } - bigits_[bigit_index] = static_cast(sum); - sum >>= bits::value; // Compute the carry. - } - // Do the same for the top half. - for (int bigit_index = num_bigits; bigit_index < num_result_bigits; - ++bigit_index) { - for (int j = num_bigits - 1, i = bigit_index - j; i < num_bigits;) - sum += static_cast(n[i++]) * n[j--]; - bigits_[bigit_index] = static_cast(sum); - sum >>= bits::value; - } - --num_result_bigits; - remove_leading_zeros(); - exp_ *= 2; - } - - // Divides this bignum by divisor, assigning the remainder to this and - // returning the quotient. - int divmod_assign(const bigint& divisor) { - FMT_ASSERT(this != &divisor, ""); - if (compare(*this, divisor) < 0) return 0; - int num_bigits = static_cast(bigits_.size()); - FMT_ASSERT(divisor.bigits_[divisor.bigits_.size() - 1] != 0, ""); - int exp_difference = exp_ - divisor.exp_; - if (exp_difference > 0) { - // Align bigints by adding trailing zeros to simplify subtraction. - bigits_.resize(num_bigits + exp_difference); - for (int i = num_bigits - 1, j = i + exp_difference; i >= 0; --i, --j) - bigits_[j] = bigits_[i]; - std::uninitialized_fill_n(bigits_.data(), exp_difference, 0); - exp_ -= exp_difference; - } - int quotient = 0; - do { - subtract_aligned(divisor); - ++quotient; - } while (compare(*this, divisor) >= 0); - return quotient; - } -}; - -enum round_direction { unknown, up, down }; - -// Given the divisor (normally a power of 10), the remainder = v % divisor for -// some number v and the error, returns whether v should be rounded up, down, or -// whether the rounding direction can't be determined due to error. -// error should be less than divisor / 2. -inline round_direction get_round_direction(uint64_t divisor, uint64_t remainder, - uint64_t error) { - FMT_ASSERT(remainder < divisor, ""); // divisor - remainder won't overflow. - FMT_ASSERT(error < divisor, ""); // divisor - error won't overflow. - FMT_ASSERT(error < divisor - error, ""); // error * 2 won't overflow. - // Round down if (remainder + error) * 2 <= divisor. - if (remainder <= divisor - remainder && error * 2 <= divisor - remainder * 2) - return down; - // Round up if (remainder - error) * 2 >= divisor. - if (remainder >= error && - remainder - error >= divisor - (remainder - error)) { - return up; - } - return unknown; -} - -namespace digits { -enum result { - more, // Generate more digits. - done, // Done generating digits. - error // Digit generation cancelled due to an error. -}; -} - -// Generates output using the Grisu digit-gen algorithm. -// error: the size of the region (lower, upper) outside of which numbers -// definitely do not round to value (Delta in Grisu3). -template -FMT_ALWAYS_INLINE digits::result grisu_gen_digits(fp value, uint64_t error, - int& exp, Handler& handler) { - const fp one(1ULL << -value.e, value.e); - // The integral part of scaled value (p1 in Grisu) = value / one. It cannot be - // zero because it contains a product of two 64-bit numbers with MSB set (due - // to normalization) - 1, shifted right by at most 60 bits. - auto integral = static_cast(value.f >> -one.e); - FMT_ASSERT(integral != 0, ""); - FMT_ASSERT(integral == value.f >> -one.e, ""); - // The fractional part of scaled value (p2 in Grisu) c = value % one. - uint64_t fractional = value.f & (one.f - 1); - exp = count_digits(integral); // kappa in Grisu. - // Divide by 10 to prevent overflow. - auto result = handler.on_start(data::powers_of_10_64[exp - 1] << -one.e, - value.f / 10, error * 10, exp); - if (result != digits::more) return result; - // Generate digits for the integral part. This can produce up to 10 digits. - do { - uint32_t digit = 0; - auto divmod_integral = [&](uint32_t divisor) { - digit = integral / divisor; - integral %= divisor; - }; - // This optimization by Milo Yip reduces the number of integer divisions by - // one per iteration. - switch (exp) { - case 10: - divmod_integral(1000000000); - break; - case 9: - divmod_integral(100000000); - break; - case 8: - divmod_integral(10000000); - break; - case 7: - divmod_integral(1000000); - break; - case 6: - divmod_integral(100000); - break; - case 5: - divmod_integral(10000); - break; - case 4: - divmod_integral(1000); - break; - case 3: - divmod_integral(100); - break; - case 2: - divmod_integral(10); - break; - case 1: - digit = integral; - integral = 0; - break; - default: - FMT_ASSERT(false, "invalid number of digits"); - } - --exp; - uint64_t remainder = - (static_cast(integral) << -one.e) + fractional; - result = handler.on_digit(static_cast('0' + digit), - data::powers_of_10_64[exp] << -one.e, remainder, - error, exp, true); - if (result != digits::more) return result; - } while (exp > 0); - // Generate digits for the fractional part. - for (;;) { - fractional *= 10; - error *= 10; - char digit = - static_cast('0' + static_cast(fractional >> -one.e)); - fractional &= one.f - 1; - --exp; - result = handler.on_digit(digit, one.f, fractional, error, exp, false); - if (result != digits::more) return result; - } -} - -// The fixed precision digit handler. -struct fixed_handler { - char* buf; - int size; - int precision; - int exp10; - bool fixed; - - digits::result on_start(uint64_t divisor, uint64_t remainder, uint64_t error, - int& exp) { - // Non-fixed formats require at least one digit and no precision adjustment. - if (!fixed) return digits::more; - // Adjust fixed precision by exponent because it is relative to decimal - // point. - precision += exp + exp10; - // Check if precision is satisfied just by leading zeros, e.g. - // format("{:.2f}", 0.001) gives "0.00" without generating any digits. - if (precision > 0) return digits::more; - if (precision < 0) return digits::done; - auto dir = get_round_direction(divisor, remainder, error); - if (dir == unknown) return digits::error; - buf[size++] = dir == up ? '1' : '0'; - return digits::done; - } - - digits::result on_digit(char digit, uint64_t divisor, uint64_t remainder, - uint64_t error, int, bool integral) { - FMT_ASSERT(remainder < divisor, ""); - buf[size++] = digit; - if (size < precision) return digits::more; - if (!integral) { - // Check if error * 2 < divisor with overflow prevention. - // The check is not needed for the integral part because error = 1 - // and divisor > (1 << 32) there. - if (error >= divisor || error >= divisor - error) return digits::error; - } else { - FMT_ASSERT(error == 1 && divisor > 2, ""); - } - auto dir = get_round_direction(divisor, remainder, error); - if (dir != up) return dir == down ? digits::done : digits::error; - ++buf[size - 1]; - for (int i = size - 1; i > 0 && buf[i] > '9'; --i) { - buf[i] = '0'; - ++buf[i - 1]; - } - if (buf[0] > '9') { - buf[0] = '1'; - buf[size++] = '0'; - } - return digits::done; - } -}; - -// The shortest representation digit handler. -struct grisu_shortest_handler { - char* buf; - int size; - // Distance between scaled value and upper bound (wp_W in Grisu3). - uint64_t diff; - - digits::result on_start(uint64_t, uint64_t, uint64_t, int&) { - return digits::more; - } - - // Decrement the generated number approaching value from above. - void round(uint64_t d, uint64_t divisor, uint64_t& remainder, - uint64_t error) { - while ( - remainder < d && error - remainder >= divisor && - (remainder + divisor < d || d - remainder >= remainder + divisor - d)) { - --buf[size - 1]; - remainder += divisor; - } - } - - // Implements Grisu's round_weed. - digits::result on_digit(char digit, uint64_t divisor, uint64_t remainder, - uint64_t error, int exp, bool integral) { - buf[size++] = digit; - if (remainder >= error) return digits::more; - uint64_t unit = integral ? 1 : data::powers_of_10_64[-exp]; - uint64_t up = (diff - 1) * unit; // wp_Wup - round(up, divisor, remainder, error); - uint64_t down = (diff + 1) * unit; // wp_Wdown - if (remainder < down && error - remainder >= divisor && - (remainder + divisor < down || - down - remainder > remainder + divisor - down)) { - return digits::error; - } - return 2 * unit <= remainder && remainder <= error - 4 * unit - ? digits::done - : digits::error; - } -}; - -// Formats value using a variation of the Fixed-Precision Positive -// Floating-Point Printout ((FPP)^2) algorithm by Steele & White: -// https://fmt.dev/p372-steele.pdf. -template -void fallback_format(Double d, buffer& buf, int& exp10) { - bigint numerator; // 2 * R in (FPP)^2. - bigint denominator; // 2 * S in (FPP)^2. - // lower and upper are differences between value and corresponding boundaries. - bigint lower; // (M^- in (FPP)^2). - bigint upper_store; // upper's value if different from lower. - bigint* upper = nullptr; // (M^+ in (FPP)^2). - fp value; - // Shift numerator and denominator by an extra bit or two (if lower boundary - // is closer) to make lower and upper integers. This eliminates multiplication - // by 2 during later computations. - // TODO: handle float - int shift = value.assign(d) ? 2 : 1; - uint64_t significand = value.f << shift; - if (value.e >= 0) { - numerator.assign(significand); - numerator <<= value.e; - lower.assign(1); - lower <<= value.e; - if (shift != 1) { - upper_store.assign(1); - upper_store <<= value.e + 1; - upper = &upper_store; - } - denominator.assign_pow10(exp10); - denominator <<= 1; - } else if (exp10 < 0) { - numerator.assign_pow10(-exp10); - lower.assign(numerator); - if (shift != 1) { - upper_store.assign(numerator); - upper_store <<= 1; - upper = &upper_store; - } - numerator *= significand; - denominator.assign(1); - denominator <<= shift - value.e; - } else { - numerator.assign(significand); - denominator.assign_pow10(exp10); - denominator <<= shift - value.e; - lower.assign(1); - if (shift != 1) { - upper_store.assign(1ULL << 1); - upper = &upper_store; - } - } - if (!upper) upper = &lower; - // Invariant: value == (numerator / denominator) * pow(10, exp10). - bool even = (value.f & 1) == 0; - int num_digits = 0; - char* data = buf.data(); - for (;;) { - int digit = numerator.divmod_assign(denominator); - bool low = compare(numerator, lower) - even < 0; // numerator <[=] lower. - // numerator + upper >[=] pow10: - bool high = add_compare(numerator, *upper, denominator) + even > 0; - data[num_digits++] = static_cast('0' + digit); - if (low || high) { - if (!low) { - ++data[num_digits - 1]; - } else if (high) { - int result = add_compare(numerator, numerator, denominator); - // Round half to even. - if (result > 0 || (result == 0 && (digit % 2) != 0)) - ++data[num_digits - 1]; - } - buf.resize(num_digits); - exp10 -= num_digits - 1; - return; - } - numerator *= 10; - lower *= 10; - if (upper != &lower) *upper *= 10; - } -} - -// Formats value using the Grisu algorithm -// (https://www.cs.tufts.edu/~nr/cs257/archive/florian-loitsch/printf.pdf) -// if T is a IEEE754 binary32 or binary64 and snprintf otherwise. -template -int format_float(T value, int precision, float_specs specs, buffer& buf) { - static_assert(!std::is_same(), ""); - FMT_ASSERT(value >= 0, "value is negative"); - - const bool fixed = specs.format == float_format::fixed; - if (value <= 0) { // <= instead of == to silence a warning. - if (precision <= 0 || !fixed) { - buf.push_back('0'); - return 0; - } - buf.resize(to_unsigned(precision)); - std::uninitialized_fill_n(buf.data(), precision, '0'); - return -precision; - } - - if (!specs.use_grisu) return snprintf_float(value, precision, specs, buf); - - int exp = 0; - const int min_exp = -60; // alpha in Grisu. - int cached_exp10 = 0; // K in Grisu. - if (precision != -1) { - if (precision > 17) return snprintf_float(value, precision, specs, buf); - fp normalized = normalize(fp(value)); - const auto cached_pow = get_cached_power( - min_exp - (normalized.e + fp::significand_size), cached_exp10); - normalized = normalized * cached_pow; - fixed_handler handler{buf.data(), 0, precision, -cached_exp10, fixed}; - if (grisu_gen_digits(normalized, 1, exp, handler) == digits::error) - return snprintf_float(value, precision, specs, buf); - int num_digits = handler.size; - if (!fixed) { - // Remove trailing zeros. - while (num_digits > 0 && buf[num_digits - 1] == '0') { - --num_digits; - ++exp; - } - } - buf.resize(to_unsigned(num_digits)); - } else { - fp fp_value; - auto boundaries = specs.binary32 - ? fp_value.assign_float_with_boundaries(value) - : fp_value.assign_with_boundaries(value); - fp_value = normalize(fp_value); - // Find a cached power of 10 such that multiplying value by it will bring - // the exponent in the range [min_exp, -32]. - const fp cached_pow = get_cached_power( - min_exp - (fp_value.e + fp::significand_size), cached_exp10); - // Multiply value and boundaries by the cached power of 10. - fp_value = fp_value * cached_pow; - boundaries.lower = multiply(boundaries.lower, cached_pow.f); - boundaries.upper = multiply(boundaries.upper, cached_pow.f); - assert(min_exp <= fp_value.e && fp_value.e <= -32); - --boundaries.lower; // \tilde{M}^- - 1 ulp -> M^-_{\downarrow}. - ++boundaries.upper; // \tilde{M}^+ + 1 ulp -> M^+_{\uparrow}. - // Numbers outside of (lower, upper) definitely do not round to value. - grisu_shortest_handler handler{buf.data(), 0, - boundaries.upper - fp_value.f}; - auto result = - grisu_gen_digits(fp(boundaries.upper, fp_value.e), - boundaries.upper - boundaries.lower, exp, handler); - if (result == digits::error) { - exp += handler.size - cached_exp10 - 1; - fallback_format(value, buf, exp); - return exp; - } - buf.resize(to_unsigned(handler.size)); - } - return exp - cached_exp10; -} - -template -int snprintf_float(T value, int precision, float_specs specs, - buffer& buf) { - // Buffer capacity must be non-zero, otherwise MSVC's vsnprintf_s will fail. - FMT_ASSERT(buf.capacity() > buf.size(), "empty buffer"); - static_assert(!std::is_same(), ""); - - // Subtract 1 to account for the difference in precision since we use %e for - // both general and exponent format. - if (specs.format == float_format::general || - specs.format == float_format::exp) - precision = (precision >= 0 ? precision : 6) - 1; - - // Build the format string. - enum { max_format_size = 7 }; // Ths longest format is "%#.*Le". - char format[max_format_size]; - char* format_ptr = format; - *format_ptr++ = '%'; - if (specs.trailing_zeros) *format_ptr++ = '#'; - if (precision >= 0) { - *format_ptr++ = '.'; - *format_ptr++ = '*'; - } - if (std::is_same()) *format_ptr++ = 'L'; - *format_ptr++ = specs.format != float_format::hex - ? (specs.format == float_format::fixed ? 'f' : 'e') - : (specs.upper ? 'A' : 'a'); - *format_ptr = '\0'; - - // Format using snprintf. - auto offset = buf.size(); - for (;;) { - auto begin = buf.data() + offset; - auto capacity = buf.capacity() - offset; -#ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION - if (precision > 100000) - throw std::runtime_error( - "fuzz mode - avoid large allocation inside snprintf"); -#endif - // Suppress the warning about a nonliteral format string. - auto snprintf_ptr = FMT_SNPRINTF; - int result = precision >= 0 - ? snprintf_ptr(begin, capacity, format, precision, value) - : snprintf_ptr(begin, capacity, format, value); - if (result < 0) { - buf.reserve(buf.capacity() + 1); // The buffer will grow exponentially. - continue; - } - unsigned size = to_unsigned(result); - // Size equal to capacity means that the last character was truncated. - if (size >= capacity) { - buf.reserve(size + offset + 1); // Add 1 for the terminating '\0'. - continue; - } - auto is_digit = [](char c) { return c >= '0' && c <= '9'; }; - if (specs.format == float_format::fixed) { - if (precision == 0) { - buf.resize(size); - return 0; - } - // Find and remove the decimal point. - auto end = begin + size, p = end; - do { - --p; - } while (is_digit(*p)); - int fraction_size = static_cast(end - p - 1); - std::memmove(p, p + 1, fraction_size); - buf.resize(size - 1); - return -fraction_size; - } - if (specs.format == float_format::hex) { - buf.resize(size + offset); - return 0; - } - // Find and parse the exponent. - auto end = begin + size, exp_pos = end; - do { - --exp_pos; - } while (*exp_pos != 'e'); - char sign = exp_pos[1]; - assert(sign == '+' || sign == '-'); - int exp = 0; - auto p = exp_pos + 2; // Skip 'e' and sign. - do { - assert(is_digit(*p)); - exp = exp * 10 + (*p++ - '0'); - } while (p != end); - if (sign == '-') exp = -exp; - int fraction_size = 0; - if (exp_pos != begin + 1) { - // Remove trailing zeros. - auto fraction_end = exp_pos - 1; - while (*fraction_end == '0') --fraction_end; - // Move the fractional part left to get rid of the decimal point. - fraction_size = static_cast(fraction_end - begin - 1); - std::memmove(begin + 1, begin + 2, fraction_size); - } - buf.resize(fraction_size + offset + 1); - return exp - fraction_size; - } -} -} // namespace internal - -template <> struct formatter { - format_parse_context::iterator parse(format_parse_context& ctx) { - return ctx.begin(); - } - - format_context::iterator format(const internal::bigint& n, - format_context& ctx) { - auto out = ctx.out(); - bool first = true; - for (auto i = n.bigits_.size(); i > 0; --i) { - auto value = n.bigits_[i - 1]; - if (first) { - out = format_to(out, "{:x}", value); - first = false; - continue; - } - out = format_to(out, "{:08x}", value); - } - if (n.exp_ > 0) - out = format_to(out, "p{}", n.exp_ * internal::bigint::bigit_bits); - return out; - } -}; - -FMT_FUNC void internal::error_handler::on_error(std::string message) { - FMT_THROW(duckdb::Exception(message)); -} - -FMT_END_NAMESPACE - -#ifdef _MSC_VER -# pragma warning(pop) -#endif - -#endif // FMT_FORMAT_INL_H_ - - -// LICENSE_CHANGE_END - - -FMT_BEGIN_NAMESPACE -namespace internal { - -template -int format_float(char* buf, std::size_t size, const char* format, int precision, - T value) { -#ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION - if (precision > 100000) - throw std::runtime_error( - "fuzz mode - avoid large allocation inside snprintf"); -#endif - // Suppress the warning about nonliteral format string. - auto snprintf_ptr = FMT_SNPRINTF; - return precision < 0 ? snprintf_ptr(buf, size, format, value) - : snprintf_ptr(buf, size, format, precision, value); -} -struct sprintf_specs { - int precision; - char type; - bool alt : 1; - - template - constexpr sprintf_specs(basic_format_specs specs) - : precision(specs.precision), type(specs.type), alt(specs.alt) {} - - constexpr bool has_precision() const { return precision >= 0; } -}; - -// This is deprecated and is kept only to preserve ABI compatibility. -template -char* sprintf_format(Double value, internal::buffer& buf, - sprintf_specs specs) { - // Buffer capacity must be non-zero, otherwise MSVC's vsnprintf_s will fail. - FMT_ASSERT(buf.capacity() != 0, "empty buffer"); - - // Build format string. - enum { max_format_size = 10 }; // longest format: %#-*.*Lg - char format[max_format_size]; - char* format_ptr = format; - *format_ptr++ = '%'; - if (specs.alt || !specs.type) *format_ptr++ = '#'; - if (specs.precision >= 0) { - *format_ptr++ = '.'; - *format_ptr++ = '*'; - } - if (std::is_same::value) *format_ptr++ = 'L'; - - char type = specs.type; - - if (type == '%') - type = 'f'; - else if (type == 0 || type == 'n') - type = 'g'; -#if FMT_MSC_VER - if (type == 'F') { - // MSVC's printf doesn't support 'F'. - type = 'f'; - } -#endif - *format_ptr++ = type; - *format_ptr = '\0'; - - // Format using snprintf. - char* start = nullptr; - char* decimal_point_pos = nullptr; - for (;;) { - std::size_t buffer_size = buf.capacity(); - start = &buf[0]; - int result = - format_float(start, buffer_size, format, specs.precision, value); - if (result >= 0) { - unsigned n = internal::to_unsigned(result); - if (n < buf.capacity()) { - // Find the decimal point. - auto p = buf.data(), end = p + n; - if (*p == '+' || *p == '-') ++p; - if (specs.type != 'a' && specs.type != 'A') { - while (p < end && *p >= '0' && *p <= '9') ++p; - if (p < end && *p != 'e' && *p != 'E') { - decimal_point_pos = p; - if (!specs.type) { - // Keep only one trailing zero after the decimal point. - ++p; - if (*p == '0') ++p; - while (p != end && *p >= '1' && *p <= '9') ++p; - char* where = p; - while (p != end && *p == '0') ++p; - if (p == end || *p < '0' || *p > '9') { - if (p != end) std::memmove(where, p, to_unsigned(end - p)); - n -= static_cast(p - where); - } - } - } - } - buf.resize(n); - break; // The buffer is large enough - continue with formatting. - } - buf.reserve(n + 1); - } else { - // If result is negative we ask to increase the capacity by at least 1, - // but as std::vector, the buffer grows exponentially. - buf.reserve(buf.capacity() + 1); - } - } - return decimal_point_pos; -} -} // namespace internal - -template FMT_API char* internal::sprintf_format(double, internal::buffer&, - sprintf_specs); -template FMT_API char* internal::sprintf_format(long double, - internal::buffer&, - sprintf_specs); - -template struct FMT_API internal::basic_data; - -// Workaround a bug in MSVC2013 that prevents instantiation of format_float. -int (*instantiate_format_float)(double, int, internal::float_specs, - internal::buffer&) = - internal::format_float; - -// Explicit instantiations for char. - -template FMT_API std::string internal::grouping_impl(locale_ref); -template FMT_API char internal::thousands_sep_impl(locale_ref); -template FMT_API char internal::decimal_point_impl(locale_ref); - -template FMT_API void internal::buffer::append(const char*, const char*); - -template FMT_API void internal::arg_map::init( - const basic_format_args& args); - -template FMT_API std::string internal::vformat( - string_view, basic_format_args); - -template FMT_API format_context::iterator internal::vformat_to( - internal::buffer&, string_view, basic_format_args); - -template FMT_API int internal::snprintf_float(double, int, - internal::float_specs, - internal::buffer&); -template FMT_API int internal::snprintf_float(long double, int, - internal::float_specs, - internal::buffer&); -template FMT_API int internal::format_float(double, int, internal::float_specs, - internal::buffer&); -template FMT_API int internal::format_float(long double, int, - internal::float_specs, - internal::buffer&); - -// Explicit instantiations for wchar_t. - -template FMT_API std::string internal::grouping_impl(locale_ref); -template FMT_API wchar_t internal::thousands_sep_impl(locale_ref); -template FMT_API wchar_t internal::decimal_point_impl(locale_ref); - -template FMT_API void internal::buffer::append(const wchar_t*, - const wchar_t*); - -template FMT_API std::wstring internal::vformat( - wstring_view, basic_format_args); -FMT_END_NAMESPACE - - -// LICENSE_CHANGE_END - -#endif diff --git a/lib/duckdb-fsst.cpp b/lib/duckdb-fsst.cpp deleted file mode 100644 index 69808c4d..00000000 --- a/lib/duckdb-fsst.cpp +++ /dev/null @@ -1,1877 +0,0 @@ -#ifdef GODUCKDB_FROM_SOURCE -// See https://raw.githubusercontent.com/duckdb/duckdb/main/LICENSE for licensing information - -#include "duckdb.hpp" -#include "duckdb-internal.hpp" -#ifndef DUCKDB_AMALGAMATION -#error header mismatch -#endif - - -// LICENSE_CHANGE_BEGIN -// The following code up to LICENSE_CHANGE_END is subject to THIRD PARTY LICENSE #6 -// See the end of this file for a list - -// this software is distributed under the MIT License (http://www.opensource.org/licenses/MIT): -// -// Copyright 2018-2020, CWI, TU Munich, FSU Jena -// -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files -// (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, -// merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// - The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES -// OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR -// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -// -// You can contact the authors via the FSST source repository : https://github.com/cwida/fsst - - -// LICENSE_CHANGE_BEGIN -// The following code up to LICENSE_CHANGE_END is subject to THIRD PARTY LICENSE #6 -// See the end of this file for a list - -// this software is distributed under the MIT License (http://www.opensource.org/licenses/MIT): -// -// Copyright 2018-2020, CWI, TU Munich, FSU Jena -// -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files -// (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, -// merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// - The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES -// OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR -// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -// -// You can contact the authors via the FSST source repository : https://github.com/cwida/fsst -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -using namespace std; - - // the official FSST API -- also usable by C mortals - -/* unsigned integers */ -typedef uint8_t u8; -typedef uint16_t u16; -typedef uint32_t u32; -typedef uint64_t u64; - -inline uint64_t fsst_unaligned_load(u8 const* V) { - uint64_t Ret; - memcpy(&Ret, V, sizeof(uint64_t)); // compiler will generate efficient code (unaligned load, where possible) - return Ret; -} - -#define FSST_ENDIAN_MARKER ((u64) 1) -#define FSST_VERSION_20190218 20190218 -#define FSST_VERSION ((u64) FSST_VERSION_20190218) - -// "symbols" are character sequences (up to 8 bytes) -// A symbol is compressed into a "code" of, in principle, one byte. But, we added an exception mechanism: -// byte 255 followed by byte X represents the single-byte symbol X. Its code is 256+X. - -// we represent codes in u16 (not u8). 12 bits code (of which 10 are used), 4 bits length -#define FSST_LEN_BITS 12 -#define FSST_CODE_BITS 9 -#define FSST_CODE_BASE 256UL /* first 256 codes [0,255] are pseudo codes: escaped bytes */ -#define FSST_CODE_MAX (1UL<=8) { - len = 8; - memcpy(val.str, input, 8); - } else { - memcpy(val.str, input, len); - } - set_code_len(FSST_CODE_MAX, len); - } - void set_code_len(u32 code, u32 len) { icl = (len<<28)|(code<<16)|((8-len)*8); } - - u32 length() const { return (u32) (icl >> 28); } - u16 code() const { return (icl >> 16) & FSST_CODE_MASK; } - u32 ignoredBits() const { return (u32) icl; } - - u8 first() const { assert( length() >= 1); return 0xFF & val.num; } - u16 first2() const { assert( length() >= 2); return 0xFFFF & val.num; } - -#define FSST_HASH_LOG2SIZE 10 -#define FSST_HASH_PRIME 2971215073LL -#define FSST_SHIFT 15 -#define FSST_HASH(w) (((w)*FSST_HASH_PRIME)^(((w)*FSST_HASH_PRIME)>>FSST_SHIFT)) - size_t hash() const { size_t v = 0xFFFFFF & val.num; return FSST_HASH(v); } // hash on the next 3 bytes -}; - -// Symbol that can be put in a queue, ordered on gain -struct QSymbol{ - Symbol symbol; - mutable u32 gain; // mutable because gain value should be ignored in find() on unordered_set of QSymbols - bool operator==(const QSymbol& other) const { return symbol.val.num == other.symbol.val.num && symbol.length() == other.symbol.length(); } -}; - -// we construct FSST symbol tables using a random sample of about 16KB (1<<14) -#define FSST_SAMPLETARGET (1<<14) -#define FSST_SAMPLEMAXSZ ((long) 2*FSST_SAMPLETARGET) - -// two phases of compression, before and after optimize(): -// -// (1) to encode values we probe (and maintain) three datastructures: -// - u16 byteCodes[65536] array at the position of the next byte (s.length==1) -// - u16 shortCodes[65536] array at the position of the next twobyte pattern (s.length==2) -// - Symbol hashtable[1024] (keyed by the next three bytes, ie for s.length>2), -// this search will yield a u16 code, it points into Symbol symbols[]. You always find a hit, because the first 256 codes are -// pseudo codes representing a single byte these will become escapes) -// -// (2) when we finished looking for the best symbol table we call optimize() to reshape it: -// - it renumbers the codes by length (first symbols of length 2,3,4,5,6,7,8; then 1 (starting from byteLim are symbols of length 1) -// length 2 codes for which no longer suffix symbol exists (< suffixLim) come first among the 2-byte codes -// (allows shortcut during compression) -// - for each two-byte combination, in all unused slots of shortCodes[], it enters the byteCode[] of the symbol corresponding -// to the first byte (if such a single-byte symbol exists). This allows us to just probe the next two bytes (if there is only one -// byte left in the string, there is still a terminator-byte added during compression) in shortCodes[]. That is, byteCodes[] -// and its codepath is no longer required. This makes compression faster. The reason we use byteCodes[] during symbolTable construction -// is that adding a new code/symbol is expensive (you have to touch shortCodes[] in 256 places). This optimization was -// hence added to make symbolTable construction faster. -// -// this final layout allows for the fastest compression code, only currently present in compressBulk - -// in the hash table, the icl field contains (low-to-high) ignoredBits:16,code:12,length:4 -#define FSST_ICL_FREE ((15<<28)|(((u32)FSST_CODE_MASK)<<16)) // high bits of icl (len=8,code=FSST_CODE_MASK) indicates free bucket - -// ignoredBits is (8-length)*8, which is the amount of high bits to zero in the input word before comparing with the hashtable key -// ..it could of course be computed from len during lookup, but storing it precomputed in some loose bits is faster -// -// the gain field is only used in the symbol queue that sorts symbols on gain - -struct SymbolTable { - static const u32 hashTabSize = 1<> (u8) s.icl); - return true; - } - bool add(Symbol s) { - assert(FSST_CODE_BASE + nSymbols < FSST_CODE_MAX); - u32 len = s.length(); - s.set_code_len(FSST_CODE_BASE + nSymbols, len); - if (len == 1) { - byteCodes[s.first()] = FSST_CODE_BASE + nSymbols + (1<> ((u8) hashTab[idx].icl)))) { - return (hashTab[idx].icl>>16) & FSST_CODE_MASK; // matched a long symbol - } - if (s.length() >= 2) { - u16 code = shortCodes[s.first2()] & FSST_CODE_MASK; - if (code >= FSST_CODE_BASE) return code; - } - return byteCodes[s.first()] & FSST_CODE_MASK; - } - u16 findLongestSymbol(u8* cur, u8* end) const { - return findLongestSymbol(Symbol(cur,end)); // represent the string as a temporary symbol - } - - // rationale for finalize: - // - during symbol table construction, we may create more than 256 codes, but bring it down to max 255 in the last makeTable() - // consequently we needed more than 8 bits during symbol table contruction, but can simplify the codes to single bytes in finalize() - // (this feature is in fact lo longer used, but could still be exploited: symbol construction creates no more than 255 symbols in each pass) - // - we not only reduce the amount of codes to <255, but also *reorder* the symbols and renumber their codes, for higher compression perf. - // we renumber codes so they are grouped by length, to allow optimized scalar string compression (byteLim and suffixLim optimizations). - // - we make the use of byteCode[] no longer necessary by inserting single-byte codes in the free spots of shortCodes[] - // Using shortCodes[] only makes compression faster. When creating the symbolTable, however, using shortCodes[] for the single-byte - // symbols is slow, as each insert touches 256 positions in it. This optimization was added when optimizing symbolTable construction time. - // - // In all, we change the layout and coding, as follows.. - // - // before finalize(): - // - The real symbols are symbols[256..256+nSymbols>. As we may have nSymbols > 255 - // - The first 256 codes are pseudo symbols (all escaped bytes) - // - // after finalize(): - // - table layout is symbols[0..nSymbols>, with nSymbols < 256. - // - Real codes are [0,nSymbols>. 8-th bit not set. - // - Escapes in shortCodes have the 8th bit set (value: 256+255=511). 255 because the code to be emitted is the escape byte 255 - // - symbols are grouped by length: 2,3,4,5,6,7,8, then 1 (single-byte codes last) - // the two-byte codes are split in two sections: - // - first section contains codes for symbols for which there is no longer symbol (no suffix). It allows an early-out during compression - // - // finally, shortCodes[] is modified to also encode all single-byte symbols (hence byteCodes[] is not required on a critical path anymore). - // - void finalize(u8 zeroTerminated) { - assert(nSymbols <= 255); - u8 newCode[256], rsum[8], byteLim = nSymbols - (lenHisto[0] - zeroTerminated); - - // compute running sum of code lengths (starting offsets for each length) - rsum[0] = byteLim; // 1-byte codes are highest - rsum[1] = zeroTerminated; - for(u32 i=1; i<7; i++) - rsum[i+1] = rsum[i] + lenHisto[i]; - - // determine the new code for each symbol, ordered by length (and splitting 2byte symbols into two classes around suffixLim) - suffixLim = rsum[1]; - symbols[newCode[0] = 0] = symbols[256]; // keep symbol 0 in place (for zeroTerminated cases only) - - for(u32 i=zeroTerminated, j=rsum[2]; i 1 && first2 == s2.first2()) // test if symbol k is a suffix of s - opt = 0; - } - newCode[i] = opt?suffixLim++:--j; // symbols without a larger suffix have a code < suffixLim - } else - newCode[i] = rsum[len-1]++; - s1.set_code_len(newCode[i],len); - symbols[newCode[i]] = s1; - } - // renumber the codes in byteCodes[] - for(u32 i=0; i<256; i++) - if ((byteCodes[i] & FSST_CODE_MASK) >= FSST_CODE_BASE) - byteCodes[i] = newCode[(u8) byteCodes[i]] + (1 << FSST_LEN_BITS); - else - byteCodes[i] = 511 + (1 << FSST_LEN_BITS); - - // renumber the codes in shortCodes[] - for(u32 i=0; i<65536; i++) - if ((shortCodes[i] & FSST_CODE_MASK) >= FSST_CODE_BASE) - shortCodes[i] = newCode[(u8) shortCodes[i]] + (shortCodes[i] & (15 << FSST_LEN_BITS)); - else - shortCodes[i] = byteCodes[i&0xFF]; - - // replace the symbols in the hash table - for(u32 i=0; i>8; - } - void count1Inc(u32 pos1) { - if (!count1Low[pos1]++) // increment high early (when low==0, not when low==255). This means (high > 0) <=> (cnt > 0) - count1High[pos1]++; //(0,0)->(1,1)->..->(255,1)->(0,1)->(1,2)->(2,2)->(3,2)..(255,2)->(0,2)->(1,3)->(2,3)... - } - void count2Inc(u32 pos1, u32 pos2) { - if (!count2Low[pos1][pos2]++) // increment high early (when low==0, not when low==255). This means (high > 0) <=> (cnt > 0) - // inc 4-bits high counter with 1<<0 (1) or 1<<4 (16) -- depending on whether pos2 is even or odd, repectively - count2High[pos1][(pos2)>>1] += 1 << (((pos2)&1)<<2); // we take our chances with overflow.. (4K maxval, on a 8K sample) - } - u32 count1GetNext(u32 &pos1) { // note: we will advance pos1 to the next nonzero counter in register range - // read 16-bits single symbol counter, split into two 8-bits numbers (count1Low, count1High), while skipping over zeros - u64 high = fsst_unaligned_load(&count1High[pos1]); - - u32 zero = high?(__builtin_ctzl(high)>>3):7UL; // number of zero bytes - high = (high >> (zero << 3)) & 255; // advance to nonzero counter - if (((pos1 += zero) >= FSST_CODE_MAX) || !high) // SKIP! advance pos2 - return 0; // all zero - - u32 low = count1Low[pos1]; - if (low) high--; // high is incremented early and low late, so decrement high (unless low==0) - return (u32) ((high << 8) + low); - } - u32 count2GetNext(u32 pos1, u32 &pos2) { // note: we will advance pos2 to the next nonzero counter in register range - // read 12-bits pairwise symbol counter, split into low 8-bits and high 4-bits number while skipping over zeros - u64 high = fsst_unaligned_load(&count2High[pos1][pos2>>1]); - high >>= ((pos2&1) << 2); // odd pos2: ignore the lowest 4 bits & we see only 15 counters - - u32 zero = high?(__builtin_ctzl(high)>>2):(15UL-(pos2&1UL)); // number of zero 4-bits counters - high = (high >> (zero << 2)) & 15; // advance to nonzero counter - if (((pos2 += zero) >= FSST_CODE_MAX) || !high) // SKIP! advance pos2 - return 0UL; // all zero - - u32 low = count2Low[pos1][pos2]; - if (low) high--; // high is incremented early and low late, so decrement high (unless low==0) - return (u32) ((high << 8) + low); - } - void backup1(u8 *buf) { - memcpy(buf, count1High, FSST_CODE_MAX); - memcpy(buf+FSST_CODE_MAX, count1Low, FSST_CODE_MAX); - } - void restore1(u8 *buf) { - memcpy(count1High, buf, FSST_CODE_MAX); - memcpy(count1Low, buf+FSST_CODE_MAX, FSST_CODE_MAX); - } -}; -#endif - - -#define FSST_BUFSZ (3<<19) // 768KB - -// an encoder is a symbolmap plus some bufferspace, needed during map construction as well as compression -struct Encoder { - shared_ptr symbolTable; // symbols, plus metadata and data structures for quick compression (shortCode,hashTab, etc) - union { - Counters counters; // for counting symbol occurences during map construction - u8 simdbuf[FSST_BUFSZ]; // for compression: SIMD string staging area 768KB = 256KB in + 512KB out (worst case for 256KB in) - }; -}; - -// job control integer representable in one 64bits SIMD lane: cur/end=input, out=output, pos=which string (2^9=512 per call) -struct SIMDjob { - u64 out:19,pos:9,end:18,cur:18; // cur/end is input offsets (2^18=256KB), out is output offset (2^19=512KB) -}; - -extern bool -duckdb_fsst_hasAVX512(); // runtime check for avx512 capability - -extern size_t -duckdb_fsst_compressAVX512( - SymbolTable &symbolTable, - u8* codeBase, // IN: base address for codes, i.e. compression output (points to simdbuf+256KB) - u8* symbolBase, // IN: base address for string bytes, i.e. compression input (points to simdbuf) - SIMDjob* input, // IN: input array (size n) with job information: what to encode, where to store it. - SIMDjob* output, // OUT: output array (size n) with job information: how much got encoded, end output pointer. - size_t n, // IN: size of arrays input and output (should be max 512) - size_t unroll); // IN: degree of SIMD unrolling - -// C++ fsst-compress function with some more control of how the compression happens (algorithm flavor, simd unroll degree) -size_t compressImpl(Encoder *encoder, size_t n, size_t lenIn[], u8 *strIn[], size_t size, u8 * output, size_t *lenOut, u8 *strOut[], bool noSuffixOpt, bool avoidBranch, int simd); -size_t compressAuto(Encoder *encoder, size_t n, size_t lenIn[], u8 *strIn[], size_t size, u8 * output, size_t *lenOut, u8 *strOut[], int simd); - - -// LICENSE_CHANGE_END - - -#if DUCKDB_FSST_ENABLE_INTRINSINCS && (defined(__x86_64__) || defined(_M_X64)) -#include - -#ifdef _WIN32 -bool duckdb_fsst_hasAVX512() { - int info[4]; - __cpuidex(info, 0x00000007, 0); - return (info[1]>>16)&1; -} -#else -#include -bool duckdb_fsst_hasAVX512() { - int info[4]; - __cpuid_count(0x00000007, 0, info[0], info[1], info[2], info[3]); - return (info[1]>>16)&1; -} -#endif -#else -bool duckdb_fsst_hasAVX512() { return false; } -#endif - -// BULK COMPRESSION OF STRINGS -// -// In one call of this function, we can compress 512 strings, each of maximum length 511 bytes. -// strings can be shorter than 511 bytes, no problem, but if they are longer we need to cut them up. -// -// In each iteration of the while loop, we find one code in each of the unroll*8 strings, i.e. (8,16,24 or 32) for resp. unroll=1,2,3,4 -// unroll3 performs best on my hardware -// -// In the worst case, each final encoded string occupies 512KB bytes (512*1024; with 1024=512xexception, exception = 2 bytes). -// - hence codeBase is a buffer of 512KB (needs 19 bits jobs), symbolBase of 256KB (needs 18 bits jobs). -// -// 'jobX' controls the encoding of each string and is therefore a u64 with format [out:19][pos:9][end:18][cur:18] (low-to-high bits) -// The field 'pos' tells which string we are processing (0..511). We need this info as strings will complete compressing out-of-order. -// -// Strings will have different lengths, and when a string is finished, we reload from the buffer of 512 input strings. -// This continues until we have less than (8,16,24 or 32; depending on unroll) strings left to process. -// - so 'processed' is the amount of strings we started processing and it is between [480,512]. -// Note that when we quit, there will still be some (<32) strings that we started to process but which are unfinished. -// - so 'unfinished' is that amount. These unfinished strings will be encoded further using the scalar method. -// -// Apart from the coded strings, we return in a output[] array of size 'processed' the job values of the 'finished' strings. -// In the following 'unfinished' slots (processed=finished+unfinished) we output the 'job' values of the unfinished strings. -// -// For the finished strings, we need [out:19] to see the compressed size and [pos:9] to see which string we refer to. -// For the unfinished strings, we need all fields of 'job' to continue the compression with scalar code (see SIMD code in compressBatch). -// -// THIS IS A SEPARATE CODE FILE NOT BECAUSE OF MY LOVE FOR MODULARIZED CODE BUT BECAUSE IT ALLOWS TO COMPILE IT WITH DIFFERENT FLAGS -// in particular, unrolling is crucial for gather/scatter performance, but requires registers. the #define all_* expressions however, -// will be detected to be constants by g++ -O2 and will be precomputed and placed into AVX512 registers - spoiling 9 of them. -// This reduces the effectiveness of unrolling, hence -O2 makes the loop perform worse than -O1 which skips this optimization. -// Assembly inspection confirmed that 3-way unroll with -O1 avoids needless load/stores. - -size_t duckdb_fsst_compressAVX512(SymbolTable &symbolTable, u8* codeBase, u8* symbolBase, SIMDjob *input, SIMDjob *output, size_t n, size_t unroll) { - size_t processed = 0; - // define some constants (all_x means that all 8 lanes contain 64-bits value X) -#if defined(__AVX512F__) and DUCKDB_FSST_ENABLE_INTRINSINCS - //__m512i all_suffixLim= _mm512_broadcastq_epi64(_mm_set1_epi64((__m64) (u64) symbolTable->suffixLim)); -- for variants b,c - __m512i all_MASK = _mm512_broadcastq_epi64(_mm_set1_epi64((__m64) (u64) -1)); - __m512i all_PRIME = _mm512_broadcastq_epi64(_mm_set1_epi64((__m64) (u64) FSST_HASH_PRIME)); - __m512i all_ICL_FREE = _mm512_broadcastq_epi64(_mm_set1_epi64((__m64) (u64) FSST_ICL_FREE)); -#define all_HASH _mm512_srli_epi64(all_MASK, 64-FSST_HASH_LOG2SIZE) -#define all_ONE _mm512_srli_epi64(all_MASK, 63) -#define all_M19 _mm512_srli_epi64(all_MASK, 45) -#define all_M18 _mm512_srli_epi64(all_MASK, 46) -#define all_M28 _mm512_srli_epi64(all_MASK, 36) -#define all_FFFFFF _mm512_srli_epi64(all_MASK, 40) -#define all_FFFF _mm512_srli_epi64(all_MASK, 48) -#define all_FF _mm512_srli_epi64(all_MASK, 56) - - SIMDjob *inputEnd = input+n; - assert(n >= unroll*8 && n <= 512); // should be close to 512 - __m512i job1, job2, job3, job4; // will contain current jobs, for each unroll 1,2,3,4 - __mmask8 loadmask1 = 255, loadmask2 = 255*(unroll>1), loadmask3 = 255*(unroll>2), loadmask4 = 255*(unroll>3); // 2b loaded new strings bitmask per unroll - u32 delta1 = 8, delta2 = 8*(unroll>1), delta3 = 8*(unroll>2), delta4 = 8*(unroll>3); // #new loads this SIMD iteration per unroll - - if (unroll >= 4) { - while (input+delta1+delta2+delta3+delta4 < inputEnd) { - - -// LICENSE_CHANGE_BEGIN -// The following code up to LICENSE_CHANGE_END is subject to THIRD PARTY LICENSE #6 -// See the end of this file for a list - -// this software is distributed under the MIT License (http://www.opensource.org/licenses/MIT): -// this software is distributed under the MIT License (http://www.opensource.org/licenses/MIT): -// this software is distributed under the MIT License (http://www.opensource.org/licenses/MIT): -// this software is distributed under the MIT License (http://www.opensource.org/licenses/MIT): -// -// -// -// -// Copyright 2018-2020, CWI, TU Munich, FSU Jena -// Copyright 2018-2020, CWI, TU Munich, FSU Jena -// Copyright 2018-2020, CWI, TU Munich, FSU Jena -// Copyright 2018-2020, CWI, TU Munich, FSU Jena -// -// -// -// -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files -// (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, -// (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, -// (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, -// (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, -// merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is -// merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is -// merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is -// merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// furnished to do so, subject to the following conditions: -// furnished to do so, subject to the following conditions: -// furnished to do so, subject to the following conditions: -// -// -// -// -// - The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. -// - The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. -// - The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. -// - The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. -// -// -// -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, E1PRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, E2PRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, E3PRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, E4PRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES -// OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -// OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -// OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -// OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR -// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR -// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR -// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR -// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -// -// -// -// -// You can contact the authors via the FSST source repository : https://github.com/cwida/fsst -// You can contact the authors via the FSST source repository : https://github.com/cwida/fsst -// You can contact the authors via the FSST source repository : https://github.com/cwida/fsst -// You can contact the authors via the FSST source repository : https://github.com/cwida/fsst -// -// -// -// - // load new jobs in the empty lanes (initially, all lanes are empty, so loadmask1=11111111, delta1=8). - // load new jobs in the empty lanes (initially, all lanes are empty, so loadmask2=11111111, delta2=8). - // load new jobs in the empty lanes (initially, all lanes are empty, so loadmask3=11111111, delta3=8). - // load new jobs in the empty lanes (initially, all lanes are empty, so loadmask4=11111111, delta4=8). - job1 = _mm512_mask_expandloadu_epi64(job1, loadmask1, input); input += delta1; - job2 = _mm512_mask_expandloadu_epi64(job2, loadmask2, input); input += delta2; - job3 = _mm512_mask_expandloadu_epi64(job3, loadmask3, input); input += delta3; - job4 = _mm512_mask_expandloadu_epi64(job4, loadmask4, input); input += delta4; - // load the next 8 input string bytes (uncompressed data, aka 'symbols'). - // load the next 8 input string bytes (uncompressed data, aka 'symbols'). - // load the next 8 input string bytes (uncompressed data, aka 'symbols'). - // load the next 8 input string bytes (uncompressed data, aka 'symbols'). - __m512i word1 = _mm512_i64gather_epi64(_mm512_srli_epi64(job1, 46), symbolBase, 1); - __m512i word2 = _mm512_i64gather_epi64(_mm512_srli_epi64(job2, 46), symbolBase, 1); - __m512i word3 = _mm512_i64gather_epi64(_mm512_srli_epi64(job3, 46), symbolBase, 1); - __m512i word4 = _mm512_i64gather_epi64(_mm512_srli_epi64(job4, 46), symbolBase, 1); - // load 16-bits codes from the 2-byte-prefix keyed lookup table. It also store 1-byte codes in all free slots. - // load 16-bits codes from the 2-byte-prefix keyed lookup table. It also store 1-byte codes in all free slots. - // load 16-bits codes from the 2-byte-prefix keyed lookup table. It also store 1-byte codes in all free slots. - // load 16-bits codes from the 2-byte-prefix keyed lookup table. It also store 1-byte codes in all free slots. - // code1: Lowest 8 bits contain the code. Eleventh bit is whether it is an escaped code. Next 4 bits is length (2 or 1). - // code2: Lowest 8 bits contain the code. Eleventh bit is whether it is an escaped code. Next 4 bits is length (2 or 1). - // code3: Lowest 8 bits contain the code. Eleventh bit is whether it is an escaped code. Next 4 bits is length (2 or 1). - // code4: Lowest 8 bits contain the code. Eleventh bit is whether it is an escaped code. Next 4 bits is length (2 or 1). - __m512i code1 = _mm512_i64gather_epi64(_mm512_and_epi64(word1, all_FFFF), symbolTable.shortCodes, sizeof(u16)); - __m512i code2 = _mm512_i64gather_epi64(_mm512_and_epi64(word2, all_FFFF), symbolTable.shortCodes, sizeof(u16)); - __m512i code3 = _mm512_i64gather_epi64(_mm512_and_epi64(word3, all_FFFF), symbolTable.shortCodes, sizeof(u16)); - __m512i code4 = _mm512_i64gather_epi64(_mm512_and_epi64(word4, all_FFFF), symbolTable.shortCodes, sizeof(u16)); - // get the first three bytes of the string. - // get the first three bytes of the string. - // get the first three bytes of the string. - // get the first three bytes of the string. - __m512i pos1 = _mm512_mullo_epi64(_mm512_and_epi64(word1, all_FFFFFF), all_PRIME); - __m512i pos2 = _mm512_mullo_epi64(_mm512_and_epi64(word2, all_FFFFFF), all_PRIME); - __m512i pos3 = _mm512_mullo_epi64(_mm512_and_epi64(word3, all_FFFFFF), all_PRIME); - __m512i pos4 = _mm512_mullo_epi64(_mm512_and_epi64(word4, all_FFFFFF), all_PRIME); - // hash them into a random number: pos1 = pos1*PRIME; pos1 ^= pos1>>SHIFT - // hash them into a random number: pos2 = pos2*PRIME; pos2 ^= pos2>>SHIFT - // hash them into a random number: pos3 = pos3*PRIME; pos3 ^= pos3>>SHIFT - // hash them into a random number: pos4 = pos4*PRIME; pos4 ^= pos4>>SHIFT - pos1 = _mm512_slli_epi64(_mm512_and_epi64(_mm512_xor_epi64(pos1,_mm512_srli_epi64(pos1,FSST_SHIFT)), all_HASH), 4); - pos2 = _mm512_slli_epi64(_mm512_and_epi64(_mm512_xor_epi64(pos2,_mm512_srli_epi64(pos2,FSST_SHIFT)), all_HASH), 4); - pos3 = _mm512_slli_epi64(_mm512_and_epi64(_mm512_xor_epi64(pos3,_mm512_srli_epi64(pos3,FSST_SHIFT)), all_HASH), 4); - pos4 = _mm512_slli_epi64(_mm512_and_epi64(_mm512_xor_epi64(pos4,_mm512_srli_epi64(pos4,FSST_SHIFT)), all_HASH), 4); - // lookup in the 3-byte-prefix keyed hash table - // lookup in the 3-byte-prefix keyed hash table - // lookup in the 3-byte-prefix keyed hash table - // lookup in the 3-byte-prefix keyed hash table - __m512i icl1 = _mm512_i64gather_epi64(pos1, (((char*) symbolTable.hashTab) + 8), 1); - __m512i icl2 = _mm512_i64gather_epi64(pos2, (((char*) symbolTable.hashTab) + 8), 1); - __m512i icl3 = _mm512_i64gather_epi64(pos3, (((char*) symbolTable.hashTab) + 8), 1); - __m512i icl4 = _mm512_i64gather_epi64(pos4, (((char*) symbolTable.hashTab) + 8), 1); - // speculatively store the first input byte into the second position of the write1 register (in case it turns out to be an escaped byte). - // speculatively store the first input byte into the second position of the write2 register (in case it turns out to be an escaped byte). - // speculatively store the first input byte into the second position of the write3 register (in case it turns out to be an escaped byte). - // speculatively store the first input byte into the second position of the write4 register (in case it turns out to be an escaped byte). - __m512i write1 = _mm512_slli_epi64(_mm512_and_epi64(word1, all_FF), 8); - __m512i write2 = _mm512_slli_epi64(_mm512_and_epi64(word2, all_FF), 8); - __m512i write3 = _mm512_slli_epi64(_mm512_and_epi64(word3, all_FF), 8); - __m512i write4 = _mm512_slli_epi64(_mm512_and_epi64(word4, all_FF), 8); - // lookup just like the icl1 above, but loads the next 8 bytes. This fetches the actual string bytes in the hash table. - // lookup just like the icl2 above, but loads the next 8 bytes. This fetches the actual string bytes in the hash table. - // lookup just like the icl3 above, but loads the next 8 bytes. This fetches the actual string bytes in the hash table. - // lookup just like the icl4 above, but loads the next 8 bytes. This fetches the actual string bytes in the hash table. - __m512i symb1 = _mm512_i64gather_epi64(pos1, (((char*) symbolTable.hashTab) + 0), 1); - __m512i symb2 = _mm512_i64gather_epi64(pos2, (((char*) symbolTable.hashTab) + 0), 1); - __m512i symb3 = _mm512_i64gather_epi64(pos3, (((char*) symbolTable.hashTab) + 0), 1); - __m512i symb4 = _mm512_i64gather_epi64(pos4, (((char*) symbolTable.hashTab) + 0), 1); - // generate the FF..FF mask with an FF for each byte of the symbol (we need to AND the input with this to correctly check equality). - // generate the FF..FF mask with an FF for each byte of the symbol (we need to AND the input with this to correctly check equality). - // generate the FF..FF mask with an FF for each byte of the symbol (we need to AND the input with this to correctly check equality). - // generate the FF..FF mask with an FF for each byte of the symbol (we need to AND the input with this to correctly check equality). - pos1 = _mm512_srlv_epi64(all_MASK, _mm512_and_epi64(icl1, all_FF)); - pos2 = _mm512_srlv_epi64(all_MASK, _mm512_and_epi64(icl2, all_FF)); - pos3 = _mm512_srlv_epi64(all_MASK, _mm512_and_epi64(icl3, all_FF)); - pos4 = _mm512_srlv_epi64(all_MASK, _mm512_and_epi64(icl4, all_FF)); - // check symbol < |str| as well as whether it is an occupied slot (cmplt checks both conditions at once) and check string equality (cmpeq). - // check symbol < |str| as well as whether it is an occupied slot (cmplt checks both conditions at once) and check string equality (cmpeq). - // check symbol < |str| as well as whether it is an occupied slot (cmplt checks both conditions at once) and check string equality (cmpeq). - // check symbol < |str| as well as whether it is an occupied slot (cmplt checks both conditions at once) and check string equality (cmpeq). - __mmask8 match1 = _mm512_cmpeq_epi64_mask(symb1, _mm512_and_epi64(word1, pos1)) & _mm512_cmplt_epi64_mask(icl1, all_ICL_FREE); - __mmask8 match2 = _mm512_cmpeq_epi64_mask(symb2, _mm512_and_epi64(word2, pos2)) & _mm512_cmplt_epi64_mask(icl2, all_ICL_FREE); - __mmask8 match3 = _mm512_cmpeq_epi64_mask(symb3, _mm512_and_epi64(word3, pos3)) & _mm512_cmplt_epi64_mask(icl3, all_ICL_FREE); - __mmask8 match4 = _mm512_cmpeq_epi64_mask(symb4, _mm512_and_epi64(word4, pos4)) & _mm512_cmplt_epi64_mask(icl4, all_ICL_FREE); - // for the hits, overwrite the codes with what comes from the hash table (codes for symbols of length >=3). The rest stays with what shortCodes gave. - // for the hits, overwrite the codes with what comes from the hash table (codes for symbols of length >=3). The rest stays with what shortCodes gave. - // for the hits, overwrite the codes with what comes from the hash table (codes for symbols of length >=3). The rest stays with what shortCodes gave. - // for the hits, overwrite the codes with what comes from the hash table (codes for symbols of length >=3). The rest stays with what shortCodes gave. - code1 = _mm512_mask_mov_epi64(code1, match1, _mm512_srli_epi64(icl1, 16)); - code2 = _mm512_mask_mov_epi64(code2, match2, _mm512_srli_epi64(icl2, 16)); - code3 = _mm512_mask_mov_epi64(code3, match3, _mm512_srli_epi64(icl3, 16)); - code4 = _mm512_mask_mov_epi64(code4, match4, _mm512_srli_epi64(icl4, 16)); - // write out the code byte as the first output byte. Notice that this byte may also be the escape code 255 (for escapes) coming from shortCodes. - // write out the code byte as the first output byte. Notice that this byte may also be the escape code 255 (for escapes) coming from shortCodes. - // write out the code byte as the first output byte. Notice that this byte may also be the escape code 255 (for escapes) coming from shortCodes. - // write out the code byte as the first output byte. Notice that this byte may also be the escape code 255 (for escapes) coming from shortCodes. - write1 = _mm512_or_epi64(write1, _mm512_and_epi64(code1, all_FF)); - write2 = _mm512_or_epi64(write2, _mm512_and_epi64(code2, all_FF)); - write3 = _mm512_or_epi64(write3, _mm512_and_epi64(code3, all_FF)); - write4 = _mm512_or_epi64(write4, _mm512_and_epi64(code4, all_FF)); - // zip the irrelevant 6 bytes (just stay with the 2 relevant bytes containing the 16-bits code) - // zip the irrelevant 6 bytes (just stay with the 2 relevant bytes containing the 16-bits code) - // zip the irrelevant 6 bytes (just stay with the 2 relevant bytes containing the 16-bits code) - // zip the irrelevant 6 bytes (just stay with the 2 relevant bytes containing the 16-bits code) - code1 = _mm512_and_epi64(code1, all_FFFF); - code2 = _mm512_and_epi64(code2, all_FFFF); - code3 = _mm512_and_epi64(code3, all_FFFF); - code4 = _mm512_and_epi64(code4, all_FFFF); - // write out the compressed data. It writes 8 bytes, but only 1 byte is relevant :-(or 2 bytes are, in case of an escape code) - // write out the compressed data. It writes 8 bytes, but only 1 byte is relevant :-(or 2 bytes are, in case of an escape code) - // write out the compressed data. It writes 8 bytes, but only 1 byte is relevant :-(or 2 bytes are, in case of an escape code) - // write out the compressed data. It writes 8 bytes, but only 1 byte is relevant :-(or 2 bytes are, in case of an escape code) - _mm512_i64scatter_epi64(codeBase, _mm512_and_epi64(job1, all_M19), write1, 1); - _mm512_i64scatter_epi64(codeBase, _mm512_and_epi64(job2, all_M19), write2, 1); - _mm512_i64scatter_epi64(codeBase, _mm512_and_epi64(job3, all_M19), write3, 1); - _mm512_i64scatter_epi64(codeBase, _mm512_and_epi64(job4, all_M19), write4, 1); - // increase the job1.cur field in the job with the symbol length (for this, shift away 12 bits from the code) - // increase the job2.cur field in the job with the symbol length (for this, shift away 12 bits from the code) - // increase the job3.cur field in the job with the symbol length (for this, shift away 12 bits from the code) - // increase the job4.cur field in the job with the symbol length (for this, shift away 12 bits from the code) - job1 = _mm512_add_epi64(job1, _mm512_slli_epi64(_mm512_srli_epi64(code1, FSST_LEN_BITS), 46)); - job2 = _mm512_add_epi64(job2, _mm512_slli_epi64(_mm512_srli_epi64(code2, FSST_LEN_BITS), 46)); - job3 = _mm512_add_epi64(job3, _mm512_slli_epi64(_mm512_srli_epi64(code3, FSST_LEN_BITS), 46)); - job4 = _mm512_add_epi64(job4, _mm512_slli_epi64(_mm512_srli_epi64(code4, FSST_LEN_BITS), 46)); - // increase the job1.out' field with one, or two in case of an escape code (add 1 plus the escape bit, i.e the 8th) - // increase the job2.out' field with one, or two in case of an escape code (add 1 plus the escape bit, i.e the 8th) - // increase the job3.out' field with one, or two in case of an escape code (add 1 plus the escape bit, i.e the 8th) - // increase the job4.out' field with one, or two in case of an escape code (add 1 plus the escape bit, i.e the 8th) - job1 = _mm512_add_epi64(job1, _mm512_add_epi64(all_ONE, _mm512_and_epi64(_mm512_srli_epi64(code1, 8), all_ONE))); - job2 = _mm512_add_epi64(job2, _mm512_add_epi64(all_ONE, _mm512_and_epi64(_mm512_srli_epi64(code2, 8), all_ONE))); - job3 = _mm512_add_epi64(job3, _mm512_add_epi64(all_ONE, _mm512_and_epi64(_mm512_srli_epi64(code3, 8), all_ONE))); - job4 = _mm512_add_epi64(job4, _mm512_add_epi64(all_ONE, _mm512_and_epi64(_mm512_srli_epi64(code4, 8), all_ONE))); - // test which lanes are done now (job1.cur==job1.end), cur starts at bit 46, end starts at bit 28 (the highest 2x18 bits in the job1 register) - // test which lanes are done now (job2.cur==job2.end), cur starts at bit 46, end starts at bit 28 (the highest 2x18 bits in the job2 register) - // test which lanes are done now (job3.cur==job3.end), cur starts at bit 46, end starts at bit 28 (the highest 2x18 bits in the job3 register) - // test which lanes are done now (job4.cur==job4.end), cur starts at bit 46, end starts at bit 28 (the highest 2x18 bits in the job4 register) - loadmask1 = _mm512_cmpeq_epi64_mask(_mm512_srli_epi64(job1, 46), _mm512_and_epi64(_mm512_srli_epi64(job1, 28), all_M18)); - loadmask2 = _mm512_cmpeq_epi64_mask(_mm512_srli_epi64(job2, 46), _mm512_and_epi64(_mm512_srli_epi64(job2, 28), all_M18)); - loadmask3 = _mm512_cmpeq_epi64_mask(_mm512_srli_epi64(job3, 46), _mm512_and_epi64(_mm512_srli_epi64(job3, 28), all_M18)); - loadmask4 = _mm512_cmpeq_epi64_mask(_mm512_srli_epi64(job4, 46), _mm512_and_epi64(_mm512_srli_epi64(job4, 28), all_M18)); - // calculate the amount of lanes in job1 that are done - // calculate the amount of lanes in job2 that are done - // calculate the amount of lanes in job3 that are done - // calculate the amount of lanes in job4 that are done - delta1 = _mm_popcnt_u32((int) loadmask1); - delta2 = _mm_popcnt_u32((int) loadmask2); - delta3 = _mm_popcnt_u32((int) loadmask3); - delta4 = _mm_popcnt_u32((int) loadmask4); - // write out the job state for the lanes that are done (we need the final 'job1.out' value to compute the compressed string length) - // write out the job state for the lanes that are done (we need the final 'job2.out' value to compute the compressed string length) - // write out the job state for the lanes that are done (we need the final 'job3.out' value to compute the compressed string length) - // write out the job state for the lanes that are done (we need the final 'job4.out' value to compute the compressed string length) - _mm512_mask_compressstoreu_epi64(output, loadmask1, job1); output += delta1; - _mm512_mask_compressstoreu_epi64(output, loadmask2, job2); output += delta2; - _mm512_mask_compressstoreu_epi64(output, loadmask3, job3); output += delta3; - _mm512_mask_compressstoreu_epi64(output, loadmask4, job4); output += delta4; - - -// LICENSE_CHANGE_END - - } - } else if (unroll == 3) { - while (input+delta1+delta2+delta3 < inputEnd) { - - -// LICENSE_CHANGE_BEGIN -// The following code up to LICENSE_CHANGE_END is subject to THIRD PARTY LICENSE #6 -// See the end of this file for a list - -// this software is distributed under the MIT License (http://www.opensource.org/licenses/MIT): -// this software is distributed under the MIT License (http://www.opensource.org/licenses/MIT): -// this software is distributed under the MIT License (http://www.opensource.org/licenses/MIT): -// -// -// -// Copyright 2018-2020, CWI, TU Munich, FSU Jena -// Copyright 2018-2020, CWI, TU Munich, FSU Jena -// Copyright 2018-2020, CWI, TU Munich, FSU Jena -// -// -// -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files -// (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, -// (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, -// (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, -// merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is -// merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is -// merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// furnished to do so, subject to the following conditions: -// furnished to do so, subject to the following conditions: -// -// -// -// - The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. -// - The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. -// - The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. -// -// -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, E1PRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, E2PRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, E3PRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES -// OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -// OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -// OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR -// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR -// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR -// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -// -// -// -// You can contact the authors via the FSST source repository : https://github.com/cwida/fsst -// You can contact the authors via the FSST source repository : https://github.com/cwida/fsst -// You can contact the authors via the FSST source repository : https://github.com/cwida/fsst -// -// -// - // load new jobs in the empty lanes (initially, all lanes are empty, so loadmask1=11111111, delta1=8). - // load new jobs in the empty lanes (initially, all lanes are empty, so loadmask2=11111111, delta2=8). - // load new jobs in the empty lanes (initially, all lanes are empty, so loadmask3=11111111, delta3=8). - job1 = _mm512_mask_expandloadu_epi64(job1, loadmask1, input); input += delta1; - job2 = _mm512_mask_expandloadu_epi64(job2, loadmask2, input); input += delta2; - job3 = _mm512_mask_expandloadu_epi64(job3, loadmask3, input); input += delta3; - // load the next 8 input string bytes (uncompressed data, aka 'symbols'). - // load the next 8 input string bytes (uncompressed data, aka 'symbols'). - // load the next 8 input string bytes (uncompressed data, aka 'symbols'). - __m512i word1 = _mm512_i64gather_epi64(_mm512_srli_epi64(job1, 46), symbolBase, 1); - __m512i word2 = _mm512_i64gather_epi64(_mm512_srli_epi64(job2, 46), symbolBase, 1); - __m512i word3 = _mm512_i64gather_epi64(_mm512_srli_epi64(job3, 46), symbolBase, 1); - // load 16-bits codes from the 2-byte-prefix keyed lookup table. It also store 1-byte codes in all free slots. - // load 16-bits codes from the 2-byte-prefix keyed lookup table. It also store 1-byte codes in all free slots. - // load 16-bits codes from the 2-byte-prefix keyed lookup table. It also store 1-byte codes in all free slots. - // code1: Lowest 8 bits contain the code. Eleventh bit is whether it is an escaped code. Next 4 bits is length (2 or 1). - // code2: Lowest 8 bits contain the code. Eleventh bit is whether it is an escaped code. Next 4 bits is length (2 or 1). - // code3: Lowest 8 bits contain the code. Eleventh bit is whether it is an escaped code. Next 4 bits is length (2 or 1). - __m512i code1 = _mm512_i64gather_epi64(_mm512_and_epi64(word1, all_FFFF), symbolTable.shortCodes, sizeof(u16)); - __m512i code2 = _mm512_i64gather_epi64(_mm512_and_epi64(word2, all_FFFF), symbolTable.shortCodes, sizeof(u16)); - __m512i code3 = _mm512_i64gather_epi64(_mm512_and_epi64(word3, all_FFFF), symbolTable.shortCodes, sizeof(u16)); - // get the first three bytes of the string. - // get the first three bytes of the string. - // get the first three bytes of the string. - __m512i pos1 = _mm512_mullo_epi64(_mm512_and_epi64(word1, all_FFFFFF), all_PRIME); - __m512i pos2 = _mm512_mullo_epi64(_mm512_and_epi64(word2, all_FFFFFF), all_PRIME); - __m512i pos3 = _mm512_mullo_epi64(_mm512_and_epi64(word3, all_FFFFFF), all_PRIME); - // hash them into a random number: pos1 = pos1*PRIME; pos1 ^= pos1>>SHIFT - // hash them into a random number: pos2 = pos2*PRIME; pos2 ^= pos2>>SHIFT - // hash them into a random number: pos3 = pos3*PRIME; pos3 ^= pos3>>SHIFT - pos1 = _mm512_slli_epi64(_mm512_and_epi64(_mm512_xor_epi64(pos1,_mm512_srli_epi64(pos1,FSST_SHIFT)), all_HASH), 4); - pos2 = _mm512_slli_epi64(_mm512_and_epi64(_mm512_xor_epi64(pos2,_mm512_srli_epi64(pos2,FSST_SHIFT)), all_HASH), 4); - pos3 = _mm512_slli_epi64(_mm512_and_epi64(_mm512_xor_epi64(pos3,_mm512_srli_epi64(pos3,FSST_SHIFT)), all_HASH), 4); - // lookup in the 3-byte-prefix keyed hash table - // lookup in the 3-byte-prefix keyed hash table - // lookup in the 3-byte-prefix keyed hash table - __m512i icl1 = _mm512_i64gather_epi64(pos1, (((char*) symbolTable.hashTab) + 8), 1); - __m512i icl2 = _mm512_i64gather_epi64(pos2, (((char*) symbolTable.hashTab) + 8), 1); - __m512i icl3 = _mm512_i64gather_epi64(pos3, (((char*) symbolTable.hashTab) + 8), 1); - // speculatively store the first input byte into the second position of the write1 register (in case it turns out to be an escaped byte). - // speculatively store the first input byte into the second position of the write2 register (in case it turns out to be an escaped byte). - // speculatively store the first input byte into the second position of the write3 register (in case it turns out to be an escaped byte). - __m512i write1 = _mm512_slli_epi64(_mm512_and_epi64(word1, all_FF), 8); - __m512i write2 = _mm512_slli_epi64(_mm512_and_epi64(word2, all_FF), 8); - __m512i write3 = _mm512_slli_epi64(_mm512_and_epi64(word3, all_FF), 8); - // lookup just like the icl1 above, but loads the next 8 bytes. This fetches the actual string bytes in the hash table. - // lookup just like the icl2 above, but loads the next 8 bytes. This fetches the actual string bytes in the hash table. - // lookup just like the icl3 above, but loads the next 8 bytes. This fetches the actual string bytes in the hash table. - __m512i symb1 = _mm512_i64gather_epi64(pos1, (((char*) symbolTable.hashTab) + 0), 1); - __m512i symb2 = _mm512_i64gather_epi64(pos2, (((char*) symbolTable.hashTab) + 0), 1); - __m512i symb3 = _mm512_i64gather_epi64(pos3, (((char*) symbolTable.hashTab) + 0), 1); - // generate the FF..FF mask with an FF for each byte of the symbol (we need to AND the input with this to correctly check equality). - // generate the FF..FF mask with an FF for each byte of the symbol (we need to AND the input with this to correctly check equality). - // generate the FF..FF mask with an FF for each byte of the symbol (we need to AND the input with this to correctly check equality). - pos1 = _mm512_srlv_epi64(all_MASK, _mm512_and_epi64(icl1, all_FF)); - pos2 = _mm512_srlv_epi64(all_MASK, _mm512_and_epi64(icl2, all_FF)); - pos3 = _mm512_srlv_epi64(all_MASK, _mm512_and_epi64(icl3, all_FF)); - // check symbol < |str| as well as whether it is an occupied slot (cmplt checks both conditions at once) and check string equality (cmpeq). - // check symbol < |str| as well as whether it is an occupied slot (cmplt checks both conditions at once) and check string equality (cmpeq). - // check symbol < |str| as well as whether it is an occupied slot (cmplt checks both conditions at once) and check string equality (cmpeq). - __mmask8 match1 = _mm512_cmpeq_epi64_mask(symb1, _mm512_and_epi64(word1, pos1)) & _mm512_cmplt_epi64_mask(icl1, all_ICL_FREE); - __mmask8 match2 = _mm512_cmpeq_epi64_mask(symb2, _mm512_and_epi64(word2, pos2)) & _mm512_cmplt_epi64_mask(icl2, all_ICL_FREE); - __mmask8 match3 = _mm512_cmpeq_epi64_mask(symb3, _mm512_and_epi64(word3, pos3)) & _mm512_cmplt_epi64_mask(icl3, all_ICL_FREE); - // for the hits, overwrite the codes with what comes from the hash table (codes for symbols of length >=3). The rest stays with what shortCodes gave. - // for the hits, overwrite the codes with what comes from the hash table (codes for symbols of length >=3). The rest stays with what shortCodes gave. - // for the hits, overwrite the codes with what comes from the hash table (codes for symbols of length >=3). The rest stays with what shortCodes gave. - code1 = _mm512_mask_mov_epi64(code1, match1, _mm512_srli_epi64(icl1, 16)); - code2 = _mm512_mask_mov_epi64(code2, match2, _mm512_srli_epi64(icl2, 16)); - code3 = _mm512_mask_mov_epi64(code3, match3, _mm512_srli_epi64(icl3, 16)); - // write out the code byte as the first output byte. Notice that this byte may also be the escape code 255 (for escapes) coming from shortCodes. - // write out the code byte as the first output byte. Notice that this byte may also be the escape code 255 (for escapes) coming from shortCodes. - // write out the code byte as the first output byte. Notice that this byte may also be the escape code 255 (for escapes) coming from shortCodes. - write1 = _mm512_or_epi64(write1, _mm512_and_epi64(code1, all_FF)); - write2 = _mm512_or_epi64(write2, _mm512_and_epi64(code2, all_FF)); - write3 = _mm512_or_epi64(write3, _mm512_and_epi64(code3, all_FF)); - // zip the irrelevant 6 bytes (just stay with the 2 relevant bytes containing the 16-bits code) - // zip the irrelevant 6 bytes (just stay with the 2 relevant bytes containing the 16-bits code) - // zip the irrelevant 6 bytes (just stay with the 2 relevant bytes containing the 16-bits code) - code1 = _mm512_and_epi64(code1, all_FFFF); - code2 = _mm512_and_epi64(code2, all_FFFF); - code3 = _mm512_and_epi64(code3, all_FFFF); - // write out the compressed data. It writes 8 bytes, but only 1 byte is relevant :-(or 2 bytes are, in case of an escape code) - // write out the compressed data. It writes 8 bytes, but only 1 byte is relevant :-(or 2 bytes are, in case of an escape code) - // write out the compressed data. It writes 8 bytes, but only 1 byte is relevant :-(or 2 bytes are, in case of an escape code) - _mm512_i64scatter_epi64(codeBase, _mm512_and_epi64(job1, all_M19), write1, 1); - _mm512_i64scatter_epi64(codeBase, _mm512_and_epi64(job2, all_M19), write2, 1); - _mm512_i64scatter_epi64(codeBase, _mm512_and_epi64(job3, all_M19), write3, 1); - // increase the job1.cur field in the job with the symbol length (for this, shift away 12 bits from the code) - // increase the job2.cur field in the job with the symbol length (for this, shift away 12 bits from the code) - // increase the job3.cur field in the job with the symbol length (for this, shift away 12 bits from the code) - job1 = _mm512_add_epi64(job1, _mm512_slli_epi64(_mm512_srli_epi64(code1, FSST_LEN_BITS), 46)); - job2 = _mm512_add_epi64(job2, _mm512_slli_epi64(_mm512_srli_epi64(code2, FSST_LEN_BITS), 46)); - job3 = _mm512_add_epi64(job3, _mm512_slli_epi64(_mm512_srli_epi64(code3, FSST_LEN_BITS), 46)); - // increase the job1.out' field with one, or two in case of an escape code (add 1 plus the escape bit, i.e the 8th) - // increase the job2.out' field with one, or two in case of an escape code (add 1 plus the escape bit, i.e the 8th) - // increase the job3.out' field with one, or two in case of an escape code (add 1 plus the escape bit, i.e the 8th) - job1 = _mm512_add_epi64(job1, _mm512_add_epi64(all_ONE, _mm512_and_epi64(_mm512_srli_epi64(code1, 8), all_ONE))); - job2 = _mm512_add_epi64(job2, _mm512_add_epi64(all_ONE, _mm512_and_epi64(_mm512_srli_epi64(code2, 8), all_ONE))); - job3 = _mm512_add_epi64(job3, _mm512_add_epi64(all_ONE, _mm512_and_epi64(_mm512_srli_epi64(code3, 8), all_ONE))); - // test which lanes are done now (job1.cur==job1.end), cur starts at bit 46, end starts at bit 28 (the highest 2x18 bits in the job1 register) - // test which lanes are done now (job2.cur==job2.end), cur starts at bit 46, end starts at bit 28 (the highest 2x18 bits in the job2 register) - // test which lanes are done now (job3.cur==job3.end), cur starts at bit 46, end starts at bit 28 (the highest 2x18 bits in the job3 register) - loadmask1 = _mm512_cmpeq_epi64_mask(_mm512_srli_epi64(job1, 46), _mm512_and_epi64(_mm512_srli_epi64(job1, 28), all_M18)); - loadmask2 = _mm512_cmpeq_epi64_mask(_mm512_srli_epi64(job2, 46), _mm512_and_epi64(_mm512_srli_epi64(job2, 28), all_M18)); - loadmask3 = _mm512_cmpeq_epi64_mask(_mm512_srli_epi64(job3, 46), _mm512_and_epi64(_mm512_srli_epi64(job3, 28), all_M18)); - // calculate the amount of lanes in job1 that are done - // calculate the amount of lanes in job2 that are done - // calculate the amount of lanes in job3 that are done - delta1 = _mm_popcnt_u32((int) loadmask1); - delta2 = _mm_popcnt_u32((int) loadmask2); - delta3 = _mm_popcnt_u32((int) loadmask3); - // write out the job state for the lanes that are done (we need the final 'job1.out' value to compute the compressed string length) - // write out the job state for the lanes that are done (we need the final 'job2.out' value to compute the compressed string length) - // write out the job state for the lanes that are done (we need the final 'job3.out' value to compute the compressed string length) - _mm512_mask_compressstoreu_epi64(output, loadmask1, job1); output += delta1; - _mm512_mask_compressstoreu_epi64(output, loadmask2, job2); output += delta2; - _mm512_mask_compressstoreu_epi64(output, loadmask3, job3); output += delta3; - - -// LICENSE_CHANGE_END - - } - } else if (unroll == 2) { - while (input+delta1+delta2 < inputEnd) { - - -// LICENSE_CHANGE_BEGIN -// The following code up to LICENSE_CHANGE_END is subject to THIRD PARTY LICENSE #6 -// See the end of this file for a list - -// this software is distributed under the MIT License (http://www.opensource.org/licenses/MIT): -// this software is distributed under the MIT License (http://www.opensource.org/licenses/MIT): -// -// -// Copyright 2018-2020, CWI, TU Munich, FSU Jena -// Copyright 2018-2020, CWI, TU Munich, FSU Jena -// -// -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files -// (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, -// (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, -// merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is -// merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// furnished to do so, subject to the following conditions: -// -// -// - The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. -// - The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. -// -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, E1PRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, E2PRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES -// OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -// OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR -// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR -// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -// -// -// You can contact the authors via the FSST source repository : https://github.com/cwida/fsst -// You can contact the authors via the FSST source repository : https://github.com/cwida/fsst -// -// - // load new jobs in the empty lanes (initially, all lanes are empty, so loadmask1=11111111, delta1=8). - // load new jobs in the empty lanes (initially, all lanes are empty, so loadmask2=11111111, delta2=8). - job1 = _mm512_mask_expandloadu_epi64(job1, loadmask1, input); input += delta1; - job2 = _mm512_mask_expandloadu_epi64(job2, loadmask2, input); input += delta2; - // load the next 8 input string bytes (uncompressed data, aka 'symbols'). - // load the next 8 input string bytes (uncompressed data, aka 'symbols'). - __m512i word1 = _mm512_i64gather_epi64(_mm512_srli_epi64(job1, 46), symbolBase, 1); - __m512i word2 = _mm512_i64gather_epi64(_mm512_srli_epi64(job2, 46), symbolBase, 1); - // load 16-bits codes from the 2-byte-prefix keyed lookup table. It also store 1-byte codes in all free slots. - // load 16-bits codes from the 2-byte-prefix keyed lookup table. It also store 1-byte codes in all free slots. - // code1: Lowest 8 bits contain the code. Eleventh bit is whether it is an escaped code. Next 4 bits is length (2 or 1). - // code2: Lowest 8 bits contain the code. Eleventh bit is whether it is an escaped code. Next 4 bits is length (2 or 1). - __m512i code1 = _mm512_i64gather_epi64(_mm512_and_epi64(word1, all_FFFF), symbolTable.shortCodes, sizeof(u16)); - __m512i code2 = _mm512_i64gather_epi64(_mm512_and_epi64(word2, all_FFFF), symbolTable.shortCodes, sizeof(u16)); - // get the first three bytes of the string. - // get the first three bytes of the string. - __m512i pos1 = _mm512_mullo_epi64(_mm512_and_epi64(word1, all_FFFFFF), all_PRIME); - __m512i pos2 = _mm512_mullo_epi64(_mm512_and_epi64(word2, all_FFFFFF), all_PRIME); - // hash them into a random number: pos1 = pos1*PRIME; pos1 ^= pos1>>SHIFT - // hash them into a random number: pos2 = pos2*PRIME; pos2 ^= pos2>>SHIFT - pos1 = _mm512_slli_epi64(_mm512_and_epi64(_mm512_xor_epi64(pos1,_mm512_srli_epi64(pos1,FSST_SHIFT)), all_HASH), 4); - pos2 = _mm512_slli_epi64(_mm512_and_epi64(_mm512_xor_epi64(pos2,_mm512_srli_epi64(pos2,FSST_SHIFT)), all_HASH), 4); - // lookup in the 3-byte-prefix keyed hash table - // lookup in the 3-byte-prefix keyed hash table - __m512i icl1 = _mm512_i64gather_epi64(pos1, (((char*) symbolTable.hashTab) + 8), 1); - __m512i icl2 = _mm512_i64gather_epi64(pos2, (((char*) symbolTable.hashTab) + 8), 1); - // speculatively store the first input byte into the second position of the write1 register (in case it turns out to be an escaped byte). - // speculatively store the first input byte into the second position of the write2 register (in case it turns out to be an escaped byte). - __m512i write1 = _mm512_slli_epi64(_mm512_and_epi64(word1, all_FF), 8); - __m512i write2 = _mm512_slli_epi64(_mm512_and_epi64(word2, all_FF), 8); - // lookup just like the icl1 above, but loads the next 8 bytes. This fetches the actual string bytes in the hash table. - // lookup just like the icl2 above, but loads the next 8 bytes. This fetches the actual string bytes in the hash table. - __m512i symb1 = _mm512_i64gather_epi64(pos1, (((char*) symbolTable.hashTab) + 0), 1); - __m512i symb2 = _mm512_i64gather_epi64(pos2, (((char*) symbolTable.hashTab) + 0), 1); - // generate the FF..FF mask with an FF for each byte of the symbol (we need to AND the input with this to correctly check equality). - // generate the FF..FF mask with an FF for each byte of the symbol (we need to AND the input with this to correctly check equality). - pos1 = _mm512_srlv_epi64(all_MASK, _mm512_and_epi64(icl1, all_FF)); - pos2 = _mm512_srlv_epi64(all_MASK, _mm512_and_epi64(icl2, all_FF)); - // check symbol < |str| as well as whether it is an occupied slot (cmplt checks both conditions at once) and check string equality (cmpeq). - // check symbol < |str| as well as whether it is an occupied slot (cmplt checks both conditions at once) and check string equality (cmpeq). - __mmask8 match1 = _mm512_cmpeq_epi64_mask(symb1, _mm512_and_epi64(word1, pos1)) & _mm512_cmplt_epi64_mask(icl1, all_ICL_FREE); - __mmask8 match2 = _mm512_cmpeq_epi64_mask(symb2, _mm512_and_epi64(word2, pos2)) & _mm512_cmplt_epi64_mask(icl2, all_ICL_FREE); - // for the hits, overwrite the codes with what comes from the hash table (codes for symbols of length >=3). The rest stays with what shortCodes gave. - // for the hits, overwrite the codes with what comes from the hash table (codes for symbols of length >=3). The rest stays with what shortCodes gave. - code1 = _mm512_mask_mov_epi64(code1, match1, _mm512_srli_epi64(icl1, 16)); - code2 = _mm512_mask_mov_epi64(code2, match2, _mm512_srli_epi64(icl2, 16)); - // write out the code byte as the first output byte. Notice that this byte may also be the escape code 255 (for escapes) coming from shortCodes. - // write out the code byte as the first output byte. Notice that this byte may also be the escape code 255 (for escapes) coming from shortCodes. - write1 = _mm512_or_epi64(write1, _mm512_and_epi64(code1, all_FF)); - write2 = _mm512_or_epi64(write2, _mm512_and_epi64(code2, all_FF)); - // zip the irrelevant 6 bytes (just stay with the 2 relevant bytes containing the 16-bits code) - // zip the irrelevant 6 bytes (just stay with the 2 relevant bytes containing the 16-bits code) - code1 = _mm512_and_epi64(code1, all_FFFF); - code2 = _mm512_and_epi64(code2, all_FFFF); - // write out the compressed data. It writes 8 bytes, but only 1 byte is relevant :-(or 2 bytes are, in case of an escape code) - // write out the compressed data. It writes 8 bytes, but only 1 byte is relevant :-(or 2 bytes are, in case of an escape code) - _mm512_i64scatter_epi64(codeBase, _mm512_and_epi64(job1, all_M19), write1, 1); - _mm512_i64scatter_epi64(codeBase, _mm512_and_epi64(job2, all_M19), write2, 1); - // increase the job1.cur field in the job with the symbol length (for this, shift away 12 bits from the code) - // increase the job2.cur field in the job with the symbol length (for this, shift away 12 bits from the code) - job1 = _mm512_add_epi64(job1, _mm512_slli_epi64(_mm512_srli_epi64(code1, FSST_LEN_BITS), 46)); - job2 = _mm512_add_epi64(job2, _mm512_slli_epi64(_mm512_srli_epi64(code2, FSST_LEN_BITS), 46)); - // increase the job1.out' field with one, or two in case of an escape code (add 1 plus the escape bit, i.e the 8th) - // increase the job2.out' field with one, or two in case of an escape code (add 1 plus the escape bit, i.e the 8th) - job1 = _mm512_add_epi64(job1, _mm512_add_epi64(all_ONE, _mm512_and_epi64(_mm512_srli_epi64(code1, 8), all_ONE))); - job2 = _mm512_add_epi64(job2, _mm512_add_epi64(all_ONE, _mm512_and_epi64(_mm512_srli_epi64(code2, 8), all_ONE))); - // test which lanes are done now (job1.cur==job1.end), cur starts at bit 46, end starts at bit 28 (the highest 2x18 bits in the job1 register) - // test which lanes are done now (job2.cur==job2.end), cur starts at bit 46, end starts at bit 28 (the highest 2x18 bits in the job2 register) - loadmask1 = _mm512_cmpeq_epi64_mask(_mm512_srli_epi64(job1, 46), _mm512_and_epi64(_mm512_srli_epi64(job1, 28), all_M18)); - loadmask2 = _mm512_cmpeq_epi64_mask(_mm512_srli_epi64(job2, 46), _mm512_and_epi64(_mm512_srli_epi64(job2, 28), all_M18)); - // calculate the amount of lanes in job1 that are done - // calculate the amount of lanes in job2 that are done - delta1 = _mm_popcnt_u32((int) loadmask1); - delta2 = _mm_popcnt_u32((int) loadmask2); - // write out the job state for the lanes that are done (we need the final 'job1.out' value to compute the compressed string length) - // write out the job state for the lanes that are done (we need the final 'job2.out' value to compute the compressed string length) - _mm512_mask_compressstoreu_epi64(output, loadmask1, job1); output += delta1; - _mm512_mask_compressstoreu_epi64(output, loadmask2, job2); output += delta2; - - -// LICENSE_CHANGE_END - - } - } else { - while (input+delta1 < inputEnd) { - - -// LICENSE_CHANGE_BEGIN -// The following code up to LICENSE_CHANGE_END is subject to THIRD PARTY LICENSE #6 -// See the end of this file for a list - -// this software is distributed under the MIT License (http://www.opensource.org/licenses/MIT): -// -// Copyright 2018-2020, CWI, TU Munich, FSU Jena -// -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files -// (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, -// merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// - The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, E1PRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES -// OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR -// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -// -// You can contact the authors via the FSST source repository : https://github.com/cwida/fsst -// - // load new jobs in the empty lanes (initially, all lanes are empty, so loadmask1=11111111, delta1=8). - job1 = _mm512_mask_expandloadu_epi64(job1, loadmask1, input); input += delta1; - // load the next 8 input string bytes (uncompressed data, aka 'symbols'). - __m512i word1 = _mm512_i64gather_epi64(_mm512_srli_epi64(job1, 46), symbolBase, 1); - // load 16-bits codes from the 2-byte-prefix keyed lookup table. It also store 1-byte codes in all free slots. - // code1: Lowest 8 bits contain the code. Eleventh bit is whether it is an escaped code. Next 4 bits is length (2 or 1). - __m512i code1 = _mm512_i64gather_epi64(_mm512_and_epi64(word1, all_FFFF), symbolTable.shortCodes, sizeof(u16)); - // get the first three bytes of the string. - __m512i pos1 = _mm512_mullo_epi64(_mm512_and_epi64(word1, all_FFFFFF), all_PRIME); - // hash them into a random number: pos1 = pos1*PRIME; pos1 ^= pos1>>SHIFT - pos1 = _mm512_slli_epi64(_mm512_and_epi64(_mm512_xor_epi64(pos1,_mm512_srli_epi64(pos1,FSST_SHIFT)), all_HASH), 4); - // lookup in the 3-byte-prefix keyed hash table - __m512i icl1 = _mm512_i64gather_epi64(pos1, (((char*) symbolTable.hashTab) + 8), 1); - // speculatively store the first input byte into the second position of the write1 register (in case it turns out to be an escaped byte). - __m512i write1 = _mm512_slli_epi64(_mm512_and_epi64(word1, all_FF), 8); - // lookup just like the icl1 above, but loads the next 8 bytes. This fetches the actual string bytes in the hash table. - __m512i symb1 = _mm512_i64gather_epi64(pos1, (((char*) symbolTable.hashTab) + 0), 1); - // generate the FF..FF mask with an FF for each byte of the symbol (we need to AND the input with this to correctly check equality). - pos1 = _mm512_srlv_epi64(all_MASK, _mm512_and_epi64(icl1, all_FF)); - // check symbol < |str| as well as whether it is an occupied slot (cmplt checks both conditions at once) and check string equality (cmpeq). - __mmask8 match1 = _mm512_cmpeq_epi64_mask(symb1, _mm512_and_epi64(word1, pos1)) & _mm512_cmplt_epi64_mask(icl1, all_ICL_FREE); - // for the hits, overwrite the codes with what comes from the hash table (codes for symbols of length >=3). The rest stays with what shortCodes gave. - code1 = _mm512_mask_mov_epi64(code1, match1, _mm512_srli_epi64(icl1, 16)); - // write out the code byte as the first output byte. Notice that this byte may also be the escape code 255 (for escapes) coming from shortCodes. - write1 = _mm512_or_epi64(write1, _mm512_and_epi64(code1, all_FF)); - // zip the irrelevant 6 bytes (just stay with the 2 relevant bytes containing the 16-bits code) - code1 = _mm512_and_epi64(code1, all_FFFF); - // write out the compressed data. It writes 8 bytes, but only 1 byte is relevant :-(or 2 bytes are, in case of an escape code) - _mm512_i64scatter_epi64(codeBase, _mm512_and_epi64(job1, all_M19), write1, 1); - // increase the job1.cur field in the job with the symbol length (for this, shift away 12 bits from the code) - job1 = _mm512_add_epi64(job1, _mm512_slli_epi64(_mm512_srli_epi64(code1, FSST_LEN_BITS), 46)); - // increase the job1.out' field with one, or two in case of an escape code (add 1 plus the escape bit, i.e the 8th) - job1 = _mm512_add_epi64(job1, _mm512_add_epi64(all_ONE, _mm512_and_epi64(_mm512_srli_epi64(code1, 8), all_ONE))); - // test which lanes are done now (job1.cur==job1.end), cur starts at bit 46, end starts at bit 28 (the highest 2x18 bits in the job1 register) - loadmask1 = _mm512_cmpeq_epi64_mask(_mm512_srli_epi64(job1, 46), _mm512_and_epi64(_mm512_srli_epi64(job1, 28), all_M18)); - // calculate the amount of lanes in job1 that are done - delta1 = _mm_popcnt_u32((int) loadmask1); - // write out the job state for the lanes that are done (we need the final 'job1.out' value to compute the compressed string length) - _mm512_mask_compressstoreu_epi64(output, loadmask1, job1); output += delta1; - - -// LICENSE_CHANGE_END - - } - } - - // flush the job states of the unfinished strings at the end of output[] - processed = n - (inputEnd - input); - u32 unfinished = 0; - if (unroll > 1) { - if (unroll > 2) { - if (unroll > 3) { - _mm512_mask_compressstoreu_epi64(output+unfinished, loadmask4=~loadmask4, job4); - unfinished += _mm_popcnt_u32((int) loadmask4); - } - _mm512_mask_compressstoreu_epi64(output+unfinished, loadmask3=~loadmask3, job3); - unfinished += _mm_popcnt_u32((int) loadmask3); - } - _mm512_mask_compressstoreu_epi64(output+unfinished, loadmask2=~loadmask2, job2); - unfinished += _mm_popcnt_u32((int) loadmask2); - } - _mm512_mask_compressstoreu_epi64(output+unfinished, loadmask1=~loadmask1, job1); -#else - (void) symbolTable; - (void) codeBase; - (void) symbolBase; - (void) input; - (void) output; - (void) n; - (void) unroll; -#endif - return processed; -} - - -// LICENSE_CHANGE_END - - -// LICENSE_CHANGE_BEGIN -// The following code up to LICENSE_CHANGE_END is subject to THIRD PARTY LICENSE #6 -// See the end of this file for a list - -// this software is distributed under the MIT License (http://www.opensource.org/licenses/MIT): -// -// Copyright 2018-2020, CWI, TU Munich, FSU Jena -// -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files -// (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, -// merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// - The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES -// OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR -// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -// -// You can contact the authors via the FSST source repository : https://github.com/cwida/fsst - - -Symbol concat(Symbol a, Symbol b) { - Symbol s; - u32 length = a.length()+b.length(); - if (length > Symbol::maxLength) length = Symbol::maxLength; - s.set_code_len(FSST_CODE_MASK, length); - s.val.num = (b.val.num << (8*a.length())) | a.val.num; - return s; -} - -namespace std { -template <> -class hash { -public: - size_t operator()(const QSymbol& q) const { - uint64_t k = q.symbol.val.num; - const uint64_t m = 0xc6a4a7935bd1e995; - const int r = 47; - uint64_t h = 0x8445d61a4e774912 ^ (8*m); - k *= m; - k ^= k >> r; - k *= m; - h ^= k; - h *= m; - h ^= h >> r; - h *= m; - h ^= h >> r; - return h; - } -}; -} - -bool isEscapeCode(u16 pos) { return pos < FSST_CODE_BASE; } - -std::ostream& operator<<(std::ostream& out, const Symbol& s) { - for (u32 i=0; i line, size_t len[], bool zeroTerminated=false) { - SymbolTable *st = new SymbolTable(), *bestTable = new SymbolTable(); - int bestGain = (int) -FSST_SAMPLEMAXSZ; // worst case (everything exception) - size_t sampleFrac = 128; - - // start by determining the terminator. We use the (lowest) most infrequent byte as terminator - st->zeroTerminated = zeroTerminated; - if (zeroTerminated) { - st->terminator = 0; // except in case of zeroTerminated mode, then byte 0 is terminator regardless frequency - } else { - u16 byteHisto[256]; - memset(byteHisto, 0, sizeof(byteHisto)); - for(size_t i=0; iterminator = 256; - while(i-- > 0) { - if (byteHisto[i] > minSize) continue; - st->terminator = i; - minSize = byteHisto[i]; - } - } - assert(st->terminator != 256); - - // a random number between 0 and 128 - auto rnd128 = [&](size_t i) { return 1 + (FSST_HASH((i+1UL)*sampleFrac)&127); }; - - // compress sample, and compute (pair-)frequencies - auto compressCount = [&](SymbolTable *st, Counters &counters) { // returns gain - int gain = 0; - - for(size_t i=0; i sampleFrac) continue; - } - if (cur < end) { - u8* start = cur; - u16 code2 = 255, code1 = st->findLongestSymbol(cur, end); - cur += st->symbols[code1].length(); - gain += (int) (st->symbols[code1].length()-(1+isEscapeCode(code1))); - while (true) { - // count single symbol (i.e. an option is not extending it) - counters.count1Inc(code1); - - // as an alternative, consider just using the next byte.. - if (st->symbols[code1].length() != 1) // .. but do not count single byte symbols doubly - counters.count1Inc(*start); - - if (cur==end) { - break; - } - - // now match a new symbol - start = cur; - if (curhashTabSize-1); - Symbol s = st->hashTab[idx]; - code2 = st->shortCodes[word & 0xFFFF] & FSST_CODE_MASK; - word &= (0xFFFFFFFFFFFFFFFF >> (u8) s.icl); - if ((s.icl < FSST_ICL_FREE) & (s.val.num == word)) { - code2 = s.code(); - cur += s.length(); - } else if (code2 >= FSST_CODE_BASE) { - cur += 2; - } else { - code2 = st->byteCodes[word & 0xFF] & FSST_CODE_MASK; - cur += 1; - } - } else { - code2 = st->findLongestSymbol(cur, end); - cur += st->symbols[code2].length(); - } - - // compute compressed output size - gain += ((int) (cur-start))-(1+isEscapeCode(code2)); - - // now count the subsequent two symbols we encode as an extension codesibility - if (sampleFrac < 128) { // no need to count pairs in final round - // consider the symbol that is the concatenation of the two last symbols - counters.count2Inc(code1, code2); - - // as an alternative, consider just extending with the next byte.. - if ((cur-start) > 1) // ..but do not count single byte extensions doubly - counters.count2Inc(code1, *start); - } - code1 = code2; - } - } - } - return gain; - }; - - auto makeTable = [&](SymbolTable *st, Counters &counters) { - // hashmap of c (needed because we can generate duplicate candidates) - unordered_set cands; - - // artificially make terminater the most frequent symbol so it gets included - u16 terminator = st->nSymbols?FSST_CODE_BASE:st->terminator; - counters.count1Set(terminator,65535); - - auto addOrInc = [&](unordered_set &cands, Symbol s, u64 count) { - if (count < (5*sampleFrac)/128) return; // improves both compression speed (less candidates), but also quality!! - QSymbol q; - q.symbol = s; - q.gain = count * s.length(); - auto it = cands.find(q); - if (it != cands.end()) { - q.gain += (*it).gain; - cands.erase(*it); - } - cands.insert(q); - }; - - // add candidate symbols based on counted frequency - for (u32 pos1=0; pos1nSymbols; pos1++) { - u32 cnt1 = counters.count1GetNext(pos1); // may advance pos1!! - if (!cnt1) continue; - - // heuristic: promoting single-byte symbols (*8) helps reduce exception rates and increases [de]compression speed - Symbol s1 = st->symbols[pos1]; - addOrInc(cands, s1, ((s1.length()==1)?8LL:1LL)*cnt1); - - if (sampleFrac >= 128 || // last round we do not create new (combined) symbols - s1.length() == Symbol::maxLength || // symbol cannot be extended - s1.val.str[0] == st->terminator) { // multi-byte symbols cannot contain the terminator byte - continue; - } - for (u32 pos2=0; pos2nSymbols; pos2++) { - u32 cnt2 = counters.count2GetNext(pos1, pos2); // may advance pos2!! - if (!cnt2) continue; - - // create a new symbol - Symbol s2 = st->symbols[pos2]; - Symbol s3 = concat(s1, s2); - if (s2.val.str[0] != st->terminator) // multi-byte symbols cannot contain the terminator byte - addOrInc(cands, s3, cnt2); - } - } - - // insert candidates into priority queue (by gain) - auto cmpGn = [](const QSymbol& q1, const QSymbol& q2) { return (q1.gain < q2.gain) || (q1.gain == q2.gain && q1.symbol.val.num > q2.symbol.val.num); }; - priority_queue,decltype(cmpGn)> pq(cmpGn); - for (auto& q : cands) - pq.push(q); - - // Create new symbol map using best candidates - st->clear(); - while (st->nSymbols < 255 && !pq.empty()) { - QSymbol q = pq.top(); - pq.pop(); - st->add(q.symbol); - } - }; - - u8 bestCounters[512*sizeof(u16)]; -#ifdef NONOPT_FSST - for(size_t frac : {127, 127, 127, 127, 127, 127, 127, 127, 127, 128}) { - sampleFrac = frac; -#else - for(sampleFrac=8; true; sampleFrac += 30) { -#endif - memset(&counters, 0, sizeof(Counters)); - long gain = compressCount(st, counters); - if (gain >= bestGain) { // a new best solution! - counters.backup1(bestCounters); - *bestTable = *st; bestGain = gain; - } - if (sampleFrac >= 128) break; // we do 5 rounds (sampleFrac=8,38,68,98,128) - makeTable(st, counters); - } - delete st; - counters.restore1(bestCounters); - makeTable(bestTable, counters); - bestTable->finalize(zeroTerminated); // renumber codes for more efficient compression - return bestTable; -} - -static inline size_t compressSIMD(SymbolTable &symbolTable, u8* symbolBase, size_t nlines, size_t len[], u8* line[], size_t size, u8* dst, size_t lenOut[], u8* strOut[], int unroll) { - size_t curLine = 0, inOff = 0, outOff = 0, batchPos = 0, empty = 0, budget = size; - u8 *lim = dst + size, *codeBase = symbolBase + (1<<18); // 512KB temp space for compressing 512 strings - SIMDjob input[512]; // combined offsets of input strings (cur,end), and string #id (pos) and output (dst) pointer - SIMDjob output[512]; // output are (pos:9,dst:19) end pointers (compute compressed length from this) - size_t jobLine[512]; // for which line in the input sequence was this job (needed because we may split a line into multiple jobs) - - while (curLine < nlines && outOff <= (1<<19)) { - size_t prevLine = curLine, chunk, curOff = 0; - - // bail out if the output buffer cannot hold the compressed next string fully - if (((len[curLine]-curOff)*2 + 7) > budget) break; // see below for the +7 - else budget -= (len[curLine]-curOff)*2; - - strOut[curLine] = (u8*) 0; - lenOut[curLine] = 0; - - do { - do { - chunk = len[curLine] - curOff; - if (chunk > 511) { - chunk = 511; // large strings need to be chopped up into segments of 511 bytes - } - // create a job in this batch - SIMDjob job; - job.cur = inOff; - job.end = job.cur + chunk; - job.pos = batchPos; - job.out = outOff; - - // worst case estimate for compressed size (+7 is for the scatter that writes extra 7 zeros) - outOff += 7 + 2*(size_t)(job.end - job.cur); // note, total size needed is 512*(511*2+7) bytes. - if (outOff > (1<<19)) break; // simdbuf may get full, stop before this chunk - - // register job in this batch - input[batchPos] = job; - jobLine[batchPos] = curLine; - - if (chunk == 0) { - empty++; // detect empty chunks -- SIMD code cannot handle empty strings, so they need to be filtered out - } else { - // copy string chunk into temp buffer - memcpy(symbolBase + inOff, line[curLine] + curOff, chunk); - inOff += chunk; - curOff += chunk; - symbolBase[inOff++] = (u8) symbolTable.terminator; // write an extra char at the end that will not be encoded - } - if (++batchPos == 512) break; - } while(curOff < len[curLine]); - - if ((batchPos == 512) || (outOff > (1<<19)) || (++curLine >= nlines)) { // cannot accumulate more? - if (batchPos-empty >= 32) { // if we have enough work, fire off fsst_compressAVX512 (32 is due to max 4x8 unrolling) - // radix-sort jobs on length (longest string first) - // -- this provides best load balancing and allows to skip empty jobs at the end - u16 sortpos[513]; - memset(sortpos, 0, sizeof(sortpos)); - - // calculate length histo - for(size_t i=0; i> (u8) s.icl); - if ((s.icl < FSST_ICL_FREE) && s.val.num == word) { - *out++ = (u8) s.code(); cur += s.length(); - } else { - // could be a 2-byte or 1-byte code, or miss - // handle everything with predication - *out = (u8) code; - out += 1+((code&FSST_CODE_BASE)>>8); - cur += (code>>FSST_LEN_BITS); - } - } - job.out = out - codeBase; - } - // postprocess job info - job.cur = 0; - job.end = job.out - input[job.pos].out; // misuse .end field as compressed size - job.out = input[job.pos].out; // reset offset to start of encoded string - input[job.pos] = job; - } - - // copy out the result data - for(size_t i=0; i> (u8) s.icl); - if ((s.icl < FSST_ICL_FREE) && s.val.num == word) { - *out++ = (u8) s.code(); cur += s.length(); - } else if (avoidBranch) { - // could be a 2-byte or 1-byte code, or miss - // handle everything with predication - *out = (u8) code; - out += 1+((code&FSST_CODE_BASE)>>8); - cur += (code>>FSST_LEN_BITS); - } else if ((u8) code < byteLim) { - // 2 byte code after checking there is no longer pattern - *out++ = (u8) code; cur += 2; - } else { - // 1 byte code or miss. - *out = (u8) code; - out += 1+((code&FSST_CODE_BASE)>>8); // predicated - tested with a branch, that was always worse - cur++; - } - } - } - }; - - for(curLine=0; curLine 511) { - chunk = 511; // we need to compress in chunks of 511 in order to be byte-compatible with simd-compressed FSST - } - if ((2*chunk+7) > (size_t) (lim-out)) { - return curLine; // out of memory - } - // copy the string to the 511-byte buffer - memcpy(buf, cur, chunk); - buf[chunk] = (u8) symbolTable.terminator; - cur = buf; - end = cur + chunk; - - // based on symboltable stats, choose a variant that is nice to the branch predictor - if (noSuffixOpt) { - compressVariant(true,false); - } else if (avoidBranch) { - compressVariant(false,true); - } else { - compressVariant(false, false); - } - } while((curOff += chunk) < lenIn[curLine]); - lenOut[curLine] = (size_t) (out - strOut[curLine]); - } - return curLine; -} - -#define FSST_SAMPLELINE ((size_t) 512) - -// quickly select a uniformly random set of lines such that we have between [FSST_SAMPLETARGET,FSST_SAMPLEMAXSZ) string bytes -vector makeSample(u8* sampleBuf, u8* strIn[], size_t *lenIn, size_t nlines, - unique_ptr>& sample_len_out) { - size_t totSize = 0; - vector sample; - - for(size_t i=0; i>(new vector()); - sample_len_out->reserve(nlines + FSST_SAMPLEMAXSZ/FSST_SAMPLELINE); - - // This fails if we have a lot of small strings and a few big ones? - while(sampleBuf < sampleLim) { - // choose a non-empty line - sampleRnd = FSST_HASH(sampleRnd); - size_t linenr = sampleRnd % nlines; - while (lenIn[linenr] == 0) - if (++linenr == nlines) linenr = 0; - - // choose a chunk - size_t chunks = 1 + ((lenIn[linenr]-1) / FSST_SAMPLELINE); - sampleRnd = FSST_HASH(sampleRnd); - size_t chunk = FSST_SAMPLELINE*(sampleRnd % chunks); - - // add the chunk to the sample - size_t len = min(lenIn[linenr]-chunk,FSST_SAMPLELINE); - memcpy(sampleBuf, strIn[linenr]+chunk, len); - sample.push_back(sampleBuf); - - sample_len_out->push_back(len); - sampleBuf += len; - } - } - return sample; -} - -extern "C" duckdb_fsst_encoder_t* duckdb_fsst_create(size_t n, size_t lenIn[], u8 *strIn[], int zeroTerminated) { - u8* sampleBuf = new u8[FSST_SAMPLEMAXSZ]; - unique_ptr> sample_sizes; - vector sample = makeSample(sampleBuf, strIn, lenIn, n?n:1, sample_sizes); // careful handling of input to get a right-size and representative sample - Encoder *encoder = new Encoder(); - size_t* sampleLen = sample_sizes ? sample_sizes->data() : &lenIn[0]; - encoder->symbolTable = shared_ptr(buildSymbolTable(encoder->counters, sample, sampleLen, zeroTerminated)); - delete[] sampleBuf; - return (duckdb_fsst_encoder_t*) encoder; -} - -/* create another encoder instance, necessary to do multi-threaded encoding using the same symbol table */ -extern "C" duckdb_fsst_encoder_t* duckdb_fsst_duplicate(duckdb_fsst_encoder_t *encoder) { - Encoder *e = new Encoder(); - e->symbolTable = ((Encoder*)encoder)->symbolTable; // it is a shared_ptr - return (duckdb_fsst_encoder_t*) e; -} - -// export a symbol table in compact format. -extern "C" u32 duckdb_fsst_export(duckdb_fsst_encoder_t *encoder, u8 *buf) { - Encoder *e = (Encoder*) encoder; - // In ->version there is a versionnr, but we hide also suffixLim/terminator/nSymbols there. - // This is sufficient in principle to *reconstruct* a duckdb_fsst_encoder_t from a duckdb_fsst_decoder_t - // (such functionality could be useful to append compressed data to an existing block). - // - // However, the hash function in the encoder hash table is endian-sensitive, and given its - // 'lossy perfect' hashing scheme is *unable* to contain other-endian-produced symbol tables. - // Doing a endian-conversion during hashing will be slow and self-defeating. - // - // Overall, we could support reconstructing an encoder for incremental compression, but - // should enforce equal-endianness. Bit of a bummer. Not going there now. - // - // The version field is now there just for future-proofness, but not used yet - - // version allows keeping track of fsst versions, track endianness, and encoder reconstruction - u64 version = (FSST_VERSION << 32) | // version is 24 bits, most significant byte is 0 - (((u64) e->symbolTable->suffixLim) << 24) | - (((u64) e->symbolTable->terminator) << 16) | - (((u64) e->symbolTable->nSymbols) << 8) | - FSST_ENDIAN_MARKER; // least significant byte is nonzero - - /* do not assume unaligned reads here */ - memcpy(buf, &version, 8); - buf[8] = e->symbolTable->zeroTerminated; - for(u32 i=0; i<8; i++) - buf[9+i] = (u8) e->symbolTable->lenHisto[i]; - u32 pos = 17; - - // emit only the used bytes of the symbols - for(u32 i = e->symbolTable->zeroTerminated; i < e->symbolTable->nSymbols; i++) - for(u32 j = 0; j < e->symbolTable->symbols[i].length(); j++) - buf[pos++] = e->symbolTable->symbols[i].val.str[j]; // serialize used symbol bytes - - return pos; // length of what was serialized -} - -#define FSST_CORRUPT 32774747032022883 /* 7-byte number in little endian containing "corrupt" */ - -extern "C" u32 duckdb_fsst_import(duckdb_fsst_decoder_t *decoder, u8 *buf) { - u64 version = 0; - u32 code, pos = 17; - u8 lenHisto[8]; - - // version field (first 8 bytes) is now there just for future-proofness, unused still (skipped) - memcpy(&version, buf, 8); - if ((version>>32) != FSST_VERSION) return 0; - decoder->zeroTerminated = buf[8]&1; - memcpy(lenHisto, buf+9, 8); - - // in case of zero-terminated, first symbol is "" (zero always, may be overwritten) - decoder->len[0] = 1; - decoder->symbol[0] = 0; - - // we use lenHisto[0] as 1-byte symbol run length (at the end) - code = decoder->zeroTerminated; - if (decoder->zeroTerminated) lenHisto[0]--; // if zeroTerminated, then symbol "" aka 1-byte code=0, is not stored at the end - - // now get all symbols from the buffer - for(u32 l=1; l<=8; l++) { /* l = 1,2,3,4,5,6,7,8 */ - for(u32 i=0; i < lenHisto[(l&7) /* 1,2,3,4,5,6,7,0 */]; i++, code++) { - decoder->len[code] = (l&7)+1; /* len = 2,3,4,5,6,7,8,1 */ - decoder->symbol[code] = 0; - for(u32 j=0; jlen[code]; j++) - ((u8*) &decoder->symbol[code])[j] = buf[pos++]; // note this enforces 'little endian' symbols - } - } - if (decoder->zeroTerminated) lenHisto[0]++; - - // fill unused symbols with text "corrupt". Gives a chance to detect corrupted code sequences (if there are unused symbols). - while(code<255) { - decoder->symbol[code] = FSST_CORRUPT; - decoder->len[code++] = 8; - } - return pos; -} - -// runtime check for simd -inline size_t _compressImpl(Encoder *e, size_t nlines, size_t lenIn[], u8 *strIn[], size_t size, u8 *output, size_t *lenOut, u8 *strOut[], bool noSuffixOpt, bool avoidBranch, int simd) { -#ifndef NONOPT_FSST - if (simd && duckdb_fsst_hasAVX512()) - return compressSIMD(*e->symbolTable, e->simdbuf, nlines, lenIn, strIn, size, output, lenOut, strOut, simd); -#endif - (void) simd; - return compressBulk(*e->symbolTable, nlines, lenIn, strIn, size, output, lenOut, strOut, noSuffixOpt, avoidBranch); -} -size_t compressImpl(Encoder *e, size_t nlines, size_t lenIn[], u8 *strIn[], size_t size, u8 *output, size_t *lenOut, u8 *strOut[], bool noSuffixOpt, bool avoidBranch, int simd) { - return _compressImpl(e, nlines, lenIn, strIn, size, output, lenOut, strOut, noSuffixOpt, avoidBranch, simd); -} - -// adaptive choosing of scalar compression method based on symbol length histogram -inline size_t _compressAuto(Encoder *e, size_t nlines, size_t lenIn[], u8 *strIn[], size_t size, u8 *output, size_t *lenOut, u8 *strOut[], int simd) { - bool avoidBranch = false, noSuffixOpt = false; - if (100*e->symbolTable->lenHisto[1] > 65*e->symbolTable->nSymbols && 100*e->symbolTable->suffixLim > 95*e->symbolTable->lenHisto[1]) { - noSuffixOpt = true; - } else if ((e->symbolTable->lenHisto[0] > 24 && e->symbolTable->lenHisto[0] < 92) && - (e->symbolTable->lenHisto[0] < 43 || e->symbolTable->lenHisto[6] + e->symbolTable->lenHisto[7] < 29) && - (e->symbolTable->lenHisto[0] < 72 || e->symbolTable->lenHisto[2] < 72)) { - avoidBranch = true; - } - return _compressImpl(e, nlines, lenIn, strIn, size, output, lenOut, strOut, noSuffixOpt, avoidBranch, simd); -} -size_t compressAuto(Encoder *e, size_t nlines, size_t lenIn[], u8 *strIn[], size_t size, u8 *output, size_t *lenOut, u8 *strOut[], int simd) { - return _compressAuto(e, nlines, lenIn, strIn, size, output, lenOut, strOut, simd); -} - -// the main compression function (everything automatic) -extern "C" size_t duckdb_fsst_compress(duckdb_fsst_encoder_t *encoder, size_t nlines, size_t lenIn[], u8 *strIn[], size_t size, u8 *output, size_t *lenOut, u8 *strOut[]) { - // to be faster than scalar, simd needs 64 lines or more of length >=12; or fewer lines, but big ones (totLen > 32KB) - size_t totLen = accumulate(lenIn, lenIn+nlines, 0); - int simd = totLen > nlines*12 && (nlines > 64 || totLen > (size_t) 1<<15); - return _compressAuto((Encoder*) encoder, nlines, lenIn, strIn, size, output, lenOut, strOut, 3*simd); -} - -/* deallocate encoder */ -extern "C" void duckdb_fsst_destroy(duckdb_fsst_encoder_t* encoder) { - Encoder *e = (Encoder*) encoder; - delete e; -} - -/* very lazy implementation relying on export and import */ -extern "C" duckdb_fsst_decoder_t duckdb_fsst_decoder(duckdb_fsst_encoder_t *encoder) { - u8 buf[sizeof(duckdb_fsst_decoder_t)]; - u32 cnt1 = duckdb_fsst_export(encoder, buf); - duckdb_fsst_decoder_t decoder; - u32 cnt2 = duckdb_fsst_import(&decoder, buf); - assert(cnt1 == cnt2); (void) cnt1; (void) cnt2; - return decoder; -} - -// LICENSE_CHANGE_END - -#endif diff --git a/lib/duckdb-hyperloglog.cpp b/lib/duckdb-hyperloglog.cpp deleted file mode 100644 index 62b568a1..00000000 --- a/lib/duckdb-hyperloglog.cpp +++ /dev/null @@ -1,2721 +0,0 @@ -#ifdef GODUCKDB_FROM_SOURCE -// See https://raw.githubusercontent.com/duckdb/duckdb/main/LICENSE for licensing information - -#include "duckdb.hpp" -#include "duckdb-internal.hpp" -#ifndef DUCKDB_AMALGAMATION -#error header mismatch -#endif - - -// LICENSE_CHANGE_BEGIN -// The following code up to LICENSE_CHANGE_END is subject to THIRD PARTY LICENSE #1 -// See the end of this file for a list - -/* hyperloglog.c - Redis HyperLogLog probabilistic cardinality approximation. - * This file implements the algorithm and the exported Redis commands. - * - * Copyright (c) 2014, Salvatore Sanfilippo - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * * Redistributions of source code must retain the above copyright notice, - * this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of Redis nor the names of its contributors may be used - * to endorse or promote products derived from this software without - * specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE - * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - * POSSIBILITY OF SUCH DAMAGE. - */ - - - - - - - -// LICENSE_CHANGE_BEGIN -// The following code up to LICENSE_CHANGE_END is subject to THIRD PARTY LICENSE #1 -// See the end of this file for a list - -/* SDSLib 2.0 -- A C dynamic strings library - * - * Copyright (c) 2006-2015, Salvatore Sanfilippo - * Copyright (c) 2015, Oran Agra - * Copyright (c) 2015, Redis Labs, Inc - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * * Redistributions of source code must retain the above copyright notice, - * this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of Redis nor the names of its contributors may be used - * to endorse or promote products derived from this software without - * specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE - * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - * POSSIBILITY OF SUCH DAMAGE. - */ - -#ifndef __SDS_H -#define __SDS_H - - -#ifdef _MSC_VER -#define __attribute__(A) -#define ssize_t int64_t -#endif - -#define SDS_MAX_PREALLOC (1024*1024) - -#include -#include -#include - -namespace duckdb_hll { - - -typedef char *sds; - -/* Note: sdshdr5 is never used, we just access the flags byte directly. - * However is here to document the layout of type 5 SDS strings. */ -struct __attribute__ ((__packed__)) sdshdr5 { - unsigned char flags; /* 3 lsb of type, and 5 msb of string length */ - char buf[1]; -}; -struct __attribute__ ((__packed__)) sdshdr8 { - uint8_t len; /* used */ - uint8_t alloc; /* excluding the header and null terminator */ - unsigned char flags; /* 3 lsb of type, 5 unused bits */ - char buf[1]; -}; -struct __attribute__ ((__packed__)) sdshdr16 { - uint16_t len; /* used */ - uint16_t alloc; /* excluding the header and null terminator */ - unsigned char flags; /* 3 lsb of type, 5 unused bits */ - char buf[1]; -}; -struct __attribute__ ((__packed__)) sdshdr32 { - uint32_t len; /* used */ - uint32_t alloc; /* excluding the header and null terminator */ - unsigned char flags; /* 3 lsb of type, 5 unused bits */ - char buf[1]; -}; -struct __attribute__ ((__packed__)) sdshdr64 { - uint64_t len; /* used */ - uint64_t alloc; /* excluding the header and null terminator */ - unsigned char flags; /* 3 lsb of type, 5 unused bits */ - char buf[1]; -}; - -#define SDS_TYPE_5 0 -#define SDS_TYPE_8 1 -#define SDS_TYPE_16 2 -#define SDS_TYPE_32 3 -#define SDS_TYPE_64 4 -#define SDS_TYPE_MASK 7 -#define SDS_TYPE_BITS 3 -#define SDS_HDR_VAR(T,s) struct sdshdr##T *sh = (struct sdshdr##T *)((s)-(sizeof(struct sdshdr##T))); -#define SDS_HDR(T,s) ((struct sdshdr##T *)((s)-(sizeof(struct sdshdr##T)))) -#define SDS_TYPE_5_LEN(f) ((f)>>SDS_TYPE_BITS) - -static inline size_t sdslen(const sds s) { - unsigned char flags = s[-1]; - switch(flags&SDS_TYPE_MASK) { - case SDS_TYPE_5: - return SDS_TYPE_5_LEN(flags); - case SDS_TYPE_8: - return SDS_HDR(8,s)->len; - case SDS_TYPE_16: - return SDS_HDR(16,s)->len; - case SDS_TYPE_32: - return SDS_HDR(32,s)->len; - case SDS_TYPE_64: - return SDS_HDR(64,s)->len; - } - return 0; -} - -static inline size_t sdsavail(const sds s) { - unsigned char flags = s[-1]; - switch(flags&SDS_TYPE_MASK) { - case SDS_TYPE_5: { - return 0; - } - case SDS_TYPE_8: { - SDS_HDR_VAR(8,s); - return sh->alloc - sh->len; - } - case SDS_TYPE_16: { - SDS_HDR_VAR(16,s); - return sh->alloc - sh->len; - } - case SDS_TYPE_32: { - SDS_HDR_VAR(32,s); - return sh->alloc - sh->len; - } - case SDS_TYPE_64: { - SDS_HDR_VAR(64,s); - return sh->alloc - sh->len; - } - } - return 0; -} - -static inline void sdssetlen(sds s, size_t newlen) { - unsigned char flags = s[-1]; - switch(flags&SDS_TYPE_MASK) { - case SDS_TYPE_5: - { - unsigned char *fp = ((unsigned char*)s)-1; - *fp = SDS_TYPE_5 | (newlen << SDS_TYPE_BITS); - } - break; - case SDS_TYPE_8: - SDS_HDR(8,s)->len = newlen; - break; - case SDS_TYPE_16: - SDS_HDR(16,s)->len = newlen; - break; - case SDS_TYPE_32: - SDS_HDR(32,s)->len = newlen; - break; - case SDS_TYPE_64: - SDS_HDR(64,s)->len = newlen; - break; - } -} - -static inline void sdsinclen(sds s, size_t inc) { - unsigned char flags = s[-1]; - switch(flags&SDS_TYPE_MASK) { - case SDS_TYPE_5: - { - unsigned char *fp = ((unsigned char*)s)-1; - unsigned char newlen = SDS_TYPE_5_LEN(flags)+inc; - *fp = SDS_TYPE_5 | (newlen << SDS_TYPE_BITS); - } - break; - case SDS_TYPE_8: - SDS_HDR(8,s)->len += inc; - break; - case SDS_TYPE_16: - SDS_HDR(16,s)->len += inc; - break; - case SDS_TYPE_32: - SDS_HDR(32,s)->len += inc; - break; - case SDS_TYPE_64: - SDS_HDR(64,s)->len += inc; - break; - } -} - -/* sdsalloc() = sdsavail() + sdslen() */ -static inline size_t sdsalloc(const sds s) { - unsigned char flags = s[-1]; - switch(flags&SDS_TYPE_MASK) { - case SDS_TYPE_5: - return SDS_TYPE_5_LEN(flags); - case SDS_TYPE_8: - return SDS_HDR(8,s)->alloc; - case SDS_TYPE_16: - return SDS_HDR(16,s)->alloc; - case SDS_TYPE_32: - return SDS_HDR(32,s)->alloc; - case SDS_TYPE_64: - return SDS_HDR(64,s)->alloc; - } - return 0; -} - -static inline void sdssetalloc(sds s, size_t newlen) { - unsigned char flags = s[-1]; - switch(flags&SDS_TYPE_MASK) { - case SDS_TYPE_5: - /* Nothing to do, this type has no total allocation info. */ - break; - case SDS_TYPE_8: - SDS_HDR(8,s)->alloc = newlen; - break; - case SDS_TYPE_16: - SDS_HDR(16,s)->alloc = newlen; - break; - case SDS_TYPE_32: - SDS_HDR(32,s)->alloc = newlen; - break; - case SDS_TYPE_64: - SDS_HDR(64,s)->alloc = newlen; - break; - } -} - -sds sdsnewlen(const void *init, size_t initlen); -sds sdsnew(const char *init); -sds sdsempty(void); -sds sdsdup(const sds s); -void sdsfree(sds s); -sds sdsgrowzero(sds s, size_t len); -sds sdscatlen(sds s, const void *t, size_t len); -sds sdscat(sds s, const char *t); -sds sdscatsds(sds s, const sds t); -sds sdscpylen(sds s, const char *t, size_t len); -sds sdscpy(sds s, const char *t); - -sds sdscatvprintf(sds s, const char *fmt, va_list ap); -#ifdef __GNUC__ -sds sdscatprintf(sds s, const char *fmt, ...) - __attribute__((format(printf, 2, 3))); -#else -sds sdscatprintf(sds s, const char *fmt, ...); -#endif - -sds sdscatfmt(sds s, char const *fmt, ...); -sds sdstrim(sds s, const char *cset); -void sdsrange(sds s, ssize_t start, ssize_t end); -void sdsupdatelen(sds s); -void sdsclear(sds s); -int sdscmp(const sds s1, const sds s2); -sds *sdssplitlen(const char *s, ssize_t len, const char *sep, int seplen, int *count); -void sdsfreesplitres(sds *tokens, int count); -void sdstolower(sds s); -void sdstoupper(sds s); -sds sdsfromlonglong(long long value); -sds sdscatrepr(sds s, const char *p, size_t len); -sds *sdssplitargs(const char *line, int *argc); -sds sdsmapchars(sds s, const char *from, const char *to, size_t setlen); -sds sdsjoin(char **argv, int argc, char *sep); -sds sdsjoinsds(sds *argv, int argc, const char *sep, size_t seplen); - -/* Low level functions exposed to the user API */ -sds sdsMakeRoomFor(sds s, size_t addlen); -void sdsIncrLen(sds s, ssize_t incr); -sds sdsRemoveFreeSpace(sds s); -size_t sdsAllocSize(sds s); -void *sdsAllocPtr(sds s); - -/* Export the allocator used by SDS to the program using SDS. - * Sometimes the program SDS is linked to, may use a different set of - * allocators, but may want to allocate or free things that SDS will - * respectively free or allocate. */ -void *sds_malloc(size_t size); -void *sds_realloc(void *ptr, size_t size); -void sds_free(void *ptr); - -#ifdef REDIS_TEST -int sdsTest(int argc, char *argv[]); -#endif -} - - -#endif - -// LICENSE_CHANGE_END - - -#include -#include -#include -#include -#include -#include - - - -namespace duckdb_hll { - -#define HLL_SPARSE_MAX_BYTES 3000 - -/* The Redis HyperLogLog implementation is based on the following ideas: - * - * * The use of a 64 bit hash function as proposed in [1], in order to don't - * limited to cardinalities up to 10^9, at the cost of just 1 additional - * bit per register. - * * The use of 16384 6-bit registers for a great level of accuracy, using - * a total of 12k per key. - * * The use of the Redis string data type. No new type is introduced. - * * No attempt is made to compress the data structure as in [1]. Also the - * algorithm used is the original HyperLogLog Algorithm as in [2], with - * the only difference that a 64 bit hash function is used, so no correction - * is performed for values near 2^32 as in [1]. - * - * [1] Heule, Nunkesser, Hall: HyperLogLog in Practice: Algorithmic - * Engineering of a State of The Art Cardinality Estimation Algorithm. - * - * [2] P. Flajolet, Éric Fusy, O. Gandouet, and F. Meunier. Hyperloglog: The - * analysis of a near-optimal cardinality estimation algorithm. - * - * Redis uses two representations: - * - * 1) A "dense" representation where every entry is represented by - * a 6-bit integer. - * 2) A "sparse" representation using run length compression suitable - * for representing HyperLogLogs with many registers set to 0 in - * a memory efficient way. - * - * - * HLL header - * === - * - * Both the dense and sparse representation have a 16 byte header as follows: - * - * +------+---+-----+----------+ - * | HYLL | E | N/U | Cardin. | - * +------+---+-----+----------+ - * - * The first 4 bytes are a magic string set to the bytes "HYLL". - * "E" is one byte encoding, currently set to HLL_DENSE or - * HLL_SPARSE. N/U are three not used bytes. - * - * The "Cardin." field is a 64 bit integer stored in little endian format - * with the latest cardinality computed that can be reused if the data - * structure was not modified since the last computation (this is useful - * because there are high probabilities that HLLADD operations don't - * modify the actual data structure and hence the approximated cardinality). - * - * When the most significant bit in the most significant byte of the cached - * cardinality is set, it means that the data structure was modified and - * we can't reuse the cached value that must be recomputed. - * - * Dense representation - * === - * - * The dense representation used by Redis is the following: - * - * +--------+--------+--------+------// //--+ - * |11000000|22221111|33333322|55444444 .... | - * +--------+--------+--------+------// //--+ - * - * The 6 bits counters are encoded one after the other starting from the - * LSB to the MSB, and using the next bytes as needed. - * - * Sparse representation - * === - * - * The sparse representation encodes registers using a run length - * encoding composed of three opcodes, two using one byte, and one using - * of two bytes. The opcodes are called ZERO, XZERO and VAL. - * - * ZERO opcode is represented as 00xxxxxx. The 6-bit integer represented - * by the six bits 'xxxxxx', plus 1, means that there are N registers set - * to 0. This opcode can represent from 1 to 64 contiguous registers set - * to the value of 0. - * - * XZERO opcode is represented by two bytes 01xxxxxx yyyyyyyy. The 14-bit - * integer represented by the bits 'xxxxxx' as most significant bits and - * 'yyyyyyyy' as least significant bits, plus 1, means that there are N - * registers set to 0. This opcode can represent from 0 to 16384 contiguous - * registers set to the value of 0. - * - * VAL opcode is represented as 1vvvvvxx. It contains a 5-bit integer - * representing the value of a register, and a 2-bit integer representing - * the number of contiguous registers set to that value 'vvvvv'. - * To obtain the value and run length, the integers vvvvv and xx must be - * incremented by one. This opcode can represent values from 1 to 32, - * repeated from 1 to 4 times. - * - * The sparse representation can't represent registers with a value greater - * than 32, however it is very unlikely that we find such a register in an - * HLL with a cardinality where the sparse representation is still more - * memory efficient than the dense representation. When this happens the - * HLL is converted to the dense representation. - * - * The sparse representation is purely positional. For example a sparse - * representation of an empty HLL is just: XZERO:16384. - * - * An HLL having only 3 non-zero registers at position 1000, 1020, 1021 - * respectively set to 2, 3, 3, is represented by the following three - * opcodes: - * - * XZERO:1000 (Registers 0-999 are set to 0) - * VAL:2,1 (1 register set to value 2, that is register 1000) - * ZERO:19 (Registers 1001-1019 set to 0) - * VAL:3,2 (2 registers set to value 3, that is registers 1020,1021) - * XZERO:15362 (Registers 1022-16383 set to 0) - * - * In the example the sparse representation used just 7 bytes instead - * of 12k in order to represent the HLL registers. In general for low - * cardinality there is a big win in terms of space efficiency, traded - * with CPU time since the sparse representation is slower to access: - * - * The following table shows average cardinality vs bytes used, 100 - * samples per cardinality (when the set was not representable because - * of registers with too big value, the dense representation size was used - * as a sample). - * - * 100 267 - * 200 485 - * 300 678 - * 400 859 - * 500 1033 - * 600 1205 - * 700 1375 - * 800 1544 - * 900 1713 - * 1000 1882 - * 2000 3480 - * 3000 4879 - * 4000 6089 - * 5000 7138 - * 6000 8042 - * 7000 8823 - * 8000 9500 - * 9000 10088 - * 10000 10591 - * - * The dense representation uses 12288 bytes, so there is a big win up to - * a cardinality of ~2000-3000. For bigger cardinalities the constant times - * involved in updating the sparse representation is not justified by the - * memory savings. The exact maximum length of the sparse representation - * when this implementation switches to the dense representation is - * configured via the define server.hll_sparse_max_bytes. - */ - -struct hllhdr { - char magic[4]; /* "HYLL" */ - uint8_t encoding; /* HLL_DENSE or HLL_SPARSE. */ - uint8_t notused[3]; /* Reserved for future use, must be zero. */ - uint8_t card[8]; /* Cached cardinality, little endian. */ - uint8_t registers[1]; /* Data bytes. */ -}; - -/* The cached cardinality MSB is used to signal validity of the cached value. */ -#define HLL_INVALIDATE_CACHE(hdr) (hdr)->card[7] |= (1<<7) -#define HLL_VALID_CACHE(hdr) (((hdr)->card[7] & (1<<7)) == 0) - -#define HLL_P 12 /* The greater is P, the smaller the error. */ -#define HLL_Q (64-HLL_P) /* The number of bits of the hash value used for - determining the number of leading zeros. */ -#define HLL_REGISTERS (1< 6 - * - * Right shift b0 of 'fb' bits. - * - * +--------+ - * |11000000| <- Initial value of b0 - * |00000011| <- After right shift of 6 pos. - * +--------+ - * - * Left shift b1 of bits 8-fb bits (2 bits) - * - * +--------+ - * |22221111| <- Initial value of b1 - * |22111100| <- After left shift of 2 bits. - * +--------+ - * - * OR the two bits, and finally AND with 111111 (63 in decimal) to - * clean the higher order bits we are not interested in: - * - * +--------+ - * |00000011| <- b0 right shifted - * |22111100| <- b1 left shifted - * |22111111| <- b0 OR b1 - * | 111111| <- (b0 OR b1) AND 63, our value. - * +--------+ - * - * We can try with a different example, like pos = 0. In this case - * the 6-bit counter is actually contained in a single byte. - * - * b0 = 6 * pos / 8 = 0 - * - * +--------+ - * |11000000| <- Our byte at b0 - * +--------+ - * - * fb = 6 * pos % 8 = 0 - * - * So we right shift of 0 bits (no shift in practice) and - * left shift the next byte of 8 bits, even if we don't use it, - * but this has the effect of clearing the bits so the result - * will not be affacted after the OR. - * - * ------------------------------------------------------------------------- - * - * Setting the register is a bit more complex, let's assume that 'val' - * is the value we want to set, already in the right range. - * - * We need two steps, in one we need to clear the bits, and in the other - * we need to bitwise-OR the new bits. - * - * Let's try with 'pos' = 1, so our first byte at 'b' is 0, - * - * "fb" is 6 in this case. - * - * +--------+ - * |11000000| <- Our byte at b0 - * +--------+ - * - * To create a AND-mask to clear the bits about this position, we just - * initialize the mask with the value 63, left shift it of "fs" bits, - * and finally invert the result. - * - * +--------+ - * |00111111| <- "mask" starts at 63 - * |11000000| <- "mask" after left shift of "ls" bits. - * |00111111| <- "mask" after invert. - * +--------+ - * - * Now we can bitwise-AND the byte at "b" with the mask, and bitwise-OR - * it with "val" left-shifted of "ls" bits to set the new bits. - * - * Now let's focus on the next byte b1: - * - * +--------+ - * |22221111| <- Initial value of b1 - * +--------+ - * - * To build the AND mask we start again with the 63 value, right shift - * it by 8-fb bits, and invert it. - * - * +--------+ - * |00111111| <- "mask" set at 2&6-1 - * |00001111| <- "mask" after the right shift by 8-fb = 2 bits - * |11110000| <- "mask" after bitwise not. - * +--------+ - * - * Now we can mask it with b+1 to clear the old bits, and bitwise-OR - * with "val" left-shifted by "rs" bits to set the new value. - */ - -/* Note: if we access the last counter, we will also access the b+1 byte - * that is out of the array, but sds strings always have an implicit null - * term, so the byte exists, and we can skip the conditional (or the need - * to allocate 1 byte more explicitly). */ - -/* Store the value of the register at position 'regnum' into variable 'target'. - * 'p' is an array of unsigned bytes. */ -#define HLL_DENSE_GET_REGISTER(target,p,regnum) do { \ - uint8_t *_p = (uint8_t*) p; \ - unsigned long _byte = regnum*HLL_BITS/8; \ - unsigned long _fb = regnum*HLL_BITS&7; \ - unsigned long _fb8 = 8 - _fb; \ - unsigned long b0 = _p[_byte]; \ - unsigned long b1 = _p[_byte+1]; \ - target = ((b0 >> _fb) | (b1 << _fb8)) & HLL_REGISTER_MAX; \ -} while(0) - -/* Set the value of the register at position 'regnum' to 'val'. - * 'p' is an array of unsigned bytes. */ -#define HLL_DENSE_SET_REGISTER(p,regnum,val) do { \ - uint8_t *_p = (uint8_t*) p; \ - unsigned long _byte = regnum*HLL_BITS/8; \ - unsigned long _fb = regnum*HLL_BITS&7; \ - unsigned long _fb8 = 8 - _fb; \ - unsigned long _v = val; \ - _p[_byte] &= ~(HLL_REGISTER_MAX << _fb); \ - _p[_byte] |= _v << _fb; \ - _p[_byte+1] &= ~(HLL_REGISTER_MAX >> _fb8); \ - _p[_byte+1] |= _v >> _fb8; \ -} while(0) - -/* Macros to access the sparse representation. - * The macros parameter is expected to be an uint8_t pointer. */ -#define HLL_SPARSE_XZERO_BIT 0x40 /* 01xxxxxx */ -#define HLL_SPARSE_VAL_BIT 0x80 /* 1vvvvvxx */ -#define HLL_SPARSE_IS_ZERO(p) (((*(p)) & 0xc0) == 0) /* 00xxxxxx */ -#define HLL_SPARSE_IS_XZERO(p) (((*(p)) & 0xc0) == HLL_SPARSE_XZERO_BIT) -#define HLL_SPARSE_IS_VAL(p) ((*(p)) & HLL_SPARSE_VAL_BIT) -#define HLL_SPARSE_ZERO_LEN(p) (((*(p)) & 0x3f)+1) -#define HLL_SPARSE_XZERO_LEN(p) (((((*(p)) & 0x3f) << 8) | (*((p)+1)))+1) -#define HLL_SPARSE_VAL_VALUE(p) ((((*(p)) >> 2) & 0x1f)+1) -#define HLL_SPARSE_VAL_LEN(p) (((*(p)) & 0x3)+1) -#define HLL_SPARSE_VAL_MAX_VALUE 32 -#define HLL_SPARSE_VAL_MAX_LEN 4 -#define HLL_SPARSE_ZERO_MAX_LEN 64 -#define HLL_SPARSE_XZERO_MAX_LEN 16384 -#define HLL_SPARSE_VAL_SET(p,val,len) do { \ - *(p) = (((val)-1)<<2|((len)-1))|HLL_SPARSE_VAL_BIT; \ -} while(0) -#define HLL_SPARSE_ZERO_SET(p,len) do { \ - *(p) = (len)-1; \ -} while(0) -#define HLL_SPARSE_XZERO_SET(p,len) do { \ - int _l = (len)-1; \ - *(p) = (_l>>8) | HLL_SPARSE_XZERO_BIT; \ - *((p)+1) = (_l&0xff); \ -} while(0) -#define HLL_ALPHA_INF 0.721347520444481703680 /* constant for 0.5/ln(2) */ - -/* ========================= HyperLogLog algorithm ========================= */ - -/* Our hash function is MurmurHash2, 64 bit version. - * It was modified for Redis in order to provide the same result in - * big and little endian archs (endian neutral). */ -uint64_t MurmurHash64A (const void * key, int len, unsigned int seed) { - const uint64_t m = 0xc6a4a7935bd1e995; - const int r = 47; - uint64_t h = seed ^ (len * m); - const uint8_t *data = (const uint8_t *)key; - const uint8_t *end = data + (len-(len&7)); - - while(data != end) { - uint64_t k; - -#if (BYTE_ORDER == LITTLE_ENDIAN) - #ifdef USE_ALIGNED_ACCESS - memcpy(&k,data,sizeof(uint64_t)); - #else - k = *((uint64_t*)data); - #endif -#else - k = (uint64_t) data[0]; - k |= (uint64_t) data[1] << 8; - k |= (uint64_t) data[2] << 16; - k |= (uint64_t) data[3] << 24; - k |= (uint64_t) data[4] << 32; - k |= (uint64_t) data[5] << 40; - k |= (uint64_t) data[6] << 48; - k |= (uint64_t) data[7] << 56; -#endif - - k *= m; - k ^= k >> r; - k *= m; - h ^= k; - h *= m; - data += 8; - } - - switch(len & 7) { - case 7: h ^= (uint64_t)data[6] << 48; /* fall-thru */ - case 6: h ^= (uint64_t)data[5] << 40; /* fall-thru */ - case 5: h ^= (uint64_t)data[4] << 32; /* fall-thru */ - case 4: h ^= (uint64_t)data[3] << 24; /* fall-thru */ - case 3: h ^= (uint64_t)data[2] << 16; /* fall-thru */ - case 2: h ^= (uint64_t)data[1] << 8; /* fall-thru */ - case 1: h ^= (uint64_t)data[0]; - h *= m; /* fall-thru */ - }; - - h ^= h >> r; - h *= m; - h ^= h >> r; - return h; -} - -/* Given a string element to add to the HyperLogLog, returns the length - * of the pattern 000..1 of the element hash. As a side effect 'regp' is - * set to the register index this element hashes to. */ -int hllPatLen(unsigned char *ele, size_t elesize, long *regp) { - uint64_t hash, bit, index; - int count; - - /* Count the number of zeroes starting from bit HLL_REGISTERS - * (that is a power of two corresponding to the first bit we don't use - * as index). The max run can be 64-P+1 = Q+1 bits. - * - * Note that the final "1" ending the sequence of zeroes must be - * included in the count, so if we find "001" the count is 3, and - * the smallest count possible is no zeroes at all, just a 1 bit - * at the first position, that is a count of 1. - * - * This may sound like inefficient, but actually in the average case - * there are high probabilities to find a 1 after a few iterations. */ - hash = MurmurHash64A(ele,elesize,0xadc83b19ULL); - index = hash & HLL_P_MASK; /* Register index. */ - hash >>= HLL_P; /* Remove bits used to address the register. */ - hash |= ((uint64_t)1< oldcount) { - HLL_DENSE_SET_REGISTER(registers,index,count); - return 1; - } else { - return 0; - } -} - -/* "Add" the element in the dense hyperloglog data structure. - * Actually nothing is added, but the max 0 pattern counter of the subset - * the element belongs to is incremented if needed. - * - * This is just a wrapper to hllDenseSet(), performing the hashing of the - * element in order to retrieve the index and zero-run count. */ -int hllDenseAdd(uint8_t *registers, unsigned char *ele, size_t elesize) { - long index; - uint8_t count = hllPatLen(ele,elesize,&index); - /* Update the register if this element produced a longer run of zeroes. */ - return hllDenseSet(registers,index,count); -} - -/* Compute the register histogram in the dense representation. */ -void hllDenseRegHisto(uint8_t *registers, int* reghisto) { - int j; - - /* Redis default is to use 16384 registers 6 bits each. The code works - * with other values by modifying the defines, but for our target value - * we take a faster path with unrolled loops. */ - if (HLL_REGISTERS == 16384 && HLL_BITS == 6) { - uint8_t *r = registers; - unsigned long r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, - r10, r11, r12, r13, r14, r15; - for (j = 0; j < 1024; j++) { - /* Handle 16 registers per iteration. */ - r0 = r[0] & 63; - r1 = (r[0] >> 6 | r[1] << 2) & 63; - r2 = (r[1] >> 4 | r[2] << 4) & 63; - r3 = (r[2] >> 2) & 63; - r4 = r[3] & 63; - r5 = (r[3] >> 6 | r[4] << 2) & 63; - r6 = (r[4] >> 4 | r[5] << 4) & 63; - r7 = (r[5] >> 2) & 63; - r8 = r[6] & 63; - r9 = (r[6] >> 6 | r[7] << 2) & 63; - r10 = (r[7] >> 4 | r[8] << 4) & 63; - r11 = (r[8] >> 2) & 63; - r12 = r[9] & 63; - r13 = (r[9] >> 6 | r[10] << 2) & 63; - r14 = (r[10] >> 4 | r[11] << 4) & 63; - r15 = (r[11] >> 2) & 63; - - reghisto[r0]++; - reghisto[r1]++; - reghisto[r2]++; - reghisto[r3]++; - reghisto[r4]++; - reghisto[r5]++; - reghisto[r6]++; - reghisto[r7]++; - reghisto[r8]++; - reghisto[r9]++; - reghisto[r10]++; - reghisto[r11]++; - reghisto[r12]++; - reghisto[r13]++; - reghisto[r14]++; - reghisto[r15]++; - - r += 12; - } - } else { - for(j = 0; j < HLL_REGISTERS; j++) { - unsigned long reg; - HLL_DENSE_GET_REGISTER(reg,registers,j); - reghisto[reg]++; - } - } -} - -/* ================== Sparse representation implementation ================= */ - -/* Convert the HLL with sparse representation given as input in its dense - * representation. Both representations are represented by SDS strings, and - * the input representation is freed as a side effect. - * - * The function returns C_OK if the sparse representation was valid, - * otherwise C_ERR is returned if the representation was corrupted. */ -int hllSparseToDense(robj *o) { - sds sparse = (sds) o->ptr, dense; - struct hllhdr *hdr, *oldhdr = (struct hllhdr*)sparse; - int idx = 0, runlen, regval; - uint8_t *p = (uint8_t*)sparse, *end = p+sdslen(sparse); - - /* If the representation is already the right one return ASAP. */ - hdr = (struct hllhdr*) sparse; - if (hdr->encoding == HLL_DENSE) return HLL_C_OK; - - /* Create a string of the right size filled with zero bytes. - * Note that the cached cardinality is set to 0 as a side effect - * that is exactly the cardinality of an empty HLL. */ - dense = sdsnewlen(NULL,HLL_DENSE_SIZE); - hdr = (struct hllhdr*) dense; - *hdr = *oldhdr; /* This will copy the magic and cached cardinality. */ - hdr->encoding = HLL_DENSE; - - /* Now read the sparse representation and set non-zero registers - * accordingly. */ - p += HLL_HDR_SIZE; - while(p < end) { - if (HLL_SPARSE_IS_ZERO(p)) { - runlen = HLL_SPARSE_ZERO_LEN(p); - idx += runlen; - p++; - } else if (HLL_SPARSE_IS_XZERO(p)) { - runlen = HLL_SPARSE_XZERO_LEN(p); - idx += runlen; - p += 2; - } else { - runlen = HLL_SPARSE_VAL_LEN(p); - regval = HLL_SPARSE_VAL_VALUE(p); - while(runlen--) { - HLL_DENSE_SET_REGISTER(hdr->registers + 1,idx,regval); - idx++; - } - p++; - } - } - - /* If the sparse representation was valid, we expect to find idx - * set to HLL_REGISTERS. */ - if (idx != HLL_REGISTERS) { - sdsfree(dense); - return HLL_C_ERR; - } - - /* Free the old representation and set the new one. */ - sdsfree((sds) o->ptr); - o->ptr = dense; - return HLL_C_OK; -} - -/* Low level function to set the sparse HLL register at 'index' to the - * specified value if the current value is smaller than 'count'. - * - * The object 'o' is the String object holding the HLL. The function requires - * a reference to the object in order to be able to enlarge the string if - * needed. - * - * On success, the function returns 1 if the cardinality changed, or 0 - * if the register for this element was not updated. - * On error (if the representation is invalid) -1 is returned. - * - * As a side effect the function may promote the HLL representation from - * sparse to dense: this happens when a register requires to be set to a value - * not representable with the sparse representation, or when the resulting - * size would be greater than server.hll_sparse_max_bytes. */ -int hllSparseSet(robj *o, long index, uint8_t count) { - struct hllhdr *hdr; - uint8_t oldcount, *sparse, *end, *p, *prev, *next; - long first, span; - long is_zero = 0, is_xzero = 0, is_val = 0, runlen = 0; - uint8_t seq[5], *n; - int last; - int len; - int seqlen; - int oldlen; - int deltalen; - - /* If the count is too big to be representable by the sparse representation - * switch to dense representation. */ - if (count > HLL_SPARSE_VAL_MAX_VALUE) goto promote; - - /* When updating a sparse representation, sometimes we may need to - * enlarge the buffer for up to 3 bytes in the worst case (XZERO split - * into XZERO-VAL-XZERO). Make sure there is enough space right now - * so that the pointers we take during the execution of the function - * will be valid all the time. */ - o->ptr = (sds) sdsMakeRoomFor((sds) o->ptr,3); - - /* Step 1: we need to locate the opcode we need to modify to check - * if a value update is actually needed. */ - sparse = p = ((uint8_t*)o->ptr) + HLL_HDR_SIZE; - end = p + sdslen((sds) o->ptr) - HLL_HDR_SIZE; - - first = 0; - prev = NULL; /* Points to previous opcode at the end of the loop. */ - next = NULL; /* Points to the next opcode at the end of the loop. */ - span = 0; - while(p < end) { - long oplen; - - /* Set span to the number of registers covered by this opcode. - * - * This is the most performance critical loop of the sparse - * representation. Sorting the conditionals from the most to the - * least frequent opcode in many-bytes sparse HLLs is faster. */ - oplen = 1; - if (HLL_SPARSE_IS_ZERO(p)) { - span = HLL_SPARSE_ZERO_LEN(p); - } else if (HLL_SPARSE_IS_VAL(p)) { - span = HLL_SPARSE_VAL_LEN(p); - } else { /* XZERO. */ - span = HLL_SPARSE_XZERO_LEN(p); - oplen = 2; - } - /* Break if this opcode covers the register as 'index'. */ - if (index <= first+span-1) break; - prev = p; - p += oplen; - first += span; - } - if (span == 0) return -1; /* Invalid format. */ - - next = HLL_SPARSE_IS_XZERO(p) ? p+2 : p+1; - if (next >= end) next = NULL; - - /* Cache current opcode type to avoid using the macro again and - * again for something that will not change. - * Also cache the run-length of the opcode. */ - if (HLL_SPARSE_IS_ZERO(p)) { - is_zero = 1; - runlen = HLL_SPARSE_ZERO_LEN(p); - } else if (HLL_SPARSE_IS_XZERO(p)) { - is_xzero = 1; - runlen = HLL_SPARSE_XZERO_LEN(p); - } else { - is_val = 1; - runlen = HLL_SPARSE_VAL_LEN(p); - } - - /* Step 2: After the loop: - * - * 'first' stores to the index of the first register covered - * by the current opcode, which is pointed by 'p'. - * - * 'next' ad 'prev' store respectively the next and previous opcode, - * or NULL if the opcode at 'p' is respectively the last or first. - * - * 'span' is set to the number of registers covered by the current - * opcode. - * - * There are different cases in order to update the data structure - * in place without generating it from scratch: - * - * A) If it is a VAL opcode already set to a value >= our 'count' - * no update is needed, regardless of the VAL run-length field. - * In this case PFADD returns 0 since no changes are performed. - * - * B) If it is a VAL opcode with len = 1 (representing only our - * register) and the value is less than 'count', we just update it - * since this is a trivial case. */ - if (is_val) { - oldcount = HLL_SPARSE_VAL_VALUE(p); - /* Case A. */ - if (oldcount >= count) return 0; - - /* Case B. */ - if (runlen == 1) { - HLL_SPARSE_VAL_SET(p,count,1); - goto updated; - } - } - - /* C) Another trivial to handle case is a ZERO opcode with a len of 1. - * We can just replace it with a VAL opcode with our value and len of 1. */ - if (is_zero && runlen == 1) { - HLL_SPARSE_VAL_SET(p,count,1); - goto updated; - } - - /* D) General case. - * - * The other cases are more complex: our register requires to be updated - * and is either currently represented by a VAL opcode with len > 1, - * by a ZERO opcode with len > 1, or by an XZERO opcode. - * - * In those cases the original opcode must be split into multiple - * opcodes. The worst case is an XZERO split in the middle resuling into - * XZERO - VAL - XZERO, so the resulting sequence max length is - * 5 bytes. - * - * We perform the split writing the new sequence into the 'new' buffer - * with 'newlen' as length. Later the new sequence is inserted in place - * of the old one, possibly moving what is on the right a few bytes - * if the new sequence is longer than the older one. */ - n = seq; - last = first+span-1; /* Last register covered by the sequence. */ - - if (is_zero || is_xzero) { - /* Handle splitting of ZERO / XZERO. */ - if (index != first) { - len = index-first; - if (len > HLL_SPARSE_ZERO_MAX_LEN) { - HLL_SPARSE_XZERO_SET(n,len); - n += 2; - } else { - HLL_SPARSE_ZERO_SET(n,len); - n++; - } - } - HLL_SPARSE_VAL_SET(n,count,1); - n++; - if (index != last) { - len = last-index; - if (len > HLL_SPARSE_ZERO_MAX_LEN) { - HLL_SPARSE_XZERO_SET(n,len); - n += 2; - } else { - HLL_SPARSE_ZERO_SET(n,len); - n++; - } - } - } else { - /* Handle splitting of VAL. */ - int curval = HLL_SPARSE_VAL_VALUE(p); - - if (index != first) { - len = index-first; - HLL_SPARSE_VAL_SET(n,curval,len); - n++; - } - HLL_SPARSE_VAL_SET(n,count,1); - n++; - if (index != last) { - len = last-index; - HLL_SPARSE_VAL_SET(n,curval,len); - n++; - } - } - - /* Step 3: substitute the new sequence with the old one. - * - * Note that we already allocated space on the sds string - * calling sdsMakeRoomFor(). */ - seqlen = n-seq; - oldlen = is_xzero ? 2 : 1; - deltalen = seqlen-oldlen; - - if (deltalen > 0 && - sdslen((sds) o->ptr)+deltalen > HLL_SPARSE_MAX_BYTES) goto promote; - if (deltalen && next) memmove(next+deltalen,next,end-next); - sdsIncrLen((sds) o->ptr,deltalen); - memcpy(p,seq,seqlen); - end += deltalen; - -updated: { - /* Step 4: Merge adjacent values if possible. - * - * The representation was updated, however the resulting representation - * may not be optimal: adjacent VAL opcodes can sometimes be merged into - * a single one. */ - p = prev ? prev : sparse; - int scanlen = 5; /* Scan up to 5 upcodes starting from prev. */ - while (p < end && scanlen--) { - if (HLL_SPARSE_IS_XZERO(p)) { - p += 2; - continue; - } else if (HLL_SPARSE_IS_ZERO(p)) { - p++; - continue; - } - /* We need two adjacent VAL opcodes to try a merge, having - * the same value, and a len that fits the VAL opcode max len. */ - if (p+1 < end && HLL_SPARSE_IS_VAL(p+1)) { - int v1 = HLL_SPARSE_VAL_VALUE(p); - int v2 = HLL_SPARSE_VAL_VALUE(p+1); - if (v1 == v2) { - int len = HLL_SPARSE_VAL_LEN(p)+HLL_SPARSE_VAL_LEN(p+1); - if (len <= HLL_SPARSE_VAL_MAX_LEN) { - HLL_SPARSE_VAL_SET(p+1,v1,len); - memmove(p,p+1,end-p); - sdsIncrLen((sds) o->ptr,-1); - end--; - /* After a merge we reiterate without incrementing 'p' - * in order to try to merge the just merged value with - * a value on its right. */ - continue; - } - } - } - p++; - } - - /* Invalidate the cached cardinality. */ - hdr = (struct hllhdr *) o->ptr; - HLL_INVALIDATE_CACHE(hdr); - return 1; -} -promote: /* Promote to dense representation. */ - if (hllSparseToDense(o) == HLL_C_ERR) return -1; /* Corrupted HLL. */ - hdr = (struct hllhdr *) o->ptr; - - /* We need to call hllDenseAdd() to perform the operation after the - * conversion. However the result must be 1, since if we need to - * convert from sparse to dense a register requires to be updated. - * - * Note that this in turn means that PFADD will make sure the command - * is propagated to slaves / AOF, so if there is a sparse -> dense - * conversion, it will be performed in all the slaves as well. */ - int dense_retval = hllDenseSet(hdr->registers + 1,index,count); - assert(dense_retval == 1); - return dense_retval; -} - -/* "Add" the element in the sparse hyperloglog data structure. - * Actually nothing is added, but the max 0 pattern counter of the subset - * the element belongs to is incremented if needed. - * - * This function is actually a wrapper for hllSparseSet(), it only performs - * the hashshing of the elmenet to obtain the index and zeros run length. */ -int hllSparseAdd(robj *o, unsigned char *ele, size_t elesize) { - long index; - uint8_t count = hllPatLen(ele,elesize,&index); - /* Update the register if this element produced a longer run of zeroes. */ - return hllSparseSet(o,index,count); -} - -/* Compute the register histogram in the sparse representation. */ -void hllSparseRegHisto(uint8_t *sparse, int sparselen, int *invalid, int* reghisto) { - int idx = 0, runlen, regval; - uint8_t *end = sparse+sparselen, *p = sparse; - - while(p < end) { - if (HLL_SPARSE_IS_ZERO(p)) { - runlen = HLL_SPARSE_ZERO_LEN(p); - idx += runlen; - reghisto[0] += runlen; - p++; - } else if (HLL_SPARSE_IS_XZERO(p)) { - runlen = HLL_SPARSE_XZERO_LEN(p); - idx += runlen; - reghisto[0] += runlen; - p += 2; - } else { - runlen = HLL_SPARSE_VAL_LEN(p); - regval = HLL_SPARSE_VAL_VALUE(p); - idx += runlen; - reghisto[regval] += runlen; - p++; - } - } - if (idx != HLL_REGISTERS && invalid) *invalid = 1; -} - -/* ========================= HyperLogLog Count ============================== - * This is the core of the algorithm where the approximated count is computed. - * The function uses the lower level hllDenseRegHisto() and hllSparseRegHisto() - * functions as helpers to compute histogram of register values part of the - * computation, which is representation-specific, while all the rest is common. */ - -/* Implements the register histogram calculation for uint8_t data type - * which is only used internally as speedup for PFCOUNT with multiple keys. */ -void hllRawRegHisto(uint8_t *registers, int* reghisto) { - uint64_t *word = (uint64_t*) registers; - uint8_t *bytes; - int j; - - for (j = 0; j < HLL_REGISTERS/8; j++) { - if (*word == 0) { - reghisto[0] += 8; - } else { - bytes = (uint8_t*) word; - reghisto[bytes[0]]++; - reghisto[bytes[1]]++; - reghisto[bytes[2]]++; - reghisto[bytes[3]]++; - reghisto[bytes[4]]++; - reghisto[bytes[5]]++; - reghisto[bytes[6]]++; - reghisto[bytes[7]]++; - } - word++; - } -} - -// somehow this is missing on some platforms -#ifndef INFINITY -// from math.h -#define INFINITY 1e50f -#endif - - -/* Helper function sigma as defined in - * "New cardinality estimation algorithms for HyperLogLog sketches" - * Otmar Ertl, arXiv:1702.01284 */ -double hllSigma(double x) { - if (x == 1.) return INFINITY; - double zPrime; - double y = 1; - double z = x; - do { - x *= x; - zPrime = z; - z += x * y; - y += y; - } while(zPrime != z); - return z; -} - -/* Helper function tau as defined in - * "New cardinality estimation algorithms for HyperLogLog sketches" - * Otmar Ertl, arXiv:1702.01284 */ -double hllTau(double x) { - if (x == 0. || x == 1.) return 0.; - double zPrime; - double y = 1.0; - double z = 1 - x; - do { - x = sqrt(x); - zPrime = z; - y *= 0.5; - z -= pow(1 - x, 2)*y; - } while(zPrime != z); - return z / 3; -} - -/* Return the approximated cardinality of the set based on the harmonic - * mean of the registers values. 'hdr' points to the start of the SDS - * representing the String object holding the HLL representation. - * - * If the sparse representation of the HLL object is not valid, the integer - * pointed by 'invalid' is set to non-zero, otherwise it is left untouched. - * - * hllCount() supports a special internal-only encoding of HLL_RAW, that - * is, hdr->registers will point to an uint8_t array of HLL_REGISTERS element. - * This is useful in order to speedup PFCOUNT when called against multiple - * keys (no need to work with 6-bit integers encoding). */ -uint64_t hllCount(struct hllhdr *hdr, int *invalid) { - double m = HLL_REGISTERS; - double E; - int j; - int reghisto[HLL_Q+2] = {0}; - - /* Compute register histogram */ - if (hdr->encoding == HLL_DENSE) { - hllDenseRegHisto(hdr->registers + 1,reghisto); - } else if (hdr->encoding == HLL_SPARSE) { - hllSparseRegHisto(hdr->registers + 1, - sdslen((sds)hdr)-HLL_HDR_SIZE,invalid,reghisto); - } else if (hdr->encoding == HLL_RAW) { - hllRawRegHisto(hdr->registers + 1,reghisto); - } else { - *invalid = 1; - return 0; - //serverPanic("Unknown HyperLogLog encoding in hllCount()"); - } - - /* Estimate cardinality form register histogram. See: - * "New cardinality estimation algorithms for HyperLogLog sketches" - * Otmar Ertl, arXiv:1702.01284 */ - double z = m * hllTau((m-reghisto[HLL_Q+1])/(double)m); - for (j = HLL_Q; j >= 1; --j) { - z += reghisto[j]; - z *= 0.5; - } - z += m * hllSigma(reghisto[0]/(double)m); - E = llroundl(HLL_ALPHA_INF*m*m/z); - - return (uint64_t) E; -} - -/* Call hllDenseAdd() or hllSparseAdd() according to the HLL encoding. */ -int hll_add(robj *o, unsigned char *ele, size_t elesize) { - struct hllhdr *hdr = (struct hllhdr *) o->ptr; - switch(hdr->encoding) { - case HLL_DENSE: return hllDenseAdd(hdr->registers + 1,ele,elesize); - case HLL_SPARSE: return hllSparseAdd(o,ele,elesize); - default: return -1; /* Invalid representation. */ - } -} - -/* Merge by computing MAX(registers[i],hll[i]) the HyperLogLog 'hll' - * with an array of uint8_t HLL_REGISTERS registers pointed by 'max'. - * - * The hll object must be already validated via isHLLObjectOrReply() - * or in some other way. - * - * If the HyperLogLog is sparse and is found to be invalid, C_ERR - * is returned, otherwise the function always succeeds. */ -int hllMerge(uint8_t *max, robj *hll) { - struct hllhdr *hdr = (struct hllhdr *) hll->ptr; - int i; - - if (hdr->encoding == HLL_DENSE) { - uint8_t val; - - for (i = 0; i < HLL_REGISTERS; i++) { - HLL_DENSE_GET_REGISTER(val,hdr->registers + 1,i); - if (val > max[i]) max[i] = val; - } - } else { - uint8_t *p = (uint8_t *) hll->ptr, *end = p + sdslen((sds) hll->ptr); - long runlen, regval; - - p += HLL_HDR_SIZE; - i = 0; - while(p < end) { - if (HLL_SPARSE_IS_ZERO(p)) { - runlen = HLL_SPARSE_ZERO_LEN(p); - i += runlen; - p++; - } else if (HLL_SPARSE_IS_XZERO(p)) { - runlen = HLL_SPARSE_XZERO_LEN(p); - i += runlen; - p += 2; - } else { - runlen = HLL_SPARSE_VAL_LEN(p); - regval = HLL_SPARSE_VAL_VALUE(p); - while(runlen--) { - if (regval > max[i]) max[i] = regval; - i++; - } - p++; - } - } - if (i != HLL_REGISTERS) return HLL_C_ERR; - } - return HLL_C_OK; -} - -/* ========================== robj creation ========================== */ -robj *createObject(void *ptr) { - robj *result = (robj*) malloc(sizeof(robj)); - result->ptr = ptr; - return result; -} - -void destroyObject(robj *obj) { - free(obj); -} - -/* ========================== HyperLogLog commands ========================== */ - -/* Create an HLL object. We always create the HLL using sparse encoding. - * This will be upgraded to the dense representation as needed. */ -robj *hll_create(void) { - robj *o; - struct hllhdr *hdr; - sds s; - uint8_t *p; - int sparselen = HLL_HDR_SIZE + - (((HLL_REGISTERS+(HLL_SPARSE_XZERO_MAX_LEN-1)) / - HLL_SPARSE_XZERO_MAX_LEN)*2); - int aux; - - /* Populate the sparse representation with as many XZERO opcodes as - * needed to represent all the registers. */ - aux = HLL_REGISTERS; - s = sdsnewlen(NULL,sparselen); - p = (uint8_t*)s + HLL_HDR_SIZE; - while(aux) { - int xzero = HLL_SPARSE_XZERO_MAX_LEN; - if (xzero > aux) xzero = aux; - HLL_SPARSE_XZERO_SET(p,xzero); - p += 2; - aux -= xzero; - } - assert((p-(uint8_t*)s) == sparselen); - - /* Create the actual object. */ - o = createObject(s); - hdr = (struct hllhdr *) o->ptr; - memcpy(hdr->magic,"HYLL",4); - hdr->encoding = HLL_SPARSE; - return o; -} - -void hll_destroy(robj *obj) { - if (!obj) { - return; - } - sdsfree((sds) obj->ptr); - destroyObject(obj); -} - - - -int hll_count(robj *o, size_t *result) { - int invalid = 0; - *result = hllCount((struct hllhdr*) o->ptr, &invalid); - return invalid == 0 ? HLL_C_OK : HLL_C_ERR; -} - -robj *hll_merge(robj **hlls, size_t hll_count) { - uint8_t max[HLL_REGISTERS]; - struct hllhdr *hdr; - size_t j; - /* Use dense representation as target? */ - int use_dense = 0; - - /* Compute an HLL with M[i] = MAX(M[i]_j). - * We store the maximum into the max array of registers. We'll write - * it to the target variable later. */ - memset(max, 0, sizeof(max)); - for (j = 0; j < hll_count; j++) { - /* Check type and size. */ - robj *o = hlls[j]; - if (o == NULL) continue; /* Assume empty HLL for non existing var. */ - - /* If at least one involved HLL is dense, use the dense representation - * as target ASAP to save time and avoid the conversion step. */ - hdr = (struct hllhdr *) o->ptr; - if (hdr->encoding == HLL_DENSE) use_dense = 1; - - /* Merge with this HLL with our 'max' HHL by setting max[i] - * to MAX(max[i],hll[i]). */ - if (hllMerge(max, o) == HLL_C_ERR) { - return NULL; - } - } - - /* Create the destination key's value. */ - robj *result = hll_create(); - if (!result) { - return NULL; - } - - /* Convert the destination object to dense representation if at least - * one of the inputs was dense. */ - if (use_dense && hllSparseToDense(result) == HLL_C_ERR) { - hll_destroy(result); - return NULL; - } - - /* Write the resulting HLL to the destination HLL registers and - * invalidate the cached value. */ - for (j = 0; j < HLL_REGISTERS; j++) { - if (max[j] == 0) continue; - hdr = (struct hllhdr *) result->ptr; - switch(hdr->encoding) { - case HLL_DENSE: hllDenseSet(hdr->registers + 1,j,max[j]); break; - case HLL_SPARSE: hllSparseSet(result,j,max[j]); break; - } - } - return result; -} - -uint64_t get_size() { - return HLL_DENSE_SIZE; -} - -} - -namespace duckdb { - -static inline int AddToLog(void *log, const uint64_t &index, const uint8_t &count) { - auto o = (duckdb_hll::robj *)log; - duckdb_hll::hllhdr *hdr = (duckdb_hll::hllhdr *)o->ptr; - D_ASSERT(hdr->encoding == HLL_DENSE); - return duckdb_hll::hllDenseSet(hdr->registers + 1, index, count); -} - -void AddToLogsInternal(UnifiedVectorFormat &vdata, idx_t count, uint64_t indices[], uint8_t counts[], void ***logs[], - const SelectionVector *log_sel) { - // 'logs' is an array of pointers to AggregateStates - // AggregateStates have a pointer to a HyperLogLog object - // HyperLogLog objects have a pointer to a 'robj', which we need - for (idx_t i = 0; i < count; i++) { - auto log = logs[log_sel->get_index(i)]; - if (log && vdata.validity.RowIsValid(vdata.sel->get_index(i))) { - AddToLog(**log, indices[i], counts[i]); - } - } -} - -void AddToSingleLogInternal(UnifiedVectorFormat &vdata, idx_t count, uint64_t indices[], uint8_t counts[], void *log) { - const auto o = (duckdb_hll::robj *)log; - duckdb_hll::hllhdr *hdr = (duckdb_hll::hllhdr *)o->ptr; - D_ASSERT(hdr->encoding == HLL_DENSE); - - const auto registers = hdr->registers + 1; - for (idx_t i = 0; i < count; i++) { - if (vdata.validity.RowIsValid(vdata.sel->get_index(i))) { - duckdb_hll::hllDenseSet(registers, indices[i], counts[i]); - } - } -} - -} // namespace duckdb - - -// LICENSE_CHANGE_END - - -// LICENSE_CHANGE_BEGIN -// The following code up to LICENSE_CHANGE_END is subject to THIRD PARTY LICENSE #1 -// See the end of this file for a list - -/* SDSLib 2.0 -- A C dynamic strings library - * - * Copyright (c) 2006-2015, Salvatore Sanfilippo - * Copyright (c) 2015, Oran Agra - * Copyright (c) 2015, Redis Labs, Inc - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * * Redistributions of source code must retain the above copyright notice, - * this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of Redis nor the names of its contributors may be used - * to endorse or promote products derived from this software without - * specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE - * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - * POSSIBILITY OF SUCH DAMAGE. - */ - -#include -#include -#include -#include -#include -#include - - -namespace duckdb_hll { - -static inline int sdsHdrSize(char type) { - switch(type&SDS_TYPE_MASK) { - case SDS_TYPE_5: - return sizeof(struct sdshdr5); - case SDS_TYPE_8: - return sizeof(struct sdshdr8); - case SDS_TYPE_16: - return sizeof(struct sdshdr16); - case SDS_TYPE_32: - return sizeof(struct sdshdr32); - case SDS_TYPE_64: - return sizeof(struct sdshdr64); - } - return 0; -} - -static inline char sdsReqType(size_t string_size) { - if (string_size < 1<<5) - return SDS_TYPE_5; - if (string_size < 1<<8) - return SDS_TYPE_8; - if (string_size < 1<<16) - return SDS_TYPE_16; -#if (LONG_MAX == LLONG_MAX) - if (string_size < 1ll<<32) - return SDS_TYPE_32; - return SDS_TYPE_64; -#else - return SDS_TYPE_32; -#endif -} - -/* Create a new sds string with the content specified by the 'init' pointer - * and 'initlen'. - * If NULL is used for 'init' the string is initialized with zero bytes. - * If SDS_NOINIT is used, the buffer is left uninitialized; - * - * The string is always null-termined (all the sds strings are, always) so - * even if you create an sds string with: - * - * mystring = sdsnewlen("abc",3); - * - * You can print the string with printf() as there is an implicit \0 at the - * end of the string. However the string is binary safe and can contain - * \0 characters in the middle, as the length is stored in the sds header. */ -sds sdsnewlen(const void *init, size_t initlen) { - void *sh; - sds s; - char type = sdsReqType(initlen); - /* Empty strings are usually created in order to append. Use type 8 - * since type 5 is not good at this. */ - if (type == SDS_TYPE_5 && initlen == 0) type = SDS_TYPE_8; - int hdrlen = sdsHdrSize(type); - unsigned char *fp; /* flags pointer. */ - - sh = malloc(hdrlen+initlen+1); - if (!init) - memset(sh, 0, hdrlen+initlen+1); - if (sh == NULL) return NULL; - s = (char*)sh+hdrlen; - fp = ((unsigned char*)s)-1; - switch(type) { - case SDS_TYPE_5: { - *fp = type | (initlen << SDS_TYPE_BITS); - break; - } - case SDS_TYPE_8: { - SDS_HDR_VAR(8,s); - sh->len = initlen; - sh->alloc = initlen; - *fp = type; - break; - } - case SDS_TYPE_16: { - SDS_HDR_VAR(16,s); - sh->len = initlen; - sh->alloc = initlen; - *fp = type; - break; - } - case SDS_TYPE_32: { - SDS_HDR_VAR(32,s); - sh->len = initlen; - sh->alloc = initlen; - *fp = type; - break; - } - case SDS_TYPE_64: { - SDS_HDR_VAR(64,s); - sh->len = initlen; - sh->alloc = initlen; - *fp = type; - break; - } - } - if (initlen && init) - memcpy(s, init, initlen); - s[initlen] = '\0'; - return s; -} - -/* Create an empty (zero length) sds string. Even in this case the string - * always has an implicit null term. */ -sds sdsempty(void) { - return sdsnewlen("",0); -} - -/* Create a new sds string starting from a null terminated C string. */ -sds sdsnew(const char *init) { - size_t initlen = (init == NULL) ? 0 : strlen(init); - return sdsnewlen(init, initlen); -} - -/* Duplicate an sds string. */ -sds sdsdup(const sds s) { - return sdsnewlen(s, sdslen(s)); -} - -/* Free an sds string. No operation is performed if 's' is NULL. */ -void sdsfree(sds s) { - if (s == NULL) return; - free((char*)s-sdsHdrSize(s[-1])); -} - -/* Set the sds string length to the length as obtained with strlen(), so - * considering as content only up to the first null term character. - * - * This function is useful when the sds string is hacked manually in some - * way, like in the following example: - * - * s = sdsnew("foobar"); - * s[2] = '\0'; - * sdsupdatelen(s); - * printf("%d\n", sdslen(s)); - * - * The output will be "2", but if we comment out the call to sdsupdatelen() - * the output will be "6" as the string was modified but the logical length - * remains 6 bytes. */ -void sdsupdatelen(sds s) { - size_t reallen = strlen(s); - sdssetlen(s, reallen); -} - -/* Modify an sds string in-place to make it empty (zero length). - * However all the existing buffer is not discarded but set as free space - * so that next append operations will not require allocations up to the - * number of bytes previously available. */ -void sdsclear(sds s) { - sdssetlen(s, 0); - s[0] = '\0'; -} - -/* Enlarge the free space at the end of the sds string so that the caller - * is sure that after calling this function can overwrite up to addlen - * bytes after the end of the string, plus one more byte for nul term. - * - * Note: this does not change the *length* of the sds string as returned - * by sdslen(), but only the free buffer space we have. */ -sds sdsMakeRoomFor(sds s, size_t addlen) { - void *sh, *newsh; - size_t avail = sdsavail(s); - size_t len, newlen; - char type, oldtype = s[-1] & SDS_TYPE_MASK; - int hdrlen; - - /* Return ASAP if there is enough space left. */ - if (avail >= addlen) return s; - - len = sdslen(s); - sh = (char*)s-sdsHdrSize(oldtype); - newlen = (len+addlen); - if (newlen < SDS_MAX_PREALLOC) - newlen *= 2; - else - newlen += SDS_MAX_PREALLOC; - - type = sdsReqType(newlen); - - /* Don't use type 5: the user is appending to the string and type 5 is - * not able to remember empty space, so sdsMakeRoomFor() must be called - * at every appending operation. */ - if (type == SDS_TYPE_5) type = SDS_TYPE_8; - - hdrlen = sdsHdrSize(type); - if (oldtype==type) { - newsh = realloc(sh, hdrlen+newlen+1); - if (newsh == NULL) return NULL; - s = (char*)newsh+hdrlen; - } else { - /* Since the header size changes, need to move the string forward, - * and can't use realloc */ - newsh = malloc(hdrlen+newlen+1); - if (newsh == NULL) return NULL; - memcpy((char*)newsh+hdrlen, s, len+1); - free(sh); - s = (char*)newsh+hdrlen; - s[-1] = type; - sdssetlen(s, len); - } - sdssetalloc(s, newlen); - return s; -} - -/* Reallocate the sds string so that it has no free space at the end. The - * contained string remains not altered, but next concatenation operations - * will require a reallocation. - * - * After the call, the passed sds string is no longer valid and all the - * references must be substituted with the new pointer returned by the call. */ -sds sdsRemoveFreeSpace(sds s) { - void *sh, *newsh; - char type, oldtype = s[-1] & SDS_TYPE_MASK; - int hdrlen, oldhdrlen = sdsHdrSize(oldtype); - size_t len = sdslen(s); - sh = (char*)s-oldhdrlen; - - /* Check what would be the minimum SDS header that is just good enough to - * fit this string. */ - type = sdsReqType(len); - hdrlen = sdsHdrSize(type); - - /* If the type is the same, or at least a large enough type is still - * required, we just realloc(), letting the allocator to do the copy - * only if really needed. Otherwise if the change is huge, we manually - * reallocate the string to use the different header type. */ - if (oldtype==type || type > SDS_TYPE_8) { - newsh = realloc(sh, oldhdrlen+len+1); - if (newsh == NULL) return NULL; - s = (char*)newsh+oldhdrlen; - } else { - newsh = malloc(hdrlen+len+1); - if (newsh == NULL) return NULL; - memcpy((char*)newsh+hdrlen, s, len+1); - free(sh); - s = (char*)newsh+hdrlen; - s[-1] = type; - sdssetlen(s, len); - } - sdssetalloc(s, len); - return s; -} - -/* Return the total size of the allocation of the specified sds string, - * including: - * 1) The sds header before the pointer. - * 2) The string. - * 3) The free buffer at the end if any. - * 4) The implicit null term. - */ -size_t sdsAllocSize(sds s) { - size_t alloc = sdsalloc(s); - return sdsHdrSize(s[-1])+alloc+1; -} - -/* Return the pointer of the actual SDS allocation (normally SDS strings - * are referenced by the start of the string buffer). */ -void *sdsAllocPtr(sds s) { - return (void*) (s-sdsHdrSize(s[-1])); -} - -/* Increment the sds length and decrements the left free space at the - * end of the string according to 'incr'. Also set the null term - * in the new end of the string. - * - * This function is used in order to fix the string length after the - * user calls sdsMakeRoomFor(), writes something after the end of - * the current string, and finally needs to set the new length. - * - * Note: it is possible to use a negative increment in order to - * right-trim the string. - * - * Usage example: - * - * Using sdsIncrLen() and sdsMakeRoomFor() it is possible to mount the - * following schema, to cat bytes coming from the kernel to the end of an - * sds string without copying into an intermediate buffer: - * - * oldlen = sdslen(s); - * s = sdsMakeRoomFor(s, BUFFER_SIZE); - * nread = read(fd, s+oldlen, BUFFER_SIZE); - * ... check for nread <= 0 and handle it ... - * sdsIncrLen(s, nread); - */ -void sdsIncrLen(sds s, ssize_t incr) { - unsigned char flags = s[-1]; - size_t len; - switch(flags&SDS_TYPE_MASK) { - case SDS_TYPE_5: { - unsigned char *fp = ((unsigned char*)s)-1; - unsigned char oldlen = SDS_TYPE_5_LEN(flags); - assert((incr > 0 && oldlen+incr < 32) || (incr < 0 && oldlen >= (unsigned int)(-incr))); - *fp = SDS_TYPE_5 | ((oldlen+incr) << SDS_TYPE_BITS); - len = oldlen+incr; - break; - } - case SDS_TYPE_8: { - SDS_HDR_VAR(8,s); - assert((incr >= 0 && sh->alloc-sh->len >= incr) || (incr < 0 && sh->len >= (unsigned int)(-incr))); - len = (sh->len += incr); - break; - } - case SDS_TYPE_16: { - SDS_HDR_VAR(16,s); - assert((incr >= 0 && sh->alloc-sh->len >= incr) || (incr < 0 && sh->len >= (unsigned int)(-incr))); - len = (sh->len += incr); - break; - } - case SDS_TYPE_32: { - SDS_HDR_VAR(32,s); - assert((incr >= 0 && sh->alloc-sh->len >= (unsigned int)incr) || (incr < 0 && sh->len >= (unsigned int)(-incr))); - len = (sh->len += incr); - break; - } - case SDS_TYPE_64: { - SDS_HDR_VAR(64,s); - assert((incr >= 0 && sh->alloc-sh->len >= (uint64_t)incr) || (incr < 0 && sh->len >= (uint64_t)(-incr))); - len = (sh->len += incr); - break; - } - default: len = 0; /* Just to avoid compilation warnings. */ - } - s[len] = '\0'; -} - -/* Grow the sds to have the specified length. Bytes that were not part of - * the original length of the sds will be set to zero. - * - * if the specified length is smaller than the current length, no operation - * is performed. */ -sds sdsgrowzero(sds s, size_t len) { - size_t curlen = sdslen(s); - - if (len <= curlen) return s; - s = sdsMakeRoomFor(s,len-curlen); - if (s == NULL) return NULL; - - /* Make sure added region doesn't contain garbage */ - memset(s+curlen,0,(len-curlen+1)); /* also set trailing \0 byte */ - sdssetlen(s, len); - return s; -} - -/* Append the specified binary-safe string pointed by 't' of 'len' bytes to the - * end of the specified sds string 's'. - * - * After the call, the passed sds string is no longer valid and all the - * references must be substituted with the new pointer returned by the call. */ -sds sdscatlen(sds s, const void *t, size_t len) { - size_t curlen = sdslen(s); - - s = sdsMakeRoomFor(s,len); - if (s == NULL) return NULL; - memcpy(s+curlen, t, len); - sdssetlen(s, curlen+len); - s[curlen+len] = '\0'; - return s; -} - -/* Append the specified null termianted C string to the sds string 's'. - * - * After the call, the passed sds string is no longer valid and all the - * references must be substituted with the new pointer returned by the call. */ -sds sdscat(sds s, const char *t) { - return sdscatlen(s, t, strlen(t)); -} - -/* Append the specified sds 't' to the existing sds 's'. - * - * After the call, the modified sds string is no longer valid and all the - * references must be substituted with the new pointer returned by the call. */ -sds sdscatsds(sds s, const sds t) { - return sdscatlen(s, t, sdslen(t)); -} - -/* Destructively modify the sds string 's' to hold the specified binary - * safe string pointed by 't' of length 'len' bytes. */ -sds sdscpylen(sds s, const char *t, size_t len) { - if (sdsalloc(s) < len) { - s = sdsMakeRoomFor(s,len-sdslen(s)); - if (s == NULL) return NULL; - } - memcpy(s, t, len); - s[len] = '\0'; - sdssetlen(s, len); - return s; -} - -/* Like sdscpylen() but 't' must be a null-termined string so that the length - * of the string is obtained with strlen(). */ -sds sdscpy(sds s, const char *t) { - return sdscpylen(s, t, strlen(t)); -} - -/* Helper for sdscatlonglong() doing the actual number -> string - * conversion. 's' must point to a string with room for at least - * SDS_LLSTR_SIZE bytes. - * - * The function returns the length of the null-terminated string - * representation stored at 's'. */ -#define SDS_LLSTR_SIZE 21 -int sdsll2str(char *s, long long value) { - char *p, aux; - unsigned long long v; - size_t l; - - /* Generate the string representation, this method produces - * an reversed string. */ - v = (value < 0) ? -value : value; - p = s; - do { - *p++ = '0'+(v%10); - v /= 10; - } while(v); - if (value < 0) *p++ = '-'; - - /* Compute length and add null term. */ - l = p-s; - *p = '\0'; - - /* Reverse the string. */ - p--; - while(s < p) { - aux = *s; - *s = *p; - *p = aux; - s++; - p--; - } - return l; -} - -/* Identical sdsll2str(), but for unsigned long long type. */ -int sdsull2str(char *s, unsigned long long v) { - char *p, aux; - size_t l; - - /* Generate the string representation, this method produces - * an reversed string. */ - p = s; - do { - *p++ = '0'+(v%10); - v /= 10; - } while(v); - - /* Compute length and add null term. */ - l = p-s; - *p = '\0'; - - /* Reverse the string. */ - p--; - while(s < p) { - aux = *s; - *s = *p; - *p = aux; - s++; - p--; - } - return l; -} - -/* Create an sds string from a long long value. It is much faster than: - * - * sdscatprintf(sdsempty(),"%lld\n", value); - */ -sds sdsfromlonglong(long long value) { - char buf[SDS_LLSTR_SIZE]; - int len = sdsll2str(buf,value); - - return sdsnewlen(buf,len); -} - -/* Like sdscatprintf() but gets va_list instead of being variadic. */ -sds sdscatvprintf(sds s, const char *fmt, va_list ap) { - va_list cpy; - char staticbuf[1024], *buf = staticbuf, *t; - size_t buflen = strlen(fmt)*2; - - /* We try to start using a static buffer for speed. - * If not possible we revert to heap allocation. */ - if (buflen > sizeof(staticbuf)) { - buf = (char*) malloc(buflen); - if (buf == NULL) return NULL; - } else { - buflen = sizeof(staticbuf); - } - - /* Try with buffers two times bigger every time we fail to - * fit the string in the current buffer size. */ - while(1) { - buf[buflen-2] = '\0'; - va_copy(cpy,ap); - vsnprintf(buf, buflen, fmt, cpy); - va_end(cpy); - if (buf[buflen-2] != '\0') { - if (buf != staticbuf) free(buf); - buflen *= 2; - buf = (char*) malloc(buflen); - if (buf == NULL) return NULL; - continue; - } - break; - } - - /* Finally concat the obtained string to the SDS string and return it. */ - t = sdscat(s, buf); - if (buf != staticbuf) free(buf); - return t; -} - -/* Append to the sds string 's' a string obtained using printf-alike format - * specifier. - * - * After the call, the modified sds string is no longer valid and all the - * references must be substituted with the new pointer returned by the call. - * - * Example: - * - * s = sdsnew("Sum is: "); - * s = sdscatprintf(s,"%d+%d = %d",a,b,a+b). - * - * Often you need to create a string from scratch with the printf-alike - * format. When this is the need, just use sdsempty() as the target string: - * - * s = sdscatprintf(sdsempty(), "... your format ...", args); - */ -sds sdscatprintf(sds s, const char *fmt, ...) { - va_list ap; - char *t; - va_start(ap, fmt); - t = sdscatvprintf(s,fmt,ap); - va_end(ap); - return t; -} - -/* This function is similar to sdscatprintf, but much faster as it does - * not rely on sprintf() family functions implemented by the libc that - * are often very slow. Moreover directly handling the sds string as - * new data is concatenated provides a performance improvement. - * - * However this function only handles an incompatible subset of printf-alike - * format specifiers: - * - * %s - C String - * %S - SDS string - * %i - signed int - * %I - 64 bit signed integer (long long, int64_t) - * %u - unsigned int - * %U - 64 bit unsigned integer (unsigned long long, uint64_t) - * %% - Verbatim "%" character. - */ -sds sdscatfmt(sds s, char const *fmt, ...) { - size_t initlen = sdslen(s); - const char *f = fmt; - long i; - va_list ap; - - va_start(ap,fmt); - f = fmt; /* Next format specifier byte to process. */ - i = initlen; /* Position of the next byte to write to dest str. */ - while(*f) { - char next, *str; - size_t l; - long long num; - unsigned long long unum; - - /* Make sure there is always space for at least 1 char. */ - if (sdsavail(s)==0) { - s = sdsMakeRoomFor(s,1); - } - - switch(*f) { - case '%': - next = *(f+1); - f++; - switch(next) { - case 's': - case 'S': - str = va_arg(ap,char*); - l = (next == 's') ? strlen(str) : sdslen(str); - if (sdsavail(s) < l) { - s = sdsMakeRoomFor(s,l); - } - memcpy(s+i,str,l); - sdsinclen(s,l); - i += l; - break; - case 'i': - case 'I': - if (next == 'i') - num = va_arg(ap,int); - else - num = va_arg(ap,long long); - { - char buf[SDS_LLSTR_SIZE]; - l = sdsll2str(buf,num); - if (sdsavail(s) < l) { - s = sdsMakeRoomFor(s,l); - } - memcpy(s+i,buf,l); - sdsinclen(s,l); - i += l; - } - break; - case 'u': - case 'U': - if (next == 'u') - unum = va_arg(ap,unsigned int); - else - unum = va_arg(ap,unsigned long long); - { - char buf[SDS_LLSTR_SIZE]; - l = sdsull2str(buf,unum); - if (sdsavail(s) < l) { - s = sdsMakeRoomFor(s,l); - } - memcpy(s+i,buf,l); - sdsinclen(s,l); - i += l; - } - break; - default: /* Handle %% and generally %. */ - s[i++] = next; - sdsinclen(s,1); - break; - } - break; - default: - s[i++] = *f; - sdsinclen(s,1); - break; - } - f++; - } - va_end(ap); - - /* Add null-term */ - s[i] = '\0'; - return s; -} - -/* Remove the part of the string from left and from right composed just of - * contiguous characters found in 'cset', that is a null terminted C string. - * - * After the call, the modified sds string is no longer valid and all the - * references must be substituted with the new pointer returned by the call. - * - * Example: - * - * s = sdsnew("AA...AA.a.aa.aHelloWorld :::"); - * s = sdstrim(s,"Aa. :"); - * printf("%s\n", s); - * - * Output will be just "Hello World". - */ -sds sdstrim(sds s, const char *cset) { - char *start, *end, *sp, *ep; - size_t len; - - sp = start = s; - ep = end = s+sdslen(s)-1; - while(sp <= end && strchr(cset, *sp)) sp++; - while(ep > sp && strchr(cset, *ep)) ep--; - len = (sp > ep) ? 0 : ((ep-sp)+1); - if (s != sp) memmove(s, sp, len); - s[len] = '\0'; - sdssetlen(s,len); - return s; -} - -/* Turn the string into a smaller (or equal) string containing only the - * substring specified by the 'start' and 'end' indexes. - * - * start and end can be negative, where -1 means the last character of the - * string, -2 the penultimate character, and so forth. - * - * The interval is inclusive, so the start and end characters will be part - * of the resulting string. - * - * The string is modified in-place. - * - * Example: - * - * s = sdsnew("Hello World"); - * sdsrange(s,1,-1); => "ello World" - */ -void sdsrange(sds s, ssize_t start, ssize_t end) { - size_t newlen, len = sdslen(s); - - if (len == 0) return; - if (start < 0) { - start = len+start; - if (start < 0) start = 0; - } - if (end < 0) { - end = len+end; - if (end < 0) end = 0; - } - newlen = (start > end) ? 0 : (end-start)+1; - if (newlen != 0) { - if (start >= (ssize_t)len) { - newlen = 0; - } else if (end >= (ssize_t)len) { - end = len-1; - newlen = (start > end) ? 0 : (end-start)+1; - } - } else { - start = 0; - } - if (start && newlen) memmove(s, s+start, newlen); - s[newlen] = 0; - sdssetlen(s,newlen); -} - -/* Apply tolower() to every character of the sds string 's'. */ -void sdstolower(sds s) { - size_t len = sdslen(s), j; - - for (j = 0; j < len; j++) s[j] = tolower(s[j]); -} - -/* Apply toupper() to every character of the sds string 's'. */ -void sdstoupper(sds s) { - size_t len = sdslen(s), j; - - for (j = 0; j < len; j++) s[j] = toupper(s[j]); -} - -/* Compare two sds strings s1 and s2 with memcmp(). - * - * Return value: - * - * positive if s1 > s2. - * negative if s1 < s2. - * 0 if s1 and s2 are exactly the same binary string. - * - * If two strings share exactly the same prefix, but one of the two has - * additional characters, the longer string is considered to be greater than - * the smaller one. */ -int sdscmp(const sds s1, const sds s2) { - size_t l1, l2, minlen; - int cmp; - - l1 = sdslen(s1); - l2 = sdslen(s2); - minlen = (l1 < l2) ? l1 : l2; - cmp = memcmp(s1,s2,minlen); - if (cmp == 0) return l1>l2? 1: (l1". - * - * After the call, the modified sds string is no longer valid and all the - * references must be substituted with the new pointer returned by the call. */ -sds sdscatrepr(sds s, const char *p, size_t len) { - s = sdscatlen(s,"\"",1); - while(len--) { - switch(*p) { - case '\\': - case '"': - s = sdscatprintf(s,"\\%c",*p); - break; - case '\n': s = sdscatlen(s,"\\n",2); break; - case '\r': s = sdscatlen(s,"\\r",2); break; - case '\t': s = sdscatlen(s,"\\t",2); break; - case '\a': s = sdscatlen(s,"\\a",2); break; - case '\b': s = sdscatlen(s,"\\b",2); break; - default: - if (isprint(*p)) - s = sdscatprintf(s,"%c",*p); - else - s = sdscatprintf(s,"\\x%02x",(unsigned char)*p); - break; - } - p++; - } - return sdscatlen(s,"\"",1); -} - -/* Helper function for sdssplitargs() that returns non zero if 'c' - * is a valid hex digit. */ -int is_hex_digit(char c) { - return (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || - (c >= 'A' && c <= 'F'); -} - -/* Helper function for sdssplitargs() that converts a hex digit into an - * integer from 0 to 15 */ -int hex_digit_to_int(char c) { - switch(c) { - case '0': return 0; - case '1': return 1; - case '2': return 2; - case '3': return 3; - case '4': return 4; - case '5': return 5; - case '6': return 6; - case '7': return 7; - case '8': return 8; - case '9': return 9; - case 'a': case 'A': return 10; - case 'b': case 'B': return 11; - case 'c': case 'C': return 12; - case 'd': case 'D': return 13; - case 'e': case 'E': return 14; - case 'f': case 'F': return 15; - default: return 0; - } -} - -/* Split a line into arguments, where every argument can be in the - * following programming-language REPL-alike form: - * - * foo bar "newline are supported\n" and "\xff\x00otherstuff" - * - * The number of arguments is stored into *argc, and an array - * of sds is returned. - * - * The caller should free the resulting array of sds strings with - * sdsfreesplitres(). - * - * Note that sdscatrepr() is able to convert back a string into - * a quoted string in the same format sdssplitargs() is able to parse. - * - * The function returns the allocated tokens on success, even when the - * input string is empty, or NULL if the input contains unbalanced - * quotes or closed quotes followed by non space characters - * as in: "foo"bar or "foo' - */ -sds *sdssplitargs(const char *line, int *argc) { - const char *p = line; - char *current = NULL; - char **vector = NULL; - - *argc = 0; - while(1) { - /* skip blanks */ - while(*p && isspace(*p)) p++; - if (*p) { - /* get a token */ - int inq=0; /* set to 1 if we are in "quotes" */ - int insq=0; /* set to 1 if we are in 'single quotes' */ - int done=0; - - if (current == NULL) current = sdsempty(); - while(!done) { - if (inq) { - if (*p == '\\' && *(p+1) == 'x' && - is_hex_digit(*(p+2)) && - is_hex_digit(*(p+3))) - { - unsigned char byte; - - byte = (hex_digit_to_int(*(p+2))*16)+ - hex_digit_to_int(*(p+3)); - current = sdscatlen(current,(char*)&byte,1); - p += 3; - } else if (*p == '\\' && *(p+1)) { - char c; - - p++; - switch(*p) { - case 'n': c = '\n'; break; - case 'r': c = '\r'; break; - case 't': c = '\t'; break; - case 'b': c = '\b'; break; - case 'a': c = '\a'; break; - default: c = *p; break; - } - current = sdscatlen(current,&c,1); - } else if (*p == '"') { - /* closing quote must be followed by a space or - * nothing at all. */ - if (*(p+1) && !isspace(*(p+1))) goto err; - done=1; - } else if (!*p) { - /* unterminated quotes */ - goto err; - } else { - current = sdscatlen(current,p,1); - } - } else if (insq) { - if (*p == '\\' && *(p+1) == '\'') { - p++; - current = sdscatlen(current,"'",1); - } else if (*p == '\'') { - /* closing quote must be followed by a space or - * nothing at all. */ - if (*(p+1) && !isspace(*(p+1))) goto err; - done=1; - } else if (!*p) { - /* unterminated quotes */ - goto err; - } else { - current = sdscatlen(current,p,1); - } - } else { - switch(*p) { - case ' ': - case '\n': - case '\r': - case '\t': - case '\0': - done=1; - break; - case '"': - inq=1; - break; - case '\'': - insq=1; - break; - default: - current = sdscatlen(current,p,1); - break; - } - } - if (*p) p++; - } - /* add the token to the vector */ - vector = (char**) realloc(vector,((*argc)+1)*sizeof(char*)); - vector[*argc] = current; - (*argc)++; - current = NULL; - } else { - /* Even on empty input string return something not NULL. */ - if (vector == NULL) vector = (char**) malloc(sizeof(void*)); - return vector; - } - } - -err: - while((*argc)--) - sdsfree(vector[*argc]); - free(vector); - if (current) sdsfree(current); - *argc = 0; - return NULL; -} - -/* Modify the string substituting all the occurrences of the set of - * characters specified in the 'from' string to the corresponding character - * in the 'to' array. - * - * For instance: sdsmapchars(mystring, "ho", "01", 2) - * will have the effect of turning the string "hello" into "0ell1". - * - * The function returns the sds string pointer, that is always the same - * as the input pointer since no resize is needed. */ -sds sdsmapchars(sds s, const char *from, const char *to, size_t setlen) { - size_t j, i, l = sdslen(s); - - for (j = 0; j < l; j++) { - for (i = 0; i < setlen; i++) { - if (s[j] == from[i]) { - s[j] = to[i]; - break; - } - } - } - return s; -} - -/* Join an array of C strings using the specified separator (also a C string). - * Returns the result as an sds string. */ -sds sdsjoin(char **argv, int argc, char *sep) { - sds join = sdsempty(); - int j; - - for (j = 0; j < argc; j++) { - join = sdscat(join, argv[j]); - if (j != argc-1) join = sdscat(join,sep); - } - return join; -} - -/* Like sdsjoin, but joins an array of SDS strings. */ -sds sdsjoinsds(sds *argv, int argc, const char *sep, size_t seplen) { - sds join = sdsempty(); - int j; - - for (j = 0; j < argc; j++) { - join = sdscatsds(join, argv[j]); - if (j != argc-1) join = sdscatlen(join,sep,seplen); - } - return join; -} - -/* Wrappers to the allocators used by SDS. Note that SDS will actually - * just use the macros defined into sdsalloc.h in order to avoid to pay - * the overhead of function calls. Here we define these wrappers only for - * the programs SDS is linked to, if they want to touch the SDS internals - * even if they use a different allocator. */ -void *sdmalloc(size_t size) { return malloc(size); } -void *sdrealloc(void *ptr, size_t size) { return realloc(ptr,size); } -void sdfree(void *ptr) { free(ptr); } - -} - -// LICENSE_CHANGE_END - -#endif diff --git a/lib/duckdb-internal.hpp b/lib/duckdb-internal.hpp deleted file mode 100644 index 45739c3f..00000000 --- a/lib/duckdb-internal.hpp +++ /dev/null @@ -1,82224 +0,0 @@ -#ifdef GODUCKDB_FROM_SOURCE -// See https://raw.githubusercontent.com/duckdb/duckdb/main/LICENSE for licensing information - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/catalog/catalog_search_path.hpp -// -// -//===----------------------------------------------------------------------===// - - - -#include - - - - - -namespace duckdb { - -class ClientContext; - -struct CatalogSearchEntry { - CatalogSearchEntry(string catalog, string schema); - - string catalog; - string schema; - -public: - string ToString() const; - static string ListToString(const vector &input); - static CatalogSearchEntry Parse(const string &input); - static vector ParseList(const string &input); - -private: - static CatalogSearchEntry ParseInternal(const string &input, idx_t &pos); - static string WriteOptionallyQuoted(const string &input); -}; - -enum class CatalogSetPathType { SET_SCHEMA, SET_SCHEMAS }; - -//! The schema search path, in order by which entries are searched if no schema entry is provided -class CatalogSearchPath { -public: - DUCKDB_API explicit CatalogSearchPath(ClientContext &client_p); - CatalogSearchPath(const CatalogSearchPath &other) = delete; - - DUCKDB_API void Set(CatalogSearchEntry new_value, CatalogSetPathType set_type); - DUCKDB_API void Set(vector new_paths, CatalogSetPathType set_type); - DUCKDB_API void Reset(); - - DUCKDB_API const vector &Get(); - const vector &GetSetPaths() { - return set_paths; - } - DUCKDB_API const CatalogSearchEntry &GetDefault(); - DUCKDB_API string GetDefaultSchema(const string &catalog); - DUCKDB_API string GetDefaultCatalog(const string &schema); - - DUCKDB_API vector GetSchemasForCatalog(const string &catalog); - DUCKDB_API vector GetCatalogsForSchema(const string &schema); - - DUCKDB_API bool SchemaInSearchPath(ClientContext &context, const string &catalog_name, const string &schema_name); - -private: - void SetPaths(vector new_paths); - - string GetSetName(CatalogSetPathType set_type); - -private: - ClientContext &context; - vector paths; - //! Only the paths that were explicitly set (minus the always included paths) - vector set_paths; -}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - - -namespace duckdb { - -//! An aggregate function in the catalog -class AggregateFunctionCatalogEntry : public FunctionEntry { -public: - static constexpr const CatalogType Type = CatalogType::AGGREGATE_FUNCTION_ENTRY; - static constexpr const char *Name = "aggregate function"; - -public: - AggregateFunctionCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateAggregateFunctionInfo &info) - : FunctionEntry(CatalogType::AGGREGATE_FUNCTION_ENTRY, catalog, schema, info), functions(info.functions) { - } - - //! The aggregate functions - AggregateFunctionSet functions; -}; -} // namespace duckdb - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/catalog/catalog_entry/collate_catalog_entry.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - -namespace duckdb { - -//! A collation catalog entry -class CollateCatalogEntry : public StandardEntry { -public: - static constexpr const CatalogType Type = CatalogType::COLLATION_ENTRY; - static constexpr const char *Name = "collation"; - -public: - CollateCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateCollationInfo &info) - : StandardEntry(CatalogType::COLLATION_ENTRY, schema, catalog, info.name), function(info.function), - combinable(info.combinable), not_required_for_equality(info.not_required_for_equality) { - } - - //! The collation function to push in case collation is required - ScalarFunction function; - //! Whether or not the collation can be combined with other collations. - bool combinable; - //! Whether or not the collation is required for equality comparisons or not. For many collations a binary - //! comparison for equality comparisons is correct, allowing us to skip the collation in these cases which greatly - //! speeds up processing. - bool not_required_for_equality; -}; -} // namespace duckdb - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/catalog/catalog_entry/copy_function_catalog_entry.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - -namespace duckdb { - -class Catalog; -struct CreateCopyFunctionInfo; - -//! A table function in the catalog -class CopyFunctionCatalogEntry : public StandardEntry { -public: - static constexpr const CatalogType Type = CatalogType::COPY_FUNCTION_ENTRY; - static constexpr const char *Name = "copy function"; - -public: - CopyFunctionCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateCopyFunctionInfo &info); - - //! The copy function - CopyFunction function; -}; -} // namespace duckdb - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/catalog/catalog_entry/index_catalog_entry.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - -namespace duckdb { - -struct DataTableInfo; -class Index; - -//! An index catalog entry -class IndexCatalogEntry : public StandardEntry { -public: - static constexpr const CatalogType Type = CatalogType::INDEX_ENTRY; - static constexpr const char *Name = "index"; - -public: - //! Create an IndexCatalogEntry and initialize storage for it - IndexCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateIndexInfo &info); - - optional_ptr index; - string sql; - vector> expressions; - vector> parsed_expressions; - case_insensitive_map_t options; - -public: - unique_ptr GetInfo() const override; - string ToSQL() const override; - - virtual string GetSchemaName() const = 0; - virtual string GetTableName() const = 0; -}; - -} // namespace duckdb - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/catalog/catalog_entry/macro_catalog_entry.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/catalog/catalog_entry/macro_catalog_entry.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - - -namespace duckdb { - -//! A macro function in the catalog -class MacroCatalogEntry : public FunctionEntry { -public: - MacroCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateMacroInfo &info); - - //! The macro function - unique_ptr function; - -public: - unique_ptr GetInfo() const override; - - string ToSQL() const override { - return function->ToSQL(schema.name, name); - } -}; - -} // namespace duckdb - - -namespace duckdb { - -//! A macro function in the catalog -class ScalarMacroCatalogEntry : public MacroCatalogEntry { -public: - static constexpr const CatalogType Type = CatalogType::MACRO_ENTRY; - static constexpr const char *Name = "macro function"; - -public: - ScalarMacroCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateMacroInfo &info); -}; -} // namespace duckdb - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/catalog/catalog_entry/pragma_function_catalog_entry.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - -namespace duckdb { - -class Catalog; -struct CreatePragmaFunctionInfo; - -//! A table function in the catalog -class PragmaFunctionCatalogEntry : public FunctionEntry { -public: - static constexpr const CatalogType Type = CatalogType::PRAGMA_FUNCTION_ENTRY; - static constexpr const char *Name = "pragma function"; - -public: - PragmaFunctionCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreatePragmaFunctionInfo &info); - - //! The pragma functions - PragmaFunctionSet functions; -}; -} // namespace duckdb - - - - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/catalog/catalog_entry/table_catalog_entry.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - - - - - - - -namespace duckdb { - -class DataTable; -struct CreateTableInfo; -struct BoundCreateTableInfo; - -struct RenameColumnInfo; -struct AddColumnInfo; -struct RemoveColumnInfo; -struct SetDefaultInfo; -struct ChangeColumnTypeInfo; -struct AlterForeignKeyInfo; -struct SetNotNullInfo; -struct DropNotNullInfo; - -class TableFunction; -struct FunctionData; - -class TableColumnInfo; -struct ColumnSegmentInfo; -class TableStorageInfo; - -class LogicalGet; -class LogicalProjection; -class LogicalUpdate; - -//! A table catalog entry -class TableCatalogEntry : public StandardEntry { -public: - static constexpr const CatalogType Type = CatalogType::TABLE_ENTRY; - static constexpr const char *Name = "table"; - -public: - //! Create a TableCatalogEntry and initialize storage for it - DUCKDB_API TableCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateTableInfo &info); - -public: - DUCKDB_API unique_ptr GetInfo() const override; - - DUCKDB_API bool HasGeneratedColumns() const; - - //! Returns whether or not a column with the given name exists - DUCKDB_API bool ColumnExists(const string &name); - //! Returns a reference to the column of the specified name. Throws an - //! exception if the column does not exist. - DUCKDB_API const ColumnDefinition &GetColumn(const string &name); - //! Returns a reference to the column of the specified logical index. Throws an - //! exception if the column does not exist. - DUCKDB_API const ColumnDefinition &GetColumn(LogicalIndex idx); - //! Returns a list of types of the table, excluding generated columns - DUCKDB_API vector GetTypes(); - //! Returns a list of the columns of the table - DUCKDB_API const ColumnList &GetColumns() const; - //! Returns the underlying storage of the table - virtual DataTable &GetStorage(); - //! Returns a list of the bound constraints of the table - virtual const vector> &GetBoundConstraints(); - - //! Returns a list of the constraints of the table - DUCKDB_API const vector> &GetConstraints(); - DUCKDB_API string ToSQL() const override; - - //! Get statistics of a column (physical or virtual) within the table - virtual unique_ptr GetStatistics(ClientContext &context, column_t column_id) = 0; - - //! Returns the column index of the specified column name. - //! If the column does not exist: - //! If if_column_exists is true, returns DConstants::INVALID_INDEX - //! If if_column_exists is false, throws an exception - DUCKDB_API LogicalIndex GetColumnIndex(string &name, bool if_exists = false); - - //! Returns the scan function that can be used to scan the given table - virtual TableFunction GetScanFunction(ClientContext &context, unique_ptr &bind_data) = 0; - - virtual bool IsDuckTable() const { - return false; - } - - DUCKDB_API static string ColumnsToSQL(const ColumnList &columns, const vector> &constraints); - - //! Returns a list of segment information for this table, if exists - virtual vector GetColumnSegmentInfo(); - - //! Returns the storage info of this table - virtual TableStorageInfo GetStorageInfo(ClientContext &context) = 0; - - virtual void BindUpdateConstraints(LogicalGet &get, LogicalProjection &proj, LogicalUpdate &update, - ClientContext &context); - -protected: - //! A list of columns that are part of this table - ColumnList columns; - //! A list of constraints that are part of this table - vector> constraints; -}; -} // namespace duckdb - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/catalog/catalog_entry/table_function_catalog_entry.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - - -namespace duckdb { - -//! A table function in the catalog -class TableFunctionCatalogEntry : public FunctionEntry { -public: - static constexpr const CatalogType Type = CatalogType::TABLE_FUNCTION_ENTRY; - static constexpr const char *Name = "table function"; - -public: - TableFunctionCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateTableFunctionInfo &info); - - //! The table function - TableFunctionSet functions; - -public: - unique_ptr AlterEntry(ClientContext &context, AlterInfo &info) override; -}; -} // namespace duckdb - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/catalog/catalog_entry/view_catalog_entry.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - - -namespace duckdb { - -class DataTable; -struct CreateViewInfo; - -//! A view catalog entry -class ViewCatalogEntry : public StandardEntry { -public: - static constexpr const CatalogType Type = CatalogType::VIEW_ENTRY; - static constexpr const char *Name = "view"; - -public: - //! Create a real TableCatalogEntry and initialize storage for it - ViewCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateViewInfo &info); - - //! The query of the view - unique_ptr query; - //! The SQL query (if any) - string sql; - //! The set of aliases associated with the view - vector aliases; - //! The returned types of the view - vector types; - -public: - unique_ptr GetInfo() const override; - - unique_ptr AlterEntry(ClientContext &context, AlterInfo &info) override; - - unique_ptr Copy(ClientContext &context) const override; - - string ToSQL() const override; - -private: - void Initialize(CreateViewInfo &info); -}; -} // namespace duckdb - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/catalog/default/default_schemas.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - -namespace duckdb { - -class DefaultSchemaGenerator : public DefaultGenerator { -public: - explicit DefaultSchemaGenerator(Catalog &catalog); - -public: - unique_ptr CreateDefaultEntry(ClientContext &context, const string &entry_name) override; - vector GetDefaultEntries() override; -}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/catalog/catalog_entry/type_catalog_entry.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - -namespace duckdb { - -//! A type catalog entry -class TypeCatalogEntry : public StandardEntry { -public: - static constexpr const CatalogType Type = CatalogType::TYPE_ENTRY; - static constexpr const char *Name = "type"; - -public: - //! Create a TypeCatalogEntry and initialize storage for it - TypeCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateTypeInfo &info); - - LogicalType user_type; - -public: - unique_ptr GetInfo() const override; - - string ToSQL() const override; -}; -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/main/client_data.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - - - - -namespace duckdb { -class AttachedDatabase; -class BufferedFileWriter; -class ClientContext; -class CatalogSearchPath; -class FileOpener; -class FileSystem; -class HTTPState; -class QueryProfiler; -class QueryProfilerHistory; -class PreparedStatementData; -class SchemaCatalogEntry; -struct RandomEngine; - -struct ClientData { - explicit ClientData(ClientContext &context); - ~ClientData(); - - //! Query profiler - shared_ptr profiler; - //! QueryProfiler History - unique_ptr query_profiler_history; - - //! The set of temporary objects that belong to this client - shared_ptr temporary_objects; - //! The set of bound prepared statements that belong to this client - case_insensitive_map_t> prepared_statements; - - //! The writer used to log queries (if logging is enabled) - unique_ptr log_query_writer; - //! The random generator used by random(). Its seed value can be set by setseed(). - unique_ptr random_engine; - - //! The catalog search path - unique_ptr catalog_search_path; - - //! The file opener of the client context - unique_ptr file_opener; - - //! HTTP State in this query - shared_ptr http_state; - - //! The clients' file system wrapper - unique_ptr client_file_system; - - //! The file search path - string file_search_path; - - //! The Max Line Length Size of Last Query Executed on a CSV File. (Only used for testing) - //! FIXME: this should not be done like this - bool debug_set_max_line_length = false; - idx_t debug_max_line_length = 0; - -public: - DUCKDB_API static ClientData &Get(ClientContext &context); -}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/main/extension_helper.hpp -// -// -//===----------------------------------------------------------------------===// - - - -#include - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/main/extension_entries.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - -// NOTE: this file is generated by scripts/generate_extensions_function.py. Check out the check-load-install-extensions -// job in .github/workflows/LinuxRelease.yml on how to use it - -namespace duckdb { - -struct ExtensionEntry { - char name[48]; - char extension[48]; -}; - -static constexpr ExtensionEntry EXTENSION_FUNCTIONS[] = { - {"->>", "json"}, - {"array_to_json", "json"}, - {"create_fts_index", "fts"}, - {"current_localtime", "icu"}, - {"current_localtimestamp", "icu"}, - {"dbgen", "tpch"}, - {"drop_fts_index", "fts"}, - {"dsdgen", "tpcds"}, - {"excel_text", "excel"}, - {"from_json", "json"}, - {"from_json_strict", "json"}, - {"from_substrait", "substrait"}, - {"from_substrait_json", "substrait"}, - {"fuzz_all_functions", "sqlsmith"}, - {"fuzzyduck", "sqlsmith"}, - {"get_substrait", "substrait"}, - {"get_substrait_json", "substrait"}, - {"host", "inet"}, - {"iceberg_metadata", "iceberg"}, - {"iceberg_scan", "iceberg"}, - {"iceberg_snapshots", "iceberg"}, - {"icu_calendar_names", "icu"}, - {"icu_sort_key", "icu"}, - {"json", "json"}, - {"json_array", "json"}, - {"json_array_length", "json"}, - {"json_contains", "json"}, - {"json_deserialize_sql", "json"}, - {"json_execute_serialized_sql", "json"}, - {"json_extract", "json"}, - {"json_extract_path", "json"}, - {"json_extract_path_text", "json"}, - {"json_extract_string", "json"}, - {"json_group_array", "json"}, - {"json_group_object", "json"}, - {"json_group_structure", "json"}, - {"json_keys", "json"}, - {"json_merge_patch", "json"}, - {"json_object", "json"}, - {"json_quote", "json"}, - {"json_serialize_sql", "json"}, - {"json_structure", "json"}, - {"json_transform", "json"}, - {"json_transform_strict", "json"}, - {"json_type", "json"}, - {"json_valid", "json"}, - {"load_aws_credentials", "aws"}, - {"make_timestamptz", "icu"}, - {"parquet_metadata", "parquet"}, - {"parquet_scan", "parquet"}, - {"parquet_schema", "parquet"}, - {"pg_timezone_names", "icu"}, - {"postgres_attach", "postgres_scanner"}, - {"postgres_scan", "postgres_scanner"}, - {"postgres_scan_pushdown", "postgres_scanner"}, - {"read_json", "json"}, - {"read_json_auto", "json"}, - {"read_json_objects", "json"}, - {"read_json_objects_auto", "json"}, - {"read_ndjson", "json"}, - {"read_ndjson_auto", "json"}, - {"read_ndjson_objects", "json"}, - {"read_parquet", "parquet"}, - {"reduce_sql_statement", "sqlsmith"}, - {"row_to_json", "json"}, - {"scan_arrow_ipc", "arrow"}, - {"sql_auto_complete", "autocomplete"}, - {"sqlite_attach", "sqlite_scanner"}, - {"sqlite_scan", "sqlite_scanner"}, - {"sqlsmith", "sqlsmith"}, - {"st_area", "spatial"}, - {"st_area_spheroid", "spatial"}, - {"st_asgeojson", "spatial"}, - {"st_ashexwkb", "spatial"}, - {"st_astext", "spatial"}, - {"st_aswkb", "spatial"}, - {"st_boundary", "spatial"}, - {"st_buffer", "spatial"}, - {"st_centroid", "spatial"}, - {"st_collect", "spatial"}, - {"st_collectionextract", "spatial"}, - {"st_contains", "spatial"}, - {"st_containsproperly", "spatial"}, - {"st_convexhull", "spatial"}, - {"st_coveredby", "spatial"}, - {"st_covers", "spatial"}, - {"st_crosses", "spatial"}, - {"st_difference", "spatial"}, - {"st_dimension", "spatial"}, - {"st_disjoint", "spatial"}, - {"st_distance", "spatial"}, - {"st_distance_spheroid", "spatial"}, - {"st_drivers", "spatial"}, - {"st_dwithin", "spatial"}, - {"st_dwithin_spheroid", "spatial"}, - {"st_endpoint", "spatial"}, - {"st_envelope", "spatial"}, - {"st_envelope_agg", "spatial"}, - {"st_equals", "spatial"}, - {"st_extent", "spatial"}, - {"st_exteriorring", "spatial"}, - {"st_flipcoordinates", "spatial"}, - {"st_geometrytype", "spatial"}, - {"st_geomfromgeojson", "spatial"}, - {"st_geomfromhexewkb", "spatial"}, - {"st_geomfromhexwkb", "spatial"}, - {"st_geomfromtext", "spatial"}, - {"st_geomfromwkb", "spatial"}, - {"st_intersection", "spatial"}, - {"st_intersection_agg", "spatial"}, - {"st_intersects", "spatial"}, - {"st_intersects_extent", "spatial"}, - {"st_isclosed", "spatial"}, - {"st_isempty", "spatial"}, - {"st_isring", "spatial"}, - {"st_issimple", "spatial"}, - {"st_isvalid", "spatial"}, - {"st_length", "spatial"}, - {"st_length_spheroid", "spatial"}, - {"st_linestring2dfromwkb", "spatial"}, - {"st_list_proj_crs", "spatial"}, - {"st_makeline", "spatial"}, - {"st_ngeometries", "spatial"}, - {"st_ninteriorrings", "spatial"}, - {"st_normalize", "spatial"}, - {"st_npoints", "spatial"}, - {"st_numgeometries", "spatial"}, - {"st_numinteriorrings", "spatial"}, - {"st_numpoints", "spatial"}, - {"st_overlaps", "spatial"}, - {"st_perimeter", "spatial"}, - {"st_perimeter_spheroid", "spatial"}, - {"st_point", "spatial"}, - {"st_point2d", "spatial"}, - {"st_point2dfromwkb", "spatial"}, - {"st_point3d", "spatial"}, - {"st_point4d", "spatial"}, - {"st_pointn", "spatial"}, - {"st_pointonsurface", "spatial"}, - {"st_polygon2dfromwkb", "spatial"}, - {"st_reverse", "spatial"}, - {"st_read", "spatial"}, - {"st_readosm", "spatial"}, - {"st_reduceprecision", "spatial"}, - {"st_removerepeatedpoints", "spatial"}, - {"st_simplify", "spatial"}, - {"st_simplifypreservetopology", "spatial"}, - {"st_startpoint", "spatial"}, - {"st_touches", "spatial"}, - {"st_transform", "spatial"}, - {"st_union", "spatial"}, - {"st_union_agg", "spatial"}, - {"st_within", "spatial"}, - {"st_x", "spatial"}, - {"st_xmax", "spatial"}, - {"st_xmin", "spatial"}, - {"st_y", "spatial"}, - {"st_ymax", "spatial"}, - {"st_ymin", "spatial"}, - {"stem", "fts"}, - {"text", "excel"}, - {"to_arrow_ipc", "arrow"}, - {"to_json", "json"}, - {"tpcds", "tpcds"}, - {"tpcds_answers", "tpcds"}, - {"tpcds_queries", "tpcds"}, - {"tpch", "tpch"}, - {"tpch_answers", "tpch"}, - {"tpch_queries", "tpch"}, - {"visualize_diff_profiling_output", "visualizer"}, - {"visualize_json_profiling_output", "visualizer"}, - {"visualize_last_profiling_output", "visualizer"}, -}; // END_OF_EXTENSION_FUNCTIONS - -static constexpr ExtensionEntry EXTENSION_SETTINGS[] = { - {"azure_storage_connection_string", "azure"}, - {"binary_as_string", "parquet"}, - {"calendar", "icu"}, - {"force_download", "httpfs"}, - {"http_retries", "httpfs"}, - {"http_retry_backoff", "httpfs"}, - {"http_retry_wait_ms", "httpfs"}, - {"http_timeout", "httpfs"}, - {"s3_access_key_id", "httpfs"}, - {"s3_endpoint", "httpfs"}, - {"s3_region", "httpfs"}, - {"s3_secret_access_key", "httpfs"}, - {"s3_session_token", "httpfs"}, - {"s3_uploader_max_filesize", "httpfs"}, - {"s3_uploader_max_parts_per_file", "httpfs"}, - {"s3_uploader_thread_limit", "httpfs"}, - {"s3_url_compatibility_mode", "httpfs"}, - {"s3_url_style", "httpfs"}, - {"s3_use_ssl", "httpfs"}, - {"sqlite_all_varchar", "sqlite_scanner"}, - {"timezone", "icu"}, -}; // END_OF_EXTENSION_SETTINGS - -// Note: these are currently hardcoded in scripts/generate_extensions_function.py -// TODO: automate by passing though to script via duckdb -static constexpr ExtensionEntry EXTENSION_COPY_FUNCTIONS[] = {{"parquet", "parquet"}, - {"json", "json"}}; // END_OF_EXTENSION_COPY_FUNCTIONS - -// Note: these are currently hardcoded in scripts/generate_extensions_function.py -// TODO: automate by passing though to script via duckdb -static constexpr ExtensionEntry EXTENSION_TYPES[] = { - {"json", "json"}, {"inet", "inet"}, {"geometry", "spatial"}}; // END_OF_EXTENSION_TYPES - -// Note: these are currently hardcoded in scripts/generate_extensions_function.py -// TODO: automate by passing though to script via duckdb -static constexpr ExtensionEntry EXTENSION_COLLATIONS[] = { - {"af", "icu"}, {"am", "icu"}, {"ar", "icu"}, {"ar_sa", "icu"}, {"as", "icu"}, {"az", "icu"}, - {"be", "icu"}, {"bg", "icu"}, {"bn", "icu"}, {"bo", "icu"}, {"br", "icu"}, {"bs", "icu"}, - {"ca", "icu"}, {"ceb", "icu"}, {"chr", "icu"}, {"cs", "icu"}, {"cy", "icu"}, {"da", "icu"}, - {"de", "icu"}, {"de_at", "icu"}, {"dsb", "icu"}, {"dz", "icu"}, {"ee", "icu"}, {"el", "icu"}, - {"en", "icu"}, {"en_us", "icu"}, {"eo", "icu"}, {"es", "icu"}, {"et", "icu"}, {"fa", "icu"}, - {"fa_af", "icu"}, {"ff", "icu"}, {"fi", "icu"}, {"fil", "icu"}, {"fo", "icu"}, {"fr", "icu"}, - {"fr_ca", "icu"}, {"fy", "icu"}, {"ga", "icu"}, {"gl", "icu"}, {"gu", "icu"}, {"ha", "icu"}, - {"haw", "icu"}, {"he", "icu"}, {"he_il", "icu"}, {"hi", "icu"}, {"hr", "icu"}, {"hsb", "icu"}, - {"hu", "icu"}, {"hy", "icu"}, {"id", "icu"}, {"id_id", "icu"}, {"ig", "icu"}, {"is", "icu"}, - {"it", "icu"}, {"ja", "icu"}, {"ka", "icu"}, {"kk", "icu"}, {"kl", "icu"}, {"km", "icu"}, - {"kn", "icu"}, {"ko", "icu"}, {"kok", "icu"}, {"ku", "icu"}, {"ky", "icu"}, {"lb", "icu"}, - {"lkt", "icu"}, {"ln", "icu"}, {"lo", "icu"}, {"lt", "icu"}, {"lv", "icu"}, {"mk", "icu"}, - {"ml", "icu"}, {"mn", "icu"}, {"mr", "icu"}, {"ms", "icu"}, {"mt", "icu"}, {"my", "icu"}, - {"nb", "icu"}, {"nb_no", "icu"}, {"ne", "icu"}, {"nl", "icu"}, {"nn", "icu"}, {"om", "icu"}, - {"or", "icu"}, {"pa", "icu"}, {"pa_in", "icu"}, {"pl", "icu"}, {"ps", "icu"}, {"pt", "icu"}, - {"ro", "icu"}, {"ru", "icu"}, {"sa", "icu"}, {"se", "icu"}, {"si", "icu"}, {"sk", "icu"}, - {"sl", "icu"}, {"smn", "icu"}, {"sq", "icu"}, {"sr", "icu"}, {"sr_ba", "icu"}, {"sr_me", "icu"}, - {"sr_rs", "icu"}, {"sv", "icu"}, {"sw", "icu"}, {"ta", "icu"}, {"te", "icu"}, {"th", "icu"}, - {"tk", "icu"}, {"to", "icu"}, {"tr", "icu"}, {"ug", "icu"}, {"uk", "icu"}, {"ur", "icu"}, - {"uz", "icu"}, {"vi", "icu"}, {"wae", "icu"}, {"wo", "icu"}, {"xh", "icu"}, {"yi", "icu"}, - {"yo", "icu"}, {"yue", "icu"}, {"yue_cn", "icu"}, {"zh", "icu"}, {"zh_cn", "icu"}, {"zh_hk", "icu"}, - {"zh_mo", "icu"}, {"zh_sg", "icu"}, {"zh_tw", "icu"}, {"zu", "icu"}}; // END_OF_EXTENSION_COLLATIONS - -// Note: these are currently hardcoded in scripts/generate_extensions_function.py -// TODO: automate by passing though to script via duckdb -static constexpr ExtensionEntry EXTENSION_FILE_PREFIXES[] = { - {"http://", "httpfs"}, {"https://", "httpfs"}, {"s3://", "httpfs"}, - // {"azure://", "azure"} -}; // END_OF_EXTENSION_FILE_PREFIXES - -// Note: these are currently hardcoded in scripts/generate_extensions_function.py -// TODO: automate by passing though to script via duckdb -static constexpr ExtensionEntry EXTENSION_FILE_POSTFIXES[] = { - {".parquet", "parquet"}, {".json", "json"}, {".jsonl", "json"}, {".ndjson", "json"}, - {".shp", "spatial"}, {".gpkg", "spatial"}, {".fgb", "spatial"}}; // END_OF_EXTENSION_FILE_POSTFIXES - -// Note: these are currently hardcoded in scripts/generate_extensions_function.py -// TODO: automate by passing though to script via duckdb -static constexpr ExtensionEntry EXTENSION_FILE_CONTAINS[] = {{".parquet?", "parquet"}, - {".json?", "json"}, - {".ndjson?", ".jsonl?"}, - {".jsonl?", ".ndjson?"}}; // EXTENSION_FILE_CONTAINS - -static constexpr const char *AUTOLOADABLE_EXTENSIONS[] = { - // "azure", - "arrow", - "aws", - "autocomplete", - "excel", - "fts", - "httpfs", - // "inet", - // "icu", - "json", - "parquet", - "postgres_scanner", - // "spatial", TODO: table function isnt always autoloaded so test fails - "sqlsmith", - "sqlite_scanner", - "tpcds", - "tpch", - "visualizer", -}; // END_OF_AUTOLOADABLE_EXTENSIONS - -} // namespace duckdb - - -namespace duckdb { -class DuckDB; - -enum class ExtensionLoadResult : uint8_t { LOADED_EXTENSION = 0, EXTENSION_UNKNOWN = 1, NOT_LOADED = 2 }; - -struct DefaultExtension { - const char *name; - const char *description; - bool statically_loaded; -}; - -struct ExtensionAlias { - const char *alias; - const char *extension; -}; - -struct ExtensionInitResult { - string filename; - string basename; - - void *lib_hdl; -}; - -class ExtensionHelper { -public: - static void LoadAllExtensions(DuckDB &db); - - static ExtensionLoadResult LoadExtension(DuckDB &db, const std::string &extension); - - static void InstallExtension(ClientContext &context, const string &extension, bool force_install, - const string &respository = ""); - static void InstallExtension(DBConfig &config, FileSystem &fs, const string &extension, bool force_install, - const string &respository = ""); - static void LoadExternalExtension(ClientContext &context, const string &extension); - static void LoadExternalExtension(DatabaseInstance &db, FileSystem &fs, const string &extension, - optional_ptr client_config); - - //! Autoload an extension by name. Depending on the current settings, this will either load or install+load - static void AutoLoadExtension(ClientContext &context, const string &extension_name); - DUCKDB_API static bool TryAutoLoadExtension(ClientContext &context, const string &extension_name) noexcept; - - static string ExtensionDirectory(ClientContext &context); - static string ExtensionDirectory(DBConfig &config, FileSystem &fs); - static string ExtensionUrlTemplate(optional_ptr config, const string &repository); - static string ExtensionFinalizeUrlTemplate(const string &url, const string &name); - - static idx_t DefaultExtensionCount(); - static DefaultExtension GetDefaultExtension(idx_t index); - - static idx_t ExtensionAliasCount(); - static ExtensionAlias GetExtensionAlias(idx_t index); - - static const vector GetPublicKeys(); - - // Returns extension name, or empty string if not a replacement open path - static string ExtractExtensionPrefixFromPath(const string &path); - - //! Apply any known extension aliases - static string ApplyExtensionAlias(string extension_name); - - static string GetExtensionName(const string &extension); - static bool IsFullPath(const string &extension); - - //! Lookup a name in an ExtensionEntry list - template - static string FindExtensionInEntries(const string &name, const ExtensionEntry (&entries)[N]) { - auto lcase = StringUtil::Lower(name); - - auto it = - std::find_if(entries, entries + N, [&](const ExtensionEntry &element) { return element.name == lcase; }); - - if (it != entries + N && it->name == lcase) { - return it->extension; - } - return ""; - } - - //! Whether an extension can be autoloaded (i.e. it's registered as an autoloadable extension in - //! extension_entries.hpp) - static bool CanAutoloadExtension(const string &ext_name); - - //! Utility functions for creating meaningful error messages regarding missing extensions - static string WrapAutoLoadExtensionErrorMsg(ClientContext &context, const string &base_error, - const string &extension_name); - static string AddExtensionInstallHintToErrorMsg(ClientContext &context, const string &base_error, - const string &extension_name); - -private: - static void InstallExtensionInternal(DBConfig &config, ClientConfig *client_config, FileSystem &fs, - const string &local_path, const string &extension, bool force_install, - const string &repository); - static const vector PathComponents(); - static bool AllowAutoInstall(const string &extension); - static ExtensionInitResult InitialLoad(DBConfig &config, FileSystem &fs, const string &extension, - optional_ptr client_config); - static bool TryInitialLoad(DBConfig &config, FileSystem &fs, const string &extension, ExtensionInitResult &result, - string &error, optional_ptr client_config); - //! For tagged releases we use the tag, else we use the git commit hash - static const string GetVersionDirectoryName(); - //! Version tags occur with and without 'v', tag in extension path is always with 'v' - static const string NormalizeVersionTag(const string &version_tag); - static bool IsRelease(const string &version_tag); - static bool CreateSuggestions(const string &extension_name, string &message); - -private: - static ExtensionLoadResult LoadExtensionInternal(DuckDB &db, const std::string &extension, bool initial_load); -}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/catalog/default/default_types.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - -namespace duckdb { -class SchemaCatalogEntry; - -class DefaultTypeGenerator : public DefaultGenerator { -public: - DefaultTypeGenerator(Catalog &catalog, SchemaCatalogEntry &schema); - - SchemaCatalogEntry &schema; - -public: - DUCKDB_API static LogicalTypeId GetDefaultType(const string &name); - - unique_ptr CreateDefaultEntry(ClientContext &context, const string &entry_name) override; - vector GetDefaultEntries() override; -}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/main/extension/generated_extension_loader.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - -#if defined(GENERATED_EXTENSION_HEADERS) and !defined(DUCKDB_AMALGAMATION) -#include "generated_extension_headers.hpp" - - -namespace duckdb { - -//! Looks through the CMake-generated list of extensions that are linked into DuckDB currently to try load -bool TryLoadLinkedExtension(DuckDB &db, const string &extension); -extern vector linked_extensions; -extern vector loaded_extension_test_paths; - -} // namespace duckdb -#endif -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/main/attached_database.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - - - -namespace duckdb { -class Catalog; -class DatabaseInstance; -class StorageManager; -class TransactionManager; -class StorageExtension; - -struct AttachInfo; - -enum class AttachedDatabaseType { - READ_WRITE_DATABASE, - READ_ONLY_DATABASE, - SYSTEM_DATABASE, - TEMP_DATABASE, -}; - -//! The AttachedDatabase represents an attached database instance -class AttachedDatabase : public CatalogEntry { -public: - //! Create the built-in system attached database (without storage) - explicit AttachedDatabase(DatabaseInstance &db, AttachedDatabaseType type = AttachedDatabaseType::SYSTEM_DATABASE); - //! Create an attached database instance with the specified name and storage - AttachedDatabase(DatabaseInstance &db, Catalog &catalog, string name, string file_path, AccessMode access_mode); - //! Create an attached database instance with the specified storage extension - AttachedDatabase(DatabaseInstance &db, Catalog &catalog, StorageExtension &ext, string name, AttachInfo &info, - AccessMode access_mode); - ~AttachedDatabase() override; - - void Initialize(); - - Catalog &ParentCatalog() override; - StorageManager &GetStorageManager(); - Catalog &GetCatalog(); - TransactionManager &GetTransactionManager(); - DatabaseInstance &GetDatabase() { - return db; - } - const string &GetName() const { - return name; - } - bool IsSystem() const; - bool IsTemporary() const; - bool IsReadOnly() const; - bool IsInitialDatabase() const; - void SetInitialDatabase(); - - static string ExtractDatabaseName(const string &dbpath, FileSystem &fs); - -private: - DatabaseInstance &db; - unique_ptr storage; - unique_ptr catalog; - unique_ptr transaction_manager; - AttachedDatabaseType type; - optional_ptr parent_catalog; - bool is_initial_database = false; -}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/main/database_manager.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - - - - -namespace duckdb { -class AttachedDatabase; -class Catalog; -class CatalogSet; -class ClientContext; -class DatabaseInstance; - -//! The DatabaseManager is a class that sits at the root of all attached databases -class DatabaseManager { - friend class Catalog; - -public: - explicit DatabaseManager(DatabaseInstance &db); - ~DatabaseManager(); - -public: - static DatabaseManager &Get(DatabaseInstance &db); - static DatabaseManager &Get(ClientContext &db); - static DatabaseManager &Get(AttachedDatabase &db); - - void InitializeSystemCatalog(); - //! Get an attached database with the given name - optional_ptr GetDatabase(ClientContext &context, const string &name); - //! Add a new attached database to the database manager - void AddDatabase(ClientContext &context, unique_ptr db); - void DetachDatabase(ClientContext &context, const string &name, OnEntryNotFound if_not_found); - //! Returns a reference to the system catalog - Catalog &GetSystemCatalog(); - static const string &GetDefaultDatabase(ClientContext &context); - void SetDefaultDatabase(ClientContext &context, const string &new_value); - - optional_ptr GetDatabaseFromPath(ClientContext &context, const string &path); - vector> GetDatabases(ClientContext &context); - - transaction_t GetNewQueryNumber() { - return current_query_number++; - } - transaction_t ActiveQueryNumber() const { - return current_query_number; - } - idx_t ModifyCatalog() { - return catalog_version++; - } - bool HasDefaultDatabase() { - return !default_database.empty(); - } - -private: - //! The system database is a special database that holds system entries (e.g. functions) - unique_ptr system; - //! The set of attached databases - unique_ptr databases; - //! The global catalog version, incremented whenever anything changes in the catalog - atomic catalog_version; - //! The current query number - atomic current_query_number; - //! The current default database - string default_database; -}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/function/built_in_functions.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - -namespace duckdb { - -class BuiltinFunctions { -public: - BuiltinFunctions(CatalogTransaction transaction, Catalog &catalog); - ~BuiltinFunctions(); - - //! Initialize a catalog with all built-in functions - void Initialize(); - -public: - void AddFunction(AggregateFunctionSet set); - void AddFunction(AggregateFunction function); - void AddFunction(ScalarFunctionSet set); - void AddFunction(PragmaFunction function); - void AddFunction(const string &name, PragmaFunctionSet functions); - void AddFunction(ScalarFunction function); - void AddFunction(const vector &names, ScalarFunction function); - void AddFunction(TableFunctionSet set); - void AddFunction(TableFunction function); - void AddFunction(CopyFunction function); - - void AddCollation(string name, ScalarFunction function, bool combinable = false, - bool not_required_for_equality = false); - -private: - CatalogTransaction transaction; - Catalog &catalog; - -private: - template - void Register() { - T::RegisterFunction(*this); - } - - // table-producing functions - void RegisterTableScanFunctions(); - void RegisterSQLiteFunctions(); - void RegisterReadFunctions(); - void RegisterTableFunctions(); - void RegisterArrowFunctions(); - - // aggregates - void RegisterDistributiveAggregates(); - - // scalar functions - void RegisterCompressedMaterializationFunctions(); - void RegisterGenericFunctions(); - void RegisterOperators(); - void RegisterStringFunctions(); - void RegisterNestedFunctions(); - void RegisterSequenceFunctions(); - - // pragmas - void RegisterPragmaFunctions(); -}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/storage/database_size.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - -namespace duckdb { - -struct DatabaseSize { - idx_t total_blocks = 0; - idx_t block_size = 0; - idx_t free_blocks = 0; - idx_t used_blocks = 0; - idx_t bytes = 0; - idx_t wal_size = 0; -}; - -struct MetadataBlockInfo { - block_id_t block_id; - idx_t total_blocks; - vector free_list; -}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/catalog/catalog_entry/duck_index_entry.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - -namespace duckdb { - -//! An index catalog entry -class DuckIndexEntry : public IndexCatalogEntry { -public: - //! Create an IndexCatalogEntry and initialize storage for it - DuckIndexEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateIndexInfo &info); - ~DuckIndexEntry(); - - shared_ptr info; - -public: - string GetSchemaName() const override; - string GetTableName() const override; - //! This drops in-memory index data and marks all blocks on disk as free blocks, allowing to reclaim them - void CommitDrop(); -}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/execution/index/art/art.hpp -// -// -//===----------------------------------------------------------------------===// - - - - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/execution/index/art/node.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/execution/index/index_pointer.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - -namespace duckdb { - -class IndexPointer { -public: - //! Bit-shifting - static constexpr idx_t SHIFT_OFFSET = 32; - static constexpr idx_t SHIFT_METADATA = 56; - //! AND operations - static constexpr idx_t AND_OFFSET = 0x0000000000FFFFFF; - static constexpr idx_t AND_BUFFER_ID = 0x00000000FFFFFFFF; - static constexpr idx_t AND_METADATA = 0xFF00000000000000; - -public: - //! Constructs an empty IndexPointer - IndexPointer() : data(0) {}; - //! Constructs an in-memory IndexPointer with a buffer ID and an offset - IndexPointer(const uint32_t buffer_id, const uint32_t offset) : data(0) { - auto shifted_offset = ((idx_t)offset) << SHIFT_OFFSET; - data += shifted_offset; - data += buffer_id; - }; - -public: - //! Get data (all 64 bits) - inline idx_t Get() const { - return data; - } - //! Set data (all 64 bits) - inline void Set(const idx_t data_p) { - data = data_p; - } - - //! Returns false, if the metadata is empty - inline bool HasMetadata() const { - return data & AND_METADATA; - } - //! Get metadata (zero to 7th bit) - inline uint8_t GetMetadata() const { - return data >> SHIFT_METADATA; - } - //! Set metadata (zero to 7th bit) - inline void SetMetadata(const uint8_t metadata) { - data += (idx_t)metadata << SHIFT_METADATA; - } - - //! Get the offset (8th to 23rd bit) - inline idx_t GetOffset() const { - auto offset = data >> SHIFT_OFFSET; - return offset & AND_OFFSET; - } - //! Get the buffer ID (24th to 63rd bit) - inline idx_t GetBufferId() const { - return data & AND_BUFFER_ID; - } - - //! Resets the IndexPointer - inline void Clear() { - data = 0; - } - - //! Adds an idx_t to a buffer ID, the rightmost 32 bits of data contain the buffer ID - inline void IncreaseBufferId(const idx_t summand) { - data += summand; - } - - //! Comparison operator - inline bool operator==(const IndexPointer &ptr) const { - return data == ptr.data; - } - -private: - //! Data holds all the information contained in an IndexPointer - //! [0 - 7: metadata, - //! 8 - 23: offset, 24 - 63: buffer ID] - //! NOTE: we do not use bit fields because when using bit fields Windows compiles - //! the IndexPointer class into 16 bytes instead of the intended 8 bytes, doubling the - //! space requirements - //! https://learn.microsoft.com/en-us/cpp/cpp/cpp-bit-fields?view=msvc-170 - idx_t data; -}; - -static_assert(sizeof(IndexPointer) == sizeof(idx_t), "Invalid size for IndexPointer."); - -} // namespace duckdb - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/execution/index/fixed_size_allocator.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/execution/index/fixed_size_buffer.hpp -// -// -//===----------------------------------------------------------------------===// - - - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/storage/partial_block_manager.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/storage/storage_manager.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/storage/table_io_manager.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - -namespace duckdb { -class BlockManager; -class DataTable; -class MetadataManager; - -class TableIOManager { -public: - virtual ~TableIOManager() { - } - - //! Obtains a reference to the TableIOManager of a specific table - static TableIOManager &Get(DataTable &table); - - //! The block manager used for managing index data - virtual BlockManager &GetIndexBlockManager() = 0; - - //! The block manager used for storing row group data - virtual BlockManager &GetBlockManagerForRowData() = 0; - - virtual MetadataManager &GetMetadataManager() = 0; -}; - -} // namespace duckdb - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/storage/write_ahead_log.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/enums/wal_type.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - -namespace duckdb { - -enum class WALType : uint8_t { - INVALID = 0, - // ----------------------------- - // Catalog - // ----------------------------- - CREATE_TABLE = 1, - DROP_TABLE = 2, - - CREATE_SCHEMA = 3, - DROP_SCHEMA = 4, - - CREATE_VIEW = 5, - DROP_VIEW = 6, - - CREATE_SEQUENCE = 8, - DROP_SEQUENCE = 9, - SEQUENCE_VALUE = 10, - - CREATE_MACRO = 11, - DROP_MACRO = 12, - - CREATE_TYPE = 13, - DROP_TYPE = 14, - - ALTER_INFO = 20, - - CREATE_TABLE_MACRO = 21, - DROP_TABLE_MACRO = 22, - - CREATE_INDEX = 23, - DROP_INDEX = 24, - - // ----------------------------- - // Data - // ----------------------------- - USE_TABLE = 25, - INSERT_TUPLE = 26, - DELETE_TUPLE = 27, - UPDATE_TUPLE = 28, - // ----------------------------- - // Flush - // ----------------------------- - CHECKPOINT = 99, - WAL_FLUSH = 100 -}; -} - - - - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/catalog/catalog_entry/macro_catalog_entry.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - -namespace duckdb { - -//! A macro function in the catalog -class TableMacroCatalogEntry : public MacroCatalogEntry { -public: - static constexpr const CatalogType Type = CatalogType::TABLE_MACRO_ENTRY; - static constexpr const char *Name = "table macro function"; - -public: - TableMacroCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateMacroInfo &info); -}; - -} // namespace duckdb - - - - - -namespace duckdb { - -struct AlterInfo; - -class AttachedDatabase; -class Catalog; -class DatabaseInstance; -class SchemaCatalogEntry; -class SequenceCatalogEntry; -class ScalarMacroCatalogEntry; -class ViewCatalogEntry; -class TypeCatalogEntry; -class TableCatalogEntry; -class Transaction; -class TransactionManager; - -class ReplayState { -public: - ReplayState(AttachedDatabase &db, ClientContext &context) - : db(db), context(context), catalog(db.GetCatalog()), deserialize_only(false) { - } - - AttachedDatabase &db; - ClientContext &context; - Catalog &catalog; - optional_ptr current_table; - bool deserialize_only; - MetaBlockPointer checkpoint_id; - -public: - void ReplayEntry(WALType entry_type, BinaryDeserializer &deserializer); - -protected: - virtual void ReplayCreateTable(BinaryDeserializer &deserializer); - void ReplayDropTable(BinaryDeserializer &deserializer); - void ReplayAlter(BinaryDeserializer &deserializer); - - void ReplayCreateView(BinaryDeserializer &deserializer); - void ReplayDropView(BinaryDeserializer &deserializer); - - void ReplayCreateSchema(BinaryDeserializer &deserializer); - void ReplayDropSchema(BinaryDeserializer &deserializer); - - void ReplayCreateType(BinaryDeserializer &deserializer); - void ReplayDropType(BinaryDeserializer &deserializer); - - void ReplayCreateSequence(BinaryDeserializer &deserializer); - void ReplayDropSequence(BinaryDeserializer &deserializer); - void ReplaySequenceValue(BinaryDeserializer &deserializer); - - void ReplayCreateMacro(BinaryDeserializer &deserializer); - void ReplayDropMacro(BinaryDeserializer &deserializer); - - void ReplayCreateTableMacro(BinaryDeserializer &deserializer); - void ReplayDropTableMacro(BinaryDeserializer &deserializer); - - void ReplayCreateIndex(BinaryDeserializer &deserializer); - void ReplayDropIndex(BinaryDeserializer &deserializer); - - void ReplayUseTable(BinaryDeserializer &deserializer); - void ReplayInsert(BinaryDeserializer &deserializer); - void ReplayDelete(BinaryDeserializer &deserializer); - void ReplayUpdate(BinaryDeserializer &deserializer); - void ReplayCheckpoint(BinaryDeserializer &deserializer); -}; - -//! The WriteAheadLog (WAL) is a log that is used to provide durability. Prior -//! to committing a transaction it writes the changes the transaction made to -//! the database to the log, which can then be replayed upon startup in case the -//! server crashes or is shut down. -class WriteAheadLog { -public: - //! Initialize the WAL in the specified directory - explicit WriteAheadLog(AttachedDatabase &database, const string &path); - virtual ~WriteAheadLog(); - - //! Skip writing to the WAL - bool skip_writing; - -public: - //! Replay the WAL - static bool Replay(AttachedDatabase &database, string &path); - - //! Returns the current size of the WAL in bytes - int64_t GetWALSize(); - //! Gets the total bytes written to the WAL since startup - idx_t GetTotalWritten(); - - virtual void WriteCreateTable(const TableCatalogEntry &entry); - void WriteDropTable(const TableCatalogEntry &entry); - - void WriteCreateSchema(const SchemaCatalogEntry &entry); - void WriteDropSchema(const SchemaCatalogEntry &entry); - - void WriteCreateView(const ViewCatalogEntry &entry); - void WriteDropView(const ViewCatalogEntry &entry); - - void WriteCreateSequence(const SequenceCatalogEntry &entry); - void WriteDropSequence(const SequenceCatalogEntry &entry); - void WriteSequenceValue(const SequenceCatalogEntry &entry, SequenceValue val); - - void WriteCreateMacro(const ScalarMacroCatalogEntry &entry); - void WriteDropMacro(const ScalarMacroCatalogEntry &entry); - - void WriteCreateTableMacro(const TableMacroCatalogEntry &entry); - void WriteDropTableMacro(const TableMacroCatalogEntry &entry); - - void WriteCreateIndex(const IndexCatalogEntry &entry); - void WriteDropIndex(const IndexCatalogEntry &entry); - - void WriteCreateType(const TypeCatalogEntry &entry); - void WriteDropType(const TypeCatalogEntry &entry); - //! Sets the table used for subsequent insert/delete/update commands - void WriteSetTable(string &schema, string &table); - - void WriteAlter(const AlterInfo &info); - - void WriteInsert(DataChunk &chunk); - void WriteDelete(DataChunk &chunk); - //! Write a single (sub-) column update to the WAL. Chunk must be a pair of (COL, ROW_ID). - //! The column_path vector is a *path* towards a column within the table - //! i.e. if we have a table with a single column S STRUCT(A INT, B INT) - //! and we update the validity mask of "S.B" - //! the column path is: - //! 0 (first column of table) - //! -> 1 (second subcolumn of struct) - //! -> 0 (first subcolumn of INT) - void WriteUpdate(DataChunk &chunk, const vector &column_path); - - //! Truncate the WAL to a previous size, and clear anything currently set in the writer - void Truncate(int64_t size); - //! Delete the WAL file on disk. The WAL should not be used after this point. - void Delete(); - void Flush(); - - void WriteCheckpoint(MetaBlockPointer meta_block); - -protected: - AttachedDatabase &database; - unique_ptr writer; - string wal_path; -}; - -} // namespace duckdb - - - -namespace duckdb { -class BlockManager; -class Catalog; -class CheckpointWriter; -class DatabaseInstance; -class TransactionManager; -class TableCatalogEntry; - -class StorageCommitState { -public: - // Destruction of this object, without prior call to FlushCommit, - // will roll back the committed changes. - virtual ~StorageCommitState() { - } - - // Make the commit persistent - virtual void FlushCommit() = 0; -}; - -//! StorageManager is responsible for managing the physical storage of the -//! database on disk -class StorageManager { -public: - StorageManager(AttachedDatabase &db, string path, bool read_only); - virtual ~StorageManager(); - -public: - static StorageManager &Get(AttachedDatabase &db); - static StorageManager &Get(Catalog &catalog); - - //! Initialize a database or load an existing database from the given path - void Initialize(); - - DatabaseInstance &GetDatabase(); - AttachedDatabase &GetAttached() { - return db; - } - - //! Get the WAL of the StorageManager, returns nullptr if in-memory - optional_ptr GetWriteAheadLog() { - return wal.get(); - } - - string GetDBPath() { - return path; - } - bool InMemory(); - - virtual bool AutomaticCheckpoint(idx_t estimated_wal_bytes) = 0; - virtual unique_ptr GenStorageCommitState(Transaction &transaction, bool checkpoint) = 0; - virtual bool IsCheckpointClean(MetaBlockPointer checkpoint_id) = 0; - virtual void CreateCheckpoint(bool delete_wal = false, bool force_checkpoint = false) = 0; - virtual DatabaseSize GetDatabaseSize() = 0; - virtual vector GetMetadataInfo() = 0; - virtual shared_ptr GetTableIOManager(BoundCreateTableInfo *info) = 0; - -protected: - virtual void LoadDatabase() = 0; - -protected: - //! The database this storagemanager belongs to - AttachedDatabase &db; - //! The path of the database - string path; - //! The WriteAheadLog of the storage manager - unique_ptr wal; - //! Whether or not the database is opened in read-only mode - bool read_only; - -public: - template - TARGET &Cast() { - D_ASSERT(dynamic_cast(this)); - return reinterpret_cast(*this); - } - template - const TARGET &Cast() const { - D_ASSERT(dynamic_cast(this)); - return reinterpret_cast(*this); - } -}; - -//! Stores database in a single file. -class SingleFileStorageManager : public StorageManager { -public: - SingleFileStorageManager(AttachedDatabase &db, string path, bool read_only); - - //! The BlockManager to read/store meta information and data in blocks - unique_ptr block_manager; - //! TableIoManager - unique_ptr table_io_manager; - -public: - bool AutomaticCheckpoint(idx_t estimated_wal_bytes) override; - unique_ptr GenStorageCommitState(Transaction &transaction, bool checkpoint) override; - bool IsCheckpointClean(MetaBlockPointer checkpoint_id) override; - void CreateCheckpoint(bool delete_wal, bool force_checkpoint) override; - DatabaseSize GetDatabaseSize() override; - vector GetMetadataInfo() override; - shared_ptr GetTableIOManager(BoundCreateTableInfo *info) override; - -protected: - void LoadDatabase() override; -}; -} // namespace duckdb - - - - -namespace duckdb { -class DatabaseInstance; -class ClientContext; -class ColumnSegment; -class MetadataReader; -class SchemaCatalogEntry; -class SequenceCatalogEntry; -class TableCatalogEntry; -class ViewCatalogEntry; -class TypeCatalogEntry; - -//! Regions that require zero-initialization to avoid leaking memory -struct UninitializedRegion { - idx_t start; - idx_t end; -}; - -//! The current state of a partial block -struct PartialBlockState { - //! The block id of the partial block - block_id_t block_id; - //! The total bytes that we can assign to this block - uint32_t block_size; - //! Next allocation offset, and also the current allocation size - uint32_t offset; - //! The number of times that this block has been used for partial allocations - uint32_t block_use_count; -}; - -struct PartialBlock { - PartialBlock(PartialBlockState state, BlockManager &block_manager, const shared_ptr &block_handle); - virtual ~PartialBlock() { - } - - //! The current state of a partial block - PartialBlockState state; - //! All uninitialized regions on this block, we need to zero-initialize them when flushing - vector uninitialized_regions; - //! The block manager of the partial block manager - BlockManager &block_manager; - //! The block handle of the underlying block that this partial block writes to - shared_ptr block_handle; - -public: - //! Add regions that need zero-initialization to avoid leaking memory - void AddUninitializedRegion(const idx_t start, const idx_t end); - //! Flush the block to disk and zero-initialize any free space and uninitialized regions - virtual void Flush(const idx_t free_space_left) = 0; - void FlushInternal(const idx_t free_space_left); - virtual void Merge(PartialBlock &other, idx_t offset, idx_t other_size) = 0; - virtual void Clear() = 0; - -public: - template - TARGET &Cast() { - D_ASSERT(dynamic_cast(this)); - return reinterpret_cast(*this); - } -}; - -struct PartialBlockAllocation { - //! The BlockManager owning the block_id - BlockManager *block_manager {nullptr}; - //! The number of assigned bytes to the caller - uint32_t allocation_size; - //! The current state of the partial block - PartialBlockState state; - //! Arbitrary state related to the partial block storage - unique_ptr partial_block; -}; - -enum class CheckpointType { FULL_CHECKPOINT, APPEND_TO_TABLE }; - -//! Enables sharing blocks across some scope. Scope is whatever we want to share -//! blocks across. It may be an entire checkpoint or just a single row group. -//! In any case, they must share a block manager. -class PartialBlockManager { -public: - //! 20% free / 80% utilization - static constexpr const idx_t DEFAULT_MAX_PARTIAL_BLOCK_SIZE = Storage::BLOCK_SIZE / 5 * 4; - //! Max number of shared references to a block. No effective limit by default. - static constexpr const idx_t DEFAULT_MAX_USE_COUNT = 1u << 20; - //! No point letting map size grow unbounded. We'll drop blocks with the - //! least free space first. - static constexpr const idx_t MAX_BLOCK_MAP_SIZE = 1u << 31; - -public: - PartialBlockManager(BlockManager &block_manager, CheckpointType checkpoint_type, - uint32_t max_partial_block_size = DEFAULT_MAX_PARTIAL_BLOCK_SIZE, - uint32_t max_use_count = DEFAULT_MAX_USE_COUNT); - virtual ~PartialBlockManager(); - -public: - //! Flush any remaining partial blocks to disk - void FlushPartialBlocks(); - - PartialBlockAllocation GetBlockAllocation(uint32_t segment_size); - - virtual void AllocateBlock(PartialBlockState &state, uint32_t segment_size); - - void Merge(PartialBlockManager &other); - //! Register a partially filled block that is filled with "segment_size" entries - void RegisterPartialBlock(PartialBlockAllocation &&allocation); - - //! Clear remaining blocks without writing them to disk - void ClearBlocks(); - - //! Rollback all data written by this partial block manager - void Rollback(); - -protected: - BlockManager &block_manager; - CheckpointType checkpoint_type; - //! A map of (available space -> PartialBlock) for partially filled blocks - //! This is a multimap because there might be outstanding partial blocks with - //! the same amount of left-over space - multimap> partially_filled_blocks; - //! The set of written blocks - unordered_set written_blocks; - - //! The maximum size (in bytes) at which a partial block will be considered a partial block - uint32_t max_partial_block_size; - uint32_t max_use_count; - -protected: - //! Try to obtain a partially filled block that can fit "segment_size" bytes - //! If successful, returns true and returns the block_id and offset_in_block to write to - //! Otherwise, returns false - bool GetPartialBlock(idx_t segment_size, unique_ptr &state); - - bool HasBlockAllocation(uint32_t segment_size); - void AddWrittenBlock(block_id_t block); -}; - -} // namespace duckdb - - - - - - -namespace duckdb { - -class FixedSizeAllocator; -class MetadataWriter; - -struct PartialBlockForIndex : public PartialBlock { -public: - PartialBlockForIndex(PartialBlockState state, BlockManager &block_manager, - const shared_ptr &block_handle); - ~PartialBlockForIndex() override {}; - -public: - void Flush(const idx_t free_space_left) override; - void Clear() override; - void Merge(PartialBlock &other, idx_t offset, idx_t other_size) override; -}; - -//! A fixed-size buffer holds fixed-size segments of data. It lazily deserializes a buffer, if on-disk and not -//! yet in memory, and it only serializes dirty and non-written buffers to disk during -//! serialization. -class FixedSizeBuffer { -public: - //! Constants for fast offset calculations in the bitmask - static constexpr idx_t BASE[] = {0x00000000FFFFFFFF, 0x0000FFFF, 0x00FF, 0x0F, 0x3, 0x1}; - static constexpr uint8_t SHIFT[] = {32, 16, 8, 4, 2, 1}; - -public: - //! Constructor for a new in-memory buffer - explicit FixedSizeBuffer(BlockManager &block_manager); - //! Constructor for deserializing buffer metadata from disk - FixedSizeBuffer(BlockManager &block_manager, const idx_t segment_count, const idx_t allocation_size, - const BlockPointer &block_pointer); - - //! Block manager of the database instance - BlockManager &block_manager; - - //! The number of allocated segments - idx_t segment_count; - //! The size of allocated memory in this buffer (necessary for copying while pinning) - idx_t allocation_size; - - //! True: the in-memory buffer is no longer consistent with a (possibly existing) copy on disk - bool dirty; - //! True: can be vacuumed after the vacuum operation - bool vacuum; - - //! Partial block id and offset - BlockPointer block_pointer; - -public: - //! Returns true, if the buffer is in-memory - inline bool InMemory() const { - return buffer_handle.IsValid(); - } - //! Returns true, if the block is on-disk - inline bool OnDisk() const { - return block_pointer.IsValid(); - } - //! Returns a pointer to the buffer in memory, and calls Deserialize, if the buffer is not in memory - inline data_ptr_t Get(const bool dirty_p = true) { - if (!InMemory()) { - Pin(); - } - if (dirty_p) { - dirty = dirty_p; - } - return buffer_handle.Ptr(); - } - //! Destroys the in-memory buffer and the on-disk block - void Destroy(); - //! Serializes a buffer (if dirty or not on disk) - void Serialize(PartialBlockManager &partial_block_manager, const idx_t available_segments, const idx_t segment_size, - const idx_t bitmask_offset); - //! Pin a buffer (if not in-memory) - void Pin(); - //! Returns the first free offset in a bitmask - uint32_t GetOffset(const idx_t bitmask_count); - -private: - //! The buffer handle of the in-memory buffer - BufferHandle buffer_handle; - //! The block handle of the on-disk buffer - shared_ptr block_handle; - -private: - //! Returns the maximum non-free offset in a bitmask - uint32_t GetMaxOffset(const idx_t available_segments_per_buffer); - //! Sets all uninitialized regions of a buffer in the respective partial block allocation - void SetUninitializedRegions(PartialBlockForIndex &p_block_for_index, const idx_t segment_size, const idx_t offset, - const idx_t bitmask_offset); -}; - -} // namespace duckdb - - - - - - -namespace duckdb { - -//! The FixedSizeAllocator provides pointers to fixed-size memory segments of pre-allocated memory buffers. -//! The pointers are IndexPointers, and the leftmost byte (metadata) must always be zero. -//! It is also possible to directly request a C++ pointer to the underlying segment of an index pointer. -class FixedSizeAllocator { -public: - //! We can vacuum 10% or more of the total in-memory footprint - static constexpr uint8_t VACUUM_THRESHOLD = 10; - -public: - FixedSizeAllocator(const idx_t segment_size, BlockManager &block_manager); - - //! Block manager of the database instance - BlockManager &block_manager; - //! Buffer manager of the database instance - BufferManager &buffer_manager; - //! Metadata manager for (de)serialization - MetadataManager &metadata_manager; - -public: - //! Get a new IndexPointer to a segment, might cause a new buffer allocation - IndexPointer New(); - //! Free the segment of the IndexPointer - void Free(const IndexPointer ptr); - //! Returns a pointer of type T to a segment. If dirty is false, then T should be a const class - template - inline T *Get(const IndexPointer ptr, const bool dirty = true) { - return (T *)Get(ptr, dirty); - } - - //! Resets the allocator, e.g., during 'DELETE FROM table' - void Reset(); - - //! Returns the in-memory usage in bytes - inline idx_t GetMemoryUsage() const; - - //! Returns the upper bound of the available buffer IDs, i.e., upper_bound > max_buffer_id - idx_t GetUpperBoundBufferId() const; - //! Merge another FixedSizeAllocator into this allocator. Both must have the same segment size - void Merge(FixedSizeAllocator &other); - - //! Initialize a vacuum operation, and return true, if the allocator needs a vacuum - bool InitializeVacuum(); - //! Finalize a vacuum operation by freeing all vacuumed buffers - void FinalizeVacuum(); - //! Returns true, if an IndexPointer qualifies for a vacuum operation, and false otherwise - inline bool NeedsVacuum(const IndexPointer ptr) const { - if (vacuum_buffers.find(ptr.GetBufferId()) != vacuum_buffers.end()) { - return true; - } - return false; - } - //! Vacuums an IndexPointer - IndexPointer VacuumPointer(const IndexPointer ptr); - - //! Serializes all in-memory buffers and the metadata - BlockPointer Serialize(PartialBlockManager &partial_block_manager, MetadataWriter &writer); - //! Deserializes all metadata - void Deserialize(const BlockPointer &block_pointer); - -private: - //! Allocation size of one segment in a buffer - //! We only need this value to calculate bitmask_count, bitmask_offset, and - //! available_segments_per_buffer - idx_t segment_size; - - //! Number of validity_t values in the bitmask - idx_t bitmask_count; - //! First starting byte of the payload (segments) - idx_t bitmask_offset; - //! Number of possible segment allocations per buffer - idx_t available_segments_per_buffer; - - //! Total number of allocated segments in all buffers - //! We can recalculate this by iterating over all buffers - idx_t total_segment_count; - - //! Buffers containing the segments - unordered_map buffers; - //! Buffers with free space - unordered_set buffers_with_free_space; - //! Buffers qualifying for a vacuum (helper field to allow for fast NeedsVacuum checks) - unordered_set vacuum_buffers; - -private: - //! Returns the data_ptr_t to a segment, and sets the dirty flag of the buffer containing that segment - inline data_ptr_t Get(const IndexPointer ptr, const bool dirty = true) { - D_ASSERT(ptr.GetOffset() < available_segments_per_buffer); - D_ASSERT(buffers.find(ptr.GetBufferId()) != buffers.end()); - auto &buffer = buffers.find(ptr.GetBufferId())->second; - auto buffer_ptr = buffer.Get(dirty); - return buffer_ptr + ptr.GetOffset() * segment_size + bitmask_offset; - } - //! Returns an available buffer id - idx_t GetAvailableBufferId() const; -}; - -} // namespace duckdb - - -namespace duckdb { - -// classes -enum class NType : uint8_t { - PREFIX = 1, - LEAF = 2, - NODE_4 = 3, - NODE_16 = 4, - NODE_48 = 5, - NODE_256 = 6, - LEAF_INLINED = 7, -}; - -class ART; -class Prefix; -class MetadataReader; -class MetadataWriter; - -// structs -struct BlockPointer; -struct ARTFlags; -struct MetaBlockPointer; - -//! The Node is the pointer class of the ART index. -//! It inherits from the IndexPointer, and adds ART-specific functionality -class Node : public IndexPointer { -public: - //! Node thresholds - static constexpr uint8_t NODE_48_SHRINK_THRESHOLD = 12; - static constexpr uint8_t NODE_256_SHRINK_THRESHOLD = 36; - //! Node sizes - static constexpr uint8_t NODE_4_CAPACITY = 4; - static constexpr uint8_t NODE_16_CAPACITY = 16; - static constexpr uint8_t NODE_48_CAPACITY = 48; - static constexpr uint16_t NODE_256_CAPACITY = 256; - //! Other constants - static constexpr uint8_t EMPTY_MARKER = 48; - static constexpr uint8_t LEAF_SIZE = 4; - static constexpr uint8_t PREFIX_SIZE = 15; - static constexpr idx_t AND_ROW_ID = 0x00FFFFFFFFFFFFFF; - -public: - //! Get a new pointer to a node, might cause a new buffer allocation, and initialize it - static void New(ART &art, Node &node, const NType type); - //! Free the node (and its subtree) - static void Free(ART &art, Node &node); - - //! Get references to the allocator - static FixedSizeAllocator &GetAllocator(const ART &art, const NType type); - //! Get a (immutable) reference to the node. If dirty is false, then T should be a const class - template - static inline const NODE &Ref(const ART &art, const Node ptr, const NType type) { - return *(GetAllocator(art, type).Get(ptr, false)); - } - //! Get a (const) reference to the node. If dirty is false, then T should be a const class - template - static inline NODE &RefMutable(const ART &art, const Node ptr, const NType type) { - return *(GetAllocator(art, type).Get(ptr)); - } - - //! Replace the child node at byte - void ReplaceChild(const ART &art, const uint8_t byte, const Node child) const; - //! Insert the child node at byte - static void InsertChild(ART &art, Node &node, const uint8_t byte, const Node child); - //! Delete the child node at byte - static void DeleteChild(ART &art, Node &node, Node &prefix, const uint8_t byte); - - //! Get the child (immutable) for the respective byte in the node - optional_ptr GetChild(ART &art, const uint8_t byte) const; - //! Get the child for the respective byte in the node - optional_ptr GetChildMutable(ART &art, const uint8_t byte) const; - //! Get the first child (immutable) that is greater or equal to the specific byte - optional_ptr GetNextChild(ART &art, uint8_t &byte) const; - //! Get the first child that is greater or equal to the specific byte - optional_ptr GetNextChildMutable(ART &art, uint8_t &byte) const; - - //! Returns the string representation of the node, or only traverses and verifies the node and its subtree - string VerifyAndToString(ART &art, const bool only_verify) const; - //! Returns the capacity of the node - idx_t GetCapacity() const; - //! Returns the matching node type for a given count - static NType GetARTNodeTypeByCount(const idx_t count); - - //! Initializes a merge by incrementing the buffer IDs of a node and its subtree - void InitializeMerge(ART &art, const ARTFlags &flags); - //! Merge another node into this node - bool Merge(ART &art, Node &other); - //! Merge two nodes by first resolving their prefixes - bool ResolvePrefixes(ART &art, Node &other); - //! Merge two nodes that have no prefix or the same prefix - bool MergeInternal(ART &art, Node &other); - - //! Vacuum all nodes that exceed their respective vacuum thresholds - void Vacuum(ART &art, const ARTFlags &flags); - - //! Get the row ID (8th to 63rd bit) - inline row_t GetRowId() const { - return Get() & AND_ROW_ID; - } - //! Set the row ID (8th to 63rd bit) - inline void SetRowId(const row_t row_id) { - Set((Get() & AND_METADATA) | row_id); - } - - //! Returns the type of the node, which is held in the metadata - inline NType GetType() const { - return NType(GetMetadata()); - } - - //! Assign operator - inline void operator=(const IndexPointer &ptr) { - Set(ptr.Get()); - } -}; -} // namespace duckdb - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/array.hpp -// -// -//===----------------------------------------------------------------------===// - - - -#include - -namespace duckdb { -using std::array; -} - - -namespace duckdb { - -// classes -enum class VerifyExistenceType : uint8_t { - APPEND = 0, // appends to a table - APPEND_FK = 1, // appends to a table that has a foreign key - DELETE_FK = 2 // delete from a table that has a foreign key -}; -class ConflictManager; -class ARTKey; -class FixedSizeAllocator; - -// structs -struct ARTIndexScanState; -struct ARTFlags { - vector vacuum_flags; - vector merge_buffer_counts; -}; - -class ART : public Index { -public: - //! FixedSizeAllocator count of the ART - static constexpr uint8_t ALLOCATOR_COUNT = 6; - -public: - //! Constructs an ART - ART(const vector &column_ids, TableIOManager &table_io_manager, - const vector> &unbound_expressions, const IndexConstraintType constraint_type, - AttachedDatabase &db, - const shared_ptr, ALLOCATOR_COUNT>> &allocators_ptr = nullptr, - const BlockPointer &block = BlockPointer()); - - //! Root of the tree - Node tree = Node(); - //! Fixed-size allocators holding the ART nodes - shared_ptr, ALLOCATOR_COUNT>> allocators; - //! True, if the ART owns its data - bool owns_data; - -public: - //! Initialize a single predicate scan on the index with the given expression and column IDs - unique_ptr InitializeScanSinglePredicate(const Transaction &transaction, const Value &value, - const ExpressionType expression_type) override; - //! Initialize a two predicate scan on the index with the given expression and column IDs - unique_ptr InitializeScanTwoPredicates(const Transaction &transaction, const Value &low_value, - const ExpressionType low_expression_type, - const Value &high_value, - const ExpressionType high_expression_type) override; - //! Performs a lookup on the index, fetching up to max_count result IDs. Returns true if all row IDs were fetched, - //! and false otherwise - bool Scan(const Transaction &transaction, const DataTable &table, IndexScanState &state, const idx_t max_count, - vector &result_ids) override; - - //! Called when data is appended to the index. The lock obtained from InitializeLock must be held - PreservedError Append(IndexLock &lock, DataChunk &entries, Vector &row_identifiers) override; - //! Verify that data can be appended to the index without a constraint violation - void VerifyAppend(DataChunk &chunk) override; - //! Verify that data can be appended to the index without a constraint violation using the conflict manager - void VerifyAppend(DataChunk &chunk, ConflictManager &conflict_manager) override; - //! Deletes all data from the index. The lock obtained from InitializeLock must be held - void CommitDrop(IndexLock &index_lock) override; - //! Delete a chunk of entries from the index. The lock obtained from InitializeLock must be held - void Delete(IndexLock &lock, DataChunk &entries, Vector &row_identifiers) override; - //! Insert a chunk of entries into the index - PreservedError Insert(IndexLock &lock, DataChunk &data, Vector &row_ids) override; - - //! Construct an ART from a vector of sorted keys - bool ConstructFromSorted(idx_t count, vector &keys, Vector &row_identifiers); - - //! Search equal values and fetches the row IDs - bool SearchEqual(ARTKey &key, idx_t max_count, vector &result_ids); - //! Search equal values used for joins that do not need to fetch data - void SearchEqualJoinNoFetch(ARTKey &key, idx_t &result_size); - - //! Serializes the index and returns the pair of block_id offset positions - BlockPointer Serialize(MetadataWriter &writer) override; - - //! Merge another index into this index. The lock obtained from InitializeLock must be held, and the other - //! index must also be locked during the merge - bool MergeIndexes(IndexLock &state, Index &other_index) override; - - //! Traverses an ART and vacuums the qualifying nodes. The lock obtained from InitializeLock must be held - void Vacuum(IndexLock &state) override; - - //! Generate ART keys for an input chunk - static void GenerateKeys(ArenaAllocator &allocator, DataChunk &input, vector &keys); - - //! Generate a string containing all the expressions and their respective values that violate a constraint - string GenerateErrorKeyName(DataChunk &input, idx_t row); - //! Generate the matching error message for a constraint violation - string GenerateConstraintErrorMessage(VerifyExistenceType verify_type, const string &key_name); - //! Performs constraint checking for a chunk of input data - void CheckConstraintsForChunk(DataChunk &input, ConflictManager &conflict_manager) override; - - //! Returns the string representation of the ART, or only traverses and verifies the index - string VerifyAndToString(IndexLock &state, const bool only_verify) override; - - //! Find the node with a matching key, or return nullptr if not found - optional_ptr Lookup(const Node &node, const ARTKey &key, idx_t depth); - //! Insert a key into the tree - bool Insert(Node &node, const ARTKey &key, idx_t depth, const row_t &row_id); - -private: - //! Insert a row ID into a leaf - bool InsertToLeaf(Node &leaf, const row_t &row_id); - //! Erase a key from the tree (if a leaf has more than one value) or erase the leaf itself - void Erase(Node &node, const ARTKey &key, idx_t depth, const row_t &row_id); - - //! Returns all row IDs belonging to a key greater (or equal) than the search key - bool SearchGreater(ARTIndexScanState &state, ARTKey &key, bool equal, idx_t max_count, vector &result_ids); - //! Returns all row IDs belonging to a key less (or equal) than the upper_bound - bool SearchLess(ARTIndexScanState &state, ARTKey &upper_bound, bool equal, idx_t max_count, - vector &result_ids); - //! Returns all row IDs belonging to a key within the range of lower_bound and upper_bound - bool SearchCloseRange(ARTIndexScanState &state, ARTKey &lower_bound, ARTKey &upper_bound, bool left_equal, - bool right_equal, idx_t max_count, vector &result_ids); - - //! Initializes a merge operation by returning a set containing the buffer count of each fixed-size allocator - void InitializeMerge(ARTFlags &flags); - - //! Initializes a vacuum operation by calling the initialize operation of the respective - //! node allocator, and returns a vector containing either true, if the allocator at - //! the respective position qualifies, or false, if not - void InitializeVacuum(ARTFlags &flags); - //! Finalizes a vacuum operation by calling the finalize operation of all qualifying - //! fixed size allocators - void FinalizeVacuum(const ARTFlags &flags); - - //! Internal function to return the string representation of the ART, - //! or only traverses and verifies the index - string VerifyAndToStringInternal(const bool only_verify); - - //! Deserialize the allocators of the ART - void Deserialize(const BlockPointer &pointer); -}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/catalog/catalog_entry/dschema_catalog_entry.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - -namespace duckdb { - -//! A schema in the catalog -class DuckSchemaEntry : public SchemaCatalogEntry { -public: - DuckSchemaEntry(Catalog &catalog, string name, bool is_internal); - -private: - //! The catalog set holding the tables - CatalogSet tables; - //! The catalog set holding the indexes - CatalogSet indexes; - //! The catalog set holding the table functions - CatalogSet table_functions; - //! The catalog set holding the copy functions - CatalogSet copy_functions; - //! The catalog set holding the pragma functions - CatalogSet pragma_functions; - //! The catalog set holding the scalar and aggregate functions - CatalogSet functions; - //! The catalog set holding the sequences - CatalogSet sequences; - //! The catalog set holding the collations - CatalogSet collations; - //! The catalog set holding the types - CatalogSet types; - -public: - optional_ptr AddEntry(CatalogTransaction transaction, unique_ptr entry, - OnCreateConflict on_conflict); - optional_ptr AddEntryInternal(CatalogTransaction transaction, unique_ptr entry, - OnCreateConflict on_conflict, DependencyList dependencies); - - optional_ptr CreateTable(CatalogTransaction transaction, BoundCreateTableInfo &info) override; - optional_ptr CreateFunction(CatalogTransaction transaction, CreateFunctionInfo &info) override; - optional_ptr CreateIndex(ClientContext &context, CreateIndexInfo &info, - TableCatalogEntry &table) override; - optional_ptr CreateView(CatalogTransaction transaction, CreateViewInfo &info) override; - optional_ptr CreateSequence(CatalogTransaction transaction, CreateSequenceInfo &info) override; - optional_ptr CreateTableFunction(CatalogTransaction transaction, - CreateTableFunctionInfo &info) override; - optional_ptr CreateCopyFunction(CatalogTransaction transaction, - CreateCopyFunctionInfo &info) override; - optional_ptr CreatePragmaFunction(CatalogTransaction transaction, - CreatePragmaFunctionInfo &info) override; - optional_ptr CreateCollation(CatalogTransaction transaction, CreateCollationInfo &info) override; - optional_ptr CreateType(CatalogTransaction transaction, CreateTypeInfo &info) override; - void Alter(ClientContext &context, AlterInfo &info) override; - void Scan(ClientContext &context, CatalogType type, const std::function &callback) override; - void Scan(CatalogType type, const std::function &callback) override; - void DropEntry(ClientContext &context, DropInfo &info) override; - optional_ptr GetEntry(CatalogTransaction transaction, CatalogType type, const string &name) override; - SimilarCatalogEntry GetSimilarEntry(CatalogTransaction transaction, CatalogType type, const string &name) override; - - void Verify(Catalog &catalog) override; - -private: - //! Get the catalog set for the specified type - CatalogSet &GetCatalogSet(CatalogType type); -}; -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/catalog/default/default_functions.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - -namespace duckdb { -class SchemaCatalogEntry; - -struct DefaultMacro { - const char *schema; - const char *name; - const char *parameters[8]; - const char *macro; -}; - -class DefaultFunctionGenerator : public DefaultGenerator { -public: - DefaultFunctionGenerator(Catalog &catalog, SchemaCatalogEntry &schema); - - SchemaCatalogEntry &schema; - - DUCKDB_API static unique_ptr CreateInternalMacroInfo(DefaultMacro &default_macro); - DUCKDB_API static unique_ptr CreateInternalTableMacroInfo(DefaultMacro &default_macro); - -public: - unique_ptr CreateDefaultEntry(ClientContext &context, const string &entry_name) override; - vector GetDefaultEntries() override; - -private: - static unique_ptr CreateInternalTableMacroInfo(DefaultMacro &default_macro, - unique_ptr function); -}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/catalog/default/default_views.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - -namespace duckdb { -class SchemaCatalogEntry; - -class DefaultViewGenerator : public DefaultGenerator { -public: - DefaultViewGenerator(Catalog &catalog, SchemaCatalogEntry &schema); - - SchemaCatalogEntry &schema; - -public: - unique_ptr CreateDefaultEntry(ClientContext &context, const string &entry_name) override; - vector GetDefaultEntries() override; -}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/catalog/catalog_entry/dtable_catalog_entry.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - -namespace duckdb { - -//! A table catalog entry -class DuckTableEntry : public TableCatalogEntry { -public: - //! Create a TableCatalogEntry and initialize storage for it - DuckTableEntry(Catalog &catalog, SchemaCatalogEntry &schema, BoundCreateTableInfo &info, - std::shared_ptr inherited_storage = nullptr); - -public: - unique_ptr AlterEntry(ClientContext &context, AlterInfo &info) override; - void UndoAlter(ClientContext &context, AlterInfo &info) override; - //! Returns the underlying storage of the table - DataTable &GetStorage() override; - //! Returns a list of the bound constraints of the table - const vector> &GetBoundConstraints() override; - - //! Get statistics of a column (physical or virtual) within the table - unique_ptr GetStatistics(ClientContext &context, column_t column_id) override; - - unique_ptr Copy(ClientContext &context) const override; - - void SetAsRoot() override; - - void CommitAlter(string &column_name); - void CommitDrop(); - - TableFunction GetScanFunction(ClientContext &context, unique_ptr &bind_data) override; - - vector GetColumnSegmentInfo() override; - - TableStorageInfo GetStorageInfo(ClientContext &context) override; - - bool IsDuckTable() const override { - return true; - } - -private: - unique_ptr RenameColumn(ClientContext &context, RenameColumnInfo &info); - unique_ptr AddColumn(ClientContext &context, AddColumnInfo &info); - unique_ptr RemoveColumn(ClientContext &context, RemoveColumnInfo &info); - unique_ptr SetDefault(ClientContext &context, SetDefaultInfo &info); - unique_ptr ChangeColumnType(ClientContext &context, ChangeColumnTypeInfo &info); - unique_ptr SetNotNull(ClientContext &context, SetNotNullInfo &info); - unique_ptr DropNotNull(ClientContext &context, DropNotNullInfo &info); - unique_ptr AddForeignKeyConstraint(ClientContext &context, AlterForeignKeyInfo &info); - unique_ptr DropForeignKeyConstraint(ClientContext &context, AlterForeignKeyInfo &info); - - void UpdateConstraintsOnColumnDrop(const LogicalIndex &removed_index, const vector &adjusted_indices, - const RemoveColumnInfo &info, CreateTableInfo &create_info, bool is_generated); - -private: - //! A reference to the underlying storage unit used for this table - std::shared_ptr storage; - //! A list of constraints that are part of this table - vector> bound_constraints; - //! Manages dependencies of the individual columns of the table - ColumnDependencyManager column_dependency_manager; -}; -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/constraints/bound_foreign_key_constraint.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - -namespace duckdb { - -class BoundForeignKeyConstraint : public BoundConstraint { -public: - static constexpr const ConstraintType TYPE = ConstraintType::FOREIGN_KEY; - -public: - BoundForeignKeyConstraint(ForeignKeyInfo info_p, physical_index_set_t pk_key_set_p, - physical_index_set_t fk_key_set_p) - : BoundConstraint(ConstraintType::FOREIGN_KEY), info(std::move(info_p)), pk_key_set(std::move(pk_key_set_p)), - fk_key_set(std::move(fk_key_set_p)) { -#ifdef DEBUG - D_ASSERT(info.pk_keys.size() == pk_key_set.size()); - for (auto &key : info.pk_keys) { - D_ASSERT(pk_key_set.find(key) != pk_key_set.end()); - } - D_ASSERT(info.fk_keys.size() == fk_key_set.size()); - for (auto &key : info.fk_keys) { - D_ASSERT(fk_key_set.find(key) != fk_key_set.end()); - } -#endif - } - - ForeignKeyInfo info; - //! The same keys but stored as an unordered set - physical_index_set_t pk_key_set; - //! The same keys but stored as an unordered set - physical_index_set_t fk_key_set; -}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/parser/constraints/foreign_key_constraint.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - -namespace duckdb { - -class ForeignKeyConstraint : public Constraint { -public: - static constexpr const ConstraintType TYPE = ConstraintType::FOREIGN_KEY; - -public: - DUCKDB_API ForeignKeyConstraint(vector pk_columns, vector fk_columns, ForeignKeyInfo info); - - //! The set of main key table's columns - vector pk_columns; - //! The set of foreign key table's columns - vector fk_columns; - ForeignKeyInfo info; - -public: - DUCKDB_API string ToString() const override; - - DUCKDB_API unique_ptr Copy() const override; - - DUCKDB_API void Serialize(Serializer &serializer) const override; - DUCKDB_API static unique_ptr Deserialize(Deserializer &deserializer); - -private: - ForeignKeyConstraint(); -}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/constraints/bound_check_constraint.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - - -namespace duckdb { - -//! The CheckConstraint contains an expression that must evaluate to TRUE for -//! every row in a table -class BoundCheckConstraint : public BoundConstraint { -public: - static constexpr const ConstraintType TYPE = ConstraintType::CHECK; - -public: - BoundCheckConstraint() : BoundConstraint(ConstraintType::CHECK) { - } - - //! The expression - unique_ptr expression; - //! The columns used by the CHECK constraint - physical_index_set_t bound_columns; -}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/constraints/bound_not_null_constraint.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - -namespace duckdb { - -class BoundNotNullConstraint : public BoundConstraint { -public: - static constexpr const ConstraintType TYPE = ConstraintType::NOT_NULL; - -public: - explicit BoundNotNullConstraint(PhysicalIndex index) : BoundConstraint(ConstraintType::NOT_NULL), index(index) { - } - - //! Column index this constraint pertains to - PhysicalIndex index; -}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/constraints/bound_unique_constraint.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - -namespace duckdb { - -class BoundUniqueConstraint : public BoundConstraint { -public: - static constexpr const ConstraintType TYPE = ConstraintType::UNIQUE; - -public: - BoundUniqueConstraint(vector keys, logical_index_set_t key_set, bool is_primary_key) - : BoundConstraint(ConstraintType::UNIQUE), keys(std::move(keys)), key_set(std::move(key_set)), - is_primary_key(is_primary_key) { -#ifdef DEBUG - D_ASSERT(this->keys.size() == this->key_set.size()); - for (auto &key : this->keys) { - D_ASSERT(this->key_set.find(key) != this->key_set.end()); - } -#endif - } - - //! The keys that define the unique constraint - vector keys; - //! The same keys but stored as an unordered set - logical_index_set_t key_set; - //! Whether or not the unique constraint is a primary key - bool is_primary_key; -}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/expression/bound_reference_expression.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - -namespace duckdb { - -//! A BoundReferenceExpression represents a physical index into a DataChunk -class BoundReferenceExpression : public Expression { -public: - static constexpr const ExpressionClass TYPE = ExpressionClass::BOUND_REF; - -public: - BoundReferenceExpression(string alias, LogicalType type, idx_t index); - BoundReferenceExpression(LogicalType type, storage_t index); - - //! Index used to access data in the chunks - storage_t index; - -public: - bool IsScalar() const override { - return false; - } - bool IsFoldable() const override { - return false; - } - - string ToString() const override; - - hash_t Hash() const override; - bool Equals(const BaseExpression &other) const override; - - unique_ptr Copy() override; - - void Serialize(Serializer &serializer) const override; - static unique_ptr Deserialize(Deserializer &deserializer); -}; -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/expression_binder/alter_binder.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - -namespace duckdb { -class TableCatalogEntry; - -//! The ALTER binder is responsible for binding an expression within alter statements -class AlterBinder : public ExpressionBinder { -public: - AlterBinder(Binder &binder, ClientContext &context, TableCatalogEntry &table, vector &bound_columns, - LogicalType target_type); - - TableCatalogEntry &table; - vector &bound_columns; - -protected: - BindResult BindExpression(unique_ptr &expr_ptr, idx_t depth, - bool root_expression = false) override; - - BindResult BindColumn(ColumnRefExpression &expr); - - string UnsupportedAggregateMessage() override; -}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/parser/parsed_expression_iterator.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - -#include - -namespace duckdb { - -class ParsedExpressionIterator { -public: - static void EnumerateChildren(const ParsedExpression &expression, - const std::function &callback); - static void EnumerateChildren(ParsedExpression &expr, const std::function &callback); - static void EnumerateChildren(ParsedExpression &expr, - const std::function &child)> &callback); - - static void EnumerateTableRefChildren(TableRef &ref, - const std::function &child)> &callback); - static void EnumerateQueryNodeChildren(QueryNode &node, - const std::function &child)> &callback); - - static void EnumerateQueryNodeModifiers(QueryNode &node, - const std::function &child)> &callback); -}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/parser/constraints/check_constraint.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - -namespace duckdb { - -//! The CheckConstraint contains an expression that must evaluate to TRUE for -//! every row in a table -class CheckConstraint : public Constraint { -public: - static constexpr const ConstraintType TYPE = ConstraintType::CHECK; - -public: - DUCKDB_API explicit CheckConstraint(unique_ptr expression); - - unique_ptr expression; - -public: - DUCKDB_API string ToString() const override; - - DUCKDB_API unique_ptr Copy() const override; - - DUCKDB_API void Serialize(Serializer &serializer) const override; - DUCKDB_API static unique_ptr Deserialize(Deserializer &deserializer); -}; - -} // namespace duckdb - - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/parser/constraints/unique_constraint.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - -namespace duckdb { - -class UniqueConstraint : public Constraint { -public: - static constexpr const ConstraintType TYPE = ConstraintType::UNIQUE; - -public: - DUCKDB_API UniqueConstraint(LogicalIndex index, bool is_primary_key); - DUCKDB_API UniqueConstraint(vector columns, bool is_primary_key); - - //! The index of the column for which this constraint holds. Only used when the constraint relates to a single - //! column, equal to DConstants::INVALID_INDEX if not used - LogicalIndex index; - //! The set of columns for which this constraint holds by name. Only used when the index field is not used. - vector columns; - //! Whether or not this is a PRIMARY KEY constraint, or a UNIQUE constraint. - bool is_primary_key; - -public: - DUCKDB_API string ToString() const override; - - DUCKDB_API unique_ptr Copy() const override; - - DUCKDB_API void Serialize(Serializer &serializer) const override; - DUCKDB_API static unique_ptr Deserialize(Deserializer &deserializer); - -private: - UniqueConstraint(); -}; - -} // namespace duckdb - - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/function/table/table_scan.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - -namespace duckdb { -class DuckTableEntry; -class TableCatalogEntry; - -struct TableScanBindData : public TableFunctionData { - explicit TableScanBindData(DuckTableEntry &table) : table(table), is_index_scan(false), is_create_index(false) { - } - - //! The table to scan - DuckTableEntry &table; - - //! Whether or not the table scan is an index scan - bool is_index_scan; - //! Whether or not the table scan is for index creation - bool is_create_index; - //! The row ids to fetch (in case of an index scan) - vector result_ids; - -public: - bool Equals(const FunctionData &other_p) const override { - auto &other = (const TableScanBindData &)other_p; - return &other.table == &table && result_ids == other.result_ids; - } -}; - -//! The table scan function represents a sequential scan over one of DuckDB's base tables. -struct TableScanFunction { - static void RegisterFunction(BuiltinFunctions &set); - static TableFunction GetFunction(); - static TableFunction GetIndexScanFunction(); - static optional_ptr GetTableEntry(const TableFunction &function, - const optional_ptr bind_data); -}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/storage/table_storage_info.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - -namespace duckdb { - -struct ColumnSegmentInfo { - idx_t row_group_index; - idx_t column_id; - string column_path; - idx_t segment_idx; - string segment_type; - idx_t segment_start; - idx_t segment_count; - string compression_type; - string segment_stats; - bool has_updates; - bool persistent; - block_id_t block_id; - idx_t block_offset; - string segment_info; -}; - -struct IndexInfo { - bool is_unique; - bool is_primary; - bool is_foreign; - unordered_set column_set; -}; - -class TableStorageInfo { -public: - //! The (estimated) cardinality of the table - idx_t cardinality = DConstants::INVALID_INDEX; - //! Info of the indexes of a table - vector index_info; -}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/operator/logical_projection.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - -namespace duckdb { - -//! LogicalProjection represents the projection list in a SELECT clause -class LogicalProjection : public LogicalOperator { -public: - static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_PROJECTION; - -public: - LogicalProjection(idx_t table_index, vector> select_list); - - idx_t table_index; - -public: - vector GetColumnBindings() override; - void Serialize(Serializer &serializer) const override; - static unique_ptr Deserialize(Deserializer &deserializer); - - vector GetTableIndex() const override; - string GetName() const override; - -protected: - void ResolveTypes() override; -}; -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/operator/logical_update.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - -namespace duckdb { -class TableCatalogEntry; - -class LogicalUpdate : public LogicalOperator { -public: - static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_UPDATE; - -public: - explicit LogicalUpdate(TableCatalogEntry &table); - - //! The base table to update - TableCatalogEntry &table; - //! table catalog index - idx_t table_index; - //! if returning option is used, return the update chunk - bool return_chunk; - vector columns; - vector> bound_defaults; - bool update_is_del_and_insert; - -public: - void Serialize(Serializer &serializer) const override; - static unique_ptr Deserialize(Deserializer &deserializer); - - idx_t EstimateCardinality(ClientContext &context) override; - string GetName() const override; - -protected: - vector GetColumnBindings() override; - void ResolveTypes() override; - -private: - LogicalUpdate(ClientContext &context, const unique_ptr &table_info); -}; -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/function/scalar_macro_function.hpp -// -// -//===----------------------------------------------------------------------===// - - -//! The SelectStatement of the view - - - - - - - - -namespace duckdb { - -class ScalarMacroFunction : public MacroFunction { -public: - static constexpr const MacroType TYPE = MacroType::SCALAR_MACRO; - -public: - explicit ScalarMacroFunction(unique_ptr expression); - ScalarMacroFunction(void); - - //! The macro expression - unique_ptr expression; - -public: - unique_ptr Copy() const override; - - string ToSQL(const string &schema, const string &name) const override; - - void Serialize(Serializer &serializer) const override; - static unique_ptr Deserialize(Deserializer &deserializer); -}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/algorithm.hpp -// -// -//===----------------------------------------------------------------------===// - - - -#include -#include -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/catalog/dependency_manager.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/catalog/dependency.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - -namespace duckdb { -class CatalogEntry; - -enum class DependencyType { - DEPENDENCY_REGULAR = 0, - DEPENDENCY_AUTOMATIC = 1, - DEPENDENCY_OWNS = 2, - DEPENDENCY_OWNED_BY = 3 -}; - -struct Dependency { - Dependency(CatalogEntry &entry, DependencyType dependency_type = DependencyType::DEPENDENCY_REGULAR) - : // NOLINT: Allow implicit conversion from `CatalogEntry` - entry(entry), dependency_type(dependency_type) { - } - - //! The catalog entry this depends on - reference entry; - //! The type of dependency - DependencyType dependency_type; -}; - -struct DependencyHashFunction { - uint64_t operator()(const Dependency &a) const { - std::hash hash_func; - return hash_func((void *)&a.entry.get()); - } -}; - -struct DependencyEquality { - bool operator()(const Dependency &a, const Dependency &b) const { - return RefersToSameObject(a.entry, b.entry); - } -}; -using dependency_set_t = unordered_set; - -} // namespace duckdb - - - - -#include - -namespace duckdb { -class DuckCatalog; -class ClientContext; -class DependencyList; - -//! The DependencyManager is in charge of managing dependencies between catalog entries -class DependencyManager { - friend class CatalogSet; - -public: - explicit DependencyManager(DuckCatalog &catalog); - - //! Erase the object from the DependencyManager; this should only happen when the object itself is destroyed - void EraseObject(CatalogEntry &object); - - //! Scans all dependencies, returning pairs of (object, dependent) - void Scan(const std::function &callback); - - void AddOwnership(CatalogTransaction transaction, CatalogEntry &owner, CatalogEntry &entry); - -private: - DuckCatalog &catalog; - //! Map of objects that DEPEND on [object], i.e. [object] can only be deleted when all entries in the dependency map - //! are deleted. - catalog_entry_map_t dependents_map; - //! Map of objects that the source object DEPENDS on, i.e. when any of the entries in the vector perform a CASCADE - //! drop then [object] is deleted as well - catalog_entry_map_t dependencies_map; - -private: - void AddObject(CatalogTransaction transaction, CatalogEntry &object, DependencyList &dependencies); - void DropObject(CatalogTransaction transaction, CatalogEntry &object, bool cascade); - void AlterObject(CatalogTransaction transaction, CatalogEntry &old_obj, CatalogEntry &new_obj); - void EraseObjectInternal(CatalogEntry &object); -}; -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/serializer/binary_serializer.hpp -// -// -//===----------------------------------------------------------------------===// - - - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/serializer/serializer.hpp -// -// -//===----------------------------------------------------------------------===// - - - -//------------------------------------------------------------------------- -// This file is automatically generated by scripts/generate_enum_util.py -// Do not edit this file manually, your changes will be overwritten -// If you want to exclude an enum from serialization, add it to the blacklist in the script -// -// Note: The generated code will only work properly if the enum is a top level item in the duckdb namespace -// If the enum is nested in a class, or in another namespace, the generated code will not compile. -// You should move the enum to the duckdb namespace, manually write a specialization or add it to the blacklist -//------------------------------------------------------------------------- - - - - -#include - - -namespace duckdb { - -struct EnumUtil { - // String -> Enum - template - static T FromString(const char *value) = delete; - - template - static T FromString(const string &value) { return FromString(value.c_str()); } - - // Enum -> String - template - static const char *ToChars(T value) = delete; - - template - static string ToString(T value) { return string(ToChars(value)); } -}; - -enum class AccessMode : uint8_t; - -enum class AggregateHandling : uint8_t; - -enum class AggregateOrderDependent : uint8_t; - -enum class AggregateType : uint8_t; - -enum class AlterForeignKeyType : uint8_t; - -enum class AlterScalarFunctionType : uint8_t; - -enum class AlterTableFunctionType : uint8_t; - -enum class AlterTableType : uint8_t; - -enum class AlterType : uint8_t; - -enum class AlterViewType : uint8_t; - -enum class AppenderType : uint8_t; - -enum class ArrowDateTimeType : uint8_t; - -enum class ArrowVariableSizeType : uint8_t; - -enum class BindingMode : uint8_t; - -enum class BitpackingMode : uint8_t; - -enum class BlockState : uint8_t; - -enum class CAPIResultSetType : uint8_t; - -enum class CSVState : uint8_t; - -enum class CTEMaterialize : uint8_t; - -enum class CatalogType : uint8_t; - -enum class CheckpointAbort : uint8_t; - -enum class ChunkInfoType : uint8_t; - -enum class ColumnDataAllocatorType : uint8_t; - -enum class ColumnDataScanProperties : uint8_t; - -enum class ColumnSegmentType : uint8_t; - -enum class CompressedMaterializationDirection : uint8_t; - -enum class CompressionType : uint8_t; - -enum class ConflictManagerMode : uint8_t; - -enum class ConstraintType : uint8_t; - -enum class DataFileType : uint8_t; - -enum class DatePartSpecifier : uint8_t; - -enum class DebugInitialize : uint8_t; - -enum class DefaultOrderByNullType : uint8_t; - -enum class DistinctType : uint8_t; - -enum class ErrorType : uint16_t; - -enum class ExceptionFormatValueType : uint8_t; - -enum class ExplainOutputType : uint8_t; - -enum class ExplainType : uint8_t; - -enum class ExpressionClass : uint8_t; - -enum class ExpressionType : uint8_t; - -enum class ExtensionLoadResult : uint8_t; - -enum class ExtraTypeInfoType : uint8_t; - -enum class FileBufferType : uint8_t; - -enum class FileCompressionType : uint8_t; - -enum class FileGlobOptions : uint8_t; - -enum class FileLockType : uint8_t; - -enum class FilterPropagateResult : uint8_t; - -enum class ForeignKeyType : uint8_t; - -enum class FunctionNullHandling : uint8_t; - -enum class FunctionSideEffects : uint8_t; - -enum class HLLStorageType : uint8_t; - -enum class IndexConstraintType : uint8_t; - -enum class IndexType : uint8_t; - -enum class InsertColumnOrder : uint8_t; - -enum class InterruptMode : uint8_t; - -enum class JoinRefType : uint8_t; - -enum class JoinType : uint8_t; - -enum class KeywordCategory : uint8_t; - -enum class LoadType : uint8_t; - -enum class LogicalOperatorType : uint8_t; - -enum class LogicalTypeId : uint8_t; - -enum class LookupResultType : uint8_t; - -enum class MacroType : uint8_t; - -enum class MapInvalidReason : uint8_t; - -enum class NType : uint8_t; - -enum class NewLineIdentifier : uint8_t; - -enum class OnConflictAction : uint8_t; - -enum class OnCreateConflict : uint8_t; - -enum class OnEntryNotFound : uint8_t; - -enum class OperatorFinalizeResultType : uint8_t; - -enum class OperatorResultType : uint8_t; - -enum class OptimizerType : uint32_t; - -enum class OrderByNullType : uint8_t; - -enum class OrderPreservationType : uint8_t; - -enum class OrderType : uint8_t; - -enum class OutputStream : uint8_t; - -enum class ParseInfoType : uint8_t; - -enum class ParserExtensionResultType : uint8_t; - -enum class ParserMode : uint8_t; - -enum class PartitionSortStage : uint8_t; - -enum class PartitionedColumnDataType : uint8_t; - -enum class PartitionedTupleDataType : uint8_t; - -enum class PendingExecutionResult : uint8_t; - -enum class PhysicalOperatorType : uint8_t; - -enum class PhysicalType : uint8_t; - -enum class PragmaType : uint8_t; - -enum class PreparedParamType : uint8_t; - -enum class ProfilerPrintFormat : uint8_t; - -enum class QueryNodeType : uint8_t; - -enum class QueryResultType : uint8_t; - -enum class QuoteRule : uint8_t; - -enum class RelationType : uint8_t; - -enum class RenderMode : uint8_t; - -enum class ResultModifierType : uint8_t; - -enum class SampleMethod : uint8_t; - -enum class SequenceInfo : uint8_t; - -enum class SetOperationType : uint8_t; - -enum class SetScope : uint8_t; - -enum class SetType : uint8_t; - -enum class SimplifiedTokenType : uint8_t; - -enum class SinkCombineResultType : uint8_t; - -enum class SinkFinalizeType : uint8_t; - -enum class SinkResultType : uint8_t; - -enum class SourceResultType : uint8_t; - -enum class StatementReturnType : uint8_t; - -enum class StatementType : uint8_t; - -enum class StatisticsType : uint8_t; - -enum class StatsInfo : uint8_t; - -enum class StrTimeSpecifier : uint8_t; - -enum class SubqueryType : uint8_t; - -enum class TableColumnType : uint8_t; - -enum class TableFilterType : uint8_t; - -enum class TableReferenceType : uint8_t; - -enum class TableScanType : uint8_t; - -enum class TaskExecutionMode : uint8_t; - -enum class TaskExecutionResult : uint8_t; - -enum class TimestampCastResult : uint8_t; - -enum class TransactionType : uint8_t; - -enum class TupleDataPinProperties : uint8_t; - -enum class UndoFlags : uint32_t; - -enum class UnionInvalidReason : uint8_t; - -enum class VectorAuxiliaryDataType : uint8_t; - -enum class VectorBufferType : uint8_t; - -enum class VectorType : uint8_t; - -enum class VerificationType : uint8_t; - -enum class VerifyExistenceType : uint8_t; - -enum class WALType : uint8_t; - -enum class WindowAggregationMode : uint32_t; - -enum class WindowBoundary : uint8_t; - - -template<> -const char* EnumUtil::ToChars(AccessMode value); - -template<> -const char* EnumUtil::ToChars(AggregateHandling value); - -template<> -const char* EnumUtil::ToChars(AggregateOrderDependent value); - -template<> -const char* EnumUtil::ToChars(AggregateType value); - -template<> -const char* EnumUtil::ToChars(AlterForeignKeyType value); - -template<> -const char* EnumUtil::ToChars(AlterScalarFunctionType value); - -template<> -const char* EnumUtil::ToChars(AlterTableFunctionType value); - -template<> -const char* EnumUtil::ToChars(AlterTableType value); - -template<> -const char* EnumUtil::ToChars(AlterType value); - -template<> -const char* EnumUtil::ToChars(AlterViewType value); - -template<> -const char* EnumUtil::ToChars(AppenderType value); - -template<> -const char* EnumUtil::ToChars(ArrowDateTimeType value); - -template<> -const char* EnumUtil::ToChars(ArrowVariableSizeType value); - -template<> -const char* EnumUtil::ToChars(BindingMode value); - -template<> -const char* EnumUtil::ToChars(BitpackingMode value); - -template<> -const char* EnumUtil::ToChars(BlockState value); - -template<> -const char* EnumUtil::ToChars(CAPIResultSetType value); - -template<> -const char* EnumUtil::ToChars(CSVState value); - -template<> -const char* EnumUtil::ToChars(CTEMaterialize value); - -template<> -const char* EnumUtil::ToChars(CatalogType value); - -template<> -const char* EnumUtil::ToChars(CheckpointAbort value); - -template<> -const char* EnumUtil::ToChars(ChunkInfoType value); - -template<> -const char* EnumUtil::ToChars(ColumnDataAllocatorType value); - -template<> -const char* EnumUtil::ToChars(ColumnDataScanProperties value); - -template<> -const char* EnumUtil::ToChars(ColumnSegmentType value); - -template<> -const char* EnumUtil::ToChars(CompressedMaterializationDirection value); - -template<> -const char* EnumUtil::ToChars(CompressionType value); - -template<> -const char* EnumUtil::ToChars(ConflictManagerMode value); - -template<> -const char* EnumUtil::ToChars(ConstraintType value); - -template<> -const char* EnumUtil::ToChars(DataFileType value); - -template<> -const char* EnumUtil::ToChars(DatePartSpecifier value); - -template<> -const char* EnumUtil::ToChars(DebugInitialize value); - -template<> -const char* EnumUtil::ToChars(DefaultOrderByNullType value); - -template<> -const char* EnumUtil::ToChars(DistinctType value); - -template<> -const char* EnumUtil::ToChars(ErrorType value); - -template<> -const char* EnumUtil::ToChars(ExceptionFormatValueType value); - -template<> -const char* EnumUtil::ToChars(ExplainOutputType value); - -template<> -const char* EnumUtil::ToChars(ExplainType value); - -template<> -const char* EnumUtil::ToChars(ExpressionClass value); - -template<> -const char* EnumUtil::ToChars(ExpressionType value); - -template<> -const char* EnumUtil::ToChars(ExtensionLoadResult value); - -template<> -const char* EnumUtil::ToChars(ExtraTypeInfoType value); - -template<> -const char* EnumUtil::ToChars(FileBufferType value); - -template<> -const char* EnumUtil::ToChars(FileCompressionType value); - -template<> -const char* EnumUtil::ToChars(FileGlobOptions value); - -template<> -const char* EnumUtil::ToChars(FileLockType value); - -template<> -const char* EnumUtil::ToChars(FilterPropagateResult value); - -template<> -const char* EnumUtil::ToChars(ForeignKeyType value); - -template<> -const char* EnumUtil::ToChars(FunctionNullHandling value); - -template<> -const char* EnumUtil::ToChars(FunctionSideEffects value); - -template<> -const char* EnumUtil::ToChars(HLLStorageType value); - -template<> -const char* EnumUtil::ToChars(IndexConstraintType value); - -template<> -const char* EnumUtil::ToChars(IndexType value); - -template<> -const char* EnumUtil::ToChars(InsertColumnOrder value); - -template<> -const char* EnumUtil::ToChars(InterruptMode value); - -template<> -const char* EnumUtil::ToChars(JoinRefType value); - -template<> -const char* EnumUtil::ToChars(JoinType value); - -template<> -const char* EnumUtil::ToChars(KeywordCategory value); - -template<> -const char* EnumUtil::ToChars(LoadType value); - -template<> -const char* EnumUtil::ToChars(LogicalOperatorType value); - -template<> -const char* EnumUtil::ToChars(LogicalTypeId value); - -template<> -const char* EnumUtil::ToChars(LookupResultType value); - -template<> -const char* EnumUtil::ToChars(MacroType value); - -template<> -const char* EnumUtil::ToChars(MapInvalidReason value); - -template<> -const char* EnumUtil::ToChars(NType value); - -template<> -const char* EnumUtil::ToChars(NewLineIdentifier value); - -template<> -const char* EnumUtil::ToChars(OnConflictAction value); - -template<> -const char* EnumUtil::ToChars(OnCreateConflict value); - -template<> -const char* EnumUtil::ToChars(OnEntryNotFound value); - -template<> -const char* EnumUtil::ToChars(OperatorFinalizeResultType value); - -template<> -const char* EnumUtil::ToChars(OperatorResultType value); - -template<> -const char* EnumUtil::ToChars(OptimizerType value); - -template<> -const char* EnumUtil::ToChars(OrderByNullType value); - -template<> -const char* EnumUtil::ToChars(OrderPreservationType value); - -template<> -const char* EnumUtil::ToChars(OrderType value); - -template<> -const char* EnumUtil::ToChars(OutputStream value); - -template<> -const char* EnumUtil::ToChars(ParseInfoType value); - -template<> -const char* EnumUtil::ToChars(ParserExtensionResultType value); - -template<> -const char* EnumUtil::ToChars(ParserMode value); - -template<> -const char* EnumUtil::ToChars(PartitionSortStage value); - -template<> -const char* EnumUtil::ToChars(PartitionedColumnDataType value); - -template<> -const char* EnumUtil::ToChars(PartitionedTupleDataType value); - -template<> -const char* EnumUtil::ToChars(PendingExecutionResult value); - -template<> -const char* EnumUtil::ToChars(PhysicalOperatorType value); - -template<> -const char* EnumUtil::ToChars(PhysicalType value); - -template<> -const char* EnumUtil::ToChars(PragmaType value); - -template<> -const char* EnumUtil::ToChars(PreparedParamType value); - -template<> -const char* EnumUtil::ToChars(ProfilerPrintFormat value); - -template<> -const char* EnumUtil::ToChars(QueryNodeType value); - -template<> -const char* EnumUtil::ToChars(QueryResultType value); - -template<> -const char* EnumUtil::ToChars(QuoteRule value); - -template<> -const char* EnumUtil::ToChars(RelationType value); - -template<> -const char* EnumUtil::ToChars(RenderMode value); - -template<> -const char* EnumUtil::ToChars(ResultModifierType value); - -template<> -const char* EnumUtil::ToChars(SampleMethod value); - -template<> -const char* EnumUtil::ToChars(SequenceInfo value); - -template<> -const char* EnumUtil::ToChars(SetOperationType value); - -template<> -const char* EnumUtil::ToChars(SetScope value); - -template<> -const char* EnumUtil::ToChars(SetType value); - -template<> -const char* EnumUtil::ToChars(SimplifiedTokenType value); - -template<> -const char* EnumUtil::ToChars(SinkCombineResultType value); - -template<> -const char* EnumUtil::ToChars(SinkFinalizeType value); - -template<> -const char* EnumUtil::ToChars(SinkResultType value); - -template<> -const char* EnumUtil::ToChars(SourceResultType value); - -template<> -const char* EnumUtil::ToChars(StatementReturnType value); - -template<> -const char* EnumUtil::ToChars(StatementType value); - -template<> -const char* EnumUtil::ToChars(StatisticsType value); - -template<> -const char* EnumUtil::ToChars(StatsInfo value); - -template<> -const char* EnumUtil::ToChars(StrTimeSpecifier value); - -template<> -const char* EnumUtil::ToChars(SubqueryType value); - -template<> -const char* EnumUtil::ToChars(TableColumnType value); - -template<> -const char* EnumUtil::ToChars(TableFilterType value); - -template<> -const char* EnumUtil::ToChars(TableReferenceType value); - -template<> -const char* EnumUtil::ToChars(TableScanType value); - -template<> -const char* EnumUtil::ToChars(TaskExecutionMode value); - -template<> -const char* EnumUtil::ToChars(TaskExecutionResult value); - -template<> -const char* EnumUtil::ToChars(TimestampCastResult value); - -template<> -const char* EnumUtil::ToChars(TransactionType value); - -template<> -const char* EnumUtil::ToChars(TupleDataPinProperties value); - -template<> -const char* EnumUtil::ToChars(UndoFlags value); - -template<> -const char* EnumUtil::ToChars(UnionInvalidReason value); - -template<> -const char* EnumUtil::ToChars(VectorAuxiliaryDataType value); - -template<> -const char* EnumUtil::ToChars(VectorBufferType value); - -template<> -const char* EnumUtil::ToChars(VectorType value); - -template<> -const char* EnumUtil::ToChars(VerificationType value); - -template<> -const char* EnumUtil::ToChars(VerifyExistenceType value); - -template<> -const char* EnumUtil::ToChars(WALType value); - -template<> -const char* EnumUtil::ToChars(WindowAggregationMode value); - -template<> -const char* EnumUtil::ToChars(WindowBoundary value); - - -template<> -AccessMode EnumUtil::FromString(const char *value); - -template<> -AggregateHandling EnumUtil::FromString(const char *value); - -template<> -AggregateOrderDependent EnumUtil::FromString(const char *value); - -template<> -AggregateType EnumUtil::FromString(const char *value); - -template<> -AlterForeignKeyType EnumUtil::FromString(const char *value); - -template<> -AlterScalarFunctionType EnumUtil::FromString(const char *value); - -template<> -AlterTableFunctionType EnumUtil::FromString(const char *value); - -template<> -AlterTableType EnumUtil::FromString(const char *value); - -template<> -AlterType EnumUtil::FromString(const char *value); - -template<> -AlterViewType EnumUtil::FromString(const char *value); - -template<> -AppenderType EnumUtil::FromString(const char *value); - -template<> -ArrowDateTimeType EnumUtil::FromString(const char *value); - -template<> -ArrowVariableSizeType EnumUtil::FromString(const char *value); - -template<> -BindingMode EnumUtil::FromString(const char *value); - -template<> -BitpackingMode EnumUtil::FromString(const char *value); - -template<> -BlockState EnumUtil::FromString(const char *value); - -template<> -CAPIResultSetType EnumUtil::FromString(const char *value); - -template<> -CSVState EnumUtil::FromString(const char *value); - -template<> -CTEMaterialize EnumUtil::FromString(const char *value); - -template<> -CatalogType EnumUtil::FromString(const char *value); - -template<> -CheckpointAbort EnumUtil::FromString(const char *value); - -template<> -ChunkInfoType EnumUtil::FromString(const char *value); - -template<> -ColumnDataAllocatorType EnumUtil::FromString(const char *value); - -template<> -ColumnDataScanProperties EnumUtil::FromString(const char *value); - -template<> -ColumnSegmentType EnumUtil::FromString(const char *value); - -template<> -CompressedMaterializationDirection EnumUtil::FromString(const char *value); - -template<> -CompressionType EnumUtil::FromString(const char *value); - -template<> -ConflictManagerMode EnumUtil::FromString(const char *value); - -template<> -ConstraintType EnumUtil::FromString(const char *value); - -template<> -DataFileType EnumUtil::FromString(const char *value); - -template<> -DatePartSpecifier EnumUtil::FromString(const char *value); - -template<> -DebugInitialize EnumUtil::FromString(const char *value); - -template<> -DefaultOrderByNullType EnumUtil::FromString(const char *value); - -template<> -DistinctType EnumUtil::FromString(const char *value); - -template<> -ErrorType EnumUtil::FromString(const char *value); - -template<> -ExceptionFormatValueType EnumUtil::FromString(const char *value); - -template<> -ExplainOutputType EnumUtil::FromString(const char *value); - -template<> -ExplainType EnumUtil::FromString(const char *value); - -template<> -ExpressionClass EnumUtil::FromString(const char *value); - -template<> -ExpressionType EnumUtil::FromString(const char *value); - -template<> -ExtensionLoadResult EnumUtil::FromString(const char *value); - -template<> -ExtraTypeInfoType EnumUtil::FromString(const char *value); - -template<> -FileBufferType EnumUtil::FromString(const char *value); - -template<> -FileCompressionType EnumUtil::FromString(const char *value); - -template<> -FileGlobOptions EnumUtil::FromString(const char *value); - -template<> -FileLockType EnumUtil::FromString(const char *value); - -template<> -FilterPropagateResult EnumUtil::FromString(const char *value); - -template<> -ForeignKeyType EnumUtil::FromString(const char *value); - -template<> -FunctionNullHandling EnumUtil::FromString(const char *value); - -template<> -FunctionSideEffects EnumUtil::FromString(const char *value); - -template<> -HLLStorageType EnumUtil::FromString(const char *value); - -template<> -IndexConstraintType EnumUtil::FromString(const char *value); - -template<> -IndexType EnumUtil::FromString(const char *value); - -template<> -InsertColumnOrder EnumUtil::FromString(const char *value); - -template<> -InterruptMode EnumUtil::FromString(const char *value); - -template<> -JoinRefType EnumUtil::FromString(const char *value); - -template<> -JoinType EnumUtil::FromString(const char *value); - -template<> -KeywordCategory EnumUtil::FromString(const char *value); - -template<> -LoadType EnumUtil::FromString(const char *value); - -template<> -LogicalOperatorType EnumUtil::FromString(const char *value); - -template<> -LogicalTypeId EnumUtil::FromString(const char *value); - -template<> -LookupResultType EnumUtil::FromString(const char *value); - -template<> -MacroType EnumUtil::FromString(const char *value); - -template<> -MapInvalidReason EnumUtil::FromString(const char *value); - -template<> -NType EnumUtil::FromString(const char *value); - -template<> -NewLineIdentifier EnumUtil::FromString(const char *value); - -template<> -OnConflictAction EnumUtil::FromString(const char *value); - -template<> -OnCreateConflict EnumUtil::FromString(const char *value); - -template<> -OnEntryNotFound EnumUtil::FromString(const char *value); - -template<> -OperatorFinalizeResultType EnumUtil::FromString(const char *value); - -template<> -OperatorResultType EnumUtil::FromString(const char *value); - -template<> -OptimizerType EnumUtil::FromString(const char *value); - -template<> -OrderByNullType EnumUtil::FromString(const char *value); - -template<> -OrderPreservationType EnumUtil::FromString(const char *value); - -template<> -OrderType EnumUtil::FromString(const char *value); - -template<> -OutputStream EnumUtil::FromString(const char *value); - -template<> -ParseInfoType EnumUtil::FromString(const char *value); - -template<> -ParserExtensionResultType EnumUtil::FromString(const char *value); - -template<> -ParserMode EnumUtil::FromString(const char *value); - -template<> -PartitionSortStage EnumUtil::FromString(const char *value); - -template<> -PartitionedColumnDataType EnumUtil::FromString(const char *value); - -template<> -PartitionedTupleDataType EnumUtil::FromString(const char *value); - -template<> -PendingExecutionResult EnumUtil::FromString(const char *value); - -template<> -PhysicalOperatorType EnumUtil::FromString(const char *value); - -template<> -PhysicalType EnumUtil::FromString(const char *value); - -template<> -PragmaType EnumUtil::FromString(const char *value); - -template<> -PreparedParamType EnumUtil::FromString(const char *value); - -template<> -ProfilerPrintFormat EnumUtil::FromString(const char *value); - -template<> -QueryNodeType EnumUtil::FromString(const char *value); - -template<> -QueryResultType EnumUtil::FromString(const char *value); - -template<> -QuoteRule EnumUtil::FromString(const char *value); - -template<> -RelationType EnumUtil::FromString(const char *value); - -template<> -RenderMode EnumUtil::FromString(const char *value); - -template<> -ResultModifierType EnumUtil::FromString(const char *value); - -template<> -SampleMethod EnumUtil::FromString(const char *value); - -template<> -SequenceInfo EnumUtil::FromString(const char *value); - -template<> -SetOperationType EnumUtil::FromString(const char *value); - -template<> -SetScope EnumUtil::FromString(const char *value); - -template<> -SetType EnumUtil::FromString(const char *value); - -template<> -SimplifiedTokenType EnumUtil::FromString(const char *value); - -template<> -SinkCombineResultType EnumUtil::FromString(const char *value); - -template<> -SinkFinalizeType EnumUtil::FromString(const char *value); - -template<> -SinkResultType EnumUtil::FromString(const char *value); - -template<> -SourceResultType EnumUtil::FromString(const char *value); - -template<> -StatementReturnType EnumUtil::FromString(const char *value); - -template<> -StatementType EnumUtil::FromString(const char *value); - -template<> -StatisticsType EnumUtil::FromString(const char *value); - -template<> -StatsInfo EnumUtil::FromString(const char *value); - -template<> -StrTimeSpecifier EnumUtil::FromString(const char *value); - -template<> -SubqueryType EnumUtil::FromString(const char *value); - -template<> -TableColumnType EnumUtil::FromString(const char *value); - -template<> -TableFilterType EnumUtil::FromString(const char *value); - -template<> -TableReferenceType EnumUtil::FromString(const char *value); - -template<> -TableScanType EnumUtil::FromString(const char *value); - -template<> -TaskExecutionMode EnumUtil::FromString(const char *value); - -template<> -TaskExecutionResult EnumUtil::FromString(const char *value); - -template<> -TimestampCastResult EnumUtil::FromString(const char *value); - -template<> -TransactionType EnumUtil::FromString(const char *value); - -template<> -TupleDataPinProperties EnumUtil::FromString(const char *value); - -template<> -UndoFlags EnumUtil::FromString(const char *value); - -template<> -UnionInvalidReason EnumUtil::FromString(const char *value); - -template<> -VectorAuxiliaryDataType EnumUtil::FromString(const char *value); - -template<> -VectorBufferType EnumUtil::FromString(const char *value); - -template<> -VectorType EnumUtil::FromString(const char *value); - -template<> -VerificationType EnumUtil::FromString(const char *value); - -template<> -VerifyExistenceType EnumUtil::FromString(const char *value); - -template<> -WALType EnumUtil::FromString(const char *value); - -template<> -WindowAggregationMode EnumUtil::FromString(const char *value); - -template<> -WindowBoundary EnumUtil::FromString(const char *value); - - -} - - -#include -#include - - - - - - - - - -namespace duckdb { - -class Serializer; // Forward declare -class Deserializer; // Forward declare - -typedef uint16_t field_id_t; -const field_id_t MESSAGE_TERMINATOR_FIELD_ID = 0xFFFF; - -// Backport to c++11 -template -using void_t = void; - -// Check for anything implementing a `void Serialize(Serializer &Serializer)` method -template -struct has_serialize : std::false_type {}; -template -struct has_serialize< - T, typename std::enable_if< - std::is_same().Serialize(std::declval())), void>::value, T>::type> - : std::true_type {}; - -template -struct has_deserialize : std::false_type {}; - -// Accept `static unique_ptr Deserialize(Deserializer& deserializer)` -template -struct has_deserialize< - T, typename std::enable_if(Deserializer &)>::value, T>::type> - : std::true_type {}; - -// Accept `static shared_ptr Deserialize(Deserializer& deserializer)` -template -struct has_deserialize< - T, typename std::enable_if(Deserializer &)>::value, T>::type> - : std::true_type {}; - -// Accept `static T Deserialize(Deserializer& deserializer)` -template -struct has_deserialize< - T, typename std::enable_if::value, T>::type> - : std::true_type {}; - -// Check if T is a vector, and provide access to the inner type -template -struct is_vector : std::false_type {}; -template -struct is_vector> : std::true_type { - typedef T ELEMENT_TYPE; -}; - -template -struct is_unsafe_vector : std::false_type {}; -template -struct is_unsafe_vector> : std::true_type { - typedef T ELEMENT_TYPE; -}; - -// Check if T is a unordered map, and provide access to the inner type -template -struct is_unordered_map : std::false_type {}; -template -struct is_unordered_map> : std::true_type { - typedef typename std::tuple_element<0, std::tuple>::type KEY_TYPE; - typedef typename std::tuple_element<1, std::tuple>::type VALUE_TYPE; - typedef typename std::tuple_element<2, std::tuple>::type HASH_TYPE; - typedef typename std::tuple_element<3, std::tuple>::type EQUAL_TYPE; -}; - -template -struct is_map : std::false_type {}; -template -struct is_map> : std::true_type { - typedef typename std::tuple_element<0, std::tuple>::type KEY_TYPE; - typedef typename std::tuple_element<1, std::tuple>::type VALUE_TYPE; - typedef typename std::tuple_element<2, std::tuple>::type HASH_TYPE; - typedef typename std::tuple_element<3, std::tuple>::type EQUAL_TYPE; -}; - -template -struct is_unique_ptr : std::false_type {}; -template -struct is_unique_ptr> : std::true_type { - typedef T ELEMENT_TYPE; -}; - -template -struct is_shared_ptr : std::false_type {}; -template -struct is_shared_ptr> : std::true_type { - typedef T ELEMENT_TYPE; -}; - -template -struct is_optional_ptr : std::false_type {}; -template -struct is_optional_ptr> : std::true_type { - typedef T ELEMENT_TYPE; -}; - -template -struct is_pair : std::false_type {}; -template -struct is_pair> : std::true_type { - typedef T FIRST_TYPE; - typedef U SECOND_TYPE; -}; - -template -struct is_unordered_set : std::false_type {}; -template -struct is_unordered_set> : std::true_type { - typedef typename std::tuple_element<0, std::tuple>::type ELEMENT_TYPE; - typedef typename std::tuple_element<1, std::tuple>::type HASH_TYPE; - typedef typename std::tuple_element<2, std::tuple>::type EQUAL_TYPE; -}; - -template -struct is_set : std::false_type {}; -template -struct is_set> : std::true_type { - typedef typename std::tuple_element<0, std::tuple>::type ELEMENT_TYPE; - typedef typename std::tuple_element<1, std::tuple>::type HASH_TYPE; - typedef typename std::tuple_element<2, std::tuple>::type EQUAL_TYPE; -}; - -template -struct is_atomic : std::false_type {}; - -template -struct is_atomic> : std::true_type { - typedef T TYPE; -}; - -struct SerializationDefaultValue { - - template - static inline typename std::enable_if::value, T>::type GetDefault() { - using INNER = typename is_atomic::TYPE; - return static_cast(GetDefault()); - } - - template - static inline bool IsDefault(const typename std::enable_if::value, T>::type &value) { - using INNER = typename is_atomic::TYPE; - return value == GetDefault(); - } - - template - static inline typename std::enable_if::value, T>::type GetDefault() { - return static_cast(0); - } - - template - static inline bool IsDefault(const typename std::enable_if::value, T>::type &value) { - return value == static_cast(0); - } - - template - static inline typename std::enable_if::value, T>::type GetDefault() { - return T(); - } - - template - static inline bool IsDefault(const typename std::enable_if::value, T>::type &value) { - return !value; - } - - template - static inline typename std::enable_if::value, T>::type GetDefault() { - return T(); - } - - template - static inline bool IsDefault(const typename std::enable_if::value, T>::type &value) { - return !value; - } - - template - static inline typename std::enable_if::value, T>::type GetDefault() { - return T(); - } - - template - static inline bool IsDefault(const typename std::enable_if::value, T>::type &value) { - return !value; - } - - template - static inline typename std::enable_if::value, T>::type GetDefault() { - return T(); - } - - template - static inline bool IsDefault(const typename std::enable_if::value, T>::type &value) { - return value.empty(); - } - - template - static inline typename std::enable_if::value, T>::type GetDefault() { - return T(); - } - - template - static inline bool IsDefault(const typename std::enable_if::value, T>::type &value) { - return value.empty(); - } - - template - static inline typename std::enable_if::value, T>::type GetDefault() { - return T(); - } - - template - static inline bool IsDefault(const typename std::enable_if::value, T>::type &value) { - return value.empty(); - } - - template - static inline typename std::enable_if::value, T>::type GetDefault() { - return T(); - } - - template - static inline bool IsDefault(const typename std::enable_if::value, T>::type &value) { - return value.empty(); - } - - template - static inline typename std::enable_if::value, T>::type GetDefault() { - return T(); - } - - template - static inline bool IsDefault(const typename std::enable_if::value, T>::type &value) { - return value.empty(); - } - - template - static inline typename std::enable_if::value, T>::type GetDefault() { - return T(); - } - - template - static inline bool IsDefault(const typename std::enable_if::value, T>::type &value) { - return value.empty(); - } -}; - -} // namespace duckdb - - - - - - -namespace duckdb { - -class Serializer { -protected: - bool serialize_enum_as_string = false; - bool serialize_default_values = false; - -public: - virtual ~Serializer() { - } - - class List { - friend Serializer; - - private: - Serializer &serializer; - explicit List(Serializer &serializer) : serializer(serializer) { - } - - public: - // Serialize an element - template - void WriteElement(const T &value); - - // Serialize an object - template - void WriteObject(FUNC f); - }; - -public: - // Serialize a value - template - void WriteProperty(const field_id_t field_id, const char *tag, const T &value) { - OnPropertyBegin(field_id, tag); - WriteValue(value); - OnPropertyEnd(); - } - - // Default value - template - void WritePropertyWithDefault(const field_id_t field_id, const char *tag, const T &value) { - // If current value is default, don't write it - if (!serialize_default_values && SerializationDefaultValue::IsDefault(value)) { - OnOptionalPropertyBegin(field_id, tag, false); - OnOptionalPropertyEnd(false); - return; - } - OnOptionalPropertyBegin(field_id, tag, true); - WriteValue(value); - OnOptionalPropertyEnd(true); - } - - template - void WritePropertyWithDefault(const field_id_t field_id, const char *tag, const T &value, const T &&default_value) { - // If current value is default, don't write it - if (!serialize_default_values && (value == default_value)) { - OnOptionalPropertyBegin(field_id, tag, false); - OnOptionalPropertyEnd(false); - return; - } - OnOptionalPropertyBegin(field_id, tag, true); - WriteValue(value); - OnOptionalPropertyEnd(true); - } - - // Special case: data_ptr_T - void WriteProperty(const field_id_t field_id, const char *tag, const_data_ptr_t ptr, idx_t count) { - OnPropertyBegin(field_id, tag); - WriteDataPtr(ptr, count); - OnPropertyEnd(); - } - - // Manually begin an object - template - void WriteObject(const field_id_t field_id, const char *tag, FUNC f) { - OnPropertyBegin(field_id, tag); - OnObjectBegin(); - f(*this); - OnObjectEnd(); - OnPropertyEnd(); - } - - template - void WriteList(const field_id_t field_id, const char *tag, idx_t count, FUNC func) { - OnPropertyBegin(field_id, tag); - OnListBegin(count); - List list {*this}; - for (idx_t i = 0; i < count; i++) { - func(list, i); - } - OnListEnd(); - OnPropertyEnd(); - } - -protected: - template - typename std::enable_if::value, void>::type WriteValue(const T value) { - if (serialize_enum_as_string) { - // Use the enum serializer to lookup tostring function - auto str = EnumUtil::ToChars(value); - WriteValue(str); - } else { - // Use the underlying type - WriteValue(static_cast::type>(value)); - } - } - - // Unique Pointer Ref - template - void WriteValue(const unique_ptr &ptr) { - WriteValue(ptr.get()); - } - - // Shared Pointer Ref - template - void WriteValue(const shared_ptr &ptr) { - WriteValue(ptr.get()); - } - - // Pointer - template - void WriteValue(const T *ptr) { - if (ptr == nullptr) { - OnNullableBegin(false); - OnNullableEnd(); - } else { - OnNullableBegin(true); - WriteValue(*ptr); - OnNullableEnd(); - } - } - - // Pair - template - void WriteValue(const std::pair &pair) { - OnObjectBegin(); - WriteProperty(0, "first", pair.first); - WriteProperty(1, "second", pair.second); - OnObjectEnd(); - } - - // Reference Wrapper - template - void WriteValue(const reference ref) { - WriteValue(ref.get()); - } - - // Vector - template - void WriteValue(const vector &vec) { - auto count = vec.size(); - OnListBegin(count); - for (auto &item : vec) { - WriteValue(item); - } - OnListEnd(); - } - - template - void WriteValue(const unsafe_vector &vec) { - auto count = vec.size(); - OnListBegin(count); - for (auto &item : vec) { - WriteValue(item); - } - OnListEnd(); - } - - // UnorderedSet - // Serialized the same way as a list/vector - template - void WriteValue(const duckdb::unordered_set &set) { - auto count = set.size(); - OnListBegin(count); - for (auto &item : set) { - WriteValue(item); - } - OnListEnd(); - } - - // Set - // Serialized the same way as a list/vector - template - void WriteValue(const duckdb::set &set) { - auto count = set.size(); - OnListBegin(count); - for (auto &item : set) { - WriteValue(item); - } - OnListEnd(); - } - - // Map - // serialized as a list of pairs - template - void WriteValue(const duckdb::unordered_map &map) { - auto count = map.size(); - OnListBegin(count); - for (auto &item : map) { - OnObjectBegin(); - WriteProperty(0, "key", item.first); - WriteProperty(1, "value", item.second); - OnObjectEnd(); - } - OnListEnd(); - } - - // Map - // serialized as a list of pairs - template - void WriteValue(const duckdb::map &map) { - auto count = map.size(); - OnListBegin(count); - for (auto &item : map) { - OnObjectBegin(); - WriteProperty(0, "key", item.first); - WriteProperty(1, "value", item.second); - OnObjectEnd(); - } - OnListEnd(); - } - - // class or struct implementing `Serialize(Serializer& Serializer)`; - template - typename std::enable_if::value>::type WriteValue(const T &value) { - OnObjectBegin(); - value.Serialize(*this); - OnObjectEnd(); - } - -protected: - // Hooks for subclasses to override to implement custom behavior - virtual void OnPropertyBegin(const field_id_t field_id, const char *tag) = 0; - virtual void OnPropertyEnd() = 0; - virtual void OnOptionalPropertyBegin(const field_id_t field_id, const char *tag, bool present) = 0; - virtual void OnOptionalPropertyEnd(bool present) = 0; - virtual void OnObjectBegin() = 0; - virtual void OnObjectEnd() = 0; - virtual void OnListBegin(idx_t count) = 0; - virtual void OnListEnd() = 0; - virtual void OnNullableBegin(bool present) = 0; - virtual void OnNullableEnd() = 0; - - // Handle primitive types, a serializer needs to implement these. - virtual void WriteNull() = 0; - virtual void WriteValue(char value) { - throw NotImplementedException("Write char value not implemented"); - } - virtual void WriteValue(bool value) = 0; - virtual void WriteValue(uint8_t value) = 0; - virtual void WriteValue(int8_t value) = 0; - virtual void WriteValue(uint16_t value) = 0; - virtual void WriteValue(int16_t value) = 0; - virtual void WriteValue(uint32_t value) = 0; - virtual void WriteValue(int32_t value) = 0; - virtual void WriteValue(uint64_t value) = 0; - virtual void WriteValue(int64_t value) = 0; - virtual void WriteValue(hugeint_t value) = 0; - virtual void WriteValue(float value) = 0; - virtual void WriteValue(double value) = 0; - virtual void WriteValue(const string_t value) = 0; - virtual void WriteValue(const string &value) = 0; - virtual void WriteValue(const char *str) = 0; - virtual void WriteDataPtr(const_data_ptr_t ptr, idx_t count) = 0; - void WriteValue(LogicalIndex value) { - WriteValue(value.index); - } - void WriteValue(PhysicalIndex value) { - WriteValue(value.index); - } -}; - -// We need to special case vector because elements of vector cannot be referenced -template <> -void Serializer::WriteValue(const vector &vec); - -// List Impl -template -void Serializer::List::WriteObject(FUNC f) { - serializer.OnObjectBegin(); - f(serializer); - serializer.OnObjectEnd(); -} - -template -void Serializer::List::WriteElement(const T &value) { - serializer.WriteValue(value); -} - -} // namespace duckdb - - - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/serializer/encoding_util.hpp -// -// -//===----------------------------------------------------------------------===// - - - - -#include - -namespace duckdb { - -struct EncodingUtil { - - // Encode unsigned integer, returns the number of bytes written - template - static idx_t EncodeUnsignedLEB128(data_ptr_t target, T value) { - static_assert(std::is_integral::value, "Must be integral"); - static_assert(std::is_unsigned::value, "Must be unsigned"); - static_assert(sizeof(T) <= sizeof(uint64_t), "Must be uint64_t or smaller"); - - idx_t offset = 0; - do { - uint8_t byte = value & 0x7F; - value >>= 7; - if (value != 0) { - byte |= 0x80; - } - target[offset++] = byte; - } while (value != 0); - return offset; - } - - // Decode unsigned integer, returns the number of bytes read - template - static idx_t DecodeUnsignedLEB128(const_data_ptr_t source, T &result) { - static_assert(std::is_integral::value, "Must be integral"); - static_assert(std::is_unsigned::value, "Must be unsigned"); - static_assert(sizeof(T) <= sizeof(uint64_t), "Must be uint64_t or smaller"); - - result = 0; - idx_t shift = 0; - idx_t offset = 0; - uint8_t byte; - do { - byte = source[offset++]; - result |= static_cast(byte & 0x7F) << shift; - shift += 7; - } while (byte & 0x80); - - return offset; - } - - // Encode signed integer, returns the number of bytes written - template - static idx_t EncodeSignedLEB128(data_ptr_t target, T value) { - static_assert(std::is_integral::value, "Must be integral"); - static_assert(std::is_signed::value, "Must be signed"); - static_assert(sizeof(T) <= sizeof(int64_t), "Must be int64_t or smaller"); - - idx_t offset = 0; - do { - uint8_t byte = value & 0x7F; - value >>= 7; - - // Determine whether more bytes are needed - if ((value == 0 && (byte & 0x40) == 0) || (value == -1 && (byte & 0x40))) { - target[offset++] = byte; - break; - } else { - byte |= 0x80; - target[offset++] = byte; - } - } while (true); - return offset; - } - - // Decode signed integer, returns the number of bytes read - template - static idx_t DecodeSignedLEB128(const_data_ptr_t source, T &result) { - static_assert(std::is_integral::value, "Must be integral"); - static_assert(std::is_signed::value, "Must be signed"); - static_assert(sizeof(T) <= sizeof(int64_t), "Must be int64_t or smaller"); - - // This is used to avoid undefined behavior when shifting into the sign bit - using unsigned_type = typename std::make_unsigned::type; - - result = 0; - idx_t shift = 0; - idx_t offset = 0; - - uint8_t byte; - do { - byte = source[offset++]; - result |= static_cast(byte & 0x7F) << shift; - shift += 7; - } while (byte & 0x80); - - // Sign-extend if the most significant bit of the last byte is set - if (shift < sizeof(T) * 8 && (byte & 0x40)) { - result |= -(static_cast(1) << shift); - } - return offset; - } - - template - static typename std::enable_if::value, idx_t>::type DecodeLEB128(const_data_ptr_t source, - T &result) { - return DecodeSignedLEB128(source, result); - } - - template - static typename std::enable_if::value, idx_t>::type DecodeLEB128(const_data_ptr_t source, - T &result) { - return DecodeUnsignedLEB128(source, result); - } - - template - static typename std::enable_if::value, idx_t>::type EncodeLEB128(data_ptr_t target, T value) { - return EncodeSignedLEB128(target, value); - } - - template - static typename std::enable_if::value, idx_t>::type EncodeLEB128(data_ptr_t target, T value) { - return EncodeUnsignedLEB128(target, value); - } -}; - -} // namespace duckdb - - -namespace duckdb { - -class BinarySerializer : public Serializer { -public: - explicit BinarySerializer(WriteStream &stream, bool serialize_default_values_p = false) : stream(stream) { - serialize_default_values = serialize_default_values_p; - serialize_enum_as_string = false; - } - -private: - struct DebugState { - unordered_set seen_field_tags; - unordered_set seen_field_ids; - vector> seen_fields; - }; - - void WriteData(const_data_ptr_t buffer, idx_t write_size) { - stream.WriteData(buffer, write_size); - } - - template - void Write(T element) { - static_assert(std::is_trivially_destructible(), "Write element must be trivially destructible"); - WriteData(const_data_ptr_cast(&element), sizeof(T)); - } - void WriteData(const char *ptr, idx_t write_size) { - WriteData(const_data_ptr_cast(ptr), write_size); - } - - template - void VarIntEncode(T value) { - uint8_t buffer[16]; - auto write_size = EncodingUtil::EncodeLEB128(buffer, value); - D_ASSERT(write_size <= sizeof(buffer)); - WriteData(buffer, write_size); - } - -public: - template - static void Serialize(const T &value, WriteStream &stream, bool serialize_default_values = false) { - BinarySerializer serializer(stream, serialize_default_values); - serializer.OnObjectBegin(); - value.Serialize(serializer); - serializer.OnObjectEnd(); - } - - void Begin() { - OnObjectBegin(); - } - void End() { - OnObjectEnd(); - } - -protected: - //------------------------------------------------------------------------- - // Nested Type Hooks - //------------------------------------------------------------------------- - // We serialize optional values as a message with a "present" flag, followed by the value. - void OnPropertyBegin(const field_id_t field_id, const char *tag) final; - void OnPropertyEnd() final; - void OnOptionalPropertyBegin(const field_id_t field_id, const char *tag, bool present) final; - void OnOptionalPropertyEnd(bool present) final; - void OnListBegin(idx_t count) final; - void OnListEnd() final; - void OnObjectBegin() final; - void OnObjectEnd() final; - void OnNullableBegin(bool present) final; - void OnNullableEnd() final; - - //------------------------------------------------------------------------- - // Primitive Types - //------------------------------------------------------------------------- - void WriteNull() final; - void WriteValue(char value) final; - void WriteValue(uint8_t value) final; - void WriteValue(int8_t value) final; - void WriteValue(uint16_t value) final; - void WriteValue(int16_t value) final; - void WriteValue(uint32_t value) final; - void WriteValue(int32_t value) final; - void WriteValue(uint64_t value) final; - void WriteValue(int64_t value) final; - void WriteValue(hugeint_t value) final; - void WriteValue(float value) final; - void WriteValue(double value) final; - void WriteValue(const string_t value) final; - void WriteValue(const string &value) final; - void WriteValue(const char *value) final; - void WriteValue(bool value) final; - void WriteDataPtr(const_data_ptr_t ptr, idx_t count) final; - -private: - vector debug_stack; - WriteStream &stream; -}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/serializer/binary_deserializer.hpp -// -// -//===----------------------------------------------------------------------===// - - - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/serializer/format_serializer.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/serializer/deserialization_data.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/bound_parameter_map.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/expression/bound_parameter_data.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - -namespace duckdb { - -struct BoundParameterData { -public: - BoundParameterData() { - } - explicit BoundParameterData(Value val) : value(std::move(val)), return_type(value.type()) { - } - -private: - Value value; - -public: - LogicalType return_type; - -public: - void SetValue(Value val) { - value = std::move(val); - } - - const Value &GetValue() const { - return value; - } - - void Serialize(Serializer &serializer) const; - static shared_ptr Deserialize(Deserializer &deserializer); -}; - -} // namespace duckdb - - - -namespace duckdb { - -class ParameterExpression; -class BoundParameterExpression; - -using bound_parameter_map_t = case_insensitive_map_t>; - -struct BoundParameterMap { -public: - explicit BoundParameterMap(case_insensitive_map_t ¶meter_data); - -public: - LogicalType GetReturnType(const string &identifier); - - bound_parameter_map_t *GetParametersPtr(); - - const bound_parameter_map_t &GetParameters(); - - const case_insensitive_map_t &GetParameterData(); - - unique_ptr BindParameterExpression(ParameterExpression &expr); - -private: - shared_ptr CreateOrGetData(const string &identifier); - void CreateNewParameter(const string &id, const shared_ptr ¶m_data); - -private: - bound_parameter_map_t parameters; - // Pre-provided parameter data if populated - case_insensitive_map_t ¶meter_data; -}; - -} // namespace duckdb - - -namespace duckdb { -class ClientContext; -class Catalog; -class DatabaseInstance; -enum class ExpressionType : uint8_t; - -struct DeserializationData { - stack> contexts; - stack> databases; - stack enums; - stack> parameter_data; - stack> types; - - template - void Set(T entry) = delete; - - template - T Get() = delete; - - template - void Unset() = delete; - - template - inline void AssertNotEmpty(const stack &e) { - if (e.empty()) { - throw InternalException("DeserializationData - unexpected empty stack"); - } - } -}; - -template <> -inline void DeserializationData::Set(ExpressionType type) { - enums.push(idx_t(type)); -} - -template <> -inline ExpressionType DeserializationData::Get() { - AssertNotEmpty(enums); - return ExpressionType(enums.top()); -} - -template <> -inline void DeserializationData::Unset() { - AssertNotEmpty(enums); - enums.pop(); -} - -template <> -inline void DeserializationData::Set(LogicalOperatorType type) { - enums.push(idx_t(type)); -} - -template <> -inline LogicalOperatorType DeserializationData::Get() { - AssertNotEmpty(enums); - return LogicalOperatorType(enums.top()); -} - -template <> -inline void DeserializationData::Unset() { - AssertNotEmpty(enums); - enums.pop(); -} - -template <> -inline void DeserializationData::Set(CompressionType type) { - enums.push(idx_t(type)); -} - -template <> -inline CompressionType DeserializationData::Get() { - AssertNotEmpty(enums); - return CompressionType(enums.top()); -} - -template <> -inline void DeserializationData::Unset() { - AssertNotEmpty(enums); - enums.pop(); -} - -template <> -inline void DeserializationData::Set(CatalogType type) { - enums.push(idx_t(type)); -} - -template <> -inline CatalogType DeserializationData::Get() { - AssertNotEmpty(enums); - return CatalogType(enums.top()); -} - -template <> -inline void DeserializationData::Unset() { - AssertNotEmpty(enums); - enums.pop(); -} - -template <> -inline void DeserializationData::Set(ClientContext &context) { - contexts.push(context); -} - -template <> -inline ClientContext &DeserializationData::Get() { - AssertNotEmpty(contexts); - return contexts.top(); -} - -template <> -inline void DeserializationData::Unset() { - AssertNotEmpty(contexts); - contexts.pop(); -} - -template <> -inline void DeserializationData::Set(DatabaseInstance &db) { - databases.push(db); -} - -template <> -inline DatabaseInstance &DeserializationData::Get() { - AssertNotEmpty(databases); - return databases.top(); -} - -template <> -inline void DeserializationData::Unset() { - AssertNotEmpty(databases); - databases.pop(); -} - -template <> -inline void DeserializationData::Set(bound_parameter_map_t &context) { - parameter_data.push(context); -} - -template <> -inline bound_parameter_map_t &DeserializationData::Get() { - AssertNotEmpty(parameter_data); - return parameter_data.top(); -} - -template <> -inline void DeserializationData::Unset() { - AssertNotEmpty(parameter_data); - parameter_data.pop(); -} - -template <> -inline void DeserializationData::Set(LogicalType &type) { - types.emplace(type); -} - -template <> -inline LogicalType &DeserializationData::Get() { - AssertNotEmpty(types); - return types.top(); -} - -template <> -inline void DeserializationData::Unset() { - AssertNotEmpty(types); - types.pop(); -} - -} // namespace duckdb - - - - - -namespace duckdb { - -class Deserializer { -protected: - bool deserialize_enum_from_string = false; - DeserializationData data; - -public: - virtual ~Deserializer() { - } - - class List { - friend Deserializer; - - private: - Deserializer &deserializer; - explicit List(Deserializer &deserializer) : deserializer(deserializer) { - } - - public: - // Deserialize an element - template - T ReadElement(); - - // Deserialize an object - template - void ReadObject(FUNC f); - }; - -public: - // Read into an existing value - template - inline void ReadProperty(const field_id_t field_id, const char *tag, T &ret) { - OnPropertyBegin(field_id, tag); - ret = Read(); - OnPropertyEnd(); - } - - // Read and return a value - template - inline T ReadProperty(const field_id_t field_id, const char *tag) { - OnPropertyBegin(field_id, tag); - auto ret = Read(); - OnPropertyEnd(); - return ret; - } - - // Default Value return - template - inline T ReadPropertyWithDefault(const field_id_t field_id, const char *tag) { - if (!OnOptionalPropertyBegin(field_id, tag)) { - OnOptionalPropertyEnd(false); - return std::forward(SerializationDefaultValue::GetDefault()); - } - auto ret = Read(); - OnOptionalPropertyEnd(true); - return ret; - } - - template - inline T ReadPropertyWithDefault(const field_id_t field_id, const char *tag, T &&default_value) { - if (!OnOptionalPropertyBegin(field_id, tag)) { - OnOptionalPropertyEnd(false); - return std::forward(default_value); - } - auto ret = Read(); - OnOptionalPropertyEnd(true); - return ret; - } - - // Default value in place - template - inline void ReadPropertyWithDefault(const field_id_t field_id, const char *tag, T &ret) { - if (!OnOptionalPropertyBegin(field_id, tag)) { - ret = std::forward(SerializationDefaultValue::GetDefault()); - OnOptionalPropertyEnd(false); - return; - } - ret = Read(); - OnOptionalPropertyEnd(true); - } - - template - inline void ReadPropertyWithDefault(const field_id_t field_id, const char *tag, T &ret, T &&default_value) { - if (!OnOptionalPropertyBegin(field_id, tag)) { - ret = std::forward(default_value); - OnOptionalPropertyEnd(false); - return; - } - ret = Read(); - OnOptionalPropertyEnd(true); - } - - // Special case: - // Read into an existing data_ptr_t - inline void ReadProperty(const field_id_t field_id, const char *tag, data_ptr_t ret, idx_t count) { - OnPropertyBegin(field_id, tag); - ReadDataPtr(ret, count); - OnPropertyEnd(); - } - - // Try to read a property, if it is not present, continue, otherwise read and discard the value - template - inline void ReadDeletedProperty(const field_id_t field_id, const char *tag) { - // Try to read the property. If not present, great! - if (!OnOptionalPropertyBegin(field_id, tag)) { - OnOptionalPropertyEnd(false); - return; - } - // Otherwise read and discard the value - (void)Read(); - OnOptionalPropertyEnd(true); - } - - //! Set a serialization property - template - void Set(T entry) { - return data.Set(entry); - } - - //! Retrieve the last set serialization property of this type - template - T Get() { - return data.Get(); - } - - //! Unset a serialization property - template - void Unset() { - return data.Unset(); - } - - template - void ReadList(const field_id_t field_id, const char *tag, FUNC func) { - OnPropertyBegin(field_id, tag); - auto size = OnListBegin(); - List list {*this}; - for (idx_t i = 0; i < size; i++) { - func(list, i); - } - OnListEnd(); - OnPropertyEnd(); - } - - template - void ReadObject(const field_id_t field_id, const char *tag, FUNC func) { - OnPropertyBegin(field_id, tag); - OnObjectBegin(); - func(*this); - OnObjectEnd(); - OnPropertyEnd(); - } - -private: - // Deserialize anything implementing a Deserialize method - template - inline typename std::enable_if::value, T>::type Read() { - OnObjectBegin(); - auto val = T::Deserialize(*this); - OnObjectEnd(); - return val; - } - - template - inline typename std::enable_if::value, T>::type Read() { - using ELEMENT_TYPE = typename is_unique_ptr::ELEMENT_TYPE; - unique_ptr ptr = nullptr; - auto is_present = OnNullableBegin(); - if (is_present) { - OnObjectBegin(); - ptr = ELEMENT_TYPE::Deserialize(*this); - OnObjectEnd(); - } - OnNullableEnd(); - return ptr; - } - - // Deserialize shared_ptr - template - inline typename std::enable_if::value, T>::type Read() { - using ELEMENT_TYPE = typename is_shared_ptr::ELEMENT_TYPE; - shared_ptr ptr = nullptr; - auto is_present = OnNullableBegin(); - if (is_present) { - OnObjectBegin(); - ptr = ELEMENT_TYPE::Deserialize(*this); - OnObjectEnd(); - } - OnNullableEnd(); - return ptr; - } - - // Deserialize a vector - template - inline typename std::enable_if::value, T>::type Read() { - using ELEMENT_TYPE = typename is_vector::ELEMENT_TYPE; - T vec; - auto size = OnListBegin(); - for (idx_t i = 0; i < size; i++) { - vec.push_back(Read()); - } - OnListEnd(); - return vec; - } - - template - inline typename std::enable_if::value, T>::type Read() { - using ELEMENT_TYPE = typename is_unsafe_vector::ELEMENT_TYPE; - T vec; - auto size = OnListBegin(); - for (idx_t i = 0; i < size; i++) { - vec.push_back(Read()); - } - OnListEnd(); - - return vec; - } - - // Deserialize a map - template - inline typename std::enable_if::value, T>::type Read() { - using KEY_TYPE = typename is_unordered_map::KEY_TYPE; - using VALUE_TYPE = typename is_unordered_map::VALUE_TYPE; - - T map; - auto size = OnListBegin(); - for (idx_t i = 0; i < size; i++) { - OnObjectBegin(); - auto key = ReadProperty(0, "key"); - auto value = ReadProperty(1, "value"); - OnObjectEnd(); - map[std::move(key)] = std::move(value); - } - OnListEnd(); - return map; - } - - template - inline typename std::enable_if::value, T>::type Read() { - using KEY_TYPE = typename is_map::KEY_TYPE; - using VALUE_TYPE = typename is_map::VALUE_TYPE; - - T map; - auto size = OnListBegin(); - for (idx_t i = 0; i < size; i++) { - OnObjectBegin(); - auto key = ReadProperty(0, "key"); - auto value = ReadProperty(1, "value"); - OnObjectEnd(); - map[std::move(key)] = std::move(value); - } - OnListEnd(); - return map; - } - - // Deserialize an unordered set - template - inline typename std::enable_if::value, T>::type Read() { - using ELEMENT_TYPE = typename is_unordered_set::ELEMENT_TYPE; - auto size = OnListBegin(); - T set; - for (idx_t i = 0; i < size; i++) { - set.insert(Read()); - } - OnListEnd(); - return set; - } - - // Deserialize a set - template - inline typename std::enable_if::value, T>::type Read() { - using ELEMENT_TYPE = typename is_set::ELEMENT_TYPE; - auto size = OnListBegin(); - T set; - for (idx_t i = 0; i < size; i++) { - set.insert(Read()); - } - OnListEnd(); - return set; - } - - // Deserialize a pair - template - inline typename std::enable_if::value, T>::type Read() { - using FIRST_TYPE = typename is_pair::FIRST_TYPE; - using SECOND_TYPE = typename is_pair::SECOND_TYPE; - OnObjectBegin(); - auto first = ReadProperty(0, "first"); - auto second = ReadProperty(1, "second"); - OnObjectEnd(); - return std::make_pair(first, second); - } - - // Primitive types - // Deserialize a bool - template - inline typename std::enable_if::value, T>::type Read() { - return ReadBool(); - } - - // Deserialize a char - template - inline typename std::enable_if::value, T>::type Read() { - return ReadChar(); - } - - // Deserialize a int8_t - template - inline typename std::enable_if::value, T>::type Read() { - return ReadSignedInt8(); - } - - // Deserialize a uint8_t - template - inline typename std::enable_if::value, T>::type Read() { - return ReadUnsignedInt8(); - } - - // Deserialize a int16_t - template - inline typename std::enable_if::value, T>::type Read() { - return ReadSignedInt16(); - } - - // Deserialize a uint16_t - template - inline typename std::enable_if::value, T>::type Read() { - return ReadUnsignedInt16(); - } - - // Deserialize a int32_t - template - inline typename std::enable_if::value, T>::type Read() { - return ReadSignedInt32(); - } - - // Deserialize a uint32_t - template - inline typename std::enable_if::value, T>::type Read() { - return ReadUnsignedInt32(); - } - - // Deserialize a int64_t - template - inline typename std::enable_if::value, T>::type Read() { - return ReadSignedInt64(); - } - - // Deserialize a uint64_t - template - inline typename std::enable_if::value, T>::type Read() { - return ReadUnsignedInt64(); - } - - // Deserialize a float - template - inline typename std::enable_if::value, T>::type Read() { - return ReadFloat(); - } - - // Deserialize a double - template - inline typename std::enable_if::value, T>::type Read() { - return ReadDouble(); - } - - // Deserialize a string - template - inline typename std::enable_if::value, T>::type Read() { - return ReadString(); - } - - // Deserialize a Enum - template - inline typename std::enable_if::value, T>::type Read() { - if (deserialize_enum_from_string) { - auto str = ReadString(); - return EnumUtil::FromString(str.c_str()); - } else { - return (T)Read::type>(); - } - } - - // Deserialize a hugeint_t - template - inline typename std::enable_if::value, T>::type Read() { - return ReadHugeInt(); - } - - // Deserialize a LogicalIndex - template - inline typename std::enable_if::value, T>::type Read() { - return LogicalIndex(ReadUnsignedInt64()); - } - - // Deserialize a PhysicalIndex - template - inline typename std::enable_if::value, T>::type Read() { - return PhysicalIndex(ReadUnsignedInt64()); - } - -protected: - // Hooks for subclasses to override to implement custom behavior - virtual void OnPropertyBegin(const field_id_t field_id, const char *tag) = 0; - virtual void OnPropertyEnd() = 0; - virtual bool OnOptionalPropertyBegin(const field_id_t field_id, const char *tag) = 0; - virtual void OnOptionalPropertyEnd(bool present) = 0; - - virtual void OnObjectBegin() = 0; - virtual void OnObjectEnd() = 0; - virtual idx_t OnListBegin() = 0; - virtual void OnListEnd() = 0; - virtual bool OnNullableBegin() = 0; - virtual void OnNullableEnd() = 0; - - // Handle primitive types, a serializer needs to implement these. - virtual bool ReadBool() = 0; - virtual char ReadChar() { - throw NotImplementedException("ReadChar not implemented"); - } - virtual int8_t ReadSignedInt8() = 0; - virtual uint8_t ReadUnsignedInt8() = 0; - virtual int16_t ReadSignedInt16() = 0; - virtual uint16_t ReadUnsignedInt16() = 0; - virtual int32_t ReadSignedInt32() = 0; - virtual uint32_t ReadUnsignedInt32() = 0; - virtual int64_t ReadSignedInt64() = 0; - virtual uint64_t ReadUnsignedInt64() = 0; - virtual hugeint_t ReadHugeInt() = 0; - virtual float ReadFloat() = 0; - virtual double ReadDouble() = 0; - virtual string ReadString() = 0; - virtual void ReadDataPtr(data_ptr_t &ptr, idx_t count) = 0; -}; - -template -void Deserializer::List::ReadObject(FUNC f) { - deserializer.OnObjectBegin(); - f(deserializer); - deserializer.OnObjectEnd(); -} - -template -T Deserializer::List::ReadElement() { - return deserializer.Read(); -} - -} // namespace duckdb - - - - -namespace duckdb { -class ClientContext; - -class BinaryDeserializer : public Deserializer { -public: - explicit BinaryDeserializer(ReadStream &stream) : stream(stream) { - deserialize_enum_from_string = false; - } - - template - unique_ptr Deserialize() { - OnObjectBegin(); - auto result = T::Deserialize(*this); - OnObjectEnd(); - D_ASSERT(nesting_level == 0); // make sure we are at the root level - return result; - } - - template - static unique_ptr Deserialize(ReadStream &stream) { - BinaryDeserializer deserializer(stream); - return deserializer.template Deserialize(); - } - - template - static unique_ptr Deserialize(ReadStream &stream, ClientContext &context, bound_parameter_map_t ¶meters) { - BinaryDeserializer deserializer(stream); - deserializer.Set(context); - deserializer.Set(parameters); - return deserializer.template Deserialize(); - } - - void Begin() { - OnObjectBegin(); - } - - void End() { - OnObjectEnd(); - D_ASSERT(nesting_level == 0); // make sure we are at the root level - } - - ReadStream &GetStream() { - return stream; - } - -private: - ReadStream &stream; - idx_t nesting_level = 0; - - // Allow peeking 1 field ahead - bool has_buffered_field = false; - field_id_t buffered_field = 0; - -private: - field_id_t PeekField() { - if (!has_buffered_field) { - buffered_field = ReadPrimitive(); - has_buffered_field = true; - } - return buffered_field; - } - void ConsumeField() { - if (!has_buffered_field) { - buffered_field = ReadPrimitive(); - } else { - has_buffered_field = false; - } - } - field_id_t NextField() { - if (has_buffered_field) { - has_buffered_field = false; - return buffered_field; - } - return ReadPrimitive(); - } - - void ReadData(data_ptr_t buffer, idx_t read_size) { - stream.ReadData(buffer, read_size); - } - - template - T ReadPrimitive() { - T value; - ReadData(data_ptr_cast(&value), sizeof(T)); - return value; - } - - template - T VarIntDecode() { - // FIXME: maybe we should pass a source to EncodingUtil instead - uint8_t buffer[16]; - idx_t varint_size; - for (varint_size = 0; varint_size < 16; varint_size++) { - ReadData(buffer + varint_size, 1); - if (!(buffer[varint_size] & 0x80)) { - varint_size++; - break; - } - } - T value; - auto read_size = EncodingUtil::DecodeLEB128(buffer, value); - D_ASSERT(read_size == varint_size); - (void)read_size; - return value; - } - - //===--------------------------------------------------------------------===// - // Nested Types Hooks - //===--------------------------------------------------------------------===// - void OnPropertyBegin(const field_id_t field_id, const char *tag) final; - void OnPropertyEnd() final; - bool OnOptionalPropertyBegin(const field_id_t field_id, const char *tag) final; - void OnOptionalPropertyEnd(bool present) final; - void OnObjectBegin() final; - void OnObjectEnd() final; - idx_t OnListBegin() final; - void OnListEnd() final; - bool OnNullableBegin() final; - void OnNullableEnd() final; - - //===--------------------------------------------------------------------===// - // Primitive Types - //===--------------------------------------------------------------------===// - bool ReadBool() final; - char ReadChar() final; - int8_t ReadSignedInt8() final; - uint8_t ReadUnsignedInt8() final; - int16_t ReadSignedInt16() final; - uint16_t ReadUnsignedInt16() final; - int32_t ReadSignedInt32() final; - uint32_t ReadUnsignedInt32() final; - int64_t ReadSignedInt64() final; - uint64_t ReadUnsignedInt64() final; - float ReadFloat() final; - double ReadDouble() final; - string ReadString() final; - hugeint_t ReadHugeInt() final; - void ReadDataPtr(data_ptr_t &ptr, idx_t count) final; -}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/catalog/dcatalog.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - -namespace duckdb { - -//! The Catalog object represents the catalog of the database. -class DuckCatalog : public Catalog { -public: - explicit DuckCatalog(AttachedDatabase &db); - ~DuckCatalog(); - -public: - bool IsDuckCatalog() override; - void Initialize(bool load_builtin) override; - string GetCatalogType() override { - return "duckdb"; - } - - DependencyManager &GetDependencyManager() { - return *dependency_manager; - } - mutex &GetWriteLock() { - return write_lock; - } - -public: - DUCKDB_API optional_ptr CreateSchema(CatalogTransaction transaction, CreateSchemaInfo &info) override; - DUCKDB_API void ScanSchemas(ClientContext &context, std::function callback) override; - DUCKDB_API void ScanSchemas(std::function callback); - - DUCKDB_API optional_ptr - GetSchema(CatalogTransaction transaction, const string &schema_name, OnEntryNotFound if_not_found, - QueryErrorContext error_context = QueryErrorContext()) override; - - DUCKDB_API unique_ptr PlanCreateTableAs(ClientContext &context, LogicalCreateTable &op, - unique_ptr plan) override; - DUCKDB_API unique_ptr PlanInsert(ClientContext &context, LogicalInsert &op, - unique_ptr plan) override; - DUCKDB_API unique_ptr PlanDelete(ClientContext &context, LogicalDelete &op, - unique_ptr plan) override; - DUCKDB_API unique_ptr PlanUpdate(ClientContext &context, LogicalUpdate &op, - unique_ptr plan) override; - DUCKDB_API unique_ptr BindCreateIndex(Binder &binder, CreateStatement &stmt, - TableCatalogEntry &table, - unique_ptr plan) override; - - DatabaseSize GetDatabaseSize(ClientContext &context) override; - vector GetMetadataInfo(ClientContext &context) override; - - DUCKDB_API bool InMemory() override; - DUCKDB_API string GetDBPath() override; - -private: - DUCKDB_API void DropSchema(CatalogTransaction transaction, DropInfo &info); - DUCKDB_API void DropSchema(ClientContext &context, DropInfo &info) override; - optional_ptr CreateSchemaInternal(CatalogTransaction transaction, CreateSchemaInfo &info); - void Verify() override; - -private: - //! The DependencyManager manages dependencies between different catalog objects - unique_ptr dependency_manager; - //! Write lock for the catalog - mutex write_lock; - //! The catalog set holding the schemas - unique_ptr schemas; -}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/catalog/mapping_value.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - -namespace duckdb { -struct AlterInfo; - -class ClientContext; - -struct EntryIndex { - EntryIndex() : catalog(nullptr), index(DConstants::INVALID_INDEX) { - } - EntryIndex(CatalogSet &catalog, idx_t index) : catalog(&catalog), index(index) { - auto entry = catalog.entries.find(index); - if (entry == catalog.entries.end()) { - throw InternalException("EntryIndex - Catalog entry not found in constructor!?"); - } - catalog.entries[index].reference_count++; - } - ~EntryIndex() { - if (!catalog) { - return; - } - auto entry = catalog->entries.find(index); - D_ASSERT(entry != catalog->entries.end()); - auto remaining_ref = --entry->second.reference_count; - if (remaining_ref == 0) { - catalog->entries.erase(index); - } - catalog = nullptr; - } - // disable copy constructors - EntryIndex(const EntryIndex &other) = delete; - EntryIndex &operator=(const EntryIndex &) = delete; - //! enable move constructors - EntryIndex(EntryIndex &&other) noexcept { - catalog = nullptr; - index = DConstants::INVALID_INDEX; - std::swap(catalog, other.catalog); - std::swap(index, other.index); - } - EntryIndex &operator=(EntryIndex &&other) noexcept { - std::swap(catalog, other.catalog); - std::swap(index, other.index); - return *this; - } - - unique_ptr &GetEntry() { - auto entry = catalog->entries.find(index); - if (entry == catalog->entries.end()) { - throw InternalException("EntryIndex - Catalog entry not found!?"); - } - return entry->second.entry; - } - idx_t GetIndex() { - return index; - } - EntryIndex Copy() { - if (catalog) { - return EntryIndex(*catalog, index); - } else { - return EntryIndex(); - } - } - -private: - CatalogSet *catalog; - idx_t index; -}; - -struct MappingValue { - explicit MappingValue(EntryIndex index_p) - : index(std::move(index_p)), timestamp(0), deleted(false), parent(nullptr) { - } - - EntryIndex index; - transaction_t timestamp; - bool deleted; - unique_ptr child; - MappingValue *parent; -}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/transaction/duck_transaction.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - -namespace duckdb { -class RowVersionManager; - -class DuckTransaction : public Transaction { -public: - DuckTransaction(TransactionManager &manager, ClientContext &context, transaction_t start_time, - transaction_t transaction_id); - ~DuckTransaction() override; - - //! The start timestamp of this transaction - transaction_t start_time; - //! The transaction id of this transaction - transaction_t transaction_id; - //! The commit id of this transaction, if it has successfully been committed - transaction_t commit_id; - //! Map of all sequences that were used during the transaction and the value they had in this transaction - unordered_map sequence_usage; - //! Highest active query when the transaction finished, used for cleaning up - transaction_t highest_active_query; - -public: - static DuckTransaction &Get(ClientContext &context, AttachedDatabase &db); - static DuckTransaction &Get(ClientContext &context, Catalog &catalog); - LocalStorage &GetLocalStorage(); - - void PushCatalogEntry(CatalogEntry &entry, data_ptr_t extra_data = nullptr, idx_t extra_data_size = 0); - - //! Commit the current transaction with the given commit identifier. Returns an error message if the transaction - //! commit failed, or an empty string if the commit was sucessful - string Commit(AttachedDatabase &db, transaction_t commit_id, bool checkpoint) noexcept; - //! Returns whether or not a commit of this transaction should trigger an automatic checkpoint - bool AutomaticCheckpoint(AttachedDatabase &db); - - //! Rollback - void Rollback() noexcept; - //! Cleanup the undo buffer - void Cleanup(); - - bool ChangesMade(); - - void PushDelete(DataTable &table, RowVersionManager &info, idx_t vector_idx, row_t rows[], idx_t count, - idx_t base_row); - void PushAppend(DataTable &table, idx_t row_start, idx_t row_count); - UpdateInfo *CreateUpdateInfo(idx_t type_size, idx_t entries); - - bool IsDuckTransaction() const override { - return true; - } - -private: - //! The undo buffer is used to store old versions of rows that are updated - //! or deleted - UndoBuffer undo_buffer; - //! The set of uncommitted appends for the transaction - unique_ptr storage; -}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/transaction/transaction_manager.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - - - - -namespace duckdb { - -class AttachedDatabase; -class ClientContext; -class Catalog; -struct ClientLockWrapper; -class DatabaseInstance; -class Transaction; - -//! The Transaction Manager is responsible for creating and managing -//! transactions -class TransactionManager { -public: - explicit TransactionManager(AttachedDatabase &db); - virtual ~TransactionManager(); - - //! Start a new transaction - virtual Transaction *StartTransaction(ClientContext &context) = 0; - //! Commit the given transaction. Returns a non-empty error message on failure. - virtual string CommitTransaction(ClientContext &context, Transaction *transaction) = 0; - //! Rollback the given transaction - virtual void RollbackTransaction(Transaction *transaction) = 0; - - virtual void Checkpoint(ClientContext &context, bool force = false) = 0; - - static TransactionManager &Get(AttachedDatabase &db); - - virtual bool IsDuckTransactionManager() { - return false; - } - - AttachedDatabase &GetDB() { - return db; - } - -protected: - //! The attached database - AttachedDatabase &db; -}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/function/table_macro_function.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - - - - - -namespace duckdb { - -class TableMacroFunction : public MacroFunction { -public: - static constexpr const MacroType TYPE = MacroType::TABLE_MACRO; - -public: - explicit TableMacroFunction(unique_ptr query_node); - TableMacroFunction(void); - - //! The main query node - unique_ptr query_node; - -public: - unique_ptr Copy() const override; - - string ToSQL(const string &schema, const string &name) const override; - - void Serialize(Serializer &serializer) const override; - static unique_ptr Deserialize(Deserializer &deserializer); -}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/catalog/default/builtin_types/types.hpp -// -// -//===----------------------------------------------------------------------===// -// This file is generated by scripts/generate_builtin_types.py - - - - - - -namespace duckdb { - -struct DefaultType { - const char *name; - LogicalTypeId type; -}; - -using builtin_type_array = std::array; - -static constexpr const builtin_type_array BUILTIN_TYPES{{ - {"decimal", LogicalTypeId::DECIMAL}, - {"dec", LogicalTypeId::DECIMAL}, - {"numeric", LogicalTypeId::DECIMAL}, - {"time", LogicalTypeId::TIME}, - {"date", LogicalTypeId::DATE}, - {"timestamp", LogicalTypeId::TIMESTAMP}, - {"datetime", LogicalTypeId::TIMESTAMP}, - {"timestamp_us", LogicalTypeId::TIMESTAMP}, - {"timestamp_ms", LogicalTypeId::TIMESTAMP_MS}, - {"timestamp_ns", LogicalTypeId::TIMESTAMP_NS}, - {"timestamp_s", LogicalTypeId::TIMESTAMP_SEC}, - {"timestamptz", LogicalTypeId::TIMESTAMP_TZ}, - {"timetz", LogicalTypeId::TIME_TZ}, - {"interval", LogicalTypeId::INTERVAL}, - {"varchar", LogicalTypeId::VARCHAR}, - {"bpchar", LogicalTypeId::VARCHAR}, - {"string", LogicalTypeId::VARCHAR}, - {"char", LogicalTypeId::VARCHAR}, - {"nvarchar", LogicalTypeId::VARCHAR}, - {"text", LogicalTypeId::VARCHAR}, - {"blob", LogicalTypeId::BLOB}, - {"bytea", LogicalTypeId::BLOB}, - {"varbinary", LogicalTypeId::BLOB}, - {"binary", LogicalTypeId::BLOB}, - {"hugeint", LogicalTypeId::HUGEINT}, - {"int128", LogicalTypeId::HUGEINT}, - {"bigint", LogicalTypeId::BIGINT}, - {"oid", LogicalTypeId::BIGINT}, - {"long", LogicalTypeId::BIGINT}, - {"int8", LogicalTypeId::BIGINT}, - {"int64", LogicalTypeId::BIGINT}, - {"ubigint", LogicalTypeId::UBIGINT}, - {"uint64", LogicalTypeId::UBIGINT}, - {"integer", LogicalTypeId::INTEGER}, - {"int", LogicalTypeId::INTEGER}, - {"int4", LogicalTypeId::INTEGER}, - {"signed", LogicalTypeId::INTEGER}, - {"integral", LogicalTypeId::INTEGER}, - {"int32", LogicalTypeId::INTEGER}, - {"uinteger", LogicalTypeId::UINTEGER}, - {"uint32", LogicalTypeId::UINTEGER}, - {"smallint", LogicalTypeId::SMALLINT}, - {"int2", LogicalTypeId::SMALLINT}, - {"short", LogicalTypeId::SMALLINT}, - {"int16", LogicalTypeId::SMALLINT}, - {"usmallint", LogicalTypeId::USMALLINT}, - {"uint16", LogicalTypeId::USMALLINT}, - {"tinyint", LogicalTypeId::TINYINT}, - {"int1", LogicalTypeId::TINYINT}, - {"utinyint", LogicalTypeId::UTINYINT}, - {"uint8", LogicalTypeId::UTINYINT}, - {"struct", LogicalTypeId::STRUCT}, - {"row", LogicalTypeId::STRUCT}, - {"list", LogicalTypeId::LIST}, - {"map", LogicalTypeId::MAP}, - {"union", LogicalTypeId::UNION}, - {"bit", LogicalTypeId::BIT}, - {"bitstring", LogicalTypeId::BIT}, - {"boolean", LogicalTypeId::BOOLEAN}, - {"bool", LogicalTypeId::BOOLEAN}, - {"logical", LogicalTypeId::BOOLEAN}, - {"uuid", LogicalTypeId::UUID}, - {"guid", LogicalTypeId::UUID}, - {"enum", LogicalTypeId::ENUM}, - {"null", LogicalTypeId::SQLNULL}, - {"float", LogicalTypeId::FLOAT}, - {"real", LogicalTypeId::FLOAT}, - {"float4", LogicalTypeId::FLOAT}, - {"double", LogicalTypeId::DOUBLE}, - {"float8", LogicalTypeId::DOUBLE} -}}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/core_functions/core_functions.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - -namespace duckdb { - -class Catalog; -struct CatalogTransaction; - -struct CoreFunctions { - static void RegisterFunctions(Catalog &catalog, CatalogTransaction transaction); -}; - -} // namespace duckdb -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - - - -#ifndef DUCKDB_ADBC_INIT -#define DUCKDB_ADBC_INIT - - - -#ifdef __cplusplus -extern "C" { -#endif - -//! We gotta leak the symbols of the init function -duckdb_adbc::AdbcStatusCode duckdb_adbc_init(size_t count, struct duckdb_adbc::AdbcDriver *driver, - struct duckdb_adbc::AdbcError *error); - -#ifdef __cplusplus -} -#endif - -#endif - - -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#ifndef NANOARROW_H_INCLUDED -#define NANOARROW_H_INCLUDED - -#include -#include -#include - - - -namespace duckdb_nanoarrow { - -/// \file Arrow C Implementation -/// -/// EXPERIMENTAL. Interface subject to change. - -/// \page object-model Object Model -/// -/// Except where noted, objects are not thread-safe and clients should -/// take care to serialize accesses to methods. -/// -/// Because this library is intended to be vendored, it provides full type -/// definitions and encourages clients to stack or statically allocate -/// where convenient. - -/// \defgroup nanoarrow-malloc Memory management -/// -/// Non-buffer members of a struct ArrowSchema and struct ArrowArray -/// must be allocated using ArrowMalloc() or ArrowRealloc() and freed -/// using ArrowFree for schemas and arrays allocated here. Buffer members -/// are allocated using an ArrowBufferAllocator. - -/// \brief Allocate like malloc() -void *ArrowMalloc(int64_t size); - -/// \brief Reallocate like realloc() -void *ArrowRealloc(void *ptr, int64_t size); - -/// \brief Free a pointer allocated using ArrowMalloc() or ArrowRealloc(). -void ArrowFree(void *ptr); - -/// \brief Array buffer allocation and deallocation -/// -/// Container for allocate, reallocate, and free methods that can be used -/// to customize allocation and deallocation of buffers when constructing -/// an ArrowArray. -struct ArrowBufferAllocator { - /// \brief Allocate a buffer or return NULL if it cannot be allocated - uint8_t *(*allocate)(struct ArrowBufferAllocator *allocator, int64_t size); - - /// \brief Reallocate a buffer or return NULL if it cannot be reallocated - uint8_t *(*reallocate)(struct ArrowBufferAllocator *allocator, uint8_t *ptr, int64_t old_size, int64_t new_size); - - /// \brief Deallocate a buffer allocated by this allocator - void (*free)(struct ArrowBufferAllocator *allocator, uint8_t *ptr, int64_t size); - - /// \brief Opaque data specific to the allocator - void *private_data; -}; - -/// \brief Return the default allocator -/// -/// The default allocator uses ArrowMalloc(), ArrowRealloc(), and -/// ArrowFree(). -struct ArrowBufferAllocator *ArrowBufferAllocatorDefault(); - -/// }@ - -/// \defgroup nanoarrow-errors Error handling primitives -/// Functions generally return an errno-compatible error code; functions that -/// need to communicate more verbose error information accept a pointer -/// to an ArrowError. This can be stack or statically allocated. The -/// content of the message is undefined unless an error code has been -/// returned. - -/// \brief Error type containing a UTF-8 encoded message. -struct ArrowError { - char message[1024]; -}; - -/// \brief Return code for success. -#define NANOARROW_OK 0 - -/// \brief Represents an errno-compatible error code -typedef int ArrowErrorCode; - -/// \brief Set the contents of an error using printf syntax -ArrowErrorCode ArrowErrorSet(struct ArrowError *error, const char *fmt, ...); - -/// \brief Get the contents of an error -const char *ArrowErrorMessage(struct ArrowError *error); - -/// }@ - -/// \defgroup nanoarrow-utils Utility data structures - -/// \brief An non-owning view of a string -struct ArrowStringView { - /// \brief A pointer to the start of the string - /// - /// If n_bytes is 0, this value may be NULL. - const char *data; - - /// \brief The size of the string in bytes, - /// - /// (Not including the null terminator.) - int64_t n_bytes; -}; - -/// \brief Arrow type enumerator -/// -/// These names are intended to map to the corresponding arrow::Type::type -/// enumerator; however, the numeric values are specifically not equal -/// (i.e., do not rely on numeric comparison). -enum ArrowType { - NANOARROW_TYPE_UNINITIALIZED = 0, - NANOARROW_TYPE_NA = 1, - NANOARROW_TYPE_BOOL, - NANOARROW_TYPE_UINT8, - NANOARROW_TYPE_INT8, - NANOARROW_TYPE_UINT16, - NANOARROW_TYPE_INT16, - NANOARROW_TYPE_UINT32, - NANOARROW_TYPE_INT32, - NANOARROW_TYPE_UINT64, - NANOARROW_TYPE_INT64, - NANOARROW_TYPE_HALF_FLOAT, - NANOARROW_TYPE_FLOAT, - NANOARROW_TYPE_DOUBLE, - NANOARROW_TYPE_STRING, - NANOARROW_TYPE_BINARY, - NANOARROW_TYPE_FIXED_SIZE_BINARY, - NANOARROW_TYPE_DATE32, - NANOARROW_TYPE_DATE64, - NANOARROW_TYPE_TIMESTAMP, - NANOARROW_TYPE_TIME32, - NANOARROW_TYPE_TIME64, - NANOARROW_TYPE_INTERVAL_MONTHS, - NANOARROW_TYPE_INTERVAL_DAY_TIME, - NANOARROW_TYPE_DECIMAL128, - NANOARROW_TYPE_DECIMAL256, - NANOARROW_TYPE_LIST, - NANOARROW_TYPE_STRUCT, - NANOARROW_TYPE_SPARSE_UNION, - NANOARROW_TYPE_DENSE_UNION, - NANOARROW_TYPE_DICTIONARY, - NANOARROW_TYPE_MAP, - NANOARROW_TYPE_EXTENSION, - NANOARROW_TYPE_FIXED_SIZE_LIST, - NANOARROW_TYPE_DURATION, - NANOARROW_TYPE_LARGE_STRING, - NANOARROW_TYPE_LARGE_BINARY, - NANOARROW_TYPE_LARGE_LIST, - NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO -}; - -/// \brief Arrow time unit enumerator -/// -/// These names and values map to the corresponding arrow::TimeUnit::type -/// enumerator. -enum ArrowTimeUnit { - NANOARROW_TIME_UNIT_SECOND = 0, - NANOARROW_TIME_UNIT_MILLI = 1, - NANOARROW_TIME_UNIT_MICRO = 2, - NANOARROW_TIME_UNIT_NANO = 3 -}; - -/// }@ - -/// \defgroup nanoarrow-schema Schema producer helpers -/// These functions allocate, copy, and destroy ArrowSchema structures - -/// \brief Initialize the fields of a schema -/// -/// Initializes the fields and release callback of schema_out. Caller -/// is responsible for calling the schema->release callback if -/// NANOARROW_OK is returned. -ArrowErrorCode ArrowSchemaInit(struct ArrowSchema *schema, enum ArrowType type); - -/// \brief Initialize the fields of a fixed-size schema -/// -/// Returns EINVAL for fixed_size <= 0 or for data_type that is not -/// NANOARROW_TYPE_FIXED_SIZE_BINARY or NANOARROW_TYPE_FIXED_SIZE_LIST. -ArrowErrorCode ArrowSchemaInitFixedSize(struct ArrowSchema *schema, enum ArrowType data_type, int32_t fixed_size); - -/// \brief Initialize the fields of a decimal schema -/// -/// Returns EINVAL for scale <= 0 or for data_type that is not -/// NANOARROW_TYPE_DECIMAL128 or NANOARROW_TYPE_DECIMAL256. -ArrowErrorCode ArrowSchemaInitDecimal(struct ArrowSchema *schema, enum ArrowType data_type, int32_t decimal_precision, - int32_t decimal_scale); - -/// \brief Initialize the fields of a time, timestamp, or duration schema -/// -/// Returns EINVAL for data_type that is not -/// NANOARROW_TYPE_TIME32, NANOARROW_TYPE_TIME64, -/// NANOARROW_TYPE_TIMESTAMP, or NANOARROW_TYPE_DURATION. The -/// timezone parameter must be NULL for a non-timestamp data_type. -ArrowErrorCode ArrowSchemaInitDateTime(struct ArrowSchema *schema, enum ArrowType data_type, - enum ArrowTimeUnit time_unit, const char *timezone); - -/// \brief Make a (recursive) copy of a schema -/// -/// Allocates and copies fields of schema into schema_out. -ArrowErrorCode ArrowSchemaDeepCopy(struct ArrowSchema *schema, struct ArrowSchema *schema_out); - -/// \brief Copy format into schema->format -/// -/// schema must have been allocated using ArrowSchemaInit or -/// ArrowSchemaDeepCopy. -ArrowErrorCode ArrowSchemaSetFormat(struct ArrowSchema *schema, const char *format); - -/// \brief Copy name into schema->name -/// -/// schema must have been allocated using ArrowSchemaInit or -/// ArrowSchemaDeepCopy. -ArrowErrorCode ArrowSchemaSetName(struct ArrowSchema *schema, const char *name); - -/// \brief Copy metadata into schema->metadata -/// -/// schema must have been allocated using ArrowSchemaInit or -/// ArrowSchemaDeepCopy. -ArrowErrorCode ArrowSchemaSetMetadata(struct ArrowSchema *schema, const char *metadata); - -/// \brief Allocate the schema->children array -/// -/// Includes the memory for each child struct ArrowSchema. -/// schema must have been allocated using ArrowSchemaInit or -/// ArrowSchemaDeepCopy. -ArrowErrorCode ArrowSchemaAllocateChildren(struct ArrowSchema *schema, int64_t n_children); - -/// \brief Allocate the schema->dictionary member -/// -/// schema must have been allocated using ArrowSchemaInit or -/// ArrowSchemaDeepCopy. -ArrowErrorCode ArrowSchemaAllocateDictionary(struct ArrowSchema *schema); - -/// \brief Reader for key/value pairs in schema metadata -struct ArrowMetadataReader { - const char *metadata; - int64_t offset; - int32_t remaining_keys; -}; - -/// \brief Initialize an ArrowMetadataReader -ArrowErrorCode ArrowMetadataReaderInit(struct ArrowMetadataReader *reader, const char *metadata); - -/// \brief Read the next key/value pair from an ArrowMetadataReader -ArrowErrorCode ArrowMetadataReaderRead(struct ArrowMetadataReader *reader, struct ArrowStringView *key_out, - struct ArrowStringView *value_out); - -/// \brief The number of bytes in in a key/value metadata string -int64_t ArrowMetadataSizeOf(const char *metadata); - -/// \brief Check for a key in schema metadata -char ArrowMetadataHasKey(const char *metadata, const char *key); - -/// \brief Extract a value from schema metadata -ArrowErrorCode ArrowMetadataGetValue(const char *metadata, const char *key, const char *default_value, - struct ArrowStringView *value_out); - -/// }@ - -/// \defgroup nanoarrow-schema-view Schema consumer helpers - -/// \brief A non-owning view of a parsed ArrowSchema -/// -/// Contains more readily extractable values than a raw ArrowSchema. -/// Clients can stack or statically allocate this structure but are -/// encouraged to use the provided getters to ensure forward -/// compatiblity. -struct ArrowSchemaView { - /// \brief A pointer to the schema represented by this view - struct ArrowSchema *schema; - - /// \brief The data type represented by the schema - /// - /// This value may be NANOARROW_TYPE_DICTIONARY if the schema has a - /// non-null dictionary member; datetime types are valid values. - /// This value will never be NANOARROW_TYPE_EXTENSION (see - /// extension_name and/or extension_metadata to check for - /// an extension type). - enum ArrowType data_type; - - /// \brief The storage data type represented by the schema - /// - /// This value will never be NANOARROW_TYPE_DICTIONARY, NANOARROW_TYPE_EXTENSION - /// or any datetime type. This value represents only the type required to - /// interpret the buffers in the array. - enum ArrowType storage_data_type; - - /// \brief The extension type name if it exists - /// - /// If the ARROW:extension:name key is present in schema.metadata, - /// extension_name.data will be non-NULL. - struct ArrowStringView extension_name; - - /// \brief The extension type metadata if it exists - /// - /// If the ARROW:extension:metadata key is present in schema.metadata, - /// extension_metadata.data will be non-NULL. - struct ArrowStringView extension_metadata; - - /// \brief The expected number of buffers in a paired ArrowArray - int32_t n_buffers; - - /// \brief The index of the validity buffer or -1 if one does not exist - int32_t validity_buffer_id; - - /// \brief The index of the offset buffer or -1 if one does not exist - int32_t offset_buffer_id; - - /// \brief The index of the data buffer or -1 if one does not exist - int32_t data_buffer_id; - - /// \brief The index of the type_ids buffer or -1 if one does not exist - int32_t type_id_buffer_id; - - /// \brief Format fixed size parameter - /// - /// This value is set when parsing a fixed-size binary or fixed-size - /// list schema; this value is undefined for other types. For a - /// fixed-size binary schema this value is in bytes; for a fixed-size - /// list schema this value refers to the number of child elements for - /// each element of the parent. - int32_t fixed_size; - - /// \brief Decimal bitwidth - /// - /// This value is set when parsing a decimal type schema; - /// this value is undefined for other types. - int32_t decimal_bitwidth; - - /// \brief Decimal precision - /// - /// This value is set when parsing a decimal type schema; - /// this value is undefined for other types. - int32_t decimal_precision; - - /// \brief Decimal scale - /// - /// This value is set when parsing a decimal type schema; - /// this value is undefined for other types. - int32_t decimal_scale; - - /// \brief Format time unit parameter - /// - /// This value is set when parsing a date/time type. The value is - /// undefined for other types. - enum ArrowTimeUnit time_unit; - - /// \brief Format timezone parameter - /// - /// This value is set when parsing a timestamp type and represents - /// the timezone format parameter. The ArrowStrintgView points to - /// data within the schema and the value is undefined for other types. - struct ArrowStringView timezone; - - /// \brief Union type ids parameter - /// - /// This value is set when parsing a union type and represents - /// type ids parameter. The ArrowStringView points to - /// data within the schema and the value is undefined for other types. - struct ArrowStringView union_type_ids; -}; - -/// \brief Initialize an ArrowSchemaView -ArrowErrorCode ArrowSchemaViewInit(struct ArrowSchemaView *schema_view, struct ArrowSchema *schema, - struct ArrowError *error); - -/// }@ - -/// \defgroup nanoarrow-buffer-builder Growable buffer builders - -/// \brief An owning mutable view of a buffer -struct ArrowBuffer { - /// \brief A pointer to the start of the buffer - /// - /// If capacity_bytes is 0, this value may be NULL. - uint8_t *data; - - /// \brief The size of the buffer in bytes - int64_t size_bytes; - - /// \brief The capacity of the buffer in bytes - int64_t capacity_bytes; - - /// \brief The allocator that will be used to reallocate and/or free the buffer - struct ArrowBufferAllocator *allocator; -}; - -/// \brief Initialize an ArrowBuffer -/// -/// Initialize a buffer with a NULL, zero-size buffer using the default -/// buffer allocator. -void ArrowBufferInit(struct ArrowBuffer *buffer); - -/// \brief Set a newly-initialized buffer's allocator -/// -/// Returns EINVAL if the buffer has already been allocated. -ArrowErrorCode ArrowBufferSetAllocator(struct ArrowBuffer *buffer, struct ArrowBufferAllocator *allocator); - -/// \brief Reset an ArrowBuffer -/// -/// Releases the buffer using the allocator's free method if -/// the buffer's data member is non-null, sets the data member -/// to NULL, and sets the buffer's size and capacity to 0. -void ArrowBufferReset(struct ArrowBuffer *buffer); - -/// \brief Move an ArrowBuffer -/// -/// Transfers the buffer data and lifecycle management to another -/// address and resets buffer. -void ArrowBufferMove(struct ArrowBuffer *buffer, struct ArrowBuffer *buffer_out); - -/// \brief Grow or shrink a buffer to a given capacity -/// -/// When shrinking the capacity of the buffer, the buffer is only reallocated -/// if shrink_to_fit is non-zero. Calling ArrowBufferResize() does not -/// adjust the buffer's size member except to ensure that the invariant -/// capacity >= size remains true. -ArrowErrorCode ArrowBufferResize(struct ArrowBuffer *buffer, int64_t new_capacity_bytes, char shrink_to_fit); - -/// \brief Ensure a buffer has at least a given additional capacity -/// -/// Ensures that the buffer has space to append at least -/// additional_size_bytes, overallocating when required. -ArrowErrorCode ArrowBufferReserve(struct ArrowBuffer *buffer, int64_t additional_size_bytes); - -/// \brief Write data to buffer and increment the buffer size -/// -/// This function does not check that buffer has the required capacity -void ArrowBufferAppendUnsafe(struct ArrowBuffer *buffer, const void *data, int64_t size_bytes); - -/// \brief Write data to buffer and increment the buffer size -/// -/// This function writes and ensures that the buffer has the required capacity, -/// possibly by reallocating the buffer. Like ArrowBufferReserve, this will -/// overallocate when reallocation is required. -ArrowErrorCode ArrowBufferAppend(struct ArrowBuffer *buffer, const void *data, int64_t size_bytes); - -/// }@ - -} // namespace duckdb_nanoarrow - -#endif // NANOARROW_H_INCLUDED - - -// Bring in the symbols from duckdb_nanoarrow into duckdb -namespace duckdb { - -// using duckdb_nanoarrow::ArrowBuffer; //We have a variant of this that should be renamed -using duckdb_nanoarrow::ArrowBufferAllocator; -using duckdb_nanoarrow::ArrowError; -using duckdb_nanoarrow::ArrowSchemaView; -using duckdb_nanoarrow::ArrowStringView; - -} // namespace duckdb - - - - - -namespace duckdb_adbc { - -struct SingleBatchArrayStream { - struct ArrowSchema schema; - struct ArrowArray batch; -}; - -AdbcStatusCode BatchToArrayStream(struct ArrowArray *values, struct ArrowSchema *schema, - struct ArrowArrayStream *stream, struct AdbcError *error); - -} // namespace duckdb_adbc -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - - - - - -#ifdef __cplusplus -extern "C" { -#endif - -#ifndef ADBC_DRIVER_MANAGER_H -#define ADBC_DRIVER_MANAGER_H -namespace duckdb_adbc { -/// \brief Common entry point for drivers via the driver manager. -/// -/// The driver manager can fill in default implementations of some -/// ADBC functions for drivers. Drivers must implement a minimum level -/// of functionality for this to be possible, however, and some -/// functions must be implemented by the driver. -/// -/// \param[in] driver_name An identifier for the driver (e.g. a path to a -/// shared library on Linux). -/// \param[in] entrypoint An identifier for the entrypoint (e.g. the -/// symbol to call for AdbcDriverInitFunc on Linux). -/// \param[in] version The ADBC revision to attempt to initialize. -/// \param[out] driver The table of function pointers to initialize. -/// \param[out] error An optional location to return an error message -/// if necessary. -ADBC_EXPORT -AdbcStatusCode AdbcLoadDriver(const char *driver_name, const char *entrypoint, int version, void *driver, - struct AdbcError *error); - -/// \brief Common entry point for drivers via the driver manager. -/// -/// The driver manager can fill in default implementations of some -/// ADBC functions for drivers. Drivers must implement a minimum level -/// of functionality for this to be possible, however, and some -/// functions must be implemented by the driver. -/// -/// \param[in] init_func The entrypoint to call. -/// \param[in] version The ADBC revision to attempt to initialize. -/// \param[out] driver The table of function pointers to initialize. -/// \param[out] error An optional location to return an error message -/// if necessary. -ADBC_EXPORT -AdbcStatusCode AdbcLoadDriverFromInitFunc(AdbcDriverInitFunc init_func, int version, void *driver, - struct AdbcError *error); - -/// \brief Set the AdbcDriverInitFunc to use. -/// -/// This is an extension to the ADBC API. The driver manager shims -/// the AdbcDatabase* functions to allow you to specify the -/// driver/entrypoint dynamically. This function lets you set the -/// entrypoint explicitly, for applications that can dynamically -/// load drivers on their own. -ADBC_EXPORT -AdbcStatusCode AdbcDriverManagerDatabaseSetInitFunc(struct AdbcDatabase *database, AdbcDriverInitFunc init_func, - struct AdbcError *error); - -/// \brief Get a human-friendly description of a status code. -ADBC_EXPORT -const char *AdbcStatusCodeMessage(AdbcStatusCode code); - -#endif // ADBC_DRIVER_MANAGER_H - -#ifdef __cplusplus -} -#endif -} // namespace duckdb_adbc -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/arrow/arrow_appender.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - -namespace duckdb { - -struct ArrowAppendData; - -//! The ArrowAppender class can be used to incrementally construct an arrow array by appending data chunks into it -class ArrowAppender { -public: - DUCKDB_API ArrowAppender(vector types, idx_t initial_capacity, ClientProperties options); - DUCKDB_API ~ArrowAppender(); - - //! Append a data chunk to the underlying arrow array - DUCKDB_API void Append(DataChunk &input, idx_t from, idx_t to, idx_t input_size); - //! Returns the underlying arrow array - DUCKDB_API ArrowArray Finalize(); - -public: - static void ReleaseArray(ArrowArray *array); - static ArrowArray *FinalizeChild(const LogicalType &type, ArrowAppendData &append_data); - static unique_ptr InitializeChild(const LogicalType &type, idx_t capacity, - ClientProperties &options); - -private: - //! The types of the chunks that will be appended in - vector types; - //! The root arrow append data - vector> root_data; - //! The total row count that has been appended - idx_t row_count = 0; - - ClientProperties options; -}; - -} // namespace duckdb - - - - - - - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/arrow/arrow_buffer.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - -struct ArrowSchema; - -namespace duckdb { - -struct ArrowBuffer { - static constexpr const idx_t MINIMUM_SHRINK_SIZE = 4096; - - ArrowBuffer() : dataptr(nullptr), count(0), capacity(0) { - } - ~ArrowBuffer() { - if (!dataptr) { - return; - } - free(dataptr); - dataptr = nullptr; - count = 0; - capacity = 0; - } - // disable copy constructors - ArrowBuffer(const ArrowBuffer &other) = delete; - ArrowBuffer &operator=(const ArrowBuffer &) = delete; - //! enable move constructors - ArrowBuffer(ArrowBuffer &&other) noexcept { - std::swap(dataptr, other.dataptr); - std::swap(count, other.count); - std::swap(capacity, other.capacity); - } - ArrowBuffer &operator=(ArrowBuffer &&other) noexcept { - std::swap(dataptr, other.dataptr); - std::swap(count, other.count); - std::swap(capacity, other.capacity); - return *this; - } - - void reserve(idx_t bytes) { // NOLINT - auto new_capacity = NextPowerOfTwo(bytes); - if (new_capacity <= capacity) { - return; - } - ReserveInternal(new_capacity); - } - - void resize(idx_t bytes) { // NOLINT - reserve(bytes); - count = bytes; - } - - void resize(idx_t bytes, data_t value) { // NOLINT - reserve(bytes); - for (idx_t i = count; i < bytes; i++) { - dataptr[i] = value; - } - count = bytes; - } - - idx_t size() { // NOLINT - return count; - } - - data_ptr_t data() { // NOLINT - return dataptr; - } - - template - T *GetData() { - return reinterpret_cast(data()); - } - -private: - void ReserveInternal(idx_t bytes) { - if (dataptr) { - dataptr = data_ptr_cast(realloc(dataptr, bytes)); - } else { - dataptr = data_ptr_cast(malloc(bytes)); - } - capacity = bytes; - } - -private: - data_ptr_t dataptr = nullptr; - idx_t count = 0; - idx_t capacity = 0; -}; - -} // namespace duckdb - - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Arrow append data -//===--------------------------------------------------------------------===// -typedef void (*initialize_t)(ArrowAppendData &result, const LogicalType &type, idx_t capacity); -// append_data: The arrow array we're appending into -// input: The data we're appending -// from: The offset into the input we're scanning -// to: The last index of the input we're scanning -// input_size: The total size of the 'input' Vector. -typedef void (*append_vector_t)(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size); -typedef void (*finalize_t)(ArrowAppendData &append_data, const LogicalType &type, ArrowArray *result); - -// This struct is used to save state for appending a column -// afterwards the ownership is passed to the arrow array, as 'private_data' -// FIXME: we should separate the append state variables from the variables required by the ArrowArray into -// ArrowAppendState -struct ArrowAppendData { - explicit ArrowAppendData(ClientProperties &options_p) : options(options_p) { - } - // the buffers of the arrow vector - ArrowBuffer validity; - ArrowBuffer main_buffer; - ArrowBuffer aux_buffer; - - idx_t row_count = 0; - idx_t null_count = 0; - - // function pointers for construction - initialize_t initialize = nullptr; - append_vector_t append_vector = nullptr; - finalize_t finalize = nullptr; - - // child data (if any) - vector> child_data; - - // the arrow array C API data, only set after Finalize - unique_ptr array; - duckdb::array buffers = {{nullptr, nullptr, nullptr}}; - vector child_pointers; - - ClientProperties options; -}; - -//===--------------------------------------------------------------------===// -// Append Helper Functions -//===--------------------------------------------------------------------===// -static void GetBitPosition(idx_t row_idx, idx_t ¤t_byte, uint8_t ¤t_bit) { - current_byte = row_idx / 8; - current_bit = row_idx % 8; -} - -static void UnsetBit(uint8_t *data, idx_t current_byte, uint8_t current_bit) { - data[current_byte] &= ~((uint64_t)1 << current_bit); -} - -static void NextBit(idx_t ¤t_byte, uint8_t ¤t_bit) { - current_bit++; - if (current_bit == 8) { - current_byte++; - current_bit = 0; - } -} - -static void ResizeValidity(ArrowBuffer &buffer, idx_t row_count) { - auto byte_count = (row_count + 7) / 8; - buffer.resize(byte_count, 0xFF); -} - -static void SetNull(ArrowAppendData &append_data, uint8_t *validity_data, idx_t current_byte, uint8_t current_bit) { - UnsetBit(validity_data, current_byte, current_bit); - append_data.null_count++; -} - -static void AppendValidity(ArrowAppendData &append_data, UnifiedVectorFormat &format, idx_t from, idx_t to) { - // resize the buffer, filling the validity buffer with all valid values - idx_t size = to - from; - ResizeValidity(append_data.validity, append_data.row_count + size); - if (format.validity.AllValid()) { - // if all values are valid we don't need to do anything else - return; - } - - // otherwise we iterate through the validity mask - auto validity_data = (uint8_t *)append_data.validity.data(); - uint8_t current_bit; - idx_t current_byte; - GetBitPosition(append_data.row_count, current_byte, current_bit); - for (idx_t i = from; i < to; i++) { - auto source_idx = format.sel->get_index(i); - // append the validity mask - if (!format.validity.RowIsValid(source_idx)) { - SetNull(append_data, validity_data, current_byte, current_bit); - } - NextBit(current_byte, current_bit); - } -} - -} // namespace duckdb - - - -namespace duckdb { - -struct ArrowBoolData { -public: - static void Initialize(ArrowAppendData &result, const LogicalType &type, idx_t capacity); - static void Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size); - static void Finalize(ArrowAppendData &append_data, const LogicalType &type, ArrowArray *result); -}; - -} // namespace duckdb - - - - -namespace duckdb { - -struct ArrowListData { -public: - static void Initialize(ArrowAppendData &result, const LogicalType &type, idx_t capacity); - static void Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size); - static void Finalize(ArrowAppendData &append_data, const LogicalType &type, ArrowArray *result); - -public: - static void AppendOffsets(ArrowAppendData &append_data, UnifiedVectorFormat &format, idx_t from, idx_t to, - vector &child_sel); -}; - -} // namespace duckdb - - - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Maps -//===--------------------------------------------------------------------===// -struct ArrowMapData { -public: - static void Initialize(ArrowAppendData &result, const LogicalType &type, idx_t capacity); - static void Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size); - static void Finalize(ArrowAppendData &append_data, const LogicalType &type, ArrowArray *result); -}; - -} // namespace duckdb - - - - - - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/function/table/arrow.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/thread.hpp -// -// -//===----------------------------------------------------------------------===// - - - -#include - -namespace duckdb { -using std::thread; -} - - - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/function/table/arrow_duck_schema.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - - -namespace duckdb { -//===--------------------------------------------------------------------===// -// Arrow Variable Size Types -//===--------------------------------------------------------------------===// -enum class ArrowVariableSizeType : uint8_t { FIXED_SIZE = 0, NORMAL = 1, SUPER_SIZE = 2 }; - -//===--------------------------------------------------------------------===// -// Arrow Time/Date Types -//===--------------------------------------------------------------------===// -enum class ArrowDateTimeType : uint8_t { - MILLISECONDS = 0, - MICROSECONDS = 1, - NANOSECONDS = 2, - SECONDS = 3, - DAYS = 4, - MONTHS = 5, - MONTH_DAY_NANO = 6 -}; - -class ArrowType { -public: - //! From a DuckDB type - ArrowType(LogicalType type_p) - : type(std::move(type_p)), size_type(ArrowVariableSizeType::NORMAL), - date_time_precision(ArrowDateTimeType::DAYS) {}; - - //! From a DuckDB type + fixed_size - ArrowType(LogicalType type_p, idx_t fixed_size_p) - : type(std::move(type_p)), size_type(ArrowVariableSizeType::FIXED_SIZE), - date_time_precision(ArrowDateTimeType::DAYS), fixed_size(fixed_size_p) {}; - - //! From a DuckDB type + variable size type - ArrowType(LogicalType type_p, ArrowVariableSizeType size_type_p) - : type(std::move(type_p)), size_type(size_type_p), date_time_precision(ArrowDateTimeType::DAYS) {}; - - //! From a DuckDB type + datetime type - ArrowType(LogicalType type_p, ArrowDateTimeType date_time_precision_p) - : type(std::move(type_p)), size_type(ArrowVariableSizeType::NORMAL), - date_time_precision(date_time_precision_p) {}; - - void AddChild(unique_ptr child); - - void AssignChildren(vector> children); - - const LogicalType &GetDuckType() const; - - ArrowVariableSizeType GetSizeType() const; - - idx_t FixedSize() const; - - void SetDictionary(unique_ptr dictionary); - - ArrowDateTimeType GetDateTimeType() const; - - const ArrowType &GetDictionary() const; - - const ArrowType &operator[](idx_t index) const; - -private: - LogicalType type; - //! If we have a nested type, their children's type. - vector> children; - //! If its a variable size type (e.g., strings, blobs, lists) holds which type it is - ArrowVariableSizeType size_type; - //! If this is a date/time holds its precision - ArrowDateTimeType date_time_precision; - //! Only for size types with fixed size - idx_t fixed_size = 0; - //! Hold the optional type if the array is a dictionary - unique_ptr dictionary_type; -}; - -using arrow_column_map_t = unordered_map>; - -struct ArrowTableType { -public: - void AddColumn(idx_t index, unique_ptr type); - const arrow_column_map_t &GetColumns() const; - -private: - arrow_column_map_t arrow_convert_data; -}; - -} // namespace duckdb - - -namespace duckdb { - -struct ArrowInterval { - int32_t months; - int32_t days; - int64_t nanoseconds; - - inline bool operator==(const ArrowInterval &rhs) const { - return this->days == rhs.days && this->months == rhs.months && this->nanoseconds == rhs.nanoseconds; - } -}; - -struct ArrowProjectedColumns { - unordered_map projection_map; - vector columns; -}; - -struct ArrowStreamParameters { - ArrowProjectedColumns projected_columns; - TableFilterSet *filters; -}; - -typedef unique_ptr (*stream_factory_produce_t)(uintptr_t stream_factory_ptr, - ArrowStreamParameters ¶meters); -typedef void (*stream_factory_get_schema_t)(uintptr_t stream_factory_ptr, ArrowSchemaWrapper &schema); - -struct ArrowScanFunctionData : public PyTableFunctionData { -public: - ArrowScanFunctionData(stream_factory_produce_t scanner_producer_p, uintptr_t stream_factory_ptr_p) - : lines_read(0), stream_factory_ptr(stream_factory_ptr_p), scanner_producer(scanner_producer_p) { - } - vector all_types; - atomic lines_read; - ArrowSchemaWrapper schema_root; - idx_t rows_per_thread; - //! Pointer to the scanner factory - uintptr_t stream_factory_ptr; - //! Pointer to the scanner factory produce - stream_factory_produce_t scanner_producer; - //! Arrow table data - ArrowTableType arrow_table; -}; - -struct ArrowScanLocalState : public LocalTableFunctionState { - explicit ArrowScanLocalState(unique_ptr current_chunk) : chunk(current_chunk.release()) { - } - - unique_ptr stream; - shared_ptr chunk; - // This vector hold the Arrow Vectors owned by DuckDB to allow for zero-copy - // Note that only DuckDB can release these vectors - unordered_map> arrow_owned_data; - idx_t chunk_offset = 0; - idx_t batch_index = 0; - vector column_ids; - //! Store child vectors for Arrow Dictionary Vectors (col-idx,vector) - unordered_map> arrow_dictionary_vectors; - TableFilterSet *filters = nullptr; - //! The DataChunk containing all read columns (even filter columns that are immediately removed) - DataChunk all_columns; -}; - -struct ArrowScanGlobalState : public GlobalTableFunctionState { - unique_ptr stream; - mutex main_mutex; - idx_t max_threads = 1; - idx_t batch_index = 0; - bool done = false; - - vector projection_ids; - vector scanned_types; - - idx_t MaxThreads() const override { - return max_threads; - } - - bool CanRemoveFilterColumns() const { - return !projection_ids.empty(); - } -}; - -struct ArrowTableFunction { -public: - static void RegisterFunction(BuiltinFunctions &set); - -public: - //! Binds an arrow table - static unique_ptr ArrowScanBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names); - //! Actual conversion from Arrow to DuckDB - static void ArrowToDuckDB(ArrowScanLocalState &scan_state, const arrow_column_map_t &arrow_convert_data, - DataChunk &output, idx_t start, bool arrow_scan_is_projected = true); - - //! Get next scan state - static bool ArrowScanParallelStateNext(ClientContext &context, const FunctionData *bind_data_p, - ArrowScanLocalState &state, ArrowScanGlobalState ¶llel_state); - - //! Initialize Global State - static unique_ptr ArrowScanInitGlobal(ClientContext &context, - TableFunctionInitInput &input); - - //! Initialize Local State - static unique_ptr ArrowScanInitLocalInternal(ClientContext &context, - TableFunctionInitInput &input, - GlobalTableFunctionState *global_state); - static unique_ptr ArrowScanInitLocal(ExecutionContext &context, - TableFunctionInitInput &input, - GlobalTableFunctionState *global_state); - - //! Scan Function - static void ArrowScanFunction(ClientContext &context, TableFunctionInput &data, DataChunk &output); - static void PopulateArrowTableType(ArrowTableType &arrow_table, ArrowSchemaWrapper &schema_p, vector &names, - vector &return_types); - -protected: - //! Defines Maximum Number of Threads - static idx_t ArrowScanMaxThreads(ClientContext &context, const FunctionData *bind_data); - - //! Allows parallel Create Table / Insertion - static idx_t ArrowGetBatchIndex(ClientContext &context, const FunctionData *bind_data_p, - LocalTableFunctionState *local_state, GlobalTableFunctionState *global_state); - - //! -----Utility Functions:----- - //! Gets Arrow Table's Cardinality - static unique_ptr ArrowScanCardinality(ClientContext &context, const FunctionData *bind_data); - //! Gets the progress on the table scan, used for Progress Bars - static double ArrowProgress(ClientContext &context, const FunctionData *bind_data, - const GlobalTableFunctionState *global_state); - //! Renames repeated columns and case sensitive columns - static void RenameArrowColumns(vector &names); - //! Helper function to get the DuckDB logical type - static unique_ptr GetArrowLogicalType(ArrowSchema &schema); -}; - -} // namespace duckdb - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Scalar Types -//===--------------------------------------------------------------------===// -struct ArrowScalarConverter { - template - static TGT Operation(SRC input) { - return input; - } - - static bool SkipNulls() { - return false; - } - - template - static void SetNull(TGT &value) { - } -}; - -struct ArrowIntervalConverter { - template - static TGT Operation(SRC input) { - ArrowInterval result; - result.months = input.months; - result.days = input.days; - result.nanoseconds = input.micros * Interval::NANOS_PER_MICRO; - return result; - } - - static bool SkipNulls() { - return true; - } - - template - static void SetNull(TGT &value) { - } -}; - -template -struct ArrowScalarBaseData { - static void Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size) { - D_ASSERT(to >= from); - idx_t size = to - from; - D_ASSERT(size <= input_size); - UnifiedVectorFormat format; - input.ToUnifiedFormat(input_size, format); - - // append the validity mask - AppendValidity(append_data, format, from, to); - - // append the main data - append_data.main_buffer.resize(append_data.main_buffer.size() + sizeof(TGT) * size); - auto data = UnifiedVectorFormat::GetData(format); - auto result_data = append_data.main_buffer.GetData(); - - for (idx_t i = from; i < to; i++) { - auto source_idx = format.sel->get_index(i); - auto result_idx = append_data.row_count + i - from; - - if (OP::SkipNulls() && !format.validity.RowIsValid(source_idx)) { - OP::template SetNull(result_data[result_idx]); - continue; - } - result_data[result_idx] = OP::template Operation(data[source_idx]); - } - append_data.row_count += size; - } -}; - -template -struct ArrowScalarData : public ArrowScalarBaseData { - static void Initialize(ArrowAppendData &result, const LogicalType &type, idx_t capacity) { - result.main_buffer.reserve(capacity * sizeof(TGT)); - } - - static void Finalize(ArrowAppendData &append_data, const LogicalType &type, ArrowArray *result) { - result->n_buffers = 2; - result->buffers[1] = append_data.main_buffer.data(); - } -}; - -} // namespace duckdb - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Structs -//===--------------------------------------------------------------------===// -struct ArrowStructData { -public: - static void Initialize(ArrowAppendData &result, const LogicalType &type, idx_t capacity); - static void Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size); - static void Finalize(ArrowAppendData &append_data, const LogicalType &type, ArrowArray *result); -}; - -} // namespace duckdb - - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Unions -//===--------------------------------------------------------------------===// -/** - * Based on https://arrow.apache.org/docs/format/Columnar.html#union-layout & - * https://arrow.apache.org/docs/format/CDataInterface.html - */ -struct ArrowUnionData { -public: - static void Initialize(ArrowAppendData &result, const LogicalType &type, idx_t capacity); - static void Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size); - static void Finalize(ArrowAppendData &append_data, const LogicalType &type, ArrowArray *result); -}; - -} // namespace duckdb - - - - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Enums -//===--------------------------------------------------------------------===// -template -struct ArrowEnumData : public ArrowScalarBaseData { - static idx_t GetLength(string_t input) { - return input.GetSize(); - } - static void WriteData(data_ptr_t target, string_t input) { - memcpy(target, input.GetData(), input.GetSize()); - } - static void EnumAppendVector(ArrowAppendData &append_data, const Vector &input, idx_t size) { - D_ASSERT(input.GetVectorType() == VectorType::FLAT_VECTOR); - - // resize the validity mask and set up the validity buffer for iteration - ResizeValidity(append_data.validity, append_data.row_count + size); - - // resize the offset buffer - the offset buffer holds the offsets into the child array - append_data.main_buffer.resize(append_data.main_buffer.size() + sizeof(uint32_t) * (size + 1)); - auto data = FlatVector::GetData(input); - auto offset_data = append_data.main_buffer.GetData(); - if (append_data.row_count == 0) { - // first entry - offset_data[0] = 0; - } - // now append the string data to the auxiliary buffer - // the auxiliary buffer's length depends on the string lengths, so we resize as required - auto last_offset = offset_data[append_data.row_count]; - for (idx_t i = 0; i < size; i++) { - auto offset_idx = append_data.row_count + i + 1; - - auto string_length = GetLength(data[i]); - - // append the offset data - auto current_offset = last_offset + string_length; - offset_data[offset_idx] = current_offset; - - // resize the string buffer if required, and write the string data - append_data.aux_buffer.resize(current_offset); - WriteData(append_data.aux_buffer.data() + last_offset, data[i]); - - last_offset = current_offset; - } - append_data.row_count += size; - } - static void Initialize(ArrowAppendData &result, const LogicalType &type, idx_t capacity) { - result.main_buffer.reserve(capacity * sizeof(TGT)); - // construct the enum child data - auto enum_data = ArrowAppender::InitializeChild(LogicalType::VARCHAR, EnumType::GetSize(type), result.options); - EnumAppendVector(*enum_data, EnumType::GetValuesInsertOrder(type), EnumType::GetSize(type)); - result.child_data.push_back(std::move(enum_data)); - } - - static void Finalize(ArrowAppendData &append_data, const LogicalType &type, ArrowArray *result) { - result->n_buffers = 2; - result->buffers[1] = append_data.main_buffer.data(); - // finalize the enum child data, and assign it to the dictionary - result->dictionary = ArrowAppender::FinalizeChild(LogicalType::VARCHAR, *append_data.child_data[0]); - } -}; - -} // namespace duckdb - - - - - - - - - - - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Varchar -//===--------------------------------------------------------------------===// -struct ArrowVarcharConverter { - template - static idx_t GetLength(SRC input) { - return input.GetSize(); - } - - template - static void WriteData(data_ptr_t target, SRC input) { - memcpy(target, input.GetData(), input.GetSize()); - } -}; - -struct ArrowUUIDConverter { - template - static idx_t GetLength(SRC input) { - return UUID::STRING_SIZE; - } - - template - static void WriteData(data_ptr_t target, SRC input) { - UUID::ToString(input, char_ptr_cast(target)); - } -}; - -template -struct ArrowVarcharData { - static void Initialize(ArrowAppendData &result, const LogicalType &type, idx_t capacity) { - result.main_buffer.reserve((capacity + 1) * sizeof(BUFTYPE)); - - result.aux_buffer.reserve(capacity); - } - - static void Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size) { - idx_t size = to - from; - UnifiedVectorFormat format; - input.ToUnifiedFormat(input_size, format); - - // resize the validity mask and set up the validity buffer for iteration - ResizeValidity(append_data.validity, append_data.row_count + size); - auto validity_data = (uint8_t *)append_data.validity.data(); - - // resize the offset buffer - the offset buffer holds the offsets into the child array - append_data.main_buffer.resize(append_data.main_buffer.size() + sizeof(BUFTYPE) * (size + 1)); - auto data = UnifiedVectorFormat::GetData(format); - auto offset_data = append_data.main_buffer.GetData(); - if (append_data.row_count == 0) { - // first entry - offset_data[0] = 0; - } - // now append the string data to the auxiliary buffer - // the auxiliary buffer's length depends on the string lengths, so we resize as required - auto last_offset = offset_data[append_data.row_count]; - idx_t max_offset = append_data.row_count + to - from; - if (max_offset > NumericLimits::Maximum() && - append_data.options.arrow_offset_size == ArrowOffsetSize::REGULAR) { - throw InvalidInputException("Arrow Appender: The maximum total string size for regular string buffers is " - "%u but the offset of %lu exceeds this.", - NumericLimits::Maximum(), max_offset); - } - for (idx_t i = from; i < to; i++) { - auto source_idx = format.sel->get_index(i); - auto offset_idx = append_data.row_count + i + 1 - from; - - if (!format.validity.RowIsValid(source_idx)) { - uint8_t current_bit; - idx_t current_byte; - GetBitPosition(append_data.row_count + i - from, current_byte, current_bit); - SetNull(append_data, validity_data, current_byte, current_bit); - offset_data[offset_idx] = last_offset; - continue; - } - - auto string_length = OP::GetLength(data[source_idx]); - - // append the offset data - auto current_offset = last_offset + string_length; - offset_data[offset_idx] = current_offset; - - // resize the string buffer if required, and write the string data - append_data.aux_buffer.resize(current_offset); - OP::WriteData(append_data.aux_buffer.data() + last_offset, data[source_idx]); - - last_offset = current_offset; - } - append_data.row_count += size; - } - - static void Finalize(ArrowAppendData &append_data, const LogicalType &type, ArrowArray *result) { - result->n_buffers = 3; - result->buffers[1] = append_data.main_buffer.data(); - result->buffers[2] = append_data.aux_buffer.data(); - } -}; - -} // namespace duckdb - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/types/bit.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - - - - -namespace duckdb { - -//! The Bit class is a static class that holds helper functions for the BIT type. -class Bit { -public: - //! Returns the number of bits in the bit string - DUCKDB_API static idx_t BitLength(string_t bits); - //! Returns the number of set bits in the bit string - DUCKDB_API static idx_t BitCount(string_t bits); - //! Returns the number of bytes in the bit string - DUCKDB_API static idx_t OctetLength(string_t bits); - //! Extracts the nth bit from bit string; the first (leftmost) bit is indexed 0 - DUCKDB_API static idx_t GetBit(string_t bit_string, idx_t n); - //! Sets the nth bit in bit string to newvalue; the first (leftmost) bit is indexed 0 - DUCKDB_API static void SetBit(string_t &bit_string, idx_t n, idx_t new_value); - //! Returns first starting index of the specified substring within bits, or zero if it's not present. - DUCKDB_API static idx_t BitPosition(string_t substring, string_t bits); - //! Converts bits to a string, writing the output to the designated output string. - //! The string needs to have space for at least GetStringSize(bits) bytes. - DUCKDB_API static void ToString(string_t bits, char *output); - DUCKDB_API static string ToString(string_t str); - //! Returns the bit size of a string -> bit conversion - DUCKDB_API static bool TryGetBitStringSize(string_t str, idx_t &result_size, string *error_message); - //! Convert a string to a bit. This function should ONLY be called after calling GetBitSize, since it does NOT - //! perform data validation. - DUCKDB_API static void ToBit(string_t str, string_t &output); - - DUCKDB_API static string ToBit(string_t str); - - //! output needs to have enough space allocated before calling this function (blob size + 1) - DUCKDB_API static void BlobToBit(string_t blob, string_t &output); - - DUCKDB_API static string BlobToBit(string_t blob); - - //! output_str needs to have enough space allocated before calling this function (sizeof(T) + 1) - template - static void NumericToBit(T numeric, string_t &output_str); - - template - static string NumericToBit(T numeric); - - //! bit is expected to fit inside of output num (bit size <= sizeof(T) + 1) - template - static void BitToNumeric(string_t bit, T &output_num); - - template - static T BitToNumeric(string_t bit); - - //! bit is expected to fit inside of output_blob (bit size = output_blob + 1) - static void BitToBlob(string_t bit, string_t &output_blob); - - static string BitToBlob(string_t bit); - - //! Creates a new bitstring of determined length - DUCKDB_API static void BitString(const string_t &input, const idx_t &len, string_t &result); - DUCKDB_API static void SetEmptyBitString(string_t &target, string_t &input); - DUCKDB_API static void SetEmptyBitString(string_t &target, idx_t len); - DUCKDB_API static idx_t ComputeBitstringLen(idx_t len); - - DUCKDB_API static void RightShift(const string_t &bit_string, const idx_t &shif, string_t &result); - DUCKDB_API static void LeftShift(const string_t &bit_string, const idx_t &shift, string_t &result); - DUCKDB_API static void BitwiseAnd(const string_t &rhs, const string_t &lhs, string_t &result); - DUCKDB_API static void BitwiseOr(const string_t &rhs, const string_t &lhs, string_t &result); - DUCKDB_API static void BitwiseXor(const string_t &rhs, const string_t &lhs, string_t &result); - DUCKDB_API static void BitwiseNot(const string_t &rhs, string_t &result); - - DUCKDB_API static void Verify(const string_t &input); - -private: - static void Finalize(string_t &str); - static idx_t GetBitInternal(string_t bit_string, idx_t n); - static void SetBitInternal(string_t &bit_string, idx_t n, idx_t new_value); - static idx_t GetBitIndex(idx_t n); - static uint8_t GetFirstByte(const string_t &str); -}; - -//===--------------------------------------------------------------------===// -// Bit Template definitions -//===--------------------------------------------------------------------===// -template -void Bit::NumericToBit(T numeric, string_t &output_str) { - D_ASSERT(output_str.GetSize() >= sizeof(T) + 1); - - auto output = output_str.GetDataWriteable(); - auto data = const_data_ptr_cast(&numeric); - - *output = 0; // set padding to 0 - ++output; - for (idx_t idx = 0; idx < sizeof(T); ++idx) { - output[idx] = data[sizeof(T) - idx - 1]; - } - Bit::Finalize(output_str); -} - -template -string Bit::NumericToBit(T numeric) { - auto bit_len = sizeof(T) + 1; - auto buffer = make_unsafe_uniq_array(bit_len); - string_t output_str(buffer.get(), bit_len); - Bit::NumericToBit(numeric, output_str); - return output_str.GetString(); -} - -template -T Bit::BitToNumeric(string_t bit) { - T output; - Bit::BitToNumeric(bit, output); - return (output); -} - -template -void Bit::BitToNumeric(string_t bit, T &output_num) { - D_ASSERT(bit.GetSize() <= sizeof(T) + 1); - - output_num = 0; - auto data = const_data_ptr_cast(bit.GetData()); - auto output = data_ptr_cast(&output_num); - - idx_t padded_byte_idx = sizeof(T) - bit.GetSize() + 1; - output[sizeof(T) - 1 - padded_byte_idx] = GetFirstByte(bit); - for (idx_t idx = padded_byte_idx + 1; idx < sizeof(T); ++idx) { - output[sizeof(T) - 1 - idx] = data[1 + idx - padded_byte_idx]; - } -} - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/types/sel_cache.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - -namespace duckdb { - -//! Selection vector cache used for caching vector slices -struct SelCache { - unordered_map> cache; -}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/arrow/result_arrow_wrapper.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - -namespace duckdb { -class ResultArrowArrayStreamWrapper { -public: - explicit ResultArrowArrayStreamWrapper(unique_ptr result, idx_t batch_size); - -public: - ArrowArrayStream stream; - unique_ptr result; - PreservedError last_error; - idx_t batch_size; - vector column_types; - vector column_names; - unique_ptr scan_state; - -private: - static int MyStreamGetSchema(struct ArrowArrayStream *stream, struct ArrowSchema *out); - static int MyStreamGetNext(struct ArrowArrayStream *stream, struct ArrowArray *out); - static void MyStreamRelease(struct ArrowArrayStream *stream); - static const char *MyStreamGetLastError(struct ArrowArrayStream *stream); -}; -} // namespace duckdb - - - - - -namespace duckdb { - -class QueryResult; - -class QueryResultChunkScanState : public ChunkScanState { -public: - QueryResultChunkScanState(QueryResult &result); - ~QueryResultChunkScanState(); - -public: - bool LoadNextChunk(PreservedError &error) override; - bool HasError() const override; - PreservedError &GetError() override; - const vector &Types() const override; - const vector &Names() const override; - -private: - bool InternalLoad(PreservedError &error); - -private: - QueryResult &result; -}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/bind_helpers.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - -namespace duckdb { - -class Value; - -Value ConvertVectorToValue(vector set); -vector ParseColumnList(const vector &set, vector &names, const string &option_name); -vector ParseColumnList(const Value &value, vector &names, const string &option_name); -vector ParseColumnsOrdered(const vector &set, vector &names, const string &loption); -vector ParseColumnsOrdered(const Value &value, vector &names, const string &loption); - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/box_renderer.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/main/query_profiler.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - - - - - - - -#include - - - -namespace duckdb { -class ClientContext; -class ExpressionExecutor; -class PhysicalOperator; -class SQLStatement; - -//! The ExpressionInfo keeps information related to an expression -struct ExpressionInfo { - explicit ExpressionInfo() : hasfunction(false) { - } - // A vector of children - vector> children; - // Extract ExpressionInformation from a given expression state - void ExtractExpressionsRecursive(unique_ptr &state); - - //! Whether or not expression has function - bool hasfunction; - //! The function Name - string function_name; - //! The function time - uint64_t function_time = 0; - //! Count the number of ALL tuples - uint64_t tuples_count = 0; - //! Count the number of tuples sampled - uint64_t sample_tuples_count = 0; -}; - -//! The ExpressionRootInfo keeps information related to the root of an expression tree -struct ExpressionRootInfo { - ExpressionRootInfo(ExpressionExecutorState &executor, string name); - - //! Count the number of time the executor called - uint64_t total_count = 0; - //! Count the number of time the executor called since last sampling - uint64_t current_count = 0; - //! Count the number of samples - uint64_t sample_count = 0; - //! Count the number of tuples in all samples - uint64_t sample_tuples_count = 0; - //! Count the number of tuples processed by this executor - uint64_t tuples_count = 0; - //! A vector which contain the pointer to root of each expression tree - unique_ptr root; - //! Name - string name; - //! Elapsed time - double time; - //! Extra Info - string extra_info; -}; - -struct ExpressionExecutorInfo { - explicit ExpressionExecutorInfo() {}; - explicit ExpressionExecutorInfo(ExpressionExecutor &executor, const string &name, int id); - - //! A vector which contain the pointer to all ExpressionRootInfo - vector> roots; - //! Id, it will be used as index for executors_info vector - int id; -}; - -struct OperatorInformation { - explicit OperatorInformation(double time_ = 0, idx_t elements_ = 0) : time(time_), elements(elements_) { - } - - double time = 0; - idx_t elements = 0; - string name; - //! A vector of Expression Executor Info - vector> executors_info; -}; - -//! The OperatorProfiler measures timings of individual operators -class OperatorProfiler { - friend class QueryProfiler; - -public: - DUCKDB_API explicit OperatorProfiler(bool enabled); - - DUCKDB_API void StartOperator(optional_ptr phys_op); - DUCKDB_API void EndOperator(optional_ptr chunk); - DUCKDB_API void Flush(const PhysicalOperator &phys_op, ExpressionExecutor &expression_executor, const string &name, - int id); - - ~OperatorProfiler() { - } - -private: - void AddTiming(const PhysicalOperator &op, double time, idx_t elements); - - //! Whether or not the profiler is enabled - bool enabled; - //! The timer used to time the execution time of the individual Physical Operators - Profiler op; - //! The stack of Physical Operators that are currently active - optional_ptr active_operator; - //! A mapping of physical operators to recorded timings - reference_map_t timings; -}; - -//! The QueryProfiler can be used to measure timings of queries -class QueryProfiler { -public: - DUCKDB_API QueryProfiler(ClientContext &context); - -public: - struct TreeNode { - PhysicalOperatorType type; - string name; - string extra_info; - OperatorInformation info; - vector> children; - idx_t depth = 0; - }; - - // Propagate save_location, enabled, detailed_enabled and automatic_print_format. - void Propagate(QueryProfiler &qp); - - using TreeMap = reference_map_t>; - -private: - unique_ptr CreateTree(const PhysicalOperator &root, idx_t depth = 0); - void Render(const TreeNode &node, std::ostream &str) const; - -public: - DUCKDB_API bool IsEnabled() const; - DUCKDB_API bool IsDetailedEnabled() const; - DUCKDB_API ProfilerPrintFormat GetPrintFormat() const; - DUCKDB_API bool PrintOptimizerOutput() const; - DUCKDB_API string GetSaveLocation() const; - - DUCKDB_API static QueryProfiler &Get(ClientContext &context); - - DUCKDB_API void StartQuery(string query, bool is_explain_analyze = false, bool start_at_optimizer = false); - DUCKDB_API void EndQuery(); - - DUCKDB_API void StartExplainAnalyze(); - - //! Adds the timings gathered by an OperatorProfiler to this query profiler - DUCKDB_API void Flush(OperatorProfiler &profiler); - - DUCKDB_API void StartPhase(string phase); - DUCKDB_API void EndPhase(); - - DUCKDB_API void Initialize(const PhysicalOperator &root); - - DUCKDB_API string QueryTreeToString() const; - DUCKDB_API void QueryTreeToStream(std::ostream &str) const; - DUCKDB_API void Print(); - - //! return the printed as a string. Unlike ToString, which is always formatted as a string, - //! the return value is formatted based on the current print format (see GetPrintFormat()). - DUCKDB_API string ToString() const; - - DUCKDB_API string ToJSON() const; - DUCKDB_API void WriteToFile(const char *path, string &info) const; - - idx_t OperatorSize() { - return tree_map.size(); - } - - void Finalize(TreeNode &node); - -private: - ClientContext &context; - - //! Whether or not the query profiler is running - bool running; - //! The lock used for flushing information from a thread into the global query profiler - mutex flush_lock; - - //! Whether or not the query requires profiling - bool query_requires_profiling; - - //! The root of the query tree - unique_ptr root; - //! The query string - string query; - //! The timer used to time the execution time of the entire query - Profiler main_query; - //! A map of a Physical Operator pointer to a tree node - TreeMap tree_map; - //! Whether or not we are running as part of a explain_analyze query - bool is_explain_analyze; - -public: - const TreeMap &GetTreeMap() const { - return tree_map; - } - -private: - //! The timer used to time the individual phases of the planning process - Profiler phase_profiler; - //! A mapping of the phase names to the timings - using PhaseTimingStorage = unordered_map; - PhaseTimingStorage phase_timings; - using PhaseTimingItem = PhaseTimingStorage::value_type; - //! The stack of currently active phases - vector phase_stack; - -private: - vector GetOrderedPhaseTimings() const; - - //! Check whether or not an operator type requires query profiling. If none of the ops in a query require profiling - //! no profiling information is output. - bool OperatorRequiresProfiling(PhysicalOperatorType op_type); -}; - -//! The QueryProfilerHistory can be used to access the profiler of previous queries -class QueryProfilerHistory { -private: - static constexpr uint64_t DEFAULT_SIZE = 20; - - //! Previous Query profilers - deque>> prev_profilers; - //! Previous Query profilers size - uint64_t prev_profilers_size = DEFAULT_SIZE; - -public: - deque>> &GetPrevProfilers() { - return prev_profilers; - } - QueryProfilerHistory() { - } - - void SetPrevProfilersSize(uint64_t prevProfilersSize) { - prev_profilers_size = prevProfilersSize; - } - uint64_t GetPrevProfilersSize() const { - return prev_profilers_size; - } - -public: - void SetProfilerHistorySize(uint64_t size) { - this->prev_profilers_size = size; - } - void ResetProfilerHistorySize() { - this->prev_profilers_size = DEFAULT_SIZE; - } -}; -} // namespace duckdb - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/list.hpp -// -// -//===----------------------------------------------------------------------===// - - - -#include - -namespace duckdb { -using std::list; -} - - -namespace duckdb { -class ColumnDataCollection; -class ColumnDataRowCollection; - -enum class ValueRenderAlignment { LEFT, MIDDLE, RIGHT }; -enum class RenderMode : uint8_t { ROWS, COLUMNS }; - -struct BoxRendererConfig { - // a max_width of 0 means we default to the terminal width - idx_t max_width = 0; - // the maximum amount of rows to render - idx_t max_rows = 20; - // the limit that is applied prior to rendering - // if we are rendering exactly "limit" rows then a question mark is rendered instead - idx_t limit = 0; - // the max col width determines the maximum size of a single column - // note that the max col width is only used if the result does not fit on the screen - idx_t max_col_width = 20; - //! how to render NULL values - string null_value = "NULL"; - //! Whether or not to render row-wise or column-wise - RenderMode render_mode = RenderMode::ROWS; - -#ifndef DUCKDB_ASCII_TREE_RENDERER - const char *LTCORNER = "\342\224\214"; // "┌"; - const char *RTCORNER = "\342\224\220"; // "┐"; - const char *LDCORNER = "\342\224\224"; // "└"; - const char *RDCORNER = "\342\224\230"; // "┘"; - - const char *MIDDLE = "\342\224\274"; // "┼"; - const char *TMIDDLE = "\342\224\254"; // "┬"; - const char *LMIDDLE = "\342\224\234"; // "├"; - const char *RMIDDLE = "\342\224\244"; // "┤"; - const char *DMIDDLE = "\342\224\264"; // "┴"; - - const char *VERTICAL = "\342\224\202"; // "│"; - const char *HORIZONTAL = "\342\224\200"; // "─"; - - const char *DOTDOTDOT = "\xE2\x80\xA6"; // "…"; - const char *DOT = "\xC2\xB7"; // "·"; - const idx_t DOTDOTDOT_LENGTH = 1; - -#else - // ASCII version - const char *LTCORNER = "<"; - const char *RTCORNER = ">"; - const char *LDCORNER = "<"; - const char *RDCORNER = ">"; - - const char *MIDDLE = "+"; - const char *TMIDDLE = "+"; - const char *LMIDDLE = "+"; - const char *RMIDDLE = "+"; - const char *DMIDDLE = "+"; - - const char *VERTICAL = "|"; - const char *HORIZONTAL = "-"; - - const char *DOTDOTDOT = "..."; // "..."; - const char *DOT = "."; // "."; - const idx_t DOTDOTDOT_LENGTH = 3; -#endif -}; - -class BoxRenderer { - static const idx_t SPLIT_COLUMN; - -public: - explicit BoxRenderer(BoxRendererConfig config_p = BoxRendererConfig()); - - string ToString(ClientContext &context, const vector &names, const ColumnDataCollection &op); - - void Render(ClientContext &context, const vector &names, const ColumnDataCollection &op, std::ostream &ss); - void Print(ClientContext &context, const vector &names, const ColumnDataCollection &op); - -private: - //! The configuration used for rendering - BoxRendererConfig config; - -private: - void RenderValue(std::ostream &ss, const string &value, idx_t column_width, - ValueRenderAlignment alignment = ValueRenderAlignment::MIDDLE); - string RenderType(const LogicalType &type); - ValueRenderAlignment TypeAlignment(const LogicalType &type); - string GetRenderValue(ColumnDataRowCollection &rows, idx_t c, idx_t r); - list FetchRenderCollections(ClientContext &context, const ColumnDataCollection &result, - idx_t top_rows, idx_t bottom_rows); - list PivotCollections(ClientContext &context, list input, - vector &column_names, vector &result_types, - idx_t row_count); - vector ComputeRenderWidths(const vector &names, const vector &result_types, - list &collections, idx_t min_width, idx_t max_width, - vector &column_map, idx_t &total_length); - void RenderHeader(const vector &names, const vector &result_types, - const vector &column_map, const vector &widths, const vector &boundaries, - idx_t total_length, bool has_results, std::ostream &ss); - void RenderValues(const list &collections, const vector &column_map, - const vector &widths, const vector &result_types, std::ostream &ss); - void RenderRowCount(string row_count_str, string shown_str, const string &column_count_str, - const vector &boundaries, bool has_hidden_rows, bool has_hidden_columns, - idx_t total_length, idx_t row_count, idx_t column_count, idx_t minimum_row_length, - std::ostream &ss); -}; - -} // namespace duckdb - - -// LICENSE_CHANGE_BEGIN -// The following code up to LICENSE_CHANGE_END is subject to THIRD PARTY LICENSE #3 -// See the end of this file for a list - -//===----------------------------------------------------------------------===// -// DuckDB -// -// utf8proc_wrapper.hpp -// -// -//===----------------------------------------------------------------------===// - - - -#include -#include -#include -#include - -namespace duckdb { - -enum class UnicodeType { INVALID, ASCII, UNICODE }; -enum class UnicodeInvalidReason { BYTE_MISMATCH, INVALID_UNICODE }; - -class Utf8Proc { -public: - //! Distinguishes ASCII, Valid UTF8 and Invalid UTF8 strings - static UnicodeType Analyze(const char *s, size_t len, UnicodeInvalidReason *invalid_reason = nullptr, size_t *invalid_pos = nullptr); - //! Performs UTF NFC normalization of string, return value needs to be free'd - static char* Normalize(const char* s, size_t len); - //! Returns whether or not the UTF8 string is valid - static bool IsValid(const char *s, size_t len); - //! Returns the position (in bytes) of the next grapheme cluster - static size_t NextGraphemeCluster(const char *s, size_t len, size_t pos); - //! Returns the position (in bytes) of the previous grapheme cluster - static size_t PreviousGraphemeCluster(const char *s, size_t len, size_t pos); - - //! Transform a codepoint to utf8 and writes it to "c", sets "sz" to the size of the codepoint - static bool CodepointToUtf8(int cp, int &sz, char *c); - //! Returns the codepoint length in bytes when encoded in UTF8 - static int CodepointLength(int cp); - //! Transform a UTF8 string to a codepoint; returns the codepoint and writes the length of the codepoint (in UTF8) to sz - static int32_t UTF8ToCodepoint(const char *c, int &sz); - //! Returns the render width of a single character in a string - static size_t RenderWidth(const char *s, size_t len, size_t pos); - static size_t RenderWidth(const std::string &str); - -}; - -} - - -// LICENSE_CHANGE_END -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/checksum.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - -namespace duckdb { - -//! Compute a checksum over a buffer of size size -uint64_t Checksum(uint8_t *buffer, size_t size); - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/crypto/md5.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - -namespace duckdb { - -class MD5Context { -public: - static constexpr idx_t MD5_HASH_LENGTH_BINARY = 16; - static constexpr idx_t MD5_HASH_LENGTH_TEXT = 32; - -public: - MD5Context(); - - void Add(const_data_ptr_t data, idx_t len) { - MD5Update(data, len); - } - void Add(const char *data); - void Add(string_t string) { - MD5Update(const_data_ptr_cast(string.GetData()), string.GetSize()); - } - void Add(const string &data) { - MD5Update(const_data_ptr_cast(data.c_str()), data.size()); - } - - //! Write the 16-byte (binary) digest to the specified location - void Finish(data_ptr_t out_digest); - //! Write the 32-character digest (in hexadecimal format) to the specified location - void FinishHex(char *out_digest); - //! Returns the 32-character digest (in hexadecimal format) as a string - string FinishHex(); - -private: - void MD5Update(const_data_ptr_t data, idx_t len); - static void DigestToBase16(const_data_ptr_t digest, char *zBuf); - - uint32_t buf[4]; - uint32_t bits[2]; - unsigned char in[64]; -}; - -} // namespace duckdb - - -// LICENSE_CHANGE_BEGIN -// The following code up to LICENSE_CHANGE_END is subject to THIRD PARTY LICENSE #4 -// See the end of this file for a list - -//===----------------------------------------------------------------------===// -// DuckDB -// -// mbedtls_wrapper.hpp -// -// -//===----------------------------------------------------------------------===// - - - -#include - -namespace duckdb_mbedtls { -class MbedTlsWrapper { -public: - static void ComputeSha256Hash(const char* in, size_t in_len, char* out); - static std::string ComputeSha256Hash(const std::string& file_content); - static bool IsValidSha256Signature(const std::string& pubkey, const std::string& signature, const std::string& sha256_hash); - static void Hmac256(const char* key, size_t key_len, const char* message, size_t message_len, char* out); - static void ToBase16(char *in, char *out, size_t len); - - static constexpr size_t SHA256_HASH_LENGTH_BYTES = 32; - static constexpr size_t SHA256_HASH_LENGTH_TEXT = 64; - - class SHA256State { - public: - SHA256State(); - ~SHA256State(); - void AddString(const std::string & str); - std::string Finalize(); - void FinishHex(char *out); - private: - void *sha_context; - }; -}; -} - - -// LICENSE_CHANGE_END -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/enums/date_part_specifier.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - -namespace duckdb { - -enum class DatePartSpecifier : uint8_t { - // BIGINT values - YEAR, - MONTH, - DAY, - DECADE, - CENTURY, - MILLENNIUM, - MICROSECONDS, - MILLISECONDS, - SECOND, - MINUTE, - HOUR, - DOW, - ISODOW, - WEEK, - ISOYEAR, - QUARTER, - DOY, - YEARWEEK, - ERA, - TIMEZONE, - TIMEZONE_HOUR, - TIMEZONE_MINUTE, - - // DOUBLE values - EPOCH, - JULIAN_DAY, - - // Invalid - INVALID, - - // Type ranges - BEGIN_BIGINT = YEAR, - BEGIN_DOUBLE = EPOCH, - BEGIN_INVALID = INVALID, -}; - -inline bool IsBigintDatepart(DatePartSpecifier part_code) { - return size_t(part_code) < size_t(DatePartSpecifier::BEGIN_DOUBLE); -} - -DUCKDB_API bool TryGetDatePartSpecifier(const string &specifier, DatePartSpecifier &result); -DUCKDB_API DatePartSpecifier GetDatePartSpecifier(const string &specifier); - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/enums/set_operation_type.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - -namespace duckdb { - -enum class SetOperationType : uint8_t { NONE = 0, UNION = 1, EXCEPT = 2, INTERSECT = 3, UNION_BY_NAME = 4 }; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/enums/set_type.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - -namespace duckdb { - -enum class SetType : uint8_t { SET = 0, RESET = 1 }; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/extra_type_info.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - -namespace duckdb { - -//! Extra Type Info Type -enum class ExtraTypeInfoType : uint8_t { - INVALID_TYPE_INFO = 0, - GENERIC_TYPE_INFO = 1, - DECIMAL_TYPE_INFO = 2, - STRING_TYPE_INFO = 3, - LIST_TYPE_INFO = 4, - STRUCT_TYPE_INFO = 5, - ENUM_TYPE_INFO = 6, - USER_TYPE_INFO = 7, - AGGREGATE_STATE_TYPE_INFO = 8 -}; - -struct ExtraTypeInfo { - explicit ExtraTypeInfo(ExtraTypeInfoType type); - explicit ExtraTypeInfo(ExtraTypeInfoType type, string alias); - virtual ~ExtraTypeInfo(); - - ExtraTypeInfoType type; - string alias; - -public: - bool Equals(ExtraTypeInfo *other_p) const; - - virtual void Serialize(Serializer &serializer) const; - static shared_ptr Deserialize(Deserializer &source); - - template - TARGET &Cast() { - D_ASSERT(dynamic_cast(this)); - return reinterpret_cast(*this); - } - template - const TARGET &Cast() const { - D_ASSERT(dynamic_cast(this)); - return reinterpret_cast(*this); - } - -protected: - virtual bool EqualsInternal(ExtraTypeInfo *other_p) const; -}; - -struct DecimalTypeInfo : public ExtraTypeInfo { - DecimalTypeInfo(uint8_t width_p, uint8_t scale_p); - - uint8_t width; - uint8_t scale; - -public: - void Serialize(Serializer &serializer) const override; - static shared_ptr Deserialize(Deserializer &source); - -protected: - bool EqualsInternal(ExtraTypeInfo *other_p) const override; - -private: - DecimalTypeInfo(); -}; - -struct StringTypeInfo : public ExtraTypeInfo { - explicit StringTypeInfo(string collation_p); - - string collation; - -public: - void Serialize(Serializer &serializer) const override; - static shared_ptr Deserialize(Deserializer &source); - -protected: - bool EqualsInternal(ExtraTypeInfo *other_p) const override; - -private: - StringTypeInfo(); -}; - -struct ListTypeInfo : public ExtraTypeInfo { - explicit ListTypeInfo(LogicalType child_type_p); - - LogicalType child_type; - -public: - void Serialize(Serializer &serializer) const override; - static shared_ptr Deserialize(Deserializer &source); - -protected: - bool EqualsInternal(ExtraTypeInfo *other_p) const override; - -private: - ListTypeInfo(); -}; - -struct StructTypeInfo : public ExtraTypeInfo { - explicit StructTypeInfo(child_list_t child_types_p); - - child_list_t child_types; - -public: - void Serialize(Serializer &serializer) const override; - static shared_ptr Deserialize(Deserializer &deserializer); - -protected: - bool EqualsInternal(ExtraTypeInfo *other_p) const override; - -private: - StructTypeInfo(); -}; - -struct AggregateStateTypeInfo : public ExtraTypeInfo { - explicit AggregateStateTypeInfo(aggregate_state_t state_type_p); - - aggregate_state_t state_type; - -public: - void Serialize(Serializer &serializer) const override; - static shared_ptr Deserialize(Deserializer &source); - -protected: - bool EqualsInternal(ExtraTypeInfo *other_p) const override; - -private: - AggregateStateTypeInfo(); -}; - -struct UserTypeInfo : public ExtraTypeInfo { - explicit UserTypeInfo(string name_p); - - string user_type_name; - -public: - void Serialize(Serializer &serializer) const override; - static shared_ptr Deserialize(Deserializer &source); - -protected: - bool EqualsInternal(ExtraTypeInfo *other_p) const override; - -private: - UserTypeInfo(); -}; - -// If this type is primarily stored in the catalog or not. Enums from Pandas/Factors are not in the catalog. -enum EnumDictType : uint8_t { INVALID = 0, VECTOR_DICT = 1 }; - -struct EnumTypeInfo : public ExtraTypeInfo { - explicit EnumTypeInfo(Vector &values_insert_order_p, idx_t dict_size_p); - EnumTypeInfo(const EnumTypeInfo &) = delete; - EnumTypeInfo &operator=(const EnumTypeInfo &) = delete; - -public: - const EnumDictType &GetEnumDictType() const; - const Vector &GetValuesInsertOrder() const; - const idx_t &GetDictSize() const; - static PhysicalType DictType(idx_t size); - - static LogicalType CreateType(Vector &ordered_data, idx_t size); - - void Serialize(Serializer &serializer) const override; - static shared_ptr Deserialize(Deserializer &source); - -protected: - // Equalities are only used in enums with different catalog entries - bool EqualsInternal(ExtraTypeInfo *other_p) const override; - - Vector values_insert_order; - -private: - EnumDictType dict_type; - idx_t dict_size; -}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/sort/partition_state.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/radix_partitioning.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/types/row/partitioned_tuple_data.hpp -// -// -//===----------------------------------------------------------------------===// - - - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/fixed_size_map.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - -namespace duckdb { - -template -struct fixed_size_map_iterator_t; - -template -struct const_fixed_size_map_iterator_t; - -template -class fixed_size_map_t { - friend struct fixed_size_map_iterator_t; - friend struct const_fixed_size_map_iterator_t; - -public: - using key_type = idx_t; - using mapped_type = T; - -public: - explicit fixed_size_map_t(idx_t capacity_p = 0) : capacity(capacity_p) { - resize(capacity); - } - - idx_t size() const { - return count; - } - - void resize(idx_t capacity_p) { - capacity = capacity_p; - occupied = ValidityMask(capacity); - values = make_unsafe_uniq_array(capacity + 1); - clear(); - } - - void clear() { - count = 0; - occupied.SetAllInvalid(capacity); - } - - T &operator[](const idx_t &key) { - D_ASSERT(key < capacity); - count += 1 - occupied.RowIsValid(key); - occupied.SetValidUnsafe(key); - return values[key]; - } - - const T &operator[](const idx_t &key) const { - D_ASSERT(key < capacity); - return values[key]; - } - - fixed_size_map_iterator_t begin() { - return fixed_size_map_iterator_t(begin_internal(), *this); - } - - const_fixed_size_map_iterator_t begin() const { - return const_fixed_size_map_iterator_t(begin_internal(), *this); - } - - fixed_size_map_iterator_t end() { - return fixed_size_map_iterator_t(capacity, *this); - } - - const_fixed_size_map_iterator_t end() const { - return const_fixed_size_map_iterator_t(capacity, *this); - } - - fixed_size_map_iterator_t find(const idx_t &index) { - if (occupied.RowIsValid(index)) { - return fixed_size_map_iterator_t(index, *this); - } else { - return end(); - } - } - - const_fixed_size_map_iterator_t find(const idx_t &index) const { - if (occupied.RowIsValid(index)) { - return const_fixed_size_map_iterator_t(index, *this); - } else { - return end(); - } - } - -private: - idx_t begin_internal() const { - idx_t index; - for (index = 0; index < capacity; index++) { - if (occupied.RowIsValid(index)) { - break; - } - } - return index; - } - -private: - idx_t capacity; - idx_t count; - - ValidityMask occupied; - unsafe_unique_array values; -}; - -template -struct fixed_size_map_iterator_t { -public: - fixed_size_map_iterator_t(idx_t index_p, fixed_size_map_t &map_p) : map(map_p), current(index_p) { - } - - fixed_size_map_iterator_t &operator++() { - for (current++; current < map.capacity; current++) { - if (map.occupied.RowIsValidUnsafe(current)) { - break; - } - } - return *this; - } - - fixed_size_map_iterator_t operator++(int) { - fixed_size_map_iterator_t tmp = *this; - ++(*this); - return tmp; - } - - idx_t &GetKey() { - return current; - } - - const idx_t &GetKey() const { - return current; - } - - T &GetValue() { - return map.values[current]; - } - - const T &GetValue() const { - return map.values[current]; - } - - friend bool operator==(const fixed_size_map_iterator_t &a, const fixed_size_map_iterator_t &b) { - return a.current == b.current; - } - - friend bool operator!=(const fixed_size_map_iterator_t &a, const fixed_size_map_iterator_t &b) { - return !(a == b); - } - -private: - fixed_size_map_t ↦ - idx_t current; -}; - -template -struct const_fixed_size_map_iterator_t { -public: - const_fixed_size_map_iterator_t(idx_t index_p, const fixed_size_map_t &map_p) : map(map_p), current(index_p) { - } - - const_fixed_size_map_iterator_t &operator++() { - for (current++; current < map.capacity; current++) { - if (map.occupied.RowIsValidUnsafe(current)) { - break; - } - } - return *this; - } - - const_fixed_size_map_iterator_t operator++(int) { - const_fixed_size_map_iterator_t tmp = *this; - ++(*this); - return tmp; - } - - const idx_t &GetKey() const { - return current; - } - - const T &GetValue() const { - return map.values[current]; - } - - friend bool operator==(const const_fixed_size_map_iterator_t &a, const const_fixed_size_map_iterator_t &b) { - return a.current == b.current; - } - - friend bool operator!=(const const_fixed_size_map_iterator_t &a, const const_fixed_size_map_iterator_t &b) { - return !(a == b); - } - -private: - const fixed_size_map_t ↦ - idx_t current; -}; - -} // namespace duckdb - - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/types/row/tuple_data_allocator.hpp -// -// -//===----------------------------------------------------------------------===// - - - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/types/row/tuple_data_layout.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - - - -namespace duckdb { - -class TupleDataLayout { -public: - using Aggregates = vector; - using ValidityBytes = TemplatedValidityMask; - - //! Creates an empty TupleDataLayout - TupleDataLayout(); - //! Create a copy of this TupleDataLayout - TupleDataLayout Copy() const; - -public: - //! Initializes the TupleDataLayout with the specified types and aggregates to an empty TupleDataLayout - void Initialize(vector types_p, Aggregates aggregates_p, bool align = true, bool heap_offset = true); - //! Initializes the TupleDataLayout with the specified types to an empty TupleDataLayout - void Initialize(vector types, bool align = true, bool heap_offset = true); - //! Initializes the TupleDataLayout with the specified aggregates to an empty TupleDataLayout - void Initialize(Aggregates aggregates_p, bool align = true, bool heap_offset = true); - - //! Returns the number of data columns - inline idx_t ColumnCount() const { - return types.size(); - } - //! Returns a list of the column types for this data chunk - inline const vector &GetTypes() const { - return types; - } - //! Returns the number of aggregates - inline idx_t AggregateCount() const { - return aggregates.size(); - } - //! Returns a list of the aggregates for this data chunk - inline Aggregates &GetAggregates() { - return aggregates; - } - //! Returns a map from column id to the struct TupleDataLayout - const inline TupleDataLayout &GetStructLayout(idx_t col_idx) const { - D_ASSERT(struct_layouts->find(col_idx) != struct_layouts->end()); - return struct_layouts->find(col_idx)->second; - } - //! Returns the total width required for each row, including padding - inline idx_t GetRowWidth() const { - return row_width; - } - //! Returns the offset to the start of the data - inline idx_t GetDataOffset() const { - return flag_width; - } - //! Returns the total width required for the data, including padding - inline idx_t GetDataWidth() const { - return data_width; - } - //! Returns the offset to the start of the aggregates - inline idx_t GetAggrOffset() const { - return flag_width + data_width; - } - //! Returns the total width required for the aggregates, including padding - inline idx_t GetAggrWidth() const { - return aggr_width; - } - //! Returns the column offsets into each row - inline const vector &GetOffsets() const { - return offsets; - } - //! Returns whether all columns in this layout are constant size - inline bool AllConstant() const { - return all_constant; - } - //! Gets offset to where heap size is stored - inline idx_t GetHeapSizeOffset() const { - return heap_size_offset; - } - //! Returns whether any of the aggregates have a destructor - inline bool HasDestructor() const { - return has_destructor; - } - -private: - //! The types of the data columns - vector types; - //! The aggregate functions - Aggregates aggregates; - //! Structs are a recursive TupleDataLayout - unique_ptr> struct_layouts; - //! The width of the validity header - idx_t flag_width; - //! The width of the data portion - idx_t data_width; - //! The width of the aggregate state portion - idx_t aggr_width; - //! The width of the entire row - idx_t row_width; - //! The offsets to the columns and aggregate data in each row - vector offsets; - //! Whether all columns in this layout are constant size - bool all_constant; - //! Offset to the heap size of every row - idx_t heap_size_offset; - //! Whether any of the aggregates have a destructor - bool has_destructor; -}; - -} // namespace duckdb - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/types/row/tuple_data_states.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - -namespace duckdb { - -enum class TupleDataPinProperties : uint8_t { - INVALID, - //! Keeps all passed blocks pinned while scanning/iterating over the chunks (for both reading/writing) - KEEP_EVERYTHING_PINNED, - //! Unpins blocks after they are done (for both reading/writing) - UNPIN_AFTER_DONE, - //! Destroys blocks after they are done (for reading only) - DESTROY_AFTER_DONE, - //! Assumes all blocks are already pinned (for reading only) - ALREADY_PINNED -}; - -struct TupleDataPinState { - perfect_map_t row_handles; - perfect_map_t heap_handles; - TupleDataPinProperties properties = TupleDataPinProperties::INVALID; -}; - -struct CombinedListData { - UnifiedVectorFormat combined_data; - list_entry_t combined_list_entries[STANDARD_VECTOR_SIZE]; - buffer_ptr selection_data; -}; - -struct TupleDataVectorFormat { - const SelectionVector *original_sel; - SelectionVector original_owned_sel; - - UnifiedVectorFormat unified; - vector children; - unique_ptr combined_list_data; -}; - -struct TupleDataChunkState { - vector vector_data; - vector column_ids; - - Vector row_locations = Vector(LogicalType::POINTER); - Vector heap_locations = Vector(LogicalType::POINTER); - Vector heap_sizes = Vector(LogicalType::UBIGINT); -}; - -struct TupleDataAppendState { - TupleDataPinState pin_state; - TupleDataChunkState chunk_state; -}; - -struct TupleDataScanState { - TupleDataPinState pin_state; - TupleDataChunkState chunk_state; - idx_t segment_index = DConstants::INVALID_INDEX; - idx_t chunk_index = DConstants::INVALID_INDEX; -}; - -struct TupleDataParallelScanState { - TupleDataScanState scan_state; - mutex lock; -}; - -using TupleDataLocalScanState = TupleDataScanState; - -} // namespace duckdb - - -namespace duckdb { - -struct TupleDataSegment; -struct TupleDataChunk; -struct TupleDataChunkPart; - -struct TupleDataBlock { -public: - TupleDataBlock(BufferManager &buffer_manager, idx_t capacity_p); - - //! Disable copy constructors - TupleDataBlock(const TupleDataBlock &other) = delete; - TupleDataBlock &operator=(const TupleDataBlock &) = delete; - - //! Enable move constructors - TupleDataBlock(TupleDataBlock &&other) noexcept; - TupleDataBlock &operator=(TupleDataBlock &&) noexcept; - -public: - //! Remaining capacity (in bytes) - idx_t RemainingCapacity() const { - D_ASSERT(size <= capacity); - return capacity - size; - } - - //! Remaining capacity (in rows) - idx_t RemainingCapacity(idx_t row_width) const { - return RemainingCapacity() / row_width; - } - -public: - //! The underlying row block - shared_ptr handle; - //! Capacity (in bytes) - idx_t capacity; - //! Occupied size (in bytes) - idx_t size; -}; - -class TupleDataAllocator { -public: - TupleDataAllocator(BufferManager &buffer_manager, const TupleDataLayout &layout); - TupleDataAllocator(TupleDataAllocator &allocator); - - //! Get the buffer manager - BufferManager &GetBufferManager(); - //! Get the buffer allocator - Allocator &GetAllocator(); - //! Get the layout - const TupleDataLayout &GetLayout() const; - //! Number of row blocks - idx_t RowBlockCount() const; - //! Number of heap blocks - idx_t HeapBlockCount() const; - -public: - //! Builds out the chunks for next append, given the metadata in the append state - void Build(TupleDataSegment &segment, TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, - const idx_t append_offset, const idx_t append_count); - //! Initializes a chunk, making its pointers valid - void InitializeChunkState(TupleDataSegment &segment, TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, - idx_t chunk_idx, bool init_heap); - static void RecomputeHeapPointers(Vector &old_heap_ptrs, const SelectionVector &old_heap_sel, - const data_ptr_t row_locations[], Vector &new_heap_ptrs, const idx_t offset, - const idx_t count, const TupleDataLayout &layout, const idx_t base_col_offset); - //! Releases or stores any handles in the management state that are no longer required - void ReleaseOrStoreHandles(TupleDataPinState &state, TupleDataSegment &segment, TupleDataChunk &chunk, - bool release_heap); - //! Releases or stores ALL handles in the management state - void ReleaseOrStoreHandles(TupleDataPinState &state, TupleDataSegment &segment); - -private: - //! Builds out a single part (grabs the lock) - TupleDataChunkPart BuildChunkPart(TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, - const idx_t append_offset, const idx_t append_count, TupleDataChunk &chunk); - //! Internal function for InitializeChunkState - void InitializeChunkStateInternal(TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, idx_t offset, - bool recompute, bool init_heap_pointers, bool init_heap_sizes, - unsafe_vector> &parts); - //! Internal function for ReleaseOrStoreHandles - static void ReleaseOrStoreHandlesInternal(TupleDataSegment &segment, - unsafe_vector &pinned_row_handles, - perfect_map_t &handles, const perfect_set_t &block_ids, - unsafe_vector &blocks, TupleDataPinProperties properties); - //! Pins the given row block - BufferHandle &PinRowBlock(TupleDataPinState &state, const TupleDataChunkPart &part); - //! Pins the given heap block - BufferHandle &PinHeapBlock(TupleDataPinState &state, const TupleDataChunkPart &part); - //! Gets the pointer to the rows for the given chunk part - data_ptr_t GetRowPointer(TupleDataPinState &state, const TupleDataChunkPart &part); - //! Gets the base pointer to the heap for the given chunk part - data_ptr_t GetBaseHeapPointer(TupleDataPinState &state, const TupleDataChunkPart &part); - -private: - //! The buffer manager - BufferManager &buffer_manager; - //! The layout of the data - const TupleDataLayout layout; - //! Blocks storing the fixed-size rows - unsafe_vector row_blocks; - //! Blocks storing the variable-size data of the fixed-size rows (e.g., string, list) - unsafe_vector heap_blocks; - - //! Re-usable arrays used while building buffer space - unsafe_vector> chunk_parts; - unsafe_vector> chunk_part_indices; -}; - -} // namespace duckdb - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/types/row/tuple_data_collection.hpp -// -// -//===----------------------------------------------------------------------===// - - - - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/types/row/tuple_data_segment.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - - - - -namespace duckdb { - -class TupleDataAllocator; -class TupleDataLayout; - -struct TupleDataChunkPart { -public: - TupleDataChunkPart(mutex &lock); - - //! Disable copy constructors - TupleDataChunkPart(const TupleDataChunkPart &other) = delete; - TupleDataChunkPart &operator=(const TupleDataChunkPart &) = delete; - - //! Enable move constructors - TupleDataChunkPart(TupleDataChunkPart &&other) noexcept; - TupleDataChunkPart &operator=(TupleDataChunkPart &&) noexcept; - - static constexpr const uint32_t INVALID_INDEX = (uint32_t)-1; - -public: - //! Index/offset of the row block - uint32_t row_block_index; - uint32_t row_block_offset; - //! Pointer/index/offset of the heap block - uint32_t heap_block_index; - uint32_t heap_block_offset; - data_ptr_t base_heap_ptr; - //! Total heap size for this chunk part - uint32_t total_heap_size; - //! Tuple count for this chunk part - uint32_t count; - //! Lock for recomputing heap pointers (owned by TupleDataChunk) - reference lock; -}; - -struct TupleDataChunk { -public: - TupleDataChunk(); - - //! Disable copy constructors - TupleDataChunk(const TupleDataChunk &other) = delete; - TupleDataChunk &operator=(const TupleDataChunk &) = delete; - - //! Enable move constructors - TupleDataChunk(TupleDataChunk &&other) noexcept; - TupleDataChunk &operator=(TupleDataChunk &&) noexcept; - - //! Add a part to this chunk - void AddPart(TupleDataChunkPart &&part, const TupleDataLayout &layout); - //! Tries to merge the last chunk part into the second-to-last one - void MergeLastChunkPart(const TupleDataLayout &layout); - //! Verify counts of the parts in this chunk - void Verify() const; - -public: - //! The parts of this chunk - unsafe_vector parts; - //! The row block ids referenced by the chunk - perfect_set_t row_block_ids; - //! The heap block ids referenced by the chunk - perfect_set_t heap_block_ids; - //! Tuple count for this chunk - idx_t count; - //! Lock for recomputing heap pointers - unsafe_unique_ptr lock; -}; - -struct TupleDataSegment { -public: - explicit TupleDataSegment(shared_ptr allocator); - - ~TupleDataSegment(); - - //! Disable copy constructors - TupleDataSegment(const TupleDataSegment &other) = delete; - TupleDataSegment &operator=(const TupleDataSegment &) = delete; - - //! Enable move constructors - TupleDataSegment(TupleDataSegment &&other) noexcept; - TupleDataSegment &operator=(TupleDataSegment &&) noexcept; - - //! The number of chunks in this segment - idx_t ChunkCount() const; - //! The size (in bytes) of this segment - idx_t SizeInBytes() const; - //! Unpins all held pins - void Unpin(); - - //! Verify counts of the chunks in this segment - void Verify() const; - //! Verify that all blocks in this segment are pinned - void VerifyEverythingPinned() const; - -public: - //! The allocator for this segment - shared_ptr allocator; - //! The chunks of this segment - unsafe_vector chunks; - //! The tuple count of this segment - idx_t count; - //! The data size of this segment - idx_t data_size; - - //! Lock for modifying pinned_handles - mutex pinned_handles_lock; - //! Where handles to row blocks will be stored with TupleDataPinProperties::KEEP_EVERYTHING_PINNED - unsafe_vector pinned_row_handles; - //! Where handles to heap blocks will be stored with TupleDataPinProperties::KEEP_EVERYTHING_PINNED - unsafe_vector pinned_heap_handles; -}; - -} // namespace duckdb - - - -namespace duckdb { - -class TupleDataAllocator; -struct TupleDataScatterFunction; -struct TupleDataGatherFunction; -struct RowOperationsState; - -typedef void (*tuple_data_scatter_function_t)(const Vector &source, const TupleDataVectorFormat &source_format, - const SelectionVector &append_sel, const idx_t append_count, - const TupleDataLayout &layout, const Vector &row_locations, - Vector &heap_locations, const idx_t col_idx, - const UnifiedVectorFormat &list_format, - const vector &child_functions); - -struct TupleDataScatterFunction { - tuple_data_scatter_function_t function; - vector child_functions; -}; - -typedef void (*tuple_data_gather_function_t)(const TupleDataLayout &layout, Vector &row_locations, const idx_t col_idx, - const SelectionVector &scan_sel, const idx_t scan_count, Vector &target, - const SelectionVector &target_sel, Vector &list_vector, - const vector &child_functions); - -struct TupleDataGatherFunction { - tuple_data_gather_function_t function; - vector child_functions; -}; - -//! TupleDataCollection represents a set of buffer-managed data stored in row format -//! FIXME: rename to RowDataCollection after we phase it out -class TupleDataCollection { - friend class TupleDataChunkIterator; - friend class PartitionedTupleData; - -public: - //! Constructs a TupleDataCollection with the specified layout - TupleDataCollection(BufferManager &buffer_manager, const TupleDataLayout &layout); - //! Constructs a TupleDataCollection with the same (shared) allocator - explicit TupleDataCollection(shared_ptr allocator); - - ~TupleDataCollection(); - -public: - //! The layout of the stored rows - const TupleDataLayout &GetLayout() const; - //! The number of rows stored in the tuple data collection - const idx_t &Count() const; - //! The number of chunks stored in the tuple data collection - idx_t ChunkCount() const; - //! The size (in bytes) of the blocks held by this tuple data collection - idx_t SizeInBytes() const; - //! Unpins all held pins - void Unpin(); - - //! Gets the scatter function for the given type - static TupleDataScatterFunction GetScatterFunction(const LogicalType &type, bool within_list = false); - //! Gets the gather function for the given type - static TupleDataGatherFunction GetGatherFunction(const LogicalType &type, bool within_list = false); - - //! Initializes an Append state - useful for optimizing many appends made to the same tuple data collection - void InitializeAppend(TupleDataAppendState &append_state, - TupleDataPinProperties properties = TupleDataPinProperties::UNPIN_AFTER_DONE); - //! Initializes an Append state - useful for optimizing many appends made to the same tuple data collection - void InitializeAppend(TupleDataAppendState &append_state, vector column_ids, - TupleDataPinProperties properties = TupleDataPinProperties::UNPIN_AFTER_DONE); - //! Initializes the Pin state of an Append state - //! - Useful for optimizing many appends made to the same tuple data collection - void InitializeAppend(TupleDataPinState &pin_state, - TupleDataPinProperties = TupleDataPinProperties::UNPIN_AFTER_DONE); - //! Initializes the Chunk state of an Append state - //! - Useful for optimizing many appends made to the same tuple data collection - void InitializeChunkState(TupleDataChunkState &chunk_state, vector column_ids = {}); - //! Initializes the Chunk state of an Append state - //! - Useful for optimizing many appends made to the same tuple data collection - static void InitializeChunkState(TupleDataChunkState &chunk_state, const vector &types, - vector column_ids = {}); - //! Append a DataChunk directly to this TupleDataCollection - calls InitializeAppend and Append internally - void Append(DataChunk &new_chunk, const SelectionVector &append_sel = *FlatVector::IncrementalSelectionVector(), - idx_t append_count = DConstants::INVALID_INDEX); - //! Append a DataChunk directly to this TupleDataCollection - calls InitializeAppend and Append internally - void Append(DataChunk &new_chunk, vector column_ids, - const SelectionVector &append_sel = *FlatVector::IncrementalSelectionVector(), - const idx_t append_count = DConstants::INVALID_INDEX); - //! Append a DataChunk to this TupleDataCollection using the specified Append state - void Append(TupleDataAppendState &append_state, DataChunk &new_chunk, - const SelectionVector &append_sel = *FlatVector::IncrementalSelectionVector(), - const idx_t append_count = DConstants::INVALID_INDEX); - //! Append a DataChunk to this TupleDataCollection using the specified pin and Chunk states - void Append(TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, DataChunk &new_chunk, - const SelectionVector &append_sel = *FlatVector::IncrementalSelectionVector(), - const idx_t append_count = DConstants::INVALID_INDEX); - //! Append a DataChunk to this TupleDataCollection using the specified pin and Chunk states - //! - ToUnifiedFormat has already been called - void AppendUnified(TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, DataChunk &new_chunk, - const SelectionVector &append_sel = *FlatVector::IncrementalSelectionVector(), - const idx_t append_count = DConstants::INVALID_INDEX); - - //! Creates a UnifiedVectorFormat in the given Chunk state for the given DataChunk - static void ToUnifiedFormat(TupleDataChunkState &chunk_state, DataChunk &new_chunk); - //! Gets the UnifiedVectorFormat from the Chunk state as an array - static void GetVectorData(const TupleDataChunkState &chunk_state, UnifiedVectorFormat result[]); - //! Computes the heap sizes for the new DataChunk that will be appended - static void ComputeHeapSizes(TupleDataChunkState &chunk_state, const DataChunk &new_chunk, - const SelectionVector &append_sel, const idx_t append_count); - - //! Builds out the buffer space for the specified Chunk state - void Build(TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, const idx_t append_offset, - const idx_t append_count); - //! Scatters the given DataChunk to the rows in the specified Chunk state - void Scatter(TupleDataChunkState &chunk_state, const DataChunk &new_chunk, const SelectionVector &append_sel, - const idx_t append_count) const; - //! Scatters the given Vector to the given column id to the rows in the specified Chunk state - void Scatter(TupleDataChunkState &chunk_state, const Vector &source, const column_t column_id, - const SelectionVector &append_sel, const idx_t append_count) const; - //! Copy rows from input to the built Chunk state - void CopyRows(TupleDataChunkState &chunk_state, TupleDataChunkState &input, const SelectionVector &append_sel, - const idx_t append_count) const; - - //! Finalizes the Pin state, releasing or storing blocks - void FinalizePinState(TupleDataPinState &pin_state, TupleDataSegment &segment); - //! Finalizes the Pin state, releasing or storing blocks - void FinalizePinState(TupleDataPinState &pin_state); - - //! Appends the other TupleDataCollection to this, destroying the other data collection - void Combine(TupleDataCollection &other); - //! Appends the other TupleDataCollection to this, destroying the other data collection - void Combine(unique_ptr other); - //! Resets the TupleDataCollection, clearing all data - void Reset(); - - //! Initializes a chunk with the correct types that can be used to call Append/Scan - void InitializeChunk(DataChunk &chunk) const; - //! Initializes a chunk with the correct types for a given scan state - void InitializeScanChunk(TupleDataScanState &state, DataChunk &chunk) const; - //! Initializes a Scan state for scanning all columns - void InitializeScan(TupleDataScanState &state, - TupleDataPinProperties properties = TupleDataPinProperties::UNPIN_AFTER_DONE) const; - //! Initializes a Scan state for scanning a subset of the columns - void InitializeScan(TupleDataScanState &state, vector column_ids, - TupleDataPinProperties properties = TupleDataPinProperties::UNPIN_AFTER_DONE) const; - //! Initialize a parallel scan over the tuple data collection over all columns - void InitializeScan(TupleDataParallelScanState &state, - TupleDataPinProperties properties = TupleDataPinProperties::UNPIN_AFTER_DONE) const; - //! Initialize a parallel scan over the tuple data collection over a subset of the columns - void InitializeScan(TupleDataParallelScanState &gstate, vector column_ids, - TupleDataPinProperties properties = TupleDataPinProperties::UNPIN_AFTER_DONE) const; - //! Scans a DataChunk from the TupleDataCollection - bool Scan(TupleDataScanState &state, DataChunk &result); - //! Scans a DataChunk from the TupleDataCollection - bool Scan(TupleDataParallelScanState &gstate, TupleDataLocalScanState &lstate, DataChunk &result); - //! Whether the last scan has been completed on this TupleDataCollection - bool ScanComplete(const TupleDataScanState &state) const; - - //! Gathers a DataChunk from the TupleDataCollection, given the specific row locations (requires full pin) - void Gather(Vector &row_locations, const SelectionVector &scan_sel, const idx_t scan_count, DataChunk &result, - const SelectionVector &target_sel) const; - //! Gathers a DataChunk (only the columns given by column_ids) from the TupleDataCollection, - //! given the specific row locations (requires full pin) - void Gather(Vector &row_locations, const SelectionVector &scan_sel, const idx_t scan_count, - const vector &column_ids, DataChunk &result, const SelectionVector &target_sel) const; - //! Gathers a Vector (from the given column id) from the TupleDataCollection - //! given the specific row locations (requires full pin) - void Gather(Vector &row_locations, const SelectionVector &sel, const idx_t scan_count, const column_t column_id, - Vector &result, const SelectionVector &target_sel) const; - - //! Converts this TupleDataCollection to a string representation - string ToString(); - //! Prints the string representation of this TupleDataCollection - void Print(); - - //! Verify that all blocks are pinned - void VerifyEverythingPinned() const; - -private: - //! Initializes the TupleDataCollection (called by the constructor) - void Initialize(); - //! Gets all column ids - void GetAllColumnIDs(vector &column_ids); - //! Adds a segment to this TupleDataCollection - void AddSegment(TupleDataSegment &&segment); - - //! Computes the heap sizes for the specific Vector that will be appended - static void ComputeHeapSizes(Vector &heap_sizes_v, const Vector &source_v, TupleDataVectorFormat &source, - const SelectionVector &append_sel, const idx_t append_count); - //! Computes the heap sizes for the specific Vector that will be appended (within a list) - static void WithinListHeapComputeSizes(Vector &heap_sizes_v, const Vector &source_v, - TupleDataVectorFormat &source_format, const SelectionVector &append_sel, - const idx_t append_count, const UnifiedVectorFormat &list_data); - //! Computes the heap sizes for the fixed-size type Vector that will be appended (within a list) - static void ComputeFixedWithinListHeapSizes(Vector &heap_sizes_v, const Vector &source_v, - TupleDataVectorFormat &source_format, const SelectionVector &append_sel, - const idx_t append_count, const UnifiedVectorFormat &list_data); - //! Computes the heap sizes for the string Vector that will be appended (within a list) - static void StringWithinListComputeHeapSizes(Vector &heap_sizes_v, const Vector &source_v, - TupleDataVectorFormat &source_format, - const SelectionVector &append_sel, const idx_t append_count, - const UnifiedVectorFormat &list_data); - //! Computes the heap sizes for the struct Vector that will be appended (within a list) - static void StructWithinListComputeHeapSizes(Vector &heap_sizes_v, const Vector &source_v, - TupleDataVectorFormat &source_format, - const SelectionVector &append_sel, const idx_t append_count, - const UnifiedVectorFormat &list_data); - //! Computes the heap sizes for the list Vector that will be appended (within a list) - static void ListWithinListComputeHeapSizes(Vector &heap_sizes_v, const Vector &source_v, - TupleDataVectorFormat &source_format, const SelectionVector &append_sel, - const idx_t append_count, const UnifiedVectorFormat &list_data); - - //! Get the next segment/chunk index for the scan - bool NextScanIndex(TupleDataScanState &scan_state, idx_t &segment_index, idx_t &chunk_index); - //! Scans the chunk at the given segment/chunk indices - void ScanAtIndex(TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, const vector &column_ids, - idx_t segment_index, idx_t chunk_index, DataChunk &result); - - //! Verify count/data size of this collection - void Verify() const; - -private: - //! The layout of the TupleDataCollection - const TupleDataLayout layout; - //! The TupleDataAllocator - shared_ptr allocator; - //! The number of entries stored in the TupleDataCollection - idx_t count; - //! The size (in bytes) of this TupleDataCollection - idx_t data_size; - //! The data segments of the TupleDataCollection - unsafe_vector segments; - //! The set of scatter functions - vector scatter_functions; - //! The set of gather functions - vector gather_functions; -}; - -} // namespace duckdb - - -namespace duckdb { - -//! Local state for parallel partitioning -struct PartitionedTupleDataAppendState { -public: - PartitionedTupleDataAppendState() : partition_indices(LogicalType::UBIGINT) { - } - -public: - Vector partition_indices; - SelectionVector partition_sel; - SelectionVector reverse_partition_sel; - - static constexpr idx_t MAP_THRESHOLD = 256; - perfect_map_t partition_entries; - fixed_size_map_t fixed_partition_entries; - - vector> partition_pin_states; - TupleDataChunkState chunk_state; -}; - -enum class PartitionedTupleDataType : uint8_t { - INVALID, - //! Radix partitioning on a hash column - RADIX -}; - -//! Shared allocators for parallel partitioning -struct PartitionTupleDataAllocators { - mutex lock; - vector> allocators; -}; - -//! PartitionedTupleData represents partitioned row data, which serves as an interface for different types of -//! partitioning, e.g., radix, hive -class PartitionedTupleData { -public: - virtual ~PartitionedTupleData(); - -public: - //! Get the layout of this PartitionedTupleData - const TupleDataLayout &GetLayout() const; - //! Get the partitioning type of this PartitionedTupleData - PartitionedTupleDataType GetType() const; - //! Initializes a local state for parallel partitioning that can be merged into this PartitionedTupleData - void InitializeAppendState(PartitionedTupleDataAppendState &state, - TupleDataPinProperties properties = TupleDataPinProperties::UNPIN_AFTER_DONE) const; - //! Appends a DataChunk to this PartitionedTupleData - void Append(PartitionedTupleDataAppendState &state, DataChunk &input, - const SelectionVector &append_sel = *FlatVector::IncrementalSelectionVector(), - const idx_t append_count = DConstants::INVALID_INDEX); - //! Appends a DataChunk to this PartitionedTupleData - //! - ToUnifiedFormat has already been called - void AppendUnified(PartitionedTupleDataAppendState &state, DataChunk &input, - const SelectionVector &append_sel = *FlatVector::IncrementalSelectionVector(), - const idx_t append_count = DConstants::INVALID_INDEX); - //! Appends rows to this PartitionedTupleData - void Append(PartitionedTupleDataAppendState &state, TupleDataChunkState &input, const idx_t count); - //! Flushes any remaining data in the append state into this PartitionedTupleData - void FlushAppendState(PartitionedTupleDataAppendState &state); - //! Combine another PartitionedTupleData into this PartitionedTupleData - void Combine(PartitionedTupleData &other); - //! Resets this PartitionedTupleData - void Reset(); - //! Repartition this PartitionedTupleData into the new PartitionedTupleData - void Repartition(PartitionedTupleData &new_partitioned_data); - //! Unpins the data - void Unpin(); - //! Get the partitions in this PartitionedTupleData - vector> &GetPartitions(); - //! Get the data of this PartitionedTupleData as a single unpartitioned TupleDataCollection - unique_ptr GetUnpartitioned(); - //! Get the count of this PartitionedTupleData - idx_t Count() const; - //! Get the size (in bytes) of this PartitionedTupleData - idx_t SizeInBytes() const; - //! Get the number of partitions of this PartitionedTupleData - idx_t PartitionCount() const; - //! Converts this PartitionedTupleData to a string representation - string ToString(); - //! Prints the string representation of this PartitionedTupleData - void Print(); - -protected: - //===--------------------------------------------------------------------===// - // Partitioning type implementation interface - //===--------------------------------------------------------------------===// - //! Initialize a PartitionedTupleDataAppendState for this type of partitioning (optional) - virtual void InitializeAppendStateInternal(PartitionedTupleDataAppendState &state, - TupleDataPinProperties properties) const { - } - //! Compute the partition indices for this type of partitioning for the input DataChunk and store them in the - //! `partition_data` of the local state. If this type creates partitions on the fly (for, e.g., hive), this - //! function is also in charge of creating new partitions and mapping the input data to a partition index - virtual void ComputePartitionIndices(PartitionedTupleDataAppendState &state, DataChunk &input) { - throw NotImplementedException("ComputePartitionIndices for this type of PartitionedTupleData"); - } - //! Compute partition indices from rows (similar to function above) - virtual void ComputePartitionIndices(Vector &row_locations, idx_t append_count, Vector &partition_indices) const { - throw NotImplementedException("ComputePartitionIndices for this type of PartitionedTupleData"); - } - //! Maximum partition index (optional) - virtual idx_t MaxPartitionIndex() const { - return DConstants::INVALID_INDEX; - } - - //! Whether or not to iterate over the original partitions in reverse order when repartitioning (optional) - virtual bool RepartitionReverseOrder() const { - return false; - } - //! Finalize states while repartitioning - useful for unpinning blocks that are no longer needed (optional) - virtual void RepartitionFinalizeStates(PartitionedTupleData &old_partitioned_data, - PartitionedTupleData &new_partitioned_data, - PartitionedTupleDataAppendState &state, idx_t finished_partition_idx) const { - } - -protected: - //! PartitionedTupleData can only be instantiated by derived classes - PartitionedTupleData(PartitionedTupleDataType type, BufferManager &buffer_manager, const TupleDataLayout &layout); - PartitionedTupleData(const PartitionedTupleData &other); - - //! Create a new shared allocator - void CreateAllocator(); - //! Whether to use fixed size map or regular marp - bool UseFixedSizeMap() const; - //! Builds a selection vector in the Append state for the partitions - //! - returns true if everything belongs to the same partition - stores partition index in single_partition_idx - void BuildPartitionSel(PartitionedTupleDataAppendState &state, const SelectionVector &append_sel, - const idx_t append_count); - template - void BuildPartitionSel(PartitionedTupleDataAppendState &state, MAP_TYPE &partition_entries, - const SelectionVector &append_sel, const idx_t append_count); - //! Builds out the buffer space in the partitions - void BuildBufferSpace(PartitionedTupleDataAppendState &state); - template - void BuildBufferSpace(PartitionedTupleDataAppendState &state, const MAP_TYPE &partition_entries); - //! Create a collection for a specific a partition - unique_ptr CreatePartitionCollection(idx_t partition_index) const { - if (allocators) { - return make_uniq(allocators->allocators[partition_index]); - } else { - return make_uniq(buffer_manager, layout); - } - } - //! Verify count/data size of this PartitionedTupleData - void Verify() const; - -protected: - PartitionedTupleDataType type; - BufferManager &buffer_manager; - const TupleDataLayout layout; - idx_t count; - idx_t data_size; - - mutex lock; - shared_ptr allocators; - vector> partitions; - -public: - template - TARGET &Cast() { - D_ASSERT(dynamic_cast(this)); - return reinterpret_cast(*this); - } - template - const TARGET &Cast() const { - D_ASSERT(dynamic_cast(this)); - return reinterpret_cast(*this); - } -}; - -} // namespace duckdb - - -namespace duckdb { - -class BufferManager; -class Vector; -struct UnifiedVectorFormat; -struct SelectionVector; - -//! Generic radix partitioning functions -struct RadixPartitioning { -public: - //! 4096 partitions ought to be enough to go out-of-core properly - static constexpr const idx_t MAX_RADIX_BITS = 12; - - //! The number of partitions for a given number of radix bits - static inline constexpr idx_t NumberOfPartitions(idx_t radix_bits) { - return idx_t(1) << radix_bits; - } - - //! Inverse of NumberOfPartitions, given a number of partitions, get the number of radix bits - static inline idx_t RadixBits(idx_t n_partitions) { - D_ASSERT(IsPowerOfTwo(n_partitions)); - for (idx_t r = 0; r < sizeof(idx_t) * 8; r++) { - if (n_partitions == NumberOfPartitions(r)) { - return r; - } - } - throw InternalException("RadixPartitioning::RadixBits unable to find partition count!"); - } - - //! Radix bits begin after uint16_t because these bits are used as salt in the aggregate HT - static inline constexpr idx_t Shift(idx_t radix_bits) { - return (sizeof(hash_t) - sizeof(uint16_t)) * 8 - radix_bits; - } - - //! Mask of the radix bits of the hash - static inline constexpr hash_t Mask(idx_t radix_bits) { - return (hash_t(1 << radix_bits) - 1) << Shift(radix_bits); - } - - //! Select using a cutoff on the radix bits of the hash - static idx_t Select(Vector &hashes, const SelectionVector *sel, idx_t count, idx_t radix_bits, idx_t cutoff, - SelectionVector *true_sel, SelectionVector *false_sel); -}; - -//! RadixPartitionedColumnData is a PartitionedColumnData that partitions input based on the radix of a hash -class RadixPartitionedColumnData : public PartitionedColumnData { -public: - RadixPartitionedColumnData(ClientContext &context, vector types, idx_t radix_bits, idx_t hash_col_idx); - RadixPartitionedColumnData(const RadixPartitionedColumnData &other); - ~RadixPartitionedColumnData() override; - - idx_t GetRadixBits() const { - return radix_bits; - } - -protected: - //===--------------------------------------------------------------------===// - // Radix Partitioning interface implementation - //===--------------------------------------------------------------------===// - idx_t BufferSize() const override { - switch (radix_bits) { - case 1: - case 2: - case 3: - case 4: - return GetBufferSize(1 << 1); - case 5: - return GetBufferSize(1 << 2); - case 6: - return GetBufferSize(1 << 3); - default: - return GetBufferSize(1 << 4); - } - } - - void InitializeAppendStateInternal(PartitionedColumnDataAppendState &state) const override; - void ComputePartitionIndices(PartitionedColumnDataAppendState &state, DataChunk &input) override; - - static constexpr idx_t GetBufferSize(idx_t div) { - return STANDARD_VECTOR_SIZE / div == 0 ? 1 : STANDARD_VECTOR_SIZE / div; - } - -private: - //! The number of radix bits - const idx_t radix_bits; - //! The index of the column holding the hashes - const idx_t hash_col_idx; -}; - -//! RadixPartitionedTupleData is a PartitionedTupleData that partitions input based on the radix of a hash -class RadixPartitionedTupleData : public PartitionedTupleData { -public: - RadixPartitionedTupleData(BufferManager &buffer_manager, const TupleDataLayout &layout, idx_t radix_bits_p, - idx_t hash_col_idx_p); - RadixPartitionedTupleData(const RadixPartitionedTupleData &other); - ~RadixPartitionedTupleData() override; - - idx_t GetRadixBits() const { - return radix_bits; - } - -private: - void Initialize(); - -protected: - //===--------------------------------------------------------------------===// - // Radix Partitioning interface implementation - //===--------------------------------------------------------------------===// - void InitializeAppendStateInternal(PartitionedTupleDataAppendState &state, - TupleDataPinProperties properties) const override; - void ComputePartitionIndices(PartitionedTupleDataAppendState &state, DataChunk &input) override; - void ComputePartitionIndices(Vector &row_locations, idx_t count, Vector &partition_indices) const override; - idx_t MaxPartitionIndex() const override { - return RadixPartitioning::NumberOfPartitions(radix_bits) - 1; - } - - bool RepartitionReverseOrder() const override { - return true; - } - void RepartitionFinalizeStates(PartitionedTupleData &old_partitioned_data, - PartitionedTupleData &new_partitioned_data, PartitionedTupleDataAppendState &state, - idx_t finished_partition_idx) const override; - -private: - //! The number of radix bits - const idx_t radix_bits; - //! The index of the column holding the hashes - const idx_t hash_col_idx; -}; - -} // namespace duckdb - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/parallel/base_pipeline_event.hpp -// -// -//===----------------------------------------------------------------------===// - - - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/parallel/event.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - -namespace duckdb { -class Executor; -class Task; - -class Event : public std::enable_shared_from_this { -public: - explicit Event(Executor &executor); - virtual ~Event() = default; - -public: - virtual void Schedule() = 0; - //! Called right after the event is finished - virtual void FinishEvent() { - } - //! Called after the event is entirely finished - virtual void FinalizeFinish() { - } - - void FinishTask(); - void Finish(); - - void AddDependency(Event &event); - bool HasDependencies() const { - return total_dependencies != 0; - } - const vector &GetParentsVerification() const; - - void CompleteDependency(); - - void SetTasks(vector> tasks); - - void InsertEvent(shared_ptr replacement_event); - - bool IsFinished() const { - return finished; - } - - virtual void PrintPipeline() { - } - - template - TARGET &Cast() { - D_ASSERT(dynamic_cast(this)); - return reinterpret_cast(*this); - } - template - const TARGET &Cast() const { - D_ASSERT(dynamic_cast(this)); - return reinterpret_cast(*this); - } - -protected: - Executor &executor; - //! The current threads working on the event - atomic finished_tasks; - //! The maximum amount of threads that can work on the event - atomic total_tasks; - - //! The amount of completed dependencies - //! The event can only be started after the dependencies have finished executing - atomic finished_dependencies; - //! The total amount of dependencies - idx_t total_dependencies; - - //! The events that depend on this event to run - vector> parents; - //! Raw pointers to the parents (used for verification only) - vector parents_raw; - - //! Whether or not the event is finished executing - atomic finished; -}; - -} // namespace duckdb - - - -namespace duckdb { - -//! A BasePipelineEvent is used as the basis of any event that belongs to a specific pipeline -class BasePipelineEvent : public Event { -public: - explicit BasePipelineEvent(shared_ptr pipeline); - explicit BasePipelineEvent(Pipeline &pipeline); - - void PrintPipeline() override { - pipeline->Print(); - } - - //! The pipeline that this event belongs to - shared_ptr pipeline; -}; - -} // namespace duckdb - - -namespace duckdb { - -class PartitionGlobalHashGroup { -public: - using GlobalSortStatePtr = unique_ptr; - using Orders = vector; - using Types = vector; - - PartitionGlobalHashGroup(BufferManager &buffer_manager, const Orders &partitions, const Orders &orders, - const Types &payload_types, bool external); - - int ComparePartitions(const SBIterator &left, const SBIterator &right) const; - - void ComputeMasks(ValidityMask &partition_mask, ValidityMask &order_mask); - - GlobalSortStatePtr global_sort; - atomic count; - idx_t batch_base; - - // Mask computation - SortLayout partition_layout; -}; - -class PartitionGlobalSinkState { -public: - using HashGroupPtr = unique_ptr; - using Orders = vector; - using Types = vector; - - using GroupingPartition = unique_ptr; - using GroupingAppend = unique_ptr; - - static void GenerateOrderings(Orders &partitions, Orders &orders, - const vector> &partition_bys, const Orders &order_bys, - const vector> &partitions_stats); - - PartitionGlobalSinkState(ClientContext &context, const vector> &partition_bys, - const vector &order_bys, const Types &payload_types, - const vector> &partitions_stats, idx_t estimated_cardinality); - - bool HasMergeTasks() const; - - unique_ptr CreatePartition(idx_t new_bits) const; - void SyncPartitioning(const PartitionGlobalSinkState &other); - - void UpdateLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append); - void CombineLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append); - - ClientContext &context; - BufferManager &buffer_manager; - Allocator &allocator; - mutex lock; - - // OVER(PARTITION BY...) (hash grouping) - unique_ptr grouping_data; - //! Payload plus hash column - TupleDataLayout grouping_types; - //! The number of radix bits if this partition is being synced with another - idx_t fixed_bits; - - // OVER(...) (sorting) - Orders partitions; - Orders orders; - const Types payload_types; - vector hash_groups; - bool external; - // Reverse lookup from hash bins to non-empty hash groups - vector bin_groups; - - // OVER() (no sorting) - unique_ptr rows; - unique_ptr strings; - - // Threading - idx_t memory_per_thread; - idx_t max_bits; - atomic count; - -private: - void ResizeGroupingData(idx_t cardinality); - void SyncLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append); -}; - -class PartitionLocalSinkState { -public: - using LocalSortStatePtr = unique_ptr; - - PartitionLocalSinkState(ClientContext &context, PartitionGlobalSinkState &gstate_p); - - // Global state - PartitionGlobalSinkState &gstate; - Allocator &allocator; - - // Shared expression evaluation - ExpressionExecutor executor; - DataChunk group_chunk; - DataChunk payload_chunk; - size_t sort_cols; - - // OVER(PARTITION BY...) (hash grouping) - unique_ptr local_partition; - unique_ptr local_append; - - // OVER(ORDER BY...) (only sorting) - LocalSortStatePtr local_sort; - - // OVER() (no sorting) - RowLayout payload_layout; - unique_ptr rows; - unique_ptr strings; - - //! Compute the hash values - void Hash(DataChunk &input_chunk, Vector &hash_vector); - //! Sink an input chunk - void Sink(DataChunk &input_chunk); - //! Merge the state into the global state. - void Combine(); -}; - -enum class PartitionSortStage : uint8_t { INIT, SCAN, PREPARE, MERGE, SORTED }; - -class PartitionLocalMergeState; - -class PartitionGlobalMergeState { -public: - using GroupDataPtr = unique_ptr; - - // OVER(PARTITION BY...) - PartitionGlobalMergeState(PartitionGlobalSinkState &sink, GroupDataPtr group_data, hash_t hash_bin); - - // OVER(ORDER BY...) - explicit PartitionGlobalMergeState(PartitionGlobalSinkState &sink); - - bool IsSorted() const { - lock_guard guard(lock); - return stage == PartitionSortStage::SORTED; - } - - bool AssignTask(PartitionLocalMergeState &local_state); - bool TryPrepareNextStage(); - void CompleteTask(); - - PartitionGlobalSinkState &sink; - GroupDataPtr group_data; - PartitionGlobalHashGroup *hash_group; - vector column_ids; - TupleDataParallelScanState chunk_state; - GlobalSortState *global_sort; - const idx_t memory_per_thread; - const idx_t num_threads; - -private: - mutable mutex lock; - PartitionSortStage stage; - idx_t total_tasks; - idx_t tasks_assigned; - idx_t tasks_completed; -}; - -class PartitionLocalMergeState { -public: - explicit PartitionLocalMergeState(PartitionGlobalSinkState &gstate); - - bool TaskFinished() { - return finished; - } - - void Prepare(); - void Scan(); - void Merge(); - - void ExecuteTask(); - - PartitionGlobalMergeState *merge_state; - PartitionSortStage stage; - atomic finished; - - // Sorting buffers - ExpressionExecutor executor; - DataChunk sort_chunk; - DataChunk payload_chunk; -}; - -class PartitionGlobalMergeStates { -public: - struct Callback { - virtual bool HasError() const { - return false; - } - }; - - using PartitionGlobalMergeStatePtr = unique_ptr; - - explicit PartitionGlobalMergeStates(PartitionGlobalSinkState &sink); - - bool ExecuteTask(PartitionLocalMergeState &local_state, Callback &callback); - - vector states; -}; - -class PartitionMergeEvent : public BasePipelineEvent { -public: - PartitionMergeEvent(PartitionGlobalSinkState &gstate_p, Pipeline &pipeline_p) - : BasePipelineEvent(pipeline_p), gstate(gstate_p), merge_states(gstate_p) { - } - - PartitionGlobalSinkState &gstate; - PartitionGlobalMergeStates merge_states; - -public: - void Schedule() override; -}; - -} // namespace duckdb - - - - - - - -namespace duckdb { - -class Index; -class ConflictInfo; - -enum class ConflictManagerMode : uint8_t { - SCAN, // gather conflicts without throwing - THROW // throw on the conflicts that were not found during the scan -}; - -enum class LookupResultType : uint8_t { LOOKUP_MISS, LOOKUP_HIT, LOOKUP_NULL }; - -class ConflictManager { -public: - ConflictManager(VerifyExistenceType lookup_type, idx_t input_size, - optional_ptr conflict_info = nullptr); - -public: - void SetIndexCount(idx_t count); - // These methods return a boolean indicating whether we should throw or not - bool AddMiss(idx_t chunk_index); - bool AddHit(idx_t chunk_index, row_t row_id); - bool AddNull(idx_t chunk_index); - VerifyExistenceType LookupType() const; - // This should be called before using the conflicts selection vector - void Finalize(); - idx_t ConflictCount() const; - const ManagedSelection &Conflicts() const; - Vector &RowIds(); - const ConflictInfo &GetConflictInfo() const; - void FinishLookup(); - void SetMode(ConflictManagerMode mode); - -private: - bool IsConflict(LookupResultType type); - const unordered_set &InternalConflictSet() const; - Vector &InternalRowIds(); - Vector &InternalIntermediate(); - ManagedSelection &InternalSelection(); - bool SingleIndexTarget() const; - bool ShouldThrow(idx_t chunk_index) const; - bool ShouldIgnoreNulls() const; - void AddConflictInternal(idx_t chunk_index, row_t row_id); - void AddToConflictSet(idx_t chunk_index); - -private: - VerifyExistenceType lookup_type; - idx_t input_size; - optional_ptr conflict_info; - idx_t index_count; - bool finalized = false; - ManagedSelection conflicts; - unique_ptr row_ids; - // Used to check if a given conflict is part of the conflict target or not - unique_ptr> conflict_set; - // Contains 'input_size' booleans, indicating if a given index in the input chunk has a conflict - unique_ptr intermediate_vector; - // Mapping from chunk_index to row_id - vector row_id_map; - // Whether we have already found the one conflict target we're interested in - bool single_index_finished = false; - ConflictManagerMode mode; -}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/execution/operator/scan/csv/csv_state_machine.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - -namespace duckdb { - -//! All States of CSV Parsing -enum class CSVState : uint8_t { - STANDARD = 0, //! Regular unquoted field state - DELIMITER = 1, //! State after encountering a field separator (e.g., ;) - RECORD_SEPARATOR = 2, //! State after encountering a record separator (i.e., \n) - CARRIAGE_RETURN = 3, //! State after encountering a carriage return(i.e., \r) - QUOTED = 4, //! State when inside a quoted field - UNQUOTED = 5, //! State when leaving a quoted field - ESCAPE = 6, //! State when encountering an escape character (e.g., \) - EMPTY_LINE = 7, //! State when encountering an empty line (i.e., \r\r \n\n, \n\r) - INVALID = 8 //! Got to an Invalid State, this should error. -}; - -//! The CSV State Machine comprises a state transition array (STA). -//! The STA indicates the current state of parsing based on both the current and preceding characters. -//! This reveals whether we are dealing with a Field, a New Line, a Delimiter, and so forth. -//! The STA's creation depends on the provided quote, character, and delimiter options for that state machine. -//! The motivation behind implementing an STA is to remove branching in regular CSV Parsing by predicting and detecting -//! the states. Note: The State Machine is currently utilized solely in the CSV Sniffer. -class CSVStateMachine { -public: - explicit CSVStateMachine(CSVReaderOptions &options_p, const CSVStateMachineOptions &state_machine_options, - shared_ptr buffer_manager_p, - CSVStateMachineCache &csv_state_machine_cache_p); - //! Resets the state machine, so it can be used again - void Reset(); - - //! Aux Function for string UTF8 Verification - void VerifyUTF8(); - - CSVStateMachineCache &csv_state_machine_cache; - - const CSVReaderOptions &options; - CSVBufferIterator csv_buffer_iterator; - //! Stores identified start row for this file (e.g., a file can start with garbage like notes, before the header) - idx_t start_row = 0; - //! The Transition Array is a Finite State Machine - //! It holds the transitions of all states, on all 256 possible different characters - const state_machine_t &transition_array; - - //! Both these variables are used for new line identifier detection - bool single_record_separator = false; - bool carry_on_separator = false; - - //! Variables Used for Sniffing - CSVState state; - CSVState previous_state; - CSVState pre_previous_state; - idx_t cur_rows; - idx_t column_count; - string value; - idx_t rows_read; - idx_t line_start_pos = 0; - - //! Dialect options resulting from sniffing - DialectOptions dialect_options; -}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/function/scalar/compressed_materialization_functions.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - -namespace duckdb { - -struct CompressedMaterializationFunctions { - //! The types we compress integral types to - static const vector IntegralTypes(); - //! The types we compress strings to - static const vector StringTypes(); - - static unique_ptr Bind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments); -}; - -//! Needed for (de)serialization without binding -enum class CompressedMaterializationDirection : uint8_t { INVALID = 0, COMPRESS = 1, DECOMPRESS = 2 }; - -struct CMIntegralCompressFun { - static ScalarFunction GetFunction(const LogicalType &input_type, const LogicalType &result_type); - static void RegisterFunction(BuiltinFunctions &set); -}; - -struct CMIntegralDecompressFun { - static ScalarFunction GetFunction(const LogicalType &input_type, const LogicalType &result_type); - static void RegisterFunction(BuiltinFunctions &set); -}; - -struct CMStringCompressFun { - static ScalarFunction GetFunction(const LogicalType &result_type); - static void RegisterFunction(BuiltinFunctions &set); -}; - -struct CMStringDecompressFun { - static ScalarFunction GetFunction(const LogicalType &input_type); - static void RegisterFunction(BuiltinFunctions &set); -}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/main/capi/capi_internal.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - - - - -#include -#include - -#ifdef _WIN32 -#ifndef strdup -#define strdup _strdup -#endif -#endif - -namespace duckdb { - -struct DatabaseData { - unique_ptr database; -}; - -struct PreparedStatementWrapper { - //! Map of name -> values - case_insensitive_map_t values; - unique_ptr statement; -}; - -struct ExtractStatementsWrapper { - vector> statements; - string error; -}; - -struct PendingStatementWrapper { - unique_ptr statement; - bool allow_streaming; -}; - -struct ArrowResultWrapper { - unique_ptr result; - unique_ptr current_chunk; - ClientProperties options; -}; - -struct AppenderWrapper { - unique_ptr appender; - string error; -}; - -enum class CAPIResultSetType : uint8_t { - CAPI_RESULT_TYPE_NONE = 0, - CAPI_RESULT_TYPE_MATERIALIZED, - CAPI_RESULT_TYPE_STREAMING, - CAPI_RESULT_TYPE_DEPRECATED -}; - -struct DuckDBResultData { - //! The underlying query result - unique_ptr result; - // Results can only use either the new API or the old API, not a mix of the two - // They start off as "none" and switch to one or the other when an API method is used - CAPIResultSetType result_set_type; -}; - -duckdb_type ConvertCPPTypeToC(const LogicalType &type); -LogicalTypeId ConvertCTypeToCPP(duckdb_type c_type); -idx_t GetCTypeSize(duckdb_type type); -duckdb_state duckdb_translate_result(unique_ptr result, duckdb_result *out); -bool deprecated_materialize_result(duckdb_result *result); - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/main/error_manager.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - -namespace duckdb { -class ClientContext; -class DatabaseInstance; - -enum class ErrorType : uint16_t { - // error message types - UNSIGNED_EXTENSION = 0, - INVALIDATED_TRANSACTION = 1, - INVALIDATED_DATABASE = 2, - - // this should always be the last value - ERROR_COUNT, - INVALID = 65535, -}; - -//! The error manager class is responsible for formatting error messages -//! It allows for error messages to be overridden by extensions and clients -class ErrorManager { -public: - template - string FormatException(ErrorType error_type, Args... params) { - vector values; - return FormatExceptionRecursive(error_type, values, params...); - } - - DUCKDB_API string FormatExceptionRecursive(ErrorType error_type, vector &values); - - template - string FormatExceptionRecursive(ErrorType error_type, vector &values, T param, - Args... params) { - values.push_back(ExceptionFormatValue::CreateFormatValue(param)); - return FormatExceptionRecursive(error_type, values, params...); - } - - template - static string FormatException(ClientContext &context, ErrorType error_type, Args... params) { - return Get(context).FormatException(error_type, params...); - } - - DUCKDB_API static string InvalidUnicodeError(const string &input, const string &context); - - //! Adds a custom error for a specific error type - void AddCustomError(ErrorType type, string new_error); - - DUCKDB_API static ErrorManager &Get(ClientContext &context); - DUCKDB_API static ErrorManager &Get(DatabaseInstance &context); - -private: - map custom_errors; -}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// src/include/duckdb/parallel/interrupt.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - -#include -#include - -namespace duckdb { - -//! InterruptMode specifies how operators should block/unblock, note that this will happen transparently to the -//! operator, as the operator only needs to return a BLOCKED result and call the callback using the InterruptState. -//! NO_INTERRUPTS: No blocking mode is specified, an error will be thrown when the operator blocks. Should only be used -//! when manually calling operators of which is known they will never block. -//! TASK: A weak pointer to a task is provided. On the callback, this task will be signalled. If the Task has -//! been deleted, this callback becomes a NOP. This is the preferred way to await blocked pipelines. -//! BLOCKING: The caller has blocked awaiting some synchronization primitive to wait for the callback. -enum class InterruptMode : uint8_t { NO_INTERRUPTS, TASK, BLOCKING }; - -//! Synchronization primitive used to await a callback in InterruptMode::BLOCKING. -struct InterruptDoneSignalState { - //! Called by the callback to signal the interrupt is over - void Signal(); - //! Await the callback signalling the interrupt is over - void Await(); - -protected: - mutex lock; - std::condition_variable cv; - bool done = false; -}; - -//! State required to make the callback after some asynchronous operation within an operator source / sink. -class InterruptState { -public: - //! Default interrupt state will be set to InterruptMode::NO_INTERRUPTS and throw an error on use of Callback() - InterruptState(); - //! Register the task to be interrupted and set mode to InterruptMode::TASK, the preferred way to handle interrupts - InterruptState(weak_ptr task); - //! Register signal state and set mode to InterruptMode::BLOCKING, used for code paths without Task. - InterruptState(weak_ptr done_signal); - - //! Perform the callback to indicate the Interrupt is over - DUCKDB_API void Callback() const; - -protected: - //! Current interrupt mode - InterruptMode mode; - //! Task ptr for InterruptMode::TASK - weak_ptr current_task; - //! Signal state for InterruptMode::BLOCKING - weak_ptr signal_state; -}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/parser/statement/insert_statement.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/parser/statement/update_statement.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - - - -namespace duckdb { - -class UpdateSetInfo { -public: - UpdateSetInfo(); - -public: - unique_ptr Copy() const; - -public: - // The condition that needs to be met to perform the update - unique_ptr condition; - // The columns to update - vector columns; - // The set expressions to execute - vector> expressions; - -protected: - UpdateSetInfo(const UpdateSetInfo &other); -}; - -class UpdateStatement : public SQLStatement { -public: - static constexpr const StatementType TYPE = StatementType::UPDATE_STATEMENT; - -public: - UpdateStatement(); - - unique_ptr table; - unique_ptr from_table; - //! keep track of optional returningList if statement contains a RETURNING keyword - vector> returning_list; - unique_ptr set_info; - //! CTEs - CommonTableExpressionMap cte_map; - -protected: - UpdateStatement(const UpdateStatement &other); - -public: - string ToString() const override; - unique_ptr Copy() const override; -}; - -} // namespace duckdb - - -namespace duckdb { -class ExpressionListRef; -class UpdateSetInfo; - -enum class OnConflictAction : uint8_t { - THROW, - NOTHING, - UPDATE, - REPLACE // Only used in transform/bind step, changed to UPDATE later -}; - -enum class InsertColumnOrder : uint8_t { INSERT_BY_POSITION = 0, INSERT_BY_NAME = 1 }; - -class OnConflictInfo { -public: - OnConflictInfo(); - -public: - unique_ptr Copy() const; - -public: - OnConflictAction action_type; - - vector indexed_columns; - //! The SET information (if action_type == UPDATE) - unique_ptr set_info; - //! The condition determining whether we apply the DO .. for conflicts that arise - unique_ptr condition; - -protected: - OnConflictInfo(const OnConflictInfo &other); -}; - -class InsertStatement : public SQLStatement { -public: - static constexpr const StatementType TYPE = StatementType::INSERT_STATEMENT; - -public: - InsertStatement(); - - //! The select statement to insert from - unique_ptr select_statement; - //! Column names to insert into - vector columns; - - //! Table name to insert to - string table; - //! Schema name to insert to - string schema; - //! The catalog name to insert to - string catalog; - - //! keep track of optional returningList if statement contains a RETURNING keyword - vector> returning_list; - - unique_ptr on_conflict_info; - unique_ptr table_ref; - - //! CTEs - CommonTableExpressionMap cte_map; - - //! Whether or not this a DEFAULT VALUES - bool default_values = false; - - //! INSERT BY POSITION or INSERT BY NAME - InsertColumnOrder column_order = InsertColumnOrder::INSERT_BY_POSITION; - -protected: - InsertStatement(const InsertStatement &other); - -public: - static string OnConflictActionToString(OnConflictAction action); - string ToString() const override; - unique_ptr Copy() const override; - - //! If the INSERT statement is inserted DIRECTLY from a values list (i.e. INSERT INTO tbl VALUES (...)) this returns - //! the expression list Otherwise, this returns NULL - optional_ptr GetValuesList() const; -}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/storage/magic_bytes.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - -namespace duckdb { -class FileSystem; - -enum class DataFileType : uint8_t { - FILE_DOES_NOT_EXIST, // file does not exist - DUCKDB_FILE, // duckdb database file - SQLITE_FILE, // sqlite database file - PARQUET_FILE // parquet file -}; - -class MagicBytes { -public: - static DataFileType CheckMagicBytes(FileSystem *fs, const string &path); -}; - -} // namespace duckdb -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/verification/statement_verifier.hpp -// -// -//===----------------------------------------------------------------------===// - - - - - - - -namespace duckdb { - -enum class VerificationType : uint8_t { - ORIGINAL, - COPIED, - DESERIALIZED, - PARSED, - UNOPTIMIZED, - NO_OPERATOR_CACHING, - PREPARED, - EXTERNAL, - - INVALID -}; - -class StatementVerifier { -public: - StatementVerifier(VerificationType type, string name, unique_ptr statement_p); - explicit StatementVerifier(unique_ptr statement_p); - static unique_ptr Create(VerificationType type, const SQLStatement &statement_p); - virtual ~StatementVerifier() noexcept; - - //! Check whether expressions in this verifier and the other verifier match - void CheckExpressions(const StatementVerifier &other) const; - //! Check whether expressions within this verifier match - void CheckExpressions() const; - - //! Run the select statement and store the result - virtual bool Run(ClientContext &context, const string &query, - const std::function(const string &, unique_ptr)> &run); - //! Compare this verifier's results with another verifier - string CompareResults(const StatementVerifier &other); - -public: - const VerificationType type; - const string name; - unique_ptr statement; - const vector> &select_list; - unique_ptr materialized_result; - - virtual bool RequireEquality() const { - return true; - } - - virtual bool DisableOptimizer() const { - return false; - } - - virtual bool DisableOperatorCaching() const { - return false; - } - - virtual bool ForceExternal() const { - return false; - } -}; - -} // namespace duckdb - - -// LICENSE_CHANGE_BEGIN -// The following code up to LICENSE_CHANGE_END is subject to THIRD PARTY LICENSE #5 -// See the end of this file for a list - -/* - Formatting library for C++ - - Copyright (c) 2012 - present, Victor Zverovich - - Permission is hereby granted, free of charge, to any person obtaining - a copy of this software and associated documentation files (the - "Software"), to deal in the Software without restriction, including - without limitation the rights to use, copy, modify, merge, publish, - distribute, sublicense, and/or sell copies of the Software, and to - permit persons to whom the Software is furnished to do so, subject to - the following conditions: - - The above copyright notice and this permission notice shall be - included in all copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND - NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE - LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION - OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION - WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - - --- Optional exception to the license --- - - As an exception, if, as a result of your compiling your source code, portions - of this Software are embedded into a machine-executable object form of such - source code, you may redistribute such embedded portions in such object form - without including the above copyright and permission notices. - */ - -#ifndef FMT_FORMAT_H_ -#define FMT_FORMAT_H_ - - - - -// LICENSE_CHANGE_BEGIN -// The following code up to LICENSE_CHANGE_END is subject to THIRD PARTY LICENSE #5 -// See the end of this file for a list - -// Formatting library for C++ - the core API -// -// Copyright (c) 2012 - present, Victor Zverovich -// All rights reserved. -// -// For the license information refer to format.h. - -#ifndef FMT_CORE_H_ -#define FMT_CORE_H_ - -#include // std::FILE -#include -#include -#include -#include - -// The fmt library version in the form major * 10000 + minor * 100 + patch. -#define FMT_VERSION 60102 - -#ifdef __has_feature -# define FMT_HAS_FEATURE(x) __has_feature(x) -#else -# define FMT_HAS_FEATURE(x) 0 -#endif - -#if defined(__has_include) && !defined(__INTELLISENSE__) && \ - !(defined(__INTEL_COMPILER) && __INTEL_COMPILER < 1600) -# define FMT_HAS_INCLUDE(x) __has_include(x) -#else -# define FMT_HAS_INCLUDE(x) 0 -#endif - -#ifdef __has_cpp_attribute -# define FMT_HAS_CPP_ATTRIBUTE(x) __has_cpp_attribute(x) -#else -# define FMT_HAS_CPP_ATTRIBUTE(x) 0 -#endif - -#if defined(__GNUC__) && !defined(__clang__) -# define FMT_GCC_VERSION (__GNUC__ * 100 + __GNUC_MINOR__) -#else -# define FMT_GCC_VERSION 0 -#endif - -#if __cplusplus >= 201103L || defined(__GXX_EXPERIMENTAL_CXX0X__) -# define FMT_HAS_GXX_CXX11 FMT_GCC_VERSION -#else -# define FMT_HAS_GXX_CXX11 0 -#endif - -#ifdef __NVCC__ -# define FMT_NVCC __NVCC__ -#else -# define FMT_NVCC 0 -#endif - -#ifdef _MSC_VER -# define FMT_MSC_VER _MSC_VER -#else -# define FMT_MSC_VER 0 -#endif - -// Check if relaxed C++14 constexpr is supported. -// GCC doesn't allow throw in constexpr until version 6 (bug 67371). -#if FMT_USE_CONSTEXPR -# define FMT_CONSTEXPR inline -# define FMT_CONSTEXPR_DECL -#else -# define FMT_CONSTEXPR inline -# define FMT_CONSTEXPR_DECL -#endif - -#ifndef FMT_OVERRIDE -# if FMT_HAS_FEATURE(cxx_override) || \ - (FMT_GCC_VERSION >= 408 && FMT_HAS_GXX_CXX11) || FMT_MSC_VER >= 1900 -# define FMT_OVERRIDE override -# else -# define FMT_OVERRIDE -# endif -#endif - -// Check if exceptions are disabled. -#ifndef FMT_EXCEPTIONS -# if (defined(__GNUC__) && !defined(__EXCEPTIONS)) || \ - FMT_MSC_VER && !_HAS_EXCEPTIONS -# define FMT_EXCEPTIONS 0 -# else -# define FMT_EXCEPTIONS 1 -# endif -#endif - -// Define FMT_USE_NOEXCEPT to make fmt use noexcept (C++11 feature). -#ifndef FMT_USE_NOEXCEPT -# define FMT_USE_NOEXCEPT 0 -#endif - -#if FMT_USE_NOEXCEPT || FMT_HAS_FEATURE(cxx_noexcept) || \ - (FMT_GCC_VERSION >= 408 && FMT_HAS_GXX_CXX11) || FMT_MSC_VER >= 1900 -# define FMT_DETECTED_NOEXCEPT noexcept -# define FMT_HAS_CXX11_NOEXCEPT 1 -#else -# define FMT_DETECTED_NOEXCEPT throw() -# define FMT_HAS_CXX11_NOEXCEPT 0 -#endif - -#ifndef FMT_NOEXCEPT -# if FMT_EXCEPTIONS || FMT_HAS_CXX11_NOEXCEPT -# define FMT_NOEXCEPT FMT_DETECTED_NOEXCEPT -# else -# define FMT_NOEXCEPT -# endif -#endif - -// [[noreturn]] is disabled on MSVC because of bogus unreachable code warnings. -#if FMT_EXCEPTIONS && FMT_HAS_CPP_ATTRIBUTE(noreturn) && !FMT_MSC_VER -# define FMT_NORETURN [[noreturn]] -#else -# define FMT_NORETURN -#endif - -#ifndef FMT_DEPRECATED -# if (FMT_HAS_CPP_ATTRIBUTE(deprecated) && __cplusplus >= 201402L) || \ - FMT_MSC_VER >= 1900 -# define FMT_DEPRECATED [[deprecated]] -# else -# if defined(__GNUC__) || defined(__clang__) -# define FMT_DEPRECATED __attribute__((deprecated)) -# elif FMT_MSC_VER -# define FMT_DEPRECATED __declspec(deprecated) -# else -# define FMT_DEPRECATED /* deprecated */ -# endif -# endif -#endif - -// Workaround broken [[deprecated]] in the Intel compiler and NVCC. -#if defined(__INTEL_COMPILER) || FMT_NVCC -# define FMT_DEPRECATED_ALIAS -#else -# define FMT_DEPRECATED_ALIAS FMT_DEPRECATED -#endif - -#ifndef FMT_BEGIN_NAMESPACE -# if FMT_HAS_FEATURE(cxx_inline_namespaces) || FMT_GCC_VERSION >= 404 || \ - FMT_MSC_VER >= 1900 -# define FMT_INLINE_NAMESPACE inline namespace -# define FMT_END_NAMESPACE \ - } \ - } -# else -# define FMT_INLINE_NAMESPACE namespace -# define FMT_END_NAMESPACE \ - } \ - using namespace v6; \ - } -# endif -# define FMT_BEGIN_NAMESPACE \ - namespace duckdb_fmt { \ - FMT_INLINE_NAMESPACE v6 { -#endif - -#if !defined(FMT_HEADER_ONLY) && defined(_WIN32) -# ifdef FMT_EXPORT -# define FMT_API __declspec(dllexport) -# elif defined(FMT_SHARED) -# define FMT_API __declspec(dllimport) -# define FMT_EXTERN_TEMPLATE_API FMT_API -# endif -#endif -#ifndef FMT_API -# define FMT_API -#endif -#ifndef FMT_EXTERN_TEMPLATE_API -# define FMT_EXTERN_TEMPLATE_API -#endif - -#ifndef FMT_HEADER_ONLY -# define FMT_EXTERN extern -#else -# define FMT_EXTERN -#endif - -// libc++ supports string_view in pre-c++17. -#if (FMT_HAS_INCLUDE() && \ - (__cplusplus > 201402L || defined(_LIBCPP_VERSION))) || \ - (defined(_MSVC_LANG) && _MSVC_LANG > 201402L && _MSC_VER >= 1910) -# include -# define FMT_USE_STRING_VIEW -#elif FMT_HAS_INCLUDE("experimental/string_view") && __cplusplus >= 201402L -# include -# define FMT_USE_EXPERIMENTAL_STRING_VIEW -#endif - -FMT_BEGIN_NAMESPACE - -// Implementations of enable_if_t and other types for pre-C++14 systems. -template -using enable_if_t = typename std::enable_if::type; -template -using conditional_t = typename std::conditional::type; -template using bool_constant = std::integral_constant; -template -using remove_reference_t = typename std::remove_reference::type; -template -using remove_const_t = typename std::remove_const::type; -template -using remove_cvref_t = typename std::remove_cv>::type; - -struct monostate {}; - -// An enable_if helper to be used in template parameters which results in much -// shorter symbols: https://godbolt.org/z/sWw4vP. Extra parentheses are needed -// to workaround a bug in MSVC 2019 (see #1140 and #1186). -#define FMT_ENABLE_IF(...) enable_if_t<(__VA_ARGS__), int> = 0 - -namespace internal { - -// A workaround for gcc 4.8 to make void_t work in a SFINAE context. -template struct void_t_impl { using type = void; }; - -#ifndef FMT_ASSERT -#define FMT_ASSERT(condition, message) -#endif - -#if defined(FMT_USE_STRING_VIEW) -template using std_string_view = std::basic_string_view; -#elif defined(FMT_USE_EXPERIMENTAL_STRING_VIEW) -template -using std_string_view = std::experimental::basic_string_view; -#else -template struct std_string_view {}; -#endif - -#ifdef FMT_USE_INT128 -// Do nothing. -#elif defined(__SIZEOF_INT128__) -# define FMT_USE_INT128 1 -using int128_t = __int128_t; -using uint128_t = __uint128_t; -#else -# define FMT_USE_INT128 0 -#endif -#if !FMT_USE_INT128 -struct int128_t {}; -struct uint128_t {}; -#endif - -// Casts a nonnegative integer to unsigned. -template -FMT_CONSTEXPR typename std::make_unsigned::type to_unsigned(Int value) { - FMT_ASSERT(value >= 0, "negative value"); - return static_cast::type>(value); -} -} // namespace internal - -template -using void_t = typename internal::void_t_impl::type; - -/** - An implementation of ``std::basic_string_view`` for pre-C++17. It provides a - subset of the API. ``fmt::basic_string_view`` is used for format strings even - if ``std::string_view`` is available to prevent issues when a library is - compiled with a different ``-std`` option than the client code (which is not - recommended). - */ -template class basic_string_view { - private: - const Char* data_; - size_t size_; - - public: - using char_type = Char; - using iterator = const Char*; - - FMT_CONSTEXPR basic_string_view() FMT_NOEXCEPT : data_(nullptr), size_(0) {} - - /** Constructs a string reference object from a C string and a size. */ - FMT_CONSTEXPR basic_string_view(const Char* s, size_t count) FMT_NOEXCEPT - : data_(s), - size_(count) {} - - /** - \rst - Constructs a string reference object from a C string computing - the size with ``std::char_traits::length``. - \endrst - */ - basic_string_view(const Char* s) - : data_(s), size_(std::char_traits::length(s)) {} - - /** Constructs a string reference from a ``std::basic_string`` object. */ - template - FMT_CONSTEXPR basic_string_view( - const std::basic_string& s) FMT_NOEXCEPT - : data_(s.data()), - size_(s.size()) {} - - template < - typename S, - FMT_ENABLE_IF(std::is_same>::value)> - FMT_CONSTEXPR basic_string_view(S s) FMT_NOEXCEPT : data_(s.data()), - size_(s.size()) {} - - /** Returns a pointer to the string data. */ - FMT_CONSTEXPR const Char* data() const { return data_; } - - /** Returns the string size. */ - FMT_CONSTEXPR size_t size() const { return size_; } - - FMT_CONSTEXPR iterator begin() const { return data_; } - FMT_CONSTEXPR iterator end() const { return data_ + size_; } - - FMT_CONSTEXPR const Char& operator[](size_t pos) const { return data_[pos]; } - - FMT_CONSTEXPR void remove_prefix(size_t n) { - data_ += n; - size_ -= n; - } - - std::string to_string() { - return std::string((char *) data(), size()); - } - - // Lexicographically compare this string reference to other. - int compare(basic_string_view other) const { - size_t str_size = size_ < other.size_ ? size_ : other.size_; - int result = std::char_traits::compare(data_, other.data_, str_size); - if (result == 0) - result = size_ == other.size_ ? 0 : (size_ < other.size_ ? -1 : 1); - return result; - } - - friend bool operator==(basic_string_view lhs, basic_string_view rhs) { - return lhs.compare(rhs) == 0; - } - friend bool operator!=(basic_string_view lhs, basic_string_view rhs) { - return lhs.compare(rhs) != 0; - } - friend bool operator<(basic_string_view lhs, basic_string_view rhs) { - return lhs.compare(rhs) < 0; - } - friend bool operator<=(basic_string_view lhs, basic_string_view rhs) { - return lhs.compare(rhs) <= 0; - } - friend bool operator>(basic_string_view lhs, basic_string_view rhs) { - return lhs.compare(rhs) > 0; - } - friend bool operator>=(basic_string_view lhs, basic_string_view rhs) { - return lhs.compare(rhs) >= 0; - } -}; - -using string_view = basic_string_view; -using wstring_view = basic_string_view; - -// A UTF-8 code unit type. -#if FMT_HAS_FEATURE(__cpp_char8_t) -typedef char8_t fmt_char8_t; -#else -typedef char fmt_char8_t; -#endif - -/** Specifies if ``T`` is a character type. Can be specialized by users. */ -template struct is_char : std::false_type {}; -template <> struct is_char : std::true_type {}; -template <> struct is_char : std::true_type {}; -template <> struct is_char : std::true_type {}; -template <> struct is_char : std::true_type {}; - -/** - \rst - Returns a string view of `s`. In order to add custom string type support to - {fmt} provide an overload of `to_string_view` for it in the same namespace as - the type for the argument-dependent lookup to work. - - **Example**:: - - namespace my_ns { - inline string_view to_string_view(const my_string& s) { - return {s.data(), s.length()}; - } - } - std::string message = fmt::format(my_string("The answer is {}"), 42); - \endrst - */ -template ::value)> -inline basic_string_view to_string_view(const Char* s) { - return s; -} - -template -inline basic_string_view to_string_view( - const std::basic_string& s) { - return s; -} - -template -inline basic_string_view to_string_view(basic_string_view s) { - return s; -} - -template >::value)> -inline basic_string_view to_string_view( - internal::std_string_view s) { - return s; -} - -// A base class for compile-time strings. It is defined in the fmt namespace to -// make formatting functions visible via ADL, e.g. format(fmt("{}"), 42). -struct compile_string {}; - -template -struct is_compile_string : std::is_base_of {}; - -template ::value)> -FMT_CONSTEXPR basic_string_view to_string_view(const S& s) { - return s; -} - -namespace internal { -void to_string_view(...); -using duckdb_fmt::v6::to_string_view; - -// Specifies whether S is a string type convertible to fmt::basic_string_view. -// It should be a constexpr function but MSVC 2017 fails to compile it in -// enable_if and MSVC 2015 fails to compile it as an alias template. -template -struct is_string : std::is_class()))> { -}; - -template struct char_t_impl {}; -template struct char_t_impl::value>> { - using result = decltype(to_string_view(std::declval())); - using type = typename result::char_type; -}; - -struct error_handler { - FMT_CONSTEXPR error_handler() = default; - FMT_CONSTEXPR error_handler(const error_handler&) = default; - - // This function is intentionally not constexpr to give a compile-time error. - FMT_NORETURN FMT_API void on_error(std::string message); -}; -} // namespace internal - -/** String's character type. */ -template using char_t = typename internal::char_t_impl::type; - -/** - \rst - Parsing context consisting of a format string range being parsed and an - argument counter for automatic indexing. - - You can use one of the following type aliases for common character types: - - +-----------------------+-------------------------------------+ - | Type | Definition | - +=======================+=====================================+ - | format_parse_context | basic_format_parse_context | - +-----------------------+-------------------------------------+ - | wformat_parse_context | basic_format_parse_context | - +-----------------------+-------------------------------------+ - \endrst - */ -template -class basic_format_parse_context : private ErrorHandler { - private: - basic_string_view format_str_; - int next_arg_id_; - - public: - using char_type = Char; - using iterator = typename basic_string_view::iterator; - - explicit FMT_CONSTEXPR basic_format_parse_context( - basic_string_view format_str, ErrorHandler eh = ErrorHandler()) - : ErrorHandler(eh), format_str_(format_str), next_arg_id_(0) {} - - /** - Returns an iterator to the beginning of the format string range being - parsed. - */ - FMT_CONSTEXPR iterator begin() const FMT_NOEXCEPT { - return format_str_.begin(); - } - - /** - Returns an iterator past the end of the format string range being parsed. - */ - FMT_CONSTEXPR iterator end() const FMT_NOEXCEPT { return format_str_.end(); } - - /** Advances the begin iterator to ``it``. */ - FMT_CONSTEXPR void advance_to(iterator it) { - format_str_.remove_prefix(internal::to_unsigned(it - begin())); - } - - /** - Reports an error if using the manual argument indexing; otherwise returns - the next argument index and switches to the automatic indexing. - */ - FMT_CONSTEXPR int next_arg_id() { - if (next_arg_id_ >= 0) return next_arg_id_++; - on_error("cannot switch from manual to automatic argument indexing"); - return 0; - } - - /** - Reports an error if using the automatic argument indexing; otherwise - switches to the manual indexing. - */ - FMT_CONSTEXPR void check_arg_id(int) { - if (next_arg_id_ > 0) - on_error("cannot switch from automatic to manual argument indexing"); - else - next_arg_id_ = -1; - } - - FMT_CONSTEXPR void check_arg_id(basic_string_view) {} - - FMT_CONSTEXPR void on_error(std::string message) { - ErrorHandler::on_error(message); - } - - FMT_CONSTEXPR ErrorHandler error_handler() const { return *this; } -}; - -using format_parse_context = basic_format_parse_context; -using wformat_parse_context = basic_format_parse_context; - -template -using basic_parse_context FMT_DEPRECATED_ALIAS = - basic_format_parse_context; -using parse_context FMT_DEPRECATED_ALIAS = basic_format_parse_context; -using wparse_context FMT_DEPRECATED_ALIAS = basic_format_parse_context; - -template class basic_format_arg; -template class basic_format_args; - -// A formatter for objects of type T. -template -struct formatter { - // A deleted default constructor indicates a disabled formatter. - formatter() = delete; -}; - -template -struct FMT_DEPRECATED convert_to_int - : bool_constant::value && - std::is_convertible::value> {}; - -// Specifies if T has an enabled formatter specialization. A type can be -// formattable even if it doesn't have a formatter e.g. via a conversion. -template -using has_formatter = - std::is_constructible>; - -namespace internal { - -/** A contiguous memory buffer with an optional growing ability. */ -template class buffer { - private: - T* ptr_; - std::size_t size_; - std::size_t capacity_; - - protected: - // Don't initialize ptr_ since it is not accessed to save a few cycles. - buffer(std::size_t sz) FMT_NOEXCEPT : size_(sz), capacity_(sz) {} - - buffer(T* p = nullptr, std::size_t sz = 0, std::size_t cap = 0) FMT_NOEXCEPT - : ptr_(p), - size_(sz), - capacity_(cap) {} - - /** Sets the buffer data and capacity. */ - void set(T* buf_data, std::size_t buf_capacity) FMT_NOEXCEPT { - ptr_ = buf_data; - capacity_ = buf_capacity; - } - - /** Increases the buffer capacity to hold at least *capacity* elements. */ - virtual void grow(std::size_t capacity) = 0; - - public: - using value_type = T; - using const_reference = const T&; - - buffer(const buffer&) = delete; - void operator=(const buffer&) = delete; - virtual ~buffer() = default; - - T* begin() FMT_NOEXCEPT { return ptr_; } - T* end() FMT_NOEXCEPT { return ptr_ + size_; } - - /** Returns the size of this buffer. */ - std::size_t size() const FMT_NOEXCEPT { return size_; } - - /** Returns the capacity of this buffer. */ - std::size_t capacity() const FMT_NOEXCEPT { return capacity_; } - - /** Returns a pointer to the buffer data. */ - T* data() FMT_NOEXCEPT { return ptr_; } - - /** Returns a pointer to the buffer data. */ - const T* data() const FMT_NOEXCEPT { return ptr_; } - - /** - Resizes the buffer. If T is a POD type new elements may not be initialized. - */ - void resize(std::size_t new_size) { - reserve(new_size); - size_ = new_size; - } - - /** Clears this buffer. */ - void clear() { size_ = 0; } - - /** Reserves space to store at least *capacity* elements. */ - void reserve(std::size_t new_capacity) { - if (new_capacity > capacity_) grow(new_capacity); - } - - void push_back(const T& value) { - reserve(size_ + 1); - ptr_[size_++] = value; - } - - /** Appends data to the end of the buffer. */ - template void append(const U* begin, const U* end); - - T& operator[](std::size_t index) { return ptr_[index]; } - const T& operator[](std::size_t index) const { return ptr_[index]; } -}; - -// A container-backed buffer. -template -class container_buffer : public buffer { - private: - Container& container_; - - protected: - void grow(std::size_t capacity) FMT_OVERRIDE { - container_.resize(capacity); - this->set(&container_[0], capacity); - } - - public: - explicit container_buffer(Container& c) - : buffer(c.size()), container_(c) {} -}; - -// Extracts a reference to the container from back_insert_iterator. -template -inline Container& get_container(std::back_insert_iterator it) { - using bi_iterator = std::back_insert_iterator; - struct accessor : bi_iterator { - accessor(bi_iterator iter) : bi_iterator(iter) {} - using bi_iterator::container; - }; - return *accessor(it).container; -} - -template -struct fallback_formatter { - fallback_formatter() = delete; -}; - -// Specifies if T has an enabled fallback_formatter specialization. -template -using has_fallback_formatter = - std::is_constructible>; - -template struct named_arg_base; -template struct named_arg; - -enum type { - none_type, - named_arg_type, - // Integer types should go first, - int_type, - uint_type, - long_long_type, - ulong_long_type, - int128_type, - uint128_type, - bool_type, - char_type, - last_integer_type = char_type, - // followed by floating-point types. - float_type, - double_type, - long_double_type, - last_numeric_type = long_double_type, - cstring_type, - string_type, - pointer_type, - custom_type -}; - -// Maps core type T to the corresponding type enum constant. -template -struct type_constant : std::integral_constant {}; - -#define FMT_TYPE_CONSTANT(Type, constant) \ - template \ - struct type_constant : std::integral_constant {} - -FMT_TYPE_CONSTANT(const named_arg_base&, named_arg_type); -FMT_TYPE_CONSTANT(int, int_type); -FMT_TYPE_CONSTANT(unsigned, uint_type); -FMT_TYPE_CONSTANT(long long, long_long_type); -FMT_TYPE_CONSTANT(unsigned long long, ulong_long_type); -FMT_TYPE_CONSTANT(int128_t, int128_type); -FMT_TYPE_CONSTANT(uint128_t, uint128_type); -FMT_TYPE_CONSTANT(bool, bool_type); -FMT_TYPE_CONSTANT(Char, char_type); -FMT_TYPE_CONSTANT(float, float_type); -FMT_TYPE_CONSTANT(double, double_type); -FMT_TYPE_CONSTANT(long double, long_double_type); -FMT_TYPE_CONSTANT(const Char*, cstring_type); -FMT_TYPE_CONSTANT(basic_string_view, string_type); -FMT_TYPE_CONSTANT(const void*, pointer_type); - -FMT_CONSTEXPR bool is_integral_type(type t) { - FMT_ASSERT(t != named_arg_type, "invalid argument type"); - return t > none_type && t <= last_integer_type; -} - -FMT_CONSTEXPR bool is_arithmetic_type(type t) { - FMT_ASSERT(t != named_arg_type, "invalid argument type"); - return t > none_type && t <= last_numeric_type; -} - -template struct string_value { - const Char* data; - std::size_t size; -}; - -template struct custom_value { - using parse_context = basic_format_parse_context; - const void* value; - void (*format)(const void* arg, parse_context& parse_ctx, Context& ctx); -}; - -// A formatting argument value. -template class value { - public: - using char_type = typename Context::char_type; - - union { - int int_value; - unsigned uint_value; - long long long_long_value; - unsigned long long ulong_long_value; - int128_t int128_value; - uint128_t uint128_value; - bool bool_value; - char_type char_value; - float float_value; - double double_value; - long double long_double_value; - const void* pointer; - string_value string; - custom_value custom; - const named_arg_base* named_arg; - }; - - FMT_CONSTEXPR value(int val = 0) : int_value(val) {} - FMT_CONSTEXPR value(unsigned val) : uint_value(val) {} - value(long long val) : long_long_value(val) {} - value(unsigned long long val) : ulong_long_value(val) {} - value(int128_t val) : int128_value(val) {} - value(uint128_t val) : uint128_value(val) {} - value(float val) : float_value(val) {} - value(double val) : double_value(val) {} - value(long double val) : long_double_value(val) {} - value(bool val) : bool_value(val) {} - value(char_type val) : char_value(val) {} - value(const char_type* val) { string.data = val; } - value(basic_string_view val) { - string.data = val.data(); - string.size = val.size(); - } - value(const void* val) : pointer(val) {} - - template value(const T& val) { - custom.value = &val; - // Get the formatter type through the context to allow different contexts - // have different extension points, e.g. `formatter` for `format` and - // `printf_formatter` for `printf`. - custom.format = format_custom_arg< - T, conditional_t::value, - typename Context::template formatter_type, - fallback_formatter>>; - } - - value(const named_arg_base& val) { named_arg = &val; } - - private: - // Formats an argument of a custom type, such as a user-defined class. - template - static void format_custom_arg( - const void* arg, basic_format_parse_context& parse_ctx, - Context& ctx) { - Formatter f; - parse_ctx.advance_to(f.parse(parse_ctx)); - ctx.advance_to(f.format(*static_cast(arg), ctx)); - } -}; - -template -FMT_CONSTEXPR basic_format_arg make_arg(const T& value); - -// To minimize the number of types we need to deal with, long is translated -// either to int or to long long depending on its size. -enum { long_short = sizeof(long) == sizeof(int) }; -using long_type = conditional_t; -using ulong_type = conditional_t; - -// Maps formatting arguments to core types. -template struct arg_mapper { - using char_type = typename Context::char_type; - - FMT_CONSTEXPR int map(signed char val) { return val; } - FMT_CONSTEXPR unsigned map(unsigned char val) { return val; } - FMT_CONSTEXPR int map(short val) { return val; } - FMT_CONSTEXPR unsigned map(unsigned short val) { return val; } - FMT_CONSTEXPR int map(int val) { return val; } - FMT_CONSTEXPR unsigned map(unsigned val) { return val; } - FMT_CONSTEXPR long_type map(long val) { return val; } - FMT_CONSTEXPR ulong_type map(unsigned long val) { return val; } - FMT_CONSTEXPR long long map(long long val) { return val; } - FMT_CONSTEXPR unsigned long long map(unsigned long long val) { return val; } - FMT_CONSTEXPR int128_t map(int128_t val) { return val; } - FMT_CONSTEXPR uint128_t map(uint128_t val) { return val; } - FMT_CONSTEXPR bool map(bool val) { return val; } - - template ::value)> - FMT_CONSTEXPR char_type map(T val) { - static_assert( - std::is_same::value || std::is_same::value, - "mixing character types is disallowed"); - return val; - } - - FMT_CONSTEXPR float map(float val) { return val; } - FMT_CONSTEXPR double map(double val) { return val; } - FMT_CONSTEXPR long double map(long double val) { return val; } - - FMT_CONSTEXPR const char_type* map(char_type* val) { return val; } - FMT_CONSTEXPR const char_type* map(const char_type* val) { return val; } - template ::value)> - FMT_CONSTEXPR basic_string_view map(const T& val) { - static_assert(std::is_same>::value, - "mixing character types is disallowed"); - return to_string_view(val); - } - template , T>::value && - !is_string::value)> - FMT_CONSTEXPR basic_string_view map(const T& val) { - return basic_string_view(val); - } - template < - typename T, - FMT_ENABLE_IF( - std::is_constructible, T>::value && - !std::is_constructible, T>::value && - !is_string::value && !has_formatter::value)> - FMT_CONSTEXPR basic_string_view map(const T& val) { - return std_string_view(val); - } - FMT_CONSTEXPR const char* map(const signed char* val) { - static_assert(std::is_same::value, "invalid string type"); - return reinterpret_cast(val); - } - FMT_CONSTEXPR const char* map(const unsigned char* val) { - static_assert(std::is_same::value, "invalid string type"); - return reinterpret_cast(val); - } - - FMT_CONSTEXPR const void* map(void* val) { return val; } - FMT_CONSTEXPR const void* map(const void* val) { return val; } - FMT_CONSTEXPR const void* map(std::nullptr_t val) { return val; } - template FMT_CONSTEXPR int map(const T*) { - // Formatting of arbitrary pointers is disallowed. If you want to output - // a pointer cast it to "void *" or "const void *". In particular, this - // forbids formatting of "[const] volatile char *" which is printed as bool - // by iostreams. - static_assert(!sizeof(T), "formatting of non-void pointers is disallowed"); - return 0; - } - - template ::value && - !has_formatter::value && - !has_fallback_formatter::value)> - FMT_CONSTEXPR auto map(const T& val) -> decltype( - map(static_cast::type>(val))) { - return map(static_cast::type>(val)); - } - template < - typename T, - FMT_ENABLE_IF( - !is_string::value && !is_char::value && - !std::is_constructible, T>::value && - (has_formatter::value || - (has_fallback_formatter::value && - !std::is_constructible, T>::value)))> - FMT_CONSTEXPR const T& map(const T& val) { - return val; - } - - template - FMT_CONSTEXPR const named_arg_base& map( - const named_arg& val) { - auto arg = make_arg(val.value); - std::memcpy(val.data, &arg, sizeof(arg)); - return val; - } -}; - -// A type constant after applying arg_mapper. -template -using mapped_type_constant = - type_constant().map(std::declval())), - typename Context::char_type>; - -enum { packed_arg_bits = 5 }; -// Maximum number of arguments with packed types. -enum { max_packed_args = 63 / packed_arg_bits }; -enum : unsigned long long { is_unpacked_bit = 1ULL << 63 }; - -template class arg_map; -} // namespace internal - -// A formatting argument. It is a trivially copyable/constructible type to -// allow storage in basic_memory_buffer. -template class basic_format_arg { - private: - internal::value value_; - internal::type type_; - - template - friend FMT_CONSTEXPR basic_format_arg internal::make_arg( - const T& value); - - template - friend FMT_CONSTEXPR auto visit_format_arg(Visitor&& vis, - const basic_format_arg& arg) - -> decltype(vis(0)); - - friend class basic_format_args; - friend class internal::arg_map; - - using char_type = typename Context::char_type; - - public: - class handle { - public: - explicit handle(internal::custom_value custom) : custom_(custom) {} - - void format(basic_format_parse_context& parse_ctx, - Context& ctx) const { - custom_.format(custom_.value, parse_ctx, ctx); - } - - private: - internal::custom_value custom_; - }; - - FMT_CONSTEXPR basic_format_arg() : type_(internal::none_type) {} - - FMT_CONSTEXPR explicit operator bool() const FMT_NOEXCEPT { - return type_ != internal::none_type; - } - - internal::type type() const { return type_; } - - bool is_integral() const { return internal::is_integral_type(type_); } - bool is_arithmetic() const { return internal::is_arithmetic_type(type_); } -}; - -/** - \rst - Visits an argument dispatching to the appropriate visit method based on - the argument type. For example, if the argument type is ``double`` then - ``vis(value)`` will be called with the value of type ``double``. - \endrst - */ -template -FMT_CONSTEXPR auto visit_format_arg(Visitor&& vis, - const basic_format_arg& arg) - -> decltype(vis(0)) { - using char_type = typename Context::char_type; - switch (arg.type_) { - case internal::none_type: - break; - case internal::named_arg_type: - FMT_ASSERT(false, "invalid argument type"); - break; - case internal::int_type: - return vis(arg.value_.int_value); - case internal::uint_type: - return vis(arg.value_.uint_value); - case internal::long_long_type: - return vis(arg.value_.long_long_value); - case internal::ulong_long_type: - return vis(arg.value_.ulong_long_value); -#if FMT_USE_INT128 - case internal::int128_type: - return vis(arg.value_.int128_value); - case internal::uint128_type: - return vis(arg.value_.uint128_value); -#else - case internal::int128_type: - case internal::uint128_type: - break; -#endif - case internal::bool_type: - return vis(arg.value_.bool_value); - case internal::char_type: - return vis(arg.value_.char_value); - case internal::float_type: - return vis(arg.value_.float_value); - case internal::double_type: - return vis(arg.value_.double_value); - case internal::long_double_type: - return vis(arg.value_.long_double_value); - case internal::cstring_type: - return vis(arg.value_.string.data); - case internal::string_type: - return vis(basic_string_view(arg.value_.string.data, - arg.value_.string.size)); - case internal::pointer_type: - return vis(arg.value_.pointer); - case internal::custom_type: - return vis(typename basic_format_arg::handle(arg.value_.custom)); - } - return vis(monostate()); -} - -namespace internal { -// A map from argument names to their values for named arguments. -template class arg_map { - private: - using char_type = typename Context::char_type; - - struct entry { - basic_string_view name; - basic_format_arg arg; - }; - - entry* map_; - unsigned size_; - - void push_back(value val) { - const auto& named = *val.named_arg; - map_[size_] = {named.name, named.template deserialize()}; - ++size_; - } - - public: - arg_map(const arg_map&) = delete; - void operator=(const arg_map&) = delete; - arg_map() : map_(nullptr), size_(0) {} - void init(const basic_format_args& args); - ~arg_map() { delete[] map_; } - - basic_format_arg find(basic_string_view name) const { - // The list is unsorted, so just return the first matching name. - for (entry *it = map_, *end = map_ + size_; it != end; ++it) { - if (it->name == name) return it->arg; - } - return {}; - } -}; - -// A type-erased reference to an std::locale to avoid heavy include. -class locale_ref { - private: - const void* locale_; // A type-erased pointer to std::locale. - - public: - locale_ref() : locale_(nullptr) {} - template explicit locale_ref(const Locale& loc); - - explicit operator bool() const FMT_NOEXCEPT { return locale_ != nullptr; } - - template Locale get() const; -}; - -template constexpr unsigned long long encode_types() { return 0; } - -template -constexpr unsigned long long encode_types() { - return mapped_type_constant::value | - (encode_types() << packed_arg_bits); -} - -template -FMT_CONSTEXPR basic_format_arg make_arg(const T& value) { - basic_format_arg arg; - arg.type_ = mapped_type_constant::value; - arg.value_ = arg_mapper().map(value); - return arg; -} - -template -inline value make_arg(const T& val) { - return arg_mapper().map(val); -} - -template -inline basic_format_arg make_arg(const T& value) { - return make_arg(value); -} -} // namespace internal - -// Formatting context. -template class basic_format_context { - public: - /** The character type for the output. */ - using char_type = Char; - - private: - OutputIt out_; - basic_format_args args_; - internal::arg_map map_; - internal::locale_ref loc_; - - public: - using iterator = OutputIt; - using format_arg = basic_format_arg; - template using formatter_type = formatter; - - basic_format_context(const basic_format_context&) = delete; - void operator=(const basic_format_context&) = delete; - /** - Constructs a ``basic_format_context`` object. References to the arguments are - stored in the object so make sure they have appropriate lifetimes. - */ - basic_format_context(OutputIt out, - basic_format_args ctx_args, - internal::locale_ref loc = internal::locale_ref()) - : out_(out), args_(ctx_args), loc_(loc) {} - - format_arg arg(int id) const { return args_.get(id); } - - // Checks if manual indexing is used and returns the argument with the - // specified name. - format_arg arg(basic_string_view name); - - internal::error_handler error_handler() { return {}; } - void on_error(std::string message) { error_handler().on_error(message); } - - // Returns an iterator to the beginning of the output range. - iterator out() { return out_; } - - // Advances the begin iterator to ``it``. - void advance_to(iterator it) { out_ = it; } - - internal::locale_ref locale() { return loc_; } -}; - -template -using buffer_context = - basic_format_context>, - Char>; -using format_context = buffer_context; -using wformat_context = buffer_context; - -/** - \rst - An array of references to arguments. It can be implicitly converted into - `~fmt::basic_format_args` for passing into type-erased formatting functions - such as `~fmt::vformat`. - \endrst - */ -template class format_arg_store { - private: - static const size_t num_args = sizeof...(Args); - static const bool is_packed = num_args < internal::max_packed_args; - - using value_type = conditional_t, - basic_format_arg>; - - // If the arguments are not packed, add one more element to mark the end. - value_type data_[num_args + (num_args == 0 ? 1 : 0)]; - - friend class basic_format_args; - - public: - static constexpr unsigned long long types = - is_packed ? internal::encode_types() - : internal::is_unpacked_bit | num_args; - - format_arg_store(const Args&... args) - : data_{internal::make_arg(args)...} {} -}; - -/** - \rst - Constructs an `~fmt::format_arg_store` object that contains references to - arguments and can be implicitly converted to `~fmt::format_args`. `Context` - can be omitted in which case it defaults to `~fmt::context`. - See `~fmt::arg` for lifetime considerations. - \endrst - */ -template -inline format_arg_store make_format_args( - const Args&... args) { - return {args...}; -} - -/** Formatting arguments. */ -template class basic_format_args { - public: - using size_type = int; - using format_arg = basic_format_arg; - - private: - // To reduce compiled code size per formatting function call, types of first - // max_packed_args arguments are passed in the types_ field. - unsigned long long types_; - union { - // If the number of arguments is less than max_packed_args, the argument - // values are stored in values_, otherwise they are stored in args_. - // This is done to reduce compiled code size as storing larger objects - // may require more code (at least on x86-64) even if the same amount of - // data is actually copied to stack. It saves ~10% on the bloat test. - const internal::value* values_; - const format_arg* args_; - }; - - bool is_packed() const { return (types_ & internal::is_unpacked_bit) == 0; } - - internal::type type(int index) const { - int shift = index * internal::packed_arg_bits; - unsigned int mask = (1 << internal::packed_arg_bits) - 1; - return static_cast((types_ >> shift) & mask); - } - - friend class internal::arg_map; - - void set_data(const internal::value* values) { values_ = values; } - void set_data(const format_arg* args) { args_ = args; } - - format_arg do_get(int index) const { - format_arg arg; - if (!is_packed()) { - auto num_args = max_size(); - if (index < num_args) arg = args_[index]; - return arg; - } - if (index > internal::max_packed_args) return arg; - arg.type_ = type(index); - if (arg.type_ == internal::none_type) return arg; - internal::value& val = arg.value_; - val = values_[index]; - return arg; - } - - public: - basic_format_args() : types_(0) {} - - /** - \rst - Constructs a `basic_format_args` object from `~fmt::format_arg_store`. - \endrst - */ - template - basic_format_args(const format_arg_store& store) - : types_(store.types) { - set_data(store.data_); - } - - /** - \rst - Constructs a `basic_format_args` object from a dynamic set of arguments. - \endrst - */ - basic_format_args(const format_arg* args, int count) - : types_(internal::is_unpacked_bit | internal::to_unsigned(count)) { - set_data(args); - } - - /** Returns the argument at specified index. */ - format_arg get(int index) const { - format_arg arg = do_get(index); - if (arg.type_ == internal::named_arg_type) - arg = arg.value_.named_arg->template deserialize(); - return arg; - } - - int max_size() const { - unsigned long long max_packed = internal::max_packed_args; - return static_cast(is_packed() ? max_packed - : types_ & ~internal::is_unpacked_bit); - } -}; - -/** An alias to ``basic_format_args``. */ -// It is a separate type rather than an alias to make symbols readable. -struct format_args : basic_format_args { - template - format_args(Args&&... args) - : basic_format_args(std::forward(args)...) {} -}; -struct wformat_args : basic_format_args { - template - wformat_args(Args&&... args) - : basic_format_args(std::forward(args)...) {} -}; - -template struct is_contiguous : std::false_type {}; - -template -struct is_contiguous> : std::true_type {}; - -template -struct is_contiguous> : std::true_type {}; - -namespace internal { - -template -struct is_contiguous_back_insert_iterator : std::false_type {}; -template -struct is_contiguous_back_insert_iterator> - : is_contiguous {}; - -template struct named_arg_base { - basic_string_view name; - - // Serialized value. - mutable char data[sizeof(basic_format_arg>)]; - - named_arg_base(basic_string_view nm) : name(nm) {} - - template basic_format_arg deserialize() const { - basic_format_arg arg; - std::memcpy(&arg, data, sizeof(basic_format_arg)); - return arg; - } -}; - -template struct named_arg : named_arg_base { - const T& value; - - named_arg(basic_string_view name, const T& val) - : named_arg_base(name), value(val) {} -}; - -template ::value)> -inline void check_format_string(const S&) { -#if defined(FMT_ENFORCE_COMPILE_STRING) - static_assert(is_compile_string::value, - "FMT_ENFORCE_COMPILE_STRING requires all format strings to " - "utilize FMT_STRING() or fmt()."); -#endif -} -template ::value)> -void check_format_string(S); - -struct view {}; -template struct bool_pack; -template -using all_true = - std::is_same, bool_pack>; - -template > -inline format_arg_store, remove_reference_t...> -make_args_checked(const S& format_str, - const remove_reference_t&... args) { - static_assert(all_true<(!std::is_base_of>() || - !std::is_reference())...>::value, - "passing views as lvalues is disallowed"); - check_format_string>...>(format_str); - return {args...}; -} - -template -std::basic_string vformat(basic_string_view format_str, - basic_format_args> args); - -template -typename buffer_context::iterator vformat_to( - buffer& buf, basic_string_view format_str, - basic_format_args> args); -} // namespace internal - -/** - \rst - Returns a named argument to be used in a formatting function. - - The named argument holds a reference and does not extend the lifetime - of its arguments. - Consequently, a dangling reference can accidentally be created. - The user should take care to only pass this function temporaries when - the named argument is itself a temporary, as per the following example. - - **Example**:: - - fmt::print("Elapsed time: {s:.2f} seconds", fmt::arg("s", 1.23)); - \endrst - */ -template > -inline internal::named_arg arg(const S& name, const T& arg) { - static_assert(internal::is_string::value, ""); - return {name, arg}; -} - -// Disable nested named arguments, e.g. ``arg("a", arg("b", 42))``. -template -void arg(S, internal::named_arg) = delete; - -/** Formats a string and writes the output to ``out``. */ -// GCC 8 and earlier cannot handle std::back_insert_iterator with -// vformat_to(...) overload, so SFINAE on iterator type instead. -template , - FMT_ENABLE_IF( - internal::is_contiguous_back_insert_iterator::value)> -OutputIt vformat_to(OutputIt out, const S& format_str, - basic_format_args> args) { - using container = remove_reference_t; - internal::container_buffer buf((internal::get_container(out))); - internal::vformat_to(buf, to_string_view(format_str), args); - return out; -} - -template ::value&& internal::is_string::value)> -inline std::back_insert_iterator format_to( - std::back_insert_iterator out, const S& format_str, - Args&&... args) { - return vformat_to( - out, to_string_view(format_str), - {internal::make_args_checked(format_str, args...)}); -} - -template > -inline std::basic_string vformat( - const S& format_str, basic_format_args> args) { - return internal::vformat(to_string_view(format_str), args); -} - -/** - \rst - Formats arguments and returns the result as a string. - - **Example**:: - - #include - std::string message = fmt::format("The answer is {}", 42); - \endrst -*/ -// Pass char_t as a default template parameter instead of using -// std::basic_string> to reduce the symbol size. -template > -inline std::basic_string format(const S& format_str, Args&&... args) { - return internal::vformat( - to_string_view(format_str), - {internal::make_args_checked(format_str, args...)}); -} - -FMT_END_NAMESPACE - -#endif // FMT_CORE_H_ - - -// LICENSE_CHANGE_END - - -#include -#include -#include -#include -#include -#include -#include - -#ifdef __clang__ -# define FMT_CLANG_VERSION (__clang_major__ * 100 + __clang_minor__) -#else -# define FMT_CLANG_VERSION 0 -#endif - -#ifdef __INTEL_COMPILER -# define FMT_ICC_VERSION __INTEL_COMPILER -#elif defined(__ICL) -# define FMT_ICC_VERSION __ICL -#else -# define FMT_ICC_VERSION 0 -#endif - -#ifdef __NVCC__ -# define FMT_CUDA_VERSION (__CUDACC_VER_MAJOR__ * 100 + __CUDACC_VER_MINOR__) -#else -# define FMT_CUDA_VERSION 0 -#endif - -#ifdef __has_builtin -# define FMT_HAS_BUILTIN(x) __has_builtin(x) -#else -# define FMT_HAS_BUILTIN(x) 0 -#endif - -#if FMT_HAS_CPP_ATTRIBUTE(fallthrough) && \ - (__cplusplus >= 201703 || FMT_GCC_VERSION != 0) -# define FMT_FALLTHROUGH [[fallthrough]] -#else -# define FMT_FALLTHROUGH -#endif - -#ifndef FMT_THROW -# if FMT_EXCEPTIONS -# if FMT_MSC_VER -FMT_BEGIN_NAMESPACE -namespace internal { -template inline void do_throw(const Exception& x) { - // Silence unreachable code warnings in MSVC because these are nearly - // impossible to fix in a generic code. - volatile bool b = true; - if (b) throw x; -} -} // namespace internal -FMT_END_NAMESPACE -# define FMT_THROW(x) internal::do_throw(x) -# else -# define FMT_THROW(x) throw x -# endif -# else -# define FMT_THROW(x) \ - do { \ - static_cast(sizeof(x)); \ - FMT_ASSERT(false, ""); \ - } while (false) -# endif -#endif - -#ifndef FMT_USE_USER_DEFINED_LITERALS -// For Intel and NVIDIA compilers both they and the system gcc/msc support UDLs. -# if (FMT_HAS_FEATURE(cxx_user_literals) || FMT_GCC_VERSION >= 407 || \ - FMT_MSC_VER >= 1900) && \ - (!(FMT_ICC_VERSION || FMT_CUDA_VERSION) || FMT_ICC_VERSION >= 1500 || \ - FMT_CUDA_VERSION >= 700) -# define FMT_USE_USER_DEFINED_LITERALS 1 -# else -# define FMT_USE_USER_DEFINED_LITERALS 0 -# endif -#endif - -#ifndef FMT_USE_UDL_TEMPLATE -#define FMT_USE_UDL_TEMPLATE 0 -#endif - -// __builtin_clz is broken in clang with Microsoft CodeGen: -// https://github.com/fmtlib/fmt/issues/519 -#if (FMT_GCC_VERSION || FMT_HAS_BUILTIN(__builtin_clz)) && !FMT_MSC_VER -# define FMT_BUILTIN_CLZ(n) __builtin_clz(n) -#endif -#if (FMT_GCC_VERSION || FMT_HAS_BUILTIN(__builtin_clzll)) && !FMT_MSC_VER -# define FMT_BUILTIN_CLZLL(n) __builtin_clzll(n) -#endif - -// Some compilers masquerade as both MSVC and GCC-likes or otherwise support -// __builtin_clz and __builtin_clzll, so only define FMT_BUILTIN_CLZ using the -// MSVC intrinsics if the clz and clzll builtins are not available. -#if FMT_MSC_VER && !defined(FMT_BUILTIN_CLZLL) && !defined(_MANAGED) -# include // _BitScanReverse, _BitScanReverse64 - -FMT_BEGIN_NAMESPACE -namespace internal { -// Avoid Clang with Microsoft CodeGen's -Wunknown-pragmas warning. -# ifndef __clang__ -# pragma intrinsic(_BitScanReverse) -# endif -inline uint32_t clz(uint32_t x) { - unsigned long r = 0; - _BitScanReverse(&r, x); - - FMT_ASSERT(x != 0, ""); - // Static analysis complains about using uninitialized data - // "r", but the only way that can happen is if "x" is 0, - // which the callers guarantee to not happen. -# pragma warning(suppress : 6102) - return 31 - r; -} -# define FMT_BUILTIN_CLZ(n) internal::clz(n) - -# if defined(_WIN64) && !defined(__clang__) -# pragma intrinsic(_BitScanReverse64) -# endif - -inline uint32_t clzll(uint64_t x) { - unsigned long r = 0; -# ifdef _WIN64 - _BitScanReverse64(&r, x); -# else - // Scan the high 32 bits. - if (_BitScanReverse(&r, static_cast(x >> 32))) return 63 - (r + 32); - - // Scan the low 32 bits. - _BitScanReverse(&r, static_cast(x)); -# endif - - FMT_ASSERT(x != 0, ""); - // Static analysis complains about using uninitialized data - // "r", but the only way that can happen is if "x" is 0, - // which the callers guarantee to not happen. -# pragma warning(suppress : 6102) - return 63 - r; -} -# define FMT_BUILTIN_CLZLL(n) internal::clzll(n) -} // namespace internal -FMT_END_NAMESPACE -#endif - -// Enable the deprecated numeric alignment. -#ifndef FMT_NUMERIC_ALIGN -# define FMT_NUMERIC_ALIGN 1 -#endif - -// Enable the deprecated percent specifier. -#ifndef FMT_DEPRECATED_PERCENT -# define FMT_DEPRECATED_PERCENT 0 -#endif - -FMT_BEGIN_NAMESPACE -namespace internal { - -// A helper function to suppress bogus "conditional expression is constant" -// warnings. -template inline T const_check(T value) { return value; } - -// An equivalent of `*reinterpret_cast(&source)` that doesn't have -// undefined behavior (e.g. due to type aliasing). -// Example: uint64_t d = bit_cast(2.718); -template -inline Dest bit_cast(const Source& source) { - static_assert(sizeof(Dest) == sizeof(Source), "size mismatch"); - Dest dest; - std::memcpy(&dest, &source, sizeof(dest)); - return dest; -} - -inline bool is_big_endian() { - auto u = 1u; - struct bytes { - char data[sizeof(u)]; - }; - return bit_cast(u).data[0] == 0; -} - -// A fallback implementation of uintptr_t for systems that lack it. -struct fallback_uintptr { - unsigned char value[sizeof(void*)]; - - fallback_uintptr() = default; - explicit fallback_uintptr(const void* p) { - *this = bit_cast(p); - if (is_big_endian()) { - for (size_t i = 0, j = sizeof(void*) - 1; i < j; ++i, --j) - std::swap(value[i], value[j]); - } - } -}; -#ifdef UINTPTR_MAX -using uintptr_t = ::uintptr_t; -inline uintptr_t to_uintptr(const void* p) { return bit_cast(p); } -#else -using uintptr_t = fallback_uintptr; -inline fallback_uintptr to_uintptr(const void* p) { - return fallback_uintptr(p); -} -#endif - -// Returns the largest possible value for type T. Same as -// std::numeric_limits::max() but shorter and not affected by the max macro. -template constexpr T max_value() { - return (std::numeric_limits::max)(); -} -template constexpr int num_bits() { - return std::numeric_limits::digits; -} -template <> constexpr int num_bits() { - return static_cast(sizeof(void*) * - std::numeric_limits::digits); -} - -// An approximation of iterator_t for pre-C++20 systems. -template -using iterator_t = decltype(std::begin(std::declval())); - -// Detect the iterator category of *any* given type in a SFINAE-friendly way. -// Unfortunately, older implementations of std::iterator_traits are not safe -// for use in a SFINAE-context. -template -struct iterator_category : std::false_type {}; - -template struct iterator_category { - using type = std::random_access_iterator_tag; -}; - -template -struct iterator_category> { - using type = typename It::iterator_category; -}; - -// Detect if *any* given type models the OutputIterator concept. -template class is_output_iterator { - // Check for mutability because all iterator categories derived from - // std::input_iterator_tag *may* also meet the requirements of an - // OutputIterator, thereby falling into the category of 'mutable iterators' - // [iterator.requirements.general] clause 4. The compiler reveals this - // property only at the point of *actually dereferencing* the iterator! - template - static decltype(*(std::declval())) test(std::input_iterator_tag); - template static char& test(std::output_iterator_tag); - template static const char& test(...); - - using type = decltype(test(typename iterator_category::type{})); - - public: - static const bool value = !std::is_const>::value; -}; - -// A workaround for std::string not having mutable data() until C++17. -template inline Char* get_data(std::basic_string& s) { - return &s[0]; -} -template -inline typename Container::value_type* get_data(Container& c) { - return c.data(); -} - -#ifdef _SECURE_SCL -// Make a checked iterator to avoid MSVC warnings. -template using checked_ptr = stdext::checked_array_iterator; -template checked_ptr make_checked(T* p, std::size_t size) { - return {p, size}; -} -#else -template using checked_ptr = T*; -template inline T* make_checked(T* p, std::size_t) { return p; } -#endif - -template ::value)> -inline checked_ptr reserve( - std::back_insert_iterator& it, std::size_t n) { - Container& c = get_container(it); - std::size_t size = c.size(); - c.resize(size + n); - return make_checked(get_data(c) + size, n); -} - -template -inline Iterator& reserve(Iterator& it, std::size_t) { - return it; -} - -// An output iterator that counts the number of objects written to it and -// discards them. -class counting_iterator { - private: - std::size_t count_; - - public: - using iterator_category = std::output_iterator_tag; - using difference_type = std::ptrdiff_t; - using pointer = void; - using reference = void; - using _Unchecked_type = counting_iterator; // Mark iterator as checked. - - struct value_type { - template void operator=(const T&) {} - }; - - counting_iterator() : count_(0) {} - - std::size_t count() const { return count_; } - - counting_iterator& operator++() { - ++count_; - return *this; - } - - counting_iterator operator++(int) { - auto it = *this; - ++*this; - return it; - } - - value_type operator*() const { return {}; } -}; - -template class truncating_iterator_base { - protected: - OutputIt out_; - std::size_t limit_; - std::size_t count_; - - truncating_iterator_base(OutputIt out, std::size_t limit) - : out_(out), limit_(limit), count_(0) {} - - public: - using iterator_category = std::output_iterator_tag; - using difference_type = void; - using pointer = void; - using reference = void; - using _Unchecked_type = - truncating_iterator_base; // Mark iterator as checked. - - OutputIt base() const { return out_; } - std::size_t count() const { return count_; } -}; - -// An output iterator that truncates the output and counts the number of objects -// written to it. -template ::value_type>::type> -class truncating_iterator; - -template -class truncating_iterator - : public truncating_iterator_base { - using traits = std::iterator_traits; - - mutable typename traits::value_type blackhole_; - - public: - using value_type = typename traits::value_type; - - truncating_iterator(OutputIt out, std::size_t limit) - : truncating_iterator_base(out, limit) {} - - truncating_iterator& operator++() { - if (this->count_++ < this->limit_) ++this->out_; - return *this; - } - - truncating_iterator operator++(int) { - auto it = *this; - ++*this; - return it; - } - - value_type& operator*() const { - return this->count_ < this->limit_ ? *this->out_ : blackhole_; - } -}; - -template -class truncating_iterator - : public truncating_iterator_base { - public: - using value_type = typename OutputIt::container_type::value_type; - - truncating_iterator(OutputIt out, std::size_t limit) - : truncating_iterator_base(out, limit) {} - - truncating_iterator& operator=(value_type val) { - if (this->count_++ < this->limit_) this->out_ = val; - return *this; - } - - truncating_iterator& operator++() { return *this; } - truncating_iterator& operator++(int) { return *this; } - truncating_iterator& operator*() { return *this; } -}; - -// A range with the specified output iterator and value type. -template -class output_range { - private: - OutputIt it_; - - public: - using value_type = T; - using iterator = OutputIt; - struct sentinel {}; - - explicit output_range(OutputIt it) : it_(it) {} - OutputIt begin() const { return it_; } - sentinel end() const { return {}; } // Sentinel is not used yet. -}; - -template -inline size_t count_code_points(basic_string_view s) { - return s.size(); -} - -// Counts the number of code points in a UTF-8 string. -inline size_t count_code_points(basic_string_view s) { - const fmt_char8_t* data = s.data(); - size_t num_code_points = 0; - for (size_t i = 0, size = s.size(); i != size; ++i) { - if ((data[i] & 0xc0) != 0x80) ++num_code_points; - } - return num_code_points; -} - -template -inline size_t code_point_index(basic_string_view s, size_t n) { - size_t size = s.size(); - return n < size ? n : size; -} - -// Calculates the index of the nth code point in a UTF-8 string. -inline size_t code_point_index(basic_string_view s, size_t n) { - const fmt_char8_t* data = s.data(); - size_t num_code_points = 0; - for (size_t i = 0, size = s.size(); i != size; ++i) { - if ((data[i] & 0xc0) != 0x80 && ++num_code_points > n) { - return i; - } - } - return s.size(); -} - -inline fmt_char8_t to_fmt_char8_t(char c) { return static_cast(c); } - -template -using needs_conversion = bool_constant< - std::is_same::value_type, - char>::value && - std::is_same::value>; - -template ::value)> -OutputIt copy_str(InputIt begin, InputIt end, OutputIt it) { - return std::copy(begin, end, it); -} - -template ::value)> -OutputIt copy_str(InputIt begin, InputIt end, OutputIt it) { - return std::transform(begin, end, it, to_fmt_char8_t); -} - -#ifndef FMT_USE_GRISU -# define FMT_USE_GRISU 1 -#endif - -template constexpr bool use_grisu() { - return FMT_USE_GRISU && std::numeric_limits::is_iec559 && - sizeof(T) <= sizeof(double); -} - -template -template -void buffer::append(const U* begin, const U* end) { - std::size_t new_size = size_ + to_unsigned(end - begin); - reserve(new_size); - std::uninitialized_copy(begin, end, make_checked(ptr_, capacity_) + size_); - size_ = new_size; -} -} // namespace internal - -// A range with an iterator appending to a buffer. -template -class buffer_range : public internal::output_range< - std::back_insert_iterator>, T> { - public: - using iterator = std::back_insert_iterator>; - using internal::output_range::output_range; - buffer_range(internal::buffer& buf) - : internal::output_range(std::back_inserter(buf)) {} -}; - -// A UTF-8 string view. -class u8string_view : public basic_string_view { - public: - u8string_view(const char* s) - : basic_string_view(reinterpret_cast(s)) {} - u8string_view(const char* s, size_t count) FMT_NOEXCEPT - : basic_string_view(reinterpret_cast(s), count) { - } -}; - -#if FMT_USE_USER_DEFINED_LITERALS -inline namespace literals { -inline u8string_view operator"" _u(const char* s, std::size_t n) { - return {s, n}; -} -} // namespace literals -#endif - -// The number of characters to store in the basic_memory_buffer object itself -// to avoid dynamic memory allocation. -enum { inline_buffer_size = 500 }; - -/** - \rst - A dynamically growing memory buffer for trivially copyable/constructible types - with the first ``SIZE`` elements stored in the object itself. - - You can use one of the following type aliases for common character types: - - +----------------+------------------------------+ - | Type | Definition | - +================+==============================+ - | memory_buffer | basic_memory_buffer | - +----------------+------------------------------+ - | wmemory_buffer | basic_memory_buffer | - +----------------+------------------------------+ - - **Example**:: - - fmt::memory_buffer out; - format_to(out, "The answer is {}.", 42); - - This will append the following output to the ``out`` object: - - .. code-block:: none - - The answer is 42. - - The output can be converted to an ``std::string`` with ``to_string(out)``. - \endrst - */ -template > -class basic_memory_buffer : private Allocator, public internal::buffer { - private: - T store_[SIZE]; - - // Deallocate memory allocated by the buffer. - void deallocate() { - T* data = this->data(); - if (data != store_) Allocator::deallocate(data, this->capacity()); - } - - protected: - void grow(std::size_t size) FMT_OVERRIDE; - - public: - using value_type = T; - using const_reference = const T&; - - explicit basic_memory_buffer(const Allocator& alloc = Allocator()) - : Allocator(alloc) { - this->set(store_, SIZE); - } - ~basic_memory_buffer() FMT_OVERRIDE { deallocate(); } - - private: - // Move data from other to this buffer. - void move(basic_memory_buffer& other) { - Allocator &this_alloc = *this, &other_alloc = other; - this_alloc = std::move(other_alloc); - T* data = other.data(); - std::size_t size = other.size(), capacity = other.capacity(); - if (data == other.store_) { - this->set(store_, capacity); - std::uninitialized_copy(other.store_, other.store_ + size, - internal::make_checked(store_, capacity)); - } else { - this->set(data, capacity); - // Set pointer to the inline array so that delete is not called - // when deallocating. - other.set(other.store_, 0); - } - this->resize(size); - } - - public: - /** - \rst - Constructs a :class:`fmt::basic_memory_buffer` object moving the content - of the other object to it. - \endrst - */ - basic_memory_buffer(basic_memory_buffer&& other) FMT_NOEXCEPT { move(other); } - - /** - \rst - Moves the content of the other ``basic_memory_buffer`` object to this one. - \endrst - */ - basic_memory_buffer& operator=(basic_memory_buffer&& other) FMT_NOEXCEPT { - FMT_ASSERT(this != &other, ""); - deallocate(); - move(other); - return *this; - } - - // Returns a copy of the allocator associated with this buffer. - Allocator get_allocator() const { return *this; } -}; - -template -void basic_memory_buffer::grow(std::size_t size) { -#ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION - if (size > 1000) throw std::runtime_error("fuzz mode - won't grow that much"); -#endif - std::size_t old_capacity = this->capacity(); - std::size_t new_capacity = old_capacity + old_capacity / 2; - if (size > new_capacity) new_capacity = size; - T* old_data = this->data(); - T* new_data = std::allocator_traits::allocate(*this, new_capacity); - // The following code doesn't throw, so the raw pointer above doesn't leak. - std::uninitialized_copy(old_data, old_data + this->size(), - internal::make_checked(new_data, new_capacity)); - this->set(new_data, new_capacity); - // deallocate must not throw according to the standard, but even if it does, - // the buffer already uses the new storage and will deallocate it in - // destructor. - if (old_data != store_) Allocator::deallocate(old_data, old_capacity); -} - -using memory_buffer = basic_memory_buffer; -using wmemory_buffer = basic_memory_buffer; - -namespace internal { - -// Returns true if value is negative, false otherwise. -// Same as `value < 0` but doesn't produce warnings if T is an unsigned type. -template ::is_signed)> -FMT_CONSTEXPR bool is_negative(T value) { - return value < 0; -} -template ::is_signed)> -FMT_CONSTEXPR bool is_negative(T) { - return false; -} - -// Smallest of uint32_t, uint64_t, uint128_t that is large enough to -// represent all values of T. -template -using uint32_or_64_or_128_t = conditional_t< - std::numeric_limits::digits <= 32, uint32_t, - conditional_t::digits <= 64, uint64_t, uint128_t>>; - -// Static data is placed in this class template for the header-only config. -template struct FMT_EXTERN_TEMPLATE_API basic_data { - static const uint64_t powers_of_10_64[]; - static const uint32_t zero_or_powers_of_10_32[]; - static const uint64_t zero_or_powers_of_10_64[]; - static const uint64_t pow10_significands[]; - static const int16_t pow10_exponents[]; - static const char digits[]; - static const char hex_digits[]; - static const char foreground_color[]; - static const char background_color[]; - static const char reset_color[5]; - static const wchar_t wreset_color[5]; - static const char signs[]; -}; - -FMT_EXTERN template struct basic_data; - -// This is a struct rather than an alias to avoid shadowing warnings in gcc. -struct data : basic_data<> {}; - -#ifdef FMT_BUILTIN_CLZLL -// Returns the number of decimal digits in n. Leading zeros are not counted -// except for n == 0 in which case count_digits returns 1. -inline int count_digits(uint64_t n) { - // Based on http://graphics.stanford.edu/~seander/bithacks.html#IntegerLog10 - // and the benchmark https://github.com/localvoid/cxx-benchmark-count-digits. - int t = (64 - FMT_BUILTIN_CLZLL(n | 1)) * 1233 >> 12; - return t - (n < data::zero_or_powers_of_10_64[t]) + 1; -} -#else -// Fallback version of count_digits used when __builtin_clz is not available. -inline int count_digits(uint64_t n) { - int count = 1; - for (;;) { - // Integer division is slow so do it for a group of four digits instead - // of for every digit. The idea comes from the talk by Alexandrescu - // "Three Optimization Tips for C++". See speed-test for a comparison. - if (n < 10) return count; - if (n < 100) return count + 1; - if (n < 1000) return count + 2; - if (n < 10000) return count + 3; - n /= 10000u; - count += 4; - } -} -#endif - -#if FMT_USE_INT128 -inline int count_digits(uint128_t n) { - int count = 1; - for (;;) { - // Integer division is slow so do it for a group of four digits instead - // of for every digit. The idea comes from the talk by Alexandrescu - // "Three Optimization Tips for C++". See speed-test for a comparison. - if (n < 10) return count; - if (n < 100) return count + 1; - if (n < 1000) return count + 2; - if (n < 10000) return count + 3; - n /= 10000U; - count += 4; - } -} -#endif - -// Counts the number of digits in n. BITS = log2(radix). -template inline int count_digits(UInt n) { - int num_digits = 0; - do { - ++num_digits; - } while ((n >>= BITS) != 0); - return num_digits; -} - -template <> int count_digits<4>(internal::fallback_uintptr n); - -#if FMT_GCC_VERSION || FMT_CLANG_VERSION -# define FMT_ALWAYS_INLINE inline __attribute__((always_inline)) -#else -# define FMT_ALWAYS_INLINE -#endif - -#ifdef FMT_BUILTIN_CLZ -// Optional version of count_digits for better performance on 32-bit platforms. -inline int count_digits(uint32_t n) { - int t = (32 - FMT_BUILTIN_CLZ(n | 1)) * 1233 >> 12; - return t - (n < data::zero_or_powers_of_10_32[t]) + 1; -} -#endif - -template FMT_API std::string grouping_impl(locale_ref loc); -template inline std::string grouping(locale_ref loc) { - return grouping_impl(loc); -} -template <> inline std::string grouping(locale_ref loc) { - return grouping_impl(loc); -} - -template FMT_API Char thousands_sep_impl(locale_ref loc); -template inline Char thousands_sep(locale_ref loc) { - return Char(thousands_sep_impl(loc)); -} -template <> inline wchar_t thousands_sep(locale_ref loc) { - return thousands_sep_impl(loc); -} - -template FMT_API Char decimal_point_impl(locale_ref loc); -template inline Char decimal_point(locale_ref loc) { - return Char(decimal_point_impl(loc)); -} -template <> inline wchar_t decimal_point(locale_ref loc) { - return decimal_point_impl(loc); -} - -// Formats a decimal unsigned integer value writing into buffer. -// add_thousands_sep is called after writing each char to add a thousands -// separator if necessary. -template -inline Char* format_decimal(Char* buffer, UInt value, int num_digits, - F add_thousands_sep) { - FMT_ASSERT(num_digits >= 0, "invalid digit count"); - buffer += num_digits; - Char* end = buffer; - while (value >= 100) { - // Integer division is slow so do it for a group of two digits instead - // of for every digit. The idea comes from the talk by Alexandrescu - // "Three Optimization Tips for C++". See speed-test for a comparison. - auto index = static_cast((value % 100) * 2); - value /= 100; - *--buffer = static_cast(data::digits[index + 1]); - add_thousands_sep(buffer); - *--buffer = static_cast(data::digits[index]); - add_thousands_sep(buffer); - } - if (value < 10) { - *--buffer = static_cast('0' + value); - return end; - } - auto index = static_cast(value * 2); - *--buffer = static_cast(data::digits[index + 1]); - add_thousands_sep(buffer); - *--buffer = static_cast(data::digits[index]); - return end; -} - -template constexpr int digits10() noexcept { - return std::numeric_limits::digits10; -} -template <> constexpr int digits10() noexcept { return 38; } -template <> constexpr int digits10() noexcept { return 38; } - -template -inline Iterator format_decimal(Iterator out, UInt value, int num_digits, - F add_thousands_sep) { - FMT_ASSERT(num_digits >= 0, "invalid digit count"); - // Buffer should be large enough to hold all digits (<= digits10 + 1). - enum { max_size = digits10() + 1 }; - Char buffer[2 * max_size]; - auto end = format_decimal(buffer, value, num_digits, add_thousands_sep); - return internal::copy_str(buffer, end, out); -} - -template -inline It format_decimal(It out, UInt value, int num_digits) { - return format_decimal(out, value, num_digits, [](Char*) {}); -} - -template -inline Char* format_uint(Char* buffer, UInt value, int num_digits, - bool upper = false) { - buffer += num_digits; - Char* end = buffer; - do { - const char* digits = upper ? "0123456789ABCDEF" : data::hex_digits; - unsigned digit = (value & ((1 << BASE_BITS) - 1)); - *--buffer = static_cast(BASE_BITS < 4 ? static_cast('0' + digit) - : digits[digit]); - } while ((value >>= BASE_BITS) != 0); - return end; -} - -template -Char* format_uint(Char* buffer, internal::fallback_uintptr n, int num_digits, - bool = false) { - auto char_digits = std::numeric_limits::digits / 4; - int start = (num_digits + char_digits - 1) / char_digits - 1; - if (int start_digits = num_digits % char_digits) { - unsigned value = n.value[start--]; - buffer = format_uint(buffer, value, start_digits); - } - for (; start >= 0; --start) { - unsigned value = n.value[start]; - buffer += char_digits; - auto p = buffer; - for (int i = 0; i < char_digits; ++i) { - unsigned digit = (value & ((1 << BASE_BITS) - 1)); - *--p = static_cast(data::hex_digits[digit]); - value >>= BASE_BITS; - } - } - return buffer; -} - -template -inline It format_uint(It out, UInt value, int num_digits, bool upper = false) { - // Buffer should be large enough to hold all digits (digits / BASE_BITS + 1). - char buffer[num_bits() / BASE_BITS + 1]; - format_uint(buffer, value, num_digits, upper); - return internal::copy_str(buffer, buffer + num_digits, out); -} - -template struct null {}; - -// Workaround an array initialization issue in gcc 4.8. -template struct fill_t { - private: - Char data_[6]; - - public: - FMT_CONSTEXPR Char& operator[](size_t index) { return data_[index]; } - FMT_CONSTEXPR const Char& operator[](size_t index) const { - return data_[index]; - } - - static FMT_CONSTEXPR fill_t make() { - auto fill = fill_t(); - fill[0] = Char(' '); - return fill; - } -}; -} // namespace internal - -// We cannot use enum classes as bit fields because of a gcc bug -// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=61414. -namespace align { -enum type { none, left, right, center, numeric }; -} -using align_t = align::type; - -namespace sign { -enum type { none, minus, plus, space }; -} -using sign_t = sign::type; - -// Format specifiers for built-in and string types. -template struct basic_format_specs { - int width; - int precision; - char type; - align_t align : 4; - sign_t sign : 3; - bool alt : 1; // Alternate form ('#'). - internal::fill_t fill; - char thousands; - - constexpr basic_format_specs() - : width(0), - precision(-1), - type(0), - align(align::none), - sign(sign::none), - alt(false), - fill(internal::fill_t::make()), - thousands('\0'){} -}; - -using format_specs = basic_format_specs; - -namespace internal { - -// A floating-point presentation format. -enum class float_format : unsigned char { - general, // General: exponent notation or fixed point based on magnitude. - exp, // Exponent notation with the default precision of 6, e.g. 1.2e-3. - fixed, // Fixed point with the default precision of 6, e.g. 0.0012. - hex -}; - -struct float_specs { - int precision; - float_format format : 8; - sign_t sign : 8; - bool upper : 1; - bool locale : 1; - bool percent : 1; - bool binary32 : 1; - bool use_grisu : 1; - bool trailing_zeros : 1; -}; - -// Writes the exponent exp in the form "[+-]d{2,3}" to buffer. -template It write_exponent(int exp, It it) { - FMT_ASSERT(-10000 < exp && exp < 10000, "exponent out of range"); - if (exp < 0) { - *it++ = static_cast('-'); - exp = -exp; - } else { - *it++ = static_cast('+'); - } - if (exp >= 100) { - const char* top = data::digits + (exp / 100) * 2; - if (exp >= 1000) *it++ = static_cast(top[0]); - *it++ = static_cast(top[1]); - exp %= 100; - } - const char* d = data::digits + exp * 2; - *it++ = static_cast(d[0]); - *it++ = static_cast(d[1]); - return it; -} - -template class float_writer { - private: - // The number is given as v = digits_ * pow(10, exp_). - const char* digits_; - int num_digits_; - int exp_; - size_t size_; - float_specs specs_; - Char decimal_point_; - - template It prettify(It it) const { - // pow(10, full_exp - 1) <= v <= pow(10, full_exp). - int full_exp = num_digits_ + exp_; - if (specs_.format == float_format::exp) { - // Insert a decimal point after the first digit and add an exponent. - *it++ = static_cast(*digits_); - int num_zeros = specs_.precision - num_digits_; - bool trailing_zeros = num_zeros > 0 && specs_.trailing_zeros; - if (num_digits_ > 1 || trailing_zeros) *it++ = decimal_point_; - it = copy_str(digits_ + 1, digits_ + num_digits_, it); - if (trailing_zeros) - it = std::fill_n(it, num_zeros, static_cast('0')); - *it++ = static_cast(specs_.upper ? 'E' : 'e'); - return write_exponent(full_exp - 1, it); - } - if (num_digits_ <= full_exp) { - // 1234e7 -> 12340000000[.0+] - it = copy_str(digits_, digits_ + num_digits_, it); - it = std::fill_n(it, full_exp - num_digits_, static_cast('0')); - if (specs_.trailing_zeros) { - *it++ = decimal_point_; - int num_zeros = specs_.precision - full_exp; - if (num_zeros <= 0) { - if (specs_.format != float_format::fixed) - *it++ = static_cast('0'); - return it; - } -#ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION - if (num_zeros > 1000) - throw std::runtime_error("fuzz mode - avoiding excessive cpu use"); -#endif - it = std::fill_n(it, num_zeros, static_cast('0')); - } - } else if (full_exp > 0) { - // 1234e-2 -> 12.34[0+] - it = copy_str(digits_, digits_ + full_exp, it); - if (!specs_.trailing_zeros) { - // Remove trailing zeros. - int num_digits = num_digits_; - while (num_digits > full_exp && digits_[num_digits - 1] == '0') - --num_digits; - if (num_digits != full_exp) *it++ = decimal_point_; - return copy_str(digits_ + full_exp, digits_ + num_digits, it); - } - *it++ = decimal_point_; - it = copy_str(digits_ + full_exp, digits_ + num_digits_, it); - if (specs_.precision > num_digits_) { - // Add trailing zeros. - int num_zeros = specs_.precision - num_digits_; - it = std::fill_n(it, num_zeros, static_cast('0')); - } - } else { - // 1234e-6 -> 0.001234 - *it++ = static_cast('0'); - int num_zeros = -full_exp; - if (specs_.precision >= 0 && specs_.precision < num_zeros) - num_zeros = specs_.precision; - int num_digits = num_digits_; - if (!specs_.trailing_zeros) - while (num_digits > 0 && digits_[num_digits - 1] == '0') --num_digits; - if (num_zeros != 0 || num_digits != 0) { - *it++ = decimal_point_; - it = std::fill_n(it, num_zeros, static_cast('0')); - it = copy_str(digits_, digits_ + num_digits, it); - } - } - return it; - } - - public: - float_writer(const char* digits, int num_digits, int exp, float_specs specs, - Char decimal_point) - : digits_(digits), - num_digits_(num_digits), - exp_(exp), - specs_(specs), - decimal_point_(decimal_point) { - int full_exp = num_digits + exp - 1; - int precision = specs.precision > 0 ? specs.precision : 16; - if (specs_.format == float_format::general && - !(full_exp >= -4 && full_exp < precision)) { - specs_.format = float_format::exp; - } - size_ = prettify(counting_iterator()).count(); - size_ += specs.sign ? 1 : 0; - } - - size_t size() const { return size_; } - size_t width() const { return size(); } - - template void operator()(It&& it) { - if (specs_.sign) *it++ = static_cast(data::signs[specs_.sign]); - it = prettify(it); - } -}; - -template -int format_float(T value, int precision, float_specs specs, buffer& buf); - -// Formats a floating-point number with snprintf. -template -int snprintf_float(T value, int precision, float_specs specs, - buffer& buf); - -template T promote_float(T value) { return value; } -inline double promote_float(float value) { return value; } - -template -FMT_CONSTEXPR void handle_int_type_spec(const Spec& specs, Handler&& handler) { - if (specs.thousands != '\0') { - handler.on_num(); - return; - } - switch (specs.type) { - case 0: - case 'd': - handler.on_dec(); - break; - case 'x': - case 'X': - handler.on_hex(); - break; - case 'b': - case 'B': - handler.on_bin(); - break; - case 'o': - handler.on_oct(); - break; - case 'n': - case 'l': - case 'L': - handler.on_num(); - break; - default: - handler.on_error("Invalid type specifier \"" + std::string(1, specs.type) + "\" for formatting a value of type int"); - } -} - -template -FMT_CONSTEXPR float_specs parse_float_type_spec( - const basic_format_specs& specs, ErrorHandler&& eh = {}) { - - auto result = float_specs(); - if (specs.thousands != '\0') { - eh.on_error("Thousand separators are not supported for floating point numbers"); - return result; - } - result.trailing_zeros = specs.alt; - switch (specs.type) { - case 0: - result.format = float_format::general; - result.trailing_zeros |= specs.precision != 0; - break; - case 'G': - result.upper = true; - FMT_FALLTHROUGH; - case 'g': - result.format = float_format::general; - break; - case 'E': - result.upper = true; - FMT_FALLTHROUGH; - case 'e': - result.format = float_format::exp; - result.trailing_zeros |= specs.precision != 0; - break; - case 'F': - result.upper = true; - FMT_FALLTHROUGH; - case 'f': - result.format = float_format::fixed; - result.trailing_zeros |= specs.precision != 0; - break; -#if FMT_DEPRECATED_PERCENT - case '%': - result.format = float_format::fixed; - result.percent = true; - break; -#endif - case 'A': - result.upper = true; - FMT_FALLTHROUGH; - case 'a': - result.format = float_format::hex; - break; - case 'n': - case 'l': - case 'L': - result.locale = true; - break; - default: - eh.on_error("Invalid type specifier \"" + std::string(1, specs.type) + "\" for formatting a value of type float"); - break; - } - return result; -} - -template -FMT_CONSTEXPR void handle_char_specs(const basic_format_specs* specs, - Handler&& handler) { - if (!specs) return handler.on_char(); - if (specs->type && specs->type != 'c') return handler.on_int(); - if (specs->align == align::numeric || specs->sign != sign::none || specs->alt) - handler.on_error("invalid format specifier for char"); - handler.on_char(); -} - -template -FMT_CONSTEXPR void handle_cstring_type_spec(Char spec, Handler&& handler) { - if (spec == 0 || spec == 's') - handler.on_string(); - else if (spec == 'p') - handler.on_pointer(); - else - handler.on_error("Invalid type specifier \"" + std::string(1, spec) + "\" for formatting a value of type string"); -} - -template -FMT_CONSTEXPR void check_string_type_spec(Char spec, ErrorHandler&& eh) { - if (spec != 0 && spec != 's') eh.on_error("Invalid type specifier \"" + std::string(1, spec) + "\" for formatting a value of type string"); -} - -template -FMT_CONSTEXPR void check_pointer_type_spec(Char spec, ErrorHandler&& eh) { - if (spec != 0 && spec != 'p') eh.on_error("Invalid type specifier \"" + std::string(1, spec) + "\" for formatting a value of type pointer"); -} - -template class int_type_checker : private ErrorHandler { - public: - FMT_CONSTEXPR explicit int_type_checker(ErrorHandler eh) : ErrorHandler(eh) {} - - FMT_CONSTEXPR void on_dec() {} - FMT_CONSTEXPR void on_hex() {} - FMT_CONSTEXPR void on_bin() {} - FMT_CONSTEXPR void on_oct() {} - FMT_CONSTEXPR void on_num() {} - - FMT_CONSTEXPR void on_error(std::string error) { - ErrorHandler::on_error(error); - } -}; - -template -class char_specs_checker : public ErrorHandler { - private: - char type_; - - public: - FMT_CONSTEXPR char_specs_checker(char type, ErrorHandler eh) - : ErrorHandler(eh), type_(type) {} - - FMT_CONSTEXPR void on_int() { - handle_int_type_spec(type_, int_type_checker(*this)); - } - FMT_CONSTEXPR void on_char() {} -}; - -template -class cstring_type_checker : public ErrorHandler { - public: - FMT_CONSTEXPR explicit cstring_type_checker(ErrorHandler eh) - : ErrorHandler(eh) {} - - FMT_CONSTEXPR void on_string() {} - FMT_CONSTEXPR void on_pointer() {} -}; - -template -void arg_map::init(const basic_format_args& args) { - if (map_) return; - map_ = new entry[internal::to_unsigned(args.max_size())]; - if (args.is_packed()) { - for (int i = 0;; ++i) { - internal::type arg_type = args.type(i); - if (arg_type == internal::none_type) return; - if (arg_type == internal::named_arg_type) push_back(args.values_[i]); - } - } - for (int i = 0, n = args.max_size(); i < n; ++i) { - auto type = args.args_[i].type_; - if (type == internal::named_arg_type) push_back(args.args_[i].value_); - } -} - -template struct nonfinite_writer { - sign_t sign; - const char* str; - static constexpr size_t str_size = 3; - - size_t size() const { return str_size + (sign ? 1 : 0); } - size_t width() const { return size(); } - - template void operator()(It&& it) const { - if (sign) *it++ = static_cast(data::signs[sign]); - it = copy_str(str, str + str_size, it); - } -}; - -// This template provides operations for formatting and writing data into a -// character range. -template class basic_writer { - public: - using char_type = typename Range::value_type; - using iterator = typename Range::iterator; - using format_specs = basic_format_specs; - - private: - iterator out_; // Output iterator. - locale_ref locale_; - - // Attempts to reserve space for n extra characters in the output range. - // Returns a pointer to the reserved range or a reference to out_. - auto reserve(std::size_t n) -> decltype(internal::reserve(out_, n)) { - return internal::reserve(out_, n); - } - - template struct padded_int_writer { - size_t size_; - string_view prefix; - char_type fill; - std::size_t padding; - F f; - - size_t size() const { return size_; } - size_t width() const { return size_; } - - template void operator()(It&& it) const { - if (prefix.size() != 0) - it = copy_str(prefix.begin(), prefix.end(), it); - it = std::fill_n(it, padding, fill); - f(it); - } - }; - - // Writes an integer in the format - // - // where are written by f(it). - template - void write_int(int num_digits, string_view prefix, format_specs specs, F f) { - std::size_t size = prefix.size() + to_unsigned(num_digits); - char_type fill = specs.fill[0]; - std::size_t padding = 0; - if (specs.align == align::numeric) { - auto unsiged_width = to_unsigned(specs.width); - if (unsiged_width > size) { - padding = unsiged_width - size; - size = unsiged_width; - } - } else if (specs.precision > num_digits) { - size = prefix.size() + to_unsigned(specs.precision); - padding = to_unsigned(specs.precision - num_digits); - fill = static_cast('0'); - } - if (specs.align == align::none) specs.align = align::right; - write_padded(specs, padded_int_writer{size, prefix, fill, padding, f}); - } - - // Writes a decimal integer. - template void write_decimal(Int value) { - auto abs_value = static_cast>(value); - bool negative = is_negative(value); - // Don't do -abs_value since it trips unsigned-integer-overflow sanitizer. - if (negative) abs_value = ~abs_value + 1; - int num_digits = count_digits(abs_value); - auto&& it = reserve((negative ? 1 : 0) + static_cast(num_digits)); - if (negative) *it++ = static_cast('-'); - it = format_decimal(it, abs_value, num_digits); - } - - // The handle_int_type_spec handler that writes an integer. - template struct int_writer { - using unsigned_type = uint32_or_64_or_128_t; - - basic_writer& writer; - const Specs& specs; - unsigned_type abs_value; - char prefix[4]; - unsigned prefix_size; - - string_view get_prefix() const { return string_view(prefix, prefix_size); } - - int_writer(basic_writer& w, Int value, const Specs& s) - : writer(w), - specs(s), - abs_value(static_cast(value)), - prefix_size(0) { - if (is_negative(value)) { - prefix[0] = '-'; - ++prefix_size; - abs_value = 0 - abs_value; - } else if (specs.sign != sign::none && specs.sign != sign::minus) { - prefix[0] = specs.sign == sign::plus ? '+' : ' '; - ++prefix_size; - } - } - - struct dec_writer { - unsigned_type abs_value; - int num_digits; - - template void operator()(It&& it) const { - it = internal::format_decimal(it, abs_value, num_digits); - } - }; - - void on_dec() { - int num_digits = count_digits(abs_value); - writer.write_int(num_digits, get_prefix(), specs, - dec_writer{abs_value, num_digits}); - } - - struct hex_writer { - int_writer& self; - int num_digits; - - template void operator()(It&& it) const { - it = format_uint<4, char_type>(it, self.abs_value, num_digits, - self.specs.type != 'x'); - } - }; - - void on_hex() { - if (specs.alt) { - prefix[prefix_size++] = '0'; - prefix[prefix_size++] = specs.type; - } - int num_digits = count_digits<4>(abs_value); - writer.write_int(num_digits, get_prefix(), specs, - hex_writer{*this, num_digits}); - } - - template struct bin_writer { - unsigned_type abs_value; - int num_digits; - - template void operator()(It&& it) const { - it = format_uint(it, abs_value, num_digits); - } - }; - - void on_bin() { - if (specs.alt) { - prefix[prefix_size++] = '0'; - prefix[prefix_size++] = static_cast(specs.type); - } - int num_digits = count_digits<1>(abs_value); - writer.write_int(num_digits, get_prefix(), specs, - bin_writer<1>{abs_value, num_digits}); - } - - void on_oct() { - int num_digits = count_digits<3>(abs_value); - if (specs.alt && specs.precision <= num_digits && abs_value != 0) { - // Octal prefix '0' is counted as a digit, so only add it if precision - // is not greater than the number of digits. - prefix[prefix_size++] = '0'; - } - writer.write_int(num_digits, get_prefix(), specs, - bin_writer<3>{abs_value, num_digits}); - } - - enum { sep_size = 1 }; - - struct num_writer { - unsigned_type abs_value; - int size; - const std::string& groups; - char_type sep; - - template void operator()(It&& it) const { - basic_string_view s(&sep, sep_size); - // Index of a decimal digit with the least significant digit having - // index 0. - int digit_index = 0; - std::string::const_iterator group = groups.cbegin(); - it = format_decimal( - it, abs_value, size, - [this, s, &group, &digit_index](char_type*& buffer) { - if (*group <= 0 || ++digit_index % *group != 0 || - *group == max_value()) - return; - if (group + 1 != groups.cend()) { - digit_index = 0; - ++group; - } - buffer -= s.size(); - std::uninitialized_copy(s.data(), s.data() + s.size(), - make_checked(buffer, s.size())); - }); - } - }; - - void on_num() { - std::string groups = grouping(writer.locale_); - if (groups.empty()) return on_dec(); - auto sep = specs.thousands; - if (!sep) return on_dec(); - int num_digits = count_digits(abs_value); - int size = num_digits; - std::string::const_iterator group = groups.cbegin(); - while (group != groups.cend() && num_digits > *group && *group > 0 && - *group != max_value()) { - size += sep_size; - num_digits -= *group; - ++group; - } - if (group == groups.cend()) - size += sep_size * ((num_digits - 1) / groups.back()); - writer.write_int(size, get_prefix(), specs, - num_writer{abs_value, size, groups, static_cast(sep)}); - } - - FMT_NORETURN void on_error(std::string error) { - FMT_THROW(duckdb::Exception(error)); - } - }; - - template struct str_writer { - const Char* s; - size_t size_; - - size_t size() const { return size_; } - size_t width() const { - return count_code_points(basic_string_view(s, size_)); - } - - template void operator()(It&& it) const { - it = copy_str(s, s + size_, it); - } - }; - - template struct pointer_writer { - UIntPtr value; - int num_digits; - - size_t size() const { return to_unsigned(num_digits) + 2; } - size_t width() const { return size(); } - - template void operator()(It&& it) const { - *it++ = static_cast('0'); - *it++ = static_cast('x'); - it = format_uint<4, char_type>(it, value, num_digits); - } - }; - - public: - explicit basic_writer(Range out, locale_ref loc = locale_ref()) - : out_(out.begin()), locale_(loc) {} - - iterator out() const { return out_; } - - // Writes a value in the format - // - // where is written by f(it). - template void write_padded(const format_specs& specs, F&& f) { - // User-perceived width (in code points). - unsigned width = to_unsigned(specs.width); - size_t size = f.size(); // The number of code units. - size_t num_code_points = width != 0 ? f.width() : size; - if (width <= num_code_points) return f(reserve(size)); - auto&& it = reserve(width + (size - num_code_points)); - char_type fill = specs.fill[0]; - std::size_t padding = width - num_code_points; - if (specs.align == align::right) { - it = std::fill_n(it, padding, fill); - f(it); - } else if (specs.align == align::center) { - std::size_t left_padding = padding / 2; - it = std::fill_n(it, left_padding, fill); - f(it); - it = std::fill_n(it, padding - left_padding, fill); - } else { - f(it); - it = std::fill_n(it, padding, fill); - } - } - - void write(int value) { write_decimal(value); } - void write(long value) { write_decimal(value); } - void write(long long value) { write_decimal(value); } - - void write(unsigned value) { write_decimal(value); } - void write(unsigned long value) { write_decimal(value); } - void write(unsigned long long value) { write_decimal(value); } - -#if FMT_USE_INT128 - void write(int128_t value) { write_decimal(value); } - void write(uint128_t value) { write_decimal(value); } -#endif - - template - void write_int(T value, const Spec& spec) { - handle_int_type_spec(spec, int_writer(*this, value, spec)); - } - - template ::value)> - void write(T value, format_specs specs = {}) { - float_specs fspecs = parse_float_type_spec(specs); - fspecs.sign = specs.sign; - if (std::signbit(value)) { // value < 0 is false for NaN so use signbit. - fspecs.sign = sign::minus; - value = -value; - } else if (fspecs.sign == sign::minus) { - fspecs.sign = sign::none; - } - - if (!std::isfinite(value)) { - auto str = std::isinf(value) ? (fspecs.upper ? "INF" : "inf") - : (fspecs.upper ? "NAN" : "nan"); - return write_padded(specs, nonfinite_writer{fspecs.sign, str}); - } - - if (specs.align == align::none) { - specs.align = align::right; - } else if (specs.align == align::numeric) { - if (fspecs.sign) { - auto&& it = reserve(1); - *it++ = static_cast(data::signs[fspecs.sign]); - fspecs.sign = sign::none; - if (specs.width != 0) --specs.width; - } - specs.align = align::right; - } - - memory_buffer buffer; - if (fspecs.format == float_format::hex) { - if (fspecs.sign) buffer.push_back(data::signs[fspecs.sign]); - snprintf_float(promote_float(value), specs.precision, fspecs, buffer); - write_padded(specs, str_writer{buffer.data(), buffer.size()}); - return; - } - int precision = specs.precision >= 0 || !specs.type ? specs.precision : 6; - if (fspecs.format == float_format::exp) ++precision; - if (const_check(std::is_same())) fspecs.binary32 = true; - fspecs.use_grisu = use_grisu(); - if (const_check(FMT_DEPRECATED_PERCENT) && fspecs.percent) value *= 100; - int exp = format_float(promote_float(value), precision, fspecs, buffer); - if (const_check(FMT_DEPRECATED_PERCENT) && fspecs.percent) { - buffer.push_back('%'); - --exp; // Adjust decimal place position. - } - fspecs.precision = precision; - char_type point = fspecs.locale ? decimal_point(locale_) - : static_cast('.'); - write_padded(specs, float_writer(buffer.data(), - static_cast(buffer.size()), - exp, fspecs, point)); - } - - void write(char value) { - auto&& it = reserve(1); - *it++ = value; - } - - template ::value)> - void write(Char value) { - auto&& it = reserve(1); - *it++ = value; - } - - void write(string_view value) { - auto&& it = reserve(value.size()); - it = copy_str(value.begin(), value.end(), it); - } - void write(wstring_view value) { - static_assert(std::is_same::value, ""); - auto&& it = reserve(value.size()); - it = std::copy(value.begin(), value.end(), it); - } - - template - void write(const Char* s, std::size_t size, const format_specs& specs) { - write_padded(specs, str_writer{s, size}); - } - - template - void write(basic_string_view s, const format_specs& specs = {}) { - const Char* data = s.data(); - std::size_t size = s.size(); - if (specs.precision >= 0 && to_unsigned(specs.precision) < size) - size = code_point_index(s, to_unsigned(specs.precision)); - write(data, size, specs); - } - - template - void write_pointer(UIntPtr value, const format_specs* specs) { - int num_digits = count_digits<4>(value); - auto pw = pointer_writer{value, num_digits}; - if (!specs) return pw(reserve(to_unsigned(num_digits) + 2)); - format_specs specs_copy = *specs; - if (specs_copy.align == align::none) specs_copy.align = align::right; - write_padded(specs_copy, pw); - } -}; - -using writer = basic_writer>; - -template struct is_integral : std::is_integral {}; -template <> struct is_integral : std::true_type {}; -template <> struct is_integral : std::true_type {}; - -template -class arg_formatter_base { - public: - using char_type = typename Range::value_type; - using iterator = typename Range::iterator; - using format_specs = basic_format_specs; - - private: - using writer_type = basic_writer; - writer_type writer_; - format_specs* specs_; - - struct char_writer { - char_type value; - - size_t size() const { return 1; } - size_t width() const { return 1; } - - template void operator()(It&& it) const { *it++ = value; } - }; - - void write_char(char_type value) { - if (specs_) - writer_.write_padded(*specs_, char_writer{value}); - else - writer_.write(value); - } - - void write_pointer(const void* p) { - writer_.write_pointer(internal::to_uintptr(p), specs_); - } - - protected: - writer_type& writer() { return writer_; } - FMT_DEPRECATED format_specs* spec() { return specs_; } - format_specs* specs() { return specs_; } - iterator out() { return writer_.out(); } - - void write(bool value) { - string_view sv(value ? "true" : "false"); - specs_ ? writer_.write(sv, *specs_) : writer_.write(sv); - } - - void write(const char_type* value) { - if (!value) { - FMT_THROW(duckdb::Exception("string pointer is null")); - } else { - auto length = std::char_traits::length(value); - basic_string_view sv(value, length); - specs_ ? writer_.write(sv, *specs_) : writer_.write(sv); - } - } - - public: - arg_formatter_base(Range r, format_specs* s, locale_ref loc) - : writer_(r, loc), specs_(s) {} - - iterator operator()(monostate) { - FMT_ASSERT(false, "invalid argument type"); - return out(); - } - - template ::value)> - iterator operator()(T value) { - if (specs_) - writer_.write_int(value, *specs_); - else - writer_.write(value); - return out(); - } - - iterator operator()(char_type value) { - internal::handle_char_specs( - specs_, char_spec_handler(*this, static_cast(value))); - return out(); - } - - iterator operator()(bool value) { - if (specs_ && specs_->type) return (*this)(value ? 1 : 0); - write(value != 0); - return out(); - } - - template ::value)> - iterator operator()(T value) { - writer_.write(value, specs_ ? *specs_ : format_specs()); - return out(); - } - - struct char_spec_handler : ErrorHandler { - arg_formatter_base& formatter; - char_type value; - - char_spec_handler(arg_formatter_base& f, char_type val) - : formatter(f), value(val) {} - - void on_int() { - if (formatter.specs_) - formatter.writer_.write_int(value, *formatter.specs_); - else - formatter.writer_.write(value); - } - void on_char() { formatter.write_char(value); } - }; - - struct cstring_spec_handler : internal::error_handler { - arg_formatter_base& formatter; - const char_type* value; - - cstring_spec_handler(arg_formatter_base& f, const char_type* val) - : formatter(f), value(val) {} - - void on_string() { formatter.write(value); } - void on_pointer() { formatter.write_pointer(value); } - }; - - iterator operator()(const char_type* value) { - if (!specs_) return write(value), out(); - internal::handle_cstring_type_spec(specs_->type, - cstring_spec_handler(*this, value)); - return out(); - } - - iterator operator()(basic_string_view value) { - if (specs_) { - internal::check_string_type_spec(specs_->type, internal::error_handler()); - writer_.write(value, *specs_); - } else { - writer_.write(value); - } - return out(); - } - - iterator operator()(const void* value) { - if (specs_) - check_pointer_type_spec(specs_->type, internal::error_handler()); - write_pointer(value); - return out(); - } -}; - -template FMT_CONSTEXPR bool is_name_start(Char c) { - return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || '_' == c; -} - -// Parses the range [begin, end) as an unsigned integer. This function assumes -// that the range is non-empty and the first character is a digit. -template -FMT_CONSTEXPR int parse_nonnegative_int(const Char*& begin, const Char* end, - ErrorHandler&& eh) { - FMT_ASSERT(begin != end && '0' <= *begin && *begin <= '9', ""); - if (*begin == '0') { - ++begin; - return 0; - } - unsigned value = 0; - // Convert to unsigned to prevent a warning. - constexpr unsigned max_int = max_value(); - unsigned big = max_int / 10; - do { - // Check for overflow. - if (value > big) { - value = max_int + 1; - break; - } - value = value * 10 + unsigned(*begin - '0'); - ++begin; - } while (begin != end && '0' <= *begin && *begin <= '9'); - if (value > max_int) eh.on_error("number is too big"); - return static_cast(value); -} - -template class custom_formatter { - private: - using char_type = typename Context::char_type; - - basic_format_parse_context& parse_ctx_; - Context& ctx_; - - public: - explicit custom_formatter(basic_format_parse_context& parse_ctx, - Context& ctx) - : parse_ctx_(parse_ctx), ctx_(ctx) {} - - bool operator()(typename basic_format_arg::handle h) const { - h.format(parse_ctx_, ctx_); - return true; - } - - template bool operator()(T) const { return false; } -}; - -template -using is_integer = - bool_constant::value && !std::is_same::value && - !std::is_same::value && - !std::is_same::value>; - -template class width_checker { - public: - explicit FMT_CONSTEXPR width_checker(ErrorHandler& eh) : handler_(eh) {} - - template ::value)> - FMT_CONSTEXPR unsigned long long operator()(T value) { - if (is_negative(value)) handler_.on_error("negative width"); - return static_cast(value); - } - - template ::value)> - FMT_CONSTEXPR unsigned long long operator()(T) { - handler_.on_error("width is not integer"); - return 0; - } - - private: - ErrorHandler& handler_; -}; - -template class precision_checker { - public: - explicit FMT_CONSTEXPR precision_checker(ErrorHandler& eh) : handler_(eh) {} - - template ::value)> - FMT_CONSTEXPR unsigned long long operator()(T value) { - if (is_negative(value)) handler_.on_error("negative precision"); - return static_cast(value); - } - - template ::value)> - FMT_CONSTEXPR unsigned long long operator()(T) { - handler_.on_error("precision is not integer"); - return 0; - } - - private: - ErrorHandler& handler_; -}; - -// A format specifier handler that sets fields in basic_format_specs. -template class specs_setter { - public: - explicit FMT_CONSTEXPR specs_setter(basic_format_specs& specs) - : specs_(specs) {} - - FMT_CONSTEXPR specs_setter(const specs_setter& other) - : specs_(other.specs_) {} - - FMT_CONSTEXPR void on_align(align_t align) { specs_.align = align; } - FMT_CONSTEXPR void on_fill(Char fill) { specs_.fill[0] = fill; } - FMT_CONSTEXPR void on_plus() { specs_.sign = sign::plus; } - FMT_CONSTEXPR void on_minus() { specs_.sign = sign::minus; } - FMT_CONSTEXPR void on_space() { specs_.sign = sign::space; } - FMT_CONSTEXPR void on_comma() { specs_.thousands = ','; } - FMT_CONSTEXPR void on_underscore() { specs_.thousands = '_'; } - FMT_CONSTEXPR void on_single_quote() { specs_.thousands = '\''; } - FMT_CONSTEXPR void on_thousands(char sep) { specs_.thousands = sep; } - FMT_CONSTEXPR void on_hash() { specs_.alt = true; } - - FMT_CONSTEXPR void on_zero() { - specs_.align = align::numeric; - specs_.fill[0] = Char('0'); - } - - FMT_CONSTEXPR void on_width(int width) { specs_.width = width; } - FMT_CONSTEXPR void on_precision(int precision) { - specs_.precision = precision; - } - FMT_CONSTEXPR void end_precision() {} - - FMT_CONSTEXPR void on_type(Char type) { - specs_.type = static_cast(type); - } - - protected: - basic_format_specs& specs_; -}; - -template class numeric_specs_checker { - public: - FMT_CONSTEXPR numeric_specs_checker(ErrorHandler& eh, internal::type arg_type) - : error_handler_(eh), arg_type_(arg_type) {} - - FMT_CONSTEXPR void require_numeric_argument() { - if (!is_arithmetic_type(arg_type_)) - error_handler_.on_error("format specifier requires numeric argument"); - } - - FMT_CONSTEXPR void check_sign() { - require_numeric_argument(); - if (is_integral_type(arg_type_) && arg_type_ != int_type && - arg_type_ != long_long_type && arg_type_ != internal::char_type) { - error_handler_.on_error("format specifier requires signed argument"); - } - } - - FMT_CONSTEXPR void check_precision() { - if (is_integral_type(arg_type_) || arg_type_ == internal::pointer_type) - error_handler_.on_error("precision not allowed for this argument type"); - } - - private: - ErrorHandler& error_handler_; - internal::type arg_type_; -}; - -// A format specifier handler that checks if specifiers are consistent with the -// argument type. -template class specs_checker : public Handler { - public: - FMT_CONSTEXPR specs_checker(const Handler& handler, internal::type arg_type) - : Handler(handler), checker_(*this, arg_type) {} - - FMT_CONSTEXPR specs_checker(const specs_checker& other) - : Handler(other), checker_(*this, other.arg_type_) {} - - FMT_CONSTEXPR void on_align(align_t align) { - if (align == align::numeric) checker_.require_numeric_argument(); - Handler::on_align(align); - } - - FMT_CONSTEXPR void on_plus() { - checker_.check_sign(); - Handler::on_plus(); - } - - FMT_CONSTEXPR void on_minus() { - checker_.check_sign(); - Handler::on_minus(); - } - - FMT_CONSTEXPR void on_space() { - checker_.check_sign(); - Handler::on_space(); - } - - FMT_CONSTEXPR void on_hash() { - checker_.require_numeric_argument(); - Handler::on_hash(); - } - - FMT_CONSTEXPR void on_zero() { - checker_.require_numeric_argument(); - Handler::on_zero(); - } - - FMT_CONSTEXPR void end_precision() { checker_.check_precision(); } - - private: - numeric_specs_checker checker_; -}; - -template