From 849273c78e0849fc97c19f45ef055180a9a06fbd Mon Sep 17 00:00:00 2001 From: Lajos Koszti Date: Mon, 2 Nov 2020 08:42:13 +0100 Subject: [PATCH] Move body parsing into functions --- build.sh | 19 +++++++- create_handler.go | 89 +++------------------------------- parse_create_request.go | 105 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 130 insertions(+), 83 deletions(-) create mode 100644 parse_create_request.go diff --git a/build.sh b/build.sh index 76bab3a..62cc9fd 100755 --- a/build.sh +++ b/build.sh @@ -42,8 +42,25 @@ REMOVE=0 OSLIST="linux darwin freebsd" ARCHLIST="amd64 386" STORAGELIST="postgres redis sqlite" +BUILD=0 -while getopts "ra:o:s:" opt +subcommand="$1" +shift +case "$subcommand" in + "test") + go test -v -tags test + return + ;; + "build") + BUILD=1; + ;; + *) + echo "Invalid command, available commands are \"test\" and \"build\"" + exit 1 + ;; +esac + +while getopts "tra:o:s:" opt do case "$opt" in "r") diff --git a/create_handler.go b/create_handler.go index 95df176..ea35a98 100644 --- a/create_handler.go +++ b/create_handler.go @@ -4,9 +4,7 @@ import ( "encoding/hex" "encoding/json" "fmt" - "io/ioutil" "log" - "mime" "net/http" "time" ) @@ -22,83 +20,18 @@ func createKey() ([]byte, string, error) { return key, keyString, nil } -func getExpiration(expire string, defaultExpire time.Duration) (time.Duration, error) { - if expire == "" { - return defaultExpire, nil - } - userExpire, err := time.ParseDuration(expire) - if err != nil { - return 0, err - } - - maxExpire := time.Duration(maxExpireSeconds) * time.Second - - if userExpire > maxExpire { - return 0, fmt.Errorf("Invalid expiration date") - } - - return userExpire, nil -} - -func getRequestBody(r *http.Request) ([]byte, error) { - var body []byte - var err error - - ct := r.Header.Get("content-type") - if ct == "" { - ct = "application/octet-stream" - } - ct, _, err = mime.ParseMediaType(ct) - if err != nil { - return nil, err - } - - switch { - case ct == "multipart/form-data": - err = r.ParseMultipartForm(1024 * 1024) - if err != nil { - return nil, err - } - - secret := r.PostForm.Get("secret") - if secret != "" { - body = []byte(secret) - } else { - file, _, err := r.FormFile("secret") - - if err != nil { - return nil, err - } - - body, err = ioutil.ReadAll(file) - - if err != nil { - return nil, err - } - } - default: - body, err = ioutil.ReadAll(r.Body) - } - - return body, err -} - -func getExpirationR(r *http.Request) (time.Duration, error) { - var expiration string - r.ParseForm() - expiration = r.Form.Get("expire") - - return getExpiration(expiration, time.Second*time.Duration(expireSeconds)) -} - func handleCreateEntry(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxDataSize) - body, err := getRequestBody(r) + data, err := parseCreateRequest(r) if err != nil { if err.Error() == "http: request body too large" { http.Error(w, "Too large", http.StatusRequestEntityTooLarge) + } else if err.Error() == "Invalid expiration date" { + log.Println(err) + http.Error(w, "Invalid expiration", http.StatusBadRequest) + return } else { http.Error(w, "Internal error", http.StatusInternalServerError) } @@ -117,15 +50,7 @@ func handleCreateEntry(w http.ResponseWriter, r *http.Request) { UUID := newUUIDString() - expiration, err := getExpirationR(r) - - if err != nil { - log.Println(err) - http.Error(w, "Invalid expiration", http.StatusBadRequest) - return - } - - err = secretStore.Create(UUID, body, expiration) + err = secretStore.Create(UUID, data.body, data.expiration) if err != nil { log.Println(err) @@ -135,7 +60,7 @@ func handleCreateEntry(w http.ResponseWriter, r *http.Request) { w.Header().Add("x-entry-uuid", UUID) w.Header().Add("x-entry-key", keyString) - w.Header().Add("x-entry-expire", time.Now().Add(expiration).Format(time.RFC3339)) + w.Header().Add("x-entry-expire", time.Now().Add(data.expiration).Format(time.RFC3339)) if r.Header.Get("Accept") == "application/json" { w.Header().Set("Content-Type", "application/json") entry, err := secretStore.GetMeta(UUID) diff --git a/parse_create_request.go b/parse_create_request.go new file mode 100644 index 0000000..4cff10b --- /dev/null +++ b/parse_create_request.go @@ -0,0 +1,105 @@ +package main + +import ( + "fmt" + "io/ioutil" + "mime" + "net/http" + "time" +) + +type requestData struct { + body []byte + expiration time.Duration +} + +func calculateExpiration(expire string, defaultExpire time.Duration) (time.Duration, error) { + if expire == "" { + return defaultExpire, nil + } + + userExpire, err := time.ParseDuration(expire) + if err != nil { + return 0, err + } + + maxExpire := time.Duration(maxExpireSeconds) * time.Second + + if userExpire > maxExpire { + return 0, fmt.Errorf("Invalid expiration date") + } + + return userExpire, nil +} + +func parseMultiForm(r *http.Request) ([]byte, error) { + err := r.ParseMultipartForm(1024 * 1024) + if err != nil { + return nil, err + } + + secret := r.PostForm.Get("secret") + if secret != "" { + body := []byte(secret) + return body, nil + } + + file, _, err := r.FormFile("secret") + + if err != nil { + return nil, err + } + + return ioutil.ReadAll(file) +} + +func getBody(r *http.Request) ([]byte, error) { + ct := r.Header.Get("content-type") + if ct == "" { + ct = "application/octet-stream" + } + ct, _, err := mime.ParseMediaType(ct) + + if err != nil { + return nil, err + } + + switch { + case ct == "multipart/form-data": + return parseMultiForm(r) + default: + return ioutil.ReadAll(r.Body) + } +} + +func (b requestData) ContentType() string { + return "plain/text" +} + +func getSecretExpiration(r *http.Request) (time.Duration, error) { + var expiration string + r.ParseForm() + expiration = r.Form.Get("expire") + + return calculateExpiration(expiration, time.Second*time.Duration(expireSeconds)) +} + +func parseCreateRequest(r *http.Request) (*requestData, error) { + body, err := getBody(r) + + if err != nil { + return nil, err + } + + expiration, err := getSecretExpiration(r) + + if err != nil { + return nil, err + } + + return &requestData{ + body: body, + expiration: expiration, + }, nil + +}