feat: better discriminate assignations to struct pointers

This commit is contained in:
Gabriel Augendre 2025-01-13 15:46:00 +01:00
parent 939d65bc16
commit ef9d47d1f0
7 changed files with 173 additions and 51 deletions

View file

@ -16,8 +16,6 @@ go install github.com/Crocmagnon/fatcontext/cmd/fatcontext@latest
fatcontext ./...
```
There are no specific configuration options or custom command-line flags.
## Example
```go

View file

@ -7,5 +7,5 @@ import (
)
func main() {
singlechecker.Main(analyzer.Analyzer)
singlechecker.Main(analyzer.NewAnalyzer())
}

View file

@ -3,6 +3,7 @@ package analyzer
import (
"bytes"
"errors"
"flag"
"fmt"
"go/ast"
"go/printer"
@ -14,22 +15,45 @@ import (
"golang.org/x/tools/go/ast/inspector"
)
var Analyzer = &analysis.Analyzer{
const FlagCheckStructPointers = "check-struct-pointers"
func NewAnalyzer() *analysis.Analyzer {
r := &runner{}
flags := flag.NewFlagSet("fatcontext", flag.ExitOnError)
flags.BoolVar(&r.DetectInStructPointers, FlagCheckStructPointers, false,
"set to true to detect potential fat contexts in struct pointers")
return &analysis.Analyzer{
Name: "fatcontext",
Doc: "detects nested contexts in loops and function literals",
Run: run,
Run: r.run,
Flags: *flags,
Requires: []*analysis.Analyzer{inspect.Analyzer},
}
}
var errUnknown = errors.New("unknown node type")
func run(pass *analysis.Pass) (interface{}, error) {
const (
categoryInLoop = "nested context in loop"
categoryInFuncLit = "nested context in function literal"
categoryInStructPointer = "potential nested context in struct pointer"
categoryUnsupported = "unsupported nested context type"
)
type runner struct {
DetectInStructPointers bool
}
func (r *runner) run(pass *analysis.Pass) (interface{}, error) {
inspctr := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
nodeFilter := []ast.Node{
(*ast.ForStmt)(nil),
(*ast.RangeStmt)(nil),
(*ast.FuncLit)(nil),
(*ast.FuncDecl)(nil),
}
inspctr.Preorder(nodeFilter, func(node ast.Node) {
@ -43,6 +67,34 @@ func run(pass *analysis.Pass) (interface{}, error) {
return
}
category := getCategory(pass, node, assignStmt)
if r.shouldIgnoreReport(category) {
return
}
fixes := r.getSuggestedFixes(pass, assignStmt, category)
pass.Report(analysis.Diagnostic{
Pos: assignStmt.Pos(),
Message: category,
SuggestedFixes: fixes,
})
})
return nil, nil
}
func (r *runner) shouldIgnoreReport(category string) bool {
return category == categoryInStructPointer && !r.DetectInStructPointers
}
func (r *runner) getSuggestedFixes(pass *analysis.Pass, assignStmt *ast.AssignStmt, category string) []analysis.SuggestedFix {
switch category {
case categoryInStructPointer, categoryUnsupported:
return nil
}
suggestedStmt := ast.AssignStmt{
Lhs: assignStmt.Lhs,
TokPos: assignStmt.TokPos,
@ -65,41 +117,37 @@ func run(pass *analysis.Pass) (interface{}, error) {
})
}
pass.Report(analysis.Diagnostic{
Pos: assignStmt.Pos(),
Message: getReportMessage(node),
SuggestedFixes: fixes,
})
})
return nil, nil
return fixes
}
func getReportMessage(node ast.Node) string {
func getCategory(pass *analysis.Pass, node ast.Node, assignStmt *ast.AssignStmt) string {
switch node.(type) {
case *ast.ForStmt, *ast.RangeStmt:
return "nested context in loop"
case *ast.FuncLit:
return "nested context in function literal"
return categoryInLoop
}
if isPointer(pass, assignStmt.Lhs[0]) {
return categoryInStructPointer
}
switch node.(type) {
case *ast.FuncLit, *ast.FuncDecl:
return categoryInFuncLit
default:
return "unsupported nested context type"
return categoryUnsupported
}
}
func getBody(node ast.Node) (*ast.BlockStmt, error) {
forStmt, ok := node.(*ast.ForStmt)
if ok {
return forStmt.Body, nil
}
rangeStmt, ok := node.(*ast.RangeStmt)
if ok {
return rangeStmt.Body, nil
}
funcLit, ok := node.(*ast.FuncLit)
if ok {
return funcLit.Body, nil
switch typedNode := node.(type) {
case *ast.ForStmt:
return typedNode.Body, nil
case *ast.RangeStmt:
return typedNode.Body, nil
case *ast.FuncLit:
return typedNode.Body, nil
case *ast.FuncDecl:
return typedNode.Body, nil
}
return nil, errUnknown
@ -174,6 +222,10 @@ func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *as
continue
}
if isPointer(pass, assignStmt.Lhs[0]) {
return assignStmt
}
// allow assignment to non-pointer children of values defined within the loop
if isWithinLoop(assignStmt.Lhs[0], node, pass) {
continue
@ -249,3 +301,13 @@ func getRootIdent(pass *analysis.Pass, node ast.Node) *ast.Ident {
}
}
}
func isPointer(pass *analysis.Pass, exp ast.Node) bool {
switch n := exp.(type) {
case *ast.SelectorExpr:
sel, ok := pass.TypesInfo.Selections[n]
return ok && sel.Indirect()
}
return false
}

View file

@ -10,12 +10,28 @@ import (
"github.com/Crocmagnon/fatcontext/pkg/analyzer"
)
func TestAll(t *testing.T) {
func TestAnalyzer(t *testing.T) {
wd, err := os.Getwd()
if err != nil {
t.Fatalf("Failed to get wd: %s", err)
}
testdata := filepath.Join(filepath.Dir(filepath.Dir(wd)), "testdata")
testdata := filepath.Join(wd, "testdata")
analysistest.Run(t, testdata, analyzer.Analyzer, "./...")
t.Run("no func decl", func(t *testing.T) {
an := analyzer.NewAnalyzer()
analysistest.Run(t, testdata, an, "./common")
analysistest.Run(t, testdata, an, "./no_structpointer")
})
t.Run("func decl", func(t *testing.T) {
an := analyzer.NewAnalyzer()
err := an.Flags.Set(analyzer.FlagCheckStructPointers, "true")
if err != nil {
t.Fatal(err)
}
analysistest.Run(t, testdata, an, "./common")
analysistest.Run(t, testdata, an, "./structpointer")
})
}

View file

@ -1,4 +1,4 @@
package src
package common
import (
"context"

View file

@ -0,0 +1,23 @@
package common
import (
"context"
)
type Container struct {
Ctx context.Context
}
func something() func(*Container) {
return func(r *Container) {
ctx := r.Ctx
ctx = context.WithValue(ctx, "key", "val")
r.Ctx = ctx
}
}
func blah(r *Container) {
ctx := r.Ctx
ctx = context.WithValue(ctx, "key", "val")
r.Ctx = ctx
}

View file

@ -0,0 +1,23 @@
package common
import (
"context"
)
type Container struct {
Ctx context.Context
}
func something() func(*Container) {
return func(r *Container) {
ctx := r.Ctx
ctx = context.WithValue(ctx, "key", "val")
r.Ctx = ctx // want "potential nested context in struct pointer"
}
}
func blah(r *Container) {
ctx := r.Ctx
ctx = context.WithValue(ctx, "key", "val")
r.Ctx = ctx // want "potential nested context in struct pointer"
}