package agentw

import (
	"context"
	"crypto/sha1"
	"encoding/base64"
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"net"
	"net/http"
	"net/http/httptest"
	"net/url"
	"strconv"
	"testing"
	"time"

	"buf.build/go/protovalidate"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/module/modshared"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/module/workspaces/rpc"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/grpctool"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/prototool"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/mock_modserver"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/testlogger"
	"go.uber.org/mock/gomock"
	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/credentials/insecure"
	"google.golang.org/grpc/status"
	"google.golang.org/protobuf/types/known/anypb"
)

// TestTunnelServer tests the tunnel server with a real gRPC server and http backend
func TestTunnelServer(t *testing.T) {
	// Start HTTP test server
	httpPort, httpCleanupFunc := startTestHTTPServer(t)
	defer httpCleanupFunc()

	// Create and start the gRPC server
	grpcServer, grpcAddr := startTestGRPCServer(t)
	defer grpcServer.Stop()

	// Create gRPC client
	conn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
	require.NoError(t, err)
	defer conn.Close()

	client := rpc.NewWorkspacesClient(conn)

	tests := []struct {
		name               string
		method             string
		path               string
		queryParams        url.Values
		headers            http.Header
		body               string
		expectedStatusCode int32
		validateBody       func(*testing.T, []byte)
	}{
		{
			name:   "GET request with no body",
			method: http.MethodGet,
			path:   "/get",
			queryParams: url.Values{
				"test-param": []string{"test-value"},
			},
			headers: http.Header{
				"X-Test-Header": []string{"test-header-value"},
			},
			expectedStatusCode: http.StatusOK,
			validateBody: func(t *testing.T, body []byte) {
				var result map[string]any
				err := json.Unmarshal(body, &result)
				require.NoError(t, err)

				// Check query params
				args, ok := result["args"].(map[string]any)
				assert.True(t, ok)
				assert.Contains(t, args, "test-param")

				// Check headers
				headers, ok := result["headers"].(map[string]any)
				assert.True(t, ok)
				assert.Contains(t, headers, "X-Test-Header")
			},
		},
		{
			name:   "GET request with body",
			method: http.MethodGet,
			path:   "/get",
			queryParams: url.Values{
				"test-param": []string{"test-value"},
			},
			headers: http.Header{
				"X-Test-Header": []string{"test-header-value"},
			},
			body:               `{"test":"data","number":123}`,
			expectedStatusCode: http.StatusOK,
			validateBody: func(t *testing.T, body []byte) {
				var result map[string]any
				err := json.Unmarshal(body, &result)
				require.NoError(t, err)

				// Check query params
				args, ok := result["args"].(map[string]any)
				assert.True(t, ok)
				assert.Contains(t, args, "test-param")

				// Check headers
				headers, ok := result["headers"].(map[string]any)
				assert.True(t, ok)
				assert.Contains(t, headers, "X-Test-Header")

				// Check that the body was received and parsed as json
				jsonData, ok := result["json_body"].(map[string]any)
				assert.True(t, ok)
				assert.Equal(t, "data", jsonData["test"])
				assert.InEpsilon(t, float64(123), jsonData["number"], 0.001)
			},
		},
		{
			name:        "POST request with json body",
			method:      http.MethodPost,
			path:        "/post",
			queryParams: url.Values{},
			headers: http.Header{
				"Content-Type": []string{"application/json"},
			},
			body:               `{"test":"data","number":123}`,
			expectedStatusCode: http.StatusCreated,
			validateBody: func(t *testing.T, body []byte) {
				var result map[string]any
				err := json.Unmarshal(body, &result)
				require.NoError(t, err)

				// Check that the body was received and parsed as json
				jsonData, ok := result["json_body"].(map[string]any)
				assert.True(t, ok)
				assert.Equal(t, "data", jsonData["test"])
				assert.InEpsilon(t, float64(123), jsonData["number"], 0.001)
			},
		},
		{
			name:        "PUT request with plain body",
			method:      http.MethodPut,
			path:        "/put",
			queryParams: url.Values{},
			headers: http.Header{
				"Content-Type": []string{"text/plain"},
			},
			body:               "Random body value",
			expectedStatusCode: http.StatusOK,
			validateBody: func(t *testing.T, body []byte) {
				var result map[string]any
				err := json.Unmarshal(body, &result)
				require.NoError(t, err)

				// Check that the body was received
				dataField, ok := result["body"].(string)
				assert.True(t, ok)
				assert.Equal(t, "Random body value", dataField)
			},
		},
		{
			name:               "DELETE request",
			method:             http.MethodDelete,
			path:               "/delete",
			queryParams:        url.Values{},
			headers:            http.Header{},
			expectedStatusCode: http.StatusOK,
			validateBody: func(t *testing.T, body []byte) {
				var result map[string]any
				err := json.Unmarshal(body, &result)
				require.NoError(t, err)

				// Delete endpoint returns request details
				assert.Contains(t, result, "url")
			},
		},
		{
			name:               "Status code 400",
			method:             http.MethodGet,
			path:               "/status/400",
			queryParams:        url.Values{},
			headers:            http.Header{},
			expectedStatusCode: http.StatusBadRequest,
			validateBody: func(t *testing.T, body []byte) {
				// Status endpoints return empty body
				assert.Empty(t, body)
			},
		},
		{
			name:               "Status code 404",
			method:             http.MethodGet,
			path:               "/status/404",
			queryParams:        url.Values{},
			headers:            http.Header{},
			expectedStatusCode: http.StatusNotFound,
			validateBody: func(t *testing.T, body []byte) {
				// Status endpoints return empty body
				assert.Empty(t, body)
			},
		},
		{
			name:               "Status code 500",
			method:             http.MethodGet,
			path:               "/status/500",
			queryParams:        url.Values{},
			headers:            http.Header{},
			expectedStatusCode: http.StatusInternalServerError,
			validateBody: func(t *testing.T, body []byte) {
				// Status endpoints return empty body
				assert.Empty(t, body)
			},
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
			defer cancel()

			// Create the bidirectional stream
			stream, err := client.TunnelHTTP(ctx)
			require.NoError(t, err)

			sendToStream(t, stream, tt.method, tt.path, tt.queryParams, tt.headers, tt.body, httpPort)
			statusCode, responseBody, err := receiveFromStream(stream)
			require.NoError(t, err)

			// Validate response
			assert.Equal(t, tt.expectedStatusCode, statusCode, "Status code mismatch")
			tt.validateBody(t, responseBody)
		})
	}
}

// TestTunnelServer_Websocket tests the tunnel server with a real gRPC server and http backend
// when the request contains an upgrade to websocket
func TestTunnelServer_Websocket(t *testing.T) {
	// Start HTTP test server
	httpPort, httpCleanupFunc := startTestHTTPServer(t)
	defer httpCleanupFunc()

	// Create and start the gRPC server
	grpcServer, grpcAddr := startTestGRPCServer(t)
	defer grpcServer.Stop()

	// Create gRPC client
	conn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
	require.NoError(t, err)
	defer conn.Close()

	client := rpc.NewWorkspacesClient(conn)

	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
	defer cancel()

	// Create the bidirectional stream
	stream, err := client.TunnelHTTP(ctx)
	require.NoError(t, err)

	queryParams := url.Values{
		"test-param": []string{"test-value"},
	}
	headers := http.Header{
		"X-Test-Header":     []string{"test-header-value"},
		"Connection":        []string{"Upgrade"},
		"Upgrade":           []string{"websocket"},
		"Sec-WebSocket-Key": []string{"test-key"},
	}
	sendToStream(t, stream, http.MethodGet, "/websocket", queryParams, headers, "", httpPort)

	// Receive the upgrade response header
	resp1, err := stream.Recv()
	require.NoError(t, err)
	respHeaderMsg, ok := resp1.Message.(*grpctool.HttpResponse_Header_)
	require.True(t, ok)
	respHeaders := prototool.HeaderKVToHTTPHeader(respHeaderMsg.Header.Response.Header)

	require.Equal(t, http.StatusSwitchingProtocols, int(respHeaderMsg.Header.Response.StatusCode))
	require.Equal(t, "websocket", respHeaders.Get("Upgrade"), "Upgrade header mismatch")
	require.Equal(t, "Upgrade", respHeaders.Get("Connection"), "Connection header mismatch")
	require.Contains(t, respHeaders, "Sec-Websocket-Accept", "WebSocket accept header mismatch")

	// Receive the trailer
	resp2, err := stream.Recv()
	require.NoError(t, err)
	_, ok = resp2.Message.(*grpctool.HttpResponse_Trailer_)
	require.True(t, ok)

	// Receive a message sent by the server
	resp3, err := stream.Recv()
	require.NoError(t, err)
	respUpgradeDataMsg, ok := resp3.Message.(*grpctool.HttpResponse_UpgradeData_)
	require.True(t, ok)
	assert.Contains(t, string(respUpgradeDataMsg.UpgradeData.Data), "Hello Websockets!")
}

// TestTunnelServer_Timeout tests the tunnel server with a real gRPC server and http backend
// when the context cancels due to a timeout
func TestTunnelServer_Timeout(t *testing.T) {
	// Start HTTP test server
	httpPort, httpCleanupFunc := startTestHTTPServer(t)
	defer httpCleanupFunc()

	// Create and start the gRPC server
	grpcServer, grpcAddr := startTestGRPCServer(t)
	defer grpcServer.Stop()

	// Create gRPC client
	conn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
	require.NoError(t, err)
	defer conn.Close()

	client := rpc.NewWorkspacesClient(conn)

	ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
	defer cancel()

	// Create the bidirectional stream
	stream, err := client.TunnelHTTP(ctx)
	require.NoError(t, err)

	// The below endpoint will respond after 1 second and the context will get canceled before that
	sendToStream(t, stream, http.MethodGet, "/delay/1", url.Values{}, http.Header{}, "", httpPort)
	_, _, err = receiveFromStream(stream)
	require.Error(t, err)
	assert.Equal(t, codes.DeadlineExceeded, status.Code(err))
}

// TestTunnelServer_ConnectionRefused tests the tunnel server with a real gRPC server
// when the connection is refused
func TestTunnelServer_ConnectionRefused(t *testing.T) {
	// Create and start the gRPC server
	grpcServer, grpcAddr := startTestGRPCServer(t)
	defer grpcServer.Stop()

	// Create gRPC client
	conn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
	require.NoError(t, err)
	defer conn.Close()

	client := rpc.NewWorkspacesClient(conn)

	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
	defer cancel()

	// Create the bidirectional stream
	stream, err := client.TunnelHTTP(ctx)
	require.NoError(t, err)

	// Create a test http server just to get a valid available port. We close the server just before sending
	// to the request. This approach is better than using a fixed port number which might have listeners and
	// would result in flakiness.
	httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		t.Error("This should never be called")
	}))
	httpURL, err := url.Parse(httpServer.URL)
	require.NoError(t, err)
	httpPort, err := strconv.Atoi(httpURL.Port())
	require.NoError(t, err)
	httpServer.Close() // Close immediately

	sendToStream(t, stream, http.MethodGet, "/get", url.Values{}, http.Header{}, "", httpPort)
	_, _, err = receiveFromStream(stream)
	require.Error(t, err)
	assert.Equal(t, codes.Unavailable, status.Code(err))
}

// sendToStream sends a request to the stream
func sendToStream(t *testing.T, stream grpc.BidiStreamingClient[grpctool.HttpRequest, grpctool.HttpResponse], method, path string, queryParams url.Values, headers http.Header, body string, port int) {
	// Build the HTTP request
	httpRequest := &prototool.HttpRequest{
		Method:  method,
		UrlPath: path,
		Header:  prototool.HTTPHeaderToHeaderKV(headers),
		Query:   prototool.URLValuesToQueryKV(queryParams),
	}

	// Create the extra field with the port
	extra, err := getHTTPHeaderExtra(uint32(port))
	require.NoError(t, err)

	// Send request header
	err = stream.Send(&grpctool.HttpRequest{
		Message: &grpctool.HttpRequest_Header_{
			Header: &grpctool.HttpRequest_Header{
				Request: httpRequest,
				Extra:   extra,
			},
		},
	})
	require.NoError(t, err)

	// Send body if present
	if body != "" {
		err = stream.Send(&grpctool.HttpRequest{
			Message: &grpctool.HttpRequest_Data_{
				Data: &grpctool.HttpRequest_Data{
					Data: []byte(body),
				},
			},
		})
		require.NoError(t, err)
	}

	// Send trailer to signal end of request
	err = stream.Send(&grpctool.HttpRequest{
		Message: &grpctool.HttpRequest_Trailer_{
			Trailer: &grpctool.HttpRequest_Trailer{},
		},
	})
	require.NoError(t, err)

	// Close the send side
	err = stream.CloseSend()
	require.NoError(t, err)
}

// receiveFromStream returns the respons from the stream
func receiveFromStream(stream grpc.BidiStreamingClient[grpctool.HttpRequest, grpctool.HttpResponse]) (int32, []byte, error) {
	var statusCode int32
	var responseBody []byte

	for {
		resp, err := stream.Recv()
		if errors.Is(err, io.EOF) {
			break
		}
		if err != nil {
			return 0, nil, err
		}

		switch msg := resp.Message.(type) {
		case *grpctool.HttpResponse_Header_:
			statusCode = msg.Header.Response.StatusCode
		case *grpctool.HttpResponse_Data_:
			responseBody = append(responseBody, msg.Data.Data...)
		}
	}
	return statusCode, responseBody, nil
}

// startTestHTTPServer starts an HTTP server which is called by the tunnel
func startTestHTTPServer(t *testing.T) (int, func()) {
	parseBody := func(r *http.Request) (string, any) {
		body, err := io.ReadAll(r.Body)
		if err != nil {
			return "", nil
		}
		defer r.Body.Close()

		bodyStr := string(body)

		// Try to parse as json
		var jsonData any
		if err := json.Unmarshal(body, &jsonData); err == nil {
			return bodyStr, jsonData
		}

		return bodyStr, nil
	}

	type response struct {
		Args     map[string][]string `json:"args"`
		Body     string              `json:"body"`
		Headers  map[string][]string `json:"headers"`
		JSONBody any                 `json:"json_body"`
		Method   string              `json:"method"`
		Origin   string              `json:"origin"`
		URL      string              `json:"url"`
	}

	buildMethodResponse := func(r *http.Request) response {
		bodyStr, jsonBody := parseBody(r)

		return response{
			Args:     r.URL.Query(),
			Body:     bodyStr,
			Headers:  r.Header,
			JSONBody: jsonBody,
			Method:   r.Method,
			Origin:   r.RemoteAddr,
			URL:      r.URL.String(),
		}
	}

	// Create a channel to signal WebSocket handler shutdown
	wsShutdown := make(chan struct{})

	mux := http.NewServeMux()
	mux.HandleFunc("GET /get", func(w http.ResponseWriter, r *http.Request) {
		w.WriteHeader(http.StatusOK)
		w.Header().Set("Content-Type", "application/json")
		err := json.NewEncoder(w).Encode(buildMethodResponse(r))
		assert.NoError(t, err)
	})
	mux.HandleFunc("PUT /put", func(w http.ResponseWriter, r *http.Request) {
		w.WriteHeader(http.StatusOK)
		w.Header().Set("Content-Type", "application/json")
		err := json.NewEncoder(w).Encode(buildMethodResponse(r))
		assert.NoError(t, err)
	})
	mux.HandleFunc("POST /post", func(w http.ResponseWriter, r *http.Request) {
		w.WriteHeader(http.StatusCreated)
		w.Header().Set("Content-Type", "application/json")
		err := json.NewEncoder(w).Encode(buildMethodResponse(r))
		assert.NoError(t, err)
	})
	mux.HandleFunc("DELETE /delete", func(w http.ResponseWriter, r *http.Request) {
		w.WriteHeader(http.StatusOK)
		w.Header().Set("Content-Type", "application/json")
		err := json.NewEncoder(w).Encode(buildMethodResponse(r))
		assert.NoError(t, err)
	})
	mux.HandleFunc("GET /status/{code}", func(w http.ResponseWriter, r *http.Request) {
		codeStr := r.PathValue("code")
		code, err := strconv.Atoi(codeStr)
		if err != nil {
			w.WriteHeader(http.StatusBadRequest)
			_, _ = w.Write([]byte(err.Error()))
			return
		}
		text := http.StatusText(code)
		if text == "" {
			w.WriteHeader(http.StatusBadRequest)
			_, _ = w.Write([]byte("invalid code: " + codeStr))
			return
		}
		w.WriteHeader(code)
	})
	mux.HandleFunc("GET /delay/{n}", func(w http.ResponseWriter, r *http.Request) {
		nStr := r.PathValue("n")
		n, err := strconv.Atoi(nStr)
		if err != nil {
			w.WriteHeader(http.StatusBadRequest)
			_, _ = w.Write([]byte(err.Error()))
		}
		time.Sleep(time.Duration(n) * time.Second)
		w.WriteHeader(http.StatusOK)
		w.Header().Set("Content-Type", "application/json")
		err = json.NewEncoder(w).Encode(buildMethodResponse(r))
		assert.NoError(t, err)
	})
	mux.HandleFunc("GET /websocket", func(w http.ResponseWriter, r *http.Request) {
		// 1. Check for WebSocket upgrade headers
		if r.Header.Get("Upgrade") != "websocket" {
			http.Error(w, "Not a websocket request", http.StatusBadRequest)
			return
		}

		// 2. Get the WebSocket key and generate accept hash
		key := r.Header.Get("Sec-WebSocket-Key")
		accept := key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
		hash := sha1.Sum([]byte(accept))
		acceptHash := base64.StdEncoding.EncodeToString(hash[:])

		// 3. Hijack the connection
		hijacker, _ := w.(http.Hijacker)
		conn, bufrw, _ := hijacker.Hijack()
		defer conn.Close()

		// 4. Send upgrade response
		respStr := fmt.Sprintf("HTTP/1.1 101 Switching Protocols\r\n"+
			"Upgrade: websocket\r\n"+
			"Connection: Upgrade\r\n"+
			"Sec-Websocket-Accept: %s\r\n\r\n", acceptHash)

		_, err := bufrw.WriteString(respStr)
		assert.NoError(t, err)
		err = bufrw.Flush()
		assert.NoError(t, err)

		// 6. Connection established! Send a message
		message := []byte("Hello Websockets!")
		// 0x81 = FIN + text frame
		frame := []byte{0x81, byte(len(message))} //nolint:prealloc
		frame = append(frame, message...)
		_, err = conn.Write(frame)
		assert.NoError(t, err)

		// Connection established - do nothing else
		<-wsShutdown
	})
	httpServer := httptest.NewServer(mux)

	cleanupFunc := func() {
		close(wsShutdown)
		httpServer.Close()
	}

	// Parse URL to get port
	httpURL, err := url.Parse(httpServer.URL)
	require.NoError(t, err)
	httpPort, err := strconv.Atoi(httpURL.Port())
	require.NoError(t, err)

	return httpPort, cleanupFunc
}

// startTestGRPCServer starts a gRPC server with the tunnel service
func startTestGRPCServer(t *testing.T) (*grpc.Server, string) {
	// Create mock dependencies
	ctrl := gomock.NewController(t)
	mockRPCAPI := mock_modserver.NewMockAgentRPCAPI(ctrl)

	// Set up expectations
	mockRPCAPI.EXPECT().Log().Return(testlogger.New(t)).AnyTimes()
	mockRPCAPI.EXPECT().HandleProcessingError(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
	mockRPCAPI.EXPECT().HandleIOError(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()

	// Create the tunnel server
	validator, err := protovalidate.New()
	require.NoError(t, err)

	tunnelSrv := newTunnelServer(
		testlogger.New(t),
		http.DefaultTransport,
		validator,
		"test-user-agent",
		"test-version",
		"test-via",
	)

	// Create gRPC server
	grpcSrv := grpc.NewServer(
		grpc.UnaryInterceptor(func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
			// Inject RPC API into context
			ctx = modshared.InjectRPCAPI(ctx, mockRPCAPI)
			return handler(ctx, req)
		}),
		grpc.StreamInterceptor(func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
			// Inject RPC API into context
			wrapped := &wrappedServerStream{
				ServerStream: ss,
				ctx:          modshared.InjectRPCAPI(ss.Context(), mockRPCAPI),
			}
			return handler(srv, wrapped)
		}),
	)

	// Register the service
	rpc.RegisterWorkspacesServer(grpcSrv, tunnelSrv)

	// Start listening
	listener, err := net.Listen("tcp", "127.0.0.1:0")
	require.NoError(t, err)

	go func() {
		err := grpcSrv.Serve(listener)
		if err != nil {
			t.Errorf("failed to start grpc server: %v", err)
		}
	}()

	// Wait for server to start
	time.Sleep(100 * time.Millisecond)

	return grpcSrv, listener.Addr().String()
}

// wrappedServerStream wraps a ServerStream with a custom context
type wrappedServerStream struct {
	grpc.ServerStream
	ctx context.Context //nolint:containedctx
}

func (w *wrappedServerStream) Context() context.Context {
	return w.ctx
}

// getHTTPHeaderExtra returns the extra field for the http request header
func getHTTPHeaderExtra(port uint32) (*anypb.Any, error) {
	extra := &rpc.HTTPHeaderExtra{
		Port: port,
	}
	return anypb.New(extra)
}
