feat: ignore context.TODO and context.Background

Related to #34
This commit is contained in:
Gabriel Augendre 2025-01-13 15:34:16 +01:00
parent 529e088561
commit 54e593c1c6
2 changed files with 71 additions and 23 deletions

View file

@ -7,7 +7,7 @@ import (
"go/ast"
"go/printer"
"go/token"
"go/types"
"slices"
"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/analysis/passes/inspect"
@ -169,13 +169,14 @@ func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *as
continue
}
// Ignore [context.Background] & [context.TODO].
if isContextFunction(assignStmt.Rhs[0], "Background", "TODO") {
continue
}
// allow assignment to non-pointer children of values defined within the loop
if lhs := getRootIdent(pass, assignStmt.Lhs[0]); lhs != nil {
if obj := pass.TypesInfo.ObjectOf(lhs); obj != nil {
if checkObjectScopeWithinNode(obj.Parent(), node) {
continue // definition is within the loop
}
}
if isWithinLoop(assignStmt.Lhs[0], node, pass) {
continue
}
return assignStmt
@ -184,16 +185,51 @@ func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *as
return nil
}
func checkObjectScopeWithinNode(scope *types.Scope, node ast.Node) bool {
// render returns the pretty-print of the given node
func render(fset *token.FileSet, x interface{}) ([]byte, error) {
var buf bytes.Buffer
if err := printer.Fprint(&buf, fset, x); err != nil {
return nil, fmt.Errorf("printing node: %w", err)
}
return buf.Bytes(), nil
}
func isContextFunction(exp ast.Expr, fnName ...string) bool {
call, ok := exp.(*ast.CallExpr)
if !ok {
return false
}
selector, ok := call.Fun.(*ast.SelectorExpr)
if !ok {
return false
}
ident, ok := selector.X.(*ast.Ident)
if !ok {
return false
}
return ident.Name == "context" && slices.Contains(fnName, selector.Sel.Name)
}
func isWithinLoop(exp ast.Expr, node ast.Node, pass *analysis.Pass) bool {
lhs := getRootIdent(pass, exp)
if lhs == nil {
return false
}
obj := pass.TypesInfo.ObjectOf(lhs)
if obj == nil {
return false
}
scope := obj.Parent()
if scope == nil {
return false
}
if scope.Pos() >= node.Pos() && scope.End() <= node.End() {
return true
}
return false
return scope.Pos() >= node.Pos() && scope.End() <= node.End()
}
func getRootIdent(pass *analysis.Pass, node ast.Node) *ast.Ident {
@ -213,12 +249,3 @@ func getRootIdent(pass *analysis.Pass, node ast.Node) *ast.Ident {
}
}
}
// render returns the pretty-print of the given node
func render(fset *token.FileSet, x interface{}) ([]byte, error) {
var buf bytes.Buffer
if err := printer.Fprint(&buf, fset, x); err != nil {
return nil, fmt.Errorf("printing node: %w", err)
}
return buf.Bytes(), nil
}

View file

@ -1,6 +1,9 @@
package src
import "context"
import (
"context"
"testing"
)
func example() {
ctx := context.Background()
@ -228,3 +231,21 @@ func okMiddleware2(ctx context.Context) func(ctx context.Context) error {
func doSomethingWithCtx(ctx context.Context) error {
return nil
}
func testCasesInit(t *testing.T) {
cases := []struct {
ctx context.Context
}{
{},
{
ctx: context.WithValue(context.Background(), "key", "value"),
},
}
for _, tc := range cases {
t.Run("some test", func(t *testing.T) {
if tc.ctx == nil {
tc.ctx = context.Background()
}
})
}
}