109 lines
4.6 KiB
Python
109 lines
4.6 KiB
Python
|
|
"""Sliding-window context-trim helpers.
|
|||
|
|
|
|||
|
|
Mirrors what llama.cpp's context-shift used to do: count tokens with tiktoken
|
|||
|
|
(cl100k_base) when available, drop oldest non-system messages until the prompt
|
|||
|
|
fits inside (n_ctx - safety_margin).
|
|||
|
|
|
|||
|
|
Also owns the per-(endpoint, model) n_ctx cache that the routes populate from
|
|||
|
|
exceed_context_size_error bodies and from finish_reason=="length" signals.
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
import tiktoken as _tiktoken
|
|||
|
|
_tiktoken_enc = _tiktoken.get_encoding("cl100k_base")
|
|||
|
|
except Exception:
|
|||
|
|
_tiktoken_enc = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _count_message_tokens(messages: list) -> int:
|
|||
|
|
"""Approximate token count for a message list.
|
|||
|
|
|
|||
|
|
Uses tiktoken cl100k_base when available (within ~5-15% of llama tokenizers).
|
|||
|
|
Falls back to char/4 heuristic if tiktoken is unavailable.
|
|||
|
|
Formula follows OpenAI's per-message overhead: 4 tokens/message + content + 2 priming.
|
|||
|
|
"""
|
|||
|
|
if _tiktoken_enc is None:
|
|||
|
|
return sum(len(str(m.get("content", ""))) for m in messages) // 4
|
|||
|
|
|
|||
|
|
total = 2 # priming tokens
|
|||
|
|
for msg in messages:
|
|||
|
|
total += 4 # per-message role/separator overhead
|
|||
|
|
content = msg.get("content", "")
|
|||
|
|
if isinstance(content, str):
|
|||
|
|
total += len(_tiktoken_enc.encode(content))
|
|||
|
|
elif isinstance(content, list):
|
|||
|
|
for part in content:
|
|||
|
|
if isinstance(part, dict) and part.get("type") == "text":
|
|||
|
|
total += len(_tiktoken_enc.encode(part.get("text", "")))
|
|||
|
|
return total
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _trim_messages_for_context(
|
|||
|
|
messages: list,
|
|||
|
|
n_ctx: int,
|
|||
|
|
safety_margin: int = None,
|
|||
|
|
target_tokens: int = None,
|
|||
|
|
) -> list:
|
|||
|
|
"""Sliding-window trim — mirrors what llama.cpp context-shift used to do.
|
|||
|
|
|
|||
|
|
Keeps all system messages and the most recent non-system messages that fit
|
|||
|
|
within (n_ctx - safety_margin) tokens. Oldest non-system messages are dropped
|
|||
|
|
first (FIFO). The last message is always preserved.
|
|||
|
|
|
|||
|
|
safety_margin defaults to 1/4 of n_ctx to leave headroom for the generated
|
|||
|
|
response, including RAG tool results and tool call JSON synthesis.
|
|||
|
|
|
|||
|
|
target_tokens: if provided, overrides the (n_ctx - safety_margin) target.
|
|||
|
|
Pass a calibrated value when actual n_prompt_tokens is known from the error
|
|||
|
|
body so that tiktoken underestimation vs the backend tokenizer is corrected.
|
|||
|
|
"""
|
|||
|
|
if target_tokens is not None:
|
|||
|
|
target = target_tokens
|
|||
|
|
else:
|
|||
|
|
if safety_margin is None:
|
|||
|
|
safety_margin = n_ctx // 4
|
|||
|
|
target = n_ctx - safety_margin
|
|||
|
|
system_msgs = [m for m in messages if m.get("role") == "system"]
|
|||
|
|
non_system = [m for m in messages if m.get("role") != "system"]
|
|||
|
|
|
|||
|
|
while len(non_system) > 1:
|
|||
|
|
if _count_message_tokens(system_msgs + non_system) <= target:
|
|||
|
|
break
|
|||
|
|
non_system.pop(0) # drop oldest non-system message
|
|||
|
|
|
|||
|
|
# Ensure the first non-system message is a user message (chat templates require it).
|
|||
|
|
# Drop any leading assistant/tool messages that were left after trimming.
|
|||
|
|
while non_system and non_system[0].get("role") != "user":
|
|||
|
|
non_system.pop(0)
|
|||
|
|
|
|||
|
|
return system_msgs + non_system
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _calibrated_trim_target(msgs: list, n_ctx: int, actual_tokens: int) -> int:
|
|||
|
|
"""Return a tiktoken-scale trim target based on how much backend tokens must be shed.
|
|||
|
|
|
|||
|
|
actual_tokens includes messages + tool schemas + overhead as counted by the backend.
|
|||
|
|
_count_message_tokens only counts message text, so we cannot derive an accurate
|
|||
|
|
per-token scale from the ratio. Instead we compute the *delta* we need to remove
|
|||
|
|
in backend space, then convert just that delta to tiktoken scale (×1.2 buffer).
|
|||
|
|
|
|||
|
|
Example: actual=17993, n_ctx=16384, headroom=4096 → need to shed 5705 backend
|
|||
|
|
tokens → shed 6846 tiktoken tokens from messages.
|
|||
|
|
"""
|
|||
|
|
cur_tiktoken = _count_message_tokens(msgs)
|
|||
|
|
headroom = n_ctx // 4 # reserve for generated output
|
|||
|
|
max_prompt = n_ctx - headroom # desired max backend tokens in prompt
|
|||
|
|
to_shed = max(0, actual_tokens - max_prompt) # backend tokens we must drop
|
|||
|
|
# Convert to tiktoken scale with 20% buffer (tiktoken underestimates llama by ~15-20%)
|
|||
|
|
tiktoken_to_shed = int(to_shed * 1.2)
|
|||
|
|
return max(1, cur_tiktoken - tiktoken_to_shed)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# Per-(endpoint, model) n_ctx cache.
|
|||
|
|
# Populated from two sources:
|
|||
|
|
# 1. 400 exceed_context_size_error body → n_ctx field
|
|||
|
|
# 2. finish_reason/done_reason == "length" in streaming → prompt_tokens + completion_tokens
|
|||
|
|
# Only used for proactive pre-trimming when n_ctx <= _CTX_TRIM_SMALL_LIMIT,
|
|||
|
|
# so large-context models (200k+ for coding) are never touched.
|
|||
|
|
_endpoint_nctx: dict[tuple[str, str], int] = {}
|
|||
|
|
_CTX_TRIM_SMALL_LIMIT = 32768 # only proactively trim models with n_ctx at or below this
|