package tunclient

import (
	"context"
	"errors"
	"io"
	"testing"
	"time"

	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/matcher"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/mock_rpc"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/mock_tunnel_rpc"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/testhelpers"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/testlogger"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tunnel/info"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tunnel/rpc"
	"go.uber.org/mock/gomock"
	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/metadata"
	"google.golang.org/grpc/status"
)

var (
	_ ConnectionInterface = (*Connection)(nil)
)

func TestPropagateUntil_Stop(t *testing.T) {
	ctxParent, cancelParent := context.WithCancel(context.Background())
	ctx, cancel, stop := propagateUntil(ctxParent)
	stop()
	// Let the Go runtime schedule the other goroutine.
	// It should exit so that this test doesn't flake.
	time.Sleep(10 * time.Millisecond)
	cancelParent()
	require.NoError(t, ctx.Err(), "Unexpected context cancellation")
	cancel()
	assert.Equal(t, context.Canceled, ctx.Err())
}

func TestPropagateUntil_NoStop(t *testing.T) {
	ctxParent, cancelParent := context.WithCancel(context.Background())
	cancelParent()
	ctx, cancel, _ := propagateUntil(ctxParent)
	defer cancel()
	<-ctx.Done()
}

func TestPropagateUntil_PreservesValues(t *testing.T) {
	type key string
	ctxParent := context.WithValue(context.Background(), key("123"), 42)
	ctx, cancel, _ := propagateUntil(ctxParent)
	defer cancel()
	assert.Equal(t, 42, ctx.Value(key("123")))
}

func TestConnectUnblocksIfNotStartedStreaming(t *testing.T) {
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()
	client, _, _, c := setupConnection(t)

	client.EXPECT().
		Connect(gomock.Any(), gomock.Any()).
		DoAndReturn(func(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[rpc.ConnectRequest, rpc.ConnectResponse], error) {
			cancel()
			<-ctx.Done()
			return nil, ctx.Err()
		})

	err := c.attempt(ctx)
	require.EqualError(t, err, "Connect(): context canceled")
}

// Visitor can get io.EOF after getting rpc.RequestInfo if client sent an rpc.Error, which was forwarded to the tunnel
// and then the tunnel closed the stream.
func TestNoErrorOnEOFAfterRequestInfo(t *testing.T) {
	ctrl := gomock.NewController(t)
	ctx := t.Context()
	client, conn, tunnel, c := setupConnection(t)
	clientStream := mock_rpc.NewMockClientStream(ctrl)

	gomock.InOrder(
		clientStream.EXPECT().
			Header().
			Return(nil, errors.New("header err")),
		tunnel.EXPECT().Send(gomock.Any()),
		tunnel.EXPECT().CloseSend(),
	)

	gomock.InOrder(
		client.EXPECT().
			Connect(gomock.Any(), gomock.Any()).
			Return(tunnel, nil),
		tunnel.EXPECT().
			Send(gomock.Any()),
		tunnel.EXPECT().
			RecvMsg(gomock.Any()).
			Do(testhelpers.RecvMsg(&rpc.ConnectResponse{
				Msg: &rpc.ConnectResponse_RequestInfo{RequestInfo: reqInfo()},
			})),
		conn.EXPECT().
			NewStream(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
			Return(clientStream, nil),
		tunnel.EXPECT().
			RecvMsg(gomock.Any()).
			Return(io.EOF),
	)

	err := c.attempt(ctx)
	require.NoError(t, err)
}

// Visitor can get io.EOF after getting rpc.Message if client sent an rpc.Error, which was forwarded to the tunnel
// and then the tunnel closed the stream.
func TestNoErrorOnEOFAfterMessage(t *testing.T) {
	ctrl := gomock.NewController(t)
	ctx := t.Context()
	client, conn, tunnel, c := setupConnection(t)
	clientStream := mock_rpc.NewMockClientStream(ctrl)

	gomock.InOrder(
		clientStream.EXPECT().
			Header().
			Return(nil, errors.New("header err")),
		tunnel.EXPECT().Send(gomock.Any()),
		tunnel.EXPECT().CloseSend(),
	)

	gomock.InOrder(
		client.EXPECT().
			Connect(gomock.Any(), gomock.Any()).
			Return(tunnel, nil),
		tunnel.EXPECT().
			Send(gomock.Any()),
		tunnel.EXPECT().
			RecvMsg(gomock.Any()).
			Do(testhelpers.RecvMsg(&rpc.ConnectResponse{
				Msg: &rpc.ConnectResponse_RequestInfo{RequestInfo: reqInfo()},
			})),
		conn.EXPECT().
			NewStream(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
			Return(clientStream, nil),
		tunnel.EXPECT().
			RecvMsg(gomock.Any()).
			Do(testhelpers.RecvMsg(&rpc.ConnectResponse{
				Msg: &rpc.ConnectResponse_Message{
					Message: &rpc.Message{Data: []byte{1, 2, 3}},
				},
			})),
		clientStream.EXPECT().
			SendMsg(gomock.Any()),
		tunnel.EXPECT().
			RecvMsg(gomock.Any()).
			Return(io.EOF),
	)

	err := c.attempt(ctx)
	require.NoError(t, err)
}

func TestNoTrailerAfterHeaderError(t *testing.T) {
	ctrl := gomock.NewController(t)
	ctx := t.Context()
	client, conn, tunnel, c := setupConnection(t)
	clientStream := mock_rpc.NewMockClientStream(ctrl)

	done := make(chan struct{})

	headerErr := status.Error(codes.InvalidArgument, "expected header err")
	gomock.InOrder(
		clientStream.EXPECT().
			Header().
			Return(nil, headerErr),
		tunnel.EXPECT().
			Send(matcher.ProtoEq(nil, &rpc.ConnectRequest{
				Msg: &rpc.ConnectRequest_Error{
					Error: &rpc.Error{
						Status: status.Convert(headerErr).Proto(),
					},
				},
			})),
		tunnel.EXPECT().
			CloseSend().
			Do(func() error {
				close(done)
				return nil
			}),
	)

	gomock.InOrder(
		client.EXPECT().
			Connect(gomock.Any(), gomock.Any()).
			Return(tunnel, nil),
		tunnel.EXPECT().
			Send(gomock.Any()),
		tunnel.EXPECT().
			RecvMsg(gomock.Any()).
			Do(testhelpers.RecvMsg(&rpc.ConnectResponse{
				Msg: &rpc.ConnectResponse_RequestInfo{RequestInfo: reqInfo()},
			})),
		conn.EXPECT().
			NewStream(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
			Return(clientStream, nil),
		tunnel.EXPECT().
			RecvMsg(gomock.Any()).
			DoAndReturn(func(m any) error {
				<-done
				return io.EOF
			}),
	)

	err := c.attempt(ctx)
	require.NoError(t, err)
}

func TestTrailerAfterRecvMsgEOF(t *testing.T) {
	ctrl := gomock.NewController(t)
	ctx := t.Context()
	client, conn, tunnel, c := setupConnection(t)
	clientStream := mock_rpc.NewMockClientStream(ctrl)

	done := make(chan struct{})

	gomock.InOrder(
		clientStream.EXPECT().
			Header(),
		tunnel.EXPECT().
			Send(matcher.ProtoEq(nil, &rpc.ConnectRequest{
				Msg: &rpc.ConnectRequest_Header{Header: &rpc.Header{}},
			}, mock_tunnel_rpc.EquateMetadataKV())),
		clientStream.EXPECT().
			RecvMsg(gomock.Any()).
			Return(io.EOF),
		clientStream.EXPECT().
			Trailer().
			Return(metadata.MD{"abc": []string{"a", "b"}}),
		tunnel.EXPECT().
			Send(matcher.ProtoEq(nil, &rpc.ConnectRequest{
				Msg: &rpc.ConnectRequest_Trailer{
					Trailer: &rpc.Trailer{
						Meta: []*rpc.MetadataKV{
							mock_tunnel_rpc.NewMetadataKV("abc", "a", "b"),
						},
					},
				},
			}, mock_tunnel_rpc.EquateMetadataKV())),
		tunnel.EXPECT().
			CloseSend().
			Do(func() error {
				close(done)
				return nil
			}),
	)

	gomock.InOrder(
		client.EXPECT().
			Connect(gomock.Any(), gomock.Any()).
			Return(tunnel, nil),
		tunnel.EXPECT().
			Send(gomock.Any()),
		tunnel.EXPECT().
			RecvMsg(gomock.Any()).
			Do(testhelpers.RecvMsg(&rpc.ConnectResponse{
				Msg: &rpc.ConnectResponse_RequestInfo{RequestInfo: reqInfo()},
			})),
		conn.EXPECT().
			NewStream(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
			Return(clientStream, nil),
		tunnel.EXPECT().
			RecvMsg(gomock.Any()).
			DoAndReturn(func(m any) error {
				<-done
				return io.EOF
			}),
	)

	err := c.attempt(ctx)
	require.NoError(t, err)
}

func TestTrailerAndErrorAfterRecvMsgError(t *testing.T) {
	ctrl := gomock.NewController(t)
	ctx := t.Context()
	client, conn, tunnel, c := setupConnection(t)
	clientStream := mock_rpc.NewMockClientStream(ctrl)

	done := make(chan struct{})

	recvErr := status.Error(codes.InvalidArgument, "expected RecvMsg err")
	gomock.InOrder(
		clientStream.EXPECT().
			Header(),
		tunnel.EXPECT().
			Send(matcher.ProtoEq(nil, &rpc.ConnectRequest{
				Msg: &rpc.ConnectRequest_Header{Header: &rpc.Header{}},
			}, mock_tunnel_rpc.EquateMetadataKV())),
		clientStream.EXPECT().
			RecvMsg(gomock.Any()).
			Return(recvErr),
		clientStream.EXPECT().
			Trailer().
			Return(metadata.MD{"abc": []string{"a", "b"}}),
		tunnel.EXPECT().
			Send(matcher.ProtoEq(nil, &rpc.ConnectRequest{
				Msg: &rpc.ConnectRequest_Trailer{
					Trailer: &rpc.Trailer{
						Meta: []*rpc.MetadataKV{
							mock_tunnel_rpc.NewMetadataKV("abc", "a", "b"),
						},
					},
				},
			}, mock_tunnel_rpc.EquateMetadataKV())),
		tunnel.EXPECT().
			Send(matcher.ProtoEq(nil, &rpc.ConnectRequest{
				Msg: &rpc.ConnectRequest_Error{
					Error: &rpc.Error{
						Status: status.Convert(recvErr).Proto(),
					},
				},
			})),
		tunnel.EXPECT().
			CloseSend().
			Do(func() error {
				close(done)
				return nil
			}),
	)

	gomock.InOrder(
		client.EXPECT().
			Connect(gomock.Any(), gomock.Any()).
			Return(tunnel, nil),
		tunnel.EXPECT().
			Send(gomock.Any()),
		tunnel.EXPECT().
			RecvMsg(gomock.Any()).
			Do(testhelpers.RecvMsg(&rpc.ConnectResponse{
				Msg: &rpc.ConnectResponse_RequestInfo{RequestInfo: reqInfo()},
			})),
		conn.EXPECT().
			NewStream(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
			Return(clientStream, nil),
		tunnel.EXPECT().
			RecvMsg(gomock.Any()).
			DoAndReturn(func(m any) error {
				<-done
				return io.EOF
			}),
	)

	err := c.attempt(ctx)
	require.NoError(t, err)
}

func TestRecvMsgUnblocksIfNotStartedStreaming(t *testing.T) {
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()
	client, _, tunnel, c := setupConnection(t)

	var connectCtx context.Context

	gomock.InOrder(
		client.EXPECT().
			Connect(gomock.Any(), gomock.Any()).
			DoAndReturn(func(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[rpc.ConnectRequest, rpc.ConnectResponse], error) {
				connectCtx = ctx //nolint: fatcontext
				return tunnel, nil
			}),
		tunnel.EXPECT().
			Send(gomock.Any()),
		tunnel.EXPECT().
			RecvMsg(gomock.Any()).
			DoAndReturn(func(m any) error {
				cancel()
				<-connectCtx.Done()
				return connectCtx.Err()
			}),
	)

	err := c.attempt(ctx)
	require.EqualError(t, err, "context canceled")
}

func TestContextIgnoredIfStartedStreaming(t *testing.T) {
	ctrl := gomock.NewController(t)
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()
	client, conn, tunnel, c := setupConnection(t)
	clientStream := mock_rpc.NewMockClientStream(ctrl)

	gomock.InOrder(
		clientStream.EXPECT().
			Header().
			Return(nil, errors.New("header err")),
		tunnel.EXPECT().Send(gomock.Any()),
		tunnel.EXPECT().CloseSend(),
	)
	var connectCtx context.Context

	gomock.InOrder(
		client.EXPECT().
			Connect(gomock.Any(), gomock.Any()).
			DoAndReturn(func(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[rpc.ConnectRequest, rpc.ConnectResponse], error) {
				connectCtx = ctx //nolint: fatcontext
				return tunnel, nil
			}),
		tunnel.EXPECT().
			Send(gomock.Any()),
		tunnel.EXPECT().
			RecvMsg(gomock.Any()).
			Do(testhelpers.RecvMsg(&rpc.ConnectResponse{
				Msg: &rpc.ConnectResponse_RequestInfo{RequestInfo: reqInfo()},
			})),
		conn.EXPECT().
			NewStream(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
			Return(clientStream, nil),
		tunnel.EXPECT().
			RecvMsg(gomock.Any()).
			DoAndReturn(func(m any) error {
				cancel()
				select {
				case <-connectCtx.Done():
					require.FailNow(t, "Unexpected context cancellation")
				default:
				}
				return errors.New("expected err")
			}),
	)

	err := c.attempt(ctx)
	require.EqualError(t, err, "expected err")
}

func TestAPIDescriptorIsSent(t *testing.T) {
	client, _, tunnel, c := setupConnection(t)
	gomock.InOrder(
		client.EXPECT().
			Connect(gomock.Any(), gomock.Any()).
			Return(tunnel, nil),
		tunnel.EXPECT().
			Send(matcher.ProtoEq(t, &rpc.ConnectRequest{
				Msg: &rpc.ConnectRequest_Descriptor_{
					Descriptor_: &rpc.Descriptor{
						ApiDescriptor: descriptor(),
					},
				},
			})).
			Return(errors.New("expected err")),
	)
	err := c.attempt(context.Background())
	require.EqualError(t, err, "Send(descriptor): expected err")
}

func TestAttemptIsUnblockedOnTunnelRecvMessageError(t *testing.T) {
	client, conn, tunnel, c := setupConnection(t)
	ctrl := gomock.NewController(t)
	clientStream := mock_rpc.NewMockClientStream(ctrl)
	var newStreamCtx context.Context
	gomock.InOrder(
		client.EXPECT().
			Connect(gomock.Any(), gomock.Any()).
			Return(tunnel, nil),
		tunnel.EXPECT().
			Send(gomock.Any()), // ConnectRequest_Descriptor_
		tunnel.EXPECT().
			RecvMsg(gomock.Any()).
			Do(testhelpers.RecvMsg(&rpc.ConnectResponse{
				Msg: &rpc.ConnectResponse_RequestInfo{RequestInfo: reqInfo()},
			})),
		conn.EXPECT().
			NewStream(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
			DoAndReturn(func(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
				newStreamCtx = ctx //nolint: fatcontext
				return clientStream, nil
			}),
		tunnel.EXPECT().
			RecvMsg(gomock.Any()).
			Return(errors.New("expected recv error")),
	)
	gomock.InOrder(
		clientStream.EXPECT().
			Header(),
		tunnel.EXPECT().
			Send(matcher.ProtoEq(nil, &rpc.ConnectRequest{
				Msg: &rpc.ConnectRequest_Header{Header: &rpc.Header{}},
			}, mock_tunnel_rpc.EquateMetadataKV())),
		clientStream.EXPECT().
			RecvMsg(gomock.Any()).
			DoAndReturn(func(m any) error {
				<-newStreamCtx.Done() // block until context is canceled
				return newStreamCtx.Err()
			}),
		clientStream.EXPECT().
			Trailer(),
		tunnel.EXPECT().
			Send(gomock.Any()).
			Return(errors.New("expected send error")),
	)

	err := c.attempt(context.Background())
	require.EqualError(t, err, "expected recv error")
}

func TestAttemptIsUnblockedOnTunnelHeaderSendError(t *testing.T) {
	client, conn, tunnel, c := setupConnection(t)
	ctrl := gomock.NewController(t)
	clientStream := mock_rpc.NewMockClientStream(ctrl)
	gomock.InOrder(
		client.EXPECT().
			Connect(gomock.Any(), gomock.Any()).
			Return(tunnel, nil),
		tunnel.EXPECT().
			Send(gomock.Any()), // ConnectRequest_Descriptor_
		tunnel.EXPECT().
			RecvMsg(gomock.Any()).
			Do(testhelpers.RecvMsg(&rpc.ConnectResponse{
				Msg: &rpc.ConnectResponse_RequestInfo{
					RequestInfo: reqInfo(),
				},
			})),
		conn.EXPECT().
			NewStream(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
			Return(clientStream, nil),
		tunnel.EXPECT().
			RecvMsg(gomock.Any()).
			Return(errors.New("expected recv error")),
	)
	gomock.InOrder(
		clientStream.EXPECT().
			Header(),
		tunnel.EXPECT().
			Send(matcher.ProtoEq(nil, &rpc.ConnectRequest{
				Msg: &rpc.ConnectRequest_Header{Header: &rpc.Header{}},
			}, mock_tunnel_rpc.EquateMetadataKV())).
			Return(io.EOF),
	)

	err := c.attempt(context.Background())
	require.EqualError(t, err, "expected recv error")
}

func setupConnection(t *testing.T) (*mock_tunnel_rpc.MockReverseTunnelClient, *mock_rpc.MockClientConnInterface, *mock_tunnel_rpc.MockReverseTunnel_ConnectClient, *Connection) {
	ctrl := gomock.NewController(t)
	client := mock_tunnel_rpc.NewMockReverseTunnelClient(ctrl)
	conn := mock_rpc.NewMockClientConnInterface(ctrl)
	tunnel := mock_tunnel_rpc.NewMockReverseTunnel_ConnectClient(ctrl)
	c := &Connection{
		Log:             testlogger.New(t),
		Descriptor:      descriptor(),
		Client:          client,
		OwnServerConn:   conn,
		PollConfig:      testhelpers.NewPollConfig(0),
		OnActive:        func(c ConnectionInterface) {},
		OnIdle:          func(c ConnectionInterface) {},
		PrepareMetadata: func(md metadata.MD) metadata.MD { return md },
	}
	return client, conn, tunnel, c
}

func descriptor() *info.APIDescriptor {
	return &info.APIDescriptor{
		Services: []*info.Service{
			{
				Name: "bla",
				Methods: []*info.Method{
					{
						Name: "bab",
					},
				},
			},
		},
	}
}

func reqInfo() *rpc.RequestInfo {
	return &rpc.RequestInfo{
		MethodName: "/service/method",
	}
}
