Skip to content

Commit

Permalink
Merge pull request #149 from eolinker/feature/pre_router
Browse files Browse the repository at this point in the history
Feature/pre router
  • Loading branch information
Dot-Liu authored Jan 19, 2024
2 parents 238218a + d2e701f commit bbc0efe
Show file tree
Hide file tree
Showing 32 changed files with 1,622 additions and 40 deletions.
8 changes: 6 additions & 2 deletions application/auth/aksk/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ var _ auth.IAuthFactory = (*factory)(nil)

var driverName = "aksk"

//Register 注册auth驱动工厂
// Register 注册auth驱动工厂
func Register() {
auth.FactoryRegister(driverName, NewFactory())
}
Expand Down Expand Up @@ -44,6 +44,10 @@ func (f *factory) Alias() []string {
}
}

func (f *factory) PreRouters() []*auth.PreRouter {
return nil
}

func (f *factory) Create(tokenName string, position string, rule interface{}) (application.IAuth, error) {
a := &aksk{
id: toId(tokenName, position),
Expand All @@ -54,7 +58,7 @@ func (f *factory) Create(tokenName string, position string, rule interface{}) (a
return a, nil
}

//NewFactory 生成一个 auth_apiKey工厂
// NewFactory 生成一个 auth_apiKey工厂
func NewFactory() auth.IAuthFactory {
typ := reflect.TypeOf((*Config)(nil))
render, _ := schema.Generate(typ, nil)
Expand Down
4 changes: 4 additions & 0 deletions application/auth/apikey/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ func (f *factory) Alias() []string {
}
}

func (f *factory) PreRouters() []*auth.PreRouter {
return nil
}

func (f *factory) Create(tokenName string, position string, rule interface{}) (application.IAuth, error) {
a := &apikey{
id: toId(tokenName, position),
Expand Down
4 changes: 4 additions & 0 deletions application/auth/basic/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ func (f *factory) Alias() []string {
}
}

func (f *factory) PreRouters() []*auth.PreRouter {
return nil
}

func (f *factory) Create(tokenName string, position string, rule interface{}) (application.IAuth, error) {
a := &basic{
id: toId(tokenName, position),
Expand Down
32 changes: 21 additions & 11 deletions application/auth/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"reflect"
"strings"

"github.com/eolinker/apinto/router"

"github.com/eolinker/apinto/application"
"github.com/eolinker/eosc/log"

Expand All @@ -18,24 +20,32 @@ var (
_ eosc.ISetting = defaultAuthFactoryRegister
)

//IAuthFactory 鉴权工厂方法
type PreRouter struct {
ID string
PreHandler router.IRouterPreHandler
Path string
Method []string
}

// IAuthFactory 鉴权工厂方法
type IAuthFactory interface {
Create(tokenName string, position string, rule interface{}) (application.IAuth, error)
Alias() []string
Render() interface{}
ConfigType() reflect.Type
UserType() reflect.Type
PreRouters() []*PreRouter
}

//IAuthFactoryRegister 实现了鉴权工厂管理器
// IAuthFactoryRegister 实现了鉴权工厂管理器
type IAuthFactoryRegister interface {
RegisterFactoryByKey(key string, factory IAuthFactory)
GetFactoryByKey(key string) (IAuthFactory, bool)
Keys() []string
Alias() map[string]string
}

//driverRegister 驱动注册器
// driverRegister 驱动注册器
type driverRegister struct {
register eosc.IRegister[IAuthFactory]
keys []string
Expand Down Expand Up @@ -80,7 +90,7 @@ func (dm *driverRegister) ReadOnly() bool {
return true
}

//newAuthFactoryManager 创建auth工厂管理器
// newAuthFactoryManager 创建auth工厂管理器
func newAuthFactoryManager() *driverRegister {
return &driverRegister{
register: eosc.NewRegister[IAuthFactory](),
Expand All @@ -90,12 +100,12 @@ func newAuthFactoryManager() *driverRegister {
}
}

//GetFactoryByKey 获取指定auth工厂
// GetFactoryByKey 获取指定auth工厂
func (dm *driverRegister) GetFactoryByKey(key string) (IAuthFactory, bool) {
return dm.register.Get(key)
}

//RegisterFactoryByKey 注册auth工厂
// RegisterFactoryByKey 注册auth工厂
func (dm *driverRegister) RegisterFactoryByKey(key string, factory IAuthFactory) {
err := dm.register.Register(key, factory, true)
if err != nil {
Expand All @@ -109,7 +119,7 @@ func (dm *driverRegister) RegisterFactoryByKey(key string, factory IAuthFactory)
}
}

//Keys 返回所有已注册的key
// Keys 返回所有已注册的key
func (dm *driverRegister) Keys() []string {
return dm.keys
}
Expand All @@ -118,18 +128,18 @@ func (dm *driverRegister) Alias() map[string]string {
return dm.driverAlias
}

//FactoryRegister 注册auth工厂到默认auth工厂注册器
// FactoryRegister 注册auth工厂到默认auth工厂注册器
func FactoryRegister(key string, factory IAuthFactory) {

defaultAuthFactoryRegister.RegisterFactoryByKey(key, factory)
}

//Get 从默认auth工厂注册器中获取auth工厂
// Get 从默认auth工厂注册器中获取auth工厂
func Get(key string) (IAuthFactory, bool) {
return defaultAuthFactoryRegister.GetFactoryByKey(key)
}

//Keys 返回默认的auth工厂注册器中所有已注册的key
// Keys 返回默认的auth工厂注册器中所有已注册的key
func Keys() []string {
return defaultAuthFactoryRegister.Keys()
}
Expand All @@ -138,7 +148,7 @@ func Alias() map[string]string {
return defaultAuthFactoryRegister.Alias()
}

//GetFactory 获取指定auth工厂,若指定的不存在则返回一个已注册的工厂
// GetFactory 获取指定auth工厂,若指定的不存在则返回一个已注册的工厂
func GetFactory(name string) (IAuthFactory, error) {
factory, ok := Get(name)
if !ok {
Expand Down
8 changes: 6 additions & 2 deletions application/auth/jwt/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ var _ auth.IAuthFactory = (*factory)(nil)

var driverName = "jwt"

//Register 注册auth驱动工厂
// Register 注册auth驱动工厂
func Register() {
auth.FactoryRegister(driverName, NewFactory())
}
Expand Down Expand Up @@ -43,6 +43,10 @@ func (f *factory) Alias() []string {
}
}

func (f *factory) PreRouters() []*auth.PreRouter {
return nil
}

func (f *factory) Create(tokenName string, position string, rule interface{}) (application.IAuth, error) {
baseConfig, ok := rule.(*application.BaseConfig)
if !ok {
Expand All @@ -66,7 +70,7 @@ func (f *factory) Create(tokenName string, position string, rule interface{}) (a
return a, nil
}

//NewFactory 生成一个 auth_apiKey工厂
// NewFactory 生成一个 auth_apiKey工厂
func NewFactory() auth.IAuthFactory {
typ := reflect.TypeOf((*Config)(nil))
render, _ := schema.Generate(typ, nil)
Expand Down
144 changes: 144 additions & 0 deletions application/auth/oauth2/authorize.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
package oauth2

import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"os"
"strconv"
"sync"
"time"

scope_manager "github.com/eolinker/apinto/scope-manager"

"github.com/eolinker/apinto/resources"
http_context "github.com/eolinker/eosc/eocontext/http-context"
)

const (
ResponseTypeCode = "code"
ResponseTypeToken = "token"
)

func NewAuthorizeHandler() *AuthorizeHandler {
return &AuthorizeHandler{}
}

type AuthorizeHandler struct {
cache scope_manager.IProxyOutput[resources.ICache]
once sync.Once
}

func (a *AuthorizeHandler) Handle(ctx http_context.IHttpContext, client *Client, params url.Values) {
responseType := params.Get("response_type")
if responseType == "" || !((responseType == ResponseTypeCode && client.EnableAuthorizationCode) || (responseType == ResponseTypeToken && client.EnableImplicitGrant)) {
ctx.Response().SetBody([]byte(fmt.Sprintf("unsupported response type: %s,client id is %s", responseType, client.ClientId)))
ctx.Response().SetStatus(http.StatusForbidden, "forbidden")
return
}

scope := params.Get("scope")
if scope == "" && client.MandatoryScope {
ctx.Response().SetBody([]byte("scope is required, client id is " + client.ClientId))
ctx.Response().SetStatus(http.StatusForbidden, "forbidden")
return
}
matchScope := false
for _, s := range client.Scopes {
if s == scope {
matchScope = true
break
}
}
if !matchScope {
ctx.Response().SetBody([]byte("invalid scope, client id is " + client.ClientId))
ctx.Response().SetStatus(http.StatusForbidden, "forbidden")
return
}

redirectURI := params.Get("redirect_uri")
if redirectURI == "" {
ctx.Response().SetBody([]byte("redirect uri is required, client id is " + client.ClientId))
ctx.Response().SetStatus(http.StatusForbidden, "forbidden")
return
}

matchRedirectUri := false
for _, uri := range client.RedirectUrls {
if uri == redirectURI {
matchRedirectUri = true
break
}
}
if !matchRedirectUri {
ctx.Response().SetBody([]byte("invalid redirect uri, client id is " + client.ClientId))
ctx.Response().SetStatus(http.StatusForbidden, "forbidden")
return
}
uri, err := url.Parse(redirectURI)
if err != nil {
ctx.Response().SetBody([]byte("invalid redirect uri, client id is " + client.ClientId))
ctx.Response().SetStatus(http.StatusForbidden, "forbidden")
return
}
a.once.Do(func() {
a.cache = scope_manager.Auto[resources.ICache]("", "redis")
})
list := a.cache.List()
if len(list) < 1 {
ctx.Response().SetBody([]byte("redis cache is not available"))
ctx.Response().SetStatus(http.StatusForbidden, "forbidden")
return
}
cache := list[0]
query := url.Values{}
switch responseType {
case ResponseTypeCode:
{
// 授权码模式
provisionKey := params.Get("provision_key")
if provisionKey != client.ProvisionKey {
ctx.Response().SetBody([]byte("invalid provision key, client id is " + client.ClientId))
ctx.Response().SetStatus(http.StatusForbidden, "forbidden")
return
}
code := generateRandomString()
redisKey := fmt.Sprintf("apinto:oauth2_codes:%s:%s", os.Getenv("cluster_id"), code)
field := map[string]interface{}{
"code": code,
"scope": scope,
}
_, err = cache.HMSetN(ctx.Context(), redisKey, field, 6*time.Minute).Result()
if err != nil {
ctx.Response().SetBody([]byte(fmt.Sprintf("(%s)redis HMSet %s error: %s", client.ClientId, redisKey, err.Error())))
ctx.Response().SetStatus(http.StatusInternalServerError, "server error")
return
}
query.Set("code", code)
}
case ResponseTypeToken:
{
token, err := generateToken(ctx.Context(), cache, client.ClientId, client.TokenExpiration, client.RefreshTokenTTL, scope, false)
if err != nil {
ctx.Response().SetBody([]byte(fmt.Sprintf("(%s)generate token error: %s", client.ClientId, err.Error())))
ctx.Response().SetStatus(http.StatusInternalServerError, "server error")
return
}
query.Set("access_token", token.AccessToken)
query.Set("token_type", "bearer")
query.Set("expires_in", strconv.Itoa(token.ExpiresIn))
}
}

state := params.Get("state")
if state != "" {
query.Set("state", state)
}
data, _ := json.Marshal(map[string]interface{}{
"redirect_uri": fmt.Sprintf("%s?%s", uri.String(), query.Encode()),
})
ctx.Response().SetBody(data)
ctx.Response().SetStatus(http.StatusOK, "OK")
return
}
42 changes: 42 additions & 0 deletions application/auth/oauth2/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package oauth2

import "github.com/eolinker/apinto/application"

const (
GrantAuthorizationCode = "authorization_code"
GrantClientCredentials = "client_credentials"
GrantRefreshToken = "refresh_token"
)

type Config struct {
application.Auth
Users []*User `json:"users" label:"用户列表"`
}

type User struct {
Pattern Pattern `json:"pattern" label:"用户信息"`
application.User
}

type Pattern struct {
ClientId string `json:"client_id"`
ClientSecret string `json:"client_secret"`
ClientType string `json:"client_type"`
HashSecret bool `json:"hash_secret"`
RedirectUrls []string `json:"redirect_urls" label:"重定向URL"`
Scopes []string `json:"scopes" label:"授权范围"`
MandatoryScope bool `json:"mandatory_scope" label:"强制授权"`
ProvisionKey string `json:"provision_key" label:"Provision Key"`
TokenExpiration int `json:"token_expiration" label:"令牌过期时间"`
RefreshTokenTTL int `json:"refresh_token_ttl" label:"刷新令牌TTL"`
EnableAuthorizationCode bool `json:"enable_authorization_code" label:"启用授权码模式"`
EnableImplicitGrant bool `json:"enable_implicit_grant" label:"启用隐式授权模式"`
EnableClientCredentials bool `json:"enable_client_credentials" label:"启用客户端凭证模式"`
AcceptHttpIfAlreadyTerminated bool `json:"accept_http_if_already_terminated" label:"如果已终止,则接受HTTP"`
ReuseRefreshToken bool `json:"reuse_refresh_token" label:"重用刷新令牌"`
PersistentRefreshToken bool `json:"persistent_refresh_token" label:"持久刷新令牌"`
}

func (u *User) Username() string {
return u.Pattern.ClientId
}
Loading

0 comments on commit bbc0efe

Please sign in to comment.