package router

import (
	"context"
	"errors"
	"fmt"
	"net/netip"
	"time"
	"unsafe"

	"buf.build/go/protovalidate"
	"github.com/redis/rueidis"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/api"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/nettool"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/redistool"
	otelmetric "go.opentelemetry.io/otel/metric"
	"google.golang.org/protobuf/proto"
	"k8s.io/apimachinery/pkg/util/sets"
	"k8s.io/utils/clock"
)

const (
	tunnelsByAgentKeyHashName = "tunnels_by_agent_key"
)

type Querier interface {
	// KASURLsByAgentKey returns the list of kas URLs for a particular agent key.
	// A partial list may be returned together with an error.
	// Safe for concurrent use.
	KASURLsByAgentKey(ctx context.Context, agentKey api.AgentKey) ([]nettool.MultiURL, error)
}

// RegistrationBuilder allows batching tunnel (un)registrations.
// Can be reused after Do is called.
type RegistrationBuilder interface {
	// Register registers tunnels for the given agent ids with the tracker.
	Register(ttl time.Duration, agentKeys ...api.AgentKey)
	// Unregister unregisters tunnels for the given agent ids with the tracker.
	Unregister(agentKeys ...api.AgentKey)
	// Refresh refreshes registered tunnels in the underlying storage.
	Refresh(ttl time.Duration, agentKeys ...api.AgentKey)
	// Do executes the enqueued operations.
	Do(context.Context) error
}

// Registerer allows to register and unregister tunnels.
// Caller is responsible for periodically calling GC() and Refresh().
// Not safe for concurrent use.
type Registerer interface {
	RegistrationBuilder() RegistrationBuilder
	// GC deletes expired tunnels from the underlying storage.
	GC(ctx context.Context, agentKeys []api.AgentKey) (int /* keysDeleted */, error)
}

type Tracker interface {
	Registerer
	Querier
}

type RedisTracker struct {
	validator         protovalidate.Validator
	ownPrivateAPIURL  string
	clock             clock.PassiveClock
	tunnelsByAgentKey redistool.ExpiringHashAPI[api.AgentKey, string] // agentKey -> kas multi URL proto -> nil
}

func NewRedisTracker(client rueidis.Client, validator protovalidate.Validator, agentKeyPrefix string, ownPrivateAPIURL nettool.MultiURL, m otelmetric.Meter) (*RedisTracker, error) {
	tunnelsByAgentKey, err := redistool.NewRedisExpiringHashAPI(tunnelsByAgentKeyHashName, client, tunnelsByAgentKeyHashKey(agentKeyPrefix), strToStr, m)
	if err != nil {
		return nil, err
	}

	ownPrivateAPIURLProto, err := proto.Marshal(fromMultiURL(ownPrivateAPIURL))
	if err != nil {
		return nil, err
	}

	return &RedisTracker{
		validator:         validator,
		ownPrivateAPIURL:  string(ownPrivateAPIURLProto),
		clock:             clock.RealClock{},
		tunnelsByAgentKey: tunnelsByAgentKey,
	}, nil
}

func (t *RedisTracker) RegistrationBuilder() RegistrationBuilder {
	return &redisRegistrationBuilder{
		ownPrivateAPIURL: t.ownPrivateAPIURL,
		clock:            t.clock,
		iob:              t.tunnelsByAgentKey.IOBuilder(),
	}
}

func (t *RedisTracker) KASURLsByAgentKey(ctx context.Context, agentKey api.AgentKey) ([]nettool.MultiURL, error) {
	urls := make(sets.Set[nettool.MultiURL])
	var errs []error
	for se, err := range t.tunnelsByAgentKey.Scan(ctx, agentKey) {
		if err != nil {
			errs = append(errs, err)
			continue
		}
		var mu MultiURL
		// Avoid creating a temporary copy
		rawHashKeyBytes := unsafe.Slice(unsafe.StringData(se.RawHashKey), len(se.RawHashKey)) //nolint: gosec
		err = proto.Unmarshal(rawHashKeyBytes, &mu)
		if err != nil {
			errs = append(errs, err)
			continue
		}
		err = t.validator.Validate(&mu) // validate just in case. We don't fully trust the stuff from Redis.
		if err != nil {
			errs = append(errs, err)
			continue
		}

		var x nettool.MultiURL
		switch {
		case mu.Host != "":
			x = nettool.NewMultiURLForHost(mu.Scheme, mu.Host, mu.TlsHost, uint16(mu.Port)) //nolint:gosec
		case len(mu.Ip) > 0:
			addrs := make([]netip.Addr, 0, len(mu.Ip))
			for _, ip := range mu.Ip {
				netIP, ok := netip.AddrFromSlice(ip)
				if !ok {
					errs = append(errs, fmt.Errorf("invalid IP: 0x%x", ip))
					continue
				}
				addrs = append(addrs, netIP.Unmap()) // unmap just in case. We don't fully trust the stuff from Redis.
			}
			x = nettool.NewMultiURLForAddresses(mu.Scheme, mu.TlsHost, uint16(mu.Port), addrs) //nolint:gosec
		default:
			// this shouldn't happen - we validate the proto above
			errs = append(errs, errors.New("neither host or IP provided"))
			continue
		}
		urls.Insert(x)
	}
	return urls.UnsortedList(), errors.Join(errs...)
}

func (t *RedisTracker) GC(ctx context.Context, agentKeys []api.AgentKey) (int /* keysDeleted */, error) {
	return t.tunnelsByAgentKey.GCFor(agentKeys, true)(ctx)
}

type redisRegistrationBuilder struct {
	ownPrivateAPIURL string
	clock            clock.PassiveClock
	iob              redistool.IOBuilder[api.AgentKey, string]
}

func (b *redisRegistrationBuilder) Register(ttl time.Duration, agentKeys ...api.AgentKey) {
	exp := b.clock.Now().Add(ttl)
	b.iob.Set(agentKeys, ttl, kv(b.ownPrivateAPIURL, exp))
}

func (b *redisRegistrationBuilder) Unregister(agentKeys ...api.AgentKey) {
	b.iob.Unset(agentKeys, b.ownPrivateAPIURL)
}

func (b *redisRegistrationBuilder) Refresh(ttl time.Duration, agentKeys ...api.AgentKey) {
	// Refresh is the same as registration right now, but this might change in the future.
	b.Register(ttl, agentKeys...)
}

func (b *redisRegistrationBuilder) Do(ctx context.Context) error {
	return b.iob.Do(ctx)
}

func kv(key string, expiresAt time.Time) redistool.BuilderKV[string] {
	return redistool.BuilderKV[string]{
		HashKey: key,
		Value: &redistool.ExpiringValue{
			ExpiresAt: expiresAt.Unix(),
			Value:     nil, // nothing to store.
		},
	}
}

// tunnelsByAgentKeyHashKey returns a key for agentKey -> (kas multi URL proto -> nil).
func tunnelsByAgentKeyHashKey(agentKeyPrefix string) redistool.KeyToRedisKey[api.AgentKey] {
	prefix := agentKeyPrefix + ":kas_by_agent_key:"
	return func(agentKey api.AgentKey) string {
		return redistool.PrefixedAgentKey(prefix, agentKey)
	}
}

func strToStr(key string) string {
	return key
}

func fromMultiURL(u nettool.MultiURL) *MultiURL {
	addrs := u.Addresses()
	ips := make([][]byte, 0, len(addrs))
	for _, addr := range addrs {
		ips = append(ips, addr.AsSlice())
	}
	return &MultiURL{
		Scheme:  u.Scheme(),
		Ip:      ips,
		Host:    u.Host(),
		TlsHost: u.TLSHost(),
		Port:    uint32(u.Port()),
	}
}
