From 5b689092efcee5d4fc1d6610ed551f88ac7ab119 Mon Sep 17 00:00:00 2001 From: Gabriel Augendre <gabriel@augendre.info> Date: Mon, 13 Jan 2025 15:46:00 +0100 Subject: [PATCH 1/2] feat: detect potential nested contexts in function declarations --- pkg/analyzer/analyzer.go | 25 ++++++++++++------------- testdata/src/example.go | 22 ++++++++++++++++++++++ 2 files changed, 34 insertions(+), 13 deletions(-) diff --git a/pkg/analyzer/analyzer.go b/pkg/analyzer/analyzer.go index b45dbc3..c4488bc 100644 --- a/pkg/analyzer/analyzer.go +++ b/pkg/analyzer/analyzer.go @@ -30,6 +30,7 @@ func run(pass *analysis.Pass) (interface{}, error) { (*ast.ForStmt)(nil), (*ast.RangeStmt)(nil), (*ast.FuncLit)(nil), + (*ast.FuncDecl)(nil), } inspctr.Preorder(nodeFilter, func(node ast.Node) { @@ -81,25 +82,23 @@ func getReportMessage(node ast.Node) string { return "nested context in loop" case *ast.FuncLit: return "nested context in function literal" + case *ast.FuncDecl: + return "potential nested context in function declaration" default: return "unsupported nested context type" } } func getBody(node ast.Node) (*ast.BlockStmt, error) { - forStmt, ok := node.(*ast.ForStmt) - if ok { - return forStmt.Body, nil - } - - rangeStmt, ok := node.(*ast.RangeStmt) - if ok { - return rangeStmt.Body, nil - } - - funcLit, ok := node.(*ast.FuncLit) - if ok { - return funcLit.Body, nil + switch typedNode := node.(type) { + case *ast.ForStmt: + return typedNode.Body, nil + case *ast.RangeStmt: + return typedNode.Body, nil + case *ast.FuncLit: + return typedNode.Body, nil + case *ast.FuncDecl: + return typedNode.Body, nil } return nil, errUnknown diff --git a/testdata/src/example.go b/testdata/src/example.go index df76f89..2dc714b 100644 --- a/testdata/src/example.go +++ b/testdata/src/example.go @@ -249,3 +249,25 @@ func testCasesInit(t *testing.T) { }) } } + +type Container struct { + Ctx context.Context +} + +func something() func(*Container) { + return func(r *Container) { + ctx := r.Ctx + ctx = context.WithValue(ctx, "key", "val") + r.Ctx = ctx // want "nested context in function literal" + } +} + +func other() func(*Container) { + return blah +} + +func blah(r *Container) { + ctx := r.Ctx + ctx = context.WithValue(ctx, "key", "val") + r.Ctx = ctx // want "potential nested context in function declaration" +} From 9371bcfb562befa60fbf24368c438c3055eab8ba Mon Sep 17 00:00:00 2001 From: Gabriel Augendre <gabriel@augendre.info> Date: Mon, 13 Jan 2025 23:08:04 +0100 Subject: [PATCH 2/2] refactor: nested context recursion --- pkg/analyzer/analyzer.go | 41 +++++++++++++--------------------------- 1 file changed, 13 insertions(+), 28 deletions(-) diff --git a/pkg/analyzer/analyzer.go b/pkg/analyzer/analyzer.go index c4488bc..5b28078 100644 --- a/pkg/analyzer/analyzer.go +++ b/pkg/analyzer/analyzer.go @@ -107,44 +107,29 @@ func getBody(node ast.Node) (*ast.BlockStmt, error) { func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *ast.AssignStmt { for _, stmt := range stmts { // Recurse if necessary - if inner, ok := stmt.(*ast.BlockStmt); ok { - found := findNestedContext(pass, node, inner.List) - if found != nil { + switch typedStmt := stmt.(type) { + case *ast.BlockStmt: + if found := findNestedContext(pass, node, typedStmt.List); found != nil { return found } - } - - if inner, ok := stmt.(*ast.IfStmt); ok { - found := findNestedContext(pass, node, inner.Body.List) - if found != nil { + case *ast.IfStmt: + if found := findNestedContext(pass, node, typedStmt.Body.List); found != nil { return found } - } - - if inner, ok := stmt.(*ast.SwitchStmt); ok { - found := findNestedContext(pass, node, inner.Body.List) - if found != nil { + case *ast.SwitchStmt: + if found := findNestedContext(pass, node, typedStmt.Body.List); found != nil { return found } - } - - if inner, ok := stmt.(*ast.CaseClause); ok { - found := findNestedContext(pass, node, inner.Body) - if found != nil { + case *ast.CaseClause: + if found := findNestedContext(pass, node, typedStmt.Body); found != nil { return found } - } - - if inner, ok := stmt.(*ast.SelectStmt); ok { - found := findNestedContext(pass, node, inner.Body.List) - if found != nil { + case *ast.SelectStmt: + if found := findNestedContext(pass, node, typedStmt.Body.List); found != nil { return found } - } - - if inner, ok := stmt.(*ast.CommClause); ok { - found := findNestedContext(pass, node, inner.Body) - if found != nil { + case *ast.CommClause: + if found := findNestedContext(pass, node, typedStmt.Body); found != nil { return found } }