import HeatmapDisplayOption from '@models/plotDisplayOption/HeatmapDisplayOption';

import {
    axisBottom,
    axisLeft,
    BaseType,
    max,
    min,
    ScaleBand,
    scaleBand,
    scaleLinear,
    ScaleLinear,
    Selection,
} from 'd3';
import { PlotMargin } from '@components/plots/builders/BasePlotBuilder';
import { AXIS_LABEL_CLASSNAMES, AXIS_TITLE_CLASSNAMES, SummaryHeatmapColorScale } from '@models/PlotConfigs';
import { AssaySummaryAnalysis } from '@models/analysis/AssaySummaryAnalysis';
import { blankToNull, formatTableHeader, roundToDecimal } from '@util/StringUtil';
import { rotateXAxisLabels as drawRotatedXAxisLabels, wrapTextNode } from '@components/plots/PlotUtil';
import { FlatSample, GroupSeries } from '@components/plots/PlotTypes';
import SummaryAnalysisPlotBuilder, {
    ConstructorParams,
    GroupStat,
} from '@components/analysisCategories/summary/plots/builders/SummaryAnalysisPlotBuilder';
import { isDefined } from '@util/TypeGuards';
import cn from 'classnames';
import { getSummaryHeatmapColorScale, getPlotPalette } from '@components/ColorPaletteUtil';
import Logger from '@util/Logger';

const logger = Logger.make('SummaryHeatmapPlotBuilder');

export type HeatmapConstructorArgs = ConstructorParams;
export default class SummaryHeatmapPlotBuilder extends SummaryAnalysisPlotBuilder {
    scales: {
        x: ScaleBand<string>;
        y: ScaleBand<string>;
        color: ScaleLinear<string, number>;
    };

    samplesByGroupName: Record<string, FlatSample[]> = {};
    statsByGroup: GroupStat[];
    legendWidth = 16;

    protected constructor(options: HeatmapConstructorArgs) {
        super(options);
        this.samplesByGroupName = this.makeSamplesByGroupName(this.allSamples);
        this.statsByGroup = this.makeFlatGroupStats();
        this.scales = this.makeScales({ margin: this.margin });
    }

    get palette() {
        return getPlotPalette(this.themeColor);
    }

    get displayOptions(): HeatmapDisplayOption {
        return this.plot.display as HeatmapDisplayOption;
    }

    static make(options: HeatmapConstructorArgs) {
        return new SummaryHeatmapPlotBuilder(options);
    }

    get maxSampleValue() {
        return max(this.allSamples, (d) => d.value ?? 0) ?? 0;
    }

    get minSampleValue() {
        return min(this.allSamples, (d) => d.value ?? 0) ?? 0;
    }

    get yDomain(): { yMin: number; yMax: number } {
        const isGrouped = this.isGrouped;
        const ySampleMin = Math.min(this.minSampleValue, 0);
        const ySampleMax = Math.max(this.maxSampleValue, 0);
        if (!isGrouped) {
            return { yMin: ySampleMin, yMax: ySampleMax };
        }

        const yMin = Math.min(min(this.statsByGroup, this.getStatValue) ?? 0, 0);
        const yMax = Math.max(max(this.statsByGroup, this.getStatValue) ?? 0, 0);
        return { yMin, yMax };
    }

    calculateMargins(): PlotMargin {
        return {
            top: 100,
            right: 55,
            bottom: 90,
            left: 110,
        };
    }

    protected makeSamplesByGroupName = (allSamples: FlatSample[]) => {
        return allSamples.reduce(
            (prev, sample) => {
                const sampleList = prev[sample.group_name] ?? [];
                sampleList.push(sample);
                prev[sample.group_name] = sampleList;
                return prev;
            },
            {} as Record<string, FlatSample[]>,
        );
    };

    get targetData() {
        return this.data.target_groups;
    }

    get analysis(): AssaySummaryAnalysis | null {
        return (this.plot?.analysis ?? null) as AssaySummaryAnalysis | null;
    }

    get group_display_order() {
        return this.analysis?.group_display_order ?? [];
    }

    get firstTargetGroups() {
        return this.targetData[0]?.groups ?? [];
    }

    get unitsLabel() {
        const experiment = this.experiment;
        if (this.plot.analysis_type === 'assay_summary_cpm_normalized') {
            return 'CPM-normalized counts';
        } else {
            return (
                blankToNull(experiment?.assay_data_units?.units_display_name) ??
                blankToNull(experiment?.assay_data_units?.units?.display_name)
            );
        }
    }

    getTargetNames = () => {
        return this.targetData.map((t) => t.target_name);
    };

    get yBreakpoints() {
        const { yMax, yMin } = this.yDomain;
        const yRange = yMax - yMin;
        const point25 = yMin + yRange / 4;
        const point50 = yMin + yRange / 2;
        const point75 = yMin + (3 * yRange) / 4;

        return { yMin, point25, point50, point75, yMax };
    }

    makeScales = ({ margin }: { margin: PlotMargin }) => {
        const firstTargetGroups = this.firstTargetGroups;

        const targetNames = this.getTargetNames();
        const height = this.height;
        const extraX = this.isGrouped ? this.legendWidth * 2 : 0;
        const groupNames = firstTargetGroups.map((d) => d.group_name);
        const sampleIds = [...new Set(this.allSamples.map((d) => d.sample_id))];

        const width = this.width;
        const xScaleBand = scaleBand()
            .domain(targetNames)
            .range([margin.left + extraX, width - margin.right]);

        const yValues = this.displayOptions.summarize_values_by === 'none' ? sampleIds : groupNames;
        const yScaleBand = scaleBand()
            .domain(yValues)
            .range([height - margin.bottom, margin.top]);
        const { yMax, yMin } = this.yDomain;
        const { point75, point25, point50 } = this.yBreakpoints;

        const heatmapColors = getSummaryHeatmapColorScale(
            this.displayOptions.heatmap_scale_color as SummaryHeatmapColorScale,
        );
        const colorScale = scaleLinear<string, number>()
            .range(heatmapColors.map((c) => c.color))
            .domain([yMin, point25, point50, point75, yMax]);

        return { x: xScaleBand, y: yScaleBand, color: colorScale };
    };

    getStatValue = (stat: GroupStat): number | undefined => {
        switch (this.displayOptions.summarize_values_by) {
            case 'median':
                return stat.median;
            case 'mean':
            case 'none':
                return stat.mean;
        }
    };

    appendYAxis = () => {
        // const scales = this.makeHeatmapScales({ margin: this.margin });
        const scales = this.scales;
        const height = this.height;
        const margin = this.margin;
        this.svg.select('.y-axis').remove();
        const drawYAxis = (g: Selection<SVGGElement, unknown, BaseType, unknown>) => {
            const yAxisConfig = axisLeft(scales.y).tickSizeOuter(0);

            g.call((g) => g.select('.domain').remove())
                .attr('transform', `translate(${margin.left},0)`)
                .attr('class', cn(AXIS_LABEL_CLASSNAMES, 'y-axis'))
                .call(yAxisConfig)
                .call((g) =>
                    g
                        .append('text')
                        .attr('x', -height / 2)
                        .attr('y', -margin.left + 20)
                        .attr('fill', 'currentColor')
                        .attr('text-anchor', 'middle')
                        .attr('transform', 'rotate(-90)')
                        .attr('class', AXIS_TITLE_CLASSNAMES),
                );

            const labels = g.selectAll<SVGTextElement, undefined>('.tick text');

            labels.call(wrapTextNode, 160);

            g.select('.domain').remove();
        };

        this.svg.append('g').call(drawYAxis);
    };

    get xAxisLabelRotation() {
        return 45;
    }

    appendXAxis = () => {
        const { x: xScale } = this.scales;
        const height = this.height;
        const margin = this.margin;
        const labelRotation = this.xAxisLabelRotation;
        this.svg.selectAll('.x-axis').remove();
        const drawXAxis = (g: Selection<SVGGElement, unknown, BaseType, unknown>) => {
            g.attr('transform', `translate(0,${height - margin.bottom})`)
                .attr('class', cn(AXIS_LABEL_CLASSNAMES, 'x-axis'))
                .call(
                    axisBottom(xScale)
                        .tickSize(12)
                        .tickSizeOuter(0)
                        .tickFormat((label) => formatTableHeader(label)),
                );
            const labels = g.selectAll<SVGTextElement, undefined>('.tick text');

            labels.call(wrapTextNode, 160);
            if (labelRotation) {
                drawRotatedXAxisLabels(g, labelRotation);
            }
            g.select('.domain').remove();

            return g;
        };

        this.svg.append('g').call(drawXAxis);
    };

    drawSamples = () => {
        const scales = this.scales;
        const tooltipContainer = this.tooltip;
        this.svg
            .selectAll()
            .data(this.allSamples)
            .enter()
            .append('rect')
            .attr('x', (d) => scales.x(d.target_name) ?? 0)
            .attr('y', (d) => scales.y(d.sample_id) ?? 0)
            .attr('width', scales.x.bandwidth() + 1)
            .attr('height', scales.y.bandwidth() + 1)
            .attr('stroke-width', 0)
            .style('fill', (d) => scales.color(d.value))
            .on('mousemove', function (event, d) {
                tooltipContainer.style('opacity', 1);
                tooltipContainer.style('white-space', 'normal');
                tooltipContainer
                    .html(
                        `<div class="white-space-normal">
<span class="block font-semibold text-dark">${d.target_name}</span>
<span class="block text-sm text-gray-600 word-break-none">sample ID: ${d.sample_id}</span>

<span class="block text-sm text-gray-600">group: ${d.group_name}</span>
<span class="block text-sm text-gray-600">value: ${roundToDecimal(d.value)}</span>
</div>`,
                    )
                    .style('left', `${event.pageX + 10}px`)
                    .style('top', `${event.pageY - 10}px`);
            });
    };

    drawGroups = () => {
        const scales = this.scales;
        const stats = this.makeFlatGroupStats();
        const tooltipContainer = this.tooltip;
        const getStatValue = this.getStatValue;
        // draw heatmap
        this.svg
            .selectAll()
            .data(stats)
            .enter()
            .append('rect')
            .attr('x', (d) => scales.x(d.target_name) ?? 0)
            .attr('y', (d) => scales.y(d.group_name) ?? 0)
            .attr('width', scales.x.bandwidth() + 1)
            .attr('height', scales.y.bandwidth() + 1)
            .style('fill', (d) => {
                const value = getStatValue(d);
                if (!isDefined(value)) {
                    return 'transparent';
                }
                return scales.color(value);
            })
            .on('mousemove', function (event, d) {
                tooltipContainer.style('opacity', 1);
                tooltipContainer
                    .html(
                        `
                        <div class="white-space-normal">
<span class="block font-semibold text-dark">${d.target_name}</span>
<span class="block text-sm text-gray-600">group: ${d.group_name}</span>
<span class="block text-sm text-gray-600">average: ${roundToDecimal(d.mean)}</span>
<span class="block text-sm text-gray-600">median: ${roundToDecimal(d.median)}</span>
</div>`,
                    )
                    .style('left', `${event.pageX + 15}px`)
                    .style('top', `${event.pageY - 10}px`);
                this.parentNode?.appendChild(this);
            });
    };

    drawGroupLegend = () => {
        const scales = this.scales;
        const customColors = this.plot.display.custom_color_json ?? {};
        const sortedGroups = this.getSortedGroups();
        const palette = this.palette;
        const getColor = (group: GroupSeries): string => {
            if (customColors[`${group.group_id}`]) {
                return customColors[`${group.group_id}`];
            }
            const groupIndex = sortedGroups?.findIndex((g) => g.group_name === group.group_name) ?? 0;
            return palette.colors[groupIndex % (palette.colors.length - 1)].color;
        };

        const groupX = this.isGrouped ? this.margin.left : this.width + this.legendWidth - this.margin.right;

        this.svg
            .selectAll()
            .data(sortedGroups)
            .enter()
            .append('rect')
            .attr('x', groupX)
            .attr('y', (d) => scales.y(d.group_name) ?? 0)
            .attr('width', this.legendWidth)
            .attr('height', scales.y.bandwidth() + 1)
            .style('fill', (d) => {
                return getColor(d);
            });
    };

    drawSampleLegend = () => {
        const target = this.targetData[0].target_name;
        const samples: FlatSample[] = this.allSamples.filter((s) => s.target_name === target);
        logger.info('samples', samples);
        const scales = this.scales;
        const sortedGroups = this.getSortedGroups();
        const palette = this.palette;
        const customColors = this.plot.display.custom_color_json ?? {};
        const getColor = (sample: FlatSample) => {
            const groupId = sample.group_id;
            if (customColors[`${groupId}`]) {
                return customColors[`${groupId}`];
            }
            const groupIndex = sortedGroups?.findIndex((g) => g.group_name === sample.group_name) ?? 0;
            return palette.colors[groupIndex % palette.colors.length].color;
        };

        this.svg
            .selectAll()
            .data(samples)
            .enter()
            .append('rect')
            .attr('x', this.width + this.legendWidth - this.margin.right)
            .attr('y', (d) => scales.y(d.sample_id) ?? 0)
            .attr('width', this.legendWidth)
            .attr('height', scales.y.bandwidth() + 0.08)
            .style('fill-opacity', '0.75')
            .style('fill', (d) => {
                return getColor(d);
            });
    };

    get legendId() {
        return `${this.plot.uuid}-linear-gradient`;
    }

    get plotWidth() {
        return this.width - this.margin.left - this.margin.right;
    }

    get isGrouped() {
        return this.displayOptions.summarize_values_by !== 'none';
    }

    drawScaleLegend = () => {
        const width = this.plotWidth / (this.displayOptions?.is_full_width ? 2 : 1);
        const rectHeight = 20;
        const spacing = 40;
        const colorScale = this.scales.color;
        const colors = colorScale.range();
        const defs = this.svg.append('defs');
        const gradient = defs.append('linearGradient').attr('id', this.legendId);
        gradient.attr('x1', '0%').attr('y1', '0%').attr('x2', '100%').attr('y2', '0%');

        colors.forEach((hex, i) => {
            gradient
                .append('stop')
                .attr('offset', `${i * 25}%`)
                .attr('stop-color', hex);
        });

        this.svg
            .append('text')
            .text(this.unitsLabel ?? 'Unknown units')
            .attr('x', this.margin.left + width / 2)
            .attr('y', this.margin.top - rectHeight - spacing)
            .attr('text-anchor', 'middle')
            .attr('fill', 'currentColor')
            .attr('class', '');

        const labelOffset = 6;

        this.svg
            .append('rect')
            .attr('class', 'legend-gradient')
            .attr('x', this.margin.left)
            .attr('y', this.margin.top - rectHeight - spacing + labelOffset)
            .attr('width', width)
            .attr('height', rectHeight)
            .style('fill', `url(#${this.legendId})`);

        const { yMin, yMax } = this.yDomain;
        const legendScale = scaleLinear()
            .domain([yMin, yMax])
            .range([0, width - 1]);

        const ticks = Object.values(this.yBreakpoints).sort();

        this.svg.selectAll('.legend-x-axis').remove();
        const drawXAxis = (g: Selection<SVGGElement, unknown, BaseType, unknown>) => {
            g.attr('transform', `translate(${this.margin.left},${this.margin.top - spacing + labelOffset})`)
                .attr('class', cn(AXIS_LABEL_CLASSNAMES, 'legend-x-axis'))
                .call(axisBottom(legendScale).tickSizeOuter(0).tickValues(ticks));

            g.select('.domain').remove();

            return g;
        };

        this.svg.append('g').call(drawXAxis);
    };

    draw = () => {
        if (this.isGrouped) {
            this.appendYAxis();
        }
        this.appendXAxis();
        const yAxisWidth = this.svg.select<SVGGElement>('.y-axis')?.node()?.getBoundingClientRect().width ?? 0;
        const xAxisHeight = this.svg.select<SVGGElement>('.x-axis')?.node()?.getBoundingClientRect().height ?? 0;
        this.margin.left = yAxisWidth + 10;
        this.margin.bottom = xAxisHeight + 20;
        this.scales = this.makeScales({ margin: this.margin });
        if (this.isGrouped) {
            this.appendYAxis();
        }

        this.appendXAxis();
        if (this.isGrouped) {
            this.drawGroups();
            this.drawGroupLegend();
        } else {
            this.drawSamples();
            this.drawSampleLegend();
        }
        this.drawScaleLegend();

        this.svg.on('mouseout', () => {
            this.tooltip.style('opacity', 0);
        });
    };
}
