Skip to content

Commit

Permalink
Merge branch 'feature/component_tests' into 'main'
Browse files Browse the repository at this point in the history
Add missing component tests

See merge request product/starhub/starhub-server!719
  • Loading branch information
Da.Lei authored and Yiling-J committed Dec 16, 2024
1 parent 377dcaf commit 0b7ce18
Show file tree
Hide file tree
Showing 21 changed files with 1,388 additions and 149 deletions.
18 changes: 18 additions & 0 deletions common/tests/stores.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ type MockStores struct {
MultiSync database.MultiSyncStore
File database.FileStore
SSH database.SSHKeyStore
Telemetry database.TelemetryStore
RepoFile database.RepoFileStore
Event database.EventStore
}

func NewMockStores(t interface {
Expand Down Expand Up @@ -88,6 +91,9 @@ func NewMockStores(t interface {
MultiSync: mockdb.NewMockMultiSyncStore(t),
File: mockdb.NewMockFileStore(t),
SSH: mockdb.NewMockSSHKeyStore(t),
Telemetry: mockdb.NewMockTelemetryStore(t),
RepoFile: mockdb.NewMockRepoFileStore(t),
Event: mockdb.NewMockEventStore(t),
}
}

Expand Down Expand Up @@ -238,3 +244,15 @@ func (s *MockStores) FileMock() *mockdb.MockFileStore {
func (s *MockStores) SSHMock() *mockdb.MockSSHKeyStore {
return s.SSH.(*mockdb.MockSSHKeyStore)
}

func (s *MockStores) TelemetryMock() *mockdb.MockTelemetryStore {
return s.Telemetry.(*mockdb.MockTelemetryStore)
}

func (s *MockStores) RepoFileMock() *mockdb.MockRepoFileStore {
return s.RepoFile.(*mockdb.MockRepoFileStore)
}

func (s *MockStores) EventMock() *mockdb.MockEventStore {
return s.Event.(*mockdb.MockEventStore)
}
42 changes: 42 additions & 0 deletions component/cluster_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package component

import (
"context"
"testing"

"github.com/stretchr/testify/require"
"opencsg.com/csghub-server/common/types"
)

func TestClusterComponent_Index(t *testing.T) {
ctx := context.TODO()
cc := initializeTestClusterComponent(ctx, t)

cc.mocks.deployer.EXPECT().ListCluster(ctx).Return(nil, nil)

data, err := cc.Index(ctx)
require.Nil(t, err)
require.Equal(t, []types.ClusterRes([]types.ClusterRes(nil)), data)
}

func TestClusterComponent_GetClusterById(t *testing.T) {
ctx := context.TODO()
cc := initializeTestClusterComponent(ctx, t)

cc.mocks.deployer.EXPECT().GetClusterById(ctx, "c1").Return(nil, nil)

data, err := cc.GetClusterById(ctx, "c1")
require.Nil(t, err)
require.Equal(t, (*types.ClusterRes)(nil), data)
}

func TestClusterComponent_Update(t *testing.T) {
ctx := context.TODO()
cc := initializeTestClusterComponent(ctx, t)

cc.mocks.deployer.EXPECT().UpdateCluster(ctx, types.ClusterRequest{}).Return(nil, nil)

data, err := cc.Update(ctx, types.ClusterRequest{})
require.Nil(t, err)
require.Equal(t, (*types.UpdateClusterResponse)(nil), data)
}
26 changes: 13 additions & 13 deletions component/evaluation.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ import (
)

type evaluationComponentImpl struct {
deployer deploy.Deployer
userStore database.UserStore
modelStore database.ModelStore
datasetStore database.DatasetStore
mirrorStore database.MirrorStore
spaceResourceStore database.SpaceResourceStore
tokenStore database.AccessTokenStore
rtfm database.RuntimeFrameworksStore
config *config.Config
ac AccountingComponent
deployer deploy.Deployer
userStore database.UserStore
modelStore database.ModelStore
datasetStore database.DatasetStore
mirrorStore database.MirrorStore
spaceResourceStore database.SpaceResourceStore
tokenStore database.AccessTokenStore
runtimeFrameworkStore database.RuntimeFrameworksStore
config *config.Config
accountingComponent AccountingComponent
}

type EvaluationComponent interface {
Expand All @@ -43,13 +43,13 @@ func NewEvaluationComponent(config *config.Config) (EvaluationComponent, error)
c.datasetStore = database.NewDatasetStore()
c.mirrorStore = database.NewMirrorStore()
c.tokenStore = database.NewAccessTokenStore()
c.rtfm = database.NewRuntimeFrameworksStore()
c.runtimeFrameworkStore = database.NewRuntimeFrameworksStore()
c.config = config
ac, err := NewAccountingComponent(config)
if err != nil {
return nil, fmt.Errorf("failed to create accounting component, %w", err)
}
c.ac = ac
c.accountingComponent = ac
return c, nil
}

Expand Down Expand Up @@ -97,7 +97,7 @@ func (c *evaluationComponentImpl) CreateEvaluation(ctx context.Context, req type
hardware.Cpu.Num = "8"
hardware.Memory = "32Gi"
}
frame, err := c.rtfm.FindEnabledByID(ctx, req.RuntimeFrameworkId)
frame, err := c.runtimeFrameworkStore.FindEnabledByID(ctx, req.RuntimeFrameworkId)
if err != nil {
return nil, fmt.Errorf("cannot find available runtime framework, %w", err)
}
Expand Down
84 changes: 28 additions & 56 deletions component/evaluation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,10 @@ import (
"testing"

"github.com/stretchr/testify/require"
mock_deploy "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/deploy"
mock_component "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component"
"opencsg.com/csghub-server/builder/deploy"
"opencsg.com/csghub-server/builder/store/database"
"opencsg.com/csghub-server/common/config"
"opencsg.com/csghub-server/common/tests"
"opencsg.com/csghub-server/common/types"
)

func NewTestEvaluationComponent(deployer deploy.Deployer, stores *tests.MockStores, ac AccountingComponent) EvaluationComponent {
cfg := &config.Config{}
cfg.Argo.QuotaGPUNumber = "1"
return &evaluationComponentImpl{
deployer: deployer,
config: cfg,
userStore: stores.User,
modelStore: stores.Model,
datasetStore: stores.Dataset,
mirrorStore: stores.Mirror,
spaceResourceStore: stores.SpaceResource,
tokenStore: stores.AccessToken,
rtfm: stores.RuntimeFramework,
ac: ac,
}
}

func TestEvaluationComponent_CreateEvaluation(t *testing.T) {
req := types.EvaluationReq{
TaskName: "test",
Expand Down Expand Up @@ -66,30 +44,28 @@ func TestEvaluationComponent_CreateEvaluation(t *testing.T) {
Token: "foo",
}
t.Run("create evaluation without resource id", func(t *testing.T) {
deployerMock := &mock_deploy.MockDeployer{}
stores := tests.NewMockStores(t)
ac := &mock_component.MockAccountingComponent{}
c := NewTestEvaluationComponent(deployerMock, stores, ac)
stores.UserMock().EXPECT().FindByUsername(ctx, req.Username).Return(database.User{
c := initializeTestEvaluationComponent(ctx, t)
c.config.Argo.QuotaGPUNumber = "1"
c.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, req.Username).Return(database.User{
RoleMask: "admin",
Username: req.Username,
UUID: req.Username,
ID: 1,
}, nil).Once()
stores.ModelMock().EXPECT().FindByPath(ctx, "opencsg", "wukong").Return(
c.mocks.stores.ModelMock().EXPECT().FindByPath(ctx, "opencsg", "wukong").Return(
&database.Model{
ID: 1,
}, nil,
).Maybe()
stores.MirrorMock().EXPECT().FindByRepoPath(ctx, types.DatasetRepo, "opencsg", "hellaswag").Return(&database.Mirror{
c.mocks.stores.MirrorMock().EXPECT().FindByRepoPath(ctx, types.DatasetRepo, "opencsg", "hellaswag").Return(&database.Mirror{
SourceRepoPath: "Rowan/hellaswag",
}, nil)
stores.AccessTokenMock().EXPECT().FindByUID(ctx, int64(1)).Return(&database.AccessToken{Token: "foo"}, nil)
stores.RuntimeFrameworkMock().EXPECT().FindEnabledByID(ctx, int64(1)).Return(&database.RuntimeFramework{
c.mocks.stores.AccessTokenMock().EXPECT().FindByUID(ctx, int64(1)).Return(&database.AccessToken{Token: "foo"}, nil)
c.mocks.stores.RuntimeFrameworkMock().EXPECT().FindEnabledByID(ctx, int64(1)).Return(&database.RuntimeFramework{
ID: 1,
FrameImage: "lm-evaluation-harness:0.4.6",
}, nil)
deployerMock.EXPECT().SubmitEvaluation(ctx, req2).Return(&types.ArgoWorkFlowRes{
c.mocks.deployer.EXPECT().SubmitEvaluation(ctx, req2).Return(&types.ArgoWorkFlowRes{
ID: 1,
TaskName: "test",
}, nil)
Expand All @@ -101,36 +77,36 @@ func TestEvaluationComponent_CreateEvaluation(t *testing.T) {
t.Run("create evaluation with resource id", func(t *testing.T) {
req.ResourceId = 1
req2.ResourceId = 1
deployerMock := &mock_deploy.MockDeployer{}
stores := tests.NewMockStores(t)
ac := &mock_component.MockAccountingComponent{}
c := NewTestEvaluationComponent(deployerMock, stores, ac)
stores.UserMock().EXPECT().FindByUsername(ctx, req.Username).Return(database.User{
c := initializeTestEvaluationComponent(ctx, t)
c.config.Argo.QuotaGPUNumber = "1"
c.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, req.Username).Return(database.User{
RoleMask: "admin",
Username: req.Username,
UUID: req.Username,
ID: 1,
}, nil).Once()
stores.ModelMock().EXPECT().FindByPath(ctx, "opencsg", "wukong").Return(
c.mocks.stores.ModelMock().EXPECT().FindByPath(ctx, "opencsg", "wukong").Return(
&database.Model{
ID: 1,
}, nil,
).Maybe()
stores.MirrorMock().EXPECT().FindByRepoPath(ctx, types.DatasetRepo, "opencsg", "hellaswag").Return(&database.Mirror{
c.mocks.stores.MirrorMock().EXPECT().FindByRepoPath(ctx, types.DatasetRepo, "opencsg", "hellaswag").Return(&database.Mirror{
SourceRepoPath: "Rowan/hellaswag",
}, nil)
stores.AccessTokenMock().EXPECT().FindByUID(ctx, int64(1)).Return(&database.AccessToken{Token: "foo"}, nil)
stores.RuntimeFrameworkMock().EXPECT().FindEnabledByID(ctx, int64(1)).Return(&database.RuntimeFramework{
c.mocks.stores.AccessTokenMock().EXPECT().FindByUID(ctx, int64(1)).Return(&database.AccessToken{Token: "foo"}, nil)
c.mocks.stores.RuntimeFrameworkMock().EXPECT().FindEnabledByID(ctx, int64(1)).Return(&database.RuntimeFramework{
ID: 1,
FrameImage: "lm-evaluation-harness:0.4.6",
}, nil)

resource, err := json.Marshal(req2.Hardware)
require.Nil(t, err)
stores.SpaceResourceMock().EXPECT().FindByID(ctx, int64(1)).Return(&database.SpaceResource{
c.mocks.stores.SpaceResourceMock().EXPECT().FindByID(ctx, int64(1)).Return(&database.SpaceResource{
ID: 1,
Resources: string(resource),
}, nil)
deployerMock.EXPECT().SubmitEvaluation(ctx, req2).Return(&types.ArgoWorkFlowRes{
c.mocks.deployer.EXPECT().SubmitEvaluation(ctx, req2).Return(&types.ArgoWorkFlowRes{

ID: 1,
TaskName: "test",
}, nil)
Expand All @@ -142,15 +118,13 @@ func TestEvaluationComponent_CreateEvaluation(t *testing.T) {
}

func TestEvaluationComponent_GetEvaluation(t *testing.T) {
deployerMock := &mock_deploy.MockDeployer{}
stores := tests.NewMockStores(t)
ac := &mock_component.MockAccountingComponent{}
c := NewTestEvaluationComponent(deployerMock, stores, ac)
ctx := context.TODO()
c := initializeTestEvaluationComponent(ctx, t)
c.config.Argo.QuotaGPUNumber = "1"
req := types.EvaluationGetReq{
Username: "test",
}
ctx := context.TODO()
deployerMock.EXPECT().GetEvaluation(ctx, req).Return(&types.ArgoWorkFlowRes{
c.mocks.deployer.EXPECT().GetEvaluation(ctx, req).Return(&types.ArgoWorkFlowRes{
ID: 1,
RepoIds: []string{"Rowan/hellaswag"},
Datasets: []string{"Rowan/hellaswag"},
Expand All @@ -161,7 +135,7 @@ func TestEvaluationComponent_GetEvaluation(t *testing.T) {
TaskType: "evaluation",
Status: "Succeed",
}, nil)
stores.DatasetMock().EXPECT().ListByPath(ctx, []string{"Rowan/hellaswag"}).Return([]database.Dataset{
c.mocks.stores.DatasetMock().EXPECT().ListByPath(ctx, []string{"Rowan/hellaswag"}).Return([]database.Dataset{
{
Repository: &database.Repository{
Path: "Rowan/hellaswag",
Expand All @@ -184,15 +158,13 @@ func TestEvaluationComponent_GetEvaluation(t *testing.T) {
}

func TestEvaluationComponent_DeleteEvaluation(t *testing.T) {
deployerMock := &mock_deploy.MockDeployer{}
stores := tests.NewMockStores(t)
ac := &mock_component.MockAccountingComponent{}
c := NewTestEvaluationComponent(deployerMock, stores, ac)
ctx := context.TODO()
c := initializeTestEvaluationComponent(ctx, t)
c.config.Argo.QuotaGPUNumber = "1"
req := types.EvaluationDelReq{
Username: "test",
}
ctx := context.TODO()
deployerMock.EXPECT().DeleteEvaluation(ctx, req).Return(nil)
c.mocks.deployer.EXPECT().DeleteEvaluation(ctx, req).Return(nil)
err := c.DeleteEvaluation(ctx, req)
require.Nil(t, err)
}
6 changes: 3 additions & 3 deletions component/event.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
)

type eventComponentImpl struct {
es database.EventStore
eventStore database.EventStore
}

// NewEventComponent creates a new EventComponent
Expand All @@ -19,7 +19,7 @@ type EventComponent interface {

func NewEventComponent() EventComponent {
return &eventComponentImpl{
es: database.NewEventStore(),
eventStore: database.NewEventStore(),
}
}

Expand All @@ -34,5 +34,5 @@ func (ec *eventComponentImpl) NewEvents(ctx context.Context, events []types.Even
})
}

return ec.es.BatchSave(ctx, dbevents)
return ec.eventStore.BatchSave(ctx, dbevents)
}
22 changes: 22 additions & 0 deletions component/event_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package component

import (
"context"
"testing"

"github.com/stretchr/testify/require"
"opencsg.com/csghub-server/builder/store/database"
"opencsg.com/csghub-server/common/types"
)

func TestEventComponent_NewEvent(t *testing.T) {
ctx := context.TODO()
ec := initializeTestEventComponent(ctx, t)

ec.mocks.stores.EventMock().EXPECT().BatchSave(ctx, []database.Event{
{EventID: "e1"},
}).Return(nil)

err := ec.NewEvents(ctx, []types.Event{{ID: "e1"}})
require.Nil(t, err)
}
Loading

0 comments on commit 0b7ce18

Please sign in to comment.