Compare commits

..

4 commits

Author SHA1 Message Date
Venkatesh Kotwade
c78136bc10
Merge e482fa4d35 into 0be9888cea 2024-08-25 03:12:51 +05:30
Venkatesh Kotwade
e482fa4d35 refactor: Update getReportMessage function to handle unsupported nested context types 2024-08-25 03:12:41 +05:30
Venkatesh Kotwade
99d25865c2 feat: Improve detection of nested contexts in function literals 2024-08-25 02:22:31 +05:30
Venkatesh Kotwade
6ddf255ca8 feat: Add detection for nested contexts in function literals 2024-08-25 01:49:35 +05:30
4 changed files with 17 additions and 30 deletions

View file

@ -14,9 +14,8 @@ permissions:
jobs: jobs:
build: build:
strategy: strategy:
fail-fast: false
matrix: matrix:
go: ['1.22', '1.23'] go: ['1.21', '1.22', '1.23.0-rc.2']
os: [macos-latest, windows-latest, ubuntu-latest] os: [macos-latest, windows-latest, ubuntu-latest]
name: build name: build
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}

2
go.mod
View file

@ -1,6 +1,6 @@
module github.com/Crocmagnon/fatcontext module github.com/Crocmagnon/fatcontext
go 1.22.6 go 1.21
require golang.org/x/tools v0.23.0 require golang.org/x/tools v0.23.0

View file

@ -38,7 +38,7 @@ func run(pass *analysis.Pass) (interface{}, error) {
return return
} }
assignStmt := findNestedContext(pass, node, body.List) assignStmt := findNestedContext(pass, body, body.List)
if assignStmt == nil { if assignStmt == nil {
return return
} }
@ -78,7 +78,9 @@ func run(pass *analysis.Pass) (interface{}, error) {
func getReportMessage(node ast.Node) string { func getReportMessage(node ast.Node) string {
switch node.(type) { switch node.(type) {
case *ast.ForStmt, *ast.RangeStmt: case *ast.ForStmt:
return "nested context in loop"
case *ast.RangeStmt:
return "nested context in loop" return "nested context in loop"
case *ast.FuncLit: case *ast.FuncLit:
return "nested context in function literal" return "nested context in function literal"
@ -106,46 +108,46 @@ func getBody(node ast.Node) (*ast.BlockStmt, error) {
return nil, errUnknown return nil, errUnknown
} }
func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *ast.AssignStmt { func findNestedContext(pass *analysis.Pass, block *ast.BlockStmt, stmts []ast.Stmt) *ast.AssignStmt {
for _, stmt := range stmts { for _, stmt := range stmts {
// Recurse if necessary // Recurse if necessary
if inner, ok := stmt.(*ast.BlockStmt); ok { if inner, ok := stmt.(*ast.BlockStmt); ok {
found := findNestedContext(pass, node, inner.List) found := findNestedContext(pass, inner, inner.List)
if found != nil { if found != nil {
return found return found
} }
} }
if inner, ok := stmt.(*ast.IfStmt); ok { if inner, ok := stmt.(*ast.IfStmt); ok {
found := findNestedContext(pass, node, inner.Body.List) found := findNestedContext(pass, inner.Body, inner.Body.List)
if found != nil { if found != nil {
return found return found
} }
} }
if inner, ok := stmt.(*ast.SwitchStmt); ok { if inner, ok := stmt.(*ast.SwitchStmt); ok {
found := findNestedContext(pass, node, inner.Body.List) found := findNestedContext(pass, inner.Body, inner.Body.List)
if found != nil { if found != nil {
return found return found
} }
} }
if inner, ok := stmt.(*ast.CaseClause); ok { if inner, ok := stmt.(*ast.CaseClause); ok {
found := findNestedContext(pass, node, inner.Body) found := findNestedContext(pass, block, inner.Body)
if found != nil { if found != nil {
return found return found
} }
} }
if inner, ok := stmt.(*ast.SelectStmt); ok { if inner, ok := stmt.(*ast.SelectStmt); ok {
found := findNestedContext(pass, node, inner.Body.List) found := findNestedContext(pass, inner.Body, inner.Body.List)
if found != nil { if found != nil {
return found return found
} }
} }
if inner, ok := stmt.(*ast.CommClause); ok { if inner, ok := stmt.(*ast.CommClause); ok {
found := findNestedContext(pass, node, inner.Body) found := findNestedContext(pass, block, inner.Body)
if found != nil { if found != nil {
return found return found
} }
@ -167,13 +169,13 @@ func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *as
} }
if assignStmt.Tok == token.DEFINE { if assignStmt.Tok == token.DEFINE {
continue break
} }
// allow assignment to non-pointer children of values defined within the loop // allow assignment to non-pointer children of values defined within the loop
if lhs := getRootIdent(pass, assignStmt.Lhs[0]); lhs != nil { if lhs := getRootIdent(pass, assignStmt.Lhs[0]); lhs != nil {
if obj := pass.TypesInfo.ObjectOf(lhs); obj != nil { if obj := pass.TypesInfo.ObjectOf(lhs); obj != nil {
if checkObjectScopeWithinNode(obj.Parent(), node) { if checkObjectScopeWithinBlock(obj.Parent(), block) {
continue // definition is within the loop continue // definition is within the loop
} }
} }
@ -185,12 +187,12 @@ func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *as
return nil return nil
} }
func checkObjectScopeWithinNode(scope *types.Scope, node ast.Node) bool { func checkObjectScopeWithinBlock(scope *types.Scope, block *ast.BlockStmt) bool {
if scope == nil { if scope == nil {
return false return false
} }
if scope.Pos() >= node.Pos() && scope.End() <= node.End() { if scope.Pos() >= block.Pos() && scope.End() <= block.End() {
return true return true
} }

View file

@ -65,20 +65,6 @@ func example() {
ctx = wrapContext(ctx) // want "nested context in function literal" ctx = wrapContext(ctx) // want "nested context in function literal"
} }
// this is fine because the context is created in the loop
for {
if ctx := context.Background(); doSomething() != nil {
ctx = wrapContext(ctx)
}
}
for {
ctx2 := context.Background()
ctx = wrapContext(ctx) // want "nested context in loop"
if doSomething() != nil {
ctx2 = wrapContext(ctx2)
}
}
} }
func wrapContext(ctx context.Context) context.Context { func wrapContext(ctx context.Context) context.Context {