Skip to content

Commit

Permalink
Move body parsing into functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Ajnasz committed Nov 2, 2020
1 parent 170b65b commit 849273c
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 83 deletions.
19 changes: 18 additions & 1 deletion build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,25 @@ REMOVE=0
OSLIST="linux darwin freebsd"
ARCHLIST="amd64 386"
STORAGELIST="postgres redis sqlite"
BUILD=0

while getopts "ra:o:s:" opt
subcommand="$1"
shift
case "$subcommand" in
"test")
go test -v -tags test
return
;;
"build")
BUILD=1;
;;
*)
echo "Invalid command, available commands are \"test\" and \"build\""
exit 1
;;
esac

while getopts "tra:o:s:" opt
do
case "$opt" in
"r")
Expand Down
89 changes: 7 additions & 82 deletions create_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@ import (
"encoding/hex"
"encoding/json"
"fmt"
"io/ioutil"
"log"
"mime"
"net/http"
"time"
)
Expand All @@ -22,83 +20,18 @@ func createKey() ([]byte, string, error) {
return key, keyString, nil
}

func getExpiration(expire string, defaultExpire time.Duration) (time.Duration, error) {
if expire == "" {
return defaultExpire, nil
}
userExpire, err := time.ParseDuration(expire)
if err != nil {
return 0, err
}

maxExpire := time.Duration(maxExpireSeconds) * time.Second

if userExpire > maxExpire {
return 0, fmt.Errorf("Invalid expiration date")
}

return userExpire, nil
}

func getRequestBody(r *http.Request) ([]byte, error) {
var body []byte
var err error

ct := r.Header.Get("content-type")
if ct == "" {
ct = "application/octet-stream"
}
ct, _, err = mime.ParseMediaType(ct)
if err != nil {
return nil, err
}

switch {
case ct == "multipart/form-data":
err = r.ParseMultipartForm(1024 * 1024)
if err != nil {
return nil, err
}

secret := r.PostForm.Get("secret")
if secret != "" {
body = []byte(secret)
} else {
file, _, err := r.FormFile("secret")

if err != nil {
return nil, err
}

body, err = ioutil.ReadAll(file)

if err != nil {
return nil, err
}
}
default:
body, err = ioutil.ReadAll(r.Body)
}

return body, err
}

func getExpirationR(r *http.Request) (time.Duration, error) {
var expiration string
r.ParseForm()
expiration = r.Form.Get("expire")

return getExpiration(expiration, time.Second*time.Duration(expireSeconds))
}

func handleCreateEntry(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxDataSize)

body, err := getRequestBody(r)
data, err := parseCreateRequest(r)

if err != nil {
if err.Error() == "http: request body too large" {
http.Error(w, "Too large", http.StatusRequestEntityTooLarge)
} else if err.Error() == "Invalid expiration date" {
log.Println(err)
http.Error(w, "Invalid expiration", http.StatusBadRequest)
return
} else {
http.Error(w, "Internal error", http.StatusInternalServerError)
}
Expand All @@ -117,15 +50,7 @@ func handleCreateEntry(w http.ResponseWriter, r *http.Request) {

UUID := newUUIDString()

expiration, err := getExpirationR(r)

if err != nil {
log.Println(err)
http.Error(w, "Invalid expiration", http.StatusBadRequest)
return
}

err = secretStore.Create(UUID, body, expiration)
err = secretStore.Create(UUID, data.body, data.expiration)

if err != nil {
log.Println(err)
Expand All @@ -135,7 +60,7 @@ func handleCreateEntry(w http.ResponseWriter, r *http.Request) {

w.Header().Add("x-entry-uuid", UUID)
w.Header().Add("x-entry-key", keyString)
w.Header().Add("x-entry-expire", time.Now().Add(expiration).Format(time.RFC3339))
w.Header().Add("x-entry-expire", time.Now().Add(data.expiration).Format(time.RFC3339))
if r.Header.Get("Accept") == "application/json" {
w.Header().Set("Content-Type", "application/json")
entry, err := secretStore.GetMeta(UUID)
Expand Down
105 changes: 105 additions & 0 deletions parse_create_request.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package main

import (
"fmt"
"io/ioutil"
"mime"
"net/http"
"time"
)

type requestData struct {
body []byte
expiration time.Duration
}

func calculateExpiration(expire string, defaultExpire time.Duration) (time.Duration, error) {
if expire == "" {
return defaultExpire, nil
}

userExpire, err := time.ParseDuration(expire)
if err != nil {
return 0, err
}

maxExpire := time.Duration(maxExpireSeconds) * time.Second

if userExpire > maxExpire {
return 0, fmt.Errorf("Invalid expiration date")
}

return userExpire, nil
}

func parseMultiForm(r *http.Request) ([]byte, error) {
err := r.ParseMultipartForm(1024 * 1024)
if err != nil {
return nil, err
}

secret := r.PostForm.Get("secret")
if secret != "" {
body := []byte(secret)
return body, nil
}

file, _, err := r.FormFile("secret")

if err != nil {
return nil, err
}

return ioutil.ReadAll(file)
}

func getBody(r *http.Request) ([]byte, error) {
ct := r.Header.Get("content-type")
if ct == "" {
ct = "application/octet-stream"
}
ct, _, err := mime.ParseMediaType(ct)

if err != nil {
return nil, err
}

switch {
case ct == "multipart/form-data":
return parseMultiForm(r)
default:
return ioutil.ReadAll(r.Body)
}
}

func (b requestData) ContentType() string {
return "plain/text"
}

func getSecretExpiration(r *http.Request) (time.Duration, error) {
var expiration string
r.ParseForm()
expiration = r.Form.Get("expire")

return calculateExpiration(expiration, time.Second*time.Duration(expireSeconds))
}

func parseCreateRequest(r *http.Request) (*requestData, error) {
body, err := getBody(r)

if err != nil {
return nil, err
}

expiration, err := getSecretExpiration(r)

if err != nil {
return nil, err
}

return &requestData{
body: body,
expiration: expiration,
}, nil

}

0 comments on commit 849273c

Please sign in to comment.