package test

import (
	"context"
	"strconv"
	"testing"

	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/module/agent2kas_tunnel/router"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/module/modshared"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/grpctool/test"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/matcher"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/mock_modagent"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/testhelpers"
	"gitlab.com/gitlab-org/cluster-integration/tunnel/rpc"
	"gitlab.com/gitlab-org/cluster-integration/tunnel/tool/grpcz"
	"gitlab.com/gitlab-org/cluster-integration/tunnel/tool/streamvisitor"
	"gitlab.com/gitlab-org/cluster-integration/tunnel/tunserver"
	"go.uber.org/mock/gomock"
	"golang.org/x/sync/errgroup"
	statuspb "google.golang.org/genproto/googleapis/rpc/status"
	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/mem"
	"google.golang.org/grpc/metadata"
	"google.golang.org/grpc/status"
	"google.golang.org/protobuf/reflect/protoreflect"
)

const (
	scalarNumber protoreflect.FieldNumber = 1
	x1Number     protoreflect.FieldNumber = 2
	dataNumber   protoreflect.FieldNumber = 3
	lastNumber   protoreflect.FieldNumber = 4

	metaKey    = "Cba"
	trailerKey = "Abc"
)

func TestStreamHappyPath(t *testing.T) {
	trailer := metadata.MD{}
	trailer.Set(trailerKey, "1", "2")
	ats := &test.GRPCTestingServer{
		StreamingFunc: func(server grpc.BidiStreamingServer[test.Request, test.Response]) error {
			recv, err := server.Recv()
			if err != nil {
				return status.Error(codes.Unavailable, "unavailable")
			}
			val, err := strconv.ParseInt(recv.S1, 10, 64)
			if err != nil {
				return status.Error(codes.Unavailable, "unavailable")
			}
			incomingContext, ok := metadata.FromIncomingContext(server.Context())
			if !ok {
				return status.Error(codes.Unavailable, "unavailable")
			}

			header := metadata.MD{}
			header.Set(metaKey, incomingContext.Get(metaKey)...)

			err = server.SetHeader(header)
			if err != nil {
				return status.Error(codes.Unavailable, "unavailable")
			}
			resps := []*test.Response{
				{
					Message: &test.Response_Scalar{
						Scalar: val,
					},
				},
				{
					Message: &test.Response_X1{
						X1: test.Enum1_v1,
					},
				},
				{
					Message: &test.Response_Data_{
						Data: &test.Response_Data{},
					},
				},
				{
					Message: &test.Response_Data_{
						Data: &test.Response_Data{},
					},
				},
				{
					Message: &test.Response_Last_{
						Last: &test.Response_Last{},
					},
				},
			}
			for _, resp := range resps {
				err = server.Send(resp)
				if err != nil {
					return status.Error(codes.Unavailable, "unavailable")
				}
			}
			server.SetTrailer(trailer)
			return nil
		},
	}
	runTest(t, ats, func(ctx context.Context, t *testing.T, client test.TestingClient) {
		for range 2 { // test several sequential requests
			testStreamHappyPath(ctx, t, client, trailer)
		}
	})
}

func testStreamHappyPath(ctx context.Context, t *testing.T, client test.TestingClient, trailer metadata.MD) {
	ctx, cancel := context.WithCancel(ctx)
	defer cancel()
	meta := metadata.MD{}
	meta.Set(metaKey, "3", "4")
	ctx = metadata.NewOutgoingContext(ctx, meta)
	stream, err := client.StreamingRequestResponse(ctx)
	require.NoError(t, err)
	err = stream.Send(&test.Request{
		S1: "123",
	})
	require.NoError(t, err)
	err = stream.CloseSend()
	require.NoError(t, err)
	var (
		scalarCalled int
		x1Called     int
		dataCalled   int
		lastCalled   int
		eofCalled    int
	)
	v, err := streamvisitor.NewStreamVisitor(&test.Response{})
	require.NoError(t, err)
	err = v.Visit(stream,
		streamvisitor.WithEOFCallback(func() error {
			eofCalled++
			return nil
		}),
		streamvisitor.WithCallback(scalarNumber, func(scalar int64) error {
			assert.EqualValues(t, 123, scalar)
			scalarCalled++
			return nil
		}),
		streamvisitor.WithCallback(x1Number, func(x1 test.Enum1) error {
			x1Called++
			return nil
		}),
		streamvisitor.WithCallback(dataNumber, func(data *test.Response_Data) error {
			dataCalled++
			return nil
		}),
		streamvisitor.WithCallback(lastNumber, func(last *test.Response_Last) error {
			lastCalled++
			return nil
		}),
	)
	require.NoError(t, err)
	assert.Equal(t, 1, scalarCalled)
	assert.Equal(t, 1, x1Called)
	assert.Equal(t, 2, dataCalled)
	assert.Equal(t, 1, lastCalled)
	assert.Equal(t, 1, eofCalled)
	assert.Equal(t, trailer, stream.Trailer())
	header, err := stream.Header()
	require.NoError(t, err)
	assert.Equal(t, meta.Get(metaKey), header.Get(metaKey))
}

func TestUnaryHappyPath(t *testing.T) {
	ats := &test.GRPCTestingServer{
		UnaryFunc: func(ctx context.Context, request *test.Request) (*test.Response, error) {
			val, err := strconv.ParseInt(request.S1, 10, 64)
			if err != nil {
				return nil, status.Error(codes.Unavailable, "unavailable")
			}
			incomingContext, _ := metadata.FromIncomingContext(ctx)
			meta := metadata.MD{}
			meta.Set(metaKey, incomingContext.Get(metaKey)...)
			err = grpc.SetHeader(ctx, meta)
			if err != nil {
				return nil, err
			}
			trailer := metadata.MD{}
			trailer.Set(trailerKey, "1", "2")
			err = grpc.SetTrailer(ctx, trailer)
			if err != nil {
				return nil, err
			}
			return &test.Response{
				Message: &test.Response_Scalar{
					Scalar: val,
				},
			}, nil
		},
	}
	runTest(t, ats, func(ctx context.Context, t *testing.T, client test.TestingClient) {
		for range 2 { // test several sequential requests
			testUnaryHappyPath(ctx, t, client)
		}
	})
}

func testUnaryHappyPath(ctx context.Context, t *testing.T, client test.TestingClient) {
	meta := metadata.MD{}
	meta.Set(metaKey, "3", "4")
	ctx = metadata.NewOutgoingContext(ctx, meta)
	var (
		headerResp  metadata.MD
		trailerResp metadata.MD
	)
	// grpc.Header() and grpc.Trailer are ok here because its a unary RPC.
	resp, err := client.RequestResponse(ctx, &test.Request{
		S1: "123",
	}, grpc.Header(&headerResp), grpc.Trailer(&trailerResp)) //nolint: forbidigo
	require.NoError(t, err)
	assert.EqualValues(t, 123, resp.Message.(*test.Response_Scalar).Scalar)
	assert.Equal(t, meta.Get(metaKey), headerResp.Get(metaKey))
	trailer := metadata.MD{}
	trailer.Set(trailerKey, "1", "2")
	assert.Equal(t, trailer, trailerResp)
}

func TestStreamError(t *testing.T) {
	statusWithDetails, err := status.New(codes.InvalidArgument, "some expected error").
		WithDetails(&test.Request{S1: "some details of the error"})
	require.NoError(t, err)
	ats := &test.GRPCTestingServer{
		StreamingFunc: func(server grpc.BidiStreamingServer[test.Request, test.Response]) error {
			return statusWithDetails.Err()
		},
	}
	runTest(t, ats, func(ctx context.Context, t *testing.T, client test.TestingClient) {
		ctx, cancel := context.WithCancel(ctx)
		defer cancel()
		stream, err := client.StreamingRequestResponse(ctx)
		require.NoError(t, err)
		_, err = stream.Recv()
		require.Error(t, err)
		receivedStatus := status.Convert(err).Proto()
		matcher.AssertProtoEqual(t, receivedStatus, statusWithDetails.Proto())
	})
}

func TestUnaryError(t *testing.T) {
	statusWithDetails, err := status.New(codes.InvalidArgument, "some expected error").
		WithDetails(&test.Request{S1: "some details of the error"})
	require.NoError(t, err)
	ats := &test.GRPCTestingServer{
		UnaryFunc: func(ctx context.Context, request *test.Request) (*test.Response, error) {
			return nil, statusWithDetails.Err()
		},
	}
	runTest(t, ats, func(ctx context.Context, t *testing.T, client test.TestingClient) {
		ctx, cancel := context.WithCancel(ctx)
		defer cancel()
		_, err := client.RequestResponse(ctx, &test.Request{
			S1: "123",
		})
		require.Error(t, err)
		receivedStatus := status.Convert(err).Proto()
		matcher.AssertProtoEqual(t, receivedStatus, statusWithDetails.Proto())
	})
}

func runTest(t *testing.T, ats test.TestingServer, f func(context.Context, *testing.T, test.TestingClient)) {
	// Start/stop
	g, ctx := errgroup.WithContext(context.Background())
	ctx, cancel := context.WithCancel(ctx)

	// Construct server and agent components
	runServer, kasConn, serverInternalServerConn, serverRPCAPI := serverConstructComponents(ctx, t)
	defer func() {
		assert.NoError(t, kasConn.Close())
		assert.NoError(t, serverInternalServerConn.Close())
	}()
	defer func() {
		assert.NoError(t, g.Wait())
	}()
	defer cancel()

	agentAPI := mock_modagent.NewMockAPI(gomock.NewController(t))
	runAgent, agentAPIServer := agentConstructComponents(ctx, t, kasConn, agentAPI)
	agentInfo := testhelpers.AgentkInfoObj()

	serverRPCAPI.EXPECT().
		AgentInfo(gomock.Any(), gomock.Any()).
		Return(agentInfo, nil).
		MinTimes(1)

	test.RegisterTestingServer(agentAPIServer, ats)

	// Run all
	g.Go(func() error {
		return runServer(ctx)
	})
	g.Go(func() error {
		return runAgent(ctx)
	})

	// Test
	client := test.NewTestingClient(serverInternalServerConn)
	f(ctx, t, client)
}

type serverTestingServer struct {
	registry *router.Registry
}

func (s *serverTestingServer) ForwardStream(srv any, server grpc.ServerStream) error {
	ctx := server.Context()
	rpcAPI := modshared.RPCAPIFromContext[modshared.RPCAPI](ctx)
	sts := grpc.ServerTransportStreamFromContext(ctx)
	service, method := grpcz.SplitGRPCMethod(sts.Method())
	_, th := s.registry.FindTunnel(ctx, testhelpers.AgentkKey1, service, method)
	defer th.Done(ctx)
	tun, err := th.Get(ctx)
	if err != nil {
		return status.FromContextError(err).Err()
	}
	defer tun.Done(ctx)
	return tun.ForwardStream(rpcAPI.Log(), server, streamingCallback{incomingStream: server})
}

// registerTestingServer is a test.RegisterTestingServer clone that's been modified to be compatible with
// tunserver.TunnelFinder.FindTunnel().
func registerTestingServer(s *grpc.Server, h *serverTestingServer) {
	// ServiceDesc must match test.Testing_ServiceDesc
	s.RegisterService(&grpc.ServiceDesc{
		ServiceName: test.Testing_ServiceDesc.ServiceName,
		Streams: []grpc.StreamDesc{
			{
				StreamName:    "RequestResponse",
				Handler:       h.ForwardStream,
				ServerStreams: true,
				ClientStreams: true,
			},
			{
				StreamName:    "StreamingRequestResponse",
				Handler:       h.ForwardStream,
				ServerStreams: true,
				ClientStreams: true,
			},
		},
		Metadata: test.Testing_ServiceDesc.Metadata,
	}, nil)
}

var (
	_ tunserver.DataCallback = streamingCallback{}
)

type streamingCallback struct {
	incomingStream grpc.ServerStream
}

func (c streamingCallback) Header(md []*rpc.MetadataKV) error {
	return c.incomingStream.SetHeader(rpc.MetadataKVToMeta(md))
}

func (c streamingCallback) Message(data []byte) error {
	return c.incomingStream.SendMsg(&rpc.RawFrame{
		Data: mem.BufferSlice{mem.SliceBuffer(data)},
	})
}

func (c streamingCallback) Trailer(md []*rpc.MetadataKV) error {
	c.incomingStream.SetTrailer(rpc.MetadataKVToMeta(md))
	return nil
}

func (c streamingCallback) Error(stat *statuspb.Status) error {
	return status.ErrorProto(stat)
}
