fatcontext/pkg/analyzer/analyzer.go

303 lines
6.5 KiB
Go
Raw Permalink Normal View History

2024-03-27 19:24:38 +01:00
package analyzer
import (
"bytes"
"errors"
"flag"
2024-03-28 00:08:19 +01:00
"fmt"
2024-03-27 19:24:38 +01:00
"go/ast"
"go/printer"
"go/token"
"slices"
2024-03-27 19:24:38 +01:00
"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/analysis/passes/inspect"
"golang.org/x/tools/go/ast/inspector"
)
const FlagCheckStructPointers = "check-struct-pointers"
func NewAnalyzer() *analysis.Analyzer {
r := &runner{}
flags := flag.NewFlagSet("fatcontext", flag.ExitOnError)
flags.BoolVar(&r.DetectInStructPointers, FlagCheckStructPointers, false,
"set to true to detect potential fat contexts in struct pointers")
return &analysis.Analyzer{
Name: "fatcontext",
Doc: "detects nested contexts in loops and function literals",
Run: r.run,
Flags: *flags,
Requires: []*analysis.Analyzer{inspect.Analyzer},
}
2024-03-27 19:24:38 +01:00
}
var errUnknown = errors.New("unknown node type")
const (
categoryInLoop = "nested context in loop"
categoryInFuncLit = "nested context in function literal"
categoryInStructPointer = "potential nested context in struct pointer"
categoryUnsupported = "unsupported nested context type"
)
type runner struct {
DetectInStructPointers bool
}
func (r *runner) run(pass *analysis.Pass) (interface{}, error) {
2024-03-27 19:24:38 +01:00
inspctr := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
nodeFilter := []ast.Node{
(*ast.ForStmt)(nil),
(*ast.RangeStmt)(nil),
(*ast.FuncLit)(nil),
(*ast.FuncDecl)(nil),
2024-03-27 19:24:38 +01:00
}
inspctr.Preorder(nodeFilter, func(node ast.Node) {
body, err := getBody(node)
if err != nil {
return
}
2025-01-17 00:37:55 +01:00
if body == nil {
return
}
assignStmt := findNestedContext(pass, node, body.List)
2024-07-12 12:04:01 +02:00
if assignStmt == nil {
return
}
category := getCategory(pass, node, assignStmt)
if r.shouldIgnoreReport(category) {
return
2024-03-27 19:24:38 +01:00
}
2024-07-12 12:04:01 +02:00
fixes := r.getSuggestedFixes(pass, assignStmt, category)
2024-07-12 12:04:01 +02:00
pass.Report(analysis.Diagnostic{
Pos: assignStmt.Pos(),
Message: category,
2024-07-12 12:04:01 +02:00
SuggestedFixes: fixes,
})
2024-03-27 19:24:38 +01:00
})
return nil, nil
}
func (r *runner) shouldIgnoreReport(category string) bool {
return category == categoryInStructPointer && !r.DetectInStructPointers
}
func (r *runner) getSuggestedFixes(pass *analysis.Pass, assignStmt *ast.AssignStmt, category string) []analysis.SuggestedFix {
switch category {
case categoryInStructPointer, categoryUnsupported:
return nil
}
suggestedStmt := ast.AssignStmt{
Lhs: assignStmt.Lhs,
TokPos: assignStmt.TokPos,
Tok: token.DEFINE,
Rhs: assignStmt.Rhs,
}
suggested, err := render(pass.Fset, &suggestedStmt)
var fixes []analysis.SuggestedFix
if err == nil {
fixes = append(fixes, analysis.SuggestedFix{
Message: "replace `=` with `:=`",
TextEdits: []analysis.TextEdit{
{
Pos: assignStmt.Pos(),
End: assignStmt.End(),
NewText: suggested,
},
},
})
}
return fixes
}
func getCategory(pass *analysis.Pass, node ast.Node, assignStmt *ast.AssignStmt) string {
switch node.(type) {
case *ast.ForStmt, *ast.RangeStmt:
return categoryInLoop
}
if isPointer(pass, assignStmt.Lhs[0]) {
return categoryInStructPointer
2024-03-27 19:24:38 +01:00
}
switch node.(type) {
case *ast.FuncLit, *ast.FuncDecl:
return categoryInFuncLit
default:
return categoryUnsupported
2024-03-27 19:24:38 +01:00
}
}
2024-03-27 19:24:38 +01:00
func getBody(node ast.Node) (*ast.BlockStmt, error) {
switch typedNode := node.(type) {
case *ast.ForStmt:
return typedNode.Body, nil
case *ast.RangeStmt:
return typedNode.Body, nil
case *ast.FuncLit:
return typedNode.Body, nil
case *ast.FuncDecl:
return typedNode.Body, nil
}
2024-03-27 19:24:38 +01:00
return nil, errUnknown
}
func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *ast.AssignStmt {
2024-07-12 12:04:01 +02:00
for _, stmt := range stmts {
// Recurse if necessary
2025-01-13 23:08:04 +01:00
switch typedStmt := stmt.(type) {
case *ast.BlockStmt:
if found := findNestedContext(pass, node, typedStmt.List); found != nil {
2024-07-12 12:04:01 +02:00
return found
}
2025-01-13 23:08:04 +01:00
case *ast.IfStmt:
if found := findNestedContext(pass, node, typedStmt.Body.List); found != nil {
2024-07-12 12:04:01 +02:00
return found
}
2025-01-13 23:08:04 +01:00
case *ast.SwitchStmt:
if found := findNestedContext(pass, node, typedStmt.Body.List); found != nil {
2024-07-12 12:04:01 +02:00
return found
}
2025-01-13 23:08:04 +01:00
case *ast.CaseClause:
if found := findNestedContext(pass, node, typedStmt.Body); found != nil {
2024-07-12 12:04:01 +02:00
return found
}
2025-01-13 23:08:04 +01:00
case *ast.SelectStmt:
if found := findNestedContext(pass, node, typedStmt.Body.List); found != nil {
2024-07-12 12:04:01 +02:00
return found
}
2025-01-13 23:08:04 +01:00
case *ast.CommClause:
if found := findNestedContext(pass, node, typedStmt.Body); found != nil {
2024-07-12 12:04:01 +02:00
return found
}
}
// Actually check for nested context
assignStmt, ok := stmt.(*ast.AssignStmt)
if !ok {
continue
}
t := pass.TypesInfo.TypeOf(assignStmt.Lhs[0])
if t == nil {
continue
}
if t.String() != "context.Context" {
continue
}
if assignStmt.Tok == token.DEFINE {
continue
2024-07-12 12:04:01 +02:00
}
// Ignore [context.Background] & [context.TODO].
if isContextFunction(assignStmt.Rhs[0], "Background", "TODO") {
continue
}
if isPointer(pass, assignStmt.Lhs[0]) {
return assignStmt
}
2024-07-12 12:04:01 +02:00
// allow assignment to non-pointer children of values defined within the loop
if isWithinLoop(assignStmt.Lhs[0], node, pass) {
continue
2024-07-12 12:04:01 +02:00
}
return assignStmt
}
return nil
}
// render returns the pretty-print of the given node
func render(fset *token.FileSet, x interface{}) ([]byte, error) {
var buf bytes.Buffer
if err := printer.Fprint(&buf, fset, x); err != nil {
return nil, fmt.Errorf("printing node: %w", err)
}
return buf.Bytes(), nil
}
func isContextFunction(exp ast.Expr, fnName ...string) bool {
call, ok := exp.(*ast.CallExpr)
if !ok {
return false
}
selector, ok := call.Fun.(*ast.SelectorExpr)
if !ok {
return false
}
ident, ok := selector.X.(*ast.Ident)
if !ok {
return false
}
return ident.Name == "context" && slices.Contains(fnName, selector.Sel.Name)
}
func isWithinLoop(exp ast.Expr, node ast.Node, pass *analysis.Pass) bool {
lhs := getRootIdent(pass, exp)
if lhs == nil {
return false
}
obj := pass.TypesInfo.ObjectOf(lhs)
if obj == nil {
return false
}
scope := obj.Parent()
if scope == nil {
return false
}
return scope.Pos() >= node.Pos() && scope.End() <= node.End()
}
func getRootIdent(pass *analysis.Pass, node ast.Node) *ast.Ident {
for {
switch n := node.(type) {
case *ast.Ident:
return n
case *ast.IndexExpr:
node = n.X
case *ast.SelectorExpr:
if sel, ok := pass.TypesInfo.Selections[n]; ok && sel.Indirect() {
return nil // indirected (pointer) roots don't imply a (safe) copy
}
node = n.X
default:
return nil
}
}
}
func isPointer(pass *analysis.Pass, exp ast.Node) bool {
switch n := exp.(type) {
case *ast.SelectorExpr:
sel, ok := pass.TypesInfo.Selections[n]
return ok && sel.Indirect()
}
return false
}