package agent

import (
	"context"
	"errors"
	"fmt"
	"testing"
	time "time"

	"github.com/stretchr/testify/assert"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/internal/tool/testing/testlogger"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v18/pkg/agentcfg"
	gomock "go.uber.org/mock/gomock"
	corev1 "k8s.io/api/core/v1"
	watch "k8s.io/apimachinery/pkg/watch"
	"k8s.io/client-go/kubernetes/fake"
)

const (
	ocsServiceAccountName = "gitlab-agent-ocs-service-account"
)

var (
	terminatedContainerState = corev1.ContainerStatus{
		State: corev1.ContainerState{
			Terminated: &corev1.ContainerStateTerminated{},
		},
	}

	trivyK8sWrapperImage = agentcfg.TrivyK8SWrapperImage{
		Repository: "registry.gitlab.com/security-products/trivy-k8s-wrapper",
		Tag:        "latest",
	}
)

func TestNamespaceScanner_ScanPodSucceeded(t *testing.T) {
	t.Parallel()
	tests := []struct {
		name                 string
		setupMocks           func(MockReporter, MockScanningManager)
		expectedUUIDs        []string
		expectedError        error
		deleteReportArtifact bool
	}{
		{
			name: "Pod Succeeded with valid UUIDs returned",
			setupMocks: func(mockReporter MockReporter, mockScanningManager MockScanningManager) {
				mockReporter.EXPECT().
					Transmit(gomock.Any(), gomock.Any()).
					Return([]string{uuid1, uuid2}, nil)

				mockScanningManager.EXPECT().readChainedConfigmaps(gomock.Any(), namespace1)
				mockScanningManager.EXPECT().parseScanningPodPayload(gomock.Any(), gomock.Any())
			},
			expectedUUIDs: []string{uuid1, uuid2},
			expectedError: nil,
		},
		{
			name: "Pod Succeeded but failed to read chained config map",
			setupMocks: func(mockReporter MockReporter, mockScanningManager MockScanningManager) {
				mockScanningManager.EXPECT().readChainedConfigmaps(gomock.Any(), namespace1).
					Return(nil, "", errors.New("readChainedConfigmaps error"))
			},
			expectedError: fmt.Errorf("could not read chained ConfigMaps: %w", errors.New("readChainedConfigmaps error")),
		},
		{
			name: "Pod Succeeded but failed to parse payload",
			setupMocks: func(mockReporter MockReporter, mockScanningManager MockScanningManager) {
				mockScanningManager.EXPECT().readChainedConfigmaps(gomock.Any(), namespace1)
				mockScanningManager.EXPECT().parseScanningPodPayload(gomock.Any(), gomock.Any()).
					Return(nil, errors.New("parseScanningPodPayload error"))
			},
			expectedError: fmt.Errorf("could not parse OCS scanning pod report: %w", errors.New("parseScanningPodPayload error")),
		},
		{
			name: "Pod Succeeded and Delete Configmaps",
			setupMocks: func(mockReporter MockReporter, mockScanningManager MockScanningManager) {
				mockReporter.EXPECT().
					Transmit(gomock.Any(), gomock.Any()).
					Return([]string{uuid1, uuid2}, nil)

				mockScanningManager.EXPECT().readChainedConfigmaps(gomock.Any(), namespace1)
				mockScanningManager.EXPECT().parseScanningPodPayload(gomock.Any(), gomock.Any())
			},
			expectedUUIDs:        []string{uuid1, uuid2},
			expectedError:        nil,
			deleteReportArtifact: true,
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			// Set up mocks
			ctrl := gomock.NewController(t)
			mockReporter := NewMockReporter(ctrl)
			mockScanningManager := NewMockScanningManager(ctrl)
			setupCommonScanningManagerMocks(*mockScanningManager, corev1.PodSucceeded, nil, tt.deleteReportArtifact)
			tt.setupMocks(*mockReporter, *mockScanningManager)

			kubeClientset := fake.NewSimpleClientset()
			nsScanner := namespaceScanner{
				kubeClientset:             kubeClientset,
				gitlabAgentNamespace:      gitlabAgentNamespace,
				gitlabAgentServiceAccount: gitlabAgentServiceAccount,
				agentKey:                  testAgentKey,
				ocsServiceAccountName:     ocsServiceAccountName,
				trivyK8sWrapperImage:      &trivyK8sWrapperImage,
				deleteReportArtifact:      tt.deleteReportArtifact,
			}

			uuids, err := nsScanner.scan(context.Background(), testlogger.New(t), namespace1, mockReporter, mockScanningManager)

			if tt.expectedError != nil {
				assert.Equal(t, tt.expectedError, err)
			} else {
				assert.NoError(t, err)
				assert.Equal(t, tt.expectedUUIDs, uuids)
			}
		})
	}
}

func TestNamespaceScanner_ScanPodFailedOrPending(t *testing.T) {
	t.Parallel()
	tests := []struct {
		name                 string
		podPhase             corev1.PodPhase
		containerStatuses    []corev1.ContainerStatus
		setupMocks           func(MockReporter, MockScanningManager)
		expectedError        error
		deleteReportArtifact bool
	}{
		{
			name:              "Pod Failed due to multiple containers statuses",
			podPhase:          corev1.PodFailed,
			containerStatuses: []corev1.ContainerStatus{terminatedContainerState, terminatedContainerState},
			expectedError:     errors.New("OCS Scanning pod should have only one container"),
		},
		{
			name:              "Pod Failed due to scanning pod termination",
			podPhase:          corev1.PodFailed,
			containerStatuses: []corev1.ContainerStatus{terminatedContainerState},
			setupMocks: func(mockReporter MockReporter, mockScanningManager MockScanningManager) {
				mockScanningManager.EXPECT().extractExitCodeError(gomock.Any(), gomock.Any()).
					Return(errors.New("failed to execute a Trivy scan"))
			},
			expectedError: errors.New("failed to execute a Trivy scan"),
		},
		{
			name:     "Pod Failed with unexpected reason",
			podPhase: corev1.PodFailed,
			containerStatuses: []corev1.ContainerStatus{{
				State: corev1.ContainerState{
					Waiting: &corev1.ContainerStateWaiting{},
				},
			}},
			expectedError: errors.New("OCS Scanning pod exited with an error. Could not retrieve an exit code"),
		},
		{
			name:              "Pod Failed and Configmap deleted",
			podPhase:          corev1.PodFailed,
			containerStatuses: []corev1.ContainerStatus{terminatedContainerState},
			setupMocks: func(mockReporter MockReporter, mockScanningManager MockScanningManager) {
				mockScanningManager.EXPECT().extractExitCodeError(gomock.Any(), gomock.Any()).
					Return(errors.New("failed to execute a Trivy scan"))
			},
			expectedError:        errors.New("failed to execute a Trivy scan"),
			deleteReportArtifact: true,
		},
	}

	pendingReasons := []string{"ImagePullBackOff", "ImageInspectError", "ErrImagePull", "ErrImageNeverPull", "InvalidImageName"}
	for _, reason := range pendingReasons {
		tests = append(tests, struct {
			name                 string
			podPhase             corev1.PodPhase
			containerStatuses    []corev1.ContainerStatus
			setupMocks           func(MockReporter, MockScanningManager)
			expectedError        error
			deleteReportArtifact bool
		}{
			name:     fmt.Sprintf("Pod Pending due to %s", reason),
			podPhase: corev1.PodPending,
			containerStatuses: []corev1.ContainerStatus{
				{
					State: corev1.ContainerState{
						Waiting: &corev1.ContainerStateWaiting{
							Reason: reason,
						},
					},
				},
			},
			expectedError: fmt.Errorf("waiting on OCS Scanning Pod error:  reason: %s", reason),
		})
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			ctrl := gomock.NewController(t)

			// reporter.Transmit should not be called if the scan pod does not complete successfully
			mockReporter := NewMockReporter(ctrl)

			mockScanningManager := NewMockScanningManager(ctrl)
			setupCommonScanningManagerMocks(*mockScanningManager, tt.podPhase, tt.containerStatuses, tt.deleteReportArtifact)
			if tt.setupMocks != nil {
				tt.setupMocks(*mockReporter, *mockScanningManager)
			}

			kubeClientset := fake.NewSimpleClientset()
			nsScanner := namespaceScanner{
				kubeClientset:             kubeClientset,
				gitlabAgentNamespace:      gitlabAgentNamespace,
				gitlabAgentServiceAccount: gitlabAgentServiceAccount,
				agentKey:                  testAgentKey,
				ocsServiceAccountName:     ocsServiceAccountName,
				trivyK8sWrapperImage:      &trivyK8sWrapperImage,
				deleteReportArtifact:      tt.deleteReportArtifact,
			}

			_, err := nsScanner.scan(context.Background(), testlogger.New(t), namespace1, mockReporter, mockScanningManager)

			assert.Equal(t, tt.expectedError, err)
		})
	}
}

func TestNamespaceScanner_ScanContextCancelled(t *testing.T) {
	t.Parallel()
	ctrl := gomock.NewController(t)

	image := fmt.Sprintf("%s:%s", trivyK8sWrapperImage.Repository, trivyK8sWrapperImage.Tag)

	// reporter.Transmit should not be called if the scan pod does not complete successfully
	mockReporter := NewMockReporter(ctrl)

	// Set up mocks for scanning manager until when the watchScanningPod method is called
	mockScanningManager := NewMockScanningManager(ctrl)
	podName := fmt.Sprintf("trivy-scan-%s", namespace1)
	mockCall := mockScanningManager.EXPECT()
	mockCall.deleteChainedConfigmaps(gomock.Any(), namespace1)
	mockCall.deployScanningPod(gomock.Any(), podName, namespace1, ocsServiceAccountName, image)
	mockCall.deleteScanningPod(podName)

	customWatcher := &customWatcher{
		eventChan: make(chan watch.Event),
	}
	// Note that we are not sending an event to the watcher to be able to catch the context cancel event
	mockCall.
		watchScanningPod(gomock.Any(), podName).
		Return(customWatcher, nil)

	kubeClientset := fake.NewSimpleClientset()
	nsScanner := namespaceScanner{
		kubeClientset:             kubeClientset,
		gitlabAgentNamespace:      gitlabAgentNamespace,
		gitlabAgentServiceAccount: gitlabAgentServiceAccount,
		agentKey:                  testAgentKey,
		ocsServiceAccountName:     ocsServiceAccountName,
		trivyK8sWrapperImage:      &trivyK8sWrapperImage,
	}

	// Create a context that's already canceled to test that the scan method returns an error
	ctx, cancel := context.WithDeadline(context.Background(), time.Now())
	defer cancel()

	_, err := nsScanner.scan(ctx, testlogger.New(t), namespace1, mockReporter, mockScanningManager)

	// A context deadline exceeded error should be returned.
	assert.Equal(t, err, context.DeadlineExceeded)
}

type customWatcher struct {
	eventChan chan watch.Event
}

func (cw *customWatcher) Stop() {}

func (cw *customWatcher) ResultChan() <-chan watch.Event {
	return cw.eventChan
}

func setupCommonScanningManagerMocks(mockScanningManager MockScanningManager, desiredPodStatus corev1.PodPhase, containerStatuses []corev1.ContainerStatus, deleteReportArtifact bool) {
	podName := fmt.Sprintf("trivy-scan-%s", namespace1)

	image := fmt.Sprintf("%s:%s", trivyK8sWrapperImage.Repository, trivyK8sWrapperImage.Tag)

	mockCall := mockScanningManager.EXPECT()
	mockCall.deployScanningPod(gomock.Any(), podName, namespace1, ocsServiceAccountName, image)
	mockCall.deleteScanningPod(podName)

	if deleteReportArtifact {
		mockCall.deleteChainedConfigmaps(gomock.Any(), namespace1).MinTimes(2)
	} else {
		mockCall.deleteChainedConfigmaps(gomock.Any(), namespace1).MaxTimes(1)
	}

	eventChan := make(chan watch.Event, 1)
	eventChan <- watch.Event{
		Type: watch.Modified,
		Object: &corev1.Pod{
			Status: corev1.PodStatus{
				Phase:             desiredPodStatus,
				ContainerStatuses: containerStatuses,
			},
		},
	}
	customWatcher := &customWatcher{
		eventChan: eventChan,
	}
	mockCall.
		watchScanningPod(gomock.Any(), podName).
		Return(customWatcher, nil)
}
