package agent

import (
	"context"
	"errors"
	"fmt"
	"log/slog"
	"net"
	"time"

	"buf.build/go/protovalidate"
	"github.com/ash2k/stager"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/api"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/module/modagent"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/module/modshared"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/grpctool"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/ioz"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/logz"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/nettool"
	"gitlab.com/gitlab-org/cluster-integration/tunnel/tool/grpcz"
	"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
	"google.golang.org/grpc"
	"google.golang.org/grpc/keepalive"
)

const (
	apiListenGracePeriod      = time.Second
	apiListenMaxConnectionAge = 2 * time.Hour
)

// ListenAPIServer represents agentk API that kas can talk to.
// It can be either:
// - in-memory. This is the case when agentk connects to kas.
// - exposed on an address. This is the case when kas connects to agentk.
type ListenAPIServer struct {
	Log                          *slog.Logger
	ListenNetwork, ListenAddress string
	CertFile, KeyFile            string

	AuxCancel context.CancelFunc
	Server    *grpc.Server
}

func NewListenAPIServer(log *slog.Logger, ot *ObsTools, factory modshared.RPCAPIFactory,
	listenNetwork, listenAddress, certFile, keyFile,
	jwtSecretFile, jwtIssuer, jwtAudience,
	mtlsClientCAFile string, mtlsEnabled bool,
	v protovalidate.Validator) (*ListenAPIServer, error) {

	var credsOpt []grpc.ServerOption

	stream := []grpc.StreamServerInterceptor{
		ot.StreamProm, // 1. measure all invocations
		modshared.StreamRPCAPIInterceptor(factory), // 2. inject RPC API
	}
	unary := []grpc.UnaryServerInterceptor{
		ot.UnaryProm, // 1. measure all invocations
		modshared.UnaryRPCAPIInterceptor(factory), // 2. inject RPC API

	}

	if jwtSecretFile != "" {
		if mtlsClientCAFile != "" || mtlsEnabled {
			return nil, errors.New("must configure either JWT or mTLS authentication. Unset some flags")
		}
		jwtSecret, err := ioz.LoadEd25519Base64PublicKey(log, jwtSecretFile)
		if err != nil {
			return nil, fmt.Errorf("auth secret file: %w", err)
		}
		credsOpt, err = grpctool.MaybeTLSCreds(certFile, keyFile)
		if err != nil {
			return nil, err
		}
		jwtAuther := grpctool.NewEdDSAJWTAuther(jwtSecret, jwtIssuer, jwtAudience, func(ctx context.Context) *slog.Logger {
			return modshared.RPCAPIFromContext[modagent.RPCAPI](ctx).Log()
		})
		stream = append(stream,
			jwtAuther.StreamServerInterceptor, // 3. auth and maybe log
		)
		unary = append(unary,
			jwtAuther.UnaryServerInterceptor, // 3. auth and maybe log
		)
	} else {
		var err error
		credsOpt, err = grpctool.MaybeMTLSCreds(certFile, keyFile, mtlsClientCAFile, mtlsEnabled)
		if err != nil {
			return nil, err
		}
	}

	auxCtx, auxCancel := context.WithCancel(context.Background())
	keepaliveOpt, sh := grpcz.MaxConnectionAge2GRPCKeepalive(auxCtx, apiListenMaxConnectionAge)
	opts := []grpc.ServerOption{
		keepaliveOpt,
		grpc.StatsHandler(otelgrpc.NewServerHandler(
			otelgrpc.WithTracerProvider(ot.TP),
			otelgrpc.WithMeterProvider(ot.MP),
			otelgrpc.WithPropagators(ot.P),
			otelgrpc.WithMessageEvents(otelgrpc.ReceivedEvents, otelgrpc.SentEvents),
		)),
		grpc.StatsHandler(sh),
		grpc.SharedWriteBuffer(true),
		grpc.ChainStreamInterceptor(append(stream,
			grpctool.StreamServerValidatingInterceptor(v), // x. wrap with validator
		)...),
		grpc.ChainUnaryInterceptor(append(unary,
			grpctool.UnaryServerValidatingInterceptor(v), // x. wrap with validator
		)...),
		grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{
			MinTime:             20 * time.Second,
			PermitWithoutStream: true,
		}),
		grpc.MaxRecvMsgSize(api.GRPCMaxMessageSize),
	}
	opts = append(opts, credsOpt...)
	return &ListenAPIServer{
		Log:           log,
		ListenNetwork: listenNetwork,
		ListenAddress: listenAddress,
		CertFile:      certFile,
		KeyFile:       keyFile,
		Server:        grpc.NewServer(opts...),
		AuxCancel:     auxCancel,
	}, nil
}

func (s *ListenAPIServer) Start(stage stager.Stage) {
	grpctool.StartServer(stage, s.Server,
		func(ctx context.Context) (net.Listener, error) {
			lis, err := nettool.ListenWithOSTCPKeepAlive(ctx, s.ListenNetwork, s.ListenAddress)
			if err != nil {
				return nil, err
			}
			addr := lis.Addr()
			s.Log.Info("API endpoint is up",
				logz.NetNetworkFromAddr(addr),
				logz.NetAddressFromAddr(addr),
			)
			return lis, nil
		},
		func() {
			time.Sleep(apiListenGracePeriod)
			// We first want gRPC server to send GOAWAY and only then return from the RPC handlers.
			// So we delay signaling the handlers.
			// See https://github.com/grpc/grpc-go/issues/6830 for more background.
			// Start a goroutine in a second and...
			time.AfterFunc(time.Second, func() {
				s.AuxCancel() // ... signal running RPC handlers to stop.
			})
		},
		func() {},
	)
}
