diff --git a/.mockery.yaml b/.mockery.yaml index ec11083a..630d57a4 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -16,6 +16,7 @@ packages: AccountingComponent: SpaceComponent: RuntimeArchitectureComponent: + SensitiveComponent: opencsg.com/csghub-server/user/component: config: interfaces: @@ -89,4 +90,11 @@ packages: config: interfaces: AccountingClient: - + opencsg.com/csghub-server/builder/parquet: + config: + interfaces: + Reader: + opencsg.com/csghub-server/builder/multisync: + config: + interfaces: + Client: diff --git a/Makefile b/Makefile index f46e7380..21c7b05a 100644 --- a/Makefile +++ b/Makefile @@ -7,7 +7,7 @@ lint: golangci-lint run cover: - go test -coverprofile=cover.out -coverpkg=./... ./... + go test -coverprofile=cover.out ./... go tool cover -html=cover.out -o cover.html open cover.html diff --git a/_mocks/opencsg.com/csghub-server/builder/multisync/mock_Client.go b/_mocks/opencsg.com/csghub-server/builder/multisync/mock_Client.go new file mode 100644 index 00000000..22ca9225 --- /dev/null +++ b/_mocks/opencsg.com/csghub-server/builder/multisync/mock_Client.go @@ -0,0 +1,329 @@ +// Code generated by mockery v2.49.1. DO NOT EDIT. + +package multisync + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + + types "opencsg.com/csghub-server/common/types" +) + +// MockClient is an autogenerated mock type for the Client type +type MockClient struct { + mock.Mock +} + +type MockClient_Expecter struct { + mock *mock.Mock +} + +func (_m *MockClient) EXPECT() *MockClient_Expecter { + return &MockClient_Expecter{mock: &_m.Mock} +} + +// DatasetInfo provides a mock function with given fields: ctx, v +func (_m *MockClient) DatasetInfo(ctx context.Context, v types.SyncVersion) (*types.Dataset, error) { + ret := _m.Called(ctx, v) + + if len(ret) == 0 { + panic("no return value specified for DatasetInfo") + } + + var r0 *types.Dataset + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.SyncVersion) (*types.Dataset, error)); ok { + return rf(ctx, v) + } + if rf, ok := ret.Get(0).(func(context.Context, types.SyncVersion) *types.Dataset); ok { + r0 = rf(ctx, v) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.Dataset) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.SyncVersion) error); ok { + r1 = rf(ctx, v) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_DatasetInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DatasetInfo' +type MockClient_DatasetInfo_Call struct { + *mock.Call +} + +// DatasetInfo is a helper method to define mock.On call +// - ctx context.Context +// - v types.SyncVersion +func (_e *MockClient_Expecter) DatasetInfo(ctx interface{}, v interface{}) *MockClient_DatasetInfo_Call { + return &MockClient_DatasetInfo_Call{Call: _e.mock.On("DatasetInfo", ctx, v)} +} + +func (_c *MockClient_DatasetInfo_Call) Run(run func(ctx context.Context, v types.SyncVersion)) *MockClient_DatasetInfo_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.SyncVersion)) + }) + return _c +} + +func (_c *MockClient_DatasetInfo_Call) Return(_a0 *types.Dataset, _a1 error) *MockClient_DatasetInfo_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_DatasetInfo_Call) RunAndReturn(run func(context.Context, types.SyncVersion) (*types.Dataset, error)) *MockClient_DatasetInfo_Call { + _c.Call.Return(run) + return _c +} + +// FileList provides a mock function with given fields: ctx, v +func (_m *MockClient) FileList(ctx context.Context, v types.SyncVersion) ([]types.File, error) { + ret := _m.Called(ctx, v) + + if len(ret) == 0 { + panic("no return value specified for FileList") + } + + var r0 []types.File + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.SyncVersion) ([]types.File, error)); ok { + return rf(ctx, v) + } + if rf, ok := ret.Get(0).(func(context.Context, types.SyncVersion) []types.File); ok { + r0 = rf(ctx, v) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.File) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.SyncVersion) error); ok { + r1 = rf(ctx, v) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_FileList_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'FileList' +type MockClient_FileList_Call struct { + *mock.Call +} + +// FileList is a helper method to define mock.On call +// - ctx context.Context +// - v types.SyncVersion +func (_e *MockClient_Expecter) FileList(ctx interface{}, v interface{}) *MockClient_FileList_Call { + return &MockClient_FileList_Call{Call: _e.mock.On("FileList", ctx, v)} +} + +func (_c *MockClient_FileList_Call) Run(run func(ctx context.Context, v types.SyncVersion)) *MockClient_FileList_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.SyncVersion)) + }) + return _c +} + +func (_c *MockClient_FileList_Call) Return(_a0 []types.File, _a1 error) *MockClient_FileList_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_FileList_Call) RunAndReturn(run func(context.Context, types.SyncVersion) ([]types.File, error)) *MockClient_FileList_Call { + _c.Call.Return(run) + return _c +} + +// Latest provides a mock function with given fields: ctx, currentVersion +func (_m *MockClient) Latest(ctx context.Context, currentVersion int64) (types.SyncVersionResponse, error) { + ret := _m.Called(ctx, currentVersion) + + if len(ret) == 0 { + panic("no return value specified for Latest") + } + + var r0 types.SyncVersionResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, int64) (types.SyncVersionResponse, error)); ok { + return rf(ctx, currentVersion) + } + if rf, ok := ret.Get(0).(func(context.Context, int64) types.SyncVersionResponse); ok { + r0 = rf(ctx, currentVersion) + } else { + r0 = ret.Get(0).(types.SyncVersionResponse) + } + + if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok { + r1 = rf(ctx, currentVersion) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_Latest_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Latest' +type MockClient_Latest_Call struct { + *mock.Call +} + +// Latest is a helper method to define mock.On call +// - ctx context.Context +// - currentVersion int64 +func (_e *MockClient_Expecter) Latest(ctx interface{}, currentVersion interface{}) *MockClient_Latest_Call { + return &MockClient_Latest_Call{Call: _e.mock.On("Latest", ctx, currentVersion)} +} + +func (_c *MockClient_Latest_Call) Run(run func(ctx context.Context, currentVersion int64)) *MockClient_Latest_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(int64)) + }) + return _c +} + +func (_c *MockClient_Latest_Call) Return(_a0 types.SyncVersionResponse, _a1 error) *MockClient_Latest_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_Latest_Call) RunAndReturn(run func(context.Context, int64) (types.SyncVersionResponse, error)) *MockClient_Latest_Call { + _c.Call.Return(run) + return _c +} + +// ModelInfo provides a mock function with given fields: ctx, v +func (_m *MockClient) ModelInfo(ctx context.Context, v types.SyncVersion) (*types.Model, error) { + ret := _m.Called(ctx, v) + + if len(ret) == 0 { + panic("no return value specified for ModelInfo") + } + + var r0 *types.Model + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.SyncVersion) (*types.Model, error)); ok { + return rf(ctx, v) + } + if rf, ok := ret.Get(0).(func(context.Context, types.SyncVersion) *types.Model); ok { + r0 = rf(ctx, v) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.Model) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.SyncVersion) error); ok { + r1 = rf(ctx, v) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_ModelInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ModelInfo' +type MockClient_ModelInfo_Call struct { + *mock.Call +} + +// ModelInfo is a helper method to define mock.On call +// - ctx context.Context +// - v types.SyncVersion +func (_e *MockClient_Expecter) ModelInfo(ctx interface{}, v interface{}) *MockClient_ModelInfo_Call { + return &MockClient_ModelInfo_Call{Call: _e.mock.On("ModelInfo", ctx, v)} +} + +func (_c *MockClient_ModelInfo_Call) Run(run func(ctx context.Context, v types.SyncVersion)) *MockClient_ModelInfo_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.SyncVersion)) + }) + return _c +} + +func (_c *MockClient_ModelInfo_Call) Return(_a0 *types.Model, _a1 error) *MockClient_ModelInfo_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_ModelInfo_Call) RunAndReturn(run func(context.Context, types.SyncVersion) (*types.Model, error)) *MockClient_ModelInfo_Call { + _c.Call.Return(run) + return _c +} + +// ReadMeData provides a mock function with given fields: ctx, v +func (_m *MockClient) ReadMeData(ctx context.Context, v types.SyncVersion) (string, error) { + ret := _m.Called(ctx, v) + + if len(ret) == 0 { + panic("no return value specified for ReadMeData") + } + + var r0 string + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.SyncVersion) (string, error)); ok { + return rf(ctx, v) + } + if rf, ok := ret.Get(0).(func(context.Context, types.SyncVersion) string); ok { + r0 = rf(ctx, v) + } else { + r0 = ret.Get(0).(string) + } + + if rf, ok := ret.Get(1).(func(context.Context, types.SyncVersion) error); ok { + r1 = rf(ctx, v) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_ReadMeData_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReadMeData' +type MockClient_ReadMeData_Call struct { + *mock.Call +} + +// ReadMeData is a helper method to define mock.On call +// - ctx context.Context +// - v types.SyncVersion +func (_e *MockClient_Expecter) ReadMeData(ctx interface{}, v interface{}) *MockClient_ReadMeData_Call { + return &MockClient_ReadMeData_Call{Call: _e.mock.On("ReadMeData", ctx, v)} +} + +func (_c *MockClient_ReadMeData_Call) Run(run func(ctx context.Context, v types.SyncVersion)) *MockClient_ReadMeData_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.SyncVersion)) + }) + return _c +} + +func (_c *MockClient_ReadMeData_Call) Return(_a0 string, _a1 error) *MockClient_ReadMeData_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_ReadMeData_Call) RunAndReturn(run func(context.Context, types.SyncVersion) (string, error)) *MockClient_ReadMeData_Call { + _c.Call.Return(run) + return _c +} + +// NewMockClient creates a new instance of MockClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockClient(t interface { + mock.TestingT + Cleanup(func()) +}) *MockClient { + mock := &MockClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/_mocks/opencsg.com/csghub-server/builder/parquet/mock_Reader.go b/_mocks/opencsg.com/csghub-server/builder/parquet/mock_Reader.go new file mode 100644 index 00000000..6e31702e --- /dev/null +++ b/_mocks/opencsg.com/csghub-server/builder/parquet/mock_Reader.go @@ -0,0 +1,156 @@ +// Code generated by mockery v2.49.1. DO NOT EDIT. + +package parquet + +import mock "github.com/stretchr/testify/mock" + +// MockReader is an autogenerated mock type for the Reader type +type MockReader struct { + mock.Mock +} + +type MockReader_Expecter struct { + mock *mock.Mock +} + +func (_m *MockReader) EXPECT() *MockReader_Expecter { + return &MockReader_Expecter{mock: &_m.Mock} +} + +// RowCount provides a mock function with given fields: objName +func (_m *MockReader) RowCount(objName string) (int, error) { + ret := _m.Called(objName) + + if len(ret) == 0 { + panic("no return value specified for RowCount") + } + + var r0 int + var r1 error + if rf, ok := ret.Get(0).(func(string) (int, error)); ok { + return rf(objName) + } + if rf, ok := ret.Get(0).(func(string) int); ok { + r0 = rf(objName) + } else { + r0 = ret.Get(0).(int) + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(objName) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockReader_RowCount_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RowCount' +type MockReader_RowCount_Call struct { + *mock.Call +} + +// RowCount is a helper method to define mock.On call +// - objName string +func (_e *MockReader_Expecter) RowCount(objName interface{}) *MockReader_RowCount_Call { + return &MockReader_RowCount_Call{Call: _e.mock.On("RowCount", objName)} +} + +func (_c *MockReader_RowCount_Call) Run(run func(objName string)) *MockReader_RowCount_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *MockReader_RowCount_Call) Return(count int, err error) *MockReader_RowCount_Call { + _c.Call.Return(count, err) + return _c +} + +func (_c *MockReader_RowCount_Call) RunAndReturn(run func(string) (int, error)) *MockReader_RowCount_Call { + _c.Call.Return(run) + return _c +} + +// TopN provides a mock function with given fields: objName, count +func (_m *MockReader) TopN(objName string, count int) ([]string, [][]interface{}, error) { + ret := _m.Called(objName, count) + + if len(ret) == 0 { + panic("no return value specified for TopN") + } + + var r0 []string + var r1 [][]interface{} + var r2 error + if rf, ok := ret.Get(0).(func(string, int) ([]string, [][]interface{}, error)); ok { + return rf(objName, count) + } + if rf, ok := ret.Get(0).(func(string, int) []string); ok { + r0 = rf(objName, count) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + if rf, ok := ret.Get(1).(func(string, int) [][]interface{}); ok { + r1 = rf(objName, count) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).([][]interface{}) + } + } + + if rf, ok := ret.Get(2).(func(string, int) error); ok { + r2 = rf(objName, count) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockReader_TopN_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'TopN' +type MockReader_TopN_Call struct { + *mock.Call +} + +// TopN is a helper method to define mock.On call +// - objName string +// - count int +func (_e *MockReader_Expecter) TopN(objName interface{}, count interface{}) *MockReader_TopN_Call { + return &MockReader_TopN_Call{Call: _e.mock.On("TopN", objName, count)} +} + +func (_c *MockReader_TopN_Call) Run(run func(objName string, count int)) *MockReader_TopN_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string), args[1].(int)) + }) + return _c +} + +func (_c *MockReader_TopN_Call) Return(columns []string, rows [][]interface{}, err error) *MockReader_TopN_Call { + _c.Call.Return(columns, rows, err) + return _c +} + +func (_c *MockReader_TopN_Call) RunAndReturn(run func(string, int) ([]string, [][]interface{}, error)) *MockReader_TopN_Call { + _c.Call.Return(run) + return _c +} + +// NewMockReader creates a new instance of MockReader. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockReader(t interface { + mock.TestingT + Cleanup(func()) +}) *MockReader { + mock := &MockReader{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/_mocks/opencsg.com/csghub-server/component/mock_SensitiveComponent.go b/_mocks/opencsg.com/csghub-server/component/mock_SensitiveComponent.go new file mode 100644 index 00000000..e8d6517d --- /dev/null +++ b/_mocks/opencsg.com/csghub-server/component/mock_SensitiveComponent.go @@ -0,0 +1,211 @@ +// Code generated by mockery v2.49.1. DO NOT EDIT. + +package component + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + types "opencsg.com/csghub-server/common/types" +) + +// MockSensitiveComponent is an autogenerated mock type for the SensitiveComponent type +type MockSensitiveComponent struct { + mock.Mock +} + +type MockSensitiveComponent_Expecter struct { + mock *mock.Mock +} + +func (_m *MockSensitiveComponent) EXPECT() *MockSensitiveComponent_Expecter { + return &MockSensitiveComponent_Expecter{mock: &_m.Mock} +} + +// CheckImage provides a mock function with given fields: ctx, scenario, ossBucketName, ossObjectName +func (_m *MockSensitiveComponent) CheckImage(ctx context.Context, scenario string, ossBucketName string, ossObjectName string) (bool, error) { + ret := _m.Called(ctx, scenario, ossBucketName, ossObjectName) + + if len(ret) == 0 { + panic("no return value specified for CheckImage") + } + + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (bool, error)); ok { + return rf(ctx, scenario, ossBucketName, ossObjectName) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) bool); ok { + r0 = rf(ctx, scenario, ossBucketName, ossObjectName) + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok { + r1 = rf(ctx, scenario, ossBucketName, ossObjectName) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockSensitiveComponent_CheckImage_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckImage' +type MockSensitiveComponent_CheckImage_Call struct { + *mock.Call +} + +// CheckImage is a helper method to define mock.On call +// - ctx context.Context +// - scenario string +// - ossBucketName string +// - ossObjectName string +func (_e *MockSensitiveComponent_Expecter) CheckImage(ctx interface{}, scenario interface{}, ossBucketName interface{}, ossObjectName interface{}) *MockSensitiveComponent_CheckImage_Call { + return &MockSensitiveComponent_CheckImage_Call{Call: _e.mock.On("CheckImage", ctx, scenario, ossBucketName, ossObjectName)} +} + +func (_c *MockSensitiveComponent_CheckImage_Call) Run(run func(ctx context.Context, scenario string, ossBucketName string, ossObjectName string)) *MockSensitiveComponent_CheckImage_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string)) + }) + return _c +} + +func (_c *MockSensitiveComponent_CheckImage_Call) Return(_a0 bool, _a1 error) *MockSensitiveComponent_CheckImage_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockSensitiveComponent_CheckImage_Call) RunAndReturn(run func(context.Context, string, string, string) (bool, error)) *MockSensitiveComponent_CheckImage_Call { + _c.Call.Return(run) + return _c +} + +// CheckRequestV2 provides a mock function with given fields: ctx, req +func (_m *MockSensitiveComponent) CheckRequestV2(ctx context.Context, req types.SensitiveRequestV2) (bool, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for CheckRequestV2") + } + + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.SensitiveRequestV2) (bool, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, types.SensitiveRequestV2) bool); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func(context.Context, types.SensitiveRequestV2) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockSensitiveComponent_CheckRequestV2_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckRequestV2' +type MockSensitiveComponent_CheckRequestV2_Call struct { + *mock.Call +} + +// CheckRequestV2 is a helper method to define mock.On call +// - ctx context.Context +// - req types.SensitiveRequestV2 +func (_e *MockSensitiveComponent_Expecter) CheckRequestV2(ctx interface{}, req interface{}) *MockSensitiveComponent_CheckRequestV2_Call { + return &MockSensitiveComponent_CheckRequestV2_Call{Call: _e.mock.On("CheckRequestV2", ctx, req)} +} + +func (_c *MockSensitiveComponent_CheckRequestV2_Call) Run(run func(ctx context.Context, req types.SensitiveRequestV2)) *MockSensitiveComponent_CheckRequestV2_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.SensitiveRequestV2)) + }) + return _c +} + +func (_c *MockSensitiveComponent_CheckRequestV2_Call) Return(_a0 bool, _a1 error) *MockSensitiveComponent_CheckRequestV2_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockSensitiveComponent_CheckRequestV2_Call) RunAndReturn(run func(context.Context, types.SensitiveRequestV2) (bool, error)) *MockSensitiveComponent_CheckRequestV2_Call { + _c.Call.Return(run) + return _c +} + +// CheckText provides a mock function with given fields: ctx, scenario, text +func (_m *MockSensitiveComponent) CheckText(ctx context.Context, scenario string, text string) (bool, error) { + ret := _m.Called(ctx, scenario, text) + + if len(ret) == 0 { + panic("no return value specified for CheckText") + } + + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) (bool, error)); ok { + return rf(ctx, scenario, text) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string) bool); ok { + r0 = rf(ctx, scenario, text) + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = rf(ctx, scenario, text) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockSensitiveComponent_CheckText_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckText' +type MockSensitiveComponent_CheckText_Call struct { + *mock.Call +} + +// CheckText is a helper method to define mock.On call +// - ctx context.Context +// - scenario string +// - text string +func (_e *MockSensitiveComponent_Expecter) CheckText(ctx interface{}, scenario interface{}, text interface{}) *MockSensitiveComponent_CheckText_Call { + return &MockSensitiveComponent_CheckText_Call{Call: _e.mock.On("CheckText", ctx, scenario, text)} +} + +func (_c *MockSensitiveComponent_CheckText_Call) Run(run func(ctx context.Context, scenario string, text string)) *MockSensitiveComponent_CheckText_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string)) + }) + return _c +} + +func (_c *MockSensitiveComponent_CheckText_Call) Return(_a0 bool, _a1 error) *MockSensitiveComponent_CheckText_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockSensitiveComponent_CheckText_Call) RunAndReturn(run func(context.Context, string, string) (bool, error)) *MockSensitiveComponent_CheckText_Call { + _c.Call.Return(run) + return _c +} + +// NewMockSensitiveComponent creates a new instance of MockSensitiveComponent. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockSensitiveComponent(t interface { + mock.TestingT + Cleanup(func()) +}) *MockSensitiveComponent { + mock := &MockSensitiveComponent{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/_mocks/opencsg.com/csghub-server/component/mock_TagComponent.go b/_mocks/opencsg.com/csghub-server/component/mock_TagComponent.go index af076cd7..05021f20 100644 --- a/_mocks/opencsg.com/csghub-server/component/mock_TagComponent.go +++ b/_mocks/opencsg.com/csghub-server/component/mock_TagComponent.go @@ -24,64 +24,6 @@ func (_m *MockTagComponent) EXPECT() *MockTagComponent_Expecter { return &MockTagComponent_Expecter{mock: &_m.Mock} } -// AllTags provides a mock function with given fields: ctx -func (_m *MockTagComponent) AllTags(ctx context.Context) ([]database.Tag, error) { - ret := _m.Called(ctx) - - if len(ret) == 0 { - panic("no return value specified for AllTags") - } - - var r0 []database.Tag - var r1 error - if rf, ok := ret.Get(0).(func(context.Context) ([]database.Tag, error)); ok { - return rf(ctx) - } - if rf, ok := ret.Get(0).(func(context.Context) []database.Tag); ok { - r0 = rf(ctx) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]database.Tag) - } - } - - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(ctx) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// MockTagComponent_AllTags_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AllTags' -type MockTagComponent_AllTags_Call struct { - *mock.Call -} - -// AllTags is a helper method to define mock.On call -// - ctx context.Context -func (_e *MockTagComponent_Expecter) AllTags(ctx interface{}) *MockTagComponent_AllTags_Call { - return &MockTagComponent_AllTags_Call{Call: _e.mock.On("AllTags", ctx)} -} - -func (_c *MockTagComponent_AllTags_Call) Run(run func(ctx context.Context)) *MockTagComponent_AllTags_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context)) - }) - return _c -} - -func (_c *MockTagComponent_AllTags_Call) Return(_a0 []database.Tag, _a1 error) *MockTagComponent_AllTags_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *MockTagComponent_AllTags_Call) RunAndReturn(run func(context.Context) ([]database.Tag, error)) *MockTagComponent_AllTags_Call { - _c.Call.Return(run) - return _c -} - // AllTagsByScopeAndCategory provides a mock function with given fields: ctx, scope, category func (_m *MockTagComponent) AllTagsByScopeAndCategory(ctx context.Context, scope string, category string) ([]*database.Tag, error) { ret := _m.Called(ctx, scope, category) diff --git a/common/tests/stores.go b/common/tests/stores.go index 46380274..2197113f 100644 --- a/common/tests/stores.go +++ b/common/tests/stores.go @@ -21,6 +21,7 @@ type MockStores struct { Prompt database.PromptStore Namespace database.NamespaceStore LfsMetaObject database.LfsMetaObjectStore + LfsLock database.LfsLockStore Mirror database.MirrorStore MirrorSource database.MirrorSourceStore AccessToken database.AccessTokenStore @@ -35,6 +36,14 @@ type MockStores struct { SpaceSdk database.SpaceSdkStore Recom database.RecomStore RepoRuntimeFramework database.RepositoriesRuntimeFrameworkStore + Discussion database.DiscussionStore + RuntimeArch database.RuntimeArchitecturesStore + ResourceModel database.ResourceModelStore + GitServerAccessToken database.GitServerAccessTokenStore + Org database.OrgStore + MultiSync database.MultiSyncStore + File database.FileStore + SSH database.SSHKeyStore } func NewMockStores(t interface { @@ -56,6 +65,7 @@ func NewMockStores(t interface { Prompt: mockdb.NewMockPromptStore(t), Namespace: mockdb.NewMockNamespaceStore(t), LfsMetaObject: mockdb.NewMockLfsMetaObjectStore(t), + LfsLock: mockdb.NewMockLfsLockStore(t), Mirror: mockdb.NewMockMirrorStore(t), MirrorSource: mockdb.NewMockMirrorSourceStore(t), AccessToken: mockdb.NewMockAccessTokenStore(t), @@ -70,6 +80,14 @@ func NewMockStores(t interface { SpaceSdk: mockdb.NewMockSpaceSdkStore(t), Recom: mockdb.NewMockRecomStore(t), RepoRuntimeFramework: mockdb.NewMockRepositoriesRuntimeFrameworkStore(t), + Discussion: mockdb.NewMockDiscussionStore(t), + RuntimeArch: mockdb.NewMockRuntimeArchitecturesStore(t), + ResourceModel: mockdb.NewMockResourceModelStore(t), + GitServerAccessToken: mockdb.NewMockGitServerAccessTokenStore(t), + Org: mockdb.NewMockOrgStore(t), + MultiSync: mockdb.NewMockMultiSyncStore(t), + File: mockdb.NewMockFileStore(t), + SSH: mockdb.NewMockSSHKeyStore(t), } } @@ -129,6 +147,10 @@ func (s *MockStores) LfsMetaObjectMock() *mockdb.MockLfsMetaObjectStore { return s.LfsMetaObject.(*mockdb.MockLfsMetaObjectStore) } +func (s *MockStores) LfsLockMock() *mockdb.MockLfsLockStore { + return s.LfsLock.(*mockdb.MockLfsLockStore) +} + func (s *MockStores) MirrorMock() *mockdb.MockMirrorStore { return s.Mirror.(*mockdb.MockMirrorStore) } @@ -184,3 +206,35 @@ func (s *MockStores) RecomMock() *mockdb.MockRecomStore { func (s *MockStores) RepoRuntimeFrameworkMock() *mockdb.MockRepositoriesRuntimeFrameworkStore { return s.RepoRuntimeFramework.(*mockdb.MockRepositoriesRuntimeFrameworkStore) } + +func (s *MockStores) DiscussionMock() *mockdb.MockDiscussionStore { + return s.Discussion.(*mockdb.MockDiscussionStore) +} + +func (s *MockStores) RuntimeArchMock() *mockdb.MockRuntimeArchitecturesStore { + return s.RuntimeArch.(*mockdb.MockRuntimeArchitecturesStore) +} + +func (s *MockStores) ResourceModelMock() *mockdb.MockResourceModelStore { + return s.ResourceModel.(*mockdb.MockResourceModelStore) +} + +func (s *MockStores) GitServerAccessTokenMock() *mockdb.MockGitServerAccessTokenStore { + return s.GitServerAccessToken.(*mockdb.MockGitServerAccessTokenStore) +} + +func (s *MockStores) OrgMock() *mockdb.MockOrgStore { + return s.Org.(*mockdb.MockOrgStore) +} + +func (s *MockStores) MultiSyncMock() *mockdb.MockMultiSyncStore { + return s.MultiSync.(*mockdb.MockMultiSyncStore) +} + +func (s *MockStores) FileMock() *mockdb.MockFileStore { + return s.File.(*mockdb.MockFileStore) +} + +func (s *MockStores) SSHMock() *mockdb.MockSSHKeyStore { + return s.SSH.(*mockdb.MockSSHKeyStore) +} diff --git a/component/code.go b/component/code.go index d2cf14af..07f7a8b1 100644 --- a/component/code.go +++ b/component/code.go @@ -5,7 +5,10 @@ import ( "fmt" "log/slog" + "opencsg.com/csghub-server/builder/git" + "opencsg.com/csghub-server/builder/git/gitserver" "opencsg.com/csghub-server/builder/git/membership" + "opencsg.com/csghub-server/builder/rpc" "opencsg.com/csghub-server/builder/store/database" "opencsg.com/csghub-server/common/config" "opencsg.com/csghub-server/common/types" @@ -27,19 +30,32 @@ type CodeComponent interface { func NewCodeComponent(config *config.Config) (CodeComponent, error) { c := &codeComponentImpl{} var err error - c.repoComponentImpl, err = NewRepoComponentImpl(config) + c.repoComponent, err = NewRepoComponentImpl(config) if err != nil { return nil, err } - c.cs = database.NewCodeStore() - c.rs = database.NewRepoStore() + c.codeStore = database.NewCodeStore() + c.repoStore = database.NewRepoStore() + gs, err := git.NewGitServer(config) + if err != nil { + return nil, fmt.Errorf("failed to create git server, error: %w", err) + } + c.gitServer = gs + c.config = config + c.userLikesStore = database.NewUserLikesStore() + c.userSvcClient = rpc.NewUserSvcHttpClient(fmt.Sprintf("%s:%d", config.User.Host, config.User.Port), + rpc.AuthWithApiKey(config.APIToken)) return c, nil } type codeComponentImpl struct { - *repoComponentImpl - cs database.CodeStore - rs database.RepoStore + config *config.Config + repoComponent RepoComponent + codeStore database.CodeStore + repoStore database.RepoStore + userLikesStore database.UserLikesStore + gitServer gitserver.GitServer + userSvcClient rpc.UserSvcClient } func (c *codeComponentImpl) Create(ctx context.Context, req *types.CreateCodeReq) (*types.Code, error) { @@ -61,7 +77,7 @@ func (c *codeComponentImpl) Create(ctx context.Context, req *types.CreateCodeReq req.RepoType = types.CodeRepo req.Readme = generateReadmeData(req.License) req.Nickname = nickname - _, dbRepo, err := c.CreateRepo(ctx, req.CreateRepoReq) + _, dbRepo, err := c.repoComponent.CreateRepo(ctx, req.CreateRepoReq) if err != nil { return nil, err } @@ -71,13 +87,13 @@ func (c *codeComponentImpl) Create(ctx context.Context, req *types.CreateCodeReq RepositoryID: dbRepo.ID, } - code, err := c.cs.Create(ctx, dbCode) + code, err := c.codeStore.Create(ctx, dbCode) if err != nil { return nil, fmt.Errorf("failed to create database code, cause: %w", err) } // Create README.md file - err = c.git.CreateRepoFile(buildCreateFileReq(&types.CreateFileParams{ + err = c.gitServer.CreateRepoFile(buildCreateFileReq(&types.CreateFileParams{ Username: dbRepo.User.Username, Email: dbRepo.User.Email, Message: initCommitMessage, @@ -93,7 +109,7 @@ func (c *codeComponentImpl) Create(ctx context.Context, req *types.CreateCodeReq } // Create .gitattributes file - err = c.git.CreateRepoFile(buildCreateFileReq(&types.CreateFileParams{ + err = c.gitServer.CreateRepoFile(buildCreateFileReq(&types.CreateFileParams{ Username: dbRepo.User.Username, Email: dbRepo.User.Email, Message: initCommitMessage, @@ -149,7 +165,7 @@ func (c *codeComponentImpl) Index(ctx context.Context, filter *types.RepoFilter, err error resCodes []types.Code ) - repos, total, err := c.PublicToUser(ctx, types.CodeRepo, filter.Username, filter, per, page) + repos, total, err := c.repoComponent.PublicToUser(ctx, types.CodeRepo, filter.Username, filter, per, page) if err != nil { newError := fmt.Errorf("failed to get public code repos,error:%w", err) return nil, 0, newError @@ -158,7 +174,7 @@ func (c *codeComponentImpl) Index(ctx context.Context, filter *types.RepoFilter, for _, repo := range repos { repoIDs = append(repoIDs, repo.ID) } - codes, err := c.cs.ByRepoIDs(ctx, repoIDs) + codes, err := c.codeStore.ByRepoIDs(ctx, repoIDs) if err != nil { newError := fmt.Errorf("failed to get codes by repo ids,error:%w", err) return nil, 0, newError @@ -210,18 +226,18 @@ func (c *codeComponentImpl) Index(ctx context.Context, filter *types.RepoFilter, func (c *codeComponentImpl) Update(ctx context.Context, req *types.UpdateCodeReq) (*types.Code, error) { req.RepoType = types.CodeRepo - dbRepo, err := c.UpdateRepo(ctx, req.UpdateRepoReq) + dbRepo, err := c.repoComponent.UpdateRepo(ctx, req.UpdateRepoReq) if err != nil { return nil, err } - code, err := c.cs.ByRepoID(ctx, dbRepo.ID) + code, err := c.codeStore.ByRepoID(ctx, dbRepo.ID) if err != nil { return nil, fmt.Errorf("failed to find code repo, error: %w", err) } //update times of code - err = c.cs.Update(ctx, *code) + err = c.codeStore.Update(ctx, *code) if err != nil { return nil, fmt.Errorf("failed to update database code repo, error: %w", err) } @@ -244,7 +260,7 @@ func (c *codeComponentImpl) Update(ctx context.Context, req *types.UpdateCodeReq } func (c *codeComponentImpl) Delete(ctx context.Context, namespace, name, currentUser string) error { - code, err := c.cs.FindByPath(ctx, namespace, name) + code, err := c.codeStore.FindByPath(ctx, namespace, name) if err != nil { return fmt.Errorf("failed to find code, error: %w", err) } @@ -255,12 +271,12 @@ func (c *codeComponentImpl) Delete(ctx context.Context, namespace, name, current Name: name, RepoType: types.CodeRepo, } - _, err = c.DeleteRepo(ctx, deleteDatabaseRepoReq) + _, err = c.repoComponent.DeleteRepo(ctx, deleteDatabaseRepoReq) if err != nil { return fmt.Errorf("failed to delete repo of code, error: %w", err) } - err = c.cs.Delete(ctx, *code) + err = c.codeStore.Delete(ctx, *code) if err != nil { return fmt.Errorf("failed to delete database code, error: %w", err) } @@ -269,12 +285,12 @@ func (c *codeComponentImpl) Delete(ctx context.Context, namespace, name, current func (c *codeComponentImpl) Show(ctx context.Context, namespace, name, currentUser string) (*types.Code, error) { var tags []types.RepoTag - code, err := c.cs.FindByPath(ctx, namespace, name) + code, err := c.codeStore.FindByPath(ctx, namespace, name) if err != nil { return nil, fmt.Errorf("failed to find code, error: %w", err) } - permission, err := c.GetUserRepoPermission(ctx, currentUser, code.Repository) + permission, err := c.repoComponent.GetUserRepoPermission(ctx, currentUser, code.Repository) if err != nil { return nil, fmt.Errorf("failed to get user repo permission, error: %w", err) } @@ -282,7 +298,7 @@ func (c *codeComponentImpl) Show(ctx context.Context, namespace, name, currentUs return nil, ErrUnauthorized } - ns, err := c.GetNameSpaceInfo(ctx, namespace) + ns, err := c.repoComponent.GetNameSpaceInfo(ctx, namespace) if err != nil { return nil, fmt.Errorf("failed to get namespace info for code, error: %w", err) } @@ -338,12 +354,12 @@ func (c *codeComponentImpl) Show(ctx context.Context, namespace, name, currentUs } func (c *codeComponentImpl) Relations(ctx context.Context, namespace, name, currentUser string) (*types.Relations, error) { - code, err := c.cs.FindByPath(ctx, namespace, name) + code, err := c.codeStore.FindByPath(ctx, namespace, name) if err != nil { return nil, fmt.Errorf("failed to find code repo, error: %w", err) } - allow, _ := c.AllowReadAccessRepo(ctx, code.Repository, currentUser) + allow, _ := c.repoComponent.AllowReadAccessRepo(ctx, code.Repository, currentUser) if !allow { return nil, ErrUnauthorized } @@ -352,7 +368,7 @@ func (c *codeComponentImpl) Relations(ctx context.Context, namespace, name, curr } func (c *codeComponentImpl) getRelations(ctx context.Context, repoID int64, currentUser string) (*types.Relations, error) { - res, err := c.RelatedRepos(ctx, repoID, currentUser) + res, err := c.repoComponent.RelatedRepos(ctx, repoID, currentUser) if err != nil { return nil, err } @@ -387,7 +403,7 @@ func (c *codeComponentImpl) OrgCodes(ctx context.Context, req *types.OrgCodesReq } } onlyPublic := !r.CanRead() - codes, total, err := c.cs.ByOrgPath(ctx, req.Namespace, req.PageSize, req.Page, onlyPublic) + codes, total, err := c.codeStore.ByOrgPath(ctx, req.Namespace, req.PageSize, req.Page, onlyPublic) if err != nil { newError := fmt.Errorf("failed to get org codes,error:%w", err) slog.Error(newError.Error()) diff --git a/component/code_test.go b/component/code_test.go new file mode 100644 index 00000000..4dff2019 --- /dev/null +++ b/component/code_test.go @@ -0,0 +1,225 @@ +package component + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/git/membership" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +func TestCodeComponent_Create(t *testing.T) { + ctx := context.TODO() + cc := initializeTestCodeComponent(ctx, t) + + req := &types.CreateCodeReq{ + CreateRepoReq: types.CreateRepoReq{ + Username: "user", + Namespace: "ns", + Name: "n", + License: "l", + Readme: "r", + }, + } + dbrepo := &database.Repository{ + ID: 1, + User: database.User{Username: "user"}, + Tags: []database.Tag{{Name: "t1"}}, + } + crq := req.CreateRepoReq + crq.Nickname = "n" + crq.Readme = generateReadmeData(req.License) + crq.RepoType = types.CodeRepo + crq.DefaultBranch = "main" + cc.mocks.components.repo.EXPECT().CreateRepo(ctx, crq).Return( + nil, dbrepo, nil, + ) + cc.mocks.stores.CodeMock().EXPECT().Create(ctx, database.Code{ + Repository: dbrepo, + RepositoryID: 1, + }).Return(&database.Code{ + RepositoryID: 1, + Repository: dbrepo, + }, nil) + cc.mocks.gitServer.EXPECT().CreateRepoFile(buildCreateFileReq(&types.CreateFileParams{ + Username: "user", + Message: initCommitMessage, + Branch: "main", + Content: crq.Readme, + NewBranch: "main", + Namespace: "ns", + Name: "n", + FilePath: readmeFileName, + }, types.CodeRepo)).Return(nil) + cc.mocks.gitServer.EXPECT().CreateRepoFile(buildCreateFileReq(&types.CreateFileParams{ + Username: "user", + Message: initCommitMessage, + Branch: "main", + Content: codeGitattributesContent, + NewBranch: "main", + Namespace: "ns", + Name: "n", + FilePath: gitattributesFileName, + }, types.CodeRepo)).Return(nil) + + resp, err := cc.Create(ctx, req) + require.Nil(t, err) + require.Equal(t, &types.Code{ + RepositoryID: 1, + User: types.User{ + Username: "user", + }, + Repository: types.Repository{ + HTTPCloneURL: "/s/.git", + SSHCloneURL: ":s/.git", + }, + Tags: []types.RepoTag{{Name: "t1"}}, + }, resp) +} + +func TestCodeComponent_Index(t *testing.T) { + ctx := context.TODO() + cc := initializeTestCodeComponent(ctx, t) + + filter := &types.RepoFilter{Username: "user"} + repos := []*database.Repository{ + {ID: 1, Name: "r1", Tags: []database.Tag{{Name: "t1"}}}, + {ID: 2, Name: "r2"}, + } + cc.mocks.components.repo.EXPECT().PublicToUser(ctx, types.CodeRepo, "user", filter, 10, 1).Return( + repos, 100, nil, + ) + cc.mocks.stores.CodeMock().EXPECT().ByRepoIDs(ctx, []int64{1, 2}).Return([]database.Code{ + {ID: 11, RepositoryID: 2}, + {ID: 12, RepositoryID: 1}, + }, nil) + + data, total, err := cc.Index(ctx, filter, 10, 1) + require.Nil(t, err) + require.Equal(t, 100, total) + require.Equal(t, []types.Code{ + {ID: 12, RepositoryID: 1, Name: "r1", Tags: []types.RepoTag{{Name: "t1"}}}, + {ID: 11, RepositoryID: 2, Name: "r2"}, + }, data) +} + +func TestCodeComponent_Update(t *testing.T) { + ctx := context.TODO() + cc := initializeTestCodeComponent(ctx, t) + + req := &types.UpdateCodeReq{ + UpdateRepoReq: types.UpdateRepoReq{ + RepoType: types.CodeRepo, + }, + } + dbrepo := &database.Repository{Name: "name"} + cc.mocks.components.repo.EXPECT().UpdateRepo(ctx, req.UpdateRepoReq).Return(dbrepo, nil) + cc.mocks.stores.CodeMock().EXPECT().ByRepoID(ctx, dbrepo.ID).Return(&database.Code{ID: 1}, nil) + cc.mocks.stores.CodeMock().EXPECT().Update(ctx, database.Code{ + ID: 1, + }).Return(nil) + + data, err := cc.Update(ctx, req) + require.Nil(t, err) + require.Equal(t, &types.Code{ID: 1, Name: "name"}, data) + +} + +func TestCodeComponent_Delete(t *testing.T) { + ctx := context.TODO() + cc := initializeTestCodeComponent(ctx, t) + + cc.mocks.stores.CodeMock().EXPECT().FindByPath(ctx, "ns", "n").Return(&database.Code{}, nil) + cc.mocks.components.repo.EXPECT().DeleteRepo(ctx, types.DeleteRepoReq{ + Username: "user", + Namespace: "ns", + Name: "n", + RepoType: types.CodeRepo, + }).Return(nil, nil) + cc.mocks.stores.CodeMock().EXPECT().Delete(ctx, database.Code{}).Return(nil) + + err := cc.Delete(ctx, "ns", "n", "user") + require.Nil(t, err) +} + +func TestCodeComponent_Show(t *testing.T) { + ctx := context.TODO() + cc := initializeTestCodeComponent(ctx, t) + + code := &database.Code{ID: 1, Repository: &database.Repository{ + ID: 11, Name: "name", User: database.User{Username: "user"}, + }} + cc.mocks.stores.CodeMock().EXPECT().FindByPath(ctx, "ns", "n").Return(code, nil) + cc.mocks.components.repo.EXPECT().GetUserRepoPermission(ctx, "user", code.Repository).Return( + &types.UserRepoPermission{CanRead: true, CanAdmin: true}, nil, + ) + cc.mocks.stores.UserLikesMock().EXPECT().IsExist(ctx, "user", int64(11)).Return(true, nil) + cc.mocks.components.repo.EXPECT().GetNameSpaceInfo(ctx, "ns").Return(&types.Namespace{}, nil) + + data, err := cc.Show(ctx, "ns", "n", "user") + require.Nil(t, err) + require.Equal(t, &types.Code{ + ID: 1, + Repository: types.Repository{ + HTTPCloneURL: "/s/.git", + SSHCloneURL: ":s/.git", + }, + RepositoryID: 11, + Namespace: &types.Namespace{}, + Name: "name", + User: types.User{Username: "user"}, + CanManage: true, + UserLikes: true, + }, data) +} + +func TestCodeComponent_Relations(t *testing.T) { + ctx := context.TODO() + cc := initializeTestCodeComponent(ctx, t) + + cc.mocks.stores.CodeMock().EXPECT().FindByPath(ctx, "ns", "n").Return(&database.Code{ + Repository: &database.Repository{}, + RepositoryID: 1, + }, nil) + cc.mocks.components.repo.EXPECT().AllowReadAccessRepo(ctx, &database.Repository{}, "user").Return(true, nil) + cc.mocks.components.repo.EXPECT().RelatedRepos(ctx, int64(1), "user").Return( + map[types.RepositoryType][]*database.Repository{ + types.ModelRepo: { + {Name: "r1"}, + }, + }, nil, + ) + + data, err := cc.Relations(ctx, "ns", "n", "user") + require.Nil(t, err) + require.Equal(t, &types.Relations{ + Models: []*types.Model{{Name: "r1"}}, + }, data) + +} + +func TestCodeComponent_OrgCodes(t *testing.T) { + ctx := context.TODO() + cc := initializeTestCodeComponent(ctx, t) + + cc.mocks.userSvcClient.EXPECT().GetMemberRole(ctx, "ns", "user").Return(membership.RoleAdmin, nil) + cc.mocks.stores.CodeMock().EXPECT().ByOrgPath(ctx, "ns", 10, 1, false).Return( + []database.Code{{ + ID: 1, Repository: &database.Repository{Name: "repo"}, + RepositoryID: 11, + }}, 100, nil, + ) + + data, total, err := cc.OrgCodes(ctx, &types.OrgDatasetsReq{ + Namespace: "ns", CurrentUser: "user", + PageOpts: types.PageOpts{Page: 1, PageSize: 10}, + }) + require.Nil(t, err) + require.Equal(t, 100, total) + require.Equal(t, []types.Code{ + {ID: 1, Name: "repo", RepositoryID: 11}, + }, data) + +} diff --git a/component/collection.go b/component/collection.go index f48d48dc..9d3a8243 100644 --- a/component/collection.go +++ b/component/collection.go @@ -31,11 +31,11 @@ type CollectionComponent interface { func NewCollectionComponent(config *config.Config) (CollectionComponent, error) { cc := &collectionComponentImpl{} - cc.cs = database.NewCollectionStore() - cc.rs = database.NewRepoStore() - cc.us = database.NewUserStore() - cc.os = database.NewOrgStore() - cc.uls = database.NewUserLikesStore() + cc.collectionStore = database.NewCollectionStore() + cc.repoStore = database.NewRepoStore() + cc.userStore = database.NewUserStore() + cc.orgStore = database.NewOrgStore() + cc.userLikesStore = database.NewUserLikesStore() cc.userSvcClient = rpc.NewUserSvcHttpClient(fmt.Sprintf("%s:%d", config.User.Host, config.User.Port), rpc.AuthWithApiKey(config.APIToken)) spaceComponent, err := NewSpaceComponent(config) @@ -47,17 +47,17 @@ func NewCollectionComponent(config *config.Config) (CollectionComponent, error) } type collectionComponentImpl struct { - os database.OrgStore - cs database.CollectionStore - rs database.RepoStore - us database.UserStore - uls database.UserLikesStore - userSvcClient rpc.UserSvcClient - spaceComponent SpaceComponent + collectionStore database.CollectionStore + orgStore database.OrgStore + repoStore database.RepoStore + userStore database.UserStore + userLikesStore database.UserLikesStore + userSvcClient rpc.UserSvcClient + spaceComponent SpaceComponent } func (cc *collectionComponentImpl) GetCollections(ctx context.Context, filter *types.CollectionFilter, per, page int) ([]types.Collection, int, error) { - collections, total, err := cc.cs.GetCollections(ctx, filter, per, page, true) + collections, total, err := cc.collectionStore.GetCollections(ctx, filter, per, page, true) if err != nil { return nil, 0, err } @@ -73,7 +73,7 @@ func (cc *collectionComponentImpl) GetCollections(ctx context.Context, filter *t func (cc *collectionComponentImpl) CreateCollection(ctx context.Context, input types.CreateCollectionReq) (*database.Collection, error) { // find by user name - user, err := cc.us.FindByUsername(ctx, input.Username) + user, err := cc.userStore.FindByUsername(ctx, input.Username) if err != nil { return nil, fmt.Errorf("cannot find user for collection, %w", err) } @@ -92,24 +92,24 @@ func (cc *collectionComponentImpl) CreateCollection(ctx context.Context, input t collection.Username = "" } - return cc.cs.CreateCollection(ctx, collection) + return cc.collectionStore.CreateCollection(ctx, collection) } func (cc *collectionComponentImpl) GetCollection(ctx context.Context, currentUser string, id int64) (*types.Collection, error) { - collection, err := cc.cs.GetCollection(ctx, id) + collection, err := cc.collectionStore.GetCollection(ctx, id) if err != nil { return nil, err } // find by user name avatar := "" if collection.Username != "" { - user, err := cc.us.FindByUsername(ctx, collection.Username) + user, err := cc.userStore.FindByUsername(ctx, collection.Username) if err != nil { return nil, fmt.Errorf("cannot find user for collection, %w", err) } avatar = user.Avatar } else if collection.Namespace != "" { - org, err := cc.os.FindByPath(ctx, collection.Namespace) + org, err := cc.orgStore.FindByPath(ctx, collection.Namespace) if err != nil { return nil, fmt.Errorf("fail to get org info, path: %s, error: %w", collection.Namespace, err) } @@ -131,7 +131,7 @@ func (cc *collectionComponentImpl) GetCollection(ctx context.Context, currentUse if err != nil { return nil, err } - likeExists, err := cc.uls.IsExistCollection(ctx, currentUser, id) + likeExists, err := cc.userLikesStore.IsExistCollection(ctx, currentUser, id) if err != nil { newError := fmt.Errorf("failed to check for the presence of the user likes,error:%w", err) return nil, newError @@ -165,7 +165,7 @@ func (cc *collectionComponentImpl) GetPublicRepos(collection types.Collection) [ } func (cc *collectionComponentImpl) UpdateCollection(ctx context.Context, input types.CreateCollectionReq) (*database.Collection, error) { - collection, err := cc.cs.GetCollection(ctx, input.ID) + collection, err := cc.collectionStore.GetCollection(ctx, input.ID) if err != nil { return nil, fmt.Errorf("cannot find collection to update, %w", err) } @@ -175,25 +175,25 @@ func (cc *collectionComponentImpl) UpdateCollection(ctx context.Context, input t collection.Private = input.Private collection.Theme = input.Theme collection.UpdatedAt = time.Now() - return cc.cs.UpdateCollection(ctx, *collection) + return cc.collectionStore.UpdateCollection(ctx, *collection) } func (cc *collectionComponentImpl) DeleteCollection(ctx context.Context, id int64, userName string) error { // find by user name - user, err := cc.us.FindByUsername(ctx, userName) + user, err := cc.userStore.FindByUsername(ctx, userName) if err != nil { return fmt.Errorf("cannot find user for collection, %w", err) } - return cc.cs.DeleteCollection(ctx, id, user.ID) + return cc.collectionStore.DeleteCollection(ctx, id, user.ID) } func (cc *collectionComponentImpl) AddReposToCollection(ctx context.Context, req types.UpdateCollectionReposReq) error { // find by user name - user, err := cc.us.FindByUsername(ctx, req.Username) + user, err := cc.userStore.FindByUsername(ctx, req.Username) if err != nil { return fmt.Errorf("cannot find user for collection, %w", err) } - collection, err := cc.cs.GetCollection(ctx, req.ID) + collection, err := cc.collectionStore.GetCollection(ctx, req.ID) if err != nil { return err } @@ -207,16 +207,16 @@ func (cc *collectionComponentImpl) AddReposToCollection(ctx context.Context, req RepositoryID: id, }) } - return cc.cs.AddCollectionRepos(ctx, collectionRepos) + return cc.collectionStore.AddCollectionRepos(ctx, collectionRepos) } func (cc *collectionComponentImpl) RemoveReposFromCollection(ctx context.Context, req types.UpdateCollectionReposReq) error { // find by user name - user, err := cc.us.FindByUsername(ctx, req.Username) + user, err := cc.userStore.FindByUsername(ctx, req.Username) if err != nil { return fmt.Errorf("cannot find user for collection, %w", err) } - collection, err := cc.cs.GetCollection(ctx, req.ID) + collection, err := cc.collectionStore.GetCollection(ctx, req.ID) if err != nil { return err } @@ -230,7 +230,7 @@ func (cc *collectionComponentImpl) RemoveReposFromCollection(ctx context.Context RepositoryID: id, }) } - return cc.cs.RemoveCollectionRepos(ctx, collectionRepos) + return cc.collectionStore.RemoveCollectionRepos(ctx, collectionRepos) } func (cc *collectionComponentImpl) getUserCollectionPermission(ctx context.Context, userName string, collection *database.Collection) (*types.UserRepoPermission, error) { @@ -290,7 +290,7 @@ func (c *collectionComponentImpl) OrgCollections(ctx context.Context, req *types } } onlyPublic := !r.CanRead() - collections, total, err := c.cs.ByUserOrgs(ctx, req.Namespace, req.PageSize, req.Page, onlyPublic) + collections, total, err := c.collectionStore.ByUserOrgs(ctx, req.Namespace, req.PageSize, req.Page, onlyPublic) if err != nil { return nil, 0, err } diff --git a/component/collection_test.go b/component/collection_test.go new file mode 100644 index 00000000..a06dbf42 --- /dev/null +++ b/component/collection_test.go @@ -0,0 +1,220 @@ +package component + +import ( + "context" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/git/membership" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +func TestCollectionComponent_GetCollections(t *testing.T) { + ctx := context.TODO() + cc := initializeTestCollectionComponent(ctx, t) + + filter := &types.CollectionFilter{Search: "foo"} + cc.mocks.stores.CollectionMock().EXPECT().GetCollections(ctx, filter, 10, 1, true).Return( + []database.Collection{{Name: "n"}}, 100, nil, + ) + data, total, err := cc.GetCollections(ctx, filter, 10, 1) + require.Nil(t, err) + require.Equal(t, 100, total) + require.Equal(t, []types.Collection{{Name: "n"}}, data) +} + +func TestCollectionComponent_CreateCollection(t *testing.T) { + ctx := context.TODO() + cc := initializeTestCollectionComponent(ctx, t) + + cc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{ + Username: "user", + }, nil) + cc.mocks.stores.CollectionMock().EXPECT().CreateCollection(ctx, database.Collection{ + Username: "user", + Name: "n", + Nickname: "nn", + Description: "d", + }).Return(&database.Collection{}, nil) + + r, err := cc.CreateCollection(ctx, types.CreateCollectionReq{ + Name: "n", + Nickname: "nn", + Description: "d", + Username: "user", + }) + require.Nil(t, err) + require.Equal(t, &database.Collection{}, r) +} + +func TestCollectionComponent_GetCollection(t *testing.T) { + ctx := context.TODO() + cc := initializeTestCollectionComponent(ctx, t) + + repos := []database.Repository{ + {RepositoryType: types.SpaceRepo, Path: "r1/foo"}, + } + cc.mocks.stores.CollectionMock().EXPECT().GetCollection(ctx, int64(1)).Return( + &database.Collection{Username: "user", Namespace: "user", Repositories: repos}, nil, + ) + cc.mocks.stores.CollectionMock().EXPECT().GetCollection(ctx, int64(2)).Return( + &database.Collection{Namespace: "ns", Repositories: repos}, nil, + ) + cc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{ + Username: "user", + Avatar: "aaa", + }, nil) + cc.mocks.stores.OrgMock().EXPECT().FindByPath(ctx, "ns").Return(database.Organization{ + Logo: "logo", + }, nil) + cc.mocks.userSvcClient.EXPECT().GetMemberRole(ctx, "ns", "user").Return(membership.RoleAdmin, nil) + cc.mocks.stores.UserLikesMock().EXPECT().IsExistCollection(ctx, "user", mock.Anything).Return(true, nil) + cc.mocks.components.space.EXPECT().Status(ctx, "r1", "foo").Return("", "go", nil) + + col, err := cc.GetCollection(ctx, "user", 1) + require.Nil(t, err) + require.Equal(t, &types.Collection{ + Username: "user", + Namespace: "user", + UserLikes: true, + CanWrite: true, + CanManage: true, + Avatar: "aaa", + Repositories: []types.CollectionRepository{ + {RepositoryType: types.SpaceRepo, Path: "r1/foo", Status: "go"}, + }, + }, col) + col, err = cc.GetCollection(ctx, "user", 2) + require.Nil(t, err) + require.Equal(t, &types.Collection{ + Namespace: "ns", + UserLikes: true, + CanWrite: true, + CanManage: true, + Avatar: "logo", + Repositories: []types.CollectionRepository{ + {RepositoryType: types.SpaceRepo, Path: "r1/foo", Status: "go"}, + }, + }, col) +} + +func TestCollectionComponent_GetPublicRepos(t *testing.T) { + ctx := context.TODO() + cc := initializeTestCollectionComponent(ctx, t) + + repos := []types.CollectionRepository{ + {RepositoryType: types.SpaceRepo, Path: "r1/foo", Private: true}, + {RepositoryType: types.SpaceRepo, Path: "r1/foo", Private: false}, + } + r := cc.GetPublicRepos(types.Collection{Repositories: repos}) + require.Equal(t, []types.CollectionRepository{ + {RepositoryType: types.SpaceRepo, Path: "r1/foo", Private: false}, + }, r) +} + +func TestCollectionComponent_UpdateCollection(t *testing.T) { + ctx := context.TODO() + cc := initializeTestCollectionComponent(ctx, t) + + cc.mocks.stores.CollectionMock().EXPECT().GetCollection(ctx, int64(1)).Return( + &database.Collection{}, nil, + ) + cc.mocks.stores.CollectionMock().EXPECT().UpdateCollection(ctx, mock.Anything).RunAndReturn(func(ctx context.Context, c database.Collection) (*database.Collection, error) { + require.Equal(t, c.Name, "n") + require.True(t, c.Private) + return &database.Collection{}, nil + }) + + r, err := cc.UpdateCollection(ctx, types.CreateCollectionReq{ + ID: 1, + Name: "n", + Private: true, + }) + require.Nil(t, err) + require.Equal(t, &database.Collection{}, r) +} + +func TestCollectionComponent_DeleteCollection(t *testing.T) { + ctx := context.TODO() + cc := initializeTestCollectionComponent(ctx, t) + + cc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{ + ID: 2, + Username: "user", + }, nil) + cc.mocks.stores.CollectionMock().EXPECT().DeleteCollection(ctx, int64(1), int64(2)).Return(nil) + + err := cc.DeleteCollection(ctx, 1, "user") + require.Nil(t, err) + +} + +func TestCollectionComponent_AddReposToCollection(t *testing.T) { + ctx := context.TODO() + cc := initializeTestCollectionComponent(ctx, t) + + cc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{ + ID: 2, + Username: "user", + }, nil) + cc.mocks.stores.CollectionMock().EXPECT().GetCollection(ctx, int64(1)).Return( + &database.Collection{UserID: 2}, nil, + ) + cc.mocks.stores.CollectionMock().EXPECT().AddCollectionRepos(ctx, []database.CollectionRepository{ + {CollectionID: 1, RepositoryID: 1}, + {CollectionID: 1, RepositoryID: 2}, + }).Return(nil) + + err := cc.AddReposToCollection(ctx, types.UpdateCollectionReposReq{ + RepoIDs: []int64{1, 2}, + Username: "user", + ID: 1, + }) + require.Nil(t, err) + +} + +func TestCollectionComponent_RemoveReposFromCollection(t *testing.T) { + ctx := context.TODO() + cc := initializeTestCollectionComponent(ctx, t) + + cc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{ + ID: 2, + Username: "user", + }, nil) + cc.mocks.stores.CollectionMock().EXPECT().GetCollection(ctx, int64(1)).Return( + &database.Collection{UserID: 2}, nil, + ) + cc.mocks.stores.CollectionMock().EXPECT().RemoveCollectionRepos(ctx, []database.CollectionRepository{ + {CollectionID: 1, RepositoryID: 1}, + {CollectionID: 1, RepositoryID: 2}, + }).Return(nil) + + err := cc.RemoveReposFromCollection(ctx, types.UpdateCollectionReposReq{ + RepoIDs: []int64{1, 2}, + Username: "user", + ID: 1, + }) + require.Nil(t, err) + +} + +func TestCollectionComponent_OrgCollections(t *testing.T) { + ctx := context.TODO() + cc := initializeTestCollectionComponent(ctx, t) + + cc.mocks.userSvcClient.EXPECT().GetMemberRole(ctx, "ns", "user").Return(membership.RoleAdmin, nil) + cc.mocks.stores.CollectionMock().EXPECT().ByUserOrgs(ctx, "ns", 10, 1, false).Return([]database.Collection{ + {Name: "col"}, + }, 100, nil) + + cols, total, err := cc.OrgCollections(ctx, &types.OrgDatasetsReq{ + Namespace: "ns", CurrentUser: "user", + PageOpts: types.PageOpts{Page: 1, PageSize: 10}, + }) + require.Nil(t, err) + require.Equal(t, 100, total) + require.Equal(t, []types.Collection{{Name: "col"}}, cols) +} diff --git a/component/dataset.go b/component/dataset.go index 44b224ed..302fbf68 100644 --- a/component/dataset.go +++ b/component/dataset.go @@ -7,7 +7,10 @@ import ( "log/slog" "time" + "opencsg.com/csghub-server/builder/git" + "opencsg.com/csghub-server/builder/git/gitserver" "opencsg.com/csghub-server/builder/git/membership" + "opencsg.com/csghub-server/builder/rpc" "opencsg.com/csghub-server/builder/store/database" "opencsg.com/csghub-server/common/config" "opencsg.com/csghub-server/common/types" @@ -91,27 +94,44 @@ type DatasetComponent interface { func NewDatasetComponent(config *config.Config) (DatasetComponent, error) { c := &datasetComponentImpl{} - c.ts = database.NewTagStore() - c.ds = database.NewDatasetStore() - c.rs = database.NewRepoStore() + c.tagStore = database.NewTagStore() + c.datasetStore = database.NewDatasetStore() + c.repoStore = database.NewRepoStore() + c.namespaceStore = database.NewNamespaceStore() + c.userStore = database.NewUserStore() + c.userLikesStore = database.NewUserLikesStore() var err error - c.repoComponentImpl, err = NewRepoComponentImpl(config) + c.repoComponent, err = NewRepoComponentImpl(config) if err != nil { return nil, fmt.Errorf("failed to create repo component, error: %w", err) } - c.sc, err = NewSensitiveComponent(config) + c.sensitiveComponent, err = NewSensitiveComponent(config) if err != nil { return nil, fmt.Errorf("failed to create sensitive component, error: %w", err) } + gs, err := git.NewGitServer(config) + if err != nil { + return nil, fmt.Errorf("failed to create git server, error: %w", err) + } + c.userSvcClient = rpc.NewUserSvcHttpClient(fmt.Sprintf("%s:%d", config.User.Host, config.User.Port), + rpc.AuthWithApiKey(config.APIToken)) + c.gitServer = gs + c.config = config return c, nil } type datasetComponentImpl struct { - *repoComponentImpl - ts database.TagStore - ds database.DatasetStore - rs database.RepoStore - sc SensitiveComponent + config *config.Config + repoComponent RepoComponent + tagStore database.TagStore + datasetStore database.DatasetStore + repoStore database.RepoStore + namespaceStore database.NamespaceStore + userStore database.UserStore + sensitiveComponent SensitiveComponent + gitServer gitserver.GitServer + userLikesStore database.UserLikesStore + userSvcClient rpc.UserSvcClient } func (c *datasetComponentImpl) Create(ctx context.Context, req *types.CreateDatasetReq) (*types.Dataset, error) { @@ -131,7 +151,7 @@ func (c *datasetComponentImpl) Create(ctx context.Context, req *types.CreateData } if !user.CanAdmin() { if namespace.NamespaceType == database.OrgNamespace { - canWrite, err := c.CheckCurrentUserPermission(ctx, req.Username, req.Namespace, membership.RoleWrite) + canWrite, err := c.repoComponent.CheckCurrentUserPermission(ctx, req.Username, req.Namespace, membership.RoleWrite) if err != nil { return nil, err } @@ -158,7 +178,7 @@ func (c *datasetComponentImpl) Create(ctx context.Context, req *types.CreateData req.RepoType = types.DatasetRepo req.Readme = generateReadmeData(req.License) req.Nickname = nickname - _, dbRepo, err := c.CreateRepo(ctx, req.CreateRepoReq) + _, dbRepo, err := c.repoComponent.CreateRepo(ctx, req.CreateRepoReq) if err != nil { return nil, err } @@ -168,13 +188,13 @@ func (c *datasetComponentImpl) Create(ctx context.Context, req *types.CreateData RepositoryID: dbRepo.ID, } - dataset, err := c.ds.Create(ctx, dbDataset) + dataset, err := c.datasetStore.Create(ctx, dbDataset) if err != nil { return nil, fmt.Errorf("failed to create database dataset, cause: %w", err) } // Create README.md file - err = c.git.CreateRepoFile(buildCreateFileReq(&types.CreateFileParams{ + err = c.gitServer.CreateRepoFile(buildCreateFileReq(&types.CreateFileParams{ Username: user.Username, Email: user.Email, Message: initCommitMessage, @@ -190,7 +210,7 @@ func (c *datasetComponentImpl) Create(ctx context.Context, req *types.CreateData } // Create .gitattributes file - err = c.git.CreateRepoFile(buildCreateFileReq(&types.CreateFileParams{ + err = c.gitServer.CreateRepoFile(buildCreateFileReq(&types.CreateFileParams{ Username: user.Username, Email: user.Email, Message: initCommitMessage, @@ -258,7 +278,7 @@ func (c *datasetComponentImpl) commonIndex(ctx context.Context, filter *types.Re err error resDatasets []types.Dataset ) - repos, total, err := c.PublicToUser(ctx, types.DatasetRepo, filter.Username, filter, per, page) + repos, total, err := c.repoComponent.PublicToUser(ctx, types.DatasetRepo, filter.Username, filter, per, page) if err != nil { newError := fmt.Errorf("failed to get public dataset repos,error:%w", err) return nil, 0, newError @@ -267,7 +287,7 @@ func (c *datasetComponentImpl) commonIndex(ctx context.Context, filter *types.Re for _, repo := range repos { repoIDs = append(repoIDs, repo.ID) } - datasets, err := c.ds.ByRepoIDs(ctx, repoIDs) + datasets, err := c.datasetStore.ByRepoIDs(ctx, repoIDs) if err != nil { newError := fmt.Errorf("failed to get datasets by repo ids,error:%w", err) return nil, 0, newError @@ -328,18 +348,18 @@ func (c *datasetComponentImpl) commonIndex(ctx context.Context, filter *types.Re func (c *datasetComponentImpl) Update(ctx context.Context, req *types.UpdateDatasetReq) (*types.Dataset, error) { req.RepoType = types.DatasetRepo - dbRepo, err := c.UpdateRepo(ctx, req.UpdateRepoReq) + dbRepo, err := c.repoComponent.UpdateRepo(ctx, req.UpdateRepoReq) if err != nil { return nil, err } - dataset, err := c.ds.ByRepoID(ctx, dbRepo.ID) + dataset, err := c.datasetStore.ByRepoID(ctx, dbRepo.ID) if err != nil { return nil, fmt.Errorf("failed to find dataset, error: %w", err) } // update times of dateset - err = c.ds.Update(ctx, *dataset) + err = c.datasetStore.Update(ctx, *dataset) if err != nil { return nil, fmt.Errorf("failed to update database dataset, error: %w", err) } @@ -362,7 +382,7 @@ func (c *datasetComponentImpl) Update(ctx context.Context, req *types.UpdateData } func (c *datasetComponentImpl) Delete(ctx context.Context, namespace, name, currentUser string) error { - dataset, err := c.ds.FindByPath(ctx, namespace, name) + dataset, err := c.datasetStore.FindByPath(ctx, namespace, name) if err != nil { return fmt.Errorf("failed to find dataset, error: %w", err) } @@ -373,12 +393,12 @@ func (c *datasetComponentImpl) Delete(ctx context.Context, namespace, name, curr Name: name, RepoType: types.DatasetRepo, } - _, err = c.DeleteRepo(ctx, deleteDatabaseRepoReq) + _, err = c.repoComponent.DeleteRepo(ctx, deleteDatabaseRepoReq) if err != nil { return fmt.Errorf("failed to delete repo of dataset, error: %w", err) } - err = c.ds.Delete(ctx, *dataset) + err = c.datasetStore.Delete(ctx, *dataset) if err != nil { return fmt.Errorf("failed to delete database dataset, error: %w", err) } @@ -387,12 +407,12 @@ func (c *datasetComponentImpl) Delete(ctx context.Context, namespace, name, curr func (c *datasetComponentImpl) Show(ctx context.Context, namespace, name, currentUser string) (*types.Dataset, error) { var tags []types.RepoTag - dataset, err := c.ds.FindByPath(ctx, namespace, name) + dataset, err := c.datasetStore.FindByPath(ctx, namespace, name) if err != nil { return nil, fmt.Errorf("failed to find dataset, error: %w", err) } - permission, err := c.GetUserRepoPermission(ctx, currentUser, dataset.Repository) + permission, err := c.repoComponent.GetUserRepoPermission(ctx, currentUser, dataset.Repository) if err != nil { return nil, fmt.Errorf("failed to get user repo permission, error: %w", err) } @@ -400,7 +420,7 @@ func (c *datasetComponentImpl) Show(ctx context.Context, namespace, name, curren return nil, ErrUnauthorized } - ns, err := c.GetNameSpaceInfo(ctx, namespace) + ns, err := c.repoComponent.GetNameSpaceInfo(ctx, namespace) if err != nil { return nil, fmt.Errorf("failed to get namespace info for dataset, error: %w", err) } @@ -458,12 +478,12 @@ func (c *datasetComponentImpl) Show(ctx context.Context, namespace, name, curren } func (c *datasetComponentImpl) Relations(ctx context.Context, namespace, name, currentUser string) (*types.Relations, error) { - dataset, err := c.ds.FindByPath(ctx, namespace, name) + dataset, err := c.datasetStore.FindByPath(ctx, namespace, name) if err != nil { return nil, fmt.Errorf("failed to find dataset repo, error: %w", err) } - allow, _ := c.AllowReadAccessRepo(ctx, dataset.Repository, currentUser) + allow, _ := c.repoComponent.AllowReadAccessRepo(ctx, dataset.Repository, currentUser) if !allow { return nil, ErrUnauthorized } @@ -472,7 +492,7 @@ func (c *datasetComponentImpl) Relations(ctx context.Context, namespace, name, c } func (c *datasetComponentImpl) getRelations(ctx context.Context, repoID int64, currentUser string) (*types.Relations, error) { - res, err := c.RelatedRepos(ctx, repoID, currentUser) + res, err := c.repoComponent.RelatedRepos(ctx, repoID, currentUser) if err != nil { return nil, err } @@ -507,7 +527,7 @@ func (c *datasetComponentImpl) OrgDatasets(ctx context.Context, req *types.OrgDa } } onlyPublic := !r.CanRead() - datasets, total, err := c.ds.ByOrgPath(ctx, req.Namespace, req.PageSize, req.Page, onlyPublic) + datasets, total, err := c.datasetStore.ByOrgPath(ctx, req.Namespace, req.PageSize, req.Page, onlyPublic) if err != nil { newError := fmt.Errorf("failed to get user datasets,error:%w", err) slog.Error(newError.Error()) diff --git a/component/dataset_test.go b/component/dataset_test.go new file mode 100644 index 00000000..69f1744d --- /dev/null +++ b/component/dataset_test.go @@ -0,0 +1,251 @@ +package component + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/git/gitserver" + "opencsg.com/csghub-server/builder/git/membership" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +func TestDatasetCompnent_Create(t *testing.T) { + ctx := context.TODO() + dc := initializeTestDatasetComponent(ctx, t) + + req := &types.CreateDatasetReq{ + CreateRepoReq: types.CreateRepoReq{ + Username: "user", + Namespace: "ns", + Name: "n", + }, + } + dc.mocks.stores.NamespaceMock().EXPECT().FindByPath(ctx, "ns").Return( + database.Namespace{}, nil, + ) + dc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{ + RoleMask: "admin", + Username: "user", + }, nil) + rq := req.CreateRepoReq + rq.RepoType = types.DatasetRepo + rq.Readme = "\n---\nlicense: \n---\n\t" + rq.DefaultBranch = "main" + rq.Nickname = "n" + dc.mocks.components.repo.EXPECT().CreateRepo(ctx, rq).Return(&gitserver.CreateRepoResp{}, &database.Repository{}, nil) + dc.mocks.stores.DatasetMock().EXPECT().Create(ctx, database.Dataset{ + Repository: &database.Repository{}, + }).Return(&database.Dataset{ + Repository: &database.Repository{ + Tags: []database.Tag{{Name: "t1"}}, + }, + }, nil) + dc.mocks.gitServer.EXPECT().CreateRepoFile(buildCreateFileReq( + &types.CreateFileParams{ + Username: "user", + Message: initCommitMessage, + Branch: "main", + Content: "\n---\nlicense: \n---\n\t", + Namespace: "ns", + Name: "n", + FilePath: readmeFileName, + }, types.DatasetRepo), + ).Return(nil) + dc.mocks.gitServer.EXPECT().CreateRepoFile(buildCreateFileReq( + &types.CreateFileParams{ + Username: "user", + Message: initCommitMessage, + Branch: "main", + Content: datasetGitattributesContent, + Namespace: "ns", + Name: "n", + FilePath: gitattributesFileName, + }, types.DatasetRepo), + ).Return(nil) + + resp, err := dc.Create(ctx, req) + require.Nil(t, err) + require.Equal(t, &types.Dataset{ + User: types.User{Username: "user"}, + Repository: types.Repository{ + HTTPCloneURL: "/s/.git", + SSHCloneURL: ":s/.git", + }, + Tags: []types.RepoTag{{Name: "t1"}}, + }, resp) + +} + +func TestDatasetCompnent_Index(t *testing.T) { + ctx := context.TODO() + dc := initializeTestDatasetComponent(ctx, t) + + filter := &types.RepoFilter{Username: "user"} + dc.mocks.components.repo.EXPECT().PublicToUser(ctx, types.DatasetRepo, "user", filter, 10, 1).Return( + []*database.Repository{ + {ID: 1, Tags: []database.Tag{{Name: "t1"}}}, + {ID: 2}, + }, 100, nil, + ) + dc.mocks.stores.DatasetMock().EXPECT().ByRepoIDs(ctx, []int64{1, 2}).Return([]database.Dataset{ + { + ID: 11, RepositoryID: 2, Repository: &database.Repository{ + User: database.User{Username: "user2"}, + }, + }, + { + ID: 12, RepositoryID: 1, Repository: &database.Repository{ + User: database.User{Username: "user1"}, + }, + }, + }, nil) + + data, total, err := dc.Index(ctx, filter, 10, 1) + require.Nil(t, err) + require.Equal(t, 100, total) + require.Equal(t, []types.Dataset{ + {ID: 12, RepositoryID: 1, Repository: types.Repository{ + HTTPCloneURL: "/s/.git", + SSHCloneURL: ":s/.git", + }, User: types.User{Username: "user1"}, + Tags: []types.RepoTag{{Name: "t1"}}, + }, + {ID: 11, RepositoryID: 2, Repository: types.Repository{ + HTTPCloneURL: "/s/.git", + SSHCloneURL: ":s/.git", + }, User: types.User{Username: "user2"}}, + }, data) + +} + +func TestDatasetCompnent_Update(t *testing.T) { + ctx := context.TODO() + dc := initializeTestDatasetComponent(ctx, t) + + req := &types.UpdateDatasetReq{UpdateRepoReq: types.UpdateRepoReq{ + RepoType: types.DatasetRepo, + }} + dc.mocks.components.repo.EXPECT().UpdateRepo(ctx, req.UpdateRepoReq).Return( + &database.Repository{ID: 1, Name: "repo"}, nil, + ) + dc.mocks.stores.DatasetMock().EXPECT().ByRepoID(ctx, int64(1)).Return( + &database.Dataset{ID: 2}, nil, + ) + dc.mocks.stores.DatasetMock().EXPECT().Update(ctx, database.Dataset{ID: 2}).Return(nil) + + d, err := dc.Update(ctx, req) + require.Nil(t, err) + require.Equal(t, &types.Dataset{ + ID: 2, + RepositoryID: 1, + Name: "repo", + }, d) +} + +func TestDatasetCompnent_Delete(t *testing.T) { + ctx := context.TODO() + dc := initializeTestDatasetComponent(ctx, t) + + dc.mocks.stores.DatasetMock().EXPECT().FindByPath(ctx, "ns", "n").Return(&database.Dataset{}, nil) + dc.mocks.components.repo.EXPECT().DeleteRepo(ctx, types.DeleteRepoReq{ + Username: "user", + Namespace: "ns", + Name: "n", + RepoType: types.DatasetRepo, + }).Return(&database.Repository{}, nil) + dc.mocks.stores.DatasetMock().EXPECT().Delete(ctx, database.Dataset{}).Return(nil) + + err := dc.Delete(ctx, "ns", "n", "user") + require.Nil(t, err) + +} + +func TestDatasetCompnent_Show(t *testing.T) { + ctx := context.TODO() + dc := initializeTestDatasetComponent(ctx, t) + + dataset := &database.Dataset{ + ID: 1, + Repository: &database.Repository{ + ID: 2, + Name: "n", + Tags: []database.Tag{{Name: "t1"}}, + User: database.User{ + Username: "user", + }, + }, + } + dc.mocks.stores.DatasetMock().EXPECT().FindByPath(ctx, "ns", "n").Return(dataset, nil) + dc.mocks.components.repo.EXPECT().GetUserRepoPermission(ctx, "user", dataset.Repository).Return(&types.UserRepoPermission{CanRead: true}, nil) + dc.mocks.components.repo.EXPECT().GetNameSpaceInfo(ctx, "ns").Return(&types.Namespace{}, nil) + dc.mocks.stores.UserLikesMock().EXPECT().IsExist(ctx, "user", int64(2)).Return(true, nil) + + d, err := dc.Show(ctx, "ns", "n", "user") + require.Nil(t, err) + require.Equal(t, &types.Dataset{ + ID: 1, + Name: "n", + RepositoryID: 2, + Tags: []types.RepoTag{{Name: "t1"}}, + Repository: types.Repository{ + HTTPCloneURL: "/s/.git", + SSHCloneURL: ":s/.git", + }, + User: types.User{Username: "user"}, + UserLikes: true, + Namespace: &types.Namespace{}, + }, d) + +} + +func TestDatasetCompnent_Relations(t *testing.T) { + ctx := context.TODO() + dc := initializeTestDatasetComponent(ctx, t) + + dataset := &database.Dataset{ + Repository: &database.Repository{}, + RepositoryID: 1, + } + dc.mocks.stores.DatasetMock().EXPECT().FindByPath(ctx, "ns", "n").Return(dataset, nil) + dc.mocks.components.repo.EXPECT().RelatedRepos(ctx, int64(1), "user").Return( + map[types.RepositoryType][]*database.Repository{ + types.ModelRepo: { + {Name: "n"}, + }, + }, nil, + ) + dc.mocks.components.repo.EXPECT().AllowReadAccessRepo(ctx, dataset.Repository, "user").Return(true, nil) + + rs, err := dc.Relations(ctx, "ns", "n", "user") + require.Nil(t, err) + require.Equal(t, &types.Relations{ + Models: []*types.Model{{Name: "n"}}, + }, rs) + +} + +func TestDatasetCompnent_OrgDatasets(t *testing.T) { + ctx := context.TODO() + dc := initializeTestDatasetComponent(ctx, t) + + dc.mocks.userSvcClient.EXPECT().GetMemberRole(ctx, "ns", "user").Return(membership.RoleAdmin, nil) + dc.mocks.stores.DatasetMock().EXPECT().ByOrgPath(ctx, "ns", 10, 1, false).Return( + []database.Dataset{ + {ID: 1, Repository: &database.Repository{Name: "repo"}}, + }, 100, nil, + ) + + data, total, err := dc.OrgDatasets(ctx, &types.OrgDatasetsReq{ + Namespace: "ns", + CurrentUser: "user", + PageOpts: types.PageOpts{Page: 1, PageSize: 10}, + }) + require.Nil(t, err) + require.Equal(t, 100, total) + require.Equal(t, []types.Dataset{ + {ID: 1, Name: "repo"}, + }, data) + +} diff --git a/component/discussion.go b/component/discussion.go index 1b2caf3c..06f30e9d 100644 --- a/component/discussion.go +++ b/component/discussion.go @@ -12,9 +12,9 @@ import ( ) type discussionComponentImpl struct { - ds database.DiscussionStore - rs database.RepoStore - us database.UserStore + discussionStore database.DiscussionStore + repoStore database.RepoStore + userStore database.UserStore } type DiscussionComponent interface { @@ -33,22 +33,22 @@ func NewDiscussionComponent() DiscussionComponent { ds := database.NewDiscussionStore() rs := database.NewRepoStore() us := database.NewUserStore() - return &discussionComponentImpl{ds: ds, rs: rs, us: us} + return &discussionComponentImpl{discussionStore: ds, repoStore: rs, userStore: us} } func (c *discussionComponentImpl) CreateRepoDiscussion(ctx context.Context, req CreateRepoDiscussionRequest) (*CreateDiscussionResponse, error) { //TODO:check if the user can access the repo //get repo by namespace and name - repo, err := c.rs.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) + repo, err := c.repoStore.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return nil, fmt.Errorf("failed to find repo by path '%s/%s/%s': %w", req.RepoType, req.Namespace, req.Name, err) } - user, err := c.us.FindByUsername(ctx, req.CurrentUser) + user, err := c.userStore.FindByUsername(ctx, req.CurrentUser) if err != nil { return nil, fmt.Errorf("failed to find user by username '%s': %w", req.CurrentUser, err) } - discussion, err := c.ds.Create(ctx, database.Discussion{ + discussion, err := c.discussionStore.Create(ctx, database.Discussion{ Title: req.Title, DiscussionableID: repo.ID, DiscussionableType: database.DiscussionableTypeRepo, @@ -72,11 +72,11 @@ func (c *discussionComponentImpl) CreateRepoDiscussion(ctx context.Context, req } func (c *discussionComponentImpl) GetDiscussion(ctx context.Context, id int64) (*ShowDiscussionResponse, error) { - discussion, err := c.ds.FindByID(ctx, id) + discussion, err := c.discussionStore.FindByID(ctx, id) if err != nil { return nil, fmt.Errorf("failed to find discussion by id '%d': %w", id, err) } - comments, err := c.ds.FindDiscussionComments(ctx, discussion.ID) + comments, err := c.discussionStore.FindDiscussionComments(ctx, discussion.ID) if err != nil && err != sql.ErrNoRows { return nil, fmt.Errorf("failed to find discussion comments by discussion id '%d': %w", discussion.ID, err) } @@ -105,18 +105,18 @@ func (c *discussionComponentImpl) GetDiscussion(ctx context.Context, id int64) ( func (c *discussionComponentImpl) UpdateDiscussion(ctx context.Context, req UpdateDiscussionRequest) error { //check if the user is the owner of the discussion - user, err := c.us.FindByUsername(ctx, req.CurrentUser) + user, err := c.userStore.FindByUsername(ctx, req.CurrentUser) if err != nil { return fmt.Errorf("failed to find user by username '%s': %w", req.CurrentUser, err) } - discussion, err := c.ds.FindByID(ctx, req.ID) + discussion, err := c.discussionStore.FindByID(ctx, req.ID) if err != nil { return fmt.Errorf("failed to find discussion by id '%d': %w", req.ID, err) } if discussion.UserID != user.ID { return fmt.Errorf("user '%s' is not the owner of the discussion '%d'", req.CurrentUser, req.ID) } - err = c.ds.UpdateByID(ctx, req.ID, req.Title) + err = c.discussionStore.UpdateByID(ctx, req.ID, req.Title) if err != nil { return fmt.Errorf("failed to update discussion by id '%d': %w", req.ID, err) } @@ -124,14 +124,14 @@ func (c *discussionComponentImpl) UpdateDiscussion(ctx context.Context, req Upda } func (c *discussionComponentImpl) DeleteDiscussion(ctx context.Context, currentUser string, id int64) error { - discussion, err := c.ds.FindByID(ctx, id) + discussion, err := c.discussionStore.FindByID(ctx, id) if err != nil { return fmt.Errorf("failed to find discussion by id '%d': %w", id, err) } if discussion.User.Username != currentUser { return fmt.Errorf("user '%s' is not the owner of the discussion '%d'", currentUser, id) } - err = c.ds.DeleteByID(ctx, id) + err = c.discussionStore.DeleteByID(ctx, id) if err != nil { return fmt.Errorf("failed to delete discussion by id '%d': %w", id, err) } @@ -140,11 +140,11 @@ func (c *discussionComponentImpl) DeleteDiscussion(ctx context.Context, currentU func (c *discussionComponentImpl) ListRepoDiscussions(ctx context.Context, req ListRepoDiscussionRequest) (*ListRepoDiscussionResponse, error) { //TODO:check if the user can access the repo - repo, err := c.rs.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) + repo, err := c.repoStore.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return nil, fmt.Errorf("failed to find repo by path '%s/%s/%s': %w", req.RepoType, req.Namespace, req.Name, err) } - discussions, err := c.ds.FindByDiscussionableID(ctx, database.DiscussionableTypeRepo, repo.ID) + discussions, err := c.discussionStore.FindByDiscussionableID(ctx, database.DiscussionableTypeRepo, repo.ID) if err != nil { return nil, fmt.Errorf("failed to list repo discussions by repo type '%s', namespace '%s', name '%s': %w", req.RepoType, req.Namespace, req.Name, err) } @@ -168,18 +168,18 @@ func (c *discussionComponentImpl) ListRepoDiscussions(ctx context.Context, req L func (c *discussionComponentImpl) CreateDiscussionComment(ctx context.Context, req CreateCommentRequest) (*CreateCommentResponse, error) { req.CommentableType = database.CommentableTypeDiscussion // get discussion by id - _, err := c.ds.FindByID(ctx, req.CommentableID) + _, err := c.discussionStore.FindByID(ctx, req.CommentableID) if err != nil { return nil, fmt.Errorf("failed to find discussion by id '%d': %w", req.CommentableID, err) } //get user by username - user, err := c.us.FindByUsername(ctx, req.CurrentUser) + user, err := c.userStore.FindByUsername(ctx, req.CurrentUser) if err != nil { return nil, fmt.Errorf("failed to find user by username '%s': %w", req.CurrentUser, err) } // create comment - comment, err := c.ds.CreateComment(ctx, database.Comment{ + comment, err := c.discussionStore.CreateComment(ctx, database.Comment{ Content: req.Content, CommentableID: req.CommentableID, CommentableType: req.CommentableType, @@ -202,12 +202,12 @@ func (c *discussionComponentImpl) CreateDiscussionComment(ctx context.Context, r } func (c *discussionComponentImpl) UpdateComment(ctx context.Context, currentUser string, id int64, content string) error { - user, err := c.us.FindByUsername(ctx, currentUser) + user, err := c.userStore.FindByUsername(ctx, currentUser) if err != nil { return fmt.Errorf("failed to find user by username '%s': %w", currentUser, err) } //get comment by id - comment, err := c.ds.FindCommentByID(ctx, id) + comment, err := c.discussionStore.FindCommentByID(ctx, id) if err != nil { return fmt.Errorf("failed to find comment by id '%d': %w", id, err) } @@ -215,7 +215,7 @@ func (c *discussionComponentImpl) UpdateComment(ctx context.Context, currentUser if comment.UserID != user.ID { return fmt.Errorf("user '%s' is not the owner of the comment '%d'", currentUser, id) } - err = c.ds.UpdateComment(ctx, id, content) + err = c.discussionStore.UpdateComment(ctx, id, content) if err != nil { return fmt.Errorf("failed to update comment by id '%d': %w", id, err) } @@ -223,12 +223,12 @@ func (c *discussionComponentImpl) UpdateComment(ctx context.Context, currentUser } func (c *discussionComponentImpl) DeleteComment(ctx context.Context, currentUser string, id int64) error { - user, err := c.us.FindByUsername(ctx, currentUser) + user, err := c.userStore.FindByUsername(ctx, currentUser) if err != nil { return fmt.Errorf("failed to find user by username '%s': %w", currentUser, err) } //get comment by id - comment, err := c.ds.FindCommentByID(ctx, id) + comment, err := c.discussionStore.FindCommentByID(ctx, id) if err != nil { return fmt.Errorf("failed to find comment by id '%d': %w", id, err) } @@ -236,7 +236,7 @@ func (c *discussionComponentImpl) DeleteComment(ctx context.Context, currentUser if comment.UserID != user.ID { return fmt.Errorf("user '%s' is not the owner of the comment '%d'", currentUser, id) } - err = c.ds.DeleteComment(ctx, id) + err = c.discussionStore.DeleteComment(ctx, id) if err != nil { return fmt.Errorf("failed to delete comment by id '%d': %w", id, err) } @@ -244,7 +244,7 @@ func (c *discussionComponentImpl) DeleteComment(ctx context.Context, currentUser } func (c *discussionComponentImpl) ListDiscussionComments(ctx context.Context, discussionID int64) ([]*DiscussionResponse_Comment, error) { - comments, err := c.ds.FindDiscussionComments(ctx, discussionID) + comments, err := c.discussionStore.FindDiscussionComments(ctx, discussionID) if err != nil { return nil, fmt.Errorf("failed to find discussion comments by discussion id '%d': %w", discussionID, err) } diff --git a/component/discussion_test.go b/component/discussion_test.go index 1813f69a..ddceff7a 100644 --- a/component/discussion_test.go +++ b/component/discussion_test.go @@ -18,9 +18,9 @@ func TestDiscussionComponent_CreateDisucssion(t *testing.T) { mockDiscussionStore := mockdb.NewMockDiscussionStore(t) // new discussionComponentImpl from mock db store comp := &discussionComponentImpl{ - rs: mockRepoStore, - us: mockUserStore, - ds: mockDiscussionStore, + repoStore: mockRepoStore, + userStore: mockUserStore, + discussionStore: mockDiscussionStore, } repo := &database.Repository{ @@ -77,9 +77,9 @@ func TestDiscussionComponent_GetDisussion(t *testing.T) { mockDiscussionStore := mockdb.NewMockDiscussionStore(t) // new discussionComponentImpl from mock db store comp := &discussionComponentImpl{ - rs: mockRepoStore, - us: mockUserStore, - ds: mockDiscussionStore, + repoStore: mockRepoStore, + userStore: mockUserStore, + discussionStore: mockDiscussionStore, } disc := database.Discussion{ @@ -126,9 +126,9 @@ func TestDiscussionComponent_UpdateDisussion(t *testing.T) { mockDiscussionStore := mockdb.NewMockDiscussionStore(t) // new discussionComponentImpl from mock db store comp := &discussionComponentImpl{ - rs: mockRepoStore, - us: mockUserStore, - ds: mockDiscussionStore, + repoStore: mockRepoStore, + userStore: mockUserStore, + discussionStore: mockDiscussionStore, } req := UpdateDiscussionRequest{ @@ -164,9 +164,9 @@ func TestDiscussionComponent_DeleteDisussion(t *testing.T) { mockDiscussionStore := mockdb.NewMockDiscussionStore(t) // new discussionComponentImpl from mock db store comp := &discussionComponentImpl{ - rs: mockRepoStore, - us: mockUserStore, - ds: mockDiscussionStore, + repoStore: mockRepoStore, + userStore: mockUserStore, + discussionStore: mockDiscussionStore, } currentUser := "user" @@ -197,9 +197,9 @@ func TestDiscussionComponent_ListRepoDiscussions(t *testing.T) { mockDiscussionStore := mockdb.NewMockDiscussionStore(t) // new discussionComponentImpl from mock db store comp := &discussionComponentImpl{ - rs: mockRepoStore, - us: mockUserStore, - ds: mockDiscussionStore, + repoStore: mockRepoStore, + userStore: mockUserStore, + discussionStore: mockDiscussionStore, } repo := &database.Repository{ @@ -239,9 +239,9 @@ func TestDiscussionComponent_CreateDisussionComment(t *testing.T) { mockDiscussionStore := mockdb.NewMockDiscussionStore(t) // new discussionComponentImpl from mock db store comp := &discussionComponentImpl{ - rs: mockRepoStore, - us: mockUserStore, - ds: mockDiscussionStore, + repoStore: mockRepoStore, + userStore: mockUserStore, + discussionStore: mockDiscussionStore, } req := CreateCommentRequest{ @@ -293,9 +293,9 @@ func TestDiscussionComponent_UpdateComment(t *testing.T) { mockDiscussionStore := mockdb.NewMockDiscussionStore(t) // new discussionComponentImpl from mock db store comp := &discussionComponentImpl{ - rs: mockRepoStore, - us: mockUserStore, - ds: mockDiscussionStore, + repoStore: mockRepoStore, + userStore: mockUserStore, + discussionStore: mockDiscussionStore, } req := CreateCommentRequest{ @@ -332,9 +332,9 @@ func TestDiscussionComponent_DeleteComment(t *testing.T) { mockDiscussionStore := mockdb.NewMockDiscussionStore(t) // new discussionComponentImpl from mock db store comp := &discussionComponentImpl{ - rs: mockRepoStore, - us: mockUserStore, - ds: mockDiscussionStore, + repoStore: mockRepoStore, + userStore: mockUserStore, + discussionStore: mockDiscussionStore, } req := CreateCommentRequest{ @@ -371,9 +371,9 @@ func TestDiscussionComponent_ListDiscussionComments(t *testing.T) { mockDiscussionStore := mockdb.NewMockDiscussionStore(t) // new discussionComponentImpl from mock db store comp := &discussionComponentImpl{ - rs: mockRepoStore, - us: mockUserStore, - ds: mockDiscussionStore, + repoStore: mockRepoStore, + userStore: mockUserStore, + discussionStore: mockDiscussionStore, } discussionID := int64(1) diff --git a/component/git_http.go b/component/git_http.go index d1eb5305..4decc37a 100644 --- a/component/git_http.go +++ b/component/git_http.go @@ -26,13 +26,14 @@ import ( ) type gitHTTPComponentImpl struct { - git gitserver.GitServer + gitServer gitserver.GitServer config *config.Config s3Client s3.Client lfsMetaObjectStore database.LfsMetaObjectStore lfsLockStore database.LfsLockStore - repo database.RepoStore - *repoComponentImpl + repoStore database.RepoStore + userStore database.UserStore + repoComponent RepoComponent } type GitHTTPComponent interface { @@ -53,7 +54,7 @@ func NewGitHTTPComponent(config *config.Config) (GitHTTPComponent, error) { c := &gitHTTPComponentImpl{} c.config = config var err error - c.git, err = git.NewGitServer(config) + c.gitServer, err = git.NewGitServer(config) if err != nil { newError := fmt.Errorf("fail to create git server,error:%w", err) slog.Error(newError.Error()) @@ -66,9 +67,10 @@ func NewGitHTTPComponent(config *config.Config) (GitHTTPComponent, error) { return nil, newError } c.lfsMetaObjectStore = database.NewLfsMetaObjectStore() - c.repo = database.NewRepoStore() + c.repoStore = database.NewRepoStore() c.lfsLockStore = database.NewLfsLockStore() - c.repoComponentImpl, err = NewRepoComponentImpl(config) + c.userStore = database.NewUserStore() + c.repoComponent, err = NewRepoComponentImpl(config) if err != nil { return nil, err } @@ -76,13 +78,13 @@ func NewGitHTTPComponent(config *config.Config) (GitHTTPComponent, error) { } func (c *gitHTTPComponentImpl) InfoRefs(ctx context.Context, req types.InfoRefsReq) (io.Reader, error) { - repo, err := c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) + repo, err := c.repoStore.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return nil, fmt.Errorf("failed to find repo, error: %w", err) } if req.Rpc == "git-receive-pack" { - allowed, err := c.AllowWriteAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) + allowed, err := c.repoComponent.AllowWriteAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) if err != nil { return nil, ErrUnauthorized } @@ -91,7 +93,7 @@ func (c *gitHTTPComponentImpl) InfoRefs(ctx context.Context, req types.InfoRefsR } } else { if repo.Private { - allowed, err := c.AllowReadAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) + allowed, err := c.repoComponent.AllowReadAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) if err != nil { return nil, ErrUnauthorized } @@ -101,7 +103,7 @@ func (c *gitHTTPComponentImpl) InfoRefs(ctx context.Context, req types.InfoRefsR } } - reader, err := c.git.InfoRefsResponse(ctx, gitserver.InfoRefsReq{ + reader, err := c.gitServer.InfoRefsResponse(ctx, gitserver.InfoRefsReq{ Namespace: req.Namespace, Name: req.Name, Rpc: req.Rpc, @@ -113,13 +115,13 @@ func (c *gitHTTPComponentImpl) InfoRefs(ctx context.Context, req types.InfoRefsR } func (c *gitHTTPComponentImpl) GitUploadPack(ctx context.Context, req types.GitUploadPackReq) error { - repo, err := c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) + repo, err := c.repoStore.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return fmt.Errorf("failed to find repo, error: %w", err) } if repo.Private { - allowed, err := c.AllowReadAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) + allowed, err := c.repoComponent.AllowReadAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) if err != nil { return ErrUnauthorized } @@ -127,7 +129,7 @@ func (c *gitHTTPComponentImpl) GitUploadPack(ctx context.Context, req types.GitU return ErrForbidden } } - err = c.git.UploadPack(ctx, gitserver.UploadPackReq{ + err = c.gitServer.UploadPack(ctx, gitserver.UploadPackReq{ Namespace: req.Namespace, Name: req.Name, Request: req.Request, @@ -140,7 +142,7 @@ func (c *gitHTTPComponentImpl) GitUploadPack(ctx context.Context, req types.GitU } func (c *gitHTTPComponentImpl) GitReceivePack(ctx context.Context, req types.GitReceivePackReq) error { - _, err := c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) + _, err := c.repoStore.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return fmt.Errorf("failed to find repo, error: %w", err) } @@ -150,14 +152,14 @@ func (c *gitHTTPComponentImpl) GitReceivePack(ctx context.Context, req types.Git return ErrUnauthorized } - allowed, err := c.AllowWriteAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) + allowed, err := c.repoComponent.AllowWriteAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) if err != nil { return ErrUnauthorized } if !allowed { return ErrForbidden } - err = c.git.ReceivePack(ctx, gitserver.ReceivePackReq{ + err = c.gitServer.ReceivePack(ctx, gitserver.ReceivePackReq{ Namespace: req.Namespace, Name: req.Name, Request: req.Request, @@ -176,7 +178,7 @@ func (c *gitHTTPComponentImpl) BuildObjectResponse(ctx context.Context, req type respObjects []*types.ObjectResponse exists bool ) - repo, err := c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) + repo, err := c.repoStore.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return nil, fmt.Errorf("failed to find repo, error: %w", err) } @@ -223,7 +225,7 @@ func (c *gitHTTPComponentImpl) BuildObjectResponse(ctx context.Context, req type // } if exists && lfsMetaObject == nil { - allowed, err := c.AllowWriteAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) + allowed, err := c.repoComponent.AllowWriteAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) if err != nil { slog.Error("unable to check if user can wirte this repo", slog.String("lfs oid", obj.Oid), slog.Any("error", err)) return nil, ErrUnauthorized @@ -307,7 +309,7 @@ func (c *gitHTTPComponentImpl) buildObjectResponse(ctx context.Context, req type func (c *gitHTTPComponentImpl) LfsUpload(ctx context.Context, body io.ReadCloser, req types.UploadRequest) error { var exists bool - repo, err := c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) + repo, err := c.repoStore.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return fmt.Errorf("failed to find repo, error: %w", err) } @@ -332,7 +334,7 @@ func (c *gitHTTPComponentImpl) LfsUpload(ctx context.Context, body io.ReadCloser } uploadOrVerify := func() error { if exists { - allowed, err := c.AllowWriteAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) + allowed, err := c.repoComponent.AllowWriteAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) if err != nil { slog.Error("Unable to check if LFS MetaObject [%s] is allowed. Error: %v", pointer.Oid, err) return err @@ -414,7 +416,7 @@ func (c *gitHTTPComponentImpl) LfsUpload(ctx context.Context, body io.ReadCloser } func (c *gitHTTPComponentImpl) LfsVerify(ctx context.Context, req types.VerifyRequest, p types.Pointer) error { - repo, err := c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) + repo, err := c.repoStore.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return fmt.Errorf("failed to find repo, error: %w", err) } @@ -451,7 +453,7 @@ func (c *gitHTTPComponentImpl) CreateLock(ctx context.Context, req types.LfsLock var ( lock *database.LfsLock ) - repo, err := c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) + repo, err := c.repoStore.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return nil, fmt.Errorf("failed to find repo, error: %w", err) } @@ -461,7 +463,7 @@ func (c *gitHTTPComponentImpl) CreateLock(ctx context.Context, req types.LfsLock return nil, ErrUnauthorized } - allowed, err := c.AllowWriteAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) + allowed, err := c.repoComponent.AllowWriteAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) if err != nil { slog.Error("Unable to check user write access:", slog.Any("error", err)) return nil, err @@ -492,7 +494,7 @@ func (c *gitHTTPComponentImpl) CreateLock(ctx context.Context, req types.LfsLock } func (c *gitHTTPComponentImpl) ListLocks(ctx context.Context, req types.ListLFSLockReq) (*types.LFSLockList, error) { - repo, err := c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) + repo, err := c.repoStore.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return nil, fmt.Errorf("failed to find repo, error: %w", err) } @@ -502,7 +504,7 @@ func (c *gitHTTPComponentImpl) ListLocks(ctx context.Context, req types.ListLFSL return nil, ErrUnauthorized } - allowed, err := c.AllowReadAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) + allowed, err := c.repoComponent.AllowReadAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) if err != nil { slog.Error("Unable to check user write access:", slog.Any("error", err)) return nil, err @@ -557,7 +559,7 @@ func (c *gitHTTPComponentImpl) UnLock(ctx context.Context, req types.UnlockLFSRe lock *database.LfsLock err error ) - _, err = c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) + _, err = c.repoStore.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return nil, fmt.Errorf("failed to find repo, error: %w", err) } @@ -567,7 +569,7 @@ func (c *gitHTTPComponentImpl) UnLock(ctx context.Context, req types.UnlockLFSRe return nil, ErrUnauthorized } - allowed, err := c.AllowWriteAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) + allowed, err := c.repoComponent.AllowWriteAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) if err != nil { slog.Error("Unable to check user write access:", slog.Any("error", err)) return nil, err @@ -602,7 +604,7 @@ func (c *gitHTTPComponentImpl) VerifyLock(ctx context.Context, req types.VerifyL theirLocks []*types.LFSLock res types.LFSLockListVerify ) - repo, err := c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) + repo, err := c.repoStore.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return nil, fmt.Errorf("failed to find repo, error: %w", err) } @@ -612,7 +614,7 @@ func (c *gitHTTPComponentImpl) VerifyLock(ctx context.Context, req types.VerifyL return nil, ErrUnauthorized } - allowed, err := c.AllowReadAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) + allowed, err := c.repoComponent.AllowReadAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) if err != nil { slog.Error("Unable to check user write access:", slog.Any("error", err)) return nil, err @@ -656,11 +658,11 @@ func (c *gitHTTPComponentImpl) VerifyLock(ctx context.Context, req types.VerifyL func (c *gitHTTPComponentImpl) LfsDownload(ctx context.Context, req types.DownloadRequest) (*url.URL, error) { pointer := types.Pointer{Oid: req.Oid} - repo, err := c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) + repo, err := c.repoStore.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return nil, fmt.Errorf("failed to find repo, error: %w", err) } - allowed, err := c.AllowReadAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) + allowed, err := c.repoComponent.AllowReadAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) if err != nil { return nil, fmt.Errorf("failed to check allowed, error: %w", err) } diff --git a/component/git_http_test.go b/component/git_http_test.go new file mode 100644 index 00000000..88d24547 --- /dev/null +++ b/component/git_http_test.go @@ -0,0 +1,492 @@ +package component + +import ( + "context" + "database/sql" + "errors" + "fmt" + "io" + "net/url" + "testing" + + "github.com/minio/minio-go/v7" + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/git/gitserver" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +func TestGitHTTPComponent_InfoRefs(t *testing.T) { + + cases := []struct { + rpc string + private bool + }{ + {"foo", true}, + {"git-receive-pack", false}, + {"foo", false}, + } + + for _, c := range cases { + t.Run(fmt.Sprintf("%+v", c), func(t *testing.T) { + ctx := context.TODO() + gc := initializeTestGitHTTPComponent(ctx, t) + + gc.mocks.stores.RepoMock().EXPECT().FindByPath( + ctx, types.ModelRepo, "ns", "n", + ).Return(&database.Repository{ + Private: c.private, + }, nil) + if c.rpc == "git-receive-pack" { + gc.mocks.components.repo.EXPECT().AllowWriteAccess(ctx, types.ModelRepo, "ns", "n", "user").Return(true, nil) + } + if c.private { + gc.mocks.components.repo.EXPECT().AllowReadAccess(ctx, types.ModelRepo, "ns", "n", "user").Return(true, nil) + } + + gc.mocks.gitServer.EXPECT().InfoRefsResponse(ctx, gitserver.InfoRefsReq{ + Namespace: "ns", + Name: "n", + Rpc: c.rpc, + RepoType: types.ModelRepo, + GitProtocol: "", + }).Return(nil, nil) + + r, err := gc.InfoRefs(ctx, types.InfoRefsReq{ + Namespace: "ns", + Name: "n", + Rpc: c.rpc, + RepoType: types.ModelRepo, + CurrentUser: "user", + }) + require.Nil(t, err) + require.Equal(t, nil, r) + + }) + } + +} + +func TestGitHTTPComponent_GitUploadPack(t *testing.T) { + ctx := context.TODO() + gc := initializeTestGitHTTPComponent(ctx, t) + + gc.mocks.stores.RepoMock().EXPECT().FindByPath( + ctx, types.ModelRepo, "ns", "n", + ).Return(&database.Repository{ + Private: true, + }, nil) + gc.mocks.components.repo.EXPECT().AllowReadAccess(ctx, types.ModelRepo, "ns", "n", "user").Return(true, nil) + gc.mocks.gitServer.EXPECT().UploadPack(ctx, gitserver.UploadPackReq{ + Namespace: "ns", + Name: "n", + Request: nil, + RepoType: types.ModelRepo, + }).Return(nil) + err := gc.GitUploadPack(ctx, types.GitUploadPackReq{ + Namespace: "ns", + Name: "n", + RepoType: types.ModelRepo, + CurrentUser: "user", + }) + require.Nil(t, err) + +} + +func TestGitHTTPComponent_GitReceivePack(t *testing.T) { + ctx := context.TODO() + gc := initializeTestGitHTTPComponent(ctx, t) + + gc.mocks.stores.RepoMock().EXPECT().FindByPath( + ctx, types.ModelRepo, "ns", "n", + ).Return(&database.Repository{ + Private: true, + }, nil) + gc.mocks.components.repo.EXPECT().AllowWriteAccess(ctx, types.ModelRepo, "ns", "n", "user").Return(true, nil) + gc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{}, nil) + gc.mocks.gitServer.EXPECT().ReceivePack(ctx, gitserver.UploadPackReq{ + Namespace: "ns", + Name: "n", + Request: nil, + RepoType: types.ModelRepo, + }).Return(nil) + err := gc.GitReceivePack(ctx, types.GitUploadPackReq{ + Namespace: "ns", + Name: "n", + RepoType: types.ModelRepo, + CurrentUser: "user", + }) + require.Nil(t, err) + +} + +func TestGitHTTPComponent_BuildObjectResponse(t *testing.T) { + ctx := context.TODO() + gc := initializeTestGitHTTPComponent(ctx, t) + + oid1 := "a3f8e1b4f77bb24e508906c6972f81928f0d926e6daef1b29d12e348b8a3547e" + oid2 := "c39e7f5f1d61fa58ec6dbcd3b60a50870c577f0988d7c080fc88d1b460e7f5f1" + gc.mocks.stores.RepoMock().EXPECT().FindByPath( + ctx, types.ModelRepo, "ns", "n", + ).Return(&database.Repository{ + ID: 123, + Private: true, + }, nil) + gc.mocks.s3Client.EXPECT().StatObject( + ctx, "", + "lfs/a3/f8/e1b4f77bb24e508906c6972f81928f0d926e6daef1b29d12e348b8a3547e", + minio.StatObjectOptions{}, + ).Return( + minio.ObjectInfo{Size: 100}, nil, + ) + gc.mocks.s3Client.EXPECT().StatObject( + ctx, "", + "lfs/c3/9e/7f5f1d61fa58ec6dbcd3b60a50870c577f0988d7c080fc88d1b460e7f5f1", + minio.StatObjectOptions{}, + ).Return( + minio.ObjectInfo{Size: 100}, nil, + ) + gc.mocks.stores.LfsMetaObjectMock().EXPECT().FindByOID(ctx, int64(123), oid1).Return( + &database.LfsMetaObject{}, nil, + ) + gc.mocks.stores.LfsMetaObjectMock().EXPECT().FindByOID(ctx, int64(123), oid2).Return( + nil, nil, + ) + gc.mocks.components.repo.EXPECT().AllowWriteAccess( + ctx, types.ModelRepo, "ns", "n", "user", + ).Return(true, nil) + gc.mocks.stores.LfsMetaObjectMock().EXPECT().Create(ctx, database.LfsMetaObject{ + Oid: oid2, + Size: 100, + RepositoryID: 123, + Existing: true, + }).Return(nil, nil) + + resp, err := gc.BuildObjectResponse(ctx, types.BatchRequest{ + Namespace: "ns", + Name: "n", + RepoType: types.ModelRepo, + CurrentUser: "user", + Objects: []types.Pointer{ + { + Oid: oid1, + Size: 5, + }, + { + Oid: oid2, + Size: 100, + }, + }, + }, true) + require.Nil(t, err) + require.Equal(t, &types.BatchResponse{ + Objects: []*types.ObjectResponse{ + { + Pointer: types.Pointer{Oid: oid1, Size: 5}, + Error: &types.ObjectError{ + Code: 422, + Message: "Object a3f8e1b4f77bb24e508906c6972f81928f0d926e6daef1b29d12e348b8a3547e is not 5 bytes", + }, + Actions: nil, + }, + { + Pointer: types.Pointer{Oid: oid2, Size: 100}, + Actions: map[string]*types.Link{}, + }, + }, + }, resp) + +} + +func TestGitHTTPComponent_LfsUpload(t *testing.T) { + + for _, exist := range []bool{false, true} { + t.Run(fmt.Sprintf("exist %v", exist), func(t *testing.T) { + ctx := context.TODO() + gc := initializeTestGitHTTPComponent(ctx, t) + + rc := io.NopCloser(&io.LimitedReader{}) + oid := "a3f8e1b4f77bb24e508906c6972f81928f0d926e6daef1b29d12e348b8a3547e" + gc.mocks.stores.RepoMock().EXPECT().FindByPath( + ctx, types.ModelRepo, "ns", "n", + ).Return(&database.Repository{ + ID: 123, + Private: true, + }, nil) + if exist { + gc.mocks.s3Client.EXPECT().StatObject( + ctx, "", + "lfs/a3/f8/e1b4f77bb24e508906c6972f81928f0d926e6daef1b29d12e348b8a3547e", + minio.StatObjectOptions{}, + ).Return( + minio.ObjectInfo{Size: 100}, nil, + ) + } else { + gc.mocks.s3Client.EXPECT().StatObject( + ctx, "", + "lfs/a3/f8/e1b4f77bb24e508906c6972f81928f0d926e6daef1b29d12e348b8a3547e", + minio.StatObjectOptions{}, + ).Return( + minio.ObjectInfo{Size: 100}, errors.New("zzzz"), + ) + } + + if exist { + gc.mocks.components.repo.EXPECT().AllowWriteAccess( + ctx, types.ModelRepo, "ns", "n", "user", + ).Return(true, nil) + } else { + gc.mocks.s3Client.EXPECT().PutObject( + ctx, "", + "lfs/a3/f8/e1b4f77bb24e508906c6972f81928f0d926e6daef1b29d12e348b8a3547e", + rc, int64(100), minio.PutObjectOptions{ + ContentType: "application/octet-stream", + SendContentMd5: true, + ConcurrentStreamParts: true, + NumThreads: 5, + }).Return(minio.UploadInfo{Size: 100}, nil) + } + gc.mocks.stores.LfsMetaObjectMock().EXPECT().Create(ctx, database.LfsMetaObject{ + Oid: oid, + Size: 100, + RepositoryID: 123, + Existing: true, + }).Return(nil, nil) + + err := gc.LfsUpload(ctx, rc, types.UploadRequest{ + Oid: oid, + Size: 100, + CurrentUser: "user", + Namespace: "ns", + Name: "n", + RepoType: types.ModelRepo, + }) + require.Nil(t, err) + }) + } + +} + +func TestGitHTTPComponent_LfsVerify(t *testing.T) { + ctx := context.TODO() + gc := initializeTestGitHTTPComponent(ctx, t) + + gc.mocks.stores.RepoMock().EXPECT().FindByPath( + ctx, types.ModelRepo, "ns", "n", + ).Return(&database.Repository{ + ID: 123, + Private: true, + }, nil) + + gc.mocks.s3Client.EXPECT().StatObject(ctx, "", "lfs/oid", minio.StatObjectOptions{}).Return( + minio.ObjectInfo{Size: 100}, nil, + ) + gc.mocks.stores.LfsMetaObjectMock().EXPECT().Create(ctx, database.LfsMetaObject{ + Oid: "oid", + Size: 100, + RepositoryID: 123, + Existing: true, + }).Return(nil, nil) + + err := gc.LfsVerify(ctx, types.VerifyRequest{ + CurrentUser: "user", + Namespace: "ns", + Name: "n", + RepoType: types.ModelRepo, + }, types.Pointer{Oid: "oid", Size: 100}) + require.Nil(t, err) + +} + +func TestGitHTTPComponent_CreateLock(t *testing.T) { + ctx := context.TODO() + gc := initializeTestGitHTTPComponent(ctx, t) + + gc.mocks.stores.RepoMock().EXPECT().FindByPath( + ctx, types.ModelRepo, "ns", "n", + ).Return(&database.Repository{ + ID: 123, + Private: true, + }, nil) + + gc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{}, nil) + gc.mocks.components.repo.EXPECT().AllowWriteAccess( + ctx, types.ModelRepo, "ns", "n", "user", + ).Return(true, nil) + lfslock := &database.LfsLock{Path: "path", RepositoryID: 123} + gc.mocks.stores.LfsLockMock().EXPECT().FindByPath(ctx, int64(123), "path").Return( + lfslock, sql.ErrNoRows, + ) + gc.mocks.stores.LfsLockMock().EXPECT().Create(ctx, *lfslock).Return(lfslock, nil) + + l, err := gc.CreateLock(ctx, types.LfsLockReq{ + Namespace: "ns", + Name: "n", + RepoType: types.ModelRepo, + CurrentUser: "user", + Path: "path", + }) + require.Nil(t, err) + require.Equal(t, lfslock, l) + +} + +func TestGitHTTPComponent_ListLocks(t *testing.T) { + ctx := context.TODO() + gc := initializeTestGitHTTPComponent(ctx, t) + + gc.mocks.stores.RepoMock().EXPECT().FindByPath( + ctx, types.ModelRepo, "ns", "n", + ).Return(&database.Repository{ + ID: 123, + Private: true, + }, nil) + + gc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{}, nil) + gc.mocks.components.repo.EXPECT().AllowReadAccess( + ctx, types.ModelRepo, "ns", "n", "user", + ).Return(true, nil) + + gc.mocks.stores.LfsLockMock().EXPECT().FindByID(ctx, int64(123)).Return( + &database.LfsLock{ID: 11, RepositoryID: 123}, nil, + ) + gc.mocks.stores.LfsLockMock().EXPECT().FindByPath(ctx, int64(123), "foo/bar").Return( + &database.LfsLock{ID: 12, RepositoryID: 123}, nil, + ) + gc.mocks.stores.LfsLockMock().EXPECT().FindByRepoID(ctx, int64(123), 1, 10).Return( + []database.LfsLock{{ID: 13, RepositoryID: 123}}, nil, + ) + + req := types.ListLFSLockReq{ + Namespace: "ns", + Name: "n", + RepoType: types.ModelRepo, + CurrentUser: "user", + Cursor: 1, + Limit: 10, + } + req1 := req + req1.ID = 123 + ll, err := gc.ListLocks(ctx, req1) + require.Nil(t, err) + require.Equal(t, &types.LFSLockList{ + Locks: []*types.LFSLock{{ID: "11", Owner: &types.LFSLockOwner{}}}, + }, ll) + req2 := req + req2.Path = "foo/bar" + ll, err = gc.ListLocks(ctx, req2) + require.Nil(t, err) + require.Equal(t, &types.LFSLockList{ + Locks: []*types.LFSLock{{ID: "12", Owner: &types.LFSLockOwner{}}}, + }, ll) + ll, err = gc.ListLocks(ctx, req) + require.Nil(t, err) + require.Equal(t, &types.LFSLockList{ + Locks: []*types.LFSLock{{ID: "13", Owner: &types.LFSLockOwner{}}}, + }, ll) +} + +func TestGitHTTPComponent_UnLock(t *testing.T) { + ctx := context.TODO() + gc := initializeTestGitHTTPComponent(ctx, t) + + gc.mocks.stores.RepoMock().EXPECT().FindByPath( + ctx, types.ModelRepo, "ns", "n", + ).Return(&database.Repository{ + ID: 123, + Private: true, + }, nil) + + gc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{}, nil) + gc.mocks.components.repo.EXPECT().AllowWriteAccess( + ctx, types.ModelRepo, "ns", "n", "user", + ).Return(true, nil) + + gc.mocks.stores.LfsLockMock().EXPECT().FindByID(ctx, int64(123)).Return( + &database.LfsLock{ID: 11, RepositoryID: 123}, nil, + ) + gc.mocks.stores.LfsLockMock().EXPECT().RemoveByID(ctx, int64(123)).Return(nil) + + lk, err := gc.UnLock(ctx, types.UnlockLFSReq{ + Namespace: "ns", + Name: "n", + RepoType: types.ModelRepo, + CurrentUser: "user", + ID: 123, + }) + require.Nil(t, err) + require.Equal(t, &database.LfsLock{ + ID: 11, + RepositoryID: 123, + }, lk) + +} + +func TestGitHTTPComponent_VerifyLock(t *testing.T) { + ctx := context.TODO() + gc := initializeTestGitHTTPComponent(ctx, t) + + gc.mocks.stores.RepoMock().EXPECT().FindByPath( + ctx, types.ModelRepo, "ns", "n", + ).Return(&database.Repository{ + ID: 123, + Private: true, + }, nil) + + gc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{}, nil) + gc.mocks.components.repo.EXPECT().AllowReadAccess( + ctx, types.ModelRepo, "ns", "n", "user", + ).Return(true, nil) + + gc.mocks.stores.LfsLockMock().EXPECT().FindByRepoID(ctx, int64(123), 10, 1).Return( + []database.LfsLock{{ID: 11, RepositoryID: 123, User: database.User{Username: "zzz"}}}, nil, + ) + + lk, err := gc.VerifyLock(ctx, types.VerifyLFSLockReq{ + Namespace: "ns", + Name: "n", + RepoType: types.ModelRepo, + CurrentUser: "user", + Cursor: 10, + Limit: 1, + }) + require.Nil(t, err) + require.Equal(t, &types.LFSLockListVerify{ + Ours: []*types.LFSLock{{ID: "11", Owner: &types.LFSLockOwner{Name: "zzz"}}}, + Next: "11", + }, lk) + +} + +func TestGitHTTPComponent_LfsDownload(t *testing.T) { + ctx := context.TODO() + gc := initializeTestGitHTTPComponent(ctx, t) + + gc.mocks.stores.RepoMock().EXPECT().FindByPath( + ctx, types.ModelRepo, "ns", "n", + ).Return(&database.Repository{ + ID: 123, + Private: true, + }, nil) + + gc.mocks.components.repo.EXPECT().AllowReadAccess( + ctx, types.ModelRepo, "ns", "n", "user", + ).Return(true, nil) + gc.mocks.stores.LfsMetaObjectMock().EXPECT().FindByOID(ctx, int64(123), "oid").Return(nil, nil) + reqParams := make(url.Values) + reqParams.Set("response-content-disposition", fmt.Sprintf("attachment;filename=%s", "sa")) + url := &url.URL{Scheme: "http"} + gc.mocks.s3Client.EXPECT().PresignedGetObject(ctx, "", "lfs/oid", ossFileExpireSeconds, reqParams).Return(url, nil) + + u, err := gc.LfsDownload(ctx, types.DownloadRequest{ + Oid: "oid", + Namespace: "ns", + Name: "n", + RepoType: types.ModelRepo, + SaveAs: "sa", + CurrentUser: "user", + }) + require.Nil(t, err) + require.Equal(t, url, u) + +} diff --git a/component/internal.go b/component/internal.go index d23f28ec..07af29ca 100644 --- a/component/internal.go +++ b/component/internal.go @@ -8,6 +8,7 @@ import ( "strconv" pb "gitlab.com/gitlab-org/gitaly/v16/proto/go/gitalypb" + "opencsg.com/csghub-server/builder/git" "opencsg.com/csghub-server/builder/git/gitserver" "opencsg.com/csghub-server/builder/git/gitserver/gitaly" "opencsg.com/csghub-server/builder/store/database" @@ -17,10 +18,13 @@ import ( ) type internalComponentImpl struct { - config *config.Config - sshKeyStore database.SSHKeyStore - repoStore database.RepoStore - *repoComponentImpl + config *config.Config + sshKeyStore database.SSHKeyStore + repoStore database.RepoStore + tokenStore database.AccessTokenStore + namespaceStore database.NamespaceStore + repoComponent RepoComponent + gitServer gitserver.GitServer } type InternalComponent interface { @@ -37,11 +41,17 @@ func NewInternalComponent(config *config.Config) (InternalComponent, error) { c.config = config c.sshKeyStore = database.NewSSHKeyStore() c.repoStore = database.NewRepoStore() - c.repoComponentImpl, err = NewRepoComponentImpl(config) + c.repoComponent, err = NewRepoComponentImpl(config) c.tokenStore = database.NewAccessTokenStore() + c.namespaceStore = database.NewNamespaceStore() if err != nil { return nil, err } + git, err := git.NewGitServer(config) + if err != nil { + return nil, fmt.Errorf("failed to create git server: %w", err) + } + c.gitServer = git return c, nil } @@ -70,7 +80,7 @@ func (c *internalComponentImpl) SSHAllowed(ctx context.Context, req types.SSHAll return nil, fmt.Errorf("failed to find ssh key by id, err: %v", err) } if req.Action == "git-receive-pack" { - allowed, err := c.AllowWriteAccess(ctx, req.RepoType, req.Namespace, req.Name, sshKey.User.Username) + allowed, err := c.repoComponent.AllowWriteAccess(ctx, req.RepoType, req.Namespace, req.Name, sshKey.User.Username) if err != nil { return nil, ErrUnauthorized } @@ -79,7 +89,7 @@ func (c *internalComponentImpl) SSHAllowed(ctx context.Context, req types.SSHAll } } else if req.Action == "git-upload-pack" { if repo.Private { - allowed, err := c.AllowReadAccess(ctx, req.RepoType, req.Namespace, req.Name, sshKey.User.Username) + allowed, err := c.repoComponent.AllowReadAccess(ctx, req.RepoType, req.Namespace, req.Name, sshKey.User.Username) if err != nil { return nil, ErrUnauthorized } @@ -133,7 +143,7 @@ func (c *internalComponentImpl) GetCommitDiff(ctx context.Context, req types.Get if repo == nil { return nil, errors.New("repo not found") } - diffs, err := c.git.GetDiffBetweenTwoCommits(ctx, gitserver.GetDiffBetweenTwoCommitsReq{ + diffs, err := c.gitServer.GetDiffBetweenTwoCommits(ctx, gitserver.GetDiffBetweenTwoCommitsReq{ Namespace: req.Namespace, Name: req.Name, RepoType: req.RepoType, @@ -165,7 +175,7 @@ func (c *internalComponentImpl) LfsAuthenticate(ctx context.Context, req types.L return nil, fmt.Errorf("failed to find ssh key by id, err: %v", err) } if repo.Private { - allowed, err := c.AllowReadAccess(ctx, req.RepoType, req.Namespace, req.Name, sshKey.User.Username) + allowed, err := c.repoComponent.AllowReadAccess(ctx, req.RepoType, req.Namespace, req.Name, sshKey.User.Username) if err != nil { return nil, ErrUnauthorized } diff --git a/component/internal_test.go b/component/internal_test.go new file mode 100644 index 00000000..4935fac9 --- /dev/null +++ b/component/internal_test.go @@ -0,0 +1,153 @@ +package component + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + pb "gitlab.com/gitlab-org/gitaly/v16/proto/go/gitalypb" + "opencsg.com/csghub-server/builder/git/gitserver" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +func TestInternalComponent_Allowed(t *testing.T) { + ctx := context.TODO() + ic := initializeTestInternalComponent(ctx, t) + + allowed, err := ic.Allowed(ctx) + require.Nil(t, err) + require.True(t, allowed) +} + +func TestInternalComponent_SSHAllowed(t *testing.T) { + ctx := context.TODO() + ic := initializeTestInternalComponent(ctx, t) + + ic.mocks.stores.NamespaceMock().EXPECT().FindByPath(ctx, "ns").Return(database.Namespace{ + ID: 321, + }, nil) + ic.mocks.stores.RepoMock().EXPECT().FindByPath(ctx, types.ModelRepo, "ns", "n").Return( + &database.Repository{ID: 123, Private: true}, nil, + ) + ic.mocks.stores.SSHMock().EXPECT().FindByID(ctx, int64(1)).Return(&database.SSHKey{ + ID: 111, + User: &database.User{ID: 11, Username: "user"}, + }, nil) + ic.mocks.components.repo.EXPECT().AllowWriteAccess( + ctx, types.ModelRepo, "ns", "n", "user", + ).Return(true, nil) + ic.mocks.components.repo.EXPECT().AllowReadAccess( + ctx, types.ModelRepo, "ns", "n", "user", + ).Return(true, nil) + + req := types.SSHAllowedReq{ + RepoType: types.ModelRepo, + Namespace: "ns", + Name: "n", + KeyID: "1", + Action: "git-receive-pack", + } + resp, err := ic.SSHAllowed(ctx, req) + require.Nil(t, err) + expected := &types.SSHAllowedResp{ + Success: true, + Message: "allowed", + Repo: req.Repo, + UserID: "11", + KeyType: "ssh", + KeyID: 111, + ProjectID: 123, + RootNamespaceID: 321, + GitConfigOptions: []string{"uploadpack.allowFilter=true", "uploadpack.allowAnySHA1InWant=true"}, + Gitaly: types.Gitaly{ + Repo: pb.Repository{ + RelativePath: "models_ns/n.git", + GlRepository: "models/ns/n", + }, + }, + StatusCode: 200, + } + + require.Equal(t, expected, resp) + + req.Action = "git-upload-pack" + resp, err = ic.SSHAllowed(ctx, req) + require.Nil(t, err) + require.Equal(t, expected, resp) + +} + +func TestInternalComponent_GetAuthorizedKeys(t *testing.T) { + ctx := context.TODO() + ic := initializeTestInternalComponent(ctx, t) + + ic.mocks.stores.SSHMock().EXPECT().FindByFingerpringSHA256( + ctx, "dUQ5GwtKsCPC8Scv1OLnOEvIW0QWULVSWyj5bZwQHwM", + ).Return(&database.SSHKey{}, nil) + key, err := ic.GetAuthorizedKeys(ctx, "foobar") + require.Nil(t, err) + require.Equal(t, &database.SSHKey{}, key) +} + +func TestInternalComponent_GetCommitDiff(t *testing.T) { + ctx := context.TODO() + ic := initializeTestInternalComponent(ctx, t) + + req := types.GetDiffBetweenTwoCommitsReq{ + Namespace: "ns", + Name: "n", + RepoType: types.ModelRepo, + Ref: "main", + LeftCommitId: "l", + RightCommitId: "r", + } + ic.mocks.stores.RepoMock().EXPECT().FindByPath(ctx, types.ModelRepo, "ns", "n").Return( + &database.Repository{}, nil, + ) + ic.mocks.gitServer.EXPECT().GetDiffBetweenTwoCommits(ctx, gitserver.GetDiffBetweenTwoCommitsReq{ + Namespace: req.Namespace, + Name: req.Name, + RepoType: req.RepoType, + Ref: req.Ref, + LeftCommitId: req.LeftCommitId, + RightCommitId: req.RightCommitId, + }).Return(&types.GiteaCallbackPushReq{Ref: "main"}, nil) + + resp, err := ic.GetCommitDiff(ctx, req) + require.Nil(t, err) + require.Equal(t, &types.GiteaCallbackPushReq{Ref: "main"}, resp) +} + +func TestInternalComponent_LfsAuthenticate(t *testing.T) { + ctx := context.TODO() + ic := initializeTestInternalComponent(ctx, t) + + ic.mocks.stores.RepoMock().EXPECT().FindByPath(ctx, types.ModelRepo, "ns", "n").Return( + &database.Repository{Private: true}, nil, + ) + ic.mocks.stores.SSHMock().EXPECT().FindByID(ctx, int64(1)).Return(&database.SSHKey{ + ID: 111, + User: &database.User{ID: 11, Username: "user"}, + }, nil) + ic.mocks.components.repo.EXPECT().AllowReadAccess( + ctx, types.ModelRepo, "ns", "n", "user", + ).Return(true, nil) + ic.mocks.stores.AccessTokenMock().EXPECT().GetUserGitToken(ctx, "user").Return( + &database.AccessToken{Token: "token"}, nil, + ) + + resp, err := ic.LfsAuthenticate(ctx, types.LfsAuthenticateReq{ + Namespace: "ns", + Name: "n", + RepoType: types.ModelRepo, + KeyID: "1", + }) + require.Nil(t, err) + require.Equal(t, &types.LfsAuthenticateResp{ + Username: "user", + LfsToken: "token", + RepoPath: "/models/ns/n.git", + }, resp) + +} diff --git a/component/mirror_source.go b/component/mirror_source.go index 136c16a0..b136e29d 100644 --- a/component/mirror_source.go +++ b/component/mirror_source.go @@ -11,8 +11,8 @@ import ( ) type mirrorSourceComponentImpl struct { - msStore database.MirrorSourceStore - userStore database.UserStore + mirrorSourceStore database.MirrorSourceStore + userStore database.UserStore } type MirrorSourceComponent interface { @@ -25,8 +25,8 @@ type MirrorSourceComponent interface { func NewMirrorSourceComponent(config *config.Config) (MirrorSourceComponent, error) { return &mirrorSourceComponentImpl{ - msStore: database.NewMirrorSourceStore(), - userStore: database.NewUserStore(), + mirrorSourceStore: database.NewMirrorSourceStore(), + userStore: database.NewUserStore(), }, nil } @@ -41,7 +41,7 @@ func (c *mirrorSourceComponentImpl) Create(ctx context.Context, req types.Create } ms.SourceName = req.SourceName ms.InfoAPIUrl = req.InfoAPiUrl - res, err := c.msStore.Create(ctx, &ms) + res, err := c.mirrorSourceStore.Create(ctx, &ms) if err != nil { return nil, fmt.Errorf("failed to create mirror source, error: %w", err) } @@ -56,7 +56,7 @@ func (c *mirrorSourceComponentImpl) Get(ctx context.Context, id int64, currentUs if !user.CanAdmin() { return nil, errors.New("user does not have admin permission") } - ms, err := c.msStore.Get(ctx, id) + ms, err := c.mirrorSourceStore.Get(ctx, id) if err != nil { return nil, fmt.Errorf("failed to get mirror source, error: %w", err) } @@ -71,7 +71,7 @@ func (c *mirrorSourceComponentImpl) Index(ctx context.Context, currentUser strin if !user.CanAdmin() { return nil, errors.New("user does not have admin permission") } - ms, err := c.msStore.Index(ctx) + ms, err := c.mirrorSourceStore.Index(ctx) if err != nil { return nil, fmt.Errorf("failed to get mirror source, error: %w", err) } @@ -89,7 +89,7 @@ func (c *mirrorSourceComponentImpl) Update(ctx context.Context, req types.Update ms.ID = req.ID ms.SourceName = req.SourceName ms.InfoAPIUrl = req.InfoAPiUrl - err = c.msStore.Update(ctx, &ms) + err = c.mirrorSourceStore.Update(ctx, &ms) if err != nil { return nil, fmt.Errorf("failed to update mirror source, error: %w", err) } @@ -104,11 +104,11 @@ func (c *mirrorSourceComponentImpl) Delete(ctx context.Context, id int64, curren if !user.CanAdmin() { return errors.New("user does not have admin permission") } - ms, err := c.msStore.Get(ctx, id) + ms, err := c.mirrorSourceStore.Get(ctx, id) if err != nil { return fmt.Errorf("failed to find mirror source, error: %w", err) } - err = c.msStore.Delete(ctx, ms) + err = c.mirrorSourceStore.Delete(ctx, ms) if err != nil { return fmt.Errorf("failed to delete mirror source, error: %w", err) } diff --git a/component/mirror_source_test.go b/component/mirror_source_test.go new file mode 100644 index 00000000..ccc3ec02 --- /dev/null +++ b/component/mirror_source_test.go @@ -0,0 +1,104 @@ +package component + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +func TestMirrorSourceComponent_Create(t *testing.T) { + ctx := context.TODO() + mc := initializeTestMirrorSourceComponent(ctx, t) + + mc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{ + RoleMask: "admin", + }, nil) + mc.mocks.stores.MirrorSourceMock().EXPECT().Create(ctx, &database.MirrorSource{ + SourceName: "sn", + InfoAPIUrl: "url", + }).Return(&database.MirrorSource{ID: 1}, nil) + + data, err := mc.Create(ctx, types.CreateMirrorSourceReq{ + SourceName: "sn", + InfoAPiUrl: "url", + CurrentUser: "user", + }) + require.Nil(t, err) + require.Equal(t, &database.MirrorSource{ID: 1}, data) +} + +func TestMirrorSourceComponent_Get(t *testing.T) { + ctx := context.TODO() + mc := initializeTestMirrorSourceComponent(ctx, t) + + mc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{ + RoleMask: "admin", + }, nil) + mc.mocks.stores.MirrorSourceMock().EXPECT().Get(ctx, int64(1)).Return(&database.MirrorSource{ID: 1}, nil) + + data, err := mc.Get(ctx, 1, "user") + require.Nil(t, err) + require.Equal(t, &database.MirrorSource{ID: 1}, data) +} + +func TestMirrorSourceComponent_Index(t *testing.T) { + ctx := context.TODO() + mc := initializeTestMirrorSourceComponent(ctx, t) + + mc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{ + RoleMask: "admin", + }, nil) + mc.mocks.stores.MirrorSourceMock().EXPECT().Index(ctx).Return([]database.MirrorSource{ + {ID: 1}, + }, nil) + + data, err := mc.Index(ctx, "user") + require.Nil(t, err) + require.Equal(t, []database.MirrorSource{ + {ID: 1}, + }, data) +} + +func TestMirrorSourceComponent_Update(t *testing.T) { + ctx := context.TODO() + mc := initializeTestMirrorSourceComponent(ctx, t) + + mc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{ + RoleMask: "admin", + }, nil) + mc.mocks.stores.MirrorSourceMock().EXPECT().Update(ctx, &database.MirrorSource{ + ID: 1, + SourceName: "sn", + InfoAPIUrl: "url", + }).Return(nil) + + data, err := mc.Update(ctx, types.UpdateMirrorSourceReq{ + ID: 1, + SourceName: "sn", + InfoAPiUrl: "url", + CurrentUser: "user", + }) + require.Nil(t, err) + require.Equal(t, &database.MirrorSource{ + ID: 1, + SourceName: "sn", + InfoAPIUrl: "url", + }, data) +} + +func TestMirrorSourceComponent_Delete(t *testing.T) { + ctx := context.TODO() + mc := initializeTestMirrorSourceComponent(ctx, t) + + mc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{ + RoleMask: "admin", + }, nil) + mc.mocks.stores.MirrorSourceMock().EXPECT().Get(ctx, int64(1)).Return(&database.MirrorSource{ID: 1}, nil) + mc.mocks.stores.MirrorSourceMock().EXPECT().Delete(ctx, &database.MirrorSource{ID: 1}).Return(nil) + + err := mc.Delete(ctx, 1, "user") + require.Nil(t, err) +} diff --git a/component/model_test.go b/component/model_test.go index ea557131..3f49f915 100644 --- a/component/model_test.go +++ b/component/model_test.go @@ -9,7 +9,6 @@ import ( "github.com/stretchr/testify/require" "opencsg.com/csghub-server/builder/git/gitserver" "opencsg.com/csghub-server/builder/git/membership" - "opencsg.com/csghub-server/builder/inference" "opencsg.com/csghub-server/builder/store/database" "opencsg.com/csghub-server/common/types" ) @@ -458,31 +457,22 @@ func TestModelComponent_DeleteRelationDataset(t *testing.T) { require.Nil(t, err) } -func TestModelComponent_Predict(t *testing.T) { - ctx := context.TODO() - mc := initializeTestModelComponent(ctx, t) - - mc.mocks.inferenceClient.EXPECT().Predict(inference.ModelID{ - Owner: "ns", - Name: "n", - }, &inference.PredictRequest{ - Prompt: "foo", - }).Return(&inference.PredictResponse{ - GeneratedText: "abcd", - }, nil) +// func TestModelComponent_Predict(t *testing.T) { +// ctx := context.TODO() +// mc := initializeTestModelComponent(ctx, t) - resp, err := mc.Predict(ctx, &types.ModelPredictReq{ - Namespace: "ns", - Name: "n", - Input: "foo", - CurrentUser: "user", - }) - require.Nil(t, err) - require.Equal(t, &types.ModelPredictResp{ - Content: "abcd", - }, resp) +// resp, err := mc.Predict(ctx, &types.ModelPredictReq{ +// Namespace: "ns", +// Name: "n", +// Input: "foo", +// CurrentUser: "user", +// }) +// require.Nil(t, err) +// require.Equal(t, &types.ModelPredictResp{ +// Content: "abcd", +// }, resp) -} +// } // func TestModelComponent_Deploy(t *testing.T) { // ctx := context.TODO() diff --git a/component/multi_sync.go b/component/multi_sync.go index edb44597..ddd08285 100644 --- a/component/multi_sync.go +++ b/component/multi_sync.go @@ -20,16 +20,16 @@ import ( ) type multiSyncComponentImpl struct { - s database.MultiSyncStore - repo database.RepoStore - model database.ModelStore - dataset database.DatasetStore - namespace database.NamespaceStore - user database.UserStore - versionStore database.SyncVersionStore - tag database.TagStore - file database.FileStore - git gitserver.GitServer + multiSyncStore database.MultiSyncStore + repoStore database.RepoStore + modelStore database.ModelStore + datasetStore database.DatasetStore + namespaceStore database.NamespaceStore + userStore database.UserStore + syncVersionStore database.SyncVersionStore + tagStore database.TagStore + fileStore database.FileStore + gitServer gitserver.GitServer } type MultiSyncComponent interface { @@ -43,21 +43,21 @@ func NewMultiSyncComponent(config *config.Config) (MultiSyncComponent, error) { return nil, fmt.Errorf("failed to create git server: %w", err) } return &multiSyncComponentImpl{ - s: database.NewMultiSyncStore(), - repo: database.NewRepoStore(), - model: database.NewModelStore(), - dataset: database.NewDatasetStore(), - namespace: database.NewNamespaceStore(), - user: database.NewUserStore(), - versionStore: database.NewSyncVersionStore(), - tag: database.NewTagStore(), - file: database.NewFileStore(), - git: git, + multiSyncStore: database.NewMultiSyncStore(), + repoStore: database.NewRepoStore(), + modelStore: database.NewModelStore(), + datasetStore: database.NewDatasetStore(), + namespaceStore: database.NewNamespaceStore(), + userStore: database.NewUserStore(), + syncVersionStore: database.NewSyncVersionStore(), + tagStore: database.NewTagStore(), + fileStore: database.NewFileStore(), + gitServer: git, }, nil } func (c *multiSyncComponentImpl) More(ctx context.Context, cur int64, limit int64) ([]types.SyncVersion, error) { - dbVersions, err := c.s.GetAfter(ctx, cur, limit) + dbVersions, err := c.multiSyncStore.GetAfter(ctx, cur, limit) if err != nil { return nil, fmt.Errorf("failed to get sync versions after %d from db: %w", cur, err) } @@ -77,7 +77,7 @@ func (c *multiSyncComponentImpl) More(ctx context.Context, cur int64, limit int6 func (c *multiSyncComponentImpl) SyncAsClient(ctx context.Context, sc multisync.Client) error { var currentVersion int64 - v, err := c.s.GetLatest(ctx) + v, err := c.multiSyncStore.GetLatest(ctx) if err != nil { if err != sql.ErrNoRows { return fmt.Errorf("failed to get latest sync version from db: %w", err) @@ -108,7 +108,7 @@ func (c *multiSyncComponentImpl) SyncAsClient(ctx context.Context, sc multisync. } } - syncVersions, err := c.s.GetAfterDistinct(ctx, v.Version) + syncVersions, err := c.multiSyncStore.GetAfterDistinct(ctx, v.Version) if err != nil { slog.Error("failed to find distinct sync versions", slog.Any("error", err)) return err @@ -214,7 +214,7 @@ func (c *multiSyncComponentImpl) createLocalDataset(ctx context.Context, m *type // HTTPCloneURL: gitRepo.HttpCloneURL, // SSHCloneURL: gitRepo.SshCloneURL, } - newDBRepo, err := c.repo.UpdateOrCreateRepo(ctx, dbRepo) + newDBRepo, err := c.repoStore.UpdateOrCreateRepo(ctx, dbRepo) if err != nil { return fmt.Errorf("fail to create database repo, error: %w", err) } @@ -230,7 +230,7 @@ func (c *multiSyncComponentImpl) createLocalDataset(ctx context.Context, m *type ShowName: tag.ShowName, Scope: database.DatasetTagScope, } - t, err := c.tag.FindOrCreate(ctx, dbTag) + t, err := c.tagStore.FindOrCreate(ctx, dbTag) if err != nil { slog.Error("failed to create or find database tag", slog.Any("tag", dbTag)) continue @@ -241,18 +241,18 @@ func (c *multiSyncComponentImpl) createLocalDataset(ctx context.Context, m *type }) } - err = c.repo.DeleteAllTags(ctx, newDBRepo.ID) + err = c.repoStore.DeleteAllTags(ctx, newDBRepo.ID) if err != nil && err != sql.ErrNoRows { slog.Error("failed to delete database tag", slog.Any("error", err)) } - err = c.repo.BatchCreateRepoTags(ctx, repoTags) + err = c.repoStore.BatchCreateRepoTags(ctx, repoTags) if err != nil { slog.Error("failed to create database tag", slog.Any("error", err)) } } - err = c.repo.DeleteAllFiles(ctx, newDBRepo.ID) + err = c.repoStore.DeleteAllFiles(ctx, newDBRepo.ID) if err != nil && err != sql.ErrNoRows { slog.Error("failed to delete database files", slog.Any("error", err)) } @@ -277,7 +277,7 @@ func (c *multiSyncComponentImpl) createLocalDataset(ctx context.Context, m *type }) } - err = c.file.BatchCreate(ctx, dbFiles) + err = c.fileStore.BatchCreate(ctx, dbFiles) if err != nil { slog.Error("failed to create all files of repo", slog.Any("sync_version", s)) } @@ -288,7 +288,7 @@ func (c *multiSyncComponentImpl) createLocalDataset(ctx context.Context, m *type Repository: newDBRepo, RepositoryID: newDBRepo.ID, } - _, err = c.dataset.CreateIfNotExist(ctx, dbDataset) + _, err = c.datasetStore.CreateIfNotExist(ctx, dbDataset) if err != nil { return fmt.Errorf("failed to create dataset in db, cause: %w", err) } @@ -340,7 +340,7 @@ func (c *multiSyncComponentImpl) createLocalModel(ctx context.Context, m *types. // HTTPCloneURL: gitRepo.HttpCloneURL, // SSHCloneURL: gitRepo.SshCloneURL, } - newDBRepo, err := c.repo.UpdateOrCreateRepo(ctx, dbRepo) + newDBRepo, err := c.repoStore.UpdateOrCreateRepo(ctx, dbRepo) if err != nil { return fmt.Errorf("fail to create database repo, error: %w", err) } @@ -356,7 +356,7 @@ func (c *multiSyncComponentImpl) createLocalModel(ctx context.Context, m *types. ShowName: tag.ShowName, Scope: database.ModelTagScope, } - t, err := c.tag.FindOrCreate(ctx, dbTag) + t, err := c.tagStore.FindOrCreate(ctx, dbTag) if err != nil { slog.Error("failed to create or find database tag", slog.Any("tag", dbTag)) continue @@ -366,17 +366,17 @@ func (c *multiSyncComponentImpl) createLocalModel(ctx context.Context, m *types. TagID: t.ID, }) } - err = c.repo.DeleteAllTags(ctx, newDBRepo.ID) + err = c.repoStore.DeleteAllTags(ctx, newDBRepo.ID) if err != nil && err != sql.ErrNoRows { slog.Error("failed to delete database tag", slog.Any("error", err)) } - err = c.repo.BatchCreateRepoTags(ctx, repoTags) + err = c.repoStore.BatchCreateRepoTags(ctx, repoTags) if err != nil { slog.Error("failed to batch create database tag", slog.Any("error", err)) } } - err = c.repo.DeleteAllFiles(ctx, newDBRepo.ID) + err = c.repoStore.DeleteAllFiles(ctx, newDBRepo.ID) if err != nil && err != sql.ErrNoRows { slog.Error("failed to delete all files for repo", slog.Any("error", err)) } @@ -401,7 +401,7 @@ func (c *multiSyncComponentImpl) createLocalModel(ctx context.Context, m *types. }) } - err = c.file.BatchCreate(ctx, dbFiles) + err = c.fileStore.BatchCreate(ctx, dbFiles) if err != nil { slog.Error("failed to create all files of repo", slog.Any("sync_version", s)) } @@ -413,7 +413,7 @@ func (c *multiSyncComponentImpl) createLocalModel(ctx context.Context, m *types. RepositoryID: newDBRepo.ID, BaseModel: m.BaseModel, } - _, err = c.model.CreateIfNotExist(ctx, dbModel) + _, err = c.modelStore.CreateIfNotExist(ctx, dbModel) if err != nil { return fmt.Errorf("failed to create database model, cause: %w", err) } @@ -426,7 +426,7 @@ func (c *multiSyncComponentImpl) createUser(ctx context.Context, req types.Creat Username: req.Username, Email: req.Email, } - gsUserResp, err := c.git.CreateUser(gsUserReq) + gsUserResp, err := c.gitServer.CreateUser(gsUserReq) if err != nil { newError := fmt.Errorf("failed to create gitserver user,error:%w", err) return database.User{}, newError @@ -443,7 +443,7 @@ func (c *multiSyncComponentImpl) createUser(ctx context.Context, req types.Creat GitID: gsUserResp.GitID, Password: gsUserResp.Password, } - err = c.user.Create(ctx, user, namespace) + err = c.userStore.Create(ctx, user, namespace) if err != nil { newError := fmt.Errorf("failed to create user,error:%w", err) return database.User{}, newError @@ -453,7 +453,7 @@ func (c *multiSyncComponentImpl) createUser(ctx context.Context, req types.Creat } func (c *multiSyncComponentImpl) getUser(ctx context.Context, userName string) (database.User, error) { - return c.user.FindByUsername(ctx, userName) + return c.userStore.FindByUsername(ctx, userName) } func (c *multiSyncComponentImpl) createLocalSyncVersion(ctx context.Context, v types.SyncVersion) error { @@ -465,7 +465,7 @@ func (c *multiSyncComponentImpl) createLocalSyncVersion(ctx context.Context, v t LastModifiedAt: v.LastModifyTime, ChangeLog: v.ChangeLog, } - err := c.versionStore.Create(ctx, &syncVersion) + err := c.syncVersionStore.Create(ctx, &syncVersion) if err != nil { return err } diff --git a/component/multi_sync_test.go b/component/multi_sync_test.go new file mode 100644 index 00000000..d0296aea --- /dev/null +++ b/component/multi_sync_test.go @@ -0,0 +1,174 @@ +package component + +import ( + "context" + "database/sql" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + multisync_mock "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/multisync" + "opencsg.com/csghub-server/builder/git/gitserver" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +func TestMultiSyncComponent_More(t *testing.T) { + ctx := context.TODO() + mc := initializeTestMultiSyncComponent(ctx, t) + + mc.mocks.stores.MultiSyncMock().EXPECT().GetAfter(ctx, int64(1), int64(10)).Return( + []database.SyncVersion{{Version: 2}}, nil, + ) + + data, err := mc.More(ctx, 1, 10) + require.Nil(t, err) + require.Equal(t, []types.SyncVersion{ + {Version: 2}, + }, data) +} + +func TestMultiSyncComponent_SyncAsClient(t *testing.T) { + ctx := mock.Anything + mc := initializeTestMultiSyncComponent(context.TODO(), t) + + mc.mocks.stores.MultiSyncMock().EXPECT().GetLatest(ctx).Return(database.SyncVersion{ + Version: 1, + }, nil) + mockedClient := multisync_mock.NewMockClient(t) + mockedClient.EXPECT().Latest(ctx, int64(1)).Return(types.SyncVersionResponse{ + Data: struct { + Versions []types.SyncVersion "json:\"versions\"" + HasMore bool "json:\"has_more\"" + }{ + Versions: []types.SyncVersion{ + {Version: 2}, + }, + HasMore: true, + }, + }, nil) + mockedClient.EXPECT().Latest(ctx, int64(2)).Return(types.SyncVersionResponse{ + Data: struct { + Versions []types.SyncVersion "json:\"versions\"" + HasMore bool "json:\"has_more\"" + }{ + Versions: []types.SyncVersion{ + {Version: 3}, + }, + HasMore: false, + }, + }, nil) + mc.mocks.stores.SyncVersionMock().EXPECT().Create(ctx, &database.SyncVersion{ + Version: 2, + }).Return(nil) + mc.mocks.stores.SyncVersionMock().EXPECT().Create(ctx, &database.SyncVersion{ + Version: 3, + }).Return(nil) + dsvs := []database.SyncVersion{ + {RepoType: types.ModelRepo}, + {RepoType: types.DatasetRepo}, + } + mc.mocks.stores.MultiSyncMock().EXPECT().GetAfterDistinct(ctx, int64(1)).Return( + dsvs, nil, + ) + svs := []types.SyncVersion{ + {RepoType: types.ModelRepo}, + {RepoType: types.DatasetRepo}, + } + // new model mock + mockedClient.EXPECT().ModelInfo(ctx, svs[0]).Return(&types.Model{ + User: &types.User{Nickname: "nn"}, + Path: "ns/user", + Tags: []types.RepoTag{{Name: "t1"}}, + }, nil) + mockedClient.EXPECT().ReadMeData(ctx, svs[0]).Return("readme", nil) + mc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "CSG_ns").Return(database.User{}, sql.ErrNoRows) + mc.mocks.gitServer.EXPECT().CreateUser(gitserver.CreateUserRequest{ + Nickname: "nn", + Username: "CSG_ns", + Email: "CSG_", + }).Return(&gitserver.CreateUserResponse{GitID: 123}, nil) + mc.mocks.stores.UserMock().EXPECT().Create(ctx, &database.User{ + NickName: "nn", + Username: "CSG_ns", + Email: "CSG_", + GitID: 123, + }, &database.Namespace{ + Path: "CSG_ns", + Mirrored: true, + }).Return(nil) + dbrepo := &database.Repository{ + Path: "CSG_ns/user", + GitPath: "models_CSG_ns/user", + Name: "user", + Readme: "readme", + Source: types.OpenCSGSource, + SyncStatus: types.SyncStatusPending, + RepositoryType: types.ModelRepo, + } + mc.mocks.stores.RepoMock().EXPECT().UpdateOrCreateRepo(ctx, *dbrepo).Return(dbrepo, nil) + dbrepo.ID = 1 + mc.mocks.stores.TagMock().EXPECT().FindOrCreate(ctx, database.Tag{ + Name: "t1", Scope: database.ModelTagScope, + }).Return( + &database.Tag{Name: "t1", ID: 11}, nil, + ) + mc.mocks.stores.RepoMock().EXPECT().DeleteAllTags(ctx, int64(1)).Return(nil) + mc.mocks.stores.RepoMock().EXPECT().BatchCreateRepoTags(ctx, []database.RepositoryTag{ + {RepositoryID: 1, TagID: 11}, + }).Return(nil) + mc.mocks.stores.RepoMock().EXPECT().DeleteAllFiles(ctx, int64(1)).Return(nil) + mockedClient.EXPECT().FileList(ctx, svs[0]).Return([]types.File{ + {Name: "foo.go"}, + }, nil) + mc.mocks.stores.FileMock().EXPECT().BatchCreate(ctx, []database.File{ + {Name: "foo.go", ParentPath: "/", RepositoryID: 1}, + }).Return(nil) + mc.mocks.stores.ModelMock().EXPECT().CreateIfNotExist(ctx, database.Model{ + RepositoryID: 1, + Repository: dbrepo, + }).Return(nil, nil) + + // new dataset mock + dbrepo = &database.Repository{ + Path: "CSG_ns/user", + GitPath: "datasets_CSG_ns/user", + Name: "user", + Readme: "readme", + Source: types.OpenCSGSource, + SyncStatus: types.SyncStatusPending, + RepositoryType: types.DatasetRepo, + } + mockedClient.EXPECT().DatasetInfo(ctx, svs[1]).Return(&types.Dataset{ + User: types.User{Nickname: "nn"}, + Path: "ns/user", + Tags: []types.RepoTag{{Name: "t2"}}, + }, nil) + mockedClient.EXPECT().ReadMeData(ctx, svs[1]).Return("readme", nil) + mc.mocks.stores.RepoMock().EXPECT().UpdateOrCreateRepo(ctx, *dbrepo).Return(dbrepo, nil) + dbrepo.ID = 2 + mc.mocks.stores.TagMock().EXPECT().FindOrCreate(ctx, database.Tag{ + Name: "t2", Scope: database.DatasetTagScope, + }).Return( + &database.Tag{Name: "t2", ID: 12}, nil, + ) + mc.mocks.stores.RepoMock().EXPECT().DeleteAllTags(ctx, int64(2)).Return(nil) + mc.mocks.stores.RepoMock().EXPECT().BatchCreateRepoTags(ctx, []database.RepositoryTag{ + {RepositoryID: 2, TagID: 12}, + }).Return(nil) + mc.mocks.stores.RepoMock().EXPECT().DeleteAllFiles(ctx, int64(2)).Return(nil) + mockedClient.EXPECT().FileList(ctx, svs[1]).Return([]types.File{ + {Name: "foo.go"}, + }, nil) + mc.mocks.stores.FileMock().EXPECT().BatchCreate(ctx, []database.File{ + {Name: "foo.go", ParentPath: "/", RepositoryID: 2}, + }).Return(nil) + mc.mocks.stores.DatasetMock().EXPECT().CreateIfNotExist(ctx, database.Dataset{ + RepositoryID: 2, + Repository: dbrepo, + }).Return(nil, nil) + + err := mc.SyncAsClient(context.TODO(), mockedClient) + require.Nil(t, err) + +} diff --git a/component/recom.go b/component/recom.go index c3661891..0e3183d8 100644 --- a/component/recom.go +++ b/component/recom.go @@ -14,9 +14,9 @@ import ( ) type recomComponentImpl struct { - rs database.RecomStore - repos database.RepoStore - gs gitserver.GitServer + recomStore database.RecomStore + repoStore database.RepoStore + gitServer gitserver.GitServer } type RecomComponent interface { @@ -33,18 +33,18 @@ func NewRecomComponent(cfg *config.Config) (RecomComponent, error) { } return &recomComponentImpl{ - rs: database.NewRecomStore(), - repos: database.NewRepoStore(), - gs: gs, + recomStore: database.NewRecomStore(), + repoStore: database.NewRepoStore(), + gitServer: gs, }, nil } func (rc *recomComponentImpl) SetOpWeight(ctx context.Context, repoID, weight int64) error { - _, err := rc.repos.FindById(ctx, repoID) + _, err := rc.repoStore.FindById(ctx, repoID) if err != nil { return fmt.Errorf("failed to find repository with id %d, err:%w", repoID, err) } - return rc.rs.UpsetOpWeights(ctx, repoID, weight) + return rc.recomStore.UpsetOpWeights(ctx, repoID, weight) } // loop through repositories and calculate the recom score of the repository @@ -54,7 +54,7 @@ func (rc *recomComponentImpl) CalculateRecomScore(ctx context.Context) { slog.Error("Error loading weights", "error", err) return } - repos, err := rc.repos.All(ctx) + repos, err := rc.repoStore.All(ctx) if err != nil { slog.Error("Error fetching repositories", "error", err) return @@ -62,7 +62,7 @@ func (rc *recomComponentImpl) CalculateRecomScore(ctx context.Context) { for _, repo := range repos { repoID := repo.ID score := rc.CalcTotalScore(ctx, repo, weights) - err := rc.rs.UpsertScore(ctx, repoID, score) + err := rc.recomStore.UpsertScore(ctx, repoID, score) if err != nil { slog.Error("Error updating recom score", slog.Int64("repo_id", repoID), slog.Float64("score", score), slog.String("error", err.Error())) @@ -132,7 +132,7 @@ func (rc *recomComponentImpl) calcQualityScore(ctx context.Context, repo *databa score := 0.0 // get file counts from git server namespace, name := repo.NamespaceAndName() - files, err := getFilePaths(namespace, name, "", repo.RepositoryType, "", rc.gs.GetRepoFileTree) + files, err := getFilePaths(namespace, name, "", repo.RepositoryType, "", rc.gitServer.GetRepoFileTree) if err != nil { return 0, fmt.Errorf("failed to get repo file tree,%w", err) } @@ -157,7 +157,7 @@ func (rc *recomComponentImpl) calcQualityScore(ctx context.Context, repo *databa func (rc *recomComponentImpl) loadWeights() (map[string]string, error) { ctx := context.Background() - items, err := rc.rs.LoadWeights(ctx) + items, err := rc.recomStore.LoadWeights(ctx) if err != nil { return nil, err } diff --git a/component/recom_test.go b/component/recom_test.go index 139e9eac..bba569e3 100644 --- a/component/recom_test.go +++ b/component/recom_test.go @@ -6,25 +6,48 @@ import ( "time" "github.com/stretchr/testify/mock" - gsmock "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/git/gitserver" "opencsg.com/csghub-server/builder/git/gitserver" "opencsg.com/csghub-server/builder/store/database" - "opencsg.com/csghub-server/common/tests" ) -func NewTestRecomComponent(stores *tests.MockStores, gitServer gitserver.GitServer) *recomComponentImpl { - return &recomComponentImpl{ - repos: stores.Repo, - gs: gitServer, - } +// func TestRecomComponent_SetOpWeight(t *testing.T) { +// ctx := context.TODO() +// rc := initializeTestRecomComponent(ctx, t) + +// rc.mocks.stores.RepoMock().EXPECT().FindById(ctx, int64(1)).Return(&database.Repository{}, nil) +// rc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{ +// RoleMask: "admin", +// }, nil) +// rc.mocks.stores.RecomMock().EXPECT().UpsetOpWeights(ctx, int64(1), int64(100)).Return(nil) + +// err := rc.SetOpWeight(ctx, 1, 100) +// require.Nil(t, err) +// } + +func TestRecomComponent_CalculateRecomScore(t *testing.T) { + ctx := context.TODO() + rc := initializeTestRecomComponent(ctx, t) + + rc.mocks.stores.RecomMock().EXPECT().LoadWeights(mock.Anything).Return( + []*database.RecomWeight{{Name: "freshness", WeightExp: "score = 12.34"}}, nil, + ) + rc.mocks.stores.RepoMock().EXPECT().All(ctx).Return([]*database.Repository{ + {ID: 1, Path: "foo/bar"}, + }, nil) + rc.mocks.stores.RecomMock().EXPECT().UpsertScore(ctx, int64(1), 12.34).Return(nil) + rc.mocks.gitServer.EXPECT().GetRepoFileTree(mock.Anything, gitserver.GetRepoInfoByPathReq{ + Namespace: "foo", + Name: "bar", + }).Return(nil, nil) + + rc.CalculateRecomScore(ctx) } func TestRecomComponent_CalculateTotalScore(t *testing.T) { - gitServer := gsmock.NewMockGitServer(t) - rc := &recomComponentImpl{gs: gitServer} ctx := context.TODO() + rc := initializeTestRecomComponent(ctx, t) - gitServer.EXPECT().GetRepoFileTree(mock.Anything, gitserver.GetRepoInfoByPathReq{ + rc.mocks.gitServer.EXPECT().GetRepoFileTree(mock.Anything, gitserver.GetRepoInfoByPathReq{ Namespace: "foo", Name: "bar", }).Return(nil, nil) diff --git a/component/runtime_architecture.go b/component/runtime_architecture.go index cb79f139..b6f2f872 100644 --- a/component/runtime_architecture.go +++ b/component/runtime_architecture.go @@ -8,6 +8,7 @@ import ( "strings" "sync" + "opencsg.com/csghub-server/builder/git" "opencsg.com/csghub-server/builder/git/gitserver" "opencsg.com/csghub-server/builder/store/database" "opencsg.com/csghub-server/common/config" @@ -22,11 +23,14 @@ var ( ) type runtimeArchitectureComponentImpl struct { - r *repoComponentImpl - ras database.RuntimeArchitecturesStore - rfs database.RuntimeFrameworksStore - ts database.TagStore - rms database.ResourceModelStore + repoComponent RepoComponent + repoStore database.RepoStore + repoRuntimeFrameworkStore database.RepositoriesRuntimeFrameworkStore + runtimeArchStore database.RuntimeArchitecturesStore + runtimeFrameworksStore database.RuntimeFrameworksStore + tagStore database.TagStore + resouceModelStore database.ResourceModelStore + gitServer gitserver.GitServer } type RuntimeArchitectureComponent interface { @@ -47,20 +51,28 @@ type RuntimeArchitectureComponent interface { func NewRuntimeArchitectureComponent(config *config.Config) (RuntimeArchitectureComponent, error) { c := &runtimeArchitectureComponentImpl{} - c.rfs = database.NewRuntimeFrameworksStore() - c.ras = database.NewRuntimeArchitecturesStore() - c.ts = database.NewTagStore() - c.rms = database.NewResourceModelStore() + c.runtimeFrameworksStore = database.NewRuntimeFrameworksStore() + c.runtimeArchStore = database.NewRuntimeArchitecturesStore() + c.tagStore = database.NewTagStore() + c.resouceModelStore = database.NewResourceModelStore() repo, err := NewRepoComponentImpl(config) if err != nil { return nil, fmt.Errorf("fail to create repo component, %w", err) } - c.r = repo + c.repoComponent = repo + c.repoStore = database.NewRepoStore() + c.repoRuntimeFrameworkStore = database.NewRepositoriesRuntimeFramework() + c.gitServer, err = git.NewGitServer(config) + if err != nil { + newError := fmt.Errorf("fail to create git server,error:%w", err) + slog.Error(newError.Error()) + return nil, newError + } return c, nil } func (c *runtimeArchitectureComponentImpl) ListByRuntimeFrameworkID(ctx context.Context, id int64) ([]database.RuntimeArchitecture, error) { - archs, err := c.ras.ListByRuntimeFrameworkID(ctx, id) + archs, err := c.runtimeArchStore.ListByRuntimeFrameworkID(ctx, id) if err != nil { return nil, fmt.Errorf("list runtime arch failed, %w", err) } @@ -68,7 +80,7 @@ func (c *runtimeArchitectureComponentImpl) ListByRuntimeFrameworkID(ctx context. } func (c *runtimeArchitectureComponentImpl) SetArchitectures(ctx context.Context, id int64, architectures []string) ([]string, error) { - _, err := c.r.runtimeFrameworksStore.FindByID(ctx, id) + _, err := c.runtimeFrameworksStore.FindByID(ctx, id) if err != nil { return nil, fmt.Errorf("invalid runtime framework id, %w", err) } @@ -77,7 +89,7 @@ func (c *runtimeArchitectureComponentImpl) SetArchitectures(ctx context.Context, if len(strings.Trim(arch, " ")) < 1 { continue } - err := c.ras.Add(ctx, database.RuntimeArchitecture{ + err := c.runtimeArchStore.Add(ctx, database.RuntimeArchitecture{ RuntimeFrameworkID: id, ArchitectureName: strings.Trim(arch, " "), }) @@ -89,7 +101,7 @@ func (c *runtimeArchitectureComponentImpl) SetArchitectures(ctx context.Context, } func (c *runtimeArchitectureComponentImpl) DeleteArchitectures(ctx context.Context, id int64, architectures []string) ([]string, error) { - _, err := c.r.runtimeFrameworksStore.FindByID(ctx, id) + _, err := c.runtimeFrameworksStore.FindByID(ctx, id) if err != nil { return nil, fmt.Errorf("invalid runtime framework id, %w", err) } @@ -98,7 +110,7 @@ func (c *runtimeArchitectureComponentImpl) DeleteArchitectures(ctx context.Conte if len(strings.Trim(arch, " ")) < 1 { continue } - err := c.ras.DeleteByRuntimeIDAndArchName(ctx, id, strings.Trim(arch, " ")) + err := c.runtimeArchStore.DeleteByRuntimeIDAndArchName(ctx, id, strings.Trim(arch, " ")) if err != nil { failedDeletes = append(failedDeletes, arch) } @@ -107,11 +119,11 @@ func (c *runtimeArchitectureComponentImpl) DeleteArchitectures(ctx context.Conte } func (c *runtimeArchitectureComponentImpl) ScanArchitecture(ctx context.Context, id int64, scanType int, models []string) error { - frame, err := c.r.runtimeFrameworksStore.FindByID(ctx, id) + frame, err := c.runtimeFrameworksStore.FindByID(ctx, id) if err != nil { return fmt.Errorf("invalid runtime framework id, %w", err) } - archs, err := c.ras.ListByRuntimeFrameworkID(ctx, id) + archs, err := c.runtimeArchStore.ListByRuntimeFrameworkID(ctx, id) if err != nil { return fmt.Errorf("list runtime arch failed, %w", err) } @@ -156,18 +168,18 @@ func (c *runtimeArchitectureComponentImpl) ScanArchitecture(ctx context.Context, } func (c *runtimeArchitectureComponentImpl) scanNewModels(ctx context.Context, req types.ScanReq) error { - repos, err := c.r.repoStore.GetRepoWithoutRuntimeByID(ctx, req.FrameID, req.Models) + repos, err := c.repoStore.GetRepoWithoutRuntimeByID(ctx, req.FrameID, req.Models) if err != nil { return fmt.Errorf("failed to get repos without runtime by ID, %w", err) } if repos == nil { return nil } - runtime_framework, err := c.rfs.FindByID(ctx, req.FrameID) + runtime_framework, err := c.runtimeFrameworksStore.FindByID(ctx, req.FrameID) if err != nil { return fmt.Errorf("failed to get runtime framework by ID, %w", err) } - runtime_framework_tags, _ := c.ts.GetTagsByScopeAndCategories(ctx, "model", []string{"runtime_framework", "resource"}) + runtime_framework_tags, _ := c.tagStore.GetTagsByScopeAndCategories(ctx, "model", []string{"runtime_framework", "resource"}) for _, repo := range repos { namespace, name := repo.NamespaceAndName() arch, err := c.GetArchitectureFromConfig(ctx, namespace, name) @@ -187,7 +199,7 @@ func (c *runtimeArchitectureComponentImpl) scanNewModels(ctx context.Context, re if !exist && !isSupportedRM { continue } - err = c.r.repoRuntimeFrameworkStore.Add(ctx, req.FrameID, repo.ID, req.FrameType) + err = c.repoRuntimeFrameworkStore.Add(ctx, req.FrameID, repo.ID, req.FrameType) if err != nil { slog.Warn("fail to create relation", slog.Any("repo", repo.Path), slog.Any("frameid", req.FrameID), slog.Any("error", err)) } @@ -207,7 +219,7 @@ func (c *runtimeArchitectureComponentImpl) scanNewModels(ctx context.Context, re // check if it's supported model resource by name func (c *runtimeArchitectureComponentImpl) IsSupportedModelResource(ctx context.Context, modelName string, rf *database.RuntimeFramework, id int64) (bool, error) { trimModel := strings.Replace(strings.ToLower(modelName), "meta-", "", 1) - rm, err := c.rms.CheckModelNameNotInRFRepo(ctx, trimModel, id) + rm, err := c.resouceModelStore.CheckModelNameNotInRFRepo(ctx, trimModel, id) if err != nil || rm == nil { return false, err } @@ -230,7 +242,7 @@ func (c *runtimeArchitectureComponentImpl) IsSupportedModelResource(ctx context. } func (c *runtimeArchitectureComponentImpl) scanExistModels(ctx context.Context, req types.ScanReq) error { - repos, err := c.r.repoStore.GetRepoWithRuntimeByID(ctx, req.FrameID, req.Models) + repos, err := c.repoStore.GetRepoWithRuntimeByID(ctx, req.FrameID, req.Models) if err != nil { return fmt.Errorf("fail to get repos with runtime by ID, %w", err) } @@ -251,7 +263,7 @@ func (c *runtimeArchitectureComponentImpl) scanExistModels(ctx context.Context, if exist { continue } - err = c.r.repoRuntimeFrameworkStore.Delete(ctx, req.FrameID, repo.ID, req.FrameType) + err = c.repoRuntimeFrameworkStore.Delete(ctx, req.FrameID, repo.ID, req.FrameType) if err != nil { slog.Warn("fail to remove relation", slog.Any("repo", repo.Path), slog.Any("frameid", req.FrameID), slog.Any("error", err)) } @@ -282,7 +294,7 @@ func (c *runtimeArchitectureComponentImpl) GetArchitectureFromConfig(ctx context } func (c *runtimeArchitectureComponentImpl) getConfigContent(ctx context.Context, namespace, name string) (string, error) { - content, err := c.r.git.GetRepoFileRaw(ctx, gitserver.GetRepoInfoByPathReq{ + content, err := c.gitServer.GetRepoFileRaw(ctx, gitserver.GetRepoInfoByPathReq{ Namespace: namespace, Name: name, Ref: MainBranch, @@ -297,10 +309,10 @@ func (c *runtimeArchitectureComponentImpl) getConfigContent(ctx context.Context, // remove runtime_framework tag from model func (c *runtimeArchitectureComponentImpl) RemoveRuntimeFrameworkTag(ctx context.Context, rftags []*database.Tag, repoId, rfId int64) { - rfw, _ := c.rfs.FindByID(ctx, rfId) + rfw, _ := c.runtimeFrameworksStore.FindByID(ctx, rfId) for _, tag := range rftags { if strings.Contains(rfw.FrameImage, tag.Name) { - err := c.ts.RemoveRepoTags(ctx, repoId, []int64{tag.ID}) + err := c.tagStore.RemoveRepoTags(ctx, repoId, []int64{tag.ID}) if err != nil { slog.Warn("fail to remove runtime_framework tag from model repo", slog.Any("repoId", repoId), slog.Any("runtime_framework_id", rfId), slog.Any("error", err)) } @@ -310,13 +322,13 @@ func (c *runtimeArchitectureComponentImpl) RemoveRuntimeFrameworkTag(ctx context // add runtime_framework tag to model func (c *runtimeArchitectureComponentImpl) AddRuntimeFrameworkTag(ctx context.Context, rftags []*database.Tag, repoId, rfId int64) error { - rfw, err := c.rfs.FindByID(ctx, rfId) + rfw, err := c.runtimeFrameworksStore.FindByID(ctx, rfId) if err != nil { return err } for _, tag := range rftags { if strings.Contains(rfw.FrameImage, tag.Name) { - err := c.ts.UpsertRepoTags(ctx, repoId, []int64{}, []int64{tag.ID}) + err := c.tagStore.UpsertRepoTags(ctx, repoId, []int64{}, []int64{tag.ID}) if err != nil { slog.Warn("fail to add runtime_framework tag to model repo", slog.Any("repoId", repoId), slog.Any("runtime_framework_id", rfId), slog.Any("error", err)) } @@ -327,14 +339,14 @@ func (c *runtimeArchitectureComponentImpl) AddRuntimeFrameworkTag(ctx context.Co // add resource tag to model func (c *runtimeArchitectureComponentImpl) AddResourceTag(ctx context.Context, rstags []*database.Tag, modelname string, repoId int64) error { - rms, err := c.rms.FindByModelName(ctx, modelname) + rms, err := c.resouceModelStore.FindByModelName(ctx, modelname) if err != nil { return err } for _, rm := range rms { for _, tag := range rstags { if strings.Contains(rm.ResourceName, tag.Name) { - err := c.ts.UpsertRepoTags(ctx, repoId, []int64{}, []int64{tag.ID}) + err := c.tagStore.UpsertRepoTags(ctx, repoId, []int64{}, []int64{tag.ID}) if err != nil { slog.Warn("fail to add resource tag to model repo", slog.Any("repoId", repoId), slog.Any("error", err)) } diff --git a/component/runtime_architecture_test.go b/component/runtime_architecture_test.go new file mode 100644 index 00000000..5db7dfdf --- /dev/null +++ b/component/runtime_architecture_test.go @@ -0,0 +1,209 @@ +package component + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/git/gitserver" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +func TestRuntimeArchComponent_ListByRuntimeFrameworkID(t *testing.T) { + ctx := context.TODO() + rc := initializeTestRuntimeArchComponent(ctx, t) + + data := []database.RuntimeArchitecture{ + {ID: 123, ArchitectureName: "arch"}, + } + rc.mocks.stores.RuntimeArchMock().EXPECT().ListByRuntimeFrameworkID(ctx, int64(1)).Return( + data, nil, + ) + resp, err := rc.ListByRuntimeFrameworkID(ctx, 1) + require.Nil(t, err) + require.Equal(t, data, resp) + +} + +func TestRuntimeArchComponent_SetArchitectures(t *testing.T) { + ctx := context.TODO() + rc := initializeTestRuntimeArchComponent(ctx, t) + + rc.mocks.stores.RuntimeFrameworkMock().EXPECT().FindByID(ctx, int64(1)).Return(nil, nil) + rc.mocks.stores.RuntimeArchMock().EXPECT().Add(ctx, database.RuntimeArchitecture{ + RuntimeFrameworkID: 1, + ArchitectureName: "foo", + }).Return(nil) + rc.mocks.stores.RuntimeArchMock().EXPECT().Add(ctx, database.RuntimeArchitecture{ + RuntimeFrameworkID: 1, + ArchitectureName: "bar", + }).Return(errors.New("")) + + failed, err := rc.SetArchitectures(ctx, int64(1), []string{"foo", "bar"}) + require.Nil(t, err) + require.Equal(t, []string{"bar"}, failed) + +} + +func TestRuntimeArchComponent_DeleteArchitectures(t *testing.T) { + ctx := context.TODO() + rc := initializeTestRuntimeArchComponent(ctx, t) + + rc.mocks.stores.RuntimeFrameworkMock().EXPECT().FindByID(ctx, int64(1)).Return(nil, nil) + rc.mocks.stores.RuntimeArchMock().EXPECT().DeleteByRuntimeIDAndArchName(ctx, int64(1), "foo").Return(nil) + rc.mocks.stores.RuntimeArchMock().EXPECT().DeleteByRuntimeIDAndArchName(ctx, int64(1), "bar").Return(errors.New("")) + + failed, err := rc.DeleteArchitectures(ctx, int64(1), []string{"foo", "bar"}) + require.Nil(t, err) + require.Equal(t, []string{"bar"}, failed) + +} + +func TestRuntimeArchComponent_ScanArchitectures(t *testing.T) { + ctx := context.TODO() + rc := initializeTestRuntimeArchComponent(ctx, t) + + rc.mocks.stores.RuntimeFrameworkMock().EXPECT().FindByID(ctx, int64(1)).Return( + &database.RuntimeFramework{ + Type: 11, + }, nil, + ) + data := []database.RuntimeArchitecture{ + {ID: 123, ArchitectureName: "arch"}, + {ID: 124, ArchitectureName: "foo"}, + } + rc.mocks.stores.RuntimeArchMock().EXPECT().ListByRuntimeFrameworkID(ctx, int64(1)).Return( + data, nil, + ) + + // scan exists mocks + rc.mocks.stores.RepoMock().EXPECT().GetRepoWithRuntimeByID(ctx, int64(1), []string{"foo"}).Return([]database.Repository{ + {Path: "foo/bar"}, + }, nil) + rc.mocks.gitServer.EXPECT().GetRepoFileRaw(ctx, gitserver.GetRepoInfoByPathReq{ + Namespace: "foo", + Name: "bar", + Ref: "main", + Path: ConfigFileName, + RepoType: types.ModelRepo, + }).Return(`{"architectures": ["foo","bar"]}`, nil) + + // scan new mocks + rc.mocks.stores.RepoMock().EXPECT().GetRepoWithoutRuntimeByID(ctx, int64(1), []string{"foo"}).Return([]database.Repository{ + {Path: "foo/bar"}, + }, nil) + rc.mocks.stores.TagMock().EXPECT().GetTagsByScopeAndCategories(ctx, database.ModelTagScope, []string{ + "runtime_framework", "resource", + }).Return([]*database.Tag{}, nil) + rc.mocks.stores.RepoRuntimeFrameworkMock().EXPECT().Add(ctx, int64(1), int64(0), 11).Return(nil) + rc.mocks.stores.ResourceModelMock().EXPECT().CheckModelNameNotInRFRepo(ctx, "bar", int64(0)).Return( + &database.ResourceModel{}, nil, + ) + rc.mocks.stores.ResourceModelMock().EXPECT().FindByModelName(ctx, "bar").Return( + []*database.ResourceModel{ + {ResourceName: "r1"}, + {ResourceName: "r2"}, + }, nil, + ) + + err := rc.ScanArchitecture(ctx, 1, 0, []string{"foo"}) + require.Nil(t, err) + // wait async code finish + ScanLock.Lock() + _ = 1 + ScanLock.Unlock() + +} + +func TestRuntimeArchComponent_IsSupportedModelResource(t *testing.T) { + ctx := context.TODO() + rc := initializeTestRuntimeArchComponent(ctx, t) + + rc.mocks.stores.ResourceModelMock().EXPECT().CheckModelNameNotInRFRepo(ctx, "model", int64(1)).Return( + &database.ResourceModel{EngineName: "a"}, nil, + ) + + r, err := rc.IsSupportedModelResource(ctx, "meta-model", &database.RuntimeFramework{ + FrameImage: "a/b", + }, 1) + require.Nil(t, err, nil) + require.False(t, r) +} + +func TestRuntimeArchComponent_GetArchitectureFromConfig(t *testing.T) { + ctx := context.TODO() + rc := initializeTestRuntimeArchComponent(ctx, t) + + rc.mocks.gitServer.EXPECT().GetRepoFileRaw(ctx, gitserver.GetRepoInfoByPathReq{ + Namespace: "foo", + Name: "bar", + Ref: "main", + Path: ConfigFileName, + RepoType: types.ModelRepo, + }).Return(`{"architectures": ["foo","bar"]}`, nil) + + arch, err := rc.GetArchitectureFromConfig(ctx, "foo", "bar") + require.Nil(t, err) + require.Equal(t, "foo", arch) + +} + +// func TestRuntimeArchComponent_RemoveRuntimeFrameworkTag(t *testing.T) { +// ctx := context.TODO() +// rc := initializeTestRuntimeArchComponent(ctx, t) + +// rc.mocks.stores.RuntimeFrameworkMock().EXPECT().FindByID(ctx, int64(2)).Return( +// &database.RuntimeFramework{ +// FrameImage: "img", +// FrameNpuImage: "npu", +// }, nil, +// ) +// rc.mocks.stores.TagMock().EXPECT().RemoveRepoTags(ctx, int64(1), []int64{1}).Return(nil) +// rc.mocks.stores.TagMock().EXPECT().RemoveRepoTags(ctx, int64(1), []int64{2}).Return(nil) + +// rc.RemoveRuntimeFrameworkTag(ctx, []*database.Tag{ +// {Name: "img", ID: 1}, +// {Name: "npu", ID: 2}, +// }, int64(1), int64(2)) +// } + +// func TestRuntimeArchComponent_AddRuntimeFrameworkTag(t *testing.T) { +// ctx := context.TODO() +// rc := initializeTestRuntimeArchComponent(ctx, t) + +// rc.mocks.stores.RuntimeFrameworkMock().EXPECT().FindByID(ctx, int64(2)).Return( +// &database.RuntimeFramework{ +// FrameImage: "img", +// FrameNpuImage: "npu", +// }, nil, +// ) +// rc.mocks.stores.TagMock().EXPECT().UpsertRepoTags(ctx, int64(1), []int64{}, []int64{1}).Return(nil) +// rc.mocks.stores.TagMock().EXPECT().UpsertRepoTags(ctx, int64(1), []int64{}, []int64{2}).Return(nil) + +// err := rc.AddRuntimeFrameworkTag(ctx, []*database.Tag{ +// {Name: "img", ID: 1}, +// {Name: "npu", ID: 2}, +// }, int64(1), int64(2)) +// require.Nil(t, err) +// } + +// func TestRuntimeArchComponent_AddResourceTag(t *testing.T) { +// ctx := context.TODO() +// rc := initializeTestRuntimeArchComponent(ctx, t) + +// rc.mocks.stores.ResourceModelMock().EXPECT().FindByModelName(ctx, "model").Return( +// []*database.ResourceModel{ +// {ResourceName: "r1"}, +// {ResourceName: "r2"}, +// }, nil, +// ) +// rc.mocks.stores.TagMock().EXPECT().UpsertRepoTags(ctx, int64(1), []int64{}, []int64{1}).Return(nil) +// rc.mocks.stores.TagMock().EXPECT().UpsertRepoTags(ctx, int64(1), []int64{}, []int64{2}).Return(nil) + +// err := rc.AddResourceTag(ctx, []*database.Tag{ +// {Name: "r1", ID: 1}, +// }, "model", int64(1)) +// require.Nil(t, err) +// } diff --git a/component/space_resource_test.go b/component/space_resource_test.go new file mode 100644 index 00000000..024cc1e9 --- /dev/null +++ b/component/space_resource_test.go @@ -0,0 +1,94 @@ +package component + +// func TestSpaceResourceComponent_Index(t *testing.T) { +// ctx := context.TODO() +// sc := initializeTestSpaceResourceComponent(ctx, t) + +// sc.mocks.deployer.EXPECT().ListCluster(ctx).Return([]types.ClusterRes{ +// {ClusterID: "c1"}, +// }, nil) +// sc.mocks.stores.SpaceResourceMock().EXPECT().Index(ctx, "c1").Return( +// []database.SpaceResource{ +// {ID: 1, Name: "sr", Resources: `{"memory": "1000"}`}, +// }, nil, +// ) +// sc.mocks.deployer.EXPECT().GetClusterById(ctx, "c1").Return(&types.ClusterRes{}, nil) +// sc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{ +// UUID: "uid", +// }, nil) + +// data, err := sc.Index(ctx, "", 1) +// require.Nil(t, err) +// require.Equal(t, []types.SpaceResource{ +// { +// ID: 1, Name: "sr", Resources: "{\"memory\": \"1000\"}", +// IsAvailable: false, Type: "cpu", +// }, +// { +// ID: 0, Name: "", Resources: "{\"memory\": \"2000\"}", IsAvailable: true, +// Type: "cpu", +// }, +// }, data) + +// } + +// func TestSpaceResourceComponent_Update(t *testing.T) { +// ctx := context.TODO() +// sc := initializeTestSpaceResourceComponent(ctx, t) + +// sc.mocks.stores.SpaceResourceMock().EXPECT().FindByID(ctx, int64(1)).Return( +// &database.SpaceResource{}, nil, +// ) +// sc.mocks.stores.SpaceResourceMock().EXPECT().Update(ctx, database.SpaceResource{ +// Name: "n", +// Resources: "r", +// }).Return(&database.SpaceResource{ID: 1, Name: "n", Resources: "r"}, nil) + +// data, err := sc.Update(ctx, &types.UpdateSpaceResourceReq{ +// ID: 1, +// Name: "n", +// Resources: "r", +// }) +// require.Nil(t, err) +// require.Equal(t, &types.SpaceResource{ +// ID: 1, +// Name: "n", +// Resources: "r", +// }, data) +// } + +// func TestSpaceResourceComponent_Create(t *testing.T) { +// ctx := context.TODO() +// sc := initializeTestSpaceResourceComponent(ctx, t) + +// sc.mocks.stores.SpaceResourceMock().EXPECT().Create(ctx, database.SpaceResource{ +// Name: "n", +// Resources: "r", +// ClusterID: "c", +// }).Return(&database.SpaceResource{ID: 1, Name: "n", Resources: "r"}, nil) + +// data, err := sc.Create(ctx, &types.CreateSpaceResourceReq{ +// Name: "n", +// Resources: "r", +// ClusterID: "c", +// }) +// require.Nil(t, err) +// require.Equal(t, &types.SpaceResource{ +// ID: 1, +// Name: "n", +// Resources: "r", +// }, data) +// } + +// func TestSpaceResourceComponent_Delete(t *testing.T) { +// ctx := context.TODO() +// sc := initializeTestSpaceResourceComponent(ctx, t) + +// sc.mocks.stores.SpaceResourceMock().EXPECT().FindByID(ctx, int64(1)).Return( +// &database.SpaceResource{}, nil, +// ) +// sc.mocks.stores.SpaceResourceMock().EXPECT().Delete(ctx, database.SpaceResource{}).Return(nil) + +// err := sc.Delete(ctx, 1) +// require.Nil(t, err) +// } diff --git a/component/space_sdk.go b/component/space_sdk.go index 1c0eca17..497dfede 100644 --- a/component/space_sdk.go +++ b/component/space_sdk.go @@ -18,18 +18,18 @@ type SpaceSdkComponent interface { func NewSpaceSdkComponent(config *config.Config) (SpaceSdkComponent, error) { c := &spaceSdkComponentImpl{} - c.sss = database.NewSpaceSdkStore() + c.spaceSdkStore = database.NewSpaceSdkStore() return c, nil } type spaceSdkComponentImpl struct { - sss database.SpaceSdkStore + spaceSdkStore database.SpaceSdkStore } func (c *spaceSdkComponentImpl) Index(ctx context.Context) ([]types.SpaceSdk, error) { var result []types.SpaceSdk - databaseSpaceSdks, err := c.sss.Index(ctx) + databaseSpaceSdks, err := c.spaceSdkStore.Index(ctx) if err != nil { return nil, err } @@ -45,7 +45,7 @@ func (c *spaceSdkComponentImpl) Index(ctx context.Context) ([]types.SpaceSdk, er } func (c *spaceSdkComponentImpl) Update(ctx context.Context, req *types.UpdateSpaceSdkReq) (*types.SpaceSdk, error) { - ss, err := c.sss.FindByID(ctx, req.ID) + ss, err := c.spaceSdkStore.FindByID(ctx, req.ID) if err != nil { slog.Error("error getting space sdk", slog.Any("error", err)) return nil, err @@ -53,7 +53,7 @@ func (c *spaceSdkComponentImpl) Update(ctx context.Context, req *types.UpdateSpa ss.Name = req.Name ss.Version = req.Version - ss, err = c.sss.Update(ctx, *ss) + ss, err = c.spaceSdkStore.Update(ctx, *ss) if err != nil { slog.Error("error getting space sdk", slog.Any("error", err)) return nil, err @@ -73,7 +73,7 @@ func (c *spaceSdkComponentImpl) Create(ctx context.Context, req *types.CreateSpa Name: req.Name, Version: req.Version, } - res, err := c.sss.Create(ctx, ss) + res, err := c.spaceSdkStore.Create(ctx, ss) if err != nil { slog.Error("error creating space sdk", slog.Any("error", err)) return nil, err @@ -89,13 +89,13 @@ func (c *spaceSdkComponentImpl) Create(ctx context.Context, req *types.CreateSpa } func (c *spaceSdkComponentImpl) Delete(ctx context.Context, id int64) error { - ss, err := c.sss.FindByID(ctx, id) + ss, err := c.spaceSdkStore.FindByID(ctx, id) if err != nil { slog.Error("error finding space sdk", slog.Any("error", err)) return err } - err = c.sss.Delete(ctx, *ss) + err = c.spaceSdkStore.Delete(ctx, *ss) if err != nil { slog.Error("error deleting space sdk", slog.Any("error", err)) return err diff --git a/component/space_sdk_test.go b/component/space_sdk_test.go new file mode 100644 index 00000000..641f00d4 --- /dev/null +++ b/component/space_sdk_test.go @@ -0,0 +1,64 @@ +package component + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +func TestSpaceSdkComponent_Index(t *testing.T) { + ctx := context.TODO() + sc := initializeTestSpaceSdkComponent(ctx, t) + + sc.mocks.stores.SpaceSdkMock().EXPECT().Index(ctx).Return([]database.SpaceSdk{ + {ID: 1, Name: "s", Version: "1"}, + }, nil) + + data, err := sc.Index(ctx) + require.Nil(t, err) + require.Equal(t, []types.SpaceSdk{{ID: 1, Name: "s", Version: "1"}}, data) +} + +func TestSpaceSdkComponent_Update(t *testing.T) { + ctx := context.TODO() + sc := initializeTestSpaceSdkComponent(ctx, t) + + s := &database.SpaceSdk{ID: 1} + sc.mocks.stores.SpaceSdkMock().EXPECT().FindByID(ctx, int64(1)).Return(s, nil) + s2 := *s + s2.Name = "n" + s2.Version = "v1" + sc.mocks.stores.SpaceSdkMock().EXPECT().Update(ctx, s2).Return(s, nil) + + data, err := sc.Update(ctx, &types.UpdateSpaceSdkReq{ID: 1, Name: "n", Version: "v1"}) + require.Nil(t, err) + require.Equal(t, &types.SpaceSdk{ID: 1, Name: "n", Version: "v1"}, data) +} + +func TestSpaceSdkComponent_Create(t *testing.T) { + ctx := context.TODO() + sc := initializeTestSpaceSdkComponent(ctx, t) + + s := database.SpaceSdk{Name: "n", Version: "v1"} + sc.mocks.stores.SpaceSdkMock().EXPECT().Create(ctx, s).Return(&s, nil) + s.ID = 1 + + data, err := sc.Create(ctx, &types.CreateSpaceSdkReq{Name: "n", Version: "v1"}) + require.Nil(t, err) + require.Equal(t, &types.SpaceSdk{ID: 1, Name: "n", Version: "v1"}, data) +} + +func TestSpaceSdkComponent_Delete(t *testing.T) { + ctx := context.TODO() + sc := initializeTestSpaceSdkComponent(ctx, t) + + s := &database.SpaceSdk{} + sc.mocks.stores.SpaceSdkMock().EXPECT().FindByID(ctx, int64(1)).Return(s, nil) + sc.mocks.stores.SpaceSdkMock().EXPECT().Delete(ctx, *s).Return(nil) + + err := sc.Delete(ctx, int64(1)) + require.Nil(t, err) +} diff --git a/component/tag.go b/component/tag.go index d428b895..2c95e4f3 100644 --- a/component/tag.go +++ b/component/tag.go @@ -16,7 +16,6 @@ import ( type TagComponent interface { AllTagsByScopeAndCategory(ctx context.Context, scope string, category string) ([]*database.Tag, error) - AllTags(ctx context.Context) ([]database.Tag, error) ClearMetaTags(ctx context.Context, repoType types.RepositoryType, namespace, name string) error UpdateMetaTags(ctx context.Context, tagScope database.TagScope, namespace, name, content string) ([]*database.RepositoryTag, error) UpdateLibraryTags(ctx context.Context, tagScope database.TagScope, namespace, name, oldFilePath, newFilePath string) error @@ -25,8 +24,8 @@ type TagComponent interface { func NewTagComponent(config *config.Config) (TagComponent, error) { tc := &tagComponentImpl{} - tc.ts = database.NewTagStore() - tc.rs = database.NewRepoStore() + tc.tagStore = database.NewTagStore() + tc.repoStore = database.NewRepoStore() if config.SensitiveCheck.Enable { tc.sensitiveChecker = rpc.NewModerationSvcHttpClient(fmt.Sprintf("%s:%d", config.Moderation.Host, config.Moderation.Port)) } @@ -34,21 +33,18 @@ func NewTagComponent(config *config.Config) (TagComponent, error) { } type tagComponentImpl struct { - ts database.TagStore - rs database.RepoStore + tagStore database.TagStore + repoStore database.RepoStore sensitiveChecker rpc.ModerationSvcClient } -func (c *tagComponentImpl) AllTags(ctx context.Context) ([]database.Tag, error) { - // TODO: query cache for tags at first - return c.ts.AllTags(ctx) -} -func (c *tagComponentImpl) AllTagsByScopeAndCategory(ctx context.Context, scope string, category string) ([]*database.Tag, error) { - return c.ts.AllTagsByScopeAndCategory(ctx, database.TagScope(scope), category) +func (tc *tagComponentImpl) AllTagsByScopeAndCategory(ctx context.Context, scope string, category string) ([]*database.Tag, error) { + return tc.tagStore.AllTagsByScopeAndCategory(ctx, database.TagScope(scope), category) } func (c *tagComponentImpl) ClearMetaTags(ctx context.Context, repoType types.RepositoryType, namespace, name string) error { - _, err := c.ts.SetMetaTags(ctx, repoType, namespace, name, nil) + + _, err := c.tagStore.SetMetaTags(ctx, repoType, namespace, name, nil) return err } @@ -60,13 +56,13 @@ func (c *tagComponentImpl) UpdateMetaTags(ctx context.Context, tagScope database // TODO:load from cache if tagScope == database.DatasetTagScope { - tp = tagparser.NewDatasetTagProcessor(c.ts) + tp = tagparser.NewDatasetTagProcessor(c.tagStore) repoType = types.DatasetRepo } else if tagScope == database.ModelTagScope { - tp = tagparser.NewModelTagProcessor(c.ts) + tp = tagparser.NewModelTagProcessor(c.tagStore) repoType = types.ModelRepo } else if tagScope == database.PromptTagScope { - tp = tagparser.NewPromptTagProcessor(c.ts) + tp = tagparser.NewPromptTagProcessor(c.tagStore) repoType = types.PromptRepo } else { // skip tag process for code and space now @@ -91,13 +87,13 @@ func (c *tagComponentImpl) UpdateMetaTags(ctx context.Context, tagScope database }) } - err = c.ts.SaveTags(ctx, tagToCreate) + err = c.tagStore.SaveTags(ctx, tagToCreate) if err != nil { slog.Error("Failed to save tags", slog.Any("error", err)) return nil, fmt.Errorf("failed to save tags, cause: %w", err) } - repo, err := c.rs.FindByPath(ctx, repoType, namespace, name) + repo, err := c.repoStore.FindByPath(ctx, repoType, namespace, name) if err != nil { slog.Error("failed to find repo", slog.Any("error", err)) return nil, fmt.Errorf("failed to find repo, cause: %w", err) @@ -105,14 +101,14 @@ func (c *tagComponentImpl) UpdateMetaTags(ctx context.Context, tagScope database metaTags := append(tagsMatched, tagToCreate...) var repoTags []*database.RepositoryTag - repoTags, err = c.ts.SetMetaTags(ctx, repoType, namespace, name, metaTags) + repoTags, err = c.tagStore.SetMetaTags(ctx, repoType, namespace, name, metaTags) if err != nil { slog.Error("failed to set dataset's tags", slog.String("namespace", namespace), slog.String("name", name), slog.Any("error", err)) return nil, fmt.Errorf("failed to set dataset's tags, cause: %w", err) } - err = c.rs.UpdateLicenseByTag(ctx, repo.ID) + err = c.repoStore.UpdateLicenseByTag(ctx, repo.ID) if err != nil { slog.Error("failed to update repo license tags", slog.Any("error", err)) } @@ -130,13 +126,13 @@ func (c *tagComponentImpl) UpdateLibraryTags(ctx context.Context, tagScope datab repoType types.RepositoryType ) if tagScope == database.DatasetTagScope { - allTags, err = c.ts.AllDatasetTags(ctx) + allTags, err = c.tagStore.AllDatasetTags(ctx) repoType = types.DatasetRepo } else if tagScope == database.ModelTagScope { - allTags, err = c.ts.AllModelTags(ctx) + allTags, err = c.tagStore.AllModelTags(ctx) repoType = types.ModelRepo } else if tagScope == database.PromptTagScope { - allTags, err = c.ts.AllPromptTags(ctx) + allTags, err = c.tagStore.AllPromptTags(ctx) repoType = types.PromptRepo } else { return nil @@ -156,7 +152,7 @@ func (c *tagComponentImpl) UpdateLibraryTags(ctx context.Context, tagScope datab oldLibTag = t } } - err = c.ts.SetLibraryTag(ctx, repoType, namespace, name, newLibTag, oldLibTag) + err = c.tagStore.SetLibraryTag(ctx, repoType, namespace, name, newLibTag, oldLibTag) if err != nil { slog.Error("failed to set %s's tags", string(repoType), slog.String("namespace", namespace), slog.String("name", name), slog.Any("error", err)) @@ -166,7 +162,7 @@ func (c *tagComponentImpl) UpdateLibraryTags(ctx context.Context, tagScope datab } func (c *tagComponentImpl) UpdateRepoTagsByCategory(ctx context.Context, tagScope database.TagScope, repoID int64, category string, tagNames []string) error { - allTags, err := c.ts.AllTagsByScopeAndCategory(ctx, tagScope, category) + allTags, err := c.tagStore.AllTagsByScopeAndCategory(ctx, tagScope, category) if err != nil { return fmt.Errorf("failed to get all tags of scope `%s`, error: %w", tagScope, err) } @@ -185,9 +181,9 @@ func (c *tagComponentImpl) UpdateRepoTagsByCategory(ctx context.Context, tagScop } var oldTagIDs []int64 - oldTagIDs, err = c.rs.TagIDs(ctx, repoID, category) + oldTagIDs, err = c.repoStore.TagIDs(ctx, repoID, category) if err != nil { return fmt.Errorf("failed to get old tag ids, error: %w", err) } - return c.ts.UpsertRepoTags(ctx, repoID, oldTagIDs, tagIDs) + return c.tagStore.UpsertRepoTags(ctx, repoID, oldTagIDs, tagIDs) } diff --git a/component/tag_test.go b/component/tag_test.go new file mode 100644 index 00000000..2f53aa30 --- /dev/null +++ b/component/tag_test.go @@ -0,0 +1,90 @@ +package component + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +func TestTagComponent_AllTagsByScopeAndCategory(t *testing.T) { + ctx := context.TODO() + tc := initializeTestTagComponent(ctx, t) + + tc.mocks.stores.TagMock().EXPECT().AllTagsByScopeAndCategory(ctx, database.CodeTagScope, "cat").Return( + []*database.Tag{{Name: "t"}}, nil, + ) + + data, err := tc.AllTagsByScopeAndCategory(ctx, "code", "cat") + require.Nil(t, err) + require.Equal(t, []*database.Tag{{Name: "t"}}, data) +} + +func TestTagComponent_ClearMetaTags(t *testing.T) { + ctx := context.TODO() + tc := initializeTestTagComponent(ctx, t) + + tc.mocks.stores.TagMock().EXPECT().SetMetaTags( + ctx, types.ModelRepo, "ns", "n", []*database.Tag(nil), + ).Return(nil, nil) + + err := tc.ClearMetaTags(ctx, types.ModelRepo, "ns", "n") + require.Nil(t, err) +} + +func TestTagComponent_UpdateMetaTags(t *testing.T) { + ctx := context.TODO() + tc := initializeTestTagComponent(ctx, t) + + tc.mocks.stores.TagMock().EXPECT().AllDatasetTags(ctx).Return([]*database.Tag{}, nil) + tc.mocks.stores.TagMock().EXPECT().SaveTags(ctx, []*database.Tag(nil)).Return(nil) + tc.mocks.stores.RepoMock().EXPECT().FindByPath(ctx, types.DatasetRepo, "ns", "n").Return( + &database.Repository{ID: 1}, nil, + ) + tc.mocks.stores.TagMock().EXPECT().SetMetaTags( + ctx, types.DatasetRepo, "ns", "n", []*database.Tag(nil), + ).Return(nil, nil) + tc.mocks.stores.RepoMock().EXPECT().UpdateLicenseByTag(ctx, int64(1)).Return(nil) + + data, err := tc.UpdateMetaTags(ctx, database.DatasetTagScope, "ns", "n", "") + require.Nil(t, err) + require.Equal(t, []*database.RepositoryTag(nil), data) +} + +func TestTagComponent_UpdateLibraryTags(t *testing.T) { + ctx := context.TODO() + tc := initializeTestTagComponent(ctx, t) + + tags := []*database.Tag{ + {Category: "framework", Name: "pytorch", ID: 1}, + {Category: "framework", Name: "tensorflow", ID: 2}, + } + tc.mocks.stores.TagMock().EXPECT().AllDatasetTags(ctx).Return(tags, nil) + tc.mocks.stores.TagMock().EXPECT().SetLibraryTag( + ctx, types.DatasetRepo, "ns", "n", tags[1], tags[0], + ).Return(nil) + + err := tc.UpdateLibraryTags( + ctx, database.DatasetTagScope, "ns", "n", "pytorch_model_old.bin", "tf_model_new.h5", + ) + require.Nil(t, err) + +} + +func TestTagComponent_UpdateRepoTagsByCategory(t *testing.T) { + ctx := context.TODO() + tc := initializeTestTagComponent(ctx, t) + + tc.mocks.stores.TagMock().EXPECT().AllTagsByScopeAndCategory(ctx, database.DatasetTagScope, "c").Return( + []*database.Tag{ + {Name: "t1", ID: 2}, + }, nil, + ) + tc.mocks.stores.RepoMock().EXPECT().TagIDs(ctx, int64(1), "c").Return([]int64{1}, nil) + tc.mocks.stores.TagMock().EXPECT().UpsertRepoTags(ctx, int64(1), []int64{1}, []int64{2}).Return(nil) + + err := tc.UpdateRepoTagsByCategory(ctx, database.DatasetTagScope, 1, "c", []string{"t1"}) + require.Nil(t, err) +} diff --git a/component/wire.go b/component/wire.go index 5567b131..89b3ab39 100644 --- a/component/wire.go +++ b/component/wire.go @@ -105,3 +105,243 @@ func initializeTestAccountingComponent(ctx context.Context, t interface { ) return &testAccountingWithMocks{} } + +type testDatasetViewerWithMocks struct { + *datasetViewerComponentImpl + mocks *Mocks +} + +func initializeTestDatasetViewerComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testDatasetViewerWithMocks { + wire.Build( + MockSuperSet, DatasetViewerComponentSet, + wire.Struct(new(testDatasetViewerWithMocks), "*"), + ) + return &testDatasetViewerWithMocks{} +} + +type testGitHTTPWithMocks struct { + *gitHTTPComponentImpl + mocks *Mocks +} + +func initializeTestGitHTTPComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testGitHTTPWithMocks { + wire.Build( + MockSuperSet, GitHTTPComponentSet, + wire.Struct(new(testGitHTTPWithMocks), "*"), + ) + return &testGitHTTPWithMocks{} +} + +type testDiscussionWithMocks struct { + *discussionComponentImpl + mocks *Mocks +} + +func initializeTestDiscussionComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testDiscussionWithMocks { + wire.Build( + MockSuperSet, DiscussionComponentSet, + wire.Struct(new(testDiscussionWithMocks), "*"), + ) + return &testDiscussionWithMocks{} +} + +type testRuntimeArchWithMocks struct { + *runtimeArchitectureComponentImpl + mocks *Mocks +} + +func initializeTestRuntimeArchComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testRuntimeArchWithMocks { + wire.Build( + MockSuperSet, RuntimeArchComponentSet, + wire.Struct(new(testRuntimeArchWithMocks), "*"), + ) + return &testRuntimeArchWithMocks{} +} + +type testMirrorWithMocks struct { + *mirrorComponentImpl + mocks *Mocks +} + +func initializeTestMirrorComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testMirrorWithMocks { + wire.Build( + MockSuperSet, MirrorComponentSet, + wire.Struct(new(testMirrorWithMocks), "*"), + ) + return &testMirrorWithMocks{} +} + +type testCollectionWithMocks struct { + *collectionComponentImpl + mocks *Mocks +} + +func initializeTestCollectionComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testCollectionWithMocks { + wire.Build( + MockSuperSet, CollectionComponentSet, + wire.Struct(new(testCollectionWithMocks), "*"), + ) + return &testCollectionWithMocks{} +} + +type testDatasetWithMocks struct { + *datasetComponentImpl + mocks *Mocks +} + +func initializeTestDatasetComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testDatasetWithMocks { + wire.Build( + MockSuperSet, DatasetComponentSet, + wire.Struct(new(testDatasetWithMocks), "*"), + ) + return &testDatasetWithMocks{} +} + +type testCodeWithMocks struct { + *codeComponentImpl + mocks *Mocks +} + +func initializeTestCodeComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testCodeWithMocks { + wire.Build( + MockSuperSet, CodeComponentSet, + wire.Struct(new(testCodeWithMocks), "*"), + ) + return &testCodeWithMocks{} +} + +type testMultiSyncWithMocks struct { + *multiSyncComponentImpl + mocks *Mocks +} + +func initializeTestMultiSyncComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testMultiSyncWithMocks { + wire.Build( + MockSuperSet, MultiSyncComponentSet, + wire.Struct(new(testMultiSyncWithMocks), "*"), + ) + return &testMultiSyncWithMocks{} +} + +type testInternalWithMocks struct { + *internalComponentImpl + mocks *Mocks +} + +func initializeTestInternalComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testInternalWithMocks { + wire.Build( + MockSuperSet, InternalComponentSet, + wire.Struct(new(testInternalWithMocks), "*"), + ) + return &testInternalWithMocks{} +} + +type testMirrorSourceWithMocks struct { + *mirrorSourceComponentImpl + mocks *Mocks +} + +func initializeTestMirrorSourceComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testMirrorSourceWithMocks { + wire.Build( + MockSuperSet, MirrorSourceComponentSet, + wire.Struct(new(testMirrorSourceWithMocks), "*"), + ) + return &testMirrorSourceWithMocks{} +} + +type testSpaceResourceWithMocks struct { + *spaceResourceComponentImpl + mocks *Mocks +} + +func initializeTestSpaceResourceComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testSpaceResourceWithMocks { + wire.Build( + MockSuperSet, SpaceResourceComponentSet, + wire.Struct(new(testSpaceResourceWithMocks), "*"), + ) + return &testSpaceResourceWithMocks{} +} + +type testTagWithMocks struct { + *tagComponentImpl + mocks *Mocks +} + +func initializeTestTagComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testTagWithMocks { + wire.Build( + MockSuperSet, TagComponentSet, + wire.Struct(new(testTagWithMocks), "*"), + ) + return &testTagWithMocks{} +} + +type testRecomWithMocks struct { + *recomComponentImpl + mocks *Mocks +} + +func initializeTestRecomComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testRecomWithMocks { + wire.Build( + MockSuperSet, RecomComponentSet, + wire.Struct(new(testRecomWithMocks), "*"), + ) + return &testRecomWithMocks{} +} + +type testSpaceSdkWithMocks struct { + *spaceSdkComponentImpl + mocks *Mocks +} + +func initializeTestSpaceSdkComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testSpaceSdkWithMocks { + wire.Build( + MockSuperSet, SpaceSdkComponentSet, + wire.Struct(new(testSpaceSdkWithMocks), "*"), + ) + return &testSpaceSdkWithMocks{} +} diff --git a/component/wire_gen_test.go b/component/wire_gen_test.go index 8290271a..2fac5ee9 100644 --- a/component/wire_gen_test.go +++ b/component/wire_gen_test.go @@ -13,7 +13,7 @@ import ( "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/deploy" "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/git/gitserver" "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/git/mirrorserver" - "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/inference" + "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/parquet" "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/rpc" "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/store/s3" "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" @@ -41,15 +41,18 @@ func initializeTestRepoComponent(ctx context.Context, t interface { mockRepoComponent := component.NewMockRepoComponent(t) mockSpaceComponent := component.NewMockSpaceComponent(t) mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) componentMockedComponents := &mockedComponents{ accounting: mockAccountingComponent, repo: mockRepoComponent, tag: mockTagComponent, space: mockSpaceComponent, runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, } - inferenceMockClient := inference.NewMockClient(t) mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) mocks := &Mocks{ stores: mockStores, components: componentMockedComponents, @@ -59,8 +62,9 @@ func initializeTestRepoComponent(ctx context.Context, t interface { mirrorServer: mockMirrorServer, mirrorQueue: mockPriorityQueue, deployer: mockDeployer, - inferenceClient: inferenceMockClient, accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, } componentTestRepoWithMocks := &testRepoWithMocks{ repoComponentImpl: componentRepoComponentImpl, @@ -83,19 +87,22 @@ func initializeTestPromptComponent(ctx context.Context, t interface { mockTagComponent := component.NewMockTagComponent(t) mockSpaceComponent := component.NewMockSpaceComponent(t) mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) componentMockedComponents := &mockedComponents{ accounting: mockAccountingComponent, repo: mockRepoComponent, tag: mockTagComponent, space: mockSpaceComponent, runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, } mockClient := s3.NewMockClient(t) mockMirrorServer := mirrorserver.NewMockMirrorServer(t) mockPriorityQueue := queue.NewMockPriorityQueue(t) mockDeployer := deploy.NewMockDeployer(t) - inferenceMockClient := inference.NewMockClient(t) mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) mocks := &Mocks{ stores: mockStores, components: componentMockedComponents, @@ -105,8 +112,9 @@ func initializeTestPromptComponent(ctx context.Context, t interface { mirrorServer: mockMirrorServer, mirrorQueue: mockPriorityQueue, deployer: mockDeployer, - inferenceClient: inferenceMockClient, accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, } componentTestPromptWithMocks := &testPromptWithMocks{ promptComponentImpl: componentPromptComponentImpl, @@ -128,19 +136,22 @@ func initializeTestUserComponent(ctx context.Context, t interface { componentUserComponentImpl := NewTestUserComponent(mockStores, mockGitServer, mockSpaceComponent, mockRepoComponent, mockDeployer, mockAccountingComponent) mockTagComponent := component.NewMockTagComponent(t) mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) componentMockedComponents := &mockedComponents{ accounting: mockAccountingComponent, repo: mockRepoComponent, tag: mockTagComponent, space: mockSpaceComponent, runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, } mockUserSvcClient := rpc.NewMockUserSvcClient(t) mockClient := s3.NewMockClient(t) mockMirrorServer := mirrorserver.NewMockMirrorServer(t) mockPriorityQueue := queue.NewMockPriorityQueue(t) - inferenceMockClient := inference.NewMockClient(t) mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) mocks := &Mocks{ stores: mockStores, components: componentMockedComponents, @@ -150,8 +161,9 @@ func initializeTestUserComponent(ctx context.Context, t interface { mirrorServer: mockMirrorServer, mirrorQueue: mockPriorityQueue, deployer: mockDeployer, - inferenceClient: inferenceMockClient, accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, } componentTestUserWithMocks := &testUserWithMocks{ userComponentImpl: componentUserComponentImpl, @@ -175,18 +187,21 @@ func initializeTestSpaceComponent(ctx context.Context, t interface { mockTagComponent := component.NewMockTagComponent(t) mockSpaceComponent := component.NewMockSpaceComponent(t) mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) componentMockedComponents := &mockedComponents{ accounting: mockAccountingComponent, repo: mockRepoComponent, tag: mockTagComponent, space: mockSpaceComponent, runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, } mockClient := s3.NewMockClient(t) mockMirrorServer := mirrorserver.NewMockMirrorServer(t) mockPriorityQueue := queue.NewMockPriorityQueue(t) - inferenceMockClient := inference.NewMockClient(t) mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) mocks := &Mocks{ stores: mockStores, components: componentMockedComponents, @@ -196,8 +211,9 @@ func initializeTestSpaceComponent(ctx context.Context, t interface { mirrorServer: mockMirrorServer, mirrorQueue: mockPriorityQueue, deployer: mockDeployer, - inferenceClient: inferenceMockClient, accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, } componentTestSpaceWithMocks := &testSpaceWithMocks{ spaceComponentImpl: componentSpaceComponentImpl, @@ -214,62 +230,614 @@ func initializeTestModelComponent(ctx context.Context, t interface { mockStores := tests.NewMockStores(t) mockRepoComponent := component.NewMockRepoComponent(t) mockSpaceComponent := component.NewMockSpaceComponent(t) - mockClient := inference.NewMockClient(t) mockDeployer := deploy.NewMockDeployer(t) mockAccountingComponent := component.NewMockAccountingComponent(t) mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) mockGitServer := gitserver.NewMockGitServer(t) mockUserSvcClient := rpc.NewMockUserSvcClient(t) - componentModelComponentImpl := NewTestModelComponent(config, mockStores, mockRepoComponent, mockSpaceComponent, mockClient, mockDeployer, mockAccountingComponent, mockRuntimeArchitectureComponent, mockGitServer, mockUserSvcClient) + componentModelComponentImpl := NewTestModelComponent(config, mockStores, mockRepoComponent, mockSpaceComponent, mockDeployer, mockAccountingComponent, mockRuntimeArchitectureComponent, mockGitServer, mockUserSvcClient) mockTagComponent := component.NewMockTagComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) componentMockedComponents := &mockedComponents{ accounting: mockAccountingComponent, repo: mockRepoComponent, tag: mockTagComponent, space: mockSpaceComponent, runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, } - s3MockClient := s3.NewMockClient(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestModelWithMocks := &testModelWithMocks{ + modelComponentImpl: componentModelComponentImpl, + mocks: mocks, + } + return componentTestModelWithMocks +} + +func initializeTestAccountingComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testAccountingWithMocks { + mockStores := tests.NewMockStores(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + componentAccountingComponentImpl := NewTestAccountingComponent(mockStores, mockAccountingClient) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockGitServer := gitserver.NewMockGitServer(t) + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestAccountingWithMocks := &testAccountingWithMocks{ + accountingComponentImpl: componentAccountingComponentImpl, + mocks: mocks, + } + return componentTestAccountingWithMocks +} + +func initializeTestDatasetViewerComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testDatasetViewerWithMocks { + mockStores := tests.NewMockStores(t) + config := ProvideTestConfig() + mockRepoComponent := component.NewMockRepoComponent(t) + mockGitServer := gitserver.NewMockGitServer(t) + mockReader := parquet.NewMockReader(t) + componentDatasetViewerComponentImpl := NewTestDatasetViewerComponent(mockStores, config, mockRepoComponent, mockGitServer, mockReader) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestDatasetViewerWithMocks := &testDatasetViewerWithMocks{ + datasetViewerComponentImpl: componentDatasetViewerComponentImpl, + mocks: mocks, + } + return componentTestDatasetViewerWithMocks +} + +func initializeTestGitHTTPComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testGitHTTPWithMocks { + config := ProvideTestConfig() + mockStores := tests.NewMockStores(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockGitServer := gitserver.NewMockGitServer(t) + mockClient := s3.NewMockClient(t) + componentGitHTTPComponentImpl := NewTestGitHTTPComponent(config, mockStores, mockRepoComponent, mockGitServer, mockClient) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestGitHTTPWithMocks := &testGitHTTPWithMocks{ + gitHTTPComponentImpl: componentGitHTTPComponentImpl, + mocks: mocks, + } + return componentTestGitHTTPWithMocks +} + +func initializeTestDiscussionComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testDiscussionWithMocks { + mockStores := tests.NewMockStores(t) + componentDiscussionComponentImpl := NewTestDiscussionComponent(mockStores) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockGitServer := gitserver.NewMockGitServer(t) + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestDiscussionWithMocks := &testDiscussionWithMocks{ + discussionComponentImpl: componentDiscussionComponentImpl, + mocks: mocks, + } + return componentTestDiscussionWithMocks +} + +func initializeTestRuntimeArchComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testRuntimeArchWithMocks { + mockStores := tests.NewMockStores(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockGitServer := gitserver.NewMockGitServer(t) + componentRuntimeArchitectureComponentImpl := NewTestRuntimeArchitectureComponent(mockStores, mockRepoComponent, mockGitServer) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestRuntimeArchWithMocks := &testRuntimeArchWithMocks{ + runtimeArchitectureComponentImpl: componentRuntimeArchitectureComponentImpl, + mocks: mocks, + } + return componentTestRuntimeArchWithMocks +} + +func initializeTestMirrorComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testMirrorWithMocks { + config := ProvideTestConfig() + mockStores := tests.NewMockStores(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockGitServer := gitserver.NewMockGitServer(t) + mockClient := s3.NewMockClient(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + componentMirrorComponentImpl := NewTestMirrorComponent(config, mockStores, mockMirrorServer, mockRepoComponent, mockGitServer, mockClient, mockPriorityQueue) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestMirrorWithMocks := &testMirrorWithMocks{ + mirrorComponentImpl: componentMirrorComponentImpl, + mocks: mocks, + } + return componentTestMirrorWithMocks +} + +func initializeTestCollectionComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testCollectionWithMocks { + mockStores := tests.NewMockStores(t) + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + componentCollectionComponentImpl := NewTestCollectionComponent(mockStores, mockUserSvcClient, mockSpaceComponent) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockGitServer := gitserver.NewMockGitServer(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestCollectionWithMocks := &testCollectionWithMocks{ + collectionComponentImpl: componentCollectionComponentImpl, + mocks: mocks, + } + return componentTestCollectionWithMocks +} + +func initializeTestDatasetComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testDatasetWithMocks { + config := ProvideTestConfig() + mockStores := tests.NewMockStores(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + mockGitServer := gitserver.NewMockGitServer(t) + componentDatasetComponentImpl := NewTestDatasetComponent(config, mockStores, mockRepoComponent, mockUserSvcClient, mockSensitiveComponent, mockGitServer) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestDatasetWithMocks := &testDatasetWithMocks{ + datasetComponentImpl: componentDatasetComponentImpl, + mocks: mocks, + } + return componentTestDatasetWithMocks +} + +func initializeTestCodeComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testCodeWithMocks { + config := ProvideTestConfig() + mockStores := tests.NewMockStores(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockGitServer := gitserver.NewMockGitServer(t) + componentCodeComponentImpl := NewTestCodeComponent(config, mockStores, mockRepoComponent, mockUserSvcClient, mockGitServer) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestCodeWithMocks := &testCodeWithMocks{ + codeComponentImpl: componentCodeComponentImpl, + mocks: mocks, + } + return componentTestCodeWithMocks +} + +func initializeTestMultiSyncComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testMultiSyncWithMocks { + config := ProvideTestConfig() + mockStores := tests.NewMockStores(t) + mockGitServer := gitserver.NewMockGitServer(t) + componentMultiSyncComponentImpl := NewTestMultiSyncComponent(config, mockStores, mockGitServer) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockClient := s3.NewMockClient(t) mockMirrorServer := mirrorserver.NewMockMirrorServer(t) mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) mocks := &Mocks{ stores: mockStores, components: componentMockedComponents, gitServer: mockGitServer, userSvcClient: mockUserSvcClient, - s3Client: s3MockClient, + s3Client: mockClient, mirrorServer: mockMirrorServer, mirrorQueue: mockPriorityQueue, deployer: mockDeployer, - inferenceClient: mockClient, accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, } - componentTestModelWithMocks := &testModelWithMocks{ - modelComponentImpl: componentModelComponentImpl, - mocks: mocks, + componentTestMultiSyncWithMocks := &testMultiSyncWithMocks{ + multiSyncComponentImpl: componentMultiSyncComponentImpl, + mocks: mocks, } - return componentTestModelWithMocks + return componentTestMultiSyncWithMocks } -func initializeTestAccountingComponent(ctx context.Context, t interface { +func initializeTestInternalComponent(ctx context.Context, t interface { Cleanup(func()) mock.TestingT -}) *testAccountingWithMocks { +}) *testInternalWithMocks { + config := ProvideTestConfig() mockStores := tests.NewMockStores(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockGitServer := gitserver.NewMockGitServer(t) + componentInternalComponentImpl := NewTestInternalComponent(config, mockStores, mockRepoComponent, mockGitServer) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) mockAccountingClient := accounting.NewMockAccountingClient(t) - componentAccountingComponentImpl := NewTestAccountingComponent(mockStores, mockAccountingClient) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestInternalWithMocks := &testInternalWithMocks{ + internalComponentImpl: componentInternalComponentImpl, + mocks: mocks, + } + return componentTestInternalWithMocks +} + +func initializeTestMirrorSourceComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testMirrorSourceWithMocks { + config := ProvideTestConfig() + mockStores := tests.NewMockStores(t) + componentMirrorSourceComponentImpl := NewTestMirrorSourceComponent(config, mockStores) mockAccountingComponent := component.NewMockAccountingComponent(t) mockRepoComponent := component.NewMockRepoComponent(t) mockTagComponent := component.NewMockTagComponent(t) mockSpaceComponent := component.NewMockSpaceComponent(t) mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) componentMockedComponents := &mockedComponents{ accounting: mockAccountingComponent, repo: mockRepoComponent, tag: mockTagComponent, space: mockSpaceComponent, runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, } mockGitServer := gitserver.NewMockGitServer(t) mockUserSvcClient := rpc.NewMockUserSvcClient(t) @@ -277,7 +845,9 @@ func initializeTestAccountingComponent(ctx context.Context, t interface { mockMirrorServer := mirrorserver.NewMockMirrorServer(t) mockPriorityQueue := queue.NewMockPriorityQueue(t) mockDeployer := deploy.NewMockDeployer(t) - inferenceMockClient := inference.NewMockClient(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) mocks := &Mocks{ stores: mockStores, components: componentMockedComponents, @@ -287,14 +857,215 @@ func initializeTestAccountingComponent(ctx context.Context, t interface { mirrorServer: mockMirrorServer, mirrorQueue: mockPriorityQueue, deployer: mockDeployer, - inferenceClient: inferenceMockClient, accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, } - componentTestAccountingWithMocks := &testAccountingWithMocks{ - accountingComponentImpl: componentAccountingComponentImpl, - mocks: mocks, + componentTestMirrorSourceWithMocks := &testMirrorSourceWithMocks{ + mirrorSourceComponentImpl: componentMirrorSourceComponentImpl, + mocks: mocks, } - return componentTestAccountingWithMocks + return componentTestMirrorSourceWithMocks +} + +func initializeTestSpaceResourceComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testSpaceResourceWithMocks { + config := ProvideTestConfig() + mockStores := tests.NewMockStores(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingComponent := component.NewMockAccountingComponent(t) + componentSpaceResourceComponentImpl := NewTestSpaceResourceComponent(config, mockStores, mockDeployer, mockAccountingComponent) + mockRepoComponent := component.NewMockRepoComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockGitServer := gitserver.NewMockGitServer(t) + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestSpaceResourceWithMocks := &testSpaceResourceWithMocks{ + spaceResourceComponentImpl: componentSpaceResourceComponentImpl, + mocks: mocks, + } + return componentTestSpaceResourceWithMocks +} + +func initializeTestTagComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testTagWithMocks { + config := ProvideTestConfig() + mockStores := tests.NewMockStores(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + componentTagComponentImpl := NewTestTagComponent(config, mockStores, mockModerationSvcClient) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockGitServer := gitserver.NewMockGitServer(t) + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestTagWithMocks := &testTagWithMocks{ + tagComponentImpl: componentTagComponentImpl, + mocks: mocks, + } + return componentTestTagWithMocks +} + +func initializeTestRecomComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testRecomWithMocks { + config := ProvideTestConfig() + mockStores := tests.NewMockStores(t) + mockGitServer := gitserver.NewMockGitServer(t) + componentRecomComponentImpl := NewTestRecomComponent(config, mockStores, mockGitServer) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestRecomWithMocks := &testRecomWithMocks{ + recomComponentImpl: componentRecomComponentImpl, + mocks: mocks, + } + return componentTestRecomWithMocks +} + +func initializeTestSpaceSdkComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testSpaceSdkWithMocks { + config := ProvideTestConfig() + mockStores := tests.NewMockStores(t) + componentSpaceSdkComponentImpl := NewTestSpaceSdkComponent(config, mockStores) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockGitServer := gitserver.NewMockGitServer(t) + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestSpaceSdkWithMocks := &testSpaceSdkWithMocks{ + spaceSdkComponentImpl: componentSpaceSdkComponentImpl, + mocks: mocks, + } + return componentTestSpaceSdkWithMocks } // wire.go: @@ -328,3 +1099,78 @@ type testAccountingWithMocks struct { *accountingComponentImpl mocks *Mocks } + +type testDatasetViewerWithMocks struct { + *datasetViewerComponentImpl + mocks *Mocks +} + +type testGitHTTPWithMocks struct { + *gitHTTPComponentImpl + mocks *Mocks +} + +type testDiscussionWithMocks struct { + *discussionComponentImpl + mocks *Mocks +} + +type testRuntimeArchWithMocks struct { + *runtimeArchitectureComponentImpl + mocks *Mocks +} + +type testMirrorWithMocks struct { + *mirrorComponentImpl + mocks *Mocks +} + +type testCollectionWithMocks struct { + *collectionComponentImpl + mocks *Mocks +} + +type testDatasetWithMocks struct { + *datasetComponentImpl + mocks *Mocks +} + +type testCodeWithMocks struct { + *codeComponentImpl + mocks *Mocks +} + +type testMultiSyncWithMocks struct { + *multiSyncComponentImpl + mocks *Mocks +} + +type testInternalWithMocks struct { + *internalComponentImpl + mocks *Mocks +} + +type testMirrorSourceWithMocks struct { + *mirrorSourceComponentImpl + mocks *Mocks +} + +type testSpaceResourceWithMocks struct { + *spaceResourceComponentImpl + mocks *Mocks +} + +type testTagWithMocks struct { + *tagComponentImpl + mocks *Mocks +} + +type testRecomWithMocks struct { + *recomComponentImpl + mocks *Mocks +} + +type testSpaceSdkWithMocks struct { + *spaceSdkComponentImpl + mocks *Mocks +} diff --git a/component/wireset.go b/component/wireset.go index 06f01a7d..b25cc075 100644 --- a/component/wireset.go +++ b/component/wireset.go @@ -6,7 +6,7 @@ import ( mock_deploy "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/deploy" mock_git "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/git/gitserver" mock_mirror "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/git/mirrorserver" - mock_inference "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/inference" + mock_preader "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/parquet" mock_rpc "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/rpc" mock_s3 "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/store/s3" mock_component "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" @@ -15,8 +15,8 @@ import ( "opencsg.com/csghub-server/builder/deploy" "opencsg.com/csghub-server/builder/git/gitserver" "opencsg.com/csghub-server/builder/git/mirrorserver" - "opencsg.com/csghub-server/builder/inference" "opencsg.com/csghub-server/builder/llm" + "opencsg.com/csghub-server/builder/parquet" "opencsg.com/csghub-server/builder/rpc" "opencsg.com/csghub-server/builder/store/s3" "opencsg.com/csghub-server/common/config" @@ -30,6 +30,7 @@ type mockedComponents struct { tag *mock_component.MockTagComponent space *mock_component.MockSpaceComponent runtimeArchitecture *mock_component.MockRuntimeArchitectureComponent + sensitive *mock_component.MockSensitiveComponent } var MockedStoreSet = wire.NewSet( @@ -47,6 +48,8 @@ var MockedComponentSet = wire.NewSet( wire.Bind(new(SpaceComponent), new(*mock_component.MockSpaceComponent)), mock_component.NewMockRuntimeArchitectureComponent, wire.Bind(new(RuntimeArchitectureComponent), new(*mock_component.MockRuntimeArchitectureComponent)), + mock_component.NewMockSensitiveComponent, + wire.Bind(new(SensitiveComponent), new(*mock_component.MockSensitiveComponent)), ) var MockedGitServerSet = wire.NewSet( @@ -79,16 +82,21 @@ var MockedMirrorQueueSet = wire.NewSet( wire.Bind(new(queue.PriorityQueue), new(*mock_mirror_queue.MockPriorityQueue)), ) -var MockedInferenceClientSet = wire.NewSet( - mock_inference.NewMockClient, - wire.Bind(new(inference.Client), new(*mock_inference.MockClient)), -) - var MockedAccountingClientSet = wire.NewSet( mock_accounting.NewMockAccountingClient, wire.Bind(new(accounting.AccountingClient), new(*mock_accounting.MockAccountingClient)), ) +var MockedParquetReaderSet = wire.NewSet( + mock_preader.NewMockReader, + wire.Bind(new(parquet.Reader), new(*mock_preader.MockReader)), +) + +var MockedModerationSvcClientSet = wire.NewSet( + mock_rpc.NewMockModerationSvcClient, + wire.Bind(new(rpc.ModerationSvcClient), new(*mock_rpc.MockModerationSvcClient)), +) + type Mocks struct { stores *tests.MockStores components *mockedComponents @@ -98,8 +106,9 @@ type Mocks struct { mirrorServer *mock_mirror.MockMirrorServer mirrorQueue *mock_mirror_queue.MockPriorityQueue deployer *mock_deploy.MockDeployer - inferenceClient *mock_inference.MockClient accountingClient *mock_accounting.MockAccountingClient + preader *mock_preader.MockReader + moderationClient *mock_rpc.MockModerationSvcClient } var AllMockSet = wire.NewSet( @@ -114,7 +123,8 @@ func ProvideTestConfig() *config.Config { var MockSuperSet = wire.NewSet( MockedComponentSet, AllMockSet, MockedStoreSet, MockedGitServerSet, MockedUserSvcSet, MockedS3Set, MockedDeployerSet, ProvideTestConfig, MockedMirrorServerSet, - MockedMirrorQueueSet, MockedInferenceClientSet, MockedAccountingClientSet, + MockedMirrorQueueSet, MockedAccountingClientSet, MockedParquetReaderSet, + MockedModerationSvcClientSet, ) func NewTestRepoComponent(config *config.Config, stores *tests.MockStores, rpcUser rpc.UserSvcClient, gitServer gitserver.GitServer, tagComponent TagComponent, s3Client s3.Client, deployer deploy.Deployer, accountingComponent AccountingComponent, mq queue.PriorityQueue, mirrorServer mirrorserver.MirrorServer) *repoComponentImpl { @@ -232,7 +242,6 @@ func NewTestModelComponent( stores *tests.MockStores, repoComponent RepoComponent, spaceComponent SpaceComponent, - inferClient inference.Client, deployer deploy.Deployer, accountingComponent AccountingComponent, runtimeArchComponent RuntimeArchitectureComponent, @@ -248,7 +257,6 @@ func NewTestModelComponent( modelStore: stores.Model, repoStore: stores.Repo, spaceResourceStore: stores.SpaceResource, - inferClient: inferClient, userStore: stores.User, deployer: deployer, accountingComponent: accountingComponent, @@ -276,3 +284,206 @@ func NewTestAccountingComponent(stores *tests.MockStores, accountingClient accou } var AccountingComponentSet = wire.NewSet(NewTestAccountingComponent) + +func NewTestDatasetViewerComponent(stores *tests.MockStores, cfg *config.Config, repoComponent RepoComponent, gitServer gitserver.GitServer, preader parquet.Reader) *datasetViewerComponentImpl { + return &datasetViewerComponentImpl{ + cfg: cfg, + preader: preader, + } +} + +var DatasetViewerComponentSet = wire.NewSet(NewTestDatasetViewerComponent) + +func NewTestGitHTTPComponent( + config *config.Config, + stores *tests.MockStores, + repoComponent RepoComponent, + gitServer gitserver.GitServer, + s3Client s3.Client, +) *gitHTTPComponentImpl { + config.APIServer.PublicDomain = "https://foo.com" + config.APIServer.SSHDomain = "ssh://test@127.0.0.1" + return &gitHTTPComponentImpl{ + config: config, + repoComponent: repoComponent, + repoStore: stores.Repo, + userStore: stores.User, + gitServer: gitServer, + s3Client: s3Client, + lfsMetaObjectStore: stores.LfsMetaObject, + lfsLockStore: stores.LfsLock, + } +} + +var GitHTTPComponentSet = wire.NewSet(NewTestGitHTTPComponent) + +func NewTestDiscussionComponent( + stores *tests.MockStores, +) *discussionComponentImpl { + return &discussionComponentImpl{ + repoStore: stores.Repo, + userStore: stores.User, + discussionStore: stores.Discussion, + } +} + +var DiscussionComponentSet = wire.NewSet(NewTestDiscussionComponent) + +func NewTestRuntimeArchitectureComponent(stores *tests.MockStores, repoComponent RepoComponent, gitServer gitserver.GitServer) *runtimeArchitectureComponentImpl { + return &runtimeArchitectureComponentImpl{ + repoComponent: repoComponent, + repoStore: stores.Repo, + repoRuntimeFrameworkStore: stores.RepoRuntimeFramework, + runtimeFrameworksStore: stores.RuntimeFramework, + runtimeArchStore: stores.RuntimeArch, + resouceModelStore: stores.ResourceModel, + tagStore: stores.Tag, + gitServer: gitServer, + } +} + +var RuntimeArchComponentSet = wire.NewSet(NewTestRuntimeArchitectureComponent) + +func NewTestMirrorComponent(config *config.Config, stores *tests.MockStores, mirrorServer mirrorserver.MirrorServer, repoComponent RepoComponent, gitServer gitserver.GitServer, s3Client s3.Client, mq queue.PriorityQueue) *mirrorComponentImpl { + return &mirrorComponentImpl{ + tokenStore: stores.GitServerAccessToken, + mirrorServer: mirrorServer, + repoComp: repoComponent, + git: gitServer, + s3Client: s3Client, + modelStore: stores.Model, + datasetStore: stores.Dataset, + codeStore: stores.Code, + repoStore: stores.Repo, + mirrorStore: stores.Mirror, + mirrorSourceStore: stores.MirrorSource, + namespaceStore: stores.Namespace, + userStore: stores.User, + config: config, + mq: mq, + } +} + +var MirrorComponentSet = wire.NewSet(NewTestMirrorComponent) + +func NewTestCollectionComponent(stores *tests.MockStores, userSvcClient rpc.UserSvcClient, spaceComponent SpaceComponent) *collectionComponentImpl { + return &collectionComponentImpl{ + collectionStore: stores.Collection, + orgStore: stores.Org, + repoStore: stores.Repo, + userStore: stores.User, + userLikesStore: stores.UserLikes, + userSvcClient: userSvcClient, + spaceComponent: spaceComponent, + } +} + +var CollectionComponentSet = wire.NewSet(NewTestCollectionComponent) + +func NewTestDatasetComponent(config *config.Config, stores *tests.MockStores, repoComponent RepoComponent, userSvcClient rpc.UserSvcClient, sensitiveComponent SensitiveComponent, gitServer gitserver.GitServer) *datasetComponentImpl { + return &datasetComponentImpl{ + config: config, + repoComponent: repoComponent, + tagStore: stores.Tag, + datasetStore: stores.Dataset, + repoStore: stores.Repo, + namespaceStore: stores.Namespace, + userStore: stores.User, + sensitiveComponent: sensitiveComponent, + gitServer: gitServer, + userLikesStore: stores.UserLikes, + userSvcClient: userSvcClient, + } +} + +var DatasetComponentSet = wire.NewSet(NewTestDatasetComponent) + +func NewTestCodeComponent(config *config.Config, stores *tests.MockStores, repoComponent RepoComponent, userSvcClient rpc.UserSvcClient, gitServer gitserver.GitServer) *codeComponentImpl { + return &codeComponentImpl{ + config: config, + repoComponent: repoComponent, + codeStore: stores.Code, + repoStore: stores.Repo, + userLikesStore: stores.UserLikes, + gitServer: gitServer, + userSvcClient: userSvcClient, + } +} + +var CodeComponentSet = wire.NewSet(NewTestCodeComponent) + +func NewTestMultiSyncComponent(config *config.Config, stores *tests.MockStores, gitServer gitserver.GitServer) *multiSyncComponentImpl { + return &multiSyncComponentImpl{ + multiSyncStore: stores.MultiSync, + repoStore: stores.Repo, + modelStore: stores.Model, + datasetStore: stores.Dataset, + namespaceStore: stores.Namespace, + userStore: stores.User, + syncVersionStore: stores.SyncVersion, + tagStore: stores.Tag, + fileStore: stores.File, + gitServer: gitServer, + } +} + +var MultiSyncComponentSet = wire.NewSet(NewTestMultiSyncComponent) + +func NewTestInternalComponent(config *config.Config, stores *tests.MockStores, repoComponent RepoComponent, gitServer gitserver.GitServer) *internalComponentImpl { + return &internalComponentImpl{ + config: config, + sshKeyStore: stores.SSH, + repoStore: stores.Repo, + tokenStore: stores.AccessToken, + namespaceStore: stores.Namespace, + repoComponent: repoComponent, + gitServer: gitServer, + } +} + +var InternalComponentSet = wire.NewSet(NewTestInternalComponent) + +func NewTestMirrorSourceComponent(config *config.Config, stores *tests.MockStores) *mirrorSourceComponentImpl { + return &mirrorSourceComponentImpl{ + mirrorSourceStore: stores.MirrorSource, + userStore: stores.User, + } +} + +var MirrorSourceComponentSet = wire.NewSet(NewTestMirrorSourceComponent) + +func NewTestSpaceResourceComponent(config *config.Config, stores *tests.MockStores, deployer deploy.Deployer, accountComponent AccountingComponent) *spaceResourceComponentImpl { + return &spaceResourceComponentImpl{ + deployer: deployer, + } +} + +var SpaceResourceComponentSet = wire.NewSet(NewTestSpaceResourceComponent) + +func NewTestTagComponent(config *config.Config, stores *tests.MockStores, sensitiveChecker rpc.ModerationSvcClient) *tagComponentImpl { + return &tagComponentImpl{ + tagStore: stores.Tag, + repoStore: stores.Repo, + sensitiveChecker: sensitiveChecker, + } +} + +var TagComponentSet = wire.NewSet(NewTestTagComponent) + +func NewTestRecomComponent(config *config.Config, stores *tests.MockStores, gitServer gitserver.GitServer) *recomComponentImpl { + return &recomComponentImpl{ + recomStore: stores.Recom, + repoStore: stores.Repo, + gitServer: gitServer, + } +} + +var RecomComponentSet = wire.NewSet(NewTestRecomComponent) + +func NewTestSpaceSdkComponent(config *config.Config, stores *tests.MockStores) *spaceSdkComponentImpl { + return &spaceSdkComponentImpl{ + spaceSdkStore: stores.SpaceSdk, + } +} + +var SpaceSdkComponentSet = wire.NewSet(NewTestSpaceSdkComponent) diff --git a/go.sum b/go.sum index e6efc3c7..db84171c 100644 --- a/go.sum +++ b/go.sum @@ -307,6 +307,7 @@ github.com/google/pprof v0.0.0-20240525223248-4bfdf5a9a2af h1:kmjWCqn2qkEml422C2 github.com/google/pprof v0.0.0-20240525223248-4bfdf5a9a2af/go.mod h1:K1liHPHnj73Fdn/EKuT8nrFqBihUSKXoLYU0BuatOYo= github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o= github.com/google/s2a-go v0.1.7/go.mod h1:50CgR4k1jNlWBu4UfS4AcfhVe1r6pdZPygJ3R8F0Qdw= +github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE= github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=