package server

import (
	"fmt"
	"time"

	"buf.build/go/protovalidate"
	"github.com/golang-jwt/jwt/v5"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/api"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/module/kubernetes_api"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/featureflag"
	otelmetric "go.opentelemetry.io/otel/metric"
)

const (
	webSocketTokenRequestParamName = "kas-websocket-token-request" //nolint:gosec
	userIDSubjectPrefix            = "user-id"
)

type webSocketToken struct {
	validator        protovalidate.Validator
	jwtSigningMethod jwt.SigningMethod
	jwtSecret        []byte
	ffCheckCounter   otelmetric.Int64Counter
}

func (t *webSocketToken) generate(url string, agentKey api.AgentKey, userID int64, impConfig *kubernetes_api.ImpersonationConfig, ffs featureflag.Set) (string, error) {
	now := time.Now()
	claims := &WebSocketTokenClaims{
		// Registered Claims
		RegisteredClaimIssuer:    api.JWTKAS,
		RegisteredClaimSubject:   fmt.Sprintf("%s:%d", userIDSubjectPrefix, userID),
		RegisteredClaimAudience:  []string{api.JWTKAS},
		RegisteredClaimExpiresAt: now.Add(5 * time.Second).Unix(),
		RegisteredClaimNotBefore: now.Add(-5 * time.Second).Unix(),
		RegisteredClaimIssuedAt:  now.Unix(),

		// Custom Private Claims
		Endpoint:            url,
		AgentId:             agentKey.ID,
		ImpersonationConfig: impConfig,
		FeatureFlags:        ffs.ToMap(),
	}
	token, err := jwt.NewWithClaims(t.jwtSigningMethod, claims).SignedString(t.jwtSecret)
	if err != nil {
		return "", err
	}
	return token, nil
}

func (t *webSocketToken) verify(token, url string) (api.AgentKey, *kubernetes_api.ImpersonationConfig, featureflag.Set, error) {
	claims := &ValidatingWebSocketTokenClaims{Validator: t.validator, ValidForEndpoint: url}
	_, err := jwt.ParseWithClaims(
		token,
		claims,
		func(*jwt.Token) (any, error) {
			return t.jwtSecret, nil
		},
		jwt.WithValidMethods([]string{t.jwtSigningMethod.Alg()}),
		jwt.WithIssuer(api.JWTKAS),
		jwt.WithAudience(api.JWTKAS),
	)

	if err != nil {
		return api.AgentKey{}, nil, featureflag.Set{}, err
	}

	ffs := featureflag.NewSet(claims.FeatureFlags, t.ffCheckCounter)
	return api.AgentKey{ID: claims.AgentId, Type: api.AgentTypeKubernetes}, claims.ImpersonationConfig, ffs, nil
}
