package kas

import (
	"context"
	"fmt"
	"sync"

	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/api"
	agent2kas_tunnel_router "gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/module/agent2kas_tunnel/router"
	"google.golang.org/grpc"
	"google.golang.org/grpc/metadata"
)

type agentConnPool struct {
	routingConn grpc.ClientConnInterface

	mu       sync.Mutex
	key2conn map[api.AgentKey]grpc.ClientConnInterface
}

func newAgentConnPool(routingConn grpc.ClientConnInterface) *agentConnPool {
	return &agentConnPool{
		routingConn: routingConn,
		key2conn:    make(map[api.AgentKey]grpc.ClientConnInterface),
	}
}

func (p *agentConnPool) Add(agentKey api.AgentKey, conn grpc.ClientConnInterface) error {
	p.mu.Lock()
	defer p.mu.Unlock()
	if _, ok := p.key2conn[agentKey]; ok {
		return fmt.Errorf("connection for agent key %s is already registered", agentKey)
	}
	p.key2conn[agentKey] = conn
	return nil
}

func (p *agentConnPool) Remove(agentKey api.AgentKey) {
	p.mu.Lock()
	defer p.mu.Unlock()
	delete(p.key2conn, agentKey)
}

func (p *agentConnPool) Get(agentKey api.AgentKey) grpc.ClientConnInterface {
	p.mu.Lock()
	defer p.mu.Unlock()
	conn, ok := p.key2conn[agentKey]
	if ok {
		return conn
	}
	return &routingMetadataWrapper{
		delegate: p.routingConn,
		agentKey: agentKey,
	}
}

var (
	_ grpc.ClientConnInterface = (*routingMetadataWrapper)(nil)
)

type routingMetadataWrapper struct {
	delegate grpc.ClientConnInterface
	agentKey api.AgentKey
}

func (w *routingMetadataWrapper) Invoke(ctx context.Context, method string, args, reply any, opts ...grpc.CallOption) error {
	return w.delegate.Invoke(w.setRoutingMetadata(ctx), method, args, reply, opts...)
}

func (w *routingMetadataWrapper) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
	return w.delegate.NewStream(w.setRoutingMetadata(ctx), desc, method, opts...)
}

func (w *routingMetadataWrapper) setRoutingMetadata(ctx context.Context) context.Context {
	md, _ := metadata.FromOutgoingContext(ctx)
	md = agent2kas_tunnel_router.SetRoutingMetadata(md, w.agentKey)
	return metadata.NewOutgoingContext(ctx, md)
}
