Skip to content

Commit

Permalink
feat: support login from many web domain
Browse files Browse the repository at this point in the history
  • Loading branch information
cjtim committed Jun 11, 2022
1 parent f0efe50 commit 8064c2c
Show file tree
Hide file tree
Showing 10 changed files with 142 additions and 53 deletions.
3 changes: 0 additions & 3 deletions .vscode/settings.json

This file was deleted.

12 changes: 4 additions & 8 deletions configs/configs.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,10 @@ type ConfigType struct {
LineClientID string `env:"LINE_CLIENT_ID" envDefault:""`
LineSecretID string `env:"LINE_SECRET_ID" envDefault:""`

LineAPIBroadcast string `envDefault:"https://api.line.me/v2/bot/message/broadcast"`
LineAPIReply string `envDefault:"https://api.line.me/v2/bot/message/reply"`
AirVisualAPINearestCity string `envDefault:"http://api.airvisual.com/v2/nearest_city"`
AirVisualAPICity string `envDefault:"http://api.airvisual.com/v2/city"`
BinanceAccountAPI string `envDefault:"https://api.binance.com/api/v3/account"`
RebrandlyAPI string `envDefault:"https://api.rebrandly.com/v1/links"`
LogFilePath string `env:"LOG_PATH" envDefault:"/var/log/cjtim-backend-go.log"`
GCLOUD_CREDENTIAL string `env:"GCLOUD_CREDENTIAL" envDefault:"./configs/serviceAcc.json"`
LINE_WEB_CALLBACK_PATH string `env:"LINE_WEB_CALLBACK_PATH" envDefault:"/user/line/callback"`

LogFilePath string `env:"LOG_PATH" envDefault:"/var/log/cjtim-backend-go.log"`
GCLOUD_CREDENTIAL string `env:"GCLOUD_CREDENTIAL" envDefault:"./configs/serviceAcc.json"`
}

func init() {
Expand Down
26 changes: 23 additions & 3 deletions handlers/auth/line.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
package auth

import (
"fmt"
"net/http"
"time"

"github.com/cjtim/be-friends-api/configs"
"github.com/cjtim/be-friends-api/internal/auth"
"github.com/cjtim/be-friends-api/repository"
"github.com/gofiber/fiber/v2"
"go.uber.org/zap"
)

// LoginLine - GET line login url
func LoginLine(c *fiber.Ctx) error {
url := auth.GetLoginURL()
host := c.Query("host")
url := auth.GetLoginURL(host)
if url == "" {
return c.SendStatus(http.StatusInternalServerError)
}
return c.Status(http.StatusOK).SendString(url)
}

Expand All @@ -25,11 +33,23 @@ func LineCallback(c *fiber.Ctx) error {
return c.SendStatus(http.StatusBadRequest)
}

jwtToken, err := auth.Callback(code, state)
jwtToken, err := auth.Callback(code)
if err != nil {
zap.L().Error("error line callback", zap.Error(err))
return c.SendStatus(http.StatusInternalServerError)
}

return c.Status(http.StatusOK).SendString(jwtToken)
clientHost, err := repository.RedisCallback.GetCallback(state)
if err != nil {
zap.L().Error("error redis - cannot get callback by state", zap.Error(err))
}
redirectURL := fmt.Sprintf("http://%s%s", clientHost, configs.Config.LINE_WEB_CALLBACK_PATH)
authCookie := fmt.Sprintf(
"%s=%s; Max-Age=%d; Path=/",
configs.Config.JWTCookies,
jwtToken,
int64(auth.TOKEN_EXPIRE/time.Second),
)
c.Response().Header.Add("set-cookie", authCookie)
return c.Redirect(redirectURL)
}
2 changes: 1 addition & 1 deletion handlers/auth/logout.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
)

func Logout(c *fiber.Ctx) error {
err := repository.RedisRepo.RemoveJwt(c.Cookies(configs.Config.JWTCookies))
err := repository.RedisJwt.RemoveJwt(c.Cookies(configs.Config.JWTCookies))
if err != nil {
return c.SendStatus(http.StatusInternalServerError)
}
Expand Down
2 changes: 1 addition & 1 deletion handlers/middlewares/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ var GetJWTMiddleware = func(c *fiber.Ctx) error {
token = jwt
}
// Check token in Redis
notValid := !repository.RedisRepo.IsJwtValid(token)
notValid := !repository.RedisJwt.IsJwtValid(token)
if notValid {
return c.Status(http.StatusBadRequest).SendString("Invalid JWT")
}
Expand Down
2 changes: 1 addition & 1 deletion handlers/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,5 @@ func Route(r *fiber.App) {
authRoute.Get("/me", middlewares.GetJWTMiddleware, auth.Me)
authRoute.Get("/logout", auth.Logout)
authRoute.Get("/line", auth.LoginLine)
authRoute.Get("/line/jwt", auth.LineCallback)
authRoute.Get("/line/callback", auth.LineCallback)
}
22 changes: 15 additions & 7 deletions internal/auth/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/cjtim/be-friends-api/internal/utils"
"github.com/cjtim/be-friends-api/repository"
"github.com/golang-jwt/jwt/v4"
"go.uber.org/zap"
)

var (
Expand Down Expand Up @@ -44,20 +45,28 @@ func GetNewToken(u *repository.User) (*jwt.Token, string, error) {
if err != nil {
return token, "", err
}
err = repository.RedisRepo.AddJwt(t, TOKEN_EXPIRE)
err = repository.RedisJwt.AddJwt(t, TOKEN_EXPIRE)
return token, t, err
}

func GetLoginURL() string {
func GetLoginURL(clientURL string) string {
url := http.Request{
URL: &url.URL{
Scheme: "https",
Host: "access.line.me",
Path: "/oauth2/v2.1/authorize",
},
}

state := utils.RandomSeq(10)
err := repository.RedisCallback.AddCallback(state, clientURL)
if err != nil {
zap.L().Error("error redis - cannot save callback", zap.Error(err))
return ""
}

q := url.URL.Query()
q.Add("state", utils.RandomSeq(10))
q.Add("state", state)
q.Add("scope", "profile openid")
q.Add("response_type", "code")
q.Add("redirect_uri", configs.Config.LineLoginCallback)
Expand All @@ -66,11 +75,10 @@ func GetLoginURL() string {
return url.URL.String()
}

func getJWT(code, state string) (string, error) {
func getJWT(code string) (string, error) {
resp, err := http.PostForm("https://api.line.me/oauth2/v2.1/token", url.Values{
"grant_type": {"authorization_code"},
"code": {code},
"state": {state},
"redirect_uri": {configs.Config.LineLoginCallback},
"client_id": {configs.Config.LineClientID},
"client_secret": {configs.Config.LineSecretID},
Expand All @@ -92,9 +100,9 @@ func getJWT(code, state string) (string, error) {
return userInfo.IDToken, err
}

func Callback(code, state string) (string, error) {
func Callback(code string) (string, error) {
// 1. Get profile from LINE
token, err := getJWT(code, state)
token, err := getJWT(code)
if err != nil {
return "", err
}
Expand Down
48 changes: 32 additions & 16 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,34 +23,52 @@ func realMain() int {
defer logger.Sync()
zap.ReplaceGlobals(logger)

db, err := repository.Connect()
errDB := prepareDB()
if errDB > 0 {
return errDB
}

app := prepareFiber()
setupCloseHandler(app)

listen := fmt.Sprintf(":%d", configs.Config.Port)
err := app.Listen(listen)
zap.L().Info("closing fiber", zap.Error(err))
zap.L().Info("closing redis conn", zap.Errors("redis", repository.CloseAll()))
zap.L().Info("closing postgres conn", zap.Error(repository.DB.Close()))
if err != nil {
zap.L().Error("fiber error", zap.Error(err))
return 1
}
return 0
}

func prepareDB() int {
_, err := repository.Connect()
if err != nil {
zap.L().Panic("error postgresql", zap.Error(err))
return 1
}
repository.DB = db

rdb, err := repository.ConnectRedis()
_, err = repository.ConnectRedis(repository.DEFAULT)
if err != nil {
zap.L().Panic("error redis", zap.Error(err))
return 1
}
repository.REDIS = rdb

app := startServer()
setupCloseHandler(app)

listen := fmt.Sprintf(":%d", configs.Config.Port)
if err := app.Listen(listen); err != nil {
repository.DB.Close()
repository.REDIS.Close()
zap.L().Error("fiber start error", zap.Error(err))
_, err = repository.ConnectRedis(repository.JWT)
if err != nil {
zap.L().Panic("error redis", zap.Error(err))
return 1
}
_, err = repository.ConnectRedis(repository.CALLBACK)
if err != nil {
zap.L().Panic("error redis", zap.Error(err))
return 1
}
return 0
}

func startServer() *fiber.App {
func prepareFiber() *fiber.App {
app := fiber.New(fiber.Config{
ErrorHandler: middlewares.ErrorHandling,
BodyLimit: 100 * 1024 * 1024, // Limit file size to 100MB
Expand All @@ -69,8 +87,6 @@ func setupCloseHandler(app *fiber.App) {
go func() {
<-c
zap.L().Info("Got SIGTERM, terminating program...")
repository.REDIS.Close()
repository.DB.Close()
app.Server().Shutdown()
}()
}
68 changes: 56 additions & 12 deletions repository/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,75 @@ import (
)

var (
REDIS *redis.Client
RedisRepo *RedisImpl
RedisDefault *defaultImpl
RedisJwt *jwtImpl
RedisCallback *callbackImpl
)

func ConnectRedis() (rdb *redis.Client, err error) {
type RedisDatabase int

const (
DEFAULT RedisDatabase = 0
JWT RedisDatabase = 1
CALLBACK RedisDatabase = 2
)

func ConnectRedis(db RedisDatabase) (rdb *redis.Client, err error) {
rdb = redis.NewClient(&redis.Options{
Addr: configs.Config.REDIS_URL,
Password: "", // no password set
DB: 0, // use default DB
Password: "", // no password set
DB: int(db), // use default DB
})
switch db {
case JWT:
RedisJwt = &jwtImpl{Client: rdb}
case CALLBACK:
RedisCallback = &callbackImpl{Client: rdb}
default:
RedisDefault = &defaultImpl{Client: rdb}
}

err = rdb.Set(context.Background(), "test", "test", time.Second*1).Err()
return
}

type RedisImpl struct{}
func CloseAll() []error {
return []error{
RedisDefault.Client.Close(),
RedisJwt.Client.Close(),
RedisCallback.Client.Close(),
}
}

type defaultImpl struct {
Client *redis.Client
}

func (r *RedisImpl) IsJwtValid(token string) bool {
err := REDIS.Get(context.Background(), token).Err()
type jwtImpl struct {
Client *redis.Client
}

type callbackImpl struct {
Client *redis.Client
}

func (r *jwtImpl) IsJwtValid(token string) bool {
err := r.Client.Get(context.Background(), token).Err()
return err != redis.Nil
}

func (r *RedisImpl) AddJwt(token string, expire time.Duration) error {
return REDIS.Set(context.Background(), token, "", expire).Err()
func (r *jwtImpl) AddJwt(token string, expire time.Duration) error {
return r.Client.Set(context.Background(), token, "", expire).Err()
}

func (r *jwtImpl) RemoveJwt(token string) error {
return r.Client.Del(context.Background(), token).Err()
}

func (r *callbackImpl) AddCallback(state, callback string) error {
return r.Client.Set(context.Background(), state, callback, time.Minute*15).Err()
}

func (r *RedisImpl) RemoveJwt(token string) error {
return REDIS.Del(context.Background(), token).Err()
func (r *callbackImpl) GetCallback(state string) (string, error) {
return r.Client.Get(context.Background(), state).Result()
}
10 changes: 9 additions & 1 deletion repository/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,13 @@ var (
)

func Connect() (*sqlx.DB, error) {
return sqlx.Open("postgres", configs.Config.DATABASE_URL)
db, err := sqlx.Open("postgres", configs.Config.DATABASE_URL)
if err != nil {
return nil, err
}
if db.Ping() != nil {
return nil, err
}
DB = db
return db, err
}

0 comments on commit 8064c2c

Please sign in to comment.