Skip to content

Commit

Permalink
clang-format
Browse files Browse the repository at this point in the history
  • Loading branch information
zeux committed Apr 17, 2024
1 parent ed41345 commit 8a7943a
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 28 deletions.
2 changes: 1 addition & 1 deletion src/helpers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ __device__ inline half gf4_ff(uint32_t v, int k) {
// gf4 decoding (2 values): 8 3-bit values + 1 fp8 scale are packed in a 32-bit word
__device__ inline half2 gf4x2_ff(uint32_t v, int k) {
half us = fp8_e5m2_ff(v & 0xff); // we expect compiler to reuse this across multiple calls
half s = us * half(-0.25f); // we expect compiler to reuse this across multiple calls
half s = us * half(-0.25f); // we expect compiler to reuse this across multiple calls
uint32_t p = v >> (8 + k * 3);
half2 q = half2(int(p & 7), int((p >> 3) & 7));
return __hfma2(q, half2(s, s), half2(us, us));
Expand Down
2 changes: 1 addition & 1 deletion src/infer.c
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ static void moe_gate(float* moe_weights, int* moe_experts, float* x, int d, int
}

inline float clip(float x, float v) {
return (x < -v) ? -v : (x > v) ? v : x;
return x < -v ? -v : (x > v ? v : x);
}

float* forward(struct Transformer* transformer, int token, int pos, unsigned flags) {
Expand Down
51 changes: 26 additions & 25 deletions src/infer.m
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@ void init_metal(void) {
device = devices[0];
queue = [device newCommandQueue];

dispatch_data_t lib_data = dispatch_data_create(infer_metallib, infer_metallib_len, dispatch_get_main_queue(), ^{});
dispatch_data_t lib_data = dispatch_data_create(infer_metallib, infer_metallib_len, dispatch_get_main_queue(), ^{
});

NSError *error = nil;
NSError* error = nil;
id<MTLLibrary> library = [device newLibraryWithData:lib_data error:&error];
assert(library);

Expand Down Expand Up @@ -111,27 +112,27 @@ void prepare_metal(struct Transformer* transformer) {
id<MTLCommandBuffer> commands = [queue commandBufferWithUnretainedReferences];
id<MTLComputeCommandEncoder> encoder = [commands computeCommandEncoder];

dispatch(encoder, "prepare_gf4", NULL, dim * config->vocab_size / 256, 32, 0, (int[]) {0}, sizeof(int), (void*[]) { weights->token_embedding_table }, 1);
dispatch(encoder, "prepare_gf4", NULL, dim * config->vocab_size / 256, 32, 0, (int[]){0}, sizeof(int), (void*[]){weights->token_embedding_table}, 1);

for (int l = 0; l < config->n_layers; ++l) {
dispatch(encoder, "prepare_gf4", NULL, q_dim * dim / 256, 32, 0, (int[]) {0}, sizeof(int), (void*[]) { weights->wq[l] }, 1);
dispatch(encoder, "prepare_gf4", NULL, kv_dim * dim / 256, 32, 0, (int[]) {0}, sizeof(int), (void*[]) { weights->wk[l] }, 1);
dispatch(encoder, "prepare_gf4", NULL, kv_dim * dim / 256, 32, 0, (int[]) {0}, sizeof(int), (void*[]) { weights->wv[l] }, 1);
dispatch(encoder, "prepare_gf4", NULL, dim * q_dim / 256, 32, 0, (int[]) {0}, sizeof(int), (void*[]) { weights->wo[l] }, 1);
dispatch(encoder, "prepare_gf4", NULL, q_dim * dim / 256, 32, 0, (int[]){0}, sizeof(int), (void*[]){weights->wq[l]}, 1);
dispatch(encoder, "prepare_gf4", NULL, kv_dim * dim / 256, 32, 0, (int[]){0}, sizeof(int), (void*[]){weights->wk[l]}, 1);
dispatch(encoder, "prepare_gf4", NULL, kv_dim * dim / 256, 32, 0, (int[]){0}, sizeof(int), (void*[]){weights->wv[l]}, 1);
dispatch(encoder, "prepare_gf4", NULL, dim * q_dim / 256, 32, 0, (int[]){0}, sizeof(int), (void*[]){weights->wo[l]}, 1);

int n_experts = config->n_experts ? config->n_experts : 1;

dispatch(encoder, "prepare_gf4", NULL, n_experts * hidden_dim * dim / 256, 32, 0, (int[]) {0}, sizeof(int), (void*[]) { weights->w1[l] }, 1);
dispatch(encoder, "prepare_gf4", NULL, n_experts * dim * hidden_dim / 256, 32, 0, (int[]) {0}, sizeof(int), (void*[]) { weights->w2[l] }, 1);
dispatch(encoder, "prepare_gf4", NULL, n_experts * hidden_dim * dim / 256, 32, 0, (int[]) {0}, sizeof(int), (void*[]) { weights->w3[l] }, 1);
dispatch(encoder, "prepare_gf4", NULL, n_experts * hidden_dim * dim / 256, 32, 0, (int[]){0}, sizeof(int), (void*[]){weights->w1[l]}, 1);
dispatch(encoder, "prepare_gf4", NULL, n_experts * dim * hidden_dim / 256, 32, 0, (int[]){0}, sizeof(int), (void*[]){weights->w2[l]}, 1);
dispatch(encoder, "prepare_gf4", NULL, n_experts * hidden_dim * dim / 256, 32, 0, (int[]){0}, sizeof(int), (void*[]){weights->w3[l]}, 1);

if (weights->moegate[l]) {
dispatch(encoder, "prepare_gf4", NULL, config->n_experts * dim / 256, 32, 0, (int[]) {0}, sizeof(int), (void*[]) { weights->moegate[l] }, 1);
dispatch(encoder, "prepare_gf4", NULL, config->n_experts * dim / 256, 32, 0, (int[]){0}, sizeof(int), (void*[]){weights->moegate[l]}, 1);
}
}

if (weights->wcls != weights->token_embedding_table) {
dispatch(encoder, "prepare_gf4", NULL, dim * config->vocab_size / 256, 32, 0, (int[]) {0}, sizeof(int), (void*[]) { weights->wcls }, 1);
dispatch(encoder, "prepare_gf4", NULL, dim * config->vocab_size / 256, 32, 0, (int[]){0}, sizeof(int), (void*[]){weights->wcls}, 1);
}

[encoder endEncoding];
Expand Down Expand Up @@ -240,53 +241,53 @@ void prepare_metal(struct Transformer* transformer) {

// copy the token embedding into x
assert(token < p->vocab_size);
dispatch(encoder, "embed", dvar, dim / 32, 32, 0, (int[]){ token * dim }, sizeof(int), (void*[]){ x, w->token_embedding_table }, 2);
dispatch(encoder, "embed", dvar, dim / 32, 32, 0, (int[]){token * dim}, sizeof(int), (void*[]){x, w->token_embedding_table}, 2);

// rotate sink tokens forward to keep pace with non-sink tokens
if (kv_sink > 0) {
dispatch(encoder, "rotate_sink", kvar, (kv_sink * kv_dim / 64) * p->n_layers, 32, 0, &(struct SinkArgs) { kv_dim, p->head_dim, p->rotary_dim, kv_sink, p->seq_len, log2(p->rope_theta) }, sizeof(struct SinkArgs), (void*[]) { s->key_cache }, 1);
dispatch(encoder, "rotate_sink", kvar, (kv_sink * kv_dim / 64) * p->n_layers, 32, 0, &(struct SinkArgs){kv_dim, p->head_dim, p->rotary_dim, kv_sink, p->seq_len, log2(p->rope_theta)}, sizeof(struct SinkArgs), (void*[]){s->key_cache}, 1);
}

// forward all the layers
for (int l = 0; l < p->n_layers; ++l) {
size_t loff = (size_t)l * p->seq_len * kv_dim; // kv cache layer offset for convenience

// pre-attention rmsnorm
dispatch(encoder, "rmsnorm", nvar, 1, 1024, 0, &(struct NormArgs) { dim, p->norm_eps, p->norm_ln }, sizeof(struct NormArgs), (void*[]) { s->xb, x, w->rms_att_weight[l] }, 3);
dispatch(encoder, "rmsnorm", nvar, 1, 1024, 0, &(struct NormArgs){dim, p->norm_eps, p->norm_ln}, sizeof(struct NormArgs), (void*[]){s->xb, x, w->rms_att_weight[l]}, 3);

// qkv
dispatch(encoder, "qkv", dkvar, (q_dim + kv_dim * 2) / 2, 32, 0, &(struct QkvArgs) { dim, q_dim, kv_dim, p->head_dim, p->rotary_dim, pos, kv_pos, p->seq_len, loff, p->qkv_clip, log2(p->rope_theta) }, sizeof(struct QkvArgs), (void*[]) { s->xb, s->q, s->key_cache, s->value_cache, w->wq[l], w->wk[l], w->wv[l], w->bqkv[l] }, 8);
dispatch(encoder, "qkv", dkvar, (q_dim + kv_dim * 2) / 2, 32, 0, &(struct QkvArgs){dim, q_dim, kv_dim, p->head_dim, p->rotary_dim, pos, kv_pos, p->seq_len, loff, p->qkv_clip, log2(p->rope_theta)}, sizeof(struct QkvArgs), (void*[]){s->xb, s->q, s->key_cache, s->value_cache, w->wq[l], w->wk[l], w->wv[l], w->bqkv[l]}, 8);

// attn score
int kv_lent = (kv_len + 7) / 8;

dispatch(encoder, "attn_score", kvar, kv_lent * p->n_heads, 32, 0, &(struct AttnArgs) { p->seq_len, kv_len, p->head_dim, kv_mul, p->n_heads, loff }, sizeof(struct AttnArgs), (void*[]) { s->att, s->q, s->key_cache }, 3);
dispatch(encoder, "attn_score", kvar, kv_lent * p->n_heads, 32, 0, &(struct AttnArgs){p->seq_len, kv_len, p->head_dim, kv_mul, p->n_heads, loff}, sizeof(struct AttnArgs), (void*[]){s->att, s->q, s->key_cache}, 3);

// attn softmax
dispatch(encoder, "attn_softmax", NULL, p->n_heads, 1024, 0, &(struct AttnArgs) { p->seq_len, kv_len, p->head_dim, kv_mul, p->n_heads, loff }, sizeof(struct AttnArgs), (void*[]) { s->att }, 1);
dispatch(encoder, "attn_softmax", NULL, p->n_heads, 1024, 0, &(struct AttnArgs){p->seq_len, kv_len, p->head_dim, kv_mul, p->n_heads, loff}, sizeof(struct AttnArgs), (void*[]){s->att}, 1);

// attn mix
dispatch(encoder, "attn_mix", kvar, q_dim, 32, 0, &(struct AttnArgs) { p->seq_len, kv_len, p->head_dim, kv_mul, p->n_heads, loff }, sizeof(struct AttnArgs), (void*[]) { s->q, s->att, s->value_cache }, 3);
dispatch(encoder, "attn_mix", kvar, q_dim, 32, 0, &(struct AttnArgs){p->seq_len, kv_len, p->head_dim, kv_mul, p->n_heads, loff}, sizeof(struct AttnArgs), (void*[]){s->q, s->att, s->value_cache}, 3);

// attn out
dispatch(encoder, "attn_out", dvar, dim, 32, 0, (int[]) { q_dim }, sizeof(int), (void*[]) { x, s->q, w->wo[l] }, 3);
dispatch(encoder, "attn_out", dvar, dim, 32, 0, (int[]){q_dim}, sizeof(int), (void*[]){x, s->q, w->wo[l]}, 3);

if (!p->norm_par) {
// post-attention rmsnorm
dispatch(encoder, "rmsnorm", nvar, 1, 1024, 0, &(struct NormArgs) { dim, p->norm_eps, p->norm_ln }, sizeof(struct NormArgs), (void*[]) { s->xb, x, w->rms_ffn_weight[l] }, 3);
dispatch(encoder, "rmsnorm", nvar, 1, 1024, 0, &(struct NormArgs){dim, p->norm_eps, p->norm_ln}, sizeof(struct NormArgs), (void*[]){s->xb, x, w->rms_ffn_weight[l]}, 3);
}

assert(p->n_experts == 0); // TODO

// ffn
dispatch(encoder, p->act_gelu ? "ffn13_gelu" : "ffn13_silu", dvar, hidden_dim, 32, 0, (int[]) { dim }, sizeof(int), (void*[]) { s->hb, s->xb, w->w1[l], w->w3[l] }, 4);
dispatch(encoder, "ffn2", dvar, dim, 32, 0, (int[]) { hidden_dim }, sizeof(int), (void*[]) { x, s->hb, w->w2[l] }, 3);
dispatch(encoder, p->act_gelu ? "ffn13_gelu" : "ffn13_silu", dvar, hidden_dim, 32, 0, (int[]){dim}, sizeof(int), (void*[]){s->hb, s->xb, w->w1[l], w->w3[l]}, 4);
dispatch(encoder, "ffn2", dvar, dim, 32, 0, (int[]){hidden_dim}, sizeof(int), (void*[]){x, s->hb, w->w2[l]}, 3);
}

// classifier into logits
if ((flags & FF_UPDATE_KV_ONLY) == 0) {
dispatch(encoder, "rmsnorm", nvar, 1, 1024, 0, &(struct NormArgs) { dim, p->norm_eps, p->norm_ln }, sizeof(struct NormArgs), (void*[]) { s->xb, x, w->rms_final_weight }, 3);
dispatch(encoder, "output", dvar, p->vocab_size, 32, 0, (int[]) { dim }, sizeof(int), (void*[]) { s->logits, s->xb, w->wcls }, 3);
dispatch(encoder, "rmsnorm", nvar, 1, 1024, 0, &(struct NormArgs){dim, p->norm_eps, p->norm_ln}, sizeof(struct NormArgs), (void*[]){s->xb, x, w->rms_final_weight}, 3);
dispatch(encoder, "output", dvar, p->vocab_size, 32, 0, (int[]){dim}, sizeof(int), (void*[]){s->logits, s->xb, w->wcls}, 3);
}

// submit commands and wait
Expand Down
2 changes: 1 addition & 1 deletion src/run.c
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ void generate(struct Transformer* transformer, struct Tokenizer* tokenizer, stru
// otherwise sample the next token from the logits
next = sample(sampler, logits);
assert(next >= 0);

// data-dependent terminating condition: the BOS token delimits sequences, EOS token ends the sequence
if (next == tokenizer->bos_id || next == tokenizer->eos_id) {
break;
Expand Down

0 comments on commit 8a7943a

Please sign in to comment.