package router

import (
	"context"
	"errors"
	"log/slog"
	"time"

	grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/api"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/module/modshared"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/grpctool"
	"gitlab.com/gitlab-org/cluster-integration/tunnel/tool/grpcz"
	"gitlab.com/gitlab-org/cluster-integration/tunnel/tool/retry"
	"gitlab.com/gitlab-org/cluster-integration/tunnel/tunserver"
	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/credentials"
	"google.golang.org/grpc/metadata"
	"google.golang.org/grpc/status"
)

var (
	_ tunserver.RouterPlugin[grpcz.URLTarget, api.AgentKey] = (*Plugin)(nil)
)

type Plugin struct {
	Log                   *slog.Logger
	Registry              *Registry
	AgentkPool            grpcz.PoolInterface[grpcz.URLTarget]
	GatewayQuerier        tunserver.PollingGatewayURLQuerier[grpcz.URLTarget, api.AgentKey]
	API                   modshared.API
	OwnPrivateAPIURL      grpcz.URLTarget
	Creds                 credentials.PerRPCCredentials
	PollConfig            retry.PollConfig
	TryNewGatewayInterval time.Duration
	TunnelFindTimeout     time.Duration
}

func (p *Plugin) FindReadyGateway(ctx context.Context, method string) (tunserver.ReadyGateway[grpcz.URLTarget], *slog.Logger, api.AgentKey, error) {
	gf := tunserver.NewGatewayFinder(
		ctx,
		p.Log,
		p.AgentkPool,
		p.GatewayQuerier,
		p.API,
		method,
		p.OwnPrivateAPIURL,
		modshared.NoAgentKey,
		p.PollConfig,
		p.TryNewGatewayInterval,
	)
	findCtx, findCancel := context.WithTimeout(ctx, p.TunnelFindTimeout)
	defer findCancel()

	rg, err := gf.Find(findCtx)
	if err != nil {
		switch { // Order is important here.
		case ctx.Err() != nil: // Incoming stream canceled.
			return tunserver.ReadyGateway[grpcz.URLTarget]{}, nil, modshared.NoAgentKey, grpctool.StatusErrorFromContext(ctx, "request aborted")
		case findCtx.Err() != nil: // Find tunnel timed out.
			p.API.HandleProcessingError(findCtx, p.Log, "Finding tunnel timed out", errors.New(findCtx.Err().Error()))
			return tunserver.ReadyGateway[grpcz.URLTarget]{}, nil, modshared.NoAgentKey, status.Error(codes.DeadlineExceeded, "finding tunnel timed out")
		default: // This should never happen, but let's handle a non-ctx error for completeness and future-proofing.
			p.API.HandleProcessingError(findCtx, p.Log, "Finding tunnel failed", err)
			return tunserver.ReadyGateway[grpcz.URLTarget]{}, nil, modshared.NoAgentKey, status.Errorf(codes.Unavailable, "find tunnel failed: %v", err)
		}
	}
	return rg, p.Log, modshared.NoAgentKey, nil
}

func (p *Plugin) FindTunnel(stream grpc.ServerStream) (bool, *slog.Logger, tunserver.FindHandle, error) {
	ctx := stream.Context()
	sts := grpc.ServerTransportStreamFromContext(ctx)
	service, method := grpcz.SplitGRPCMethod(sts.Method())
	found, handle := p.Registry.FindTunnel(ctx, service, method)
	rpcAPI := modshared.RPCAPIFromContext[modshared.RPCAPI](ctx)
	log := rpcAPI.Log()
	return found, log, handle, nil
}

func (p *Plugin) PrepareStreamForForwarding(stream grpc.ServerStream) (grpc.ServerStream, error) {
	ctx := stream.Context()
	creds, err := p.Creds.GetRequestMetadata(ctx)
	if err != nil {
		return nil, status.Errorf(codes.Internal, "credentials: %v", err)
	}
	md, _ := metadata.FromIncomingContext(ctx)
	mergeCredentialsIntoMetadata(md, creds)
	wrappedStream := grpc_middleware.WrapServerStream(stream)
	wrappedStream.WrappedContext = metadata.NewIncomingContext(wrappedStream.WrappedContext, md)
	return wrappedStream, nil
}

func mergeCredentialsIntoMetadata(md metadata.MD, creds map[string]string) {
	for k, v := range creds {
		md[k] = []string{v}
	}
}
