From 5030c3a2fad5c5d639fce73d2fefae99901d347d Mon Sep 17 00:00:00 2001
From: Gabriel Augendre <gabriel@augendre.info>
Date: Mon, 13 Jan 2025 15:34:16 +0100
Subject: [PATCH] feat: ignore context.TODO and context.Background

Related to #34
---
 pkg/analyzer/analyzer.go | 71 +++++++++++++++++++++++++++-------------
 testdata/src/example.go  | 23 ++++++++++++-
 2 files changed, 71 insertions(+), 23 deletions(-)

diff --git a/pkg/analyzer/analyzer.go b/pkg/analyzer/analyzer.go
index 7b88bf5..b45dbc3 100644
--- a/pkg/analyzer/analyzer.go
+++ b/pkg/analyzer/analyzer.go
@@ -7,7 +7,7 @@ import (
 	"go/ast"
 	"go/printer"
 	"go/token"
-	"go/types"
+	"slices"
 
 	"golang.org/x/tools/go/analysis"
 	"golang.org/x/tools/go/analysis/passes/inspect"
@@ -169,13 +169,14 @@ func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *as
 			continue
 		}
 
+		// Ignore [context.Background] & [context.TODO].
+		if isContextFunction(assignStmt.Rhs[0], "Background", "TODO") {
+			continue
+		}
+
 		// allow assignment to non-pointer children of values defined within the loop
-		if lhs := getRootIdent(pass, assignStmt.Lhs[0]); lhs != nil {
-			if obj := pass.TypesInfo.ObjectOf(lhs); obj != nil {
-				if checkObjectScopeWithinNode(obj.Parent(), node) {
-					continue // definition is within the loop
-				}
-			}
+		if isWithinLoop(assignStmt.Lhs[0], node, pass) {
+			continue
 		}
 
 		return assignStmt
@@ -184,16 +185,51 @@ func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *as
 	return nil
 }
 
-func checkObjectScopeWithinNode(scope *types.Scope, node ast.Node) bool {
+// render returns the pretty-print of the given node
+func render(fset *token.FileSet, x interface{}) ([]byte, error) {
+	var buf bytes.Buffer
+	if err := printer.Fprint(&buf, fset, x); err != nil {
+		return nil, fmt.Errorf("printing node: %w", err)
+	}
+	return buf.Bytes(), nil
+}
+
+func isContextFunction(exp ast.Expr, fnName ...string) bool {
+	call, ok := exp.(*ast.CallExpr)
+	if !ok {
+		return false
+	}
+
+	selector, ok := call.Fun.(*ast.SelectorExpr)
+	if !ok {
+		return false
+	}
+
+	ident, ok := selector.X.(*ast.Ident)
+	if !ok {
+		return false
+	}
+
+	return ident.Name == "context" && slices.Contains(fnName, selector.Sel.Name)
+}
+
+func isWithinLoop(exp ast.Expr, node ast.Node, pass *analysis.Pass) bool {
+	lhs := getRootIdent(pass, exp)
+	if lhs == nil {
+		return false
+	}
+
+	obj := pass.TypesInfo.ObjectOf(lhs)
+	if obj == nil {
+		return false
+	}
+
+	scope := obj.Parent()
 	if scope == nil {
 		return false
 	}
 
-	if scope.Pos() >= node.Pos() && scope.End() <= node.End() {
-		return true
-	}
-
-	return false
+	return scope.Pos() >= node.Pos() && scope.End() <= node.End()
 }
 
 func getRootIdent(pass *analysis.Pass, node ast.Node) *ast.Ident {
@@ -213,12 +249,3 @@ func getRootIdent(pass *analysis.Pass, node ast.Node) *ast.Ident {
 		}
 	}
 }
-
-// render returns the pretty-print of the given node
-func render(fset *token.FileSet, x interface{}) ([]byte, error) {
-	var buf bytes.Buffer
-	if err := printer.Fprint(&buf, fset, x); err != nil {
-		return nil, fmt.Errorf("printing node: %w", err)
-	}
-	return buf.Bytes(), nil
-}
diff --git a/testdata/src/example.go b/testdata/src/example.go
index 565ee49..df76f89 100644
--- a/testdata/src/example.go
+++ b/testdata/src/example.go
@@ -1,6 +1,9 @@
 package src
 
-import "context"
+import (
+	"context"
+	"testing"
+)
 
 func example() {
 	ctx := context.Background()
@@ -228,3 +231,21 @@ func okMiddleware2(ctx context.Context) func(ctx context.Context) error {
 func doSomethingWithCtx(ctx context.Context) error {
 	return nil
 }
+
+func testCasesInit(t *testing.T) {
+	cases := []struct {
+		ctx context.Context
+	}{
+		{},
+		{
+			ctx: context.WithValue(context.Background(), "key", "value"),
+		},
+	}
+	for _, tc := range cases {
+		t.Run("some test", func(t *testing.T) {
+			if tc.ctx == nil {
+				tc.ctx = context.Background()
+			}
+		})
+	}
+}