package server

import (
	"context"
	"errors"
	"testing"

	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/module/modserver"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/mock_agent2kas_router"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/mock_modserver"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/mock_tunnel_rpc"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/testhelpers"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/testlogger"
	"gitlab.com/gitlab-org/cluster-integration/tunnel/rpc"
	"gitlab.com/gitlab-org/cluster-integration/tunnel/tool/grpcz"
	"gitlab.com/gitlab-org/cluster-integration/tunnel/tool/retry"
	"go.uber.org/mock/gomock"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
)

var (
	_ rpc.ReverseTunnelServer = (*server)(nil)
)

func TestConnectAllowsValidToken(t *testing.T) {
	ctrl := gomock.NewController(t)
	h := mock_agent2kas_router.NewMockHandler(ctrl)
	mockRPCAPI := mock_modserver.NewMockAgentRPCAPIWithMockPoller(ctrl, 1)
	mockRPCAPI.EXPECT().
		Log().
		Return(testlogger.New(t)).
		AnyTimes()
	s := &server{
		tunnelHandler:          h,
		getAgentInfoPollConfig: defaultRetryConfig(),
	}
	agentInfo := testhelpers.AgentkInfoObj()
	ctx := grpcz.AddMaxConnectionAgeContext(context.Background(), context.Background())
	ctx = modserver.InjectAgentRPCAPI(ctx, mockRPCAPI)
	connectServer := mock_tunnel_rpc.NewMockReverseTunnel_ConnectServer(ctrl)
	connectServer.EXPECT().
		Context().
		Return(ctx).
		MinTimes(1)
	gomock.InOrder(
		mockRPCAPI.EXPECT().
			AgentInfo(gomock.Any(), gomock.Any()).
			Return(agentInfo, nil),
		h.EXPECT().
			HandleTunnel(gomock.Any(), gomock.Any(), agentInfo.Key, gomock.Any()),
	)
	err := s.Connect(connectServer)
	require.NoError(t, err)
}

func TestConnectRejectsInvalidToken(t *testing.T) {
	ctrl := gomock.NewController(t)
	h := mock_agent2kas_router.NewMockHandler(ctrl)
	mockRPCAPI := mock_modserver.NewMockAgentRPCAPIWithMockPoller(ctrl, 1)
	mockRPCAPI.EXPECT().
		Log().
		Return(testlogger.New(t)).
		AnyTimes()
	s := &server{
		tunnelHandler:          h,
		getAgentInfoPollConfig: defaultRetryConfig(),
	}
	ctx := grpcz.AddMaxConnectionAgeContext(context.Background(), context.Background())
	ctx = modserver.InjectAgentRPCAPI(ctx, mockRPCAPI)
	connectServer := mock_tunnel_rpc.NewMockReverseTunnel_ConnectServer(ctrl)
	connectServer.EXPECT().
		Context().
		Return(ctx).
		MinTimes(1)
	mockRPCAPI.EXPECT().
		AgentInfo(gomock.Any(), gomock.Any()).
		Return(nil, errors.New("expected err"))
	err := s.Connect(connectServer)
	assert.EqualError(t, err, "expected err")
}

func TestConnectRetriesFailedAgentInfo(t *testing.T) {
	ctrl := gomock.NewController(t)
	h := mock_agent2kas_router.NewMockHandler(ctrl)
	mockRPCAPI := mock_modserver.NewMockAgentRPCAPIWithMockPoller(ctrl, 2)
	mockRPCAPI.EXPECT().
		Log().
		Return(testlogger.New(t)).
		AnyTimes()
	s := &server{
		tunnelHandler:          h,
		getAgentInfoPollConfig: defaultRetryConfig(),
	}
	ctx := grpcz.AddMaxConnectionAgeContext(context.Background(), context.Background())
	ctx = modserver.InjectAgentRPCAPI(ctx, mockRPCAPI)
	connectServer := mock_tunnel_rpc.NewMockReverseTunnel_ConnectServer(ctrl)
	connectServer.EXPECT().
		Context().
		Return(ctx).
		MinTimes(1)
	agentInfo := testhelpers.AgentkInfoObj()
	gomock.InOrder(
		mockRPCAPI.EXPECT().
			AgentInfo(gomock.Any(), gomock.Any()).
			Return(nil, status.Error(codes.Unavailable, "unavailable")),
		mockRPCAPI.EXPECT().
			AgentInfo(gomock.Any(), gomock.Any()).
			Return(agentInfo, nil),
		h.EXPECT().
			HandleTunnel(gomock.Any(), gomock.Any(), agentInfo.Key, gomock.Any()),
	)
	err := s.Connect(connectServer)
	assert.NoError(t, err)
}

func defaultRetryConfig() retry.PollConfig {
	return retry.PollConfig{
		Interval: 0,
		Backoff: retry.NewExponentialBackoff(
			getAgentInfoInitBackoff,
			getAgentInfoMaxBackoff,
			getAgentInfoResetDuration,
			getAgentInfoBackoffFactor,
			getAgentInfoJitter,
		),
	}
}
