diff --git a/context.cc b/context.cc index a1029b0..89aaa9f 100644 --- a/context.cc +++ b/context.cc @@ -53,21 +53,21 @@ Record Scope::get(string name) { } } -ReturnResult Scope::getReason() { - if (reason) { - return *reason; +Trace Scope::getTrace() { + if (trace) { + return *trace; } else if (parent) { - return parent->getReason(); + return parent->getTrace(); } else { - return ReturnResult{ReturnReason::NORMAL, nullptr}; + return Trace{}; } } -void Scope::setReason(ReturnResult r) { - if (reason) { - *reason = r; +void Scope::setTrace(Trace r) { + if (trace) { + *trace = r; } else if (parent) { - parent->setReason(r); + parent->setTrace(r); } else { std::cerr << "Return in an invalid context"; exit(-1); diff --git a/context.h b/context.h index 3c86025..01c0866 100644 --- a/context.h +++ b/context.h @@ -7,30 +7,31 @@ #include "llvm.h" #include "type.h" +class FunDecl; struct Record { std::string id; Type type; - llvm::Argument* arg; + llvm::AllocaInst* addr; }; class Literal; -enum class ReturnReason { NORMAL, RETURN, BREAK, CONTINUE }; -struct ReturnResult { - ReturnReason reason; - llvm::Value* value; +struct Trace { + llvm::Function *llvmFun; + FunDecl* fun; + llvm::BasicBlock* endB; }; class Scope { Scope* parent; typedef std::map VarRec; std::shared_ptr varRec; - ReturnResult* reason; + Trace* trace; bool localCount(std::string); void setOrCreateVar(std::string, Record r); public: - Scope(Scope* parent = nullptr, ReturnResult* reason = nullptr) - : parent(parent), reason(reason) { + Scope(Scope* parent = nullptr, Trace* trace = nullptr) + : parent(parent), trace(trace) { varRec = std::make_shared(); } @@ -40,11 +41,11 @@ class Scope { void set(std::string, Record r); Record get(std::string); - void setReason(ReturnResult r); - ReturnResult getReason(); + void setTrace(Trace r); + Trace getTrace(); bool isWrapped() const { return parent; } - Scope wrap() { return Scope(this, reason); } - Scope wrapWithReason(ReturnResult* reason) { return Scope(this, reason); }; + Scope wrap() { return Scope(this, trace); } + Scope wrapWithTrace(Trace* t) { return Scope(this, t); }; }; diff --git a/llvm.cc b/llvm.cc index dcd6787..d606432 100644 --- a/llvm.cc +++ b/llvm.cc @@ -19,3 +19,34 @@ llvm::Type* llvmWrapper::getType(Type type) { } return nullptr; } + +llvm::Value* llvmWrapper::isTruthy(llvm::Value* v) { + if (v->getType()->isIntegerTy()) + return builder->CreateICmpNE( + v, llvm::ConstantInt::get(v->getType(), llvm::APInt(32, 0))); + abortMsg("can't use value as boolean"); + return nullptr; +} + +/// CreateEntryBlockAlloca - Create an alloca instruction in the entry block of +/// the function. This is used for mutable variables etc. +llvm::AllocaInst* llvmWrapper::createEntryBlockAlloca(llvm::Function* fun, + llvm::Type* type, + const std::string& name) { + llvm::IRBuilder<> TmpB(&fun->getEntryBlock(), fun->getEntryBlock().begin()); + return TmpB.CreateAlloca(type, 0, name.c_str()); +} + +llvm::Value* llvmWrapper::convertTo(llvm::Value* v, llvm::Type* t) { + if (v->getType() == t) return v; + if (t->isDoubleTy()) { + return builder->CreateSIToFP(v, t, "todouble"); + } else if (t->isIntegerTy()) { + if (!v->getType()->isIntegerTy()) + abortMsg("can't implict convert double into int"); + else + return builder->CreateIntCast(v, t, true, "toint32"); + } else + abortMsg("unimplemented implict convert"); + return nullptr; +} diff --git a/llvm.h b/llvm.h index e7e45cf..d67112c 100644 --- a/llvm.h +++ b/llvm.h @@ -23,4 +23,9 @@ struct llvmWrapper { builder = std::make_shared>(*ctx); }; llvm::Type* getType(Type t); + llvm::Value* isTruthy(llvm::Value*); + llvm::Value* convertTo(llvm::Value*, llvm::Type*); + llvm::AllocaInst* createEntryBlockAlloca(llvm::Function* fun, + llvm::Type* type, + const std::string& name); }; diff --git a/visitor.cc b/visitor.cc index bee7c7e..834e172 100644 --- a/visitor.cc +++ b/visitor.cc @@ -10,8 +10,8 @@ CodeGenVisitor CodeGenVisitor::wrap() { return CodeGenVisitor(scope.wrap(), l); } -CodeGenVisitor CodeGenVisitor::wrapWithReason(ReturnResult* r) { - return CodeGenVisitor(scope.wrapWithReason(r), l); +CodeGenVisitor CodeGenVisitor::wrapWithTrace(Trace* r) { + return CodeGenVisitor(scope.wrapWithTrace(r), l); } void CodeGenExprVisitor::visit(Expr* expr) { expr->accept(this); } @@ -116,7 +116,8 @@ void CodeGenExprVisitor::visit(Unary* expr) { void CodeGenExprVisitor::visit(Postfix* expr) { abortMsg("unimplemented"); } void CodeGenExprVisitor::visit(Variable* expr) { - value = scope.get(expr->name).arg; + auto r = scope.get(expr->name); + value = l.builder->CreateLoad(l.getType(r.type), r.addr, r.id.c_str()); } void CodeGenExprVisitor::visit(Call* expr) { // Look up the name in the global module table. @@ -156,19 +157,25 @@ void CodeGenVisitor::visit(FunDecl* st) { if (scope.isWrapped()) abortMsg("nested function is forbidden"); llvm::Function* F = l.mod->getFunction(st->identifier); if (F) abortMsg("redefine func"); + std::vector args; for (auto [type, token] : st->args) { auto t = l.getType(type); args.push_back(t); } + llvm::FunctionType* FT = llvm::FunctionType::get(l.getType(st->retType), args, false); F = llvm::Function::Create(FT, llvm::Function::ExternalLinkage, st->identifier, l.mod.get()); - ReturnResult r = {}; - auto v = wrapWithReason(&r); + // Create a new basic block to start insertion into. + llvm::BasicBlock* BB = llvm::BasicBlock::Create(*l.ctx, st->identifier, F); + l.builder->SetInsertPoint(BB); + + Trace r = {F, st, nullptr}; + auto v = wrapWithTrace(&r); // Set names for all arguments. size_t i = 0; @@ -176,28 +183,67 @@ void CodeGenVisitor::visit(FunDecl* st) { FormalArg formal = st->args[i++]; std::string name = formal.token.lexeme; a.setName(name); - v.scope.define(name, {name, formal.type, &a}); + auto addr = l.createEntryBlockAlloca(F, l.getType(formal.type), + formal.token.lexeme.c_str()); + l.builder->CreateStore(&a, addr); + v.scope.define(name, {name, formal.type, addr}); } - // Create a new basic block to start insertion into. - llvm::BasicBlock* BB = llvm::BasicBlock::Create(*l.ctx, st->identifier, F); - l.builder->SetInsertPoint(BB); - v.visit(st->body); - if (r.value) { - l.builder->CreateRet(r.value); - verifyFunction(*F); - } + if (llvm::verifyFunction(*F, &llvm::errs())) abortMsg("verify error"); // F->eraseFromParent(); } void CodeGenVisitor::visit(BlockStmt* st) { for (auto d : st->decls) visit(d); } -void CodeGenVisitor::visit(IfStmt* st) { unimplemented(); } +void CodeGenVisitor::visit(IfStmt* st) { + CodeGenExprVisitor v(scope, l); + v.visit(st->condition); + llvm::Value* condV = v.getValue(); + if (!condV) abortMsg("failed to generate condition"); + + // Convert condition to a bool by comparing non-equal to 0.0. + condV = l.isTruthy(condV); + llvm::Function* fun = l.builder->GetInsertBlock()->getParent(); + + // Create blocks for the then and else cases. Insert the 'then' block at the + // end of the function. + llvm::BasicBlock* thenBB = llvm::BasicBlock::Create(*l.ctx, "then", fun); + llvm::BasicBlock* elseBB = llvm::BasicBlock::Create(*l.ctx, "else"); + llvm::BasicBlock* mergeBB = llvm::BasicBlock::Create(*l.ctx, "ifcont"); + + l.builder->CreateCondBr(condV, thenBB, elseBB); + // Emit then value. + l.builder->SetInsertPoint(thenBB); + + auto v1 = wrap(); + v1.visit(st->true_branch); + + l.builder->CreateBr(mergeBB); + // Codegen of 'Then' can change the current block, update ThenBB for the PHI. + thenBB = l.builder->GetInsertBlock(); + + // Emit else block. + fun->getBasicBlockList().push_back(elseBB); + l.builder->SetInsertPoint(elseBB); + + auto v2 = wrap(); + v2.visit(st->false_branch); + + l.builder->CreateBr(mergeBB); + // codegen of 'Else' can change the current block, update ElseBB for the PHI. + elseBB = l.builder->GetInsertBlock(); + + // Emit merge block. + fun->getBasicBlockList().push_back(mergeBB); + l.builder->SetInsertPoint(mergeBB); +} void CodeGenVisitor::visit(WhileStmt* st) { unimplemented(); } void CodeGenVisitor::visit(BreakStmt* st) { unimplemented(); } void CodeGenVisitor::visit(ReturnStmt* st) { CodeGenExprVisitor v(scope, l); v.visit(st->expr); - scope.setReason({ReturnReason::NORMAL, v.getValue()}); + auto val = + l.convertTo(v.getValue(), scope.getTrace().llvmFun->getReturnType()); + l.builder->CreateRet(val); } diff --git a/visitor.h b/visitor.h index 13db6db..ed0ad9b 100644 --- a/visitor.h +++ b/visitor.h @@ -56,13 +56,6 @@ class ExprVisitor { virtual void visit(Call* expr) = 0; }; -/* -struct TypedValue { - Type type; - llvm::Value* value; -}; -*/ - class CodeGenExprVisitor : public ExprVisitor { Scope scope; llvmWrapper l; @@ -70,7 +63,6 @@ class CodeGenExprVisitor : public ExprVisitor { public: CodeGenExprVisitor(Scope scope, llvmWrapper l) : scope(scope), l(l){}; - // CodeGenExprVisitor wrap(); void visit(Expr* expr) override; void visit(Literal* expr) override; void visit(Integer* expr) override; @@ -90,7 +82,7 @@ class CodeGenVisitor : public DeclVisitor { public: CodeGenVisitor(Scope scope, llvmWrapper l) : scope(scope), l(l){}; CodeGenVisitor wrap(); - CodeGenVisitor wrapWithReason(ReturnResult* r); + CodeGenVisitor wrapWithTrace(Trace* r); virtual void visit(Declaration* d) override; virtual void visit(ExprStmt* st) override;