Skip to content

Commit

Permalink
Fix attention accumulator un-initialized logic, add comment on why on…
Browse files Browse the repository at this point in the history
…nx webgpu won't work for now
  • Loading branch information
rcurrie committed Dec 19, 2024
1 parent 5447749 commit 76fb16d
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ make serve

Processing a test sample with 2638 cells took 67 seconds in the browser vs. 34 seconds in python on the same machine.

# Leveraging a GPU

ONNX Web Runtime does have support for GPUs, but unfortunately they don't support all operators yet. Specifically TopK is not [supported](https://github.com/microsoft/onnxruntime/blob/main/js/web/docs/webgpu-operators.md)

# References

[Open Neural Network Exchange (ONNX)](https://onnx.ai/)
Expand Down
9 changes: 6 additions & 3 deletions worker.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
self.importScripts(
"https://cdnjs.cloudflare.com/ajax/libs/onnxruntime-web/1.20.1/ort.min.js",
"https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js",
"https://cdn.jsdelivr.net/npm/[email protected]/dist/iife/h5wasm.min.js",
"https://cdn.jsdelivr.net/npm/[email protected]/lib/umap-js.min.js"
Expand All @@ -22,7 +21,7 @@ self.addEventListener("message", async function (event) {
genes: self.model.genes,
});
} else if (type === "resetAttentionAccumulator") {
attentionAccumulator = new Float32Array(self.model.genes.length);
attentionAccumulator = null;
self.postMessage({ type: "attentionAccumulatorReset" });
}
});
Expand Down Expand Up @@ -251,6 +250,10 @@ async function predict(event) {

encodings.push(output.encoding.cpuData);

if (!attentionAccumulator) {
attentionAccumulator = new Float32Array(genes.length);
}

for (let i = 0; i < attentionAccumulator.length; i++) {
attentionAccumulator[i] += output.attention.cpuData[i];
}
Expand Down Expand Up @@ -300,4 +303,4 @@ async function predict(event) {
FS.unmount("/work");
self.postMessage({ type: "error", error: error.message });
}
};
}

0 comments on commit 76fb16d

Please sign in to comment.