package kas

import (
	"context"
	"fmt"
	"log/slog"
	"net"
	"time"

	"buf.build/go/protovalidate"
	"github.com/ash2k/stager"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/api"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/module/modserver"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/module/modshared"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/grpctool"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/ioz"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/logz"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/nettool"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/pkg/kascfg"
	"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
	"google.golang.org/grpc"
	"google.golang.org/grpc/keepalive"
)

type apiServer struct {
	log          *slog.Logger
	listenCfg    *kascfg.ListenApiCF
	listenMetric *nettool.ListenerMetrics
	server       *grpc.Server
	auxCancel    context.CancelFunc
	ready        func()
}

func newAPIServer(log *slog.Logger, ot *obsTools, cfg *kascfg.ConfigurationFile, factory modshared.RPCAPIFactory,
	grpcServerErrorReporter grpctool.ServerErrorReporter, v protovalidate.Validator, listenMetric *nettool.ListenerMetrics) (*apiServer, error) {
	listenCfg := cfg.Api.Listen
	jwtSecret, err := ioz.LoadSHA256Base64Secret(log, listenCfg.AuthenticationSecretFile)
	if err != nil {
		return nil, fmt.Errorf("auth secret file: %w", err)
	}
	credsOpt, err := grpctool.MaybeTLSCreds(listenCfg.CertificateFile, listenCfg.KeyFile)
	if err != nil {
		return nil, err
	}

	jwtAuther := grpctool.NewHMACJWTAuther(jwtSecret, "", api.JWTKAS, func(ctx context.Context) *slog.Logger {
		return modshared.RPCAPIFromContext[modserver.RPCAPI](ctx).Log()
	})

	auxCtx, auxCancel := context.WithCancel(context.Background())
	keepaliveOpt, sh := grpctool.MaxConnectionAge2GRPCKeepalive(auxCtx, listenCfg.MaxConnectionAge.AsDuration())
	serverOpts := []grpc.ServerOption{
		grpc.StatsHandler(otelgrpc.NewServerHandler(
			otelgrpc.WithTracerProvider(ot.MaybeTraceProvider(ot.grpcServerTracing)),
			otelgrpc.WithMeterProvider(ot.mp),
			otelgrpc.WithPropagators(ot.p),
			otelgrpc.WithMessageEvents(otelgrpc.ReceivedEvents, otelgrpc.SentEvents),
		)),
		grpc.StatsHandler(ot.ssh),
		grpc.StatsHandler(sh),
		grpc.SharedWriteBuffer(true),
		grpc.ChainStreamInterceptor(
			ot.streamProm, // 1. measure all invocations
			modshared.StreamRPCAPIInterceptor(factory),    // 2. inject RPC API
			jwtAuther.StreamServerInterceptor,             // 3. auth and maybe log
			grpctool.StreamServerValidatingInterceptor(v), // x. wrap with validator
			grpctool.StreamServerErrorReporterInterceptor(grpcServerErrorReporter),
		),
		grpc.ChainUnaryInterceptor(
			ot.unaryProm, // 1. measure all invocations
			modshared.UnaryRPCAPIInterceptor(factory),    // 2. inject RPC API
			jwtAuther.UnaryServerInterceptor,             // 3. auth and maybe log
			grpctool.UnaryServerValidatingInterceptor(v), // x. wrap with validator
			grpctool.UnaryServerErrorReporterInterceptor(grpcServerErrorReporter),
		),
		grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{
			MinTime:             20 * time.Second,
			PermitWithoutStream: true,
		}),
		keepaliveOpt,
		grpc.MaxRecvMsgSize(api.GRPCMaxMessageSize),
	}

	serverOpts = append(serverOpts, credsOpt...)

	return &apiServer{
		log:          log,
		listenCfg:    listenCfg,
		listenMetric: listenMetric,
		server:       grpc.NewServer(serverOpts...),
		auxCancel:    auxCancel,
		ready:        ot.probeRegistry.RegisterReadinessToggle("apiServer"),
	}, nil
}

func (s *apiServer) Start(stage stager.Stage) {
	grpctool.StartServer(stage, s.server,
		func(ctx context.Context) (net.Listener, error) { //nolint:dupl
			lis, err := nettool.ListenWithOSTCPKeepAlive(ctx, *s.listenCfg.Network, s.listenCfg.Address)
			if err != nil {
				return nil, err
			}
			lisWrapped, err := s.listenMetric.Wrap(lis, "api", s.listenCfg.MaxConnectionAge.AsDuration())
			if err != nil {
				_ = lis.Close()
				return nil, err
			}
			addr := lisWrapped.Addr()
			s.log.Info("API endpoint is up",
				logz.NetNetworkFromAddr(addr),
				logz.NetAddressFromAddr(addr),
			)
			s.ready()
			return lisWrapped, nil
		},
		func() {
			time.Sleep(s.listenCfg.ListenGracePeriod.AsDuration())
			// We first want gRPC server to send GOAWAY and only then return from the RPC handlers.
			// So we delay signaling the handlers and registry.
			// See https://github.com/grpc/grpc-go/issues/6830 for more background.
			// Start a goroutine in a second and signal running RPC handlers to stop.
			time.AfterFunc(time.Second, s.auxCancel)
		},
		func() {},
	)
}
