package router

import (
	"testing"
	"time"

	"github.com/stretchr/testify/require"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/api"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/mock_modshared"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/testhelpers"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/testlogger"
	metricnoop "go.opentelemetry.io/otel/metric/noop"
	"go.uber.org/mock/gomock"
	"k8s.io/apimachinery/pkg/util/wait"
)

func TestAsyncTrackerBatching1(t *testing.T) {
	ctrl := gomock.NewController(t)
	modAPI := mock_modshared.NewMockAPI(ctrl)
	mp := metricnoop.NewMeterProvider().Meter("test")
	tracker := NewMockTracker(ctrl)
	rb := NewMockRegistrationBuilder(ctrl)

	lots := make([]api.AgentKey, maxBatchSize*2+3)
	for i := range lots {
		lots[i] = api.AgentKey{ID: int64(i), Type: getAgentTypeForLot(i)}
	}

	gomock.InOrder(
		tracker.EXPECT().
			RegistrationBuilder().
			Return(rb),
		rb.EXPECT().
			Refresh(time.Minute, lots[:maxBatchSize]),
		rb.EXPECT().
			Do(gomock.Any()),
		rb.EXPECT().
			Refresh(time.Minute, lots[maxBatchSize:maxBatchSize+maxBatchSize]),
		rb.EXPECT().
			Do(gomock.Any()),
		rb.EXPECT().
			Refresh(time.Minute, lots[maxBatchSize+maxBatchSize:]),
		rb.EXPECT().
			Do(gomock.Any()),
	)

	a, err := newAsyncTracker(testlogger.New(t), modAPI, mp, tracker, time.Minute)
	require.NoError(t, err)

	var wg wait.Group
	defer wg.Wait()
	defer a.done()
	wg.Start(a.run)

	a.refresh(lots)
}

func TestAsyncTrackerBatching2(t *testing.T) {
	ctrl := gomock.NewController(t)
	modAPI := mock_modshared.NewMockAPI(ctrl)
	mp := metricnoop.NewMeterProvider().Meter("test")
	tracker := NewMockTracker(ctrl)
	rb := NewMockRegistrationBuilder(ctrl)

	lots := make([]api.AgentKey, maxBatchSize)
	for i := range lots {
		lots[i] = api.AgentKey{ID: int64(i), Type: getAgentTypeForLot(i)}
	}
	stop := make(chan struct{})
	gomock.InOrder(
		tracker.EXPECT().
			RegistrationBuilder().
			Return(rb),
		rb.EXPECT().
			Register(time.Minute, testhelpers.AgentkKey1).
			Do(func(ttl time.Duration, agentKeys ...api.AgentKey) {
				<-stop // wait for the refresh() to put task into the queue so that we batch it with this registration.
			}),
		rb.EXPECT().
			Register(time.Minute, testhelpers.AgentwKey1).
			Do(func(ttl time.Duration, agentKeys ...api.AgentKey) {
				<-stop // wait for the refresh() to put task into the queue so that we batch it with this registration.
			}),
		rb.EXPECT().
			Register(time.Minute, testhelpers.RunnerControllerKey1).
			Do(func(ttl time.Duration, agentKeys ...api.AgentKey) {
				<-stop // wait for the refresh() to put task into the queue so that we batch it with this registration.
			}),
		rb.EXPECT().
			Refresh(time.Minute, lots[:maxBatchSize-3]),
		rb.EXPECT().
			Do(gomock.Any()),
		rb.EXPECT().
			Refresh(time.Minute, lots[maxBatchSize-3:]),
		rb.EXPECT().
			Do(gomock.Any()),
	)

	a, err := newAsyncTracker(testlogger.New(t), modAPI, mp, tracker, time.Minute)
	require.NoError(t, err)

	var wg wait.Group
	defer wg.Wait()
	defer a.done()
	wg.Start(a.run)

	a.register(testhelpers.AgentkKey1)
	a.register(testhelpers.AgentwKey1)
	a.register(testhelpers.RunnerControllerKey1)
	a.refresh(lots)
	close(stop)
}

func getAgentTypeForLot(i int) api.AgentType {
	switch i % 3 {
	case 0:
		return api.AgentTypeKubernetes
	case 1:
		return api.AgentTypeWorkspace
	case 2:
		return api.AgentTypeRunnerController
	}
	return api.AgentTypeUnknown
}
