package search

import (
	"context"
	"fmt"
	"io"
	"net"
	"strings"
	"testing"

	"github.com/stretchr/testify/require"
	"google.golang.org/grpc"
	"google.golang.org/grpc/credentials/insecure"
	"google.golang.org/grpc/test/bufconn"

	"github.com/sourcegraph/zoekt"
	proto "github.com/sourcegraph/zoekt/grpc/protos/zoekt/webserver/v1"
)

func StartBufconnServer(t *testing.T, files []*proto.FileMatch, stats *proto.Stats) (*grpc.ClientConn, string) {
	lis := bufconn.Listen(1024 * 1024)
	s := grpc.NewServer()

	mock := &MockZoektGRPCServer{
		files: files,
		stats: stats,
	}
	proto.RegisterWebserverServiceServer(s, mock)

	go func() {
		_ = s.Serve(lis)
	}()

	dialer := func(context.Context, string) (net.Conn, error) {
		return lis.Dial()
	}

	// Generate a unique name and add http:// prefix to make it compatible with parseEndpoint
	endpoint := fmt.Sprintf("http://bufnet-%p", lis)

	// Extract the host part to use as the target for the connection
	hostPart := strings.TrimPrefix(endpoint, "http://")
	hostPart = strings.Split(hostPart, "/")[0]

	conn, err := NewTestGrpcClient(context.Background(), hostPart, dialer)
	require.NoError(t, err)

	return conn, endpoint
}

type MockZoektGRPCServer struct {
	proto.UnimplementedWebserverServiceServer
	files  []*proto.FileMatch
	stats  *proto.Stats
	fail   bool
	errMsg string
}

func (m *MockZoektGRPCServer) StreamSearch(req *proto.StreamSearchRequest, srv proto.WebserverService_StreamSearchServer) error {
	if m.fail {
		return fmt.Errorf("%s", m.errMsg)

	}
	return srv.Send(&proto.StreamSearchResponse{
		ResponseChunk: &proto.SearchResponse{
			Files: m.files,
			Stats: m.stats,
		},
	})
}

func ToProtoFileMatches(matches []zoekt.FileMatch) []*proto.FileMatch {
	out := make([]*proto.FileMatch, 0, len(matches))
	for _, m := range matches {
		out = append(out, &proto.FileMatch{
			FileName:   []byte(m.FileName),
			Repository: m.Repository,
			Score:      m.Score,
			Checksum:   m.Checksum,
		})
	}
	return out
}

func StartFailingBufconnServer(t *testing.T, errMsg string) (*grpc.ClientConn, string) {
	lis := bufconn.Listen(1024 * 1024)
	s := grpc.NewServer()

	mock := &MockZoektGRPCServer{
		fail:   true,
		errMsg: errMsg,
	}
	proto.RegisterWebserverServiceServer(s, mock)

	go func() { _ = s.Serve(lis) }()

	dialer := func(context.Context, string) (net.Conn, error) {
		return lis.Dial()
	}

	// Add http:// prefix to make it compatible with parseEndpoint
	endpoint := fmt.Sprintf("http://bufnet-%p", lis)

	// Extract the host part to use as the target for the connection
	hostPart := strings.TrimPrefix(endpoint, "http://")
	hostPart = strings.Split(hostPart, "/")[0]

	conn, err := NewTestGrpcClient(context.Background(), hostPart, dialer)
	require.NoError(t, err)

	return conn, endpoint
}

type mockStream struct {
	responses []*proto.StreamSearchResponse
	err       error
	index     int
}

func (m *mockStream) Recv() (*proto.StreamSearchResponse, error) {
	if m.err != nil {
		return nil, m.err
	}
	if m.index >= len(m.responses) {
		return nil, io.EOF
	}
	resp := m.responses[m.index]
	m.index++
	return resp, nil
}

func NewTestGrpcClient(ctx context.Context, target string, dialer func(context.Context, string) (net.Conn, error)) (*grpc.ClientConn, error) {
	if !strings.HasPrefix(target, "passthrough:///") {
		target = "passthrough:///" + target
	}

	return grpc.NewClient(target,
		grpc.WithContextDialer(dialer),
		grpc.WithTransportCredentials(insecure.NewCredentials()),
	)
}

// GetHostFromEndpoint is a helper function to extract host part from endpoint URL
func GetHostFromEndpoint(endpoint string) string {
	// Remove protocol prefix if present
	cleanEndpoint := endpoint
	if strings.HasPrefix(endpoint, "http://") {
		cleanEndpoint = strings.TrimPrefix(endpoint, "http://")
	} else if strings.HasPrefix(endpoint, "https://") {
		cleanEndpoint = strings.TrimPrefix(endpoint, "https://")
	}

	// Extract host part (before any path)
	hostPart := strings.Split(cleanEndpoint, "/")[0]
	return hostPart
}
