import { Tensor } from 'onnxruntime-web'
import { NMS } from './nms'

/**
 * Detect Image
 * @param {HTMLImageElement} image Image to detect
 * @param {HTMLCanvasElement} canvas canvas to draw boxes
 * @param {ort.InferenceSession} session YOLOv5 onnxruntime session
 * @param {Number} confidenceThreshold confidence threshold
 * @param {Number} classThreshold class threshold
 * @param {Number} nmsThreshold NMS / IoU threshold
 * @param {Number[]} inputShape model input shape. Normally in YOLO model [batch, channels, width, height]
 */
/* eslint-disable import/prefer-default-export */
export const detectImage = async (
    image,
    session,
    confidenceThreshold,
    classThreshold,
    nmsThreshold,
    inputShape,
    cv,
    nClasses
) => {
    const [modelWidth, modelHeight] = inputShape.slice(2)
    const mat = cv.imread(image) // RGBA ordering
    const imgScaled = new cv.Mat(modelWidth, modelHeight, cv.CV_8UC3)
    cv.resize(mat, imgScaled, new cv.Size(modelWidth, modelHeight), 0, 0, cv.INTER_NEAREST)
    const matC3 = new cv.Mat(modelWidth, modelHeight, cv.CV_8UC3) // new image matrix
    cv.cvtColor(imgScaled, matC3, cv.COLOR_RGBA2RGB) // RGBA to RGB
    const input = cv.blobFromImage(
        matC3,
        1 / 255.0,
        new cv.Size(modelWidth, modelHeight),
        new cv.Scalar(0, 0, 0),
        false,
        false
    ) // preprocessing image matrix

    mat.delete()
    matC3.delete()
    imgScaled.delete()

    const tensor = new Tensor('float32', input.data32F, inputShape) // to ort.Tensor
    const res = await session.run({ images: tensor }) // run session and get output layer
    input.delete()

    const output0 = res.output

    const boxes = []

    // looping through output
    for (let r = 0; r < output0.data.length; r += output0.dims[2]) {
        const data = output0.data.slice(r, r + output0.dims[2]) // get rows
        const scores = data.slice(5) // classes probability scores
        const confidence = data[4] // detection confidence
        const classId = scores.indexOf(Math.max(...scores)) // class id of maximum probability scores
        const maxClassProb = scores[classId] // maximum probability scores

        // filtering by thresholds
        if (confidence >= confidenceThreshold && maxClassProb >= classThreshold) {
            const [x, y, w, h] = data.slice(0, 4)
            boxes.push({
                classId,
                probability: maxClassProb,
                confidence,
                bounding: [x - 0.5 * w, y - 0.5 * h, w, h],
            })
        }
    }

    // Non Maximum Suppression by class
    const processesBoxes = []
    for (let i = 0; i < nClasses; i += 1) {
        const classesBoxes = boxes.filter((obj) => obj.classId === i)
        processesBoxes.push(...NMS(classesBoxes, nmsThreshold))
    }

    return processesBoxes
}
