package search

import (
	"context"
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"net/http"
	"sort"
	"strings"
	"sync"
	"time"

	"gitlab.com/gitlab-org/gitlab-zoekt-indexer/internal/authentication"
	"google.golang.org/grpc"
)

const (
	defaultSearchTimeout = "60s"

	defaultMaxLineMatchWindow = uint32(5000)
)

var (
	forwardedHeaders = []string{
		"Authorization",
		authentication.GitlabZoektAPIRequestHeader,
		"x-forwarded-path",
	}
)

func (searcher *Searcher) Search(r *http.Request) (*SearchResult, error) {
	req, err := NewSearchRequest(r)
	if err != nil {
		return nil, fmt.Errorf("error building search request: %w", err)
	}

	result, err := searcher.multiNodeSearch(r.Context(), req)
	if err != nil {
		return nil, fmt.Errorf("multi node search failed: %w", err)
	}

	return result, nil
}

func NewSearcher(client *http.Client) *Searcher {
	return &Searcher{
		Client:    client,
		GrpcConns: make(map[string]*grpc.ClientConn),
	}
}

func NewSearchRequest(r *http.Request) (*SearchRequest, error) {
	body, err := io.ReadAll(r.Body)
	if err != nil {
		return nil, fmt.Errorf("failed to read request body: %w", err)
	}

	var raw RawSearchRequest
	err = json.Unmarshal(body, &raw) // nolint:musttag
	if err != nil {
		return nil, fmt.Errorf("failed to parse request body: %w", err)
	}

	req, err := raw.ToSearchRequest()
	if err != nil {
		return nil, fmt.Errorf("failed to convert raw request to SearchRequest: %w", err)
	}

	if len(req.ForwardTo) == 0 {
		return nil, errors.New("no forward-to connections specified")
	}

	for _, conn := range req.ForwardTo {
		if conn.Endpoint == "" {
			return nil, errors.New("forward-to endpoint is empty")
		}
	}

	if req.TimeoutString == "" {
		req.TimeoutString = defaultSearchTimeout
	}

	timeout, err := time.ParseDuration(req.TimeoutString)
	if err != nil {
		return nil, fmt.Errorf("failed to parse Timeout: %v with error %w", req.TimeoutString, err)
	}

	req.Timeout = timeout

	if req.MaxLineMatchWindow == 0 {
		req.MaxLineMatchWindow = defaultMaxLineMatchWindow
	}

	// Forward headers from the request to each endpoint.
	// This does not overwrite existing headers in the endpoint options,
	// but adds them if they are not already set.
	for _, header := range forwardedHeaders {
		for i := range req.ForwardTo {
			conn := &req.ForwardTo[i]
			if conn.Options == nil {
				conn.Options = &EndpointOptions{
					Headers: make(map[string]string),
				}
			}
			if conn.Options.Headers == nil {
				conn.Options.Headers = make(map[string]string)
			}
			if values := r.Header.Values(header); len(values) > 0 {
				conn.Options.Headers[header] = strings.Join(values, ",")
			}
		}
	}

	return req, nil
}

func (searcher *Searcher) multiNodeSearch(ctx context.Context, s *SearchRequest) (*SearchResult, error) {
	var wg sync.WaitGroup
	responses := make(chan ZoektResponse, len(s.ForwardTo))

	timeoutCtx, cancel := context.WithTimeout(ctx, s.Timeout)
	defer cancel()

	for i := range s.ForwardTo {
		wg.Add(1)
		go func(i int) {
			defer wg.Done()
			resp := searcher.DoSearch(timeoutCtx, s, &s.ForwardTo[i])
			select {
			case responses <- resp:
			case <-timeoutCtx.Done():
			}
		}(i)
	}

	go func() {
		wg.Wait()
		close(responses)
	}()

	result := combineResults(timeoutCtx, responses, s, cancel)

	if result.TimedOut && len(result.Result.Files) == 0 {
		return nil, fmt.Errorf("search timed out")
	}

	if len(result.Failures) == len(s.ForwardTo) {
		var errMsgs []error
		for _, failure := range result.Failures {
			errMsgs = append(errMsgs, fmt.Errorf("%s: %s", failure.Endpoint, failure.Error))
		}
		combinedErr := errors.Join(errMsgs...)
		return nil, fmt.Errorf("all searches failed: %w", combinedErr)
	}

	result.sort()

	if s.MaxFileMatchResults > 0 && len(result.Result.Files) > int(s.MaxFileMatchResults) {
		result.Result.Files = result.Result.Files[:s.MaxFileMatchResults]
	}

	if s.MaxLineMatchResults > 0 {
		remaining := int(s.MaxLineMatchResults)
		maxPerFile := int(s.MaxLineMatchResultsPerFile)

		for i := range result.Result.Files {
			lines := result.Result.Files[i].LineMatches

			if maxPerFile > 0 && len(lines) > maxPerFile {
				lines = lines[:maxPerFile]
			}

			if len(lines) > remaining {
				lines = lines[:remaining]
			}

			result.Result.Files[i].LineMatches = lines
			remaining -= len(lines)

			if remaining <= 0 {
				result.Result.Files = result.Result.Files[:i+1]
				break
			}
		}
	}

	result.Result.CalculateStats()

	return result, nil
}

func combineResults(ctx context.Context, ch <-chan ZoektResponse, s *SearchRequest, cancel context.CancelFunc) *SearchResult {
	combinedResult := &SearchResult{}

	var totalFileCount int
	var totalLineCount int
	seenChecksums := make(map[string]bool)

	for {
		select {
		case resp, ok := <-ch:
			if !ok {
				return combinedResult
			}

			if resp.Error != nil {
				combinedResult.Failures = append(combinedResult.Failures, SearchFailure{
					Error:    resp.Error.Error(),
					Endpoint: resp.Endpoint,
				})
				continue
			}

			result := resp.Result
			combinedResult.Result.FileCount += uint32(result.FileCount)       // #nosec G115
			combinedResult.Result.MatchCount += uint32(result.MatchCount)     // #nosec G115
			combinedResult.Result.NgramMatches += uint32(result.NgramMatches) // #nosec G115

			for _, f := range result.Files {
				checksumKey := fmt.Sprintf("%s:%s:%x", f.FileName, f.Repository, f.Checksum)
				if seenChecksums[checksumKey] {
					continue
				}
				seenChecksums[checksumKey] = true

				lineCount := len(f.LineMatches)
				totalLineCount += lineCount
				totalFileCount++
				combinedResult.Result.Files = append(combinedResult.Result.Files, f)

				if (s.MaxFileMatchWindow > 0 && totalFileCount >= int(s.MaxFileMatchWindow)) ||
					(s.MaxLineMatchWindow > 0 && totalLineCount >= int(s.MaxLineMatchWindow)) {
					cancel()
					return combinedResult
				}
			}

		case <-ctx.Done():
			combinedResult.TimedOut = true
			return combinedResult
		}
	}
}

func (s *SearchResult) sort() {
	sort.Slice(s.Result.Files, func(i, j int) bool {
		return s.Result.Files[i].Score > s.Result.Files[j].Score
	})
}
