diff --git a/examples/argument_matching_test.go b/examples/argument_matching_test.go index ef621fa..4ea0ef9 100644 --- a/examples/argument_matching_test.go +++ b/examples/argument_matching_test.go @@ -41,6 +41,9 @@ type Maths interface { type Sender interface { SendMessage(title *string, message string) error SendMany(details map[string]string) error + + // Blocks the specified email address from being sent to. + Block(string) error } type ArgumentMatchingTests struct { @@ -128,6 +131,18 @@ func (t *ArgumentMatchingTests) Test_CanMatchMaps() { t.NoError(successResult) } +func (t *ArgumentMatchingTests) Test_CanHandleNamelessParameters() { + // Arrange + mock := sender.NewMock() + mock.Setup(sender.Block("bad-recipient@somewhere.com").Return(errors.New("cannot block that recipient!"))) + + // Act + blockedResult := mock.Instance().Block("bad-recipient@somewhere.com") + + // Assert + t.ErrorContains(blockedResult, "cannot block that recipient!") +} + func TestArgumentMatching(t *testing.T) { suite.Run(t, new(ArgumentMatchingTests)) } diff --git a/examples/mocks/sender/sender.go b/examples/mocks/sender/sender.go index 4bd34ad..fade39f 100644 --- a/examples/mocks/sender/sender.go +++ b/examples/mocks/sender/sender.go @@ -64,6 +64,27 @@ func (m *instance) SendMany(details map[string]string) (r0 error) { return } +// Blocks the specified email address from being sent to. +func (m *instance) Block(_p0 string) (r0 error) { + expectation := m.mock.Call("Block", _p0) + if expectation != nil { + if expectation.ObserveFn != nil { + observe := expectation.ObserveFn.(func(_p0 string) error) + return observe(_p0) + } + + if expectation.PanicArg != nil { + panic(expectation.PanicArg) + } + + if expectation.Returns[0] != nil { + r0 = expectation.Returns[0].(error) + } + } + + return +} + func (m *Mock) Instance() *instance { return &m.instance } @@ -315,3 +336,125 @@ type sendManyAction struct { func (a *sendManyAction) CreateExpectation() *mocking.Expectation { return &a.expectation } + +type blockMethodMatcher struct { + matcher mocking.MethodMatcher +} + +func (m *blockMethodMatcher) CreateMethodMatcher() *mocking.MethodMatcher { + return &m.matcher +} + +// Blocks the specified email address from being sent to. +func Block[P0 string | mocking.Matcher[string]](_p0 P0) *blockMethodMatcher { + result := blockMethodMatcher{ + matcher: mocking.MethodMatcher{ + MethodName: "Block", + ArgumentMatchers: make([]mocking.ArgumentMatcher, 1), + }, + } + + if matcher, ok := any(_p0).(mocking.Matcher[string]); ok { + result.matcher.ArgumentMatchers[0] = matcher + } else { + result.matcher.ArgumentMatchers[0] = kelpie.ExactMatch(any(_p0).(string)) + } + + return &result +} + +type blockTimes struct { + matcher *blockMethodMatcher +} + +// Times allows you to restrict the number of times a particular expectation can be matched. +func (m *blockMethodMatcher) Times(times uint) *blockTimes { + m.matcher.Times = × + + return &blockTimes{ + matcher: m, + } +} + +// Once specifies that the expectation will only match once. +func (m *blockMethodMatcher) Once() *blockTimes { + return m.Times(1) +} + +// Never specifies that the method has not been called. This is mainly useful for verification +// rather than mocking. +func (m *blockMethodMatcher) Never() *blockTimes { + return m.Times(0) +} + +// Return returns the specified results when the method is called. +func (t *blockTimes) Return(r0 error) *blockAction { + return &blockAction{ + expectation: mocking.Expectation{ + MethodMatcher: &t.matcher.matcher, + Returns: []any{r0}, + }, + } +} + +// Panic panics using the specified argument when the method is called. +func (t *blockTimes) Panic(arg any) *blockAction { + return &blockAction{ + expectation: mocking.Expectation{ + MethodMatcher: &t.matcher.matcher, + PanicArg: arg, + }, + } +} + +// When calls the specified observe callback when the method is called. +func (t *blockTimes) When(observe func(_p0 string) error) *blockAction { + return &blockAction{ + expectation: mocking.Expectation{ + MethodMatcher: &t.matcher.matcher, + ObserveFn: observe, + }, + } +} + +func (t *blockTimes) CreateMethodMatcher() *mocking.MethodMatcher { + return &t.matcher.matcher +} + +// Return returns the specified results when the method is called. +func (m *blockMethodMatcher) Return(r0 error) *blockAction { + return &blockAction{ + expectation: mocking.Expectation{ + MethodMatcher: &m.matcher, + Returns: []any{r0}, + }, + } +} + +// Panic panics using the specified argument when the method is called. +func (m *blockMethodMatcher) Panic(arg any) *blockAction { + return &blockAction{ + expectation: mocking.Expectation{ + MethodMatcher: &m.matcher, + PanicArg: arg, + }, + } +} + +// When calls the specified observe callback when the method is called. +func (m *blockMethodMatcher) When(observe func(_p0 string) error) *blockAction { + return &blockAction{ + expectation: mocking.Expectation{ + MethodMatcher: &m.matcher, + ObserveFn: observe, + }, + } +} + +type blockAction struct { + expectation mocking.Expectation +} + +func (a *blockAction) CreateExpectation() *mocking.Expectation { + return &a.expectation +} diff --git a/parser/import_helper.go b/parser/import_helper.go index ce8cd0c..49bd720 100644 --- a/parser/import_helper.go +++ b/parser/import_helper.go @@ -103,6 +103,18 @@ func (i *importHelper) getPackageIdentifiers(e ast.Expr) []*ast.Ident { case *ast.MapType: identifiers = append(identifiers, i.getPackageIdentifiers(n.Key)...) identifiers = append(identifiers, i.getPackageIdentifiers(n.Value)...) + case *ast.Ellipsis: + identifiers = append(identifiers, i.getPackageIdentifiers(n.Elt)...) + case *ast.FuncType: + for _, param := range n.Params.List { + identifiers = append(identifiers, i.getPackageIdentifiers(param.Type)...) + } + + if n.Results != nil { + for _, result := range n.Results.List { + identifiers = append(identifiers, i.getPackageIdentifiers(result.Type)...) + } + } default: panic(fmt.Sprintf("Could not get package identifier from ast expression: %v. This is a bug in Kelpie!", e)) } diff --git a/parser/parser.go b/parser/parser.go index 5f877b3..0133e7a 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -5,6 +5,7 @@ package parser import ( "fmt" "go/ast" + "strconv" "strings" "github.com/pkg/errors" @@ -118,10 +119,17 @@ func Parse(packageName string, directory string, filter InterfaceFilter) ([]Mock // TODO: check what situation would cause Type to not be ast.FuncType. Maybe ast.Bad? funcType := method.Type.(*ast.FuncType) - for _, param := range funcType.Params.List { - for _, paramName := range param.Names { + for paramIndex, param := range funcType.Params.List { + if len(param.Names) > 0 { + for _, paramName := range param.Names { + methodDefinition.Parameters = append(methodDefinition.Parameters, ParameterDefinition{ + Name: paramName.Name, + Type: getTypeName(param.Type, p), + }) + } + } else { methodDefinition.Parameters = append(methodDefinition.Parameters, ParameterDefinition{ - Name: paramName.Name, + Name: "_p" + strconv.Itoa(paramIndex), Type: getTypeName(param.Type, p), }) } @@ -199,6 +207,37 @@ func getTypeName(e ast.Expr, p *packages.Package) string { valueType := getTypeName(n.Value, p) return "map[" + keyType + "]" + valueType + case *ast.FuncType: + var params []string + for _, param := range n.Params.List { + parameterNames := slices.Map(param.Names, func(i *ast.Ident) string { return i.Name }) + if len(parameterNames) > 0 { + params = append(params, strings.Join(parameterNames, ", ")+" "+getTypeName(param.Type, p)) + } else { + params = append(params, getTypeName(param.Type, p)) + } + } + + var results []string + if n.Results != nil { + for _, result := range n.Results.List { + resultNames := slices.Map(result.Names, func(i *ast.Ident) string { return i.Name }) + if len(resultNames) > 0 { + results = append(results, strings.Join(resultNames, ", ")+" "+getTypeName(result.Type, p)) + } else { + results = append(results, getTypeName(result.Type, p)) + } + } + } + + functionDefinition := "func(" + strings.Join(params, ", ") + ")" + if len(results) > 0 { + functionDefinition += " (" + strings.Join(results, ", ") + ")" + } + + return functionDefinition + case *ast.Ellipsis: + return "..." + getTypeName(n.Elt, p) } panic(fmt.Sprintf("Unknown type %v. This is a bug in Kelpie!", e)) diff --git a/parser/parser_test.go b/parser/parser_test.go index 4172199..cf7f2bd 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -404,6 +404,121 @@ type UserService interface { t.Contains(requester.Imports, `"github.com/adamconnelly/kelpie-test/users"`) } +func (t *ParserTests) Test_Parse_SupportsFunctionsInParameters() { + // Arrange + input := `package users + +type User struct {} + +type UserService interface { + UpdateUsers(callback func(id int, user User)) error +}` + + // Act + result, err := t.ParseInput("users", input, t.interfaceFilter.Instance()) + + // Assert + t.NoError(err) + + userService := result[0] + + updateUsers := slices.FirstOrPanic(userService.Methods, func(m parser.MethodDefinition) bool { return m.Name == "UpdateUsers" }) + t.Equal("callback", updateUsers.Parameters[0].Name) + t.Equal("func(id int, user users.User)", updateUsers.Parameters[0].Type) + + t.Len(userService.Imports, 1) + t.Contains(userService.Imports, `"github.com/adamconnelly/kelpie-test/users"`) +} + +func (t *ParserTests) Test_Parse_SupportsNamelessParameters() { + // Arrange + input := `package users + +type UserType int + +type User struct {} + +type UserService interface { + FindUser(int, UserType) (*User, error) +}` + + // Act + result, err := t.ParseInput("users", input, t.interfaceFilter.Instance()) + + // Assert + t.NoError(err) + + userService := result[0] + + findUser := slices.FirstOrPanic(userService.Methods, func(m parser.MethodDefinition) bool { return m.Name == "FindUser" }) + t.Equal("_p0", findUser.Parameters[0].Name) + t.Equal("int", findUser.Parameters[0].Type) + + t.Equal("_p1", findUser.Parameters[1].Name) + t.Equal("users.UserType", findUser.Parameters[1].Type) +} + +func (t *ParserTests) Test_Parse_SupportsFunctionsInResults() { + // Arrange + input := `package users + +type User struct {} + +type UserService interface { + GetUserFn() func (id int) (*User, error) +}` + + // Act + result, err := t.ParseInput("users", input, t.interfaceFilter.Instance()) + + // Assert + t.NoError(err) + + userService := result[0] + + updateUserFn := slices.FirstOrPanic(userService.Methods, func(m parser.MethodDefinition) bool { return m.Name == "GetUserFn" }) + t.Equal("", updateUserFn.Results[0].Name) + t.Equal("func(id int) (*users.User, error)", updateUserFn.Results[0].Type) + + t.Len(userService.Imports, 1) + t.Contains(userService.Imports, `"github.com/adamconnelly/kelpie-test/users"`) +} + +func (t *ParserTests) Test_Parse_SupportsVariadicFunctions() { + // Arrange + input := `package users + +type UserType int + +type FindUsersOptions struct { + Type *UserType +} + +type User struct { + ID uint + Name string + Type UserType +} + +type UserService interface { + FindUsers(opts ...func(*FindUsersOptions)) ([]*User, error) +}` + + // Act + result, err := t.ParseInput("users", input, t.interfaceFilter.Instance()) + + // Assert + t.NoError(err) + + requester := result[0] + + findUsers := slices.FirstOrPanic(requester.Methods, func(m parser.MethodDefinition) bool { return m.Name == "FindUsers" }) + t.Equal("...func(*users.FindUsersOptions)", findUsers.Parameters[0].Type) + + t.Len(requester.Imports, 1) + t.Contains(requester.Imports, `"github.com/adamconnelly/kelpie-test/users"`) +} + // TODO: add a test for handling types that can't be resolved (e.g. because of a mistake in the code we're parsing) // TODO: what about empty interfaces? Return a warning?