package test

import (
	"context"
	"log/slog"
	"net"
	"testing"
	"time"

	"buf.build/go/protovalidate"
	"github.com/ash2k/stager"
	"github.com/stretchr/testify/require"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/api"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/module/agent2kas_tunnel/router"
	agent2kas_tunnel_server "gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/module/agent2kas_tunnel/server"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/module/modserver"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/module/modshared"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/grpctool"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/logz"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/mock_agent2kas_router"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/mock_modserver"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/testhelpers"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/testlogger"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/pkg/kascfg"
	"gitlab.com/gitlab-org/cluster-integration/tunnel/rpc"
	"gitlab.com/gitlab-org/cluster-integration/tunnel/tool/grpcz"
	"gitlab.com/gitlab-org/cluster-integration/tunnel/tool/retry"
	metricnoop "go.opentelemetry.io/otel/metric/noop"
	"go.opentelemetry.io/otel/trace/noop"
	"go.uber.org/mock/gomock"
	"google.golang.org/grpc"
	"google.golang.org/grpc/credentials/insecure"
	"google.golang.org/protobuf/types/known/durationpb"
)

func serverConstructComponents(ctx context.Context, t *testing.T) (func(context.Context) error, *grpc.ClientConn, *grpc.ClientConn, *mock_modserver.MockAgentRPCAPI) {
	log := testlogger.New(t)
	ctrl := gomock.NewController(t)
	mockAPI := mock_modserver.NewMockAPI(ctrl)
	serverRPCAPI := mock_modserver.NewMockAgentRPCAPI(ctrl)
	serverRPCAPI.EXPECT().
		Log().
		Return(log).
		AnyTimes()
	serverRPCAPI.EXPECT().
		PollWithBackoff(gomock.Any(), gomock.Any()).
		DoAndReturn(func(_ retry.PollConfig, f retry.PollWithBackoffFunc, opts ...retry.PollOption) error {
			for {
				err, result := f()
				if result == retry.Done {
					return err
				}
			}
		}).
		MinTimes(1)

	v, err := protovalidate.New()
	require.NoError(t, err)

	rb := mock_agent2kas_router.NewMockRegistrationBuilder(ctrl)
	rb.EXPECT().
		Register(gomock.Any(), gomock.Any()).
		AnyTimes() // may be 0 if incoming connections arrive before tunnel connections
	rb.EXPECT().
		Unregister(gomock.Any()).
		AnyTimes()
	rb.EXPECT().
		Do(gomock.Any()).
		AnyTimes()
	tunnelTracker := mock_agent2kas_router.NewMockTracker(ctrl)
	tunnelTracker.EXPECT().
		RegistrationBuilder().
		Return(rb)

	agentServer := serverConstructAgentServer(ctx, v, serverRPCAPI)
	agentServerListener := grpcz.NewPipeDialListener()

	internalListener := grpcz.NewPipeDialListener()
	tr := noop.NewTracerProvider().Tracer("test")
	m := metricnoop.NewMeterProvider().Meter("test")
	tunnelRegistry, err := router.NewRegistry(log, mockAPI, tr, m, time.Minute, time.Minute, time.Minute, tunnelTracker)
	require.NoError(t, err)

	internalServer := serverConstructInternalServer(ctx, log)
	internalServerConn, err := serverConstructInternalServerConn(t, internalListener.DialContext)
	require.NoError(t, err)

	serverFactory := agent2kas_tunnel_server.Factory{
		TunnelHandler: tunnelRegistry,
	}

	serverConfig := &modserver.Config{
		Log: log,
		Config: &kascfg.ConfigurationFile{
			Agent: &kascfg.AgentCF{
				Listen: &kascfg.ListenAgentCF{
					MaxConnectionAge: durationpb.New(time.Minute),
				},
			},
		},
		AgentServer: agentServer,
		AgentConnPool: func(agentKey api.AgentKey) grpc.ClientConnInterface {
			return internalServerConn
		},
		Validator: v,
	}
	serverModule, err := serverFactory.New(serverConfig)
	require.NoError(t, err)

	kasConn, err := serverConstructKASConnection(t, agentServerListener.DialContext)
	require.NoError(t, err)

	registerTestingServer(internalServer, &serverTestingServer{
		registry: tunnelRegistry,
	})

	return func(ctx context.Context) error {
		return stager.RunStages(ctx,
			// Start modules.
			func(stage stager.Stage) {
				if serverModule != nil {
					stage.Go(serverModule.Run)
				}
			},
			// Start gRPC servers.
			func(stage stager.Stage) {
				serverStartAgentServer(stage, agentServer, agentServerListener)
				serverStartInternalServer(stage, internalServer, internalListener)
			},
		)
	}, kasConn, internalServerConn, serverRPCAPI
}

func serverConstructInternalServer(ctx context.Context, log *slog.Logger) *grpc.Server {
	_, sh := grpcz.MaxConnectionAge2GRPCKeepalive(ctx, time.Minute)
	factory := func(ctx context.Context, fullMethodName string) modshared.RPCAPI {
		return &serverRPCAPIForTest{
			RPCAPIStub: modshared.RPCAPIStub{
				StreamCtx: ctx,
				Logger:    log,
			},
		}
	}
	return grpc.NewServer(
		grpc.StatsHandler(sh),
		grpc.ForceServerCodecV2(rpc.RawCodec{}),
		grpc.ChainStreamInterceptor(
			modshared.StreamRPCAPIInterceptor(factory),
		),
		grpc.ChainUnaryInterceptor(
			modshared.UnaryRPCAPIInterceptor(factory),
		),
	)
}

func serverConstructInternalServerConn(t *testing.T, dialContext func(ctx context.Context, addr string) (net.Conn, error)) (*grpc.ClientConn, error) {
	v, err := protovalidate.New()
	require.NoError(t, err)
	return grpc.NewClient("passthrough:internal-server",
		grpc.WithContextDialer(dialContext),
		grpc.WithTransportCredentials(insecure.NewCredentials()),
		grpc.WithChainStreamInterceptor(
			grpctool.StreamClientValidatingInterceptor(v),
		),
		grpc.WithChainUnaryInterceptor(
			grpctool.UnaryClientValidatingInterceptor(v),
		),
	)
}

func serverConstructKASConnection(t *testing.T, dialContext func(ctx context.Context, addr string) (net.Conn, error)) (*grpc.ClientConn, error) {
	v, err := protovalidate.New()
	require.NoError(t, err)
	return grpc.NewClient("passthrough:conn-to-kas",
		grpc.WithContextDialer(dialContext),
		grpc.WithTransportCredentials(insecure.NewCredentials()),
		grpc.WithPerRPCCredentials(grpctool.NewTokenCredentials(testhelpers.AgentToken, api.AgentTypeKubernetes, true)),
		grpc.WithChainStreamInterceptor(
			grpctool.StreamClientValidatingInterceptor(v),
		),
		grpc.WithChainUnaryInterceptor(
			grpctool.UnaryClientValidatingInterceptor(v),
		),
	)
}

func serverStartInternalServer(stage stager.Stage, internalServer *grpc.Server, internalListener net.Listener) {
	grpctool.StartServer(stage, internalServer,
		func(ctx context.Context) (net.Listener, error) {
			return internalListener, nil
		},
		func() {},
		func() {},
	)
}

func serverConstructAgentServer(ctx context.Context, v protovalidate.Validator, rpcAPI modserver.AgentRPCAPI) *grpc.Server {
	kp, sh := grpcz.MaxConnectionAge2GRPCKeepalive(ctx, time.Minute)
	factory := func(ctx context.Context, fullMethodName string) (modserver.AgentRPCAPI, error) {
		return rpcAPI, nil
	}
	return grpc.NewServer(
		grpc.StatsHandler(sh),
		kp,
		grpc.ChainStreamInterceptor(
			grpctool.StreamServerValidatingInterceptor(v),
			modserver.StreamAgentRPCAPIInterceptor(factory),
		),
		grpc.ChainUnaryInterceptor(
			grpctool.UnaryServerValidatingInterceptor(v),
			modserver.UnaryAgentRPCAPIInterceptor(factory),
		),
	)
}

func serverStartAgentServer(stage stager.Stage, agentServer *grpc.Server, agentServerListener net.Listener) {
	grpctool.StartServer(stage, agentServer,
		func(ctx context.Context) (net.Listener, error) {
			return agentServerListener, nil
		},
		func() {},
		func() {},
	)
}

type serverRPCAPIForTest struct {
	modshared.RPCAPIStub
}

func (a *serverRPCAPIForTest) HandleProcessingError(log *slog.Logger, msg string, err error, fields ...slog.Attr) {
	slogAttrs := make([]slog.Attr, 0, len(fields)+1)
	slogAttrs = append(slogAttrs, fields...)
	slogAttrs = append(slogAttrs, logz.Error(err))

	log.LogAttrs(a.StreamCtx, slog.LevelError, msg, slogAttrs...)
}

func (a *serverRPCAPIForTest) HandleIOError(log *slog.Logger, msg string, err error) error {
	return grpcz.HandleIOError(log, msg, err)
}
