diff --git a/models/admin/task.go b/models/admin/task.go index b4e1ac0134..5ba9e503f7 100644 --- a/models/admin/task.go +++ b/models/admin/task.go @@ -5,15 +5,15 @@ package admin import ( "context" + "encoding/base64" "fmt" "forgejo.org/models/db" repo_model "forgejo.org/models/repo" user_model "forgejo.org/models/user" "forgejo.org/modules/json" + "forgejo.org/modules/keying" "forgejo.org/modules/migration" - "forgejo.org/modules/secret" - "forgejo.org/modules/setting" "forgejo.org/modules/structs" "forgejo.org/modules/timeutil" "forgejo.org/modules/util" @@ -120,21 +120,47 @@ func (task *Task) MigrateConfig() (*migration.MigrateOptions, error) { return nil, err } + key := keying.DeriveKey(keying.ContextMigrateTask) + // decrypt credentials if opts.CloneAddrEncrypted != "" { - if opts.CloneAddr, err = secret.DecryptSecret(setting.SecretKey, opts.CloneAddrEncrypted); err != nil { + encryptedCloneAddr, err := base64.RawStdEncoding.DecodeString(opts.CloneAddrEncrypted) + if err != nil { return nil, err } + + cloneAddr, err := key.Decrypt(encryptedCloneAddr, keying.ColumnAndJSONSelectorAndID("payload_content", "clone_addr_encrypted", task.ID)) + if err != nil { + return nil, err + } + + opts.CloneAddr = string(cloneAddr) } if opts.AuthPasswordEncrypted != "" { - if opts.AuthPassword, err = secret.DecryptSecret(setting.SecretKey, opts.AuthPasswordEncrypted); err != nil { + encryptedAuthPassword, err := base64.RawStdEncoding.DecodeString(opts.AuthPasswordEncrypted) + if err != nil { return nil, err } + + authPassword, err := key.Decrypt(encryptedAuthPassword, keying.ColumnAndJSONSelectorAndID("payload_content", "auth_password_encrypted", task.ID)) + if err != nil { + return nil, err + } + + opts.AuthPassword = string(authPassword) } if opts.AuthTokenEncrypted != "" { - if opts.AuthToken, err = secret.DecryptSecret(setting.SecretKey, opts.AuthTokenEncrypted); err != nil { + encryptedAuthToken, err := base64.RawStdEncoding.DecodeString(opts.AuthTokenEncrypted) + if err != nil { return nil, err } + + authToken, err := key.Decrypt(encryptedAuthToken, keying.ColumnAndJSONSelectorAndID("payload_content", "auth_token_encrypted", task.ID)) + if err != nil { + return nil, err + } + + opts.AuthToken = string(authToken) } return &opts, nil diff --git a/models/forgejo_migrations/v14a_migrate_task_secrets.go b/models/forgejo_migrations/v14a_migrate_task_secrets.go new file mode 100644 index 0000000000..98015d53b9 --- /dev/null +++ b/models/forgejo_migrations/v14a_migrate_task_secrets.go @@ -0,0 +1,114 @@ +// Copyright 2025 The Forgejo Authors. All rights reserved. +// SPDX-License-Identifier: GPL-3.0-or-later + +package forgejo_migrations + +import ( + "context" + "encoding/base64" + "fmt" + + admin_model "forgejo.org/models/admin" + "forgejo.org/models/db" + "forgejo.org/modules/json" + "forgejo.org/modules/keying" + "forgejo.org/modules/log" + "forgejo.org/modules/migration" + "forgejo.org/modules/secret" + "forgejo.org/modules/setting" + "forgejo.org/modules/structs" + + "xorm.io/builder" + "xorm.io/xorm" +) + +func init() { + registerMigration(&Migration{ + Description: "migrate columns of `task` table to store keying material", + Upgrade: migrateTaskSecrets, + }) +} + +func migrateTaskSecrets(x *xorm.Engine) error { + return db.WithTx(db.DefaultContext, func(ctx context.Context) error { + sess := db.GetEngine(ctx) + + key := keying.DeriveKey(keying.ContextMigrateTask) + + oldEncryptionKey := setting.SecretKey + messages := make([]string, 0, 100) + ids := make([]int64, 0, 100) + + err := db.Iterate(ctx, builder.Eq{"type": structs.TaskTypeMigrateRepo}, func(ctx context.Context, bean *admin_model.Task) error { + var opts migration.MigrateOptions + err := json.Unmarshal([]byte(bean.PayloadContent), &opts) + if err != nil { + messages = append(messages, fmt.Sprintf("task.id=%d, task.doer_id=%d, task.repo_id=%d, task.owner_id=%d: json.Unmarshal(): %v", bean.ID, bean.DoerID, bean.RepoID, bean.OwnerID, err)) + ids = append(ids, bean.ID) + return nil + } + + decryptionError := false + if opts.CloneAddrEncrypted != "" { + if opts.CloneAddr, err = secret.DecryptSecret(oldEncryptionKey, opts.CloneAddrEncrypted); err != nil { + messages = append(messages, fmt.Sprintf("task.id=%d, task.doer_id=%d, task.repo_id=%d, task.owner_id=%d: secret.DecryptSecret(CloneAddrEncrypted): %v", bean.ID, bean.DoerID, bean.RepoID, bean.OwnerID, err)) + ids = append(ids, bean.ID) + decryptionError = true + } + } + + if opts.AuthPasswordEncrypted != "" { + if opts.AuthPassword, err = secret.DecryptSecret(oldEncryptionKey, opts.AuthPasswordEncrypted); err != nil { + messages = append(messages, fmt.Sprintf("task.id=%d, task.doer_id=%d, task.repo_id=%d, task.owner_id=%d: secret.DecryptSecret(AuthPasswordEncrypted): %v", bean.ID, bean.DoerID, bean.RepoID, bean.OwnerID, err)) + ids = append(ids, bean.ID) + decryptionError = true + } + } + + if opts.AuthTokenEncrypted != "" { + if opts.AuthToken, err = secret.DecryptSecret(oldEncryptionKey, opts.AuthTokenEncrypted); err != nil { + messages = append(messages, fmt.Sprintf("task.id=%d, task.doer_id=%d, task.repo_id=%d, task.owner_id=%d: secret.DecryptSecret(AuthTokenEncrypted): %v", bean.ID, bean.DoerID, bean.RepoID, bean.OwnerID, err)) + ids = append(ids, bean.ID) + decryptionError = true + } + } + + // Don't migrate a task that has a decryption error. + if decryptionError { + return nil + } + + if opts.CloneAddrEncrypted != "" { + opts.CloneAddrEncrypted = base64.RawStdEncoding.EncodeToString(key.Encrypt([]byte(opts.CloneAddr), keying.ColumnAndJSONSelectorAndID("payload_content", "clone_addr_encrypted", bean.ID))) + } + + if opts.AuthPasswordEncrypted != "" { + opts.AuthPasswordEncrypted = base64.RawStdEncoding.EncodeToString(key.Encrypt([]byte(opts.AuthPassword), keying.ColumnAndJSONSelectorAndID("payload_content", "auth_password_encrypted", bean.ID))) + } + + if opts.AuthTokenEncrypted != "" { + opts.AuthTokenEncrypted = base64.RawStdEncoding.EncodeToString(key.Encrypt([]byte(opts.AuthToken), keying.ColumnAndJSONSelectorAndID("payload_content", "auth_token_encrypted", bean.ID))) + } + + bs, err := json.Marshal(&opts) + if err != nil { + return err + } + bean.PayloadContent = string(bs) + + return bean.UpdateCols(ctx, "payload_content") + }) + + if err == nil { + if len(ids) > 0 { + log.Error("v14a_migrate_task_secrets: The following tasks were found to be corrupted and removed from the database.") + for _, message := range messages { + log.Error("v14a_migrate_task_secrets: %s", message) + } + + _, err = sess.In("id", ids).NoAutoCondition().NoAutoTime().Delete(&admin_model.Task{}) + } + } + return err + }) +} diff --git a/models/forgejo_migrations/v14a_migrate_task_secrets_test.go b/models/forgejo_migrations/v14a_migrate_task_secrets_test.go new file mode 100644 index 0000000000..58bc08502a --- /dev/null +++ b/models/forgejo_migrations/v14a_migrate_task_secrets_test.go @@ -0,0 +1,78 @@ +// Copyright 2025 The Forgejo Authors. All rights reserved. +// SPDX-License-Identifier: GPL-3.0-or-later + +package forgejo_migrations + +import ( + "encoding/base64" + "testing" + + migration_tests "forgejo.org/models/gitea_migrations/test" + "forgejo.org/modules/json" + "forgejo.org/modules/keying" + "forgejo.org/modules/migration" + "forgejo.org/modules/structs" + "forgejo.org/modules/timeutil" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_MigrateTaskSecretsToKeying(t *testing.T) { + type Task struct { + ID int64 + DoerID int64 `xorm:"index"` + OwnerID int64 `xorm:"index"` + RepoID int64 `xorm:"index"` + Type structs.TaskType + Status structs.TaskStatus `xorm:"index"` + StartTime timeutil.TimeStamp + EndTime timeutil.TimeStamp + PayloadContent string `xorm:"TEXT"` + Message string `xorm:"TEXT"` + Created timeutil.TimeStamp `xorm:"created"` + } + + // Prepare and load the testing database + x, deferable := migration_tests.PrepareTestEnv(t, 0, new(Task)) + defer deferable() + if x == nil || t.Failed() { + return + } + + cnt, err := x.Table("task").Count() + require.NoError(t, err) + assert.EqualValues(t, 3, cnt) + + require.NoError(t, migrateTaskSecrets(x)) + + cnt, err = x.Table("task").Count() + require.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var task Task + _, err = x.Table("task").ID(1).Get(&task) + require.NoError(t, err) + + var opts migration.MigrateOptions + require.NoError(t, json.Unmarshal([]byte(task.PayloadContent), &opts)) + key := keying.DeriveKey(keying.ContextMigrateTask) + + encryptedCloneAddr, err := base64.RawStdEncoding.DecodeString(opts.CloneAddrEncrypted) + require.NoError(t, err) + cloneAddr, err := key.Decrypt(encryptedCloneAddr, keying.ColumnAndJSONSelectorAndID("payload_content", "clone_addr_encrypted", task.ID)) + require.NoError(t, err) + assert.Equal(t, "https://admin:password@example.com", string(cloneAddr)) + + encryptedAuthPassword, err := base64.RawStdEncoding.DecodeString(opts.AuthPasswordEncrypted) + require.NoError(t, err) + authPassword, err := key.Decrypt(encryptedAuthPassword, keying.ColumnAndJSONSelectorAndID("payload_content", "auth_password_encrypted", task.ID)) + require.NoError(t, err) + assert.Equal(t, "password", string(authPassword)) + + encryptedAuthToken, err := base64.RawStdEncoding.DecodeString(opts.AuthTokenEncrypted) + require.NoError(t, err) + authToken, err := key.Decrypt(encryptedAuthToken, keying.ColumnAndJSONSelectorAndID("payload_content", "auth_token_encrypted", task.ID)) + require.NoError(t, err) + assert.Equal(t, "token", string(authToken)) +} diff --git a/models/gitea_migrations/fixtures/Test_MigrateTaskSecretsToKeying/task.yml b/models/gitea_migrations/fixtures/Test_MigrateTaskSecretsToKeying/task.yml new file mode 100644 index 0000000000..1953daa11c --- /dev/null +++ b/models/gitea_migrations/fixtures/Test_MigrateTaskSecretsToKeying/task.yml @@ -0,0 +1,38 @@ +- + id: 1 + doer_id: 5 + owner_id: 5 + repo_id: 5 + type: 0 + status: 4 + start_time: 1761951636 + end_time: 1761951636 + payload_content: '{"auth_token_encrypted": "2142ac70cf41885b4a3a74f2d36a64662bdbc70f70c7f5b2", "clone_addr_encrypted": "79b9eb793d5af95af61d483566474454b880c4aa80bf3028f561ca227fccfc518b18fb9823c2fa79fa9cf0efae1eb13080e0c51f26c40622ee9d649bff0ef64b", "auth_password_encrypted": "986717ee9de0b9b1fda8afe5f64d53245d8ec4131221085f59ac7e13"}' + message: 'working' + created: 176195163 + +- + id: 2 + doer_id: 5 + owner_id: 5 + repo_id: 5 + type: 0 + status: 4 + start_time: 1761951636 + end_time: 1761951636 + payload_content: '{"auth_token_encrypted": "badbad", "clone_addr_encrypted": "badbad", "auth_password_encrypted": "badbad"}' + message: 'working' + created: 176195163 + +- + id: 3 + doer_id: 5 + owner_id: 5 + repo_id: 5 + type: 0 + status: 4 + start_time: 1761951636 + end_time: 1761951636 + payload_content: '{ badjson' + message: 'working' + created: 176195163 diff --git a/modules/keying/keying.go b/modules/keying/keying.go index f39e16aeed..ff73d1995a 100644 --- a/modules/keying/keying.go +++ b/modules/keying/keying.go @@ -60,6 +60,8 @@ var ( ContextTOTP Context = "totp" // Used for the `secret` table. ContextActionSecret Context = "action_secret" + // Used for the `task` table where type == TaskTypeMigrateRepo. + ContextMigrateTask Context = "migrate_repo_task" ) // Derive *the* key for a given context, this is a deterministic function. @@ -131,3 +133,16 @@ func (k *Key) Decrypt(ciphertext, additionalData []byte) ([]byte, error) { func ColumnAndID(column string, id int64) []byte { return binary.BigEndian.AppendUint64(append([]byte(column), ':'), uint64(id)) } + +// ColumnAndJSONSelectorAndID generates a context that can be used as additional context +// for encrypting and decrypting data. It requires the column name, JSON +// selector and the row ID (this requires to be known beforehand). Be careful +// when using this, as the table name isn't part of this context. This means +// it's not bound to a particular table. The table should be part of the context +// that the key was derived for, in which case it binds through that. Use this +// over `ColumnAndID` if you're encrypting data that's stored inside JSON. +// jsonSelector must be a unambigous selector to the JSON field that stores the +// encrypted data. +func ColumnAndJSONSelectorAndID(column, jsonSelector string, id int64) []byte { + return binary.BigEndian.AppendUint64(append(append([]byte(column), ':'), append([]byte(jsonSelector), ':')...), uint64(id)) +} diff --git a/modules/keying/keying_test.go b/modules/keying/keying_test.go index f73440b357..2c05abb186 100644 --- a/modules/keying/keying_test.go +++ b/modules/keying/keying_test.go @@ -109,3 +109,23 @@ func TestKeyingColumnAndID(t *testing.T) { assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x32, 0x3a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, keying.ColumnAndID("table2", 1)) assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x32, 0x3a, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, keying.ColumnAndID("table2", math.MaxInt64)) } + +func TestColumnAndJSONSelectorAndID(t *testing.T) { + assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x3a, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x31, 0x3a, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, keying.ColumnAndJSONSelectorAndID("table", "field1", math.MinInt64)) + assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x3a, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x31, 0x3a, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, keying.ColumnAndJSONSelectorAndID("table", "field1", -1)) + assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x3a, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x31, 0x3a, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, keying.ColumnAndJSONSelectorAndID("table", "field1", 0)) + assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x3a, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x31, 0x3a, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1}, keying.ColumnAndJSONSelectorAndID("table", "field1", 1)) + assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x3a, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x31, 0x3a, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, keying.ColumnAndJSONSelectorAndID("table", "field1", math.MaxInt64)) + + assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x3a, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x32, 0x3a, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, keying.ColumnAndJSONSelectorAndID("table", "field2", math.MinInt64)) + assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x3a, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x32, 0x3a, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, keying.ColumnAndJSONSelectorAndID("table", "field2", -1)) + assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x3a, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x32, 0x3a, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, keying.ColumnAndJSONSelectorAndID("table", "field2", 0)) + assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x3a, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x32, 0x3a, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1}, keying.ColumnAndJSONSelectorAndID("table", "field2", 1)) + assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x3a, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x32, 0x3a, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, keying.ColumnAndJSONSelectorAndID("table", "field2", math.MaxInt64)) + + assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x32, 0x3a, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x31, 0x3a, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, keying.ColumnAndJSONSelectorAndID("table2", "field1", math.MinInt64)) + assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x32, 0x3a, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x31, 0x3a, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, keying.ColumnAndJSONSelectorAndID("table2", "field1", -1)) + assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x32, 0x3a, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x31, 0x3a, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, keying.ColumnAndJSONSelectorAndID("table2", "field1", 0)) + assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x32, 0x3a, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x31, 0x3a, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1}, keying.ColumnAndJSONSelectorAndID("table2", "field1", 1)) + assert.Equal(t, []byte{0x74, 0x61, 0x62, 0x6c, 0x65, 0x32, 0x3a, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x31, 0x3a, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, keying.ColumnAndJSONSelectorAndID("table2", "field1", math.MaxInt64)) +} diff --git a/services/task/main_test.go b/services/task/main_test.go new file mode 100644 index 0000000000..2f95439450 --- /dev/null +++ b/services/task/main_test.go @@ -0,0 +1,14 @@ +// Copyright 2025 The Forgejo Authors. All rights reserved. +// SPDX-License-Identifier: GPL-3.0-or-later + +package task + +import ( + "testing" + + "forgejo.org/models/unittest" +) + +func TestMain(m *testing.M) { + unittest.MainTest(m) +} diff --git a/services/task/task.go b/services/task/task.go index f030bdb38c..f011feaf8b 100644 --- a/services/task/task.go +++ b/services/task/task.go @@ -5,6 +5,7 @@ package task import ( "context" + "encoding/base64" "errors" "fmt" @@ -14,10 +15,10 @@ import ( user_model "forgejo.org/models/user" "forgejo.org/modules/graceful" "forgejo.org/modules/json" + "forgejo.org/modules/keying" "forgejo.org/modules/log" base "forgejo.org/modules/migration" "forgejo.org/modules/queue" - "forgejo.org/modules/secret" "forgejo.org/modules/setting" "forgejo.org/modules/structs" "forgejo.org/modules/timeutil" @@ -70,36 +71,38 @@ func MigrateRepository(ctx context.Context, doer, u *user_model.User, opts base. // CreateMigrateTask creates a migrate task func CreateMigrateTask(ctx context.Context, doer, u *user_model.User, opts base.MigrateOptions) (*admin_model.Task, error) { // encrypt credentials for persistence - var err error - opts.CloneAddrEncrypted, err = secret.EncryptSecret(setting.SecretKey, opts.CloneAddr) - if err != nil { - return nil, err - } - opts.CloneAddr = util.SanitizeCredentialURLs(opts.CloneAddr) - opts.AuthPasswordEncrypted, err = secret.EncryptSecret(setting.SecretKey, opts.AuthPassword) - if err != nil { - return nil, err - } - opts.AuthPassword = "" - opts.AuthTokenEncrypted, err = secret.EncryptSecret(setting.SecretKey, opts.AuthToken) - if err != nil { - return nil, err - } - opts.AuthToken = "" - bs, err := json.Marshal(&opts) - if err != nil { - return nil, err - } task := &admin_model.Task{ - DoerID: doer.ID, - OwnerID: u.ID, - Type: structs.TaskTypeMigrateRepo, - Status: structs.TaskStatusQueued, - PayloadContent: string(bs), + DoerID: doer.ID, + OwnerID: u.ID, + Type: structs.TaskTypeMigrateRepo, + Status: structs.TaskStatusQueued, } - if err := admin_model.CreateTask(ctx, task); err != nil { + if err := db.WithTx(ctx, func(ctx context.Context) error { + if err := admin_model.CreateTask(ctx, task); err != nil { + return err + } + + key := keying.DeriveKey(keying.ContextMigrateTask) + + opts.CloneAddrEncrypted = base64.RawStdEncoding.EncodeToString(key.Encrypt([]byte(opts.CloneAddr), keying.ColumnAndJSONSelectorAndID("payload_content", "clone_addr_encrypted", task.ID))) + opts.CloneAddr = util.SanitizeCredentialURLs(opts.CloneAddr) + + opts.AuthPasswordEncrypted = base64.RawStdEncoding.EncodeToString(key.Encrypt([]byte(opts.AuthPassword), keying.ColumnAndJSONSelectorAndID("payload_content", "auth_password_encrypted", task.ID))) + opts.AuthPassword = "" + + opts.AuthTokenEncrypted = base64.RawStdEncoding.EncodeToString(key.Encrypt([]byte(opts.AuthToken), keying.ColumnAndJSONSelectorAndID("payload_content", "auth_token_encrypted", task.ID))) + opts.AuthToken = "" + + bs, err := json.Marshal(&opts) + if err != nil { + return err + } + task.PayloadContent = string(bs) + + return task.UpdateCols(ctx, "payload_content") + }); err != nil { return nil, err } diff --git a/services/task/task_test.go b/services/task/task_test.go new file mode 100644 index 0000000000..50ab1394a4 --- /dev/null +++ b/services/task/task_test.go @@ -0,0 +1,52 @@ +package task + +import ( + "testing" + + admin_model "forgejo.org/models/admin" + "forgejo.org/models/unittest" + user_model "forgejo.org/models/user" + "forgejo.org/modules/migration" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCreateMigrateTask(t *testing.T) { + require.NoError(t, unittest.PrepareTestDatabase()) + user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 2}) + + t.Run("Transaction failure", func(t *testing.T) { + defer unittest.SetFaultInjector(2)() + + task, err := CreateMigrateTask(t.Context(), user, user, migration.MigrateOptions{ + CloneAddr: "https://admin:password2@example.com", + AuthPassword: "password", + AuthToken: "token", + RepoName: "migrate-test-2", + }) + require.ErrorIs(t, err, unittest.ErrFaultInjected) + require.Nil(t, task) + + unittest.AssertExistsIf(t, false, &admin_model.Task{}) + }) + + t.Run("Normal", func(t *testing.T) { + task, err := CreateMigrateTask(t.Context(), user, user, migration.MigrateOptions{ + CloneAddr: "https://admin:password@example.com", + AuthPassword: "password", + AuthToken: "token", + RepoName: "migrate-test", + }) + require.NoError(t, err) + require.NotNil(t, task) + + config, err := task.MigrateConfig() + require.NoError(t, err) + require.NotNil(t, config) + + assert.Equal(t, "token", config.AuthToken) + assert.Equal(t, "password", config.AuthPassword) + assert.Equal(t, "https://admin:password@example.com", config.CloneAddr) + }) +}