package agent

import (
	"context"
	"fmt"
	"log/slog"

	"github.com/google/cel-go/cel"
	"github.com/google/cel-go/common/types"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/module/kubernetes_api/agent/watch_graph"
	"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
	"k8s.io/apimachinery/pkg/runtime/schema"
	"k8s.io/apimachinery/pkg/util/sets"
)

type watchGraphRootsObserver struct {
	log                *slog.Logger
	delegate           watch_graph.ObjectGraphObserver[watch_graph.VertexID, watch_graph.ArcType, watch_graph.VertexData, watch_graph.ArcAttrs]
	isRootCEL          cel.Program
	objectToEvalVars   func(*unstructured.Unstructured, schema.GroupVersionResource) map[string]any
	onWarning          func(watch_graph.Warning)
	roots              sets.Set[watch_graph.VertexID]
	reachable          sets.Set[watch_graph.VertexID]
	ignoreArcDirection watch_graph.ArcTypeSet
}

func (o *watchGraphRootsObserver) OnSetVertex(ctx context.Context, insp inspector, vid watch_graph.VertexID, vd watch_graph.VertexData) {
	root := o.isRoot(ctx, vid, vd)
	wasRoot := o.roots.Has(vid)
	o.log.Debug("roots.OnSetVertex()", vidAttr(vid), vidRootAttr(root), vidWasRootAttr(wasRoot))
	if wasRoot {
		if root { // root -> root
			o.delegate.OnSetVertex(ctx, insp, vid, vd)
		} else { // root -> not root
			o.roots.Delete(vid)
			o.rebuildFromRoots(ctx, insp)
			if o.reachable.Has(vid) { // if it is still reachable, update the data.
				o.delegate.OnSetVertex(ctx, insp, vid, vd)
			}
		}
	} else {
		if root { // not root -> root
			o.roots.Insert(vid)
			o.delegate.OnSetVertex(ctx, insp, vid, vd)
			if o.reachable.Has(vid) {
				// Vertex is already reachable - update the data only (above).
			} else {
				// Vertex was not reachable - update the data (above) and emit all reachable via it.
				o.reachable.Insert(vid)
				// This may be a new or an existing vertex that became root now.
				// Hence, emit what's reachable from it.
				o.emitReachableFrom(ctx, insp, vid)
			}
		} else { // not root -> not root
			if o.reachable.Has(vid) {
				// Vertex is reachable, update the data.
				o.delegate.OnSetVertex(ctx, insp, vid, vd)
			} else {
				// If this vertex did not exist before, but there is an arc to it already,
				// it may be reachable now.
				for aid := range insp.InboundArcsFor(vid) {
					if o.reachable.Has(aid.To) {
						o.reachable.Insert(vid)
						o.delegate.OnSetVertex(ctx, insp, vid, vd)
						break
					}
				}
			}
		}
	}
}

func (o *watchGraphRootsObserver) OnDeleteVertex(ctx context.Context, insp inspector, vid watch_graph.VertexID) {
	if !o.reachable.Has(vid) {
		o.log.Debug("roots.OnDeleteVertex(): not reachable", vidAttr(vid))
		return // it's not reachable, so nothing to delete.
	}
	o.log.Debug("roots.OnDeleteVertex(): reachable", vidAttr(vid))
	o.roots.Delete(vid) // it may be a root vertex, so ensure it is not in the set anymore.
	o.reachable.Delete(vid)
	// rebuildFromRoots() would have deleted the vertex since it shouldn't have any inbound arcs,
	// but we delete it explicitly for extra code clarity and defense-in-depth against bugs.
	o.deleteVertex(ctx, insp, vid)
	o.rebuildFromRoots(ctx, insp)
}

func (o *watchGraphRootsObserver) OnSetArc(ctx context.Context, insp inspector, from watch_graph.VertexID, to watch_graph.ArcToID[watch_graph.VertexID, watch_graph.ArcType], data watch_graph.ArcAttrs) {
	toIsReachable := o.reachable.Has(to.To)

	if o.reachable.Has(from) {
		if toIsReachable {
			o.log.Debug("roots.OnSetArc(): 'from' and 'to' are reachable", arcAttr(from, to))
			o.delegate.OnSetArc(ctx, insp, from, to, data)
		} else {
			vd, isSet := insp.VertexData(to.To)
			if isSet {
				o.log.Debug("roots.OnSetArc(): 'from' is reachable, 'to' is reachable now too", arcAttr(from, to))
				o.reachable.Insert(to.To)
				o.delegate.OnSetVertex(ctx, insp, to.To, vd) // emit 'to' vertex before emitting from->to arc
				o.delegate.OnSetArc(ctx, insp, from, to, data)
				// This is the first from->to arc - emit everything that's reachable from 'to'.
				o.emitReachableFrom(ctx, insp, to.To)
			} else {
				o.log.Debug("roots.OnSetArc(): 'from' is reachable, 'to' does not exist", arcAttr(from, to))
				o.delegate.OnSetArc(ctx, insp, from, to, data)
			}
		}
	} else {
		if toIsReachable && o.ignoreArcDirection.Contains(to.ArcType) {
			o.log.Debug("roots.OnSetArc(): 'from' was not reachable but it can be reached via 'to' now", arcAttr(from, to))
			o.reachable.Insert(from)
			vd, _ := insp.VertexData(from)
			o.delegate.OnSetVertex(ctx, insp, from, vd) // emit to vertex before emitting from->to arc
			// This is the first from->to arc that is undirected - emit everything that's reachable from 'from'.
			o.emitReachableFrom(ctx, insp, from) // this will emit the new arc too
		} else {
			o.log.Debug("roots.OnSetArc(): 'from' is not reachable and neither is 'to'", arcAttr(from, to))
			// Nothing to do
		}
	}
}

func (o *watchGraphRootsObserver) OnDeleteArc(ctx context.Context, insp inspector, from watch_graph.VertexID, to watch_graph.ArcToID[watch_graph.VertexID, watch_graph.ArcType]) {
	if !o.reachable.Has(from) {
		o.log.Debug("roots.OnDeleteArc(): 'from' is not reachable", arcAttr(from, to))
		return // it's not reachable, so nothing to do.
	}
	o.log.Debug("roots.OnDeleteArc(): 'from' is reachable", arcAttr(from, to))
	// The 'from' vertex is reachable, so everything it connects to is reachable too.
	o.delegate.OnDeleteArc(ctx, insp, from, to)
	// If we've just removed the last from->to arc, then we need to re-evaluate reachability.
	if isArcFromTo(insp, from, to.To) {
		return // There is still a from->to arc, exit.
	}
	o.rebuildFromRoots(ctx, insp)
}

func (o *watchGraphRootsObserver) isRoot(ctx context.Context, vid watch_graph.VertexID, vd watch_graph.VertexData) bool {
	val, _, err := o.isRootCEL.ContextEval(ctx, o.objectToEvalVars(&unstructured.Unstructured{Object: vd.Object}, vid.GVR.Value()))
	if err != nil {
		o.onWarning(watch_graph.NewObjectProcessingWarning(vid.GVR.Value(), vid.Namespace, vid.Name, fmt.Sprintf("CEL: %v", err)))
		return false
	}
	return bool(val.(types.Bool))
}

func (o *watchGraphRootsObserver) emitReachableFrom(ctx context.Context, insp inspector, vid watch_graph.VertexID) {
	w := walkerWithDelegate{
		ctx:                ctx,
		delegate:           o.delegate,
		insp:               insp,
		reachable:          o.reachable,
		ignoreArcDirection: o.ignoreArcDirection,
	}
	w.emitReachableFrom(vid)
}

func (o *watchGraphRootsObserver) rebuildFromRoots(ctx context.Context, insp inspector) {
	o.log.Debug(fmt.Sprintf("roots.rebuildFromRoots(): roots: %d", o.roots.Len()))
	w := walker{
		insp:               insp,
		reachable:          o.roots.Clone(),
		ignoreArcDirection: o.ignoreArcDirection,
	}
	for vid := range o.roots {
		o.log.Debug("roots.rebuildFromRoots(): collecting from", vidAttr(vid))
		w.collectReachableFrom(vid)
	}
	o.log.Debug(fmt.Sprintf("roots.rebuildFromRoots(): reachable: %d", o.reachable.Len()))
	for vid := range o.reachable {
		if w.reachable.Has(vid) {
			o.log.Debug("roots.rebuildFromRoots(): still reachable", vidAttr(vid))
			continue // still reachable
		}
		o.log.Debug("roots.rebuildFromRoots(): unreachable", vidAttr(vid))
		o.deleteVertex(ctx, insp, vid)
	}
	o.reachable = w.reachable
}

func (o *watchGraphRootsObserver) deleteVertex(ctx context.Context, insp inspector, vid watch_graph.VertexID) {
	for aid := range insp.OutboundArcsFor(vid) {
		o.delegate.OnDeleteArc(ctx, insp, vid, aid)
	}
	o.delegate.OnDeleteVertex(ctx, insp, vid)
}

func isArcFromTo(insp inspector, from, to watch_graph.VertexID) bool {
	for aid := range insp.OutboundArcsFor(from) {
		if aid.To == to {
			return true
		}
	}
	return false
}

// walkerWithDelegate stores as many things as possible in fields so that it consumes less stack
// when it recursively walks the graph.
type walkerWithDelegate struct {
	ctx                context.Context //nolint:containedctx
	delegate           watch_graph.ObjectGraphObserver[watch_graph.VertexID, watch_graph.ArcType, watch_graph.VertexData, watch_graph.ArcAttrs]
	insp               inspector
	reachable          sets.Set[watch_graph.VertexID]
	ignoreArcDirection watch_graph.ArcTypeSet
}

func (w *walkerWithDelegate) emitReachableFrom(vid watch_graph.VertexID) {
	var newReachable []watch_graph.VertexID
	// Visit in breadth first order: send all the outbound arcs first, then follow them.
	for aid, aData := range w.insp.OutboundArcsFor(vid) {
		w.appendIfNotReachableAlready(&newReachable, aid.To)
		// Send arc to the vertex
		w.delegate.OnSetArc(w.ctx, w.insp, vid, aid, aData)
	}
	if !w.ignoreArcDirection.IsEmpty() {
		for aid := range w.insp.InboundArcsFor(vid) {
			if !w.ignoreArcDirection.Contains(aid.ArcType) {
				continue
			}
			w.appendIfNotReachableAlready(&newReachable, aid.To)
			// Outbound arcs of that vertex will be added by emitReachableFrom()
		}
	}
	for _, newVid := range newReachable {
		w.emitReachableFrom(newVid)
	}
}

func (w *walkerWithDelegate) appendIfNotReachableAlready(newReachable *[]watch_graph.VertexID, vid watch_graph.VertexID) {
	data, isSet := w.insp.VertexData(vid)
	if isSet {
		if !w.reachable.Has(vid) {
			// Send the newly reachable vertex
			w.reachable.Insert(vid)
			*newReachable = append(*newReachable, vid)
			w.delegate.OnSetVertex(w.ctx, w.insp, vid, data)
		}
	} else { //nolint:staticcheck
		// We don't want to call OnSetVertex() for a vertex that does not exist.
		// This branch can only be taken when called from the loop over OutboundArcsFor().
		// Vertices in the InboundArcsFor() all exist because there are outbound arcs from them.
	}
}

// walker stores as many things as possible in fields so that it consumes less stack
// when it recursively walks the graph.
// It's a simpler and faster version of walkerWithDelegate to minimize overhead.
type walker struct {
	insp               inspector
	reachable          sets.Set[watch_graph.VertexID]
	ignoreArcDirection watch_graph.ArcTypeSet
}

func (w *walker) collectReachableFrom(vid watch_graph.VertexID) {
	for aid := range w.insp.OutboundArcsFor(vid) {
		w.processArc(aid.To)
	}
	if w.ignoreArcDirection.IsEmpty() {
		return
	}
	for aid := range w.insp.InboundArcsFor(vid) {
		if !w.ignoreArcDirection.Contains(aid.ArcType) {
			continue
		}
		w.processArc(aid.To)
	}
}

func (w *walker) processArc(to watch_graph.VertexID) {
	_, isSet := w.insp.VertexData(to)
	if isSet {
		if !w.reachable.Has(to) {
			// Send the newly reachable vertex
			w.reachable.Insert(to)
			w.collectReachableFrom(to)
		}
	} else { //nolint:staticcheck
		// Destination vertex does not exist.
	}
}
