import { createApi, fetchBaseQuery } from '@reduxjs/toolkit/query/react';

import { getAccessToken } from './auth';
import {
  IPredictionModelDashboardStatus,
  IPredictionModelShapValuesStatus,
  IPredictionModelShapValuesData,
  IFeatureImportanceData,
  IStatsDataType,
  IVisualizationLayerInfo,
} from '../services/types';
import { config } from '../config';
import {
  endpointHandler,
  parseResponseToFeatureImportanceData,
  parseResponseToStatsDataDefault,
  parseResponseToStatsDataMulticlassifier,
} from '../utils';

const baseUrl = config.API_GATEWAY_URL;

export const apiGateway = createApi({
  reducerPath: 'apiGateway',
  baseQuery: fetchBaseQuery({
    baseUrl: baseUrl,
    prepareHeaders: async (headers: Headers) => {
      const accessToken = await getAccessToken();
      headers.set('Authorization', accessToken);
      return headers;
    },
  }),
  endpoints: (builder) => ({
    // Prediction Models
    getPredictionModel: builder.query<void, { modelId: string; projectId: string }>({
      query: ({ modelId, projectId }) => ({
        url: endpointHandler('prediction-model', modelId, 'get'),
        method: 'GET',
        prepareHeaders: (headers: any) => {
          headers.set('x-project-id', projectId);
          return headers;
        },
        headers: {
          'x-project-id': projectId,
        },
      }),
    }),
    createPredictionModel: builder.mutation<void, { data: FormData; projectId: string }>({
      query: ({ data, projectId }) => ({
        url: endpointHandler('prediction-model', '', 'create'),
        method: 'POST',
        body: data,
        prepareHeaders: (headers: any) => {
          if (process.env['REACT_APP_ENVIRONMENT'] !== 'local') {
            headers.set('x-project-id', projectId);
          }
          return headers;
        },
        headers: {
          'x-project-id': projectId,
        },
      }),
    }),
    updatePredictionModel: builder.mutation<void, { data: FormData; projectId: string; modelId: string }>({
      query: ({ data, projectId, modelId }) => ({
        url: endpointHandler('prediction-model', modelId, 'update'),
        method: 'PATCH',
        body: data,
        prepareHeaders: (headers: any) => {
          headers.set('x-project-id', projectId);
          return headers;
        },
        headers: {
          'x-project-id': projectId,
        },
      }),
    }),
    deletePredictionModel: builder.mutation<void, { projectId: string; modelId: string }>({
      query: ({ projectId, modelId }) => ({
        url: endpointHandler('prediction-model', modelId, 'delete'),
        method: 'DELETE',
        prepareHeaders: (headers: any) => {
          headers.set('x-project-id', projectId);
          return headers;
        },
        headers: {
          'x-project-id': projectId,
        },
      }),
    }),
    getPredictionModelDashboardStatus: builder.query<
      IPredictionModelDashboardStatus,
      { taskArn: string; projectId: string }
    >({
      query: ({ taskArn, projectId }) => {
        const form = new FormData();
        form.append('taskArn', taskArn);
        return {
          url: endpointHandler('prediction-model', '', 'get-dashboard-status'),
          method: 'POST',
          body: form,
          prepareHeaders: (headers: any) => {
            headers.set('x-project-id', projectId);
            return headers;
          },
          headers: {
            'x-project-id': projectId,
          },
        };
      },
      transformResponse: (response: any) => {
        return {
          ecsLastStatus: response.ecs_last_status,
          taskLogState: response.task_log_state,
          taskLogMessage: response.task_log_message,
        };
      },
    }),
    generatePredictionModelShapValues: builder.mutation<
      IPredictionModelShapValuesData,
      { projectId: string; modelId: string }
    >({
      query: ({ projectId, modelId }) => ({
        url: endpointHandler('prediction-model', modelId, 'generate-shap-values'),
        method: 'POST',
        prepareHeaders: (headers: any) => {
          headers.set('x-project-id', projectId);
          return headers;
        },
        headers: {
          'x-project-id': projectId,
        },
      }),
    }),
    getPredictionModelShapValuesStatus: builder.query<
      IPredictionModelShapValuesStatus,
      { taskArn: string; projectId: string }
    >({
      query: ({ taskArn, projectId }) => {
        const form = new FormData();
        form.append('taskArn', taskArn);
        return {
          url: endpointHandler('prediction-model', '', 'get-shap-values-status'),
          method: 'POST',
          body: form,
          prepareHeaders: (headers: any) => {
            headers.set('x-project-id', projectId);
            return headers;
          },
          headers: {
            'x-project-id': projectId,
          },
        };
      },
      transformResponse: (response: any) => {
        return {
          ecsLastStatus: response.ecs_last_status,
          taskLogState: response.task_log_state,
          taskLogMessage: response.task_log_message,
        };
      },
    }),
    getPredictionModelFeatureImportanceData: builder.query<
      IFeatureImportanceData,
      { projectId: string; modelId: string }
    >({
      query: ({ projectId, modelId }) => ({
        url: endpointHandler('prediction-model', modelId, 'get-feature-importance-data'),
        method: 'GET',
        prepareHeaders: (headers: any) => {
          headers.set('x-project-id', projectId);
          return headers;
        },
        headers: {
          'x-project-id': projectId,
        },
      }),
      transformResponse: (response: any) => {
        return parseResponseToFeatureImportanceData(response);
      },
    }),
    getPredictionModelTestData: builder.query<IStatsDataType, { projectId: string; modelId: string }>({
      query: ({ projectId, modelId }) => ({
        url: endpointHandler('prediction-model', modelId, 'get-test-data'),
        method: 'GET',
        prepareHeaders: (headers: any) => {
          headers.set('x-project-id', projectId);
          return headers;
        },
        headers: {
          'x-project-id': projectId,
        },
      }),
      transformResponse: (response: any) => {
        return response.report
          ? parseResponseToStatsDataMulticlassifier(response)
          : parseResponseToStatsDataDefault(response);
      },
    }),
    getPredictionModelPredictionData: builder.query<IStatsDataType, { projectId: string; modelId: string }>({
      query: ({ projectId, modelId }) => ({
        url: endpointHandler('prediction-model', modelId, 'get-prediction-data'),
        method: 'GET',
        prepareHeaders: (headers: any) => {
          headers.set('x-project-id', projectId);
          return headers;
        },
        headers: {
          'x-project-id': projectId,
        },
      }),
      transformResponse: (response: any) => {
        return response.report
          ? parseResponseToStatsDataMulticlassifier(response)
          : parseResponseToStatsDataDefault(response);
      },
    }),
    getPredictionLayers: builder.query<IVisualizationLayerInfo, { modelId: string; projectId: string }>({
      query: ({ modelId, projectId }) => ({
        url: endpointHandler('prediction-model', modelId, 'get-visualization-layers'),
        method: 'GET',
        prepareHeaders: (headers: any) => {
          headers.set('x-project-id', projectId);
          return headers;
        },
        headers: {
          'x-project-id': projectId,
        },
      }),
    }),
  }),
});
