package rpc_test

import (
	"context"
	"io"
	"testing"

	"github.com/stretchr/testify/assert"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/module/agent_configuration/rpc"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/matcher"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/mock_rpc"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/testhelpers"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/testlogger"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/pkg/agentcfg"
	"go.uber.org/mock/gomock"
)

const (
	revision1 = "rev12341234"
	revision2 = "rev123412341"
)

func TestConfigurationWatcher(t *testing.T) {
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()
	ctrl := gomock.NewController(t)
	client := mock_rpc.NewMockAgentConfigurationClient(ctrl)
	configStream := mock_rpc.NewMockAgentConfiguration_GetConfigurationClient(ctrl)
	cfg1 := &agentcfg.AgentConfiguration{
		Observability: &agentcfg.ObservabilityCF{
			Logging: &agentcfg.LoggingCF{
				Level: agentcfg.LogLevelEnum_info,
			},
		},
	}
	cfg2 := &agentcfg.AgentConfiguration{}
	gomock.InOrder(
		client.EXPECT().
			GetConfiguration(gomock.Any(), matcher.ProtoEq(t, &rpc.ConfigurationRequest{}), gomock.Any()).
			Return(configStream, nil),
		configStream.EXPECT().
			Recv().
			Return(&rpc.ConfigurationResponse{
				Configuration: cfg1,
				CommitId:      revision1,
			}, nil),
		configStream.EXPECT().
			Recv().
			Return(&rpc.ConfigurationResponse{
				Configuration: cfg2,
				CommitId:      revision2,
			}, nil),
		configStream.EXPECT().
			Recv().
			DoAndReturn(func() (*rpc.ConfigurationResponse, error) {
				cancel()
				return nil, context.Canceled
			}),
	)
	w := rpc.ConfigurationWatcher{
		Log:                testlogger.New(t),
		Client:             client,
		PollConfig:         testhelpers.NewPollConfig(0),
		ConfigPreProcessor: func(data rpc.ConfigurationData) error { return nil },
	}
	iter := 0
	w.Watch(ctx, func(ctx context.Context, config rpc.ConfigurationData) {
		switch iter {
		case 0:
			matcher.AssertProtoEqual(t, config.Config, cfg1)
		case 1:
			matcher.AssertProtoEqual(t, config.Config, cfg2)
		default:
			t.Fatal(iter)
		}
		iter++
	})
	assert.Equal(t, 2, iter)
}

func TestConfigurationWatcher_ResumeConnection(t *testing.T) {
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()
	ctrl := gomock.NewController(t)
	client := mock_rpc.NewMockAgentConfigurationClient(ctrl)
	configStream1 := mock_rpc.NewMockAgentConfiguration_GetConfigurationClient(ctrl)
	configStream2 := mock_rpc.NewMockAgentConfiguration_GetConfigurationClient(ctrl)
	gomock.InOrder(
		client.EXPECT().
			GetConfiguration(gomock.Any(), matcher.ProtoEq(t, &rpc.ConfigurationRequest{}), gomock.Any()).
			Return(configStream1, nil),
		configStream1.EXPECT().
			Recv().
			Return(&rpc.ConfigurationResponse{
				Configuration: &agentcfg.AgentConfiguration{},
				CommitId:      revision1,
			}, nil),
		configStream1.EXPECT().
			Recv().
			Return(nil, io.EOF),
		client.EXPECT().
			GetConfiguration(gomock.Any(), matcher.ProtoEq(t, &rpc.ConfigurationRequest{
				CommitId: revision1,
			}), gomock.Any()).
			Return(configStream2, nil),
		configStream2.EXPECT().
			Recv().
			DoAndReturn(func() (*rpc.ConfigurationResponse, error) {
				cancel()
				return nil, context.Canceled
			}),
	)
	w := rpc.ConfigurationWatcher{
		Log:                testlogger.New(t),
		Client:             client,
		PollConfig:         testhelpers.NewPollConfig(0),
		ConfigPreProcessor: func(data rpc.ConfigurationData) error { return nil },
	}
	w.Watch(ctx, func(ctx context.Context, config rpc.ConfigurationData) {
		// Don't care
	})
}

func TestConfigurationWatcher_ImmediateReconnectOnEOF(t *testing.T) {
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()
	ctrl := gomock.NewController(t)
	client := mock_rpc.NewMockAgentConfigurationClient(ctrl)
	configStream1 := mock_rpc.NewMockAgentConfiguration_GetConfigurationClient(ctrl)
	configStream2 := mock_rpc.NewMockAgentConfiguration_GetConfigurationClient(ctrl)
	cfg1 := &agentcfg.AgentConfiguration{
		Observability: &agentcfg.ObservabilityCF{
			Logging: &agentcfg.LoggingCF{
				Level: agentcfg.LogLevelEnum_info,
			},
		},
	}
	gomock.InOrder(
		client.EXPECT().
			GetConfiguration(gomock.Any(), matcher.ProtoEq(t, &rpc.ConfigurationRequest{}), gomock.Any()).
			Return(configStream1, nil),
		configStream1.EXPECT().
			Recv().
			Return(&rpc.ConfigurationResponse{
				Configuration: cfg1,
				CommitId:      revision1,
			}, nil),
		configStream1.EXPECT().
			Recv().
			Return(nil, io.EOF), // immediately retries after EOF
		client.EXPECT().
			GetConfiguration(gomock.Any(), matcher.ProtoEq(t, &rpc.ConfigurationRequest{
				CommitId: revision1,
			}), gomock.Any()).
			Return(configStream2, nil),
		configStream2.EXPECT().
			Recv().
			DoAndReturn(func() (*rpc.ConfigurationResponse, error) {
				cancel()
				return nil, context.Canceled
			}),
	)
	w := rpc.ConfigurationWatcher{
		Log:                testlogger.New(t),
		Client:             client,
		PollConfig:         testhelpers.NewPollConfig(0),
		ConfigPreProcessor: func(data rpc.ConfigurationData) error { return nil },
	}
	w.Watch(ctx, func(ctx context.Context, config rpc.ConfigurationData) {
		// Don't care
	})
}
