feat: better discriminate assignations to struct pointers

This commit is contained in:
Gabriel Augendre 2025-01-13 15:46:00 +01:00
parent 939d65bc16
commit ef9d47d1f0
7 changed files with 173 additions and 51 deletions

View file

@ -16,8 +16,6 @@ go install github.com/Crocmagnon/fatcontext/cmd/fatcontext@latest
fatcontext ./... fatcontext ./...
``` ```
There are no specific configuration options or custom command-line flags.
## Example ## Example
```go ```go

View file

@ -7,5 +7,5 @@ import (
) )
func main() { func main() {
singlechecker.Main(analyzer.Analyzer) singlechecker.Main(analyzer.NewAnalyzer())
} }

View file

@ -3,6 +3,7 @@ package analyzer
import ( import (
"bytes" "bytes"
"errors" "errors"
"flag"
"fmt" "fmt"
"go/ast" "go/ast"
"go/printer" "go/printer"
@ -14,22 +15,45 @@ import (
"golang.org/x/tools/go/ast/inspector" "golang.org/x/tools/go/ast/inspector"
) )
var Analyzer = &analysis.Analyzer{ const FlagCheckStructPointers = "check-struct-pointers"
Name: "fatcontext",
Doc: "detects nested contexts in loops and function literals", func NewAnalyzer() *analysis.Analyzer {
Run: run, r := &runner{}
Requires: []*analysis.Analyzer{inspect.Analyzer},
flags := flag.NewFlagSet("fatcontext", flag.ExitOnError)
flags.BoolVar(&r.DetectInStructPointers, FlagCheckStructPointers, false,
"set to true to detect potential fat contexts in struct pointers")
return &analysis.Analyzer{
Name: "fatcontext",
Doc: "detects nested contexts in loops and function literals",
Run: r.run,
Flags: *flags,
Requires: []*analysis.Analyzer{inspect.Analyzer},
}
} }
var errUnknown = errors.New("unknown node type") var errUnknown = errors.New("unknown node type")
func run(pass *analysis.Pass) (interface{}, error) { const (
categoryInLoop = "nested context in loop"
categoryInFuncLit = "nested context in function literal"
categoryInStructPointer = "potential nested context in struct pointer"
categoryUnsupported = "unsupported nested context type"
)
type runner struct {
DetectInStructPointers bool
}
func (r *runner) run(pass *analysis.Pass) (interface{}, error) {
inspctr := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) inspctr := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
nodeFilter := []ast.Node{ nodeFilter := []ast.Node{
(*ast.ForStmt)(nil), (*ast.ForStmt)(nil),
(*ast.RangeStmt)(nil), (*ast.RangeStmt)(nil),
(*ast.FuncLit)(nil), (*ast.FuncLit)(nil),
(*ast.FuncDecl)(nil),
} }
inspctr.Preorder(nodeFilter, func(node ast.Node) { inspctr.Preorder(nodeFilter, func(node ast.Node) {
@ -43,31 +67,17 @@ func run(pass *analysis.Pass) (interface{}, error) {
return return
} }
suggestedStmt := ast.AssignStmt{ category := getCategory(pass, node, assignStmt)
Lhs: assignStmt.Lhs,
TokPos: assignStmt.TokPos,
Tok: token.DEFINE,
Rhs: assignStmt.Rhs,
}
suggested, err := render(pass.Fset, &suggestedStmt)
var fixes []analysis.SuggestedFix if r.shouldIgnoreReport(category) {
if err == nil { return
fixes = append(fixes, analysis.SuggestedFix{
Message: "replace `=` with `:=`",
TextEdits: []analysis.TextEdit{
{
Pos: assignStmt.Pos(),
End: assignStmt.End(),
NewText: suggested,
},
},
})
} }
fixes := r.getSuggestedFixes(pass, assignStmt, category)
pass.Report(analysis.Diagnostic{ pass.Report(analysis.Diagnostic{
Pos: assignStmt.Pos(), Pos: assignStmt.Pos(),
Message: getReportMessage(node), Message: category,
SuggestedFixes: fixes, SuggestedFixes: fixes,
}) })
}) })
@ -75,31 +85,69 @@ func run(pass *analysis.Pass) (interface{}, error) {
return nil, nil return nil, nil
} }
func getReportMessage(node ast.Node) string { func (r *runner) shouldIgnoreReport(category string) bool {
return category == categoryInStructPointer && !r.DetectInStructPointers
}
func (r *runner) getSuggestedFixes(pass *analysis.Pass, assignStmt *ast.AssignStmt, category string) []analysis.SuggestedFix {
switch category {
case categoryInStructPointer, categoryUnsupported:
return nil
}
suggestedStmt := ast.AssignStmt{
Lhs: assignStmt.Lhs,
TokPos: assignStmt.TokPos,
Tok: token.DEFINE,
Rhs: assignStmt.Rhs,
}
suggested, err := render(pass.Fset, &suggestedStmt)
var fixes []analysis.SuggestedFix
if err == nil {
fixes = append(fixes, analysis.SuggestedFix{
Message: "replace `=` with `:=`",
TextEdits: []analysis.TextEdit{
{
Pos: assignStmt.Pos(),
End: assignStmt.End(),
NewText: suggested,
},
},
})
}
return fixes
}
func getCategory(pass *analysis.Pass, node ast.Node, assignStmt *ast.AssignStmt) string {
switch node.(type) { switch node.(type) {
case *ast.ForStmt, *ast.RangeStmt: case *ast.ForStmt, *ast.RangeStmt:
return "nested context in loop" return categoryInLoop
case *ast.FuncLit: }
return "nested context in function literal"
if isPointer(pass, assignStmt.Lhs[0]) {
return categoryInStructPointer
}
switch node.(type) {
case *ast.FuncLit, *ast.FuncDecl:
return categoryInFuncLit
default: default:
return "unsupported nested context type" return categoryUnsupported
} }
} }
func getBody(node ast.Node) (*ast.BlockStmt, error) { func getBody(node ast.Node) (*ast.BlockStmt, error) {
forStmt, ok := node.(*ast.ForStmt) switch typedNode := node.(type) {
if ok { case *ast.ForStmt:
return forStmt.Body, nil return typedNode.Body, nil
} case *ast.RangeStmt:
return typedNode.Body, nil
rangeStmt, ok := node.(*ast.RangeStmt) case *ast.FuncLit:
if ok { return typedNode.Body, nil
return rangeStmt.Body, nil case *ast.FuncDecl:
} return typedNode.Body, nil
funcLit, ok := node.(*ast.FuncLit)
if ok {
return funcLit.Body, nil
} }
return nil, errUnknown return nil, errUnknown
@ -174,6 +222,10 @@ func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *as
continue continue
} }
if isPointer(pass, assignStmt.Lhs[0]) {
return assignStmt
}
// allow assignment to non-pointer children of values defined within the loop // allow assignment to non-pointer children of values defined within the loop
if isWithinLoop(assignStmt.Lhs[0], node, pass) { if isWithinLoop(assignStmt.Lhs[0], node, pass) {
continue continue
@ -249,3 +301,13 @@ func getRootIdent(pass *analysis.Pass, node ast.Node) *ast.Ident {
} }
} }
} }
func isPointer(pass *analysis.Pass, exp ast.Node) bool {
switch n := exp.(type) {
case *ast.SelectorExpr:
sel, ok := pass.TypesInfo.Selections[n]
return ok && sel.Indirect()
}
return false
}

View file

@ -10,12 +10,28 @@ import (
"github.com/Crocmagnon/fatcontext/pkg/analyzer" "github.com/Crocmagnon/fatcontext/pkg/analyzer"
) )
func TestAll(t *testing.T) { func TestAnalyzer(t *testing.T) {
wd, err := os.Getwd() wd, err := os.Getwd()
if err != nil { if err != nil {
t.Fatalf("Failed to get wd: %s", err) t.Fatalf("Failed to get wd: %s", err)
} }
testdata := filepath.Join(filepath.Dir(filepath.Dir(wd)), "testdata") testdata := filepath.Join(wd, "testdata")
analysistest.Run(t, testdata, analyzer.Analyzer, "./...") t.Run("no func decl", func(t *testing.T) {
an := analyzer.NewAnalyzer()
analysistest.Run(t, testdata, an, "./common")
analysistest.Run(t, testdata, an, "./no_structpointer")
})
t.Run("func decl", func(t *testing.T) {
an := analyzer.NewAnalyzer()
err := an.Flags.Set(analyzer.FlagCheckStructPointers, "true")
if err != nil {
t.Fatal(err)
}
analysistest.Run(t, testdata, an, "./common")
analysistest.Run(t, testdata, an, "./structpointer")
})
} }

View file

@ -1,4 +1,4 @@
package src package common
import ( import (
"context" "context"

View file

@ -0,0 +1,23 @@
package common
import (
"context"
)
type Container struct {
Ctx context.Context
}
func something() func(*Container) {
return func(r *Container) {
ctx := r.Ctx
ctx = context.WithValue(ctx, "key", "val")
r.Ctx = ctx
}
}
func blah(r *Container) {
ctx := r.Ctx
ctx = context.WithValue(ctx, "key", "val")
r.Ctx = ctx
}

View file

@ -0,0 +1,23 @@
package common
import (
"context"
)
type Container struct {
Ctx context.Context
}
func something() func(*Container) {
return func(r *Container) {
ctx := r.Ctx
ctx = context.WithValue(ctx, "key", "val")
r.Ctx = ctx // want "potential nested context in struct pointer"
}
}
func blah(r *Container) {
ctx := r.Ctx
ctx = context.WithValue(ctx, "key", "val")
r.Ctx = ctx // want "potential nested context in struct pointer"
}