import * as d3 from 'd3';
import BasePlotBuilder, {
    PlotMargin,
    ConstructorParams as BaseParams,
} from '@/src/components/plots/builders/BasePlotBuilder';
import { ArrowPlotData, GenericCellData, PlotMapping, RidgePlotItem } from '@models/ExperimentData';
import ViolinPlotDisplayOption from '@/src/models/plotDisplayOption/ViolinPlotDisplayOption';
import { getPlotPalette } from '@/src/components/ColorPaletteUtil';
import PaletteColors from '@/src/components/PaletteColors';
import { wrapTextNode } from '@/src/components/plots/PlotUtil';
import { formatStringToNumberWithSeparator, roundToDecimal } from '@/src/util/StringUtil';

type Density = [number, number][];

// Define bandwidth for violins
const bandwidth = 0.75;

// Define horizontal_compression_threshold, used to determine if a violin should be compressed into a vertical line
// Set to 1 for no compression, 0.75 to compress a violin if 75% or more of the cells are zero-expressing, etc.
const horizontal_compression_threshold = 0.75;

// Define vertical_compression_factor, used to determine how much to compress values vertically at the bottom of a violin
// Set to 1 for no compression, 0.5 for 50% compression, etc.
const vertical_compression_factor = 0.2;

export type ConstructorParams = BaseParams<ArrowPlotData<GenericCellData>>;
export default class ViolinPlotBuilder extends BasePlotBuilder<ArrowPlotData<GenericCellData>> {
    constructor(params: ConstructorParams) {
        super(params);
    }

    calculateMargins(): PlotMargin {
        return { top: 20, right: 30, bottom: 87, left: 80 };
    }

    getTooltipContent = (
        analysisShortname: string,
        groupName: string,
        groupCellCount: number,
        percentExpressed: number,
    ): string => {
        const getPercentExpressedString = () => {
            switch (analysisShortname) {
                case 'seurat_module_score':
                    return 'with module activity';
                case 'seurat_marker_expression':
                default:
                    return 'expressing target';
            }
        };
        return `
<span class="block font-semibold text-dark">Group: ${groupName}</span>
<span class="block font-semibold text-dark">Total number of cells: ${formatStringToNumberWithSeparator(groupCellCount)}</span>
<span class="block font-semibold text-dark">% of cells ${getPercentExpressedString()}: ${roundToDecimal(percentExpressed, { decimals: 1 })}%</span>
`;
    };

    draw = () => {
        const items = this.data.items as RidgePlotItem[];
        const display = this.plot.display as ViolinPlotDisplayOption;
        const shortname = this.plot.analysis_type;
        const dataMap = this.data.plot_mapping as PlotMapping;
        const customColors = display.custom_color_json ?? {};
        const theme = display.theme_color;
        const customLegend = display.custom_legend_json ?? {};
        const groupDisplayOrder = display.group_display_order ?? [];
        const showCellsGroups = display.groups ?? {};
        const stylingOptions = this.stylingOptions;
        const publicationMode = this.publicationMode;
        const tooltipContainer = this.tooltip;

        const svg = this.svg;
        const margin = this.calculateMargins();
        const width = this.width - margin.left - margin.right;
        const height = this.height - margin.top - margin.bottom;

        //// helper functions ////
        const kernelDensityEstimator = (kernel: (x: number) => number, X: number[]) => {
            return (V: number[]): Density => X.map((x: number) => [x, d3.mean(V, (v: number) => kernel(x - v)) ?? 0]);
        };

        const kernelEpanechnikov = (k: number) => {
            return (v: number) => (Math.abs((v /= k)) <= 1 ? (0.75 * (1 - v * v)) / k : 0);
        };

        const filterGroup = (d) => {
            const hasGroups = showCellsGroups && Object.keys(showCellsGroups).length > 0;
            const groupId = d[1][0].group_id;
            const groupIsHidden = !showCellsGroups?.[groupId];
            if (hasGroups && groupIsHidden) {
                return false;
            }
            return true;
        };

        const getAreaColor = (groupId: string, i: number) => {
            const { colors } = getPlotPalette(theme);
            return customColors?.[groupId] ?? colors[i % colors.length]?.color;
        };

        const getVerticalCompressionThreshold = (allValues: number[]) => {
            // Define vertical_compression_threshold, used to determine if vertical_compression_factor should be applied to the violin
            // By default, compute the 10th percentile based on meaningful, non-zero gene expression values and use as vertical_compression_threshold
            const nonZeroValues = allValues.filter((value) => value > 0);
            if (nonZeroValues.length === 0) {
                console.warn('Violin plot: No non-zero values available for percentile calculation.');
                return 0;
            }
            const sortedValues = nonZeroValues.sort((a, b) => a - b);
            const percentile10 = sortedValues[Math.floor(0.1 * sortedValues.length)];
            return percentile10;
        };
        // end helper functions //

        // Group data by group_name
        const groupedData = d3.group(
            items.map((item) => ({
                group_id: item[dataMap.group_id],
                group_name: customLegend?.[item[dataMap.group_id]] ?? item[dataMap.group_name],
                value: +item[dataMap.value],
            })),
            (d) => d.group_name,
        );
        const sortedGroupedData = Array.from(groupedData.entries()).sort(
            (g1, g2) => groupDisplayOrder.indexOf(g1[1][0].group_id) - groupDisplayOrder.indexOf(g2[1][0].group_id),
        );
        const filteredGroupedData = sortedGroupedData.filter(filterGroup);
        const filteredKeysSet = new Set(filteredGroupedData.map(([k]) => k));

        // Set the dimensions and margins of the graph
        this._svg?.attr('width', this.width).attr('height', this.height);
        svg.attr('transform', `translate(${margin.left},${margin.top})`);

        // Build X-Axis
        const x = d3
            .scaleBand()
            .domain(filteredGroupedData.map((d) => d[0])) // Use the sorted keys
            .range([0, width])
            .padding(0.1);
        const xAxis = svg
            .append('g')
            .attr('class', 'legend-x-axis xAxis')
            .attr('transform', `translate(0,${height})`)
            .call(
                d3
                    .axisBottom(x)
                    .tickSizeOuter(0)
                    .tickFormat((value) => {
                        const maxLength = 16;

                        if (value.length > maxLength && !this.isExportMode) {
                            return `${value.substring(0, maxLength - 3)}...`;
                        }
                        return value;
                    }),
            );
        xAxis.selectAll('.domain').attr('stroke', publicationMode ? 'black' : 'currentColor');

        // Y scale
        const allValues = sortedGroupedData
            .map((d) => d[1])
            .flat()
            .map((d) => d.value);
        const globalMaxValue = d3.max(allValues) ?? 0;
        const yGlobal = d3
            .scaleLinear()
            .domain([0, globalMaxValue]) // Cover the full range of your data
            .range([height, 0]);
        const y = svg.append('g').call(d3.axisLeft(yGlobal).tickSizeOuter(0));
        y.selectAll('.domain').attr('stroke', publicationMode ? 'black' : 'currentColor');
        y.selectAll('.tick text').attr('fill', publicationMode ? '#000' : 'currentColor');

        // Function to get label text based on a condition
        const getLabelText = () => {
            switch (shortname) {
                case 'seurat_module_score':
                    return 'Module score';
                case 'seurat_marker_expression':
                default:
                    return 'Log-normalized expression';
            }
        };

        // Axis labels
        const label = svg
            .append('text')
            .attr('class', `axis-label y-axis-label`)
            .attr('x', -height / 2)
            .attr('y', -45)
            .attr('fill', stylingOptions?.yaxis?.fontColor || (publicationMode ? 'black' : 'currentColor'))
            .style('font-size', stylingOptions?.yaxis?.fontSize || 18)
            .style('font-family', stylingOptions?.yaxis?.fontFamily || 'Arial')
            .attr('text-anchor', 'middle')
            .attr('transform', 'rotate(-90)')
            .text(getLabelText())
            .call((g) => wrapTextNode(g, (height - margin.bottom) * 0.9));
        const labelWidth = label.node()?.getBoundingClientRect()?.width ?? 18;

        const yAxisWidth = svg.select<SVGGElement>('.y-axis')?.node()?.getBoundingClientRect().width ?? 70;
        margin.left = yAxisWidth + labelWidth + 12;

        const labelHeight = svg.select<SVGGElement>('.x-axis-label')?.node()?.getBoundingClientRect().height ?? 22;
        svg.append('text')
            .attr('class', `axis-label x-axis-label`)
            .attr('x', width / 2)
            .attr('y', height + labelHeight + 60)
            .attr('fill', stylingOptions?.xaxis?.fontColor || (publicationMode ? 'black' : 'currentColor'))
            .style('font-size', stylingOptions?.xaxis?.fontSize || 18)
            .style('font-family', stylingOptions?.xaxis?.fontFamily || 'Arial')
            .attr('text-anchor', 'middle')
            .text('Group')
            .call((g) => wrapTextNode(g, width * 0.8));

        // Rotate x-axis ticks for readability
        xAxis
            .selectAll('.tick text')
            .attr('fill', publicationMode ? '#000' : PaletteColors.gray500.color)
            .attr('text-anchor', 'end')
            .attr('transform', 'rotate(-45)')
            .attr('dx', '-0.5em')
            .attr('dy', '0.5em');

        // Group plotting helpers
        const shouldCompress =
            display?.custom_options_json?.useHorizontalCompression ||
            display?.custom_options_json?.useHorizontalCompression === undefined;
        const horizontalCompressionThreshold = shouldCompress ? horizontal_compression_threshold : 1;
        const verticalCompressionThreshold = getVerticalCompressionThreshold(allValues);

        // Plot each group
        sortedGroupedData.forEach(([key, values], i) => {
            if (!filteredKeysSet.has(key)) return;

            const valuesMax = d3.max(values.map((d) => d.value)) ?? 0;
            const y = d3.scaleLinear().domain([0, valuesMax]);
            const kde = kernelDensityEstimator(kernelEpanechnikov(bandwidth), y.ticks(512));
            const density: Density = kde(values.map((d) => d.value));

            // Scale for each violin
            const maxDensity = d3.max(density, (d) => d[1]) ?? 0;
            const xNum = d3
                .scaleLinear()
                .domain([0, maxDensity])
                .range([0, x.bandwidth() / 2]);

            const numberOfValuesExpressed = values.filter((d) => d.value > 0).length; // Count of non-zero values for this group
            const totalNumberOfValues = values.length; // Total count of values for this group
            const percentExpressed = (numberOfValuesExpressed / totalNumberOfValues) * 100; // Percent of expressed values for this group
            const percentExpressingZero = values.filter((d) => d.value === 0).length / values.length;
            const violinShouldCollapse = percentExpressingZero > horizontalCompressionThreshold;

            const xKeyCalc = (x(key) ?? 0) + x.bandwidth() / 2;
            const yKeyCalc = (d: [number, number]) =>
                yGlobal(d[0] < verticalCompressionThreshold ? d[0] * vertical_compression_factor : d[0]);
            const groupId = values[0][dataMap.group_id];
            const groupName = key;

            // Draw the violin
            svg.append('path')
                .attr('fill', getAreaColor(groupId, i))
                .on('mouseover', (event) => {
                    tooltipContainer.transition().duration(50).style('opacity', 1);
                    tooltipContainer
                        .html(this.getTooltipContent(shortname, groupName, totalNumberOfValues, percentExpressed))
                        .style('left', `${event.pageX + 10}px`)
                        .style('top', `${event.pageY - 10}px`);
                })
                .on('mouseout', function () {
                    tooltipContainer.transition().style('opacity', 0);
                })
                .datum(density)
                .attr('stroke', '#000')
                .attr('stroke-width', 1)
                .attr(
                    'd',
                    d3
                        .area()
                        .x(xKeyCalc)
                        .y0(yKeyCalc)
                        .y1(yKeyCalc)
                        .x0((d) => (violinShouldCollapse ? xKeyCalc : xKeyCalc - xNum(d[1])))
                        .x1((d) => (violinShouldCollapse ? xKeyCalc : xKeyCalc + xNum(d[1])))
                        .curve(d3.curveCatmullRom),
                );
        });
    };
}
