diff --git a/.mockery.yaml b/.mockery.yaml index 630d57a4..b283d57e 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -17,6 +17,10 @@ packages: SpaceComponent: RuntimeArchitectureComponent: SensitiveComponent: + opencsg.com/csghub-server/component/callback: + config: + interfaces: + SyncVersionGenerator: opencsg.com/csghub-server/user/component: config: interfaces: diff --git a/Makefile b/Makefile index 21c7b05a..739643b5 100644 --- a/Makefile +++ b/Makefile @@ -13,10 +13,12 @@ cover: mock_wire: @echo "Running wire for component mocks..." - @go run -mod=mod github.com/google/wire/cmd/wire opencsg.com/csghub-server/component + @go run -mod=mod github.com/google/wire/cmd/wire opencsg.com/csghub-server/component/... @if [ $$? -eq 0 ]; then \ - echo "Renaming wire_gen.go to wire_gen_test.go..."; \ + echo "Renaming component wire_gen.go to wire_gen_test.go..."; \ mv component/wire_gen.go component/wire_gen_test.go; \ + echo "Renaming component/callback wire_gen.go to wire_gen_test.go..."; \ + mv component/callback/wire_gen.go component/callback/wire_gen_test.go; \ else \ echo "Wire failed, skipping renaming."; \ fi diff --git a/_mocks/opencsg.com/csghub-server/component/callback/mock_SyncVersionGenerator.go b/_mocks/opencsg.com/csghub-server/component/callback/mock_SyncVersionGenerator.go new file mode 100644 index 00000000..ffd72a5e --- /dev/null +++ b/_mocks/opencsg.com/csghub-server/component/callback/mock_SyncVersionGenerator.go @@ -0,0 +1,81 @@ +// Code generated by mockery v2.49.1. DO NOT EDIT. + +package callback + +import ( + mock "github.com/stretchr/testify/mock" + types "opencsg.com/csghub-server/common/types" +) + +// MockSyncVersionGenerator is an autogenerated mock type for the SyncVersionGenerator type +type MockSyncVersionGenerator struct { + mock.Mock +} + +type MockSyncVersionGenerator_Expecter struct { + mock *mock.Mock +} + +func (_m *MockSyncVersionGenerator) EXPECT() *MockSyncVersionGenerator_Expecter { + return &MockSyncVersionGenerator_Expecter{mock: &_m.Mock} +} + +// GenSyncVersion provides a mock function with given fields: req +func (_m *MockSyncVersionGenerator) GenSyncVersion(req *types.GiteaCallbackPushReq) error { + ret := _m.Called(req) + + if len(ret) == 0 { + panic("no return value specified for GenSyncVersion") + } + + var r0 error + if rf, ok := ret.Get(0).(func(*types.GiteaCallbackPushReq) error); ok { + r0 = rf(req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockSyncVersionGenerator_GenSyncVersion_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GenSyncVersion' +type MockSyncVersionGenerator_GenSyncVersion_Call struct { + *mock.Call +} + +// GenSyncVersion is a helper method to define mock.On call +// - req *types.GiteaCallbackPushReq +func (_e *MockSyncVersionGenerator_Expecter) GenSyncVersion(req interface{}) *MockSyncVersionGenerator_GenSyncVersion_Call { + return &MockSyncVersionGenerator_GenSyncVersion_Call{Call: _e.mock.On("GenSyncVersion", req)} +} + +func (_c *MockSyncVersionGenerator_GenSyncVersion_Call) Run(run func(req *types.GiteaCallbackPushReq)) *MockSyncVersionGenerator_GenSyncVersion_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(*types.GiteaCallbackPushReq)) + }) + return _c +} + +func (_c *MockSyncVersionGenerator_GenSyncVersion_Call) Return(_a0 error) *MockSyncVersionGenerator_GenSyncVersion_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockSyncVersionGenerator_GenSyncVersion_Call) RunAndReturn(run func(*types.GiteaCallbackPushReq) error) *MockSyncVersionGenerator_GenSyncVersion_Call { + _c.Call.Return(run) + return _c +} + +// NewMockSyncVersionGenerator creates a new instance of MockSyncVersionGenerator. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockSyncVersionGenerator(t interface { + mock.TestingT + Cleanup(func()) +}) *MockSyncVersionGenerator { + mock := &MockSyncVersionGenerator{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/api/handler/callback/git_callback.go b/api/handler/callback/git_callback.go index c0d62056..1fda23f8 100644 --- a/api/handler/callback/git_callback.go +++ b/api/handler/callback/git_callback.go @@ -14,7 +14,7 @@ import ( ) type GitCallbackHandler struct { - cbc *component.GitCallbackComponent + cbc component.GitCallbackComponent config *config.Config } diff --git a/common/tests/stores.go b/common/tests/stores.go index a091a286..507a7157 100644 --- a/common/tests/stores.go +++ b/common/tests/stores.go @@ -14,6 +14,7 @@ type MockStores struct { Model database.ModelStore SpaceResource database.SpaceResourceStore Tag database.TagStore + TagRule database.TagRuleStore Dataset database.DatasetStore PromptConversation database.PromptConversationStore PromptPrefix database.PromptPrefixStore @@ -94,6 +95,7 @@ func NewMockStores(t interface { Telemetry: mockdb.NewMockTelemetryStore(t), RepoFile: mockdb.NewMockRepoFileStore(t), Event: mockdb.NewMockEventStore(t), + TagRule: mockdb.NewMockTagRuleStore(t), } } @@ -125,6 +127,10 @@ func (s *MockStores) TagMock() *mockdb.MockTagStore { return s.Tag.(*mockdb.MockTagStore) } +func (s *MockStores) TagRuleMock() *mockdb.MockTagRuleStore { + return s.TagRule.(*mockdb.MockTagRuleStore) +} + func (s *MockStores) DatasetMock() *mockdb.MockDatasetStore { return s.Dataset.(*mockdb.MockDatasetStore) } diff --git a/common/types/prompt.go b/common/types/prompt.go index 31522bbe..7ad62e90 100644 --- a/common/types/prompt.go +++ b/common/types/prompt.go @@ -1,6 +1,8 @@ package types -import "time" +import ( + "time" +) type PromptReq struct { Namespace string `json:"namespace"` @@ -101,3 +103,31 @@ type PromptRes struct { CanManage bool `json:"can_manage"` Namespace *Namespace `json:"namespace"` } + +type Prompt struct { + Title string `json:"title" binding:"required"` + Content string `json:"content" binding:"required"` + Language string `json:"language" binding:"required"` + Tags []string `json:"tags"` + Type string `json:"type"` // "text|image|video|audio" + Source string `json:"source"` + Author string `json:"author"` + Time string `json:"time"` + Copyright string `json:"copyright"` + Feedback []string `json:"feedback"` +} + +type PromptOutput struct { + Prompt + FilePath string `json:"file_path"` + CanWrite bool `json:"can_write"` + CanManage bool `json:"can_manage"` +} + +type CreatePromptReq struct { + Prompt +} + +type UpdatePromptReq struct { + Prompt +} diff --git a/component/callback/git_callback.go b/component/callback/git_callback.go index b1bddcd4..0810f1e4 100644 --- a/component/callback/git_callback.go +++ b/component/callback/git_callback.go @@ -21,33 +21,39 @@ import ( "opencsg.com/csghub-server/component" ) -// define GitCallbackComponent struct -type GitCallbackComponent struct { - config *config.Config - gs gitserver.GitServer - tc component.TagComponent - modSvcClient rpc.ModerationSvcClient - ms database.ModelStore - ds database.DatasetStore - sc component.SpaceComponent - ss database.SpaceStore - rs database.RepoStore - rrs database.RepoRelationsStore - mirrorStore database.MirrorStore - rrf database.RepositoriesRuntimeFrameworkStore - rac component.RuntimeArchitectureComponent - ras database.RuntimeArchitecturesStore - rfs database.RuntimeFrameworksStore - ts database.TagStore - dt database.TagRuleStore +type GitCallbackComponent interface { + SetRepoVisibility(yes bool) + WatchSpaceChange(ctx context.Context, req *types.GiteaCallbackPushReq) error + WatchRepoRelation(ctx context.Context, req *types.GiteaCallbackPushReq) error + SetRepoUpdateTime(ctx context.Context, req *types.GiteaCallbackPushReq) error + UpdateRepoInfos(ctx context.Context, req *types.GiteaCallbackPushReq) error +} + +type gitCallbackComponentImpl struct { + config *config.Config + gitServer gitserver.GitServer + tagComponent component.TagComponent + modSvcClient rpc.ModerationSvcClient + modelStore database.ModelStore + datasetStore database.DatasetStore + spaceComponent component.SpaceComponent + spaceStore database.SpaceStore + repoStore database.RepoStore + repoRelationStore database.RepoRelationsStore + mirrorStore database.MirrorStore + repoRuntimeFrameworkStore database.RepositoriesRuntimeFrameworkStore + runtimeArchComponent component.RuntimeArchitectureComponent + runtimeArchStore database.RuntimeArchitecturesStore + runtimeFrameworkStore database.RuntimeFrameworksStore + tagStore database.TagStore + tagRuleStore database.TagRuleStore // set visibility if file content is sensitive setRepoVisibility bool - pp component.PromptComponent maxPromptFS int64 } // new CallbackComponent -func NewGitCallback(config *config.Config) (*GitCallbackComponent, error) { +func NewGitCallback(config *config.Config) (*gitCallbackComponentImpl, error) { gs, err := git.NewGitServer(config) if err != nil { return nil, err @@ -74,45 +80,40 @@ func NewGitCallback(config *config.Config) (*GitCallbackComponent, error) { } rfs := database.NewRuntimeFrameworksStore() ts := database.NewTagStore() - pp, err := component.NewPromptComponent(config) - if err != nil { - return nil, err - } var modSvcClient rpc.ModerationSvcClient if config.SensitiveCheck.Enable { modSvcClient = rpc.NewModerationSvcHttpClient(fmt.Sprintf("%s:%d", config.Moderation.Host, config.Moderation.Port)) } dt := database.NewTagRuleStore() - return &GitCallbackComponent{ - config: config, - gs: gs, - tc: tc, - ms: ms, - ds: ds, - ss: ss, - sc: sc, - rs: rs, - rrs: rrs, - mirrorStore: mirrorStore, - modSvcClient: modSvcClient, - rrf: rrf, - rac: rac, - ras: ras, - rfs: rfs, - pp: pp, - ts: ts, - dt: dt, - maxPromptFS: config.Dataset.PromptMaxJsonlFileSize, + return &gitCallbackComponentImpl{ + config: config, + gitServer: gs, + tagComponent: tc, + modelStore: ms, + datasetStore: ds, + spaceStore: ss, + spaceComponent: sc, + repoStore: rs, + repoRelationStore: rrs, + mirrorStore: mirrorStore, + modSvcClient: modSvcClient, + repoRuntimeFrameworkStore: rrf, + runtimeArchComponent: rac, + runtimeArchStore: ras, + runtimeFrameworkStore: rfs, + tagStore: ts, + tagRuleStore: dt, + maxPromptFS: config.Dataset.PromptMaxJsonlFileSize, }, nil } // SetRepoVisibility sets a flag whether change repo's visibility if file content is sensitive -func (c *GitCallbackComponent) SetRepoVisibility(yes bool) { +func (c *gitCallbackComponentImpl) SetRepoVisibility(yes bool) { c.setRepoVisibility = yes } -func (c *GitCallbackComponent) WatchSpaceChange(ctx context.Context, req *types.GiteaCallbackPushReq) error { - err := WatchSpaceChange(req, c.ss, c.sc).Run() +func (c *gitCallbackComponentImpl) WatchSpaceChange(ctx context.Context, req *types.GiteaCallbackPushReq) error { + err := WatchSpaceChange(req, c.spaceStore, c.spaceComponent).Run() if err != nil { slog.Error("watch space change failed", slog.Any("error", err)) return err @@ -120,8 +121,8 @@ func (c *GitCallbackComponent) WatchSpaceChange(ctx context.Context, req *types. return nil } -func (c *GitCallbackComponent) WatchRepoRelation(ctx context.Context, req *types.GiteaCallbackPushReq) error { - err := WatchRepoRelation(req, c.rs, c.rrs, c.gs).Run() +func (c *gitCallbackComponentImpl) WatchRepoRelation(ctx context.Context, req *types.GiteaCallbackPushReq) error { + err := WatchRepoRelation(req, c.repoStore, c.repoRelationStore, c.gitServer).Run() if err != nil { slog.Error("watch repo relation failed", slog.Any("error", err)) return err @@ -129,7 +130,7 @@ func (c *GitCallbackComponent) WatchRepoRelation(ctx context.Context, req *types return nil } -func (c *GitCallbackComponent) SetRepoUpdateTime(ctx context.Context, req *types.GiteaCallbackPushReq) error { +func (c *gitCallbackComponentImpl) SetRepoUpdateTime(ctx context.Context, req *types.GiteaCallbackPushReq) error { // split req.Repository.FullName by '/' splits := strings.Split(req.Repository.FullName, "/") fullNamespace, repoName := splits[0], splits[1] @@ -138,7 +139,7 @@ func (c *GitCallbackComponent) SetRepoUpdateTime(ctx context.Context, req *types ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() - isMirrorRepo, err := c.rs.IsMirrorRepo(ctx, adjustedRepoType, namespace, repoName) + isMirrorRepo, err := c.repoStore.IsMirrorRepo(ctx, adjustedRepoType, namespace, repoName) if err != nil { slog.Error("failed to check if a mirror repo", slog.Any("error", err), slog.String("repo_type", string(adjustedRepoType)), slog.String("namespace", namespace), slog.String("name", repoName)) return err @@ -149,7 +150,7 @@ func (c *GitCallbackComponent) SetRepoUpdateTime(ctx context.Context, req *types slog.Error("Error parsing time:", slog.Any("error", err), slog.String("timestamp", req.HeadCommit.Timestamp)) return err } - err = c.rs.SetUpdateTimeByPath(ctx, adjustedRepoType, namespace, repoName, updated) + err = c.repoStore.SetUpdateTimeByPath(ctx, adjustedRepoType, namespace, repoName, updated) if err != nil { slog.Error("failed to set repo update time", slog.Any("error", err), slog.String("repo_type", string(adjustedRepoType)), slog.String("namespace", namespace), slog.String("name", repoName)) return err @@ -166,7 +167,7 @@ func (c *GitCallbackComponent) SetRepoUpdateTime(ctx context.Context, req *types return err } } else { - err := c.rs.SetUpdateTimeByPath(ctx, adjustedRepoType, namespace, repoName, time.Now()) + err := c.repoStore.SetUpdateTimeByPath(ctx, adjustedRepoType, namespace, repoName, time.Now()) if err != nil { slog.Error("failed to set repo update time", slog.Any("error", err), slog.String("repo_type", string(adjustedRepoType)), slog.String("namespace", namespace), slog.String("name", repoName)) return err @@ -175,7 +176,7 @@ func (c *GitCallbackComponent) SetRepoUpdateTime(ctx context.Context, req *types return nil } -func (c *GitCallbackComponent) UpdateRepoInfos(ctx context.Context, req *types.GiteaCallbackPushReq) error { +func (c *gitCallbackComponentImpl) UpdateRepoInfos(ctx context.Context, req *types.GiteaCallbackPushReq) error { commits := req.Commits ref := req.Ref // split req.Repository.FullName by '/' @@ -193,7 +194,7 @@ func (c *GitCallbackComponent) UpdateRepoInfos(ctx context.Context, req *types.G return err } -func (c *GitCallbackComponent) SensitiveCheck(ctx context.Context, req *types.GiteaCallbackPushReq) error { +func (c *gitCallbackComponentImpl) SensitiveCheck(ctx context.Context, req *types.GiteaCallbackPushReq) error { // split req.Repository.FullName by '/' splits := strings.Split(req.Repository.FullName, "/") fullNamespace, repoName := splits[0], splits[1] @@ -208,11 +209,12 @@ func (c *GitCallbackComponent) SensitiveCheck(ctx context.Context, req *types.Gi slog.Error("fail to submit repo sensitive check", slog.Any("error", err), slog.Any("repo_type", adjustedRepoType), slog.String("namespace", namespace), slog.String("name", repoName)) return err } + return nil } // modifyFiles method handles modified files, skip if not modify README.md -func (c *GitCallbackComponent) modifyFiles(ctx context.Context, repoType, namespace, repoName, ref string, fileNames []string) error { +func (c *gitCallbackComponentImpl) modifyFiles(ctx context.Context, repoType, namespace, repoName, ref string, fileNames []string) error { for _, fileName := range fileNames { slog.Debug("modify file", slog.String("file", fileName)) // update model runtime @@ -232,7 +234,7 @@ func (c *GitCallbackComponent) modifyFiles(ctx context.Context, repoType, namesp return nil } -func (c *GitCallbackComponent) removeFiles(ctx context.Context, repoType, namespace, repoName, ref string, fileNames []string) error { +func (c *gitCallbackComponentImpl) removeFiles(ctx context.Context, repoType, namespace, repoName, ref string, fileNames []string) error { // handle removed files // delete tags for _, fileName := range fileNames { @@ -244,7 +246,7 @@ func (c *GitCallbackComponent) removeFiles(ctx context.Context, repoType, namesp // use empty content to clear all the meta tags const content string = "" adjustedRepoType := types.RepositoryType(strings.TrimSuffix(repoType, "s")) - err := c.tc.ClearMetaTags(ctx, adjustedRepoType, namespace, repoName) + err := c.tagComponent.ClearMetaTags(ctx, adjustedRepoType, namespace, repoName) if err != nil { slog.Error("failed to clear meta tags", slog.String("content", content), slog.String("repo", path.Join(namespace, repoName)), slog.String("ref", ref), @@ -267,7 +269,7 @@ func (c *GitCallbackComponent) removeFiles(ctx context.Context, repoType, namesp // case SpaceRepoType: // tagScope = database.SpaceTagScope } - err := c.tc.UpdateLibraryTags(ctx, tagScope, namespace, repoName, fileName, "") + err := c.tagComponent.UpdateLibraryTags(ctx, tagScope, namespace, repoName, fileName, "") if err != nil { slog.Error("failed to remove Library tag", slog.String("namespace", namespace), slog.String("name", repoName), slog.String("ref", ref), slog.String("fileName", fileName), @@ -279,7 +281,7 @@ func (c *GitCallbackComponent) removeFiles(ctx context.Context, repoType, namesp return nil } -func (c *GitCallbackComponent) addFiles(ctx context.Context, repoType, namespace, repoName, ref string, fileNames []string) error { +func (c *gitCallbackComponentImpl) addFiles(ctx context.Context, repoType, namespace, repoName, ref string, fileNames []string) error { for _, fileName := range fileNames { slog.Debug("add file", slog.String("file", fileName)) // update model runtime @@ -310,7 +312,7 @@ func (c *GitCallbackComponent) addFiles(ctx context.Context, repoType, namespace // case SpaceRepoType: // tagScope = database.SpaceTagScope } - err := c.tc.UpdateLibraryTags(ctx, tagScope, namespace, repoName, "", fileName) + err := c.tagComponent.UpdateLibraryTags(ctx, tagScope, namespace, repoName, "", fileName) if err != nil { slog.Error("failed to add Library tag", slog.String("namespace", namespace), slog.String("name", repoName), slog.String("ref", ref), slog.String("fileName", fileName), @@ -322,7 +324,7 @@ func (c *GitCallbackComponent) addFiles(ctx context.Context, repoType, namespace return nil } -func (c *GitCallbackComponent) updateMetaTags(ctx context.Context, repoType, namespace, repoName, ref, content string) error { +func (c *gitCallbackComponentImpl) updateMetaTags(ctx context.Context, repoType, namespace, repoName, ref, content string) error { var ( err error tagScope database.TagScope @@ -342,7 +344,7 @@ func (c *GitCallbackComponent) updateMetaTags(ctx context.Context, repoType, nam // case SpaceRepoType: // tagScope = database.SpaceTagScope } - _, err = c.tc.UpdateMetaTags(ctx, tagScope, namespace, repoName, content) + _, err = c.tagComponent.UpdateMetaTags(ctx, tagScope, namespace, repoName, content) if err != nil { slog.Error("failed to update meta tags", slog.String("namespace", namespace), slog.String("content", content), slog.String("repo", repoName), slog.String("ref", ref), @@ -353,7 +355,7 @@ func (c *GitCallbackComponent) updateMetaTags(ctx context.Context, repoType, nam return nil } -func (c *GitCallbackComponent) getFileRaw(repoType, namespace, repoName, ref, fileName string) (string, error) { +func (c *gitCallbackComponentImpl) getFileRaw(repoType, namespace, repoName, ref, fileName string) (string, error) { var ( content string err error @@ -366,7 +368,7 @@ func (c *GitCallbackComponent) getFileRaw(repoType, namespace, repoName, ref, fi Path: fileName, RepoType: types.RepositoryType(repoType), } - content, err = c.gs.GetRepoFileRaw(context.Background(), getFileRawReq) + content, err = c.gitServer.GetRepoFileRaw(context.Background(), getFileRawReq) if err != nil { slog.Error("failed to get file content", slog.String("namespace", namespace), slog.String("file", fileName), slog.String("repo", repoName), slog.String("ref", ref), @@ -380,7 +382,7 @@ func (c *GitCallbackComponent) getFileRaw(repoType, namespace, repoName, ref, fi } // update repo relations -func (c *GitCallbackComponent) updateRepoRelations(ctx context.Context, repoType, namespace, repoName, ref, fileName string, deleteAction bool, fileNames []string) { +func (c *gitCallbackComponentImpl) updateRepoRelations(ctx context.Context, repoType, namespace, repoName, ref, fileName string, deleteAction bool, fileNames []string) { slog.Debug("update model relation for git callback", slog.Any("namespace", namespace), slog.Any("repoName", repoName), slog.Any("repoType", repoType), slog.Any("fileName", fileName), slog.Any("branch", ref)) if repoType == fmt.Sprintf("%ss", types.ModelRepo) { c.updateModelRuntimeFrameworks(ctx, repoType, namespace, repoName, ref, fileName, deleteAction) @@ -391,19 +393,19 @@ func (c *GitCallbackComponent) updateRepoRelations(ctx context.Context, repoType } // update dataset tags for evaluation -func (c *GitCallbackComponent) updateDatasetTags(ctx context.Context, namespace, repoName string, fileNames []string) { +func (c *gitCallbackComponentImpl) updateDatasetTags(ctx context.Context, namespace, repoName string, fileNames []string) { // script dataset repo was not supported so far scriptName := fmt.Sprintf("%s.py", repoName) if slices.Contains(fileNames, scriptName) { return } - repo, err := c.rs.FindByPath(ctx, types.DatasetRepo, namespace, repoName) + repo, err := c.repoStore.FindByPath(ctx, types.DatasetRepo, namespace, repoName) if err != nil || repo == nil { slog.Warn("fail to query repo for in callback", slog.Any("namespace", namespace), slog.Any("repoName", repoName), slog.Any("error", err)) return } // check if it's evaluation dataset - evalDataset, err := c.dt.FindByRepo(ctx, string(types.EvaluationCategory), namespace, repoName, string(types.DatasetRepo)) + evalDataset, err := c.tagRuleStore.FindByRepo(ctx, string(types.EvaluationCategory), namespace, repoName, string(types.DatasetRepo)) if err != nil { if errors.Is(err, sql.ErrNoRows) { // check if it's a mirror repo @@ -415,7 +417,7 @@ func (c *GitCallbackComponent) updateDatasetTags(ctx context.Context, namespace, namespace := strings.Split(mirror.SourceRepoPath, "/")[0] name := strings.Split(mirror.SourceRepoPath, "/")[1] // use mirror namespace and name to find dataset - evalDataset, err = c.dt.FindByRepo(ctx, string(types.EvaluationCategory), namespace, name, string(types.DatasetRepo)) + evalDataset, err = c.tagRuleStore.FindByRepo(ctx, string(types.EvaluationCategory), namespace, name, string(types.DatasetRepo)) if err != nil { slog.Debug("not an evaluation dataset, ignore it", slog.Any("repo id", repo.Path)) return @@ -429,13 +431,13 @@ func (c *GitCallbackComponent) updateDatasetTags(ctx context.Context, namespace, tagIds := []int64{} tagIds = append(tagIds, evalDataset.Tag.ID) if evalDataset.RuntimeFramework != "" { - rTag, _ := c.ts.FindTag(ctx, evalDataset.RuntimeFramework, string(types.DatasetRepo), "runtime_framework") + rTag, _ := c.tagStore.FindTag(ctx, evalDataset.RuntimeFramework, string(types.DatasetRepo), "runtime_framework") if rTag != nil { tagIds = append(tagIds, rTag.ID) } } - err = c.ts.UpsertRepoTags(ctx, repo.ID, []int64{}, tagIds) + err = c.tagStore.UpsertRepoTags(ctx, repo.ID, []int64{}, tagIds) if err != nil { slog.Warn("fail to add dataset tag", slog.Any("repoId", repo.ID), slog.Any("tag id", tagIds), slog.Any("error", err)) } @@ -443,39 +445,39 @@ func (c *GitCallbackComponent) updateDatasetTags(ctx context.Context, namespace, } // update model runtime frameworks -func (c *GitCallbackComponent) updateModelRuntimeFrameworks(ctx context.Context, repoType, namespace, repoName, ref, fileName string, deleteAction bool) { +func (c *gitCallbackComponentImpl) updateModelRuntimeFrameworks(ctx context.Context, repoType, namespace, repoName, ref, fileName string, deleteAction bool) { // must be model repo and config.json if repoType != fmt.Sprintf("%ss", types.ModelRepo) || fileName != component.ConfigFileName || (ref != ("refs/heads/"+component.MainBranch) && ref != ("refs/heads/"+component.MasterBranch)) { return } - repo, err := c.rs.FindByPath(ctx, types.ModelRepo, namespace, repoName) + repo, err := c.repoStore.FindByPath(ctx, types.ModelRepo, namespace, repoName) if err != nil || repo == nil { slog.Warn("fail to query repo for git callback", slog.Any("namespace", namespace), slog.Any("repoName", repoName), slog.Any("error", err)) return } // delete event if deleteAction { - err := c.rrf.DeleteByRepoID(ctx, repo.ID) + err := c.repoRuntimeFrameworkStore.DeleteByRepoID(ctx, repo.ID) if err != nil { slog.Warn("fail to remove repo runtimes for git callback", slog.Any("namespace", namespace), slog.Any("repoName", repoName), slog.Any("repoid", repo.ID), slog.Any("error", err)) } return } - arch, err := c.rac.GetArchitectureFromConfig(ctx, namespace, repoName) + arch, err := c.runtimeArchComponent.GetArchitectureFromConfig(ctx, namespace, repoName) if err != nil { slog.Warn("fail to get config.json content for git callback", slog.Any("namespace", namespace), slog.Any("repoName", repoName), slog.Any("error", err)) return } slog.Debug("get arch for git callback", slog.Any("namespace", namespace), slog.Any("repoName", repoName), slog.Any("arch", arch)) //add resource tag, like ascend - runtime_framework_tags, _ := c.ts.GetTagsByScopeAndCategories(ctx, "model", []string{"runtime_framework", "resource"}) + runtime_framework_tags, _ := c.tagStore.GetTagsByScopeAndCategories(ctx, "model", []string{"runtime_framework", "resource"}) fields := strings.Split(repo.Path, "/") - err = c.rac.AddResourceTag(ctx, runtime_framework_tags, fields[1], repo.ID) + err = c.runtimeArchComponent.AddResourceTag(ctx, runtime_framework_tags, fields[1], repo.ID) if err != nil { slog.Warn("fail to add resource tag", slog.Any("error", err)) return } - runtimes, err := c.ras.ListByRArchNameAndModel(ctx, arch, fields[1]) + runtimes, err := c.runtimeArchStore.ListByRArchNameAndModel(ctx, arch, fields[1]) // to do check resource models if err != nil { slog.Warn("fail to get runtime ids by arch for git callback", slog.Any("arch", arch), slog.Any("error", err)) @@ -487,7 +489,7 @@ func (c *GitCallbackComponent) updateModelRuntimeFrameworks(ctx context.Context, frameIDs = append(frameIDs, runtime.RuntimeFrameworkID) } slog.Debug("get new frame ids for git callback", slog.Any("namespace", namespace), slog.Any("repoName", repoName), slog.Any("frameIDs", frameIDs)) - newFrames, err := c.rfs.ListByIDs(ctx, frameIDs) + newFrames, err := c.runtimeFrameworkStore.ListByIDs(ctx, frameIDs) if err != nil { slog.Warn("fail to get runtime frameworks for git callback", slog.Any("arch", arch), slog.Any("error", err)) return @@ -498,7 +500,7 @@ func (c *GitCallbackComponent) updateModelRuntimeFrameworks(ctx context.Context, newFrameMap[strconv.FormatInt(frame.ID, 10)] = strconv.FormatInt(frame.ID, 10) } slog.Debug("get new frame map by arch for git callback", slog.Any("namespace", namespace), slog.Any("repoName", repoName), slog.Any("newFrameMap", newFrameMap)) - oldRepoRuntimes, err := c.rrf.GetByRepoIDs(ctx, repo.ID) + oldRepoRuntimes, err := c.repoRuntimeFrameworkStore.GetByRepoIDs(ctx, repo.ID) if err != nil { slog.Warn("fail to get repo runtimes for git callback", slog.Any("repo.ID", repo.ID), slog.Any("error", err)) return @@ -516,12 +518,12 @@ func (c *GitCallbackComponent) updateModelRuntimeFrameworks(ctx context.Context, _, exist := newFrameMap[strconv.FormatInt(old.RuntimeFrameworkID, 10)] if !exist { // remove incorrect relations - err := c.rrf.Delete(ctx, old.RuntimeFrameworkID, repo.ID, old.Type) + err := c.repoRuntimeFrameworkStore.Delete(ctx, old.RuntimeFrameworkID, repo.ID, old.Type) if err != nil { slog.Warn("fail to delete old repo runtimes for git callback", slog.Any("repo.ID", repo.ID), slog.Any("runtime framework id", old.RuntimeFrameworkID), slog.Any("error", err)) } // remove runtime framework tags - c.rac.RemoveRuntimeFrameworkTag(ctx, runtime_framework_tags, repo.ID, old.RuntimeFrameworkID) + c.runtimeArchComponent.RemoveRuntimeFrameworkTag(ctx, runtime_framework_tags, repo.ID, old.RuntimeFrameworkID) } } @@ -531,12 +533,12 @@ func (c *GitCallbackComponent) updateModelRuntimeFrameworks(ctx context.Context, _, exist := oldFrameMap[strconv.FormatInt(new.ID, 10)] if !exist { // add new relations - err := c.rrf.Add(ctx, new.ID, repo.ID, new.Type) + err := c.repoRuntimeFrameworkStore.Add(ctx, new.ID, repo.ID, new.Type) if err != nil { slog.Warn("fail to add new repo runtimes for git callback", slog.Any("repo.ID", repo.ID), slog.Any("runtime framework id", new.ID), slog.Any("error", err)) } // add runtime framework and resource tags - err = c.rac.AddRuntimeFrameworkTag(ctx, runtime_framework_tags, repo.ID, new.ID) + err = c.runtimeArchComponent.AddRuntimeFrameworkTag(ctx, runtime_framework_tags, repo.ID, new.ID) if err != nil { slog.Warn("fail to add runtime framework tag for git callback", slog.Any("repo.ID", repo.ID), slog.Any("runtime framework id", new.ID), slog.Any("error", err)) } diff --git a/component/callback/git_callback_test.go b/component/callback/git_callback_test.go new file mode 100644 index 00000000..bafa8d85 --- /dev/null +++ b/component/callback/git_callback_test.go @@ -0,0 +1,187 @@ +package callback + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/git/gitserver" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" + "opencsg.com/csghub-server/component" +) + +func TestGitCallbackComponent_SetRepoVisibility(t *testing.T) { + ctx := context.TODO() + gc := initializeTestGitCallbackComponent(ctx, t) + + require.False(t, gc.setRepoVisibility) + gc.SetRepoVisibility(true) + require.True(t, gc.setRepoVisibility) +} + +func TestGitCallbackComponent_WatchSpaceChange(t *testing.T) { + ctx := mock.Anything + gc := initializeTestGitCallbackComponent(context.TODO(), t) + + gc.mocks.stores.SpaceMock().EXPECT().FindByPath(ctx, "b", "c").Return( + &database.Space{HasAppFile: true}, nil, + ) + gc.mocks.spaceComponent.EXPECT().FixHasEntryFile(ctx, &database.Space{ + HasAppFile: true, + }).Return(nil) + gc.mocks.spaceComponent.EXPECT().Deploy(ctx, "b", "c", "b").Return(100, nil) + + err := gc.WatchSpaceChange(context.TODO(), &types.GiteaCallbackPushReq{ + Ref: "main", + Repository: types.GiteaCallbackPushReq_Repository{ + FullName: "spaces_b/c/d", + }, + }) + require.Nil(t, err) +} + +func TestGitCallbackComponent_WatchRepoRelation(t *testing.T) { + ctx := mock.Anything + gc := initializeTestGitCallbackComponent(context.TODO(), t) + + gc.mocks.gitServer.EXPECT().GetRepoFileRaw(ctx, gitserver.GetRepoInfoByPathReq{ + Namespace: "b", + Name: "c", + Ref: "refs/heads/main", + Path: "README.md", + RepoType: types.SpaceRepo, + }).Return("", nil) + gc.mocks.stores.RepoMock().EXPECT().FindByPath(ctx, types.SpaceRepo, "b", "c").Return( + &database.Repository{ID: 1}, nil, + ) + gc.mocks.stores.RepoRelationMock().EXPECT().Override(ctx, int64(1)).Return(nil) + + err := gc.WatchRepoRelation(context.TODO(), &types.GiteaCallbackPushReq{ + Ref: "refs/heads/main", + Repository: types.GiteaCallbackPushReq_Repository{ + FullName: "spaces_b/c/d", + }, + Commits: []types.GiteaCallbackPushReq_Commit{ + {Modified: []string{types.ReadmeFileName}}, + }, + }) + require.Nil(t, err) +} + +func TestGitCallbackComponent_SetRepoUpdateTime(t *testing.T) { + for _, mirror := range []bool{false, true} { + t.Run(fmt.Sprintf("mirror %v", mirror), func(t *testing.T) { + dt := time.Date(2022, 2, 2, 2, 0, 0, 0, time.UTC) + ctx := mock.Anything + gc := initializeTestGitCallbackComponent(context.TODO(), t) + + gc.mocks.stores.RepoMock().EXPECT().IsMirrorRepo( + ctx, types.ModelRepo, "ns", "n", + ).Return(mirror, nil) + + if mirror { + gc.mocks.stores.RepoMock().EXPECT().SetUpdateTimeByPath( + ctx, types.ModelRepo, "ns", "n", dt, + ).Return(nil) + gc.mocks.stores.MirrorMock().EXPECT().FindByRepoPath( + ctx, types.ModelRepo, "ns", "n", + ).Return(&database.Mirror{}, nil) + gc.mocks.stores.MirrorMock().EXPECT().Update( + ctx, mock.Anything, + ).RunAndReturn(func(ctx context.Context, m *database.Mirror) error { + require.GreaterOrEqual(t, m.LastUpdatedAt, time.Now().Add(-5*time.Second)) + return nil + }) + } else { + gc.mocks.stores.RepoMock().EXPECT().SetUpdateTimeByPath( + ctx, types.ModelRepo, "ns", "n", mock.Anything, + ).RunAndReturn(func(ctx context.Context, rt types.RepositoryType, s1, s2 string, tt time.Time) error { + require.GreaterOrEqual(t, tt, time.Now().Add(-5*time.Second)) + return nil + }) + } + + err := gc.SetRepoUpdateTime(context.TODO(), &types.GiteaCallbackPushReq{ + Repository: types.GiteaCallbackPushReq_Repository{ + FullName: "models_ns/n", + }, + HeadCommit: types.GiteaCallbackPushReq_HeadCommit{ + Timestamp: dt.Format(time.RFC3339), + }, + }) + require.Nil(t, err) + }) + } +} + +func TestGitCallbackComponent_UpdateRepoInfos(t *testing.T) { + ctx := context.TODO() + gc := initializeTestGitCallbackComponent(context.TODO(), t) + + // modified mock + gc.mocks.stores.RepoMock().EXPECT().FindByPath(ctx, types.ModelRepo, "ns", "n").Return( + &database.Repository{ID: 1, Path: "foo/bar"}, nil, + ) + gc.mocks.runtimeArchComponent.EXPECT().GetArchitectureFromConfig(ctx, "ns", "n").Return("foo", nil) + gc.mocks.stores.TagMock().EXPECT().GetTagsByScopeAndCategories( + ctx, database.ModelTagScope, []string{"runtime_framework", "resource"}, + ).Return([]*database.Tag{{Name: "t1"}}, nil) + gc.mocks.runtimeArchComponent.EXPECT().AddResourceTag( + ctx, []*database.Tag{{Name: "t1"}}, "bar", int64(1), + ).Return(nil) + gc.mocks.stores.RuntimeArchMock().EXPECT().ListByRArchNameAndModel(ctx, "foo", "bar").Return( + []database.RuntimeArchitecture{{ID: 11, RuntimeFrameworkID: 111}}, nil, + ) + gc.mocks.stores.RuntimeFrameworkMock().EXPECT().ListByIDs(ctx, []int64{111}).Return( + []database.RuntimeFramework{{ID: 12, FrameName: "fm"}}, nil, + ) + gc.mocks.stores.RepoRuntimeFrameworkMock().EXPECT().GetByRepoIDs(ctx, int64(1)).Return( + []database.RepositoriesRuntimeFramework{{RuntimeFrameworkID: 13}}, nil, + ) + gc.mocks.stores.RepoRuntimeFrameworkMock().EXPECT().Delete(ctx, int64(13), int64(1), 0).Return(nil) + gc.mocks.runtimeArchComponent.EXPECT().RemoveRuntimeFrameworkTag( + ctx, []*database.Tag{{Name: "t1"}}, int64(1), int64(13), + ).Return() + gc.mocks.stores.RepoRuntimeFrameworkMock().EXPECT().Add(ctx, int64(12), int64(1), 0).Return(nil) + gc.mocks.runtimeArchComponent.EXPECT().AddRuntimeFrameworkTag( + ctx, []*database.Tag{{Name: "t1"}}, int64(1), int64(12), + ).Return(nil) + // removed mock + gc.mocks.tagComponent.EXPECT().UpdateLibraryTags( + ctx, database.ModelTagScope, "ns", "n", "bar.go", "", + ).Return(nil) + gc.mocks.tagComponent.EXPECT().ClearMetaTags(ctx, types.ModelRepo, "ns", "n").Return(nil) + // added mock + gc.mocks.tagComponent.EXPECT().UpdateLibraryTags( + ctx, database.ModelTagScope, "ns", "n", "", "foo.go", + ).Return(nil) + gc.mocks.gitServer.EXPECT().GetRepoFileRaw(mock.Anything, gitserver.GetRepoInfoByPathReq{ + Namespace: "ns", + Name: "n", + Ref: "refs/heads/main", + Path: "README.md", + RepoType: types.ModelRepo, + }).Return("", nil) + gc.mocks.tagComponent.EXPECT().UpdateMetaTags( + ctx, database.ModelTagScope, "ns", "n", "", + ).Return(nil, nil) + + err := gc.UpdateRepoInfos(ctx, &types.GiteaCallbackPushReq{ + Ref: "refs/heads/main", + Repository: types.GiteaCallbackPushReq_Repository{ + FullName: "models_ns/n", + }, + Commits: []types.GiteaCallbackPushReq_Commit{ + { + Modified: []string{component.ConfigFileName}, + Removed: []string{"bar.go", types.ReadmeFileName}, + Added: []string{"foo.go", types.ReadmeFileName}, + }, + }, + }) + require.Nil(t, err) +} diff --git a/component/callback/sync_version_gen.go b/component/callback/sync_version_gen.go new file mode 100644 index 00000000..7086853d --- /dev/null +++ b/component/callback/sync_version_gen.go @@ -0,0 +1,42 @@ +package callback + +import ( + "context" + "fmt" + "strings" + "time" + + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +type SyncVersionGenerator interface { + GenSyncVersion(req *types.GiteaCallbackPushReq) error +} + +type syncVersionGeneratorImpl struct { + multiSyncStore database.MultiSyncStore +} + +func NewSyncVersionGenerator() *syncVersionGeneratorImpl { + return &syncVersionGeneratorImpl{ + multiSyncStore: database.NewMultiSyncStore(), + } +} + +func (g *syncVersionGeneratorImpl) GenSyncVersion(req *types.GiteaCallbackPushReq) error { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + splits := strings.Split(req.Repository.FullName, "/") + fullNamespace, repoName := splits[0], splits[1] + repoType, namespace, _ := strings.Cut(fullNamespace, "_") + _, err := g.multiSyncStore.Create(ctx, database.SyncVersion{ + SourceID: types.SyncVersionSourceOpenCSG, + RepoPath: fmt.Sprintf("%s/%s", namespace, repoName), + RepoType: types.RepositoryType(strings.TrimRight(repoType, "s")), + LastModifiedAt: req.HeadCommit.LastModifyTime, + ChangeLog: req.HeadCommit.Message, + }) + + return err +} diff --git a/component/callback/wire.go b/component/callback/wire.go new file mode 100644 index 00000000..e9abaea1 --- /dev/null +++ b/component/callback/wire.go @@ -0,0 +1,27 @@ +//go:build wireinject +// +build wireinject + +package callback + +import ( + "context" + + "github.com/google/wire" + "github.com/stretchr/testify/mock" +) + +type testGitCallbackWithMocks struct { + *gitCallbackComponentImpl + mocks *Mocks +} + +func initializeTestGitCallbackComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testGitCallbackWithMocks { + wire.Build( + MockCallbackSuperSet, GitCallbackComponentSet, + wire.Struct(new(testGitCallbackWithMocks), "*"), + ) + return &testGitCallbackWithMocks{} +} diff --git a/component/callback/wire_gen_test.go b/component/callback/wire_gen_test.go new file mode 100644 index 00000000..af85dfce --- /dev/null +++ b/component/callback/wire_gen_test.go @@ -0,0 +1,55 @@ +// Code generated by Wire. DO NOT EDIT. + +//go:generate go run -mod=mod github.com/google/wire/cmd/wire +//go:build !wireinject +// +build !wireinject + +package callback + +import ( + "context" + "github.com/stretchr/testify/mock" + "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/git/gitserver" + "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/rpc" + component2 "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" + "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component/callback" + "opencsg.com/csghub-server/common/tests" + "opencsg.com/csghub-server/component" +) + +// Injectors from wire.go: + +func initializeTestGitCallbackComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testGitCallbackWithMocks { + config := component.ProvideTestConfig() + mockStores := tests.NewMockStores(t) + mockGitServer := gitserver.NewMockGitServer(t) + mockTagComponent := component2.NewMockTagComponent(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mockRuntimeArchitectureComponent := component2.NewMockRuntimeArchitectureComponent(t) + mockSpaceComponent := component2.NewMockSpaceComponent(t) + callbackGitCallbackComponentImpl := NewTestGitCallbackComponent(config, mockStores, mockGitServer, mockTagComponent, mockModerationSvcClient, mockRuntimeArchitectureComponent, mockSpaceComponent) + mockSyncVersionGenerator := callback.NewMockSyncVersionGenerator(t) + mocks := &Mocks{ + stores: mockStores, + tagComponent: mockTagComponent, + spaceComponent: mockSpaceComponent, + syncVersionGenerator: mockSyncVersionGenerator, + gitServer: mockGitServer, + runtimeArchComponent: mockRuntimeArchitectureComponent, + } + callbackTestGitCallbackWithMocks := &testGitCallbackWithMocks{ + gitCallbackComponentImpl: callbackGitCallbackComponentImpl, + mocks: mocks, + } + return callbackTestGitCallbackWithMocks +} + +// wire.go: + +type testGitCallbackWithMocks struct { + *gitCallbackComponentImpl + mocks *Mocks +} diff --git a/component/callback/wireset.go b/component/callback/wireset.go new file mode 100644 index 00000000..0842fca7 --- /dev/null +++ b/component/callback/wireset.go @@ -0,0 +1,60 @@ +package callback + +import ( + "github.com/google/wire" + mock_git "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/git/gitserver" + mock_component "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" + mock_callback "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component/callback" + "opencsg.com/csghub-server/builder/git/gitserver" + "opencsg.com/csghub-server/builder/rpc" + "opencsg.com/csghub-server/common/config" + "opencsg.com/csghub-server/common/tests" + "opencsg.com/csghub-server/component" +) + +var MockedSyncVersionGeneratorSet = wire.NewSet( + mock_callback.NewMockSyncVersionGenerator, + wire.Bind(new(SyncVersionGenerator), new(*mock_callback.MockSyncVersionGenerator)), +) + +type Mocks struct { + stores *tests.MockStores + tagComponent *mock_component.MockTagComponent + spaceComponent *mock_component.MockSpaceComponent + syncVersionGenerator *mock_callback.MockSyncVersionGenerator + gitServer *mock_git.MockGitServer + runtimeArchComponent *mock_component.MockRuntimeArchitectureComponent +} + +var AllMockSet = wire.NewSet( + wire.Struct(new(Mocks), "*"), +) + +var MockCallbackSuperSet = wire.NewSet( + component.MockedStoreSet, component.MockedComponentSet, MockedSyncVersionGeneratorSet, AllMockSet, + component.ProvideTestConfig, component.MockedGitServerSet, component.MockedModerationSvcClientSet, +) + +func NewTestGitCallbackComponent(config *config.Config, stores *tests.MockStores, gitServer gitserver.GitServer, tagComponent component.TagComponent, modSvcClient rpc.ModerationSvcClient, runtimeArchComponent component.RuntimeArchitectureComponent, spaceComponent component.SpaceComponent) *gitCallbackComponentImpl { + return &gitCallbackComponentImpl{ + config: config, + gitServer: gitServer, + tagComponent: tagComponent, + modSvcClient: modSvcClient, + modelStore: stores.Model, + datasetStore: stores.Dataset, + spaceComponent: spaceComponent, + spaceStore: stores.Space, + repoStore: stores.Repo, + repoRelationStore: stores.RepoRelation, + mirrorStore: stores.Mirror, + repoRuntimeFrameworkStore: stores.RepoRuntimeFramework, + runtimeArchComponent: runtimeArchComponent, + runtimeArchStore: stores.RuntimeArch, + runtimeFrameworkStore: stores.RuntimeFramework, + tagStore: stores.Tag, + tagRuleStore: stores.TagRule, + } +} + +var GitCallbackComponentSet = wire.NewSet(NewTestGitCallbackComponent)