import { blankToNull, isBlank, isNotBlank } from '@util/StringUtil';
import { DataPoint } from '@components/analysisCategories/comparative/plots/VolcanoPlotUtil';
import { ArrowData } from '@util/ObjectUtil';
import { DEGSample } from '@models/ExperimentData';
import VolcanoPlotDisplayOption from '@models/plotDisplayOption/VolcanoPlotDisplayOption';
import { getVolcanoPlotThemeColors } from '@components/plots/PlotUtil';
import { PlotParams } from 'react-plotly.js';
import PaletteColors from '@components/PaletteColors';
import { isDefined } from '@util/TypeGuards';
import { ObservedSize } from '@hooks/useDebouncedResizeObserver';

import Logger from '@util/Logger';

const logger = Logger.make('PlotlyVolcanoPlotUtil');

export interface StylingOption {
    fontColor: string;
    fontSize: number;
    fontFamily: string;
}

export interface CustomPlotStylingOptions {
    labeled_points?: StylingOption;
    title: StylingOption;
    xaxis?: StylingOption;
    yaxis?: StylingOption;
}

export type DragMode = PlotParams['layout']['dragmode'];
export type IndexedDataPoint = DataPoint & { index: number };
export const POINT_SIZE = 9;
export const LINE_WIDTH = 2;
export const POINT_OPACITY = 0.75;
export const LABELED_POINT_COLOR = '#7BD2F1';
export const LINE_DASH_OPTIONS = [
    {
        name: 'solid',
        label: 'Solid',
    },
    {
        name: 'dot',
        label: 'Dot',
    },
    {
        name: 'dash',
        label: 'Dash',
    },
    {
        name: 'longdash',
        label: 'Long Dash',
    },
    {
        name: 'dashdot',
        label: 'Dash Dot',
    },
    {
        name: 'longdashdot',
        label: 'Long Dash Dot',
    },
];

export type PlotRange = {
    yMin: number;
    yMax: number;
    xMin: number;
    xMax: number;
};
export const getTargetSymbol = (d: IndexedDataPoint): string | null => {
    return d.Gene_Symbol ?? d.gene_id ?? d.peak_id ?? d.probe_id ?? null;
};

export const getTargetData = (d: DataPoint): string | null => {
    return d.probe_id ?? d.peak_id ?? d.protein_id ?? d.metabolite_id ?? d.Gene_Symbol ?? d.gene_id ?? null;
};
export const getAllPossibleTargetSymbols = (d: IndexedDataPoint): string[] => {
    return [d.peak_id, d.Gene_Symbol, d.gene_id, d.probe_id].filter(isDefined);
};

export const getHoverTemplate = (d: IndexedDataPoint) => {
    const firstLine =
        blankToNull(d.probe_id) ?? blankToNull(d.peak_id) ?? blankToNull(d.protein_id) ?? blankToNull(d.metabolite_id);

    return `${firstLine ? `<b>${firstLine}</b>` : ''}
${
    isNotBlank(d.peak_id) || isNotBlank(d.probe_id) || isNotBlank(d.protein_id) || isNotBlank(d.metabolite_id)
        ? ''
        : `<br><b>${d.Gene_Symbol}</b>`
}
${
    (isBlank(d.peak_id) && isBlank(d.probe_id) && isBlank(d.protein_id) && isBlank(d.metabolite_id)) ||
    !blankToNull(d.Gene_Symbol ?? d.gene_id)
        ? ''
        : `<br>gene: ${blankToNull(d.Gene_Symbol ?? d.gene_id) ?? 'none??'}`
}
<br>x: %{x}
<br>y: %{y}
<extra></extra>
`.trim();
};

export const getXValueAtIndex = ({ index, items }: { index: number; items: ArrowData<DEGSample> }) => {
    return items.Log2_Fold_Change[index];
};

export const getSampleForIndex = ({
    index,
    items,
}: {
    index: number;
    items: ArrowData<DEGSample>;
}): IndexedDataPoint => {
    const keys = Object.keys(items);
    const sample = keys.reduce<Partial<IndexedDataPoint>>(
        (agg, key) => {
            agg[key] = items[key][index];
            return agg;
        },
        { index: index },
    );
    sample.x = getXValueAtIndex({ index, items });
    const yVal = items.Adj_P_Value[index];
    if (yVal === 0) {
        sample.y = 0;
    } else {
        sample.y = -1 * Math.log10(yVal);
    }
    return sample as IndexedDataPoint;
};

export const getLineColors = (display?: VolcanoPlotDisplayOption | null) => {
    const volcanoThemeColors = getVolcanoPlotThemeColors(display?.theme_color);
    const customColors = display?.custom_color_json ?? {};

    return {
        positive: customColors['increased'] ?? volcanoThemeColors.positive.color,
        negative: customColors['decreased'] ?? volcanoThemeColors.negative.color,
        vertical: customColors['vertical'] ?? volcanoThemeColors.vertical.color,
        horizontal: customColors['horizontal'] ?? volcanoThemeColors.horizontal.color,
    };
};

export const getVolcanoPlotLineSettings = (display: VolcanoPlotDisplayOption) => {
    const volcanoThemeColors = getVolcanoPlotThemeColors(display.theme_color);
    const showPvalLine = display.show_pval_line;
    const negLog10PvalFillThreshold = -1 * Math.log10(display.pval_fill_threshold);

    const showFoldChangeLines = display.show_fold_change_lines ?? false;
    const log2FoldChangeFillThresholdLower = Math.log2(display.fold_change_fill_threshold_lower);
    const log2FoldChangeFillThresholdUpper = Math.log2(display.fold_change_fill_threshold_upper);
    const lineColors = getLineColors(display);

    return {
        volcanoThemeColors,
        showFoldChangeLines,
        showPvalLine,
        negLog10PvalFillThreshold,
        log2FoldChangeFillThresholdUpper,
        log2FoldChangeFillThresholdLower,
        lineColors,
    };
};

export const buildPlotlyLayout = ({
    display,
    size,
    publicationMode,
    stats,
    dragMode,
    thresholdLineWidth,
    customOptions,
    stylingOptions,
}: {
    display: VolcanoPlotDisplayOption;
    size: ObservedSize | undefined;
    publicationMode?: boolean;
    stats: PlotRange | null | undefined;
    dragMode?: DragMode;
    thresholdLineWidth?: number;
    customOptions?: any;
    stylingOptions?: CustomPlotStylingOptions;
}) => {
    const {
        lineColors,
        showFoldChangeLines,
        showPvalLine,
        negLog10PvalFillThreshold,
        log2FoldChangeFillThresholdLower,
        log2FoldChangeFillThresholdUpper,
    } = getVolcanoPlotLineSettings(display);
    const lineOptions = customOptions ?? {};

    const layout: PlotParams['layout'] = {
        autosize: false,
        width: size?.width,
        height: size?.height,
        showlegend: false,
        margin: {
            t: 0,
            r: 0,
            b: 48,
        },

        yaxis: {
            rangemode: 'tozero',
            showgrid: false,
            title: '<b>-log₁₀(adjusted p-value)</b>',
            titlefont: {
                color: stylingOptions?.yaxis?.fontColor || (publicationMode ? 'black' : undefined),
                size: stylingOptions?.yaxis?.fontSize || 18,
                family: stylingOptions?.yaxis?.fontFamily || 'Arial',
            },
            color: publicationMode ? '#000' : PaletteColors.gray500.color,
            zeroline: true,
            linewidth: 1,
            ticks: 'outside',
            type: 'linear',
        },
        xaxis: {
            showgrid: false,
            title: '<b>log₂ fold change</b>',
            titlefont: {
                color: stylingOptions?.xaxis?.fontColor || (publicationMode ? 'black' : undefined),
                size: stylingOptions?.xaxis?.fontSize || 18,
                family: stylingOptions?.xaxis?.fontFamily || 'Arial',
            },
            color: publicationMode ? '#000' : PaletteColors.gray500.color,
            ticks: 'outside',
            type: 'linear',
        },
        shapes: [],
        dragmode: dragMode,
    };

    if (showFoldChangeLines) {
        const shapes = layout.shapes ?? [];

        shapes.push(
            {
                type: 'line',
                yref: 'paper',
                y0: 0,
                y1: 1,
                x0: log2FoldChangeFillThresholdLower,
                x1: log2FoldChangeFillThresholdLower,
                line: {
                    width: thresholdLineWidth ?? lineOptions['foldChangeThresholdLineLower']?.width ?? LINE_WIDTH,
                    color: lineOptions['foldChangeThresholdLineLower']?.color ?? lineColors.vertical,
                    dash: lineOptions['foldChangeThresholdLineLower']?.dash ?? undefined,
                },
            },
            {
                type: 'line',
                yref: 'paper',
                y0: 0,
                y1: 1,
                x0: log2FoldChangeFillThresholdUpper,
                x1: log2FoldChangeFillThresholdUpper,
                line: {
                    width: thresholdLineWidth ?? lineOptions['foldChangeThresholdLineUpper']?.width ?? LINE_WIDTH,
                    color: lineOptions['foldChangeThresholdLineUpper']?.color ?? lineColors.vertical,
                    dash: lineOptions['foldChangeThresholdLineUpper']?.dash ?? undefined,
                },
            },
        );
        layout.shapes = shapes;
    }

    if (showPvalLine) {
        const shapes = layout.shapes ?? [];

        shapes.push({
            type: 'line',
            xref: 'paper',
            y0: negLog10PvalFillThreshold,
            y1: negLog10PvalFillThreshold,
            x0: 0,
            x1: 1,
            line: {
                width: thresholdLineWidth ?? lineOptions['pvalThresholdLine']?.width ?? 2,
                color: lineOptions['pvalThresholdLine']?.color ?? lineColors.horizontal,
                dash: lineOptions['pvalThresholdLine']?.dash ?? undefined,
            },
        });
        layout.shapes = shapes;
    }

    if (isDefined(display.x_axis_end) || isDefined(display.x_axis_start)) {
        const xaxis = layout.xaxis ?? {};
        logger.info('stats', stats);
        xaxis.range = [display.x_axis_start ?? stats?.xMin, display.x_axis_end ?? stats?.xMax ?? 0];
        layout.xaxis = xaxis;
    }

    if (isDefined(display.y_axis_end) || isDefined(display.y_axis_start)) {
        const yaxis = layout.yaxis ?? {};
        yaxis.rangemode = 'normal';
        yaxis.range = [display.y_axis_start ?? stats?.yMin, display.y_axis_end ?? stats?.yMax ?? 0];
        layout.yaxis = yaxis;
    }

    return layout;
};

export const getTextLabel = ({
    d,
    display,
    biomarkerTargetNames,
}: {
    d: IndexedDataPoint;
    display: VolcanoPlotDisplayOption;
    biomarkerTargetNames?: string[];
}): string | null => {
    const data = getTargetData(d);
    const symbol = getTargetSymbol(d);
    if (
        (data && display.selected_targets?.includes(data)) ||
        (symbol && biomarkerTargetNames?.some((name) => name.toLowerCase() === symbol.toLowerCase()))
    ) {
        return data;
    }
    // legacy support of gene symbol
    const allSymbols = getAllPossibleTargetSymbols(d);
    if (symbol && display.selected_targets?.some((target) => allSymbols.includes(target))) {
        return symbol;
    }
    return null;
};
export const prepareData = ({
    items,
    display,
    biomarkerTargetNames,
}: {
    items: ArrowData<DEGSample>;
    display: VolcanoPlotDisplayOption;
    biomarkerTargetNames?: string[];
}) => {
    const increasedLabeled: IndexedDataPoint[] = [];
    const decreasedLabeled: IndexedDataPoint[] = [];
    const significantUp: IndexedDataPoint[] = [];
    const significantDown: IndexedDataPoint[] = [];
    const nonSignificantUp: IndexedDataPoint[] = [];
    const nonSignificantDown: IndexedDataPoint[] = [];
    const first = getSampleForIndex({ index: 0, items });
    const stats: PlotRange = { yMin: first.y, yMax: first.y, xMin: first.x, xMax: first.x };

    const negLog10PvalFillThreshold = -1 * Math.log10(display.pval_fill_threshold); // Define the significance threshold
    const log2FoldChangeFillThresholdLower = Math.log2(display.fold_change_fill_threshold_lower);
    const log2FoldChangeFillThresholdUpper = Math.log2(display.fold_change_fill_threshold_upper);

    items.Adj_P_Value.forEach((_, i) => {
        const indexedPoint = getSampleForIndex({ index: i, items });
        const label = getTextLabel({ d: indexedPoint, display, biomarkerTargetNames });
        const isSignificant =
            indexedPoint.y >= negLog10PvalFillThreshold &&
            (indexedPoint.x < log2FoldChangeFillThresholdLower || indexedPoint.x > log2FoldChangeFillThresholdUpper);

        if (label) {
            if (indexedPoint.x < 0) decreasedLabeled.push(indexedPoint);
            else increasedLabeled.push(indexedPoint);
        } else {
            if (isSignificant) {
                if (indexedPoint.x < 0) significantDown.push(indexedPoint);
                else significantUp.push(indexedPoint);
            } else {
                if (indexedPoint.x < 0) nonSignificantDown.push(indexedPoint);
                else nonSignificantUp.push(indexedPoint);
            }
        }
    });

    return {
        increasedLabeled,
        decreasedLabeled,
        significantUp,
        significantDown,
        nonSignificantUp,
        nonSignificantDown,
        stats,
    };
};

export const prepareDataOld = ({ items }: { items: ArrowData<DEGSample> }) => {
    const increasedItems: IndexedDataPoint[] = [];
    const decreasedItems: IndexedDataPoint[] = [];
    const first = getSampleForIndex({ index: 0, items });
    const stats: PlotRange = { yMin: first.y, yMax: first.y, xMin: first.x, xMax: first.x };

    items.Adj_P_Value.forEach((d, i) => {
        const indexedPoint = getSampleForIndex({ index: i, items });

        // TODO: figure out if we should be filtering these points or not
        if (Math.abs(indexedPoint.y) === Number.POSITIVE_INFINITY || !isDefined(indexedPoint.Adj_P_Value)) {
            return;
        }

        if (indexedPoint.x < 0) {
            decreasedItems.push(indexedPoint);
        } else {
            increasedItems.push(indexedPoint);
        }

        if (indexedPoint.y > stats.yMax) {
            stats.yMax = indexedPoint.y;
        }
        if (indexedPoint.y < stats.yMin && Math.abs(indexedPoint.y) !== Number.POSITIVE_INFINITY) {
            stats.yMin = indexedPoint.y;
        }
        if (indexedPoint.x < stats.xMin) {
            stats.xMin = indexedPoint.x;
        }
        if (indexedPoint.x > stats.xMax) {
            stats.xMax = indexedPoint.x;
        }
    });

    return { increasedItems, decreasedItems, stats };
};
