Skip to content

Commit

Permalink
impl if
Browse files Browse the repository at this point in the history
  • Loading branch information
rapiz1 committed Nov 3, 2021
1 parent 751a47a commit 6e6c20a
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 34 deletions.
21 changes: 15 additions & 6 deletions llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ llvm::Type* llvmWrapper::getType(Type type) {
return llvm::Type::getDoubleTy(*ctx);
break;
case Type::Base::BOOL:
return llvm::Type::getInt1Ty(*ctx);
break;
case Type::Base::CHAR:
return llvm::Type::getInt8Ty(*ctx);
break;
Expand All @@ -20,10 +22,17 @@ llvm::Type* llvmWrapper::getType(Type type) {
return nullptr;
}

llvm::Value* llvmWrapper::isTruthy(llvm::Value* v) {
if (v->getType()->isIntegerTy())
return builder->CreateICmpNE(
v, llvm::ConstantInt::get(v->getType(), llvm::APInt(32, 0)));
llvm::Value* llvmWrapper::convertToTruthy(llvm::Value* v) {
auto t = v->getType();
int w = t->getIntegerBitWidth();
if (t->isIntegerTy()) {
if (t == llvm::Type::getInt1Ty(*ctx))
return v;
else
return builder->CreateICmpNE(
v, llvm::ConstantInt::get(llvm::IntegerType::get(*ctx, w),
llvm::APInt(w, 0, true)));
}
abortMsg("can't use value as boolean");
return nullptr;
}
Expand All @@ -37,15 +46,15 @@ llvm::AllocaInst* llvmWrapper::createEntryBlockAlloca(llvm::Function* fun,
return TmpB.CreateAlloca(type, 0, name.c_str());
}

llvm::Value* llvmWrapper::convertTo(llvm::Value* v, llvm::Type* t) {
llvm::Value* llvmWrapper::implictConvert(llvm::Value* v, llvm::Type* t) {
if (v->getType() == t) return v;
if (t->isDoubleTy()) {
return builder->CreateSIToFP(v, t, "todouble");
} else if (t->isIntegerTy()) {
if (!v->getType()->isIntegerTy())
abortMsg("can't implict convert double into int");
else
return builder->CreateIntCast(v, t, true, "toint32");
return builder->CreateIntCast(v, t, true, "toint");
} else
abortMsg("unimplemented implict convert");
return nullptr;
Expand Down
4 changes: 2 additions & 2 deletions llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ struct llvmWrapper {
builder = std::make_shared<llvm::IRBuilder<>>(*ctx);
};
llvm::Type* getType(Type t);
llvm::Value* isTruthy(llvm::Value*);
llvm::Value* convertTo(llvm::Value*, llvm::Type*);
llvm::Value* convertToTruthy(llvm::Value*);
llvm::Value* implictConvert(llvm::Value*, llvm::Type*);
llvm::AllocaInst* createEntryBlockAlloca(llvm::Function* fun,
llvm::Type* type,
const std::string& name);
Expand Down
75 changes: 49 additions & 26 deletions visitor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,31 @@ void CodeGenExprVisitor::visit(Expr* expr) { expr->accept(this); }
void CodeGenExprVisitor::visit(Literal* expr) { expr->accept(this); }

void CodeGenExprVisitor::visit(Integer* expr) {
value = llvm::ConstantInt::get(*l.ctx, llvm::APInt(32, expr->value));
setValue(llvm::ConstantInt::get(*l.ctx, llvm::APInt(32, expr->value)));
}

void CodeGenExprVisitor::visit(Double* expr) {
value = llvm::ConstantFP::get(*l.ctx, llvm::APFloat(expr->value));
setValue(llvm::ConstantFP::get(*l.ctx, llvm::APFloat(expr->value)));
}

void CodeGenExprVisitor::visit(Binary* expr) {
visit(expr->left);
auto lhs = getValue();
CodeGenExprVisitor lv(scope, l);
lv.visit(expr->left);
auto lhs = lv.getValue();

visit(expr->right);
auto rhs = getValue();
CodeGenExprVisitor rv(scope, l);
rv.visit(expr->right);
auto rhs = rv.getValue();

auto op = expr->op.tokenType;
bool hasDouble = lhs->getType()->isDoubleTy() || rhs->getType()->isDoubleTy();
bool hasInteger =
lhs->getType()->isIntegerTy() || rhs->getType()->isIntegerTy();

bool hasDouble = false;
bool hasInteger = false;

if (lhs) {
hasDouble = lhs->getType()->isDoubleTy() || rhs->getType()->isDoubleTy();
hasInteger = lhs->getType()->isIntegerTy() || rhs->getType()->isIntegerTy();
}
if (hasDouble) {
if (!lhs->getType()->isDoubleTy())
lhs = l.builder->CreateFPCast(lhs, llvm::Type::getDoubleTy(*l.ctx),
Expand All @@ -55,44 +62,52 @@ void CodeGenExprVisitor::visit(Binary* expr) {
switch (op) {
case PLUS:
if (hasDouble)
value = l.builder->CreateFAdd(lhs, rhs);
setTuple(l.builder->CreateFAdd(lhs, rhs));
else if (hasInteger)
value = l.builder->CreateAdd(lhs, rhs);
setTuple(l.builder->CreateAdd(lhs, rhs));
else
abortMsg("type mismatched");
setAddr(nullptr);
break;
case MINUS:
if (hasDouble)
value = l.builder->CreateFSub(lhs, rhs);
setTuple(l.builder->CreateFSub(lhs, rhs));
else if (hasInteger)
value = l.builder->CreateSub(lhs, rhs);
setTuple(l.builder->CreateSub(lhs, rhs));
else
abortMsg("type mismatched");
break;
case STAR:
if (hasDouble)
value = l.builder->CreateFMul(lhs, rhs);
setTuple(l.builder->CreateFMul(lhs, rhs));
else if (hasInteger)
value = l.builder->CreateMul(lhs, rhs);
setTuple(l.builder->CreateMul(lhs, rhs));
else
abortMsg("type mismatched");
break;
case SLASH:
if (hasDouble)
value = l.builder->CreateFDiv(lhs, rhs);
setTuple(l.builder->CreateFDiv(lhs, rhs));
else if (hasInteger)
value = l.builder->CreateSDiv(lhs, rhs);
setTuple(l.builder->CreateSDiv(lhs, rhs));
else
abortMsg("type mismatched");
break;
case LESS:
if (hasDouble)
value = l.builder->CreateFCmpOLT(lhs, rhs);
setTuple(l.builder->CreateFCmpOLT(lhs, rhs));
else if (hasInteger)
value = l.builder->CreateICmpSLT(lhs, rhs);
setTuple(l.builder->CreateICmpSLT(lhs, rhs));
else
abortMsg("type mismatched");
break;
case EQUAL:
if (lv.getAddr()) {
l.builder->CreateStore(rhs, lv.getAddr());
setTuple(rhs);
} else
abortMsg("cannot assign value to rvalue");
break;
default:
abortMsg("unexpected binary operator " + expr->op.lexeme);
}
Expand All @@ -117,6 +132,7 @@ void CodeGenExprVisitor::visit(Unary* expr) {
void CodeGenExprVisitor::visit(Postfix* expr) { abortMsg("unimplemented"); }
void CodeGenExprVisitor::visit(Variable* expr) {
auto r = scope.get(expr->name);
addr = r.addr;
value = l.builder->CreateLoad(l.getType(r.type), r.addr, r.id.c_str());
}
void CodeGenExprVisitor::visit(Call* expr) {
Expand Down Expand Up @@ -152,7 +168,12 @@ void CodeGenVisitor::visit(ExprStmt* st) {
}
void CodeGenVisitor::visit(AssertStmt* st) { unimplemented(); }
void CodeGenVisitor::visit(PrintStmt* st) { unimplemented(); }
void CodeGenVisitor::visit(VarDecl* st) { unimplemented(); }
void CodeGenVisitor::visit(VarDecl* st) {
auto type = l.getType(st->type);
auto addr =
l.createEntryBlockAlloca(scope.getTrace().llvmFun, type, st->identifier);
scope.define(st->identifier, {st->identifier, st->type, addr});
}
void CodeGenVisitor::visit(FunDecl* st) {
if (scope.isWrapped()) abortMsg("nested function is forbidden");
llvm::Function* F = l.mod->getFunction(st->identifier);
Expand Down Expand Up @@ -194,16 +215,16 @@ void CodeGenVisitor::visit(FunDecl* st) {
// F->eraseFromParent();
}
void CodeGenVisitor::visit(BlockStmt* st) {
for (auto d : st->decls) visit(d);
auto v = wrap();
for (auto d : st->decls) v.visit(d);
}
void CodeGenVisitor::visit(IfStmt* st) {
CodeGenExprVisitor v(scope, l);
v.visit(st->condition);
llvm::Value* condV = v.getValue();
if (!condV) abortMsg("failed to generate condition");

// Convert condition to a bool by comparing non-equal to 0.0.
condV = l.isTruthy(condV);
condV = l.convertToTruthy(condV);
llvm::Function* fun = l.builder->GetInsertBlock()->getParent();

// Create blocks for the then and else cases. Insert the 'then' block at the
Expand All @@ -227,8 +248,10 @@ void CodeGenVisitor::visit(IfStmt* st) {
fun->getBasicBlockList().push_back(elseBB);
l.builder->SetInsertPoint(elseBB);

auto v2 = wrap();
v2.visit(st->false_branch);
if (st->false_branch) {
auto v2 = wrap();
v2.visit(st->false_branch);
}

l.builder->CreateBr(mergeBB);
// codegen of 'Else' can change the current block, update ElseBB for the PHI.
Expand All @@ -244,6 +267,6 @@ void CodeGenVisitor::visit(ReturnStmt* st) {
CodeGenExprVisitor v(scope, l);
v.visit(st->expr);
auto val =
l.convertTo(v.getValue(), scope.getTrace().llvmFun->getReturnType());
l.implictConvert(v.getValue(), scope.getTrace().llvmFun->getReturnType());
l.builder->CreateRet(val);
}
10 changes: 10 additions & 0 deletions visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class CodeGenExprVisitor : public ExprVisitor {
Scope scope;
llvmWrapper l;
llvm::Value* value = nullptr;
llvm::AllocaInst* addr = nullptr;

public:
CodeGenExprVisitor(Scope scope, llvmWrapper l) : scope(scope), l(l){};
Expand All @@ -72,7 +73,16 @@ class CodeGenExprVisitor : public ExprVisitor {
void visit(Postfix* expr) override;
void visit(Variable* expr) override;
void visit(Call* expr) override;

void setValue(llvm::Value* v) { value = v; }
void setAddr(llvm::AllocaInst* a) { addr = a; }
void setTuple(llvm::Value* v, llvm::AllocaInst* a = nullptr) {
value = v;
addr = a;
}

llvm::Value* getValue() { return value; }
llvm::AllocaInst* getAddr() { return addr; }
};

class CodeGenVisitor : public DeclVisitor {
Expand Down

0 comments on commit 6e6c20a

Please sign in to comment.