-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
156 additions
and
70 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,89 +1,148 @@ | ||
import { AutoTokenizer } from "@huggingface/transformers"; | ||
import { InferenceSession, Tensor } from "onnxruntime-web"; | ||
import { PreTrainedTokenizer } from "@huggingface/transformers"; | ||
import { InferenceSession, Tensor, env } from "onnxruntime-web"; | ||
|
||
export class LatexOCR { | ||
constructor() { | ||
this.model = null; | ||
this.tokenizer = null; | ||
this.maxDimensions = [800, 400]; | ||
this.minDimensions = [100, 32]; | ||
} | ||
import ortWasmUrl from "onnx-web/ort-wasm-simd-threaded.jsep.wasm?url"; | ||
|
||
async initialize({ modelPath = __MODEL_URL__, tokenizerPath = __TOKENIZER_URL__ } = {}) { | ||
this.model = await InferenceSession.create(modelPath); | ||
this.tokenizer = await AutoTokenizer.from_pretrained(tokenizerPath); | ||
export class LatexOCR { | ||
/** @type {InferenceSession} */ | ||
#model = null; | ||
/** @type {PreTrainedTokenizer} */ | ||
#tokenizer = null; | ||
#args = { | ||
max_dimensions: [800, 800], | ||
min_dimensions: [32, 32], | ||
temperature: 0.25, | ||
bos_token: 0, | ||
eos_token: 1, | ||
pad_token: 2, | ||
max_seq_len: 256 | ||
}; | ||
|
||
async initialize({ | ||
modelUrl = __MODEL_URL__, | ||
tokenizerUrl = __TOKENIZER_URL__, | ||
tokenizerConfigUrl = __TOKENIZER_CONFIG_URL__ | ||
} = {}) { | ||
env.wasm.wasmBinary = await fetch(ortWasmUrl).then(res => res.arrayBuffer()); | ||
|
||
const [model, [tokenizer, config]] = await Promise.all([ | ||
fetch(modelUrl) | ||
.then(res => res.arrayBuffer()) | ||
.then(buffer => { | ||
return InferenceSession.create(buffer, { | ||
executionProviders: ["webgpu"], | ||
}); | ||
}), | ||
Promise.all([ | ||
fetch(tokenizerUrl).then(res => res.json()), | ||
fetch(tokenizerConfigUrl).then(res => res.json()) | ||
]) | ||
]); | ||
|
||
this.#model = model; | ||
this.#tokenizer = new PreTrainedTokenizer(tokenizer, config); | ||
this.#args = { ...this.#args, ...config }; | ||
} | ||
|
||
async predict(imageElement) { | ||
const processedImage = await this.preprocessImage(imageElement); | ||
const inputTensor = new Tensor("float32", processedImage.data, [1, 3, ...processedImage.shape]); | ||
|
||
const { output } = await this.model.run({ input: inputTensor }); | ||
return this.postProcess(output); | ||
async predict(element, { numCandidates = 1 } = {}) { | ||
const processed = await this.#preprocessImage(element); | ||
const outputs = await this.#inference(processed, numCandidates); | ||
return this.#postProcess(outputs, numCandidates); | ||
} | ||
|
||
async preprocessImage(img) { | ||
async #preprocessImage(element) { | ||
const canvas = document.createElement("canvas"); | ||
const ctx = canvas.getContext("2d"); | ||
canvas.width = element.naturalWidth; | ||
canvas.height = element.naturalHeight; | ||
ctx.drawImage(element, 0, 0); | ||
|
||
const resized = await this.minmaxSize(img); | ||
canvas.width = resized.width; | ||
canvas.height = resized.height; | ||
ctx.drawImage(resized, 0, 0); | ||
const imgData = padImage(minmaxSize(canvas, this.#args.min_dimensions, this.#args.max_dimensions)); | ||
|
||
const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height); | ||
return this.normalize(imageData); | ||
} | ||
// apply transformations | ||
const tensor = new ImageData(imgData.width, imgData.height); | ||
const mean = [0.485, 0.456, 0.406]; | ||
const std = [0.229, 0.224, 0.225]; | ||
|
||
async minmaxSize(img) { | ||
const canvas = document.createElement("canvas"); | ||
const ctx = canvas.getContext("2d"); | ||
for (let i = 0; i < imgData.data.length; i += 4) { | ||
tensor.data[i/4 * 3] = (imgData.data[i] / 255 - mean[0]) / std[0]; | ||
tensor.data[i/4 * 3 + 1] = (imgData.data[i + 1] / 255 - mean[1]) / std[1]; | ||
tensor.data[i/4 * 3 + 2] = (imgData.data[i + 2] / 255 - mean[2]) / std[2]; | ||
} | ||
|
||
return { | ||
data: Float32Array.from(tensor.data), | ||
shape: [3, imgData.height, imgData.width] | ||
}; | ||
|
||
let width = img.naturalWidth || img.width; | ||
let height = img.naturalHeight || img.height; | ||
function minmaxSize(canvas, min, max) { | ||
const [[minW, minH], [maxW, maxH]] = [min, max]; | ||
|
||
const maxRatio = Math.max(width / this.maxDimensions[0], height / this.maxDimensions[1]); | ||
if (maxRatio > 1) { | ||
width /= maxRatio; | ||
height /= maxRatio; | ||
let ratio = Math.min(maxW / canvas.width, maxH / canvas.height); | ||
if (ratio < 1) { | ||
canvas.width *= ratio; | ||
canvas.height *= ratio; | ||
} | ||
|
||
if (canvas.width < minW || canvas.height < minH) { | ||
const resizeCanvas = document.createElement("canvas"); | ||
resizeCanvas.width = Math.max(canvas.width, minW); | ||
resizeCanvas.height = Math.max(canvas.height, minH); | ||
|
||
const ctx = resizeCanvas.getContext("2d"); | ||
ctx.fillStyle = "white"; | ||
ctx.fillRect(0, 0, resizeCanvas.width, resizeCanvas.height); | ||
ctx.drawImage(canvas, 0, 0); | ||
|
||
return ctx.getImageData(0, 0, resizeCanvas.width, resizeCanvas.height); | ||
} | ||
|
||
return ctx.getImageData(0, 0, canvas.width, canvas.height); | ||
} | ||
|
||
canvas.width = Math.max(width, this.minDimensions[0]); | ||
canvas.height = Math.max(height, this.minDimensions[1]); | ||
ctx.fillStyle = "white"; | ||
ctx.fillRect(0, 0, canvas.width, canvas.height); | ||
ctx.drawImage(img, 0, 0, width, height); | ||
function padImage(imgData, padding = 20) { | ||
const canvas = document.createElement("canvas"); | ||
canvas.width = imgData.width + padding * 2; | ||
canvas.height = imgData.height + padding * 2; | ||
|
||
return canvas; | ||
} | ||
const ctx = canvas.getContext("2d"); | ||
ctx.fillStyle = "white"; | ||
ctx.fillRect(0, 0, canvas.width, canvas.height); | ||
ctx.putImageData(imgData, padding, padding); | ||
|
||
normalize(imageData) { | ||
const float32Data = new Float32Array(3 * imageData.width * imageData.height); | ||
const mean = [0.5, 0.5, 0.5]; | ||
const std = [0.5, 0.5, 0.5]; | ||
return ctx.getImageData(0, 0, canvas.width, canvas.height); | ||
} | ||
} | ||
|
||
for (let i = 0; i < imageData.data.length; i += 4) { | ||
const r = (imageData.data[i] / 255 - mean[0]) / std[0]; | ||
const g = (imageData.data[i + 1] / 255 - mean[1]) / std[1]; | ||
const b = (imageData.data[i + 2] / 255 - mean[2]) / std[2]; | ||
async #inference(processed, numCandidates) { | ||
const inputTensor = new Tensor("float32", processed.data, [1, ...processed.shape]); | ||
const outputs = []; | ||
|
||
float32Data[i / 4 * 3] = r; | ||
float32Data[i / 4 * 3 + 1] = g; | ||
float32Data[i / 4 * 3 + 2] = b; | ||
for (let i = 0; i < numCandidates; i++) { | ||
const { output } = await this.#model.run({ input: inputTensor }); | ||
outputs.push(output.data); | ||
} | ||
|
||
return { | ||
data: float32Data, | ||
shape: [imageData.height, imageData.width] | ||
}; | ||
return outputs; | ||
} | ||
|
||
postProcess(outputTensor) { | ||
const tokenIds = Array.from(outputTensor.data); | ||
let decoded = this.tokenizer.decode(tokenIds, { skip_special_tokens: true }); | ||
#postProcess(outputs, numCandidates) { | ||
const results = outputs.map(output => { | ||
const tokens = Array.from(output) | ||
.map(Math.round) | ||
.filter(t => t !== this.#args.pad_token && t !== this.#args.eos_token); | ||
|
||
const decoded = this.#tokenizer.decode(tokens, { skip_special_tokens: true }) | ||
.replace(/Ġ/g, ' ') | ||
.replace(/\[(PAD|BOS|EOS)\]/g, '') | ||
.replace(/\s+/g, ' ') | ||
.trim(); | ||
|
||
return decoded.replace(/(\\[a-z]+)\s*{/g, "$1{") | ||
.replace(/(\D)(\d)/g, "$1 $2") | ||
.replace(/(\d)(\D)/g, "$1 $2"); | ||
}); | ||
|
||
return decoded | ||
.replace(/Ġ/g, ' ') | ||
.replace(/\[(PAD|BOS|EOS)\]/g, ''); | ||
return numCandidates === 1 ? results[0] : results; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,36 @@ | ||
import { defineConfig } from "vite"; | ||
import { defineConfig, loadEnv } from "vite"; | ||
import path from "node:path"; | ||
|
||
export default defineConfig({ | ||
base: "/Eunectes/", | ||
define: { | ||
__MODEL_URL__: JSON.stringify("https://sieluna.github.io/Eunectes"), | ||
__TOKENIZER_URL__: JSON.stringify("https://sieluna.github.io/Eunectes/packages/train/model/dataset/tokenizer.json") | ||
export default defineConfig(({ mode }) => { | ||
const env = { ...process.env, ...loadEnv(mode, process.cwd(), "") }; | ||
|
||
if (mode === "development") { | ||
["MODEL_URL", "TOKENIZER_URL", "TOKENIZER_CONFIG_URL"].forEach(key => { | ||
env[key] = env[key]?.replace("https://github.com", "/github"); | ||
}); | ||
} | ||
|
||
return { | ||
base: "/Eunectes/", | ||
resolve: { | ||
alias: { | ||
"onnx-web": path.resolve(import.meta.dirname, "node_modules/onnxruntime-web/dist"), | ||
} | ||
}, | ||
server: { | ||
proxy: { | ||
"/github": { | ||
target: "https://github.com", | ||
changeOrigin: true, | ||
followRedirects: true, | ||
rewrite: (path) => path.replace(/^\/github/, "") | ||
} | ||
} | ||
}, | ||
define: { | ||
__MODEL_URL__: JSON.stringify(env.MODEL_URL), | ||
__TOKENIZER_URL__: JSON.stringify(env.TOKENIZER_URL), | ||
__TOKENIZER_CONFIG_URL__: JSON.stringify(env.TOKENIZER_CONFIG_URL) | ||
} | ||
}; | ||
}); |