Skip to content

Commit

Permalink
feat: (almost) working nomic embed architecture
Browse files Browse the repository at this point in the history
  • Loading branch information
zanussbaum committed Dec 20, 2024
1 parent d5d5796 commit 9de39ad
Show file tree
Hide file tree
Showing 10 changed files with 379 additions and 30 deletions.
11 changes: 6 additions & 5 deletions src/autograd/function.ts
Original file line number Diff line number Diff line change
Expand Up @@ -155,21 +155,22 @@ export abstract class BinaryOp extends AutogradFunction {
pass.setPipeline(this.pipeline);
pass.setBindGroup(0, bindGroup);

// TODO: set these as overrides in the layers/ops level since the kernels are different
const WORKGROUP_SIZE = 16;
const TILE_SIZE = 8;
const workgropuA = Math.ceil(a.shape[0] / (TILE_SIZE * WORKGROUP_SIZE));
const workgropuB = Math.ceil(b.shape[1] / (TILE_SIZE * WORKGROUP_SIZE));
const workgroupA = Math.ceil(a.shape[0] / (TILE_SIZE * WORKGROUP_SIZE));
const workgroupB = Math.ceil(b.shape[1] / (TILE_SIZE * WORKGROUP_SIZE));
console.log(
"a.shape[0]:",
a.shape[0],
"b.shape[1]:",
b.shape[1],
"launching workgroups",
workgropuA,
workgroupA,
",",
workgropuB,
workgroupB,
);
pass.dispatchWorkgroups(workgropuA, workgropuB);
pass.dispatchWorkgroups(workgroupA, workgroupB);
pass.end();

const stagingBuffer = this.device.createBuffer({
Expand Down
1 change: 1 addition & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ export * from "./layers/linear.js";
export * from "./layers/norm.js";
export * from "./layers/mlp.js";
export * from "./layers/attention.js";
export * from "./model/nomic_embed.js"
7 changes: 5 additions & 2 deletions src/layers/attention.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,11 @@ export class MultiHeadAttention extends Module {
): Promise<[Tensor, number]> {
// Scale factor is 1/sqrt(head_dim)
const scale = 1 / Math.sqrt(this.head_dim);
const scaleTensor = Tensor.full(query.shape, scale, false);

const scaleTensor = Tensor.full(
[query.shape[0], key.shape[0]],
scale,
false,
);
// Compute attention scores
const [scores] = await query.matmul(key.transpose());
const [scaledScores] = await scores.mul(scaleTensor);
Expand Down
2 changes: 1 addition & 1 deletion src/layers/embedding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ export class Embedding extends Module {

this.vocab_size = vocab_size;
this.emb_dim = emb_dim;
this.embedding = Tensor.randn([vocab_size, emb_dim], true);
this.embedding = Tensor.normal([vocab_size, emb_dim], true, 0.02);
}

async forward(...inputs: [Tensor]): Promise<[Tensor]> {
Expand Down
4 changes: 2 additions & 2 deletions src/layers/linear.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ export class Linear extends Module {

constructor(inputSize: number, outputSize: number) {
super("linear");
this.weight = Tensor.randn([inputSize, outputSize], true);
this.bias = Tensor.randn([outputSize], true);
this.weight = Tensor.normal([inputSize, outputSize], true, 0.02);
this.bias = Tensor.full([outputSize], 0, true);
}

async forward(...inputs: [Tensor]): Promise<[Tensor]> {
Expand Down
8 changes: 2 additions & 6 deletions src/layers/norm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,14 @@ export class LayerNorm extends Module {

// Calculate mean and reshape for broadcasting
const mean = await x.mean(reduction_dims);
console.log("mean.data", mean.data.toString());
mean.shape = [mean.shape[0], 1]; // [2, 1]

const variance = await x.variance(reduction_dims);
variance.shape = [variance.shape[0], 1]; // [2, 1]

console.log("x shape:", x.shape); // [2, 3]
console.log("mean shape:", mean.shape); // [2, 1]
console.log("variance shape:", variance.shape); // [2, 1]
console.log("gamma shape:", this.gamma.shape); // [1, 3]
console.log("beta shape:", this.beta.shape); // [1, 3]

const [numerator] = await x.sub(mean); // [2, 3]
console.log("numerator.data", numerator.data.toString());
const [denominator] = await variance.add(this.eps);
const sqrtDenom = await denominator.sqrt();
const [normalized] = await numerator.div(sqrtDenom);
Expand Down
195 changes: 195 additions & 0 deletions src/model/nomic_embed.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
import { Tensor } from "../tensor/tensor.js";
import { Module } from "../layers/module.js";
import { LayerNorm } from "../layers/norm.js";
import { MultiHeadAttention } from "../layers/attention.js";
import { MLP } from "../layers/mlp.js";
import { Embedding } from "../layers/embedding.js";

export interface NomicEmbedConfig {
vocab_size: number;
hidden_size: number;
num_hidden_layers: number;
num_attention_heads: number;
intermediate_size: number;
hidden_act: string;
hidden_dropout_prob: number;
attention_probs_dropout_prob: number;
max_position_embeddings: number;
type_vocab_size: number;
initializer_range: number;
layer_norm_eps: number;
pad_token_id: number;
position_embedding_type: string;
use_cache: boolean;
classifier_dropout: number | null;
rotary_emb_fraction: number;
use_flash_attn: boolean;
qkv_proj_bias: boolean;
mlp_fc1_bias: boolean;
mlp_fc2_bias: boolean;
causal: boolean;
}

class NomicBertEmbeddings extends Module {
private wordEmbeddings: Embedding;
private positionEmbeddings: Embedding | null;
private typeEmbeddings: Embedding | null;
private maxPositionEmbeddings: number;
private typeVocabSize: number;

constructor(config: NomicEmbedConfig) {
super("bert_embeddings");

// Word embeddings
this.wordEmbeddings = new Embedding(config.vocab_size, config.hidden_size);

// Position embeddings if using absolute positions
this.maxPositionEmbeddings = config.max_position_embeddings;
this.positionEmbeddings =
this.maxPositionEmbeddings > 0 && config.rotary_emb_fraction <= 0
? new Embedding(config.max_position_embeddings, config.hidden_size)
: null;

// Token type embeddings if used
this.typeVocabSize = config.type_vocab_size;
this.typeEmbeddings =
this.typeVocabSize > 0
? new Embedding(config.type_vocab_size, config.hidden_size)
: null;
}

async forward(
inputIds: Tensor,
positionIds?: Tensor,
tokenTypeIds?: Tensor,
inputsEmbeds?: Tensor,
): Promise<[Tensor]> {
// Get word embeddings
let [embeddings] = inputsEmbeds
? [inputsEmbeds]
: await this.wordEmbeddings.forward(inputIds);

// Add token type embeddings if used
// if (this.typeEmbeddings && this.typeVocabSize > 0 && tokenTypeIds) {
// const [typeEmbeddings] = await this.typeEmbeddings.forward(tokenTypeIds);
// console.log("typeEmbeddings.data", typeEmbeddings.data.toString());
// console.log("typeEmbeddings.shape", typeEmbeddings.shape);
// [embeddings] = await embeddings.add(typeEmbeddings);
// }

return [embeddings];
}
}

class NomicBertLayer extends Module {
private attention: MultiHeadAttention;
private mlp: MLP;
private layerNorm1: LayerNorm;
private layerNorm2: LayerNorm;

constructor(config: NomicEmbedConfig) {
super("bert_layer");
this.attention = new MultiHeadAttention(
config.hidden_size,
config.num_attention_heads,
);
this.mlp = new MLP(config.hidden_size, config.intermediate_size);
this.layerNorm1 = new LayerNorm(
[config.hidden_size],
config.layer_norm_eps,
);
this.layerNorm2 = new LayerNorm(
[config.hidden_size],
config.layer_norm_eps,
);
}

async forward(...inputs: [Tensor]): Promise<[Tensor]> {
// Self-attention
const [hiddenStates] = inputs;
const [normed1] = await this.layerNorm1.forward(hiddenStates);
const [attnOutput] = await this.attention.forward(normed1);
const [residual1] = await hiddenStates.add(attnOutput);

// MLP
const [normed2] = await this.layerNorm2.forward(residual1);
const [mlpOutput] = await this.mlp.forward(normed2);
const [residual2] = await residual1.add(mlpOutput);
return [residual2];
}
}

class NomicBertEncoder extends Module {
private layers: NomicBertLayer[];

constructor(config: NomicEmbedConfig) {
super("bert_encoder");
this.layers = Array(config.num_hidden_layers)
.fill(null)
.map(() => new NomicBertLayer(config));
}

async forward(...args: Tensor[]): Promise<[Tensor]> {
let [hiddenStates, attentionMask] = args;
let currentOutput = hiddenStates;

// Pass through each layer
for (const layer of this.layers) {
[currentOutput] = await layer.forward(currentOutput);
}

return [currentOutput];
}
}

export class NomicEmbed extends Module {
private embeddings: NomicBertEmbeddings;
private encoder: NomicBertEncoder;
private emb_ln: LayerNorm;

constructor(config: NomicEmbedConfig) {
super("nomic_embed");

// Initialize components
this.embeddings = new NomicBertEmbeddings(config);
this.encoder = new NomicBertEncoder(config);
this.emb_ln = new LayerNorm([config.hidden_size], config.layer_norm_eps);
}

private async meanPooling(
modelOutput: Tensor,
attentionMask: Tensor,
): Promise<[Tensor]> {
return [await modelOutput.mean([0])];
}

async forward(...args: Tensor[]): Promise<[Tensor]> {
// Get embeddings
const [inputIds, attentionMask, positionIds, tokenTypeIds] = args;
const [hidden] = await this.embeddings.forward(
inputIds,
positionIds,
tokenTypeIds,
);
console.log("hidden.data", hidden.data.toString());

// Apply layer norm
const [normed] = await this.emb_ln.forward(hidden);
console.log("normed.data", normed.data.toString());

// Pass through encoder
const [encoded] = await this.encoder.forward(normed, attentionMask);
// Mean pooling
console.log("encoded.data", encoded.data.toString());
const [pooled] = await this.meanPooling(encoded, attentionMask);
console.log("pooled.shape", pooled.shape);

const [norm] = await pooled.norm(2, 0);
console.log("norm.shape", norm.shape);
console.log("norm", norm.data.toString());

const [pooledNormed] = await pooled.div(norm);
// Normalize embeddings
return [pooledNormed];
}
}
5 changes: 5 additions & 0 deletions src/ops/add.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ export class Add extends BinaryOp {
);
}
}
console.log("add a.shape:", a.shape);
console.log("a.data:", a.data.toString());
console.log("add broadcasted b.shape:", b.shape);
console.log("b.data:", b.data.toString());

return b;
}

Expand Down
49 changes: 35 additions & 14 deletions src/tensor/tensor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,20 @@ export class Tensor {
return new Tensor(data, shape, requires_grad);
}

static normal(
shape: number[],
requires_grad = false,
initializer_range = 0.01,
) {
const data = new Float32Array(shape.reduce((a, b) => a * b));

for (let i = 0; i < data.length; i++) {
data[i] = Math.random() * 2 * initializer_range - initializer_range;
}

return new Tensor(data, shape, requires_grad);
}

static broadcast(tensor: Tensor, size: number, requires_grad = false) {
const shape = [size, ...tensor.shape];
const data = new Float32Array(shape.reduce((a, b) => a * b));
Expand Down Expand Up @@ -102,6 +116,7 @@ export class Tensor {

const negOne = Tensor.full(tensor.shape, -1, false);
const [negTensor] = await tensor.mul(negOne);
console.log("this.shape", this.shape);
return this.add(negTensor);
}

Expand Down Expand Up @@ -305,21 +320,27 @@ export class Tensor {
}

async gather(indices: Tensor): Promise<[Tensor, number]> {
// Convert indices to one-hot
const oneHot = new Float32Array(indices.shape[0] * this.shape[0]).fill(0);
for (let i = 0; i < indices.shape[0]; i++) {
const index = indices.data[i] + i * this.shape[0];
// set one hot value for the whole vector
oneHot.fill(1, index, index + 1);
}

const oneHotTensor = new Tensor(
oneHot,
[indices.shape[0], this.shape[0]],
indices.requires_grad,
);
// For input shape [batch_size] and embedding matrix [vocab_size, embedding_dim]
// We want output shape [batch_size, embedding_dim]
const batchSize = indices.shape[0];
const embeddingDim = this.shape[1];
const result = new Float32Array(batchSize * embeddingDim);

// For each item in the batch
for (let i = 0; i < batchSize; i++) {
const tokenId = indices.data[i];
// Copy the entire embedding vector for this token
const sourceOffset = tokenId * embeddingDim;
const targetOffset = i * embeddingDim;
for (let j = 0; j < embeddingDim; j++) {
result[targetOffset + j] = this.data[sourceOffset + j];
}
}

return oneHotTensor.matmul(this);
return [
new Tensor(result, [batchSize, embeddingDim], indices.requires_grad),
-1,
];
}

transpose() {
Expand Down
Loading

0 comments on commit 9de39ad

Please sign in to comment.