Skip to content

Commit

Permalink
llama : move calculation of relative position bucket to a new llama_r…
Browse files Browse the repository at this point in the history
…elative_position_bucket() function
  • Loading branch information
sszymczy committed Jun 20, 2024
1 parent eb4c17e commit 704b160
Showing 1 changed file with 33 additions and 66 deletions.
99 changes: 33 additions & 66 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12275,6 +12275,30 @@ static void llama_set_s_copy(llama_context & lctx) {
}
}

static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t num_buckets, bool bidirectional) {
// TODO move to hparams if a T5 variant appears that uses a different value
const int64_t max_distance = 128;

if (bidirectional) {
num_buckets >>= 1;
}

const int64_t max_exact = num_buckets >> 1;

int32_t relative_position = x - y;
int32_t relative_bucket = 0;
if (bidirectional) {
relative_bucket += (relative_position > 0) * num_buckets;
relative_position = abs(relative_position);
} else {
relative_position = -std::min<int32_t>(relative_position, 0);
}
int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (num_buckets - max_exact) / log(1.0 * max_distance / max_exact));
relative_position_if_large = std::min<int32_t>(relative_position_if_large, num_buckets - 1);
relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
return relative_bucket;
}

static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
//
// set input data
Expand Down Expand Up @@ -12515,83 +12539,26 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}

if (lctx.inp_pos_bucket) {
int64_t num_buckets = hparams.n_rel_attn_bkts;
const int64_t max_distance = 128; // TODO move to hparams
bool bidirectional = lctx.is_encoding;

if (bidirectional) {
num_buckets >>= 1;
}

int64_t max_exact = num_buckets >> 1;

const int64_t n_tokens = batch.n_tokens;

GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_pos_bucket->buffer));

int32_t * data = (int32_t *) lctx.inp_pos_bucket->data;

if (!lctx.is_encoding) {
const int64_t n_kv = kv_self.n;
const int64_t n_tokens = batch.n_tokens;

GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_pos_bucket->buffer));

int32_t * data = (int32_t *) lctx.inp_pos_bucket->data;

const int64_t n_kv = kv_self.n;
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
const llama_seq_id seq_id = batch.seq_id[j][0];

for (int i = 0; i < n_kv; ++i) {
int32_t f;
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
f = 0;
} else {
int32_t relative_position = lctx.kv_self.cells[i].pos - pos;
int32_t relative_buckets = 0;
if (bidirectional) {
relative_buckets += (relative_position > 0) * num_buckets;
relative_position = abs(relative_position);
} else {
relative_position = -std::min<int32_t>(relative_position, 0);
}
int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (num_buckets - max_exact) / log(1.0 * max_distance / max_exact));
relative_position_if_large = std::min<int32_t>(relative_position_if_large, num_buckets - 1);
relative_buckets += (relative_position < max_exact ? relative_position : relative_position_if_large);
f = relative_buckets;
}
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(lctx.kv_self.cells[i].pos, batch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding);
}
}
}
} else {
const int64_t n_tokens = batch.n_tokens;

GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_pos_bucket->buffer));

int32_t * data = (int32_t *) lctx.inp_pos_bucket->data;

for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_seq_id seq_id = batch.seq_id[j][0];

for (int i = 0; i < n_tokens; ++i) {
int32_t f = 0;
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
if (batch.seq_id[i][s] == seq_id) {
int32_t relative_position = batch.pos[i] - batch.pos[j];
int32_t relative_buckets = 0;
if (bidirectional) {
relative_buckets += (relative_position > 0) * num_buckets;
relative_position = abs(relative_position);
} else {
relative_position = -std::min<int32_t>(relative_position, 0);
}
int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (num_buckets - max_exact) / log(1.0 * max_distance / max_exact));
relative_position_if_large = std::min<int32_t>(relative_position_if_large, num_buckets - 1);
relative_buckets += (relative_position < max_exact ? relative_position : relative_position_if_large);
f = relative_buckets;
break;
}
}

data[h*(n_tokens*n_tokens) + j*n_tokens + i] = f;
data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(batch.pos[i], batch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding);
}
}
}
Expand Down

0 comments on commit 704b160

Please sign in to comment.