From 5a7d285f866e40555f47494c23ff02a53dce57f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=80=97=E5=AD=90?= Date: Sat, 2 Mar 2024 12:59:37 +0800 Subject: [PATCH] fix: content type (#63) --- context_request.go | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/context_request.go b/context_request.go index 912fc8b..fe038b9 100644 --- a/context_request.go +++ b/context_request.go @@ -25,18 +25,18 @@ import ( type ContextRequest struct { ctx *Context instance *fiber.Ctx - postData map[string]any + httpBody map[string]any log log.Log validation contractsvalidate.Validation } func NewContextRequest(ctx *Context, log log.Log, validation contractsvalidate.Validation) contractshttp.ContextRequest { - postData, err := getPostData(ctx) + httpBody, err := getHttpBody(ctx) if err != nil { LogFacade.Error(fmt.Sprintf("%+v", errors.Unwrap(err))) } - return &ContextRequest{ctx: ctx, instance: ctx.instance, postData: postData, log: log, validation: validation} + return &ContextRequest{ctx: ctx, instance: ctx.instance, httpBody: httpBody, log: log, validation: validation} } func (r *ContextRequest) AbortWithStatus(code int) { @@ -60,7 +60,7 @@ func (r *ContextRequest) All() map[string]any { for k, v := range r.instance.Queries() { data[k] = v } - for k, v := range r.postData { + for k, v := range r.httpBody { data[k] = v } @@ -255,7 +255,7 @@ func (r *ContextRequest) Path() string { func (r *ContextRequest) Input(key string, defaultValue ...string) string { keys := strings.Split(key, ".") - current := r.postData + current := r.httpBody for _, k := range keys { value, found := current[k] if found { @@ -281,7 +281,7 @@ func (r *ContextRequest) Input(key string, defaultValue ...string) string { func (r *ContextRequest) InputArray(key string, defaultValue ...[]string) []string { keys := strings.Split(key, ".") - current := r.postData + current := r.httpBody for _, k := range keys { value, found := current[k] if !found { @@ -303,7 +303,7 @@ func (r *ContextRequest) InputArray(key string, defaultValue ...[]string) []stri func (r *ContextRequest) InputMap(key string, defaultValue ...map[string]string) map[string]string { keys := strings.Split(key, ".") - current := r.postData + current := r.httpBody for _, k := range keys { value, found := current[k] if !found { @@ -433,7 +433,7 @@ func (r *ContextRequest) ValidateRequest(request contractshttp.FormRequest) (con return validator.Errors(), nil } -func getPostData(ctx *Context) (map[string]any, error) { +func getHttpBody(ctx *Context) (map[string]any, error) { if len(ctx.instance.Request().Body()) == 0 { return nil, nil } @@ -441,7 +441,7 @@ func getPostData(ctx *Context) (map[string]any, error) { contentType := ctx.instance.Get("Content-Type") data := make(map[string]any) - if contentType == "application/json" { + if strings.Contains(contentType, "application/json") { bodyBytes := ctx.instance.Body() if err := json.Unmarshal(bodyBytes, &data); err != nil { @@ -449,7 +449,7 @@ func getPostData(ctx *Context) (map[string]any, error) { } } - if contentType == "multipart/form-data" { + if strings.Contains(contentType, "multipart/form-data") { if form, err := ctx.instance.MultipartForm(); err == nil { for k, v := range form.Value { data[k] = strings.Join(v, ",") @@ -462,7 +462,7 @@ func getPostData(ctx *Context) (map[string]any, error) { } } - if contentType == "application/x-www-form-urlencoded" { + if strings.Contains(contentType, "application/x-www-form-urlencoded") { args := ctx.instance.Request().PostArgs() args.VisitAll(func(key, value []byte) { data[utils.UnsafeString(key)] = utils.UnsafeString(value)