mirror of
https://github.com/Crocmagnon/fatcontext.git
synced 2025-02-05 04:02:31 +01:00
feat: better discriminate assignations to struct pointers
This commit is contained in:
parent
939d65bc16
commit
ef9d47d1f0
7 changed files with 173 additions and 51 deletions
|
@ -16,8 +16,6 @@ go install github.com/Crocmagnon/fatcontext/cmd/fatcontext@latest
|
||||||
fatcontext ./...
|
fatcontext ./...
|
||||||
```
|
```
|
||||||
|
|
||||||
There are no specific configuration options or custom command-line flags.
|
|
||||||
|
|
||||||
## Example
|
## Example
|
||||||
|
|
||||||
```go
|
```go
|
||||||
|
|
|
@ -7,5 +7,5 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
singlechecker.Main(analyzer.Analyzer)
|
singlechecker.Main(analyzer.NewAnalyzer())
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package analyzer
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"go/ast"
|
"go/ast"
|
||||||
"go/printer"
|
"go/printer"
|
||||||
|
@ -14,22 +15,45 @@ import (
|
||||||
"golang.org/x/tools/go/ast/inspector"
|
"golang.org/x/tools/go/ast/inspector"
|
||||||
)
|
)
|
||||||
|
|
||||||
var Analyzer = &analysis.Analyzer{
|
const FlagCheckStructPointers = "check-struct-pointers"
|
||||||
Name: "fatcontext",
|
|
||||||
Doc: "detects nested contexts in loops and function literals",
|
func NewAnalyzer() *analysis.Analyzer {
|
||||||
Run: run,
|
r := &runner{}
|
||||||
Requires: []*analysis.Analyzer{inspect.Analyzer},
|
|
||||||
|
flags := flag.NewFlagSet("fatcontext", flag.ExitOnError)
|
||||||
|
flags.BoolVar(&r.DetectInStructPointers, FlagCheckStructPointers, false,
|
||||||
|
"set to true to detect potential fat contexts in struct pointers")
|
||||||
|
|
||||||
|
return &analysis.Analyzer{
|
||||||
|
Name: "fatcontext",
|
||||||
|
Doc: "detects nested contexts in loops and function literals",
|
||||||
|
Run: r.run,
|
||||||
|
Flags: *flags,
|
||||||
|
Requires: []*analysis.Analyzer{inspect.Analyzer},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var errUnknown = errors.New("unknown node type")
|
var errUnknown = errors.New("unknown node type")
|
||||||
|
|
||||||
func run(pass *analysis.Pass) (interface{}, error) {
|
const (
|
||||||
|
categoryInLoop = "nested context in loop"
|
||||||
|
categoryInFuncLit = "nested context in function literal"
|
||||||
|
categoryInStructPointer = "potential nested context in struct pointer"
|
||||||
|
categoryUnsupported = "unsupported nested context type"
|
||||||
|
)
|
||||||
|
|
||||||
|
type runner struct {
|
||||||
|
DetectInStructPointers bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *runner) run(pass *analysis.Pass) (interface{}, error) {
|
||||||
inspctr := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
|
inspctr := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
|
||||||
|
|
||||||
nodeFilter := []ast.Node{
|
nodeFilter := []ast.Node{
|
||||||
(*ast.ForStmt)(nil),
|
(*ast.ForStmt)(nil),
|
||||||
(*ast.RangeStmt)(nil),
|
(*ast.RangeStmt)(nil),
|
||||||
(*ast.FuncLit)(nil),
|
(*ast.FuncLit)(nil),
|
||||||
|
(*ast.FuncDecl)(nil),
|
||||||
}
|
}
|
||||||
|
|
||||||
inspctr.Preorder(nodeFilter, func(node ast.Node) {
|
inspctr.Preorder(nodeFilter, func(node ast.Node) {
|
||||||
|
@ -43,31 +67,17 @@ func run(pass *analysis.Pass) (interface{}, error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
suggestedStmt := ast.AssignStmt{
|
category := getCategory(pass, node, assignStmt)
|
||||||
Lhs: assignStmt.Lhs,
|
|
||||||
TokPos: assignStmt.TokPos,
|
|
||||||
Tok: token.DEFINE,
|
|
||||||
Rhs: assignStmt.Rhs,
|
|
||||||
}
|
|
||||||
suggested, err := render(pass.Fset, &suggestedStmt)
|
|
||||||
|
|
||||||
var fixes []analysis.SuggestedFix
|
if r.shouldIgnoreReport(category) {
|
||||||
if err == nil {
|
return
|
||||||
fixes = append(fixes, analysis.SuggestedFix{
|
|
||||||
Message: "replace `=` with `:=`",
|
|
||||||
TextEdits: []analysis.TextEdit{
|
|
||||||
{
|
|
||||||
Pos: assignStmt.Pos(),
|
|
||||||
End: assignStmt.End(),
|
|
||||||
NewText: suggested,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fixes := r.getSuggestedFixes(pass, assignStmt, category)
|
||||||
|
|
||||||
pass.Report(analysis.Diagnostic{
|
pass.Report(analysis.Diagnostic{
|
||||||
Pos: assignStmt.Pos(),
|
Pos: assignStmt.Pos(),
|
||||||
Message: getReportMessage(node),
|
Message: category,
|
||||||
SuggestedFixes: fixes,
|
SuggestedFixes: fixes,
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
@ -75,31 +85,69 @@ func run(pass *analysis.Pass) (interface{}, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getReportMessage(node ast.Node) string {
|
func (r *runner) shouldIgnoreReport(category string) bool {
|
||||||
|
return category == categoryInStructPointer && !r.DetectInStructPointers
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *runner) getSuggestedFixes(pass *analysis.Pass, assignStmt *ast.AssignStmt, category string) []analysis.SuggestedFix {
|
||||||
|
switch category {
|
||||||
|
case categoryInStructPointer, categoryUnsupported:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
suggestedStmt := ast.AssignStmt{
|
||||||
|
Lhs: assignStmt.Lhs,
|
||||||
|
TokPos: assignStmt.TokPos,
|
||||||
|
Tok: token.DEFINE,
|
||||||
|
Rhs: assignStmt.Rhs,
|
||||||
|
}
|
||||||
|
suggested, err := render(pass.Fset, &suggestedStmt)
|
||||||
|
|
||||||
|
var fixes []analysis.SuggestedFix
|
||||||
|
if err == nil {
|
||||||
|
fixes = append(fixes, analysis.SuggestedFix{
|
||||||
|
Message: "replace `=` with `:=`",
|
||||||
|
TextEdits: []analysis.TextEdit{
|
||||||
|
{
|
||||||
|
Pos: assignStmt.Pos(),
|
||||||
|
End: assignStmt.End(),
|
||||||
|
NewText: suggested,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return fixes
|
||||||
|
}
|
||||||
|
|
||||||
|
func getCategory(pass *analysis.Pass, node ast.Node, assignStmt *ast.AssignStmt) string {
|
||||||
switch node.(type) {
|
switch node.(type) {
|
||||||
case *ast.ForStmt, *ast.RangeStmt:
|
case *ast.ForStmt, *ast.RangeStmt:
|
||||||
return "nested context in loop"
|
return categoryInLoop
|
||||||
case *ast.FuncLit:
|
}
|
||||||
return "nested context in function literal"
|
|
||||||
|
if isPointer(pass, assignStmt.Lhs[0]) {
|
||||||
|
return categoryInStructPointer
|
||||||
|
}
|
||||||
|
|
||||||
|
switch node.(type) {
|
||||||
|
case *ast.FuncLit, *ast.FuncDecl:
|
||||||
|
return categoryInFuncLit
|
||||||
default:
|
default:
|
||||||
return "unsupported nested context type"
|
return categoryUnsupported
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getBody(node ast.Node) (*ast.BlockStmt, error) {
|
func getBody(node ast.Node) (*ast.BlockStmt, error) {
|
||||||
forStmt, ok := node.(*ast.ForStmt)
|
switch typedNode := node.(type) {
|
||||||
if ok {
|
case *ast.ForStmt:
|
||||||
return forStmt.Body, nil
|
return typedNode.Body, nil
|
||||||
}
|
case *ast.RangeStmt:
|
||||||
|
return typedNode.Body, nil
|
||||||
rangeStmt, ok := node.(*ast.RangeStmt)
|
case *ast.FuncLit:
|
||||||
if ok {
|
return typedNode.Body, nil
|
||||||
return rangeStmt.Body, nil
|
case *ast.FuncDecl:
|
||||||
}
|
return typedNode.Body, nil
|
||||||
|
|
||||||
funcLit, ok := node.(*ast.FuncLit)
|
|
||||||
if ok {
|
|
||||||
return funcLit.Body, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, errUnknown
|
return nil, errUnknown
|
||||||
|
@ -174,6 +222,10 @@ func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *as
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if isPointer(pass, assignStmt.Lhs[0]) {
|
||||||
|
return assignStmt
|
||||||
|
}
|
||||||
|
|
||||||
// allow assignment to non-pointer children of values defined within the loop
|
// allow assignment to non-pointer children of values defined within the loop
|
||||||
if isWithinLoop(assignStmt.Lhs[0], node, pass) {
|
if isWithinLoop(assignStmt.Lhs[0], node, pass) {
|
||||||
continue
|
continue
|
||||||
|
@ -249,3 +301,13 @@ func getRootIdent(pass *analysis.Pass, node ast.Node) *ast.Ident {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isPointer(pass *analysis.Pass, exp ast.Node) bool {
|
||||||
|
switch n := exp.(type) {
|
||||||
|
case *ast.SelectorExpr:
|
||||||
|
sel, ok := pass.TypesInfo.Selections[n]
|
||||||
|
return ok && sel.Indirect()
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
|
@ -10,12 +10,28 @@ import (
|
||||||
"github.com/Crocmagnon/fatcontext/pkg/analyzer"
|
"github.com/Crocmagnon/fatcontext/pkg/analyzer"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestAll(t *testing.T) {
|
func TestAnalyzer(t *testing.T) {
|
||||||
wd, err := os.Getwd()
|
wd, err := os.Getwd()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to get wd: %s", err)
|
t.Fatalf("Failed to get wd: %s", err)
|
||||||
}
|
}
|
||||||
testdata := filepath.Join(filepath.Dir(filepath.Dir(wd)), "testdata")
|
testdata := filepath.Join(wd, "testdata")
|
||||||
|
|
||||||
analysistest.Run(t, testdata, analyzer.Analyzer, "./...")
|
t.Run("no func decl", func(t *testing.T) {
|
||||||
|
an := analyzer.NewAnalyzer()
|
||||||
|
analysistest.Run(t, testdata, an, "./common")
|
||||||
|
analysistest.Run(t, testdata, an, "./no_structpointer")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("func decl", func(t *testing.T) {
|
||||||
|
an := analyzer.NewAnalyzer()
|
||||||
|
|
||||||
|
err := an.Flags.Set(analyzer.FlagCheckStructPointers, "true")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
analysistest.Run(t, testdata, an, "./common")
|
||||||
|
analysistest.Run(t, testdata, an, "./structpointer")
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package src
|
package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
23
pkg/analyzer/testdata/no_structpointer/example.go
vendored
Normal file
23
pkg/analyzer/testdata/no_structpointer/example.go
vendored
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Container struct {
|
||||||
|
Ctx context.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
func something() func(*Container) {
|
||||||
|
return func(r *Container) {
|
||||||
|
ctx := r.Ctx
|
||||||
|
ctx = context.WithValue(ctx, "key", "val")
|
||||||
|
r.Ctx = ctx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func blah(r *Container) {
|
||||||
|
ctx := r.Ctx
|
||||||
|
ctx = context.WithValue(ctx, "key", "val")
|
||||||
|
r.Ctx = ctx
|
||||||
|
}
|
23
pkg/analyzer/testdata/structpointer/example.go
vendored
Normal file
23
pkg/analyzer/testdata/structpointer/example.go
vendored
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Container struct {
|
||||||
|
Ctx context.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
func something() func(*Container) {
|
||||||
|
return func(r *Container) {
|
||||||
|
ctx := r.Ctx
|
||||||
|
ctx = context.WithValue(ctx, "key", "val")
|
||||||
|
r.Ctx = ctx // want "potential nested context in struct pointer"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func blah(r *Container) {
|
||||||
|
ctx := r.Ctx
|
||||||
|
ctx = context.WithValue(ctx, "key", "val")
|
||||||
|
r.Ctx = ctx // want "potential nested context in struct pointer"
|
||||||
|
}
|
Loading…
Add table
Reference in a new issue