package nettool_test

import (
	"context"
	"crypto/tls"
	"crypto/x509"
	"encoding/pem"
	"net"
	"testing"
	"time"

	"github.com/stretchr/testify/require"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/nettool"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/mock_tool"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/testhelpers"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/testlogger"
	"go.uber.org/mock/gomock"
)

func TestAccept_AcceptOnlyTLSConnectionsWhenExpectTLS(t *testing.T) {
	l, cert := createAndListenTLSServer(t)
	defer l.Close()

	wrapper, err := testListenerMetrics(t).Wrap(l, "test-tls", time.Minute, true)
	require.NoError(t, err)
	defer wrapper.Close()

	type connResult struct {
		conn         net.Conn
		acceptErr    error
		isTLS        bool
		handshakeErr error
	}
	resultCh := make(chan connResult, 1)

	go func() {
		var result connResult
		result.conn, result.acceptErr = wrapper.Accept()
		if result.acceptErr == nil && result.conn != nil {
			tlsConn, ok := result.conn.(interface{ HandshakeContext(context.Context) error })
			result.isTLS = ok
			if ok {
				result.handshakeErr = tlsConn.HandshakeContext(context.Background())
			}
			// Don't close the connection here as we might want to inspect it in the test
		}
		resultCh <- result
	}()

	rootCAs := x509.NewCertPool()
	rootCAs.AddCert(cert.Leaf)
	dialer := &tls.Dialer{
		Config: &tls.Config{
			RootCAs: rootCAs,
		},
	}
	conn, err := dialer.DialContext(t.Context(), "tcp", l.Addr().String())
	require.NoError(t, err)
	defer conn.Close()

	var result connResult
	select {
	case result = <-resultCh:
	case <-time.After(5 * time.Second):
		t.Fatal("timeout waiting for connection to be accepted")
	}
	require.NoError(t, result.acceptErr, "Accept should not return an error")
	require.NotNil(t, result.conn, "Expected a valid connection")
	require.True(t, result.isTLS, "Connection should be a TLS connection")
	require.NoError(t, result.handshakeErr, "TLS handshake should succeed")
	require.NoError(t, result.conn.Close(), "Closing connection should succeed")
}

func testListenerMetrics(t *testing.T) *nettool.ListenerMetrics {
	ctrl := gomock.NewController(t)
	errReporter := mock_tool.NewMockErrReporter(ctrl)

	return &nettool.ListenerMetrics{
		Log:    testlogger.New(t),
		NewSet: func(name string, ttl time.Duration) (nettool.ConnectionSet, error) { return &nopConnectionSet{}, nil },
		ErrRep: errReporter,
	}
}

func createAndListenTLSServer(t *testing.T) (net.Listener, tls.Certificate) {
	_, _, caCert, caKey := testhelpers.GenerateCACert(t)
	certFile, keyFile := testhelpers.GenerateCertInMem(t, "test-server", caCert, caKey)

	cert, err := tls.X509KeyPair(
		pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certFile}),
		pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: keyFile}),
	)
	require.NoError(t, err)

	tlsConfig := &tls.Config{
		Certificates: []tls.Certificate{cert},
	}
	l, err := tls.Listen("tcp", "127.0.0.1:0", tlsConfig)
	require.NoError(t, err)
	return l, cert
}

func TestAccept_AcceptOnlyNonTLSConnectionsWhenNotExpectingTLS(t *testing.T) {
	l, err := net.Listen("tcp", "127.0.0.1:0")
	require.NoError(t, err)
	defer l.Close()

	wrapper, err := testListenerMetrics(t).Wrap(l, "test-non-tls", time.Minute, false)
	require.NoError(t, err)
	defer wrapper.Close()

	type connResult struct {
		conn      net.Conn
		acceptErr error
		isTLS     bool
	}
	resultCh := make(chan connResult, 1)
	go func() {
		var result connResult
		result.conn, result.acceptErr = wrapper.Accept()
		if result.acceptErr == nil && result.conn != nil {
			_, result.isTLS = result.conn.(*tls.Conn)
		}
		resultCh <- result
	}()

	dialer := &net.Dialer{}
	conn, err := dialer.DialContext(context.Background(), "tcp", l.Addr().String())
	require.NoError(t, err, "Failed to establish raw TCP connection")
	defer conn.Close()

	var result connResult
	select {
	case result = <-resultCh:
	case <-time.After(5 * time.Second):
		t.Fatal("timeout waiting for connection to be accepted")
	}
	require.NoError(t, result.acceptErr, "Accept should not return an error")
	require.NotNil(t, result.conn, "Expected a valid connection")
	require.False(t, result.isTLS, "Expected a non-TLS connection")
	require.NoError(t, result.conn.Close(), "Closing connection should succeed")
}

// nopConnectionSet is a no-op implementation for testing
type nopConnectionSet struct{}

func (n *nopConnectionSet) Set(_ int64) error          { return nil }
func (n *nopConnectionSet) Unset(_ int64) error        { return nil }
func (n *nopConnectionSet) GC(_ context.Context) error { return nil }
