From 9abc1b0144feaf053176ff80bd94d694a3cfce13 Mon Sep 17 00:00:00 2001 From: Mathieu Fenniak Date: Sun, 5 Apr 2026 22:03:45 +0200 Subject: [PATCH] refactor: reduce code duplication when accessing `DefaultMaxInSize` (#11999) `DefaultMaxInSize` is an internal parameter for limiting the size of `field IN (...)` clauses in DB queries, which is a reasonable thing to do -- in addition to the errors noted when [originally introduced](https://github.com/go-gitea/gitea/pull/4594), there are technical limits that apply to each of PostgreSQL, MySQL, and SQLite which would prevent an unbounded size for a query like this. However: the size is incredibly small at 50, and, the implementation of `DefaultMaxInSize` is really wasteful with copy-and-paste coding. This PR: - introduces `GetByIDs` which fetches a `map[int64]*Model` from the database for an array of ID values, while respecting `IN` clause size limits - introduces `GetByFieldIn` which fetches a `map[int64][]*Model` from the database for an array of field values, while respecting `IN` clause size limits - uses `slices.Chunk` for other locations where queries are too complex for these implementations - bumps the `DefaultMaxInSize` parameter from 50 to 500, a conservative increase well under known limits, but 10x the current value: - PostgreSQL supports up to 1GB query text size with 65,535 parameters, but I've experienced performance degradation at high value counts - MySQL supports 64MB query text size without known limits of parameter count - SQLite supports 32,766 parameters in a query ## Checklist The [contributor guide](https://forgejo.org/docs/next/contributor/) contains information that will be helpful to first time contributors. All work and communication must conform to Forgejo's [AI Agreement](https://codeberg.org/forgejo/governance/src/branch/main/AIAgreement.md). There also are a few [conditions for merging Pull Requests in Forgejo repositories](https://codeberg.org/forgejo/governance/src/branch/main/PullRequestsAgreement.md). You are also welcome to join the [Forgejo development chatroom](https://matrix.to/#/#forgejo-development:matrix.org). ### Tests for Go changes - I added test coverage for Go changes... - [x] in their respective `*_test.go` for unit tests. - Refactored functions are assumed to be covered by existing tests to some extent; that assumption is probably wrong but the changes here are relatively easily reviewed for correctness as well. - [ ] in the `tests/integration` directory if it involves interactions with a live Forgejo server. - I ran... - [x] `make pr-go` before pushing ### Documentation - [ ] I created a pull request [to the documentation](https://codeberg.org/forgejo/docs) to explain to Forgejo users how to use this change. - [x] I did not document these changes and I do not expect someone else to do it. ### Release notes - [ ] This change will be noticed by a Forgejo user or admin (feature, bug fix, performance, etc.). I suggest to include a release note for this change. - [x] This change is not visible to a Forgejo user or admin (refactor, dependency upgrade, etc.). I think there is no need to add a release note for this change. Reviewed-on: https://codeberg.org/forgejo/forgejo/pulls/11999 Reviewed-by: Andreas Ahlenstorf Co-authored-by: Mathieu Fenniak Co-committed-by: Mathieu Fenniak --- models/activities/notification_list.go | 112 ++--------------- models/db/context.go | 86 +++++++++++++ models/db/list.go | 2 +- models/issues/comment_list.go | 168 ++++--------------------- models/issues/issue_list.go | 151 ++++------------------ models/issues/reaction.go | 10 +- tests/integration/db_query_test.go | 64 ++++++++++ 7 files changed, 215 insertions(+), 378 deletions(-) create mode 100644 tests/integration/db_query_test.go diff --git a/models/activities/notification_list.go b/models/activities/notification_list.go index bf6356021e..3f3a48eaa5 100644 --- a/models/activities/notification_list.go +++ b/models/activities/notification_list.go @@ -210,31 +210,9 @@ func (nl NotificationList) LoadRepos(ctx context.Context) (repo_model.Repository } repoIDs := nl.getPendingRepoIDs() - repos := make(map[int64]*repo_model.Repository, len(repoIDs)) - left := len(repoIDs) - for left > 0 { - limit := min(left, db.DefaultMaxInSize) - rows, err := db.GetEngine(ctx). - In("id", repoIDs[:limit]). - Rows(new(repo_model.Repository)) - if err != nil { - return nil, nil, err - } - - for rows.Next() { - var repo repo_model.Repository - err = rows.Scan(&repo) - if err != nil { - rows.Close() - return nil, nil, err - } - - repos[repo.ID] = &repo - } - _ = rows.Close() - - left -= limit - repoIDs = repoIDs[limit:] + repos, err := db.GetByIDs(ctx, "id", repoIDs, &repo_model.Repository{}) + if err != nil { + return nil, nil, err } failed := []int{} @@ -281,31 +259,9 @@ func (nl NotificationList) LoadIssues(ctx context.Context) ([]int, error) { } issueIDs := nl.getPendingIssueIDs() - issues := make(map[int64]*issues_model.Issue, len(issueIDs)) - left := len(issueIDs) - for left > 0 { - limit := min(left, db.DefaultMaxInSize) - rows, err := db.GetEngine(ctx). - In("id", issueIDs[:limit]). - Rows(new(issues_model.Issue)) - if err != nil { - return nil, err - } - - for rows.Next() { - var issue issues_model.Issue - err = rows.Scan(&issue) - if err != nil { - rows.Close() - return nil, err - } - - issues[issue.ID] = &issue - } - _ = rows.Close() - - left -= limit - issueIDs = issueIDs[limit:] + issues, err := db.GetByIDs(ctx, "id", issueIDs, &issues_model.Issue{}) + if err != nil { + return nil, err } failures := []int{} @@ -373,31 +329,9 @@ func (nl NotificationList) LoadUsers(ctx context.Context) ([]int, error) { } userIDs := nl.getUserIDs() - users := make(map[int64]*user_model.User, len(userIDs)) - left := len(userIDs) - for left > 0 { - limit := min(left, db.DefaultMaxInSize) - rows, err := db.GetEngine(ctx). - In("id", userIDs[:limit]). - Rows(new(user_model.User)) - if err != nil { - return nil, err - } - - for rows.Next() { - var user user_model.User - err = rows.Scan(&user) - if err != nil { - rows.Close() - return nil, err - } - - users[user.ID] = &user - } - _ = rows.Close() - - left -= limit - userIDs = userIDs[limit:] + users, err := db.GetByIDs(ctx, "id", userIDs, &user_model.User{}) + if err != nil { + return nil, err } failures := []int{} @@ -421,31 +355,9 @@ func (nl NotificationList) LoadComments(ctx context.Context) ([]int, error) { } commentIDs := nl.getPendingCommentIDs() - comments := make(map[int64]*issues_model.Comment, len(commentIDs)) - left := len(commentIDs) - for left > 0 { - limit := min(left, db.DefaultMaxInSize) - rows, err := db.GetEngine(ctx). - In("id", commentIDs[:limit]). - Rows(new(issues_model.Comment)) - if err != nil { - return nil, err - } - - for rows.Next() { - var comment issues_model.Comment - err = rows.Scan(&comment) - if err != nil { - rows.Close() - return nil, err - } - - comments[comment.ID] = &comment - } - _ = rows.Close() - - left -= limit - commentIDs = commentIDs[limit:] + comments, err := db.GetByIDs(ctx, "id", commentIDs, &issues_model.Comment{}) + if err != nil { + return nil, err } failures := []int{} diff --git a/models/db/context.go b/models/db/context.go index f098b40a32..18237bb2a2 100644 --- a/models/db/context.go +++ b/models/db/context.go @@ -8,6 +8,7 @@ import ( "database/sql" "errors" "fmt" + "slices" "xorm.io/builder" "xorm.io/xorm" @@ -270,6 +271,91 @@ func GetByID[T any](ctx context.Context, id int64) (object *T, exist bool, err e return &bean, true, nil } +// Retrieves multiple objects with database queries similar to an xorm `.In(idField, idList)`. idField must be a unique +// field on the database table, as a map[id]obj is returned and the usage of a non-unique field would result in objects +// being overwritten in the map. +// +// The length of the IN list is constrained to DefaultMaxInSize for each database query, resulting in multiple database +// queries if the length of the idList exceeds that setting; this constraint prevents exceeding bind parameter +// limitations or query length limitations in the database engine. +func GetByIDs[Bean any, Id comparable](ctx context.Context, idField string, idList []Id, bean *Bean) (map[Id]*Bean, error) { + retval := make(map[Id]*Bean, len(idList)) + if len(idList) == 0 { + return retval, nil + } + + table, err := TableInfo(bean) + if err != nil { + return nil, fmt.Errorf("unable to fetch table info for bean %v: %w", bean, err) + } + + var structFieldName string + for _, c := range table.Columns() { + if c.Name == idField { + structFieldName = c.FieldName + break + } + } + if structFieldName == "" { + return nil, fmt.Errorf("unable to identify struct field for id field %s", idField) + } + + for idChunk := range slices.Chunk(idList, DefaultMaxInSize) { + beans := make([]*Bean, 0, len(idChunk)) + if err := GetEngine(ctx).In(idField, idChunk).Find(&beans); err != nil { + return nil, err + } + for _, bean := range beans { + retval[extractFieldValue(bean, structFieldName).(Id)] = bean + } + } + + return retval, nil +} + +// Retrieves multiple objects with database queries similar to an xorm `.In(field, valueList)`. Similar to GetByIDs, +// except that a map[Id][]*Bean is returned as the field value is not assumed to be a unique value -- if there are +// multiple rows in the table for each value, all of them are returned. +// +// The length of the IN list is constrained to DefaultMaxInSize for each database query, resulting in multiple database +// queries if the length of the idList exceeds that setting; this constraint prevents exceeding bind parameter +// limitations or query length limitations in the database engine. +func GetByFieldIn[Bean any, Id comparable](ctx context.Context, field string, valueList []Id, bean *Bean) (map[Id][]*Bean, error) { + retval := make(map[Id][]*Bean, len(valueList)) + if len(valueList) == 0 { + return retval, nil + } + + table, err := TableInfo(bean) + if err != nil { + return nil, fmt.Errorf("unable to fetch table info for bean %v: %w", bean, err) + } + + var structFieldName string + for _, c := range table.Columns() { + if c.Name == field { + structFieldName = c.FieldName + break + } + } + if structFieldName == "" { + return nil, fmt.Errorf("unable to identify struct field for field %s", field) + } + + for idChunk := range slices.Chunk(valueList, DefaultMaxInSize) { + beans := make([]*Bean, 0, len(idChunk)) + if err := GetEngine(ctx).In(field, idChunk).Find(&beans); err != nil { + return nil, err + } + for _, bean := range beans { + fieldValue := extractFieldValue(bean, structFieldName).(Id) + retval[fieldValue] = append(retval[fieldValue], bean) + } + } + + return retval, nil +} + func Exist[T any](ctx context.Context, cond builder.Cond) (bool, error) { if !cond.IsValid() { panic("cond is invalid in db.Exist(ctx, cond). This should not be possible.") diff --git a/models/db/list.go b/models/db/list.go index 057221936c..71e9a0b1d2 100644 --- a/models/db/list.go +++ b/models/db/list.go @@ -14,7 +14,7 @@ import ( const ( // DefaultMaxInSize represents default variables number on IN () in SQL - DefaultMaxInSize = 50 + DefaultMaxInSize = 500 defaultFindSliceSize = 10 ) diff --git a/models/issues/comment_list.go b/models/issues/comment_list.go index 9a5c22244b..b218f11dfa 100644 --- a/models/issues/comment_list.go +++ b/models/issues/comment_list.go @@ -52,29 +52,9 @@ func (comments CommentList) loadLabels(ctx context.Context) error { } labelIDs := comments.getLabelIDs() - commentLabels := make(map[int64]*Label, len(labelIDs)) - left := len(labelIDs) - for left > 0 { - limit := min(left, db.DefaultMaxInSize) - rows, err := db.GetEngine(ctx). - In("id", labelIDs[:limit]). - Rows(new(Label)) - if err != nil { - return err - } - - for rows.Next() { - var label Label - err = rows.Scan(&label) - if err != nil { - _ = rows.Close() - return err - } - commentLabels[label.ID] = &label - } - _ = rows.Close() - left -= limit - labelIDs = labelIDs[limit:] + commentLabels, err := db.GetByIDs(ctx, "id", labelIDs, &Label{}) + if err != nil { + return err } for _, comment := range comments { @@ -99,18 +79,9 @@ func (comments CommentList) loadMilestones(ctx context.Context) error { return nil } - milestones := make(map[int64]*Milestone, len(milestoneIDs)) - left := len(milestoneIDs) - for left > 0 { - limit := min(left, db.DefaultMaxInSize) - err := db.GetEngine(ctx). - In("id", milestoneIDs[:limit]). - Find(&milestones) - if err != nil { - return err - } - left -= limit - milestoneIDs = milestoneIDs[limit:] + milestones, err := db.GetByIDs(ctx, "id", milestoneIDs, &Milestone{}) + if err != nil { + return err } for _, comment := range comments { @@ -135,18 +106,9 @@ func (comments CommentList) loadOldMilestones(ctx context.Context) error { return nil } - milestones := make(map[int64]*Milestone, len(milestoneIDs)) - left := len(milestoneIDs) - for left > 0 { - limit := min(left, db.DefaultMaxInSize) - err := db.GetEngine(ctx). - In("id", milestoneIDs[:limit]). - Find(&milestones) - if err != nil { - return err - } - left -= limit - milestoneIDs = milestoneIDs[limit:] + milestones, err := db.GetByIDs(ctx, "id", milestoneIDs, &Milestone{}) + if err != nil { + return err } for _, comment := range comments { @@ -167,31 +129,9 @@ func (comments CommentList) loadAssignees(ctx context.Context) error { } assigneeIDs := comments.getAssigneeIDs() - assignees := make(map[int64]*user_model.User, len(assigneeIDs)) - left := len(assigneeIDs) - for left > 0 { - limit := min(left, db.DefaultMaxInSize) - rows, err := db.GetEngine(ctx). - In("id", assigneeIDs[:limit]). - Rows(new(user_model.User)) - if err != nil { - return err - } - - for rows.Next() { - var user user_model.User - err = rows.Scan(&user) - if err != nil { - rows.Close() - return err - } - - assignees[user.ID] = &user - } - _ = rows.Close() - - left -= limit - assigneeIDs = assigneeIDs[limit:] + assignees, err := db.GetByIDs(ctx, "id", assigneeIDs, &user_model.User{}) + if err != nil { + return err } for _, comment := range comments { @@ -232,31 +172,9 @@ func (comments CommentList) LoadIssues(ctx context.Context) error { } issueIDs := comments.getIssueIDs() - issues := make(map[int64]*Issue, len(issueIDs)) - left := len(issueIDs) - for left > 0 { - limit := min(left, db.DefaultMaxInSize) - rows, err := db.GetEngine(ctx). - In("id", issueIDs[:limit]). - Rows(new(Issue)) - if err != nil { - return err - } - - for rows.Next() { - var issue Issue - err = rows.Scan(&issue) - if err != nil { - rows.Close() - return err - } - - issues[issue.ID] = &issue - } - _ = rows.Close() - - left -= limit - issueIDs = issueIDs[limit:] + issues, err := db.GetByIDs(ctx, "id", issueIDs, &Issue{}) + if err != nil { + return err } for _, comment := range comments { @@ -281,33 +199,10 @@ func (comments CommentList) loadDependentIssues(ctx context.Context) error { return nil } - e := db.GetEngine(ctx) issueIDs := comments.getDependentIssueIDs() - issues := make(map[int64]*Issue, len(issueIDs)) - left := len(issueIDs) - for left > 0 { - limit := min(left, db.DefaultMaxInSize) - rows, err := e. - In("id", issueIDs[:limit]). - Rows(new(Issue)) - if err != nil { - return err - } - - for rows.Next() { - var issue Issue - err = rows.Scan(&issue) - if err != nil { - _ = rows.Close() - return err - } - - issues[issue.ID] = &issue - } - _ = rows.Close() - - left -= limit - issueIDs = issueIDs[limit:] + issues, err := db.GetByIDs(ctx, "id", issueIDs, &Issue{}) + if err != nil { + return err } for _, comment := range comments { @@ -358,31 +253,10 @@ func (comments CommentList) LoadAttachments(ctx context.Context) (err error) { return nil } - attachments := make(map[int64][]*repo_model.Attachment, len(comments)) commentsIDs := comments.getAttachmentCommentIDs() - left := len(commentsIDs) - for left > 0 { - limit := min(left, db.DefaultMaxInSize) - rows, err := db.GetEngine(ctx). - In("comment_id", commentsIDs[:limit]). - Rows(new(repo_model.Attachment)) - if err != nil { - return err - } - - for rows.Next() { - var attachment repo_model.Attachment - err = rows.Scan(&attachment) - if err != nil { - _ = rows.Close() - return err - } - attachments[attachment.CommentID] = append(attachments[attachment.CommentID], &attachment) - } - - _ = rows.Close() - left -= limit - commentsIDs = commentsIDs[limit:] + attachments, err := db.GetByFieldIn(ctx, "comment_id", commentsIDs, &repo_model.Attachment{}) + if err != nil { + return err } for _, comment := range comments { diff --git a/models/issues/issue_list.go b/models/issues/issue_list.go index 34cfe35475..e4fd9eef2b 100644 --- a/models/issues/issue_list.go +++ b/models/issues/issue_list.go @@ -6,6 +6,7 @@ package issues import ( "context" "fmt" + "slices" "forgejo.org/models/db" project_model "forgejo.org/models/project" @@ -40,18 +41,9 @@ func (issues IssueList) LoadRepositories(ctx context.Context) (repo_model.Reposi } repoIDs := issues.getRepoIDs() - repoMaps := make(map[int64]*repo_model.Repository, len(repoIDs)) - left := len(repoIDs) - for left > 0 { - limit := min(left, db.DefaultMaxInSize) - err := db.GetEngine(ctx). - In("id", repoIDs[:limit]). - Find(&repoMaps) - if err != nil { - return nil, fmt.Errorf("find repository: %w", err) - } - left -= limit - repoIDs = repoIDs[limit:] + repoMaps, err := db.GetByIDs(ctx, "id", repoIDs, &repo_model.Repository{}) + if err != nil { + return nil, fmt.Errorf("find repository: %w", err) } for _, issue := range issues { @@ -93,18 +85,9 @@ func (issues IssueList) LoadPosters(ctx context.Context) error { } func getPostersByIDs(ctx context.Context, posterIDs []int64) (map[int64]*user_model.User, error) { - posterMaps := make(map[int64]*user_model.User, len(posterIDs)) - left := len(posterIDs) - for left > 0 { - limit := min(left, db.DefaultMaxInSize) - err := db.GetEngine(ctx). - In("id", posterIDs[:limit]). - Find(&posterMaps) - if err != nil { - return nil, err - } - left -= limit - posterIDs = posterIDs[limit:] + posterMaps, err := db.GetByIDs(ctx, "id", posterIDs, &user_model.User{}) + if err != nil { + return nil, err } return posterMaps, nil } @@ -129,18 +112,15 @@ func (issues IssueList) LoadLabels(ctx context.Context) error { issueLabels := make(map[int64][]*Label, len(issues)*3) issueIDs := issues.getIssueIDs() - left := len(issueIDs) - for left > 0 { - limit := min(left, db.DefaultMaxInSize) + for issueIDChunk := range slices.Chunk(issueIDs, db.DefaultMaxInSize) { rows, err := db.GetEngine(ctx).Table("label"). Join("LEFT", "issue_label", "issue_label.label_id = label.id"). - In("issue_label.issue_id", issueIDs[:limit]). + In("issue_label.issue_id", issueIDChunk). Asc("label.name"). Rows(new(LabelIssue)) if err != nil { return err } - for rows.Next() { var labelIssue LabelIssue err = rows.Scan(&labelIssue) @@ -157,8 +137,6 @@ func (issues IssueList) LoadLabels(ctx context.Context) error { if err1 := rows.Close(); err1 != nil { return fmt.Errorf("IssueList.LoadLabels: Close: %w", err1) } - left -= limit - issueIDs = issueIDs[limit:] } for _, issue := range issues { @@ -180,18 +158,9 @@ func (issues IssueList) LoadMilestones(ctx context.Context) error { return nil } - milestoneMaps := make(map[int64]*Milestone, len(milestoneIDs)) - left := len(milestoneIDs) - for left > 0 { - limit := min(left, db.DefaultMaxInSize) - err := db.GetEngine(ctx). - In("id", milestoneIDs[:limit]). - Find(&milestoneMaps) - if err != nil { - return err - } - left -= limit - milestoneIDs = milestoneIDs[limit:] + milestoneMaps, err := db.GetByIDs(ctx, "id", milestoneIDs, &Milestone{}) + if err != nil { + return err } for _, issue := range issues { @@ -204,22 +173,19 @@ func (issues IssueList) LoadMilestones(ctx context.Context) error { func (issues IssueList) LoadProjects(ctx context.Context) error { issueIDs := issues.getIssueIDs() projectMaps := make(map[int64]*project_model.Project, len(issues)) - left := len(issueIDs) type projectWithIssueID struct { *project_model.Project `xorm:"extends"` IssueID int64 } - for left > 0 { - limit := min(left, db.DefaultMaxInSize) - - projects := make([]*projectWithIssueID, 0, limit) + for issueIDChunk := range slices.Chunk(issueIDs, db.DefaultMaxInSize) { + projects := make([]*projectWithIssueID, 0, len(issueIDChunk)) err := db.GetEngine(ctx). Table("project"). Select("project.*, project_issue.issue_id"). Join("INNER", "project_issue", "project.id = project_issue.project_id"). - In("project_issue.issue_id", issueIDs[:limit]). + In("project_issue.issue_id", issueIDChunk). Find(&projects) if err != nil { return err @@ -227,8 +193,6 @@ func (issues IssueList) LoadProjects(ctx context.Context) error { for _, project := range projects { projectMaps[project.IssueID] = project.Project } - left -= limit - issueIDs = issueIDs[limit:] } for _, issue := range issues { @@ -249,12 +213,10 @@ func (issues IssueList) LoadAssignees(ctx context.Context) error { assignees := make(map[int64][]*user_model.User, len(issues)) issueIDs := issues.getIssueIDs() - left := len(issueIDs) - for left > 0 { - limit := min(left, db.DefaultMaxInSize) + for issueIDChunk := range slices.Chunk(issueIDs, db.DefaultMaxInSize) { rows, err := db.GetEngine(ctx).Table("issue_assignees"). Join("INNER", "`user`", "`user`.id = `issue_assignees`.assignee_id"). - In("`issue_assignees`.issue_id", issueIDs[:limit]).OrderBy(user_model.GetOrderByName()). + In("`issue_assignees`.issue_id", issueIDChunk).OrderBy(user_model.GetOrderByName()). Rows(new(AssigneeIssue)) if err != nil { return err @@ -275,8 +237,6 @@ func (issues IssueList) LoadAssignees(ctx context.Context) error { if err1 := rows.Close(); err1 != nil { return fmt.Errorf("IssueList.loadAssignees: Close: %w", err1) } - left -= limit - issueIDs = issueIDs[limit:] } for _, issue := range issues { @@ -306,33 +266,9 @@ func (issues IssueList) LoadPullRequests(ctx context.Context) error { return nil } - pullRequestMaps := make(map[int64]*PullRequest, len(issuesIDs)) - left := len(issuesIDs) - for left > 0 { - limit := min(left, db.DefaultMaxInSize) - rows, err := db.GetEngine(ctx). - In("issue_id", issuesIDs[:limit]). - Rows(new(PullRequest)) - if err != nil { - return err - } - - for rows.Next() { - var pr PullRequest - err = rows.Scan(&pr) - if err != nil { - if err1 := rows.Close(); err1 != nil { - return fmt.Errorf("IssueList.loadPullRequests: Close: %w", err1) - } - return err - } - pullRequestMaps[pr.IssueID] = &pr - } - if err1 := rows.Close(); err1 != nil { - return fmt.Errorf("IssueList.loadPullRequests: Close: %w", err1) - } - left -= limit - issuesIDs = issuesIDs[limit:] + pullRequestMaps, err := db.GetByIDs(ctx, "issue_id", issuesIDs, &PullRequest{}) + if err != nil { + return err } for _, issue := range issues { @@ -350,34 +286,10 @@ func (issues IssueList) LoadAttachments(ctx context.Context) (err error) { return nil } - attachments := make(map[int64][]*repo_model.Attachment, len(issues)) issuesIDs := issues.getIssueIDs() - left := len(issuesIDs) - for left > 0 { - limit := min(left, db.DefaultMaxInSize) - rows, err := db.GetEngine(ctx). - In("issue_id", issuesIDs[:limit]). - Rows(new(repo_model.Attachment)) - if err != nil { - return err - } - - for rows.Next() { - var attachment repo_model.Attachment - err = rows.Scan(&attachment) - if err != nil { - if err1 := rows.Close(); err1 != nil { - return fmt.Errorf("IssueList.loadAttachments: Close: %w", err1) - } - return err - } - attachments[attachment.IssueID] = append(attachments[attachment.IssueID], &attachment) - } - if err1 := rows.Close(); err1 != nil { - return fmt.Errorf("IssueList.loadAttachments: Close: %w", err1) - } - left -= limit - issuesIDs = issuesIDs[limit:] + attachments, err := db.GetByFieldIn(ctx, "issue_id", issuesIDs, &repo_model.Attachment{}) + if err != nil { + return err } for _, issue := range issues { @@ -394,12 +306,10 @@ func (issues IssueList) loadComments(ctx context.Context, cond builder.Cond) (er comments := make(map[int64][]*Comment, len(issues)) issuesIDs := issues.getIssueIDs() - left := len(issuesIDs) - for left > 0 { - limit := min(left, db.DefaultMaxInSize) + for issueIDChunk := range slices.Chunk(issuesIDs, db.DefaultMaxInSize) { rows, err := db.GetEngine(ctx).Table("comment"). Join("INNER", "issue", "issue.id = comment.issue_id"). - In("issue.id", issuesIDs[:limit]). + In("issue.id", issueIDChunk). Where(cond). Rows(new(Comment)) if err != nil { @@ -420,8 +330,6 @@ func (issues IssueList) loadComments(ctx context.Context, cond builder.Cond) (er if err1 := rows.Close(); err1 != nil { return fmt.Errorf("IssueList.loadComments: Close: %w", err1) } - left -= limit - issuesIDs = issuesIDs[limit:] } for _, issue := range issues { @@ -457,15 +365,12 @@ func (issues IssueList) loadTotalTrackedTimes(ctx context.Context) (err error) { } } - left := len(ids) - for left > 0 { - limit := min(left, db.DefaultMaxInSize) - + for idChunk := range slices.Chunk(ids, db.DefaultMaxInSize) { // select issue_id, sum(time) from tracked_time where issue_id in () group by issue_id rows, err := db.GetEngine(ctx).Table("tracked_time"). Where("deleted = ?", false). Select("issue_id, sum(time) as time"). - In("issue_id", ids[:limit]). + In("issue_id", idChunk). GroupBy("issue_id"). Rows(new(totalTimesByIssue)) if err != nil { @@ -486,8 +391,6 @@ func (issues IssueList) loadTotalTrackedTimes(ctx context.Context) (err error) { if err1 := rows.Close(); err1 != nil { return fmt.Errorf("IssueList.loadTotalTrackedTimes: Close: %w", err1) } - left -= limit - ids = ids[limit:] } for _, issue := range issues { diff --git a/models/issues/reaction.go b/models/issues/reaction.go index 9a277a8c12..21975c6b00 100644 --- a/models/issues/reaction.go +++ b/models/issues/reaction.go @@ -7,6 +7,7 @@ import ( "bytes" "context" "fmt" + "slices" "forgejo.org/models/db" repo_model "forgejo.org/models/repo" @@ -178,13 +179,12 @@ func FindReactions(ctx context.Context, opts FindReactionsOptions) (ReactionList func getReactionsForComments(ctx context.Context, issueID int64, commentIDs []int64) (map[int64]ReactionList, error) { reactions := make(map[int64]ReactionList, len(commentIDs)) - left := len(commentIDs) - for left > 0 { - limit := min(left, db.DefaultMaxInSize) + + for commentIDChunk := range slices.Chunk(commentIDs, db.DefaultMaxInSize) { rows, err := db.GetEngine(ctx). Where(builder.Eq{"issue_id": issueID}). In("reaction.`type`", setting.UI.Reactions). - In("comment_id", commentIDs[:limit]). + In("comment_id", commentIDChunk). Rows(&Reaction{}) if err != nil { return nil, err @@ -201,8 +201,6 @@ func getReactionsForComments(ctx context.Context, issueID int64, commentIDs []in } _ = rows.Close() - left -= limit - commentIDs = commentIDs[limit:] } return reactions, nil } diff --git a/tests/integration/db_query_test.go b/tests/integration/db_query_test.go new file mode 100644 index 0000000000..799d3219e8 --- /dev/null +++ b/tests/integration/db_query_test.go @@ -0,0 +1,64 @@ +// Copyright 2026 The Forgejo Authors. All rights reserved. +// SPDX-License-Identifier: GPL-3.0-or-later + +package integration + +import ( + "fmt" + "testing" + + actions_model "forgejo.org/models/actions" + "forgejo.org/models/db" + "forgejo.org/tests" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// These are basically unit tests, but by running them in the integration test suite they are tested against all +// supported database types. + +func TestDatabaseDefaultMaxInSize(t *testing.T) { + defer tests.PrepareTestEnv(t)() + + // Ensure there are more than db.DefaultMaxInSize objects in a table: + targetCount := db.DefaultMaxInSize * 2 + for i := range targetCount { + _, err := actions_model.InsertVariable(t.Context(), 2, 2, fmt.Sprintf("VAR_%d", i), fmt.Sprintf("Value %d", i)) + require.NoError(t, err) + } + + t.Run("GetByIDs", func(t *testing.T) { + defer tests.PrintCurrentTest(t)() + + allActionVariables := make([]*actions_model.ActionVariable, 0, targetCount) + err := db.GetEngine(t.Context()).Find(&allActionVariables) + require.NoError(t, err) + + allIDs := make([]int64, len(allActionVariables)) + for i := range allActionVariables { + allIDs[i] = allActionVariables[i].ID + } + + allActionVariablesAgain, err := db.GetByIDs(t.Context(), "id", allIDs, &actions_model.ActionVariable{}) + require.NoError(t, err) + assert.Len(t, allActionVariablesAgain, len(allActionVariables)) + }) + + t.Run("GetByFieldIn", func(t *testing.T) { + defer tests.PrintCurrentTest(t)() + + allActionVariables := make([]*actions_model.ActionVariable, 0, targetCount) + err := db.GetEngine(t.Context()).Find(&allActionVariables) + require.NoError(t, err) + + allIDs := make([]int64, len(allActionVariables)) + for i := range allActionVariables { + allIDs[i] = allActionVariables[i].ID + } + + allActionVariablesAgain, err := db.GetByFieldIn(t.Context(), "id", allIDs, &actions_model.ActionVariable{}) + require.NoError(t, err) + assert.Len(t, allActionVariablesAgain, len(allActionVariables)) + }) +}