package agent

import (
	"context"
	"fmt"
	"log/slog"
	"net"
	"net/url"
	"os"
	"sync"
	"time"

	"buf.build/go/protovalidate"
	"github.com/ash2k/stager"
	"github.com/golang-jwt/jwt/v5"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/api"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/module/modagent"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/module/modshared"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/errz"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/grpctool"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/ioz"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/logz"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/nettool"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/tlstool"
	"gitlab.com/gitlab-org/cluster-integration/tunnel/rpc"
	"gitlab.com/gitlab-org/cluster-integration/tunnel/tool/grpcz"
	"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
	"google.golang.org/grpc"
	"google.golang.org/grpc/backoff"
	"google.golang.org/grpc/credentials"
	"google.golang.org/grpc/credentials/insecure"
	"google.golang.org/grpc/keepalive"
)

const (
	envVarOwnPrivateAPIURL  = "OWN_PRIVATE_API_URL"
	envVarOwnPrivateAPIHost = "OWN_PRIVATE_API_HOST"

	privateAPIListenGracePeriod      = time.Second
	privateAPIListenMaxConnectionAge = 2 * time.Hour
)

type PrivateAPIServer struct {
	Log            *slog.Logger
	ListenNetwork  string
	ListenAddress  string
	OwnURL         grpcz.URLTarget
	OwnURLScheme   string
	OwnURLPort     string
	Server         grpctool.GRPCServer
	ListenServer   *grpc.Server
	InMemServer    *grpc.Server
	InMemListener  net.Listener
	AgentPool      grpcz.PoolInterface[grpcz.URLTarget]
	AgentPoolClose func()
	AuxCancel      context.CancelFunc
}

func NewPrivateAPIServer(log *slog.Logger, ot *ObsTools, errRep errz.ErrReporter, factory modshared.RPCAPIFactory,
	v protovalidate.Validator,
	userAgent, listenNetwork, listenAddress, certFile, keyFile, caCertFile, jwtAuthFile string) (*PrivateAPIServer, error) {

	jwtSecret, err := ioz.LoadSHA256Base64Secret(log, jwtAuthFile)
	if err != nil {
		return nil, fmt.Errorf("auth secret file: %w", err)
	}

	ownURL := os.Getenv(envVarOwnPrivateAPIURL)
	if ownURL == "" {
		return nil, fmt.Errorf("cannot determine own private API URL. Please set the %s environment variable", envVarOwnPrivateAPIURL)
	}
	ownURLScheme, ownURLPort, err := urlSchemeAndPort(ownURL)
	if err != nil {
		return nil, fmt.Errorf("error parsing own private API URL: %w", err)
	}
	log.Info("Using own private API URL", logz.URL(ownURL))

	ownHost := os.Getenv(envVarOwnPrivateAPIHost)

	// In-memory gRPC client->listener pipe
	listener := grpcz.NewPipeDialListener()

	// Client pool
	agentkPool, err := newAgentkPool(log, ot, errRep, jwtSecret, ownURL, ownHost,
		caCertFile, v, userAgent, listener.DialContext)
	if err != nil {
		return nil, fmt.Errorf("agentk pool: %w", err)
	}

	// Server
	auxCtx, auxCancel := context.WithCancel(context.Background())
	server, inMemServer, err := newPrivateAPIServerImpl(log, auxCtx, ot, jwtSecret, factory, ownHost, certFile, keyFile, v)
	if err != nil {
		auxCancel()
		return nil, err
	}
	return &PrivateAPIServer{
		Log:            log,
		ListenNetwork:  listenNetwork,
		ListenAddress:  listenAddress,
		OwnURL:         grpcz.URLTarget(ownURL),
		OwnURLScheme:   ownURLScheme,
		OwnURLPort:     ownURLPort,
		Server:         grpctool.AggregateServer{server, inMemServer},
		ListenServer:   server,
		InMemServer:    inMemServer,
		InMemListener:  listener,
		AgentPool:      agentkPool,
		AgentPoolClose: sync.OnceFunc(func() { agentkPool.Shutdown(30 * time.Second) }),
		AuxCancel:      auxCancel,
	}, nil
}

func (s *PrivateAPIServer) Start(stage stager.Stage) {
	stopInMem := make(chan struct{})
	grpctool.StartServer(stage, s.InMemServer,
		func(ctx context.Context) (net.Listener, error) {
			return s.InMemListener, nil
		},
		func() {
			<-stopInMem
			s.AgentPoolClose()
		},
		func() {},
	)
	grpctool.StartServer(stage, s.ListenServer,
		func(ctx context.Context) (net.Listener, error) {
			lis, err := nettool.ListenWithOSTCPKeepAlive(ctx, s.ListenNetwork, s.ListenAddress)
			if err != nil {
				return nil, err
			}
			addr := lis.Addr()
			s.Log.Info("Private API endpoint is up",
				logz.NetNetworkFromAddr(addr),
				logz.NetAddressFromAddr(addr),
			)
			return lis, nil
		},
		func() {
			time.Sleep(privateAPIListenGracePeriod)
			// We first want gRPC server to send GOAWAY and only then return from the RPC handlers.
			// So we delay signaling the handlers.
			// See https://github.com/grpc/grpc-go/issues/6830 for more background.
			// Start a goroutine in a second and...
			time.AfterFunc(time.Second, func() {
				close(stopInMem) // ... signal the in-memory server to stop.
				s.AuxCancel()    // ... signal running RPC handlers to stop.
			})
		},
		func() {},
	)
}

func (s *PrivateAPIServer) Close() error {
	s.AgentPoolClose()             // first close the client (if not closed already)
	return s.InMemListener.Close() // then close the listener (if not closed already)
}

func newPrivateAPIServerImpl(log *slog.Logger, auxCtx context.Context, ot *ObsTools,
	jwtSecret []byte, factory modshared.RPCAPIFactory,
	ownPrivateAPIHost, certFile, keyFile string, v protovalidate.Validator) (*grpc.Server, *grpc.Server, error) {

	credsOpt, err := grpctool.MaybeTLSCreds(certFile, keyFile)
	if err != nil {
		return nil, nil, err
	}
	if ownPrivateAPIHost == "" && len(credsOpt) > 0 {
		log.Info(fmt.Sprintf("%s environment variable is not set. Please set it if you want to "+
			"override the server name used for agentk->agentk TLS communication", envVarOwnPrivateAPIHost))
	}

	jwtAuther := grpctool.NewHMACJWTAuther(jwtSecret, api.JWTAgentk, api.JWTAgentk, func(ctx context.Context) *slog.Logger {
		return modshared.RPCAPIFromContext[modagent.RPCAPI](ctx).Log()
	})

	keepaliveOpt, sh := grpcz.MaxConnectionAge2GRPCKeepalive(auxCtx, privateAPIListenMaxConnectionAge)
	sharedOpts := []grpc.ServerOption{
		keepaliveOpt,
		grpc.StatsHandler(otelgrpc.NewServerHandler(
			otelgrpc.WithTracerProvider(ot.TP),
			otelgrpc.WithMeterProvider(ot.MP),
			otelgrpc.WithPropagators(ot.P),
			otelgrpc.WithMessageEvents(otelgrpc.ReceivedEvents, otelgrpc.SentEvents),
		)),
		grpc.StatsHandler(sh),
		grpc.SharedWriteBuffer(true),
		grpc.ChainStreamInterceptor(
			ot.StreamProm, // 1. measure all invocations
			modshared.StreamRPCAPIInterceptor(factory),    // 2. inject RPC API
			jwtAuther.StreamServerInterceptor,             // 3. auth and maybe log
			grpctool.StreamServerValidatingInterceptor(v), // x. wrap with validator
		),
		grpc.ChainUnaryInterceptor(
			ot.UnaryProm, // 1. measure all invocations
			modshared.UnaryRPCAPIInterceptor(factory),    // 2. inject RPC API
			jwtAuther.UnaryServerInterceptor,             // 3. auth and maybe log
			grpctool.UnaryServerValidatingInterceptor(v), // x. wrap with validator
		),
		grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{
			MinTime:             20 * time.Second,
			PermitWithoutStream: true,
		}),
		grpc.ForceServerCodecV2(rpc.RawCodecWithProtoFallback{}),
		grpc.MaxRecvMsgSize(api.GRPCMaxMessageSize),
	}
	server := grpc.NewServer(append(sharedOpts, credsOpt...)...)
	inMemServer := grpc.NewServer(sharedOpts...)
	return server, inMemServer, nil
}

func newAgentkPool(log *slog.Logger, ot *ObsTools, errRep errz.ErrReporter,
	jwtSecret []byte, ownPrivateAPIURL, ownPrivateAPIHost, caCertificateFile string,
	v protovalidate.Validator, userAgent string,
	dialer func(context.Context, string) (net.Conn, error)) (grpcz.PoolInterface[grpcz.URLTarget], error) {

	sharedPoolOpts := []grpc.DialOption{
		grpc.WithSharedWriteBuffer(true),
		// Default gRPC parameters are good, no need to change them at the moment.
		// Specify them explicitly for discoverability.
		// See https://github.com/grpc/grpc/blob/master/doc/connection-backoff.md.
		grpc.WithConnectParams(grpc.ConnectParams{
			Backoff:           backoff.DefaultConfig,
			MinConnectTimeout: 20 * time.Second, // matches the default gRPC value.
		}),
		grpc.WithStatsHandler(otelgrpc.NewClientHandler(
			otelgrpc.WithTracerProvider(ot.TP),
			otelgrpc.WithMeterProvider(ot.MP),
			otelgrpc.WithPropagators(ot.P),
			otelgrpc.WithMessageEvents(otelgrpc.ReceivedEvents, otelgrpc.SentEvents),
		)),
		grpc.WithUserAgent(userAgent),
		grpc.WithKeepaliveParams(keepalive.ClientParameters{
			Time:                55 * time.Second,
			PermitWithoutStream: true,
		}),
		grpc.WithPerRPCCredentials(&grpctool.JWTCredentials{
			SigningMethod: jwt.SigningMethodHS256,
			SigningKey:    jwtSecret,
			Audience:      api.JWTAgentk,
			Issuer:        api.JWTAgentk,
			Insecure:      true, // We may or may not have TLS setup, so always say creds don't need TLS.
		}),
		grpc.WithChainStreamInterceptor(
			ot.StreamClientProm,
			grpctool.StreamClientValidatingInterceptor(v),
		),
		grpc.WithChainUnaryInterceptor(
			ot.UnaryClientProm,
			grpctool.UnaryClientValidatingInterceptor(v),
		),
		grpc.WithNoProxy(),
		grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(api.GRPCMaxMessageSize)),
	}

	// Construct in-memory connection to private API gRPC server
	inMemConn, err := grpc.NewClient("passthrough:private-server",
		append(sharedPoolOpts,
			grpc.WithContextDialer(dialer),
			grpc.WithTransportCredentials(insecure.NewCredentials()),
		)...,
	)
	if err != nil {
		return nil, err
	}
	tlsCreds, err := tlstool.ClientConfigWithCACert(caCertificateFile)
	if err != nil {
		return nil, err
	}
	tlsCreds.ServerName = ownPrivateAPIHost
	agentkPool := grpcz.NewPool(log, errRep, grpcz.DefaultNewConnection(credentials.NewTLS(tlsCreds), sharedPoolOpts...))
	return grpcz.NewPoolSelf(log, errRep, agentkPool, grpcz.URLTarget(ownPrivateAPIURL), inMemConn), nil
}

func urlSchemeAndPort(urlStr string) (string /* scheme */, string /* port */, error) {
	u, err := url.Parse(urlStr)
	if err != nil {
		return "", "", err
	}
	var defaultPort string
	switch u.Scheme {
	case "grpc":
		defaultPort = "80"
	case "grpcs":
		defaultPort = "443"
	default:
		return "", "", fmt.Errorf("unknown scheme in own private API URL: %s", u.Scheme)
	}
	port := u.Port()
	if port == "" {
		port = defaultPort
	}
	return u.Scheme, port, nil
}
