mirror of
https://github.com/Crocmagnon/fatcontext.git
synced 2025-04-11 12:06:33 +02:00
Compare commits
6 commits
07e88037a8
...
bc3b0d1e1c
Author | SHA1 | Date | |
---|---|---|---|
bc3b0d1e1c | |||
9371bcfb56 | |||
5b689092ef | |||
939d65bc16 | |||
54e593c1c6 | |||
529e088561 |
3 changed files with 122 additions and 71 deletions
|
@ -1,12 +1,6 @@
|
|||
# This is an example .goreleaser.yml file with some sensible defaults.
|
||||
# Make sure to check the documentation at https://goreleaser.com
|
||||
|
||||
# The lines below are called `modelines`. See `:help modeline`
|
||||
# Feel free to remove those if you don't want/need to use them.
|
||||
# yaml-language-server: $schema=https://goreleaser.com/static/schema.json
|
||||
# vim: set ts=2 sw=2 tw=0 fo=cnqoj
|
||||
|
||||
version: 1
|
||||
version: 2
|
||||
|
||||
force_token: github
|
||||
|
||||
|
@ -45,3 +39,6 @@ changelog:
|
|||
exclude:
|
||||
- "^docs:"
|
||||
- "^test:"
|
||||
|
||||
release:
|
||||
draft: true
|
||||
|
|
|
@ -7,7 +7,7 @@ import (
|
|||
"go/ast"
|
||||
"go/printer"
|
||||
"go/token"
|
||||
"go/types"
|
||||
"slices"
|
||||
|
||||
"golang.org/x/tools/go/analysis"
|
||||
"golang.org/x/tools/go/analysis/passes/inspect"
|
||||
|
@ -30,6 +30,7 @@ func run(pass *analysis.Pass) (interface{}, error) {
|
|||
(*ast.ForStmt)(nil),
|
||||
(*ast.RangeStmt)(nil),
|
||||
(*ast.FuncLit)(nil),
|
||||
(*ast.FuncDecl)(nil),
|
||||
}
|
||||
|
||||
inspctr.Preorder(nodeFilter, func(node ast.Node) {
|
||||
|
@ -81,25 +82,23 @@ func getReportMessage(node ast.Node) string {
|
|||
return "nested context in loop"
|
||||
case *ast.FuncLit:
|
||||
return "nested context in function literal"
|
||||
case *ast.FuncDecl:
|
||||
return "potential nested context in function declaration"
|
||||
default:
|
||||
return "unsupported nested context type"
|
||||
}
|
||||
}
|
||||
|
||||
func getBody(node ast.Node) (*ast.BlockStmt, error) {
|
||||
forStmt, ok := node.(*ast.ForStmt)
|
||||
if ok {
|
||||
return forStmt.Body, nil
|
||||
}
|
||||
|
||||
rangeStmt, ok := node.(*ast.RangeStmt)
|
||||
if ok {
|
||||
return rangeStmt.Body, nil
|
||||
}
|
||||
|
||||
funcLit, ok := node.(*ast.FuncLit)
|
||||
if ok {
|
||||
return funcLit.Body, nil
|
||||
switch typedNode := node.(type) {
|
||||
case *ast.ForStmt:
|
||||
return typedNode.Body, nil
|
||||
case *ast.RangeStmt:
|
||||
return typedNode.Body, nil
|
||||
case *ast.FuncLit:
|
||||
return typedNode.Body, nil
|
||||
case *ast.FuncDecl:
|
||||
return typedNode.Body, nil
|
||||
}
|
||||
|
||||
return nil, errUnknown
|
||||
|
@ -108,44 +107,29 @@ func getBody(node ast.Node) (*ast.BlockStmt, error) {
|
|||
func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *ast.AssignStmt {
|
||||
for _, stmt := range stmts {
|
||||
// Recurse if necessary
|
||||
if inner, ok := stmt.(*ast.BlockStmt); ok {
|
||||
found := findNestedContext(pass, node, inner.List)
|
||||
if found != nil {
|
||||
switch typedStmt := stmt.(type) {
|
||||
case *ast.BlockStmt:
|
||||
if found := findNestedContext(pass, node, typedStmt.List); found != nil {
|
||||
return found
|
||||
}
|
||||
}
|
||||
|
||||
if inner, ok := stmt.(*ast.IfStmt); ok {
|
||||
found := findNestedContext(pass, node, inner.Body.List)
|
||||
if found != nil {
|
||||
case *ast.IfStmt:
|
||||
if found := findNestedContext(pass, node, typedStmt.Body.List); found != nil {
|
||||
return found
|
||||
}
|
||||
}
|
||||
|
||||
if inner, ok := stmt.(*ast.SwitchStmt); ok {
|
||||
found := findNestedContext(pass, node, inner.Body.List)
|
||||
if found != nil {
|
||||
case *ast.SwitchStmt:
|
||||
if found := findNestedContext(pass, node, typedStmt.Body.List); found != nil {
|
||||
return found
|
||||
}
|
||||
}
|
||||
|
||||
if inner, ok := stmt.(*ast.CaseClause); ok {
|
||||
found := findNestedContext(pass, node, inner.Body)
|
||||
if found != nil {
|
||||
case *ast.CaseClause:
|
||||
if found := findNestedContext(pass, node, typedStmt.Body); found != nil {
|
||||
return found
|
||||
}
|
||||
}
|
||||
|
||||
if inner, ok := stmt.(*ast.SelectStmt); ok {
|
||||
found := findNestedContext(pass, node, inner.Body.List)
|
||||
if found != nil {
|
||||
case *ast.SelectStmt:
|
||||
if found := findNestedContext(pass, node, typedStmt.Body.List); found != nil {
|
||||
return found
|
||||
}
|
||||
}
|
||||
|
||||
if inner, ok := stmt.(*ast.CommClause); ok {
|
||||
found := findNestedContext(pass, node, inner.Body)
|
||||
if found != nil {
|
||||
case *ast.CommClause:
|
||||
if found := findNestedContext(pass, node, typedStmt.Body); found != nil {
|
||||
return found
|
||||
}
|
||||
}
|
||||
|
@ -169,13 +153,14 @@ func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *as
|
|||
continue
|
||||
}
|
||||
|
||||
// Ignore [context.Background] & [context.TODO].
|
||||
if isContextFunction(assignStmt.Rhs[0], "Background", "TODO") {
|
||||
continue
|
||||
}
|
||||
|
||||
// allow assignment to non-pointer children of values defined within the loop
|
||||
if lhs := getRootIdent(pass, assignStmt.Lhs[0]); lhs != nil {
|
||||
if obj := pass.TypesInfo.ObjectOf(lhs); obj != nil {
|
||||
if checkObjectScopeWithinNode(obj.Parent(), node) {
|
||||
continue // definition is within the loop
|
||||
}
|
||||
}
|
||||
if isWithinLoop(assignStmt.Lhs[0], node, pass) {
|
||||
continue
|
||||
}
|
||||
|
||||
return assignStmt
|
||||
|
@ -184,16 +169,51 @@ func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *as
|
|||
return nil
|
||||
}
|
||||
|
||||
func checkObjectScopeWithinNode(scope *types.Scope, node ast.Node) bool {
|
||||
// render returns the pretty-print of the given node
|
||||
func render(fset *token.FileSet, x interface{}) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
if err := printer.Fprint(&buf, fset, x); err != nil {
|
||||
return nil, fmt.Errorf("printing node: %w", err)
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func isContextFunction(exp ast.Expr, fnName ...string) bool {
|
||||
call, ok := exp.(*ast.CallExpr)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
selector, ok := call.Fun.(*ast.SelectorExpr)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
ident, ok := selector.X.(*ast.Ident)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return ident.Name == "context" && slices.Contains(fnName, selector.Sel.Name)
|
||||
}
|
||||
|
||||
func isWithinLoop(exp ast.Expr, node ast.Node, pass *analysis.Pass) bool {
|
||||
lhs := getRootIdent(pass, exp)
|
||||
if lhs == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
obj := pass.TypesInfo.ObjectOf(lhs)
|
||||
if obj == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
scope := obj.Parent()
|
||||
if scope == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if scope.Pos() >= node.Pos() && scope.End() <= node.End() {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
return scope.Pos() >= node.Pos() && scope.End() <= node.End()
|
||||
}
|
||||
|
||||
func getRootIdent(pass *analysis.Pass, node ast.Node) *ast.Ident {
|
||||
|
@ -213,12 +233,3 @@ func getRootIdent(pass *analysis.Pass, node ast.Node) *ast.Ident {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// render returns the pretty-print of the given node
|
||||
func render(fset *token.FileSet, x interface{}) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
if err := printer.Fprint(&buf, fset, x); err != nil {
|
||||
return nil, fmt.Errorf("printing node: %w", err)
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
|
45
testdata/src/example.go
vendored
45
testdata/src/example.go
vendored
|
@ -1,6 +1,9 @@
|
|||
package src
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func example() {
|
||||
ctx := context.Background()
|
||||
|
@ -228,3 +231,43 @@ func okMiddleware2(ctx context.Context) func(ctx context.Context) error {
|
|||
func doSomethingWithCtx(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func testCasesInit(t *testing.T) {
|
||||
cases := []struct {
|
||||
ctx context.Context
|
||||
}{
|
||||
{},
|
||||
{
|
||||
ctx: context.WithValue(context.Background(), "key", "value"),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run("some test", func(t *testing.T) {
|
||||
if tc.ctx == nil {
|
||||
tc.ctx = context.Background()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type Container struct {
|
||||
Ctx context.Context
|
||||
}
|
||||
|
||||
func something() func(*Container) {
|
||||
return func(r *Container) {
|
||||
ctx := r.Ctx
|
||||
ctx = context.WithValue(ctx, "key", "val")
|
||||
r.Ctx = ctx // want "nested context in function literal"
|
||||
}
|
||||
}
|
||||
|
||||
func other() func(*Container) {
|
||||
return blah
|
||||
}
|
||||
|
||||
func blah(r *Container) {
|
||||
ctx := r.Ctx
|
||||
ctx = context.WithValue(ctx, "key", "val")
|
||||
r.Ctx = ctx // want "potential nested context in function declaration"
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue