Skip to content

Commit

Permalink
fix: allows to search for exposed nested resources and doesn't expose…
Browse files Browse the repository at this point in the history
… internal sql columns
  • Loading branch information
gdbranco committed Jun 13, 2024
1 parent 6b6f631 commit c90765f
Show file tree
Hide file tree
Showing 6 changed files with 255 additions and 69 deletions.
20 changes: 20 additions & 0 deletions pkg/dao/dinosaur.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,28 @@ import (

"github.com/openshift-online/rh-trex/pkg/api"
"github.com/openshift-online/rh-trex/pkg/db"
"github.com/openshift-online/rh-trex/pkg/util"
)

var (
dinosaurTableName = util.ToSnakeCase(api.DinosaurTypeName) + "s"
dinosaurColumns = []string{
"id",
"created_at",
"updated_at",
"species",
}
)

func DinosaurApiToModel() TableMappingRelation {
result := map[string]string{}
applyBaseMapping(result, dinosaurColumns, dinosaurTableName)
return TableMappingRelation{
Mapping: result,
relationTableName: dinosaurTableName,
}
}

type DinosaurDao interface {
Get(ctx context.Context, id string) (*api.Dinosaur, error)
Create(ctx context.Context, dinosaur *api.Dinosaur) (*api.Dinosaur, error)
Expand Down
36 changes: 36 additions & 0 deletions pkg/dao/generic.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package dao

import (
"context"
"fmt"
"strings"

"github.com/jinzhu/inflection"
Expand All @@ -10,6 +11,41 @@ import (
"github.com/openshift-online/rh-trex/pkg/db"
)

type TableMappingRelation struct {
Mapping map[string]string
relationTableName string
}

type relationMapping func() TableMappingRelation

func applyBaseMapping(result map[string]string, columns []string, tableName string) {
for _, c := range columns {
mappingKey := c
mappingValue := fmt.Sprintf("%s.%s", tableName, c)
columnParts := strings.Split(c, ".")
if len(columnParts) == 1 {
mappingKey = mappingValue
}
if len(columnParts) == 2 {
mappingValue = strings.Split(mappingKey, ".")[1]
}
result[mappingKey] = mappingValue
}
}

func applyRelationMapping(result map[string]string, relations []relationMapping) {
for _, relation := range relations {
tableMappingRelation := relation()
for k, v := range tableMappingRelation.Mapping {
if _, ok := result[k]; ok {
result[tableMappingRelation.relationTableName+"."+k] = v
} else {
result[k] = v
}
}
}
}

type Where struct {
sql string
values []any
Expand Down
63 changes: 63 additions & 0 deletions pkg/dao/generic_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package dao

import (
"fmt"
"strings"

. "github.com/onsi/ginkgo/v2/dsl/core"
. "github.com/onsi/gomega"
)

var _ = Describe("applyBaseMapping", func() {
It("generates base mapping", func() {
result := map[string]string{}
applyBaseMapping(result, []string{"id", "created_at", "column1", "nested.field"}, "test_table")
for k, v := range result {
if strings.HasPrefix(k, "test_table") {
Expect(k).To(Equal(v))
continue
}
// nested fields from table
i := strings.Index(k, ".")
Expect(k[i+1:]).To(Equal(v))
}
})
})

var _ = Describe("applyRelationMapping", func() {
It("generates relation mapping", func() {
result := map[string]string{}
applyBaseMapping(result, []string{"id", "created_at", "column1", "nested.field"}, "base_table")
applyRelationMapping(result, []relationMapping{
func() TableMappingRelation {
result := map[string]string{}
applyBaseMapping(result, []string{"id", "created_at", "column1", "nested.field"}, "relation_table")
return TableMappingRelation{
relationTableName: "relation_table",
Mapping: result,
}
},
})
for k, v := range result {
if strings.HasPrefix(k, "base_table") {
Expect(k).To(Equal(v))
continue
}
if strings.HasPrefix(k, "relation_table") {
if c := strings.Count(k, "."); c > 1 {
i := strings.Index(k, ".")
i = strings.Index(k[i+1:], ".") + i
Expect(k[i+2:]).To(Equal(v))
continue
}
Expect(k).To(Equal(v))
continue
}

// nested fields from base table
i := strings.Index(k, ".")
Expect(k[i+1:]).To(Equal(v))
fmt.Println(k, v)
}
})
})
71 changes: 41 additions & 30 deletions pkg/db/sql_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package db
import (
"fmt"
"reflect"
"slices"
"strings"

"github.com/jinzhu/inflection"
Expand All @@ -11,6 +12,11 @@ import (
"gorm.io/gorm"
)

const (
invalidFieldNameMsg = "%s is not a valid field name"
disallowedFieldNameMsg = "%s is a disallowed field name"
)

// Check if a field name starts with properties.
func startsWithProperties(s string) bool {
return strings.HasPrefix(s, "properties.")
Expand All @@ -33,34 +39,33 @@ func hasProperty(n tsl.Node) bool {
}

// getField gets the sql field associated with a name.
func getField(name string, disallowedFields map[string]string) (field string, err *errors.ServiceError) {
func getField(
name string,
disallowedFields []string,
apiToModel map[string]string,
) (field string, err *errors.ServiceError) {
// We want to accept names with trailing and leading spaces
trimmedName := strings.Trim(name, " ")

// Check for properties ->> '<some field name>'
if strings.HasPrefix(trimmedName, "properties ->>") {
field = trimmedName
return
mappedField, ok := apiToModel[trimmedName]
if !ok {
return "", errors.BadRequest(invalidFieldNameMsg, name)
}

// Check for nested field, e.g., subscription_labels.key
checkName := trimmedName
fieldParts := strings.Split(trimmedName, ".")
checkName := mappedField
fieldParts := strings.Split(checkName, ".")
if len(fieldParts) > 2 {
err = errors.BadRequest("%s is not a valid field name", name)
err = errors.BadRequest(invalidFieldNameMsg, name)
return
}
if len(fieldParts) > 1 {
checkName = fieldParts[1]
}

// Check for allowed fields
_, ok := disallowedFields[checkName]
if ok {
err = errors.BadRequest("%s is not a valid field name", name)
if slices.Contains(disallowedFields, checkName) {
err = errors.BadRequest(disallowedFieldNameMsg, name)
return
}
field = trimmedName
field = checkName
return
}

Expand Down Expand Up @@ -102,7 +107,8 @@ func propertiesNodeConverter(n tsl.Node) tsl.Node {
// b. replace the field name with the SQL column name.
func FieldNameWalk(
n tsl.Node,
disallowedFields map[string]string) (newNode tsl.Node, err *errors.ServiceError) {
disallowedFields []string,
apiToModel map[string]string) (newNode tsl.Node, err *errors.ServiceError) {

var field string
var l, r tsl.Node
Expand All @@ -124,7 +130,7 @@ func FieldNameWalk(
}

// Check field name in the disallowedFields field names.
field, err = getField(userFieldName, disallowedFields)
field, err = getField(userFieldName, disallowedFields, apiToModel)
if err != nil {
return
}
Expand All @@ -137,7 +143,7 @@ func FieldNameWalk(
default:
// o/w continue walking the tree.
if n.Left != nil {
l, err = FieldNameWalk(n.Left.(tsl.Node), disallowedFields)
l, err = FieldNameWalk(n.Left.(tsl.Node), disallowedFields, apiToModel)
if err != nil {
return
}
Expand All @@ -148,7 +154,7 @@ func FieldNameWalk(
switch v := n.Right.(type) {
case tsl.Node:
// It's a regular node, just add it.
r, err = FieldNameWalk(v, disallowedFields)
r, err = FieldNameWalk(v, disallowedFields, apiToModel)
if err != nil {
return
}
Expand All @@ -162,7 +168,7 @@ func FieldNameWalk(

// Add all nodes in the right side array.
for _, e := range v {
r, err = FieldNameWalk(e, disallowedFields)
r, err = FieldNameWalk(e, disallowedFields, apiToModel)
if err != nil {
return
}
Expand All @@ -189,23 +195,26 @@ func FieldNameWalk(
}

// cleanOrderBy takes the orderBy arg and cleans it.
func cleanOrderBy(userArg string, disallowedFields map[string]string) (orderBy string, err *errors.ServiceError) {
func cleanOrderBy(userArg string,
disallowedFields []string,
apiToModel map[string]string,
tableName string) (orderBy string, err *errors.ServiceError) {
var orderField string

// We want to accept user params with trailing and leading spaces
trimedName := strings.Trim(userArg, " ")

// Each OrderBy can be a "<field-name>" or a "<field-name> asc|desc"
order := strings.Split(trimedName, " ")
direction := "none valid"

if len(order) == 1 {
orderField, err = getField(order[0], disallowedFields)
direction = "asc"
} else if len(order) == 2 {
orderField, err = getField(order[0], disallowedFields)
direction := "asc"
if len(order) == 2 {
direction = order[1]
}
field := order[0]
if orderParts := strings.Split(order[0], "."); len(orderParts) == 1 {
field = fmt.Sprintf("%s.%s", tableName, field)
}
orderField, err = getField(field, disallowedFields, apiToModel)
if err != nil || (direction != "asc" && direction != "desc") {
err = errors.BadRequest("bad order value '%s'", userArg)
return
Expand All @@ -218,13 +227,15 @@ func cleanOrderBy(userArg string, disallowedFields map[string]string) (orderBy s
// ArgsToOrderBy returns cleaned orderBy list.
func ArgsToOrderBy(
orderByArgs []string,
disallowedFields map[string]string) (orderBy []string, err *errors.ServiceError) {
disallowedFields []string,
apiToModel map[string]string,
tableName string) (orderBy []string, err *errors.ServiceError) {

var order string
if len(orderByArgs) != 0 {
orderBy = []string{}
for _, o := range orderByArgs {
order, err = cleanOrderBy(o, disallowedFields)
order, err = cleanOrderBy(o, disallowedFields, apiToModel, tableName)
if err != nil {
return
}
Expand Down
Loading

0 comments on commit c90765f

Please sign in to comment.