package grpctool

import (
	"context"
	"crypto/tls"
	"io"
	"math/rand/v2"
	"net"
	"net/http"
	"testing"
	"time"

	"github.com/coder/websocket"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/grpctool/test"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/httpz"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/testhelpers"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/tlstool"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/wstunnel"
	"gitlab.com/gitlab-org/cluster-integration/tunnel/tool/grpcz"
	"google.golang.org/grpc"
	"google.golang.org/grpc/credentials"
	"google.golang.org/grpc/credentials/insecure"
	"google.golang.org/grpc/keepalive"
	"google.golang.org/grpc/stats"
	"k8s.io/apimachinery/pkg/util/wait"
)

func TestMaxConnectionAge(t *testing.T) {
	t.Parallel()
	const maxAge = 3 * time.Second
	srv := &test.GRPCTestingServer{
		StreamingFunc: func(server grpc.BidiStreamingServer[test.Request, test.Response]) error {
			//start := time.Now()
			//ctx := server.Context()
			//<-ctx.Done()
			//t.Logf("ctx.Err() = %v after %s", ctx.Err(), time.Since(start))
			//return ctx.Err()
			time.Sleep(maxAge + maxAge*2/10) // +20%
			return nil
		},
	}
	testClient := func(t *testing.T, client test.TestingClient) {
		start := time.Now()
		resp, err := client.StreamingRequestResponse(context.Background())
		require.NoError(t, err)
		_, err = resp.Recv()
		require.Equal(t, io.EOF, err, "%s. Elapsed: %s", err, time.Since(start))
	}
	kp := grpc.KeepaliveParams(keepalive.ServerParameters{
		MaxConnectionAge:      maxAge,
		MaxConnectionAgeGrace: maxAge,
	})

	t.Run("WebSocket", func(t *testing.T) {
		testKeepalive(t, true, kp, nil, srv, testClient)
	})
	t.Run("gRPC->WebSocket+gRPC", func(t *testing.T) {
		testKeepalive(t, false, kp, nil, srv, testClient)
	})
}

func TestMaxConnectionAgeAndMaxPollDuration(t *testing.T) {
	t.Parallel()
	const maxAge = 3 * time.Second
	srv := &test.GRPCTestingServer{
		StreamingFunc: func(server grpc.BidiStreamingServer[test.Request, test.Response]) error {
			<-grpcz.MaxConnectionAgeContextFromStreamContext(server.Context()).Done()
			return nil
		},
	}
	testClient := func(t *testing.T, client test.TestingClient) {
		start := time.Now()
		for range 3 {
			reqStart := time.Now()
			resp, err := client.StreamingRequestResponse(context.Background())
			require.NoError(t, err)
			_, err = resp.Recv()
			assert.Equal(t, io.EOF, err, "%s. Request time: %s, overall time: %s", err, time.Since(reqStart), time.Since(start))
		}
	}

	kp, sh := grpcz.MaxConnectionAge2GRPCKeepalive(context.Background(), maxAge)

	t.Run("WebSocket", func(t *testing.T) {
		testKeepalive(t, true, kp, sh, srv, testClient)
	})
	t.Run("gRPC->WebSocket+gRPC", func(t *testing.T) {
		testKeepalive(t, false, kp, sh, srv, testClient)
	})
}

func TestMaxConnectionAgeAndMaxPollDurationRandomizedSequential(t *testing.T) {
	t.Parallel()
	const maxAge = 3 * time.Second
	srv := &test.GRPCTestingServer{
		StreamingFunc: func(server grpc.BidiStreamingServer[test.Request, test.Response]) error {
			select {
			case <-grpcz.MaxConnectionAgeContextFromStreamContext(server.Context()).Done():
			case <-time.After(time.Duration(rand.Int64N(int64(maxAge)))):
			}
			return nil
		},
	}
	testClient := func(t *testing.T, client test.TestingClient) {
		for range 3 {
			start := time.Now()
			resp, err := client.StreamingRequestResponse(context.Background())
			require.NoError(t, err)
			_, err = resp.Recv()
			require.Equal(t, io.EOF, err, "%s. Elapsed: %s", err, time.Since(start))
		}
	}

	kp, sh := grpcz.MaxConnectionAge2GRPCKeepalive(context.Background(), maxAge)

	t.Run("WebSocket", func(t *testing.T) {
		testKeepalive(t, true, kp, sh, srv, testClient)
	})
	t.Run("gRPC->WebSocket+gRPC", func(t *testing.T) {
		testKeepalive(t, false, kp, sh, srv, testClient)
	})
}

func TestMaxConnectionAgeAndMaxPollDurationRandomizedParallel(t *testing.T) {
	t.Parallel()
	const maxAge = 3 * time.Second
	srv := &test.GRPCTestingServer{
		StreamingFunc: func(server grpc.BidiStreamingServer[test.Request, test.Response]) error {
			select {
			case <-grpcz.MaxConnectionAgeContextFromStreamContext(server.Context()).Done():
			case <-time.After(time.Duration(rand.Int64N(int64(maxAge)))):
			}
			return nil
		},
	}
	testClient := func(t *testing.T, client test.TestingClient) {
		var wg wait.Group
		defer wg.Wait()
		for range 10 {
			wg.Start(func() {
				for range 3 {
					time.Sleep(time.Duration(rand.Int64N(int64(maxAge) / 10)))
					start := time.Now()
					resp, err := client.StreamingRequestResponse(context.Background())
					if !assert.NoError(t, err) {
						return
					}
					_, err = resp.Recv()
					assert.Equal(t, io.EOF, err, "%s. Elapsed: %s", err, time.Since(start))
				}
			})
		}
	}

	kp, sh := grpcz.MaxConnectionAge2GRPCKeepalive(context.Background(), maxAge)

	t.Run("WebSocket", func(t *testing.T) {
		testKeepalive(t, true, kp, sh, srv, testClient)
	})
	t.Run("gRPC->WebSocket+gRPC", func(t *testing.T) {
		testKeepalive(t, false, kp, sh, srv, testClient)
	})
}

func TestWSTunnel_TLS(t *testing.T) {
	caCertFile, _, caCert, caKey := testhelpers.GenerateCACert(t)
	certFile, keyFile := testhelpers.GenerateCert(t, "srv", caCert, caKey)
	tlsConfig, err := tlstool.ServerConfig(certFile, keyFile)
	require.NoError(t, err)
	tlsConfig.NextProtos = []string{httpz.TLSNextProtoH2, httpz.TLSNextProtoH1}

	clientTLSConfig, err := tlstool.ClientConfigWithCACert(caCertFile)
	require.NoError(t, err)

	l, err := tls.Listen("tcp", "127.0.0.1:0", tlsConfig)
	require.NoError(t, err)

	lisWrapper := wstunnel.ListenerWrapper{}
	l = lisWrapper.Wrap(l, true)

	s := grpc.NewServer()
	test.RegisterTestingServer(s, &test.GRPCTestingServer{
		UnaryFunc: func(ctx context.Context, r *test.Request) (*test.Response, error) {
			return &test.Response{Message: &test.Response_Scalar{Scalar: 42}}, nil
		},
	})
	defer s.GracefulStop()

	go func() {
		assert.NoError(t, s.Serve(l))
	}()

	t.Run("gRPC", func(t *testing.T) {
		conn, err := grpc.NewClient(
			"dns:"+l.Addr().String(),
			grpc.WithTransportCredentials(credentials.NewTLS(clientTLSConfig)),
		)
		require.NoError(t, err)
		defer func() {
			assert.NoError(t, conn.Close())
		}()
		c := test.NewTestingClient(conn)
		resp, err := c.RequestResponse(context.Background(), &test.Request{})
		require.NoError(t, err)
		assert.EqualValues(t, 42, resp.GetScalar())
	})
	t.Run("gRPC via WebSocket", func(t *testing.T) {
		conn, err := grpc.NewClient(
			"passthrough:wss://"+l.Addr().String(),
			grpc.WithContextDialer(wstunnel.DialerForGRPC(0, &websocket.DialOptions{
				HTTPClient: &http.Client{
					Transport: &http.Transport{
						TLSClientConfig: clientTLSConfig,
					},
				},
			})),
			grpc.WithTransportCredentials(insecure.NewCredentials()),
		)
		require.NoError(t, err)
		defer func() {
			assert.NoError(t, conn.Close())
		}()
		c := test.NewTestingClient(conn)
		resp, err := c.RequestResponse(context.Background(), &test.Request{})
		require.NoError(t, err)
		assert.EqualValues(t, 42, resp.GetScalar())
	})
}

func TestWSTunnel_Cleartext(t *testing.T) {
	l, err := net.Listen("tcp", "127.0.0.1:0")
	require.NoError(t, err)

	lisWrapper := wstunnel.ListenerWrapper{}
	l = lisWrapper.Wrap(l, false)

	s := grpc.NewServer()
	test.RegisterTestingServer(s, &test.GRPCTestingServer{
		UnaryFunc: func(ctx context.Context, r *test.Request) (*test.Response, error) {
			return &test.Response{Message: &test.Response_Scalar{Scalar: 42}}, nil
		},
	})
	defer s.GracefulStop()

	go func() {
		assert.NoError(t, s.Serve(l))
	}()

	t.Run("gRPC", func(t *testing.T) {
		conn, err := grpc.NewClient(
			"dns:"+l.Addr().String(),
			grpc.WithTransportCredentials(insecure.NewCredentials()),
		)
		require.NoError(t, err)
		defer func() {
			assert.NoError(t, conn.Close())
		}()
		c := test.NewTestingClient(conn)
		resp, err := c.RequestResponse(context.Background(), &test.Request{})
		require.NoError(t, err)
		assert.EqualValues(t, 42, resp.GetScalar())
	})
	t.Run("gRPC via WebSocket", func(t *testing.T) {
		conn, err := grpc.NewClient(
			"passthrough:ws://"+l.Addr().String(),
			grpc.WithContextDialer(wstunnel.DialerForGRPC(0, &websocket.DialOptions{})),
			grpc.WithTransportCredentials(insecure.NewCredentials()),
		)
		require.NoError(t, err)
		defer func() {
			assert.NoError(t, conn.Close())
		}()
		c := test.NewTestingClient(conn)
		resp, err := c.RequestResponse(context.Background(), &test.Request{})
		require.NoError(t, err)
		assert.EqualValues(t, 42, resp.GetScalar())
	})
}

func testKeepalive(t *testing.T, webSocketClient bool, kpServerOption grpc.ServerOption, sh stats.Handler, srv test.TestingServer, f func(*testing.T, test.TestingClient)) {
	t.Parallel()
	l, dial := listenerAndDialer(webSocketClient)
	defer func() {
		assert.NoError(t, l.Close())
	}()
	opts := []grpc.ServerOption{kpServerOption}
	if sh != nil {
		opts = append(opts, grpc.StatsHandler(sh))
	}
	s := grpc.NewServer(opts...)
	defer s.GracefulStop()
	test.RegisterTestingServer(s, srv)
	go func() {
		assert.NoError(t, s.Serve(l))
	}()
	conn, err := grpc.NewClient(
		"passthrough:ws://pipe",
		grpc.WithContextDialer(dial),
		grpc.WithTransportCredentials(insecure.NewCredentials()),
	)
	require.NoError(t, err)
	defer func() {
		assert.NoError(t, conn.Close())
	}()
	f(t, test.NewTestingClient(conn))
}

func listenerAndDialer(webSocketClient bool) (net.Listener, func(context.Context, string) (net.Conn, error)) {
	dl := grpcz.NewPipeDialListener()
	var l net.Listener = dl
	d := dl.DialContext
	lisWrapper := wstunnel.ListenerWrapper{}
	l = lisWrapper.Wrap(l, false)

	if webSocketClient {
		d = wstunnel.DialerForGRPC(0, &websocket.DialOptions{
			HTTPClient: &http.Client{
				Transport: &http.Transport{
					DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
						return dl.DialContext(ctx, addr)
					},
				},
			},
		})
	}
	return l, d
}
