diff --git a/.gitignore b/.gitignore index 9970888..b769b3e 100644 --- a/.gitignore +++ b/.gitignore @@ -37,4 +37,6 @@ bin/ dist/ .release-env -*.db \ No newline at end of file +*.db + +.golangci.yml \ No newline at end of file diff --git a/Makefile b/Makefile index e141dcf..052442f 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,7 @@ SHELL:=/bin/bash LINT_FILE_TAG=v0.2.0 LINT_FILE_URL=https://raw.githubusercontent.com/tuihub/librarian/$(LINT_FILE_TAG)/.golangci.yml +LINT_FILE_LOCAL=.golangci.yml .PHONY: init # init env @@ -9,9 +10,12 @@ init: .PHONY: lint # lint files -lint: - golangci-lint run --fix -c <(curl -sSL $(LINT_FILE_URL)) - golangci-lint run -c <(curl -sSL $(LINT_FILE_URL)) # re-run to make sure fixes are valid, useful in some condition +lint: $(LINT_FILE_LOCAL) + golangci-lint run --fix -c $(LINT_FILE_LOCAL) + golangci-lint run -c $(LINT_FILE_LOCAL) # re-run to make sure fixes are valid, useful in some condition + +$(LINT_FILE_LOCAL): + curl -sSL $(LINT_FILE_URL) -o $(LINT_FILE_LOCAL) # show help help: diff --git a/handler.go b/handler.go deleted file mode 100644 index 4dfa13c..0000000 --- a/handler.go +++ /dev/null @@ -1,18 +0,0 @@ -package tuihub - -import ( - "context" - - pb "github.com/tuihub/protos/pkg/librarian/porter/v1" -) - -type Handler interface { - PullAccount(context.Context, *pb.PullAccountRequest) (*pb.PullAccountResponse, error) - PullAppInfo(context.Context, *pb.PullAppInfoRequest) (*pb.PullAppInfoResponse, error) - PullAccountAppInfoRelation( - context.Context, *pb.PullAccountAppInfoRelationRequest, - ) (*pb.PullAccountAppInfoRelationResponse, error) - SearchAppInfo(context.Context, *pb.SearchAppInfoRequest) (*pb.SearchAppInfoResponse, error) - PullFeed(context.Context, *pb.PullFeedRequest) (*pb.PullFeedResponse, error) - PushFeedItems(context.Context, *pb.PushFeedItemsRequest) (*pb.PushFeedItemsResponse, error) -} diff --git a/porter.go b/porter.go index 55f427e..2ef4ac6 100644 --- a/porter.go +++ b/porter.go @@ -22,6 +22,7 @@ import ( ) const ( + serviceID = "PORTER_SERVICE_ID" serverNetwork = "SERVER_NETWORK" serverAddr = "SERVER_ADDRESS" serverTimeout = "SERVER_TIMEOUT" @@ -77,11 +78,11 @@ func (p *Porter) Stop() error { func NewPorter( ctx context.Context, info *porter.GetPorterInformationResponse, - handler Handler, + service porter.LibrarianPorterServiceServer, options ...PorterOption, ) (*Porter, error) { - if handler == nil { - return nil, errors.New("handler is nil") + if service == nil { + return nil, errors.New("service is nil") } if info.GetBinarySummary() == nil { return nil, errors.New("binary summary is nil") @@ -112,13 +113,13 @@ func NewPorter( return nil, err } c := serviceWrapper{ - Handler: handler, - Info: info, - Logger: p.logger, - Client: client, - RequireToken: p.requireAsUser, - Token: nil, - lastHeartbeat: time.Time{}, + LibrarianPorterServiceServer: service, + Info: info, + Logger: p.logger, + Client: client, + RequireToken: p.requireAsUser, + Token: nil, + lastHeartbeat: time.Time{}, } p.wrapper = c p.server = NewServer( @@ -128,8 +129,12 @@ func NewPorter( ) id, _ := os.Hostname() name := "porter" + id = fmt.Sprintf("%s-%s-%s", id, name, info.GetBinarySummary().GetName()) + if customID, exist := os.LookupEnv(serviceID); exist { + id = fmt.Sprintf("%s-%s", id, customID) + } app := kratos.New( - kratos.ID(id+name), + kratos.ID(id), kratos.Name(name), kratos.Version(p.wrapper.Info.GetBinarySummary().GetBuildVersion()), kratos.Metadata(map[string]string{ diff --git a/utils.go b/utils.go index af02025..fdcfb5d 100644 --- a/utils.go +++ b/utils.go @@ -1,6 +1,10 @@ package tuihub -import "github.com/invopop/jsonschema" +import ( + "github.com/invopop/jsonschema" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) func ReflectJSONSchema(v interface{}) (string, error) { r := new(jsonschema.Reflector) @@ -20,3 +24,15 @@ func MustReflectJSONSchema(v interface{}) string { } return j } + +// isUnimplementedError checks if the error is a gRPC unimplemented error. +func isUnimplementedError(err error) bool { + if err == nil { + return false + } + st, ok := status.FromError(err) + if !ok { + return false + } + return st.Code() == codes.Unimplemented +} diff --git a/wrapper.go b/wrapper.go index a7a4101..cc84913 100644 --- a/wrapper.go +++ b/wrapper.go @@ -24,7 +24,7 @@ const ( ) type serviceWrapper struct { - Handler Handler + pb.LibrarianPorterServiceServer Info *pb.GetPorterInformationResponse Logger log.Logger Client sephirah.LibrarianSephirahServiceClient @@ -46,41 +46,45 @@ func (s *serviceWrapper) GetPorterInformation(ctx context.Context, req *pb.GetPo } func (s *serviceWrapper) EnablePorter(ctx context.Context, req *pb.EnablePorterRequest) ( *pb.EnablePorterResponse, error) { - if s.Token != nil { - if s.Token.enabler == req.GetSephirahId() { - return &pb.EnablePorterResponse{ - StatusMessage: "", - NeedRefreshToken: false, - EnablesSummary: nil, - }, nil - } else if s.lastHeartbeat.Add(defaultHeartbeatTimeout).After(time.Now()) { - return nil, fmt.Errorf("porter already enabled by %d", s.Token.enabler) + needRefreshToken := false + f := func() error { + if s.Token != nil { + if s.Token.enabler == req.GetSephirahId() { + return nil + } else if s.lastHeartbeat.Add(defaultHeartbeatTimeout).After(time.Now()) { + return fmt.Errorf("porter already enabled by %d", s.Token.enabler) + } } - } - s.lastHeartbeat = time.Now() - if s.RequireToken { - ctx = metadata.AppendToOutgoingContext(ctx, "authorization", "Bearer "+req.GetRefreshToken()) - resp, err := s.Client.RefreshToken(ctx, new(sephirah.RefreshTokenRequest)) - if err == nil { + s.Token = new(tokenInfo) + s.Token.enabler = req.GetSephirahId() + s.lastHeartbeat = time.Now() + if s.RequireToken { + if req.GetRefreshToken() == "" { + needRefreshToken = true + return nil + } + ctx = metadata.AppendToOutgoingContext(ctx, "authorization", "Bearer "+req.GetRefreshToken()) + resp, err := s.Client.RefreshToken(ctx, new(sephirah.RefreshTokenRequest)) + if err != nil { + return err + } s.Token = &tokenInfo{ enabler: req.GetSephirahId(), AccessToken: resp.GetAccessToken(), refreshToken: resp.GetRefreshToken(), } - return &pb.EnablePorterResponse{ - StatusMessage: "", - NeedRefreshToken: false, - EnablesSummary: nil, - }, nil } + return nil + } + if err := f(); err != nil { + return nil, err + } + if resp, err := s.LibrarianPorterServiceServer.EnablePorter(ctx, req); isUnimplementedError(err) { + return new(pb.EnablePorterResponse), nil + } else { + resp.NeedRefreshToken = needRefreshToken + return resp, err } - s.Token = new(tokenInfo) - s.Token.enabler = req.GetSephirahId() - return &pb.EnablePorterResponse{ - StatusMessage: "", - NeedRefreshToken: true, - EnablesSummary: nil, - }, nil } func (s *serviceWrapper) Enabled() bool { return s.Token != nil @@ -110,28 +114,26 @@ func NewServer(c *ServerConfig, service pb.LibrarianPorterServiceServer, logger } type service struct { - pb.UnimplementedLibrarianPorterServiceServer - p serviceWrapper + serviceWrapper } func NewService(p serviceWrapper) pb.LibrarianPorterServiceServer { return &service{ - UnimplementedLibrarianPorterServiceServer: pb.UnimplementedLibrarianPorterServiceServer{}, - p: p, + p, } } func (s *service) GetPorterInformation(ctx context.Context, req *pb.GetPorterInformationRequest) ( *pb.GetPorterInformationResponse, error) { - return s.p.GetPorterInformation(ctx, req) + return s.serviceWrapper.GetPorterInformation(ctx, req) } func (s *service) EnablePorter(ctx context.Context, req *pb.EnablePorterRequest) ( *pb.EnablePorterResponse, error) { - return s.p.EnablePorter(ctx, req) + return s.serviceWrapper.EnablePorter(ctx, req) } func (s *service) PullAccount(ctx context.Context, req *pb.PullAccountRequest) ( *pb.PullAccountResponse, error) { - if !s.p.Enabled() { + if !s.serviceWrapper.Enabled() { return nil, errors.Forbidden("Unauthorized caller", "") } if req.GetAccountId() == nil || @@ -139,15 +141,15 @@ func (s *service) PullAccount(ctx context.Context, req *pb.PullAccountRequest) ( req.GetAccountId().GetPlatformAccountId() == "" { return nil, errors.BadRequest("Invalid account id", "") } - for _, account := range s.p.Info.GetFeatureSummary().GetAccountPlatforms() { + for _, account := range s.serviceWrapper.Info.GetFeatureSummary().GetAccountPlatforms() { if account.GetId() == req.GetAccountId().GetPlatform() { - return s.p.Handler.PullAccount(ctx, req) + return s.serviceWrapper.LibrarianPorterServiceServer.PullAccount(ctx, req) } } return nil, errors.BadRequest("Unsupported account platform", "") } func (s *service) PullAppInfo(ctx context.Context, req *pb.PullAppInfoRequest) (*pb.PullAppInfoResponse, error) { - if !s.p.Enabled() { + if !s.serviceWrapper.Enabled() { return nil, errors.Forbidden("Unauthorized caller", "") } if req.GetAppInfoId() == nil || @@ -156,16 +158,16 @@ func (s *service) PullAppInfo(ctx context.Context, req *pb.PullAppInfoRequest) ( req.GetAppInfoId().GetSourceAppId() == "" { return nil, errors.BadRequest("Invalid app id", "") } - for _, source := range s.p.Info.GetFeatureSummary().GetAppInfoSources() { + for _, source := range s.serviceWrapper.Info.GetFeatureSummary().GetAppInfoSources() { if source.GetId() == req.GetAppInfoId().GetSource() { - return s.p.Handler.PullAppInfo(ctx, req) + return s.serviceWrapper.LibrarianPorterServiceServer.PullAppInfo(ctx, req) } } return nil, errors.BadRequest("Unsupported app source", "") } func (s *service) PullAccountAppInfoRelation(ctx context.Context, req *pb.PullAccountAppInfoRelationRequest) ( *pb.PullAccountAppInfoRelationResponse, error) { - if !s.p.Enabled() { + if !s.serviceWrapper.Enabled() { return nil, errors.Forbidden("Unauthorized caller", "") } if req.GetAccountId() == nil || @@ -173,44 +175,44 @@ func (s *service) PullAccountAppInfoRelation(ctx context.Context, req *pb.PullAc req.GetAccountId().GetPlatform() == "" || req.GetAccountId().GetPlatformAccountId() == "" { return nil, errors.BadRequest("Invalid account id", "") } - for _, account := range s.p.Info.GetFeatureSummary().GetAccountPlatforms() { + for _, account := range s.serviceWrapper.Info.GetFeatureSummary().GetAccountPlatforms() { if account.GetId() == req.GetAccountId().GetPlatform() { - return s.p.Handler.PullAccountAppInfoRelation(ctx, req) + return s.serviceWrapper.LibrarianPorterServiceServer.PullAccountAppInfoRelation(ctx, req) } } return nil, errors.BadRequest("Unsupported account", "") } func (s *service) SearchAppInfo(ctx context.Context, req *pb.SearchAppInfoRequest) (*pb.SearchAppInfoResponse, error) { - if !s.p.Enabled() { + if !s.serviceWrapper.Enabled() { return nil, errors.Forbidden("Unauthorized caller", "") } if req.GetName() == "" { return nil, errors.BadRequest("Invalid app name", "") } - if len(s.p.Info.GetFeatureSummary().GetAppInfoSources()) > 0 { - return s.p.Handler.SearchAppInfo(ctx, req) + if len(s.serviceWrapper.Info.GetFeatureSummary().GetAppInfoSources()) > 0 { + return s.serviceWrapper.LibrarianPorterServiceServer.SearchAppInfo(ctx, req) } return nil, errors.BadRequest("Unsupported app source", "") } func (s *service) PullFeed(ctx context.Context, req *pb.PullFeedRequest) (*pb.PullFeedResponse, error) { - if !s.p.Enabled() { + if !s.serviceWrapper.Enabled() { return nil, errors.Forbidden("Unauthorized caller", "") } - for _, source := range s.p.Info.GetFeatureSummary().GetFeedSources() { + for _, source := range s.serviceWrapper.Info.GetFeatureSummary().GetFeedSources() { if source.GetId() == req.GetSource().GetId() { - return s.p.Handler.PullFeed(ctx, req) + return s.serviceWrapper.LibrarianPorterServiceServer.PullFeed(ctx, req) } } return nil, errors.BadRequest("Unsupported feed source", "") } func (s *service) PushFeedItems(ctx context.Context, req *pb.PushFeedItemsRequest) ( *pb.PushFeedItemsResponse, error) { - if !s.p.Enabled() { + if !s.serviceWrapper.Enabled() { return nil, errors.Forbidden("Unauthorized caller", "") } - for _, destination := range s.p.Info.GetFeatureSummary().GetNotifyDestinations() { + for _, destination := range s.serviceWrapper.Info.GetFeatureSummary().GetNotifyDestinations() { if destination.GetId() == req.GetDestination().GetId() { - return s.p.Handler.PushFeedItems(ctx, req) + return s.serviceWrapper.LibrarianPorterServiceServer.PushFeedItems(ctx, req) } } return nil, errors.BadRequest("Unsupported notify destination", "")