Skip to content

Commit

Permalink
Detect nested contexts in function literals (#18)
Browse files Browse the repository at this point in the history
* feat: Add detection for nested contexts in function literals
* feat: Improve detection of nested contexts in function literals
* refactor: Update getReportMessage function to handle unsupported nested context types
* use node instead of block
* refactor: use multi case
* added one more case
* feat: also added support for multiple contexts
  • Loading branch information
venkycode committed Aug 26, 2024
1 parent 0d2c401 commit be0aa70
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 11 deletions.
52 changes: 41 additions & 11 deletions pkg/analyzer/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"go/ast"
"go/printer"
"go/token"
"go/types"

"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/analysis/passes/inspect"
Expand All @@ -28,6 +29,7 @@ func run(pass *analysis.Pass) (interface{}, error) {
nodeFilter := []ast.Node{
(*ast.ForStmt)(nil),
(*ast.RangeStmt)(nil),
(*ast.FuncLit)(nil),
}

inspctr.Preorder(nodeFilter, func(node ast.Node) {
Expand All @@ -36,7 +38,7 @@ func run(pass *analysis.Pass) (interface{}, error) {
return
}

assignStmt := findNestedContext(pass, body, body.List)
assignStmt := findNestedContext(pass, node, body.List)
if assignStmt == nil {
return
}
Expand Down Expand Up @@ -65,7 +67,7 @@ func run(pass *analysis.Pass) (interface{}, error) {

pass.Report(analysis.Diagnostic{
Pos: assignStmt.Pos(),
Message: "nested context in loop",
Message: getReportMessage(node),
SuggestedFixes: fixes,
})

Expand All @@ -74,6 +76,17 @@ func run(pass *analysis.Pass) (interface{}, error) {
return nil, nil
}

func getReportMessage(node ast.Node) string {
switch node.(type) {
case *ast.ForStmt, *ast.RangeStmt:
return "nested context in loop"
case *ast.FuncLit:
return "nested context in function literal"
default:
return "unsupported nested context type"
}
}

func getBody(node ast.Node) (*ast.BlockStmt, error) {
forStmt, ok := node.(*ast.ForStmt)
if ok {
Expand All @@ -85,49 +98,54 @@ func getBody(node ast.Node) (*ast.BlockStmt, error) {
return rangeStmt.Body, nil
}

funcLit, ok := node.(*ast.FuncLit)
if ok {
return funcLit.Body, nil
}

return nil, errUnknown
}

func findNestedContext(pass *analysis.Pass, block *ast.BlockStmt, stmts []ast.Stmt) *ast.AssignStmt {
func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *ast.AssignStmt {
for _, stmt := range stmts {
// Recurse if necessary
if inner, ok := stmt.(*ast.BlockStmt); ok {
found := findNestedContext(pass, inner, inner.List)
found := findNestedContext(pass, node, inner.List)
if found != nil {
return found
}
}

if inner, ok := stmt.(*ast.IfStmt); ok {
found := findNestedContext(pass, inner.Body, inner.Body.List)
found := findNestedContext(pass, node, inner.Body.List)
if found != nil {
return found
}
}

if inner, ok := stmt.(*ast.SwitchStmt); ok {
found := findNestedContext(pass, inner.Body, inner.Body.List)
found := findNestedContext(pass, node, inner.Body.List)
if found != nil {
return found
}
}

if inner, ok := stmt.(*ast.CaseClause); ok {
found := findNestedContext(pass, block, inner.Body)
found := findNestedContext(pass, node, inner.Body)
if found != nil {
return found
}
}

if inner, ok := stmt.(*ast.SelectStmt); ok {
found := findNestedContext(pass, inner.Body, inner.Body.List)
found := findNestedContext(pass, node, inner.Body.List)
if found != nil {
return found
}
}

if inner, ok := stmt.(*ast.CommClause); ok {
found := findNestedContext(pass, block, inner.Body)
found := findNestedContext(pass, node, inner.Body)
if found != nil {
return found
}
Expand All @@ -149,13 +167,13 @@ func findNestedContext(pass *analysis.Pass, block *ast.BlockStmt, stmts []ast.St
}

if assignStmt.Tok == token.DEFINE {
break
continue
}

// allow assignment to non-pointer children of values defined within the loop
if lhs := getRootIdent(pass, assignStmt.Lhs[0]); lhs != nil {
if obj := pass.TypesInfo.ObjectOf(lhs); obj != nil {
if obj.Pos() >= block.Pos() && obj.Pos() < block.End() {
if checkObjectScopeWithinNode(obj.Parent(), node) {
continue // definition is within the loop
}
}
Expand All @@ -167,6 +185,18 @@ func findNestedContext(pass *analysis.Pass, block *ast.BlockStmt, stmts []ast.St
return nil
}

func checkObjectScopeWithinNode(scope *types.Scope, node ast.Node) bool {
if scope == nil {
return false
}

if scope.Pos() >= node.Pos() && scope.End() <= node.End() {
return true
}

return false
}

func getRootIdent(pass *analysis.Pass, node ast.Node) *ast.Ident {
for {
switch n := node.(type) {
Expand Down
48 changes: 48 additions & 0 deletions testdata/src/example.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,26 @@ func example() {

break
}

// detects contexts wrapped in function literals (this is risky as function literals can be called multiple times)
_ = func() {
ctx = wrapContext(ctx) // want "nested context in function literal"
}

// this is fine because the context is created in the loop
for {
if ctx := context.Background(); doSomething() != nil {
ctx = wrapContext(ctx)
}
}

for {
ctx2 := context.Background()
ctx = wrapContext(ctx) // want "nested context in loop"
if doSomething() != nil {
ctx2 = wrapContext(ctx2)
}
}
}

func wrapContext(ctx context.Context) context.Context {
Expand Down Expand Up @@ -180,3 +200,31 @@ func inVariousNestedBlocks(ctx context.Context) {
break
}
}

// this middleware could run on every request, bloating the request parameter level context and causing a memory leak
func badMiddleware(ctx context.Context) func() error {
return func() error {
ctx = wrapContext(ctx) // want "nested context in function literal"
return doSomethingWithCtx(ctx)
}
}

// this middleware is fine, as it doesn't modify the context of parent function
func okMiddleware(ctx context.Context) func() error {
return func() error {
ctx := wrapContext(ctx)
return doSomethingWithCtx(ctx)
}
}

// this middleware is fine, as it only modifies the context passed to it
func okMiddleware2(ctx context.Context) func(ctx context.Context) error {
return func(ctx context.Context) error {
ctx = wrapContext(ctx)
return doSomethingWithCtx(ctx)
}
}

func doSomethingWithCtx(ctx context.Context) error {
return nil
}

0 comments on commit be0aa70

Please sign in to comment.