Skip to content

Commit

Permalink
Merge pull request #392 from merlin-northern/men_6804_key_rotation_su…
Browse files Browse the repository at this point in the history
…pport

feat: "kid" support in JWT header and multiple keys.
  • Loading branch information
merlin-northern authored Nov 29, 2023
2 parents ff3e93c + 27467e3 commit e77736d
Show file tree
Hide file tree
Showing 44 changed files with 1,245 additions and 69 deletions.
20 changes: 16 additions & 4 deletions api/http/api_useradm.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ var (
type UserAdmApiHandlers struct {
userAdm useradm.App
db store.DataStore
jwth jwt.Handler
jwth map[int]jwt.Handler
config Config
}

Expand All @@ -85,7 +85,7 @@ type Config struct {
func NewUserAdmApiHandlers(
userAdm useradm.App,
db store.DataStore,
jwth jwt.Handler,
jwth map[int]jwt.Handler,
config Config,
) ApiHandler {
return &UserAdmApiHandlers{
Expand Down Expand Up @@ -205,7 +205,13 @@ func (u *UserAdmApiHandlers) AuthLogoutHandler(w rest.ResponseWriter, r *rest.Re
l := log.FromContext(ctx)

if tokenStr, err := authz.ExtractToken(r.Request); err == nil {
token, err := u.jwth.FromJWT(tokenStr)
keyId := jwt.GetKeyId(tokenStr)
if _, ok := u.jwth[keyId]; !ok {
rest_utils.RestErrWithLogInternal(w, r, l, errors.New("internal error"))
return
}

token, err := u.jwth[keyId].FromJWT(tokenStr)
if err != nil {
rest_utils.RestErrWithLogInternal(w, r, l, err)
return
Expand Down Expand Up @@ -374,7 +380,13 @@ func (u *UserAdmApiHandlers) UpdateUserHandler(w rest.ResponseWriter, r *rest.Re

// extract the token used to update the user
if tokenStr, err := authz.ExtractToken(r.Request); err == nil {
token, err := u.jwth.FromJWT(tokenStr)
keyId := jwt.GetKeyId(tokenStr)
if _, ok := u.jwth[keyId]; !ok {
rest_utils.RestErrWithLogInternal(w, r, l, errors.New("internal error"))
return
}

token, err := u.jwth[keyId].FromJWT(tokenStr)
if err != nil {
rest_utils.RestErrWithLogInternal(w, r, l, err)
return
Expand Down
10 changes: 5 additions & 5 deletions api/http/api_useradm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -852,10 +852,10 @@ func makeMockApiHandler(t *testing.T, uadm useradm.App, db store.DataStore) http
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
assert.NoError(t, err)

jwth := jwt.NewJWTHandlerRS256(key)
jwth := jwt.NewJWTHandlerRS256(key, 0)

// API handler
handlers := NewUserAdmApiHandlers(uadm, db, jwth, Config{})
handlers := NewUserAdmApiHandlers(uadm, db, map[int]jwt.Handler{0: jwth}, Config{})
assert.NotNil(t, handlers)

app, err := handlers.GetApp()
Expand All @@ -879,9 +879,9 @@ func makeMockApiHandler(t *testing.T, uadm useradm.App, db store.DataStore) http
mock.AnythingOfType("*log.Logger")).Return(authorizer)

authzmw := &authz.AuthzMiddleware{
Authz: authorizer,
ResFunc: ExtractResourceAction,
JWTHandler: jwth,
Authz: authorizer,
ResFunc: ExtractResourceAction,
JWTHandlers: map[int]jwt.Handler{0: jwth},
}

ifmw := &rest.IfMiddleware{
Expand Down
14 changes: 12 additions & 2 deletions authz/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ const (
type AuthzMiddleware struct {
Authz Authorizer
ResFunc ResourceActionExtractor
JWTHandler jwt.Handler
JWTHandlers map[int]jwt.Handler
JWTFallbackHandler jwt.Handler
}

Expand All @@ -60,8 +60,18 @@ func (mw *AuthzMiddleware) MiddlewareFunc(h rest.HandlerFunc) rest.HandlerFunc {
return
}

keyId := jwt.GetKeyId(tokstr)

if _, ok := mw.JWTHandlers[keyId]; !ok {
// we have not found the corresponding handler for the key by id
// on purpose we do not return common.ErrKeyIdNotFound -- to not allow
// the enumeration attack
rest_utils.RestErrWithLog(w, r, l, ErrAuthzTokenInvalid, http.StatusUnauthorized)
return
}

// parse token, insert into env
token, err := mw.JWTHandler.FromJWT(tokstr)
token, err := mw.JWTHandlers[keyId].FromJWT(tokstr)
if err != nil && mw.JWTFallbackHandler != nil {
token, err = mw.JWTFallbackHandler.FromJWT(tokstr)
}
Expand Down
8 changes: 4 additions & 4 deletions authz/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,11 +331,11 @@ func TestAuthzMiddleware(t *testing.T) {

//finish setting up the middleware
privkey := loadPrivKey("../crypto/private.pem", t)
jwth := jwt.NewJWTHandlerRS256(privkey)
jwth := jwt.NewJWTHandlerRS256(privkey, 0)
mw := AuthzMiddleware{
Authz: a,
ResFunc: resfunc,
JWTHandler: jwth,
Authz: a,
ResFunc: resfunc,
JWTHandlers: map[int]jwt.Handler{0: jwth},
}
api.Use(&mw)

Expand Down
45 changes: 45 additions & 0 deletions common/keys.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright 2023 Northern.tech AS
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package common

import (
"path/filepath"
"regexp"
"strconv"

"github.com/pkg/errors"
)

const KeyIdZero = 0

var (
ErrKeyIdNotFound = errors.New("cant locate key by key id")
ErrKeyIdCollision = errors.New("key id already loaded")
)

func KeyIdFromPath(privateKeyPath string, privateKeyFilenamePattern string) (keyId int) {
fileName := filepath.Base(privateKeyPath)
r, _ := regexp.Compile(privateKeyFilenamePattern)
b := []byte(fileName)
indices := r.FindAllSubmatchIndex(b, -1)
keyId = KeyIdZero
if len(indices) > 0 && len(indices[0]) > 3 {
k, err := strconv.Atoi(string(b[indices[0][2]:indices[0][3]]))
if err == nil {
keyId = k
}
}
return keyId
}
34 changes: 34 additions & 0 deletions common/keys_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright 2023 Northern.tech AS
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package common

import (
"strconv"
"testing"

"github.com/stretchr/testify/assert"
)

func TestKeyIdFromPath(t *testing.T) {
var keyId int
for i := 1; i < 1024; i++ {
keyId = KeyIdFromPath("/etc/useradm/rsa/private.id."+strconv.Itoa(i)+".pem", "private\\.id\\.([0-9]*)\\.pem")
assert.Equal(t, i, keyId)
}
for i := 1; i < 1024; i++ {
keyId = KeyIdFromPath("/etc/useradm/rsa/private.id-"+strconv.Itoa(i)+".pem", "private\\.id\\.([0-9]*)\\.pem")
assert.Equal(t, KeyIdZero, keyId)
}
}
8 changes: 8 additions & 0 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ listen: :8080
# Overwrite with environment variable: USERADM_SERVER_PRIV_KEY_PATH
# server_priv_key_path: /etc/useradm/rsa/private.pem

# Private key filename pattern - used to support multiple keys and key rotation
# Each file in a directory where server_priv_key_path reside the service checks
# against the pattern. If the file matches, then it is loaded as a private key
# identified with an id which exists in the file name.
# Defaults to: "private\\.id\\.([0-9]*)\\.pem"
# Overwrite with environment variable: USERADM_SERVER_PRIV_KEY_FILENAME_PATTERN
# server_priv_key_filename_pattern: "private\\.id\\.([0-9]*)\\.pem"

# Fallback private key path - used for JWT verification
# Defaults to: none
# Overwrite with environment variable: USERADM_SERVER_FALLBACK_PRIV_KEY_PATH
Expand Down
8 changes: 6 additions & 2 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ const (
SettingMiddleware = "middleware"
SettingMiddlewareDefault = "prod"

SettingServerPrivKeyPath = "server_priv_key_path"
SettingServerPrivKeyPathDefault = "/etc/useradm/rsa/private.pem"
SettingServerPrivKeyPath = "server_priv_key_path"
SettingServerPrivKeyPathDefault = "/etc/useradm/rsa/private.pem"
SettingServerPrivKeyFileNamePattern = "server_priv_key_filename_pattern"
SettingServerPrivKeyFileNamePatternDefault = "private\\.id\\.([0-9]*)\\.pem"

SettingServerFallbackPrivKeyPath = "server_fallback_priv_key_path"
SettingServerFallbackPrivKeyPathDefault = ""
Expand Down Expand Up @@ -73,6 +75,8 @@ var (
{Key: SettingListen, Value: SettingListenDefault},
{Key: SettingMiddleware, Value: SettingMiddlewareDefault},
{Key: SettingServerPrivKeyPath, Value: SettingServerPrivKeyPathDefault},
{Key: SettingServerPrivKeyFileNamePattern,
Value: SettingServerPrivKeyFileNamePatternDefault},
{Key: SettingServerFallbackPrivKeyPath, Value: SettingServerFallbackPrivKeyPathDefault},
{Key: SettingJWTIssuer, Value: SettingJWTIssuerDefault},
{Key: SettingJWTExpirationTimeout, Value: SettingJWTExpirationTimeoutDefault},
Expand Down
48 changes: 44 additions & 4 deletions jwt/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ import (
"os"

"github.com/pkg/errors"

jwtv4 "github.com/golang-jwt/jwt/v4"

"github.com/mendersoftware/useradm/common"
)

var (
Expand All @@ -45,7 +49,7 @@ type Handler interface {
FromJWT(string) (*Token, error)
}

func NewJWTHandler(privateKeyPath string) (Handler, error) {
func NewJWTHandler(privateKeyPath string, privateKeyFilenamePattern string) (Handler, error) {
priv, err := os.ReadFile(privateKeyPath)
block, _ := pem.Decode(priv)
if block == nil {
Expand All @@ -57,18 +61,54 @@ func NewJWTHandler(privateKeyPath string) (Handler, error) {
if err != nil {
return nil, errors.Wrap(err, "failed to read rsa private key")
}
return NewJWTHandlerRS256(privKey), nil
return NewJWTHandlerRS256(
privKey,
common.KeyIdFromPath(privateKeyPath, privateKeyFilenamePattern),
),
nil
case pemHeaderPKCS8:
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return nil, errors.Wrap(err, "failed to read private key")
}
switch v := key.(type) {
case *rsa.PrivateKey:
return NewJWTHandlerRS256(v), nil
return NewJWTHandlerRS256(
v,
common.KeyIdFromPath(privateKeyPath, privateKeyFilenamePattern),
),
nil
case ed25519.PrivateKey:
return NewJWTHandlerEd25519(&v), nil
return NewJWTHandlerEd25519(
&v,
common.KeyIdFromPath(privateKeyPath, privateKeyFilenamePattern),
),
nil
}
}
return nil, errors.Errorf("unsupported server private key type")
}

func GetKeyId(tokenString string) int {
token, _, err := jwtv4.NewParser().ParseUnverified(tokenString, &Claims{})

if err != nil {
return common.KeyIdZero
}

if _, ok := token.Header["kid"]; ok {
if _, ok := token.Header["kid"]; ok {
if _, isFloat := token.Header["kid"].(float64); isFloat {
return int(token.Header["kid"].(float64))
}
if _, isInt := token.Header["kid"].(int64); isInt {
return int(token.Header["kid"].(int64))
}
if _, isInt := token.Header["kid"].(int); isInt {
return token.Header["kid"].(int)
}
}
}

return common.KeyIdZero
}
35 changes: 29 additions & 6 deletions jwt/jwt_ed25519.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,38 +15,61 @@ package jwt

import (
"crypto/ed25519"
"strconv"

"github.com/golang-jwt/jwt/v4"
"github.com/pkg/errors"

"github.com/mendersoftware/useradm/common"
)

// JWTHandlerEd25519 is an Ed25519-specific JWTHandler
type JWTHandlerEd25519 struct {
privKey *ed25519.PrivateKey
privKey map[int]*ed25519.PrivateKey
currentKeyId int
}

func NewJWTHandlerEd25519(privKey *ed25519.PrivateKey) *JWTHandlerEd25519 {
func NewJWTHandlerEd25519(privKey *ed25519.PrivateKey, keyId int) *JWTHandlerEd25519 {
return &JWTHandlerEd25519{
privKey: privKey,
privKey: map[int]*ed25519.PrivateKey{keyId: privKey},
currentKeyId: keyId,
}
}

func (j *JWTHandlerEd25519) ToJWT(token *Token) (string, error) {
//generate
jt := jwt.NewWithClaims(jwt.SigningMethodEdDSA, &token.Claims)

jt.Header["kid"] = token.KeyId
if _, exists := j.privKey[token.KeyId]; !exists {
return "", common.ErrKeyIdNotFound
}
//sign
data, err := jt.SignedString(j.privKey)
data, err := jt.SignedString(j.privKey[token.KeyId])
return data, err
}

func (j *JWTHandlerEd25519) FromJWT(tokstr string) (*Token, error) {
jwttoken, err := jwt.ParseWithClaims(tokstr, &Claims{},
func(token *jwt.Token) (interface{}, error) {
keyId := common.KeyIdZero
if _, ok := token.Header["kid"]; ok {
if _, isFloat := token.Header["kid"].(float64); isFloat {
keyId = int(token.Header["kid"].(float64))
}
if _, isInt := token.Header["kid"].(int64); isInt {
keyId = int(token.Header["kid"].(int64))
}
if _, isInt := token.Header["kid"].(int); isInt {
keyId = token.Header["kid"].(int)
}
}
if _, ok := token.Method.(*jwt.SigningMethodEd25519); !ok {
return nil, errors.New("unexpected signing method: " + token.Method.Alg())
}
return j.privKey.Public(), nil
if _, exists := j.privKey[keyId]; !exists {
return nil, errors.New("cannot find the key with id " + strconv.Itoa(keyId))
}
return j.privKey[keyId].Public(), nil
},
)

Expand Down
Loading

0 comments on commit e77736d

Please sign in to comment.