feat: Add detection for nested contexts in function literals

This commit is contained in:
Venkatesh Kotwade 2024-08-25 01:49:35 +05:30
parent 0be9888cea
commit 6ddf255ca8
2 changed files with 34 additions and 1 deletions

View file

@ -28,6 +28,7 @@ func run(pass *analysis.Pass) (interface{}, error) {
nodeFilter := []ast.Node{ nodeFilter := []ast.Node{
(*ast.ForStmt)(nil), (*ast.ForStmt)(nil),
(*ast.RangeStmt)(nil), (*ast.RangeStmt)(nil),
(*ast.FuncLit)(nil),
} }
inspctr.Preorder(nodeFilter, func(node ast.Node) { inspctr.Preorder(nodeFilter, func(node ast.Node) {
@ -65,7 +66,7 @@ func run(pass *analysis.Pass) (interface{}, error) {
pass.Report(analysis.Diagnostic{ pass.Report(analysis.Diagnostic{
Pos: assignStmt.Pos(), Pos: assignStmt.Pos(),
Message: "nested context in loop", Message: getReportMessage(node),
SuggestedFixes: fixes, SuggestedFixes: fixes,
}) })
@ -74,6 +75,19 @@ func run(pass *analysis.Pass) (interface{}, error) {
return nil, nil return nil, nil
} }
func getReportMessage(node ast.Node) string {
switch node.(type) {
case *ast.ForStmt:
return "nested context in loop"
case *ast.RangeStmt:
return "nested context in loop"
case *ast.FuncLit:
return "nested context in function literal"
default:
return "nested context"
}
}
func getBody(node ast.Node) (*ast.BlockStmt, error) { func getBody(node ast.Node) (*ast.BlockStmt, error) {
forStmt, ok := node.(*ast.ForStmt) forStmt, ok := node.(*ast.ForStmt)
if ok { if ok {
@ -85,6 +99,11 @@ func getBody(node ast.Node) (*ast.BlockStmt, error) {
return rangeStmt.Body, nil return rangeStmt.Body, nil
} }
funcLit, ok := node.(*ast.FuncLit)
if ok {
return funcLit.Body, nil
}
return nil, errUnknown return nil, errUnknown
} }

View file

@ -59,6 +59,12 @@ func example() {
break break
} }
// detects contexts wrapped in function literals (this is risky as function literals can be called multiple times)
_ = func() {
ctx = wrapContext(ctx) // want "nested context in function literal"
}
} }
func wrapContext(ctx context.Context) context.Context { func wrapContext(ctx context.Context) context.Context {
@ -180,3 +186,11 @@ func inVariousNestedBlocks(ctx context.Context) {
break break
} }
} }
// this middleware could run on every request, bloating the request parameter level context and causing a memory leak
func memoryLeakCausingMiddleware(ctx context.Context) func(ctx context.Context) error {
return func(ctx context.Context) error {
ctx = wrapContext(ctx) // want "nested context in function literal"
return doSomething()
}
}