package analyzer

import (
	"bytes"
	"errors"
	"flag"
	"fmt"
	"go/ast"
	"go/printer"
	"go/token"
	"slices"

	"golang.org/x/tools/go/analysis"
	"golang.org/x/tools/go/analysis/passes/inspect"
	"golang.org/x/tools/go/ast/inspector"
)

const FlagCheckStructPointers = "check-struct-pointers"

func NewAnalyzer() *analysis.Analyzer {
	r := &runner{}

	flags := flag.NewFlagSet("fatcontext", flag.ExitOnError)
	flags.BoolVar(&r.DetectInStructPointers, FlagCheckStructPointers, false,
		"set to true to detect potential fat contexts in struct pointers")

	return &analysis.Analyzer{
		Name:     "fatcontext",
		Doc:      "detects nested contexts in loops and function literals",
		Run:      r.run,
		Flags:    *flags,
		Requires: []*analysis.Analyzer{inspect.Analyzer},
	}
}

var errUnknown = errors.New("unknown node type")

const (
	categoryInLoop          = "nested context in loop"
	categoryInFuncLit       = "nested context in function literal"
	categoryInStructPointer = "potential nested context in struct pointer"
	categoryUnsupported     = "unsupported nested context type"
)

type runner struct {
	DetectInStructPointers bool
}

func (r *runner) run(pass *analysis.Pass) (interface{}, error) {
	inspctr := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)

	nodeFilter := []ast.Node{
		(*ast.ForStmt)(nil),
		(*ast.RangeStmt)(nil),
		(*ast.FuncLit)(nil),
		(*ast.FuncDecl)(nil),
	}

	inspctr.Preorder(nodeFilter, func(node ast.Node) {
		body, err := getBody(node)
		if err != nil {
			return
		}

		if body == nil {
			return
		}

		assignStmt := findNestedContext(pass, node, body.List)
		if assignStmt == nil {
			return
		}

		category := getCategory(pass, node, assignStmt)

		if r.shouldIgnoreReport(category) {
			return
		}

		fixes := r.getSuggestedFixes(pass, assignStmt, category)

		pass.Report(analysis.Diagnostic{
			Pos:            assignStmt.Pos(),
			Message:        category,
			SuggestedFixes: fixes,
		})
	})

	return nil, nil
}

func (r *runner) shouldIgnoreReport(category string) bool {
	return category == categoryInStructPointer && !r.DetectInStructPointers
}

func (r *runner) getSuggestedFixes(pass *analysis.Pass, assignStmt *ast.AssignStmt, category string) []analysis.SuggestedFix {
	switch category {
	case categoryInStructPointer, categoryUnsupported:
		return nil
	}

	suggestedStmt := ast.AssignStmt{
		Lhs:    assignStmt.Lhs,
		TokPos: assignStmt.TokPos,
		Tok:    token.DEFINE,
		Rhs:    assignStmt.Rhs,
	}
	suggested, err := render(pass.Fset, &suggestedStmt)

	var fixes []analysis.SuggestedFix
	if err == nil {
		fixes = append(fixes, analysis.SuggestedFix{
			Message: "replace `=` with `:=`",
			TextEdits: []analysis.TextEdit{
				{
					Pos:     assignStmt.Pos(),
					End:     assignStmt.End(),
					NewText: suggested,
				},
			},
		})
	}

	return fixes
}

func getCategory(pass *analysis.Pass, node ast.Node, assignStmt *ast.AssignStmt) string {
	switch node.(type) {
	case *ast.ForStmt, *ast.RangeStmt:
		return categoryInLoop
	}

	if isPointer(pass, assignStmt.Lhs[0]) {
		return categoryInStructPointer
	}

	switch node.(type) {
	case *ast.FuncLit, *ast.FuncDecl:
		return categoryInFuncLit
	default:
		return categoryUnsupported
	}
}

func getBody(node ast.Node) (*ast.BlockStmt, error) {
	switch typedNode := node.(type) {
	case *ast.ForStmt:
		return typedNode.Body, nil
	case *ast.RangeStmt:
		return typedNode.Body, nil
	case *ast.FuncLit:
		return typedNode.Body, nil
	case *ast.FuncDecl:
		return typedNode.Body, nil
	}

	return nil, errUnknown
}

func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *ast.AssignStmt {
	for _, stmt := range stmts {
		// Recurse if necessary
		switch typedStmt := stmt.(type) {
		case *ast.BlockStmt:
			if found := findNestedContext(pass, node, typedStmt.List); found != nil {
				return found
			}
		case *ast.IfStmt:
			if found := findNestedContext(pass, node, typedStmt.Body.List); found != nil {
				return found
			}
		case *ast.SwitchStmt:
			if found := findNestedContext(pass, node, typedStmt.Body.List); found != nil {
				return found
			}
		case *ast.CaseClause:
			if found := findNestedContext(pass, node, typedStmt.Body); found != nil {
				return found
			}
		case *ast.SelectStmt:
			if found := findNestedContext(pass, node, typedStmt.Body.List); found != nil {
				return found
			}
		case *ast.CommClause:
			if found := findNestedContext(pass, node, typedStmt.Body); found != nil {
				return found
			}
		}

		// Actually check for nested context
		assignStmt, ok := stmt.(*ast.AssignStmt)
		if !ok {
			continue
		}

		t := pass.TypesInfo.TypeOf(assignStmt.Lhs[0])
		if t == nil {
			continue
		}

		if t.String() != "context.Context" {
			continue
		}

		if assignStmt.Tok == token.DEFINE {
			continue
		}

		// Ignore [context.Background] & [context.TODO].
		if isContextFunction(assignStmt.Rhs[0], "Background", "TODO") {
			continue
		}

		if isPointer(pass, assignStmt.Lhs[0]) {
			return assignStmt
		}

		// allow assignment to non-pointer children of values defined within the loop
		if isWithinLoop(assignStmt.Lhs[0], node, pass) {
			continue
		}

		return assignStmt
	}

	return nil
}

// render returns the pretty-print of the given node
func render(fset *token.FileSet, x interface{}) ([]byte, error) {
	var buf bytes.Buffer
	if err := printer.Fprint(&buf, fset, x); err != nil {
		return nil, fmt.Errorf("printing node: %w", err)
	}
	return buf.Bytes(), nil
}

func isContextFunction(exp ast.Expr, fnName ...string) bool {
	call, ok := exp.(*ast.CallExpr)
	if !ok {
		return false
	}

	selector, ok := call.Fun.(*ast.SelectorExpr)
	if !ok {
		return false
	}

	ident, ok := selector.X.(*ast.Ident)
	if !ok {
		return false
	}

	return ident.Name == "context" && slices.Contains(fnName, selector.Sel.Name)
}

func isWithinLoop(exp ast.Expr, node ast.Node, pass *analysis.Pass) bool {
	lhs := getRootIdent(pass, exp)
	if lhs == nil {
		return false
	}

	obj := pass.TypesInfo.ObjectOf(lhs)
	if obj == nil {
		return false
	}

	scope := obj.Parent()
	if scope == nil {
		return false
	}

	return scope.Pos() >= node.Pos() && scope.End() <= node.End()
}

func getRootIdent(pass *analysis.Pass, node ast.Node) *ast.Ident {
	for {
		switch n := node.(type) {
		case *ast.Ident:
			return n
		case *ast.IndexExpr:
			node = n.X
		case *ast.SelectorExpr:
			if sel, ok := pass.TypesInfo.Selections[n]; ok && sel.Indirect() {
				return nil // indirected (pointer) roots don't imply a (safe) copy
			}
			node = n.X
		default:
			return nil
		}
	}
}

func isPointer(pass *analysis.Pass, exp ast.Node) bool {
	switch n := exp.(type) {
	case *ast.SelectorExpr:
		sel, ok := pass.TypesInfo.Selections[n]
		return ok && sel.Indirect()
	}

	return false
}