package grpctool

import (
	"bufio"
	"context"
	"errors"
	"fmt"
	"io"
	"log/slog"
	"net"
	"net/http"

	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/logz"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/memz"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/prototool"
	"gitlab.com/gitlab-org/cluster-integration/tunnel/tool/streamvisitor"
	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
)

type InboundGRPCToOutboundHTTPStream = grpc.ServerStreamingServer[HttpResponse]
type HandleProcessingErrorFunc func(msg string, err error)
type HandleIOErrorFunc func(msg string, err error) error

type DoResponse struct {
	// Resp is the server's response to a request.
	Resp *http.Response
	// UpgradeConn is the underlying network connection to the server.
	// May be nil if request was not an Upgrade request or if server decided not to switch protocols
	// (non-101 response status code).
	UpgradeConn net.Conn
	// ConnReader is a buffered reader, wrapping UpgradeConn. Is set when UpgradeConn is set.
	// Must be used for reading as it may contain buffered bytes that are no longer available directly via UpgradeConn.
	ConnReader *bufio.Reader
}

// HTTPDo makes an HTTP request and returns a response. If an HTTP upgrade was requested, the underlying network
// connection is also returned. Implementations that don't support Upgrade should return an error.
type HTTPDo func(ctx context.Context, header *HttpRequest_Header, body io.Reader) (DoResponse, error)

type InboundGRPCToOutboundHTTP struct {
	Log                   *slog.Logger
	HandleProcessingError HandleProcessingErrorFunc
	HandleIOError         HandleIOErrorFunc
	HTTPDo                HTTPDo
}

func (x *InboundGRPCToOutboundHTTP) Pipe(inbound InboundGRPCToOutboundHTTPStream) (retErr error) {
	var upgradeConn net.Conn
	defer func() {
		if upgradeConn != nil {
			err := upgradeConn.Close()
			if retErr == nil {
				retErr = x.maybeHandleIOError("error closing connection", err)
			}
		}
	}()

	ctx := inbound.Context()

	pr, pw := io.Pipe()
	headerC := make(chan *HttpRequest_Header)
	// buffered to not block the sender as receiver might encounter an error and exit before even trying to receive.
	respC := make(chan DoResponse, 1)
	s := InboundStreamToOutboundStream{
		// Pipe gRPC request -> HTTP request
		PipeInboundToOutbound: func() error {
			// unblock the PipeOutboundToInbound goroutine if we exited before sending the header due to an error.
			defer close(headerC)

			// NOTE: during normal operation pipeInboundToOutbound() cleanly closes pw.
			// However, if the inbound stream was broken or on any other error (e.g. message validation error),
			// we must ensure pw is closed to unblock the reader of pr.
			// Hence, we pass HTTPRequestStreamVisitor's error, nil or not, to CloseWithError(err).
			// pw.CloseWithError() is a no-op if the stream had been closed already, so it's safe to call multiple times.

			err := x.pipeInboundToOutbound(inbound, headerC, respC, pw)
			_ = pw.CloseWithError(err) // We don't care about the returned error, it's always nil.
			return err
		},
		// Pipe HTTP response -> gRPC response
		PipeOutboundToInbound: func() error {
			// Make sure the writer is unblocked if we exit abruptly
			// The error is ignored because it will always occur if things go normally - the pipe will have been
			// closed already when this code is reached (and that's an error).
			defer pr.Close() //nolint: errcheck
			// unblock the PipeInboundToOutbound goroutine if we exited before sending the response object due to an error.
			defer close(respC)
			select {
			case <-ctx.Done():
				return ctx.Err()
			case header, ok := <-headerC:
				if !ok {
					// Something went wrong in the PipeInboundToOutbound goroutine, exit.
					return nil
				}

				var body io.Reader
				if header.IsRequestWithoutBody() {
					// NOTE: The golang standard library will add a `Transfer-Encoding: chunked` to the request
					// for bodies with unknown size - which upgrade requests are,
					// see https://github.com/golang/go/blob/39ca989b883b913287d282365510a9152a3f80e6/src/net/http/transfer.go#L95
					// This leads to a zero-sized chunked HTTP body (`0 CR LF CR LF`) during upgrade requests which may
					// not be consumed by certain HTTP servers before hijacking the connection and switching
					// to "raw" TCP mode, namely the spdy upgrade logic in the Kubernetes apimachinery pkg (used in CRIs), see
					// https://github.com/kubernetes/kubernetes/blob/f51dad586ddc1a02b4fcc4e3974092ad78b630a7/staging/src/k8s.io/apimachinery/pkg/util/httpstream/spdy/upgrade.go#LL86C9-L86C9
					// However, we suspect that there is another bug on the Kubernetes stack to sometimes consumes
					// these additionally bytes in the body and forwards a correct request to destination (e.g. CRI).
					// See https://gitlab.com/gitlab-org/cluster-integration/gitlab-agent/-/issues/393
					body = http.NoBody
				} else {
					body = pr
				}
				r, err := x.HTTPDo(ctx, header, body)
				if err != nil {
					return err
				}
				respC <- r
				// this store is not synchronized and that's ok because PipeOutboundToInbound is executed
				// on the caller's goroutine.
				upgradeConn = r.UpgradeConn
				return x.pipeOutboundToInbound(inbound, r, header.Request.IsUpgrade())
			}
		},
	}
	err := s.Pipe()
	switch {
	case err == nil:
	case IsStatusError(err):
		// A gRPC status already
	case errors.Is(err, context.Canceled):
		x.Log.Debug("gRPC -> HTTP", logz.Error(err))
		err = status.Errorf(codes.Canceled, "gRPC -> HTTP: %v", err)
	case errors.Is(err, context.DeadlineExceeded):
		x.Log.Debug("gRPC -> HTTP", logz.Error(err))
		err = status.Errorf(codes.DeadlineExceeded, "gRPC -> HTTP: %v", err)
	default:
		x.HandleProcessingError("gRPC -> HTTP", err)
		err = status.Errorf(codes.Unavailable, "gRPC -> HTTP: %v", err)
	}
	return err
}

func (x *InboundGRPCToOutboundHTTP) pipeInboundToOutbound(inbound InboundGRPCToOutboundHTTPStream,
	headerC chan<- *HttpRequest_Header, respC <-chan DoResponse, pw *io.PipeWriter) error {
	var isUpgrade bool
	var notExpectingBody bool
	var upgradeConn net.Conn
	return HTTPRequestStreamVisitor().Visit(inbound,
		streamvisitor.WithCallback(HTTPRequestHeaderFieldNumber, func(header *HttpRequest_Header) error {
			x.logRequest(header)
			isUpgrade = header.Request.IsUpgrade()
			notExpectingBody = header.IsRequestWithoutBody()
			ctx := inbound.Context()
			select {
			case <-ctx.Done():
				return ctx.Err()
			case headerC <- header:
				return nil
			}
		}),
		streamvisitor.WithCallback(HTTPRequestDataFieldNumber, func(data *HttpRequest_Data) error {
			if notExpectingBody {
				return status.Errorf(codes.Internal, "unexpected HttpRequest_Data message received")
			}
			_, err := pw.Write(data.Data)
			return x.maybeHandleIOError("request body write", err)
		}),
		streamvisitor.WithCallback(HTTPRequestTrailerFieldNumber, func(trailer *HttpRequest_Trailer) error {
			if isUpgrade {
				// Nothing more to send, close the write end of the pipe
				err := pw.Close()
				return x.maybeHandleIOError("request body close", err)
			}
			// Nothing to do
			return nil
		}),
		streamvisitor.WithCallback(HTTPRequestUpgradeDataFieldNumber, func(data *HttpRequest_UpgradeData) error {
			if !isUpgrade {
				// Inbound client didn't request a connection upgrade but sent an upgrade data frame.
				return status.Error(codes.Internal, "unexpected HttpRequest_UpgradeData message for non-upgrade request")
			}
			if upgradeConn == nil {
				r, ok := <-respC
				if !ok {
					// error in the other goroutine, abort.
					return context.Canceled
				}
				if r.Resp.StatusCode != http.StatusSwitchingProtocols {
					// Outbound server doesn't want to switch protocols but inbound client sent an upgrade data frame.
					return status.Errorf(codes.Internal, "unexpected HttpRequest_UpgradeData message for HTTP status code %d", r.Resp.StatusCode)
				}
				upgradeConn = r.UpgradeConn
			}
			_, err := upgradeConn.Write(data.Data)
			return x.maybeHandleIOError("upgrade request write", err)
		}),
		streamvisitor.WithEOFCallback(func() error {
			if !isUpgrade {
				// Nothing more to send, close the write end of the pipe
				err := pw.Close()
				return x.maybeHandleIOError("request body close", err)
			}
			return nil
		}),
	)
}

func (x *InboundGRPCToOutboundHTTP) logRequest(header *HttpRequest_Header) {
	if !x.Log.Enabled(context.Background(), slog.LevelDebug) {
		return
	}
	req := header.Request
	if len(req.Query) > 0 {
		x.Log.Debug(fmt.Sprintf("Handling %s %s?%s", req.Method, req.UrlPath, req.URLQuery().Encode()))
	} else {
		x.Log.Debug(fmt.Sprintf("Handling %s %s", req.Method, req.UrlPath))
	}
}

func (x *InboundGRPCToOutboundHTTP) pipeOutboundToInbound(inbound InboundGRPCToOutboundHTTPStream, r DoResponse, isUpgrade bool) error {
	err := x.sendResponseHeaderAndBody(inbound, r.Resp)
	if err != nil {
		return err
	}

	err = inbound.Send(&HttpResponse{
		Message: &HttpResponse_Trailer_{
			Trailer: &HttpResponse_Trailer{},
		},
	})
	if err != nil {
		return x.handleIOError("SendMsg(HttpResponse_Trailer) failed", err)
	}
	if isUpgrade && r.Resp.StatusCode == http.StatusSwitchingProtocols {
		// Only stream if upgrade was requested AND outbound server is switching protocols.
		return x.sendUpgradeResponseStream(inbound, r.ConnReader)
	}
	return nil
}

func (x *InboundGRPCToOutboundHTTP) sendResponseHeaderAndBody(inbound InboundGRPCToOutboundHTTPStream, resp *http.Response) (retErr error) {
	defer func() {
		err := resp.Body.Close()
		if retErr == nil {
			retErr = x.maybeHandleIOError("response body close", err)
		}
	}()
	err := inbound.Send(&HttpResponse{
		Message: &HttpResponse_Header_{
			Header: &HttpResponse_Header{
				Response: &prototool.HttpResponse{
					StatusCode: int32(resp.StatusCode), //nolint: gosec
					Status:     resp.Status,
					Header:     prototool.HTTPHeaderToHeaderKV(resp.Header),
				},
			},
		},
	})
	if err != nil {
		return x.handleIOError("SendMsg(HttpResponse_Header) failed", err)
	}

	bp := memz.Get32k()
	defer memz.Put32k(bp)
	buffer := *bp
	for {
		n, readErr := resp.Body.Read(buffer)
		if n > 0 { // handle n>0 before readErr != nil to ensure any consumed data gets forwarded
			sendErr := inbound.Send(&HttpResponse{
				Message: &HttpResponse_Data_{
					Data: &HttpResponse_Data{
						Data: buffer[:n],
					},
				},
			})
			if sendErr != nil {
				return x.handleIOError("SendMsg(HttpResponse_Data) failed", sendErr)
			}
		}
		if readErr != nil {
			if readErr == io.EOF {
				break
			}
			return x.handleIOError("read HTTP response body", readErr)
		}
	}
	return nil
}

func (x *InboundGRPCToOutboundHTTP) sendUpgradeResponseStream(inbound InboundGRPCToOutboundHTTPStream, upgradeConn *bufio.Reader) error {
	bp := memz.Get32k()
	defer memz.Put32k(bp)
	buffer := *bp
	for {
		n, readErr := upgradeConn.Read(buffer)
		if n > 0 { // handle n>0 before readErr != nil to ensure any consumed data gets forwarded
			sendErr := inbound.Send(&HttpResponse{
				Message: &HttpResponse_UpgradeData_{
					UpgradeData: &HttpResponse_UpgradeData{
						Data: buffer[:n],
					},
				},
			})
			if sendErr != nil {
				return x.handleIOError("SendMsg(HttpResponse_UpgradeData) failed", sendErr)
			}
		}
		if readErr != nil {
			if readErr == io.EOF {
				break
			}
			return x.handleIOError("read upgrade response body", readErr)
		}
	}
	return nil
}

func (x *InboundGRPCToOutboundHTTP) maybeHandleIOError(msg string, err error) error {
	if err != nil {
		return x.handleIOError(msg, err)
	}
	return nil
}

func (x *InboundGRPCToOutboundHTTP) handleIOError(msg string, err error) error {
	return x.HandleIOError("gRPC -> HTTP: "+msg, err)
}
