package router

import (
	"io"
	"log/slog"

	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/grpctool"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/logz"
	tunrpc "gitlab.com/gitlab-org/cluster-integration/tunnel/rpc"
	"gitlab.com/gitlab-org/cluster-integration/tunnel/tool/grpcz"
	"gitlab.com/gitlab-org/cluster-integration/tunnel/tunserver"
	"google.golang.org/grpc"
	"google.golang.org/grpc/mem"
)

type SimpleRPCForwarder struct {
	service, method string
	tunnel          grpc.ServerStream
	tunnelRetErr    chan<- error
}

func NewSimpleRPCForwarder(tunnel grpc.ServerStream, tunnelRetErr chan<- error) *SimpleRPCForwarder {
	sts := grpc.ServerTransportStreamFromContext(tunnel.Context())
	service, method := grpcz.SplitGRPCMethod(sts.Method())

	return &SimpleRPCForwarder{
		service:      service,
		method:       method,
		tunnel:       tunnel,
		tunnelRetErr: tunnelRetErr,
	}
}

func (f *SimpleRPCForwarder) SupportsServiceAndMethod(service, method string) bool {
	return f.service == service && f.method == method
}

func (f *SimpleRPCForwarder) ForwardStream(log *slog.Logger, incomingStream grpc.ServerStream, cb tunserver.DataCallback) error {
	pair := f.forwardStream(log, incomingStream, cb)
	f.tunnelRetErr <- pair.forTunnel
	return pair.forIncomingStream
}

func (f *SimpleRPCForwarder) forwardStream(log *slog.Logger, incomingStream grpc.ServerStream, cb tunserver.DataCallback) errPair {
	// Channel of size 1 to ensure that if we return early, the second goroutine has space for the value.
	// We don't care about the second value if the first one has at least one non-nil error.
	res := make(chan errPair, 1)
	incomingCtx := incomingStream.Context()

	// Goroutine 1: Pipe incoming stream (gateway request) into the tunnel stream
	goErrPair(res, func() (error /* forTunnel */, error /* forIncomingStream */) {
		return f.pipeIncomingToTunnel(log, incomingStream)
	})

	// Goroutine 2: Pipe tunnel stream into the callback (back to gateway)
	goErrPair(res, func() (error /* forTunnel */, error /* forIncomingStream */) {
		return f.pipeTunnelToCallback(log, cb)
	})

	pair := <-res
	if !pair.isNil() {
		return pair
	}

	select {
	case <-incomingCtx.Done():
		// incoming stream finished sending all data (i.e. io.EOF was read from it) but
		// now it signals that it's closing. We need to abort the potentially stuck t.tunnel.RecvMsg().
		err := grpctool.StatusErrorFromContext(incomingCtx, "Incoming stream closed")
		pair = errPair{
			forTunnel:         err,
			forIncomingStream: err,
		}
	case pair = <-res:
	}
	return pair
}

// pipeIncomingToTunnel reads messages from the incoming stream (gateway request)
// and forwards them to the client through the tunnel as AdmitJobResponse messages.
func (f *SimpleRPCForwarder) pipeIncomingToTunnel(log *slog.Logger, incomingStream grpc.ServerStream) (error, error) {
	for {
		// Receive raw message from incoming stream (gateway)
		frame := tunrpc.RawFrame{}
		err := incomingStream.RecvMsg(&frame)
		if err != nil {
			if err == io.EOF { //nolint:errorlint
				// Gateway finished sending, close tunnel gracefully
				return nil, nil
			}
			// Error reading from gateway
			log.Debug("RecvMsg(SimpleRPCForwarder): Failed to receive message from incoming stream", logz.Error(err))
			return err, err
		}

		// Forward the message to the client via tunnel
		err = f.tunnel.SendMsg(&frame)
		frame.Data.Free() // No longer needed, free ASAP.
		if err != nil {
			// Error sending to client
			log.Debug("SendMsg(SimpleRPCForwarder): Failed to send message to tunnel", logz.Error(err))
			return err, err
		}
	}
}

// pipeTunnelToCallback reads AdmitJobRequest messages from the client through the tunnel
// and forwards them back to the gateway via the callback.
func (f *SimpleRPCForwarder) pipeTunnelToCallback(log *slog.Logger, cb tunserver.DataCallback) (error, error) {
	for {
		// Receive message from client via tunnel
		var frame tunrpc.RawFrame
		err := f.tunnel.RecvMsg(&frame)
		if err != nil {
			if err == io.EOF { //nolint:errorlint
				// Client finished sending, close gateway gracefully
				return nil, nil
			}
			// Error reading from client
			log.Debug("RecvMsg(SimpleRPCForwarder): Failed to receive message from tunnel", logz.Error(err))
			return err, err
		}

		data := frame.Data.MaterializeToBuffer(mem.DefaultBufferPool())
		frame.Data.Free() // No longer needed, free ASAP.

		// Forward to gateway via callback
		err = cb.Message(data.ReadOnlyData())
		data.Free() // No longer needed, free ASAP.
		if err != nil {
			// Error sending to gateway
			log.Debug("Message(SimpleRPCForwarder): Failed to send message via callback", logz.Error(err))
			return err, err
		}
	}
}

type errPair struct {
	forTunnel         error
	forIncomingStream error
}

func (p errPair) isNil() bool {
	return p.forTunnel == nil && p.forIncomingStream == nil
}

func goErrPair(c chan<- errPair, f func() (error /* forTunnel */, error /* forIncomingStream */)) {
	go func() {
		var pair errPair
		pair.forTunnel, pair.forIncomingStream = f()
		c <- pair
	}()
}
