package kas

import (
	"context"
	"errors"
	"fmt"
	"log/slog"
	"sync"

	grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/auth"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/api"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/gitlab"
	gapi "gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/gitlab/api"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/module/modserver"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/module/modshared"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/server_api"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/cache"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/grpctool"
	"go.opentelemetry.io/otel/trace"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/metadata"
	"google.golang.org/grpc/status"
)

type serverAgentRPCAPI struct {
	modshared.RPCAPI
	TokenWithType    api.AgentTokenWithType
	GitLabClient     gitlab.ClientInterface
	AgentInfoCache   *cache.CacheWithErr[api.AgentTokenWithType, server_api.AgentInfo]
	agentKeyAttrOnce sync.Once
}

func (a *serverAgentRPCAPI) AgentTokenWithType() api.AgentTokenWithType {
	return a.TokenWithType
}

func (a *serverAgentRPCAPI) AgentInfo(ctx context.Context, log *slog.Logger) (server_api.AgentInfo, error) {
	agentInfo, err := a.getAgentInfoCached(ctx)
	switch {
	case err == nil:
		agentKey := agentInfo.AgentKey()
		a.agentKeyAttrOnce.Do(func() {
			trace.SpanFromContext(ctx).SetAttributes(
				api.TraceAgentIDAttr.Int64(agentKey.ID),
				api.TraceAgentTypeAttr.String(agentKey.Type.String()),
			)
		})
		return agentInfo, nil
	case errors.Is(err, context.Canceled):
		err = status.Error(codes.Canceled, err.Error())
	case errors.Is(err, context.DeadlineExceeded):
		err = status.Error(codes.DeadlineExceeded, err.Error())
	case gitlab.IsForbidden(err):
		err = status.Error(codes.PermissionDenied, "forbidden")
	case gitlab.IsUnauthorized(err):
		err = status.Error(codes.Unauthenticated, "unauthenticated")
	case gitlab.IsNotFound(err):
		err = status.Error(codes.NotFound, "agent not found")
	default:
		a.HandleProcessingError(log, "AgentInfo()", err)
		err = status.Error(codes.Unavailable, "unavailable")
	}
	return nil, err
}

func (a *serverAgentRPCAPI) getAgentInfoCached(ctx context.Context) (server_api.AgentInfo, error) {
	return a.AgentInfoCache.GetItem(ctx, a.TokenWithType, func() (server_api.AgentInfo, error) {
		return gapi.GetAgentInfo(ctx, a.GitLabClient, a.TokenWithType, gitlab.WithoutRetries())
	})
}

type serverAgentRPCAPIFactory struct {
	rpcAPIFactory  modshared.RPCAPIFactory
	gitLabClient   gitlab.ClientInterface
	agentInfoCache *cache.CacheWithErr[api.AgentTokenWithType, server_api.AgentInfo]
}

func (f *serverAgentRPCAPIFactory) New(ctx context.Context, fullMethodName string) (modserver.AgentRPCAPI, error) {
	token, err := grpc_auth.AuthFromMD(ctx, "bearer")
	if err != nil {
		return nil, err
	}
	agentType, err := agentTypeFromMD(ctx)
	if err != nil {
		return nil, err
	}
	return &serverAgentRPCAPI{
		RPCAPI: f.rpcAPIFactory(ctx, fullMethodName),
		TokenWithType: api.AgentTokenWithType{
			Token: api.AgentToken(token),
			Type:  agentType,
		},
		GitLabClient:   f.gitLabClient,
		AgentInfoCache: f.agentInfoCache,
	}, nil
}

// agentTypeFromMD is a helper function for extracting the agent type from the gRPC metadata of the request.
func agentTypeFromMD(ctx context.Context) (api.AgentType, error) {
	vals := metadata.ValueFromIncomingContext(ctx, grpctool.MetadataAgentType)
	switch len(vals) {
	case 0:
		return api.AgentTypeKubernetes, nil
	case 1:
		return api.ParseAgentType(vals[0])
	default:
		return api.AgentTypeUnknown, fmt.Errorf("expecting a single %s, got %d", grpctool.MetadataAgentType, len(vals))
	}
}
