package router

import (
	"context"
	"errors"
	"log/slog"
	"strconv"
	"strings"
	"time"

	grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/api"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/module/modshared"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/fieldz"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/grpctool"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/logz"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/nettool"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/retry"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tunnel/tunserver"
	"go.opentelemetry.io/otel/attribute"
	otelcodes "go.opentelemetry.io/otel/codes"
	otelmetric "go.opentelemetry.io/otel/metric"
	"go.opentelemetry.io/otel/trace"
	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/metadata"
	"google.golang.org/grpc/status"
	"k8s.io/utils/clock"
)

var (
	_ tunserver.RouterPlugin[nettool.MultiURL] = (*Plugin)(nil)
)

const (
	// routingHopPrefix is a metadata key prefix that is used for metadata keys that should be consumed by
	// the gateway kas instances and not passed along to agentk.
	routingHopPrefix = "kas-hop-"
	// routingAgentIDMetadataKey is used to pass destination agent id in request metadata
	// from the routing kas instance, that is handling the incoming request, to the gateway kas instance,
	// that is forwarding the request to an agent.
	routingAgentIDMetadataKey = routingHopPrefix + "routing-agent-id"
	// routingAgentTypeMetadataKey is used to pass destination agent type in request metadata
	// from the routing kas instance, that is handling the incoming request, to the gateway kas instance,
	// that is forwarding the request to an agent.
	routingAgentTypeMetadataKey = routingHopPrefix + "routing-agent-type"

	routingDurationMetricName                = "tunnel_routing_duration"
	routingTunnelTimeoutConnectedRecently    = "tunnel_routing_timeout_connected_recently"
	routingTunnelTimeoutNotConnectedRecently = "tunnel_routing_timeout_not_connected_recently"

	routingStatusAttributeName attribute.Key = "status"
)

var (
	routingStatusSuccessMetricOpts = []otelmetric.RecordOption{otelmetric.WithAttributeSet(attribute.NewSet(routingStatusAttributeName.String("success")))}
	routingStatusAbortedMetricOpts = []otelmetric.RecordOption{otelmetric.WithAttributeSet(attribute.NewSet(routingStatusAttributeName.String("aborted")))}
)

type AgentFinder interface {
	AgentLastConnected(ctx context.Context, agentKey api.AgentKey) (time.Time, error)
}

type TunnelFinder interface {
	FindTunnel(ctx context.Context, agentKey api.AgentKey, service, method string) (bool, tunserver.FindHandle)
}

// gatewayFinderFactory proves an indirection point for testing.
type gatewayFinderFactory func(outgoingCtx context.Context, log *slog.Logger, fullMethod string, agentKey api.AgentKey) tunserver.GatewayFinder[nettool.MultiURL]

type Plugin struct {
	api                                       modshared.API
	agentFinder                               AgentFinder
	tunnelRegistry                            TunnelFinder
	tracer                                    trace.Tracer
	tunnelFindTimeout                         time.Duration
	clock                                     clock.PassiveClock
	gwFactory                                 gatewayFinderFactory
	routingDuration                           otelmetric.Float64Histogram
	routingTimeoutConnectedRecentlyCounter    otelmetric.Int64Counter
	routingTimeoutNotConnectedRecentlyCounter otelmetric.Int64Counter
}

func NewPlugin(
	modAPI modshared.API,
	kasPool grpctool.PoolInterface[nettool.MultiURL],
	gatewayQuerier tunserver.PollingGatewayURLQuerier[nettool.MultiURL],
	agentFinder AgentFinder,
	tunnelRegistry TunnelFinder,
	tracer trace.Tracer,
	meter otelmetric.Meter,
	ownPrivateAPIURL nettool.MultiURL,
	pollConfig retry.PollConfigFactory,
	tryNewGatewayInterval time.Duration,
	tunnelFindTimeout time.Duration,
) (*Plugin, error) {
	return newPlugin(
		modAPI,
		agentFinder,
		tunnelRegistry,
		tracer,
		meter,
		tunnelFindTimeout,
		clock.RealClock{},
		func(outgoingCtx context.Context, log *slog.Logger, method string, agentKey api.AgentKey) tunserver.GatewayFinder[nettool.MultiURL] {
			return tunserver.NewGatewayFinder[nettool.MultiURL](
				outgoingCtx,
				log,
				kasPool,
				gatewayQuerier,
				modAPI,
				method,
				ownPrivateAPIURL,
				agentKey,
				pollConfig,
				tryNewGatewayInterval,
			)
		})
}

func newPlugin(
	api modshared.API,
	agentFinder AgentFinder,
	tunnelRegistry TunnelFinder,
	tracer trace.Tracer,
	meter otelmetric.Meter,
	tunnelFindTimeout time.Duration,
	clock clock.PassiveClock,
	gwFactory gatewayFinderFactory,
) (*Plugin, error) {
	routingDuration, connectedTimeoutCounter, registeredTimeoutCounter, err := constructRoutingMetrics(meter)
	if err != nil {
		return nil, err
	}

	return &Plugin{
		api:                                    api,
		agentFinder:                            agentFinder,
		tunnelRegistry:                         tunnelRegistry,
		tracer:                                 tracer,
		tunnelFindTimeout:                      tunnelFindTimeout,
		clock:                                  clock,
		gwFactory:                              gwFactory,
		routingDuration:                        routingDuration,
		routingTimeoutConnectedRecentlyCounter: connectedTimeoutCounter,
		routingTimeoutNotConnectedRecentlyCounter: registeredTimeoutCounter,
	}, nil
}

func (p *Plugin) FindReadyGateway(ctx context.Context, log *slog.Logger, method string) (tunserver.ReadyGateway[nettool.MultiURL], *slog.Logger, api.AgentKey, error) {
	startRouting := p.clock.Now()
	findCtx, span := p.tracer.Start(ctx, "router.findReadyGateway", trace.WithSpanKind(trace.SpanKindInternal))
	defer span.End()

	md, _ := metadata.FromOutgoingContext(ctx)
	agentKey, err := agentKeyFromMeta(md.Get(routingAgentIDMetadataKey), md.Get(routingAgentTypeMetadataKey))
	if err != nil {
		span.SetStatus(otelcodes.Error, "")
		span.RecordError(err)
		return tunserver.ReadyGateway[nettool.MultiURL]{}, nil, api.AgentKey{}, err // returns gRPC status error
	}

	log = log.With(logz.AgentKey(agentKey))
	gf := p.gwFactory(ctx, log, method, agentKey)
	findCtx, findCancel := context.WithTimeout(findCtx, p.tunnelFindTimeout)
	defer findCancel()

	rg, err := gf.Find(findCtx)
	if err != nil {
		switch { // Order is important here.
		case ctx.Err() != nil: // Incoming stream canceled.
			p.routingDuration.Record( //nolint: contextcheck
				context.Background(),
				float64(p.clock.Since(startRouting))/float64(time.Second),
				routingStatusAbortedMetricOpts...,
			)
			span.SetStatus(otelcodes.Error, "Aborted")
			span.RecordError(ctx.Err())
			return tunserver.ReadyGateway[nettool.MultiURL]{}, nil, api.AgentKey{}, grpctool.StatusErrorFromContext(ctx, "request aborted")
		case findCtx.Err() != nil: // Find tunnel timed out.
			findCtxErr := findCtx.Err()

			errCtx, cancel := context.WithTimeout(context.WithoutCancel(findCtx), 1*time.Second) // preserve context values for tracing
			defer cancel()
			// NOTE: we know at this point that we failed to find a tunnel within the timeout.
			// However, we don't know if the agent is simply not connected or if we fail to find the gateway.
			// We have the chance here to correlate the error with the data from the agent registration.
			lastConnected, lastConnErr := p.agentFinder.AgentLastConnected(errCtx, agentKey)
			if lastConnErr != nil {
				p.api.HandleProcessingError(findCtx, log, "Unable to correlate tunnel timeout error with agent registration data", lastConnErr, fieldz.AgentKey(agentKey))
				// we can continue here as we still want to report the original error.
			} else if lastConnected != (time.Time{}) {
				// we determined that the agent is registered and should be connected, but yet we failed to find the gateway for it.
				p.routingTimeoutConnectedRecentlyCounter.Add(context.Background(), 1) //nolint: contextcheck
				span.SetStatus(otelcodes.Error, "Timed out, agent connected recently")
				span.RecordError(findCtxErr)

				lastConnAgo := p.clock.Since(lastConnected) / time.Second
				p.api.HandleProcessingError(findCtx, log, "Finding tunnel timed out", errors.New("agent was recently connected"),
					fieldz.AgentKey(agentKey), fieldz.LastConnectedAt(lastConnected), fieldz.SecondsAgo(lastConnAgo))
				return tunserver.ReadyGateway[nettool.MultiURL]{}, nil, api.AgentKey{}, status.Errorf(codes.DeadlineExceeded, "finding tunnel timed out. Agent last connected %d second(s) ago", lastConnAgo)
			}
			p.routingTimeoutNotConnectedRecentlyCounter.Add(context.Background(), 1) //nolint: contextcheck
			span.SetStatus(otelcodes.Error, "Timed out, agent not connected recently")
			span.RecordError(findCtxErr)

			log.Debug("Finding tunnel timed out. Agent hasn't connected recently", logz.Error(findCtxErr), logz.AgentKey(agentKey))
			return tunserver.ReadyGateway[nettool.MultiURL]{}, nil, api.AgentKey{}, status.Error(codes.DeadlineExceeded, "finding tunnel timed out. Agent hasn't connected recently. Make sure the agent is connected and up to date")
		default: // This should never happen, but let's handle a non-ctx error for completeness and future-proofing.
			span.SetStatus(otelcodes.Error, "Failed")
			span.RecordError(err)

			p.api.HandleProcessingError(findCtx, log, "Finding tunnel failed", err, fieldz.AgentKey(agentKey))
			return tunserver.ReadyGateway[nettool.MultiURL]{}, nil, api.AgentKey{}, status.Errorf(codes.Unavailable, "find tunnel failed: %v", err)
		}
	}
	p.routingDuration.Record( //nolint: contextcheck
		context.Background(),
		float64(p.clock.Since(startRouting))/float64(time.Second),
		routingStatusSuccessMetricOpts...,
	)
	span.SetStatus(otelcodes.Ok, "")
	return rg, log, agentKey, nil
}

func (p *Plugin) FindTunnel(stream grpc.ServerStream, rpcAPI modshared.RPCAPI) (bool, *slog.Logger, tunserver.FindHandle, error) {
	ctx := stream.Context()
	agentKey, err := agentKeyFromMeta(
		metadata.ValueFromIncomingContext(ctx, routingAgentIDMetadataKey),
		metadata.ValueFromIncomingContext(ctx, routingAgentTypeMetadataKey),
	)
	if err != nil {
		return false, nil, nil, err
	}
	sts := grpc.ServerTransportStreamFromContext(ctx)
	service, method := grpctool.SplitGRPCMethod(sts.Method())
	log := rpcAPI.Log().With(logz.AgentKey(agentKey))
	found, handle := p.tunnelRegistry.FindTunnel(ctx, agentKey, service, method)
	return found, log, handle, nil
}

func (p *Plugin) PrepareStreamForForwarding(stream grpc.ServerStream) (grpc.ServerStream, error) {
	md, _ := metadata.FromIncomingContext(stream.Context())
	removeHopMeta(md)
	// Overwrite incoming MD with sanitized MD
	wrappedStream := grpc_middleware.WrapServerStream(stream)
	wrappedStream.WrappedContext = metadata.NewIncomingContext(
		wrappedStream.WrappedContext,
		md,
	)
	return wrappedStream, nil
}

func removeHopMeta(md metadata.MD) {
	for k := range md {
		if strings.HasPrefix(k, routingHopPrefix) {
			delete(md, k)
		}
	}
}

func SetRoutingMetadata(md metadata.MD, agentKey api.AgentKey) metadata.MD {
	if md == nil {
		md = metadata.MD{}
	}
	md[routingAgentIDMetadataKey] = []string{strconv.FormatInt(agentKey.ID, 10)}
	md[routingAgentTypeMetadataKey] = []string{agentKey.Type.String()}
	return md
}

// agentKeyFromMeta returns agent key or an gRPC status error.
func agentKeyFromMeta(idVal []string, typeVal []string) (api.AgentKey, error) {
	agentKey, err := agentIDFromMeta(idVal)
	if err != nil {
		return api.AgentKey{}, err
	}
	agentType, err := agentTypeFromMeta(typeVal)
	if err != nil {
		return api.AgentKey{}, err
	}

	return api.AgentKey{ID: agentKey, Type: agentType}, nil
}

// agentIDFromMeta returns agent id or an gRPC status error.
func agentIDFromMeta(val []string) (int64, error) {
	if len(val) != 1 {
		return 0, status.Errorf(codes.InvalidArgument, "expecting a single %s, got %d", routingAgentIDMetadataKey, len(val))
	}
	agentID, err := strconv.ParseInt(val[0], 10, 64)
	if err != nil {
		return 0, status.Errorf(codes.InvalidArgument, "invalid %s", routingAgentIDMetadataKey)
	}

	return agentID, nil
}

// agentTypeFromMeta returns agent type or an gRPC status error.
func agentTypeFromMeta(val []string) (api.AgentType, error) {
	if len(val) > 1 {
		return api.AgentTypeUnknown, status.Errorf(codes.InvalidArgument, "expecting a single %s, got %d", routingAgentTypeMetadataKey, len(val))
	}
	agentType := api.AgentTypeKubernetes
	if len(val) == 1 {
		var err error
		agentType, err = api.ParseAgentType(val[0])
		if err != nil {
			return agentType, status.Error(codes.InvalidArgument, err.Error())
		}
	}
	if agentType == api.AgentTypeUnknown {
		return api.AgentTypeUnknown, status.Errorf(codes.InvalidArgument, "invalid %s", routingAgentTypeMetadataKey)
	}

	return agentType, nil
}

func constructRoutingMetrics(dm otelmetric.Meter) (otelmetric.Float64Histogram, otelmetric.Int64Counter, otelmetric.Int64Counter, error) {
	hist, err := dm.Float64Histogram(
		routingDurationMetricName,
		otelmetric.WithUnit("s"),
		otelmetric.WithDescription("The time it takes the tunnel router to find a suitable tunnel in seconds"),
		otelmetric.WithExplicitBucketBoundaries(0.001, 0.004, 0.016, 0.064, 0.256, 1.024, 4.096, 16.384),
	)
	if err != nil {
		return nil, nil, nil, err
	}
	routingTimeoutConnectedRecentlyCounter, err := dm.Int64Counter(
		routingTunnelTimeoutConnectedRecently,
		otelmetric.WithDescription("The total number of times routing timed out for agents that recently connected but no connection was found to use for the tunnel"),
	)
	if err != nil {
		return nil, nil, nil, err
	}
	routingTimeoutNotConnectedRecentlyCounter, err := dm.Int64Counter(
		routingTunnelTimeoutNotConnectedRecently,
		otelmetric.WithDescription("The total number of times routing timed out for agents that haven't been connected recently"),
	)
	if err != nil {
		return nil, nil, nil, err
	}
	return hist, routingTimeoutConnectedRecentlyCounter, routingTimeoutNotConnectedRecentlyCounter, nil
}
