From ef9d47d1f097531a5b5b31c8d77cdbab3325aed0 Mon Sep 17 00:00:00 2001 From: Gabriel Augendre Date: Mon, 13 Jan 2025 15:46:00 +0100 Subject: [PATCH] feat: better discriminate assignations to struct pointers --- README.md | 2 - cmd/fatcontext/main.go | 2 +- pkg/analyzer/analyzer.go | 150 +++++++++++++----- pkg/analyzer/analyzer_test.go | 22 ++- .../analyzer/testdata/common}/example.go | 2 +- .../testdata/no_structpointer/example.go | 23 +++ .../testdata/structpointer/example.go | 23 +++ 7 files changed, 173 insertions(+), 51 deletions(-) rename {testdata/src => pkg/analyzer/testdata/common}/example.go (99%) create mode 100644 pkg/analyzer/testdata/no_structpointer/example.go create mode 100644 pkg/analyzer/testdata/structpointer/example.go diff --git a/README.md b/README.md index ffbcafa..d157ba3 100644 --- a/README.md +++ b/README.md @@ -16,8 +16,6 @@ go install github.com/Crocmagnon/fatcontext/cmd/fatcontext@latest fatcontext ./... ``` -There are no specific configuration options or custom command-line flags. - ## Example ```go diff --git a/cmd/fatcontext/main.go b/cmd/fatcontext/main.go index 747f6a9..10fc1e4 100644 --- a/cmd/fatcontext/main.go +++ b/cmd/fatcontext/main.go @@ -7,5 +7,5 @@ import ( ) func main() { - singlechecker.Main(analyzer.Analyzer) + singlechecker.Main(analyzer.NewAnalyzer()) } diff --git a/pkg/analyzer/analyzer.go b/pkg/analyzer/analyzer.go index b45dbc3..8f095d3 100644 --- a/pkg/analyzer/analyzer.go +++ b/pkg/analyzer/analyzer.go @@ -3,6 +3,7 @@ package analyzer import ( "bytes" "errors" + "flag" "fmt" "go/ast" "go/printer" @@ -14,22 +15,45 @@ import ( "golang.org/x/tools/go/ast/inspector" ) -var Analyzer = &analysis.Analyzer{ - Name: "fatcontext", - Doc: "detects nested contexts in loops and function literals", - Run: run, - Requires: []*analysis.Analyzer{inspect.Analyzer}, +const FlagCheckStructPointers = "check-struct-pointers" + +func NewAnalyzer() *analysis.Analyzer { + r := &runner{} + + flags := flag.NewFlagSet("fatcontext", flag.ExitOnError) + flags.BoolVar(&r.DetectInStructPointers, FlagCheckStructPointers, false, + "set to true to detect potential fat contexts in struct pointers") + + return &analysis.Analyzer{ + Name: "fatcontext", + Doc: "detects nested contexts in loops and function literals", + Run: r.run, + Flags: *flags, + Requires: []*analysis.Analyzer{inspect.Analyzer}, + } } var errUnknown = errors.New("unknown node type") -func run(pass *analysis.Pass) (interface{}, error) { +const ( + categoryInLoop = "nested context in loop" + categoryInFuncLit = "nested context in function literal" + categoryInStructPointer = "potential nested context in struct pointer" + categoryUnsupported = "unsupported nested context type" +) + +type runner struct { + DetectInStructPointers bool +} + +func (r *runner) run(pass *analysis.Pass) (interface{}, error) { inspctr := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) nodeFilter := []ast.Node{ (*ast.ForStmt)(nil), (*ast.RangeStmt)(nil), (*ast.FuncLit)(nil), + (*ast.FuncDecl)(nil), } inspctr.Preorder(nodeFilter, func(node ast.Node) { @@ -43,31 +67,17 @@ func run(pass *analysis.Pass) (interface{}, error) { return } - suggestedStmt := ast.AssignStmt{ - Lhs: assignStmt.Lhs, - TokPos: assignStmt.TokPos, - Tok: token.DEFINE, - Rhs: assignStmt.Rhs, - } - suggested, err := render(pass.Fset, &suggestedStmt) + category := getCategory(pass, node, assignStmt) - var fixes []analysis.SuggestedFix - if err == nil { - fixes = append(fixes, analysis.SuggestedFix{ - Message: "replace `=` with `:=`", - TextEdits: []analysis.TextEdit{ - { - Pos: assignStmt.Pos(), - End: assignStmt.End(), - NewText: suggested, - }, - }, - }) + if r.shouldIgnoreReport(category) { + return } + fixes := r.getSuggestedFixes(pass, assignStmt, category) + pass.Report(analysis.Diagnostic{ Pos: assignStmt.Pos(), - Message: getReportMessage(node), + Message: category, SuggestedFixes: fixes, }) }) @@ -75,31 +85,69 @@ func run(pass *analysis.Pass) (interface{}, error) { return nil, nil } -func getReportMessage(node ast.Node) string { +func (r *runner) shouldIgnoreReport(category string) bool { + return category == categoryInStructPointer && !r.DetectInStructPointers +} + +func (r *runner) getSuggestedFixes(pass *analysis.Pass, assignStmt *ast.AssignStmt, category string) []analysis.SuggestedFix { + switch category { + case categoryInStructPointer, categoryUnsupported: + return nil + } + + suggestedStmt := ast.AssignStmt{ + Lhs: assignStmt.Lhs, + TokPos: assignStmt.TokPos, + Tok: token.DEFINE, + Rhs: assignStmt.Rhs, + } + suggested, err := render(pass.Fset, &suggestedStmt) + + var fixes []analysis.SuggestedFix + if err == nil { + fixes = append(fixes, analysis.SuggestedFix{ + Message: "replace `=` with `:=`", + TextEdits: []analysis.TextEdit{ + { + Pos: assignStmt.Pos(), + End: assignStmt.End(), + NewText: suggested, + }, + }, + }) + } + + return fixes +} + +func getCategory(pass *analysis.Pass, node ast.Node, assignStmt *ast.AssignStmt) string { switch node.(type) { case *ast.ForStmt, *ast.RangeStmt: - return "nested context in loop" - case *ast.FuncLit: - return "nested context in function literal" + return categoryInLoop + } + + if isPointer(pass, assignStmt.Lhs[0]) { + return categoryInStructPointer + } + + switch node.(type) { + case *ast.FuncLit, *ast.FuncDecl: + return categoryInFuncLit default: - return "unsupported nested context type" + return categoryUnsupported } } func getBody(node ast.Node) (*ast.BlockStmt, error) { - forStmt, ok := node.(*ast.ForStmt) - if ok { - return forStmt.Body, nil - } - - rangeStmt, ok := node.(*ast.RangeStmt) - if ok { - return rangeStmt.Body, nil - } - - funcLit, ok := node.(*ast.FuncLit) - if ok { - return funcLit.Body, nil + switch typedNode := node.(type) { + case *ast.ForStmt: + return typedNode.Body, nil + case *ast.RangeStmt: + return typedNode.Body, nil + case *ast.FuncLit: + return typedNode.Body, nil + case *ast.FuncDecl: + return typedNode.Body, nil } return nil, errUnknown @@ -174,6 +222,10 @@ func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *as continue } + if isPointer(pass, assignStmt.Lhs[0]) { + return assignStmt + } + // allow assignment to non-pointer children of values defined within the loop if isWithinLoop(assignStmt.Lhs[0], node, pass) { continue @@ -249,3 +301,13 @@ func getRootIdent(pass *analysis.Pass, node ast.Node) *ast.Ident { } } } + +func isPointer(pass *analysis.Pass, exp ast.Node) bool { + switch n := exp.(type) { + case *ast.SelectorExpr: + sel, ok := pass.TypesInfo.Selections[n] + return ok && sel.Indirect() + } + + return false +} diff --git a/pkg/analyzer/analyzer_test.go b/pkg/analyzer/analyzer_test.go index 3477ed0..b577a1f 100644 --- a/pkg/analyzer/analyzer_test.go +++ b/pkg/analyzer/analyzer_test.go @@ -10,12 +10,28 @@ import ( "github.com/Crocmagnon/fatcontext/pkg/analyzer" ) -func TestAll(t *testing.T) { +func TestAnalyzer(t *testing.T) { wd, err := os.Getwd() if err != nil { t.Fatalf("Failed to get wd: %s", err) } - testdata := filepath.Join(filepath.Dir(filepath.Dir(wd)), "testdata") + testdata := filepath.Join(wd, "testdata") - analysistest.Run(t, testdata, analyzer.Analyzer, "./...") + t.Run("no func decl", func(t *testing.T) { + an := analyzer.NewAnalyzer() + analysistest.Run(t, testdata, an, "./common") + analysistest.Run(t, testdata, an, "./no_structpointer") + }) + + t.Run("func decl", func(t *testing.T) { + an := analyzer.NewAnalyzer() + + err := an.Flags.Set(analyzer.FlagCheckStructPointers, "true") + if err != nil { + t.Fatal(err) + } + + analysistest.Run(t, testdata, an, "./common") + analysistest.Run(t, testdata, an, "./structpointer") + }) } diff --git a/testdata/src/example.go b/pkg/analyzer/testdata/common/example.go similarity index 99% rename from testdata/src/example.go rename to pkg/analyzer/testdata/common/example.go index df76f89..e437b1d 100644 --- a/testdata/src/example.go +++ b/pkg/analyzer/testdata/common/example.go @@ -1,4 +1,4 @@ -package src +package common import ( "context" diff --git a/pkg/analyzer/testdata/no_structpointer/example.go b/pkg/analyzer/testdata/no_structpointer/example.go new file mode 100644 index 0000000..043dc53 --- /dev/null +++ b/pkg/analyzer/testdata/no_structpointer/example.go @@ -0,0 +1,23 @@ +package common + +import ( + "context" +) + +type Container struct { + Ctx context.Context +} + +func something() func(*Container) { + return func(r *Container) { + ctx := r.Ctx + ctx = context.WithValue(ctx, "key", "val") + r.Ctx = ctx + } +} + +func blah(r *Container) { + ctx := r.Ctx + ctx = context.WithValue(ctx, "key", "val") + r.Ctx = ctx +} diff --git a/pkg/analyzer/testdata/structpointer/example.go b/pkg/analyzer/testdata/structpointer/example.go new file mode 100644 index 0000000..964645f --- /dev/null +++ b/pkg/analyzer/testdata/structpointer/example.go @@ -0,0 +1,23 @@ +package common + +import ( + "context" +) + +type Container struct { + Ctx context.Context +} + +func something() func(*Container) { + return func(r *Container) { + ctx := r.Ctx + ctx = context.WithValue(ctx, "key", "val") + r.Ctx = ctx // want "potential nested context in struct pointer" + } +} + +func blah(r *Container) { + ctx := r.Ctx + ctx = context.WithValue(ctx, "key", "val") + r.Ctx = ctx // want "potential nested context in struct pointer" +}