import { Tensor, InferenceSession } from 'onnxruntime-web'
import pako from 'pako'
import untar from 'js-untar'

const sessions = {}

export function getModel(model) {
    if (!Object.hasOwn(sessions, model)) {
        return null
    }

    return sessions[model]
}

export function isModelReady(model) {
    return Object.hasOwn(sessions, model)
}

export function downloadModel(url) {
    return new Promise((resolve, reject) => {
        let error = false
        fetch(url)
            .then((res) => res.arrayBuffer())
            .then(pako.inflate)
            .then((res) => res.buffer)
            .then(untar)
            .then((res) => resolve(res[0].buffer))
            .catch(() => {
                error = true
                if (error) {
                    reject(new Error('Download model failed'))
                }
            })
    })
}

export async function warmupModel(model, dims) {
    // OK. we generate a random input and call Session.run() as a warmup query
    const size = dims.reduce((a, b) => a * b)
    const warmupTensor = new Tensor('float32', new Float32Array(size), dims)

    for (let i = 0; i < size; i += 1) {
        warmupTensor.data[i] = Math.random() * 2.0 - 1.0 // random value [-1.0, 1.0)
    }

    try {
        const feeds = {}
        feeds[model.inputNames[0]] = warmupTensor
        return await model.run(feeds)
    } catch (e) {
        return null
    }
}

export function createModelCpu(id, url, shape) {
    return new Promise((resolve, reject) => {
        if (Object.hasOwn(sessions, id)) {
            resolve(sessions[id])
        }

        downloadModel(url)
            .then((model) => {
                InferenceSession.create(model, { executionProviders: ["cpu"] })
                    .then(async (sess) => {
                        sessions[id] = sess
                        const result = await warmupModel(sess, shape)
                        if (result === null) {
                            reject(new Error('Warmup error'))
                            return null
                        }
                        resolve(sess)
                        return null
                    })
                    .catch((error) => {
                        reject(error)
                    })
            })
            .catch((error) => {
                reject(error)
            })
    })
}
