package gitlab

import (
	"context"
	"fmt"
	"net/http"
	"net/url"
	"time"

	"buf.build/go/protovalidate"
	"github.com/golang-jwt/jwt/v5"
	"github.com/hashicorp/go-retryablehttp"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/api"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/httpz"
	"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
	otelmetric "go.opentelemetry.io/otel/metric"
)

const (
	// This header carries the JWT token for gitlab-rails
	jwtRequestHeader  = "Gitlab-Kas-Api-Request"
	jwtValidFor       = 30 * time.Second
	jwtNotBefore      = 5 * time.Second
	jwtGitLabAudience = "gitlab"
)

type HTTPClient interface {
	Do(*retryablehttp.Request) (*http.Response, error)
}

type Client struct {
	backend           *url.URL
	httpClient        HTTPClient
	httpClientNoRetry HTTPClient
	userAgent         string
	validator         protovalidate.Validator
	ffCheckCounter    otelmetric.Int64Counter
}

func NewClient(backend *url.URL, authSecret []byte, opts ...ClientOption) *Client {
	o := applyClientOptions(opts)
	var transport http.RoundTripper = &http.Transport{
		Proxy:                 o.transportConfig.Proxy,
		DialContext:           o.transportConfig.DialContext,
		TLSClientConfig:       o.transportConfig.TLSClientConfig,
		TLSHandshakeTimeout:   o.transportConfig.TLSHandshakeTimeout,
		MaxIdleConns:          o.transportConfig.MaxIdleConns,
		MaxIdleConnsPerHost:   o.transportConfig.MaxIdleConnsPerHost,
		MaxConnsPerHost:       o.transportConfig.MaxConnsPerHost,
		IdleConnTimeout:       o.transportConfig.IdleConnTimeout,
		ResponseHeaderTimeout: o.transportConfig.ResponseHeaderTimeout,
		ForceAttemptHTTP2:     o.transportConfig.ForceAttemptHTTP2,
	}
	if o.limiter != nil {
		transport = &httpz.RateLimitingRoundTripper{
			Delegate: transport,
			Limiter:  o.limiter,
		}
	}
	httpClient := &http.Client{
		Transport: otelhttp.NewTransport(
			transport,
			otelhttp.WithPropagators(o.tracePropagator),
			otelhttp.WithTracerProvider(o.traceProvider),
			otelhttp.WithMeterProvider(o.meterProvider),
		),
		CheckRedirect: func(req *http.Request, via []*http.Request) error {
			return http.ErrUseLastResponse
		},
	}
	hook := jwtRequestLogHook(authSecret, o.retryConfig.RequestLogHook)
	return &Client{
		backend: backend,
		httpClient: &retryablehttp.Client{
			HTTPClient:      httpClient,
			Logger:          o.retryConfig.Logger,
			RetryWaitMin:    o.retryConfig.RetryWaitMin,
			RetryWaitMax:    o.retryConfig.RetryWaitMax,
			RetryMax:        o.retryConfig.RetryMax,
			RequestLogHook:  hook,
			ResponseLogHook: o.retryConfig.ResponseLogHook,
			CheckRetry:      o.retryConfig.CheckRetry,
			Backoff:         o.retryConfig.Backoff,
			ErrorHandler:    errorHandler,
		},
		httpClientNoRetry: &retryablehttp.Client{
			HTTPClient:      httpClient,
			Logger:          o.retryConfig.Logger,
			RetryMax:        0,
			RequestLogHook:  hook,
			ResponseLogHook: o.retryConfig.ResponseLogHook,
			CheckRetry:      o.retryConfig.CheckRetry,
			ErrorHandler:    errorHandler,
		},
		userAgent:      o.userAgent,
		validator:      o.validator,
		ffCheckCounter: o.ffCheckCounter,
	}
}

func (c *Client) Do(ctx context.Context, opts ...DoOption) error {
	o, err := applyDoOptions(c.validator, opts)
	if err != nil {
		return err
	}
	var rawBody any
	switch {
	case o.bodyReader != nil:
		rawBody = o.bodyReader
	case o.bodyBytes != nil:
		rawBody = o.bodyBytes
	}
	r, err := retryablehttp.NewRequestWithContext(ctx, o.method, httpz.MergeURLPathAndQuery(c.backend, o.path, o.query), rawBody)
	if err != nil {
		return fmt.Errorf("NewRequest: %w", err)
	}
	if len(o.header) > 0 {
		r.Header = o.header
	}
	if o.withJWT {
		r.Header[jwtRequestHeader] = []string{"set me"} // let the jwtRequestLogHook know we need to sign this request
	} else {
		delete(r.Header, jwtRequestHeader) // ensure this is not set
	}
	if c.userAgent != "" {
		r.Header[httpz.UserAgentHeader] = []string{c.userAgent}
	}

	client := c.httpClient
	if o.noRetry {
		client = c.httpClientNoRetry
	}
	resp, err := client.Do(r) //nolint: bodyclose
	if err != nil {
		ctxErr := ctx.Err()
		if ctxErr != nil {
			err = ctxErr // assume request errored out because of context
		}
	} else if o.featureFlagsHandler != nil {
		o.featureFlagsHandler(resp.Header, c.ffCheckCounter)
	}
	return o.responseHandler.Handle(c.validator, resp, err)
}

func (c *Client) FeatureFlagCheckCounter() otelmetric.Int64Counter {
	return c.ffCheckCounter
}

func jwtRequestLogHook(authSecret []byte, hook retryablehttp.RequestLogHook) retryablehttp.RequestLogHook {
	return func(logger retryablehttp.Logger, r *http.Request, reqNum int) {
		if len(r.Header[jwtRequestHeader]) > 0 {
			now := time.Now()
			claims := jwt.RegisteredClaims{
				Issuer:    api.JWTKAS,
				Audience:  jwt.ClaimStrings{jwtGitLabAudience},
				ExpiresAt: jwt.NewNumericDate(now.Add(jwtValidFor)),
				NotBefore: jwt.NewNumericDate(now.Add(-jwtNotBefore)),
				IssuedAt:  jwt.NewNumericDate(now),
			}
			signedClaims, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString(authSecret)
			if err != nil {
				if logger != nil {
					logger.Printf("Sign JWT: %v", err)
				}
			} else {
				r.Header[jwtRequestHeader] = []string{signedClaims}
			}
		}
		if hook != nil {
			hook(logger, r, reqNum)
		}
	}
}

// errorHandler returns the last response and error when ran out of retry attempts.
// It masks values of URL query parameters.
func errorHandler(resp *http.Response, err error, numTries int) (*http.Response, error) {
	ue, ok := err.(*url.Error) //nolint: errorlint
	if !ok {
		return resp, err
	}
	u, parseErr := url.Parse(ue.URL)
	if parseErr != nil {
		return resp, err
	}
	if u.RawQuery != "" {
		maskURLQueryParams(u)
		ue.URL = u.String()
	}
	return resp, ue
}

func maskURLQueryParams(u *url.URL) {
	newVal := []string{"x"}
	q := u.Query()
	for k := range q {
		q[k] = newVal
	}
	u.RawQuery = q.Encode()
}
