120 lines
5.1 KiB
Python
120 lines
5.1 KiB
Python
"""Sliding-window context-trim helpers.
|
||
|
||
Mirrors what llama.cpp's context-shift used to do: count tokens with tiktoken
|
||
(cl100k_base) when available, drop oldest non-system messages until the prompt
|
||
fits inside (n_ctx - safety_margin).
|
||
|
||
Also owns the per-(endpoint, model) n_ctx cache that the routes populate from
|
||
exceed_context_size_error bodies and from finish_reason=="length" signals.
|
||
"""
|
||
import os
|
||
|
||
# Point tiktoken at the vendored cl100k_base vocab so the encoding loads offline,
|
||
# without a network download. The download would otherwise fail anyway: this repo
|
||
# has a top-level `requests` package that shadows the pip `requests` tiktoken's
|
||
# downloader imports, so get_encoding() would silently fall back to char/4. See
|
||
# vendor/tiktoken/. setdefault lets an explicit env override win.
|
||
os.environ.setdefault(
|
||
"TIKTOKEN_CACHE_DIR",
|
||
os.path.join(os.path.dirname(os.path.abspath(__file__)), "vendor", "tiktoken"),
|
||
)
|
||
|
||
try:
|
||
import tiktoken as _tiktoken
|
||
_tiktoken_enc = _tiktoken.get_encoding("cl100k_base")
|
||
except Exception:
|
||
_tiktoken_enc = None
|
||
|
||
|
||
def _count_message_tokens(messages: list) -> int:
|
||
"""Approximate token count for a message list.
|
||
|
||
Uses tiktoken cl100k_base when available (within ~5-15% of llama tokenizers).
|
||
Falls back to char/4 heuristic if tiktoken is unavailable.
|
||
Formula follows OpenAI's per-message overhead: 4 tokens/message + content + 2 priming.
|
||
"""
|
||
if _tiktoken_enc is None:
|
||
return sum(len(str(m.get("content", ""))) for m in messages) // 4
|
||
|
||
total = 2 # priming tokens
|
||
for msg in messages:
|
||
total += 4 # per-message role/separator overhead
|
||
content = msg.get("content", "")
|
||
if isinstance(content, str):
|
||
total += len(_tiktoken_enc.encode(content))
|
||
elif isinstance(content, list):
|
||
for part in content:
|
||
if isinstance(part, dict) and part.get("type") == "text":
|
||
total += len(_tiktoken_enc.encode(part.get("text", "")))
|
||
return total
|
||
|
||
|
||
def _trim_messages_for_context(
|
||
messages: list,
|
||
n_ctx: int,
|
||
safety_margin: int = None,
|
||
target_tokens: int = None,
|
||
) -> list:
|
||
"""Sliding-window trim — mirrors what llama.cpp context-shift used to do.
|
||
|
||
Keeps all system messages and the most recent non-system messages that fit
|
||
within (n_ctx - safety_margin) tokens. Oldest non-system messages are dropped
|
||
first (FIFO). The last message is always preserved.
|
||
|
||
safety_margin defaults to 1/4 of n_ctx to leave headroom for the generated
|
||
response, including RAG tool results and tool call JSON synthesis.
|
||
|
||
target_tokens: if provided, overrides the (n_ctx - safety_margin) target.
|
||
Pass a calibrated value when actual n_prompt_tokens is known from the error
|
||
body so that tiktoken underestimation vs the backend tokenizer is corrected.
|
||
"""
|
||
if target_tokens is not None:
|
||
target = target_tokens
|
||
else:
|
||
if safety_margin is None:
|
||
safety_margin = n_ctx // 4
|
||
target = n_ctx - safety_margin
|
||
system_msgs = [m for m in messages if m.get("role") == "system"]
|
||
non_system = [m for m in messages if m.get("role") != "system"]
|
||
|
||
while len(non_system) > 1:
|
||
if _count_message_tokens(system_msgs + non_system) <= target:
|
||
break
|
||
non_system.pop(0) # drop oldest non-system message
|
||
|
||
# Ensure the first non-system message is a user message (chat templates require it).
|
||
# Drop any leading assistant/tool messages that were left after trimming.
|
||
while non_system and non_system[0].get("role") != "user":
|
||
non_system.pop(0)
|
||
|
||
return system_msgs + non_system
|
||
|
||
|
||
def _calibrated_trim_target(msgs: list, n_ctx: int, actual_tokens: int) -> int:
|
||
"""Return a tiktoken-scale trim target based on how much backend tokens must be shed.
|
||
|
||
actual_tokens includes messages + tool schemas + overhead as counted by the backend.
|
||
_count_message_tokens only counts message text, so we cannot derive an accurate
|
||
per-token scale from the ratio. Instead we compute the *delta* we need to remove
|
||
in backend space, then convert just that delta to tiktoken scale (×1.2 buffer).
|
||
|
||
Example: actual=17993, n_ctx=16384, headroom=4096 → need to shed 5705 backend
|
||
tokens → shed 6846 tiktoken tokens from messages.
|
||
"""
|
||
cur_tiktoken = _count_message_tokens(msgs)
|
||
headroom = n_ctx // 4 # reserve for generated output
|
||
max_prompt = n_ctx - headroom # desired max backend tokens in prompt
|
||
to_shed = max(0, actual_tokens - max_prompt) # backend tokens we must drop
|
||
# Convert to tiktoken scale with 20% buffer (tiktoken underestimates llama by ~15-20%)
|
||
tiktoken_to_shed = int(to_shed * 1.2)
|
||
return max(1, cur_tiktoken - tiktoken_to_shed)
|
||
|
||
|
||
# Per-(endpoint, model) n_ctx cache.
|
||
# Populated from two sources:
|
||
# 1. 400 exceed_context_size_error body → n_ctx field
|
||
# 2. finish_reason/done_reason == "length" in streaming → prompt_tokens + completion_tokens
|
||
# Only used for proactive pre-trimming when n_ctx <= _CTX_TRIM_SMALL_LIMIT,
|
||
# so large-context models (200k+ for coding) are never touched.
|
||
_endpoint_nctx: dict[tuple[str, str], int] = {}
|
||
_CTX_TRIM_SMALL_LIMIT = 32768 # only proactively trim models with n_ctx at or below this
|