forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
parser.cpp
822 lines (772 loc) · 25.3 KB
/
parser.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
#include <torch/csrc/jit/frontend/parser.h>
#include <torch/csrc/jit/frontend/lexer.h>
#include <torch/csrc/jit/frontend/parse_string_literal.h>
#include <torch/csrc/jit/frontend/tree.h>
#include <torch/csrc/jit/frontend/tree_views.h>
#include <optional>
namespace torch::jit {
Decl mergeTypesFromTypeComment(
const Decl& decl,
const Decl& type_annotation_decl,
bool is_method) {
auto expected_num_annotations = decl.params().size();
if (is_method) {
// `self` argument
expected_num_annotations -= 1;
}
if (expected_num_annotations != type_annotation_decl.params().size()) {
throw ErrorReport(decl.range())
<< "Number of type annotations ("
<< type_annotation_decl.params().size()
<< ") did not match the number of "
<< (is_method ? "method" : "function") << " parameters ("
<< expected_num_annotations << ")";
}
auto old = decl.params();
auto _new = type_annotation_decl.params();
// Merge signature idents and ranges with annotation types
std::vector<Param> new_params;
size_t i = is_method ? 1 : 0;
size_t j = 0;
if (is_method) {
new_params.push_back(old[0]);
}
for (; i < decl.params().size(); ++i, ++j) {
new_params.emplace_back(old[i].withType(_new[j].type()));
}
return Decl::create(
decl.range(),
List<Param>::create(decl.range(), new_params),
type_annotation_decl.return_type());
}
struct ParserImpl {
explicit ParserImpl(const std::shared_ptr<Source>& source)
: L(source), shared(sharedParserData()) {}
Ident parseIdent() {
auto t = L.expect(TK_IDENT);
// whenever we parse something that has a TreeView type we always
// use its create method so that the accessors and the constructor
// of the Compound tree are in the same place.
return Ident::create(t.range, t.text());
}
TreeRef createApply(const Expr& expr) {
TreeList attributes;
auto range = L.cur().range;
TreeList inputs;
parseArguments(inputs, attributes);
return Apply::create(
range,
expr,
List<Expr>(makeList(range, std::move(inputs))),
List<Attribute>(makeList(range, std::move(attributes))));
}
static bool followsTuple(int kind) {
switch (kind) {
case TK_PLUS_EQ:
case TK_MINUS_EQ:
case TK_TIMES_EQ:
case TK_DIV_EQ:
case TK_MOD_EQ:
case TK_BIT_OR_EQ:
case TK_BIT_AND_EQ:
case TK_BIT_XOR_EQ:
case TK_LSHIFT_EQ:
case TK_RSHIFT_EQ:
case TK_POW_EQ:
case TK_NEWLINE:
case '=':
case ')':
return true;
default:
return false;
}
}
// exp | expr, | expr, expr, ...
Expr parseExpOrExpTuple() {
auto prefix = parseExp();
if (L.cur().kind == ',') {
std::vector<Expr> exprs = {prefix};
while (L.nextIf(',')) {
if (followsTuple(L.cur().kind))
break;
exprs.push_back(parseExp());
}
auto list = List<Expr>::create(prefix.range(), exprs);
prefix = TupleLiteral::create(list.range(), list);
}
return prefix;
}
// things like a 1.0 or a(4) that are not unary/binary expressions
// and have higher precedence than all of them
TreeRef parseBaseExp() {
TreeRef prefix;
switch (L.cur().kind) {
case TK_NUMBER: {
prefix = parseConst();
} break;
case TK_TRUE:
case TK_FALSE:
case TK_NONE:
case TK_NONE_TYPE: {
auto k = L.cur().kind;
auto r = L.cur().range;
prefix = create_compound(k, r, {});
L.next();
} break;
case '(': {
L.next();
if (L.nextIf(')')) {
/// here we have the empty tuple case
std::vector<Expr> vecExpr;
List<Expr> listExpr = List<Expr>::create(L.cur().range, vecExpr);
prefix = TupleLiteral::create(L.cur().range, listExpr);
break;
}
prefix = parseExpOrExpTuple();
L.expect(')');
} break;
case '[': {
auto list = parseList('[', ',', ']', &ParserImpl::parseExp);
if (list.size() == 1 && (*list.begin()).kind() == TK_LIST_COMP) {
prefix = *list.begin();
} else {
for (auto se : list) {
if (se.kind() == TK_LIST_COMP) {
throw ErrorReport(list.range())
<< " expected a single list comprehension within '[' , ']'";
}
}
prefix = ListLiteral::create(list.range(), List<Expr>(list));
}
} break;
case '{': {
L.next();
// If we have a dict literal, `keys` and `values` will store the keys
// and values used in the object's construction. EDGE CASE: We have a
// dict comprehension, so we'll get the first element of the dict
// comprehension in `keys` and a list comprehension in `values`.
// For example, `{i : chr(i + 65) for i in range(4)}` would give us
// `i` in `keys` and `chr(i + 65) for i in range(4)` in `values`.
// The optimal way of handling this case is to simply splice the new
// dict comprehension together from the existing list comprehension.
// Splicing prevents breaking changes to our API and does not require
// the use of global variables.
std::vector<Expr> keys;
std::vector<Expr> values;
auto range = L.cur().range;
if (L.cur().kind != '}') {
do {
keys.push_back(parseExp());
L.expect(':');
values.push_back(parseExp());
} while (L.nextIf(','));
}
L.expect('}');
if (keys.size() == 1 && (*values.begin()).kind() == TK_LIST_COMP) {
ListComp lc(*values.begin());
prefix = DictComp::create(
range, *keys.begin(), lc.elt(), lc.target(), lc.iter());
} else {
prefix = DictLiteral::create(
range,
List<Expr>::create(range, keys),
List<Expr>::create(range, values));
}
} break;
case TK_STRINGLITERAL: {
prefix = parseConcatenatedStringLiterals();
} break;
case TK_ELLIPSIS:
case TK_DOTS: {
prefix = Dots::create(L.cur().range);
L.next();
} break;
default: {
Ident name = parseIdent();
prefix = Var::create(name.range(), name);
} break;
}
while (true) {
if (L.nextIf('.')) {
const auto name = parseIdent();
prefix = Select::create(name.range(), Expr(prefix), Ident(name));
} else if (L.cur().kind == '(') {
prefix = createApply(Expr(prefix));
} else if (L.cur().kind == '[') {
prefix = parseSubscript(prefix);
} else {
break;
}
}
return prefix;
}
std::optional<TreeRef> maybeParseAssignmentOp() {
auto r = L.cur().range;
switch (L.cur().kind) {
case TK_PLUS_EQ:
case TK_MINUS_EQ:
case TK_TIMES_EQ:
case TK_DIV_EQ:
case TK_BIT_OR_EQ:
case TK_BIT_AND_EQ:
case TK_BIT_XOR_EQ:
case TK_MOD_EQ: {
int modifier = L.next().text()[0];
return create_compound(modifier, r, {});
} break;
case TK_LSHIFT_EQ: {
L.next();
return create_compound(TK_LSHIFT, r, {});
} break;
case TK_RSHIFT_EQ: {
L.next();
return create_compound(TK_RSHIFT, r, {});
} break;
case TK_POW_EQ: {
L.next();
return create_compound(TK_POW, r, {});
} break;
case '=': {
L.next();
return create_compound('=', r, {}); // no reduction
} break;
default:
return std::nullopt;
}
}
TreeRef parseTrinary(
TreeRef true_branch,
const SourceRange& range,
int binary_prec) {
auto cond = parseExp();
L.expect(TK_ELSE);
auto false_branch = parseExp(binary_prec);
return create_compound(
TK_IF_EXPR, range, {cond, std::move(true_branch), false_branch});
}
// parse the longest expression whose binary operators have
// precedence strictly greater than 'precedence'
// precedence == 0 will parse _all_ expressions
// this is the core loop of 'top-down precedence parsing'
Expr parseExp() {
return parseExp(0);
}
Expr parseExp(int precedence) {
TreeRef prefix;
int unary_prec = 0;
if (shared.isUnary(L.cur().kind, &unary_prec)) {
auto kind = L.cur().kind;
auto pos = L.cur().range;
L.next();
auto unary_kind = kind == '*' ? TK_STARRED
: kind == '-' ? TK_UNARY_MINUS
: kind;
auto subexp = parseExp(unary_prec);
// fold '-' into constant numbers, so that attributes can accept
// things like -1
if (unary_kind == TK_UNARY_MINUS && subexp.kind() == TK_CONST) {
prefix = Const::create(subexp.range(), "-" + Const(subexp).text());
} else {
prefix = create_compound(unary_kind, pos, {subexp});
}
} else {
prefix = parseBaseExp();
}
int binary_prec = 0;
while (shared.isBinary(L.cur().kind, &binary_prec)) {
if (binary_prec <= precedence) // not allowed to parse something which is
// not greater than 'precedence'
break;
int kind = L.cur().kind;
auto pos = L.cur().range;
L.next();
if (shared.isRightAssociative(kind))
binary_prec--;
if (kind == TK_NOTIN) {
// NB: `not in` is just `not( in )`, so we don't introduce new tree view
// but just make it a nested call in our tree view structure
prefix = create_compound(TK_IN, pos, {prefix, parseExp(binary_prec)});
prefix = create_compound(TK_NOT, pos, {prefix});
continue;
}
// special case for trinary operator
if (kind == TK_IF) {
prefix = parseTrinary(prefix, pos, binary_prec);
continue;
}
if (kind == TK_FOR) {
// TK_FOR targets should only parse exprs prec greater than 4, which
// only includes subset of Exprs that suppose to be on the LHS according
// to the python grammar
// https://docs.python.org/3/reference/grammar.html
auto target = parseLHSExp();
L.expect(TK_IN);
auto iter = parseExp();
prefix = ListComp::create(pos, Expr(prefix), target, iter);
continue;
}
prefix = create_compound(kind, pos, {prefix, parseExp(binary_prec)});
}
return Expr(prefix);
}
void parseSequence(
int begin,
int sep,
int end,
const std::function<void()>& parse) {
if (begin != TK_NOTHING) {
L.expect(begin);
}
while (end != L.cur().kind) {
parse();
if (!L.nextIf(sep)) {
if (end != TK_NOTHING) {
L.expect(end);
}
return;
}
}
L.expect(end);
}
template <typename T>
List<T> parseList(int begin, int sep, int end, T (ParserImpl::*parse)()) {
auto r = L.cur().range;
std::vector<T> elements;
parseSequence(
begin, sep, end, [&] { elements.emplace_back((this->*parse)()); });
return List<T>::create(r, elements);
}
Const parseConst() {
auto range = L.cur().range;
auto t = L.expect(TK_NUMBER);
return Const::create(t.range, t.text());
}
StringLiteral parseConcatenatedStringLiterals() {
auto range = L.cur().range;
std::string ss;
while (L.cur().kind == TK_STRINGLITERAL) {
auto literal_range = L.cur().range;
ss.append(parseStringLiteral(literal_range, L.next().text()));
}
return StringLiteral::create(range, ss);
}
Expr parseAttributeValue() {
return parseExp();
}
void parseArguments(TreeList& inputs, TreeList& attributes) {
parseSequence('(', ',', ')', [&] {
if (L.cur().kind == TK_IDENT && L.lookahead().kind == '=') {
auto ident = parseIdent();
L.expect('=');
auto v = parseAttributeValue();
attributes.push_back(Attribute::create(ident.range(), Ident(ident), v));
} else {
inputs.push_back(parseExp());
}
});
}
// parse LHS acceptable exprs, which only includes subset of Exprs that prec
// is greater than 4 according to the python grammar
Expr parseLHSExp() {
return parseExp(4);
}
// Parse expr's of the form [a:], [:b], [a:b], [:] and all variations with
// "::"
Expr parseSubscriptExp() {
TreeRef first, second, third;
auto range = L.cur().range;
if (L.cur().kind != ':') {
first = parseExp();
}
if (L.nextIf(':')) {
if (L.cur().kind != ',' && L.cur().kind != ']' && L.cur().kind != ':') {
second = parseExp();
}
if (L.nextIf(':')) {
if (L.cur().kind != ',' && L.cur().kind != ']') {
third = parseExp();
}
}
auto maybe_first = first ? Maybe<Expr>::create(range, Expr(first))
: Maybe<Expr>::create(range);
auto maybe_second = second ? Maybe<Expr>::create(range, Expr(second))
: Maybe<Expr>::create(range);
auto maybe_third = third ? Maybe<Expr>::create(range, Expr(third))
: Maybe<Expr>::create(range);
return SliceExpr::create(range, maybe_first, maybe_second, maybe_third);
} else {
return Expr(first);
}
}
TreeRef parseSubscript(const TreeRef& value) {
const auto range = L.cur().range;
auto subscript_exprs =
parseList('[', ',', ']', &ParserImpl::parseSubscriptExp);
const auto whole_range =
SourceRange(range.source(), range.start(), L.cur().range.start());
return Subscript::create(whole_range, Expr(value), subscript_exprs);
}
Maybe<Expr> maybeParseTypeAnnotation() {
if (L.nextIf(':')) {
// NB: parseExp must not be called inline, since argument evaluation order
// changes when L.cur().range is mutated with respect to the parseExp()
// call.
auto expr = parseExp();
return Maybe<Expr>::create(expr.range(), expr);
} else {
return Maybe<Expr>::create(L.cur().range);
}
}
TreeRef parseFormalParam(bool kwarg_only) {
auto ident = parseIdent();
TreeRef type = maybeParseTypeAnnotation();
TreeRef def;
if (L.nextIf('=')) {
// NB: parseExp must not be called inline, since argument evaluation order
// changes when L.cur().range is mutated with respect to the parseExp()
// call.
auto expr = parseExp();
def = Maybe<Expr>::create(expr.range(), expr);
} else {
def = Maybe<Expr>::create(L.cur().range);
}
return Param::create(
type->range(),
Ident(ident),
Maybe<Expr>(type),
Maybe<Expr>(def),
kwarg_only);
}
Param parseBareTypeAnnotation() {
auto type = parseExp();
return Param::create(
type.range(),
Ident::create(type.range(), ""),
Maybe<Expr>::create(type.range(), type),
Maybe<Expr>::create(type.range()),
/*kwarg_only=*/false);
}
Decl parseTypeComment() {
auto range = L.cur().range;
L.expect(TK_TYPE_COMMENT);
auto param_types =
parseList('(', ',', ')', &ParserImpl::parseBareTypeAnnotation);
TreeRef return_type;
if (L.nextIf(TK_ARROW)) {
auto return_type_range = L.cur().range;
return_type = Maybe<Expr>::create(return_type_range, parseExp());
} else {
return_type = Maybe<Expr>::create(L.cur().range);
}
return Decl::create(range, param_types, Maybe<Expr>(return_type));
}
// 'first' has already been parsed since expressions can exist
// alone on a line:
// first[,other,lhs] = rhs
TreeRef parseAssign(const Expr& lhs) {
auto type = maybeParseTypeAnnotation();
auto maybeOp = maybeParseAssignmentOp();
if (maybeOp) {
// There is an assignment operator, parse the RHS and generate the
// assignment.
auto rhs = parseExpOrExpTuple();
if (maybeOp.value()->kind() == '=') {
std::vector<Expr> lhs_list = {lhs};
while (L.nextIf('=')) {
lhs_list.push_back(rhs);
rhs = parseExpOrExpTuple();
}
if (type.present() && lhs_list.size() > 1) {
throw ErrorReport(type.range())
<< "Annotated multiple assignment is not supported in python";
}
L.expect(TK_NEWLINE);
return Assign::create(
lhs.range(),
List<Expr>::create(lhs_list[0].range(), lhs_list),
Maybe<Expr>::create(rhs.range(), rhs),
type);
} else {
L.expect(TK_NEWLINE);
// this is an augmented assignment
if (lhs.kind() == TK_TUPLE_LITERAL) {
throw ErrorReport(lhs.range())
<< " augmented assignment can only have one LHS expression";
}
return AugAssign::create(
lhs.range(), lhs, AugAssignKind(*maybeOp), Expr(rhs));
}
} else {
// There is no assignment operator, so this is of the form `lhs : <type>`
TORCH_INTERNAL_ASSERT(type.present());
L.expect(TK_NEWLINE);
return Assign::create(
lhs.range(),
List<Expr>::create(lhs.range(), {lhs}),
Maybe<Expr>::create(lhs.range()),
type);
}
}
TreeRef parseStmt(bool in_class = false) {
switch (L.cur().kind) {
case TK_IF:
return parseIf();
case TK_WHILE:
return parseWhile();
case TK_FOR:
return parseFor();
case TK_GLOBAL: {
auto range = L.next().range;
auto idents =
parseList(TK_NOTHING, ',', TK_NOTHING, &ParserImpl::parseIdent);
L.expect(TK_NEWLINE);
return Global::create(range, idents);
}
case TK_RETURN: {
auto range = L.next().range;
Expr value = L.cur().kind != TK_NEWLINE
? parseExpOrExpTuple()
: Expr(create_compound(TK_NONE, range, {}));
L.expect(TK_NEWLINE);
return Return::create(range, value);
}
case TK_RAISE: {
auto range = L.next().range;
auto expr = parseExp();
L.expect(TK_NEWLINE);
return Raise::create(range, expr);
}
case TK_ASSERT: {
auto range = L.next().range;
auto cond = parseExp();
Maybe<Expr> maybe_first = Maybe<Expr>::create(range);
if (L.nextIf(',')) {
auto msg = parseExp();
maybe_first = Maybe<Expr>::create(range, Expr(msg));
}
L.expect(TK_NEWLINE);
return Assert::create(range, cond, maybe_first);
}
case TK_BREAK: {
auto range = L.next().range;
L.expect(TK_NEWLINE);
return Break::create(range);
}
case TK_CONTINUE: {
auto range = L.next().range;
L.expect(TK_NEWLINE);
return Continue::create(range);
}
case TK_PASS: {
auto range = L.next().range;
L.expect(TK_NEWLINE);
return Pass::create(range);
}
case TK_DEF: {
return parseFunction(/*is_method=*/in_class);
}
case TK_DELETE: {
auto range = L.next().range;
auto targets =
parseList(TK_NOTHING, ',', TK_NOTHING, &ParserImpl::parseExp);
L.expect(TK_NEWLINE);
return Delete::create(range, targets);
}
case TK_WITH: {
return parseWith();
}
default: {
auto lhs = parseExpOrExpTuple();
if (L.cur().kind != TK_NEWLINE) {
return parseAssign(lhs);
} else {
L.expect(TK_NEWLINE);
return ExprStmt::create(lhs.range(), lhs);
}
}
}
}
WithItem parseWithItem() {
auto target = parseExp();
if (L.cur().kind == TK_AS) {
// If the current token is TK_AS, this with item is of the form
// "expression as target".
auto token = L.expect(TK_AS);
Ident ident = parseIdent();
auto var = Var::create(ident.range(), ident);
return WithItem::create(
token.range, target, Maybe<Var>::create(ident.range(), var));
} else {
// If not, this with item is of the form "expression".
return WithItem::create(
target.range(), target, Maybe<Var>::create(target.range()));
}
}
TreeRef parseIf(bool expect_if = true) {
auto r = L.cur().range;
if (expect_if)
L.expect(TK_IF);
auto cond = parseExp();
L.expect(':');
auto true_branch = parseStatements(/*expect_indent=*/true);
auto false_branch = makeList(L.cur().range, {});
if (L.nextIf(TK_ELSE)) {
L.expect(':');
false_branch = parseStatements(/*expect_indent=*/true);
} else if (L.nextIf(TK_ELIF)) {
// NB: this needs to be a separate statement, since the call to parseIf
// mutates the lexer state, and thus causes a heap-use-after-free in
// compilers which evaluate argument expressions LTR
auto range = L.cur().range;
false_branch = makeList(range, {parseIf(false)});
}
return If::create(
r, Expr(cond), List<Stmt>(true_branch), List<Stmt>(false_branch));
}
TreeRef parseWhile() {
auto r = L.cur().range;
L.expect(TK_WHILE);
auto cond = parseExp();
L.expect(':');
auto body = parseStatements(/*expect_indent=*/true);
return While::create(r, Expr(cond), List<Stmt>(body));
}
TreeRef parseFor() {
auto r = L.cur().range;
L.expect(TK_FOR);
auto targets = parseList(TK_NOTHING, ',', TK_IN, &ParserImpl::parseLHSExp);
auto itrs = parseList(TK_NOTHING, ',', ':', &ParserImpl::parseExp);
auto body = parseStatements(/*expect_indent=*/true);
return For::create(r, targets, itrs, body);
}
TreeRef parseWith() {
auto r = L.cur().range;
// Parse "with expression [as target][, expression [as target]]*:".
L.expect(TK_WITH);
auto targets = parseList(TK_NOTHING, ',', ':', &ParserImpl::parseWithItem);
// Parse the body.
auto body = parseStatements(/*expect_indent=*/true);
return With::create(r, targets, body);
}
TreeRef parseStatements(bool expect_indent, bool in_class = false) {
auto r = L.cur().range;
if (expect_indent) {
L.expect(TK_INDENT);
}
TreeList stmts;
do {
stmts.push_back(parseStmt(in_class));
} while (!L.nextIf(TK_DEDENT));
return create_compound(TK_LIST, r, std::move(stmts));
}
Maybe<Expr> parseReturnAnnotation() {
if (L.nextIf(TK_ARROW)) {
// Exactly one expression for return type annotation
auto return_type_range = L.cur().range;
return Maybe<Expr>::create(return_type_range, parseExp());
} else {
return Maybe<Expr>::create(L.cur().range);
}
}
List<Param> parseFormalParams() {
auto r = L.cur().range;
std::vector<Param> params;
bool kwarg_only = false;
parseSequence('(', ',', ')', [&] {
if (!kwarg_only && L.nextIf('*')) {
kwarg_only = true;
} else {
params.emplace_back(parseFormalParam(kwarg_only));
}
});
return List<Param>::create(r, params);
}
Decl parseDecl() {
// Parse return type annotation
List<Param> paramlist = parseFormalParams();
TreeRef return_type;
Maybe<Expr> return_annotation = parseReturnAnnotation();
L.expect(':');
return Decl::create(
paramlist.range(), List<Param>(paramlist), return_annotation);
}
TreeRef parseClass() {
L.expect(TK_CLASS_DEF);
const auto name = parseIdent();
Maybe<Expr> superclass = Maybe<Expr>::create(name.range());
if (L.nextIf('(')) {
// Only support inheriting from NamedTuple right now.
auto id = parseExp();
superclass = Maybe<Expr>::create(id.range(), id);
L.expect(')');
}
L.expect(':');
const auto statements =
parseStatements(/*expect_indent=*/true, /*in_class=*/true);
return ClassDef::create(
name.range(), name, superclass, List<Stmt>(statements));
}
TreeRef parseFunction(bool is_method) {
L.expect(TK_DEF);
auto name = parseIdent();
auto decl = parseDecl();
TreeRef stmts_list;
if (L.nextIf(TK_INDENT)) {
// Handle type annotations specified in a type comment as the first line
// of the function.
if (L.cur().kind == TK_TYPE_COMMENT) {
auto type_annotation_decl = Decl(parseTypeComment());
L.expect(TK_NEWLINE);
decl = mergeTypesFromTypeComment(decl, type_annotation_decl, is_method);
}
stmts_list = parseStatements(false);
} else {
// Special case: the Python grammar allows one-line functions with a
// single statement.
if (L.cur().kind == TK_TYPE_COMMENT) {
auto type_annotation_decl = Decl(parseTypeComment());
decl = mergeTypesFromTypeComment(decl, type_annotation_decl, is_method);
}
TreeList stmts;
stmts.push_back(parseStmt(is_method));
stmts_list = create_compound(TK_LIST, L.cur().range, std::move(stmts));
}
return Def::create(
name.range(), Ident(name), Decl(decl), List<Stmt>(stmts_list));
}
Lexer& lexer() {
return L;
}
private:
// short helpers to create nodes
TreeRef create_compound(
int kind,
const SourceRange& range,
TreeList&& trees) {
return Compound::create(kind, range, std::move(trees));
}
TreeRef makeList(const SourceRange& range, TreeList&& trees) {
return create_compound(TK_LIST, range, std::move(trees));
}
Lexer L;
SharedParserData& shared;
};
Parser::Parser(const std::shared_ptr<Source>& src)
: pImpl(new ParserImpl(src)) {}
Parser::~Parser() = default;
TreeRef Parser::parseFunction(bool is_method) {
return pImpl->parseFunction(is_method);
}
TreeRef Parser::parseClass() {
return pImpl->parseClass();
}
Lexer& Parser::lexer() {
return pImpl->lexer();
}
Decl Parser::parseTypeComment() {
return pImpl->parseTypeComment();
}
Expr Parser::parseExp() {
return pImpl->parseExp();
}
} // namespace torch::jit