diff --git a/Makefile b/Makefile index 21c7b05a..739643b5 100644 --- a/Makefile +++ b/Makefile @@ -13,10 +13,12 @@ cover: mock_wire: @echo "Running wire for component mocks..." - @go run -mod=mod github.com/google/wire/cmd/wire opencsg.com/csghub-server/component + @go run -mod=mod github.com/google/wire/cmd/wire opencsg.com/csghub-server/component/... @if [ $$? -eq 0 ]; then \ - echo "Renaming wire_gen.go to wire_gen_test.go..."; \ + echo "Renaming component wire_gen.go to wire_gen_test.go..."; \ mv component/wire_gen.go component/wire_gen_test.go; \ + echo "Renaming component/callback wire_gen.go to wire_gen_test.go..."; \ + mv component/callback/wire_gen.go component/callback/wire_gen_test.go; \ else \ echo "Wire failed, skipping renaming."; \ fi diff --git a/api/handler/callback/git_callback.go b/api/handler/callback/git_callback.go index c0d62056..1fda23f8 100644 --- a/api/handler/callback/git_callback.go +++ b/api/handler/callback/git_callback.go @@ -14,7 +14,7 @@ import ( ) type GitCallbackHandler struct { - cbc *component.GitCallbackComponent + cbc component.GitCallbackComponent config *config.Config } diff --git a/common/tests/stores.go b/common/tests/stores.go index 2197113f..507a7157 100644 --- a/common/tests/stores.go +++ b/common/tests/stores.go @@ -14,6 +14,7 @@ type MockStores struct { Model database.ModelStore SpaceResource database.SpaceResourceStore Tag database.TagStore + TagRule database.TagRuleStore Dataset database.DatasetStore PromptConversation database.PromptConversationStore PromptPrefix database.PromptPrefixStore @@ -44,6 +45,9 @@ type MockStores struct { MultiSync database.MultiSyncStore File database.FileStore SSH database.SSHKeyStore + Telemetry database.TelemetryStore + RepoFile database.RepoFileStore + Event database.EventStore } func NewMockStores(t interface { @@ -88,6 +92,10 @@ func NewMockStores(t interface { MultiSync: mockdb.NewMockMultiSyncStore(t), File: mockdb.NewMockFileStore(t), SSH: mockdb.NewMockSSHKeyStore(t), + Telemetry: mockdb.NewMockTelemetryStore(t), + RepoFile: mockdb.NewMockRepoFileStore(t), + Event: mockdb.NewMockEventStore(t), + TagRule: mockdb.NewMockTagRuleStore(t), } } @@ -119,6 +127,10 @@ func (s *MockStores) TagMock() *mockdb.MockTagStore { return s.Tag.(*mockdb.MockTagStore) } +func (s *MockStores) TagRuleMock() *mockdb.MockTagRuleStore { + return s.TagRule.(*mockdb.MockTagRuleStore) +} + func (s *MockStores) DatasetMock() *mockdb.MockDatasetStore { return s.Dataset.(*mockdb.MockDatasetStore) } @@ -238,3 +250,15 @@ func (s *MockStores) FileMock() *mockdb.MockFileStore { func (s *MockStores) SSHMock() *mockdb.MockSSHKeyStore { return s.SSH.(*mockdb.MockSSHKeyStore) } + +func (s *MockStores) TelemetryMock() *mockdb.MockTelemetryStore { + return s.Telemetry.(*mockdb.MockTelemetryStore) +} + +func (s *MockStores) RepoFileMock() *mockdb.MockRepoFileStore { + return s.RepoFile.(*mockdb.MockRepoFileStore) +} + +func (s *MockStores) EventMock() *mockdb.MockEventStore { + return s.Event.(*mockdb.MockEventStore) +} diff --git a/common/types/prompt.go b/common/types/prompt.go index 31522bbe..7ad62e90 100644 --- a/common/types/prompt.go +++ b/common/types/prompt.go @@ -1,6 +1,8 @@ package types -import "time" +import ( + "time" +) type PromptReq struct { Namespace string `json:"namespace"` @@ -101,3 +103,31 @@ type PromptRes struct { CanManage bool `json:"can_manage"` Namespace *Namespace `json:"namespace"` } + +type Prompt struct { + Title string `json:"title" binding:"required"` + Content string `json:"content" binding:"required"` + Language string `json:"language" binding:"required"` + Tags []string `json:"tags"` + Type string `json:"type"` // "text|image|video|audio" + Source string `json:"source"` + Author string `json:"author"` + Time string `json:"time"` + Copyright string `json:"copyright"` + Feedback []string `json:"feedback"` +} + +type PromptOutput struct { + Prompt + FilePath string `json:"file_path"` + CanWrite bool `json:"can_write"` + CanManage bool `json:"can_manage"` +} + +type CreatePromptReq struct { + Prompt +} + +type UpdatePromptReq struct { + Prompt +} diff --git a/component/callback/git_callback.go b/component/callback/git_callback.go index b1bddcd4..0810f1e4 100644 --- a/component/callback/git_callback.go +++ b/component/callback/git_callback.go @@ -21,33 +21,39 @@ import ( "opencsg.com/csghub-server/component" ) -// define GitCallbackComponent struct -type GitCallbackComponent struct { - config *config.Config - gs gitserver.GitServer - tc component.TagComponent - modSvcClient rpc.ModerationSvcClient - ms database.ModelStore - ds database.DatasetStore - sc component.SpaceComponent - ss database.SpaceStore - rs database.RepoStore - rrs database.RepoRelationsStore - mirrorStore database.MirrorStore - rrf database.RepositoriesRuntimeFrameworkStore - rac component.RuntimeArchitectureComponent - ras database.RuntimeArchitecturesStore - rfs database.RuntimeFrameworksStore - ts database.TagStore - dt database.TagRuleStore +type GitCallbackComponent interface { + SetRepoVisibility(yes bool) + WatchSpaceChange(ctx context.Context, req *types.GiteaCallbackPushReq) error + WatchRepoRelation(ctx context.Context, req *types.GiteaCallbackPushReq) error + SetRepoUpdateTime(ctx context.Context, req *types.GiteaCallbackPushReq) error + UpdateRepoInfos(ctx context.Context, req *types.GiteaCallbackPushReq) error +} + +type gitCallbackComponentImpl struct { + config *config.Config + gitServer gitserver.GitServer + tagComponent component.TagComponent + modSvcClient rpc.ModerationSvcClient + modelStore database.ModelStore + datasetStore database.DatasetStore + spaceComponent component.SpaceComponent + spaceStore database.SpaceStore + repoStore database.RepoStore + repoRelationStore database.RepoRelationsStore + mirrorStore database.MirrorStore + repoRuntimeFrameworkStore database.RepositoriesRuntimeFrameworkStore + runtimeArchComponent component.RuntimeArchitectureComponent + runtimeArchStore database.RuntimeArchitecturesStore + runtimeFrameworkStore database.RuntimeFrameworksStore + tagStore database.TagStore + tagRuleStore database.TagRuleStore // set visibility if file content is sensitive setRepoVisibility bool - pp component.PromptComponent maxPromptFS int64 } // new CallbackComponent -func NewGitCallback(config *config.Config) (*GitCallbackComponent, error) { +func NewGitCallback(config *config.Config) (*gitCallbackComponentImpl, error) { gs, err := git.NewGitServer(config) if err != nil { return nil, err @@ -74,45 +80,40 @@ func NewGitCallback(config *config.Config) (*GitCallbackComponent, error) { } rfs := database.NewRuntimeFrameworksStore() ts := database.NewTagStore() - pp, err := component.NewPromptComponent(config) - if err != nil { - return nil, err - } var modSvcClient rpc.ModerationSvcClient if config.SensitiveCheck.Enable { modSvcClient = rpc.NewModerationSvcHttpClient(fmt.Sprintf("%s:%d", config.Moderation.Host, config.Moderation.Port)) } dt := database.NewTagRuleStore() - return &GitCallbackComponent{ - config: config, - gs: gs, - tc: tc, - ms: ms, - ds: ds, - ss: ss, - sc: sc, - rs: rs, - rrs: rrs, - mirrorStore: mirrorStore, - modSvcClient: modSvcClient, - rrf: rrf, - rac: rac, - ras: ras, - rfs: rfs, - pp: pp, - ts: ts, - dt: dt, - maxPromptFS: config.Dataset.PromptMaxJsonlFileSize, + return &gitCallbackComponentImpl{ + config: config, + gitServer: gs, + tagComponent: tc, + modelStore: ms, + datasetStore: ds, + spaceStore: ss, + spaceComponent: sc, + repoStore: rs, + repoRelationStore: rrs, + mirrorStore: mirrorStore, + modSvcClient: modSvcClient, + repoRuntimeFrameworkStore: rrf, + runtimeArchComponent: rac, + runtimeArchStore: ras, + runtimeFrameworkStore: rfs, + tagStore: ts, + tagRuleStore: dt, + maxPromptFS: config.Dataset.PromptMaxJsonlFileSize, }, nil } // SetRepoVisibility sets a flag whether change repo's visibility if file content is sensitive -func (c *GitCallbackComponent) SetRepoVisibility(yes bool) { +func (c *gitCallbackComponentImpl) SetRepoVisibility(yes bool) { c.setRepoVisibility = yes } -func (c *GitCallbackComponent) WatchSpaceChange(ctx context.Context, req *types.GiteaCallbackPushReq) error { - err := WatchSpaceChange(req, c.ss, c.sc).Run() +func (c *gitCallbackComponentImpl) WatchSpaceChange(ctx context.Context, req *types.GiteaCallbackPushReq) error { + err := WatchSpaceChange(req, c.spaceStore, c.spaceComponent).Run() if err != nil { slog.Error("watch space change failed", slog.Any("error", err)) return err @@ -120,8 +121,8 @@ func (c *GitCallbackComponent) WatchSpaceChange(ctx context.Context, req *types. return nil } -func (c *GitCallbackComponent) WatchRepoRelation(ctx context.Context, req *types.GiteaCallbackPushReq) error { - err := WatchRepoRelation(req, c.rs, c.rrs, c.gs).Run() +func (c *gitCallbackComponentImpl) WatchRepoRelation(ctx context.Context, req *types.GiteaCallbackPushReq) error { + err := WatchRepoRelation(req, c.repoStore, c.repoRelationStore, c.gitServer).Run() if err != nil { slog.Error("watch repo relation failed", slog.Any("error", err)) return err @@ -129,7 +130,7 @@ func (c *GitCallbackComponent) WatchRepoRelation(ctx context.Context, req *types return nil } -func (c *GitCallbackComponent) SetRepoUpdateTime(ctx context.Context, req *types.GiteaCallbackPushReq) error { +func (c *gitCallbackComponentImpl) SetRepoUpdateTime(ctx context.Context, req *types.GiteaCallbackPushReq) error { // split req.Repository.FullName by '/' splits := strings.Split(req.Repository.FullName, "/") fullNamespace, repoName := splits[0], splits[1] @@ -138,7 +139,7 @@ func (c *GitCallbackComponent) SetRepoUpdateTime(ctx context.Context, req *types ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() - isMirrorRepo, err := c.rs.IsMirrorRepo(ctx, adjustedRepoType, namespace, repoName) + isMirrorRepo, err := c.repoStore.IsMirrorRepo(ctx, adjustedRepoType, namespace, repoName) if err != nil { slog.Error("failed to check if a mirror repo", slog.Any("error", err), slog.String("repo_type", string(adjustedRepoType)), slog.String("namespace", namespace), slog.String("name", repoName)) return err @@ -149,7 +150,7 @@ func (c *GitCallbackComponent) SetRepoUpdateTime(ctx context.Context, req *types slog.Error("Error parsing time:", slog.Any("error", err), slog.String("timestamp", req.HeadCommit.Timestamp)) return err } - err = c.rs.SetUpdateTimeByPath(ctx, adjustedRepoType, namespace, repoName, updated) + err = c.repoStore.SetUpdateTimeByPath(ctx, adjustedRepoType, namespace, repoName, updated) if err != nil { slog.Error("failed to set repo update time", slog.Any("error", err), slog.String("repo_type", string(adjustedRepoType)), slog.String("namespace", namespace), slog.String("name", repoName)) return err @@ -166,7 +167,7 @@ func (c *GitCallbackComponent) SetRepoUpdateTime(ctx context.Context, req *types return err } } else { - err := c.rs.SetUpdateTimeByPath(ctx, adjustedRepoType, namespace, repoName, time.Now()) + err := c.repoStore.SetUpdateTimeByPath(ctx, adjustedRepoType, namespace, repoName, time.Now()) if err != nil { slog.Error("failed to set repo update time", slog.Any("error", err), slog.String("repo_type", string(adjustedRepoType)), slog.String("namespace", namespace), slog.String("name", repoName)) return err @@ -175,7 +176,7 @@ func (c *GitCallbackComponent) SetRepoUpdateTime(ctx context.Context, req *types return nil } -func (c *GitCallbackComponent) UpdateRepoInfos(ctx context.Context, req *types.GiteaCallbackPushReq) error { +func (c *gitCallbackComponentImpl) UpdateRepoInfos(ctx context.Context, req *types.GiteaCallbackPushReq) error { commits := req.Commits ref := req.Ref // split req.Repository.FullName by '/' @@ -193,7 +194,7 @@ func (c *GitCallbackComponent) UpdateRepoInfos(ctx context.Context, req *types.G return err } -func (c *GitCallbackComponent) SensitiveCheck(ctx context.Context, req *types.GiteaCallbackPushReq) error { +func (c *gitCallbackComponentImpl) SensitiveCheck(ctx context.Context, req *types.GiteaCallbackPushReq) error { // split req.Repository.FullName by '/' splits := strings.Split(req.Repository.FullName, "/") fullNamespace, repoName := splits[0], splits[1] @@ -208,11 +209,12 @@ func (c *GitCallbackComponent) SensitiveCheck(ctx context.Context, req *types.Gi slog.Error("fail to submit repo sensitive check", slog.Any("error", err), slog.Any("repo_type", adjustedRepoType), slog.String("namespace", namespace), slog.String("name", repoName)) return err } + return nil } // modifyFiles method handles modified files, skip if not modify README.md -func (c *GitCallbackComponent) modifyFiles(ctx context.Context, repoType, namespace, repoName, ref string, fileNames []string) error { +func (c *gitCallbackComponentImpl) modifyFiles(ctx context.Context, repoType, namespace, repoName, ref string, fileNames []string) error { for _, fileName := range fileNames { slog.Debug("modify file", slog.String("file", fileName)) // update model runtime @@ -232,7 +234,7 @@ func (c *GitCallbackComponent) modifyFiles(ctx context.Context, repoType, namesp return nil } -func (c *GitCallbackComponent) removeFiles(ctx context.Context, repoType, namespace, repoName, ref string, fileNames []string) error { +func (c *gitCallbackComponentImpl) removeFiles(ctx context.Context, repoType, namespace, repoName, ref string, fileNames []string) error { // handle removed files // delete tags for _, fileName := range fileNames { @@ -244,7 +246,7 @@ func (c *GitCallbackComponent) removeFiles(ctx context.Context, repoType, namesp // use empty content to clear all the meta tags const content string = "" adjustedRepoType := types.RepositoryType(strings.TrimSuffix(repoType, "s")) - err := c.tc.ClearMetaTags(ctx, adjustedRepoType, namespace, repoName) + err := c.tagComponent.ClearMetaTags(ctx, adjustedRepoType, namespace, repoName) if err != nil { slog.Error("failed to clear meta tags", slog.String("content", content), slog.String("repo", path.Join(namespace, repoName)), slog.String("ref", ref), @@ -267,7 +269,7 @@ func (c *GitCallbackComponent) removeFiles(ctx context.Context, repoType, namesp // case SpaceRepoType: // tagScope = database.SpaceTagScope } - err := c.tc.UpdateLibraryTags(ctx, tagScope, namespace, repoName, fileName, "") + err := c.tagComponent.UpdateLibraryTags(ctx, tagScope, namespace, repoName, fileName, "") if err != nil { slog.Error("failed to remove Library tag", slog.String("namespace", namespace), slog.String("name", repoName), slog.String("ref", ref), slog.String("fileName", fileName), @@ -279,7 +281,7 @@ func (c *GitCallbackComponent) removeFiles(ctx context.Context, repoType, namesp return nil } -func (c *GitCallbackComponent) addFiles(ctx context.Context, repoType, namespace, repoName, ref string, fileNames []string) error { +func (c *gitCallbackComponentImpl) addFiles(ctx context.Context, repoType, namespace, repoName, ref string, fileNames []string) error { for _, fileName := range fileNames { slog.Debug("add file", slog.String("file", fileName)) // update model runtime @@ -310,7 +312,7 @@ func (c *GitCallbackComponent) addFiles(ctx context.Context, repoType, namespace // case SpaceRepoType: // tagScope = database.SpaceTagScope } - err := c.tc.UpdateLibraryTags(ctx, tagScope, namespace, repoName, "", fileName) + err := c.tagComponent.UpdateLibraryTags(ctx, tagScope, namespace, repoName, "", fileName) if err != nil { slog.Error("failed to add Library tag", slog.String("namespace", namespace), slog.String("name", repoName), slog.String("ref", ref), slog.String("fileName", fileName), @@ -322,7 +324,7 @@ func (c *GitCallbackComponent) addFiles(ctx context.Context, repoType, namespace return nil } -func (c *GitCallbackComponent) updateMetaTags(ctx context.Context, repoType, namespace, repoName, ref, content string) error { +func (c *gitCallbackComponentImpl) updateMetaTags(ctx context.Context, repoType, namespace, repoName, ref, content string) error { var ( err error tagScope database.TagScope @@ -342,7 +344,7 @@ func (c *GitCallbackComponent) updateMetaTags(ctx context.Context, repoType, nam // case SpaceRepoType: // tagScope = database.SpaceTagScope } - _, err = c.tc.UpdateMetaTags(ctx, tagScope, namespace, repoName, content) + _, err = c.tagComponent.UpdateMetaTags(ctx, tagScope, namespace, repoName, content) if err != nil { slog.Error("failed to update meta tags", slog.String("namespace", namespace), slog.String("content", content), slog.String("repo", repoName), slog.String("ref", ref), @@ -353,7 +355,7 @@ func (c *GitCallbackComponent) updateMetaTags(ctx context.Context, repoType, nam return nil } -func (c *GitCallbackComponent) getFileRaw(repoType, namespace, repoName, ref, fileName string) (string, error) { +func (c *gitCallbackComponentImpl) getFileRaw(repoType, namespace, repoName, ref, fileName string) (string, error) { var ( content string err error @@ -366,7 +368,7 @@ func (c *GitCallbackComponent) getFileRaw(repoType, namespace, repoName, ref, fi Path: fileName, RepoType: types.RepositoryType(repoType), } - content, err = c.gs.GetRepoFileRaw(context.Background(), getFileRawReq) + content, err = c.gitServer.GetRepoFileRaw(context.Background(), getFileRawReq) if err != nil { slog.Error("failed to get file content", slog.String("namespace", namespace), slog.String("file", fileName), slog.String("repo", repoName), slog.String("ref", ref), @@ -380,7 +382,7 @@ func (c *GitCallbackComponent) getFileRaw(repoType, namespace, repoName, ref, fi } // update repo relations -func (c *GitCallbackComponent) updateRepoRelations(ctx context.Context, repoType, namespace, repoName, ref, fileName string, deleteAction bool, fileNames []string) { +func (c *gitCallbackComponentImpl) updateRepoRelations(ctx context.Context, repoType, namespace, repoName, ref, fileName string, deleteAction bool, fileNames []string) { slog.Debug("update model relation for git callback", slog.Any("namespace", namespace), slog.Any("repoName", repoName), slog.Any("repoType", repoType), slog.Any("fileName", fileName), slog.Any("branch", ref)) if repoType == fmt.Sprintf("%ss", types.ModelRepo) { c.updateModelRuntimeFrameworks(ctx, repoType, namespace, repoName, ref, fileName, deleteAction) @@ -391,19 +393,19 @@ func (c *GitCallbackComponent) updateRepoRelations(ctx context.Context, repoType } // update dataset tags for evaluation -func (c *GitCallbackComponent) updateDatasetTags(ctx context.Context, namespace, repoName string, fileNames []string) { +func (c *gitCallbackComponentImpl) updateDatasetTags(ctx context.Context, namespace, repoName string, fileNames []string) { // script dataset repo was not supported so far scriptName := fmt.Sprintf("%s.py", repoName) if slices.Contains(fileNames, scriptName) { return } - repo, err := c.rs.FindByPath(ctx, types.DatasetRepo, namespace, repoName) + repo, err := c.repoStore.FindByPath(ctx, types.DatasetRepo, namespace, repoName) if err != nil || repo == nil { slog.Warn("fail to query repo for in callback", slog.Any("namespace", namespace), slog.Any("repoName", repoName), slog.Any("error", err)) return } // check if it's evaluation dataset - evalDataset, err := c.dt.FindByRepo(ctx, string(types.EvaluationCategory), namespace, repoName, string(types.DatasetRepo)) + evalDataset, err := c.tagRuleStore.FindByRepo(ctx, string(types.EvaluationCategory), namespace, repoName, string(types.DatasetRepo)) if err != nil { if errors.Is(err, sql.ErrNoRows) { // check if it's a mirror repo @@ -415,7 +417,7 @@ func (c *GitCallbackComponent) updateDatasetTags(ctx context.Context, namespace, namespace := strings.Split(mirror.SourceRepoPath, "/")[0] name := strings.Split(mirror.SourceRepoPath, "/")[1] // use mirror namespace and name to find dataset - evalDataset, err = c.dt.FindByRepo(ctx, string(types.EvaluationCategory), namespace, name, string(types.DatasetRepo)) + evalDataset, err = c.tagRuleStore.FindByRepo(ctx, string(types.EvaluationCategory), namespace, name, string(types.DatasetRepo)) if err != nil { slog.Debug("not an evaluation dataset, ignore it", slog.Any("repo id", repo.Path)) return @@ -429,13 +431,13 @@ func (c *GitCallbackComponent) updateDatasetTags(ctx context.Context, namespace, tagIds := []int64{} tagIds = append(tagIds, evalDataset.Tag.ID) if evalDataset.RuntimeFramework != "" { - rTag, _ := c.ts.FindTag(ctx, evalDataset.RuntimeFramework, string(types.DatasetRepo), "runtime_framework") + rTag, _ := c.tagStore.FindTag(ctx, evalDataset.RuntimeFramework, string(types.DatasetRepo), "runtime_framework") if rTag != nil { tagIds = append(tagIds, rTag.ID) } } - err = c.ts.UpsertRepoTags(ctx, repo.ID, []int64{}, tagIds) + err = c.tagStore.UpsertRepoTags(ctx, repo.ID, []int64{}, tagIds) if err != nil { slog.Warn("fail to add dataset tag", slog.Any("repoId", repo.ID), slog.Any("tag id", tagIds), slog.Any("error", err)) } @@ -443,39 +445,39 @@ func (c *GitCallbackComponent) updateDatasetTags(ctx context.Context, namespace, } // update model runtime frameworks -func (c *GitCallbackComponent) updateModelRuntimeFrameworks(ctx context.Context, repoType, namespace, repoName, ref, fileName string, deleteAction bool) { +func (c *gitCallbackComponentImpl) updateModelRuntimeFrameworks(ctx context.Context, repoType, namespace, repoName, ref, fileName string, deleteAction bool) { // must be model repo and config.json if repoType != fmt.Sprintf("%ss", types.ModelRepo) || fileName != component.ConfigFileName || (ref != ("refs/heads/"+component.MainBranch) && ref != ("refs/heads/"+component.MasterBranch)) { return } - repo, err := c.rs.FindByPath(ctx, types.ModelRepo, namespace, repoName) + repo, err := c.repoStore.FindByPath(ctx, types.ModelRepo, namespace, repoName) if err != nil || repo == nil { slog.Warn("fail to query repo for git callback", slog.Any("namespace", namespace), slog.Any("repoName", repoName), slog.Any("error", err)) return } // delete event if deleteAction { - err := c.rrf.DeleteByRepoID(ctx, repo.ID) + err := c.repoRuntimeFrameworkStore.DeleteByRepoID(ctx, repo.ID) if err != nil { slog.Warn("fail to remove repo runtimes for git callback", slog.Any("namespace", namespace), slog.Any("repoName", repoName), slog.Any("repoid", repo.ID), slog.Any("error", err)) } return } - arch, err := c.rac.GetArchitectureFromConfig(ctx, namespace, repoName) + arch, err := c.runtimeArchComponent.GetArchitectureFromConfig(ctx, namespace, repoName) if err != nil { slog.Warn("fail to get config.json content for git callback", slog.Any("namespace", namespace), slog.Any("repoName", repoName), slog.Any("error", err)) return } slog.Debug("get arch for git callback", slog.Any("namespace", namespace), slog.Any("repoName", repoName), slog.Any("arch", arch)) //add resource tag, like ascend - runtime_framework_tags, _ := c.ts.GetTagsByScopeAndCategories(ctx, "model", []string{"runtime_framework", "resource"}) + runtime_framework_tags, _ := c.tagStore.GetTagsByScopeAndCategories(ctx, "model", []string{"runtime_framework", "resource"}) fields := strings.Split(repo.Path, "/") - err = c.rac.AddResourceTag(ctx, runtime_framework_tags, fields[1], repo.ID) + err = c.runtimeArchComponent.AddResourceTag(ctx, runtime_framework_tags, fields[1], repo.ID) if err != nil { slog.Warn("fail to add resource tag", slog.Any("error", err)) return } - runtimes, err := c.ras.ListByRArchNameAndModel(ctx, arch, fields[1]) + runtimes, err := c.runtimeArchStore.ListByRArchNameAndModel(ctx, arch, fields[1]) // to do check resource models if err != nil { slog.Warn("fail to get runtime ids by arch for git callback", slog.Any("arch", arch), slog.Any("error", err)) @@ -487,7 +489,7 @@ func (c *GitCallbackComponent) updateModelRuntimeFrameworks(ctx context.Context, frameIDs = append(frameIDs, runtime.RuntimeFrameworkID) } slog.Debug("get new frame ids for git callback", slog.Any("namespace", namespace), slog.Any("repoName", repoName), slog.Any("frameIDs", frameIDs)) - newFrames, err := c.rfs.ListByIDs(ctx, frameIDs) + newFrames, err := c.runtimeFrameworkStore.ListByIDs(ctx, frameIDs) if err != nil { slog.Warn("fail to get runtime frameworks for git callback", slog.Any("arch", arch), slog.Any("error", err)) return @@ -498,7 +500,7 @@ func (c *GitCallbackComponent) updateModelRuntimeFrameworks(ctx context.Context, newFrameMap[strconv.FormatInt(frame.ID, 10)] = strconv.FormatInt(frame.ID, 10) } slog.Debug("get new frame map by arch for git callback", slog.Any("namespace", namespace), slog.Any("repoName", repoName), slog.Any("newFrameMap", newFrameMap)) - oldRepoRuntimes, err := c.rrf.GetByRepoIDs(ctx, repo.ID) + oldRepoRuntimes, err := c.repoRuntimeFrameworkStore.GetByRepoIDs(ctx, repo.ID) if err != nil { slog.Warn("fail to get repo runtimes for git callback", slog.Any("repo.ID", repo.ID), slog.Any("error", err)) return @@ -516,12 +518,12 @@ func (c *GitCallbackComponent) updateModelRuntimeFrameworks(ctx context.Context, _, exist := newFrameMap[strconv.FormatInt(old.RuntimeFrameworkID, 10)] if !exist { // remove incorrect relations - err := c.rrf.Delete(ctx, old.RuntimeFrameworkID, repo.ID, old.Type) + err := c.repoRuntimeFrameworkStore.Delete(ctx, old.RuntimeFrameworkID, repo.ID, old.Type) if err != nil { slog.Warn("fail to delete old repo runtimes for git callback", slog.Any("repo.ID", repo.ID), slog.Any("runtime framework id", old.RuntimeFrameworkID), slog.Any("error", err)) } // remove runtime framework tags - c.rac.RemoveRuntimeFrameworkTag(ctx, runtime_framework_tags, repo.ID, old.RuntimeFrameworkID) + c.runtimeArchComponent.RemoveRuntimeFrameworkTag(ctx, runtime_framework_tags, repo.ID, old.RuntimeFrameworkID) } } @@ -531,12 +533,12 @@ func (c *GitCallbackComponent) updateModelRuntimeFrameworks(ctx context.Context, _, exist := oldFrameMap[strconv.FormatInt(new.ID, 10)] if !exist { // add new relations - err := c.rrf.Add(ctx, new.ID, repo.ID, new.Type) + err := c.repoRuntimeFrameworkStore.Add(ctx, new.ID, repo.ID, new.Type) if err != nil { slog.Warn("fail to add new repo runtimes for git callback", slog.Any("repo.ID", repo.ID), slog.Any("runtime framework id", new.ID), slog.Any("error", err)) } // add runtime framework and resource tags - err = c.rac.AddRuntimeFrameworkTag(ctx, runtime_framework_tags, repo.ID, new.ID) + err = c.runtimeArchComponent.AddRuntimeFrameworkTag(ctx, runtime_framework_tags, repo.ID, new.ID) if err != nil { slog.Warn("fail to add runtime framework tag for git callback", slog.Any("repo.ID", repo.ID), slog.Any("runtime framework id", new.ID), slog.Any("error", err)) } diff --git a/component/callback/git_callback_test.go b/component/callback/git_callback_test.go new file mode 100644 index 00000000..bafa8d85 --- /dev/null +++ b/component/callback/git_callback_test.go @@ -0,0 +1,187 @@ +package callback + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/git/gitserver" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" + "opencsg.com/csghub-server/component" +) + +func TestGitCallbackComponent_SetRepoVisibility(t *testing.T) { + ctx := context.TODO() + gc := initializeTestGitCallbackComponent(ctx, t) + + require.False(t, gc.setRepoVisibility) + gc.SetRepoVisibility(true) + require.True(t, gc.setRepoVisibility) +} + +func TestGitCallbackComponent_WatchSpaceChange(t *testing.T) { + ctx := mock.Anything + gc := initializeTestGitCallbackComponent(context.TODO(), t) + + gc.mocks.stores.SpaceMock().EXPECT().FindByPath(ctx, "b", "c").Return( + &database.Space{HasAppFile: true}, nil, + ) + gc.mocks.spaceComponent.EXPECT().FixHasEntryFile(ctx, &database.Space{ + HasAppFile: true, + }).Return(nil) + gc.mocks.spaceComponent.EXPECT().Deploy(ctx, "b", "c", "b").Return(100, nil) + + err := gc.WatchSpaceChange(context.TODO(), &types.GiteaCallbackPushReq{ + Ref: "main", + Repository: types.GiteaCallbackPushReq_Repository{ + FullName: "spaces_b/c/d", + }, + }) + require.Nil(t, err) +} + +func TestGitCallbackComponent_WatchRepoRelation(t *testing.T) { + ctx := mock.Anything + gc := initializeTestGitCallbackComponent(context.TODO(), t) + + gc.mocks.gitServer.EXPECT().GetRepoFileRaw(ctx, gitserver.GetRepoInfoByPathReq{ + Namespace: "b", + Name: "c", + Ref: "refs/heads/main", + Path: "README.md", + RepoType: types.SpaceRepo, + }).Return("", nil) + gc.mocks.stores.RepoMock().EXPECT().FindByPath(ctx, types.SpaceRepo, "b", "c").Return( + &database.Repository{ID: 1}, nil, + ) + gc.mocks.stores.RepoRelationMock().EXPECT().Override(ctx, int64(1)).Return(nil) + + err := gc.WatchRepoRelation(context.TODO(), &types.GiteaCallbackPushReq{ + Ref: "refs/heads/main", + Repository: types.GiteaCallbackPushReq_Repository{ + FullName: "spaces_b/c/d", + }, + Commits: []types.GiteaCallbackPushReq_Commit{ + {Modified: []string{types.ReadmeFileName}}, + }, + }) + require.Nil(t, err) +} + +func TestGitCallbackComponent_SetRepoUpdateTime(t *testing.T) { + for _, mirror := range []bool{false, true} { + t.Run(fmt.Sprintf("mirror %v", mirror), func(t *testing.T) { + dt := time.Date(2022, 2, 2, 2, 0, 0, 0, time.UTC) + ctx := mock.Anything + gc := initializeTestGitCallbackComponent(context.TODO(), t) + + gc.mocks.stores.RepoMock().EXPECT().IsMirrorRepo( + ctx, types.ModelRepo, "ns", "n", + ).Return(mirror, nil) + + if mirror { + gc.mocks.stores.RepoMock().EXPECT().SetUpdateTimeByPath( + ctx, types.ModelRepo, "ns", "n", dt, + ).Return(nil) + gc.mocks.stores.MirrorMock().EXPECT().FindByRepoPath( + ctx, types.ModelRepo, "ns", "n", + ).Return(&database.Mirror{}, nil) + gc.mocks.stores.MirrorMock().EXPECT().Update( + ctx, mock.Anything, + ).RunAndReturn(func(ctx context.Context, m *database.Mirror) error { + require.GreaterOrEqual(t, m.LastUpdatedAt, time.Now().Add(-5*time.Second)) + return nil + }) + } else { + gc.mocks.stores.RepoMock().EXPECT().SetUpdateTimeByPath( + ctx, types.ModelRepo, "ns", "n", mock.Anything, + ).RunAndReturn(func(ctx context.Context, rt types.RepositoryType, s1, s2 string, tt time.Time) error { + require.GreaterOrEqual(t, tt, time.Now().Add(-5*time.Second)) + return nil + }) + } + + err := gc.SetRepoUpdateTime(context.TODO(), &types.GiteaCallbackPushReq{ + Repository: types.GiteaCallbackPushReq_Repository{ + FullName: "models_ns/n", + }, + HeadCommit: types.GiteaCallbackPushReq_HeadCommit{ + Timestamp: dt.Format(time.RFC3339), + }, + }) + require.Nil(t, err) + }) + } +} + +func TestGitCallbackComponent_UpdateRepoInfos(t *testing.T) { + ctx := context.TODO() + gc := initializeTestGitCallbackComponent(context.TODO(), t) + + // modified mock + gc.mocks.stores.RepoMock().EXPECT().FindByPath(ctx, types.ModelRepo, "ns", "n").Return( + &database.Repository{ID: 1, Path: "foo/bar"}, nil, + ) + gc.mocks.runtimeArchComponent.EXPECT().GetArchitectureFromConfig(ctx, "ns", "n").Return("foo", nil) + gc.mocks.stores.TagMock().EXPECT().GetTagsByScopeAndCategories( + ctx, database.ModelTagScope, []string{"runtime_framework", "resource"}, + ).Return([]*database.Tag{{Name: "t1"}}, nil) + gc.mocks.runtimeArchComponent.EXPECT().AddResourceTag( + ctx, []*database.Tag{{Name: "t1"}}, "bar", int64(1), + ).Return(nil) + gc.mocks.stores.RuntimeArchMock().EXPECT().ListByRArchNameAndModel(ctx, "foo", "bar").Return( + []database.RuntimeArchitecture{{ID: 11, RuntimeFrameworkID: 111}}, nil, + ) + gc.mocks.stores.RuntimeFrameworkMock().EXPECT().ListByIDs(ctx, []int64{111}).Return( + []database.RuntimeFramework{{ID: 12, FrameName: "fm"}}, nil, + ) + gc.mocks.stores.RepoRuntimeFrameworkMock().EXPECT().GetByRepoIDs(ctx, int64(1)).Return( + []database.RepositoriesRuntimeFramework{{RuntimeFrameworkID: 13}}, nil, + ) + gc.mocks.stores.RepoRuntimeFrameworkMock().EXPECT().Delete(ctx, int64(13), int64(1), 0).Return(nil) + gc.mocks.runtimeArchComponent.EXPECT().RemoveRuntimeFrameworkTag( + ctx, []*database.Tag{{Name: "t1"}}, int64(1), int64(13), + ).Return() + gc.mocks.stores.RepoRuntimeFrameworkMock().EXPECT().Add(ctx, int64(12), int64(1), 0).Return(nil) + gc.mocks.runtimeArchComponent.EXPECT().AddRuntimeFrameworkTag( + ctx, []*database.Tag{{Name: "t1"}}, int64(1), int64(12), + ).Return(nil) + // removed mock + gc.mocks.tagComponent.EXPECT().UpdateLibraryTags( + ctx, database.ModelTagScope, "ns", "n", "bar.go", "", + ).Return(nil) + gc.mocks.tagComponent.EXPECT().ClearMetaTags(ctx, types.ModelRepo, "ns", "n").Return(nil) + // added mock + gc.mocks.tagComponent.EXPECT().UpdateLibraryTags( + ctx, database.ModelTagScope, "ns", "n", "", "foo.go", + ).Return(nil) + gc.mocks.gitServer.EXPECT().GetRepoFileRaw(mock.Anything, gitserver.GetRepoInfoByPathReq{ + Namespace: "ns", + Name: "n", + Ref: "refs/heads/main", + Path: "README.md", + RepoType: types.ModelRepo, + }).Return("", nil) + gc.mocks.tagComponent.EXPECT().UpdateMetaTags( + ctx, database.ModelTagScope, "ns", "n", "", + ).Return(nil, nil) + + err := gc.UpdateRepoInfos(ctx, &types.GiteaCallbackPushReq{ + Ref: "refs/heads/main", + Repository: types.GiteaCallbackPushReq_Repository{ + FullName: "models_ns/n", + }, + Commits: []types.GiteaCallbackPushReq_Commit{ + { + Modified: []string{component.ConfigFileName}, + Removed: []string{"bar.go", types.ReadmeFileName}, + Added: []string{"foo.go", types.ReadmeFileName}, + }, + }, + }) + require.Nil(t, err) +} diff --git a/component/callback/wire.go b/component/callback/wire.go new file mode 100644 index 00000000..e9abaea1 --- /dev/null +++ b/component/callback/wire.go @@ -0,0 +1,27 @@ +//go:build wireinject +// +build wireinject + +package callback + +import ( + "context" + + "github.com/google/wire" + "github.com/stretchr/testify/mock" +) + +type testGitCallbackWithMocks struct { + *gitCallbackComponentImpl + mocks *Mocks +} + +func initializeTestGitCallbackComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testGitCallbackWithMocks { + wire.Build( + MockCallbackSuperSet, GitCallbackComponentSet, + wire.Struct(new(testGitCallbackWithMocks), "*"), + ) + return &testGitCallbackWithMocks{} +} diff --git a/component/callback/wire_gen_test.go b/component/callback/wire_gen_test.go new file mode 100644 index 00000000..f05a8e39 --- /dev/null +++ b/component/callback/wire_gen_test.go @@ -0,0 +1,52 @@ +// Code generated by Wire. DO NOT EDIT. + +//go:generate go run -mod=mod github.com/google/wire/cmd/wire +//go:build !wireinject +// +build !wireinject + +package callback + +import ( + "context" + "github.com/stretchr/testify/mock" + "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/git/gitserver" + "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/rpc" + component2 "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" + "opencsg.com/csghub-server/common/tests" + "opencsg.com/csghub-server/component" +) + +// Injectors from wire.go: + +func initializeTestGitCallbackComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testGitCallbackWithMocks { + config := component.ProvideTestConfig() + mockStores := tests.NewMockStores(t) + mockGitServer := gitserver.NewMockGitServer(t) + mockTagComponent := component2.NewMockTagComponent(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mockRuntimeArchitectureComponent := component2.NewMockRuntimeArchitectureComponent(t) + mockSpaceComponent := component2.NewMockSpaceComponent(t) + callbackGitCallbackComponentImpl := NewTestGitCallbackComponent(config, mockStores, mockGitServer, mockTagComponent, mockModerationSvcClient, mockRuntimeArchitectureComponent, mockSpaceComponent) + mocks := &Mocks{ + stores: mockStores, + tagComponent: mockTagComponent, + spaceComponent: mockSpaceComponent, + gitServer: mockGitServer, + runtimeArchComponent: mockRuntimeArchitectureComponent, + } + callbackTestGitCallbackWithMocks := &testGitCallbackWithMocks{ + gitCallbackComponentImpl: callbackGitCallbackComponentImpl, + mocks: mocks, + } + return callbackTestGitCallbackWithMocks +} + +// wire.go: + +type testGitCallbackWithMocks struct { + *gitCallbackComponentImpl + mocks *Mocks +} diff --git a/component/callback/wireset.go b/component/callback/wireset.go new file mode 100644 index 00000000..4401844c --- /dev/null +++ b/component/callback/wireset.go @@ -0,0 +1,53 @@ +package callback + +import ( + "github.com/google/wire" + mock_git "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/git/gitserver" + mock_component "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" + "opencsg.com/csghub-server/builder/git/gitserver" + "opencsg.com/csghub-server/builder/rpc" + "opencsg.com/csghub-server/common/config" + "opencsg.com/csghub-server/common/tests" + "opencsg.com/csghub-server/component" +) + +type Mocks struct { + stores *tests.MockStores + tagComponent *mock_component.MockTagComponent + spaceComponent *mock_component.MockSpaceComponent + gitServer *mock_git.MockGitServer + runtimeArchComponent *mock_component.MockRuntimeArchitectureComponent +} + +var AllMockSet = wire.NewSet( + wire.Struct(new(Mocks), "*"), +) + +var MockCallbackSuperSet = wire.NewSet( + component.MockedStoreSet, component.MockedComponentSet, AllMockSet, + component.ProvideTestConfig, component.MockedGitServerSet, component.MockedModerationSvcClientSet, +) + +func NewTestGitCallbackComponent(config *config.Config, stores *tests.MockStores, gitServer gitserver.GitServer, tagComponent component.TagComponent, modSvcClient rpc.ModerationSvcClient, runtimeArchComponent component.RuntimeArchitectureComponent, spaceComponent component.SpaceComponent) *gitCallbackComponentImpl { + return &gitCallbackComponentImpl{ + config: config, + gitServer: gitServer, + tagComponent: tagComponent, + modSvcClient: modSvcClient, + modelStore: stores.Model, + datasetStore: stores.Dataset, + spaceComponent: spaceComponent, + spaceStore: stores.Space, + repoStore: stores.Repo, + repoRelationStore: stores.RepoRelation, + mirrorStore: stores.Mirror, + repoRuntimeFrameworkStore: stores.RepoRuntimeFramework, + runtimeArchComponent: runtimeArchComponent, + runtimeArchStore: stores.RuntimeArch, + runtimeFrameworkStore: stores.RuntimeFramework, + tagStore: stores.Tag, + tagRuleStore: stores.TagRule, + } +} + +var GitCallbackComponentSet = wire.NewSet(NewTestGitCallbackComponent) diff --git a/component/cluster_test.go b/component/cluster_test.go new file mode 100644 index 00000000..04e7820d --- /dev/null +++ b/component/cluster_test.go @@ -0,0 +1,42 @@ +package component + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/common/types" +) + +func TestClusterComponent_Index(t *testing.T) { + ctx := context.TODO() + cc := initializeTestClusterComponent(ctx, t) + + cc.mocks.deployer.EXPECT().ListCluster(ctx).Return(nil, nil) + + data, err := cc.Index(ctx) + require.Nil(t, err) + require.Equal(t, []types.ClusterRes([]types.ClusterRes(nil)), data) +} + +func TestClusterComponent_GetClusterById(t *testing.T) { + ctx := context.TODO() + cc := initializeTestClusterComponent(ctx, t) + + cc.mocks.deployer.EXPECT().GetClusterById(ctx, "c1").Return(nil, nil) + + data, err := cc.GetClusterById(ctx, "c1") + require.Nil(t, err) + require.Equal(t, (*types.ClusterRes)(nil), data) +} + +func TestClusterComponent_Update(t *testing.T) { + ctx := context.TODO() + cc := initializeTestClusterComponent(ctx, t) + + cc.mocks.deployer.EXPECT().UpdateCluster(ctx, types.ClusterRequest{}).Return(nil, nil) + + data, err := cc.Update(ctx, types.ClusterRequest{}) + require.Nil(t, err) + require.Equal(t, (*types.UpdateClusterResponse)(nil), data) +} diff --git a/component/evaluation.go b/component/evaluation.go index 683261cb..e437a11a 100644 --- a/component/evaluation.go +++ b/component/evaluation.go @@ -15,16 +15,16 @@ import ( ) type evaluationComponentImpl struct { - deployer deploy.Deployer - userStore database.UserStore - modelStore database.ModelStore - datasetStore database.DatasetStore - mirrorStore database.MirrorStore - spaceResourceStore database.SpaceResourceStore - tokenStore database.AccessTokenStore - rtfm database.RuntimeFrameworksStore - config *config.Config - ac AccountingComponent + deployer deploy.Deployer + userStore database.UserStore + modelStore database.ModelStore + datasetStore database.DatasetStore + mirrorStore database.MirrorStore + spaceResourceStore database.SpaceResourceStore + tokenStore database.AccessTokenStore + runtimeFrameworkStore database.RuntimeFrameworksStore + config *config.Config + accountingComponent AccountingComponent } type EvaluationComponent interface { @@ -43,13 +43,13 @@ func NewEvaluationComponent(config *config.Config) (EvaluationComponent, error) c.datasetStore = database.NewDatasetStore() c.mirrorStore = database.NewMirrorStore() c.tokenStore = database.NewAccessTokenStore() - c.rtfm = database.NewRuntimeFrameworksStore() + c.runtimeFrameworkStore = database.NewRuntimeFrameworksStore() c.config = config ac, err := NewAccountingComponent(config) if err != nil { return nil, fmt.Errorf("failed to create accounting component, %w", err) } - c.ac = ac + c.accountingComponent = ac return c, nil } @@ -97,7 +97,7 @@ func (c *evaluationComponentImpl) CreateEvaluation(ctx context.Context, req type hardware.Cpu.Num = "8" hardware.Memory = "32Gi" } - frame, err := c.rtfm.FindEnabledByID(ctx, req.RuntimeFrameworkId) + frame, err := c.runtimeFrameworkStore.FindEnabledByID(ctx, req.RuntimeFrameworkId) if err != nil { return nil, fmt.Errorf("cannot find available runtime framework, %w", err) } diff --git a/component/evaluation_test.go b/component/evaluation_test.go index a1620f14..cccc2ca9 100644 --- a/component/evaluation_test.go +++ b/component/evaluation_test.go @@ -6,32 +6,10 @@ import ( "testing" "github.com/stretchr/testify/require" - mock_deploy "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/deploy" - mock_component "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" - "opencsg.com/csghub-server/builder/deploy" "opencsg.com/csghub-server/builder/store/database" - "opencsg.com/csghub-server/common/config" - "opencsg.com/csghub-server/common/tests" "opencsg.com/csghub-server/common/types" ) -func NewTestEvaluationComponent(deployer deploy.Deployer, stores *tests.MockStores, ac AccountingComponent) EvaluationComponent { - cfg := &config.Config{} - cfg.Argo.QuotaGPUNumber = "1" - return &evaluationComponentImpl{ - deployer: deployer, - config: cfg, - userStore: stores.User, - modelStore: stores.Model, - datasetStore: stores.Dataset, - mirrorStore: stores.Mirror, - spaceResourceStore: stores.SpaceResource, - tokenStore: stores.AccessToken, - rtfm: stores.RuntimeFramework, - ac: ac, - } -} - func TestEvaluationComponent_CreateEvaluation(t *testing.T) { req := types.EvaluationReq{ TaskName: "test", @@ -66,30 +44,28 @@ func TestEvaluationComponent_CreateEvaluation(t *testing.T) { Token: "foo", } t.Run("create evaluation without resource id", func(t *testing.T) { - deployerMock := &mock_deploy.MockDeployer{} - stores := tests.NewMockStores(t) - ac := &mock_component.MockAccountingComponent{} - c := NewTestEvaluationComponent(deployerMock, stores, ac) - stores.UserMock().EXPECT().FindByUsername(ctx, req.Username).Return(database.User{ + c := initializeTestEvaluationComponent(ctx, t) + c.config.Argo.QuotaGPUNumber = "1" + c.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, req.Username).Return(database.User{ RoleMask: "admin", Username: req.Username, UUID: req.Username, ID: 1, }, nil).Once() - stores.ModelMock().EXPECT().FindByPath(ctx, "opencsg", "wukong").Return( + c.mocks.stores.ModelMock().EXPECT().FindByPath(ctx, "opencsg", "wukong").Return( &database.Model{ ID: 1, }, nil, ).Maybe() - stores.MirrorMock().EXPECT().FindByRepoPath(ctx, types.DatasetRepo, "opencsg", "hellaswag").Return(&database.Mirror{ + c.mocks.stores.MirrorMock().EXPECT().FindByRepoPath(ctx, types.DatasetRepo, "opencsg", "hellaswag").Return(&database.Mirror{ SourceRepoPath: "Rowan/hellaswag", }, nil) - stores.AccessTokenMock().EXPECT().FindByUID(ctx, int64(1)).Return(&database.AccessToken{Token: "foo"}, nil) - stores.RuntimeFrameworkMock().EXPECT().FindEnabledByID(ctx, int64(1)).Return(&database.RuntimeFramework{ + c.mocks.stores.AccessTokenMock().EXPECT().FindByUID(ctx, int64(1)).Return(&database.AccessToken{Token: "foo"}, nil) + c.mocks.stores.RuntimeFrameworkMock().EXPECT().FindEnabledByID(ctx, int64(1)).Return(&database.RuntimeFramework{ ID: 1, FrameImage: "lm-evaluation-harness:0.4.6", }, nil) - deployerMock.EXPECT().SubmitEvaluation(ctx, req2).Return(&types.ArgoWorkFlowRes{ + c.mocks.deployer.EXPECT().SubmitEvaluation(ctx, req2).Return(&types.ArgoWorkFlowRes{ ID: 1, TaskName: "test", }, nil) @@ -101,36 +77,36 @@ func TestEvaluationComponent_CreateEvaluation(t *testing.T) { t.Run("create evaluation with resource id", func(t *testing.T) { req.ResourceId = 1 req2.ResourceId = 1 - deployerMock := &mock_deploy.MockDeployer{} - stores := tests.NewMockStores(t) - ac := &mock_component.MockAccountingComponent{} - c := NewTestEvaluationComponent(deployerMock, stores, ac) - stores.UserMock().EXPECT().FindByUsername(ctx, req.Username).Return(database.User{ + c := initializeTestEvaluationComponent(ctx, t) + c.config.Argo.QuotaGPUNumber = "1" + c.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, req.Username).Return(database.User{ RoleMask: "admin", Username: req.Username, UUID: req.Username, ID: 1, }, nil).Once() - stores.ModelMock().EXPECT().FindByPath(ctx, "opencsg", "wukong").Return( + c.mocks.stores.ModelMock().EXPECT().FindByPath(ctx, "opencsg", "wukong").Return( &database.Model{ ID: 1, }, nil, ).Maybe() - stores.MirrorMock().EXPECT().FindByRepoPath(ctx, types.DatasetRepo, "opencsg", "hellaswag").Return(&database.Mirror{ + c.mocks.stores.MirrorMock().EXPECT().FindByRepoPath(ctx, types.DatasetRepo, "opencsg", "hellaswag").Return(&database.Mirror{ SourceRepoPath: "Rowan/hellaswag", }, nil) - stores.AccessTokenMock().EXPECT().FindByUID(ctx, int64(1)).Return(&database.AccessToken{Token: "foo"}, nil) - stores.RuntimeFrameworkMock().EXPECT().FindEnabledByID(ctx, int64(1)).Return(&database.RuntimeFramework{ + c.mocks.stores.AccessTokenMock().EXPECT().FindByUID(ctx, int64(1)).Return(&database.AccessToken{Token: "foo"}, nil) + c.mocks.stores.RuntimeFrameworkMock().EXPECT().FindEnabledByID(ctx, int64(1)).Return(&database.RuntimeFramework{ ID: 1, FrameImage: "lm-evaluation-harness:0.4.6", }, nil) + resource, err := json.Marshal(req2.Hardware) require.Nil(t, err) - stores.SpaceResourceMock().EXPECT().FindByID(ctx, int64(1)).Return(&database.SpaceResource{ + c.mocks.stores.SpaceResourceMock().EXPECT().FindByID(ctx, int64(1)).Return(&database.SpaceResource{ ID: 1, Resources: string(resource), }, nil) - deployerMock.EXPECT().SubmitEvaluation(ctx, req2).Return(&types.ArgoWorkFlowRes{ + c.mocks.deployer.EXPECT().SubmitEvaluation(ctx, req2).Return(&types.ArgoWorkFlowRes{ + ID: 1, TaskName: "test", }, nil) @@ -142,15 +118,13 @@ func TestEvaluationComponent_CreateEvaluation(t *testing.T) { } func TestEvaluationComponent_GetEvaluation(t *testing.T) { - deployerMock := &mock_deploy.MockDeployer{} - stores := tests.NewMockStores(t) - ac := &mock_component.MockAccountingComponent{} - c := NewTestEvaluationComponent(deployerMock, stores, ac) + ctx := context.TODO() + c := initializeTestEvaluationComponent(ctx, t) + c.config.Argo.QuotaGPUNumber = "1" req := types.EvaluationGetReq{ Username: "test", } - ctx := context.TODO() - deployerMock.EXPECT().GetEvaluation(ctx, req).Return(&types.ArgoWorkFlowRes{ + c.mocks.deployer.EXPECT().GetEvaluation(ctx, req).Return(&types.ArgoWorkFlowRes{ ID: 1, RepoIds: []string{"Rowan/hellaswag"}, Datasets: []string{"Rowan/hellaswag"}, @@ -161,7 +135,7 @@ func TestEvaluationComponent_GetEvaluation(t *testing.T) { TaskType: "evaluation", Status: "Succeed", }, nil) - stores.DatasetMock().EXPECT().ListByPath(ctx, []string{"Rowan/hellaswag"}).Return([]database.Dataset{ + c.mocks.stores.DatasetMock().EXPECT().ListByPath(ctx, []string{"Rowan/hellaswag"}).Return([]database.Dataset{ { Repository: &database.Repository{ Path: "Rowan/hellaswag", @@ -184,15 +158,13 @@ func TestEvaluationComponent_GetEvaluation(t *testing.T) { } func TestEvaluationComponent_DeleteEvaluation(t *testing.T) { - deployerMock := &mock_deploy.MockDeployer{} - stores := tests.NewMockStores(t) - ac := &mock_component.MockAccountingComponent{} - c := NewTestEvaluationComponent(deployerMock, stores, ac) + ctx := context.TODO() + c := initializeTestEvaluationComponent(ctx, t) + c.config.Argo.QuotaGPUNumber = "1" req := types.EvaluationDelReq{ Username: "test", } - ctx := context.TODO() - deployerMock.EXPECT().DeleteEvaluation(ctx, req).Return(nil) + c.mocks.deployer.EXPECT().DeleteEvaluation(ctx, req).Return(nil) err := c.DeleteEvaluation(ctx, req) require.Nil(t, err) } diff --git a/component/event.go b/component/event.go index 2ea4a1ec..3a28bb69 100644 --- a/component/event.go +++ b/component/event.go @@ -8,7 +8,7 @@ import ( ) type eventComponentImpl struct { - es database.EventStore + eventStore database.EventStore } // NewEventComponent creates a new EventComponent @@ -19,7 +19,7 @@ type EventComponent interface { func NewEventComponent() EventComponent { return &eventComponentImpl{ - es: database.NewEventStore(), + eventStore: database.NewEventStore(), } } @@ -34,5 +34,5 @@ func (ec *eventComponentImpl) NewEvents(ctx context.Context, events []types.Even }) } - return ec.es.BatchSave(ctx, dbevents) + return ec.eventStore.BatchSave(ctx, dbevents) } diff --git a/component/event_test.go b/component/event_test.go new file mode 100644 index 00000000..ebd0a039 --- /dev/null +++ b/component/event_test.go @@ -0,0 +1,22 @@ +package component + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +func TestEventComponent_NewEvent(t *testing.T) { + ctx := context.TODO() + ec := initializeTestEventComponent(ctx, t) + + ec.mocks.stores.EventMock().EXPECT().BatchSave(ctx, []database.Event{ + {EventID: "e1"}, + }).Return(nil) + + err := ec.NewEvents(ctx, []types.Event{{ID: "e1"}}) + require.Nil(t, err) +} diff --git a/component/hf_dataset.go b/component/hf_dataset.go index 03c48e89..5b669376 100644 --- a/component/hf_dataset.go +++ b/component/hf_dataset.go @@ -6,6 +6,7 @@ import ( "log/slog" "strings" + "opencsg.com/csghub-server/builder/git" "opencsg.com/csghub-server/builder/git/gitserver" "opencsg.com/csghub-server/builder/store/database" "opencsg.com/csghub-server/common/config" @@ -19,22 +20,28 @@ type HFDatasetComponent interface { func NewHFDatasetComponent(config *config.Config) (HFDatasetComponent, error) { c := &hFDatasetComponentImpl{} - c.ts = database.NewTagStore() - c.ds = database.NewDatasetStore() - c.rs = database.NewRepoStore() + c.tagStore = database.NewTagStore() + c.datasetStore = database.NewDatasetStore() + c.repoStore = database.NewRepoStore() var err error - c.repoComponentImpl, err = NewRepoComponentImpl(config) + c.repoComponent, err = NewRepoComponentImpl(config) if err != nil { return nil, err } + gs, err := git.NewGitServer(config) + if err != nil { + return nil, fmt.Errorf("failed to create git server, error: %w", err) + } + c.gitServer = gs return c, nil } type hFDatasetComponentImpl struct { - *repoComponentImpl - ts database.TagStore - ds database.DatasetStore - rs database.RepoStore + repoComponent RepoComponent + tagStore database.TagStore + datasetStore database.DatasetStore + repoStore database.RepoStore + gitServer gitserver.GitServer } func convertFilePathFromRoute(path string) string { @@ -42,12 +49,12 @@ func convertFilePathFromRoute(path string) string { } func (h *hFDatasetComponentImpl) GetPathsInfo(ctx context.Context, req types.PathReq) ([]types.HFDSPathInfo, error) { - ds, err := h.ds.FindByPath(ctx, req.Namespace, req.Name) + ds, err := h.datasetStore.FindByPath(ctx, req.Namespace, req.Name) if err != nil { return nil, fmt.Errorf("failed to find dataset, error: %w", err) } - allow, err := h.AllowReadAccessRepo(ctx, ds.Repository, req.CurrentUser) + allow, err := h.repoComponent.AllowReadAccessRepo(ctx, ds.Repository, req.CurrentUser) if err != nil { return nil, fmt.Errorf("failed to check dataset permission, error: %w", err) } @@ -62,7 +69,7 @@ func (h *hFDatasetComponentImpl) GetPathsInfo(ctx context.Context, req types.Pat Path: convertFilePathFromRoute(req.Path), RepoType: types.DatasetRepo, } - file, _ := h.git.GetRepoFileContents(ctx, getRepoFileTree) + file, _ := h.gitServer.GetRepoFileContents(ctx, getRepoFileTree) if file == nil { return []types.HFDSPathInfo{}, nil } @@ -81,12 +88,12 @@ func (h *hFDatasetComponentImpl) GetPathsInfo(ctx context.Context, req types.Pat } func (h *hFDatasetComponentImpl) GetDatasetTree(ctx context.Context, req types.PathReq) ([]types.HFDSPathInfo, error) { - ds, err := h.ds.FindByPath(ctx, req.Namespace, req.Name) + ds, err := h.datasetStore.FindByPath(ctx, req.Namespace, req.Name) if err != nil { return nil, fmt.Errorf("failed to find dataset tree, error: %w", err) } - allow, err := h.AllowReadAccessRepo(ctx, ds.Repository, req.CurrentUser) + allow, err := h.repoComponent.AllowReadAccessRepo(ctx, ds.Repository, req.CurrentUser) if err != nil { return nil, fmt.Errorf("failed to check dataset permission, error: %w", err) } @@ -102,7 +109,7 @@ func (h *hFDatasetComponentImpl) GetDatasetTree(ctx context.Context, req types.P Path: req.Path, RepoType: types.DatasetRepo, } - tree, err := h.git.GetRepoFileTree(ctx, getRepoFileTree) + tree, err := h.gitServer.GetRepoFileTree(ctx, getRepoFileTree) if err != nil { slog.Warn("failed to get repo file tree", slog.Any("getRepoFileTree", getRepoFileTree), slog.String("error", err.Error())) return []types.HFDSPathInfo{}, nil diff --git a/component/hf_dataset_test.go b/component/hf_dataset_test.go new file mode 100644 index 00000000..b8e0bda1 --- /dev/null +++ b/component/hf_dataset_test.go @@ -0,0 +1,71 @@ +package component + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/git/gitserver" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +func TestHFDataset_GetPathsInfo(t *testing.T) { + ctx := context.TODO() + hc := initializeTestHFDatasetComponent(ctx, t) + + dataset := &database.Dataset{} + hc.mocks.stores.DatasetMock().EXPECT().FindByPath(ctx, "ns", "n").Return(dataset, nil) + hc.mocks.components.repo.EXPECT().AllowReadAccessRepo(ctx, dataset.Repository, "user").Return(true, nil) + hc.mocks.gitServer.EXPECT().GetRepoFileContents(ctx, gitserver.GetRepoInfoByPathReq{ + Namespace: "ns", + Name: "n", + Path: "a/b", + Ref: "main", + RepoType: types.DatasetRepo, + }).Return(&types.File{ + Type: "go", LastCommitSHA: "sha", Size: 5, Path: "foo", + }, nil) + + data, err := hc.GetPathsInfo(ctx, types.PathReq{ + Namespace: "ns", + Name: "n", + Ref: "main", + Path: "a/b", + CurrentUser: "user", + }) + require.Nil(t, err) + require.Equal(t, []types.HFDSPathInfo{ + {Type: "file", Path: "foo", Size: 5, OID: "sha"}, + }, data) + +} + +func TestHFDataset_GetDatasetTree(t *testing.T) { + ctx := context.TODO() + hc := initializeTestHFDatasetComponent(ctx, t) + + dataset := &database.Dataset{} + hc.mocks.stores.DatasetMock().EXPECT().FindByPath(ctx, "ns", "n").Return(dataset, nil) + hc.mocks.components.repo.EXPECT().AllowReadAccessRepo(ctx, dataset.Repository, "user").Return(true, nil) + hc.mocks.gitServer.EXPECT().GetRepoFileTree(ctx, gitserver.GetRepoInfoByPathReq{ + Namespace: "ns", + Name: "n", + Path: "a/b", + RepoType: types.DatasetRepo, + }).Return([]*types.File{ + {Type: "go", LastCommitSHA: "sha", Size: 5, Path: "foo"}, + }, nil) + + data, err := hc.GetDatasetTree(ctx, types.PathReq{ + Namespace: "ns", + Name: "n", + Ref: "main", + Path: "a/b", + CurrentUser: "user", + }) + require.Nil(t, err) + require.Equal(t, []types.HFDSPathInfo{ + {Type: "go", Path: "foo", Size: 5, OID: "sha"}, + }, data) +} diff --git a/component/list.go b/component/list.go index 3567692b..2fcccdde 100644 --- a/component/list.go +++ b/component/list.go @@ -16,22 +16,22 @@ type ListComponent interface { func NewListComponent(config *config.Config) (ListComponent, error) { c := &listComponentImpl{} - c.ds = database.NewDatasetStore() - c.ms = database.NewModelStore() - c.ss = database.NewSpaceStore() + c.datasetStore = database.NewDatasetStore() + c.modelStore = database.NewModelStore() + c.spaceStore = database.NewSpaceStore() return c, nil } type listComponentImpl struct { - ms database.ModelStore - ds database.DatasetStore - ss database.SpaceStore + modelStore database.ModelStore + datasetStore database.DatasetStore + spaceStore database.SpaceStore } func (c *listComponentImpl) ListModelsByPath(ctx context.Context, req *types.ListByPathReq) ([]*types.ModelResp, error) { var modelResp []*types.ModelResp - models, err := c.ms.ListByPath(ctx, req.Paths) + models, err := c.modelStore.ListByPath(ctx, req.Paths) if err != nil { slog.Error("error listing models by path", "error", err, slog.Any("paths", req.Paths)) return nil, err @@ -67,7 +67,7 @@ func (c *listComponentImpl) ListModelsByPath(ctx context.Context, req *types.Lis func (c *listComponentImpl) ListDatasetsByPath(ctx context.Context, req *types.ListByPathReq) ([]*types.DatasetResp, error) { var datasetResp []*types.DatasetResp - datasets, err := c.ds.ListByPath(ctx, req.Paths) + datasets, err := c.datasetStore.ListByPath(ctx, req.Paths) if err != nil { slog.Error("error listing datasets by path", "error", err, slog.Any("paths", req.Paths)) return nil, err diff --git a/component/list_test.go b/component/list_test.go new file mode 100644 index 00000000..490503cf --- /dev/null +++ b/component/list_test.go @@ -0,0 +1,46 @@ +package component + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +func TestListComponent_ListModelsByPath(t *testing.T) { + ctx := context.TODO() + lc := initializeTestListComponent(ctx, t) + + lc.mocks.stores.ModelMock().EXPECT().ListByPath(ctx, []string{"foo"}).Return( + []database.Model{ + {Repository: &database.Repository{ + Name: "r1", + Tags: []database.Tag{{Name: "t1"}}, + }}, + }, nil, + ) + + data, err := lc.ListModelsByPath(ctx, &types.ListByPathReq{Paths: []string{"foo"}}) + require.Nil(t, err) + require.Equal(t, []*types.ModelResp{{Name: "r1", Tags: []types.RepoTag{{Name: "t1"}}}}, data) +} + +func TestListComponent_ListDatasetByPath(t *testing.T) { + ctx := context.TODO() + lc := initializeTestListComponent(ctx, t) + + lc.mocks.stores.DatasetMock().EXPECT().ListByPath(ctx, []string{"foo"}).Return( + []database.Dataset{ + {Repository: &database.Repository{ + Name: "r1", + Tags: []database.Tag{{Name: "t1"}}, + }}, + }, nil, + ) + + data, err := lc.ListDatasetsByPath(ctx, &types.ListByPathReq{Paths: []string{"foo"}}) + require.Nil(t, err) + require.Equal(t, []*types.ModelResp{{Name: "r1", Tags: []types.RepoTag{{Name: "t1"}}}}, data) +} diff --git a/component/repo_file.go b/component/repo_file.go index 22481ca6..c1f35b0f 100644 --- a/component/repo_file.go +++ b/component/repo_file.go @@ -14,9 +14,9 @@ import ( ) type repoFileComponentImpl struct { - rfs database.RepoFileStore - rs database.RepoStore - gs gitserver.GitServer + repoFileStore database.RepoFileStore + repoStore database.RepoStore + gitServer gitserver.GitServer } type RepoFileComponent interface { @@ -26,23 +26,23 @@ type RepoFileComponent interface { func NewRepoFileComponent(conf *config.Config) (RepoFileComponent, error) { c := &repoFileComponentImpl{ - rfs: database.NewRepoFileStore(), - rs: database.NewRepoStore(), + repoFileStore: database.NewRepoFileStore(), + repoStore: database.NewRepoStore(), } gs, err := git.NewGitServer(conf) if err != nil { return nil, fmt.Errorf("failed to create git server, error: %w", err) } - c.gs = gs + c.gitServer = gs return c, nil } func (c *repoFileComponentImpl) GenRepoFileRecords(ctx context.Context, repoType types.RepositoryType, namespace, name string) error { - repo, err := c.rs.FindByPath(ctx, repoType, namespace, name) + repo, err := c.repoStore.FindByPath(ctx, repoType, namespace, name) if err != nil { return fmt.Errorf("failed to find repo, error: %w", err) } - return c.createRepoFileRecords(ctx, *repo, "", c.gs.GetRepoFileTree) + return c.createRepoFileRecords(ctx, *repo, "", c.gitServer.GetRepoFileTree) } func (c *repoFileComponentImpl) GenRepoFileRecordsBatch(ctx context.Context, repoType types.RepositoryType, lastRepoID int64, concurrency int) error { @@ -54,7 +54,7 @@ func (c *repoFileComponentImpl) GenRepoFileRecordsBatch(ctx context.Context, rep //TODO: load last repo id from redis cache batch := 10 for { - repos, err := c.rs.BatchGet(ctx, repoType, lastRepoID, batch) + repos, err := c.repoStore.BatchGet(ctx, repoType, lastRepoID, batch) if err != nil { return fmt.Errorf("failed to get repos in batch, error: %w", err) } @@ -65,7 +65,7 @@ func (c *repoFileComponentImpl) GenRepoFileRecordsBatch(ctx context.Context, rep go func(repo database.Repository) { slog.Info("start to get files of repository", slog.Any("repoType", repoType), slog.String("path", repo.Path)) //get file paths of repo - err := c.createRepoFileRecords(ctx, repo, "", c.gs.GetRepoFileTree) + err := c.createRepoFileRecords(ctx, repo, "", c.gitServer.GetRepoFileTree) if err != nil { slog.Error("fail to get all files of repository", slog.String("path", repo.Path), slog.String("repo_type", string(repo.RepositoryType)), @@ -127,7 +127,7 @@ func (c *repoFileComponentImpl) createRepoFileRecords(ctx context.Context, repo var exists bool var err error - if exists, err = c.rfs.Exists(ctx, rf); err != nil { + if exists, err = c.repoFileStore.Exists(ctx, rf); err != nil { slog.Error("failed to check repository file exists", slog.Any("repo_id", repo.ID), slog.String("file_path", rf.Path), slog.String("error", err.Error())) continue @@ -137,7 +137,7 @@ func (c *repoFileComponentImpl) createRepoFileRecords(ctx context.Context, repo slog.Info("skip create exist repository file", slog.Any("repo_id", repo.ID), slog.String("file_path", rf.Path)) continue } - if err := c.rfs.Create(ctx, &rf); err != nil { + if err := c.repoFileStore.Create(ctx, &rf); err != nil { slog.Error("failed to save repository file", slog.Any("repo_id", repo.ID), slog.String("error", err.Error())) return fmt.Errorf("failed to save repository file, error: %w", err) diff --git a/component/repo_file_test.go b/component/repo_file_test.go new file mode 100644 index 00000000..c40c4978 --- /dev/null +++ b/component/repo_file_test.go @@ -0,0 +1,89 @@ +package component + +import ( + "context" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/git/gitserver" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +func TestRepoFileComponent_GenRepoFileRecords(t *testing.T) { + ctx := context.TODO() + rc := initializeTestRepoFileComponent(ctx, t) + + rc.mocks.stores.RepoMock().EXPECT().FindByPath(ctx, types.ModelRepo, "ns", "n").Return( + &database.Repository{ID: 1, Path: "foo/bar"}, nil, + ) + rc.mocks.gitServer.EXPECT().GetRepoFileTree(mock.Anything, gitserver.GetRepoInfoByPathReq{ + Namespace: "foo", + Name: "bar", + }).Return( + []*types.File{ + {Path: "a/b", Type: "dir"}, + {Path: "foo.go", Type: "go"}, + }, nil, + ) + rc.mocks.gitServer.EXPECT().GetRepoFileTree(mock.Anything, gitserver.GetRepoInfoByPathReq{ + Path: "a/b", + Namespace: "foo", + Name: "bar", + }).Return( + []*types.File{}, nil, + ) + rc.mocks.stores.RepoFileMock().EXPECT().Exists(ctx, database.RepositoryFile{ + RepositoryID: 1, + Path: "foo.go", + FileType: "go", + }).Return(false, nil) + rc.mocks.stores.RepoFileMock().EXPECT().Create(ctx, &database.RepositoryFile{ + RepositoryID: 1, + Path: "foo.go", + FileType: "go", + }).Return(nil) + + err := rc.GenRepoFileRecords(ctx, types.ModelRepo, "ns", "n") + require.Nil(t, err) + +} + +func TestRepoFileComponent_GenRepoFileRecordsBatch(t *testing.T) { + ctx := context.TODO() + rc := initializeTestRepoFileComponent(ctx, t) + + rc.mocks.stores.RepoMock().EXPECT().BatchGet(ctx, types.ModelRepo, int64(1), 10).Return( + []database.Repository{{ID: 1, Path: "foo/bar"}}, nil, + ) + rc.mocks.gitServer.EXPECT().GetRepoFileTree(mock.Anything, gitserver.GetRepoInfoByPathReq{ + Namespace: "foo", + Name: "bar", + }).Return( + []*types.File{ + {Path: "a/b", Type: "dir"}, + {Path: "foo.go", Type: "go"}, + }, nil, + ) + rc.mocks.gitServer.EXPECT().GetRepoFileTree(mock.Anything, gitserver.GetRepoInfoByPathReq{ + Path: "a/b", + Namespace: "foo", + Name: "bar", + }).Return( + []*types.File{}, nil, + ) + rc.mocks.stores.RepoFileMock().EXPECT().Exists(ctx, database.RepositoryFile{ + RepositoryID: 1, + Path: "foo.go", + FileType: "go", + }).Return(false, nil) + rc.mocks.stores.RepoFileMock().EXPECT().Create(ctx, &database.RepositoryFile{ + RepositoryID: 1, + Path: "foo.go", + FileType: "go", + }).Return(nil) + + err := rc.GenRepoFileRecordsBatch(ctx, types.ModelRepo, 1, 10) + require.Nil(t, err) +} diff --git a/component/sensitive_test.go b/component/sensitive_test.go index a7c1eb97..3d760373 100644 --- a/component/sensitive_test.go +++ b/component/sensitive_test.go @@ -6,7 +6,6 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" - mockrpc "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/rpc" mocktypes "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/common/types" "opencsg.com/csghub-server/builder/rpc" "opencsg.com/csghub-server/builder/sensitive" @@ -15,13 +14,12 @@ import ( ) func TestSensitiveComponent_CheckText(t *testing.T) { - mockModeration := mockrpc.NewMockModerationSvcClient(t) - mockModeration.EXPECT().PassTextCheck(mock.Anything, mock.Anything, mock.Anything).Return(&rpc.CheckResult{ + ctx := context.TODO() + comp := initializeTestSensitiveComponent(ctx, t) + + comp.mocks.moderationClient.EXPECT().PassTextCheck(mock.Anything, mock.Anything, mock.Anything).Return(&rpc.CheckResult{ IsSensitive: false, }, nil) - comp := &sensitiveComponentImpl{ - checker: mockModeration, - } success, err := comp.CheckText(context.TODO(), string(sensitive.ScenarioChatDetection), "test") require.Nil(t, err) @@ -29,13 +27,12 @@ func TestSensitiveComponent_CheckText(t *testing.T) { } func TestSensitiveComponent_CheckImage(t *testing.T) { - mockModeration := mockrpc.NewMockModerationSvcClient(t) - mockModeration.EXPECT().PassImageCheck(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&rpc.CheckResult{ + ctx := context.TODO() + comp := initializeTestSensitiveComponent(ctx, t) + + comp.mocks.moderationClient.EXPECT().PassImageCheck(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&rpc.CheckResult{ IsSensitive: false, }, nil) - comp := &sensitiveComponentImpl{ - checker: mockModeration, - } success, err := comp.CheckImage(context.TODO(), string(sensitive.ScenarioChatDetection), "ossBucketName", "ossObjectName") require.Nil(t, err) @@ -43,13 +40,13 @@ func TestSensitiveComponent_CheckImage(t *testing.T) { } func TestSensitiveComponent_CheckRequestV2(t *testing.T) { - mockModeration := mockrpc.NewMockModerationSvcClient(t) - mockModeration.EXPECT().PassTextCheck(mock.Anything, mock.Anything, mock.Anything).Return(&rpc.CheckResult{ + ctx := context.TODO() + comp := initializeTestSensitiveComponent(ctx, t) + + comp.mocks.moderationClient.EXPECT().PassTextCheck(mock.Anything, mock.Anything, mock.Anything).Return(&rpc.CheckResult{ IsSensitive: false, }, nil).Twice() - comp := &sensitiveComponentImpl{ - checker: mockModeration, - } + mockRequest := mocktypes.NewMockSensitiveRequestV2(t) mockRequest.EXPECT().GetSensitiveFields().Return([]types.SensitiveField{ { diff --git a/component/sshkey.go b/component/sshkey.go index a0f4cd10..9ccc9fa1 100644 --- a/component/sshkey.go +++ b/component/sshkey.go @@ -23,10 +23,10 @@ type SSHKeyComponent interface { func NewSSHKeyComponent(config *config.Config) (SSHKeyComponent, error) { c := &sSHKeyComponentImpl{} - c.ss = database.NewSSHKeyStore() - c.us = database.NewUserStore() + c.sshKeyStore = database.NewSSHKeyStore() + c.userStore = database.NewUserStore() var err error - c.gs, err = git.NewGitServer(config) + c.gitServer, err = git.NewGitServer(config) if err != nil { newError := fmt.Errorf("failed to create git server,error:%w", err) slog.Error(newError.Error()) @@ -36,17 +36,17 @@ func NewSSHKeyComponent(config *config.Config) (SSHKeyComponent, error) { } type sSHKeyComponentImpl struct { - ss database.SSHKeyStore - us database.UserStore - gs gitserver.GitServer + sshKeyStore database.SSHKeyStore + userStore database.UserStore + gitServer gitserver.GitServer } func (c *sSHKeyComponentImpl) Create(ctx context.Context, req *types.CreateSSHKeyRequest) (*database.SSHKey, error) { - user, err := c.us.FindByUsername(ctx, req.Username) + user, err := c.userStore.FindByUsername(ctx, req.Username) if err != nil { return nil, fmt.Errorf("failed to find user,error:%w", err) } - nameExistsKey, err := c.ss.FindByNameAndUserID(ctx, req.Name, user.ID) + nameExistsKey, err := c.sshKeyStore.FindByNameAndUserID(ctx, req.Name, user.ID) if err != nil && !errors.Is(err, sql.ErrNoRows) { return nil, fmt.Errorf("failed to find if ssh key exists,error:%w", err) } @@ -54,15 +54,14 @@ func (c *sSHKeyComponentImpl) Create(ctx context.Context, req *types.CreateSSHKe return nil, fmt.Errorf("ssh key name already exists") } - contentExistsKey, err := c.ss.FindByKeyContent(ctx, req.Content) + contentExistsKey, err := c.sshKeyStore.FindByKeyContent(ctx, req.Content) if err != nil && !errors.Is(err, sql.ErrNoRows) { return nil, fmt.Errorf("failed to find if ssh key exists,error:%w", err) } if contentExistsKey.ID != 0 { return nil, fmt.Errorf("ssh key already exists") } - - sk, err := c.gs.CreateSSHKey(req) + sk, err := c.gitServer.CreateSSHKey(req) if err != nil { return nil, fmt.Errorf("failed to create git SSH key,error:%w", err) } @@ -80,7 +79,7 @@ func (c *sSHKeyComponentImpl) Create(ctx context.Context, req *types.CreateSSHKe } sk.UserID = user.ID sk.FingerprintSHA256 = fingerprint - resSk, err := c.ss.Create(ctx, sk) + resSk, err := c.sshKeyStore.Create(ctx, sk) if err != nil { return nil, fmt.Errorf("failed to create database SSH key,error:%w", err) } @@ -88,7 +87,7 @@ func (c *sSHKeyComponentImpl) Create(ctx context.Context, req *types.CreateSSHKe } func (c *sSHKeyComponentImpl) Index(ctx context.Context, username string, per, page int) ([]database.SSHKey, error) { - sks, err := c.ss.Index(ctx, username, per, page) + sks, err := c.sshKeyStore.Index(ctx, username, per, page) if err != nil { return nil, fmt.Errorf("failed to get database SSH keys,error:%w", err) } @@ -96,15 +95,15 @@ func (c *sSHKeyComponentImpl) Index(ctx context.Context, username string, per, p } func (c *sSHKeyComponentImpl) Delete(ctx context.Context, username, name string) error { - sshKey, err := c.ss.FindByUsernameAndName(ctx, username, name) + sshKey, err := c.sshKeyStore.FindByUsernameAndName(ctx, username, name) if err != nil { return fmt.Errorf("failed to get database SSH keys,error:%w", err) } - err = c.gs.DeleteSSHKey(int(sshKey.GitID)) + err = c.gitServer.DeleteSSHKey(int(sshKey.GitID)) if err != nil { return fmt.Errorf("failed to delete git SSH keys,error:%w", err) } - err = c.ss.Delete(ctx, sshKey.GitID) + err = c.sshKeyStore.Delete(ctx, sshKey.ID) if err != nil { return fmt.Errorf("failed to delete database SSH keys,error:%w", err) } diff --git a/component/sshkey_test.go b/component/sshkey_test.go new file mode 100644 index 00000000..5bececec --- /dev/null +++ b/component/sshkey_test.go @@ -0,0 +1,67 @@ +package component + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +const testKey = ` +ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCn4yeHw9InFrZIxYxFhs5Giam76NPIJ1kOqEq1xvWz4vJJMGkoqosTsqUf+V4Pj18qSUbSEDbwibzkIAPFNRiF1lQWgpFvZrZsTmD6rV1ODYjGPu5HLHqjCY/ffY+n/cAz66sZ5TQUMh+9HmUkVriu/Flfo7dWrbsrC73vgfVptSzSIEehkm4wL40XaZI4wQ7JffdXyqz5CU/lK+CFaPU2nLnxVoL9CEaFbCglcP4sO2jir2Rcx5ZNBMHYpsqk9N4cOxpS/IA9YX2tla3o4wltJoO83Vp0qH1ds15WBAlwUAdpJGDajh93kgYki6Kn2v41IgmqgFcXpmBQ+48QZXfh +` + +func TestSSHKeyComponent_Create(t *testing.T) { + ctx := context.TODO() + sc := initializeTestSSHKeyComponent(ctx, t) + + req := &types.CreateSSHKeyRequest{ + Username: "user", + Name: "n", + Content: testKey, + } + sc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{ID: 1}, nil) + sc.mocks.stores.SSHMock().EXPECT().FindByNameAndUserID(ctx, "n", int64(1)).Return( + &database.SSHKey{}, nil, + ) + sc.mocks.stores.SSHMock().EXPECT().FindByKeyContent(ctx, testKey).Return(&database.SSHKey{}, nil) + sc.mocks.gitServer.EXPECT().CreateSSHKey(req).Return(&database.SSHKey{}, nil) + sc.mocks.stores.SSHMock().EXPECT().Create(ctx, &database.SSHKey{ + UserID: 1, + FingerprintSHA256: "DZMgXySN8FuYZo2qvIAZOXNB0J81NMAv1SikyHvCPmw", + }).Return(&database.SSHKey{}, nil) + + data, err := sc.Create(ctx, req) + require.NoError(t, err) + require.Equal(t, &database.SSHKey{}, data) + +} + +func TestSSHKeyComponent_Index(t *testing.T) { + ctx := context.TODO() + sc := initializeTestSSHKeyComponent(ctx, t) + + sc.mocks.stores.SSHMock().EXPECT().Index(ctx, "user", 10, 1).Return( + []database.SSHKey{{Name: "a"}}, nil, + ) + + data, err := sc.Index(ctx, "user", 10, 1) + require.Nil(t, err) + require.Equal(t, data, []database.SSHKey{{Name: "a"}}) +} + +func TestSSHKeyComponent_Delete(t *testing.T) { + ctx := context.TODO() + sc := initializeTestSSHKeyComponent(ctx, t) + + sc.mocks.stores.SSHMock().EXPECT().FindByUsernameAndName(ctx, "user", "key").Return( + database.SSHKey{ID: 1, GitID: 123}, nil, + ) + sc.mocks.gitServer.EXPECT().DeleteSSHKey(123).Return(nil) + sc.mocks.stores.SSHMock().EXPECT().Delete(ctx, int64(1)).Return(nil) + + err := sc.Delete(ctx, "user", "key") + require.Nil(t, err) +} diff --git a/component/sync_client_setting_test.go b/component/sync_client_setting_test.go new file mode 100644 index 00000000..9b1efeba --- /dev/null +++ b/component/sync_client_setting_test.go @@ -0,0 +1,43 @@ +package component + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +func TestSyncClientSettingComponent_Create(t *testing.T) { + ctx := context.TODO() + sc := initializeTestSyncClientSettingComponent(ctx, t) + + sc.mocks.stores.SyncClientSettingMock().EXPECT().SyncClientSettingExists(ctx).Return(true, nil) + sc.mocks.stores.SyncClientSettingMock().EXPECT().DeleteAll(ctx).Return(nil) + sc.mocks.stores.SyncClientSettingMock().EXPECT().Create(ctx, &database.SyncClientSetting{ + Token: "t", + ConcurrentCount: 1, + MaxBandwidth: 5, + }).Return(&database.SyncClientSetting{}, nil) + + data, err := sc.Create(ctx, types.CreateSyncClientSettingReq{ + Token: "t", + ConcurrentCount: 1, + MaxBandwidth: 5, + }) + require.Nil(t, err) + require.Equal(t, &database.SyncClientSetting{}, data) + +} + +func TestSyncClientSettingComponent_Show(t *testing.T) { + ctx := context.TODO() + sc := initializeTestSyncClientSettingComponent(ctx, t) + + sc.mocks.stores.SyncClientSettingMock().EXPECT().First(ctx).Return(&database.SyncClientSetting{}, nil) + + data, err := sc.Show(ctx) + require.Nil(t, err) + require.Equal(t, &database.SyncClientSetting{}, data) +} diff --git a/component/telemetry.go b/component/telemetry.go index a1654ae9..bbdd362f 100644 --- a/component/telemetry.go +++ b/component/telemetry.go @@ -12,10 +12,9 @@ import ( ) type telemetryComponentImpl struct { - // Add telemetry related fields and methods here - ts database.TelemetryStore - us database.UserStore - rs database.RepoStore + telemetryStore database.TelemetryStore + userStore database.UserStore + repoStore database.RepoStore } type TelemetryComponent interface { @@ -27,7 +26,7 @@ func NewTelemetryComponent() (TelemetryComponent, error) { ts := database.NewTelemetryStore() us := database.NewUserStore() rs := database.NewRepoStore() - return &telemetryComponentImpl{ts: ts, us: us, rs: rs}, nil + return &telemetryComponentImpl{telemetryStore: ts, userStore: us, repoStore: rs}, nil } func (tc *telemetryComponentImpl) SaveUsageData(ctx context.Context, usage telemetry.Usage) error { @@ -52,7 +51,7 @@ func (tc *telemetryComponentImpl) SaveUsageData(ctx context.Context, usage telem Settings: usage.Settings, Counts: usage.Counts, } - err := tc.ts.Save(ctx, &t) + err := tc.telemetryStore.Save(ctx, &t) if err != nil { return fmt.Errorf("failed to save telemetry data to db: %w", err) } @@ -105,27 +104,27 @@ func (tc *telemetryComponentImpl) GenUsageData(ctx context.Context) (telemetry.U } func (tc *telemetryComponentImpl) getUserCnt(ctx context.Context) (int, error) { - return tc.us.CountUsers(ctx) + return tc.userStore.CountUsers(ctx) } func (tc *telemetryComponentImpl) getCounts(ctx context.Context) (telemetry.Counts, error) { var counts telemetry.Counts - modelCnt, err := tc.rs.CountByRepoType(ctx, types.ModelRepo) + modelCnt, err := tc.repoStore.CountByRepoType(ctx, types.ModelRepo) if err != nil { return counts, fmt.Errorf("failed to get model repo count: %w", err) } - dsCnt, err := tc.rs.CountByRepoType(ctx, types.DatasetRepo) + dsCnt, err := tc.repoStore.CountByRepoType(ctx, types.DatasetRepo) if err != nil { return counts, fmt.Errorf("failed to get dataset repo count: %w", err) } - codeCnt, err := tc.rs.CountByRepoType(ctx, types.CodeRepo) + codeCnt, err := tc.repoStore.CountByRepoType(ctx, types.CodeRepo) if err != nil { return counts, fmt.Errorf("failed to get code repo count: %w", err) } - spaceCnt, err := tc.rs.CountByRepoType(ctx, types.SpaceRepo) + spaceCnt, err := tc.repoStore.CountByRepoType(ctx, types.SpaceRepo) if err != nil { return counts, fmt.Errorf("failed to get space repo count: %w", err) } diff --git a/component/telemetry_test.go b/component/telemetry_test.go new file mode 100644 index 00000000..dd839952 --- /dev/null +++ b/component/telemetry_test.go @@ -0,0 +1,56 @@ +package component + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" + "opencsg.com/csghub-server/common/types/telemetry" +) + +func TestTelemetryComponent_SaveUsageData(t *testing.T) { + ctx := context.TODO() + tc := initializeTestTelemetryComponent(ctx, t) + + tc.mocks.stores.TelemetryMock().EXPECT().Save(ctx, &database.Telemetry{ + UUID: "uid", + Version: "v1", + Licensee: telemetry.Licensee{}, + Settings: telemetry.Settings{}, + Counts: telemetry.Counts{}, + }).Return(nil) + + err := tc.SaveUsageData(ctx, telemetry.Usage{ + UUID: "uid", + Version: "v1", + }) + require.Nil(t, err) + +} + +func TestTelemetryComponent_GenUsageData(t *testing.T) { + ctx := context.TODO() + tc := initializeTestTelemetryComponent(ctx, t) + + tc.mocks.stores.UserMock().EXPECT().CountUsers(ctx).Return(100, nil) + tc.mocks.stores.RepoMock().EXPECT().CountByRepoType(ctx, types.ModelRepo).Return(10, nil) + tc.mocks.stores.RepoMock().EXPECT().CountByRepoType(ctx, types.DatasetRepo).Return(20, nil) + tc.mocks.stores.RepoMock().EXPECT().CountByRepoType(ctx, types.CodeRepo).Return(30, nil) + tc.mocks.stores.RepoMock().EXPECT().CountByRepoType(ctx, types.SpaceRepo).Return(40, nil) + + data, err := tc.GenUsageData(ctx) + require.Nil(t, err) + + require.Equal(t, 100, data.ActiveUserCount) + require.Equal(t, 30, data.Counts.Codes) + require.Equal(t, 20, data.Counts.Datasets) + require.Equal(t, 10, data.Counts.Models) + require.Equal(t, 40, data.Counts.Spaces) + require.Equal(t, 100, data.Counts.TotalRepos) + require.NotEmpty(t, data.UUID) + require.GreaterOrEqual(t, time.Now(), data.RecordedAt) + require.LessOrEqual(t, time.Now().Add(-5*time.Second), data.RecordedAt) +} diff --git a/component/wire.go b/component/wire.go index 89b3ab39..76ebad30 100644 --- a/component/wire.go +++ b/component/wire.go @@ -345,3 +345,163 @@ func initializeTestSpaceSdkComponent(ctx context.Context, t interface { ) return &testSpaceSdkWithMocks{} } + +type testTelemetryWithMocks struct { + *telemetryComponentImpl + mocks *Mocks +} + +func initializeTestTelemetryComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testTelemetryWithMocks { + wire.Build( + MockSuperSet, TelemetryComponentSet, + wire.Struct(new(testTelemetryWithMocks), "*"), + ) + return &testTelemetryWithMocks{} +} + +type testClusterWithMocks struct { + *clusterComponentImpl + mocks *Mocks +} + +func initializeTestClusterComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testClusterWithMocks { + wire.Build( + MockSuperSet, ClusterComponentSet, + wire.Struct(new(testClusterWithMocks), "*"), + ) + return &testClusterWithMocks{} +} + +type testEvaluationWithMocks struct { + *evaluationComponentImpl + mocks *Mocks +} + +func initializeTestEvaluationComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testEvaluationWithMocks { + wire.Build( + MockSuperSet, EvaluationComponentSet, + wire.Struct(new(testEvaluationWithMocks), "*"), + ) + return &testEvaluationWithMocks{} +} + +type testHFDatasetWithMocks struct { + *hFDatasetComponentImpl + mocks *Mocks +} + +func initializeTestHFDatasetComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testHFDatasetWithMocks { + wire.Build( + MockSuperSet, HFDatasetComponentSet, + wire.Struct(new(testHFDatasetWithMocks), "*"), + ) + return &testHFDatasetWithMocks{} +} + +type testRepoFileWithMocks struct { + *repoFileComponentImpl + mocks *Mocks +} + +func initializeTestRepoFileComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testRepoFileWithMocks { + wire.Build( + MockSuperSet, RepoFileComponentSet, + wire.Struct(new(testRepoFileWithMocks), "*"), + ) + return &testRepoFileWithMocks{} +} + +type testSensitiveWithMocks struct { + *sensitiveComponentImpl + mocks *Mocks +} + +func initializeTestSensitiveComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testSensitiveWithMocks { + wire.Build( + MockSuperSet, SensitiveComponentSet, + wire.Struct(new(testSensitiveWithMocks), "*"), + ) + return &testSensitiveWithMocks{} +} + +type testSSHKeyWithMocks struct { + *sSHKeyComponentImpl + mocks *Mocks +} + +func initializeTestSSHKeyComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testSSHKeyWithMocks { + wire.Build( + MockSuperSet, SSHKeyComponentSet, + wire.Struct(new(testSSHKeyWithMocks), "*"), + ) + return &testSSHKeyWithMocks{} +} + +type testListWithMocks struct { + *listComponentImpl + mocks *Mocks +} + +func initializeTestListComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testListWithMocks { + wire.Build( + MockSuperSet, ListComponentSet, + wire.Struct(new(testListWithMocks), "*"), + ) + return &testListWithMocks{} +} + +type testSyncClientSettingWithMocks struct { + *syncClientSettingComponentImpl + mocks *Mocks +} + +func initializeTestSyncClientSettingComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testSyncClientSettingWithMocks { + wire.Build( + MockSuperSet, SyncClientSettingComponentSet, + wire.Struct(new(testSyncClientSettingWithMocks), "*"), + ) + return &testSyncClientSettingWithMocks{} +} + +type testEventWithMocks struct { + *eventComponentImpl + mocks *Mocks +} + +func initializeTestEventComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testEventWithMocks { + wire.Build( + MockSuperSet, EventComponentSet, + wire.Struct(new(testEventWithMocks), "*"), + ) + return &testEventWithMocks{} +} diff --git a/component/wire_gen_test.go b/component/wire_gen_test.go index 2fac5ee9..a8fec516 100644 --- a/component/wire_gen_test.go +++ b/component/wire_gen_test.go @@ -1068,6 +1068,506 @@ func initializeTestSpaceSdkComponent(ctx context.Context, t interface { return componentTestSpaceSdkWithMocks } +func initializeTestTelemetryComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testTelemetryWithMocks { + config := ProvideTestConfig() + mockStores := tests.NewMockStores(t) + componentTelemetryComponentImpl := NewTestTelemetryComponent(config, mockStores) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockGitServer := gitserver.NewMockGitServer(t) + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestTelemetryWithMocks := &testTelemetryWithMocks{ + telemetryComponentImpl: componentTelemetryComponentImpl, + mocks: mocks, + } + return componentTestTelemetryWithMocks +} + +func initializeTestClusterComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testClusterWithMocks { + config := ProvideTestConfig() + mockDeployer := deploy.NewMockDeployer(t) + componentClusterComponentImpl := NewTestClusterComponent(config, mockDeployer) + mockStores := tests.NewMockStores(t) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockGitServer := gitserver.NewMockGitServer(t) + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestClusterWithMocks := &testClusterWithMocks{ + clusterComponentImpl: componentClusterComponentImpl, + mocks: mocks, + } + return componentTestClusterWithMocks +} + +func initializeTestEvaluationComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testEvaluationWithMocks { + config := ProvideTestConfig() + mockStores := tests.NewMockStores(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingComponent := component.NewMockAccountingComponent(t) + componentEvaluationComponentImpl := NewTestEvaluationComponent(config, mockStores, mockDeployer, mockAccountingComponent) + mockRepoComponent := component.NewMockRepoComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockGitServer := gitserver.NewMockGitServer(t) + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestEvaluationWithMocks := &testEvaluationWithMocks{ + evaluationComponentImpl: componentEvaluationComponentImpl, + mocks: mocks, + } + return componentTestEvaluationWithMocks +} + +func initializeTestHFDatasetComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testHFDatasetWithMocks { + config := ProvideTestConfig() + mockStores := tests.NewMockStores(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockGitServer := gitserver.NewMockGitServer(t) + componentHFDatasetComponentImpl := NewTestHFDatasetComponent(config, mockStores, mockRepoComponent, mockGitServer) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestHFDatasetWithMocks := &testHFDatasetWithMocks{ + hFDatasetComponentImpl: componentHFDatasetComponentImpl, + mocks: mocks, + } + return componentTestHFDatasetWithMocks +} + +func initializeTestRepoFileComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testRepoFileWithMocks { + config := ProvideTestConfig() + mockStores := tests.NewMockStores(t) + mockGitServer := gitserver.NewMockGitServer(t) + componentRepoFileComponentImpl := NewTestRepoFileComponent(config, mockStores, mockGitServer) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestRepoFileWithMocks := &testRepoFileWithMocks{ + repoFileComponentImpl: componentRepoFileComponentImpl, + mocks: mocks, + } + return componentTestRepoFileWithMocks +} + +func initializeTestSensitiveComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testSensitiveWithMocks { + config := ProvideTestConfig() + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + componentSensitiveComponentImpl := NewTestSensitiveComponent(config, mockModerationSvcClient) + mockStores := tests.NewMockStores(t) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockGitServer := gitserver.NewMockGitServer(t) + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestSensitiveWithMocks := &testSensitiveWithMocks{ + sensitiveComponentImpl: componentSensitiveComponentImpl, + mocks: mocks, + } + return componentTestSensitiveWithMocks +} + +func initializeTestSSHKeyComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testSSHKeyWithMocks { + config := ProvideTestConfig() + mockStores := tests.NewMockStores(t) + mockGitServer := gitserver.NewMockGitServer(t) + componentSSHKeyComponentImpl := NewTestSSHKeyComponent(config, mockStores, mockGitServer) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestSSHKeyWithMocks := &testSSHKeyWithMocks{ + sSHKeyComponentImpl: componentSSHKeyComponentImpl, + mocks: mocks, + } + return componentTestSSHKeyWithMocks +} + +func initializeTestListComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testListWithMocks { + config := ProvideTestConfig() + mockStores := tests.NewMockStores(t) + componentListComponentImpl := NewTestListComponent(config, mockStores) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockGitServer := gitserver.NewMockGitServer(t) + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestListWithMocks := &testListWithMocks{ + listComponentImpl: componentListComponentImpl, + mocks: mocks, + } + return componentTestListWithMocks +} + +func initializeTestSyncClientSettingComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testSyncClientSettingWithMocks { + config := ProvideTestConfig() + mockStores := tests.NewMockStores(t) + componentSyncClientSettingComponentImpl := NewTestSyncClientSettingComponent(config, mockStores) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockGitServer := gitserver.NewMockGitServer(t) + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestSyncClientSettingWithMocks := &testSyncClientSettingWithMocks{ + syncClientSettingComponentImpl: componentSyncClientSettingComponentImpl, + mocks: mocks, + } + return componentTestSyncClientSettingWithMocks +} + +func initializeTestEventComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testEventWithMocks { + config := ProvideTestConfig() + mockStores := tests.NewMockStores(t) + componentEventComponentImpl := NewTestEventComponent(config, mockStores) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockGitServer := gitserver.NewMockGitServer(t) + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestEventWithMocks := &testEventWithMocks{ + eventComponentImpl: componentEventComponentImpl, + mocks: mocks, + } + return componentTestEventWithMocks +} + // wire.go: type testRepoWithMocks struct { @@ -1174,3 +1674,53 @@ type testSpaceSdkWithMocks struct { *spaceSdkComponentImpl mocks *Mocks } + +type testTelemetryWithMocks struct { + *telemetryComponentImpl + mocks *Mocks +} + +type testClusterWithMocks struct { + *clusterComponentImpl + mocks *Mocks +} + +type testEvaluationWithMocks struct { + *evaluationComponentImpl + mocks *Mocks +} + +type testHFDatasetWithMocks struct { + *hFDatasetComponentImpl + mocks *Mocks +} + +type testRepoFileWithMocks struct { + *repoFileComponentImpl + mocks *Mocks +} + +type testSensitiveWithMocks struct { + *sensitiveComponentImpl + mocks *Mocks +} + +type testSSHKeyWithMocks struct { + *sSHKeyComponentImpl + mocks *Mocks +} + +type testListWithMocks struct { + *listComponentImpl + mocks *Mocks +} + +type testSyncClientSettingWithMocks struct { + *syncClientSettingComponentImpl + mocks *Mocks +} + +type testEventWithMocks struct { + *eventComponentImpl + mocks *Mocks +} diff --git a/component/wireset.go b/component/wireset.go index b25cc075..6723031f 100644 --- a/component/wireset.go +++ b/component/wireset.go @@ -487,3 +487,104 @@ func NewTestSpaceSdkComponent(config *config.Config, stores *tests.MockStores) * } var SpaceSdkComponentSet = wire.NewSet(NewTestSpaceSdkComponent) + +func NewTestTelemetryComponent(config *config.Config, stores *tests.MockStores) *telemetryComponentImpl { + return &telemetryComponentImpl{ + telemetryStore: stores.Telemetry, + userStore: stores.User, + repoStore: stores.Repo, + } +} + +var TelemetryComponentSet = wire.NewSet(NewTestTelemetryComponent) + +func NewTestClusterComponent(config *config.Config, deployer deploy.Deployer) *clusterComponentImpl { + return &clusterComponentImpl{ + deployer: deployer, + } +} + +var ClusterComponentSet = wire.NewSet(NewTestClusterComponent) + +func NewTestEvaluationComponent(config *config.Config, stores *tests.MockStores, deployer deploy.Deployer, accountingComponent AccountingComponent) *evaluationComponentImpl { + return &evaluationComponentImpl{ + deployer: deployer, + userStore: stores.User, + modelStore: stores.Model, + datasetStore: stores.Dataset, + mirrorStore: stores.Mirror, + spaceResourceStore: stores.SpaceResource, + tokenStore: stores.AccessToken, + runtimeFrameworkStore: stores.RuntimeFramework, + config: config, + accountingComponent: accountingComponent, + } +} + +var EvaluationComponentSet = wire.NewSet(NewTestEvaluationComponent) + +func NewTestHFDatasetComponent(config *config.Config, stores *tests.MockStores, repoComponent RepoComponent, gitServer gitserver.GitServer) *hFDatasetComponentImpl { + return &hFDatasetComponentImpl{ + repoComponent: repoComponent, + tagStore: stores.Tag, + datasetStore: stores.Dataset, + repoStore: stores.Repo, + gitServer: gitServer, + } +} + +var HFDatasetComponentSet = wire.NewSet(NewTestHFDatasetComponent) + +func NewTestRepoFileComponent(config *config.Config, stores *tests.MockStores, gitServer gitserver.GitServer) *repoFileComponentImpl { + return &repoFileComponentImpl{ + repoFileStore: stores.RepoFile, + repoStore: stores.Repo, + gitServer: gitServer, + } +} + +var RepoFileComponentSet = wire.NewSet(NewTestRepoFileComponent) + +func NewTestSensitiveComponent(config *config.Config, checker rpc.ModerationSvcClient) *sensitiveComponentImpl { + return &sensitiveComponentImpl{ + checker: checker, + } +} + +var SensitiveComponentSet = wire.NewSet(NewTestSensitiveComponent) + +func NewTestSSHKeyComponent(config *config.Config, stores *tests.MockStores, gitServer gitserver.GitServer) *sSHKeyComponentImpl { + return &sSHKeyComponentImpl{ + sshKeyStore: stores.SSH, + userStore: stores.User, + gitServer: gitServer, + } +} + +var SSHKeyComponentSet = wire.NewSet(NewTestSSHKeyComponent) + +func NewTestListComponent(config *config.Config, stores *tests.MockStores) *listComponentImpl { + return &listComponentImpl{ + modelStore: stores.Model, + datasetStore: stores.Dataset, + spaceStore: stores.Space, + } +} + +var ListComponentSet = wire.NewSet(NewTestListComponent) + +func NewTestSyncClientSettingComponent(config *config.Config, stores *tests.MockStores) *syncClientSettingComponentImpl { + return &syncClientSettingComponentImpl{ + settingStore: stores.SyncClientSetting, + } +} + +var SyncClientSettingComponentSet = wire.NewSet(NewTestSyncClientSettingComponent) + +func NewTestEventComponent(config *config.Config, stores *tests.MockStores) *eventComponentImpl { + return &eventComponentImpl{ + eventStore: stores.Event, + } +} + +var EventComponentSet = wire.NewSet(NewTestEventComponent)