diff --git a/builder/store/database/event.go b/builder/store/database/event.go index 1add4e6e..dfe1f04a 100644 --- a/builder/store/database/event.go +++ b/builder/store/database/event.go @@ -18,6 +18,12 @@ func NewEventStore() EventStore { } } +func NewEventStoreWithDB(db *DB) EventStore { + return &eventStoreImpl{ + db: db, + } +} + func (s *eventStoreImpl) Save(ctx context.Context, event Event) error { return assertAffectedOneRow(s.db.Core.NewInsert().Model(&event).Exec(ctx)) } diff --git a/builder/store/database/event_test.go b/builder/store/database/event_test.go new file mode 100644 index 00000000..6a15577e --- /dev/null +++ b/builder/store/database/event_test.go @@ -0,0 +1,39 @@ +package database_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/tests" +) + +func TestEventStore_Save(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewEventStoreWithDB(db) + err := store.Save(ctx, database.Event{ + Module: "m1", + }) + require.Nil(t, err) + event := &database.Event{} + err = db.Core.NewSelect().Model(event).Where("module=?", "m1").Scan(ctx) + require.Nil(t, err) + require.Equal(t, "m1", event.Module) + + err = store.BatchSave(ctx, []database.Event{ + {Module: "m2"}, + {Module: "m3"}, + }) + require.Nil(t, err) + err = db.Core.NewSelect().Model(event).Where("module=?", "m2").Scan(ctx) + require.Nil(t, err) + require.Equal(t, "m2", event.Module) + err = db.Core.NewSelect().Model(event).Where("module=?", "m3").Scan(ctx) + require.Nil(t, err) + require.Equal(t, "m3", event.Module) + +} diff --git a/builder/store/database/lfs_lock.go b/builder/store/database/lfs_lock.go index bd762eba..6b55c366 100644 --- a/builder/store/database/lfs_lock.go +++ b/builder/store/database/lfs_lock.go @@ -22,6 +22,12 @@ func NewLfsLockStore() LfsLockStore { } } +func NewLfsLockStoreWithDB(db *DB) LfsLockStore { + return &lfsLockStoreImpl{ + db: db, + } +} + type LfsLock struct { ID int64 `bun:",pk,autoincrement" json:"id"` RepositoryID int64 `bun:",notnull" json:"repository_id"` @@ -37,7 +43,7 @@ func (s *lfsLockStoreImpl) FindByID(ctx context.Context, ID int64) (*LfsLock, er err := s.db.Operator.Core.NewSelect(). Model(&lfsLock). Relation("User"). - Where("id = ?", ID). + Where("lfs_lock.id = ?", ID). Scan(ctx) if err != nil { return nil, err diff --git a/builder/store/database/lfs_lock_test.go b/builder/store/database/lfs_lock_test.go new file mode 100644 index 00000000..4cc50b8c --- /dev/null +++ b/builder/store/database/lfs_lock_test.go @@ -0,0 +1,47 @@ +package database_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/tests" +) + +func TestLfsLockStore_CRUD(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewLfsLockStoreWithDB(db) + _, err := store.Create(ctx, database.LfsLock{ + RepositoryID: 123, + Path: "foo/bar", + }) + require.Nil(t, err) + + lock := &database.LfsLock{} + err = db.Core.NewSelect().Model(lock).Scan(ctx) + require.Nil(t, err) + require.Equal(t, "foo/bar", lock.Path) + + lock, err = store.FindByID(ctx, lock.ID) + require.Nil(t, err) + require.Equal(t, "foo/bar", lock.Path) + + lock, err = store.FindByPath(ctx, 123, "foo/bar") + require.Nil(t, err) + require.Equal(t, "foo/bar", lock.Path) + + ls, err := store.FindByRepoID(ctx, 123, 1, 10) + require.Nil(t, err) + require.Equal(t, 1, len(ls)) + require.Equal(t, "foo/bar", ls[0].Path) + + err = store.RemoveByID(ctx, lock.ID) + require.Nil(t, err) + _, err = store.FindByID(ctx, lock.ID) + require.NotNil(t, err) + +} diff --git a/builder/store/database/lfs_meta_object.go b/builder/store/database/lfs_meta_object.go index 22b3db70..db43437f 100644 --- a/builder/store/database/lfs_meta_object.go +++ b/builder/store/database/lfs_meta_object.go @@ -25,6 +25,12 @@ func NewLfsMetaObjectStore() LfsMetaObjectStore { } } +func NewLfsMetaObjectStoreWithDB(db *DB) LfsMetaObjectStore { + return &lfsMetaObjectStoreImpl{ + db: db, + } +} + type LfsMetaObject struct { ID int64 `bun:",pk,autoincrement" json:"user_id"` Oid string `bun:",notnull" json:"oid"` @@ -70,10 +76,10 @@ func (s *lfsMetaObjectStoreImpl) Create(ctx context.Context, lfsObj LfsMetaObjec } func (s *lfsMetaObjectStoreImpl) RemoveByOid(ctx context.Context, oid string, repoID int64) error { - err := s.db.Operator.Core.NewDelete(). + _, err := s.db.Operator.Core.NewDelete(). Model(&LfsMetaObject{}). Where("oid = ? and repository_id= ?", oid, repoID). - Scan(ctx) + Exec(ctx) return err } diff --git a/builder/store/database/lfs_meta_object_test.go b/builder/store/database/lfs_meta_object_test.go new file mode 100644 index 00000000..e4d04b77 --- /dev/null +++ b/builder/store/database/lfs_meta_object_test.go @@ -0,0 +1,82 @@ +package database_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/tests" +) + +func TestLfsMetaStore_CRUD(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewLfsMetaObjectStoreWithDB(db) + _, err := store.Create(ctx, database.LfsMetaObject{ + RepositoryID: 123, + Oid: "foobar", + }) + require.Nil(t, err) + + obj := &database.LfsMetaObject{} + err = db.Core.NewSelect().Model(obj).Scan(ctx) + require.Nil(t, err) + require.Equal(t, "foobar", obj.Oid) + + obj, err = store.FindByOID(ctx, 123, "foobar") + require.Nil(t, err) + require.Equal(t, "foobar", obj.Oid) + + objs, err := store.FindByRepoID(ctx, 123) + require.Nil(t, err) + require.Equal(t, 1, len(objs)) + require.Equal(t, "foobar", objs[0].Oid) + + // update + _, err = store.UpdateOrCreate(ctx, database.LfsMetaObject{ + RepositoryID: 123, + Oid: "foobar", + Size: 999, + }) + require.Nil(t, err) + obj, err = store.FindByOID(ctx, 123, "foobar") + require.Nil(t, err) + require.Equal(t, 999, int(obj.Size)) + + // create + _, err = store.UpdateOrCreate(ctx, database.LfsMetaObject{ + RepositoryID: 456, + Oid: "bar", + Size: 998, + }) + require.Nil(t, err) + obj, err = store.FindByOID(ctx, 456, "bar") + require.Nil(t, err) + require.Equal(t, 998, int(obj.Size)) + + err = store.BulkUpdateOrCreate(ctx, []database.LfsMetaObject{ + {RepositoryID: 123, Oid: "foobar", Size: 1}, + {RepositoryID: 456, Oid: "bar", Size: 2}, + {RepositoryID: 789, Oid: "barfoo", Size: 3}, + }) + require.Nil(t, err) + + obj, err = store.FindByOID(ctx, 123, "foobar") + require.Nil(t, err) + require.Equal(t, 1, int(obj.Size)) + obj, err = store.FindByOID(ctx, 456, "bar") + require.Nil(t, err) + require.Equal(t, 2, int(obj.Size)) + obj, err = store.FindByOID(ctx, 789, "barfoo") + require.Nil(t, err) + require.Equal(t, 3, int(obj.Size)) + + err = store.RemoveByOid(ctx, "foobar", 123) + require.Nil(t, err) + _, err = store.FindByOID(ctx, 123, "foobar") + require.NotNil(t, err) + +} diff --git a/builder/store/database/llm_config.go b/builder/store/database/llm_config.go index 6c1758f8..eac56136 100644 --- a/builder/store/database/llm_config.go +++ b/builder/store/database/llm_config.go @@ -27,6 +27,10 @@ func NewLLMConfigStore() LLMConfigStore { return &lLMConfigStoreImpl{db: defaultDB} } +func NewLLMConfigStoreWithDB(db *DB) LLMConfigStore { + return &lLMConfigStoreImpl{db: db} +} + func (s *lLMConfigStoreImpl) GetOptimization(ctx context.Context) (*LLMConfig, error) { var config LLMConfig err := s.db.Operator.Core.NewSelect().Model(&config).Where("type = 1 and enabled = true").Limit(1).Scan(ctx) diff --git a/builder/store/database/llm_config_test.go b/builder/store/database/llm_config_test.go new file mode 100644 index 00000000..5e157025 --- /dev/null +++ b/builder/store/database/llm_config_test.go @@ -0,0 +1,40 @@ +package database_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/tests" +) + +func TestLLMConfigStore_GetOptimization(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewLLMConfigStoreWithDB(db) + _, err := db.Core.NewInsert().Model(&database.LLMConfig{ + Type: 1, + Enabled: true, + ModelName: "c1", + }).Exec(ctx) + require.Nil(t, err) + _, err = db.Core.NewInsert().Model(&database.LLMConfig{ + Type: 2, + Enabled: true, + ModelName: "c2", + }).Exec(ctx) + require.Nil(t, err) + _, err = db.Core.NewInsert().Model(&database.LLMConfig{ + Type: 1, + Enabled: false, + ModelName: "c3", + }).Exec(ctx) + require.Nil(t, err) + + cfg, err := store.GetOptimization(ctx) + require.Nil(t, err) + require.Equal(t, "c1", cfg.ModelName) +} diff --git a/builder/store/database/member.go b/builder/store/database/member.go index 12f4316f..db2ca309 100644 --- a/builder/store/database/member.go +++ b/builder/store/database/member.go @@ -23,6 +23,12 @@ func NewMemberStore() MemberStore { } } +func NewMemberStoreWithDB(db *DB) MemberStore { + return &memberStoreImpl{ + db: db, + } +} + // Member is the relationship between a user and an organization. type Member struct { ID int64 `bun:",pk,autoincrement" json:"id"` diff --git a/builder/store/database/member_test.go b/builder/store/database/member_test.go new file mode 100644 index 00000000..6af010c1 --- /dev/null +++ b/builder/store/database/member_test.go @@ -0,0 +1,46 @@ +package database_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/tests" +) + +func TestMemberStore_CRUD(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewMemberStoreWithDB(db) + + err := store.Add(ctx, 123, 456, "foo") + require.Nil(t, err) + mem := &database.Member{} + err = db.Core.NewSelect().Model(mem).Where("user_id=?", 456).Scan(ctx) + require.Nil(t, err) + require.Equal(t, "foo", mem.Role) + + mem, err = store.Find(ctx, 123, 456) + require.Nil(t, err) + require.Equal(t, "foo", mem.Role) + + ms, err := store.UserMembers(ctx, 456) + require.Nil(t, err) + require.Equal(t, 1, len(ms)) + require.Equal(t, "foo", ms[0].Role) + + ms, count, err := store.OrganizationMembers(ctx, 123, 10, 1) + require.Nil(t, err) + require.Equal(t, 1, len(ms)) + require.Equal(t, 1, count) + require.Equal(t, "foo", ms[0].Role) + + err = store.Delete(ctx, 123, 456, "foo") + require.Nil(t, err) + _, err = store.Find(ctx, 123, 456) + require.NotNil(t, err) + +} diff --git a/builder/store/database/mirror.go b/builder/store/database/mirror.go index 9f733d56..d9ccd448 100644 --- a/builder/store/database/mirror.go +++ b/builder/store/database/mirror.go @@ -41,6 +41,12 @@ func NewMirrorStore() MirrorStore { } } +func NewMirrorStoreWithDB(db *DB) MirrorStore { + return &mirrorStoreImpl{ + db: db, + } +} + type Mirror struct { ID int64 `bun:",pk,autoincrement" json:"id"` Interval string `bun:",notnull" json:"interval"` @@ -185,7 +191,7 @@ func (s *mirrorStoreImpl) WithPaginationWithRepository(ctx context.Context) ([]M var mirrors []Mirror err := s.db.Operator.Core.NewSelect(). Model(&mirrors). - Relation("Repositoy"). + Relation("Repository"). Scan(ctx) if err != nil { return nil, err diff --git a/builder/store/database/mirror_source.go b/builder/store/database/mirror_source.go index 314f171c..33b9017e 100644 --- a/builder/store/database/mirror_source.go +++ b/builder/store/database/mirror_source.go @@ -25,6 +25,12 @@ func NewMirrorSourceStore() MirrorSourceStore { } } +func NewMirrorSourceStoreWithDB(db *DB) MirrorSourceStore { + return &mirrorSourceStoreImpl{ + db: db, + } +} + type MirrorSource struct { ID int64 `bun:",pk,autoincrement" json:"id"` SourceName string `bun:",notnull,unique" json:"source_name"` diff --git a/builder/store/database/mirror_source_test.go b/builder/store/database/mirror_source_test.go new file mode 100644 index 00000000..b742c3ca --- /dev/null +++ b/builder/store/database/mirror_source_test.go @@ -0,0 +1,54 @@ +package database_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/tests" +) + +func TestMirrorSourceStore_CRUD(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewMirrorSourceStoreWithDB(db) + _, err := store.Create(ctx, &database.MirrorSource{ + SourceName: "foo", + }) + require.Nil(t, err) + + mi := &database.MirrorSource{} + err = db.Core.NewSelect().Model(mi).Scan(ctx) + require.Nil(t, err) + require.Equal(t, "foo", mi.SourceName) + + mi, err = store.Get(ctx, mi.ID) + require.Nil(t, err) + require.Equal(t, "foo", mi.SourceName) + + mi, err = store.FindByName(ctx, "foo") + require.Nil(t, err) + require.Equal(t, "foo", mi.SourceName) + + mi.SourceName = "bar" + err = store.Update(ctx, mi) + require.Nil(t, err) + mi = &database.MirrorSource{} + err = db.Core.NewSelect().Model(mi).Scan(ctx) + require.Nil(t, err) + require.Equal(t, "bar", mi.SourceName) + + mis, err := store.Index(ctx) + require.Nil(t, err) + require.Equal(t, 1, len(mis)) + require.Equal(t, "bar", mis[0].SourceName) + + err = store.Delete(ctx, mi) + require.Nil(t, err) + _, err = store.Get(ctx, mi.ID) + require.NotNil(t, err) + +} diff --git a/builder/store/database/mirror_test.go b/builder/store/database/mirror_test.go new file mode 100644 index 00000000..ab056121 --- /dev/null +++ b/builder/store/database/mirror_test.go @@ -0,0 +1,253 @@ +package database_test + +import ( + "context" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/tests" + "opencsg.com/csghub-server/common/types" +) + +func TestMirrorStore_CRUD(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewMirrorStoreWithDB(db) + _, err := store.Create(ctx, &database.Mirror{ + Interval: "foo", + RepositoryID: 123, + PushMirrorCreated: true, + Status: types.MirrorFinished, + Priority: types.HighMirrorPriority, + }) + require.Nil(t, err) + + mi := &database.Mirror{} + err = db.Core.NewSelect().Model(mi).Scan(ctx) + require.Nil(t, err) + require.Equal(t, "foo", mi.Interval) + + mi, err = store.FindByID(ctx, mi.ID) + require.Nil(t, err) + require.Equal(t, "foo", mi.Interval) + + mi, err = store.FindByRepoID(ctx, 123) + require.Nil(t, err) + require.Equal(t, "foo", mi.Interval) + + exist, err := store.IsExist(ctx, 123) + require.Nil(t, err) + require.True(t, exist) + exist, err = store.IsExist(ctx, 456) + require.Nil(t, err) + require.False(t, exist) + + repo := &database.Repository{ + RepositoryType: types.ModelRepo, + GitPath: "models_ns/n", + Name: "repo", + } + err = db.Core.NewInsert().Model(repo).Scan(ctx, repo) + require.Nil(t, err) + + exist, err = store.IsRepoExist(ctx, types.ModelRepo, "ns", "n") + require.Nil(t, err) + require.True(t, exist) + + exist, err = store.IsRepoExist(ctx, types.ModelRepo, "ns", "n2") + require.Nil(t, err) + require.False(t, exist) + + mi.RepositoryID = repo.ID + err = store.Update(ctx, mi) + require.Nil(t, err) + + err = db.Core.NewSelect().Model(mi).Scan(ctx) + require.Nil(t, err) + require.Equal(t, repo.ID, mi.RepositoryID) + + mi, err = store.FindByRepoPath(ctx, types.ModelRepo, "ns", "n") + require.Nil(t, err) + require.Equal(t, repo.ID, mi.RepositoryID) + + ms, err := store.WithPagination(ctx) + require.Nil(t, err) + require.Equal(t, 1, len(ms)) + + ms, err = store.WithPaginationWithRepository(ctx) + require.Nil(t, err) + require.Equal(t, 1, len(ms)) + require.Equal(t, "repo", ms[0].Repository.Name) + + ms, err = store.PushedMirror(ctx) + require.Nil(t, err) + require.Equal(t, 1, len(ms)) + + ms, err = store.NoPushMirror(ctx) + require.Nil(t, err) + require.Equal(t, 0, len(ms)) + + ms, err = store.Finished(ctx) + require.Nil(t, err) + require.Equal(t, 1, len(ms)) + + ms, err = store.Unfinished(ctx) + require.Nil(t, err) + require.Equal(t, 0, len(ms)) + + mi.AccessToken = "abc" + repo.Nickname = "fooo" + err = store.UpdateMirrorAndRepository(ctx, mi, repo) + require.Nil(t, err) + mi = &database.Mirror{} + err = db.Core.NewSelect().Model(mi).Scan(ctx) + require.Nil(t, err) + require.Equal(t, "abc", mi.AccessToken) + repo = &database.Repository{} + err = db.Core.NewSelect().Model(repo).Scan(ctx) + require.Nil(t, err) + require.Equal(t, "fooo", repo.Nickname) + + err = store.Delete(ctx, mi) + require.Nil(t, err) + _, err = store.FindByID(ctx, mi.ID) + require.NotNil(t, err) + +} + +func TestMirrorStore_FindWithMapping(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewMirrorStoreWithDB(db) + + repos := []*database.Repository{ + {Name: "repo1", RepositoryType: types.ModelRepo, Path: "models_ns/repo1"}, + {Name: "repo2", RepositoryType: types.DatasetRepo, Path: "datasets_ns/repo2"}, + {Name: "repo3", RepositoryType: types.PromptRepo, Path: "prompts_ns/repo3"}, + } + + for _, repo := range repos { + repo.GitPath = repo.Path + err := db.Core.NewInsert().Model(repo).Scan(ctx, repo) + require.Nil(t, err) + sp := strings.Split(repo.Path, "_") + _, err = store.Create(ctx, &database.Mirror{ + RepositoryID: repo.ID, + SourceRepoPath: strings.ReplaceAll(sp[1], "ns/", "nsn/"), + Interval: repo.Name, + }) + require.Nil(t, err) + } + + mi, err := store.FindWithMapping(ctx, types.ModelRepo, "ns", "repo1", types.CSGHubMapping) + require.Nil(t, err) + require.Equal(t, "repo1", mi.Interval) + + _, err = store.FindWithMapping(ctx, types.ModelRepo, "ns", "repo1", types.HFMapping) + require.NotNil(t, err) + + mi, err = store.FindWithMapping(ctx, types.ModelRepo, "nsn", "repo1", types.HFMapping) + require.Nil(t, err) + require.Equal(t, "repo1", mi.Interval) + + mi, err = store.FindWithMapping(ctx, types.DatasetRepo, "nsn", "repo2", types.HFMapping) + require.Nil(t, err) + require.Equal(t, "repo2", mi.Interval) + + mi, err = store.FindWithMapping(ctx, types.PromptRepo, "nsn", "repo3", types.AutoMapping) + require.Nil(t, err) + require.Equal(t, "repo3", mi.Interval) +} + +func TestMirrorStore_ToSync(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewMirrorStoreWithDB(db) + + dt := time.Now().Add(1 * time.Hour) + mirrors := []*database.Mirror{ + {NextExecutionTimestamp: dt, Status: types.MirrorFailed, Interval: "m1"}, + {NextExecutionTimestamp: dt, Status: types.MirrorFinished, Interval: "m2"}, + {NextExecutionTimestamp: dt, Status: types.MirrorIncomplete, Interval: "m3"}, + {NextExecutionTimestamp: dt, Status: types.MirrorRepoSynced, Interval: "m4"}, + {NextExecutionTimestamp: dt, Status: types.MirrorRunning, Interval: "m5"}, + {NextExecutionTimestamp: dt, Status: types.MirrorWaiting, Interval: "m6"}, + {NextExecutionTimestamp: dt.Add(-5 * time.Hour), Status: types.MirrorFinished, Interval: "m7"}, + } + for _, m := range mirrors { + _, err := store.Create(ctx, m) + require.Nil(t, err) + } + + ms, err := store.ToSyncRepo(ctx) + require.Nil(t, err) + names := []string{} + for _, m := range ms { + names = append(names, m.Interval) + } + require.ElementsMatch(t, []string{"m1", "m3", "m6", "m7"}, names) + + ms, err = store.ToSyncLfs(ctx) + require.Nil(t, err) + names = []string{} + for _, m := range ms { + names = append(names, m.Interval) + } + require.ElementsMatch(t, []string{"m4", "m7"}, names) + +} + +func TestMirrorStore_IndexWithPagination(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewMirrorStoreWithDB(db) + + mirrors := []*database.Mirror{ + {Interval: "m1", LocalRepoPath: "foo", SourceUrl: "bar"}, + {Interval: "m2", LocalRepoPath: "bar", SourceUrl: "foo"}, + } + for _, m := range mirrors { + _, err := store.Create(ctx, m) + require.Nil(t, err) + } + + ms, count, err := store.IndexWithPagination(ctx, 10, 1) + require.Nil(t, err) + names := []string{} + for _, m := range ms { + names = append(names, m.Interval) + } + require.Equal(t, 2, count) + require.ElementsMatch(t, []string{"m1", "m2"}, names) + +} + +func TestMirrorStore_StatusCount(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewMirrorStoreWithDB(db) + + mirrors := []*database.Mirror{ + {Interval: "m1", Status: types.MirrorFailed}, + {Interval: "m2", Status: types.MirrorFailed}, + {Interval: "m3", Status: types.MirrorFinished}, + } + for _, m := range mirrors { + _, err := store.Create(ctx, m) + require.Nil(t, err) + } + +} diff --git a/builder/store/database/model.go b/builder/store/database/model.go index 34452bc7..067c0008 100644 --- a/builder/store/database/model.go +++ b/builder/store/database/model.go @@ -37,6 +37,12 @@ func NewModelStore() ModelStore { } } +func NewModelStoreWithDB(db *DB) ModelStore { + return &modelStoreImpl{ + db: db, + } +} + type Model struct { ID int64 `bun:",pk,autoincrement" json:"id"` RepositoryID int64 `bun:",notnull" json:"repository_id"` diff --git a/builder/store/database/model_test.go b/builder/store/database/model_test.go new file mode 100644 index 00000000..f5d70afc --- /dev/null +++ b/builder/store/database/model_test.go @@ -0,0 +1,227 @@ +package database_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/tests" + "opencsg.com/csghub-server/common/types" +) + +func TestModelStore_CRUD(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewModelStoreWithDB(db) + _, err := store.Create(ctx, database.Model{ + RepositoryID: 123, + }) + require.Nil(t, err) + + m := &database.Model{} + err = db.Core.NewSelect().Model(m).Where("repository_id=?", 123).Scan(ctx) + require.Nil(t, err) + + m, err = store.ByID(ctx, m.ID) + require.Nil(t, err) + require.Equal(t, int64(123), m.RepositoryID) + + m.RepositoryID = 456 + _, err = store.Update(ctx, *m) + require.Nil(t, err) + m = &database.Model{} + err = db.Core.NewSelect().Model(m).Where("repository_id=?", 456).Scan(ctx) + require.Nil(t, err) + + m, err = store.ByRepoID(ctx, 456) + require.Nil(t, err) + require.Equal(t, int64(456), m.RepositoryID) + + ms, err := store.ByRepoIDs(ctx, []int64{456}) + require.Nil(t, err) + require.Equal(t, int64(456), ms[0].RepositoryID) + + _, err = store.CreateIfNotExist(ctx, database.Model{ + RepositoryID: 789, + }) + require.Nil(t, err) + m, err = store.ByRepoID(ctx, 789) + require.Nil(t, err) + require.Equal(t, int64(789), m.RepositoryID) + + repo := &database.Repository{ + Path: "foo/bar", + GitPath: "foo/bar2", + Private: true, + } + err = db.Core.NewInsert().Model(repo).Scan(ctx, repo) + require.Nil(t, err) + m.RepositoryID = repo.ID + _, err = store.Update(ctx, *m) + require.Nil(t, err) + + ms, total, err := store.ByUsername(ctx, "foo", 10, 1, false) + require.Nil(t, err) + require.Equal(t, 1, total) + require.Equal(t, len(ms), 1) + + ms, total, err = store.ByUsername(ctx, "foo", 10, 1, true) + require.Nil(t, err) + require.Equal(t, 0, total) + require.Equal(t, len(ms), 0) + + ms, total, err = store.ByOrgPath(ctx, "foo", 10, 1, false) + require.Nil(t, err) + require.Equal(t, 1, total) + require.Equal(t, len(ms), 1) + + ms, total, err = store.ByOrgPath(ctx, "foo", 10, 1, true) + require.Nil(t, err) + require.Equal(t, 0, total) + require.Equal(t, len(ms), 0) + + m, err = store.FindByPath(ctx, "foo", "bar") + require.Nil(t, err) + require.Equal(t, repo.ID, m.RepositoryID) + + err = store.Delete(ctx, *m) + require.Nil(t, err) + _, err = store.FindByPath(ctx, "foo", "bar") + require.NotNil(t, err) +} + +func TestModelStore_ListByPath(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewModelStoreWithDB(db) + + dt := &database.Tag{} + err := db.Core.NewInsert().Model(&database.Tag{ + Name: "tag1", + Category: "evaluation", + }).Scan(ctx, dt) + require.Nil(t, err) + tag1pk := dt.ID + + err = db.Core.NewInsert().Model(&database.Tag{ + Name: "tag2", + Category: "foo", + }).Scan(ctx, dt) + require.Nil(t, err) + tag2pk := dt.ID + + dr := &database.Repository{} + err = db.Core.NewInsert().Model(&database.Repository{ + Name: "repo", + Path: "foo/bar", + GitPath: "a", + }).Scan(ctx, dr) + require.Nil(t, err) + repopk := dr.ID + + for _, tpk := range []int64{tag1pk, tag2pk} { + _, err = db.Core.NewInsert().Model(&database.RepositoryTag{ + RepositoryID: repopk, + TagID: tpk, + }).Exec(ctx) + require.Nil(t, err) + } + + _, err = store.Create(ctx, database.Model{ + RepositoryID: repopk, + }) + require.Nil(t, err) + + dr2 := &database.Repository{} + err = db.Core.NewInsert().Model(&database.Repository{ + Name: "repo2", + Path: "bar/foo", + GitPath: "b", + }).Scan(ctx, dr2) + require.Nil(t, err) + _, err = store.Create(ctx, database.Model{ + RepositoryID: dr2.ID, + }) + require.Nil(t, err) + + dr3 := &database.Repository{} + err = db.Core.NewInsert().Model(&database.Repository{ + Name: "repo3", + Path: "foo/bar", + GitPath: "c", + RepositoryType: types.ModelRepo, + }).Scan(ctx, dr3) + require.Nil(t, err) + _, err = store.Create(ctx, database.Model{ + RepositoryID: dr3.ID, + }) + require.Nil(t, err) + + dss, err := store.ListByPath(ctx, []string{"bar/foo", "foo/bar"}) + require.Nil(t, err) + require.Equal(t, 3, len(dss)) + + tags := []string{} + for _, t := range dss[1].Repository.Tags { + tags = append(tags, t.Name) + } + require.Equal(t, []string{}, tags) + + names := []string{} + for _, ds := range dss { + names = append(names, ds.Repository.Name) + } + require.Equal(t, []string{"repo2", "repo", "repo3"}, names) + +} + +func TestModelStore_UserLikes(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewModelStoreWithDB(db) + + repos := []*database.Repository{ + {Name: "repo1", Path: "p1", GitPath: "p1"}, + {Name: "repo2", Path: "p2", GitPath: "p2"}, + {Name: "repo3", Path: "p3", GitPath: "p3"}, + } + + for _, repo := range repos { + err := db.Core.NewInsert().Model(repo).Scan(ctx, repo) + require.Nil(t, err) + _, err = store.Create(ctx, database.Model{ + RepositoryID: repo.ID, + }) + require.Nil(t, err) + + } + + _, err := db.Core.NewInsert().Model(&database.UserLike{ + UserID: 123, + RepoID: repos[0].ID, + }).Exec(ctx) + require.Nil(t, err) + _, err = db.Core.NewInsert().Model(&database.UserLike{ + UserID: 123, + RepoID: repos[2].ID, + }).Exec(ctx) + require.Nil(t, err) + + dss, total, err := store.UserLikesModels(ctx, 123, 10, 1) + require.Nil(t, err) + require.Equal(t, 2, total) + + names := []string{} + for _, ds := range dss { + names = append(names, ds.Repository.Name) + } + require.Equal(t, []string{"repo1", "repo3"}, names) + +} diff --git a/builder/store/database/multi_sync.go b/builder/store/database/multi_sync.go index 521efba5..9d810604 100644 --- a/builder/store/database/multi_sync.go +++ b/builder/store/database/multi_sync.go @@ -24,8 +24,14 @@ func NewMultiSyncStore() MultiSyncStore { } } +func NewMultiSyncStoreWithDB(db *DB) MultiSyncStore { + return &multiSyncStoreImpl{ + db: db, + } +} + func (s *multiSyncStoreImpl) Create(ctx context.Context, v SyncVersion) (*SyncVersion, error) { - res, err := s.db.Core.NewInsert().Model(&v).Exec(ctx, &v) + res, err := s.db.Core.NewInsert().Model(&v).Exec(ctx) if err := assertAffectedOneRow(res, err); err != nil { return nil, fmt.Errorf("create sync version in db failed,error:%w", err) } diff --git a/builder/store/database/multi_sync_test.go b/builder/store/database/multi_sync_test.go new file mode 100644 index 00000000..63ac36c9 --- /dev/null +++ b/builder/store/database/multi_sync_test.go @@ -0,0 +1,61 @@ +package database_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/tests" + "opencsg.com/csghub-server/common/types" +) + +func TestMultiSyncStore_CRUD(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewMultiSyncStoreWithDB(db) + + _, err := store.Create(ctx, database.SyncVersion{ + Version: 123, + SourceID: 1, + RepoPath: "a", + RepoType: types.ModelRepo, + }) + require.Nil(t, err) + + sv := &database.SyncVersion{} + err = db.Core.NewSelect().Model(sv).Where("version=?", 123).Scan(ctx) + require.Nil(t, err) + require.Equal(t, 123, int(sv.Version)) + + _, err = store.Create(ctx, database.SyncVersion{ + Version: 103, + SourceID: 1, + RepoPath: "a", + RepoType: types.ModelRepo, + }) + require.Nil(t, err) + _, err = store.Create(ctx, database.SyncVersion{ + Version: 143, + SourceID: 1, + RepoPath: "a", + RepoType: types.ModelRepo, + }) + require.Nil(t, err) + svs, err := store.GetAfter(ctx, 123, 1) + require.Nil(t, err) + require.Equal(t, len(svs), 1) + require.Equal(t, 143, int(svs[0].Version)) + + svv, err := store.GetLatest(ctx) + require.Nil(t, err) + require.Equal(t, 143, int(svv.Version)) + + svs, err = store.GetAfterDistinct(ctx, 100) + require.Nil(t, err) + require.Equal(t, len(svs), 1) + require.True(t, int(svs[0].Version) > 100) + +} diff --git a/builder/store/database/namespace.go b/builder/store/database/namespace.go index 063fa6f6..28eb1a1d 100644 --- a/builder/store/database/namespace.go +++ b/builder/store/database/namespace.go @@ -4,17 +4,22 @@ import ( "context" ) -type namespaceStoreImpl struct { - db *DB +// Define the NamespaceStore interface +type NamespaceStore interface { + FindByPath(ctx context.Context, path string) (Namespace, error) + Exists(ctx context.Context, path string) (bool, error) } -type NamespaceStore interface { - FindByPath(ctx context.Context, path string) (namespace Namespace, err error) - Exists(ctx context.Context, path string) (exists bool, err error) +type NamespaceStoreImpl struct { + db *DB } func NewNamespaceStore() NamespaceStore { - return &namespaceStoreImpl{db: defaultDB} + return &NamespaceStoreImpl{db: defaultDB} +} + +func NewNamespaceStoreWithDB(db *DB) NamespaceStore { + return &NamespaceStoreImpl{db: db} } type NamespaceType string @@ -34,13 +39,13 @@ type Namespace struct { times } -func (s *namespaceStoreImpl) FindByPath(ctx context.Context, path string) (namespace Namespace, err error) { +func (s *NamespaceStoreImpl) FindByPath(ctx context.Context, path string) (namespace Namespace, err error) { namespace.Path = path err = s.db.Operator.Core.NewSelect().Model(&namespace).Relation("User").Where("path = ?", path).Scan(ctx) return } -func (s *namespaceStoreImpl) Exists(ctx context.Context, path string) (exists bool, err error) { +func (s *NamespaceStoreImpl) Exists(ctx context.Context, path string) (exists bool, err error) { var namespace Namespace return s.db.Operator.Core. NewSelect(). diff --git a/builder/store/database/namespace_test.go b/builder/store/database/namespace_test.go new file mode 100644 index 00000000..8f4970e4 --- /dev/null +++ b/builder/store/database/namespace_test.go @@ -0,0 +1,37 @@ +package database_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/tests" +) + +func TestNamespaceStore_All(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewNamespaceStoreWithDB(db) + + _, err := db.Core.NewInsert().Model(&database.Namespace{ + Path: "foo/bar", + }).Exec(ctx) + require.Nil(t, err) + + exist, err := store.Exists(ctx, "foo/bar") + require.Nil(t, err) + require.True(t, exist) + exist, err = store.Exists(ctx, "foo/bar2") + require.Nil(t, err) + require.False(t, exist) + + ns, err := store.FindByPath(ctx, "foo/bar") + require.Nil(t, err) + require.Equal(t, "foo/bar", ns.Path) + _, err = store.FindByPath(ctx, "foo/bar2") + require.NotNil(t, err) + +} diff --git a/builder/store/database/organization.go b/builder/store/database/organization.go index ede0fda2..68eec4b8 100644 --- a/builder/store/database/organization.go +++ b/builder/store/database/organization.go @@ -26,6 +26,12 @@ func NewOrgStore() OrgStore { } } +func NewOrgStoreWithDB(db *DB) OrgStore { + return &orgStoreImpl{ + db: db, + } +} + type Organization struct { ID int64 `bun:",pk,autoincrement" json:"id"` Nickname string `bun:"name,notnull" json:"name"` diff --git a/builder/store/database/organization_test.go b/builder/store/database/organization_test.go new file mode 100644 index 00000000..91a8ebf2 --- /dev/null +++ b/builder/store/database/organization_test.go @@ -0,0 +1,80 @@ +package database_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/tests" +) + +func TestOrganizationStore_CRUD(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewOrgStoreWithDB(db) + err := store.Create(ctx, &database.Organization{ + Name: "o1", + }, &database.Namespace{Path: "o1"}) + require.Nil(t, err) + + org := &database.Organization{} + err = db.Core.NewSelect().Model(org).Where("path=?", "o1").Scan(ctx) + require.Nil(t, err) + require.Equal(t, "o1", org.Name) + ns := &database.Namespace{} + err = db.Core.NewSelect().Model(ns).Where("path=?", "o1").Scan(ctx) + require.Nil(t, err) + require.Equal(t, "o1", ns.Path) + require.Equal(t, database.OrgNamespace, ns.NamespaceType) + + orgv, err := store.FindByPath(ctx, "o1") + require.Nil(t, err) + require.Equal(t, "o1", orgv.Name) + + exist, err := store.Exists(ctx, "o1") + require.Nil(t, err) + require.True(t, exist) + exist, err = store.Exists(ctx, "bar") + require.Nil(t, err) + require.False(t, exist) + + org.Homepage = "abc" + err = store.Update(ctx, org) + require.Nil(t, err) + org = &database.Organization{} + err = db.Core.NewSelect().Model(org).Where("path=?", "o1").Scan(ctx) + require.Nil(t, err) + require.Equal(t, "abc", org.Homepage) + + owner := &database.User{Username: "u1"} + err = db.Core.NewInsert().Model(owner).Scan(ctx, owner) + require.Nil(t, err) + + member := &database.Member{ + OrganizationID: org.ID, + UserID: 321, + } + err = db.Core.NewInsert().Model(member).Scan(ctx, member) + require.Nil(t, err) + org.UserID = owner.ID + err = store.Update(ctx, org) + require.Nil(t, err) + + orgs, err := store.GetUserOwnOrgs(ctx, "u1") + require.Nil(t, err) + require.Equal(t, 1, len(orgs)) + + orgs, err = store.GetUserBelongOrgs(ctx, 321) + require.Nil(t, err) + require.Equal(t, 1, len(orgs)) + + err = store.Delete(ctx, "o1") + require.Nil(t, err) + exist, err = store.Exists(ctx, "foo") + require.Nil(t, err) + require.False(t, exist) + +} diff --git a/builder/store/database/prompt.go b/builder/store/database/prompt.go index a8a0d4fb..383a229c 100644 --- a/builder/store/database/prompt.go +++ b/builder/store/database/prompt.go @@ -29,14 +29,14 @@ type PromptStore interface { ByOrgPath(ctx context.Context, namespace string, per, page int, onlyPublic bool) (prompts []Prompt, total int, err error) } -func NewPromptStore() PromptStore { - return &promptStoreImpl{db: defaultDB} -} - func NewPromptStoreWithDB(db *DB) PromptStore { return &promptStoreImpl{db: db} } +func NewPromptStore() PromptStore { + return &promptStoreImpl{db: defaultDB} +} + func (s *promptStoreImpl) Create(ctx context.Context, input Prompt) (*Prompt, error) { res, err := s.db.Core.NewInsert().Model(&input).Exec(ctx, &input) if err := assertAffectedOneRow(res, err); err != nil { @@ -121,6 +121,7 @@ func (s *promptStoreImpl) ByUsername(ctx context.Context, username string, per, if err != nil { return } + total, err = query.Count(ctx) if err != nil { return diff --git a/builder/store/database/prompt_conversation.go b/builder/store/database/prompt_conversation.go index f2bfa535..d1217b40 100644 --- a/builder/store/database/prompt_conversation.go +++ b/builder/store/database/prompt_conversation.go @@ -45,6 +45,10 @@ func NewPromptConversationStore() PromptConversationStore { return &promptConversationStoreImpl{db: defaultDB} } +func NewPromptConversationStoreWithDB(db *DB) PromptConversationStore { + return &promptConversationStoreImpl{db: db} +} + func (p *promptConversationStoreImpl) CreateConversation(ctx context.Context, conversation PromptConversation) error { err := p.db.Operator.Core.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { if err := assertAffectedOneRow(tx.NewInsert().Model(&conversation).Exec(ctx)); err != nil { diff --git a/builder/store/database/prompt_conversation_test.go b/builder/store/database/prompt_conversation_test.go new file mode 100644 index 00000000..7426b411 --- /dev/null +++ b/builder/store/database/prompt_conversation_test.go @@ -0,0 +1,95 @@ +package database_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/tests" +) + +func TestPromptConversationStore_CRUD(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewPromptConversationStoreWithDB(db) + msg := &database.PromptConversationMessage{ + ConversationID: "cv", + Content: "msg", + } + err := db.Core.NewInsert().Model(msg).Scan(ctx, msg) + require.Nil(t, err) + err = store.CreateConversation(ctx, database.PromptConversation{ + UserID: 123, + ConversationID: "cv", + Title: "foo", + }) + require.Nil(t, err) + + pc := &database.PromptConversation{} + err = db.Core.NewSelect().Model(pc).Where("user_id=?", 123).Scan(ctx) + require.Nil(t, err) + require.Equal(t, "foo", pc.Title) + + pc, err = store.GetConversationByID(ctx, 123, "cv", false) + require.Nil(t, err) + require.Equal(t, "foo", pc.Title) + require.Nil(t, pc.Messages) + pc, err = store.GetConversationByID(ctx, 123, "cv", true) + require.Nil(t, err) + require.Equal(t, "foo", pc.Title) + require.Equal(t, 1, len(pc.Messages)) + require.Equal(t, "msg", pc.Messages[0].Content) + + pc.Title = "bar" + err = store.UpdateConversation(ctx, *pc) + require.Nil(t, err) + pc = &database.PromptConversation{} + err = db.Core.NewSelect().Model(pc).Where("user_id=?", 123).Scan(ctx) + require.Nil(t, err) + require.Equal(t, "bar", pc.Title) + + pcs, err := store.FindConversationsByUserID(ctx, 123) + require.Nil(t, err) + require.Equal(t, 1, len(pcs)) + require.Equal(t, "bar", pcs[0].Title) + + _, err = store.SaveConversationMessage(ctx, database.PromptConversationMessage{ + ConversationID: pc.ConversationID, + Content: "foobar", + }) + require.Nil(t, err) + + pc, err = store.GetConversationByID(ctx, 123, "cv", true) + require.Nil(t, err) + require.Equal(t, 2, len(pc.Messages)) + + err = store.LikeMessageByID(ctx, msg.ID) + require.Nil(t, err) + err = db.Core.NewSelect().Model(msg).WherePK().Scan(ctx) + require.Nil(t, err) + require.Equal(t, true, msg.UserLike) + err = store.LikeMessageByID(ctx, msg.ID) + require.Nil(t, err) + err = db.Core.NewSelect().Model(msg).WherePK().Scan(ctx) + require.Nil(t, err) + require.Equal(t, false, msg.UserLike) + + err = store.HateMessageByID(ctx, msg.ID) + require.Nil(t, err) + err = db.Core.NewSelect().Model(msg).WherePK().Scan(ctx) + require.Nil(t, err) + require.Equal(t, true, msg.UserHate) + err = store.HateMessageByID(ctx, msg.ID) + require.Nil(t, err) + err = db.Core.NewSelect().Model(msg).WherePK().Scan(ctx) + require.Nil(t, err) + require.Equal(t, false, msg.UserHate) + + err = store.DeleteConversationsByID(ctx, 123, pc.ConversationID) + require.Nil(t, err) + _, err = store.GetConversationByID(ctx, 123, "cv", false) + require.NotNil(t, err) +} diff --git a/builder/store/database/prompt_prefix.go b/builder/store/database/prompt_prefix.go index 03f14e3f..e31fda82 100644 --- a/builder/store/database/prompt_prefix.go +++ b/builder/store/database/prompt_prefix.go @@ -23,6 +23,10 @@ func NewPromptPrefixStore() PromptPrefixStore { return &promptPrefixStoreImpl{db: defaultDB} } +func NewPromptPrefixStoreWithDB(db *DB) PromptPrefixStore { + return &promptPrefixStoreImpl{db: db} +} + func (p *promptPrefixStoreImpl) Get(ctx context.Context) (*PromptPrefix, error) { var prefix PromptPrefix err := p.db.Operator.Core.NewSelect().Model(&prefix).Order("id desc").Limit(1).Scan(ctx) diff --git a/builder/store/database/prompt_prefix_test.go b/builder/store/database/prompt_prefix_test.go new file mode 100644 index 00000000..24b6bcd0 --- /dev/null +++ b/builder/store/database/prompt_prefix_test.go @@ -0,0 +1,32 @@ +package database_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/tests" +) + +func TestPromptPrefixStore_Get(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewPromptPrefixStoreWithDB(db) + + _, err := db.Core.NewInsert().Model(&database.PromptPrefix{ + EN: "foo", + }).Exec(ctx) + require.Nil(t, err) + _, err = db.Core.NewInsert().Model(&database.PromptPrefix{ + EN: "bar", + }).Exec(ctx) + require.Nil(t, err) + + prefix, err := store.Get(ctx) + require.Nil(t, err) + require.Equal(t, "bar", prefix.EN) + +} diff --git a/builder/store/database/recom.go b/builder/store/database/recom.go index 5535f19e..7387325e 100644 --- a/builder/store/database/recom.go +++ b/builder/store/database/recom.go @@ -22,6 +22,12 @@ func NewRecomStore() RecomStore { } } +func NewRecomStoreWithDB(db *DB) RecomStore { + return &recomStoreImpl{ + db: db, + } +} + // Index returns repos in descend order of score. func (s *recomStoreImpl) Index(ctx context.Context, page, pageSize int) ([]*RecomRepoScore, error) { items := make([]*RecomRepoScore, 0) diff --git a/builder/store/database/recom_test.go b/builder/store/database/recom_test.go new file mode 100644 index 00000000..52866395 --- /dev/null +++ b/builder/store/database/recom_test.go @@ -0,0 +1,57 @@ +package database_test + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/tests" +) + +func TestRecomStore_All(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewRecomStoreWithDB(db) + + err := store.UpsertScore(ctx, 123, 1) + require.Nil(t, err) + err = store.UpsertScore(ctx, 123, 2) + require.Nil(t, err) + err = store.UpsertScore(ctx, 456, 1) + require.Nil(t, err) + + scores, err := store.Index(ctx, 0, 10) + require.Nil(t, err) + require.Equal(t, 2, len(scores)) + ids := []string{} + for _, s := range scores { + ids = append(ids, fmt.Sprintf("%d/%.0f", s.RepositoryID, s.Score)) + } + require.Equal(t, []string{"123/2", "456/1"}, ids) + + _, err = db.Core.NewInsert().Model(&database.RecomWeight{Name: "w1"}).Exec(ctx) + require.Nil(t, err) + ws, err := store.LoadWeights(ctx) + require.Nil(t, err) + require.Equal(t, 3, len(ws)) + names := []string{} + for _, w := range ws { + names = append(names, w.Name) + } + require.ElementsMatch(t, []string{"freshness", "downloads", "w1"}, names) + + _, err = db.Core.NewInsert().Model(&database.RecomOpWeight{ + Weight: 3, + RepositoryID: 123, + }).Exec(ctx) + require.Nil(t, err) + wos, err := store.LoadOpWeights(ctx) + require.Nil(t, err) + require.Equal(t, 1, len(wos)) + require.Equal(t, 3, wos[0].Weight) + +} diff --git a/builder/store/database/repo_relation.go b/builder/store/database/repo_relation.go index ad2e7398..6ac2c17d 100644 --- a/builder/store/database/repo_relation.go +++ b/builder/store/database/repo_relation.go @@ -29,6 +29,12 @@ func NewRepoRelationsStore() RepoRelationsStore { } } +func NewRepoRelationsStoreWithDB(db *DB) RepoRelationsStore { + return &repoRelationsStoreImpl{ + db: db, + } +} + type RepoRelation struct { ID int64 `bun:",pk,autoincrement" json:"id"` FromRepoID int64 `bun:",notnull" json:"from_repo_id"` diff --git a/builder/store/database/repo_relation_test.go b/builder/store/database/repo_relation_test.go new file mode 100644 index 00000000..6441e7fd --- /dev/null +++ b/builder/store/database/repo_relation_test.go @@ -0,0 +1,111 @@ +package database_test + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/tests" +) + +func TestRepoRelationStore_FromTo(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewRepoRelationsStoreWithDB(db) + relations := []*database.RepoRelation{ + {FromRepoID: 1, ToRepoID: 2}, + {FromRepoID: 1, ToRepoID: 3}, + {FromRepoID: 1, ToRepoID: 4}, + {FromRepoID: 3, ToRepoID: 5}, + } + + for _, rel := range relations { + err := db.Core.NewInsert().Model(rel).Scan(ctx, rel) + require.Nil(t, err) + } + + rs, err := store.From(ctx, 1) + require.Nil(t, err) + ids := []int64{} + for _, r := range rs { + ids = append(ids, r.ToRepoID) + } + require.ElementsMatch(t, []int64{2, 3, 4}, ids) + + rs, err = store.To(ctx, 5) + require.Nil(t, err) + ids = []int64{} + for _, r := range rs { + ids = append(ids, r.FromRepoID) + } + require.ElementsMatch(t, []int64{3}, ids) + + err = store.Delete(ctx, 1, 3) + require.Nil(t, err) + rs, err = store.From(ctx, 1) + require.Nil(t, err) + ids = []int64{} + for _, r := range rs { + ids = append(ids, r.ToRepoID) + } + require.ElementsMatch(t, []int64{2, 4}, ids) + +} + +func TestRepoRelationStore_Override(t *testing.T) { + cases := []struct { + from int64 + to []int64 + expected1To []int64 + expected3To []int64 + }{ + {1, nil, []int64{}, []int64{5}}, + {1, []int64{2}, []int64{2}, []int64{5}}, + {1, []int64{5}, []int64{5}, []int64{5}}, + {1, []int64{2, 3}, []int64{2, 3}, []int64{5}}, + } + + for _, c := range cases { + t.Run(fmt.Sprintf("%+v", c), func(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewRepoRelationsStoreWithDB(db) + relations := []*database.RepoRelation{ + {FromRepoID: 1, ToRepoID: 2}, + {FromRepoID: 1, ToRepoID: 3}, + {FromRepoID: 1, ToRepoID: 4}, + {FromRepoID: 3, ToRepoID: 5}, + } + + for _, rel := range relations { + err := db.Core.NewInsert().Model(rel).Scan(ctx, rel) + require.Nil(t, err) + } + + err := store.Override(ctx, c.from, c.to...) + require.Nil(t, err) + + rs, err := store.From(ctx, 1) + require.Nil(t, err) + ids := []int64{} + for _, r := range rs { + ids = append(ids, r.ToRepoID) + } + require.ElementsMatch(t, c.expected1To, ids) + + rs, err = store.From(ctx, 3) + require.Nil(t, err) + ids = []int64{} + for _, r := range rs { + ids = append(ids, r.ToRepoID) + } + require.ElementsMatch(t, c.expected3To, ids) + }) + } +} diff --git a/builder/store/database/repository_file.go b/builder/store/database/repository_file.go index 4c268310..e1bc3b09 100644 --- a/builder/store/database/repository_file.go +++ b/builder/store/database/repository_file.go @@ -39,6 +39,12 @@ func NewRepoFileStore() RepoFileStore { } } +func NewRepoFileStoreWithDB(db *DB) RepoFileStore { + return &repoFileStoreImpl{ + db: db, + } +} + func (s *repoFileStoreImpl) Create(ctx context.Context, file *RepositoryFile) error { _, err := s.db.Operator.Core.NewInsert().Model(file).Exec(ctx) return err diff --git a/builder/store/database/repository_file_check.go b/builder/store/database/repository_file_check.go index 09d82b31..d320bae9 100644 --- a/builder/store/database/repository_file_check.go +++ b/builder/store/database/repository_file_check.go @@ -33,6 +33,12 @@ func NewRepoFileCheckStore() RepoFileCheckStore { } } +func NewRepoFileCheckStoreWithDB(db *DB) RepoFileCheckStore { + return &repoFileCheckStoreImpl{ + db: db, + } +} + func (s *repoFileCheckStoreImpl) Create(ctx context.Context, history RepositoryFileCheck) error { _, err := s.db.Operator.Core.NewInsert().Model(&history).Exec(ctx) return err diff --git a/builder/store/database/repository_file_check_test.go b/builder/store/database/repository_file_check_test.go new file mode 100644 index 00000000..e01c728d --- /dev/null +++ b/builder/store/database/repository_file_check_test.go @@ -0,0 +1,39 @@ +package database_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/tests" + "opencsg.com/csghub-server/common/types" +) + +func TestRepositoryFileCheckStore_CreateUpsert(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewRepoFileCheckStoreWithDB(db) + + err := store.Create(ctx, database.RepositoryFileCheck{ + RepoFileID: 123, + Message: "foo", + Status: types.SensitiveCheckPass, + }) + require.Nil(t, err) + rf := &database.RepositoryFileCheck{} + err = db.Core.NewSelect().Model(rf).Where("repo_file_id=?", 123).Scan(ctx, rf) + require.Nil(t, err) + require.Equal(t, "foo", rf.Message) + + rf.Message = "bar" + err = store.Upsert(ctx, *rf) + require.Nil(t, err) + rf = &database.RepositoryFileCheck{} + err = db.Core.NewSelect().Model(rf).Where("repo_file_id=?", 123).Scan(ctx, rf) + require.Nil(t, err) + require.Equal(t, "bar", rf.Message) + +} diff --git a/builder/store/database/repository_file_test.go b/builder/store/database/repository_file_test.go new file mode 100644 index 00000000..fb22b225 --- /dev/null +++ b/builder/store/database/repository_file_test.go @@ -0,0 +1,138 @@ +package database_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/tests" + "opencsg.com/csghub-server/common/types" +) + +func TestRepoFileStore_Create(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewRepoFileStoreWithDB(db) + + err := store.Create(ctx, &database.RepositoryFile{ + Path: "foo", + RepositoryID: 123, + }) + require.Nil(t, err) + + rf := &database.RepositoryFile{} + err = db.Core.NewSelect().Model(rf).Where("path=?", "foo").Scan(ctx, rf) + require.Nil(t, err) + require.Equal(t, "foo", rf.Path) + +} + +func TestRepoFileStore_BatchGet(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewRepoFileStoreWithDB(db) + + repo := &database.Repository{} + err := db.Core.NewInsert().Model(repo).Scan(ctx, repo) + require.Nil(t, err) + + // check failed file + err = store.Create(ctx, &database.RepositoryFile{ + RepositoryID: repo.ID, + Path: "foo", + Branch: "main", + }) + require.Nil(t, err) + rf := &database.RepositoryFile{} + err = db.Core.NewSelect().Model(rf).Where("path=?", "foo").Scan(ctx, rf) + require.Nil(t, err) + + _, err = db.Core.NewInsert().Model(&database.RepositoryFileCheck{ + RepoFileID: rf.ID, + Status: types.SensitiveCheckFail, + }).Exec(ctx) + require.Nil(t, err) + + // check pass file + err = store.Create(ctx, &database.RepositoryFile{ + RepositoryID: repo.ID, + Path: "bar", + Branch: "main", + }) + require.Nil(t, err) + rf = &database.RepositoryFile{} + err = db.Core.NewSelect().Model(rf).Where("path=?", "bar").Scan(ctx, rf) + require.Nil(t, err) + + _, err = db.Core.NewInsert().Model(&database.RepositoryFileCheck{ + RepoFileID: rf.ID, + Status: types.SensitiveCheckPass, + }).Exec(ctx) + require.Nil(t, err) + + rfs, err := store.BatchGet(ctx, repo.ID, 0, 10) + require.Nil(t, err) + ps := []string{} + for _, rf := range rfs { + ps = append(ps, rf.Path) + } + require.Equal(t, []string{"foo", "bar"}, ps) + + rfs, err = store.BatchGetUnchcked(ctx, repo.ID, 0, 10) + require.Nil(t, err) + ps = []string{} + for _, rf := range rfs { + ps = append(ps, rf.Path) + } + require.Equal(t, []string{}, ps) + + exist, err := store.ExistsSensitiveCheckRecord(ctx, repo.ID, "main", types.SensitiveCheckPass) + require.Nil(t, err) + require.True(t, exist) + exist, err = store.ExistsSensitiveCheckRecord(ctx, repo.ID, "main", types.SensitiveCheckFail) + require.Nil(t, err) + require.True(t, exist) + exist, err = store.ExistsSensitiveCheckRecord(ctx, repo.ID, "main", types.SensitiveCheckSkip) + require.Nil(t, err) + require.False(t, exist) +} + +func TestRepoFileStore_Exists(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewRepoFileStoreWithDB(db) + + err := store.Create(ctx, &database.RepositoryFile{ + Path: "foo", + RepositoryID: 123, + Branch: "main", + CommitSha: "12321", + }) + require.Nil(t, err) + + exist, err := store.Exists(ctx, database.RepositoryFile{ + Path: "foo", + RepositoryID: 123, + Branch: "main", + CommitSha: "12321", + }) + require.Nil(t, err) + require.True(t, exist) + + exist, err = store.Exists(ctx, database.RepositoryFile{ + Path: "foo", + RepositoryID: 123, + Branch: "main", + CommitSha: "12322", + }) + require.Nil(t, err) + require.False(t, exist) + +} diff --git a/builder/store/database/repository_runtime_framework.go b/builder/store/database/repository_runtime_framework.go index 841a29f3..b060f6da 100644 --- a/builder/store/database/repository_runtime_framework.go +++ b/builder/store/database/repository_runtime_framework.go @@ -5,16 +5,33 @@ import ( "fmt" ) -type RepositoriesRuntimeFrameworkStore struct { +type RepositoriesRuntimeFrameworkStore interface { + ListByRuntimeFrameworkID(ctx context.Context, runtimeFrameworkID int64, deployType int) ([]RepositoriesRuntimeFramework, error) + Add(ctx context.Context, runtimeFrameworkID, repoID int64, deployType int) error + Delete(ctx context.Context, runtimeFrameworkID, repoID int64, deployType int) error + DeleteByRepoID(ctx context.Context, repoID int64) error + GetByIDsAndType(ctx context.Context, runtimeFrameworkID, repoID int64, deployType int) ([]RepositoriesRuntimeFramework, error) + ListRepoIDsByType(ctx context.Context, deployType int) ([]RepositoriesRuntimeFramework, error) + GetByRepoIDsAndType(ctx context.Context, repoID int64, deployType int) ([]RepositoriesRuntimeFramework, error) + GetByRepoIDs(ctx context.Context, repoID int64) ([]RepositoriesRuntimeFramework, error) +} + +type repositoriesRuntimeFrameworkStoreImpl struct { db *DB } -func NewRepositoriesRuntimeFramework() *RepositoriesRuntimeFrameworkStore { - return &RepositoriesRuntimeFrameworkStore{ +func NewRepositoriesRuntimeFramework() RepositoriesRuntimeFrameworkStore { + return &repositoriesRuntimeFrameworkStoreImpl{ db: defaultDB, } } +func NewRepositoriesRuntimeFrameworkWithDB(db *DB) RepositoriesRuntimeFrameworkStore { + return &repositoriesRuntimeFrameworkStoreImpl{ + db: db, + } +} + type RepositoriesRuntimeFramework struct { ID int64 `bun:",pk,autoincrement" json:"id"` RuntimeFrameworkID int64 `bun:",notnull" json:"runtime_framework_id"` @@ -23,7 +40,7 @@ type RepositoriesRuntimeFramework struct { Type int `bun:",notnull" json:"type"` // 0-space, 1-inference, 2-finetune } -func (m *RepositoriesRuntimeFrameworkStore) ListByRuntimeFrameworkID(ctx context.Context, runtimeFrameworkID int64, deployType int) ([]RepositoriesRuntimeFramework, error) { +func (m *repositoriesRuntimeFrameworkStoreImpl) ListByRuntimeFrameworkID(ctx context.Context, runtimeFrameworkID int64, deployType int) ([]RepositoriesRuntimeFramework, error) { var result []RepositoriesRuntimeFramework _, err := m.db.Operator.Core. NewSelect(). @@ -35,7 +52,7 @@ func (m *RepositoriesRuntimeFrameworkStore) ListByRuntimeFrameworkID(ctx context return result, nil } -func (m *RepositoriesRuntimeFrameworkStore) Add(ctx context.Context, runtimeFrameworkID, repoID int64, deployType int) error { +func (m *repositoriesRuntimeFrameworkStoreImpl) Add(ctx context.Context, runtimeFrameworkID, repoID int64, deployType int) error { relation := RepositoriesRuntimeFramework{ RuntimeFrameworkID: runtimeFrameworkID, RepoID: repoID, @@ -45,7 +62,7 @@ func (m *RepositoriesRuntimeFrameworkStore) Add(ctx context.Context, runtimeFram return err } -func (m *RepositoriesRuntimeFrameworkStore) Delete(ctx context.Context, runtimeFrameworkID, repoID int64, deployType int) error { +func (m *repositoriesRuntimeFrameworkStoreImpl) Delete(ctx context.Context, runtimeFrameworkID, repoID int64, deployType int) error { res, err := m.db.BunDB.Exec("delete from repositories_runtime_frameworks where type = ? and repo_id = ? and runtime_framework_id = ?", deployType, repoID, runtimeFrameworkID) if err != nil { return err @@ -54,7 +71,7 @@ func (m *RepositoriesRuntimeFrameworkStore) Delete(ctx context.Context, runtimeF return err } -func (m *RepositoriesRuntimeFrameworkStore) DeleteByRepoID(ctx context.Context, repoID int64) error { +func (m *repositoriesRuntimeFrameworkStoreImpl) DeleteByRepoID(ctx context.Context, repoID int64) error { _, err := m.db.Operator.Core.NewDelete().Model((*RepositoriesRuntimeFramework)(nil)).Where("repo_id = ?", repoID).Exec(ctx) if err != nil { return fmt.Errorf("delete repo runtime failed, %w", err) @@ -62,25 +79,25 @@ func (m *RepositoriesRuntimeFrameworkStore) DeleteByRepoID(ctx context.Context, return nil } -func (m *RepositoriesRuntimeFrameworkStore) GetByIDsAndType(ctx context.Context, runtimeFrameworkID, repoID int64, deployType int) ([]RepositoriesRuntimeFramework, error) { +func (m *repositoriesRuntimeFrameworkStoreImpl) GetByIDsAndType(ctx context.Context, runtimeFrameworkID, repoID int64, deployType int) ([]RepositoriesRuntimeFramework, error) { var result []RepositoriesRuntimeFramework _, err := m.db.Operator.Core.NewSelect().Model(&result).Where("type = ? and repo_id=? and runtime_framework_id = ?", deployType, repoID, runtimeFrameworkID).Exec(ctx, &result) return result, err } -func (m *RepositoriesRuntimeFrameworkStore) ListRepoIDsByType(ctx context.Context, deployType int) ([]RepositoriesRuntimeFramework, error) { +func (m *repositoriesRuntimeFrameworkStoreImpl) ListRepoIDsByType(ctx context.Context, deployType int) ([]RepositoriesRuntimeFramework, error) { var result []RepositoriesRuntimeFramework _, err := m.db.Operator.Core.NewSelect().Model(&result).Where("type = ?", deployType).Exec(ctx, &result) return result, err } -func (m *RepositoriesRuntimeFrameworkStore) GetByRepoIDsAndType(ctx context.Context, repoID int64, deployType int) ([]RepositoriesRuntimeFramework, error) { +func (m *repositoriesRuntimeFrameworkStoreImpl) GetByRepoIDsAndType(ctx context.Context, repoID int64, deployType int) ([]RepositoriesRuntimeFramework, error) { var result []RepositoriesRuntimeFramework _, err := m.db.Operator.Core.NewSelect().Model(&result).Where("type = ? and repo_id=?", deployType, repoID).Exec(ctx, &result) return result, err } -func (m *RepositoriesRuntimeFrameworkStore) GetByRepoIDs(ctx context.Context, repoID int64) ([]RepositoriesRuntimeFramework, error) { +func (m *repositoriesRuntimeFrameworkStoreImpl) GetByRepoIDs(ctx context.Context, repoID int64) ([]RepositoriesRuntimeFramework, error) { var result []RepositoriesRuntimeFramework _, err := m.db.Operator.Core.NewSelect().Model(&result).Where("repo_id=?", repoID).Exec(ctx, &result) if err != nil { diff --git a/builder/store/database/repository_runtime_framework_test.go b/builder/store/database/repository_runtime_framework_test.go new file mode 100644 index 00000000..e2f43f09 --- /dev/null +++ b/builder/store/database/repository_runtime_framework_test.go @@ -0,0 +1,65 @@ +package database_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/tests" +) + +func TestRepoRuntimeFrameworkStore_CRUD(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewRepositoriesRuntimeFrameworkWithDB(db) + err := store.Add(ctx, 123, 456, 1) + require.Nil(t, err) + + rf := &database.RepositoriesRuntimeFramework{} + err = db.Core.NewSelect().Model(rf).Where("repo_id=?", 456).Scan(ctx) + require.Nil(t, err) + require.Equal(t, 456, int(rf.RepoID)) + + rfs, err := store.GetByIDsAndType(ctx, 123, 456, 1) + require.Nil(t, err) + require.Equal(t, 1, len(rfs)) + require.Equal(t, 456, int(rfs[0].RepoID)) + + rfs, err = store.ListRepoIDsByType(ctx, 1) + require.Nil(t, err) + require.Equal(t, 1, len(rfs)) + require.Equal(t, 456, int(rfs[0].RepoID)) + rfs, err = store.ListRepoIDsByType(ctx, 2) + require.Nil(t, err) + require.Equal(t, 0, len(rfs)) + + rfs, err = store.GetByRepoIDsAndType(ctx, 456, 1) + require.Nil(t, err) + require.Equal(t, 1, len(rfs)) + require.Equal(t, 456, int(rfs[0].RepoID)) + rfs, err = store.GetByRepoIDsAndType(ctx, 456, 2) + require.Nil(t, err) + require.Equal(t, 0, len(rfs)) + + rfs, err = store.GetByRepoIDs(ctx, 456) + require.Nil(t, err) + require.Equal(t, 1, len(rfs)) + require.Equal(t, 456, int(rfs[0].RepoID)) + + err = store.Delete(ctx, 123, 456, 1) + require.Nil(t, err) + rfs, err = store.GetByIDsAndType(ctx, 123, 456, 1) + require.Nil(t, err) + require.Equal(t, 0, len(rfs)) + + err = store.Add(ctx, 123, 456, 1) + require.Nil(t, err) + err = store.DeleteByRepoID(ctx, 456) + require.Nil(t, err) + rfs, err = store.GetByIDsAndType(ctx, 123, 456, 1) + require.Nil(t, err) + require.Equal(t, 0, len(rfs)) +} diff --git a/builder/store/database/resources_models.go b/builder/store/database/resources_models.go index 02348331..5dbc1da9 100644 --- a/builder/store/database/resources_models.go +++ b/builder/store/database/resources_models.go @@ -19,6 +19,10 @@ func NewResourceModelStore() ResourceModelStore { return &resourceModelStoreImpl{db: defaultDB} } +func NewResourceModelStoreWithDB(db *DB) ResourceModelStore { + return &resourceModelStoreImpl{db: db} +} + type ResourceModel struct { ID int64 `bun:",pk,autoincrement" json:"id"` ResourceName string `bun:",notnull" json:"resource_name"` diff --git a/builder/store/database/resources_models_test.go b/builder/store/database/resources_models_test.go new file mode 100644 index 00000000..a0c80a21 --- /dev/null +++ b/builder/store/database/resources_models_test.go @@ -0,0 +1,45 @@ +package database_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/tests" +) + +func TestResourceModelStore_All(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewResourceModelStoreWithDB(db) + _, err := db.Core.NewInsert().Model(&database.ResourceModel{ + ModelName: "foo", + }).Exec(ctx) + require.Nil(t, err) + _, err = db.Core.NewInsert().Model(&database.ResourceModel{ + ModelName: "bar", + }).Exec(ctx) + require.Nil(t, err) + + ms, err := store.FindByModelName(ctx, "foo") + require.Nil(t, err) + require.Equal(t, 1, len(ms)) + require.Equal(t, "foo", ms[0].ModelName) + + _, err = db.Core.NewInsert().Model(&database.RepositoriesRuntimeFramework{ + RepoID: 123, + }).Exec(ctx) + require.Nil(t, err) + + m, err := store.CheckModelNameNotInRFRepo(ctx, "foo", 456) + require.Nil(t, err) + require.Equal(t, "foo", m.ModelName) + + m, err = store.CheckModelNameNotInRFRepo(ctx, "foo", 123) + require.Nil(t, err) + require.Nil(t, m) + +} diff --git a/builder/store/database/space.go b/builder/store/database/space.go index c9ecb6d4..9c0bb665 100644 --- a/builder/store/database/space.go +++ b/builder/store/database/space.go @@ -14,8 +14,8 @@ type spaceStoreImpl struct { } type SpaceStore interface { - BeginTx(ctx context.Context) (bun.Tx, error) - CreateTx(ctx context.Context, tx bun.Tx, input Space) (*Space, error) + // BeginTx(ctx context.Context) (bun.Tx, error) + // CreateTx(ctx context.Context, tx bun.Tx, input Space) (*Space, error) Create(ctx context.Context, input Space) (*Space, error) Update(ctx context.Context, input Space) (err error) FindByPath(ctx context.Context, namespace, name string) (*Space, error) @@ -36,20 +36,26 @@ func NewSpaceStore() SpaceStore { } } -func (s *spaceStoreImpl) BeginTx(ctx context.Context) (bun.Tx, error) { - return s.db.Core.BeginTx(ctx, nil) +func NewSpaceStoreWithDB(db *DB) SpaceStore { + return &spaceStoreImpl{ + db: db, + } } -func (s *spaceStoreImpl) CreateTx(ctx context.Context, tx bun.Tx, input Space) (*Space, error) { - res, err := tx.NewInsert().Model(&input).Exec(ctx) - if err := assertAffectedOneRow(res, err); err != nil { - slog.Error("create space in tx failed", slog.String("error", err.Error())) - return nil, fmt.Errorf("create space in tx failed,error:%w", err) - } +// func (s *spaceStoreImpl) BeginTx(ctx context.Context) (bun.Tx, error) { +// return s.db.Core.BeginTx(ctx, nil) +// } - input.ID, _ = res.LastInsertId() - return &input, nil -} +// func (s *spaceStoreImpl) CreateTx(ctx context.Context, tx bun.Tx, input Space) (*Space, error) { +// res, err := tx.NewInsert().Model(&input).Exec(ctx) +// if err := assertAffectedOneRow(res, err); err != nil { +// slog.Error("create space in tx failed", slog.String("error", err.Error())) +// return nil, fmt.Errorf("create space in tx failed,error:%w", err) +// } + +// input.ID, _ = res.LastInsertId() +// return &input, nil +// } func (s *spaceStoreImpl) Create(ctx context.Context, input Space) (*Space, error) { res, err := s.db.Core.NewInsert().Model(&input).Exec(ctx) diff --git a/builder/store/database/space_resource.go b/builder/store/database/space_resource.go index deaeada8..e1d2dc08 100644 --- a/builder/store/database/space_resource.go +++ b/builder/store/database/space_resource.go @@ -23,6 +23,10 @@ func NewSpaceResourceStore() SpaceResourceStore { return &spaceResourceStoreImpl{db: defaultDB} } +func NewSpaceResourceStoreWithDB(db *DB) SpaceResourceStore { + return &spaceResourceStoreImpl{db: db} +} + type SpaceResource struct { ID int64 `bun:",pk,autoincrement" json:"id"` Name string `bun:",notnull" json:"name"` diff --git a/builder/store/database/space_resource_test.go b/builder/store/database/space_resource_test.go new file mode 100644 index 00000000..28885c44 --- /dev/null +++ b/builder/store/database/space_resource_test.go @@ -0,0 +1,59 @@ +package database_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/tests" +) + +func TestSpaceResourceStore_CRUD(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewSpaceResourceStoreWithDB(db) + + _, err := store.Create(ctx, database.SpaceResource{ + Name: "r1", + ClusterID: "c1", + }) + require.Nil(t, err) + sr := &database.SpaceResource{} + err = db.Core.NewSelect().Model(sr).Where("name=?", "r1").Scan(ctx, sr) + require.Nil(t, err) + require.Equal(t, "c1", sr.ClusterID) + + sr, err = store.FindByID(ctx, sr.ID) + require.Nil(t, err) + require.Equal(t, "c1", sr.ClusterID) + + sr, err = store.FindByName(ctx, "r1") + require.Nil(t, err) + require.Equal(t, "c1", sr.ClusterID) + + srs, err := store.FindAll(ctx) + require.Nil(t, err) + require.Equal(t, 1, len(srs)) + require.Equal(t, "c1", srs[0].ClusterID) + + srs, err = store.Index(ctx, "c1") + require.Nil(t, err) + require.Equal(t, 1, len(srs)) + require.Equal(t, "c1", srs[0].ClusterID) + + sr.Name = "r2" + _, err = store.Update(ctx, *sr) + require.Nil(t, err) + sr, err = store.FindByID(ctx, sr.ID) + require.Nil(t, err) + require.Equal(t, "r2", sr.Name) + + err = store.Delete(ctx, *sr) + require.Nil(t, err) + _, err = store.FindByID(ctx, sr.ID) + require.NotNil(t, err) + +} diff --git a/builder/store/database/space_sdk.go b/builder/store/database/space_sdk.go index a5f243aa..e4246a6d 100644 --- a/builder/store/database/space_sdk.go +++ b/builder/store/database/space_sdk.go @@ -21,6 +21,10 @@ func NewSpaceSdkStore() SpaceSdkStore { return &spaceSdkStoreImpl{db: defaultDB} } +func NewSpaceSdkStoreWithDB(db *DB) SpaceSdkStore { + return &spaceSdkStoreImpl{db: db} +} + type SpaceSdk struct { ID int64 `bun:",pk,autoincrement" json:"id"` Name string `bun:",notnull" json:"name"` diff --git a/builder/store/database/space_sdk_test.go b/builder/store/database/space_sdk_test.go new file mode 100644 index 00000000..410ab5a4 --- /dev/null +++ b/builder/store/database/space_sdk_test.go @@ -0,0 +1,50 @@ +package database_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/tests" +) + +func TestSpaceSDKStore_CRUD(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewSpaceSdkStoreWithDB(db) + + _, err := store.Create(ctx, database.SpaceSdk{ + Name: "r1", + Version: "v1", + }) + require.Nil(t, err) + ss := &database.SpaceSdk{} + err = db.Core.NewSelect().Model(ss).Where("name=?", "r1").Scan(ctx, ss) + require.Nil(t, err) + require.Equal(t, "v1", ss.Version) + + ss, err = store.FindByID(ctx, ss.ID) + require.Nil(t, err) + require.Equal(t, "v1", ss.Version) + + sss, err := store.Index(ctx) + require.Nil(t, err) + require.Equal(t, 1, len(sss)) + require.Equal(t, "v1", sss[0].Version) + + ss.Name = "r2" + _, err = store.Update(ctx, *ss) + require.Nil(t, err) + ss, err = store.FindByID(ctx, ss.ID) + require.Nil(t, err) + require.Equal(t, "r2", ss.Name) + + err = store.Delete(ctx, *ss) + require.Nil(t, err) + _, err = store.FindByID(ctx, ss.ID) + require.NotNil(t, err) + +} diff --git a/builder/store/database/space_test.go b/builder/store/database/space_test.go new file mode 100644 index 00000000..f585694b --- /dev/null +++ b/builder/store/database/space_test.go @@ -0,0 +1,220 @@ +package database_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/tests" + "opencsg.com/csghub-server/common/types" +) + +func TestSpaceStore_CRUD(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewSpaceStoreWithDB(db) + _, err := store.Create(ctx, database.Space{ + RepositoryID: 123, + }) + require.Nil(t, err) + + sp := &database.Space{} + err = db.Core.NewSelect().Model(sp).Where("repository_id=?", 123).Scan(ctx) + require.Nil(t, err) + + sp, err = store.ByID(ctx, sp.ID) + require.Nil(t, err) + require.Equal(t, int64(123), sp.RepositoryID) + + sp.RepositoryID = 456 + err = store.Update(ctx, *sp) + require.Nil(t, err) + sp = &database.Space{} + err = db.Core.NewSelect().Model(sp).Where("repository_id=?", 456).Scan(ctx) + require.Nil(t, err) + + sp, err = store.ByRepoID(ctx, 456) + require.Nil(t, err) + require.Equal(t, int64(456), sp.RepositoryID) + + sps, err := store.ByRepoIDs(ctx, []int64{456}) + require.Nil(t, err) + require.Equal(t, int64(456), sps[0].RepositoryID) + + repo := &database.Repository{ + Path: "foo/bar", + GitPath: "foo/bar2", + Private: true, + RepositoryType: types.SpaceRepo, + } + err = db.Core.NewInsert().Model(repo).Scan(ctx, repo) + require.Nil(t, err) + sp.RepositoryID = repo.ID + err = store.Update(ctx, *sp) + require.Nil(t, err) + + sps, total, err := store.ByUsername(ctx, "foo", 10, 1, false) + require.Nil(t, err) + require.Equal(t, 1, total) + require.Equal(t, len(sps), 1) + + sps, total, err = store.ByUsername(ctx, "foo", 10, 1, true) + require.Nil(t, err) + require.Equal(t, 0, total) + require.Equal(t, len(sps), 0) + + sps, total, err = store.ByOrgPath(ctx, "foo", 10, 1, false) + require.Nil(t, err) + require.Equal(t, 1, total) + require.Equal(t, len(sps), 1) + + sps, total, err = store.ByOrgPath(ctx, "foo", 10, 1, true) + require.Nil(t, err) + require.Equal(t, 0, total) + require.Equal(t, len(sps), 0) + + sp, err = store.FindByPath(ctx, "foo", "bar") + require.Nil(t, err) + require.Equal(t, repo.ID, sp.RepositoryID) + + err = store.Delete(ctx, *sp) + require.Nil(t, err) + _, err = store.FindByPath(ctx, "foo", "bar") + require.NotNil(t, err) +} + +func TestSpaceStore_ListByPath(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewSpaceStoreWithDB(db) + + dt := &database.Tag{} + err := db.Core.NewInsert().Model(&database.Tag{ + Name: "tag1", + Category: "evaluation", + }).Scan(ctx, dt) + require.Nil(t, err) + tag1pk := dt.ID + + err = db.Core.NewInsert().Model(&database.Tag{ + Name: "tag2", + Category: "foo", + }).Scan(ctx, dt) + require.Nil(t, err) + tag2pk := dt.ID + + dr := &database.Repository{} + err = db.Core.NewInsert().Model(&database.Repository{ + Name: "repo", + Path: "foo/bar", + GitPath: "a", + }).Scan(ctx, dr) + require.Nil(t, err) + repopk := dr.ID + + for _, tpk := range []int64{tag1pk, tag2pk} { + _, err = db.Core.NewInsert().Model(&database.RepositoryTag{ + RepositoryID: repopk, + TagID: tpk, + }).Exec(ctx) + require.Nil(t, err) + } + + _, err = store.Create(ctx, database.Space{ + RepositoryID: repopk, + }) + require.Nil(t, err) + + dr2 := &database.Repository{} + err = db.Core.NewInsert().Model(&database.Repository{ + Name: "repo2", + Path: "bar/foo", + GitPath: "b", + }).Scan(ctx, dr2) + require.Nil(t, err) + _, err = store.Create(ctx, database.Space{ + RepositoryID: dr2.ID, + }) + require.Nil(t, err) + + dr3 := &database.Repository{} + err = db.Core.NewInsert().Model(&database.Repository{ + Name: "repo3", + Path: "foo/bar", + GitPath: "c", + RepositoryType: types.ModelRepo, + }).Scan(ctx, dr3) + require.Nil(t, err) + _, err = store.Create(ctx, database.Space{ + RepositoryID: dr3.ID, + }) + require.Nil(t, err) + + sps, err := store.ListByPath(ctx, []string{"bar/foo", "foo/bar"}) + require.Nil(t, err) + require.Equal(t, 3, len(sps)) + + tags := []string{} + for _, t := range sps[1].Repository.Tags { + tags = append(tags, t.Name) + } + require.Equal(t, []string{}, tags) + + names := []string{} + for _, sp := range sps { + names = append(names, sp.Repository.Name) + } + require.Equal(t, []string{"repo2", "repo", "repo3"}, names) + +} + +func TestSpaceStore_UserLikes(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewSpaceStoreWithDB(db) + + repos := []*database.Repository{ + {Name: "repo1", Path: "p1", GitPath: "p1"}, + {Name: "repo2", Path: "p2", GitPath: "p2"}, + {Name: "repo3", Path: "p3", GitPath: "p3"}, + } + + for _, repo := range repos { + err := db.Core.NewInsert().Model(repo).Scan(ctx, repo) + require.Nil(t, err) + _, err = store.Create(ctx, database.Space{ + RepositoryID: repo.ID, + }) + require.Nil(t, err) + + } + + _, err := db.Core.NewInsert().Model(&database.UserLike{ + UserID: 123, + RepoID: repos[0].ID, + }).Exec(ctx) + require.Nil(t, err) + _, err = db.Core.NewInsert().Model(&database.UserLike{ + UserID: 123, + RepoID: repos[2].ID, + }).Exec(ctx) + require.Nil(t, err) + + sps, total, err := store.ByUserLikes(ctx, 123, 10, 1) + require.Nil(t, err) + require.Equal(t, 2, total) + + names := []string{} + for _, sp := range sps { + names = append(names, sp.Repository.Name) + } + require.Equal(t, []string{"repo1", "repo3"}, names) + +} diff --git a/builder/store/database/ssh_key.go b/builder/store/database/ssh_key.go index dd49d114..6bd7f269 100644 --- a/builder/store/database/ssh_key.go +++ b/builder/store/database/ssh_key.go @@ -26,6 +26,12 @@ func NewSSHKeyStore() SSHKeyStore { } } +func NewSSHKeyStoreWithDB(db *DB) SSHKeyStore { + return &sSHKeyStoreImpl{ + db: db, + } +} + type SSHKey struct { ID int64 `bun:",pk,autoincrement" json:"id"` GitID int64 `bun:",notnull" json:"git_id"` diff --git a/builder/store/database/ssh_key_test.go b/builder/store/database/ssh_key_test.go new file mode 100644 index 00000000..ba5ecaf8 --- /dev/null +++ b/builder/store/database/ssh_key_test.go @@ -0,0 +1,65 @@ +package database_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/tests" +) + +func TestSSHKeyStore_CRUD(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewSSHKeyStoreWithDB(db) + user := &database.User{ + Username: "user", + } + err := db.Core.NewInsert().Model(user).Scan(ctx, user) + require.Nil(t, err) + _, err = store.Create(ctx, &database.SSHKey{ + GitID: 123, + FingerprintSHA256: "foo", + UserID: user.ID, + Name: "key", + Content: "content", + }) + require.Nil(t, err) + + sh := &database.SSHKey{} + err = db.Core.NewSelect().Model(sh).Where("git_id=?", 123).Scan(ctx) + require.Nil(t, err) + + sh, err = store.FindByID(ctx, sh.ID) + require.Nil(t, err) + require.Equal(t, int64(123), sh.GitID) + + sh, err = store.FindByFingerpringSHA256(ctx, "foo") + require.Nil(t, err) + require.Equal(t, int64(123), sh.GitID) + + exist, err := store.IsExist(ctx, "user", "key") + require.Nil(t, err) + require.True(t, exist) + + shv, err := store.FindByUsernameAndName(ctx, "user", "key") + require.Nil(t, err) + require.Equal(t, int64(123), shv.GitID) + + sh, err = store.FindByKeyContent(ctx, "content") + require.Nil(t, err) + require.Equal(t, int64(123), sh.GitID) + + sh, err = store.FindByNameAndUserID(ctx, "key", user.ID) + require.Nil(t, err) + require.Equal(t, int64(123), sh.GitID) + + err = store.Delete(ctx, 123) + require.Nil(t, err) + _, err = store.FindByID(ctx, sh.ID) + require.NotNil(t, err) + +} diff --git a/builder/store/database/sync_client_setting.go b/builder/store/database/sync_client_setting.go index f80c7c0d..8f3f8654 100644 --- a/builder/store/database/sync_client_setting.go +++ b/builder/store/database/sync_client_setting.go @@ -19,6 +19,12 @@ func NewSyncClientSettingStore() SyncClientSettingStore { } } +func NewSyncClientSettingStoreWithDB(db *DB) SyncClientSettingStore { + return &syncClientSettingStoreImpl{ + db: db, + } +} + type SyncClientSetting struct { ID int64 `bun:",pk,autoincrement" json:"id"` Token string `bun:",notnull" json:"token"` diff --git a/builder/store/database/sync_client_setting_test.go b/builder/store/database/sync_client_setting_test.go new file mode 100644 index 00000000..d3124fe8 --- /dev/null +++ b/builder/store/database/sync_client_setting_test.go @@ -0,0 +1,46 @@ +package database_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/tests" +) + +func TestSyncClientSettingStore_CRUD(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewSyncClientSettingStoreWithDB(db) + err := store.DeleteAll(ctx) + require.Nil(t, err) + + _, err = store.Create(ctx, &database.SyncClientSetting{ + Token: "tk", + ConcurrentCount: 5, + }) + require.Nil(t, err) + + sc := &database.SyncClientSetting{} + err = db.Core.NewSelect().Model(sc).Where("token=?", "tk").Scan(ctx, sc) + require.Nil(t, err) + require.Equal(t, 5, sc.ConcurrentCount) + + sc, err = store.First(ctx) + require.Nil(t, err) + require.Equal(t, 5, sc.ConcurrentCount) + + exist, err := store.SyncClientSettingExists(ctx) + require.Nil(t, err) + require.True(t, exist) + + err = store.DeleteAll(ctx) + require.Nil(t, err) + exist, err = store.SyncClientSettingExists(ctx) + require.Nil(t, err) + require.False(t, exist) + +} diff --git a/builder/store/database/sync_version.go b/builder/store/database/sync_version.go index 15762858..f194a061 100644 --- a/builder/store/database/sync_version.go +++ b/builder/store/database/sync_version.go @@ -25,6 +25,12 @@ func NewSyncVersionStore() SyncVersionStore { } } +func NewSyncVersionStoreWithDB(db *DB) SyncVersionStore { + return &syncVersionStoreImpl{ + db: db, + } +} + func (s *syncVersionStoreImpl) Create(ctx context.Context, version *SyncVersion) (err error) { _, err = s.db.Operator.Core.NewInsert().Model(version).Exec(ctx) return diff --git a/builder/store/database/sync_version_test.go b/builder/store/database/sync_version_test.go new file mode 100644 index 00000000..34d0b94e --- /dev/null +++ b/builder/store/database/sync_version_test.go @@ -0,0 +1,49 @@ +package database_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/tests" + "opencsg.com/csghub-server/common/types" +) + +func TestSyncVersionStore_CRUD(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewSyncVersionStoreWithDB(db) + + err := store.Create(ctx, &database.SyncVersion{ + Version: 1, + SourceID: 123, + RepoPath: "foo", + RepoType: types.ModelRepo, + }) + require.Nil(t, err) + + sv := &database.SyncVersion{} + err = db.Core.NewSelect().Model(sv).Where("version=?", 1).Scan(ctx, sv) + require.Nil(t, err) + require.Equal(t, int64(123), sv.SourceID) + + sv, err = store.FindByPath(ctx, "foo") + require.Nil(t, err) + require.Equal(t, int64(1), sv.Version) + + sv, err = store.FindByRepoTypeAndPath(ctx, "foo", types.ModelRepo) + require.Nil(t, err) + require.Equal(t, int64(1), sv.Version) + + err = store.BatchCreate(ctx, []database.SyncVersion{ + {Version: 2, RepoPath: "bar"}, + }) + require.Nil(t, err) + sv, err = store.FindByPath(ctx, "bar") + require.Nil(t, err) + require.Equal(t, int64(2), sv.Version) + +} diff --git a/builder/store/database/telemetry.go b/builder/store/database/telemetry.go index b56b42a7..29393f21 100644 --- a/builder/store/database/telemetry.go +++ b/builder/store/database/telemetry.go @@ -41,6 +41,12 @@ func NewTelemetryStore() TelemetryStore { } } +func NewTelemetryStoreWithDB(db *DB) TelemetryStore { + return &telemetryStoreImpl{ + db: db, + } +} + func (s *telemetryStoreImpl) Save(ctx context.Context, telemetry *Telemetry) error { return assertAffectedOneRow(s.db.Core.NewInsert().Model(telemetry).Exec(ctx)) } diff --git a/builder/store/database/telemetry_test.go b/builder/store/database/telemetry_test.go new file mode 100644 index 00000000..55db2c48 --- /dev/null +++ b/builder/store/database/telemetry_test.go @@ -0,0 +1,23 @@ +package database_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/tests" +) + +func TestTelemetryStore_Save(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewTelemetryStoreWithDB(db) + err := store.Save(ctx, &database.Telemetry{ + UUID: "foo", + }) + require.Nil(t, err) + +} diff --git a/component/callback/git_callback.go b/component/callback/git_callback.go index 7b51eba4..4c41798f 100644 --- a/component/callback/git_callback.go +++ b/component/callback/git_callback.go @@ -32,7 +32,7 @@ type GitCallbackComponent struct { rs database.RepoStore rrs database.RepoRelationsStore mirrorStore database.MirrorStore - rrf *database.RepositoriesRuntimeFrameworkStore + rrf database.RepositoriesRuntimeFrameworkStore rac component.RuntimeArchitectureComponent ras database.RuntimeArchitecturesStore rfs database.RuntimeFrameworksStore diff --git a/component/repo.go b/component/repo.go index d1378f42..84ec0202 100644 --- a/component/repo.go +++ b/component/repo.go @@ -65,7 +65,7 @@ type repoComponentImpl struct { mirrorSource database.MirrorSourceStore tokenStore database.AccessTokenStore rtfm database.RuntimeFrameworksStore - rrtfms *database.RepositoriesRuntimeFrameworkStore + rrtfms database.RepositoriesRuntimeFrameworkStore syncVersion database.SyncVersionStore syncClientSetting database.SyncClientSettingStore file database.FileStore