from django.contrib.auth.models import User
from django.test.utils import override_settings
from django.urls import reverse
from rest_framework import status
from rest_framework.test import APIClient

from users.models import Token

from .base import ModelTestCase

__all__ = ("APITestCase", "APIViewTestCases")


class APITestCase(ModelTestCase):
    client_class = APIClient
    view_namespace = None

    def setUp(self):
        """
        Creates a superuser and token for API calls.
        """
        self.user = User.objects.create(
            username="testuser", is_staff=True, is_superuser=True
        )
        self.token = Token.objects.create(user=self.user)
        self.header = {"HTTP_AUTHORIZATION": f"Token {self.token.key}"}

    def _get_view_namespace(self):
        return f"{self.view_namespace or self.model._meta.app_label}-api"

    def _get_detail_url(self, instance):
        viewname = f"{self._get_view_namespace()}:{instance._meta.model_name}-detail"
        return reverse(viewname, kwargs={"pk": instance.pk})

    def _get_list_url(self):
        viewname = f"{self._get_view_namespace()}:{self.model._meta.model_name}-list"
        return reverse(viewname)


class APIViewTestCases:
    class GetObjectView(APITestCase):
        @override_settings(LOGIN_REQUIRED=False)
        def test_get_object_anonymous(self):
            """
            GET a single object as an unauthenticated user.
            """
            url = self._get_detail_url(self._get_queryset().first())
            response = self.client.get(url, **self.header)
            self.assertHttpStatus(response, status.HTTP_200_OK)

        @override_settings(LOGIN_REQUIRED=True)
        def test_get_object(self):
            """
            GET a single object as an authenticated user with permission to view the
            object.
            """
            self.assertGreaterEqual(
                self._get_queryset().count(),
                1,
                f"Test requires the creation of at least one {self.model} instance",
            )
            instance = self._get_queryset()[0]

            self.add_permissions("view")

            # Try GET to permitted object
            url = self._get_detail_url(instance)
            self.assertHttpStatus(
                self.client.get(url, **self.header), status.HTTP_200_OK
            )

        @override_settings(LOGIN_REQUIRED=True)
        def test_get_object_fields(self):
            """
            GET a single object using the "fields" parameter.
            """
            self.assertGreaterEqual(
                self._get_queryset().count(),
                1,
                f"Test requires the creation of at least one {self.model} instance",
            )
            instance = self._get_queryset()[0]

            self.add_permissions("view")

            fields = getattr(
                self, "query_fields", ["id", "url", "display_url", "display"]
            )
            url = f"{self._get_detail_url(instance)}?fields={','.join(fields)}"
            response = self.client.get(url, **self.header)

            self.assertHttpStatus(response, status.HTTP_200_OK)
            self.assertEqual(sorted(response.data), sorted(fields))

        @override_settings(LOGIN_REQUIRED=True)
        def test_options_object(self):
            """
            Make an OPTIONS request for a single object.
            """
            url = self._get_detail_url(self._get_queryset().first())
            response = self.client.options(url, **self.header)
            self.assertHttpStatus(response, status.HTTP_200_OK)

    class ListObjectsView(APITestCase):
        brief_fields = []

        def test_list_objects(self):
            """
            GET a list of objects.
            """
            url = self._get_list_url()
            response = self.client.get(url, **self.header)

            self.assertEqual(len(response.data["results"]), self.model.objects.count())

        def test_list_objects_brief(self):
            """
            GET a list of objects using the "brief" parameter.
            """
            url = f"{self._get_list_url()}?brief=1"
            response = self.client.get(url, **self.header)

            self.assertEqual(len(response.data["results"]), self.model.objects.count())
            self.assertEqual(
                sorted(response.data["results"][0]), sorted(self.brief_fields)
            )

        def test_list_objects_fields(self):
            """
            GET a list of objects using the "fields" parameter.
            """
            fields = getattr(
                self, "query_fields", ["id", "url", "display_url", "display"]
            )
            url = f"{self._get_list_url()}?fields={','.join(fields)}"
            response = self.client.get(url, **self.header)

            self.assertEqual(len(response.data["results"]), self.model.objects.count())
            self.assertEqual(sorted(response.data["results"][0]), sorted(fields))

    class CreateObjectView(APITestCase):
        create_data = []
        validation_excluded_fields = []

        def test_create_object(self):
            """
            POST a single object.
            """
            initial_count = self.model.objects.count()
            url = self._get_list_url()
            response = self.client.post(
                url, self.create_data[0], format="json", **self.header
            )

            self.assertHttpStatus(response, status.HTTP_201_CREATED)
            self.assertEqual(self.model.objects.count(), initial_count + 1)
            self.assertInstanceEqual(
                self.model.objects.get(pk=response.data["id"]),
                self.create_data[0],
                exclude=self.validation_excluded_fields,
                api=True,
            )

        def test_bulk_create_object(self):
            """
            POST a set of objects in a single request.
            """
            initial_count = self.model.objects.count()
            url = self._get_list_url()
            response = self.client.post(
                url, self.create_data, format="json", **self.header
            )

            self.assertHttpStatus(response, status.HTTP_201_CREATED)
            self.assertEqual(
                self.model.objects.count(), initial_count + len(self.create_data)
            )

    class UpdateObjectView(APITestCase):
        update_data = {}
        bulk_update_data = None
        validation_excluded_fields = []

        def test_update_object(self):
            """
            PATCH a single object identified by its numeric ID.
            """
            instance = self.model.objects.first()
            url = self._get_detail_url(instance)
            update_data = self.update_data or self.create_data[0]
            response = self.client.patch(url, update_data, format="json", **self.header)

            self.assertHttpStatus(response, status.HTTP_200_OK)
            instance.refresh_from_db()
            self.assertInstanceEqual(
                instance,
                self.update_data,
                exclude=self.validation_excluded_fields,
                api=True,
            )

        def test_bulk_update_objects(self):
            """
            PATCH a set of objects in a single request.
            """
            if self.bulk_update_data is None:
                self.skipTest("Bulk update data not set")

            id_list = self._get_queryset().values_list("id", flat=True)[:3]
            self.assertEqual(len(id_list), 3, "Not enough objects to test bulk update")
            data = [{"id": id, **self.bulk_update_data} for id in id_list]

            response = self.client.patch(
                self._get_list_url(), data, format="json", **self.header
            )
            self.assertHttpStatus(response, status.HTTP_200_OK)

            for i, obj in enumerate(response.data):
                for field in self.bulk_update_data:
                    self.assertIn(
                        field,
                        obj,
                        f"Bulk update field '{field}' missing from object {i} in response",
                    )
            for instance in self._get_queryset().filter(pk__in=id_list):
                self.assertInstanceEqual(instance, self.bulk_update_data, api=True)

    class DeleteObjectView(APITestCase):
        def test_delete_object(self):
            """
            DELETE a single object identified by its numeric ID.
            """
            instance = self.model.objects.first()
            url = self._get_detail_url(instance)
            response = self.client.delete(url, **self.header)

            self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT)
            self.assertFalse(self.model.objects.filter(pk=instance.pk).exists())

        def test_bulk_delete_objects(self):
            """
            DELETE a set of objects in a single request.
            """
            # Target the three most recently created objects
            id_list = (
                self._get_queryset().order_by("-id").values_list("id", flat=True)[:3]
            )
            self.assertEqual(
                len(id_list), 3, "Not enough objects to test bulk deletion"
            )
            data = [{"id": id} for id in id_list]

            initial_count = self._get_queryset().count()
            response = self.client.delete(
                self._get_list_url(), data, format="json", **self.header
            )
            self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT)
            self.assertEqual(self._get_queryset().count(), initial_count - 3)

    class View(
        GetObjectView,
        ListObjectsView,
        CreateObjectView,
        UpdateObjectView,
        DeleteObjectView,
    ):
        pass
