diff --git a/pkg/util/lock/key_lock.go b/pkg/util/lock/key_lock.go index 97910aed7bd14..3bf8d85ce2939 100644 --- a/pkg/util/lock/key_lock.go +++ b/pkg/util/lock/key_lock.go @@ -18,6 +18,7 @@ package lock import ( "sync" + "time" "go.uber.org/zap" @@ -33,8 +34,12 @@ func (m *RefLock) ref() { m.refCounter++ } -func (m *RefLock) unref() { - m.refCounter-- +func (m *RefLock) unref() bool { + if m.refCounter > 0 { + m.refCounter-- + return true + } + return false } func newRefLock() *RefLock { @@ -46,17 +51,39 @@ func newRefLock() *RefLock { } type KeyLock[K comparable] struct { - keyLocksMutex sync.Mutex - refLocks map[K]*RefLock + keyLocksMutex sync.Mutex + refLocks map[K]*RefLock + backgroundGCInterval time.Duration } func NewKeyLock[K comparable]() *KeyLock[K] { + return NewKeyLockWithGCTime[K](5 * time.Second) +} + +func NewKeyLockWithGCTime[K comparable](gcInterval time.Duration) *KeyLock[K] { keyLock := KeyLock[K]{ - refLocks: make(map[K]*RefLock), + refLocks: make(map[K]*RefLock), + backgroundGCInterval: gcInterval, } + keyLock.StartGC() return &keyLock } +func (k *KeyLock[K]) StartGC() { + go func() { + gcTimer := time.NewTimer(k.backgroundGCInterval) + for range gcTimer.C { + k.keyLocksMutex.Lock() + for key, keyLock := range k.refLocks { + if keyLock.refCounter == 0 { + delete(k.refLocks, key) + } + } + k.keyLocksMutex.Unlock() + } + }() +} + func (k *KeyLock[K]) Lock(key K) { k.keyLocksMutex.Lock() // update the key map @@ -84,9 +111,10 @@ func (k *KeyLock[K]) Unlock(lockedKey K) { log.Warn("Unlocking non-existing key", zap.Any("key", lockedKey)) return } - keyLock.unref() - if keyLock.refCounter == 0 { - delete(k.refLocks, lockedKey) + success := keyLock.unref() + if !success { + log.Warn("Unlocking non-locked key", zap.Any("key", lockedKey)) + return } keyLock.mutex.Unlock() } @@ -118,9 +146,10 @@ func (k *KeyLock[K]) RUnlock(lockedKey K) { log.Warn("Unlocking non-existing key", zap.Any("key", lockedKey)) return } - keyLock.unref() - if keyLock.refCounter == 0 { - delete(k.refLocks, lockedKey) + success := keyLock.unref() + if !success { + log.Warn("Unlocking non-locked key", zap.Any("key", lockedKey)) + return } keyLock.mutex.RUnlock() } @@ -128,5 +157,11 @@ func (k *KeyLock[K]) RUnlock(lockedKey K) { func (k *KeyLock[K]) size() int { k.keyLocksMutex.Lock() defer k.keyLocksMutex.Unlock() - return len(k.refLocks) + s := 0 + for _, keyLock := range k.refLocks { + if keyLock.refCounter > 0 { + s++ + } + } + return s } diff --git a/pkg/util/lock/key_lock_test.go b/pkg/util/lock/key_lock_test.go index 46002b9ed4176..37ba4bc558b0d 100644 --- a/pkg/util/lock/key_lock_test.go +++ b/pkg/util/lock/key_lock_test.go @@ -67,3 +67,23 @@ func TestKeyRLock(t *testing.T) { wg.Wait() assert.Equal(t, keyLock.size(), 0) } + +func TestNewKeyLock(t *testing.T) { + keyLock := NewKeyLockWithGCTime[string](time.Second) + keyLock.Lock("a") + keyLock.Lock("b") + + keyLock.Unlock("a") + keyLock.Unlock("b") + + assert.Equal(t, 0, keyLock.size()) + keyLock.keyLocksMutex.Lock() + keyLen := len(keyLock.refLocks) + keyLock.keyLocksMutex.Unlock() + assert.Equal(t, 2, keyLen) + time.Sleep(2 * time.Second) + keyLock.keyLocksMutex.Lock() + keyLen = len(keyLock.refLocks) + keyLock.keyLocksMutex.Unlock() + assert.Equal(t, 0, keyLen) +}