From 091030580e85ab39e1dace70689e41a9a220b5b0 Mon Sep 17 00:00:00 2001 From: Gabriel Augendre Date: Fri, 12 Jul 2024 12:04:01 +0200 Subject: [PATCH] feat: detect in nested blocks --- pkg/analyzer/analyzer.go | 168 ++++++++++++++++++++++++++------------- testdata/src/example.go | 109 +++++++++++++++++++++++++ 2 files changed, 220 insertions(+), 57 deletions(-) diff --git a/pkg/analyzer/analyzer.go b/pkg/analyzer/analyzer.go index 90c5333..c3be4e9 100644 --- a/pkg/analyzer/analyzer.go +++ b/pkg/analyzer/analyzer.go @@ -36,64 +36,39 @@ func run(pass *analysis.Pass) (interface{}, error) { return } - for _, stmt := range body.List { - assignStmt, ok := stmt.(*ast.AssignStmt) - if !ok { - continue - } - - t := pass.TypesInfo.TypeOf(assignStmt.Lhs[0]) - if t == nil { - continue - } - - if t.String() != "context.Context" { - continue - } - - if assignStmt.Tok == token.DEFINE { - break - } - - // allow assignment to non-pointer children of values defined within the loop - if lhs := getRootIdent(pass, assignStmt.Lhs[0]); lhs != nil { - if obj := pass.TypesInfo.ObjectOf(lhs); obj != nil { - if obj.Pos() >= body.Pos() && obj.Pos() < body.End() { - continue // definition is within the loop - } - } - } - - suggestedStmt := ast.AssignStmt{ - Lhs: assignStmt.Lhs, - TokPos: assignStmt.TokPos, - Tok: token.DEFINE, - Rhs: assignStmt.Rhs, - } - suggested, err := render(pass.Fset, &suggestedStmt) - - var fixes []analysis.SuggestedFix - if err == nil { - fixes = append(fixes, analysis.SuggestedFix{ - Message: "replace `=` with `:=`", - TextEdits: []analysis.TextEdit{ - { - Pos: assignStmt.Pos(), - End: assignStmt.End(), - NewText: []byte(suggested), - }, - }, - }) - } - - pass.Report(analysis.Diagnostic{ - Pos: assignStmt.Pos(), - Message: "nested context in loop", - SuggestedFixes: fixes, - }) - - break + assignStmt := findNestedContext(pass, body, body.List) + if assignStmt == nil { + return } + + suggestedStmt := ast.AssignStmt{ + Lhs: assignStmt.Lhs, + TokPos: assignStmt.TokPos, + Tok: token.DEFINE, + Rhs: assignStmt.Rhs, + } + suggested, err := render(pass.Fset, &suggestedStmt) + + var fixes []analysis.SuggestedFix + if err == nil { + fixes = append(fixes, analysis.SuggestedFix{ + Message: "replace `=` with `:=`", + TextEdits: []analysis.TextEdit{ + { + Pos: assignStmt.Pos(), + End: assignStmt.End(), + NewText: []byte(suggested), + }, + }, + }) + } + + pass.Report(analysis.Diagnostic{ + Pos: assignStmt.Pos(), + Message: "nested context in loop", + SuggestedFixes: fixes, + }) + }) return nil, nil @@ -113,6 +88,85 @@ func getBody(node ast.Node) (*ast.BlockStmt, error) { return nil, errUnknown } +func findNestedContext(pass *analysis.Pass, block *ast.BlockStmt, stmts []ast.Stmt) *ast.AssignStmt { + for _, stmt := range stmts { + // Recurse if necessary + if inner, ok := stmt.(*ast.BlockStmt); ok { + found := findNestedContext(pass, inner, inner.List) + if found != nil { + return found + } + } + + if inner, ok := stmt.(*ast.IfStmt); ok { + found := findNestedContext(pass, inner.Body, inner.Body.List) + if found != nil { + return found + } + } + + if inner, ok := stmt.(*ast.SwitchStmt); ok { + found := findNestedContext(pass, inner.Body, inner.Body.List) + if found != nil { + return found + } + } + + if inner, ok := stmt.(*ast.CaseClause); ok { + found := findNestedContext(pass, block, inner.Body) + if found != nil { + return found + } + } + + if inner, ok := stmt.(*ast.SelectStmt); ok { + found := findNestedContext(pass, inner.Body, inner.Body.List) + if found != nil { + return found + } + } + + if inner, ok := stmt.(*ast.CommClause); ok { + found := findNestedContext(pass, block, inner.Body) + if found != nil { + return found + } + } + + // Actually check for nested context + assignStmt, ok := stmt.(*ast.AssignStmt) + if !ok { + continue + } + + t := pass.TypesInfo.TypeOf(assignStmt.Lhs[0]) + if t == nil { + continue + } + + if t.String() != "context.Context" { + continue + } + + if assignStmt.Tok == token.DEFINE { + break + } + + // allow assignment to non-pointer children of values defined within the loop + if lhs := getRootIdent(pass, assignStmt.Lhs[0]); lhs != nil { + if obj := pass.TypesInfo.ObjectOf(lhs); obj != nil { + if obj.Pos() >= block.Pos() && obj.Pos() < block.End() { + continue // definition is within the loop + } + } + } + + return assignStmt + } + + return nil +} + func getRootIdent(pass *analysis.Pass, node ast.Node) *ast.Ident { for { switch n := node.(type) { diff --git a/testdata/src/example.go b/testdata/src/example.go index c9c6ea2..c8038ba 100644 --- a/testdata/src/example.go +++ b/testdata/src/example.go @@ -25,12 +25,50 @@ func example() { ctx = wrapContext(ctx) // want "nested context in loop" break } + + // not fooled by shadowing in nested blocks + for { + err := doSomething() + if err != nil { + ctx := wrapContext(ctx) + ctx = wrapContext(ctx) + } + + switch err { + case nil: + ctx := wrapContext(ctx) + ctx = wrapContext(ctx) + default: + ctx := wrapContext(ctx) + ctx = wrapContext(ctx) + } + + { + ctx := wrapContext(ctx) + ctx = wrapContext(ctx) + } + + select { + case <-ctx.Done(): + ctx := wrapContext(ctx) + ctx = wrapContext(ctx) + default: + } + + ctx = wrapContext(ctx) // want "nested context in loop" + + break + } } func wrapContext(ctx context.Context) context.Context { return context.WithoutCancel(ctx) } +func doSomething() error { + return nil +} + // storing contexts in a struct isn't recommended, but local copies of a non-pointer struct should act like local copies of a context. func inStructs(ctx context.Context) { for i := 0; i < 10; i++ { @@ -71,3 +109,74 @@ func inStructs(ctx context.Context) { rp[0].Ctx = context.WithValue(rp[0].Ctx, "other", "val") } } + +func inVariousNestedBlocks(ctx context.Context) { + for { + err := doSomething() + if err != nil { + ctx = wrapContext(ctx) // want "nested context in loop" + } + + break + } + + for { + err := doSomething() + if err != nil { + if true { + ctx = wrapContext(ctx) // want "nested context in loop" + } + } + + break + } + + for { + err := doSomething() + switch err { + case nil: + ctx = wrapContext(ctx) // want "nested context in loop" + } + + break + } + + for { + err := doSomething() + switch err { + default: + ctx = wrapContext(ctx) // want "nested context in loop" + } + + break + } + + for { + ctx := wrapContext(ctx) + + err := doSomething() + if err != nil { + ctx = wrapContext(ctx) + } + + break + } + + for { + { + ctx = wrapContext(ctx) // want "nested context in loop" + } + + break + } + + for { + select { + case <-ctx.Done(): + ctx = wrapContext(ctx) // want "nested context in loop" + default: + } + + break + } +}