diff --git a/pkg/analyzer/analyzer.go b/pkg/analyzer/analyzer.go
index b45dbc3..5b28078 100644
--- a/pkg/analyzer/analyzer.go
+++ b/pkg/analyzer/analyzer.go
@@ -30,6 +30,7 @@ func run(pass *analysis.Pass) (interface{}, error) {
 		(*ast.ForStmt)(nil),
 		(*ast.RangeStmt)(nil),
 		(*ast.FuncLit)(nil),
+		(*ast.FuncDecl)(nil),
 	}
 
 	inspctr.Preorder(nodeFilter, func(node ast.Node) {
@@ -81,25 +82,23 @@ func getReportMessage(node ast.Node) string {
 		return "nested context in loop"
 	case *ast.FuncLit:
 		return "nested context in function literal"
+	case *ast.FuncDecl:
+		return "potential nested context in function declaration"
 	default:
 		return "unsupported nested context type"
 	}
 }
 
 func getBody(node ast.Node) (*ast.BlockStmt, error) {
-	forStmt, ok := node.(*ast.ForStmt)
-	if ok {
-		return forStmt.Body, nil
-	}
-
-	rangeStmt, ok := node.(*ast.RangeStmt)
-	if ok {
-		return rangeStmt.Body, nil
-	}
-
-	funcLit, ok := node.(*ast.FuncLit)
-	if ok {
-		return funcLit.Body, nil
+	switch typedNode := node.(type) {
+	case *ast.ForStmt:
+		return typedNode.Body, nil
+	case *ast.RangeStmt:
+		return typedNode.Body, nil
+	case *ast.FuncLit:
+		return typedNode.Body, nil
+	case *ast.FuncDecl:
+		return typedNode.Body, nil
 	}
 
 	return nil, errUnknown
@@ -108,44 +107,29 @@ func getBody(node ast.Node) (*ast.BlockStmt, error) {
 func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *ast.AssignStmt {
 	for _, stmt := range stmts {
 		// Recurse if necessary
-		if inner, ok := stmt.(*ast.BlockStmt); ok {
-			found := findNestedContext(pass, node, inner.List)
-			if found != nil {
+		switch typedStmt := stmt.(type) {
+		case *ast.BlockStmt:
+			if found := findNestedContext(pass, node, typedStmt.List); found != nil {
 				return found
 			}
-		}
-
-		if inner, ok := stmt.(*ast.IfStmt); ok {
-			found := findNestedContext(pass, node, inner.Body.List)
-			if found != nil {
+		case *ast.IfStmt:
+			if found := findNestedContext(pass, node, typedStmt.Body.List); found != nil {
 				return found
 			}
-		}
-
-		if inner, ok := stmt.(*ast.SwitchStmt); ok {
-			found := findNestedContext(pass, node, inner.Body.List)
-			if found != nil {
+		case *ast.SwitchStmt:
+			if found := findNestedContext(pass, node, typedStmt.Body.List); found != nil {
 				return found
 			}
-		}
-
-		if inner, ok := stmt.(*ast.CaseClause); ok {
-			found := findNestedContext(pass, node, inner.Body)
-			if found != nil {
+		case *ast.CaseClause:
+			if found := findNestedContext(pass, node, typedStmt.Body); found != nil {
 				return found
 			}
-		}
-
-		if inner, ok := stmt.(*ast.SelectStmt); ok {
-			found := findNestedContext(pass, node, inner.Body.List)
-			if found != nil {
+		case *ast.SelectStmt:
+			if found := findNestedContext(pass, node, typedStmt.Body.List); found != nil {
 				return found
 			}
-		}
-
-		if inner, ok := stmt.(*ast.CommClause); ok {
-			found := findNestedContext(pass, node, inner.Body)
-			if found != nil {
+		case *ast.CommClause:
+			if found := findNestedContext(pass, node, typedStmt.Body); found != nil {
 				return found
 			}
 		}
diff --git a/testdata/src/example.go b/testdata/src/example.go
index df76f89..2dc714b 100644
--- a/testdata/src/example.go
+++ b/testdata/src/example.go
@@ -249,3 +249,25 @@ func testCasesInit(t *testing.T) {
 		})
 	}
 }
+
+type Container struct {
+	Ctx context.Context
+}
+
+func something() func(*Container) {
+	return func(r *Container) {
+		ctx := r.Ctx
+		ctx = context.WithValue(ctx, "key", "val")
+		r.Ctx = ctx // want "nested context in function literal"
+	}
+}
+
+func other() func(*Container) {
+	return blah
+}
+
+func blah(r *Container) {
+	ctx := r.Ctx
+	ctx = context.WithValue(ctx, "key", "val")
+	r.Ctx = ctx // want "potential nested context in function declaration"
+}