package agent

import (
	"context"
	"errors"
	"net"

	"buf.build/go/protovalidate"
	"github.com/ash2k/stager"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/api"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/module/modshared"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/grpctool"
	"gitlab.com/gitlab-org/cluster-integration/tunnel/tool/grpcz"
	"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
	"google.golang.org/grpc"
	"google.golang.org/grpc/credentials/insecure"
)

// InMemAPIServer represents agentk API that kas can talk to when running in agentk->kas tunnel mode.
type InMemAPIServer struct {
	Server        *grpc.Server
	InMemConn     *grpc.ClientConn
	InMemListener net.Listener
}

func NewInMemAPIServer(ot *ObsTools, factory modshared.RPCAPIFactory, v protovalidate.Validator) (*InMemAPIServer, error) {
	// In-mem gRPC client->listener pipe
	listener := grpcz.NewPipeDialListener()

	// Construct connection to the API gRPC server
	conn, err := grpc.NewClient("passthrough:api-server",
		grpc.WithSharedWriteBuffer(true),
		grpc.WithContextDialer(listener.DialContext),
		grpc.WithTransportCredentials(insecure.NewCredentials()),
		grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(api.GRPCMaxMessageSize)),
	)
	if err != nil {
		return nil, err
	}
	return &InMemAPIServer{
		Server: grpc.NewServer(
			grpc.StatsHandler(otelgrpc.NewServerHandler(
				otelgrpc.WithTracerProvider(ot.TP),
				otelgrpc.WithMeterProvider(ot.MP),
				otelgrpc.WithPropagators(ot.P),
				otelgrpc.WithMessageEvents(otelgrpc.ReceivedEvents, otelgrpc.SentEvents),
			)),
			grpc.StatsHandler(grpcz.ServerNoopMaxConnAgeStatsHandler{}),
			grpc.SharedWriteBuffer(true),
			grpc.ChainStreamInterceptor(
				ot.StreamProm, // 1. measure all invocations
				modshared.StreamRPCAPIInterceptor(factory),    // 2. inject RPC API
				grpctool.StreamServerValidatingInterceptor(v), // x. wrap with validator
			),
			grpc.ChainUnaryInterceptor(
				ot.UnaryProm, // 1. measure all invocations
				modshared.UnaryRPCAPIInterceptor(factory),    // 2. inject RPC API
				grpctool.UnaryServerValidatingInterceptor(v), // x. wrap with validator
			),
			grpc.MaxRecvMsgSize(api.GRPCMaxMessageSize),
		),
		InMemConn:     conn,
		InMemListener: listener,
	}, nil
}

func (s *InMemAPIServer) Start(stage stager.Stage) {
	grpctool.StartServer(stage, s.Server,
		func(ctx context.Context) (net.Listener, error) {
			return s.InMemListener, nil
		},
		func() {},
		func() {},
	)
}

func (s *InMemAPIServer) Close() error {
	return errors.Join(
		s.InMemConn.Close(),     // first close the client
		s.InMemListener.Close(), // then close the listener (if not closed already)
	)
}
