Skip to content

Commit

Permalink
Add filtering to generate list endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
joeriddles committed Sep 27, 2024
1 parent e75c92f commit ff8c371
Show file tree
Hide file tree
Showing 13 changed files with 331 additions and 144 deletions.
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{
"search.useIgnoreFiles": false
"search.useIgnoreFiles": false,
"yaml.validate": false
}
49 changes: 49 additions & 0 deletions examples/cars/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,51 @@ func Test_GetVehicle(t *testing.T) {
assert.NotEqual(t, 0, actualVehicle.ID)
}

func Test_GetVehicle_WithFilters(t *testing.T) {
// Arrange
ctx := context.Background()
query := newQuery(t)
controller := api.NewVehicleController(query)

_, _, vehicle, _ := setupModels(t, query)

vehicleModelId := int(vehicle.VehicleModelID)
personId := int(vehicle.PersonID)

// Act
testCases := []struct {
name string
params api.GetVehicleParams
expected int
}{
{"Vin", api.GetVehicleParams{Vin: &vehicle.Vin}, 1},
{"VehicleModelID", api.GetVehicleParams{VehicleModelID: &vehicleModelId}, 1},
{"PersonID", api.GetVehicleParams{PersonID: &personId}, 1},
{"Wrong Vin", api.GetVehicleParams{Vin: ptr("nope")}, 0},
{"Wrong VehicleModelId", api.GetVehicleParams{VehicleModelID: ptr(vehicleModelId + 1)}, 0},
{"Wrong PersonID", api.GetVehicleParams{PersonID: ptr(personId + 1)}, 0},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
response, err := controller.GetVehicle(ctx, api.GetVehicleRequestObject{
Params: tc.params,
})
require.NoError(t, err)

// Assert
rec := httptest.NewRecorder()
err = response.VisitGetVehicleResponse(rec)
require.NoError(t, err)
assert.Equal(t, 200, rec.Code)

vehicles := &[]api.Vehicle{}
err = json.Unmarshal(rec.Body.Bytes(), vehicles)
require.NoError(t, err)
assert.Equal(t, tc.expected, len(*vehicles))
})
}
}

func Test_GetVehicleID(t *testing.T) {
// Arrange
ctx := context.Background()
Expand Down Expand Up @@ -264,3 +309,7 @@ func Test_PostVehicleForSale(t *testing.T) {
assert.Equal(t, time.Duration(60), vehicleForSaleFromDb.Duration)
assert.True(t, vehicleForSaleFromDb.Amount.Equal(decimal.NewFromFloat(100.00)))
}

func ptr[T any](val T) *T {
return &val
}
9 changes: 8 additions & 1 deletion examples/custom_controller/custom_controller.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,14 @@ func New{{.model.Name}}Controller(query *query.Query) {{.model.Name}}Controller
}

func (c *{{.model.Name|ToCamelCase}}Controller) Get{{.model.Name}}(ctx context.Context, request {{Types}}Get{{.model.Name}}RequestObject) ({{Types}}Get{{.model.Name}}ResponseObject, error) {
{{.model.Name|ToCamelCase}}s, err := c.repository.List(ctx, map[string]interface{}{})
filters := &repository.{{.model.Name}}Filter{}
j, err := json.Marshal(request.Params)
if err != nil {
return {{Types}}Get{{.model.Name}}400JSONResponse{}, nil
}
json.Unmarshal(j, filters)

{{.model.Name|ToCamelCase}}s, err := c.repository.List(ctx, filters)
if err != nil {
return nil, err
}
Expand Down
11 changes: 9 additions & 2 deletions examples/repository/custom_repository.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,17 @@ import (
model "{{.pkg}}"
)

type {{.model.Name}}Filter struct {
{{range .model.Fields -}}
{{- with .|ToOpenApiType}}{{if not .IsSimpleType}}{{continue}}{{end}}{{end -}}
{{.Name}} {{.|GetGormQueryType|ToPtr}} `json:"{{.Name|ToSnakeCase}},omitempty"`
{{end}}
}

type {{.model.Name}}Repository interface {
List(
ctx context.Context,
filters any,
filters *{{.model.Name}}Filter,
) ([]*model.{{.model.Name}}, error)

Get(
Expand Down Expand Up @@ -47,7 +54,7 @@ func New{{.model.Name}}Repository(query *query.Query) {{.model.Name}}Repository

func (r *{{.model.Name|ToCamelCase}}Repository) List(
ctx context.Context,
filters any,
filters *{{.model.Name}}Filter,
) ([]*model.{{.model.Name}}, error) {
return nil, nil
}
Expand Down
156 changes: 156 additions & 0 deletions pkg/convert/convert.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
package convert

import (
"bufio"
"bytes"
"fmt"
"go/types"
"strings"
"text/template"

"github.com/joeriddles/goalesce/pkg/entity"
"github.com/joeriddles/goalesce/pkg/utils"
)

const DefaultSrc string = "src"
const DefaultDst string = "dst"

// Convert the src field to the matching field on dst.
func ConvertField(
templates *template.Template,
field *entity.GormModelField,
dst *entity.GormModelMetadata,
) string {
return ConvertFieldNamed(templates, field, dst, DefaultSrc, DefaultDst)
}

// Convert the src field to the matching field on dst.
func ConvertFieldNamed(
templates *template.Template,
field *entity.GormModelField,
dst *entity.GormModelMetadata,
from, to string,
) string {
dstField := dst.GetField(field.Name)

if field.MapApiFunc != nil {
return fmt.Sprintf("%v.%v = model.%v(%v.%v)", to, dstField.Name, *field.MapApiFunc, from, field.Name)
}

srcType := field.GetType()
dstType := dstField.GetType()

isSrcPtr := false
if ptrSrc, ok := srcType.(*types.Pointer); ok {
isSrcPtr = true
srcType = ptrSrc.Elem()
}

isDstPtr := false
if ptrDst, ok := dstType.(*types.Pointer); ok {
isDstPtr = true
dstType = ptrDst.Elem()
}

switch s := srcType.(type) {
case *types.Basic:
switch d := dstType.(type) {
case *types.Basic:
if s.Kind() != d.Kind() && types.ConvertibleTo(s, d) {
if isSrcPtr && isDstPtr {
var b bytes.Buffer
w := bufio.NewWriter(&b)
if err := templates.ExecuteTemplate(w, "mapper_ptr_to_ptr.tmpl", map[string]string{
"dst": to,
"dstField": dstField.Name,
"dstType": d.Name(),
"src": from,
"srcField": field.Name,
}); err != nil {
return err.Error()
}
if err := w.Flush(); err != nil {
return err.Error()
}
return b.String()
}

return fmt.Sprintf("%v.%v = %v(%v.%v)", to, dstField.Name, d.Name(), from, field.Name)
}
}
case *types.Named:
switch d := dstType.(type) {
case *types.Named:
if s.Obj().Name() == "Time" && d.Obj().Name() == "DeletedAt" {
return fmt.Sprintf("%v.%v = convertTimeToGormDeletedAt(%v.%v)", to, dstField.Name, from, field.Name)
} else if d.Obj().Name() == "Time" && s.Obj().Name() == "DeletedAt" {
return fmt.Sprintf("%v.%v = convertGormDeletedAtToTime(%v.%v)", to, dstField.Name, from, field.Name)
}

// TODO(joeriddles): add field to GormModelField for references to user-defined models?
if utils.IsComplexType(dstField.Type) && !strings.Contains(dstField.Type, ".") {
isSrcPtr := strings.Contains(field.Type, "*")
mapperName, isDstPtr := strings.CutPrefix(dstField.Type, "*")
if dst.IsApi {
mapperName = mapperName + "Api"
}

if isDstPtr {
if !isSrcPtr {
from = "&" + from
}
return fmt.Sprintf(`%v.%v = New%vMapper().MapPtr(%v.%v)`, to, dstField.Name, mapperName, from, field.Name)
} else {
return fmt.Sprintf(`%v.%v = New%vMapper().Map(%v.%v)`, to, dstField.Name, mapperName, from, field.Name)
}
}
}
case *types.Slice:
if _, ok := dstType.(*types.Slice); ok {
isDstPtr := strings.HasPrefix(dstField.Type, "*")
var isDstElemPtr bool
if isDstPtr {
isDstElemPtr = dstField.Type[3:4] == "*"
} else {
isDstElemPtr = dstField.Type[2:3] == "*"
}

isSrcPtr := strings.HasPrefix(field.Type, "*")
var isSrcElemPtr bool
if isSrcPtr {
isSrcElemPtr = field.Type[3:4] == "*"
} else {
isSrcElemPtr = field.Type[2:3] == "*"
}

mapperName := strings.ReplaceAll(strings.ReplaceAll(dstField.Type, "*", ""), "[]", "")
if dst.IsApi {
mapperName = mapperName + "Api"
}

if dst.IsApi {
if isSrcPtr && isSrcElemPtr {
return fmt.Sprintf(`if %v.%v != nil { %v.%v = New%vMapper().MapPtrSlicePtrs(%v.%v) }`, from, field.Name, to, dstField.Name, mapperName, from, field.Name)
} else if isSrcPtr {
return fmt.Sprintf(`if %v.%v != nil { %v.%v = New%vMapper().MapPtrSlice(%v.%v) }`, from, field.Name, to, dstField.Name, mapperName, from, field.Name)
} else if isSrcElemPtr {
return fmt.Sprintf(`if %v.%v != nil { %v.%v = New%vMapper().MapSlicePtrs(%v.%v) }`, from, field.Name, to, dstField.Name, mapperName, from, field.Name)
} else {
return fmt.Sprintf(`if %v.%v != nil { %v.%v = New%vMapper().MapSlice(%v.%v) }`, from, field.Name, to, dstField.Name, mapperName, from, field.Name)
}
} else {
if isDstPtr && isDstElemPtr {
return fmt.Sprintf(`if %v.%v != nil { %v.%v = New%vMapper().MapPtrSlicePtrs(%v.%v) }`, from, field.Name, to, dstField.Name, mapperName, from, field.Name)
} else if isDstPtr {
return fmt.Sprintf(`if %v.%v != nil { %v.%v = New%vMapper().MapPtrSlice(%v.%v) }`, from, field.Name, to, dstField.Name, mapperName, from, field.Name)
} else if isDstElemPtr {
return fmt.Sprintf(`if %v.%v != nil { %v.%v = New%vMapper().MapSlicePtrs(%v.%v) }`, from, field.Name, to, dstField.Name, mapperName, from, field.Name)
} else {
return fmt.Sprintf(`if %v.%v != nil { %v.%v = New%vMapper().MapSlice(%v.%v) }`, from, field.Name, to, dstField.Name, mapperName, from, field.Name)
}
}
}
}

return fmt.Sprintf("%v.%v = %v.%v", to, dstField.Name, from, field.Name)
}
12 changes: 12 additions & 0 deletions pkg/entity/entity.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package entity

import (
"go/types"
"strings"

"github.com/joeriddles/goalesce/pkg/utils"
)
Expand Down Expand Up @@ -72,3 +73,14 @@ func (f *GormModelField) WithType(t types.Type, moduleName string) {
func (f *GormModelField) GetType() types.Type {
return f.t
}

// Get the type for use in a Go type declaration
//
// Example: github.com/shopspring/decimal.Decimal -> decimal.Decimal
func (f *GormModelField) GetGoType() string {
if strings.Contains(f.Type, "/") {
split := strings.Split(f.Type, "/")
return split[len(split)-1]
}
return f.Type
}
Loading

0 comments on commit ff8c371

Please sign in to comment.