Skip to content

Commit

Permalink
fix: shrink service wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
MuZhou233 committed Sep 20, 2024
1 parent 764d1e5 commit f21d4d8
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 86 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,6 @@ bin/

dist/
.release-env
*.db
*.db

.golangci.yml
10 changes: 7 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
SHELL:=/bin/bash
LINT_FILE_TAG=v0.2.0
LINT_FILE_URL=https://raw.githubusercontent.com/tuihub/librarian/$(LINT_FILE_TAG)/.golangci.yml
LINT_FILE_LOCAL=.golangci.yml

.PHONY: init
# init env
Expand All @@ -9,9 +10,12 @@ init:

.PHONY: lint
# lint files
lint:
golangci-lint run --fix -c <(curl -sSL $(LINT_FILE_URL))
golangci-lint run -c <(curl -sSL $(LINT_FILE_URL)) # re-run to make sure fixes are valid, useful in some condition
lint: $(LINT_FILE_LOCAL)
golangci-lint run --fix -c $(LINT_FILE_LOCAL)
golangci-lint run -c $(LINT_FILE_LOCAL) # re-run to make sure fixes are valid, useful in some condition

$(LINT_FILE_LOCAL):
curl -sSL $(LINT_FILE_URL) -o $(LINT_FILE_LOCAL)

# show help
help:
Expand Down
18 changes: 0 additions & 18 deletions handler.go

This file was deleted.

27 changes: 16 additions & 11 deletions porter.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
)

const (
serviceID = "PORTER_SERVICE_ID"
serverNetwork = "SERVER_NETWORK"
serverAddr = "SERVER_ADDRESS"
serverTimeout = "SERVER_TIMEOUT"
Expand Down Expand Up @@ -77,11 +78,11 @@ func (p *Porter) Stop() error {
func NewPorter(
ctx context.Context,
info *porter.GetPorterInformationResponse,
handler Handler,
service porter.LibrarianPorterServiceServer,
options ...PorterOption,
) (*Porter, error) {
if handler == nil {
return nil, errors.New("handler is nil")
if service == nil {
return nil, errors.New("service is nil")
}
if info.GetBinarySummary() == nil {
return nil, errors.New("binary summary is nil")
Expand Down Expand Up @@ -112,13 +113,13 @@ func NewPorter(
return nil, err
}
c := serviceWrapper{
Handler: handler,
Info: info,
Logger: p.logger,
Client: client,
RequireToken: p.requireAsUser,
Token: nil,
lastHeartbeat: time.Time{},
LibrarianPorterServiceServer: service,
Info: info,
Logger: p.logger,
Client: client,
RequireToken: p.requireAsUser,
Token: nil,
lastHeartbeat: time.Time{},
}
p.wrapper = c
p.server = NewServer(
Expand All @@ -128,8 +129,12 @@ func NewPorter(
)
id, _ := os.Hostname()
name := "porter"
id = fmt.Sprintf("%s-%s-%s", id, name, info.GetBinarySummary().GetName())
if customID, exist := os.LookupEnv(serviceID); exist {
id = fmt.Sprintf("%s-%s", id, customID)
}
app := kratos.New(
kratos.ID(id+name),
kratos.ID(id),
kratos.Name(name),
kratos.Version(p.wrapper.Info.GetBinarySummary().GetBuildVersion()),
kratos.Metadata(map[string]string{
Expand Down
18 changes: 17 additions & 1 deletion utils.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package tuihub

import "github.com/invopop/jsonschema"
import (
"github.com/invopop/jsonschema"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

func ReflectJSONSchema(v interface{}) (string, error) {
r := new(jsonschema.Reflector)
Expand All @@ -20,3 +24,15 @@ func MustReflectJSONSchema(v interface{}) string {
}
return j
}

// isUnimplementedError checks if the error is a gRPC unimplemented error.
func isUnimplementedError(err error) bool {
if err == nil {
return false
}
st, ok := status.FromError(err)
if !ok {
return false
}
return st.Code() == codes.Unimplemented
}
106 changes: 54 additions & 52 deletions wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ const (
)

type serviceWrapper struct {
Handler Handler
pb.LibrarianPorterServiceServer
Info *pb.GetPorterInformationResponse
Logger log.Logger
Client sephirah.LibrarianSephirahServiceClient
Expand All @@ -46,41 +46,45 @@ func (s *serviceWrapper) GetPorterInformation(ctx context.Context, req *pb.GetPo
}
func (s *serviceWrapper) EnablePorter(ctx context.Context, req *pb.EnablePorterRequest) (
*pb.EnablePorterResponse, error) {
if s.Token != nil {
if s.Token.enabler == req.GetSephirahId() {
return &pb.EnablePorterResponse{
StatusMessage: "",
NeedRefreshToken: false,
EnablesSummary: nil,
}, nil
} else if s.lastHeartbeat.Add(defaultHeartbeatTimeout).After(time.Now()) {
return nil, fmt.Errorf("porter already enabled by %d", s.Token.enabler)
needRefreshToken := false
f := func() error {
if s.Token != nil {
if s.Token.enabler == req.GetSephirahId() {
return nil
} else if s.lastHeartbeat.Add(defaultHeartbeatTimeout).After(time.Now()) {
return fmt.Errorf("porter already enabled by %d", s.Token.enabler)
}
}
}
s.lastHeartbeat = time.Now()
if s.RequireToken {
ctx = metadata.AppendToOutgoingContext(ctx, "authorization", "Bearer "+req.GetRefreshToken())
resp, err := s.Client.RefreshToken(ctx, new(sephirah.RefreshTokenRequest))
if err == nil {
s.Token = new(tokenInfo)
s.Token.enabler = req.GetSephirahId()
s.lastHeartbeat = time.Now()
if s.RequireToken {
if req.GetRefreshToken() == "" {
needRefreshToken = true
return nil
}
ctx = metadata.AppendToOutgoingContext(ctx, "authorization", "Bearer "+req.GetRefreshToken())
resp, err := s.Client.RefreshToken(ctx, new(sephirah.RefreshTokenRequest))
if err != nil {
return err
}
s.Token = &tokenInfo{
enabler: req.GetSephirahId(),
AccessToken: resp.GetAccessToken(),
refreshToken: resp.GetRefreshToken(),
}
return &pb.EnablePorterResponse{
StatusMessage: "",
NeedRefreshToken: false,
EnablesSummary: nil,
}, nil
}
return nil
}
if err := f(); err != nil {
return nil, err
}
if resp, err := s.LibrarianPorterServiceServer.EnablePorter(ctx, req); isUnimplementedError(err) {
return new(pb.EnablePorterResponse), nil
} else {
resp.NeedRefreshToken = needRefreshToken
return resp, err
}
s.Token = new(tokenInfo)
s.Token.enabler = req.GetSephirahId()
return &pb.EnablePorterResponse{
StatusMessage: "",
NeedRefreshToken: true,
EnablesSummary: nil,
}, nil
}
func (s *serviceWrapper) Enabled() bool {
return s.Token != nil
Expand Down Expand Up @@ -110,44 +114,42 @@ func NewServer(c *ServerConfig, service pb.LibrarianPorterServiceServer, logger
}

type service struct {
pb.UnimplementedLibrarianPorterServiceServer
p serviceWrapper
serviceWrapper
}

func NewService(p serviceWrapper) pb.LibrarianPorterServiceServer {
return &service{
UnimplementedLibrarianPorterServiceServer: pb.UnimplementedLibrarianPorterServiceServer{},
p: p,
p,
}
}

func (s *service) GetPorterInformation(ctx context.Context, req *pb.GetPorterInformationRequest) (
*pb.GetPorterInformationResponse, error) {
return s.p.GetPorterInformation(ctx, req)
return s.serviceWrapper.GetPorterInformation(ctx, req)
}
func (s *service) EnablePorter(ctx context.Context, req *pb.EnablePorterRequest) (
*pb.EnablePorterResponse, error) {
return s.p.EnablePorter(ctx, req)
return s.serviceWrapper.EnablePorter(ctx, req)
}
func (s *service) PullAccount(ctx context.Context, req *pb.PullAccountRequest) (
*pb.PullAccountResponse, error) {
if !s.p.Enabled() {
if !s.serviceWrapper.Enabled() {
return nil, errors.Forbidden("Unauthorized caller", "")
}
if req.GetAccountId() == nil ||
req.GetAccountId().GetPlatform() == "" ||
req.GetAccountId().GetPlatformAccountId() == "" {
return nil, errors.BadRequest("Invalid account id", "")
}
for _, account := range s.p.Info.GetFeatureSummary().GetAccountPlatforms() {
for _, account := range s.serviceWrapper.Info.GetFeatureSummary().GetAccountPlatforms() {
if account.GetId() == req.GetAccountId().GetPlatform() {
return s.p.Handler.PullAccount(ctx, req)
return s.serviceWrapper.LibrarianPorterServiceServer.PullAccount(ctx, req)
}
}
return nil, errors.BadRequest("Unsupported account platform", "")
}
func (s *service) PullAppInfo(ctx context.Context, req *pb.PullAppInfoRequest) (*pb.PullAppInfoResponse, error) {
if !s.p.Enabled() {
if !s.serviceWrapper.Enabled() {
return nil, errors.Forbidden("Unauthorized caller", "")
}
if req.GetAppInfoId() == nil ||
Expand All @@ -156,61 +158,61 @@ func (s *service) PullAppInfo(ctx context.Context, req *pb.PullAppInfoRequest) (
req.GetAppInfoId().GetSourceAppId() == "" {
return nil, errors.BadRequest("Invalid app id", "")
}
for _, source := range s.p.Info.GetFeatureSummary().GetAppInfoSources() {
for _, source := range s.serviceWrapper.Info.GetFeatureSummary().GetAppInfoSources() {
if source.GetId() == req.GetAppInfoId().GetSource() {
return s.p.Handler.PullAppInfo(ctx, req)
return s.serviceWrapper.LibrarianPorterServiceServer.PullAppInfo(ctx, req)
}
}
return nil, errors.BadRequest("Unsupported app source", "")
}
func (s *service) PullAccountAppInfoRelation(ctx context.Context, req *pb.PullAccountAppInfoRelationRequest) (
*pb.PullAccountAppInfoRelationResponse, error) {
if !s.p.Enabled() {
if !s.serviceWrapper.Enabled() {
return nil, errors.Forbidden("Unauthorized caller", "")
}
if req.GetAccountId() == nil ||
req.GetRelationType() == librarian.AccountAppRelationType_ACCOUNT_APP_RELATION_TYPE_UNSPECIFIED ||
req.GetAccountId().GetPlatform() == "" || req.GetAccountId().GetPlatformAccountId() == "" {
return nil, errors.BadRequest("Invalid account id", "")
}
for _, account := range s.p.Info.GetFeatureSummary().GetAccountPlatforms() {
for _, account := range s.serviceWrapper.Info.GetFeatureSummary().GetAccountPlatforms() {
if account.GetId() == req.GetAccountId().GetPlatform() {
return s.p.Handler.PullAccountAppInfoRelation(ctx, req)
return s.serviceWrapper.LibrarianPorterServiceServer.PullAccountAppInfoRelation(ctx, req)
}
}
return nil, errors.BadRequest("Unsupported account", "")
}
func (s *service) SearchAppInfo(ctx context.Context, req *pb.SearchAppInfoRequest) (*pb.SearchAppInfoResponse, error) {
if !s.p.Enabled() {
if !s.serviceWrapper.Enabled() {
return nil, errors.Forbidden("Unauthorized caller", "")
}
if req.GetName() == "" {
return nil, errors.BadRequest("Invalid app name", "")
}
if len(s.p.Info.GetFeatureSummary().GetAppInfoSources()) > 0 {
return s.p.Handler.SearchAppInfo(ctx, req)
if len(s.serviceWrapper.Info.GetFeatureSummary().GetAppInfoSources()) > 0 {
return s.serviceWrapper.LibrarianPorterServiceServer.SearchAppInfo(ctx, req)
}
return nil, errors.BadRequest("Unsupported app source", "")
}
func (s *service) PullFeed(ctx context.Context, req *pb.PullFeedRequest) (*pb.PullFeedResponse, error) {
if !s.p.Enabled() {
if !s.serviceWrapper.Enabled() {
return nil, errors.Forbidden("Unauthorized caller", "")
}
for _, source := range s.p.Info.GetFeatureSummary().GetFeedSources() {
for _, source := range s.serviceWrapper.Info.GetFeatureSummary().GetFeedSources() {
if source.GetId() == req.GetSource().GetId() {
return s.p.Handler.PullFeed(ctx, req)
return s.serviceWrapper.LibrarianPorterServiceServer.PullFeed(ctx, req)
}
}
return nil, errors.BadRequest("Unsupported feed source", "")
}
func (s *service) PushFeedItems(ctx context.Context, req *pb.PushFeedItemsRequest) (
*pb.PushFeedItemsResponse, error) {
if !s.p.Enabled() {
if !s.serviceWrapper.Enabled() {
return nil, errors.Forbidden("Unauthorized caller", "")
}
for _, destination := range s.p.Info.GetFeatureSummary().GetNotifyDestinations() {
for _, destination := range s.serviceWrapper.Info.GetFeatureSummary().GetNotifyDestinations() {
if destination.GetId() == req.GetDestination().GetId() {
return s.p.Handler.PushFeedItems(ctx, req)
return s.serviceWrapper.LibrarianPorterServiceServer.PushFeedItems(ctx, req)
}
}
return nil, errors.BadRequest("Unsupported notify destination", "")
Expand Down

0 comments on commit f21d4d8

Please sign in to comment.