package server

import (
	"context"
	"testing"

	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/module/agent_tracker"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/module/agent_tracker/rpc"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/module/modserver"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/module/modshared"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/mock_agent_tracker"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/mock_modserver"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/testhelpers"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/testlogger"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/version"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/pkg/entity/agentk"
	"go.uber.org/mock/gomock"
)

var (
	_ modserver.Factory = (*Factory)(nil)
)

func TestServer_GetConnectedAgentksByAgentIDs(t *testing.T) {
	_, mockTracker, s, ctx := setupServer(t)

	agentKey1 := testhelpers.AgentkKey1
	agentKey2 := testhelpers.AgentkKey1
	projectID1 := testhelpers.ProjectID
	projectID2 := testhelpers.ProjectID + 1
	req := &rpc.GetConnectedAgentksByAgentIDsRequest{
		AgentIds: []int64{agentKey1.ID, agentKey2.ID},
	}

	mockTracker.EXPECT().
		GetAgentkConnectionsByID(ctx, agentKey1.ID).
		Return(func(yield func(*agent_tracker.ConnectedAgentkInfo) bool) {
			yield(&agent_tracker.ConnectedAgentkInfo{
				ConnectionId: 123123123,
				AgentId:      agentKey1.ID,
				ProjectId:    projectID1,
			})
		})
	mockTracker.EXPECT().
		GetAgentkConnectionsByID(ctx, agentKey2.ID).
		Return(func(yield func(*agent_tracker.ConnectedAgentkInfo) bool) {
			yield(&agent_tracker.ConnectedAgentkInfo{
				ConnectionId: 456456456,
				AgentId:      agentKey2.ID,
				ProjectId:    projectID2,
			})
		})

	resp, err := s.GetConnectedAgentksByAgentIDs(ctx, req)
	require.NoError(t, err)
	assert.Len(t, resp.Agents, 2)
	assert.Equal(t, agentKey1.ID, resp.Agents[0].AgentId)
	assert.Equal(t, agentKey2.ID, resp.Agents[1].AgentId)
}

func TestServer_GetConnectedAgentksByAgentIDs_WithAgentWarning(t *testing.T) {
	_, mockTracker, s, ctx := setupServer(t)

	agentKey1 := testhelpers.AgentkKey1
	agentKey2 := testhelpers.AgentkKey1
	projectID1 := testhelpers.ProjectID
	projectID2 := testhelpers.ProjectID + 1
	req := &rpc.GetConnectedAgentksByAgentIDsRequest{
		AgentIds: []int64{agentKey1.ID, agentKey2.ID},
	}

	mockTracker.EXPECT().
		GetAgentkConnectionsByID(ctx, agentKey1.ID).
		Return(func(yield func(*agent_tracker.ConnectedAgentkInfo) bool) {
			yield(&agent_tracker.ConnectedAgentkInfo{
				AgentMeta: &agentk.Meta{
					// This is a major version higher than the server
					Version: "13.0.0",
				},
				ConnectionId: 123123123,
				AgentId:      agentKey1.ID,
				ProjectId:    projectID1,
			})
		})
	mockTracker.EXPECT().
		GetAgentkConnectionsByID(ctx, agentKey2.ID).
		Return(func(yield func(*agent_tracker.ConnectedAgentkInfo) bool) {
			yield(&agent_tracker.ConnectedAgentkInfo{
				ConnectionId: 456456456,
				AgentId:      agentKey2.ID,
				ProjectId:    projectID2,
			})
		})

	resp, err := s.GetConnectedAgentksByAgentIDs(ctx, req)
	require.NoError(t, err)
	assert.Len(t, resp.Agents, 2)
	assert.Equal(t, agentKey1.ID, resp.Agents[0].AgentId)
	assert.Equal(t, "The agent server for Kubernetes (KAS) version cannot be checked for compatibility. KAS and the agent for Kubernetes (agentk) might not be compatible. Make sure agentk and KAS have the same version.", resp.Agents[0].Warnings[0].GetVersion().Message)
	assert.Equal(t, agentKey2.ID, resp.Agents[1].AgentId)
}

func TestServer_CountAgentsByAgentVersions(t *testing.T) {
	mockRPCAPI, mockTracker, s, ctx := setupServer(t)

	req := &rpc.CountAgentsByAgentVersionsRequest{}

	mockRPCAPI.EXPECT().
		Log().
		Return(testlogger.New(t))

	mockTracker.EXPECT().
		CountAgentksByAgentVersions(ctx).
		DoAndReturn(func(ctx context.Context) (map[string]int64, error) {
			counts := map[string]int64{
				"16.8.0": 111,
				"16.9.0": 222,
			}
			return counts, nil
		})

	resp, err := s.CountAgentsByAgentVersions(ctx, req)
	require.NoError(t, err)
	assert.Len(t, resp.AgentVersions, 2)
	assert.EqualValues(t, 111, resp.AgentVersions["16.8.0"])
	assert.EqualValues(t, 222, resp.AgentVersions["16.9.0"])
}

func setupServer(t *testing.T) (*mock_modserver.MockRPCAPI, *mock_agent_tracker.MockTracker, *server, context.Context) {
	ctrl := gomock.NewController(t)

	mockRPCAPI := mock_modserver.NewMockRPCAPI(ctrl)
	mockTracker := mock_agent_tracker.NewMockTracker(ctrl)

	s := &server{
		agentQuerier:       mockTracker,
		serverVersion:      version.Version{Major: 12, Minor: 3, Patch: 4},
		gitlabReleasesList: []string{"12.3.4"},
	}

	ctx := modshared.InjectRPCAPI(context.Background(), mockRPCAPI)

	return mockRPCAPI, mockTracker, s, ctx
}
