fatcontext/pkg/analyzer/analyzer.go

196 lines
4.1 KiB
Go
Raw Permalink Normal View History

2024-03-27 19:24:38 +01:00
package analyzer
import (
"bytes"
"errors"
2024-03-28 00:08:19 +01:00
"fmt"
2024-03-27 19:24:38 +01:00
"go/ast"
"go/printer"
"go/token"
2024-03-27 19:24:38 +01:00
"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/analysis/passes/inspect"
"golang.org/x/tools/go/ast/inspector"
)
var Analyzer = &analysis.Analyzer{
Name: "fatcontext",
Doc: "detects nested contexts in loops",
2024-03-27 19:24:38 +01:00
Run: run,
Requires: []*analysis.Analyzer{inspect.Analyzer},
}
var errUnknown = errors.New("unknown node type")
func run(pass *analysis.Pass) (interface{}, error) {
inspctr := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
nodeFilter := []ast.Node{
(*ast.ForStmt)(nil),
(*ast.RangeStmt)(nil),
}
inspctr.Preorder(nodeFilter, func(node ast.Node) {
body, err := getBody(node)
if err != nil {
return
}
2024-07-12 12:04:01 +02:00
assignStmt := findNestedContext(pass, body, body.List)
if assignStmt == nil {
return
}
2024-07-12 12:04:01 +02:00
suggestedStmt := ast.AssignStmt{
Lhs: assignStmt.Lhs,
TokPos: assignStmt.TokPos,
Tok: token.DEFINE,
Rhs: assignStmt.Rhs,
}
suggested, err := render(pass.Fset, &suggestedStmt)
var fixes []analysis.SuggestedFix
if err == nil {
fixes = append(fixes, analysis.SuggestedFix{
Message: "replace `=` with `:=`",
TextEdits: []analysis.TextEdit{
{
Pos: assignStmt.Pos(),
End: assignStmt.End(),
NewText: []byte(suggested),
2024-03-27 19:24:38 +01:00
},
2024-07-12 12:04:01 +02:00
},
2024-03-27 19:24:38 +01:00
})
}
2024-07-12 12:04:01 +02:00
pass.Report(analysis.Diagnostic{
Pos: assignStmt.Pos(),
Message: "nested context in loop",
SuggestedFixes: fixes,
})
2024-03-27 19:24:38 +01:00
})
return nil, nil
}
func getBody(node ast.Node) (*ast.BlockStmt, error) {
forStmt, ok := node.(*ast.ForStmt)
if ok {
return forStmt.Body, nil
}
rangeStmt, ok := node.(*ast.RangeStmt)
if ok {
return rangeStmt.Body, nil
}
return nil, errUnknown
}
2024-07-12 12:04:01 +02:00
func findNestedContext(pass *analysis.Pass, block *ast.BlockStmt, stmts []ast.Stmt) *ast.AssignStmt {
for _, stmt := range stmts {
// Recurse if necessary
if inner, ok := stmt.(*ast.BlockStmt); ok {
found := findNestedContext(pass, inner, inner.List)
if found != nil {
return found
}
}
if inner, ok := stmt.(*ast.IfStmt); ok {
found := findNestedContext(pass, inner.Body, inner.Body.List)
if found != nil {
return found
}
}
if inner, ok := stmt.(*ast.SwitchStmt); ok {
found := findNestedContext(pass, inner.Body, inner.Body.List)
if found != nil {
return found
}
}
if inner, ok := stmt.(*ast.CaseClause); ok {
found := findNestedContext(pass, block, inner.Body)
if found != nil {
return found
}
}
if inner, ok := stmt.(*ast.SelectStmt); ok {
found := findNestedContext(pass, inner.Body, inner.Body.List)
if found != nil {
return found
}
}
if inner, ok := stmt.(*ast.CommClause); ok {
found := findNestedContext(pass, block, inner.Body)
if found != nil {
return found
}
}
// Actually check for nested context
assignStmt, ok := stmt.(*ast.AssignStmt)
if !ok {
continue
}
t := pass.TypesInfo.TypeOf(assignStmt.Lhs[0])
if t == nil {
continue
}
if t.String() != "context.Context" {
continue
}
if assignStmt.Tok == token.DEFINE {
break
}
// allow assignment to non-pointer children of values defined within the loop
if lhs := getRootIdent(pass, assignStmt.Lhs[0]); lhs != nil {
if obj := pass.TypesInfo.ObjectOf(lhs); obj != nil {
if obj.Pos() >= block.Pos() && obj.Pos() < block.End() {
continue // definition is within the loop
}
}
}
return assignStmt
}
return nil
}
func getRootIdent(pass *analysis.Pass, node ast.Node) *ast.Ident {
for {
switch n := node.(type) {
case *ast.Ident:
return n
case *ast.IndexExpr:
node = n.X
case *ast.SelectorExpr:
if sel, ok := pass.TypesInfo.Selections[n]; ok && sel.Indirect() {
return nil // indirected (pointer) roots don't imply a (safe) copy
}
node = n.X
default:
return nil
}
}
}
2024-03-27 19:24:38 +01:00
// render returns the pretty-print of the given node
2024-03-28 00:08:19 +01:00
func render(fset *token.FileSet, x interface{}) (string, error) {
2024-03-27 19:24:38 +01:00
var buf bytes.Buffer
if err := printer.Fprint(&buf, fset, x); err != nil {
2024-03-28 00:08:19 +01:00
return "", fmt.Errorf("printing node: %w", err)
2024-03-27 19:24:38 +01:00
}
2024-03-28 00:08:19 +01:00
return buf.String(), nil
2024-03-27 19:24:38 +01:00
}