diff --git a/.forgejo/workflows/testing.yml b/.forgejo/workflows/testing.yml index b1ab74ab3d..199e616daa 100644 --- a/.forgejo/workflows/testing.yml +++ b/.forgejo/workflows/testing.yml @@ -172,6 +172,8 @@ jobs: cacher: image: ${{ matrix.cacher.image }} options: ${{ matrix.cacher.options }} + env: + ALLOW_EMPTY_PASSWORD: "yes" # redis & valkey will immediately shutdown with no defined password unless overridden steps: - uses: https://data.forgejo.org/actions/checkout@v5 - uses: ./.forgejo/workflows-composite/setup-env @@ -186,7 +188,7 @@ jobs: env: RACE_ENABLED: 'true' TAGS: bindata - TEST_REDIS_SERVER: cacher:${{ matrix.cacher.port }} + TEST_REDIS_SERVER: cacher:6379 test-mysql: if: vars.ROLE == 'forgejo-coding' || vars.ROLE == 'forgejo-testing' runs-on: docker diff --git a/modules/cache/cache.go b/modules/cache/cache.go index 9bf4e9a00e..cdd9179623 100644 --- a/modules/cache/cache.go +++ b/modules/cache/cache.go @@ -9,6 +9,7 @@ import ( "strconv" "time" + "forgejo.org/modules/log" "forgejo.org/modules/setting" mc "code.forgejo.org/go-chi/cache" @@ -16,7 +17,11 @@ import ( _ "code.forgejo.org/go-chi/cache/memcache" // memcache plugin for cache ) -var conn mc.Cache +var ( + conn mc.Cache + ErrInconvertible = errors.New("value from cache was not convertible to expected type") + mutexMap MutexMap +) func newCache(cacheConfig setting.Cache) (mc.Cache, error) { return mc.NewCacher(mc.Options{ @@ -78,102 +83,104 @@ func GetCache() mc.Cache { return conn } -// GetString returns the key value from cache with callback when no key exists in cache -func GetString(key string, getFunc func() (string, error)) (string, error) { +// concurrencySafeGet is a single-process concurrency safe fetch from the cache, which provides the guarantee that after +// calling `cache.Remove(key)` and then `cache.Get*(key, ...)`, the value returned from cache will never have been +// computed **before** the `Remove` invocation. It uses in-memory synchronization, so its guarantee does not extend to +// a clustered configuration. +// +// getFunc is the computation for the value if caching is not available. convertFunc converts the cached value into the +// target type, and can return `ErrInconvertible` to indicate that the value couldn't be converts and should be +// recomputed instead; other errors are passed through. +func concurrencySafeGet[T any](key string, getFunc func() (T, error), convertFunc func(v any) (T, error)) (T, error) { if conn == nil || setting.CacheService.TTL <= 0 { return getFunc() } + // Use a double-checking method -- once before acquiring the write lock on this key (this block), and then again + // afterwards to avoid calling `getFunc` if it was computed while we were acquiring the lock. This causes two cache + // hits as a trade-off to minimize the number of lock acquisitions. If this trade-off causes too much cache load, + // this first `Get` could be removed -- the second one is performance-critical to ensure that after waiting a "long + // time" to compute w/ `getFunc`, we don't immediately redo that work after acquiring the lock. cached := conn.Get(key) - - if cached == nil { - value, err := getFunc() - if err != nil { - return value, err + if cached != nil { + retval, err := convertFunc(cached) + if err == nil { + return retval, nil + } else if !errors.Is(err, ErrInconvertible) { // for ErrInconvertible we'll fall through to recalculating the value + var zero T + return zero, err } - return value, conn.Put(key, value, setting.CacheService.TTLSeconds()) } - if value, ok := cached.(string); ok { - return value, nil + defer mutexMap.Lock(key)() + + // The second, performance-critical, check if the cache contains the target value. + cached = conn.Get(key) + if cached != nil { + retval, err := convertFunc(cached) + if err == nil { + return retval, nil + } else if !errors.Is(err, ErrInconvertible) { // for ErrInconvertible we'll fall through to recalculating the value + var zero T + return zero, err + } } - if stringer, ok := cached.(fmt.Stringer); ok { - return stringer.String(), nil + value, err := getFunc() + if err != nil { + return value, err } + return value, conn.Put(key, value, setting.CacheService.TTLSeconds()) +} - return fmt.Sprintf("%s", cached), nil +// GetString returns the key value from cache with callback when no key exists in cache +func GetString(key string, getFunc func() (string, error)) (string, error) { + v, err := concurrencySafeGet(key, getFunc, func(cached any) (string, error) { + if value, ok := cached.(string); ok { + return value, nil + } + if stringer, ok := cached.(fmt.Stringer); ok { + return stringer.String(), nil + } + return fmt.Sprintf("%s", cached), nil + }) + return v, err } // GetInt returns key value from cache with callback when no key exists in cache func GetInt(key string, getFunc func() (int, error)) (int, error) { - if conn == nil || setting.CacheService.TTL <= 0 { - return getFunc() - } - - cached := conn.Get(key) - - if cached == nil { - value, err := getFunc() - if err != nil { - return value, err + v, err := concurrencySafeGet(key, getFunc, func(cached any) (int, error) { + switch v := cached.(type) { + case int: + return v, nil + case string: + value, err := strconv.Atoi(v) + if err != nil { + return 0, err + } + return value, nil } - - return value, conn.Put(key, value, setting.CacheService.TTLSeconds()) - } - - switch v := cached.(type) { - case int: - return v, nil - case string: - value, err := strconv.Atoi(v) - if err != nil { - return 0, err - } - return value, nil - default: - value, err := getFunc() - if err != nil { - return value, err - } - return value, conn.Put(key, value, setting.CacheService.TTLSeconds()) - } + return 0, ErrInconvertible + }) + return v, err } // GetInt64 returns key value from cache with callback when no key exists in cache func GetInt64(key string, getFunc func() (int64, error)) (int64, error) { - if conn == nil || setting.CacheService.TTL <= 0 { - return getFunc() - } - - cached := conn.Get(key) - - if cached == nil { - value, err := getFunc() - if err != nil { - return value, err + v, err := concurrencySafeGet(key, getFunc, func(cached any) (int64, error) { + switch v := cached.(type) { + case int64: + return v, nil + case string: + value, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return 0, err + } + return value, nil } - - return value, conn.Put(key, value, setting.CacheService.TTLSeconds()) - } - - switch v := conn.Get(key).(type) { - case int64: - return v, nil - case string: - value, err := strconv.ParseInt(v, 10, 64) - if err != nil { - return 0, err - } - return value, nil - default: - value, err := getFunc() - if err != nil { - return value, err - } - - return value, conn.Put(key, value, setting.CacheService.TTLSeconds()) - } + return 0, ErrInconvertible + }) + return v, err } // Remove key from cache @@ -181,5 +188,15 @@ func Remove(key string) { if conn == nil { return } - _ = conn.Delete(key) + + // The goal of `Remove(key)` is to ensure that *after* it is completed, a new value is computed. It's possible that + // a value is being computed for the key *right now* -- `getFunc` is about to return, we're about to delete the key, + // and then it will be Put into the cache with an out-of-date value computed before the `Remove(key)`. To prevent + // this we need the `Remove(key)` to also lock on the key, just like `Get*(key, ...)` does when computing it. + defer mutexMap.Lock(key)() + + err := conn.Delete(key) + if err != nil { + log.Error("unexpected error deleting key %s from cache: %v", err) + } } diff --git a/modules/cache/cache_test.go b/modules/cache/cache_test.go index aa28c98452..c9225c965c 100644 --- a/modules/cache/cache_test.go +++ b/modules/cache/cache_test.go @@ -5,21 +5,44 @@ package cache import ( "errors" + "fmt" + "math/rand" + "os" + "sync" + "sync/atomic" "testing" "time" "forgejo.org/modules/setting" + "forgejo.org/modules/test" + "code.forgejo.org/go-chi/cache" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func createTestCache() { - conn, _ = newCache(setting.Cache{ - Adapter: "memory", - TTL: time.Minute, - }) - setting.CacheService.TTL = 24 * time.Hour +func createTestCache(t *testing.T) { + var err error + var testCache cache.Cache + + testRedisHost := os.Getenv("TEST_REDIS_SERVER") + if testRedisHost == "" { + testCache, err = newCache(setting.Cache{ + Adapter: "memory", + TTL: time.Minute, + }) + } else { + testCache, err = newCache(setting.Cache{ + Adapter: "redis", + Conn: fmt.Sprintf("redis://%s", testRedisHost), + TTL: time.Minute, + }) + } + require.NoError(t, err) + require.NotNil(t, testCache) + + t.Cleanup(test.MockVariableValue(&conn, testCache)) + t.Cleanup(test.MockVariableValue(&setting.CacheService.TTL, 24*time.Hour)) } func TestNewContext(t *testing.T) { @@ -36,13 +59,13 @@ func TestNewContext(t *testing.T) { } func TestGetCache(t *testing.T) { - createTestCache() + createTestCache(t) assert.NotNil(t, GetCache()) } func TestGetString(t *testing.T) { - createTestCache() + createTestCache(t) data, err := GetString("key", func() (string, error) { return "", errors.New("some error") @@ -78,7 +101,7 @@ func TestGetString(t *testing.T) { } func TestGetInt(t *testing.T) { - createTestCache() + createTestCache(t) data, err := GetInt("key", func() (int, error) { return 0, errors.New("some error") @@ -114,7 +137,7 @@ func TestGetInt(t *testing.T) { } func TestGetInt64(t *testing.T) { - createTestCache() + createTestCache(t) data, err := GetInt64("key", func() (int64, error) { return 0, errors.New("some error") @@ -148,3 +171,56 @@ func TestGetInt64(t *testing.T) { assert.EqualValues(t, 100, data) Remove("key") } + +func TestCacheConcurrencySafety(t *testing.T) { + createTestCache(t) + + testRedisHost := os.Getenv("TEST_REDIS_SERVER") + if testRedisHost == "" { + t.Skip("only relevant for a remote redis host") + } + + numTests := 20 + numIncrements := 1000 + for testCount := range numTests { + t.Run(fmt.Sprintf("attempt:%d", testCount), func(t *testing.T) { + var counter atomic.Int64 + var wg sync.WaitGroup + var firstError atomic.Value + + getFunc := func() (int64, error) { + lastValue := counter.Load() + time.Sleep(time.Duration(rand.Intn(20)) * time.Millisecond) + return lastValue, nil + } + + for range numIncrements { + wg.Go(func() { + counterValue := counter.Add(1) + Remove(t.Name()) + cachedValue, err := GetInt64(t.Name(), getFunc) + if err != nil { + firstError.CompareAndSwap(nil, fmt.Sprintf("cache error: %v", err)) + } else if cachedValue < counterValue { + firstError.CompareAndSwap(nil, fmt.Sprintf("incremented to value %d but retrieved value %d from cache", counterValue, cachedValue)) + } + }) + } + + wg.Wait() + require.EqualValues(t, numIncrements, counter.Load()) + if err := firstError.Load(); err != nil { + t.Fatal(err) + } + + // Without invalidating the cache, check what was last stored in it. + value, err := GetInt64(t.Name(), func() (int64, error) { + t.Fatal("getFunc should not be invoked") + return 0, errors.New("getFunc should not be invoked") + }) + + require.NoError(t, err) + assert.EqualValues(t, numIncrements, value) + }) + } +} diff --git a/modules/cache/mutex_map.go b/modules/cache/mutex_map.go new file mode 100644 index 0000000000..beb27f32c8 --- /dev/null +++ b/modules/cache/mutex_map.go @@ -0,0 +1,60 @@ +// Copyright 2025 The Forgejo Authors. All rights reserved. +// SPDX-License-Identifier: GPL-3.0-or-later + +package cache + +import ( + "sync" +) + +// MutexMap is basically a map[string]sync.Mutex which allows you to have one mutex per string key being locked. Unlike +// a map[string]sync.Mutex, this map will automatically remove the Mutexes from itself when they are not being waited +// for, preventing resource waste. It does this by keeping a reference count of the current Lock calls for the given +// key. +type MutexMap struct { + mu sync.Mutex // mutex to be held when accessing mutexMap + mutexMap map[string]*refcountMutex +} + +type refcountMutex struct { + refCount int // access to refCount is protected by the MutexMap's mu + + sync.Mutex +} + +// Locks the given key, and returns a function that must be invoked to unlock the key. +func (m *MutexMap) Lock(key string) func() { + m.mu.Lock() + if m.mutexMap == nil { + m.mutexMap = make(map[string]*refcountMutex) + } + mutex, ok := m.mutexMap[key] + if !ok { + mutex = &refcountMutex{} + m.mutexMap[key] = mutex + } + mutex.refCount++ + m.mu.Unlock() + + mutex.Lock() + + unlockPending := true + + return func() { + if !unlockPending { + // unlocking twice would cause incorrect reference counts and might release another goroutine's mutex -- try + // to detect and panic so that this programming error can be found closest to the source. + panic("MutexMap unlock invoked twice") + } + + unlockPending = false + mutex.Unlock() + + m.mu.Lock() + mutex.refCount-- + if mutex.refCount == 0 { + delete(m.mutexMap, key) + } + m.mu.Unlock() + } +} diff --git a/modules/cache/mutex_map_test.go b/modules/cache/mutex_map_test.go new file mode 100644 index 0000000000..324b3228d8 --- /dev/null +++ b/modules/cache/mutex_map_test.go @@ -0,0 +1,131 @@ +// Copyright 2025 The Forgejo Authors. All rights reserved. +// SPDX-License-Identifier: GPL-3.0-or-later + +package cache + +import ( + "math/rand" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestMutexMap_BasicLockUnlock(t *testing.T) { + mm := &MutexMap{} + + unlock := mm.Lock("test-key") + unlock() + + // Should be able to lock again + unlock2 := mm.Lock("test-key") + unlock2() +} + +func TestMutexMap_ConcurrentSameKey(t *testing.T) { + mm := &MutexMap{} + var anotherLockActive atomic.Bool + var firstError atomic.Value + var wg sync.WaitGroup + + for range 10 { + wg.Go(func() { + unlock := mm.Lock("shared-key") + defer unlock() + + // should *not* find that another goroutine has put `true` into here. + swapped := anotherLockActive.CompareAndSwap(false, true) + if !swapped { + firstError.CompareAndSwap(nil, "anotherLockActive was true!") + } + time.Sleep(time.Duration(rand.Intn(20)) * time.Millisecond) // jitter the goroutines to ensure no serial execution + anotherLockActive.Store(false) + }) + } + + wg.Wait() + + if err := firstError.Load(); err != nil { + t.Fatal(err) + } +} + +func TestMutexMap_DifferentKeys(t *testing.T) { + mm := &MutexMap{} + done := make(chan bool, 1) + + go func() { + // If these somehow refered to the same underlying `sync.Mutex`, because `sync.Mutex` is not re-entrant this would + // never complete. + unlock1 := mm.Lock("test-key-1") + unlock2 := mm.Lock("test-key-2") + unlock3 := mm.Lock("test-key-3") + unlock1() + unlock2() + unlock3() + done <- true + }() + + select { + case <-done: + // Success + case <-time.After(1 * time.Second): // early timeout so that we don't wait for t.Deadline() + t.Fatal("test incomplete after timeout, indicating a locking bug") + } +} + +func TestMutexMap_SimpleCleanup(t *testing.T) { + mm := &MutexMap{} + unlock1 := mm.Lock("test-key-1") + + mm.mu.Lock() + assert.Len(t, mm.mutexMap, 1) + mm.mu.Unlock() + + unlock1() + + mm.mu.Lock() + assert.Empty(t, mm.mutexMap) + mm.mu.Unlock() +} + +func TestMutexMap_ConcurrentCleanup(t *testing.T) { + mm := &MutexMap{} + var foundRefGreaterThanOne atomic.Bool + var wg sync.WaitGroup + + for range 10 { + wg.Go(func() { + unlock := mm.Lock("shared-key") + defer unlock() + + time.Sleep(time.Duration(rand.Intn(20)) * time.Millisecond) // jitter the goroutines to ensure no serial execution + + mm.mu.Lock() + rcMutex := mm.mutexMap["shared-key"] + if rcMutex.refCount > 1 { + foundRefGreaterThanOne.Store(true) + } + mm.mu.Unlock() + }) + } + + wg.Wait() + + assert.True(t, foundRefGreaterThanOne.Load(), "expected to find a refCount > 1") + + mm.mu.Lock() + assert.Empty(t, mm.mutexMap) + mm.mu.Unlock() +} + +func TestMutexMap_UnlockTwice(t *testing.T) { + mm := &MutexMap{} + assert.Panics(t, func() { + unlock := mm.Lock("test") + unlock() + unlock() + }) +}