package router

import (
	"context"
	"log/slog"
	"sync"
	"time"

	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/api"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/module/modshared"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/fieldz"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/mathz"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/nettool"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/syncz"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tunnel/tunserver"
	"go.opentelemetry.io/otel/trace"
)

type pollingContext struct {
	subs      syncz.Subscriptions[[]nettool.MultiURL] // sub/unsub are protected for consistency, but Dispatch() is not - syncz.Subscriptions has its own mutex.
	cancel    context.CancelFunc
	kasURLs   []nettool.MultiURL
	stoppedAt time.Time
	// consumers tracks number of subscribers.
	// Cannot rely on subs.Len() because unsubscribes in PollGatewayURLs() happen without holding the main mutex
	// i.e. multiple can happen concurrently.
	consumers int32
}

func (c *pollingContext) isExpired(before time.Time) bool {
	return c.consumers == 0 && c.stoppedAt.Before(before)
}

// AggregatingQuerier groups polling requests.
type AggregatingQuerier struct {
	log        *slog.Logger
	delegate   Querier
	api        modshared.API
	tracer     trace.Tracer
	pollPeriod time.Duration
	gcPeriod   time.Duration

	mu        sync.Mutex
	listeners map[api.AgentKey]*pollingContext
}

func NewAggregatingQuerier(log *slog.Logger, delegate Querier, modAPI modshared.API, tracer trace.Tracer, pollPeriod, gcPeriod time.Duration) *AggregatingQuerier {
	return &AggregatingQuerier{
		log:        log,
		delegate:   delegate,
		api:        modAPI,
		tracer:     tracer,
		pollPeriod: pollPeriod,
		gcPeriod:   gcPeriod,
		listeners:  make(map[api.AgentKey]*pollingContext),
	}
}

func (q *AggregatingQuerier) Run(ctx context.Context) error {
	done := ctx.Done()
	gcPeriodWithJitter := mathz.DurationWithPositiveJitter(q.gcPeriod, gcPeriodJitterPercentage)
	t := time.NewTicker(gcPeriodWithJitter)
	defer t.Stop()
	for {
		select {
		case <-done:
			return nil
		case <-t.C:
			q.runGC()
		}
	}
}

func (q *AggregatingQuerier) runGC() {
	before := time.Now().Add(-q.gcPeriod)
	q.mu.Lock()
	defer q.mu.Unlock()
	for agentKey, pc := range q.listeners {
		if pc.isExpired(before) {
			delete(q.listeners, agentKey)
		}
	}
}

func (q *AggregatingQuerier) PollGatewayURLs(ctx context.Context, agentKey api.AgentKey, cb tunserver.PollGatewayURLsCallback[nettool.MultiURL]) {
	listen := q.maybeStartPolling(ctx, agentKey)
	defer q.maybeStopPolling(agentKey)

	listen(func(ctx context.Context, kasURLs []nettool.MultiURL) {
		cb(kasURLs)
	})
}

func (q *AggregatingQuerier) CachedGatewayURLs(agentKey api.AgentKey) []nettool.MultiURL {
	q.mu.Lock()
	defer q.mu.Unlock()
	pc := q.listeners[agentKey]
	if pc == nil { // no existing context
		return nil
	}
	return pc.kasURLs
}

func (q *AggregatingQuerier) maybeStartPolling(ctx context.Context, agentKey api.AgentKey) syncz.Listen[[]nettool.MultiURL] {
	q.mu.Lock()
	defer q.mu.Unlock()
	pc := q.listeners[agentKey]
	if pc == nil { // no existing context
		pc = &pollingContext{
			subs: *syncz.NewSubscriptions[[]nettool.MultiURL](),
		}
		q.listeners[agentKey] = pc
	}

	// Subscribe for notifications. Must be done before starting a goroutine, so we don't miss any notifications.
	listen := pc.subs.Subscribe(ctx)

	if pc.consumers == 0 { // first consumer, start polling
		pollCtx, cancel := context.WithCancel(context.Background())
		pc.cancel = cancel
		go q.poll(pollCtx, agentKey, pc) //nolint: contextcheck
	}
	pc.consumers++
	return listen
}

func (q *AggregatingQuerier) maybeStopPolling(agentKey api.AgentKey) {
	q.mu.Lock()
	defer q.mu.Unlock()

	pc := q.listeners[agentKey]
	pc.consumers--
	if pc.consumers > 0 {
		// There are still consumers, no need to stop polling
		return
	}

	// No more consumers, stop polling
	pc.cancel()     // stop polling
	pc.cancel = nil // release the kraken! err... GC
	pc.stoppedAt = time.Now()
	if len(pc.kasURLs) == 0 {
		// No point in keeping this pollingContext around if it doesn't have any cached URLs.
		delete(q.listeners, agentKey)
	}
}

func (q *AggregatingQuerier) poll(ctx context.Context, agentKey api.AgentKey, pc *pollingContext) {
	ctx, span := q.tracer.Start(ctx, "AggregatingQuerier.poll",
		trace.WithSpanKind(trace.SpanKindInternal),
		trace.WithAttributes(
			api.TraceAgentIDAttr.Int64(agentKey.ID),
			api.TraceAgentTypeAttr.String(agentKey.Type.String()),
		),
	)
	defer span.End()

	t := time.NewTimer(0) // run immediately
	defer t.Stop()
	done := ctx.Done()
	for {
		select {
		case <-done:
			return
		case <-t.C:
			kasURLs, err := q.delegate.KASURLsByAgentKey(ctx, agentKey)
			if err != nil {
				q.api.HandleProcessingError(ctx, q.log, "KASURLsByAgentKey() failed", err, fieldz.AgentKey(agentKey))
				// fallthrough
			}
			if err != nil && len(kasURLs) == 0 {
				// if there was an error, and we failed to retrieve any kas URLs from Redis, we don't want to erase the
				// cache. So, no-op here.
			} else {
				q.mu.Lock()
				pc.kasURLs = kasURLs
				q.mu.Unlock()
			}
			if len(kasURLs) > 0 {
				pc.subs.Dispatch(ctx, kasURLs)
			}
			t.Reset(q.pollPeriod) // re-arm the timer
		}
	}
}
