Skip to content

Commit

Permalink
feat: add support for variadic functions
Browse files Browse the repository at this point in the history
To make it useful, I also had to add support for unnamed parameters since that's fairly common when writing interfaces like this.
  • Loading branch information
adamconnelly committed Jun 2, 2024
1 parent 7b5adb1 commit bb57a55
Show file tree
Hide file tree
Showing 5 changed files with 327 additions and 3 deletions.
15 changes: 15 additions & 0 deletions examples/argument_matching_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ type Maths interface {
type Sender interface {
SendMessage(title *string, message string) error
SendMany(details map[string]string) error

// Blocks the specified email address from being sent to.
Block(string) error
}

type ArgumentMatchingTests struct {
Expand Down Expand Up @@ -128,6 +131,18 @@ func (t *ArgumentMatchingTests) Test_CanMatchMaps() {
t.NoError(successResult)
}

func (t *ArgumentMatchingTests) Test_CanHandleNamelessParameters() {
// Arrange
mock := sender.NewMock()
mock.Setup(sender.Block("[email protected]").Return(errors.New("cannot block that recipient!")))

// Act
blockedResult := mock.Instance().Block("[email protected]")

// Assert
t.ErrorContains(blockedResult, "cannot block that recipient!")
}

func TestArgumentMatching(t *testing.T) {
suite.Run(t, new(ArgumentMatchingTests))
}
143 changes: 143 additions & 0 deletions examples/mocks/sender/sender.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions parser/import_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,18 @@ func (i *importHelper) getPackageIdentifiers(e ast.Expr) []*ast.Ident {
case *ast.MapType:
identifiers = append(identifiers, i.getPackageIdentifiers(n.Key)...)
identifiers = append(identifiers, i.getPackageIdentifiers(n.Value)...)
case *ast.Ellipsis:
identifiers = append(identifiers, i.getPackageIdentifiers(n.Elt)...)
case *ast.FuncType:
for _, param := range n.Params.List {
identifiers = append(identifiers, i.getPackageIdentifiers(param.Type)...)
}

if n.Results != nil {
for _, result := range n.Results.List {
identifiers = append(identifiers, i.getPackageIdentifiers(result.Type)...)
}
}
default:
panic(fmt.Sprintf("Could not get package identifier from ast expression: %v. This is a bug in Kelpie!", e))
}
Expand Down
45 changes: 42 additions & 3 deletions parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package parser
import (
"fmt"
"go/ast"
"strconv"
"strings"

"github.com/pkg/errors"
Expand Down Expand Up @@ -118,10 +119,17 @@ func Parse(packageName string, directory string, filter InterfaceFilter) ([]Mock

// TODO: check what situation would cause Type to not be ast.FuncType. Maybe ast.Bad?
funcType := method.Type.(*ast.FuncType)
for _, param := range funcType.Params.List {
for _, paramName := range param.Names {
for paramIndex, param := range funcType.Params.List {
if len(param.Names) > 0 {
for _, paramName := range param.Names {
methodDefinition.Parameters = append(methodDefinition.Parameters, ParameterDefinition{
Name: paramName.Name,
Type: getTypeName(param.Type, p),
})
}
} else {
methodDefinition.Parameters = append(methodDefinition.Parameters, ParameterDefinition{
Name: paramName.Name,
Name: "_p" + strconv.Itoa(paramIndex),
Type: getTypeName(param.Type, p),
})
}
Expand Down Expand Up @@ -199,6 +207,37 @@ func getTypeName(e ast.Expr, p *packages.Package) string {
valueType := getTypeName(n.Value, p)

return "map[" + keyType + "]" + valueType
case *ast.FuncType:
var params []string
for _, param := range n.Params.List {
parameterNames := slices.Map(param.Names, func(i *ast.Ident) string { return i.Name })
if len(parameterNames) > 0 {
params = append(params, strings.Join(parameterNames, ", ")+" "+getTypeName(param.Type, p))
} else {
params = append(params, getTypeName(param.Type, p))
}
}

var results []string
if n.Results != nil {
for _, result := range n.Results.List {
resultNames := slices.Map(result.Names, func(i *ast.Ident) string { return i.Name })
if len(resultNames) > 0 {
results = append(results, strings.Join(resultNames, ", ")+" "+getTypeName(result.Type, p))
} else {
results = append(results, getTypeName(result.Type, p))
}
}
}

functionDefinition := "func(" + strings.Join(params, ", ") + ")"
if len(results) > 0 {
functionDefinition += " (" + strings.Join(results, ", ") + ")"
}

return functionDefinition
case *ast.Ellipsis:
return "..." + getTypeName(n.Elt, p)
}

panic(fmt.Sprintf("Unknown type %v. This is a bug in Kelpie!", e))
Expand Down
115 changes: 115 additions & 0 deletions parser/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,121 @@ type UserService interface {
t.Contains(requester.Imports, `"github.com/adamconnelly/kelpie-test/users"`)
}

func (t *ParserTests) Test_Parse_SupportsFunctionsInParameters() {
// Arrange
input := `package users
type User struct {}
type UserService interface {
UpdateUsers(callback func(id int, user User)) error
}`

// Act
result, err := t.ParseInput("users", input, t.interfaceFilter.Instance())

// Assert
t.NoError(err)

userService := result[0]

updateUsers := slices.FirstOrPanic(userService.Methods, func(m parser.MethodDefinition) bool { return m.Name == "UpdateUsers" })
t.Equal("callback", updateUsers.Parameters[0].Name)
t.Equal("func(id int, user users.User)", updateUsers.Parameters[0].Type)

t.Len(userService.Imports, 1)
t.Contains(userService.Imports, `"github.com/adamconnelly/kelpie-test/users"`)
}

func (t *ParserTests) Test_Parse_SupportsNamelessParameters() {
// Arrange
input := `package users
type UserType int
type User struct {}
type UserService interface {
FindUser(int, UserType) (*User, error)
}`

// Act
result, err := t.ParseInput("users", input, t.interfaceFilter.Instance())

// Assert
t.NoError(err)

userService := result[0]

findUser := slices.FirstOrPanic(userService.Methods, func(m parser.MethodDefinition) bool { return m.Name == "FindUser" })
t.Equal("_p0", findUser.Parameters[0].Name)
t.Equal("int", findUser.Parameters[0].Type)

t.Equal("_p1", findUser.Parameters[1].Name)
t.Equal("users.UserType", findUser.Parameters[1].Type)
}

func (t *ParserTests) Test_Parse_SupportsFunctionsInResults() {
// Arrange
input := `package users
type User struct {}
type UserService interface {
GetUserFn() func (id int) (*User, error)
}`

// Act
result, err := t.ParseInput("users", input, t.interfaceFilter.Instance())

// Assert
t.NoError(err)

userService := result[0]

updateUserFn := slices.FirstOrPanic(userService.Methods, func(m parser.MethodDefinition) bool { return m.Name == "GetUserFn" })
t.Equal("", updateUserFn.Results[0].Name)
t.Equal("func(id int) (*users.User, error)", updateUserFn.Results[0].Type)

t.Len(userService.Imports, 1)
t.Contains(userService.Imports, `"github.com/adamconnelly/kelpie-test/users"`)
}

func (t *ParserTests) Test_Parse_SupportsVariadicFunctions() {
// Arrange
input := `package users
type UserType int
type FindUsersOptions struct {
Type *UserType
}
type User struct {
ID uint
Name string
Type UserType
}
type UserService interface {
FindUsers(opts ...func(*FindUsersOptions)) ([]*User, error)
}`

// Act
result, err := t.ParseInput("users", input, t.interfaceFilter.Instance())

// Assert
t.NoError(err)

requester := result[0]

findUsers := slices.FirstOrPanic(requester.Methods, func(m parser.MethodDefinition) bool { return m.Name == "FindUsers" })
t.Equal("...func(*users.FindUsersOptions)", findUsers.Parameters[0].Type)

t.Len(requester.Imports, 1)
t.Contains(requester.Imports, `"github.com/adamconnelly/kelpie-test/users"`)
}

// TODO: add a test for handling types that can't be resolved (e.g. because of a mistake in the code we're parsing)
// TODO: what about empty interfaces? Return a warning?

Expand Down

0 comments on commit bb57a55

Please sign in to comment.