Skip to content

Commit

Permalink
Metal: Pre-process gf4 weights in memory to make matmul a little faster
Browse files Browse the repository at this point in the history
To accelerate matmul further we need a slightly different order and structure of bits, and
it's getting relatively expensive to convert this on the fly during a matmul, so we now
shuffle the gf4 weight bits in memory. This allows us (for now) to remove a XOR, and with
another small tweak wrt bit math we get a ~2% speedup. Slightly more should be possible
if we also change the bit order, although we need to be careful to keep fp8 scale fast to
decode, so that will have to happen separately.
  • Loading branch information
zeux committed Apr 16, 2024
1 parent a4f30e3 commit fd3852f
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 5 deletions.
32 changes: 32 additions & 0 deletions src/infer.m
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,38 @@ void prepare_metal(struct Transformer* transformer) {
weights->bqkv[l] = bqkv;
}
}

if (weights->dbits == 4) {
id<MTLCommandBuffer> commands = [queue commandBufferWithUnretainedReferences];
id<MTLComputeCommandEncoder> encoder = [commands computeCommandEncoder];

dispatch(encoder, "prepare_gf4", NULL, dim * config->vocab_size / 256, 32, 0, (int[]) {0}, sizeof(int), (void*[]) { weights->token_embedding_table }, 1);

for (int l = 0; l < config->n_layers; ++l) {
dispatch(encoder, "prepare_gf4", NULL, q_dim * dim / 256, 32, 0, (int[]) {0}, sizeof(int), (void*[]) { weights->wq[l] }, 1);
dispatch(encoder, "prepare_gf4", NULL, kv_dim * dim / 256, 32, 0, (int[]) {0}, sizeof(int), (void*[]) { weights->wk[l] }, 1);
dispatch(encoder, "prepare_gf4", NULL, kv_dim * dim / 256, 32, 0, (int[]) {0}, sizeof(int), (void*[]) { weights->wv[l] }, 1);
dispatch(encoder, "prepare_gf4", NULL, dim * q_dim / 256, 32, 0, (int[]) {0}, sizeof(int), (void*[]) { weights->wo[l] }, 1);

int n_experts = config->n_experts ? config->n_experts : 1;

dispatch(encoder, "prepare_gf4", NULL, n_experts * hidden_dim * dim / 256, 32, 0, (int[]) {0}, sizeof(int), (void*[]) { weights->w1[l] }, 1);
dispatch(encoder, "prepare_gf4", NULL, n_experts * dim * hidden_dim / 256, 32, 0, (int[]) {0}, sizeof(int), (void*[]) { weights->w2[l] }, 1);
dispatch(encoder, "prepare_gf4", NULL, n_experts * hidden_dim * dim / 256, 32, 0, (int[]) {0}, sizeof(int), (void*[]) { weights->w3[l] }, 1);

if (weights->moegate[l]) {
dispatch(encoder, "prepare_gf4", NULL, config->n_experts * dim / 256, 32, 0, (int[]) {0}, sizeof(int), (void*[]) { weights->moegate[l] }, 1);
}
}

if (weights->wcls != weights->token_embedding_table) {
dispatch(encoder, "prepare_gf4", NULL, dim * config->vocab_size / 256, 32, 0, (int[]) {0}, sizeof(int), (void*[]) { weights->wcls }, 1);
}

[encoder endEncoding];
[commands commit];
[commands waitUntilCompleted];
}
}

struct SinkArgs {
Expand Down
17 changes: 12 additions & 5 deletions src/infer.metal
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ inline float blockreduce_max(threadgroup float* vs, float val, uint id) {

inline half gf4_ff(uint32_t v, int k) {
half s = as_type<half>(uint16_t(v << 8)) * half(-0.25f); // we expect compiler to reuse this across multiple calls
return half(int((v >> (8 + k * 3)) & 7) - 4) * s;
return half((int((v >> (8 + k * 3)) & 7) ^ 4) - 4) * s;
}

inline float matmul_warppar(device float* x, device half* w, int i, int n, uint id) {
Expand Down Expand Up @@ -67,17 +67,18 @@ inline float matmul_warppar(device AT* x, device uint32_t* w, int i, int n, uint
float val = 0.0f;
for (int j = lane * 8; j < n; j += warpSize * 8) {
uint32_t wg = w[i * n / 8 + j / 8];
float4 xx0 = float4(*(device AT4*)&x[j]);
float4 xx1 = float4(*(device AT4*)&x[j + 4]);
AT4 xx0 = *(device AT4*)&x[j];
AT4 xx1 = *(device AT4*)&x[j + 4];

int wgi = ((wg & 0xffff0000) | ((wg >> 4) & 0xffff)) ^ 0x92409240;
int wgi = ((wg & 0xfff00000) | ((wg >> 4) & 0xfff0));

float us = as_type<half>(uint16_t(wg << 8));
float s = us * -0.25f * exp2(-13.f);

float acc = 0;
for (int k = 0; k < 4; ++k) {
int wgk = (wgi << (9 - k * 3)) & 0xe000e000;
int wgk = wgi << (9 - k * 3);
if (k != 0) wgk &= 0xe000e000;
short2 wgkp = as_type<short2>(wgk);
acc += float(wgkp.x) * xx0[k];
acc += float(wgkp.y) * xx1[k];
Expand All @@ -87,6 +88,12 @@ inline float matmul_warppar(device AT* x, device uint32_t* w, int i, int n, uint
return simd_sum(val);
}

kernel void prepare_gf4(constant int& n [[buffer(0)]], device uint32_t* data [[buffer(1)]], uint id [[thread_position_in_grid]]) {
uint32_t wg = data[id];
wg ^= 0x92492400;
data[id] = wg;
}

inline float gelu(float x) {
return 0.5f * x * (1.0f + precise::tanh(0.797885f * (x + 0.044715f * x * x * x)));
}
Expand Down

0 comments on commit fd3852f

Please sign in to comment.