Compare commits

...

6 commits

3 changed files with 122 additions and 71 deletions

View file

@ -1,12 +1,6 @@
# This is an example .goreleaser.yml file with some sensible defaults.
# Make sure to check the documentation at https://goreleaser.com
# The lines below are called `modelines`. See `:help modeline`
# Feel free to remove those if you don't want/need to use them.
# yaml-language-server: $schema=https://goreleaser.com/static/schema.json
# vim: set ts=2 sw=2 tw=0 fo=cnqoj
version: 1
version: 2
force_token: github
@ -45,3 +39,6 @@ changelog:
exclude:
- "^docs:"
- "^test:"
release:
draft: true

View file

@ -7,7 +7,7 @@ import (
"go/ast"
"go/printer"
"go/token"
"go/types"
"slices"
"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/analysis/passes/inspect"
@ -30,6 +30,7 @@ func run(pass *analysis.Pass) (interface{}, error) {
(*ast.ForStmt)(nil),
(*ast.RangeStmt)(nil),
(*ast.FuncLit)(nil),
(*ast.FuncDecl)(nil),
}
inspctr.Preorder(nodeFilter, func(node ast.Node) {
@ -81,25 +82,23 @@ func getReportMessage(node ast.Node) string {
return "nested context in loop"
case *ast.FuncLit:
return "nested context in function literal"
case *ast.FuncDecl:
return "potential nested context in function declaration"
default:
return "unsupported nested context type"
}
}
func getBody(node ast.Node) (*ast.BlockStmt, error) {
forStmt, ok := node.(*ast.ForStmt)
if ok {
return forStmt.Body, nil
}
rangeStmt, ok := node.(*ast.RangeStmt)
if ok {
return rangeStmt.Body, nil
}
funcLit, ok := node.(*ast.FuncLit)
if ok {
return funcLit.Body, nil
switch typedNode := node.(type) {
case *ast.ForStmt:
return typedNode.Body, nil
case *ast.RangeStmt:
return typedNode.Body, nil
case *ast.FuncLit:
return typedNode.Body, nil
case *ast.FuncDecl:
return typedNode.Body, nil
}
return nil, errUnknown
@ -108,44 +107,29 @@ func getBody(node ast.Node) (*ast.BlockStmt, error) {
func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *ast.AssignStmt {
for _, stmt := range stmts {
// Recurse if necessary
if inner, ok := stmt.(*ast.BlockStmt); ok {
found := findNestedContext(pass, node, inner.List)
if found != nil {
switch typedStmt := stmt.(type) {
case *ast.BlockStmt:
if found := findNestedContext(pass, node, typedStmt.List); found != nil {
return found
}
}
if inner, ok := stmt.(*ast.IfStmt); ok {
found := findNestedContext(pass, node, inner.Body.List)
if found != nil {
case *ast.IfStmt:
if found := findNestedContext(pass, node, typedStmt.Body.List); found != nil {
return found
}
}
if inner, ok := stmt.(*ast.SwitchStmt); ok {
found := findNestedContext(pass, node, inner.Body.List)
if found != nil {
case *ast.SwitchStmt:
if found := findNestedContext(pass, node, typedStmt.Body.List); found != nil {
return found
}
}
if inner, ok := stmt.(*ast.CaseClause); ok {
found := findNestedContext(pass, node, inner.Body)
if found != nil {
case *ast.CaseClause:
if found := findNestedContext(pass, node, typedStmt.Body); found != nil {
return found
}
}
if inner, ok := stmt.(*ast.SelectStmt); ok {
found := findNestedContext(pass, node, inner.Body.List)
if found != nil {
case *ast.SelectStmt:
if found := findNestedContext(pass, node, typedStmt.Body.List); found != nil {
return found
}
}
if inner, ok := stmt.(*ast.CommClause); ok {
found := findNestedContext(pass, node, inner.Body)
if found != nil {
case *ast.CommClause:
if found := findNestedContext(pass, node, typedStmt.Body); found != nil {
return found
}
}
@ -169,13 +153,14 @@ func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *as
continue
}
// Ignore [context.Background] & [context.TODO].
if isContextFunction(assignStmt.Rhs[0], "Background", "TODO") {
continue
}
// allow assignment to non-pointer children of values defined within the loop
if lhs := getRootIdent(pass, assignStmt.Lhs[0]); lhs != nil {
if obj := pass.TypesInfo.ObjectOf(lhs); obj != nil {
if checkObjectScopeWithinNode(obj.Parent(), node) {
continue // definition is within the loop
}
}
if isWithinLoop(assignStmt.Lhs[0], node, pass) {
continue
}
return assignStmt
@ -184,16 +169,51 @@ func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *as
return nil
}
func checkObjectScopeWithinNode(scope *types.Scope, node ast.Node) bool {
// render returns the pretty-print of the given node
func render(fset *token.FileSet, x interface{}) ([]byte, error) {
var buf bytes.Buffer
if err := printer.Fprint(&buf, fset, x); err != nil {
return nil, fmt.Errorf("printing node: %w", err)
}
return buf.Bytes(), nil
}
func isContextFunction(exp ast.Expr, fnName ...string) bool {
call, ok := exp.(*ast.CallExpr)
if !ok {
return false
}
selector, ok := call.Fun.(*ast.SelectorExpr)
if !ok {
return false
}
ident, ok := selector.X.(*ast.Ident)
if !ok {
return false
}
return ident.Name == "context" && slices.Contains(fnName, selector.Sel.Name)
}
func isWithinLoop(exp ast.Expr, node ast.Node, pass *analysis.Pass) bool {
lhs := getRootIdent(pass, exp)
if lhs == nil {
return false
}
obj := pass.TypesInfo.ObjectOf(lhs)
if obj == nil {
return false
}
scope := obj.Parent()
if scope == nil {
return false
}
if scope.Pos() >= node.Pos() && scope.End() <= node.End() {
return true
}
return false
return scope.Pos() >= node.Pos() && scope.End() <= node.End()
}
func getRootIdent(pass *analysis.Pass, node ast.Node) *ast.Ident {
@ -213,12 +233,3 @@ func getRootIdent(pass *analysis.Pass, node ast.Node) *ast.Ident {
}
}
}
// render returns the pretty-print of the given node
func render(fset *token.FileSet, x interface{}) ([]byte, error) {
var buf bytes.Buffer
if err := printer.Fprint(&buf, fset, x); err != nil {
return nil, fmt.Errorf("printing node: %w", err)
}
return buf.Bytes(), nil
}

View file

@ -1,6 +1,9 @@
package src
import "context"
import (
"context"
"testing"
)
func example() {
ctx := context.Background()
@ -228,3 +231,43 @@ func okMiddleware2(ctx context.Context) func(ctx context.Context) error {
func doSomethingWithCtx(ctx context.Context) error {
return nil
}
func testCasesInit(t *testing.T) {
cases := []struct {
ctx context.Context
}{
{},
{
ctx: context.WithValue(context.Background(), "key", "value"),
},
}
for _, tc := range cases {
t.Run("some test", func(t *testing.T) {
if tc.ctx == nil {
tc.ctx = context.Background()
}
})
}
}
type Container struct {
Ctx context.Context
}
func something() func(*Container) {
return func(r *Container) {
ctx := r.Ctx
ctx = context.WithValue(ctx, "key", "val")
r.Ctx = ctx // want "nested context in function literal"
}
}
func other() func(*Container) {
return blah
}
func blah(r *Container) {
ctx := r.Ctx
ctx = context.WithValue(ctx, "key", "val")
r.Ctx = ctx // want "potential nested context in function declaration"
}