package router

import (
	"context"
	"errors"
	"fmt"
	"log/slog"
	"sync"
	"time"

	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/api"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/module/modshared"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/logz"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/mathz"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/nettool"
	"gitlab.com/gitlab-org/cluster-integration/tunnel/info"
	"gitlab.com/gitlab-org/cluster-integration/tunnel/rpc"
	"gitlab.com/gitlab-org/cluster-integration/tunnel/tunserver"
	"go.opentelemetry.io/otel/attribute"
	otelmetric "go.opentelemetry.io/otel/metric"
	"go.opentelemetry.io/otel/trace"
	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
	"k8s.io/apimachinery/pkg/util/wait"
)

const (
	traceTunnelFoundAttr    attribute.Key = "found"
	traceStoppedTunnelsAttr attribute.Key = "stoppedTunnels"
	traceAbortedFTRAttr     attribute.Key = "abortedFTR"
	// gcPeriodJitterPercentage is the percentage of the configured GC period that's added as jitter.
	gcPeriodJitterPercentage = 5
)

type Handler interface {
	// HandleTunnel is called with server-side interface of the reverse tunnel.
	// It registers the tunnel and blocks, waiting for a request to proxy through the tunnel.
	// The method returns the error value to return to gRPC framework.
	// ageCtx can be used to unblock the method if the tunnel is not being used already.
	HandleTunnel(ctx, ageCtx context.Context, agentKey api.AgentKey, forwarderFactory ForwarderFactory) error
}

type findTunnelRequest struct {
	agentKey        api.AgentKey
	service, method string
	retTunHolder    chan<- *tunnelHolder
}

type tunnelHolder struct {
	forwarder Forwarder
	agentKey  api.AgentKey
	state     tunserver.StateType
	retErr    chan<- error

	onForward func(*tunnelHolder) error
	onDone    func(context.Context, *tunnelHolder)
}

type Forwarder interface {
	tunserver.Forwarder

	SupportsServiceAndMethod(service, method string) bool
}

type TunnelForwarder struct {
	tunserver.TunnelForwarder

	Descriptor *info.APIDescriptor
}

func (f *TunnelForwarder) SupportsServiceAndMethod(service, method string) bool {
	return f.Descriptor.SupportsServiceAndMethod(service, method)
}

func NewTunnelForwarder(server grpc.BidiStreamingServer[rpc.ConnectRequest, rpc.ConnectResponse], tunnelRetErr chan<- error) (*TunnelForwarder, error) {
	recv, err := server.Recv()
	if err != nil {
		return nil, err
	}
	descriptor, ok := recv.Msg.(*rpc.ConnectRequest_Descriptor_)
	if !ok {
		return nil, status.Errorf(codes.InvalidArgument, "invalid oneof value type: %T", recv.Msg)
	}

	return &TunnelForwarder{
		TunnelForwarder: tunserver.TunnelForwarder{
			Tunnel:       server,
			TunnelRetErr: tunnelRetErr,
		},
		Descriptor: descriptor.Descriptor_.ApiDescriptor,
	}, nil
}

type ForwarderFactory func(tunnelRetErr chan<- error) (Forwarder, error)

type agentKey2tunInfo struct {
	tunHolders map[*tunnelHolder]struct{}
}

type Registry struct {
	log           *slog.Logger
	api           modshared.API
	tracer        trace.Tracer
	refreshPeriod time.Duration
	gcPeriod      time.Duration
	tunnelTracker Querier
	asyncTracker  *asyncTracker

	mu                     sync.Mutex
	tunsByAgentKey         map[api.AgentKey]agentKey2tunInfo
	findRequestsByAgentKey map[api.AgentKey]map[*findTunnelRequest]struct{}
}

func NewRegistry(log *slog.Logger, modAPI modshared.API, tracer trace.Tracer, meter otelmetric.Meter,
	refreshPeriod, gcPeriod, ttl time.Duration, tunnelTracker Tracker) (*Registry, error) {

	at, err := newAsyncTracker(log, modAPI, meter, tunnelTracker, ttl)
	if err != nil {
		return nil, err
	}
	return &Registry{
		log:                    log,
		api:                    modAPI,
		tracer:                 tracer,
		refreshPeriod:          refreshPeriod,
		gcPeriod:               gcPeriod,
		tunnelTracker:          tunnelTracker,
		asyncTracker:           at,
		tunsByAgentKey:         make(map[api.AgentKey]agentKey2tunInfo),
		findRequestsByAgentKey: make(map[api.AgentKey]map[*findTunnelRequest]struct{}),
	}, nil
}

func (r *Registry) FindTunnel(ctx context.Context, agentKey api.AgentKey, service, method string) (bool, tunserver.FindHandle) {
	_, span := r.tracer.Start(ctx, "Registry.FindTunnel", trace.WithSpanKind(trace.SpanKindInternal))
	defer span.End()

	// Buffer 1 to not block on send when a tunnel is found before find request is registered.
	retTunHolder := make(chan *tunnelHolder, 1) // can receive nil from it if Stop() is called
	ftr := &findTunnelRequest{
		agentKey:     agentKey,
		service:      service,
		method:       method,
		retTunHolder: retTunHolder,
	}
	found := false
	func() { //nolint:contextcheck
		r.mu.Lock()
		defer r.mu.Unlock()

		// 1. Check if we have a suitable tunnel
		for tunHolder := range r.tunsByAgentKey[agentKey].tunHolders {
			if !tunHolder.forwarder.SupportsServiceAndMethod(service, method) {
				continue
			}
			// Suitable tunnel found!
			tunHolder.state = tunserver.StateFound
			retTunHolder <- tunHolder // must not block because the reception is below
			found = true
			r.unregisterTunnelLocked(tunHolder)
			return
		}
		// 2. No suitable tunnel found, add to the queue
		findRequestsForAgentKey := r.findRequestsByAgentKey[agentKey]
		if findRequestsForAgentKey == nil {
			findRequestsForAgentKey = make(map[*findTunnelRequest]struct{}, 1)
			r.findRequestsByAgentKey[agentKey] = findRequestsForAgentKey
		}
		findRequestsForAgentKey[ftr] = struct{}{}
	}()
	span.SetAttributes(traceTunnelFoundAttr.Bool(found))
	return found, &findHandle{
		tracer:       r.tracer,
		retTunHolder: retTunHolder,
		done: func() { //nolint:contextcheck
			r.mu.Lock()
			defer r.mu.Unlock()
			close(retTunHolder)
			tun := <-retTunHolder // will get nil if there was nothing in the channel or if registry is shutting down.
			if tun != nil {
				// Got the tunnel, but it's too late so return it to the registry.
				r.onTunnelDoneLocked(tun)
			} else {
				r.deleteFindRequestLocked(ftr)
			}
		},
	}
}

func (r *Registry) HandleTunnel(ctx, ageCtx context.Context, agentKey api.AgentKey, forwarderFactory ForwarderFactory) error {
	_, span := r.tracer.Start(ctx, "Registry.HandleTunnel", trace.WithSpanKind(trace.SpanKindServer))
	defer span.End() // we don't add the returned error to the span as it's added by the gRPC OTEL stats handler already.

	retErr := make(chan error, 1)
	forwarder, err := forwarderFactory(retErr)
	if err != nil {
		return err
	}

	tunHolder := &tunnelHolder{
		forwarder: forwarder,
		agentKey:  agentKey,
		state:     tunserver.StateReady,
		retErr:    retErr,

		onForward: r.onTunnelForward,
		onDone:    r.onTunnelDone,
	}
	// Register
	r.registerTunnel(tunHolder) //nolint:contextcheck
	// Wait for return error or for cancellation
	select {
	case <-ageCtx.Done():
		// Context canceled
		r.mu.Lock()
		switch tunHolder.state {
		case tunserver.StateReady:
			defer r.mu.Unlock()
			tunHolder.state = tunserver.StateContextDone
			r.unregisterTunnelLocked(tunHolder) //nolint:contextcheck
			return nil
		case tunserver.StateFound:
			// Tunnel was found but hasn't been used yet, Done() hasn't been called.
			// Set State to StateContextDone so that ForwardStream() errors out without doing any I/O.
			tunHolder.state = tunserver.StateContextDone
			r.mu.Unlock()
			return nil
		case tunserver.StateForwarding:
			// I/O on the stream will error out, just wait for the return value.
			r.mu.Unlock()
			return <-retErr
		case tunserver.StateDone:
			// Forwarding has finished and then ctx signaled done. Return the result value from forwarding.
			r.mu.Unlock()
			return <-retErr
		case tunserver.StateContextDone:
			// Cannot happen twice.
			r.mu.Unlock()
			panic(errors.New("unreachable"))
		default:
			// Should never happen
			r.mu.Unlock()
			panic(fmt.Errorf("invalid State: %d", tunHolder.state))
		}
	case err = <-retErr:
		return err
	}
}

func (r *Registry) registerTunnel(toReg *tunnelHolder) {
	r.mu.Lock()
	defer r.mu.Unlock()

	r.registerTunnelLocked(toReg)
}

func (r *Registry) registerTunnelLocked(toReg *tunnelHolder) {
	agentKey := toReg.agentKey
	// 1. Before registering the tunnel see if there is a find tunnel request waiting for it
	findRequestsForAgentKey := r.findRequestsByAgentKey[agentKey]
	for ftr := range findRequestsForAgentKey {
		if !toReg.forwarder.SupportsServiceAndMethod(ftr.service, ftr.method) {
			continue
		}
		// Waiting request found!
		toReg.state = tunserver.StateFound
		r.log.Debug("Registering agent tunnel and immediately found request to satisfy",
			logz.AgentKey(agentKey), logz.TunnelsByAgent(len(r.tunsByAgentKey[agentKey].tunHolders)))

		ftr.retTunHolder <- toReg      // Satisfy the waiting request ASAP
		r.deleteFindRequestLocked(ftr) // Remove it from the queue
		return
	}

	r.log.Debug("Registering agent tunnel", logz.AgentKey(agentKey), logz.TunnelsByAgent(len(r.tunsByAgentKey[agentKey].tunHolders)))

	// 2. Register the tunnel
	toReg.state = tunserver.StateReady
	info, ok := r.tunsByAgentKey[agentKey]
	if !ok {
		info = agentKey2tunInfo{
			tunHolders: make(map[*tunnelHolder]struct{}),
		}
		r.tunsByAgentKey[agentKey] = info // not a pointer, put it in
		// First tunnel for this agentKey. Register it.
		r.asyncTracker.register(agentKey)
	}
	info.tunHolders[toReg] = struct{}{}
}

func (r *Registry) unregisterTunnelLocked(toUnreg *tunnelHolder) {
	agentKey := toUnreg.agentKey
	info := r.tunsByAgentKey[agentKey]
	delete(info.tunHolders, toUnreg)

	r.log.Debug("Unregistering agent tunnel", logz.AgentKey(agentKey), logz.TunnelsByAgent(len(info.tunHolders)))
	if len(info.tunHolders) > 0 {
		// There are more tunnels for this agent key, nothing to do.
		return
	}
	// Last tunnel for this agentKey had been used. Unregister it.
	delete(r.tunsByAgentKey, agentKey)
	r.asyncTracker.unregister(agentKey)
}

func (r *Registry) onTunnelForward(tunHolder *tunnelHolder) error {
	r.mu.Lock()
	defer r.mu.Unlock()
	switch tunHolder.state {
	case tunserver.StateReady:
		return status.Error(codes.Internal, "unreachable: ready -> forwarding should never happen")
	case tunserver.StateFound:
		tunHolder.state = tunserver.StateForwarding
		return nil
	case tunserver.StateForwarding:
		return status.Error(codes.Internal, "ForwardStream() called more than once")
	case tunserver.StateDone:
		return status.Error(codes.Internal, "ForwardStream() called after Done()")
	case tunserver.StateContextDone:
		return status.Error(codes.Canceled, "ForwardStream() called on done stream")
	default:
		return status.Errorf(codes.Internal, "unreachable: invalid State: %d", tunHolder.state)
	}
}

func (r *Registry) onTunnelDone(ctx context.Context, tunHolder *tunnelHolder) {
	_, span := r.tracer.Start(ctx, "Registry.onTunnelDone", trace.WithSpanKind(trace.SpanKindInternal))
	defer span.End()

	r.mu.Lock()
	defer r.mu.Unlock()
	r.onTunnelDoneLocked(tunHolder) //nolint:contextcheck
}

func (r *Registry) onTunnelDoneLocked(tunHolder *tunnelHolder) {
	switch tunHolder.state {
	case tunserver.StateReady:
		panic(errors.New("unreachable: ready -> done should never happen"))
	case tunserver.StateFound:
		// Tunnel was found but was not used, Done() was called. Just put it back.
		r.registerTunnelLocked(tunHolder)
	case tunserver.StateForwarding:
		tunHolder.state = tunserver.StateDone
	case tunserver.StateDone:
		panic(errors.New("Done() called more than once"))
	case tunserver.StateContextDone:
	// Done() called after canceled context in HandleTunnel(). Nothing to do.
	default:
		// Should never happen
		panic(fmt.Errorf("invalid State: %d", tunHolder.state))
	}
}

func (r *Registry) deleteFindRequestLocked(ftr *findTunnelRequest) {
	findRequestsForAgentKey := r.findRequestsByAgentKey[ftr.agentKey]
	delete(findRequestsForAgentKey, ftr)
	if len(findRequestsForAgentKey) == 0 {
		delete(r.findRequestsByAgentKey, ftr.agentKey)
	}
}

func (r *Registry) KASURLsByAgentKey(ctx context.Context, agentKey api.AgentKey) ([]nettool.MultiURL, error) {
	ctx, span := r.tracer.Start(ctx, "Registry.KASURLsByAgentKey", trace.WithSpanKind(trace.SpanKindInternal))
	defer span.End()

	return r.tunnelTracker.KASURLsByAgentKey(ctx, agentKey)
}

func (r *Registry) Run(ctx context.Context) error {
	var wg wait.Group
	wg.Start(r.asyncTracker.run)
	defer wg.Wait()
	defer r.asyncTracker.done()
	defer r.stopInternal(ctx)
	refreshTicker := time.NewTicker(r.refreshPeriod)
	defer refreshTicker.Stop()
	gcPeriodWithJitter := mathz.DurationWithPositiveJitter(r.gcPeriod, gcPeriodJitterPercentage)
	gcTicker := time.NewTicker(gcPeriodWithJitter)
	defer gcTicker.Stop()
	done := ctx.Done()
	for {
		select {
		case <-done:
			return nil
		case <-refreshTicker.C:
			r.refreshRegistrations() //nolint:contextcheck
		case <-gcTicker.C:
			r.runGC() //nolint:contextcheck
		}
	}
}

// stopInternal aborts any open tunnels.
// It should not be necessary to abort tunnels when registry is used correctly i.e. this method is called after
// all tunnels have terminated gracefully.
func (r *Registry) stopInternal(ctx context.Context) (int /*stoppedTun*/, int /*abortedFtr*/) {
	ctx, span := r.tracer.Start(ctx, "Registry.stopInternal", trace.WithSpanKind(trace.SpanKindInternal))
	defer span.End()

	stoppedTun := 0
	abortedFtr := 0

	r.mu.Lock()
	defer r.mu.Unlock()

	// 1. Abort all waiting new stream requests
	for _, findRequestsForAgentKey := range r.findRequestsByAgentKey {
		for ftr := range findRequestsForAgentKey {
			ftr.retTunHolder <- nil
		}
		abortedFtr += len(findRequestsForAgentKey)
	}
	clear(r.findRequestsByAgentKey)

	// 2. Abort all tunnels
	agentKeys := make([]api.AgentKey, 0, len(r.tunsByAgentKey))
	for agentKey, info := range r.tunsByAgentKey {
		for tunHolder := range info.tunHolders {
			tunHolder.state = tunserver.StateDone
			tunHolder.retErr <- nil // nil so that HandleTunnel() returns cleanly and agent immediately retries
		}
		stoppedTun += len(info.tunHolders)
		agentKeys = append(agentKeys, agentKey)
	}
	r.asyncTracker.unregister(agentKeys...) //nolint:contextcheck
	clear(r.tunsByAgentKey)

	if stoppedTun > 0 || abortedFtr > 0 {
		r.api.HandleProcessingError(ctx, r.log, "", errors.New("stopped tunnels or aborted find tunnel requests"),
			slog.Int("num_stopped_tunnels", stoppedTun), slog.Int("num_aborted_find_tunnel_requests", abortedFtr))
	}

	span.SetAttributes(traceStoppedTunnelsAttr.Int(stoppedTun), traceAbortedFTRAttr.Int(abortedFtr))
	return stoppedTun, abortedFtr
}

func (r *Registry) refreshRegistrations() {
	r.mu.Lock()
	defer r.mu.Unlock()

	toRefresh := r.agentKeysLocked()
	r.asyncTracker.refresh(toRefresh)
}

func (r *Registry) runGC() {
	r.mu.Lock()
	defer r.mu.Unlock()

	tunsToGC := r.agentKeysLocked()
	r.asyncTracker.gc(tunsToGC)
}

func (r *Registry) agentKeysLocked() []api.AgentKey {
	ids := make([]api.AgentKey, 0, len(r.tunsByAgentKey))
	for agentKey := range r.tunsByAgentKey {
		ids = append(ids, agentKey)
	}
	return ids
}

func (t *tunnelHolder) ForwardStream(log *slog.Logger, incomingStream grpc.ServerStream, cb tunserver.DataCallback) error {
	if err := t.onForward(t); err != nil {
		return err
	}

	return t.forwarder.ForwardStream(log, incomingStream, cb)
}

func (t *tunnelHolder) Done(ctx context.Context) {
	t.onDone(ctx, t)
}
