import BasePlotBuilder, {
    ConstructorParams as BaseParams,
    PlotMargin,
} from '@components/plots/builders/BasePlotBuilder';
import { isBlank, isNotBlank } from '@util/StringUtil';
import { SurvivalData, SurvivalSample } from '@models/ExperimentData';
import { SurvivalAnalysis } from '@models/analysis/SurvivalAnalysis';
import KaplanMeierCurveDisplayOption from '@models/plotDisplayOption/KaplanMeierCurveDisplayOption';
import Plot from '@models/Plot';
import * as d3 from 'd3';
import { axisBottom, axisLeft, BaseType, max, scaleLinear, ScaleLinear, Selection } from 'd3';
import Experiment from '@models/Experiment';
import {
    AXIS_LABEL_CLASSNAMES,
    AXIS_LABEL_PUBLICATION_CLASSNAMES,
    AXIS_TITLE_CLASSNAMES,
    AXIS_TITLE_PUBLICATION_CLASSNAMES,
} from '@models/PlotConfigs';
import Logger from '@util/Logger';
import cn from 'classnames';
import { getPlotPalette } from '@components/ColorPaletteUtil';
import { CustomPlotStylingOptions } from '@components/analysisCategories/comparative/plots/PlotlyVolcanoPlotUtil';

const logger = Logger.make('KaplanMeierCurvePlotBuilder');

type PlotType = Plot<SurvivalAnalysis, KaplanMeierCurveDisplayOption>;
type DataType = SurvivalData;
export type ConstructorParams = BaseParams<DataType, PlotType> & {
    stylingOptions: CustomPlotStylingOptions | null;
};

export default class KaplanMeierCurvePlotBuilder extends BasePlotBuilder<
    SurvivalData,
    Plot<SurvivalAnalysis, KaplanMeierCurveDisplayOption>
> {
    stylingOptions: CustomPlotStylingOptions | null;
    static make(options: ConstructorParams) {
        return new KaplanMeierCurvePlotBuilder(options);
    }

    static getAxisTitles(plot: PlotType) {
        const display = plot.display as KaplanMeierCurveDisplayOption;
        const analysis = plot.analysis as SurvivalAnalysis;
        const yLabel = display.y_axis_title ?? 'Survival probability';

        const units = analysis?.time_to_event_units;
        const buildXLabel = () => {
            if (!isBlank(display.x_axis_title)) {
                return display.x_axis_title;
            }
            const unitsPart = isBlank(units) ? '' : `(${units})`;
            return `OS ${unitsPart}`.trim();
        };
        const xLabel = buildXLabel();

        return { xLabel, yLabel };
    }

    svg: Selection<BaseType, unknown, BaseType, unknown>;
    plot: PlotType;
    data: DataType;
    experiment: Experiment;
    tooltipId: string;
    width: number;
    height: number;
    tooltip: Selection<HTMLDivElement, unknown, HTMLElement, unknown>;
    margin: PlotMargin;
    yAxisFormat = '~%';
    scales: { xScale: ScaleLinear<number, number>; yScale: ScaleLinear<number, number> };

    protected constructor(options: ConstructorParams) {
        super(options);
        this.scales = this.makeScales();
        this.stylingOptions = options.stylingOptions;
    }

    makeScales() {
        const { xMin, xMax } = this.xDomain;
        const xScale = scaleLinear()
            .domain([xMin, xMax])
            .rangeRound([this.margin.left, this.width - this.margin.right]);

        const yScale = scaleLinear()
            .domain([0, 1])
            .rangeRound([this.height - this.margin.bottom, this.margin.top]);
        return { xScale, yScale };
    }

    get displayOptions() {
        return this.plot.display;
    }

    calculateMargins(): PlotMargin {
        return {
            top: 20,
            right: 30,
            bottom: 50,
            left: isNotBlank(this.yAxisTitle) ? 90 : 70,
        };
    }

    get yDomain(): { yMin: number; yMax: number } {
        return { yMax: 1, yMin: 0 };
    }

    get xDomain(): { xMin: number; xMax: number } {
        return { xMin: 0, xMax: max(this.data.items, (d) => d.x_value) ?? 0 };
    }

    get yAxisTitle() {
        const { yLabel } = KaplanMeierCurvePlotBuilder.getAxisTitles(this.plot);
        return yLabel;
    }

    get xAxisTitle() {
        const { xLabel } = KaplanMeierCurvePlotBuilder.getAxisTitles(this.plot);
        return xLabel;
    }

    appendYAxis = () => {
        const styles = this.stylingOptions?.yaxis;
        const { yScale } = this.scales;
        const height = this.height;
        const margin = this.margin;
        const yAxisTitle = this.yAxisTitle;
        const yAxisFormat = this.yAxisFormat;
        const yAxisConfig = axisLeft(yScale).ticks(null, yAxisFormat).tickSizeOuter(0);

        this.svg.selectAll('.y-axis').remove();
        this.svg.selectAll('.y-axis-title').remove();

        this.svg
            .append('g')
            .attr('transform', `translate(${margin.left},0)`)
            .attr(
                'class',
                cn(this.publicationMode ? AXIS_LABEL_PUBLICATION_CLASSNAMES : AXIS_LABEL_CLASSNAMES, 'y-axis'),
            )
            .call((g) => g.select('.domain').remove())
            .call(yAxisConfig)
            .call((g) => {
                const yAxisWidth = g.node()?.getBoundingClientRect().width ?? 100;
                g.append('text')
                    .attr('x', -(height - margin.bottom) / 2)
                    .attr('y', -yAxisWidth - 12)
                    .attr('fill', styles ? styles.fontColor : 'currentColor')
                    .style('font-size', styles ? styles.fontSize : '18')
                    .style('font-family', styles ? styles.fontFamily : 'Arial')
                    .attr('text-anchor', 'middle')
                    .attr('transform', 'rotate(-90)')
                    .attr(
                        'class',
                        cn(
                            this.publicationMode ? AXIS_TITLE_PUBLICATION_CLASSNAMES : AXIS_TITLE_CLASSNAMES,
                            'y-axis-title',
                        ),
                    )
                    .text(yAxisTitle);
            });
    };

    appendXAxis = () => {
        const styles = this.stylingOptions?.xaxis;
        const { xScale } = this.scales;
        const height = this.height;
        const width = this.width;
        const margin = this.margin;

        this.svg.selectAll('.x-axis').remove();
        this.svg.selectAll('.x-axis-title').remove();

        const axisContainer = this.svg.append('g').attr('class', cn('x-axis'));

        axisContainer
            .append('g')
            .attr('transform', `translate(0,${height - margin.bottom})`)
            .attr(
                'class',
                cn(this.publicationMode ? AXIS_LABEL_PUBLICATION_CLASSNAMES : AXIS_LABEL_CLASSNAMES, 'x-axis'),
            )
            .call(axisBottom(xScale).tickSize(12).tickSizeOuter(0))
            .call((g) => {
                const xAxisHeight = g.node()?.getBoundingClientRect().height ?? 100;
                return g
                    .append('text')
                    .attr('x', width / 2)
                    .attr('y', xAxisHeight + 24)
                    .attr('fill', styles ? styles.fontColor : 'currentColor')
                    .style('font-size', styles ? styles.fontSize : '18')
                    .style('font-family', styles ? styles.fontFamily : 'Arial')
                    .attr('text-anchor', 'middle')
                    .attr(
                        'class',
                        cn(
                            this.publicationMode ? AXIS_TITLE_PUBLICATION_CLASSNAMES : AXIS_TITLE_CLASSNAMES,
                            'x-axis-title',
                        ),
                    )
                    .text(`${this.xAxisTitle ?? ''}`);
            });
    };

    get orderedGroupIds(): number[] {
        const group_display_order = this.displayOptions.group_display_order ?? [];
        if (group_display_order.length > 0) {
            return group_display_order;
        }
        return this.plot.analysis?.groups?.map((g) => g.id) ?? [];
    }

    getLineColor = (groupId: number): string => {
        const customColors = this.displayOptions.custom_color_json ?? {};
        if (customColors[`${groupId}`]) {
            return customColors[`${groupId}`];
        }

        const palette = getPlotPalette(this.themeColor);
        const sortedGroupIndex = this.orderedGroupIds.indexOf(groupId);
        if (sortedGroupIndex < 0) {
            logger.warn('Unable to find group in the group display order', {
                groupId,
                plotId: this.plot.uuid,
                displayOptions: this.displayOptions,
            });
        }
        const groupIndex = Math.max(0, sortedGroupIndex);
        return palette.colors[groupIndex % (palette.colors.length - 1)].color;
    };

    getLineItems(samples: SurvivalSample[] = []): SurvivalSample[] {
        if (samples.length === 0) {
            return [];
        }
        const firstSample = { ...samples[0] };

        firstSample.y_value = 1;
        firstSample.x_value = 0;
        return [firstSample, ...samples].sort((d1, d2) => d1.x_value - d2.x_value);
    }

    getDataByGroup = () => {
        return this.data.items.reduce<Record<number, SurvivalSample[]>>((map, sample) => {
            const items = map[sample.sample_group_id] ?? [];
            items.push(sample);
            map[sample.sample_group_id] = items;
            return map;
        }, {});
    };

    appendLines = () => {
        this.svg.selectAll('.data-lines').remove();
        const lineContainer = this.svg.append('g').attr('class', 'data-lines');

        const { xScale, yScale } = this.scales;
        const dataByGroup = this.getDataByGroup();

        const drawLine = (groupId: number, samples: SurvivalSample[]) => {
            const line = d3
                .line<SurvivalSample>()
                .x((d) => xScale(d.x_value))
                .y((d) => yScale(d.y_value))
                .curve(d3.curveStepAfter);

            const lineColor = this.getLineColor(groupId);
            const sortedData = this.getLineItems(samples);
            lineContainer
                .append('path')
                .attr('class', 'data-line')
                .datum(sortedData)
                .attr('d', line)
                .attr('class', 'plot-line')
                .attr('fill', 'none')
                .attr('stroke', lineColor)
                .attr('stroke-width', 2);
        };

        Object.keys(dataByGroup).map((key) => {
            const groupId = Number(key);
            drawLine(groupId, dataByGroup[groupId]);
        });
    };

    appendTicks = () => {
        this.svg.selectAll('.data-ticks').remove();
        const tickContainer = this.svg.append('g').attr('class', 'data-ticks');

        const { xScale, yScale } = this.scales;
        const tickHeight = 8;
        const dataByGroup = this.data.items
            .filter((s) => s.tick === 1)
            .reduce<Record<number, SurvivalSample[]>>((map, sample) => {
                const items = map[sample.sample_group_id] ?? [];
                items.push(sample);
                map[sample.sample_group_id] = items;
                return map;
            }, {});

        const drawLine = (groupId: number, samples: SurvivalSample[]) => {
            const lineColor = this.getLineColor(groupId);
            tickContainer
                .append('g')
                .selectAll('line')
                .data(samples)
                .enter()
                .append('line')
                .attr('x1', (d) => xScale(d.x_value))
                .attr('x2', (d) => xScale(d.x_value))
                .attr('y1', (d) => yScale(d.y_value) - tickHeight)
                .attr('y2', (d) => yScale(d.y_value) + tickHeight)
                .attr('stroke', lineColor)
                .attr('stroke-width', 1);
        };

        Object.keys(dataByGroup).map((key) => {
            const groupId = Number(key);
            drawLine(groupId, dataByGroup[groupId]);
        });
    };

    appendMouseEvents = () => {
        const svg = this.svg;
        const width = this.width - this.margin.right - this.margin.left;

        const xScale = this.scales.xScale;
        const yScale = this.scales.yScale;
        const height = this.height - this.margin.top - this.margin.bottom;

        const dataByGroup = this.getDataByGroup();

        const mouseG = svg.append('g').attr('class', 'mouse-over-effects');

        mouseG
            .append('path') // this is the black vertical line to follow mouse
            .attr('class', 'mouse-line stroke-current text-[#171717]')
            .style('stroke', 'currentColor')
            .style('stroke-width', '1px')
            .style('opacity', '0');

        const mousePerLine = mouseG
            .selectAll('.mouse-per-line')
            .data(Object.keys(dataByGroup).map(Number))
            .enter()
            .append('g')
            .attr('class', 'mouse-per-line')
            .attr('transform', `translate(${this.margin.left},${this.margin.top})`);

        mousePerLine
            .append('circle')
            .attr('r', 7)
            .style('stroke', (d) => {
                return this.getLineColor(d);
            })
            .style('fill', 'none')
            .style('stroke-width', '1px')
            .style('opacity', '0');

        mousePerLine
            .append('text')
            .attr('transform', 'translate(10,6)')
            .attr('style', 'text-shadow: white 0px 0px 2px');

        const lineBisector = d3.bisector<SurvivalSample, number>((d) => d.x_value).right;
        mouseG
            .append('svg:rect') // append a rect to catch mouse movements on canvas
            .attr('width', width) // can't catch mouse events on a g element
            .attr('height', height) // can't catch mouse events on a g element

            .attr('fill', 'none')
            .attr('pointer-events', 'all')
            .attr('transform', `translate(${this.margin.left},${this.margin.top})`)
            // .attr('attr', `translate(${this.margin.left}, ${-this.margin.top})`)
            .on('mouseout', function () {
                // on mouse out hide line, circles and text
                svg.select('.mouse-line').style('opacity', '0');
                svg.selectAll('.mouse-per-line circle').style('opacity', '0');
                svg.selectAll('.mouse-per-line text').style('opacity', '0');
            })
            .on('mouseover', function () {
                // on mouse in show line, circles and text
                svg.select('.mouse-line').style('opacity', '1');
                svg.selectAll('.mouse-per-line circle').style('opacity', '1');
                svg.selectAll('.mouse-per-line text').style('opacity', '1');
            })
            .on('mousemove', (e: MouseEvent) => {
                // mouse moving over canvas
                const [mouseX] = d3.pointer(e);

                svg.select('.mouse-line').attr('d', () => {
                    let d = 'M' + (mouseX + this.margin.left) + ',' + (height + this.margin.top);
                    d += ' ' + (mouseX + this.margin.left) + ',' + (this.margin.top - 6);
                    return d;
                });

                const getBisectedSample = (groupId: number) => {
                    const x0 = xScale.invert(mouseX + this.margin.left);
                    const samples = dataByGroup[groupId];
                    const lineItems = this.getLineItems(samples);

                    const bisectedIndex = lineBisector(lineItems, x0, 1);
                    const d0 = lineItems[Math.max(bisectedIndex - 1, 0)];
                    const d1 = lineItems[Math.min(bisectedIndex, lineItems.length - 1)];

                    return [d0, d1];
                };

                // is the mouse over the halfway point of the plot?
                const isHalfX = mouseX + this.margin.left > this.width / 2;

                svg.selectAll<BaseType, number>('.mouse-per-line')
                    .attr('opacity', (groupId) => {
                        const x0 = xScale.invert(mouseX + this.margin.left);
                        const samples = dataByGroup[groupId];
                        const lineItems = this.getLineItems(samples);
                        const bisectedIndex = lineBisector(lineItems, x0, 1);
                        const d0 = lineItems[bisectedIndex];
                        return d0 ? 1 : 0;
                    })
                    .attr('transform', (groupId) => {
                        const [d0] = getBisectedSample(groupId);
                        return 'translate(' + (mouseX + this.margin.left) + ',' + yScale(d0.y_value) + ')';
                    })
                    .call((g) => {
                        g.select('text')
                            .text((groupId) => {
                                const [d0] = getBisectedSample(groupId);

                                return `${(d0.y_value * 100).toFixed(1)}%`;
                            })
                            .attr('text-anchor', () => {
                                if (isHalfX) {
                                    return 'end';
                                }
                                return 'start';
                            })
                            .attr('transform', () => {
                                if (isHalfX) {
                                    return 'translate(-10,6)';
                                }
                                return 'translate(10,6)';
                            });
                    });
            });
    };

    draw = () => {
        this.clearAll();

        this.appendYAxis();
        this.appendXAxis();

        const yAxisWidth = this.svg.select<SVGGElement>('.y-axis')?.node()?.getBoundingClientRect().width ?? 0;
        const xAxisHeight = this.svg.select<SVGGElement>('.x-axis')?.node()?.getBoundingClientRect().height ?? 0;
        const xAxisTitleHeight =
            this.svg.select<SVGGElement>('.x-axis-title')?.node()?.getBoundingClientRect().height ?? 0;
        this.margin.left = yAxisWidth;
        this.margin.bottom = xAxisHeight + xAxisTitleHeight;
        this.scales = this.makeScales();
        this.appendYAxis();
        this.appendXAxis();

        this.appendLines();
        this.appendTicks();

        this.appendMouseEvents();
    };
}
