mirror of
https://codeberg.org/forgejo/forgejo.git
synced 2026-05-12 22:10:25 +00:00
fix: possible cause of invalid issue counts; cache module doesn't guarantee concurrency safety (#10127)
Reviewed-on: https://codeberg.org/forgejo/forgejo/pulls/10127 Reviewed-by: oliverpool <oliverpool@noreply.codeberg.org> Reviewed-by: Gusted <gusted@noreply.codeberg.org>
This commit is contained in:
commit
b464f76931
5 changed files with 373 additions and 87 deletions
|
|
@ -172,6 +172,8 @@ jobs:
|
|||
cacher:
|
||||
image: ${{ matrix.cacher.image }}
|
||||
options: ${{ matrix.cacher.options }}
|
||||
env:
|
||||
ALLOW_EMPTY_PASSWORD: "yes" # redis & valkey will immediately shutdown with no defined password unless overridden
|
||||
steps:
|
||||
- uses: https://data.forgejo.org/actions/checkout@v5
|
||||
- uses: ./.forgejo/workflows-composite/setup-env
|
||||
|
|
@ -186,7 +188,7 @@ jobs:
|
|||
env:
|
||||
RACE_ENABLED: 'true'
|
||||
TAGS: bindata
|
||||
TEST_REDIS_SERVER: cacher:${{ matrix.cacher.port }}
|
||||
TEST_REDIS_SERVER: cacher:6379
|
||||
test-mysql:
|
||||
if: vars.ROLE == 'forgejo-coding' || vars.ROLE == 'forgejo-testing'
|
||||
runs-on: docker
|
||||
|
|
|
|||
169
modules/cache/cache.go
vendored
169
modules/cache/cache.go
vendored
|
|
@ -9,6 +9,7 @@ import (
|
|||
"strconv"
|
||||
"time"
|
||||
|
||||
"forgejo.org/modules/log"
|
||||
"forgejo.org/modules/setting"
|
||||
|
||||
mc "code.forgejo.org/go-chi/cache"
|
||||
|
|
@ -16,7 +17,11 @@ import (
|
|||
_ "code.forgejo.org/go-chi/cache/memcache" // memcache plugin for cache
|
||||
)
|
||||
|
||||
var conn mc.Cache
|
||||
var (
|
||||
conn mc.Cache
|
||||
ErrInconvertible = errors.New("value from cache was not convertible to expected type")
|
||||
mutexMap MutexMap
|
||||
)
|
||||
|
||||
func newCache(cacheConfig setting.Cache) (mc.Cache, error) {
|
||||
return mc.NewCacher(mc.Options{
|
||||
|
|
@ -78,102 +83,104 @@ func GetCache() mc.Cache {
|
|||
return conn
|
||||
}
|
||||
|
||||
// GetString returns the key value from cache with callback when no key exists in cache
|
||||
func GetString(key string, getFunc func() (string, error)) (string, error) {
|
||||
// concurrencySafeGet is a single-process concurrency safe fetch from the cache, which provides the guarantee that after
|
||||
// calling `cache.Remove(key)` and then `cache.Get*(key, ...)`, the value returned from cache will never have been
|
||||
// computed **before** the `Remove` invocation. It uses in-memory synchronization, so its guarantee does not extend to
|
||||
// a clustered configuration.
|
||||
//
|
||||
// getFunc is the computation for the value if caching is not available. convertFunc converts the cached value into the
|
||||
// target type, and can return `ErrInconvertible` to indicate that the value couldn't be converts and should be
|
||||
// recomputed instead; other errors are passed through.
|
||||
func concurrencySafeGet[T any](key string, getFunc func() (T, error), convertFunc func(v any) (T, error)) (T, error) {
|
||||
if conn == nil || setting.CacheService.TTL <= 0 {
|
||||
return getFunc()
|
||||
}
|
||||
|
||||
// Use a double-checking method -- once before acquiring the write lock on this key (this block), and then again
|
||||
// afterwards to avoid calling `getFunc` if it was computed while we were acquiring the lock. This causes two cache
|
||||
// hits as a trade-off to minimize the number of lock acquisitions. If this trade-off causes too much cache load,
|
||||
// this first `Get` could be removed -- the second one is performance-critical to ensure that after waiting a "long
|
||||
// time" to compute w/ `getFunc`, we don't immediately redo that work after acquiring the lock.
|
||||
cached := conn.Get(key)
|
||||
|
||||
if cached == nil {
|
||||
value, err := getFunc()
|
||||
if err != nil {
|
||||
return value, err
|
||||
if cached != nil {
|
||||
retval, err := convertFunc(cached)
|
||||
if err == nil {
|
||||
return retval, nil
|
||||
} else if !errors.Is(err, ErrInconvertible) { // for ErrInconvertible we'll fall through to recalculating the value
|
||||
var zero T
|
||||
return zero, err
|
||||
}
|
||||
return value, conn.Put(key, value, setting.CacheService.TTLSeconds())
|
||||
}
|
||||
|
||||
if value, ok := cached.(string); ok {
|
||||
return value, nil
|
||||
defer mutexMap.Lock(key)()
|
||||
|
||||
// The second, performance-critical, check if the cache contains the target value.
|
||||
cached = conn.Get(key)
|
||||
if cached != nil {
|
||||
retval, err := convertFunc(cached)
|
||||
if err == nil {
|
||||
return retval, nil
|
||||
} else if !errors.Is(err, ErrInconvertible) { // for ErrInconvertible we'll fall through to recalculating the value
|
||||
var zero T
|
||||
return zero, err
|
||||
}
|
||||
}
|
||||
|
||||
if stringer, ok := cached.(fmt.Stringer); ok {
|
||||
return stringer.String(), nil
|
||||
value, err := getFunc()
|
||||
if err != nil {
|
||||
return value, err
|
||||
}
|
||||
return value, conn.Put(key, value, setting.CacheService.TTLSeconds())
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s", cached), nil
|
||||
// GetString returns the key value from cache with callback when no key exists in cache
|
||||
func GetString(key string, getFunc func() (string, error)) (string, error) {
|
||||
v, err := concurrencySafeGet(key, getFunc, func(cached any) (string, error) {
|
||||
if value, ok := cached.(string); ok {
|
||||
return value, nil
|
||||
}
|
||||
if stringer, ok := cached.(fmt.Stringer); ok {
|
||||
return stringer.String(), nil
|
||||
}
|
||||
return fmt.Sprintf("%s", cached), nil
|
||||
})
|
||||
return v, err
|
||||
}
|
||||
|
||||
// GetInt returns key value from cache with callback when no key exists in cache
|
||||
func GetInt(key string, getFunc func() (int, error)) (int, error) {
|
||||
if conn == nil || setting.CacheService.TTL <= 0 {
|
||||
return getFunc()
|
||||
}
|
||||
|
||||
cached := conn.Get(key)
|
||||
|
||||
if cached == nil {
|
||||
value, err := getFunc()
|
||||
if err != nil {
|
||||
return value, err
|
||||
v, err := concurrencySafeGet(key, getFunc, func(cached any) (int, error) {
|
||||
switch v := cached.(type) {
|
||||
case int:
|
||||
return v, nil
|
||||
case string:
|
||||
value, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
return value, conn.Put(key, value, setting.CacheService.TTLSeconds())
|
||||
}
|
||||
|
||||
switch v := cached.(type) {
|
||||
case int:
|
||||
return v, nil
|
||||
case string:
|
||||
value, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return value, nil
|
||||
default:
|
||||
value, err := getFunc()
|
||||
if err != nil {
|
||||
return value, err
|
||||
}
|
||||
return value, conn.Put(key, value, setting.CacheService.TTLSeconds())
|
||||
}
|
||||
return 0, ErrInconvertible
|
||||
})
|
||||
return v, err
|
||||
}
|
||||
|
||||
// GetInt64 returns key value from cache with callback when no key exists in cache
|
||||
func GetInt64(key string, getFunc func() (int64, error)) (int64, error) {
|
||||
if conn == nil || setting.CacheService.TTL <= 0 {
|
||||
return getFunc()
|
||||
}
|
||||
|
||||
cached := conn.Get(key)
|
||||
|
||||
if cached == nil {
|
||||
value, err := getFunc()
|
||||
if err != nil {
|
||||
return value, err
|
||||
v, err := concurrencySafeGet(key, getFunc, func(cached any) (int64, error) {
|
||||
switch v := cached.(type) {
|
||||
case int64:
|
||||
return v, nil
|
||||
case string:
|
||||
value, err := strconv.ParseInt(v, 10, 64)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
return value, conn.Put(key, value, setting.CacheService.TTLSeconds())
|
||||
}
|
||||
|
||||
switch v := conn.Get(key).(type) {
|
||||
case int64:
|
||||
return v, nil
|
||||
case string:
|
||||
value, err := strconv.ParseInt(v, 10, 64)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return value, nil
|
||||
default:
|
||||
value, err := getFunc()
|
||||
if err != nil {
|
||||
return value, err
|
||||
}
|
||||
|
||||
return value, conn.Put(key, value, setting.CacheService.TTLSeconds())
|
||||
}
|
||||
return 0, ErrInconvertible
|
||||
})
|
||||
return v, err
|
||||
}
|
||||
|
||||
// Remove key from cache
|
||||
|
|
@ -181,5 +188,15 @@ func Remove(key string) {
|
|||
if conn == nil {
|
||||
return
|
||||
}
|
||||
_ = conn.Delete(key)
|
||||
|
||||
// The goal of `Remove(key)` is to ensure that *after* it is completed, a new value is computed. It's possible that
|
||||
// a value is being computed for the key *right now* -- `getFunc` is about to return, we're about to delete the key,
|
||||
// and then it will be Put into the cache with an out-of-date value computed before the `Remove(key)`. To prevent
|
||||
// this we need the `Remove(key)` to also lock on the key, just like `Get*(key, ...)` does when computing it.
|
||||
defer mutexMap.Lock(key)()
|
||||
|
||||
err := conn.Delete(key)
|
||||
if err != nil {
|
||||
log.Error("unexpected error deleting key %s from cache: %v", err)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
96
modules/cache/cache_test.go
vendored
96
modules/cache/cache_test.go
vendored
|
|
@ -5,21 +5,44 @@ package cache
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"forgejo.org/modules/setting"
|
||||
"forgejo.org/modules/test"
|
||||
|
||||
"code.forgejo.org/go-chi/cache"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func createTestCache() {
|
||||
conn, _ = newCache(setting.Cache{
|
||||
Adapter: "memory",
|
||||
TTL: time.Minute,
|
||||
})
|
||||
setting.CacheService.TTL = 24 * time.Hour
|
||||
func createTestCache(t *testing.T) {
|
||||
var err error
|
||||
var testCache cache.Cache
|
||||
|
||||
testRedisHost := os.Getenv("TEST_REDIS_SERVER")
|
||||
if testRedisHost == "" {
|
||||
testCache, err = newCache(setting.Cache{
|
||||
Adapter: "memory",
|
||||
TTL: time.Minute,
|
||||
})
|
||||
} else {
|
||||
testCache, err = newCache(setting.Cache{
|
||||
Adapter: "redis",
|
||||
Conn: fmt.Sprintf("redis://%s", testRedisHost),
|
||||
TTL: time.Minute,
|
||||
})
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, testCache)
|
||||
|
||||
t.Cleanup(test.MockVariableValue(&conn, testCache))
|
||||
t.Cleanup(test.MockVariableValue(&setting.CacheService.TTL, 24*time.Hour))
|
||||
}
|
||||
|
||||
func TestNewContext(t *testing.T) {
|
||||
|
|
@ -36,13 +59,13 @@ func TestNewContext(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestGetCache(t *testing.T) {
|
||||
createTestCache()
|
||||
createTestCache(t)
|
||||
|
||||
assert.NotNil(t, GetCache())
|
||||
}
|
||||
|
||||
func TestGetString(t *testing.T) {
|
||||
createTestCache()
|
||||
createTestCache(t)
|
||||
|
||||
data, err := GetString("key", func() (string, error) {
|
||||
return "", errors.New("some error")
|
||||
|
|
@ -78,7 +101,7 @@ func TestGetString(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestGetInt(t *testing.T) {
|
||||
createTestCache()
|
||||
createTestCache(t)
|
||||
|
||||
data, err := GetInt("key", func() (int, error) {
|
||||
return 0, errors.New("some error")
|
||||
|
|
@ -114,7 +137,7 @@ func TestGetInt(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestGetInt64(t *testing.T) {
|
||||
createTestCache()
|
||||
createTestCache(t)
|
||||
|
||||
data, err := GetInt64("key", func() (int64, error) {
|
||||
return 0, errors.New("some error")
|
||||
|
|
@ -148,3 +171,56 @@ func TestGetInt64(t *testing.T) {
|
|||
assert.EqualValues(t, 100, data)
|
||||
Remove("key")
|
||||
}
|
||||
|
||||
func TestCacheConcurrencySafety(t *testing.T) {
|
||||
createTestCache(t)
|
||||
|
||||
testRedisHost := os.Getenv("TEST_REDIS_SERVER")
|
||||
if testRedisHost == "" {
|
||||
t.Skip("only relevant for a remote redis host")
|
||||
}
|
||||
|
||||
numTests := 20
|
||||
numIncrements := 1000
|
||||
for testCount := range numTests {
|
||||
t.Run(fmt.Sprintf("attempt:%d", testCount), func(t *testing.T) {
|
||||
var counter atomic.Int64
|
||||
var wg sync.WaitGroup
|
||||
var firstError atomic.Value
|
||||
|
||||
getFunc := func() (int64, error) {
|
||||
lastValue := counter.Load()
|
||||
time.Sleep(time.Duration(rand.Intn(20)) * time.Millisecond)
|
||||
return lastValue, nil
|
||||
}
|
||||
|
||||
for range numIncrements {
|
||||
wg.Go(func() {
|
||||
counterValue := counter.Add(1)
|
||||
Remove(t.Name())
|
||||
cachedValue, err := GetInt64(t.Name(), getFunc)
|
||||
if err != nil {
|
||||
firstError.CompareAndSwap(nil, fmt.Sprintf("cache error: %v", err))
|
||||
} else if cachedValue < counterValue {
|
||||
firstError.CompareAndSwap(nil, fmt.Sprintf("incremented to value %d but retrieved value %d from cache", counterValue, cachedValue))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
require.EqualValues(t, numIncrements, counter.Load())
|
||||
if err := firstError.Load(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Without invalidating the cache, check what was last stored in it.
|
||||
value, err := GetInt64(t.Name(), func() (int64, error) {
|
||||
t.Fatal("getFunc should not be invoked")
|
||||
return 0, errors.New("getFunc should not be invoked")
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.EqualValues(t, numIncrements, value)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
60
modules/cache/mutex_map.go
vendored
Normal file
60
modules/cache/mutex_map.go
vendored
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
// Copyright 2025 The Forgejo Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: GPL-3.0-or-later
|
||||
|
||||
package cache
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// MutexMap is basically a map[string]sync.Mutex which allows you to have one mutex per string key being locked. Unlike
|
||||
// a map[string]sync.Mutex, this map will automatically remove the Mutexes from itself when they are not being waited
|
||||
// for, preventing resource waste. It does this by keeping a reference count of the current Lock calls for the given
|
||||
// key.
|
||||
type MutexMap struct {
|
||||
mu sync.Mutex // mutex to be held when accessing mutexMap
|
||||
mutexMap map[string]*refcountMutex
|
||||
}
|
||||
|
||||
type refcountMutex struct {
|
||||
refCount int // access to refCount is protected by the MutexMap's mu
|
||||
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
// Locks the given key, and returns a function that must be invoked to unlock the key.
|
||||
func (m *MutexMap) Lock(key string) func() {
|
||||
m.mu.Lock()
|
||||
if m.mutexMap == nil {
|
||||
m.mutexMap = make(map[string]*refcountMutex)
|
||||
}
|
||||
mutex, ok := m.mutexMap[key]
|
||||
if !ok {
|
||||
mutex = &refcountMutex{}
|
||||
m.mutexMap[key] = mutex
|
||||
}
|
||||
mutex.refCount++
|
||||
m.mu.Unlock()
|
||||
|
||||
mutex.Lock()
|
||||
|
||||
unlockPending := true
|
||||
|
||||
return func() {
|
||||
if !unlockPending {
|
||||
// unlocking twice would cause incorrect reference counts and might release another goroutine's mutex -- try
|
||||
// to detect and panic so that this programming error can be found closest to the source.
|
||||
panic("MutexMap unlock invoked twice")
|
||||
}
|
||||
|
||||
unlockPending = false
|
||||
mutex.Unlock()
|
||||
|
||||
m.mu.Lock()
|
||||
mutex.refCount--
|
||||
if mutex.refCount == 0 {
|
||||
delete(m.mutexMap, key)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
}
|
||||
}
|
||||
131
modules/cache/mutex_map_test.go
vendored
Normal file
131
modules/cache/mutex_map_test.go
vendored
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
// Copyright 2025 The Forgejo Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: GPL-3.0-or-later
|
||||
|
||||
package cache
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestMutexMap_BasicLockUnlock(t *testing.T) {
|
||||
mm := &MutexMap{}
|
||||
|
||||
unlock := mm.Lock("test-key")
|
||||
unlock()
|
||||
|
||||
// Should be able to lock again
|
||||
unlock2 := mm.Lock("test-key")
|
||||
unlock2()
|
||||
}
|
||||
|
||||
func TestMutexMap_ConcurrentSameKey(t *testing.T) {
|
||||
mm := &MutexMap{}
|
||||
var anotherLockActive atomic.Bool
|
||||
var firstError atomic.Value
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for range 10 {
|
||||
wg.Go(func() {
|
||||
unlock := mm.Lock("shared-key")
|
||||
defer unlock()
|
||||
|
||||
// should *not* find that another goroutine has put `true` into here.
|
||||
swapped := anotherLockActive.CompareAndSwap(false, true)
|
||||
if !swapped {
|
||||
firstError.CompareAndSwap(nil, "anotherLockActive was true!")
|
||||
}
|
||||
time.Sleep(time.Duration(rand.Intn(20)) * time.Millisecond) // jitter the goroutines to ensure no serial execution
|
||||
anotherLockActive.Store(false)
|
||||
})
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if err := firstError.Load(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMutexMap_DifferentKeys(t *testing.T) {
|
||||
mm := &MutexMap{}
|
||||
done := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
// If these somehow refered to the same underlying `sync.Mutex`, because `sync.Mutex` is not re-entrant this would
|
||||
// never complete.
|
||||
unlock1 := mm.Lock("test-key-1")
|
||||
unlock2 := mm.Lock("test-key-2")
|
||||
unlock3 := mm.Lock("test-key-3")
|
||||
unlock1()
|
||||
unlock2()
|
||||
unlock3()
|
||||
done <- true
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Success
|
||||
case <-time.After(1 * time.Second): // early timeout so that we don't wait for t.Deadline()
|
||||
t.Fatal("test incomplete after timeout, indicating a locking bug")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMutexMap_SimpleCleanup(t *testing.T) {
|
||||
mm := &MutexMap{}
|
||||
unlock1 := mm.Lock("test-key-1")
|
||||
|
||||
mm.mu.Lock()
|
||||
assert.Len(t, mm.mutexMap, 1)
|
||||
mm.mu.Unlock()
|
||||
|
||||
unlock1()
|
||||
|
||||
mm.mu.Lock()
|
||||
assert.Empty(t, mm.mutexMap)
|
||||
mm.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestMutexMap_ConcurrentCleanup(t *testing.T) {
|
||||
mm := &MutexMap{}
|
||||
var foundRefGreaterThanOne atomic.Bool
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for range 10 {
|
||||
wg.Go(func() {
|
||||
unlock := mm.Lock("shared-key")
|
||||
defer unlock()
|
||||
|
||||
time.Sleep(time.Duration(rand.Intn(20)) * time.Millisecond) // jitter the goroutines to ensure no serial execution
|
||||
|
||||
mm.mu.Lock()
|
||||
rcMutex := mm.mutexMap["shared-key"]
|
||||
if rcMutex.refCount > 1 {
|
||||
foundRefGreaterThanOne.Store(true)
|
||||
}
|
||||
mm.mu.Unlock()
|
||||
})
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
assert.True(t, foundRefGreaterThanOne.Load(), "expected to find a refCount > 1")
|
||||
|
||||
mm.mu.Lock()
|
||||
assert.Empty(t, mm.mutexMap)
|
||||
mm.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestMutexMap_UnlockTwice(t *testing.T) {
|
||||
mm := &MutexMap{}
|
||||
assert.Panics(t, func() {
|
||||
unlock := mm.Lock("test")
|
||||
unlock()
|
||||
unlock()
|
||||
})
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue