import React, { useMemo, useCallback } from "react";

import { AxisBottom, AxisLeft } from "@visx/axis";
import { Annotation, Label } from "@visx/annotation";
import { curveNatural } from "@visx/curve";
import { GridColumns } from "@visx/grid";
import { Group } from "@visx/group";
import { scaleLinear } from "@visx/scale";
import { AreaClosed, LinePath } from "@visx/shape";

import { round10 } from "_react/shared/_helpers/numbers";
import { ALL_PITCH_TYPES, MAX_PITCH_VELOCITY, MIN_PITCH_VELOCITY } from "_react/shared/_constants/pitch_types";
import { TPitchTypes } from "_react/shared/_types/pitch_types";
import { getExtent, getPlotDimensions } from "_react/shared/dataviz/_helpers";
import { DEFAULT_OPACITY, DEFAULT_FILL_OPACITY } from "_react/shared/dataviz/_constants";
import { useDataVizColors, useAxisLabelProps } from "_react/shared/dataviz/_hooks";
import { TAxisExtrema, TMarginsProps } from "_react/shared/dataviz/_types";

export type TXAxisExtremaProps = TAxisExtrema<number> & {
	isOverrideDistributionExtrema?: boolean;
};

type TXAxisProps = {
	xLabel?: string;
	extrema?: TXAxisExtremaProps;
	isShowXAxisGridLines?: boolean;
	isXAxisReverse?: boolean;
};

type TYAxisProps = {
	yLabel?: string;
	extrema?: TAxisExtrema<number>;
	scaleMaxMultiplier?: number;
};

type TPitchMetricRidgelinePlotProps<T, K extends keyof T> = {
	distributionData?: { [index: string]: Array<T> };
	distributionX: K | string;
	getDistributionXFunction?: (obj: T) => number;
	distributionY: K | string;
	getDistributionYFunction?: (obj: T) => number;
	sortDistributionDataPitchTypes?: (a: [string, Array<T> | undefined], b: [string, Array<T> | undefined]) => number;
	averageReleaseVeloData?: { [index: string]: number };
	xAxis?: TXAxisProps;
	yAxis?: TYAxisProps;
	// TODO: Decide if we want to add this prop or change the background color on our plots
	backgroundColor?: string;
	// TODO: Need to make this responsive to parent but want to wait until parent components are created
	width?: number;
	height?: number;
	margins?: TMarginsProps;
};

export const PitchMetricRidgelinePlot = <T, K extends keyof T>({
	distributionData,
	distributionX,
	getDistributionXFunction,
	distributionY,
	getDistributionYFunction,
	sortDistributionDataPitchTypes,
	averageReleaseVeloData,
	xAxis,
	yAxis,
	backgroundColor: backgroundColorProp,
	width,
	height,
	margins
}: TPitchMetricRidgelinePlotProps<T, K>) => {
	// STYLING SETUP
	const { axisColor, backgroundColor, gridStrokeColor, pitchTypeColorDict } = useDataVizColors();
	const axisLabelProps = useAxisLabelProps();

	// SIZING SETUP
	const {
		width: WIDTH,
		height: HEIGHT,
		margins: MARGIN,
		innerWidth: INNER_WIDTH,
		innerHeight: INNER_HEIGHT
	} = getPlotDimensions(width ?? 500, height, margins ?? { top: 20, right: 70 });

	// DATA SETUP
	const getXValues = useCallback(
		(datum: T) => (getDistributionXFunction ? getDistributionXFunction(datum) : datum[distributionX as K]),
		[getDistributionXFunction, distributionX]
	);
	const getYValues = useCallback(
		(datum: T) => (getDistributionYFunction ? getDistributionYFunction(datum) : datum[distributionY as K]),
		[getDistributionYFunction, distributionY]
	);

	const [xMinOverall, xMaxOverall] = useMemo(() => {
		// If values are provided, don't bother calculating the extent
		if (
			xAxis?.extrema?.isOverrideDistributionExtrema &&
			xAxis?.extrema?.min !== undefined &&
			xAxis?.extrema?.max !== undefined
		)
			return [xAxis.extrema.min, xAxis.extrema.max];

		const [xMinDefault, xMaxDefault] = getExtent<number>(
			xAxis?.extrema?.min !== undefined ? xAxis.extrema.min : MIN_PITCH_VELOCITY,
			xAxis?.extrema?.max !== undefined ? xAxis.extrema.max : MAX_PITCH_VELOCITY,
			distributionData
				? // Create an array of all the values from all the pitch types
				  Object.values(distributionData)
						.reduce(
							(aggregatedData: Array<T>, pitchTypeData: Array<T>) => aggregatedData.concat(pitchTypeData),
							[]
						)
						.map((datum: T) => getXValues(datum) as number)
				: undefined
		);
		const xMin =
			xAxis?.extrema?.isOverrideDistributionExtrema && xAxis?.extrema?.min !== undefined
				? xAxis.extrema.min
				: xMinDefault;
		const xMax =
			xAxis?.extrema?.isOverrideDistributionExtrema && xAxis?.extrema?.max !== undefined
				? xAxis.extrema.max
				: xMaxDefault;
		return [xMin, xMax];
	}, [xAxis?.extrema, distributionData, getXValues]);

	// Get the most probable x value (highest y value) for each pitch type to sort the pitch types on the y axis
	const pitchTypeMostProbableXValues: { [index: string]: number } = useMemo(() => {
		if (!distributionData) return {};
		return Object.assign(
			{},
			...Object.keys(distributionData).map((pitchType: string) => {
				const pitchDataSorted = [...distributionData[pitchType]].sort(
					(a_datum: T, b_datum: T) => (getYValues(b_datum) as number) - (getYValues(a_datum) as number)
				);
				const yMax = getXValues(pitchDataSorted[0]);
				return { [pitchType]: yMax ? (yMax as number) : 0 };
			})
		);
	}, [distributionData, getYValues, getXValues]);

	const [yMinOverall, yMaxOverall] = useMemo(() => {
		return getExtent<number>(
			yAxis?.extrema?.min !== undefined ? yAxis.extrema.min : 0,
			yAxis?.extrema?.max !== undefined ? yAxis.extrema.max : 1,
			distributionData
				? // Create an array of all the values from all the pitch types
				  Object.values(distributionData)
						.reduce(
							(aggregatedData: Array<T>, pitchTypeData: Array<T>) => aggregatedData.concat(pitchTypeData),
							[]
						)
						.map((datum: T) => getYValues(datum) as number)
				: undefined
		);
	}, [distributionData, yAxis?.extrema, getYValues]);

	// AXES SETUP
	const xScaleOverall = scaleLinear({
		domain: xAxis?.isXAxisReverse ? [xMaxOverall, xMinOverall] : [xMinOverall, xMaxOverall],
		range: [0, INNER_WIDTH]
	});
	// The overall y axis height needs to be multiplied by the number of plots we have
	// The individual plot y axis range needs to be a fraction of the overall height
	const distributionDataLength = distributionData ? Object.entries(distributionData).length : 1;
	const yScaleOverall = scaleLinear({
		domain: [yMinOverall, yMaxOverall * distributionDataLength],
		range: [INNER_HEIGHT, 0],
		nice: true
	});

	const yScaleIndividualPlot = scaleLinear({
		domain: [yMinOverall, yMaxOverall * (yAxis?.scaleMaxMultiplier ?? 1)],
		range: [INNER_HEIGHT / distributionDataLength, 0],
		nice: true
	});

	// The min of the y axis may be a little lower than the data y min because nice = true for the axis scale
	// This gets the y axis min so that the plots go all the way to bottom of the y axis
	const yAxisMin = yScaleOverall.domain()[0];
	const yAxisMax = yScaleOverall.domain()[1];

	// Function to sort distribution data pitch types, default is to sort by a pitch type's most probable x value ascending
	const sortDistributionDataPitchTypesFunction =
		sortDistributionDataPitchTypes ??
		((
			[a_pitch_type, _a_data]: [string, Array<T> | undefined],
			[b_pitch_type, _b_data]: [string, Array<T> | undefined]
		) => {
			if (a_pitch_type === ALL_PITCH_TYPES) return -1;
			if (b_pitch_type === ALL_PITCH_TYPES) return 1;
			return pitchTypeMostProbableXValues[b_pitch_type] - pitchTypeMostProbableXValues[a_pitch_type];
		});

	// Custom tick labels for the pitch types
	const yTickLabels: Array<{ label: string; value: number }> = distributionData
		? Object.entries(distributionData)
				.sort(sortDistributionDataPitchTypesFunction)
				.map(([pitchType, _data]: [string, Array<T> | undefined], index: number) => {
					return { label: pitchType, value: (yAxisMax / distributionDataLength) * index };
				})
		: [];

	return (
		<>
			<svg width={WIDTH} height={HEIGHT}>
				<rect x={0} y={0} width={WIDTH} height={HEIGHT} fill={backgroundColorProp ?? backgroundColor} />
				<Group left={MARGIN.left} top={MARGIN.top}>
					{xAxis?.isShowXAxisGridLines && (
						<GridColumns
							scale={xScaleOverall}
							width={INNER_WIDTH}
							height={INNER_HEIGHT}
							stroke={gridStrokeColor}
							numTicks={7}
						/>
					)}
					<AxisBottom
						scale={xScaleOverall}
						top={INNER_HEIGHT}
						label={xAxis?.xLabel}
						labelProps={axisLabelProps}
						stroke={axisColor}
						tickStroke={axisColor}
						numTicks={7}
					/>
					<AxisLeft
						scale={yScaleOverall}
						label={yAxis?.yLabel}
						labelOffset={20}
						stroke={axisColor}
						tickStroke={axisColor}
						tickValues={
							yTickLabels
								? yTickLabels.map((tickLabel: { value: number; label: string }) => tickLabel.value)
								: undefined
						}
						tickFormat={
							yTickLabels
								? v =>
										yTickLabels.find(
											(tickLabel: { value: number; label: string }) => tickLabel.value === v
										)?.label
								: undefined
						}
					/>
					{distributionData &&
						Object.entries(distributionData)
							?.sort(sortDistributionDataPitchTypesFunction)
							.map(([pitchType, data]: [string, Array<T> | undefined], index: number) => {
								if (data === undefined) return <></>;
								// Filter data to just what fits within the x-axis
								const filteredData = data.filter((datum: T) => {
									const xValue = getXValues(datum);
									return xValue >= xMinOverall && xValue <= xMaxOverall;
								});
								// Find the value for the bottom of the y axis (values go up as you go down the y axis)
								// and divide it by how many pitch types we will need to plot
								// This gives you how much of the y axis each pitch type should be taking up
								const yScaleMin = yScaleOverall(yAxisMin);
								const yScalePitchTypeAdjustment = yScaleMin / distributionDataLength;
								// Calculate where each plot starts and ends on the y axis
								// which is the bottom of the y axis and then move up the y axis in increments for each pitch type
								const yScaleStart = yScaleMin - yScalePitchTypeAdjustment * index;
								const yScaleEnd = yScaleStart - yScalePitchTypeAdjustment;

								return (
									<Group className={pitchType} key={pitchType}>
										<LinePath
											curve={curveNatural}
											data={filteredData}
											x={(datum: T) => xScaleOverall(getXValues(datum) as number)}
											y={(datum: T) =>
												yScaleIndividualPlot(getYValues(datum) as number) + yScaleEnd
											}
											stroke={pitchTypeColorDict[pitchType as TPitchTypes]}
											strokeWidth={1}
											strokeOpacity={DEFAULT_OPACITY}
											shapeRendering="geometricPrecision"
										/>
										<AreaClosed
											yScale={yScaleOverall}
											data={filteredData}
											x={(datum: T) => xScaleOverall(getXValues(datum) as number)}
											y={(datum: T) =>
												yScaleIndividualPlot(getYValues(datum) as number) + yScaleEnd
											}
											fill={pitchTypeColorDict[pitchType as TPitchTypes]}
											fillOpacity={DEFAULT_FILL_OPACITY}
											curve={curveNatural}
											y0={() => yScaleStart}
										/>
										{averageReleaseVeloData && averageReleaseVeloData[pitchType] && (
											<Annotation
												x={xScaleOverall(xMaxOverall)}
												y={yScaleIndividualPlot(0) + yScaleEnd}
												dx={0.5}
												dy={0.5}
											>
												<Label
													title={`${round10(averageReleaseVeloData[pitchType], -1)} mph`}
													titleFontWeight={500}
													titleFontSize={12}
													showAnchorLine={false}
													fontColor={axisColor}
													backgroundFill={gridStrokeColor}
													horizontalAnchor="start"
													verticalAnchor="end"
													backgroundPadding={{ left: 5, top: 2, bottom: 2 }}
												/>
											</Annotation>
										)}
									</Group>
								);
							})}
				</Group>
			</svg>
		</>
	);
};

export default PitchMetricRidgelinePlot;
