import { Slider, Typography } from '@mui/material';
import Grid2 from '@mui/material/Unstable_Grid2/Grid2';
import React, { useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useAppSelector } from '../../app/hooks';
import { countPopulationChange, sumPopulationPerYear } from '../../common/functions';
import { PopulationChangeChart, PopulationChangeType } from '../../components/charts/ComposedCharts';
import { AgeGroupChart } from '../../components/charts/LineCharts';
import { CohortByAreaByYearType } from '../../types';
import { useGetAdjustedDataForecastForForecastQuery } from '../apis/apiSlice';
import { useGetAreasQuery } from '../apis/areaSlice';
import { selectCurrentForecast } from '../apis/forecastApi';
import { AGE_GROUPS, PopulationGroupSizes } from '../municipality/ChangeBarChart';
import { AreaSelection } from './AreaSelection';
import { getAreaHierarchy } from './queryViewFunctions';

import * as dfd from 'danfojs';

const UNKNOWN_AREA_ID = '999999'
const SMALL_AREA_IDX = 0
const MAJOR_AREA_IDX = 1

const jsonToDataFrame = (json: CohortByAreaByYearType): dfd.DataFrame => {

    const columns = [ 'year', 'area', Object.values(Object.values(json)[0]).map(cohort => Object.keys(cohort))[0] ].flat()
    const flat = Object.entries(json).map(([year, areaData]) => [
        ...Object.entries(areaData).map(([area, cohort]) => [
            year,
            area,
            ...Object.entries(cohort).map(([key, value]) => value)
        ])
    ]).flat()

    return new dfd.DataFrame(flat, { columns, index: flat.map(([ year, area ]) => `${year}_${area}`) })
}

const dataFrameToJSON = (df: dfd.DataFrame): CohortByAreaByYearType => {

    const data = dfd.toJSON(df) as any[]

    return data.reduce((prev, curr) => {
        const { year, key: _, area, ...rest } = curr
        return {
            ...prev,
            [year]: {
                ...prev[year],
                [area]: {
                    ...prev[year]?.[area],
                    ...rest,
                }
            }
        }
    }, {} as any) as any
}

const calculateAgeGroups = (data?: dfd.DataFrame): any => {
    if (!data || !data.size) return
    
    // Get the year value, before dropping the column
    const year = (data.loc({ columns: ['year'] }).values[0] as any[])[0]; // TODO: bit fragile

    // Sum cohort sizes over all areas
    const cohortSizesSeries = data.copy().drop({ columns: ['year', 'area'] }).sum({ axis: 0 });
    const cohortSizes = new dfd.DataFrame([cohortSizesSeries.values], { columns: cohortSizesSeries.index.map(index => index.toString()) }); // back to DataFrame

    // Not the bestest way to find min and max ages, but it works
    // These are updated in the loop below
    let minAge = 99;
    let maxAge = 0;

    // Calculate age group sums
    const REVERSED_AGE_GROUPS = AGE_GROUPS.slice().reverse()
    const ageGroupLabels = REVERSED_AGE_GROUPS.map(({ from, to }) => `${from}-${to}`);
    const ageGroupCohorts = REVERSED_AGE_GROUPS.map(({ from, to }) => {
        const femaleCohorts = [];
        const maleCohorts = [];
        for (let i = from; i <= to; i++) {
            if (i <= 99) { // 99 includes 100 and above
                if (data.columns.includes(`f${i}`)) { // Check if the cohort exists, assume that f and m cohorts are the same
                    femaleCohorts.push(`f${i}`);
                    maleCohorts.push(`m${i}`); 

                    if (i < minAge) {
                        minAge = i;
                    }
                    if (i > maxAge) {
                        maxAge = i;
                    }
                }
            } 
        }
        return [femaleCohorts, maleCohorts];
    });

    // Create the result data structure
    const result: { [year: string]: { male: any[], female: any[] } } = {};
    result[year] = { male: [], female: [] };

    // Create an empty DataFrame for age group sizes. FIXME: without the dummy data, the addColumn fails. See how to fix this.
    //const ageGroupSizes = new dfd.DataFrame({'dummy': [0, 0]}, {'index': ['females', 'males']}); //, 'columns': ["f0", "f1"]});
    
    // Create a new column for each age group, and set the value to the sum of the age groups
    for (let i = 0; i < ageGroupLabels.length; i++) {
        const label = ageGroupLabels[i];
        const ageFrom = REVERSED_AGE_GROUPS[i].from;
        const ageTo = REVERSED_AGE_GROUPS[i].to;

        if (ageFrom < minAge || (ageTo > maxAge && maxAge != 99)) {
            continue;
        }
        
        // Access female and male cohorts correctly
        const femaleCohorts = ageGroupCohorts![i]![0]; // First array for females
        const maleCohorts = ageGroupCohorts![i]![1]; // Second array for males
        
        // Sum up the cohort sizes
        const numberOfFemales = cohortSizes.loc({ columns: femaleCohorts }).sum({ axis: 1 }).values[0] as number;
        const numberOfMales = cohortSizes.loc({ columns: maleCohorts }).sum({ axis: 1 }).values[0] as number; // Correct cohort for males
    
        result[year].female.push({
            sex: 'female',
            label: label,
            size: -Number(numberOfFemales.toFixed(2))
        });

        result[year].male.push({
            sex: 'male',
            label: label,
            size: Number(numberOfMales.toFixed(2))
        });
    }

    return [result];
};

const DataframeAgeGroupChart: React.FC<{ data?: dfd.DataFrame }> = ({ data }) => {

    const [ ageGroups, setAgeGroups ] = useState<PopulationGroupSizes>()

    function findMaxAgeGroupSize(populationData?: PopulationGroupSizes): number {
        if (!populationData) {
          return 0;
        }
        
        let maxSize = 0;
        const values = populationData.values();
        for (const entries of populationData.values()) {
            const males = Object.values(entries)[0].male;
            const females = Object.values(entries)[0].female;
            const maleGroupSizes = males.map((group: any) => Math.abs(group.size));
            const femaleGroupSizes = females.map((group: any) => Math.abs(group.size));
            const maxGroupSize = Math.max(...maleGroupSizes, ...femaleGroupSizes);
            if (maxGroupSize > maxSize) {
              maxSize = maxGroupSize;
            }
        }
      
        return maxSize;
    }

    function maxAxisValue(maxSize: number): number {
        if (maxSize < 100) {
            return Math.ceil(maxSize * 1.10 / 10) * 10;
        }
        else {
            return Math.ceil(maxSize * 1.10 / 100) * 100;
        }
    }

    useEffect(() => {
        if (!data || !data.size) return

        const ageGroups = calculateAgeGroups(data)
        setAgeGroups(ageGroups as any)
    }, [ data ])

    return (
        ageGroups ?
        <Grid2 container>
            <Grid2 xs={12}>
                <AgeGroupChart
                    ageGroups={ageGroups}
                    maxSize={maxAxisValue(findMaxAgeGroupSize(ageGroups))} 
                    minYear={Number(Object.keys(ageGroups[0]))}
                    aspect={0.73}
                />
            </Grid2>
        </Grid2>
        : null
    )
}

const DataframePopulationChangeChart: React.FC<{ data?: dfd.DataFrame }> = ({ data }) => {

    const [populationChange, setPopulationChange] = useState<PopulationChangeType | undefined>(undefined)

    useEffect(() => {

        if (!data) return

        const populationSums = sumPopulationPerYear(dataFrameToJSON(data))
        const populationChange = countPopulationChange(populationSums)

        setPopulationChange(populationChange! as any)

    }, [data])

    return populationChange ? <PopulationChangeChart data={populationChange!} aspect={0.96} /> : null
}

const getLocColumns = (ages: { from:number, to:number }) => [
    'year',
    'area',
    ...Array.from(
        { length: ages.to - ages.from + 1 },
        (_, i) => `f${ages.from + i}`
    ),
    ...Array.from(
        { length: ages.to - ages.from + 1 },
        (_, i) => `m${ages.from + i}`
    )
]

const QueryView: React.FC = () => {

    const { t } = useTranslation()

    const { id: forecastId, fromYear, toYear } = useAppSelector(selectCurrentForecast)!

    const { data: forecast } = useGetAdjustedDataForecastForForecastQuery({ forecastId: forecastId! })
    const { data: areas } = useGetAreasQuery()

    const [forecastYears, setForecastYears] = useState<number[]>([])
    const [dataframe, setDataframe] = useState<dfd.DataFrame | undefined>(undefined)

    const [currentYear, setCurrentYear] = useState<number>(fromYear)
    const [ages, setAges] = useState<{ from: number, to: number }>({ from: 0, to: 99 })
    const [selectedAreas, setSelectedAreas] = useState<string[]>([])

    const [ areasPerCurrentYearFilter, setAreasPerCurrentYearFilter ] = useState<string[]>([])
    const [ selectedAreasFilter, setSelectedAreasFilter ] = useState<string[]>([])
    const [ locColumns, setLocColumns ] = useState<string[]>([])

    const filterSelectedAreas = () => setSelectedAreasFilter(forecastYears.map(year => selectedAreas.filter(area => area.length > 4).map(area => `${year}_${area}`)).flat())
    
    const filterAreasPerCurrentYear = () => {

        const smallAreas = selectedAreas.filter(area => area.length > 4)
        const areasPerYear = forecastYears.flatMap(year => smallAreas.map(area => `${year}_${area}`))
        const areasPerCurrentYear = areasPerYear.filter(area => area.startsWith(`${currentYear}_`))

        setAreasPerCurrentYearFilter(areasPerCurrentYear)
    }

    const filterColumns = () => setLocColumns(getLocColumns(ages))

    useEffect(() => {

        if (areas) {
            
            const areaIds = Object.keys(areas[SMALL_AREA_IDX])
                .concat(
                    Object.keys(areas[MAJOR_AREA_IDX])
                    .filter(area => !area.endsWith(UNKNOWN_AREA_ID))
                )

            setSelectedAreas(areaIds)
            filterSelectedAreas()
        }
    // eslint-disable-next-line
    }, [areas])

    useEffect(() => {
        if (forecast) {
            const df = jsonToDataFrame(forecast.adjusted_projections[0])
            setDataframe(df)
            setForecastYears(Array.from({ length: toYear - fromYear + 1 }, (_, i) => fromYear + i))
        }
    }, [ forecast, fromYear, toYear ])

    useEffect(() => {
        const filtered = forecastYears.map(year => selectedAreas.filter(area => area.length > 4).map(area => `${year}_${area}`)).flat()
        setSelectedAreasFilter(filtered)
    }, [ selectedAreas, areas, forecastYears ])

    useEffect(() => {
        filterSelectedAreas()
        filterAreasPerCurrentYear()
    // eslint-disable-next-line
    }, [ currentYear, selectedAreas, areas, forecastYears ])

    useEffect(() => {
        filterColumns()
    // eslint-disable-next-line
    }, [ ages ])

    const [ tempYear, setTempYear ] = useState<number>(fromYear)
    const [ tempAges, setTempAges ] = useState<{ from: number, to: number }>({ from: 0, to: 99 })

    return areas?.length && dataframe?.values.length ? (
        <Grid2 container width={'100%'} height={'80%'}>
            <Grid2 xs={3} height={'100vh'}>
                <Typography variant='h3' style={{ fontSize: '1rem', fontWeight: 700 }}>{t('queryView.areas')}</Typography>
                <AreaSelection
                    areas={getAreaHierarchy(areas![1], areas![0])}
                    selectedAreas={selectedAreas}
                    setSelectedAreas={setSelectedAreas}
                    props={ { height: '90vh' } }
                />
                {
                    // @ts-ignore
                    <Slider
                        sx={{ paddingTop: '2.5rem', alignSelf: 'flex-end' }}
                        title={t('general.ages')}
                        min={0}
                        max={99}
                        value={[tempAges.from, tempAges.to]}
                        onChange={(e, [from, to]: number[]) => setTempAges({ from, to })}
                        onChangeCommitted={(e, [from, to]: number[]) => setAges({ from, to })}
                        valueLabelDisplay='on'
                        marks={[
                            { value: 0, label: '0' },
                            { value: 99, label: '99' }
                        ]}
                    />
                }
            </Grid2>
            <Grid2 container xs={9}>
                <Grid2 xs={12} container>
                    <Grid2 xs={7}>
                        <DataframePopulationChangeChart data={dataframe.loc({ rows: selectedAreasFilter, columns: locColumns })} />
                    </Grid2>
                    <Grid2 xs={5}>
                        <DataframeAgeGroupChart data={selectRowsAndColumns(dataframe, areasPerCurrentYearFilter, locColumns)} />
                        <Slider
                            title='test'
                            min={fromYear}
                            max={toYear}
                            value={tempYear}
                            onChange={(_, value) => setTempYear(Number(value))}
                            onChangeCommitted={(_, value) => setCurrentYear(Number(value))}
                        />
                    </Grid2>
                </Grid2>
            </Grid2>
        </Grid2>
    ) : null
};

function selectRowsAndColumns(dataframe: dfd.DataFrame, rows: string[], columns: string[]): dfd.DataFrame {
    const selectMask = dataframe.index.map(index => rows.includes(index.toString()));
    const selected = dataframe.loc({ rows: selectMask, columns: columns }); // true or false for each index value, depending on if it is found from the rows parameter
    return selected;
}

export default QueryView;