Skip to content

Commit

Permalink
utilize the tenant aspect inline with hydra
Browse files Browse the repository at this point in the history
  • Loading branch information
pitabwire committed Jan 20, 2024
1 parent 18aee1d commit d3ae38a
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 47 deletions.
104 changes: 78 additions & 26 deletions authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,75 @@ const ctxKeyAuthentication = contextKey("authenticationKey")
// AuthenticationClaims Create a struct that will be encoded to a JWT.
// We add jwt.StandardClaims as an embedded type, to provide fields like expiry time
type AuthenticationClaims struct {
ProfileID string `json:"sub,omitempty"`
TenantID string `json:"tenant_id,omitempty"`
PartitionID string `json:"partition_id,omitempty"`
AccessID string `json:"access_id,omitempty"`
Roles []string `json:"roles,omitempty"`
ProfileID string `json:"sub,omitempty"`
Ext map[string]interface{} `json:"ext,omitempty"`
jwt.RegisteredClaims
}

func (a *AuthenticationClaims) TenantId() string {

result := ""
val, ok := a.Ext["tenant_id"]
if !ok {
return ""
}

result, ok = val.(string)
if !ok {
return ""
}

return result
}

func (a *AuthenticationClaims) PartitionId() string {

result := ""
val, ok := a.Ext["partition_id"]
if !ok {
return ""
}

result, ok = val.(string)
if !ok {
return ""
}

return result
}

func (a *AuthenticationClaims) AccessId() string {

result := ""
val, ok := a.Ext["access_id"]
if !ok {
return ""
}

result, ok = val.(string)
if !ok {
return ""
}

return result
}

func (a *AuthenticationClaims) Roles() []string {

result := []string{}
val, ok := a.Ext["roles"]
if !ok {
return result
}

result, ok = val.([]string)
if !ok {
return []string{}
}

return result
}

func (a *AuthenticationClaims) isSystem() bool {
//TODO: tokens which are granted as client credentials have no partition information attached
// Since we cannot pass custom information to token to allow specifying who is an admin.
Expand All @@ -43,11 +104,11 @@ func (a *AuthenticationClaims) isSystem() bool {
func (a *AuthenticationClaims) AsMetadata() map[string]string {

m := make(map[string]string)
m["tenant_id"] = a.TenantID
m["partition_id"] = a.PartitionID
m["tenant_id"] = a.TenantId()
m["partition_id"] = a.PartitionId()
m["profile_id"] = a.ProfileID
m["access_id"] = a.AccessID
m["roles"] = strings.Join(a.Roles, ",")
m["access_id"] = a.AccessId()
m["roles"] = strings.Join(a.Roles(), ",")
return m
}

Expand All @@ -70,24 +131,15 @@ func ClaimsFromContext(ctx context.Context) *AuthenticationClaims {
func ClaimsFromMap(m map[string]string) *AuthenticationClaims {
var authenticationClaims AuthenticationClaims

if val, ok := m["tenant_id"]; ok {
authenticationClaims = AuthenticationClaims{}
authenticationClaims.TenantID = val

if val, ok := m["partition_id"]; ok {
authenticationClaims.PartitionID = val

if val, ok := m["profile_id"]; ok {
authenticationClaims.ProfileID = val
authenticationClaims = AuthenticationClaims{
Ext: map[string]interface{}{},
}

if val, ok := m["access_id"]; ok {
authenticationClaims.AccessID = val
if val, ok := m["roles"]; ok {
authenticationClaims.Roles = strings.Split(val, ",")
return &authenticationClaims
}
}
}
for key, val := range m {
if key == "roles" {
authenticationClaims.Ext[key] = strings.Split(val, ",")
} else {
authenticationClaims.Ext[key] = val
}
}

Expand Down
4 changes: 2 additions & 2 deletions authorization.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ func AuthHasAccess(ctx context.Context, action string, subject string) (bool, er
}

payload := map[string]interface{}{
"namespace": authClaims.TenantID,
"object": authClaims.PartitionID,
"namespace": authClaims.TenantId(),
"object": authClaims.PartitionId(),
"relation": action,
"subject_id": subject,
}
Expand Down
25 changes: 14 additions & 11 deletions authorization_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ func authorizationControlListWrite(ctx context.Context, writeServerURL string, a
}

payload := map[string]interface{}{
"namespace": authClaims.TenantID,
"object": authClaims.PartitionID,
"namespace": authClaims.TenantId(),
"object": authClaims.PartitionId(),
"relation": action,
"subject_id": subject,
}
Expand Down Expand Up @@ -53,10 +53,12 @@ func TestAuthorizationControlListWrite(t *testing.T) {
ctx = frame.ToContext(ctx, srv)

authClaim := frame.AuthenticationClaims{
TenantID: "default",
PartitionID: "partition",
ProfileID: "profile",
AccessID: "access",
ProfileID: "profile",
Ext: map[string]any{
"partition_id": "partition",
"tenant_id": "default",
"access_id": "access",
},
}
ctx = authClaim.ClaimsToContext(ctx)

Expand All @@ -77,11 +79,12 @@ func TestAuthHasAccess(t *testing.T) {
ctx = frame.ToContext(ctx, srv)

authClaim := frame.AuthenticationClaims{
TenantID: "default",
PartitionID: "partition",
ProfileID: "profile",
AccessID: "access",
}
ProfileID: "profile",
Ext: map[string]any{
"partition_id": "partition",
"tenant_id": "default",
"access_id": "access",
}}
ctx = authClaim.ClaimsToContext(ctx)

err := authorizationControlListWrite(ctx, authorizationServerURL, "read", "reader")
Expand Down
10 changes: 5 additions & 5 deletions common.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ func (model *BaseModel) GenID(ctx context.Context) {
return
}

if authClaim.AccessID != "" {
model.AccessID = authClaim.AccessID
if authClaim.AccessId() != "" {
model.AccessID = authClaim.AccessId()
}

if authClaim.TenantID != "" && authClaim.PartitionID != "" {
model.PartitionID = authClaim.PartitionID
model.TenantID = authClaim.TenantID
if authClaim.TenantId() != "" && authClaim.PartitionId() != "" {
model.PartitionID = authClaim.PartitionId()
model.TenantID = authClaim.TenantId()
}
}

Expand Down
6 changes: 3 additions & 3 deletions datastore.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ func tenantPartition(ctx context.Context) func(db *gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
authClaim := ClaimsFromContext(ctx)
if authClaim != nil &&
authClaim.TenantID != "" &&
authClaim.PartitionID != "" &&
authClaim.TenantId() != "" &&
authClaim.PartitionId() != "" &&
!authClaim.isSystem() {
return db.Where("tenant_id = ? AND partition_id = ?", authClaim.TenantID, authClaim.PartitionID)
return db.Where("tenant_id = ? AND partition_id = ?", authClaim.TenantId(), authClaim.PartitionId())
}
return db
}
Expand Down

0 comments on commit d3ae38a

Please sign in to comment.