package grpctool

import (
	"context"
	"fmt"
	"net"
	"net/url"

	"github.com/ash2k/stager"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/tlstool"
	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/credentials"
	"google.golang.org/grpc/status"
)

func RequestCanceledOrTimedOut(err error) bool {
	return RequestCanceled(err) || RequestTimedOut(err)
}

func RequestCanceled(err error) bool {
	for err != nil {
		if err == context.Canceled { //nolint:errorlint
			return true
		}
		code := status.Code(err)
		if code == codes.Canceled {
			return true
		}
		switch x := err.(type) { //nolint:errorlint
		case interface{ Unwrap() error }:
			err = x.Unwrap()
		case interface{ Unwrap() []error }: // support errors produced by errors.Join()
			for _, err = range x.Unwrap() {
				if RequestCanceled(err) {
					return true
				}
			}
			return false
		default:
			return false
		}
	}
	return false
}

func RequestTimedOut(err error) bool {
	for err != nil {
		if err == context.DeadlineExceeded { //nolint:errorlint
			return true
		}
		code := status.Code(err)
		if code == codes.DeadlineExceeded {
			return true
		}
		switch x := err.(type) { //nolint:errorlint
		case interface{ Unwrap() error }:
			err = x.Unwrap()
		case interface{ Unwrap() []error }: // support errors produced by errors.Join()
			for _, err = range x.Unwrap() {
				if RequestTimedOut(err) {
					return true
				}
			}
			return false
		default:
			return false
		}
	}
	return false
}

func AugmentErrorMessage(defaultCode codes.Code, msg string, err error) error {
	statusGetter, ok := err.(interface {
		GRPCStatus() *status.Status
	})
	if ok {
		s := statusGetter.GRPCStatus().Proto()
		s.Message = fmt.Sprintf("%s: %s", msg, s.Message)
		err = status.ErrorProto(s)
	} else {
		err = status.Errorf(defaultCode, "%s: %v", msg, err)
	}
	return err
}

func StartServer(stage stager.Stage, server *grpc.Server, listener func(context.Context) (net.Listener, error), beforeStop, afterStop func()) {
	stage.Go(func(ctx context.Context) error {
		// gRPC listener
		lis, err := listener(ctx)
		if err != nil {
			return err
		}
		return server.Serve(lis)
	})
	stage.GoWhenDone(func() error {
		// Can be called because Serve() failed, main ctx was canceled, or some stage failed.
		beforeStop()
		server.GracefulStop()
		afterStop()
		return nil
	})
}

func IsStatusError(err error) bool {
	_, ok := err.(interface {
		GRPCStatus() *status.Status
	})
	return ok
}

// StatusErrorFromContext is a version of status.FromContextError(ctx.Err()).Err() that allows to augment the
// error message.
func StatusErrorFromContext(ctx context.Context, msg string) error {
	err := ctx.Err()
	var code codes.Code
	switch err {
	case context.Canceled:
		code = codes.Canceled
	case context.DeadlineExceeded:
		code = codes.DeadlineExceeded
	default:
		code = codes.Unknown
	}
	return status.Errorf(code, "%s: %v", msg, err)
}

// HostWithPort adds port if it was not specified in a URL with a "grpc" or "grpcs" scheme.
func HostWithPort(u *url.URL) string {
	port := u.Port()
	if port != "" {
		return u.Host
	}
	switch u.Scheme {
	case "grpc":
		return net.JoinHostPort(u.Host, "80")
	case "grpcs":
		return net.JoinHostPort(u.Host, "443")
	default:
		// Function called with unknown scheme, just return the original host.
		return u.Host
	}
}

func MaybeTLSCreds(certFile, keyFile string) ([]grpc.ServerOption, error) {
	config, err := tlstool.MaybeServerConfig(certFile, keyFile)
	if err != nil {
		return nil, err
	}
	if config == nil {
		return nil, nil
	}
	return []grpc.ServerOption{grpc.Creds(credentials.NewTLS(config))}, nil
}

func MaybeMTLSCreds(certFile, keyFile, mtlsClientCAFile string, mtlsEnabled bool) ([]grpc.ServerOption, error) {
	config, err := tlstool.MaybeServerConfigWithMTLS(certFile, keyFile, mtlsClientCAFile, mtlsEnabled)
	if err != nil {
		return nil, err
	}
	if config == nil {
		return nil, nil
	}
	return []grpc.ServerOption{grpc.Creds(credentials.NewTLS(config))}, nil
}
