From 49f4822ed012d9818c80ca4fcdeb7e2d55c04806 Mon Sep 17 00:00:00 2001 From: Alec Thomas Date: Mon, 8 Aug 2022 12:24:02 +1000 Subject: [PATCH] fix: union elements that were pointers would panic Fixes #255 --- ebnf.go | 2 +- grammar.go | 7 +++++-- nodes.go | 23 +++++++++++++++++++---- options.go | 5 +++-- parser_test.go | 22 ++++++++++++++++++++++ visit.go | 2 +- 6 files changed, 51 insertions(+), 10 deletions(-) diff --git a/ebnf.go b/ebnf.go index 34c894df..1698f753 100644 --- a/ebnf.go +++ b/ebnf.go @@ -62,7 +62,7 @@ func buildEBNF(root bool, n node, seen map[node]bool, p *ebnfp, outp *[]*ebnfp) p = &ebnfp{name: name} *outp = append(*outp, p) seen[n] = true - for i, next := range n.members { + for i, next := range n.nodeMembers { if i > 0 { p.out += " | " } diff --git a/grammar.go b/grammar.go index 43f6122c..45d5e069 100644 --- a/grammar.go +++ b/grammar.go @@ -28,7 +28,10 @@ func (g *generatorContext) addUnionDefs(defs []unionDef) error { if _, exists := g.typeNodes[def.typ]; exists { return fmt.Errorf("duplicate definition for interface or union type %s", def.typ) } - unionNode := &union{def.typ, make([]node, 0, len(def.members))} + unionNode := &union{ + unionDef: def, + nodeMembers: make([]node, 0, len(def.members)), + } g.typeNodes[def.typ], unionNodes[i] = unionNode, unionNode } for i, def := range defs { @@ -38,7 +41,7 @@ func (g *generatorContext) addUnionDefs(defs []unionDef) error { if err != nil { return err } - unionNode.members = append(unionNode.members, memberNode) + unionNode.nodeMembers = append(unionNode.nodeMembers, memberNode) } } return nil diff --git a/nodes.go b/nodes.go index 221a9655..3d609ca7 100644 --- a/nodes.go +++ b/nodes.go @@ -96,8 +96,8 @@ func (c *custom) Parse(ctx *parseContext, parent reflect.Value) (out []reflect.V // @@ (for a union) type union struct { - typ reflect.Type - members []node + unionDef + nodeMembers []node } func (u *union) String() string { return ebnf(u) } @@ -105,13 +105,13 @@ func (u *union) GoString() string { return u.typ.Name() } func (u *union) Parse(ctx *parseContext, parent reflect.Value) (out []reflect.Value, err error) { defer ctx.printTrace(u)() - temp := disjunction{u.members} + temp := disjunction{u.nodeMembers} vals, err := temp.Parse(ctx, parent) if err != nil { return nil, err } for i := range vals { - vals[i] = vals[i].Convert(u.typ) + vals[i] = maybeRef(u.members[i], vals[i]).Convert(u.typ) } return vals, nil } @@ -637,6 +637,21 @@ func sizeOfKind(kind reflect.Kind) int { panic("unsupported kind " + kind.String()) } +func maybeRef(tmpl reflect.Type, strct reflect.Value) reflect.Value { + if strct.Type() == tmpl { + return strct + } + if tmpl.Kind() == reflect.Ptr { + if strct.CanAddr() { + return strct.Addr() + } + ptr := reflect.New(tmpl) + ptr.Set(strct) + return ptr + } + return strct +} + // Set field. // // If field is a pointer the pointer will be set to the value. If field is a string, value will be diff --git a/options.go b/options.go index 65482238..4842ecba 100644 --- a/options.go +++ b/options.go @@ -96,9 +96,10 @@ func ParseTypeWith[T any](parseFn func(*lexer.PeekingLexer) (T, error)) Option { // try to parse the second member at all. func Union[T any](members ...T) Option { return func(p *parserOptions) error { - unionType := reflect.TypeOf((*T)(nil)).Elem() + var t T + unionType := reflect.TypeOf(&t).Elem() if unionType.Kind() != reflect.Interface { - return fmt.Errorf("Union: union type must be an interface (got %s)", unionType) + return fmt.Errorf("union: union type must be an interface (got %s)", unionType) } memberTypes := make([]reflect.Type, 0, len(members)) for _, m := range members { diff --git a/parser_test.go b/parser_test.go index 3ad8d334..64a031fc 100644 --- a/parser_test.go +++ b/parser_test.go @@ -1842,3 +1842,25 @@ func TestParseSubProduction(t *testing.T) { require.NoError(t, err) require.Equal(t, &expectedItem2, actualItem2) } + +type I255Grammar struct { + Union I255Union `@@` +} + +type I255Union interface{ union() } + +type I255String struct { + Value string `@String` +} + +func (*I255String) union() {} + +func TestIssue255(t *testing.T) { + parser, err := participle.Build[I255Grammar]( + participle.Union[I255Union](&I255String{}), + ) + require.NoError(t, err) + g, err := parser.ParseString("", `"Hello, World!"`) + require.NoError(t, err) + require.Equal(t, &I255Grammar{Union: &I255String{Value: `"Hello, World!"`}}, g) +} diff --git a/visit.go b/visit.go index 9371d0d0..e4186b18 100644 --- a/visit.go +++ b/visit.go @@ -20,7 +20,7 @@ func visit(n node, visitor func(n node, next func() error) error) error { case *custom: return nil case *union: - for _, member := range n.members { + for _, member := range n.nodeMembers { if err := visit(member, visitor); err != nil { return err }