From 1d45539e364cf78bf273e4d17ed5aec69fd55ad2 Mon Sep 17 00:00:00 2001 From: Marc Capell Date: Sun, 8 Sep 2024 00:26:47 +0200 Subject: [PATCH] parser: finish let and return expressions --- ast/ast.go | 26 +++++++++ parser/parser.go | 45 ++++++++++++-- parser/parser_test.go | 133 +++++++++++++++++++++++++++--------------- 3 files changed, 154 insertions(+), 50 deletions(-) diff --git a/ast/ast.go b/ast/ast.go index 2c52557..99149e4 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -257,3 +257,29 @@ func (fl *FunctionLiteral) String() string { return out.String() } + +type CallExpression struct { + Token token.Token // the `(` token + Function Expression + Arguments []Expression +} + +func (ce *CallExpression) expressionNode() {} + +func (ce *CallExpression) TokenLiteral() string { return ce.Token.Literal } + +func (ce *CallExpression) String() string { + var out bytes.Buffer + + args := []string{} + for _, a := range ce.Arguments { + args = append(args, a.String()) + } + + out.WriteString(ce.Function.String()) + out.WriteString("(") + out.WriteString(strings.Join(args, ", ")) + out.WriteString(")") + + return out.String() +} diff --git a/parser/parser.go b/parser/parser.go index 29f579c..d789a4b 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -34,6 +34,7 @@ var precedences = map[token.TokenType]int{ token.MINUS: SUM, token.SLASH: PRODUCT, token.ASTERISK: PRODUCT, + token.LPAREN: CALL, } type Parser struct { @@ -76,6 +77,7 @@ func New(l *lexer.Lexer) *Parser { p.registerInfix(token.NOT_EQ, p.parseInfixExpression) p.registerInfix(token.LT, p.parseInfixExpression) p.registerInfix(token.GT, p.parseInfixExpression) + p.registerInfix(token.LPAREN, p.parseCallExpression) return p } @@ -137,8 +139,11 @@ func (p *Parser) parseLetStatement() *ast.LetStatement { return nil } - // TODO: We're skipping the expressions until we encounter a semicolon - for !p.curTokenIs(token.SEMICOLON) { + p.nextToken() + + stmt.Value = p.parseExpression(LOWEST) + + if p.peekTokenIs(token.SEMICOLON) { p.nextToken() } @@ -150,8 +155,9 @@ func (p *Parser) parseReturnStatement() *ast.ReturnStatement { p.nextToken() - // TODO: We're skipping the expressions until we encounter a semicolon - for !p.curTokenIs(token.SEMICOLON) { + stmt.ReturnValue = p.parseExpression(LOWEST) + + if p.peekTokenIs(token.SEMICOLON) { p.nextToken() } @@ -349,6 +355,37 @@ func (p *Parser) parseFunctionParameters() []*ast.Identifier { return identifiers } +func (p *Parser) parseCallExpression(function ast.Expression) ast.Expression { + exp := &ast.CallExpression{Token: p.curToken, Function: function} + exp.Arguments = p.parseCallArguments() + + return exp +} + +func (p *Parser) parseCallArguments() []ast.Expression { + args := []ast.Expression{} + + p.nextToken() + + if p.peekTokenIs(token.RPAREN) { + return args + } + + args = append(args, p.parseExpression(LOWEST)) + + for p.peekTokenIs(token.COMMA) { + p.nextToken() + p.nextToken() + args = append(args, p.parseExpression(LOWEST)) + } + + if !p.expectPeek(token.RPAREN) { + return nil + } + + return args +} + func (p *Parser) expectPeek(t token.TokenType) bool { if p.peekTokenIs(t) { p.nextToken() diff --git a/parser/parser_test.go b/parser/parser_test.go index 4ff8de3..ad90e3e 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -9,68 +9,62 @@ import ( ) func TestLetStatements(t *testing.T) { - input := ` -let x = 5; -let y = 10; -let foobar = 838383; -` - - l := lexer.New(input) - p := New(l) - - program := p.ParseProgram() - checkParserErrors(t, p) - - if program == nil { - t.Fatalf("ParseProgram() returned nil") - } - - if len(program.Statements) != 3 { - t.Fatalf("program.Statements does not contain 3 statements. got=%d", len(program.Statements)) - } - tests := []struct { + input string expectedIdentifier string + expectedValue interface{} }{ - {"x"}, - {"y"}, - {"foobar"}, + {"let x = 5;", "x", 5}, + {"let y = true;", "y", true}, + {"let foobar = y;", "foobar", "y"}, } - for i, tt := range tests { - stmt := program.Statements[i] + for _, tt := range tests { + l := lexer.New(tt.input) + p := New(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 1 { + t.Fatalf("program.Statements does not contain 1 statements. got=%d", len(program.Statements)) + } + + stmt := program.Statements[0] if !testLetStatement(t, stmt, tt.expectedIdentifier) { return } + + val := stmt.(*ast.LetStatement).Value + if !testLiteralExpression(t, val, tt.expectedValue) { + return + } } } func TestReturnStatements(t *testing.T) { - input := ` -return 5; -return 10; -return 993322; - ` - - l := lexer.New(input) - p := New(l) - - program := p.ParseProgram() - checkParserErrors(t, p) - - if len(program.Statements) != 3 { - t.Fatalf("program.Statements does not contain 3 statements. got=%d", len(program.Statements)) + tests := []struct { + input string + expectedValue interface{} + }{ + {"return 5;", 5}, + {"return true;", true}, + {"return x;", "x"}, } - for _, stmt := range program.Statements { - returnStmt, ok := stmt.(*ast.ReturnStatement) - if !ok { - t.Errorf("stmt not *ast.ReturnStatement. got=%T", stmt) - continue + for _, tt := range tests { + l := lexer.New(tt.input) + p := New(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 1 { + t.Fatalf("program.Statements does not contain 1 statements. got=%d", len(program.Statements)) } - if returnStmt.TokenLiteral() != "return" { - t.Errorf("returnStmt.TokenLiteral not 'return', got %q", returnStmt.TokenLiteral()) + stmt := program.Statements[0] + val := stmt.(*ast.ReturnStatement).ReturnValue + if !testLiteralExpression(t, val, tt.expectedValue) { + return } } } @@ -275,6 +269,18 @@ func TestOperatorPrecedenceParsing(t *testing.T) { "!(true == true)", "(!(true == true))", }, + { + "a + add(b * c) + d", + "((a + add((b * c))) + d)", + }, + { + "add(a, b, 1, 2 * 3, 4 + 5, add(6, 7 * 8))", + "add(a, b, 1, (2 * 3), (4 + 5), add(6, (7 * 8)))", + }, + { + "add(a + b + c * d / f + g)", + "add((((a + b) + ((c * d) / f)) + g))", + }, } for _, tt := range tests { @@ -489,6 +495,41 @@ func TestFunctionParametersParsing(t *testing.T) { } } +func TestCallExpressionParsing(t *testing.T) { + input := `add(1, 2 * 3, 4 + 5);` + + l := lexer.New(input) + p := New(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 1 { + t.Fatalf("program.Statements does not contain %d statements. got=%d\n", 1, len(program.Statements)) + } + + stmt, ok := program.Statements[0].(*ast.ExpressionStatement) + if !ok { + t.Fatalf("program.Statements[0] is not ast.ExpressionStatement. got=%T\n", program.Statements[0]) + } + + exp, ok := stmt.Expression.(*ast.CallExpression) + if !ok { + t.Fatalf("stmt.Expression is not ast.CallExpression. got=%T", stmt.Expression) + } + + if !testIdentifier(t, exp.Function, "add") { + return + } + + if len(exp.Arguments) != 3 { + t.Fatalf("wrong length of arguments. want 3, got=%d", len(exp.Arguments)) + } + + testLiteralExpression(t, exp.Arguments[0], 1) + testInfixExpression(t, exp.Arguments[1], 2, "*", 3) + testInfixExpression(t, exp.Arguments[2], 4, "+", 5) +} + func testLetStatement(t *testing.T, s ast.Statement, name string) bool { if s.TokenLiteral() != "let" { t.Errorf("s.TokenLiteral not 'let'. got=%q", s.TokenLiteral())