Compare commits

..

No commits in common. "4410b65005bd5c6170573eb79dffce60f66ed5b2" and "0d2c4019d419d6287fd1f1351b1562550afaace5" have entirely different histories.

4 changed files with 15 additions and 92 deletions

View file

@ -16,7 +16,7 @@ jobs:
golangci: golangci:
strategy: strategy:
matrix: matrix:
go: ['1.22', '1.23'] go: ['1.21', '1.22']
name: lint name: lint
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:

View file

@ -1,6 +1,6 @@
# fatcontext # fatcontext
`fatcontext` is a Go linter which detects potential fat contexts in loops or function literals. `fatcontext` is a Go linter which detects potential fat contexts in loops.
They can lead to performance issues, as documented here: https://gabnotes.org/fat-contexts/ They can lead to performance issues, as documented here: https://gabnotes.org/fat-contexts/
## Installation / usage ## Installation / usage

View file

@ -7,7 +7,6 @@ import (
"go/ast" "go/ast"
"go/printer" "go/printer"
"go/token" "go/token"
"go/types"
"golang.org/x/tools/go/analysis" "golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/analysis/passes/inspect" "golang.org/x/tools/go/analysis/passes/inspect"
@ -16,7 +15,7 @@ import (
var Analyzer = &analysis.Analyzer{ var Analyzer = &analysis.Analyzer{
Name: "fatcontext", Name: "fatcontext",
Doc: "detects nested contexts in loops and function literals", Doc: "detects nested contexts in loops",
Run: run, Run: run,
Requires: []*analysis.Analyzer{inspect.Analyzer}, Requires: []*analysis.Analyzer{inspect.Analyzer},
} }
@ -29,7 +28,6 @@ func run(pass *analysis.Pass) (interface{}, error) {
nodeFilter := []ast.Node{ nodeFilter := []ast.Node{
(*ast.ForStmt)(nil), (*ast.ForStmt)(nil),
(*ast.RangeStmt)(nil), (*ast.RangeStmt)(nil),
(*ast.FuncLit)(nil),
} }
inspctr.Preorder(nodeFilter, func(node ast.Node) { inspctr.Preorder(nodeFilter, func(node ast.Node) {
@ -38,7 +36,7 @@ func run(pass *analysis.Pass) (interface{}, error) {
return return
} }
assignStmt := findNestedContext(pass, node, body.List) assignStmt := findNestedContext(pass, body, body.List)
if assignStmt == nil { if assignStmt == nil {
return return
} }
@ -67,25 +65,15 @@ func run(pass *analysis.Pass) (interface{}, error) {
pass.Report(analysis.Diagnostic{ pass.Report(analysis.Diagnostic{
Pos: assignStmt.Pos(), Pos: assignStmt.Pos(),
Message: getReportMessage(node), Message: "nested context in loop",
SuggestedFixes: fixes, SuggestedFixes: fixes,
}) })
}) })
return nil, nil return nil, nil
} }
func getReportMessage(node ast.Node) string {
switch node.(type) {
case *ast.ForStmt, *ast.RangeStmt:
return "nested context in loop"
case *ast.FuncLit:
return "nested context in function literal"
default:
return "unsupported nested context type"
}
}
func getBody(node ast.Node) (*ast.BlockStmt, error) { func getBody(node ast.Node) (*ast.BlockStmt, error) {
forStmt, ok := node.(*ast.ForStmt) forStmt, ok := node.(*ast.ForStmt)
if ok { if ok {
@ -97,54 +85,49 @@ func getBody(node ast.Node) (*ast.BlockStmt, error) {
return rangeStmt.Body, nil return rangeStmt.Body, nil
} }
funcLit, ok := node.(*ast.FuncLit)
if ok {
return funcLit.Body, nil
}
return nil, errUnknown return nil, errUnknown
} }
func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *ast.AssignStmt { func findNestedContext(pass *analysis.Pass, block *ast.BlockStmt, stmts []ast.Stmt) *ast.AssignStmt {
for _, stmt := range stmts { for _, stmt := range stmts {
// Recurse if necessary // Recurse if necessary
if inner, ok := stmt.(*ast.BlockStmt); ok { if inner, ok := stmt.(*ast.BlockStmt); ok {
found := findNestedContext(pass, node, inner.List) found := findNestedContext(pass, inner, inner.List)
if found != nil { if found != nil {
return found return found
} }
} }
if inner, ok := stmt.(*ast.IfStmt); ok { if inner, ok := stmt.(*ast.IfStmt); ok {
found := findNestedContext(pass, node, inner.Body.List) found := findNestedContext(pass, inner.Body, inner.Body.List)
if found != nil { if found != nil {
return found return found
} }
} }
if inner, ok := stmt.(*ast.SwitchStmt); ok { if inner, ok := stmt.(*ast.SwitchStmt); ok {
found := findNestedContext(pass, node, inner.Body.List) found := findNestedContext(pass, inner.Body, inner.Body.List)
if found != nil { if found != nil {
return found return found
} }
} }
if inner, ok := stmt.(*ast.CaseClause); ok { if inner, ok := stmt.(*ast.CaseClause); ok {
found := findNestedContext(pass, node, inner.Body) found := findNestedContext(pass, block, inner.Body)
if found != nil { if found != nil {
return found return found
} }
} }
if inner, ok := stmt.(*ast.SelectStmt); ok { if inner, ok := stmt.(*ast.SelectStmt); ok {
found := findNestedContext(pass, node, inner.Body.List) found := findNestedContext(pass, inner.Body, inner.Body.List)
if found != nil { if found != nil {
return found return found
} }
} }
if inner, ok := stmt.(*ast.CommClause); ok { if inner, ok := stmt.(*ast.CommClause); ok {
found := findNestedContext(pass, node, inner.Body) found := findNestedContext(pass, block, inner.Body)
if found != nil { if found != nil {
return found return found
} }
@ -166,13 +149,13 @@ func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *as
} }
if assignStmt.Tok == token.DEFINE { if assignStmt.Tok == token.DEFINE {
continue break
} }
// allow assignment to non-pointer children of values defined within the loop // allow assignment to non-pointer children of values defined within the loop
if lhs := getRootIdent(pass, assignStmt.Lhs[0]); lhs != nil { if lhs := getRootIdent(pass, assignStmt.Lhs[0]); lhs != nil {
if obj := pass.TypesInfo.ObjectOf(lhs); obj != nil { if obj := pass.TypesInfo.ObjectOf(lhs); obj != nil {
if checkObjectScopeWithinNode(obj.Parent(), node) { if obj.Pos() >= block.Pos() && obj.Pos() < block.End() {
continue // definition is within the loop continue // definition is within the loop
} }
} }
@ -184,18 +167,6 @@ func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *as
return nil return nil
} }
func checkObjectScopeWithinNode(scope *types.Scope, node ast.Node) bool {
if scope == nil {
return false
}
if scope.Pos() >= node.Pos() && scope.End() <= node.End() {
return true
}
return false
}
func getRootIdent(pass *analysis.Pass, node ast.Node) *ast.Ident { func getRootIdent(pass *analysis.Pass, node ast.Node) *ast.Ident {
for { for {
switch n := node.(type) { switch n := node.(type) {

View file

@ -59,26 +59,6 @@ func example() {
break break
} }
// detects contexts wrapped in function literals (this is risky as function literals can be called multiple times)
_ = func() {
ctx = wrapContext(ctx) // want "nested context in function literal"
}
// this is fine because the context is created in the loop
for {
if ctx := context.Background(); doSomething() != nil {
ctx = wrapContext(ctx)
}
}
for {
ctx2 := context.Background()
ctx = wrapContext(ctx) // want "nested context in loop"
if doSomething() != nil {
ctx2 = wrapContext(ctx2)
}
}
} }
func wrapContext(ctx context.Context) context.Context { func wrapContext(ctx context.Context) context.Context {
@ -200,31 +180,3 @@ func inVariousNestedBlocks(ctx context.Context) {
break break
} }
} }
// this middleware could run on every request, bloating the request parameter level context and causing a memory leak
func badMiddleware(ctx context.Context) func() error {
return func() error {
ctx = wrapContext(ctx) // want "nested context in function literal"
return doSomethingWithCtx(ctx)
}
}
// this middleware is fine, as it doesn't modify the context of parent function
func okMiddleware(ctx context.Context) func() error {
return func() error {
ctx := wrapContext(ctx)
return doSomethingWithCtx(ctx)
}
}
// this middleware is fine, as it only modifies the context passed to it
func okMiddleware2(ctx context.Context) func(ctx context.Context) error {
return func(ctx context.Context) error {
ctx = wrapContext(ctx)
return doSomethingWithCtx(ctx)
}
}
func doSomethingWithCtx(ctx context.Context) error {
return nil
}