Skip to content

Commit

Permalink
fix: union elements that were pointers would panic
Browse files Browse the repository at this point in the history
Fixes #255
  • Loading branch information
alecthomas committed Aug 8, 2022
1 parent e8112b2 commit 49f4822
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 10 deletions.
2 changes: 1 addition & 1 deletion ebnf.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func buildEBNF(root bool, n node, seen map[node]bool, p *ebnfp, outp *[]*ebnfp)
p = &ebnfp{name: name}
*outp = append(*outp, p)
seen[n] = true
for i, next := range n.members {
for i, next := range n.nodeMembers {
if i > 0 {
p.out += " | "
}
Expand Down
7 changes: 5 additions & 2 deletions grammar.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ func (g *generatorContext) addUnionDefs(defs []unionDef) error {
if _, exists := g.typeNodes[def.typ]; exists {
return fmt.Errorf("duplicate definition for interface or union type %s", def.typ)
}
unionNode := &union{def.typ, make([]node, 0, len(def.members))}
unionNode := &union{
unionDef: def,
nodeMembers: make([]node, 0, len(def.members)),
}
g.typeNodes[def.typ], unionNodes[i] = unionNode, unionNode
}
for i, def := range defs {
Expand All @@ -38,7 +41,7 @@ func (g *generatorContext) addUnionDefs(defs []unionDef) error {
if err != nil {
return err
}
unionNode.members = append(unionNode.members, memberNode)
unionNode.nodeMembers = append(unionNode.nodeMembers, memberNode)
}
}
return nil
Expand Down
23 changes: 19 additions & 4 deletions nodes.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,22 +96,22 @@ func (c *custom) Parse(ctx *parseContext, parent reflect.Value) (out []reflect.V

// @@ (for a union)
type union struct {
typ reflect.Type
members []node
unionDef
nodeMembers []node
}

func (u *union) String() string { return ebnf(u) }
func (u *union) GoString() string { return u.typ.Name() }

func (u *union) Parse(ctx *parseContext, parent reflect.Value) (out []reflect.Value, err error) {
defer ctx.printTrace(u)()
temp := disjunction{u.members}
temp := disjunction{u.nodeMembers}
vals, err := temp.Parse(ctx, parent)
if err != nil {
return nil, err
}
for i := range vals {
vals[i] = vals[i].Convert(u.typ)
vals[i] = maybeRef(u.members[i], vals[i]).Convert(u.typ)
}
return vals, nil
}
Expand Down Expand Up @@ -637,6 +637,21 @@ func sizeOfKind(kind reflect.Kind) int {
panic("unsupported kind " + kind.String())
}

func maybeRef(tmpl reflect.Type, strct reflect.Value) reflect.Value {
if strct.Type() == tmpl {
return strct
}
if tmpl.Kind() == reflect.Ptr {
if strct.CanAddr() {
return strct.Addr()
}
ptr := reflect.New(tmpl)
ptr.Set(strct)
return ptr
}
return strct
}

// Set field.
//
// If field is a pointer the pointer will be set to the value. If field is a string, value will be
Expand Down
5 changes: 3 additions & 2 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,10 @@ func ParseTypeWith[T any](parseFn func(*lexer.PeekingLexer) (T, error)) Option {
// try to parse the second member at all.
func Union[T any](members ...T) Option {
return func(p *parserOptions) error {
unionType := reflect.TypeOf((*T)(nil)).Elem()
var t T
unionType := reflect.TypeOf(&t).Elem()
if unionType.Kind() != reflect.Interface {
return fmt.Errorf("Union: union type must be an interface (got %s)", unionType)
return fmt.Errorf("union: union type must be an interface (got %s)", unionType)
}
memberTypes := make([]reflect.Type, 0, len(members))
for _, m := range members {
Expand Down
22 changes: 22 additions & 0 deletions parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1842,3 +1842,25 @@ func TestParseSubProduction(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &expectedItem2, actualItem2)
}

type I255Grammar struct {
Union I255Union `@@`
}

type I255Union interface{ union() }

type I255String struct {
Value string `@String`
}

func (*I255String) union() {}

func TestIssue255(t *testing.T) {
parser, err := participle.Build[I255Grammar](
participle.Union[I255Union](&I255String{}),
)
require.NoError(t, err)
g, err := parser.ParseString("", `"Hello, World!"`)
require.NoError(t, err)
require.Equal(t, &I255Grammar{Union: &I255String{Value: `"Hello, World!"`}}, g)
}
2 changes: 1 addition & 1 deletion visit.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func visit(n node, visitor func(n node, next func() error) error) error {
case *custom:
return nil
case *union:
for _, member := range n.members {
for _, member := range n.nodeMembers {
if err := visit(member, visitor); err != nil {
return err
}
Expand Down

0 comments on commit 49f4822

Please sign in to comment.