package server

import (
	"context"
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"log/slog"
	"net/http"
	"time"

	"buf.build/go/protovalidate"
	"github.com/coder/websocket"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/api"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/module/kubernetes_api"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/module/kubernetes_api/rpc"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/module/modserver"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/grpctool"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/logz"
	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
	"google.golang.org/protobuf/encoding/protojson"
	"google.golang.org/protobuf/proto"
)

const (
	graphAPISubprotocolName = "gitlab-agent-graph-api"
	graphAPIReadLimit       = 4 * 1024

	// Constants below match our WebSocket library:

	// maxControlPayload is the maximum length of a control frame payload.
	// See https://tools.ietf.org/html/rfc6455#section-5.5.
	maxControlPayload = 125
	maxCloseReason    = maxControlPayload - 2
)

type graphAPIHandler struct {
	log        *slog.Logger
	api        modserver.API
	validator  protovalidate.Validator
	watchGraph func(ctx context.Context, agentKey api.AgentKey, in *rpc.WatchGraphRequest) (grpc.ServerStreamingClient[rpc.WatchGraphResponse], error)
}

func (h *graphAPIHandler) Handle(auxCtx context.Context, agentKey api.AgentKey, impConfig *kubernetes_api.ImpersonationConfig, w http.ResponseWriter, r *http.Request) {
	// 1. accept WebSocket
	c, err := h.acceptWebSocket(w, r) //nolint: contextcheck
	if err != nil {
		h.log.Debug("GraphAPI: WebSocket error", logz.Error(err))
		return
	}

	// 2. read the request message from client
	req, err := h.readRequest(auxCtx, c)
	if err != nil {
		h.log.Debug("GraphAPI: read request error", logz.Error(err))
		return
	}

	// 3. client is not supposed to send us anything else. Fail on any data messages.
	ctx := c.CloseRead(auxCtx)

	// 4. establish a watch RPC to agentk
	ctx, cancel := context.WithCancel(ctx)
	defer cancel() // ensure RPC is canceled on return
	watch, err := h.establishWatch(ctx, agentKey, impConfig, req, c)
	if err != nil {
		h.log.Error("GraphAPI: WatchGraph failed", logz.Error(err))
		return
	}

	// 5. process watch stream
	h.processWatch(auxCtx, watch, c)
}

func (h *graphAPIHandler) acceptWebSocket(w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) {
	c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
		Subprotocols: []string{graphAPISubprotocolName},
		// InsecureSkipVerify is used to disable the Origin check.
		// We enable it because:
		// - (1) we already do the Origin check in the proxy.
		// - (2) the websocket package doesn't implement it correctly:
		//       see https://github.com/coder/websocket/issues/529
		InsecureSkipVerify: true,
	})
	if err != nil {
		return nil, err // no need to wrap
	}
	if c.Subprotocol() == "" {
		err = fmt.Errorf("protocol unspecified, please use the %q protocol", graphAPISubprotocolName)
		h.closeWebSocket(c, websocket.StatusProtocolError, err)
		return nil, err
	}
	c.SetReadLimit(graphAPIReadLimit)
	return c, nil
}

func (h *graphAPIHandler) readRequest(auxCtx context.Context, c *websocket.Conn) (*rpc.WatchGraphWebSocketRequest, error) {
	msgType, data, err := c.Read(auxCtx)
	if err != nil {
		return nil, err
	}

	if msgType != websocket.MessageText {
		err = fmt.Errorf("expecting a WebSocket text message, got type %d", msgType)
		h.closeWebSocket(c, websocket.StatusUnsupportedData, err) //nolint: contextcheck
		return nil, err
	}
	req := &rpc.WatchGraphWebSocketRequest{}
	err = protojson.Unmarshal(data, req)
	if err != nil {
		err = fmt.Errorf("proto unmarshal: %w", err)
		h.closeWebSocket(c, websocket.StatusInvalidFramePayloadData, err) //nolint: contextcheck
		return nil, err
	}
	err = h.validator.Validate(req)
	if err != nil {
		err = fmt.Errorf("validate: %w", err)
		h.closeWebSocket(c, websocket.StatusInvalidFramePayloadData, err) //nolint: contextcheck
		return nil, err
	}
	return req, nil
}

func (h *graphAPIHandler) establishWatch(ctx context.Context, agentKey api.AgentKey, impConfig *kubernetes_api.ImpersonationConfig,
	req *rpc.WatchGraphWebSocketRequest, c *websocket.Conn) (grpc.ServerStreamingClient[rpc.WatchGraphResponse], error) {
	watch, err := h.watchGraph(ctx, agentKey, &rpc.WatchGraphRequest{
		ImpConfig:  impConfig,
		Queries:    req.Queries,
		Namespaces: req.Namespaces,
		Roots:      req.Roots,
	})
	if err != nil {
		// TODO handle errors like canceled
		h.closeWebSocket(c, websocket.StatusInternalError, fmt.Errorf("GraphAPI: WatchGraph failed: %w", err)) //nolint: contextcheck
		return nil, err
	}
	return watch, nil
}

func (h *graphAPIHandler) processWatch(auxCtx context.Context, watch grpc.ServerStreamingClient[rpc.WatchGraphResponse], c *websocket.Conn) {
	for {
		watchMsg, err := watch.Recv()
		if err != nil {
			var respStatus websocket.StatusCode
			var respReason error
			switch {
			case err == io.EOF: //nolint:errorlint
				respStatus = websocket.StatusGoingAway
				respReason = errors.New("agentk closed connection")
			case grpctool.RequestCanceled(err):
				respStatus = websocket.StatusGoingAway
				switch {
				case auxCtx.Err() != nil:
					respReason = fmt.Errorf("kas shutting down: %w", err)
				case watch.Context().Err() != nil:
					respReason = fmt.Errorf("client closed connection: %w", err)
				default:
					respReason = fmt.Errorf("agentk dropped connection: %w", err)
				}
			case status.Code(err) == codes.InvalidArgument:
				respStatus = websocket.StatusInvalidFramePayloadData
				respReason = err
			default:
				respStatus = websocket.StatusInternalError
				respReason = err
			}
			h.log.Debug("GraphAPI: WatchGraph error", logz.Error(respReason))
			h.closeWebSocket(c, respStatus, respReason) //nolint: contextcheck
			return
		}
		respData, err := prepareJSONResponse(watchMsg)
		if err != nil {
			h.api.HandleProcessingError(auxCtx, h.log, "GraphAPI: JSON marshal", err)
			h.closeWebSocket(c, websocket.StatusInternalError, fmt.Errorf("GraphAPI: JSON marshal: %w", err)) //nolint: contextcheck
			return
		}
		err = c.Write(auxCtx, websocket.MessageText, respData)
		if err != nil {
			h.log.Debug("GraphAPI: WebSocket write failed", logz.Error(err))
			return
		}
	}
}

func (h *graphAPIHandler) closeWebSocket(c *websocket.Conn, status websocket.StatusCode, reasonErr error) {
	reason := reasonErr.Error()
	if isErrorReasonTooLong(reason) {
		data, err := json.Marshal(&jsonWatchGraphResponse{
			Error: &jsonError{
				Code:    status,
				CodeStr: status.String(),
				Reason:  reason,
			},
		})
		if err != nil {
			h.log.Error("json.Marshal", logz.Error(err))
		} else {
			ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
			defer cancel()
			_ = c.Write(ctx, websocket.MessageText, data)
		}

		reason = reason[:maxCloseReason-3] + "..."
	}
	err := c.Close(status, reason)
	if err != nil {
		h.log.Error("WebSocket close error", logz.Error(err))
	}
}

func isErrorReasonTooLong(reason string) bool {
	return len(reason) > maxCloseReason
}

func prepareJSONResponse(watchMsg *rpc.WatchGraphResponse) ([]byte, error) {
	actions, err := prepareJSONResponseActions(watchMsg)
	if err != nil {
		return nil, err
	}
	return json.Marshal(&jsonWatchGraphResponse{
		Actions:  actions,
		Warnings: prepareJSONResponseWarnings(watchMsg),
	})
}

func prepareJSONResponseActions(watchMsg *rpc.WatchGraphResponse) ([]jsonWatchGraphAction, error) {
	actions := make([]jsonWatchGraphAction, 0, len(watchMsg.Actions))
	for _, action := range watchMsg.Actions {
		var act jsonWatchGraphAction
		switch a := action.Action.(type) {
		case *rpc.Action_SetVertex_:
			act.SetVertex = &jsonSetVertex{
				Vertex:      rpcVertexToJSONVertex(a.SetVertex.Vertex),
				Object:      a.SetVertex.Object,
				JSONPath:    a.SetVertex.JsonPath,
				HelmRelease: a.SetVertex.HelmRelease,
			}
		case *rpc.Action_DeleteVertex_:
			act.DeleteVertex = &jsonProtoMsg{msg: a.DeleteVertex}
		case *rpc.Action_SetArc_:
			act.SetArc = &jsonSetArc{
				Source:      rpcVertexToJSONVertex(a.SetArc.Source),
				Destination: rpcVertexToJSONVertex(a.SetArc.Destination),
				Type:        a.SetArc.Type,
				Attributes:  a.SetArc.Attributes,
			}
		case *rpc.Action_DeleteArc_:
			act.DeleteArc = &jsonProtoMsg{msg: a.DeleteArc}
		default:
			return nil, fmt.Errorf("unknown action type: %T", action)
		}
		actions = append(actions, act)
	}
	return actions, nil
}

func prepareJSONResponseWarnings(watchMsg *rpc.WatchGraphResponse) []jsonWarning {
	warnings := make([]jsonWarning, 0, len(watchMsg.Warnings))
	for _, warn := range watchMsg.Warnings {
		warnings = append(warnings, jsonWarning{
			Type:       warn.Type,
			Message:    warn.Message,
			Attributes: warn.Attributes,
		})
	}
	return warnings
}

var (
	_ json.Marshaler = jsonProtoMsg{}
)

type jsonWatchGraphResponse struct {
	Actions  []jsonWatchGraphAction `json:"actions,omitempty"`
	Warnings []jsonWarning          `json:"warnings,omitempty"`
	Error    *jsonError             `json:"error,omitempty"`
}

type jsonWatchGraphAction struct {
	SetVertex    *jsonSetVertex `json:"svx,omitempty"`
	DeleteVertex *jsonProtoMsg  `json:"dvx,omitempty"`
	SetArc       *jsonSetArc    `json:"sarc,omitempty"`
	DeleteArc    *jsonProtoMsg  `json:"darc,omitempty"`
}

type jsonProtoMsg struct {
	msg proto.Message
}

func (j jsonProtoMsg) MarshalJSON() ([]byte, error) {
	return protojson.Marshal(j.msg)
}

func rpcVertexToJSONVertex(v *rpc.Action_Vertex) jsonVertex {
	return jsonVertex{
		Group:     v.Group,
		Version:   v.Version,
		Resource:  v.Resource,
		Namespace: v.Namespace,
		Name:      v.Name,
	}
}

type jsonVertex struct {
	Group     string `json:"g,omitempty"`
	Version   string `json:"v"`
	Resource  string `json:"r"`
	Namespace string `json:"ns,omitempty"`
	Name      string `json:"n"`
}

type jsonSetVertex struct {
	Vertex      jsonVertex      `json:"vx"`
	Object      json.RawMessage `json:"o,omitempty"`
	JSONPath    json.RawMessage `json:"j,omitempty"`
	HelmRelease json.RawMessage `json:"hr,omitempty"`
}

type jsonSetArc struct {
	Source      jsonVertex      `json:"s"`
	Destination jsonVertex      `json:"d"`
	Type        string          `json:"t"`
	Attributes  json.RawMessage `json:"a,omitempty"`
}

type jsonWarning struct {
	Type       string          `json:"t"`
	Message    string          `json:"m"`
	Attributes json.RawMessage `json:"a,omitempty"`
}

type jsonError struct {
	Code    websocket.StatusCode `json:"code"`
	CodeStr string               `json:"code_string"`
	Reason  string               `json:"reason"`
}
