package router

import (
	"context"
	"errors"
	"io"
	"testing"

	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/mock_rpc"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/mock_tunnel_tunserver"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/testlogger"
	tunrpc "gitlab.com/gitlab-org/cluster-integration/tunnel/rpc"
	"gitlab.com/gitlab-org/cluster-integration/tunnel/tunserver"
	"go.uber.org/mock/gomock"
	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/mem"
	"google.golang.org/grpc/status"
)

var (
	_ Forwarder           = (*SimpleRPCForwarder)(nil)
	_ tunserver.Forwarder = (*SimpleRPCForwarder)(nil)
)

func TestSupportsServiceAndMethod(t *testing.T) {
	ctrl := gomock.NewController(t)
	tunnel := mock_rpc.NewMockServerStream(ctrl)
	sts := mock_rpc.NewMockServerTransportStream(ctrl)

	ctx := grpc.NewContextWithServerTransportStream(context.Background(), sts)
	tunnel.EXPECT().Context().Return(ctx)
	sts.EXPECT().Method().Return("/gitlab.service1/DoSomething")

	tunnelRetErr := make(chan error, 1)
	forwarder := NewSimpleRPCForwarder(tunnel, tunnelRetErr)

	assert.True(t, forwarder.SupportsServiceAndMethod("gitlab.service1", "DoSomething"))
	assert.False(t, forwarder.SupportsServiceAndMethod("gitlab.service1", "OtherMethod"))
	assert.False(t, forwarder.SupportsServiceAndMethod("other.service", "DoSomething"))
}

func TestPipeIncomingToTunnel_EOF(t *testing.T) {
	ctrl := gomock.NewController(t)
	tunnel := mock_rpc.NewMockServerStream(ctrl)
	incomingStream := mock_rpc.NewMockServerStream(ctrl)
	sts := mock_rpc.NewMockServerTransportStream(ctrl)

	ctx := grpc.NewContextWithServerTransportStream(context.Background(), sts)
	tunnel.EXPECT().Context().Return(ctx)
	sts.EXPECT().Method().Return("/gitlab.service1/DoSomething")

	tunnelRetErr := make(chan error, 1)
	forwarder := NewSimpleRPCForwarder(tunnel, tunnelRetErr)

	incomingStream.EXPECT().RecvMsg(gomock.Any()).Return(io.EOF)

	errTunnel, errIncoming := forwarder.pipeIncomingToTunnel(testlogger.New(t), incomingStream)
	assert.NoError(t, errTunnel)
	assert.NoError(t, errIncoming)
}

func TestPipeIncomingToTunnel_RecvError(t *testing.T) {
	ctrl := gomock.NewController(t)
	tunnel := mock_rpc.NewMockServerStream(ctrl)
	incomingStream := mock_rpc.NewMockServerStream(ctrl)
	sts := mock_rpc.NewMockServerTransportStream(ctrl)

	ctx := grpc.NewContextWithServerTransportStream(context.Background(), sts)
	tunnel.EXPECT().Context().Return(ctx)
	sts.EXPECT().Method().Return("/gitlab.service1/DoSomething")

	tunnelRetErr := make(chan error, 1)
	forwarder := NewSimpleRPCForwarder(tunnel, tunnelRetErr)

	expectedErr := status.Error(codes.Internal, "recv error")
	incomingStream.EXPECT().RecvMsg(gomock.Any()).Return(expectedErr)

	errTunnel, errIncoming := forwarder.pipeIncomingToTunnel(testlogger.New(t), incomingStream)
	assert.Equal(t, expectedErr, errTunnel)
	assert.Equal(t, expectedErr, errIncoming)
}

func TestPipeTunnelToCallback_EOF(t *testing.T) {
	ctrl := gomock.NewController(t)
	tunnel := mock_rpc.NewMockServerStream(ctrl)
	cb := mock_tunnel_tunserver.NewMockDataCallback(ctrl)
	sts := mock_rpc.NewMockServerTransportStream(ctrl)

	ctx := grpc.NewContextWithServerTransportStream(context.Background(), sts)
	tunnel.EXPECT().Context().Return(ctx).MinTimes(1)
	sts.EXPECT().Method().Return("/gitlab.service1/DoSomething")

	tunnelRetErr := make(chan error, 1)
	forwarder := NewSimpleRPCForwarder(tunnel, tunnelRetErr)

	tunnel.EXPECT().RecvMsg(gomock.Any()).Return(io.EOF)

	errTunnel, errIncoming := forwarder.pipeTunnelToCallback(testlogger.New(t), cb)
	assert.NoError(t, errTunnel)
	assert.NoError(t, errIncoming)
}

func TestPipeTunnelToCallback_RecvError(t *testing.T) {
	ctrl := gomock.NewController(t)
	tunnel := mock_rpc.NewMockServerStream(ctrl)
	cb := mock_tunnel_tunserver.NewMockDataCallback(ctrl)
	sts := mock_rpc.NewMockServerTransportStream(ctrl)

	ctx := grpc.NewContextWithServerTransportStream(context.Background(), sts)
	tunnel.EXPECT().Context().Return(ctx).MinTimes(1)
	sts.EXPECT().Method().Return("/gitlab.service1/DoSomething")

	tunnelRetErr := make(chan error, 1)
	forwarder := NewSimpleRPCForwarder(tunnel, tunnelRetErr)

	expectedErr := status.Error(codes.Internal, "tunnel recv error")
	tunnel.EXPECT().RecvMsg(gomock.Any()).Return(expectedErr)

	errTunnel, errIncoming := forwarder.pipeTunnelToCallback(testlogger.New(t), cb)
	assert.Equal(t, expectedErr, errTunnel)
	assert.Equal(t, expectedErr, errIncoming)
}

func TestPipeTunnelToCallback_CallbackError(t *testing.T) {
	ctrl := gomock.NewController(t)
	tunnel := mock_rpc.NewMockServerStream(ctrl)
	cb := mock_tunnel_tunserver.NewMockDataCallback(ctrl)
	sts := mock_rpc.NewMockServerTransportStream(ctrl)

	ctx := grpc.NewContextWithServerTransportStream(context.Background(), sts)
	tunnel.EXPECT().Context().Return(ctx).MinTimes(1)
	sts.EXPECT().Method().Return("/gitlab.service1/DoSomething")

	tunnelRetErr := make(chan error, 1)
	forwarder := NewSimpleRPCForwarder(tunnel, tunnelRetErr)

	tunnel.EXPECT().RecvMsg(gomock.Any()).DoAndReturn(func(m any) error {
		return nil
	})

	expectedErr := errors.New("callback error")
	cb.EXPECT().Message(gomock.Any()).Return(expectedErr)

	errTunnel, errIncoming := forwarder.pipeTunnelToCallback(testlogger.New(t), cb)
	assert.Equal(t, expectedErr, errTunnel)
	assert.Equal(t, expectedErr, errIncoming)
}

func TestErrPair_IsNil(t *testing.T) {
	tests := []struct {
		name     string
		pair     errPair
		expected bool
	}{
		{
			name:     "both nil",
			pair:     errPair{forTunnel: nil, forIncomingStream: nil},
			expected: true,
		},
		{
			name:     "tunnel error",
			pair:     errPair{forTunnel: errors.New("error"), forIncomingStream: nil},
			expected: false,
		},
		{
			name:     "incoming error",
			pair:     errPair{forTunnel: nil, forIncomingStream: errors.New("error")},
			expected: false,
		},
		{
			name:     "both errors",
			pair:     errPair{forTunnel: errors.New("error1"), forIncomingStream: errors.New("error2")},
			expected: false,
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			assert.Equal(t, tt.expected, tt.pair.isNil())
		})
	}
}

func TestPipeIncomingToTunnel_ForwardsMessage(t *testing.T) {
	ctrl := gomock.NewController(t)
	tunnel := mock_rpc.NewMockServerStream(ctrl)
	incomingStream := mock_rpc.NewMockServerStream(ctrl)
	sts := mock_rpc.NewMockServerTransportStream(ctrl)

	ctx := grpc.NewContextWithServerTransportStream(context.Background(), sts)
	tunnel.EXPECT().Context().Return(ctx)
	sts.EXPECT().Method().Return("/gitlab.service1/DoSomething")

	tunnelRetErr := make(chan error, 1)
	forwarder := NewSimpleRPCForwarder(tunnel, tunnelRetErr)

	testData := []byte{1, 2, 3, 4, 5}

	// Incoming stream receives a message
	gomock.InOrder(
		incomingStream.EXPECT().RecvMsg(gomock.Any()).DoAndReturn(func(m any) error {
			frame := m.(*tunrpc.RawFrame)
			frame.Data = mem.BufferSlice{mem.NewBuffer(&testData, mem.DefaultBufferPool())}
			return nil
		}),
		incomingStream.EXPECT().RecvMsg(gomock.Any()).Return(io.EOF),
	)

	// Tunnel should receive the same message
	var sentFrame *tunrpc.RawFrame
	tunnel.EXPECT().SendMsg(gomock.Any()).DoAndReturn(func(m any) error {
		sentFrame = m.(*tunrpc.RawFrame)
		return nil
	})

	errTunnel, errIncoming := forwarder.pipeIncomingToTunnel(testlogger.New(t), incomingStream)

	require.NoError(t, errTunnel)
	require.NoError(t, errIncoming)
	require.NotNil(t, sentFrame)
	assert.Equal(t, testData, sentFrame.Data.Materialize())
}

func TestPipeTunnelToCallback_ForwardsMessage(t *testing.T) {
	ctrl := gomock.NewController(t)
	tunnel := mock_rpc.NewMockServerStream(ctrl)
	cb := mock_tunnel_tunserver.NewMockDataCallback(ctrl)
	sts := mock_rpc.NewMockServerTransportStream(ctrl)

	ctx := grpc.NewContextWithServerTransportStream(context.Background(), sts)
	tunnel.EXPECT().Context().Return(ctx).MinTimes(1)
	sts.EXPECT().Method().Return("/gitlab.service1/DoSomething")

	tunnelRetErr := make(chan error, 1)
	forwarder := NewSimpleRPCForwarder(tunnel, tunnelRetErr)

	testData := []byte{10, 20, 30, 40, 50}

	// Tunnel receives a message, then EOF
	gomock.InOrder(
		tunnel.EXPECT().RecvMsg(gomock.Any()).DoAndReturn(func(m any) error {
			// Simulate receiving a RawFrame - the actual frame will be populated by RecvMsg
			frame := m.(*tunrpc.RawFrame)
			frame.Data = mem.BufferSlice{mem.NewBuffer(&testData, mem.DefaultBufferPool())}
			return nil
		}),
		tunnel.EXPECT().RecvMsg(gomock.Any()).Return(io.EOF),
	)

	// Callback should receive a message
	var forwardedData []byte
	cb.EXPECT().Message(gomock.Any()).DoAndReturn(func(data []byte) error {
		forwardedData = data
		return nil
	})

	errTunnel, errIncoming := forwarder.pipeTunnelToCallback(testlogger.New(t), cb)

	require.NoError(t, errTunnel)
	require.NoError(t, errIncoming)
	assert.Equal(t, testData, forwardedData)
}
