diff --git a/llama.cpp b/llama.cpp index 6b9d6b702a4a7..23086da8f3df4 100644 --- a/llama.cpp +++ b/llama.cpp @@ -12275,6 +12275,30 @@ static void llama_set_s_copy(llama_context & lctx) { } } +static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t num_buckets, bool bidirectional) { + // TODO move to hparams if a T5 variant appears that uses a different value + const int64_t max_distance = 128; + + if (bidirectional) { + num_buckets >>= 1; + } + + const int64_t max_exact = num_buckets >> 1; + + int32_t relative_position = x - y; + int32_t relative_bucket = 0; + if (bidirectional) { + relative_bucket += (relative_position > 0) * num_buckets; + relative_position = abs(relative_position); + } else { + relative_position = -std::min(relative_position, 0); + } + int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (num_buckets - max_exact) / log(1.0 * max_distance / max_exact)); + relative_position_if_large = std::min(relative_position_if_large, num_buckets - 1); + relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large); + return relative_bucket; +} + static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { // // set input data @@ -12515,83 +12539,26 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } if (lctx.inp_pos_bucket) { - int64_t num_buckets = hparams.n_rel_attn_bkts; - const int64_t max_distance = 128; // TODO move to hparams - bool bidirectional = lctx.is_encoding; - - if (bidirectional) { - num_buckets >>= 1; - } - - int64_t max_exact = num_buckets >> 1; - + const int64_t n_tokens = batch.n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_pos_bucket->buffer)); + + int32_t * data = (int32_t *) lctx.inp_pos_bucket->data; + if (!lctx.is_encoding) { - const int64_t n_kv = kv_self.n; - const int64_t n_tokens = batch.n_tokens; - - GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_pos_bucket->buffer)); - - int32_t * data = (int32_t *) lctx.inp_pos_bucket->data; - + const int64_t n_kv = kv_self.n; for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { - const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j][0]; - for (int i = 0; i < n_kv; ++i) { - int32_t f; - if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) { - f = 0; - } else { - int32_t relative_position = lctx.kv_self.cells[i].pos - pos; - int32_t relative_buckets = 0; - if (bidirectional) { - relative_buckets += (relative_position > 0) * num_buckets; - relative_position = abs(relative_position); - } else { - relative_position = -std::min(relative_position, 0); - } - int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (num_buckets - max_exact) / log(1.0 * max_distance / max_exact)); - relative_position_if_large = std::min(relative_position_if_large, num_buckets - 1); - relative_buckets += (relative_position < max_exact ? relative_position : relative_position_if_large); - f = relative_buckets; - } - data[h*(n_kv*n_tokens) + j*n_kv + i] = f; + data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(lctx.kv_self.cells[i].pos, batch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding); } } } } else { - const int64_t n_tokens = batch.n_tokens; - - GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_pos_bucket->buffer)); - - int32_t * data = (int32_t *) lctx.inp_pos_bucket->data; - for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { - const llama_seq_id seq_id = batch.seq_id[j][0]; - for (int i = 0; i < n_tokens; ++i) { - int32_t f = 0; - for (int s = 0; s < batch.n_seq_id[i]; ++s) { - if (batch.seq_id[i][s] == seq_id) { - int32_t relative_position = batch.pos[i] - batch.pos[j]; - int32_t relative_buckets = 0; - if (bidirectional) { - relative_buckets += (relative_position > 0) * num_buckets; - relative_position = abs(relative_position); - } else { - relative_position = -std::min(relative_position, 0); - } - int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (num_buckets - max_exact) / log(1.0 * max_distance / max_exact)); - relative_position_if_large = std::min(relative_position_if_large, num_buckets - 1); - relative_buckets += (relative_position < max_exact ? relative_position : relative_position_if_large); - f = relative_buckets; - break; - } - } - - data[h*(n_tokens*n_tokens) + j*n_tokens + i] = f; + data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(batch.pos[i], batch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding); } } }