Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix tensor inheritance #451

Merged
merged 5 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ function validateInputs(session, inputs) {
async function sessionRun(session, inputs) {
const checkedInputs = validateInputs(session, inputs);
try {
// @ts-ignore
let output = await session.run(checkedInputs);
output = replaceTensors(output);
return output;
Expand Down Expand Up @@ -292,6 +293,7 @@ function prepareAttentionMask(self, tokens) {
if (is_pad_token_in_inputs && is_pad_token_not_equal_to_eos_token_id) {
let data = BigInt64Array.from(
// Note: != so that int matches bigint
// @ts-ignore
tokens.data.map(x => x != pad_token_id)
)
return new Tensor('int64', data, tokens.dims)
Expand Down Expand Up @@ -704,9 +706,10 @@ export class PreTrainedModel extends Callable {
* @todo Use https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/FinalizationRegistry
*/
async dispose() {
let promises = [];
const promises = [];
for (let key of Object.keys(this)) {
let item = this[key];
const item = this[key];
// @ts-ignore
if (item instanceof InferenceSession) {
promises.push(item.handler.dispose())
}
Expand Down
16 changes: 9 additions & 7 deletions src/utils/generation.js
Original file line number Diff line number Diff line change
Expand Up @@ -261,32 +261,34 @@ export class WhisperTimeStampLogitsProcessor extends LogitsProcessor {
return logits;
}

const logitsData = /** @type {Float32Array} */(logits.data);

// timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
const seq = input_ids.slice(this.begin_index);
const last_was_timestamp = seq.length >= 1 && seq[seq.length - 1] >= this.timestamp_begin;
const penultimate_was_timestamp = seq.length < 2 || seq[seq.length - 2] >= this.timestamp_begin;

if (last_was_timestamp) {
if (penultimate_was_timestamp) { // has to be non-timestamp
logits.data.subarray(this.timestamp_begin).fill(-Infinity);
logitsData.subarray(this.timestamp_begin).fill(-Infinity);
} else { // cannot be normal text tokens
logits.data.subarray(0, this.eos_token_id).fill(-Infinity);
logitsData.subarray(0, this.eos_token_id).fill(-Infinity);
}
}

// apply the `max_initial_timestamp` option
if (input_ids.length === this.begin_index && this.max_initial_timestamp_index !== null) {
const last_allowed = this.timestamp_begin + this.max_initial_timestamp_index;
logits.data.subarray(last_allowed + 1).fill(-Infinity);
logitsData.subarray(last_allowed + 1).fill(-Infinity);
}

// if sum of probability over timestamps is above any other token, sample timestamp
const logprobs = log_softmax(logits.data);
const logprobs = log_softmax(logitsData);
const timestamp_logprob = Math.log(logprobs.subarray(this.timestamp_begin).map(Math.exp).reduce((a, b) => a + b));
const max_text_token_logprob = max(logprobs.subarray(0, this.timestamp_begin))[0];

if (timestamp_logprob > max_text_token_logprob) {
logits.data.subarray(0, this.timestamp_begin).fill(-Infinity);
logitsData.subarray(0, this.timestamp_begin).fill(-Infinity);
}

return logits;
Expand Down Expand Up @@ -697,12 +699,12 @@ export class Sampler extends Callable {
* Returns the specified logits as an array, with temperature applied.
* @param {Tensor} logits
* @param {number} index
* @returns {Array}
* @returns {Float32Array}
*/
getLogits(logits, index) {
let vocabSize = logits.dims.at(-1);

let logs = logits.data;
let logs = /** @type {Float32Array} */(logits.data);

if (index === -1) {
logs = logs.slice(-vocabSize);
Expand Down
15 changes: 13 additions & 2 deletions src/utils/image.js
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ export class RawImage {

/**
* Create a new `RawImage` object.
* @param {Uint8ClampedArray} data The pixel data.
* @param {Uint8ClampedArray|Uint8Array} data The pixel data.
* @param {number} width The width of the image.
* @param {number} height The height of the image.
* @param {1|2|3|4} channels The number of channels.
Expand Down Expand Up @@ -173,7 +173,18 @@ export class RawImage {
} else {
throw new Error(`Unsupported channel format: ${channel_format}`);
}
return new RawImage(tensor.data, tensor.dims[1], tensor.dims[0], tensor.dims[2]);
if (!(tensor.data instanceof Uint8ClampedArray || tensor.data instanceof Uint8Array)) {
throw new Error(`Unsupported tensor type: ${tensor.type}`);
}
switch (tensor.dims[2]) {
case 1:
case 2:
case 3:
case 4:
return new RawImage(tensor.data, tensor.dims[1], tensor.dims[0], tensor.dims[2]);
default:
throw new Error(`Unsupported number of channels: ${tensor.dims[2]}`);
}
}

/**
Expand Down
25 changes: 13 additions & 12 deletions src/utils/maths.js
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,9 @@ export function transpose_data(array, dims, axes) {

/**
* Compute the softmax of an array of numbers.
*
* @param {number[]} arr The array of numbers to compute the softmax of.
* @returns {number[]} The softmax array.
* @template {TypedArray|number[]} T
* @param {T} arr The array of numbers to compute the softmax of.
* @returns {T} The softmax array.
*/
export function softmax(arr) {
// Compute the maximum value in the array
Expand All @@ -142,18 +142,20 @@ export function softmax(arr) {
const exps = arr.map(x => Math.exp(x - maxVal));

// Compute the sum of the exponentials
// @ts-ignore
const sumExps = exps.reduce((acc, val) => acc + val, 0);

// Compute the softmax values
const softmaxArr = exps.map(x => x / sumExps);

return softmaxArr;
return /** @type {T} */(softmaxArr);
}

/**
* Calculates the logarithm of the softmax function for the input array.
* @param {number[]} arr The input array to calculate the log_softmax function for.
* @returns {any} The resulting log_softmax array.
* @template {TypedArray|number[]} T
* @param {T} arr The input array to calculate the log_softmax function for.
* @returns {T} The resulting log_softmax array.
xenova marked this conversation as resolved.
Show resolved Hide resolved
*/
export function log_softmax(arr) {
// Compute the softmax values
Expand All @@ -162,7 +164,7 @@ export function log_softmax(arr) {
// Apply log formula to each element
const logSoftmaxArr = softmaxArr.map(x => Math.log(x));

return logSoftmaxArr;
return /** @type {T} */(logSoftmaxArr);
}

/**
Expand All @@ -178,8 +180,7 @@ export function dot(arr1, arr2) {

/**
* Get the top k items from an iterable, sorted by descending order
*
* @param {Array} items The items to be sorted
* @param {any[]|TypedArray} items The items to be sorted
* @param {number} [top_k=0] The number of top items to return (default: 0 = return all)
* @returns {Array} The top k items, sorted by descending order
*/
Expand Down Expand Up @@ -252,8 +253,8 @@ export function min(arr) {

/**
* Returns the value and index of the maximum element in an array.
* @param {number[]|TypedArray} arr array of numbers.
* @returns {number[]} the value and index of the maximum element, of the form: [valueOfMax, indexOfMax]
* @param {number[]|AnyTypedArray} arr array of numbers.
* @returns {[number, number]} the value and index of the maximum element, of the form: [valueOfMax, indexOfMax]
* @throws {Error} If array is empty.
*/
export function max(arr) {
Expand All @@ -266,7 +267,7 @@ export function max(arr) {
indexOfMax = i;
}
}
return [max, indexOfMax];
return [Number(max), indexOfMax];
}

function isPowerOfTwo(number) {
Expand Down
Loading
Loading