Compare commits

..

No commits in common. "master" and "v0.7.0" have entirely different histories.

16 changed files with 84 additions and 450 deletions

View file

@ -6,10 +6,6 @@ updates:
directory: "/"
schedule:
interval: "monthly"
groups:
github-actions:
patterns:
- "*" # Group all updates into a single larger pull request.
- package-ecosystem: "gomod"
directory: "/"
schedule:

View file

@ -16,7 +16,7 @@ jobs:
strategy:
fail-fast: false
matrix:
go: [stable, oldstable]
go: ['1.22', '1.23']
os: [macos-latest, windows-latest, ubuntu-latest]
name: build
runs-on: ${{ matrix.os }}

View file

@ -16,7 +16,7 @@ jobs:
golangci:
strategy:
matrix:
go: [stable, oldstable]
go: ['1.22', '1.23']
name: lint
runs-on: ubuntu-latest
steps:
@ -25,6 +25,6 @@ jobs:
with:
go-version: ${{ matrix.go }}
- name: golangci-lint
uses: golangci/golangci-lint-action@v7
uses: golangci/golangci-lint-action@v6
with:
version: latest

View file

@ -17,8 +17,6 @@ jobs:
fetch-depth: 0
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: stable
- name: Run GoReleaser
uses: goreleaser/goreleaser-action@v6
with:

View file

@ -1,18 +1,9 @@
version: "2"
issues:
fix: true
exclude-dirs:
- contrib
linters:
default: all
disable:
- depguard
- exhaustruct
formatters:
enable:
- goimports
- gofmt
- gofumpt
- golines
settings:
goimports:
local-prefixes:
- github.com/Crocmagnon/fatcontext
linters-settings:
goimports:
local-prefixes: "github.com/Crocmagnon/fatcontext"

View file

@ -9,7 +9,7 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/golangci/golangci-lint
rev: v2.0.2
rev: v1.63.4
hooks:
- id: golangci-lint-full
- repo: local

View file

@ -1,4 +1,3 @@
// Package main runs the analyzer. It's the CLI entrypoint.
package main
import (

8
go.mod
View file

@ -1,10 +1,10 @@
module github.com/Crocmagnon/fatcontext
go 1.23.0
go 1.22.0
require golang.org/x/tools v0.31.0
require golang.org/x/tools v0.28.0
require (
golang.org/x/mod v0.24.0 // indirect
golang.org/x/sync v0.12.0 // indirect
golang.org/x/mod v0.22.0 // indirect
golang.org/x/sync v0.10.0 // indirect
)

12
go.sum
View file

@ -1,8 +1,8 @@
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU=
golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/tools v0.31.0 h1:0EedkvKDbh+qistFTd0Bcwe/YLh4vHwWEkiI0toFIBU=
golang.org/x/tools v0.31.0/go.mod h1:naFTU+Cev749tSJRXJlna0T3WxKvb1kWEx15xA4SdmQ=
golang.org/x/mod v0.22.0 h1:D4nJWe9zXqHOmWqj4VMOJhvzj7bEZg4wEYa759z1pH4=
golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/tools v0.28.0 h1:WuB6qZ4RPCQo5aP3WdKZS7i595EdWqWR8vqJTlwTVK8=
golang.org/x/tools v0.28.0/go.mod h1:dcIOrVd3mfQKTgrDVQHqCPMWy6lnhfhtX3hLXYVLfRw=

View file

@ -1,4 +1,3 @@
// Package analyzer contains everything related to the linter analysis.
package analyzer
import (
@ -16,31 +15,25 @@ import (
"golang.org/x/tools/go/ast/inspector"
)
// FlagCheckStructPointers is a possible flag for the analyzer.
// Exported to make it usable in golangci-lint.
const FlagCheckStructPointers = "check-struct-pointers"
// NewAnalyzer returns a fatcontext analyzer.
func NewAnalyzer() *analysis.Analyzer {
rnnr := &runner{}
r := &runner{}
flags := flag.NewFlagSet("fatcontext", flag.ExitOnError)
flags.BoolVar(&rnnr.DetectInStructPointers, FlagCheckStructPointers, false,
flags.BoolVar(&r.DetectInStructPointers, FlagCheckStructPointers, false,
"set to true to detect potential fat contexts in struct pointers")
return &analysis.Analyzer{
Name: "fatcontext",
Doc: "detects nested contexts in loops and function literals",
Run: rnnr.run,
Run: r.run,
Flags: *flags,
Requires: []*analysis.Analyzer{inspect.Analyzer},
}
}
var (
errUnknown = errors.New("unknown node type")
errInvalidAnalysis = errors.New("invalid analysis")
)
var errUnknown = errors.New("unknown node type")
const (
categoryInLoop = "nested context in loop"
@ -54,10 +47,7 @@ type runner struct {
}
func (r *runner) run(pass *analysis.Pass) (interface{}, error) {
inspctr, typeValid := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
if !typeValid {
return nil, errInvalidAnalysis
}
inspctr := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
nodeFilter := []ast.Node{
(*ast.ForStmt)(nil),
@ -72,10 +62,6 @@ func (r *runner) run(pass *analysis.Pass) (interface{}, error) {
return
}
if body == nil {
return
}
assignStmt := findNestedContext(pass, node, body.List)
if assignStmt == nil {
return
@ -96,18 +82,14 @@ func (r *runner) run(pass *analysis.Pass) (interface{}, error) {
})
})
return nil, nil //nolint:nilnil // we have no result to send to other analyzers
return nil, nil
}
func (r *runner) shouldIgnoreReport(category string) bool {
return category == categoryInStructPointer && !r.DetectInStructPointers
}
func (r *runner) getSuggestedFixes(
pass *analysis.Pass,
assignStmt *ast.AssignStmt,
category string,
) []analysis.SuggestedFix {
func (r *runner) getSuggestedFixes(pass *analysis.Pass, assignStmt *ast.AssignStmt, category string) []analysis.SuggestedFix {
switch category {
case categoryInStructPointer, categoryUnsupported:
return nil
@ -174,9 +156,31 @@ func getBody(node ast.Node) (*ast.BlockStmt, error) {
func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *ast.AssignStmt {
for _, stmt := range stmts {
// Recurse if necessary
stmtList := getStmtList(stmt)
if found := findNestedContext(pass, node, stmtList); found != nil {
return found
switch typedStmt := stmt.(type) {
case *ast.BlockStmt:
if found := findNestedContext(pass, node, typedStmt.List); found != nil {
return found
}
case *ast.IfStmt:
if found := findNestedContext(pass, node, typedStmt.Body.List); found != nil {
return found
}
case *ast.SwitchStmt:
if found := findNestedContext(pass, node, typedStmt.Body.List); found != nil {
return found
}
case *ast.CaseClause:
if found := findNestedContext(pass, node, typedStmt.Body); found != nil {
return found
}
case *ast.SelectStmt:
if found := findNestedContext(pass, node, typedStmt.Body.List); found != nil {
return found
}
case *ast.CommClause:
if found := findNestedContext(pass, node, typedStmt.Body); found != nil {
return found
}
}
// Actually check for nested context
@ -218,48 +222,28 @@ func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *as
return nil
}
func getStmtList(stmt ast.Stmt) []ast.Stmt {
switch typedStmt := stmt.(type) {
case *ast.BlockStmt:
return typedStmt.List
case *ast.IfStmt:
return typedStmt.Body.List
case *ast.SwitchStmt:
return typedStmt.Body.List
case *ast.CaseClause:
return typedStmt.Body
case *ast.SelectStmt:
return typedStmt.Body.List
case *ast.CommClause:
return typedStmt.Body
}
return nil
}
// render returns the pretty-print of the given node.
// render returns the pretty-print of the given node
func render(fset *token.FileSet, x interface{}) ([]byte, error) {
var buf bytes.Buffer
if err := printer.Fprint(&buf, fset, x); err != nil {
return nil, fmt.Errorf("printing node: %w", err)
}
return buf.Bytes(), nil
}
func isContextFunction(exp ast.Expr, fnName ...string) bool {
call, typeValid := exp.(*ast.CallExpr)
if !typeValid {
call, ok := exp.(*ast.CallExpr)
if !ok {
return false
}
selector, typeValid := call.Fun.(*ast.SelectorExpr)
if !typeValid {
selector, ok := call.Fun.(*ast.SelectorExpr)
if !ok {
return false
}
ident, typeValid := selector.X.(*ast.Ident)
if !typeValid {
ident, ok := selector.X.(*ast.Ident)
if !ok {
return false
}
@ -287,17 +271,16 @@ func isWithinLoop(exp ast.Expr, node ast.Node, pass *analysis.Pass) bool {
func getRootIdent(pass *analysis.Pass, node ast.Node) *ast.Ident {
for {
switch typedNode := node.(type) {
switch n := node.(type) {
case *ast.Ident:
return typedNode
return n
case *ast.IndexExpr:
node = typedNode.X
node = n.X
case *ast.SelectorExpr:
if sel, ok := pass.TypesInfo.Selections[typedNode]; ok && sel.Indirect() {
if sel, ok := pass.TypesInfo.Selections[n]; ok && sel.Indirect() {
return nil // indirected (pointer) roots don't imply a (safe) copy
}
node = typedNode.X
node = n.X
default:
return nil
}
@ -305,10 +288,9 @@ func getRootIdent(pass *analysis.Pass, node ast.Node) *ast.Ident {
}
func isPointer(pass *analysis.Pass, exp ast.Node) bool {
switch n := exp.(type) { //nolint:gocritic // Future-proofing with switch instead of if.
switch n := exp.(type) {
case *ast.SelectorExpr:
sel, ok := pass.TypesInfo.Selections[n]
return ok && sel.Indirect()
}

View file

@ -1,6 +1,8 @@
package analyzer_test
import (
"os"
"path/filepath"
"testing"
"golang.org/x/tools/go/analysis/analysistest"
@ -9,59 +11,27 @@ import (
)
func TestAnalyzer(t *testing.T) {
t.Parallel()
testCases := []struct {
desc string
dir string
options map[string]string
}{
{
desc: "no func decl",
dir: "common",
},
{
desc: "no func decl",
dir: "no_structpointer",
},
{
desc: "func decl",
dir: "common",
options: map[string]string{
analyzer.FlagCheckStructPointers: "true",
},
},
{
desc: "func decl",
dir: "structpointer",
options: map[string]string{
analyzer.FlagCheckStructPointers: "true",
},
},
wd, err := os.Getwd()
if err != nil {
t.Fatalf("Failed to get wd: %s", err)
}
testdata := filepath.Join(wd, "testdata")
for _, test := range testCases {
t.Run(test.desc+"_"+test.dir, func(t *testing.T) {
t.Parallel()
t.Run("no func decl", func(t *testing.T) {
an := analyzer.NewAnalyzer()
analysistest.Run(t, testdata, an, "./common")
analysistest.Run(t, testdata, an, "./no_structpointer")
})
anlzr := analyzer.NewAnalyzer()
t.Run("func decl", func(t *testing.T) {
an := analyzer.NewAnalyzer()
for k, v := range test.options {
err := anlzr.Flags.Set(k, v)
if err != nil {
t.Fatal(err)
}
}
err := an.Flags.Set(analyzer.FlagCheckStructPointers, "true")
if err != nil {
t.Fatal(err)
}
analysistest.RunWithSuggestedFixes(t, analysistest.TestData(), anlzr, test.dir)
})
}
}
func TestAnalyzer_cgo(t *testing.T) {
t.Parallel()
a := analyzer.NewAnalyzer()
analysistest.Run(t, analysistest.TestData(), a, "cgo")
analysistest.Run(t, testdata, an, "./common")
analysistest.Run(t, testdata, an, "./structpointer")
})
}

View file

@ -1,51 +0,0 @@
package cgo
/*
#include <stdio.h>
#include <stdlib.h>
void myprint(char* s) {
printf("%d\n", s);
}
*/
import "C"
import (
"context"
"unsafe"
)
func _() {
cs := C.CString("Hello from stdio\n")
C.myprint(cs)
C.free(unsafe.Pointer(cs))
}
func _() {
ctx := context.Background()
for i := 0; i < 10; i++ {
ctx := context.WithValue(ctx, "key", i)
ctx = context.WithValue(ctx, "other", "val")
}
for i := 0; i < 10; i++ {
ctx = context.WithValue(ctx, "key", i) // want "nested context in loop"
ctx = context.WithValue(ctx, "other", "val")
}
for item := range []string{"one", "two", "three"} {
ctx = wrapContext(ctx) // want "nested context in loop"
ctx := context.WithValue(ctx, "key", item)
ctx = wrapContext(ctx)
}
for {
ctx = wrapContext(ctx) // want "nested context in loop"
break
}
}
func wrapContext(ctx context.Context) context.Context {
return context.WithoutCancel(ctx)
}

View file

@ -1,251 +0,0 @@
package common
import (
"context"
"testing"
)
func example() {
ctx := context.Background()
for i := 0; i < 10; i++ {
ctx := context.WithValue(ctx, "key", i)
ctx = context.WithValue(ctx, "other", "val")
}
for i := 0; i < 10; i++ {
ctx := context.WithValue(ctx, "key", i) // want "nested context in loop"
ctx = context.WithValue(ctx, "other", "val")
}
for item := range []string{"one", "two", "three"} {
ctx := wrapContext(ctx) // want "nested context in loop"
ctx := context.WithValue(ctx, "key", item)
ctx = wrapContext(ctx)
}
for {
ctx := wrapContext(ctx) // want "nested context in loop"
break
}
// not fooled by shadowing in nested blocks
for {
err := doSomething()
if err != nil {
ctx := wrapContext(ctx)
ctx = wrapContext(ctx)
}
switch err {
case nil:
ctx := wrapContext(ctx)
ctx = wrapContext(ctx)
default:
ctx := wrapContext(ctx)
ctx = wrapContext(ctx)
}
{
ctx := wrapContext(ctx)
ctx = wrapContext(ctx)
}
select {
case <-ctx.Done():
ctx := wrapContext(ctx)
ctx = wrapContext(ctx)
default:
}
ctx := wrapContext(ctx) // want "nested context in loop"
break
}
// detects contexts wrapped in function literals (this is risky as function literals can be called multiple times)
_ = func() {
ctx := wrapContext(ctx) // want "nested context in function literal"
}
// this is fine because the context is created in the loop
for {
if ctx := context.Background(); doSomething() != nil {
ctx = wrapContext(ctx)
}
}
for {
ctx2 := context.Background()
ctx := wrapContext(ctx) // want "nested context in loop"
if doSomething() != nil {
ctx2 = wrapContext(ctx2)
}
}
}
func wrapContext(ctx context.Context) context.Context {
return context.WithoutCancel(ctx)
}
func doSomething() error {
return nil
}
// storing contexts in a struct isn't recommended, but local copies of a non-pointer struct should act like local copies of a context.
func inStructs(ctx context.Context) {
for i := 0; i < 10; i++ {
c := struct{ Ctx context.Context }{ctx}
c.Ctx = context.WithValue(c.Ctx, "key", i)
c.Ctx = context.WithValue(c.Ctx, "other", "val")
}
for i := 0; i < 10; i++ {
c := []struct{ Ctx context.Context }{{ctx}}
c[0].Ctx = context.WithValue(c[0].Ctx, "key", i)
c[0].Ctx = context.WithValue(c[0].Ctx, "other", "val")
}
c := struct{ Ctx context.Context }{ctx}
for i := 0; i < 10; i++ {
c := c
c.Ctx = context.WithValue(c.Ctx, "key", i)
c.Ctx = context.WithValue(c.Ctx, "other", "val")
}
pc := &struct{ Ctx context.Context }{ctx}
for i := 0; i < 10; i++ {
c := pc
c.Ctx := context.WithValue(c.Ctx, "key", i) // want "nested context in loop"
c.Ctx = context.WithValue(c.Ctx, "other", "val")
}
r := []struct{ Ctx context.Context }{{ctx}}
for i := 0; i < 10; i++ {
r[0].Ctx := context.WithValue(r[0].Ctx, "key", i) // want "nested context in loop"
r[0].Ctx = context.WithValue(r[0].Ctx, "other", "val")
}
rp := []*struct{ Ctx context.Context }{{ctx}}
for i := 0; i < 10; i++ {
rp[0].Ctx := context.WithValue(rp[0].Ctx, "key", i) // want "nested context in loop"
rp[0].Ctx = context.WithValue(rp[0].Ctx, "other", "val")
}
}
func inVariousNestedBlocks(ctx context.Context) {
for {
err := doSomething()
if err != nil {
ctx := wrapContext(ctx) // want "nested context in loop"
}
break
}
for {
err := doSomething()
if err != nil {
if true {
ctx := wrapContext(ctx) // want "nested context in loop"
}
}
break
}
for {
err := doSomething()
switch err {
case nil:
ctx := wrapContext(ctx) // want "nested context in loop"
}
break
}
for {
err := doSomething()
switch err {
default:
ctx := wrapContext(ctx) // want "nested context in loop"
}
break
}
for {
ctx := wrapContext(ctx)
err := doSomething()
if err != nil {
ctx = wrapContext(ctx)
}
break
}
for {
{
ctx := wrapContext(ctx) // want "nested context in loop"
}
break
}
for {
select {
case <-ctx.Done():
ctx := wrapContext(ctx) // want "nested context in loop"
default:
}
break
}
}
// this middleware could run on every request, bloating the request parameter level context and causing a memory leak
func badMiddleware(ctx context.Context) func() error {
return func() error {
ctx := wrapContext(ctx) // want "nested context in function literal"
return doSomethingWithCtx(ctx)
}
}
// this middleware is fine, as it doesn't modify the context of parent function
func okMiddleware(ctx context.Context) func() error {
return func() error {
ctx := wrapContext(ctx)
return doSomethingWithCtx(ctx)
}
}
// this middleware is fine, as it only modifies the context passed to it
func okMiddleware2(ctx context.Context) func(ctx context.Context) error {
return func(ctx context.Context) error {
ctx = wrapContext(ctx)
return doSomethingWithCtx(ctx)
}
}
func doSomethingWithCtx(ctx context.Context) error {
return nil
}
func testCasesInit(t *testing.T) {
cases := []struct {
ctx context.Context
}{
{},
{
ctx: context.WithValue(context.Background(), "key", "value"),
},
}
for _, tc := range cases {
t.Run("some test", func(t *testing.T) {
if tc.ctx == nil {
tc.ctx = context.Background()
}
})
}
}