package tls

import (
	"context"
	"crypto/tls"
	"testing"

	"github.com/stretchr/testify/require"

	"gitlab.com/gitlab-org/gitlab-pages/internal/config"
	"gitlab.com/gitlab-org/gitlab-pages/internal/domain"
)

var cert = []byte(`-----BEGIN CERTIFICATE-----
MIIBhTCCASugAwIBAgIQIRi6zePL6mKjOipn+dNuaTAKBggqhkjOPQQDAjASMRAw
DgYDVQQKEwdBY21lIENvMB4XDTE3MTAyMDE5NDMwNloXDTE4MTAyMDE5NDMwNlow
EjEQMA4GA1UEChMHQWNtZSBDbzBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABD0d
7VNhbWvZLWPuj/RtHFjvtJBEwOkhbN/BnnE8rnZR8+sbwnc/KhCk3FhnpHZnQz7B
5aETbbIgmuvewdjvSBSjYzBhMA4GA1UdDwEB/wQEAwICpDATBgNVHSUEDDAKBggr
BgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MCkGA1UdEQQiMCCCDmxvY2FsaG9zdDo1
NDUzgg4xMjcuMC4wLjE6NTQ1MzAKBggqhkjOPQQDAgNIADBFAiEA2zpJEPQyz6/l
Wf86aX6PepsntZv2GYlA5UpabfT2EZICICpJ5h/iI+i341gBmLiAFQOyTDT+/wQc
6MF9+Yw1Yy0t
-----END CERTIFICATE-----`)

var key = []byte(`-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIIrYSSNQFaA2Hwf1duRSxKtLYX5CB04fSeQ6tF1aY/PuoAoGCCqGSM49
AwEHoUQDQgAEPR3tU2Fta9ktY+6P9G0cWO+0kETA6SFs38GecTyudlHz6xvCdz8q
EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA==
-----END EC PRIVATE KEY-----`)

var getConfig = func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
	return &tls.Config{MinVersion: tls.VersionTLS12}, nil
}

func TestInvalidKeyPair(t *testing.T) {
	cfg := &config.Config{}
	getDomain := func(ctx context.Context, name string) (*domain.Domain, error) {
		return nil, nil
	}
	config, err := GetTLSConfig(cfg, getDomain, getConfig)
	require.Error(t, err)
	require.Nil(t, config)
}

func TestClientCert(t *testing.T) {
	cfg := &config.Config{
		General: config.General{
			RootCertificate: cert,
			RootKey:         key,
		},
		TLS: config.TLS{
			ClientAuth: tls.RequireAndVerifyClientCert,
			ClientCert: "./testdata/cert.crt",
		},
	}

	getDomain := func(ctx context.Context, name string) (*domain.Domain, error) {
		return domain.New(name, string(cert), string(key), "", nil), nil
	}

	tlsconfig, err := GetTLSConfig(cfg, getDomain, getConfig)
	require.NoError(t, err)
	require.NotNil(t, tlsconfig)

	// Test that the certificate is properly loaded
	cert, err := tlsconfig.GetCertificate(&tls.ClientHelloInfo{ServerName: "test.example.com"})
	require.NoError(t, err)
	require.NotNil(t, cert)

	// Fix type assertions to match correct function signatures
	require.IsType(t, (func(*tls.ClientHelloInfo) (*tls.Certificate, error))(nil), tlsconfig.GetCertificate)
	require.IsType(t, (func(*tls.ClientHelloInfo) (*tls.Config, error))(nil), tlsconfig.GetConfigForClient)
	require.Equal(t, tls.RequireAndVerifyClientCert, tlsconfig.ClientAuth)
	require.NotNil(t, tlsconfig.ClientCAs)
}

func TestGetTLSConfig(t *testing.T) {
	tests := map[string]struct {
		cfg       *config.Config
		getDomain func(context.Context, string) (*domain.Domain, error)
		expected  *tls.Config
		wantErr   bool
	}{
		"invalid key pair": {
			cfg: &config.Config{},
			getDomain: func(ctx context.Context, name string) (*domain.Domain, error) {
				return nil, nil
			},
			expected: nil,
			wantErr:  true,
		},
		"valid config with client cert": {
			cfg: &config.Config{
				General: config.General{
					RootCertificate: cert,
					RootKey:         key,
				},
				TLS: config.TLS{
					ClientAuth: tls.RequireAndVerifyClientCert,
					ClientCert: "./testdata/cert.crt",
					MinVersion: tls.VersionTLS12,
				},
			},
			getDomain: func(ctx context.Context, name string) (*domain.Domain, error) {
				return domain.New(name, string(cert), string(key), "", nil), nil
			},
			expected: &tls.Config{
				MinVersion: tls.VersionTLS12,
			},
			wantErr: false,
		},
		"valid config without client cert": {
			cfg: &config.Config{
				General: config.General{
					RootCertificate: cert,
					RootKey:         key,
				},
				TLS: config.TLS{
					MinVersion: tls.VersionTLS12,
				},
			},
			getDomain: func(ctx context.Context, name string) (*domain.Domain, error) {
				return domain.New(name, string(cert), string(key), "", nil), nil
			},
			expected: &tls.Config{
				MinVersion: tls.VersionTLS12,
			},
			wantErr: false,
		},
	}

	for name, tc := range tests {
		t.Run(name, func(t *testing.T) {
			config, err := GetTLSConfig(tc.cfg, tc.getDomain, getConfig)
			if tc.wantErr {
				require.Error(t, err)
				return
			}
			require.NoError(t, err)
			require.NotNil(t, config)
			require.Equal(t, tc.expected.MinVersion, config.MinVersion)
			require.Equal(t, preferredCipherSuites, config.CipherSuites)
			require.NotNil(t, config.GetCertificate)
			require.NotNil(t, config.GetConfigForClient)

			if len(tc.cfg.TLS.ClientCert) > 0 {
				require.Equal(t, tc.cfg.TLS.ClientAuth, config.ClientAuth)
				require.NotNil(t, config.ClientCAs)
			} else {
				require.Equal(t, tls.NoClientCert, config.ClientAuth)
				require.Nil(t, config.ClientCAs)
			}
		})
	}
}

func TestEnsureCertificate(t *testing.T) {
	tests := map[string]struct {
		domain   *domain.Domain
		wantCert *tls.Certificate
		wantErr  bool
	}{
		"nil domain": {
			domain:   nil,
			wantCert: nil,
			wantErr:  true,
		},
		"valid domain": {
			domain:   domain.New("test.example.com", string(cert), string(key), "", nil),
			wantCert: &tls.Certificate{},
			wantErr:  false,
		},
		"invalid certificate": {
			domain:   domain.New("invalid.example.com", "invalid-cert", "invalid-key", "", nil),
			wantCert: nil,
			wantErr:  true,
		},
	}

	for name, tc := range tests {
		t.Run(name, func(t *testing.T) {
			certResult, err := tc.domain.EnsureCertificate()
			if tc.wantErr {
				require.Error(t, err)
				require.Nil(t, certResult)
				return
			}
			require.NoError(t, err)
			if tc.wantCert == nil {
				require.Nil(t, certResult)
			} else {
				require.NotNil(t, certResult)
			}
		})
	}
}
