From 0f9412c2ac1b18b4891aca8958512c6f96131824 Mon Sep 17 00:00:00 2001 From: Venkatesh Kotwade Date: Sun, 25 Aug 2024 01:49:35 +0530 Subject: [PATCH 1/7] feat: Add detection for nested contexts in function literals --- pkg/analyzer/analyzer.go | 21 ++++++++++++++++++++- testdata/src/example.go | 14 ++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/pkg/analyzer/analyzer.go b/pkg/analyzer/analyzer.go index c3be4e9..dda17eb 100644 --- a/pkg/analyzer/analyzer.go +++ b/pkg/analyzer/analyzer.go @@ -28,6 +28,7 @@ func run(pass *analysis.Pass) (interface{}, error) { nodeFilter := []ast.Node{ (*ast.ForStmt)(nil), (*ast.RangeStmt)(nil), + (*ast.FuncLit)(nil), } inspctr.Preorder(nodeFilter, func(node ast.Node) { @@ -65,7 +66,7 @@ func run(pass *analysis.Pass) (interface{}, error) { pass.Report(analysis.Diagnostic{ Pos: assignStmt.Pos(), - Message: "nested context in loop", + Message: getReportMessage(node), SuggestedFixes: fixes, }) @@ -74,6 +75,19 @@ func run(pass *analysis.Pass) (interface{}, error) { return nil, nil } +func getReportMessage(node ast.Node) string { + switch node.(type) { + case *ast.ForStmt: + return "nested context in loop" + case *ast.RangeStmt: + return "nested context in loop" + case *ast.FuncLit: + return "nested context in function literal" + default: + return "nested context" + } +} + func getBody(node ast.Node) (*ast.BlockStmt, error) { forStmt, ok := node.(*ast.ForStmt) if ok { @@ -85,6 +99,11 @@ func getBody(node ast.Node) (*ast.BlockStmt, error) { return rangeStmt.Body, nil } + funcLit, ok := node.(*ast.FuncLit) + if ok { + return funcLit.Body, nil + } + return nil, errUnknown } diff --git a/testdata/src/example.go b/testdata/src/example.go index c8038ba..b96d514 100644 --- a/testdata/src/example.go +++ b/testdata/src/example.go @@ -59,6 +59,12 @@ func example() { break } + + // detects contexts wrapped in function literals (this is risky as function literals can be called multiple times) + _ = func() { + ctx = wrapContext(ctx) // want "nested context in function literal" + } + } func wrapContext(ctx context.Context) context.Context { @@ -180,3 +186,11 @@ func inVariousNestedBlocks(ctx context.Context) { break } } + +// this middleware could run on every request, bloating the request parameter level context and causing a memory leak +func memoryLeakCausingMiddleware(ctx context.Context) func(ctx context.Context) error { + return func(ctx context.Context) error { + ctx = wrapContext(ctx) // want "nested context in function literal" + return doSomething() + } +} From b46089e7869a7690aa60dace57f9b294e66b2c30 Mon Sep 17 00:00:00 2001 From: Venkatesh Kotwade Date: Sun, 25 Aug 2024 02:22:31 +0530 Subject: [PATCH 2/7] feat: Improve detection of nested contexts in function literals --- pkg/analyzer/analyzer.go | 15 ++++++++++++++- testdata/src/example.go | 26 +++++++++++++++++++++++--- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/pkg/analyzer/analyzer.go b/pkg/analyzer/analyzer.go index dda17eb..0840a06 100644 --- a/pkg/analyzer/analyzer.go +++ b/pkg/analyzer/analyzer.go @@ -7,6 +7,7 @@ import ( "go/ast" "go/printer" "go/token" + "go/types" "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/analysis/passes/inspect" @@ -174,7 +175,7 @@ func findNestedContext(pass *analysis.Pass, block *ast.BlockStmt, stmts []ast.St // allow assignment to non-pointer children of values defined within the loop if lhs := getRootIdent(pass, assignStmt.Lhs[0]); lhs != nil { if obj := pass.TypesInfo.ObjectOf(lhs); obj != nil { - if obj.Pos() >= block.Pos() && obj.Pos() < block.End() { + if checkObjectScopeWithinBlock(obj.Parent(), block) { continue // definition is within the loop } } @@ -186,6 +187,18 @@ func findNestedContext(pass *analysis.Pass, block *ast.BlockStmt, stmts []ast.St return nil } +func checkObjectScopeWithinBlock(scope *types.Scope, block *ast.BlockStmt) bool { + if scope == nil { + return false + } + + if scope.Pos() >= block.Pos() && scope.End() <= block.End() { + return true + } + + return false +} + func getRootIdent(pass *analysis.Pass, node ast.Node) *ast.Ident { for { switch n := node.(type) { diff --git a/testdata/src/example.go b/testdata/src/example.go index b96d514..12f69a5 100644 --- a/testdata/src/example.go +++ b/testdata/src/example.go @@ -188,9 +188,29 @@ func inVariousNestedBlocks(ctx context.Context) { } // this middleware could run on every request, bloating the request parameter level context and causing a memory leak -func memoryLeakCausingMiddleware(ctx context.Context) func(ctx context.Context) error { - return func(ctx context.Context) error { +func badMiddleware(ctx context.Context) func() error { + return func() error { ctx = wrapContext(ctx) // want "nested context in function literal" - return doSomething() + return doSomethingWithCtx(ctx) } } + +// this middleware is fine, as it doesn't modify the context of parent function +func okMiddleware(ctx context.Context) func() error { + return func() error { + ctx := wrapContext(ctx) + return doSomethingWithCtx(ctx) + } +} + +// this middleware is fine, as it only modifies the context passed to it +func okMiddleware2(ctx context.Context) func(ctx context.Context) error { + return func(ctx context.Context) error { + ctx = wrapContext(ctx) + return doSomethingWithCtx(ctx) + } +} + +func doSomethingWithCtx(ctx context.Context) error { + return nil +} From 07aa8cc6a24565bd360de49e7714e8fa3b637c47 Mon Sep 17 00:00:00 2001 From: Venkatesh Kotwade Date: Sun, 25 Aug 2024 03:12:41 +0530 Subject: [PATCH 3/7] refactor: Update getReportMessage function to handle unsupported nested context types --- pkg/analyzer/analyzer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/analyzer/analyzer.go b/pkg/analyzer/analyzer.go index 0840a06..9ccdeed 100644 --- a/pkg/analyzer/analyzer.go +++ b/pkg/analyzer/analyzer.go @@ -85,7 +85,7 @@ func getReportMessage(node ast.Node) string { case *ast.FuncLit: return "nested context in function literal" default: - return "nested context" + return "unsupported nested context type" } } From 71bde6a5f63dec7a55a5556498a5d61f8a72a93e Mon Sep 17 00:00:00 2001 From: Venkatesh Kotwade Date: Mon, 26 Aug 2024 01:28:11 +0530 Subject: [PATCH 4/7] use node instead of block --- pkg/analyzer/analyzer.go | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/pkg/analyzer/analyzer.go b/pkg/analyzer/analyzer.go index 9ccdeed..677c097 100644 --- a/pkg/analyzer/analyzer.go +++ b/pkg/analyzer/analyzer.go @@ -38,7 +38,7 @@ func run(pass *analysis.Pass) (interface{}, error) { return } - assignStmt := findNestedContext(pass, body, body.List) + assignStmt := findNestedContext(pass, node, body.List) if assignStmt == nil { return } @@ -108,7 +108,7 @@ func getBody(node ast.Node) (*ast.BlockStmt, error) { return nil, errUnknown } -func findNestedContext(pass *analysis.Pass, block *ast.BlockStmt, stmts []ast.Stmt) *ast.AssignStmt { +func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *ast.AssignStmt { for _, stmt := range stmts { // Recurse if necessary if inner, ok := stmt.(*ast.BlockStmt); ok { @@ -119,35 +119,35 @@ func findNestedContext(pass *analysis.Pass, block *ast.BlockStmt, stmts []ast.St } if inner, ok := stmt.(*ast.IfStmt); ok { - found := findNestedContext(pass, inner.Body, inner.Body.List) + found := findNestedContext(pass, inner, inner.Body.List) if found != nil { return found } } if inner, ok := stmt.(*ast.SwitchStmt); ok { - found := findNestedContext(pass, inner.Body, inner.Body.List) + found := findNestedContext(pass, inner, inner.Body.List) if found != nil { return found } } if inner, ok := stmt.(*ast.CaseClause); ok { - found := findNestedContext(pass, block, inner.Body) + found := findNestedContext(pass, node, inner.Body) if found != nil { return found } } if inner, ok := stmt.(*ast.SelectStmt); ok { - found := findNestedContext(pass, inner.Body, inner.Body.List) + found := findNestedContext(pass, inner, inner.Body.List) if found != nil { return found } } if inner, ok := stmt.(*ast.CommClause); ok { - found := findNestedContext(pass, block, inner.Body) + found := findNestedContext(pass, node, inner.Body) if found != nil { return found } @@ -175,7 +175,7 @@ func findNestedContext(pass *analysis.Pass, block *ast.BlockStmt, stmts []ast.St // allow assignment to non-pointer children of values defined within the loop if lhs := getRootIdent(pass, assignStmt.Lhs[0]); lhs != nil { if obj := pass.TypesInfo.ObjectOf(lhs); obj != nil { - if checkObjectScopeWithinBlock(obj.Parent(), block) { + if checkObjectScopeWithinNode(obj.Parent(), node) { continue // definition is within the loop } } @@ -187,12 +187,12 @@ func findNestedContext(pass *analysis.Pass, block *ast.BlockStmt, stmts []ast.St return nil } -func checkObjectScopeWithinBlock(scope *types.Scope, block *ast.BlockStmt) bool { +func checkObjectScopeWithinNode(scope *types.Scope, node ast.Node) bool { if scope == nil { return false } - if scope.Pos() >= block.Pos() && scope.End() <= block.End() { + if scope.Pos() >= node.Pos() && scope.End() <= node.End() { return true } From 89a1841d57fbde3a1332114d16b6f8622d8ba2ce Mon Sep 17 00:00:00 2001 From: Venkatesh Kotwade Date: Mon, 26 Aug 2024 01:29:37 +0530 Subject: [PATCH 5/7] refactor: use multi case --- pkg/analyzer/analyzer.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pkg/analyzer/analyzer.go b/pkg/analyzer/analyzer.go index 677c097..611f303 100644 --- a/pkg/analyzer/analyzer.go +++ b/pkg/analyzer/analyzer.go @@ -78,9 +78,7 @@ func run(pass *analysis.Pass) (interface{}, error) { func getReportMessage(node ast.Node) string { switch node.(type) { - case *ast.ForStmt: - return "nested context in loop" - case *ast.RangeStmt: + case *ast.ForStmt, *ast.RangeStmt: return "nested context in loop" case *ast.FuncLit: return "nested context in function literal" From 387c533fad143d77bb0b4f54c14609fa197e4f8c Mon Sep 17 00:00:00 2001 From: Venkatesh Kotwade Date: Mon, 26 Aug 2024 01:35:09 +0530 Subject: [PATCH 6/7] added one more case --- testdata/src/example.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/testdata/src/example.go b/testdata/src/example.go index 12f69a5..e72813f 100644 --- a/testdata/src/example.go +++ b/testdata/src/example.go @@ -65,6 +65,12 @@ func example() { ctx = wrapContext(ctx) // want "nested context in function literal" } + // this is fine because the context is created in the loop + for { + if ctx := context.Background(); doSomething() != nil { + ctx = wrapContext(ctx) + } + } } func wrapContext(ctx context.Context) context.Context { From c2c0e62d59d83c811c9ac863693006cd3e09c7e9 Mon Sep 17 00:00:00 2001 From: Venkatesh Kotwade Date: Mon, 26 Aug 2024 01:41:16 +0530 Subject: [PATCH 7/7] feat: also added support for multiple contexts --- pkg/analyzer/analyzer.go | 10 +++++----- testdata/src/example.go | 8 ++++++++ 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/pkg/analyzer/analyzer.go b/pkg/analyzer/analyzer.go index 611f303..5cd7108 100644 --- a/pkg/analyzer/analyzer.go +++ b/pkg/analyzer/analyzer.go @@ -110,21 +110,21 @@ func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *as for _, stmt := range stmts { // Recurse if necessary if inner, ok := stmt.(*ast.BlockStmt); ok { - found := findNestedContext(pass, inner, inner.List) + found := findNestedContext(pass, node, inner.List) if found != nil { return found } } if inner, ok := stmt.(*ast.IfStmt); ok { - found := findNestedContext(pass, inner, inner.Body.List) + found := findNestedContext(pass, node, inner.Body.List) if found != nil { return found } } if inner, ok := stmt.(*ast.SwitchStmt); ok { - found := findNestedContext(pass, inner, inner.Body.List) + found := findNestedContext(pass, node, inner.Body.List) if found != nil { return found } @@ -138,7 +138,7 @@ func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *as } if inner, ok := stmt.(*ast.SelectStmt); ok { - found := findNestedContext(pass, inner, inner.Body.List) + found := findNestedContext(pass, node, inner.Body.List) if found != nil { return found } @@ -167,7 +167,7 @@ func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *as } if assignStmt.Tok == token.DEFINE { - break + continue } // allow assignment to non-pointer children of values defined within the loop diff --git a/testdata/src/example.go b/testdata/src/example.go index e72813f..565ee49 100644 --- a/testdata/src/example.go +++ b/testdata/src/example.go @@ -71,6 +71,14 @@ func example() { ctx = wrapContext(ctx) } } + + for { + ctx2 := context.Background() + ctx = wrapContext(ctx) // want "nested context in loop" + if doSomething() != nil { + ctx2 = wrapContext(ctx2) + } + } } func wrapContext(ctx context.Context) context.Context {