package redistool

import (
	"context"
	"errors"
	"fmt"
	"iter"
	"time"
	"unsafe"

	"github.com/redis/rueidis"
	"go.opentelemetry.io/otel/attribute"
	otelmetric "go.opentelemetry.io/otel/metric"
	"google.golang.org/protobuf/proto"
)

const (
	scanCount                             = 1000
	maxKeyGCAttempts                      = 2
	maxDeleteBatchSize                    = 128
	gcDeletedKeysMetricName               = "redis_expiring_hash_api_gc_deleted_keys_count"
	gcConflictMetricName                  = "redis_expiring_hash_api_gc_conflict"
	expiringHashNameKey     attribute.Key = "expiring_hash_name"
)

type ScanEntry struct {
	RawHashKey string
	Value      []byte
}

// ExpiringHashAPI represents a low-level API to work with a two-level hash: key K1 -> hashKey K2 -> value []byte.
// key identifies the hash; hashKey identifies the key in the hash; value is the value for the hashKey.
type ExpiringHashAPI[K1 any, K2 any] interface {
	IOBuilder() IOBuilder[K1, K2]
	Scan(ctx context.Context, key K1) iter.Seq2[ScanEntry, error]
	GCFor(keys []K1, transactional bool) func(context.Context) (int /* keysDeleted */, error)
}

type RedisExpiringHashAPI[K1 any, K2 any] struct {
	client         rueidis.Client
	key1ToRedisKey KeyToRedisKey[K1]
	key2ToRedisKey KeyToRedisKey[K2]
	gcCounter      otelmetric.Int64Counter
	gcConflict     otelmetric.Int64Counter
	addOpts        []otelmetric.AddOption
}

func NewRedisExpiringHashAPI[K1 any, K2 any](name string, client rueidis.Client, key1ToRedisKey KeyToRedisKey[K1], key2ToRedisKey KeyToRedisKey[K2], m otelmetric.Meter) (*RedisExpiringHashAPI[K1, K2], error) {
	gcCounter, err := m.Int64Counter(
		gcDeletedKeysMetricName,
		otelmetric.WithDescription("Number of keys that have been garbage collected in a single pass"),
	)
	if err != nil {
		return nil, err
	}
	gcConflict, err := m.Int64Counter(
		gcConflictMetricName,
		otelmetric.WithDescription("Number of times garbage collection was aborted due to a concurrent hash mutation"),
	)
	if err != nil {
		return nil, err
	}

	return &RedisExpiringHashAPI[K1, K2]{
		client:         client,
		key1ToRedisKey: key1ToRedisKey,
		key2ToRedisKey: key2ToRedisKey,
		gcCounter:      gcCounter,
		gcConflict:     gcConflict,
		addOpts: []otelmetric.AddOption{
			otelmetric.WithAttributeSet(attribute.NewSet(expiringHashNameKey.String(name))),
		},
	}, nil
}

func (h *RedisExpiringHashAPI[K1, K2]) IOBuilder() IOBuilder[K1, K2] {
	return &redisIOBuilder[K1, K2]{
		client:         h.client,
		key1ToRedisKey: h.key1ToRedisKey,
		key2ToRedisKey: h.key2ToRedisKey,
	}
}

func (h *RedisExpiringHashAPI[K1, K2]) Scan(ctx context.Context, key K1) iter.Seq2[ScanEntry, error] {
	return func(yield func(ScanEntry, error) bool) {
		now := time.Now().Unix()
		redisKey := h.key1ToRedisKey(key)
		s := rueidis.NewScanner(func(cursor uint64) (rueidis.ScanEntry, error) {
			hscanCmd := h.client.B().Hscan().Key(redisKey).Cursor(cursor).Count(scanCount).Build()
			return h.client.Do(ctx, hscanCmd).AsScanEntry()
		})
		for k, v := range s.Iter2() {
			var msg ExpiringValue
			// Avoid creating a temporary copy
			vBytes := unsafe.Slice(unsafe.StringData(v), len(v)) //nolint: gosec
			err := proto.Unmarshal(vBytes, &msg)
			if err != nil {
				if !yield(ScanEntry{}, fmt.Errorf("failed to unmarshal hash value from hashkey 0x%x: %w", k, err)) {
					return
				}
				continue
			}
			if msg.ExpiresAt < now { // skip expired entry
				continue
			}
			if !yield(ScanEntry{RawHashKey: k, Value: msg.Value}, nil) {
				return
			}
		}
		err := s.Err()
		if err != nil {
			yield(ScanEntry{}, err)
		}
	}
}

func (h *RedisExpiringHashAPI[K1, K2]) GCFor(keys []K1, transactional bool) func(context.Context) (int /* keysDeleted */, error) {
	return func(ctx context.Context) (deletedKeys int, retErr error) {
		defer func() { //nolint:contextcheck
			h.gcCounter.Add(context.Background(), int64(deletedKeys), h.addOpts...)
		}()

		if transactional {
			return h.gcForTransactional(ctx, keys)
		} else {
			return h.gcForNonTransactional(ctx, keys)
		}
	}
}

func (h *RedisExpiringHashAPI[K1, K2]) gcForNonTransactional(ctx context.Context, keys []K1) (int /* keysDeleted */, error) {
	var deletedKeys int
	for _, key := range keys {
		deleted, err := h.gcHashNonTransactional(ctx, key)
		deletedKeys += deleted
		if err != nil {
			return deletedKeys, err
		}
	}
	return deletedKeys, nil
}

// gcHashNonTransactional iterates a hash and removes all expired values.
// It assumes that values are marshaled ExpiringValue.
func (h *RedisExpiringHashAPI[K1, K2]) gcHashNonTransactional(ctx context.Context, key K1) (int /* keysDeleted */, error) {
	redisKey := h.key1ToRedisKey(key)
	keysDeleted := 0
	var errs []error
	var batch []string
	sendBatch := func() {
		delCmd := h.client.B().Hdel().Key(redisKey).Field(batch...).Build()
		err := h.client.Do(ctx, delCmd).Error()
		if err != nil {
			errs = append(errs, err)
		}
		keysDeleted += len(batch)
		batch = batch[:0]
	}
	for keyToDelete, err := range getKeysToGC(ctx, redisKey, h.client) {
		if err != nil {
			errs = append(errs, err)
			continue
		}
		batch = append(batch, keyToDelete)
		if len(batch) >= maxDeleteBatchSize {
			sendBatch()
		}
	}
	if len(batch) > 0 {
		sendBatch()
	}
	return keysDeleted, errors.Join(errs...)
}

func (h *RedisExpiringHashAPI[K1, K2]) gcForTransactional(ctx context.Context, keys []K1) (int /* keysDeleted */, error) {
	var deletedKeys int
	client, cancel := h.client.Dedicate()
	defer cancel()
	for _, key := range keys {
		deleted, err := h.gcHashTransactional(ctx, key, client)
		deletedKeys += deleted
		switch err {
		case nil, errAttemptsExceeded: //nolint:errorlint
			// Try to GC next key on conflicts
		default:
			return deletedKeys, err
		}
	}
	return deletedKeys, nil
}

// gcHashTransactional iterates a hash and removes all expired values.
// It assumes that values are marshaled ExpiringValue.
// Returns attemptsExceeded if maxAttempts attempts were made but all failed.
func (h *RedisExpiringHashAPI[K1, K2]) gcHashTransactional(ctx context.Context, key K1, c rueidis.DedicatedClient) (int /* keysDeleted */, error) {
	var errs []error
	keysDeletedTotal := 0
	redisKey := h.key1ToRedisKey(key)
	finished := false
	for !finished {
		keysDeleted := 0
		// We don't want to delete a k->v mapping that has just been overwritten by another client. So use a transaction.
		// We don't want to retry too many times to GC to avoid spending too much time on it. Retry once.
		err := transaction(ctx, maxKeyGCAttempts, c, h.gcConflict, h.addOpts, func(ctx context.Context) ([]rueidis.Completed, error) {
			var keysToDelete []string
			prepareCommands := func() []rueidis.Completed {
				keysDeleted = len(keysToDelete)
				return []rueidis.Completed{
					c.B().Hdel().Key(redisKey).Field(keysToDelete...).Build(),
				}
			}
			for keyToDelete, err := range getKeysToGC(ctx, redisKey, c) {
				if err != nil {
					errs = append(errs, err)
					continue
				}
				keysToDelete = append(keysToDelete, keyToDelete)
				if len(keysToDelete) >= maxDeleteBatchSize {
					return prepareCommands(), nil
				}
			}
			finished = true
			if len(keysToDelete) == 0 {
				return nil, nil // errs is handled outside the closure
			}
			return prepareCommands(), nil
		}, redisKey)
		if err != nil {
			// Propagate attemptsExceeded error and any other errors as is.
			return keysDeletedTotal, err
		}
		keysDeletedTotal += keysDeleted
	}
	return keysDeletedTotal, errors.Join(errs...)
}

func getKeysToGC(ctx context.Context, redisKey string, c rueidis.CoreClient) iter.Seq2[string, error] {
	return func(yield func(string, error) bool) {
		now := time.Now().Unix()
		s := rueidis.NewScanner(func(cursor uint64) (rueidis.ScanEntry, error) {
			hscanCmd := c.B().Hscan().Key(redisKey).Cursor(cursor).Count(scanCount).Build()
			return c.Do(ctx, hscanCmd).AsScanEntry()
		})
		for k, v := range s.Iter2() {
			var msg ExpiringValueTimestamp
			// Avoid creating a temporary copy
			vBytes := unsafe.Slice(unsafe.StringData(v), len(v)) //nolint: gosec
			err := proto.UnmarshalOptions{
				DiscardUnknown: true, // We know there is one more field, but we don't need it
			}.Unmarshal(vBytes, &msg)
			if err != nil {
				if !yield(k, err) {
					return
				}
				continue
			}

			if msg.ExpiresAt < now {
				if !yield(k, nil) {
					return
				}
			}
		}
		err := s.Err()
		if err != nil {
			yield("", err)
		}
	}
}
