package agent

import (
	"context"
	"encoding/json"
	"errors"
	"fmt"
	"log/slog"
	"reflect"
	"runtime"
	"strings"

	"github.com/google/cel-go/cel"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/module/kubernetes_api"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/module/kubernetes_api/agent/watch_graph"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/module/kubernetes_api/rpc"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/module/modagent"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/module/modshared"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/fieldz"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/grpctool"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/logz"
	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
	"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
	"k8s.io/apimachinery/pkg/runtime/schema"
	"k8s.io/apimachinery/pkg/util/sets"
	"k8s.io/apimachinery/pkg/util/wait"
	"k8s.io/client-go/discovery"
	"k8s.io/client-go/dynamic"
	"k8s.io/client-go/rest"
	"k8s.io/client-go/util/jsonpath"
)

const (
	maxBatchSize = 128
)

func (s *server) WatchGraph(req *rpc.WatchGraphRequest, srv grpc.ServerStreamingServer[rpc.WatchGraphResponse]) error {
	return s.WatchGraphWithRoots(req, srv)
}

func (s *server) WatchGraphWithRoots(req *rpc.WatchGraphRequest, srv grpc.ServerStreamingServer[rpc.WatchGraphResponse]) error {
	actionCh := make(chan *rpc.Action, maxBatchSize)
	warnCh := make(chan *rpc.Warning, maxBatchSize)
	errCh := make(chan *watch_graph.Error)
	ctx := srv.Context()
	rpcAPI := modshared.RPCAPIFromContext[modagent.RPCAPI](ctx)
	log := rpcAPI.Log()
	done := ctx.Done()

	g, err := s.watchGraphConstruct(log, rpcAPI, req,
		func(action *rpc.Action) {
			select {
			case <-done:
			case actionCh <- action:
			}
		},
		func(warn *rpc.Warning) {
			select {
			case <-done:
			case warnCh <- warn:
			}
		},
	)
	if err != nil {
		log.Error("Watch graph call failed", logz.Error(err))
		return err
	}

	var wg wait.Group
	defer wg.Wait()
	wg.Start(func() {
		runErr := g.Run(ctx)
		if runErr != nil {
			select {
			case <-done:
			case errCh <- runErr:
			}
		}
	})

	return s.watchGraphStreamResponse(ctx, actionCh, warnCh, errCh, srv)
}

// watchGraphConstruct returns gRPC-status-compatible error.
func (s *server) watchGraphConstruct(log *slog.Logger, rpcAPI modagent.RPCAPI, req *rpc.WatchGraphRequest,
	onAction func(*rpc.Action), onWarn func(*rpc.Warning)) (*watch_graph.WatchGraph, error) {

	queries, err := s.watchGraphConstructQueries(req.Queries)
	if err != nil {
		return nil, err
	}
	ns, err := s.watchGraphConstructNamespaces(req.Namespaces)
	if err != nil {
		return nil, err
	}
	restConfig, err := s.watchGraphConstructRESTConfig(req.ImpConfig)
	if err != nil {
		return nil, err
	}
	client, err := dynamic.NewForConfig(restConfig)
	if err != nil {
		return nil, status.Errorf(codes.Unavailable, "RESTClientFor(): %v", err)
	}
	discoClient, err := discovery.NewDiscoveryClientForConfig(restConfig)
	if err != nil {
		return nil, status.Errorf(codes.Unavailable, "NewDiscoveryClientForConfig(): %v", err)
	}

	toJSON := func(v any) ([]byte, bool) {
		data, errMarshal := json.Marshal(v)
		if errMarshal != nil {
			rpcAPI.HandleProcessingError(log, "JSON marshal error", errMarshal)
			onWarn(&rpc.Warning{
				Type:    watch_graph.InternalErrorWarnType,
				Message: fmt.Sprintf("JSON marshal error: %v", errMarshal),
			})
			return nil, false
		}
		return data, true
	}
	onWarning := func(w watch_graph.Warning) {
		var attrs []byte
		if !w.Attributes.IsZero() {
			attrs, _ = toJSON(w.Attributes) // continue on error anyway
		}
		rpcAPI.HandleProcessingError(log, "Graph API warning", errors.New(w.Message), attributes(attrs))
		onWarn(&rpc.Warning{
			Type:       w.Type,
			Message:    w.Message,
			Attributes: attrs,
		})
	}
	var o watch_graph.ObjectGraphObserver[watch_graph.VertexID, watch_graph.ArcType, watch_graph.VertexData, watch_graph.ArcAttrs]
	o = &watchGraphActionEmittingObserver{
		log:       log,
		onAction:  onAction,
		onWarning: onWarning,
	}
	if req.Roots != nil {
		isRootCEL, err := watchGraphPrepareBoolCELProgram(s.objEnv, joinBoolCELExprs(req.Roots.ObjectSelectorExpressions))
		if err != nil {
			return nil, status.Errorf(codes.InvalidArgument, "roots expression: %v", err)
		}
		o = &watchGraphRootsObserver{
			log:                log,
			delegate:           o,
			isRootCEL:          isRootCEL,
			objectToEvalVars:   objectToEvalVars,
			onWarning:          onWarning,
			roots:              make(sets.Set[watch_graph.VertexID]),
			reachable:          make(sets.Set[watch_graph.VertexID]),
			ignoreArcDirection: arcTypeStringSliceToTypedSlice(req.Roots.IgnoreArcDirection),
		}
	}
	wg := watch_graph.New(watch_graph.Opts{
		Log:              log,
		Queries:          queries,
		Namespaces:       ns,
		DiscoClient:      discoClient,
		Client:           client,
		OnWarning:        onWarning,
		ObjectToEvalVars: objectToEvalVars,
		Graph: watch_graph.NewObjectGraph(watch_graph.ObjectGraphOpts[watch_graph.VertexID, watch_graph.ArcType, watch_graph.VertexData, watch_graph.ArcAttrs]{
			IsVertexDataEqual: func(a, b watch_graph.VertexData) bool {
				return reflect.DeepEqual(a.Object, b.Object)
			},
			IsArcDataEqual: func(a, b watch_graph.ArcAttrs) bool {
				return a == b
			},
			Observer: o,
		}),
	})
	return wg, nil
}

// watchGraphConstructRESTConfig returns gRPC-status-compatible error.
func (s *server) watchGraphConstructRESTConfig(impConfig *kubernetes_api.ImpersonationConfig) (*rest.Config, error) {
	restConfig := s.restConfig
	restImp := !isEmptyImpersonationConfig(restConfig.Impersonate)
	cfgImp := !impConfig.IsEmpty()
	switch {
	case !restImp && !cfgImp:
		// No impersonation
	case restImp && !cfgImp:
		// Impersonation is configured in the rest config
	case !restImp && cfgImp:
		// Impersonation is configured in the agent config
		restConfig = rest.CopyConfig(restConfig) // copy to avoid mutating a shared config object
		restConfig.Impersonate.UserName = impConfig.Username
		restConfig.Impersonate.UID = impConfig.Uid
		restConfig.Impersonate.Groups = impConfig.Groups
		restConfig.Impersonate.Extra = impConfig.GetExtraAsMap()
	default:
		// Nested impersonation support https://gitlab.com/gitlab-org/gitlab/-/issues/338664
		return nil, status.Error(codes.FailedPrecondition, "Nested impersonation is not supported - agent is already configured to impersonate an identity")
	}
	return restConfig, nil
}

// watchGraphConstructQueries returns gRPC-status-compatible error.
func (s *server) watchGraphConstructQueries(queries []*rpc.Query) ([]any, error) {
	res := make([]any, 0, len(queries))
	for queryPos, query := range queries {
		add, err := s.watchGraphConstructQuery(query)
		if err != nil {
			return nil, grpctool.AugmentErrorMessage(codes.InvalidArgument, fmt.Sprintf("queries[%d]", queryPos), err)
		}
		res = append(res, add)
	}
	return res, nil
}

func (s *server) watchGraphConstructQuery(query *rpc.Query) (any, error) {
	switch q := query.Query.(type) {
	case *rpc.Query_Include_:
		resourceSelectorProg, err := watchGraphPrepareBoolCELProgram(s.resEnv, q.Include.ResourceSelectorExpression)
		if err != nil {
			return nil, status.Errorf(codes.InvalidArgument, "resource selector expression: %v", err)
		}
		object := q.Include.GetObject()                                                                              // nil-safe
		objSelectorProg, err := watchGraphMaybePrepareBoolCELProgram(s.objEnv, object.GetObjectSelectorExpression()) // nil-safe
		if err != nil {
			return nil, status.Errorf(codes.InvalidArgument, "object selector expression: %v", err)
		}
		jsonPathStr := object.GetJsonPath() // nil-safe
		var jsonPathParser *jsonpath.JSONPath
		if jsonPathStr != "" {
			jsonPathParser = jsonpath.New("")
			err = jsonPathParser.Parse("{" + jsonPathStr + "}")
			if err != nil {
				return nil, status.Errorf(codes.InvalidArgument, "JSON path: %v", err)
			}
		}
		return watch_graph.QueryInclude{
			ResourceSelectorExpression: resourceSelectorProg,
			Object: watch_graph.QueryIncludeObject{
				LabelSelector:            object.GetLabelSelector(), // nil-safe
				FieldSelector:            object.GetFieldSelector(), // nil-safe
				ObjectSelectorExpression: objSelectorProg,
				JSONPath:                 jsonPathParser,
			},
		}, nil
	case *rpc.Query_Exclude_:
		resourceSelectorProg, err := watchGraphPrepareBoolCELProgram(s.resEnv, q.Exclude.ResourceSelectorExpression)
		if err != nil {
			return nil, status.Errorf(codes.InvalidArgument, "resource selector expression: %v", err)
		}
		return watch_graph.QueryExclude{
			ResourceSelectorExpression: resourceSelectorProg,
		}, nil
	default:
		// Should never happen
		return nil, status.Errorf(codes.Internal, "unexpected query type: %T", query.Query)
	}
}

// watchGraphConstructNamespaces returns gRPC-status-compatible error.
func (s *server) watchGraphConstructNamespaces(rns *rpc.Namespaces) (watch_graph.Namespaces, error) {
	if rns == nil {
		return watch_graph.Namespaces{}, nil
	}
	program, err := watchGraphMaybePrepareBoolCELProgram(s.objEnv, rns.ObjectSelectorExpression)
	if err != nil {
		return watch_graph.Namespaces{}, status.Errorf(codes.InvalidArgument, "namespaces object selector expression: %v", err)
	}

	return watch_graph.Namespaces{
		Names:                    sets.New(rns.Names...),
		LabelSelector:            rns.LabelSelector,
		FieldSelector:            rns.FieldSelector,
		ObjectSelectorExpression: program,
	}, nil
}

func (s *server) watchGraphStreamResponse(ctx context.Context, actionCh <-chan *rpc.Action, warnCh <-chan *rpc.Warning,
	errCh <-chan *watch_graph.Error, srv grpc.ServerStreamingServer[rpc.WatchGraphResponse]) error {
	resp := &rpc.WatchGraphResponse{}
	done := ctx.Done()
	for {
		select { // get the first thing to send
		case <-done:
			return nil
		case err := <-errCh:
			return graphError2grpcError(err)
		case action := <-actionCh:
			resp.Actions = append(resp.Actions, action)
		case err := <-warnCh:
			resp.Warnings = append(resp.Warnings, err)
		}
		shouldRunGosched := true
		// Try to get more things to send - batch send.
	batch:
		for len(resp.Actions)+len(resp.Warnings) < maxBatchSize {
			select {
			case <-done:
				// The client doesn't care about the buffered actions/warnings, they can be safely discarded.
				return nil
			case err := <-errCh:
				return graphError2grpcError(err)
			case action := <-actionCh:
				resp.Actions = append(resp.Actions, action)
				shouldRunGosched = true
			case warn := <-warnCh:
				resp.Warnings = append(resp.Warnings, warn)
				shouldRunGosched = true
			default:
				if shouldRunGosched {
					shouldRunGosched = false
					runtime.Gosched() // let the producer put more things into the channels to increase batch size
					continue
				}
				break batch
			}
		}
		err := srv.Send(resp)
		if err != nil {
			return err
		}
		clear(resp.Actions) // remove object references to help GC
		clear(resp.Warnings)
		resp.Actions = resp.Actions[:0] // start from the beginning
		resp.Warnings = resp.Warnings[:0]
	}
}

func graphError2grpcError(err *watch_graph.Error) error {
	var code codes.Code
	switch err.Code {
	case watch_graph.InvalidArgument:
		code = codes.InvalidArgument
	case watch_graph.Unavailable:
		code = codes.Unavailable
	case watch_graph.InternalError:
		code = codes.Internal
	default:
		code = codes.Unknown
	}
	return status.Error(code, err.Message)
}

func watchGraphMaybePrepareBoolCELProgram(env *cel.Env, celExpr string) (cel.Program, error) {
	if celExpr == "" {
		return nil, nil
	}
	return watchGraphPrepareBoolCELProgram(env, celExpr)
}

func watchGraphPrepareBoolCELProgram(env *cel.Env, celExpr string) (cel.Program, error) {
	return cel2program(env, celExpr, cel.BoolType)
}

func cel2program(env *cel.Env, celExpr string, expectedOutput *cel.Type) (cel.Program, error) {
	ast, iss := env.Parse(celExpr)
	err := iss.Err()
	if err != nil {
		return nil, err
	}
	checkedAST, iss := env.Check(ast)
	err = iss.Err()
	if err != nil {
		return nil, err
	}
	outputType := checkedAST.OutputType()
	if outputType != expectedOutput {
		return nil, fmt.Errorf("expected %q got %s", expectedOutput, outputType)
	}
	return env.Program(
		checkedAST,
		cel.InterruptCheckFrequency(100),
		cel.EvalOptions(cel.OptOptimize),
	)
}

func joinBoolCELExprs(exprs []string) string {
	switch len(exprs) {
	case 0:
		return ""
	case 1:
		return exprs[0]
	default:
		var sb strings.Builder
		sb.WriteByte('(')
		for i, expr := range exprs {
			sb.WriteString(expr)
			if i == len(exprs)-1 { // last one
				sb.WriteByte(')')
			} else {
				sb.WriteString(")||(")
			}
		}
		return sb.String()
	}
}

func attributes(attrs []byte) fieldz.Field {
	if len(attrs) == 0 {
		return fieldz.Field{}
	}
	return fieldz.Field{Key: fieldz.AttributesFieldName, Value: string(attrs)}
}

func objectToEvalVars(obj *unstructured.Unstructured, gvr schema.GroupVersionResource) map[string]any {
	return map[string]any{
		"obj":         obj.Object,
		"group":       gvr.Group,
		"version":     gvr.Version,
		"resource":    gvr.Resource,
		"namespace":   obj.GetNamespace(),
		"name":        obj.GetName(),
		"labels":      obj.GetLabels(),
		"annotations": obj.GetAnnotations(),
	}
}

func arcTypeStringSliceToTypedSlice(arcType []string) watch_graph.ArcTypeSet {
	var res watch_graph.ArcTypeSet
	for _, at := range arcType {
		res.Add(watch_graph.ParseArcTypeStr(at))
	}
	return res
}

type vertexIDValuer watch_graph.VertexID

// LogValue lazily constructs the string representation of the vertex ID.
func (v vertexIDValuer) LogValue() slog.Value {
	gvr := v.GVR.Value()
	return slog.StringValue(fmt.Sprintf("%s/%s/%s ns=%s n=%s", gvr.Group, gvr.Version, gvr.Resource, v.Namespace, v.Name))
}

func vidAttr(vid watch_graph.VertexID) slog.Attr {
	return slog.Any(logz.VertexID, vertexIDValuer(vid))
}

func fromAttr(from watch_graph.VertexID) slog.Attr {
	return slog.Any(logz.FromVertexID, vertexIDValuer(from))
}

func toAttr(to watch_graph.VertexID) slog.Attr {
	return slog.Any(logz.ToVertexID, vertexIDValuer(to))
}

func arcAttr(from watch_graph.VertexID, to watch_graph.ArcToID[watch_graph.VertexID, watch_graph.ArcType]) slog.Attr {
	return slog.Group(
		"",
		fromAttr(from),
		toAttr(to.To),
		slog.String(logz.ArcType, to.ArcType.String()),
	)
}

func vidRootAttr(root bool) slog.Attr {
	return slog.Bool(logz.VidRoot, root)
}

func vidWasRootAttr(wasRoot bool) slog.Attr {
	return slog.Bool(logz.VidWasRoot, wasRoot)
}
