mirror of
https://github.com/Crocmagnon/fatcontext.git
synced 2025-02-05 12:12:32 +01:00
parent
529e088561
commit
54e593c1c6
2 changed files with 71 additions and 23 deletions
|
@ -7,7 +7,7 @@ import (
|
||||||
"go/ast"
|
"go/ast"
|
||||||
"go/printer"
|
"go/printer"
|
||||||
"go/token"
|
"go/token"
|
||||||
"go/types"
|
"slices"
|
||||||
|
|
||||||
"golang.org/x/tools/go/analysis"
|
"golang.org/x/tools/go/analysis"
|
||||||
"golang.org/x/tools/go/analysis/passes/inspect"
|
"golang.org/x/tools/go/analysis/passes/inspect"
|
||||||
|
@ -169,13 +169,14 @@ func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *as
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Ignore [context.Background] & [context.TODO].
|
||||||
|
if isContextFunction(assignStmt.Rhs[0], "Background", "TODO") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// allow assignment to non-pointer children of values defined within the loop
|
// allow assignment to non-pointer children of values defined within the loop
|
||||||
if lhs := getRootIdent(pass, assignStmt.Lhs[0]); lhs != nil {
|
if isWithinLoop(assignStmt.Lhs[0], node, pass) {
|
||||||
if obj := pass.TypesInfo.ObjectOf(lhs); obj != nil {
|
continue
|
||||||
if checkObjectScopeWithinNode(obj.Parent(), node) {
|
|
||||||
continue // definition is within the loop
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return assignStmt
|
return assignStmt
|
||||||
|
@ -184,16 +185,51 @@ func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *as
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkObjectScopeWithinNode(scope *types.Scope, node ast.Node) bool {
|
// render returns the pretty-print of the given node
|
||||||
|
func render(fset *token.FileSet, x interface{}) ([]byte, error) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if err := printer.Fprint(&buf, fset, x); err != nil {
|
||||||
|
return nil, fmt.Errorf("printing node: %w", err)
|
||||||
|
}
|
||||||
|
return buf.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isContextFunction(exp ast.Expr, fnName ...string) bool {
|
||||||
|
call, ok := exp.(*ast.CallExpr)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
selector, ok := call.Fun.(*ast.SelectorExpr)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
ident, ok := selector.X.(*ast.Ident)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return ident.Name == "context" && slices.Contains(fnName, selector.Sel.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isWithinLoop(exp ast.Expr, node ast.Node, pass *analysis.Pass) bool {
|
||||||
|
lhs := getRootIdent(pass, exp)
|
||||||
|
if lhs == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
obj := pass.TypesInfo.ObjectOf(lhs)
|
||||||
|
if obj == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
scope := obj.Parent()
|
||||||
if scope == nil {
|
if scope == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if scope.Pos() >= node.Pos() && scope.End() <= node.End() {
|
return scope.Pos() >= node.Pos() && scope.End() <= node.End()
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func getRootIdent(pass *analysis.Pass, node ast.Node) *ast.Ident {
|
func getRootIdent(pass *analysis.Pass, node ast.Node) *ast.Ident {
|
||||||
|
@ -213,12 +249,3 @@ func getRootIdent(pass *analysis.Pass, node ast.Node) *ast.Ident {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// render returns the pretty-print of the given node
|
|
||||||
func render(fset *token.FileSet, x interface{}) ([]byte, error) {
|
|
||||||
var buf bytes.Buffer
|
|
||||||
if err := printer.Fprint(&buf, fset, x); err != nil {
|
|
||||||
return nil, fmt.Errorf("printing node: %w", err)
|
|
||||||
}
|
|
||||||
return buf.Bytes(), nil
|
|
||||||
}
|
|
||||||
|
|
23
testdata/src/example.go
vendored
23
testdata/src/example.go
vendored
|
@ -1,6 +1,9 @@
|
||||||
package src
|
package src
|
||||||
|
|
||||||
import "context"
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
func example() {
|
func example() {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
@ -228,3 +231,21 @@ func okMiddleware2(ctx context.Context) func(ctx context.Context) error {
|
||||||
func doSomethingWithCtx(ctx context.Context) error {
|
func doSomethingWithCtx(ctx context.Context) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func testCasesInit(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
ctx context.Context
|
||||||
|
}{
|
||||||
|
{},
|
||||||
|
{
|
||||||
|
ctx: context.WithValue(context.Background(), "key", "value"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run("some test", func(t *testing.T) {
|
||||||
|
if tc.ctx == nil {
|
||||||
|
tc.ctx = context.Background()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue