Skip to content

Commit

Permalink
Show error when not enough points to calculate umap
Browse files Browse the repository at this point in the history
  • Loading branch information
rcurrie committed Dec 19, 2024
1 parent 76fb16d commit af70c5c
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 35 deletions.
3 changes: 2 additions & 1 deletion index.html
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ <h2 class="mb-0" style="font-weight: 300">SIMS Web</h2>
<div class="custom-file">
<input type="file" class="custom-file-input" name="file_input" id="file_input"
accept=".h5ad">
<label id="file_input_label" class="custom-file-label" for="image">Select an AnnData/Scanpy
<label id="file_input_label" class="custom-file-label" for="file_input">Select an
AnnData/Scanpy
(.h5ad) file</label>
</div>
</div>
Expand Down
6 changes: 2 additions & 4 deletions main.js
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ async function predict(worker, modelID, h5File, cellRangePercent) {
)} minutes`;
resolve([cellNames, classes, predictions, coordinates]);
} else if (type === "error") {
reject(error);
document.getElementById("message").textContent = error.message;
// reject(error);
}
};

Expand Down Expand Up @@ -310,9 +311,6 @@ async function main() {
worker = new Worker("worker.js");
}

// Reset the attention accumulator
worker.postMessage({ type: "resetAttentionAccumulator" });

const [cellNames, classes, predictions, coordinates] = await predict(
worker,
modelID,
Expand Down
59 changes: 29 additions & 30 deletions worker.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,21 @@ self.importScripts(
"https://cdn.jsdelivr.net/npm/[email protected]/lib/umap-js.min.js"
);

// Global variables for accumulators
let attentionAccumulator = null;
// Global variables
self.model = null;
self.attentionAccumulator = null;

self.addEventListener("message", async function (event) {
const { type } = event.data;

if (type === "startPrediction") {
predict(event);
// Handle prediction as shown above
// ...
} else if (type === "getAttentionAccumulator") {
self.postMessage({
type: "attentionAccumulator",
attentionAccumulator: attentionAccumulator.buffer,
attentionAccumulator: self.attentionAccumulator.buffer,
genes: self.model.genes,
});
} else if (type === "resetAttentionAccumulator") {
attentionAccumulator = null;
self.postMessage({ type: "attentionAccumulatorReset" });
}
});

Expand Down Expand Up @@ -97,19 +93,16 @@ async function instantiateModel(id) {

if (location.hostname === "localhost") {
ort.env.debug = true;
ort.env.logLevel = "verbose";
ort.env.trace = true;
options["logSeverityLevel"] = 0;
options["logVerbosityLevel"] = 0;
// ort.env.logLevel = "verbose";
// ort.env.trace = true;
// options["logSeverityLevel"] = 0;
// options["logVerbosityLevel"] = 0;
}

// Create the InferenceSession with the model ArrayBuffer
const session = await ort.InferenceSession.create(modelArray.buffer, options);
console.log("Model Output names", session.outputNames);

// Initialize attention accumulator
attentionAccumulator = new Float32Array(genes.length);

return { id, session, genes, classes };
}

Expand Down Expand Up @@ -164,6 +157,9 @@ async function predict(event) {
self.model = await instantiateModel(event.data.modelID);
}

// Reset attention accumulator
self.attentionAccumulator = new Float32Array(self.model.genes.length);

self.postMessage({ type: "status", message: "Loading file" });
if (!FS.analyzePath("/work").exists) {
FS.mkdir("/work");
Expand Down Expand Up @@ -250,12 +246,8 @@ async function predict(event) {

encodings.push(output.encoding.cpuData);

if (!attentionAccumulator) {
attentionAccumulator = new Float32Array(genes.length);
}

for (let i = 0; i < attentionAccumulator.length; i++) {
attentionAccumulator[i] += output.attention.cpuData[i];
for (let i = 0; i < self.attentionAccumulator.length; i++) {
self.attentionAccumulator[i] += output.attention.cpuData[i];
}

// Post progress update
Expand All @@ -273,15 +265,22 @@ async function predict(event) {
nEpochs: 400,
nNeighbors: 15,
});
const coordinates = await umap.fitAsync(encodings, (epochNumber) => {
// check progress and give user feedback, or return `false` to stop
self.postMessage({
type: "progress",
message: "Computing coordinates...",
countFinished: epochNumber,
totalToProcess: umap.getNEpochs(),

let coordinates = null;
try {
coordinates = await umap.fitAsync(encodings, (epochNumber) => {
// check progress and give user feedback, or return `false` to stop
self.postMessage({
type: "progress",
message: "Computing coordinates...",
countFinished: epochNumber,
totalToProcess: umap.getNEpochs(),
});
});
});
} catch (error) {
self.postMessage({ type: "error", error });
throw error;
}

annData.close();
FS.unmount("/work");
Expand All @@ -301,6 +300,6 @@ async function predict(event) {
});
} catch (error) {
FS.unmount("/work");
self.postMessage({ type: "error", error: error.message });
self.postMessage({ type: "error", error: error });
}
}

0 comments on commit af70c5c

Please sign in to comment.