import { put, select, takeLatest, takeEvery, take } from 'redux-saga/effects'
import { eventChannel } from 'redux-saga'

import { Tensor, InferenceSession, env } from 'onnxruntime-web'
import pako from 'pako'
import untar from 'js-untar'

const sessions = {}

export const INFER_IMAGE = 'FALCON/DATASETS/VIEWER/INFER_IMAGE'

export function getModel(model) {
    if (!Object.hasOwn(sessions, model)) {
        return null
    }

    return sessions[model]
}

export function isModelReady(model) {
    return Object.hasOwn(sessions, model)
}

export function downloadModel(url) {
    return new Promise((resolve, reject) => {
        let error = false
        fetch(url)
            .then((res) => res.arrayBuffer())
            .then(pako.inflate)
            .then((res) => res.buffer)
            .then(untar)
            .then((res) => resolve(res[0].buffer))
            .catch(() => {
                error = true
                if (error) {
                    reject(new Error('Download model failed'))
                }
            })
    })
}

export async function warmupModel(model, dims) {
    // OK. we generate a random input and call Session.run() as a warmup query
    const size = dims.reduce((a, b) => a * b)
    const warmupTensor = new Tensor('float32', new Float32Array(size), dims)

    for (let i = 0; i < size; i += 1) {
        warmupTensor.data[i] = Math.random() * 2.0 - 1.0 // random value [-1.0, 1.0)
    }

    try {
        const feeds = {}
        feeds[model.inputNames[0]] = warmupTensor
        return await model.run(feeds)
    } catch (e) {
        console.log(e)
        return null
    }
}

export function createModel(id, url, shape, device) {
    return new Promise((resolve, reject) => {
        if (Object.hasOwn(sessions, id)) {
            return resolve(sessions[id])
        }

        downloadModel(url)
            .then((model) => {
                InferenceSession.create(model, { executionProviders: [device] })
                    .then(async (sess) => {
                        if (sess === null) {
                            reject(new Error('Create session error'))
                            return null
                        }

                        sessions[id] = sess
                        const result = await warmupModel(sess, shape)
                        if (result === null) {
                            reject(new Error('Warmup error'))
                            return null
                        }
                        resolve(sess)
                        return null
                    })
                    .catch((error) => {
                        reject(error)
                    })
            })
            .catch((error) => {
                reject(error)
            })
    })
}

/**
 * Perform Non Maximum Suppression to filter overlapping boxes
 * @param {Array[Object]} boxes Boxes
 * @param {Number} overlapThresh Overlapping / IoU threshold
 * @returns {Array[Object]} Boxes
 */
/* eslint-disable import/prefer-default-export */
export const NMS = (boxes, overlapThresh) => {
    if (boxes.length === 0) {
        return []
    }

    const pick = []
    let boxesProcessed = boxes
    boxesProcessed.sort((b1, b2) => {
        return b1.confidence - b2.confidence
    })

    while (boxesProcessed.length > 0) {
        const last = boxesProcessed[boxesProcessed.length - 1]
        pick.push(last)
        const suppress = [last]

        for (let i = 0; i < boxesProcessed.length - 1; i += 1) {
            const box = boxesProcessed[i]
            const xx1 = Math.max(box.bounding[0], last.bounding[0])
            const yy1 = Math.max(box.bounding[1], last.bounding[1])
            const xx2 = Math.min(
                box.bounding[0] + box.bounding[2],
                last.bounding[0] + last.bounding[2]
            )
            const yy2 = Math.min(
                box.bounding[1] + box.bounding[3],
                last.bounding[1] + last.bounding[3]
            )
            const w = Math.max(0, xx2 - xx1 + 1)
            const h = Math.max(0, yy2 - yy1 + 1)
            const overlap = (w * h) / ((box.bounding[2] + 1) * (box.bounding[3] + 1))

            if (overlap > overlapThresh) {
                suppress.push(boxesProcessed[i])
            }
        }

        boxesProcessed = boxesProcessed.filter((box) => {
            return !suppress.find((supp) => {
                return supp === box
            })
        })
    }

    return pick
}

/**
 * Detect Image
 * @param {HTMLImageElement} image Image to detect
 * @param {HTMLCanvasElement} canvas canvas to draw boxes
 * @param {ort.InferenceSession} session YOLOv5 onnxruntime session
 * @param {Number} confidenceThreshold confidence threshold
 * @param {Number} classThreshold class threshold
 * @param {Number} nmsThreshold NMS / IoU threshold
 * @param {Number[]} inputShape model input shape. Normally in YOLO model [batch, channels, width, height]
 */
/* eslint-disable import/prefer-default-export */
async function detectImage(
    image,
    session,
    confidenceThreshold,
    classThreshold,
    nmsThreshold,
    inputShape,
    nClasses
) {
    const [modelWidth, modelHeight] = inputShape.slice(2)
    const mat = cv.imread(image) // RGBA ordering
    const imgScaled = new cv.Mat(modelWidth, modelHeight, cv.CV_8UC3)
    cv.resize(mat, imgScaled, new cv.Size(modelWidth, modelHeight), 0, 0, cv.INTER_NEAREST)
    const matC3 = new cv.Mat(modelWidth, modelHeight, cv.CV_8UC3) // new image matrix
    cv.cvtColor(imgScaled, matC3, cv.COLOR_RGBA2RGB) // RGBA to RGB
    const input = cv.blobFromImage(
        matC3,
        1 / 255.0,
        new cv.Size(modelWidth, modelHeight),
        new cv.Scalar(0, 0, 0),
        false,
        false
    ) // preprocessing image matrix

    mat.delete()
    matC3.delete()
    imgScaled.delete()

    const tensor = new Tensor('float32', input.data32F, inputShape) // to ort.Tensor
    const res = await session.run({ images: tensor }) // run session and get output layer
    input.delete()

    const output0 = res.output

    const boxes = []

    // looping through output
    for (let r = 0; r < output0.data.length; r += output0.dims[2]) {
        const data = output0.data.slice(r, r + output0.dims[2]) // get rows
        const scores = data.slice(5) // classes probability scores
        const confidence = data[4] // detection confidence
        const classId = scores.indexOf(Math.max(...scores)) // class id of maximum probability scores
        const maxClassProb = scores[classId] // maximum probability scores

        // filtering by thresholds
        if (confidence >= confidenceThreshold && maxClassProb >= classThreshold) {
            const [x, y, w, h] = data.slice(0, 4)
            boxes.push({
                classId,
                probability: maxClassProb,
                confidence,
                bounding: [x - 0.5 * w, y - 0.5 * h, w, h],
            })
        }
    }

    // Non Maximum Suppression by class
    const processesBoxes = []
    for (let i = 0; i < nClasses; i += 1) {
        const classesBoxes = boxes.filter((obj) => obj.classId === i)
        processesBoxes.push(...NMS(classesBoxes, nmsThreshold))
    }

    return processesBoxes
}

function* infer(action) {
    const {
        inferenceModel,
        inferenceModelShape,
        inferenceModelClasses,
        inferenceDeliverableUrl,
        inferenceConfidenceThreshold,
        inferenceClassThreshold,
        inferenceNmsThreshold,
        inferenceDevice,
    } = yield select((state) => state.datasets.settings)

    const channel = eventChannel((emitter) => {
        const image = new Image()

        image.onload = async () => {
            const model = await createModel(
                inferenceModel,
                inferenceDeliverableUrl,
                inferenceModelShape,
                inferenceDevice
            )

            const result = await detectImage(
                image,
                model,
                inferenceConfidenceThreshold,
                inferenceClassThreshold,
                inferenceNmsThreshold,
                inferenceModelShape,
                inferenceModelClasses.length
            )

            emitter({
                type: 'INFER_IMAGE_SUCCESS',
                annotations: result.map((r) => ({
                    classId: r.classId,
                    classLabel: inferenceModelClasses[r.classId],
                    confidence: r.confidence,
                    bbox: [
                        (r.bounding[0] * image.width) / inferenceModelShape[2],
                        (r.bounding[1] * image.height) / inferenceModelShape[3],
                        (r.bounding[2] * image.width) / inferenceModelShape[2],
                        (r.bounding[3] * image.height) / inferenceModelShape[3],
                    ],
                })),
            })
        }
        image.crossOrigin = 'Anonymous'
        image.src = `${action.src}?${new Date().getTime()}`

        return () => {}
    })

    const msg = yield take(channel)
    yield put(msg)
}

function* mySaga() {
    yield [takeLatest(INFER_IMAGE, infer)]
}

export default mySaga
