// https://simplestatistics.org/docs/
import * as SS from 'simple-statistics'
import * as U from '../Utils'
import * as STT from "../StypeTools"
import minimize from './Math/Optimizers'
import * as Colors from "../Colors"
import {WorkspaceStore} from "../Database"
import {alias, list, object, primitive, serializable} from "serializr"
import {
    CategoryValue,
    ColumnDBIndex,
    ColumnId,
    ColumnsById,
    ColumnSummaryType,
    ColumnUIIndex,
    DataValue,
    IColumnSummary,
    INumberSummary,
    ITableColumn,
    NumericalCategories,
    ZeroColumnId
} from "../Concepts/Basic"
import {FirstMoment, ITimeInterval} from "../Concepts/DateTime"
import {AggregationByCategory, AggregationByInterval, AggregationType} from "../Concepts/Aggregation"
import {NumFreq} from "../Concepts/FrequencyOfValue"
import {Point} from "../Concepts/Geometry"
import {StandardColor} from "../Concepts/Colors"
import DataColumn from "./DataColumn"
import TableColumn from "./TableColumn"
import ValueDistribution from "./ValueDistribution"
import CategoryDistribution from "./CategoryDistribution"
import ColumnSummary from "./ColumnSummary"
import NumberSummary from "./NumberSummary"
import {IDataTable, UnknownColumnWidth} from "../Concepts/DataTable"
import {IVisSettings} from "../Concepts/VisSettings"
import {IWorkspaceStore, WorkspaceStoreId} from "../Concepts/DataBase"
import {SemanticType} from "../Concepts/SemanticType"

const MAX_KDE_X = 100

/***
 * Represents user's table of source data
 */

export default class DataTable implements IDataTable {
    @serializable(alias('rowcnt')) protected readonly _rowCount: number
    // column unique identifiers ordered by column database index
    @serializable(alias('colids', list(primitive()))) protected readonly _columnIds: ColumnId[]
    // table columns ordered by visual column index
    @serializable(alias('columns', list(object(TableColumn)))) protected readonly _columns: TableColumn[]
    @serializable(alias('data', list(list(primitive())))) protected readonly data: DataValue[][]

    static async loadFromStore (wsStore:IWorkspaceStore):Promise<IDataTable> {
        const columns = await wsStore.createTableColumnStore(TableColumn).retrieveAll(),
            tableColumnStore = wsStore.createTableColumnStore(TableColumn)
        columns.forEach(column => column.bindToStore (tableColumnStore))
        return new DataTable(
            columns,
            await wsStore.data.retrieveAll(),
            wsStore.id
        )
    }

    static async createFromData (data: DataValue[][], columns:TableColumn[]):Promise<IDataTable> {
        return new DataTable(columns, data)
    }

    constructor(columns?: TableColumn[], data?: DataValue[][], protected _wsStoreId?:WorkspaceStoreId) {
        this._columns = columns ?? []
        this.data = data ?? []
        this._columnIds = this._columns.map (i => i.id)
        this._rowCount = this.data.length
        this.resortColumns()
    }

    get wsStoreId() {
        if (this._wsStoreId !== undefined) {
            return this._wsStoreId;
        }
        else {
            throw Error ("DataTable is not bound to any WorkspaceStore")
        }
    }

    get rowCount() { return this._rowCount; }

    get visualConfigSnapshot():string {
        return this._columns.map (c => `${c.id},${c.width},${c.order},${c.selected?1:0},${c.pinned?1:0}`).join(';')
    }

    columnOrigIndexById (id:ColumnId):ColumnDBIndex {
        const index = this._columnIds.indexOf(id)
        if (index < 0) {
            throw Error (`Can't get index of column with id=${id} as the latter is not found`)
        }
        return ColumnDBIndex(index)
    }

    columnIdByUIIndex (index:ColumnUIIndex):ColumnId {
        const column = this._columns[index]
        if (column === undefined) {
            throw Error (`Can't get id of column with index=${index} as the latter doesn't exist`)
        }
        return ColumnId(column.id)
    }

    columnIdByDBIndex (index:ColumnDBIndex):ColumnId {
            const id = this._columnIds[index]
            if (id === undefined) {
                throw Error (`Can't get id of column with dbindex=${index} as the latter doesn't exist`)
            }
            return ColumnId(id)
        }

    buildIndexToOrigIndexMap():ColumnDBIndex[] {
        return this._columns.map(c => this.columnOrigIndexById(c.id))
    }

    getColumnInfo (id: ColumnId):ITableColumn {
        const cols = this._columns.filter (c => c.id === id)
        if (cols.length === 1) {
            return cols[0]
        } else {
            throw new Error (cols.length ? `There are more than one column with id=${id}` : `Column with id=${id} is not found`)
        }
    }

    getStype (id: ColumnId) {
        return this.getColumnInfo(id).stype;
    }

    getRows (start: number, end: number):Promise<{start:number, data:DataValue[][]}> {
        return Promise.resolve({start, data:this.data.slice (start, end)})
    }

    async changeWorkspaceStore (visSettings?: IVisSettings):Promise<WorkspaceStoreId> {
        const store = WorkspaceStore.create()
        try {
            this._wsStoreId = store.id
            const tableColumnStore = store.createTableColumnStore(TableColumn)
            await tableColumnStore.bulkAdd(this.columnsOrderedByOrigIndex() as TableColumn[])
            this._columns.forEach(column => column.bindToStore(tableColumnStore))
            await store.data.add(0, this.data)
            if (visSettings) {
                await visSettings.putToStore(store)
            }
            return store.id
        }
        catch (e) {
            await WorkspaceStore.delete (store.id)
            throw e
        }
    }

    protected aggregateValueColumns<T> (map:Map<T, Map<ColumnId, number>|null>,
                                        valueColumnIds:ColumnId[],
                                        aggrType:AggregationType,
                                        breakDownValueFromRow:(row:DataValue[])=>T) {

        const countMap = new Map<T, Map<ColumnId, number>>(),
            colIndexes = valueColumnIds.map (id => this.columnOrigIndexById(id))
        this.data.forEach(row => {
            const breakValue = breakDownValueFromRow (row),
                resultsByColumn = map.get(breakValue)
            if (resultsByColumn) {
                valueColumnIds.forEach ((colId,i) => {
                    resultsByColumn.set (colId, DataTable.valueByAggrType(row[colIndexes[i]] as number | null, aggrType, resultsByColumn.get (colId)))
                    if (aggrType === AggregationType.Mean) {
                        const item = U.get(countMap, breakValue)
                        item.set (colId, U.get(item, colId) + (row[colIndexes[i]] === null ? 0 : 1))
                    }
                })
            } else {
                map.set(breakValue, new Map (valueColumnIds.map ((colId,i) => ([colId, DataTable.valueByAggrType(row[colIndexes[i]] as number | null, aggrType)]))))
                if (aggrType === AggregationType.Mean) {
                    countMap.set(breakValue, new Map (valueColumnIds.map((colId, i) => ([colId, row[colIndexes[i]] === null ? 0 : 1]))))
                }
            }
        })

        // get mean value
        if (aggrType === AggregationType.Mean) {
            for (const breakValue of countMap.keys()) {
                valueColumnIds.forEach ((colId) => {
                    const count = U.def(countMap.get(breakValue)?.get(colId))
                    if (count > 0) {
                        const breakValueMap = map.get(breakValue)
                        if (U.mustNotBeNullNorUndefined(breakValueMap)) {
                            breakValueMap.set(colId, U.get(breakValueMap, colId) / count)
                        }
                    }
                })
            }
        }

        // get rid of Infinity values
        if (aggrType === AggregationType.Min || aggrType === AggregationType.Max) {
            for (const breakValue of map.keys()) {
                valueColumnIds.forEach ((colId) => {
                    const resultsByColumn = map.get(breakValue)
                        if (resultsByColumn) {
                            const value = U.get(resultsByColumn, colId)
                            if (value === Infinity || value === -Infinity) {
                                resultsByColumn.delete(colId)
                            }
                        }
                })
            }
        }
    }

    /**
     * If valueColumnIds specified and not empty the function returns an aggregation of their values grouped by the category column values,
     * otherwise it returns counts of different category column values
     */
    aggregateByCategory (aggrType:AggregationType, categoryColumnId: ColumnId, valueColumnIds?: ColumnId[]): Promise<AggregationByCategory> {
        const categoryColumnIndex = this.columnOrigIndexById(categoryColumnId)
        const map = new Map<CategoryValue, Map<ColumnId, number>>()
        if (valueColumnIds === undefined || valueColumnIds.length === 0) {
            // count category frequency
            this.data.forEach(row => {
                const val = row[categoryColumnIndex]
                const cat = val === null ? null : val.toString()
                const count = map.get(cat)
                if (count === undefined) {
                    map.set(cat, new Map ([[ZeroColumnId, 1]]))
                } else {
                    count.set(ZeroColumnId, U.get(count, ZeroColumnId) + 1)
                }
            })
        } else {
            this.aggregateValueColumns(map, valueColumnIds, aggrType, (row:DataValue[])=>{
                const val = row[categoryColumnIndex]
                return val === null ? null : val.toString()
            })
        }
        return Promise.resolve(map)
    }

    /**
     * If valueColumnIds specified and not empty the function returns an aggregation of their values grouped by the category column values,
     * otherwise it returns how often values of time column get into each time interval
     */
    aggregateByInterval (aggrType:AggregationType, interval:ITimeInterval, timeColumnId: ColumnId, valueColumnIds: ColumnId[]): Promise<AggregationByInterval> {
        const timeColumnIndex = this.columnOrigIndexById(timeColumnId),
            timeColumnInfo = this.getColumnInfo(timeColumnId),
            map = new Map<FirstMoment|null, Map<ColumnId, number>|null>()

        U.assert (STT.isDateAndTime(timeColumnInfo.stype) || STT.isDateOnly(timeColumnInfo.stype) || STT.isTimeOnly(timeColumnInfo.stype), "timeColumnId must point to date/time column")

        // create resulting map for all interval moments covered by the time column
        const startMoment = interval.getFirstMoment(timeColumnInfo.numberSummary.min, timeColumnInfo.stype),
            endMoment = interval.getFirstMoment(timeColumnInfo.numberSummary.max, timeColumnInfo.stype)

        for (let moment = startMoment; moment <= endMoment; moment = interval.addInterval(moment, timeColumnInfo.stype)) {
            map.set (moment, null)
        }
        if (timeColumnInfo.summary.nullCount>0) {
            map.set (null, null)
        }

        if (valueColumnIds === undefined || valueColumnIds.length === 0) {
            // count interval frequency
            this.data.forEach(row => {
                const val = row[timeColumnIndex],
                    moment = val === null ? null : interval.getFirstMoment(val as number, timeColumnInfo.stype)
                const count = map.get(moment)
                if (count) {
                    count.set(ZeroColumnId, U.get(count, ZeroColumnId) + 1)
                } else {
                    map.set(moment, new Map ([[ZeroColumnId, 1]]))
                }
            })
        } else {
            this.aggregateValueColumns(map, valueColumnIds, aggrType, (row: DataValue[]) => {
                const val = row[timeColumnIndex]
                return val === null ? null : interval.getFirstMoment(val as number, timeColumnInfo.stype)
            })
        }

        return Promise.resolve(map)
    }

    getTuples (columnIds: ColumnId[]): Promise<ColumnsById> {
        const columnIndices = columnIds.map (colId => this.columnOrigIndexById(colId))
        return Promise.resolve(new Map (
            columnIds
                .filter ((colId, i) => columnIndices[i]>=0)
                .map ((colId, i) => [colId, this.data.map (row => row[columnIndices[i]])])
        ))
    }

    /*** Column Summary ***/
    protected static getMLCV (xiArray:number[], xjArray:number[], h:number): number {
        // https://medium.com/analytics-vidhya/kernel-density-estimation-kernel-construction-and-bandwidth-optimization-using-maximum-b1dfce127073
        let sumi = 0
        for (const xi of xiArray) {
            let sumj = 0
            for (const xj of xjArray) {
                sumj += xi === xj ? 0 : Math.exp(-0.5 * Math.pow((xj - xi) / h, 2)) / (h * U.sqrt2PI)
            }
            sumi += Math.log(sumj) - Math.log ((xiArray.length - 1) * h)
        }
        return sumi / xiArray.length
    }

    protected static getKDE (column:DataColumn, origXArray:number[], distr:NumFreq[]):number[] {
        const origValues = column.numbers,
            // normalizing the original values to avoid problems with large numbers
            minValue = U.min(origValues),
            maxValue = U.max(origValues),
            range = maxValue - minValue,
            normalize = (v:number) => (v - minValue)/range,
            normValues = origValues.map (normalize),
            normXArray = origXArray.map (normalize)

        // build simplified sampling for bandwidth optimization
        const maxAllowedNumberOfEachValue = Math.max (1, Math.round(distr.length / 1000)),
            maxFreq = U.max (distr.map (vf => vf.freq)),
            coeff = Math.max (1, maxFreq / maxAllowedNumberOfEachValue),
            mlcvValues:number[] = []

        for (const vf of distr) {
            for (let i = 0; i < vf.freq / coeff; i++) {
                mlcvValues.push(normalize(vf.value))
            }
        }

        // calcalate the optimal bandwidth
        const h = minimize((h:number):number => -DataTable.getMLCV(mlcvValues, normXArray, h), 1e-5, 1e5) * 2

        // calculate kde
        return normXArray.map (xj => normValues.reduce((sum, xi) => {
            const k = Math.exp(-0.5 * Math.pow((xj - xi) / h, 2)) / (h * U.sqrt2PI)
            return sum + k
        }, 0) / normValues.length)
    }

    static calculateValueGridStep(values: { value: number }[]): number {
        let step: number | null = null
        for (let i = 0; i < values.length - 1; i++) {
            const newDiff = Math.abs(U.fround(values[i].value - values[i + 1].value))
            if (step === null) {
                step = newDiff
            } else {
                if (U.equal(newDiff, step)
                    || (newDiff > step && Math.abs(newDiff / step - Math.round(newDiff / step)) < 1e-12)
                    || (newDiff < step && Math.abs(step / newDiff - Math.round(step / newDiff)) < 1e-12)
                ) {
                    step = Math.min(step, newDiff)
                } else {
                    step = U.fround(Math.pow (10, -U.digitsAfterDecimal(U.fround(Math.abs(step - newDiff)))))
                }
            }
        }
        return values.length > 1 ? U.notNull(step) : 1
    }

    static getNumberSummary (column:DataColumn, quickVersion= false):INumberSummary {
        const columnValues = column.numbers,
            distribution = quickVersion ? undefined : ValueDistribution.fromDataColumn(column),
            valueFreqs = distribution?.meanValueFreq,
            nullCount = column.nullCount,
            nothing = columnValues.length === 0,
            min = nothing ? 0 : SS.min(columnValues),
            max = nothing ? 0 : SS.max(columnValues),
            [q1, median, q3] = nothing ? [0, 0, 0] : SS.quantile(columnValues, [0.25, 0.5, 0.75]),
            iqr = q3 - q1,
            whiskerMin = nothing ? 0 : Math.max(q1 - 1.5 * iqr, min),
            whiskerMax = nothing ? 0 : Math.min(q3 + 1.5 * iqr, max),
            outliers = nothing ? [] : columnValues.filter(n => n < whiskerMin || n > whiskerMax),
            mean = nothing ? 0 : SS.mean(columnValues),
            std = nothing ? 0 : SS.standardDeviation(columnValues)

        // calculate kde
        let kde:{x:number, y:number}[] | undefined
        if (!quickVersion && valueFreqs !== undefined) {
            kde = []
            if (valueFreqs.length === 1) {
                kde.push(new Point (valueFreqs[0].value, 1))
            }
            else if (!nothing) {
                let kdeX = valueFreqs.map(vf => vf.value).sort((a, b) => a > b ? 1 : (a === b ? 0 : -1))

                // limit number of points used for building KDE with MAX_KDE_X
                if (kdeX.length > MAX_KDE_X) {
                    const selectedIndices = new Set(U.times(MAX_KDE_X, i => Math.round(i * (kdeX.length - 1) / (MAX_KDE_X - 1))))
                    kdeX = [...selectedIndices].map (i => kdeX[i])
                }
                const kdeY = DataTable.getKDE(column, kdeX, valueFreqs)
                kde = kdeX.map((x, i) => {
                    return new Point (x, kdeY[i])
                })
            }
        }

        const valueGridStep = valueFreqs ? DataTable.calculateValueGridStep(valueFreqs) : null
        let categories:NumericalCategories|undefined

        if (valueGridStep !== null && valueFreqs && valueFreqs.length <= U.settings.numToCat && valueGridStep > 0 && (max-min)/valueGridStep <= 25) {
            const legendValues = [...valueFreqs.map(vf => vf.value)].sort((a, b) => a - b),
                zeroOffsetCoef = Math.abs(max) / (Math.abs(max) + Math.abs(min)),
                positiveAndZeroCount = Math.round(legendValues.length * zeroOffsetCoef),
                colors = min < 0 && max > 0
                    ? Colors.getStandard3ColorRange(Colors.triGradientColorGreen, Colors.triGradientColorYellow, Colors.triGradientColorRed, legendValues.length - positiveAndZeroCount, positiveAndZeroCount)
                    : Colors.getStandard2ColorRange(Colors.duGradientColorYellow, Colors.duGradientColorBlue, legendValues.length),
                legendCategories = new Map<number, StandardColor>(legendValues.map((value, index) => [value, colors[index]])),
                axisCategories = U.times(U.fround((max-min) / valueGridStep) + 1, i => U.fround(min + i * valueGridStep))
            categories = new NumericalCategories(valueGridStep, axisCategories, legendCategories)
        }

        return new NumberSummary(min, max, whiskerMin, whiskerMax, q1, q3, iqr, mean, median, std, outliers, nullCount, distribution, kde, categories)
    }

    static getColumnSummary (column: DataColumn, quickVersion= false): IColumnSummary{
        if (STT.isContinuous(column.stype)) {
            return new ColumnSummary (
                ColumnSummaryType.Number,
                column.nonNullValues.length + column.nullCount,
                column.nullCount,
                DataTable.getNumberSummary(column, quickVersion)
            )
        } else if (STT.isCategorical(column.stype)) {
            return new ColumnSummary (
                ColumnSummaryType.Category,
                column.nonNullValues.length + column.nullCount,
                column.nullCount,
                undefined,
                CategoryDistribution.fromDataColumn(column)
            )
        } else {
            return new ColumnSummary (
                ColumnSummaryType.Text,
                column.nonNullValues.length + column.nullCount,
                column.nullCount,
                DataTable.getNumberSummary(new DataColumn(column.strings.map (value => [(value as string).length]), 0, SemanticType.Number))
            )
        }
    }

    /*** Column Information ***/

    columnByIndex (index:ColumnUIIndex):ITableColumn {
        const column = this._columns[index]
        if (column === undefined) {
            throw new Error (`There is no column with index=${index}`)
        }
        return column
    }

    get lastColumn():ITableColumn {
        if (this.numberOfColumns) {
            return this._columns[this.numberOfColumns - 1]
        } else {
            throw new Error ("There is no last column as there is no column at all")
        }
    }

    // returns columns that pass the test, ordered by ui index
    filteredColumnsByUiIndex (predicate:(col:ITableColumn)=>boolean):ITableColumn[] {
        return this._columns.filter(predicate)
    }

    get numberOfPinnedColumns(): number {
        return this._columns.filter (c => c.pinned).length
    }

    get totalColumnWidth(): number {
        return this._columns.reduce ((w, c) => w + c.width, 0)
    }

    get numberOfColumns () {
        return this._columns.length
    }

    get selectedColumns(): ITableColumn[] {
        return this._columns.filter(c => c.selected)
    }

    get selectedColumnIds(): ColumnId[] {
        return this.selectedColumns.map (c => c.id)
    }

    get columnWidthsByVisualIndex() {
        return this._columns.map(column => column.width)
    }

    get columnIndexesWithUnknownWidth():ColumnUIIndex[] {
        const indexes:ColumnUIIndex[] = []
        for (const [index, column] of this._columns.entries()) {
            if (column.width === UnknownColumnWidth) {
                indexes.push(ColumnUIIndex(index))
            }
        }
        return indexes
    }

    changeColumnWidth (columnIndex: ColumnUIIndex, width: number):string {
        this._columns[columnIndex].width = width
        return this.visualConfigSnapshot
    }

    swapColumns (columnIndex: ColumnUIIndex, direction: 1 | -1):string {
        const me = this._columns[columnIndex],
            nn = columnIndex + Math.sign(direction),
            neighbor = nn >=0 && nn < this._columns.length ? this._columns[nn] : null
        if (neighbor && neighbor.pinned === me.pinned) {
            [neighbor.order, me.order] = [me.order, neighbor.order]
            this.resortColumns()
        }
        return this.visualConfigSnapshot
    }

    toggleColumnPinning (columnIndex: ColumnUIIndex):{pinned: boolean, snapshot: string} {
        const col = this._columns[columnIndex]
        col.pinned = !col.pinned
        this.resortColumns()
        return {pinned: col.pinned, snapshot: this.visualConfigSnapshot}
    }

    selectColumn (columnIndex: ColumnUIIndex, exclusive:boolean):string {
        this._columns.forEach((c, i) => {
            if (i === columnIndex) {
                c.selectionTime = !c.selected || exclusive ? new Date() : null;
            } else if (exclusive && i!==columnIndex && c.selected) {
                c.selectionTime = null
            }
        })
        return this.visualConfigSnapshot
    }

    selectColumns(columnIds: ColumnId[]):string {
        const time = new Date().getMilliseconds()
        this._columns.forEach((column, index) =>
            column.selectionTime = columnIds.indexOf(column.id) < 0 ? null : new Date(time + index)
        )
        return this.visualConfigSnapshot
    }

    /*** Internal auxiliary methods ***/

    protected columnById (id:ColumnId):ITableColumn {
        for (const column of this._columns) {
            if (column.id === id) {
                return column
            }
        }
        throw new Error (`There is no column with id=${id}`)
    }

    protected columnsOrderedByOrigIndex ():ITableColumn[] {
        return this._columnIds.map (id => this.columnById(id))
    }

    protected resortColumns () {
        this._columns.sort((a, b) => a.order + (a.pinned ? 0 : 1000) < b.order + (b.pinned ? 0 : 1000) ? -1 : 1)
    }

    protected static valueByAggrType (value: number | null, aggrType: AggregationType, aggregation?:number):number  {
        switch (aggrType) {
            case AggregationType.Min:
                return value === null ? aggregation ?? Infinity : Math.min(aggregation ?? Infinity, value)
            case AggregationType.Max:
                return value === null ? aggregation ?? -Infinity : Math.max(aggregation ?? -Infinity, value)
            case AggregationType.Mean:
            case AggregationType.Sum:
                return (aggregation ?? 0) + (value ?? 0)
            case AggregationType.NullCount:
                return (aggregation ?? 0) + (value === null ? 1 : 0)
            case AggregationType.NonNullCount:
                return (aggregation ?? 0) + (value !== null ? 1 : 0)
        }
    }
}