import Experiment from '@models/Experiment';
import Plot from '@models/Plot';
import useSWR from 'swr';
import Endpoints from '@services/Endpoints';
import useApi from '@hooks/useApi';
import { AssaySummaryAnalysis } from '@models/analysis/AssaySummaryAnalysis';
import { AnalysisGroupV3, AnalysisGroupV3Response } from '../models/AnalysisParameters';
import { useEffect, useState } from 'react';
import { usePollingEffect } from './usePollingEffect';

type Props = {
    annotation_set_id: string;
    experiment: Experiment;
    latent_variable_id: string;
    plot: Plot;
    variable_ids: string[];
};
const useExperimentPlotGroupsV3 = ({
    annotation_set_id,
    experiment,
    latent_variable_id,
    plot,
    variable_ids = [],
}: Props) => {
    const experimentId = experiment.uuid;
    const plotId = plot.uuid;
    const fetchGroups = Boolean(latent_variable_id || annotation_set_id || variable_ids.length > 0);
    const [loading, setLoading] = useState(fetchGroups ?? false);
    const [hasError, setHasError] = useState(false);
    const [groups, setGroups] = useState<AnalysisGroupV3[]>([]);
    const [taskId, setTaskId] = useState('');

    const api = useApi();

    const { data } = useSWR<AnalysisGroupV3>(
        fetchGroups ? Endpoints.lab.experiment.plot.groupsV3({ experimentId, plotId }) : null,
        {
            fetcher: (url) =>
                api.post(url, {
                    ...(variable_ids.length > 0 ? { variable_ids: [...(variable_ids ?? [])].sort() } : undefined),
                    ...(annotation_set_id ? { annotation_set_id } : undefined),
                    ...(latent_variable_id ? { latent_variable_id } : undefined),
                }),
            revalidateOnMount: true,
            revalidateOnFocus: false,
            revalidateOnReconnect: false,
        },
    );

    useEffect(() => {
        if (data?.task_id) {
            setTaskId(data.task_id);
        }
    }, [data?.task_id]);

    usePollingEffect(
        async () => {
            if (!taskId) return;
            setLoading(true);
            const getResponse = await api.get<AnalysisGroupV3Response>(Endpoints.external_tools.celeryTask(taskId));
            if (getResponse?.status === 'SUCCESS') {
                setGroups(getResponse?.result);
                setLoading(false);
                setHasError(false);
                setTaskId('');
            }
            if (getResponse?.status === 'FAILED') {
                setHasError(true);
                setLoading(false);
                setTaskId('');
            }
        },
        [taskId],
        {
            trigger: Boolean(taskId),
            interval: 3_000,
        },
    );

    /**
     * Get group details for a given group ID.
     * @param {number} group_id
     * @return {AnalysisGroupV3 | null}
     */
    const getGroupById = (group_id: number): AnalysisGroupV3 | null => {
        return groups.find((g) => `${g.uuid}` === `${group_id}`) ?? null;
    };

    const analysis = plot?.analysis as AssaySummaryAnalysis | null;
    const selectedGroupIds = analysis?.groups?.map((g) => g.uuid) ?? [];
    /**
     * The groups that were selected in the plot's analysis
     */
    const selectedGroups = groups.filter((g) => selectedGroupIds.includes(g.uuid));
    return { hasError, loading, groups, getGroupById, selectedGroups };
};

export default useExperimentPlotGroupsV3;
