diff --git a/modules/cache/cache_test.go b/modules/cache/cache_test.go index 61944d7c22..c9225c965c 100644 --- a/modules/cache/cache_test.go +++ b/modules/cache/cache_test.go @@ -5,13 +5,18 @@ package cache import ( "errors" + "fmt" + "math/rand" + "os" + "sync" + "sync/atomic" "testing" "time" - "code.forgejo.org/go-chi/cache" "forgejo.org/modules/setting" "forgejo.org/modules/test" + "code.forgejo.org/go-chi/cache" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -166,3 +171,56 @@ func TestGetInt64(t *testing.T) { assert.EqualValues(t, 100, data) Remove("key") } + +func TestCacheConcurrencySafety(t *testing.T) { + createTestCache(t) + + testRedisHost := os.Getenv("TEST_REDIS_SERVER") + if testRedisHost == "" { + t.Skip("only relevant for a remote redis host") + } + + numTests := 20 + numIncrements := 1000 + for testCount := range numTests { + t.Run(fmt.Sprintf("attempt:%d", testCount), func(t *testing.T) { + var counter atomic.Int64 + var wg sync.WaitGroup + var firstError atomic.Value + + getFunc := func() (int64, error) { + lastValue := counter.Load() + time.Sleep(time.Duration(rand.Intn(20)) * time.Millisecond) + return lastValue, nil + } + + for range numIncrements { + wg.Go(func() { + counterValue := counter.Add(1) + Remove(t.Name()) + cachedValue, err := GetInt64(t.Name(), getFunc) + if err != nil { + firstError.CompareAndSwap(nil, fmt.Sprintf("cache error: %v", err)) + } else if cachedValue < counterValue { + firstError.CompareAndSwap(nil, fmt.Sprintf("incremented to value %d but retrieved value %d from cache", counterValue, cachedValue)) + } + }) + } + + wg.Wait() + require.EqualValues(t, numIncrements, counter.Load()) + if err := firstError.Load(); err != nil { + t.Fatal(err) + } + + // Without invalidating the cache, check what was last stored in it. + value, err := GetInt64(t.Name(), func() (int64, error) { + t.Fatal("getFunc should not be invoked") + return 0, errors.New("getFunc should not be invoked") + }) + + require.NoError(t, err) + assert.EqualValues(t, numIncrements, value) + }) + } +}