package agent

import (
	"bytes"
	"context"
	"fmt"
	"io"
	"log/slog"
	"net"
	"net/http"
	"net/url"
	"os"
	"strings"
	"time"

	"buf.build/go/protovalidate"
	"github.com/coder/websocket"
	prometheus2 "github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus"
	"github.com/prometheus/client_golang/prometheus"
	"github.com/prometheus/client_golang/prometheus/collectors"
	"github.com/spf13/cobra"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/api"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/module/modshared"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/grpctool"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/httpz"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/logz"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/metric"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/syncz"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/tlstool"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/wstunnel"
	"gitlab.com/gitlab-org/cluster-integration/tunnel/tool/grpcz"
	"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
	otelmetric "go.opentelemetry.io/otel/metric"
	noop2 "go.opentelemetry.io/otel/metric/noop"
	"go.opentelemetry.io/otel/propagation"
	"go.opentelemetry.io/otel/trace"
	"go.opentelemetry.io/otel/trace/noop"
	"google.golang.org/grpc"
	"google.golang.org/grpc/backoff"
	"google.golang.org/grpc/credentials"
	"google.golang.org/grpc/credentials/insecure"
	"google.golang.org/grpc/encoding/gzip"
	"google.golang.org/grpc/keepalive"
)

const (
	AgentName                         = "gitlab-agent"
	DefaultAPIListenNetwork           = "tcp"
	DefaultPrivateAPIListenNetwork    = "tcp"
	DefaultPrivateAPIListenAddress    = ":8081"
	DefaultObservabilityListenNetwork = "tcp"
	DefaultObservabilityListenAddress = ":8080"

	envVarGRPCLogLevel = "GRPC_GO_LOG_SEVERITY_LEVEL" // intentionally the same name as gRPC uses
	envVarLogLevel     = "LOG_LEVEL"
)

type ObsTools struct {
	Reg              *prometheus.Registry
	TP               trace.TracerProvider
	MP               otelmetric.MeterProvider
	P                propagation.TextMapPropagator
	StreamProm       grpc.StreamServerInterceptor
	UnaryProm        grpc.UnaryServerInterceptor
	StreamClientProm grpc.StreamClientInterceptor
	UnaryClientProm  grpc.UnaryClientInterceptor
}

type Options struct {
	Log               *slog.Logger
	LogLevel          *slog.LevelVar
	GRPCHandler       slog.Handler
	GRPCLogLevel      *slog.LevelVar
	Validator         protovalidate.Validator
	AgentKey          *syncz.ValueHolder[api.AgentKey]
	GitLabExternalURL *syncz.ValueHolder[url.URL]

	KASAddress       string
	KASCACertFile    string
	KASHeaders       []string
	KASSkipTLSVerify bool
	KASTLSServerName string

	ObservabilityListenNetwork string
	ObservabilityListenAddress string
	ObservabilityCertFile      string
	ObservabilityKeyFile       string

	TokenFile  string
	AgentToken api.AgentToken
}

func (o *Options) ConstructKASConnection(ot *ObsTools, userAgent string, agentType api.AgentType) (*grpc.ClientConn, error) {
	tlsConfig, err := tlstool.ClientConfigWithCACert(o.KASCACertFile)
	if err != nil {
		return nil, err
	}
	tlsConfig.InsecureSkipVerify = o.KASSkipTLSVerify
	tlsConfig.ServerName = o.KASTLSServerName
	u, err := url.Parse(o.KASAddress)
	if err != nil {
		return nil, fmt.Errorf("invalid gitlab-kas address: %w", err)
	}
	kasHeaders, err := parseHeaders(o.KASHeaders)
	if err != nil {
		return nil, err
	}
	opts := []grpc.DialOption{
		grpc.WithStatsHandler(otelgrpc.NewClientHandler(
			otelgrpc.WithTracerProvider(ot.TP),
			otelgrpc.WithMeterProvider(ot.MP),
			otelgrpc.WithPropagators(ot.P),
			otelgrpc.WithMessageEvents(otelgrpc.ReceivedEvents, otelgrpc.SentEvents),
		)),
		// Default gRPC parameters are good, no need to change them at the moment.
		// Specify them explicitly for discoverability.
		// See https://github.com/grpc/grpc/blob/master/doc/connection-backoff.md.
		grpc.WithConnectParams(grpc.ConnectParams{
			Backoff:           backoff.DefaultConfig,
			MinConnectTimeout: 20 * time.Second, // matches the default gRPC value.
		}),
		grpc.WithSharedWriteBuffer(true),
		grpc.WithDefaultCallOptions(grpc.UseCompressor(gzip.Name), grpc.MaxCallRecvMsgSize(api.GRPCMaxMessageSize)),
		grpc.WithUserAgent(userAgent),
		// keepalive.ClientParameters must be specified at least as large as what is allowed by the
		// Server-side grpc.KeepaliveEnforcementPolicy
		grpc.WithKeepaliveParams(keepalive.ClientParameters{
			// kas allows min 20 seconds, trying to stay below 60 seconds (typical load-balancer timeout) and
			// above kas' Server keepalive Time so that kas pings the client sometimes. This helps mitigate
			// reverse-proxies' enforced Server response timeout.
			Time:                55 * time.Second,
			PermitWithoutStream: true,
		}),
		grpc.WithChainStreamInterceptor(
			ot.StreamClientProm,
			grpctool.StreamClientValidatingInterceptor(o.Validator),
		),
		grpc.WithChainUnaryInterceptor(
			ot.UnaryClientProm,
			grpctool.UnaryClientValidatingInterceptor(o.Validator),
		),
	}
	var addressToDial string
	// "grpcs" is the only scheme where encryption is done by gRPC.
	// "wss" is secure too but gRPC cannot know that, so we tell it it's not.
	secure := u.Scheme == "grpcs"
	switch u.Scheme {
	case "ws", "wss":
		addressToDial = "passthrough:" + o.KASAddress
		dialer := net.Dialer{
			Timeout:   30 * time.Second,
			KeepAlive: 30 * time.Second,
		}
		kasHeaders.Set(httpz.UserAgentHeader, userAgent)
		kasHeaders.Set(httpz.AuthorizationHeader, "Bearer "+string(o.AgentToken))

		opts = append(opts, grpc.WithContextDialer(wstunnel.DialerForGRPC(api.WebSocketMaxMessageSize, &websocket.DialOptions{
			HTTPClient: &http.Client{
				Transport: &http.Transport{
					Proxy:                 http.ProxyFromEnvironment,
					DialContext:           dialer.DialContext,
					TLSClientConfig:       tlsConfig,
					MaxIdleConns:          10,
					IdleConnTimeout:       90 * time.Second,
					TLSHandshakeTimeout:   10 * time.Second,
					ResponseHeaderTimeout: 20 * time.Second,
				},
				CheckRedirect: func(req *http.Request, via []*http.Request) error {
					return http.ErrUseLastResponse
				},
			},
			HTTPHeader:      kasHeaders,
			CompressionMode: websocket.CompressionDisabled,
		})))
	case "grpc":
		// See https://github.com/grpc/grpc/blob/master/doc/naming.md.
		addressToDial = "dns:" + grpctool.HostWithPort(u)
		opts = append(opts,
			grpc.WithPerRPCCredentials(grpctool.NewHeaderMetadata(kasHeaders, !secure)),
			// See https://github.com/grpc/grpc/blob/master/doc/service_config.md.
			// See https://github.com/grpc/grpc/blob/master/doc/load-balancing.md.
			grpc.WithDefaultServiceConfig(`{"loadBalancingConfig":[{"round_robin":{}}]}`),
		)
	case "grpcs":
		// See https://github.com/grpc/grpc/blob/master/doc/naming.md.
		addressToDial = "dns:" + grpctool.HostWithPort(u)
		opts = append(opts,
			grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)),
			grpc.WithPerRPCCredentials(grpctool.NewHeaderMetadata(kasHeaders, !secure)),
			// See https://github.com/grpc/grpc/blob/master/doc/service_config.md.
			// See https://github.com/grpc/grpc/blob/master/doc/load-balancing.md.
			grpc.WithDefaultServiceConfig(`{"loadBalancingConfig":[{"round_robin":{}}]}`),
		)
	default:
		return nil, fmt.Errorf("unsupported scheme in GitLab Kubernetes Agent Server address: %q", u.Scheme)
	}
	if !secure {
		opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
	}
	opts = append(opts, grpc.WithPerRPCCredentials(grpctool.NewTokenCredentials(o.AgentToken, agentType, !secure)))
	conn, err := grpc.NewClient(addressToDial, opts...)
	if err != nil {
		return nil, fmt.Errorf("gRPC.dial: %w", err)
	}
	return conn, nil
}

func parseHeaders(raw []string) (http.Header, error) {
	header := http.Header{}
	for _, h := range raw {
		k, v, ok := strings.Cut(h, ":")
		if !ok {
			return nil, fmt.Errorf("invalid header supplied: %s", h)
		}
		k, v = strings.Trim(k, " "), strings.Trim(v, " ")
		if len(k) < 1 || len(v) < 1 {
			return nil, fmt.Errorf("invalid header supplied: %s", h)
		}
		header.Add(k, v)
	}
	return header, nil
}

func (o *Options) ConstructObservabilityTools() (*ObsTools, func() error, error) {
	// Metrics
	reg := prometheus.NewPedanticRegistry()
	goCollector := collectors.NewGoCollector()
	procCollector := collectors.NewProcessCollector(collectors.ProcessCollectorOpts{})
	srvProm := prometheus2.NewServerMetrics()
	clientProm := prometheus2.NewClientMetrics()
	err := metric.Register(reg, goCollector, procCollector, srvProm, clientProm)
	if err != nil {
		return nil, nil, err
	}
	streamProm := srvProm.StreamServerInterceptor()
	unaryProm := srvProm.UnaryServerInterceptor()
	streamClientProm := clientProm.StreamClientInterceptor()
	unaryClientProm := clientProm.UnaryClientInterceptor()

	// TODO Tracing
	tp := noop.NewTracerProvider()
	p := propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{})

	// TODO metrics via OTEL
	mp := noop2.NewMeterProvider()

	ot := &ObsTools{
		Reg:              reg,
		TP:               tp,
		MP:               mp,
		P:                p,
		StreamProm:       streamProm,
		UnaryProm:        unaryProm,
		StreamClientProm: streamClientProm,
		UnaryClientProm:  unaryClientProm,
	}
	stop := func() error { return nil }
	return ot, stop, nil
}

func (o *Options) logHandler(envVar, level string, writer io.Writer) (slog.Handler, *slog.LevelVar, error) {
	l := os.Getenv(envVar)
	if l == "" {
		l = level
	}
	levelVar := &slog.LevelVar{}
	err := levelVar.UnmarshalText([]byte(l))
	if err != nil {
		return nil, nil, err
	}

	handler := &agentKeyHandler{
		delegate: slog.NewJSONHandler(writer, &slog.HandlerOptions{
			Level: levelVar,
		}),
		agentKey: o.AgentKey,
	}
	return handler, levelVar, nil
}

func (o *Options) Complete(envVarAgentToken, defaultLogLevel, defaultGRPCLogLevel string) error {
	stderr := &logz.LockedWriter{Writer: os.Stderr}
	handler, logLevel, err := o.logHandler(envVarLogLevel, defaultLogLevel, stderr)
	if err != nil {
		return err
	}
	o.Log = slog.New(handler)
	o.LogLevel = logLevel

	o.GRPCHandler, o.GRPCLogLevel, err = o.logHandler(envVarGRPCLogLevel, defaultGRPCLogLevel, stderr)
	if err != nil {
		return err
	}

	o.Validator, err = protovalidate.New()
	if err != nil {
		return err
	}

	tokenFromEnv, ok := os.LookupEnv(envVarAgentToken)
	switch {
	case o.TokenFile != "" && ok:
		return fmt.Errorf("unable to use both token file and %s environment variable to set the agent token", envVarAgentToken)
	case o.TokenFile != "":
		tokenData, err := os.ReadFile(o.TokenFile)
		if err != nil {
			return fmt.Errorf("token file: %w", err)
		}
		tokenData = bytes.TrimSuffix(tokenData, []byte{'\n'})
		o.AgentToken = api.AgentToken(tokenData)
	case ok:
		o.AgentToken = api.AgentToken(tokenFromEnv)
		err := os.Unsetenv(envVarAgentToken)
		if err != nil {
			return fmt.Errorf("failed to unset env var: %w", err)
		}
	default:
		return fmt.Errorf("agent token not set. Please set either token file or %s environment variable", envVarAgentToken)
	}

	return nil
}

func (o *Options) NewRPCAPIFactory() modshared.RPCAPIFactory {
	return func(ctx context.Context, method string) modshared.RPCAPI {
		service, method := grpcz.SplitGRPCMethod(method)
		return &agentRPCAPI{
			RPCAPIStub: modshared.RPCAPIStub{
				Logger:    o.Log.With(logz.TraceIDFromContext(ctx), logz.GRPCService(service), logz.GRPCMethod(method)),
				StreamCtx: ctx,
			},
		}
	}
}

func AddCommonFlagsToCommand(c *cobra.Command, o *Options) {
	f := c.Flags()

	f.StringVar(&o.KASAddress, "kas-address", o.KASAddress, "GitLab Kubernetes Agent Server address")
	f.StringVar(&o.KASCACertFile, "kas-ca-cert-file", o.KASCACertFile, "File with X.509 certificate authority certificate in PEM format. Used for verifying cert of KAS (GitLab Kubernetes Agent Server)")
	f.StringArrayVar(&o.KASHeaders, "kas-header", o.KASHeaders, "HTTP headers to set when connecting to the agent server")
	f.BoolVar(&o.KASSkipTLSVerify, "kas-insecure-skip-tls-verify", o.KASSkipTLSVerify, "If true, the agent server's certificate will not be checked for validity. This will make the connection insecure")
	f.StringVar(&o.KASTLSServerName, "kas-tls-server-name", o.KASTLSServerName, "Server name to use for agent server certificate validation. If it is not provided, the hostname used to contact the server is used")

	f.StringVar(&o.ObservabilityListenNetwork, "observability-listen-network", o.ObservabilityListenNetwork, "Observability network to listen on")
	f.StringVar(&o.ObservabilityListenAddress, "observability-listen-address", o.ObservabilityListenAddress, "Observability address to listen on")
	f.StringVar(&o.ObservabilityCertFile, "observability-cert-file", o.ObservabilityCertFile, "File with X.509 certificate in PEM format for observability endpoint TLS")
	f.StringVar(&o.ObservabilityKeyFile, "observability-key-file", o.ObservabilityKeyFile, "File with X.509 key in PEM format for observability endpoint TLS")

	f.StringVar(&o.TokenFile, "token-file", o.TokenFile, "File with access token")

	c.MarkFlagsRequiredTogether("observability-cert-file", "observability-key-file")
}
