package agent_tracker //nolint:staticcheck

import (
	"context"
	"fmt"
	"iter"
	"log/slog"
	"strconv"
	"time"

	"buf.build/go/protovalidate"
	"github.com/redis/rueidis"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/errz"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/logz"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/mathz"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/redistool"
	otelmetric "go.opentelemetry.io/otel/metric"
	"golang.org/x/sync/errgroup"
	"google.golang.org/protobuf/proto"
)

const (
	connectedAgentsKey            int64 = 0
	agentkConnectionsByIDHashName       = "connections_by_agent_id"
	connectedAgentsHashName             = "connected_agents"
	agentVersionsHashName               = "agent_versions"
	// agentVersionKey is not actually used as a key. See `agentVersionsHashKey` function.
	agentVersionKey                   int64 = 0
	connectionsByAgentVersionHashName       = "connections_by_agent_version"
	agentwConnectionsByIDHashName           = "agentw_connections_by_id"
	// gcPeriodJitterPercentage is the percentage of the configured GC period that's added as jitter.
	gcPeriodJitterPercentage = 5
)

type ExpiringRegisterer interface {
	// RegisterAgentkExpiring registers agentk connection with the tracker.
	// Registration will expire if not refreshed using this method.
	RegisterAgentkExpiring(ctx context.Context, info *ConnectedAgentkInfo) error
	// RegisterAgentwExpiring registers agentw connection with the tracker.
	// Registration will expire if not refreshed using this method.
	RegisterAgentwExpiring(ctx context.Context, info *ConnectedAgentwInfo) error
	// UnregisterAgentk unregisters agentk connection with the tracker.
	// This may be called to intentionally unregister the agentk,
	// even though it may eventually expire.
	UnregisterAgentk(ctx context.Context, info *DisconnectAgentkInfo) error
	// UnregisterAgentw unregisters agentw connection with the tracker.
	// This may be called to intentionally unregister the agentw,
	// even though it may eventually expire.
	UnregisterAgentw(ctx context.Context, info *DisconnectAgentwInfo) error
}

type Querier interface {
	GetAgentkConnectionsByID(ctx context.Context, agentID int64) iter.Seq[*ConnectedAgentkInfo]
	GetAgentwConnectionsByID(ctx context.Context, agentID int64) iter.Seq[*ConnectedAgentwInfo]
	GetConnectedAgentsCount(ctx context.Context) (int64, error)
	CountAgentsByAgentVersions(ctx context.Context) (map[string]int64, error)
}

type Tracker interface {
	ExpiringRegisterer
	Querier
	Run(ctx context.Context) error
}

type RedisTracker struct {
	log       *slog.Logger
	errRep    errz.ErrReporter
	validator protovalidate.Validator
	ttl       time.Duration
	gcPeriod  time.Duration

	agentkConnectionsByID redistool.ExpiringHash[int64, int64] // agentID -> connectionId -> agentk info
	agentwConnectionsByID redistool.ExpiringHash[int64, int64] // agentID -> connectionId -> agentw info
	connectedAgents       redistool.ExpiringHash[int64, int64] // hash name -> agentID -> ""
	// agentVersions keeps track of the list of agent versions that have active agents.
	agentVersions redistool.ExpiringHash[int64, string] // hash name -> agentVersions -> ""
	// connectionsByAgentVersion stores connections data for each agent versions.
	connectionsByAgentVersion redistool.ExpiringHash[string, int64] // agentVersion -> connectionId -> hash
}

func NewRedisTracker(log *slog.Logger, errRep errz.ErrReporter, v protovalidate.Validator, client rueidis.Client,
	agentKeyPrefix string, ttl, gcPeriod time.Duration, m otelmetric.Meter) (*RedisTracker, error) {
	agentkConnectionsByID, err := redistool.NewRedisExpiringHash(agentkConnectionsByIDHashName, client, connectionsByAgentIDHashKey(agentKeyPrefix), int64ToStr, ttl, m, true)
	if err != nil {
		return nil, err
	}
	agentwConnectionsByID, err := redistool.NewRedisExpiringHash(agentwConnectionsByIDHashName, client, connectionsByAgentIDHashKey(agentKeyPrefix), int64ToStr, ttl, m, true)
	if err != nil {
		return nil, err
	}
	connectedAgents, err := redistool.NewRedisExpiringHash(connectedAgentsHashName, client, connectedAgentsHashKey(agentKeyPrefix), int64ToStr, ttl, m, false)
	if err != nil {
		return nil, err
	}
	agentVersions, err := redistool.NewRedisExpiringHash(agentVersionsHashName, client, agentVersionsHashKey(agentKeyPrefix), strToStr, ttl, m, false)
	if err != nil {
		return nil, err
	}
	connectionsByAgentVersion, err := redistool.NewRedisExpiringHash(connectionsByAgentVersionHashName, client, connectionsByAgentVersionHashKey(agentKeyPrefix), int64ToStr, ttl, m, false)
	if err != nil {
		return nil, err
	}

	return &RedisTracker{
		log:                       log,
		errRep:                    errRep,
		validator:                 v,
		ttl:                       ttl,
		gcPeriod:                  gcPeriod,
		agentkConnectionsByID:     agentkConnectionsByID,
		agentwConnectionsByID:     agentwConnectionsByID,
		connectedAgents:           connectedAgents,
		agentVersions:             agentVersions,
		connectionsByAgentVersion: connectionsByAgentVersion,
	}, nil
}

func (t *RedisTracker) Run(ctx context.Context) error {
	gcPeriodWithJitter := mathz.DurationWithPositiveJitter(t.gcPeriod, gcPeriodJitterPercentage)
	gcTicker := time.NewTicker(gcPeriodWithJitter)
	defer gcTicker.Stop()
	done := ctx.Done()
	for {
		select {
		case <-done:
			return nil
		case <-gcTicker.C:
			keysDeleted := t.runGC(ctx)
			if keysDeleted > 0 {
				t.log.Info("Deleted expired agent connections records", logz.RemovedHashKeys(keysDeleted))
			}
		}
	}
}

func (t *RedisTracker) RegisterAgentkExpiring(ctx context.Context, info *ConnectedAgentkInfo) error {
	err := t.validator.Validate(info)
	if err != nil {
		// This should never happen
		return fmt.Errorf("invalid ConnectedAgentkInfo: %w", err)
	}
	infoBytes, err := proto.Marshal(info)
	if err != nil {
		// This should never happen
		return fmt.Errorf("proto.Marshal: %w", err)
	}
	exp := time.Now().Add(t.ttl)

	var wg errgroup.Group
	wg.Go(func() error {
		return t.agentkConnectionsByID.SetEX(ctx, info.AgentId, info.ConnectionId, infoBytes, exp)
	})
	wg.Go(func() error {
		return t.connectedAgents.SetEX(ctx, connectedAgentsKey, info.AgentId, nil, exp)
	})
	wg.Go(func() error {
		agentPodInfoBytes, err := proto.Marshal(&AgentPodInfo{
			AgentId: info.AgentId,
			PodId:   info.ConnectionId,
		})
		if err != nil {
			// This should never happen
			return fmt.Errorf("failed to marshal AgentPodInfo object: %w", err)
		}
		return t.connectionsByAgentVersion.SetEX(ctx, info.AgentMeta.Version, info.ConnectionId, agentPodInfoBytes, exp)
	})
	wg.Go(func() error {
		return t.agentVersions.SetEX(ctx, agentVersionKey, info.AgentMeta.Version, nil, exp)
	})
	return wg.Wait()
}

func (t *RedisTracker) RegisterAgentwExpiring(ctx context.Context, info *ConnectedAgentwInfo) error {
	err := t.validator.Validate(info)
	if err != nil {
		// This should never happen
		return fmt.Errorf("invalid ConnectedAgentwInfo: %w", err)
	}
	infoBytes, err := proto.Marshal(info)
	if err != nil {
		// This should never happen
		return fmt.Errorf("proto.Marshal: %w", err)
	}
	exp := time.Now().Add(t.ttl)

	return t.agentwConnectionsByID.SetEX(ctx, info.WorkspaceId, info.ConnectionId, infoBytes, exp)
}

func (t *RedisTracker) UnregisterAgentk(ctx context.Context, info *DisconnectAgentkInfo) error {
	var wg errgroup.Group
	wg.Go(func() error {
		return t.agentkConnectionsByID.DelEX(ctx, info.AgentId, info.ConnectionId)
	})
	wg.Go(func() error {
		return t.connectedAgents.DelEX(ctx, connectedAgentsKey, info.AgentId)
	})
	wg.Go(func() error {
		return t.connectionsByAgentVersion.DelEX(ctx, info.AgentMeta.Version, info.ConnectionId)
	})
	wg.Go(func() error {
		return t.agentVersions.DelEX(ctx, agentVersionKey, info.AgentMeta.Version)
	})
	return wg.Wait()
}

func (t *RedisTracker) UnregisterAgentw(ctx context.Context, info *DisconnectAgentwInfo) error {
	return t.agentwConnectionsByID.DelEX(ctx, info.WorkspaceId, info.ConnectionId)
}

func (t *RedisTracker) GetAgentkConnectionsByID(ctx context.Context, agentID int64) iter.Seq[*ConnectedAgentkInfo] {
	return yieldConnectedAgentInfo[*ConnectedAgentkInfo](ctx, t.log, agentID, t.errRep, t.validator, t.agentkConnectionsByID)
}

func (t *RedisTracker) GetAgentwConnectionsByID(ctx context.Context, agentID int64) iter.Seq[*ConnectedAgentwInfo] {
	return yieldConnectedAgentInfo[*ConnectedAgentwInfo](ctx, t.log, agentID, t.errRep, t.validator, t.agentwConnectionsByID)
}

type msg[T any] interface {
	*T
	proto.Message
}

func yieldConnectedAgentInfo[M msg[T], T any](ctx context.Context, log *slog.Logger, agentID int64, errRep errz.ErrReporter,
	validator protovalidate.Validator, hash redistool.ExpiringHash[int64, int64]) func(yield func(M) bool) {

	return func(yield func(M) bool) {
		for se, err := range hash.Scan(ctx, agentID) {
			if err != nil {
				errRep.HandleProcessingError(ctx, log, fmt.Sprintf("Redis %s hash scan", hash.GetName()), err)
				continue
			}
			var info T
			m := M(&info)
			err = proto.Unmarshal(se.Value, m)
			if err != nil {
				errRep.HandleProcessingError(ctx, log, fmt.Sprintf("Redis %s hash scan: proto.Unmarshal(%T)", hash.GetName(), info), err)
				continue
			}
			err = validator.Validate(m)
			if err != nil {
				errRep.HandleProcessingError(ctx, log, fmt.Sprintf("Redis %s hash scan: invalid %T", hash.GetName(), info), err)
				continue
			}
			if !yield(m) {
				return
			}
		}
	}
}

func (t *RedisTracker) GetConnectedAgentsCount(ctx context.Context) (int64, error) {
	return t.connectedAgents.Len(ctx, connectedAgentsKey)
}

func (t *RedisTracker) CountAgentsByAgentVersions(ctx context.Context) (map[string]int64, error) {
	agentVersions := t.getAgentVersions(ctx)

	counts := make(map[string]int64, len(agentVersions))
	for _, version := range agentVersions {
		count, err := t.connectionsByAgentVersion.Len(ctx, version)
		if err != nil {
			return nil, fmt.Errorf("failed to get hash length from connectionsByAgentVersion in Redis: %w", err)
		}
		counts[version] = count
	}
	return counts, nil
}

func (t *RedisTracker) runGC(ctx context.Context) int {
	type gcFunc struct {
		hashName string
		gc       func(context.Context) (int, error)
	}
	gcFuncs := []gcFunc{ // slice ensures deterministic iteration order
		{
			hashName: agentkConnectionsByIDHashName,
			gc:       t.agentkConnectionsByID.GC,
		},
		{
			hashName: agentwConnectionsByIDHashName,
			gc:       t.agentwConnectionsByID.GC,
		},
		{
			hashName: connectedAgentsHashName,
			gc:       t.connectedAgents.GC,
		},
		{
			hashName: agentVersionsHashName,
			gc:       t.agentVersions.GC, // First GC agentVersions
		},
		{
			hashName: connectionsByAgentVersionHashName,
			gc:       t.gcConnectionsByAgentVersion, // Then GC hashes based on what is in agentVersions
		},
	}
	keysDeleted := 0
	// No rush so run GC sequentially to not stress RAM/CPU/Redis/network.
	// We have more important work to do that we shouldn't impact.
	for _, gf := range gcFuncs {
		deleted, err := gf.gc(ctx)
		keysDeleted += deleted
		if err != nil {
			if errz.ContextDone(err) {
				t.log.Debug("Redis GC interrupted", logz.Error(err))
				break
			}
			t.errRep.HandleProcessingError(ctx, t.log, fmt.Sprintf("Failed to GC data in %s Redis hash", gf.hashName), err)
			// continue anyway
		}
	}
	return keysDeleted
}

func (t *RedisTracker) gcConnectionsByAgentVersion(ctx context.Context) (int, error) {
	// Get a list of agent versions.
	agentVersions := t.getAgentVersions(ctx)

	// GC connectionsByAgentVersion for agent versions that we got from agentVersions hash.
	return t.connectionsByAgentVersion.GCFor(agentVersions)(ctx)
}

func (t *RedisTracker) getAgentVersions(ctx context.Context) []string {
	var agentVersions []string //nolint:prealloc
	for se, err := range t.agentVersions.Scan(ctx, agentVersionKey) {
		if err != nil {
			if errz.ContextDone(err) {
				t.log.Debug("getAgentVersions() interrupted", logz.Error(err))
				return nil
			}
			t.errRep.HandleProcessingError(ctx, t.log, "getAgentVersions: failed to scan redis hash", err)
			continue
		}

		agentVersions = append(agentVersions, se.RawHashKey)
	}
	return agentVersions
}

// connectionsByAgentIDHashKey returns a key for agentID -> (connectionId -> marshaled ConnectedAgentkInfo).
func connectionsByAgentIDHashKey(agentKeyPrefix string) redistool.KeyToRedisKey[int64] {
	prefix := agentKeyPrefix + ":conn_by_agent_id:"
	return func(agentID int64) string {
		return redistool.PrefixedInt64Key(prefix, agentID)
	}
}

// connectedAgentsHashKey returns the key for the hash of connected agents.
func connectedAgentsHashKey(agentKeyPrefix string) redistool.KeyToRedisKey[int64] {
	prefix := agentKeyPrefix + ":connected_agents"
	return func(_ int64) string {
		return prefix
	}
}

func connectionsByAgentVersionHashKey(agentKeyPrefix string) redistool.KeyToRedisKey[string] {
	prefix := agentKeyPrefix + ":conn_by_agent_version:"
	return func(agentVersion string) string {
		return prefix + agentVersion
	}
}

func agentVersionsHashKey(agentKeyPrefix string) redistool.KeyToRedisKey[int64] {
	prefix := agentKeyPrefix + ":agent_versions"
	return func(_ int64) string {
		return prefix
	}
}

func int64ToStr(key int64) string {
	return strconv.FormatInt(key, 10)
}

func strToStr(s string) string {
	return s
}
