package tunserver

import (
	"context"
	"errors"
	"sync"
	"testing"
	"time"

	"github.com/stretchr/testify/assert"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v17/internal/tool/fieldz"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v17/internal/tool/grpctool"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v17/internal/tool/grpctool/test"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v17/internal/tool/testing/mock_modshared"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v17/internal/tool/testing/mock_rpc"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v17/internal/tool/testing/testhelpers"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v17/internal/tool/testing/testlogger"
	"go.uber.org/mock/gomock"
)

const (
	selfAddr   grpctool.URLTarget = "grpc://self"
	kasURLPipe grpctool.URLTarget = "grpc://pipe"
)

func TestGatewayFinder_PollStartsSingleGoroutineForURL(t *testing.T) {
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()
	tf, querier, api, kasPool := setupGatewayFinder(ctx, t)

	var wg sync.WaitGroup
	wg.Add(2)

	gomock.InOrder(
		kasPool.EXPECT().
			Dial(selfAddr).
			DoAndReturn(func(targetURL grpctool.URLTarget) (grpctool.PoolConn, error) {
				wg.Done()
				<-ctx.Done() // block to simulate a long running dial
				return nil, ctx.Err()
			}),
		api.EXPECT().
			HandleProcessingError(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), fieldz.AgentID(testhelpers.AgentID)),
	)
	gomock.InOrder(
		querier.EXPECT().
			CachedGatewayURLs(testhelpers.AgentID),
		querier.EXPECT().
			PollGatewayURLs(gomock.Any(), testhelpers.AgentID, gomock.Any()).
			Do(func(ctx context.Context, agentID int64, cb PollGatewayURLsCallback[grpctool.URLTarget]) {
				cb([]grpctool.URLTarget{kasURLPipe})
				cb([]grpctool.URLTarget{kasURLPipe}) // same thing two times
				wg.Wait()
				cancel()
				<-ctx.Done()
			}),
	)
	gomock.InOrder(
		kasPool.EXPECT().
			Dial(kasURLPipe).
			DoAndReturn(func(targetURL grpctool.URLTarget) (grpctool.PoolConn, error) {
				wg.Done()
				<-ctx.Done() // block to simulate a long running dial
				return nil, ctx.Err()
			}),
		api.EXPECT().
			HandleProcessingError(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), fieldz.AgentID(testhelpers.AgentID)),
	)

	_, err := tf.Find(ctx)
	assert.Same(t, context.Canceled, err)
	assert.Len(t, tf.connections, 2)
	assert.Contains(t, tf.connections, selfAddr)
	assert.Contains(t, tf.connections, kasURLPipe)
}

func TestGatewayFinder_PollStartsGoroutineForEachURL(t *testing.T) {
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()
	tf, querier, api, kasPool := setupGatewayFinder(ctx, t)

	var wg sync.WaitGroup
	wg.Add(3)

	gomock.InOrder(
		kasPool.EXPECT().
			Dial(selfAddr).
			DoAndReturn(func(targetURL grpctool.URLTarget) (grpctool.PoolConn, error) {
				wg.Done()
				<-ctx.Done() // block to simulate a long running dial
				return nil, ctx.Err()
			}),
		api.EXPECT().
			HandleProcessingError(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), fieldz.AgentID(testhelpers.AgentID)),
	)
	gomock.InOrder(
		querier.EXPECT().
			CachedGatewayURLs(testhelpers.AgentID),
		querier.EXPECT().
			PollGatewayURLs(gomock.Any(), testhelpers.AgentID, gomock.Any()).
			Do(func(ctx context.Context, agentID int64, cb PollGatewayURLsCallback[grpctool.URLTarget]) {
				cb([]grpctool.URLTarget{kasURLPipe, "grpc://pipe2"})
				wg.Wait()
				cancel()
				<-ctx.Done()
			}),
	)
	kasPool.EXPECT().
		Dial(kasURLPipe).
		DoAndReturn(func(targetURL grpctool.URLTarget) (grpctool.PoolConn, error) {
			wg.Done()
			<-ctx.Done() // block to simulate a long running dial
			return nil, ctx.Err()
		})
	kasPool.EXPECT().
		Dial(grpctool.URLTarget("grpc://pipe2")).
		DoAndReturn(func(targetURL grpctool.URLTarget) (grpctool.PoolConn, error) {
			wg.Done()
			<-ctx.Done() // block to simulate a long running dial
			return nil, ctx.Err()
		})
	api.EXPECT().
		HandleProcessingError(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), fieldz.AgentID(testhelpers.AgentID)).
		Times(2)
	_, err := tf.Find(ctx)
	assert.Same(t, context.Canceled, err)
	assert.Len(t, tf.connections, 3)
	assert.Contains(t, tf.connections, selfAddr)
	assert.Contains(t, tf.connections, kasURLPipe)
	assert.Contains(t, tf.connections, grpctool.URLTarget("grpc://pipe2"))
}

func TestGatewayFinder_StopTryingAbsentKASURL(t *testing.T) {
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()
	tf, querier, api, kasPool := setupGatewayFinder(ctx, t)

	var wg sync.WaitGroup
	wg.Add(2)

	gomock.InOrder(
		kasPool.EXPECT().
			Dial(selfAddr).
			DoAndReturn(func(targetURL grpctool.URLTarget) (grpctool.PoolConn, error) {
				wg.Done()
				<-ctx.Done() // block to simulate a long running dial
				return nil, ctx.Err()
			}),
		api.EXPECT().
			HandleProcessingError(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), fieldz.AgentID(testhelpers.AgentID)),
	)
	gomock.InOrder(
		querier.EXPECT().
			CachedGatewayURLs(testhelpers.AgentID),
		querier.EXPECT().
			PollGatewayURLs(gomock.Any(), testhelpers.AgentID, gomock.Any()).
			Do(func(ctx context.Context, agentID int64, cb PollGatewayURLsCallback[grpctool.URLTarget]) {
				cb([]grpctool.URLTarget{kasURLPipe})
				wg.Wait()
				cancel()
				<-ctx.Done()
			}),
	)
	kasPool.EXPECT().
		Dial(kasURLPipe).
		DoAndReturn(func(targetURL grpctool.URLTarget) (grpctool.PoolConn, error) {
			defer wg.Done()
			tf.mu.Lock()
			defer tf.mu.Unlock()
			tf.gatewayURLs = nil // remove kasURLPipe from the list
			return nil, errors.New("boom")
		})
	api.EXPECT().
		HandleProcessingError(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), fieldz.AgentID(testhelpers.AgentID))
	_, err := tf.Find(ctx)
	assert.Same(t, context.Canceled, err)
	assert.Len(t, tf.connections, 1)
	assert.Contains(t, tf.connections, selfAddr)
}

func setupGatewayFinder(ctx context.Context, t *testing.T) (*gatewayFinder[grpctool.URLTarget], *MockPollingGatewayURLQuerier[grpctool.URLTarget], *mock_modshared.MockAPI, *mock_rpc.MockPoolInterface[grpctool.URLTarget]) {
	t.Parallel()
	ctrl := gomock.NewController(t)
	querier := NewMockPollingGatewayURLQuerier[grpctool.URLTarget](ctrl)
	api := mock_modshared.NewMockAPI(ctrl)
	kasPool := mock_rpc.NewMockPoolInterface[grpctool.URLTarget](ctrl)

	tf := NewGatewayFinder[grpctool.URLTarget](
		ctx,
		testlogger.New(t),
		kasPool,
		querier,
		api,
		test.Testing_RequestResponse_FullMethodName,
		selfAddr,
		testhelpers.AgentID,
		testhelpers.NewPollConfig(100*time.Millisecond),
		10*time.Millisecond,
	).(*gatewayFinder[grpctool.URLTarget])
	return tf, querier, api, kasPool
}
