SGLang Memory Management & Cache
Published:
Note: Complex systems often include numerous corner cases and technical implementations that can make the source code challenging to understand for newcomers.
To make the core concepts more accessible, this blog post uses pseudocode that focuses on the main ideas while omitting implementation details (such as
self
references and other technical specifics). While simplified, the pseudocode maintains the essential logic and workflow of the system.Of source, if you want to know all details, the best way is to look directly at the source code, which is available here
Main walker:
launch_server
⇒ _launch_subprocesses
⇒ Init Scheduler
⇒ Init TpWorker
⇒ Init ModelConfig & ModelRunner
⇒ ModelRunner init KV Cache Pool & Allocator
Main points in this blog:
- How
mem-fraction-static
works in the KV Cache Initiation - How is each token’s
KV Cache
computed - How
KV Cache Pool
are managed(allocate, free, use) - How
Radix Cache
reuses KV Cache
This blog mainly contains 2 sections
- In the KV Cache Management section, we will explore how
KV Cache
is managed (creation, allocation, free, and usage) - In the Radix Tree Cache section, we will explore how the
radix tree
data structure enables KV Cache reuse
KV Cache Management
Background The
ModelRunner
: owns real models, runs the forward pass of models
here is the initialization of ModelRunner
, and also the initialization of KV Cache Pool
In this process of initiating memory pool
, SGLang provides 3 abstract managers
req_to_token_pool
: A memory pool that maps a request’s tokens toout_cache_loc
token_to_kv_pool
: A pool that mapsout_cache_loc
fromreq_token_pool
to its real KV Cache datatoken_to_kv_pool_allocator
: Allocate and free real KV Cache data
class ModelRunner:
def __init__(self, model_config, ....):
# adjust `AttentionBackend`, `mem_fraction_static`
model_specific_adjustment()
# since SGLang adjusts the settings depending on Model Arch
# then update that info globally
global_server_args_dict.update({...})
# build WORLD_GROUP, TP_GROUP, PP_GROUP for later communication
# after init the distributed settings, get the minimum GPU memory across the world
min_per_gpu_memory = init_torch_distributed()
initialize(min_per_gpu_memory)
def initialize(min_per_gpu_memory):
# load sampler and model
sampler = Sampler()
load_model()
######
# Until now, Model Weights & Distributed Initialization occpuy some GPU memory
# Note: but `min_per_gpu_memory` doesn't change
######
# Core in this blog!!!
init_memory_pool(
min_per_gpu_memory,
server_args.max_running_requests, # these 2 args are set by users
server_args.max_total_tokens)
# ...
init_cublas()
init_attention_backend()
init_cuda_graphs()
def init_memory_pool(
total_gpu_memory,
max_num_reqs=None,
max_total_tokens=None):
# compute how many token's KV Cache can be saved in each GPU
max_total_num_tokens = profile_max_num_token(total_gpu_memory)
# adjust max_num_requests
if max_num_reqs is None:
max_num_reqs = min(
max(max_total_num_tokens / model_config.context_len * 512, 2048),
4096
)
# adjust max_total_tokens
if max_total_tokens is None:
if max_total_tokens > max_total_num_tokens: logger.warning...
max_total_num_tokens = min(max_total_tokens, max_total_num_tokens)
# align page size
max_total_num_tokens = (max_total_num_tokens // page_size) * page_size
# init req_to_token_pool
req_to_token_pool = ReqToTokenPool(
max_num_reqs + 1,
model_config.context_len + 4,
...)
# init token_to_kv_pool
token_to_kv_pool = MHATokenToKVPool(
max_total_num_tokens,
page_size,
kv_cache_dtype,
head_num,
head_dim,
layer_num,
...)
# init token_to_kv_pool_allocator
token_to_kv_pool_allocator = TokenToKVPoolAllocator(
max_total_num_tokens,
kv_cache_dtype,
device,
token_to_kv_pool)
...END !!!
def profile_max_num_token(total_gpu_memory):
# get min_per_gpu_memory in the world
# Note: model has been loaded before
available_gpu_memory = get_available_gpu_memory(distributed=True)
# Compute how much gpu memory **a token's KV Cache** occupy
# Note: In TP settings, each GPU only handles part of `attention head` when computing attention scores
cell_size = (
model_config.get_num_kv_heads(get_attention_tp_size()) # get how many num_kv_heads in TP setting
* model_config.head_dim
* num_layers
* 2 # since K and V
* element_size(kv_cache_dtype) # bytes for each element of KV Cache Type
)
# This is the **role** of `mem_fraction_static` here
# Note:
# - `total_gpu_memory` is after initializing the distributed environment, min_per_gpu_memory
# - `available_gpu_memory` is after initializing the distbuted environment and loading model, min_per_gpu_memory
# - `total_gpu_memory * (1 - mem_fraction_static)`: the other potential GPU memory usage (like `activation` in the forward pass)
# - `rest_memory`: Free GPU Memory(after loading model) substracting the other GPU memory, the rest is for `KV Cache`
rest_memory = available_gpu_memory - total_gpu_memory *
(1 - mem_fraction_static)
# convert rest_memory from GigeByte back to Byte metric
# compute how many tokens' KV cache can be saved
max_num_tokens = int(rest_memory * (1 << 30) // cell_size)
return max_num_tokens
Reading from above simplified code reviews, we can see:
💡: How mem-fraction-static
works in the KV Cache Initiation
The mem_fraction_static
of GPU memory
is used for model weights
and KV Cache Pool
, Use a smaller value if you see out-of-memory errors. But how does the process go?
- Get Free GPU Memory (
M1
: total GPU free memory) - Load model (this occupy some GPU Memory)
- Get Free GPU Memory again (
M2
: After Loading Model) - Compute non-static GPU memory: (
M3 = M1 * (1 - mem_fraction_static)
) - The memory for KV cache Pool:
M2 - M3
💡: How is each token’s KV Cache
computed
tp_num_head * head_dim * num_layers * 2 * element_size (torch._utils._element_size(kv_cache_dtype))
Managers
req_to_token_pool
A memory pool that maps a request to its token locations.
Shape: max_num_reqs + 1
* self.model_config.context_len + 4
Dtype: torch.int32
Access:
- dim0: the concrete
req_idx
- dim1: token positions in req (starting from 0, 1, 2…), identify the specific token in the request
out_cache_loc
for token, it points to the KV cache indices associated with the token identified by dim0 and dim1
class ReqToTokenPool:
def __init__(size, max_context_len):
req_to_token = torch.zeros(size, max_context_len, dtype=torch.int32)
# record free slots
free_slots = list(range(size))
def write(indices, values):
req_to_token[indices] = values
def available_size():
return len(free_slots)
def alloc(need_size):
if need_size > len(free_slots): return None
# directly remove `need_size` slots
select_index = free_slots[:need_size]
free_slots = free_slots[need_size:]
return select_index
def free(free_index):
free_slots.extend(free_index)
def clear():
free_slots = list(range(size)
token_to_kv_pool
A pool that maps out_cache_loc
from req_token_pool
to its real KV Cache data
Mainly maintain the k_buffer
and v_buffer
which has the same shape
Shape(List of Tensor
): layer_num
* List[Tensor
], where each Tensor
: max_total_num_tokens + page_size
* head_num
* head_dim
Access:
- dim0:
layer_id
identify the specific layer - dim1:
out_cache_loc
identify the specific KV cache indices - dim2:
head
- dim3:
head_dim
- value: real KV Cache data
class MHATokenToKVPool(KVCache):
def __init__(size, page_size, dtype, head_num, head_dim, layer_num, device, start_layer...):
# create real KV Cache buffers
_create_buffers()
############
# Now, each GPU Memory is nearly exhausted
###########
def _create_buffers():
k_buffer = [
torch.zeros(
(size + page_size, head_num, head_dim),
kv_cache_dtype,
device,
)
for _ in range(layer_num)
]
v_buffer = [
torch.zeros(
(size + page_size, head_num, head_dim),
kv_cache_dtype,
device,
)
for _ in range(layer_num)
]
def _clear_buffers():
del k_buffer, v_buffer
################
## READ API
################
def get_key_buffer(layer_id):
return k_buffer[layer_id - start_layer]
def get_value_buffer(layer_id):
return v_buffer[layer_id - start_layer]
def get_kv_buffer(layer_id):
return get_key_buffer(layer_id), get_value_buffer(layer_id)
############
## WRITE API
############
def set_kv_buffer(layer, loc, cache_k, cache_v, ...):
layer_id = layer.layer_id
k_buffer[layer_id - start_layer][loc] = cache_k
v_buffer[layer_id - start_layer][loc] = cache_v
token_to_kv_pool_allocator
used to allocate real KV Cache data: out_cache_loc
class TokenToKVPoolAllocator:
def __init__(size [max_total_num_tokens], dtype, page_size device, kvcache [token_to_kvcache_pool]):
page_size = 1
clear()
def clear():
free_slots = torch.arange(1, self.size + 1, dtype=torch.int64, device)
def available_size():
return len(free_slots)
##########################
# ALLOCATE API
#########################
def alloc(need_size):
if need_size > len(self.free_slots): return None
select_index = free_slots[:need_size]
free_slots = free_slots[need_size:]
return select_index
###########################
## FREE API
###########################
def free(free_index):
free_slots = torch.cat((free_slots, free_index))
Allocate Slots to Reqs & out_cache_loc
This raises the question: how does SGLang
use the above managers to efficiently allocate slots for each token in the requests and free them in a timely manner?
LLM inference consists of two main stages. We start by identifying the allocation requirements for each stage.
- prefill:
req_to_token_pool.alloc
: since we have new reqstoken_to_kv_pool_allocator.alloc
: Maybe,- if we have the
kv cache
in the tokens in the reqs, we can just usereq_to_token_pool.write
to reuse those kv cache - if we don’t have the
kv cache
, then getout_cache_loc
by callingtoken_to_kv_pool_allocator.alloc
, then writeout_cache_loc
intoreq_token_pool
- if we have the
- decode:
req_to_token_pool.alloc
: don’t needtoken_to_kv_pool_allocate.alloc
Need, since we decode one new token one time
So in the scheduler.get_next_batch_to_run
where get ScheduleBatch
, for different stage, there are different logics to prepare where allocate and free slots happened.
class ScheduleBatch:
"""Store all information of a batch on the scheduler."""
def prepare_for_extend():
bs = len(reqs)
req_pool_indices = alloc_req_slots(bs)
# fill_ids = origin_input_ids + output_ids
# input_ids are those token_ids whose KV Cache needs computing
input_ids = [r.fill_ids[len(r.prefix_indices): ] for r in reqs]
# this is the num tokens we need allocate slots to accommodate
extend_num_tokens = sum(len(ids) for ids in input_ids)
seq_lens = [len(r.fill_ids) for r in reqs]
prefix_lens = [len(r.prefix_indices) for r in reqs]
# extend_lens is actually equal to `seq_lens - prefix_lens`
extend_lens = [r.extend_input_len for r in reqs]
for i, (req, seq_len, pre_len) in enumerate(reqs, seq_lens, pre_lens):
req.req_pool_idx = req_pool_indices[i]
# here assert again
assert seq_len - pre_len == req.extend_input_len
if pre_len > 0:
# write cached `out_cache_loc` into `req_to_token_pool`
req_to_token_pool.write(
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
)
out_cache_loc = alloc_token_slots(extend_num_tokens)
pt = 0
for i in range(bs):
# write uncached `out_cache_loc` into `req_to_token_pool`
for i in range(bs):
self.req_to_token_pool.write(
(req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])),
out_cache_loc[pt : pt + extend_lens[i]],
)
pt += extend_lens[i]
... END !!!
def prepare_for_decode():
bs = len(reqs)
# allocate `bs` tokens
out_cache_loc = self.alloc_token_slots(bs)
# compute `req_to_token_pool` locs
locs = seq_lens + 1
# write
req_to_token_pool.write(
(req_pool_indices, locs), out_cache_loc.to(torch.int32)
)
... END !!!
def alloc_req_slots(num_reqs):
req_pool_indices = req_to_token_pool.alloc(num_reqs)
if req_pool_indices is None: raise RuntimeError("")
return req_pool_indices
def alloc_token_slots(num_tokens):
out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
if out_cache_loc is None: raise RuntimeError()
return out_cache_loc
Read & Save Real KV Cache Data when computing Attention Scores
In model forward, model_runner
will call attention_backnend.init_forward_metadata
to initialize the metadata for the attention backend and then call the actual forward_extend
and forward_decode
during the init_forward_metadata
, by use req_to_token_pool.req_to_token
, we get the page table
which is then used in each layer’s attention score computation
class FlashAttentionBackend(AttentionBackend):
def init_forward_metadata(forward_batch):
metadata = FlashAttentionMetadata()
if forward_batch.is_decode():
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
# get the page table!
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
elif forward_batch.is_extend():
# ... nearly same ...
save & retrieve
process takes place at the model forward, where attention_backend.forward_extend
or attention_backend.forward_extend
class FlashAttention(AttentionBackend):
def forward_extend(q, k, v, layer, forward_batch, save_kv_cache=True, ...):
if k is not None:
if v is not None:
cache_loc = forward_batch.out_cache_loc
# !!! Save the KV Cache into token_to_kv_pool !!!
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, ...
)
# Use precomputed metadata across all layers
# prepare metedata for FlashAttention operator
metadata = self.forward_metadata
page_table = metadata.page_table
cu_seqlens_q = metadata.cu_seqlens_q
cache_seqlens = metadata.cache_seqlens_int32
max_seqlen_q = metadata.max_seq_len_q
max_seqlen_k = metadata.max_seq_len_k
cu_seqlens_k = metadata.cu_seqlens_k
# !!! Retrive the KV Cache from token_to_kv_pool !!!
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
layer.layer_id
)
# review the format
key_cache = key_cache.view(
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
)
value_cache = value_cache.view(
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
)
result = flash_attn_with_kvcache(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
key_cache,
value_cache,
page_table,
...
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
def forward_decode(forward_batch):
# ... nearly same to forward_extend ...
The first section KV Cache Management
is over here, we talked about
- How
KV Cache
are initiate: Just create a List of Huge Tensors - How
KV Cache
is manged (allocateslots
,tokens
to reqs) - How the real
KV Cache data
are saved and retrieved when computing attention scores
Radix Tree Cache
One novel idea of SGLang
is Radix Attention
, which uses radix tree
to reuse KV Cache
as much as possible.
So, what is Radix Tree
?
Its core idea is to get reuseable out_cache_loc
based on the token_ids
, token_ids
is the key in the tree node. So for the requests with the same prefix token_ids
, we search in the tree to get its out_cache_loc
, by doing so, we can use one out_cache_loc
for two or more requests.
Radix Tree
class TreeNode:
counter = 0
def __init__(self, id: Optional[int] = None):
self.children = defaultdict(TreeNode) # use 1 page-size key as the dict_key
self.parent = None
self.key = None # Key is the `token_ids`
self.value = None # Value is the `out_cache_loc`, which records the location of real KV Cache data
self.lock_ref = 0 # how many reqs reference this node
self.last_access_time = time.monotonic()
self.hit_count = 0
# indicating the node is loading KV cache from host
self.loading = False
# store the host indices of KV cache
self.host_value = None
self.id = TreeNode.counter if id is None else id
TreeNode.counter += 1
class RadixTree(BasePrefixCache):
def __init__(req_to_token_pool, token_to_kv_pool_allocator, page_size, ...):
if page_size == 1:
# key_match_fn: given 2 keys, return how many prefix ids that two keys has
key_match_fn = _key_match_page_size1
# get_child_key_fn: get 1-page-size key
get_child_key_fn = lambda key: key[0]
else:
key_match_fn = partial(_key_match_paged, page_size=page_size)
get_child_key_fn = lambda key: tuple(key[:page_size])
reset()
def reset(self):
self.root_node = TreeNode()
self.root_node.key = []
self.root_node.value = []
self.root_node.lock_ref = 1
self.evictable_size_ = 0
self.protected_size_ = 0
self._record_all_cleared_event()
Match
########################
# Match Prefix
########################
def match_prefix(key: List[int]):
page_aligned_len = len(key) // page_size * page_size
key = key[:page_aligned_len]
value, last_node = _match_prefix_helper(root_node, key)
if value: value = torch.cat(value)
else: value = torch.empty((0,), dtype=torch.int64, device=device)
# 1. prefix `out_cache_loc` in the radix tree
# 2. last_node
return value, last_node
def _match_prefix_helper(node, key):
# update time
node.last_access_time = time.monotonic()
# get child key first
child_key = self.get_child_key_fn(key)
value = []
while len(key) > 0 and child_key in node.children.keys():
child = node.children[child_key]
# update time
child.last_access_time = time.monotonic()
# get how many number of prefix ids (n * page_size)
prefix_len = self.key_match_fn(child.key, key)
if prefix_len < len(child.key):
# not a full match, split a full match, but shorter new_node
# NOTE: prefix_len is at least 1-page-size since `child_key in node.children.keys()`
new_node = self._split_node(child.key, child, prefix_len)
# append the matched value
value.append(new_node.value)
node = new_node
break
else:
# full match, try to get next child
# save the value
value.append(child.value)
# update the node
node = child
# truncate the prefix matched keys
key = key[prefix_len:]
if len(key):
child_key = self.get_child_key_fn(key)
return value, node
Split Node
#############
# Split Node
#############
def _split_node(key: List[int], child, split_len):
# here, key is actually child's key
# key and value will be split into two parts
# key and value: [......................... | ..........................]
# prefix_len
# left: a new node's kv right: truncated child
# after this split process, `child(node)` will be
# `parent <-> child` =>
# `parent <-> new_node <-> truncated child`
# create a new node
new_node = TreeNode()
# make `new_node ---truncated child's 1-page-size key---> child`
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
# make `parent -> new_node`
new_node.parent = child.parent
# make new_node get the same ref count
new_node.lock_ref = child.lock_ref
# get left side kv, and set them to new_node
new_node.key = child.key[:split_len]
new_node.value = child.value[:split_len]
# make `new_node <- child`
child.parent = new_node
# make `child` become `truncated child`: truncate the split_len key and value
child.key = child.key[split_len:]
child.value = child.value[split_len:]
# make `parent ----new_node's 1-page-size key---> new_node
new_node.parent.children[self.get_child_key_fn(key)] = new_node
return new_node
Insert Node
################
# Insert Node
################
def insert(self, key: List, value=None):
if self.disable: return 0
if value is None: value = [x for x in key]
return _insert_helper(root_node, key, value)
def _insert_helper(node, key, value):
# update node's time for LRU eviction
node.last_access_time = time.monotonic()
if len(key) == 0: return 0
# get 1-page-size key used for searching prefix
child_key = get_child_key_fn(key)
total_prefix_length = 0
while len(key) > 0 and child_key in node.children.keys():
# get next node
node = node.children[child_key]
# update next node's time
node.last_access_time = time.monotonic()
# get prefix_len of next node and query key
prefix_len = self.key_match_fn(node.key, key)
total_prefix_length += prefix_len
# update key and value
key = key[prefix_len:]
value = value[prefix_len:]
if prefix_len < len(node.key):
# not a full match, split the node
new_node = _split_node(node.key, node, prefix_len)
node = new_node
if len(key):
# there are still some keys hasn't been matched, try to continue to find next node
child_key = get_child_key_fn(key)
# NOTE: if prefix_len < len(node.key)
# then it is impossible to continue this while loop
# because the splitted new node only have one child, which is the unmatched node
# so this new `child_key` doesn't exist `node.children.keys()`
# this while loop continues only if a full match, but the query key still has a remaining part
if len(key):
# if there exists still a remaining key that doesn't match in this radix tree,
# create a new node
# NOTE: this new node's lock_ref is 0, so it deems evictable
new_node = TreeNode()
new_node.parent = node
new_node.key = key
new_node.value = value
# make node` point to this `new_node`
node.children[child_key] = new_node
# this is evictable since it is a leaf node
evictable_size_ += len(value)
return total_prefix_length
Lock Ref
##################
# Handle Lock Ref
##################
def dec_lock_ref(node):
if disable: return 0 # if disable radix tree
delta = 0
# bottom to up
while node != root_node:
if node.lock_ref == 1:
# if there is only 1 ref to this node, this node deems evictable
evictable_size_ += len(node.value)
protected_size_ -= len(node.value)
delta += len(node.value)
lock_ref -= 1
node = node.parent
return delta
def inc_lock_ref(node):
if disable: return 0
delta = 0
# bottom to up
while node != root_node:
if node.lock_ref == 0:
# if no other req ref this node, this node turns evictable to protectable
evictable_size_ -= len(node.value)
self.protected_size_ += len(node.value)
delta -= len(node.value)
node.lock_ref += 1
return delta
API
- Cache when request finished or unfished
- Evcit
#######################
# Cache Unfinished Req
#######################
def cache_unfinished_req(req):
token_ids = req.fill_ids
# get `out_cache_loc`, which is actually Value
kv_indices = req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
]
if page_size != 1:
page_aligned_len = len(kv_indices) // page_size * page_size
# V align
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
else:
page_aligned_len = len(kv_indices)
page_aligned_kv_indices = kv_indices.clone()
# K align
page_aligned_token_ids = token_ids[:page_aligned_len]
# insert K,V
new_prefix_len = insert(page_aligned_token_ids, page_aligned_kv_indices)
# remove repetive part
token_to_kv_pool_allocator.free(
kv_indices[len(req.prefix_indices) : new_prefix_len]
)
# get prefixed `out_cache_loc` and `new_last_node`
new_indices, new_last_node = self.match_prefix(page_aligned_token_ids)
# only write new `out_cache_loc`
req_to_token_pool.write(
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
new_indices[len(req.prefix_indices) :],
)
# root -> ... -> last_node -> ... -> new_last_node
# |-- lock_ref - 1 --|
dec_lock_ref(req.last_node)
# root -> ... -> last_node -> ... -> new_last_node
# |------------- lock_ref + 1 -----------------|
inc_lock_ref(new_last_node)
#####################
# Cache Finished Req
#####################
def cache_finished_req(req):
if self.disable:
# if disable radix tree, free the KV Cache of this finished req directly
# get `out_cache_loc`
kv_indices = req_to_token_pool.req_to_token[
req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1
]
# free `req slots` and `token_to_kv_pool slots`
token_to_kv_pool_allocator.free(kv_indices)
req_to_token_pool.free(req.req_pool_idx)
return
# if using radix tree, don't free KV Cache instantly for reusing opportunities
# get token_ids, which is actually key
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
# get `out_cache_loc`, which is actually value
kv_indices = req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
]
# assuming page size is 1, so it is automatically aligned
page_aligned_len = len(kv_indices)
page_aligned_kv_indices = kv_indices.clone()
# insert the [token_ids, out_cache_loc] into radix tree for reuse
new_prefix_len = insert(
token_ids[:page_aligned_len], page_aligned_kv_indices
)
# only free [len(prefix_indices): new_prefix_len] part of kv pool, why?
# since these part of `out_cache_loc` are REPETITIVE (REDUNDANT)!
# The whole process is as follows:
# `req.prefix_indices` is computed when it is scheduled at first
# `new_prefix_len` is the prefix lens when it is finished
# [len(req.prefix_indices): new_prefix_len] is the repetive part during which computed
token_to_kv_pool_allocator.free(
kv_indices[len(req.prefix_indices) : new_prefix_len]
)
# free `req slot` for sure
# since the req has been finished, its req_pool_idx can be used for other reqs
req_to_token_pool.free(req.req_pool_idx)
# dec lock_ref of those node owns out_cache_loc[:len(prefix_indices)]
# these part will be possibly evictable
# but Note: these `out_cache_loc` have not been evicted yet
dec_lock_ref(req.last_node)
def evict(num_tokens: int):
if disable: return
leaves = _collect_leaves()
# sort by `last_access_time` (LRU)
heapq.heapify(leaves)
num_evicted = 0
while num_evicted < num_tokens and len(leaves):
x = heapq.heappop(leaves)
if x == self.root_node: break
# if some reqs are pointing to this node, skip it
if x.lock_ref > 0: continue
# free this node's `out_cache_loc`
token_to_kv_pool_allocator.free(x.value)
num_evicted += len(x.value)
_delete_leaf(x)
# add new leaves node for next evitable
if len(x.parent.children) == 0:
heapq.heappush(leaves, x.parent)
def _delete_leaf(node):
# delete this node from its parent
for k, v in node.parent.children.items():
if v == node:
break
del node.parent.children[k]
# update evicatble_size
evictable_size_ -= len(node.key)
Usage
How to use the above API provided by radix_cache_tree
?
Cache
When prefill
is over,
def process_batch_result_prefill(batch, result):
for i, (req, next_token_id) in enumerate(batch.reqs, result.next_token_ids):
req.output_ids.append(next_token_id)
req.check_finished()
if req.finished():
tree_cache.cache_finished_req(req)
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
# This updates radix so others can match
tree_cache.cache_unfinished_req(req)
When decode
is over,
def process_batch_result_decode(batch, result):
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
req.check_finished()
if req.finished():
tree_cache.cache_finished_req(req)
💡 Only when decode
finished, tree_cache cached its (token_ids
, out_cache_loc
)
Evict
Evict, which is also free out_cache_loc
, happened when available_size in token_to_kv_pool
cannot support the incoming req
def alloc_token_slots(num_tokens: int, backup_state: bool = False):
if token_to_kv_pool_allocator.available_size() < num_tokens:
if tree_cache is not None:
tree_cache.evict(num_tokens)
out_cache_loc = token_to_kv_pool_allocator.alloc(num_tokens)