package server

import (
	"context"
	"encoding/json"
	"io"
	"net/http"
	"testing"
	"time"

	"github.com/bufbuild/protovalidate-go"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
	gapi "gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v17/internal/gitlab/api"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v17/internal/module/modserver"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v17/internal/module/usage_metrics"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v17/internal/tool/httpz"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v17/internal/tool/testing/matcher"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v17/internal/tool/testing/mock_gitlab"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v17/internal/tool/testing/mock_modserver"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v17/internal/tool/testing/mock_usage_metrics"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v17/internal/tool/testing/testhelpers"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v17/internal/tool/testing/testlogger"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v17/pkg/kascfg"
	"go.uber.org/mock/gomock"
	"google.golang.org/protobuf/types/known/durationpb"
)

var (
	_ modserver.Module        = &module{}
	_ modserver.Factory       = &Factory{}
	_ modserver.ApplyDefaults = ApplyDefaults
)

func TestSendUsage(t *testing.T) {
	t.Parallel()
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()
	uniqueCounter := []int64{1, 2}
	counters, uniqueCounters, payload := setUpPayload(5, uniqueCounter)

	m, tracker, _ := setupModule(t, func(w http.ResponseWriter, r *http.Request) {
		assertNoContentRequest(t, r, payload)
		w.WriteHeader(http.StatusNoContent)
	})
	ud := &usage_metrics.UsageData{Counters: counters, UniqueCounters: uniqueCounters}
	gomock.InOrder(
		tracker.EXPECT().
			CloneUsageData().
			Return(ud),
		tracker.EXPECT().
			Subtract(ud),
		tracker.EXPECT().
			CloneUsageData().
			DoAndReturn(func() *usage_metrics.UsageData {
				cancel()
				return &usage_metrics.UsageData{}
			}),
		tracker.EXPECT().
			CloneUsageData().
			Return(&usage_metrics.UsageData{}),
	)
	require.NoError(t, m.Run(ctx))
}

func TestSendUsageFailureAndRetry(t *testing.T) {
	t.Parallel()
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()
	uniqueCounter := []int64{1, 2}
	uniqueCounter2 := []int64{6, 7}
	counters, uniqueCounters, payload := setUpPayload(5, uniqueCounter)
	ud1 := &usage_metrics.UsageData{Counters: counters, UniqueCounters: uniqueCounters}
	counters2, uniqueCounters2, payload2 := setUpPayload(10, uniqueCounter2)
	ud2 := &usage_metrics.UsageData{Counters: counters2, UniqueCounters: uniqueCounters2}
	var call int
	m, tracker, mockAPI := setupModule(t, func(w http.ResponseWriter, r *http.Request) {
		call++
		switch call {
		case 1:
			assertNoContentRequest(t, r, payload)
			w.WriteHeader(http.StatusInternalServerError)
		case 2:
			assertNoContentRequest(t, r, payload2)
			w.WriteHeader(http.StatusNoContent)
		default:
			assert.Fail(t, "unexpected call", call)
		}
	})
	gomock.InOrder(
		tracker.EXPECT().
			CloneUsageData().
			Return(ud1),
		mockAPI.EXPECT().
			HandleProcessingError(gomock.Any(), gomock.Any(), "Failed to send usage data", matcher.ErrorEq("HTTP status code: 500 for path /api/v4/internal/kubernetes/usage_metrics")),
		tracker.EXPECT().
			CloneUsageData().
			Return(ud2),
		tracker.EXPECT().
			Subtract(ud2),
		tracker.EXPECT().
			CloneUsageData().
			DoAndReturn(func() *usage_metrics.UsageData {
				cancel()
				return &usage_metrics.UsageData{}
			}),
		tracker.EXPECT().
			CloneUsageData().
			Return(&usage_metrics.UsageData{}),
	)
	require.NoError(t, m.Run(ctx))
}

func TestSendUsageHTTP(t *testing.T) {
	t.Parallel()
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()
	uniqueCounter := []int64{1, 2}
	counters, uniqueCounters, payload := setUpPayload(5, uniqueCounter)
	ud := &usage_metrics.UsageData{Counters: counters, UniqueCounters: uniqueCounters}

	m, tracker, _ := setupModule(t, func(w http.ResponseWriter, r *http.Request) {
		assertNoContentRequest(t, r, payload)
		w.WriteHeader(http.StatusNoContent)
	})

	gomock.InOrder(
		tracker.EXPECT().
			CloneUsageData().
			Return(ud),
		tracker.EXPECT().
			Subtract(ud).
			Do(func(ud *usage_metrics.UsageData) {
				cancel()
			}),
		tracker.EXPECT().
			CloneUsageData().
			Return(&usage_metrics.UsageData{}),
	)
	require.NoError(t, m.Run(ctx))
}

func setupModule(t *testing.T, handler func(http.ResponseWriter, *http.Request)) (*module, *mock_usage_metrics.MockUsageTrackerInterface, *mock_modserver.MockAPI) {
	ctrl := gomock.NewController(t)
	tracker := mock_usage_metrics.NewMockUsageTrackerInterface(ctrl)
	mockAPI := mock_modserver.NewMockAPI(ctrl)
	f := Factory{
		UsageTracker: tracker,
	}
	config := &kascfg.ConfigurationFile{}
	ApplyDefaults(config)
	config.Observability.UsageReportingPeriod = durationpb.New(100 * time.Millisecond)
	v, err := protovalidate.New()
	require.NoError(t, err)
	m, err := f.New(&modserver.Config{
		Log:          testlogger.New(t),
		API:          mockAPI,
		Config:       config,
		GitLabClient: mock_gitlab.SetupClient(t, gapi.UsagePingAPIPath, handler),
		UsageTracker: tracker,
		Validator:    v,
	})
	require.NoError(t, err)
	return m.(*module), tracker, mockAPI
}

func assertNoContentRequest(t *testing.T, r *http.Request, expectedPayload any) {
	testhelpers.AssertRequestMethod(t, r, http.MethodPost)
	assert.Empty(t, r.Header[httpz.AcceptHeader])
	testhelpers.AssertRequestContentTypeJSON(t, r)
	testhelpers.AssertRequestUserAgent(t, r, testhelpers.KASUserAgent)
	testhelpers.AssertJWTSignature(t, r)
	expectedBin, err := json.Marshal(expectedPayload)
	if !assert.NoError(t, err) {
		return
	}
	var expected any
	err = json.Unmarshal(expectedBin, &expected)
	if !assert.NoError(t, err) {
		return
	}
	actualBin, err := io.ReadAll(r.Body)
	if !assert.NoError(t, err) {
		return
	}
	var actual any
	err = json.Unmarshal(actualBin, &actual)
	if !assert.NoError(t, err) {
		return
	}
	assert.Equal(t, expected, actual)
}

func setUpPayload(counter int64, uniqueCounter []int64) (map[string]int64, map[string][]int64, map[string]any) {
	payload := map[string]any{}
	var counters = map[string]int64{
		"x": counter,
	}
	var uniqueCounters = map[string][]int64{
		"x": uniqueCounter,
	}
	payload["counters"] = counters
	payload["unique_counters"] = uniqueCounters
	return counters, uniqueCounters, payload
}
