Skip to content

Commit

Permalink
refactor: improve js latex ocr
Browse files Browse the repository at this point in the history
  • Loading branch information
Sieluna committed Feb 8, 2025
1 parent 8c07878 commit 48b826e
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 70 deletions.
185 changes: 122 additions & 63 deletions src/latex-ocr.js
Original file line number Diff line number Diff line change
@@ -1,89 +1,148 @@
import { AutoTokenizer } from "@huggingface/transformers";
import { InferenceSession, Tensor } from "onnxruntime-web";
import { PreTrainedTokenizer } from "@huggingface/transformers";
import { InferenceSession, Tensor, env } from "onnxruntime-web";

export class LatexOCR {
constructor() {
this.model = null;
this.tokenizer = null;
this.maxDimensions = [800, 400];
this.minDimensions = [100, 32];
}
import ortWasmUrl from "onnx-web/ort-wasm-simd-threaded.jsep.wasm?url";

async initialize({ modelPath = __MODEL_URL__, tokenizerPath = __TOKENIZER_URL__ } = {}) {
this.model = await InferenceSession.create(modelPath);
this.tokenizer = await AutoTokenizer.from_pretrained(tokenizerPath);
export class LatexOCR {
/** @type {InferenceSession} */
#model = null;
/** @type {PreTrainedTokenizer} */
#tokenizer = null;
#args = {
max_dimensions: [800, 800],
min_dimensions: [32, 32],
temperature: 0.25,
bos_token: 0,
eos_token: 1,
pad_token: 2,
max_seq_len: 256
};

async initialize({
modelUrl = __MODEL_URL__,
tokenizerUrl = __TOKENIZER_URL__,
tokenizerConfigUrl = __TOKENIZER_CONFIG_URL__
} = {}) {
env.wasm.wasmBinary = await fetch(ortWasmUrl).then(res => res.arrayBuffer());

const [model, [tokenizer, config]] = await Promise.all([
fetch(modelUrl)
.then(res => res.arrayBuffer())
.then(buffer => {
return InferenceSession.create(buffer, {
executionProviders: ["webgpu"],
});
}),
Promise.all([
fetch(tokenizerUrl).then(res => res.json()),
fetch(tokenizerConfigUrl).then(res => res.json())
])
]);

this.#model = model;
this.#tokenizer = new PreTrainedTokenizer(tokenizer, config);
this.#args = { ...this.#args, ...config };
}

async predict(imageElement) {
const processedImage = await this.preprocessImage(imageElement);
const inputTensor = new Tensor("float32", processedImage.data, [1, 3, ...processedImage.shape]);

const { output } = await this.model.run({ input: inputTensor });
return this.postProcess(output);
async predict(element, { numCandidates = 1 } = {}) {
const processed = await this.#preprocessImage(element);
const outputs = await this.#inference(processed, numCandidates);
return this.#postProcess(outputs, numCandidates);
}

async preprocessImage(img) {
async #preprocessImage(element) {
const canvas = document.createElement("canvas");
const ctx = canvas.getContext("2d");
canvas.width = element.naturalWidth;
canvas.height = element.naturalHeight;
ctx.drawImage(element, 0, 0);

const resized = await this.minmaxSize(img);
canvas.width = resized.width;
canvas.height = resized.height;
ctx.drawImage(resized, 0, 0);
const imgData = padImage(minmaxSize(canvas, this.#args.min_dimensions, this.#args.max_dimensions));

const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
return this.normalize(imageData);
}
// apply transformations
const tensor = new ImageData(imgData.width, imgData.height);
const mean = [0.485, 0.456, 0.406];
const std = [0.229, 0.224, 0.225];

async minmaxSize(img) {
const canvas = document.createElement("canvas");
const ctx = canvas.getContext("2d");
for (let i = 0; i < imgData.data.length; i += 4) {
tensor.data[i/4 * 3] = (imgData.data[i] / 255 - mean[0]) / std[0];
tensor.data[i/4 * 3 + 1] = (imgData.data[i + 1] / 255 - mean[1]) / std[1];
tensor.data[i/4 * 3 + 2] = (imgData.data[i + 2] / 255 - mean[2]) / std[2];
}

return {
data: Float32Array.from(tensor.data),
shape: [3, imgData.height, imgData.width]
};

let width = img.naturalWidth || img.width;
let height = img.naturalHeight || img.height;
function minmaxSize(canvas, min, max) {
const [[minW, minH], [maxW, maxH]] = [min, max];

const maxRatio = Math.max(width / this.maxDimensions[0], height / this.maxDimensions[1]);
if (maxRatio > 1) {
width /= maxRatio;
height /= maxRatio;
let ratio = Math.min(maxW / canvas.width, maxH / canvas.height);
if (ratio < 1) {
canvas.width *= ratio;
canvas.height *= ratio;
}

if (canvas.width < minW || canvas.height < minH) {
const resizeCanvas = document.createElement("canvas");
resizeCanvas.width = Math.max(canvas.width, minW);
resizeCanvas.height = Math.max(canvas.height, minH);

const ctx = resizeCanvas.getContext("2d");
ctx.fillStyle = "white";
ctx.fillRect(0, 0, resizeCanvas.width, resizeCanvas.height);
ctx.drawImage(canvas, 0, 0);

return ctx.getImageData(0, 0, resizeCanvas.width, resizeCanvas.height);
}

return ctx.getImageData(0, 0, canvas.width, canvas.height);
}

canvas.width = Math.max(width, this.minDimensions[0]);
canvas.height = Math.max(height, this.minDimensions[1]);
ctx.fillStyle = "white";
ctx.fillRect(0, 0, canvas.width, canvas.height);
ctx.drawImage(img, 0, 0, width, height);
function padImage(imgData, padding = 20) {
const canvas = document.createElement("canvas");
canvas.width = imgData.width + padding * 2;
canvas.height = imgData.height + padding * 2;

return canvas;
}
const ctx = canvas.getContext("2d");
ctx.fillStyle = "white";
ctx.fillRect(0, 0, canvas.width, canvas.height);
ctx.putImageData(imgData, padding, padding);

normalize(imageData) {
const float32Data = new Float32Array(3 * imageData.width * imageData.height);
const mean = [0.5, 0.5, 0.5];
const std = [0.5, 0.5, 0.5];
return ctx.getImageData(0, 0, canvas.width, canvas.height);
}
}

for (let i = 0; i < imageData.data.length; i += 4) {
const r = (imageData.data[i] / 255 - mean[0]) / std[0];
const g = (imageData.data[i + 1] / 255 - mean[1]) / std[1];
const b = (imageData.data[i + 2] / 255 - mean[2]) / std[2];
async #inference(processed, numCandidates) {
const inputTensor = new Tensor("float32", processed.data, [1, ...processed.shape]);
const outputs = [];

float32Data[i / 4 * 3] = r;
float32Data[i / 4 * 3 + 1] = g;
float32Data[i / 4 * 3 + 2] = b;
for (let i = 0; i < numCandidates; i++) {
const { output } = await this.#model.run({ input: inputTensor });
outputs.push(output.data);
}

return {
data: float32Data,
shape: [imageData.height, imageData.width]
};
return outputs;
}

postProcess(outputTensor) {
const tokenIds = Array.from(outputTensor.data);
let decoded = this.tokenizer.decode(tokenIds, { skip_special_tokens: true });
#postProcess(outputs, numCandidates) {
const results = outputs.map(output => {
const tokens = Array.from(output)
.map(Math.round)
.filter(t => t !== this.#args.pad_token && t !== this.#args.eos_token);

const decoded = this.#tokenizer.decode(tokens, { skip_special_tokens: true })
.replace(/Ġ/g, ' ')
.replace(/\[(PAD|BOS|EOS)\]/g, '')
.replace(/\s+/g, ' ')
.trim();

return decoded.replace(/(\\[a-z]+)\s*{/g, "$1{")
.replace(/(\D)(\d)/g, "$1 $2")
.replace(/(\d)(\D)/g, "$1 $2");
});

return decoded
.replace(/Ġ/g, ' ')
.replace(/\[(PAD|BOS|EOS)\]/g, '');
return numCandidates === 1 ? results[0] : results;
}
}
2 changes: 1 addition & 1 deletion src/uploader.js
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ export function setupUploader({ element, onUpdate } = {}) {
file = assets[0];
if (file) {
try {
const image = await loadImage(e.target.files[0]);
const image = await loadImage(file);
const latex = await ocr.predict(image);
onUpdate(latex);
} catch (error) {
Expand Down
39 changes: 33 additions & 6 deletions vite.config.js
Original file line number Diff line number Diff line change
@@ -1,9 +1,36 @@
import { defineConfig } from "vite";
import { defineConfig, loadEnv } from "vite";
import path from "node:path";

export default defineConfig({
base: "/Eunectes/",
define: {
__MODEL_URL__: JSON.stringify("https://sieluna.github.io/Eunectes"),
__TOKENIZER_URL__: JSON.stringify("https://sieluna.github.io/Eunectes/packages/train/model/dataset/tokenizer.json")
export default defineConfig(({ mode }) => {
const env = { ...process.env, ...loadEnv(mode, process.cwd(), "") };

if (mode === "development") {
["MODEL_URL", "TOKENIZER_URL", "TOKENIZER_CONFIG_URL"].forEach(key => {
env[key] = env[key]?.replace("https://github.com", "/github");
});
}

return {
base: "/Eunectes/",
resolve: {
alias: {
"onnx-web": path.resolve(import.meta.dirname, "node_modules/onnxruntime-web/dist"),
}
},
server: {
proxy: {
"/github": {
target: "https://github.com",
changeOrigin: true,
followRedirects: true,
rewrite: (path) => path.replace(/^\/github/, "")
}
}
},
define: {
__MODEL_URL__: JSON.stringify(env.MODEL_URL),
__TOKENIZER_URL__: JSON.stringify(env.TOKENIZER_URL),
__TOKENIZER_CONFIG_URL__: JSON.stringify(env.TOKENIZER_CONFIG_URL)
}
};
});

0 comments on commit 48b826e

Please sign in to comment.