package router

import (
	"context"
	"errors"
	"fmt"
	"log/slog"
	"sync"

	"gitlab.com/gitlab-org/cluster-integration/tunnel/info"
	"gitlab.com/gitlab-org/cluster-integration/tunnel/rpc"
	"gitlab.com/gitlab-org/cluster-integration/tunnel/tunserver"
	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
)

var _ tunserver.Tunnel = (*tunnelHolder)(nil)

type findTunnelRequest struct {
	service, method string
	retTun          chan<- *tunnelHolder
}

type findHandle struct {
	retTun    <-chan *tunnelHolder
	done      func(context.Context)
	gotTunnel bool
}

type tunnelHolder struct {
	forwarder  tunserver.Forwarder
	descriptor *info.APIDescriptor
	state      tunserver.StateType
	retErr     chan<- error

	onForward func(*tunnelHolder) error
	onDone    func(context.Context, *tunnelHolder)
}

func (t *tunnelHolder) ForwardStream(log *slog.Logger, incomingStream grpc.ServerStream, cb tunserver.DataCallback) error {
	if err := t.onForward(t); err != nil {
		return err
	}

	return t.forwarder.ForwardStream(log, incomingStream, cb)
}

func (t *tunnelHolder) Done(ctx context.Context) {
	t.onDone(ctx, t)
}

func (h *findHandle) Get(ctx context.Context) (tunserver.Tunnel, error) {
	select {
	case <-ctx.Done():
		return nil, ctx.Err()
	case tun := <-h.retTun:
		h.gotTunnel = true
		if tun == nil {
			return nil, status.Error(codes.Canceled, "agentk is shutting down")
		}
		return tun, nil
	}
}

func (h *findHandle) Done(ctx context.Context) {
	if h.gotTunnel {
		// No cleanup needed if Get returned a tunnel.
		return
	}
	h.done(ctx)
}

type Registry struct {
	log *slog.Logger

	mu           sync.Mutex
	findRequests map[*findTunnelRequest]struct{}
	tunHolders   map[*tunnelHolder]struct{}
}

func NewRegistry(log *slog.Logger) *Registry {
	return &Registry{
		log:          log,
		findRequests: make(map[*findTunnelRequest]struct{}),
		tunHolders:   make(map[*tunnelHolder]struct{}),
	}
}

func (r *Registry) FindTunnel(_ context.Context, service, method string) (bool, tunserver.FindHandle) {
	// Buffer 1 to not block on send when a tunnel is found before find request is registered.
	retTun := make(chan *tunnelHolder, 1) // can receive nil from it if Stop() is called
	ftr := &findTunnelRequest{
		service: service,
		method:  method,
		retTun:  retTun,
	}
	found := false
	func() {
		r.mu.Lock()
		defer r.mu.Unlock()

		// 1. Check if we have a suitable tunnel
		for tunHolder := range r.tunHolders {
			if !tunHolder.descriptor.SupportsServiceAndMethod(service, method) {
				continue
			}
			// Suitable tunnel found!
			tunHolder.state = tunserver.StateFound
			retTun <- tunHolder // must not block because the reception is below
			found = true
			r.unregisterTunnelLocked(tunHolder)
			return
		}
		// 2. No suitable tunnel found, add to the queue
		r.findRequests[ftr] = struct{}{}
	}()
	return found, &findHandle{
		retTun: retTun,
		done: func(ctx context.Context) {
			r.mu.Lock()
			defer r.mu.Unlock()
			close(retTun)
			tun := <-retTun // will get nil if there was nothing in the channel or if registry is shutting down.
			if tun != nil {
				// Got the tunnel, but it's too late so return it to the registry.
				r.onTunnelDoneLocked(tun)
			} else {
				r.deleteFindRequestLocked(ftr)
			}
		},
	}
}

// HandleTunnel is called with server-side interface of the reverse tunnel.
// It registers the tunnel and blocks, waiting for a request to proxy through the tunnel.
// The method returns the error value to return to gRPC framework.
// ageCtx can be used to unblock the method if the tunnel is not being used already.
func (r *Registry) HandleTunnel(ageCtx context.Context, server grpc.BidiStreamingServer[rpc.ConnectRequest, rpc.ConnectResponse]) error {
	recv, err := server.Recv()
	if err != nil {
		return err
	}
	descriptor, ok := recv.Msg.(*rpc.ConnectRequest_Descriptor_)
	if !ok {
		return status.Errorf(codes.InvalidArgument, "invalid oneof value type: %T", recv.Msg)
	}
	retErr := make(chan error, 1)
	tunHolder := &tunnelHolder{
		forwarder: &tunserver.TunnelForwarder{
			Tunnel:       server,
			TunnelRetErr: retErr,
		},
		descriptor: descriptor.Descriptor_.ApiDescriptor,
		state:      tunserver.StateReady,
		retErr:     retErr,

		onForward: r.onTunnelForward,
		onDone:    r.onTunnelDone,
	}
	// Register
	r.registerTunnel(tunHolder)
	// Wait for return error or for cancellation
	select {
	case <-ageCtx.Done():
		// Context canceled
		r.mu.Lock()
		switch tunHolder.state {
		case tunserver.StateReady:
			tunHolder.state = tunserver.StateContextDone
			r.unregisterTunnelLocked(tunHolder)
			r.mu.Unlock()
			return nil
		case tunserver.StateFound:
			// Tunnel was found but hasn't been used yet, Done() hasn't been called.
			// Set State to StateContextDone so that ForwardStream() errors out without doing any I/O.
			tunHolder.state = tunserver.StateContextDone
			r.mu.Unlock()
			return nil
		case tunserver.StateForwarding:
			// I/O on the stream will error out, just wait for the return value.
			r.mu.Unlock()
			return <-retErr
		case tunserver.StateDone:
			// Forwarding has finished and then ctx signaled done. Return the result value from forwarding.
			r.mu.Unlock()
			return <-retErr
		case tunserver.StateContextDone:
			// Cannot happen twice.
			r.mu.Unlock()
			panic(errors.New("unreachable"))
		default:
			// Should never happen
			r.mu.Unlock()
			panic(fmt.Errorf("invalid State: %d", tunHolder.state))
		}
	case err = <-retErr:
		return err
	}
}

func (r *Registry) registerTunnel(toReg *tunnelHolder) {
	r.mu.Lock()
	defer r.mu.Unlock()

	r.registerTunnelLocked(toReg)
}

func (r *Registry) registerTunnelLocked(toReg *tunnelHolder) {
	// 1. Before registering the tunnel see if there is a find tunnel request waiting for it
	for ftr := range r.findRequests {
		if !toReg.descriptor.SupportsServiceAndMethod(ftr.service, ftr.method) {
			continue
		}
		// Waiting request found!
		toReg.state = tunserver.StateFound

		ftr.retTun <- toReg            // Satisfy the waiting request ASAP
		r.deleteFindRequestLocked(ftr) // Remove it from the queue
		return
	}

	// 2. Register the tunnel
	toReg.state = tunserver.StateReady
	r.tunHolders[toReg] = struct{}{}
}

func (r *Registry) unregisterTunnelLocked(toUnreg *tunnelHolder) {
	delete(r.tunHolders, toUnreg)
}

func (r *Registry) onTunnelForward(tunHolder *tunnelHolder) error {
	r.mu.Lock()
	defer r.mu.Unlock()
	switch tunHolder.state {
	case tunserver.StateReady:
		return status.Error(codes.Internal, "unreachable: ready -> forwarding should never happen")
	case tunserver.StateFound:
		tunHolder.state = tunserver.StateForwarding
		return nil
	case tunserver.StateForwarding:
		return status.Error(codes.Internal, "ForwardStream() called more than once")
	case tunserver.StateDone:
		return status.Error(codes.Internal, "ForwardStream() called after Done()")
	case tunserver.StateContextDone:
		return status.Error(codes.Canceled, "ForwardStream() called on done stream")
	default:
		return status.Errorf(codes.Internal, "unreachable: invalid State: %d", tunHolder.state)
	}
}

func (r *Registry) onTunnelDone(_ context.Context, tunHolder *tunnelHolder) {
	r.mu.Lock()
	defer r.mu.Unlock()
	r.onTunnelDoneLocked(tunHolder)
}

func (r *Registry) onTunnelDoneLocked(tunHolder *tunnelHolder) {
	switch tunHolder.state {
	case tunserver.StateReady:
		panic(errors.New("unreachable: ready -> done should never happen"))
	case tunserver.StateFound:
		// Tunnel was found but was not used, Done() was called. Just put it back.
		r.registerTunnelLocked(tunHolder)
	case tunserver.StateForwarding:
		tunHolder.state = tunserver.StateDone
	case tunserver.StateDone:
		panic("Done() called more than once")
	case tunserver.StateContextDone:
	// Done() called after canceled context in HandleTunnel(). Nothing to do.
	default:
		// Should never happen
		panic(fmt.Sprintf("invalid State: %d", tunHolder.state))
	}
}

func (r *Registry) deleteFindRequestLocked(ftr *findTunnelRequest) {
	delete(r.findRequests, ftr)
}
