Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Protect against NPEs in notifications list (#10879) #10883

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 58 additions & 19 deletions models/notification.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package models
import (
"fmt"

"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/timeutil"
)

Expand Down Expand Up @@ -281,9 +282,9 @@ func (nl NotificationList) getPendingRepoIDs() []int64 {
}

// LoadRepos loads repositories from database
func (nl NotificationList) LoadRepos() (RepositoryList, error) {
func (nl NotificationList) LoadRepos() (RepositoryList, []int, error) {
if len(nl) == 0 {
return RepositoryList{}, nil
return RepositoryList{}, []int{}, nil
}

var repoIDs = nl.getPendingRepoIDs()
Expand All @@ -298,15 +299,15 @@ func (nl NotificationList) LoadRepos() (RepositoryList, error) {
In("id", repoIDs[:limit]).
Rows(new(Repository))
if err != nil {
return nil, err
return nil, nil, err
}

for rows.Next() {
var repo Repository
err = rows.Scan(&repo)
if err != nil {
rows.Close()
return nil, err
return nil, nil, err
}

repos[repo.ID] = &repo
Expand All @@ -317,14 +318,21 @@ func (nl NotificationList) LoadRepos() (RepositoryList, error) {
repoIDs = repoIDs[limit:]
}

failed := []int{}

var reposList = make(RepositoryList, 0, len(repoIDs))
for _, notification := range nl {
for i, notification := range nl {
if notification.Repository == nil {
notification.Repository = repos[notification.RepoID]
}
if notification.Repository == nil {
log.Error("Notification[%d]: RepoID: %d not found", notification.ID, notification.RepoID)
failed = append(failed, i)
continue
}
var found bool
for _, r := range reposList {
if r.ID == notification.Repository.ID {
if r.ID == notification.RepoID {
found = true
break
}
Expand All @@ -333,7 +341,7 @@ func (nl NotificationList) LoadRepos() (RepositoryList, error) {
reposList = append(reposList, notification.Repository)
}
}
return reposList, nil
return reposList, failed, nil
}

func (nl NotificationList) getPendingIssueIDs() []int64 {
Expand All @@ -350,9 +358,9 @@ func (nl NotificationList) getPendingIssueIDs() []int64 {
}

// LoadIssues loads issues from database
func (nl NotificationList) LoadIssues() error {
func (nl NotificationList) LoadIssues() ([]int, error) {
if len(nl) == 0 {
return nil
return []int{}, nil
}

var issueIDs = nl.getPendingIssueIDs()
Expand All @@ -367,15 +375,15 @@ func (nl NotificationList) LoadIssues() error {
In("id", issueIDs[:limit]).
Rows(new(Issue))
if err != nil {
return err
return nil, err
}

for rows.Next() {
var issue Issue
err = rows.Scan(&issue)
if err != nil {
rows.Close()
return err
return nil, err
}

issues[issue.ID] = &issue
Expand All @@ -386,13 +394,38 @@ func (nl NotificationList) LoadIssues() error {
issueIDs = issueIDs[limit:]
}

for _, notification := range nl {
failures := []int{}

for i, notification := range nl {
if notification.Issue == nil {
notification.Issue = issues[notification.IssueID]
if notification.Issue == nil {
log.Error("Notification[%d]: IssueID: %d Not Found", notification.ID, notification.IssueID)
failures = append(failures, i)
continue
}
notification.Issue.Repo = notification.Repository
}
}
return nil
return failures, nil
}

// Without returns the notification list without the failures
func (nl NotificationList) Without(failures []int) NotificationList {
if failures == nil || len(failures) == 0 {
zeripath marked this conversation as resolved.
Show resolved Hide resolved
return nl
}
remaining := make([]*Notification, 0, len(nl))
last := -1
var i int
for _, i = range failures {
remaining = append(remaining, nl[last+1:i]...)
last = i
}
if len(nl) > i {
remaining = append(remaining, nl[i+1:]...)
}
return remaining
}

func (nl NotificationList) getPendingCommentIDs() []int64 {
Expand All @@ -409,9 +442,9 @@ func (nl NotificationList) getPendingCommentIDs() []int64 {
}

// LoadComments loads comments from database
func (nl NotificationList) LoadComments() error {
func (nl NotificationList) LoadComments() ([]int, error) {
if len(nl) == 0 {
return nil
return []int{}, nil
}

var commentIDs = nl.getPendingCommentIDs()
Expand All @@ -426,15 +459,15 @@ func (nl NotificationList) LoadComments() error {
In("id", commentIDs[:limit]).
Rows(new(Comment))
if err != nil {
return err
return nil, err
}

for rows.Next() {
var comment Comment
err = rows.Scan(&comment)
if err != nil {
rows.Close()
return err
return nil, err
}

comments[comment.ID] = &comment
Expand All @@ -445,13 +478,19 @@ func (nl NotificationList) LoadComments() error {
commentIDs = commentIDs[limit:]
}

for _, notification := range nl {
failures := []int{}
for i, notification := range nl {
if notification.CommentID > 0 && notification.Comment == nil && comments[notification.CommentID] != nil {
notification.Comment = comments[notification.CommentID]
if notification.Comment == nil {
log.Error("Notification[%d]: CommentID[%d] failed to load", notification.ID, notification.CommentID)
failures = append(failures, i)
continue
}
notification.Comment.Issue = notification.Issue
}
}
return nil
return failures, nil
}

// GetNotificationCount returns the notification count for user
Expand Down
21 changes: 18 additions & 3 deletions routers/user/notification.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,24 +68,39 @@ func Notifications(c *context.Context) {
return
}

repos, err := notifications.LoadRepos()
failCount := 0

repos, failures, err := notifications.LoadRepos()
if err != nil {
c.ServerError("LoadRepos", err)
return
}
notifications = notifications.Without(failures)
if err := repos.LoadAttributes(); err != nil {
c.ServerError("LoadAttributes", err)
return
}
failCount += len(failures)

if err := notifications.LoadIssues(); err != nil {
failures, err = notifications.LoadIssues()
if err != nil {
c.ServerError("LoadIssues", err)
return
}
if err := notifications.LoadComments(); err != nil {
notifications = notifications.Without(failures)
failCount += len(failures)

failures, err = notifications.LoadComments()
if err != nil {
c.ServerError("LoadComments", err)
return
}
notifications = notifications.Without(failures)
failCount += len(failures)

if failCount > 0 {
c.Flash.Error(fmt.Sprintf("ERROR: %d notifications were removed due to missing parts - check the logs", failCount))
}

total, err := models.GetNotificationCount(c.User, status)
if err != nil {
Expand Down