import logging
import uuid

from django.conf import settings
from django.contrib import auth, messages
from django.contrib.auth.middleware import RemoteUserMiddleware as RemoteUserMiddleware_
from django.core.exceptions import ImproperlyConfigured
from django.db import connection, ProgrammingError
from django.db.utils import InternalError
from django.http import Http404, HttpResponseRedirect
from django_prometheus import middleware

from netbox.config import clear_config, get_config
from netbox.metrics import Metrics
from netbox.views import handler_500
from utilities.api import is_api_request, is_graphql_request
from utilities.error_handlers import handle_rest_api_exception
from utilities.request import apply_request_processors

__all__ = (
    'CoreMiddleware',
    'MaintenanceModeMiddleware',
    'PrometheusAfterMiddleware',
    'PrometheusBeforeMiddleware',
    'RemoteUserMiddleware',
)


class CoreMiddleware:

    def __init__(self, get_response):
        self.get_response = get_response

    def __call__(self, request):

        # Assign a random unique ID to the request. This will be used for change logging.
        request.id = uuid.uuid4()

        # Apply all registered request processors
        with apply_request_processors(request):
            response = self.get_response(request)

        # Check if language cookie should be renewed
        if request.user.is_authenticated and settings.SESSION_SAVE_EVERY_REQUEST:
            if language := request.user.config.get('locale.language'):
                response.set_cookie(
                    key=settings.LANGUAGE_COOKIE_NAME,
                    value=language,
                    max_age=request.session.get_expiry_age(),
                    secure=settings.SESSION_COOKIE_SECURE,
                )

        # Attach the unique request ID as an HTTP header.
        response['X-Request-ID'] = request.id

        # Enable the Vary header to help with caching of HTMX responses
        response['Vary'] = 'HX-Request'

        # If this is an API request, attach an HTTP header annotating the API version (e.g. '3.5').
        if is_api_request(request):
            response['API-Version'] = settings.REST_FRAMEWORK_VERSION

        # Clear any cached dynamic config parameters after each request.
        clear_config()

        return response

    def process_exception(self, request, exception):
        """
        Implement custom error handling logic for production deployments.
        """
        # Don't catch exceptions when in debug mode
        if settings.DEBUG:
            return

        # Cleanly handle exceptions that occur from REST API requests
        if is_api_request(request):
            return handle_rest_api_exception(request)

        # Ignore Http404s (defer to Django's built-in 404 handling)
        if isinstance(exception, Http404):
            return

        # Determine the type of exception. If it's a common issue, return a custom error page with instructions.
        custom_template = None
        if isinstance(exception, ProgrammingError):
            custom_template = 'exceptions/programming_error.html'
        elif isinstance(exception, ImportError):
            custom_template = 'exceptions/import_error.html'
        elif isinstance(exception, PermissionError):
            custom_template = 'exceptions/permission_error.html'

        # Return a custom error message, or fall back to Django's default 500 error handling
        if custom_template:
            return handler_500(request, template_name=custom_template)


class RemoteUserMiddleware(RemoteUserMiddleware_):
    """
    Custom implementation of Django's RemoteUserMiddleware which allows for a user-configurable HTTP header name.
    """
    async_capable = False
    force_logout_if_no_header = False

    def __init__(self, get_response):
        if get_response is None:
            raise ValueError("get_response must be provided.")
        self.get_response = get_response

    @property
    def header(self):
        return settings.REMOTE_AUTH_HEADER

    def __call__(self, request):
        logger = logging.getLogger('netbox.authentication.RemoteUserMiddleware')
        # Bypass middleware if remote authentication is not enabled
        if not settings.REMOTE_AUTH_ENABLED:
            return self.get_response(request)
        # AuthenticationMiddleware is required so that request.user exists.
        if not hasattr(request, 'user'):
            raise ImproperlyConfigured(
                "The Django remote user auth middleware requires the"
                " authentication middleware to be installed.  Edit your"
                " MIDDLEWARE setting to insert"
                " 'django.contrib.auth.middleware.AuthenticationMiddleware'"
                " before the RemoteUserMiddleware class.")
        try:
            username = request.META[self.header]
        except KeyError:
            # If specified header doesn't exist then remove any existing
            # authenticated remote-user, or return (leaving request.user set to
            # AnonymousUser by the AuthenticationMiddleware).
            if self.force_logout_if_no_header and request.user.is_authenticated:
                self._remove_invalid_user(request)
            return self.get_response(request)
        # If the user is already authenticated and that user is the user we are
        # getting passed in the headers, then the correct user is already
        # persisted in the session and we don't need to continue.
        if request.user.is_authenticated:
            if request.user.get_username() == self.clean_username(username, request):
                return self.get_response(request)
            else:
                # An authenticated user is associated with the request, but
                # it does not match the authorized user in the header.
                self._remove_invalid_user(request)

        # We are seeing this user for the first time in this session, attempt
        # to authenticate the user.
        if settings.REMOTE_AUTH_GROUP_SYNC_ENABLED:
            logger.debug("Trying to sync Groups")
            user = auth.authenticate(
                request, remote_user=username, remote_groups=self._get_groups(request))
        else:
            user = auth.authenticate(request, remote_user=username)
        if user:
            # User is valid.
            # Update the User's Profile if set by request headers
            if settings.REMOTE_AUTH_USER_FIRST_NAME in request.META:
                user.first_name = request.META[settings.REMOTE_AUTH_USER_FIRST_NAME]
            if settings.REMOTE_AUTH_USER_LAST_NAME in request.META:
                user.last_name = request.META[settings.REMOTE_AUTH_USER_LAST_NAME]
            if settings.REMOTE_AUTH_USER_EMAIL in request.META:
                user.email = request.META[settings.REMOTE_AUTH_USER_EMAIL]
            user.save()

            # Set request.user and persist user in the session
            # by logging the user in.
            request.user = user
            auth.login(request, user)

        return self.get_response(request)

    def _get_groups(self, request):
        logger = logging.getLogger(
            'netbox.authentication.RemoteUserMiddleware')

        groups_string = request.META.get(
            settings.REMOTE_AUTH_GROUP_HEADER, None)
        if groups_string:
            groups = groups_string.split(settings.REMOTE_AUTH_GROUP_SEPARATOR)
        else:
            groups = []
        logger.debug(f"Groups are {groups}")
        return groups


class PrometheusBeforeMiddleware(middleware.PrometheusBeforeMiddleware):
    metrics_cls = Metrics


class PrometheusAfterMiddleware(middleware.PrometheusAfterMiddleware):
    metrics_cls = Metrics

    def process_response(self, request, response):
        response = super().process_response(request, response)

        # Increment REST API request counters
        if is_api_request(request):
            method = self._method(request)
            name = self._get_view_name(request)
            self.label_metric(self.metrics.rest_api_requests, request, method=method).inc()
            self.label_metric(self.metrics.rest_api_requests_by_view_method, request, method=method, view=name).inc()

        # Increment GraphQL API request counters
        elif is_graphql_request(request):
            self.metrics.graphql_api_requests.inc()

        return response


class MaintenanceModeMiddleware:
    """
    Middleware that checks if the application is in maintenance mode
    and restricts write-related operations to the database.
    """

    def __init__(self, get_response):
        self.get_response = get_response

    def __call__(self, request):
        if get_config().MAINTENANCE_MODE:
            self._set_session_type(
                allow_write=request.path_info.startswith(settings.MAINTENANCE_EXEMPT_PATHS)
            )

        return self.get_response(request)

    @staticmethod
    def _set_session_type(allow_write):
        """
        Prevent any write-related database operations.

        Args:
            allow_write (bool): If True, write operations will be permitted.
        """
        with connection.cursor() as cursor:
            mode = 'READ WRITE' if allow_write else 'READ ONLY'
            cursor.execute(f'SET SESSION CHARACTERISTICS AS TRANSACTION {mode};')

    def process_exception(self, request, exception):
        """
        Prevent any write-related database operations if an exception is raised.
        """
        if get_config().MAINTENANCE_MODE and isinstance(exception, InternalError):
            error_message = 'NetBox is currently operating in maintenance mode and is unable to perform write ' \
                            'operations. Please try again later.'

            if is_api_request(request):
                return handle_rest_api_exception(request, error=error_message)

            messages.error(request, error_message)
            return HttpResponseRedirect(request.path_info)
