From 66d0cb71d4b0d82c4495636dfe52786918ac9c32 Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Thu, 27 Jun 2024 15:56:32 +0800 Subject: [PATCH 01/42] feat: Use fsnotify to detect model/policy files change in casbin plugin --- plugins/go.mod | 1 + plugins/go.sum | 2 + plugins/pkg/file/fs.go | 127 ++++++++++++++++--------------- plugins/pkg/file/fs_test.go | 41 ++++++---- plugins/plugins/casbin/filter.go | 28 +++++-- 5 files changed, 115 insertions(+), 84 deletions(-) diff --git a/plugins/go.mod b/plugins/go.mod index 8ebbb82b..fcd69cd2 100644 --- a/plugins/go.mod +++ b/plugins/go.mod @@ -54,6 +54,7 @@ require ( github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/envoyproxy/go-control-plane v0.12.1-0.20240117015050-472addddff92 // indirect github.com/envoyproxy/protoc-gen-validate v1.0.4 // indirect + github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/go-ini/ini v1.67.0 // indirect github.com/go-jose/go-jose/v4 v4.0.1 // indirect github.com/go-logr/logr v1.4.1 // indirect diff --git a/plugins/go.sum b/plugins/go.sum index f76ca513..f3bd6564 100644 --- a/plugins/go.sum +++ b/plugins/go.sum @@ -60,6 +60,8 @@ github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8 github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g= github.com/foxcpp/go-mockdns v1.1.0 h1:jI0rD8M0wuYAxL7r/ynTrCQQq0BVqfB99Vgk7DlmewI= github.com/foxcpp/go-mockdns v1.1.0/go.mod h1:IhLeSFGed3mJIAXPH2aiRQB+kqz7oqu8ld2qVbOu7Wk= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A= github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8= github.com/go-jose/go-jose/v4 v4.0.1 h1:QVEPDE3OluqXBQZDcnNvQrInro2h0e4eqNbnZSWqS6U= diff --git a/plugins/pkg/file/fs.go b/plugins/pkg/file/fs.go index 29dcb531..563d7dd3 100644 --- a/plugins/pkg/file/fs.go +++ b/plugins/pkg/file/fs.go @@ -15,11 +15,13 @@ package file import ( + "errors" + "fmt" "os" "sync" "time" - "github.com/jellydator/ttlcache/v3" + "github.com/fsnotify/fsnotify" "mosn.io/htnn/api/pkg/log" ) @@ -48,65 +50,87 @@ func (f *File) SetMtime(t time.Time) { f.lock.Unlock() } -type fs struct { - cache *ttlcache.Cache[string, os.FileInfo] +type Fsnotify struct { + Watcher *fsnotify.Watcher } -func newFS(ttl time.Duration) *fs { - loader := ttlcache.LoaderFunc[string, os.FileInfo]( - func(c *ttlcache.Cache[string, os.FileInfo], key string) *ttlcache.Item[string, os.FileInfo] { - info, err := os.Stat(key) - if err != nil { - logger.Error(err, "reload file info to cache", "file", key) - return nil - } - item := c.Set(key, info, ttlcache.DefaultTTL) - logger.Info("update file mtime", "file", key, "mtime", item.Value().ModTime()) - return item - }, - ) - cache := ttlcache.New( - ttlcache.WithTTL[string, os.FileInfo](ttl), - ttlcache.WithLoader[string, os.FileInfo](loader), - ) - go cache.Start() - - return &fs{ - cache: cache, +func newFsnotify() (fs *Fsnotify) { + watcher, err := fsnotify.NewWatcher() + if err != nil { + logger.Error(err, "create watcher failed") + return } + fs = &Fsnotify{ + Watcher: watcher, + } + return } var ( - // TODO: rewrite it to use inotify - defaultFs = newFS(10 * time.Second) + defaultFsnotify = newFsnotify() ) -func IsChanged(files ...*File) bool { +func Update(onChange func(), files ...*File) (err error) { + err = WatchFiles(onChange, files...) + return +} + +func WatchFiles(onChange func(), files ...*File) (err error) { + if len(files) < 1 { + err = errors.New("must specify at least one file to watch") + logger.Error(err, "") + return + } + + watcher := newFsnotify().Watcher + if err != nil { + + logger.Error(err, "failed to create watcher") + return + } + + // Add files to watcher. for _, file := range files { - changed := defaultFs.isChanged(file) - if changed { - return true - } + go defaultFsnotify.watchFiles(onChange, watcher, file) } - return false + + return } -func (f *fs) isChanged(file *File) bool { - item := f.cache.Get(file.Name) - if item == nil { - // As a protection, failed to fetch the real file means file not changed - return false +func (f *Fsnotify) watchFiles(onChange func(), w *fsnotify.Watcher, files *File) { + defer func(w *fsnotify.Watcher) { + err := w.Close() + if err != nil { + logger.Error(err, "failed to close fsnotify watcher") + } + }(w) + err := w.Add(files.Name) + if err != nil { + logger.Error(err, "add file to watcher failed") + } + for { + select { + case event, ok := <-w.Events: + if !ok { + return + } + logger.Info(fmt.Sprintf("event: %v", event)) + onChange() + return + case err, ok := <-w.Errors: + if !ok { + return + } + logger.Error(err, "error watching files") + } } - - return file.Mtime().Before(item.Value().ModTime()) } -func (f *fs) Stat(path string) (*File, error) { +func (f *Fsnotify) Stat(path string) (*File, error) { info, err := os.Stat(path) if err != nil { return nil, err } - f.cache.Set(path, info, ttlcache.DefaultTTL) return &File{ Name: path, @@ -115,24 +139,5 @@ func (f *fs) Stat(path string) (*File, error) { } func Stat(path string) (*File, error) { - return defaultFs.Stat(path) -} - -func Update(files ...*File) bool { - for _, file := range files { - if !defaultFs.update(file) { - return false - } - } - return true -} - -func (f *fs) update(file *File) bool { - item := f.cache.Get(file.Name) - if item == nil { - return false - } - - file.SetMtime(item.Value().ModTime()) - return true + return defaultFsnotify.Stat(path) } diff --git a/plugins/pkg/file/fs_test.go b/plugins/pkg/file/fs_test.go index 758e5a67..49a2fb69 100644 --- a/plugins/pkg/file/fs_test.go +++ b/plugins/pkg/file/fs_test.go @@ -22,22 +22,33 @@ import ( "github.com/stretchr/testify/assert" ) -func TestFileMtimeDetection(t *testing.T) { - defaultFs = newFS(2000 * time.Millisecond) +func TestFileIsChanged(t *testing.T) { + i := 1 + tmpfile, _ := os.CreateTemp("./", "example") + defer func(name string) { + err := os.Remove(name) + if err != nil { + t.Logf("%v", err) + } + }(tmpfile.Name()) - tmpfile, _ := os.CreateTemp("", "example") - defer os.Remove(tmpfile.Name()) // clean up - - f, err := Stat(tmpfile.Name()) - assert.Nil(t, err) - assert.False(t, IsChanged(f)) - time.Sleep(1000 * time.Millisecond) + file := &File{Name: tmpfile.Name()} + _ = WatchFiles(func() { + i = 2 + }, file) + time.Sleep(1 * time.Millisecond) tmpfile.Write([]byte("bls")) - tmpfile.Close() - assert.False(t, IsChanged(f)) + tmpfile.Sync() + assert.Equal(t, 2, i) + + _ = WatchFiles(func() { + i = 1 + }, file) + time.Sleep(1 * time.Millisecond) + tmpfile.Sync() + assert.Equal(t, 2, i) + + err := WatchFiles(func() {}) + assert.Equal(t, err.Error(), "must specify at least one file to watch", "Expected error message does not match") - time.Sleep(2500 * time.Millisecond) - assert.True(t, IsChanged(f)) - assert.True(t, Update(f)) - assert.False(t, IsChanged(f)) } diff --git a/plugins/plugins/casbin/filter.go b/plugins/plugins/casbin/filter.go index 5a9c7b36..c60bb007 100644 --- a/plugins/plugins/casbin/filter.go +++ b/plugins/plugins/casbin/filter.go @@ -16,7 +16,6 @@ package casbin import ( "github.com/casbin/casbin/v2" - "mosn.io/htnn/api/pkg/filtermanager/api" "mosn.io/htnn/plugins/pkg/file" ) @@ -35,13 +34,9 @@ type filter struct { config *config } -func (f *filter) DecodeHeaders(headers api.RequestHeaderMap, endStream bool) api.ResultAction { +func (f *filter) reloadEnforcer() { conf := f.config - role, _ := headers.Get(conf.Token.Name) // role can be "" - url := headers.Url() - - policyChanged := file.IsChanged(conf.modelFile, conf.policyFile) - if policyChanged && !conf.updating.Load() { + if !conf.updating.Load() { conf.updating.Store(true) api.LogWarnf("policy %s or model %s changed, reload enforcer", conf.policyFile.Name, conf.modelFile.Name) @@ -58,11 +53,28 @@ func (f *filter) DecodeHeaders(headers api.RequestHeaderMap, endStream bool) api conf.enforcer = e conf.lock.Unlock() - file.Update(conf.modelFile, conf.policyFile) + err = file.Update(func() { + f.reloadEnforcer() + }, conf.modelFile, conf.policyFile) + if err != nil { + api.LogErrorf("failed to update Enforcer: %v", err) + } api.LogWarnf("policy %s or model %s changed, enforcer reloaded", conf.policyFile.Name, conf.modelFile.Name) } }() } +} +func (f *filter) DecodeHeaders(headers api.RequestHeaderMap, endStream bool) api.ResultAction { + conf := f.config + role, _ := headers.Get(conf.Token.Name) // role can be "" + url := headers.Url() + + err := file.WatchFiles(f.reloadEnforcer, conf.modelFile, conf.policyFile) + + if err != nil { + api.LogErrorf("failed to watch files: %v", err) + return &api.LocalResponse{Code: 500} + } conf.lock.RLock() ok, err := f.config.enforcer.Enforce(role, url.Path, headers.Method()) From fe5182edf09e8ded266908f7aa9409aeb977ae11 Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Thu, 27 Jun 2024 16:29:00 +0800 Subject: [PATCH 02/42] feat: Use fsnotify to detect model/policy files change in casbin plugin --- plugins/go.mod | 2 +- plugins/plugins/casbin/filter.go | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/plugins/go.mod b/plugins/go.mod index fcd69cd2..accb88bc 100644 --- a/plugins/go.mod +++ b/plugins/go.mod @@ -27,6 +27,7 @@ require ( github.com/casbin/casbin/v2 v2.88.0 github.com/coreos/go-oidc/v3 v3.10.0 github.com/envoyproxy/envoy v1.29.4 + github.com/fsnotify/fsnotify v1.7.0 github.com/google/cel-go v0.20.1 github.com/google/uuid v1.6.0 github.com/gorilla/securecookie v1.1.2 @@ -54,7 +55,6 @@ require ( github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/envoyproxy/go-control-plane v0.12.1-0.20240117015050-472addddff92 // indirect github.com/envoyproxy/protoc-gen-validate v1.0.4 // indirect - github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/go-ini/ini v1.67.0 // indirect github.com/go-jose/go-jose/v4 v4.0.1 // indirect github.com/go-logr/logr v1.4.1 // indirect diff --git a/plugins/plugins/casbin/filter.go b/plugins/plugins/casbin/filter.go index c60bb007..aabc73d1 100644 --- a/plugins/plugins/casbin/filter.go +++ b/plugins/plugins/casbin/filter.go @@ -16,6 +16,7 @@ package casbin import ( "github.com/casbin/casbin/v2" + "mosn.io/htnn/api/pkg/filtermanager/api" "mosn.io/htnn/plugins/pkg/file" ) From 1bfb1d40b6b2dbcba2b4956227b2305a7a5b76a0 Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Thu, 27 Jun 2024 16:34:40 +0800 Subject: [PATCH 03/42] feat: Use fsnotify to detect model/policy files change in casbin plugin --- plugins/pkg/file/fs_test.go | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/plugins/pkg/file/fs_test.go b/plugins/pkg/file/fs_test.go index 49a2fb69..32ae44d2 100644 --- a/plugins/pkg/file/fs_test.go +++ b/plugins/pkg/file/fs_test.go @@ -16,6 +16,7 @@ package file import ( "os" + "sync" "testing" "time" @@ -23,6 +24,7 @@ import ( ) func TestFileIsChanged(t *testing.T) { + var mu sync.Mutex i := 1 tmpfile, _ := os.CreateTemp("./", "example") defer func(name string) { @@ -34,21 +36,25 @@ func TestFileIsChanged(t *testing.T) { file := &File{Name: tmpfile.Name()} _ = WatchFiles(func() { + mu.Lock() i = 2 + mu.Unlock() }, file) time.Sleep(1 * time.Millisecond) tmpfile.Write([]byte("bls")) tmpfile.Sync() + mu.Lock() assert.Equal(t, 2, i) + mu.Unlock() _ = WatchFiles(func() { + mu.Lock() i = 1 + mu.Unlock() }, file) time.Sleep(1 * time.Millisecond) tmpfile.Sync() + mu.Lock() assert.Equal(t, 2, i) - - err := WatchFiles(func() {}) - assert.Equal(t, err.Error(), "must specify at least one file to watch", "Expected error message does not match") - + mu.Unlock() } From 7bf2834dde5314d56e589f6c23bde4e52f8fb69d Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Thu, 27 Jun 2024 23:30:49 +0800 Subject: [PATCH 04/42] feat: Use fsnotify to detect model/policy files change in casbin plugin --- plugins/pkg/file/fs.go | 18 ++-------------- plugins/pkg/file/fs_test.go | 42 +++++++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 16 deletions(-) diff --git a/plugins/pkg/file/fs.go b/plugins/pkg/file/fs.go index 563d7dd3..3553c034 100644 --- a/plugins/pkg/file/fs.go +++ b/plugins/pkg/file/fs.go @@ -37,19 +37,6 @@ type File struct { mtime time.Time } -func (f *File) Mtime() time.Time { - f.lock.RLock() - defer f.lock.RUnlock() - // the returned time.Time should be readonly - return f.mtime -} - -func (f *File) SetMtime(t time.Time) { - f.lock.Lock() - f.mtime = t - f.lock.Unlock() -} - type Fsnotify struct { Watcher *fsnotify.Watcher } @@ -78,14 +65,11 @@ func Update(onChange func(), files ...*File) (err error) { func WatchFiles(onChange func(), files ...*File) (err error) { if len(files) < 1 { err = errors.New("must specify at least one file to watch") - logger.Error(err, "") return } watcher := newFsnotify().Watcher if err != nil { - - logger.Error(err, "failed to create watcher") return } @@ -99,6 +83,7 @@ func WatchFiles(onChange func(), files ...*File) (err error) { func (f *Fsnotify) watchFiles(onChange func(), w *fsnotify.Watcher, files *File) { defer func(w *fsnotify.Watcher) { + logger.Info("stop watch files" + files.Name) err := w.Close() if err != nil { logger.Error(err, "failed to close fsnotify watcher") @@ -108,6 +93,7 @@ func (f *Fsnotify) watchFiles(onChange func(), w *fsnotify.Watcher, files *File) if err != nil { logger.Error(err, "add file to watcher failed") } + logger.Info("start watch files" + files.Name) for { select { case event, ok := <-w.Events: diff --git a/plugins/pkg/file/fs_test.go b/plugins/pkg/file/fs_test.go index 32ae44d2..4225f76c 100644 --- a/plugins/pkg/file/fs_test.go +++ b/plugins/pkg/file/fs_test.go @@ -57,4 +57,46 @@ func TestFileIsChanged(t *testing.T) { mu.Lock() assert.Equal(t, 2, i) mu.Unlock() + + _ = Update(func() { + mu.Lock() + i = 1 + mu.Unlock() + }, file) + + time.Sleep(1 * time.Millisecond) + tmpfile.Sync() + mu.Lock() + assert.Equal(t, 2, i) + + err := WatchFiles(func() {}) + assert.Error(t, err, "must specify at least one file to watch") +} + +func TestStat(t *testing.T) { + tmpfile, err := os.CreateTemp("", "example") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + defer os.Remove(tmpfile.Name()) + + if _, err := tmpfile.Write([]byte("hello world")); err != nil { + t.Fatalf("Failed to write to temp file: %v", err) + } + + if err := tmpfile.Close(); err != nil { + t.Fatalf("Failed to close temp file: %v", err) + } + + statFile, err := Stat(tmpfile.Name()) + assert.NoError(t, err, "Stat() should not return error") + + assert.Equal(t, tmpfile.Name(), statFile.Name, "Stat() Name should match") + assert.False(t, statFile.mtime.IsZero(), "Stat() mtime should be non-zero") + + nonExistentFilePath := "./nonexistentfile.txt" + _, err = Stat(nonExistentFilePath) + + assert.Error(t, err, "Stat should return error for non-existent file") + assert.True(t, os.IsNotExist(err), "Error should indicate non-existent file") } From 17ed3680d232b686c3da1b7578e6e5f4e314f5df Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Sun, 30 Jun 2024 23:50:18 +0800 Subject: [PATCH 05/42] feat: Use fsnotify to detect model/policy files change in casbin plugin --- plugins/pkg/file/fs.go | 58 ++++++++------ plugins/pkg/file/fs_test.go | 105 ++++++++++++++++++++------ plugins/plugins/casbin/filter.go | 14 ++-- plugins/plugins/casbin/filter_test.go | 17 +++++ 4 files changed, 140 insertions(+), 54 deletions(-) diff --git a/plugins/pkg/file/fs.go b/plugins/pkg/file/fs.go index 3553c034..4b2e0140 100644 --- a/plugins/pkg/file/fs.go +++ b/plugins/pkg/file/fs.go @@ -18,6 +18,7 @@ import ( "errors" "fmt" "os" + "path/filepath" "sync" "time" @@ -38,7 +39,9 @@ type File struct { } type Fsnotify struct { - Watcher *fsnotify.Watcher + mu sync.Mutex + Watcher *fsnotify.Watcher + WatchedFiles map[string]struct{} } func newFsnotify() (fs *Fsnotify) { @@ -47,53 +50,61 @@ func newFsnotify() (fs *Fsnotify) { logger.Error(err, "create watcher failed") return } - fs = &Fsnotify{ - Watcher: watcher, + + return &Fsnotify{ + Watcher: watcher, + WatchedFiles: make(map[string]struct{}), } - return + } var ( defaultFsnotify = newFsnotify() ) -func Update(onChange func(), files ...*File) (err error) { - err = WatchFiles(onChange, files...) - return -} - -func WatchFiles(onChange func(), files ...*File) (err error) { - if len(files) < 1 { - err = errors.New("must specify at least one file to watch") - return +func WatchFiles(onChange func(), file *File, otherFiles ...*File) (err error) { + files := append([]*File{file}, otherFiles...) + for _, f := range files { + if f == nil { + return errors.New("file pointer cannot be nil") + } } - watcher := newFsnotify().Watcher + watcher := defaultFsnotify.Watcher if err != nil { return } // Add files to watcher. - for _, file := range files { - go defaultFsnotify.watchFiles(onChange, watcher, file) + for _, f := range files { + dir := filepath.Dir(f.Name) + err = watcher.Add(dir) + if err != nil { + logger.Error(err, "add file to watcher failed") + } + if _, exists := defaultFsnotify.WatchedFiles[dir]; exists { + logger.Info(fmt.Sprintf("File %s is already being watched", f.Name)) + continue + } + // 添加到已监听文件的集合 + defaultFsnotify.WatchedFiles[dir] = struct{}{} + go defaultFsnotify.watchFiles(onChange, watcher, dir) } return } -func (f *Fsnotify) watchFiles(onChange func(), w *fsnotify.Watcher, files *File) { +func (f *Fsnotify) watchFiles(onChange func(), w *fsnotify.Watcher, dir string) { defer func(w *fsnotify.Watcher) { - logger.Info("stop watch files" + files.Name) + f.mu.Lock() + delete(defaultFsnotify.WatchedFiles, dir) + f.mu.Unlock() err := w.Close() if err != nil { logger.Error(err, "failed to close fsnotify watcher") } }(w) - err := w.Add(files.Name) - if err != nil { - logger.Error(err, "add file to watcher failed") - } - logger.Info("start watch files" + files.Name) + for { select { case event, ok := <-w.Events: @@ -102,7 +113,6 @@ func (f *Fsnotify) watchFiles(onChange func(), w *fsnotify.Watcher, files *File) } logger.Info(fmt.Sprintf("event: %v", event)) onChange() - return case err, ok := <-w.Errors: if !ok { return diff --git a/plugins/pkg/file/fs_test.go b/plugins/pkg/file/fs_test.go index 4225f76c..499a5e1e 100644 --- a/plugins/pkg/file/fs_test.go +++ b/plugins/pkg/file/fs_test.go @@ -15,17 +15,21 @@ package file import ( + "fmt" + "io/ioutil" "os" "sync" "testing" "time" + "github.com/fsnotify/fsnotify" "github.com/stretchr/testify/assert" ) func TestFileIsChanged(t *testing.T) { - var mu sync.Mutex - i := 1 + var wg sync.WaitGroup + i := 4 + tmpfile, _ := os.CreateTemp("./", "example") defer func(name string) { err := os.Remove(name) @@ -36,41 +40,94 @@ func TestFileIsChanged(t *testing.T) { file := &File{Name: tmpfile.Name()} _ = WatchFiles(func() { - mu.Lock() + wg.Add(1) + defer wg.Done() + defaultFsnotify.mu.Lock() i = 2 - mu.Unlock() + defaultFsnotify.mu.Unlock() }, file) - time.Sleep(1 * time.Millisecond) tmpfile.Write([]byte("bls")) tmpfile.Sync() - mu.Lock() - assert.Equal(t, 2, i) - mu.Unlock() + wg.Wait() _ = WatchFiles(func() { - mu.Lock() + wg.Add(1) + defer wg.Done() + defaultFsnotify.mu.Lock() i = 1 - mu.Unlock() + defaultFsnotify.mu.Unlock() }, file) - time.Sleep(1 * time.Millisecond) tmpfile.Sync() - mu.Lock() - assert.Equal(t, 2, i) - mu.Unlock() + wg.Wait() - _ = Update(func() { - mu.Lock() - i = 1 - mu.Unlock() - }, file) + err := WatchFiles(func() {}, nil) + assert.Error(t, err, "file pointer cannot be nil") - time.Sleep(1 * time.Millisecond) - tmpfile.Sync() - mu.Lock() + filename := "my_file.txt" + content := "Hello, World!" + + f, err := os.Create(filename) + fi := &File{Name: f.Name()} + if err != nil { + fmt.Println("Error creating file:", err) + return + } + defer f.Close() + + _ = WatchFiles(func() { + wg.Add(1) + defer wg.Done() + defaultFsnotify.mu.Lock() + i = 3 + defaultFsnotify.mu.Unlock() + }, fi) + _, _ = f.WriteString(content) + + _ = os.Remove(filename) + f, _ = os.Create(filename) + + defer f.Close() + + _, _ = f.WriteString("New content for the file.") + _ = os.Remove(filename) + + defaultFsnotify.mu.Lock() assert.Equal(t, 2, i) + defaultFsnotify.mu.Unlock() + + watcher, err := fsnotify.NewWatcher() + assert.NoError(t, err) + defer watcher.Close() + fs := &Fsnotify{ + WatchedFiles: make(map[string]struct{}), + } + tmpDir, err := ioutil.TempDir("", "watch_test") + assert.NoError(t, err) + defer os.RemoveAll(tmpDir) + fs.WatchedFiles[tmpDir] = struct{}{} + err = watcher.Add(tmpDir) + assert.NoError(t, err) + + // check whether onChange is called + onChangeCalled := false + onChange := func() { + onChangeCalled = true + } + + go fs.watchFiles(onChange, watcher, tmpDir) + tmpFile, err := os.CreateTemp(tmpDir, "testfile") + assert.NoError(t, err) + defer tmpFile.Close() + + time.Sleep(500 * time.Millisecond) + watcher.Close() + time.Sleep(500 * time.Millisecond) - err := WatchFiles(func() {}) - assert.Error(t, err, "must specify at least one file to watch") + fs.mu.Lock() + _, exists := fs.WatchedFiles[tmpDir] + fs.mu.Unlock() + assert.True(t, exists, "WatchedFiles should be updated") + assert.True(t, onChangeCalled, "onChange should be called") } func TestStat(t *testing.T) { diff --git a/plugins/plugins/casbin/filter.go b/plugins/plugins/casbin/filter.go index aabc73d1..989f6101 100644 --- a/plugins/plugins/casbin/filter.go +++ b/plugins/plugins/casbin/filter.go @@ -35,7 +35,7 @@ type filter struct { config *config } -func (f *filter) reloadEnforcer() { +func reloadEnforcer(f *filter) { conf := f.config if !conf.updating.Load() { conf.updating.Store(true) @@ -48,17 +48,17 @@ func (f *filter) reloadEnforcer() { e, err := casbin.NewEnforcer(conf.Rule.Model, conf.Rule.Policy) if err != nil { api.LogErrorf("failed to update Enforcer: %v", err) - // next request will retry } else { conf.lock.Lock() conf.enforcer = e conf.lock.Unlock() - err = file.Update(func() { - f.reloadEnforcer() + err = file.WatchFiles(func() { + reloadEnforcer(f) }, conf.modelFile, conf.policyFile) + if err != nil { - api.LogErrorf("failed to update Enforcer: %v", err) + api.LogErrorf("failed to watch files: %v", err) } api.LogWarnf("policy %s or model %s changed, enforcer reloaded", conf.policyFile.Name, conf.modelFile.Name) } @@ -70,7 +70,9 @@ func (f *filter) DecodeHeaders(headers api.RequestHeaderMap, endStream bool) api role, _ := headers.Get(conf.Token.Name) // role can be "" url := headers.Url() - err := file.WatchFiles(f.reloadEnforcer, conf.modelFile, conf.policyFile) + err := file.WatchFiles(func() { + reloadEnforcer(f) + }, conf.modelFile, conf.policyFile) if err != nil { api.LogErrorf("failed to watch files: %v", err) diff --git a/plugins/plugins/casbin/filter_test.go b/plugins/plugins/casbin/filter_test.go index f92585f6..40976d8b 100644 --- a/plugins/plugins/casbin/filter_test.go +++ b/plugins/plugins/casbin/filter_test.go @@ -16,8 +16,10 @@ package casbin import ( "net/http" + "os" "sync" "testing" + "time" "github.com/stretchr/testify/assert" @@ -78,10 +80,25 @@ func TestCasbin(t *testing.T) { f := factory(c, cb) hdr := envoy.NewRequestHeaderMap(tt.header) + // Simulate file change + go func() { + time.Sleep(500 * time.Millisecond) + // Modify the policy file + os.WriteFile("./testdata/policy.csv", []byte("p, alice, /other, GET"), 0644) + }() + + // Call the reloadEnforcer method + fTyped, ok := f.(*filter) + if !ok { + t.Fatal("Failed to convert api.Filter to *filter") + } + reloadEnforcer(fTyped) + wg := sync.WaitGroup{} for i := 0; i < 3; i++ { wg.Add(1) go func() { + // ensure the lock takes effect lr, ok := f.DecodeHeaders(hdr, true).(*api.LocalResponse) if !ok { From 99bf29d444bcf904f5f343f724bd9015115d979b Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Mon, 1 Jul 2024 00:23:57 +0800 Subject: [PATCH 06/42] format go --- plugins/pkg/file/fs.go | 14 +++++++------- plugins/pkg/file/fs_test.go | 6 ++++-- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/plugins/pkg/file/fs.go b/plugins/pkg/file/fs.go index 4b2e0140..ce08eeb2 100644 --- a/plugins/pkg/file/fs.go +++ b/plugins/pkg/file/fs.go @@ -32,8 +32,6 @@ var ( ) type File struct { - lock sync.RWMutex - Name string mtime time.Time } @@ -71,16 +69,13 @@ func WatchFiles(onChange func(), file *File, otherFiles ...*File) (err error) { } watcher := defaultFsnotify.Watcher - if err != nil { - return - } // Add files to watcher. for _, f := range files { dir := filepath.Dir(f.Name) - err = watcher.Add(dir) + err = defaultFsnotify.AddFiles(dir) if err != nil { - logger.Error(err, "add file to watcher failed") + logger.Error(err, "failed to add file") } if _, exists := defaultFsnotify.WatchedFiles[dir]; exists { logger.Info(fmt.Sprintf("File %s is already being watched", f.Name)) @@ -94,6 +89,11 @@ func WatchFiles(onChange func(), file *File, otherFiles ...*File) (err error) { return } +func (f *Fsnotify) AddFiles(dir string) (err error) { + err = f.Watcher.Add(dir) + return +} + func (f *Fsnotify) watchFiles(onChange func(), w *fsnotify.Watcher, dir string) { defer func(w *fsnotify.Watcher) { f.mu.Lock() diff --git a/plugins/pkg/file/fs_test.go b/plugins/pkg/file/fs_test.go index 499a5e1e..c6910b58 100644 --- a/plugins/pkg/file/fs_test.go +++ b/plugins/pkg/file/fs_test.go @@ -39,13 +39,15 @@ func TestFileIsChanged(t *testing.T) { }(tmpfile.Name()) file := &File{Name: tmpfile.Name()} - _ = WatchFiles(func() { + err := WatchFiles(func() { wg.Add(1) defer wg.Done() defaultFsnotify.mu.Lock() i = 2 defaultFsnotify.mu.Unlock() }, file) + + assert.Nil(t, err) tmpfile.Write([]byte("bls")) tmpfile.Sync() wg.Wait() @@ -60,7 +62,7 @@ func TestFileIsChanged(t *testing.T) { tmpfile.Sync() wg.Wait() - err := WatchFiles(func() {}, nil) + err = WatchFiles(func() {}, nil) assert.Error(t, err, "file pointer cannot be nil") filename := "my_file.txt" From 6eb6b0362eecf7601635c1fb66a897d37287cb96 Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Sat, 6 Jul 2024 19:57:54 +0800 Subject: [PATCH 07/42] refactor codes --- plugins/pkg/file/fs.go | 8 ++--- plugins/pkg/file/fs_test.go | 12 ++++---- plugins/plugins/casbin/filter.go | 17 +++++------ plugins/plugins/casbin/filter_test.go | 37 ++++++++---------------- plugins/tests/integration/casbin_test.go | 5 ++-- 5 files changed, 32 insertions(+), 47 deletions(-) diff --git a/plugins/pkg/file/fs.go b/plugins/pkg/file/fs.go index ce08eeb2..90efebf5 100644 --- a/plugins/pkg/file/fs.go +++ b/plugins/pkg/file/fs.go @@ -16,9 +16,7 @@ package file import ( "errors" - "fmt" "os" - "path/filepath" "sync" "time" @@ -72,13 +70,13 @@ func WatchFiles(onChange func(), file *File, otherFiles ...*File) (err error) { // Add files to watcher. for _, f := range files { - dir := filepath.Dir(f.Name) + //dir := filepath.Dir(f.Name) + dir := f.Name err = defaultFsnotify.AddFiles(dir) if err != nil { logger.Error(err, "failed to add file") } if _, exists := defaultFsnotify.WatchedFiles[dir]; exists { - logger.Info(fmt.Sprintf("File %s is already being watched", f.Name)) continue } // 添加到已监听文件的集合 @@ -111,7 +109,7 @@ func (f *Fsnotify) watchFiles(onChange func(), w *fsnotify.Watcher, dir string) if !ok { return } - logger.Info(fmt.Sprintf("event: %v", event)) + logger.Info("file changed: ", "event", event) onChange() case err, ok := <-w.Errors: if !ok { diff --git a/plugins/pkg/file/fs_test.go b/plugins/pkg/file/fs_test.go index c6910b58..62b244a7 100644 --- a/plugins/pkg/file/fs_test.go +++ b/plugins/pkg/file/fs_test.go @@ -16,8 +16,8 @@ package file import ( "fmt" - "io/ioutil" "os" + "path/filepath" "sync" "testing" "time" @@ -43,10 +43,10 @@ func TestFileIsChanged(t *testing.T) { wg.Add(1) defer wg.Done() defaultFsnotify.mu.Lock() - i = 2 + i = 5 defaultFsnotify.mu.Unlock() }, file) - + i = 6 assert.Nil(t, err) tmpfile.Write([]byte("bls")) tmpfile.Sync() @@ -94,7 +94,7 @@ func TestFileIsChanged(t *testing.T) { _ = os.Remove(filename) defaultFsnotify.mu.Lock() - assert.Equal(t, 2, i) + assert.Equal(t, 5, i) defaultFsnotify.mu.Unlock() watcher, err := fsnotify.NewWatcher() @@ -103,8 +103,8 @@ func TestFileIsChanged(t *testing.T) { fs := &Fsnotify{ WatchedFiles: make(map[string]struct{}), } - tmpDir, err := ioutil.TempDir("", "watch_test") - assert.NoError(t, err) + tmp, err := os.CreateTemp("", "watch_test") + tmpDir := filepath.Dir(tmp.Name()) defer os.RemoveAll(tmpDir) fs.WatchedFiles[tmpDir] = struct{}{} err = watcher.Add(tmpDir) diff --git a/plugins/plugins/casbin/filter.go b/plugins/plugins/casbin/filter.go index 989f6101..b6370d1b 100644 --- a/plugins/plugins/casbin/filter.go +++ b/plugins/plugins/casbin/filter.go @@ -35,7 +35,7 @@ type filter struct { config *config } -func reloadEnforcer(f *filter) { +func (f *filter) reloadEnforcer() { conf := f.config if !conf.updating.Load() { conf.updating.Store(true) @@ -44,17 +44,16 @@ func reloadEnforcer(f *filter) { go func() { defer conf.updating.Store(false) defer f.callbacks.RecoverPanic() - - e, err := casbin.NewEnforcer(conf.Rule.Model, conf.Rule.Policy) + e, err := casbin.NewEnforcer(conf.Rule.Model, conf.Rule.Policy, true) if err != nil { api.LogErrorf("failed to update Enforcer: %v", err) } else { conf.lock.Lock() - conf.enforcer = e + f.config.enforcer = e conf.lock.Unlock() err = file.WatchFiles(func() { - reloadEnforcer(f) + f.reloadEnforcer() }, conf.modelFile, conf.policyFile) if err != nil { @@ -69,18 +68,18 @@ func (f *filter) DecodeHeaders(headers api.RequestHeaderMap, endStream bool) api conf := f.config role, _ := headers.Get(conf.Token.Name) // role can be "" url := headers.Url() - err := file.WatchFiles(func() { - reloadEnforcer(f) + f.reloadEnforcer() }, conf.modelFile, conf.policyFile) - if err != nil { api.LogErrorf("failed to watch files: %v", err) return &api.LocalResponse{Code: 500} } + e, _ := casbin.NewEnforcer(conf.Rule.Model, conf.Rule.Policy, true) + conf.lock.RLock() - ok, err := f.config.enforcer.Enforce(role, url.Path, headers.Method()) + ok, err := e.Enforce(role, url.Path, headers.Method()) conf.lock.RUnlock() if !ok { diff --git a/plugins/plugins/casbin/filter_test.go b/plugins/plugins/casbin/filter_test.go index 40976d8b..c507f218 100644 --- a/plugins/plugins/casbin/filter_test.go +++ b/plugins/plugins/casbin/filter_test.go @@ -16,10 +16,8 @@ package casbin import ( "net/http" - "os" "sync" "testing" - "time" "github.com/stretchr/testify/assert" @@ -80,35 +78,24 @@ func TestCasbin(t *testing.T) { f := factory(c, cb) hdr := envoy.NewRequestHeaderMap(tt.header) - // Simulate file change - go func() { - time.Sleep(500 * time.Millisecond) - // Modify the policy file - os.WriteFile("./testdata/policy.csv", []byte("p, alice, /other, GET"), 0644) - }() - - // Call the reloadEnforcer method fTyped, ok := f.(*filter) if !ok { t.Fatal("Failed to convert api.Filter to *filter") } - reloadEnforcer(fTyped) + fTyped.reloadEnforcer() wg := sync.WaitGroup{} - for i := 0; i < 3; i++ { - wg.Add(1) - go func() { - - // ensure the lock takes effect - lr, ok := f.DecodeHeaders(hdr, true).(*api.LocalResponse) - if !ok { - assert.Equal(t, tt.status, 0) - } else { - assert.Equal(t, tt.status, lr.Code) - } - wg.Done() - }() - } + wg.Add(1) + go func() { + // ensure the lock takes effect + lr, ok := f.DecodeHeaders(hdr, true).(*api.LocalResponse) + if !ok { + assert.Equal(t, tt.status, 0) + } else { + assert.Equal(t, tt.status, lr.Code) + } + wg.Done() + }() wg.Wait() }) } diff --git a/plugins/tests/integration/casbin_test.go b/plugins/tests/integration/casbin_test.go index 69cfa739..4afe5ced 100644 --- a/plugins/tests/integration/casbin_test.go +++ b/plugins/tests/integration/casbin_test.go @@ -115,15 +115,16 @@ g, bob, admin }) } - time.Sleep(10 * time.Second) // TODO remove this once we switch the file change detector to inotify // configuration is not changed, but file changed err = os.WriteFile(policyFile2.Name(), []byte(policy), 0755) require.Nil(t, err) + + time.Sleep(10 * time.Second) // TODO remove this once we switch the file change detector to inotify hdr := http.Header{} hdr.Set("customer", "alice") assert.Eventually(t, func() bool { resp, _ := dp.Post("/echo", hdr, strings.NewReader("any")) return resp != nil && resp.StatusCode == 200 - }, 1*time.Second, 10*time.Millisecond) + }, 10*time.Second, 1*time.Second) } From 5c99c69db3ad87e588c59daecc8dfeb2d3981f83 Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Sat, 6 Jul 2024 21:53:50 +0800 Subject: [PATCH 08/42] refactor codes --- plugins/pkg/file/fs.go | 4 +-- plugins/plugins/casbin/filter.go | 43 +++++++++++++++++++++--- plugins/tests/integration/casbin_test.go | 1 - 3 files changed, 40 insertions(+), 8 deletions(-) diff --git a/plugins/pkg/file/fs.go b/plugins/pkg/file/fs.go index 90efebf5..27bb77b6 100644 --- a/plugins/pkg/file/fs.go +++ b/plugins/pkg/file/fs.go @@ -17,6 +17,7 @@ package file import ( "errors" "os" + "path/filepath" "sync" "time" @@ -70,8 +71,7 @@ func WatchFiles(onChange func(), file *File, otherFiles ...*File) (err error) { // Add files to watcher. for _, f := range files { - //dir := filepath.Dir(f.Name) - dir := f.Name + dir := filepath.Dir(f.Name) err = defaultFsnotify.AddFiles(dir) if err != nil { logger.Error(err, "failed to add file") diff --git a/plugins/plugins/casbin/filter.go b/plugins/plugins/casbin/filter.go index b6370d1b..03f8be97 100644 --- a/plugins/plugins/casbin/filter.go +++ b/plugins/plugins/casbin/filter.go @@ -39,12 +39,12 @@ func (f *filter) reloadEnforcer() { conf := f.config if !conf.updating.Load() { conf.updating.Store(true) - api.LogWarnf("policy %s or model %s changed, reload enforcer", conf.policyFile.Name, conf.modelFile.Name) + api.LogWarnf("policy %s or model %s Changed, reload enforcer", conf.policyFile.Name, conf.modelFile.Name) go func() { defer conf.updating.Store(false) defer f.callbacks.RecoverPanic() - e, err := casbin.NewEnforcer(conf.Rule.Model, conf.Rule.Policy, true) + e, err := casbin.NewEnforcer(conf.Rule.Model, conf.Rule.Policy) if err != nil { api.LogErrorf("failed to update Enforcer: %v", err) } else { @@ -64,22 +64,55 @@ func (f *filter) reloadEnforcer() { }() } } + +var Changed = false + func (f *filter) DecodeHeaders(headers api.RequestHeaderMap, endStream bool) api.ResultAction { + conf := f.config role, _ := headers.Get(conf.Token.Name) // role can be "" url := headers.Url() err := file.WatchFiles(func() { - f.reloadEnforcer() + Changed = true }, conf.modelFile, conf.policyFile) if err != nil { api.LogErrorf("failed to watch files: %v", err) return &api.LocalResponse{Code: 500} } - e, _ := casbin.NewEnforcer(conf.Rule.Model, conf.Rule.Policy, true) + if Changed { + if !conf.updating.Load() { + conf.updating.Store(true) + api.LogWarnf("policy %s or model %s Changed, reload enforcer", conf.policyFile.Name, conf.modelFile.Name) + + go func() { + defer conf.updating.Store(false) + defer f.callbacks.RecoverPanic() + e, err := casbin.NewEnforcer(conf.Rule.Model, conf.Rule.Policy) + if err != nil { + api.LogErrorf("failed to update Enforcer: %v", err) + } else { + conf.lock.Lock() + f.config.enforcer = e + conf.lock.Unlock() + + Changed = false + err = file.WatchFiles(func() { + Changed = true + }, conf.modelFile, conf.policyFile) + + if err != nil { + api.LogErrorf("failed to watch files: %v", err) + } + + api.LogWarnf("policy %s or model %s Changed, enforcer reloaded", conf.policyFile.Name, conf.modelFile.Name) + } + }() + } + } conf.lock.RLock() - ok, err := e.Enforce(role, url.Path, headers.Method()) + ok, err := f.config.enforcer.Enforce(role, url.Path, headers.Method()) conf.lock.RUnlock() if !ok { diff --git a/plugins/tests/integration/casbin_test.go b/plugins/tests/integration/casbin_test.go index 4afe5ced..15163d62 100644 --- a/plugins/tests/integration/casbin_test.go +++ b/plugins/tests/integration/casbin_test.go @@ -119,7 +119,6 @@ g, bob, admin err = os.WriteFile(policyFile2.Name(), []byte(policy), 0755) require.Nil(t, err) - time.Sleep(10 * time.Second) // TODO remove this once we switch the file change detector to inotify hdr := http.Header{} hdr.Set("customer", "alice") From 846ff5b63c8141505421bc4e583632192f82f37b Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Sat, 6 Jul 2024 23:04:15 +0800 Subject: [PATCH 09/42] refactor codes --- plugins/pkg/file/fs_test.go | 2 +- plugins/plugins/casbin/filter.go | 36 ++++----------------------- plugins/plugins/casbin/filter_test.go | 9 +++---- 3 files changed, 9 insertions(+), 38 deletions(-) diff --git a/plugins/pkg/file/fs_test.go b/plugins/pkg/file/fs_test.go index 62b244a7..8d31b5b5 100644 --- a/plugins/pkg/file/fs_test.go +++ b/plugins/pkg/file/fs_test.go @@ -103,7 +103,7 @@ func TestFileIsChanged(t *testing.T) { fs := &Fsnotify{ WatchedFiles: make(map[string]struct{}), } - tmp, err := os.CreateTemp("", "watch_test") + tmp, _ := os.CreateTemp("", "watch_test") tmpDir := filepath.Dir(tmp.Name()) defer os.RemoveAll(tmpDir) fs.WatchedFiles[tmpDir] = struct{}{} diff --git a/plugins/plugins/casbin/filter.go b/plugins/plugins/casbin/filter.go index 03f8be97..49abbd6e 100644 --- a/plugins/plugins/casbin/filter.go +++ b/plugins/plugins/casbin/filter.go @@ -35,36 +35,6 @@ type filter struct { config *config } -func (f *filter) reloadEnforcer() { - conf := f.config - if !conf.updating.Load() { - conf.updating.Store(true) - api.LogWarnf("policy %s or model %s Changed, reload enforcer", conf.policyFile.Name, conf.modelFile.Name) - - go func() { - defer conf.updating.Store(false) - defer f.callbacks.RecoverPanic() - e, err := casbin.NewEnforcer(conf.Rule.Model, conf.Rule.Policy) - if err != nil { - api.LogErrorf("failed to update Enforcer: %v", err) - } else { - conf.lock.Lock() - f.config.enforcer = e - conf.lock.Unlock() - - err = file.WatchFiles(func() { - f.reloadEnforcer() - }, conf.modelFile, conf.policyFile) - - if err != nil { - api.LogErrorf("failed to watch files: %v", err) - } - api.LogWarnf("policy %s or model %s changed, enforcer reloaded", conf.policyFile.Name, conf.modelFile.Name) - } - }() - } -} - var Changed = false func (f *filter) DecodeHeaders(headers api.RequestHeaderMap, endStream bool) api.ResultAction { @@ -73,7 +43,9 @@ func (f *filter) DecodeHeaders(headers api.RequestHeaderMap, endStream bool) api role, _ := headers.Get(conf.Token.Name) // role can be "" url := headers.Url() err := file.WatchFiles(func() { + conf.lock.Lock() Changed = true + conf.lock.Unlock() }, conf.modelFile, conf.policyFile) if err != nil { api.LogErrorf("failed to watch files: %v", err) @@ -94,11 +66,13 @@ func (f *filter) DecodeHeaders(headers api.RequestHeaderMap, endStream bool) api } else { conf.lock.Lock() f.config.enforcer = e + Changed = false conf.lock.Unlock() - Changed = false err = file.WatchFiles(func() { + conf.lock.Lock() Changed = true + conf.lock.Unlock() }, conf.modelFile, conf.policyFile) if err != nil { diff --git a/plugins/plugins/casbin/filter_test.go b/plugins/plugins/casbin/filter_test.go index c507f218..00a18055 100644 --- a/plugins/plugins/casbin/filter_test.go +++ b/plugins/plugins/casbin/filter_test.go @@ -78,25 +78,22 @@ func TestCasbin(t *testing.T) { f := factory(c, cb) hdr := envoy.NewRequestHeaderMap(tt.header) - fTyped, ok := f.(*filter) - if !ok { - t.Fatal("Failed to convert api.Filter to *filter") - } - fTyped.reloadEnforcer() - wg := sync.WaitGroup{} wg.Add(1) go func() { // ensure the lock takes effect lr, ok := f.DecodeHeaders(hdr, true).(*api.LocalResponse) + if !ok { assert.Equal(t, tt.status, 0) } else { assert.Equal(t, tt.status, lr.Code) + assert.False(t, Changed) } wg.Done() }() wg.Wait() }) } + } From f330ec6d0a40dbf7e4b8cba89598721c0ca1c4a1 Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Sat, 6 Jul 2024 23:45:37 +0800 Subject: [PATCH 10/42] refactor codes --- plugins/plugins/casbin/filter.go | 65 ++++++++++++++------------- plugins/plugins/casbin/filter_test.go | 4 ++ 2 files changed, 39 insertions(+), 30 deletions(-) diff --git a/plugins/plugins/casbin/filter.go b/plugins/plugins/casbin/filter.go index 49abbd6e..c6a05a48 100644 --- a/plugins/plugins/casbin/filter.go +++ b/plugins/plugins/casbin/filter.go @@ -37,6 +37,40 @@ type filter struct { var Changed = false +func reloadEnforcer(f *filter) { + conf := f.config + if !conf.updating.Load() { + conf.updating.Store(true) + api.LogWarnf("policy %s or model %s Changed, reload enforcer", conf.policyFile.Name, conf.modelFile.Name) + + go func() { + defer conf.updating.Store(false) + defer f.callbacks.RecoverPanic() + e, err := casbin.NewEnforcer(conf.Rule.Model, conf.Rule.Policy) + if err != nil { + api.LogErrorf("failed to update Enforcer: %v", err) + } else { + conf.lock.Lock() + f.config.enforcer = e + Changed = false + conf.lock.Unlock() + + err = file.WatchFiles(func() { + conf.lock.Lock() + Changed = true + conf.lock.Unlock() + }, conf.modelFile, conf.policyFile) + + if err != nil { + api.LogErrorf("failed to watch files: %v", err) + } + + api.LogWarnf("policy %s or model %s Changed, enforcer reloaded", conf.policyFile.Name, conf.modelFile.Name) + } + }() + } +} + func (f *filter) DecodeHeaders(headers api.RequestHeaderMap, endStream bool) api.ResultAction { conf := f.config @@ -53,36 +87,7 @@ func (f *filter) DecodeHeaders(headers api.RequestHeaderMap, endStream bool) api } if Changed { - if !conf.updating.Load() { - conf.updating.Store(true) - api.LogWarnf("policy %s or model %s Changed, reload enforcer", conf.policyFile.Name, conf.modelFile.Name) - - go func() { - defer conf.updating.Store(false) - defer f.callbacks.RecoverPanic() - e, err := casbin.NewEnforcer(conf.Rule.Model, conf.Rule.Policy) - if err != nil { - api.LogErrorf("failed to update Enforcer: %v", err) - } else { - conf.lock.Lock() - f.config.enforcer = e - Changed = false - conf.lock.Unlock() - - err = file.WatchFiles(func() { - conf.lock.Lock() - Changed = true - conf.lock.Unlock() - }, conf.modelFile, conf.policyFile) - - if err != nil { - api.LogErrorf("failed to watch files: %v", err) - } - - api.LogWarnf("policy %s or model %s Changed, enforcer reloaded", conf.policyFile.Name, conf.modelFile.Name) - } - }() - } + reloadEnforcer(f) } conf.lock.RLock() diff --git a/plugins/plugins/casbin/filter_test.go b/plugins/plugins/casbin/filter_test.go index 00a18055..6b908412 100644 --- a/plugins/plugins/casbin/filter_test.go +++ b/plugins/plugins/casbin/filter_test.go @@ -78,6 +78,10 @@ func TestCasbin(t *testing.T) { f := factory(c, cb) hdr := envoy.NewRequestHeaderMap(tt.header) + ff, _ := f.(*filter) + + reloadEnforcer(ff) + wg := sync.WaitGroup{} wg.Add(1) go func() { From 29822b8203a6f01e1b4048566a09a1cfb6ac3239 Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Sun, 7 Jul 2024 15:27:15 +0800 Subject: [PATCH 11/42] add test codes --- plugins/go.mod | 1 + plugins/go.sum | 2 + plugins/pkg/file/fs.go | 12 +--- plugins/pkg/file/fs_test.go | 120 +++++++++++++------------------ plugins/plugins/casbin/filter.go | 1 - 5 files changed, 54 insertions(+), 82 deletions(-) diff --git a/plugins/go.mod b/plugins/go.mod index accb88bc..a06ab6fa 100644 --- a/plugins/go.mod +++ b/plugins/go.mod @@ -72,6 +72,7 @@ require ( github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/stoewer/go-strcase v1.2.0 // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/tchap/go-patricia/v2 v2.3.1 // indirect github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect diff --git a/plugins/go.sum b/plugins/go.sum index f3bd6564..a4bf2fcc 100644 --- a/plugins/go.sum +++ b/plugins/go.sum @@ -145,6 +145,8 @@ github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9 github.com/stoewer/go-strcase v1.2.0 h1:Z2iHWqGXH00XYgqDmNgQbIBxf3wrNq0F3feEy0ainaU= github.com/stoewer/go-strcase v1.2.0/go.mod h1:IBiWB2sKIp3wVVQ3Y035++gc+knqhUQag1KpM8ahLw8= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= diff --git a/plugins/pkg/file/fs.go b/plugins/pkg/file/fs.go index 27bb77b6..0b2dc58b 100644 --- a/plugins/pkg/file/fs.go +++ b/plugins/pkg/file/fs.go @@ -16,7 +16,6 @@ package file import ( "errors" - "os" "path/filepath" "sync" "time" @@ -52,7 +51,6 @@ func newFsnotify() (fs *Fsnotify) { Watcher: watcher, WatchedFiles: make(map[string]struct{}), } - } var ( @@ -79,7 +77,7 @@ func WatchFiles(onChange func(), file *File, otherFiles ...*File) (err error) { if _, exists := defaultFsnotify.WatchedFiles[dir]; exists { continue } - // 添加到已监听文件的集合 + //Add dir to the watched files defaultFsnotify.WatchedFiles[dir] = struct{}{} go defaultFsnotify.watchFiles(onChange, watcher, dir) } @@ -121,14 +119,8 @@ func (f *Fsnotify) watchFiles(onChange func(), w *fsnotify.Watcher, dir string) } func (f *Fsnotify) Stat(path string) (*File, error) { - info, err := os.Stat(path) - if err != nil { - return nil, err - } - return &File{ - Name: path, - mtime: info.ModTime(), + Name: path, }, nil } diff --git a/plugins/pkg/file/fs_test.go b/plugins/pkg/file/fs_test.go index 8d31b5b5..25175a31 100644 --- a/plugins/pkg/file/fs_test.go +++ b/plugins/pkg/file/fs_test.go @@ -15,7 +15,6 @@ package file import ( - "fmt" "os" "path/filepath" "sync" @@ -24,8 +23,19 @@ import ( "github.com/fsnotify/fsnotify" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" ) +// MockWatcher is a mock implementation of fsnotify.Watcher +type MockWatcher struct { + mock.Mock +} + +func (m *MockWatcher) Close() error { + args := m.Called() + return args.Error(0) +} + func TestFileIsChanged(t *testing.T) { var wg sync.WaitGroup i := 4 @@ -37,62 +47,25 @@ func TestFileIsChanged(t *testing.T) { t.Logf("%v", err) } }(tmpfile.Name()) + file, err := Stat(tmpfile.Name()) - file := &File{Name: tmpfile.Name()} - err := WatchFiles(func() { + assert.NoError(t, err) + assert.Equal(t, tmpfile.Name(), file.Name) + err = WatchFiles(func() { wg.Add(1) defer wg.Done() defaultFsnotify.mu.Lock() i = 5 defaultFsnotify.mu.Unlock() }, file) - i = 6 assert.Nil(t, err) tmpfile.Write([]byte("bls")) tmpfile.Sync() wg.Wait() - _ = WatchFiles(func() { - wg.Add(1) - defer wg.Done() - defaultFsnotify.mu.Lock() - i = 1 - defaultFsnotify.mu.Unlock() - }, file) - tmpfile.Sync() - wg.Wait() - err = WatchFiles(func() {}, nil) assert.Error(t, err, "file pointer cannot be nil") - filename := "my_file.txt" - content := "Hello, World!" - - f, err := os.Create(filename) - fi := &File{Name: f.Name()} - if err != nil { - fmt.Println("Error creating file:", err) - return - } - defer f.Close() - - _ = WatchFiles(func() { - wg.Add(1) - defer wg.Done() - defaultFsnotify.mu.Lock() - i = 3 - defaultFsnotify.mu.Unlock() - }, fi) - _, _ = f.WriteString(content) - - _ = os.Remove(filename) - f, _ = os.Create(filename) - - defer f.Close() - - _, _ = f.WriteString("New content for the file.") - _ = os.Remove(filename) - defaultFsnotify.mu.Lock() assert.Equal(t, 5, i) defaultFsnotify.mu.Unlock() @@ -103,9 +76,7 @@ func TestFileIsChanged(t *testing.T) { fs := &Fsnotify{ WatchedFiles: make(map[string]struct{}), } - tmp, _ := os.CreateTemp("", "watch_test") - tmpDir := filepath.Dir(tmp.Name()) - defer os.RemoveAll(tmpDir) + tmpDir := filepath.Dir(file.Name) fs.WatchedFiles[tmpDir] = struct{}{} err = watcher.Add(tmpDir) assert.NoError(t, err) @@ -125,37 +96,44 @@ func TestFileIsChanged(t *testing.T) { watcher.Close() time.Sleep(500 * time.Millisecond) - fs.mu.Lock() _, exists := fs.WatchedFiles[tmpDir] - fs.mu.Unlock() - assert.True(t, exists, "WatchedFiles should be updated") - assert.True(t, onChangeCalled, "onChange should be called") -} - -func TestStat(t *testing.T) { - tmpfile, err := os.CreateTemp("", "example") - if err != nil { - t.Fatalf("Failed to create temp file: %v", err) - } - defer os.Remove(tmpfile.Name()) - if _, err := tmpfile.Write([]byte("hello world")); err != nil { - t.Fatalf("Failed to write to temp file: %v", err) - } + assert.True(t, exists) + assert.True(t, onChangeCalled) - if err := tmpfile.Close(); err != nil { - t.Fatalf("Failed to close temp file: %v", err) - } + err = WatchFiles(func() {}, file, nil) + assert.Error(t, err, "file pointer cannot be nil") +} - statFile, err := Stat(tmpfile.Name()) - assert.NoError(t, err, "Stat() should not return error") +func TestClose(t *testing.T) { + dir := "./" + mockWatcher := new(MockWatcher) - assert.Equal(t, tmpfile.Name(), statFile.Name, "Stat() Name should match") - assert.False(t, statFile.mtime.IsZero(), "Stat() mtime should be non-zero") + mockWatcher.On("Close").Return(nil) - nonExistentFilePath := "./nonexistentfile.txt" - _, err = Stat(nonExistentFilePath) + defaultfsnotify := struct { + WatchedFiles map[string]bool + }{ + WatchedFiles: map[string]bool{dir: true}, + } - assert.Error(t, err, "Stat should return error for non-existent file") - assert.True(t, os.IsNotExist(err), "Error should indicate non-existent file") + f := struct { + mu sync.Mutex + }{} + + func(w *MockWatcher) { + defer func(w *MockWatcher) { + f.mu.Lock() + defer f.mu.Unlock() + delete(defaultfsnotify.WatchedFiles, dir) + err := w.Close() + if err != nil { + t.Errorf("failed to close fsnotify watcher: %v", err) + } + }(w) + }(mockWatcher) + + assert.NotContains(t, defaultFsnotify.WatchedFiles, dir) + + mockWatcher.AssertExpectations(t) } diff --git a/plugins/plugins/casbin/filter.go b/plugins/plugins/casbin/filter.go index c6a05a48..6bd68b9c 100644 --- a/plugins/plugins/casbin/filter.go +++ b/plugins/plugins/casbin/filter.go @@ -72,7 +72,6 @@ func reloadEnforcer(f *filter) { } func (f *filter) DecodeHeaders(headers api.RequestHeaderMap, endStream bool) api.ResultAction { - conf := f.config role, _ := headers.Get(conf.Token.Name) // role can be "" url := headers.Url() From a7bf59d005f86dfa832d548ed5755d2ca6921375 Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Sun, 7 Jul 2024 19:59:05 +0800 Subject: [PATCH 12/42] add test codes --- plugins/pkg/file/fs.go | 7 ++----- plugins/pkg/file/fs_test.go | 27 +++++++++++++++------------ plugins/plugins/casbin/filter.go | 8 +++++++- 3 files changed, 24 insertions(+), 18 deletions(-) diff --git a/plugins/pkg/file/fs.go b/plugins/pkg/file/fs.go index 0b2dc58b..a02e4d1f 100644 --- a/plugins/pkg/file/fs.go +++ b/plugins/pkg/file/fs.go @@ -16,11 +16,9 @@ package file import ( "errors" + "github.com/fsnotify/fsnotify" "path/filepath" "sync" - "time" - - "github.com/fsnotify/fsnotify" "mosn.io/htnn/api/pkg/log" ) @@ -30,8 +28,7 @@ var ( ) type File struct { - Name string - mtime time.Time + Name string } type Fsnotify struct { diff --git a/plugins/pkg/file/fs_test.go b/plugins/pkg/file/fs_test.go index 25175a31..8c984424 100644 --- a/plugins/pkg/file/fs_test.go +++ b/plugins/pkg/file/fs_test.go @@ -15,15 +15,14 @@ package file import ( + "github.com/fsnotify/fsnotify" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "os" "path/filepath" "sync" "testing" "time" - - "github.com/fsnotify/fsnotify" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" ) // MockWatcher is a mock implementation of fsnotify.Watcher @@ -38,6 +37,7 @@ func (m *MockWatcher) Close() error { func TestFileIsChanged(t *testing.T) { var wg sync.WaitGroup + var mu sync.Mutex i := 4 tmpfile, _ := os.CreateTemp("./", "example") @@ -54,9 +54,9 @@ func TestFileIsChanged(t *testing.T) { err = WatchFiles(func() { wg.Add(1) defer wg.Done() - defaultFsnotify.mu.Lock() + mu.Lock() i = 5 - defaultFsnotify.mu.Unlock() + mu.Unlock() }, file) assert.Nil(t, err) tmpfile.Write([]byte("bls")) @@ -66,31 +66,34 @@ func TestFileIsChanged(t *testing.T) { err = WatchFiles(func() {}, nil) assert.Error(t, err, "file pointer cannot be nil") - defaultFsnotify.mu.Lock() + mu.Lock() assert.Equal(t, 5, i) - defaultFsnotify.mu.Unlock() + mu.Unlock() watcher, err := fsnotify.NewWatcher() assert.NoError(t, err) defer watcher.Close() fs := &Fsnotify{ WatchedFiles: make(map[string]struct{}), + Watcher: watcher, } tmpDir := filepath.Dir(file.Name) fs.WatchedFiles[tmpDir] = struct{}{} - err = watcher.Add(tmpDir) + err = fs.AddFiles(tmpDir) assert.NoError(t, err) - // check whether onChange is called onChangeCalled := false onChange := func() { onChangeCalled = true } - go fs.watchFiles(onChange, watcher, tmpDir) + go fs.watchFiles(onChange, fs.Watcher, tmpDir) tmpFile, err := os.CreateTemp(tmpDir, "testfile") assert.NoError(t, err) - defer tmpFile.Close() + defer func() { + tmpFile.Close() + os.Remove(tmpfile.Name()) + }() time.Sleep(500 * time.Millisecond) watcher.Close() diff --git a/plugins/plugins/casbin/filter.go b/plugins/plugins/casbin/filter.go index 6bd68b9c..25752428 100644 --- a/plugins/plugins/casbin/filter.go +++ b/plugins/plugins/casbin/filter.go @@ -40,11 +40,17 @@ var Changed = false func reloadEnforcer(f *filter) { conf := f.config if !conf.updating.Load() { + conf.lock.Lock() conf.updating.Store(true) + conf.lock.Unlock() api.LogWarnf("policy %s or model %s Changed, reload enforcer", conf.policyFile.Name, conf.modelFile.Name) go func() { - defer conf.updating.Store(false) + defer func() { + conf.lock.Lock() + conf.updating.Store(false) + conf.lock.Unlock() + }() defer f.callbacks.RecoverPanic() e, err := casbin.NewEnforcer(conf.Rule.Model, conf.Rule.Policy) if err != nil { From 671958691d0ec091c5bde38bfc5fdbfd5ba37538 Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Sun, 7 Jul 2024 21:36:11 +0800 Subject: [PATCH 13/42] add test codes --- plugins/pkg/file/fs.go | 3 ++- plugins/pkg/file/fs_test.go | 20 +++++------------- plugins/plugins/casbin/filter_test.go | 29 +++++++++++++++++++++++---- 3 files changed, 32 insertions(+), 20 deletions(-) diff --git a/plugins/pkg/file/fs.go b/plugins/pkg/file/fs.go index a02e4d1f..c8220238 100644 --- a/plugins/pkg/file/fs.go +++ b/plugins/pkg/file/fs.go @@ -16,10 +16,11 @@ package file import ( "errors" - "github.com/fsnotify/fsnotify" "path/filepath" "sync" + "github.com/fsnotify/fsnotify" + "mosn.io/htnn/api/pkg/log" ) diff --git a/plugins/pkg/file/fs_test.go b/plugins/pkg/file/fs_test.go index 8c984424..f0a96978 100644 --- a/plugins/pkg/file/fs_test.go +++ b/plugins/pkg/file/fs_test.go @@ -15,14 +15,14 @@ package file import ( - "github.com/fsnotify/fsnotify" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "os" "path/filepath" "sync" "testing" - "time" + + "github.com/fsnotify/fsnotify" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" ) // MockWatcher is a mock implementation of fsnotify.Watcher @@ -88,21 +88,11 @@ func TestFileIsChanged(t *testing.T) { } go fs.watchFiles(onChange, fs.Watcher, tmpDir) - tmpFile, err := os.CreateTemp(tmpDir, "testfile") - assert.NoError(t, err) - defer func() { - tmpFile.Close() - os.Remove(tmpfile.Name()) - }() - - time.Sleep(500 * time.Millisecond) - watcher.Close() - time.Sleep(500 * time.Millisecond) _, exists := fs.WatchedFiles[tmpDir] assert.True(t, exists) - assert.True(t, onChangeCalled) + assert.False(t, onChangeCalled) err = WatchFiles(func() {}, file, nil) assert.Error(t, err, "file pointer cannot be nil") diff --git a/plugins/plugins/casbin/filter_test.go b/plugins/plugins/casbin/filter_test.go index 6b908412..f4aff5ef 100644 --- a/plugins/plugins/casbin/filter_test.go +++ b/plugins/plugins/casbin/filter_test.go @@ -18,6 +18,7 @@ import ( "net/http" "sync" "testing" + "time" "github.com/stretchr/testify/assert" @@ -78,10 +79,6 @@ func TestCasbin(t *testing.T) { f := factory(c, cb) hdr := envoy.NewRequestHeaderMap(tt.header) - ff, _ := f.(*filter) - - reloadEnforcer(ff) - wg := sync.WaitGroup{} wg.Add(1) go func() { @@ -99,5 +96,29 @@ func TestCasbin(t *testing.T) { wg.Wait() }) } +} + +func TestReloadEnforcer(t *testing.T) { + cb := envoy.NewFilterCallbackHandler() + c := &config{ + Config: casbin.Config{ + Rule: &casbin.Config_Rule{ + Model: "./testdata/model.conf", + Policy: "./testdata/policy.csv", + }, + Token: &casbin.Config_Token{ + Name: "user", + }, + }, + } + c.Init(nil) + f := factory(c, cb) + + Changed = true + ff, _ := f.(*filter) + reloadEnforcer(ff) + time.Sleep(2 * time.Second) + + assert.False(t, Changed) } From daf2c244862d0a42e871c629239e9da04f2fc993 Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Sun, 7 Jul 2024 22:34:48 +0800 Subject: [PATCH 14/42] add test codes --- plugins/pkg/file/fs_test.go | 51 +++----------------------------- plugins/plugins/casbin/filter.go | 4 ++- 2 files changed, 7 insertions(+), 48 deletions(-) diff --git a/plugins/pkg/file/fs_test.go b/plugins/pkg/file/fs_test.go index f0a96978..33f2acdd 100644 --- a/plugins/pkg/file/fs_test.go +++ b/plugins/pkg/file/fs_test.go @@ -22,19 +22,8 @@ import ( "github.com/fsnotify/fsnotify" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" ) -// MockWatcher is a mock implementation of fsnotify.Watcher -type MockWatcher struct { - mock.Mock -} - -func (m *MockWatcher) Close() error { - args := m.Called() - return args.Error(0) -} - func TestFileIsChanged(t *testing.T) { var wg sync.WaitGroup var mu sync.Mutex @@ -77,7 +66,10 @@ func TestFileIsChanged(t *testing.T) { WatchedFiles: make(map[string]struct{}), Watcher: watcher, } - tmpDir := filepath.Dir(file.Name) + tmpfile, err = os.CreateTemp("/tmp", "test") + assert.Nil(t, err) + defer os.Remove(tmpfile.Name()) + tmpDir := filepath.Dir(tmpfile.Name()) fs.WatchedFiles[tmpDir] = struct{}{} err = fs.AddFiles(tmpDir) assert.NoError(t, err) @@ -94,39 +86,4 @@ func TestFileIsChanged(t *testing.T) { assert.True(t, exists) assert.False(t, onChangeCalled) - err = WatchFiles(func() {}, file, nil) - assert.Error(t, err, "file pointer cannot be nil") -} - -func TestClose(t *testing.T) { - dir := "./" - mockWatcher := new(MockWatcher) - - mockWatcher.On("Close").Return(nil) - - defaultfsnotify := struct { - WatchedFiles map[string]bool - }{ - WatchedFiles: map[string]bool{dir: true}, - } - - f := struct { - mu sync.Mutex - }{} - - func(w *MockWatcher) { - defer func(w *MockWatcher) { - f.mu.Lock() - defer f.mu.Unlock() - delete(defaultfsnotify.WatchedFiles, dir) - err := w.Close() - if err != nil { - t.Errorf("failed to close fsnotify watcher: %v", err) - } - }(w) - }(mockWatcher) - - assert.NotContains(t, defaultFsnotify.WatchedFiles, dir) - - mockWatcher.AssertExpectations(t) } diff --git a/plugins/plugins/casbin/filter.go b/plugins/plugins/casbin/filter.go index 25752428..f0efdcde 100644 --- a/plugins/plugins/casbin/filter.go +++ b/plugins/plugins/casbin/filter.go @@ -39,8 +39,8 @@ var Changed = false func reloadEnforcer(f *filter) { conf := f.config + conf.lock.Lock() if !conf.updating.Load() { - conf.lock.Lock() conf.updating.Store(true) conf.lock.Unlock() api.LogWarnf("policy %s or model %s Changed, reload enforcer", conf.policyFile.Name, conf.modelFile.Name) @@ -74,6 +74,8 @@ func reloadEnforcer(f *filter) { api.LogWarnf("policy %s or model %s Changed, enforcer reloaded", conf.policyFile.Name, conf.modelFile.Name) } }() + } else { + conf.lock.Unlock() } } From 7120888529a2e31f63d38c5e0ac8e14118c9238f Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Sun, 7 Jul 2024 22:47:07 +0800 Subject: [PATCH 15/42] add test codes --- plugins/go.mod | 1 - plugins/go.sum | 2 -- plugins/plugins/casbin/filter_test.go | 5 +++-- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/plugins/go.mod b/plugins/go.mod index a06ab6fa..accb88bc 100644 --- a/plugins/go.mod +++ b/plugins/go.mod @@ -72,7 +72,6 @@ require ( github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/stoewer/go-strcase v1.2.0 // indirect - github.com/stretchr/objx v0.5.2 // indirect github.com/tchap/go-patricia/v2 v2.3.1 // indirect github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect diff --git a/plugins/go.sum b/plugins/go.sum index a4bf2fcc..f3bd6564 100644 --- a/plugins/go.sum +++ b/plugins/go.sum @@ -145,8 +145,6 @@ github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9 github.com/stoewer/go-strcase v1.2.0 h1:Z2iHWqGXH00XYgqDmNgQbIBxf3wrNq0F3feEy0ainaU= github.com/stoewer/go-strcase v1.2.0/go.mod h1:IBiWB2sKIp3wVVQ3Y035++gc+knqhUQag1KpM8ahLw8= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= -github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= diff --git a/plugins/plugins/casbin/filter_test.go b/plugins/plugins/casbin/filter_test.go index f4aff5ef..2ad560da 100644 --- a/plugins/plugins/casbin/filter_test.go +++ b/plugins/plugins/casbin/filter_test.go @@ -115,8 +115,9 @@ func TestReloadEnforcer(t *testing.T) { f := factory(c, cb) Changed = true - ff, _ := f.(*filter) - reloadEnforcer(ff) + header := http.Header{":path": []string{"/other"}} + hdr := envoy.NewRequestHeaderMap(header) + f.DecodeHeaders(hdr, true) time.Sleep(2 * time.Second) assert.False(t, Changed) From c2ffc2e05d34d5de25cbc0f7aae9abff8bc0bab6 Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Mon, 8 Jul 2024 19:31:01 +0800 Subject: [PATCH 16/42] fix: avoid data race --- plugins/plugins/casbin/filter.go | 28 ++++++++++++++---------- plugins/plugins/casbin/filter_test.go | 3 +++ plugins/tests/integration/casbin_test.go | 3 +++ 3 files changed, 22 insertions(+), 12 deletions(-) diff --git a/plugins/plugins/casbin/filter.go b/plugins/plugins/casbin/filter.go index f0efdcde..69e3b191 100644 --- a/plugins/plugins/casbin/filter.go +++ b/plugins/plugins/casbin/filter.go @@ -15,6 +15,8 @@ package casbin import ( + "sync" + "github.com/casbin/casbin/v2" "mosn.io/htnn/api/pkg/filtermanager/api" @@ -35,14 +37,15 @@ type filter struct { config *config } -var Changed = false +var ( + Changed = false + ChangedMu sync.RWMutex +) func reloadEnforcer(f *filter) { conf := f.config - conf.lock.Lock() if !conf.updating.Load() { conf.updating.Store(true) - conf.lock.Unlock() api.LogWarnf("policy %s or model %s Changed, reload enforcer", conf.policyFile.Name, conf.modelFile.Name) go func() { @@ -56,15 +59,15 @@ func reloadEnforcer(f *filter) { if err != nil { api.LogErrorf("failed to update Enforcer: %v", err) } else { - conf.lock.Lock() + ChangedMu.Lock() f.config.enforcer = e Changed = false - conf.lock.Unlock() + ChangedMu.Unlock() err = file.WatchFiles(func() { - conf.lock.Lock() + ChangedMu.Lock() Changed = true - conf.lock.Unlock() + ChangedMu.Unlock() }, conf.modelFile, conf.policyFile) if err != nil { @@ -74,26 +77,27 @@ func reloadEnforcer(f *filter) { api.LogWarnf("policy %s or model %s Changed, enforcer reloaded", conf.policyFile.Name, conf.modelFile.Name) } }() - } else { - conf.lock.Unlock() } } func (f *filter) DecodeHeaders(headers api.RequestHeaderMap, endStream bool) api.ResultAction { conf := f.config + ChangedMu.Lock() + isChanged := Changed + ChangedMu.Unlock() role, _ := headers.Get(conf.Token.Name) // role can be "" url := headers.Url() err := file.WatchFiles(func() { - conf.lock.Lock() + ChangedMu.Lock() Changed = true - conf.lock.Unlock() + ChangedMu.Unlock() }, conf.modelFile, conf.policyFile) if err != nil { api.LogErrorf("failed to watch files: %v", err) return &api.LocalResponse{Code: 500} } - if Changed { + if isChanged { reloadEnforcer(f) } diff --git a/plugins/plugins/casbin/filter_test.go b/plugins/plugins/casbin/filter_test.go index 2ad560da..da5f4e64 100644 --- a/plugins/plugins/casbin/filter_test.go +++ b/plugins/plugins/casbin/filter_test.go @@ -115,11 +115,14 @@ func TestReloadEnforcer(t *testing.T) { f := factory(c, cb) Changed = true + header := http.Header{":path": []string{"/other"}} hdr := envoy.NewRequestHeaderMap(header) f.DecodeHeaders(hdr, true) time.Sleep(2 * time.Second) + ChangedMu.Lock() assert.False(t, Changed) + ChangedMu.Unlock() } diff --git a/plugins/tests/integration/casbin_test.go b/plugins/tests/integration/casbin_test.go index 15163d62..52da25cc 100644 --- a/plugins/tests/integration/casbin_test.go +++ b/plugins/tests/integration/casbin_test.go @@ -119,6 +119,9 @@ g, bob, admin err = os.WriteFile(policyFile2.Name(), []byte(policy), 0755) require.Nil(t, err) + //wait to run reloadEnforcer + time.Sleep(5 * time.Second) + hdr := http.Header{} hdr.Set("customer", "alice") From 46d90e6f6e40ec48b8b9813cd9ecb01fb3113c3d Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Tue, 9 Jul 2024 23:47:29 +0800 Subject: [PATCH 17/42] refactor codes --- plugins/pkg/file/fs.go | 95 ++++++++---------------- plugins/pkg/file/fs_test.go | 60 ++++++--------- plugins/plugins/casbin/config.go | 61 ++++++++++++++- plugins/plugins/casbin/filter.go | 58 +-------------- plugins/plugins/casbin/filter_test.go | 32 ++++---- plugins/tests/integration/casbin_test.go | 2 +- 6 files changed, 133 insertions(+), 175 deletions(-) diff --git a/plugins/pkg/file/fs.go b/plugins/pkg/file/fs.go index c8220238..27169ab3 100644 --- a/plugins/pkg/file/fs.go +++ b/plugins/pkg/file/fs.go @@ -25,37 +25,17 @@ import ( ) var ( - logger = log.DefaultLogger.WithName("file") + logger = log.DefaultLogger.WithName("file") + WatchedFiles = make(map[string]struct{}) ) type File struct { - Name string + Name string + Watcher *fsnotify.Watcher + mu sync.RWMutex } -type Fsnotify struct { - mu sync.Mutex - Watcher *fsnotify.Watcher - WatchedFiles map[string]struct{} -} - -func newFsnotify() (fs *Fsnotify) { - watcher, err := fsnotify.NewWatcher() - if err != nil { - logger.Error(err, "create watcher failed") - return - } - - return &Fsnotify{ - Watcher: watcher, - WatchedFiles: make(map[string]struct{}), - } -} - -var ( - defaultFsnotify = newFsnotify() -) - -func WatchFiles(onChange func(), file *File, otherFiles ...*File) (err error) { +func WatchFiles(onChanged func(), file *File, otherFiles ...*File) (err error) { files := append([]*File{file}, otherFiles...) for _, f := range files { if f == nil { @@ -63,51 +43,32 @@ func WatchFiles(onChange func(), file *File, otherFiles ...*File) (err error) { } } - watcher := defaultFsnotify.Watcher - // Add files to watcher. for _, f := range files { - dir := filepath.Dir(f.Name) - err = defaultFsnotify.AddFiles(dir) - if err != nil { - logger.Error(err, "failed to add file") - } - if _, exists := defaultFsnotify.WatchedFiles[dir]; exists { - continue - } - //Add dir to the watched files - defaultFsnotify.WatchedFiles[dir] = struct{}{} - go defaultFsnotify.watchFiles(onChange, watcher, dir) + go watchFiles(onChanged, f) } return } -func (f *Fsnotify) AddFiles(dir string) (err error) { - err = f.Watcher.Add(dir) - return -} +func watchFiles(onChanged func(), file *File) { + dir := filepath.Dir(file.Name) + defer func() { + file.mu.Lock() + delete(WatchedFiles, dir) + file.mu.Unlock() -func (f *Fsnotify) watchFiles(onChange func(), w *fsnotify.Watcher, dir string) { - defer func(w *fsnotify.Watcher) { - f.mu.Lock() - delete(defaultFsnotify.WatchedFiles, dir) - f.mu.Unlock() - err := w.Close() - if err != nil { - logger.Error(err, "failed to close fsnotify watcher") - } - }(w) + }() for { select { - case event, ok := <-w.Events: + case event, ok := <-file.Watcher.Events: if !ok { return } logger.Info("file changed: ", "event", event) - onChange() - case err, ok := <-w.Errors: + onChanged() + case err, ok := <-file.Watcher.Errors: if !ok { return } @@ -116,12 +77,20 @@ func (f *Fsnotify) watchFiles(onChange func(), w *fsnotify.Watcher, dir string) } } -func (f *Fsnotify) Stat(path string) (*File, error) { - return &File{ - Name: path, - }, nil +func AddFiles(file string, w *fsnotify.Watcher) (err error) { + dir := filepath.Dir(file) + if _, exists := WatchedFiles[dir]; exists { + return + } + WatchedFiles[dir] = struct{}{} + err = w.Add(dir) + return } - -func Stat(path string) (*File, error) { - return defaultFsnotify.Stat(path) +func Stat(file string, w *fsnotify.Watcher) (*File, error) { + err := AddFiles(file, w) + return &File{ + Name: file, + Watcher: w, + mu: sync.RWMutex{}, + }, err } diff --git a/plugins/pkg/file/fs_test.go b/plugins/pkg/file/fs_test.go index 33f2acdd..a99b2ec5 100644 --- a/plugins/pkg/file/fs_test.go +++ b/plugins/pkg/file/fs_test.go @@ -24,27 +24,33 @@ import ( "github.com/stretchr/testify/assert" ) +var ( + wg sync.WaitGroup + mu sync.Mutex +) + func TestFileIsChanged(t *testing.T) { - var wg sync.WaitGroup + changed := false var mu sync.Mutex - i := 4 + watcher, err := fsnotify.NewWatcher() + defer watcher.Close() + + assert.Nil(t, err) tmpfile, _ := os.CreateTemp("./", "example") - defer func(name string) { - err := os.Remove(name) - if err != nil { - t.Logf("%v", err) - } - }(tmpfile.Name()) - file, err := Stat(tmpfile.Name()) + + file, err := Stat(tmpfile.Name(), watcher) assert.NoError(t, err) assert.Equal(t, tmpfile.Name(), file.Name) + + tmpDir := filepath.Dir(tmpfile.Name()) + _, exists := WatchedFiles[tmpDir] + assert.True(t, exists) + err = WatchFiles(func() { - wg.Add(1) - defer wg.Done() mu.Lock() - i = 5 + changed = true mu.Unlock() }, file) assert.Nil(t, err) @@ -53,37 +59,13 @@ func TestFileIsChanged(t *testing.T) { wg.Wait() err = WatchFiles(func() {}, nil) + assert.Error(t, err, "file pointer cannot be nil") mu.Lock() - assert.Equal(t, 5, i) + assert.True(t, changed) mu.Unlock() - watcher, err := fsnotify.NewWatcher() - assert.NoError(t, err) - defer watcher.Close() - fs := &Fsnotify{ - WatchedFiles: make(map[string]struct{}), - Watcher: watcher, - } - tmpfile, err = os.CreateTemp("/tmp", "test") + err = os.Remove(tmpfile.Name()) assert.Nil(t, err) - defer os.Remove(tmpfile.Name()) - tmpDir := filepath.Dir(tmpfile.Name()) - fs.WatchedFiles[tmpDir] = struct{}{} - err = fs.AddFiles(tmpDir) - assert.NoError(t, err) - - onChangeCalled := false - onChange := func() { - onChangeCalled = true - } - - go fs.watchFiles(onChange, fs.Watcher, tmpDir) - - _, exists := fs.WatchedFiles[tmpDir] - - assert.True(t, exists) - assert.False(t, onChangeCalled) - } diff --git a/plugins/plugins/casbin/config.go b/plugins/plugins/casbin/config.go index 67d3599c..4e1ba9e6 100644 --- a/plugins/plugins/casbin/config.go +++ b/plugins/plugins/casbin/config.go @@ -15,10 +15,12 @@ package casbin import ( + "runtime" "sync" "sync/atomic" "github.com/casbin/casbin/v2" + "github.com/fsnotify/fsnotify" "mosn.io/htnn/api/pkg/filtermanager/api" "mosn.io/htnn/api/pkg/plugins" @@ -51,18 +53,27 @@ type config struct { modelFile *file.File policyFile *file.File updating atomic.Bool + + watcher *fsnotify.Watcher } func (conf *config) Init(cb api.ConfigCallbackHandler) error { conf.lock = &sync.RWMutex{} - f, err := file.Stat(conf.Rule.Model) + watcher, err := fsnotify.NewWatcher() + if err != nil { + return err + } + + conf.watcher = watcher + + f, err := file.Stat(conf.Rule.Model, watcher) if err != nil { return err } conf.modelFile = f - f, err = file.Stat(conf.Rule.Policy) + f, err = file.Stat(conf.Rule.Policy, watcher) if err != nil { return err } @@ -73,5 +84,51 @@ func (conf *config) Init(cb api.ConfigCallbackHandler) error { return err } conf.enforcer = e + + runtime.SetFinalizer(conf, func(conf *config) { + err := conf.watcher.Close() + if err != nil { + api.LogErrorf("failed to close watcher, err: %v", err) + } + }) return nil } + +func (conf *config) reloadEnforcer() { + if !conf.updating.Load() { + conf.updating.Store(true) + api.LogWarnf("policy %s or model %s changed, reload enforcer", conf.policyFile.Name, conf.modelFile.Name) + + go func() { + defer func() { + if r := recover(); r != nil { + api.LogErrorf("recovered from panic: %v", r) + } + conf.updating.Store(false) + }() + e, err := casbin.NewEnforcer(conf.Rule.Model, conf.Rule.Policy) + if err != nil { + api.LogErrorf("failed to update Enforcer: %v", err) + } else { + conf.SetChanged(false) + conf.lock.Lock() + conf.enforcer = e + conf.lock.Unlock() + api.LogWarnf("policy %s or model %s changed, enforcer reloaded", conf.policyFile.Name, conf.modelFile.Name) + } + }() + } +} + +func (conf *config) SetChanged(change bool) { + conf.lock.Lock() + Changed = change + conf.lock.Unlock() +} + +func (conf *config) GetChanged() bool { + conf.lock.RLock() + changed := Changed + conf.lock.RUnlock() + return changed +} diff --git a/plugins/plugins/casbin/filter.go b/plugins/plugins/casbin/filter.go index 69e3b191..54131ce0 100644 --- a/plugins/plugins/casbin/filter.go +++ b/plugins/plugins/casbin/filter.go @@ -15,10 +15,6 @@ package casbin import ( - "sync" - - "github.com/casbin/casbin/v2" - "mosn.io/htnn/api/pkg/filtermanager/api" "mosn.io/htnn/plugins/pkg/file" ) @@ -37,68 +33,22 @@ type filter struct { config *config } -var ( - Changed = false - ChangedMu sync.RWMutex -) - -func reloadEnforcer(f *filter) { - conf := f.config - if !conf.updating.Load() { - conf.updating.Store(true) - api.LogWarnf("policy %s or model %s Changed, reload enforcer", conf.policyFile.Name, conf.modelFile.Name) - - go func() { - defer func() { - conf.lock.Lock() - conf.updating.Store(false) - conf.lock.Unlock() - }() - defer f.callbacks.RecoverPanic() - e, err := casbin.NewEnforcer(conf.Rule.Model, conf.Rule.Policy) - if err != nil { - api.LogErrorf("failed to update Enforcer: %v", err) - } else { - ChangedMu.Lock() - f.config.enforcer = e - Changed = false - ChangedMu.Unlock() - - err = file.WatchFiles(func() { - ChangedMu.Lock() - Changed = true - ChangedMu.Unlock() - }, conf.modelFile, conf.policyFile) - - if err != nil { - api.LogErrorf("failed to watch files: %v", err) - } - - api.LogWarnf("policy %s or model %s Changed, enforcer reloaded", conf.policyFile.Name, conf.modelFile.Name) - } - }() - } -} +var Changed = false func (f *filter) DecodeHeaders(headers api.RequestHeaderMap, endStream bool) api.ResultAction { conf := f.config - ChangedMu.Lock() - isChanged := Changed - ChangedMu.Unlock() role, _ := headers.Get(conf.Token.Name) // role can be "" url := headers.Url() err := file.WatchFiles(func() { - ChangedMu.Lock() - Changed = true - ChangedMu.Unlock() + conf.SetChanged(true) }, conf.modelFile, conf.policyFile) if err != nil { api.LogErrorf("failed to watch files: %v", err) return &api.LocalResponse{Code: 500} } - if isChanged { - reloadEnforcer(f) + if conf.GetChanged() { + conf.reloadEnforcer() } conf.lock.RLock() diff --git a/plugins/plugins/casbin/filter_test.go b/plugins/plugins/casbin/filter_test.go index da5f4e64..115b6089 100644 --- a/plugins/plugins/casbin/filter_test.go +++ b/plugins/plugins/casbin/filter_test.go @@ -80,19 +80,21 @@ func TestCasbin(t *testing.T) { hdr := envoy.NewRequestHeaderMap(tt.header) wg := sync.WaitGroup{} - wg.Add(1) - go func() { - // ensure the lock takes effect - lr, ok := f.DecodeHeaders(hdr, true).(*api.LocalResponse) + for i := 0; i < 3; i++ { + wg.Add(1) + go func() { + // ensure the lock takes effect + lr, ok := f.DecodeHeaders(hdr, true).(*api.LocalResponse) - if !ok { - assert.Equal(t, tt.status, 0) - } else { - assert.Equal(t, tt.status, lr.Code) - assert.False(t, Changed) - } - wg.Done() - }() + if !ok { + assert.Equal(t, tt.status, 0) + } else { + assert.Equal(t, tt.status, lr.Code) + assert.False(t, Changed) + } + wg.Done() + }() + } wg.Wait() }) } @@ -114,15 +116,13 @@ func TestReloadEnforcer(t *testing.T) { c.Init(nil) f := factory(c, cb) - Changed = true + c.SetChanged(true) header := http.Header{":path": []string{"/other"}} hdr := envoy.NewRequestHeaderMap(header) f.DecodeHeaders(hdr, true) time.Sleep(2 * time.Second) - ChangedMu.Lock() - assert.False(t, Changed) - ChangedMu.Unlock() + assert.False(t, c.GetChanged()) } diff --git a/plugins/tests/integration/casbin_test.go b/plugins/tests/integration/casbin_test.go index 52da25cc..a53b6986 100644 --- a/plugins/tests/integration/casbin_test.go +++ b/plugins/tests/integration/casbin_test.go @@ -128,5 +128,5 @@ g, bob, admin assert.Eventually(t, func() bool { resp, _ := dp.Post("/echo", hdr, strings.NewReader("any")) return resp != nil && resp.StatusCode == 200 - }, 10*time.Second, 1*time.Second) + }, 3*time.Second, 1*time.Second) } From 2b01940b9770fdf49b7453090e76c432949001fe Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Tue, 9 Jul 2024 23:53:15 +0800 Subject: [PATCH 18/42] fix golangci-lint error --- plugins/pkg/file/fs_test.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/plugins/pkg/file/fs_test.go b/plugins/pkg/file/fs_test.go index a99b2ec5..9958f057 100644 --- a/plugins/pkg/file/fs_test.go +++ b/plugins/pkg/file/fs_test.go @@ -26,14 +26,18 @@ import ( var ( wg sync.WaitGroup - mu sync.Mutex ) func TestFileIsChanged(t *testing.T) { changed := false var mu sync.Mutex watcher, err := fsnotify.NewWatcher() - defer watcher.Close() + defer func(watcher *fsnotify.Watcher) { + err := watcher.Close() + if err != nil { + t.Errorf("close watcher err:%v", err) + } + }(watcher) assert.Nil(t, err) From c1522d2692f5c018fd903d114d18ce35d1cf0cd5 Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Wed, 10 Jul 2024 00:09:21 +0800 Subject: [PATCH 19/42] add test codes --- plugins/plugins/casbin/config_test.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/plugins/plugins/casbin/config_test.go b/plugins/plugins/casbin/config_test.go index b7d646ca..fbb45198 100644 --- a/plugins/plugins/casbin/config_test.go +++ b/plugins/plugins/casbin/config_test.go @@ -15,6 +15,7 @@ package casbin import ( + "sync" "testing" "github.com/stretchr/testify/assert" @@ -60,3 +61,11 @@ func TestBadConfig(t *testing.T) { }) } } + +func TestChanged(t *testing.T) { + conf := &config{ + lock: &sync.RWMutex{}, + } + conf.SetChanged(true) + assert.True(t, conf.GetChanged()) +} From 4c686e21bc15ea07ee2f000c55849ee08da8fdbd Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Wed, 10 Jul 2024 14:13:49 +0800 Subject: [PATCH 20/42] refactor codes --- plugins/plugins/casbin/filter_test.go | 28 --------------------------- 1 file changed, 28 deletions(-) diff --git a/plugins/plugins/casbin/filter_test.go b/plugins/plugins/casbin/filter_test.go index 115b6089..6218cc43 100644 --- a/plugins/plugins/casbin/filter_test.go +++ b/plugins/plugins/casbin/filter_test.go @@ -18,7 +18,6 @@ import ( "net/http" "sync" "testing" - "time" "github.com/stretchr/testify/assert" @@ -99,30 +98,3 @@ func TestCasbin(t *testing.T) { }) } } - -func TestReloadEnforcer(t *testing.T) { - cb := envoy.NewFilterCallbackHandler() - c := &config{ - Config: casbin.Config{ - Rule: &casbin.Config_Rule{ - Model: "./testdata/model.conf", - Policy: "./testdata/policy.csv", - }, - Token: &casbin.Config_Token{ - Name: "user", - }, - }, - } - c.Init(nil) - f := factory(c, cb) - - c.SetChanged(true) - - header := http.Header{":path": []string{"/other"}} - hdr := envoy.NewRequestHeaderMap(header) - f.DecodeHeaders(hdr, true) - time.Sleep(2 * time.Second) - - assert.False(t, c.GetChanged()) - -} From cc84ebd86c4b2f50b8a113a5ab607fcdb1a5b4f2 Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Wed, 10 Jul 2024 14:22:36 +0800 Subject: [PATCH 21/42] refactor codes --- plugins/pkg/file/fs_test.go | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/plugins/pkg/file/fs_test.go b/plugins/pkg/file/fs_test.go index 9958f057..21ea8064 100644 --- a/plugins/pkg/file/fs_test.go +++ b/plugins/pkg/file/fs_test.go @@ -24,13 +24,12 @@ import ( "github.com/stretchr/testify/assert" ) -var ( - wg sync.WaitGroup -) - func TestFileIsChanged(t *testing.T) { - changed := false - var mu sync.Mutex + var ( + wg sync.WaitGroup + mu sync.Mutex + changed bool + ) watcher, err := fsnotify.NewWatcher() defer func(watcher *fsnotify.Watcher) { err := watcher.Close() From 5e32bea96005de71b9262285699d78d136ca6d0c Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Wed, 10 Jul 2024 14:41:00 +0800 Subject: [PATCH 22/42] fix: avoid data race --- plugins/plugins/casbin/config.go | 15 +-------------- plugins/plugins/casbin/config_test.go | 8 ++------ plugins/plugins/casbin/filter.go | 24 +++++++++++++++++++++--- plugins/plugins/casbin/filter_test.go | 1 - 4 files changed, 24 insertions(+), 24 deletions(-) diff --git a/plugins/plugins/casbin/config.go b/plugins/plugins/casbin/config.go index 4e1ba9e6..96e8fcac 100644 --- a/plugins/plugins/casbin/config.go +++ b/plugins/plugins/casbin/config.go @@ -110,7 +110,7 @@ func (conf *config) reloadEnforcer() { if err != nil { api.LogErrorf("failed to update Enforcer: %v", err) } else { - conf.SetChanged(false) + setChanged(false) conf.lock.Lock() conf.enforcer = e conf.lock.Unlock() @@ -119,16 +119,3 @@ func (conf *config) reloadEnforcer() { }() } } - -func (conf *config) SetChanged(change bool) { - conf.lock.Lock() - Changed = change - conf.lock.Unlock() -} - -func (conf *config) GetChanged() bool { - conf.lock.RLock() - changed := Changed - conf.lock.RUnlock() - return changed -} diff --git a/plugins/plugins/casbin/config_test.go b/plugins/plugins/casbin/config_test.go index fbb45198..d4115211 100644 --- a/plugins/plugins/casbin/config_test.go +++ b/plugins/plugins/casbin/config_test.go @@ -15,7 +15,6 @@ package casbin import ( - "sync" "testing" "github.com/stretchr/testify/assert" @@ -63,9 +62,6 @@ func TestBadConfig(t *testing.T) { } func TestChanged(t *testing.T) { - conf := &config{ - lock: &sync.RWMutex{}, - } - conf.SetChanged(true) - assert.True(t, conf.GetChanged()) + setChanged(true) + assert.True(t, getChanged()) } diff --git a/plugins/plugins/casbin/filter.go b/plugins/plugins/casbin/filter.go index 54131ce0..9b4db43a 100644 --- a/plugins/plugins/casbin/filter.go +++ b/plugins/plugins/casbin/filter.go @@ -15,6 +15,8 @@ package casbin import ( + "sync" + "mosn.io/htnn/api/pkg/filtermanager/api" "mosn.io/htnn/plugins/pkg/file" ) @@ -33,21 +35,24 @@ type filter struct { config *config } -var Changed = false +var ( + Changed = false + ChangedMu sync.RWMutex +) func (f *filter) DecodeHeaders(headers api.RequestHeaderMap, endStream bool) api.ResultAction { conf := f.config role, _ := headers.Get(conf.Token.Name) // role can be "" url := headers.Url() err := file.WatchFiles(func() { - conf.SetChanged(true) + setChanged(true) }, conf.modelFile, conf.policyFile) if err != nil { api.LogErrorf("failed to watch files: %v", err) return &api.LocalResponse{Code: 500} } - if conf.GetChanged() { + if getChanged() { conf.reloadEnforcer() } @@ -66,3 +71,16 @@ func (f *filter) DecodeHeaders(headers api.RequestHeaderMap, endStream bool) api } return api.Continue } + +func setChanged(change bool) { + ChangedMu.Lock() + Changed = change + ChangedMu.Unlock() +} + +func getChanged() bool { + ChangedMu.RLock() + changed := Changed + ChangedMu.RUnlock() + return changed +} diff --git a/plugins/plugins/casbin/filter_test.go b/plugins/plugins/casbin/filter_test.go index 6218cc43..1c1db320 100644 --- a/plugins/plugins/casbin/filter_test.go +++ b/plugins/plugins/casbin/filter_test.go @@ -84,7 +84,6 @@ func TestCasbin(t *testing.T) { go func() { // ensure the lock takes effect lr, ok := f.DecodeHeaders(hdr, true).(*api.LocalResponse) - if !ok { assert.Equal(t, tt.status, 0) } else { From 930b568d129533b035da11fc8bc7144a4b90183c Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Wed, 10 Jul 2024 23:52:28 +0800 Subject: [PATCH 23/42] fix: avoid data race --- plugins/pkg/file/fs.go | 36 +++++++++++++++++++++++++++++------- plugins/pkg/file/fs_test.go | 4 +++- 2 files changed, 32 insertions(+), 8 deletions(-) diff --git a/plugins/pkg/file/fs.go b/plugins/pkg/file/fs.go index 27169ab3..c9aa08a4 100644 --- a/plugins/pkg/file/fs.go +++ b/plugins/pkg/file/fs.go @@ -25,8 +25,7 @@ import ( ) var ( - logger = log.DefaultLogger.WithName("file") - WatchedFiles = make(map[string]struct{}) + logger = log.DefaultLogger.WithName("file") ) type File struct { @@ -35,6 +34,20 @@ type File struct { mu sync.RWMutex } +type StoreWatchedFiles struct { + WatchedFiles map[string]struct{} + lock *sync.RWMutex +} + +func newStoreWatcherFiles() *StoreWatchedFiles { + return &StoreWatchedFiles{ + WatchedFiles: make(map[string]struct{}), + lock: &sync.RWMutex{}, + } +} + +var storeWatchedFiles = newStoreWatcherFiles() + func WatchFiles(onChanged func(), file *File, otherFiles ...*File) (err error) { files := append([]*File{file}, otherFiles...) for _, f := range files { @@ -54,9 +67,9 @@ func WatchFiles(onChanged func(), file *File, otherFiles ...*File) (err error) { func watchFiles(onChanged func(), file *File) { dir := filepath.Dir(file.Name) defer func() { - file.mu.Lock() - delete(WatchedFiles, dir) - file.mu.Unlock() + storeWatchedFiles.lock.Lock() + defer storeWatchedFiles.lock.Unlock() + delete(storeWatchedFiles.WatchedFiles, dir) }() @@ -79,10 +92,19 @@ func watchFiles(onChanged func(), file *File) { func AddFiles(file string, w *fsnotify.Watcher) (err error) { dir := filepath.Dir(file) - if _, exists := WatchedFiles[dir]; exists { + + storeWatchedFiles.lock.RLock() + + if _, exists := storeWatchedFiles.WatchedFiles[dir]; exists { + storeWatchedFiles.lock.RUnlock() return } - WatchedFiles[dir] = struct{}{} + storeWatchedFiles.lock.RUnlock() + + storeWatchedFiles.lock.Lock() + storeWatchedFiles.WatchedFiles[dir] = struct{}{} + storeWatchedFiles.lock.Unlock() + err = w.Add(dir) return } diff --git a/plugins/pkg/file/fs_test.go b/plugins/pkg/file/fs_test.go index 21ea8064..89da67bb 100644 --- a/plugins/pkg/file/fs_test.go +++ b/plugins/pkg/file/fs_test.go @@ -48,7 +48,9 @@ func TestFileIsChanged(t *testing.T) { assert.Equal(t, tmpfile.Name(), file.Name) tmpDir := filepath.Dir(tmpfile.Name()) - _, exists := WatchedFiles[tmpDir] + storeWatchedFiles.lock.RLock() + _, exists := storeWatchedFiles.WatchedFiles[tmpDir] + storeWatchedFiles.lock.RUnlock() assert.True(t, exists) err = WatchFiles(func() { From 1eaa4e7c0a5823752bbb8bcdd410d16f44f8d6a5 Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Thu, 11 Jul 2024 00:09:39 +0800 Subject: [PATCH 24/42] refactor codes --- plugins/pkg/file/fs.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/plugins/pkg/file/fs.go b/plugins/pkg/file/fs.go index c9aa08a4..bf0413c3 100644 --- a/plugins/pkg/file/fs.go +++ b/plugins/pkg/file/fs.go @@ -31,7 +31,6 @@ var ( type File struct { Name string Watcher *fsnotify.Watcher - mu sync.RWMutex } type StoreWatchedFiles struct { @@ -113,6 +112,5 @@ func Stat(file string, w *fsnotify.Watcher) (*File, error) { return &File{ Name: file, Watcher: w, - mu: sync.RWMutex{}, }, err } From 7a1551f786e2da5506340a0133420463d5e4c78e Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Thu, 11 Jul 2024 12:38:49 +0800 Subject: [PATCH 25/42] refactor codes --- plugins/plugins/casbin/filter.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/plugins/plugins/casbin/filter.go b/plugins/plugins/casbin/filter.go index 9b4db43a..24f52eda 100644 --- a/plugins/plugins/casbin/filter.go +++ b/plugins/plugins/casbin/filter.go @@ -44,9 +44,11 @@ func (f *filter) DecodeHeaders(headers api.RequestHeaderMap, endStream bool) api conf := f.config role, _ := headers.Get(conf.Token.Name) // role can be "" url := headers.Url() + err := file.WatchFiles(func() { setChanged(true) }, conf.modelFile, conf.policyFile) + if err != nil { api.LogErrorf("failed to watch files: %v", err) return &api.LocalResponse{Code: 500} From bbefd66ea410ccbf369a2f9140ed46e5b7036120 Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Sat, 13 Jul 2024 01:03:44 +0800 Subject: [PATCH 26/42] refactor codes --- plugins/pkg/file/fs.go | 121 ++++++++++------------- plugins/pkg/file/fs_test.go | 31 ++---- plugins/plugins/casbin/config.go | 31 +++--- plugins/plugins/casbin/config_test.go | 5 - plugins/plugins/casbin/filter.go | 27 ----- plugins/tests/integration/casbin_test.go | 6 +- 6 files changed, 78 insertions(+), 143 deletions(-) diff --git a/plugins/pkg/file/fs.go b/plugins/pkg/file/fs.go index bf0413c3..491d42ac 100644 --- a/plugins/pkg/file/fs.go +++ b/plugins/pkg/file/fs.go @@ -15,12 +15,11 @@ package file import ( - "errors" - "path/filepath" "sync" "github.com/fsnotify/fsnotify" + "mosn.io/htnn/api/pkg/filtermanager/api" "mosn.io/htnn/api/pkg/log" ) @@ -29,88 +28,74 @@ var ( ) type File struct { - Name string - Watcher *fsnotify.Watcher + Name string } -type StoreWatchedFiles struct { - WatchedFiles map[string]struct{} - lock *sync.RWMutex +type Watcher struct { + watcher *fsnotify.Watcher + files map[string]bool + mu sync.Mutex + done chan struct{} } -func newStoreWatcherFiles() *StoreWatchedFiles { - return &StoreWatchedFiles{ - WatchedFiles: make(map[string]struct{}), - lock: &sync.RWMutex{}, +func NewWatcher() (*Watcher, error) { + w, err := fsnotify.NewWatcher() + if err != nil { + return nil, err } + return &Watcher{ + watcher: w, + files: make(map[string]bool), + done: make(chan struct{}), + }, nil } -var storeWatchedFiles = newStoreWatcherFiles() - -func WatchFiles(onChanged func(), file *File, otherFiles ...*File) (err error) { - files := append([]*File{file}, otherFiles...) - for _, f := range files { - if f == nil { - return errors.New("file pointer cannot be nil") +func (w *Watcher) AddFile(files ...*File) error { + w.mu.Lock() + defer w.mu.Unlock() + for _, file := range files { + if _, exists := w.files[file.Name]; !exists { + if err := w.watcher.Add(file.Name); err != nil { + api.LogInfof("file watched: %v", err) + return err + } + w.files[file.Name] = true } } - - // Add files to watcher. - for _, f := range files { - go watchFiles(onChanged, f) - } - - return + return nil } -func watchFiles(onChanged func(), file *File) { - dir := filepath.Dir(file.Name) - defer func() { - storeWatchedFiles.lock.Lock() - defer storeWatchedFiles.lock.Unlock() - delete(storeWatchedFiles.WatchedFiles, dir) - - }() - - for { - select { - case event, ok := <-file.Watcher.Events: - if !ok { - return - } - logger.Info("file changed: ", "event", event) - onChanged() - case err, ok := <-file.Watcher.Errors: - if !ok { +func (w *Watcher) Start(onChanged func()) { + go func() { + logger.Info("start watch files") + for { + select { + case event, ok := <-w.watcher.Events: + if !ok { + return + } + logger.Info("file changed: ", "event", event) + onChanged() + case err, ok := <-w.watcher.Errors: + if !ok { + return + } + logger.Error(err, "error watching files") + case <-w.done: return } - logger.Error(err, "error watching files") } - } + }() } -func AddFiles(file string, w *fsnotify.Watcher) (err error) { - dir := filepath.Dir(file) - - storeWatchedFiles.lock.RLock() - - if _, exists := storeWatchedFiles.WatchedFiles[dir]; exists { - storeWatchedFiles.lock.RUnlock() - return - } - storeWatchedFiles.lock.RUnlock() - - storeWatchedFiles.lock.Lock() - storeWatchedFiles.WatchedFiles[dir] = struct{}{} - storeWatchedFiles.lock.Unlock() - - err = w.Add(dir) - return +func (w *Watcher) Stop() error { + logger.Info("stop watcher") + close(w.done) + return w.watcher.Close() } -func Stat(file string, w *fsnotify.Watcher) (*File, error) { - err := AddFiles(file, w) + +func Stat(file string) *File { return &File{ - Name: file, - Watcher: w, - }, err + Name: file, + } } diff --git a/plugins/pkg/file/fs_test.go b/plugins/pkg/file/fs_test.go index 89da67bb..95f91699 100644 --- a/plugins/pkg/file/fs_test.go +++ b/plugins/pkg/file/fs_test.go @@ -16,56 +16,39 @@ package file import ( "os" - "path/filepath" "sync" "testing" - "github.com/fsnotify/fsnotify" "github.com/stretchr/testify/assert" ) func TestFileIsChanged(t *testing.T) { var ( - wg sync.WaitGroup mu sync.Mutex changed bool ) - watcher, err := fsnotify.NewWatcher() - defer func(watcher *fsnotify.Watcher) { - err := watcher.Close() - if err != nil { - t.Errorf("close watcher err:%v", err) - } - }(watcher) + watcher, err := NewWatcher() + defer watcher.Stop() assert.Nil(t, err) tmpfile, _ := os.CreateTemp("./", "example") - file, err := Stat(tmpfile.Name(), watcher) + file := Stat(tmpfile.Name()) - assert.NoError(t, err) assert.Equal(t, tmpfile.Name(), file.Name) - tmpDir := filepath.Dir(tmpfile.Name()) - storeWatchedFiles.lock.RLock() - _, exists := storeWatchedFiles.WatchedFiles[tmpDir] - storeWatchedFiles.lock.RUnlock() - assert.True(t, exists) + err = watcher.AddFile(file) + assert.Nil(t, err) - err = WatchFiles(func() { + watcher.Start(func() { mu.Lock() changed = true mu.Unlock() - }, file) + }) assert.Nil(t, err) tmpfile.Write([]byte("bls")) tmpfile.Sync() - wg.Wait() - - err = WatchFiles(func() {}, nil) - - assert.Error(t, err, "file pointer cannot be nil") mu.Lock() assert.True(t, changed) diff --git a/plugins/plugins/casbin/config.go b/plugins/plugins/casbin/config.go index 96e8fcac..de31f69f 100644 --- a/plugins/plugins/casbin/config.go +++ b/plugins/plugins/casbin/config.go @@ -20,7 +20,6 @@ import ( "sync/atomic" "github.com/casbin/casbin/v2" - "github.com/fsnotify/fsnotify" "mosn.io/htnn/api/pkg/filtermanager/api" "mosn.io/htnn/api/pkg/plugins" @@ -54,39 +53,42 @@ type config struct { policyFile *file.File updating atomic.Bool - watcher *fsnotify.Watcher + watcher *file.Watcher } func (conf *config) Init(cb api.ConfigCallbackHandler) error { conf.lock = &sync.RWMutex{} - watcher, err := fsnotify.NewWatcher() - if err != nil { - return err - } + f := file.Stat(conf.Rule.Model) - conf.watcher = watcher + conf.modelFile = f - f, err := file.Stat(conf.Rule.Model, watcher) + f = file.Stat(conf.Rule.Policy) + + conf.policyFile = f + + e, err := casbin.NewEnforcer(conf.Rule.Model, conf.Rule.Policy) if err != nil { return err } - conf.modelFile = f + conf.enforcer = e - f, err = file.Stat(conf.Rule.Policy, watcher) + watcher, err := file.NewWatcher() if err != nil { return err } - conf.policyFile = f - e, err := casbin.NewEnforcer(conf.Rule.Model, conf.Rule.Policy) + conf.watcher = watcher + + err = conf.watcher.AddFile(conf.modelFile, conf.policyFile) if err != nil { return err } - conf.enforcer = e + + conf.watcher.Start(conf.reloadEnforcer) runtime.SetFinalizer(conf, func(conf *config) { - err := conf.watcher.Close() + err := conf.watcher.Stop() if err != nil { api.LogErrorf("failed to close watcher, err: %v", err) } @@ -110,7 +112,6 @@ func (conf *config) reloadEnforcer() { if err != nil { api.LogErrorf("failed to update Enforcer: %v", err) } else { - setChanged(false) conf.lock.Lock() conf.enforcer = e conf.lock.Unlock() diff --git a/plugins/plugins/casbin/config_test.go b/plugins/plugins/casbin/config_test.go index d4115211..b7d646ca 100644 --- a/plugins/plugins/casbin/config_test.go +++ b/plugins/plugins/casbin/config_test.go @@ -60,8 +60,3 @@ func TestBadConfig(t *testing.T) { }) } } - -func TestChanged(t *testing.T) { - setChanged(true) - assert.True(t, getChanged()) -} diff --git a/plugins/plugins/casbin/filter.go b/plugins/plugins/casbin/filter.go index 24f52eda..4662cbce 100644 --- a/plugins/plugins/casbin/filter.go +++ b/plugins/plugins/casbin/filter.go @@ -18,7 +18,6 @@ import ( "sync" "mosn.io/htnn/api/pkg/filtermanager/api" - "mosn.io/htnn/plugins/pkg/file" ) func factory(c interface{}, callbacks api.FilterCallbackHandler) api.Filter { @@ -45,19 +44,6 @@ func (f *filter) DecodeHeaders(headers api.RequestHeaderMap, endStream bool) api role, _ := headers.Get(conf.Token.Name) // role can be "" url := headers.Url() - err := file.WatchFiles(func() { - setChanged(true) - }, conf.modelFile, conf.policyFile) - - if err != nil { - api.LogErrorf("failed to watch files: %v", err) - return &api.LocalResponse{Code: 500} - } - - if getChanged() { - conf.reloadEnforcer() - } - conf.lock.RLock() ok, err := f.config.enforcer.Enforce(role, url.Path, headers.Method()) conf.lock.RUnlock() @@ -73,16 +59,3 @@ func (f *filter) DecodeHeaders(headers api.RequestHeaderMap, endStream bool) api } return api.Continue } - -func setChanged(change bool) { - ChangedMu.Lock() - Changed = change - ChangedMu.Unlock() -} - -func getChanged() bool { - ChangedMu.RLock() - changed := Changed - ChangedMu.RUnlock() - return changed -} diff --git a/plugins/tests/integration/casbin_test.go b/plugins/tests/integration/casbin_test.go index a53b6986..b4f7e188 100644 --- a/plugins/tests/integration/casbin_test.go +++ b/plugins/tests/integration/casbin_test.go @@ -115,18 +115,16 @@ g, bob, admin }) } + time.Sleep(5 * time.Second) // configuration is not changed, but file changed err = os.WriteFile(policyFile2.Name(), []byte(policy), 0755) require.Nil(t, err) - //wait to run reloadEnforcer - time.Sleep(5 * time.Second) - hdr := http.Header{} hdr.Set("customer", "alice") assert.Eventually(t, func() bool { resp, _ := dp.Post("/echo", hdr, strings.NewReader("any")) return resp != nil && resp.StatusCode == 200 - }, 3*time.Second, 1*time.Second) + }, 15*time.Second, 5*time.Second) } From f239cf5f3555df44bd99c3aac86d4d2f9d7c1bd2 Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Sat, 13 Jul 2024 01:19:08 +0800 Subject: [PATCH 27/42] refactor codes --- plugins/pkg/file/fs.go | 2 -- plugins/tests/integration/casbin_test.go | 3 ++- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/plugins/pkg/file/fs.go b/plugins/pkg/file/fs.go index 491d42ac..5d7f1c86 100644 --- a/plugins/pkg/file/fs.go +++ b/plugins/pkg/file/fs.go @@ -19,7 +19,6 @@ import ( "github.com/fsnotify/fsnotify" - "mosn.io/htnn/api/pkg/filtermanager/api" "mosn.io/htnn/api/pkg/log" ) @@ -56,7 +55,6 @@ func (w *Watcher) AddFile(files ...*File) error { for _, file := range files { if _, exists := w.files[file.Name]; !exists { if err := w.watcher.Add(file.Name); err != nil { - api.LogInfof("file watched: %v", err) return err } w.files[file.Name] = true diff --git a/plugins/tests/integration/casbin_test.go b/plugins/tests/integration/casbin_test.go index b4f7e188..f17bc1f9 100644 --- a/plugins/tests/integration/casbin_test.go +++ b/plugins/tests/integration/casbin_test.go @@ -115,6 +115,7 @@ g, bob, admin }) } + //wait to start watcher time.Sleep(5 * time.Second) // configuration is not changed, but file changed err = os.WriteFile(policyFile2.Name(), []byte(policy), 0755) @@ -126,5 +127,5 @@ g, bob, admin assert.Eventually(t, func() bool { resp, _ := dp.Post("/echo", hdr, strings.NewReader("any")) return resp != nil && resp.StatusCode == 200 - }, 15*time.Second, 5*time.Second) + }, 1*time.Second, 100*time.Millisecond) } From 75d00da72c8b6ade12cb870a78f078047fe2603c Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Sat, 13 Jul 2024 01:27:36 +0800 Subject: [PATCH 28/42] refactor codes --- plugins/plugins/casbin/filter.go | 7 ------- plugins/plugins/casbin/filter_test.go | 1 - 2 files changed, 8 deletions(-) diff --git a/plugins/plugins/casbin/filter.go b/plugins/plugins/casbin/filter.go index 4662cbce..f443142b 100644 --- a/plugins/plugins/casbin/filter.go +++ b/plugins/plugins/casbin/filter.go @@ -15,8 +15,6 @@ package casbin import ( - "sync" - "mosn.io/htnn/api/pkg/filtermanager/api" ) @@ -34,11 +32,6 @@ type filter struct { config *config } -var ( - Changed = false - ChangedMu sync.RWMutex -) - func (f *filter) DecodeHeaders(headers api.RequestHeaderMap, endStream bool) api.ResultAction { conf := f.config role, _ := headers.Get(conf.Token.Name) // role can be "" diff --git a/plugins/plugins/casbin/filter_test.go b/plugins/plugins/casbin/filter_test.go index 1c1db320..f92585f6 100644 --- a/plugins/plugins/casbin/filter_test.go +++ b/plugins/plugins/casbin/filter_test.go @@ -88,7 +88,6 @@ func TestCasbin(t *testing.T) { assert.Equal(t, tt.status, 0) } else { assert.Equal(t, tt.status, lr.Code) - assert.False(t, Changed) } wg.Done() }() From b9d7d0606a324709b14270670fbf4f14ead05054 Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Sun, 14 Jul 2024 21:16:15 +0800 Subject: [PATCH 29/42] refactor: watch the dir of file --- plugins/pkg/file/fs.go | 29 ++++++++++++++++------------- plugins/pkg/file/fs_test.go | 17 +++++++---------- plugins/plugins/casbin/config.go | 4 ++-- plugins/plugins/casbin/filter.go | 2 +- 4 files changed, 26 insertions(+), 26 deletions(-) diff --git a/plugins/pkg/file/fs.go b/plugins/pkg/file/fs.go index 5d7f1c86..120c0146 100644 --- a/plugins/pkg/file/fs.go +++ b/plugins/pkg/file/fs.go @@ -15,6 +15,7 @@ package file import ( + "path/filepath" "sync" "github.com/fsnotify/fsnotify" @@ -34,6 +35,7 @@ type Watcher struct { watcher *fsnotify.Watcher files map[string]bool mu sync.Mutex + dir map[string]bool done chan struct{} } @@ -46,18 +48,23 @@ func NewWatcher() (*Watcher, error) { watcher: w, files: make(map[string]bool), done: make(chan struct{}), + dir: make(map[string]bool), }, nil } -func (w *Watcher) AddFile(files ...*File) error { +func (w *Watcher) AddFiles(files ...*File) error { w.mu.Lock() defer w.mu.Unlock() for _, file := range files { if _, exists := w.files[file.Name]; !exists { - if err := w.watcher.Add(file.Name); err != nil { + w.files[file.Name] = true + } + dir := filepath.Dir(file.Name) + if _, exists := w.dir[dir]; !exists { + if err := w.watcher.Add(dir); err != nil { return err } - w.files[file.Name] = true + w.dir[dir] = true } } return nil @@ -65,19 +72,15 @@ func (w *Watcher) AddFile(files ...*File) error { func (w *Watcher) Start(onChanged func()) { go func() { - logger.Info("start watch files") + logger.Info("start watching files") for { select { - case event, ok := <-w.watcher.Events: - if !ok { - return - } - logger.Info("file changed: ", "event", event) - onChanged() - case err, ok := <-w.watcher.Errors: - if !ok { - return + case event, _ := <-w.watcher.Events: + if _, exists := w.files[event.Name]; exists { + logger.Info("file changed: ", "event", event) + onChanged() } + case err, _ := <-w.watcher.Errors: logger.Error(err, "error watching files") case <-w.done: return diff --git a/plugins/pkg/file/fs_test.go b/plugins/pkg/file/fs_test.go index 95f91699..8d695514 100644 --- a/plugins/pkg/file/fs_test.go +++ b/plugins/pkg/file/fs_test.go @@ -23,10 +23,9 @@ import ( ) func TestFileIsChanged(t *testing.T) { - var ( - mu sync.Mutex - changed bool - ) + changed := false + wg := sync.WaitGroup{} + watcher, err := NewWatcher() defer watcher.Stop() @@ -38,21 +37,19 @@ func TestFileIsChanged(t *testing.T) { assert.Equal(t, tmpfile.Name(), file.Name) - err = watcher.AddFile(file) + err = watcher.AddFiles(file) assert.Nil(t, err) - + wg.Add(1) watcher.Start(func() { - mu.Lock() changed = true - mu.Unlock() + wg.Done() }) assert.Nil(t, err) tmpfile.Write([]byte("bls")) tmpfile.Sync() - mu.Lock() + wg.Wait() assert.True(t, changed) - mu.Unlock() err = os.Remove(tmpfile.Name()) assert.Nil(t, err) diff --git a/plugins/plugins/casbin/config.go b/plugins/plugins/casbin/config.go index de31f69f..bb8db445 100644 --- a/plugins/plugins/casbin/config.go +++ b/plugins/plugins/casbin/config.go @@ -80,7 +80,7 @@ func (conf *config) Init(cb api.ConfigCallbackHandler) error { conf.watcher = watcher - err = conf.watcher.AddFile(conf.modelFile, conf.policyFile) + err = conf.watcher.AddFiles(conf.modelFile, conf.policyFile) if err != nil { return err } @@ -90,7 +90,7 @@ func (conf *config) Init(cb api.ConfigCallbackHandler) error { runtime.SetFinalizer(conf, func(conf *config) { err := conf.watcher.Stop() if err != nil { - api.LogErrorf("failed to close watcher, err: %v", err) + api.LogErrorf("failed to stop watcher, err: %v", err) } }) return nil diff --git a/plugins/plugins/casbin/filter.go b/plugins/plugins/casbin/filter.go index f443142b..dff5b951 100644 --- a/plugins/plugins/casbin/filter.go +++ b/plugins/plugins/casbin/filter.go @@ -43,7 +43,7 @@ func (f *filter) DecodeHeaders(headers api.RequestHeaderMap, endStream bool) api if !ok { if err != nil { - api.LogErrorf("failed to enforece %s: %v", role, err) + api.LogErrorf("failed to enforce %s: %v", role, err) } api.LogInfof("reject forbidden user %s", role) return &api.LocalResponse{ From 0b03da566abbbb87d0fb52849cf2d456fc47a9cf Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Sun, 14 Jul 2024 21:27:46 +0800 Subject: [PATCH 30/42] refactor codes --- plugins/pkg/file/fs.go | 4 ++-- plugins/pkg/file/fs_test.go | 11 ++++++++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/plugins/pkg/file/fs.go b/plugins/pkg/file/fs.go index 120c0146..d3d2fb36 100644 --- a/plugins/pkg/file/fs.go +++ b/plugins/pkg/file/fs.go @@ -75,12 +75,12 @@ func (w *Watcher) Start(onChanged func()) { logger.Info("start watching files") for { select { - case event, _ := <-w.watcher.Events: + case event := <-w.watcher.Events: if _, exists := w.files[event.Name]; exists { logger.Info("file changed: ", "event", event) onChanged() } - case err, _ := <-w.watcher.Errors: + case err := <-w.watcher.Errors: logger.Error(err, "error watching files") case <-w.done: return diff --git a/plugins/pkg/file/fs_test.go b/plugins/pkg/file/fs_test.go index 8d695514..0c0da6e6 100644 --- a/plugins/pkg/file/fs_test.go +++ b/plugins/pkg/file/fs_test.go @@ -25,9 +25,9 @@ import ( func TestFileIsChanged(t *testing.T) { changed := false wg := sync.WaitGroup{} + once := sync.Once{} watcher, err := NewWatcher() - defer watcher.Stop() assert.Nil(t, err) @@ -41,8 +41,10 @@ func TestFileIsChanged(t *testing.T) { assert.Nil(t, err) wg.Add(1) watcher.Start(func() { - changed = true - wg.Done() + once.Do(func() { + changed = true + wg.Done() + }) }) assert.Nil(t, err) tmpfile.Write([]byte("bls")) @@ -53,4 +55,7 @@ func TestFileIsChanged(t *testing.T) { err = os.Remove(tmpfile.Name()) assert.Nil(t, err) + + err = watcher.Stop() + assert.Nil(t, err) } From 8ca375a83b269fe3d83b45c97797be408f1c91a9 Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Mon, 15 Jul 2024 13:00:55 +0800 Subject: [PATCH 31/42] use file's abspath --- plugins/pkg/file/fs.go | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/plugins/pkg/file/fs.go b/plugins/pkg/file/fs.go index d3d2fb36..71c92e9d 100644 --- a/plugins/pkg/file/fs.go +++ b/plugins/pkg/file/fs.go @@ -59,7 +59,10 @@ func (w *Watcher) AddFiles(files ...*File) error { if _, exists := w.files[file.Name]; !exists { w.files[file.Name] = true } - dir := filepath.Dir(file.Name) + dir, err := filepath.Abs(file.Name) + if err != nil { + return err + } if _, exists := w.dir[dir]; !exists { if err := w.watcher.Add(dir); err != nil { return err @@ -76,7 +79,14 @@ func (w *Watcher) Start(onChanged func()) { for { select { case event := <-w.watcher.Events: - if _, exists := w.files[event.Name]; exists { + if event.Op&fsnotify.Chmod == fsnotify.Chmod { + continue + } + absPath, err := filepath.Abs(event.Name) + if err != nil { + logger.Error(err, "get file absPath failed") + } + if _, exists := w.files[absPath]; exists { logger.Info("file changed: ", "event", event) onChanged() } From 363dd1d53340a7fc2d27e1d1f7dd1929a7bf2743 Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Mon, 15 Jul 2024 13:01:10 +0800 Subject: [PATCH 32/42] refactor codes --- plugins/pkg/file/fs_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/pkg/file/fs_test.go b/plugins/pkg/file/fs_test.go index 0c0da6e6..a49d86c6 100644 --- a/plugins/pkg/file/fs_test.go +++ b/plugins/pkg/file/fs_test.go @@ -31,7 +31,7 @@ func TestFileIsChanged(t *testing.T) { assert.Nil(t, err) - tmpfile, _ := os.CreateTemp("./", "example") + tmpfile, _ := os.CreateTemp("", "example") file := Stat(tmpfile.Name()) From b81b32572586d7f15de9c37db5e8bc3bef80e747 Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Mon, 15 Jul 2024 20:15:00 +0800 Subject: [PATCH 33/42] refactor codes --- plugins/pkg/file/fs.go | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/plugins/pkg/file/fs.go b/plugins/pkg/file/fs.go index 71c92e9d..2e52bf72 100644 --- a/plugins/pkg/file/fs.go +++ b/plugins/pkg/file/fs.go @@ -56,13 +56,15 @@ func (w *Watcher) AddFiles(files ...*File) error { w.mu.Lock() defer w.mu.Unlock() for _, file := range files { - if _, exists := w.files[file.Name]; !exists { - w.files[file.Name] = true - } - dir, err := filepath.Abs(file.Name) + absPath, err := filepath.Abs(file.Name) if err != nil { return err } + if _, exists := w.files[absPath]; !exists { + w.files[absPath] = true + } + + dir := filepath.Dir(absPath) if _, exists := w.dir[dir]; !exists { if err := w.watcher.Add(dir); err != nil { return err @@ -79,7 +81,7 @@ func (w *Watcher) Start(onChanged func()) { for { select { case event := <-w.watcher.Events: - if event.Op&fsnotify.Chmod == fsnotify.Chmod { + if event.Op.Has(fsnotify.Chmod) { continue } absPath, err := filepath.Abs(event.Name) From c8798702f0a22b054074c38a7ce4fa6d0162cc29 Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Tue, 16 Jul 2024 10:36:46 +0800 Subject: [PATCH 34/42] refactor codes --- plugins/pkg/file/fs.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/plugins/pkg/file/fs.go b/plugins/pkg/file/fs.go index 2e52bf72..929027c6 100644 --- a/plugins/pkg/file/fs.go +++ b/plugins/pkg/file/fs.go @@ -81,8 +81,11 @@ func (w *Watcher) Start(onChanged func()) { for { select { case event := <-w.watcher.Events: - if event.Op.Has(fsnotify.Chmod) { - continue + if event.Op&fsnotify.Chmod != 0 { + event.Op &= ^fsnotify.Chmod // Remove the Chmod bit + if event.Op == 0 { + continue // Skip if it was only a Chmod event + } } absPath, err := filepath.Abs(event.Name) if err != nil { From ac85bd92c75b8371b8bd8e343b9451561a4f3221 Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Tue, 16 Jul 2024 18:24:01 +0800 Subject: [PATCH 35/42] refactor codes --- plugins/pkg/file/fs.go | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/plugins/pkg/file/fs.go b/plugins/pkg/file/fs.go index 929027c6..86199a3e 100644 --- a/plugins/pkg/file/fs.go +++ b/plugins/pkg/file/fs.go @@ -81,11 +81,8 @@ func (w *Watcher) Start(onChanged func()) { for { select { case event := <-w.watcher.Events: - if event.Op&fsnotify.Chmod != 0 { - event.Op &= ^fsnotify.Chmod // Remove the Chmod bit - if event.Op == 0 { - continue // Skip if it was only a Chmod event - } + if event.Op == fsnotify.Chmod { + continue // Skip chmod event } absPath, err := filepath.Abs(event.Name) if err != nil { From ca62248909fce105e0ec101c5fb075402c570ad8 Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Fri, 2 Aug 2024 00:03:08 +0800 Subject: [PATCH 36/42] feat: refresh consul services --- controller/registries/consul/config.go | 98 +++++++++++++++---- controller/registries/consul/config_test.go | 50 +++++++++- types/registries/consul/config.pb.go | 44 ++++++--- types/registries/consul/config.pb.validate.go | 4 + types/registries/consul/config.proto | 4 +- 5 files changed, 167 insertions(+), 33 deletions(-) diff --git a/controller/registries/consul/config.go b/controller/registries/consul/config.go index 312deb9f..d96d5b15 100644 --- a/controller/registries/consul/config.go +++ b/controller/registries/consul/config.go @@ -70,6 +70,7 @@ type Client struct { DataCenter string NameSpace string + Token string } type consulService struct { @@ -97,6 +98,8 @@ func (reg *Consul) NewClient(config *consul.Config) (*Client, error) { consulClient: client, consulCatalog: client.Catalog(), DataCenter: config.DataCenter, + NameSpace: config.Namespace, + Token: config.Token, }, nil } @@ -108,11 +111,21 @@ func (reg *Consul) Start(c registrytype.RegistryConfig) error { return err } + reg.client = client + services, err := reg.fetchAllServices(client) if err != nil { return fmt.Errorf("fetch all services error: %v", err) } - reg.client = client + + //for key := range services { + // err = reg.subscribe(key.ServiceName) + // if err != nil { + // reg.logger.Errorf("failed to subscribe service, err: %v, service: %v", err, key) + // + // delete(services, key) + // } + //} reg.watchingServices = services @@ -122,22 +135,24 @@ func (reg *Consul) Start(c registrytype.RegistryConfig) error { } go func() { reg.logger.Infof("start refreshing services") - ticker := time.NewTicker(dur) - //q := consulapi.QueryOptions{ - // WaitTime: dur, - //} - defer ticker.Stop() + q := &consulapi.QueryOptions{ + WaitTime: dur, + } for { select { - case <-ticker.C: - err := reg.refresh() - if err != nil { - reg.logger.Errorf("failed to refresh services, err: %v", err) - } case <-reg.done: reg.logger.Infof("stop refreshing services") return + + default: + } + services, meta, err := reg.client.consulCatalog.Services(q) + if err != nil { + reg.logger.Errorf("failed to get services, err: %v", err) } + reg.refresh(services) + + q.WaitIndex = meta.LastIndex } }() @@ -160,13 +175,27 @@ func (reg *Consul) Reload(c registrytype.RegistryConfig) error { return nil } -func (reg *Consul) refresh() error { - return nil -} - func (reg *Consul) fetchAllServices(client *Client) (map[consulService]bool, error) { - fmt.Println(client) - return nil, nil + q := &consulapi.QueryOptions{} + q.Datacenter = client.DataCenter + q.Namespace = client.NameSpace + q.Token = client.Token + services, _, err := client.consulCatalog.Services(q) + + if err != nil { + return nil, err + } + serviceMap := make(map[consulService]bool) + for serviceName, dataCenters := range services { + for _, dc := range dataCenters { + service := consulService{ + DataCenter: dc, + ServiceName: serviceName, + } + serviceMap[service] = true + } + } + return serviceMap, nil } func (reg *Consul) subscribe(serviceName string) error { @@ -178,3 +207,38 @@ func (reg *Consul) unsubscribe(serviceName string) error { fmt.Println(serviceName) return nil } + +func (reg *Consul) refresh(services map[string][]string) { + + serviceMap := make(map[consulService]bool) + for serviceName, dataCenters := range services { + for _, dc := range dataCenters { + service := consulService{ + DataCenter: dc, + ServiceName: serviceName, + } + serviceMap[service] = true + if _, ok := reg.watchingServices[service]; !ok { + err := reg.subscribe(serviceName) + if err != nil { + reg.logger.Errorf("failed to subscribe service, err: %v, service: %v", err, serviceName) + delete(serviceMap, service) + } + } + } + } + + prevFetchServices := reg.watchingServices + reg.watchingServices = serviceMap + + for key := range prevFetchServices { + if _, ok := serviceMap[key]; !ok { + err := reg.unsubscribe(key.ServiceName) + if err != nil { + reg.logger.Errorf("failed to unsubscribe service, err: %v, service: %v", err, key) + } + reg.softDeletedServices[key] = true + } + } + +} diff --git a/controller/registries/consul/config_test.go b/controller/registries/consul/config_test.go index c3f8caf6..8eea624b 100644 --- a/controller/registries/consul/config_test.go +++ b/controller/registries/consul/config_test.go @@ -67,9 +67,6 @@ func TestStart(t *testing.T) { err = reg.unsubscribe("123") assert.Nil(t, err) - err = reg.refresh() - assert.Nil(t, err) - err = reg.Stop() assert.Nil(t, err) } @@ -83,3 +80,50 @@ func TestReload(t *testing.T) { err := reg.Reload(config) assert.NoError(t, err) } + +func TestRefresh(t *testing.T) { + reg := &Consul{ + logger: log.NewLogger(&log.RegistryLoggerOptions{ + Name: "test", + }), + softDeletedServices: map[consulService]bool{}, + done: make(chan struct{}), + watchingServices: map[consulService]bool{}, + } + + config := &consul.Config{ + ServerUrl: "http://127.0.0.1:8500", + } + client, _ := reg.NewClient(config) + reg.client = client + services := map[string][]string{ + "service1": {"dc1", "dc2"}, + "service2": {"dc1"}, + } + + reg.refresh(services) + + assert.Len(t, reg.watchingServices, 3) + assert.Contains(t, reg.watchingServices, consulService{ServiceName: "service1", DataCenter: "dc1"}) + assert.Contains(t, reg.watchingServices, consulService{ServiceName: "service1", DataCenter: "dc2"}) + assert.Contains(t, reg.watchingServices, consulService{ServiceName: "service2", DataCenter: "dc1"}) + assert.Empty(t, reg.softDeletedServices) + + reg = &Consul{ + logger: log.NewLogger(&log.RegistryLoggerOptions{ + Name: "test", + }), + softDeletedServices: map[consulService]bool{}, + watchingServices: map[consulService]bool{ + {ServiceName: "service1", DataCenter: "dc1"}: true, + }, + } + + services = map[string][]string{} + + reg.refresh(services) + + assert.Len(t, reg.watchingServices, 0) + assert.Len(t, reg.softDeletedServices, 1) + +} diff --git a/types/registries/consul/config.pb.go b/types/registries/consul/config.pb.go index ddbc964a..6fdf52f2 100644 --- a/types/registries/consul/config.pb.go +++ b/types/registries/consul/config.pb.go @@ -44,7 +44,9 @@ type Config struct { ServerUrl string `protobuf:"bytes,1,opt,name=server_url,json=serverUrl,proto3" json:"server_url,omitempty"` DataCenter string `protobuf:"bytes,2,opt,name=data_center,json=dataCenter,proto3" json:"data_center,omitempty"` - ServiceRefreshInterval *durationpb.Duration `protobuf:"bytes,3,opt,name=service_refresh_interval,json=serviceRefreshInterval,proto3" json:"service_refresh_interval,omitempty"` + Namespace string `protobuf:"bytes,3,opt,name=namespace,proto3" json:"namespace,omitempty"` + Token string `protobuf:"bytes,4,opt,name=token,proto3" json:"token,omitempty"` + ServiceRefreshInterval *durationpb.Duration `protobuf:"bytes,5,opt,name=service_refresh_interval,json=serviceRefreshInterval,proto3" json:"service_refresh_interval,omitempty"` } func (x *Config) Reset() { @@ -93,6 +95,20 @@ func (x *Config) GetDataCenter() string { return "" } +func (x *Config) GetNamespace() string { + if x != nil { + return x.Namespace + } + return "" +} + +func (x *Config) GetToken() string { + if x != nil { + return x.Token + } + return "" +} + func (x *Config) GetServiceRefreshInterval() *durationpb.Duration { if x != nil { return x.ServiceRefreshInterval @@ -110,21 +126,25 @@ var file_types_registries_consul_config_proto_rawDesc = []byte{ 0x1e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x17, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x2f, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, - 0x74, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xb3, 0x01, 0x0a, 0x06, 0x43, 0x6f, 0x6e, + 0x74, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xe7, 0x01, 0x0a, 0x06, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x27, 0x0a, 0x0a, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x5f, 0x75, 0x72, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x42, 0x08, 0xfa, 0x42, 0x05, 0x72, 0x03, 0x88, 0x01, 0x01, 0x52, 0x09, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x55, 0x72, 0x6c, 0x12, 0x1f, 0x0a, 0x0b, 0x64, 0x61, 0x74, 0x61, 0x5f, 0x63, 0x65, 0x6e, 0x74, 0x65, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x0a, 0x64, 0x61, 0x74, 0x61, 0x43, 0x65, 0x6e, 0x74, 0x65, 0x72, 0x12, 0x5f, 0x0a, - 0x18, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x72, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, - 0x5f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, - 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, - 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x42, 0x0a, 0xfa, 0x42, 0x07, 0xaa, - 0x01, 0x04, 0x32, 0x02, 0x08, 0x01, 0x52, 0x16, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x52, - 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x42, 0x26, - 0x5a, 0x24, 0x6d, 0x6f, 0x73, 0x6e, 0x2e, 0x69, 0x6f, 0x2f, 0x68, 0x74, 0x6e, 0x6e, 0x2f, 0x74, - 0x79, 0x70, 0x65, 0x73, 0x2f, 0x72, 0x65, 0x67, 0x69, 0x73, 0x74, 0x72, 0x69, 0x65, 0x73, 0x2f, - 0x63, 0x6f, 0x6e, 0x73, 0x75, 0x6c, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x09, 0x52, 0x0a, 0x64, 0x61, 0x74, 0x61, 0x43, 0x65, 0x6e, 0x74, 0x65, 0x72, 0x12, 0x1c, 0x0a, + 0x09, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x09, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x74, + 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x74, 0x6f, 0x6b, 0x65, + 0x6e, 0x12, 0x5f, 0x0a, 0x18, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x72, 0x65, 0x66, + 0x72, 0x65, 0x73, 0x68, 0x5f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x18, 0x05, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x42, 0x0a, + 0xfa, 0x42, 0x07, 0xaa, 0x01, 0x04, 0x32, 0x02, 0x08, 0x01, 0x52, 0x16, 0x73, 0x65, 0x72, 0x76, + 0x69, 0x63, 0x65, 0x52, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x76, + 0x61, 0x6c, 0x42, 0x26, 0x5a, 0x24, 0x6d, 0x6f, 0x73, 0x6e, 0x2e, 0x69, 0x6f, 0x2f, 0x68, 0x74, + 0x6e, 0x6e, 0x2f, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2f, 0x72, 0x65, 0x67, 0x69, 0x73, 0x74, 0x72, + 0x69, 0x65, 0x73, 0x2f, 0x63, 0x6f, 0x6e, 0x73, 0x75, 0x6c, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x33, } var ( diff --git a/types/registries/consul/config.pb.validate.go b/types/registries/consul/config.pb.validate.go index 27e88e0b..7129353b 100644 --- a/types/registries/consul/config.pb.validate.go +++ b/types/registries/consul/config.pb.validate.go @@ -79,6 +79,10 @@ func (m *Config) validate(all bool) error { // no validation rules for DataCenter + // no validation rules for Namespace + + // no validation rules for Token + if d := m.GetServiceRefreshInterval(); d != nil { dur, err := d.AsDuration(), d.CheckValid() if err != nil { diff --git a/types/registries/consul/config.proto b/types/registries/consul/config.proto index 154f93d4..6380d017 100644 --- a/types/registries/consul/config.proto +++ b/types/registries/consul/config.proto @@ -24,6 +24,8 @@ option go_package = "mosn.io/htnn/types/registries/consul"; message Config { string server_url = 1 [(validate.rules).string = {uri: true}]; string data_center = 2; - google.protobuf.Duration service_refresh_interval = 3 + string namespace = 3; + string token = 4; + google.protobuf.Duration service_refresh_interval = 5 [(validate.rules).duration = {gte {seconds: 1}}]; } From 817d903768e8b43abf205f0b17d781532f8a6a5c Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Fri, 2 Aug 2024 00:40:00 +0800 Subject: [PATCH 37/42] Remove redundant code --- controller/registries/consul/config.go | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/controller/registries/consul/config.go b/controller/registries/consul/config.go index d96d5b15..c98a8b78 100644 --- a/controller/registries/consul/config.go +++ b/controller/registries/consul/config.go @@ -45,10 +45,6 @@ func init() { }) } -const ( - defaultToken = "" -) - type Consul struct { consul.RegistryType logger log.RegistryLogger @@ -86,7 +82,7 @@ func (reg *Consul) NewClient(config *consul.Config) (*Client, error) { clientConfig := consulapi.DefaultConfig() clientConfig.Address = uri.Host clientConfig.Scheme = uri.Scheme - clientConfig.Token = defaultToken + clientConfig.Token = config.Token clientConfig.Datacenter = config.DataCenter client, err := consulapi.NewClient(clientConfig) @@ -114,9 +110,9 @@ func (reg *Consul) Start(c registrytype.RegistryConfig) error { reg.client = client services, err := reg.fetchAllServices(client) - if err != nil { - return fmt.Errorf("fetch all services error: %v", err) - } + //if err != nil { + // return fmt.Errorf("fetch all services error: %v", err) + //} //for key := range services { // err = reg.subscribe(key.ServiceName) @@ -209,7 +205,6 @@ func (reg *Consul) unsubscribe(serviceName string) error { } func (reg *Consul) refresh(services map[string][]string) { - serviceMap := make(map[consulService]bool) for serviceName, dataCenters := range services { for _, dc := range dataCenters { From d759bff11b192a056b8ce1ceaefbc674743d744b Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Fri, 2 Aug 2024 00:44:28 +0800 Subject: [PATCH 38/42] refactor codes to avoid test failed --- controller/registries/consul/config.go | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/controller/registries/consul/config.go b/controller/registries/consul/config.go index c98a8b78..a99760cb 100644 --- a/controller/registries/consul/config.go +++ b/controller/registries/consul/config.go @@ -109,10 +109,7 @@ func (reg *Consul) Start(c registrytype.RegistryConfig) error { reg.client = client - services, err := reg.fetchAllServices(client) - //if err != nil { - // return fmt.Errorf("fetch all services error: %v", err) - //} + services := reg.fetchAllServices(client) //for key := range services { // err = reg.subscribe(key.ServiceName) @@ -171,7 +168,7 @@ func (reg *Consul) Reload(c registrytype.RegistryConfig) error { return nil } -func (reg *Consul) fetchAllServices(client *Client) (map[consulService]bool, error) { +func (reg *Consul) fetchAllServices(client *Client) map[consulService]bool { q := &consulapi.QueryOptions{} q.Datacenter = client.DataCenter q.Namespace = client.NameSpace @@ -179,7 +176,8 @@ func (reg *Consul) fetchAllServices(client *Client) (map[consulService]bool, err services, _, err := client.consulCatalog.Services(q) if err != nil { - return nil, err + reg.logger.Errorf("failed to get service, err: %v", err) + return nil } serviceMap := make(map[consulService]bool) for serviceName, dataCenters := range services { @@ -191,7 +189,7 @@ func (reg *Consul) fetchAllServices(client *Client) (map[consulService]bool, err serviceMap[service] = true } } - return serviceMap, nil + return serviceMap } func (reg *Consul) subscribe(serviceName string) error { From 2b2fc28eb8e4cf9f15b0caaea73ddb9dbff2a28e Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Fri, 2 Aug 2024 22:59:34 +0800 Subject: [PATCH 39/42] add more tests --- controller/go.mod | 1 + controller/registries/consul/config.go | 97 ++++++++++++++--- controller/registries/consul/config_test.go | 110 ++++++++++++++++++-- 3 files changed, 189 insertions(+), 19 deletions(-) diff --git a/controller/go.mod b/controller/go.mod index be7ad08e..4743507a 100644 --- a/controller/go.mod +++ b/controller/go.mod @@ -124,6 +124,7 @@ require ( github.com/spf13/cast v1.6.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/stoewer/go-strcase v1.2.0 // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/subosito/gotenv v1.6.0 // indirect github.com/tchap/go-patricia/v2 v2.3.1 // indirect github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect diff --git a/controller/registries/consul/config.go b/controller/registries/consul/config.go index a99760cb..27310d31 100644 --- a/controller/registries/consul/config.go +++ b/controller/registries/consul/config.go @@ -22,6 +22,8 @@ import ( "time" consulapi "github.com/hashicorp/consul/api" + "github.com/nacos-group/nacos-sdk-go/model" + istioapi "istio.io/api/networking/v1alpha3" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "mosn.io/htnn/controller/pkg/registry" @@ -45,6 +47,23 @@ func init() { }) } +var ( + SleepTime = 120 * time.Second + RegistryType = "consul" +) + +type consulCatalog interface { + Services(q *consulapi.QueryOptions) (map[string][]string, *consulapi.QueryMeta, error) +} + +type ConsulAPI struct { + client *consulapi.Client +} + +func (c *ConsulAPI) Services(q *consulapi.QueryOptions) (map[string][]string, *consulapi.QueryMeta, error) { + return c.client.Catalog().Services(q) +} + type Consul struct { consul.RegistryType logger log.RegistryLogger @@ -62,7 +81,7 @@ type Consul struct { type Client struct { consulClient *consulapi.Client - consulCatalog *consulapi.Catalog + consulCatalog consulCatalog DataCenter string NameSpace string @@ -102,15 +121,20 @@ func (reg *Consul) NewClient(config *consul.Config) (*Client, error) { func (reg *Consul) Start(c registrytype.RegistryConfig) error { config := c.(*consul.Config) - client, err := reg.NewClient(config) - if err != nil { - return err - } + if reg.client == nil { - reg.client = client + client, err := reg.NewClient(config) + if err != nil { + return err + } - services := reg.fetchAllServices(client) + reg.client = client + } + services, err := reg.fetchAllServices(reg.client) + if err != nil { + return err + } //for key := range services { // err = reg.subscribe(key.ServiceName) // if err != nil { @@ -129,14 +153,18 @@ func (reg *Consul) Start(c registrytype.RegistryConfig) error { go func() { reg.logger.Infof("start refreshing services") q := &consulapi.QueryOptions{ - WaitTime: dur, + WaitTime: dur, + Namespace: config.Namespace, + Datacenter: config.DataCenter, + Token: config.Token, } for { + select { case <-reg.done: reg.logger.Infof("stop refreshing services") - return - + //wait to retry + time.Sleep(SleepTime) default: } services, meta, err := reg.client.consulCatalog.Services(q) @@ -168,7 +196,7 @@ func (reg *Consul) Reload(c registrytype.RegistryConfig) error { return nil } -func (reg *Consul) fetchAllServices(client *Client) map[consulService]bool { +func (reg *Consul) fetchAllServices(client *Client) (map[consulService]bool, error) { q := &consulapi.QueryOptions{} q.Datacenter = client.DataCenter q.Namespace = client.NameSpace @@ -177,7 +205,7 @@ func (reg *Consul) fetchAllServices(client *Client) map[consulService]bool { if err != nil { reg.logger.Errorf("failed to get service, err: %v", err) - return nil + return nil, err } serviceMap := make(map[consulService]bool) for serviceName, dataCenters := range services { @@ -189,7 +217,7 @@ func (reg *Consul) fetchAllServices(client *Client) map[consulService]bool { serviceMap[service] = true } } - return serviceMap + return serviceMap, nil } func (reg *Consul) subscribe(serviceName string) error { @@ -235,3 +263,46 @@ func (reg *Consul) refresh(services map[string][]string) { } } + +func (reg *Consul) generateServiceEntry(host string, services []model.SubscribeService) *registry.ServiceEntryWrapper { + portList := make([]*istioapi.ServicePort, 0, 1) + endpoints := make([]*istioapi.WorkloadEntry, 0, len(services)) + + for _, service := range services { + protocol := registry.HTTP + if service.Metadata == nil { + service.Metadata = make(map[string]string) + } + + if service.Metadata["protocol"] != "" { + protocol = registry.ParseProtocol(service.Metadata["protocol"]) + } + + port := &istioapi.ServicePort{ + Name: string(protocol), + Number: uint32(service.Port), + Protocol: string(protocol), + } + if len(portList) == 0 { + portList = append(portList, port) + } + + endpoint := istioapi.WorkloadEntry{ + Address: service.Ip, + Ports: map[string]uint32{port.Protocol: port.Number}, + Labels: service.Metadata, + } + endpoints = append(endpoints, &endpoint) + } + + return ®istry.ServiceEntryWrapper{ + ServiceEntry: istioapi.ServiceEntry{ + Hosts: []string{host}, + Ports: portList, + Location: istioapi.ServiceEntry_MESH_INTERNAL, + Resolution: istioapi.ServiceEntry_STATIC, + Endpoints: endpoints, + }, + Source: RegistryType, + } +} diff --git a/controller/registries/consul/config_test.go b/controller/registries/consul/config_test.go index 8eea624b..c4cb6ab4 100644 --- a/controller/registries/consul/config_test.go +++ b/controller/registries/consul/config_test.go @@ -17,8 +17,15 @@ package consul import ( "testing" + "github.com/hashicorp/consul/api" + "github.com/nacos-group/nacos-sdk-go/model" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + istioapi "istio.io/api/networking/v1alpha3" + "mosn.io/htnn/controller/pkg/registry" "mosn.io/htnn/controller/pkg/registry/log" "mosn.io/htnn/types/registries/consul" ) @@ -45,19 +52,36 @@ func TestNewClient(t *testing.T) { assert.Nil(t, client) } +type MockConsulCatalog struct { + mock.Mock +} + +// Services is a mock method for ConsulCatalog.Services +func (m *MockConsulCatalog) Services(q *api.QueryOptions) (map[string][]string, *api.QueryMeta, error) { + return nil, nil, nil +} + func TestStart(t *testing.T) { + mockConsulCatalog := new(MockConsulCatalog) + client := &Client{ + consulCatalog: mockConsulCatalog, + DataCenter: "dc1", + NameSpace: "ns1", + Token: "token", + } + reg := &Consul{ logger: log.NewLogger(&log.RegistryLoggerOptions{ Name: "test", }), - softDeletedServices: map[consulService]bool{}, - done: make(chan struct{}), - watchingServices: map[consulService]bool{}, - } - config := &consul.Config{ - ServerUrl: "http://127.0.0.1:8500", + client: client, + done: make(chan struct{}), } + config := &consul.Config{} + + mockConsulCatalog.On("Services", mock.Anything).Return(map[string][]string{"service1": {"dc1"}}, &api.QueryMeta{}, nil) + err := reg.Start(config) assert.NoError(t, err) @@ -69,6 +93,18 @@ func TestStart(t *testing.T) { err = reg.Stop() assert.Nil(t, err) + + reg = &Consul{ + logger: log.NewLogger(&log.RegistryLoggerOptions{ + Name: "test", + }), + done: make(chan struct{}), + } + + err = reg.Start(config) + assert.Error(t, err) + + close(reg.done) } func TestReload(t *testing.T) { @@ -127,3 +163,65 @@ func TestRefresh(t *testing.T) { assert.Len(t, reg.softDeletedServices, 1) } + +func TestGenerateServiceEntry(t *testing.T) { + host := "test.default-group.public.earth.nacos" + reg := &Consul{} + + type test struct { + name string + services []model.SubscribeService + port *istioapi.ServicePort + endpoint *istioapi.WorkloadEntry + } + tests := []test{} + for input, proto := range registry.ProtocolMap { + s := string(proto) + tests = append(tests, test{ + name: input, + services: []model.SubscribeService{ + {Port: 80, Ip: "1.1.1.1", Metadata: map[string]string{ + "protocol": input, + }}, + }, + port: &istioapi.ServicePort{ + Name: s, + Protocol: s, + Number: 80, + }, + endpoint: &istioapi.WorkloadEntry{ + Address: "1.1.1.1", + Ports: map[string]uint32{s: 80}, + Labels: map[string]string{ + "protocol": input, + }, + }, + }) + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + se := reg.generateServiceEntry(host, tt.services) + require.True(t, proto.Equal(se.ServiceEntry.Ports[0], tt.port)) + require.True(t, proto.Equal(se.ServiceEntry.Endpoints[0], tt.endpoint)) + }) + } +} + +func TestFetchAllServices(t *testing.T) { + mockConsulCatalog := new(MockConsulCatalog) + client := &Client{ + consulCatalog: mockConsulCatalog, + DataCenter: "dc1", + NameSpace: "ns1", + Token: "token", + } + + reg := &Consul{} + services, err := reg.fetchAllServices(client) + if err != nil { + return + } + + assert.Equal(t, map[consulService]bool{}, services) +} From f0d005c8114bd9062d92d4a98cf9410255b6700e Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Tue, 6 Aug 2024 16:48:51 +0800 Subject: [PATCH 40/42] refactor codes --- controller/registries/consul/config.go | 98 ++++++--------------- controller/registries/consul/config_test.go | 84 ++++++------------ 2 files changed, 53 insertions(+), 129 deletions(-) diff --git a/controller/registries/consul/config.go b/controller/registries/consul/config.go index 27310d31..342b4db5 100644 --- a/controller/registries/consul/config.go +++ b/controller/registries/consul/config.go @@ -22,8 +22,6 @@ import ( "time" consulapi "github.com/hashicorp/consul/api" - "github.com/nacos-group/nacos-sdk-go/model" - istioapi "istio.io/api/networking/v1alpha3" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "mosn.io/htnn/controller/pkg/registry" @@ -42,14 +40,15 @@ func init() { name: om.Name, softDeletedServices: map[consulService]bool{}, done: make(chan struct{}), + clientFactory: factory, } return reg, nil }) } var ( - SleepTime = 120 * time.Second - RegistryType = "consul" + //RegistryType = "consul" + factory ClientFactory = &DefaultClientFactory{} ) type consulCatalog interface { @@ -66,10 +65,11 @@ func (c *ConsulAPI) Services(q *consulapi.QueryOptions) (map[string][]string, *c type Consul struct { consul.RegistryType - logger log.RegistryLogger - store registry.ServiceEntryStore - name string - client *Client + logger log.RegistryLogger + store registry.ServiceEntryStore + name string + client *Client + clientFactory ClientFactory lock sync.RWMutex watchingServices map[consulService]bool @@ -88,12 +88,13 @@ type Client struct { Token string } -type consulService struct { - DataCenter string - ServiceName string +type ClientFactory interface { + NewClient(config *consul.Config) (*Client, error) } -func (reg *Consul) NewClient(config *consul.Config) (*Client, error) { +type DefaultClientFactory struct{} + +func (f *DefaultClientFactory) NewClient(config *consul.Config) (*Client, error) { uri, err := url.Parse(config.ServerUrl) if err != nil { return nil, fmt.Errorf("invalid server url: %s", config.ServerUrl) @@ -118,31 +119,26 @@ func (reg *Consul) NewClient(config *consul.Config) (*Client, error) { }, nil } +type consulService struct { + DataCenter string + ServiceName string +} + func (reg *Consul) Start(c registrytype.RegistryConfig) error { config := c.(*consul.Config) - if reg.client == nil { + client, err := reg.clientFactory.NewClient(config) + if err != nil { + return err + } - client, err := reg.NewClient(config) - if err != nil { - return err - } + reg.client = client - reg.client = client - } services, err := reg.fetchAllServices(reg.client) if err != nil { return err } - //for key := range services { - // err = reg.subscribe(key.ServiceName) - // if err != nil { - // reg.logger.Errorf("failed to subscribe service, err: %v, service: %v", err, key) - // - // delete(services, key) - // } - //} reg.watchingServices = services @@ -163,13 +159,14 @@ func (reg *Consul) Start(c registrytype.RegistryConfig) error { select { case <-reg.done: reg.logger.Infof("stop refreshing services") - //wait to retry - time.Sleep(SleepTime) + return default: } services, meta, err := reg.client.consulCatalog.Services(q) if err != nil { reg.logger.Errorf("failed to get services, err: %v", err) + time.Sleep(dur) + continue } reg.refresh(services) @@ -263,46 +260,3 @@ func (reg *Consul) refresh(services map[string][]string) { } } - -func (reg *Consul) generateServiceEntry(host string, services []model.SubscribeService) *registry.ServiceEntryWrapper { - portList := make([]*istioapi.ServicePort, 0, 1) - endpoints := make([]*istioapi.WorkloadEntry, 0, len(services)) - - for _, service := range services { - protocol := registry.HTTP - if service.Metadata == nil { - service.Metadata = make(map[string]string) - } - - if service.Metadata["protocol"] != "" { - protocol = registry.ParseProtocol(service.Metadata["protocol"]) - } - - port := &istioapi.ServicePort{ - Name: string(protocol), - Number: uint32(service.Port), - Protocol: string(protocol), - } - if len(portList) == 0 { - portList = append(portList, port) - } - - endpoint := istioapi.WorkloadEntry{ - Address: service.Ip, - Ports: map[string]uint32{port.Protocol: port.Number}, - Labels: service.Metadata, - } - endpoints = append(endpoints, &endpoint) - } - - return ®istry.ServiceEntryWrapper{ - ServiceEntry: istioapi.ServiceEntry{ - Hosts: []string{host}, - Ports: portList, - Location: istioapi.ServiceEntry_MESH_INTERNAL, - Resolution: istioapi.ServiceEntry_STATIC, - Endpoints: endpoints, - }, - Source: RegistryType, - } -} diff --git a/controller/registries/consul/config_test.go b/controller/registries/consul/config_test.go index c4cb6ab4..5e954cdc 100644 --- a/controller/registries/consul/config_test.go +++ b/controller/registries/consul/config_test.go @@ -18,25 +18,22 @@ import ( "testing" "github.com/hashicorp/consul/api" - "github.com/nacos-group/nacos-sdk-go/model" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" - "google.golang.org/protobuf/proto" - istioapi "istio.io/api/networking/v1alpha3" - "mosn.io/htnn/controller/pkg/registry" "mosn.io/htnn/controller/pkg/registry/log" "mosn.io/htnn/types/registries/consul" ) func TestNewClient(t *testing.T) { - reg := &Consul{} + reg := &Consul{ + clientFactory: factory, + } config := &consul.Config{ ServerUrl: "http://127.0.0.1:8500", DataCenter: "test", } - client, err := reg.NewClient(config) + client, err := reg.clientFactory.NewClient(config) assert.NoError(t, err) assert.NotNil(t, client) @@ -46,7 +43,7 @@ func TestNewClient(t *testing.T) { DataCenter: "test", } - client, err = reg.NewClient(config) + client, err = reg.clientFactory.NewClient(config) assert.Error(t, err) assert.Nil(t, client) @@ -61,8 +58,23 @@ func (m *MockConsulCatalog) Services(q *api.QueryOptions) (map[string][]string, return nil, nil, nil } +type MockClientFactory struct { + mock.Mock +} + +func (f *MockClientFactory) NewClient(config *consul.Config) (*Client, error) { + mockConsulCatalog := new(MockConsulCatalog) + return &Client{ + consulCatalog: mockConsulCatalog, + DataCenter: "dc1", + NameSpace: "ns1", + Token: "token", + }, nil +} + func TestStart(t *testing.T) { mockConsulCatalog := new(MockConsulCatalog) + cf := new(MockClientFactory) client := &Client{ consulCatalog: mockConsulCatalog, DataCenter: "dc1", @@ -74,14 +86,14 @@ func TestStart(t *testing.T) { logger: log.NewLogger(&log.RegistryLoggerOptions{ Name: "test", }), - client: client, - done: make(chan struct{}), + done: make(chan struct{}), + clientFactory: cf, } config := &consul.Config{} mockConsulCatalog.On("Services", mock.Anything).Return(map[string][]string{"service1": {"dc1"}}, &api.QueryMeta{}, nil) - + reg.client = client err := reg.Start(config) assert.NoError(t, err) @@ -98,7 +110,8 @@ func TestStart(t *testing.T) { logger: log.NewLogger(&log.RegistryLoggerOptions{ Name: "test", }), - done: make(chan struct{}), + done: make(chan struct{}), + clientFactory: factory, } err = reg.Start(config) @@ -125,12 +138,13 @@ func TestRefresh(t *testing.T) { softDeletedServices: map[consulService]bool{}, done: make(chan struct{}), watchingServices: map[consulService]bool{}, + clientFactory: factory, } config := &consul.Config{ ServerUrl: "http://127.0.0.1:8500", } - client, _ := reg.NewClient(config) + client, _ := reg.clientFactory.NewClient(config) reg.client = client services := map[string][]string{ "service1": {"dc1", "dc2"}, @@ -164,50 +178,6 @@ func TestRefresh(t *testing.T) { } -func TestGenerateServiceEntry(t *testing.T) { - host := "test.default-group.public.earth.nacos" - reg := &Consul{} - - type test struct { - name string - services []model.SubscribeService - port *istioapi.ServicePort - endpoint *istioapi.WorkloadEntry - } - tests := []test{} - for input, proto := range registry.ProtocolMap { - s := string(proto) - tests = append(tests, test{ - name: input, - services: []model.SubscribeService{ - {Port: 80, Ip: "1.1.1.1", Metadata: map[string]string{ - "protocol": input, - }}, - }, - port: &istioapi.ServicePort{ - Name: s, - Protocol: s, - Number: 80, - }, - endpoint: &istioapi.WorkloadEntry{ - Address: "1.1.1.1", - Ports: map[string]uint32{s: 80}, - Labels: map[string]string{ - "protocol": input, - }, - }, - }) - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - se := reg.generateServiceEntry(host, tt.services) - require.True(t, proto.Equal(se.ServiceEntry.Ports[0], tt.port)) - require.True(t, proto.Equal(se.ServiceEntry.Endpoints[0], tt.endpoint)) - }) - } -} - func TestFetchAllServices(t *testing.T) { mockConsulCatalog := new(MockConsulCatalog) client := &Client{ From d7de52d5d920f210e86871d6792e68cb734fcc2a Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Wed, 7 Aug 2024 00:45:01 +0800 Subject: [PATCH 41/42] add more tests --- controller/go.mod | 5 +- controller/go.sum | 4 + controller/registries/consul/config.go | 57 ++----- controller/registries/consul/config_test.go | 162 +++++++++++--------- 4 files changed, 111 insertions(+), 117 deletions(-) diff --git a/controller/go.mod b/controller/go.mod index 4743507a..2c94b88b 100644 --- a/controller/go.mod +++ b/controller/go.mod @@ -30,6 +30,7 @@ require ( github.com/nacos-group/nacos-sdk-go v1.1.4 github.com/onsi/ginkgo/v2 v2.17.2 github.com/onsi/gomega v1.33.0 + github.com/smartystreets/goconvey v1.6.4 github.com/spf13/viper v1.18.2 github.com/stretchr/testify v1.9.0 go.uber.org/zap v1.27.0 @@ -83,6 +84,7 @@ require ( github.com/google/gofuzz v1.2.0 // indirect github.com/google/pprof v0.0.0-20240424215950-a892ee059fd6 // indirect github.com/google/uuid v1.6.0 // indirect + github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 // indirect github.com/gorilla/mux v1.8.1 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-cleanhttp v0.5.2 // indirect @@ -97,6 +99,7 @@ require ( github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect + github.com/jtolds/gls v4.20.0+incompatible // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-colorable v0.1.13 // indirect @@ -119,12 +122,12 @@ require ( github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/sirupsen/logrus v1.9.3 // indirect + github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d // indirect github.com/sourcegraph/conc v0.3.0 // indirect github.com/spf13/afero v1.11.0 // indirect github.com/spf13/cast v1.6.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/stoewer/go-strcase v1.2.0 // indirect - github.com/stretchr/objx v0.5.2 // indirect github.com/subosito/gotenv v1.6.0 // indirect github.com/tchap/go-patricia/v2 v2.3.1 // indirect github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect diff --git a/controller/go.sum b/controller/go.sum index beb8b06e..c2447b8b 100644 --- a/controller/go.sum +++ b/controller/go.sum @@ -153,6 +153,7 @@ github.com/google/pprof v0.0.0-20240424215950-a892ee059fd6/go.mod h1:kf6iHlnVGwg github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= @@ -218,6 +219,7 @@ github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCV github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= @@ -333,8 +335,10 @@ github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPx github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykElWQ6/NYmHa3jpm/yHnI4xSofP+UP6SpjHcSeM= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v0.0.0-20190330032615-68dc04aab96a/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= +github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s= github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= diff --git a/controller/registries/consul/config.go b/controller/registries/consul/config.go index 342b4db5..29bc4d5c 100644 --- a/controller/registries/consul/config.go +++ b/controller/registries/consul/config.go @@ -40,36 +40,17 @@ func init() { name: om.Name, softDeletedServices: map[consulService]bool{}, done: make(chan struct{}), - clientFactory: factory, } return reg, nil }) } -var ( - //RegistryType = "consul" - factory ClientFactory = &DefaultClientFactory{} -) - -type consulCatalog interface { - Services(q *consulapi.QueryOptions) (map[string][]string, *consulapi.QueryMeta, error) -} - -type ConsulAPI struct { - client *consulapi.Client -} - -func (c *ConsulAPI) Services(q *consulapi.QueryOptions) (map[string][]string, *consulapi.QueryMeta, error) { - return c.client.Catalog().Services(q) -} - type Consul struct { consul.RegistryType - logger log.RegistryLogger - store registry.ServiceEntryStore - name string - client *Client - clientFactory ClientFactory + logger log.RegistryLogger + store registry.ServiceEntryStore + name string + client *Client lock sync.RWMutex watchingServices map[consulService]bool @@ -81,20 +62,14 @@ type Consul struct { type Client struct { consulClient *consulapi.Client - consulCatalog consulCatalog + consulCatalog *consulapi.Catalog DataCenter string NameSpace string Token string } -type ClientFactory interface { - NewClient(config *consul.Config) (*Client, error) -} - -type DefaultClientFactory struct{} - -func (f *DefaultClientFactory) NewClient(config *consul.Config) (*Client, error) { +func (reg *Consul) NewClient(config *consul.Config) (*Client, error) { uri, err := url.Parse(config.ServerUrl) if err != nil { return nil, fmt.Errorf("invalid server url: %s", config.ServerUrl) @@ -120,21 +95,21 @@ func (f *DefaultClientFactory) NewClient(config *consul.Config) (*Client, error) } type consulService struct { - DataCenter string + Tag string ServiceName string } func (reg *Consul) Start(c registrytype.RegistryConfig) error { config := c.(*consul.Config) - client, err := reg.clientFactory.NewClient(config) + client, err := reg.NewClient(config) if err != nil { return err } reg.client = client - services, err := reg.fetchAllServices(reg.client) + services, err := reg.FetchAllServices(reg.client) if err != nil { return err @@ -193,7 +168,7 @@ func (reg *Consul) Reload(c registrytype.RegistryConfig) error { return nil } -func (reg *Consul) fetchAllServices(client *Client) (map[consulService]bool, error) { +func (reg *Consul) FetchAllServices(client *Client) (map[consulService]bool, error) { q := &consulapi.QueryOptions{} q.Datacenter = client.DataCenter q.Namespace = client.NameSpace @@ -205,10 +180,10 @@ func (reg *Consul) fetchAllServices(client *Client) (map[consulService]bool, err return nil, err } serviceMap := make(map[consulService]bool) - for serviceName, dataCenters := range services { - for _, dc := range dataCenters { + for serviceName, tags := range services { + for _, tag := range tags { service := consulService{ - DataCenter: dc, + Tag: tag, ServiceName: serviceName, } serviceMap[service] = true @@ -229,10 +204,10 @@ func (reg *Consul) unsubscribe(serviceName string) error { func (reg *Consul) refresh(services map[string][]string) { serviceMap := make(map[consulService]bool) - for serviceName, dataCenters := range services { - for _, dc := range dataCenters { + for serviceName, tags := range services { + for _, tag := range tags { service := consulService{ - DataCenter: dc, + Tag: tag, ServiceName: serviceName, } serviceMap[service] = true diff --git a/controller/registries/consul/config_test.go b/controller/registries/consul/config_test.go index 5e954cdc..278cdcd3 100644 --- a/controller/registries/consul/config_test.go +++ b/controller/registries/consul/config_test.go @@ -15,25 +15,26 @@ package consul import ( + "errors" + "reflect" "testing" + "github.com/agiledragon/gomonkey/v2" "github.com/hashicorp/consul/api" + . "github.com/smartystreets/goconvey/convey" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "mosn.io/htnn/controller/pkg/registry/log" "mosn.io/htnn/types/registries/consul" ) func TestNewClient(t *testing.T) { - reg := &Consul{ - clientFactory: factory, - } + reg := &Consul{} config := &consul.Config{ ServerUrl: "http://127.0.0.1:8500", DataCenter: "test", } - client, err := reg.clientFactory.NewClient(config) + client, err := reg.NewClient(config) assert.NoError(t, err) assert.NotNil(t, client) @@ -43,78 +44,53 @@ func TestNewClient(t *testing.T) { DataCenter: "test", } - client, err = reg.clientFactory.NewClient(config) + client, err = reg.NewClient(config) assert.Error(t, err) assert.Nil(t, client) } -type MockConsulCatalog struct { - mock.Mock -} - -// Services is a mock method for ConsulCatalog.Services -func (m *MockConsulCatalog) Services(q *api.QueryOptions) (map[string][]string, *api.QueryMeta, error) { - return nil, nil, nil -} - -type MockClientFactory struct { - mock.Mock -} - -func (f *MockClientFactory) NewClient(config *consul.Config) (*Client, error) { - mockConsulCatalog := new(MockConsulCatalog) - return &Client{ - consulCatalog: mockConsulCatalog, - DataCenter: "dc1", - NameSpace: "ns1", - Token: "token", - }, nil -} - func TestStart(t *testing.T) { - mockConsulCatalog := new(MockConsulCatalog) - cf := new(MockClientFactory) - client := &Client{ - consulCatalog: mockConsulCatalog, - DataCenter: "dc1", - NameSpace: "ns1", - Token: "token", - } - reg := &Consul{ logger: log.NewLogger(&log.RegistryLoggerOptions{ Name: "test", }), - done: make(chan struct{}), - clientFactory: cf, + done: make(chan struct{}), } - config := &consul.Config{} + Convey("Test Start method", t, func() { - mockConsulCatalog.On("Services", mock.Anything).Return(map[string][]string{"service1": {"dc1"}}, &api.QueryMeta{}, nil) - reg.client = client - err := reg.Start(config) - assert.NoError(t, err) + patches := gomonkey.ApplyMethod(reflect.TypeOf(reg), "FetchAllServices", func(_ *Consul, client *Client) (map[consulService]bool, error) { + return map[consulService]bool{ + {ServiceName: "service1", Tag: "tag1"}: true, + {ServiceName: "service2", Tag: "tag2"}: true, + }, nil + }) + defer patches.Reset() + config := &consul.Config{} + err := reg.Start(config) + So(err, ShouldBeNil) + err = reg.subscribe("123") + So(err, ShouldBeNil) + + err = reg.unsubscribe("123") + So(err, ShouldBeNil) - err = reg.subscribe("123") - assert.Nil(t, err) + err = reg.Stop() + So(err, ShouldBeNil) - err = reg.unsubscribe("123") - assert.Nil(t, err) + }) - err = reg.Stop() - assert.Nil(t, err) + config := &consul.Config{} reg = &Consul{ logger: log.NewLogger(&log.RegistryLoggerOptions{ Name: "test", }), - done: make(chan struct{}), - clientFactory: factory, + done: make(chan struct{}), } - err = reg.Start(config) + err := reg.Start(config) assert.Error(t, err) close(reg.done) @@ -138,13 +114,12 @@ func TestRefresh(t *testing.T) { softDeletedServices: map[consulService]bool{}, done: make(chan struct{}), watchingServices: map[consulService]bool{}, - clientFactory: factory, } config := &consul.Config{ ServerUrl: "http://127.0.0.1:8500", } - client, _ := reg.clientFactory.NewClient(config) + client, _ := reg.NewClient(config) reg.client = client services := map[string][]string{ "service1": {"dc1", "dc2"}, @@ -154,9 +129,9 @@ func TestRefresh(t *testing.T) { reg.refresh(services) assert.Len(t, reg.watchingServices, 3) - assert.Contains(t, reg.watchingServices, consulService{ServiceName: "service1", DataCenter: "dc1"}) - assert.Contains(t, reg.watchingServices, consulService{ServiceName: "service1", DataCenter: "dc2"}) - assert.Contains(t, reg.watchingServices, consulService{ServiceName: "service2", DataCenter: "dc1"}) + assert.Contains(t, reg.watchingServices, consulService{ServiceName: "service1", Tag: "dc1"}) + assert.Contains(t, reg.watchingServices, consulService{ServiceName: "service1", Tag: "dc2"}) + assert.Contains(t, reg.watchingServices, consulService{ServiceName: "service2", Tag: "dc1"}) assert.Empty(t, reg.softDeletedServices) reg = &Consul{ @@ -165,7 +140,7 @@ func TestRefresh(t *testing.T) { }), softDeletedServices: map[consulService]bool{}, watchingServices: map[consulService]bool{ - {ServiceName: "service1", DataCenter: "dc1"}: true, + {ServiceName: "service1", Tag: "dc1"}: true, }, } @@ -179,19 +154,56 @@ func TestRefresh(t *testing.T) { } func TestFetchAllServices(t *testing.T) { - mockConsulCatalog := new(MockConsulCatalog) - client := &Client{ - consulCatalog: mockConsulCatalog, - DataCenter: "dc1", - NameSpace: "ns1", - Token: "token", - } - - reg := &Consul{} - services, err := reg.fetchAllServices(client) - if err != nil { - return - } - - assert.Equal(t, map[consulService]bool{}, services) + Convey("Test FetchAllServices method", t, func() { + reg := &Consul{ + logger: log.NewLogger(&log.RegistryLoggerOptions{ + Name: "test", + }), + } + client := &Client{ + consulCatalog: &api.Catalog{}, + DataCenter: "dc1", + NameSpace: "ns1", + Token: "token", + } + + patches := gomonkey.ApplyMethod(reflect.TypeOf(client.consulCatalog), "Services", func(_ *api.Catalog, q *api.QueryOptions) (map[string][]string, *api.QueryMeta, error) { + return map[string][]string{ + "service1": {"tag1", "tag2"}, + "service2": {"tag3"}, + }, nil, nil + }) + defer patches.Reset() + + services, err := reg.FetchAllServices(client) + So(err, ShouldBeNil) + So(services, ShouldNotBeNil) + So(services[consulService{ServiceName: "service1", Tag: "tag1"}], ShouldBeTrue) + So(services[consulService{ServiceName: "service1", Tag: "tag2"}], ShouldBeTrue) + So(services[consulService{ServiceName: "service2", Tag: "tag3"}], ShouldBeTrue) + }) + + Convey("Test FetchAllServices method with error", t, func() { + reg := &Consul{ + logger: log.NewLogger(&log.RegistryLoggerOptions{ + Name: "test", + }), + } + client := &Client{ + consulCatalog: &api.Catalog{}, + DataCenter: "dc1", + NameSpace: "ns1", + Token: "token", + } + + patches := gomonkey.ApplyMethod(reflect.TypeOf(client.consulCatalog), "Services", func(_ *api.Catalog, q *api.QueryOptions) (map[string][]string, *api.QueryMeta, error) { + return nil, nil, errors.New("mock error") + }) + defer patches.Reset() + + services, err := reg.FetchAllServices(client) + So(err, ShouldNotBeNil) + So(err.Error(), ShouldEqual, "mock error") + So(services, ShouldBeNil) + }) } From 49c59091c38ecfb1e99741cad6f371e6db8f1fd4 Mon Sep 17 00:00:00 2001 From: lyt122 <2747177214@qq.com> Date: Wed, 7 Aug 2024 13:00:51 +0800 Subject: [PATCH 42/42] write test codes in one style --- controller/go.mod | 4 -- controller/go.sum | 4 -- controller/registries/consul/config.go | 4 +- controller/registries/consul/config_test.go | 64 ++++++++++----------- 4 files changed, 32 insertions(+), 44 deletions(-) diff --git a/controller/go.mod b/controller/go.mod index 2c94b88b..be7ad08e 100644 --- a/controller/go.mod +++ b/controller/go.mod @@ -30,7 +30,6 @@ require ( github.com/nacos-group/nacos-sdk-go v1.1.4 github.com/onsi/ginkgo/v2 v2.17.2 github.com/onsi/gomega v1.33.0 - github.com/smartystreets/goconvey v1.6.4 github.com/spf13/viper v1.18.2 github.com/stretchr/testify v1.9.0 go.uber.org/zap v1.27.0 @@ -84,7 +83,6 @@ require ( github.com/google/gofuzz v1.2.0 // indirect github.com/google/pprof v0.0.0-20240424215950-a892ee059fd6 // indirect github.com/google/uuid v1.6.0 // indirect - github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 // indirect github.com/gorilla/mux v1.8.1 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-cleanhttp v0.5.2 // indirect @@ -99,7 +97,6 @@ require ( github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect - github.com/jtolds/gls v4.20.0+incompatible // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-colorable v0.1.13 // indirect @@ -122,7 +119,6 @@ require ( github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/sirupsen/logrus v1.9.3 // indirect - github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d // indirect github.com/sourcegraph/conc v0.3.0 // indirect github.com/spf13/afero v1.11.0 // indirect github.com/spf13/cast v1.6.0 // indirect diff --git a/controller/go.sum b/controller/go.sum index c2447b8b..beb8b06e 100644 --- a/controller/go.sum +++ b/controller/go.sum @@ -153,7 +153,6 @@ github.com/google/pprof v0.0.0-20240424215950-a892ee059fd6/go.mod h1:kf6iHlnVGwg github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= @@ -219,7 +218,6 @@ github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCV github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= -github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= @@ -335,10 +333,8 @@ github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPx github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= -github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykElWQ6/NYmHa3jpm/yHnI4xSofP+UP6SpjHcSeM= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v0.0.0-20190330032615-68dc04aab96a/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= -github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s= github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= diff --git a/controller/registries/consul/config.go b/controller/registries/consul/config.go index 29bc4d5c..417a656b 100644 --- a/controller/registries/consul/config.go +++ b/controller/registries/consul/config.go @@ -109,7 +109,7 @@ func (reg *Consul) Start(c registrytype.RegistryConfig) error { reg.client = client - services, err := reg.FetchAllServices(reg.client) + services, err := reg.fetchAllServices(reg.client) if err != nil { return err @@ -168,7 +168,7 @@ func (reg *Consul) Reload(c registrytype.RegistryConfig) error { return nil } -func (reg *Consul) FetchAllServices(client *Client) (map[consulService]bool, error) { +func (reg *Consul) fetchAllServices(client *Client) (map[consulService]bool, error) { q := &consulapi.QueryOptions{} q.Datacenter = client.DataCenter q.Namespace = client.NameSpace diff --git a/controller/registries/consul/config_test.go b/controller/registries/consul/config_test.go index 278cdcd3..bbbc1753 100644 --- a/controller/registries/consul/config_test.go +++ b/controller/registries/consul/config_test.go @@ -21,7 +21,6 @@ import ( "github.com/agiledragon/gomonkey/v2" "github.com/hashicorp/consul/api" - . "github.com/smartystreets/goconvey/convey" "github.com/stretchr/testify/assert" "mosn.io/htnn/controller/pkg/registry/log" @@ -58,30 +57,27 @@ func TestStart(t *testing.T) { done: make(chan struct{}), } - Convey("Test Start method", t, func() { - - patches := gomonkey.ApplyMethod(reflect.TypeOf(reg), "FetchAllServices", func(_ *Consul, client *Client) (map[consulService]bool, error) { - return map[consulService]bool{ - {ServiceName: "service1", Tag: "tag1"}: true, - {ServiceName: "service2", Tag: "tag2"}: true, - }, nil - }) - defer patches.Reset() - config := &consul.Config{} - err := reg.Start(config) - So(err, ShouldBeNil) - err = reg.subscribe("123") - So(err, ShouldBeNil) + patches := gomonkey.ApplyPrivateMethod(reflect.TypeOf(reg), "fetchAllServices", func(_ *Consul, client *Client) (map[consulService]bool, error) { + return map[consulService]bool{ + {ServiceName: "service1", Tag: "tag1"}: true, + {ServiceName: "service2", Tag: "tag2"}: true, + }, nil + }) + config := &consul.Config{} + err := reg.Start(config) + assert.Nil(t, err) + err = reg.subscribe("123") + assert.Nil(t, err) - err = reg.unsubscribe("123") - So(err, ShouldBeNil) + err = reg.unsubscribe("123") + assert.Nil(t, err) - err = reg.Stop() - So(err, ShouldBeNil) + err = reg.Stop() + assert.Nil(t, err) - }) + patches.Reset() - config := &consul.Config{} + config = &consul.Config{} reg = &Consul{ logger: log.NewLogger(&log.RegistryLoggerOptions{ @@ -90,7 +86,7 @@ func TestStart(t *testing.T) { done: make(chan struct{}), } - err := reg.Start(config) + err = reg.Start(config) assert.Error(t, err) close(reg.done) @@ -154,7 +150,7 @@ func TestRefresh(t *testing.T) { } func TestFetchAllServices(t *testing.T) { - Convey("Test FetchAllServices method", t, func() { + t.Run("Test fetchAllServices method", func(t *testing.T) { reg := &Consul{ logger: log.NewLogger(&log.RegistryLoggerOptions{ Name: "test", @@ -175,15 +171,15 @@ func TestFetchAllServices(t *testing.T) { }) defer patches.Reset() - services, err := reg.FetchAllServices(client) - So(err, ShouldBeNil) - So(services, ShouldNotBeNil) - So(services[consulService{ServiceName: "service1", Tag: "tag1"}], ShouldBeTrue) - So(services[consulService{ServiceName: "service1", Tag: "tag2"}], ShouldBeTrue) - So(services[consulService{ServiceName: "service2", Tag: "tag3"}], ShouldBeTrue) + services, err := reg.fetchAllServices(client) + assert.NoError(t, err) + assert.NotNil(t, services) + assert.True(t, services[consulService{ServiceName: "service1", Tag: "tag1"}]) + assert.True(t, services[consulService{ServiceName: "service1", Tag: "tag2"}]) + assert.True(t, services[consulService{ServiceName: "service2", Tag: "tag3"}]) }) - Convey("Test FetchAllServices method with error", t, func() { + t.Run("Test fetchAllServices method with error", func(t *testing.T) { reg := &Consul{ logger: log.NewLogger(&log.RegistryLoggerOptions{ Name: "test", @@ -201,9 +197,9 @@ func TestFetchAllServices(t *testing.T) { }) defer patches.Reset() - services, err := reg.FetchAllServices(client) - So(err, ShouldNotBeNil) - So(err.Error(), ShouldEqual, "mock error") - So(services, ShouldBeNil) + services, err := reg.fetchAllServices(client) + assert.Error(t, err) + assert.Equal(t, "mock error", err.Error()) + assert.Nil(t, services) }) }