From b46089e7869a7690aa60dace57f9b294e66b2c30 Mon Sep 17 00:00:00 2001 From: Venkatesh Kotwade Date: Sun, 25 Aug 2024 02:22:31 +0530 Subject: [PATCH] feat: Improve detection of nested contexts in function literals --- pkg/analyzer/analyzer.go | 15 ++++++++++++++- testdata/src/example.go | 26 +++++++++++++++++++++++--- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/pkg/analyzer/analyzer.go b/pkg/analyzer/analyzer.go index dda17eb..0840a06 100644 --- a/pkg/analyzer/analyzer.go +++ b/pkg/analyzer/analyzer.go @@ -7,6 +7,7 @@ import ( "go/ast" "go/printer" "go/token" + "go/types" "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/analysis/passes/inspect" @@ -174,7 +175,7 @@ func findNestedContext(pass *analysis.Pass, block *ast.BlockStmt, stmts []ast.St // allow assignment to non-pointer children of values defined within the loop if lhs := getRootIdent(pass, assignStmt.Lhs[0]); lhs != nil { if obj := pass.TypesInfo.ObjectOf(lhs); obj != nil { - if obj.Pos() >= block.Pos() && obj.Pos() < block.End() { + if checkObjectScopeWithinBlock(obj.Parent(), block) { continue // definition is within the loop } } @@ -186,6 +187,18 @@ func findNestedContext(pass *analysis.Pass, block *ast.BlockStmt, stmts []ast.St return nil } +func checkObjectScopeWithinBlock(scope *types.Scope, block *ast.BlockStmt) bool { + if scope == nil { + return false + } + + if scope.Pos() >= block.Pos() && scope.End() <= block.End() { + return true + } + + return false +} + func getRootIdent(pass *analysis.Pass, node ast.Node) *ast.Ident { for { switch n := node.(type) { diff --git a/testdata/src/example.go b/testdata/src/example.go index b96d514..12f69a5 100644 --- a/testdata/src/example.go +++ b/testdata/src/example.go @@ -188,9 +188,29 @@ func inVariousNestedBlocks(ctx context.Context) { } // this middleware could run on every request, bloating the request parameter level context and causing a memory leak -func memoryLeakCausingMiddleware(ctx context.Context) func(ctx context.Context) error { - return func(ctx context.Context) error { +func badMiddleware(ctx context.Context) func() error { + return func() error { ctx = wrapContext(ctx) // want "nested context in function literal" - return doSomething() + return doSomethingWithCtx(ctx) } } + +// this middleware is fine, as it doesn't modify the context of parent function +func okMiddleware(ctx context.Context) func() error { + return func() error { + ctx := wrapContext(ctx) + return doSomethingWithCtx(ctx) + } +} + +// this middleware is fine, as it only modifies the context passed to it +func okMiddleware2(ctx context.Context) func(ctx context.Context) error { + return func(ctx context.Context) error { + ctx = wrapContext(ctx) + return doSomethingWithCtx(ctx) + } +} + +func doSomethingWithCtx(ctx context.Context) error { + return nil +}