import {
    AXIS_LABEL_CLASSNAMES,
    AXIS_LABEL_PUBLICATION_CLASSNAMES,
    AXIS_TITLE_CLASSNAMES,
    AXIS_TITLE_PUBLICATION_CLASSNAMES,
} from '@models/PlotConfigs';
import React, { useMemo } from 'react';
import * as d3 from 'd3';
import { DEGData } from '@models/ExperimentData';
import { createPlotTooltip, getVolcanoPlotThemeColors } from '@components/plots/PlotUtil';
import VolcanoPlotDisplayOption from '@models/plotDisplayOption/VolcanoPlotDisplayOption';
import Logger from '@util/Logger';
import { DataPoint, prepareDEGData } from '@components/analysisCategories/comparative/plots/VolcanoPlotUtil';
import cn from 'classnames';
import { blankToNull, isBlank, isNotBlank } from '@util/StringUtil';
import DynamicPlotContainer, { DrawChartFn } from '@components/plots/DynamicPlotContainer';
import { isDefined } from '@util/TypeGuards';
import {
    CustomPlotStylingOptions,
    getTargetData,
} from '@components/analysisCategories/comparative/plots/PlotlyVolcanoPlotUtil';

const AXIS_PADDING_PERCENT = 0.05;
const logger = Logger.make('VolcanoPlot');
const DOT_HIGHLIGHT_COLOR = '#7BD2F1';

type Props = { customPlotStylingOptions: CustomPlotStylingOptions | null };

const VolcanoPlotView = ({ customPlotStylingOptions }: Props) => {
    const drawChart = useMemo<DrawChartFn>(
        () =>
            ({ svgSelection: svg, size, context, tooltipId }) => {
                const { publicationMode, plot, plotData: data } = context;
                const options = plot.display as VolcanoPlotDisplayOption;
                if (!options || !data) {
                    return;
                }
                const themeColor = options.theme_color;
                const showTargetLabels = options.custom_options_json?.show_target_labels;

                const showPvalLine = options.show_pval_line;
                const negLog10PvalFillThreshold = -1 * Math.log10(options.pval_fill_threshold);

                const showFoldChangeLines = options.show_fold_change_lines ?? false;
                const log2FoldChangeFillThresholdLower = Math.log2(options.fold_change_fill_threshold_lower);
                const log2FoldChangeFillThresholdUpper = Math.log2(options.fold_change_fill_threshold_upper);

                const customColors = options.custom_color_json ?? {};
                const volcanoThemeColors = getVolcanoPlotThemeColors(themeColor);

                const positive = customColors['increased'] ?? volcanoThemeColors.positive.color;
                const negative = customColors['decreased'] ?? volcanoThemeColors.negative.color;
                const vertical = customColors['vertical'] ?? volcanoThemeColors.vertical.color;
                const horizontal = customColors['horizontal'] ?? volcanoThemeColors.horizontal.color;

                const preparedData = prepareDEGData(data as DEGData, options);
                if (!preparedData) {
                    logger.warn('no prepared data found, removing chart.');
                    svg.selectAll('g').remove();
                    return;
                }

                const tooltipContainer = createPlotTooltip(tooltipId);
                const allSamples = preparedData.items;
                const { x: xStats, y: yStats } = preparedData;

                svg.selectAll('g').remove();
                const height = size.height;
                const width = size.width;
                const margin = { top: 20, right: 30, bottom: 60, left: 40 };

                const yPadding = Math.abs(yStats.max - yStats.min) * AXIS_PADDING_PERCENT;
                const xPadding = Math.abs(xStats.max - xStats.min) * AXIS_PADDING_PERCENT;

                const yMin = options.y_axis_start ?? Math.min(yStats.min - yPadding, 0);
                const yMax = options.y_axis_end ?? yStats.max + yPadding;

                const xMin = options.x_axis_start ?? xStats.min - xPadding;
                const xMax = options.x_axis_end ?? xStats.max + xPadding;

                const yScale = d3
                    .scaleLinear()
                    .domain([yMin, yMax])
                    .rangeRound([height - margin.bottom, margin.top]);

                const xScale = d3
                    .scaleLinear()
                    .domain([xMin, xMax])
                    .rangeRound([margin.left, width - margin.right]);

                const xAxis = (g) =>
                    g
                        .attr('transform', `translate(0,${height - margin.bottom})`)
                        .attr('class', publicationMode ? AXIS_LABEL_PUBLICATION_CLASSNAMES : AXIS_LABEL_CLASSNAMES)
                        .call(d3.axisBottom(xScale).tickSizeOuter(0));

                svg.selectAll('.axis-label').remove();
                svg.selectAll('.y-axis').remove();

                // Add x-axis title
                svg.append('text')
                    .attr(
                        'class',
                        `axis-label ${publicationMode ? AXIS_TITLE_PUBLICATION_CLASSNAMES : AXIS_TITLE_CLASSNAMES}`,
                    )
                    .attr('x', width / 2)
                    .attr('y', height - 5)
                    .attr(
                        'fill',
                        customPlotStylingOptions?.xaxis ? customPlotStylingOptions?.xaxis.fontColor : 'currentColor',
                    )
                    .style(
                        'font-size',
                        customPlotStylingOptions?.xaxis ? customPlotStylingOptions?.xaxis.fontSize : '18',
                    )
                    .style(
                        'font-family',
                        customPlotStylingOptions?.xaxis ? customPlotStylingOptions?.xaxis.fontFamily : 'Arial',
                    )
                    .attr('text-anchor', 'middle')
                    .text(`log₂ fold change`);

                const yAxisFormat = yMax > 10_000 ? '.1e' : ',f';
                const yAxis = (g) =>
                    g
                        .call((g) => g.select('.domain').remove())
                        .attr('transform', `translate(${xScale(0)},0)`)
                        .attr(
                            'class',
                            `y-axis pointer-events-none ${
                                publicationMode ? AXIS_LABEL_PUBLICATION_CLASSNAMES : AXIS_LABEL_CLASSNAMES
                            }`,
                        )
                        .call(
                            d3
                                .axisLeft(yScale)
                                .ticks(10, yAxisFormat)
                                .tickFormat((d) => (d !== 0 ? `${d}` : '')),
                        )
                        .call((g) => g.select('.y-axis > .tick:first-of-type').remove());

                // Add y-axis label
                svg.append('text')
                    .attr(
                        'class',
                        `axis-label ${publicationMode ? AXIS_TITLE_PUBLICATION_CLASSNAMES : AXIS_TITLE_CLASSNAMES}`,
                    )
                    .attr('x', -height / 2)
                    .attr('y', margin.left - 20)
                    .attr(
                        'fill',
                        customPlotStylingOptions?.yaxis ? customPlotStylingOptions?.yaxis.fontColor : 'currentColor',
                    )
                    .style(
                        'font-size',
                        customPlotStylingOptions?.yaxis ? customPlotStylingOptions?.yaxis.fontSize : '18',
                    )
                    .style(
                        'font-family',
                        customPlotStylingOptions?.yaxis ? customPlotStylingOptions?.yaxis.fontFamily : 'Arial',
                    )
                    .attr('text-anchor', 'middle')
                    .attr('transform', 'rotate(-90)')
                    .text(`-log₁₀(adjusted p-value)`);

                // Append the X and Y axis
                svg.append('g').call(xAxis);
                svg.append('g').call(yAxis);

                // Optionally draw a horizontal line for the p-value threshold
                svg.selectAll('.line-pval').remove();
                if (showPvalLine) {
                    svg.append('line')
                        .attr('class', `line-pval `)
                        .attr('stroke', horizontal)
                        .style('stroke-width', 2)
                        .attr('x1', xScale(xStats.min))
                        .attr('y1', yScale(negLog10PvalFillThreshold))
                        .attr('x2', xScale(xStats.max))
                        .attr('y2', yScale(negLog10PvalFillThreshold));
                } else {
                    svg.selectAll('.line-pval').remove();
                }

                // Optionally draw vertical lines at the fold change thresholds
                svg.selectAll('.line-fold-change').remove();
                if (showFoldChangeLines) {
                    // Lower bound line
                    svg.append('line')
                        .attr('class', `line-fold-change`)
                        .attr('stroke', vertical)
                        .style('stroke-width', 2)
                        .attr('x1', xScale(log2FoldChangeFillThresholdLower))
                        .attr('y1', yScale(yMin))
                        .attr('x2', xScale(log2FoldChangeFillThresholdLower))
                        .attr('y2', yScale(yMax));

                    // Upper bound line
                    svg.append('line')
                        .attr('class', `line-fold-change `)
                        .attr('stroke', vertical)
                        .style('stroke-width', 2)
                        .attr('x1', xScale(log2FoldChangeFillThresholdUpper))
                        .attr('y1', yScale(yMin))
                        .attr('x2', xScale(log2FoldChangeFillThresholdUpper))
                        .attr('y2', yScale(yMax));
                } else {
                    svg.selectAll('.line-fold-change').remove();
                }

                const getCircleStrokeColor = (d: DataPoint) => {
                    const data = getTargetData(d) ?? '';
                    if (options.selected_targets?.includes(data)) {
                        return DOT_HIGHLIGHT_COLOR;
                    }

                    return d.x > 0 ? positive : negative;
                };

                const getCircleFillColor = (d: DataPoint) => {
                    const data = getTargetData(d) ?? '';
                    if (options.selected_targets?.includes(data)) {
                        return DOT_HIGHLIGHT_COLOR;
                    }

                    const currentColor = getCircleStrokeColor(d);

                    return (d.x > log2FoldChangeFillThresholdUpper && d.y > negLog10PvalFillThreshold) ||
                        (d.x < log2FoldChangeFillThresholdLower && d.y > negLog10PvalFillThreshold)
                        ? currentColor
                        : 'white';
                };

                const DOT_RADIUS = 4.5;

                // Draw Scatter Plot
                // Add dots
                svg.append('g')
                    .selectAll('dot')
                    .data(allSamples)
                    .enter()
                    .append('circle')
                    .attr('cx', function (d) {
                        return xScale(d.x);
                    })
                    .attr('cy', function (d) {
                        return yScale(d.y);
                    })
                    .attr('r', DOT_RADIUS)
                    .attr('stroke', getCircleStrokeColor)
                    .attr('stroke-width', 1)
                    .attr('class', '')
                    .style('fill-opacity', '0.75')
                    .style('fill', getCircleFillColor)
                    .on('mouseover', function (event, d) {
                        const circle = d3.select(this);
                        circle
                            .style('cursor', 'crosshair')
                            .style('fill-opacity', '1')
                            .attr('stroke', 'red')
                            .style('fill', 'red');
                        tooltipContainer.style('opacity', 1);
                        tooltipContainer
                            .html(
                                `
<span class="block font-semibold text-dark">${blankToNull(d.probe_id) ?? blankToNull(d.peak_id) ?? ''}</span>
<span class="${cn('block font-semibold text-dark', { hidden: isNotBlank(d.peak_id) || isNotBlank(d.probe_id) })}">${
                                    blankToNull(d.Gene_Symbol) ?? ''
                                }</span>
<span class="${cn('block text-sm text-gray-600', {
                                    hidden:
                                        (isBlank(d.peak_id) && isBlank(d.probe_id)) ||
                                        !blankToNull(d.Gene_Symbol ?? d.gene_id),
                                })}">gene: ${blankToNull(d.Gene_Symbol ?? d.gene_id) ?? ''}</span>
<span class="block text-sm text-gray-600">x: ${d.x.toFixed(4)}</span>
<span class="block text-sm text-gray-600">y: ${d.y.toFixed(4)}</span>
`,
                            )
                            .style('left', `${event.pageX + 10}px`)
                            .style('top', `${event.pageY - 10}px`);
                        this.parentNode?.appendChild(this);
                    })
                    .on('mouseout', function (event, d) {
                        const circle = d3.select(this);
                        circle
                            .style('fill-opacity', '.75')
                            .attr('stroke', getCircleStrokeColor(d))
                            .style('fill', getCircleFillColor(d));
                        tooltipContainer.style('opacity', 0);
                    });

                const selectedTargetSamples = allSamples.filter((d) =>
                    options.selected_targets?.some((t) =>
                        [d.peak_id, d.Gene_Symbol, d.gene_id, d.probe_id].filter(isDefined).includes(t),
                    ),
                );
                const labelOffset = 6;
                if (showTargetLabels) {
                    svg.append('g')
                        .selectAll('text')
                        .data(selectedTargetSamples)
                        .enter()
                        .append('text')
                        .text((d) => getTargetData(d) ?? '')
                        .attr('x', function (d) {
                            const offset = d.x >= 0 ? -labelOffset : labelOffset;
                            return xScale(d.x) + offset;
                        })
                        .attr('y', function (d) {
                            return yScale(d.y);
                        })
                        .attr('text-anchor', (d) => (d.x >= 0 ? 'end' : 'start'));
                }
            },
        [customPlotStylingOptions],
    );

    return <DynamicPlotContainer drawChart={drawChart} />;
};

export default VolcanoPlotView;
