From 0f9412c2ac1b18b4891aca8958512c6f96131824 Mon Sep 17 00:00:00 2001 From: Venkatesh Kotwade Date: Sun, 25 Aug 2024 01:49:35 +0530 Subject: [PATCH] feat: Add detection for nested contexts in function literals --- pkg/analyzer/analyzer.go | 21 ++++++++++++++++++++- testdata/src/example.go | 14 ++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/pkg/analyzer/analyzer.go b/pkg/analyzer/analyzer.go index c3be4e9..dda17eb 100644 --- a/pkg/analyzer/analyzer.go +++ b/pkg/analyzer/analyzer.go @@ -28,6 +28,7 @@ func run(pass *analysis.Pass) (interface{}, error) { nodeFilter := []ast.Node{ (*ast.ForStmt)(nil), (*ast.RangeStmt)(nil), + (*ast.FuncLit)(nil), } inspctr.Preorder(nodeFilter, func(node ast.Node) { @@ -65,7 +66,7 @@ func run(pass *analysis.Pass) (interface{}, error) { pass.Report(analysis.Diagnostic{ Pos: assignStmt.Pos(), - Message: "nested context in loop", + Message: getReportMessage(node), SuggestedFixes: fixes, }) @@ -74,6 +75,19 @@ func run(pass *analysis.Pass) (interface{}, error) { return nil, nil } +func getReportMessage(node ast.Node) string { + switch node.(type) { + case *ast.ForStmt: + return "nested context in loop" + case *ast.RangeStmt: + return "nested context in loop" + case *ast.FuncLit: + return "nested context in function literal" + default: + return "nested context" + } +} + func getBody(node ast.Node) (*ast.BlockStmt, error) { forStmt, ok := node.(*ast.ForStmt) if ok { @@ -85,6 +99,11 @@ func getBody(node ast.Node) (*ast.BlockStmt, error) { return rangeStmt.Body, nil } + funcLit, ok := node.(*ast.FuncLit) + if ok { + return funcLit.Body, nil + } + return nil, errUnknown } diff --git a/testdata/src/example.go b/testdata/src/example.go index c8038ba..b96d514 100644 --- a/testdata/src/example.go +++ b/testdata/src/example.go @@ -59,6 +59,12 @@ func example() { break } + + // detects contexts wrapped in function literals (this is risky as function literals can be called multiple times) + _ = func() { + ctx = wrapContext(ctx) // want "nested context in function literal" + } + } func wrapContext(ctx context.Context) context.Context { @@ -180,3 +186,11 @@ func inVariousNestedBlocks(ctx context.Context) { break } } + +// this middleware could run on every request, bloating the request parameter level context and causing a memory leak +func memoryLeakCausingMiddleware(ctx context.Context) func(ctx context.Context) error { + return func(ctx context.Context) error { + ctx = wrapContext(ctx) // want "nested context in function literal" + return doSomething() + } +}