Skip to content

Commit

Permalink
bugfix, c-style funct decl, type helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
rapiz1 committed Nov 3, 2021
1 parent d75b9e5 commit fd8fb7b
Show file tree
Hide file tree
Showing 16 changed files with 166 additions and 96 deletions.
27 changes: 16 additions & 11 deletions llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,24 @@

#include "log.h"

llvm::Type* llvmWrapper::getBool() { return llvm::Type::getInt1Ty(*ctx); }
llvm::Type* llvmWrapper::getInt() { return llvm::Type::getInt32Ty(*ctx); }
llvm::Type* llvmWrapper::getChar() { return llvm::Type::getInt8Ty(*ctx); }
llvm::Type* llvmWrapper::getDouble() { return llvm::Type::getDoubleTy(*ctx); }

llvm::Type* llvmWrapper::getType(Type type) {
switch (type.base) {
case Type::Base::INT:
return llvm::Type::getInt32Ty(*ctx);
return getInt();
break;
case Type::Base::DOUBLE:
return llvm::Type::getDoubleTy(*ctx);
return getDouble();
break;
case Type::Base::BOOL:
return llvm::Type::getInt1Ty(*ctx);
return getBool();
break;
case Type::Base::CHAR:
return llvm::Type::getInt8Ty(*ctx);
return getChar();
break;
default:
abortMsg("Unrecognize type");
Expand All @@ -24,14 +29,12 @@ llvm::Type* llvmWrapper::getType(Type type) {

llvm::Value* llvmWrapper::convertToTruthy(llvm::Value* v) {
auto t = v->getType();
int w = t->getIntegerBitWidth();
if (t == getBool()) return v;
if (t->isIntegerTy()) {
if (t == llvm::Type::getInt1Ty(*ctx))
return v;
else
return builder->CreateICmpNE(
v, llvm::ConstantInt::get(llvm::IntegerType::get(*ctx, w),
llvm::APInt(w, 0, true)));
int w = t->getIntegerBitWidth();
return builder->CreateICmpNE(
v, llvm::ConstantInt::get(llvm::IntegerType::get(*ctx, w),
llvm::APInt(w, 0, true)));
}
abortMsg("can't use value as boolean");
return nullptr;
Expand All @@ -50,6 +53,8 @@ llvm::Value* llvmWrapper::implictConvert(llvm::Value* v, llvm::Type* t) {
if (v->getType() == t) return v;
if (t->isDoubleTy()) {
return builder->CreateSIToFP(v, t, "todouble");
} else if (t == getBool()) {
return convertToTruthy(v);
} else if (t->isIntegerTy()) {
if (!v->getType()->isIntegerTy())
abortMsg("can't implict convert double into int");
Expand Down
4 changes: 4 additions & 0 deletions llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ struct llvmWrapper {
mod = std::make_shared<llvm::Module>("mod", *ctx);
builder = std::make_shared<llvm::IRBuilder<>>(*ctx);
};
llvm::Type* getBool();
llvm::Type* getInt();
llvm::Type* getChar();
llvm::Type* getDouble();
llvm::Type* getType(Type t);
llvm::Value* convertToTruthy(llvm::Value*);
llvm::Value* implictConvert(llvm::Value*, llvm::Type*);
Expand Down
5 changes: 1 addition & 4 deletions log.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,4 @@ inline void abortMsg(std::string s) {
exit(-1);
};

inline void unimplemented() {
std::cerr << "unimplemented\n";
exit(-1);
}
inline void unimplemented() { abortMsg("unimplemented"); }
44 changes: 17 additions & 27 deletions parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,20 +63,29 @@ std::vector<Declaration*> Parser::program() {

Declaration* Parser::decl() {
Declaration* d = nullptr;
Type type;
Token id(INVALID, "invalid", 0);
switch (peek().tokenType) {
case VAR:
case INT:
case DOUBLE:
case FLOAT:
case CHAR:
case BOOL:
d = varDecl();
type = parseType();
if (match(1, LEFT_SQUARE)) {
// FIXME:
std::cerr << "array declaration not implemented\n";
exit(-1);
consume(RIGHT_SQUARE, "Expect `]`");
}

id = consume(IDENTIFIER, "Expect an identifier");
if (match(1, LEFT_PAREN))
d = funDecl(type, id);
else
d = varDecl(type, id);
break;

case FUNCTION:
d = funDecl();
break;

default:
d = stmt();
break;
Expand Down Expand Up @@ -265,19 +274,7 @@ Type Parser::parseType() {
return {base};
}

VarDecl* Parser::varDecl() {
Type type = parseType();

if (match(1, LEFT_SQUARE)) {
// FIXME:
std::cerr << "array declaration not implemented\n";
exit(-1);

consume(RIGHT_SQUARE, "Expect `]`");
}

Token id = consume(IDENTIFIER, "Expect an identifier");

VarDecl* Parser::varDecl(Type type, Token id) {
Expr* init = nullptr;
if (peek().tokenType == EQUAL) {
advance();
Expand Down Expand Up @@ -315,9 +312,7 @@ RealArgs Parser::real_args() {
return args;
}

FunDecl* Parser::funDecl() {
consume(FUNCTION, "Expect a `function` declaration");
Token id = consume(IDENTIFIER, "Expect an identifier for function name");
FunDecl* Parser::funDecl(Type retType, Token id) {
consume(LEFT_PAREN, "Expect `(` as argument list begins");

Args a;
Expand All @@ -326,11 +321,6 @@ FunDecl* Parser::funDecl() {
}

consume(RIGHT_PAREN, "Expect `)` as argument list ends");
Type retType = {};
if (match(1, RIGHT_ARROW)) {
advance();
retType = parseType();
}

BlockStmt* b = blockStmt();

Expand Down
11 changes: 6 additions & 5 deletions parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@ class Parser {
WhileStmt* whileStmt(); // WHILE '(' EXPRESSION ')' STMT
ReturnStmt* returnStmt(); // RETURN EXPR;
Type parseType(); // (VAR | INT | DOUBLE | CHAR)
VarDecl* varDecl(); // TYPE ('['SIZE']')? IDENTIFIER
// (EQUAL EXPRESSION)? ;
FunDecl* funDecl(); // FUN IDENTIFIER '(' ARGS? ')' (RIGHT_ARROW TYPE)? BLOCK
Args args(); // TYPE ID (, TYPE ID)*
RealArgs real_args(); // EXPR (, EXPR)*
VarDecl* varDecl(Type type, Token id); // TYPE ('['SIZE']')? IDENTIFIER
// (EQUAL EXPRESSION)? ;
FunDecl* funDecl(Type type,
Token id); // TYPE IDENTIFIER '(' ARGS? ')' BLOCK
Args args(); // TYPE ID (, TYPE ID)*
RealArgs real_args(); // EXPR (, EXPR)*

Expr* expression(); // ASSIGN
Expr* assignment(); // LVAL '=' EQUALITY | EQUALITY
Expand Down
10 changes: 6 additions & 4 deletions tests/assert
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
assert 1 == 1;
assert true == true;
assert 1 != 0;
assert true != false;
int main() {
assert 1 == 1;
assert true == true;
assert 1 != 0;
assert true != false;
}
13 changes: 8 additions & 5 deletions tests/gcd
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
function gcd(int a, int b) -> int {
int gcd(int a, int b) {
if (b == 0) return a;
while (a >= b) a = a - b;
return gcd(b, a);
}

assert gcd(99, 121) == 11;
assert gcd(1, 1) == 1;
assert gcd(123, 8) == 1;
assert gcd(21, 30) == 3;
int main() {
assert gcd(99, 121) == 11;
assert gcd(1, 1) == 1;
assert gcd(123, 8) == 1;
assert gcd(21, 30) == 3;
}

18 changes: 10 additions & 8 deletions tests/prefix
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
int a = 3;
assert --a == 2;
assert ++a == 3;
bool b = true;
assert !b == false;
assert !!b == true;
assert !!!b == false;
assert !!!!b == true;
int main() {
int a = 3;
assert --a == 2;
assert ++a == 3;
bool b = true;
assert !b == false;
assert !!b == true;
assert !!!b == false;
assert !!!!b == true;
}
13 changes: 8 additions & 5 deletions tests/sum
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
int sum = 0;
for (int i = 1; i <= 10; i++)
sum = sum + i;
print sum;
assert sum == 55;
int main() {
int sum = 0;
for (int i = 1; i <= 10; i++) {
sum = sum + i;
}
print sum;
assert sum == 55;
}
6 changes: 4 additions & 2 deletions tests/types/bool
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
bool b = true;
print b;
int main() {
bool b = true;
print b;
}
6 changes: 4 additions & 2 deletions tests/types/double
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
double a = 1;
print a;
int main() {
double a = 1;
print a;
}
6 changes: 4 additions & 2 deletions tests/types/float
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
float a = 1;
print a;
int main() {
float a = 1;
print a;
}
6 changes: 4 additions & 2 deletions tests/types/frac
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
double a = 1.2;
print a;
int main() {
double a = 1.2;
print a;
}
6 changes: 4 additions & 2 deletions tests/types/int
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
int a = 1;
print a;
int main() {
int a = 1;
print a;
}
84 changes: 67 additions & 17 deletions visitor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ void CodeGenExprVisitor::visit(Integer* expr) {
void CodeGenExprVisitor::visit(Double* expr) {
setValue(llvm::ConstantFP::get(*l.ctx, llvm::APFloat(expr->value)));
}
void CodeGenExprVisitor::visit(Boolean* expr) {
setValue(llvm::ConstantInt::get(*l.ctx, llvm::APInt(1, expr->value)));
}

void CodeGenExprVisitor::visit(Binary* expr) {
CodeGenExprVisitor lv(scope, l);
Expand Down Expand Up @@ -52,12 +55,16 @@ void CodeGenExprVisitor::visit(Binary* expr) {
rhs = l.builder->CreateFPCast(rhs, llvm::Type::getDoubleTy(*l.ctx),
"casttmp");
} else if (hasInteger) {
if (!lhs->getType()->isIntegerTy())
lhs = l.builder->CreateIntCast(lhs, llvm::Type::getInt32Ty(*l.ctx), true,
"casttmp");
if (!rhs->getType()->isIntegerTy())
rhs = l.builder->CreateIntCast(rhs, llvm::Type::getInt32Ty(*l.ctx), true,
"casttmp");
unsigned int maxw = 0;
if (lhs->getType()->isIntegerTy())
maxw = std::max(maxw, lhs->getType()->getIntegerBitWidth());
if (rhs->getType()->isIntegerTy())
maxw = std::max(maxw, rhs->getType()->getIntegerBitWidth());
auto upgradeType = llvm::IntegerType::get(*l.ctx, maxw);
if (lhs->getType() != upgradeType)
lhs = l.builder->CreateIntCast(lhs, upgradeType, true, "casttmp");
if (rhs->getType() != upgradeType)
rhs = l.builder->CreateIntCast(rhs, upgradeType, true, "casttmp");
}
switch (op) {
case PLUS:
Expand Down Expand Up @@ -154,22 +161,58 @@ void CodeGenExprVisitor::visit(Binary* expr) {
}

void CodeGenExprVisitor::visit(Unary* expr) {
unimplemented();
/*
visit(expr->child);
auto v = getValue();
auto op = expr->op.tokenType;
switch (op) {
case BANG:
break;
if (op == BANG) {
value = l.convertToTruthy(value);
value = l.builder->CreateNot(value);
} else if (op == MINUS) {
value = l.builder->CreateNeg(value);
} else {
if (!addr)
abortMsg("cannot apply operator " + expr->op.lexeme + " to lvalue");
if (!value->getType()->isIntegerTy())
abortMsg("cant apply " + expr->op.lexeme + "to non integer");
int width = value->getType()->getIntegerBitWidth();
auto con = llvm::Constant::getIntegerValue(value->getType(),
llvm::APInt(width, 1, true));
switch (op) {
case MINUSMINUS:
value = l.builder->CreateSub(value, con);
break;
case PLUSPLUS:
value = l.builder->CreateAdd(value, con);
break;
default:
break;
}
l.builder->CreateStore(addr, value);
}
}

void CodeGenExprVisitor::visit(Postfix* expr) {
visit(expr->child);
auto val = getValue();
auto addr = getAddr();
if (!val->getType()->isIntegerTy())
abortMsg("cant apply " + expr->op.lexeme + "to non integer");
int width = val->getType()->getIntegerBitWidth();
auto con = llvm::Constant::getIntegerValue(val->getType(),
llvm::APInt(width, 1, true));
llvm::Value* ret = nullptr;
switch (expr->op.tokenType) {
case PLUSPLUS:
ret = l.builder->CreateAdd(val, con);
break;
case MINUSMINUS:
ret = l.builder->CreateSub(val, con);
break;
default:
abortMsg("unimplemented postfix operator " + expr->op.lexeme);
break;
}
*/
l.builder->CreateStore(ret, addr);
}

void CodeGenExprVisitor::visit(Postfix* expr) { abortMsg("unimplemented"); }
void CodeGenExprVisitor::visit(Variable* expr) {
auto r = scope.get(expr->name);
addr = r.addr;
Expand Down Expand Up @@ -206,9 +249,16 @@ void CodeGenVisitor::visit(ExprStmt* st) {
CodeGenExprVisitor v(scope, l);
v.visit(st->expr);
}
void CodeGenVisitor::visit(AssertStmt* st) { /*unimplemented();*/
void CodeGenVisitor::visit(AssertStmt* st) {
CodeGenExprVisitor v(scope, l);
v.visit(st->expr);
// FIXME:
}
void CodeGenVisitor::visit(PrintStmt* st) {
CodeGenExprVisitor v(scope, l);
v.visit(st->expr);
// FIXME:
}
void CodeGenVisitor::visit(PrintStmt* st) { unimplemented(); }
void CodeGenVisitor::visit(VarDecl* st) {
auto type = l.getType(st->type);
auto addr =
Expand Down
Loading

0 comments on commit fd8fb7b

Please sign in to comment.