From 6e6c20a6ed8c7e70e6a342065aca39441b2ed75a Mon Sep 17 00:00:00 2001 From: Yujia Qiao Date: Wed, 3 Nov 2021 10:09:46 +0800 Subject: [PATCH] impl if --- llvm.cc | 21 ++++++++++----- llvm.h | 4 +-- visitor.cc | 75 +++++++++++++++++++++++++++++++++++------------------- visitor.h | 10 ++++++++ 4 files changed, 76 insertions(+), 34 deletions(-) diff --git a/llvm.cc b/llvm.cc index d606432..6baacc9 100644 --- a/llvm.cc +++ b/llvm.cc @@ -11,6 +11,8 @@ llvm::Type* llvmWrapper::getType(Type type) { return llvm::Type::getDoubleTy(*ctx); break; case Type::Base::BOOL: + return llvm::Type::getInt1Ty(*ctx); + break; case Type::Base::CHAR: return llvm::Type::getInt8Ty(*ctx); break; @@ -20,10 +22,17 @@ llvm::Type* llvmWrapper::getType(Type type) { return nullptr; } -llvm::Value* llvmWrapper::isTruthy(llvm::Value* v) { - if (v->getType()->isIntegerTy()) - return builder->CreateICmpNE( - v, llvm::ConstantInt::get(v->getType(), llvm::APInt(32, 0))); +llvm::Value* llvmWrapper::convertToTruthy(llvm::Value* v) { + auto t = v->getType(); + int w = t->getIntegerBitWidth(); + if (t->isIntegerTy()) { + if (t == llvm::Type::getInt1Ty(*ctx)) + return v; + else + return builder->CreateICmpNE( + v, llvm::ConstantInt::get(llvm::IntegerType::get(*ctx, w), + llvm::APInt(w, 0, true))); + } abortMsg("can't use value as boolean"); return nullptr; } @@ -37,7 +46,7 @@ llvm::AllocaInst* llvmWrapper::createEntryBlockAlloca(llvm::Function* fun, return TmpB.CreateAlloca(type, 0, name.c_str()); } -llvm::Value* llvmWrapper::convertTo(llvm::Value* v, llvm::Type* t) { +llvm::Value* llvmWrapper::implictConvert(llvm::Value* v, llvm::Type* t) { if (v->getType() == t) return v; if (t->isDoubleTy()) { return builder->CreateSIToFP(v, t, "todouble"); @@ -45,7 +54,7 @@ llvm::Value* llvmWrapper::convertTo(llvm::Value* v, llvm::Type* t) { if (!v->getType()->isIntegerTy()) abortMsg("can't implict convert double into int"); else - return builder->CreateIntCast(v, t, true, "toint32"); + return builder->CreateIntCast(v, t, true, "toint"); } else abortMsg("unimplemented implict convert"); return nullptr; diff --git a/llvm.h b/llvm.h index d67112c..2f20586 100644 --- a/llvm.h +++ b/llvm.h @@ -23,8 +23,8 @@ struct llvmWrapper { builder = std::make_shared>(*ctx); }; llvm::Type* getType(Type t); - llvm::Value* isTruthy(llvm::Value*); - llvm::Value* convertTo(llvm::Value*, llvm::Type*); + llvm::Value* convertToTruthy(llvm::Value*); + llvm::Value* implictConvert(llvm::Value*, llvm::Type*); llvm::AllocaInst* createEntryBlockAlloca(llvm::Function* fun, llvm::Type* type, const std::string& name); diff --git a/visitor.cc b/visitor.cc index 834e172..4f6b2c2 100644 --- a/visitor.cc +++ b/visitor.cc @@ -19,24 +19,31 @@ void CodeGenExprVisitor::visit(Expr* expr) { expr->accept(this); } void CodeGenExprVisitor::visit(Literal* expr) { expr->accept(this); } void CodeGenExprVisitor::visit(Integer* expr) { - value = llvm::ConstantInt::get(*l.ctx, llvm::APInt(32, expr->value)); + setValue(llvm::ConstantInt::get(*l.ctx, llvm::APInt(32, expr->value))); } void CodeGenExprVisitor::visit(Double* expr) { - value = llvm::ConstantFP::get(*l.ctx, llvm::APFloat(expr->value)); + setValue(llvm::ConstantFP::get(*l.ctx, llvm::APFloat(expr->value))); } void CodeGenExprVisitor::visit(Binary* expr) { - visit(expr->left); - auto lhs = getValue(); + CodeGenExprVisitor lv(scope, l); + lv.visit(expr->left); + auto lhs = lv.getValue(); - visit(expr->right); - auto rhs = getValue(); + CodeGenExprVisitor rv(scope, l); + rv.visit(expr->right); + auto rhs = rv.getValue(); auto op = expr->op.tokenType; - bool hasDouble = lhs->getType()->isDoubleTy() || rhs->getType()->isDoubleTy(); - bool hasInteger = - lhs->getType()->isIntegerTy() || rhs->getType()->isIntegerTy(); + + bool hasDouble = false; + bool hasInteger = false; + + if (lhs) { + hasDouble = lhs->getType()->isDoubleTy() || rhs->getType()->isDoubleTy(); + hasInteger = lhs->getType()->isIntegerTy() || rhs->getType()->isIntegerTy(); + } if (hasDouble) { if (!lhs->getType()->isDoubleTy()) lhs = l.builder->CreateFPCast(lhs, llvm::Type::getDoubleTy(*l.ctx), @@ -55,44 +62,52 @@ void CodeGenExprVisitor::visit(Binary* expr) { switch (op) { case PLUS: if (hasDouble) - value = l.builder->CreateFAdd(lhs, rhs); + setTuple(l.builder->CreateFAdd(lhs, rhs)); else if (hasInteger) - value = l.builder->CreateAdd(lhs, rhs); + setTuple(l.builder->CreateAdd(lhs, rhs)); else abortMsg("type mismatched"); + setAddr(nullptr); break; case MINUS: if (hasDouble) - value = l.builder->CreateFSub(lhs, rhs); + setTuple(l.builder->CreateFSub(lhs, rhs)); else if (hasInteger) - value = l.builder->CreateSub(lhs, rhs); + setTuple(l.builder->CreateSub(lhs, rhs)); else abortMsg("type mismatched"); break; case STAR: if (hasDouble) - value = l.builder->CreateFMul(lhs, rhs); + setTuple(l.builder->CreateFMul(lhs, rhs)); else if (hasInteger) - value = l.builder->CreateMul(lhs, rhs); + setTuple(l.builder->CreateMul(lhs, rhs)); else abortMsg("type mismatched"); break; case SLASH: if (hasDouble) - value = l.builder->CreateFDiv(lhs, rhs); + setTuple(l.builder->CreateFDiv(lhs, rhs)); else if (hasInteger) - value = l.builder->CreateSDiv(lhs, rhs); + setTuple(l.builder->CreateSDiv(lhs, rhs)); else abortMsg("type mismatched"); break; case LESS: if (hasDouble) - value = l.builder->CreateFCmpOLT(lhs, rhs); + setTuple(l.builder->CreateFCmpOLT(lhs, rhs)); else if (hasInteger) - value = l.builder->CreateICmpSLT(lhs, rhs); + setTuple(l.builder->CreateICmpSLT(lhs, rhs)); else abortMsg("type mismatched"); break; + case EQUAL: + if (lv.getAddr()) { + l.builder->CreateStore(rhs, lv.getAddr()); + setTuple(rhs); + } else + abortMsg("cannot assign value to rvalue"); + break; default: abortMsg("unexpected binary operator " + expr->op.lexeme); } @@ -117,6 +132,7 @@ void CodeGenExprVisitor::visit(Unary* expr) { void CodeGenExprVisitor::visit(Postfix* expr) { abortMsg("unimplemented"); } void CodeGenExprVisitor::visit(Variable* expr) { auto r = scope.get(expr->name); + addr = r.addr; value = l.builder->CreateLoad(l.getType(r.type), r.addr, r.id.c_str()); } void CodeGenExprVisitor::visit(Call* expr) { @@ -152,7 +168,12 @@ void CodeGenVisitor::visit(ExprStmt* st) { } void CodeGenVisitor::visit(AssertStmt* st) { unimplemented(); } void CodeGenVisitor::visit(PrintStmt* st) { unimplemented(); } -void CodeGenVisitor::visit(VarDecl* st) { unimplemented(); } +void CodeGenVisitor::visit(VarDecl* st) { + auto type = l.getType(st->type); + auto addr = + l.createEntryBlockAlloca(scope.getTrace().llvmFun, type, st->identifier); + scope.define(st->identifier, {st->identifier, st->type, addr}); +} void CodeGenVisitor::visit(FunDecl* st) { if (scope.isWrapped()) abortMsg("nested function is forbidden"); llvm::Function* F = l.mod->getFunction(st->identifier); @@ -194,7 +215,8 @@ void CodeGenVisitor::visit(FunDecl* st) { // F->eraseFromParent(); } void CodeGenVisitor::visit(BlockStmt* st) { - for (auto d : st->decls) visit(d); + auto v = wrap(); + for (auto d : st->decls) v.visit(d); } void CodeGenVisitor::visit(IfStmt* st) { CodeGenExprVisitor v(scope, l); @@ -202,8 +224,7 @@ void CodeGenVisitor::visit(IfStmt* st) { llvm::Value* condV = v.getValue(); if (!condV) abortMsg("failed to generate condition"); - // Convert condition to a bool by comparing non-equal to 0.0. - condV = l.isTruthy(condV); + condV = l.convertToTruthy(condV); llvm::Function* fun = l.builder->GetInsertBlock()->getParent(); // Create blocks for the then and else cases. Insert the 'then' block at the @@ -227,8 +248,10 @@ void CodeGenVisitor::visit(IfStmt* st) { fun->getBasicBlockList().push_back(elseBB); l.builder->SetInsertPoint(elseBB); - auto v2 = wrap(); - v2.visit(st->false_branch); + if (st->false_branch) { + auto v2 = wrap(); + v2.visit(st->false_branch); + } l.builder->CreateBr(mergeBB); // codegen of 'Else' can change the current block, update ElseBB for the PHI. @@ -244,6 +267,6 @@ void CodeGenVisitor::visit(ReturnStmt* st) { CodeGenExprVisitor v(scope, l); v.visit(st->expr); auto val = - l.convertTo(v.getValue(), scope.getTrace().llvmFun->getReturnType()); + l.implictConvert(v.getValue(), scope.getTrace().llvmFun->getReturnType()); l.builder->CreateRet(val); } diff --git a/visitor.h b/visitor.h index ed0ad9b..4052655 100644 --- a/visitor.h +++ b/visitor.h @@ -60,6 +60,7 @@ class CodeGenExprVisitor : public ExprVisitor { Scope scope; llvmWrapper l; llvm::Value* value = nullptr; + llvm::AllocaInst* addr = nullptr; public: CodeGenExprVisitor(Scope scope, llvmWrapper l) : scope(scope), l(l){}; @@ -72,7 +73,16 @@ class CodeGenExprVisitor : public ExprVisitor { void visit(Postfix* expr) override; void visit(Variable* expr) override; void visit(Call* expr) override; + + void setValue(llvm::Value* v) { value = v; } + void setAddr(llvm::AllocaInst* a) { addr = a; } + void setTuple(llvm::Value* v, llvm::AllocaInst* a = nullptr) { + value = v; + addr = a; + } + llvm::Value* getValue() { return value; } + llvm::AllocaInst* getAddr() { return addr; } }; class CodeGenVisitor : public DeclVisitor {