jojo/modules/keying/keying_test.go
oliverpool 67df538958 feat: cache derived keys for faster keying (#10114)
Currently `DeriveKey` is called every time that a secret must be encoded/decoded. Since this function is deterministic, its result can be cached to allow a 250x speedup (the original took less than half a microsecond, so this more of a micro-optimization...).

```
go test -bench=.
goos: linux
goarch: amd64
pkg: forgejo.org/modules/keying
cpu: Intel(R) Core(TM) Ultra 5 125H
BenchmarkExpandPRK-18            2071627               564.2 ns/op
BenchmarkExpandPRKOnce-18       541438192                2.206 ns/op
PASS
ok      forgejo.org/modules/keying      2.369s
```

## Other changes

- Since the keys can be constructed once, it simplifies a bit the callsites (`keying.TOTP.Encrypt(...)` instead of `keying.DeriveKey(keying.ContextTOTP).Encrypt(...)`)
- All `Encrypt`/`Decrypt` calls will panic forever if called before `Init` has been called (current it panics as long as `Init` has not been called)
- Calling `Init` twice with different keys will trigger a panic (currently racy)
- Calling `Decrypt` with a short ciphertext does not panic anymore (like when calling with long-enough garbage)

Reviewed-on: https://codeberg.org/forgejo/forgejo/pulls/10114
Reviewed-by: Gusted <gusted@noreply.codeberg.org>
Co-authored-by: oliverpool <git@olivier.pfad.fr>
Co-committed-by: oliverpool <git@olivier.pfad.fr>
2025-11-16 14:29:14 +01:00

161 lines
8 KiB
Go

// Copyright 2024 The Forgejo Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package keying
import (
"crypto/cipher"
"crypto/hkdf"
"encoding/base64"
"math"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/chacha20poly1305"
)
func TestKeying(t *testing.T) {
t.Run("Initialization", func(t *testing.T) {
Init([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07})
})
t.Run("Double initialization", func(t *testing.T) {
t.Run("Same key allowed", func(t *testing.T) {
Init([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07})
})
t.Run("Different key panics", func(t *testing.T) {
assert.Panics(t, func() {
Init([]byte{0x02, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07})
})
})
})
t.Run("Context separation", func(t *testing.T) {
key1 := deriveKey("TESTING")
key2 := deriveKey("TESTING2")
ciphertext := key1.Encrypt([]byte("This is for context TESTING"), nil)
plaintext, err := key2.Decrypt(ciphertext, nil)
require.Error(t, err)
assert.Empty(t, plaintext)
plaintext, err = key1.Decrypt(ciphertext, nil)
require.NoError(t, err)
assert.EqualValues(t, "This is for context TESTING", plaintext)
})
key := deriveKey("TESTING PURPOSES")
plainText := []byte("Forgejo is run by [Redacted]")
var cipherText []byte
t.Run("Encrypt", func(t *testing.T) {
cipherText = key.Encrypt(plainText, []byte{0x05, 0x06})
cipherText2 := key.Encrypt(plainText, []byte{0x05, 0x06})
// Ensure ciphertexts don't have an deterministic output.
assert.NotEqual(t, cipherText, cipherText2)
})
t.Run("Decrypt", func(t *testing.T) {
t.Run("Successful", func(t *testing.T) {
convertedPlainText, err := key.Decrypt(cipherText, []byte{0x05, 0x06})
require.NoError(t, err)
assert.Equal(t, plainText, convertedPlainText)
})
t.Run("Old secret", func(t *testing.T) {
// ensure that new code can still decode old secrets
known, err := base64.RawStdEncoding.DecodeString("LABcdFTke+FAESOAUkaQvdFO/tLFdugvXHqUYQaESy9eCedUsorjpe1N350NN+AU7gv6xyK3DHuugD+wjnVcNvt+9hA")
require.NoError(t, err)
convertedPlainText, err := key.Decrypt(known, []byte{0x05, 0x06})
require.NoError(t, err)
assert.Equal(t, plainText, convertedPlainText)
})
t.Run("Not enough additional data", func(t *testing.T) {
plainText, err := key.Decrypt(cipherText, []byte{0x05})
require.Error(t, err)
assert.Empty(t, plainText)
})
t.Run("Too much additional data", func(t *testing.T) {
plainText, err := key.Decrypt(cipherText, []byte{0x05, 0x06, 0x07})
require.Error(t, err)
assert.Empty(t, plainText)
})
t.Run("Incorrect nonce", func(t *testing.T) {
// Flip the first byte of the nonce.
cipherText[0] = ^cipherText[0]
plainText, err := key.Decrypt(cipherText, []byte{0x05, 0x06})
require.Error(t, err)
assert.Empty(t, plainText)
})
t.Run("Incorrect ciphertext", func(t *testing.T) {
_, err := key.Decrypt(nil, nil)
require.Error(t, err)
cipherText := make([]byte, chacha20poly1305.NonceSizeX)
_, err = key.Decrypt(cipherText, nil)
require.Error(t, err)
})
})
}
func TestKeyingColumnAndID(t *testing.T) {
assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x3a, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, ColumnAndID("table", math.MinInt64))
assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x3a, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, ColumnAndID("table", -1))
assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x3a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, ColumnAndID("table", 0))
assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x3a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, ColumnAndID("table", 1))
assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x3a, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, ColumnAndID("table", math.MaxInt64))
assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x32, 0x3a, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, ColumnAndID("table2", math.MinInt64))
assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x32, 0x3a, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, ColumnAndID("table2", -1))
assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x32, 0x3a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, ColumnAndID("table2", 0))
assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x32, 0x3a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, ColumnAndID("table2", 1))
assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x32, 0x3a, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, ColumnAndID("table2", math.MaxInt64))
}
func TestColumnAndJSONSelectorAndID(t *testing.T) {
assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x3a, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x31, 0x3a, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, ColumnAndJSONSelectorAndID("table", "field1", math.MinInt64))
assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x3a, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x31, 0x3a, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, ColumnAndJSONSelectorAndID("table", "field1", -1))
assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x3a, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x31, 0x3a, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, ColumnAndJSONSelectorAndID("table", "field1", 0))
assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x3a, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x31, 0x3a, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1}, ColumnAndJSONSelectorAndID("table", "field1", 1))
assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x3a, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x31, 0x3a, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, ColumnAndJSONSelectorAndID("table", "field1", math.MaxInt64))
assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x3a, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x32, 0x3a, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, ColumnAndJSONSelectorAndID("table", "field2", math.MinInt64))
assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x3a, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x32, 0x3a, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, ColumnAndJSONSelectorAndID("table", "field2", -1))
assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x3a, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x32, 0x3a, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, ColumnAndJSONSelectorAndID("table", "field2", 0))
assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x3a, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x32, 0x3a, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1}, ColumnAndJSONSelectorAndID("table", "field2", 1))
assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x3a, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x32, 0x3a, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, ColumnAndJSONSelectorAndID("table", "field2", math.MaxInt64))
assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x32, 0x3a, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x31, 0x3a, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, ColumnAndJSONSelectorAndID("table2", "field1", math.MinInt64))
assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x32, 0x3a, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x31, 0x3a, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, ColumnAndJSONSelectorAndID("table2", "field1", -1))
assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x32, 0x3a, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x31, 0x3a, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, ColumnAndJSONSelectorAndID("table2", "field1", 0))
assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x32, 0x3a, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x31, 0x3a, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1}, ColumnAndJSONSelectorAndID("table2", "field1", 1))
assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x32, 0x3a, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x31, 0x3a, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, ColumnAndJSONSelectorAndID("table2", "field1", math.MaxInt64))
}
// 500 ns/op
func BenchmarkExpandPRK(b *testing.B) {
prk, err := hkdf.Extract(hash, []byte("secret"), nil)
require.NoError(b, err)
for b.Loop() {
expandPRK(prk, "testing")
}
}
// 2 ns/op
func BenchmarkExpandPRKOnce(b *testing.B) {
prk, err := hkdf.Extract(hash, []byte("secret"), nil)
require.NoError(b, err)
once := sync.OnceValue(func() cipher.AEAD {
return expandPRK(prk, "testing")
})
for b.Loop() {
once()
}
}