mirror of
https://github.com/Crocmagnon/fatcontext.git
synced 2024-12-23 22:41:57 +01:00
Detect nested contexts in function literals (#18)
* feat: Add detection for nested contexts in function literals * feat: Improve detection of nested contexts in function literals * refactor: Update getReportMessage function to handle unsupported nested context types * use node instead of block * refactor: use multi case * added one more case * feat: also added support for multiple contexts
This commit is contained in:
parent
0d2c4019d4
commit
be0aa70f23
2 changed files with 89 additions and 11 deletions
|
@ -7,6 +7,7 @@ import (
|
|||
"go/ast"
|
||||
"go/printer"
|
||||
"go/token"
|
||||
"go/types"
|
||||
|
||||
"golang.org/x/tools/go/analysis"
|
||||
"golang.org/x/tools/go/analysis/passes/inspect"
|
||||
|
@ -28,6 +29,7 @@ func run(pass *analysis.Pass) (interface{}, error) {
|
|||
nodeFilter := []ast.Node{
|
||||
(*ast.ForStmt)(nil),
|
||||
(*ast.RangeStmt)(nil),
|
||||
(*ast.FuncLit)(nil),
|
||||
}
|
||||
|
||||
inspctr.Preorder(nodeFilter, func(node ast.Node) {
|
||||
|
@ -36,7 +38,7 @@ func run(pass *analysis.Pass) (interface{}, error) {
|
|||
return
|
||||
}
|
||||
|
||||
assignStmt := findNestedContext(pass, body, body.List)
|
||||
assignStmt := findNestedContext(pass, node, body.List)
|
||||
if assignStmt == nil {
|
||||
return
|
||||
}
|
||||
|
@ -65,7 +67,7 @@ func run(pass *analysis.Pass) (interface{}, error) {
|
|||
|
||||
pass.Report(analysis.Diagnostic{
|
||||
Pos: assignStmt.Pos(),
|
||||
Message: "nested context in loop",
|
||||
Message: getReportMessage(node),
|
||||
SuggestedFixes: fixes,
|
||||
})
|
||||
|
||||
|
@ -74,6 +76,17 @@ func run(pass *analysis.Pass) (interface{}, error) {
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
func getReportMessage(node ast.Node) string {
|
||||
switch node.(type) {
|
||||
case *ast.ForStmt, *ast.RangeStmt:
|
||||
return "nested context in loop"
|
||||
case *ast.FuncLit:
|
||||
return "nested context in function literal"
|
||||
default:
|
||||
return "unsupported nested context type"
|
||||
}
|
||||
}
|
||||
|
||||
func getBody(node ast.Node) (*ast.BlockStmt, error) {
|
||||
forStmt, ok := node.(*ast.ForStmt)
|
||||
if ok {
|
||||
|
@ -85,49 +98,54 @@ func getBody(node ast.Node) (*ast.BlockStmt, error) {
|
|||
return rangeStmt.Body, nil
|
||||
}
|
||||
|
||||
funcLit, ok := node.(*ast.FuncLit)
|
||||
if ok {
|
||||
return funcLit.Body, nil
|
||||
}
|
||||
|
||||
return nil, errUnknown
|
||||
}
|
||||
|
||||
func findNestedContext(pass *analysis.Pass, block *ast.BlockStmt, stmts []ast.Stmt) *ast.AssignStmt {
|
||||
func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *ast.AssignStmt {
|
||||
for _, stmt := range stmts {
|
||||
// Recurse if necessary
|
||||
if inner, ok := stmt.(*ast.BlockStmt); ok {
|
||||
found := findNestedContext(pass, inner, inner.List)
|
||||
found := findNestedContext(pass, node, inner.List)
|
||||
if found != nil {
|
||||
return found
|
||||
}
|
||||
}
|
||||
|
||||
if inner, ok := stmt.(*ast.IfStmt); ok {
|
||||
found := findNestedContext(pass, inner.Body, inner.Body.List)
|
||||
found := findNestedContext(pass, node, inner.Body.List)
|
||||
if found != nil {
|
||||
return found
|
||||
}
|
||||
}
|
||||
|
||||
if inner, ok := stmt.(*ast.SwitchStmt); ok {
|
||||
found := findNestedContext(pass, inner.Body, inner.Body.List)
|
||||
found := findNestedContext(pass, node, inner.Body.List)
|
||||
if found != nil {
|
||||
return found
|
||||
}
|
||||
}
|
||||
|
||||
if inner, ok := stmt.(*ast.CaseClause); ok {
|
||||
found := findNestedContext(pass, block, inner.Body)
|
||||
found := findNestedContext(pass, node, inner.Body)
|
||||
if found != nil {
|
||||
return found
|
||||
}
|
||||
}
|
||||
|
||||
if inner, ok := stmt.(*ast.SelectStmt); ok {
|
||||
found := findNestedContext(pass, inner.Body, inner.Body.List)
|
||||
found := findNestedContext(pass, node, inner.Body.List)
|
||||
if found != nil {
|
||||
return found
|
||||
}
|
||||
}
|
||||
|
||||
if inner, ok := stmt.(*ast.CommClause); ok {
|
||||
found := findNestedContext(pass, block, inner.Body)
|
||||
found := findNestedContext(pass, node, inner.Body)
|
||||
if found != nil {
|
||||
return found
|
||||
}
|
||||
|
@ -149,13 +167,13 @@ func findNestedContext(pass *analysis.Pass, block *ast.BlockStmt, stmts []ast.St
|
|||
}
|
||||
|
||||
if assignStmt.Tok == token.DEFINE {
|
||||
break
|
||||
continue
|
||||
}
|
||||
|
||||
// allow assignment to non-pointer children of values defined within the loop
|
||||
if lhs := getRootIdent(pass, assignStmt.Lhs[0]); lhs != nil {
|
||||
if obj := pass.TypesInfo.ObjectOf(lhs); obj != nil {
|
||||
if obj.Pos() >= block.Pos() && obj.Pos() < block.End() {
|
||||
if checkObjectScopeWithinNode(obj.Parent(), node) {
|
||||
continue // definition is within the loop
|
||||
}
|
||||
}
|
||||
|
@ -167,6 +185,18 @@ func findNestedContext(pass *analysis.Pass, block *ast.BlockStmt, stmts []ast.St
|
|||
return nil
|
||||
}
|
||||
|
||||
func checkObjectScopeWithinNode(scope *types.Scope, node ast.Node) bool {
|
||||
if scope == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if scope.Pos() >= node.Pos() && scope.End() <= node.End() {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func getRootIdent(pass *analysis.Pass, node ast.Node) *ast.Ident {
|
||||
for {
|
||||
switch n := node.(type) {
|
||||
|
|
48
testdata/src/example.go
vendored
48
testdata/src/example.go
vendored
|
@ -59,6 +59,26 @@ func example() {
|
|||
|
||||
break
|
||||
}
|
||||
|
||||
// detects contexts wrapped in function literals (this is risky as function literals can be called multiple times)
|
||||
_ = func() {
|
||||
ctx = wrapContext(ctx) // want "nested context in function literal"
|
||||
}
|
||||
|
||||
// this is fine because the context is created in the loop
|
||||
for {
|
||||
if ctx := context.Background(); doSomething() != nil {
|
||||
ctx = wrapContext(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
ctx2 := context.Background()
|
||||
ctx = wrapContext(ctx) // want "nested context in loop"
|
||||
if doSomething() != nil {
|
||||
ctx2 = wrapContext(ctx2)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func wrapContext(ctx context.Context) context.Context {
|
||||
|
@ -180,3 +200,31 @@ func inVariousNestedBlocks(ctx context.Context) {
|
|||
break
|
||||
}
|
||||
}
|
||||
|
||||
// this middleware could run on every request, bloating the request parameter level context and causing a memory leak
|
||||
func badMiddleware(ctx context.Context) func() error {
|
||||
return func() error {
|
||||
ctx = wrapContext(ctx) // want "nested context in function literal"
|
||||
return doSomethingWithCtx(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// this middleware is fine, as it doesn't modify the context of parent function
|
||||
func okMiddleware(ctx context.Context) func() error {
|
||||
return func() error {
|
||||
ctx := wrapContext(ctx)
|
||||
return doSomethingWithCtx(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// this middleware is fine, as it only modifies the context passed to it
|
||||
func okMiddleware2(ctx context.Context) func(ctx context.Context) error {
|
||||
return func(ctx context.Context) error {
|
||||
ctx = wrapContext(ctx)
|
||||
return doSomethingWithCtx(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func doSomethingWithCtx(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue