Skip to content

Commit

Permalink
feat: add better hmac sign check; add mandatory fields check feature …
Browse files Browse the repository at this point in the history
…in hmac sign
  • Loading branch information
sebastocorp committed Dec 18, 2024
1 parent 7db04b2 commit a259482
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 20 deletions.
3 changes: 2 additions & 1 deletion api/v1alpha2/config_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ type HmacConfigT struct {
EncryptionAlgorithm string `yaml:"encryptionAlgorithm"`

//
Url HmacUrlConfigT `yaml:"url,omitempty"`
MandatoryFields []string `yaml:"mandatoryFields,omitempty"`
Url HmacUrlConfigT `yaml:"url,omitempty"`
}

type HmacUrlConfigT struct {
Expand Down
4 changes: 3 additions & 1 deletion docs/samples/doorkeeper.v1alpha2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ authorizations:
type: URL
encryptionKey: ${ENV:ENVIRONMENT_VARIABLE_WITH_ENCRYPTION_KEY}$
encryptionAlgorithm: "sha256"

mandatoryFields:
- hmac
- exp
url:
# (Optional) Transforms special characters (including /) with %XX sequences as needed
# When lowerEncode is true, encoded chars will be lowercase (e.g. %2f instead of %2F)
Expand Down
6 changes: 3 additions & 3 deletions internal/doorkeeper/doorkeeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,14 @@ func (d *DoorkeeperT) handleRequest(w http.ResponseWriter, r *http.Request) {

valid, err = checkAuthorization(r, d.auths[authn])
if err != nil {
valid = false
utils.SetLogField(logFields, utils.LogFieldKeyError, err.Error())
d.log.Error("unable to check authorization", logFields)
return
}
utils.SetLogField(logFields, utils.LogFieldKeyAuthorizationResult, strconv.FormatBool(valid))

d.log.Debug("success in check authorization", logFields)
d.log.Debug("check authorization result", logFields)
reqResults = append(reqResults, valid)
utils.SetLogField(logFields, utils.LogFieldKeyError, utils.LogFieldValueDefaultStr)
}
utils.SetLogField(logFields, utils.LogFieldKeyAuthorization, utils.LogFieldValueDefaultStr)
utils.SetLogField(logFields, utils.LogFieldKeyAuthorizationResult, utils.LogFieldValueDefaultStr)
Expand Down
4 changes: 2 additions & 2 deletions internal/doorkeeper/wrappers.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ func checkAuthorization(r *http.Request, auth *v1alpha2.AuthorizationConfigT) (v

//
var generatedHmac, receivedHmac string
valid, generatedHmac, receivedHmac, err = hmac.ValidateTokenUrl(paramToCheck, auth.Hmac.EncryptionKey, auth.Hmac.EncryptionAlgorithm, path)
valid, generatedHmac, receivedHmac, err = hmac.ValidateTokenUrl(paramToCheck, auth.Hmac.EncryptionKey, auth.Hmac.EncryptionAlgorithm, path, auth.Hmac.MandatoryFields)
if err != nil {
err = fmt.Errorf("unable to validate token in request: %s", err.Error())
err = fmt.Errorf("unable to validate hmac sign in request: %s", err.Error())
return valid, err
}
_ = generatedHmac
Expand Down
36 changes: 29 additions & 7 deletions internal/hmac/hmac.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,25 +50,47 @@ func generateHMAC(tokenDigest, encryptionKey, encryptionAlgorithm string) (token

// ValidateToken TODO
// token: exp={int}~hmac={hash}
func ValidateTokenUrl(token, encryptionKey, encryptionAlgorithm, url string) (isValid bool, generatedHmac, receivedHmac string, err error) {
func ValidateTokenUrl(token, encryptionKey, encryptionAlgorithm, url string, mandatoryFields []string) (isValid bool, generatedHmac, receivedHmac string, err error) {
tokenFields := map[string]string{}
tokenParts := strings.Split(token, "~")
for _, fieldv := range tokenParts {
fieldParts := strings.SplitN(fieldv, "=", 2)
if len(fieldParts) != 2 {
continue
}
tokenFields[fieldParts[0]] = fieldParts[1]
}

for _, fv := range mandatoryFields {
if _, ok := tokenFields[fv]; !ok {
err = fmt.Errorf("mandatory field '%s' not found in hmac sign", fv)
return isValid, generatedHmac, receivedHmac, err
}
}

// split token to get tokenDigest and HMAC
tokenParts := strings.Split(token, "~hmac=")
if len(tokenParts) != 2 {
hmacTokenParts := strings.Split(token, "~hmac=")
if len(hmacTokenParts) != 2 {
err = fmt.Errorf("hmac sign without main 'hmac' field")
return isValid, generatedHmac, receivedHmac, err
}
tokenDigest := fmt.Sprintf("%s~url=%s", tokenParts[0], url)
tokenHMAC := []byte(tokenParts[1])
tokenDigest := fmt.Sprintf("%s~url=%s", hmacTokenParts[0], url)
tokenHMAC := []byte(hmacTokenParts[1])

// check expiration time
expPart := strings.TrimPrefix(strings.Split(tokenDigest, "~")[0], "exp=")
expPart, ok := tokenFields["exp"]
if !ok {
err = fmt.Errorf("hmac sign without main 'exp' field")
return isValid, generatedHmac, receivedHmac, err
}
exp, err := strconv.ParseInt(expPart, 10, 64)
if err != nil {
err = fmt.Errorf("invalid expiration time '%s'", expPart)
return isValid, generatedHmac, receivedHmac, err
}

if time.Now().Unix() >= exp {
err = fmt.Errorf("token has expired")
err = fmt.Errorf("hmac sign has expired")
return isValid, generatedHmac, receivedHmac, err
}

Expand Down
15 changes: 9 additions & 6 deletions internal/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ const (
)

type RequestT struct {
Method string `json:"method"`
Host string `json:"host"`
Path string `json:"path"`
Headers http.Header `json:"headers"`
Body string `json:"body"`
Method string `json:"method"`
Host string `json:"host"`
Path string `json:"path"`
QueryParams string `json:"queryParams"`
Headers http.Header `json:"headers"`
Body string `json:"body"`
}

type ResponseT struct {
Expand All @@ -46,7 +47,7 @@ func RequestID(r *http.Request) string {
}
headers += "}"

reqStr := fmt.Sprintf("{method: '%s', host: '%s', path: '%s', headers: '%s'}", r.Method, r.Host, r.URL.Path, headers)
reqStr := fmt.Sprintf("{method: '%s', host: '%s', path: '%s/%s', headers: '%s'}", r.Method, r.Host, r.URL.Path, r.URL.RawQuery, headers)
md5Hash := md5.New()
_, err := md5Hash.Write([]byte(reqStr))
if err != nil {
Expand All @@ -60,6 +61,7 @@ func RequestStruct(r *http.Request) (req RequestT) {
req.Method = r.Method
req.Host = r.Host
req.Path = r.URL.Path
req.QueryParams = r.URL.RawQuery
req.Headers = make(http.Header)
for hk, hvs := range r.Header {
for _, hv := range hvs {
Expand Down Expand Up @@ -100,6 +102,7 @@ func DefaultRequestStruct() (req RequestT) {
req.Method = LogFieldValueDefaultStr
req.Host = LogFieldValueDefaultStr
req.Path = LogFieldValueDefaultStr
req.QueryParams = LogFieldValueDefaultStr
req.Headers = make(http.Header)
req.Body = LogFieldValueDefaultStr
return req
Expand Down

0 comments on commit a259482

Please sign in to comment.