import { Tensor, InferenceSession } from "onnxruntime-web";
import pako from "pako"
import untar from "js-untar"
import ndarray from "ndarray"
import ops from "ndarray-ops";

let sessions = {}

export async function downloadModel(url) {
    
    return new Promise((resolve, reject) => {
        fetch(url)
        .then(res => res.arrayBuffer())
        .then(pako.inflate)
        .then(res => res.buffer)
        .then(untar)
        .then(res => resolve(res[0].buffer))
        .catch(reject)
    })
}

export async function createModelCpu(id, url) {
    
    if(sessions.hasOwnProperty(id)) {
        return sessions[id]
    }

    const model = await downloadModel(url)

    const sess = await InferenceSession.create(model)
    if(sess) {
        sessions[id] = sess
    }

    await warmupModel(sess, [ 1, 3, 640, 640 ])

    return sess
}

export async function warmupModel(model, dims) {
    // OK. we generate a random input and call Session.run() as a warmup query
    const size = dims.reduce((a, b) => a * b);
    const warmupTensor = new Tensor('float32', new Float32Array(size), dims);
  
    for (let i = 0; i < size; i++) {
      warmupTensor.data[i] = Math.random() * 2.0 - 1.0;  // random value [-1.0, 1.0)
    }

    try {
      const feeds = {};
      feeds[model.inputNames[0]] = warmupTensor;
      await model.run(feeds);
    } catch (e) {
      console.error(e);
    }
}

export async function runModel(model, imageData) {

    const { data, width, height } = imageData;

    // data processing
    const dataTensor = ndarray(new Float32Array(data), [width, height, 4]);
    const dataProcessedTensor = ndarray(new Float32Array(width * height * 3), [
        1,
        3,
        width,
        height,
    ]);

    ops.assign(
        dataProcessedTensor.pick(0, 0, null, null),
        dataTensor.pick(null, null, 0)
    );
    ops.assign(
        dataProcessedTensor.pick(0, 1, null, null),
        dataTensor.pick(null, null, 1)
    );
    ops.assign(
        dataProcessedTensor.pick(0, 2, null, null),
        dataTensor.pick(null, null, 2)
    );

    const tensor = new Tensor("float32", new Float32Array(width * height * 3), [
        1,
        3,
        width,
        height,
    ]);

    (tensor.data).set(dataProcessedTensor.data);

    const start = new Date();
    try {
        const feeds = {};
        feeds[model.inputNames[0]] = tensor;
        const outputData = await model.run(feeds);
        const end = new Date();
        const inferenceTime = (end.getTime() - start.getTime());
        const output = outputData[model.outputNames[0]];
        console.log(output, inferenceTime)
        return [output, inferenceTime];
    } catch (e) {
        console.error(e);
        throw new Error();
    }
}