diff --git a/routers/web/repo/pull.go b/routers/web/repo/pull.go index 83e3a3cef3..c4b8583e2d 100644 --- a/routers/web/repo/pull.go +++ b/routers/web/repo/pull.go @@ -7,6 +7,7 @@ package repo import ( + stdCtx "context" "errors" "fmt" "html" @@ -16,6 +17,7 @@ import ( "path" "strconv" "strings" + "time" "forgejo.org/models" actions_model "forgejo.org/models/actions" @@ -1420,7 +1422,15 @@ func MergePullRequest(ctx *context.Context) { } } - if err := pull_service.Merge(ctx, pr, ctx.Doer, ctx.Repo.GitRepo, repo_model.MergeStyle(form.Do), form.HeadCommitID, message, false); err != nil { + // If the HTTP request is cancelled by the user agent, don't stop work. We've started a merge and need to finish all + // the related work. All usage of `ctx` throughout the rest of this function should be only for error handling or UI + // interactions, and all effective work should use `workCtx` instead. + workCtx, cancelWorkCtx := stdCtx.WithTimeout( + stdCtx.WithoutCancel(ctx), + time.Duration(setting.Git.Timeout.Default)*time.Second) + defer cancelWorkCtx() + + if err := pull_service.Merge(workCtx, pr, ctx.Doer, ctx.Repo.GitRepo, repo_model.MergeStyle(form.Do), form.HeadCommitID, message, false); err != nil { if models.IsErrInvalidMergeStyle(err) { ctx.JSONError(ctx.Tr("repo.pulls.invalid_merge_option")) } else if models.IsErrMergeConflicts(err) { @@ -1491,7 +1501,7 @@ func MergePullRequest(ctx *context.Context) { } log.Trace("Pull request merged: %d", pr.ID) - if err := stopTimerIfAvailable(ctx, ctx.Doer, issue); err != nil { + if err := stopTimerIfAvailable(workCtx, ctx.Doer, issue); err != nil { ctx.ServerError("stopTimerIfAvailable", err) return } @@ -1504,7 +1514,7 @@ func MergePullRequest(ctx *context.Context) { headRepo = ctx.Repo.GitRepo } else { var err error - headRepo, err = gitrepo.OpenRepository(ctx, pr.HeadRepo) + headRepo, err = gitrepo.OpenRepository(workCtx, pr.HeadRepo) if err != nil { ctx.ServerError(fmt.Sprintf("OpenRepository[%s]", pr.HeadRepo.FullName()), err) return @@ -1512,7 +1522,7 @@ func MergePullRequest(ctx *context.Context) { defer headRepo.Close() } - if err := repo_service.DeleteBranchAfterMerge(ctx, ctx.Doer, pr, headRepo); err != nil { + if err := repo_service.DeleteBranchAfterMerge(workCtx, ctx.Doer, pr, headRepo); err != nil { switch { case errors.Is(err, repo_service.ErrBranchIsDefault): ctx.Flash.Error(ctx.Tr("repo.pulls.delete_after_merge.head_branch.is_default")) @@ -1557,7 +1567,7 @@ func CancelAutoMergePullRequest(ctx *context.Context) { ctx.Redirect(issue.HTMLURL()) } -func stopTimerIfAvailable(ctx *context.Context, user *user_model.User, issue *issues_model.Issue) error { +func stopTimerIfAvailable(ctx stdCtx.Context, user *user_model.User, issue *issues_model.Issue) error { if issues_model.StopwatchExists(ctx, user.ID, issue.ID) { if err := issues_model.CreateOrStopIssueStopwatch(ctx, user, issue); err != nil { return err diff --git a/tests/integration/pull_merge_test.go b/tests/integration/pull_merge_test.go index b12ced9073..a987603ce7 100644 --- a/tests/integration/pull_merge_test.go +++ b/tests/integration/pull_merge_test.go @@ -5,6 +5,7 @@ package integration import ( "bytes" + "context" "encoding/base64" "fmt" "math/rand/v2" @@ -1197,6 +1198,68 @@ func shuffleSlice(slice []int64) { }) } +func bulkCreatePRs(t *testing.T, prCount int, repo *repo_model.Repository, token string, labelIDs []int64, milestoneID int64) { + var createAllPRs sync.WaitGroup + var errorListMutex sync.Mutex + var errorList []any + for i := range prCount { + createAllPRs.Add(1) + go func(i int) { + defer createAllPRs.Done() + defer func() { + if r := recover(); r != nil { + errorListMutex.Lock() + defer errorListMutex.Unlock() + errorList = append(errorList, r) + } + }() + + // We're going to create two branches; a new target branch where the PR will merge *into*, and a new + // head branch where the PR will merge *from*. This test is about finding internal concurrency + // conflicts within Forgejo that prevent merges, and, merging simultaneously into the *same branch* + // would have natural conflicts that aren't what we're attempting to test. + targetBranchName := fmt.Sprintf("target-branch-%d", i) + req := NewRequestWithJSON(t, + "POST", + fmt.Sprintf("/api/v1/repos/%s/%s/branches", repo.OwnerName, repo.Name), + &api.CreateBranchRepoOption{ + OldRefName: "main", + BranchName: targetBranchName, + }).AddTokenAuth(token) + MakeRequest(t, req, http.StatusCreated) + + // Create the head branch that we'll be trying to merge from, with a file change: + headBranchName := fmt.Sprintf("update-%d", i) + req = NewRequestWithJSON(t, + "POST", + fmt.Sprintf("/api/v1/repos/%s/%s/contents/README-%d.md", repo.OwnerName, repo.Name, i), + &api.CreateFileOptions{ + FileOptions: api.FileOptions{ + NewBranchName: headBranchName, + }, + ContentBase64: base64.StdEncoding.EncodeToString(fmt.Appendf(nil, "Hello, world %d!\n", i)), + }).AddTokenAuth(token) + MakeRequest(t, req, http.StatusCreated) + + // Create a PR for the branch + myLabelIDs := slices.Clone(labelIDs) + shuffleSlice(myLabelIDs) // use a random ordering for labels as it may cause deadlocks when their count of assigned issues is updated + req = NewRequestWithJSON(t, http.MethodPost, + fmt.Sprintf("/api/v1/repos/%s/%s/pulls", repo.OwnerName, repo.Name), + &api.CreatePullRequestOption{ + Head: headBranchName, + Base: targetBranchName, + Title: fmt.Sprintf("create PR from branch %s", headBranchName), + Labels: myLabelIDs, + Milestone: milestoneID, + }).AddTokenAuth(token) + MakeRequest(t, req, http.StatusCreated) + }(i) + } + createAllPRs.Wait() + assert.Empty(t, errorList) +} + func TestMergeConcurrency(t *testing.T) { onApplicationRun(t, func(t *testing.T, giteaURL *url.URL) { user2 := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 2}) @@ -1241,67 +1304,7 @@ func TestMergeConcurrency(t *testing.T) { var apiMilestone api.Milestone DecodeJSON(t, resp, &apiMilestone) - { - var createAllPRs sync.WaitGroup - var errorListMutex sync.Mutex - var errorList []any - for i := range concurrentCount { - createAllPRs.Add(1) - go func(i int) { - defer createAllPRs.Done() - defer func() { - if r := recover(); r != nil { - errorListMutex.Lock() - defer errorListMutex.Unlock() - errorList = append(errorList, r) - } - }() - - // We're going to create two branches; a new target branch where the PR will merge *into*, and a new - // head branch where the PR will merge *from*. This test is about finding internal concurrency - // conflicts within Forgejo that prevent merges, and, merging simultaneously into the *same branch* - // would have natural conflicts that aren't what we're attempting to test. - targetBranchName := fmt.Sprintf("target-branch-%d", i) - req := NewRequestWithJSON(t, - "POST", - fmt.Sprintf("/api/v1/repos/%s/%s/branches", repo.OwnerName, repo.Name), - &api.CreateBranchRepoOption{ - OldRefName: "main", - BranchName: targetBranchName, - }).AddTokenAuth(token) - MakeRequest(t, req, http.StatusCreated) - - // Create the head branch that we'll be trying to merge from, with a file change: - headBranchName := fmt.Sprintf("update-%d", i) - req = NewRequestWithJSON(t, - "POST", - fmt.Sprintf("/api/v1/repos/%s/%s/contents/README-%d.md", repo.OwnerName, repo.Name, i), - &api.CreateFileOptions{ - FileOptions: api.FileOptions{ - NewBranchName: headBranchName, - }, - ContentBase64: base64.StdEncoding.EncodeToString(fmt.Appendf(nil, "Hello, world %d!\n", i)), - }).AddTokenAuth(token) - MakeRequest(t, req, http.StatusCreated) - - // Create a PR for the branch - myLabelIDs := slices.Clone(labelIDs) - shuffleSlice(myLabelIDs) // use a random ordering for labels as it may cause deadlocks when their count of assigned issues is updated - req = NewRequestWithJSON(t, http.MethodPost, - fmt.Sprintf("/api/v1/repos/%s/%s/pulls", repo.OwnerName, repo.Name), - &api.CreatePullRequestOption{ - Head: headBranchName, - Base: targetBranchName, - Title: fmt.Sprintf("create PR from branch %s", headBranchName), - Labels: myLabelIDs, - Milestone: apiMilestone.ID, - }).AddTokenAuth(token) - MakeRequest(t, req, http.StatusCreated) - }(i) - } - createAllPRs.Wait() - assert.Empty(t, errorList) - } + bulkCreatePRs(t, concurrentCount, repo, token, labelIDs, apiMilestone.ID) // All our PRs are created; now let's try to merge them concurrently. @@ -1377,3 +1380,106 @@ func TestMergeConcurrency(t *testing.T) { } }) } + +func TestMergeHTTPRequestCancellation(t *testing.T) { + onApplicationRun(t, func(t *testing.T, giteaURL *url.URL) { + user2 := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 2}) + user2Session := loginUser(t, "user2") + token := getUserToken(t, "user2", auth_model.AccessTokenScopeWriteRepository, auth_model.AccessTokenScopeWriteIssue) + + // The purpose of this test is to interrupt the HTTP request to "/%s/%s/pulls/%d/merge" by cancelling the + // context at various times during the request, and ensuring that we don't get into any states where the request + // has partially succeeded but then been cancelled -- for example, wrote the merge to the repo, but didn't + // update Forgejo's database. To do this we're going to create a bunch of PRs, merge them, and cancel request + // during merge -- evenly distributing the cancellation times like this: + cancellationChecks := 5 // number of pull requests to create and attempt to merge + measuredMergeTime := 283 * time.Millisecond // time measured on a test system for one POST /%s/%s/pulls/%d/merge + cancellationDuration := measuredMergeTime / time.Duration(cancellationChecks) // cancel after (i+1) * cancellationDuration for each PR + + repo, _, deferrer := tests.CreateDeclarativeRepo(t, user2, "concurrency-test", nil, nil, nil) + defer deferrer() + + bulkCreatePRs(t, cancellationChecks, repo, token, nil, 0) + + // All our PRs are created; now let's try to merge them concurrently. This technically doesn't have to be + // concurrent, but `TestMergeConcurrency` already had all this logic for this test to copy, and it reduces the + // test runtime: + { + var mergeAllPRs sync.WaitGroup + var errorListMutex sync.Mutex + var errorList []any + for i := range cancellationChecks { + mergeAllPRs.Add(1) + go func(i int) { + defer mergeAllPRs.Done() + defer func() { + if r := recover(); r != nil { + errorListMutex.Lock() + defer errorListMutex.Unlock() + errorList = append(errorList, r) + } + }() + + targetBranchName := fmt.Sprintf("target-branch-%d", i) + headBranchName := fmt.Sprintf("update-%d", i) + pr := unittest.AssertExistsAndLoadBean(t, &issues_model.PullRequest{ + HeadRepoID: repo.ID, + BaseRepoID: repo.ID, + HeadBranch: headBranchName, + BaseBranch: targetBranchName, + }) + + // Here's the major subject of this test: every merge request is fired with a different context + // timeout, causing the HTTP request to be interrupted in different places throughout the request. + reqCtx, cancel := context.WithTimeout(t.Context(), time.Duration(i+1)*cancellationDuration) + defer cancel() + + req := NewRequestWithValues(t, "POST", + fmt.Sprintf("/%s/%s/pulls/%d/merge", repo.OwnerName, repo.Name, pr.Index), map[string]string{ + "do": "merge", + "delete_branch_after_merge": "on", + }) + req.Request = req.WithContext(reqCtx) + user2Session.MakeRequest(t, req, NoExpectedStatus) + }(i) + } + mergeAllPRs.Wait() + assert.Empty(t, errorList) + } + + // Verify that all PRs are in a consistent state of merged or not (not a corrupt state): + gitRepo, err := gitrepo.OpenRepository(t.Context(), repo) + require.NoError(t, err) + + for i := range cancellationChecks { + targetBranchName := fmt.Sprintf("target-branch-%d", i) + headBranchName := fmt.Sprintf("update-%d", i) + pr := unittest.AssertExistsAndLoadBean(t, &issues_model.PullRequest{ + HeadRepoID: repo.ID, + BaseRepoID: repo.ID, + HeadBranch: headBranchName, + BaseBranch: targetBranchName, + }) + targetBranchInDB := unittest.AssertExistsAndLoadBean(t, &git_model.Branch{ + RepoID: repo.ID, + Name: targetBranchName, + }) + + targetBranchCommitIDInRepo, err := gitRepo.GetBranchCommitID(targetBranchName) + require.NoError(t, err) + assert.Equal(t, targetBranchCommitIDInRepo, targetBranchInDB.CommitID, "real commit ID match for %s", targetBranchName) + + targetBranchCommitInRepo, err := gitRepo.GetCommit(targetBranchCommitIDInRepo) + require.NoError(t, err) + assert.Equal(t, strings.TrimSpace(targetBranchCommitInRepo.CommitMessage), strings.TrimSpace(targetBranchInDB.CommitMessage)) + + if pr.HasMerged { + assert.Equal(t, + fmt.Sprintf("Merge pull request 'create PR from branch %[1]s' (#%[2]d) from %[1]s into %[3]s", headBranchName, pr.Index, targetBranchName), + targetBranchInDB.CommitMessage) + } else { + assert.Equal(t, "Initial commit", targetBranchInDB.CommitMessage) + } + } + }) +}