Skip to content

Commit

Permalink
kv cache slot search improvements (ggerganov#3493)
Browse files Browse the repository at this point in the history
* kv cache slot search improvements

* Use n_ctx in kv find slot for consistency

* Ensure kv cache head points to a valid slot in llama_decode internal

* Add some comments to prevent dumb people (like me) from getting confused.
  • Loading branch information
KerfuffleV2 authored Oct 6, 2023
1 parent 0c731ca commit 9ca79d5
Showing 1 changed file with 35 additions and 6 deletions.
41 changes: 35 additions & 6 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,9 @@ struct llama_kv_cell {
struct llama_kv_cache {
bool has_shift = false;

// Note: The value of head isn't only used to optimize searching
// for a free KV slot. llama_decode_internal also uses it, so it
// cannot be freely changed after a slot has been allocated.
uint32_t head = 0;
uint32_t size = 0;

Expand Down Expand Up @@ -1339,6 +1342,8 @@ static bool llama_kv_cache_init(

// find an empty slot of size "n_tokens" in the cache
// updates the cache head
// Note: On success, it's important that cache.head points
// to the first cell of the slot.
static bool llama_kv_cache_find_slot(
struct llama_kv_cache & cache,
const struct llama_batch & batch) {
Expand All @@ -1354,8 +1359,8 @@ static bool llama_kv_cache_find_slot(

while (true) {
if (cache.head + n_tokens > n_ctx) {
n_tested += n_ctx - cache.head;
cache.head = 0;
n_tested += n_ctx - cache.head;
continue;
}

Expand Down Expand Up @@ -1406,13 +1411,18 @@ static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0,
cache.cells[i].pos = -1;
cache.cells[i].seq_id.clear();
}

// Searching for a free slot can start here since we know it will be empty.
cache.head = uint32_t(c0);
}

static void llama_kv_cache_seq_rm(
struct llama_kv_cache & cache,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1) {
uint32_t new_head = cache.size;

if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();

Expand All @@ -1421,9 +1431,13 @@ static void llama_kv_cache_seq_rm(
cache.cells[i].seq_id.erase(seq_id);
if (cache.cells[i].seq_id.empty()) {
cache.cells[i].pos = -1;
if (new_head == cache.size) new_head = i;
}
}
}

// If we freed up a slot, set head to it so searching can start there.
if (new_head != cache.size) cache.head = new_head;
}

static void llama_kv_cache_seq_cp(
Expand All @@ -1435,6 +1449,8 @@ static void llama_kv_cache_seq_cp(
if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();

cache.head = 0;

for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
cache.cells[i].seq_id.insert(seq_id_dst);
Expand All @@ -1443,12 +1459,18 @@ static void llama_kv_cache_seq_cp(
}

static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) {
uint32_t new_head = cache.size;

for (uint32_t i = 0; i < cache.size; ++i) {
if (!cache.cells[i].has_seq_id(seq_id)) {
cache.cells[i].pos = -1;
cache.cells[i].seq_id.clear();
if (new_head == cache.size) new_head = i;
}
}

// If we freed up a slot, set head to it so searching can start there.
if (new_head != cache.size) cache.head = new_head;
}

static void llama_kv_cache_seq_shift(
Expand All @@ -1457,6 +1479,8 @@ static void llama_kv_cache_seq_shift(
llama_pos p0,
llama_pos p1,
llama_pos delta) {
uint32_t new_head = cache.size;

if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();

Expand All @@ -1466,12 +1490,17 @@ static void llama_kv_cache_seq_shift(
if (cache.cells[i].pos < 0) {
cache.cells[i].pos = -1;
cache.cells[i].seq_id.clear();
if (new_head == cache.size) new_head = i;
} else {
cache.has_shift = true;
cache.cells[i].delta = delta;
}
}
}

// If we freed up a slot, set head to it so searching can start there.
// Otherwise we just start the next search from the beginning.
cache.head = new_head != cache.size ? new_head : 0;
}

//
Expand Down Expand Up @@ -4492,10 +4521,6 @@ static int llama_decode_internal(
batch.seq_id = seq_id.data();
}

// we always start to search for a free slot from the start of the cache
// TODO: better strategies can be implemented
kv_self.head = 0;

if (!llama_kv_cache_find_slot(kv_self, batch)) {
return 1;
}
Expand Down Expand Up @@ -4581,8 +4606,12 @@ static int llama_decode_internal(
#endif

// update the kv ring buffer
lctx.kv_self.head += n_tokens;
lctx.kv_self.has_shift = false;
lctx.kv_self.head += n_tokens;
// Ensure kv cache head points to a valid index.
if (lctx.kv_self.head >= lctx.kv_self.size) {
lctx.kv_self.head = 0;
}

#ifdef GGML_PERF
// print timing information per ggml operation (for debugging purposes)
Expand Down

0 comments on commit 9ca79d5

Please sign in to comment.