package search

import (
	"context"
	"errors"
	"net"
	"net/http"
	"strings"
	"sync"
	"testing"
	"time"

	"github.com/stretchr/testify/require"
	"google.golang.org/grpc"
	"google.golang.org/grpc/credentials"
	"google.golang.org/grpc/metadata"
	"google.golang.org/grpc/test/bufconn"

	proto "github.com/sourcegraph/zoekt/grpc/protos/zoekt/webserver/v1"
)

func TestSearch_UsesGrpc(t *testing.T) {
	t.Parallel()

	conn, endpoint := StartBufconnServer(t, []*proto.FileMatch{
		{FileName: []byte("grpc.go"), Score: 1.0},
	}, &proto.Stats{FileCount: 1})
	defer conn.Close()

	// Extract the host part from the endpoint for storing the connection
	hostPart := GetHostFromEndpoint(endpoint)

	searcher := &Searcher{
		Client: &http.Client{},
		GrpcConns: map[string]*grpc.ClientConn{
			hostPart: conn, // Store using the host part as key
		},
	}

	raw := &RawSearchRequest{
		ZoektSearchRequest: ZoektSearchRequest{
			Query: "grpc test",
		},
		Timeout: time.Second,
		ForwardToV1: []Conn{
			{
				Endpoint: endpoint,
				RepoIds:  []uint32{1},
			},
		},
	}

	// Convert RawSearchRequest to SearchRequest
	req, err := raw.ToSearchRequest()
	require.NoError(t, err)

	resp := searcher.DoSearch(context.Background(), req, &req.ForwardTo[0])

	require.NoError(t, resp.Error)
	require.Equal(t, endpoint, resp.Endpoint)
	require.Len(t, resp.Result.Files, 1)
	require.Equal(t, "grpc.go", resp.Result.Files[0].FileName)
}

func TestHandleGrpcSearchStream(t *testing.T) {
	stream := &mockStream{
		responses: []*proto.StreamSearchResponse{
			{
				ResponseChunk: &proto.SearchResponse{
					Stats: &proto.Stats{
						FileCount:    1,
						MatchCount:   2,
						NgramMatches: 3,
					},
					Files: []*proto.FileMatch{
						{FileName: []byte("test.go")},
					},
				},
			},
		},
	}

	ctx := context.Background()
	resp := handleGrpcSearchStream(ctx, stream, "test-endpoint", &SearchRequest{})

	require.NoError(t, resp.Error)
	require.Len(t, resp.Result.Files, 1)
	require.Equal(t, 1, resp.Result.FileCount)
	require.Equal(t, 2, resp.Result.MatchCount)
	require.Equal(t, 3, resp.Result.NgramMatches)
}

func TestHandleGrpcSearchStream_NilResponse(t *testing.T) {
	stream := &mockStream{
		responses: []*proto.StreamSearchResponse{nil},
	}
	ctx := context.Background()
	resp := handleGrpcSearchStream(ctx, stream, "test-endpoint", &SearchRequest{})

	require.Error(t, resp.Error)
	require.Contains(t, resp.Error.Error(), "received nil response")
}

func TestHandleGrpcSearchStream_NilChunk(t *testing.T) {
	stream := &mockStream{
		responses: []*proto.StreamSearchResponse{
			{}, // ResponseChunk is nil
			{ // Add valid response to end the stream
				ResponseChunk: &proto.SearchResponse{
					Stats: &proto.Stats{FileCount: 1},
					Files: []*proto.FileMatch{{FileName: []byte("valid.go")}},
				},
			},
		},
	}
	ctx := context.Background()
	resp := handleGrpcSearchStream(ctx, stream, "test-endpoint", &SearchRequest{})

	require.NoError(t, resp.Error)
	require.Len(t, resp.Result.Files, 1)
	require.Equal(t, "valid.go", resp.Result.Files[0].FileName)
}

func TestHandleGrpcSearchStream_CanceledContext(t *testing.T) {
	stream := &mockStream{
		responses: []*proto.StreamSearchResponse{
			{
				ResponseChunk: &proto.SearchResponse{
					Files: []*proto.FileMatch{{FileName: []byte("cancel.go")}},
					Stats: &proto.Stats{FileCount: 1},
				},
			},
		},
	}
	ctx, cancel := context.WithCancel(context.Background())
	cancel() // Cancel before the first recv

	resp := handleGrpcSearchStream(ctx, stream, "test-endpoint", &SearchRequest{})

	require.ErrorIs(t, resp.Error, context.Canceled)
}

func TestHandleGrpcSearchStream_ErrorFromStream(t *testing.T) {
	stream := &mockStream{
		err: errors.New("boom"),
	}
	ctx := context.Background()
	resp := handleGrpcSearchStream(ctx, stream, "test-endpoint", &SearchRequest{})

	require.Error(t, resp.Error)
	require.Contains(t, resp.Error.Error(), "boom")
}

func TestHandleGrpcSearchStream_MaxLineMatchWindow(t *testing.T) {
	// Create stream with 3 files, each with 3 line matches = 9 total
	stream := &mockStream{
		responses: []*proto.StreamSearchResponse{
			{
				ResponseChunk: &proto.SearchResponse{
					Stats: &proto.Stats{FileCount: 3, MatchCount: 9},
					Files: []*proto.FileMatch{
						{
							FileName: []byte("file1.go"),
							LineMatches: []*proto.LineMatch{
								{Line: []byte("match1")},
								{Line: []byte("match2")},
								{Line: []byte("match3")},
							},
							Score: 0.9,
						},
						{
							FileName: []byte("file2.go"),
							LineMatches: []*proto.LineMatch{
								{Line: []byte("match4")},
								{Line: []byte("match5")},
								{Line: []byte("match6")},
							},
							Score: 0.8,
						},
						{
							FileName: []byte("file3.go"),
							LineMatches: []*proto.LineMatch{
								{Line: []byte("match7")},
								{Line: []byte("match8")},
								{Line: []byte("match9")},
							},
							Score: 0.7,
						},
					},
				},
			},
		},
	}

	ctx := context.Background()
	// Set maxLineMatchWindow to 5 - should stop after collecting 5 line matches
	resp := handleGrpcSearchStream(ctx, stream, "test-endpoint", &SearchRequest{
		MaxLineMatchWindow: 5,
	})

	require.NoError(t, resp.Error)
	require.Len(t, resp.Result.Files, 2) // Should have 2 files (first with 3, second with 2)

	// Count total line matches
	totalLineMatches := 0
	for _, file := range resp.Result.Files {
		totalLineMatches += len(file.LineMatches)
	}
	require.Equal(t, 5, totalLineMatches)
}

func TestHandleGrpcSearchStream_MaxLineMatchResultsPerFile(t *testing.T) {
	// Create stream with 1 file with 5 line matches
	stream := &mockStream{
		responses: []*proto.StreamSearchResponse{
			{
				ResponseChunk: &proto.SearchResponse{
					Stats: &proto.Stats{FileCount: 1, MatchCount: 5},
					Files: []*proto.FileMatch{
						{
							FileName: []byte("file.go"),
							LineMatches: []*proto.LineMatch{
								{Line: []byte("match1")},
								{Line: []byte("match2")},
								{Line: []byte("match3")},
								{Line: []byte("match4")},
								{Line: []byte("match5")},
							},
						},
					},
				},
			},
		},
	}

	ctx := context.Background()
	// Set maxPerFile to 3 - should only keep 3 line matches per file
	resp := handleGrpcSearchStream(ctx, stream, "test-endpoint", &SearchRequest{
		MaxLineMatchResultsPerFile: 3,
	})

	require.NoError(t, resp.Error)
	require.Len(t, resp.Result.Files, 1)
	require.Len(t, resp.Result.Files[0].LineMatches, 3)
}

func TestHandleGrpcSearchStream_BothLimits(t *testing.T) {
	// Create stream with multiple files
	stream := &mockStream{
		responses: []*proto.StreamSearchResponse{
			{
				ResponseChunk: &proto.SearchResponse{
					Files: []*proto.FileMatch{
						{
							FileName: []byte("file1.go"),
							LineMatches: []*proto.LineMatch{
								{Line: []byte("match1")},
								{Line: []byte("match2")},
								{Line: []byte("match3")},
								{Line: []byte("match4")},
								{Line: []byte("match5")},
							},
						},
						{
							FileName: []byte("file2.go"),
							LineMatches: []*proto.LineMatch{
								{Line: []byte("match6")},
								{Line: []byte("match7")},
								{Line: []byte("match8")},
							},
						},
					},
				},
			},
		},
	}

	ctx := context.Background()
	// maxPerFile=2 means each file can have at most 2 matches
	// maxLineMatchWindow=3 means total across all files is 3
	resp := handleGrpcSearchStream(ctx, stream, "test-endpoint", &SearchRequest{
		MaxLineMatchWindow:         3,
		MaxLineMatchResultsPerFile: 2,
	})

	require.NoError(t, resp.Error)
	// First file gets truncated to 2 matches, second file gets 1 match (to reach total of 3)
	require.Len(t, resp.Result.Files, 2)
	require.Len(t, resp.Result.Files[0].LineMatches, 2)
	require.Len(t, resp.Result.Files[1].LineMatches, 1)

	// Total should be 3
	totalLineMatches := 0
	for _, file := range resp.Result.Files {
		totalLineMatches += len(file.LineMatches)
	}
	require.Equal(t, 3, totalLineMatches)
}

func TestGetGrpcConnConcurrent(t *testing.T) {
	searcher := &Searcher{
		GrpcConns: make(map[string]*grpc.ClientConn),
	}
	searcher.grpcMutex = sync.RWMutex{}

	var wg sync.WaitGroup
	const goroutines = 10
	for i := 0; i < goroutines; i++ {
		wg.Add(1)
		go func() {
			defer wg.Done()
			_, _ = searcher.getGrpcConn("http://localhost:1234")
		}()
	}
	wg.Wait()
}

// TestEndpointPathExtraction tests that paths are correctly extracted from endpoints
func TestParseEndpoint(t *testing.T) {
	testCases := []struct {
		name             string
		endpoint         string
		expectedProtocol string
		expectedPath     string
		expectedHostPort string
	}{
		{
			name:             "simple endpoint without path",
			endpoint:         "bufnet",
			expectedProtocol: "",
			expectedPath:     "",
			expectedHostPort: "bufnet",
		},
		{
			name:             "endpoint with path",
			endpoint:         "bufnet/some/path",
			expectedProtocol: "",
			expectedPath:     "/some/path",
			expectedHostPort: "bufnet",
		},
		{
			name:             "http endpoint with path",
			endpoint:         "http://bufnet/some/path",
			expectedProtocol: "http",
			expectedPath:     "/some/path",
			expectedHostPort: "bufnet",
		},
		{
			name:             "https endpoint with path",
			endpoint:         "https://bufnet/some/path",
			expectedProtocol: "https",
			expectedPath:     "/some/path",
			expectedHostPort: "bufnet",
		},
		{
			name:             "http endpoint with port",
			endpoint:         "http://bufnet:8080/some/path",
			expectedProtocol: "http",
			expectedPath:     "/some/path",
			expectedHostPort: "bufnet:8080",
		},
		{
			name:             "https endpoint with port",
			endpoint:         "https://bufnet:8443/some/path",
			expectedProtocol: "https",
			expectedPath:     "/some/path",
			expectedHostPort: "bufnet:8443",
		},
		{
			name:             "endpoint with complex path",
			endpoint:         "bufnet/some/path/with/multiple/segments",
			expectedProtocol: "",
			expectedPath:     "/some/path/with/multiple/segments",
			expectedHostPort: "bufnet",
		},
	}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			info, err := parseEndpoint(tc.endpoint)

			// For endpoints without http/https protocol, we expect an error
			if tc.expectedProtocol == "" && !strings.HasPrefix(tc.endpoint, "http://") && !strings.HasPrefix(tc.endpoint, "https://") {
				require.Error(t, err, "Should return error for unknown protocol")
				require.Equal(t, tc.endpoint, info.Original, "Original endpoint should be preserved even on error")
				return
			}

			require.NoError(t, err)
			require.Equal(t, tc.endpoint, info.Original, "Original endpoint should be preserved")
			require.Equal(t, tc.expectedProtocol, info.Protocol, "Protocol should be correctly extracted")
			require.Equal(t, tc.expectedPath, info.Path, "Path should be correctly extracted")
			require.Equal(t, tc.expectedHostPort, info.HostPort, "HostPort should be correctly extracted")
		})
	}
}

func TestEndpointPathExtraction(t *testing.T) {
	t.Parallel()

	// Create a server that captures the metadata from the request
	lis := bufconn.Listen(1024 * 1024)
	s := grpc.NewServer()

	server := &MetadataCapturingServer{
		files: []*proto.FileMatch{
			{FileName: []byte("test.go"), Score: 1.0},
		},
		stats: &proto.Stats{FileCount: 1},
	}
	proto.RegisterWebserverServiceServer(s, server)

	go func() {
		_ = s.Serve(lis)
	}()

	dialer := func(context.Context, string) (net.Conn, error) {
		return lis.Dial()
	}

	conn, err := NewTestGrpcClient(context.Background(), "bufnet", dialer)
	require.NoError(t, err)
	defer conn.Close()

	searcher := &Searcher{
		Client: &http.Client{},
		GrpcConns: map[string]*grpc.ClientConn{
			"bufnet": conn,
		},
	}

	testCases := []struct {
		name           string
		endpoint       string
		expectedPath   string
		expectedTarget string
	}{
		{
			name:           "http endpoint without path",
			endpoint:       "http://bufnet",
			expectedPath:   "",
			expectedTarget: "bufnet",
		},
		{
			name:           "endpoint with path",
			endpoint:       "http://bufnet/some/path",
			expectedPath:   "/some/path",
			expectedTarget: "bufnet",
		},
		{
			name:           "http endpoint with path",
			endpoint:       "http://bufnet/some/path",
			expectedPath:   "/some/path",
			expectedTarget: "bufnet",
		},
		{
			name:           "https endpoint with path",
			endpoint:       "https://bufnet/some/path",
			expectedPath:   "/some/path",
			expectedTarget: "bufnet",
		},
		{
			name:           "endpoint with complex path",
			endpoint:       "http://bufnet/some/path/with/multiple/segments",
			expectedPath:   "/some/path/with/multiple/segments",
			expectedTarget: "bufnet",
		},
	}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			// Reset the captured metadata
			server.capturedMD = nil

			raw := &RawSearchRequest{
				ZoektSearchRequest: ZoektSearchRequest{
					Query: "test",
				},
				Timeout: 0,
				ForwardToV1: []Conn{
					{
						Endpoint: tc.endpoint,
						RepoIds:  []uint32{1},
					},
				},
			}

			req, err := raw.ToSearchRequest()
			require.NoError(t, err)

			resp := searcher.DoSearch(context.Background(), req, &req.ForwardTo[0])
			require.NoError(t, resp.Error)

			// Verify the path was correctly extracted and forwarded
			if tc.expectedPath != "" {
				require.NotNil(t, server.capturedMD)
				paths := server.capturedMD.Get("x-forwarded-path")
				require.Len(t, paths, 1)
				require.Equal(t, tc.expectedPath, paths[0])
			} else {
				// If no path expected, the header should not be set
				if server.capturedMD != nil {
					paths := server.capturedMD.Get("x-forwarded-path")
					require.Empty(t, paths)
				}
			}
		})
	}
}

// TestConnectionCachingWithPaths tests that connections are properly cached based on host:port
func TestConnectionCachingWithPaths(t *testing.T) {
	t.Parallel()

	// Test the path extraction logic
	testCases := []struct {
		endpoint    string
		expectedKey string
	}{
		{
			endpoint:    "example.com",
			expectedKey: "example.com",
		},
		{
			endpoint:    "example.com/path",
			expectedKey: "example.com",
		},
		{
			endpoint:    "example.com/path/with/segments",
			expectedKey: "example.com",
		},
		{
			endpoint:    "http://example.com/api",
			expectedKey: "example.com",
		},
		{
			endpoint:    "https://example.com/some/path",
			expectedKey: "example.com",
		},
		{
			endpoint:    "example.com:8080/path",
			expectedKey: "example.com:8080",
		},
		{
			endpoint:    "http://example.com:9090/api",
			expectedKey: "example.com:9090",
		},
	}

	for _, tc := range testCases {
		t.Run(tc.endpoint, func(t *testing.T) {
			// Parse the endpoint like getGrpcConn does
			target := strings.TrimPrefix(tc.endpoint, "http://")
			target = strings.TrimPrefix(target, "https://")

			// Extract the host:port part
			parts := strings.SplitN(target, "/", 2)
			hostPort := parts[0]

			// Verify it matches our expectation
			require.Equal(t, tc.expectedKey, hostPort)
		})
	}

	// Test cache key generation in a simplified way
	conn1, err := grpc.NewClient("host1.example.com", grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(nil, "")))
	if err != nil {
		panic("failed to connect to host1: " + err.Error())
	}

	conn2, err := grpc.NewClient("host2.example.com", grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(nil, "")))
	if err != nil {
		panic("failed to connect to host2: " + err.Error())
	}

	searcher := &Searcher{
		GrpcConns: map[string]*grpc.ClientConn{
			"host1.example.com": conn1,
			"host2.example.com": conn2,
		},
	}

	// Manually extract cache key the same way getGrpcConn does
	endpoint := "http://host1.example.com/some/path"

	target := strings.TrimPrefix(endpoint, "http://")
	target = strings.TrimPrefix(target, "https://")
	parts := strings.SplitN(target, "/", 2)
	hostPort := parts[0]

	// Verify we'd find the right connection in the map
	_, exists := searcher.GrpcConns[hostPort]
	require.True(t, exists, "Connection should exist for host extracted from endpoint with path")

	// Different endpoint with same host should map to same connection
	endpoint2 := "https://host1.example.com/different/path"

	target2 := strings.TrimPrefix(endpoint2, "http://")
	target2 = strings.TrimPrefix(target2, "https://")
	parts2 := strings.SplitN(target2, "/", 2)
	hostPort2 := parts2[0]

	require.Equal(t, hostPort, hostPort2, "Different endpoints with same host should extract to same key")
	_, exists = searcher.GrpcConns[hostPort2]
	require.True(t, exists, "Connection should exist for second endpoint variant")
}

// TestForwardPathHeader tests that x-forwarded-path header is properly set
func TestForwardPathHeader(t *testing.T) {
	t.Parallel()

	// Setup a mock GRPC server that can capture the metadata
	lis := bufconn.Listen(1024 * 1024)
	s := grpc.NewServer()

	server := &MetadataCapturingServer{
		files: []*proto.FileMatch{
			{FileName: []byte("header_test.go"), Score: 1.0},
		},
		stats: &proto.Stats{FileCount: 1},
	}
	proto.RegisterWebserverServiceServer(s, server)

	go func() {
		_ = s.Serve(lis)
	}()

	dialer := func(context.Context, string) (net.Conn, error) {
		return lis.Dial()
	}

	conn, err := NewTestGrpcClient(context.Background(), "bufnet", dialer)
	require.NoError(t, err)
	defer conn.Close()

	searcher := &Searcher{
		Client: &http.Client{},
		GrpcConns: map[string]*grpc.ClientConn{
			"bufnet": conn,
		},
	}

	testCases := []struct {
		name            string
		endpoint        string
		headers         map[string]string
		expectedHeaders map[string]string
	}{
		{
			name:     "forward path with custom headers",
			endpoint: "http://bufnet/api/search",
			headers: map[string]string{
				"Authorization": "Bearer token123",
			},
			expectedHeaders: map[string]string{
				"x-forwarded-path": "/api/search",
				"authorization":    "Bearer token123",
			},
		},
		{
			name:     "http endpoint with path",
			endpoint: "http://bufnet/zoekt/api",
			headers: map[string]string{
				"Content-Type": "application/json",
			},
			expectedHeaders: map[string]string{
				"x-forwarded-path": "/zoekt/api",
				"content-type":     "application/grpc",
			},
		},
		{
			name:     "endpoint without path",
			endpoint: "http://bufnet",
			headers: map[string]string{
				"X-Invalid-Header": "value",
			},
			expectedHeaders: map[string]string{},
		},
	}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			// Reset captured metadata
			server.capturedMD = nil

			raw := &RawSearchRequest{
				ZoektSearchRequest: ZoektSearchRequest{
					Query: "test query",
				},
				HeadersV1: tc.headers,
				Timeout:   0,
				ForwardToV1: []Conn{
					{
						Endpoint: tc.endpoint,
						RepoIds:  []uint32{1},
					},
				},
			}

			req, err := raw.ToSearchRequest()
			require.NoError(t, err)

			resp := searcher.DoSearch(context.Background(), req, &req.ForwardTo[0])
			require.NoError(t, resp.Error)
			require.NotNil(t, server.capturedMD)

			// Verify all expected headers were forwarded
			for key, expectedValue := range tc.expectedHeaders {
				values := server.capturedMD.Get(key)
				require.NotEmpty(t, values, "Header %s not found", key)
				require.Equal(t, expectedValue, values[0], "Header %s has incorrect value", key)
			}
		})
	}
}

// MetadataCapturingServer is a mock implementation that captures metadata from requests
type MetadataCapturingServer struct {
	proto.UnimplementedWebserverServiceServer
	files      []*proto.FileMatch
	stats      *proto.Stats
	capturedMD metadata.MD
}

func (s *MetadataCapturingServer) StreamSearch(req *proto.StreamSearchRequest, stream proto.WebserverService_StreamSearchServer) error {
	// Capture the metadata from the context
	md, ok := metadata.FromIncomingContext(stream.Context())
	if ok {
		s.capturedMD = md
	}

	// Return a mock response
	err := stream.Send(&proto.StreamSearchResponse{
		ResponseChunk: &proto.SearchResponse{
			Files: s.files,
			Stats: s.stats,
		},
	})
	return err
}

func (s *MetadataCapturingServer) Search(ctx context.Context, req *proto.SearchRequest) (*proto.SearchResponse, error) {
	// Capture the metadata from the context
	md, ok := metadata.FromIncomingContext(ctx)
	if ok {
		s.capturedMD = md
	}

	return &proto.SearchResponse{
		Files: s.files,
		Stats: s.stats,
	}, nil
}
