package opensearch

import (
	"context"
	"fmt"
	"log/slog"
	"net/http"
	"strings"
	"time"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/credentials"
	"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
	"github.com/aws/aws-sdk-go/aws/defaults"
	"github.com/aws/aws-sdk-go/aws/session"
	v4 "github.com/aws/aws-sdk-go/aws/signer/v4"
	"github.com/deoxxa/aws_signing_client"
	"github.com/olivere/elastic/v7"

	"gitlab.com/gitlab-org/gitlab-elasticsearch-indexer/internal/mode/chunk/types"
)

type OpenSearchClient struct {
	Conn   types.OpenSearchConnection
	Client *elastic.Client
}

func (c *OpenSearchClient) GetClient() *elastic.Client {
	return c.Client
}

func New(conn types.OpenSearchConnection) *OpenSearchClient {
	return &OpenSearchClient{
		Conn: conn,
	}
}

func (c *OpenSearchClient) Connect(ctx context.Context) error {
	slog.Debug("connecting to OpenSearch", "aws", c.Conn.AWS, "urls", c.Conn.URL)

	if c.Conn.AWS && c.Conn.Region == "" {
		return fmt.Errorf("aws_region is required when aws is enabled")
	}

	var opts []elastic.ClientOptionFunc

	httpClient := &http.Client{}
	if c.Conn.RequestTimeout != 0 {
		httpClient.Timeout = time.Duration(c.Conn.RequestTimeout) * time.Second
	}

	if c.Conn.AWS {
		awsConfig := defaults.Config().WithRegion(c.Conn.Region)
		credentials, err := c.resolveAWSCredentials(awsConfig)
		if err != nil {
			return fmt.Errorf("failed to resolve AWS credentials: %w", err)
		}
		signer := v4.NewSigner(credentials)
		awsClient, err := aws_signing_client.New(signer, httpClient, "es", c.Conn.Region)
		if err != nil {
			return fmt.Errorf("failed to create AWS signing client: %w", err)
		}

		opts = append(opts, elastic.SetHttpClient(awsClient))
	} else {
		if c.Conn.RequestTimeout != 0 {
			opts = append(opts, elastic.SetHttpClient(httpClient))
		}
	}

	// Sniffer should look for HTTPS URLs if at-least-one initial URL is HTTPS
	for _, url := range c.Conn.URL {
		if strings.HasPrefix(url, "https:") {
			opts = append(opts, elastic.SetScheme("https"))
			break
		}
	}

	opts = append(opts, elastic.SetURL(c.Conn.URL...), elastic.SetSniff(false))
	opts = append(opts, elastic.SetHealthcheck(false))

	client, err := elastic.NewClient(opts...)
	if err != nil {
		return fmt.Errorf("failed to create OpenSearch client: %w", err)
	}

	c.Client = client

	// Temporary: call ping to verify the connection
	info, code, err := client.Ping(c.Conn.URL[0]).Do(ctx)
	if err != nil {
		return fmt.Errorf("failed to ping OpenSearch: %w", err)
	}
	if code >= 400 {
		return fmt.Errorf("OpenSearch ping failed with status %d", code)
	}

	slog.Info("OpenSearch connection successful",
		"status", code,
		"version", info.Version.Number,
		"cluster_name", info.ClusterName,
	)

	return nil
}

func (c *OpenSearchClient) resolveAWSCredentials(awsConfig *aws.Config) (*credentials.Credentials, error) {
	if c.Conn.RoleARN != "" {
		slog.Debug("using AWS AssumeRole authentication", "roleArn", c.Conn.RoleARN)
		return c.createAssumeRoleCredentials(awsConfig)
	}

	// Fall back to the original credential chain
	providers := []credentials.Provider{
		&credentials.StaticProvider{
			Value: credentials.Value{
				AccessKeyID:     c.Conn.AccessKey,
				SecretAccessKey: c.Conn.SecretKey,
			},
		},
	}
	providers = append(providers, defaults.CredProviders(awsConfig, defaults.Handlers())...)
	return credentials.NewChainCredentials(providers), nil
}

func (c *OpenSearchClient) createAssumeRoleCredentials(awsConfig *aws.Config) (*credentials.Credentials, error) {
	sess, err := session.NewSession(awsConfig)
	if err != nil {
		return nil, fmt.Errorf("failed to create AWS session: %w", err)
	}

	var baseCredentials *credentials.Credentials
	if c.Conn.AccessKey != "" && c.Conn.SecretKey != "" {
		// Use static credentials if provided
		baseCredentials = credentials.NewStaticCredentials(
			c.Conn.AccessKey,
			c.Conn.SecretKey,
			"",
		)
		slog.Debug("using static credentials as base for role assumption")
	} else {
		// Use default credential chain
		baseCredentials = credentials.NewChainCredentials(
			defaults.CredProviders(awsConfig, defaults.Handlers()),
		)
		slog.Debug("using default credential chain as base for role assumption")
	}

	sess.Config.Credentials = baseCredentials
	assumeRoleCredentials := stscreds.NewCredentials(sess, c.Conn.RoleARN)

	return assumeRoleCredentials, nil
}
