2024-03-27 19:24:38 +01:00
|
|
|
package analyzer
|
|
|
|
|
|
|
|
import (
|
|
|
|
"bytes"
|
|
|
|
"errors"
|
2024-03-28 00:08:19 +01:00
|
|
|
"fmt"
|
2024-03-27 19:24:38 +01:00
|
|
|
"go/ast"
|
|
|
|
"go/printer"
|
|
|
|
"go/token"
|
2024-05-29 16:47:58 +02:00
|
|
|
|
2024-03-27 19:24:38 +01:00
|
|
|
"golang.org/x/tools/go/analysis"
|
|
|
|
"golang.org/x/tools/go/analysis/passes/inspect"
|
|
|
|
"golang.org/x/tools/go/ast/inspector"
|
|
|
|
)
|
|
|
|
|
|
|
|
var Analyzer = &analysis.Analyzer{
|
2024-03-27 23:50:12 +01:00
|
|
|
Name: "fatcontext",
|
|
|
|
Doc: "detects nested contexts in loops",
|
2024-03-27 19:24:38 +01:00
|
|
|
Run: run,
|
|
|
|
Requires: []*analysis.Analyzer{inspect.Analyzer},
|
|
|
|
}
|
|
|
|
|
|
|
|
var errUnknown = errors.New("unknown node type")
|
|
|
|
|
|
|
|
func run(pass *analysis.Pass) (interface{}, error) {
|
|
|
|
inspctr := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
|
|
|
|
|
|
|
|
nodeFilter := []ast.Node{
|
|
|
|
(*ast.ForStmt)(nil),
|
|
|
|
(*ast.RangeStmt)(nil),
|
|
|
|
}
|
|
|
|
|
|
|
|
inspctr.Preorder(nodeFilter, func(node ast.Node) {
|
|
|
|
body, err := getBody(node)
|
|
|
|
if err != nil {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, stmt := range body.List {
|
|
|
|
assignStmt, ok := stmt.(*ast.AssignStmt)
|
|
|
|
if !ok {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
t := pass.TypesInfo.TypeOf(assignStmt.Lhs[0])
|
|
|
|
if t == nil {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
if t.String() != "context.Context" {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
if assignStmt.Tok == token.DEFINE {
|
|
|
|
break
|
|
|
|
}
|
|
|
|
|
2024-05-29 17:26:10 +02:00
|
|
|
// allow assignment to non-pointer children of values defined within the loop
|
2024-05-29 16:47:58 +02:00
|
|
|
if lhs := getRootIdent(pass, assignStmt.Lhs[0]); lhs != nil {
|
|
|
|
if obj := pass.TypesInfo.ObjectOf(lhs); obj != nil {
|
|
|
|
if obj.Pos() >= body.Pos() && obj.Pos() < body.End() {
|
2024-05-29 17:26:10 +02:00
|
|
|
continue // definition is within the loop
|
2024-05-29 16:47:58 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-03-27 19:57:23 +01:00
|
|
|
suggestedStmt := ast.AssignStmt{
|
|
|
|
Lhs: assignStmt.Lhs,
|
|
|
|
TokPos: assignStmt.TokPos,
|
|
|
|
Tok: token.DEFINE,
|
|
|
|
Rhs: assignStmt.Rhs,
|
|
|
|
}
|
2024-03-28 00:08:19 +01:00
|
|
|
suggested, err := render(pass.Fset, &suggestedStmt)
|
|
|
|
|
|
|
|
var fixes []analysis.SuggestedFix
|
|
|
|
if err == nil {
|
|
|
|
fixes = append(fixes, analysis.SuggestedFix{
|
|
|
|
Message: "replace `=` with `:=`",
|
|
|
|
TextEdits: []analysis.TextEdit{
|
|
|
|
{
|
|
|
|
Pos: assignStmt.Pos(),
|
|
|
|
End: assignStmt.End(),
|
|
|
|
NewText: []byte(suggested),
|
2024-03-27 19:24:38 +01:00
|
|
|
},
|
|
|
|
},
|
2024-03-28 00:08:19 +01:00
|
|
|
})
|
|
|
|
}
|
|
|
|
|
|
|
|
pass.Report(analysis.Diagnostic{
|
|
|
|
Pos: assignStmt.Pos(),
|
|
|
|
Message: "nested context in loop",
|
|
|
|
SuggestedFixes: fixes,
|
2024-03-27 19:24:38 +01:00
|
|
|
})
|
|
|
|
|
|
|
|
break
|
|
|
|
}
|
|
|
|
})
|
|
|
|
|
|
|
|
return nil, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func getBody(node ast.Node) (*ast.BlockStmt, error) {
|
|
|
|
forStmt, ok := node.(*ast.ForStmt)
|
|
|
|
if ok {
|
|
|
|
return forStmt.Body, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
rangeStmt, ok := node.(*ast.RangeStmt)
|
|
|
|
if ok {
|
|
|
|
return rangeStmt.Body, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil, errUnknown
|
|
|
|
}
|
|
|
|
|
2024-05-29 16:47:58 +02:00
|
|
|
func getRootIdent(pass *analysis.Pass, node ast.Node) *ast.Ident {
|
|
|
|
for {
|
|
|
|
switch n := node.(type) {
|
|
|
|
case *ast.Ident:
|
|
|
|
return n
|
|
|
|
case *ast.IndexExpr:
|
|
|
|
node = n.X
|
|
|
|
case *ast.SelectorExpr:
|
|
|
|
if sel, ok := pass.TypesInfo.Selections[n]; ok && sel.Indirect() {
|
|
|
|
return nil // indirected (pointer) roots don't imply a (safe) copy
|
|
|
|
}
|
|
|
|
node = n.X
|
|
|
|
default:
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-03-27 19:24:38 +01:00
|
|
|
// render returns the pretty-print of the given node
|
2024-03-28 00:08:19 +01:00
|
|
|
func render(fset *token.FileSet, x interface{}) (string, error) {
|
2024-03-27 19:24:38 +01:00
|
|
|
var buf bytes.Buffer
|
|
|
|
if err := printer.Fprint(&buf, fset, x); err != nil {
|
2024-03-28 00:08:19 +01:00
|
|
|
return "", fmt.Errorf("printing node: %w", err)
|
2024-03-27 19:24:38 +01:00
|
|
|
}
|
2024-03-28 00:08:19 +01:00
|
|
|
return buf.String(), nil
|
2024-03-27 19:24:38 +01:00
|
|
|
}
|