//import {IMAGENET_CLASSES} from "app/lib/interelcom/tfjsutils/ImagenetClasses";
import type * as tf from "@tensorflow/tfjs";
import type {GraphModel, LayersModel, TensorLike} from "@tensorflow/tfjs";
import TFJSLoader from "app/lib/interelcom/tfjsutils/TFJSLoader";
import {Tensor} from "@tensorflow/tfjs-core/dist/tensor";

export interface MobilenetConf {
    label: string,
    path: string,
    tfhub: boolean,
    inputMin: number;
    inputMax: number,
    inputValues: number,
    imageSize: number,
    resultClasses: { [id: number]: string },
    format: string
}

export interface TopPrediction {
    className: string,
    probability: number
}

export default class MobilenetPrediction {

    private conf: MobilenetConf;

    private graphmodel: GraphModel<string>;
    private layersmodel: LayersModel;
    private _tf: typeof tf;

    constructor(conf: MobilenetConf) {
        this.conf = conf;
    }

    public getConf(): MobilenetConf {
        return this.conf;
    }

    public async load(callback: (fraction: number) => void) {

        this._tf = await TFJSLoader.load();

        if (this.conf.format == 'graph-model') {

            let model = await this._tf.loadGraphModel(
                this.conf.path,
                {
                    fromTFHub: this.conf.tfhub,
                    onProgress: (fraction: number) => {
                        this.progress(fraction)
                        if (null != callback) {
                            callback(fraction);
                        }
                    }
                }
            ) as GraphModel<string>;

            const result: tf.TensorContainer = this._tf.tidy(() => {
                let predict = model.predict(
                    this._tf.zeros([1, this.conf.imageSize, this.conf.imageSize, 3])
                ) as tf.Tensor
                return predict;
            });

            await result.data();
            result.dispose();

            this.graphmodel = model;

            return;
        }

        if (this.conf.format == 'layers-model') {

            let model = await this._tf.loadLayersModel(
                this.conf.path,
                {
                    fromTFHub: this.conf.tfhub,
                    onProgress: (fraction: number) => {
                        this.progress(fraction)
                        if (null != callback) {
                            callback(fraction);
                        }
                    }
                }
            );


            const result = this._tf.tidy(() => {
                return model.predict(
                    this._tf.zeros([1, this.conf.imageSize, this.conf.imageSize, 3])
                ) as tf.Tensor
            });

            await result.data();
            result.dispose();

            this.layersmodel = model;

            return;
        }


        throw new Error("Unknown model format: " + this.conf.format);
    }

    private progress(fraction: number) {
        console.log("PROGRESS: " + fraction);
    }

    public async predict(canvas: HTMLCanvasElement) {

        let logits = this._tf.tidy(() => {

            let canvasData = this._tf.browser.fromPixels(canvas);

            let unnormalized = this._tf.cast(canvasData, 'float32');

            let normalized = unnormalized;

            if(this.conf.inputValues == 1){
                let normalizationConstant = 1.0 / 255.0;
                normalized = this._tf.add(
                    this._tf.mul(unnormalized, normalizationConstant),
                    0
                );
            }

            let resized = normalized;

            if (canvasData.shape[0] !== this.conf.imageSize || canvasData.shape[1] !== this.conf.imageSize) {
                const alignCorners = true;

                resized = this._tf.image.resizeBilinear(
                    normalized,
                    [this.conf.imageSize, this.conf.imageSize],
                    alignCorners
                );
            }

            // Reshape to a single-element batch so we can pass it to predict.
            const batched = this._tf.reshape(resized, [-1, this.conf.imageSize, this.conf.imageSize, 3]);

            if (null != this.layersmodel) {
                return this.layersmodel.predict(batched) as tf.Tensor2D;
            }

            return this.graphmodel.predict(batched) as tf.Tensor2D;
        });

        let classesCount = Object.keys(this.conf.resultClasses).length

        let result = this._tf.slice(logits, [0, 0], [-1, classesCount]);

        return await this.getTopKClasses(result, 10);
    }


    /**
     * Computes the probabilities of the topK classes given logits by computing
     * softmax to get probabilities and then sorting the probabilities.
     * @param logits Tensor representing the logits from MobileNet. tf.Tensor2D
     * @param topK The number of top predictions to show.
     */
    public async getTopKClasses<T extends Tensor>(logits: T | TensorLike, topK: number) {

        const softmax = this._tf.softmax(logits);
        const values = await softmax.data();
        softmax.dispose();

        const valuesAndIndices: { value: number, index: number }[] = [];
        for (let i = 0; i < values.length; i++) {
            valuesAndIndices.push({value: values[i], index: i});
        }

        valuesAndIndices.sort((a, b) => {
            return b.value - a.value;
        });

        console.log(valuesAndIndices);

        const topkValues = new Float32Array(topK);
        const topkIndices = new Int32Array(topK);
        for (let i = 0; i < topK; i++) {
            topkValues[i] = valuesAndIndices[i].value;
            topkIndices[i] = valuesAndIndices[i].index;
        }

        const topClassesAndProbs: TopPrediction[] = [];
        for (let i = 0; i < topkIndices.length; i++) {

            topClassesAndProbs.push({
                className: this.conf.resultClasses[topkIndices[i]],
                probability: topkValues[i]
            })

        }
        return topClassesAndProbs;
    }

}