Skip to content

Commit

Permalink
fix: array_combine with non-string types
Browse files Browse the repository at this point in the history
  • Loading branch information
aceforeverd committed Jun 12, 2024
1 parent 6ebe49b commit 5f0930c
Show file tree
Hide file tree
Showing 10 changed files with 87 additions and 94 deletions.
22 changes: 11 additions & 11 deletions cases/query/udf_query.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -581,17 +581,17 @@ cases:
rows:
- ["1-3,1-4,2-3,2-4"]

# - id: array_combine_2
# desc: array_combine casting array to array<string> first
# mode: request-unsupport
# sql: |
# select
# array_join(array_combine("-", [1, 2], [3, 4]), ",") c0,
# expect:
# columns:
# - c0 string
# rows:
# - ["1-3,1-4,2-3,2-4"]
- id: array_combine_2
desc: array_combine casting array to array<string> first
mode: request-unsupport
sql: |
select
array_join(array_combine("-", [1, 2], [3, 4]), ",") c0,
expect:
columns:
- c0 string
rows:
- ["1-3,1-4,2-3,2-4"]

# ================================================================
# Map data type
Expand Down
49 changes: 0 additions & 49 deletions hybridse/src/base/cartesian_product.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
namespace hybridse {
namespace base {

int32_t CartesianProductIterSize() { return sizeof(CartesianProductViewIterator); }

static auto cartesian_product(const std::vector<std::vector<int>>& lists) {
std::vector<std::vector<int>> result;
if (std::find_if(std::begin(lists), std::end(lists), [](auto e) -> bool { return e.size() == 0; }) !=
Expand Down Expand Up @@ -60,52 +58,5 @@ std::vector<std::vector<int>> cartesian_product(absl::Span<int const> vec) {
return cartesian_product(input);
}

auto cartesian_product_iterator(absl::Span<int const> vec) { auto products = cartesian_product(vec); }

void CartesianProductIterNew(int32_t* vec, int32_t sz, int8_t* output) {
auto d = cartesian_product(absl::MakeSpan(vec, sz));
new (output) CartesianProductViewIterator(d);
}

void CartesianProductIterNext(int8_t* ptr) {
auto* it = reinterpret_cast<CartesianProductViewIterator*>(ptr);
if (it != nullptr) {
it->Next();
}
}

bool CartesianProductIterValid(int8_t* ptr) {
auto* it = reinterpret_cast<CartesianProductViewIterator*>(ptr);
if (it != nullptr) {
return it->Valid();
}
return false;
}

int32_t CartesianProductCount(int8_t* ptr) {
auto* it = reinterpret_cast<CartesianProductViewIterator*>(ptr);
if (it != nullptr) {
return it->data.size();
}
return 0;
}

int32_t CartesianProductIterGet(int8_t* ptr, int32_t idx) {
auto* it = reinterpret_cast<CartesianProductViewIterator*>(ptr);
if (it != nullptr) {
return it->GetProduct(idx);
}
return 0;
}

void CartesianProductIterDel(int8_t* output) {
if (output != nullptr) {
auto* it = reinterpret_cast<CartesianProductViewIterator*>(output);
if (it != nullptr) {
it->~CartesianProductViewIterator();
}
}
}

} // namespace base
} // namespace hybridse
23 changes: 0 additions & 23 deletions hybridse/src/base/cartesian_product.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,38 +17,15 @@
#ifndef HYBRIDSE_SRC_BASE_CARTESIAN_PRODUCT_H_
#define HYBRIDSE_SRC_BASE_CARTESIAN_PRODUCT_H_

#include <cstdint>
#include <vector>

#include "absl/types/span.h"

namespace hybridse {
namespace base {

using CartesianProductViewForIndex = std::vector<std::vector<int>>;

struct CartesianProductViewIterator {
explicit CartesianProductViewIterator(const CartesianProductViewForIndex& d) : data(d) { it = data.cbegin(); }

CartesianProductViewForIndex ::const_iterator Next() { return std::next(it); }
bool Valid() const { return it != data.cend(); }

int GetProduct(int i) const { return it->at(i); }

CartesianProductViewForIndex data;
CartesianProductViewForIndex ::const_iterator it;
};

std::vector<std::vector<int>> cartesian_product(absl::Span<int const> vec);

int32_t CartesianProductIterSize();
void CartesianProductIterNew(int32_t* vec, int32_t sz, int8_t* ptr);
int32_t CartesianProductCount(int8_t* ptr);
int32_t CartesianProductIterGet(int8_t* ptr, int32_t idx);
void CartesianProductIterNext(int8_t* ptr);
bool CartesianProductIterValid(int8_t* ptr);
void CartesianProductIterDel(int8_t* ptr);

} // namespace base
} // namespace hybridse

Expand Down
58 changes: 49 additions & 9 deletions hybridse/src/codegen/array_ir_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@

#include <string>

#include "absl/strings/substitute.h"
#include "base/fe_status.h"
#include "codegen/cast_expr_ir_builder.h"
#include "codegen/context.h"
#include "codegen/ir_base_builder.h"
#include "codegen/string_ir_builder.h"

namespace hybridse {
namespace codegen {
Expand Down Expand Up @@ -137,7 +140,7 @@ absl::StatusOr<llvm::Value*> ArrayIRBuilder::NumElements(CodeGenContextBase* ctx
return out;
}

absl::StatusOr<llvm::Value*> ArrayIRBuilder::CastFrom(CodeGenContextBase* ctx, llvm::Value* src) {
absl::StatusOr<llvm::Value*> ArrayIRBuilder::CastToArrayString(CodeGenContextBase* ctx, llvm::Value* src) {
auto sb = StructTypeIRBuilder::CreateStructTypeIRBuilder(ctx->GetModule(), src->getType());
CHECK_ABSL_STATUSOR(sb);

Expand All @@ -150,17 +153,21 @@ absl::StatusOr<llvm::Value*> ArrayIRBuilder::CastFrom(CodeGenContextBase* ctx, l
auto fields = src_builder->Load(ctx, src);
CHECK_ABSL_STATUSOR(fields);
llvm::Value* src_raws = fields.value().at(RAW_IDX);
llvm::Value* src_nulls = fields.value().at(NULL_IDX);
llvm::Value* num_elements = fields.value().at(SZ_IDX);


llvm::Value* casted = nullptr;
if (!CreateDefault(ctx->GetCurrentBlock(), &casted)) {
return absl::InternalError("codegen error: fail to construct default array");
}
// initialize each element
CHECK_ABSL_STATUS(Initialize(ctx, casted, {num_elements}));

auto builder = ctx->GetBuilder();
auto* raw_array_ptr = builder->CreateAlloca(element_type_, num_elements);
auto* nullables_ptr = builder->CreateAlloca(builder->getInt1Ty(), num_elements);
auto dst_fields = Load(ctx, casted);
CHECK_ABSL_STATUSOR(fields);
auto* raw_array_ptr = dst_fields.value().at(RAW_IDX);
auto* nullables_ptr = dst_fields.value().at(NULL_IDX);

llvm::Type* idx_type = builder->getInt64Ty();
llvm::Value* idx = builder->CreateAlloca(idx_type);
Expand All @@ -176,13 +183,16 @@ absl::StatusOr<llvm::Value*> ArrayIRBuilder::CastFrom(CodeGenContextBase* ctx, l

llvm::Value* src_ele_value =
builder->CreateLoad(src_ele_type, builder->CreateGEP(src_ele_type, src_raws, idx_val));
llvm::Value* dst_ele =
builder->CreateLoad(element_type_, builder->CreateGEP(element_type_, raw_array_ptr, idx_val));

NativeValue out;
CHECK_STATUS(cast_builder.Cast(NativeValue::Create(src_ele_value), element_type_, &out));
codegen::StringIRBuilder str_builder(ctx->GetModule());
auto s = str_builder.CastFrom(ctx->GetCurrentBlock(), src_ele_value, dst_ele);
CHECK_TRUE(s.ok(), common::kCodegenError, s.ToString());

builder->CreateStore(out.GetRaw(), builder->CreateGEP(element_type_, raw_array_ptr, idx_val));
builder->CreateStore(out.GetIsNull(builder),
builder->CreateGEP(builder->getInt1Ty(), nullables_ptr, idx_val));
builder->CreateStore(
builder->CreateLoad(builder->getInt1Ty(), builder->CreateGEP(builder->getInt1Ty(), src_nulls, idx_val)),
builder->CreateGEP(builder->getInt1Ty(), nullables_ptr, idx_val));

builder->CreateStore(builder->CreateAdd(idx_val, builder->getInt64(1)), idx);
return {};
Expand All @@ -192,5 +202,35 @@ absl::StatusOr<llvm::Value*> ArrayIRBuilder::CastFrom(CodeGenContextBase* ctx, l
return casted;
}

absl::Status ArrayIRBuilder::Initialize(CodeGenContextBase* ctx, ::llvm::Value* alloca,
absl::Span<llvm::Value* const> args) const {
auto* builder = ctx->GetBuilder();
StringIRBuilder str_builder(ctx->GetModule());
auto ele_type = str_builder.GetType();
if (!alloca->getType()->isPointerTy() || alloca->getType()->getPointerElementType() != struct_type_ ||
ele_type->getPointerTo() != element_type_) {
return absl::UnimplementedError(absl::Substitute(
"not able to Initialize array except array<string>, got type $0", GetLlvmObjectString(alloca->getType())));
}
if (args.size() != 1) {
// require one argument that is array size
return absl::InvalidArgumentError("initialize array requries one argument which is array size");
}
if (!args[0]->getType()->isIntegerTy()) {
return absl::InvalidArgumentError("array size argument should be integer");
}
auto sz = args[0];
if (sz->getType() != builder->getInt64Ty()) {
CastExprIRBuilder cast_builder(ctx->GetCurrentBlock());
base::Status s;
cast_builder.SafeCastNumber(sz, builder->getInt64Ty(), &sz, s);
CHECK_STATUS_TO_ABSL(s);
}
auto fn = ctx->GetModule()->getOrInsertFunction("hybridse_alloc_array_string", builder->getVoidTy(),
struct_type_->getPointerTo(), builder->getInt64Ty());

builder->CreateCall(fn, {alloca, sz});
return absl::OkStatus();
}
} // namespace codegen
} // namespace hybridse
5 changes: 4 additions & 1 deletion hybridse/src/codegen/array_ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class ArrayIRBuilder : public StructTypeIRBuilder {
CHECK_TRUE(false, common::kCodegenError, "casting to array un-implemented");
};

absl::StatusOr<llvm::Value*> CastFrom(CodeGenContextBase* ctx, llvm::Value* src);
absl::StatusOr<llvm::Value*> CastToArrayString(CodeGenContextBase* ctx, llvm::Value* src);

absl::StatusOr<NativeValue> ExtractElement(CodeGenContextBase* ctx, const NativeValue& arr,
const NativeValue& key) const override;
Expand All @@ -51,6 +51,9 @@ class ArrayIRBuilder : public StructTypeIRBuilder {

bool CreateDefault(::llvm::BasicBlock* block, ::llvm::Value** output) override;

absl::Status Initialize(CodeGenContextBase* ctx, ::llvm::Value* alloca,
absl::Span<llvm::Value* const> args) const override;

private:
void InitStructType() override;

Expand Down
12 changes: 12 additions & 0 deletions hybridse/src/codegen/string_ir_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -403,5 +403,17 @@ base::Status StringIRBuilder::ConcatWS(::llvm::BasicBlock* block,
*output = NativeValue::CreateWithFlag(concat_str, ret_null);
return base::Status();
}
absl::Status StringIRBuilder::CastFrom(llvm::BasicBlock* block, llvm::Value* src, llvm::Value* alloca) {
if (IsStringPtr(src->getType())) {
return absl::UnimplementedError("not necessary to cast string to string");
}
::llvm::IRBuilder<> builder(block);
::std::string fn_name = "string." + TypeName(src->getType());

auto cast_func = m_->getOrInsertFunction(
fn_name, ::llvm::FunctionType::get(builder.getVoidTy(), {src->getType(), alloca->getType()}, false));
builder.CreateCall(cast_func, {src, alloca});
return absl::OkStatus();
}
} // namespace codegen
} // namespace hybridse
1 change: 1 addition & 0 deletions hybridse/src/codegen/string_ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class StringIRBuilder : public StructTypeIRBuilder {
bool CopyFrom(::llvm::BasicBlock* block, ::llvm::Value* src, ::llvm::Value* dist) override;
base::Status CastFrom(::llvm::BasicBlock* block, const NativeValue& src, NativeValue* output) override;
base::Status CastFrom(::llvm::BasicBlock* block, ::llvm::Value* src, ::llvm::Value** output);
absl::Status CastFrom(llvm::BasicBlock* block, llvm::Value* in, llvm::Value* alloca);

bool NewString(::llvm::BasicBlock* block, ::llvm::Value** output);
bool NewString(::llvm::BasicBlock* block, const std::string& str,
Expand Down
6 changes: 5 additions & 1 deletion hybridse/src/codegen/struct_ir_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ absl::StatusOr<NativeValue> Combine(CodeGenContextBase* ctx, const NativeValue d
return absl::InternalError("codegen error: arguments to array_combine is not ARRAY");
}
if (!tp->GetGenericType(0)->IsString()) {
auto s = arr_builder.CastFrom(ctx, args.at(i).GetRaw());
auto s = arr_builder.CastToArrayString(ctx, args.at(i).GetRaw());
CHECK_ABSL_STATUSOR(s);
casted_args.at(i) = NativeValue::Create(s.value());
} else {
Expand Down Expand Up @@ -339,5 +339,9 @@ absl::StatusOr<NativeValue> Combine(CodeGenContextBase* ctx, const NativeValue d
return NativeValue::Create(out);
}

absl::Status StructTypeIRBuilder::Initialize(CodeGenContextBase* ctx, ::llvm::Value* alloca,
absl::Span<llvm::Value* const> args) const {
return absl::UnimplementedError(absl::StrCat("Initialize for type ", GetLlvmObjectString(struct_type_)));
}
} // namespace codegen
} // namespace hybridse
3 changes: 3 additions & 0 deletions hybridse/src/codegen/struct_ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ class StructTypeIRBuilder : public TypeIRBuilder {
virtual absl::StatusOr<::llvm::Value*> ConstructFromRaw(CodeGenContextBase* ctx,
absl::Span<::llvm::Value* const> args) const;

virtual absl::Status Initialize(CodeGenContextBase* ctx, ::llvm::Value* alloca,
absl::Span<llvm::Value* const> args) const;

// Extract element value from composite data type
// 1. extract from array type by index
// 2. extract from struct type by field name
Expand Down
2 changes: 2 additions & 0 deletions hybridse/src/vm/jit_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,8 @@ void InitBuiltinJitSymbols(HybridSeJitWrapper* jit) {

// cartesian product
jit->AddExternalFunction("hybridse_array_combine", reinterpret_cast<void*>(&hybridse::udf::v1::array_combine));
jit->AddExternalFunction("hybridse_alloc_array_string",
reinterpret_cast<void*>(&hybridse::udf::v1::AllocManagedArray<codec::StringRef>));
}

} // namespace vm
Expand Down

0 comments on commit 5f0930c

Please sign in to comment.