From c90765fa6589369e033faa01352ae0dffd166df4 Mon Sep 17 00:00:00 2001 From: Guilherme Branco Date: Thu, 13 Jun 2024 13:47:35 -0300 Subject: [PATCH] fix: allows to search for exposed nested resources and doesn't expose internal sql columns --- pkg/dao/dinosaur.go | 20 ++++++++ pkg/dao/generic.go | 36 ++++++++++++++ pkg/dao/generic_test.go | 63 ++++++++++++++++++++++++ pkg/db/sql_helpers.go | 71 +++++++++++++++------------ pkg/services/generic.go | 93 +++++++++++++++++++++++++++--------- pkg/services/generic_test.go | 41 +++++++++------- 6 files changed, 255 insertions(+), 69 deletions(-) create mode 100644 pkg/dao/generic_test.go diff --git a/pkg/dao/dinosaur.go b/pkg/dao/dinosaur.go index e100c9dc..61e2741c 100644 --- a/pkg/dao/dinosaur.go +++ b/pkg/dao/dinosaur.go @@ -7,8 +7,28 @@ import ( "github.com/openshift-online/rh-trex/pkg/api" "github.com/openshift-online/rh-trex/pkg/db" + "github.com/openshift-online/rh-trex/pkg/util" ) +var ( + dinosaurTableName = util.ToSnakeCase(api.DinosaurTypeName) + "s" + dinosaurColumns = []string{ + "id", + "created_at", + "updated_at", + "species", + } +) + +func DinosaurApiToModel() TableMappingRelation { + result := map[string]string{} + applyBaseMapping(result, dinosaurColumns, dinosaurTableName) + return TableMappingRelation{ + Mapping: result, + relationTableName: dinosaurTableName, + } +} + type DinosaurDao interface { Get(ctx context.Context, id string) (*api.Dinosaur, error) Create(ctx context.Context, dinosaur *api.Dinosaur) (*api.Dinosaur, error) diff --git a/pkg/dao/generic.go b/pkg/dao/generic.go index c7680e50..3ef71835 100644 --- a/pkg/dao/generic.go +++ b/pkg/dao/generic.go @@ -2,6 +2,7 @@ package dao import ( "context" + "fmt" "strings" "github.com/jinzhu/inflection" @@ -10,6 +11,41 @@ import ( "github.com/openshift-online/rh-trex/pkg/db" ) +type TableMappingRelation struct { + Mapping map[string]string + relationTableName string +} + +type relationMapping func() TableMappingRelation + +func applyBaseMapping(result map[string]string, columns []string, tableName string) { + for _, c := range columns { + mappingKey := c + mappingValue := fmt.Sprintf("%s.%s", tableName, c) + columnParts := strings.Split(c, ".") + if len(columnParts) == 1 { + mappingKey = mappingValue + } + if len(columnParts) == 2 { + mappingValue = strings.Split(mappingKey, ".")[1] + } + result[mappingKey] = mappingValue + } +} + +func applyRelationMapping(result map[string]string, relations []relationMapping) { + for _, relation := range relations { + tableMappingRelation := relation() + for k, v := range tableMappingRelation.Mapping { + if _, ok := result[k]; ok { + result[tableMappingRelation.relationTableName+"."+k] = v + } else { + result[k] = v + } + } + } +} + type Where struct { sql string values []any diff --git a/pkg/dao/generic_test.go b/pkg/dao/generic_test.go new file mode 100644 index 00000000..59346b75 --- /dev/null +++ b/pkg/dao/generic_test.go @@ -0,0 +1,63 @@ +package dao + +import ( + "fmt" + "strings" + + . "github.com/onsi/ginkgo/v2/dsl/core" + . "github.com/onsi/gomega" +) + +var _ = Describe("applyBaseMapping", func() { + It("generates base mapping", func() { + result := map[string]string{} + applyBaseMapping(result, []string{"id", "created_at", "column1", "nested.field"}, "test_table") + for k, v := range result { + if strings.HasPrefix(k, "test_table") { + Expect(k).To(Equal(v)) + continue + } + // nested fields from table + i := strings.Index(k, ".") + Expect(k[i+1:]).To(Equal(v)) + } + }) +}) + +var _ = Describe("applyRelationMapping", func() { + It("generates relation mapping", func() { + result := map[string]string{} + applyBaseMapping(result, []string{"id", "created_at", "column1", "nested.field"}, "base_table") + applyRelationMapping(result, []relationMapping{ + func() TableMappingRelation { + result := map[string]string{} + applyBaseMapping(result, []string{"id", "created_at", "column1", "nested.field"}, "relation_table") + return TableMappingRelation{ + relationTableName: "relation_table", + Mapping: result, + } + }, + }) + for k, v := range result { + if strings.HasPrefix(k, "base_table") { + Expect(k).To(Equal(v)) + continue + } + if strings.HasPrefix(k, "relation_table") { + if c := strings.Count(k, "."); c > 1 { + i := strings.Index(k, ".") + i = strings.Index(k[i+1:], ".") + i + Expect(k[i+2:]).To(Equal(v)) + continue + } + Expect(k).To(Equal(v)) + continue + } + + // nested fields from base table + i := strings.Index(k, ".") + Expect(k[i+1:]).To(Equal(v)) + fmt.Println(k, v) + } + }) +}) diff --git a/pkg/db/sql_helpers.go b/pkg/db/sql_helpers.go index cc867489..ad0497cc 100644 --- a/pkg/db/sql_helpers.go +++ b/pkg/db/sql_helpers.go @@ -3,6 +3,7 @@ package db import ( "fmt" "reflect" + "slices" "strings" "github.com/jinzhu/inflection" @@ -11,6 +12,11 @@ import ( "gorm.io/gorm" ) +const ( + invalidFieldNameMsg = "%s is not a valid field name" + disallowedFieldNameMsg = "%s is a disallowed field name" +) + // Check if a field name starts with properties. func startsWithProperties(s string) bool { return strings.HasPrefix(s, "properties.") @@ -33,34 +39,33 @@ func hasProperty(n tsl.Node) bool { } // getField gets the sql field associated with a name. -func getField(name string, disallowedFields map[string]string) (field string, err *errors.ServiceError) { +func getField( + name string, + disallowedFields []string, + apiToModel map[string]string, +) (field string, err *errors.ServiceError) { // We want to accept names with trailing and leading spaces trimmedName := strings.Trim(name, " ") - // Check for properties ->> '' - if strings.HasPrefix(trimmedName, "properties ->>") { - field = trimmedName - return + mappedField, ok := apiToModel[trimmedName] + if !ok { + return "", errors.BadRequest(invalidFieldNameMsg, name) } // Check for nested field, e.g., subscription_labels.key - checkName := trimmedName - fieldParts := strings.Split(trimmedName, ".") + checkName := mappedField + fieldParts := strings.Split(checkName, ".") if len(fieldParts) > 2 { - err = errors.BadRequest("%s is not a valid field name", name) + err = errors.BadRequest(invalidFieldNameMsg, name) return } - if len(fieldParts) > 1 { - checkName = fieldParts[1] - } // Check for allowed fields - _, ok := disallowedFields[checkName] - if ok { - err = errors.BadRequest("%s is not a valid field name", name) + if slices.Contains(disallowedFields, checkName) { + err = errors.BadRequest(disallowedFieldNameMsg, name) return } - field = trimmedName + field = checkName return } @@ -102,7 +107,8 @@ func propertiesNodeConverter(n tsl.Node) tsl.Node { // b. replace the field name with the SQL column name. func FieldNameWalk( n tsl.Node, - disallowedFields map[string]string) (newNode tsl.Node, err *errors.ServiceError) { + disallowedFields []string, + apiToModel map[string]string) (newNode tsl.Node, err *errors.ServiceError) { var field string var l, r tsl.Node @@ -124,7 +130,7 @@ func FieldNameWalk( } // Check field name in the disallowedFields field names. - field, err = getField(userFieldName, disallowedFields) + field, err = getField(userFieldName, disallowedFields, apiToModel) if err != nil { return } @@ -137,7 +143,7 @@ func FieldNameWalk( default: // o/w continue walking the tree. if n.Left != nil { - l, err = FieldNameWalk(n.Left.(tsl.Node), disallowedFields) + l, err = FieldNameWalk(n.Left.(tsl.Node), disallowedFields, apiToModel) if err != nil { return } @@ -148,7 +154,7 @@ func FieldNameWalk( switch v := n.Right.(type) { case tsl.Node: // It's a regular node, just add it. - r, err = FieldNameWalk(v, disallowedFields) + r, err = FieldNameWalk(v, disallowedFields, apiToModel) if err != nil { return } @@ -162,7 +168,7 @@ func FieldNameWalk( // Add all nodes in the right side array. for _, e := range v { - r, err = FieldNameWalk(e, disallowedFields) + r, err = FieldNameWalk(e, disallowedFields, apiToModel) if err != nil { return } @@ -189,7 +195,10 @@ func FieldNameWalk( } // cleanOrderBy takes the orderBy arg and cleans it. -func cleanOrderBy(userArg string, disallowedFields map[string]string) (orderBy string, err *errors.ServiceError) { +func cleanOrderBy(userArg string, + disallowedFields []string, + apiToModel map[string]string, + tableName string) (orderBy string, err *errors.ServiceError) { var orderField string // We want to accept user params with trailing and leading spaces @@ -197,15 +206,15 @@ func cleanOrderBy(userArg string, disallowedFields map[string]string) (orderBy s // Each OrderBy can be a "" or a " asc|desc" order := strings.Split(trimedName, " ") - direction := "none valid" - - if len(order) == 1 { - orderField, err = getField(order[0], disallowedFields) - direction = "asc" - } else if len(order) == 2 { - orderField, err = getField(order[0], disallowedFields) + direction := "asc" + if len(order) == 2 { direction = order[1] } + field := order[0] + if orderParts := strings.Split(order[0], "."); len(orderParts) == 1 { + field = fmt.Sprintf("%s.%s", tableName, field) + } + orderField, err = getField(field, disallowedFields, apiToModel) if err != nil || (direction != "asc" && direction != "desc") { err = errors.BadRequest("bad order value '%s'", userArg) return @@ -218,13 +227,15 @@ func cleanOrderBy(userArg string, disallowedFields map[string]string) (orderBy s // ArgsToOrderBy returns cleaned orderBy list. func ArgsToOrderBy( orderByArgs []string, - disallowedFields map[string]string) (orderBy []string, err *errors.ServiceError) { + disallowedFields []string, + apiToModel map[string]string, + tableName string) (orderBy []string, err *errors.ServiceError) { var order string if len(orderByArgs) != 0 { orderBy = []string{} for _, o := range orderByArgs { - order, err = cleanOrderBy(o, disallowedFields) + order, err = cleanOrderBy(o, disallowedFields, apiToModel, tableName) if err != nil { return } diff --git a/pkg/services/generic.go b/pkg/services/generic.go index d4fd841e..1076bbdf 100644 --- a/pkg/services/generic.go +++ b/pkg/services/generic.go @@ -39,11 +39,18 @@ type sqlGenericService struct { } var ( - SearchDisallowedFields = map[string]map[string]string{} - allFieldsAllowed = map[string]string{} + searchDisallowedFields = map[string][]string{} + allFieldsAllowed = []string{} // Some mappings are not required as they match AMS resource 1:1 // Such as Organization modelToAmsResource = map[string]string{} + + // TODO: This should be more dynamic + // prefarably utilizing the openapi json via reflect + // and the column names from the model + openapiToModelFields = map[string]dao.TableMappingRelation{ + api.DinosaurTypeName: dao.DinosaurApiToModel(), + } ) // wrap all needed pieces for the LIST funciton @@ -54,14 +61,19 @@ type listContext struct { pagingMeta *api.PagingMeta ulog *logger.OCMLogger resourceList interface{} - disallowedFields *map[string]string + disallowedFields []string + openapiToModel map[string]string resourceType string joins map[string]dao.TableRelation groupBy []string set map[string]bool } -func newListContext(ctx context.Context, args *ListArguments, resourceList interface{}) (*listContext, interface{}, *errors.ServiceError) { +func newListContext( + ctx context.Context, + args *ListArguments, + resourceList interface{}, +) (*listContext, interface{}, *errors.ServiceError) { username := auth.GetUsernameFromContext(ctx) log := logger.NewOCMLogger(ctx) resourceModel := reflect.TypeOf(resourceList).Elem().Elem() @@ -69,10 +81,11 @@ func newListContext(ctx context.Context, args *ListArguments, resourceList inter if resourceTypeStr == "" { return nil, nil, errors.GeneralError("Could not determine resource type") } - disallowedFields := SearchDisallowedFields[resourceTypeStr] + disallowedFields := searchDisallowedFields[resourceTypeStr] if disallowedFields == nil { disallowedFields = allFieldsAllowed } + openapiToModel := openapiToModelFields[resourceTypeStr] args.Search = strings.Trim(args.Search, " ") return &listContext{ ctx: ctx, @@ -81,7 +94,8 @@ func newListContext(ctx context.Context, args *ListArguments, resourceList inter pagingMeta: &api.PagingMeta{Page: args.Page}, ulog: &log, resourceList: resourceList, - disallowedFields: &disallowedFields, + disallowedFields: disallowedFields, + openapiToModel: openapiToModel.Mapping, resourceType: resourceTypeStr, }, reflect.New(resourceModel).Interface(), nil } @@ -103,9 +117,19 @@ func (s *sqlGenericService) populateSearchRestriction(listCtx *listContext, mode resourceName = string(name) } if resourceIncludesOrgId(model) { - resourceReview, err := s.ocmClient.Authorization.ResourceReview(ctx, listCtx.username, auth.GetAction, resourceName) + resourceReview, err := s.ocmClient.Authorization.ResourceReview( + ctx, + listCtx.username, + auth.GetAction, + resourceName, + ) if err != nil { - return errors.GeneralError("Failed to verify resource review for user '%s' on resource '%s': %v", listCtx.username, listCtx.resourceType, err) + return errors.GeneralError( + "Failed to verify resource review for user '%s' on resource '%s': %v", + listCtx.username, + listCtx.resourceType, + err, + ) } // TODO setup a search builder @@ -125,7 +149,11 @@ func (s *sqlGenericService) populateSearchRestriction(listCtx *listContext, mode } // resourceList must be a pointer to a slice of database resource objects -func (s *sqlGenericService) List(ctx context.Context, args *ListArguments, resourceList interface{}) (*api.PagingMeta, *errors.ServiceError) { +func (s *sqlGenericService) List( + ctx context.Context, + args *ListArguments, + resourceList interface{}, +) (*api.PagingMeta, *errors.ServiceError) { listCtx, model, err := newListContext(ctx, args, resourceList) if err != nil { return nil, err @@ -186,7 +214,8 @@ func (s *sqlGenericService) buildPreload(listCtx *listContext, d *dao.GenericDao func (s *sqlGenericService) buildOrderBy(listCtx *listContext, d *dao.GenericDao) (bool, *errors.ServiceError) { if len(listCtx.args.OrderBy) != 0 { - orderByArgs, serviceErr := db.ArgsToOrderBy(listCtx.args.OrderBy, *listCtx.disallowedFields) + orderByArgs, serviceErr := db.ArgsToOrderBy(listCtx.args.OrderBy, listCtx.disallowedFields, + listCtx.openapiToModel, (*d).GetTableName()) if serviceErr != nil { return false, serviceErr } @@ -197,7 +226,10 @@ func (s *sqlGenericService) buildOrderBy(listCtx *listContext, d *dao.GenericDao return false, nil } -func (s *sqlGenericService) buildSearchValues(listCtx *listContext, d *dao.GenericDao) (string, []any, *errors.ServiceError) { +func (s *sqlGenericService) buildSearchValues( + listCtx *listContext, + d *dao.GenericDao, +) (string, []any, *errors.ServiceError) { if listCtx.args.Search == "" { s.addJoins(listCtx, d) return "", nil, nil @@ -323,7 +355,11 @@ func zeroSlice(i interface{}, cap int64) *errors.ServiceError { // walk the TSL tree looking for fields like, e.g., creator.username, and then: // (1) look up the related table by its 1st part - creator // (2) replace it by table name - creator.username -> accounts.username -func (s *sqlGenericService) treeWalkForRelatedTables(listCtx *listContext, tslTree tsl.Node, genericDao *dao.GenericDao) (tsl.Node, *errors.ServiceError) { +func (s *sqlGenericService) treeWalkForRelatedTables( + listCtx *listContext, + tslTree tsl.Node, + genericDao *dao.GenericDao, +) (tsl.Node, *errors.ServiceError) { resourceTable := (*genericDao).GetTableName() if listCtx.joins == nil { listCtx.joins = map[string]dao.TableRelation{} @@ -331,17 +367,21 @@ func (s *sqlGenericService) treeWalkForRelatedTables(listCtx *listContext, tslTr walkFn := func(field string) (string, error) { fieldParts := strings.Split(field, ".") if len(fieldParts) > 1 && fieldParts[0] != resourceTable { - fieldName := fieldParts[0] - _, exists := listCtx.joins[fieldName] + nestedResource := fieldParts[0] + _, exists := listCtx.joins[nestedResource] if !exists { - if relation, ok := (*genericDao).GetTableRelation(fieldName); ok { - listCtx.joins[fieldName] = relation - } else { - return field, fmt.Errorf("%s is not a related resource of %s", fieldName, listCtx.resourceType) + // Populates relation if join exists + if relation, ok := (*genericDao).GetTableRelation(nestedResource); ok { + listCtx.joins[nestedResource] = relation + } else if _, ok := listCtx.openapiToModel[field]; !ok { + // If also not exposed as a nested resource consider this is an error + return field, fmt.Errorf("%s is not a related resource of %s", strings.Join(fieldParts, "."), listCtx.resourceType) } } - //replace by table name - fieldParts[0] = listCtx.joins[fieldName].ForeignTableName + // replace by table name if coming from join + if value, ok := listCtx.joins[nestedResource]; ok { + fieldParts[0] = value.ForeignTableName + } return strings.Join(fieldParts, "."), nil } return field, nil @@ -356,7 +396,11 @@ func (s *sqlGenericService) treeWalkForRelatedTables(listCtx *listContext, tslTr } // prepend table name to these "free" identifiers since they could cause "ambiguous" errors -func (s *sqlGenericService) treeWalkForAddingTableName(listCtx *listContext, tslTree tsl.Node, dao *dao.GenericDao) (tsl.Node, *errors.ServiceError) { +func (s *sqlGenericService) treeWalkForAddingTableName( + listCtx *listContext, + tslTree tsl.Node, + dao *dao.GenericDao, +) (tsl.Node, *errors.ServiceError) { resourceTable := (*dao).GetTableName() walkFn := func(field string) (string, error) { @@ -378,9 +422,12 @@ func (s *sqlGenericService) treeWalkForAddingTableName(listCtx *listContext, tsl return tslTree, nil } -func (s *sqlGenericService) treeWalkForSqlizer(listCtx *listContext, tslTree tsl.Node) (tsl.Node, squirrel.Sqlizer, *errors.ServiceError) { +func (s *sqlGenericService) treeWalkForSqlizer( + listCtx *listContext, + tslTree tsl.Node, +) (tsl.Node, squirrel.Sqlizer, *errors.ServiceError) { // Check field names in tree - tslTree, serviceErr := db.FieldNameWalk(tslTree, *listCtx.disallowedFields) + tslTree, serviceErr := db.FieldNameWalk(tslTree, listCtx.disallowedFields, listCtx.openapiToModel) if serviceErr != nil { return tslTree, nil, serviceErr } diff --git a/pkg/services/generic_test.go b/pkg/services/generic_test.go index ddfd06aa..403d424e 100644 --- a/pkg/services/generic_test.go +++ b/pkg/services/generic_test.go @@ -147,23 +147,32 @@ var _ = Describe("Sql Translation", func() { genericDao = dao.NewGenericDao(&dbFactory) genericService = sqlGenericService{genericDao: genericDao} }) - DescribeTable("Errors", func( - search string, errorMsg string) { - listCtx, model, serviceErr := newListContext( - context.Background(), - &ListArguments{Search: search}, - &[]api.Dinosaur{}, - ) - Expect(serviceErr).ToNot(HaveOccurred()) - d := genericDao.GetInstanceDao(context.Background(), model) - (*listCtx.disallowedFields)["id"] = "id" - _, serviceErr = genericService.buildSearch(listCtx, &d) - Expect(serviceErr).To(HaveOccurred()) - Expect(serviceErr.Code).To(Equal(errors.ErrorBadRequest)) - Expect(serviceErr.Error()).To(Equal(errorMsg)) - }, + DescribeTable( + "Errors", + func( + search string, errorMsg string) { + listCtx, model, serviceErr := newListContext( + context.Background(), + &ListArguments{Search: search}, + &[]api.Dinosaur{}, + ) + Expect(serviceErr).ToNot(HaveOccurred()) + d := genericDao.GetInstanceDao(context.Background(), model) + listCtx.disallowedFields = []string{"dinosaurs.id"} + _, serviceErr = genericService.buildSearch(listCtx, &d) + Expect(serviceErr).To(HaveOccurred()) + Expect(serviceErr.Code).To(Equal(errors.ErrorBadRequest)) + Expect(serviceErr.Error()).To(Equal(errorMsg)) + }, Entry("Garbage", "garbage", "rh-trex-21: Failed to parse search query: garbage"), - Entry("Invalid field name", "id in ('123')", "rh-trex-21: dinosaurs.id is not a valid field name")) + Entry("Disallowed field name", "id in ('123')", "rh-trex-21: dinosaurs.id is a disallowed field name"), + Entry("Unknown field name", "bike = '123'", "rh-trex-21: dinosaurs.bike is not a valid field name"), + Entry( + "Unknown relation field", + "status.bike = '123'", + "rh-trex-21: status.bike is not a related resource of Dinosaur", + ), + ) DescribeTable("Sql Parsing", func( search string, sqlReal string, valuesReal types.GomegaMatcher) {