From 54e593c1c6f898d0cdafdf5b72021db212e5f44b Mon Sep 17 00:00:00 2001 From: Gabriel Augendre Date: Mon, 13 Jan 2025 15:34:16 +0100 Subject: [PATCH] feat: ignore context.TODO and context.Background Related to #34 --- pkg/analyzer/analyzer.go | 71 +++++++++++++++++++++++++++------------- testdata/src/example.go | 23 ++++++++++++- 2 files changed, 71 insertions(+), 23 deletions(-) diff --git a/pkg/analyzer/analyzer.go b/pkg/analyzer/analyzer.go index 7b88bf5..b45dbc3 100644 --- a/pkg/analyzer/analyzer.go +++ b/pkg/analyzer/analyzer.go @@ -7,7 +7,7 @@ import ( "go/ast" "go/printer" "go/token" - "go/types" + "slices" "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/analysis/passes/inspect" @@ -169,13 +169,14 @@ func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *as continue } + // Ignore [context.Background] & [context.TODO]. + if isContextFunction(assignStmt.Rhs[0], "Background", "TODO") { + continue + } + // allow assignment to non-pointer children of values defined within the loop - if lhs := getRootIdent(pass, assignStmt.Lhs[0]); lhs != nil { - if obj := pass.TypesInfo.ObjectOf(lhs); obj != nil { - if checkObjectScopeWithinNode(obj.Parent(), node) { - continue // definition is within the loop - } - } + if isWithinLoop(assignStmt.Lhs[0], node, pass) { + continue } return assignStmt @@ -184,16 +185,51 @@ func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *as return nil } -func checkObjectScopeWithinNode(scope *types.Scope, node ast.Node) bool { +// render returns the pretty-print of the given node +func render(fset *token.FileSet, x interface{}) ([]byte, error) { + var buf bytes.Buffer + if err := printer.Fprint(&buf, fset, x); err != nil { + return nil, fmt.Errorf("printing node: %w", err) + } + return buf.Bytes(), nil +} + +func isContextFunction(exp ast.Expr, fnName ...string) bool { + call, ok := exp.(*ast.CallExpr) + if !ok { + return false + } + + selector, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return false + } + + ident, ok := selector.X.(*ast.Ident) + if !ok { + return false + } + + return ident.Name == "context" && slices.Contains(fnName, selector.Sel.Name) +} + +func isWithinLoop(exp ast.Expr, node ast.Node, pass *analysis.Pass) bool { + lhs := getRootIdent(pass, exp) + if lhs == nil { + return false + } + + obj := pass.TypesInfo.ObjectOf(lhs) + if obj == nil { + return false + } + + scope := obj.Parent() if scope == nil { return false } - if scope.Pos() >= node.Pos() && scope.End() <= node.End() { - return true - } - - return false + return scope.Pos() >= node.Pos() && scope.End() <= node.End() } func getRootIdent(pass *analysis.Pass, node ast.Node) *ast.Ident { @@ -213,12 +249,3 @@ func getRootIdent(pass *analysis.Pass, node ast.Node) *ast.Ident { } } } - -// render returns the pretty-print of the given node -func render(fset *token.FileSet, x interface{}) ([]byte, error) { - var buf bytes.Buffer - if err := printer.Fprint(&buf, fset, x); err != nil { - return nil, fmt.Errorf("printing node: %w", err) - } - return buf.Bytes(), nil -} diff --git a/testdata/src/example.go b/testdata/src/example.go index 565ee49..df76f89 100644 --- a/testdata/src/example.go +++ b/testdata/src/example.go @@ -1,6 +1,9 @@ package src -import "context" +import ( + "context" + "testing" +) func example() { ctx := context.Background() @@ -228,3 +231,21 @@ func okMiddleware2(ctx context.Context) func(ctx context.Context) error { func doSomethingWithCtx(ctx context.Context) error { return nil } + +func testCasesInit(t *testing.T) { + cases := []struct { + ctx context.Context + }{ + {}, + { + ctx: context.WithValue(context.Background(), "key", "value"), + }, + } + for _, tc := range cases { + t.Run("some test", func(t *testing.T) { + if tc.ctx == nil { + tc.ctx = context.Background() + } + }) + } +}