Skip to content

Commit

Permalink
auth: use custom status code if token is expired
Browse files Browse the repository at this point in the history
  • Loading branch information
thisisommore committed Jul 5, 2022
1 parent 9827002 commit d413675
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 32 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,9 @@ Note - Some unset data is emitted.
| :--------- | :---------------------------------------- |
| `feedback` | **Required**. `string` |
| `rating` | **Required**. `int` `ranging from 0 to 5` |

## Custom status codes

| Code | Meaning |
| :----- | :-------------- |
| `4011` | `Token expired` |
23 changes: 19 additions & 4 deletions api/middleware/auth/paseto/paseto.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package paseto

import (
"errors"
"fmt"
"net/http"

customstatuscodes "github.com/NetSepio/gateway/constants/http/custom_status_codes"
"github.com/NetSepio/gateway/models/claims"
"github.com/vk-rv/pvx"

Expand All @@ -23,12 +25,13 @@ func PASETO(c *gin.Context) {
var headers GenericAuthHeaders
err := c.BindHeader(&headers)
if err != nil {
err = fmt.Errorf("failed to bind header, %s", err)
logValidationFailed(headers.Authorization, err)
c.AbortWithStatus(http.StatusInternalServerError)
return
}
if headers.Authorization == "" {
logValidationFailed(headers.Authorization, err)
logValidationFailed(headers.Authorization, ErrAuthHeaderMissing)
httphelper.ErrResponse(c, http.StatusBadRequest, ErrAuthHeaderMissing.Error())
c.Abort()
return
Expand All @@ -42,16 +45,28 @@ func PASETO(c *gin.Context) {
Decrypt(pasetoToken, symK).
ScanClaims(&cc)
if err != nil {
var validationErr *pvx.ValidationError
if errors.As(err, &validationErr) {
if validationErr.HasExpiredErr() {
err = fmt.Errorf("failed to scan claims for paseto token, %s", err)
logValidationFailed(headers.Authorization, err)
httphelper.CErrResponse(c, http.StatusUnauthorized, customstatuscodes.TokenExpired, "token expired")
c.Abort()
return
}

}
err = fmt.Errorf("failed to scan claims for paseto token, %s", err)
logValidationFailed(headers.Authorization, err)
c.AbortWithStatus(http.StatusForbidden)
c.AbortWithStatus(http.StatusUnauthorized)
return
} else {

if err := cc.Valid(); err != nil {
logValidationFailed(headers.Authorization, err)
if err.Error() == gorm.ErrRecordNotFound.Error() {
c.AbortWithStatus(http.StatusForbidden)
c.AbortWithStatus(http.StatusUnauthorized)
} else {
err = fmt.Errorf("failed to validate claim, %s", err)
logwrapper.Log.Error(err)
c.AbortWithStatus(http.StatusInternalServerError)
}
Expand Down
32 changes: 21 additions & 11 deletions api/middleware/auth/paseto/paseto_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package paseto

import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/NetSepio/gateway/api/types"
"github.com/NetSepio/gateway/config"
"github.com/NetSepio/gateway/config/dbconfig"
customstatuscodes "github.com/NetSepio/gateway/constants/http/custom_status_codes"
"github.com/NetSepio/gateway/models"
"github.com/NetSepio/gateway/models/claims"
"github.com/NetSepio/gateway/util/pkg/auth"
Expand Down Expand Up @@ -43,21 +46,21 @@ func Test_PASETO(t *testing.T) {
if err != nil {
t.Fatal(err)
}
statusCode := callApi(t, token)
assert.Equal(t, http.StatusOK, statusCode)
rr := callApi(t, token)
assert.Equal(t, http.StatusOK, rr.Result().StatusCode)
})

t.Run("Should return 403 with incorret PASETO", func(t *testing.T) {
t.Run("Should return 401 with incorret PASETO", func(t *testing.T) {
newClaims := claims.New(testWalletAddress)
token, err := auth.GenerateToken(newClaims, "this private key is valid key")
token, err := auth.GenerateToken(newClaims, "aaaabbaa")
if err != nil {
t.Fatal(err)
}
statusCode := callApi(t, token)
assert.Equal(t, http.StatusForbidden, statusCode)
rr := callApi(t, token)
assert.Equal(t, http.StatusUnauthorized, rr.Result().StatusCode)
})

t.Run("Should return 403 with expired PASETO", func(t *testing.T) {
t.Run("Should return 401 and 4011 with expired PASETO", func(t *testing.T) {
expiration := time.Now().Add(time.Second * 2)
signedBy := envutil.MustGetEnv("SIGNED_BY")
newClaims := claims.CustomClaims{
Expand All @@ -73,13 +76,20 @@ func Test_PASETO(t *testing.T) {
t.Fatal(err)
}

statusCode := callApi(t, token)
assert.Equal(t, http.StatusForbidden, statusCode)
rr := callApi(t, token)
assert.Equal(t, http.StatusUnauthorized, rr.Result().StatusCode)
var response types.ApiResponse
body := rr.Body
err = json.NewDecoder(body).Decode(&response)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, customstatuscodes.TokenExpired, response.StatusCode)
})

}

func callApi(t *testing.T, token string) int {
func callApi(t *testing.T, token string) *httptest.ResponseRecorder {
rr := httptest.NewRecorder()
ginTestApp := gin.New()

Expand All @@ -91,7 +101,7 @@ func callApi(t *testing.T, token string) int {
ginTestApp.Use(PASETO)
ginTestApp.Use(successHander)
ginTestApp.ServeHTTP(rr, rq)
return rr.Result().StatusCode
return rr
}

func successHander(c *gin.Context) {
Expand Down
8 changes: 4 additions & 4 deletions api/types/http.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package types

type ApiResponse struct {
Status int `json:"status,omitempty"`
Error string `json:"error,omitempty"`
Message string `json:"message,omitempty"`
Payload interface{} `json:"payload,omitempty"`
StatusCode int `json:"statusCode,omitempty"`
Error string `json:"error,omitempty"`
Message string `json:"message,omitempty"`
Payload interface{} `json:"payload,omitempty"`
}
5 changes: 5 additions & 0 deletions constants/http/custom_status_codes/custom_status_codes.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package customstatuscodes

const (
TokenExpired = 4031
)
34 changes: 21 additions & 13 deletions util/pkg/httphelper/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,34 +12,42 @@ import (
//TODO: add method for internet error with common msg
func ErrResponse(c *gin.Context, statusCode int, errMessage string) {
response := types.ApiResponse{
Status: statusCode,
Error: errMessage,
StatusCode: statusCode,
Error: errMessage,
}
c.JSON(response.Status, response)
c.JSON(response.StatusCode, response)
}

func CErrResponse(c *gin.Context, statusCode int, customStatusCode int, errMessage string) {
response := types.ApiResponse{
StatusCode: customStatusCode,
Error: errMessage,
}
c.JSON(statusCode, response)
}

func SuccessResponse(c *gin.Context, message string, payload interface{}) {
response := types.ApiResponse{
Status: http.StatusOK,
Payload: payload,
Message: message,
StatusCode: http.StatusOK,
Payload: payload,
Message: message,
}
c.JSON(response.Status, response)
c.JSON(response.StatusCode, response)
}

func InternalServerError(c *gin.Context) {
response := types.ApiResponse{
Status: http.StatusInternalServerError,
Error: "unexpected error occurred",
StatusCode: http.StatusInternalServerError,
Error: "unexpected error occurred",
}
c.JSON(response.Status, response)
c.JSON(response.StatusCode, response)
}

func NewInternalServerError(c *gin.Context, format string, args ...interface{}) {
logwrapper.Errorf(format, args...)
response := types.ApiResponse{
Status: http.StatusInternalServerError,
Error: "unexpected error occurred",
StatusCode: http.StatusInternalServerError,
Error: "unexpected error occurred",
}
c.JSON(response.Status, response)
c.JSON(response.StatusCode, response)
}

0 comments on commit d413675

Please sign in to comment.