package redistool

import (
	"context"
	"crypto/sha256"
	"errors"
	"fmt"
	"testing"
	"time"

	"github.com/redis/rueidis"
	rmock "github.com/redis/rueidis/mock"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/api"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/matcher"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/testhelpers"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/testlogger"
	"go.uber.org/mock/gomock"
	clock_testing "k8s.io/utils/clock/testing"
)

const (
	ctxKey = 23124
)

func BenchmarkBuildTokenLimiterKey(b *testing.B) {
	b.ReportAllocs()
	const prefix = "pref"
	const currentMinute = 42
	var sink string
	requestKey := []byte{1, 2, 3, 4}
	for b.Loop() {
		sink = buildTokenLimiterKey(prefix, requestKey, currentMinute)
	}
	_ = sink
}

func TestBuildTokenLimiterKey(t *testing.T) {
	const prefix = "pref"
	const currentMinute = 42
	requestKey := []byte{1, 2, 3, 4}
	key := buildTokenLimiterKey(prefix, requestKey, currentMinute)

	assert.Equal(t, fmt.Sprintf("%s:AQIDBA==:%d", prefix, currentMinute), key)
}

func TestTokenLimiterHappyPath(t *testing.T) {
	ctx, _, client, limiter, key := setup(t)

	client.EXPECT().
		Do(gomock.Any(), rmock.Match("GET", key)).
		Return(rmock.Result(rmock.RedisInt64(0)))
	client.EXPECT().
		DoMulti(gomock.Any(),
			rmock.Match("MULTI"),
			rmock.Match("INCR", key),
			rmock.Match("EXPIRE", key, "59"),
			rmock.Match("EXEC"),
		)

	require.True(t, limiter.Allow(ctx), "Allow when no token has been consumed")
}

func TestTokenLimiterOverLimit(t *testing.T) {
	ctx, _, client, limiter, key := setup(t)

	client.EXPECT().
		Do(gomock.Any(), rmock.Match("GET", key)).
		Return(rmock.Result(rmock.RedisInt64(1)))

	require.False(t, limiter.Allow(ctx), "Do not allow when a token has been consumed")
}

func TestTokenLimiterNotAllowedWhenGetError(t *testing.T) {
	ctx, rpcAPI, client, limiter, key := setup(t)
	err := errors.New("test connection error")
	client.EXPECT().
		Do(gomock.Any(), rmock.Match("GET", key)).
		Return(rmock.ErrorResult(err))

	rpcAPI.EXPECT().
		HandleProcessingError("redistool.TokenLimiter: error retrieving minute bucket count", err)

	require.False(t, limiter.Allow(ctx), "Do not allow when there is a connection error")
}

func TestTokenLimiterNotAllowedWhenIncrError(t *testing.T) {
	err := errors.New("test connection error")
	ctx, rpcAPI, client, limiter, key := setup(t)

	client.EXPECT().
		Do(gomock.Any(), rmock.Match("GET", key)).
		Return(rmock.Result(rmock.RedisInt64(0)))
	client.EXPECT().
		DoMulti(gomock.Any(),
			rmock.Match("MULTI"),
			rmock.Match("INCR", key),
			rmock.Match("EXPIRE", key, "59"),
			rmock.Match("EXEC"),
		).
		Return([]rueidis.RedisResult{rmock.ErrorResult(err)})
	rpcAPI.EXPECT().
		HandleProcessingError("redistool.TokenLimiter: error while incrementing token key count", matcher.ErrorIs(err))

	require.False(t, limiter.Allow(ctx), "Do not allow when there is a connection error")
}

func setup(t *testing.T) (context.Context, *MockRPCAPI, *rmock.Client, *TokenLimiter, string) {
	ctrl := gomock.NewController(t)
	client := rmock.NewClient(ctrl)
	rpcAPI := NewMockRPCAPI(ctrl)
	rpcAPI.EXPECT().
		Log().
		Return(testlogger.New(t)).
		AnyTimes()

	limiter := NewTokenLimiter(client, "key_prefix", 1,
		func(ctx context.Context) RPCAPI {
			rpcAPI.EXPECT().
				RequestKey().
				Return(agentToken2key(ctx.Value(ctxKey).(api.AgentToken)))
			return rpcAPI
		})
	limiter.clock = clock_testing.NewFakePassiveClock(time.Unix(100, 0))
	ctx := context.WithValue(context.Background(), ctxKey, testhelpers.AgentToken) //nolint: staticcheck
	key := buildTokenLimiterKey(limiter.keyPrefix, agentToken2key(testhelpers.AgentToken), byte(limiter.clock.Now().UTC().Minute()))
	return ctx, rpcAPI, client, limiter, key
}

func agentToken2key(token api.AgentToken) []byte {
	tokenHash := sha256.Sum256([]byte(token))
	return tokenHash[:]
}
